diff --git a/.circleci/config.yml b/.circleci/config.yml index 60e586934b..7a12d3c07d 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -2,15 +2,22 @@ version: 2.1 setup: true +on_main_or_tag_filter: &on_main_or_tag_filter + filters: + branches: + only: main + tags: + only: /^v\d+\.\d+\.\d+/ + on_tag_filter: &on_tag_filter filters: branches: ignore: /.*/ tags: - only: /^v.+/ + only: /^v\d+\.\d+\.\d+/ orbs: - path-filtering: circleci/path-filtering@0.1.3 + path-filtering: circleci/path-filtering@1.2.0 jobs: publish: @@ -32,7 +39,7 @@ jobs: command: unset TWINE_USERNAME TWINE_PASSWORD && make publish-tests gh-release: docker: - - image: cimg/node:16.14 + - image: cimg/node:20.19.0 resource_class: small steps: - run: @@ -47,40 +54,41 @@ jobs: ui-build: docker: - - image: cimg/node:19.8 + - image: cimg/node:20.19.0 resource_class: medium steps: - checkout - run: - name: Install packages - command: npm --prefix web/client ci + name: Install Dependencies + command: | + pnpm install - run: name: Build UI - command: npm --prefix web/client run build + command: pnpm --prefix web/client run build - persist_to_workspace: root: web/client paths: - dist trigger_private_renovate: - docker: - - image: cimg/base:2021.11 - resource_class: small - steps: - - run: - name: Trigger private renovate - command: | - curl --request POST \ - --url $TOBIKO_PRIVATE_CIRCLECI_URL \ - --header "Circle-Token: $TOBIKO_PRIVATE_CIRCLECI_KEY" \ - --header "content-type: application/json" \ - --data '{ - "branch":"main", - "parameters":{ - "run_main_pr":false, - "run_sqlmesh_commit":false, - "run_renovate":true - } - }' + docker: + - image: cimg/base:2021.11 + resource_class: small + steps: + - run: + name: Trigger private renovate + command: | + curl --request POST \ + --url $TOBIKO_PRIVATE_CIRCLECI_URL \ + --header "Circle-Token: $TOBIKO_PRIVATE_CIRCLECI_KEY" \ + --header "content-type: application/json" \ + --data '{ + "branch":"main", + "parameters":{ + "run_main_pr":false, + "run_sqlmesh_commit":false, + "run_renovate":true + } + }' workflows: setup-workflow: @@ -89,20 +97,19 @@ workflows: mapping: | web/client/.* client true (sqlmesh|tests|examples|web/server)/.* python true - pytest.ini|setup.cfg|setup.py python true + pytest.ini|setup.cfg|setup.py|pyproject.toml python true \.circleci/.*|Makefile|\.pre-commit-config\.yaml common true - + vscode/extensions/.* vscode true + tag: "3.9" - gh-release: <<: *on_tag_filter - ui-build: - <<: *on_tag_filter - requires: - - gh-release + <<: *on_main_or_tag_filter - publish: - <<: *on_tag_filter + <<: *on_main_or_tag_filter requires: - ui-build - trigger_private_renovate: <<: *on_tag_filter requires: - - publish \ No newline at end of file + - publish diff --git a/.circleci/continue_config.yml b/.circleci/continue_config.yml index 7c7411e9c4..bf27e03f47 100644 --- a/.circleci/continue_config.yml +++ b/.circleci/continue_config.yml @@ -1,8 +1,5 @@ version: 2.1 -orbs: - python: circleci/python@1.5.0 - parameters: client: type: boolean @@ -14,6 +11,9 @@ parameters: type: boolean default: false +orbs: + windows: circleci/windows@5.0 + commands: halt_unless_core: steps: @@ -37,6 +37,21 @@ commands: - run: circleci-agent step halt jobs: + vscode_test: + docker: + - image: cimg/node:20.19.1-browsers + resource_class: small + steps: + - checkout + - run: + name: Install Dependencies + command: | + pnpm install + - run: + name: Run VSCode extension CI + command: | + cd vscode/extension + pnpm run ci doc_tests: docker: - image: cimg/python:3.10 @@ -51,7 +66,7 @@ jobs: name: Run doc tests command: make doc-test - style_and_slow_tests: + style_and_cicd_tests: parameters: python_version: type: string @@ -70,42 +85,58 @@ jobs: name: Install ODBC command: sudo apt-get install unixodbc-dev - run: - name: Install SQLMesh and dbt adapter dependencies - command: make install-cicd-test + name: Install SQLMesh dev dependencies + command: make install-dev + - run: + name: Fix Git URL override + command: git config --global --unset url."ssh://git@github.com".insteadOf - run: name: Run linters and code style checks command: make py-style + - unless: + condition: + equal: ["3.9", << parameters.python_version >>] + steps: + - run: + name: Exercise the benchmarks + command: make benchmark-ci - run: - name: Run slow tests + name: Run cicd tests command: make cicd-test + - store_test_results: + path: test-results - style_and_slow_tests_pydantic_v1: - docker: - - image: cimg/python:3.10 - resource_class: large - environment: - PYTEST_XDIST_AUTO_NUM_WORKERS: 8 + cicd_tests_windows: + executor: + name: windows/default + size: large steps: - halt_unless_core - - checkout - run: - name: Install OpenJDK - command: sudo apt-get update && sudo apt-get install default-jdk - - run: - name: Install ODBC - command: sudo apt-get install unixodbc-dev - - run: - name: Install SQLMesh and dbt adapter dependencies - command: make install-cicd-test + name: Enable symlinks in git config + command: git config --global core.symlinks true + - checkout - run: - name: Install Pydantic v1 - command: pip install --upgrade "pydantic<2.0.0" && pip uninstall pydantic_core -y + name: Install System Dependencies + command: | + choco install make which -y + refreshenv - run: - name: Run linters and code style checks - command: make py-style + name: Install SQLMesh dev dependencies + command: | + python -m venv venv + . ./venv/Scripts/activate + python.exe -m pip install --upgrade pip + make install-dev - run: - name: Run slow tests - command: make cicd-test + name: Run fast unit tests + command: | + . ./venv/Scripts/activate + which python + python --version + make fast-test + - store_test_results: + path: test-results migration_test: docker: @@ -117,172 +148,184 @@ jobs: - halt_unless_core - checkout - run: - name: Run the migration test - command: ./.circleci/test_migration.sh + name: Run the migration test - sushi + command: ./.circleci/test_migration.sh sushi "--gateway duckdb_persistent" + - run: + name: Run the migration test - sushi_dbt + command: ./.circleci/test_migration.sh sushi_dbt "--config migration_test_config" ui_style: docker: - - image: cimg/python:3.8 + - image: cimg/node:20.19.0 resource_class: small steps: - - halt_unless_client - checkout - - run: - command: | - cp .pre-commit-config.yaml pre-commit-cache-key.txt - python --version --version >> pre-commit-cache-key.txt - restore_cache: + name: Restore pnpm Package Cache keys: - - v1-pc-cache-{{ checksum "pre-commit-cache-key.txt" }} - - run: - name: Install pre-commit - command: pip install pre-commit + - pnpm-packages-{{ checksum "pnpm-lock.yaml" }} - run: - name: Run linters and code style checks - command: make ui-style + name: Install Dependencies + command: | + pnpm install - save_cache: - key: v1-pc-cache-{{ checksum "pre-commit-cache-key.txt" }} + name: Save pnpm Package Cache + key: pnpm-packages-{{ checksum "pnpm-lock.yaml" }} paths: - - ~/.cache/pre-commit + - .pnpm-store + - run: + name: Run linters and code style checks + command: pnpm run lint ui_test: docker: - - image: mcr.microsoft.com/playwright:v1.40.1-jammy + - image: mcr.microsoft.com/playwright:v1.54.1-jammy resource_class: medium steps: - halt_unless_client - checkout - restore_cache: + name: Restore pnpm Package Cache keys: - - v1-nm-cache-{{ checksum "web/client/package-lock.json" }} + - pnpm-packages-{{ checksum "pnpm-lock.yaml" }} + - run: + name: Install pnpm package manager + command: | + npm install --global corepack@latest + corepack enable + corepack prepare pnpm@latest-10 --activate + pnpm config set store-dir .pnpm-store - run: - name: Install packages - command: npm --prefix web/client ci + name: Install Dependencies + command: | + pnpm install - save_cache: - key: v1-nm-cache-{{ checksum "web/client/package-lock.json" }} + name: Save pnpm Package Cache + key: pnpm-packages-{{ checksum "pnpm-lock.yaml" }} paths: - - /root/.npm + - .pnpm-store - run: name: Run tests command: npm --prefix web/client run test - airflow_docker_tests: + engine_tests_docker: + parameters: + engine: + type: string machine: - image: ubuntu-2204:2022.10.2 + image: ubuntu-2404:2024.05.1 docker_layer_caching: true resource_class: large environment: - PYTEST_XDIST_AUTO_NUM_WORKERS: 8 SQLMESH__DISABLE_ANONYMIZED_ANALYTICS: "1" steps: + - halt_unless_core - checkout - run: - name: Install envsubst - command: sudo apt-get update && sudo apt-get install gettext-base - - run: - name: Install ruamel.yaml - command: pip3 install ruamel.yaml==0.16.0 + name: Install OS-level dependencies + command: ./.circleci/install-prerequisites.sh "<< parameters.engine >>" - run: - name: Run Airflow slow tests - command: make airflow-docker-test-with-env - no_output_timeout: 15m - - run: - name: Collect Airflow logs - command: | - tar -czf ./airflow_logs.tgz -C ./examples/airflow/logs . - mkdir -p /tmp/airflow_logs - cp ./airflow_logs.tgz /tmp/airflow_logs/ - when: on_fail - - store_artifacts: - path: /tmp/airflow_logs + name: Run tests + command: make << parameters.engine >>-test + no_output_timeout: 20m + - store_test_results: + path: test-results - engine_adapter_docker_tests: - machine: - image: ubuntu-2204:2022.10.2 - docker_layer_caching: true - resource_class: large + engine_tests_cloud: + parameters: + engine: + type: string + docker: + - image: cimg/python:3.12 + resource_class: medium environment: - PYTEST_XDIST_AUTO_NUM_WORKERS: 8 + PYTEST_XDIST_AUTO_NUM_WORKERS: 4 SQLMESH__DISABLE_ANONYMIZED_ANALYTICS: "1" steps: + - halt_unless_core - checkout - run: - name: Install pg_config - command: sudo apt-get update && sudo apt-get install libpq-dev + name: Install OS-level dependencies + command: ./.circleci/install-prerequisites.sh "<< parameters.engine >>" - run: - name: Install dependencies - command: make install-engine-test - - run: - name: Bring up Dockerized Engines - command: make engine-up + name: Generate database name + command: | + UUID=`cat /proc/sys/kernel/random/uuid` + TEST_DB_NAME="circleci_${UUID:0:8}" + echo "export TEST_DB_NAME='$TEST_DB_NAME'" >> "$BASH_ENV" + echo "export SNOWFLAKE_DATABASE='$TEST_DB_NAME'" >> "$BASH_ENV" + echo "export DATABRICKS_CATALOG='$TEST_DB_NAME'" >> "$BASH_ENV" + echo "export REDSHIFT_DATABASE='$TEST_DB_NAME'" >> "$BASH_ENV" + echo "export GCP_POSTGRES_DATABASE='$TEST_DB_NAME'" >> "$BASH_ENV" + echo "export FABRIC_DATABASE='$TEST_DB_NAME'" >> "$BASH_ENV" + + # Make snowflake private key available + echo $SNOWFLAKE_PRIVATE_KEY_RAW | base64 -d > /tmp/snowflake-keyfile.p8 + echo "export SNOWFLAKE_PRIVATE_KEY_FILE='/tmp/snowflake-keyfile.p8'" >> "$BASH_ENV" - run: - name: Make sure DBs are ready - command: sleep 60 + name: Create test database + command: ./.circleci/manage-test-db.sh << parameters.engine >> "$TEST_DB_NAME" up - run: name: Run tests - command: make engine-docker-test - no_output_timeout: 30m - - trigger_private_tests: - docker: - - image: cimg/base:2021.11 - resource_class: small - steps: - - checkout - - run: - name: Trigger private tests - command: | - echo 'export COMMIT_MESSAGE="$(git log --format=%s -n 1 $CIRCLE_SHA1)"' >> "$BASH_ENV" - echo 'export FORMATTED_COMMIT_MESSAGE="${COMMIT_MESSAGE//\"/\\\"}"' >> "$BASH_ENV" - source "$BASH_ENV" - curl --request POST \ - --url $TOBIKO_PRIVATE_CIRCLECI_URL \ - --header "Circle-Token: $TOBIKO_PRIVATE_CIRCLECI_KEY" \ - --header "content-type: application/json" \ - --data '{ - "branch":"main", - "parameters":{ - "run_main_pr":false, - "run_sqlmesh_commit":true, - "sqlmesh_branch":"'$CIRCLE_BRANCH'", - "sqlmesh_commit_author":"'$CIRCLE_USERNAME'", - "sqlmesh_commit_hash":"'$CIRCLE_SHA1'", - "sqlmesh_commit_message":"'"$FORMATTED_COMMIT_MESSAGE"'" - } - }' + command: | + make << parameters.engine >>-test + no_output_timeout: 20m + - run: + name: Tear down test database + command: ./.circleci/manage-test-db.sh << parameters.engine >> "$TEST_DB_NAME" down + when: always + - store_test_results: + path: test-results workflows: main_pr: jobs: - doc_tests - - style_and_slow_tests: + - style_and_cicd_tests: matrix: parameters: python_version: - ["3.8", "3.9", "3.10", "3.11", "3.12"] - - style_and_slow_tests_pydantic_v1 - - airflow_docker_tests: - requires: - - style_and_slow_tests - filters: - branches: - only: - - main - - engine_adapter_docker_tests: - context: engine_adapter_slow - requires: - - style_and_slow_tests - filters: - branches: - only: - - main - - trigger_private_tests: + - "3.9" + - "3.10" + - "3.11" + - "3.12" + - "3.13" + - cicd_tests_windows + - engine_tests_docker: + name: engine_<< matrix.engine >> + matrix: + parameters: + engine: + - duckdb + - postgres + - mysql + - mssql + - trino + - spark + - clickhouse + - risingwave + - engine_tests_cloud: + name: cloud_engine_<< matrix.engine >> + context: + - sqlmesh_cloud_database_integration requires: - - style_and_slow_tests + - engine_tests_docker + matrix: + parameters: + engine: + - snowflake + - databricks + - redshift + - bigquery + - clickhouse-cloud + - athena + - fabric + - gcp-postgres filters: branches: only: - main - ui_style - ui_test + - vscode_test - migration_test diff --git a/.circleci/install-prerequisites.sh b/.circleci/install-prerequisites.sh new file mode 100755 index 0000000000..446221dba6 --- /dev/null +++ b/.circleci/install-prerequisites.sh @@ -0,0 +1,42 @@ +#!/bin/bash + +# This script is intended to be run by an Ubuntu build agent on CircleCI +# The goal is to install OS-level dependencies that are required before trying to install Python dependencies + +set -e + +if [ -z "$1" ]; then + echo "USAGE: $0 " + exit 1 +fi + +ENGINE="$1" + +COMMON_DEPENDENCIES="libpq-dev netcat-traditional unixodbc-dev" +ENGINE_DEPENDENCIES="" + +if [ "$ENGINE" == "spark" ]; then + ENGINE_DEPENDENCIES="default-jdk" +elif [ "$ENGINE" == "fabric" ]; then + echo "Installing Microsoft package repository" + + # ref: https://learn.microsoft.com/en-us/sql/connect/odbc/linux-mac/installing-the-microsoft-odbc-driver-for-sql-server + curl -sSL -O https://packages.microsoft.com/config/ubuntu/$(grep VERSION_ID /etc/os-release | cut -d '"' -f 2)/packages-microsoft-prod.deb + sudo dpkg -i packages-microsoft-prod.deb + rm packages-microsoft-prod.deb + + ENGINE_DEPENDENCIES="msodbcsql18" +fi + +ALL_DEPENDENCIES="$COMMON_DEPENDENCIES $ENGINE_DEPENDENCIES" + +echo "Installing OS-level dependencies: $ALL_DEPENDENCIES" + +sudo apt-get clean && sudo apt-get -y update && sudo ACCEPT_EULA='Y' apt-get -y install $ALL_DEPENDENCIES + +if [ "$ENGINE" == "spark" ]; then + echo "Using Java version for spark:" + java -version +fi + +echo "All done" \ No newline at end of file diff --git a/.circleci/manage-test-db.sh b/.circleci/manage-test-db.sh new file mode 100755 index 0000000000..b6e9c265c9 --- /dev/null +++ b/.circleci/manage-test-db.sh @@ -0,0 +1,170 @@ +#!/bin/bash + +# The purpose of this script is to create and destroy temporary test databases on the cloud engines +# The idea is that a database is created, the integration tests are run on that database and then the database is dropped +# This allows builds for multiple PR's to run concurrently without the tests clobbering each other and also gives each set of tests a fresh environment + +# Note: It is expected that the environment variables defined in 'tests/core/engine_adapter/config.yaml' for each cloud engine are set + +set -e + +if [ -z "$1" ] || [ -z "$2" ] || [ -z "$3" ]; then + echo "USAGE: $0 " + exit 1 +fi + +ENGINE="$1" +DB_NAME="$2" +DIRECTION="$3" + +function_exists() { + declare -f -F $1 > /dev/null + return $? +} + +# Snowflake +snowflake_init() { + echo "Installing Snowflake CLI" + pip install "snowflake-cli" +} + +snowflake_up() { + snow sql -q "create database if not exists $1" --temporary-connection +} + +snowflake_down() { + snow sql -q "drop database if exists $1" --temporary-connection +} + +# Databricks +databricks_init() { + echo "Installing Databricks CLI" + curl -fsSL https://raw.githubusercontent.com/databricks/setup-cli/main/install.sh | sudo sh || true +} + +databricks_up() { + databricks catalogs create $1 || true +} + +databricks_down() { + databricks catalogs delete $1 --force || true +} + +# Redshift +redshift_init() { + psql --version +} + +redshift_exec() { + PGPASSWORD=$REDSHIFT_PASSWORD psql -h $REDSHIFT_HOST -p $REDSHIFT_PORT -U $REDSHIFT_USER -c "$1" dev +} + +redshift_up() { + redshift_exec "create database $1" +} + +redshift_down() { + # try to prevent a "database is being accessed by other users" error when running DROP DATABASE + EXIT_CODE=1 + ATTEMPTS=0 + while [ $EXIT_CODE -ne 0 ] && [ $ATTEMPTS -lt 5 ]; do + # note: sometimes this pg_terminate_backend() call can randomly fail with: ERROR: Insufficient privileges + # if it does, let's proceed with the drop anyway rather than aborting and never attempting the drop + redshift_exec "select pg_terminate_backend(procpid) from pg_stat_activity where datname = '$1'" || true + + # perform drop + redshift_exec "drop database $1;" && EXIT_CODE=$? || EXIT_CODE=$? + if [ $EXIT_CODE -ne 0 ]; then + echo "Unable to drop database; retrying..." + ATTEMPTS=$((ATTEMPTS + 1)) + sleep 5 + fi + done +} + +# BigQuery +bigquery_init() { + # Write out the keyfile for the integration tests to pick up + echo "Writing out keyfile to $BIGQUERY_KEYFILE" + echo "$BIGQUERY_KEYFILE_CONTENTS" > $BIGQUERY_KEYFILE +} + + +# Clickhouse cloud +clickhouse-cloud_init() { + # note: the ping endpoint doesnt seem to need any API keys + until curl https://$CLICKHOUSE_CLOUD_HOST:8443/ping + do + echo "Pinging Clickhouse Cloud service to ensure it's not in idle mode..." + sleep 5 + done + echo "Clickhouse Cloud instance $CLICKHOUSE_CLOUD_HOST is up and running" +} + +# GCP Postgres +gcp-postgres_init() { + # Download and start Cloud SQL Proxy + curl -fsSL -o cloud-sql-proxy https://storage.googleapis.com/cloud-sql-connectors/cloud-sql-proxy/v2.18.0/cloud-sql-proxy.linux.amd64 + chmod +x cloud-sql-proxy + echo "$GCP_POSTGRES_KEYFILE_JSON" > /tmp/keyfile.json + ./cloud-sql-proxy --credentials-file /tmp/keyfile.json $GCP_POSTGRES_INSTANCE_CONNECTION_STRING & + + # Wait for proxy to start + sleep 5 +} + +gcp-postgres_exec() { + PGPASSWORD=$GCP_POSTGRES_PASSWORD psql -h 127.0.0.1 -U $GCP_POSTGRES_USER -c "$1" postgres +} + +gcp-postgres_up() { + gcp-postgres_exec "create database $1" +} + +gcp-postgres_down() { + gcp-postgres_exec "drop database $1" +} + +# Fabric +fabric_init() { + python --version #note: as at 2025-08-20, ms-fabric-cli is pinned to Python >= 3.10, <3.13 + pip install ms-fabric-cli + + # to prevent the '[EncryptionFailed] An error occurred with the encrypted cache.' error + # ref: https://microsoft.github.io/fabric-cli/#switch-to-interactive-mode-optional + fab config set encryption_fallback_enabled true + + echo "Logging in to Fabric" + fab auth login -u $FABRIC_CLIENT_ID -p $FABRIC_CLIENT_SECRET --tenant $FABRIC_TENANT_ID +} + +fabric_up() { + fab create "SQLMesh CircleCI.Workspace/$1.Warehouse" +} + +fabric_down() { + fab rm -f "SQLMesh CircleCI.Workspace/$1.Warehouse" || true +} + +INIT_FUNC="${ENGINE}_init" +UP_FUNC="${ENGINE}_up" +DOWN_FUNC="${ENGINE}_down" + +# If called with an unimplemented / unsupported engine, just exit +if ! function_exists $INIT_FUNC ; then + echo "WARN: $INIT_FUNC not implemeted; exiting" + exit 0 +fi + +echo "Initializing $ENGINE" +$INIT_FUNC + +if [ "$DIRECTION" == "up" ] && function_exists $UP_FUNC; then + echo "Creating database $DB_NAME" + $UP_FUNC $DB_NAME +elif [ "$DIRECTION" == "down" ] && function_exists $DOWN_FUNC; then + echo "Dropping database $DB_NAME" + $DOWN_FUNC $DB_NAME +fi + +echo "All done" diff --git a/.circleci/test_migration.sh b/.circleci/test_migration.sh index fc869eb439..bb1776550a 100755 --- a/.circleci/test_migration.sh +++ b/.circleci/test_migration.sh @@ -1,11 +1,6 @@ #!/usr/bin/env bash set -ex -CONFIG_NAME="local_config" -TMP_DIR=$(mktemp -d) -SUSHI_DIR="$TMP_DIR/sushi" - - if [[ -z $(git tag --points-at HEAD) ]]; then # If the current commit is not tagged, we need to find the last tag LAST_TAG=$(git describe --tags --abbrev=0) @@ -14,28 +9,49 @@ else LAST_TAG=$(git tag --sort=-creatordate | head -n 2 | tail -n 1) fi +if [ "$1" == "" ]; then + echo "Usage: $0 " + echo "eg $0 sushi '--gateway duckdb_persistent'" + exit 1 +fi + + +TMP_DIR=$(mktemp -d) +EXAMPLE_NAME="$1" +SQLMESH_OPTS="$2" +EXAMPLE_DIR="./examples/$EXAMPLE_NAME" +TEST_DIR="$TMP_DIR/$EXAMPLE_NAME" + +echo "Running migration test for '$EXAMPLE_NAME' in '$TEST_DIR' for example project '$EXAMPLE_DIR' using options '$SQLMESH_OPTS'" + +# Copy the example project from the *current* checkout so it's stable across old/new SQLMesh versions +cp -r "$EXAMPLE_DIR" "$TEST_DIR" + git checkout $LAST_TAG # Install dependencies from the previous release. make install-dev -cp -r ./examples/sushi $TMP_DIR +# this is only needed temporarily until the released tag for $LAST_TAG includes this config +if [ "$EXAMPLE_NAME" == "sushi_dbt" ]; then + echo 'migration_test_config = sqlmesh_config(Path(__file__).parent, dbt_target_name="duckdb")' >> $TEST_DIR/config.py +fi # Run initial plan -pushd $SUSHI_DIR +pushd $TEST_DIR rm -rf ./data/* -sqlmesh --config $CONFIG_NAME plan --no-prompts --auto-apply +sqlmesh $SQLMESH_OPTS plan --no-prompts --auto-apply +rm -rf .cache popd -# Switch back to the starting state of the repository +# Switch back to the starting state of the repository git checkout - # Install updated dependencies. make install-dev # Migrate and make sure the diff is empty -pushd $SUSHI_DIR -sqlmesh --config $CONFIG_NAME migrate -sqlmesh --config $CONFIG_NAME diff prod +pushd $TEST_DIR +sqlmesh $SQLMESH_OPTS migrate +sqlmesh $SQLMESH_OPTS diff prod popd - diff --git a/.circleci/wait-for-db.sh b/.circleci/wait-for-db.sh new file mode 100755 index 0000000000..a313320279 --- /dev/null +++ b/.circleci/wait-for-db.sh @@ -0,0 +1,83 @@ +#!/bin/bash + +# The purpose of this script is to be called after `docker compose up -d` has been run for a given database +# The idea is to block until the database is available to serve requests. Once the database can serve requests, +# the integration tests can be run. +# Therefore, the ports etc are tightly coupled with the compose.yml files under tests/core/engine_adapter/docker/ +# +# Note that if the docker daemon is not running `localhost`, you can set the DOCKER_HOSTNAME environment variable to the +# correct host Docker is running on + +set -e + +if [ -z "$1" ]; then + echo "USAGE: $0 " + exit 1 +fi + +ENGINE="$1" + +function_exists() { + declare -f -F $1 > /dev/null + return $? +} + +probe_port() { + HOSTNAME=${DOCKER_HOSTNAME:-localhost} + echo "Probing '$HOSTNAME' on port $1" + while ! nc -z $HOSTNAME $1; do + sleep 1 + done +} + +clickhouse_ready() { + probe_port 8123 +} + +postgres_ready() { + probe_port 5432 +} + +mssql_ready() { + probe_port 1433 +} + +mysql_ready() { + probe_port 3306 +} + +spark_ready() { + probe_port 15002 +} + +trino_ready() { + # Trino has a built-in healthcheck script, just call that + docker compose -f tests/core/engine_adapter/integration/docker/compose.trino.yaml exec trino /bin/bash -c '/usr/lib/trino/bin/health-check' +} + +risingwave_ready() { + probe_port 4566 +} + +echo "Waiting for $ENGINE to be ready..." + +READINESS_FUNC="${ENGINE}_ready" + +# If called with an unimplemented / unsupported engine, just exit +if ! function_exists $READINESS_FUNC ; then + echo "WARN: $READINESS_FUNC not implemeted; exiting" + exit 0 +fi + +EXIT_CODE=1 + +while [ $EXIT_CODE -ne 0 ]; do + echo "Checking $ENGINE" + $READINESS_FUNC && EXIT_CODE=$? || EXIT_CODE=$? + if [ $EXIT_CODE -ne 0 ]; then + echo "$ENGINE not ready; sleeping" + sleep 5 + fi +done + +echo "$ENGINE is ready!" \ No newline at end of file diff --git a/.claude/agents/code-reviewer.md b/.claude/agents/code-reviewer.md new file mode 100644 index 0000000000..85ab5be3dc --- /dev/null +++ b/.claude/agents/code-reviewer.md @@ -0,0 +1,73 @@ +--- +name: code-reviewer +description: Use this agent PROACTIVELY when you need expert code review after writing or modifying code. This agent should be called after completing any coding task to ensure quality, architectural compliance, and catch potential issues. Examples: Context: The user has just implemented a new feature for processing SQLMesh snapshots. user: 'I just added a new method to handle snapshot fingerprinting in the Context class' assistant: 'Let me use the code-reviewer agent to analyze this implementation for potential issues and architectural compliance' Since code was just written, use the code-reviewer agent to review the implementation for quality, edge cases, and adherence to SQLMesh patterns. Context: An agent just generated a database migration script. user: 'Here's the migration I created for adding a new state table' assistant: 'Now I'll have the code-reviewer agent examine this migration for safety and best practices' Since a migration was created, use the code-reviewer agent to ensure it follows SQLMesh migration patterns and handles edge cases safely. +tools: Glob, Grep, LS, Read, NotebookRead, WebFetch, TodoWrite, WebSearch, Bash +model: sonnet +color: blue +--- + +You are an Expert Code Reviewer, a senior software engineer with deep expertise in code quality, architecture, and best practices. You NEVER write code yourself - your sole focus is providing thorough, insightful code reviews that catch issues other engineers might miss. + +Your core responsibilities: + +## Analysis Approach + +- Examine code for architectural alignment with established patterns and principles +- Identify potential edge cases, race conditions, and error scenarios +- Evaluate performance implications and scalability concerns +- Check for security vulnerabilities and data safety issues +- Assess maintainability, readability, and documentation quality +- Verify adherence to project-specific coding standards and conventions + +## Review Methodology + +- **Architectural Review**: Does the code follow established patterns? Does it fit well within the existing codebase structure? +- **Logic Analysis**: Are there logical flaws, edge cases, or scenarios that could cause failures? +- **Error Handling**: Is error handling comprehensive and appropriate? Are failure modes considered? +- **Performance Review**: Are there performance bottlenecks, inefficient algorithms, or resource leaks? +- **Security Assessment**: Are there potential security vulnerabilities or data exposure risks? +- **Maintainability Check**: Is the code readable, well-structured, and properly documented? + +### Standard Code Review Checklist + +- Code is simple and readable +- Functions, classes, and variables are well-named +- No duplicated code +- Proper error handling with specific error types +- No exposed secrets, API keys, or credentials +- Input validation and sanitization implemented +- Good test coverage including edge cases +- Performance considerations addressed +- Security best practices followed +- Documentation updated for significant changes + +## Feedback Structure + +Organize your reviews into clear categories: + +- **Critical Issues**: Problems that could cause failures, security issues, or data corruption +- **Architectural Concerns**: Deviations from established patterns or design principles +- **Edge Cases**: Scenarios that might not be handled properly +- **Performance Considerations**: Potential bottlenecks or inefficiencies +- **Maintainability Improvements**: Suggestions for better code organization or documentation +- **Documentation**: Suggestions to update documentation for significant changes + +## Communication Style + +- Be constructive and specific in your feedback +- Explain the 'why' behind your suggestions, not just the 'what' +- Prioritize issues by severity and impact +- Acknowledge good practices when you see them +- Provide context for your recommendations +- Ask clarifying questions when code intent is unclear + +## Important Constraints + +- You NEVER write, modify, or suggest specific code implementations +- You focus purely on analysis and high-level guidance +- You always consider the broader system context and existing codebase patterns +- You escalate concerns about fundamental architectural decisions +- You validate that solutions align with project requirements and constraints + +When reviewing code, assume you're looking at recently written code unless explicitly told otherwise. Focus on providing actionable insights that help improve code quality while respecting the existing architectural decisions and project constraints. + diff --git a/.claude/agents/developer.md b/.claude/agents/developer.md new file mode 100644 index 0000000000..3a9f32d6c4 --- /dev/null +++ b/.claude/agents/developer.md @@ -0,0 +1,110 @@ +--- +name: developer +description: Use this agent PROACTIVELY when you need to understand the user's task, read GitHub issues, implement new features, write comprehensive tests, refactor existing code, fix bugs, or make any code changes that require deep understanding of the project's architecture and coding standards. Examples: Context: User wants to add a new SQL dialect adapter to SQLMesh. user: 'I need to implement support for Oracle database in SQLMesh' assistant: 'I'll use the software-engineer agent to implement the Oracle adapter following SQLMesh's engine adapter patterns' Since this requires implementing a new feature with proper architecture understanding, use the software-engineer agent. Context: User discovers a bug in the migration system. user: 'The migration v0084 is failing on MySQL due to field size limits' assistant: 'Let me use the software-engineer agent to investigate and fix this migration issue' This requires debugging and fixing code while understanding SQLMesh's migration patterns, so use the software-engineer agent. Context: User needs comprehensive tests for a new feature. user: 'I just implemented a new snapshot fingerprinting algorithm and need tests' assistant: 'I'll use the software-engineer agent to write comprehensive tests following SQLMesh's testing patterns' Writing thorough tests requires understanding the codebase architecture and testing conventions, so use the software-engineer agent. +model: sonnet +color: red +--- + +You are an expert software engineer with deep expertise in Python, SQL, data engineering, and modern software development practices. You specialize in working with complex codebases like SQLMesh, understanding architectural patterns, and implementing robust, well-tested solutions. + +Your core responsibilities: + +# Project-Specific Expertise + +- Understand SQLMesh's core concepts: virtual environments, fingerprinting, snapshots, plans. You can find documentation in the ./docs folder +- Implement engine adapters following the established 16+ engine pattern +- Handle state sync and migration patterns correctly +- Support dbt integration requirements when relevant + +# Problem-Solving Approach + +1. Analyze the existing codebase to understand patterns and conventions +2. Come up with an implementation plan; identify edge cases and trade-offs; request feedback and ask clarifying questions +3. IMPORTANT: Write comprehensive tests covering normal and edge cases BEFORE you write any implementation code. It's expected for these tests to fail at first, the implementation should then ensure that the tests are passing +4. Confirm that the written tests cover the full scope of the work that has been requested +5. Identify the most appropriate location for new code based on architecture +6. Study similar existing implementations as reference +7. Implement following established patterns and best practices +8. Validate code quality with style checks +9. Consider backward compatibility and migration needs especially when the persistent state + +# Implementation Best Practices + +## Code Implementation + +- Write clean, maintainable, and performant code following established patterns +- Implement new features by studying existing similar implementations first +- Follow the project's architectural principles and design patterns +- Use appropriate abstractions and avoid code duplication +- Ensure cross-platform compatibility (Windows/Linux/macOS) + +## Testing Best Practices + +- Write comprehensive tests using pytest with appropriate markers (fast/slow/engine-specific) +- Follow the project's testing philosophy: fast tests for development, comprehensive coverage for CI +- Use existing test utilities `assert_exp_eq` and others for validation when appropriate +- Test edge cases, error conditions, and cross-engine compatibility +- Use existing tests in the same module as a reference for new tests +- Write an integration test(s) that runs against the `sushi` project when the scope of feature touches multiple decoupled components +- Only add tests within the `tests/` folder. Prefer adding tests to existing modules over creating new files +- Tests are marked with pytest markers: + - **Type markers**: `fast`, `slow`, `docker`, `remote`, `cicdonly`, `isolated`, `registry_isolation` + - **Domain markers**: `cli`, `dbt`, `github`, `jupyter`, `web` + - **Engine markers**: `engine`, `athena`, `bigquery`, `clickhouse`, `databricks`, `duckdb`, `motherduck`, `mssql`, `mysql`, `postgres`, `redshift`, `snowflake`, `spark`, `trino`, `risingwave` +- Default to `fast` tests during development +- Engine tests use real connections when available, mocks otherwise +- The `sushi` example project is used extensively in tests +- Use `DuckDBMetadata` helper for validating table metadata in tests + +## Code Quality Standards + +- Python: Black formatting, isort for imports, mypy for type checking, Ruff for linting +- TypeScript/React: ESLint + Prettier configuration +- All style checks run via `make style` +- Pre-commit hooks enforce all style rules automatically +- Important: Some modules (duckdb, numpy, pandas) are banned at module level to prevent import-time side effects +- Write clear docstrings and comments for complex logic but avoid comments that are too frequent or state overly obvious details +- Make sure there are no trailing whitespaces in edited files + +## Writing Functions / Methods Best Practices + +When evaluating whether a function you implemented is good or not, use this checklist: + +1. Can you read the function and easily follow what it's doing? If yes, then stop here +2. Does the function have very high cyclomatic complexity? (number of independent paths, or, in a lot of cases, number of nesting if if-else as a proxy). If it does, then it likely needs to be rewritten +2. Are the arguments and return values annotated with the correct types? +3. Are there any common data structures and algorithms that would make this function much easier to follow and more robust? +4. Are there any unused parameters in the function? +5. Are there any unnecessary type casts that can be moved to function arguments? +6. Is the function easily testable without mocking core features? If not, can this function be tested as part of an integration test? +7. Does it have any hidden untested dependencies or any values that can be factored out into the arguments instead? Only care about non-trivial dependencies that can actually change or affect the function +8. Brainstorm 3 better function names and see if the current name is the best, consistent with rest of codebase + +IMPORTANT: you SHOULD NOT refactor out a separate function unless there is a compelling need, such as: +- the refactored function is used in more than one place +- the refactored function is easily unit testable while the original function is not AND you can't test it any other way +- the original function is extremely hard to follow and you resort to putting comments everywhere just to explain it + +## Using Git + +- Use Conventional Commits format when writing commit messages: https://www.conventionalcommits.org/en/v1.0.0 + +# Communication + +- Be concise and to the point +- Explain your architectural decisions and reasoning +- Highlight any potential breaking changes or migration requirements +- Suggest related improvements or refactoring opportunities +- Document complex algorithms or business logic clearly + +# Common Pitfalls + +1. **Engine Tests**: Many tests require specific database credentials or Docker. Check test markers before running. +2. **Path Handling**: Be careful with Windows paths - use `pathlib.Path` for cross-platform compatibility. +3. **State Management**: Understanding the state sync mechanism is crucial for debugging environment issues. +4. **Snapshot Versioning**: Changes to model logic create new versions - this is by design for safe deployments. +5. **Module Imports**: Avoid importing duckdb, numpy, or pandas at module level - these are banned by Ruff to prevent long load times in cases where the libraries aren't used. +6. **Import And Attribute Errors**: If the code raises `ImportError` or `AttributeError` try running the `make install-dev` command first to make sure all dependencies are up to date + +When implementing features, always consider the broader impact on the system, ensure proper error handling, and maintain the high code quality standards established in the project. Your implementations should be production-ready and align with SQLMesh's philosophy of safe, reliable data transformations. + diff --git a/.claude/agents/qa-reviewer.md b/.claude/agents/qa-reviewer.md new file mode 100644 index 0000000000..b1f6842f32 --- /dev/null +++ b/.claude/agents/qa-reviewer.md @@ -0,0 +1,106 @@ +--- +name: qa-reviewer +description: Use this agent PROACTIVELY when you need to analyze a PR or code changes to provide structured QA testing guidance for human QA testers. This agent reviews PRs and provides specific testing scenarios, example projects to use, commands to run, and validation steps. Examples: Context: A developer just implemented virtual environment isolation for SQLMesh. user: 'I just added support for isolated virtual environments in SQLMesh' assistant: 'Let me use the qa-reviewer agent to create comprehensive QA testing instructions for this feature' Since a significant feature was implemented, use the qa-reviewer agent to provide structured testing guidance for QA. Context: A PR adds a new SQL engine adapter. user: 'Here's the PR that adds BigQuery support to SQLMesh' assistant: 'I'll use the qa-reviewer agent to analyze this change and create QA test scenarios' Since a new engine adapter was added, use the qa-reviewer agent to provide testing guidance specific to engine adapters. +tools: Glob, Grep, LS, Read, NotebookRead, WebFetch, TodoWrite, WebSearch, Bash +model: sonnet +color: green +--- + +You are a QA Test Specialist with deep expertise in SQLMesh's architecture, testing methodologies, and quality assurance practices. You specialize in analyzing code changes and providing comprehensive, structured testing guidance for human QA testers. + +Your core responsibilities: + +## Analysis Approach + +- Review PRs and code changes to understand the scope and impact of modifications +- Identify all components, features, and workflows that could be affected by the changes +- Consider edge cases, integration points, and potential failure scenarios +- Map changes to existing example projects and testing workflows +- Provide specific, actionable testing instructions that non-developers can follow +- MUST write full instructions to the `plans/` folder with the filename of `_.md` so they can be reviewed and executed by QA testers + +## QA Test Plan Structure + +Organize your QA recommendations into clear, actionable sections: + +### **Change Summary** +- Brief description of what was changed and why +- Key components and files modified +- Potential impact areas and affected workflows + +### **Test Environment Setup** +- Which example project(s) to use for testing (e.g., `examples/sushi/`, `examples/sushi_dbt/`) +- Any necessary environment configuration or setup steps +- Required tools, databases, or dependencies + +### **Core Test Scenarios** +- Step-by-step testing procedures with specific commands +- Expected results and success criteria for each test +- Validation commands to confirm expected behavior +- Screenshots or output examples where helpful + +### **Edge Case Testing** +- Boundary conditions and error scenarios to test +- Negative test cases and expected failure modes +- Cross-platform considerations (Windows/Linux/macOS) +- Performance and scalability considerations + +### **Regression Testing** +- Existing functionality that should be retested +- Critical workflows that must continue working +- Backward compatibility scenarios + +### **Integration Testing** +- Cross-component testing scenarios +- Multi-engine testing when relevant +- dbt integration testing if applicable +- UI/CLI integration points + +## Example Project Guidance + +Provide specific guidance on: +- Which `examples/` project best demonstrates the feature +- How to modify example projects for comprehensive testing +- Custom test scenarios using real-world-like data +- Commands to set up test scenarios and validate results + +## Command Examples + +Always provide: +- Exact CLI commands to run tests +- Configuration file modifications needed +- Environment variable settings +- Database setup commands when applicable +- Validation queries or commands to check results + +## Testing Best Practices + +- Focus on user-facing functionality and workflows +- Include both happy path and error scenarios +- Provide clear success/failure criteria +- Consider different user personas (data analysts, engineers, platform teams) +- If the change doesn't have engine specific logic in it, prefer to test against duckdb since that is easiest +- Include performance and scalability considerations +- DO NOT have a step which is running an existing test - these tests are automatically run in CI and should not be duplicated in manual testing instructions +- Assume all example projects are already tested as is and don't suggest doing a test which is running them again +- All tests MUST just use `sqlmesh` cli commands - do not use the Python API. The goal is to run tests that mimic what an actual user would do, which is using the CLI. +- A common pattern could be using the `sqlmesh` cli command and then running a Python script to validate the database or state is in an expected state, but the Python script should not be a test itself, just a validation step. + +## Communication Style + +- Use clear, numbered steps for testing procedures +- Provide exact commands that can be copy-pasted +- Include expected outputs and how to interpret results +- Explain the "why" behind each test scenario +- Use language accessible to QA testers who may not be developers +- Organize content with clear headings and bullet points + +## Important Constraints + +- You NEVER write or modify code - you only analyze and provide testing guidance +- You focus on user-facing functionality and workflows +- You always provide specific, actionable testing steps +- You consider the full user journey and realistic usage scenarios +- You validate that your recommendations align with SQLMesh's architecture and patterns + +When analyzing changes, assume you're looking at a recent PR or set of code modifications. Focus on providing comprehensive testing guidance that ensures the changes work correctly, don't break existing functionality, and provide a good user experience across different scenarios and environments. \ No newline at end of file diff --git a/.claude/agents/technical-writer.md b/.claude/agents/technical-writer.md new file mode 100644 index 0000000000..7e8be9b928 --- /dev/null +++ b/.claude/agents/technical-writer.md @@ -0,0 +1,56 @@ +--- +name: technical-writer +description: Use this agent PROACTIVELY when you need to create, update, or maintain technical documentation for SQLMesh. Examples include: writing user guides for virtual environments, creating API documentation for new features, updating existing docs after code changes, writing deep-dive technical explanations of core concepts like fingerprinting or state sync, creating migration guides for users upgrading between versions, or documenting new engine adapter implementations. This agent should be used proactively when code changes affect user-facing functionality or when new features need documentation. +model: sonnet +color: white +--- + +You are a Technical Documentation Specialist with deep expertise in SQLMesh's architecture, concepts, and codebase. You possess comprehensive knowledge of data transformation frameworks, SQL engines, and developer tooling, combined with exceptional technical writing skills. + +Your core responsibilities: + +## Documentation Maintenance & Creation + +- Maintain existing documentation by identifying outdated content, broken links, and missing information +- Create new documentation pages that align with SQLMesh's documentation structure and style +- Ensure all documentation follows consistent formatting, terminology, and organizational patterns +- Update documentation proactively when code changes affect user-facing functionality + +### Editing + +- When editing files make sure to not leave any whitespaces + +## Multi-Audience Writing + +- Write clear, accessible guides for less technical users (data analysts, business users) focusing on practical workflows and concepts +- Create comprehensive deep-dives for technical users (data engineers, platform engineers) covering architecture, implementation details, and advanced configurations +- Adapt your writing style, depth, and examples based on the target audience's technical expertise + +## SQLMesh Expertise + +- Demonstrate deep understanding of SQLMesh's core concepts: virtual environments, fingerprinting, state sync, plan/apply workflows, incremental processing, and multi-dialect support +- Accurately explain complex technical concepts like model versioning, virtual data environments, state migration, and data intervals +- Reference appropriate code examples from the codebase when illustrating concepts +- Understand the relationship between SQLMesh components and how they work together + +## Quality Standards + +- Ensure technical accuracy by cross-referencing code implementation and existing documentation +- Include practical examples, code snippets, and real-world use cases +- Structure content with clear headings, bullet points, and logical flow +- Provide troubleshooting guidance and common pitfall warnings where relevant +- Include relevant CLI commands, configuration examples, and best practices + +## Documentation Types You Excel At + +- User guides and tutorials for specific workflows +- API documentation and reference materials +- Architecture explanations and system overviews +- Migration guides and upgrade instructions +- Troubleshooting guides and FAQ sections +- Integration guides for external tools and systems + +When creating documentation, always consider the user's journey and provide the right level of detail for their needs. For less technical users, focus on what they need to accomplish and provide step-by-step guidance. For technical users, include implementation details, configuration options, and architectural context. Always validate technical accuracy against the actual codebase and existing documentation patterns. + +IMPORTANT: You SHOULD NEVER edit any code. Make sure you only change files in the `docs/` folder. + diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 0000000000..5acbdac5d3 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,10 @@ +version: 2 +updates: + - package-ecosystem: 'npm' + directory: '/' + schedule: + interval: 'weekly' + - package-ecosystem: 'github-actions' + directory: '/' + schedule: + interval: 'weekly' diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md new file mode 100644 index 0000000000..7585f0ce10 --- /dev/null +++ b/.github/pull_request_template.md @@ -0,0 +1,16 @@ +## Description + + + +## Test Plan + + + +## Checklist + +- [ ] I have run `make style` and fixed any issues +- [ ] I have added tests for my changes (if applicable) +- [ ] All existing tests pass (`make fast-test`) +- [ ] My commits are signed off (`git commit -s`) per the [DCO](DCO) + + diff --git a/.github/scripts/get_scm_version.py b/.github/scripts/get_scm_version.py new file mode 100644 index 0000000000..79dfee9e5d --- /dev/null +++ b/.github/scripts/get_scm_version.py @@ -0,0 +1,4 @@ +from setuptools_scm import get_version + +version = get_version(root='../../', relative_to=__file__) +print(version.split('+')[0]) diff --git a/.github/workflows/dco.yml b/.github/workflows/dco.yml new file mode 100644 index 0000000000..a1c4e07300 --- /dev/null +++ b/.github/workflows/dco.yml @@ -0,0 +1,17 @@ +name: Sanity check +on: [pull_request] + +jobs: + commits_check_job: + runs-on: ubuntu-latest + name: Commits Check + steps: + - name: Get PR Commits + id: 'get-pr-commits' + uses: tim-actions/get-pr-commits@master + with: + token: ${{ secrets.GITHUB_TOKEN }} + - name: DCO Check + uses: tim-actions/dco@master + with: + commits: ${{ steps.get-pr-commits.outputs.commits }} diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml new file mode 100644 index 0000000000..69e93635dc --- /dev/null +++ b/.github/workflows/pr.yaml @@ -0,0 +1,148 @@ +on: + push: + branches: + - main + pull_request: + branches: + - main +concurrency: + group: 'pr-${{ github.event.pull_request.number }}' + cancel-in-progress: true +permissions: + contents: read +jobs: + test-vscode: + env: + PLAYWRIGHT_SKIP_BROWSER_DOWNLOAD: 1 + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v5 + - uses: actions/setup-node@v6 + with: + node-version: '22' + - uses: pnpm/action-setup@v4 + with: + version: latest + - name: Install dependencies + run: pnpm install + - name: Run CI + run: pnpm run ci + test-vscode-e2e: + runs-on: + labels: [ubuntu-2204-8] + # As at 2026-01-12 this job flakes 100% of the time. It needs investigation + if: false + steps: + - uses: actions/checkout@v5 + - uses: actions/setup-node@v6 + with: + node-version: '22' + - uses: pnpm/action-setup@v4 + with: + version: latest + - name: Install dependencies + run: pnpm install + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: '3.12' + - name: Install uv + uses: astral-sh/setup-uv@v7 + - name: Install python dependencies + run: | + python -m venv .venv + source .venv/bin/activate + make install-dev + - name: Install code-server + run: curl -fsSL https://code-server.dev/install.sh | sh + - name: Install Playwright browsers + working-directory: ./vscode/extension + run: pnpm exec playwright install + - name: Run e2e tests + working-directory: ./vscode/extension + timeout-minutes: 30 + run: | + source ../../.venv/bin/activate + pnpm run test:e2e + - uses: actions/upload-artifact@v5 + if: ${{ !cancelled() }} + with: + name: playwright-report + path: vscode/extension/playwright-report/ + retention-days: 30 + test-dbt-versions: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + dbt-version: ['1.3', '1.4', '1.5', '1.6', '1.7', '1.8', '1.9', '1.10'] + steps: + - uses: actions/checkout@v5 + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: '3.10' + - name: Install uv + uses: astral-sh/setup-uv@v7 + - name: Install SQLMesh dev dependencies + run: | + uv venv .venv + source .venv/bin/activate + UV=1 make install-dev-dbt-${{ matrix.dbt-version }} + - name: Run dbt tests + # We can't run slow tests across all engines due to tests requiring DuckDB and old versions + # of DuckDB require a version of DuckDB we no longer support + run: | + source .venv/bin/activate + + # Remove semantic_models and metrics sections for DBT versions < 1.6.0 + # Using explicit list to avoid version comparison issues + if [[ "${{ matrix.dbt-version }}" == "1.3" ]] || \ + [[ "${{ matrix.dbt-version }}" == "1.4" ]] || \ + [[ "${{ matrix.dbt-version }}" == "1.5" ]]; then + + echo "DBT version is ${{ matrix.dbt-version }} (< 1.6.0), removing semantic_models and metrics sections..." + + schema_file="tests/fixtures/dbt/sushi_test/models/schema.yml" + if [[ -f "$schema_file" ]]; then + echo "Modifying $schema_file..." + + # Create a temporary file + temp_file=$(mktemp) + + # Use awk to remove semantic_models and metrics sections + awk ' + /^semantic_models:/ { in_semantic=1; next } + /^metrics:/ { in_metrics=1; next } + /^[^ ]/ && (in_semantic || in_metrics) { + in_semantic=0; + in_metrics=0 + } + !in_semantic && !in_metrics { print } + ' "$schema_file" > "$temp_file" + + # Move the temp file back + mv "$temp_file" "$schema_file" + + echo "Successfully removed semantic_models and metrics sections" + else + echo "Schema file not found at $schema_file, skipping..." + fi + else + echo "DBT version is ${{ matrix.dbt-version }} (>= 1.6.0), keeping semantic_models and metrics sections" + fi + + make dbt-fast-test + - name: Test SQLMesh info in sushi_dbt + working-directory: ./examples/sushi_dbt + run: | + source ../../.venv/bin/activate + sed -i 's/target: in_memory/target: postgres/g' profiles.yml + if [[ $(echo -e "${{ matrix.dbt-version }}\n1.5.0" | sort -V | head -n1) == "${{ matrix.dbt-version }}" ]] && [[ "${{ matrix.dbt-version }}" != "1.5.0" ]]; then + echo "DBT version is ${{ matrix.dbt-version }} (< 1.5.0), removing version parameters..." + sed -i -e 's/, version=1) }}/) }}/g' -e 's/, v=1) }}/) }}/g' models/top_waiters.sql + else + echo "DBT version is ${{ matrix.dbt-version }} (>= 1.5.0), keeping version parameters" + fi + + sqlmesh info --skip-connection diff --git a/.github/workflows/private-repo-test.yaml b/.github/workflows/private-repo-test.yaml new file mode 100644 index 0000000000..9b2365f48a --- /dev/null +++ b/.github/workflows/private-repo-test.yaml @@ -0,0 +1,97 @@ +name: Private Repo Testing + +on: + pull_request_target: + branches: + - main + +concurrency: + group: 'private-test-${{ github.event.pull_request.number }}' + cancel-in-progress: true + +permissions: + contents: read + +jobs: + trigger-private-test: + runs-on: ubuntu-latest + steps: + - name: Checkout code + uses: actions/checkout@v5 + with: + fetch-depth: 0 + ref: ${{ github.event.pull_request.head.sha || github.ref }} + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: '3.12' + - name: Install uv + uses: astral-sh/setup-uv@v7 + - name: Set up Node.js for UI build + uses: actions/setup-node@v6 + with: + node-version: '20' + - name: Install pnpm + uses: pnpm/action-setup@v4 + with: + version: latest + - name: Install UI dependencies + run: pnpm install + - name: Build UI + run: pnpm --prefix web/client run build + - name: Install Python dependencies + run: | + python -m venv .venv + source .venv/bin/activate + pip install build twine setuptools_scm + - name: Generate development version + id: version + run: | + source .venv/bin/activate + # Generate a PEP 440 compliant unique version including run attempt + BASE_VERSION=$(python .github/scripts/get_scm_version.py) + COMMIT_SHA=$(git rev-parse --short HEAD) + # Use PEP 440 compliant format: base.devN+pr.sha.attempt + UNIQUE_VERSION="${BASE_VERSION}+pr${{ github.event.pull_request.number }}.${COMMIT_SHA}.run${{ github.run_attempt }}" + echo "version=$UNIQUE_VERSION" >> $GITHUB_OUTPUT + echo "Generated unique version with run attempt: $UNIQUE_VERSION" + - name: Build package + env: + SETUPTOOLS_SCM_PRETEND_VERSION: ${{ steps.version.outputs.version }} + run: | + source .venv/bin/activate + python -m build + - name: Configure PyPI for private repository + env: + TOBIKO_PRIVATE_PYPI_URL: ${{ secrets.TOBIKO_PRIVATE_PYPI_URL }} + TOBIKO_PRIVATE_PYPI_KEY: ${{ secrets.TOBIKO_PRIVATE_PYPI_KEY }} + run: ./.circleci/update-pypirc.sh + - name: Publish to private PyPI + run: | + source .venv/bin/activate + python -m twine upload -r tobiko-private dist/* + - name: Publish Python Tests package + env: + SETUPTOOLS_SCM_PRETEND_VERSION: ${{ steps.version.outputs.version }} + run: | + source .venv/bin/activate + unset TWINE_USERNAME TWINE_PASSWORD && make publish-tests + - name: Get GitHub App token + id: get_token + uses: actions/create-github-app-token@v2 + with: + private-key: ${{ secrets.TOBIKO_RENOVATE_BOT_PRIVATE_KEY }} + app-id: ${{ secrets.TOBIKO_RENOVATE_BOT_APP_ID }} + owner: ${{ secrets.PRIVATE_REPO_OWNER }} + - name: Trigger private repository workflow + uses: convictional/trigger-workflow-and-wait@v1.6.5 + with: + owner: ${{ secrets.PRIVATE_REPO_OWNER }} + repo: ${{ secrets.PRIVATE_REPO_NAME }} + github_token: ${{ steps.get_token.outputs.token }} + workflow_file_name: ${{ secrets.PRIVATE_WORKFLOW_FILE }} + client_payload: | + { + "package_version": "${{ steps.version.outputs.version }}", + "pr_number": "${{ github.event.pull_request.number }}" + } diff --git a/.github/workflows/release_extension.yaml b/.github/workflows/release_extension.yaml new file mode 100644 index 0000000000..ed46d40d47 --- /dev/null +++ b/.github/workflows/release_extension.yaml @@ -0,0 +1,56 @@ +name: Release VSCode Extension +on: + workflow_dispatch: + inputs: + version: + description: 'Version to release (e.g., 1.0.0)' + required: true + type: string +jobs: + release: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v5 + - name: Check branch is main + run: | + if [[ "${{ github.ref }}" != "refs/heads/main" ]]; then + echo "Error: This workflow can only be run from the main branch" + exit 1 + fi + echo "Branch check passed: running from main branch" + - name: Validate version format + run: | + version="${{ github.event.inputs.version }}" + if ! [[ $version =~ ^[0-9]+\.[0-9]+\.[0-9]+(-[0-9A-Za-z-]+(\.[0-9A-Za-z-]+)*)?(\+[0-9A-Za-z-]+(\.[0-9A-Za-z-]+)*)?$ ]]; then + echo "Error: Version must be a valid semantic version (e.g., 1.0.0, 1.0.0-beta.1, 1.0.0+build.1)" + exit 1 + fi + echo "Version format is valid: $version" + - name: Setup Node.js + uses: actions/setup-node@v6 + with: + node-version: '20' + - name: Install pnpm + uses: pnpm/action-setup@v4 + with: + version: 10 + - name: Install dependencies + run: pnpm install --frozen-lockfile + - name: Update package.json version + working-directory: vscode/extension + run: | + npm version ${{ github.event.inputs.version }} --no-git-tag-version + - name: Build extension + working-directory: vscode/extension + run: pnpm run vscode:package + - name: Upload extension to Marketplace + working-directory: vscode/extension + run: | + pnpx vsce publish --packagePath sqlmesh-${{ github.event.inputs.version }}.vsix + env: + VSCE_PAT: ${{ secrets.VSCE_PAT }} + - name: Upload extension to OpenVSX + working-directory: vscode/extension + run: | + pnpx ovsx publish -p ${{ secrets.OPEN_VSX_TOKEN }} sqlmesh-${{ github.event.inputs.version }}.vsix diff --git a/.github/workflows/release_shared_js.yaml b/.github/workflows/release_shared_js.yaml new file mode 100644 index 0000000000..eb68163739 --- /dev/null +++ b/.github/workflows/release_shared_js.yaml @@ -0,0 +1,58 @@ +name: Release web common code +on: + workflow_dispatch: + inputs: + version: + description: 'Version to release (e.g., 1.0.0)' + required: true + type: string +permissions: + id-token: write + contents: read +jobs: + release: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v5 + - name: Check branch is main + run: | + if [[ "${{ github.ref }}" != "refs/heads/main" ]]; then + echo "Error: This workflow can only be run from the main branch" + exit 1 + fi + echo "Branch check passed: running from main branch" + - name: Validate version format + run: | + version="${{ github.event.inputs.version }}" + if ! [[ $version =~ ^[0-9]+\.[0-9]+\.[0-9]+(-[0-9A-Za-z-]+(\.[0-9A-Za-z-]+)*)?(\+[0-9A-Za-z-]+(\.[0-9A-Za-z-]+)*)?$ ]]; then + echo "Error: Version must be a valid semantic version (e.g., 1.0.0, 1.0.0-beta.1, 1.0.0+build.1)" + exit 1 + fi + echo "Version format is valid: $version" + - name: Setup Node.js + uses: actions/setup-node@v6 + with: + node-version: '20' + registry-url: 'https://registry.npmjs.org' + - name: Update npm + run: npm install -g npm@latest + - name: Print npm version + run: npm --version + - name: Install pnpm + uses: pnpm/action-setup@v4 + with: + version: 10 + - name: Install dependencies + run: pnpm install --frozen-lockfile + - name: Update package.json version + working-directory: web/common + run: | + npm version ${{ github.event.inputs.version }} --no-git-tag-version + - name: Build package + working-directory: web/common + run: pnpm run build + - name: Publish to npm + working-directory: web/common + run: | + npm publish diff --git a/.gitignore b/.gitignore index 6251d93e8b..16593984dd 100644 --- a/.gitignore +++ b/.gitignore @@ -50,6 +50,7 @@ coverage.xml *.py,cover .hypothesis/ .pytest_cache/ +test-results/ # Translations *.mo @@ -107,6 +108,7 @@ venv/ ENV/ env.bak/ venv.bak/ +venv*/ # Spyder project settings .spyderproject @@ -136,14 +138,11 @@ dmypy.json *~ *# -# Airflow example -examples/airflow/Dockerfile -examples/airflow/docker-compose.yaml -examples/airflow/airflow.sh -examples/airflow/.env -examples/airflow/logs -examples/airflow/plugins -examples/airflow/warehouse +# Vim +*.swp +*.swo +.null-ls* + *.duckdb *.duckdb.wal @@ -162,3 +161,7 @@ tests/_version.py # spark metastore_db/ spark-warehouse/ + +# claude +.claude/ + diff --git a/.nvmrc b/.nvmrc new file mode 100644 index 0000000000..209e3ef4b6 --- /dev/null +++ b/.nvmrc @@ -0,0 +1 @@ +20 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3057e365ec..bb63cf1be1 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,7 +7,7 @@ repos: language: python types_or: [python, pyi] require_serial: true - files: &files ^(sqlmesh/|tests/|web/|examples/|setup.py) + files: &files ^(sqlmesh/|sqlmesh_dbt/|tests/|web/|examples/|setup.py) - id: ruff-format name: ruff-format entry: ruff format --force-exclude --line-length 100 @@ -23,35 +23,8 @@ repos: files: *files require_serial: true exclude: ^(tests/fixtures) - - repo: https://github.com/pre-commit/mirrors-prettier - rev: "fc26039" - hooks: - - id: prettier - name: prettier - files: ^(web/client) - entry: prettier --write --ignore-path web/client/.prettierignore - exclude: ^(web/client/node_modules) - require_serial: true - language: node - - repo: https://github.com/pre-commit/mirrors-eslint - rev: "4620ec5" - hooks: - - id: eslint - name: eslint - files: ^(web/client) - exclude: ^(web/client/node_modules) - entry: eslint --fix - additional_dependencies: - [ - "@typescript-eslint/eslint-plugin@6.5.0", - "@typescript-eslint/parser@6.5.0", - eslint@8.48.0, - eslint-config-prettier@9.0.0, - eslint-config-standard-with-typescript@39.0.0, - eslint-plugin-import@2.28.1, - eslint-plugin-n@16.0.2, - eslint-plugin-promise@6.1.1, - eslint-plugin-react@7.33.2, - ] - require_serial: true - language: node + - id: valid migrations + name: valid migrations + entry: tooling/validating_migration_numbers.sh + language: system + pass_filenames: false diff --git a/.prettierignore b/.prettierignore new file mode 100644 index 0000000000..78cf56de58 --- /dev/null +++ b/.prettierignore @@ -0,0 +1,44 @@ +web/client/**/*.py +web/client/.prettierignore +web/client/.gitignore +web/client/node_modules/ +web/client/test-results/ +web/client/playwright-report/ +web/client/playwright/.cache/ +web/client/dist +web/client/public/favicons/ +web/client/public/fonts/ +web/client/src/styles/fonts/ +web/client/src/assets/fonts/ +web/client/tsconfig.tsbuildinfo +web/client/src/utils/tbk-components.js + +node_modules/ +vscode/extension/node_modules/ +vscode/extension/dist +vscode/extension/out +vscode/extension/src_react +vscode/extension/tsconfig.tsbuildinfo +vscode/extension/.vscode-test/ +vscode/extension/playwright-report/ +vscode/extension/test-results/ +vscode/extension/.test_setup + +sqlmesh +docs +/tests/** +examples +posts +.circleci +README.md +mkdocs.yml +.readthedocs.yaml +.pre-commit-config.yaml +package-lock.json +**/*.md +.ruff_cache +.pytest_cache +.venv +.vscode +build +pnpm-lock.yaml \ No newline at end of file diff --git a/web/client/.prettierrc.js b/.prettierrc.cjs similarity index 100% rename from web/client/.prettierrc.js rename to .prettierrc.cjs diff --git a/.readthedocs.yaml b/.readthedocs.yaml index dfdc4ce507..68d856c589 100644 --- a/.readthedocs.yaml +++ b/.readthedocs.yaml @@ -3,10 +3,10 @@ version: 2 build: os: ubuntu-22.04 tools: - python: "3.8" + python: "3.10" jobs: pre_build: - - pip install -e . + - pip install -e ".[athena,azuresql,bigframes,bigquery,clickhouse,databricks,dbt,dlt,gcppostgres,github,llm,mssql,mysql,mwaa,postgres,redshift,slack,snowflake,trino,web,risingwave]" - make api-docs mkdocs: diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000000..a7f86098d1 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,354 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Agent-Based Development Workflow + +Every time the user requests a feature or bug fix, you MUST follow the process below: + +### Development Process + +1. **Understanding The Task**: Use the `developer` agent to understand what the user is asking for and to read GitHub issues +2. **Feature Development & Bug Fixes**: Use the `developer` agent for implementing features and fixing bugs. IMPORTANT: Always begin by writing a failing test (or tests) that reflects the expected behavior +3. **Code Review**: After development work, invoke the `code-reviewer` agent to review the implementation +4. **Iteration**: Use the `developer` agent again to address feedback from the code reviewer +5. **Repeat**: Continue the developer → code-reviewer cycle until no more feedback remains +6. **Documentation**: If the feature or bug fix requires documentation updates, invoke the `technical-writer` agent + +IMPORTANT: Make sure to share the project overview, architecture overview, and other concepts outlined below with the agent when it is invoked. + +### Agent Responsibilities + +**Developer Agent**: +- Understands a feature request or a reported issue +- Implements new features following SQLMesh's architecture patterns +- Fixes bugs with proper understanding of the codebase +- Writes comprehensive tests following SQLMesh's testing conventions +- Follows established code style and conventions + +**Code-Reviewer Agent**: +- Reviews implementation for quality and architectural compliance +- Identifies potential issues, edge cases, and improvements +- Ensures adherence to SQLMesh patterns and best practices +- Validates test coverage and quality + +**Technical-Writer Agent**: +- Creates and updates user-facing documentation +- Writes API documentation for new features +- Updates existing docs after code changes +- Creates migration guides and deep-dive technical explanations + +## Project Overview + +SQLMesh is a next-generation data transformation framework that enables: +- Virtual data environments for isolated development without warehouse costs +- Plan/apply workflow (like Terraform) for safe deployments +- Multi-dialect SQL support with automatic transpilation +- Incremental processing to run only necessary transformations +- Built-in testing and CI/CD integration + +**Requirements**: Python >= 3.9 (Note: Python 3.13+ is not yet supported) + +## Essential Commands + +### Environment setup +```bash +# Create and activate a Python virtual environment (Python >= 3.9, < 3.13) +python -m venv .venv +source ./.venv/bin/activate # On Windows: .venv\Scripts\activate + +# Install development dependencies +make install-dev + +# Setup pre-commit hooks (important for code quality) +make install-pre-commit +``` + +### Common Development Tasks +```bash +# Run linters and formatters (ALWAYS run before committing) +make style + +# Fast tests for quick feedback during development +make fast-test + +# Slow tests for comprehensive coverage +make slow-test + +# Run specific test file +pytest tests/core/test_context.py -v + +# Run tests with specific marker +pytest -m "not slow and not docker" -v + +# Build package +make package + +# Serve documentation locally +make docs-serve +``` + +### Engine-Specific Testing +```bash +# DuckDB (default, no setup required) +make duckdb-test + +# Other engines require credentials/Docker +make snowflake-test # Needs SNOWFLAKE_* env vars +make bigquery-test # Needs GOOGLE_APPLICATION_CREDENTIALS +make databricks-test # Needs DATABRICKS_* env vars +``` + +### UI Development +```bash +# In web/client directory +pnpm run dev # Start development server +pnpm run build # Production build +pnpm run test # Run tests + +# Docker-based UI +make ui-up # Start UI in Docker +make ui-down # Stop UI +``` + +## Architecture Overview + +### Core Components + +**sqlmesh/core/context.py**: The main Context class orchestrates all SQLMesh operations. This is the entry point for understanding how models are loaded, plans are created, and executions happen. + +**sqlmesh/core/model/**: Model definitions and kinds (FULL, INCREMENTAL_BY_TIME_RANGE, SCD_TYPE_2, etc.). Each model kind has specific behaviors for how data is processed. + +**sqlmesh/core/snapshot/**: The versioning system. Snapshots are immutable versions of models identified by fingerprints. Understanding snapshots is crucial for how SQLMesh tracks changes. + +**sqlmesh/core/plan/**: Plan building and evaluation logic. Plans determine what changes need to be applied and in what order. + +**sqlmesh/core/engine_adapter/**: Database engine adapters provide a unified interface across 16+ SQL engines. Each adapter handles engine-specific SQL generation and execution. + +### Key Concepts + +1. **Virtual Environments**: Lightweight branches that share unchanged data between environments, reducing storage costs and deployment time. + +2. **Fingerprinting**: Models are versioned using content-based fingerprints. Any change to a model's logic creates a new version. + +3. **State Sync**: Manages metadata across different backends (can be stored in the data warehouse or external databases). + +4. **Intervals**: Time-based partitioning system for incremental models, tracking what data has been processed. + +## Important Files + +- `sqlmesh/core/context.py`: Main orchestration class +- `examples/sushi/`: Reference implementation used in tests +- `web/server/main.py`: Web UI backend entry point +- `web/client/src/App.tsx`: Web UI frontend entry point +- `vscode/extension/src/extension.ts`: VSCode extension entry point + +## GitHub CI/CD Bot Architecture + +SQLMesh includes a GitHub CI/CD bot integration that automates data transformation workflows. The implementation is located in `sqlmesh/integrations/github/` and follows a clean architectural pattern. + +### Code Organization + +**Core Integration Files:** +- `sqlmesh/cicd/bot.py`: Main CLI entry point (`sqlmesh_cicd` command) +- `sqlmesh/integrations/github/cicd/controller.py`: Core bot orchestration logic +- `sqlmesh/integrations/github/cicd/command.py`: Individual command implementations +- `sqlmesh/integrations/github/cicd/config.py`: Configuration classes and validation + +### Architecture Pattern + +The bot follows a **Command Pattern** architecture: + +1. **CLI Layer** (`bot.py`): Handles argument parsing and delegates to controllers +2. **Controller Layer** (`controller.py`): Orchestrates workflow execution and manages state +3. **Command Layer** (`command.py`): Implements individual operations (test, deploy, plan, etc.) +4. **Configuration Layer** (`config.py`): Manages bot configuration and validation + +### Key Components + +**GitHubCICDController**: Main orchestrator that: +- Manages GitHub API interactions via PyGithub +- Coordinates workflow execution across different commands +- Handles error reporting through GitHub Check Runs +- Manages PR comment interactions and status updates + +**Command Implementations**: +- `run_tests()`: Executes unit tests with detailed reporting +- `update_pr_environment()`: Creates/updates virtual PR environments +- `gen_prod_plan()`: Generates production deployment plans +- `deploy_production()`: Handles production deployments +- `check_required_approvers()`: Validates approval requirements + +**Configuration Management**: +- Uses Pydantic models for type-safe configuration +- Supports both YAML config files and environment variables +- Validates bot settings and user permissions +- Handles approval workflows and deployment triggers + +### Integration with Core SQLMesh + +The bot leverages core SQLMesh components: +- **Context**: Uses SQLMesh Context for project operations +- **Plan/Apply**: Integrates with SQLMesh's plan generation and application +- **Virtual Environments**: Creates isolated PR environments using SQLMesh's virtual data environments +- **State Sync**: Manages metadata synchronization across environments +- **Testing Framework**: Executes SQLMesh unit tests and reports results + +### Error Handling and Reporting + +- **GitHub Check Runs**: Creates detailed status reports for each workflow step +- **PR Comments**: Provides user-friendly feedback on failures and successes +- **Structured Logging**: Uses SQLMesh's logging framework for debugging +- **Exception Handling**: Graceful handling of GitHub API failures and SQLMesh errors + +## Environment Variables for Engine Testing + +When running engine-specific tests, these environment variables are required: + +- **Snowflake**: `SNOWFLAKE_ACCOUNT`, `SNOWFLAKE_WAREHOUSE`, `SNOWFLAKE_DATABASE`, `SNOWFLAKE_USER`, `SNOWFLAKE_PASSWORD` +- **BigQuery**: `BIGQUERY_KEYFILE` or `GOOGLE_APPLICATION_CREDENTIALS` +- **Databricks**: `DATABRICKS_CATALOG`, `DATABRICKS_SERVER_HOSTNAME`, `DATABRICKS_HTTP_PATH`, `DATABRICKS_ACCESS_TOKEN`, `DATABRICKS_CONNECT_VERSION` +- **Redshift**: `REDSHIFT_HOST`, `REDSHIFT_USER`, `REDSHIFT_PASSWORD`, `REDSHIFT_DATABASE` +- **Athena**: `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `ATHENA_S3_WAREHOUSE_LOCATION` +- **ClickHouse Cloud**: `CLICKHOUSE_CLOUD_HOST`, `CLICKHOUSE_CLOUD_USERNAME`, `CLICKHOUSE_CLOUD_PASSWORD` + +## Migrations System + +SQLMesh uses a migration system to evolve its internal state database schema and metadata format. The migrations handle changes to SQLMesh's internal structure, not user data transformations. + +### Migration Structure + +**Location**: `sqlmesh/migrations/` - Contains 80+ migration files from v0001 to v0083+ + +**Naming Convention**: `v{XXXX}_{descriptive_name}.py` (e.g., `v0001_init.py`, `v0083_use_sql_for_scd_time_data_type_data_hash.py`) + +**Core Infrastructure**: +- `sqlmesh/core/state_sync/db/migrator.py`: Main migration orchestrator +- `sqlmesh/utils/migration.py`: Cross-database compatibility utilities +- `sqlmesh/core/state_sync/base.py`: Auto-discovery and loading logic + +### Migration Categories + +**Schema Evolution**: +- State table creation/modification (snapshots, environments, intervals) +- Column additions/removals and index management +- Database engine compatibility fixes (MySQL/MSSQL field size limits) + +**Data Format Migrations**: +- JSON metadata structure updates (snapshot serialization changes) +- Path normalization (Windows compatibility) +- Fingerprint recalculation when SQLGlot parsing changes + +**Cleanup Operations**: +- Removing obsolete tables and unused data +- Metadata optimization and attribute cleanup + +### Key Migration Patterns + +```python +# Standard migration function signature +def migrate(state_sync, **kwargs): # type: ignore + engine_adapter = state_sync.engine_adapter + schema = state_sync.schema + # Migration logic here + +# Common operations +engine_adapter.create_state_table(table_name, columns_dict) +engine_adapter.alter_table(alter_expression) +engine_adapter.drop_table(table_name) +``` + +### State Management Integration + +**Core State Tables**: +- `_snapshots`: Model version metadata (most frequently migrated) +- `_environments`: Environment definitions +- `_versions`: Schema/SQLGlot/SQLMesh version tracking +- `_intervals`: Incremental processing metadata + +**Migration Safety**: +- Automatic backups before migration (unless `skip_backup=True`) +- Atomic database transactions for consistency +- Snapshot count validation before/after migrations +- Automatic rollback on failures + +### Migration Execution + +**Auto-Discovery**: Migrations are automatically loaded using `pkgutil.iter_modules()` + +**Triggers**: Migrations run automatically when: +- Schema version mismatch detected +- SQLGlot version changes require fingerprint recalculation +- Manual `sqlmesh migrate` command execution + +**Execution Flow**: +1. Version comparison (local vs remote schema) +2. Backup creation of state tables +3. Sequential migration execution (numerical order) +4. Snapshot fingerprint recalculation if needed +5. Environment updates with new snapshot references + +## dbt Integration + +SQLMesh provides native support for dbt projects, allowing users to run existing dbt projects while gaining access to SQLMesh's advanced features like virtual environments and plan/apply workflows. + +### Core dbt Integration + +**Location**: `sqlmesh/dbt/` - Complete dbt integration architecture + +**Key Components**: +- `sqlmesh/dbt/loader.py`: Main dbt project loader extending SQLMesh's base loader +- `sqlmesh/dbt/manifest.py`: dbt manifest parsing and project discovery +- `sqlmesh/dbt/adapter.py`: dbt adapter system for SQL execution and schema operations +- `sqlmesh/dbt/model.py`: dbt model configurations and materialization mapping +- `sqlmesh/dbt/context.py`: dbt project context and environment management + +### Project Conversion + +**dbt Converter**: `sqlmesh/dbt/converter/` - Tools for migrating dbt projects to SQLMesh + +**Key Features**: +- `convert.py`: Main conversion orchestration +- `jinja.py` & `jinja_transforms.py`: Jinja template and macro conversion +- Full support for dbt assets (models, seeds, sources, tests, snapshots, macros) + +**CLI Commands**: +```bash +# Initialize SQLMesh in existing dbt project +sqlmesh init -t dbt + +# Convert dbt project to SQLMesh format +sqlmesh dbt convert +``` + +### Supported dbt Features + +**Project Structure**: +- Full dbt project support (models, seeds, sources, tests, snapshots, macros) +- dbt package dependencies and version management +- Profile integration using existing `profiles.yml` for connections + +**Materializations**: +- All standard dbt materializations (table, view, incremental, ephemeral) +- Incremental model strategies (delete+insert, merge, insert_overwrite) +- SCD Type 2 support and snapshot strategies + +**Advanced Features**: +- Jinja templating with full macro support +- Runtime variable passing and configuration +- dbt test integration and execution +- Cross-database compatibility with SQLMesh's multi-dialect support + +### Example Projects + +**sushi_dbt**: `examples/sushi_dbt/` - Complete dbt project running with SQLMesh +**Test Fixtures**: `tests/fixtures/dbt/sushi_test/` - Comprehensive test dbt project with all asset types + +### Integration Benefits + +When using dbt with SQLMesh, you gain: +- **Virtual Environments**: Isolated development without warehouse costs +- **Plan/Apply Workflow**: Safe deployments with change previews +- **Multi-Dialect Support**: Run the same dbt project across different SQL engines +- **Advanced Testing**: Enhanced testing capabilities beyond standard dbt tests +- **State Management**: Sophisticated metadata and versioning system diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000..287a87dab5 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,5 @@ +# Code of Conduct + +SQLMesh follows the [LF Projects Code of Conduct](https://lfprojects.org/policies/code-of-conduct/). All participants in the project are expected to abide by it. + +If you believe someone is violating the code of conduct, please report it by following the instructions in the [LF Projects Code of Conduct](https://lfprojects.org/policies/code-of-conduct/). diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000..0e1d8e1c6e --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,90 @@ +# Contributing to SQLMesh + +## Welcome + +SQLMesh is a project of the Linux Foundation. We welcome contributions from anyone — whether you're fixing a bug, improving documentation, or proposing a new feature. + +## Technical Steering Committee (TSC) + +The TSC is responsible for technical oversight of the SQLMesh project, including coordinating technical direction, approving contribution policies, and maintaining community norms. + +Initial TSC voting members are the project's Maintainers: + +| Name | GitHub Handle | Affiliation | Role | +|---------------------|---------------|----------------|------------| +| Alexander Butler | z3z1ma | Harness | TSC Member | +| Alexander Filipchik | afilipchik | Cloud Kitchens | TSC Member | +| Reid Hooper | rhooper9711 | Benzinga | TSC Member | +| Yuki Kakegawa | StuffbyYuki | Jump.ai | TSC Member | +| Toby Mao | tobymao | Fivetran | TSC Chair | +| Alex Wilde | alexminerv | Minerva | TSC Member | + + +## Roles + +**Contributors**: Anyone who contributes code, documentation, or other technical artifacts to the project. + +**Maintainers**: Contributors who have earned the ability to modify source code, documentation, or other technical artifacts. A Contributor may become a Maintainer by majority approval of the TSC. A Maintainer may be removed by majority approval of the TSC. + +## How to Contribute + +1. Fork the repository on GitHub +2. Create a branch for your changes +3. Make your changes and commit them with a sign-off (see DCO section below) +4. Submit a pull request against the `main` branch + +File issues at [github.com/sqlmesh/sqlmesh/issues](https://github.com/sqlmesh/sqlmesh/issues). + +## Developer Certificate of Origin (DCO) + +All contributions must include a `Signed-off-by` line in the commit message per the [Developer Certificate of Origin](DCO). This certifies that you wrote the contribution or have the right to submit it under the project's open source license. + +Use `git commit -s` to add the sign-off automatically: + +```bash +git commit -s -m "Your commit message" +``` + +To fix a commit that is missing the sign-off: + +```bash +git commit --amend -s +``` + +To add a sign-off to multiple commits: + +```bash +git rebase HEAD~N --signoff +``` + +## Development Setup + +See [docs/development.md](docs/development.md) for full setup instructions. Key commands: + +```bash +python -m venv .venv +source .venv/bin/activate +make install-dev +make style # Run before submitting +make fast-test # Quick test suite +``` + +## Coding Standards + +- Run `make style` before submitting a pull request +- Follow existing code patterns and conventions in the codebase +- New files should include an SPDX license header: + ```python + # SPDX-License-Identifier: Apache-2.0 + ``` + +## Pull Request Process + +- Describe your changes clearly in the pull request description +- Ensure all CI checks pass +- Include a DCO sign-off on all commits (`git commit -s`) +- Be responsive to review feedback from maintainers + +## Licensing + +Code contributions are licensed under the [Apache License 2.0](LICENSE). Documentation contributions are licensed under [Creative Commons Attribution 4.0 International (CC-BY-4.0)](https://creativecommons.org/licenses/by/4.0/). See the LICENSE file and the [technical charter](sqlmesh-technical-charter.pdf) for details. diff --git a/DCO b/DCO new file mode 100644 index 0000000000..49b8cb0549 --- /dev/null +++ b/DCO @@ -0,0 +1,34 @@ +Developer Certificate of Origin +Version 1.1 + +Copyright (C) 2004, 2006 The Linux Foundation and its contributors. + +Everyone is permitted to copy and distribute verbatim copies of this +license document, but changing it is not allowed. + + +Developer's Certificate of Origin 1.1 + +By making a contribution to this project, I certify that: + +(a) The contribution was created in whole or in part by me and I + have the right to submit it under the open source license + indicated in the file; or + +(b) The contribution is based upon previous work that, to the best + of my knowledge, is covered under an appropriate open source + license and I have the right under that license to submit that + work with modifications, whether created in whole or in part + by me, under the same open source license (unless I am + permitted to submit under a different license), as indicated + in the file; or + +(c) The contribution was provided directly to me by some other + person who certified (a), (b) or (c) and I have not modified + it. + +(d) I understand and agree that this project and the contribution + are public and that a record of the contribution (including all + personal information I submit with it, including my sign-off) is + maintained indefinitely and may be redistributed consistent with + this project or the open source license(s) involved. diff --git a/GOVERNANCE.md b/GOVERNANCE.md new file mode 100644 index 0000000000..44b6bc9947 --- /dev/null +++ b/GOVERNANCE.md @@ -0,0 +1,62 @@ +# SQLMesh Project Governance + +## Overview + +SQLMesh is a Series of LF Projects, LLC. The project is governed by its [Technical Charter](sqlmesh-technical-charter.pdf) and overseen by the Technical Steering Committee (TSC). SQLMesh is a project of the [Linux Foundation](https://www.linuxfoundation.org/). + +## Technical Steering Committee + +The TSC is responsible for all technical oversight of the project, including: + +- Coordinating the technical direction of the project +- Approving project or system proposals +- Organizing sub-projects and removing sub-projects +- Creating sub-committees or working groups to focus on cross-project technical issues +- Appointing representatives to work with other open source or open standards communities +- Establishing community norms, workflows, issuing releases, and security vulnerability reports +- Approving and implementing policies for contribution requirements +- Coordinating any marketing, events, or communications regarding the project + +## TSC Composition + +TSC voting members are initially the project's Maintainers as listed in [CONTRIBUTING.md](CONTRIBUTING.md). The TSC may elect a Chair from among its voting members. The Chair presides over TSC meetings and serves as the primary point of contact with the Linux Foundation. + +## Decision Making + +The project operates as a consensus-based community. When a formal vote is required: + +- Each voting TSC member receives one vote +- A quorum of 50% of voting members is required to conduct a vote +- Decisions are made by a majority of those present when quorum is met +- Electronic votes (e.g., via GitHub issues or mailing list) require a majority of all voting members to pass +- Votes that do not meet quorum or remain unresolved may be referred to the Series Manager for resolution + +## Charter Amendments + +The technical charter may be amended by a two-thirds vote of the entire TSC, subject to approval by LF Projects, LLC. + +## Reference + +The full technical charter is available at [sqlmesh-technical-charter.pdf](sqlmesh-technical-charter.pdf). + +# TSC Meeting Minutes + +## 2026-03-10 — Initial TSC Meeting + +**Members present:** Toby Mao (tobymao) + +### Vote 1: Elect Toby Mao as TSC Chair +- **Motion by:** Toby Mao +- **Votes:** Toby Mao: Yes +- **Result:** Approved (1-0-0, yes-no-abstain) + +### Vote 2: Elect TSC founding members +- **Question:** Shall the following members be added to the TSC? + - Alexander Butler (z3z1ma) + - Alexander Filipchik (afilipchik) + - Reid Hooper (rhooper9711) + - Yuki Kakegawa (StuffbyYuki) + - Alex Wilde (alexminerv) +- **Motion by:** Toby Mao +- **Votes:** Toby Mao: Yes +- **Result:** Approved (1-0-0, yes-no-abstain) diff --git a/LICENSE b/LICENSE index eabfad022a..7e95724816 100644 --- a/LICENSE +++ b/LICENSE @@ -186,7 +186,7 @@ same "printed page" as the copyright notice for easier identification within third-party archives. - Copyright 2024 Tobiko Data Inc. + Copyright Contributors to the SQLMesh project Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000000..7ecb7896bd --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,2 @@ +prune docs/ +prune posts/ diff --git a/Makefile b/Makefile index 4d1d4b2d01..e7a78de472 100644 --- a/Makefile +++ b/Makefile @@ -1,20 +1,71 @@ .PHONY: docs -install-dev: - pip3 install -e ".[dev,web,slack]" +ifdef UV + PIP := uv pip +else + PIP := pip3 +endif + +UNAME_S := $(shell uname -s) +ifeq ($(UNAME_S),Darwin) + SED_INPLACE = sed -i '' +else + SED_INPLACE = sed -i +endif -install-cicd-test: - pip3 install -e ".[dev,web,slack,cicdtest]" +install-dev: + $(PIP) install -e ".[dev,web,slack,dlt,lsp]" ./examples/custom_materializations install-doc: - pip3 install -r ./docs/requirements.txt - -install-engine-test: - pip3 install -e ".[dev,web,slack,mysql,postgres,databricks,redshift,bigquery,snowflake,trino,mssql]" + $(PIP) install -r ./docs/requirements.txt install-pre-commit: pre-commit install +install-dev-dbt-%: + @version="$*"; \ + period_count=$$(echo "$$version" | tr -cd '.' | wc -c); \ + if [ "$$period_count" -eq 0 ]; then \ + version="$${version:0:1}.$${version:1}"; \ + elif [ "$$period_count" -eq 1 ]; then \ + version="$$version.0"; \ + fi; \ + echo "Installing dbt version: $$version"; \ + cp pyproject.toml pyproject.toml.backup; \ + $(SED_INPLACE) 's/"pydantic>=2.0.0"/"pydantic"/g' pyproject.toml; \ + if [ "$$version" = "1.10.0" ]; then \ + echo "Applying special handling for dbt 1.10.0"; \ + $(SED_INPLACE) -E 's/"(dbt-core)[^"]*"/"\1~='"$$version"'"/g' pyproject.toml; \ + $(SED_INPLACE) -E 's/"(dbt-(bigquery|duckdb|snowflake|athena-community|clickhouse|redshift|trino))[^"]*"/"\1"/g' pyproject.toml; \ + $(SED_INPLACE) -E 's/"(dbt-databricks)[^"]*"/"\1~='"$$version"'"/g' pyproject.toml; \ + else \ + echo "Applying version $$version to all dbt packages"; \ + $(SED_INPLACE) -E 's/"(dbt-[^"><=~!]+)[^"]*"/"\1~='"$$version"'"/g' pyproject.toml; \ + fi; \ + if printf '%s\n' "$$version" | awk -F. '{ if ($$1 == 1 && (($$2 >= 3 && $$2 <= 5) || $$2 == 10)) exit 0; exit 1 }'; then \ + echo "Applying numpy<2 constraint for dbt $$version"; \ + $(SED_INPLACE) 's/"numpy"/"numpy<2"/g' pyproject.toml; \ + fi; \ + $(MAKE) install-dev; \ + if [ "$$version" = "1.6.0" ]; then \ + echo "Applying overrides for dbt 1.6.0"; \ + $(PIP) install 'pydantic>=2.0.0' 'google-cloud-bigquery==3.30.0' 'databricks-sdk==0.28.0' --reinstall; \ + fi; \ + if [ "$$version" = "1.7.0" ]; then \ + echo "Applying overrides for dbt 1.7.0"; \ + $(PIP) install 'databricks-sdk==0.28.0' --reinstall; \ + fi; \ + if [ "$$version" = "1.5.0" ]; then \ + echo "Applying overrides for dbt 1.5.0"; \ + $(PIP) install 'dbt-databricks==1.5.6' 'numpy<2' --reinstall; \ + fi; \ + if [ "$$version" = "1.3.0" ]; then \ + echo "Applying overrides for dbt $$version - upgrading google-cloud-bigquery"; \ + $(PIP) install 'google-cloud-bigquery>=3.0.0' --upgrade; \ + fi; \ + mv pyproject.toml.backup pyproject.toml; \ + echo "Restored original pyproject.toml" + style: pre-commit run --all-files @@ -22,43 +73,22 @@ py-style: SKIP=prettier,eslint pre-commit run --all-files ui-style: - SKIP=ruff,ruff-format,mypy pre-commit run --all-files + pnpm run lint doc-test: - PYTEST_PLUGINS=tests.common_fixtures pytest --doctest-modules sqlmesh/core sqlmesh/utils + python -m pytest --doctest-modules sqlmesh/core sqlmesh/utils package: - pip3 install wheel && python3 setup.py sdist bdist_wheel + $(PIP) install build && python3 -m build publish: package - pip3 install twine && python3 -m twine upload dist/* + $(PIP) install twine && python3 -m twine upload dist/* package-tests: - pip3 install wheel && python3 tests/setup.py sdist bdist_wheel + $(PIP) install build && cp pyproject.toml tests/sqlmesh_pyproject.toml && python3 -m build tests/ publish-tests: package-tests - pip3 install twine && python3 -m twine upload -r tobiko-private tests/dist/* - -develop: - python3 setup.py develop - -airflow-init: - export AIRFLOW_ENGINE_OPERATOR=spark && make -C ./examples/airflow init - -airflow-run: - make -C ./examples/airflow run - -airflow-stop: - make -C ./examples/airflow stop - -airflow-clean: - make -C ./examples/airflow clean - -airflow-psql: - make -C ./examples/airflow psql - -airflow-spark-sql: - make -C ./examples/airflow spark-sql + $(PIP) install twine && python3 -m twine upload -r tobiko-private tests/dist/* docs-serve: mkdocs serve @@ -70,59 +100,43 @@ api-docs-serve: python pdoc/cli.py ui-up: - docker-compose -f ./web/docker-compose.yml up --build -d && $(if $(shell which open), open http://localhost:8001, echo "Open http://localhost:8001 in your browser.") + docker compose -f ./web/docker-compose.yml up --build -d && $(if $(shell which open), open http://localhost:8001, echo "Open http://localhost:8001 in your browser.") ui-down: - docker-compose -f ./web/docker-compose.yml down + docker compose -f ./web/docker-compose.yml down ui-build: - docker-compose -f ./web/docker-compose.yml -f ./web/docker-compose.build.yml run app + docker compose -f ./web/docker-compose.yml -f ./web/docker-compose.build.yml run app clean-build: rm -rf build/ && rm -rf dist/ && rm -rf *.egg-info +clear-caches: + find . -type d -name ".cache" -exec rm -rf {} + 2>/dev/null && echo "Successfully removed all .cache directories" + dev-publish: ui-build clean-build publish jupyter-example: jupyter lab tests/slows/jupyter/example_outputs.ipynb -engine-up: - docker-compose -f ./tests/core/engine_adapter/docker-compose.yaml up -d +engine-up: engine-clickhouse-up engine-mssql-up engine-mysql-up engine-postgres-up engine-spark-up engine-trino-up -engine-down: - docker-compose -f ./tests/core/engine_adapter/docker-compose.yaml down +engine-down: engine-clickhouse-down engine-mssql-down engine-mysql-down engine-postgres-down engine-spark-down engine-trino-down fast-test: - pytest -n auto -m "fast and not cicdonly" + pytest -n auto -m "fast and not cicdonly" --junitxml=test-results/junit-fast-test.xml && pytest -m "isolated" && pytest -m "registry_isolation" && pytest -m "dialect_isolated" slow-test: - pytest -n auto -m "(fast or slow) and not cicdonly" + pytest -n auto -m "(fast or slow) and not cicdonly" && pytest -m "isolated" && pytest -m "registry_isolation" && pytest -m "dialect_isolated" cicd-test: - pytest -n auto -m "fast or slow" + pytest -n auto -m "fast or slow" --junitxml=test-results/junit-cicd.xml && pytest -m "isolated" && pytest -m "registry_isolation" && pytest -m "dialect_isolated" core-fast-test: - pytest -n auto -m "fast and not web and not github and not dbt and not airflow and not jupyter" + pytest -n auto -m "fast and not web and not github and not dbt and not jupyter" core-slow-test: - pytest -n auto -m "(fast or slow) and not web and not github and not dbt and not airflow and not jupyter" - -airflow-fast-test: - pytest -n auto -m "fast and airflow" - -airflow-test: - pytest -n auto -m "(fast or slow) and airflow" - -airflow-local-test: - export AIRFLOW__DATABASE__SQL_ALCHEMY_CONN=postgresql+psycopg2://airflow:airflow@localhost/airflow && \ - pytest -n 1 -m "docker and airflow" - -airflow-docker-test: - make -C ./examples/airflow docker-test - -airflow-local-test-with-env: develop airflow-clean airflow-init airflow-run airflow-local-test airflow-stop - -airflow-docker-test-with-env: develop airflow-clean airflow-init airflow-run airflow-docker-test airflow-stop + pytest -n auto -m "(fast or slow) and not web and not github and not dbt and not jupyter" engine-slow-test: pytest -n auto -m "(fast or slow) and engine" @@ -139,6 +153,9 @@ engine-test: dbt-test: pytest -n auto -m "dbt and not cicdonly" +dbt-fast-test: + pytest -n auto -m "dbt and fast" --reruns 3 + github-test: pytest -n auto -m "github" @@ -148,35 +165,91 @@ jupyter-test: web-test: pytest -n auto -m "web" -bigquery-test: - pytest -n auto -m "bigquery" +guard-%: + @ if [ "${${*}}" = "" ]; then \ + echo "Environment variable $* not set"; \ + exit 1; \ + fi + +engine-%-install: + $(PIP) install -e ".[dev,web,slack,lsp,${*}]" ./examples/custom_materializations + +engine-docker-%-up: + docker compose -f ./tests/core/engine_adapter/integration/docker/compose.${*}.yaml up -d + ./.circleci/wait-for-db.sh ${*} + +engine-%-up: engine-%-install engine-docker-%-up + @echo "Engine '${*}' is up and running" + +engine-%-down: + docker compose -f ./tests/core/engine_adapter/integration/docker/compose.${*}.yaml down -v + +################## +# Docker Engines # +################## + +clickhouse-test: engine-clickhouse-up + pytest -n auto -m "clickhouse" --reruns 3 --junitxml=test-results/junit-clickhouse.xml + +duckdb-test: engine-duckdb-install + pytest -n auto -m "duckdb" --reruns 3 --junitxml=test-results/junit-duckdb.xml + +mssql-test: engine-mssql-up + pytest -n auto -m "mssql" --reruns 3 --junitxml=test-results/junit-mssql.xml + +mysql-test: engine-mysql-up + pytest -n auto -m "mysql" --reruns 3 --junitxml=test-results/junit-mysql.xml + +postgres-test: engine-postgres-up + pytest -n auto -m "postgres" --reruns 3 --junitxml=test-results/junit-postgres.xml + +spark-test: engine-spark-up + pytest -n auto -m "spark" --reruns 3 --junitxml=test-results/junit-spark.xml && pytest -n auto -m "pyspark" --reruns 3 --junitxml=test-results/junit-pyspark.xml + +trino-test: engine-trino-up + pytest -n auto -m "trino" --reruns 3 --junitxml=test-results/junit-trino.xml + +risingwave-test: engine-risingwave-up + pytest -n auto -m "risingwave" --reruns 3 --junitxml=test-results/junit-risingwave.xml + +################# +# Cloud Engines # +################# + +snowflake-test: guard-SNOWFLAKE_ACCOUNT guard-SNOWFLAKE_WAREHOUSE guard-SNOWFLAKE_DATABASE guard-SNOWFLAKE_USER engine-snowflake-install + pytest -n auto -m "snowflake" --reruns 3 --junitxml=test-results/junit-snowflake.xml -databricks-test: - pytest -n auto -m "databricks" +bigquery-test: guard-BIGQUERY_KEYFILE engine-bigquery-install + $(PIP) install -e ".[bigframes]" + pytest -n auto -m "bigquery" --reruns 3 --junitxml=test-results/junit-bigquery.xml -duckdb-test: - pytest -n auto -m "duckdb" +databricks-test: guard-DATABRICKS_CATALOG guard-DATABRICKS_SERVER_HOSTNAME guard-DATABRICKS_HTTP_PATH guard-DATABRICKS_CONNECT_VERSION engine-databricks-install + $(PIP) install 'databricks-connect==${DATABRICKS_CONNECT_VERSION}' + pytest -n auto -m "databricks" --reruns 3 --junitxml=test-results/junit-databricks.xml -mssql-test: - pytest -n auto -m "mssql" +redshift-test: guard-REDSHIFT_HOST guard-REDSHIFT_USER guard-REDSHIFT_PASSWORD guard-REDSHIFT_DATABASE engine-redshift-install + pytest -n auto -m "redshift" --reruns 3 --junitxml=test-results/junit-redshift.xml -mysql-test: - pytest -n auto -m "mysql" +clickhouse-cloud-test: guard-CLICKHOUSE_CLOUD_HOST guard-CLICKHOUSE_CLOUD_USERNAME guard-CLICKHOUSE_CLOUD_PASSWORD engine-clickhouse-install + pytest -n 1 -m "clickhouse_cloud" --reruns 3 --junitxml=test-results/junit-clickhouse-cloud.xml -postgres-test: - pytest -n auto -m "postgres" +athena-test: guard-AWS_ACCESS_KEY_ID guard-AWS_SECRET_ACCESS_KEY guard-ATHENA_S3_WAREHOUSE_LOCATION engine-athena-install + pytest -n auto -m "athena" --reruns 3 --junitxml=test-results/junit-athena.xml -redshift-test: - pytest -n auto -m "redshift" +fabric-test: guard-FABRIC_HOST guard-FABRIC_CLIENT_ID guard-FABRIC_CLIENT_SECRET guard-FABRIC_DATABASE engine-fabric-install + pytest -n auto -m "fabric" --reruns 3 --junitxml=test-results/junit-fabric.xml -snowflake-test: - pytest -n auto -m "snowflake" +gcp-postgres-test: guard-GCP_POSTGRES_INSTANCE_CONNECTION_STRING guard-GCP_POSTGRES_USER guard-GCP_POSTGRES_PASSWORD guard-GCP_POSTGRES_KEYFILE_JSON engine-gcppostgres-install + pytest -n auto -m "gcp_postgres" --reruns 3 --junitxml=test-results/junit-gcp-postgres.xml -spark-test: - pytest -n auto -m "spark" +vscode_settings: + mkdir -p .vscode + cp -r ./tooling/vscode/*.json .vscode/ -spark-pyspark-test: - pytest -n auto -m "spark_pyspark" +vscode-generate-openapi: + python3 web/server/openapi.py --output vscode/openapi.json + pnpm run fmt + cd vscode/react && pnpm run generate:api -trino-test: - pytest -n auto -m "trino or trino_iceberg or trino_delta" +benchmark-ci: + python benchmarks/lsp_render_model_bench.py --debug-single-value diff --git a/README.md b/README.md index f652d5de6c..41f78cc138 100644 --- a/README.md +++ b/README.md @@ -1,41 +1,193 @@ -![SQLMesh logo](sqlmesh.png) +

+ SQLMesh logo +

+

SQLMesh is a project of the Linux Foundation.

-SQLMesh is a next-generation data transformation and modeling framework that is backwards compatible with dbt. It aims to be easy to use, correct, and efficient. +SQLMesh is a next-generation data transformation framework designed to ship data quickly, efficiently, and without error. Data teams can run and deploy data transformations written in SQL or Python with visibility and control at any size. -SQLMesh enables data practitioners to efficiently run and deploy data transformations written in SQL or Python. +It is more than just a [dbt alternative](https://tobikodata.com/reduce_costs_with_cron_and_partitions.html). -Although SQLMesh will make your dbt projects more efficient, reliable, and maintainable, it is more than just a [dbt alternative](https://tobikodata.com/sqlmesh_for_dbt_1.html). +

+ Architecture Diagram +

-## Select Features -* [Semantic Understanding of SQL](https://tobikodata.com/semantic-understanding-of-sql.html) - * Compile time error checking (for 10 different SQL dialects!) - * Definitions using [simply SQL](https://sqlmesh.readthedocs.io/en/stable/concepts/models/sql_models/#sql-based-definition) (no need for redundant and confusing Jinja + YAML) - * [Self documenting queries](https://tobikodata.com/metadata-everywhere.html) using native SQL Comments -* Efficiency - * Never builds a table [more than once](https://tobikodata.com/simplicity-or-efficiency-how-dbt-makes-you-choose.html) - * Partition-based [incremental models](https://tobikodata.com/correctly-loading-incremental-data-at-scale.html) -* Confidence - * Plan / Apply workflow like [Terraform](https://www.terraform.io/) to understand potential impact of changes - * Easy to use [CI/CD bot](https://sqlmesh.readthedocs.io/en/stable/integrations/github/) - * Automatic [column level lineage](https://tobikodata.com/automatically-detecting-breaking-changes-in-sql-queries.html) and data contracts - * [Unit tests](https://tobikodata.com/we-need-even-greater-expectations.html) and audits +## Core Features -For more information, check out the [website](https://sqlmesh.com) and [documentation](https://sqlmesh.readthedocs.io/en/stable/). +SQLMesh Plan Mode + +> Get instant SQL impact and context of your changes, both in the CLI and in the [SQLMesh VSCode Extension](https://sqlmesh.readthedocs.io/en/latest/guides/vscode/?h=vs+cod) + +
+ Virtual Data Environments + + * See a full diagram of how [Virtual Data Environments](https://whimsical.com/virtual-data-environments-MCT8ngSxFHict4wiL48ymz) work + * [Watch this video to learn more](https://www.youtube.com/watch?v=weJH3eM0rzc) + +
+ + * Create isolated development environments without data warehouse costs + * Plan / Apply workflow like [Terraform](https://www.terraform.io/) to understand potential impact of changes + * Easy to use [CI/CD bot](https://sqlmesh.readthedocs.io/en/stable/integrations/github/) for true blue-green deployments + +
+Efficiency and Testing + +Running this command will generate a unit test file in the `tests/` folder: `test_stg_payments.yaml` + +Runs a live query to generate the expected output of the model + +```bash +sqlmesh create_test tcloud_demo.stg_payments --query tcloud_demo.seed_raw_payments "select * from tcloud_demo.seed_raw_payments limit 5" + +# run the unit test +sqlmesh test +``` + +```sql +MODEL ( + name tcloud_demo.stg_payments, + cron '@daily', + grain payment_id, + audits (UNIQUE_VALUES(columns = ( + payment_id + )), NOT_NULL(columns = ( + payment_id + ))) +); + +SELECT + id AS payment_id, + order_id, + payment_method, + amount / 100 AS amount, /* `amount` is currently stored in cents, so we convert it to dollars */ + 'new_column' AS new_column, /* non-breaking change example */ +FROM tcloud_demo.seed_raw_payments +``` + +```yaml +test_stg_payments: +model: tcloud_demo.stg_payments +inputs: + tcloud_demo.seed_raw_payments: + - id: 66 + order_id: 58 + payment_method: coupon + amount: 1800 + - id: 27 + order_id: 24 + payment_method: coupon + amount: 2600 + - id: 30 + order_id: 25 + payment_method: coupon + amount: 1600 + - id: 109 + order_id: 95 + payment_method: coupon + amount: 2400 + - id: 3 + order_id: 3 + payment_method: coupon + amount: 100 +outputs: + query: + - payment_id: 66 + order_id: 58 + payment_method: coupon + amount: 18.0 + new_column: new_column + - payment_id: 27 + order_id: 24 + payment_method: coupon + amount: 26.0 + new_column: new_column + - payment_id: 30 + order_id: 25 + payment_method: coupon + amount: 16.0 + new_column: new_column + - payment_id: 109 + order_id: 95 + payment_method: coupon + amount: 24.0 + new_column: new_column + - payment_id: 3 + order_id: 3 + payment_method: coupon + amount: 1.0 + new_column: new_column +``` +
+ +* Never build a table [more than once](https://tobikodata.com/simplicity-or-efficiency-how-dbt-makes-you-choose.html) +* Track what data’s been modified and run only the necessary transformations for [incremental models](https://tobikodata.com/correctly-loading-incremental-data-at-scale.html) +* Run [unit tests](https://tobikodata.com/we-need-even-greater-expectations.html) for free and configure automated audits +* Run [table diffs](https://sqlmesh.readthedocs.io/en/stable/examples/sqlmesh_cli_crash_course/?h=crash#run-data-diff-against-prod) between prod and dev based on tables/views impacted by a change + +
+Level Up Your SQL +Write SQL in any dialect and SQLMesh will transpile it to your target SQL dialect on the fly before sending it to the warehouse. +Transpile Example +
+ +* Debug transformation errors *before* you run them in your warehouse in [10+ different SQL dialects](https://sqlmesh.readthedocs.io/en/stable/integrations/overview/#execution-engines) +* Definitions using [simply SQL](https://sqlmesh.readthedocs.io/en/stable/concepts/models/sql_models/#sql-based-definition) (no need for redundant and confusing `Jinja` + `YAML`) +* See impact of changes before you run them in your warehouse with column-level lineage + +For more information, check out the [documentation](https://sqlmesh.readthedocs.io/en/stable/). ## Getting Started Install SQLMesh through [pypi](https://pypi.org/project/sqlmesh/) by running: -```pip install sqlmesh``` +```bash +mkdir sqlmesh-example +cd sqlmesh-example +python -m venv .venv +source .venv/bin/activate +pip install 'sqlmesh[lsp]' # install the sqlmesh package with extensions to work with VSCode +source .venv/bin/activate # reactivate the venv to ensure you're using the right installation +sqlmesh init # follow the prompts to get started (choose DuckDB) +``` + + + +> Note: You may need to run `python3` or `pip3` instead of `python` or `pip`, depending on your python installation. + +
+Windows Installation + +```bash +mkdir sqlmesh-example +cd sqlmesh-example +python -m venv .venv +.\.venv\Scripts\Activate.ps1 +pip install 'sqlmesh[lsp]' # install the sqlmesh package with extensions to work with VSCode +.\.venv\Scripts\Activate.ps1 # reactivate the venv to ensure you're using the right installation +sqlmesh init # follow the prompts to get started (choose DuckDB) +``` +
+ + +Follow the [quickstart guide](https://sqlmesh.readthedocs.io/en/stable/quickstart/cli/) to learn how to use SQLMesh. You already have a head start! + +Follow the [crash course](https://sqlmesh.readthedocs.io/en/stable/examples/sqlmesh_cli_crash_course/) to learn the core movesets and use the easy to reference cheat sheet. + +Follow this [example](https://sqlmesh.readthedocs.io/en/stable/examples/incremental_time_full_walkthrough/) to learn how to use SQLMesh in a full walkthrough. + +## Join Our Community +Connect with us in the following ways: -Follow the [tutorial](https://sqlmesh.readthedocs.io/en/stable/quick_start/) to learn how to use SQLMesh. +* Join the [Tobiko Slack Community](https://tobikodata.com/slack) to ask questions, or just to say hi! +* File an issue on our [GitHub](https://github.com/SQLMesh/sqlmesh/issues/new) +* Send us an email at [hello@tobikodata.com](mailto:hello@tobikodata.com) with your questions or feedback +* Read our [blog](https://tobikodata.com/blog) -## Join our community -We'd love to join you on your data journey. Connect with us in the following ways: +## Contributing +We welcome contributions! Please see [CONTRIBUTING.md](CONTRIBUTING.md) for guidelines on how to contribute, including our DCO sign-off requirement. -* Join the [Tobiko Slack community](https://tobikodata.com/slack) to ask questions, or just to say hi! -* File an issue on our [GitHub](https://github.com/TobikoData/sqlmesh/issues/new). -* Send us an email at [hello@tobikodata.com](mailto:hello@tobikodata.com) with your questions or feedback. +Please review our [Code of Conduct](CODE_OF_CONDUCT.md) and [Governance](GOVERNANCE.md) documents. -## Contribution -Contributions in the form of issues or pull requests are greatly appreciated. [Read more](https://sqlmesh.readthedocs.io/en/stable/development/) about how to develop for SQLMesh. +[Read more](https://sqlmesh.readthedocs.io/en/stable/development/) on how to set up your development environment. +## License +This project is licensed under the [Apache License 2.0](LICENSE). Documentation is licensed under [CC-BY-4.0](https://creativecommons.org/licenses/by/4.0/). diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000000..2ffffacea3 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,17 @@ +# Security Policy + +## Reporting a Vulnerability + +If you discover a security vulnerability in SQLMesh, please report it through [GitHub Security Advisories](https://github.com/sqlmesh/sqlmesh/security/advisories/new). Do not file a public issue for security vulnerabilities. + +## Response + +We will acknowledge receipt of your report within 72 hours and aim to provide an initial assessment within one week. + +## Disclosure + +We follow a coordinated disclosure process. We will work with you to understand and address the issue before any public disclosure. + +## Supported Versions + +Security fixes are generally applied to the latest release. Critical vulnerabilities may be backported to recent prior releases at the discretion of the maintainers. diff --git a/benchmarks/lsp_render_model_bench.py b/benchmarks/lsp_render_model_bench.py new file mode 100644 index 0000000000..f41f5f2d22 --- /dev/null +++ b/benchmarks/lsp_render_model_bench.py @@ -0,0 +1,118 @@ +#!/usr/bin/env python + +import asyncio +import pyperf +import os +import logging +from pathlib import Path +from lsprotocol import types + +from sqlmesh.lsp.custom import RenderModelRequest, RENDER_MODEL_FEATURE +from sqlmesh.lsp.uri import URI +from pygls.client import JsonRPCClient + +# Suppress debug logging during benchmark +logging.getLogger().setLevel(logging.WARNING) + + +class LSPClient(JsonRPCClient): + """A custom LSP client for benchmarking.""" + + def __init__(self): + super().__init__() + self.render_model_result = None + self.initialized = asyncio.Event() + + # Register handlers for notifications we expect from the server + @self.feature(types.WINDOW_SHOW_MESSAGE) + def handle_show_message(_): + # Silently ignore show message notifications during benchmark + pass + + @self.feature(types.WINDOW_LOG_MESSAGE) + def handle_log_message(_): + # Silently ignore log message notifications during benchmark + pass + + async def initialize_server(self): + """Send initialization request to server.""" + # Get the sushi example directory + sushi_dir = Path(__file__).parent.parent / "examples" / "sushi" + + response = await self.protocol.send_request_async( + types.INITIALIZE, + types.InitializeParams( + process_id=os.getpid(), + root_uri=URI.from_path(sushi_dir).value, + capabilities=types.ClientCapabilities(), + workspace_folders=[ + types.WorkspaceFolder( + uri=URI.from_path(sushi_dir).value, + name="sushi" + ) + ] + ) + ) + + # Send initialized notification + self.protocol.notify(types.INITIALIZED, types.InitializedParams()) + self.initialized.set() + return response + + +async def benchmark_render_model_async(client: LSPClient, model_path: Path): + """Benchmark the render_model request.""" + uri = URI.from_path(model_path).value + + # Send render_model request + result = await client.protocol.send_request_async( + RENDER_MODEL_FEATURE, + RenderModelRequest(textDocumentUri=uri) + ) + + return result + + +def benchmark_render_model(loops): + """Synchronous wrapper for the benchmark.""" + async def run(): + # Create client + client = LSPClient() + + # Start the SQLMesh LSP server as a subprocess + await client.start_io("python", "-m", "sqlmesh.lsp.main") + + # Initialize the server + await client.initialize_server() + + # Get a model file to test with + sushi_dir = Path(__file__).parent.parent / "examples" / "sushi" + model_path = sushi_dir / "models" / "customers.sql" + + # Warm up + await benchmark_render_model_async(client, model_path) + + # Run benchmark + t0 = pyperf.perf_counter() + for _ in range(loops): + await benchmark_render_model_async(client, model_path) + dt = pyperf.perf_counter() - t0 + + # Clean up + await client.stop() + + return dt + + return asyncio.run(run()) + + +def main(): + runner = pyperf.Runner() + runner.bench_time_func( + "lsp_render_model", + benchmark_render_model + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/docs/HOWTO.md b/docs/HOWTO.md new file mode 100644 index 0000000000..edd7c9833f --- /dev/null +++ b/docs/HOWTO.md @@ -0,0 +1,465 @@ +# SQLMesh Docs: Editing Guide + +You have been asked/told to work on SQLMesh's docs - congratulations! + +This document will get you set up to modify or create new SQLMesh documentation. It describes: + +- The workflow for modifying or adding docs +- How we approach writing style for the docs +- The tools used to work with docs +- How to write docs in markdown +- Configuring the docs site +- Hosting the docs on readthedocs.io + +From a technical perspective, docs modifications are just like modifications to any other code file. Therefore, they are made and approved via pull requests in the SQLMesh Github repo. + +## Workflow + +When modifying or adding the docs, you will generally follow these steps: + +1. Clone the latest version of the SQLMesh Github repo +2. Locate the file to edit in the repo's `/docs` directory (or create a new file) +3. Ensure the docs tools are [set up](#setup) and working +4. Start a local version of the docs site +5. Make your changes, examining them in the local docs site +6. Create a new git branch +7. Commit your changes to the new branch +8. Push the branch to Github +9. Open a pull request + +Depending on the scale/complexity of the changes, reviews may happen in one of two ways: + +For larger changes or new pages, Trey will do a full review and editing pass. He will make edits directly in the doc file, create a new git branch, and make a PR **against your PR branch** (NOT against SQLMesh main). + +You will review the changes and provide feedback, Trey will update the doc, and you will approve the PR when you are satisfied. + +!!! important "Trey edits first" + + Because Trey will make large changes, his review and editing pass should occur BEFORE other team members spend time reviewing. + +After Trey's PR has been merged into your branch, you will receive comments/feedback from other team members in the Github PR interface. You will then make the requested changes and push them to the branch, the PR will be approved by Trey or another team member, and it can be merged. + +If your changes are smaller, Trey will not do a full edit and will provide comments/feedback in the Github PR interface like everyone else. + +### New docs + +Brand new docs pages usually require a significant amount of editing. Therefore, when drafting a new page your main focus is ensuring all the content is present, accurate, and ordered/structured sensibly. + +If you built the feature being documented, you have the most knowledge about how it works and which parts are important. Your opinion and context are critical. + +Do not spend too much time wordsmithing and styling. Because so much editing will happen, language you work hard on may be removed or altered. That's demoralizing, even if it's replaced by something you agree is better (and especially if it's replaced with something worse). + +Your wordsmithing and style are important, but they should be the last step of the writing process. Doing them on the first draft does not provide a good ROI. + +## Writing style + +We do not have a written style guide, but we try to follow a few stylistic conventions. + +At a high level, think "simpler is better." + +Data engineering is complex, so SQLMesh is complex. We must focus on minimizing cognitive load for readers, while ensuring all the content is present and accurate. This is a difficult balance. + +The most important specific stylistic conventions are: + +1. Use second person voice when providing instructions + - DO: "Add an audit to the model." + - DO: "You can partition the table if necessary." +2. Use first person plural when describing actions but not providing instructions (e.g., extended example) + - DO: "First, we create a new Python environment." + - DO: "After running the plan command, we see the following output." +3. Use active voice + - DO: "SQLMesh automatically infers the table schema." + - DO NOT: "The table schema is inferred automatically by SQLMesh." +4. Prefer short sentences to long + - Not dogmatic, use your judgment +5. Liberally use code examples and graphics + - Abstract discussion is boring and difficult for people to follow +6. Liberally use headers to structure content within a page + - But don't go overboard + +## Tools + +SQLMesh docs are built with the [`MkDocs` library](https://www.mkdocs.org/). + +`MkDocs` is a static site generator that converts the files in our `/docs` directory into website files. When SQLMesh's Github repo has a PR merged to main, a build is triggered that converts and uploads the files to `readthedocs.io`, which then serves them to end users. + +`MkDocs` is configured in the `mkdocs.yml` configuration file, which specifies the site page hierarchy, color theme, and MkDocs plugins used (e.g., [Material for MkDocs](https://squidfunk.github.io/mkdocs-material/)). + +### Setup + +You will work on the docs in a local copy of the sqlmesh git repository. + +If you don't have a copy of the repo on your machine, open a terminal and clone it into a `sqlmesh` directory by executing: + +``` bash +git clone https://github.com/SQLMesh/sqlmesh.git +``` + +And navigate to the directory: + +``` bash +cd sqlmesh +``` + +`MkDocs` is a Python library, so we first create and activate a new virtual environment: + +```bash +python -m venv .venv +source .venv/bin/activate +``` + +We will now run three separate installation commands. + +First, we install the "pre-commit" tools that automatically validate files when the `git commit` command is run: + +```bash +make install-pre-commit +``` + +Next, we install the core SQLMesh dev dependencies: + +```bash +make install-dev +``` + +And, finally, we install `MkDocs` and other docs dependencies: + +```bash +make install-doc +``` + +The docs requirements file pins library versions, which can sometimes cause unresolvable conflicts. If you receive a "cannot find compatible versions" error for the final command, run this instead: + +```bash +pip install mkdocs mkdocs-include-markdown-plugin mkdocs-material mkdocs-material-extensions mkdocs-glightbox pdoc +``` + +### Usage + +It is helpful run a local version of the docs site while editing or adding docs. That way you can preview how your changes will look on the SQLMesh hosted docs. + +Navigate to the `sqlmesh` directory we created before and run `mkdocs serve`: + +``` bash +> mkdocs serve + +INFO - Building documentation... +INFO - Cleaning site directory +INFO - Documentation built in 3.63 seconds +INFO - [16:02:59] Watching paths for changes: 'docs', 'mkdocs.yml' +INFO - [16:02:59] Serving on http://127.0.0.1:8000/ +``` + +View the docs site by navigating to `http://127.0.0.1:8000/` in a web browser. To view the HOTWO doc, navigate to `http://127.0.0.1:8000/HOWTO/`. + +The command will block the terminal in which it is run, so you must open a new terminal to do anything on the command line. + +The docs site will update in real time as you edit and save changes to the underlying files. + +## Docs markdown + +We use `MkDocs` so we can control almost all of the site's appearance and behavior with markdown. That makes it simple to maintain the docs as text files in the SQLMesh Github repo. + +This section discusses the different ways we use markdown to control the appearance and behavior of the docs site. + +### Document structure + +A docs page's structure (headers and within-page navigation) is defined by the use of markdown headers. + +A markdown header is a line that begins with between one and four hash marks `#`. The number of hash marks determines the "level" of the header, with one hash mark being the highest level. + +Every docs page must begin with a top-level header (one hash mark). This header is used as the page's title in the navigation bar. + +!!! important + + The page may only have one top-level header that begins with a single hash mark! + +Subsequent headers are used to divide the page into sections, with each level down nested within its parent (e.g., three-level `###` headers are nested within two-level `##` headers). + +A within-page table of contents bar is automatically generated from the headers of the page and displayed on the right side. + +For example, the [Configuration guide's](./guides/configuration.md) navigation bar uses multiple header levels to group content: + +![Configuration guide within-page navigation bar](./readme/docs-site_within-page-nav_config-guide.png) + +### Lists + +We do not want pages that are a "wall of text," which is difficult to read and understand. Instead, use lists to break up a page and more effectively communicate its content. + +For example, if we are describing a process with multiple steps, it is clearer to use a numbered list of those steps than a separate sentence for each step. + +Similarly, any time a sentence contains a long list of items, you should consider using a bulleted list instead. + +Lists are useful for breaking up a page, but that visual distinction draws people's attention. Be careful not to use so many lists that you tip the balance from "too much text" to "too little text." + +To specify a list, put each element on its own line. Start the line with: + +- A dash `-` or asterisk `*` for a bullet list +- A number and period `1.` for a numbered list +- A letter and period `a.` for a lettered list + +!!! important "Empty line before list!" + + You must put an empty line before the first list element, or it will not render. + +We can specify a simple bullet list like this: + +``` +Here's a bullet list! + +- First item +- Second item +``` + +And it renders to this: + +Here's a bullet list! + +- First item +- Second item + +
+Or a numbered list: + +``` +1. First item +2. Second item +``` + +1. First item +2. Second item + +
+ +You can nest list items by adding 4 spaces of indentation: + +``` +- First item + - First subitem + - Second subitem +- Second item +``` + +- First item + - First subitem + - Second subitem +- Second item + +### Inline code + +Sometimes we need to display a simple code snippet inline with regular text. + +For example, we might be describing the `sqlmesh plan` command and want to differentiate the words "sqlmesh plan" from the other words. + +Do this by wrapping the code in single backticks: + +``` +I want to make sure `sqlmesh plan` looks different than the other words! +``` + +### Code blocks + +The SQLMesh docs include many examples of code or command output. These examples are displayed in special "code blocks" that display and highlight the code. + +Code blocks begin and end with three backticks ```. The code to display goes between the first and second set of backticks. + +Specify the code language next to the first set of backticks to ensure proper syntax highlighting. For example, we could specify Python highlighting like this: + +``` + ``` python + + my_result = 1 + 1 + + ``` +``` + +For terminal commands and output, specify the language as `bash`. + +Code blocks have a number of options for display, the most important of which are line numbers and highlighted lines. + +Line numbers are important for larger code blocks, making it easier for the text to reference specific parts of the code. + +Highlighted lines provide an even more direct way to draw attention to specific parts of the code. + +This figure shows examples of the different code block options: + +![Code block options](./readme/docs-site_code-block-options.png) + +### Callouts + +Callouts are used to draw attention to important points or to highlight important information. + +Use them to ensure that readers notice key points. They are particularly useful if the important point is embedded in a large section of text. + +We use the "admonitions" library for callouts, and [they have 12 built-in types](https://squidfunk.github.io/mkdocs-material/reference/admonitions/#supported-types) with different icons and styles: + +- `note` +- `abstract` +- `info` +- `tip` +- `example` +- `quote` +- `success` +- `question` +- `warning` +- `failure` +- `danger` +- `bug` + +Create a callout by starting a line with three exclamation marks `!!!` and the name of the callout type you want to use. For example: + +``` +!!! note + This creates a note callout! + +``` + +And this is what that callout looks like: + +!!! note + This creates a note callout! + +By default, the callout title is its type. You can change the title by adding it in quote after the callout type: + +``` +!!! important "Custom title" + This creates an important callout with a custom title! +``` + +!!! important "Custom title" + This creates an important callout with a custom title! + +You can make a callout collapsible by using three question marks `???` instead of exclamation marks: + +``` +??? tip + This creates a collapsible tip callout! +``` + +??? tip + This creates a collapsible tip callout! + +You can make the collapsible open by default by adding a plus sign `+` to the three question marks: + +``` +???+ warning + This creates a collapsible warning callout that is open by default! +``` + +???+ warning + This creates a collapsible warning callout that is open by default! + +### Images + +The SQLMesh docs use screenshots of output, graphics, and other images to supplement the text. + +To add an image, first create it and save it in PNG format. Save it in a folder in the directory where its doc's markdown file is located. + +Add the image to a page with this markdown: + +``` +![Image's alt text](./relative/path/to/image.png) +``` + +Note that: +- The line starts with an exclamation point `!` +- Brackets containing the image's alt text come next +- The relative path to the image follows the brackets + +There may not be spaces between the exclamation point, brackets, and path. + +Specify alt text for all images. + +### Custom CSS and inline HTML + +Sometimes markdown just doesn't cut it. + +`MkDocs` supports custom CSS and inline HTML, both of which we use as necessary (but sparingly). + +For example, by default you can only link to navigation elements within a page (like section titles). We sometimes want to link to individual pieces of content, so we use inline HTML to create a custom anchor link. + +For example, [in the FAQ](./faq/faq.md#schema-question) we make a link to the "schema question" with the inline HTML ``. + +## Configuring docs + +Docs are configured in the `mkdocs.yml` file. + +The first section of the file defines high-level information about SQLMesh, such as the docs site's name and our Github repo URL/name. + +![SQLMesh mkdocs.yml configuration file](./readme/mkdocs-file.png) + +We describe subsequent sections below. + +### Site layout and navigation + +The bulk of the file defines the structure/layout of the docs site's pages under the `nav` key. + +It defines a hierarchy of pages and subpages that is reflected in the site's navigation elements (e.g., top menu, left sidebar). + +As with all YAML files, indentation plays a key role. Each level of indentation generates a new level down in the hierarchy. + +One indentation below `nav` corresponds to top-level navigation elements like the menu bar links. + +Here we see links generated from the first three top-level `nav` entries `Overview`, `Get started`, and `Guides`: + +![SQLMesh docs site top-level nav](./readme/docs-site_top-level-nav.png) + +As we continue downward in the hierarchy, we may either add specific pages or new section(s) that contain subpages. SQLMesh docs use both these approaches in different places. + +For example, the `Get started` section contains 6 subpages, while the `Guides` section contains 4 sections that each specify their own subpages. + +Here we see the 6 subpages specified directly under the `Get started` entry in the lefthand navbar: + +![SQLMesh docs site 2nd level nav - Get Started section](./readme/docs-site_2nd-level-nav_get-started.png) + +And here we see the first three subsections `Project structure`, `Project setup`, and `Project content` (and their subpages) specified under the `Guides` entry: + +![SQLMesh docs site 2nd level nav - Guides section](./readme/docs-site_2nd-level-nav_guides.png) + +You may continue to add subsections as needed. At the time of writing, only the `Tobiko Cloud` section uses a 3rd level of nested sections. + +### Theme and colors + +The `theme` section defines the appearance of the docs site. It is rarely modified. + +It specifies the theme name, logo, and color palette, and configures features like the navigation bar and sidebar. + +![SQLMesh mkdocs.yml theme section](./readme/docs-site_mkdocs-theme.png) + +### Plugins/extensions + +The `plugins` and `markdown_extensions` sections specify different plugins and extensions we use to add functionality to the docs site. It is rarely modified. + +Some examples: +- Plugin `glightbox` allows users to expand and zoom on images +- Markdown extension `pymdownx.tabbed` specifies how tabbed content is displayed +- Markdown extension `admonition` allows us to add callout boxes + +![SQLMesh mkdocs.yml plugins section](./readme/docs-site_mkdocs-plugins.png) + +### Extra + +The final sections of the `mkdocs.yml` file define assorted metadata about the site. + +The `extra_css` key specifies the location of the file containing custom CSS. We use this to add custom colors to some elements. It is rarely modified. + +The `extra` section specifies links embedded in the site footer and our Google Analytics ID. It is rarely modified. + +![SQLMesh mkdocs.yml extra section](./readme/docs-site_mkdocs-extra.png) + +## Docs hosting + +Our docs are built with `MkDocs`, but they are hosted on `readthedocs.io`. + +When a PR is merged to main, it triggers a docs build and deployment. That process is configured in the `.readthedocs.yaml` file. + +Readthedocs supports multiple versions of the docs, with two important versions: `stable` and `latest`. + +The `stable` docs are built from the latest Github release tag, while the `latest` docs are built from the latest commit on main. + +We have hidden the interface for accessing `latest`, so users will generally not be able to access it. However, you may access it by replacing the word "stable" with "latest" in a URL: + +For example, the Getting started page is at: + +- Stable: `https://sqlmesh.readthedocs.io/en/stable/quick_start/` +- Latest: `https://sqlmesh.readthedocs.io/en/latest/quick_start/` \ No newline at end of file diff --git a/docs/_readthedocs/html/favicon.svg b/docs/_readthedocs/html/favicon.svg new file mode 100644 index 0000000000..cbf6e39228 --- /dev/null +++ b/docs/_readthedocs/html/favicon.svg @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/docs/cloud/cloud_index.md b/docs/cloud/cloud_index.md new file mode 100644 index 0000000000..aedd918a2c --- /dev/null +++ b/docs/cloud/cloud_index.md @@ -0,0 +1,46 @@ + +# Welcome to Tobiko Cloud + +[Tobiko Cloud](https://tobikodata.com/product.html) is a data transformation platform that enhances the ease and efficiency of managing data pipelines with SQLMesh. + +Tobiko Cloud is designed for companies who want to: + +- Host SQLMesh on a robust, reliable platform without building and maintaining it themselves +- Understand the status, activity, and performance of data pipelines at a glance +- Rapidly detect and debug problems with their pipelines +- Monitor cloud costs over time, by model (BigQuery and Snowflake engines only) + +![Tobiko Cloud](./cloud_index/tobiko-cloud.png) + +## How is Tobiko Cloud different from SQLMesh? + +Tobiko Cloud complements SQLMesh, supporting companies that need enterprise-level features like scalability, observability, and cost optimization. + +Here’s a comparison: + +1. **Deployment**: Tobiko Cloud simplifies SQLMesh deployment by hosting it on our infrastructure. + + It provides enterprise-grade hosting and scalability for complex data transformations, freeing teams from managing infrastructure themselves. + +2. **Observability and Insights**: Tobiko Cloud integrates deeply with SQLMesh, providing instant visibility into pipeline versions, code changes, and errors. + + This allows teams to monitor their pipelines, detect changes in pipeline behavior, and rapidly trace the root causes of data issues. + +4. **Efficiency**: SQLMesh's built-in features like virtual data environments and automatic change classification reduce computational costs and improve processing speeds. + + Tobiko Cloud's enhanced change classification identifies even more scenarios where code changes don't require rerunning downstream models. + +4. **Cost monitoring**: Tobiko Cloud automatically tracks costs per model execution for BigQuery and Snowflake. + + This allows teams to rapidly detect anomalous spending and to identify the models driving cloud costs. + +## Learn more + +Ready to unlock a faster, smarter, and more efficient way to manage your data pipelines? Book a call with the Tobiko Cloud team today! + +Discover how Tobiko's managed SQLMesh platform will empower your team to scale effortlessly, optimize costs, and deliver accurate data faster — all while freeing your team from infrastructure headaches. + +Whether you're a data engineer, or decision-maker, Tobiko Cloud gives you data transformation without the waste. Let's talk! + +
+ diff --git a/docs/cloud/cloud_index/tobiko-cloud.png b/docs/cloud/cloud_index/tobiko-cloud.png new file mode 100644 index 0000000000..ed2ed69d95 Binary files /dev/null and b/docs/cloud/cloud_index/tobiko-cloud.png differ diff --git a/docs/cloud/features/alerts_notifications.md b/docs/cloud/features/alerts_notifications.md new file mode 100644 index 0000000000..f8c4d0e0fc --- /dev/null +++ b/docs/cloud/features/alerts_notifications.md @@ -0,0 +1,119 @@ +# Alerts + +Nobody likes learning about a data problem from stakeholders' angry messages about broken dashboards. If something goes wrong, you want to be the first to know! + +Tobiko Cloud makes sure you hear about problems first, alerting the right people immediately when a problem occurs. + +## Configuring Alerts + +Configure alerts in the Tobiko Cloud Settings section. + +To begin, navigate to Settings from the Home screen by clicking the `Settings` link in the top left navigation menu. + +![Image highlighting location of the Setting section link](./alerts_notifications/settings_section_link.png) + +In the Settings section, navigate to the Alerts page by clicking the `Alerts` link in the top left navigation menu. + +Then add a new alert by clicking the `Add Alert` button in the top right. + +![Image highlighting location of the Add Alert button](./alerts_notifications/add_alert_button.png) + +This opens the Add New Alert configuration page. + +Specify an informative name for the alert in the Name field, and click the drop downs for when you want this to run. This is a simple `event` alert, but we'll go into more options below. + +After you're finished configuring the new alert, save it by clicking the `Save` button in the bottom right. + +![Image showing the add Alert page](./alerts_notifications/add_alert_page.png) + +## Alerts + +Tobiko Cloud sends an alert based on a *trigger*. There are two types of triggers: [events](#event-triggers) and [measures](#measure-triggers). + +Events are tied to steps in the SQLMesh `plan` and `run` processes. For example, you could alert whenever a `plan` succeeded or a `run` failed. + +Choose whether the alert will be triggered by a Measure or Event in the alert's Trigger Type field. + +![Image showing the add Alert page trigger type field](./alerts_notifications/add_alert_trigger_type.png) + +### Event triggers + +Tobiko Cloud Alerts can be triggered by the following events: + +- Plan start +- Plan end +- Plan failure +- Run start +- Run end +- Run failure + +Specify an event trigger by first choosing whether it is tied to a `plan` or `run` Artifact. + +![Image showing the Artifact dropdown](./alerts_notifications/add_event_artifact.png) + +Next, choose the notification Event type: Start, Failure, or End. + +![Image showing the Event dropdown](./alerts_notifications/add_event_event.png) + +Finally, choose a Notification Target where the alert should be sent (described [below](#notification-targets)) and click the Save button in the bottom right. + +![Image showing the add event Alert page](./alerts_notifications/add_event_alert_page.png) + +### Measure triggers + +Tobiko Cloud Alerts can be triggered when a measure exceeds a threshold or meets a condition. + +To configure a measure alert, first build the condition that triggers the measure. Choose the measure of interest, the comparison operator, and a threshold value. + +![Image showing the add measure Alert page condition section](./alerts_notifications/add_measure_alert_condition.png) + +Now specify the alert Artifact field. + +Some measures, like run time, are most useful when accumulated over an entire `plan` or `run`. For example, you might want to alert whenever a `run`'s total run time is longer than four hours. + +Configure a cumulative measure alert by choosing an Artifact type of Plan or Run. + +Configure a non-cumulative measure alert by choosing an Artifact type of Measure. + +![Image showing the add measure Alert page Artifact field](./alerts_notifications/add_measure_alert_artifact.png) + +To prevent alert fatigue, you can limit measure-based alerts to a specific environment or model in the optional Environment and Model fields. + +![Image showing the add measure Alert page](./alerts_notifications/add_measure_alert_page.png) + +## Notification Targets + +Each alert is sent to one or more notification targets. + +A notification target is a way for alerts to contact you. A target can be used in multiple alerts, so you only have to configure them once. + +### Notification Target Configuration + +Configure Notification targets in the Tobiko Cloud Settings section. + +To add a new notification target, navigate to the Notification Targets page and click the Add Notification Target button in the top right. + +![Image highlighting location of the Add Notification Target button](./alerts_notifications/add_notification_target_button.png) + +Then enter a descriptive name for the new notification target, select its type (described [below](#notification-target-types)), fill in the configuration information, and click Save. + +![Image showing the add Notification Target page](./alerts_notifications/add_notification_target_page.png) + +### Notification target types + +Tobiko Cloud supports the following notification target types, which require you to provide different pieces of configuration information. + +- Slack API + - API Token + - Format: `xoxb-[13 digits]-[13 digits]-[24 alphanumeric characters]` + - Channel ID + - Format: `T[10 capital letter or numeric characters]` + - Example: T139Z25G8F4 +- Slack Webhook + - Webhook URL + - Format: Web URL + - Example: https://hooks.slack.com/services/T00000000/B00000000/XXXXXXXXXXXXXXXXXXXXXXXX +- PagerDuty + - Routing Key + - Format: `[32 alphanumeric characters]` + - Example: j16lxprdvoy21paigybthal0llk51kh5k diff --git a/docs/cloud/features/alerts_notifications/add_alert_button.png b/docs/cloud/features/alerts_notifications/add_alert_button.png new file mode 100644 index 0000000000..14bd2700d9 Binary files /dev/null and b/docs/cloud/features/alerts_notifications/add_alert_button.png differ diff --git a/docs/cloud/features/alerts_notifications/add_alert_page.png b/docs/cloud/features/alerts_notifications/add_alert_page.png new file mode 100644 index 0000000000..8681adf5c5 Binary files /dev/null and b/docs/cloud/features/alerts_notifications/add_alert_page.png differ diff --git a/docs/cloud/features/alerts_notifications/add_alert_trigger_type.png b/docs/cloud/features/alerts_notifications/add_alert_trigger_type.png new file mode 100644 index 0000000000..1e29fe83c4 Binary files /dev/null and b/docs/cloud/features/alerts_notifications/add_alert_trigger_type.png differ diff --git a/docs/cloud/features/alerts_notifications/add_event_alert_page.png b/docs/cloud/features/alerts_notifications/add_event_alert_page.png new file mode 100644 index 0000000000..5521c5cc29 Binary files /dev/null and b/docs/cloud/features/alerts_notifications/add_event_alert_page.png differ diff --git a/docs/cloud/features/alerts_notifications/add_event_artifact.png b/docs/cloud/features/alerts_notifications/add_event_artifact.png new file mode 100644 index 0000000000..5aa55412f0 Binary files /dev/null and b/docs/cloud/features/alerts_notifications/add_event_artifact.png differ diff --git a/docs/cloud/features/alerts_notifications/add_event_event.png b/docs/cloud/features/alerts_notifications/add_event_event.png new file mode 100644 index 0000000000..5dbc5b2f83 Binary files /dev/null and b/docs/cloud/features/alerts_notifications/add_event_event.png differ diff --git a/docs/cloud/features/alerts_notifications/add_measure_alert_artifact.png b/docs/cloud/features/alerts_notifications/add_measure_alert_artifact.png new file mode 100644 index 0000000000..18bc08a03c Binary files /dev/null and b/docs/cloud/features/alerts_notifications/add_measure_alert_artifact.png differ diff --git a/docs/cloud/features/alerts_notifications/add_measure_alert_condition.png b/docs/cloud/features/alerts_notifications/add_measure_alert_condition.png new file mode 100644 index 0000000000..647cf0ad96 Binary files /dev/null and b/docs/cloud/features/alerts_notifications/add_measure_alert_condition.png differ diff --git a/docs/cloud/features/alerts_notifications/add_measure_alert_page.png b/docs/cloud/features/alerts_notifications/add_measure_alert_page.png new file mode 100644 index 0000000000..21876f0d22 Binary files /dev/null and b/docs/cloud/features/alerts_notifications/add_measure_alert_page.png differ diff --git a/docs/cloud/features/alerts_notifications/add_notification_target_button.png b/docs/cloud/features/alerts_notifications/add_notification_target_button.png new file mode 100644 index 0000000000..e823a89e32 Binary files /dev/null and b/docs/cloud/features/alerts_notifications/add_notification_target_button.png differ diff --git a/docs/cloud/features/alerts_notifications/add_notification_target_page.png b/docs/cloud/features/alerts_notifications/add_notification_target_page.png new file mode 100644 index 0000000000..42ec022c0f Binary files /dev/null and b/docs/cloud/features/alerts_notifications/add_notification_target_page.png differ diff --git a/docs/cloud/features/alerts_notifications/settings_section_link.png b/docs/cloud/features/alerts_notifications/settings_section_link.png new file mode 100644 index 0000000000..5740bf566b Binary files /dev/null and b/docs/cloud/features/alerts_notifications/settings_section_link.png differ diff --git a/docs/cloud/features/costs_savings.md b/docs/cloud/features/costs_savings.md new file mode 100644 index 0000000000..a8e8b3c94e --- /dev/null +++ b/docs/cloud/features/costs_savings.md @@ -0,0 +1,46 @@ +# Data Warehouse Costs and Savings with Tobiko + +Understanding and managing data warehouse costs is challenging. Tobiko Cloud helps by tracking your data warehouse costs and integrating them into the Tobiko Cloud UI. + +Tobiko Cloud tracks data warehouse cost estimates per model for BigQuery and Snowflake projects. It also estimates how much money Tobiko Cloud has saved you by skipping unnecessary model reruns. + +## Supported Data Warehouse Pricing Plans + +Tobiko Cloud supports costs and savings data for these data warehouse pricing plans: + +- BigQuery On Demand +- Snowflake Credits + +## Data Warehouse Cost Configuration + +If you use a supported pricing plan, visit Settings to configure Tobiko Cloud's cost estimates. + +![Image highlighting location of the Settings link in the left site navigation](./costs_savings/costs-navigation.png) + +On the General settings page (1), select your pricing plan (2), enter your costs, and then save (3). + +![Annotated image showing locations of the general settings link, pricing plan form fields, and save button](./costs_savings/costs-steps.png) + +## Where to find cost and savings information + +Estimated costs and savings are displayed on the homepage, production environment page, runs and plans pages, and individual model pages. + +Cost information on each page will look similar to this: + +![Example of costs and savings data as seen on the Tobiko Cloud homepage](./costs_savings/costs-example.png) + +### Savings Categories + +When calculating your data warehouse costs, we also calculate how much you saved by using Tobiko! + +Tobiko Cloud comes with even more change categorization capabilities than open-source SQLMesh, such as advanced column-level impact analysis. + +Cost savings are broken up into three main categories: + +- **Prevented Reruns**: If SQLMesh already executed a change in one environment, we won't rerun it in another environment (backfills from development environments are reused when it is safe to do so). +- **Unaffected Downstream**: SQLMesh understands SQL, so we skip re-execution if a downstream model is not affected by an upstream change. +- **Virtual Environments**: With Virtual data environments, new environments can be created without running any computations at all. + +### Where to find cost savings information + +Cost savings are included in most places costs are displayed. Find how much you've saved using Tobiko by viewing the homepage, production environment page, or individual model pages. diff --git a/docs/cloud/features/costs_savings/costs-example.png b/docs/cloud/features/costs_savings/costs-example.png new file mode 100644 index 0000000000..20dd443e0b Binary files /dev/null and b/docs/cloud/features/costs_savings/costs-example.png differ diff --git a/docs/cloud/features/costs_savings/costs-navigation.png b/docs/cloud/features/costs_savings/costs-navigation.png new file mode 100644 index 0000000000..8978669c1c Binary files /dev/null and b/docs/cloud/features/costs_savings/costs-navigation.png differ diff --git a/docs/cloud/features/costs_savings/costs-steps.png b/docs/cloud/features/costs_savings/costs-steps.png new file mode 100644 index 0000000000..037baf699d Binary files /dev/null and b/docs/cloud/features/costs_savings/costs-steps.png differ diff --git a/docs/cloud/features/data_catalog.md b/docs/cloud/features/data_catalog.md new file mode 100644 index 0000000000..4e623b7518 --- /dev/null +++ b/docs/cloud/features/data_catalog.md @@ -0,0 +1,7 @@ +# Data Catalog + +Tobiko Cloud serves a hosted version of your SQLMesh Data Catalog! + +Anyone with access to your Tobiko Cloud instance can use it to explore your production environment models at any time, without needing to run SQLMesh UI on their personal computer. + +![Tobiko Cloud's hosted Data Catalog](./data_catalog/data-catalog-model.png) diff --git a/docs/cloud/features/data_catalog/data-catalog-model.png b/docs/cloud/features/data_catalog/data-catalog-model.png new file mode 100644 index 0000000000..9294aff0e9 Binary files /dev/null and b/docs/cloud/features/data_catalog/data-catalog-model.png differ diff --git a/docs/cloud/features/debugger_view.md b/docs/cloud/features/debugger_view.md new file mode 100644 index 0000000000..f2be89e256 --- /dev/null +++ b/docs/cloud/features/debugger_view.md @@ -0,0 +1,82 @@ +# Debugger View + +
+ +This view is used to help you debug production run issues with your SQLMesh models in Tobiko Cloud. + +Fixing data pipelines in production is a stressful, time-consuming process, so we're here to make it easier with a few visuals/clicks. + +> Note: the debugger view is only available for models that have been executed in your data warehouse via the `tcloud sqlmesh plan` or `tcloud sqlmesh run` commands. + +## Using the Debugger View + +Step 1: On the Tobiko Cloud home page, click any bar in the `Runs Daily` chart to open the debugger view. It doesn't matter whether the bar is green or red. + +![debugger_view_step_1](./debugger_view/debugger_view_step_1.png) + +

+Step 2: Click on the "Explore Executions" tab (bubble 1) to see specific execution details about the run, by model. + +Then choose a model to view. We clicked on the `orders` model (bubble 2) - notice that it shows you a focused view of the DAG (think: lineage) centered on whichever model you clicked. + +![debugger_view_step_2](./debugger_view/debugger_view_step_2.png) + +From here, you can explore the execution details of the run with a model's focused tabs (bubble 3). + +The rest of this page describes those tabs. + +> Pro tip: you can toggle whether timestamps are in UTC or your local timezone in the page's upper right corner. + +## Debugger View Tabs + +### Overview + +See a summary of the model's characteristics and behavior during current and historical runs. + +- You'll see a high-level overview of the model's characteristics, including the execution time, duration, completion status, and next scheduled run. +- View the past 5 run and plan activities to see the model's historical behavior. This is useful to get a pulse on how often it succeeds or fails. If you notice it's failing often, this is a good model to investigate further. +- Click on the "Previous Run" tile to explore the details of the previous run. This is useful if you want to compare the previous run to the current one if you notice duration is shorter or longer than expected. +- Click on the "Last Plan" tile to explore the details of the last plan that was applied. This is useful to see if the model's code was changed in a way that sped up or elongated the duration of the run. It's also helpful to verify if the schema changed in a way that might be causing an issue. + +![overview](./debugger_view/overview.png) + +### Impact + +See the current model's downstream and upstream models in a list format. + +This is useful if lots of models are in the DAG view and you want to see the model's full impact at a glance. + +![impact](./debugger_view/impact.png) + + +### Definition + +See the exact code that was executed during the run. + +This is useful if you want to determine whether the code changed in a way that might be causing an issue. + +![definition](./debugger_view/definition.png) + +### Schema + +See the current model's schema. + +This is useful to determine whether the schema changed in a way that might be causing an issue. + +![schema](./debugger_view/schema.png) + +### Intervals + +See the specific time intervals that were processed during the run. + +This is useful to see which exact time intervals succeeded or failed. Also, it's useful to determine whether time intervals changed in a way that might be causing an issue such as longer run duration. + +![intervals](./debugger_view/intervals.png) + +### Log + +See all the SQLMesh logs from the run. + +You can filter for multiple levels of logs: `info`, `warning`, `error`, etc. + +![log](./debugger_view/log.png) \ No newline at end of file diff --git a/docs/cloud/features/debugger_view/debugger_view_step_1.png b/docs/cloud/features/debugger_view/debugger_view_step_1.png new file mode 100644 index 0000000000..275aecc5f3 Binary files /dev/null and b/docs/cloud/features/debugger_view/debugger_view_step_1.png differ diff --git a/docs/cloud/features/debugger_view/debugger_view_step_2.png b/docs/cloud/features/debugger_view/debugger_view_step_2.png new file mode 100644 index 0000000000..ce9a00f820 Binary files /dev/null and b/docs/cloud/features/debugger_view/debugger_view_step_2.png differ diff --git a/docs/cloud/features/debugger_view/definition.png b/docs/cloud/features/debugger_view/definition.png new file mode 100644 index 0000000000..fb78537147 Binary files /dev/null and b/docs/cloud/features/debugger_view/definition.png differ diff --git a/docs/cloud/features/debugger_view/impact.png b/docs/cloud/features/debugger_view/impact.png new file mode 100644 index 0000000000..30d8f547f0 Binary files /dev/null and b/docs/cloud/features/debugger_view/impact.png differ diff --git a/docs/cloud/features/debugger_view/intervals.png b/docs/cloud/features/debugger_view/intervals.png new file mode 100644 index 0000000000..1fd456d204 Binary files /dev/null and b/docs/cloud/features/debugger_view/intervals.png differ diff --git a/docs/cloud/features/debugger_view/log.png b/docs/cloud/features/debugger_view/log.png new file mode 100644 index 0000000000..4fc1f12302 Binary files /dev/null and b/docs/cloud/features/debugger_view/log.png differ diff --git a/docs/cloud/features/debugger_view/overview.png b/docs/cloud/features/debugger_view/overview.png new file mode 100644 index 0000000000..98967f7ad5 Binary files /dev/null and b/docs/cloud/features/debugger_view/overview.png differ diff --git a/docs/cloud/features/debugger_view/schema.png b/docs/cloud/features/debugger_view/schema.png new file mode 100644 index 0000000000..37745d12df Binary files /dev/null and b/docs/cloud/features/debugger_view/schema.png differ diff --git a/docs/cloud/features/incident_reporting.md b/docs/cloud/features/incident_reporting.md new file mode 100644 index 0000000000..1ff062e06b --- /dev/null +++ b/docs/cloud/features/incident_reporting.md @@ -0,0 +1,46 @@ +# Incident Reporting + +We monitor Tobiko Cloud 24/7 to ensure your projects are running smoothly. + +If you encounter any issues, however, you can report incidents directly in Tobiko Cloud itself. + +This will notify our support team, who will investigate and resolve the issue as quickly as possible. + +### Reporting an incident + +Follow these steps to report an incident in Tobiko Cloud: + +1. Visit the [Tobiko Cloud Incident Reporting Page](https://incidents.tobikodata.com/) +2. Select one of the three severity levels for your incident +3. Enter the project name the incident is related to + * The project name is displayed after your organization name in the Cloud UI +4. Write a detailed description of the incident + * Include all relevant information that will help our support team understand and resolve the issue +5. Click the `Submit` button to send your incident report +6. You will receive a confirmation message indicating that your incident has been reported successfully +7. You will hear from our support team after submitting the incident report + +![Tobiko Cloud incident reporting page](./incident_reporting/incident_reporting.png) + +### Reporting an incident when SSO is unavailable + +Single Sign-On (SSO) is the default way to log in to Tobiko Cloud. However, SSO could be down or not working when you need to report an incident. + +Tobiko Cloud provides a standalone page that doesn't require SSO so you can report an incident when SSO is not working. The page is unique to your organization. + +The standalone URL is available in the incident reporting page when you log in with SSO. Because accessing the standalone URL does not require SSO, you should only share it with staff authorized to report incidents. + +To store your standalone incident reporting URL: + +1. Visit the [Tobiko Cloud Incident Reporting Page](https://incidents.tobikodata.com/) +2. Click the `Copy Standalone URL` button below the incident reporting section +3. Save this URL in an easily accessible location in case you need to report an incident when SSO is not working + +!!! note "Don't wait!" + We recommend copying this URL *right now* so your organization is protected from difficulty reporting an incident. + +### SSO not enabled for your organization + +SSO login is required for accessing the standalone incident reporting URL. + +SSO is enabled by default in Tobiko Cloud. If it is not enabled for your organization, contact your solution architect and ask them to provide you with a standalone incident reporting URL. diff --git a/docs/cloud/features/incident_reporting/incident_reporting.png b/docs/cloud/features/incident_reporting/incident_reporting.png new file mode 100644 index 0000000000..542f1aa883 Binary files /dev/null and b/docs/cloud/features/incident_reporting/incident_reporting.png differ diff --git a/docs/cloud/features/observability/development_environment.md b/docs/cloud/features/observability/development_environment.md new file mode 100644 index 0000000000..2e8953f548 --- /dev/null +++ b/docs/cloud/features/observability/development_environment.md @@ -0,0 +1,75 @@ +# Development Environment + +Tobiko Cloud extends the SQLMesh CLI to advance your development workflow. Instead of relying on a static terminal output isolated to your local machine when running `tcloud sqlmesh plan dev`, Tobiko Cloud tracks development history automatically displayed in a rich user interface. We want mental load at a minimum so you can focus on your most important work. + +At its core, this transforms development from a single-player to a multi-player experience. Instead of sharing screenshots and scrolling through terminal history, all you have to do now is share a link to your work. + +### When you might use this + +**Team Collaboration** + +The platform helps foster team collaboration by providing clear visibility into team activities. Developers can easily see who is working on specific models, prevent workflow conflicts, and avoid duplicate efforts. This creates a multiplayer development experience. + +**Performance Tracking** + +You can monitor changes over time, review recent activities including successes and failures, and gain detailed insights into specific plan execution outcomes to get a better sense of trends. Check out the example image of a development environment page with [multiple plan changes in a day.](#plan-history-image) + +**Simplified Communication and Team Alignment** + +Eliminate friction in sharing complex development context through manual pull requests or direct messages. These URLs serve as comprehensive summaries, displaying last run times, data intervals for incremental models, and detailed change information such as metadata modifications and model removals. + +![The feel and speed of sharing links in Tobiko Cloud](./development_environment/link_sharing_feel.gif) + + +## Using the Environments Tab +The Environments page shows an overview of all the environments that exist in your project (both yours and any your teammates have created). + +![tcloud environment page](./development_environment/environments.png) + +The page's table includes a link to each environment's page, along with the environment's creation date, the date it was last updated, and the date it will expire if not updated again. Clicking an environment's name from the main environments page takes you to its individual page. + +![tcloud development environment](./development_environment/tcloud_development_environment.png) + +## Individual Environment page +The page begins with an at-a-glance summary of the most recent plan applied to the environment. + +![tcloud environment page layout](./development_environment/tcloud_dev_env_labelled.png) + +1. Its completion status and time of the last plan applied +2. The latest time interval backfilled by the plan +3. Count of models present in the environment +4. An interactive visualization that summarizes the differences between the environment's models and the `prod` environment's models + - The count of directly modified models is represented in blue + - The count of added models is green + - The count of removed models is red + +??? "ProTip:" + + If a stakeholder or else anyone on your team is looking to understand an environment you own and are working on, you can share the link with them and they will be able to access and see all of the information about your environment. + + It's a great place to start to have open conversations about what was recently added, removed or changed in an environment! + + +## Differences from Prod section + +Development environments are used to prepare and test changes before deploying them to `prod`, with separate tabs for each type of change (directly modified, indirectly modified, metadata-only changes, added, removed). Below is a screenshot from an environment version that shows all these tab options. + + +![Prod Differences section with all options](./development_environment/dev_env_comprehensive.png) + +In the summary, each model's name is a link to [its model page](./model.md). This links to the information about the version of the model used in _this environment_ not the overall prod model. This means that you can get insight into what your working on in dev instead of the "stale" version in prod ("stale" relative to your work). + +## Plan history information + +The plan applications chart is a calendar visualization of all plans that have been applied to the environment in the previous 2 weeks. + +![Plan History Information](./development_environment/plan_history.png) + + +The chart represents days on its `x-axis` (each column is a day with the corresponding date across the top) and the time of day on its `y-axis` (each day begins at the top and ends at the bottom). + +Each day displays zero or more horizantal bars representing `plan` duration. If no `plans` occurred on a day, no bars will be displayed. If multiple `plans` occurred on the same day, their horizantal bars will be stacked. + +The chart uses color to convey the staus of a `plan` at a glance. Green is completed, grey is in progress, red is failed. + +Hovering over a bar reveals summary information about the `plan`, including its completion status, start time, end time, total duration, and change summary. The summary includes a link to [the `plan`'s page](./plan.md). \ No newline at end of file diff --git a/docs/cloud/features/observability/development_environment/dev_env_comprehensive.png b/docs/cloud/features/observability/development_environment/dev_env_comprehensive.png new file mode 100644 index 0000000000..c5db1527cb Binary files /dev/null and b/docs/cloud/features/observability/development_environment/dev_env_comprehensive.png differ diff --git a/docs/cloud/features/observability/development_environment/environments.png b/docs/cloud/features/observability/development_environment/environments.png new file mode 100644 index 0000000000..c0545c13ac Binary files /dev/null and b/docs/cloud/features/observability/development_environment/environments.png differ diff --git a/docs/cloud/features/observability/development_environment/link_sharing_feel.gif b/docs/cloud/features/observability/development_environment/link_sharing_feel.gif new file mode 100644 index 0000000000..29e3a6cd41 Binary files /dev/null and b/docs/cloud/features/observability/development_environment/link_sharing_feel.gif differ diff --git a/docs/cloud/features/observability/development_environment/plan_history.png b/docs/cloud/features/observability/development_environment/plan_history.png new file mode 100644 index 0000000000..d0116f3e2d Binary files /dev/null and b/docs/cloud/features/observability/development_environment/plan_history.png differ diff --git a/docs/cloud/features/observability/development_environment/tcloud_dev_env_labelled.png b/docs/cloud/features/observability/development_environment/tcloud_dev_env_labelled.png new file mode 100644 index 0000000000..96fa02add2 Binary files /dev/null and b/docs/cloud/features/observability/development_environment/tcloud_dev_env_labelled.png differ diff --git a/docs/cloud/features/observability/development_environment/tcloud_development_environment.png b/docs/cloud/features/observability/development_environment/tcloud_development_environment.png new file mode 100644 index 0000000000..b2c5a7969d Binary files /dev/null and b/docs/cloud/features/observability/development_environment/tcloud_development_environment.png differ diff --git a/docs/cloud/features/observability/measures_dashboards.md b/docs/cloud/features/observability/measures_dashboards.md new file mode 100644 index 0000000000..6960c76b6d --- /dev/null +++ b/docs/cloud/features/observability/measures_dashboards.md @@ -0,0 +1,7 @@ +# Measures + +Coming Soon! + +# Dashboards + +Coming Soon! \ No newline at end of file diff --git a/docs/cloud/features/observability/model.md b/docs/cloud/features/observability/model.md new file mode 100644 index 0000000000..bd14d0c88b --- /dev/null +++ b/docs/cloud/features/observability/model.md @@ -0,0 +1,62 @@ +# Models + +The model overview page provides comprehensive observability features that let you explore detailed information about a model. This centralized view gives you quick access to critical metrics and performance data, providing a window into the model's health and status. + +Model owners typically use this page to monitor and check their models. It provides essential information in an easy-to-scan format, eliminating the need to debug issues through the command line interface. From this page you can quickly diagnose: + +1. Model anomalies + 1. Did the model suddenly take a really long time to run? + 2. Is the model repeatedly failing due to audits or schema evolution? +2. Downstream impacts + 1. If the model fails to run, lineage lets you immediately see what other models are affected +3. Which version introduced errors + 1. Use the model's version history to identify which changes caused a problem + + +## Navigate to a model + +There are a number of ways you can navigate to a model's page. This method shows you how to find your model directly from the Environments page. + +1. Select "Environments" from the left hand menu +2. Click the environment you want to explore from the list. + ![Tobiko Cloud environment page](./model/tcloud_environments.png) +3. Navigate to the Models section and click "Explore" to view available models + ![Tobiko Cloud environment page explore models link](./model/tcloud_environment_explore-models.png) +4. Browse through the model list and select a model to access its detailed information + ![Tobiko Cloud environment models list](./model/tcloud_model_list.png) + +## Model page information + +Each model page presents a comprehensive summary that includes the key components and metrics used to monitor model behavior. + +From here, you can identify anomalies in the model's run time based on historical run times and how they have been changing over time (or not!). + +You can also check other critical information, like the model's source code, its lineage relative to other models, its contents in previous versions, and even an approximation of how much it costs (if you have [cost savings set up](../costs_savings.md)). + +The following detailed information outlines the different sections: + +![Tobiko Cloud model status and metadata](./model/tcloud_model_status-metadata.png) + +- Current status graphs: Provide visual representations of model health through freshness indicators and detailed daily execution graphs + - Freshness indicator: Shows the current status of the model and the percentage of up-to-date models in production (as long as this is green, you have nothing to worry about in your production environment) + - Historical Freshness graph: Gives an at-a-glance picture of the history of the model's freshness. + - Green means it's up to date and has run smoothly for every interval + - Orange means that one interval is pending and will be processed on the next run + - Red means the model has more than one interval waiting to be processed because an interval was not processed during a previous run + - Daily executions: tells you the length of time it took the model to run on each day. This is a great place to quickly identify anomalies in the model's run time (both running too long *or* too short). +- Model details: Features tabs that display summary statistics, model source code, and interactive model lineage visualizations + +![Tobiko Cloud model version history](./model/tcloud_model_2.png) + +- Version history: Delivers a comprehensive chronological view of all model versions, with detailed information including: + - Precise timestamp of version promotion + - Clear indication of change impact (breaking or non-breaking modifications) + - Direct access to the complete implementation plan code +- Data Warehouse costs: estimates the cost of the model as set up by your team in [cost savings](../costs_savings.md) + +![Tobiko Cloud model version history](./model/tcloud_model_3.png) + +- Loaded intervals: these periods represent the time spans processed during each job execution, which generally consist of the time between one job and the next. These intervals are crucial for understanding the boundaries of data processing cycles, which may correspond to the start of anomalous model behavior. + - The table displays the specific model version in effect during that job execution, enabling precise tracking of version-specific outputs + - Helps track forward-only model changes by maintaining a clear chronological record of modifications, ensuring data consistency and preventing retroactive alterations +- Recent activity: Maintains a detailed log of version executions and comprehensive version audits \ No newline at end of file diff --git a/docs/cloud/features/observability/model/tcloud_environment_explore-models.png b/docs/cloud/features/observability/model/tcloud_environment_explore-models.png new file mode 100644 index 0000000000..e4c1991e90 Binary files /dev/null and b/docs/cloud/features/observability/model/tcloud_environment_explore-models.png differ diff --git a/docs/cloud/features/observability/model/tcloud_environments.png b/docs/cloud/features/observability/model/tcloud_environments.png new file mode 100644 index 0000000000..8233e39a09 Binary files /dev/null and b/docs/cloud/features/observability/model/tcloud_environments.png differ diff --git a/docs/cloud/features/observability/model/tcloud_model_2.png b/docs/cloud/features/observability/model/tcloud_model_2.png new file mode 100644 index 0000000000..32a7460979 Binary files /dev/null and b/docs/cloud/features/observability/model/tcloud_model_2.png differ diff --git a/docs/cloud/features/observability/model/tcloud_model_3.png b/docs/cloud/features/observability/model/tcloud_model_3.png new file mode 100644 index 0000000000..8fea306556 Binary files /dev/null and b/docs/cloud/features/observability/model/tcloud_model_3.png differ diff --git a/docs/cloud/features/observability/model/tcloud_model_list.png b/docs/cloud/features/observability/model/tcloud_model_list.png new file mode 100644 index 0000000000..8ac377a4d9 Binary files /dev/null and b/docs/cloud/features/observability/model/tcloud_model_list.png differ diff --git a/docs/cloud/features/observability/model/tcloud_model_status-metadata.png b/docs/cloud/features/observability/model/tcloud_model_status-metadata.png new file mode 100644 index 0000000000..b7cc1d86ed Binary files /dev/null and b/docs/cloud/features/observability/model/tcloud_model_status-metadata.png differ diff --git a/docs/cloud/features/observability/model_freshness.md b/docs/cloud/features/observability/model_freshness.md new file mode 100644 index 0000000000..67389f8b6c --- /dev/null +++ b/docs/cloud/features/observability/model_freshness.md @@ -0,0 +1,52 @@ +# Model Freshness + +Model freshness indicators on the homepage allow you to immediately determine whether the production environment is correct and up to date. + +Additional information on the page, such as lists of models and their current status, helps you investigate any freshness issues, identify problematic models, and check if CI/CD processes have stopped running. + +![tcloud model freshness](./model_freshness/find_model_freshness.png) + +## When you might use this + +The model freshness chart answers the question "how is the production environment right now?" It summarizes the recent history of production models and whether they were backfilled on time. + +When the chart is all green, everything is running smoothly and you're good to go! + +Red indicators in the past don't require immediate action, but they may provide lessons that can help prevent similar issues in the future. + +Red indicators now mean it's time to take action and debug the issue. + +## Finding the model freshness chart + +The model freshness chart is near the top of the Tobiko Cloud homepage. + +![tcloud model freshness](./model_freshness/find_model_freshness.png) + + +## Model freshness indicators + +Model freshness is the timeliness of the data most recently processed by a model. In other words, it measures how up-to-date each model is relative to its `cron`. + +The chart displays historical data, showing the percentage of models that were fresh (y-axis) across time (x-axis). + +This historical view helps when troubleshooting data issues — you can quickly check if the issue is associated with delayed model runs. + +![tcloud model freshness](./model_freshness/tcloud_model_freshness.png) + +The chart uses color to show the percentage of models in different states: + +1. Models that have run for all previous cron periods are "complete" (green). + - All green indicates the data warehouse is fully up-to-date +2. Models that haven't run for the most recent cron period are "pending" (yellow). +3. Models that haven't run for multiple previous cron periods are "behind" (red). + - Red signals potential issues that need investigation + +Keep in mind that if a model shows red (behind) in the past, that doesn't necessarily reflect its current status. It may have caught up by now! + +The chart is interactive — hovering reveals the distribution of model freshness at a specific time point. + +![Tobiko Cloud model freshness chart tooltip](./model_freshness/tcloud_model-freshness_tooltip.png) + +Click a time point to open a list of the models that were complete, pending, or behind at that time. + +![Tobiko Cloud model freshness list](./model_freshness/tcloud_model_freshness_list.png) diff --git a/docs/cloud/features/observability/model_freshness/find_model_freshness.png b/docs/cloud/features/observability/model_freshness/find_model_freshness.png new file mode 100644 index 0000000000..d5c9ca8763 Binary files /dev/null and b/docs/cloud/features/observability/model_freshness/find_model_freshness.png differ diff --git a/docs/cloud/features/observability/model_freshness/tcloud_model-freshness_tooltip.png b/docs/cloud/features/observability/model_freshness/tcloud_model-freshness_tooltip.png new file mode 100644 index 0000000000..c4072e97b8 Binary files /dev/null and b/docs/cloud/features/observability/model_freshness/tcloud_model-freshness_tooltip.png differ diff --git a/docs/cloud/features/observability/model_freshness/tcloud_model_freshness.png b/docs/cloud/features/observability/model_freshness/tcloud_model_freshness.png new file mode 100644 index 0000000000..a7d8c00983 Binary files /dev/null and b/docs/cloud/features/observability/model_freshness/tcloud_model_freshness.png differ diff --git a/docs/cloud/features/observability/model_freshness/tcloud_model_freshness_list.png b/docs/cloud/features/observability/model_freshness/tcloud_model_freshness_list.png new file mode 100644 index 0000000000..794c8d2d57 Binary files /dev/null and b/docs/cloud/features/observability/model_freshness/tcloud_model_freshness_list.png differ diff --git a/docs/cloud/features/observability/overview.md b/docs/cloud/features/observability/overview.md new file mode 100644 index 0000000000..fefcc52cbc --- /dev/null +++ b/docs/cloud/features/observability/overview.md @@ -0,0 +1,49 @@ +# Overview + +Fixing problems with data pipelines is challenging because there are so many potential causes. + +For transformation pipelines, those range from upstream source timeouts to SQL query errors to Python library conflicts (and more!). + +![Data Ops Observability](./overview/data-ops-light.png) + +Tobiko Cloud makes it easy to detect and respond to changes in your pipelines: + +- Did a problem occur? + - **Alerts notify you immediately.** +- When did the problem occur? + - **Historical pipeline information reveals the moment.** +- Where is the problem coming from? + - **Easy navigation through pipeline components lets you pinpoint the source.** +- What is causing the problem? + - **Centralized logs and errors have all the details.** + +## How it works + +Tobiko Cloud captures detailed metadata throughout your data project's lifecycle. + +During the execution of plans and runs, it collects information about model performance and system health to give you complete visibility into your system's operations. + +This information allows you to: + +- Monitor the health and performance of your data pipelines +- Track the status of current and historical runs +- Review a detailed version history of your models and transformations +- Creation of custom visualizations and metrics +- Troubleshoot problems and optimize inefficient operations + +Observability features are seamlessly integrated into Tobiko Cloud, making it simple to monitor and understand your project's behavior. For example, on the Tobiko Cloud Homepage, there is run history, plan executions, and freshness displayed for the production environment: + +![Observability on the Homepage](./overview/observability_section_home.png) + +Instead of digging through complex logs or piecing together information from multiple sources, you can quickly access the relevant information from any part of your project. + + \ No newline at end of file diff --git a/docs/cloud/features/observability/overview/data-ops-light.png b/docs/cloud/features/observability/overview/data-ops-light.png new file mode 100644 index 0000000000..08a6988245 Binary files /dev/null and b/docs/cloud/features/observability/overview/data-ops-light.png differ diff --git a/docs/cloud/features/observability/overview/observability_section_home.png b/docs/cloud/features/observability/overview/observability_section_home.png new file mode 100644 index 0000000000..815c971074 Binary files /dev/null and b/docs/cloud/features/observability/overview/observability_section_home.png differ diff --git a/docs/cloud/features/observability/plan.md b/docs/cloud/features/observability/plan.md new file mode 100644 index 0000000000..5c8d748c9a --- /dev/null +++ b/docs/cloud/features/observability/plan.md @@ -0,0 +1,91 @@ +# Plans + +Plan pages provide comprehensive, detailed insights into each plan executed across your SQLMesh environments. These pages act as a central hub where team members can monitor and understand all aspects of plan execution, from start to finish. + +In open-source SQLMesh, information about plans is stored locally by default, so team members only have immediate visibility into the plans they have executed themselves. + +To address this limitation, we've created a comprehensive plan page that serves two essential purposes: + +1. Provide a centralized place where every team member can view, track, and understand all plans and their current status. + - Benefit: increase transparency and improve collaboration across the team +2. Maintains detailed historical records, providing a reference of all a projects' changes, when each change was implemented, and how your project has evolved over time. + - Benefit: ensure nothing gets lost or forgotten as teams evolve over time + +![tcloud plan](./plan/plan.png) + +## When you might use this + +**Team Collaboration** + +Improves team collaboration through an easy-to-understand view of everyone's changes, so the entire team can see the latest updates made to an environment. + +**Monitoring** + +Tells you exactly what's happening by monitoring plan execution status. Instantly identify plans that are currently running, have completed successfully, or have encountered any issues that need attention. + +If you do encounter any issues, this page serves as an ideal starting point for debugging: + +- You no longer need to spend time searching through log files trying to locate specific model changes or modifications you've made - everything is organized and easily accessible +- We've carefully curated a log that captures everything that occurred during the plan execution. This provides a consolidated location where you can examine any plan or model and access its relevant logs, eliminating the need to parse CLI output to find what you're interested in + 1. For storage optimization, logs are retained for one week before being automatically cleaned up from the system +- Share monitoring information with teammates via links to plan pages, not screenshots of terminal output + +**Change clarification** + +Delivers a visualization of a plan's model changes, making it simple to share sets of modifications with team members who do not have direct access to the local development environment. + +Summary information at the top of the page provides context and assistance in understanding what might have gone wrong, making troubleshooting more efficient and systematic. + +- For example, share changes with teammates without opening a pull request (which could trigger an unwanted CI/CD pipeline) + +## Navigating to a Plan page +Every SQLMesh `plan` is applied to a specific environment. To locate a `plan`, first navigate to its [Environment page](./development_environment.md). + +The environment page's Recent Activity table includes a list of every recent `plan` and `run`. To visit a `plan`'s page, locate the `plan` by application date and click on its blue ID link in the table's final column. + +![tcloud plan information](./plan/plan_info.png) + +Clicking the link opens the detailed plan page: + +![tcloud plan](./plan/plan.png) + +## Plan summary + +The top section provides an at-a-glance overview of the plan, including: + +![tcloud plan](./plan/plan_top_section.png) + +- `Status`: the plan's completion status (possible values: complete, in progress, failed) +- `When`: the times when the plan started and completed +- `Plan Type`: the plan's type classification. Possible values: + - `Environment update`: the plan includes a modified model + - `Restatement`: the plan included a restated model + - `System`: the Tobiko Cloud team has made a upgrade to your system (no models or data were affected) +- `Backfill Dates`: dates for which the model was backfilled +- `Changes`: chart displaying counts of model change types (directly modified model count in blue, added models in green, removed models in red) + +## Plan changes + +The middle section presents a detailed summary of all plan changes. + +![plan example](./plan/plan_changes.png) + +Each change category has its own tab on the left side: `added` models, `directly modified` models, `metadata-only modified` models, `indirectly modified` models, and `removed` models. + +Clicking a model name takes you to its [individual model page](./model.md). + + +## Updates and Executions section + +The final section displays the different actions SQLMesh took when executing the plan, where each type of action has its own tab across the top: + +- `Physical Layer Updates` (creating physical tables) +- `Model Executions` (executing model queries) +- `Audits` (running model audits) +- `Virtual Updates` (updating environment views) + +![tcloud plan audits section](./plan/plan_tabs.png) + +The number of models in each category is included in the tab title. + +Each tab contains a table with detailed information on and links to the model(s) that have been updated. diff --git a/docs/cloud/features/observability/plan/plan.png b/docs/cloud/features/observability/plan/plan.png new file mode 100644 index 0000000000..66996ac56a Binary files /dev/null and b/docs/cloud/features/observability/plan/plan.png differ diff --git a/docs/cloud/features/observability/plan/plan_changes.png b/docs/cloud/features/observability/plan/plan_changes.png new file mode 100644 index 0000000000..054b8437a6 Binary files /dev/null and b/docs/cloud/features/observability/plan/plan_changes.png differ diff --git a/docs/cloud/features/observability/plan/plan_info.png b/docs/cloud/features/observability/plan/plan_info.png new file mode 100644 index 0000000000..96e301d5cd Binary files /dev/null and b/docs/cloud/features/observability/plan/plan_info.png differ diff --git a/docs/cloud/features/observability/plan/plan_tabs.png b/docs/cloud/features/observability/plan/plan_tabs.png new file mode 100644 index 0000000000..d8848c591c Binary files /dev/null and b/docs/cloud/features/observability/plan/plan_tabs.png differ diff --git a/docs/cloud/features/observability/plan/plan_top_section.png b/docs/cloud/features/observability/plan/plan_top_section.png new file mode 100644 index 0000000000..1376a06bcd Binary files /dev/null and b/docs/cloud/features/observability/plan/plan_top_section.png differ diff --git a/docs/cloud/features/observability/prod_environment.md b/docs/cloud/features/observability/prod_environment.md new file mode 100644 index 0000000000..71fc97ddaf --- /dev/null +++ b/docs/cloud/features/observability/prod_environment.md @@ -0,0 +1,86 @@ +# Prod Environment + +A data transformation system's most important component is the production environment, which provides the data your business runs on. + +When you first log in to Tobiko Cloud, you'll see the production environment page. This page shows you at a glance if your data systems are working properly. + +It helps data teams quickly check their work without having to dig through complicated logs - just look at the visual dashboard, and you'll know if everything is running smoothly. + +![tcloud prod env](./prod_environment/tcloud_prod_environment.png) + +## When you might use this + +**After a production update** + +The dashboard helps you check if your recent updates to production are working correctly. It uses a simple color system to show you what's happening: green means everything is good, and red shows where there might be problems. + +If you see red in your current run, plan or freshness, it means there's a problem that needs your attention. Don't worry about red marks from the past (in the historical and previous runs/plans) - these are old issues that have already been fixed. + +Best part? You can check all of this in about 5-10 seconds. + +**Quick cost check** + +The homepage also displays cost metrics for your production environment, a feature exclusive to production (not available in development environments). This allows you to quickly understand and monitor your team's model execution costs without diving into detailed reports. + +## Observing production + +Tobiko Cloud makes it easy to understand your production environment, embedding four observability features directly on your project's homepage: + +1. [Model Freshness chart](./model_freshness.md) +2. Runs and plans chart +3. Recent activity table +4. Warehouse costs overview + +![tcloud prod env](./prod_environment/tcloud_prod_environment_labelled.png) + +!!! Note + + Model freshness has its own feature page - learn more [here](./model_freshness.md)! + +### Runs and Plans Chart + +SQLMesh performs two primary actions: running the project's models on a cadence and applying plans to update the project's content/behavior. + +The Runs and Plans Chart displays a summary of all `run`s and `plans` that occurred over the previous two weeks. It shows when they occurred and how long they took to execute. + +![tcloud weekly runs](./prod_environment/weekly_runs.png) + +The chart uses color to convey `run` status at a glance: bars representing `run`s that successfully completed are marked in green, failed `run`s are red, and `run`s currently in progress are gray. `plan`s are always displayed in purple. + +The chart represents time on its `x-axis`, where each entry represents one day. The date corresponding to each day is displayed at the top of the chart. + +Each day displays zero or more vertical bars representing `run` duration. If no `run`s occurred on a day, no vertical bars will be displayed. If multiple `run`s occurred on the same day, their vertical bars will be stacked. + +The chart's `y-axis` represents `run` duration. The height of each `run`'s bar corresponds to its duration, allowing you to quickly assess execution times. + +For example, consider the leftmost entry in the figure above: + +- The label at the top of the chart shows that it represents November 26 +- The entry consists of a single green bar, which tells us that one successful `run` occurred +- The bottom of the bar begins at 0 seconds on the `y-axis`, and the top of the bar ends at 20 seconds, telling us the `run` took 20 seconds to execute + +In contrast, consider the rightmost entry in the figure above: + +- The label at the top of the chart shows that it represents December 9 +- The entry contains two green bars, which tells us that two successful `run`s occurred +- The lower bar begins at 0 seconds on the `y-axis` and reaches up to 13 seconds, telling us the `run` took 13 seconds to execute +- The upper bar begins at 13 seconds on the `y-axis` and reaches up to 22 seconds, telling us that the `run` took 22 - 13 = 9 seconds to execute + +Learn more about a `run` or `plan` by hovering over its bar, which displays a link to its page, its start and end times, and its duration. + +### Recent Activity Table + +The recent activity table provides comprehensive information about recent project activities, displaying both `run`s and `plan`s in chronological order. This provides a more granular view than the runs and plans chart. + +For each activity entry, you can view its completion status, estimated cost of execution (BigQuery and Snowflake engines only), total duration from start to finish, start and completion times, and a unique identification hash for reference purposes. + +![tcloud recent activity](./prod_environment/recent_activity.png) + +The table provides the ability to filter which rows are displayed by typing into the text box in the top right. This helps you locate specific information within the activity log, making it easier to find and analyze particular events or patterns in your system's operational history. + +### Warehouse Costs Overview +Managing data warehouse costs can be complex. Tobiko Cloud simplifies this by monitoring costs directly. For BigQuery and Snowflake projects, it tracks cost estimates per model and calculates savings from avoided model reruns. + +The costs and savings summary information and chart display the costs to run and host all the models in your production environment over the last 30 days. This provides a great way to quickly see increases and decreases in daily running costs. To learn more, [check out the cost savings docs](../costs_savings.md). + +![tcloud recent activity](./prod_environment/costs.png) \ No newline at end of file diff --git a/docs/cloud/features/observability/prod_environment/costs.png b/docs/cloud/features/observability/prod_environment/costs.png new file mode 100644 index 0000000000..469d802ad5 Binary files /dev/null and b/docs/cloud/features/observability/prod_environment/costs.png differ diff --git a/docs/cloud/features/observability/prod_environment/recent_activity.png b/docs/cloud/features/observability/prod_environment/recent_activity.png new file mode 100644 index 0000000000..727bfe21f5 Binary files /dev/null and b/docs/cloud/features/observability/prod_environment/recent_activity.png differ diff --git a/docs/cloud/features/observability/prod_environment/tcloud_prod_environment.png b/docs/cloud/features/observability/prod_environment/tcloud_prod_environment.png new file mode 100644 index 0000000000..f1f6104f86 Binary files /dev/null and b/docs/cloud/features/observability/prod_environment/tcloud_prod_environment.png differ diff --git a/docs/cloud/features/observability/prod_environment/tcloud_prod_environment_labelled.png b/docs/cloud/features/observability/prod_environment/tcloud_prod_environment_labelled.png new file mode 100644 index 0000000000..ca52c49f0e Binary files /dev/null and b/docs/cloud/features/observability/prod_environment/tcloud_prod_environment_labelled.png differ diff --git a/docs/cloud/features/observability/prod_environment/weekly_runs.png b/docs/cloud/features/observability/prod_environment/weekly_runs.png new file mode 100644 index 0000000000..46056888da Binary files /dev/null and b/docs/cloud/features/observability/prod_environment/weekly_runs.png differ diff --git a/docs/cloud/features/observability/run.md b/docs/cloud/features/observability/run.md new file mode 100644 index 0000000000..406a9d90a9 --- /dev/null +++ b/docs/cloud/features/observability/run.md @@ -0,0 +1,55 @@ +# Runs + +Run pages, like [plan pages](./plan.md), serve as centralized information sources that provide detailed insights into individual runs executed across your various environments. + +They were created with the same philosophy as the plan pages, providing a consistent user experience and navigation pattern. + +These pages act as a central hub where team members can monitor and understand all aspects of a run’s execution, from start to finish. Additionally, they can serve as a jumping off point for investigating run-related errors or unexpected behavior. + +![tcloud run](./run/tcloud_run.png) + +## When you might use this + +If you're monitoring data pipelines, a common activity is verifying the status of the most recent run. + +The run page provides a quick way to check whether a run has succeeded or failed and when exactly it was executed. The page includes a comprehensive view of all model executions and audits that were included in the run. + +If you need deeper insights, the [Debugger View](../debugger_view.md) offers advanced analysis capabilities. This powerful tool allows teams to investigate which models are taking the longest time to update, helping identify potential performance bottlenecks in their data pipelines. + +## Navigating to a Run page + +Every SQLMesh `run` is applied to a specific environment. To locate a `run`, first navigate to its [Environment page](./development_environment.md). + +The environment page's Recent Activity table includes a list of every recent `plan` and `run`. To learn more about a `run`, locate the `run` by application date and click on its blue ID link in the table's final column. + +![tcloud run information](./run/run_info.png) + +Clicking the link opens the detailed run overview page: + +![tcloud run](./run/tcloud_run.png) + +## Summary + +The top of the overview page summarizes the `run`, including: + + 1. `Status`: completion status (completed, in progress, or failed) + 2. `When`: start and end times + 3. `Changes since previous run`: list of project changes that occurred since the previous `run` + +![tcloud run](./run/tcloud_run_summary.png) + +## Details + +The lower portion of the page contains a table with three tabs. + +`Model Executions`: list of executed models, including completion status, run times, error messages (when applicable), and links to detailed execution logs for troubleshooting + +![tcloud run model executions](./run/run_model_executions.png) + +`Audits`: list of audit executions statuses, including completion status, whether the audit is blocking, and links to detailed audit logs for verification + +![tcloud run model executions](./run/run_audits.png) + +`Explore Executions`: interactive view of executed models, including a lineage graph of model dependencies, and detailed information about impact analysis, model definitions, time intervals processed, and links to associated logs (learn more on the [Debugger View page](../debugger_view.md)) + +![tcloud run model executions](./run/run_explore_executions.png) \ No newline at end of file diff --git a/docs/cloud/features/observability/run/run_audits.png b/docs/cloud/features/observability/run/run_audits.png new file mode 100644 index 0000000000..da9d9fe66a Binary files /dev/null and b/docs/cloud/features/observability/run/run_audits.png differ diff --git a/docs/cloud/features/observability/run/run_explore_executions.png b/docs/cloud/features/observability/run/run_explore_executions.png new file mode 100644 index 0000000000..8fb36a6362 Binary files /dev/null and b/docs/cloud/features/observability/run/run_explore_executions.png differ diff --git a/docs/cloud/features/observability/run/run_info.png b/docs/cloud/features/observability/run/run_info.png new file mode 100644 index 0000000000..674a7075bb Binary files /dev/null and b/docs/cloud/features/observability/run/run_info.png differ diff --git a/docs/cloud/features/observability/run/run_model_executions.png b/docs/cloud/features/observability/run/run_model_executions.png new file mode 100644 index 0000000000..0989cda5f3 Binary files /dev/null and b/docs/cloud/features/observability/run/run_model_executions.png differ diff --git a/docs/cloud/features/observability/run/tcloud_run.png b/docs/cloud/features/observability/run/tcloud_run.png new file mode 100644 index 0000000000..79a045f68c Binary files /dev/null and b/docs/cloud/features/observability/run/tcloud_run.png differ diff --git a/docs/cloud/features/observability/run/tcloud_run_summary.png b/docs/cloud/features/observability/run/tcloud_run_summary.png new file mode 100644 index 0000000000..dc2a931fd5 Binary files /dev/null and b/docs/cloud/features/observability/run/tcloud_run_summary.png differ diff --git a/docs/cloud/features/scheduler/airflow.md b/docs/cloud/features/scheduler/airflow.md new file mode 100644 index 0000000000..11e82769dd --- /dev/null +++ b/docs/cloud/features/scheduler/airflow.md @@ -0,0 +1,244 @@ +# Airflow + +Tobiko Cloud's Airflow integration allows you to combine Airflow system monitoring with the powerful debugging tools in Tobiko Cloud. + +
+ +## How it works + +Tobiko Cloud uses a custom approach to Airflow integration. + +The Airflow DAG task mirrors the progress of the Tobiko Cloud scheduler run. Each local task reflects the outcome of its corresponding remote task. + +This allows you to observe at a glance how your data pipeline is progressing, displayed alongside your other pipelines in Airflow. No need to context switch to Tobiko Cloud! + +### Why a custom approach? + +Tobiko Cloud's scheduler performs multiple optimizations to ensure that your pipelines run correctly and efficiently. Those optimizations are only possible within our SQLMesh-aware scheduler. + +Our approach allows you to benefit from those optimizations while retaining the flexibility to attach extra tasks or logic to the DAG in your broader pipeline orchestration context. + +Because `run`s are still triggered by the Tobiko Cloud scheduler and tasks in the local DAG just reflect their remote equivalent in Tobiko Cloud, we call our custom approach a *facade*. + +## Setup + +Your SQLMesh project must be configured and connected to Tobiko Cloud before using the Airflow integration. + +Learn more about connecting to Tobiko Cloud in the [Getting Started page](../../tcloud_getting_started.md). + +### Install libraries + +After connecting your project to Tobiko Cloud, you're ready to set up the Airflow integration. + +Start by installing the `tobiko-cloud-scheduler-facade` library in your Airflow runtime environment. + +Make sure to include the `[airflow]` extra in the installation command: + +``` bash +$ pip install tobiko-cloud-scheduler-facade[airflow] +``` + +!!! info "Mac Users" + + On Mac OS, you may get the following error: + + `zsh: no matches found: tobiko-cloud-scheduler-facade[airflow]` + + In which case, the argument to `pip install` needs to be quoted like so: + + ``` + $ pip install 'tobiko-cloud-scheduler-facade[airflow]' + ``` + +### Connect Airflow to Tobiko Cloud + +First, provision an OAuth Client for Airflow to use by following the guide on how to [provision client credentials](../security/single_sign_on.md#provisioning-client-credentials). + +After provisioning the credentials, you can obtain the `Client ID` and `Client Secret` values for Airflow to use to connect to Tobiko Cloud. + +Next, add an Airflow [connection](https://airflow.apache.org/docs/apache-airflow/stable/howto/connection.html#creating-a-connection-with-the-ui) containing your Tobiko Cloud credentials. + +Specify these fields when adding the connection: + +- **Connection ID**: connection name of your choice + - May not contain spaces, single quotes `'`, or double quotes `"` +- **Connection Type**: always HTTP +- **Host**: URL for your Tobiko Cloud project +- **Login**: OAuth `Client ID` for Airflow +- **Password**: OAuth `Client Secret` for Airflow + +It is convenient to specify the connection in the Airflow UI, as in this example with the name `tobiko_cloud`: + +![Add a connection in the Airflow UI](./airflow/add_connection.png) + +If the connection is successful, it will appear in the connection list: + +![List of connections in the Airflow UI](./airflow/connection_list.png) + +!!! info "Remember the connection name!" + + Name the connection whatever you like, but remember that name because it's used for the `conn_id` parameter below. + +## Create a DAG + +You are now ready to create an Airflow DAG that connects to Tobiko Cloud. + +This example code demonstrates the creation process, which requires: + +- Importing the `SQLMeshEnterpriseAirflow` operator +- Creating a `SQLMeshEnterpriseAirflow` instance with your Airflow connection id (the name from [above](#connect-airflow-to-tobiko-cloud!)) +- Creating the DAG object with the `create_cadence_dag()` method + + +```python linenums="1" +# folder: dags/ +# file name: tobiko_cloud_airflow_integration.py + +# Import SQLMeshEnterpriseAirflow operator +from tobikodata.scheduler_facades.airflow import SQLMeshEnterpriseAirflow + +# Create SQLMeshEnterpriseAirflow instance with connection ID +tobiko_cloud = SQLMeshEnterpriseAirflow(conn_id="tobiko_cloud") + +# Create DAG for `prod` environment from SQLMeshEnterpriseAirflow instance +first_task, last_task, dag = tobiko_cloud.create_cadence_dag(environment="prod") +``` + +This is all that's needed to integrate with Tobiko Cloud! + +## Monitor Tobiko Cloud actions + +Once your DAG is loaded by Airflow, it will be populated with the SQLMesh models for the specified `environment` and will automatically trigger when the next Cloud Scheduler run happens. + +You will see an entry in the DAG list: + +![Airflow UI list of DAGs](./airflow/dag_list.png) + +You can browse the DAG just like any other - each node is a SQLMesh model: + +![Airflow UI DAG view](./airflow/dag_view.png) + +## Debugging + +Each task in the local DAG writes logs that include a link to its corresponding remote task in Tobiko Cloud. + +In the Airflow UI, find these logs in the task's Logs tab: + +![Airflow UI task logs](./airflow/task_logs.png) + +Clicking the link opens the remote task in the Tobiko Cloud [Debugger View](../debugger_view.md), which provides information and tools to aid debugging: + +![Tobiko Cloud UI debugger view](./airflow/cloud_debugger.png) + +## Extending the DAG + +You may extend the local DAG with arguments to the `create_cadence_dag()` method. + +This section describes how to extend your local DAG and demonstrates some simple extensions. + +### Base DAG structure + +The local DAG represents your SQLMesh project's models and their activity in Tobiko Cloud. This section describes how the DAG is structured. + +The DAG is composed of SQLMesh models, but there must be a boundary around those models to separate them from your broader Airflow pipeline. The boundary consists of two tasks that serve as entry and exit nodes for the entire Tobiko Cloud run. + +The first and last tasks in the DAG are the boundary tasks. The tasks are the same in every local DAG instance: + +- First task: `Sensor` task that synchronizes with Tobiko Cloud +- Last task: `DummyOperator` task that ensures all models without downstream dependencies have completed before declaring the DAG completed + +![Airflow DAG boundary tasks](./airflow/boundary_tasks.png) + +### Using `create_cadence_dag()` + +The local DAG is extended at the time of creation via arguments to the `create_cadence_dag()` method. + +Each DAG corresponds to a specific SQLMesh project environment (`prod` by default). Specify another environment by passing its name to `create_cadence_dag()`'s `environment` argument. + +The `create_cadence_dag()` method returns a tuple of references: + +- `first_task` - a reference to the first task in the DAG (always the `Sensor` boundary task) +- `last_task` - a reference to the last task in the DAG (always the `DummyOperator` boundary task) +- `dag` - a reference to the Airflow `DAG` object + +Use these references to manipulate the DAG and attach extra behavior. + +### Examples + +#### Slack notification when run begins + +Attach a task to the `first_task` to send a Slack notification when a `run` completes: + +```python +# Create DAG +first_task, last_task, dag = tobiko_cloud.create_cadence_dag(environment="prod") + +# Attach Slack operator to first_task +first_task >> SlackAPIPostOperator(task_id="notify_slack", channel="#notifications", ...) +``` + +Airflow DAG view: + +![Airflow DAG with Slack notification task](./airflow/add_task_at_start.png) + +#### Send email and trigger DAG when run completes + +Attach tasks to the `last_task` to send an email and trigger another DAG on `run` completion: + +```python +# Create DAG +first_task, last_task, dag = tobiko_cloud.create_cadence_dag(environment="prod") + +# Attach Email operator to last_task +last_task >> EmailOperator(task_id="notify_admin", to="admin@example.com", subject="SQLMesh run complete") + +# Attach DAG trigger operator to last_task +last_task >> TriggerDagRunOperator(task_id="trigger_job", trigger_dag_id="some_downstream_job") +``` + +Airflow DAG view: + +![Airflow DAG with email and DAG trigger tasks on run completion](./airflow/add_task_at_end.png) + +#### Trigger DAG when specific model completes + +Trigger another DAG after a specific model has completed, without waiting for the entire run to complete: + +```python +# Create DAG +first_task, last_task, dag = tobiko_cloud.create_cadence_dag(environment="prod") + +# Get `sushi.customers` model task +customers_task = dag.get_task("sushi.customers") + +# Attach DAG trigger operator to `sushi.customers` model task +customers_task >> TriggerDagRunOperator(task_id="customers_updated", trigger_dag_id="some_other_pipeline", ...) +``` + +Airflow DAG view: + +![Airflow DAG with DAG trigger task on a specific model's completin](./airflow/add_task_after_specific_model.png) + +!!! info "Model task names" + + Each model's Airflow `task_id` is the SQLMesh fully qualified model name. View a task's `task_id` by hovering over its node in the Airflow DAG view. + + Each model's display name in the Airflow DAG view is just the *table* portion of the fully qualified model name. For example, a SQLMesh model named `foo.model_a` will be labeled `model_a` in the Airflow DAG view. + +## Configuration + +### `SQLMeshEnterpriseAirflow` parameters + +| Option | Description | Type | Required | +|-----------|--------------------------------------------------------------------------|:----:|:--------:| +| `conn_id` | The Airflow connection ID containing the Tobiko Cloud connection details | str | Y | + +### `create_cadence_dag()` parameters + +| Option | Description | Type | Required | +|----------------------|----------------------------------------------------------------------------------------|:----:|:--------:| +| `environment` | Which SQLMesh environment to target. Default: `prod` | str | N | +| `dag_kwargs` | A dict of arguments to pass to the Airflow DAG object when it is created. | dict | N | +| `common_task_kwargs` | A dict of kwargs to pass to all task operators in the DAG | dict | N | +| `sensor_task_kwargs` | A dict of kwargs to pass to just the sensor task operators in the DAG | dict | N | +| `report_task_kwargs` | A dict of kwargs to pass to just the model / progress report task operators in the DAG | dict | N | \ No newline at end of file diff --git a/docs/cloud/features/scheduler/airflow/add_connection.png b/docs/cloud/features/scheduler/airflow/add_connection.png new file mode 100644 index 0000000000..e73e6ef2e3 Binary files /dev/null and b/docs/cloud/features/scheduler/airflow/add_connection.png differ diff --git a/docs/cloud/features/scheduler/airflow/add_task_after_specific_model.png b/docs/cloud/features/scheduler/airflow/add_task_after_specific_model.png new file mode 100644 index 0000000000..00c527803f Binary files /dev/null and b/docs/cloud/features/scheduler/airflow/add_task_after_specific_model.png differ diff --git a/docs/cloud/features/scheduler/airflow/add_task_at_end.png b/docs/cloud/features/scheduler/airflow/add_task_at_end.png new file mode 100644 index 0000000000..c8c4753799 Binary files /dev/null and b/docs/cloud/features/scheduler/airflow/add_task_at_end.png differ diff --git a/docs/cloud/features/scheduler/airflow/add_task_at_start.png b/docs/cloud/features/scheduler/airflow/add_task_at_start.png new file mode 100644 index 0000000000..e3630ce915 Binary files /dev/null and b/docs/cloud/features/scheduler/airflow/add_task_at_start.png differ diff --git a/docs/cloud/features/scheduler/airflow/boundary_tasks.png b/docs/cloud/features/scheduler/airflow/boundary_tasks.png new file mode 100644 index 0000000000..5bc69990d5 Binary files /dev/null and b/docs/cloud/features/scheduler/airflow/boundary_tasks.png differ diff --git a/docs/cloud/features/scheduler/airflow/cloud_debugger.png b/docs/cloud/features/scheduler/airflow/cloud_debugger.png new file mode 100644 index 0000000000..ec62f6b3bb Binary files /dev/null and b/docs/cloud/features/scheduler/airflow/cloud_debugger.png differ diff --git a/docs/cloud/features/scheduler/airflow/connection_list.png b/docs/cloud/features/scheduler/airflow/connection_list.png new file mode 100644 index 0000000000..37d0e85dde Binary files /dev/null and b/docs/cloud/features/scheduler/airflow/connection_list.png differ diff --git a/docs/cloud/features/scheduler/airflow/dag_list.png b/docs/cloud/features/scheduler/airflow/dag_list.png new file mode 100644 index 0000000000..5cfbbbf2e0 Binary files /dev/null and b/docs/cloud/features/scheduler/airflow/dag_list.png differ diff --git a/docs/cloud/features/scheduler/airflow/dag_view.png b/docs/cloud/features/scheduler/airflow/dag_view.png new file mode 100644 index 0000000000..535f31602d Binary files /dev/null and b/docs/cloud/features/scheduler/airflow/dag_view.png differ diff --git a/docs/cloud/features/scheduler/airflow/task_logs.png b/docs/cloud/features/scheduler/airflow/task_logs.png new file mode 100644 index 0000000000..5036c63c48 Binary files /dev/null and b/docs/cloud/features/scheduler/airflow/task_logs.png differ diff --git a/docs/cloud/features/scheduler/dagster.md b/docs/cloud/features/scheduler/dagster.md new file mode 100644 index 0000000000..338bed2572 --- /dev/null +++ b/docs/cloud/features/scheduler/dagster.md @@ -0,0 +1,389 @@ +# Dagster + +Tobiko Cloud's Dagster integration allows you to combine Dagster system monitoring with the powerful debugging tools in Tobiko Cloud. + +
+ +## How it works + +Tobiko Cloud uses a custom approach to Dagster integration. + +The `mirror` job mirrors the progress of the Tobiko Cloud scheduler run. Each local task reflects the outcome of its corresponding remote task. If an asset is materialized remotely, the job emits a Dagster materialization event. + +This allows you to observe at a glance how your data pipeline is progressing, displayed alongside your other pipelines in Dagster. No need to context switch to Tobiko Cloud! + +### Why a custom approach? + +Tobiko Cloud's scheduler performs multiple optimizations to ensure that your pipelines run correctly and efficiently. Those optimizations are only possible within our SQLMesh-aware scheduler. + +Our approach allows you to benefit from those optimizations while retaining the flexibility to attach extra tasks or logic to the Dagster Assets created by Tobiko Cloud. + +Because `run`s are still triggered by the Tobiko Cloud scheduler and tasks in the local DAG just reflect their remote equivalent in Tobiko Cloud, we call our custom approach a *facade*. + +## Setup + +Your SQLMesh project must be configured and connected to Tobiko Cloud before using the Dagster integration. + +Learn more about connecting to Tobiko Cloud in the [Getting Started page](../../tcloud_getting_started.md). + +!!! info "Supported Dagster versions" + This integration is supported on Dagster 1.9.1 or later. Earlier versions may work but they are not tested. + +### Configure Dagster project + +After connecting your project to Tobiko Cloud, you're ready to set up the Dagster integration. + +First, navigate to your Dagster project or [create a new one](https://docs.dagster.io/guides/build/projects/creating-a-new-project). + +Next, add the `tobiko-cloud-scheduler-facade` library to the `dependencies` section of your [Dagster project](https://docs.dagster.io/guides/understanding-dagster-project-files)'s `pyproject.toml`: + +```python title="pyproject.toml" hl_lines="4" +[project] +dependencies = [ + "dagster", + "tobiko-cloud-scheduler-facade[dagster]" +], +``` + +And then install it into the Python environment used by your Dagster project: + +```sh +$ pip install -e '.[dev]' +``` + +### Connect Dagster to Tobiko Cloud + +Dagster recommends [injecting secret values using Environment Variables](https://docs.dagster.io/guides/dagster/using-environment-variables-and-secrets#using-environment-variables-and-secrets). The exact method you should use depends on how your organization deploys Dagster. + +On this page, we demonstrate the secrets method Dagster recommends for **local development**. + +First, provision an OAuth Client for Dagster to use by following the guide on how to [provision client credentials](../security/single_sign_on.md#provisioning-client-credentials). + +After provisioning the credentials, you can obtain the `Client ID` and `Client Secret` values for Dagster to use to connect to Tobiko Cloud. + +In your Dagster project, create an `.env` file if it does not already exist. Next, specify environment variables containing the Tobiko Cloud URL and OAuth secrets: + +```sh title=".env" +TCLOUD_BASE_URL= # ex: https://cloud.tobikodata.com/sqlmesh/tobiko/public-demo/ +TCLOUD_CLIENT_ID= # ex: '5ad2938d-e607-489a-8bec-bdfb5924b79b' +TCLOUD_CLIENT_SECRET= # ex: 'psohFoOcgweYnbx-bmYn3XXRDSNIP' +``` + +### Create Dagster objects + +You are now ready to create Dagster objects connected to Tobiko Cloud. + +This example code demonstrates the creation process, which requires: + +- Importing the `SQLMeshEnterpriseDagster` class from the `tobikodata` Python library +- Creating a `SQLMeshEnterpriseDagster` instance configured with the environment variables from the project's `.env` file +- Creating a `Definitions` object with the instance's `create_definitions()` method + +In your Dagster project's `definitions.py` file, insert the following: + +```python title="definitions.py" linenums="1" +from tobikodata.scheduler_facades.dagster import SQLMeshEnterpriseDagster +from dagster import EnvVar # for accessing variables in .env file + +# create and configure SQLMeshEnterpriseDagster instance named `sqlmesh` +sqlmesh = SQLMeshEnterpriseDagster( + url=EnvVar("TCLOUD_BASE_URL").get_value(), # environment variable from .env file + oauth_client_id=EnvVar("TCLOUD_CLIENT_ID").get_value(), # environment variable from .env file + oauth_client_secret=EnvVar("TCLOUD_CLIENT_SECRET").get_value(), # environment variable from .env file +) + +# create Definitions object with `sqlmesh` object's `create_definitions()` method +tobiko_cloud_definitions = sqlmesh.create_definitions(environment="prod") +``` + +!!! info + If there is an existing definitions object already declared in your Dagster project, merge in the Tobiko Cloud definitions like this: + + ```python + defs = Definitions(...) # existing Definitions object + + defs = Definitions.merge(defs, sqlmesh.create_definitions(environment="prod")) + ``` + +This is all that's needed to integrate with Tobiko Cloud! + +Once Dagster loads your project, the new SQLMesh objects will be available. + +## Available Dagster objects + +The Tobiko Cloud Dagster integration exports the following objects to Dagster: + +- An `Asset` object for every SQLMesh Model + ![Dagster UI Asset Lineage](./dagster/asset_lineage.png) +

+ +- An `AssetCheck` object attached to the relevant `Asset`'s for every SQLMesh Audit + ![Dagster UI Asset Checks](./dagster/asset_check_list.png) +

+ +- Two `Jobs`: + - A `sync` job to synchronise the current state of all Assets and Asset Checks from Tobiko Cloud to Dagster + - A `mirror` job that tracks a Cloud Scheduler run and mirrors the results to Dagster + ![Dagster UI Jobs List](./dagster/job_list.png) +

+ +- A `Sensor` to monitor Tobiko Cloud for new cadence runs and trigger the `mirror` job when one is detected + ![Dagster UI Sensor List](./dagster/sensor_list.png) + +Once your Definitions are loaded by Dagster, these objects will be available in the Dagster UI. + +## Monitor Tobiko Cloud actions + +Dagster retrieves information from Tobiko Cloud with a [Sensor](https://docs.dagster.io/guides/automate/sensors). + +To start monitoring Tobiko Cloud actions, enable the Sensor [in the Dagster UI](https://docs.dagster.io/guides/automate/sensors/monitoring-sensors-in-the-dagster-ui): + +![Enable the track sensor](./dagster/enable_sensor.png) + +The Sensor is configured to run every 30 seconds. It does the following: + +- On the first run, it triggers the `sync` job. This synchronizes the materialization status of the Dagster assets with the Models and Audits from Tobiko Cloud. +- On subsequent runs, it checks if a new Cloud Scheduler run has occurred. If so, it triggers the `mirror` job to mirror the outcome of that run in Dagster. + +![Job run records in the Dagster UI](./dagster/job_run_records.png) + +!!! question "Why are there two jobs?" + The Tobiko Cloud scheduler does everything it can to prevent unnecessary work, such as only reporting materialization information for the models that were updated in a run. + + Therefore, Dagster does not receive materialization information for excluded models or objects that are never part of a cadence run (such as [seeds](../../../concepts/models/seed_models.md)). + + The `sync` jobs addresses this by copying the current state of the entire project. The `mirror` job then updates that information based on what happens during a specific cadence run. + +To manually refresh materialization information for all models, run the `sync` job manually from the Dagster UI: + +![Manually run the sync job](./dagster/manual_sync_run.png) + +## Debugging + +When something goes wrong, the first priority is getting more information. + +Tobiko Cloud makes it easy to access that information from Dagster via links to each object's corresponding remote task in Tobiko Cloud. + +In the Dagster UI, the links are available in the job's Logs page: + +![Dagster Job Logs](./dagster/job_logs.png) + +Alternatively, in the Asset Catalog, the link is included in the last evaluation's logs as Metadata: + +![Dagster Asset Metadata](./dagster/asset_latest_materialization_metadata.png) + +Clicking the link opens the remote task in the Tobiko Cloud [Debugger View](../debugger_view.md), which provides information and tools to aid debugging: + +![Tobiko Cloud UI debugger view](./airflow/cloud_debugger.png) + +## Picking up new Models + +Dagster does not automatically reload the Asset `Definitions` defined in the code [above](#create-dagster-objects). This means that models added or removed during a `plan` will not automatically appear in Dagster. + +This section describes two methods for refreshing Dagster and picking up those models. + +### Automatic method + +Dagster runs user code in an isolated sandbox for security purposes, which complicates automatic reloading of Asset `Definitions`. + +Specifically, we must use Dagster's GraphQL API, which is not enabled by default. To enable it, specify a GraphQL host and port when creating the `SQLMeshEnterpriseDagster` instance: + +```python title="definitions.py" linenums="1" hl_lines="4 5" +sqlmesh = SQLMeshEnterpriseDagster( + url=EnvVar("TCLOUD_BASE_URL").get_value(), + #...SNIP..., + dagster_graphql_host="localhost", # Example GraphQL host (could be passed in an environment variable instead) + dagster_graphql_port=3000 # Example GraphQL port (could be passed in an environment variable instead) +) +``` + +The GraphQL host and port above reflect the specific Dagster deployment used for this example. (A Dagster deployment in local development mode that was started with `dagster dev` typically uses hostname `localhost` and port `3000`.) + +The values you should specify for `dagster_graphql_host` and `dagster_graphql_port` depend on the GraphQL hostname and port in your Dagster deployment. + +The `mirror` job's Sensor automatically picks up the new/removed assets by issuing a GraphQL request to reload the Code Location. It then executes the `mirror` job as usual. + +### Manual method + +At any time, you can update Dagster's asset information by clicking the "Reload" Code Location button: + +![Dagster reload code location](./dagster/reload_code_location.png) + +## Attaching custom logic + +Dagster includes a robust events system that lets you detect and respond to events issued by your Assets or Jobs. + +Tobiko Cloud's Dagster integration lets you run your own custom logic in response to events emitted by Tobiko Cloud assets. + +To listen for materialization events on Assets, use an [Asset Sensor](https://docs.dagster.io/concepts/partitions-schedules-sensors/asset-sensors). + +To listen for job runs, use a [Run Status Sensor](https://docs.dagster.io/concepts/partitions-schedules-sensors/sensors#run-status-sensors). + +Dagster also provides a framework called [Declarative Automation](https://docs.dagster.io/concepts/automation/declarative-automation) that builds on top of these sensors. + +### Examples + +Here are some examples of running custom logic in response to Tobiko Cloud events. + +Note that Dagster has a lot of flexibility in how it can be configured, and the methods we describe below aren't necessarily the right choice for every configuration. + +We recommend familiarizing yourself with Dagster's [Automation](https://docs.dagster.io/concepts/automation) features to get the most out of your Tobiko Cloud deployment with Dagster. + +#### Respond to run status + +To listen for Tobiko Cloud run events, create a [Run Status Sensor](https://docs.dagster.io/concepts/partitions-schedules-sensors/sensors#run-status-sensors) that listens for events on the `mirror` job and triggers your custom job in response. + +![Dagster run status sensor](./dagster/run_status_sensor.png) + +Creating the Run Status Sensor has three steps: defining a custom job, detecting Tobiko Cloud events, and creating a Sensor that executes your custom job when events are detected. + +Step 1: Define a custom job + +Your custom job has full access to Python and any libraries installed in your Dagster environment, so you can implement any logic you like. + +Define a function executing your logic and decorate it with `@op`. Then define a function calling the logic function and decorate it with `@job` to group it into a Job: + +``` python linenums="1" +# function that implements custom logic +@op +def send_email(): + import smtplib + + with smtplib.SMTP("smtp.yourdomain.com") as server: + server.sendmail(...) + +# function that creates job to execute custom logic function +@job +def send_email_job(): + send_email() +``` + +Step 2: Detect Tobiko Cloud events + +There are two approaches to detecting Tobiko Cloud events. In the examples below, both approaches create a `mirror_job` object used by the Run Status Sensor. + +The reference approach detects events based on a reference to the Tobiko Cloud mirror job, which is always named `tobiko_cloud_mirror_run_prod`. It extracts the reference from the Definitions object we created above: + +``` python +mirror_job = tobiko_cloud_definitions.get_job_def("tobiko_cloud_mirror_run_prod") +``` + +Alternatively, use the [JobSelector](https://docs.dagster.io/concepts/partitions-schedules-sensors/sensors#cross-code-location-run-status-sensors) approach if the Tobiko Cloud Definitions are in their own Code Location and not directly accessible from your job code: + +``` python +from dagster import JobSelector + +mirror_job = JobSelector(job_name="tobiko_cloud_mirror_run_prod") +``` + +Step 3: Create a Run Status Sensor + +With our `mirror_job` object in hand, we are ready to create a `@run_status_sensor` that listens to the mirror job and triggers your custom job when the mirror job is complete: + +```python +@run_status_sensor( + run_status=DagsterRunStatus.SUCCESS, + monitored_jobs=[mirror_job], # Sensor should listen to `mirror_job` + request_job=send_email_job # Sensor should execute `send_email_job` when `mirror_job` is complete +) +def on_tobiko_cloud_start_run(context: RunStatusSensorContext): + return RunRequest() +``` + +You can adjust the decorator's `run_status` argument to listen for different statuses, the `monitored_jobs` argument to monitor other Tobiko Cloud jobs, and the `request_job` argument to trigger a different custom job when an event is detected. + +Here's an example that triggers a Slack notification when a new run starts: + +```python title="Sensor sends Slack notification when a new run starts" linenums="1" +from dagster import run_status_sensor, job, DagsterRunStatus, EnvVar, RunRequest, RunStatusSensorContext +from dagster_slack import SlackResource + +# get reference to mirror job +mirror_job = tobiko_cloud_definitions.get_job_def("tobiko_cloud_mirror_run_prod") + +# define custom logic function +@op +def slack_op(slack: SlackResource): + # see the dagster-slack docs here: https://docs.dagster.io/_apidocs/libraries/dagster-slack + slack.get_client().chat_postMessage(channel="#notifications", ...) + +# define job function that calls custom logic function +@job +def notify_slack(): + slack_op() + +# define Sensor +@run_status_sensor( + run_status=DagsterRunStatus.STARTED, # Listens for STARTED runs + monitored_jobs=[mirror_job], + request_job=notify_slack +) +def on_tobiko_cloud_start_run(context: RunStatusSensorContext): + return RunRequest() +``` + +#### Respond to Asset Materialization + +When Tobiko Cloud refreshes or adds new data to a model, a Materialization event occurs for its corresponding Asset in Dagster. The Materialization event provides a hook we can use to run custom logic. + +![Dagster asset sensor](./dagster/asset_sensor.png) + +As before, the custom logic can do anything you want, such as triggering the materialization of another Asset fully managed by Dagster or running some custom task. Triggering the materialization of Tobiko Cloud Assets will not work correctly, as they simply reflect the operations performed by Tobiko Cloud. + +To listen for Asset Materialization events, create an [Asset Sensor](https://docs.dagster.io/concepts/partitions-schedules-sensors/asset-sensors). + +For example, let's say your Tobiko Cloud project has a model called `postgres.crm.customers`, and it's showing in the Dagster Asset Catalog under "postgres / crm / customers". + +Define an Asset Sensor to respond to this model's materialization events like this: + +```python +from dagster import AssetKey, SensorEvaluationContext, EventLogEntry + +@job +def internal_customers_pipeline(): + # custom logic goes here + pass + +@asset_sensor( + asset_key=AssetKey(["postgres", "crm", "customers"]), # Asset key found in Dagster Asset Catalog + job=internal_customers_pipeline +) +def on_crm_customers_updated(context: SensorEvaluationContext, asset_event: EventLogEntry): + yield RunRequest() +``` + +The sensor will trigger every time the Asset with the key `postgres / crm / customers` is materialized. + +To identify the `AssetKey`'s of your Assets, check Dagster's Asset Catalog. Each part of the path is a segment of the Asset Key. + +![Dagster asset keys](./dagster/asset_keys.png) + +These `AssetKey` values correspond to the models in the screenshot above: + +```python +from dagster import AssetKey + +active_customers = AssetKey(["postgres", "sushi", "active_customers"]) +customer_revenue_by_day = AssetKey(["postgres", "sushi", "customer_revenue_by_day"]) +``` + +## Configuration + +### `SQLMeshEnterpriseDagster` parameters + +| Option | Description | Type | Required | +|----------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:----:|:--------:| +| `url` | The Base URL to your Tobiko Cloud instance | str | Y | +| `oauth_client_id` | OAuth Client ID of the credentials you [provisioned](../security/single_sign_on.md#provisioning-client-credentials) for Dagster | str | N | +| `oauth_client_secret` | OAuth Client Secret of the credentials you [provisioned](../security/single_sign_on.md#provisioning-client-credentials) for Dagster | str | N | +| `dagster_graphql_host` | Hostname of the Dagster Webserver GraphQL endpoint | str | N | +| `dagster_graphql_port` | Port of the Dagster Webserver GraphQL endpoint | int | N | +| `dagster_graphql_kwargs` | Extra args to pass to the [DagsterGraphQLClient](https://docs.dagster.io/api/python-api/libraries/dagster-graphql#dagster_graphql.DagsterGraphQLClient) class when it is instantiated | dict | N | + +### `create_definitions()` parameters + +| Option | Description | Type | Required | +|----------------------------|----------------------------------------------------------------------------------------|:----:|:--------:| +| `environment` | Which SQLMesh environment to target. Default: `prod` | str | N | +| `asset_prefix` | Top-level category to nest Tobiko Cloud assets under | str | N | +| `enable_sensor_by_default` | Whether the Sensor that polls for new runs should be enabled by default. Default: True | bool | N | diff --git a/docs/cloud/features/scheduler/dagster/asset_check_list.png b/docs/cloud/features/scheduler/dagster/asset_check_list.png new file mode 100644 index 0000000000..9e01ea0352 Binary files /dev/null and b/docs/cloud/features/scheduler/dagster/asset_check_list.png differ diff --git a/docs/cloud/features/scheduler/dagster/asset_keys.png b/docs/cloud/features/scheduler/dagster/asset_keys.png new file mode 100644 index 0000000000..217d6c8a31 Binary files /dev/null and b/docs/cloud/features/scheduler/dagster/asset_keys.png differ diff --git a/docs/cloud/features/scheduler/dagster/asset_latest_materialization_metadata.png b/docs/cloud/features/scheduler/dagster/asset_latest_materialization_metadata.png new file mode 100644 index 0000000000..c3c39f2f6c Binary files /dev/null and b/docs/cloud/features/scheduler/dagster/asset_latest_materialization_metadata.png differ diff --git a/docs/cloud/features/scheduler/dagster/asset_lineage.png b/docs/cloud/features/scheduler/dagster/asset_lineage.png new file mode 100644 index 0000000000..26cd22a386 Binary files /dev/null and b/docs/cloud/features/scheduler/dagster/asset_lineage.png differ diff --git a/docs/cloud/features/scheduler/dagster/asset_sensor.png b/docs/cloud/features/scheduler/dagster/asset_sensor.png new file mode 100644 index 0000000000..fa38d8cefd Binary files /dev/null and b/docs/cloud/features/scheduler/dagster/asset_sensor.png differ diff --git a/docs/cloud/features/scheduler/dagster/enable_sensor.png b/docs/cloud/features/scheduler/dagster/enable_sensor.png new file mode 100644 index 0000000000..8c4fdc3557 Binary files /dev/null and b/docs/cloud/features/scheduler/dagster/enable_sensor.png differ diff --git a/docs/cloud/features/scheduler/dagster/job_list.png b/docs/cloud/features/scheduler/dagster/job_list.png new file mode 100644 index 0000000000..88786930fd Binary files /dev/null and b/docs/cloud/features/scheduler/dagster/job_list.png differ diff --git a/docs/cloud/features/scheduler/dagster/job_logs.png b/docs/cloud/features/scheduler/dagster/job_logs.png new file mode 100644 index 0000000000..6a7dcb88e2 Binary files /dev/null and b/docs/cloud/features/scheduler/dagster/job_logs.png differ diff --git a/docs/cloud/features/scheduler/dagster/job_run_records.png b/docs/cloud/features/scheduler/dagster/job_run_records.png new file mode 100644 index 0000000000..690186e332 Binary files /dev/null and b/docs/cloud/features/scheduler/dagster/job_run_records.png differ diff --git a/docs/cloud/features/scheduler/dagster/manual_sync_run.png b/docs/cloud/features/scheduler/dagster/manual_sync_run.png new file mode 100644 index 0000000000..64be419ead Binary files /dev/null and b/docs/cloud/features/scheduler/dagster/manual_sync_run.png differ diff --git a/docs/cloud/features/scheduler/dagster/reload_code_location.png b/docs/cloud/features/scheduler/dagster/reload_code_location.png new file mode 100644 index 0000000000..e4a28982e6 Binary files /dev/null and b/docs/cloud/features/scheduler/dagster/reload_code_location.png differ diff --git a/docs/cloud/features/scheduler/dagster/run_status_sensor.png b/docs/cloud/features/scheduler/dagster/run_status_sensor.png new file mode 100644 index 0000000000..e709094780 Binary files /dev/null and b/docs/cloud/features/scheduler/dagster/run_status_sensor.png differ diff --git a/docs/cloud/features/scheduler/dagster/sensor_list.png b/docs/cloud/features/scheduler/dagster/sensor_list.png new file mode 100644 index 0000000000..f9461f3879 Binary files /dev/null and b/docs/cloud/features/scheduler/dagster/sensor_list.png differ diff --git a/docs/cloud/features/scheduler/hybrid_executors/hybrid-executors_standard-hybrid-deployment.png b/docs/cloud/features/scheduler/hybrid_executors/hybrid-executors_standard-hybrid-deployment.png new file mode 100644 index 0000000000..2877933d1f Binary files /dev/null and b/docs/cloud/features/scheduler/hybrid_executors/hybrid-executors_standard-hybrid-deployment.png differ diff --git a/docs/cloud/features/scheduler/hybrid_executors_docker_compose.md b/docs/cloud/features/scheduler/hybrid_executors_docker_compose.md new file mode 100644 index 0000000000..8f8f323139 --- /dev/null +++ b/docs/cloud/features/scheduler/hybrid_executors_docker_compose.md @@ -0,0 +1,121 @@ +# Tobiko Cloud Hybrid Executors - Docker Compose Setup + +
+ +This Docker Compose configuration allows you to run Tobiko Cloud hybrid executors locally or on any server that supports Docker Compose. + +Hybrid executors allow you to run operations on your own infrastructure while leveraging Tobiko Cloud for orchestration. + +## What this setup provides + +This setup deploys two hybrid executors that pass work tasks from Tobiko Cloud to your data warehouse in a secure way: + +- **Apply Executor**: Handles applying changes to the data warehouse +- **Run Executor**: Handles scheduled model execution + +Both executors must be properly configured with environment variables to connect to Tobiko Cloud and your data warehouse. + +## Prerequisites + +- Access to a [data warehouse supported by Tobiko Cloud](../../../integrations/overview.md#execution-engines) (e.g., Postgres, Snowflake, BigQuery) +- Docker and Docker Compose +- A Tobiko Cloud account with [client ID and client secret](../security/single_sign_on.md#provisioning-client-credentials) + +## Quick start guide + +1. **Get docker-compose file**: + + Download the [docker-compose.yml](https://raw.githubusercontent.com/SQLMesh/sqlmesh/refs/heads/main/docs/cloud/features/scheduler/scheduler/docker-compose.yml) and [.env.example](https://raw.githubusercontent.com/SQLMesh/sqlmesh/refs/heads/main/docs/cloud/features/scheduler/scheduler/.env.example) files to a local directory. + +2. **Create your environment file**: + + Copy the downloaded example environment file into a new `.env` file: + + ```bash + cp .env.example .env + ``` + +3. **Edit the .env file** with your project's configuration: + + - Set your Tobiko Cloud organization, project, client ID, and client secret + - Configure your gateway connection details + - Adjust resource limits if needed + +4. **Start the executors**: + + ```bash + docker compose up -d + ``` + +5. **Check the logs**: + + ```bash + docker compose logs -f + ``` + +## Configuration options + +### Gateway configuration + +The default configuration in the `docker-compose.yml` file uses Postgres, but you can use [any supported SQL engine](../../../integrations/overview.md#execution-engines) by adjusting the connection parameters in your `.env` file. + +#### Multiple gateways + +To configure multiple gateways, add additional environment variables for each gateway the `docker-compose.yml` file: + +```yaml +environment: + # First gateway + SQLMESH__GATEWAYS__GATEWAY_A__CONNECTION__TYPE: ${DB_TYPE:-postgres} + # ... other GATEWAY_A configuration ... + + # Second gateway + SQLMESH__GATEWAYS__GATEWAY_B__CONNECTION__TYPE: snowflake + SQLMESH__GATEWAYS__GATEWAY_B__CONNECTION__ACCOUNT: ${SNOWFLAKE_ACCOUNT} + # ... other GATEWAY_B configuration ... +``` + +## Health checking + +Verify the health of your executors by running these commands: + +```bash +docker compose exec apply-executor /app/pex executor apply --check +docker compose exec run-executor /app/pex executor run --check +``` + +Example successful output: + +```bash +> docker compose exec apply-executor /app/pex executor apply --check +2025-04-09 21:24:49,873 - MainThread - httpx - INFO - HTTP Request: GET https://cloud.tobikodata.com/sqlmesh///api/state-sync/enterprise-version/upgrade "HTTP/1.1 200 OK" (_client.py:1025) +2025-04-09 21:24:49,889 - MainThread - tobikodata.tcloud.installer - INFO - Executor is installed (installer.py:180) +``` + +In addition, ensure the executors are healthy by running `echo $?` to confirm the check command returned exit code 0. + +## Stopping the executors + +To stop the executors: + +```bash +docker compose down +``` + +## Troubleshooting + +If you encounter issues: + +1. Check the logs: `docker compose logs -f` +2. Verify your connection settings in the `.env` file +3. Ensure your client ID and client secret are correct +4. Check that your SQL engine is accessible from the Docker containers + +## Security considerations + +!!! warning "Never commit .env to version control" + + The `.env` file contains sensitive information. Never commit it to version control. + +- Consider using Docker secrets or a secrets management solution in production environments. +- For production deployments, consider using the Kubernetes Helm chart instead, which offers more robust reliability and secret management options. \ No newline at end of file diff --git a/docs/cloud/features/scheduler/hybrid_executors_helm.md b/docs/cloud/features/scheduler/hybrid_executors_helm.md new file mode 100644 index 0000000000..b945ad6bd6 --- /dev/null +++ b/docs/cloud/features/scheduler/hybrid_executors_helm.md @@ -0,0 +1,328 @@ +# Tobiko Cloud Hybrid Executors Helm Chart + +This Helm chart deploys Tobiko Cloud hybrid executors, enabling your on-premise Kubernetes cluster to connect to Tobiko Cloud for operations. + +Hybrid executors allow you to run operations on your own infrastructure while leveraging Tobiko Cloud for orchestration. + +## What this chart does + +This chart deploys two hybrid executors that pass work tasks from Tobiko Cloud to your data warehouse in a secure way: + +- **Apply Executor**: Handles applying changes to the data warehouse +- **Run Executor**: Handles scheduled model execution + +Both executors must be properly configured with environment variables to connect to Tobiko Cloud and your data warehouse. + +## Prerequisites + +- Access to a [data warehouse supported by Tobiko Cloud](../../../integrations/overview.md#execution-engines) (e.g., Postgres, Snowflake, BigQuery) +- Helm 3.8+ +- A Tobiko Cloud account with [client ID and client secret](../security/single_sign_on.md#provisioning-client-credentials) + +## Quick start guide + +Create a `values.yaml` file with your Tobiko Cloud configuration. + +```bash +# Create a values file +cat > my-values.yaml << EOF +global: + cloud: + org: "your-organization" + project: "your-project" + clientId: "your-client-id" + clientSecret: "your-client-secret" + sqlmesh: + gateways: + gateway_a: + connection: + type: postgres + host: "your-database-host" + port: 5432 + database: "your-database" + user: "your-database-user" +EOF +``` + +### Option 1: Install Directly with Helm + +```bash +# Install the chart from local directory +helm install executors oci://registry-1.docker.io/tobikodata/hybrid-executors -f my-values.yaml +``` + +### Option 2: Generate YAML files without installing + +If you prefer to review and apply the Kubernetes YAML files manually: + +```bash +# Generate YAML files without installing +helm template executors oci://registry-1.docker.io/tobikodata/hybrid-executors -f my-values.yaml > generated-manifests.yaml +``` + +```bash +# Review the generated files +cat generated-manifests.yaml +``` + +```bash +# Apply when ready +kubectl apply -f generated-manifests.yaml +``` + +## Basic configuration + +The most important configuration values are: + +| Parameter | Description | Required | +|-----------------------------|-------------------------------------|--------------------| +| `global.cloud.org` | Your Tobiko Cloud organization name | Yes | +| `global.cloud.project` | Your Tobiko Cloud project name | Yes | +| `global.cloud.clientId` | Your Tobiko Cloud client ID | Yes | +| `global.cloud.clientSecret` | Your Tobiko Cloud client secret | Yes | +| `global.sqlmesh.gateways` | Database connections configuration | Yes | + +### Gateway configuration + +Configure your gateway's SQL engine connection in the `global.sqlmesh.gateways` section: + +```yaml +global: + sqlmesh: + gateways: + gateway_a: # Put the default gateway first if defining multiple gateways + connection: + type: postgres # Or snowflake, bigquery, etc. + host: "your-db-host" + port: 5432 + database: "sqlmesh" + user: "sqlmesh_user" + # Password should be managed as a secret (see below) +``` + +## Secret management options + +The chart provides multiple options for managing secrets. Use the one most aligned with your security requirements and deployment patterns. + +### Context: Helm's dynamic secret detection + +The chart automatically treats `global.cloud.clientSecret` and any gateway connection parameter with keywords `password`, `secret`, or `token` in its name as a secret: + +```yaml +global: + cloud: + clientId: "your-client-id" # Not a secret + clientSecret: "your-client-secret" # Automatically treated as a secret (contains keyword "secret") + sqlmesh: + gateways: + gateway_a: + connection: + type: postgres # Not a secret + host: "my-db-host" # Not a secret + password: "p@ssw0rd" # Automatically treated as a secret (contains keyword "password") + client_secret: "xyz123" # Automatically treated as a secret (contains keyword "secret") + api_token: "abc456" # Automatically treated as a secret (contains keyword "token") + access_key: "key123" # Not a secret +``` + +Use the `secretParams` key to force parameters to be treated as secrets (even if their name doesn't contain a secret keyword): + +```yaml +global: + sqlmesh: + # Force these parameters to be treated as secrets regardless of name + secretParams: ["access_key", "certificate"] + # Force these parameters to NOT be treated as secrets even if they contain secret keywords + nonSecretParams: ["token_endpoint", "password_policy"] +``` + +### Option 1: Secrets directly in values.yaml (development only) + +!!! warning "Development only" + + This approach is only recommended for development environments and testing. + + Never store secrets in plain text in version control. + +Define secrets directly in the values file: + +```yaml +global: + cloud: + clientId: "your-tobiko-cloud-client-id" # Not a secret + clientSecret: "your-tobiko-cloud-client-secret" # Automatically treated as a secret +``` + +### Option 2: Existing Kubernetes Secrets + +Reference an existing Kubernetes Secret: + +```yaml +secrets: + existingSecret: "my-existing-secret" +``` + +The existing secret must contain the required keys: + +- `TCLOUD_CLIENT_SECRET` +- `SQLMESH__GATEWAYS____CONNECTION__` for each gateway secret + +If your executors use different secrets, you can specify secrets at the executor level: + +```yaml +apply: + envFromSecret: "apply-executor-secrets" +run: + envFromSecret: "run-executor-secrets" +``` + +### Option 3: External Secrets Operator + +If you're using [External Secrets Operator](https://external-secrets.io/), you can pull secrets from your secret store: + +```yaml +secrets: + externalSecrets: + enabled: true + secretStore: "aws-secretsmanager" + keyPrefix: "sqlmesh/" +``` + +This will create an ExternalSecret resource that pulls secrets from your configured secret provider. + +### Example: Creating a secret manually + +Create a secret with all required credentials: + +```bash +kubectl create secret generic my-sqlmesh-secrets \ + --from-literal=TCLOUD_CLIENT_SECRET=your-client-secret \ + --from-literal=SQLMESH__GATEWAYS__GATEWAY_A__CONNECTION__PASSWORD=your-password \ + --from-literal=SQLMESH__GATEWAYS__GATEWAY_A__CONNECTION__API_TOKEN=your-token \ + --from-literal=SQLMESH__GATEWAYS__GATEWAY_B__CONNECTION__CLIENT_SECRET=another-secret +``` + +Then reference it in your values: + +```yaml +secrets: + existingSecret: "my-sqlmesh-secrets" + +global: + cloud: + clientId: "your-client-id" # Not a secret + # Secret values will be loaded from the secret + # clientSecret: (loaded from secret) + sqlmesh: + gateways: + gateway_a: + connection: + # Define all non-secret parameters here + type: postgres + host: "my-db-host" + port: 5432 + database: "sqlmesh" + user: "sqlmesh_user" + # Secret values will be loaded from the secret + # password: (loaded from secret) + # api_token: (loaded from secret) + gateway_b: + connection: + type: snowflake + # Other non-secret parameters + # client_secret: (loaded from secret) +``` + +### Example: Using an existing secret + +```yaml +# values.yaml +secrets: + existingSecret: "my-sqlmesh-secrets" + +global: + image: + repository: tobikodata/tcloud + tag: latest + cloud: + org: "my-organization" + project: "my-project" + clientId: "your-client-id" + sqlmesh: + gateways: + gateway_a: + connection: + type: postgres + host: "my-db-host" + port: 5432 + database: "sqlmesh" + user: "sqlmesh_user" + +apply: + replicaCount: 1 + +run: + replicaCount: 2 +``` + +## Defining Custom Environment Variables + +If there are additional environment variables that are required to run your project, you will want to define them for both the apply and run executors. + +```yaml +apply: + extraEnvVars: + - name: MY_CUSTOM_ENV_VAR + value: "my_value" +run: + extraEnvVars: + - name: MY_CUSTOM_ENV_VAR + value: "my_value" +``` + +## Customizing resources + +You can customize CPU, memory, and ephemeral-storage for each executor. This sets the resources for the `apply` executor: + +```yaml +apply: + resources: + requests: + memory: "2Gi" + cpu: "1" + ephemeral-storage: "10Gi" + limits: + memory: "4Gi" + cpu: "2" + ephemeral-storage: "10Gi" +``` + +## Verifying the installation + +After installation, check that the executors are running: + +```bash +kubectl get pods -l app.kubernetes.io/instance=my-executors +``` + +You should see pods for both apply and run executors: +``` +NAME READY STATUS RESTARTS AGE +my-executors-apply-7b6c9d8f9-abc12 1/1 Running 0 1m +my-executors-run-6d5b8c7e8-def34 1/1 Running 0 1m +``` + +## Troubleshooting + +If your executors aren't starting properly, check the logs: + +```bash +kubectl logs -l app.kubernetes.io/instance=my-executors,app.kubernetes.io/component=apply-executor +kubectl logs -l app.kubernetes.io/instance=my-executors,app.kubernetes.io/component=run-executor +``` + +Common issues: + +- Incorrect client ID or client secret +- SQL engine connection issues +- Insufficient permissions \ No newline at end of file diff --git a/docs/cloud/features/scheduler/hybrid_executors_overview.md b/docs/cloud/features/scheduler/hybrid_executors_overview.md new file mode 100644 index 0000000000..ae07cfc364 --- /dev/null +++ b/docs/cloud/features/scheduler/hybrid_executors_overview.md @@ -0,0 +1,172 @@ +# Overview + +In a standard deployment, Tobiko Cloud securely manages your data warehouse connections so it can run your project. + +However, you may prefer not to share your data warehouse credentials or want to bring the execution closer to your data. To support this, Tobiko Cloud offers hybrid deployments where we host the scheduler and you host the executors that perform the scheduled actions. + +With this approach, Tobiko Cloud uses project metadata to manage SQLMesh user access control, schedule and trigger runs, and apply plans, but all data access and query execution occurs within your infrastructure. Tobiko Cloud has no access to your data or warehouse credentials. + +This gives you complete control over data security and network access while still benefiting from Tobiko Cloud's powerful scheduling capabilities. + +## How it works + +Tobiko Cloud has three primary tasks: determine what should happen when (scheduling), make those things happen (executing), and monitor everything that happens (observing). + +In a standard deployment, all three of these occur within the Tobiko Cloud environment. You configure a gateway in Tobiko Cloud, and Tobiko Cloud uses it to execute work tasks (such as a `plan` or `run`). + +In a hybrid deployment, Tobiko Cloud does not execute tasks directly with the engine. Instead, it passes tasks to the executors hosted in your environment, which then execute the tasks with the engine. This extra layer between Tobiko Cloud and your SQL engine means Tobiko Cloud has no knowledge of your credentials. + +Executors are Docker containers that connect to both Tobiko Cloud and your SQL engine. They pull work tasks from the Tobiko Cloud scheduler and execute them with your SQL engine. + +![Architecture for standard and hybrid deployments](./hybrid_executors/hybrid-executors_standard-hybrid-deployment.png) + +## Deployment Options + +You can deploy the executor containers using any method that works for your infrastructure and operational requirements. + +The executors are standard Docker containers that can be deployed in any container environment as long as they're configured with the required environment variables. + +We provide two reference implementations: + +1. [**Kubernetes with Helm Chart**](./hybrid_executors_helm.md): For production environments, we provide a [Helm chart](./hybrid_executors_helm.md) that includes robust configurability, secret management, and scaling options. + +2. [**Docker Compose**](./hybrid_executors_docker_compose): For simpler environments or testing, we offer a [Docker Compose setup](./hybrid_executors_docker_compose) to quickly deploy executors on any machine with Docker. + +You're free to adapt these reference implementations or create your own deployment method that fits your specific needs. + +As described below, two executor instances must be running and properly configured at all times (one executor for `run` operations and one for `apply` operations). + +## Configuration + +This section describes basic configuration concepts for hybrid executors. For detailed configuration options, refer to the documentation for your chosen deployment method above. + +Tobiko Cloud requires 2 executor instances to be running at all times: + +1. **Run Executor**: Handles scheduled model execution +2. **Apply Executor**: Handles applying changes to the data warehouse + +Both executors need to be properly configured with environment variables for connecting to Tobiko Cloud and your data warehouse. + +### Environment Variables + +Executors require different types of information to connect to Tobiko Cloud and your data warehouse. Provide that information via environment variables. + +#### TCLOUD variables + +One important type of environment variable is the `TCLOUD` variables used for connecting to Tobiko Cloud. + +The first required `TCLOUD` variable is a unique Tobiko Cloud URL for your project, which your Solutions Architect will provide after your project is created. + +You also need the Client ID and Client Secret variables, which are generated when you [create an OAuth Client](../security/single_sign_on.md#provisioning-client-credentials) in the Tobiko Cloud UI. + +Specify the URL, Client ID, and Client Secret in these environment variables: + +``` bash +TCLOUD_URL={your Tobiko Cloud project URL} +TCLOUD_CLIENT_ID={your Client ID} +TCLOUD_CLIENT_SECRET={your Client Secret} +``` + +!!! important "Set TCLOUD variables on the Docker container" + + Environment variables used for connecting to Tobiko Cloud, such as `TCLOUD_URL`, `TCLOUD_CLIENT_ID`, and `TCLOUD_CLIENT_SECRET`, must be set on the executor's Docker container. + +#### Other environment variables + +The executors also require configuration parameters for other aspects of your project. + +For example, your executor must know how to connect to your SQL engine, so you must configure a [gateway via environment variables](../../../guides/configuration.md#overrides). + +This example specifies a Postgres gateway named "GATEWAY_A" and set it as the default gateway: + +```env +SQLMESH__GATEWAYS__GATEWAY_A__CONNECTION__TYPE=postgres +SQLMESH__GATEWAYS__GATEWAY_A__CONNECTION__HOST=10.10.10.10 +SQLMESH__GATEWAYS__GATEWAY_A__CONNECTION__PORT=5432 +SQLMESH__GATEWAYS__GATEWAY_A__CONNECTION__DATABASE=example_db +SQLMESH__GATEWAYS__GATEWAY_A__CONNECTION__USER=example_user +SQLMESH__GATEWAYS__GATEWAY_A__CONNECTION__PASSWORD=example_password + +# make it the default gateway +SQLMESH__DEFAULT_GATEWAY=GATEWAY_A +``` + +**Note**: If your project uses multiple gateways, each gateway requires its own set of environment variables. + +For example, we might add a second gateway named "GATEWAY_B" like this. Note that the gateway names `GATEWAY_A` and `GATEWAY_B` are embedded in the environment variable name: + +```env +SQLMESH__GATEWAYS__GATEWAY_A__CONNECTION__TYPE= +# + +SQLMESH__GATEWAYS__GATEWAY_B__CONNECTION__TYPE= +# +``` + +For more configuration details, including secure secret management options, refer to the [Helm chart](./hybrid_executors_helm.md) or [Docker Compose](./hybrid_executors_docker_compose) deployment documentation. + +After the executors are configured and running, they will connect to Tobiko Cloud. Once connected, they will appear in the cloud UI and be ready to apply `plan`s and execute scheduled `run`s. + +![executors](./scheduler/executors.png) + +We strongly recommend setting up a monitoring system for the executor containers to ensure they run smoothly and to help troubleshoot issues. Monitoring should include logs and system metrics like memory and CPU usage. + +#### .env file + +As mentioned above, `TCLOUD` environment variables **must** be set on the executor's Docker container. However, other environment variables can be set later because they are only needed after the executor is connected to Tobiko Cloud. + +Instead of setting a variable on the container, you can use a `.env` file. This can be useful for environment variables that require frequent updating, such as short-lived API tokens. + +!!! note "Only if necessary" + + Using a `.env` file increases the complexity of your deployment, so only use it if you have environment variables that require frequent updates. + +Create a `.env` file and define your environment variables in it. Then, mount the `.env` file into the docker image. Any external process can update the file with a new variable value, and the executor will automatically pick up the changes. + +Tell the executor where to find the `.env` file by specifying the file's full path in the `TCLOUD` environment variable `TCLOUD_ENV_FILE`. + +**Note**: the `TCLOUD_ENV_FILE` environment variable **must** be set on the executor's Docker container, just like the `TCLOUD_URL`, `TCLOUD_CLIENT_ID`, and `TCLOUD_CLIENT_SECRET` variables. + +### Network Configuration + +Tobiko Cloud never pushes information into your network, so it doesn't need inbound access. + +Instead, the executors you host make outbound requests polling Tobiko Cloud for work tasks and other information. Executors are only required to connect to Tobiko Cloud and your SQL engine, with all information flowing via outbound requests. + +If the executors are running in a network without public internet access, configure the network to allow executor and local user access to Tobiko Cloud on these IP addresses: + +```bash +34.28.17.91 +34.136.27.153 +34.136.131.201 +``` + +### Project Configuration + +Project configuration is the same for hybrid and standard Tobiko Cloud deployments. + +See [connection configuration](./scheduler.md#connection-configuration) for details. + +## Required system specs + +The exact system requirements for executors vary depending on your project and work processes. + +In general, we recommend a minimum of 2GB of RAM and 1 vCPU for each executor. + +Complex or resource-intensive Python models may require executors with more resources. + +## Health checks + +In production settings, we recommend setting up health checks to monitor the status of your executors. + +Health checks help ensure your executors are operating correctly and can identify issues before they impact your workflows. + +For detailed information on implementing health checks: + +- **Kubernetes/Helm**: see the [Hybrid Executors Helm Chart documentation](./hybrid_executors_helm.md#verifying-the-installation) for information on health check configuration in Kubernetes. + +- **Docker Compose**: see the [Docker Compose setup documentation](./hybrid_executors_docker_compose#health-checking) for health check implementation with Docker Compose. + +Both executor types (run and apply) should have appropriate health checks to ensure proper system monitoring and reliability. + +**Note** When configuring health checks, ensure timeouts are set appropriately based on your executors' resources. Default timeouts can sometimes be too short. diff --git a/docs/cloud/features/scheduler/scheduler.md b/docs/cloud/features/scheduler/scheduler.md new file mode 100644 index 0000000000..5d28a3be50 --- /dev/null +++ b/docs/cloud/features/scheduler/scheduler.md @@ -0,0 +1,243 @@ +# Scheduler + +Tobiko Cloud offers scheduling capabilities that have several advantages over the scheduler built into the open source version of SQLMesh. + +## Cloud scheduler benefits + +This section describes the specific advantages of using the Tobiko Cloud scheduler. + +### Schedule executions + +With Tobiko Cloud, users don't need to configure a cron job that periodically runs the `sqlmesh run` command. + +Instead, Tobiko Cloud automatically schedules model execution based on the cron expressions in the project's model definitions. + +### Concurrent runs + +Unlike the built-in scheduler, Tobiko Cloud parallelizes both model executions and run jobs. + +This means that if one run job is blocked by a long-running model, other independent models can still execute concurrently in separate run jobs. + +### Run pausing + +Tobiko Cloud allows you to pause and resume model execution at both the environment and individual model level. + +This granular control helps prevent problems during maintenance windows and troubleshoot issues. + +### Isolated Python environments + +Tobiko Cloud automatically manages Python dependencies of your Python macros and models. + +Each virtual environment has its own isolated Python environment and set of dependencies, ensuring that changes in one environment won't affect other environments. + +### Improved concurrency control + +The cloud scheduler ensures that plans targeting the same environment are applied sequentially, preventing race conditions and ensuring correct results. + +### Access control + +Tobiko Cloud manages your data warehouse connection. This allows users to execute `run` and `plan` commands without needing local access to warehouse credentials. + +Tobiko Cloud also provides fine-grained access control for the `run` and `plan` commands. User permissions may be limited to specific environments or models (coming soon). + +## Using the Cloud scheduler + +This section describes how to configure and use the Tobiko Cloud scheduler. + +### Connection configuration + +To start using the cloud scheduler, configure the connection to your data warehouse in the Tobiko Cloud UI. +

+ +**Step 1**: Click on the "Settings" tab in the sidebar. + +![settings_tab](./scheduler/settings_tab.png) +

+ +**Step 2**: Navigate to the "Connections" tab (1) and click on the "Add Connection" button (2). + +![add_connection](./scheduler/add_connection.png) +

+ +**Step 3**: Enter the name of the gateway in the Gateway field and the connection configuration in YAML format in the YAML Configuration field. + +The format follows the [connection configuration](../../../guides/configuration.md#connections) in the SQLMesh Connections guide. + +!!! warning "Gateway names must match" + + This gateway name must match the name of the gateway specified in the project's `config.yaml` file. + +![add_connection_form](./scheduler/add_connection_form.png) +

+ +**Step 4**: Click the "Save" button to add the connection. + +The connection will be tested and only saved if the connection is successful. The configuration is stored in encrypted form using AES-256 encryption and is only decrypted for execution purposes. + +![add_connection_success](./scheduler/add_connection_success.png) +

+ +**Step 5**: Switch to the Cloud scheduler in the project's configuration file. + +Update your project's `config.yaml` file to specify a scheduler of type `cloud`. + +!!! warning "Gateway name must match" + + This gateway name must match the name of the gateway that was specified when adding the connection in the Tobiko Cloud UI. + +=== "YAML" + + ```yaml linenums="1" hl_lines="3 4" + gateways: + gateway_a: + scheduler: + type: cloud + + default_gateway: gateway_a + ``` + +=== "Python" + + ```python linenums="1" hl_lines="8" + from sqlmesh.core.config import GatewayConfig + + from tobikodata.sqlmesh_enterprise.config import EnterpriseConfig, RemoteCloudSchedulerConfig + + config = EnterpriseConfig( + gateways={ + "gateway_a": GatewayConfig( + scheduler=RemoteCloudSchedulerConfig() + ), + }, + default_gateway="gateway_a", + ) + ``` + +### Pausing model executions + +Temporarily pausing model execution can be useful when troubleshooting issues or during maintenance windows. + +Tobiko Cloud allows you to pause and resume model execution at both the environment and individual model level. + +#### Pausing all models in an environment + +To pause all models in an environment, navigate to the environment's page and click the "Pause" button. + +This will pause **all** model executions in this environment. + +![pause_environment](./scheduler/pause_environment.png) + +To resume the environment, click the "Resume" button. + +![resume_environment](./scheduler/resume_environment.png) +

+ +#### Pausing a model + +To pause a model in a specific environment, navigate to the environment's page and click "See all pauses" (located next to the "Pause" button). + +![see_all_pauses](./scheduler/see_all_pauses.png) + +In that page, click the "Create Pause" button. + +![create_pause](./scheduler/create_pause.png) + +Select the model you want to pause and provide a reason for pausing it (optional). + +![create_pause_form](./scheduler/create_pause_form.png) + +Click the "Create" button in the bottom right. + +The target model and its downstream dependencies will not be run in this environment. + +!!! note "Paused models included in plans" + + Paused models will not execute during a `run`, but they will execute during a `plan` application (if affected by the plan's changes). +

+ +#### Resuming a model + +To resume a model, navigate to an environment's pauses page and click the "Delete" button for that model's pause. + +![delete_pause](./scheduler/delete_pause.png) + +## Python Dependencies + +Tobiko Cloud automatically manages Python dependencies of your Python macros and models. Each virtual environment has its own isolated Python environment where relevant libraries are installed, ensuring that changes in one environment won't affect other environments. + +SQLMesh automatically infers which Python libraries are used by statically analyzing the code of your models and macros. + +For fine-grained control, dependencies can be specified, pinned, or excluded using the `sqlmesh-requirements.lock` file. See the [Python library dependencies](../../../guides/configuration.md#python-library-dependencies) section in the SQLMesh configuration guide for more information. + +## Secret Manager + +Tobiko Cloud provides a secrets manager where you can define environment variables for your project's Python models. + +These variables are most commonly used to provide sensitive information to Python models, such as API keys or other credentials. + +Secret values are encrypted at rest and only available in the environment of your running Python models. + +!!! note "Cloud Scheduler Only" + + Secrets from the secret manager do not load into hybrid executors. They are only used for cloud scheduler executors. + +Secret names have two restrictions - they must: + +- Start with a letter or an underscore +- Only include letters, numbers, and underscores (no spaces or other symbols) + +Secret values have no limits or restrictions. We recommend base64 encoding any secrets that contain binary data. + +### Defining secrets + +Define a secret on the Secrets page, accessible via the Settings section in Tobiko Cloud's left side navigation bar. + +The Secrets page has a single panel you use to create a new secret, edit the value of an existing secret, or remove an existing secret. You cannot view the value of any existing secret. + +In this example, only one secret has been defined: `MY_SECRET`. Update its value by entering a new value in the Secret field and clicking the `Update` button, or delete it by clicking the `Remove` button. + +![secrets_panel](./scheduler/secrets.png) + + +### Python Model Example + +This Python model demonstrates how to read the `MY_SECRET` secret from an environment variable. + +!!! danger "Protecting Secrets" + + Only read environment variables from inside a Python model's `execute` function definition (not in the global scope). + + If the variable is read in the global scope, SQLMesh will load the value from *your local system* when it renders the Python model instead of loading it at runtime on our executors. + + This could expose sensitive information or embed an incorrect local value in the rendered model. + +```python linenums="1" +import os +import pandas as pd +import typing as t +from datetime import datetime + +from sqlmesh import ExecutionContext, model + +# DO NOT read environment variables here. +# Only inside the `execute` function definition! + +@model( + "my_model.name", + columns={ + "column_name": "int", + }, +) +def execute( + context: ExecutionContext, + start: datetime, + end: datetime, + execution_time: datetime, + **kwargs: t.Any, +) -> pd.DataFrame: + + # Read a secret from the MY_SECRET environment variable + my_secret = os.environ["MY_SECRET"] + + ... +``` diff --git a/docs/cloud/features/scheduler/scheduler/.env.example b/docs/cloud/features/scheduler/scheduler/.env.example new file mode 100644 index 0000000000..c71357553b --- /dev/null +++ b/docs/cloud/features/scheduler/scheduler/.env.example @@ -0,0 +1,24 @@ +# Tobiko Cloud Configuration +ORGANIZATION=your-organization +PROJECT=your-project +TCLOUD_CLIENT_ID=your-client-id +TCLOUD_CLIENT_SECRET=your-client-secret + +# Database Configuration +DEFAULT_GATEWAY=GATEWAY_A +DB_TYPE=postgres +DB_HOST=your-database-host +DB_PORT=5432 +DB_NAME=your-database-name +DB_USER=your-database-user +DB_PASSWORD=your-database-password + +# Optional: Resource Limits +APPLY_MEMORY_LIMIT=4g +APPLY_CPU_LIMIT=2 +APPLY_MEMORY_REQUEST=2g +APPLY_CPU_REQUEST=1 +PLAN_MEMORY_LIMIT=4g +PLAN_CPU_LIMIT=2 +PLAN_MEMORY_REQUEST=2g +PLAN_CPU_REQUEST=1 \ No newline at end of file diff --git a/docs/cloud/features/scheduler/scheduler/add_connection.png b/docs/cloud/features/scheduler/scheduler/add_connection.png new file mode 100644 index 0000000000..c11930a4d5 Binary files /dev/null and b/docs/cloud/features/scheduler/scheduler/add_connection.png differ diff --git a/docs/cloud/features/scheduler/scheduler/add_connection_form.png b/docs/cloud/features/scheduler/scheduler/add_connection_form.png new file mode 100644 index 0000000000..f2b49d2052 Binary files /dev/null and b/docs/cloud/features/scheduler/scheduler/add_connection_form.png differ diff --git a/docs/cloud/features/scheduler/scheduler/add_connection_success.png b/docs/cloud/features/scheduler/scheduler/add_connection_success.png new file mode 100644 index 0000000000..7bf74c4bb6 Binary files /dev/null and b/docs/cloud/features/scheduler/scheduler/add_connection_success.png differ diff --git a/docs/cloud/features/scheduler/scheduler/add_oath_client.png b/docs/cloud/features/scheduler/scheduler/add_oath_client.png new file mode 100644 index 0000000000..4e255f775a Binary files /dev/null and b/docs/cloud/features/scheduler/scheduler/add_oath_client.png differ diff --git a/docs/cloud/features/scheduler/scheduler/create_pause.png b/docs/cloud/features/scheduler/scheduler/create_pause.png new file mode 100644 index 0000000000..08fdc40bd4 Binary files /dev/null and b/docs/cloud/features/scheduler/scheduler/create_pause.png differ diff --git a/docs/cloud/features/scheduler/scheduler/create_pause_form.png b/docs/cloud/features/scheduler/scheduler/create_pause_form.png new file mode 100644 index 0000000000..56c5758144 Binary files /dev/null and b/docs/cloud/features/scheduler/scheduler/create_pause_form.png differ diff --git a/docs/cloud/features/scheduler/scheduler/delete_pause.png b/docs/cloud/features/scheduler/scheduler/delete_pause.png new file mode 100644 index 0000000000..4c7e1be36f Binary files /dev/null and b/docs/cloud/features/scheduler/scheduler/delete_pause.png differ diff --git a/docs/cloud/features/scheduler/scheduler/docker-compose.yml b/docs/cloud/features/scheduler/scheduler/docker-compose.yml new file mode 100644 index 0000000000..0cbf2426eb --- /dev/null +++ b/docs/cloud/features/scheduler/scheduler/docker-compose.yml @@ -0,0 +1,72 @@ +services: + apply-executor: + image: tobikodata/tcloud:latest + platform: linux/amd64 + command: executor apply + restart: unless-stopped + environment: + # Tobiko Cloud connection + TCLOUD_URL: https://internal.cloud.tobikodata.com/sqlmesh/${ORGANIZATION}/${PROJECT} + TCLOUD_CLIENT_ID: ${TCLOUD_CLIENT_ID} + TCLOUD_CLIENT_SECRET: ${TCLOUD_CLIENT_SECRET} + + # SQLMesh configuration + SQLMESH__DEFAULT_GATEWAY: ${DEFAULT_GATEWAY:-GATEWAY_A} + + # Example database configuration (adjust for your database) + # All database parameters below should be customized for your specific setup + SQLMESH__GATEWAYS__GATEWAY_A__CONNECTION__TYPE: ${DB_TYPE} + SQLMESH__GATEWAYS__GATEWAY_A__CONNECTION__HOST: ${DB_HOST} + SQLMESH__GATEWAYS__GATEWAY_A__CONNECTION__PORT: ${DB_PORT} + SQLMESH__GATEWAYS__GATEWAY_A__CONNECTION__DATABASE: ${DB_NAME} + SQLMESH__GATEWAYS__GATEWAY_A__CONNECTION__USER: ${DB_USER} + SQLMESH__GATEWAYS__GATEWAY_A__CONNECTION__PASSWORD: ${DB_PASSWORD} + volumes: + # Optional volume for persistent storage if needed + - apply-executor-data:/app/data + deploy: + resources: + limits: + memory: ${APPLY_MEMORY_LIMIT:-4g} + cpus: ${APPLY_CPU_LIMIT:-2} + reservations: + memory: ${APPLY_MEMORY_REQUEST:-2g} + cpus: ${APPLY_CPU_REQUEST:-1} + + run-executor: + image: tobikodata/tcloud:latest + platform: linux/amd64 + command: executor run + restart: unless-stopped + environment: + # Tobiko Cloud connection + TCLOUD_URL: https://internal.cloud.tobikodata.com/sqlmesh/${ORGANIZATION}/${PROJECT} + TCLOUD_CLIENT_ID: ${TCLOUD_CLIENT_ID} + TCLOUD_CLIENT_SECRET: ${TCLOUD_CLIENT_SECRET} + + # SQLMesh configuration + SQLMESH__DEFAULT_GATEWAY: ${DEFAULT_GATEWAY:-GATEWAY_A} + + # Example database configuration (adjust for your database) + # All database parameters below should be customized for your specific setup + SQLMESH__GATEWAYS__GATEWAY_A__CONNECTION__TYPE: ${DB_TYPE} + SQLMESH__GATEWAYS__GATEWAY_A__CONNECTION__HOST: ${DB_HOST} + SQLMESH__GATEWAYS__GATEWAY_A__CONNECTION__PORT: ${DB_PORT} + SQLMESH__GATEWAYS__GATEWAY_A__CONNECTION__DATABASE: ${DB_NAME} + SQLMESH__GATEWAYS__GATEWAY_A__CONNECTION__USER: ${DB_USER} + SQLMESH__GATEWAYS__GATEWAY_A__CONNECTION__PASSWORD: ${DB_PASSWORD} + volumes: + # Optional volume for persistent storage if needed + - run-executor-data:/app/data + deploy: + resources: + limits: + memory: ${PLAN_MEMORY_LIMIT:-4g} + cpus: ${PLAN_CPU_LIMIT:-2} + reservations: + memory: ${PLAN_MEMORY_REQUEST:-2g} + cpus: ${PLAN_CPU_REQUEST:-1} + +volumes: + apply-executor-data: {} + run-executor-data: {} diff --git a/docs/cloud/features/scheduler/scheduler/executors.png b/docs/cloud/features/scheduler/scheduler/executors.png new file mode 100644 index 0000000000..b3df53da31 Binary files /dev/null and b/docs/cloud/features/scheduler/scheduler/executors.png differ diff --git a/docs/cloud/features/scheduler/scheduler/pause_environment.png b/docs/cloud/features/scheduler/scheduler/pause_environment.png new file mode 100644 index 0000000000..1027c879b2 Binary files /dev/null and b/docs/cloud/features/scheduler/scheduler/pause_environment.png differ diff --git a/docs/cloud/features/scheduler/scheduler/resume_environment.png b/docs/cloud/features/scheduler/scheduler/resume_environment.png new file mode 100644 index 0000000000..4d55bf4f18 Binary files /dev/null and b/docs/cloud/features/scheduler/scheduler/resume_environment.png differ diff --git a/docs/cloud/features/scheduler/scheduler/secrets.png b/docs/cloud/features/scheduler/scheduler/secrets.png new file mode 100644 index 0000000000..7873bf8b77 Binary files /dev/null and b/docs/cloud/features/scheduler/scheduler/secrets.png differ diff --git a/docs/cloud/features/scheduler/scheduler/see_all_pauses.png b/docs/cloud/features/scheduler/scheduler/see_all_pauses.png new file mode 100644 index 0000000000..549258de9a Binary files /dev/null and b/docs/cloud/features/scheduler/scheduler/see_all_pauses.png differ diff --git a/docs/cloud/features/scheduler/scheduler/settings_tab.png b/docs/cloud/features/scheduler/scheduler/settings_tab.png new file mode 100644 index 0000000000..6405f22830 Binary files /dev/null and b/docs/cloud/features/scheduler/scheduler/settings_tab.png differ diff --git a/docs/cloud/features/security/security.md b/docs/cloud/features/security/security.md new file mode 100644 index 0000000000..59b2149432 --- /dev/null +++ b/docs/cloud/features/security/security.md @@ -0,0 +1,79 @@ +# Security Overview + + +At Tobiko, we treat security as a first-class citizen because we know how valuable your data assets are. Our team follows and executes security best practices across each layer of our product. + +## Tobiko Cloud Standard Deployment + +Our standard Tobiko Cloud deployment consists of several components that are each responsible for different parts of the product. + +Below is a diagram of the components along with their descriptions. + +![tobiko_cloud_standard_deployment](./security/tcloud_standard_deployment.png){ width=80% height=60% style="display: block; margin: 0 auto" } + +- **Scheduler**: Orchestrates schedule cadence and hosts state metadata (code versions, logs, cost) +- **Executor**: Applies code changes and runs SQL queries (actual data processing in SQL Engine) and Python models in proper DAG order. +- **Gateway**: Stores authentication credentials for SQL Engine. Secured through encryption. +- **SQL Engine**: Processes and stores data based on the above instructions within the **customer’s** environment. + +## Tobiko Cloud Hybrid Deployment + +For some customers, our hybrid deployment option is a great fit. It provides a seamless experience with Tobiko Cloud but within your own VPC and infrastructure. + +In a hybrid deployment, Tobiko Cloud does not execute tasks directly with the engine. Instead, it passes tasks to the executors hosted in your environment, which then execute the tasks with the engine. + +Executors are Docker containers that connect to both Tobiko Cloud and your SQL engine. They pull work tasks from the Tobiko Cloud scheduler and execute them with your SQL engine. This is a pull-only mechanism authenticated through an OAuth Client ID/Secret. Whitelist IPs in your network to allow reaching Tobiko Cloud IPs from the executor: 34.28.17.91, 34.136.27.153, 34.136.131.20 + +Below is a diagram of the components along with their description. + +![tobiko_cloud_hybrid_deployment](./security/tcloud_hybrid_deployment.png){ width=80% height=60% style="display: block; margin: 0 auto" } + +- **Scheduler**: Orchestrates schedule cadence and hosts state metadata (code versions, logs, cost). **Never pushes** instructions to executor. +- **Executor**: Appplies code changes and runs SQL queries and Python models in proper DAG order (actual data processing in SQL Engine) +- **Gateway**: Stores authentication credentials for SQL Engine. Secured through your secrets manager or Kubernetes Secrets. +- **SQL Engine**: Processes and stores data based on the above instructions +- **Executor -> Scheduler**: A pull-only mechanism for obtaining work tasks. +- **Helm Chart**: For production environements, we provide a [Helm chart](../scheduler/hybrid_executors_helm.md) that includes robust configurability, secret management, and scaling options. +- **Docker Compose**: For simpler environments or testing, we offer a [Docker Compose setup](../scheduler/hybrid_executors_docker_compose.md) to quickly deploy executors on any machine with Docker. + + + +## Internal Code Practices + +We enforce coding standards throughout Tobiko to write, maintain, and collaborate on code effectively. These practice ensure consistency, maintainability, reliability, and most importantly, trust. + +A few key components of our internal code requirements: + +- We used signed Git commits, required approvers, and signed Docker artifacts. +- Each commit to a `main` branch must be approved by someone other than the author. +- We sign commits and register the key with GitHub ([Github Docs](https://docs.github.com/en/authentication/managing-commit-signature-verification/signing-commits)). +- Binaries are signed using cosign and OIDC for keyless ([Signing docs](https://docs.sigstore.dev/cosign/signing/overview/)). +- Attestations are created to certify an image, enforced with GCP Binary Authorization ([Attestation docs](https://cloud.google.com/binary-authorization/docs/key-concepts#attestations)). +- Encryption is a key feature of our security posture and is enforced at each stage of access. For example, the state database automatically encrypts all data. Credentials are also securely encrypted and stored. +- We back up each state database nightly and before upgrades. These backups are stored for 14 days. + +## Penetration Testing + +At least once a year, Tobiko engages a third-party security firm to perform a penetration test. This test evaluates our systems by identifying and attempting to exploit known vulnerabilities, focusing on critical external and/or internal assets. A detailed report is available upon request. + + +## Asset and Access Management + +### How do we protect PGP keys? + +If an employee loses their laptop, we don't need to get the old PGP key back because we can invalidate the key directly. + +We use GitHub to sign code commits. At the time the code was committed, the PGP key was valid. When an employee loses their laptop, we will invalidate it, and they will regenerate a new key to use in future commits. The old commits are still valid because the PGP key was valid at the time the commit was made. + +### How do we invalidate PGP keys if someone did steal it and could potentially use it? + +We would revoke access for the GitHub user account associated with the compromised key and not give it access again until the old PGP key is deprecated and a new key issued. + +### If someone steals a laptop, what's our continuity plan in protecting code? + +- All employee devices are monitored for proper encryption and password policies. +- Laptop protection is enforced through file encryption via Vanta. +- Mandatory lock screen after a timeout. +- We follow a formal IT asset disposal procedure to prevent key compromise through improper hardware disposal. +- See above for PGP key protection. +- Binaries are signed using Cosign and OIDC for keyless signing. diff --git a/docs/cloud/features/security/security/tcloud_hybrid_deployment.png b/docs/cloud/features/security/security/tcloud_hybrid_deployment.png new file mode 100644 index 0000000000..6573342f60 Binary files /dev/null and b/docs/cloud/features/security/security/tcloud_hybrid_deployment.png differ diff --git a/docs/cloud/features/security/security/tcloud_standard_deployment.png b/docs/cloud/features/security/security/tcloud_standard_deployment.png new file mode 100644 index 0000000000..5b79a3ceba Binary files /dev/null and b/docs/cloud/features/security/security/tcloud_standard_deployment.png differ diff --git a/docs/cloud/features/security/single_sign_on.md b/docs/cloud/features/security/single_sign_on.md new file mode 100644 index 0000000000..716fb10589 --- /dev/null +++ b/docs/cloud/features/security/single_sign_on.md @@ -0,0 +1,220 @@ +# SSO (Single Sign-On) + +## Overview + +Tobiko Cloud supports single sign-on (SSO) through OpenID and SAML 2.0 providers. + +This makes it easy to provision access to users and simplifies authentication. + + +## Setup & Prerequsites + +You must have an active Tobiko Cloud instance with SSO enabled. Please contact your account team to ensure this is enabled. + +If your Tobiko Cloud instance is setup to require SSO, then you won't need to provide a token in your `tcloud.yml` configuration. + +Below is an example of a `tcloud.yml` configuration: +```yaml +projects: + : + url: + token: # you won't need this anymore + gateway: + extras: + pip_executable: +default_project: +``` + +## Identity Providers + +Tobiko Cloud currently supports OpenID and SAML 2.0. + +### OpenID + +This provider implements [OpenID Connect Core +1.0](https://openid.net/specs/openid-connect-core-1_0.html) in order to allow us +to login with most OAuth2 login providers. + +There are two ways to use OpenID Providers. The first is a +if you use a shared provider like Google, Github, +Microsoft, etc. + +#### Google OAuth + +To enable Google OAuth, all we need is your domain (ex: `yourname@companyname.com`, `companyname.com` is the domain). From here, we can switch SSO on with Google OAuth. + +The login flow will look like the following if you access [cloud.tobikodata.com/auth/login](https://cloud.tobikodata.com/auth/login) directly from your browser. If authenticating through CLI see [here](../security/single_sign_on.md#status) for more details. + +
+ +#### Other OAuth Providers + +If you use Okta and other custom OpenID/OAuth2 providers you need to add us +as an Application or Client (terms differ across providers). + +You will need the following information to do this: + +| Name | Purpose | Value | +|--------------|--------------------------------------------------------------------------------------------------------------------------------------|--------------------------------------------------------| +| Redirect URI | Where the OAuth provider should redirect users to after a successfull login. Can also be called "Callback URL" or something similar. | `https://cloud.tobikodata.com/auth/handler/` | +| Logout URL | Where users can go to log out of our system | `https://cloud.tobikodata.com/auth/logout` | +| Web Origin | Which host names our OAuth service uses | `https://cloud.tobikodata.com` | + +Often only a Redirect URI is required, but some providers like the additional +information as well. + +We will need the following information from you once you set us up: + +| Name | Purpose | Example | +|---------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------| +| Client ID | The random ID we use to communicate with their OAuth service | `` | +| Client Secret | The random secret we use to authentication with their OAuth service | `` | +| Open ID Configuration URL | This is the URL we use to gather the rest of their OpenID Configuration. We can often find this on our own and don't need to request it from them, check with the onboarding engineer to make sure we know this. | + +Once we have the above information, we can enable SSO on your account. You will then follow the login flow through your provider such as logging in through Okta. + +### SAML V2.0 + +This provider uses [python3-saml](https://github.com/SAML-Toolkits/python3-saml) +to support SAML V2.0 authentication. + +#### Requirements + +If you are using a SAML provider we need to receive three pieces of +information from you below: + +| Name | Purpose | Example | +|-------------|----------------------------------------------------|-------------------------------------------------------------------------------------------------------------------| +| Entity ID | This is the providers Entity ID | `https://saml.example.com/entityid` | +| SSO URL | This is the URL to use for SSO | `https://mocksaml.com/api/saml/sso` | +| Certificate | The certificate of the SAML Provider in PEM format | [PEM Certificates](https://www.ssl.com/guide/pem-der-crt-and-cer-x-509-encodings-and-conversions/#ftoc-heading-1) | + +We will provide a similar set of information below: + +| Name | Purpose | Example | +|-------------|---------------------------------------|--------------------------------------------------------------| +| Metadata URL| TThis URL contains all of this information | `https://cloud.tobikodata.com/auth/saml/metadata/` | +| Entity ID | This is our Entity ID | `https://cloud.tobikodata.com` | +| SSO URL | This is our HTTP-Redirect Binding URL | `https://cloud.tobikodata.com/auth/saml/callback/` | +| Certificate | Our SAML Certificate | **TBD** | + +All data except the certificate will change per provider. For example if we had +a provider named `acme`: + +- **Metadata URL**: `https://cloud.tobikodata.com/auth/saml/metadata/acme` +- **Entity ID**: `https://cloud.tobikodata.com/auth/saml/metadata/acme` +- **SSO URL**: `https://cloud.tobikodata.com/auth/saml/callback/acme` + +### Okta Integration + +The following instructions will walk you through configuring Okta as your identity provider. +Log into your Okta account. Navigate to Application and create a new app. You will want to select SAML 2.0 + +![okta_setup_1](./single_sign_on/okta_setup_1.png) + +Next, name your app "Tobiko Cloud". You can add the app logo by downloading the image [here](https://avatars.githubusercontent.com/u/113925670?s=200&v=4). + +![okta_setup_2](./single_sign_on/okta_setup_2.png) + +#### SAML Configurations and Settings + +1. We now need to fill in the SAML Settings. Please enter the following values: + + + - **Single sign-on URL**: `https://cloud.tobikodata.com/auth/saml/callback/acme` + - **Audience URI (SP Entity ID)**: `https://cloud.tobikodata.com/auth/saml/metadata/acme` + + ![okta_setup_3](./single_sign_on/okta_setup_3.png) + +2. Fill in the Attribute Statements section with email, firstName, and lastName: These are required to properly map to your users. + + ![okta_setup_4](./single_sign_on/okta_setup_4.png) + +3. Click next and now you are on the last step. Check off the box `Contact app vendor` and hit `Finish`. Now you're all set! + + ![okta_setup_5](./single_sign_on/okta_setup_5.png) + +Here is what you will see if you are accessing Tobiko Cloud via Okta. Click on the Tobiko Cloud icon to be redirected to the application. + +![sso_okta](./single_sign_on/sso_okta.png) + +## Authentication Workflow + +### Status + +You can see what the status of your session is with the `status` command: + +``` bash +$ tcloud auth status +``` + + +![tcloud_auth](./single_sign_on/tcloud_auth.png) + +### Login + +Run the `login` command to begin the login process: + +``` bash +$ tcloud auth login +``` + +![tcloud_login](./single_sign_on/tcloud_login.png) + +At this point your system browser should open and allow you to log in. If you are already logged in, this should be a very quick process. It will look like the below: + +![tcloud_auth_browser_login](./single_sign_on/tcloud_auth_browser_login.png) + +![tcloud_auth_browser_success](./single_sign_on/tcloud_auth_browser_success.png) + + + +After you have authenticated, you will be prompted with a success message in your browser and a message telling you that it's safe to close your browser window. Your terminal will then have the following result: + +``` bash +Success! ✅ + +Current Tobiko Cloud SSO session expires in 1439 minutes +``` + + +### Logging Out + +In order to delete your session information you can use the log out command: + +``` bash +> tcloud auth logout +Logged out of Tobiko Cloud + +> tcloud auth status +Not currently authenticated +``` + +![tcloud_logout](./single_sign_on/tcloud_logout.png) + +Otherwise, you will be logged out automatically when the SSO session expires (every 24 hours). + +## OAuth Clients + +Sometimes, you want to grant an external service access to your Tobiko Cloud project. For example, the external service could be the [CICD bot](../../../integrations/github.md) or a [scheduler integration](../scheduler/airflow.md). + +These services take `Client ID` and `Client Secret` credentials. + +!!! Info "One set of credentials per service" + It's best practice to provision a separate set of credentials for each service that you wish to connect to Tobiko Cloud. This gives you the flexibility to revoke credentials for a specific service without affecting access for other services. + +### Provisioning client credentials + +To provision OAuth credentials for a new service, browse to `Settings -> OAuth Clients` in the lefthand navigation menu. + +In the page's Create new Client section, enter a client name and human readable description: + +![Add new OAuth Client](./single_sign_on/oauth_client_1.png) + +Once you click `Save`, the client will be added to the list: + +![OAuth Client List](./single_sign_on/oauth_client_2.png) + +To fetch the Client ID or Client Secret, click `Copy ID` or `Copy Secret`. The values will be copied to the system clipboard. + +Paste these values into an external service's authentication configuration so it can connect to your Tobiko Cloud project. \ No newline at end of file diff --git a/docs/cloud/features/security/single_sign_on/oauth_client_1.png b/docs/cloud/features/security/single_sign_on/oauth_client_1.png new file mode 100644 index 0000000000..81e01c230b Binary files /dev/null and b/docs/cloud/features/security/single_sign_on/oauth_client_1.png differ diff --git a/docs/cloud/features/security/single_sign_on/oauth_client_2.png b/docs/cloud/features/security/single_sign_on/oauth_client_2.png new file mode 100644 index 0000000000..53b93580d0 Binary files /dev/null and b/docs/cloud/features/security/single_sign_on/oauth_client_2.png differ diff --git a/docs/cloud/features/security/single_sign_on/okta_setup_1.png b/docs/cloud/features/security/single_sign_on/okta_setup_1.png new file mode 100644 index 0000000000..79f8a18229 Binary files /dev/null and b/docs/cloud/features/security/single_sign_on/okta_setup_1.png differ diff --git a/docs/cloud/features/security/single_sign_on/okta_setup_2.png b/docs/cloud/features/security/single_sign_on/okta_setup_2.png new file mode 100644 index 0000000000..fe7df25e66 Binary files /dev/null and b/docs/cloud/features/security/single_sign_on/okta_setup_2.png differ diff --git a/docs/cloud/features/security/single_sign_on/okta_setup_3.png b/docs/cloud/features/security/single_sign_on/okta_setup_3.png new file mode 100644 index 0000000000..583faf50a4 Binary files /dev/null and b/docs/cloud/features/security/single_sign_on/okta_setup_3.png differ diff --git a/docs/cloud/features/security/single_sign_on/okta_setup_4.png b/docs/cloud/features/security/single_sign_on/okta_setup_4.png new file mode 100644 index 0000000000..e11e4111a2 Binary files /dev/null and b/docs/cloud/features/security/single_sign_on/okta_setup_4.png differ diff --git a/docs/cloud/features/security/single_sign_on/okta_setup_5.png b/docs/cloud/features/security/single_sign_on/okta_setup_5.png new file mode 100644 index 0000000000..f4d2a32c27 Binary files /dev/null and b/docs/cloud/features/security/single_sign_on/okta_setup_5.png differ diff --git a/docs/cloud/features/security/single_sign_on/sso_okta.png b/docs/cloud/features/security/single_sign_on/sso_okta.png new file mode 100644 index 0000000000..7656a91584 Binary files /dev/null and b/docs/cloud/features/security/single_sign_on/sso_okta.png differ diff --git a/docs/cloud/features/security/single_sign_on/tcloud_auth.png b/docs/cloud/features/security/single_sign_on/tcloud_auth.png new file mode 100644 index 0000000000..18a3d75a78 Binary files /dev/null and b/docs/cloud/features/security/single_sign_on/tcloud_auth.png differ diff --git a/docs/cloud/features/security/single_sign_on/tcloud_auth_browser_login.png b/docs/cloud/features/security/single_sign_on/tcloud_auth_browser_login.png new file mode 100644 index 0000000000..0d9f483cf9 Binary files /dev/null and b/docs/cloud/features/security/single_sign_on/tcloud_auth_browser_login.png differ diff --git a/docs/cloud/features/security/single_sign_on/tcloud_auth_browser_success.png b/docs/cloud/features/security/single_sign_on/tcloud_auth_browser_success.png new file mode 100644 index 0000000000..429e178815 Binary files /dev/null and b/docs/cloud/features/security/single_sign_on/tcloud_auth_browser_success.png differ diff --git a/docs/cloud/features/security/single_sign_on/tcloud_login.png b/docs/cloud/features/security/single_sign_on/tcloud_login.png new file mode 100644 index 0000000000..285328b8cf Binary files /dev/null and b/docs/cloud/features/security/single_sign_on/tcloud_login.png differ diff --git a/docs/cloud/features/security/single_sign_on/tcloud_logout.png b/docs/cloud/features/security/single_sign_on/tcloud_logout.png new file mode 100644 index 0000000000..fb581e7973 Binary files /dev/null and b/docs/cloud/features/security/single_sign_on/tcloud_logout.png differ diff --git a/docs/cloud/features/upgrade/upgrade-ui-available.png b/docs/cloud/features/upgrade/upgrade-ui-available.png new file mode 100644 index 0000000000..a9ea90fa08 Binary files /dev/null and b/docs/cloud/features/upgrade/upgrade-ui-available.png differ diff --git a/docs/cloud/features/upgrade/upgrade-ui-custom-version.png b/docs/cloud/features/upgrade/upgrade-ui-custom-version.png new file mode 100644 index 0000000000..810861a741 Binary files /dev/null and b/docs/cloud/features/upgrade/upgrade-ui-custom-version.png differ diff --git a/docs/cloud/features/upgrade/upgrade-ui-latest.png b/docs/cloud/features/upgrade/upgrade-ui-latest.png new file mode 100644 index 0000000000..c6e34cb534 Binary files /dev/null and b/docs/cloud/features/upgrade/upgrade-ui-latest.png differ diff --git a/docs/cloud/features/upgrade/upgrade-ui-progress.png b/docs/cloud/features/upgrade/upgrade-ui-progress.png new file mode 100644 index 0000000000..78282628ec Binary files /dev/null and b/docs/cloud/features/upgrade/upgrade-ui-progress.png differ diff --git a/docs/cloud/features/upgrade/upgrade-ui-up-to-date.png b/docs/cloud/features/upgrade/upgrade-ui-up-to-date.png new file mode 100644 index 0000000000..e044ad1d8d Binary files /dev/null and b/docs/cloud/features/upgrade/upgrade-ui-up-to-date.png differ diff --git a/docs/cloud/features/upgrades.md b/docs/cloud/features/upgrades.md new file mode 100644 index 0000000000..c6ee00d713 --- /dev/null +++ b/docs/cloud/features/upgrades.md @@ -0,0 +1,75 @@ +# Upgrading Tobiko Cloud + +Tobiko regularly releases new versions of Tobiko Cloud that add features and improve reliability. + +This page describes how to upgrade your Tobiko Cloud projects to a newer version. + +## Upgrade availability + +Navigate to `Settings > Upgrade` in the Tobiko Cloud UI to determine whether a new version of Tobiko Cloud is available for your project. + +If your project is already up to date, you will see a grey message: + +![Tobiko Cloud Upgrade Already Up-to-Date](./upgrade/upgrade-ui-up-to-date.png) + +If a new version is available for your project, the page will include a notification box, version, and blue Upgrade Now button: + +![Tobiko Cloud Upgrade Page](./upgrade/upgrade-ui-available.png) + +## Upgrading a project + +On the Upgrade page, you can choose to upgrade to the latest version or specify a custom version. + +!!! info "Upgrade Permissions" + Only users with Tobiko Cloud `Admin` permissions can perform upgrades. + +!!! danger "Upgrade Side Effects" + The upgrade process may take a few minutes to complete. During this time, your Tobiko Cloud project will be unavailable. + + Any in-progress plans and runs will be aborted: + + - Aborted plans will be stopped, and you must **manually** start them again. + - Aborted runs will be automatically resumed shortly after the upgrade completes. + + To avoid unexpected interruptions, please notify your team before starting the upgrade. + +### Latest Version + +Click the **Upgrade Now** button and confirm to begin upgrading your project to the latest version. + +![Tobiko Cloud Upgrade Page](./upgrade/upgrade-ui-latest.png) + +### Custom Version + +We recommend upgrading your Tobiko Cloud project to the latest version, but you may prefer to upgrade to a specific version. + +For example, consider a team that has separate staging and production Tobiko Cloud projects. They upgrade the staging project first, run tests, and only upgrade the production project after verifying that staging works as expected. + +If a new version of Tobiko Cloud is released during this testing period, the latest available version will not match the version tested in staging. The team can specify a custom Tobiko Cloud version to upgrade production to the specific version that was already tested in staging. + +To specify a custom version, select the **Custom** tab on the Upgrade page and enter your desired version in the text box. + +![Tobiko Cloud Upgrade Custom Version](./upgrade/upgrade-ui-custom-version.png) + +Make sure you are entering a valid custom version by: + + - Entering the custom version **without** the leading `v` + - Confirming that the version is valid and later than the current version of the project + +If your custom version is not valid, Tobiko Cloud will display an error message. + +After entering the custom version, click the **Upgrade Now** button and confirm to begin the upgrade process. + +## Upgrade Progress + +Tobiko Cloud will display a progress page while the upgrade is in progress: + +![Tobiko Cloud Upgrade Progress](./upgrade/upgrade-ui-progress.png) + +Once the upgrade is complete, Tobiko Cloud will automatically redirect you back to your upgraded project. + +## Upgrade Support + +If you encounter an issue during the upgrade process, please [report an incident](./incident_reporting.md). Our support team will follow up as soon as possible. + +For the quickest response, we recommend upgrading Monday through Friday between 9am and 5pm PST. \ No newline at end of file diff --git a/docs/cloud/features/xdb_diffing.md b/docs/cloud/features/xdb_diffing.md new file mode 100644 index 0000000000..fbdeb52ca5 --- /dev/null +++ b/docs/cloud/features/xdb_diffing.md @@ -0,0 +1,144 @@ +# Cross-database Table Diffing + +Tobiko Cloud extends SQLMesh's [within-database table diff tool](../../guides/tablediff.md) to support comparison of tables or views across different database systems. + +It provides a method of validating models that can be used along with [evaluating a model](../../guides/models.md#evaluating-a-model) and [testing a model with unit tests](../../guides/testing.md#testing-changes-to-models). + +!!! tip "Learn more about table diffing" + + Learn more about using the table diff tool in the SQLMesh [table diff guide](../../guides/tablediff.md). + +## Diffing tables or views across gateways + +SQLMesh executes a project's models with a single database system, specified as a [gateway](../../guides/connections.md) in the project configuration. + +The within-database table diff tool described above compares tables or environments within such a system. Sometimes, however, you might want to compare tables that reside in two different data systems. + +For example, you might migrate your data transformations from an on-premises SQL engine to a cloud SQL engine while setting up your SQLMesh project. To demonstrate equivalence between the systems you could run the transformations in both and compare the new tables to the old tables. + +The [within-database table diff](../../guides/tablediff.md) tool cannot make those comparisons, for two reasons: + +1. It must join the two tables being diffed, but with two systems no single database engine can access both tables. +2. It assumes that data values can be compared across tables without modification. However, the diff must account for differences in data types across the two SQL engines (e.g., whether timestamps should include time zone information). + +SQLMesh's cross-database table diff tool is built for just this scenario. Its comparison algorithm efficiently diffs tables without moving them from one system to the other and automatically addresses differences in data types. + +## Configuration and syntax + +To diff tables across systems, first configure a [gateway](../../reference/configuration.md#gateway) for each database system in your SQLMesh configuration file. + +This example configures `bigquery` and `snowflake` gateways: + +```yaml linenums="1" +gateways: + bigquery: + connection: + type: bigquery + [other connection parameters] + + snowflake: + connection: + type: snowflake + [other connection parameters] +``` + +Then, specify each table's gateway in the `table_diff` command with this syntax: `[source_gateway]|[source table]:[target_gateway]|[target table]`. + +For example, we could diff the `landing.table` table across `bigquery` and `snowflake` gateways like this: + +```sh +$ tcloud sqlmesh table_diff 'bigquery|landing.table:snowflake|landing.table' +``` + +This syntax tells SQLMesh to use the cross-database diffing algorithm instead of the normal within-database diffing algorithm. + +After adding gateways to the table names, use `table_diff` as described in the [SQLMesh table diff guide](../../guides/tablediff.md) - the same options apply for specifying the join keys, decimal precision, etc. See `tcloud sqlmesh table_diff --help` for a [full list of options](../../reference/cli.md#table_diff). + +!!! warning + + Cross-database diff works for data objects (tables / views). + + Diffing _models_ is not supported because we do not assume that both the source and target databases are managed by SQLMesh. + +## Example output + +A cross-database diff is broken up into two stages. + +The first stage is a schema diff. This example shows that differences in column name case across the two tables are identified as schema differences: + +```bash +$ tcloud sqlmesh table_diff 'bigquery|sqlmesh_example.full_model:snowflake|sqlmesh_example.full_model' --on item_id --show-sample + +Schema Diff Between 'BIGQUERY|SQLMESH_EXAMPLE.FULL_MODEL' and 'SNOWFLAKE|SQLMESH_EXAMPLE.FULL_MODEL': +├── Added Columns: +│ ├── ITEM_ID (DECIMAL(38, 0)) +│ └── NUM_ORDERS (DECIMAL(38, 0)) +└── Removed Columns: + ├── item_id (BIGINT) + └── num_orders (BIGINT) +Schema has differences; continue comparing rows? [y/n]: +``` + +SQLMesh prompts you before comparing data values across table rows. The prompt provides an opportunity to discontinue the comparison if the schemas are vastly different (potentially indicating a mistake) or you need to exclude columns from the diff because you know they won't match. + +The second stage of the diff is comparing data values across tables. Within each system, SQLMesh divides the data into chunks, evaluates each chunk, and compares the outputs across systems. If a difference is found, it performs a row-level diff on that chunk by reading a sample of mismatched rows from each system. + +This example shows that 2 rows were present in each system but had different values, one row was in Bigquery only, and one row was in Snowflake only: + +```bash +Dividing source dataset into 10 chunks (based on 10947709 total records) +Checking chunks against target dataset +Chunk 1 hash mismatch! +Starting row-level comparison for the range (1 -> 3) +Identifying individual record hashes that don't match +Comparing + +Row Counts: +├── PARTIAL MATCH: 2 rows (66.67%) +├── BIGQUERY ONLY: 1 rows (16.67%) +└── SNOWFLAKE ONLY: 1 rows (16.67%) + +COMMON ROWS column comparison stats: + pct_match +num_orders 0.0 + + +COMMON ROWS sample data differences: +Column: num_orders +┏━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━┓ +┃ item_id ┃ BIGQUERY ┃ SNOWFLAKE ┃ +┡━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━┩ +│ 1 │ 5 │ 7 │ +│ 2 │ 1 │ 2 │ +└─────────┴──────────┴───────────┘ + +BIGQUERY ONLY sample rows: +item_id num_orders + 7 4 + + +SNOWFLAKE ONLY sample rows: +item_id num_orders + 4 6 +``` + +If there are no differences found between chunks, the source and target datasets can be considered equal: + +```bash +Chunk 1 (1094771 rows) matches! +Chunk 2 (1094771 rows) matches! +... +Chunk 10 (1094770 rows) matches! + +All 10947709 records match between 'bigquery|sqlmesh_example.full_model' and 'snowflake|TEST.SQLMESH_EXAMPLE.FULL_MODEL' +``` + +!!! info + + Don't forget to specify the `--show-sample` option if you'd like to see a sample of the actual mismatched data! + + Otherwise, only high level statistics for the mismatched rows will be printed. + +### Supported engines + +Cross-database diffing is supported on all execution engines that [SQLMesh supports](../../integrations/overview.md#execution-engines). \ No newline at end of file diff --git a/docs/cloud/tcloud_getting_started.md b/docs/cloud/tcloud_getting_started.md new file mode 100644 index 0000000000..00ad8a3c25 --- /dev/null +++ b/docs/cloud/tcloud_getting_started.md @@ -0,0 +1,306 @@ +# Tobiko Cloud: Getting Started + +Tobiko Cloud is a data platform that extends SQLMesh to make it easy to manage data at scale without the waste. + +We're here to make it easy to get started and feel confident that everything is working as expected. After you've completed the steps below, you'll have achieved the following: + +- Log in to Tobiko Cloud via the browser +- Connect Tobiko Cloud to your local machine via the CLI +- Connect Tobiko Cloud to your data warehouse +- Verify that Tobiko Cloud interacts with your data warehouse as expected + +## Prerequisites + +Before you start, the Tobiko team must complete a few steps. + +Your Tobiko Solutions Architect will: + +- Set up a 1 hour meeting with you to fully onboard +- Request that a new Tobiko Cloud account be created for you (single tenant by default) +- Share a temporary password link that expires in 7 days +- Make sure you save the password in your own password manager + +To prepare for the meeting, ensure you or another attendee have data warehouse administrator rights to: + +- Update warehouse user and object permissions +- Create new users and grant them create/update/delete permissions on a specific database (ex: `database.schema.table`) + +For migrations from SQLMesh (open source) to Tobiko Cloud only: + +- Your Tobiko Solutions Architect will send you a script to extract your current state +- You send that state to the Tobiko Cloud engineers to validate before the migration occurs +- After validation, Tobiko Solutions Architect will schedule a migration date and meeting to move your state to Tobiko Cloud. There will be some downtime if you are running SQLMesh in a production environment. + +> Note: if you must be on VPN to access your data warehouse or have specific security requirements, please let us know and we can discuss options to ensure Tobiko Cloud can securely connect. + +Technical Requirements: + +- Tobiko Cloud requires Python version between 3.9 and 3.12 + +!!! note + If you don't have a supported Python version installed, you can use [uv](https://docs.astral.sh/uv/getting-started/installation/#installation-methods) to install it. + At the time of writing, these are the suggested commands to install uv and Python: + + === "macOS and Linux" + + ```bash + curl -LsSf https://astral.sh/uv/install.sh | sh + ``` + + === "Windows" + + ```powershell + powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex" + ``` + + ```bash + uv python install 3.12 + ``` + + +## Log in to Tobiko Cloud + +The first step to setting up Tobiko Cloud is logging in to the web interface: + +1. We will authenticate into your Tobiko Cloud instance. If it is your first time going through this flow, your Solutions Architect will guide you on [how to get SSO configured](https://sqlmesh.readthedocs.io/en/stable/cloud/features/single_sign_on/). Open the url below. + ```bash + https://cloud.tobikodata.com/auth/login + ``` +2. Once logged in, you should see the home page. If you are not redirected, then input your Tobiko Cloud URL in the browser (ex: +https://cloud.tobikodata.com/sqlmesh/tobiko/public-demo/observer/) + + Your view should be empty, but the figure below shows a populated example with Tobiko Cloud running in production: + +

+![tcloud home page](./tcloud_getting_started/tcloud_home_page.png) + +## Install the `tcloud` CLI + +Now we need to configure the `tcloud` command line interface tool. + +First, open a terminal within your terminal/IDE (ex: VSCode). Then follow the following steps to install the `tcloud` CLI: + +1. Create a new project directory, or an existing SQLMesh project, and navigate into it: + + ```bash + mkdir tcloud_project + cd tcloud_project + ``` + +2. Create a new file called `requirements.txt` and add `tcloud` to it: + + ```bash + echo 'tcloud' > requirements.txt + ``` + + > Pypi source: [tcloud](https://pypi.org/project/tcloud/) + + > Note: your Tobiko Solutions Architect will provide you a pinned version of `tcloud` + +3. Create a Python virtual environment in the project directory and install `tcloud`. The following demonstrates how to do this using [uv](https://docs.astral.sh/uv/pip/environments/#creating-a-virtual-environment) ([installation instructions](#prerequisites)): + + ```bash linenums="1" + uv venv --python 3.12 --seed # create a virtual environment inside the project directory + source .venv/bin/activate # activate the virtual environment + uv pip install -r requirements.txt # install the tcloud CLI + which tcloud # verify the tcloud CLI is installed in the venv in the path above + ``` + +!!! note + You may need to run `python3` or `pip3` instead of `python` or `pip`, depending on your python installation. + + If you do not see `tcloud` in the virtual environment path above, you may need to reactivate the venv: + + ```bash + source .venv/bin/activate + which tcloud + # expected path: /Users/person/Desktop/git_repos/tobiko-cloud-demo/.venv/bin/tcloud + ``` + +- Create an alias to ensure use of `tcloud`: + + We recommend using a command line alias to ensure all `sqlmesh` commands run on Tobiko Cloud. + + Set the alias in the terminal by running `alias sqlmesh='tcloud sqlmesh'` in every session. + + Or add this to your shell profile file (ex: `~/.zshrc` or `~/.bashrc`) so you don't have to run the command every time: + + ```bash + alias sqlmesh='tcloud sqlmesh' + ``` + + Note: the rest of the commands in this document will NOT use the alias to avoid confusion with the open source SQLMesh CLI. + +## Connect Tobiko Cloud to Data Warehouse + +Now we're ready to connect your data warehouse to Tobiko Cloud: + +1. Create a new file called `tcloud.yaml` and add the project configuration below, substituting the appropriate values for your project: + + ```yaml + projects: + public-demo: # TODO: update this for the project name in the URL + url: https://cloud.tobikodata.com/sqlmesh/tobiko/public-demo/ # TODO: update for your unique URL + gateway: tobiko_cloud + extras: bigquery,web,github # TODO: update bigquery for your data warehouse + pip_executable: uv pip + default_project: public-demo # TODO: update this for the project name in the URL + ``` + +2. If you are going through the SSO flow then, run the following command: + ``` bash + tcloud auth login + ``` + This will fire off the SSO flow and open a link in your browser to authenticate. + + Once authenticated, you will see the following screen. + + ![tcloud_auth_success](./tcloud_getting_started/tcloud_auth_success.png) + +3. Initialize a new SQLMesh project: + + ```bash + tcloud sqlmesh init + ``` + +4. Update your project's `config.yaml` with your data warehouse connection information: + + Your new SQLMesh project will contain a configuration file named `config.yaml` that includes a DuckDB connection. + + Replace the DuckDB connection information with your data warehouse's information. + + This example shows a Bigquery warehouse connection; see more examples [here](../integrations/overview.md). + + ```yaml linenums="1" + gateways: + tobiko_cloud: # this will use the config in tcloud.yaml for state_connection + scheduler: # TODO: add the connection in the Tobiko Cloud Connections Page with the credentials for your data warehouse + type: cloud + + default_gateway: tobiko_cloud + + model_defaults: + dialect: bigquery # TODO: update for your data warehouse + start: 2024-08-19 # TODO: I recommend updating this to an earlier date representing the historical data you want to backfill + + # make Tobiko Cloud only allow deploying to dev environments, use env var to override in CI/CD + # allow_prod_deploy: {{ env_var('ALLOW_PROD_DEPLOY', 'false') }} + + # enables synchronized deployments to prod when a pull request gets a `/deploy` command or is approved by a required approver + cicd_bot: + type: github + merge_method: squash + skip_pr_backfill: false + enable_deploy_command: true + auto_categorize_changes: + external: full + python: full + sql: full + seed: full + + # preview data for forward only models + plan: + enable_preview: true + + # list of users that are allowed to approve PRs for synchronized deployments + users: + - username: sung_tcloud_demo + github_username: sungchun12 + roles: + - required_approver + ``` + +5. Create a `tcloud` user in the warehouse + + During your onboarding call, we will walk through instructions live to create a new `tcloud` data warehouse user with the necessary permissions. + + SQLMesh will run as this user to create, update, and delete tables in your data warehouse. You can scope the user permissions to a specific database if needed. + + Find additional data warehouse specific instructions here: [Data Warehouse Integrations](../integrations/overview.md). + + +6. Verify the connection between Tobiko Cloud and data warehouse: + + Now we're ready to verify that the connection between Tobiko Cloud and the data warehouse is working properly. + + Run the `info` command from your terminal: + + ```bash + tcloud sqlmesh info + ``` + + It will return output similar to this: + + ```bash + (.venv) ➜ tcloud_project git:(main) ✗ tcloud sqlmesh info + Models: 3 + Macros: 0 + Data warehouse connection succeeded + State backend connection succeeded + ``` + +## Verify SQLMesh functionality + +Let's run a `plan` to verify that SQLMesh is working correctly. + +Run `tcloud sqlmesh plan` in your terminal and enter `y` at the prompt to apply the changes. + +```bash +tcloud sqlmesh plan +``` + +It will return output similar to this: + +```bash +(.venv) ➜ tcloud_project git:(main) ✗ tcloud sqlmesh plan +====================================================================== +Successfully Ran 1 tests against duckdb +---------------------------------------------------------------------- +New environment `prod` will be created from `prod` +Summary of differences against `prod`: +Models: +└── Added: + ├── sqlmesh_example.full_model + ├── sqlmesh_example.incremental_model + └── sqlmesh_example.seed_model +Models needing backfill (missing dates): +├── sqlmesh_example.full_model: 2024-11-24 - 2024-11-24 +├── sqlmesh_example.incremental_model: 2020-01-01 - 2024-11-24 +└── sqlmesh_example.seed_model: 2024-11-24 - 2024-11-24 +Apply - Backfill Tables [y/n]: y + +[1/1] sqlmesh_example.seed_model evaluated in 0.00s +[1/1] sqlmesh_example.incremental_model evaluated in 0.01s +[1/1] sqlmesh_example.full_model evaluated in 0.01s +Evaluating models ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 3/3 • 0:00:00 + + +All model batches have been executed successfully + +Virtually Updating 'prod' ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 0:00:00 + +The target environment has been updated successfully +``` + +Tobiko Cloud and SQLMesh are working! + +## Next steps + +Your `tcloud` project directory should look and feel like this: + +![tcloud project directory](./tcloud_getting_started/tcloud_project_dir.png) + +From here, if you have an existing SQLMesh project, you can copy over your existing models and macros to the `models` and `macros` directories (along with other files as needed). + +You are now fully onboarded with Tobiko Cloud. We recommend reviewing the helpful links below to get familiar with SQLMesh and Tobiko Cloud. + +Here's to data transformation without the waste! + +### Helpful Links +- [Walkthrough Example](../examples/incremental_time_full_walkthrough.md) +- [Quickstart](../quick_start.md) +- [Project Guide and getting setup](../guides/projects.md) +- [Models Guide](../guides/models.md) +- [GitHub Actions CI/CD bot](../integrations/github.md) +- [Testing Models](../concepts/tests.md) +- [SQLMesh Macros](../concepts/macros/sqlmesh_macros.md) \ No newline at end of file diff --git a/docs/cloud/tcloud_getting_started/tcloud_auth_success.png b/docs/cloud/tcloud_getting_started/tcloud_auth_success.png new file mode 100644 index 0000000000..2989d976c8 Binary files /dev/null and b/docs/cloud/tcloud_getting_started/tcloud_auth_success.png differ diff --git a/docs/cloud/tcloud_getting_started/tcloud_home_page.png b/docs/cloud/tcloud_getting_started/tcloud_home_page.png new file mode 100644 index 0000000000..03fa0c5a27 Binary files /dev/null and b/docs/cloud/tcloud_getting_started/tcloud_home_page.png differ diff --git a/docs/cloud/tcloud_getting_started/tcloud_project_dir.png b/docs/cloud/tcloud_getting_started/tcloud_project_dir.png new file mode 100644 index 0000000000..8b3e2d1033 Binary files /dev/null and b/docs/cloud/tcloud_getting_started/tcloud_project_dir.png differ diff --git a/docs/comparisons.md b/docs/comparisons.md index fef5b9bc65..ef6049acd6 100644 --- a/docs/comparisons.md +++ b/docs/comparisons.md @@ -37,7 +37,6 @@ SQLMesh aims to be dbt format-compatible. Importing existing dbt projects with m | `Virtual Data Environments` | ❌ | [✅](../concepts/environments) | `Open-source CI/CD bot` | ❌ | [✅](../integrations/github) | `Data consistency enforcement` | ❌ | ✅ -| `Native Airflow integration` | ❌ | [✅](../integrations/airflow) | Interfaces | `CLI` | ✅ | [✅](../reference/cli) | `Paid UI` | ✅ | ❌ diff --git a/docs/concepts/audits.md b/docs/concepts/audits.md index 61643803dc..a5a9fccc49 100644 --- a/docs/concepts/audits.md +++ b/docs/concepts/audits.md @@ -7,10 +7,36 @@ By default, SQLMesh will halt plan application when an audit fails so potentiall A comprehensive suite of audits can identify data issues upstream, whether they are from your vendors or other teams. Audits also empower your data engineers and analysts to work with confidence by catching problems early as they work on new features or make updates to your models. -**NOTE**: For incremental models, audits are only applied to intervals being processed - not for the entire underlying table. +**NOTE**: For incremental by time range models, audits are only applied to intervals being processed - not for the entire underlying table. + +## Blocking audits +A failed blocking audit halts the execution of a `plan` or `run` to prevent invalid data from propagating to downstream models. The impact of a failure depends on whether you are running a `plan` or a `run`. + +SQLMesh's blocking audit process is: + +1. Evaluate the model (e.g., insert new data or rebuild the table) +2. Run the audit query against the newly updated model table. For incremental models, the audit only runs on the processed time intervals. +3. If the query returns any rows, the audit fails, halting the `plan` or `run`. + +### Plan vs. Run + +The key difference is when the model's data is promoted to the production environment: + +* **`plan`**: SQLMesh evaluates and audits all modified models *before* promoting them to production. If an audit fails, the `plan` stops, and the production table is untouched. Invalid data is contained in an isolated table and never reaches the production environment. + +* **`run`**: SQLMesh evaluates and audits models directly against the production environment. If an audit fails, the `run` stops, but the invalid data *is already present* in the production table. The "blocking" action prevents this bad data from being used to build other downstream models. + +### Fixing a Failed Audit + +If a blocking audit fails during a `run`, you must fix the invalid data in the production table. To do so: + +1. **Find the root cause**: examine upstream models and data sources +2. **Fix the source** + * If the cause is an **external data source**, fix it there. Then, run a [restatement plan](./plans.md#restatement-plans) on the first SQLMesh model that ingests the source data. This will restate all downstream models, including the one with the failed audit. + * If the cause is a **SQLMesh model**, update the model's logic. Then apply the change with a `plan`, which will automatically re-evaluate all downstream models. ## User-Defined Audits -In SQLMesh, user-defined audits are defined in `.sql` files in an `audit` directory in your SQLMesh project. Multiple audits can be defined in a single file, so you can organize them to your liking. Alternatively, audits can be defined inline within the model definition itself. +In SQLMesh, user-defined audits are defined in `.sql` files in an `audits` directory in your SQLMesh project. Multiple audits can be defined in a single file, so you can organize them to your liking. Alternatively, audits can be defined inline within the model definition itself. Audits are SQL queries that should not return any rows; in other words, they query for bad data, so returned rows indicates that something is wrong. @@ -75,6 +101,28 @@ Notice how `column` and `threshold` parameters have been set. These values will Note that the same audit can be applied more than once to the a model using different sets of parameters. +Generic audits can define default values as follows: +```sql linenums="1" +AUDIT ( + name does_not_exceed_threshold, + defaults ( + threshold = 10, + column = id + ) +); +SELECT * FROM @this_model +WHERE @column >= @threshold; +``` + +Alternatively, you can apply specific audits globally by including them in the model defaults configuration: + +```sql linenums="1" +model_defaults: + audits: + - assert_positive_order_ids + - does_not_exceed_threshold(column := id, threshold := 1000) +``` + ### Naming We recommended avoiding SQL keywords when naming audit parameters. Quote any audit argument that is also a SQL keyword. @@ -99,7 +147,7 @@ MODEL ( name sushi.items, audits(does_not_exceed_threshold(column := id, threshold := 1000), price_is_not_null) ); -SELECT id, price +SELECT id, price FROM sushi.seed; AUDIT (name does_not_exceed_threshold); @@ -246,7 +294,8 @@ MODEL ( #### accepted_values, accepted_values_non_blocking Ensures that all rows of the specified column contain one of the accepted values. -NOTE: rows with `NULL` values for the column will pass this audit in most databases/engines. Use the [`not_null` audit](#not_null) to ensure there are no `NULL` values present in a column. +!!! note + Rows with `NULL` values for the column will pass this audit in most databases/engines. Use the [`not_null` audit](#not_null) to ensure there are no `NULL` values present in a column. This example asserts that column `name` has a value of 'Hamachi', 'Unagi', or 'Sake': @@ -254,7 +303,7 @@ This example asserts that column `name` has a value of 'Hamachi', 'Unagi', or 'S MODEL ( name sushi.items, audits ( - accepted_values(column := name, is_in=('Hamachi', 'Unagi', 'Sake')) + accepted_values(column := name, is_in := ('Hamachi', 'Unagi', 'Sake')) ) ); ``` @@ -262,7 +311,8 @@ MODEL ( #### not_accepted_values, not_accepted_values_non_blocking Ensures that no rows of the specified column contain one of the not accepted values. -NOTE: this audit does not support rejecting `NULL` values. Use the [`not_null` audit](#not_null) to ensure there are no `NULL` values present in a column. +!!! note + This audit does not support rejecting `NULL` values. Use the [`not_null` audit](#not_null) to ensure there are no `NULL` values present in a column. This example asserts that column `name` is not one of 'Hamburger' or 'French fries': @@ -337,7 +387,8 @@ MODEL ( These audits concern the characteristics of values in character/string columns. -NOTE: databases/engines may exhibit different behavior for different character sets or languages. +!!! warning + Databases/engines may exhibit different behavior for different character sets or languages. #### not_empty_string, not_empty_string_non_blocking Ensures that no rows of a column contain an empty string value `''`. @@ -353,7 +404,7 @@ MODEL ( ); ``` -#### string_length_equal_audit, string_length_equal_audit_non_blocking +#### string_length_equal, string_length_equal_non_blocking Ensures that all rows of a column contain a string with the specified number of characters. This example asserts that all `zip` values are 5 characters long: @@ -362,12 +413,12 @@ This example asserts that all `zip` values are 5 characters long: MODEL ( name sushi.customers, audits ( - string_length_equal_audit(column := zip, v := 5) + string_length_equal(column := zip, v := 5) ) ); ``` -#### string_length_between_audit, string_length_between_audit_non_blocking +#### string_length_between, string_length_between_non_blocking Ensures that all rows of a column contain a string with number of characters in the specified range. Range is inclusive by default, such that values equal to the range boundaries will pass the audit. This example asserts that all `name` values have 5 or more and 50 or fewer characters: @@ -376,7 +427,7 @@ This example asserts that all `name` values have 5 or more and 50 or fewer chara MODEL ( name sushi.customers, audits ( - string_length_between_audit(column := name, min_v := 5, max_v := 50) + string_length_between(column := name, min_v := 5, max_v := 50) ) ); ``` @@ -387,7 +438,7 @@ This example specifies the `inclusive := false` argument to assert that all rows MODEL ( name sushi.customers, audits ( - string_length_between_audit(column := zip, min_v := 4, max_v := 60, inclusive := false) + string_length_between(column := zip, min_v := 4, max_v := 60, inclusive := false) ) ); ``` @@ -509,7 +560,9 @@ MODEL ( These audits concern the statistical distributions of numeric columns. -NOTE: audit thresholds will likely require fine-tuning via trial and error for each column being audited. +!!! note + + Audit thresholds will likely require fine-tuning via trial and error for each column being audited. #### mean_in_range, mean_in_range_non_blocking Ensures that a numeric column's mean is in the specified range. Range is inclusive by default, such that values equal to the range boundaries will pass the audit. @@ -612,7 +665,7 @@ MODEL ( You can execute audits with the `sqlmesh audit` command as follows: ```bash -$ sqlmesh -p project audit -start 2022-01-01 -end 2022-01-02 +$ sqlmesh -p project audit --start 2022-01-01 --end 2022-01-02 Found 1 audit(s). assert_item_price_is_not_null FAIL. diff --git a/docs/concepts/glossary.md b/docs/concepts/glossary.md index 9add9aee0e..15ffbc2de9 100644 --- a/docs/concepts/glossary.md +++ b/docs/concepts/glossary.md @@ -57,6 +57,9 @@ Combining data from various sources (such as from a data warehouse) into one uni ## Lineage The lineage of your data is a visualization of the life cycle of your data as it flows from data sources downstream to consumption. +## Physical Layer +The physical layer is where SQLMesh stores and manages data in database tables and materialized views. It is the concrete data storage layer of the SQL engine, in contrast to the [SQLMesh virtual layer's](#virtual-layer) views. SQLMesh handles the management and maintenance of the physical layer automatically, and users should rarely interact with it directly. + ## Plan Summaries An upcoming feature that allows users to see a summary of changes applied to a given environment. @@ -78,6 +81,9 @@ A view is the result of a SQL query on a database. ## Virtual Environments SQLMesh's unique approach to environment that allows it to provide both environment isolation and the ability to share tables across environments. This is done in a way to ensure data consistency and accuracy. See [plan application](plans.md#plan-application) for more information. +## Virtual Layer +The virtual layer is SQLMesh's abstraction layer over the [physical layer and physical data storage](#physical-layer). While the physical layer consists of tables where data is actually stored, the virtual layer consists of views that expose tables in the underlying physical layer. Most users should only interact with the virtual layer when building models or querying data. + ## Virtual Update Term used to describe a plan that can be applied without having to load any additional data or build any additional tables. See [Virtual Update](plans.md#virtual-update) for more information. diff --git a/docs/concepts/macros/jinja_macros.md b/docs/concepts/macros/jinja_macros.md index 36b751113b..49b5f81912 100644 --- a/docs/concepts/macros/jinja_macros.md +++ b/docs/concepts/macros/jinja_macros.md @@ -50,6 +50,30 @@ JINJA_STATEMENT_BEGIN; JINJA_END; ``` +## SQLMesh predefined variables + +SQLMesh provides multiple [predefined macro variables](./macro_variables.md) you may reference in jinja code. + +Some predefined variables provide information about the SQLMesh project itself, like the [`runtime_stage`](./macro_variables.md#runtime-variables) and [`this_model`](./macro_variables.md#runtime-variables) variables. + +Other predefined variables are [temporal](./macro_variables.md#temporal-variables), like `start_ds` and `execution_date`. They are used to build incremental model queries and are only available in incremental model kinds. + +Access predefined macro variables by passing their unquoted name in curly braces. For example, this demonstrates how to access the `start_ds` and `end_ds` variables: + +```sql linenums="1" +JINJA_QUERY_BEGIN; + +SELECT * +FROM table +WHERE time_column BETWEEN '{{ start_ds }}' and '{{ end_ds }}'; + +JINJA_END; +``` + +Because the two macro variables return string values, we must surround the curly braces with single quotes `'`. Other macro variables, such as `start_epoch`, return numeric values and do not require the single quotes. + +The `gateway` variable uses a slightly different syntax than other predefined variables because it is a function call. Instead of the bare name `{{ gateway }}`, it must include parentheses: `{{ gateway() }}`. + ## User-defined variables SQLMesh supports two kinds of user-defined macro variables: global and local. @@ -90,6 +114,39 @@ WHERE some_value = {{ var('missing_var', 0) }}; JINJA_END; ``` +### Gateway variables + +Like global variables, gateway variables are defined in the project configuration file. However, they are specified in a specific gateway's `variables` key. Learn more about defining gateway variables in the [SQLMesh macros documentation](./sqlmesh_macros.md#gateway-variables). + +Access gateway variables in models using the same methods as [global variables](#global-variables). + +Gateway-specific variable values take precedence over variables with the same name specified in the configuration file's root `variables` key. + +### Blueprint variables + +Blueprint variables are defined as a property of the `MODEL` statement, and serve as a mechanism for [creating model templates](../models/sql_models.md): + +```sql linenums="1" +MODEL ( + name @customer.some_table, + kind FULL, + blueprints ( + (customer := customer1, field_a := x, field_b := y), + (customer := customer2, field_a := z) + ) +); + +JINJA_QUERY_BEGIN; +SELECT + {{ blueprint_var('field_a') }} + {{ blueprint_var('field_b', 'default_b') }} AS field_b +FROM {{ blueprint_var('customer') }}.some_source +JINJA_END; +``` + +Blueprint variables can be accessed using the `{{ blueprint_var() }}` macro function, which also supports specifying default values in case the variable is undefined (similar to `{{ var() }}`). + + ### Local variables Define your own variables with the Jinja statement `{% set ... %}`. For example, we could specify the name of the `num_orders` column in the `sqlmesh_example.full_model` like this: diff --git a/docs/concepts/macros/macro_variables.md b/docs/concepts/macros/macro_variables.md index e72ede481d..398117b3a9 100644 --- a/docs/concepts/macros/macro_variables.md +++ b/docs/concepts/macros/macro_variables.md @@ -1,10 +1,20 @@ # Macro variables -The most common use case for macros is variable substitution. For example, you might have a SQL query that filters by date in the `WHERE` clause. +Macro variables are placeholders whose values are substituted in when the macro is rendered. + +They enable dynamic macro behavior - for example, a date parameter's value might be based on when the macro was run. + +!!! note + + This page discusses SQLMesh's built-in macro variables. Learn more about custom, user-defined macro variables on the [SQLMesh macros page](./sqlmesh_macros.md#user-defined-variables). + +## Example + +Consider a SQL query that filters by date in the `WHERE` clause. Instead of manually changing the date each time the model is run, you can use a macro variable to make the date dynamic. With the dynamic approach, the date changes automatically based on when the query is run. -Consider this query that filters for rows where column `my_date` is after '2023-01-01': +This query filters for rows where column `my_date` is after '2023-01-01': ```sql linenums="1" SELECT * @@ -34,42 +44,56 @@ This example used one of SQLMesh's predefined variables, but you can also define We describe SQLMesh's predefined variables below; user-defined macro variables are discussed in the [SQLMesh macros](./sqlmesh_macros.md#user-defined-variables) and [Jinja macros](./jinja_macros.md#user-defined-variables) pages. -## Predefined Variables +## Predefined variables SQLMesh comes with predefined variables that can be used in your queries. They are automatically set by the SQLMesh runtime. -Most predefined variables are related to time and use a combination of prefixes (start, end, execution) and postfixes (date, ds, ts, epoch, millis). They are described in the next section; [other predefined variables](#runtime-variables) are discussed in the following section. +Most predefined variables are related to time and use a combination of prefixes (start, end, etc.) and postfixes (date, ds, ts, etc.). They are described in the next section; [other predefined variables](#runtime-variables) are discussed in the following section. ### Temporal variables -SQLMesh uses the python [datetime module](https://docs.python.org/3/library/datetime.html) for handling dates and times. It uses the standard [Unix epoch](https://en.wikipedia.org/wiki/Unix_time) start of 1970-01-01. *All predefined variables with a time component use the [UTC time zone](https://en.wikipedia.org/wiki/Coordinated_Universal_Time).* +SQLMesh uses the python [datetime module](https://docs.python.org/3/library/datetime.html) for handling dates and times. It uses the standard [Unix epoch](https://en.wikipedia.org/wiki/Unix_time) start of 1970-01-01. + +!!! tip "Important" + + Predefined variables with a time component always use the [UTC time zone](https://en.wikipedia.org/wiki/Coordinated_Universal_Time). + + Learn more about timezones and incremental models [here](../models/model_kinds.md#timezones). Prefixes: -* start - The inclusive starting interval of a model run. -* end - The inclusive end interval of a model run. -* execution - The timestamp of when the execution started. +* start - The inclusive starting interval of a model run +* end - The inclusive end interval of a model run +* execution - The timestamp of when the execution started Postfixes: -* date - A python date object that converts into a native SQL Date. +* dt - A python datetime object that converts into a native SQL `TIMESTAMP` (or SQL engine equivalent) +* dtntz - A python datetime object that converts into a native SQL `TIMESTAMP WITHOUT TIME ZONE` (or SQL engine equivalent) +* date - A python date object that converts into a native SQL `DATE` * ds - A date string with the format: '%Y-%m-%d' -* ts - An ISO 8601 datetime formatted string: '%Y-%m-%d %H:%M:%S'. -* tstz - An ISO 8601 datetime formatted string with timezone: '%Y-%m-%d %H:%M:%S%z'. -* epoch - An integer representing seconds since Unix epoch. -* millis - An integer representing milliseconds since Unix epoch. +* ts - An ISO 8601 datetime formatted string: '%Y-%m-%d %H:%M:%S' +* tstz - An ISO 8601 datetime formatted string with timezone: '%Y-%m-%d %H:%M:%S%z' +* hour - An integer representing the hour of the day, with values 0-23 +* epoch - An integer representing seconds since Unix epoch +* millis - An integer representing milliseconds since Unix epoch All predefined temporal macro variables: +* dt + * @start_dt + * @end_dt + * @execution_dt + +* dtntz + * @start_dtntz + * @end_dtntz + * @execution_dtntz + * date * @start_date * @end_date * @execution_date -* datetime - * @start_dt - * @end_dt - * @execution_dt - * ds * @start_ds * @end_ds @@ -85,6 +109,11 @@ All predefined temporal macro variables: * @end_tstz * @execution_tstz +* hour + * @start_hour + * @end_hour + * @execution_hour + * epoch * @start_epoch * @end_epoch @@ -97,19 +126,36 @@ All predefined temporal macro variables: ### Runtime variables -SQLMesh provides two other predefined variables used to modify model behavior based on information available at runtime. +SQLMesh provides additional predefined variables used to modify model behavior based on information available at runtime. + +* @runtime_stage - A string value denoting the current stage of the SQLMesh runtime. Typically used in models to conditionally execute pre/post-statements (learn more [here](../models/sql_models.md#optional-prepost-statements)). It returns one of these values: + * 'loading' - The project is being loaded into SQLMesh's runtime context. + * 'creating' - The model tables are being created for the first time. The data may be inserted during table creation. + * 'evaluating' - The model query logic is evaluated, and the data is inserted into the existing model table. + * 'promoting' - The model is being promoted in the target environment (view created during virtual layer update). + * 'demoting' - The model is being demoted in the target environment (view dropped during virtual layer update). + * 'auditing' - The audit is being run. + * 'testing' - The model query logic is being evaluated in the context of a unit test. +* @gateway - A string value containing the name of the current [gateway](../../guides/connections.md). +* @this_model - The physical table name that the model's view selects from. Typically used to create [generic audits](../audits.md#generic-audits). When used in [on_virtual_update statements](../models/sql_models.md#optional-on-virtual-update-statements), it contains the qualified view name instead. +* @model_kind_name - A string value containing the name of the current model kind. Intended to be used in scenarios where you need to control the [physical properties in model defaults](../../reference/model_configuration.md#model-defaults). + +!!! note "Embedding variables in strings" + + Macro variable references sometimes use the curly brace syntax `@{variable}`, which serves a different purpose than the regular `@variable` syntax. + + The curly brace syntax tells SQLMesh that the rendered string should be treated as an identifier, instead of simply replacing the macro variable value. + + For example, if `variable` is defined as `@DEF(`variable`, foo.bar)`, then `@variable` produces `foo.bar`, while `@{variable}` produces `"foo.bar"`. This is because SQLMesh converts `foo.bar` into an identifier, using double quotes to correctly include the `.` character in the identifier name. -* @runtime_stage - A string value that denotes the current stage of the SQLMesh runtime. It can take one of the following values: - * 'loading' - The project is currently being loaded into SQLMesh's runtime context. - * 'creating' - The model tables are being created. - * 'evaluating' - The models' logic is being evaluated. - * 'testing' - The models' logic is being evaluated in the context of a unit test. -* @gateway - A string value that represents the name of the selected [gateway](../../guides/connections.md). + In practice, `@{variable}` is most commonly used to interpolate a value within an identifier, e.g., `@{variable}_suffix`, whereas `@variable` is used to do plain substitutions for string literals. -### Audit-only variables + Learn more [above](#embedding-variables-in-strings). -Some predefined variables are only supported in [SQLMesh audit definitions](../audits.md). +#### Before all and after all variables -* @this_model - used to create [generic audits](../audits.md#generic-audits) +The following variables are also available in [`before_all` and `after_all` statements](../../guides/configuration.md#before_all-and-after_all-statements), as well as in macros invoked within them. -The `{{ this_model }}` Jinja macro variable may be used in model definitions for the rare cases when SQLGlot cannot fully parse a statement and you need to reference the model's underlying physical table directly. We recommend against using it unless absolutely required. +* @this_env - A string value containing the name of the current [environment](../environments.md). +* @schemas - A list of the schema names of the [virtual layer](../../concepts/glossary.md#virtual-layer) of the current environment. +* @views - A list of the view names of the [virtual layer](../../concepts/glossary.md#virtual-layer) of the current environment. \ No newline at end of file diff --git a/docs/concepts/macros/sqlmesh_macros.md b/docs/concepts/macros/sqlmesh_macros.md index 5e3557ca38..c7d967b12c 100644 --- a/docs/concepts/macros/sqlmesh_macros.md +++ b/docs/concepts/macros/sqlmesh_macros.md @@ -38,14 +38,67 @@ It uses the following five step approach to accomplish this: 5. Modify the semantic representation of the SQL query with the substituted variable values from (3) and functions from (4). +### Embedding variables in strings + +SQLMesh always incorporates macro variable values into the semantic representation of a SQL query (step 5 above). To do that, it infers the role each macro variable value plays in the query. + +For context, two commonly used types of string in SQL are: + +- String literals, which represent text values and are surrounded by single quotes, such as `'the_string'` +- Identifiers, which reference database objects like column, table, alias, and function names + - They may be unquoted or quoted with double quotes, backticks, or brackets, depending on the SQL dialect + +In a normal query, SQLMesh can easily determine which role a given string is playing. However, it is more difficult if a macro variable is embedded directly into a string - especially if the string is in the `MODEL` block (and not the query itself). + +For example, consider a project that defines a [gateway variable](#gateway-variables) named `gateway_var`. The project includes a model that references `@gateway_var` as part of the schema in the model's `name`, which is a SQL *identifier*. + +This is how we might try to write the model: + +``` sql title="Incorrectly rendered to string literal" +MODEL ( + name the_@gateway_var_schema.table +); +``` + +From SQLMesh's perspective, the model schema is the combination of three sub-strings: `the_`, the value of `@gateway_var`, and `_schema`. + +SQLMesh will concatenate those strings, but it does not have the context to know that it is building a SQL identifier and will return a string literal. + +To provide the context SQLMesh needs, you must add curly braces to the macro variable reference: `@{gateway_var}` instead of `@gateway_var`: + +``` sql title="Correctly rendered to identifier" +MODEL ( + name the_@{gateway_var}_schema.table +); +``` + +The curly braces let SQLMesh know that it should treat the string as a SQL identifier, which it will then quote based on the SQL dialect's quoting rules. + +The most common use of the curly brace syntax is embedding macro variables into strings, it can also be used to differentiate string literals and identifiers in SQL queries. For example, consider a macro variable `my_variable` whose value is `col`. + +If we `SELECT` this value with regular macro syntax, it will render to a string literal: + +``` sql +SELECT @my_variable AS the_column; -- renders to SELECT 'col' AS the_column +``` + +`'col'` is surrounded with single quotes, and the SQL engine will use that string as the column's data value. + +If we use curly braces, SQLMesh will know that we want to use the rendered string as an identifier: + +``` sql +SELECT @{my_variable} AS the_column; -- renders to SELECT col AS the_column +``` + +`col` is not surrounded with single quotes, and the SQL engine will determine that the query is referencing a column or other object named `col`. ## User-defined variables -SQLMesh supports three kinds of user-defined macro variables: [global](#global-variables), [gateway](#gateway-variables), and [local](#local-variables). +SQLMesh supports four kinds of user-defined macro variables: [global](#global-variables), [gateway](#gateway-variables), [blueprint](#blueprint-variables) and [local](#local-variables). -Global and gateway macro variables are defined in the project configuration file and can be accessed in any project model. Local macro variables are defined in a model definition and can only be accessed in that model. +Global and gateway macro variables are defined in the project configuration file and can be accessed in any project model. Blueprint and macro variables are defined in a model definition and can only be accessed in that model. -Macro variables with the same name may be specified at any or all of the global, gateway, and local levels. When variables are specified at multiple levels, the value of the most specific level takes precedence. For example, the value of a local variable takes precedence over the value of a gateway variable with the same name, and the value of a gateway variable takes precedence over the value of a global variable. +Macro variables with the same name may be specified at any or all of the global, gateway, blueprint and local levels. When variables are specified at multiple levels, the value of the most specific level takes precedence. For example, the value of a local variable takes precedence over the value of a blueprint or gateway variable with the same name, and the value of a gateway variable takes precedence over the value of a global variable. ### Global variables @@ -57,17 +110,37 @@ Access global variable values in a model definition using the `@` macr For example, this SQLMesh configuration key defines six variables of different data types: -```yaml linenums="1" -variables: - int_var: 1 - float_var: 2.0 - bool_var: true - str_var: "cat" - list_var: [1, 2, 3] - dict_var: - key1: 1 - key2: 2 -``` +=== "YAML" + + ```yaml linenums="1" + variables: + int_var: 1 + float_var: 2.0 + bool_var: true + str_var: "cat" + list_var: [1, 2, 3] + dict_var: + key1: 1 + key2: 2 + ``` + +=== "Python" + + ``` python linenums="1" + variables = { + "int_var": 1, + "float_var": 2.0, + "bool_var": True, + "str_var": "cat", + "list_var": [1, 2, 3], + "dict_var": {"key1": 1, "key2": 2}, + } + + config = Config( + variables=variables, + ... # other Config arguments + ) + ``` A model definition could access the `int_var` value in a `WHERE` clause like this: @@ -101,21 +174,83 @@ A similar API is available for [Python macro functions](#accessing-global-variab Like global variables, gateway variables are defined in the project configuration file. However, they are specified in a specific gateway's `variables` key: -```yaml linenums="1" -gateways: - my_gateway: - variables: - int_var: 1 - ... -``` +=== "YAML" + + ```yaml linenums="1" + gateways: + my_gateway: + variables: + int_var: 1 + ... + ``` + +=== "Python" + + ``` python linenums="1" + gateway_variables = { + "int_var": 1 + } + + config = Config( + gateways={ + "my_gateway": GatewayConfig( + variables=gateway_variables + ... # other GatewayConfig arguments + ), + } + ) + ``` Access them in models using the same methods as [global variables](#global-variables). Gateway-specific variable values take precedence over variables with the same name specified in the root `variables` key. +### Blueprint variables + +Blueprint macro variables are defined in a model. Blueprint variable values take precedence over [global](#global-variables) or [gateway-specific](#gateway-variables) variables with the same name. + +Blueprint variables are defined as a property of the `MODEL` statement, and serve as a mechanism for [creating model templates](../models/sql_models.md): + +```sql linenums="1" +MODEL ( + name @customer.some_table, + kind FULL, + blueprints ( + (customer := customer1, field_a := x, field_b := y, field_c := 'foo'), + (customer := customer2, field_a := z, field_b := w, field_c := 'bar') + ) +); + +SELECT + @field_a, + @{field_b} AS field_b, + @field_c AS @{field_c} +FROM @customer.some_source + +/* +When rendered for customer1.some_table: +SELECT + x, + y AS field_b, + 'foo' AS foo +FROM customer1.some_source + +When rendered for customer2.some_table: +SELECT + z, + w AS field_b, + 'bar' AS bar +FROM customer2.some_source +*/ +``` + +Note the use of both regular `@field_a` and curly brace syntax `@{field_b}` macro variable references in the model query. Both of these will be rendered as identifiers. In the case of `field_c`, which in the blueprints is a string, it would be rendered as a string literal when used with the regular macro syntax `@field_c` and if we want to use the string as an identifier then we use the curly braces `@{field_c}`. Learn more [above](#embedding-variables-in-strings) + +Blueprint variables can be accessed using the syntax shown above, or through the `@BLUEPRINT_VAR()` macro function, which also supports specifying default values in case the variable is undefined (similar to `@VAR()`). + ### Local variables -Local macro variables are defined in a model. Local variable values take precedence over [global](#global-variables) or [gateway-specific](#gateway-variables) variables with the same name. +Local macro variables are defined in a model. Local variable values take precedence over [global](#global-variables), [blueprint](#blueprint-variables), or [gateway-specific](#gateway-variables) variables with the same name. Define your own local macro variables with the `@DEF` macro operator. For example, you could set the macro variable `macro_var` to the value `1` with: @@ -385,7 +520,13 @@ FROM table This syntax works regardless of whether the array values are quoted or not. -NOTE: SQLMesh macros support placing macro values at the end of a column name simply using `column_@x`. However if you wish to substitute the variable anywhere else in the identifier, you need to use the more explicit substitution syntax `@{}`. This avoids ambiguity. These are valid uses: `@{x}_column` or `my_@{x}_column`. +!!! note "Embedding macros in strings" + + SQLMesh macros support placing macro values at the end of a column name using `column_@x`. + + However, if you wish to substitute the variable anywhere else in the identifier, you need to use the more explicit curly brace syntax `@{}` to avoid ambiguity. For example, these are valid uses: `@{x}_column` or `my_@{x}_column`. + + Learn more about embedding macros in strings [above](#embedding-variables-in-strings) ### @IF @@ -459,6 +600,10 @@ SELECT FROM table ``` +[Macro rendering](#sqlmesh-macro-approach) occurs before the `@IF` condition is evaluated. For example, SQLMesh doesn't evaluate the condition `my_column > @my_value` until it has first substituted the number `@my_value` represents. + +Your macro might do things besides returning a value, such as printing a message or executing a statement (i.e., the macro "has side effects"). The side effect code will always run during the rendering step. To prevent this, modify the macro code to condition the side effects on the evaluation stage. + #### Pre/post-statements `@IF` may be used to conditionally execute pre/post-statements: @@ -607,7 +752,7 @@ If the column data types are known, the resulting query `CAST`s columns to their **NOTE**: the `exclude` argument used to be named `except_`. The latter is still supported but we discourage its use because it will be deprecated in the future. -Like all SQLMesh macro functions, omitting an argument when calling `@STAR` requires passing all subsequent arguments with their name and the special `:=` keyword operator. For example, we might omit the `alias` argument with `@STAR(foo, exclude := [c])`. Learn more about macro function arguments [below](#positional-and-keyword-arguments). +Like all SQLMesh macro functions, omitting an argument when calling `@STAR` requires passing subsequent arguments with their name and the special `:=` keyword operator. For example, we might omit the `alias` argument with `@STAR(foo, exclude := [c])`. Learn more about macro function arguments [below](#positional-and-keyword-arguments). As a `@STAR` example, consider the following query: @@ -618,6 +763,7 @@ FROM foo AS bar ``` The arguments to `@STAR` are: + 1. The name of the table `foo` (from the query's `FROM foo`) 2. The table alias `bar` (from the query's `AS bar`) 3. A list of columns to exclude from the selection, containing one column `c` @@ -635,6 +781,7 @@ FROM foo AS bar ``` Note these aspects of the rendered query: + - Each column is `CAST` to its data type in the table `foo` (e.g., `a` to `TEXT`) - Each column selection uses the alias `bar` (e.g., `"bar"."a"`) - Column `c` is not present because it was passed to `@STAR`'s `exclude` argument @@ -662,13 +809,14 @@ FROM foo AS bar ``` Note these aspects of the rendered query: + - Columns `a` and `b` have the prefix `"ab_pre_"` , while column `d` has the prefix `"d_pre_"` - Column `c` is not present because it was passed to the `exclude` argument in both `@STAR` calls - `my_column` is present in the query ### @GENERATE_SURROGATE_KEY -`@GENERATE_SURROGATE_KEY` generates a surrogate key from a set of columns. The surrogate key is a sequence of alphanumeric digits returned by the [`MD5` hash function](https://en.wikipedia.org/wiki/MD5) on the concatenated column values. +`@GENERATE_SURROGATE_KEY` generates a surrogate key from a set of columns. The surrogate key is a sequence of alphanumeric digits returned by a hash function, such as [`MD5`](https://en.wikipedia.org/wiki/MD5), on the concatenated column values. The surrogate key is created by: 1. `CAST`ing each column's value to `TEXT` (or the SQL engine's equivalent type) @@ -680,7 +828,7 @@ For example, the following query: ```sql linenums="1" SELECT - @GENERATE_SURROGATE_KEY(a, b, c) + @GENERATE_SURROGATE_KEY(a, b, c) AS col FROM foo ``` @@ -690,16 +838,40 @@ would be rendered as: SELECT MD5( CONCAT( - COALESCE(CAST(a AS TEXT), '_sqlmesh_surrogate_key_null_'), + COALESCE(CAST("a" AS TEXT), '_sqlmesh_surrogate_key_null_'), '|', - COALESCE(CAST(b AS TEXT), '_sqlmesh_surrogate_key_null_'), + COALESCE(CAST("b" AS TEXT), '_sqlmesh_surrogate_key_null_'), '|', - COALESCE(CAST(c AS TEXT), '_sqlmesh_surrogate_key_null_') + COALESCE(CAST("c" AS TEXT), '_sqlmesh_surrogate_key_null_') ) - ) + ) AS "col" +FROM "foo" AS "foo" +``` + +By default, the `MD5` function is used, but this behavior can change by setting the `hash_function` argument as follows: + +```sql linenums="1" +SELECT + @GENERATE_SURROGATE_KEY(a, b, c, hash_function := 'SHA256') AS col FROM foo ``` +This query will similarly be rendered as: + +```sql linenums="1" +SELECT + SHA256( + CONCAT( + COALESCE(CAST("a" AS TEXT), '_sqlmesh_surrogate_key_null_'), + '|', + COALESCE(CAST("b" AS TEXT), '_sqlmesh_surrogate_key_null_'), + '|', + COALESCE(CAST("c" AS TEXT), '_sqlmesh_surrogate_key_null_') + ) + ) AS "col" +FROM "foo" AS "foo" +``` + ### @SAFE_ADD `@SAFE_ADD` adds two or more operands, substituting `NULL`s with `0`s. It returns `NULL` if all operands are `NULL`. @@ -761,7 +933,9 @@ FROM foo `@UNION` returns a `UNION` query that selects all columns with matching names and data types from the tables. -Its first argument is the `UNION` "type", `'DISTINCT` (removing duplicated rows) or `'ALL'` (returning all rows). Subsequent arguments are the tables to be combined. +Its first argument can be either a condition or the `UNION` "type". If the first argument evaluates to a boolean (`TRUE` or `FALSE`), it's treated as a condition. If the condition is `FALSE`, only the first table is returned. If it's `TRUE`, the union operation is performed. + +If the first argument is not a boolean condition, it's treated as the `UNION` "type": either `'DISTINCT'` (removing duplicated rows) or `'ALL'` (returning all rows). Subsequent arguments are the tables to be combined. Let's assume that: @@ -788,6 +962,47 @@ SELECT FROM bar ``` +If the union type is omitted, `'ALL'` is used as the default. So the following expression: + +```sql linenums="1" +@UNION(foo, bar) +``` + +would be rendered as: + +```sql linenums="1" +SELECT + CAST(a AS INT) AS a, + CAST(c AS TEXT) AS c +FROM foo +UNION ALL +SELECT + CAST(a AS INT) AS a, + CAST(c AS TEXT) AS c +FROM bar +``` + +You can also use a condition to control whether the union happens: + +```sql linenums="1" +@UNION(1 > 0, 'all', foo, bar) +``` + +This would render the same as above. However, if the condition is `FALSE`: + +```sql linenums="1" +@UNION(1 > 2, 'all', foo, bar) +``` + +Only the first table would be selected: + +```sql linenums="1" +SELECT + CAST(a AS INT) AS a, + CAST(c AS TEXT) AS c +FROM foo +``` + ### @HAVERSINE_DISTANCE `@HAVERSINE_DISTANCE` returns the [haversine distance](https://en.wikipedia.org/wiki/Haversine_formula) between two geographic points. @@ -826,7 +1041,7 @@ It supports the following arguments, in this order: - `column`: The column to pivot - `values`: The values to use for pivoting (one column is created for each value in `values`) -- `alias`: Whether to create aliases for the resulting columns, defaults to true +- `alias` (optional): Whether to create aliases for the resulting columns, defaults to true - `agg` (optional): The aggregation function to use, defaults to `SUM` - `cmp` (optional): The comparison operator to use for comparing the column values, defaults to `=` - `prefix` (optional): A prefix to use for all aliases @@ -836,7 +1051,7 @@ It supports the following arguments, in this order: - `quote` (optional): Whether to quote the resulting aliases, defaults to true - `distinct` (optional): Whether to apply a `DISTINCT` clause for the aggregation function, defaults to false -SQLMesh macro operators do not accept named arguments. For example, `@PIVOT(column=column_to_pivot)` will error. +Like all SQLMesh macro functions, omitting an argument when calling `@PIVOT` requires passing subsequent arguments with their name and the special `:=` keyword operator. For example, we might omit the `agg` argument with `@PIVOT(status, ['cancelled', 'completed'], cmp := '<')`. Learn more about macro function arguments [below](#positional-and-keyword-arguments). For example, the following query: @@ -859,6 +1074,120 @@ FROM rides GROUP BY 1 ``` +### @DEDUPLICATE + +`@DEDUPLICATE` is used to deduplicate rows in a table based on the specified partition and order columns with a window function. + +It supports the following arguments, in this order: + +- `relation`: The table or CTE name to deduplicate +- `partition_by`: column names, or expressions to use to identify a window of rows out of which to select one as the deduplicated row +- `order_by`: A list of strings representing the ORDER BY clause, optional - you can add nulls ordering like this: [' desc nulls last'] + +For example, the following query: +```sql linenums="1" +with raw_data as ( +@deduplicate(my_table, [id, cast(event_date as date)], ['event_date DESC', 'status ASC']) +) + +select * from raw_data +``` + +would be rendered as: + +```sql linenums="1" +WITH "raw_data" AS ( + SELECT + * + FROM "my_table" AS "my_table" + QUALIFY + ROW_NUMBER() OVER (PARTITION BY "id", CAST("event_date" AS DATE) ORDER BY "event_date" DESC, "status" ASC) = 1 +) +SELECT + * +FROM "raw_data" AS "raw_data" +``` + +### @DATE_SPINE + +`@DATE_SPINE` returns the SQL required to build a date spine. The spine will include the start_date (if it is aligned to the datepart), AND it will include the end_date. This is different from the [`date_spine`](https://github.com/dbt-labs/dbt-utils?tab=readme-ov-file#date_spine-source) macro in `dbt-utils` which will NOT include the end_date. It's typically used to join in unique, hard-coded, date ranges to with other tables/views, so people don't have to constantly adjust date ranges in `where` clauses across many SQL models. + +It supports the following arguments, in this order: + +- `datepart`: The datepart to use for the date spine - day, week, month, quarter, year +- `start_date`: The start date for the date spine in format YYYY-MM-DD +- `end_date`: The end date for the date spine in format YYYY-MM-DD + +For example, the following query: +```sql linenums="1" +WITH discount_promotion_dates AS ( + @date_spine('day', '2024-01-01', '2024-01-16') +) + +SELECT * FROM discount_promotion_dates +``` + +would be rendered as: + +```sql linenums="1" +WITH "discount_promotion_dates" AS ( + SELECT + "_exploded"."date_day" AS "date_day" + FROM UNNEST(CAST(GENERATE_SERIES(CAST('2024-01-01' AS DATE), CAST('2024-01-16' AS DATE), INTERVAL '1' DAY) AS +DATE[])) AS "_exploded"("date_day") +) +SELECT + "discount_promotion_dates"."date_day" AS "date_day" +FROM "discount_promotion_dates" AS "discount_promotion_dates" +``` + +Note: This is DuckDB SQL and other dialects will be transpiled accordingly. +- Recursive CTEs (common table expressions) will be used for `Redshift / MySQL / MSSQL`. +- For `MSSQL` in particular, there's a recursion limit of approximately 100. If this becomes a problem, you can add an `OPTION (MAXRECURSION 0)` clause after the date spine macro logic to remove the limit. This applies for long date ranges. + +### @RESOLVE_TEMPLATE + +`@resolve_template` is a helper macro intended to be used in situations where you need to gain access to the *components* of the physical object name. It's intended for use in the following situations: + +- Providing explicit control over table locations on a per-model basis for engines that decouple storage and compute (such as Athena, Trino, Spark etc) +- Generating references to engine-specific metadata tables that are derived from the physical table name, such as the [`$properties`](https://trino.io/docs/current/connector/iceberg.html#metadata-tables) metadata table in Trino. + +Under the hood, it uses the `@this_model` variable so it can only be used during the `creating` and `evaluation` [runtime stages](./macro_variables.md#runtime-variables). Attempting to use it at the `loading` runtime stage will result in a no-op. + +The `@resolve_template` macro supports the following arguments: + + - `template` - The string template to render into an AST node + - `mode` - What type of SQLGlot AST node to return after rendering the template. Valid values are `literal` or `table`. Defaults to `literal`. + +The `template` can contain the following placeholders that will be substituted: + + - `@{catalog_name}` - The name of the catalog, eg `datalake` + - `@{schema_name}` - The name of the physical schema that SQLMesh is using for the model version table, eg `sqlmesh__landing` + - `@{table_name}` - The name of the physical table that SQLMesh is using for the model version, eg `landing__customers__2517971505` + +Note the use of the curly brace syntax `@{}` in the template placeholders - learn more [above](#embedding-variables-in-strings). + +The `@resolve_template` macro can be used in a `MODEL` block: + +```sql linenums="1" hl_lines="5" +MODEL ( + name datalake.landing.customers, + ... + physical_properties ( + location = @resolve_template('s3://warehouse-data/@{catalog_name}/prod/@{schema_name}/@{table_name}') + ) +); +-- CREATE TABLE "datalake"."sqlmesh__landing"."landing__customers__2517971505" ... +-- WITH (location = 's3://warehouse-data/datalake/prod/sqlmesh__landing/landing__customers__2517971505') +``` + +And also within a query, using `mode := 'table'`: + +```sql linenums="1" +SELECT * FROM @resolve_template('@{catalog_name}.@{schema_name}.@{table_name}$properties', mode := 'table') +-- SELECT * FROM "datalake"."sqlmesh__landing"."landing__customers__2517971505$properties" +``` + ### @AND `@AND` combines a sequence of operands using the `AND` operator, filtering out any NULL expressions. @@ -1262,7 +1591,9 @@ If an argument has a default value, the value is not parsed by SQLGlot before th #### Positional and keyword arguments -In a macro call, the arguments may be provided by position if none are skipped. For example, consider the `add_args()` function - it has three arguments with default values provided in the function definition: +In a macro call, the arguments may be provided by position if none are skipped. + +For example, consider the `add_args()` function - it has three arguments with default values provided in the function definition: ```python linenums="1" from sqlmesh import macro @@ -1279,7 +1610,7 @@ def add_args( An `@add_args` call providing values for all arguments accepts positional arguments like this: `@add_args(5, 6, 7)` (which returns 5 + 6 + 7 = `18`). A call omitting and using the default value for the the final `argument_3` can also use positional arguments: `@add_args(5, 6)` (which returns 5 + 6 + 3 = `14`). -However, skipping an argument requires providing all subsequent argument names (i.e., using "keyword arguments"). For example, skipping the second argument above by just omitting it - `@add_args(5, , 7)` - results in an error. +However, skipping an argument requires specifying the names of subsequent arguments (i.e., using "keyword arguments"). For example, skipping the second argument above by just omitting it - `@add_args(5, , 7)` - results in an error. Unlike Python, SQLMesh keyword arguments must use the special operator `:=`. To skip and use the default value for the second argument above, the call must name the third argument: `@add_args(5, argument_3 := 8)` (which returns 5 + 2 + 8 = `15`). @@ -1430,6 +1761,73 @@ def some_macro(evaluator): ... ``` +#### Accessing model, physical table, and virtual layer view names + +All SQLMesh models have a name in their `MODEL` specification. We refer to that as the model's "unresolved" name because it may not correspond to any specific object in the SQL engine. + +When SQLMesh renders and executes a model, it converts the model name into three forms at different stages: + +1. The *fully qualified* name + + - If the model name is of the form `schema.table`, SQLMesh determines the correct catalog and adds it, like `catalog.schema.table` + - SQLMesh quotes each component of the name using the SQL engine's quoting and case-sensitivity rules, like `"catalog"."schema"."table"` + +2. The *resolved* physical table name + + - The qualified name of the model's underlying physical table + +3. The *resolved* virtual layer view name + + - The qualified name of the model's virtual layer view in the environment where the model is being executed + +You can access any of these three forms in a Python macro through properties of the `evaluation` context object. + +Access the unresolved, fully-qualified name through the `this_model_fqn` property. + +```python linenums="1" +from sqlmesh.core.macros import macro + +@macro() +def some_macro(evaluator): + # Example: + # Name in model definition: landing.customers + # Value returned here: '"datalake"."landing"."customers"' + unresolved_model_fqn = evaluator.this_model_fqn + ... +``` + +Access the resolved physical table and virtual layer view names through the `this_model` property. + +The `this_model` property returns different names depending on the runtime stage: + +- `promoting` runtime stage: `this_model` resolves to the virtual layer view name + + - Example + - Model name is `db.test_model` + - `plan` is running in the `dev` environment + - `this_model` resolves to `"catalog"."db__dev"."test_model"` (note the `__dev` suffix in the schema name) + +- All other runtime stages: `this_model` resolves to the physical table name + + - Example + - Model name is `db.test_model` + - `plan` is running in any environment + - `this_model` resolves to `"catalog"."sqlmesh__project"."project__test_model__684351896"` + +```python linenums="1" +from sqlmesh.core.macros import macro + +@macro() +def some_macro(evaluator): + if evaluator.runtime_stage == "promoting": + # virtual layer view name '"catalog"."db__dev"."test_model"' + resolved_name = evaluator.this_model + else: + # physical table name '"catalog"."sqlmesh__project"."project__test_model__684351896"' + resolved_name = evaluator.this_model + ... +``` + #### Accessing model schemas Model schemas can be accessed within a Python macro function through its evaluation context's `column_to_types()` method, if the column types can be statically determined. For instance, a schema of an [external model](../models/external_models.md) can be accessed only after the `sqlmesh create_external_models` command has been executed. @@ -1481,6 +1879,8 @@ Accessing the schema of an upstream model can be useful for various reasons. For Thus, leveraging `columns_to_types` can also enable one to write code according to the [DRY](https://en.wikipedia.org/wiki/Don%27t_repeat_yourself) principle, as a single macro function can implement the transformations instead of creating a different macro for each model of interest. +Note: there may be models whose schema is not available when the project is being loaded, in which case a special placeholder column will be returned, aptly named: `__schema_unavailable_at_load__`. In some cases, the macro's implementation will need to account for this placeholder in order to avoid issues due to the schema being unavailable. + #### Accessing snapshots After a SQLMesh project has been successfully loaded, its snapshots can be accessed in Python macro functions and Python models that generate SQL through the `get_snapshot` method of `MacroEvaluator`. @@ -1545,10 +1945,17 @@ The methods are available because the `column` argument is parsed as a SQLGlot [ Column expressions are sub-classes of the [Condition class](https://sqlglot.com/sqlglot/expressions.html#Condition), so they have builder methods like [`between`](https://sqlglot.com/sqlglot/expressions.html#Condition.between) and [`like`](https://sqlglot.com/sqlglot/expressions.html#Condition.like). -#### Metadata only macros as model pre/post-statements -When you first use your macro functions in your models as pre/post-statements, SQLMesh will identify those models as directly modified the next time you create a plan. These models will then need backfills. The same thing applies when you edit or remove these pre/post-statements. If your macro does not have any effect on your models' data and you do not want it to trigger backfills, you can configure your macro to be part of a model's metadata. That way, SQLMesh can still detect changes and create new snapshots for your models when you add, edit, or delete your macro pre/post-statements. To do this, pass in True to the `metadata_only` parameter of the `@macro()` decorator. +#### Macro pre/post-statements -```python linenums="1" +Macro functions may be used to generate pre/post-statements in a model. + +By default, when you first add the pre/post-statement macro functions to a model, SQLMesh will treat those models as directly modified and require a backfill in the next plan. SQLMesh will also treat edits to or removals of pre/post-statement macros as a breaking change. + +If your macro does not affect the data returned by a model and you do not want its addition/editing/removal to trigger a backfill, you can specify in the macro definition that it only affects the model's metadata. SQLMesh will still detect changes and create new snapshots for a model when you add/edit/remove the macro, but it will not view the change as breaking and require a backfill. + +Specify that a macro only affects a model's metadata by setting the `@macro()` decorator's `metadata_only` argument to `True`. For example: + +```python linenums="1" hl_lines="3" from sqlmesh import macro @macro(metadata_only=True) @@ -1574,11 +1981,15 @@ Typed macros in SQLMesh use Python's type hints. Here's a simple example of a ty from sqlmesh import macro @macro() -def repeat_string(evaluator, text: str, count: int) -> str: +def repeat_string(evaluator, text: str, count: int): return text * count ``` -Usage in SQLMesh: +This macro takes two arguments: `text` of type `str` and `count` of type `int`, and it returns a string. + +Without type hints, the inputs are two SQLGlot `exp.Literal` objects you would need to manually convert to Python `str` and `int` types. With type hints, you can work with them as string and integer types directly. + +Let's try to use the macro in a SQLMesh model: ```sql linenums="1" SELECT @@ -1586,7 +1997,44 @@ SELECT FROM some_table; ``` -This macro takes two arguments: `text` of type `str` and `count` of type `int`, and it returns a string. Without type hints, the inputs to the macro would have been two `exp.Literal` objects you would have had to convert to strings and integers manually. +Unfortunately, this model generates an error when rendered: + +``` +Error: Invalid expression / Unexpected token. Line 1, Col: 23. + SQLMesh SQLMesh SQLMesh +``` + +Why? The macro returned `SQLMesh SQLMesh SQLMesh` as expected, but that string is not valid SQL in the rendered query: + +```sql linenums="1" hl_lines="2" +SELECT + SQLMesh SQLMesh SQLMesh as repeated_string ### invalid SQL code +FROM some_table; +``` + +The problem is a mismatch between our macro's Python return type `str` and the type expected by the parsed SQL query. + +Recall that SQLMesh macros work by modifying the query's semantic representation. In that representation, a SQLGlot string literal type is expected. SQLMesh will do its best to return the type expected by the query's semantic representation, but that is not possible in all scenarios. + +Therefore, we must explicitly convert the output with SQLGlot's `exp.Literal.string()` method: + +```python linenums="1" hl_lines="5" +from sqlmesh import macro + +@macro() +def repeat_string(evaluator, text: str, count: int): + return exp.Literal.string(text * count) +``` + +Now the query will render with a valid single-quoted string literal: + +```sql linenums="1" +SELECT + 'SQLMesh SQLMesh SQLMesh ' AS "repeated_string" +FROM "some_table" AS "some_table" +``` + +Typed macros coerce the **inputs** to a macro function, but the macro code is responsible for coercing the **output** to the type expected by the query's semantic representation. #### Supported Types @@ -1596,10 +2044,12 @@ SQLMesh supports common Python types for typed macros including: - `int` - `float` - `bool` +- `datetime.datetime` +- `datetime.date` - `SQL` -- When you want the SQL string representation of the argument that's passed in -- `List[T]` - where `T` is any supported type including sqlglot expressions -- `Tuple[T]` - where `T` is any supported type including sqlglot expressions -- `Union[T1, T2, ...]` - where `T1`, `T2`, etc. are any supported types including sqlglot expressions +- `list[T]` - where `T` is any supported type including sqlglot expressions +- `tuple[T]` - where `T` is any supported type including sqlglot expressions +- `T1 | T2 | ...` - where `T1`, `T2`, etc. are any supported types including sqlglot expressions We also support SQLGlot expressions as type hints, allowing you to ensure inputs are coerced to the desired SQL AST node your intending on working with. Some useful examples include: @@ -1661,7 +2111,7 @@ FROM some_table; Generics can be nested and are resolved recursively allowing for fairly robust type hinting. -See examples of the coercion function in action in the test suite [here](https://github.com/TobikoData/sqlmesh/blob/main/tests/core/test_macros.py). +See examples of the coercion function in action in the test suite [here](https://github.com/SQLMesh/sqlmesh/blob/main/tests/core/test_macros.py). #### Conclusion diff --git a/docs/concepts/models/external_models.md b/docs/concepts/models/external_models.md index a8557813bc..ef2b39a10c 100644 --- a/docs/concepts/models/external_models.md +++ b/docs/concepts/models/external_models.md @@ -56,6 +56,8 @@ If SQLMesh does not have access to an external table's metadata, the table will In some use-cases such as [isolated systems with multiple gateways](../../guides/isolated_systems.md#multiple-gateways), there are external models that only exist on a certain gateway. +**Gateway names are case-insensitive in external model configurations.** You can specify the gateway name using any case (e.g., `gateway: dev`, `gateway: DEV`, `gateway: Dev`) and SQLMesh will handle the matching correctly. + Consider the following model that queries an external table with a dynamic database based on the current gateway: ``` @@ -70,7 +72,9 @@ FROM @{gateway}_db.external_table; ``` -This table will be named differently depending on which `--gateway` SQLMesh is run with. For example: +This table will be named differently depending on which `--gateway` SQLMesh is run with (learn more about the curly brace `@{gateway}` syntax [here](../../concepts/macros/sqlmesh_macros.md#embedding-variables-in-strings)). + +For example: - `sqlmesh --gateway dev plan` - SQLMesh will try to query `dev_db.external_table` - `sqlmesh --gateway prod plan` - SQLMesh will try to query `prod_db.external_table` @@ -98,7 +102,7 @@ This example demonstrates the structure of a `external_models.yaml` file: column_d: float - name: external_db.gateway_specific_external_table description: Another external table that only exists when the gateway is set to "test" - gateway: test + gateway: test # Case-insensitive - could also be "TEST", "Test", etc. columns: column_e: int column_f: varchar diff --git a/docs/concepts/models/managed_models.md b/docs/concepts/models/managed_models.md index 5e167e80f5..786c6aa89d 100644 --- a/docs/concepts/models/managed_models.md +++ b/docs/concepts/models/managed_models.md @@ -7,10 +7,14 @@ For supported engines, we expose this functionality through Managed models. This Due to this, managed models would typically be built off an [External Model](./external_models.md) rather than another SQLMesh model. Since SQLMesh already ensures that models it's tracking are kept up to date, the main benefit of managed models comes when they read from external tables that arent tracked by SQLMesh. +!!! warning "Not supported in Python models" + + Python models do not support the `MANAGED` [model kind](./model_kinds.md) - use a SQL model isntead. + ## Difference from materialized views The difference between an Managed model and a materialized view is down to semantics and in some engines there is no difference. -SQLMesh has support for [materialized views](./model_kinds#materialized-views) already. However, depending on the engine, these are subject to some limitations, such as: +SQLMesh has support for [materialized views](../model_kinds#materialized-views) already. However, depending on the engine, these are subject to some limitations, such as: - A Materialized View query can only be derived from a single base table - The Materialized View is not automatically maintained by the engine. To refresh the data, a `REFRESH MATERIALIZED VIEW` or equivalent command must be issued @@ -34,6 +38,11 @@ However, there is usually extra vendor-imposed costs associated with Managed mod Therefore, we try to not create managed tables unnecessarily. For example, in [forward-only plans](../plans.md#forward-only-change) we just create a normal table to preview the changes and only re-create the managed table on deployment to prod. +!!! warning + Due to the use of normal tables for dev previews, it is possible to write a query that uses features that are available to normal tables in the target engine but not managed tables. This could result in a scenario where a plan works in a dev environment but fails when deployed to production. + + We believe the cost savings are worth it, however please [reach out](https://tobikodata.com/slack) if this causes problems for you. + ## Supported Engines SQLMesh supports managed models in the following database engines: @@ -79,9 +88,9 @@ AS SELECT FROM raw_events ``` -!!! info +!!! note - Note that SQLMesh will not create intervals and run this model for each interval, so there is no need to add a WHERE clause with date filters like you would for a normal incremental model. How the data in this model is refreshed is completely up to Snowflake. + SQLMesh will not create intervals and run this model for each interval, so there is no need to add a WHERE clause with date filters like you would for a normal incremental model. How the data in this model is refreshed is completely up to Snowflake. #### Table properties diff --git a/docs/concepts/models/model_kinds.md b/docs/concepts/models/model_kinds.md index d529a3de64..d01cc738a6 100644 --- a/docs/concepts/models/model_kinds.md +++ b/docs/concepts/models/model_kinds.md @@ -2,6 +2,8 @@ This page describes the kinds of [models](./overview.md) SQLMesh supports, which determine how the data for a model is loaded. +Find information about all model kind configuration parameters in the [model configuration reference page](../../reference/model_configuration.md). + ## INCREMENTAL_BY_TIME_RANGE Models of the `INCREMENTAL_BY_TIME_RANGE` kind are computed incrementally based on a time range. This is an optimal choice for datasets in which records are captured over time and represent immutable facts such as events, logs, or transactions. Using this kind for appropriate datasets typically results in significant cost and time savings. @@ -10,7 +12,7 @@ Only missing time intervals are processed during each execution for `INCREMENTAL An `INCREMENTAL_BY_TIME_RANGE` model has two requirements that other models do not: it must know which column contains the time data it will use to filter the data by time range, and it must contain a `WHERE` clause that filters the upstream data by time. -The name of the column containing time data is specified in the model's `MODEL` DDL. It is specified ih the DDL `kind` specification's `time_column` key. This example shows the `MODEL` DDL for an `INCREMENTAL_BY_TIME_RANGE` model that stores time data in the "event_date" column: +The name of the column containing time data is specified in the model's `MODEL` DDL. It is specified in the DDL `kind` specification's `time_column` key. This example shows the `MODEL` DDL for an `INCREMENTAL_BY_TIME_RANGE` model that stores time data in the "event_date" column: ```sql linenums="1" MODEL ( @@ -21,8 +23,308 @@ MODEL ( ); ``` + In addition to specifying a time column in the `MODEL` DDL, the model's query must contain a `WHERE` clause that filters the upstream records by time range. SQLMesh provides special macros that represent the start and end of the time range being processed: `@start_date` / `@end_date` and `@start_ds` / `@end_ds`. Refer to [Macros](../macros/macro_variables.md) for more information. +??? "Example SQL sequence when applying this model kind (ex: BigQuery)" + This is borrowed from the full walkthrough: [Incremental by Time Range](../../examples/incremental_time_full_walkthrough.md) + + Create a model with the following definition and run `sqlmesh plan dev`: + + ```sql + MODEL ( + name demo.incrementals_demo, + kind INCREMENTAL_BY_TIME_RANGE ( + -- How does this model kind behave? + -- DELETE by time range, then INSERT + time_column transaction_date, + + -- How do I handle late-arriving data? + -- Handle late-arriving events for the past 2 (2*1) days based on cron + -- interval. Each time it runs, it will process today, yesterday, and + -- the day before yesterday. + lookback 2, + ), + + -- Don't backfill data before this date + start '2024-10-25', + + -- What schedule should I run these at? + -- Daily at Midnight UTC + cron '@daily', + + -- Good documentation for the primary key + grain transaction_id, + + -- How do I test this data? + -- Validate that the `transaction_id` primary key values are both unique + -- and non-null. Data audit tests only run for the processed intervals, + -- not for the entire table. + -- audits ( + -- UNIQUE_VALUES(columns = (transaction_id)), + -- NOT_NULL(columns = (transaction_id)) + -- ) + ); + + WITH sales_data AS ( + SELECT + transaction_id, + product_id, + customer_id, + transaction_amount, + -- How do I account for UTC vs. PST (California baby) timestamps? + -- Make sure all time columns are in UTC and convert them to PST in the + -- presentation layer downstream. + transaction_timestamp, + payment_method, + currency + FROM sqlmesh-public-demo.tcloud_raw_data.sales -- Source A: sales data + -- How do I make this run fast and only process the necessary intervals? + -- Use our date macros that will automatically run the necessary intervals. + -- Because SQLMesh manages state, it will know what needs to run each time + -- you invoke `sqlmesh run`. + WHERE transaction_timestamp BETWEEN @start_dt AND @end_dt + ), + + product_usage AS ( + SELECT + product_id, + customer_id, + last_usage_date, + usage_count, + feature_utilization_score, + user_segment + FROM sqlmesh-public-demo.tcloud_raw_data.product_usage -- Source B + -- Include usage data from the 30 days before the interval + WHERE last_usage_date BETWEEN DATE_SUB(@start_dt, INTERVAL 30 DAY) AND @end_dt + ) + + SELECT + s.transaction_id, + s.product_id, + s.customer_id, + s.transaction_amount, + -- Extract the date from the timestamp to partition by day + DATE(s.transaction_timestamp) as transaction_date, + -- Convert timestamp to PST using a SQL function in the presentation layer for end users + DATETIME(s.transaction_timestamp, 'America/Los_Angeles') as transaction_timestamp_pst, + s.payment_method, + s.currency, + -- Product usage metrics + p.last_usage_date, + p.usage_count, + p.feature_utilization_score, + p.user_segment, + -- Derived metrics + CASE + WHEN p.usage_count > 100 AND p.feature_utilization_score > 0.8 THEN 'Power User' + WHEN p.usage_count > 50 THEN 'Regular User' + WHEN p.usage_count IS NULL THEN 'New User' + ELSE 'Light User' + END as user_type, + -- Time since last usage + DATE_DIFF(s.transaction_timestamp, p.last_usage_date, DAY) as days_since_last_usage + FROM sales_data s + LEFT JOIN product_usage p + ON s.product_id = p.product_id + AND s.customer_id = p.customer_id + ``` + + SQLMesh will execute this SQL to create a versioned table in the physical layer. Note that the table's version fingerprint, `50975949`, is part of the table name. + + ```sql + CREATE TABLE IF NOT EXISTS `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__incrementals_demo__50975949` ( + `transaction_id` STRING, + `product_id` STRING, + `customer_id` STRING, + `transaction_amount` NUMERIC, + `transaction_date` DATE OPTIONS (description='We extract the date from the timestamp to partition by day'), + `transaction_timestamp_pst` DATETIME OPTIONS (description='Convert this to PST using a SQL function'), + `payment_method` STRING, + `currency` STRING, + `last_usage_date` TIMESTAMP, + `usage_count` INT64, + `feature_utilization_score` FLOAT64, + `user_segment` STRING, + `user_type` STRING OPTIONS (description='Derived metrics'), + `days_since_last_usage` INT64 OPTIONS (description='Time since last usage') + ) + PARTITION BY `transaction_date` + ``` + + SQLMesh will validate the SQL before processing data (note the `WHERE FALSE LIMIT 0` and the placeholder timestamps). + + ```sql + WITH `sales_data` AS ( + SELECT + `sales`.`transaction_id` AS `transaction_id`, + `sales`.`product_id` AS `product_id`, + `sales`.`customer_id` AS `customer_id`, + `sales`.`transaction_amount` AS `transaction_amount`, + `sales`.`transaction_timestamp` AS `transaction_timestamp`, + `sales`.`payment_method` AS `payment_method`, + `sales`.`currency` AS `currency` + FROM `sqlmesh-public-demo`.`tcloud_raw_data`.`sales` AS `sales` + WHERE ( + `sales`.`transaction_timestamp` <= CAST('1970-01-01 23:59:59.999999+00:00' AS TIMESTAMP) AND + `sales`.`transaction_timestamp` >= CAST('1970-01-01 00:00:00+00:00' AS TIMESTAMP)) AND + FALSE + ), + `product_usage` AS ( + SELECT + `product_usage`.`product_id` AS `product_id`, + `product_usage`.`customer_id` AS `customer_id`, + `product_usage`.`last_usage_date` AS `last_usage_date`, + `product_usage`.`usage_count` AS `usage_count`, + `product_usage`.`feature_utilization_score` AS `feature_utilization_score`, + `product_usage`.`user_segment` AS `user_segment` + FROM `sqlmesh-public-demo`.`tcloud_raw_data`.`product_usage` AS `product_usage` + WHERE ( + `product_usage`.`last_usage_date` <= CAST('1970-01-01 23:59:59.999999+00:00' AS TIMESTAMP) AND + `product_usage`.`last_usage_date` >= CAST('1969-12-02 00:00:00+00:00' AS TIMESTAMP) + ) AND + FALSE + ) + + SELECT + `s`.`transaction_id` AS `transaction_id`, + `s`.`product_id` AS `product_id`, + `s`.`customer_id` AS `customer_id`, + CAST(`s`.`transaction_amount` AS NUMERIC) AS `transaction_amount`, + DATE(`s`.`transaction_timestamp`) AS `transaction_date`, + DATETIME(`s`.`transaction_timestamp`, 'America/Los_Angeles') AS `transaction_timestamp_pst`, + `s`.`payment_method` AS `payment_method`, + `s`.`currency` AS `currency`, + `p`.`last_usage_date` AS `last_usage_date`, + `p`.`usage_count` AS `usage_count`, + `p`.`feature_utilization_score` AS `feature_utilization_score`, + `p`.`user_segment` AS `user_segment`, + CASE + WHEN `p`.`feature_utilization_score` > 0.8 AND `p`.`usage_count` > 100 THEN 'Power User' + WHEN `p`.`usage_count` > 50 THEN 'Regular User' + WHEN `p`.`usage_count` IS NULL THEN 'New User' + ELSE 'Light User' + END AS `user_type`, + DATE_DIFF(`s`.`transaction_timestamp`, `p`.`last_usage_date`, DAY) AS `days_since_last_usage` + FROM `sales_data` AS `s` + LEFT JOIN `product_usage` AS `p` + ON `p`.`customer_id` = `s`.`customer_id` AND + `p`.`product_id` = `s`.`product_id` + WHERE FALSE + LIMIT 0 + ``` + + SQLMesh will merge data into the empty table. + + ```sql + MERGE INTO `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__incrementals_demo__50975949` AS `__MERGE_TARGET__` USING ( + WITH `sales_data` AS ( + SELECT + `transaction_id`, + `product_id`, + `customer_id`, + `transaction_amount`, + `transaction_timestamp`, + `payment_method`, + `currency` + FROM `sqlmesh-public-demo`.`tcloud_raw_data`.`sales` AS `sales` + WHERE `transaction_timestamp` BETWEEN CAST('2024-10-25 00:00:00+00:00' AS TIMESTAMP) AND CAST('2024-11-04 23:59:59.999999+00:00' AS TIMESTAMP) + ), + `product_usage` AS ( + SELECT + `product_id`, + `customer_id`, + `last_usage_date`, + `usage_count`, + `feature_utilization_score`, + `user_segment` + FROM `sqlmesh-public-demo`.`tcloud_raw_data`.`product_usage` AS `product_usage` + WHERE `last_usage_date` BETWEEN DATE_SUB(CAST('2024-10-25 00:00:00+00:00' AS TIMESTAMP), INTERVAL '30' DAY) AND CAST('2024-11-04 23:59:59.999999+00:00' AS TIMESTAMP) + ) + + SELECT + `transaction_id`, + `product_id`, + `customer_id`, + `transaction_amount`, + `transaction_date`, + `transaction_timestamp_pst`, + `payment_method`, + `currency`, + `last_usage_date`, + `usage_count`, + `feature_utilization_score`, + `user_segment`, + `user_type`, + `days_since_last_usage` + FROM ( + SELECT + `s`.`transaction_id` AS `transaction_id`, + `s`.`product_id` AS `product_id`, + `s`.`customer_id` AS `customer_id`, + `s`.`transaction_amount` AS `transaction_amount`, + DATE(`s`.`transaction_timestamp`) AS `transaction_date`, + DATETIME(`s`.`transaction_timestamp`, 'America/Los_Angeles') AS `transaction_timestamp_pst`, + `s`.`payment_method` AS `payment_method`, + `s`.`currency` AS `currency`, + `p`.`last_usage_date` AS `last_usage_date`, + `p`.`usage_count` AS `usage_count`, + `p`.`feature_utilization_score` AS `feature_utilization_score`, + `p`.`user_segment` AS `user_segment`, + CASE + WHEN `p`.`usage_count` > 100 AND `p`.`feature_utilization_score` > 0.8 THEN 'Power User' + WHEN `p`.`usage_count` > 50 THEN 'Regular User' + WHEN `p`.`usage_count` IS NULL THEN 'New User' + ELSE 'Light User' + END AS `user_type`, + DATE_DIFF(`s`.`transaction_timestamp`, `p`.`last_usage_date`, DAY) AS `days_since_last_usage` + FROM `sales_data` AS `s` + LEFT JOIN `product_usage` AS `p` + ON `s`.`product_id` = `p`.`product_id` + AND `s`.`customer_id` = `p`.`customer_id` + ) AS `_subquery` + WHERE `transaction_date` BETWEEN CAST('2024-10-25' AS DATE) AND CAST('2024-11-04' AS DATE) + ) AS `__MERGE_SOURCE__` + ON FALSE + WHEN NOT MATCHED BY SOURCE AND `transaction_date` BETWEEN CAST('2024-10-25' AS DATE) AND CAST('2024-11-04' AS DATE) THEN DELETE + WHEN NOT MATCHED THEN + INSERT ( + `transaction_id`, `product_id`, `customer_id`, `transaction_amount`, `transaction_date`, `transaction_timestamp_pst`, + `payment_method`, `currency`, `last_usage_date`, `usage_count`, `feature_utilization_score`, `user_segment`, `user_type`, + `days_since_last_usage` + ) + VALUES ( + `transaction_id`, `product_id`, `customer_id`, `transaction_amount`, `transaction_date`, `transaction_timestamp_pst`, + `payment_method`, `currency`, `last_usage_date`, `usage_count`, `feature_utilization_score`, `user_segment`, `user_type`, + `days_since_last_usage` + ) + ``` + + SQLMesh will create a suffixed `__dev` schema based on the name of the plan environment. + + ```sql + CREATE SCHEMA IF NOT EXISTS `sqlmesh-public-demo`.`demo__dev` + ``` + + SQLMesh will create a view in the virtual layer to pointing to the versioned table in the physical layer. + + ```sql + CREATE OR REPLACE VIEW `sqlmesh-public-demo`.`demo__dev`.`incrementals_demo` AS + SELECT * + FROM `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__incrementals_demo__50975949` + ``` + +!!! tip "Important" + + A model's `time_column` should be in the [UTC time zone](https://en.wikipedia.org/wiki/Coordinated_Universal_Time) to ensure correct interaction with SQLMesh's scheduler and predefined macro variables. + + This requirement aligns with the data engineering best practice of converting datetime/timestamp columns to UTC as soon as they are ingested into the data system and only converting them to local timezones when they exit the system for downstream uses. The `cron_tz` flag **does not** change this requirement. + + Placing all timezone conversion code in the system's first/last transformation models prevents inadvertent timezone-related errors as data flows between models. + + If a model must use a different timezone, parameters like [lookback](./overview.md#lookback), [allow_partials](./overview.md#allow_partials), and [cron](./overview.md#cron) with offset time can be used to try to account for misalignment between the model's timezone and the UTC timezone used by SQLMesh. + + This example implements a complete `INCREMENTAL_BY_TIME_RANGE` model that specifies the time column name `event_date` in the `MODEL` DDL and includes a SQL `WHERE` clause to filter records by time range: ```sql linenums="1" hl_lines="3-5 12-13" @@ -44,6 +346,10 @@ WHERE ### Time column SQLMesh needs to know which column in the model's output represents the timestamp or date associated with each record. +!!! tip "Important" + + The `time_column` variable should be in the UTC time zone - learn more [above](#timezones). + The time column is used to determine which records will be overwritten during data [restatement](../plans.md#restatement-plans) and provides a partition key for engines that support partitioning (such as Apache Spark). The name of the time column is specified in the `MODEL` DDL `kind` specification: ```sql linenums="1" hl_lines="4" @@ -64,7 +370,10 @@ MODEL ( ) ); ``` -**Note:** The time format should be defined using the same SQL dialect as the one used to define the model's query. + +!!! note + + The time format should be defined using the same SQL dialect as the one used to define the model's query. SQLMesh also uses the time column to automatically append a time range filter to the model's query at runtime, which prevents records that are not part of the target interval from being stored. This is a safety mechanism that prevents unintentionally overwriting unrelated records when handling late-arriving data. @@ -99,10 +408,29 @@ WHERE AND event_date BETWEEN @start_ds AND @end_ds; -- `event_date` time column filter automatically added by SQLMesh ``` +### Partitioning + +By default, we ensure that the `time_column` is part of the [partitioned_by](./overview.md#partitioned_by) property of the model so that it forms part of the partition key and allows the database engine to do partition pruning. If it is not explicitly listed in the Model definition, we will automatically add it. + +However, this may be undesirable if you want to exclusively partition on another column or you want to partition on something like `month(time_column)` but the engine you're using doesnt support partitioning based on expressions. + +To opt out of this behaviour, you can set `partition_by_time_column false` like so: + +```sql linenums="1" hl_lines="5" +MODEL ( + name db.events, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column event_date, + partition_by_time_column false + ), + partitioned_by (other_col) -- event_date will no longer be automatically added here and the partition key will just be 'other_col' +); +``` + ### Idempotency -It is recommended that queries of models of this kind are [idempotent](../glossary.md#idempotency) to prevent unexpected results during data [restatement](../plans.md#restatement-plans). +We recommend making sure incremental by time range model queries are [idempotent](../glossary.md#idempotency) to prevent unexpected results during data [restatement](../plans.md#restatement-plans). -Note, however, that upstream models and tables can impact a model's idempotency. For example, referencing an upstream model of kind [FULL](#full) in the model query automatically causes the model to be non-idempotent. +Note, however, that upstream models and tables can impact a model's idempotency. For example, referencing an upstream model of kind [FULL](#full) in the model query automatically causes the model to be non-idempotent because its data could change on every model execution. ### Materialization strategy Depending on the target engine, models of the `INCREMENTAL_BY_TIME_RANGE` kind are materialized using the following strategies: @@ -117,71 +445,21 @@ Depending on the target engine, models of the `INCREMENTAL_BY_TIME_RANGE` kind a | Postgres | DELETE by time range, then INSERT | | DuckDB | DELETE by time range, then INSERT | -## INCREMENTAL_BY_PARTITION - -Models of the `INCREMENTAL_BY_PARTITION` kind are computed incrementally based on partition. A set of columns defines the model's partitioning key, and a partition is the group of rows with the same partitioning key value. - -This model kind is designed for the scenario where data rows should be loaded and updated as a group based on their shared value for the partitioning key. This kind may be used with any SQL engine; SQLMesh will automatically create partitioned tables on engines that support explicit table partitioning (e.g., [BigQuery](https://cloud.google.com/bigquery/docs/creating-partitioned-tables), [Databricks](https://docs.databricks.com/en/sql/language-manual/sql-ref-partition.html)). - -If a partitioning key in newly loaded data is not present in the model table, the new partitioning key and its data rows are inserted. If a partitioning key in newly loaded data is already present in the model table, **all the partitioning key's existing data rows in the model table are replaced** with the partitioning key's data rows in the newly loaded data. If a partitioning key is present in the model table but not present in the newly loaded data, the partitioning key's existing data rows are not modified and remain in the model table. - -This kind is a good fit for datasets that have the following traits: - -* The dataset's records can be grouped by a partitioning key. -* Each record has a partitioning key associated with it. -* It is appropriate to upsert records, so existing records can be overwritten by new arrivals when their partitioning keys match. -* All existing records associated with a given partitioning key can be removed or overwritten when any new record has the partitioning key value. - -The column defining the partitioning key is specified in the model's `MODEL` DDL `partitioned_by` key. This example shows the `MODEL` DDL for an `INCREMENTAL_BY_PARTITION` model whose partition key is the row's value for the `region` column: - -```sql linenums="1" hl_lines="4" -MODEL ( - name db.events, - kind INCREMENTAL_BY_PARTITION, - partitioned_by region, -); -``` - -Compound partition keys are also supported, such as `region` and `department`: - -```sql linenums="1" hl_lines="4" -MODEL ( - name db.events, - kind INCREMENTAL_BY_PARTITION, - partitioned_by (region, department), -); -``` - -Date and/or timestamp column expressions are also supported (varies by SQL engine). This BigQuery example's partition key is based on the month each row's `event_date` occurred: - -```sql linenums="1" hl_lines="4" -MODEL ( - name db.events, - kind INCREMENTAL_BY_PARTITION, - partitioned_by DATETIME_TRUNC(event_date, MONTH) -); -``` +## INCREMENTAL_BY_UNIQUE_KEY -**Note**: Partial data [restatement](../plans.md#restatement-plans) is not supported for this model kind, which means that the entire table will be recreated from scratch if restated. This may lead to data loss, so data restatement is disabled for models of this kind by default. +Models of the `INCREMENTAL_BY_UNIQUE_KEY` kind are computed incrementally based on a key. -### Materialization strategy -Depending on the target engine, models of the `INCREMENTAL_BY_PARTITION` kind are materialized using the following strategies: +They insert or update rows based on these rules: -| Engine | Strategy | -|------------|-----------------------------------------| -| Databricks | REPLACE WHERE by partitioning key | -| Spark | INSERT OVERWRITE by partitioning key | -| Snowflake | DELETE by partitioning key, then INSERT | -| BigQuery | DELETE by partitioning key, then INSERT | -| Redshift | DELETE by partitioning key, then INSERT | -| Postgres | DELETE by partitioning key, then INSERT | -| DuckDB | DELETE by partitioning key, then INSERT | +- If a key in newly loaded data is not present in the model table, the new data row is inserted. +- If a key in newly loaded data is already present in the model table, the existing row is updated with the new data. +- If a key is present in the model table but not present in the newly loaded data, its row is not modified and remains in the model table. -## INCREMENTAL_BY_UNIQUE_KEY +!!! important "Prevent duplicated keys" -Models of the `INCREMENTAL_BY_UNIQUE_KEY` kind are computed incrementally based on a key that is unique for each data row. + If you do not want duplicated keys in the model table, you must ensure the model query does not return rows with duplicate keys. -If a key in newly loaded data is not present in the model table, the new data row is inserted. If a key in newly loaded data is already present in the model table, the existing row is updated with the new data. If a key is present in the model table but not present in the newly loaded data, its row is not modified and remains in the model table. + SQLMesh does not automatically detect or prevent duplicates. This kind is a good fit for datasets that have the following traits: @@ -217,7 +495,7 @@ MODEL ( ); ``` -`INCREMENTAL_BY_UNIQUE_KEY` model kinds can also filter upstream records by time range using a SQL `WHERE` clause and the `@start_date`, `@end_date` or other macros (similar to the [INCREMENTAL_BY_TIME_RANGE](#incremental_by_time_range) kind): +`INCREMENTAL_BY_UNIQUE_KEY` model kinds can also filter upstream records by time range using a SQL `WHERE` clause and the `@start_date`, `@end_date` or other macro variables (similar to the [INCREMENTAL_BY_TIME_RANGE](#incremental_by_time_range) kind). Note that SQLMesh macro time variables are in the UTC time zone. ```sql linenums="1" hl_lines="6-7" SELECT name::TEXT as name, @@ -228,6 +506,66 @@ WHERE event_date BETWEEN @start_date AND @end_date; ``` +??? "Example SQL sequence when applying this model kind (ex: BigQuery)" + + Create a model with the following definition and run `sqlmesh plan dev`: + + ```sql + MODEL ( + name demo.incremental_by_unique_key_example, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key id + ), + start '2020-01-01', + cron '@daily', + ); + + SELECT + id, + item_id, + event_date + FROM demo.seed_model + WHERE + event_date BETWEEN @start_date AND @end_date + ``` + + SQLMesh will execute this SQL to create a versioned table in the physical layer. Note that the table's version fingerprint, `1161945221`, is part of the table name. + + ```sql + CREATE TABLE IF NOT EXISTS `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__incremental_by_unique_key_example__1161945221` (`id` INT64, `item_id` INT64, `event_date` DATE) + ``` + + SQLMesh will validate the model's query before processing data (note the `FALSE LIMIT 0` in the `WHERE` statement and the placeholder dates). + + ```sql + SELECT `seed_model`.`id` AS `id`, `seed_model`.`item_id` AS `item_id`, `seed_model`.`event_date` AS `event_date` + FROM `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__seed_model__2834544882` AS `seed_model` + WHERE (`seed_model`.`event_date` <= CAST('1970-01-01' AS DATE) AND `seed_model`.`event_date` >= CAST('1970-01-01' AS DATE)) AND FALSE LIMIT 0 + ``` + + SQLMesh will create a versioned table in the physical layer. + + ```sql + CREATE OR REPLACE TABLE `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__incremental_by_unique_key_example__1161945221` AS + SELECT CAST(`id` AS INT64) AS `id`, CAST(`item_id` AS INT64) AS `item_id`, CAST(`event_date` AS DATE) AS `event_date` + FROM (SELECT `seed_model`.`id` AS `id`, `seed_model`.`item_id` AS `item_id`, `seed_model`.`event_date` AS `event_date` + FROM `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__seed_model__2834544882` AS `seed_model` + WHERE `seed_model`.`event_date` <= CAST('2024-10-30' AS DATE) AND `seed_model`.`event_date` >= CAST('2020-01-01' AS DATE)) AS `_subquery` + ``` + + SQLMesh will create a suffixed `__dev` schema based on the name of the plan environment. + + ```sql + CREATE SCHEMA IF NOT EXISTS `sqlmesh-public-demo`.`demo__dev` + ``` + + SQLMesh will create a view in the virtual layer pointing to the versioned table in the physical layer. + + ```sql + CREATE OR REPLACE VIEW `sqlmesh-public-demo`.`demo__dev`.`incremental_by_unique_key_example` AS + SELECT * FROM `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__incremental_by_unique_key_example__1161945221` + ``` + **Note:** Models of the `INCREMENTAL_BY_UNIQUE_KEY` kind are inherently [non-idempotent](../glossary.md#idempotency), which should be taken into consideration during data [restatement](../plans.md#restatement-plans). As a result, partial data restatement is not supported for this model kind, which means that the entire table will be recreated from scratch if restated. ### Unique Key Expressions @@ -252,21 +590,73 @@ MODEL ( name db.employees, kind INCREMENTAL_BY_UNIQUE_KEY ( unique_key name, - when_matched WHEN MATCHED THEN UPDATE SET target.salary = COALESCE(source.salary, target.salary) + when_matched ( + WHEN MATCHED THEN UPDATE SET target.salary = COALESCE(source.salary, target.salary) + ) ) ); ``` The `source` and `target` aliases are required when using the `when_matched` expression in order to distinguish between the source and target columns. +Multiple `WHEN MATCHED` expressions can also be provided. Ex: + +```sql linenums="1" hl_lines="5-6" +MODEL ( + name db.employees, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key name, + when_matched ( + WHEN MATCHED AND source.value IS NULL THEN UPDATE SET target.salary = COALESCE(source.salary, target.salary) + WHEN MATCHED THEN UPDATE SET target.title = COALESCE(source.title, target.title) + ) + ) +); +``` + **Note**: `when_matched` is only available on engines that support the `MERGE` statement. Currently supported engines include: * BigQuery * Databricks * Postgres +* Redshift * Snowflake * Spark +In Redshift's case, to enable the use of the native `MERGE` statement, you need to pass the `enable_merge` flag in the connection and set it to `true`. It is disabled by default. + +```yaml linenums="1" +gateways: + redshift: + connection: + type: redshift + enable_merge: true +``` + +Redshift supports only the `UPDATE` or `DELETE` actions for the `WHEN MATCHED` clause and does not allow multiple `WHEN MATCHED` expressions. For further information, refer to the [Redshift documentation](https://docs.aws.amazon.com/redshift/latest/dg/r_MERGE.html#r_MERGE-parameters). + +### Merge Filter Expression + +The `MERGE` statement typically induces a full table scan of the existing table, which can be problematic with large data volumes. + +Prevent a full table scan by passing filtering conditions to the `merge_filter` parameter. + +The `merge_filter` accepts a single or a conjunction of predicates to be used in the `ON` clause of the `MERGE` operation: + +```sql linenums="1" hl_lines="5" +MODEL ( + name db.employee_contracts, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key id, + merge_filter source._operation IS NULL AND target.contract_date > dateadd(day, -7, current_date) + ) +); +``` + +Similar to `when_matched`, the `source` and `target` aliases are used to distinguish between the source and target tables. + +If an existing dbt project uses the [incremental_predicates](https://docs.getdbt.com/docs/build/incremental-strategy#about-incremental_predicates) functionality, SQLMesh will automatically convert them into the equivalent `merge_filter` specification. + ### Materialization strategy Depending on the target engine, models of the `INCREMENTAL_BY_UNIQUE_KEY` kind are materialized using the following strategies: @@ -301,6 +691,64 @@ FROM db.employees GROUP BY title; ``` +??? "Example SQL sequence when applying this model kind (ex: BigQuery)" + + Create a model with the following definition and run `sqlmesh plan dev`: + + ```sql + MODEL ( + name demo.full_model_example, + kind FULL, + cron '@daily', + grain item_id, + ); + + SELECT + item_id, + COUNT(DISTINCT id) AS num_orders + FROM demo.incremental_model + GROUP BY + item_id + ``` + + SQLMesh will execute this SQL to create a versioned table in the physical layer. Note that the table's version fingerprint, `2345651858`, is part of the table name. + + ```sql + CREATE TABLE IF NOT EXISTS `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__full_model_example__2345651858` (`item_id` INT64, `num_orders` INT64) + ``` + + SQLMesh will validate the model's query before processing data (note the `WHERE FALSE` and `LIMIT 0`). + + ```sql + SELECT `incremental_model`.`item_id` AS `item_id`, COUNT(DISTINCT `incremental_model`.`id`) AS `num_orders` + FROM `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__incremental_model__89556012` AS `incremental_model` + WHERE FALSE + GROUP BY `incremental_model`.`item_id` LIMIT 0 + ``` + + SQLMesh will create a versioned table in the physical layer. + + ```sql + CREATE OR REPLACE TABLE `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__full_model_example__2345651858` AS + SELECT CAST(`item_id` AS INT64) AS `item_id`, CAST(`num_orders` AS INT64) AS `num_orders` + FROM (SELECT `incremental_model`.`item_id` AS `item_id`, COUNT(DISTINCT `incremental_model`.`id`) AS `num_orders` + FROM `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__incremental_model__89556012` AS `incremental_model` + GROUP BY `incremental_model`.`item_id`) AS `_subquery` + ``` + + SQLMesh will create a suffixed `__dev` schema based on the name of the plan environment. + + ```sql + CREATE SCHEMA IF NOT EXISTS `sqlmesh-public-demo`.`demo__dev` + ``` + + SQLMesh will create a view in the virtual layer pointing to the versioned table in the physical layer. + + ```sql + CREATE OR REPLACE VIEW `sqlmesh-public-demo`.`demo__dev`.`full_model_example` AS + SELECT * FROM `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__full_model_example__2345651858` + ``` + ### Materialization strategy Depending on the target engine, models of the `FULL` kind are materialized using the following strategies: @@ -321,8 +769,11 @@ The `VIEW` kind is different, because no data is actually written during model e **Note:** `VIEW` is the default model kind if kind is not specified. +**Note:** Python models do not support the `VIEW` model kind - use a SQL model instead. + **Note:** With this kind, the model's query is evaluated every time the model is referenced in a downstream query. This may incur undesirable compute cost and time in cases where the model's query is compute-intensive, or when the model is referenced in many downstream queries. + This example specifies a `VIEW` model kind: ```sql linenums="1" hl_lines="3" MODEL ( @@ -335,6 +786,42 @@ SELECT FROM db.employees; ``` +??? "Example SQL sequence when applying this model kind (ex: BigQuery)" + + Create a model with the following definition and run `sqlmesh plan dev`: + + ```sql + MODEL ( + name demo.example_view, + kind VIEW, + cron '@daily', + ); + + SELECT + 'hello there' as a_column + ``` + + SQLMesh will execute this SQL to create a versioned view in the physical layer. Note that the view's version fingerprint, `1024042926`, is part of the view name. + + ```sql + CREATE OR REPLACE VIEW `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__example_view__1024042926` + (`a_column`) AS SELECT 'hello there' AS `a_column` + ``` + + SQLMesh will create a suffixed `__dev` schema based on the name of the plan environment. + + ```sql + CREATE SCHEMA IF NOT EXISTS `sqlmesh-public-demo`.`demo__dev` + ``` + + SQLMesh will create a view in the virtual layer pointing to the versioned view in the physical layer. + + ```sql + CREATE OR REPLACE VIEW `sqlmesh-public-demo`.`demo__dev`.`example_view` AS + SELECT * FROM `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__example_view__1024042926` + ``` + + ### Materialized Views The `VIEW` model kind can be configured to represent a materialized view by setting the `materialized` flag to `true`: ```sql linenums="1" hl_lines="4" @@ -357,7 +844,9 @@ During the evaluation of a model of this kind, the view will be replaced or recr ## EMBEDDED Embedded models are a way to share common logic between different models of other kinds. -There are no data assets (tables or views) associated with `EMBEDDED` models in the data warehouse. Instead, an `EMBEDDED` model's query is injected directly into the query of each downstream model that references it. +There are no data assets (tables or views) associated with `EMBEDDED` models in the data warehouse. Instead, an `EMBEDDED` model's query is injected directly into the query of each downstream model that references it, as a subquery. + +**Note:** Python models do not support the `EMBEDDED` model kind - use a SQL model instead. This example specifies a `EMBEDDED` model kind: ```sql linenums="1" hl_lines="3" @@ -374,6 +863,70 @@ FROM db.employees; ## SEED The `SEED` model kind is used to specify [seed models](./seed_models.md) for using static CSV datasets in your SQLMesh project. +**Notes:** + +- Seed models are loaded only once unless the SQL model and/or seed file is updated. +- Python models do not support the `SEED` model kind - use a SQL model instead. + +??? "Example SQL sequence when applying this model kind (ex: BigQuery)" + + Create a model with the following definition and run `sqlmesh plan dev`: + + ```sql + MODEL ( + name demo.seed_example, + kind SEED ( + path '../../seeds/seed_example.csv' + ), + columns ( + id INT64, + item_id INT64, + event_date DATE + ), + grain (id, event_date) + ) + ``` + + SQLMesh will execute this SQL to create a versioned table in the physical layer. Note that the table's version fingerprint, `3038173937`, is part of the table name. + + ```sql + CREATE TABLE IF NOT EXISTS `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__seed_example__3038173937` (`id` INT64, `item_id` INT64, `event_date` DATE) + ``` + + SQLMesh will upload the seed as a temp table in the physical layer. + + ```sql + sqlmesh-public-demo.sqlmesh__demo.__temp_demo__seed_example__3038173937_9kzbpld7 + ``` + + SQLMesh will create a versioned table in the physical layer from the temp table. + + ```sql + CREATE OR REPLACE TABLE `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__seed_example__3038173937` AS + SELECT CAST(`id` AS INT64) AS `id`, CAST(`item_id` AS INT64) AS `item_id`, CAST(`event_date` AS DATE) AS `event_date` + FROM (SELECT `id`, `item_id`, `event_date` + FROM `sqlmesh-public-demo`.`sqlmesh__demo`.`__temp_demo__seed_example__3038173937_9kzbpld7`) AS `_subquery` + ``` + + SQLMesh will drop the temp table in the physical layer. + + ```sql + DROP TABLE IF EXISTS `sqlmesh-public-demo`.`sqlmesh__demo`.`__temp_demo__seed_example__3038173937_9kzbpld7` + ``` + + SQLMesh will create a suffixed `__dev` schema based on the name of the plan environment. + + ```sql + CREATE SCHEMA IF NOT EXISTS `sqlmesh-public-demo`.`demo__dev` + ``` + + SQLMesh will create a view in the virtual layer pointing to the versioned table in the physical layer. + + ```sql + CREATE OR REPLACE VIEW `sqlmesh-public-demo`.`demo__dev`.`seed_example` AS + SELECT * FROM `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__seed_example__3038173937` + ``` + ## SCD Type 2 SCD Type 2 is a model kind that supports [slowly changing dimensions](https://en.wikipedia.org/wiki/Slowly_changing_dimension#Type_2:_add_new_row) (SCDs) in your SQLMesh project. SCDs are a common pattern in data warehousing that allow you to track changes to records over time. @@ -388,7 +941,7 @@ There are two ways to tracking changes: By Time (Recommended) or By Column. ### SCD Type 2 By Time (Recommended) -SCD Type 2 By Time supports sourcing from tables that have an "Updated At" timestamp defined in the table that tells you when a given was last updated. +SCD Type 2 By Time supports sourcing from tables that have an "Updated At" timestamp defined in the table that tells you when a given record was last updated. This is the recommended way since this "Updated At" gives you a precise time when the record was last updated and therefore improves the accuracy of the SCD Type 2 table that is produced. This example specifies a `SCD_TYPE_2_BY_TIME` model kind: @@ -516,12 +1069,7 @@ TABLE db.menu_items ( A hard delete is when a record no longer exists in the source table. When this happens, -If `invalidate_hard_deletes` is set to `true` (default): - -* `valid_to` column will be set to the time when the SQLMesh run started that detected the missing record (called `execution_time`). -* If the record is added back, then the `valid_to` column will remain unchanged. - -If `invalidate_hard_deletes` is set to `false`: +If `invalidate_hard_deletes` is set to `false` (default): * `valid_to` column will continue to be set to `NULL` (therefore still considered "valid") * If the record is added back, then the `valid_to` column will be set to the `valid_from` of the new record. @@ -531,13 +1079,18 @@ When a record is added back, the new record will be inserted into the table with * SCD_TYPE_2_BY_TIME: the largest of either the `updated_at` timestamp of the new record or the `valid_from` timestamp of the deleted record in the SCD Type 2 table * SCD_TYPE_2_BY_COLUMN: the `execution_time` when the record was detected again -One way to think about `invalidate_hard_deletes` is that, if enabled, deletes are most accurately tracked in the SCD Type 2 table since it records when the delete occurred. +If `invalidate_hard_deletes` is set to `true`: + +* `valid_to` column will be set to the time when the SQLMesh run started that detected the missing record (called `execution_time`). +* If the record is added back, then the `valid_to` column will remain unchanged. + +One way to think about `invalidate_hard_deletes` is that, if `invalidate_hard_deletes` is set to `true`, deletes are most accurately tracked in the SCD Type 2 table since it records when the delete occurred. As a result though, you can have gaps between records if the there is a gap of time between when it was deleted and added back. -If you would prefer to not have gaps, and a result consider missing records in source as still "valid", then you can set `invalidate_hard_deletes` to `false`. +If you would prefer to not have gaps, and a result consider missing records in source as still "valid", then you can leave the default value or set `invalidate_hard_deletes` to `false`. ### Example of SCD Type 2 By Time in Action -Lets say that you started with the following data in your source table: +Lets say that you started with the following data in your source table and `invalidate_hard_deletes` is set to `true`: | ID | Name | Price | Updated At | |----|------------------|:-----:|:-------------------:| @@ -613,7 +1166,7 @@ Since in this case the updated at timestamp did not change it is likely the item ### Example of SCD Type 2 By Column in Action -Lets say that you started with the following data in your source table: +Lets say that you started with the following data in your source table and `invalidate_hard_deletes` is set to `true`: | ID | Name | Price | |----|------------------|:-----:| @@ -688,12 +1241,29 @@ This is the most accurate representation of the menu based on the source data pr ### Shared Configuration Options -| Name | Description | Type | -|--------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------| -| unique_key | Unique key used for identifying rows between source and target | List of strings or string | -| valid_from_name | The name of the `valid_from` column to create in the target table. Default: `valid_from` | string | -| valid_to_name | The name of the `valid_to` column to create in the target table. Default: `valid_to` | string | -| invalidate_hard_deletes | If set to `true`, when a record is missing from the source table it will be marked as invalid. Default: `true` | bool | +| Name | Description | Type | +|-------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------| +| unique_key | Unique key used for identifying rows between source and target | List of strings or string | +| valid_from_name | The name of the `valid_from` column to create in the target table. Default: `valid_from` | string | +| valid_to_name | The name of the `valid_to` column to create in the target table. Default: `valid_to` | string | +| invalidate_hard_deletes | If set to `true`, when a record is missing from the source table it will be marked as invalid. Default: `false` | bool | +| batch_size | The maximum number of intervals that can be evaluated in a single backfill task. If this is `None`, all intervals will be processed as part of a single task. See [Processing Source Table with Historical Data](#processing-source-table-with-historical-data) for more info on this use case. (Default: `None`) | int | + +!!! tip "Important" + + If using BigQuery, the default data type of the valid_from/valid_to columns is DATETIME. If you want to use TIMESTAMP, you can specify the data type in the model definition. + + ```sql linenums="1" hl_lines="5" + MODEL ( + name db.menu_items, + kind SCD_TYPE_2_BY_TIME ( + unique_key id, + time_data_type TIMESTAMP + ) + ); + ``` + + This could likely be used on other engines to change the expected data type but has only been tested on BigQuery. ### SCD Type 2 By Time Configuration Options @@ -704,10 +1274,66 @@ This is the most accurate representation of the menu based on the source data pr ### SCD Type 2 By Column Configuration Options -| Name | Description | Type | -|------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------| -| columns | The name of the columns to check for changes. `*` to represent that all columns should be checked. | List of strings or string | -| execution_time_as_valid_from | By default, for new rows `valid_from` is set to `1970-01-01 00:00:00`. This changes the behavior to set it to the `execution_time` of when the pipeline ran. Default: `false` | bool | +| Name | Description | Type | +|------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|---------------------------| +| columns | The name of the columns to check for changes. `*` to represent that all columns should be checked. | List of strings or string | +| execution_time_as_valid_from | By default, when the model is first loaded `valid_from` is set to `1970-01-01 00:00:00` and future new rows will have `execution_time` of when the pipeline ran. This changes the behavior to always use `execution_time`. Default: `false` | bool | +| updated_at_name | If sourcing from a table that includes as timestamp to use as valid_from, set this property to that column. See [Processing Source Table with Historical Data](#processing-source-table-with-historical-data) for more info on this use case. (Default: `None`) | int | + + +### Processing Source Table with Historical Data + +The most common case for SCD Type 2 is creating history for a table that it doesn't have it already. +In the example of the restaurant menu, the menu just tells you what is offered right now, but you want to know what was offered over time. +In this case, the default setting of `None` for `batch_size` is the best option. + +Another use case though is processing a source table that already has history in it. +A common example of this is a "daily snapshot" table that is created by a source system that takes a snapshot of the data at the end of each day. +If your source table has historical records, like a "daily snapshot" table, then set `batch_size` to `1` to process each interval (each day if a `@daily` cron) in sequential order. +That way the historical records will be properly captured in the SCD Type 2 table. + +#### Example - Source from Daily Snapshot Table + +```sql linenums="1" +MODEL ( + name db.table, + kind SCD_TYPE_2_BY_COLUMN ( + unique_key id, + columns [some_value], + updated_at_name ds, + batch_size 1 + ), + start '2025-01-01', + cron '@daily' +); +SELECT + id, + some_value, + ds +FROM + source_table +WHERE + ds between @start_ds and @end_ds +``` + +This will process each day of the source table in sequential order (if more than one day to process), checking `some_value` column to see if it changed. If it did change, `valid_from` will be set to match the `ds` column (except for first value which would be `1970-01-01 00:00:00`). + +If the source data was the following: + +| id | some_value | ds | +|----|------------|:-----------:| +| 1 | 1 | 2025-01-01 | +| 1 | 2 | 2025-01-02 | +| 1 | 3 | 2025-01-03 | +| 1 | 3 | 2025-01-04 | + +Then the resulting SCD Type 2 table would be: + +| id | some_value | ds | valid_from | valid_to | +|----|------------|:-----------:|:-------------------:|:-------------------:| +| 1 | 1 | 2025-01-01 | 1970-01-01 00:00:00 | 2025-01-02 00:00:00 | +| 1 | 2 | 2025-01-02 | 2025-01-02 00:00:00 | 2025-01-03 00:00:00 | +| 1 | 3 | 2025-01-03 | 2025-01-03 00:00:00 | NULL | ### Querying SCD Type 2 Models @@ -807,6 +1433,46 @@ GROUP BY id ``` +### Reset SCD Type 2 Model (clearing history) + +SCD Type 2 models are designed by default to protect the data that has been captured because it is not possible to recreate the history once it has been lost. +However, there are cases where you may want to clear the history and start fresh. +For this use use case you will want to start by setting `disable_restatement` to `false` in the model definition. + +```sql linenums="1" hl_lines="5" +MODEL ( + name db.menu_items, + kind SCD_TYPE_2_BY_TIME ( + unique_key id, + disable_restatement false + ) +); +``` + +Plan/apply this change to production. +Then you will want to [restate the model](../plans.md#restatement-plans). + +```bash +sqlmesh plan --restate-model db.menu_items +``` + +!!! warning + + This will remove the historical data on the model which in most situations cannot be recovered. + +Once complete you will want to remove `disable_restatement` on the model definition which will set it back to `true` and prevent accidental data loss. + +```sql linenums="1" +MODEL ( + name db.menu_items, + kind SCD_TYPE_2_BY_TIME ( + unique_key id, + ) +); +``` + +Plan/apply this change to production. + ## EXTERNAL The EXTERNAL model kind is used to specify [external models](./external_models.md) that store metadata about external tables. External models are special; they are not specified in .sql files like the other model kinds. They are optional but useful for propagating column and type information for external tables queried in your SQLMesh project. @@ -817,9 +1483,11 @@ The EXTERNAL model kind is used to specify [external models](./external_models.m Managed models are still under development and the API / semantics may change as support for more engines is added +**Note:** Python models do not support the `MANAGED` model kind - use a SQL model instead. + The `MANAGED` model kind is used to create models where the underlying database engine manages the data lifecycle. -These models dont get updated with new intervals or refreshed when `sqlmesh run` is called. Responsibility for keeping the *data* up to date falls on the engine. +These models don't get updated with new intervals or refreshed when `sqlmesh run` is called. Responsibility for keeping the *data* up to date falls on the engine. You can control how the engine creates the managed model by using the [`physical_properties`](../overview#physical_properties-previously-table_properties) to pass engine-specific parameters for adapter to use when issuing commands to the underlying database. @@ -827,4 +1495,164 @@ Due to there being no standard, each vendor has a different implementation with We would recommend using standard SQLMesh model types in the first instance. However, if you do need to use Managed models, you still gain other SQLMesh benefits like the ability to use them in [virtual environments](../../concepts/overview#build-a-virtual-environment). -See [Managed Models](./managed_models.md) for more information on which engines are supported and which properties are available. \ No newline at end of file +See [Managed Models](./managed_models.md) for more information on which engines are supported and which properties are available. + +## INCREMENTAL_BY_PARTITION + +Models of the `INCREMENTAL_BY_PARTITION` kind are computed incrementally based on partition. A set of columns defines the model's partitioning key, and a partition is the group of rows with the same partitioning key value. + +!!! question "Should you use this model kind?" + + Any model kind can use a partitioned **table** by specifying the [`partitioned_by` key](../models/overview.md#partitioned_by) in the `MODEL` DDL. + + The "partition" in `INCREMENTAL_BY_PARTITION` is about how the data is **loaded** when the model runs. + + `INCREMENTAL_BY_PARTITION` models are inherently [non-idempotent](../glossary.md#idempotency), so restatements and other actions can cause data loss. This makes them more complex to manage than other model kinds. + + In most scenarios, an `INCREMENTAL_BY_TIME_RANGE` model can meet your needs and will be easier to manage. The `INCREMENTAL_BY_PARTITION` model kind should only be used when the data must be loaded by partition (usually for performance reasons). + +This model kind is designed for the scenario where data rows should be loaded and updated as a group based on their shared value for the partitioning key. + +It may be used with any SQL engine. SQLMesh will automatically create partitioned tables on engines that support explicit table partitioning (e.g., [BigQuery](https://cloud.google.com/bigquery/docs/creating-partitioned-tables), [Databricks](https://docs.databricks.com/en/sql/language-manual/sql-ref-partition.html)). + +New rows are loaded based on their partitioning key value: + +- If a partitioning key in newly loaded data is not present in the model table, the new partitioning key and its data rows are inserted. +- If a partitioning key in newly loaded data is already present in the model table, **all the partitioning key's existing data rows in the model table are replaced** with the partitioning key's data rows in the newly loaded data. +- If a partitioning key is present in the model table but not present in the newly loaded data, the partitioning key's existing data rows are not modified and remain in the model table. + +This kind should only be used for datasets that have the following traits: + +* The dataset's records can be grouped by a partitioning key. +* Each record has a partitioning key associated with it. +* It is appropriate to upsert records, so existing records can be overwritten by new arrivals when their partitioning keys match. +* All existing records associated with a given partitioning key can be removed or overwritten when any new record has the partitioning key value. + +The column defining the partitioning key is specified in the model's `MODEL` DDL `partitioned_by` key. This example shows the `MODEL` DDL for an `INCREMENTAL_BY_PARTITION` model whose partition key is the row's value for the `region` column: + +```sql linenums="1" hl_lines="4" +MODEL ( + name db.events, + kind INCREMENTAL_BY_PARTITION, + partitioned_by region, +); +``` + +Compound partition keys are also supported, such as `region` and `department`: + +```sql linenums="1" hl_lines="4" +MODEL ( + name db.events, + kind INCREMENTAL_BY_PARTITION, + partitioned_by (region, department), +); +``` + +Date and/or timestamp column expressions are also supported (varies by SQL engine). This BigQuery example's partition key is based on the month each row's `event_date` occurred: + +```sql linenums="1" hl_lines="4" +MODEL ( + name db.events, + kind INCREMENTAL_BY_PARTITION, + partitioned_by DATETIME_TRUNC(event_date, MONTH) +); +``` + +!!! warning "Only full restatements supported" + + Partial data [restatements](../plans.md#restatement-plans) are used to reprocess part of a table's data (usually a limited time range). + + Partial data restatement is not supported for `INCREMENTAL_BY_PARTITION` models. If you restate an `INCREMENTAL_BY_PARTITION` model, its entire table will be recreated from scratch. + + Restating `INCREMENTAL_BY_PARTITION` models may lead to data loss and should be performed with care. + +### Example + +This is a fuller example of how you would use this model kind in practice. It limits the number of partitions to backfill based on time range in the `partitions_to_update` CTE. + +```sql linenums="1" +MODEL ( + name demo.incremental_by_partition_demo, + kind INCREMENTAL_BY_PARTITION, + partitioned_by user_segment, +); + +-- This is the source of truth for what partitions need to be updated and will join to the product usage data +-- This could be an INCREMENTAL_BY_TIME_RANGE model that reads in the user_segment values last updated in the past 30 days to reduce scope +-- Use this strategy to reduce full restatements +WITH partitions_to_update AS ( + SELECT DISTINCT + user_segment + FROM demo.incremental_by_time_range_demo -- upstream table tracking which user segments to update + WHERE last_updated_at BETWEEN DATE_SUB(@start_dt, INTERVAL 30 DAY) AND @end_dt +), + +product_usage AS ( + SELECT + product_id, + customer_id, + last_usage_date, + usage_count, + feature_utilization_score, + user_segment + FROM sqlmesh-public-demo.tcloud_raw_data.product_usage + WHERE user_segment IN (SELECT user_segment FROM partitions_to_update) -- partition filter applied here +) + +SELECT + product_id, + customer_id, + last_usage_date, + usage_count, + feature_utilization_score, + user_segment, + CASE + WHEN usage_count > 100 AND feature_utilization_score > 0.7 THEN 'Power User' + WHEN usage_count > 50 THEN 'Regular User' + WHEN usage_count IS NULL THEN 'New User' + ELSE 'Light User' + END as user_type +FROM product_usage +``` + +**Note**: Partial data [restatement](../plans.md#restatement-plans) is not supported for this model kind, which means that the entire table will be recreated from scratch if restated. This may lead to data loss. + +### Materialization strategy +Depending on the target engine, models of the `INCREMENTAL_BY_PARTITION` kind are materialized using the following strategies: + +| Engine | Strategy | +|------------|-----------------------------------------| +| Databricks | REPLACE WHERE by partitioning key | +| Spark | INSERT OVERWRITE by partitioning key | +| Snowflake | DELETE by partitioning key, then INSERT | +| BigQuery | DELETE by partitioning key, then INSERT | +| Redshift | DELETE by partitioning key, then INSERT | +| Postgres | DELETE by partitioning key, then INSERT | +| DuckDB | DELETE by partitioning key, then INSERT | + +## INCREMENTAL_UNMANAGED + +The `INCREMENTAL_UNMANAGED` model kind exists to support append-only tables. It's "unmanaged" in the sense that SQLMesh doesnt try to manage how the data is loaded. SQLMesh will just run your query on the configured cadence and append whatever it gets into the table. + +!!! question "Should you use this model kind?" + + Some patterns for data management, such as Data Vault, may rely on append-only tables. In this situation, `INCREMENTAL_UNMANAGED` is the correct type to use. + + In most other situations, you probably want `INCREMENTAL_BY_TIME_RANGE` or `INCREMENTAL_BY_UNIQUE_KEY` because they give you much more control over how the data is loaded. + +Usage of the `INCREMENTAL_UNMANAGED` model kind is straightforward: + +```sql linenums="1" hl_lines="3" +MODEL ( + name db.events, + kind INCREMENTAL_UNMANAGED, +); +``` + +Since it's unmanaged, it doesnt support the `batch_size` and `batch_concurrency` properties to control how data is loaded like the other incremental model types do. + +!!! warning "Only full restatements supported" + + Similar to `INCREMENTAL_BY_PARTITION`, attempting to [restate](../plans.md#restatement-plans) an `INCREMENTAL_UNMANAGED` model will trigger a full restatement. That is, the model will be rebuilt from scratch rather than from a time slice you specify. + + This is because an append-only table is inherently non-idempotent. Restating `INCREMENTAL_UNMANAGED` models may lead to data loss and should be performed with care. diff --git a/docs/concepts/models/overview.md b/docs/concepts/models/overview.md index f29f7fda79..d6356462b4 100644 --- a/docs/concepts/models/overview.md +++ b/docs/concepts/models/overview.md @@ -8,8 +8,8 @@ SQLMesh will automatically determine the relationships among and lineage of your The following is an example of a model defined in SQL. Note the following aspects: - Models can include descriptive information as comments, such as the first line. - - The first non-comment statement of a `model.sql` file is the `MODEL` DDL. - - The last non-comment statement should be a `SELECT` statement that defines the logic needed to create the table + - The first non-comment statement in the file is the `MODEL` DDL. + - The last non-comment statement is a `SELECT` query containing the logic that transforms the data. ```sql linenums="1" -- Customer revenue computed and stored daily. @@ -177,7 +177,9 @@ This table lists each engine's support for `TABLE` and `VIEW` object comments: | Engine | `TABLE` comments | `VIEW` comments | | ------------- | ---------------- | --------------- | +| Athena | N | N | | BigQuery | Y | Y | +| ClickHouse | Y | Y | | Databricks | Y | Y | | DuckDB <=0.9 | N | N | | DuckDB >=0.10 | Y | Y | @@ -197,139 +199,372 @@ The `MODEL` DDL statement takes various properties, which are used for both meta Learn more about these properties and their default values in the [model configuration reference](../../reference/model_configuration.md#general-model-properties). ### name -- `name` specifies the name of the model. This name represents the production view name that the model outputs, so it generally takes the form of `"schema"."view_name"`. The name of a model must be unique in a SQLMesh project.

-When models are used in non-production environments, SQLMesh automatically prefixes the names. For example, consider a model named `"sushi"."customers"`. In production its view is named `"sushi"."customers"`, and in dev its view is named `"sushi__dev"."customers"`.

-Name is ***required*** and must be ***unique***, unless [name inference](../../reference/model_configuration.md#model-naming) is enabled. +: Name specifies the name of the model. This name represents the production view name that the model outputs, so it generally takes the form of `"schema"."view_name"`. The name of a model must be unique in a SQLMesh project. + + When models are used in non-production environments, SQLMesh automatically prefixes the names. For example, consider a model named `"sushi"."customers"`. In production its view is named `"sushi"."customers"`, and in dev its view is named `"sushi__dev"."customers"`. + + Name is ***required*** and must be ***unique***, unless [name inference](../../reference/model_configuration.md#model-naming) is enabled. + +### project +: Project specifies the name of the project the model belongs to. Used in multi-repo SQLMesh deployments. ### kind -- Kind specifies what [kind](model_kinds.md) a model is. A model's kind determines how it is computed and stored. The default kind is `VIEW`, which means a view is created and your query is run each time that view is accessed. See [below](#incremental-model-properties) for properties that apply to incremental model kinds. +: Kind specifies what [kind](model_kinds.md) a model is. A model's kind determines how it is computed and stored. The default kind is `VIEW` for SQL models, which means a view is created and your query is run each time that view is accessed. On the other hand, the default kind for Python models is `FULL`, which means that a table is created and the Python code is executed each time the model is evaluated. See [below](#incremental-model-properties) for properties that apply to incremental model kinds. + +### audits +: Audits specifies which [audits](../audits.md) should run after the model is evaluated. ### dialect -- Dialect defines the SQL dialect of the model. By default, this uses the dialect in the [configuration file `model_defaults` `dialect` key](../../reference/configuration.md#model-configuration). All SQL dialects [supported by the SQLGlot library](https://github.com/tobymao/sqlglot/blob/main/sqlglot/dialects/__init__.py) are allowed. +: Dialect defines the SQL dialect of the model. By default, this uses the dialect in the [configuration file `model_defaults` `dialect` key](../../reference/configuration.md#model-configuration). All SQL dialects [supported by the SQLGlot library](https://github.com/tobymao/sqlglot/blob/main/sqlglot/dialects/__init__.py) are allowed. ### owner -- Owner specifies who the main point of contact is for the model. It is an important field for organizations that have many data collaborators. +: Owner specifies who the main point of contact is for the model. It is an important field for organizations that have many data collaborators. ### stamp -- An optional arbitrary string sequence used to create new model versions without making changes to any of the functional components of the definition. +: An optional arbitrary string sequence used to create a new model version without changing the functional components of the definition. -### start -- Start is used to determine the earliest time needed to process the model. It can be an absolute date/time (`2022-01-01`), or a relative one (`1 year ago`). - -### end -- End is used to determine the latest time needed to process the model. It can be an absolute date/time (`2022-01-01`), or a relative one (`1 year ago`). +### tags +: Tags are one or more labels used to organize your models. ### cron -- Cron is used to schedule your model to process or refresh at a certain interval. It accepts a [cron expression](https://en.wikipedia.org/wiki/Cron) or any of `@hourly`, `@daily`, `@weekly`, or `@monthly`. +: Cron is used to schedule when your model processes or refreshes data. It accepts a [cron expression](https://en.wikipedia.org/wiki/Cron) or any of `@hourly`, `@daily`, `@weekly`, or `@monthly`. All times are assumed to be UTC timezone by default. + +### cron_tz +: Cron timezone is used to specify the timezone of the cron. This is only used for scheduling and does not affect the intervals processed in an incremental model. For example, if a model is `@daily` with cron_tz `America/Los_Angeles`, it will run every day 12AM pacific time, however the `start` and `end` variables passed to the incremental model will represent the UTC date boundaries. ### interval_unit -- Interval unit determines the granularity of data intervals for this model. By default the interval unit is automatically derived from the `cron` expression. Supported values are: `year`, `month`, `day`, `hour`, `half_hour`, `quarter_hour`, and `five_minute`. +: Interval unit determines the temporal granularity with which time intervals are calculated for the model. -### tags -- Tags are one or more labels used to organize your models. + By default, the interval unit is automatically derived from the [`cron`](#cron) expression and does not need to be specified. + + Supported values are: `year`, `month`, `day`, `hour`, `half_hour`, `quarter_hour`, and `five_minute`. + + #### Relationship to [`cron`](#cron) + + The SQLMesh scheduler needs two temporal pieces of information from a model: specific times when the model should run and the finest temporal granularity with which the data is processed or stored. The `interval_unit` specifies that granularity. + + If a model's `cron` parameter is a frequency like `@daily`, the run times and `interval_unit` are simple to determine: the model is ready to run at the start of the day, and its `interval_unit` is `day`. Similarly, a `cron` of `@hourly` is ready to run at the start of each hour, and its `interval_unit` is `hour`. + + If [`cron`](#cron) is specified with a cron expression, however, SQLMesh uses a more complex approach to derive the `interval_unit`. + + A [cron expression](https://en.wikipedia.org/wiki/Cron) can generate complex time intervals, so SQLMesh does not parse it directly. Instead, it: + + 1. Generates the next five run times from the cron expression (relative to the time of calculation) + 2. Calculates the duration of the intervals between those five values + 3. Determines the model's `interval_unit` as the largest interval unit value that is less than or equal to the minimum duration from (2) + + For example, consider a cron expression corresponding to "run every 43 minutes." Its `interval_unit` is `half_hour` because that is the largest `interval_unit` value *shorter* than 43 minutes. If the cron expression is "run every 67 minutes", its `interval_unit` is `hour` given the same logic. + + However, `interval_unit` does not have to be inferred from [`cron`](#cron) - you can specify it explicitly to customize how your backfill occurs. + + #### Specifying `interval_unit` + + Models often run on a regular cadence, where the same amount of time passes between each run and the same time length of data is processed in each run. + + For example, a model might run at midnight every day (1 run per day) to process the previous day's data (1 day's worth of data per run). The length of time between runs and the time length of data processed in each run are both 1 day (or both 2 days if you miss a run). + + However, the run cadence length and processed data length do not have to be the same. + + Consider a model that runs every day at 7:30am and processes data up until 7am today. The model's `cron` is a cron expression for "run every day at 7:30am," from which SQLMesh infers an `interval_unit` of `day`. + + What will happen when this model runs? First, SQLMesh will identify the most recent completed interval. The `interval_unit` was inferred to be `day`, so the last complete interval was yesterday. SQLMesh will not include any of today's data between 12:00am and 7:00am in the run. + + To include today's data, manually specify an `interval_unit` of `hour`. When the model runs at 7:30am, SQLMesh will identify the most recent completed `hour` interval as 6:00-7:00am and include data through that interval in the backfill. + + ```sql + MODEL ( + name sqlmesh_example.up_until_7, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column date_column, + ), + start '2024-11-01', + cron '30 7 * * *', -- cron expression for "every day at 7:30am" + interval_unit 'hour', -- backfill up until the most recently completed hour (rather than day) + ); + ``` + + !!! warning "Caution: complex use case" + + The example below is a complex use case that uses the `allow_partials` configuration option. We recommend that you do **NOT** use this option unless absolutely necessary. + + When partials are allowed, you will not be able to determine the cause of missing data. A pipeline problem and a correctly executed partial backfill both result in missing data, so you may not be able to differentiate the two. + + Overall, you risk sharing incomplete/incorrect data even when SQLMesh runs successfully. Learn more on the [Tobiko blog](https://tobikodata.com/data-completeness.html). + + This section configures a model that: + + - Runs every hour + - Processes data for the last two days on every run + - Processes the data that has accumulated so far today on every run + + Configuring this model requires letting SQLMesh process partially completed intervals by setting the model configuration `allow_partials True`. + + The data for partial intervals is only temporary - SQLMesh will reprocess the entire interval once it is complete. + + ```sql + MODEL ( + name sqlmesh_example.demo, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column date_column, + lookback 2, -- 2 days of late-arriving data to backfill + ), + start '2024-11-01', + cron '@hourly', -- run model hourly, not tied to the interval_unit + allow_partials true, -- allow partial intervals so today's data is processed in each run + interval_unit 'day', -- finest granularity of data to be time bucketed + ); + ``` + + The `lookback` is calculated in days because the model's `interval_unit` is specified as `day`. + +### start +: Start is used to determine the earliest time needed to process the model. It can be an absolute date/time (`2022-01-01`), or a relative one (`1 year ago`). + +### end +: End is used to determine the latest time needed to process the model. It can be an absolute date/time (`2022-01-01`), or a relative one (`1 year ago`). + +### description +: Optional description of the model. Automatically registered as a table description/comment with the underlying SQL engine (if supported by the engine). + +### column_descriptions +: Optional dictionary of [key/value pairs](#explicit-column-comments). Automatically registered as column descriptions/comments with the underlying SQL engine (if supported by the engine). If not present, [inline comments](#inline-column-comments) will automatically be registered. ### grain -- A model's grain is the column or combination of columns that uniquely identify a row in the results returned by the model's query. If the grain is set, SQLMesh tools like `table_diff` are simpler to run because they automatically use the model grain for parameters that would otherwise need to be specified manually. +: A model's grain is the column or combination of columns that uniquely identify a row in the results returned by the model's query. If the grain is set, SQLMesh tools like `table_diff` are simpler to run because they automatically use the model grain for parameters that would otherwise need to be specified manually. ### grains -- A model can define multiple grains if it has more than one unique key or combination of keys. +: A model can define multiple grains if it has more than one unique key or combination of keys. ### references -- References are non-unique columns or combinations of columns that identify a join relationship to an entity. For example, a model could define a reference `account_id`, which would indicate that it can now automatically join to any model with an `account_id` grain. It cannot safely join to a table with an `account_id` reference because references are not unique and doing so would constitute a many-to-many join. Sometimes columns are named differently, in that case you can alias column names to a common entity name. For example `guest_id AS account_id` would allow a model with the column guest\_id to join to a model with the grain account\_id. +: References are non-unique columns or combinations of columns that identify a join relationship to another model. + + For example, a model could define a reference `account_id`, which would indicate that it can now automatically join to any model with an `account_id` grain. It cannot safely join to a table with an `account_id` reference because references are not unique and doing so would constitute a many-to-many join. + + Sometimes columns are named differently, in that case you can alias column names to a common entity name. For example `guest_id AS account_id` would allow a model with the column guest\_id to join to a model with the grain account\_id. + +### depends_on +: Depends on explicitly specifies the models on which the model depends, in addition to the ones automatically inferred by from the model code. + +### table_format +: Table format is an optional property for engines that support table formats like `iceberg` and `hive` where the physical file format is configurable. The intention is to define the table type using `table_format` and then the on-disk format of the files within the table using `storage_format`. + + Note that this property only implemented for engines that allow the `table_format` to be configured independently of the `storage_format`. ### storage_format -- Storage format is a property for engines such as Spark or Hive that support storage formats such as `parquet` and `orc`. +: Storage format is a property for engines such as Spark or Hive that support storage formats such as `parquet` and `orc`. Note that some engines dont make a distinction between `table_format` and `storage_format`, in which case `storage_format` is used and `table_format` is ignored. ### partitioned_by -- Partitioned by plays two roles. For most model kinds, it is an optional property for engines that support table partitioning such as Spark or BigQuery. For the [`INCREMENTAL_BY_PARTITION` model kind](./model_kinds.md#incremental_by_partition), it defines the partition key used to incrementally load data. It can specify a multi-column partition key or modify a date column for partitioning. For example, in BigQuery you could partition by day by extracting the day component of a timestamp column `event_ts` with `partitioned_by TIMESTAMP_TRUNC(event_ts, DAY)`. +: Partitioned by plays two roles. For most model kinds, it is an optional property for engines that support table partitioning such as Spark or BigQuery. + + For the [`INCREMENTAL_BY_PARTITION` model kind](./model_kinds.md#incremental_by_partition), it defines the partition key used to incrementally load data. + + It can specify a multi-column partition key or modify a date column for partitioning. For example, in BigQuery you could partition by day by extracting the day component of a timestamp column `event_ts` with `partitioned_by TIMESTAMP_TRUNC(event_ts, DAY)`. ### clustered_by -- Clustered by is an optional property for engines such as Bigquery that support clustering. +: Clustered by is an optional property for engines such as Bigquery that support clustering. ### columns -- By default, SQLMesh [infers a model's column names and types](#conventions) from its SQL query. Disable that behavior by manually specifying all column names and data types in the model's `columns` property. -- **WARNING**: SQLMesh may exhibit unexpected behavior if the `columns` property includes columns not returned by the query, omits columns returned by the query, or specifies data types other than the ones returned by the query. -- NOTE: Specifying column names and data types is required for [Python models](../models/python_models.md) that return DataFrames. +: By default, SQLMesh [infers a model's column names and types](#conventions) from its SQL query. Disable that behavior by manually specifying all column names and data types in the model's `columns` property. -For example, this shows a seed model definition that includes the `columns` key. It specifies the data types for all columns in the file: the `holiday_name` column is data type `VARCHAR` and the `holiday_date` column is data type `DATE`. + **WARNING**: SQLMesh may exhibit unexpected behavior if the `columns` property includes columns not returned by the query, omits columns returned by the query, or specifies data types other than the ones returned by the query. -```sql linenums="1" hl_lines="6-9" -MODEL ( - name test_db.national_holidays, - kind SEED ( - path 'national_holidays.csv' - ), - columns ( - holiday_name VARCHAR, - holiday_date DATE - ) -); -``` + For example, this shows a seed model definition that includes the `columns` key. It specifies the data types for all columns in the file: the `holiday_name` column is data type `VARCHAR` and the `holiday_date` column is data type `DATE`. -### description -- Optional description of the model. Automatically registered as a table description/comment with the underlying SQL engine (if supported by the engine). + ```sql linenums="1" hl_lines="6-9" + MODEL ( + name test_db.national_holidays, + kind SEED ( + path 'national_holidays.csv' + ), + columns ( + holiday_name VARCHAR, + holiday_date DATE + ) + ); + ``` -### column_descriptions -- Optional dictionary of [key/value pairs](#explicit-column-comments). Automatically registered as column descriptions/comments with the underlying SQL engine (if supported by the engine). If not present, [inline comments](#inline-column-comments) will automatically be registered. + NOTE: Specifying column names and data types is required for [Python models](../models/python_models.md) that return DataFrames. -### physical_properties (previously table_properties) -- A key-value mapping of arbitrary properties specific to the target engine that are applied to the model table / view in the physical layer. For example: +### physical_properties +: Previously named `table_properties` -```sql linenums="1" -MODEL ( - ..., - physical_properties ( - partition_expiration_days = 7, - require_partition_filter = true - ) -); + Physical properties is a key-value mapping of arbitrary properties that are applied to the model table / view in the physical layer. Note the partitioning details and `creatable_type` which overrides the kind of model/view created. In this case it creates a `TRANSIENT TABLE`. While `creatable_type` is generic, other properties are adapter specific so check the engine documentation for those. For example: -``` + ```sql linenums="1" + MODEL ( + ..., + physical_properties ( + partition_expiration_days = 7, + require_partition_filter = true, + creatable_type = TRANSIENT + ) + ); + + ``` ### virtual_properties -- A key-value mapping of arbitrary properties specific to the target engine that are applied to the model view in the virtual layer. For example: +: Virtual properties is a key-value mapping of arbitrary properties that are applied to the model view in the virtual layer. Note the partitioning details and `creatable_type` which overrides the kind of model/view created. In this case it creates a `SECURE VIEW`. While `creatable_type` is generic, other properties are adapter specific so check the engine documentation for those. For example: -```sql linenums="1" -MODEL ( - ..., - virtual_properties ( - labels = [('test-label', 'label-value')] - ) -); + ```sql linenums="1" + MODEL ( + ..., + virtual_properties ( + creatable_type = SECURE, + labels = [('test-label', 'label-value')] + ) + ); -``` + ``` + +### session_properties +: Session properties is a key-value mapping of arbitrary properties specific to the target engine that are applied to the engine session. ### allow_partials -- Indicates that this model can be executed for partial (incomplete) data intervals. By default, each model processes only complete intervals to prevent common mistakes caused by partial data. The size of the interval is determined by the model's [interval_unit](#interval_unit). Setting `allow_partials` to `true` overrides this behavior, indicating that the model may process a segment of input data that is missing some of the data points. Please note that setting this attribute to `true` results in the disregard of the [cron](#cron) attribute. +: Indicates that this model can be executed for partial (incomplete) data intervals. + + By default, each model processes only complete intervals to prevent common errors caused by partial data. The size of the interval is determined by the model's [interval_unit](#interval_unit). + + Setting `allow_partials` to `true` overrides this behavior, indicating that the model may process a segment of input data that is missing some of the data points. + + NOTE: To force the model to run every time, set `allow_partials` to `true` and use the `--ignore-cron` argument: `sqlmesh run --ignore-cron`. Simply setting `allow_partials` to `true` does not guarantee that the model will run on every `sqlmesh run` command invocation. The model’s configured `cron` schedule is still respected, even when partial intervals are allowed. + + Similarly, using `--ignore-cron` without setting `allow_partials` to `true` does not guarantee the model will run every time. Depending on the time of day, the interval might not be complete and ready for execution, even when ignoring the `cron` schedule. Therefore, both are required to ensure that the model runs on every `sqlmesh run` invocation. ### enabled -- Whether the model is enabled. This attribute is `true` by default. Setting it to `false` causes SQLMesh to ignore this model when loading the project. +: Whether the model is enabled. This attribute is `true` by default. Setting it to `false` causes SQLMesh to ignore this model when loading the project. + +### physical_version +: Pins the version of this model's physical table to the given value. + + NOTE: This can only be set for forward-only models. + +### gateway +: Specifies the gateway to use for the execution of this model. When not specified, the default gateway is used. + +### optimize_query +: Whether the model's query should be optimized. All SQL models are optimized by default. Setting this +to `false` causes SQLMesh to disable query canonicalization & simplification. This should be turned off only if the optimized query leads to errors such as surpassing text limit. + +!!! warning + Turning off the optimizer may prevent column-level lineage from working for the affected model and its descendants, unless all columns in the model's query are qualified and it contains no star projections (e.g. `SELECT *`). + +### validate_query +: Whether the model's query will be validated at compile time. This attribute is `false` by default. Setting it to `true` causes SQLMesh to raise an error instead of emitting warnings. This will display invalid columns in your SQL statements along with models containing `SELECT *` that cannot be automatically expanded to list out all columns. This ensures SQL is verified locally before time and money are spent running the SQL in your data warehouse. + +!!! warning + This flag is deprecated as of v.0.159.7+ in favor of the [linter](../../guides/linter.md). To preserve validation during compilation, the [built-in rules](../../guides/linter.md#built-in-rules) that check for correctness should be [configured](../../guides/linter.md#rule-violation-behavior) to error severity. + +### ignored_rules +: Specifies which linter rules should be ignored/excluded for this model. + +### formatting +: Whether the model will be formatted. All models are formatted by default. Setting this to `false` causes SQLMesh to ignore this model during `sqlmesh format`. ## Incremental Model Properties -For models that are incremental, the following parameters can be specified in the `kind`'s definition. +These properties can be specified in an incremental model's `kind` definition. + +Some properties are only available in specific model kinds - see the [model configuration reference](../../reference/model_configuration.md#incremental-models) for more information and a complete list of each `kind`'s properties. ### time_column -- Time column is a required property for incremental models. It is used to determine which records to overwrite when doing an incremental insert. Time column can have an optional format string specified in the SQL dialect of the model. -- Engines that support partitioning, such as Spark and BigQuery, use the time column as the model's partition key. Multi-column partitions or modifications to columns can be specified with the [`partitioned_by` property](#partitioned_by). +: Time column is a required property for incremental models. It is used to determine which records to overwrite when doing an incremental insert. Time column can have an optional format string specified in the SQL dialect of the model. -### lookback -- Lookback is used with [incremental by time range](model_kinds.md#incremental_by_time_range) models to capture late-arriving data. It must be a positive integer and specifies the number of interval time units prior to the current interval the model should include. For example, a model with cron `@daily` and `lookback` of 7 would include the previous 7 days each time it ran, while a model with cron `@weekly` and `lookback` of 7 would include the previous 7 weeks each time it ran. + Engines that support partitioning, such as Spark and BigQuery, use the time column as the model's partition key. Multi-column partitions or modifications to columns can be specified with the [`partitioned_by` property](#partitioned_by). -### on_destructive_change -- What should happen when a change to a [forward-only model](../../guides/incremental_time.md#forward-only-models) or incremental model in a [forward-only plan](../plans.md#forward-only-plans) causes a destructive modification to the table schema (i.e., requires dropping an existing column). SQLMesh checks for destructive changes at plan time based on the model definition and run time based on the model's underlying physical tables. Must be one of the following values: `allow`, `warn`, or `error` (default). + !!! tip "Important" + + The `time_column` variable should be in the UTC time zone - learn more [here](./model_kinds.md#timezones). ### batch_size -- Batch size is used to optimize backfilling incremental data. It determines the maximum number of intervals to run in a single job. For example, if a model specifies a cron of `@hourly` and a batch_size of `12`, when backfilling 3 days of data, the scheduler will spawn 6 jobs. (3 days * 24 hours/day = 72 hour intervals to fill. 72 intervals / 12 intervals per job = 6 jobs.) +: Batch size is used to backfill incremental data when the number of intervals to backfill is too large for the engine to execute in a single pass. It allows you to process sets of intervals in batches small enough to execute on your system. The `batch_size` parameter determines the maximum number of [`interval_unit`s](#interval_unit) of data to run in a single job. + + For example, consider a model with an `@hourly` [`cron`](#cron) that has not run in 3 days. Because its [`cron`](#cron) is `@hourly`, its [`interval_unit`](#interval_unit) is `hour`. + + First, let's calculate the total number of outstanding intervals to backfill: 3 days of unprocessed data * 24 hours/day = 72 `hour` intervals. + + Now we can calculate the number of jobs for different `batch_size` values with this formula: + + Number of Intervals / `batch_size` = Number of jobs to run + + Let's look at the number of jobs for a few different `batch_size` values: + - `batch_size` not specified: scheduler will spawn 1 job that processes all 72 intervals (SQLMesh's default behavior) + - `batch_size` of 1: scheduler will spawn [72 `hour` intervals / 1 interval per job] = 72 jobs + - `batch_size` of 12: scheduler will spawn [72 `hour` intervals / 12 intervals per job] = 6 jobs ### batch_concurrency -- The maximum number of [batches](#batch_size) that can run concurrently for this model. If not specified, the concurrency is only constrained by the number of concurrent tasks set in the connection settings. +: The maximum number of [batches](#batch_size) that can run concurrently for this model. If not specified, the concurrency is only constrained by the number of concurrent tasks set in the connection settings. + +### lookback +: Lookback is used with [incremental by time range](model_kinds.md#incremental_by_time_range) and [incremental by unique key](model_kinds.md#incremental_by_unique_key) models to capture late-arriving data. It allows the model to access data points not in the time interval currently being processed. + + It must be a positive integer and specifies how many [`interval_unit`s](#interval_unit) intervals before the current interval the model should include. + + For example, consider a model with cron `@daily` ([`interval_unit`](#interval_unit) `day`). If the model specified a `lookback` of 7, SQLMesh would include the 7 days prior to the time interval being processed. A model with cron `@weekly` and `lookback` of 7 would include the 7 weeks prior to the time interval being processed. + + Or consider a model whose cron expression is "run every 6 hours" (`0 */6 * * *`). SQLMesh calculates its [`interval_unit`](#interval_unit) as `hour`. The `lookback` value is calculated in `interval_units`, so a `lookback` of 1 would include the 1 hour prior to the time interval being processed. ### forward_only -- Set this to true to indicate that all changes to this model should be [forward-only](../plans.md#forward-only-plans). +: Set this to true to indicate that all changes to this model should be [forward-only](../plans.md#forward-only-plans). + +### on_destructive_change +: What should happen when a change to a [forward-only model](../../guides/incremental_time.md#forward-only-models) or incremental model in a [forward-only plan](../plans.md#forward-only-plans) causes a destructive modification to the table schema (i.e., requires dropping an existing column or modifying column constraints in ways that could cause data loss). + + SQLMesh checks for destructive changes at plan time based on the model definition and run time based on the model's underlying physical tables. + + Must be one of the following values: `allow`, `warn`, `error` (default), or `ignore`. + +### on_additive_change +: What should happen when a change to a [forward-only model](../../guides/incremental_time.md#forward-only-models) or incremental model in a [forward-only plan](../plans.md#forward-only-plans) causes an additive modification to the table schema (i.e., adding new columns, modifying column data types in compatible ways, ect.). + + SQLMesh checks for additive changes at plan time based on the model definition and run time based on the model's underlying physical tables. + + Must be one of the following values: `allow` (default), `warn`, `error`, or `ignore`. ### disable_restatement -- Set this to true to indicate that [data restatement](../plans.md#restatement-plans) is disabled for this model. +: Set this to true to indicate that [data restatement](../plans.md#restatement-plans) is disabled for this model. + +### auto_restatement_cron +: A cron expression that determines when SQLMesh should automatically restate this model. Restatement means re-evaluating either a number of last intervals (controlled by [`auto_restatement_intervals`](#auto_restatement_intervals)) for model kinds that support it or the entire model for model kinds that don't. Downstream models that depend on this model will also be restated. The auto-restatement is only applied when running the `sqlmesh run` command against the production environment. + + A common use case for auto-restatement is to periodically re-evaluate a model (less frequently than the model's cron) to account for late-arriving data or dimension changes. However, relying on this feature is generally not recommended, as it often indicates an underlying issue with the data model or dependency chain. Instead, users should prefer setting the [`lookback`](#lookback) property to handle late-arriving data more effectively. + + Unlike the [`lookback`](#lookback) property, which only controls the time range of data scanned, auto-restatement rewrites all previously processed data for this model in the target table. + + For model kinds that don't support [`auto_restatement_intervals`](#auto_restatement_intervals) the table will be re-created from scratch. + + Models with [`disable_restatement`](#disable_restatement) set to `true` will not be restated automatically even if this property is set. + + **NOTE**: Models with this property set can only be [previewed](../plans.md#data-preview-for-forward-only-changes) in development environments, which means that the data computed in those environments will not be reused in production. + + ```sql linenums="1" hl_lines="6" + MODEL ( + name test_db.national_holidays, + cron '@daily', + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key key, + auto_restatement_cron '@weekly', + ) + ); + ``` + +### auto_restatement_intervals +: The number of last intervals to restate automatically. This is only applied in conjunction with [`auto_restatement_cron`](#auto_restatement_cron). + + If not specified, the entire model will be restated. + + This property is only supported for the `INCREMENTAL_BY_TIME_RANGE` model kind. + + ```sql linenums="1" hl_lines="7" + MODEL ( + name test_db.national_holidays, + cron '@daily', + kind INCREMENTAL_BY_TIME_RANGE ( + time_column event_ts, + auto_restatement_cron '@weekly', + auto_restatement_intervals 7, -- automatically restate the last 7 days of data + ) + ); + ``` ## Macros Macros can be used for passing in parameterized arguments such as dates, as well as for making SQL less repetitive. By default, SQLMesh provides several predefined macro variables that can be used. Macros are used by prefixing with the `@` symbol. For more information, refer to [macros](../macros/overview.md). @@ -367,34 +602,3 @@ FROM y; -- Cleanup statements DROP TABLE temp_table; ``` - -## Time column -Models that are loaded incrementally require a time column to partition data. - -A time column is a column in a model with an optional format string in the dialect of the model; for example, `'%Y-%m-%d'` for DuckDB or `'yyyy-mm-dd'` for Snowflake. For more information, refer to [time column](./model_kinds.md#time-column). - -### Advanced usage -The column used as your model's time column is not limited to a text or date type. In the following example, the time column, `di`, is an integer: - -```sql linenums="1" hl_lines="5" --- Orders are partitioned by the di int column -MODEL ( - name sushi.orders, - dialect duckdb, - kind INCREMENTAL_BY_TIME_RANGE ( - time_column (order_date_int, '%Y%m%d') - ), -); - -SELECT - id::INT AS id, -- Primary key - customer_id::INT AS customer_id, -- Id of customer who made the order - waiter_id::INT AS waiter_id, -- Id of waiter who took the order - start_ts::TEXT AS start_ts, -- Start timestamp - end_ts::TEXT AS end_ts, -- End timestamp - di::INT AS order_date_int -- Date of order -FROM raw.orders -WHERE - order_date_int BETWEEN @start_ds AND @end_ds -``` -SQLMesh will handle casting the start and end dates to the type of your time column. The format is reflected in the time column format string. diff --git a/docs/concepts/models/python_models.md b/docs/concepts/models/python_models.md index c1e5b39cf4..10884ecedf 100644 --- a/docs/concepts/models/python_models.md +++ b/docs/concepts/models/python_models.md @@ -4,6 +4,16 @@ Although SQL is a powerful tool, some use cases are better handled by Python. Fo SQLMesh has first-class support for models defined in Python; there are no restrictions on what can be done in the Python model as long as it returns a Pandas or Spark DataFrame instance. + +!!! info "Unsupported model kinds" + + Python models do not support these [model kinds](./model_kinds.md) - use a SQL model instead. + + * `VIEW` + * `SEED` + * `MANAGED` + * `EMBEDDED` + ## Definition To create a Python model, add a new file with the `*.py` extension to the `models/` directory. Inside the file, define a function named `execute`. For example: @@ -33,7 +43,7 @@ The `execute` function is wrapped with the `@model` [decorator](https://wiki.pyt Because SQLMesh creates tables before evaluating models, the schema of the output DataFrame is a required argument. The `@model` argument `columns` contains a dictionary of column names to types. -The function takes an `ExecutionContext` that is able to run queries and to retrieve the current time interval that is being processed, along with arbitrary key-value arguments passed in at runtime. The function can either return a Pandas, PySpark, or Snowpark Dataframe instance. +The function takes an `ExecutionContext` that is able to run queries and to retrieve the current time interval that is being processed, along with arbitrary key-value arguments passed in at runtime. The function can either return a Pandas, PySpark, Bigframe, or Snowpark Dataframe instance. If the function output is too large, it can also be returned in chunks using Python generators. @@ -52,9 +62,12 @@ Supported `kind` dictionary `name` values are: - `ModelKindName.SEED` - `ModelKindName.INCREMENTAL_BY_TIME_RANGE` - `ModelKindName.INCREMENTAL_BY_UNIQUE_KEY` +- `ModelKindName.INCREMENTAL_BY_PARTITION` - `ModelKindName.SCD_TYPE_2_BY_TIME` - `ModelKindName.SCD_TYPE_2_BY_COLUMN` - `ModelKindName.EMBEDDED` +- `ModelKindName.CUSTOM` +- `ModelKindName.MANAGED` - `ModelKindName.EXTERNAL` This example demonstrates how to specify an incremental by time range model kind in Python: @@ -87,7 +100,56 @@ Optional pre/post-statements allow you to execute SQL commands before and after For example, pre/post-statements might modify settings or create indexes. However, be careful not to run any statement that could conflict with the execution of another statement if models run concurrently, such as creating a physical table. -Pre- and post-statements are issued with the SQLMesh [`fetchdf` method](../../reference/cli.md#fetchdf) [described above](#execution-context). +You can set the `pre_statements` and `post_statements` arguments to a list of SQL strings, SQLGlot expressions, or macro calls to define the model's pre/post-statements. + +**Project-level defaults:** You can also define pre/post-statements at the project level using `model_defaults` in your configuration. These will be applied to all models in your project and merged with any model-specific statements. Default statements are executed first, followed by model-specific statements. Learn more about this in the [model configuration reference](../../reference/model_configuration.md#model-defaults). + +``` python linenums="1" hl_lines="8-12" +@model( + "db.test_model", + kind="full", + columns={ + "id": "int", + "name": "text", + }, + pre_statements=[ + "SET GLOBAL parameter = 'value';", + exp.Cache(this=exp.table_("x"), expression=exp.select("1")), + ], + post_statements=["@CREATE_INDEX(@this_model, id)"], +) +def execute( + context: ExecutionContext, + start: datetime, + end: datetime, + execution_time: datetime, + **kwargs: t.Any, +) -> pd.DataFrame: + + return pd.DataFrame([ + {"id": 1, "name": "name"} + ]) + +``` + +The previous example's `post_statements` called user-defined SQLMesh macro `@CREATE_INDEX(@this_model, id)`. + +We could define the `CREATE_INDEX` macro in the project's `macros` directory like this. The macro creates a table index on a single column, conditional on the [runtime stage](../macros/macro_variables.md#runtime-variables) being `creating` (table creation time). + + +``` python linenums="1" +@macro() +def create_index( + evaluator: MacroEvaluator, + model_name: str, + column: str, +): + if evaluator.runtime_stage == "creating": + return f"CREATE INDEX idx ON {model_name}({column});" + return None +``` + +Alternatively, pre- and post-statements can be issued with the SQLMesh [`fetchdf` method](../../reference/cli.md#fetchdf) [described above](#execution-context). Pre-statements may be specified anywhere in the function body before it `return`s or `yield`s. Post-statements must execute after the function completes, so instead of `return`ing a value the function must `yield` the value. The post-statement must be specified after the `yield`. @@ -103,7 +165,7 @@ def execute( ) -> pd.DataFrame: # pre-statement - context.fetchdf("SET GLOBAL parameter = 'value';") + context.engine_adapter.execute("SET GLOBAL parameter = 'value';") # post-statement requires using `yield` instead of `return` yield pd.DataFrame([ @@ -111,19 +173,56 @@ def execute( ]) # post-statement - context.fetchdf("CREATE INDEX idx ON example.pre_post_statements (id);") + context.engine_adapter.execute("CREATE INDEX idx ON example.pre_post_statements (id);") +``` + +## Optional on-virtual-update statements + +The optional on-virtual-update statements allow you to execute SQL commands after the completion of the [Virtual Update](#virtual-update). + +These can be used, for example, to grant privileges on views of the virtual layer. + +Similar to pre/post-statements you can set the `on_virtual_update` argument in the `@model` decorator to a list of SQL strings, SQLGlot expressions, or macro calls. + +**Project-level defaults:** You can also define on-virtual-update statements at the project level using `model_defaults` in your configuration. These will be applied to all models in your project (including Python models) and merged with any model-specific statements. Default statements are executed first, followed by model-specific statements. Learn more about this in the [model configuration reference](../../reference/model_configuration.md#model-defaults). + +``` python linenums="1" hl_lines="8" +@model( + "db.test_model", + kind="full", + columns={ + "id": "int", + "name": "text", + }, + on_virtual_update=["GRANT SELECT ON VIEW @this_model TO ROLE dev_role"], +) +def execute( + context: ExecutionContext, + start: datetime, + end: datetime, + execution_time: datetime, + **kwargs: t.Any, +) -> pd.DataFrame: + + return pd.DataFrame([ + {"id": 1, "name": "name"} + ]) ``` +!!! note + + Table resolution for these statements occurs at the virtual layer. This means that table names, including `@this_model` macro, are resolved to their qualified view names. For instance, when running the plan in an environment named `dev`, `db.test_model` and `@this_model` would resolve to `db__dev.test_model` and not to the physical table name. ## Dependencies -In order to fetch data from an upstream model, you first get the table name using `context`'s `table` method. This returns the appropriate table name for the current runtime [environment](../environments.md): + +In order to fetch data from an upstream model, you first get the table name using `context`'s `resolve_table` method. This returns the appropriate table name for the current runtime [environment](../environments.md): ```python linenums="1" -table = context.table("docs_example.upstream_model") +table = context.resolve_table("docs_example.upstream_model") df = context.fetchdf(f"SELECT * FROM {table}") ``` -The `table` method will automatically add the referenced model to the Python model's dependencies. +The `resolve_table` method will automatically add the referenced model to the Python model's dependencies. The only other way to set dependencies of models in Python models is to define them explicitly in the `@model` decorator using the keyword `depends_on`. The dependencies defined in the model decorator take precedence over any dynamic references inside the function. @@ -143,15 +242,52 @@ def execute( ) -> pd.DataFrame: # ignored due to @model dependency "upstream_dependency" - context.table("docs_example.another_dependency") + context.resolve_table("docs_example.another_dependency") +``` + +User-defined [global variables](global-variables) or [blueprint variables](#python-model-blueprinting) can also be used in `resolve_table` calls, as shown in the following example (similarly for `blueprint_var()`): + +```python linenums="1" +@model( + "@schema_name.test_model2", + kind="FULL", + columns={"id": "INT"}, +) +def execute(context, **kwargs): + table = context.resolve_table(f"{context.var('schema_name')}.test_model1") + select_query = exp.select("*").from_(table) + return context.fetchdf(select_query) ``` +## Returning empty dataframes -## Global variables +Python models may not return an empty dataframe. -[User-defined global variables](../../reference/configuration.md#variables) can be accessed from within the Python model using function arguments, where the name of the argument represents a variable key. For example: +If your model could possibly return an empty dataframe, conditionally `yield` the dataframe or an empty generator instead of `return`ing: -```python linenums="1" hl_lines="9" +```python linenums="1" hl_lines="10-13" +@model( + "my_model.empty_df" +) +def execute( + context: ExecutionContext, +) -> pd.DataFrame: + + [...code creating df...] + + if df.empty: + yield from () + else: + yield df +``` + +## User-defined variables + +[User-defined global variables](../../reference/configuration.md#variables) can be accessed from within the Python model with the `context.var` method. + +For example, this model access the user-defined variables `var` and `var_with_default`. It specifies a default value of `default_value` if `variable_with_default` resolves to a missing value. + +```python linenums="1" hl_lines="11 12" @model( "my_model.name", ) @@ -160,16 +296,18 @@ def execute( start: datetime, end: datetime, execution_time: datetime, - my_var: Optional[str] = None, **kwargs: t.Any, ) -> pd.DataFrame: + var_value = context.var("var") + var_with_default_value = context.var("var_with_default", "default_value") ... ``` -Make sure to assign a default value to such arguments if you anticipate a missing variable key. Please note that arguments must be specified explicitly; in other words, variables can be accessed using `kwargs`. +Alternatively, you can access global variables via `execute` function arguments, where the name of the argument corresponds to the name of a variable key. -Alternatively, variables can be accessed using the `context.var` method. For example: -```python linenums="1" hl_lines="11 12" +For example, this model specifies `my_var` as an argument to the `execute` method. The model code can reference the `my_var` object directly: + +```python linenums="1" hl_lines="9 12" @model( "my_model.name", ) @@ -178,12 +316,117 @@ def execute( start: datetime, end: datetime, execution_time: datetime, + my_var: Optional[str] = None, **kwargs: t.Any, ) -> pd.DataFrame: - var_value = context.var("") - another_var_value = context.var("", "default_value") + my_var_plus1 = my_var + 1 ... ``` + +Make sure the argument has a default value if it's possible for the variable to be missing. + +Note that arguments must be specified explicitly - variables cannot be accessed using `kwargs`. + +## Python model blueprinting + +A Python model can also serve as a template for creating multiple models, or _blueprints_, by specifying a list of key-value dicts in the `blueprints` property. In order to achieve this, the model's name must be parameterized with a variable that exists in this mapping. + +For instance, the following model will result into two new models, each using the corresponding mapping in the `blueprints` property: + +```python linenums="1" +import typing as t +from datetime import datetime + +import pandas as pd +from sqlmesh import ExecutionContext, model + +@model( + "@{customer}.some_table", + kind="FULL", + blueprints=[ + {"customer": "customer1", "field_a": "x", "field_b": "y"}, + {"customer": "customer2", "field_a": "z", "field_b": "w"}, + ], + columns={ + "field_a": "text", + "field_b": "text", + "customer": "text", + }, +) +def entrypoint( + context: ExecutionContext, + start: datetime, + end: datetime, + execution_time: datetime, + **kwargs: t.Any, +) -> pd.DataFrame: + return pd.DataFrame( + { + "field_a": [context.blueprint_var("field_a")], + "field_b": [context.blueprint_var("field_b")], + "customer": [context.blueprint_var("customer")], + } + ) +``` + +Note the use of curly brace syntax `@{customer}` in the model name above. It is used to ensure SQLMesh can combine the macro variable into the model name identifier correctly - learn more [here](../../concepts/macros/sqlmesh_macros.md#embedding-variables-in-strings). + +Blueprint variable mappings can also be constructed dynamically, e.g., by using a macro: `blueprints="@gen_blueprints()"`. This is useful in cases where the `blueprints` list needs to be sourced from external sources, such as CSV files. + +For example, the definition of the `gen_blueprints` may look like this: + +```python linenums="1" +from sqlmesh import macro + +@macro() +def gen_blueprints(evaluator): + return ( + "((customer := customer1, field_a := x, field_b := y)," + " (customer := customer2, field_a := z, field_b := w))" + ) +``` + +It's also possible to use the `@EACH` macro, combined with a global list variable (`@values`): + +```python linenums="1" + +@model( + "@{customer}.some_table", + blueprints="@EACH(@values, x -> (customer := schema_@x))", + ... +) +... +``` + +## Using macros in model properties + +Python models support macro variables in model properties. However, special care must be taken when the macro variable appears within a string. + +For example when using macro variables inside cron expressions, you need to wrap the entire expression in quotes and prefix it with `@` to ensure proper parsing: + +```python linenums="1" +# Correct: Wrap the cron expression containing a macro variable +@model( + "my_model", + cron="@'*/@{mins} * * * *'", # Note the @'...' syntax + ... +) + +# This also works with blueprint variables +@model( + "@{customer}.scheduled_model", + cron="@'0 @{hour} * * *'", + blueprints=[ + {"customer": "customer_1", "hour": 2}, # Runs at 2 AM + {"customer": "customer_2", "hour": 8}, # Runs at 8 AM + ], + ... +) + +``` + +This is necessary because cron expressions often use `@` for aliases (like `@daily`, `@hourly`), which can conflict with SQLMesh's macro syntax. + ## Examples ### Basic The following is an example of a Python model returning a static Pandas DataFrame. @@ -195,6 +438,7 @@ import typing as t from datetime import datetime import pandas as pd +from sqlglot.expressions import to_column from sqlmesh import ExecutionContext, model @model( @@ -210,7 +454,7 @@ from sqlmesh import ExecutionContext, model "name": "Name corresponding to the ID", }, audits=[ - ("not_null", {"columns": ["id"]}), + ("not_null", {"columns": [to_column("id")]}), ], ) def execute( @@ -251,7 +495,7 @@ def execute( **kwargs: t.Any, ) -> pd.DataFrame: # get the upstream model's name and register it as a dependency - table = context.table("upstream_model") + table = context.resolve_table("upstream_model") # fetch data from the model as a pandas DataFrame # if the engine is spark, this returns a spark DataFrame @@ -290,7 +534,7 @@ def execute( **kwargs: t.Any, ) -> DataFrame: # get the upstream model's name and register it as a dependency - table = context.table("upstream_model") + table = context.resolve_table("upstream_model") # use the spark DataFrame api to add the country column df = context.spark.table(table).withColumn("country", functions.lit("USA")) @@ -333,6 +577,57 @@ def execute( return df ``` +### Bigframe +This example demonstrates using the [Bigframe](https://cloud.google.com/bigquery/docs/use-bigquery-dataframes#pandas-examples) DataFrame API. If you use Bigquery, the Bigframe API is preferred to Pandas as all computation is done in Bigquery. + +```python linenums="1" +import typing as t +from datetime import datetime + +from bigframes.pandas import DataFrame + +from sqlmesh import ExecutionContext, model + + +def get_bucket(num: int): + if not num: + return "NA" + boundary = 10 + return "at_or_above_10" if num >= boundary else "below_10" + + +@model( + "mart.wiki", + columns={ + "title": "text", + "views": "int", + "bucket": "text", + }, +) +def execute( + context: ExecutionContext, + start: datetime, + end: datetime, + execution_time: datetime, + **kwargs: t.Any, +) -> DataFrame: + # Create a remote function to be used in the Bigframe DataFrame + remote_get_bucket = context.bigframe.remote_function([int], str)(get_bucket) + + # Returns the Bigframe DataFrame handle, no data is computed locally + df = context.bigframe.read_gbq("bigquery-samples.wikipedia_pageviews.200809h") + + df = ( + # This runs entirely on the BigQuery engine lazily + df[df.title.str.contains(r"[Gg]oogle")] + .groupby(["title"], as_index=False)["views"] + .sum(numeric_only=True) + .sort_values("views", ascending=False) + ) + + return df.assign(bucket=df["views"].apply(remote_get_bucket)) +``` + ### Batching If the output of a Python model is very large and you cannot use Spark, it may be helpful to split the output into multiple batches. @@ -355,7 +650,7 @@ def execute( **kwargs: t.Any, ) -> pd.DataFrame: # get the upstream model's table name - table = context.table("upstream_model") + table = context.resolve_table("upstream_model") for i in range(3): # run 3 queries to get chunks of data and not run out of memory diff --git a/docs/concepts/models/seed_models.md b/docs/concepts/models/seed_models.md index bcfd25eca5..6f14960182 100644 --- a/docs/concepts/models/seed_models.md +++ b/docs/concepts/models/seed_models.md @@ -14,6 +14,10 @@ Seed models are a good fit for static datasets that change infrequently or not a * Names of national holidays and their dates * A static list of identifiers that should be excluded +!!! warning "Not supported in Python models" + + Python models do not support the `SEED` [model kind](./model_kinds.md) - use a SQL model instead. + ## Creating a seed model Similar to [SQL models](./sql_models.md), `SEED` models are defined in files with the `.sql` extension in the `models/` directory of the SQLMesh project. @@ -94,13 +98,15 @@ Christmas,2023-12-25 ``` When we run the `sqlmesh plan` command, the new seed model is automatically detected: -```bash hl_lines="6-7" +```bash hl_lines="8-9" $ sqlmesh plan ====================================================================== Successfully Ran 0 tests against duckdb ---------------------------------------------------------------------- -Summary of differences against `prod`: -└── Added Models: +`prod` environment will be initialized + +Models +└── Added: └── test_db.national_holidays Models needing backfill (missing dates): └── test_db.national_holidays: (2023-02-16, 2023-02-16) @@ -129,7 +135,9 @@ $ sqlmesh plan ====================================================================== Successfully Ran 0 tests against duckdb ---------------------------------------------------------------------- -Summary of differences against `prod`: +Differences from the `prod` environment: + +Models: └── Directly Modified: └── test_db.national_holidays --- @@ -190,3 +198,34 @@ ALTER SESSION SET TIMEZONE = 'UTC'; -- These are post-statements ALTER SESSION SET TIMEZONE = 'PST'; ``` + +## On-virtual-update statements + +Seed models also support on-virtual-update statements, which are executed after the completion of the [Virtual Update](#virtual-update). + +**Project-level defaults:** You can also define on-virtual-update statements at the project level using `model_defaults` in your configuration. These will be applied to all models in your project (including seed models) and merged with any model-specific statements. Default statements are executed first, followed by model-specific statements. Learn more about this in the [model configuration reference](../../reference/model_configuration.md#model-defaults). + +These must be enclosed within an `ON_VIRTUAL_UPDATE_BEGIN;` ...; `ON_VIRTUAL_UPDATE_END;` block: + +```sql linenums="1" hl_lines="8-13" +MODEL ( + name test_db.national_holidays, + kind SEED ( + path 'national_holidays.csv' + ) +); + +ON_VIRTUAL_UPDATE_BEGIN; +GRANT SELECT ON VIEW @this_model TO ROLE dev_role; +JINJA_STATEMENT_BEGIN; +GRANT SELECT ON VIEW {{ this_model }} TO ROLE admin_role; +JINJA_END; +ON_VIRTUAL_UPDATE_END; +``` + + +[Jinja expressions](../macros/jinja_macros.md) can also be used within them, as demonstrated in the example above. These expressions must be properly nested within a `JINJA_STATEMENT_BEGIN;` and `JINJA_END;` block. + +!!! note + + Table resolution for these statements occurs at the virtual layer. This means that table names, including `@this_model` macro, are resolved to their qualified view names. For instance, when running the plan in an environment named `dev`, `db.customers` and `@this_model` would resolve to `db__dev.customers` and not to the physical table name. \ No newline at end of file diff --git a/docs/concepts/models/sql_models.md b/docs/concepts/models/sql_models.md index bd5900f149..217cd7a6a2 100644 --- a/docs/concepts/models/sql_models.md +++ b/docs/concepts/models/sql_models.md @@ -10,6 +10,7 @@ The SQL-based definition of SQL models is the most common one, and consists of t * Optional pre-statements * A single query * Optional post-statements +* Optional on-virtual-update-statements These models are designed to look and feel like you're simply using SQL, but they can be customized for advanced use cases. @@ -62,19 +63,160 @@ Refer to `MODEL` [properties](./overview.md#properties) for the full list of all Optional pre/post-statements allow you to execute SQL commands before and after a model runs, respectively. -For example, post/post-statements might modify settings or create indexes. However, be careful not to run any statement that could conflict with the execution of another statement if the models run concurrently, such as creating a physical table. +For example, pre/post-statements might modify settings or create a table index. However, be careful not to run any statement that could conflict with the execution of another model if they are run concurrently, such as creating a physical table. -Pre/post-statements are evaluated twice: when a model's table is created and when its query logic is evaluated. Since executing such statements more than once can have unintended side-effects, it is also possible to [conditionally execute](../macros/sqlmesh_macros.md#if) them depending on SQLMesh's [runtime stage](../macros/macro_variables.md#predefined-variables). +Pre/post-statements are just standard SQL commands located before/after the model query. They must end with a semi-colon, and the model query must end with a semi-colon if a post-statement is present. The [example above](#example) contains both pre- and post-statements. + +**Project-level defaults:** You can also define pre/post-statements at the project level using `model_defaults` in your configuration. These will be applied to all models in your project and merged with any model-specific statements. Default statements are executed first, followed by model-specific statements. Learn more about this in the [model configuration reference](../../reference/model_configuration.md#model-defaults). + +!!! warning + + Pre/post-statements are evaluated twice: when a model's table is created and when its query logic is evaluated. Executing statements more than once can have unintended side-effects, so you can [conditionally execute](../macros/sqlmesh_macros.md#prepost-statements) them based on SQLMesh's [runtime stage](../macros/macro_variables.md#runtime-variables). + +The pre/post-statements in the [example above](#example) will run twice because they are not conditioned on runtime stage. + +We can condition the post-statement to only run after the model query is evaluated using the [`@IF` macro operator](../macros/sqlmesh_macros.md#if) and [`@runtime_stage` macro variable](../macros/macro_variables.md#runtime-variables) like this: + +```sql linenums="1" hl_lines="8-11" +MODEL ( + name db.customers, + kind FULL, +); + +[...same as example above...] + +@IF( + @runtime_stage = 'evaluating', + UNCACHE TABLE countries +); +``` + +Note that the SQL command `UNCACHE TABLE countries` inside the `@IF()` macro does **not** end with a semi-colon. Instead, the semi-colon comes after the `@IF()` macro's closing parenthesis. + +### Optional on-virtual-update statements + +The optional on-virtual-update statements allow you to execute SQL commands after the completion of the [Virtual Update](#virtual-update). + +These can be used, for example, to grant privileges on views of the virtual layer. + +**Project-level defaults:** You can also define on-virtual-update statements at the project level using `model_defaults` in your configuration. These will be applied to all models in your project and merged with any model-specific statements. Default statements are executed first, followed by model-specific statements. Learn more about this in the [model configuration reference](../../reference/model_configuration.md#model-defaults). + +These SQL statements must be enclosed within an `ON_VIRTUAL_UPDATE_BEGIN;` ...; `ON_VIRTUAL_UPDATE_END;` block like this: + +```sql linenums="1" hl_lines="10-15" +MODEL ( + name db.customers, + kind FULL +); + +SELECT + r.id::INT +FROM raw.restaurants AS r; + +ON_VIRTUAL_UPDATE_BEGIN; +GRANT SELECT ON VIEW @this_model TO ROLE role_name; +JINJA_STATEMENT_BEGIN; +GRANT SELECT ON VIEW {{ this_model }} TO ROLE admin; +JINJA_END; +ON_VIRTUAL_UPDATE_END; +``` + +[Jinja expressions](../macros/jinja_macros.md) can also be used within them, as demonstrated in the example above. These expressions must be properly nested within a `JINJA_STATEMENT_BEGIN;` and `JINJA_END;` block. + +!!! note + + Table resolution for these statements occurs at the virtual layer. This means that table names, including `@this_model` macro, are resolved to their qualified view names. For instance, when running the plan in an environment named `dev`, `db.customers` and `@this_model` would resolve to `db__dev.customers` and not to the physical table name. ### The model query The model must contain a standalone query, which can be a single `SELECT` expression, or multiple `SELECT` expressions combined with the `UNION`, `INTERSECT`, or `EXCEPT` operators. The result of this query will be used to populate the model's table or view. +### SQL model blueprinting + +A SQL model can also serve as a template for creating multiple models, or _blueprints_, by specifying a list of key-value mappings in the `blueprints` property. In order to achieve this, the model's name must be parameterized with a variable that exists in this mapping. + +For instance, the following model will result into two new models, each using the corresponding mapping in the `blueprints` property: + +```sql linenums="1" +MODEL ( + name @customer.some_table, + kind FULL, + blueprints ( + (customer := customer1, field_a := x, field_b := y), + (customer := customer2, field_a := z, field_b := w) + ) +); + +SELECT + @field_a, + @{field_b} AS field_b, + @'prefix_@{field_a}_suffix' AS literal_example +FROM @customer.some_source +``` + +The two models produced from this template are: + +```sql linenums="1" +-- This uses the first variable mapping +MODEL ( + name customer1.some_table, + kind FULL +); + +SELECT + x, + y AS field_b, + 'prefix_x_suffix' AS literal_example +FROM customer1.some_source + +-- This uses the second variable mapping +MODEL ( + name customer2.some_table, + kind FULL +); + +SELECT + z, + w AS field_b, + 'prefix_z_suffix' AS literal_example +FROM customer2.some_source +``` + +Both `@field_a` and `@{field_b}` resolve blueprint variable values as SQL identifiers. The curly brace syntax is useful when embedding a variable within a larger string where the variable boundary would otherwise be ambiguous (e.g. `@{customer}_suffix`). To produce a string literal with interpolated variables, use the `@'...@{var}...'` syntax as shown with `literal_example` above. Learn more about the curly brace syntax [here](../../concepts/macros/sqlmesh_macros.md#embedding-variables-in-strings). + +Blueprint variable mappings can also be constructed dynamically, e.g., by using a macro: `blueprints @gen_blueprints()`. This is useful in cases where the `blueprints` list needs to be sourced from external sources, such as CSV files. + +For example, the definition of the `gen_blueprints` may look like this: + +```python linenums="1" +from sqlmesh import macro + +@macro() +def gen_blueprints(evaluator): + return ( + "((customer := customer1, field_a := x, field_b := y)," + " (customer := customer2, field_a := z, field_b := w))" + ) +``` + +It's also possible to use the `@EACH` macro, combined with a global list variable (`@values`): + +```sql linenums="1" +MODEL ( + name @customer.some_table, + kind FULL, + blueprints @EACH(@values, x -> (customer := schema_@x)), +); + +SELECT + 1 AS c +``` + ## Python-based definition The Python-based definition of SQL models consists of a single python function, decorated with SQLMesh's `@model` [decorator](https://wiki.python.org/moin/PythonDecorators). The decorator is required to have the `is_sql` keyword argument set to `True` to distinguish it from [Python models](./python_models.md) that return DataFrame instances. -This function's return value serves as the model's query, and it must be either a SQL string or a [SQLGlot expression](https://github.com/tobymao/sqlglot/blob/main/sqlglot/expressions.py). The `@model` decorator is used to define the model's [metadata](#MODEL-DDL) and, optionally its pre/post-statements that are also in the form of SQL strings or SQLGlot expressions. +This function's return value serves as the model's query, and it must be either a SQL string or a [SQLGlot expression](https://github.com/tobymao/sqlglot/blob/main/sqlglot/expressions.py). The `@model` decorator is used to define the model's [metadata](#MODEL-DDL) and, optionally its pre/post-statements or on-virtual-update-statements that are also in the form of SQL strings or SQLGlot expressions. Defining a SQL model using Python can be beneficial in cases where its query is too complex to express cleanly in SQL, for example due to having many dynamic components that would require heavy use of [macros](../macros/overview/). Since Python-based models generate SQL, they support the same features as regular SQL models, such as column-level [lineage](../glossary/#lineage). @@ -88,7 +230,7 @@ The following example demonstrates how the above `db.customers` model can be def from sqlglot import exp from sqlmesh.core.model import model -from sqlmesh.core.macro import MacroEvaluator +from sqlmesh.core.macros import MacroEvaluator @model( "db.customers", @@ -96,6 +238,7 @@ from sqlmesh.core.macro import MacroEvaluator kind="FULL", pre_statements=["CACHE TABLE countries AS SELECT * FROM raw.countries"], post_statements=["UNCACHE TABLE countries"], + on_virtual_update=["GRANT SELECT ON VIEW @this_model TO ROLE dev_role"], ) def entrypoint(evaluator: MacroEvaluator) -> str | exp.Expression: return ( @@ -107,13 +250,78 @@ def entrypoint(evaluator: MacroEvaluator) -> str | exp.Expression: One could also define this model by simply returning a string that contained the SQL query of the SQL-based example. Strings used as pre/post-statements or return values in Python-based models will be parsed into SQLGlot expressions, which means that SQLMesh will still be able to understand them semantically and thus provide information such as column-level lineage. -**Note:** Since python models have access to the macro evaluation context (`MacroEvaluator`), they can also [access model schemas](../macros/sqlmesh_macros.md#accessing-model-schemas) through its `columns_to_types` method. +!!! note + + Since python models have access to the macro evaluation context (`MacroEvaluator`), they can also [access model schemas](../macros/sqlmesh_macros.md#accessing-model-schemas) through its `columns_to_types` method. ### `@model` decorator -The `@model` decorator is the Python equivalent of the `MODEL` DDL. In addition to model metadata and configuration information, one can also set the keyword arguments `pre_statements` and `post_statements` to a list of SQL strings and/or SQLGlot expressions to define the pre/post-statements of the model, respectively. +The `@model` decorator is the Python equivalent of the `MODEL` DDL. + +In addition to model metadata and configuration information, one can also set the keyword arguments `pre_statements`, `post_statements` and `on_virtual_update` to a list of SQL strings and/or SQLGlot expressions to define the pre/post-statements and on-virtual-update-statements of the model, respectively. + +!!! note + + All of the [metadata property](./overview.md#model-properties) field names are the same as those in the `MODEL` DDL. + +### Python model blueprinting + +A Python-based SQL model can also serve as a template for creating multiple models, or _blueprints_, by specifying a list of key-value dicts in the `blueprints` property. In order to achieve this, the model's name must be parameterized with a variable that exists in this mapping. + +For instance, the following model will result into two new models, each using the corresponding mapping in the `blueprints` property: + +```python linenums="1" +from sqlglot import exp + +from sqlmesh.core.model import model +from sqlmesh.core.macros import MacroEvaluator + +@model( + "@{customer}.some_table", + is_sql=True, + kind="FULL", + blueprints=[ + {"customer": "customer1", "field_a": "x", "field_b": "y"}, + {"customer": "customer2", "field_a": "z", "field_b": "w"}, + ], +) +def entrypoint(evaluator: MacroEvaluator) -> str | exp.Expression: + field_a = evaluator.blueprint_var("field_a") + field_b = evaluator.blueprint_var("field_b") + customer = evaluator.blueprint_var("customer") + + return exp.select(field_a, field_b).from_(f"{customer}.some_source") +``` + +The two models produced from this template are the same as in the [example](#SQL-model-blueprinting) for SQL-based blueprinting. + +Blueprint variable mappings can also be constructed dynamically, e.g., by using a macro: `blueprints="@gen_blueprints()"`. This is useful in cases where the `blueprints` list needs to be sourced from external sources, such as CSV files. + +For example, the definition of the `gen_blueprints` may look like this: -**Note:** All of the [metadata](./overview.md#properties) field names are the same as those in the `MODEL` DDL. +```python linenums="1" +from sqlmesh import macro + +@macro() +def gen_blueprints(evaluator): + return ( + "((customer := customer1, field_a := x, field_b := y)," + " (customer := customer2, field_a := z, field_b := w))" + ) +``` + +It's also possible to use the `@EACH` macro, combined with a global list variable (`@values`): + +```python linenums="1" + +@model( + "@{customer}.some_table", + is_sql=True, + blueprints="@EACH(@values, x -> (customer := schema_@x))", + ... +) +... +``` ## Automatic dependencies @@ -130,7 +338,7 @@ JOIN countries SQLMesh will detect that the model depends on both `employees` and `countries`. When executing this model, it will ensure that `employees` and `countries` are executed first. -External dependencies not defined in SQLMesh are also supported. SQLMesh can either depend on them implicitly through the order in which they are executed, or through signals if you are using [Airflow](../../integrations/airflow.md). +External dependencies not defined in SQLMesh are also supported. SQLMesh can either depend on them implicitly through the order in which they are executed, or through [signals](../../guides/signals.md). Although automatic dependency detection works most of the time, there may be specific cases for which you want to define dependencies manually. You can do so in the `MODEL` DDL with the [dependencies property](./overview.md#properties). diff --git a/docs/concepts/overview.md b/docs/concepts/overview.md index 14dab2c4ec..fcfd8cf248 100644 --- a/docs/concepts/overview.md +++ b/docs/concepts/overview.md @@ -61,13 +61,11 @@ You create audits by writing SQL queries that should return 0 rows. For example, Audits are flexible — they can be tied to a specific model's contents, or you can use [macros](./macros/overview.md) to create audits that are usable by multiple models. SQLMesh also includes pre-made audits for common use cases, such as detecting NULL or duplicated values. -You specify which audits should run for a model by including them in the model's metadata properties. +You specify which audits should run for a model by including them in the model's metadata properties. To apply them globally across your project, include them in the model defaults configuration. SQLMesh automatically runs audits when you apply a `plan` to an environment, or you can run them on demand with the [`audit` command](../reference/cli.md#audit). ## Infrastructure and orchestration Every company's data infrastructure is different. SQLMesh is flexible with regard to which engines and orchestration frameworks you use — its only requirement is access to the target SQL/analytics engine. -SQLMesh keeps track of model versions and processed data intervals using your existing infrastructure. If SQLMesh is configured without an external orchestrator (such as Airflow), it automatically creates a `sqlmesh` database in your data warehouse for its internal metadata. - -If SQLMesh is configured with Airflow, then it will store all its metadata in the Airflow database. Read more about how [SQLMesh integrates with Airflow](../integrations/airflow.md). +SQLMesh keeps track of model versions and processed data intervals using your existing infrastructure. SQLMesh it automatically creates a `sqlmesh` schema in your data warehouse for its internal metadata. diff --git a/docs/concepts/plans.md b/docs/concepts/plans.md index 1063c91953..defcd06c0d 100644 --- a/docs/concepts/plans.md +++ b/docs/concepts/plans.md @@ -39,14 +39,9 @@ Choose this option when a change has been made to a model's logic that has a fun ### Non-breaking change A directly-modified model that is classified as non-breaking will be backfilled, but its downstream dependencies will not. -This is a common choice in scenarios such as an addition of a new column, an action which doesn't affect downstream models, as new columns can't be used by downstream models without modifying them directly to select the column. If any downstream models contain a `select *` from the model, SQLMesh attempts to infer breaking status on a best-effort basis. We recommend explicitly specifying a query's columns to avoid unnecessary recomputation. +This is a common choice in scenarios such as an addition of a new column, an action which doesn't affect downstream models, as new columns can't be used by downstream models without modifying them directly to select the column. -### Forward-only change -A modified (either directly or indirectly) model that is categorized as forward-only will continue to use the existing physical table once the change is deployed to production (the `prod` environment). This means that no backfill will take place. - -While iterating on forward-only changes in the development environment, the model's output will be stored in either a temporary table or a shallow clone of the production table if supported by the engine. In either case the data produced this way in the development environment can only be used for preview and will **not** be reused once the change is deployed to production. See [Forward-only Plans](#forward-only-plans) for more details. - -This category is assigned by SQLMesh automatically either when a user opts into using a [forward-only plan](#forward-only-plans) or when a model is explicitly configured to be forward-only. +If any downstream models contain a `select *` from the model, SQLMesh attempts to infer breaking status on a best-effort basis. We recommend explicitly specifying a query's columns to avoid unnecessary recomputation. ### Summary @@ -55,7 +50,17 @@ This category is assigned by SQLMesh automatically either when a user opts into | [Breaking](#breaking-change) | [Direct](glossary.md#direct-modification) or [Indirect](glossary.md#indirect-modification) | [Backfill](glossary.md#backfill) | | [Non-breaking](#non-breaking-change) | [Direct](glossary.md#direct-modification) | [Backfill](glossary.md#backfill) | | [Non-breaking](#non-breaking-change) | [Indirect](glossary.md#indirect-modification) | [No Backfill](glossary.md#backfill) | -| [Forward-only](#forward-only-change) | [Direct](glossary.md#direct-modification) or [Indirect](glossary.md#indirect-modification) | [No Backfill](glossary.md#backfill), schema change | + +## Forward-only change +In addition to categorizing a change as breaking or non-breaking, it can also be classified as forward-only. + +A model change classified as forward-only will continue to use the existing physical table once the change is deployed to production (the `prod` environment). This means that no backfill will take place. + +While iterating on forward-only changes in the development environment, the model's output will be stored in either a temporary table or a shallow clone of the production table if supported by the engine. + +In either case the data produced this way in the development environment can only be used for preview and will **not** be reused once the change is deployed to production. See [Forward-only Plans](#forward-only-plans) for more details. + +This category is assigned by SQLMesh automatically either when a user opts into using a [forward-only plan](#forward-only-plans) or when a model is explicitly configured to be forward-only. ## Plan application Once a plan has been created and reviewed, it is then applied to the target [environment](environments.md) in order for its changes to take effect. @@ -68,12 +73,18 @@ When a plan is applied to an environment, the environment gets associated with t *Each model variant gets its own physical table while environments only contain references to these tables.* -This unique approach to understanding and applying changes is what enables SQLMesh's Virtual Environments. This technology allows SQLMesh to ensure complete isolation between environments while allowing it to share physical data assets between environments when appropriate and safe to do so. Additionally, since each model change is captured in a separate physical table, reverting to a previous version becomes a simple and quick operation (refer to [Virtual Update](#virtual-update)) as long as its physical table hasn't been garbage collected by the janitor process. SQLMesh makes it easy to be correct and really hard to accidentally and irreversibly break things. +This unique approach to understanding and applying changes is what enables SQLMesh's Virtual Environments. It allows SQLMesh to ensure complete isolation between environments while allowing it to share physical data assets between environments when appropriate and safe to do so. + +Additionally, since each model change is captured in a separate physical table, reverting to a previous version becomes a simple and quick operation (refer to [Virtual Update](#virtual-update)) as long as its physical table hasn't been garbage collected by the janitor process. + +SQLMesh makes it easy to be correct and really hard to accidentally and irreversibly break things. ### Backfilling -Despite all the benefits, the approach described above is not without trade-offs. When a new model version is just created, a physical table assigned to it is empty. Therefore, SQLMesh needs to re-apply the logic of the new model version to the entire date range of this model in order to populate the new version's physical table. This process is called backfilling. +Despite all the benefits, the approach described above is not without trade-offs. -At the moment, we are using the term backfilling broadly to describe any situation in which a model is updated. That includes these operations: +When a new model version is just created, a physical table assigned to it is empty. Therefore, SQLMesh needs to re-apply the logic of the new model version to the entire date range of this model in order to populate the new version's physical table. This process is called backfilling. + +We use the term backfilling broadly to describe any situation in which a model is updated. That includes these operations: * When a VIEW model is created * When a FULL model is built @@ -81,16 +92,224 @@ At the moment, we are using the term backfilling broadly to describe any situati * When an INCREMENTAL model has recent data appended to it * When an INCREMENTAL model has older data inserted (i.e., resolving a data gap or prepending historical data) -We will be iterating on terminology to better capture the nuances of each type in future versions. +Note for incremental models: despite the fact that backfilling can happen incrementally (see `batch_size` parameter on models), there is an extra cost associated with this operation due to additional runtime involved. If the runtime cost is a concern, use a [forward-only plan](#forward-only-plans) instead. + +### Virtual Update +A benefit of SQLMesh's approach is that data for a new model version can be fully pre-built while still in a development environment. That way all changes and their downstream dependencies can be fully previewed before they are promoted to the production environment. + +With this approach, the process of promoting a change to production is reduced to reference swapping. + +If during plan creation no data gaps have been detected and only references to new model versions need to be updated, then the update is referred to as a Virtual Update. Virtual Updates impose no additional runtime overhead or cost. + +### Start and end dates + +The `plan` command provides two temporal options: `--start` and `--end`. These options are only applicable to plans for non-prod environments. + +For context, every model has a start date. The start can be specified in [the model definition](./models/overview.md#start), in the [project configuration's `model_defaults`](../guides/configuration.md#model-defaults), or by SQLMesh's default value of yesterday. + +Because the prod environment supports business operations, prod plans ensure every model is backfilled from its start date until the most recent completed time interval. Due to that restriction, the `plan` command's `--start` and `--end` options are not supported for regular plans against prod. The options are supported for [restatement plans](#restatement-plans) against prod to allow re-processing a subset of existing data. + +Non-prod plans are typically used for development, so their models can optionally be backfilled for any date range with the `--start` and `--end` options. Limiting the date range makes backfills faster and development more efficient, especially for incremental models using large tables. + +#### Model kind limitations + +Some model kinds do not support backfilling a limited date range. + +For context, SQLMesh strives to make models _idempotent_, meaning that if we ran them multiple times we would get the same correct result every time. + +However, some model kinds are inherently non-idempotent: + +- [INCREMENTAL_BY_UNIQUE_KEY](models/model_kinds.md#incremental_by_unique_key) +- [INCREMENTAL_BY_PARTITION](models/model_kinds.md#incremental_by_partition) +- [SCD_TYPE_2_BY_TIME](models/model_kinds.md#scd-type-2-by-time-recommended) +- [SCD_TYPE_2_BY_COLUMN](models/model_kinds.md#scd-type-2-by-column) +- Any model whose query is self-referential (i.e., the contents of new data rows are affected by the data rows already present in the table) + +Those model kinds will behave as follows in a non-prod plan that specifies a limited date range: + +- If the `--start` option date is the same as or before the model's start date, the model is fully refreshed for all of time +- If the `--start` option date is after the model's start date, only a preview is computed for this model which can't be reused when deploying to production + +#### Example + +Consider a SQLMesh project with a default start date of 2024-09-20. + +It contains the following `INCREMENTAL_BY_UNIQUE_KEY` model that specifies an explicit start date of 2024-09-23: + +```sql linenums="1" hl_lines="6" +MODEL ( + name sqlmesh_example.start_end_model, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key item_id + ), + start '2024-09-23' +); + +SELECT + item_id, + num_orders +FROM + sqlmesh_example.full_model +``` + +When we run the project's first plan, we see that SQLMesh correctly detected a different start date for our `start_end_model` than the other models (which have the project default start of 2024-09-20): + +```bash linenums="1" hl_lines="17" +❯ sqlmesh plan +====================================================================== +Successfully Ran 1 tests against duckdb +---------------------------------------------------------------------- +`prod` environment will be initialized + +Models: +└── Added: + ├── sqlmesh_example.full_model + ├── sqlmesh_example.incremental_model + ├── sqlmesh_example.seed_model + └── sqlmesh_example.start_end_model +Models needing backfill (missing dates): +├── sqlmesh_example.full_model: 2024-09-20 - 2024-09-26 +├── sqlmesh_example.incremental_model: 2024-09-20 - 2024-09-26 +├── sqlmesh_example.seed_model: 2024-09-20 - 2024-09-26 +└── sqlmesh_example.start_end_model: 2024-09-23 - 2024-09-26 +Apply - Backfill Tables [y/n]: +``` + +After executing that plan, we add columns to both the `incremental_model` and `start_end_model` queries. + +We then execute `sqlmesh plan dev` to create the new `dev` environment: + +```bash linenums="1" hl_lines="23-26" + +❯ sqlmesh plan dev +====================================================================== +Successfully Ran 1 tests against duckdb +---------------------------------------------------------------------- +New environment `dev` will be created from `prod` + +Differences from the `prod` environment: + +Models: +├── Directly Modified: +│ ├── sqlmesh_example__dev.start_end_model +│ └── sqlmesh_example__dev.incremental_model +└── Indirectly Modified: + └── sqlmesh_example__dev.full_model + +[...model diff omitted...] + +Directly Modified: sqlmesh_example__dev.incremental_model (Non-breaking) +└── Indirectly Modified Children: + └── sqlmesh_example__dev.full_model (Indirect Non-breaking) + +[...model diff omitted...] + +Directly Modified: sqlmesh_example__dev.start_end_model (Non-breaking) +Models needing backfill (missing dates): +├── sqlmesh_example__dev.incremental_model: 2024-09-20 - 2024-09-26 +└── sqlmesh_example__dev.start_end_model: 2024-09-23 - 2024-09-26 +Enter the backfill start date (eg. '1 year', '2020-01-01') or blank to backfill from the beginning of history: +``` + +Note two things about the output: -Note for incremental models: despite the fact that backfilling can happen incrementally (see `batch_size` parameter on models), there is an extra cost associated with this operation due to additional runtime involved. If the runtime cost is a concern, a [forward-only plan](#forward-only-plans) can be used instead. +1. As before, SQLMesh displays the complete backfill time range for each model, using the project default start of 2024-09-20 for `incremental_model` and 2024-09-23 for `start_end_model` +2. SQLMesh prompted us for a backfill start date because we didn't pass the `--start` option to the `sqlmesh plan dev` command + +Let's cancel that plan and start a new one, passing a start date of 2024-09-24. + +The `start_end_model` is of kind `INCREMENTAL_BY_UNIQUE_KEY`, which is non-idempotent and cannot be backfilled for a limited time range. + +Because the command's `--start` of 2024-09-24 is after `start_end_model`'s start date 2024-09-23, `start_end_model` is marked as preview: + +``` bash linenums="1" hl_lines="12-13 20-21" +❯ sqlmesh plan dev --start 2024-09-24 +====================================================================== +Successfully Ran 1 tests against duckdb +---------------------------------------------------------------------- +New environment `dev` will be created from `prod` + +Differences from the `prod` environment: + +Models: +├── Directly Modified: +│ ├── sqlmesh_example__dev.start_end_model +│ └── sqlmesh_example__dev.incremental_model +└── Indirectly Modified: + └── sqlmesh_example__dev.full_model + +[...model diff omitted...] + +Directly Modified: sqlmesh_example__dev.start_end_model (Non-breaking) +Models needing backfill (missing dates): +├── sqlmesh_example__dev.incremental_model: 2024-09-24 - 2024-09-26 +└── sqlmesh_example__dev.start_end_model: 2024-09-24 - 2024-09-26 (preview) +Enter the backfill end date (eg. '1 month ago', '2020-01-01') or blank to backfill up until '2024-09-27 00:00:00': +``` + +#### Minimum intervals + +When you run a plan with a fixed `--start` or `--end` date, you create a virtual data environment with a limited subset of data. However, if the time range specified is less than the size of an interval on one of your models, that model will be skipped by default. + +For example, if you have a model like so: + +```sql +MODEL( + name sqlmesh_example.monthly_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column month + ), + cron '@monthly' +); + +SELECT SUM(a) AS sum_a, MONTH(day) AS month +FROM sqlmesh_example.upstream_model +WHERE day BETWEEN @start_ds AND @end_ds +``` -#### Data preview -As mentioned earlier, the data output produced by [forward-only changes](#forward-only-change) in the development environment can only be used for preview and will not be reused upon deployment to production. +make a change to it and run the following: + +```bash linenums="1" hl_lines="8" +$ sqlmesh plan dev --start '1 day ago' + +Models: +└── Added: + └── sqlmesh_example__dev.monthly_model +Apply - Virtual Update [y/n]: y + +SKIP: No model batches to execute +``` + +No data will be backfilled because `1 day ago` does not contain a complete month. However, you can use the `--min-intervals` option to override this behaviour like so: + +```bash linenums="1" hl_lines="11" +$ sqlmesh plan dev --start '1 day ago' --min-intervals 1 + +Models: +└── Added: + └── sqlmesh_example__dev.monthly_model +Apply - Virtual Update [y/n]: y + +[1/1] sqlmesh_example__dev.monthly_model [insert 2025-06-01 - 2025-06-30] 0.08s +Executing model batches ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 1/1 • 0:00:00 + +✔ Model batches executed +``` + +This will ensure that regardless of the plan `--start` date, all added or modified models will have at least `--min-intervals` intervals considered for backfill. + +!!! info + + If you are running plans manually you can just adjust the `--start` date to be wide enough to cover the models in question. + + The `--min-intervals` option is primarily intended for [automation scenarios](../integrations/github.md) where the plan is always run with a default relative start date and you always want (for example) "2 weeks worth of data" in the target environment. + +### Data preview for forward-only changes +As mentioned earlier, the data output produced by [forward-only changes](#forward-only-change) in a development environment can only be used for preview and will not be reused in production. The same holds true for any subsequent changes that depend on undeployed forward-only changes - data can be previewed but can't be reused in production. -Backfills that are exclusively for preview purposes and will not be reused upon deployment to production are explicitly labeled as such in the plan summary: +Backfills that are exclusively for preview purposes and will not be reused upon deployment to production are explicitly labeled with `(preview)` in the plan summary: ```bash Models needing backfill (missing dates): ├── sushi__dev.customers: 2023-12-22 - 2023-12-28 (preview) @@ -99,19 +318,18 @@ Models needing backfill (missing dates): └── sushi__dev.waiter_as_customer_by_day: 2023-12-22 - 2023-12-28 (preview) ``` -### Virtual Update -Another benefit of the SQLMesh approach is that data for a new model version can be fully pre-built while still in a development environment. That way all changes and their downstream dependencies can be fully previewed before they are promoted to the production environment. - -With this approach, the process of promoting a change to production is reduced to reference swapping. If during plan creation no data gaps have been detected and only references to new model versions need to be updated, then the update is referred to as a Virtual Update. Virtual Updates impose no additional runtime overhead or cost. - ## Forward-only plans Sometimes the runtime cost associated with rebuilding an entire physical table is too high and outweighs the benefits a separate table provides. This is when a forward-only plan comes in handy. -When a forward-only plan is applied to the `prod` environment, none of the plan's changed models will have new physical tables created for them. Instead, physical tables from previous model versions are reused. The benefit of this is that no backfilling is required, so there is no runtime overhead or cost. The drawback is that reverting to a previous version is no longer simple and requires a combination of additional forward-only changes and [restatements](#restatement-plans). +When a forward-only plan is applied to the `prod` environment, none of the plan's changed models will have new physical tables created for them. Instead, physical tables from previous model versions are reused. + +The benefit of this is that no backfilling is required, so there is no runtime overhead or cost. The drawback is that reverting to a previous version is no longer simple and requires a combination of additional forward-only changes and [restatements](#restatement-plans). Note that once a forward-only change is applied to `prod`, all development environments that referred to the previous versions of the updated models will be impacted. -A core component of the development process is to execute code and verify its behavior. To enable this while preserving isolation between environments, `sqlmesh plan [environment name]` evaluates code in non-`prod` environments while targeting shallow (a.k.a. "zero-copy") clones of production tables for engines that support them or newly created temporary physical tables for engines that don't. This means that only a limited preview of changes is available in the development environment before the change is promoted to `prod`. The date range of the preview is provided as part of plan creation command. +A core component of the development process is to execute code and verify its behavior. To enable this while preserving isolation between environments, `sqlmesh plan [environment name]` evaluates code in non-`prod` environments while targeting shallow (a.k.a. "zero-copy") clones of production tables for engines that support them or newly created temporary physical tables for engines that don't. + +This means that only a limited preview of changes is available in the development environment before the change is promoted to `prod`. The date range of the preview is provided as part of plan creation command. Engines for which table cloning is supported include: @@ -126,7 +344,10 @@ To create a forward-only plan, add the `--forward-only` option to the `plan` com sqlmesh plan [environment name] --forward-only ``` -**Note:** The `--forward-only` flag is not required when applying changes to models that have been explicitly configured as [forward-only](models/overview.md#forward_only). Use it only if you need to provide a time range for the preview window or the [effective date](#effective-date). +!!! note + The `--forward-only` flag is not required when applying changes to models that have been explicitly configured as [forward-only](models/overview.md#forward_only). + + Use it only if you need to provide a time range for the preview window or the [effective date](#effective-date). ### Destructive changes @@ -134,30 +355,71 @@ Some model changes destroy existing data in a table. SQLMesh automatically detec Forward-only plans treats all of the plan's model changes as forward-only. In these plans, SQLMesh will check all modified incremental models for destructive schema changes, not just forward-only models. -SQLMesh determines what to do for each model based on this setting hierarchy: the [model's `on_destructive_change` value](../guides/incremental_time.md#destructive-changes) (if present), the `on_destructive_change` [model defaults](../reference/model_configuration.md#model-defaults) value (if present), and the SQLMesh global default of `error`. +SQLMesh determines what to do for each model based on this setting hierarchy: + +- **For destructive changes**: the [model's `on_destructive_change` value](../guides/incremental_time.md#schema-changes) (if present), the `on_destructive_change` [model defaults](../reference/model_configuration.md#model-defaults) value (if present), and the SQLMesh global default of `error` +- **For additive changes**: the [model's `on_additive_change` value](../guides/incremental_time.md#schema-changes) (if present), the `on_additive_change` [model defaults](../reference/model_configuration.md#model-defaults) value (if present), and the SQLMesh global default of `allow` + +If you want to temporarily allow destructive changes to models that don't allow them, use the `plan` command's `--allow-destructive-model` selector to specify which models. +Similarly, if you want to temporarily allow additive changes to models configured with `on_additive_change=error`, use the `--allow-additive-model` selector. + +For example, to allow destructive changes to all models in the `analytics` schema: +```bash +sqlmesh plan --forward-only --allow-destructive-model "analytics.*" +``` + +Or to allow destructive changes to multiple specific models: +```bash +sqlmesh plan --forward-only --allow-destructive-model "sales.revenue_model" --allow-destructive-model "marketing.campaign_model" +``` -If you want to temporarily allow destructive changes to models that don't allow them, use the `plan` command's `--allow-destructive-change` selector to specify which models. Learn more about model selectors [here](../guides/model_selection.md). +Learn more about model selectors [here](../guides/model_selection.md). ### Effective date Changes that are part of the forward-only plan can also be applied retroactively to the production environment by specifying the effective date: + ```bash sqlmesh plan --forward-only --effective-from 2023-01-01 ``` + This way SQLMesh will know to recompute data intervals starting from the specified date once forward-only changes are deployed to production. ## Restatement plans -There are cases when models need to be re-evaluated for a given time range, even though changes may not have been made to those model definitions. This could be due to an upstream issue with a dataset defined outside the SQLMesh platform, or when a [forward-only plan](#forward-only-plans) change needs to be applied retroactively to a bounded interval of historical data. -For this reason, the `plan` command supports the `--restate-model`, which allows users to specify one or more names of a model or model tag (using `tag:` syntax) to be reprocessed. These can also refer to an external table defined outside SQLMesh. +Models sometimes need to be re-evaluated for a given time range, even though the model definition has not changed. -Application of a plan will trigger a cascading backfill for all specified models (other than external tables), as well as all models downstream from them. The plan's date range determines the data intervals that will be affected. +For example, these scenarios all require re-evaluating model data that already exists: -Please note that models of kinds [INCREMENTAL_BY_UNIQUE_KEY](models/model_kinds.md#INCREMENTAL_BY_UNIQUE_KEY), [SCD_TYPE_2_BY_TIME](models/model_kinds.md#scd-type-2), and [SCD_TYPE_2_BY_COLUMN](models/model_kinds.md#scd-type-2) cannot be partially restated. Therefore, such models will be fully refreshed regardless of the start/end dates provided by a user in the plan. +- Correcting an upstream data issue by reprocessing some of a model's existing data +- Retroactively applying a [forward-only plan](#forward-only-plans) change to some historical data +- Fully refreshing a model -To prevent models from ever being restated, set the [disable_restatement](models/overview.md#disable_restatement) attribute to `true`. +In SQLMesh, reprocessing existing data is called a "restatement." + +Restate one or more models' data with the `plan` command's `--restate-model` selector. The [selector](../guides/model_selection.md) lets you specify which models to restate by name, wildcard, or tag (syntax [below](#restatement-examples)). + +!!! warning "No changes allowed" -See examples below for how to restate both based on model names and model tags. + Unlike regular plans, restatement plans ignore changes to local files. They can only restate the model versions already in the target environment. + You cannot restate a new model - it must already be present in the target environment. If it's not, add it first by running `sqlmesh plan` without the `--restate-model` option. + +Applying a restatement plan will trigger a cascading backfill for all selected models, as well as all models downstream from them. Models with restatement disabled will be skipped and not backfilled. + +You may restate external models. An [external model](./models/external_models.md) is just metadata about an external table, so the model does not actually reprocess anything. Instead, it triggers a cascading backfill of all downstream models. + +The plan's `--start` and `--end` date options determine which data intervals will be reprocessed. Some model kinds cannot be backfilled for limited date ranges, though - learn more [below](#model-kind-limitations). + +!!! info "Just catching up" + + Restatement plans "catch models up" to the latest time interval already processed in the environment. They cannot process additional intervals because the required data has not yet been processed upstream. + + If you pass an `--end` date later than the environment's most recent time interval, SQLMesh will just catch up to the environment and will ignore any additional intervals. + +To prevent models from ever being restated, set the [disable_restatement](models/overview.md#disable_restatement) attribute to `true`. + + +These examples demonstrate how to select which models to restate based on model names or model tags. === "Names Only" @@ -169,7 +431,7 @@ See examples below for how to restate both based on model names and model tags. ```bash # All selected models (including upstream models) will also include their downstream models - sqlmesh plan --restate-model "+db.model_a" --restate-model "tag:+expensive" + sqlmesh plan --restate-model "+db.model_a" --restate-model "+tag:expensive" ``` === "Wildcards" @@ -181,5 +443,50 @@ See examples below for how to restate both based on model names and model tags. === "Upstream + Wildcards" ```bash - sqlmesh plan --restate-model "+db*" --restate-model "tag:+exp*" + sqlmesh plan --restate-model "+db*" --restate-model "+tag:exp*" + ``` + +=== "Specific Date Range" + + ```bash + sqlmesh plan --restate-model "db.model_a" --start "2024-01-01" --end "2024-01-10" ``` + +### Restating production vs development + +Restatement plans behave differently depending on if you're targeting the `prod` environment or a [development environment](./environments.md#how-to-use-environments). + +If you target a development environment by including an environment name like `dev`: + +```bash +sqlmesh plan dev --restate-model "db.model_a" --start "2024-01-01" --end "2024-01-10" +``` + +the restatement plan will restate the requested intervals for the specified model in the `dev` environment. In other environments, the model will be unaffected. + +However, if you target the `prod` environment by omitting an environment name: + +```bash +sqlmesh plan --restate-model "db.model_a" --start "2024-01-01" --end "2024-01-10" +``` + +the restatement plan will restate the intervals in the `prod` table *and clear the model's time intervals from state in every other environment*. + +The next time you do a run in `dev`, the intervals already reprocessed in `prod` are reprocessed in `dev` as well. This is to prevent old data from getting promoted to `prod` in the future. + +This behavior also clears the affected intervals for downstream tables that only exist in development environments. Consider the following example: + + - Table `A` exists in `prod` + - A virtual environment `dev` is created with new tables `B` and `C` downstream of `A` + - the DAG in `prod` looks like `A` + - the DAG in `dev` looks like `A <- B <- C` + - A restatement plan is executed against table `A` in `prod` + - SQLMesh will clear the affected intervals for `B` and `C` in `dev` even though those tables do not exist in `prod` + +!!! info "Bringing development environments up to date" + + A restatement plan against `prod` clears time intervals from state for models in development environments, but it does not trigger a run to reprocess those intervals. + + Execute `sqlmesh run ` to trigger reprocessing in the development environment. + + This is necessary because a `prod` restatement plan only does work in the `prod` environment for speed and efficiency. \ No newline at end of file diff --git a/docs/concepts/state.md b/docs/concepts/state.md new file mode 100644 index 0000000000..ea5391ec20 --- /dev/null +++ b/docs/concepts/state.md @@ -0,0 +1,279 @@ +# State + +SQLMesh stores information about your project in a state database that is usually separate from your main warehouse. + +The SQLMesh state database contains: + +- Information about every [Model Version](./models/overview.md) in your project (query, loaded intervals, dependencies) +- A list of every [Virtual Data Environment](./environments.md) in the project +- Which model versions are [promoted](./plans.md#plan-application) into each [Virtual Data Environment](./environments.md) +- Information about any [auto restatements](./models/overview.md#auto_restatement_cron) present in your project +- Other metadata about your project such as current SQLMesh / SQLGlot version + +The state database is how SQLMesh "remembers" what it's done before so it can compute a minimum set of operations to apply changes instead of rebuilding everything every time. It's also how SQLMesh tracks what historical data has already been backfilled for [incremental models](./models/model_kinds.md#incremental_by_time_range) so you dont need to add branching logic into the model query to handle this. + +!!! info "State database performance" + + The workload against the state database is an OLTP workload that requires transaction support in order to work correctly. + + For the best experience, we recommend [Tobiko Cloud](../cloud/cloud_index.md) or databases designed for OLTP workloads such as [PostgreSQL](../integrations/engines/postgres.md). + + Using your warehouse OLAP database to store state is supported for proof-of-concept projects but is not suitable for production and **will** lead to poor performance and consistency. + + For more information on engines suitable for the SQLMesh state database, see the [configuration guide](../guides/configuration.md#state-connection). + +## Exporting / Importing State + +SQLMesh supports exporting the state database to a `.json` file. From there, you can inspect the file with any tool that can read text files. You can also pass the file around and import it back in to a SQLMesh project running elsewhere. + +### Exporting state + +SQLMesh can export the state database to a file like so: + +```bash +$ sqlmesh state export -o state.json +Exporting state to 'state.json' from the following connection: + +Gateway: dev +State Connection: +├── Type: postgres +├── Catalog: sushi_dev +└── Dialect: postgres + +Continue? [y/n]: y + + Exporting versions ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 3/3 • 0:00:00 + Exporting snapshots ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 17/17 • 0:00:00 +Exporting environments ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 1/1 • 0:00:00 + +State exported successfully to 'state.json' +``` + +This will produce a file `state.json` in the current directory containing the SQLMesh state. + +The state file is a simple `json` file that looks like: + +```json +{ + /* State export metadata */ + "metadata": { + "timestamp": "2025-03-16 23:09:00+00:00", /* UTC timestamp of when the file was produced */ + "file_version": 1, /* state export file format version */ + "importable": true /* whether or not this file can be imported with `sqlmesh state import` */ + }, + /* Library versions used to produce this state export file */ + "versions": { + "schema_version": 76 /* sqlmesh state database schema version */, + "sqlglot_version": "26.10.1" /* version of SQLGlot used to produce the state file */, + "sqlmesh_version": "0.165.1" /* version of SQLMesh used to produce the state file */, + }, + /* array of objects containing every Snapshot (physical table) tracked by the SQLMesh project */ + "snapshots": [ + { "name": "..." } + ], + /* object for every Virtual Data Environment in the project. key = environment name, value = environment details */ + "environments": { + "prod": { + /* information about the environment itself */ + "environment": { + "..." + }, + /* information about any before_all / after_all statements for this environment */ + "statements": [ + "..." + ] + } + } +} +``` + +#### Specific environments + +You can export a specific environment like so: + +```sh +$ sqlmesh state export --environment my_dev -o my_dev_state.json +``` + +Note that every snapshot that is part of the environment will be exported, not just the differences from `prod`. The reason for this is so that the environment can be fully imported elsewhere without any assumptions about which snapshots are already present in state. + +#### Local state + +You can export local state like so: + +```bash +$ sqlmesh state export --local -o local_state.json +``` + +This essentially just exports the state of the local context which includes local changes that have not been applied to any virtual data environments. + +Therefore, a local state export will only have `snapshots` populated. `environments` will be empty because virtual data environments are only present in the warehouse / remote state. In addition, the file is marked as **not importable** so it cannot be used with a subsequent `sqlmesh state import` command. + +### Importing state + +!!! warning "Back up your state database first!" + + Please ensure you have created an independent backup of your state database in case something goes wrong during the state import. + + SQLMesh tries to wrap the state import in a transaction but some database engines do not support transactions against DDL which means + a import error has the potential to leave the state database in an inconsistent state. + +SQLMesh can import a state file into the state database like so: + +```bash +$ sqlmesh state import -i state.json --replace +Loading state from 'state.json' into the following connection: + +Gateway: dev +State Connection: +├── Type: postgres +├── Catalog: sushi_dev +└── Dialect: postgres + +[WARNING] This destructive operation will delete all existing state against the 'dev' gateway +and replace it with what\'s in the 'state.json' file. + +Are you sure? [y/n]: y + +State File Information: +├── Creation Timestamp: 2025-03-31 02:15:00+00:00 +├── File Version: 1 +├── SQLMesh version: 0.170.1.dev0 +├── SQLMesh migration version: 76 +└── SQLGlot version: 26.12.0 + + Importing versions ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 3/3 • 0:00:00 + Importing snapshots ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 17/17 • 0:00:00 +Importing environments ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 1/1 • 0:00:00 + +State imported successfully from 'state.json' +``` + +Note that the state database structure needs to be present and up to date, so run `sqlmesh migrate` before running `sqlmesh state import` if you get a version mismatch error. + +If you have a partial state export, perhaps for a single environment - you can merge it in by omitting the `--replace` parameter: + +```bash +$ sqlmesh state import -i state.json +... + +[WARNING] This operation will merge the contents of the state file to the state located at the 'dev' gateway. +Matching snapshots or environments will be replaced. +Non-matching snapshots or environments will be ignored. + +Are you sure? [y/n]: y + +... +State imported successfully from 'state.json' +``` + + +### Specific gateways + +If your project has [multiple gateways](../guides/configuration.md#gateways) with different state connections per gateway, you can target the [state_connection](../guides/configuration.md#state-connection) of a specific gateway like so: + +```bash +# state export +$ sqlmesh --gateway state export -o state.json + +# state import +$ sqlmesh --gateway state import -i state.json +``` + +## Version Compatibility + +When importing state, the state file must have been produced with the same major and minor version of SQLMesh that is being used to import it. + +If you attempt to import state with an incompatible version, you will get the following error: + +```bash +$ sqlmesh state import -i state.json +...SNIP... + +State import failed! +Error: SQLMesh version mismatch. You are running '0.165.1' but the state file was created with '0.164.1'. +Please upgrade/downgrade your SQLMesh version to match the state file before performing the import. +``` + +### Upgrading a state file + +You can upgrade a state file produced by an old SQLMesh version to be compatible with a newer SQLMesh version by: + +- Loading it into a local database using the older SQLMesh version +- Installing the newer SQLMesh version +- Running `sqlmesh migrate` to upgrade the state within the local database +- Running `sqlmesh state export` to export it back out again. The new export is now compatible with the newer version of SQLMesh. + +Below is an example of how to upgrade a state file created with SQLMesh `0.164.1` to be compatible with SQLMesh `0.165.1`. + +First, create and activate a virtual environment to isolate the SQLMesh versions from your main environment: + +```bash +$ python -m venv migration-env + +$ . ./migration-env/bin/activate + +(migration-env)$ +``` + +Install the SQLMesh version compatible with your state file. The correct version to use is printed in the error message, eg `the state file was created with '0.164.1'` means you need to install SQLMesh `0.164.1`: + +```bash +(migration-env)$ pip install "sqlmesh==0.164.1" +``` + +Add a gateway to your `config.yaml` like so: + +```yaml +gateways: + migration: + connection: + type: duckdb + database: ./state-migration.duckdb +``` + +The goal here is to define just enough config for SQLMesh to be able to use a local database to run the state export/import commands. SQLMesh still needs to inherit things like the `model_defaults` from your project in order to migrate state correctly which is why we have not used an isolated directory. + +!!! warning + + From here on, be sure to specify `--gateway migration` to all SQLMesh commands or you run the risk of accidentally clobbering any state on your main gateway + +You can now import your state export using the same version of SQLMesh it was created with: + +```bash +(migration-env)$ sqlmesh --gateway migration migrate + +(migration-env)$ sqlmesh --gateway migration state import -i state.json +... +State imported successfully from 'state.json' +``` + +Now we have the state imported, we can upgrade SQLMesh and export the state from the new version. +The new version was printed in the original error message, eg `You are running '0.165.1'` + +To upgrade SQLMesh, simply install the new version: + +```bash +(migration-env)$ pip install --upgrade "sqlmesh==0.165.1" +``` + +Migrate the state to the new version: + +```bash +(migration-env)$ sqlmesh --gateway migration migrate +``` + +And finally, create a new state file which is now compatible with the new SQLMesh version: + +```bash + (migration-env)$ sqlmesh --gateway migration state export -o state-migrated.json +``` + +The `state-migrated.json` file is now compatible with the newer version of SQLMesh. +You can then transfer it to the place you originally needed it and import it in: + +```bash +$ sqlmesh state import -i state-migrated.json +... +State imported successfully from 'state-migrated.json' +``` \ No newline at end of file diff --git a/docs/concepts/tests.md b/docs/concepts/tests.md index c4b2adb42d..c1714ea982 100644 --- a/docs/concepts/tests.md +++ b/docs/concepts/tests.md @@ -265,7 +265,7 @@ test_parameterized_model: ... ``` -For example, assuming `gold` is a [config variable](../reference/configuration/#variables) with value `gold_db`, the above test would be rendered as: +For example, assuming `gold` is a [config variable](../reference/configuration.md#variables) with value `gold_db`, the above test would be rendered as: ```yaml linenums="1" test_parameterized_model: @@ -486,6 +486,9 @@ These fixtures are dropped by default after the execution completes, but it is p This can be helpful when debugging a test failure, because for example it's possible to query the fixture views directly and verify that they are defined correctly. +!!! note + By default, the views that are necessary to run a unit test are created within a new, unique schema, whose name looks like `sqlmesh_test_`. To specify a custom name for this schema, set the [`.schema`](#test_nameschema) test attribute. + ### Type mismatches It's not always possible to correctly interpret certain values in a unit test without additional context. For example, a YAML dictionary can be used to represent both a `STRUCT` and a `MAP` value in SQL. @@ -512,6 +515,10 @@ The name of the model being tested. This model must be defined in the project's An optional description of the test, which can be used to provide additional context. +### `.schema` + +The name of the schema that will contain the views that are necessary to run this unit test. + ### `.gateway` The gateway whose `test_connection` will be used to run this test. If not specified, the default gateway is used. @@ -598,7 +605,7 @@ An optional dictionary that maps columns to their types: ```yaml linenums="1" : columns: - - : + : ... ``` diff --git a/docs/development.md b/docs/development.md index 3ec5ff2c00..ff8b250d87 100644 --- a/docs/development.md +++ b/docs/development.md @@ -1,42 +1,103 @@ # Contribute to development -SQLMesh is licensed under [Apache 2.0](https://github.com/TobikoData/sqlmesh/blob/main/LICENSE). We encourage community contribution and would love for you to get involved. + +SQLMesh is licensed under [Apache 2.0](https://github.com/SQLMesh/sqlmesh/blob/main/LICENSE). We encourage community contribution and would love for you to get involved. The following document outlines the process to contribute to SQLMesh. ## Prerequisites + +Before you begin, ensure you have the following installed on your machine. Exactly how to install these is dependent on your operating system. + * Docker * Docker Compose V2 * OpenJDK >= 11 +* Python >= 3.9 < 3.13 -## Commands reference +## Virtual environment setup + +We do recommend using a virtual environment to develop SQLMesh. + +```bash +python -m venv .venv +source .venv/bin/activate +``` + +Once you have activated your virtual environment, you can install the dependencies by running the following command. -Install dev dependencies: ```bash make install-dev ``` + +Optionally, you can use pre-commit to automatically run linters/formatters: + +```bash +make install-pre-commit +``` + +## Python development + Run linters and formatters: + ```bash make style ``` + Run faster tests for quicker local feedback: + ```bash make fast-test ``` + Run more comprehensive tests that run on each commit: + ```bash make slow-test ``` -Run Airflow tests that will run when PR is merged to main: + +## Documentation + +In order to run the documentation server, you will need to install the dependencies by running the following command. + ```bash -make airflow-docker-test-with-env +make install-doc ``` -Run docs server: + +Once you have installed the dependencies, you can run the documentation server by running the following command. + ```bash make docs-serve ``` + +Run docs tests: + +```bash +make doc-test +``` + +## UI development + +In addition to the Python development, you can also develop the UI. + +The UI is built using React and Typescript. To run the UI, you will need to install the dependencies by running the following command. + +```bash +pnpm install +``` + Run ide: + ```bash make ui-up ``` -(Optional) Use pre-commit to automatically run linters/formatters: + +## Developing the VSCode extension + +Similar to UI development, you can also develop the VSCode extension. To do so, make sure you have the dependencies installed by running the following command inside the `vscode/extension` directory. + ```bash -make install-pre-commit +pnpm install +``` + +Once that is done, developing the VSCode extension is most easily done by launching the `Run Extensions` debug task from a Visual Studio Code workspace opened at the root of the SQLMesh repository. By default, the VSCode extension will run the SQLMesh server locally and open a new Visual Studio Code window that allows you to try out the SQLMesh IDE. It opens the `examples/sushi` project by default. To set up Visual Studio Code to run the `Run Extensions` debug task, you can run the following command which will copy the `launch.json` and `tasks.json` files to the `.vscode` directory. + +```bash +make vscode_settings ``` diff --git a/docs/examples/incremental_time/column_level_audit_trail.png b/docs/examples/incremental_time/column_level_audit_trail.png new file mode 100644 index 0000000000..f715f3eac1 Binary files /dev/null and b/docs/examples/incremental_time/column_level_audit_trail.png differ diff --git a/docs/examples/incremental_time/node_level_audit_trail.png b/docs/examples/incremental_time/node_level_audit_trail.png new file mode 100644 index 0000000000..8f023085d3 Binary files /dev/null and b/docs/examples/incremental_time/node_level_audit_trail.png differ diff --git a/docs/examples/incremental_time_full_walkthrough.md b/docs/examples/incremental_time_full_walkthrough.md new file mode 100644 index 0000000000..ffa9def911 --- /dev/null +++ b/docs/examples/incremental_time_full_walkthrough.md @@ -0,0 +1,1456 @@ +# Incremental by Time Range + +
+ +SQLMesh incremental models are a powerful feature that come in many flavors and configurations so you can fine tune your query performance and scheduled runs **exactly** how you want with a plethora of guardrails. + +However, we recognize with all this power comes a responsibility to make sure you’re equipped to succeed confidently. + +We’re going to walk you through a clear story problem step by step. The end outcome is for you to feel confident with this new workflow to: + +- Build a mental model for how to solve data transformation problems with SQLMesh incremental models +- Know which configs to update and why +- Run a sequence of `sqlmesh` commands and know exactly what’s running and why +- Understand the tradeoffs between different approaches and make the right decisions for your use case +- Save precious time and money running your data transformation pipelines + +## Story Problem + +I am a data engineer working for a company selling software directly to customers. I have sales data with millions of transactions per day, and I want to add dimensions from other raw sources to better understand what sales/product trends are happening. + +So I have two raw data sources like this: + +- Source A: raw sales data is extracted and loaded into my data warehouse (think: BigQuery, Snowflake, Databricks, etc.) hourly +- Source B: product usage data from a backend database (think: Postgres) is extracted and loaded into my data warehouse daily + +On first impression, this looks like a piece of cake. However, as I reflect on what success looks like for this to be built AND maintained well, there’s a lot of problems to solve for. Don’t worry, we answer all these questions at the end. + +- How do I handle late-arriving data? +- How do I account for UTC vs. PST (California) timestamps, do I convert them? +- What schedule should I run these at? +- How do I test this data? +- How do I make this run fast and only the intervals necessary (read: partitions)? +- How do I make patch changes when an edge case error occurs with incorrect data from months ago? +- What do unit tests look and feel like for this? +- How do I prevent data gaps with unprocessed or incomplete intervals? +- Am I okay processing incomplete intervals (think: allow partials)? +- What tradeoffs am I willing to make for fresh data? +- How to make this not feel so complex during development? +- How do I know SQLMesh is behaving how I want it to behave? + +## Development Workflow + +You’ll be following this general sequence of actions when working with SQLMesh: + +1. `sqlmesh plan dev`: create a dev environment for your new SQL model +2. `sqlmesh fetchdf`: preview data in dev +3. `sqlmesh create_external_models`: automatically generate documentation for raw source tables' column-level lineage +4. `sqlmesh plan`: promote model from dev to prod +5. `sqlmesh plan dev --forward-only`: make more code changes and only process new data going forward with those code changes; leave historical data alone +6. `sqlmesh fetchdf`: preview data in dev +7. `sqlmesh create_test`: automatically generate unit tests +8. `sqlmesh test`: run those unit tests +9. `sqlmesh plan`: promote dev to prod + +> Note: If this is the first time you're running SQLMesh, I recommend following the [CLI Quickstart](../quickstart/cli.md) first and then coming back to this example. + +## Setup + +Let’s start with some demo data coupled with an existing SQLMesh project with models already in production. + +I recommend not reading too much into the exact contents of this data outside of timestamps and primary/foreign keys. All of this is fabricated for the purposes of this guide. + +We have data like the below that gets ingested into our data warehouse on a daily basis. + +??? "Raw product usage data" + + | product_id | customer_id | last_usage_date | usage_count | feature_utilization_score | user_segment | + | ---------- | ----------- | ------------------------- | ----------- | ------------------------- | ------------ | + | PROD-101 | CUST-001 | 2024-10-25 23:45:00+00:00 | 120 | 0.85 | enterprise | + | PROD-103 | CUST-001 | 2024-10-27 12:30:00+00:00 | 95 | 0.75 | enterprise | + | PROD-102 | CUST-002 | 2024-10-25 15:15:00+00:00 | 150 | 0.92 | enterprise | + | PROD-103 | CUST-002 | 2024-10-26 14:20:00+00:00 | 80 | 0.68 | enterprise | + | PROD-101 | CUST-003 | 2024-10-25 18:30:00+00:00 | 45 | 0.45 | professional | + | PROD-102 | CUST-003 | 2024-10-27 19:45:00+00:00 | 30 | 0.35 | professional | + | PROD-103 | CUST-004 | 2024-10-25 21:20:00+00:00 | 15 | 0.25 | starter | + | PROD-102 | CUST-005 | 2024-10-25 23:10:00+00:00 | 5 | 0.15 | starter | + | PROD-102 | CUST-006 | 2024-10-26 15:30:00+00:00 | 110 | 0.88 | enterprise | + | PROD-101 | CUST-007 | 2024-10-26 17:45:00+00:00 | 60 | 0.55 | professional | + | PROD-103 | CUST-008 | 2024-10-26 22:20:00+00:00 | 25 | 0.30 | starter | + | PROD-101 | CUST-009 | 2024-10-27 05:15:00+00:00 | 75 | 0.65 | professional | + | PROD-102 | CUST-010 | 2024-10-27 08:40:00+00:00 | 3 | 0.10 | starter | + +??? "Raw sales data" + + | transaction_id | product_id | customer_id | transaction_amount | transaction_timestamp | payment_method | currency | + | -------------- | ---------- | ----------- | ------------------ | ------------------------- | -------------- | -------- | + | TX-001 | PROD-101 | CUST-001 | 99.99 | 2024-10-25 08:30:00+00:00 | credit_card | USD | + | TX-002 | PROD-102 | CUST-002 | 149.99 | 2024-10-25 09:45:00+00:00 | paypal | USD | + | TX-003 | PROD-101 | CUST-003 | 99.99 | 2024-10-25 15:20:00+00:00 | credit_card | USD | + | TX-004 | PROD-103 | CUST-004 | 299.99 | 2024-10-25 18:10:00+00:00 | credit_card | USD | + | TX-005 | PROD-102 | CUST-005 | 149.99 | 2024-10-25 21:30:00+00:00 | debit_card | USD | + | TX-006 | PROD-101 | CUST-001 | 99.99 | 2024-10-26 03:15:00+00:00 | credit_card | USD | + | TX-007 | PROD-103 | CUST-002 | 299.99 | 2024-10-26 07:45:00+00:00 | paypal | USD | + | TX-008 | PROD-102 | CUST-006 | 149.99 | 2024-10-26 11:20:00+00:00 | credit_card | USD | + | TX-009 | PROD-101 | CUST-007 | 99.99 | 2024-10-26 14:30:00+00:00 | debit_card | USD | + | TX-010 | PROD-103 | CUST-008 | 299.99 | 2024-10-26 19:45:00+00:00 | credit_card | USD | + | TX-011 | PROD-101 | CUST-009 | 99.99 | 2024-10-27 02:30:00+00:00 | paypal | USD | + | TX-012 | PROD-102 | CUST-010 | 149.99 | 2024-10-27 05:15:00+00:00 | credit_card | USD | + | TX-013 | PROD-103 | CUST-001 | 299.99 | 2024-10-27 08:40:00+00:00 | credit_card | USD | + | TX-014 | PROD-101 | CUST-002 | 99.99 | 2024-10-27 13:25:00+00:00 | debit_card | USD | + | TX-015 | PROD-102 | CUST-003 | 149.99 | 2024-10-27 16:50:00+00:00 | credit_card | USD | + +??? "Code to load the data into BigQuery" + + If you want to follow along, here are BigQuery SQL queries to make it easier for you! Just run them directly in the query console. Feel free to adjust for your data warehouse. + + ```sql + -- Create the product_usage table with appropriate schema + CREATE OR REPLACE TABLE `sqlmesh-public-demo.tcloud_raw_data.product_usage` ( + product_id STRING NOT NULL, + customer_id STRING NOT NULL, + last_usage_date TIMESTAMP NOT NULL, + usage_count INT64 NOT NULL, + feature_utilization_score FLOAT64 NOT NULL, + user_segment STRING NOT NULL, + ); + + -- Insert the data + INSERT INTO `sqlmesh-public-demo.tcloud_raw_data.product_usage` + (product_id, customer_id, last_usage_date, usage_count, feature_utilization_score, user_segment) + VALUES + ('PROD-101', 'CUST-001', TIMESTAMP '2024-10-25 23:45:00+00:00', 120, 0.85, 'enterprise'), + ('PROD-103', 'CUST-001', TIMESTAMP '2024-10-27 12:30:00+00:00', 95, 0.75, 'enterprise'), + ('PROD-102', 'CUST-002', TIMESTAMP '2024-10-25 15:15:00+00:00', 150, 0.92, 'enterprise'), + ('PROD-103', 'CUST-002', TIMESTAMP '2024-10-26 14:20:00+00:00', 80, 0.68, 'enterprise'), + ('PROD-101', 'CUST-003', TIMESTAMP '2024-10-25 18:30:00+00:00', 45, 0.45, 'professional'), + ('PROD-102', 'CUST-003', TIMESTAMP '2024-10-27 19:45:00+00:00', 30, 0.35, 'professional'), + ('PROD-103', 'CUST-004', TIMESTAMP '2024-10-25 21:20:00+00:00', 15, 0.25, 'starter'), + ('PROD-102', 'CUST-005', TIMESTAMP '2024-10-25 23:10:00+00:00', 5, 0.15, 'starter'), + ('PROD-102', 'CUST-006', TIMESTAMP '2024-10-26 15:30:00+00:00', 110, 0.88, 'enterprise'), + ('PROD-101', 'CUST-007', TIMESTAMP '2024-10-26 17:45:00+00:00', 60, 0.55, 'professional'), + ('PROD-103', 'CUST-008', TIMESTAMP '2024-10-26 22:20:00+00:00', 25, 0.30, 'starter'), + ('PROD-101', 'CUST-009', TIMESTAMP '2024-10-27 05:15:00+00:00', 75, 0.65, 'professional'), + ('PROD-102', 'CUST-010', TIMESTAMP '2024-10-27 08:40:00+00:00', 3, 0.10, 'starter'); + + ``` + + ```sql + --Create the sales table with appropriate schema + CREATE OR REPLACE TABLE `sqlmesh-public-demo.tcloud_raw_data.sales` ( + transaction_id STRING NOT NULL, + product_id STRING NOT NULL, + customer_id STRING NOT NULL, + transaction_amount NUMERIC(10,2) NOT NULL, + transaction_timestamp TIMESTAMP NOT NULL, + payment_method STRING, + currency STRING, + ); + + -- Then, insert the data + INSERT INTO `sqlmesh-public-demo.tcloud_raw_data.sales` + (transaction_id, product_id, customer_id, transaction_amount, transaction_timestamp, payment_method, currency) + VALUES + ('TX-001', 'PROD-101', 'CUST-001', 99.99, TIMESTAMP '2024-10-25 08:30:00+00:00', 'credit_card', 'USD'), + ('TX-002', 'PROD-102', 'CUST-002', 149.99, TIMESTAMP '2024-10-25 09:45:00+00:00', 'paypal', 'USD'), + ('TX-003', 'PROD-101', 'CUST-003', 99.99, TIMESTAMP '2024-10-25 15:20:00+00:00', 'credit_card', 'USD'), + ('TX-004', 'PROD-103', 'CUST-004', 299.99, TIMESTAMP '2024-10-25 18:10:00+00:00', 'credit_card', 'USD'), + ('TX-005', 'PROD-102', 'CUST-005', 149.99, TIMESTAMP '2024-10-25 21:30:00+00:00', 'debit_card', 'USD'), + ('TX-006', 'PROD-101', 'CUST-001', 99.99, TIMESTAMP '2024-10-26 03:15:00+00:00', 'credit_card', 'USD'), + ('TX-007', 'PROD-103', 'CUST-002', 299.99, TIMESTAMP '2024-10-26 07:45:00+00:00', 'paypal', 'USD'), + ('TX-008', 'PROD-102', 'CUST-006', 149.99, TIMESTAMP '2024-10-26 11:20:00+00:00', 'credit_card', 'USD'), + ('TX-009', 'PROD-101', 'CUST-007', 99.99, TIMESTAMP '2024-10-26 14:30:00+00:00', 'debit_card', 'USD'), + ('TX-010', 'PROD-103', 'CUST-008', 299.99, TIMESTAMP '2024-10-26 19:45:00+00:00', 'credit_card', 'USD'), + ('TX-011', 'PROD-101', 'CUST-009', 99.99, TIMESTAMP '2024-10-27 02:30:00+00:00', 'paypal', 'USD'), + ('TX-012', 'PROD-102', 'CUST-010', 149.99, TIMESTAMP '2024-10-27 05:15:00+00:00', 'credit_card', 'USD'), + ('TX-013', 'PROD-103', 'CUST-001', 299.99, TIMESTAMP '2024-10-27 08:40:00+00:00', 'credit_card', 'USD'), + ('TX-014', 'PROD-101', 'CUST-002', 99.99, TIMESTAMP '2024-10-27 13:25:00+00:00', 'debit_card', 'USD'), + ('TX-015', 'PROD-102', 'CUST-003', 149.99, TIMESTAMP '2024-10-27 16:50:00+00:00', 'credit_card', 'USD'); + ``` + +## Model Configuration + +I can answer some of the questions above by walking through the model's config, coupled with the business logic/code I prepared ahead of time. + +You can see this code in a SQLMesh project context [here](https://github.com/sungchun12/sqlmesh-demos/blob/incremental-demo/models/examples/incremental_model.sql). + +```sql +MODEL ( + name demo.incrementals_demo, + kind INCREMENTAL_BY_TIME_RANGE ( + -- How does this model kind behave? + -- DELETE by time range, then INSERT + time_column transaction_date, + + -- How do I handle late-arriving data? + -- Handle late-arriving events for the past 2 (2*1) days based on cron + -- interval. Each time it runs, it will process today, yesterday, and + -- the day before yesterday. + lookback 2, + ), + + -- Don't backfill data before this date + start '2024-10-25', + + -- What schedule should I run these at? + -- Daily at Midnight UTC + cron '@daily', + + -- Good documentation for the primary key + grain transaction_id, + + -- How do I test this data? + -- Validate that the `transaction_id` primary key values are both unique + -- and non-null. Data audit tests only run for the processed intervals, + -- not for the entire table. + audits ( + UNIQUE_VALUES(columns = (transaction_id)), + NOT_NULL(columns = (transaction_id)) + ) +); + +WITH sales_data AS ( + SELECT + transaction_id, + product_id, + customer_id, + transaction_amount, + -- How do I account for UTC vs. PST (California baby) timestamps? + -- Make sure all time columns are in UTC and convert them to PST in the + -- presentation layer downstream. + transaction_timestamp, + payment_method, + currency + FROM sqlmesh-public-demo.tcloud_raw_data.sales -- Source A: sales data + -- How do I make this run fast and only process the necessary intervals? + -- Use our date macros that will automatically run the necessary intervals. + -- Because SQLMesh manages state, it will know what needs to run each time + -- you invoke `sqlmesh run`. + WHERE transaction_timestamp BETWEEN @start_dt AND @end_dt +), + +product_usage AS ( + SELECT + product_id, + customer_id, + last_usage_date, + usage_count, + feature_utilization_score, + user_segment + FROM sqlmesh-public-demo.tcloud_raw_data.product_usage -- Source B + -- Include usage data from the 30 days before the interval + WHERE last_usage_date BETWEEN DATE_SUB(@start_dt, INTERVAL 30 DAY) AND @end_dt +) + +SELECT + s.transaction_id, + s.product_id, + s.customer_id, + s.transaction_amount, + -- Extract the date from the timestamp to partition by day + DATE(s.transaction_timestamp) as transaction_date, + -- Convert timestamp to PST using a SQL function in the presentation layer for end users + DATETIME(s.transaction_timestamp, 'America/Los_Angeles') as transaction_timestamp_pst, + s.payment_method, + s.currency, + -- Product usage metrics + p.last_usage_date, + p.usage_count, + p.feature_utilization_score, + p.user_segment, + -- Derived metrics + CASE + WHEN p.usage_count > 100 AND p.feature_utilization_score > 0.8 THEN 'Power User' + WHEN p.usage_count > 50 THEN 'Regular User' + WHEN p.usage_count IS NULL THEN 'New User' + ELSE 'Light User' + END as user_type, + -- Time since last usage + DATE_DIFF(s.transaction_timestamp, p.last_usage_date, DAY) as days_since_last_usage +FROM sales_data s +LEFT JOIN product_usage p + ON s.product_id = p.product_id + AND s.customer_id = p.customer_id +``` + +## Creating the model + +I’m creating this model for the first time against an existing SQLMesh project that already has data in production. So let’s run this in a `dev` environment. + +Run this command to add this incremental model to a `dev` environment: + +```bash +sqlmesh plan dev +``` + +*Note: Using `sqlmesh` version `0.132.1` at the time of writing* + +Keep pressing enter on the date prompts, as we want to backfill all of history since 2024-10-25. + +```bash +(venv) ✗ sqlmesh plan dev +====================================================================== +Successfully Ran 2 tests against duckdb +---------------------------------------------------------------------- +New environment `dev` will be created from `prod` + +Differences from the `prod` environment: + +Models: +└── Added: + └── demo__dev.incrementals_demo +Models needing backfill (missing dates): +└── demo__dev.incrementals_demo: 2024-10-25 - 2024-11-04 +Enter the backfill start date (eg. '1 year', '2020-01-01') or blank to backfill from the beginning of history: +Enter the backfill end date (eg. '1 month ago', '2020-01-01') or blank to backfill up until now: +Apply - Backfill Tables [y/n]: y +[1/1] demo__dev.incrementals_demo evaluated in 6.97s +Evaluating models ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 1/1 • 0:00:06 + + +All model batches have been executed successfully + +Virtually Updating 'dev' ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 0:00:01 + +The target environment has been updated successfully +``` + +Now I’m thinking to myself "what exact SQL queries are running to make sure this is behaving as I expect?" + +This sequence of queries is exactly what’s happening in the query engine. Click on the toggles to see the SQL queries. + +??? "Create an empty table with the proper schema that’s also versioned (ex: `__50975949`)" + + ```sql + CREATE TABLE IF NOT EXISTS `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__incrementals_demo__50975949` ( + `transaction_id` STRING, + `product_id` STRING, + `customer_id` STRING, + `transaction_amount` NUMERIC, + `transaction_date` DATE OPTIONS (description='We extract the date from the timestamp to partition by day'), + `transaction_timestamp_pst` DATETIME OPTIONS (description='Convert this to PST using a SQL function'), + `payment_method` STRING, + `currency` STRING, + `last_usage_date` TIMESTAMP, + `usage_count` INT64, + `feature_utilization_score` FLOAT64, + `user_segment` STRING, + `user_type` STRING OPTIONS (description='Derived metrics'), + `days_since_last_usage` INT64 OPTIONS (description='Time since last usage') + ) + PARTITION BY `transaction_date` + ``` + +??? "Validate the SQL before processing data (note the `WHERE FALSE LIMIT 0` and the placeholder timestamps)" + + ```sql + WITH `sales_data` AS ( + SELECT + `sales`.`transaction_id` AS `transaction_id`, + `sales`.`product_id` AS `product_id`, + `sales`.`customer_id` AS `customer_id`, + `sales`.`transaction_amount` AS `transaction_amount`, + `sales`.`transaction_timestamp` AS `transaction_timestamp`, + `sales`.`payment_method` AS `payment_method`, + `sales`.`currency` AS `currency` + FROM `sqlmesh-public-demo`.`tcloud_raw_data`.`sales` AS `sales` + WHERE ( + `sales`.`transaction_timestamp` <= CAST('1970-01-01 23:59:59.999999+00:00' AS TIMESTAMP) AND + `sales`.`transaction_timestamp` >= CAST('1970-01-01 00:00:00+00:00' AS TIMESTAMP)) AND + FALSE + ), + `product_usage` AS ( + SELECT + `product_usage`.`product_id` AS `product_id`, + `product_usage`.`customer_id` AS `customer_id`, + `product_usage`.`last_usage_date` AS `last_usage_date`, + `product_usage`.`usage_count` AS `usage_count`, + `product_usage`.`feature_utilization_score` AS `feature_utilization_score`, + `product_usage`.`user_segment` AS `user_segment` + FROM `sqlmesh-public-demo`.`tcloud_raw_data`.`product_usage` AS `product_usage` + WHERE ( + `product_usage`.`last_usage_date` <= CAST('1970-01-01 23:59:59.999999+00:00' AS TIMESTAMP) AND + `product_usage`.`last_usage_date` >= CAST('1969-12-02 00:00:00+00:00' AS TIMESTAMP) + ) AND + FALSE + ) + + SELECT + `s`.`transaction_id` AS `transaction_id`, + `s`.`product_id` AS `product_id`, + `s`.`customer_id` AS `customer_id`, + CAST(`s`.`transaction_amount` AS NUMERIC) AS `transaction_amount`, + DATE(`s`.`transaction_timestamp`) AS `transaction_date`, + DATETIME(`s`.`transaction_timestamp`, 'America/Los_Angeles') AS `transaction_timestamp_pst`, + `s`.`payment_method` AS `payment_method`, + `s`.`currency` AS `currency`, + `p`.`last_usage_date` AS `last_usage_date`, + `p`.`usage_count` AS `usage_count`, + `p`.`feature_utilization_score` AS `feature_utilization_score`, + `p`.`user_segment` AS `user_segment`, + CASE + WHEN `p`.`feature_utilization_score` > 0.8 AND `p`.`usage_count` > 100 THEN 'Power User' + WHEN `p`.`usage_count` > 50 THEN 'Regular User' + WHEN `p`.`usage_count` IS NULL THEN 'New User' + ELSE 'Light User' + END AS `user_type`, + DATE_DIFF(`s`.`transaction_timestamp`, `p`.`last_usage_date`, DAY) AS `days_since_last_usage` + FROM `sales_data` AS `s` + LEFT JOIN `product_usage` AS `p` + ON `p`.`customer_id` = `s`.`customer_id` AND + `p`.`product_id` = `s`.`product_id` + WHERE FALSE + LIMIT 0 + ``` + +??? "Merge data into empty table" + + ```sql + MERGE INTO `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__incrementals_demo__50975949` AS `__MERGE_TARGET__` USING ( + WITH `sales_data` AS ( + SELECT + `transaction_id`, + `product_id`, + `customer_id`, + `transaction_amount`, + `transaction_timestamp`, + `payment_method`, + `currency` + FROM `sqlmesh-public-demo`.`tcloud_raw_data`.`sales` AS `sales` + WHERE `transaction_timestamp` BETWEEN CAST('2024-10-25 00:00:00+00:00' AS TIMESTAMP) AND CAST('2024-11-04 23:59:59.999999+00:00' AS TIMESTAMP) + ), + `product_usage` AS ( + SELECT + `product_id`, + `customer_id`, + `last_usage_date`, + `usage_count`, + `feature_utilization_score`, + `user_segment` + FROM `sqlmesh-public-demo`.`tcloud_raw_data`.`product_usage` AS `product_usage` + WHERE `last_usage_date` BETWEEN DATE_SUB(CAST('2024-10-25 00:00:00+00:00' AS TIMESTAMP), INTERVAL '30' DAY) AND CAST('2024-11-04 23:59:59.999999+00:00' AS TIMESTAMP) + ) + + SELECT + `transaction_id`, + `product_id`, + `customer_id`, + `transaction_amount`, + `transaction_date`, + `transaction_timestamp_pst`, + `payment_method`, + `currency`, + `last_usage_date`, + `usage_count`, + `feature_utilization_score`, + `user_segment`, + `user_type`, + `days_since_last_usage` + FROM ( + SELECT + `s`.`transaction_id` AS `transaction_id`, + `s`.`product_id` AS `product_id`, + `s`.`customer_id` AS `customer_id`, + `s`.`transaction_amount` AS `transaction_amount`, + DATE(`s`.`transaction_timestamp`) AS `transaction_date`, + DATETIME(`s`.`transaction_timestamp`, 'America/Los_Angeles') AS `transaction_timestamp_pst`, + `s`.`payment_method` AS `payment_method`, + `s`.`currency` AS `currency`, + `p`.`last_usage_date` AS `last_usage_date`, + `p`.`usage_count` AS `usage_count`, + `p`.`feature_utilization_score` AS `feature_utilization_score`, + `p`.`user_segment` AS `user_segment`, + CASE + WHEN `p`.`usage_count` > 100 AND `p`.`feature_utilization_score` > 0.8 THEN 'Power User' + WHEN `p`.`usage_count` > 50 THEN 'Regular User' + WHEN `p`.`usage_count` IS NULL THEN 'New User' + ELSE 'Light User' + END AS `user_type`, + DATE_DIFF(`s`.`transaction_timestamp`, `p`.`last_usage_date`, DAY) AS `days_since_last_usage` + FROM `sales_data` AS `s` + LEFT JOIN `product_usage` AS `p` + ON `s`.`product_id` = `p`.`product_id` + AND `s`.`customer_id` = `p`.`customer_id` + ) AS `_subquery` + WHERE `transaction_date` BETWEEN CAST('2024-10-25' AS DATE) AND CAST('2024-11-04' AS DATE) + ) AS `__MERGE_SOURCE__` + ON FALSE + WHEN NOT MATCHED BY SOURCE AND `transaction_date` BETWEEN CAST('2024-10-25' AS DATE) AND CAST('2024-11-04' AS DATE) THEN DELETE + WHEN NOT MATCHED THEN + INSERT ( + `transaction_id`, `product_id`, `customer_id`, `transaction_amount`, `transaction_date`, `transaction_timestamp_pst`, + `payment_method`, `currency`, `last_usage_date`, `usage_count`, `feature_utilization_score`, `user_segment`, `user_type`, + `days_since_last_usage` + ) + VALUES ( + `transaction_id`, `product_id`, `customer_id`, `transaction_amount`, `transaction_date`, `transaction_timestamp_pst`, + `payment_method`, `currency`, `last_usage_date`, `usage_count`, `feature_utilization_score`, `user_segment`, `user_type`, + `days_since_last_usage` + ) + ``` + +??? "Run data audits to test if `transaction_id` is unique and not null (SQL is automatically generated)" + + `UNIQUE_VALUES()` audit + ```sql + SELECT + COUNT(*) + FROM ( + SELECT * + FROM ( + SELECT + ROW_NUMBER() OVER ( + PARTITION BY (`transaction_id`) O + RDER BY (`transaction_id`) + ) AS `rank_` + FROM ( + SELECT * + FROM `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__incrementals_demo__50975949` AS `demo__incrementals_demo__50975949` + WHERE `transaction_date` BETWEEN CAST('2024-10-25' AS DATE) AND CAST('2024-11-05' AS DATE) + ) AS `_q_0` + WHERE TRUE + ) AS `_q_1` + WHERE `rank_` > 1 + ) AS `audit` + ``` + + `NOT_NULL()` audit + ```sql + SELECT + COUNT(*) + FROM ( + SELECT * + FROM ( + SELECT * + FROM `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__incrementals_demo__50975949` AS `demo__incrementals_demo__50975949` + WHERE `transaction_date` BETWEEN CAST('2024-10-25' AS DATE) AND CAST('2024-11-05' AS DATE) + ) AS `_q_0` + WHERE + `transaction_id` IS NULL + AND TRUE + ) AS `audit` + ``` + +??? "Create development schema based on the name of the plan dev environment" + + ```sql + CREATE SCHEMA IF NOT EXISTS `sqlmesh-public-demo`.`demo__dev` + ``` + +??? "Create a view in the virtual layer to officially query this new table." + + Don’t worry, you won’t get view performance penalties - modern query engines employ pushdown predicate to query the base table directly [example](https://docs.snowflake.com/en/developer-guide/pushdown-optimization). + + ```sql + CREATE OR REPLACE VIEW `sqlmesh-public-demo`.`demo__dev`.`incrementals_demo` AS + SELECT * + FROM `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__incrementals_demo__50975949` + ``` + +Now let’s make sure the look and feel is what I want. Let’s query the new `dev` table: + +```bash +sqlmesh fetchdf "select * from demo__dev.incrementals_demo limit 5" +``` + +```bash +(.venv) ✗ sqlmesh fetchdf "select * from demo__dev.incrementals_demo limit 5" + + transaction_id product_id customer_id transaction_amount transaction_date ... usage_count feature_utilization_score user_segment user_type days_since_last_usage +0 TX-010 PROD-103 CUST-008 299.990000000 2024-10-26 ... 25 0.30 starter Light User 0 +1 TX-008 PROD-102 CUST-006 149.990000000 2024-10-26 ... 110 0.88 enterprise Power User 0 +2 TX-006 PROD-101 CUST-001 99.990000000 2024-10-26 ... 120 0.85 enterprise Power User 0 +3 TX-009 PROD-101 CUST-007 99.990000000 2024-10-26 ... 60 0.55 professional Regular User 0 +4 TX-007 PROD-103 CUST-002 299.990000000 2024-10-26 ... 80 0.68 enterprise Regular User 0 + +[5 rows x 14 columns] +``` + +## Track Column Level Lineage + +Now that I have a solid start to my development, I want to document and visualize how this transformation logic works without manually writing a bunch of `yaml` for the next hour. + +Thankfully, I don’t have to. I’ll get an automatically generated `external_models.yaml` file that will parse my `incrementals_demo.sql` model and query BigQuery metadata to get all columns AND their data types. All of it neatly formatted. + +Run this command: + +```bash +sqlmesh create_external_models +``` + +```yaml +# external_models.yaml +- name: '`sqlmesh-public-demo`.`tcloud_raw_data`.`product_usage`' + columns: + product_id: STRING + customer_id: STRING + last_usage_date: TIMESTAMP + usage_count: INT64 + feature_utilization_score: FLOAT64 + user_segment: STRING +- name: '`sqlmesh-public-demo`.`tcloud_raw_data`.`sales`' + columns: + transaction_id: STRING + product_id: STRING + customer_id: STRING + transaction_amount: NUMERIC(10,2) + transaction_timestamp: TIMESTAMP + payment_method: STRING + currency: STRING +``` + +Now, when I run the command below in my terminal and click on the link it will open up my browser to show the column level lineage I know and love. + +```bash +sqlmesh ui +``` + +```bash +(venv) ✗ sqlmesh ui +INFO: Started server process [89705] +INFO: Waiting for application startup. +INFO: Application startup complete. +INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit) +``` + +![image.png](./incremental_time/node_level_audit_trail.png) + +When I click on a column in `demo.incrementals_demo`, it will trace the column to the source! + +![image.png](./incremental_time/column_level_audit_trail.png) + +Now, typically, I will promote these changes to production using SQLMesh’s open source GitHub CICD bot as shown in [this demo pull request](https://github.com/TobikoData/tobiko-cloud-demo/pull/4), but to keep this guide simpler, let’s run `sqlmesh plan` directly. + +This is where I feel the claim “data transformation without the waste” feels tangible. I did all this great work in my dev environment, and I’m used to reprocessing and duplicating storage in production. However, by default SQLMesh will bypass all that and create new views to point to the same physical tables created in `dev`! You can see for yourself in the query history. + +```bash +(venv) ✗ sqlmesh plan +====================================================================== +Successfully Ran 2 tests against duckdb +---------------------------------------------------------------------- +Differences from the `prod` environment: + +Models: +├── Added: + ├── demo.incrementals_demo + ├── tcloud_raw_data.product_usage + └── tcloud_raw_data.sales +Apply - Virtual Update [y/n]: y + +SKIP: No physical layer updates to perform + +SKIP: No model batches to execute + +Virtually Updating 'prod' ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 0:00:02 + +The target environment has been updated successfully +``` + +??? "Create production schema if it does not exist" + + ```sql + CREATE SCHEMA IF NOT EXISTS `sqlmesh-public-demo`.`demo` + ``` + +??? "Create a production version of the view. This is where SQLMesh reuses the hard work you’ve already done." + + ```sql + CREATE OR REPLACE VIEW `sqlmesh-public-demo`.`demo`.`incrementals_demo` AS + SELECT * + FROM `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__incrementals_demo__3076101542` + ``` + +??? "Run data audits to test if `transaction_id` is unique and not null (SQL is automatically generated)" + + **Made you look! No need to rerun audits we already passed in dev.** + +## Making Changes + +Alright, it feels pretty neat to go through this workflow, but now comes the part that represents the majority of my job as a data engineer: + +- Making changes +- Testing those changes +- Promoting those changes safely and confidently to production + +Let’s say I want to change my code's definition of a power user but ONLY going forward because we want to broaden our definition. However, I still want to retain how we defined power users historically. + +At first glance, this is a very surgical operation that can feel intimidating with custom `DML` operations, but thankfully SQLMesh has a native way to solve this problem. + +First, I make the change to decrease the threshold in my SQL logic: + +```sql +CASE + WHEN p.usage_count > 50 AND p.feature_utilization_score > 0.5 THEN 'Power User' +``` + +Unlike last time, I run `sqlmesh plan dev --forward-only` with the `--forward-only` flag, which tells SQLMesh it should not run the changed model against all the existing data. + +In the terminal output, I can see the change displayed like before, but I see some new date prompts. + +I leave the [effective date](../concepts/plans.md#effective-date) prompt blank because I do not want to reprocess historical data in `prod` - I only want to apply this new business logic going forward. + +However, I do want to preview the new business logic in my `dev` environment before pushing to `prod`. Because I have [configured SQLMesh to create previews](https://github.com/SQLMesh/sqlmesh-demos/blob/e0e3899e173cf7b8447ae707402a9df59911d1c0/config.yaml#L42) for forward-only models in my `config.yaml` file, SQLMesh has created a temporary copy of the `prod` table in my `dev` environment, so I can test the new logic on historical data. + +I specify the beginning of the preview's historical data window as `2024-10-27` in the preview start date prompt, and I specify the end of the window as now by leaving the preview end date prompt blank. + +```bash +sqlmesh plan dev --forward-only +``` + +```bash +(venv) ➜ sqlmesh-demos git:(incremental-demo) ✗ sqlmesh plan dev --forward-only +====================================================================== +Successfully Ran 2 tests against duckdb +---------------------------------------------------------------------- +Differences from the `dev` environment: + +Models: +└── Directly Modified: + └── demo__dev.incrementals_demo +--- + ++++ + +@@ -57,7 +57,7 @@ + + p.feature_utilization_score, + p.user_segment, + CASE +- WHEN p.usage_count > 100 AND p.feature_utilization_score > 0.6 ++ WHEN p.usage_count > 50 AND p.feature_utilization_score > 0.5 + THEN 'Power User' + WHEN p.usage_count > 50 + THEN 'Regular User' +Directly Modified: demo__dev.incrementals_demo (Forward-only) +Enter the effective date (eg. '1 year', '2020-01-01') to apply forward-only changes retroactively or blank to only apply them going forward once changes +are deployed to prod: +Models needing backfill (missing dates): +└── demo__dev.incrementals_demo: 2024-11-07 - 2024-11-07 (preview) +Enter the preview start date (eg. '1 year', '2020-01-01') or blank to backfill to preview starting from yesterday: 2024-10-27 +Enter the preview end date (eg. '1 month ago', '2020-01-01') or blank to preview up until '2024-11-08 00:00:00': +Apply - Preview Tables [y/n]: y +[1/1] demo__dev.incrementals_demo evaluated in 6.18s +Evaluating models ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 1/1 • 0:00:06 + + +All model batches have been executed successfully + +Virtually Updating 'dev' ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 0:00:01 + +The target environment has been updated successfully + +``` + +??? "Create another empty table with the proper schema that’s also versioned (ex: `__2896326998__dev__schema_migration_source`)." + + ```sql + CREATE TABLE IF NOT EXISTS `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__incrementals_demo__2896326998__dev__schema_migration_source` ( + `transaction_id` STRING, `product_id` STRING, `customer_id` STRING, `transaction_amount` NUMERIC, `transaction_date` DATE, + `transaction_timestamp_pst` DATETIME, `payment_method` STRING, `currency` STRING, `last_usage_date` TIMESTAMP, `usage_count` INT64, + `feature_utilization_score` FLOAT64, `user_segment` STRING, `user_type` STRING, `days_since_last_usage` INT64 + ) + PARTITION BY `transaction_date` + ``` + + +??? "Validate new SQL (note the `WHERE FALSE LIMIT 0` and the placeholder timestamps)" + + ```sql + WITH `sales_data` AS ( + SELECT + `sales`.`transaction_id` AS `transaction_id`, + `sales`.`product_id` AS `product_id`, + `sales`.`customer_id` AS `customer_id`, + `sales`.`transaction_amount` AS `transaction_amount`, + `sales`.`transaction_timestamp` AS `transaction_timestamp`, + `sales`.`payment_method` AS `payment_method`, + `sales`.`currency` AS `currency` + FROM `sqlmesh-public-demo`.`tcloud_raw_data`.`sales` AS `sales` + WHERE ( + `sales`.`transaction_timestamp` <= CAST('1970-01-01 23:59:59.999999+00:00' AS TIMESTAMP) + AND `sales`.`transaction_timestamp` >= CAST('1970-01-01 00:00:00+00:00' AS TIMESTAMP)) + AND FALSE + ), + `product_usage` AS ( + SELECT + `product_usage`.`product_id` AS `product_id`, + `product_usage`.`customer_id` AS `customer_id`, + `product_usage`.`last_usage_date` AS `last_usage_date`, + `product_usage`.`usage_count` AS `usage_count`, + `product_usage`.`feature_utilization_score` AS `feature_utilization_score`, + `product_usage`.`user_segment` AS `user_segment` + FROM `sqlmesh-public-demo`.`tcloud_raw_data`.`product_usage` AS `product_usage` + WHERE ( + `product_usage`.`last_usage_date` <= CAST('1970-01-01 23:59:59.999999+00:00' AS TIMESTAMP) + AND `product_usage`.`last_usage_date` >= CAST('1969-12-02 00:00:00+00:00' AS TIMESTAMP)) + AND FALSE + ) + SELECT + `s`.`transaction_id` AS `transaction_id`, + `s`.`product_id` AS `product_id`, + `s`.`customer_id` AS `customer_id`, + CAST(`s`.`transaction_amount` AS NUMERIC) AS `transaction_amount`, + DATE(`s`.`transaction_timestamp`) AS `transaction_date`, + DATETIME(`s`.`transaction_timestamp`, 'America/Los_Angeles') AS `transaction_timestamp_pst`, + `s`.`payment_method` AS `payment_method`, + `s`.`currency` AS `currency`, + `p`.`last_usage_date` AS `last_usage_date`, + `p`.`usage_count` AS `usage_count`, + `p`.`feature_utilization_score` AS `feature_utilization_score`, + `p`.`user_segment` AS `user_segment`, + CASE + WHEN `p`.`feature_utilization_score` > 0.5 AND `p`.`usage_count` > 50 THEN 'Power User' + WHEN `p`.`usage_count` > 50 THEN 'Regular User' + WHEN `p`.`usage_count` IS NULL THEN 'New User' + ELSE 'Light User' + END AS `user_type`, + DATE_DIFF(`s`.`transaction_timestamp`, `p`.`last_usage_date`, DAY) AS `days_since_last_usage` + FROM `sales_data` AS `s` + LEFT JOIN `product_usage` AS `p` ON + `p`.`customer_id` = `s`.`customer_id` + AND `p`.`product_id` = `s`.`product_id` + WHERE FALSE + LIMIT 0 + ``` + +??? "Create a **CLONE** of the table in the `preview` process so that we work with physical data for these specific backfill date ranges." + + This will NOT be reused when deployed to prod. + + ```sql + CREATE OR REPLACE TABLE `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__incrementals_demo__2896326998__dev` + CLONE `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__incrementals_demo__843752089` + ``` + +??? "Inspect metadata for this newly versioned table we’re creating, so we can properly track it from its journey from dev to prod eventually." + + This query examines the table's `INFORMATION_SCHEMA` metadata about column names and types to confirm for SQLMesh’s state that objects exist as expected. + + Since other actors could hypothetically touch/modify the project's tables, SQLMesh doesn’t ever reuse this info because it could have changed. That’s why we see this query executed so many times in the logs. + + ```sql + WITH `clustering_info` AS ( + SELECT + `table_catalog`, + `table_schema`, + `table_name`, + STRING_AGG(`column_name` ORDER BY `clustering_ordinal_position`) AS `clustering_key` + FROM `sqlmesh-public-demo`.`sqlmesh__demo`.`INFORMATION_SCHEMA`.`COLUMNS` + WHERE `clustering_ordinal_position` IS NOT NULL + GROUP BY 1, 2, 3 + ) + SELECT + `table_catalog` AS `catalog`, + `table_name` AS `name`, + `table_schema` AS `schema_name`, + CASE + WHEN `table_type` = 'BASE TABLE' THEN 'TABLE' + WHEN `table_type` = 'CLONE' THEN 'TABLE' + WHEN `table_type` = 'EXTERNAL' THEN 'TABLE' + WHEN `table_type` = 'SNAPSHOT' THEN 'TABLE' + WHEN `table_type` = 'VIEW' THEN 'VIEW' + WHEN `table_type` = 'MATERIALIZED VIEW' THEN 'MATERIALIZED_VIEW' + ELSE `table_type` END + AS `type`, + `ci`.`clustering_key` AS `clustering_key` + FROM `sqlmesh-public-demo`.`sqlmesh__demo`.`INFORMATION_SCHEMA`.`TABLES` + LEFT JOIN `clustering_info` AS `ci` USING (`table_catalog`, `table_schema`, `table_name`) + WHERE `table_name` IN ('demo__incrementals_demo__2896326998__dev') + ``` + +??? "Inspect metadata to track journey for the migration source schema" + + ```sql + WITH `clustering_info` AS ( + SELECT + `table_catalog`, + `table_schema`, + `table_name`, + STRING_AGG(`column_name` ORDER BY `clustering_ordinal_position`) AS `clustering_key` + FROM `sqlmesh-public-demo`.`sqlmesh__demo`.`INFORMATION_SCHEMA`.`COLUMNS` + WHERE `clustering_ordinal_position` IS NOT NULL + GROUP BY 1, 2, 3 + ) + SELECT + `table_catalog` AS `catalog`, + `table_name` AS `name`, + `table_schema` AS `schema_name`, + CASE + WHEN `table_type` = 'BASE TABLE' THEN 'TABLE' + WHEN `table_type` = 'CLONE' THEN 'TABLE' + WHEN `table_type` = 'EXTERNAL' THEN 'TABLE' + WHEN `table_type` = 'SNAPSHOT' THEN 'TABLE' + WHEN `table_type` = 'VIEW' THEN 'VIEW' + WHEN `table_type` = 'MATERIALIZED VIEW' THEN 'MATERIALIZED_VIEW' + ELSE `table_type` + END + AS `type`, + `ci`.`clustering_key` AS `clustering_key` + FROM `sqlmesh-public-demo`.`sqlmesh__demo`.`INFORMATION_SCHEMA`.`TABLES` + LEFT JOIN `clustering_info` AS `ci` USING (`table_catalog`, `table_schema`, `table_name`) + WHERE `table_name` IN ('demo__incrementals_demo__2896326998__dev__schema_migration_source') + ``` + +??? "Drop the migration source table because we have the metadata we need now for proper state tracking" + + ```sql + DROP TABLE IF EXISTS `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__incrementals_demo__2896326998__dev__schema_migration_source` + ``` + +??? "Merge data into empty table for only the intervals I care about: 2024-10-27 to 'up until now'" + + ```sql + MERGE INTO `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__incrementals_demo__2896326998__dev` AS `__MERGE_TARGET__` USING ( + WITH `sales_data` AS ( + SELECT + `sales`.`transaction_id` AS `transaction_id`, + `sales`.`product_id` AS `product_id`, + `sales`.`customer_id` AS `customer_id`, + `sales`.`transaction_amount` AS `transaction_amount`, + `sales`.`transaction_timestamp` AS `transaction_timestamp`, + `sales`.`payment_method` AS `payment_method`, + `sales`.`currency` AS `currency` + FROM `sqlmesh-public-demo`.`tcloud_raw_data`.`sales` AS `sales` + WHERE + `sales`.`transaction_timestamp` <= CAST('2024-11-07 23:59:59.999999+00:00' AS TIMESTAMP) + AND `sales`.`transaction_timestamp` >= CAST('2024-10-27 00:00:00+00:00' AS TIMESTAMP) + ), + `product_usage` AS ( + SELECT + `product_usage`.`product_id` AS `product_id`, + `product_usage`.`customer_id` AS `customer_id`, + `product_usage`.`last_usage_date` AS `last_usage_date`, + `product_usage`.`usage_count` AS `usage_count`, + `product_usage`.`feature_utilization_score` AS `feature_utilization_score`, + `product_usage`.`user_segment` AS `user_segment` + FROM `sqlmesh-public-demo`.`tcloud_raw_data`.`product_usage` AS `product_usage` + WHERE + `product_usage`.`last_usage_date` <= CAST('2024-11-07 23:59:59.999999+00:00' AS TIMESTAMP) + AND `product_usage`.`last_usage_date` >= CAST('2024-09-27 00:00:00+00:00' AS TIMESTAMP) + ) + SELECT + `transaction_id`, + `product_id`, + `customer_id`, + `transaction_amount`, + `transaction_date`, + `transaction_timestamp_pst`, + `payment_method`, + `currency`, + `last_usage_date`, + `usage_count`, + `feature_utilization_score`, + `user_segment`, + `user_type`, + `days_since_last_usage` + FROM ( + SELECT + `s`.`transaction_id` AS `transaction_id`, + `s`.`product_id` AS `product_id`, + `s`.`customer_id` AS `customer_id`, + CAST(`s`.`transaction_amount` AS NUMERIC) AS `transaction_amount`, + DATE(`s`.`transaction_timestamp`) AS `transaction_date`, + DATETIME(`s`.`transaction_timestamp`, 'America/Los_Angeles') AS `transaction_timestamp_pst`, + `s`.`payment_method` AS `payment_method`, + `s`.`currency` AS `currency`, + `p`.`last_usage_date` AS `last_usage_date`, + `p`.`usage_count` AS `usage_count`, + `p`.`feature_utilization_score` AS `feature_utilization_score`, + `p`.`user_segment` AS `user_segment`, + CASE + WHEN `p`.`feature_utilization_score` > 0.5 AND `p`.`usage_count` > 50 THEN 'Power User' + WHEN `p`.`usage_count` > 50 THEN 'Regular User' + WHEN `p`.`usage_count` IS NULL THEN 'New User' + ELSE 'Light User' + END + AS `user_type`, + DATE_DIFF(`s`.`transaction_timestamp`, `p`.`last_usage_date`, DAY) AS `days_since_last_usage` + FROM `sales_data` AS `s` + LEFT JOIN `product_usage` AS `p` ON + `p`.`customer_id` = `s`.`customer_id` + AND `p`.`product_id` = `s`.`product_id` + ) AS `_subquery` + WHERE `transaction_date` BETWEEN CAST('2024-10-27' AS DATE) AND CAST('2024-11-07' AS DATE) + ) AS `__MERGE_SOURCE__ + ON FALSE + WHEN NOT MATCHED BY SOURCE AND `transaction_date` BETWEEN CAST('2024-10-27' AS DATE) AND CAST('2024-11-07' AS DATE) THEN DELETE + WHEN NOT MATCHED THEN INSERT ( + `transaction_id`, `product_id`, `customer_id`, `transaction_amount`, `transaction_date`, `transaction_timestamp_pst`, + `payment_method`, `currency`, `last_usage_date`, `usage_count`, `feature_utilization_score`, `user_segment`, `user_type`, + `days_since_last_usage`) + VALUES (`transaction_id`, `product_id`, `customer_id`, `transaction_amount`, `transaction_date`, `transaction_timestamp_pst`, + `payment_method`, `currency`, `last_usage_date`, `usage_count`, `feature_utilization_score`, `user_segment`, `user_type`, + `days_since_last_usage`) + ``` + +??? "Run data audits to test if `transaction_id` is unique and not null." + + SQL is automatically generated for the preview data range in scope: 2024-10-27 to “up until now”. + + `UNIQUE_VALUES()` audit + ```sql + SELECT + COUNT(*) + FROM ( + SELECT * + FROM ( + SELECT ROW_NUMBER() OVER ( + PARTITION BY (`transaction_id`) + ORDER BY (`transaction_id`) + ) AS `rank_` + FROM ( + SELECT * + FROM `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__incrementals_demo__2896326998__dev` AS `demo__incrementals_demo__2896326998__dev` + WHERE `transaction_date` BETWEEN CAST('2024-10-27' AS DATE) AND CAST('2024-11-08' AS DATE) + ) AS `_q_0` + WHERE TRUE + ) AS `_q_1` + WHERE `rank_` > 1 + ) AS `audit` + ``` + + `NOT_NULL()` audit + ```sql + SELECT + COUNT(*) + FROM ( + SELECT * + FROM ( + SELECT * + FROM `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__incrementals_demo__2896326998__dev` AS `demo__incrementals_demo__2896326998__dev` + WHERE `transaction_date` BETWEEN CAST('2024-10-27' AS DATE) AND CAST('2024-11-08' AS DATE) + ) AS `_q_0` + WHERE + (`transaction_id`) IS NULL + AND TRUE + ) AS `audit` + ``` + +??? "Create development schema based on the name of the plan dev environment" + + ```sql + CREATE SCHEMA IF NOT EXISTS `sqlmesh-public-demo`.`demo__dev` + ``` + +??? "Create a view in the virtual layer to officially query this new table version" + + ```sql + CREATE OR REPLACE VIEW `sqlmesh-public-demo`.`demo__dev`.`incrementals_demo` AS + SELECT * FROM `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__incrementals_demo__2896326998__dev` + ``` + +Now I’m getting exactly what I expect when I preview the data. + +- Backfill (reprocess) the new definition of power user on and after 2024-10-27 in `dev` only +- See the new power user definition apply from 2024-10-27 to now +- Retain the old definition of power user before 2024-10-27 to preview the comparison + +An experience I’d manually do outside of my transformation workflow with a cobbling of python scripts and ad hoc SQL is now both clear and predictable and tracked in SQLMesh’s state history. + +```bash +sqlmesh fetchdf "select * from demo__dev.incrementals_demo where usage_count>=50" +``` + +```bash +(venv) ✗ sqlmesh fetchdf "select * from demo__dev.incrementals_demo where usage_count>=50" + + transaction_id product_id customer_id transaction_amount ... feature_utilization_score user_segment user_type days_since_last_usage +0 TX-002 PROD-102 CUST-002 149.990000000 ... 0.92 enterprise Power User 0 +1 TX-001 PROD-101 CUST-001 99.990000000 ... 0.85 enterprise Power User 0 +2 TX-008 PROD-102 CUST-006 149.990000000 ... 0.88 enterprise Power User 0 +3 TX-006 PROD-101 CUST-001 99.990000000 ... 0.85 enterprise Power User 0 +4 TX-007 PROD-103 CUST-002 299.990000000 ... 0.68 enterprise Regular User 0 +5 TX-009 PROD-101 CUST-007 99.990000000 ... 0.55 professional Regular User 0 +6 TX-011 PROD-101 CUST-009 99.990000000 ... 0.65 professional Power User 0 +7 TX-013 PROD-103 CUST-001 299.990000000 ... 0.75 enterprise Power User 0 + +[8 rows x 14 columns] +``` + +Now, here, I may think through this question during development: + +- What if I don’t like the data results during the preview part of my `sqlmesh plan dev --forward-only`? + - I update my code changes, go through the above workflow again, and preview data for a specific date range whether for a regular `sqlmesh plan dev` or `sqlmesh plan dev --forward-only` + +## Adding Unit Tests + +Data audits are great, but they only verify basic things like primary key integrity. They don’t validate my SQL logic is doing exactly what I want. + +I know SQLMesh has unit tests, but the quiet part out loud is that I dislike writing so much `yaml` by hand. Thankfully, I don’t have to. + +I can use the `sqlmesh create_test` command to generate the unit test configuration file for me, using SQL queries to select and store the data the tests will run on. + +```bash +sqlmesh create_test demo.incrementals_demo \ +--query sqlmesh-public-demo.tcloud_raw_data.product_usage "select * from sqlmesh-public-demo.tcloud_raw_data.product_usage where customer_id='CUST-001'" \ +--query sqlmesh-public-demo.tcloud_raw_data.sales "select * from sqlmesh-public-demo.tcloud_raw_data.sales where customer_id='CUST-001'" \ +--var start_dt '2024-10-25' \ +--var end_dt '2024-10-27' +``` + +It’ll create a unit test configuration file automatically like the below based on live queried data called `test_incrementals_demo.yaml`. I can then modify this file to my liking. + +??? "Unit test configuration file" + + ```yaml + test_incrementals_demo: + model: demo.incrementals_demo + inputs: + '`sqlmesh-public-demo`.`tcloud_raw_data`.`product_usage`': + - product_id: PROD-101 + customer_id: CUST-001 + last_usage_date: 2024-10-25 23:45:00+00:00 + usage_count: 120 + feature_utilization_score: 0.85 + user_segment: enterprise + - product_id: PROD-103 + customer_id: CUST-001 + last_usage_date: 2024-10-27 12:30:00+00:00 + usage_count: 95 + feature_utilization_score: 0.75 + user_segment: enterprise + '`sqlmesh-public-demo`.`tcloud_raw_data`.`sales`': + - transaction_id: TX-013 + product_id: PROD-103 + customer_id: CUST-001 + transaction_amount: '299.990000000' + transaction_timestamp: 2024-10-27 08:40:00+00:00 + payment_method: credit_card + currency: USD + - transaction_id: TX-006 + product_id: PROD-101 + customer_id: CUST-001 + transaction_amount: '99.990000000' + transaction_timestamp: 2024-10-26 03:15:00+00:00 + payment_method: credit_card + currency: USD + - transaction_id: TX-001 + product_id: PROD-101 + customer_id: CUST-001 + transaction_amount: '99.990000000' + transaction_timestamp: 2024-10-25 08:30:00+00:00 + payment_method: credit_card + currency: USD + outputs: + query: + - transaction_id: TX-006 + product_id: PROD-101 + customer_id: CUST-001 + transaction_amount: 99.99 + transaction_date: 2024-10-25 + transaction_timestamp_pst: 2024-10-25 20:15:00 + payment_method: credit_card + currency: USD + last_usage_date: 2024-10-25 16:45:00-07:00 + usage_count: 120 + feature_utilization_score: 0.85 + user_segment: enterprise + user_type: Power User + days_since_last_usage: 0 + - transaction_id: TX-001 + product_id: PROD-101 + customer_id: CUST-001 + transaction_amount: 99.99 + transaction_date: 2024-10-25 + transaction_timestamp_pst: 2024-10-25 01:30:00 + payment_method: credit_card + currency: USD + last_usage_date: 2024-10-25 16:45:00-07:00 + usage_count: 120 + feature_utilization_score: 0.85 + user_segment: enterprise + user_type: Power User + days_since_last_usage: 0 + vars: + start_dt: '2024-10-25' + end_dt: '2024-10-27' + ``` + +Now, when I run `sqlmesh test` I run all my unit tests for free on my local machine. + +SQLMesh runs these unit test fixtures directly in [duckdb](https://duckdb.org/) in-memory by transpiling your specific database’s SQL syntax into the same meaning via [SQLGlot](https://github.com/tobymao/sqlglot). That’s why it runs so fast! + +??? "I can also run my unit tests against my main query engine to test things like UDFs or if there’s very specific SQL functions that do not neatly transpile to duckdb. Example test connection." + + ```yaml + gateways: + bigquery: + connection: + concurrent_tasks: 24 + register_comments: true + type: bigquery + method: service-account-json + keyfile_json: {{ env_var('GOOGLE_SQLMESH_CREDENTIALS') }} + project: sqlmesh-public-demo + test_connection: + concurrent_tasks: 24 + register_comments: true + type: bigquery + method: service-account-json + keyfile_json: {{ env_var('GOOGLE_SQLMESH_CREDENTIALS') }} + project: sqlmesh-public-demo + ``` + +```sql +(venv) ✗ sqlmesh test +... +---------------------------------------------------------------------- +Ran 3 tests in 0.090s + +OK +``` + +## Promoting Changes to Production + +Now that I’ve done all this great work, how do I get this promoted into production? + +Typically, I will open a pull request combined with the [SQLMesh GitHub CI/CD bot](../integrations/github.md) as I mentioned earlier in this guide. But to keep it simple, I’ll run `sqlmesh plan` as I did above. + +This time because it’s promoting a forward-only dev model into prod, it’s a virtual update to the SQL definition. + +We run a bunch of metadata queries to version tables. More queries (read: 15/15 in the progress bar) are run in this forward-only model promotion to track schema evolution, if it appears, between the old and new schema. + +Next time it’s run, it’ll backfill new data with this new definition of ‘Power User’. + +```bash +sqlmesh plan +``` + +```bash +(venv) ➜ sqlmesh-demos git:(incremental-demo) ✗ sqlmesh plan +====================================================================== +Successfully Ran 3 tests against duckdb +---------------------------------------------------------------------- +Differences from the `prod` environment: + +Models: +└── Directly Modified: + └── demo.incrementals_demo +--- + ++++ + +@@ -57,7 +57,7 @@ + + p.feature_utilization_score, + p.user_segment, + CASE +- WHEN p.usage_count > 100 AND p.feature_utilization_score > 0.6 ++ WHEN p.usage_count > 50 AND p.feature_utilization_score > 0.5 + THEN 'Power User' + WHEN p.usage_count > 50 + THEN 'Regular User' +Directly Modified: demo.incrementals_demo (Forward-only) +Apply - Virtual Update [y/n]: y + +SKIP: No physical layer updates to perform + +SKIP: No model batches to execute + +Virtually Updating 'prod' ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 0:00:02 + +The target environment has been updated successfully +``` + +??? "Create production schema if it does not exist" + + ```sql + CREATE SCHEMA IF NOT EXISTS `sqlmesh-public-demo`.`demo` + ``` + +??? "Create a production version of the view. This is where SQLMesh reuses the hard work you’ve already done. No need to rerun audits." + + ```sql + CREATE OR REPLACE VIEW `sqlmesh-public-demo`.`demo`.`incrementals_demo` AS + SELECT * FROM `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__incrementals_demo__2896326998` + ``` + +Now, when a cron job runs on a schedule SQLMesh will track that midnight UTC has passed for a full day before running new intervals to backfill in this SQL model. Note: it will skip backfilling this model if a full day interval has not passed. + +The run will look and feel like the below as an example. + +```bash +sqlmesh run --select-model "demo.incrementals_demo" +``` + +```bash +(venv) ✗ sqlmesh run --select-model "demo.incrementals_demo" +[1/1] demo.incrementals_demo evaluated in 8.40s +Evaluating models ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 1/1 • 0:00:08 + + +All model batches have been executed successfully + +Run finished for environment 'prod' +``` + +??? "Merge data into empty table for only the intervals I have not backfilled since last running this command" + + ```sql + MERGE INTO `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__incrementals_demo__922005762` AS `__MERGE_TARGET__` USING ( + WITH `sales_data` AS ( + SELECT + `sales`.`transaction_id` AS `transaction_id`, + `sales`.`product_id` AS `product_id`, + `sales`.`customer_id` AS `customer_id`, + `sales`.`transaction_amount` AS `transaction_amount`, + `sales`.`transaction_timestamp` AS `transaction_timestamp`, + `sales`.`payment_method` AS `payment_method`, + `sales`.`currency` AS `currency` + FROM `sqlmesh-public-demo`.`tcloud_raw_data`.`sales` AS `sales` + WHERE + `sales`.`transaction_timestamp` <= CAST('2024-11-07 23:59:59.999999+00:00' AS TIMESTAMP) + AND `sales`.`transaction_timestamp` >= CAST('2024-11-03 00:00:00+00:00' AS TIMESTAMP) + ), + `product_usage` AS ( + SELECT + `product_usage`.`product_id` AS `product_id`, + `product_usage`.`customer_id` AS `customer_id`, + `product_usage`.`last_usage_date` AS `last_usage_date`, + `product_usage`.`usage_count` AS `usage_count`, + `product_usage`.`feature_utilization_score` AS `feature_utilization_score`, + `product_usage`.`user_segment` AS `user_segment` + FROM `sqlmesh-public-demo`.`tcloud_raw_data`.`product_usage` AS `product_usage` + WHERE + `product_usage`.`last_usage_date` <= CAST('2024-11-07 23:59:59.999999+00:00' AS TIMESTAMP) + AND `product_usage`.`last_usage_date` >= CAST('2024-10-04 00:00:00+00:00' AS TIMESTAMP) + ) + SELECT + `transaction_id`, + `product_id`, + `customer_id`, + `transaction_amount`, + `transaction_date`, + `transaction_timestamp_pst`, + `payment_method`, + `currency`, + `last_usage_date`, + `usage_count`, + `feature_utilization_score`, + `user_segment`, + `user_type`, + `days_since_last_usage` + FROM ( + SELECT + `s`.`transaction_id` AS `transaction_id`, + `s`.`product_id` AS `product_id`, + `s`.`customer_id` AS `customer_id`, + `s`.`transaction_amount` AS `transaction_amount`, + DATE(`s`.`transaction_timestamp`) AS `transaction_date`, + DATETIME(`s`.`transaction_timestamp`, 'America/Los_Angeles') AS `transaction_timestamp_pst`, + `s`.`payment_method` AS `payment_method`, + `s`.`currency` AS `currency`, + `p`.`last_usage_date` AS `last_usage_date`, + `p`.`usage_count` AS `usage_count`, + `p`.`feature_utilization_score` AS `feature_utilization_score`, + `p`.`user_segment` AS `user_segment`, + CASE + WHEN `p`.`feature_utilization_score` > 0.6 AND `p`.`usage_count` > 60 THEN 'Power User' + WHEN `p`.`usage_count` > 50 THEN 'Regular User' + WHEN `p`.`usage_count` IS NULL THEN 'New User' + ELSE 'Light User' + END + AS `user_type`, + DATE_DIFF(`s`.`transaction_timestamp`, `p`.`last_usage_date`, DAY) AS `days_since_last_usage` + FROM `sales_data` AS `s` + LEFT JOIN `product_usage` AS `p` ON + `p`.`customer_id` = `s`.`customer_id` + AND `p`.`product_id` = `s`.`product_id` + ) AS `_subquery` + WHERE + `transaction_date` BETWEEN CAST('2024-11-03' AS DATE) + AND CAST('2024-11-07' AS DATE) + ) AS `__MERGE_SOURCE__` + ON FALSE + WHEN NOT MATCHED BY SOURCE AND `transaction_date` BETWEEN CAST('2024-11-03' AS DATE) AND CAST('2024-11-07' AS DATE) THEN DELETE + WHEN NOT MATCHED THEN INSERT ( + `transaction_id`, `product_id`, `customer_id`, `transaction_amount`, `transaction_date`, `transaction_timestamp_pst`, + `payment_method`, `currency`, `last_usage_date`, `usage_count`, `feature_utilization_score`, `user_segment`, `user_type`, + `days_since_last_usage`) + VALUES (`transaction_id`, `product_id`, `customer_id`, `transaction_amount`, `transaction_date`, `transaction_timestamp_pst`, + `payment_method`, `currency`, `last_usage_date`, `usage_count`, `feature_utilization_score`, `user_segment`, `user_type`, + `days_since_last_usage`) + ``` + +??? "Run data audits to test if transaction_id is unique and not null. SQL is automatically generated." + + `UNIQUE_VALUES()` audit + ```sql + SELECT + COUNT(*) + FROM ( + SELECT * + FROM ( + SELECT + ROW_NUMBER() OVER ( + PARTITION BY (`transaction_id`) + ORDER BY (`transaction_id`) + ) AS `rank_` + FROM ( + SELECT * + FROM `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__incrementals_demo__922005762` AS `demo__incrementals_demo__922005762` + WHERE `transaction_date` BETWEEN CAST('2024-11-03' AS DATE) AND CAST('2024-11-08' AS DATE) + ) AS `_q_0` + WHERE TRUE + ) AS `_q_1` + WHERE `rank_` > 1 + ) AS `audit` + ``` + + `NOT_NULL()` audit + ```sql + SELECT + COUNT(*) + FROM ( + SELECT * + FROM ( + SELECT * + FROM `sqlmesh-public-demo`.`sqlmesh__demo`.`demo__incrementals_demo__922005762` AS `demo__incrementals_demo__922005762` + WHERE `transaction_date` BETWEEN CAST('2024-11-03' AS DATE) AND CAST('2024-11-08' AS DATE) + ) AS `_q_0` + WHERE + (`transaction_id`) IS NULL + AND TRUE + ) AS `audit` + ``` + +## Summary + +I went through a full workflow for an intimidating problem, and it feels really good knowing what goes on behind the scenes when I run these SQLMesh commands. For those coming from other transformation frameworks like dbt, this is a new way to work. + +It respects data as infrastructure vs. things to rebuild many times over each time you change something. I hope you feel equipped AND confident to start using SQLMesh and especially incremental models today! + +I’ll make it convenient for you in making sure we answered all the pertinent questions. + +- How do I handle late arriving data? + - Use the `lookback` config. +- How do I account for UTC vs. PST (California baby) timestamps, do I convert them? + - See the SQL logic for how everything is in UTC by default for safe and reliable processing and then convert the presentation timestamp to PST for downstream tools (ex: business intelligence, data sharing) +- What schedule should I run these at? + - Daily is a default as we don’t want to show incomplete intervals when merging product and sales information. You can go as low as 5 minutes. +- How do I test this data? + - Unit tests for code, audits for data +- How do I make this run fast and only the intervals necessary (read: partitions)? + - SQLMesh macros work by default to run and test only the intervals necessary because it manages state. + - No `max(timestamp)` acrobatics. +- How do I make patch changes when an edge case error occurs with incorrect data for specific time ranges? + - Make code changes and backfill only what’s necessary safely in dev before promoting to prod. + - Retain history AND correct changes for specific time ranges. + - Check out the forward-only example above and notice you can make **retroactive** changes to prod. +- What do unit tests look and feel like for this? + - See automatic unit test creation above. No manual `yaml` handwriting! +- How do I prevent data gaps with unprocessed or incomplete intervals? + - SQLMesh manages state, so it will track which intervals were backfilled vs. not. + - Even if an interval failed during a scheduled `sqlmesh run`, it will recognize that the next time this command is run and attempt to backfill that previously failed interval. + - No `max(timestamp)` acrobats. +- Am I okay processing incomplete intervals (think: allow partials)? + - I'm only okay with allowing partial intervals to be processed for things like logging event data, but for sales and product data, I want to make sure complete intervals are processed so end users don't confuse incomplete data with incorrect data. +- What tradeoffs am I willing to make for fresh data? + - I prefer complete data over fresh data for its own sake. Correctness matters when viewing revenue data. +- How to make this not feel so complex during development? + - Hopefully this guide helps ;) +- How do I know SQLMesh is behaving how I want it to behave? + - See the queries run by SQLMesh above. They’re listed out exactly as listed in the query history. + - I skip listing out basic metadata queries and test connection queries like `SELECT 1` as those are more background tasks than core logic tasks. +- Bonus question: How does this compare to dbt’s way of handling incrementals? + - [See here for a complete comparison](https://tobikodata.com/dbt-incremental-but-incomplete.html) diff --git a/docs/examples/overview.md b/docs/examples/overview.md new file mode 100644 index 0000000000..e7dbc1916d --- /dev/null +++ b/docs/examples/overview.md @@ -0,0 +1,42 @@ +# Overview + +Realistic examples are a fantastic way to understand SQLMesh better. + +They allow you to tinker with a project's code and data, issuing different SQLMesh commands to see what happens. + +You can reset the examples at any time, so if things get turned around you can just start over! + +This page links to a few different types of examples: + +- **Walkthroughs** pose a specific story or task, and you follow along as we work through the story + - Walkthroughs **do not** require running code, although the code is available if you would like to + - Different walkthroughs use different SQL engines, so if you want to run the code you might need to update it for your SQL engine +- **Projects** are self-contained SQLMesh projects and datasets + - Projects generally use DuckDB so you can run them locally without installing or accessing a separate SQL engine + +!!! tip + + If you haven't tried out SQLMesh before, we recommending working through the [SQLMesh Quickstart](../quick_start.md) before trying these examples! + +## Walkthroughs + +Walkthroughs are easy to follow and provide lots of information in a self-contained format. + +- Get the SQLMesh workflow under your fingers with the [SQLMesh CLI Crash Course](./sqlmesh_cli_crash_course.md) +- See the end-to-end workflow in action with the [Incremental by Time Range: Full Walkthrough](./incremental_time_full_walkthrough.md) (BigQuery SQL engine) + +## Projects + +SQLMesh example projects are stored in the [sqlmesh-examples Github repository](https://github.com/SQLMesh/sqlmesh-examples). The repository's front page includes additional information about how to download the files and set up the projects. + +The two most comprehensive example projects use the SQLMesh `sushi` data, based on a fictional sushi restaurant. ("Tobiko" is the Japanese word for flying fish roe, commonly used in sushi.) + +The `sushi` data is described in an [overview notebook](https://github.com/SQLMesh/sqlmesh-examples/blob/main/001_sushi/sushi-overview.ipynb) in the repository. + +The example repository include two versions of the `sushi` project, at different levels of complexity: + +- The [`simple` project](https://github.com/SQLMesh/sqlmesh-examples/tree/main/001_sushi/1_simple) contains four `VIEW` and one `SEED` model + - The `VIEW` model kind refreshes every run, making it easy to reason about SQLMesh's behavior +- The [`moderate` project](https://github.com/SQLMesh/sqlmesh-examples/tree/main/001_sushi/2_moderate) contains five `INCREMENTAL_BY_TIME_RANGE`, one `FULL`, one `VIEW`, and one `SEED` model + - The incremental models allow you to observe how and when new data is transformed by SQLMesh + - Some models, like `customer_revenue_lifetime`, demonstrate more advanced incremental queries like customer lifetime value calculation diff --git a/docs/examples/sqlmesh_cli_crash_course.md b/docs/examples/sqlmesh_cli_crash_course.md new file mode 100644 index 0000000000..0bf5780f12 --- /dev/null +++ b/docs/examples/sqlmesh_cli_crash_course.md @@ -0,0 +1,1257 @@ +# SQLMesh CLI Crash Course + +
+ +This doc is designed to get you intimate with a **majority** of the SQLMesh workflows you’ll use to build *and* maintain transformation data pipelines. The goal is to get SQLMesh into muscle memory in 30 minutes or less. + +This doc is inspired by community observations, face-to-face conversations, live screenshares, and debugging sessions. This is *not* an exhaustive list but is rooted in lived experience. + +You can follow along in the [open source GitHub repo](https://github.com/sungchun12/sqlmesh-cli-crash-course). + +If you're new to how SQLMesh uses virtual data environments, [watch this quick explainer](https://www.loom.com/share/216835d64b3a4d56b2e061fa4bd9ee76?sid=88b3289f-e19b-4ccc-8b88-3faf9d7c9ce3). + +!!! tip + + Put this page on your second monitor or in a side by side window to swiftly copy/paste into your terminal. + +## Development Workflow + +You’ll use these commands 80% of the time because this is how you apply the changes you make to models. The workflow is: + +1. Make changes to your models directly in SQL and python files (pre-made in examples below) +2. Plan the changes in your dev environment +3. Apply the changes to your dev environment +4. Audit the changes (test data quality) +5. Run data diff against prod +6. Apply the changes to prod + +### Preview, Apply, and Audit Changes in `dev` + +You can make changes quickly and confidently through one simple command: `sqlmesh plan dev` + +- Plan the changes in your dev environment. +- Apply the changes to your dev environment by entering `y` at the prompt. +- Audit the changes (test data quality). This happens automatically when you apply the changes to dev. + +Note: If you run this without making any changes, SQLMesh will prompt you to make changes or use the `--include-unmodified` flag like this `sqlmesh plan dev --include-unmodified`. We recommend you make changes first before running this command to avoid creating a lot of noise in your dev environment with extraneous virtual layer views. + +=== "SQLMesh" + + ```bash + sqlmesh plan dev + ``` + + ```bash + sqlmesh plan + ``` + + If you want to move faster, you can add the `--auto-apply` flag to skip the manual prompt and apply the plan. You should do this when you're familiar with the plan output, and don't need to see tiny changes in the diff output before applying the plan. + + ```bash + sqlmesh plan --auto-apply + ``` + +=== "Tobiko Cloud" + + ```bash + tcloud sqlmesh plan dev + ``` + + ```bash + tcloud sqlmesh plan + ``` + + If you want to move faster, you can add the `--auto-apply` flag to skip the manual prompt and apply the plan. You should do this when you're familiar with the plan output, and don't need to see tiny changes in the diff output before applying the plan. + + ```bash + tcloud sqlmesh plan --auto-apply + ``` + +??? "Example Output" + I made a breaking change to `incremental_model` and `full_model`. + + SQLMesh: + + - Showed me the models impacted by the changes. + - Showed me the changes that will be made to the models. + - Showed me the models that need to be backfilled. + - Prompted me to apply the changes to `dev`. + - Showed me the audit failures that raise as warnings. + - Updated the physical layer to validate the SQL. + - Executed the model batches by inserting the data into the physical layer. + - Updated the virtual layer's view pointers to reflect the changes. + + ```bash + > sqlmesh plan dev + Differences from the `dev` environment: + + Models: + ├── Directly Modified: + │ ├── sqlmesh_example__dev.incremental_model + │ └── sqlmesh_example__dev.full_model + └── Indirectly Modified: + └── sqlmesh_example__dev.view_model + + --- + + +++ + + @@ -9,7 +9,8 @@ + + SELECT + item_id, + COUNT(DISTINCT id) AS num_orders, + - 6 AS new_column + + new_column + FROM sqlmesh_example.incremental_model + GROUP BY + - item_id + + item_id, + + new_column + + Directly Modified: sqlmesh_example__dev.full_model (Breaking) + + --- + + +++ + + @@ -15,7 +15,7 @@ + + id, + item_id, + event_date, + - 5 AS new_column + + 7 AS new_column + FROM sqlmesh_example.seed_model + WHERE + event_date BETWEEN @start_date AND @end_date + + Directly Modified: sqlmesh_example__dev.incremental_model (Breaking) + └── Indirectly Modified Children: + └── sqlmesh_example__dev.view_model (Indirect Breaking) + Models needing backfill: + ├── sqlmesh_example__dev.full_model: [full refresh] + ├── sqlmesh_example__dev.incremental_model: [2020-01-01 - 2025-04-16] + └── sqlmesh_example__dev.view_model: [recreate view] + Apply - Backfill Tables [y/n]: y + + Updating physical layer ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 2/2 • 0:00:00 + + ✔ Physical layer updated + + [1/1] sqlmesh_example__dev.incremental_model [insert 2020-01-01 - 2025-04-16] 0.03s + Executing model batches ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 0.0% • pending • 0:00:00 + sqlmesh_example__dev.incremental_model . + [WARNING] sqlmesh_example__dev.full_model: 'assert_positive_order_ids' audit error: 2 rows failed. Learn more in logs: + /Users/sung/Desktop/git_repos/sqlmesh-cli-revamp/logs/sqlmesh_2025_04_18_10_33_43.log + [1/1] sqlmesh_example__dev.full_model [full refresh, audits ❌1] 0.01s + Executing model batches ━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━━━━ 33.3% • 1/3 • 0:00:00 + sqlmesh_example__dev.full_model . + [WARNING] sqlmesh_example__dev.view_model: 'assert_positive_order_ids' audit error: 2 rows failed. Learn more in logs: + /Users/sung/Desktop/git_repos/sqlmesh-cli-revamp/logs/sqlmesh_2025_04_18_10_33_43.log + [1/1] sqlmesh_example__dev.view_model [recreate view, audits ✔2 ❌1] 0.01s + Executing model batches ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 3/3 • 0:00:00 + + ✔ Model batches executed + + Updating virtual layer ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 3/3 • 0:00:00 + + ✔ Virtual layer updated + ``` + +### Run Data Diff Against Prod + +
+ +Run data diff against prod. This is a good way to verify the changes are behaving as expected **after** applying them to `dev`. + +To make this easier and faster, you can run data diff against all models in the environment impacted by plan changes applied using the `-m '*'` flag example below. No need to specify the model name! Read more about options [here](../guides/tablediff.md). + +=== "SQLMesh" + + ```bash + sqlmesh table_diff prod:dev sqlmesh_example.full_model --show-sample + ``` + + ```bash + sqlmesh table_diff : --show-sample + ``` + + ```bash + sqlmesh table_diff prod:dev -m '*' --show-sample + ``` + +=== "Tobiko Cloud" + + ```bash + tcloud sqlmesh table_diff prod:dev sqlmesh_example.full_model --show-sample + ``` + + ```bash + tcloud sqlmesh table_diff : --show-sample + ``` + + ```bash + tcloud sqlmesh table_diff prod:dev -m '*' --show-sample + ``` + +??? "Example Output" + I compare the `prod` and `dev` environments for `sqlmesh_example.full_model`. + + - Verified environments and models to diff along with the join on grain configured. + - Showed me schema diffs between the environments. + - Showed me row count diffs between the environments. + - Showed me common rows stats between the environments. + - Showed me sample data differences between the environments. + - This is where your human judgement comes in to verify the changes are behaving as expected. + + Model definition: + ```sql linenums="1" hl_lines="6" + -- models/full_model.sql + MODEL ( + name sqlmesh_example.full_model, + kind FULL, + cron '@daily', + grain item_id, -- grain is optional BUT necessary for table diffs to work correctly. It's your primary key that is unique and not null. + audits (assert_positive_order_ids), + ); + + SELECT + item_id, + COUNT(DISTINCT id) AS num_orders, + new_column + FROM + sqlmesh_example.incremental_model + GROUP BY item_id, new_column + ``` + + Table diff: + ```bash + > sqlmesh table_diff prod:dev sqlmesh_example.full_model --show-sample + Table Diff + ├── Model: + │ └── sqlmesh_example.full_model + ├── Environment: + │ ├── Source: prod + │ └── Target: dev + ├── Tables: + │ ├── Source: db.sqlmesh_example.full_model + │ └── Target: db.sqlmesh_example__dev.full_model + └── Join On: + └── item_id + + Schema Diff Between 'PROD' and 'DEV' environments for model 'sqlmesh_example.full_model': + └── Schemas match + + + Row Counts: + └── PARTIAL MATCH: 5 rows (100.0%) + + COMMON ROWS column comparison stats: + pct_match + num_orders 100.0 + new_column 0.0 + + + COMMON ROWS sample data differences: + Column: new_column + ┏━━━━━━━━━┳━━━━━━┳━━━━━┓ + ┃ item_id ┃ PROD ┃ DEV ┃ + ┡━━━━━━━━━╇━━━━━━╇━━━━━┩ + │ -11 │ 5 │ 7 │ + │ -3 │ 5 │ 7 │ + │ 1 │ 5 │ 7 │ + │ 3 │ 5 │ 7 │ + │ 9 │ 5 │ 7 │ + └─────────┴──────┴─────┘ + ``` + +### Apply Changes to Prod + +After you feel confident about the changes, apply them to `prod`. + +!!! warning "Apply the changes to prod" + We recommend only applying changes to `prod` [**using CI/CD**](../integrations/github.md) as best practice. + For learning purposes and hot fixes, you can manually apply the changes to prod by entering `y` at the prompt. + +=== "SQLMesh" + + ```bash + sqlmesh plan + ``` + +=== "Tobiko Cloud" + + ```bash + tcloud sqlmesh plan + ``` + +??? "Example Output" + After I feel confident about the changes, I apply them to `prod`. + + SQLMesh: + + - Showed me the models impacted by the changes. + - Showed me the changes that will be made to the models. + - Showed me the models that need to be backfilled. None in this case as it was already backfilled earlier in `dev`. + - Prompted me to apply the changes to `prod`. + - Showed me physical layer and execution steps are skipped as the changes were already applied to `dev`. + - Updated the virtual layer view pointers to reflect the changes. + + ```bash + > sqlmesh plan + Differences from the `prod` environment: + + Models: + ├── Directly Modified: + │ ├── sqlmesh_example.full_model + │ └── sqlmesh_example.incremental_model + └── Indirectly Modified: + └── sqlmesh_example.view_model + + --- + + +++ + + @@ -9,7 +9,8 @@ + + SELECT + item_id, + COUNT(DISTINCT id) AS num_orders, + - 5 AS new_column + + new_column + FROM sqlmesh_example.incremental_model + GROUP BY + - item_id + + item_id, + + new_column + + Directly Modified: sqlmesh_example.full_model (Breaking) + + --- + + +++ + + @@ -15,7 +15,7 @@ + + id, + item_id, + event_date, + - 5 AS new_column + + 7 AS new_column + FROM sqlmesh_example.seed_model + WHERE + event_date BETWEEN @start_date AND @end_date + + Directly Modified: sqlmesh_example.incremental_model (Breaking) + └── Indirectly Modified Children: + └── sqlmesh_example.view_model (Indirect Breaking) + Apply - Virtual Update [y/n]: y + + SKIP: No physical layer updates to perform + + SKIP: No model batches to execute + + Updating virtual layer ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 3/3 • 0:00:00 + + ✔ Virtual layer updated + ``` + +--- + +## Enhanced Testing Workflow + +You'll use these commands to validate your changes are behaving as expected. Audits (data tests) are a great first step, and you'll want to grow from there to feel confident about your pipelines. The workflow is as follows: + +1. Create and audit external models outside of SQLMesh's control (ex: data loaded in by Fivetran, Airbyte, etc.) +2. Automatically generate unit tests for your models +3. Ad hoc query the data directly in the CLI +4. Lint your models to catch known syntax errors + +--- + +### Create and Audit External Models + +Sometimes models `SELECT` from tables/views that are outside of SQLMesh's control. SQLMesh can automatically parse their fully qualified names from model definitions (ex: `bigquery-public-data`.`ga4_obfuscated_sample_ecommerce`.`events_20210131`) and determine their full schemas and column data types. + +These "external model" schemas are used for column level lineage. You can also add audits to test data quality. If an audit fails, SQLMesh prevents downstream models from wastefully running. + +=== "SQLMesh" + + ```bash + sqlmesh create_external_models + ``` + +=== "Tobiko Cloud" + + ```bash + tcloud sqlmesh create_external_models + ``` + +??? "Example Output" + Note: this is an example from a separate Tobiko Cloud project, so you can't follow along in the Github repo. + + - Generated external models from the `bigquery-public-data`.`ga4_obfuscated_sample_ecommerce`.`events_20210131` table parsed in the model's SQL. + - Added an audit to the external model to ensure `event_date` is not NULL. + - Viewed a plan preview of the changes that will be made for the external model. + + ```sql linenums="1" hl_lines="29" title="models/external_model_example.sql" + MODEL ( + name tcloud_demo.external_model + ); + + SELECT + event_date, + event_timestamp, + event_name, + event_params, + event_previous_timestamp, + event_value_in_usd, + event_bundle_sequence_id, + event_server_timestamp_offset, + user_id, + user_pseudo_id, + privacy_info, + user_properties, + user_first_touch_timestamp, + user_ltv, + device, + geo, + app_info, + traffic_source, + stream_id, + platform, + event_dimensions, + ecommerce + /* items */ + FROM bigquery-public-data.ga4_obfuscated_sample_ecommerce.events_20210131 -- I fully qualified the external table name and sqlmesh will automatically create the external model + ``` + + `sqlmesh create_external_models` output file: + + ```yaml linenums="1" hl_lines="2 3 4" title="external_models.yaml" + - name: '`bigquery-public-data`.`ga4_obfuscated_sample_ecommerce`.`events_20210131`' + audits: # I added this audit manually to the external model YAML file + - name: not_null + columns: "[event_date]" + columns: + event_date: STRING + event_timestamp: INT64 + event_name: STRING + event_params: ARRAY>> + event_previous_timestamp: INT64 + event_value_in_usd: FLOAT64 + event_bundle_sequence_id: INT64 + event_server_timestamp_offset: INT64 + user_id: STRING + user_pseudo_id: STRING + privacy_info: STRUCT + user_properties: ARRAY>> + user_first_touch_timestamp: INT64 + user_ltv: STRUCT + device: STRUCT> + geo: STRUCT + app_info: STRUCT + traffic_source: STRUCT + stream_id: INT64 + platform: STRING + event_dimensions: STRUCT + ecommerce: STRUCT + items: ARRAY> + gateway: public-demo + ``` + + ```bash + > sqlmesh plan dev_sung + Differences from the `dev_sung` environment: + + Models: + └── Metadata Updated: + └── "bigquery-public-data".ga4_obfuscated_sample_ecommerce__dev_sung.events_20210131 + + --- + + +++ + + @@ -29,5 +29,6 @@ + + ecommerce STRUCT, + items ARRAY> + ), + + audits (not_null('columns' = [event_date])), + gateway `public-demo` + ) + + Metadata Updated: "bigquery-public-data".ga4_obfuscated_sample_ecommerce__dev_sung.events_20210131 + Models needing backfill: + └── "bigquery-public-data".ga4_obfuscated_sample_ecommerce__dev_sung.events_20210131: [full refresh] + Apply - Backfill Tables [y/n]: + ``` + +### Automatically Generate Unit Tests + +You can ensure business logic is working as expected by running your models against static sample data. + +Unit tests run *before* a plan is applied automatically. This is great for testing complex business logic (ex: `CASE WHEN` conditions) *before* you backfill data. No need to write them manually, either! + +=== "SQLMesh" + + Create a unit test based on 5 rows from the upstream `sqlmesh_example.incremental_model`. + + ```bash + sqlmesh create_test sqlmesh_example.full_model \ + --query sqlmesh_example.incremental_model \ + "select * from sqlmesh_example.incremental_model limit 5" + ``` + + ```bash + sqlmesh create_test \ + --query \ + "select * from limit 5" + ``` + + +=== "Tobiko Cloud" + + ```bash + tcloud sqlmesh create_test demo.stg_payments \ + --query demo.seed_raw_payments \ + "select * from demo.seed_raw_payments limit 5" + ``` + + ```bash + tcloud sqlmesh create_test \ + --query \ + "select * from limit 5" + ``` + +??? "Example Output" + + SQLMesh: + + - Generated unit tests for the `sqlmesh_example.full_model` model by live querying the data. + - Ran the tests and they passed locally in DuckDB. + - If you're using a cloud data warehouse, this will transpile your SQL syntax to its equivalent in duckdb. + - This runs fast and free on your local machine. + + Generated test definition file: + + ```yaml linenums="1" title="tests/test_full_model.yaml" + test_full_model: + model: '"db"."sqlmesh_example"."full_model"' + inputs: + '"db"."sqlmesh_example"."incremental_model"': + - id: -11 + item_id: -11 + event_date: 2020-01-01 + new_column: 7 + - id: 1 + item_id: 1 + event_date: 2020-01-01 + new_column: 7 + - id: 3 + item_id: 3 + event_date: 2020-01-03 + new_column: 7 + - id: 4 + item_id: 1 + event_date: 2020-01-04 + new_column: 7 + - id: 5 + item_id: 1 + event_date: 2020-01-05 + new_column: 7 + outputs: + query: + - item_id: 3 + num_orders: 1 + new_column: 7 + - item_id: 1 + num_orders: 3 + new_column: 7 + - item_id: -11 + num_orders: 1 + new_column: 7 + ``` + + Manually execute tests with `sqlmesh test`: + + ```bash + (demo) ➜ demo git:(main) ✗ sqlmesh test + . + ---------------------------------------------------------------------- + Ran 1 test in 0.053s + + OK + ``` + + ```bash + # what do we see if the test fails? + (demo) ➜ demo git:(main) ✗ sqlmesh test + F + ====================================================================== + FAIL: test_full_model (/Users/sung/Desktop/git_repos/sqlmesh-cli-revamp/tests/test_full_model.yaml) + None + ---------------------------------------------------------------------- + AssertionError: Data mismatch (exp: expected, act: actual) + + new_column + exp act + 0 0.0 7.0 + + ---------------------------------------------------------------------- + Ran 1 test in 0.020s + + FAILED (failures=1) + ``` + +### Run Ad-Hoc Queries + +You can run live queries directly from the CLI. This is great to validate the look and feel of your changes without context switching to your query console. + +Pro tip: run this after `sqlmesh table_diff` to get a full picture of your changes. + +=== "SQLMesh" + + ```bash + sqlmesh fetchdf "select * from sqlmesh_example__dev.full_model limit 5" + ``` + + ```bash + # construct arbitrary query + sqlmesh fetchdf "select * from . limit 5" # double underscore in schema name is important. Not needed for prod. + ``` + +=== "Tobiko Cloud" + + ```bash + tcloud sqlmesh fetchdf "select * from sqlmesh_example__dev.full_model limit 5" + ``` + + ```bash + # construct arbitrary query + tcloud sqlmesh fetchdf "select * from . limit 5" # double underscore in schema name is important. Not needed for prod. + ``` + +??? "Example Output" + ```bash + item_id num_orders new_column + 0 9 1 7 + 1 -11 1 7 + 2 3 1 7 + 3 -3 1 7 + 4 1 4 7 + ``` + +### Linting + +If enabled, linting runs automatically during development. The linting rules can be overridden per model, too. + +This is a great way to catch SQL issues before wasting runtime in your data warehouse. It runs automatically, or you can run it manually to proactively check for any issues. + +=== "SQLMesh" + + ```bash + sqlmesh lint + ``` + +=== "Tobiko Cloud" + + ```bash + tcloud sqlmesh lint + ``` + +??? "Example Output" + + You add linting rules in your `config.yaml` file. + + ```yaml linenums="1" hl_lines="13-17" title="config.yaml" + gateways: + duckdb: + connection: + type: duckdb + database: db.db + + default_gateway: duckdb + + model_defaults: + dialect: duckdb + start: 2025-03-26 + + linter: + enabled: true + rules: ["ambiguousorinvalidcolumn", "invalidselectstarexpansion"] # raise errors for these rules + warn_rules: ["noselectstar", "nomissingaudits"] + # ignored_rules: ["noselectstar"] + ``` + + ```bash + > sqlmesh lint + [WARNING] Linter warnings for /Users/sung/Desktop/git_repos/sqlmesh-cli-revamp/models/lint_warn.sql: + - noselectstar: Query should not contain SELECT * on its outer most projections, even if it can be + expanded. + - nomissingaudits: Model `audits` must be configured to test data quality. + [WARNING] Linter warnings for + /Users/sung/Desktop/git_repos/sqlmesh-cli-revamp/models/incremental_by_partition.sql: + - nomissingaudits: Model `audits` must be configured to test data quality. + [WARNING] Linter warnings for /Users/sung/Desktop/git_repos/sqlmesh-cli-revamp/models/seed_model.sql: + - nomissingaudits: Model `audits` must be configured to test data quality. + [WARNING] Linter warnings for + /Users/sung/Desktop/git_repos/sqlmesh-cli-revamp/models/incremental_by_unique_key.sql: + - nomissingaudits: Model `audits` must be configured to test data quality. + [WARNING] Linter warnings for + /Users/sung/Desktop/git_repos/sqlmesh-cli-revamp/models/incremental_model.sql: + - nomissingaudits: Model `audits` must be configured to test data quality. + ``` + +## Debugging Workflow + +You'll use these commands as needed to validate that your changes are behaving as expected. This is great to get more details beyond the defaults above. The workflow is as follows: + +1. Render the model to verify the SQL is looking as expected. +2. Run SQLMesh in verbose mode so you can verify its behavior. +3. View the logs easily in your terminal. + +### Render your SQL Changes + +This is a great way to verify that your model's SQL is looking as expected before applying the changes. It is especially important if you're migrating from one query engine to another (ex: postgres to databricks). + +=== "SQLMesh" + + ```bash + sqlmesh render sqlmesh_example.incremental_model + ``` + + ```bash + sqlmesh render sqlmesh_example.incremental_model --dialect databricks + ``` + + ```bash + sqlmesh render --dialect + ``` + +=== "Tobiko Cloud" + + ```bash + tcloud sqlmesh render sqlmesh_example.incremental_model + ``` + + ```bash + tcloud sqlmesh render sqlmesh_example.incremental_model --dialect databricks + ``` + + ```bash + tcloud sqlmesh render --dialect + ``` + +??? "Example Output" + + Model definition: + + ```sql linenums="1" title="models/incremental_model.sql" + MODEL ( + name sqlmesh_example.incremental_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column event_date + ), + start '2020-01-01', + cron '@daily', + grain (id, event_date) + ); + + SELECT + id, + item_id, + event_date, + 7 as new_column + FROM + sqlmesh_example.seed_model + WHERE + event_date BETWEEN @start_date AND @end_date + ``` + + SQLMesh returns the full SQL code in the default or target dialect. + + ```sql hl_lines="11" + > sqlmesh render sqlmesh_example.incremental_model + -- rendered sql in default dialect + SELECT + "seed_model"."id" AS "id", + "seed_model"."item_id" AS "item_id", + "seed_model"."event_date" AS "event_date", + 7 AS "new_column" + FROM "db"."sqlmesh__sqlmesh_example"."sqlmesh_example__seed_model__3294646944" AS "seed_model" /* + db.sqlmesh_example.seed_model */ + WHERE + "seed_model"."event_date" <= CAST('1970-01-01' AS DATE) -- placeholder dates for date macros + AND "seed_model"."event_date" >= CAST('1970-01-01' AS DATE) + ``` + + ```sql + > sqlmesh render sqlmesh_example.incremental_model --dialect databricks + -- rendered sql in databricks dialect + SELECT + `seed_model`.`id` AS `id`, + `seed_model`.`item_id` AS `item_id`, + `seed_model`.`event_date` AS `event_date`, + 7 AS `new_column` + FROM `db`.`sqlmesh__sqlmesh_example`.`sqlmesh_example__seed_model__3294646944` AS `seed_model` /* + db.sqlmesh_example.seed_model */ + WHERE + `seed_model`.`event_date` <= CAST('1970-01-01' AS DATE) + AND `seed_model`.`event_date` >= CAST('1970-01-01' AS DATE) + ``` + +### Apply Plan Changes in Verbose Mode + +Verbose mode lets you see detailed operations in the physical and virtual layers. This is useful to see exactly what SQLMesh is doing every step. After, you can copy/paste the fully qualified table/view name into your query console to validate the data (if that's your preference). + +=== "SQLMesh" + + ```bash + sqlmesh plan dev -vv + ``` + + ```bash + sqlmesh plan -vv + ``` + +=== "Tobiko Cloud" + + ```bash + tcloud sqlmesh plan dev -vv + ``` + + ```bash + tcloud sqlmesh plan -vv + ``` + +??? "Example Output" + + ```bash hl_lines="48-50" + > sqlmesh plan dev -vv + [WARNING] Linter warnings for + /Users/sung/Desktop/git_repos/sqlmesh-cli-revamp/models/incremental_by_partition.sql: + - nomissingaudits: Model `audits` must be configured to test data quality. + [WARNING] Linter warnings for /Users/sung/Desktop/git_repos/sqlmesh-cli-revamp/models/seed_model.sql: + - nomissingaudits: Model `audits` must be configured to test data quality. + [WARNING] Linter warnings for + /Users/sung/Desktop/git_repos/sqlmesh-cli-revamp/models/incremental_by_unique_key.sql: + - nomissingaudits: Model `audits` must be configured to test data quality. + [WARNING] Linter warnings for + /Users/sung/Desktop/git_repos/sqlmesh-cli-revamp/models/incremental_model.sql: + - nomissingaudits: Model `audits` must be configured to test data quality. + + Differences from the `dev` environment: + + Models: + ├── Directly Modified: + │ └── db.sqlmesh_example__dev.incremental_model + └── Indirectly Modified: + ├── db.sqlmesh_example__dev.full_model + └── db.sqlmesh_example__dev.view_model + + --- + + +++ + + @@ -15,7 +15,7 @@ + + id, + item_id, + event_date, + - 9 AS new_column + + 7 AS new_column + FROM sqlmesh_example.seed_model + WHERE + event_date BETWEEN @start_date AND @end_date + + Directly Modified: db.sqlmesh_example__dev.incremental_model (Breaking) + └── Indirectly Modified Children: + ├── db.sqlmesh_example__dev.full_model (Breaking) + └── db.sqlmesh_example__dev.view_model (Indirect Breaking) + Apply - Virtual Update [y/n]: y + + SKIP: No physical layer updates to perform + + SKIP: No model batches to execute + + db.sqlmesh_example__dev.incremental_model updated # you'll notice that it's updated vs. promoted because we changed the existing view definition + db.sqlmesh_example__dev.full_model updated + db.sqlmesh_example__dev.view_model updated + Updating virtual layer ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 3/3 • 0:00:00 + + ✔ Virtual layer updated + ``` + +### View Logs Easily + +Each time you perform a SQLMesh command, it creates a log file in the `logs` directory. You can view them by manually navigating to the correct file name with latest timestamp or with this simple shell command. + +This is useful to see the exact queries that were executed to apply your changes. Admittedly, this is outside of native functionality, but it's a quick and easy way to view logs. + +```bash +# install this open source tool that enhances the default `cat` command +# https://github.com/sharkdp/bat +brew install bat # installation command if using homebrew +``` + +```bash +bat --theme='ansi' $(ls -t logs/ | head -n 1 | sed 's/^/logs\//') +``` + +- In simple terms this command works like this: "Show me the contents of the newest log file in the `logs/` directory, with nice formatting and syntax highlighting.” +- press `q` to quit out of big files in the terminal + +??? "Example Output" + + This is the log file for the `sqlmesh plan dev` command. If you want to see the log file directly, you can click on the file path in the output to open it in your code editor. + + ```bash + ──────┬────────────────────────────────────────────────────────────────────────────────────────────── + │ File: logs/sqlmesh_2025_04_18_12_34_35.log + ──────┼────────────────────────────────────────────────────────────────────────────────────────────── + 1 │ 2025-04-18 12:34:35,715 - MainThread - sqlmesh.core.config.connection - INFO - Creating new D + │ uckDB adapter for data files: {'db.db'} (connection.py:319) + 2 │ 2025-04-18 12:34:35,951 - MainThread - sqlmesh.core.console - WARNING - Linter warnings for / + │ Users/sung/Desktop/git_repos/sqlmesh-cli-revamp/models/incremental_by_partition.sql: + 3 │ - nomissingaudits: Model `audits` must be configured to test data quality. (console.py:1848) + 4 │ 2025-04-18 12:34:35,953 - MainThread - sqlmesh.core.console - WARNING - Linter warnings for / + │ Users/sung/Desktop/git_repos/sqlmesh-cli-revamp/models/seed_model.sql: + 5 │ - nomissingaudits: Model `audits` must be configured to test data quality. (console.py:1848) + 6 │ 2025-04-18 12:34:35,953 - MainThread - sqlmesh.core.console - WARNING - Linter warnings for / + │ Users/sung/Desktop/git_repos/sqlmesh-cli-revamp/models/incremental_by_unique_key.sql: + 7 │ - nomissingaudits: Model `audits` must be configured to test data quality. (console.py:1848) + 8 │ 2025-04-18 12:34:35,953 - MainThread - sqlmesh.core.console - WARNING - Linter warnings for / + │ Users/sung/Desktop/git_repos/sqlmesh-cli-revamp/models/incremental_model.sql: + 9 │ - nomissingaudits: Model `audits` must be configured to test data quality. (console.py:1848) + 10 │ 2025-04-18 12:34:35,954 - MainThread - sqlmesh.core.config.connection - INFO - Using existing + │ DuckDB adapter due to overlapping data file: db.db (connection.py:309) + 11 │ 2025-04-18 12:34:37,071 - MainThread - sqlmesh.core.snapshot.evaluator - INFO - Listing data + │ objects in schema db.sqlmesh__sqlmesh_example (evaluator.py:338) + 12 │ 2025-04-18 12:34:37,072 - MainThread - sqlmesh.core.engine_adapter.base - INFO - Executing SQ + │ L: SELECT CURRENT_CATALOG() (base.py:2128) + 13 │ 2025-04-18 12:34:37,072 - MainThread - sqlmesh.core.engine_adapter.base - INFO - Executing SQ + │ L: SELECT CURRENT_CATALOG() (base.py:2128) + ``` + +## Run on Production Schedule + +SQLMesh schedules your transformation on a per-model basis in proper DAG order. This makes it easy to configure how often each step in your pipeline runs to backfill data. + +SQLMesh won't schedule models whose upstream models are late or failed, and they will rerun from point of failure by default! + +Example scenario and model DAG: + +`stg_transactions`(cron: `@hourly`) -> `fct_transcations`(cron: `@daily`). All times in UTC. + +1. `stg_transactions` runs hourly +2. `fct_transcations` runs at 12am UTC if `stg_transactions` is fresh and updated since its most recent hour interval +3. If `stg_transactions` failed from 11pm-11:59:59pm, it will prevent `fct_transcations` from running and put it in a `pending` state +4. If `fct_transactions` is `pending` past its full interval (1 full day), it will be put in a `late` state +5. Once `stg_transactions` runs successfully either from a retry or a fix from a pull request, `fct_transactions` will rerun from the point of failure. This is true even if `fct_transactions` has been `late` for several days. + +Note: `pending` and `late` states are only supported in Tobiko Cloud. In SQLMesh, it will only understand if the model is ready or not ready to execute without mention of these states. + +If you're using open source SQLMesh, you can run this command in your orchestrator (ex: Dagster, GitHub Actions, etc.) every 5 minutes or at your lowest model cron schedule (ex: every 1 hour). Don't worry! It will only run executions that need to be run. + +If you're using Tobiko Cloud, this configures automatically without additional configuration. + +### Run Models + +This command is intended be run on a schedule. It will skip the physical and virtual layer updates and simply execute the model batches. + +=== "SQLMesh" + + ```bash + sqlmesh run + ``` + +=== "Tobiko Cloud" + + ```bash + tcloud sqlmesh run + ``` + +??? "Example Output" + + This is what it looks like if models are ready to run. + + ```bash + > sqlmesh run + [1/1] sqlmesh_example.incremental_model [insert 2025-04-17 - 2025-04-17] + 0.01s + [1/1] sqlmesh_example.incremental_unique_model [insert/update rows] + 0.01s + [1/1] sqlmesh_example_v3.incremental_partition_model [insert partitions] + 0.01s + Executing model batches ━━━━━━━━━━━━━━━━╺━━━━━━━━━━━━━━━━━━━━━━━ 40.0% • 2/5 • 0:00:00 + sqlmesh_example_v3.incremental_partition_model . + [WARNING] sqlmesh_example.full_model: 'assert_positive_order_ids' audit error: 2 rows failed. Learn + more in logs: /Users/sung/Desktop/git_repos/sqlmesh-cli-revamp/logs/sqlmesh_2025_04_18_12_48_35.log + [1/1] sqlmesh_example.full_model [full refresh, audits ❌1] + 0.01s + Executing model batches ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╺━━━━━━━ 80.0% • 4/5 • 0:00:00 + sqlmesh_example.view_model . + [WARNING] sqlmesh_example.view_model: 'assert_positive_order_ids' audit error: 2 rows failed. Learn + more in logs: /Users/sung/Desktop/git_repos/sqlmesh-cli-revamp/logs/sqlmesh_2025_04_18_12_48_35.log + [1/1] sqlmesh_example.view_model [recreate view, audits ✔2 ❌1] + 0.01s + Executing model batches ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 5/5 • 0:00:00 + + ✔ Model batches executed + + Run finished for environment 'prod' + ``` + + This is what it looks like if no models are ready to run. + + ```bash + > sqlmesh run + No models are ready to run. Please wait until a model `cron` interval has elapsed. + + Next run will be ready at 2025-04-18 05:00PM PDT (2025-04-19 12:00AM UTC). + ``` + +### Run Models with Incomplete Intervals (Warning) + +You can run models that execute backfills each time you invoke a `run`, whether ad hoc or on a schedule. + +!!! warning "Run Models with Incomplete Intervals" + This only applies to incremental models that have `allow_partials` set to `true`. + This is generally not recommended for production environments as you risk shipping incomplete data which will be perceived as broken data. + +=== "SQLMesh" + + ```bash + sqlmesh run --ignore-cron + ``` + +=== "Tobiko Cloud" + + ```bash + tcloud sqlmesh run --ignore-cron + ``` + +??? "Example Output" + + Model definition: + ```sql linenums="1" hl_lines="15" title="models/incremental_model.sql" + MODEL ( + name sqlmesh_example.incremental_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column event_date + ), + start '2020-01-01', + cron '@daily', + grain (id, event_date), + audits( UNIQUE_VALUES(columns = ( + id, + )), NOT_NULL(columns = ( + id, + event_date + ))), + allow_partials true + ); + + SELECT + id, + item_id, + event_date, + 16 as new_column + FROM + sqlmesh_example.seed_model + WHERE + event_date BETWEEN @start_date AND @end_date + ``` + + ```bash + > sqlmesh run --ignore-cron + [1/1] sqlmesh_example.incremental_model [insert 2025-04-19 - 2025-04-19, audits ✔2] 0.05s + Executing model batches ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 1/1 • 0:00:00 + + ✔ Model batches executed + + Run finished for environment 'prod' + ``` + +## Forward-Only Development Workflow + +This is an advanced workflow and specifically designed for large incremental models (ex: > 200 million rows) that take a long time to run even during development. It solves for: + +- Transforming data with schema evolution in `struct` and nested `array` data types. +- Retaining history of a calculated column and applying a new calculation to new rows going forward. +- Retain history of a column with complex conditional `CASE WHEN` logic and apply new conditions to new rows going forward. + +When you modify a forward-only model and apply the plan to `prod` after the dev workflow, it will NOT backfill historical data. It will only execute model batches for new intervals **going forward in time** (i.e., only for new rows). + +If you want to see a full walkthrough, [go here](incremental_time_full_walkthrough.md). + +=== "SQLMesh" + + ```bash + sqlmesh plan dev --forward-only + ``` + + ```bash + sqlmesh plan --forward-only + ``` + +=== "Tobiko Cloud" + + ```bash + tcloud sqlmesh plan dev --forward-only + ``` + + ```bash + tcloud sqlmesh plan --forward-only + ``` + +??? "Example Output" + + - I applied a change to a new column + - It impacts 2 downstream models + - I enforced a forward-only plan to avoid backfilling historical data for the incremental model (ex: `preview` language in the CLI output) + - I previewed the changes in a clone of the incremental impacted (clones will NOT be reused in production) along with the full and view models (these are NOT clones). + + ```bash + > sqlmesh plan dev + Differences from the `dev` environment: + + Models: + ├── Directly Modified: + │ └── sqlmesh_example__dev.incremental_model + └── Indirectly Modified: + ├── sqlmesh_example__dev.view_model + └── sqlmesh_example__dev.full_model + + --- + + +++ + + @@ -16,7 +16,7 @@ + + id, + item_id, + event_date, + - 9 AS new_column + + 10 AS new_column + FROM sqlmesh_example.seed_model + WHERE + event_date BETWEEN @start_date AND @end_date + + Directly Modified: sqlmesh_example__dev.incremental_model (Forward-only) + └── Indirectly Modified Children: + ├── sqlmesh_example__dev.full_model (Forward-only) + └── sqlmesh_example__dev.view_model (Forward-only) + Models needing backfill: + ├── sqlmesh_example__dev.full_model: [full refresh] (preview) + ├── sqlmesh_example__dev.incremental_model: [2025-04-17 - 2025-04-17] (preview) + └── sqlmesh_example__dev.view_model: [recreate view] (preview) + Apply - Preview Tables [y/n]: y + + Updating physical layer ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 3/3 • 0:00:00 + + ✔ Physical layer updated + + [1/1] sqlmesh_example__dev.incremental_model [insert 2025-04-17 - 2025-04-17] 0.01s + [1/1] sqlmesh_example__dev.full_model [full refresh, audits ✔1] 0.01s + [1/1] sqlmesh_example__dev.view_model [recreate view, audits ✔3] 0.01s + Executing model batches ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 3/3 • 0:00:00 + + ✔ Model batches executed + + Updating virtual layer ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 3/3 • 0:00:00 + + ✔ Virtual layer updated + ``` + + When the plan is applied to `prod`, it will only execute model batches for new intervals (new rows). This will NOT re-use `preview` models (backfilled data) in development. + + ```bash + > sqlmesh plan + Differences from the `prod` environment: + + Models: + ├── Directly Modified: + │ └── sqlmesh_example.incremental_model + └── Indirectly Modified: + ├── sqlmesh_example.view_model + └── sqlmesh_example.full_model + + --- + + +++ + + @@ -9,13 +9,14 @@ + + disable_restatement FALSE, + on_destructive_change 'ERROR' + ), + - grains ((id, event_date)) + + grains ((id, event_date)), + + allow_partials TRUE + ) + SELECT + id, + item_id, + event_date, + - 7 AS new_column + + 10 AS new_column + FROM sqlmesh_example.seed_model + WHERE + event_date BETWEEN @start_date AND @end_date + + Directly Modified: sqlmesh_example.incremental_model (Forward-only) + └── Indirectly Modified Children: + ├── sqlmesh_example.full_model (Forward-only) + └── sqlmesh_example.view_model (Forward-only) + Apply - Virtual Update [y/n]: y + + SKIP: No physical layer updates to perform + + SKIP: No model batches to execute + + Updating virtual layer ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 3/3 • 0:00:00 + + ✔ Virtual layer updated + ``` + + +## Miscellaneous + +If you notice you have a lot of old development schemas/data, you can clean them up with the following command. This process runs automatically during the `sqlmesh run` command. This defaults to deleting data older than 7 days. + +=== "SQLMesh" + + ```bash + sqlmesh janitor + ``` + +=== "Tobiko Cloud" + + ```bash + tcloud sqlmesh janitor + ``` \ No newline at end of file diff --git a/docs/faq/faq.md b/docs/faq/faq.md index 74014753c2..b4a0d7e4d4 100644 --- a/docs/faq/faq.md +++ b/docs/faq/faq.md @@ -57,7 +57,7 @@ ??? question "What is semantic understanding of SQL?" Semantic understanding is the result of analyzing SQL code to determine what it does at a granular level. SQLMesh uses the free, open-source Python library [SQLGlot](https://github.com/tobymao/sqlglot) to parse the SQL code and build the semantic understanding. - Semantic understanding allows SQLMesh to do things like transpilation (executing one SQL dialect on an engine running another dialect) and protecting incremental loading queries from duplicating data. + Semantic understanding allows SQLMesh to do things like transpilation (executing one SQL dialect on an engine running another dialect) and preventing incremental loading queries from duplicating data. ??? question "Does SQLMesh work like Terraform?" SQLMesh was inspired by Terraform, but its commands are not equivalent. @@ -74,7 +74,7 @@ SQLMesh is a Python library. After ensuring you have [an appropriate Python runtime](../prerequisites.md), install it [with `pip`](../installation.md). ??? question "How do I use SQLMesh?" - SQLMesh has three interfaces: [command line](../reference/cli.md), [Jupyter or Databricks notebook](../reference/notebook.md), and graphical user interface. + SQLMesh has three interfaces: [command line](../reference/cli.md), [Jupyter or Databricks notebook](../reference/notebook.md), and [graphical user interface](../guides/ui.md). The [quickstart guide](../quick_start.md) demonstrates an example project in each of the interfaces. @@ -85,7 +85,9 @@ SQLMesh creates schemas for two reasons: - SQLMesh stores state/metadata information about a project in the `sqlmesh` schema. This schema is created in the project's default gateway, or you can [specify a different location](../reference/configuration.md#state-connection). - - SQLMesh uses [Virtual Data Environments](https://tobikodata.com/virtual-data-environments.html) to prevent duplicative computation whenever possible. + - SQLMesh uses [Virtual Data Environments](https://tobikodata.com/virtual-data-environments.html) to prevent duplicative computation whenever possible, and stores environment-specific objects in separate schemas by default. + + How Virtual Data Environments work: Virtual Data Environments work by maintaining a *virtual layer* of views that users interact with when building models and a *physical layer* of tables that stores the actual data. @@ -102,8 +104,12 @@ ??? question "What's the difference between a `test` and an `audit`?" A SQLMesh [`test`](../concepts/tests.md) is analogous to a "unit test" in software engineering. It tests *code* based on known inputs and outputs. In SQLMesh, the inputs and outputs are specified in a YAML file, and SQLMesh automatically runs them when `sqlmesh plan` is executed. + Writing YAML is annoying and error-prone, so SQLMesh's [`create_test` command](../concepts/tests.md#automatic-test-generation) allows you to automatically generate YAML test files based on queries of existing data tables. + A SQLMesh [`audit`](../concepts/audits.md) validates that transformed *data* meet some criteria. For example, an `audit` might verify that a column contains no `NULL` values or has no duplicated values. SQLMesh automatically runs audits when a `sqlmesh plan` is executed and the plan is applied or when `sqlmesh run` is executed. + When the `sqlmesh plan` command is executed, SQLMesh `test`s run _before_ any model's code is executed. A SQLMesh model's `audit`s run _after_ the model's code is executed to validate the data output by the model. + ??? question "How does a model know when to run?" A SQLMesh model determines when to run based on its [`cron`](#cron-question) parameter and how much time has elapsed since its previous run. @@ -122,7 +128,11 @@ SQLMesh’s `plan` command is the primary tool for understanding the effects of changes you make to your project. If your project files have changed or are different from the state of an environment, you execute `sqlmesh plan [environment name]` to synchronize the environment's state with your project files. `sqlmesh plan` will generate a summary of the actions needed to implement the changes, automatically run unit tests, and prompt you to `apply` the plan and implement the changes. - If your project files have not changed, you execute `sqlmesh run` to run your project's models and audits. You can execute `sqlmesh run` yourself or with the native [Airflow integration](../integrations/airflow.md). If running it yourself, a sensible approach is to use Linux’s `cron` tool to execute `sqlmesh run` on a cadence at least as frequent as your briefest SQLMesh model `cron` parameter. For example, if your most frequent model’s `cron` is hour, your `cron` tool should execute `sqlmesh run` at least every hour. + If your project files have not changed, you execute `sqlmesh run` to run your project's models and audits. + + `sqlmesh run` does not use models, macros, or audits from your local project files. Everything it executes is based on the model, macro, and audit versions currently promoted in the target environment. Those versions are stored in the metadata SQLMesh captures about the state of your environment. + + A sensible approach to executing `sqlmesh run` is to use Linux’s `cron` tool to execute `sqlmesh run` on a cadence at least as frequent as your briefest SQLMesh model `cron` parameter. For example, if your most frequent model’s `cron` is hour, your `cron` tool should execute `sqlmesh run` at least every hour. ??? question "What are start date and end date for?" SQLMesh uses the ["intervals" approach](https://tobikodata.com/data_load_patterns_101.html) to determine the date ranges that should be included in an incremental by time model query. It divides time into disjoint intervals and tracks which intervals have ever been processed. @@ -143,6 +153,11 @@ You can retroactively apply the forward-only plan's changes to existing data in the production environment with [`plan`'s `--effective-from` option](../reference/cli.md#plan). +??? question "How can I force a model to run now?" + Ensure that the model's `allow_partials` attribute is set to `true` and execute the `run` command with the `--ignore-cron` option: `sqlmesh run --ignore-cron`. + + See the documentation for [allow_partials](../concepts/models/overview.md#allow_partials) to understand the rationale behind this. + ## Databases/Engines @@ -152,14 +167,17 @@ ## Scheduling ??? question "How do I run SQLMesh models on a schedule?" - You can run SQLMesh models using the [built-in scheduler](../guides/scheduling.md#built-in-scheduler) or with the native [Airflow integration](../integrations/airflow.md). + You can run SQLMesh models using the [built-in scheduler](../guides/scheduling.md#built-in-scheduler) or using [Tobiko Cloud](../cloud/features/scheduler/scheduler.md) Both approaches use each model's `cron` parameter to determine when the model should run - see the [question about `cron` above](#cron-question) for more information. The built-in scheduler works by executing the command `sqlmesh run`. A sensible approach to running on your project on a schedule is to use Linux’s `cron` tool to execute `sqlmesh run` on a cadence at least as frequent as your briefest SQLMesh model `cron` parameter. For example, if your most frequent model’s `cron` is hour, the `cron` tool should execute `sqlmesh run` at least every hour. ??? question "How do I use SQLMesh with Airflow?" - SQLMesh has first-class support for Airflow - learn more [here](../integrations/airflow.md). + Tobiko Cloud offers first-class support for Airflow - learn more [here](../cloud/features/scheduler/airflow.md) + +??? question "How do I use SQLMesh with Dagster?" + Tobiko Cloud offers first-class support for Dagster - learn more [here](../cloud/features/scheduler/dagster.md) ## Warnings and Errors @@ -221,14 +239,14 @@ SQLMesh always maintains state about the project structure, contents, and past runs. State information enables powerful SQLMesh features like virtual data environments and easy incremental loads. - State information is stored by default - you do not need to take any action to maintain or to use it when executing models. As the dbt caveats page says, state information is powerful but complex. SQLMesh handles that complexity for you so you don't need to learn about or understand the underlying mechanics. + State information is stored by default - you do not need to take any action to maintain or to use it when executing models. As the dbt caveats page says, state information is powerful but complex. SQLMesh handles that complexity for you so you don't need to worry about the underlying mechanics. - SQLMesh stores state information in database tables. By default, it stores this information in the same [database/connection where your project models run](../reference/configuration.md#gateways). You can specify a [different database/connection](../reference/configuration.md#state-connection) if you would prefer to store state information somewhere else. + SQLMesh stores state information in database tables. By default, it stores this information in the same [database/connection where your project models run](../reference/configuration.md#gateways). You can specify a [different database/connection](../reference/configuration.md#state-connection) if you would prefer to store state information somewhere else. We recommend using a separate connection for storing state in production deployments. SQLMesh adds information to the state tables via transactions, and some databases like BigQuery are not optimized to execute transactions. Changing the state connection to another database like PostgreSQL can alleviate performance issues you may encounter due to state transactions. ??? question "How do I get column-level lineage for my dbt project?" - SQLMesh can run dbt projects with its [dbt adapter](../integrations/dbt.md). After configuring the dbt project to work with SQLMesh, you can view the column-level lineage in the SQLMesh browser UI: + SQLMesh can run dbt projects with its [dbt adapter](../integrations/dbt.md). After configuring the dbt project to work with SQLMesh, you can view the column-level lineage in the [SQLMesh browser UI](../guides/ui.md): ![SQLMesh UI displaying column-level lineage for Sushi dbt example](./faq/ui-colum-lineage_sushi-dbt.png) @@ -240,7 +258,7 @@ ??? question "How do incremental models determine which dates to ingest?" dbt uses the "most recent record" approach to determine which dates should be included in an incremental load. It works by querying the existing data for the most recent date it contains, then ingesting all records after that date from the source system in a single query. - SQLMesh uses the "intervals" approach instead. It divides time into disjoint intervals based on a model's `cron` parameter then records which intervals have ever been processed. It ingests source records from only unprocessed intervals. The intervals approach enables features like loading in batches. + SQLMesh uses the ["intervals" approach](https://tobikodata.com/data_load_patterns_101.html) instead. It divides time into disjoint intervals based on a model's `cron` parameter then records which intervals have ever been processed. It ingests source records from only unprocessed intervals. The intervals approach enables features like loading in batches. ??? question "How do I run an append only model in SQLMesh?" SQLMesh does not support append-only models as implemented in dbt. You can achieve a similar outcome by defining a time column and using an [incremental by time range](../concepts/models/model_kinds.md#incremental_by_time_range) model or specifying a unique key and using an [incremental by unique key](../concepts/models/model_kinds.md#incremental_by_unique_key) model. @@ -249,7 +267,7 @@ ??? question "How does Tobiko Data make money?" - - Model execution observability and monitoring tools (in development) + - Tobiko Cloud: learn more [here](https://tobikodata.com/product.html) - Enterprise Github Actions CI/CD App (in development) - Advanced version of [open source CI/CD bot](../integrations/github.md) - Providing hands-on support for companies' SQLMesh projects diff --git a/docs/guides/configuration.md b/docs/guides/configuration.md index 63a6b32319..d6d4f20c11 100644 --- a/docs/guides/configuration.md +++ b/docs/guides/configuration.md @@ -21,6 +21,9 @@ The sources have the following order of precedence: 2. `config.yaml` or `config.py` in the `~/.sqlmesh` folder. 3. `config.yaml` or `config.py` in a project folder. [LOWEST PRECEDENCE] +!!! note + To relocate the `.sqlmesh` folder, set the `SQLMESH_HOME` environment variable to your preferred directory path. + ### File type You can specify a SQLMesh configuration in either YAML or Python. @@ -98,7 +101,52 @@ All software runs within a system environment that stores information as "enviro SQLMesh can access environment variables during configuration, which enables approaches like storing passwords/secrets outside the configuration file and changing configuration parameters dynamically based on which user is running SQLMesh. -You can use environment variables in two ways: specifying them in the configuration file or creating properly named variables to override configuration file values. +You can specify environment variables in the configuration file or by storing them in a `.env` file. + +### .env files + +SQLMesh automatically loads environment variables from a `.env` file in your project directory. This provides a convenient way to manage environment variables without having to set them in your shell. + +Create a `.env` file in your project root with key-value pairs: + +```bash +# .env file +SNOWFLAKE_PW=my_secret_password +S3_BUCKET=s3://my-data-bucket/warehouse +DATABASE_URL=postgresql://user:pass@localhost/db + +# Override specific SQLMesh configuration values +SQLMESH__DEFAULT_GATEWAY=production +SQLMESH__MODEL_DEFAULTS__DIALECT=snowflake +``` + +See the [overrides](#overrides) section for a detailed explanation of how these are defined. + +The rest of the `.env` file variables can be used in your configuration files with `{{ env_var('VARIABLE_NAME') }}` syntax in YAML or accessed via `os.environ['VARIABLE_NAME']` in Python. + +#### Custom dot env file location and name + +By default, SQLMesh loads `.env` files from each project directory. However, you can specify a custom path using the `--dotenv` CLI flag directly when running a command: + +```bash +sqlmesh --dotenv /path/to/custom/.env plan +``` + +!!! note + The `--dotenv` flag is a global option and must be placed **before** the subcommand (e.g. `plan`, `run`), not after. + +Alternatively, you can export the `SQLMESH_DOTENV_PATH` environment variable once, to persist a custom path across all subsequent commands in your shell session: + +```bash +export SQLMESH_DOTENV_PATH=/path/to/custom/.custom_env +sqlmesh plan +sqlmesh run +``` + +**Important considerations:** +- Add `.env` to your `.gitignore` file to avoid committing sensitive information +- SQLMesh will only load the `.env` file if it exists in the project directory (unless a custom path is specified) +- When using a custom path, that specific file takes precedence over any `.env` file in the project directory. ### Configuration file @@ -151,6 +199,55 @@ The examples specify a Snowflake connection whose password is stored in an envir ) ``` +#### Default target environment + +The SQLMesh `plan` command acts on the `prod` environment by default (i.e., `sqlmesh plan` is equivalent to `sqlmesh plan prod`). + +In some organizations, users never run plans directly against `prod` - they do all SQLMesh work in a development environment unique to them. In a standard SQLMesh configuration, this means they need to include their development environment name every time they issue the `plan` command (e.g., `sqlmesh plan dev_tony`). + +If your organization works like this, it may be convenient to change the `plan` command's default environment from `prod` to each user's development environment. That way people can issue `sqlmesh plan` without typing the environment name every time. + +The SQLMesh configuration `user()` function returns the name of the user currently logged in and running SQLMesh. It retrieves the username from system environment variables like `USER` on MacOS/Linux or `USERNAME` on Windows. + +Call `user()` inside Jinja curly braces with the syntax `{{ user() }}`, which allows you to combine the user name with a prefix or suffix. + +The example configuration below constructs the environment name by appending the username to the end of the string `dev_`. If the user running SQLMesh is `tony`, the default target environment when they run SQLMesh will be `dev_tony`. In other words, `sqlmesh plan` will be equivalent to `sqlmesh plan dev_tony`. + +=== "YAML" + + Default target environment is `dev_` combined with the username running SQLMesh. + + ```yaml + default_target_environment: dev_{{ user() }} + ``` + +=== "Python" + + Default target environment is `dev_` combined with the username running SQLMesh. + + Retrieve the username with the `getpass.getuser()` function, and combine it with `dev_` in a Python f-string. + + ```python linenums="1" hl_lines="1 17" + import getpass + import os + from sqlmesh.core.config import ( + Config, + ModelDefaultsConfig, + GatewayConfig, + SnowflakeConnectionConfig + ) + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + gateways={ + "my_gateway": GatewayConfig( + connection=DuckDBConnectionConfig(), + ), + }, + default_target_environment=f"dev_{getpass.getuser()}", + ) + ``` + ### Overrides Environment variables have the highest precedence among configuration methods, as [noted above](#configuration-files). They will automatically override configuration file specifications if they follow a specific naming structure. @@ -194,24 +291,57 @@ Conceptually, we can group the root level parameters into the following types. E The rest of this page provides additional detail for some of the configuration options and provides brief examples. Comprehensive lists of configuration options are at the [configuration reference page](../reference/configuration.md). +### Cache directory + +By default, the SQLMesh cache is stored in a `.cache` directory within your project folder. You can customize the cache location using the `cache_dir` configuration option: + +=== "YAML" + + ```yaml linenums="1" + # Relative path to project directory + cache_dir: my_custom_cache + + # Absolute path + cache_dir: /tmp/sqlmesh_cache + + ``` + +=== "Python" + + ```python linenums="1" + from sqlmesh.core.config import Config, ModelDefaultsConfig + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + cache_dir="/tmp/sqlmesh_cache", + ) + ``` + +The cache directory is automatically created if it doesn't exist. You can clear the cache using the `sqlmesh clean` command. + ### Table/view storage locations SQLMesh creates schemas, physical tables, and views in the data warehouse/engine. Learn more about why and how SQLMesh creates schema in the ["Why does SQLMesh create schemas?" FAQ](../faq/faq.md#schema-question). -The default SQLMesh behavior described in the FAQ is appropriate for most deployments, but you can override where SQLMesh creates physical tables and views with the `physical_schema_override`, `environment_suffix_target`, and `environment_catalog_mapping` configuration options. These options are in the [environments](../reference/configuration.md#environments) section of the configuration reference page. +The default SQLMesh behavior described in the FAQ is appropriate for most deployments, but you can override *where* SQLMesh creates physical tables and views with the `physical_schema_mapping`, `environment_suffix_target`, and `environment_catalog_mapping` configuration options. + +You can also override *what* the physical tables are called by using the `physical_table_naming_convention` option. + +These options are in the [environments](../reference/configuration.md#environments) section of the configuration reference page. #### Physical table schemas -By default, SQLMesh creates physical tables for a model with a naming convention of `sqlmesh__[model schema]`. +By default, SQLMesh creates physical schemas for a model with a naming convention of `sqlmesh__[model schema]`. -This can be overridden on a per-schema basis using the `physical_schema_override` option, which removes the `sqlmesh__` prefix and uses the name you provide. +This can be overridden on a per-schema basis using the `physical_schema_mapping` option, which removes the `sqlmesh__` prefix and uses the [regex pattern](https://docs.python.org/3/library/re.html#regular-expression-syntax) you provide to map the schemas defined in your model to their corresponding physical schemas. -This example configuration overrides the default physical schemas for the `my_schema` model schema: +This example configuration overrides the default physical schemas for the `my_schema` model schema and any model schemas starting with `dev`: === "YAML" ```yaml linenums="1" - physical_schema_override: - my_schema: my_new_schema + physical_schema_mapping: + '^my_schema$': my_new_schema, + '^dev.*': development ``` === "Python" @@ -221,19 +351,31 @@ This example configuration overrides the default physical schemas for the `my_sc config = Config( model_defaults=ModelDefaultsConfig(dialect=), - physical_schema_override={"my_schema":"my_new_schema"}, + physical_schema_mapping={ + "^my_schema$": "my_new_schema", + '^dev.*': "development" + }, ) ``` -If you had a model name of `my_schema.table`, the physical table would be created as `my_new_schema.table_` instead of the default behavior of `sqlmesh__my_schema.table_`. +This config causes the following mapping behaviour: + +| Model name | Default physical location | Resolved physical location +| --------------------- | ----------------------------------------- | ------------------------------------ | +| `my_schema.my_table` | `sqlmesh__my_schema.table_` | `my_new_schema.table_` | +| `dev_schema.my_table` | `sqlmesh__dev_schema.table_` | `development.table_` | +| `other.my_table` | `sqlmesh__other.table_` | `sqlmesh__other.table_` | -This key only applies to the _physical tables_ that SQLMesh creates - the views are still created in `my_schema` (prod) or `my_schema__`. + +This only applies to the _physical tables_ that SQLMesh creates - the views are still created in `my_schema` (prod) or `my_schema__`. #### Disable environment-specific schemas SQLMesh stores `prod` environment views in the schema in a model's name - for example, the `prod` views for a model `my_schema.users` will be located in `my_schema`. -By default, for non-prod environments SQLMesh creates a new schema that appends the environment name to the model name's schema. For example, by default the view for a model `my_schema.users` in a SQLMesh environment named `dev` will be located in the schema `my_schema__dev`. +By default, for non-prod environments SQLMesh creates a new schema that appends the environment name to the model name's schema. For example, by default the view for a model `my_schema.users` in a SQLMesh environment named `dev` will be located in the schema `my_schema__dev` as `my_schema__dev.users`. + +##### Show at the table level instead This behavior can be changed to append a suffix at the end of a _table/view_ name instead. Appending the suffix to a table/view name means that non-prod environment views will be created in the same schema as the `prod` environment. The prod and non-prod views are differentiated by non-prod view names ending with `__`. @@ -249,7 +391,7 @@ Config example: === "Python" - The Python `environment_suffix_target` argument takes an `EnvironmentSuffixTarget` enumeration with a value of `EnvironmentSuffixTarget.TABLE` or `EnvironmentSuffixTarget.SCHEMA` (default). + The Python `environment_suffix_target` argument takes an `EnvironmentSuffixTarget` enumeration with a value of `EnvironmentSuffixTarget.TABLE`, `EnvironmentSuffixTarget.CATALOG` or `EnvironmentSuffixTarget.SCHEMA` (default). ```python linenums="1" from sqlmesh.core.config import Config, ModelDefaultsConfig, EnvironmentSuffixTarget @@ -260,16 +402,194 @@ Config example: ) ``` -The default behavior of appending the suffix to schemas is recommended because it leaves production with a single clean interface for accessing the views. However, if you are deploying SQLMesh in an environment with tight restrictions on schema creation then this can be a useful way of reducing the number of schemas SQLMesh uses. +!!! info "Default behavior" + The default behavior of appending the suffix to schemas is recommended because it leaves production with a single clean interface for accessing the views. However, if you are deploying SQLMesh in an environment with tight restrictions on schema creation then this can be a useful way of reducing the number of schemas SQLMesh uses. + +##### Show at the catalog level instead + +If neither the schema (default) nor the table level are sufficient for your use case, you can indicate the environment at the catalog level instead. + +This can be useful if you have downstream BI reporting tools and you would like to point them at a development environment to test something out without renaming all the table / schema references within the report query. + +In order to achieve this, you can configure [environment_suffix_target](../reference/configuration.md#environments) like so: + +=== "YAML" + + ```yaml linenums="1" + environment_suffix_target: catalog + ``` + +=== "Python" + + The Python `environment_suffix_target` argument takes an `EnvironmentSuffixTarget` enumeration with a value of `EnvironmentSuffixTarget.TABLE`, `EnvironmentSuffixTarget.CATALOG` or `EnvironmentSuffixTarget.SCHEMA` (default). + + ```python linenums="1" + from sqlmesh.core.config import Config, ModelDefaultsConfig, EnvironmentSuffixTarget + + config = Config( + model_defaults=ModelDefaultsConfig(dialect=), + environment_suffix_target=EnvironmentSuffixTarget.CATALOG, + ) + ``` + +Given the example of a model called `my_schema.users` with a default catalog of `warehouse` this will cause the following behavior: + +- For the `prod` environment, the default catalog as configured in the gateway will be used. So the view will be created at `warehouse.my_schema.users` +- For any other environment, eg `dev`, the environment name will be appended to the default catalog. So the view will be created at `warehouse__dev.my_schema.users` +- If a model is fully qualified with a catalog already, eg `finance_mart.my_schema.users`, then the environment catalog will be based off the model catalog and not the default catalog. In this example, the view will be created at `finance_mart__dev.my_schema.users` + + +!!! warning "Caveats" + - Using `environment_suffix_target: catalog` only works on engines that support querying across different catalogs. If your engine does not support cross-catalog queries then you will need to use `environment_suffix_target: schema` or `environment_suffix_target: table` instead. + - Automatic catalog creation is not supported on all engines even if they support cross-catalog queries. For engines where it is not supported, the catalogs must be managed externally from SQLMesh and exist prior to invoking SQLMesh. + +#### Physical table naming convention + +Out of the box, SQLMesh has the following defaults set: + + - `environment_suffix_target: schema` + - `physical_table_naming_convention: schema_and_table` + - no `physical_schema_mapping` overrides, so a `sqlmesh__` physical schema will be created for each model schema + +This means that given a catalog of `warehouse` and a model named `finance_mart.transaction_events_over_threshold`, SQLMesh will create physical tables using the following convention: + +``` +# .sqlmesh__.__
__ + +warehouse.sqlmesh__finance_mart.finance_mart__transaction_events_over_threshold__ +``` + +This deliberately contains some redundancy with the *model* schema as it's repeated at the physical layer in both the physical schema name as well as the physical table name. + +This default exists to make the physical table names portable between different configurations. If you were to define a `physical_schema_mapping` that maps all models to the same physical schema, since the model schema is included in the table name as well, there are no naming conflicts. + +##### Table only + +Some engines have object name length limitations which cause them to [silently truncate](https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS) table and view names that exceed this limit. This behaviour breaks SQLMesh, so we raise a runtime error if we detect the engine would silently truncate the name of the table we are trying to create. + +Having redundancy in the physical table names does reduce the number of characters that can be utilised in model names. To increase the number of characters available to model names, you can use `physical_table_naming_convention` like so: + +=== "YAML" + + ```yaml linenums="1" + physical_table_naming_convention: table_only + ``` + +=== "Python" + + ```python linenums="1" + from sqlmesh.core.config import Config, ModelDefaultsConfig, TableNamingConvention + + config = Config( + model_defaults=ModelDefaultsConfig(dialect=), + physical_table_naming_convention=TableNamingConvention.TABLE_ONLY, + ) + ``` + +This will cause SQLMesh to omit the model schema from the table name and generate physical names that look like (using the above example): +``` +# .sqlmesh__.
__ + +warehouse.sqlmesh__finance_mart.transaction_events_over_threshold__ +``` + +Notice that the model schema name is no longer part of the physical table name. This allows for slightly longer model names on engines with low identifier length limits, which may be useful for your project. + +In this configuration, it is your responsibility to ensure that any schema overrides in `physical_schema_mapping` result in each model schema getting mapped to a unique physical schema. + +For example, the following configuration will cause **data corruption**: + +```yaml +physical_table_naming_convention: table_only +physical_schema_mapping: + '.*': sqlmesh +``` + +This is because every model schema is mapped to the same physical schema but the model schema name is omitted from the physical table name. + +##### MD5 hash + +If you *still* need more characters, you can set `physical_table_naming_convention: hash_md5` like so: + +=== "YAML" + + ```yaml linenums="1" + physical_table_naming_convention: hash_md5 + ``` + +=== "Python" + + ```python linenums="1" + from sqlmesh.core.config import Config, ModelDefaultsConfig, TableNamingConvention + + config = Config( + model_defaults=ModelDefaultsConfig(dialect=), + physical_table_naming_convention=TableNamingConvention.HASH_MD5, + ) + ``` + +This will cause SQLMesh generate physical names that are always 45-50 characters in length and look something like: + +``` +# sqlmesh_md5__ + +sqlmesh_md5__d3b07384d113edec49eaa6238ad5ff00 + +# or, for a dev preview +sqlmesh_md5__d3b07384d113edec49eaa6238ad5ff00__dev +``` + +This has a downside that now it's much more difficult to determine which table corresponds to which model by just looking at the database with a SQL client. However, the table names have a predictable length so there are no longer any surprises with identfiers exceeding the max length at the physical layer. + +#### Virtual Data Environment Modes + +By default, Virtual Data Environments (VDE) are applied across both development and production environments. This allows SQLMesh to reuse physical tables when appropriate, even when promoting from development to production. + +However, users may prefer their production environment to be non-virtual. The non-exhaustive list of reasons may include: + +- Integration with third-party tools and platforms, such as data catalogs, may not work well with the virtual view layer that SQLMesh imposes by default +- A desire to rely on time travel features provided by cloud data warehouses such as BigQuery, Snowflake, and Databricks + +To mitigate this, SQLMesh offers an alternative 'dev-only' mode for using VDE. It can be enabled in the project configuration like so: + +=== "YAML" + + ```yaml linenums="1" + virtual_environment_mode: dev_only + ``` + +=== "Python" + + ```python linenums="1" + from sqlmesh.core.config import Config + + config = Config( + virtual_environment_mode="dev_only", + ) + ``` + +'dev-only' mode means that VDE is applied only in development environments. While in production, model tables and views are updated directly and bypass the virtual layer. This also means that physical tables in production will be created using the original, **unversioned** model names. Users will still benefit from VDE and data reuse across development environments. + +Please note the following tradeoffs when enabling this mode: + +- All data inserted in development environments is used only for [preview](../concepts/plans.md#data-preview-for-forward-only-changes) and will **not** be reused in production +- Reverting a model to a previous version will be applied going forward and may require an explicit data restatement + +!!! warning + Switching the mode for an existing project will result in a **complete rebuild** of all models in the project. Refer to the [Table Migration Guide](./table_migration.md) to migrate existing tables without rebuilding them from scratch. + #### Environment view catalogs By default, SQLMesh creates an environment view in the same [catalog](../concepts/glossary.md#catalog) as the physical table the view points to. The physical table's catalog is determined by either the catalog specified in the model name or the default catalog defined in the connection. -Some companies fully segregate `prod` and non-prod environment objects by catalog. For example, they might have a "prod" catalog that contains all `prod` environment physical tables and views and a separate "dev" catalog that contains all `dev` environment physical tables and views. +It can be desirable to create `prod` and non-prod virtual layer objects in separate catalogs instead. For example, there might be a "prod" catalog that contains all `prod` environment views and a separate "dev" catalog that contains all `dev` environment views. Separate prod and non-prod catalogs can also be useful if you have a CI/CD pipeline that creates environments, like the [SQLMesh Github Actions CI/CD Bot](../integrations/github.md). You might want to store the CI/CD environment objects in a dedicated catalog since there can be many of them. +!!! info "Virtual layer only" + Note that the following setting only affects the [virtual layer](../concepts/glossary.md#virtual-layer). If you need full segregation by catalog between environments in the [physical layer](../concepts/glossary.md#physical-layer) as well, see the [Isolated Systems Guide](../guides/isolated_systems.md). + To configure separate catalogs, provide a mapping from [regex patterns](https://en.wikipedia.org/wiki/Regular_expression) to catalog names. SQLMesh will compare the name of an environment to the regex patterns; when it finds a match it will store the environment's objects in the corresponding catalog. SQLMesh evaluates the regex patterns in the order defined in the configuration; it uses the catalog for the first matching pattern. If no match is found, the catalog defined in the model or the default catalog defined on the connection will be used. @@ -306,6 +626,9 @@ With the example configuration above, SQLMesh would evaluate environment names a * If the environment name starts with `dev`, the catalog will be `dev`. * If the environment name starts with `analytics_repo`, the catalog will be `cicd`. +!!! warning + This feature is mutually exclusive with `environment_suffix_target: catalog` in order to prevent ambiguous mappings from being defined. Attempting to specify both `environment_catalog_mapping` and `environment_suffix_target: catalog` will raise an error on project load + *Note:* This feature is only available for engines that support querying across catalogs. At the time of writing, the following engines are **NOT** supported: * [MySQL](../integrations/engines/mysql.md) @@ -323,7 +646,7 @@ With the example configuration above, SQLMesh would evaluate environment names a SQLMesh compares the current state of project files to an environment when `sqlmesh plan` is run. It detects changes to models, which can be classified as breaking or non-breaking. -SQLMesh can attempt to automatically [categorize](../concepts/plans.md#change-categories) the changes it detects. The `plan.auto_categorize_changes` option determines whether SQLMesh should attempt automatic change categorization. This option is in the [environments](../reference/configuration.md#environments) section of the configuration reference page. +SQLMesh can attempt to automatically [categorize](../concepts/plans.md#change-categories) the changes it detects. The `plan.auto_categorize_changes` option determines whether SQLMesh should attempt automatic change categorization. This option is in the [plan](../reference/configuration.md#plan) section of the configuration reference page. Supported values: @@ -370,11 +693,124 @@ Example showing default values: ) ``` + +### Always comparing against production + +By default, SQLMesh compares the current state of project files to the target `` environment when `sqlmesh plan ` is run. However, a common expectation is that local changes should always be compared to the production environment. + +The `always_recreate_environment` boolean plan option can alter this behavior. When enabled, SQLMesh will always attempt to compare against the production environment by recreating the target environment; If `prod` does not exist, SQLMesh will fall back to comparing against the target environment. + +**NOTE:**: Upon succesfull plan application, changes are still promoted to the target `` environment. + +=== "YAML" + + ```yaml linenums="1" + plan: + always_recreate_environment: True + ``` + +=== "Python" + + ```python linenums="1" + from sqlmesh.core.config import ( + Config, + ModelDefaultsConfig, + PlanConfig, + ) + + config = Config( + model_defaults=ModelDefaultsConfig(dialect=), + plan=PlanConfig( + always_recreate_environment=True, + ), + ) + ``` + +#### Change Categorization Example + +Consider this scenario with `always_recreate_environment` enabled: + +1. Initial state in `prod`: +```sql +MODEL (name sqlmesh_example.test_model, kind FULL); +SELECT 1 AS col +``` + +1. First (breaking) change in `dev`: +```sql +MODEL (name sqlmesh_example__dev.test_model, kind FULL); +SELECT 2 AS col +``` + +??? "Output plan example #1" + + ```bash + New environment `dev` will be created from `prod` + + Differences from the `prod` environment: + + Models: + └── Directly Modified: + └── sqlmesh_example__dev.test_model + + --- + +++ + + + kind FULL + ) + SELECT + - 1 AS col + + 2 AS col + ``` + +3. Second (metadata) change in `dev`: +```sql +MODEL (name sqlmesh_example__dev.test_model, kind FULL, owner 'John Doe'); +SELECT 5 AS col +``` + +??? "Output plan example #2" + + ```bash + New environment `dev` will be created from `prod` + + Differences from the `prod` environment: + + Models: + └── Directly Modified: + └── sqlmesh_example__dev.test_model + + --- + + +++ + + @@ -1,8 +1,9 @@ + + MODEL ( + name sqlmesh_example.test_model, + + owner "John Doe", + kind FULL + ) + SELECT + - 1 AS col + + 2 AS col + + Directly Modified: sqlmesh_example__dev.test_model (Breaking) + Models needing backfill: + └── sqlmesh_example__dev.test_model: [full refresh] + ``` + +Even though the second change should have been a metadata change (thus not requiring a backfill), it will still be classified as a breaking change because the comparison is against production instead of the previous development state. This is intentional and may cause additional backfills as more changes are accumulated. + + ### Gateways The `gateways` configuration defines how SQLMesh should connect to the data warehouse, state backend, and scheduler. These options are in the [gateway](../reference/configuration.md#gateway) section of the configuration reference page. -Each gateway key represents a unique gateway name and configures its connections. For example, this configures the `my_gateway` gateway: +Each gateway key represents a unique gateway name and configures its connections. **Gateway names are case-insensitive** - SQLMesh automatically normalizes gateway names to lowercase during configuration validation. This means you can use any case in your configuration files (e.g., `MyGateway`, `mygateway`, `MYGATEWAY`) and they will all work correctly. + +For example, this configures the `my_gateway` gateway: === "YAML" @@ -472,9 +908,11 @@ Example snowflake connection configuration: These pages describe the connection configuration options for each execution engine. +* [Athena](../integrations/engines/athena.md) * [BigQuery](../integrations/engines/bigquery.md) * [Databricks](../integrations/engines/databricks.md) * [DuckDB](../integrations/engines/duckdb.md) +* [Fabric](../integrations/engines/fabric.md) * [MotherDuck](../integrations/engines/motherduck.md) * [MySQL](../integrations/engines/mysql.md) * [MSSQL](../integrations/engines/mssql.md) @@ -488,35 +926,49 @@ These pages describe the connection configuration options for each execution eng #### State connection Configuration for the state backend connection if different from the data warehouse connection. -**Using the same connection for data warehouse and state is only recommended for non-production deployments of SQLMesh.** -Unlike data transformations, storing state information requires database transactions. Data warehouses aren’t optimized for executing transactions, so storing state information in them can slow down your project. -Even worse data corruption can occur with simultaneous writes to the same table. -Therefore, using your data warehouse is fine for testing but once you start running SQLMesh in production, you should use a dedicated state connection. -Recommended state backend engines for production deployments: +The data warehouse connection is used to store SQLMesh state if the `state_connection` key is not specified. + +Unlike data transformations, storing state information requires database transactions. Data warehouses aren’t optimized for executing transactions, and storing state information in them can slow down your project or produce corrupted data due to simultaneous writes to the same table. Therefore, production SQLMesh deployments should use a dedicated state connection. + +!!! note + Using the same connection for data warehouse and state is not recommended for production deployments of SQLMesh. + +The easiest and most reliable way to manage your state connection is for [Tobiko Cloud](https://tobikodata.com/product.html) to do it for you. If you'd rather handle it yourself, we list recommended and unsupported state engines below. + +Recommended state engines for production deployments: * [Postgres](../integrations/engines/postgres.md) * [GCP Postgres](../integrations/engines/gcp-postgres.md) -Other supported state backend engines (less tested than recommended): +Other state engines with fast and reliable database transactions (less tested than the recommended engines): +* [DuckDB](../integrations/engines/duckdb.md) + * With the caveat that it's a [single user](https://duckdb.org/docs/connect/concurrency.html#writing-to-duckdb-from-multiple-processes) database so will not scale to production usage * [MySQL](../integrations/engines/mysql.md) +* [MSSQL](../integrations/engines/mssql.md) -Ineligible state backends even for development: +Unsupported state engines, even for development: +* [ClickHouse](../integrations/engines/clickhouse.md) * [Spark](../integrations/engines/spark.md) * [Trino](../integrations/engines/trino.md) -The data warehouse connection is used if the `state_connection` key is not specified, unless the configuration uses an Airflow or Google Cloud Composer scheduler. If using one of those schedulers and no state connection is specified, the state connection defaults to the scheduler's database. - -Example postgres state connection configuration: +This example gateway configuration uses Snowflake for the data warehouse connection and Postgres for the state backend connection: === "YAML" ```yaml linenums="1" gateways: my_gateway: + connection: + # snowflake credentials here + type: snowflake + user: + password: + account: state_connection: + # postgres credentials here type: postgres host: port: @@ -534,13 +986,21 @@ Example postgres state connection configuration: Config, ModelDefaultsConfig, GatewayConfig, - PostgresConnectionConfig + PostgresConnectionConfig, + SnowflakeConnectionConfig ) config = Config( model_defaults=ModelDefaultsConfig(dialect=), gateways={ "my_gateway": GatewayConfig( + # snowflake credentials here + connection=SnowflakeConnectionConfig( + user=, + password=, + account=, + ), + # postgres credentials here state_connection=PostgresConnectionConfig( host=, port=, @@ -641,7 +1101,7 @@ Configuration for a connection used to run unit tests. An in-memory DuckDB datab ### Scheduler -Identifies which scheduler backend to use. The scheduler backend is used both for storing metadata and for executing [plans](../concepts/plans.md). By default, the scheduler type is set to `builtin`, which uses the existing SQL engine to store metadata. Use the `airflow` type integrate with Airflow. +Identifies which scheduler backend to use. The scheduler backend is used both for storing metadata and for executing [plans](../concepts/plans.md). By default, the scheduler type is set to `builtin`, which uses the existing SQL engine to store metadata. These options are in the [scheduler](../reference/configuration.md#scheduler) section of the configuration reference page. @@ -682,89 +1142,6 @@ Example configuration: No additional configuration options are supported by this scheduler type. -#### Airflow - -Example configuration: - -=== "YAML" - - ```yaml linenums="1" - gateways: - my_gateway: - scheduler: - type: airflow - airflow_url: - username: - password: - ``` - -=== "Python" - - An Airflow scheduler is specified with an `AirflowSchedulerConfig` object. - - ```python linenums="1" - from sqlmesh.core.config import ( - Config, - ModelDefaultsConfig, - GatewayConfig, - AirflowSchedulerConfig, - ) - - config = Config( - model_defaults=ModelDefaultsConfig(dialect=), - gateways={ - "my_gateway": GatewayConfig( - scheduler=AirflowSchedulerConfig( - airflow_url=, - username=, - password=, - ), - ), - } - ) - ``` - -See [Airflow Integration Guide](../integrations/airflow.md) for information about how to integrate Airflow with SQLMesh. See the [configuration reference page](../reference/configuration.md#airflow) for a list of all parameters. - -#### Cloud Composer - -The Google Cloud Composer scheduler type shares the same configuration options as the `airflow` type, except for `username` and `password`. Cloud Composer relies on `gcloud` authentication, so the `username` and `password` options are not required. - -Example configuration: - -=== "YAML" - - ```yaml linenums="1" - gateways: - my_gateway: - scheduler: - type: cloud_composer - airflow_url: - ``` - -=== "Python" - - An Google Cloud Composer scheduler is specified with an `CloudComposerSchedulerConfig` object. - - ```python linenums="1" - from sqlmesh.core.config import ( - Config, - ModelDefaultsConfig, - GatewayConfig, - CloudComposerSchedulerConfig, - ) - - config = Config( - model_defaults=ModelDefaultsConfig(dialect=), - gateways={ - "my_gateway": GatewayConfig( - scheduler=CloudComposerSchedulerConfig( - airflow_url=, - ), - ), - } - ) - ``` ### Gateway/connection defaults @@ -914,6 +1291,39 @@ This may be useful in cases where the name casing needs to be preserved, since t See [here](https://sqlglot.com/sqlglot/dialects/dialect.html#NormalizationStrategy) to learn more about the supported normalization strategies. +##### Gateway-specific model defaults + +You can also define gateway specific `model_defaults` in the `gateways` section, which override the global defaults for that gateway. + +```yaml linenums="1" hl_lines="6 14" +gateways: + redshift: + connection: + type: redshift + model_defaults: + dialect: "snowflake,normalization_strategy=case_insensitive" + snowflake: + connection: + type: snowflake + +default_gateway: snowflake + +model_defaults: + dialect: snowflake + start: 2025-02-05 +``` + +This allows you to tailor the behavior of models for each gateway without affecting the global `model_defaults`. + +For example, in some SQL engines identifiers like table and column names are case-sensitive, but they are case-insensitive in other engines. By default, a project that uses both types of engines would need to ensure the models for each engine aligned with the engine's normalization behavior, which makes project maintenance and debugging more challenging. + +Gateway-specific `model_defaults` allow you to change how SQLMesh performs identifier normalization *by engine* to align the different engines' behavior. + +In the example above, the project's default dialect is `snowflake` (line 14). The `redshift` gateway configuration overrides that global default dialect with `"snowflake,normalization_strategy=case_insensitive"` (line 6). + +That value tells SQLMesh that the `redshift` gateway's models will be written in the Snowflake SQL dialect (so need to be transpiled from Snowflake to Redshift), but that the resulting Redshift SQL should treat identifiers as case-insensitive to match Snowflake's behavior. + + #### Model Kinds Model kinds are required in each model file's `MODEL` DDL statement. They may optionally be used to specify a default kind in the model defaults configuration key. @@ -933,7 +1343,7 @@ The `VIEW`, `FULL`, and `EMBEDDED` model kinds are specified by name only, while ); ``` - `INCREMENTAL_BY_TIME_RANGE` requires an array specifying the model's `time_column`: + `INCREMENTAL_BY_TIME_RANGE` requires an array specifying the model's `time_column` (which should be in the UTC time zone): ```sql linenums="1" MODEL( @@ -993,6 +1403,83 @@ Example enabling name inference: ) ``` +### Before_all and after_all Statements + +The `before_all` and `after_all` statements are executed at the start and end, respectively, of the `sqlmesh plan` and `sqlmesh run` commands. + +These statements can be defined in the configuration file under the `before_all` and `after_all` keys, either as a list of SQL statements or by using SQLMesh macros: + +=== "YAML" + + ```yaml linenums="1" + before_all: + - CREATE TABLE IF NOT EXISTS analytics (table VARCHAR, eval_time VARCHAR) + after_all: + - "@grant_select_privileges()" + - "@IF(@this_env = 'prod', @grant_schema_usage())" + ``` + +=== "Python" + + ```python linenums="1" + from sqlmesh.core.config import Config + + config = Config( + before_all = [ + "CREATE TABLE IF NOT EXISTS analytics (table VARCHAR, eval_time VARCHAR)" + ], + after_all = [ + "@grant_select_privileges()", + "@IF(@this_env = 'prod', @grant_schema_usage())" + ], + ) + ``` + +#### Examples + +These statements allow for actions to be executed before all individual model statements or after all have run, respectively. They can also simplify tasks such as granting privileges. + +##### Example: Granting Select Privileges + +For example, rather than using an `on_virtual_update` statement in each model to grant privileges on the views of the virtual layer, a single macro can be defined and used at the end of the plan: + +```python linenums="1" +from sqlmesh.core.macros import macro + +@macro() +def grant_select_privileges(evaluator): + if evaluator.views: + return [ + f"GRANT SELECT ON VIEW {view_name} /* sqlglot.meta replace=false */ TO ROLE admin_role;" + for view_name in evaluator.views + ] +``` + +By including the comment `/* sqlglot.meta replace=false */`, you further ensure that the evaluator does not replace the view name with the physical table name during rendering. + +##### Example: Granting Schema Privileges + +Similarly, you can define a macro to grant schema usage privileges and, as demonstrated in the configuration above, using `this_env` macro conditionally execute it only in the production environment. + +```python linenums="1" +from sqlmesh import macro + +@macro() +def grant_schema_usage(evaluator): + if evaluator.this_env == "prod" and evaluator.schemas: + return [ + f"GRANT USAGE ON SCHEMA {schema} TO admin_role;" + for schema in evaluator.schemas + ] +``` + +As demonstrated in these examples, the `schemas` and `views` are available within the macro evaluator for macros invoked within the `before_all` and `after_all` statements. Additionally, the macro `this_env` provides access to the current environment name, which can be helpful for more advanced use cases that require fine-grained control over their behaviour. + +### Linting + +SQLMesh provides a linter that checks for potential issues in your models' code. Enable it and specify which linting rules to apply in the configuration file's `linter` key. + +Learn more about linting configuration in the [linting guide](./linter.md). ### Debug mode @@ -1021,3 +1508,27 @@ Example enabling debug mode for the CLI command `sqlmesh plan`: C:\> set SQLMESH_DEBUG=1 C:\> sqlmesh plan ``` + + +### Python library dependencies +SQLMesh enables you to write Python models and macros which depend on third-party libraries. To ensure each run / evaluation uses the same version, you can specify versions in a `sqlmesh-requirements.lock` file in the root of your project. + +The sqlmesh.lock must be of the format `dep==version`. Only `==` is supported. + +For example: + +``` +numpy==2.1.2 +pandas==2.2.3 +``` + +This feature is only available in [Tobiko Cloud](https://tobikodata.com/product.html). + +#### Excluding dependencies + +You can exclude dependencies by prefixing the dependency with a `^`. For example: + +``` +^numpy +pandas==2.2.3 +``` diff --git a/docs/guides/connections.md b/docs/guides/connections.md index 166c64eb56..e0dca0f7a4 100644 --- a/docs/guides/connections.md +++ b/docs/guides/connections.md @@ -2,8 +2,6 @@ ## Overview -**Note:** The following guide only applies when using the built-in scheduler. Connections are configured differently when using an external scheduler such as Airflow. See the [Scheduling guide](scheduling.md) for more details. - In order to deploy models and to apply changes to them, you must configure a connection to your Data Warehouse and, optionally, connection to the database where the SQLMesh state is stored. This can be done in either the `config.yaml` file in your project folder, or the one in `~/.sqlmesh`. Each connection is configured as part of a gateway which has a unique name associated with it. The gateway name can be used to select a specific combination of connection settings when using the CLI. For example: @@ -23,7 +21,7 @@ sqlmesh --gateway local_db plan ## State connection -By default, the data warehouse connection is also used to store the SQLMesh state, unless the configuration uses an Airflow or Google Cloud Composer scheduler. If using one of those schedulers, the state connection defaults to the scheduler's database. +By default, the data warehouse connection is also used to store the SQLMesh state. The state connection can be changed by providing different connection settings in the `state_connection` key of the gateway configuration: diff --git a/docs/guides/custom_materializations.md b/docs/guides/custom_materializations.md index 03c4da3551..905a3d017e 100644 --- a/docs/guides/custom_materializations.md +++ b/docs/guides/custom_materializations.md @@ -24,13 +24,13 @@ A custom materialization must: - Be written in Python code - Be a Python class that inherits the SQLMesh `CustomMaterialization` base class -- Use or override the `insert` method from the SQLMesh [`MaterializableStrategy`](https://github.com/TobikoData/sqlmesh/blob/034476e7f64d261860fd630c3ac56d8a9c9f3e3a/sqlmesh/core/snapshot/evaluator.py#L1146) class/subclasses +- Use or override the `insert` method from the SQLMesh [`MaterializableStrategy`](https://github.com/SQLMesh/sqlmesh/blob/034476e7f64d261860fd630c3ac56d8a9c9f3e3a/sqlmesh/core/snapshot/evaluator.py#L1146) class/subclasses - Be loaded or imported by SQLMesh at runtime A custom materialization may: -- Use or override methods from the SQLMesh [`MaterializableStrategy`](https://github.com/TobikoData/sqlmesh/blob/034476e7f64d261860fd630c3ac56d8a9c9f3e3a/sqlmesh/core/snapshot/evaluator.py#L1146) class/subclasses -- Use or override methods from the SQLMesh [`EngineAdapter`](https://github.com/TobikoData/sqlmesh/blob/034476e7f64d261860fd630c3ac56d8a9c9f3e3a/sqlmesh/core/engine_adapter/base.py#L67) class/subclasses +- Use or override methods from the SQLMesh [`MaterializableStrategy`](https://github.com/SQLMesh/sqlmesh/blob/034476e7f64d261860fd630c3ac56d8a9c9f3e3a/sqlmesh/core/snapshot/evaluator.py#L1146) class/subclasses +- Use or override methods from the SQLMesh [`EngineAdapter`](https://github.com/SQLMesh/sqlmesh/blob/034476e7f64d261860fd630c3ac56d8a9c9f3e3a/sqlmesh/core/engine_adapter/base.py#L67) class/subclasses - Execute arbitrary SQL code and fetch results with the engine adapter `execute` and related methods A custom materialization may perform arbitrary Python processing with Pandas or other libraries, but in most cases that logic should reside in a [Python model](../concepts/models/python_models.md) instead of the materialization. @@ -64,6 +64,7 @@ class CustomFullMaterialization(CustomMaterialization): query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: self.adapter.replace_query(table_name, query_or_df) @@ -78,6 +79,7 @@ Let's unpack this materialization: * `query_or_df` - a query (of SQLGlot expression type) or DataFrame (Pandas, PySpark, or Snowpark) instance to be inserted * `model` - the model definition object used to access model parameters and user-specified materialization arguments * `is_first_insert` - whether this is the first insert for the current version of the model (used with batched or multi-step inserts) + * `render_kwargs` - a dictionary of arguments used to render the model query * `kwargs` - additional and future arguments * The `self.adapter` instance is used to interact with the target engine. It comes with a set of useful high-level APIs like `replace_query`, `columns`, and `table_exists`, but also supports executing arbitrary SQL expressions with its `execute` method. @@ -150,13 +152,108 @@ class CustomFullMaterialization(CustomMaterialization): query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: config_value = model.custom_materialization_properties["config_key"] # Proceed with implementing the insertion logic. - # Example existing materialization for look and feel: https://github.com/TobikoData/sqlmesh/blob/main/sqlmesh/core/snapshot/evaluator.py + # Example existing materialization for look and feel: https://github.com/SQLMesh/sqlmesh/blob/main/sqlmesh/core/snapshot/evaluator.py +``` + +## Extending `CustomKind` + +!!! warning + This is even lower level usage that contains a bunch of extra complexity and relies on knowledge of the SQLMesh internals. + If you dont need this level of complexity, stick with the method described above. + +In many cases, the above usage of a custom materialization will suffice. + +However, you may still want tighter integration with SQLMesh's internals: + +- You may want to validate custom properties are correct before any database connections are made +- You may want to leverage existing functionality of SQLMesh that relies on specific properties being present + +In this case, you can provide a subclass of `CustomKind` for SQLMesh to use instead of `CustomKind` itself. +During project load, SQLMesh will instantiate your *subclass* instead of `CustomKind`. + +This allows you to run custom validators at load time rather than having to perform extra validation when `insert()` is invoked on your `CustomMaterialization`. + +You can also define standard Python `@property` methods to "hoist" properties declared inside `materialization_properties` to the top level on your `Kind` object. This can make using them from within your custom materialization easier. + +To extend `CustomKind`, first you define a subclass like so: + +```python linenums="1" hl_lines="7" +from typing_extensions import Self +from pydantic import field_validator, ValidationInfo +from sqlmesh import CustomKind +from sqlmesh.utils.pydantic import list_of_fields_validator +from sqlmesh.utils.errors import ConfigError + +class MyCustomKind(CustomKind): + + _primary_key: t.List[exp.Expression] + + @model_validator(mode="after") + def _validate_model(self) -> Self: + self._primary_key = list_of_fields_validator( + self.materialization_properties.get("primary_key"), + { "dialect": self.dialect } + ) + if not self.primary_key: + raise ConfigError("primary_key must be specified") + return self + + @property + def primary_key(self) -> t.List[exp.Expression]: + return self._primary_key + +``` + +To use it within a model, we can do something like: + +```sql linenums="1" hl_lines="4" +MODEL ( + name my_db.my_model, + kind CUSTOM ( + materialization 'my_custom_full', + materialization_properties ( + primary_key = (col1, col2) + ) + ) +); +``` + +To indicate to SQLMesh that it should use the `MyCustomKind` subclass instead of `CustomKind`, specify it as a generic type parameter on your custom materialization class like so: + +```python linenums="1" hl_lines="1 16" +class CustomFullMaterialization(CustomMaterialization[MyCustomKind]): + NAME = "my_custom_full" + + def insert( + self, + table_name: str, + query_or_df: QueryOrDF, + model: Model, + is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], + **kwargs: t.Any, + ) -> None: + assert isinstance(model.kind, MyCustomKind) + + self.adapter.merge( + ..., + unique_key=model.kind.primary_key + ) ``` +When SQLMesh loads your custom materialization, it will inspect the Python type signature for generic parameters that are subclasses of `CustomKind`. If it finds one, it will instantiate your subclass when building `model.kind` instead of using the default `CustomKind` class. + +In this example, this means that: + +- Validation for `primary_key` happens at load time instead of evaluation time. So if there is an issue, you can abort early rather than halfway through applying a plan. +- When your custom materialization is called to load data into tables, `model.kind` will resolve to your custom kind object so you can access the extra properties you defined without first needing to validate them / coerce them to a usable type. + + ## Sharing custom materializations ### Copying files @@ -195,4 +292,4 @@ setup( ) ``` -Refer to the SQLMesh Github [custom_materializations](https://github.com/TobikoData/sqlmesh/tree/main/examples/custom_materializations) example for more details on Python packaging. +Refer to the SQLMesh Github [custom_materializations](https://github.com/SQLMesh/sqlmesh/tree/main/examples/custom_materializations) example for more details on Python packaging. diff --git a/docs/guides/customizing_sqlmesh.md b/docs/guides/customizing_sqlmesh.md new file mode 100644 index 0000000000..3b95b6ba82 --- /dev/null +++ b/docs/guides/customizing_sqlmesh.md @@ -0,0 +1,74 @@ +# Customizing SQLMesh + +SQLMesh supports the workflows used by the vast majority of data engineering teams. However, your company may have bespoke processes or tools that require special integration with SQLMesh. + +Fortunately, SQLMesh is an open-source Python library, so you can view its underlying code and customize it for your needs. + +Customization generally involves subclassing SQLMesh classes to extend or modify their functionality. + +!!! danger "Caution" + + Customize SQLMesh with extreme caution. Errors may cause SQLMesh to produce unexpected results. + +## Custom loader + +Loading is the process of reading project files and converting their contents into SQLMesh's internal Python objects. + +The loading stage is a convenient place to customize SQLMesh behavior because you can access a project's objects after they've been ingested from file but before SQLMesh uses them. + +SQLMesh's `SqlMeshLoader` class handles the loading process - customize it by subclassing it and overriding its methods. + +!!! note "Python configuration only" + + Custom loaders require using the [Python configuration format](./configuration.md#python) (YAML is not supported). + +### Modify every model + +One reason to customize the loading process is to do something to every model. For example, you might want to add a post-statement to every model. + +The loading process parses all model SQL statements, so new or modified SQL must be parsed by SQLGlot before being passed to a model object. + +This custom loader example adds a post-statement to every model: + +``` python linenums="1" title="config.py" +from sqlmesh.core.loader import SqlMeshLoader +from sqlmesh.utils import UniqueKeyDict +from sqlmesh.core.dialect import parse_one +from sqlmesh.core.config import Config + +# New `CustomLoader` class subclasses `SqlMeshLoader` +class CustomLoader(SqlMeshLoader): + # Override SqlMeshLoader's `_load_models` method to access every model + def _load_models( + self, + macros: "MacroRegistry", + jinja_macros: "JinjaMacroRegistry", + gateway: str | None, + audits: UniqueKeyDict[str, "ModelAudit"], + signals: UniqueKeyDict[str, "signal"], + ) -> UniqueKeyDict[str, "Model"]: + # Call SqlMeshLoader's normal `_load_models` method to ingest models from file and parse model SQL + models = super()._load_models(macros, jinja_macros, gateway, audits, signals) + + new_models = {} + # Loop through the existing model names/objects + for model_name, model in models.items(): + # Create list of existing and new post-statements + new_post_statements = [ + # Existing post-statements from model object + *model.post_statements, + # New post-statement is raw SQL, so we parse it with SQLGlot's `parse_one` function. + # Make sure to specify the SQL dialect if different from the project default. + parse_one(f"VACUUM @this_model"), + ] + # Create a copy of the model with the `post_statements_` field updated + new_models[model_name] = model.copy(update={"post_statements_": new_post_statements}) + + return new_models + +# Pass the CustomLoader class to the SQLMesh configuration object +config = Config( + # < your configuration parameters here >, + loader=CustomLoader, +) +``` \ No newline at end of file diff --git a/docs/guides/incremental_time.md b/docs/guides/incremental_time.md index f610a52850..8663ae9926 100644 --- a/docs/guides/incremental_time.md +++ b/docs/guides/incremental_time.md @@ -109,6 +109,10 @@ The model configuration specifies that the column `model_time_column` represents The `WHERE` clause uses the [SQLMesh predefined macro variables](../concepts/macros/macro_variables.md#predefined-variables) `@start_ds` and `@end_ds` to specify the date range. SQLMesh automatically substitutes in the correct dates based on which intervals are being processed in a job. +!!! tip "Important" + + The `time_column` should be in the [UTC time zone](https://en.wikipedia.org/wiki/Coordinated_Universal_Time) to ensure correct interaction with SQLMesh's scheduler and predefined macro variables. + In addition to the query `WHERE` clause, SQLMesh prevents data leakage by automatically wrapping the query in another time-filtering `WHERE` clause using the time column in the model's configuration. This raises a question: if SQLMesh automatically adds a time filtering `WHERE` clause, why do you need to include one in the query? Because the two filters play different roles: @@ -155,19 +159,49 @@ WHERE Alternatively, all the changes contained in a *specific plan* can be classified as forward-only with a flag: `sqlmesh plan --forward-only`. A subsequent plan that did not include the forward-only flag would fully refresh the model's physical table. Learn more about forward-only plans [here](../concepts/plans.md#forward-only-plans). -### Destructive changes +### Schema changes + +When SQLMesh processes forward-only changes to incremental models, it compares the model's new schema with the existing physical table schema to detect potential data loss or compatibility issues. SQLMesh categorizes schema changes into two types: + +#### Destructive changes -Some model changes destroy existing data in a table. Dropping a column from the model is the most direct cause, but changing a column's data type (such as casting a column from a `STRING` to `INTEGER`) can also require a drop. (Whether or not a specific change requires dropping a column may differ across SQL engines.) +Some model changes destroy existing data in a table. Examples include: -Forward-only models are used to retain existing data. Before executing forward-only changes to incremental models, SQLMesh performs a check to determine if existing data will be destroyed. +- **Dropping a column** from the model +- **Renaming a column** +- **Modifying a column data type** in a ways that could cause data loss -The check is performed at plan time based on the model definition. SQLMesh may not be able to resolve all of a model's column data types and complete the check, so the check is performed again at run time based on the physical tables underlying the model. +Whether a specific change is destructive may differ across SQL engines based on their schema evolution capabilities. + +#### Additive changes + +Additive changes are any changes to the table's columns that aren't categorized as destructive. A simple example would be adding a column to a table but another would be changing a column data type to a type that is compatible (ex: INT -> STRING). + +SQLMesh performs schema change detection at plan time based on the model definition. If SQLMesh cannot resolve all of a model's column data types at plan time, the check is performed again at run time based on the physical tables underlying the model. #### Changes to forward-only models -A model's `on_destructive_change` [configuration setting](../reference/model_configuration.md#incremental-models) determines what happens when SQLMesh detects a destructive change. +SQLMesh provides two configuration settings to control how schema changes are handled: + +- **`on_destructive_change`** - Controls behavior for destructive schema changes +- **`on_additive_change`** - Controls behavior for additive schema changes -By default, SQLMesh will error so no data is lost. You can set `on_destructive_change` to `warn` or `allow` in the model's `MODEL` block to allow destructive changes. +##### Configuration options + +Both properties support four values: + +- **`error`** (default for `on_destructive_change`): Stop execution and raise an error +- **`warn`**: Log a warning but proceed with the change +- **`allow`** (default for `on_additive_change`): Silently proceed with the change +- **`ignore`**: Skip the schema change check entirely for this change type + +!!! warning "Ignore is Dangerous" + +`ignore` is dangerous since it can result in error or data loss. It likely should never be used but could be useful as an "escape-hatch" or a way to workaround unexpected behavior. + +##### Destructive change handling + +The `on_destructive_change` [configuration setting](../reference/model_configuration.md#incremental-models) determines what happens when SQLMesh detects a destructive change. By default, SQLMesh will error so no data is lost. This example configures a model to silently `allow` destructive changes: @@ -182,12 +216,93 @@ MODEL ( ); ``` -A default `on_destructive_change` value can be set for all incremental models that do not specify it themselves in the [model defaults configuration](../reference/model_configuration.md#model-defaults). +##### Additive change handling + +The `on_additive_change` configuration setting determines what happens when SQLMesh detects an additive change like adding new columns. By default, SQLMesh allows these changes since they don't destroy existing data. + +This example configures a model to raise an error for additive changes (useful for strict schema control): + +``` sql linenums="1" +MODEL ( + name sqlmesh_example.new_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column model_time_column, + forward_only true, + on_additive_change error + ), +); +``` + +##### Combining both settings + +You can configure both settings together to have fine-grained control over schema evolution: + +``` sql linenums="1" +MODEL ( + name sqlmesh_example.new_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column model_time_column, + forward_only true, + on_destructive_change warn, -- Warn but allow destructive changes + on_additive_change allow -- Silently allow new columns + ), +); +``` + +##### Model defaults + +Default values for both `on_destructive_change` and `on_additive_change` can be set for all incremental models in the [model defaults configuration](../reference/model_configuration.md#model-defaults). + +##### Common use cases + +Here are some common patterns for configuring schema change handling: + +**Strict schema control** - Prevent any schema changes: +```sql linenums="1" +MODEL ( + name sqlmesh_example.strict_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column event_date, + forward_only true, + on_destructive_change error, -- Block destructive changes + on_additive_change error -- Block even new columns + ), +); +``` + +**Permissive development model** - Allow all schema changes: +```sql linenums="1" +MODEL ( + name sqlmesh_example.dev_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column event_date, + forward_only true, + on_destructive_change allow, -- Allow dropping columns + on_additive_change allow -- Allow new columns (`allow` is the default value for this setting, so it can be omitted here) + ), +); +``` + +**Production safety** - Allow safe changes, warn about risky ones: +```sql linenums="1" +MODEL ( + name sqlmesh_example.production_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column event_date, + forward_only true, + on_destructive_change warn, -- Warn about destructive changes + on_additive_change allow -- Allow new columns (`allow` is the default value for this setting, so it can be omitted here) + ), +); +``` #### Changes in forward-only plans -The SQLMesh `plan` [`--forward-only` option](../concepts/plans.md#forward-only-plans) treats all the plan's model changes as forward-only. When this option is specified, SQLMesh will check all modified incremental models for destructive schema changes, not just models configured with `forward_only true`. +The SQLMesh `plan` [`--forward-only` option](../concepts/plans.md#forward-only-plans) treats all the plan's model changes as forward-only. When this option is specified, SQLMesh will check all modified incremental models for both destructive and additive schema changes, not just models configured with `forward_only true`. + +SQLMesh determines what to do for each model based on this setting hierarchy: -SQLMesh determines what to do for each model based on this setting hierarchy: the model's `on_destructive_change` value (if present), the `on_destructive_change` [model defaults](../reference/model_configuration.md#model-defaults) value (if present), and the SQLMesh global default of `error`. +- **For destructive changes**: the model's `on_destructive_change` value (if present), the `on_destructive_change` [model defaults](../reference/model_configuration.md#model-defaults) value (if present), and the SQLMesh global default of `error` +- **For additive changes**: the model's `on_additive_change` value (if present), the `on_additive_change` [model defaults](../reference/model_configuration.md#model-defaults) value (if present), and the SQLMesh global default of `allow` -If you want to temporarily allow destructive changes to models that don't allow them, use the `plan` command's [`--allow-destructive-change` selector](../concepts/plans.md#destructive-changes) to specify which models. Learn more about model selectors [here](../guides/model_selection.md). +If you want to temporarily allow destructive changes to models that don't allow them, use the `plan` command's [`--allow-destructive-model` selector](../concepts/plans.md#destructive-changes) to specify which models. Similarly, if you want to temporarily allow additive changes to models configured with `on_additive_change=error`, use the [`--allow-additive-model` selector](../concepts/plans.md#destructive-changes). Learn more about model selectors [here](../guides/model_selection.md). diff --git a/docs/guides/isolated_systems.md b/docs/guides/isolated_systems.md index 462e761534..a032675653 100644 --- a/docs/guides/isolated_systems.md +++ b/docs/guides/isolated_systems.md @@ -70,7 +70,7 @@ MODEL ( ) ``` -To embed the gateway name directly in the schema name, use the `@{gateway}` syntax: +To embed the gateway name directly in the schema name, use the curly brace `@{gateway}` syntax: ```sql linenums="1" MODEL ( @@ -78,6 +78,8 @@ MODEL ( ) ``` +Learn more about the curly brace `@{}` syntax [here](../concepts/macros/sqlmesh_macros.md#embedding-variables-in-strings). + ## Workflow ### Linking systems diff --git a/docs/guides/linter.md b/docs/guides/linter.md new file mode 100644 index 0000000000..6cdac167ec --- /dev/null +++ b/docs/guides/linter.md @@ -0,0 +1,261 @@ +# Linter guide + +![Linter](./linter/linter_example.png) + +Linting is a powerful tool for improving code quality and consistency. It enables you to automatically validate model definition, ensuring they adhere to your team's best practices. + +When a SQLMesh plan is created, each model's code is checked for compliance with a set of rules you choose. + +SQLMesh provides built-in rules, and you can define custom rules. This improves code quality and helps detect issues early in the development cycle when they are simpler to debug. + +## Rules + +Each linting rule is responsible for identifying a pattern in a model's code. + +Some rules validate that a pattern is *not* present, such as not allowing `SELECT *` in a model's outermost query. Other rules validate that a pattern *is* present, like ensuring that every model's `owner` field is specified. We refer to both of these below as "validating a pattern". + +Rules are defined in Python. Each rule is an individual Python class that inherits from SQLMesh's `Rule` base class and defines the logic for validating a pattern. + +We display a portion of the `Rule` base class's code below ([full source code](https://github.com/SQLMesh/sqlmesh/blob/main/sqlmesh/core/linter/rule.py)). Its methods and properties illustrate the most important components of the subclassed rules you define. + +Each rule class you create has four vital components: + +1. Name: the class's name is used as the rule's name. +2. Description: the class should define a docstring that provides a short explanation of the rule's purpose. +3. Pattern validation logic: the class should define a `check_model()` method containing the core logic that validates the rule's pattern. The method can access any `Model` attribute. +4. Rule violation logic: if a rule's pattern is not validated, the rule is "violated" and the class should return a `RuleViolation` object. The `RuleViolation` object should include the contextual information a user needs to understand and fix the problem. + +``` python linenums="1" +# Class name used as rule's name +class Rule: + # Docstring provides rule's description + """The base class for a rule.""" + + # Pattern validation logic goes in `check_model()` method + @abc.abstractmethod + def check_model(self, model: Model) -> t.Optional[RuleViolation]: + """The evaluation function that checks for a violation of this rule.""" + + # Rule violation object returned by `violation()` method + def violation(self, violation_msg: t.Optional[str] = None) -> RuleViolation: + """Return a RuleViolation instance if this rule is violated""" + return RuleViolation(rule=self, violation_msg=violation_msg or self.summary) +``` + +### Built-in rules + +SQLMesh includes a set of predefined rules that check for potential SQL errors or enforce code style. + +An example of the latter is the `NoSelectStar` rule, which prohibits a model from using `SELECT *` in its query's outer-most select statement. + +Here is code for the built-in `NoSelectStar` rule class, with the different components annotated: + +``` python linenums="1" +# Rule's name is the class name `NoSelectStar` +class NoSelectStar(Rule): + # Docstring explaining rule + """Query should not contain SELECT * on its outer most projections, even if it can be expanded.""" + + def check_model(self, model: Model) -> t.Optional[RuleViolation]: + # If this model does not contain a SQL query, there is nothing to validate + if not isinstance(model, SqlModel): + return None + + # Use the query's `is_star` property to detect the `SELECT *` pattern. + # If present, call the `violation()` method to return a `RuleViolation` object. + return self.violation() if model.query.is_star else None +``` + +Here are all of SQLMesh's built-in linting rules: + +| Name | Check type | Explanation | +| -------------------------- | ----------- | ------------------------------------------------------------------------------------------------------------------------ | +| `ambiguousorinvalidcolumn` | Correctness | SQLMesh found duplicate columns or was unable to determine whether a column is duplicated or not | +| `invalidselectstarexpansion` | Correctness | The query's top-level selection may be `SELECT *`, but only if SQLMesh can expand the `SELECT *` into individual columns | +| `noselectstar` | Stylistic | The query's top-level selection may not be `SELECT *`, even if SQLMesh can expand the `SELECT *` into individual columns | +| `nomissingaudits` | Governance | SQLMesh did not find any `audits` in the model's configuration to test data quality. | + +### User-defined rules + +You may define custom rules to implement your team's best practices. + +For instance, you could ensure all models have an `owner` by defining the following linting rule: + +``` python linenums="1" title="linter/user.py" +import typing as t + +from sqlmesh.core.linter.rule import Rule, RuleViolation +from sqlmesh.core.model import Model + +class NoMissingOwner(Rule): + """Model owner should always be specified.""" + + def check_model(self, model: Model) -> t.Optional[RuleViolation]: + # Rule violated if the model's owner field (`model.owner`) is not specified + return self.violation() if not model.owner else None + +``` + +Place a rule's code in the project's `linter/` directory. SQLMesh will load all subclasses of `Rule` from that directory. + +If the rule is specified in the project's [configuration file](#applying-linting-rules), SQLMesh will run it when: +- A plan is created during `sqlmesh plan` +- The command `sqlmesh lint` is ran + +SQLMesh will error if a model violates the rule, informing you which model(s) violated the rule. In this example, `full_model.sql` violated the `NoMissingOwner` rule, essentially halting execution: + +``` bash +$ sqlmesh plan + +Linter errors for .../models/full_model.sql: + - nomissingowner: Model owner should always be specified. + +Error: Linter detected errors in the code. Please fix them before proceeding. +``` + +Or through the standalone command, for faster iterations: + +``` bash +$ sqlmesh lint + +Linter errors for .../models/full_model.sql: + - nomissingowner: Model owner should always be specified. + +Error: Linter detected errors in the code. Please fix them before proceeding. +``` + +Use `sqlmesh lint --help` for more information. + + +## Applying linting rules + +Specify which linting rules a project should apply in the project's [configuration file](./configuration.md). + +Rules are specified as lists of rule names under the `linter` key. Globally enable or disable linting with the `enabled` key, which is `false` by default. + +NOTE: you **must** set the `enabled` key to `true` key to apply the project's linting rules. + +### Specific linting rules + +This example specifies that the `"ambiguousorinvalidcolumn"` and `"invalidselectstarexpansion"` linting rules should be enforced: + +=== "YAML" + + ```yaml linenums="1" + linter: + enabled: true + rules: ["ambiguousorinvalidcolumn", "invalidselectstarexpansion"] + ``` + +=== "Python" + + ```python linenums="1" + from sqlmesh.core.config import Config, LinterConfig + + config = Config( + linter=LinterConfig( + enabled=True, + rules=["ambiguousorinvalidcolumn", "invalidselectstarexpansion"] + ) + ) + ``` + +### All linting rules + +Apply every built-in and user-defined rule by specifying `"ALL"` instead of a list of rules: + +=== "YAML" + + ```yaml linenums="1" + linter: + enabled: True + rules: "ALL" + ``` + +=== "Python" + + ```python linenums="1" + from sqlmesh.core.config import Config, LinterConfig + + config = Config( + linter=LinterConfig( + enabled=True, + rules="all", + ) + ) + ``` + +If you want to apply all rules except for a few, you can specify `"ALL"` and list the rules to ignore in the `ignored_rules` key: + +=== "YAML" + + ```yaml linenums="1" + linter: + enabled: True + rules: "ALL" # apply all built-in and user-defined rules and error if violated + ignored_rules: ["noselectstar"] # but don't run the `noselectstar` rule + ``` + +=== "Python" + + ```python linenums="1" + from sqlmesh.core.config import Config, LinterConfig + + config = Config( + linter=LinterConfig( + enabled=True, + # apply all built-in and user-defined linting rules and error if violated + rules="all", + # but don't run the `noselectstar` rule + ignored_rules=["noselectstar"] + ) + ) + ``` + +### Exclude a model from linting + +You can specify that a specific *model* ignore a linting rule by specifying `ignored_rules` in its `MODEL` block. + +This example specifies that the model `docs_example.full_model` should not run the `invalidselectstarexpansion` rule: + +```sql linenums="1" +MODEL( + name docs_example.full_model, + ignored_rules ["invalidselectstarexpansion"] # or "ALL" to turn off linting completely +); +``` + +### Rule violation behavior + +Linting rule violations raise an error by default, preventing the project from running until the violation is addressed. + +You may specify that a rule's violation should not error and only log a warning by specifying it in the `warn_rules` key instead of the `rules` key. + +=== "YAML" + + ```yaml linenums="1" + linter: + enabled: True + # error if `ambiguousorinvalidcolumn` rule violated + rules: ["ambiguousorinvalidcolumn"] + # but only warn if "invalidselectstarexpansion" is violated + warn_rules: ["invalidselectstarexpansion"] + ``` + +=== "Python" + + ```python linenums="1" + from sqlmesh.core.config import Config, LinterConfig + + config = Config( + linter=LinterConfig( + enabled=True, + # error if `ambiguousorinvalidcolumn` rule violated + rules=["ambiguousorinvalidcolumn"], + # but only warn if "invalidselectstarexpansion" is violated + warn_rules=["invalidselectstarexpansion"], + ) + ) + ``` + +SQLMesh will raise an error if the same rule is included in more than one of the `rules`, `warn_rules`, and `ignored_rules` keys since they should be mutually exclusive. \ No newline at end of file diff --git a/docs/guides/linter/linter_example.png b/docs/guides/linter/linter_example.png new file mode 100644 index 0000000000..d88ea1bac9 Binary files /dev/null and b/docs/guides/linter/linter_example.png differ diff --git a/docs/guides/migrations.md b/docs/guides/migrations.md index 2847e1b3af..f65a34460a 100644 --- a/docs/guides/migrations.md +++ b/docs/guides/migrations.md @@ -36,8 +36,3 @@ Migrations should ideally run when no one will be running plan/apply. Migrations should not be run in parallel. Due to these constraints, it is better for a person responsible for managing SQLMesh to manually issue migrations. Therefore, it is not recommended to issue migrations from CI/CD pipelines. - -### Airflow Scheduler Migrations - -If using Airflow, migrations are automatically run after the SQLMesh version is upgraded and cluster is restarted. -Therefore, migrations **should not** be run manually. diff --git a/docs/guides/model_selection.md b/docs/guides/model_selection.md index 109f2bc8d3..79fd17a18c 100644 --- a/docs/guides/model_selection.md +++ b/docs/guides/model_selection.md @@ -2,7 +2,7 @@ This guide describes how to select specific models to include in a SQLMesh plan, which can be useful when modifying a subset of the models in a SQLMesh project. -Note: the selector syntax described below is also used for the SQLMesh `plan` [`--allow-destructive-model` selector](../concepts/plans.md#destructive-changes). +Note: the selector syntax described below is also used for the SQLMesh `plan` [`--allow-destructive-model` and `--allow-additive-model` selectors](../concepts/plans.md#destructive-changes) and for the `table_diff` command to [diff a selection of models](./tablediff.md#diffing-multiple-models-across-environments). ## Background @@ -62,7 +62,7 @@ The upstream/downstream indicator may be combined with the wildcard operator. Fo The combination of the upstream/downstream indicator, wildcards, and multiple `--select-model` arguments enables granular and complex model selections for a plan. -Upstream/downstream indicators also apply to tags. For example, `--select-model "tag:+reporting*"` would select all models with tags that start with `reporting` and their upstream models. +Upstream/downstream indicators also apply to tags. For example, `--select-model "+tag:reporting*"` would select all models with tags that start with `reporting` and their upstream models. ## Backfill @@ -78,7 +78,7 @@ NOTE: the `--backfill-model` argument can only be used in development environmen ## Examples -We now demonstrate the use of `--select-model` and `--backfill-model` with the SQLMesh `sushi` example project, available in the `examples/sushi` directory of the [SQLMesh Github repository](https://github.com/TobikoData/sqlmesh). +We now demonstrate the use of `--select-model` and `--backfill-model` with the SQLMesh `sushi` example project, available in the `examples/sushi` directory of the [SQLMesh Github repository](https://github.com/SQLMesh/sqlmesh). ### sushi @@ -100,7 +100,10 @@ If we run a `plan` without selecting specific models, SQLMesh includes the two d ```bash ❯ sqlmesh plan dev -Summary of differences against `dev`: +New environment `dev` will be created from `prod` + +Differences from the `prod` environment: + Models: ├── Directly Modified: │ ├── sushi.order_items @@ -118,7 +121,10 @@ If we specify the `--select-model` option to select `"sushi.order_items"`, the d ```bash ❯ sqlmesh plan dev --select-model "sushi.order_items" -Summary of differences against `dev`: +New environment `dev` will be created from `prod` + +Differences from the `prod` environment: + Models: ├── Directly Modified: │ └── sushi.order_items @@ -135,7 +141,10 @@ If we specify the `--select-model` option with the upstream `+` to select `"+sus ```bash ❯ sqlmesh plan dev --select-model "+sushi.order_items" -Summary of differences against `dev`: +New environment `dev` will be created from `prod` + +Differences from the `prod` environment: + Models: ├── Directly Modified: │ ├── sushi.items @@ -153,9 +162,12 @@ If we specify the `--select-model` option to select `"sushi.items"`, SQLMesh doe However, it does classify `sushi.order_items` as indirectly modified. Its direct modification is excluded by the model selection, but it is indirectly modified by being downstream of the selected `sushi.items` model: -```bash hl_lines="7" +```bash hl_lines="10" ❯ sqlmesh plan dev --select-model "sushi.items" -Summary of differences against `dev`: +New environment `dev` will be created from `prod` + +Differences from the `prod` environment: + Models: ├── Directly Modified: │ └── sushi.items @@ -173,7 +185,10 @@ If we specify the `--select-model` option with the downstream `+` to select `"su ```bash ❯ sqlmesh plan dev --select-model "sushi.items+" -Summary of differences against `dev`: +New environment `dev` will be created from `prod` + +Differences from the `prod` environment: + Models: ├── Directly Modified: │ ├── sushi.items @@ -191,7 +206,10 @@ If we specify the `--select-model` option with the wildcard `*` to select `"sush ```bash ❯ sqlmesh plan dev --select-model "sushi.*items" -Summary of differences against `dev`: +New environment `dev` will be created from `prod` + +Differences from the `prod` environment: + Models: ├── Directly Modified: │ ├── sushi.order_items @@ -203,6 +221,82 @@ Models: └── sushi.customer_revenue_lifetime ``` +#### Select with tags + +If we specify the `--select-model` option with a tag selector like `"tag:reporting"`, all models with the "reporting" tag will be selected. Tags are case-insensitive and support wildcards: + +```bash +❯ sqlmesh plan dev --select-model "tag:reporting*" +New environment `dev` will be created from `prod` + +Differences from the `prod` environment: + +Models: +├── Directly Modified: +│ ├── sushi.daily_revenue +│ └── sushi.monthly_revenue +└── Indirectly Modified: + └── sushi.revenue_dashboard +``` + +#### Select with git changes + +The git-based selector allows you to select models whose files have changed compared to a target branch (default: main). This includes: + +- Untracked files (new files not in git) +- Uncommitted changes in working directory (both staged and unstaged) +- Committed changes different from the target branch + +For example: + +```bash +❯ sqlmesh plan dev --select-model "git:feature" +New environment `dev` will be created from `prod` + +Differences from the `prod` environment: + +Models: +├── Directly Modified: +│ └── sushi.items # Changed in feature branch +└── Indirectly Modified: + ├── sushi.order_items + └── sushi.daily_revenue +``` + +You can also combine git selection with upstream/downstream indicators: + +```bash +❯ sqlmesh plan dev --select-model "git:feature+" +# Selects changed models and their downstream dependencies + +❯ sqlmesh plan dev --select-model "+git:feature" +# Selects changed models and their upstream dependencies +``` + +#### Complex selections with logical operators + +The model selector supports combining multiple conditions using logical operators: + +- `&` (AND): Both conditions must be true +- `|` (OR): Either condition must be true +- `^` (NOT): Negates a condition + +For example: + +```bash +❯ sqlmesh plan dev --select-model "(tag:finance & ^tag:deprecated)" +# Selects models with finance tag that don't have deprecated tag + +❯ sqlmesh plan dev --select-model "(+model_a | model_b+)" +# Selects model_a and its upstream deps OR model_b and its downstream deps + +❯ sqlmesh plan dev --select-model "(tag:finance & git:main)" +# Selects changed models that also have the finance tag + +❯ sqlmesh plan dev --select-model "^(tag:test) & metrics.*" +# Selects models in metrics schema that don't have the test tag +``` + ### Backfill examples #### No backfill selection diff --git a/docs/guides/models.md b/docs/guides/models.md index 94e99a9ee2..e3b4ab1cfa 100644 --- a/docs/guides/models.md +++ b/docs/guides/models.md @@ -60,12 +60,16 @@ To preview changes using `plan`: 1. Enter the `sqlmesh plan ` command. 2. Enter `1` to classify the changes as `Breaking`, or enter `2` to classify the changes as `Non-Breaking`. In this example, the changes are classified as `Non-Breaking`: -```hl_lines="23 24" +```bash linenums="1" hl_lines="27-28" $ sqlmesh plan dev ====================================================================== Successfully Ran 1 tests against duckdb ---------------------------------------------------------------------- -Summary of differences against `dev`: +New environment `dev` will be created from `prod` + +Differences from the `prod` environment: + +Models ├── Directly Modified: │ └── sqlmesh_example.incremental_model └── Indirectly Modified: @@ -115,12 +119,14 @@ To revert your change: 1. Open the model file you wish to edit in your preferred editor, and undo a change you made earlier. For this example, we'll remove the column we added in the [quickstart](../quick_start.md) example. 2. Run `sqlmesh plan` and apply your changes. Enter `y` to run a Virtual Update. -```hl_lines="24" +```bash linenums="1" hl_lines="26" $ sqlmesh plan dev ====================================================================== Successfully Ran 1 tests against duckdb ---------------------------------------------------------------------- -Summary of differences against `dev`: +Differences from the `dev` environment: + +Models ├── Directly Modified: │ └── sqlmesh_example.incremental_model └── Indirectly Modified: @@ -187,7 +193,9 @@ To delete a model: ====================================================================== Successfully Ran 0 tests against duckdb ---------------------------------------------------------------------- - Summary of differences against `dev`: + Differences from the `dev` environment: + + Models └── Removed Models: └── sqlmesh_example.full_model Apply - Virtual Update [y/n]: y @@ -203,12 +211,14 @@ To delete a model: 3. Plan and apply your changes to production, and enter `y` for the Virtual Update. By default, the `sqlmesh plan` command targets your production environment: - ``` + ```bash linenums="1" $ sqlmesh plan ====================================================================== Successfully Ran 0 tests against duckdb ---------------------------------------------------------------------- - Summary of differences against `prod`: + Differences from the `prod` environment: + + Models └── Removed Models: └── sqlmesh_example.full_model Apply - Virtual Update [y/n]: y diff --git a/docs/guides/multi_engine.md b/docs/guides/multi_engine.md new file mode 100644 index 0000000000..f2ccd31394 --- /dev/null +++ b/docs/guides/multi_engine.md @@ -0,0 +1,312 @@ +# Multi-Engine guide + +Organizations typically connect to a data warehouse through a single engine to ensure data consistency. However, there are cases where the processing capabilities of one engine may be better suited to specific tasks than another. + +Companies are increasingly decoupling how/where data is stored from the how computations are run on the data, requiring interoperability across platforms and tools. Open table formats like Apache Iceberg, Delta Lake, and Hive provide a common storage format that can be used by multiple SQL engines. + +SQLMesh enables this decoupling by supporting multiple engine adapters within a single project, giving you the flexibility to choose the best engine for each computational task. You can specify the engine each model uses, based on what computations the model performs or other organization-specific considerations. + +## Configuring a Project with Multiple Engines + +Configuring your project to use multiple engines follows a simple process: + +- Include all required [gateway connections](../reference/configuration.md#connection) in your configuration. +- Specify the `gateway` to be used for execution in the `MODEL` DDL. + +If no gateway is explicitly defined for a model, the [default_gateway](../reference/configuration.md#default-gateway) of the project is used. + +By default, virtual layer views are created in the `default_gateway`. This approach requires that all engines can read from and write to the same shared catalog, so a view in the `default_gateway` can access a table in another gateway. + +Alternatively, each gateway can create the virtual layer views for the models it runs. Use this approach by setting the [gateway_managed_virtual_layer](#gateway-managed-virtual-layer) flag to `true` in your project configuration. + +### Shared Virtual Layer + +To dive deeper, in SQLMesh the [physical layer](../concepts/glossary.md#physical-layer) is the concrete data storage layer, where it stores and manages data in database tables and materialized views. + +While, the [virtual layer](../concepts/glossary.md#virtual-layer) consists of views, one for each model, each pointing to a snapshot table in the physical layer. + +In a multi-engine project with a shared data catalog, the model-specific gateway is responsible for the physical layer, while the default gateway is used for managing the virtual layer. + +#### Example: DuckDB + PostgreSQL + +Below is a simple example of setting up a project with connections to both DuckDB and PostgreSQL. + +In this setup, the PostgreSQL engine is set as the default, so it will be used to manage views in the virtual layer. Meanwhile, DuckDB's [attach](https://duckdb.org/docs/sql/statements/attach.html) feature enables read-write access to the PostgreSQL catalog's physical tables. + +=== "YAML" + + ```yaml linenums="1" + gateways: + duckdb: + connection: + type: duckdb + catalogs: + main_db: + type: postgres + path: 'dbname=main_db user=postgres host=127.0.0.1' + extensions: + - name: iceberg + postgres: + connection: + type: postgres + database: main_db + user: user + password: password + host: 127.0.0.1 + port: 5432 + default_gateway: postgres + ``` + +=== "Python" + + ```python linenums="1" + from sqlmesh.core.config import ( + Config, + ModelDefaultsConfig, + GatewayConfig, + DuckDBConnectionConfig, + PostgresConnectionConfig + ) + from sqlmesh.core.config.connection import DuckDBAttachOptions + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="postgres"), + gateways={ + "duckdb": GatewayConfig( + connection=DuckDBConnectionConfig( + catalogs={ + "main_db": DuckDBAttachOptions( + type="postgres", + path="dbname=main_db user=postgres host=127.0.0.1" + ), + }, + extensions=["iceberg"], + ) + ), + "postgres": GatewayConfig( + connection=PostgresConnectionConfig( + host="127.0.0.1", + port=5432, + user="postgres", + password="password", + database="main_db", + ) + ), + }, + default_gateway="postgres", + ) + ``` + +Given this configuration, when a model’s gateway is set to DuckDB, the DuckDB engine will perform the calculations before materializing the physical table in the PostgreSQL `main_db` catalog. + +```sql linenums="1" +MODEL ( + name orders.order_ship_date, + kind FULL, + gateway duckdb, +); + +SELECT + l_orderkey, + l_shipdate +FROM + iceberg_scan('data/bucket/lineitem_iceberg', allow_moved_paths = true); +``` + +The `order_ship_date` model specifies the DuckDB engine, which will perform the computations used to create the physical table in the PostgreSQL database. + +This allows you to efficiently scan data from an Iceberg table, or even query tables directly from S3 when used with the [HTTPFS](https://duckdb.org/docs/stable/extensions/httpfs/overview.html) extension. + +![Figure 1: PostgreSQL + DuckDB](./multi_engine/postgres_duckdb.png) +*Figure 1: The gateways denote the execution engine, while both the virtual layer’s views and the physical layer's tables reside in Postgres* + +In models where no gateway is specified, such as the `customer_orders` model, the default PostgreSQL engine will both create the physical table and the views in the virtual layer. + +### Gateway-Managed Virtual Layer + +By default, all virtual layer views are created in the project's default gateway. + +If your project's engines don’t have a mutually accessible catalog or your raw data is located in different engines, you may prefer for each model's virtual layer view to exist in the gateway that ran the model. This allows a single SQLMesh project to manage isolated sets of models in different gateways, which is sometimes necessary for data governance or security concerns. + +To enable this, set `gateway_managed_virtual_layer` to `true` in your configuration. By default, this flag is set to false. + +#### Example: Redshift + Athena + Snowflake + +Consider a scenario where you need to create a project with models in Redshift, Athena and Snowflake, where each engine hosts its models' virtual layer views. + +First, add the connections to your configuration and set the `gateway_managed_virtual_layer` flag to `true`: + +=== "YAML" + + ```yaml linenums="1" hl_lines="30" + gateways: + redshift: + connection: + type: redshift + user: + password: + host: + database: + variables: + gw_var: 'redshift' + athena: + connection: + type: athena + aws_access_key_id: + aws_secret_access_key: + s3_warehouse_location: + variables: + gw_var: 'athena' + snowflake: + connection: + type: snowflake + account: + user: + database: + warehouse: + variables: + gw_var: 'snowflake' + + default_gateway: redshift + gateway_managed_virtual_layer: true + + variables: + gw_var: 'global' + global_var: 5 + ``` + +=== "Python" + + ```python linenums="1" hl_lines="48" + from sqlmesh.core.config import ( + Config, + ModelDefaultsConfig, + GatewayConfig, + RedshiftConnectionConfig, + AthenaConnectionConfig, + SnowflakeConnectionConfig, + ) + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="redshift"), + gateways={ + "redshift": GatewayConfig( + connection=RedshiftConnectionConfig( + user="", + password="", + host="", + database="", + ), + variables={ + "gw_var": "redshift" + }, + ), + "athena": GatewayConfig( + connection=AthenaConnectionConfig( + aws_access_key_id="", + aws_secret_access_key="", + region_name="", + s3_warehouse_location="", + ), + variables={ + "gw_var": "athena" + }, + ), + "snowflake": GatewayConfig( + connection=SnowflakeConnectionConfig( + account="", + user="", + database="", + warehouse="", + ), + variables={ + "gw_var": "snowflake" + }, + ), + }, + default_gateway="redshift", + gateway_managed_virtual_layer=True, + variables={ + "gw_var": "global", + "global_var": 5, + }, + ) + ``` + +Note that gateway-specific variables take precedence over global ones. In the example above, the `gw_var` used in a model will resolve to the value specified in the model's gateway. + +For further customization, you can also enable [gateway-specific model defaults](../guides/configuration.md#gateway-specific-model-defaults). This allows you to define custom behaviors, such as specifying a dialect with case-insensitivity normalization. + +In the example configuration above the default gateway is `redshift`, so all models without a `gateway` specification will run on redshift, as in this `order_dates` model: + +```sql linenums="1" +MODEL ( + name redshift_schema.order_dates, + table_format iceberg, +); + +SELECT + order_date, + order_id +FROM + bucket.raw_data; +``` + +For the `athena_schema.order_status` model, we explicitly specify the `athena` gateway: + +```sql linenums="1" hl_lines="4" +MODEL ( + name athena_schema.order_status, + table_format iceberg, + gateway athena, +); + +SELECT + order_id, + status +FROM + bucket.raw_data; +``` + +Finally, specifying the `snowflake` gateway for the `customer_orders` model ensures it is isolated from the rest and reads from a table within the Snowflake database: + +```sql linenums="1" hl_lines="4" +MODEL ( + name snowflake_schema.customer_orders, + table_format iceberg, + gateway snowflake +); + +SELECT + customer_id, + orders +FROM + bronze_schema.customer_data; +``` + + +![Figure 2: Athena + Redshift + Snowflake](./multi_engine/athena_redshift_snowflake.png) +*Figure 2: The gateways represent the execution engine and indicate where the virtual layer’s views and the physical layer's tables reside* + +When you run the plan, the catalogs for each model will be set automatically based on the gateway’s connection and each corresponding model will be executed by the specified engine: + +```bash +❯ sqlmesh plan + +`prod` environment will be initialized + +Models: +└── Added: + ├── awsdatacatalog.athena_schema.order_status # each model uses its gateway's catalog and schema + ├── redshift_schema.order_dates + └── silver.snowflake_schema.customers +Models needing backfill: +├── awsdatacatalog.athena_schema.order_status: [full refresh] +├── redshift_schema.order_dates: [full refresh] +└── silver.snowflake_schema.customer_orders: [full refresh] +Apply - Backfill Tables [y/n]: y +``` + +The views of the virtual layer will also be created by each corresponding engine. + +This approach provides isolation between your models, while maintaining centralized control over your project. diff --git a/docs/guides/multi_engine/athena_redshift_snowflake.png b/docs/guides/multi_engine/athena_redshift_snowflake.png new file mode 100644 index 0000000000..db2cff2d17 Binary files /dev/null and b/docs/guides/multi_engine/athena_redshift_snowflake.png differ diff --git a/docs/guides/multi_engine/postgres_duckdb.png b/docs/guides/multi_engine/postgres_duckdb.png new file mode 100644 index 0000000000..0afcd500ca Binary files /dev/null and b/docs/guides/multi_engine/postgres_duckdb.png differ diff --git a/docs/guides/multi_repo.md b/docs/guides/multi_repo.md index 19e8d2d4b3..4dae4de57e 100644 --- a/docs/guides/multi_repo.md +++ b/docs/guides/multi_repo.md @@ -1,9 +1,11 @@ # Multi-Repo guide -Although mono repos are convenient and easy to use, sometimes your organization may choose to use multiple repos. SQLMesh provides native support for multiple repos and makes it easy to maintain data consistency and correctness even with multiple repos. +Although mono repos are convenient and easy to use, sometimes your organization may choose to use multiple repos. +SQLMesh provides native support for multiple repos and makes it easy to maintain data consistency and correctness even with multiple repos. +If you are wanting to separate your systems/data and provide isolation, checkout the [isolated systems guide](https://sqlmesh.readthedocs.io/en/stable/guides/isolated_systems/?h=isolated). ## Bootstrapping multiple projects -Setting up SQLMesh with multiple repos is quite simple. Copy the contents of this example [multi-repo project](https://github.com/TobikoData/sqlmesh/tree/main/examples/multi). +Setting up SQLMesh with multiple repos is quite simple. Copy the contents of this example [multi-repo project](https://github.com/SQLMesh/sqlmesh/tree/main/examples/multi). To bootstrap the project, you can point SQLMesh at both projects. @@ -12,9 +14,10 @@ $ sqlmesh -p examples/multi/repo_1 -p examples/multi/repo_2/ plan ====================================================================== Successfully Ran 0 tests against duckdb ---------------------------------------------------------------------- -New environment `prod` will be created from `prod` -Summary of differences against `prod`: -└── Added Models: +`prod` environment will be initialized + +Models +└── Added: ├── silver.d ├── bronze.a ├── bronze.b @@ -62,7 +65,9 @@ $ sqlmesh -p examples/multi/repo_1 plan ====================================================================== Successfully Ran 0 tests against duckdb ---------------------------------------------------------------------- -Summary of differences against `prod`: +Differences from the `prod` environment: + +Models ├── Directly Modified: │ └── bronze.a └── Indirectly Modified: @@ -121,7 +126,9 @@ $ sqlmesh -p examples/multi/repo_1 plan ====================================================================== Successfully Ran 0 tests against duckdb ---------------------------------------------------------------------- -Summary of differences against `prod`: +Differences from the `prod` environment: + +Models ├── Directly Modified: │ └── bronze.a └── Indirectly Modified: @@ -166,7 +173,7 @@ SQLMesh correctly detects a breaking change and allows you to perform a multi-re ## Configuring projects with multiple repositories -To add support for multiple repositories, add a `project` key to the config file in each of the respective repos. +To add support for multiple repositories, add a `project` key to the config file in each of the respective repos. ```yaml project: repo_1 @@ -177,3 +184,32 @@ gateways: Even if you do not have a need for multiple repos now, consider adding a `project` key so that you can easily support multiple repos in the future. +## Running migrations with multiple repositories + +When doing a [migration](./migrations.md), pass in a single repo path using the `-p` flag. It doesn't matter which repo you choose. + +``` +$ sqlmesh -p examples/multi/repo_1 migrate +``` + +## Multi-Repo dbt projects + +SQLMesh also supports multiple repos for dbt projects, allowing it to correctly detect changes and orchestrate backfills even when changes span multiple dbt projects. + +You can watch a [quick demo](https://www.loom.com/share/69c083428bb348da8911beb2cd4d30b2) of this setup or experiment with the [multi-repo dbt example](https://github.com/SQLMesh/sqlmesh/tree/main/examples/multi_dbt) yourself. + +## Multi-repo mixed projects + +Native SQLMesh projects can be used alongside dbt projects in a multi-repo setup. + +This allows managing and sourcing tables from either project type within the same multi-repo project and facilitates a gradual migration from dbt to SQLMesh. + +Use the same syntax as SQLMesh-only multi-repo projects to execute a multi-repo project with either dbt or a combination of dbt and SQLMesh projects: + +``` +$ sqlmesh -p examples/multi_hybrid/dbt_repo -p examples/multi_hybrid/sqlmesh_repo plan +``` + +SQLMesh will automatically detect dependencies and lineage across both SQLMesh and dbt projects, even when models are sourcing from different project types. + +For an example of this setup, refer to the [mixed SQLMesh and dbt example](https://github.com/SQLMesh/sqlmesh/tree/main/examples/multi_hybrid). diff --git a/docs/guides/notifications.md b/docs/guides/notifications.md index 85beae6c3b..749a71c842 100644 --- a/docs/guides/notifications.md +++ b/docs/guides/notifications.md @@ -130,7 +130,7 @@ This example stops all notifications other than those for `User1`: SQLMesh notifications are triggered by events. The events that should trigger a notification are specified in the notification target's `notify_on` field. -Notifications are support for [`plan` application](../concepts/plans.md) start/end/failure, [`run`](../reference/cli.md#run) start/end/failure, and [`audit`](../concepts/audits.md) failures. +Notifications are supported for [`plan` application](../concepts/plans.md) start/end/failure, [`run`](../reference/cli.md#run) start/end/failure, and [`audit`](../concepts/audits.md) failures. For `plan` and `run` start/end, the target environment name is included in the notification message. For failures, the Python exception or error text is included in the notification message. @@ -256,7 +256,7 @@ This example shows an email notification target, where `sushi@example.com` email In Python configuration files, new notification targets can be configured to send custom messages. -To customize a notification, create a new notification target class as a subclass of one of the three target classes described above (`SlackWebhookNotificationTarget`, `SlackApiNotificationTarget`, or `BasicSMTPNotificationTarget`). See the definitions of these classes on Github [here](https://github.com/TobikoData/sqlmesh/blob/main/sqlmesh/core/notification_target.py). +To customize a notification, create a new notification target class as a subclass of one of the three target classes described above (`SlackWebhookNotificationTarget`, `SlackApiNotificationTarget`, or `BasicSMTPNotificationTarget`). See the definitions of these classes on Github [here](https://github.com/SQLMesh/sqlmesh/blob/main/sqlmesh/core/notification_target.py). Each of those notification target classes is a subclass of `BaseNotificationTarget`, which contains a `notify` function corresponding to each event type. This table lists the notification functions, along with the contextual information available to them at calling time (e.g., the environment name for start/end events): diff --git a/docs/guides/observer.md b/docs/guides/observer.md deleted file mode 100644 index e3ec9c5ebb..0000000000 --- a/docs/guides/observer.md +++ /dev/null @@ -1,308 +0,0 @@ -# SQLMesh Observer - -Data pipelines break. Upstream sources change without warning, buggy code gets merged, and cloud services randomly time out. These problems are ubiquitous, and someone is responsible for fixing them (probably you if you're reading this). - -SQLMesh Observer provides the information you need to rapidly detect, understand, and remedy problems with SQLMesh data transformation pipelines. - -This page describes how to install, run, and use SQLMesh Observer. - -## Context - -### The Challenge - -Remediating problems with data pipelines is challenging because there are so many potential causes. For transformation pipelines, those range from upstream source timeouts to SQL query errors to Python library conflicts (and more!). - -A useful observation tool should enable answering the following questions: - -- Did a problem occur? -- When did it occur? -- What type of problem is it? -- Where is the problem coming from? -- What is causing the problem? - -SQLMesh Observer supports answering these questions in four ways: - -1. Automatically [notifying users](./notifications.md) if a problem occurs -2. Capturing, storing, and displaying historical measures to reveal when a problem occurred -3. Enabling easy navigation from aggregated to granular information about pipeline components to identify the problem source -4. Centralizing error information from multiple sources to debug the problem - -### Measures - -SQLMesh Observer automatically captures and stores measures from all SQLMesh actions. We now briefly review the SQLMesh workflow before describing the different measures Observer captures. - -#### SQLMesh workflow - -The core of a SQLMesh project is its **models**. Roughly, each model consists of one SQL query and metadata that tells SQLMesh about how the model should be processed. - -Each model may have **audits** that validate the data returned by a model (e.g., verifying that a column contains no `NULL` values). By default, SQLMesh will stop running a project if an audit fails for any of its models. - -When you run a project on a SQL engine, you must choose an **environment** in which to run it. Environments allow people to modify projects in an isolated space that won't interfere with anyone else (or the version of the project running in production). - -SQLMesh stores a unique fingerprint of the project's content on each run so it can determine if any of that content has changed the next time you run it in that environment. - -When a project's content has changed, an environment is updated to reflect those changes with a SQLMesh **plan**. The plan identifies all the changes and determines which data will be affected by them so it only has to re-run the relevant models. - -After changes have been applied with a plan, the project is **run** on a schedule to process new data that has arrived since the previous run. - -The five entities in bold - models, audits, environments, runs, and plans - provide the information SQLMesh Observer captures to help you efficiently identify and remediate problems with your transformation pipeline. - -#### Data - -We now describe the specific measures SQLMesh captures about each entity. - -SQLMesh performs its primary actions during **plans** and **runs**, so most measures are generated when they occur. Both plans and runs are executed in a specific **environment**, so all of their measures are environment-specific. - -These measures are recorded and stored for each plan or run in a specific environment: - -- When it began and ended -- Total run time -- Whether it failed -- Whether and how any model audits failed -- The model versions evaluated during the plan/run -- Each model's run time - -Additionally, you can define [custom measures](#custom-measures) that will be captured for each model. - -## Installation - -SQLMesh Observer is part of the `sqlmesh-enterprise` Python library and is installed via `pip`. - -Installation requires a license key provided by Tobiko Data. You include the license key in the `pip` install command executed from the command line. It is quite long, so we recommend placing it in a file that the installation command reads. In this example, we have stored the key in a `txt` file: - -![SQLMesh Enterprise key stored in txt file](./observer/observer_key-file.png) - -Run the installation command and read the key file with the following command. The key is passed to the `--extra-index-url` argument, either directly by pasting the key into the command or by reading the key from file with an embedded `cat` command. You should replace `` with the path to your key file: - -``` bash -> pip install "sqlmesh-enterprise" --extra-index-url "$(cat )" -``` - -`sqlmesh-enterprise` works by overriding components of `sqlmesh` open source, and installing `sqlmesh-enterprise` will automatically install open-source `sqlmesh`. - -SQLMesh extras, such as SQL engine drivers, can be passed directly to the `sqlmesh-enterprise` installation command. This example installs the SQLMesh Slack notification and Snowflake engine driver extras: - -``` bash -> pip install "sqlmesh-enterprise[slack,snowflake]" --extra-index-url "$(cat )" -``` - -NOTE: `sqlmesh-enterprise` will not function properly if open-source `sqlmesh` is installed after it. - -## Startup - -As with the open-source [SQLMesh Browser UI](../quickstart/ui.md), SQLMesh Observer is initiated from the command line then opened in a web browser. - -First, navigate to your project directory in the CLI. Then start Observer by running the `sqlmesh observe` command: - -```bash -sqlmesh observe -``` - -After starting up, SQLMesh Observer is served at `http://127.0.0.1:8000` by default: - -![SQLMesh Observer startup on CLI](./observer/observer_cli.png) - -Navigate to the URL by clicking the link in your terminal (if supported) or copy-pasting it into your web browser: - -![SQLMesh Observer dashboard interface](./observer/observer_dashboard.png) - -## Interface - -We now describe the components of the SQLMesh Observer user interface. - -### Dashboard - -The "Dashboard" page is displayed when Observer starts - it consists of the following components: - -1. Links to the other two pages, "Environments" and "Plan Applications," in the top left -2. Counts and links to key information about environments, models, and plans in the top center -3. Interactive chart of historical `run` run times in the middle center -4. Interactive chart of historical audit failure counts in the bottom left -5. Interactive chart of historical `run` failures in the bottom right - -![SQLMesh Observer dashboard](./observer/observer_dashboard-components.png) - -### Charts - -Observer presents historical information via charts and tables. Most charts represent time on the x-axis and share the same appearance and user options. - -In a chart's top left corner is the `Time` selector, which sets the range of the x-axis. For example, the first chart displays 1 week of data, from November 27 through December 4. The second chart displays the same data but includes 3 months of historical data beginning on September 4: - -![SQLMesh Observer chart x-axis time selector](./observer/observer_chart-time-selector.png) - -In a chart's top right corner is the `Scale` selector, which toggles between a linear and log y-axis scale. A log scale may be helpful for comparing highly variable data series over time. This example displays the data from the second chart in the previous figure with a log y-axis scale: - -![SQLMesh Observer chart y-axis scale selector](./observer/observer_chart-scale-selector.png) - -Charts also display the data underlying a specific data point when the mouse hovers over it: - -![SQLMesh Observer chart mouse hover](./observer/observer_chart-hover.png) - -Many charts display purple `Plan` markers, which provide contextual information about when changes to the project occurred. Clicking on the marker will open a page containing [more information about the plan](#plan-applications). - -Some Observer tables include a button that toggles a chart of the measures in the table: - -![SQLMesh Observer table chart toggle](./observer/observer_table-chart-toggle.png) - - -### Environments - -Access the `Environments` landing page via the navigation links in the dashboard's top left. It displays a table listing each SQLMesh environment, the date it was created, the date it was last updated, and the date it expires (after which the SQLMesh janitor will delete it). The `prod` environment is always present and has no expiration date. - -![SQLMesh Observer environment landing page](./observer/observer_environments-landing.png) - -Clicking an environment's name in the table open's the environment's information page. The page begins with historical charts of run time, audit failures, and evaluation failures: - -![SQLMesh Observer environment information page](./observer/observer_environments-info-1.png) - -The page continues with lists of recent audit failures, evaluation failure, and model evaluations: - -![SQLMesh Observer environment information: recent occurrences](./observer/observer_environments-info-2.png) - -The page finishes with a list of models that differ from those currently in the `prod` environment, a list of the audits that have historically failed most frequently, a list of the models that have historically failed most frequently, and a list of the models with the longest run times: - -![SQLMesh Observer environment information: historical outliers](./observer/observer_environments-info-3.png) - -Each model differing from the `prod` environment may be expanded to view the text diff between the two. The models are listed separately based on whether the plan directly or indirectly modified them, and breaking changes are indicated with an orange "Breaking" label: - -![SQLMesh Observer environment information: model text diff](./observer/observer_environments-info-prod-diff.png) - -### Plan Applications - -Access the `Plan Applications` landing page via the navigation links in the dashboard's top left. It displays a table listing each SQLMesh project plan that has been applied and includes the following information about each: - -- Plan ID -- Previous plan ID (most recent plan executed prior) -- Environment to which the plan was applied (with link to environment information page) -- A count of models in the plan (with link to the plan's models) -- Whether the plan included model restatements -- Whether the plan was in forward-only mode -- The start and end dates of the time interval covered by the plan -- The start and end times of the plan application - -![SQLMesh Observer plans list](./observer/observer_plans-list.png) - -Clicking a Plan ID opens its information page, which lists the information included in the landing page table and links to models added or modified by the plan: - -![SQLMesh Observer plan information page](./observer/observer_plans-information.png) - -Modified models can be expanded to display a text diff of the change: - -![SQLMesh Observer plan text diff](./observer/observer_plans-text-diff.png) - -### Models - -A model can change over time, so its information is associated with a specific SQLMesh environment and plan. Access a model's page via links in a plan or environment page. - -The model information page begins with historical charts of model run time, audit failures, and evaluation failures: - -![SQLMesh Observer model charts](./observer/observer_model-information-1.png) - -It continues with details about the model, including its metadata (e.g., model dialect and kind), model text, and list of previous model versions and text diffs: - -![SQLMesh Observer model details](./observer/observer_model-information-2.png) - -Next, the Loaded Intervals section displays the time intervals that have been loaded and are currently present in the model's physical table, and the Recent Model Evaluations section lists the time interval each evaluation processed and the evaluation's start and end times: - -![SQLMesh Observer model time intervals](./observer/observer_model-information-3.png) - -The model information page concludes with a list of most frequent audits the model has failed, the most frequent time intervals that failed, and the largest historical model run times: - -![SQLMesh Observer historical outliers](./observer/observer_model-information-4.png) - -## Custom measures - -SQLMesh Observer allows you to calculate and track custom measures in addition to the ones it [automatically calculates](#data). - -### Definition - -Each custom measure is associated with a model and is defined by a SQL query in the model file. - -The `@measure` macro is used to define custom measures. The body of the `@measure` macro is the query, and each column in the query defines a separate measure. - -A measure's name is the name of the column that defined it. Measure names must be unique within a model, but a name may be used in multiple models. - -A model may contain more than one `@measure` macro specification. The `@measure` macros must be specified after the model's primary query. They will be executed during a SQLMesh `plan` or `run` after the primary model query is executed. - -This example shows a model definition that includes a measure query defining two measures: `row_count` (the total number of rows in the table) and `num_col_avg` (the average value of the model's `numeric_col` column). - -```sql -MODEL ( - name custom_measure.example, - kind FULL -); - -SELECT - numeric_col -FROM - custom_measure.upstream; - -@measure( -- Measure query specified in the `@measure` macro - SELECT - COUNT(*) AS row_count, -- Table's row count - AVG(numeric_col) AS num_col_avg -- Average value of `numeric_col` - FROM custom_measure.example -- Select FROM the name of the model -); -``` - -Every time the `custom_measure.example` model is executed, Observer will execute the measure query and store the value it returns. - -By default, the measure's timestamp will be the execution time of the `plan`/`run` that captured the measure. [Incremental by time range](../concepts/models/model_kinds.md#incremental_by_time_range) models may specify [custom timestamps](#custom-time-column). - -An Observer chart allows you to select which measure to display. The chart displays the value of the selected measure on the y-axis and the execution time of the associated `plan`/`run` on the x-axis, allowing you to monitor whether the value has meaningfully changed since the previous execution. - -### Incremental by time models - -#### Custom time column - -In the previous example, Observer automatically associated each measure value with the execution time of the `plan` or `run` that executed it. - -For [incremental by time range models](../concepts/models/model_kinds.md#incremental_by_time_range), you can customize how measures are associated with time by including your own time column in the measure query. - -The time column must be named `ts` and may be of any datetime data type (e.g., date string, `DATE`, `TIMESTAMP`, etc.). Custom times are typically derived from a datetime column in the model data and are most useful when the measure groups by the datetime. - -For example, this incremental model stores the date of each data point in the `event_datestring` column. We could measure each day's row count and numeric column average with this measure query: - -```sql -MODEL ( - name custom_measure.incremental_example - kind INCREMENTAL_BY_TIME_RANGE ( - time_column event_datestring - ) -); - -SELECT - event_datestring, - numeric_col -FROM - custom_measure.upstream -WHERE - event_datestring BETWEEN @start_ds AND @end_ds; - -@measure( - SELECT - event_datestring AS ts, -- Custom measure time column `ts` - COUNT(*) AS daily_row_count, -- Daily row count - AVG(numeric_col) AS daily_num_col_avg -- Daily average value of `numeric_col` - FROM custom_measure.incremental_example - WHERE event_datestring BETWEEN @start_ds AND @end_ds -- Filter measure on time - GROUP BY event_datestring -- Group measure by time -); -``` - -The measure query both filters and groups the data based on the model's time column `event_datestring`. The filtering and grouping ensures that only one measure value is ever calculated for a specific day of data. - -NOTE: the custom time column approach will not work correctly if the model's [`lookback` argument](../concepts/models/overview.md#lookback) is specified because a given day's data will be processed every time it is in the lookback window. - -#### Execution and custom times - -A model may contain multiple measure queries, so both execution time and custom time measures may be specified for the same model. - -These two measure types help answer different questions: - -1. Execution time: has something meaningfully changed **on this `plan`/`run`** compared to previous plans/runs? -2. Custom time: has something meaningfully changed **in a specific time point's data** compared to other time points? - -If multiple time points of data are processed during each model execution, an anomaly at a specific time may not be detectable from an execution time measure alone. - -Custom time measures enable monitoring at the temporal granularity of the data itself. diff --git a/docs/guides/observer/observer_chart-hover.png b/docs/guides/observer/observer_chart-hover.png deleted file mode 100644 index a06bc605e0..0000000000 Binary files a/docs/guides/observer/observer_chart-hover.png and /dev/null differ diff --git a/docs/guides/observer/observer_chart-scale-selector.png b/docs/guides/observer/observer_chart-scale-selector.png deleted file mode 100644 index fd603301e5..0000000000 Binary files a/docs/guides/observer/observer_chart-scale-selector.png and /dev/null differ diff --git a/docs/guides/observer/observer_chart-time-selector.png b/docs/guides/observer/observer_chart-time-selector.png deleted file mode 100644 index f0baf27adb..0000000000 Binary files a/docs/guides/observer/observer_chart-time-selector.png and /dev/null differ diff --git a/docs/guides/observer/observer_cli.png b/docs/guides/observer/observer_cli.png deleted file mode 100644 index c1237d098b..0000000000 Binary files a/docs/guides/observer/observer_cli.png and /dev/null differ diff --git a/docs/guides/observer/observer_dashboard-components.png b/docs/guides/observer/observer_dashboard-components.png deleted file mode 100644 index ebc9479808..0000000000 Binary files a/docs/guides/observer/observer_dashboard-components.png and /dev/null differ diff --git a/docs/guides/observer/observer_dashboard.png b/docs/guides/observer/observer_dashboard.png deleted file mode 100644 index 9980e73570..0000000000 Binary files a/docs/guides/observer/observer_dashboard.png and /dev/null differ diff --git a/docs/guides/observer/observer_environments-info-1.png b/docs/guides/observer/observer_environments-info-1.png deleted file mode 100644 index 16d914d689..0000000000 Binary files a/docs/guides/observer/observer_environments-info-1.png and /dev/null differ diff --git a/docs/guides/observer/observer_environments-info-2.png b/docs/guides/observer/observer_environments-info-2.png deleted file mode 100644 index 63bba899f2..0000000000 Binary files a/docs/guides/observer/observer_environments-info-2.png and /dev/null differ diff --git a/docs/guides/observer/observer_environments-info-3.png b/docs/guides/observer/observer_environments-info-3.png deleted file mode 100644 index 1cd5c8918c..0000000000 Binary files a/docs/guides/observer/observer_environments-info-3.png and /dev/null differ diff --git a/docs/guides/observer/observer_environments-info-prod-diff.png b/docs/guides/observer/observer_environments-info-prod-diff.png deleted file mode 100644 index e7807be778..0000000000 Binary files a/docs/guides/observer/observer_environments-info-prod-diff.png and /dev/null differ diff --git a/docs/guides/observer/observer_environments-landing.png b/docs/guides/observer/observer_environments-landing.png deleted file mode 100644 index d0315b8a57..0000000000 Binary files a/docs/guides/observer/observer_environments-landing.png and /dev/null differ diff --git a/docs/guides/observer/observer_key-file.png b/docs/guides/observer/observer_key-file.png deleted file mode 100644 index 541f1ede0d..0000000000 Binary files a/docs/guides/observer/observer_key-file.png and /dev/null differ diff --git a/docs/guides/observer/observer_model-information-1.png b/docs/guides/observer/observer_model-information-1.png deleted file mode 100644 index 731233e256..0000000000 Binary files a/docs/guides/observer/observer_model-information-1.png and /dev/null differ diff --git a/docs/guides/observer/observer_model-information-2.png b/docs/guides/observer/observer_model-information-2.png deleted file mode 100644 index 1b77b7c323..0000000000 Binary files a/docs/guides/observer/observer_model-information-2.png and /dev/null differ diff --git a/docs/guides/observer/observer_model-information-3.png b/docs/guides/observer/observer_model-information-3.png deleted file mode 100644 index 6e4a45199b..0000000000 Binary files a/docs/guides/observer/observer_model-information-3.png and /dev/null differ diff --git a/docs/guides/observer/observer_model-information-4.png b/docs/guides/observer/observer_model-information-4.png deleted file mode 100644 index ffe492c19c..0000000000 Binary files a/docs/guides/observer/observer_model-information-4.png and /dev/null differ diff --git a/docs/guides/observer/observer_plans-information.png b/docs/guides/observer/observer_plans-information.png deleted file mode 100644 index ddc341ce68..0000000000 Binary files a/docs/guides/observer/observer_plans-information.png and /dev/null differ diff --git a/docs/guides/observer/observer_plans-list.png b/docs/guides/observer/observer_plans-list.png deleted file mode 100644 index cbded1fe44..0000000000 Binary files a/docs/guides/observer/observer_plans-list.png and /dev/null differ diff --git a/docs/guides/observer/observer_plans-text-diff.png b/docs/guides/observer/observer_plans-text-diff.png deleted file mode 100644 index 096c40135f..0000000000 Binary files a/docs/guides/observer/observer_plans-text-diff.png and /dev/null differ diff --git a/docs/guides/observer/observer_table-chart-toggle.png b/docs/guides/observer/observer_table-chart-toggle.png deleted file mode 100644 index f4af75681a..0000000000 Binary files a/docs/guides/observer/observer_table-chart-toggle.png and /dev/null differ diff --git a/docs/guides/projects.md b/docs/guides/projects.md index 9c78dee3f2..e4dabd76cc 100644 --- a/docs/guides/projects.md +++ b/docs/guides/projects.md @@ -27,25 +27,27 @@ To create a project from the command line, follow these steps: 1. To scaffold a project, it is recommended that you use a python virtual environment by running the following commands: ```bash - python -m venv .env + python -m venv .venv ``` ```bash - source .env/bin/activate + source .venv/bin/activate ``` ```bash pip install sqlmesh ``` - **Note:** When using a python virtual environment, you must ensure that it is activated first. You should see `(.env)` in your command line; if you don't, run `source .env/bin/activate` from your project directory to activate your environment. + **Note:** When using a python virtual environment, you must ensure that it is activated first. You should see `(.venv)` in your command line; if you don't, run `source .venv/bin/activate` from your project directory to activate your environment. 1. Once you have activated your environment, run the following command and SQLMesh will build out your project: ```bash - sqlmesh init + sqlmesh init [SQL_DIALECT] ``` + In the command above, you can use any [SQL dialect supported by sqlglot](https://sqlglot.com/sqlglot/dialects.html), for example "duckdb". + The following directories and files will be created that you can use to organize your SQLMesh project: - config.py (database configuration file) diff --git a/docs/guides/scheduling.md b/docs/guides/scheduling.md index ebc707e8a0..80d58db366 100644 --- a/docs/guides/scheduling.md +++ b/docs/guides/scheduling.md @@ -2,8 +2,8 @@ SQLMesh currently offers two ways of scheduling model evaluation: -* Using the [built-in scheduler](#built-in-scheduler) -* By [integrating with Airflow](#integrating-with-airflow) +* Using [SQLMesh's built-in scheduler](#built-in-scheduler) +* Using [Tobiko Cloud](../cloud/features/scheduler/scheduler.md) ## Built-in scheduler @@ -29,85 +29,3 @@ sqlmesh_example.example_incremental_model ━━━━━━━━━━━━ ``` **Note:** The `sqlmesh run` command performs model evaluation based on the missing data intervals identified at the time of running. It does not run continuously, and will exit once evaluation is complete. You must run this command periodically with a cron job, a CI/CD tool like Jenkins, or in a similar fashion. - - -## Integrating with Airflow - -### Configuring the Airflow cluster - -SQLMesh natively integrates with the popular open source workflow orchestrator [Apache Airflow](https://airflow.apache.org/), both self-hosted and managed (e.g. Google Cloud Composer, Amazon MWAA, Astronomer). - -To integrate with [Airflow](../integrations/airflow.md), ensure that you meet the [prerequisites](/prerequisites), then perform the following: - -1. Install the SQLMesh Python package on all nodes of the Airflow cluster using the following command: - - pip install sqlmesh - - **Note:** The Airflow webserver must be restarted after installation. - -2. Within the Airflow `dags/` folder, create a file called `sqlmesh.py`. - -3. Within that file add the following, making sure to replace "spark" with your engine and `spark_catalog` with your default catalog: - - from sqlmesh.schedulers.airflow.integration import SQLMeshAirflow - - sqlmesh_airflow = SQLMeshAirflow("spark", default_catalog="spark_catalog") - - for dag in sqlmesh_airflow.dags: - globals()[dag.dag_id] = dag - - The example above uses `spark` as the engine of choice. Other engines can be configured instead by providing a corresponding string as an argument to the `SQLMeshAirflow` constructor. Supported strings are `"spark"`, `"databricks"`, `"snowflake"`, `"bigquery"`, `"redshift"`, `"trino"`, `"mssql"` and `"mysql"`. See the [Airflow Cluster Configuration](../integrations/airflow.md#airflow-cluster-configuration) for full list of arguments and their descriptions. - -After setup is completed, the `sqlmesh_janitor_dag` DAG should become available in the Airflow UI when filtered by the `sqlmesh` tag: - -![Airflow UI after successful setup](scheduling/airflow_successful_setup.png) - -### Configuring the client - -On the client side, you must configure the connection to your Airflow cluster in the `config.yaml` file as follows: - - default_scheduler: - type: airflow - airflow_url: http://localhost:8080/ - username: airflow - password: airflow - -Alternatively, the configuration above can be generated automatically as part of the project initialization using the `airflow` template: -```bash -sqlmesh init [PROJECT SQL DIALECT] -t airflow -``` - -For Airflow configuration types specific to Google Cloud Composer, configure the file as follows: - - default_scheduler: - type: cloud_composer - airflow_url: https:/XXXXXXXX.composer.googleusercontent.com/ - -**Note:** Guidelines for integrating with managed offerings other than Google Cloud Composer will be added later. - -### Running the `plan` command - -Run the `sqlmesh plan` command to apply all changes on the target Airflow cluster. - -Below is example output from running the `sqlmesh plan` command in the example project generated by the `sqlmesh init` command: -```bash -$ sqlmesh plan -====================================================================== -Successfully Ran 1 tests against duckdb ----------------------------------------------------------------------- -Summary of differences against `prod`: -└── Added Models: - ├── sqlmesh_example.example_incremental_model - └── sqlmesh_example.example_full_model -Models needing backfill (missing dates): -├── sqlmesh_example.example_incremental_model: (2020-01-01, 2023-02-13) -└── sqlmesh_example.example_full_model: (2023-02-13, 2023-02-13) -Enter the backfill start date (eg. '1 year', '2020-01-01') or blank for the beginning of history: 2023-02-13 -Apply - Backfill Tables [y/n]: y -Waiting for the plan application DAG 'sqlmesh_plan_application__prod__fb88a0c6_16f9_4a3e_93ec_7f8026bc878c' to be provisioned on Airflow -Track plan application progress using link -``` - -Once the command runs, the following DAGs will become available within the Airflow UI: - -![Airflow UI after successful plan application](scheduling/airflow_successful_plan_apply.png) diff --git a/docs/guides/signals.md b/docs/guides/signals.md new file mode 100644 index 0000000000..4c678d729b --- /dev/null +++ b/docs/guides/signals.md @@ -0,0 +1,153 @@ +# Signals guide + +SQLMesh's [built-in scheduler](./scheduling.md#built-in-scheduler) controls which models are evaluated when the `sqlmesh run` command is executed. + +It determines whether to evaluate a model based on whether the model's [`cron`](../concepts/models/overview.md#cron) has elapsed since the previous evaluation. For example, if a model's `cron` was `@daily`, the scheduler would evaluate the model if its last evaluation occurred on any day before today. + +Unfortunately, the world does not always accommodate our data system's schedules. Data may land in our system _after_ downstream daily models already ran. The scheduler did its job correctly, but today's late data will not be processed until tomorrow's scheduled run. + +You can use signals to prevent this problem. + +## What is a signal? + +The scheduler uses two criteria to determine whether a model should be evaluated: whether its `cron` elapsed since the last evaluation and whether it upstream dependencies' runs have completed. + +Signals allow you to specify additional criteria that must be met before the scheduler evaluates the model. + +A signal definition is simply a function that checks whether a criterion is met. Before describing the checking function, we provide some background information about how the scheduler works. + +The scheduler doesn't actually evaluate "a model" - it evaluates a model over a specific time interval. This is clearest for incremental models, where only rows in the time interval are ingested during an evaluation. However, evaluation of non-temporal model kinds like `FULL` and `VIEW` are also based on a time interval: the model's `cron` frequency. + +The scheduler's decisions are based on these time intervals. For each model, the scheduler examines a set of candidate intervals and identifies the ones that are ready for evaluation. + +It then divides those into _batches_ (configured with the model's [batch_size](../concepts/models/overview.md#batch_size) parameter). For incremental models, it evaluates the model once for each batch. For non-incremental models, it evaluates the model once if any batch contains an interval. + +Signal checking functions examines a batch of time intervals. The function is always called with a batch of time intervals (DateTimeRanges). It can also optionally be called with key word arguments. It may return `True` if all intervals are ready for evaluation, `False` if no intervals are ready, or the time intervals themselves if only some are ready. A checking function is defined with the `@signal` decorator. + +!!! note "One model, multiple signals" + + Multiple signals may be specified for a model. SQLMesh categorizes a candidate interval as ready for evaluation if **all** the signal checking functions determine it is ready. + +## Defining a signal + +To define a signal, create a `signals` directory in your project folder. Define your signal in a file named `__init__.py` in that directory (you can have additional python file names as well). + +A signal is a function that accepts a batch (`DateTimeRanges: t.List[t.Tuple[datetime, datetime]]`) and returns a batch or a boolean. It needs to use the `@signal` decorator. + +We now demonstrate signals of varying complexity. + +### Simple example + +This example defines a `RandomSignal` method. + +The method returns `True` (indicating that all intervals are ready for evaluation) if a random number is greater than a threshold specified in the model definition: + +```python linenums="1" +import random +import typing as t +from sqlmesh import signal, DatetimeRanges + + +@signal() +def random_signal(batch: DatetimeRanges, threshold: float) -> t.Union[bool, DatetimeRanges]: + return random.random() > threshold +``` + +Note that the `random_signal()` takes a mandatory user defined `threshold` argument. + +The `random_signal()` method extracts the threshold metadata and compares a random number to it. The type is inferred based on the same [rules as SQLMesh Macros](../concepts/macros/sqlmesh_macros.md#typed-macros). + +Now that we have a working signal, we need to specify that a model should use the signal by passing metadata to the model DDL's `signals` key. + +The `signals` key accepts an array delimited by brackets `[]`. Each function in the list should contain the metadata needed for one signal evaluation. + +This example specifies that the `random_signal()` should evaluate once with a threshold of 0.5: + +```sql linenums="1" hl_lines="4-6" +MODEL ( + name example.signal_model, + kind FULL, + signals ( + random_signal(threshold := 0.5), # specify threshold value + ) +); + +SELECT 1 +``` + +The next time this project is `sqlmesh run`, our signal will metaphorically flip a coin to determine whether the model should be evaluated. + +### Advanced Example + +This example demonstrates more advanced use of signals: a signal returning a subset of intervals from a batch (rather than a single `True`/`False` value for all intervals in the batch) + +```python +import typing as t + +from sqlmesh import signal, DatetimeRanges +from sqlmesh.utils.date import to_datetime + + +# signal that returns only intervals that are <= 1 week ago +@signal() +def one_week_ago(batch: DatetimeRanges) -> t.Union[bool, DatetimeRanges]: + dt = to_datetime("1 week ago") + + return [ + (start, end) + for start, end in batch + if start <= dt + ] +``` + +Instead of returning a single `True`/`False` value for whether a batch of intervals is ready for evaluation, the `one_week_ago()` function returns specific intervals from the batch. + +It generates a datetime argument, to which it compares the beginning of each interval in the batch. If the interval start is before that argument, the interval is ready for evaluation and included in the returned list. +These signals can be added to a model like so. + +```sql linenums="1" hl_lines="7-10" +MODEL ( + name example.signal_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds, + ), + start '2 week ago', + signals ( + one_week_ago(), + ) +); + + +SELECT @start_ds AS ds +``` + +### Accessing execution context / engine adapter +It is possible to access the execution context in a signal and access the engine adapter (warehouse connection). + +```python +import typing as t + +from sqlmesh import signal, DatetimeRanges, ExecutionContext + + +# add the context argument to your function +@signal() +def one_week_ago(batch: DatetimeRanges, context: ExecutionContext) -> t.Union[bool, DatetimeRanges]: + return len(context.engine_adapter.fetchdf("SELECT 1")) > 1 +``` + +### Testing Signals +Signals only evaluate on `run` or with `check_intervals`. + +To test signals with the [check_intervals](../reference/cli.md#check_intervals) command: + +1. Deploy your changes to an environment with `sqlmesh plan my_dev`. +2. Run `sqlmesh check_intervals my_dev`. + + * To check a subset of models use the --select-model flag. + * To turn off signals and just check missing intervals, use the --no-signals flag. + +3. To iterate, make changes to the signal, and redeploy with step 1. + +!!! note + `check_intervals` only works on remote models in an environment. Local signal changes are never run. diff --git a/docs/guides/table_migration.md b/docs/guides/table_migration.md index cb57abd359..351a704ac3 100644 --- a/docs/guides/table_migration.md +++ b/docs/guides/table_migration.md @@ -129,9 +129,9 @@ Consider an existing table named `my_schema.existing_table`. Migrating this tabl b. Specify the start of the first time interval SQLMesh should track in the `MODEL` DDL `start` key (example uses "2024-01-01") - c. Create the model in the SQLMesh project without backfilling any data by running `sqlmesh plan [environment name] --skip-backfill --start 2024-01-01`, replacing "[environment name]" with an environment name other than `prod` and using the same start date from the `MODEL` DDL in step 3b. + c. Create the model in the SQLMesh project without backfilling any data by running `sqlmesh plan [environment name] --empty-backfill --start 2024-01-01`, replacing "[environment name]" with an environment name other than `prod` and using the same start date from the `MODEL` DDL in step 3b. -4. Determine the name of the model's snapshot physical table by running `sqlmesh table_name my_schema.existing_table`. For example, it might return `sqlmesh__my_schema.existing_table_123456`. +4. Determine the name of the model's snapshot physical table by running `sqlmesh table_name --env [environment name] --prod my_schema.existing_table`. For example, it might return `sqlmesh__my_schema.existing_table_123456`. 5. Rename the original table `my_schema.existing_table_temp` to `sqlmesh__my_schema.existing_table_123456` The model would have code similar to: diff --git a/docs/guides/tablediff.md b/docs/guides/tablediff.md index 0300d93c9f..6d649b3e93 100644 --- a/docs/guides/tablediff.md +++ b/docs/guides/tablediff.md @@ -118,6 +118,62 @@ Grain should have unique and not-null audits for accurate results. ``` +Under the hood, SQLMesh stores temporary data in the database to perform the comparison. +The default schema for these temporary tables is `sqlmesh_temp` but can be changed with the `--temp-schema` option. +The schema can be specified as a `CATALOG.SCHEMA` or `SCHEMA`. + + +## Diffing multiple models across environments + +SQLMesh allows you to compare multiple models across environments at once using model selection expressions. This is useful when you want to validate changes across a set of related models or the entire project. + +To diff multiple models, use the `--select-model` (or `-m` for short) option with the table diff command: + +```bash +sqlmesh table_diff prod:dev --select-model "sqlmesh_example.*" +``` + +When diffing multiple models, SQLMesh will: + +1. Show the models returned by the selector that exist in both environments and have differences +2. Compare these models and display the data diff of each model + +> Note: Models will only be data diffed if there's a breaking change that impacts them. + +The `--select-model` option supports a powerful selection syntax that lets you choose models using patterns, tags, dependencies and git status. For complete details, see the [model selection guide](./model_selection.md). + +> Note: Surround your selection pattern in single or double quotes. Ex: `'*'`, `"sqlmesh_example.*"` + +Here are some common examples: + +```bash +# Select all models in a schema +sqlmesh table_diff prod:dev -m "sqlmesh_example.*" + +# Select a model and its dependencies +sqlmesh table_diff prod:dev -m "+model_name" # include upstream deps +sqlmesh table_diff prod:dev -m "model_name+" # include downstream deps + +# Select models by tag +sqlmesh table_diff prod:dev -m "tag:finance" + +# Select models with git changes +sqlmesh table_diff prod:dev -m "git:feature" + +# Use logical operators for complex selections +sqlmesh table_diff prod:dev -m "(metrics.* & ^tag:deprecated)" # models in the metrics schema that aren't deprecated + +# Combine multiple selectors +sqlmesh table_diff prod:dev -m "tag:finance" -m "metrics.*_daily" +``` + +When multiple selectors are provided, they are combined with OR logic, meaning a model matching any of the selectors will be included. + +!!! note + All models being compared must have their `grain` defined that is unique and not null, as this is used to perform the join between the tables in the two environments. + + If the `--warn-grain-check` option is used, this requirement is not enforced. Instead of raising an error, a warning is displayed for the models without a defined grain and diffs are computed for the remaining models. + ## Diffing tables or views Compare specific tables or views with the SQLMesh CLI interface by using the command `sqlmesh table_diff [source table]:[target table]`. @@ -153,3 +209,24 @@ SQLMESH_EXAMPLE.INCREMENTAL_MODEL ONLY sample rows: ``` The output matches, with the exception of the column labels in the `COMMON ROWS sample data differences`. The underlying table for each column is indicated by `s__` for "source" table (first table in the command's colon operator `:`) and `t__` for "target" table (second table in the command's colon operator `:`). + +## Diffing tables or views across gateways + +!!! info "Tobiko Cloud Feature" + + Cross-database table diffing is available in [Tobiko Cloud](../cloud/features/xdb_diffing.md). + +SQLMesh executes a project's models with a single database system, specified as a [gateway](../guides/connections.md#overview) in the project configuration. + +The within-database table diff tool described above compares tables or environments within such a system. Sometimes, however, you might want to compare tables that reside in two different data systems. + +For example, you might migrate your data transformations from an on-premises SQL engine to a cloud SQL engine while setting up your SQLMesh project. To demonstrate equivalence between the systems you could run the transformations in both and compare the new tables to the old tables. + +The [within-database table diff](#diffing-models-across-environments) tool cannot make those comparisons, for two reasons: + +1. It must join the two tables being diffed, but with two systems no single database engine can access both tables. +2. It assumes that data values can be compared across tables without modification. If the systems use different SQL engines, however, the diff must account for differences in the engines' data types (e.g., whether timestamps should include time zone information). + +SQLMesh's cross-database table diff tool is built for just this scenario. Its comparison algorithm efficiently diffs tables without moving them from one system to the other and automatically addresses differences in data types. + +Learn more about cross-database table diffing in our [Tobiko Cloud docs](../cloud/features/xdb_diffing.md). diff --git a/docs/guides/ui.md b/docs/guides/ui.md index bf5bfe3c8e..29bb204988 100644 --- a/docs/guides/ui.md +++ b/docs/guides/ui.md @@ -1,5 +1,10 @@ # Browser UI guide +!!! warning + + Browser UI is deprecated. Please use the [VSCode extension](vscode.md) instead. + + SQLMesh's free, open-source browser user interface (UI) makes it easy to understand, explore, and modify your SQLMesh project. This page describes the UI's components and how they work. @@ -24,7 +29,7 @@ For development work, we recommend using the SQLMesh UI alongside an IDE. The UI Before beginning, ensure that you meet all the [prerequisites](../prerequisites.md) for using SQLMesh. The SQLMesh browser UI requires additional Python libraries not included in the base SQLMesh installation. -To use the UI, install SQLMesh with the `web` add-on. First, if using a python virtual environment, ensure it's activated by running `source .env/bin/activate` command from the folder used during [installation](../installation.md). +To use the UI, install SQLMesh with the `web` add-on. First, if using a python virtual environment, ensure it's activated by running `source .venv/bin/activate` command from the folder used during [installation](../installation.md). Next, install the UI with `pip`: @@ -42,11 +47,11 @@ sqlmesh ui After starting up, the SQLMesh web UI is served at `http://127.0.0.1:8000` by default: -![SQLMesh web UI startup on CLI](./ui/ui-quickstart_cli.png) +![SQLMesh web UI startup on CLI](./ui/ui-quickstart_cli.png){ loading=lazy } Navigate to the URL by clicking the link in your terminal (if supported) or copy-pasting it into your web browser: -![SQLMesh web UI startup in browser](./ui/ui-quickstart_ui-startup.png) +![SQLMesh web UI startup in browser](./ui/ui-quickstart_ui-startup.png){ loading=lazy } ## Modules @@ -56,7 +61,7 @@ The UI modules are: - [Code editor](#editor-module) - [Plan builder](#plan-module) -- [Project documentation](#docs-module) +- [Data catalog](#data-catalog-module) - [Table and column lineage](#lineage-module) The screenshots in most examples below use the default `editor` mode. @@ -71,11 +76,11 @@ The `editor` module will appear by default if the UI is started without specifyi 4. Inspector provides settings and information based on recent actions and the currently active pane. (Note: inspector pane is collapsed by default. Expand it by clicking the hamburger button at the top of the collapsed pane - see previous image.) 5. Details displays column-level lineage for models open in the editor and results of queries. (Note: details pane is collapsed by default. It will automatically expand upon opening a model in the editor or running a query.) -![SQLMesh browser UI panes](./ui/ui-quickstart_ui-startup-panes.png) +![SQLMesh browser UI panes](./ui/ui-quickstart_ui-startup-panes.png){ loading=lazy } It also contains nine buttons: -1. Toggle Editor/Docs/Errors/Plan toggles among the editor module (default), docs module, errors view, and plan module. Errors view is only available if an error has occurred. +1. Toggle Editor/Data Catalog/Errors/Plan toggles among the editor module (default), data catalog module, errors view, and plan module. Errors view is only available if an error has occurred. 2. History navigation returns to previous views, similar to the back button in a web browser. 3. Add new tab opens a new code editor window. 4. Plan opens the plan module. @@ -85,7 +90,7 @@ It also contains nine buttons: 8. Format SQL query reformats a SQL query using SQLGlot's pretty layout. 9. Change SQL dialect specifies the SQL dialect of the current tab for custom SQL queries. It does not affect the SQL dialect for the project. -![SQLMesh browser UI buttons](./ui/ui-guide_ui-startup-buttons.png) +![SQLMesh browser UI buttons](./ui/ui-guide_ui-startup-buttons.png){ loading=lazy } And it contains four status indicators: @@ -94,7 +99,7 @@ And it contains four status indicators: 3. Change indicator displays a summary of the changes in the project files relative to the most recently run SQLMesh plan in the selected environment. 4. Error indicator displays the count of errors in the project. -![SQLMesh web UI status indicators](./ui/ui-quickstart_ui-startup-status.png) +![SQLMesh web UI status indicators](./ui/ui-quickstart_ui-startup-status.png){ loading=lazy } #### Edit models @@ -102,11 +107,11 @@ Open a model in a new tab by clicking its file name in the left-hand project dir The tab will show the model definition, and the details pane at the bottom will display the model in the project's table and column lineage. -![Incremental model open in editor](./ui/ui-quickstart_incremental-model.png) +![Incremental model open in editor](./ui/ui-quickstart_incremental-model.png){ loading=lazy } The lineage display will update as model modifications are saved. For example, you might modify the incremental SQL model by adding a new column to the query. Press `Cmd + S` (`Ctrl + S` on Windows) to save the modified model file and display the updated lineage: -![Incremental model modified in editor](./ui/ui-quickstart_incremental-model-modified.png) +![Incremental model modified in editor](./ui/ui-quickstart_incremental-model-modified.png){ loading=lazy } The `Changes` indicator in the top right now shows blue and orange circles that reflect our model update. @@ -116,11 +121,11 @@ Run SQL queries by executing them from custom SQL editor tabs. For example, we might add a SQL query `select * from sqlmesh_example.incremental_model` to the Custom SQL 1 tab. To run the query, first click the hamburger icon to open the explorer pane: -![Querying `dev` incremental model with SQL query in editor](./ui/ui-guide_fetchdf-prod-query.png) +![Querying `dev` incremental model with SQL query in editor](./ui/ui-guide_fetchdf-prod-query.png){ loading=lazy } Then click the `Run Query` button in the bottom right to execute the query: -![Results from querying dev incremental model with SQL query in editor](./ui/ui-guide_fetchdf-prod.png) +![Results from querying dev incremental model with SQL query in editor](./ui/ui-guide_fetchdf-prod.png){ loading=lazy } The results appear in an interactive table in the details pane below the editor. @@ -143,7 +148,7 @@ When you open the plan module, it contains multiple pieces of information about - The `Changes` section shows that SQLMesh detected three models added relative to the current empty environment. - The `Backfills` section shows that backfills will occur for all three of the added models. -![Plan module - new project](./ui/ui-quickstart_run-plan.png) +![Plan module - new project](./ui/ui-quickstart_run-plan.png){ loading=lazy } SQLMesh will apply the plan and initiate backfill when you click the blue button labeled `Apply Changes And Backfill`. @@ -155,7 +160,7 @@ The `Snapshot Tables Created` indicates that [snapshots](../concepts/architectur The `Backfilled` section shows progress indicators for the backfill operations. The first progress indicator shows the total number of tasks and completion percentage for the entire backfill operation. The remaining progress bars show completion percentage and run time for each model (very fast in this simple example). -![Plan module - plan applied](./ui/ui-quickstart_apply-plan.png) +![Plan module - plan applied](./ui/ui-quickstart_apply-plan.png){ loading=lazy } #### New environment @@ -163,21 +168,21 @@ To create a new environment, open the environment menu by clicking the drop-down To create an environment named "dev," type `dev` into the Environment field and click the blue `Add` button. -![Open environment menu](./ui/ui-quickstart_create-dev.png) +![Open environment menu](./ui/ui-quickstart_create-dev.png){ loading=lazy } The drop-down now shows that the SQLMesh UI is working in the `dev` environment: -![Working in dev environment](./ui/ui-quickstart_plan-dev.png) +![Working in dev environment](./ui/ui-quickstart_plan-dev.png){ loading=lazy } To populate the environment with views of the production environment, click the green `Plan` button to open the plan module: -![Run plan on dev pane](./ui/ui-quickstart_run-plan-dev.png) +![Run plan on dev pane](./ui/ui-quickstart_run-plan-dev.png){ loading=lazy } The output section does not list any added/modified models or backfills because `dev` is being created from the existing `prod` environment without modification. Clicking the blue `Apply Virtual Update` button applies the new plan: -![Run plan on dev pane output](./ui/ui-quickstart_run-plan-dev-output.png) +![Run plan on dev pane output](./ui/ui-quickstart_run-plan-dev-output.png){ loading=lazy } #### Existing environment @@ -186,27 +191,27 @@ If you modify the project files, you will want to apply the changes to an existi The plan module will summarize the changes when you open it: -![Plan pane after opening plan module with modified model](./ui/ui-quickstart_run-plan-dev-modified.png) +![Plan pane after opening plan module with modified model](./ui/ui-quickstart_run-plan-dev-modified.png){ loading=lazy } The `Changes` section detects that `incremental_model` was directly modified and that `full_model` was indirectly modified because it selects from the incremental model. Click the blue `Apply Changes And Backfill` button to apply the plan and execute the backfill: -![Plan after applying updated plan with modified model](./ui/ui-quickstart_apply-plan-dev-modified.png) +![Plan after applying updated plan with modified model](./ui/ui-quickstart_apply-plan-dev-modified.png){ loading=lazy } -### Docs module +### Data Catalog module -The docs module displays information about all your project's models in one interface. +The data catalog module displays information about all your project's models in one interface. A list of all models is displayed in the left-hand pane. You can filter models by name by typing in the field at the top of the pane. When you choose a model, its query, lineage, and attributes are displayed. This example shows information from the [quickstart project](../quick_start.md) incremental model: -![Docs with incremental model selected](./ui/ui-guide_docs.png) +![Data Catalog with incremental model selected](./ui/ui-guide_docs.png){ loading=lazy } By default, the model definition source code is displayed. If you toggle to `Compiled Query`, it will display an example of the model query rendered with macro values substituted: -![Docs with compiled incremental model query](./ui/ui-guide_docs-compiled-query.png) +![Data Catalog with compiled incremental model query](./ui/ui-guide_docs-compiled-query.png){ loading=lazy } ### Lineage module @@ -214,15 +219,15 @@ The lineage module displays a graphical representation of the project's table an Click a model in the left-hand pane to view its lineage. By default, only the model's upstream parents and downstream children are displayed: -![Lineage module](./ui/ui-guide_lineage.png) +![Lineage module](./ui/ui-guide_lineage.png){ loading=lazy } You may include all a project's models by clicking `All` in the Show drop-down on the upper right. In this example, two additional models appear: -![Lineage module - all models](./ui/ui-guide_lineage-all.png) +![Lineage module - all models](./ui/ui-guide_lineage-all.png){ loading=lazy } Click `Connected` in the Show drop-down menu to highlight edges between upstream parents and downstream children in blue. This may be helpful when when a project contains many models: -![Lineage module - all models, connected edges](./ui/ui-guide_lineage-all-connected.png) +![Lineage module - all models, connected edges](./ui/ui-guide_lineage-all-connected.png){ loading=lazy } ## Modes @@ -232,9 +237,9 @@ You may specify the UI mode as an option when you [start the UI on the command l The UI modes contain these modules: -- `editor`: code editor, plan builder, project documentation, table and column lineage -- `plan`: plan builder, project documentation, table and column lineage -- `docs`: project documentation, table and column lineage +- `editor`: code editor, plan builder, data catalog, table and column lineage +- `plan`: plan builder, data catalog, table and column lineage +- `catalog`: data catalog, table and column lineage ### Working with an IDE @@ -248,31 +253,31 @@ To use this workflow, first open a terminal in VSCode and navigate to your proje 1. Start the browser UI in `plan` mode with the command `sqlmesh ui --mode plan`: -![VSCode - start the UI](./ui/ui-guide_vscode-start-ui.png) +![VSCode - start the UI](./ui/ui-guide_vscode-start-ui.png){ loading=lazy }

2. In VSCode, type the shortcut `cmd+shift+p` to open the search menu: -![VSCode - open search menu](./ui/ui-guide_vscode-open-search.png) +![VSCode - open search menu](./ui/ui-guide_vscode-open-search.png){ loading=lazy }

3. Type `simple browser` into the search menu and click the entry `Simple browser: Show`: -![VSCode - open simple browser](./ui/ui-guide_vscode-start-browser.png) +![VSCode - open simple browser](./ui/ui-guide_vscode-start-browser.png){ loading=lazy }

4. Copy the web address printed by the command output (`http://127.0.0.1:8000` by default), paste it into the menu, and click enter: -![VSCode - navigate to UI](./ui/ui-guide_vscode-browser-url.png) +![VSCode - navigate to UI](./ui/ui-guide_vscode-browser-url.png){ loading=lazy }

5. The UI will now appear in a VSCode tab: -![VSCode - UI in browser tab](./ui/ui-guide_vscode-browser-ui.png) +![VSCode - UI in browser tab](./ui/ui-guide_vscode-browser-ui.png){ loading=lazy }

6. Split the VSCode window to open a code editor alongside the UI. As you update models, the UI plan and lineage interfaces will update to reflect the changes in real time: -![VSCode - update model plan](./ui/ui-guide_vscode-update-plan.png) +![VSCode - update model plan](./ui/ui-guide_vscode-update-plan.png){ loading=lazy } -![VSCode - update model lineage](./ui/ui-guide_vscode-update-lineage.png) +![VSCode - update model lineage](./ui/ui-guide_vscode-update-lineage.png){ loading=lazy } diff --git a/docs/guides/vscode.md b/docs/guides/vscode.md new file mode 100644 index 0000000000..151e630f27 --- /dev/null +++ b/docs/guides/vscode.md @@ -0,0 +1,208 @@ +# Visual Studio Code Extension + +
+ +!!! danger "Preview" + + The SQLMesh Visual Studio Code extension is in preview and undergoing active development. You may encounter bugs or API incompatibilities with the SQLMesh version you are running. + + We encourage you to try the extension and [create Github issues](https://github.com/tobikodata/sqlmesh/issues) for any problems you encounter. + +In this guide, you'll set up the SQLMesh extension in the Visual Studio Code IDE software (which we refer to as "VSCode"). + +We'll show you the capabilities of the extension and how to troubleshoot common issues. + +## Installation + +### VSCode extension + +Install the extension through the official Visual Studio [marketplace website](https://marketplace.visualstudio.com/items?itemName=tobikodata.sqlmesh) or by searching for `SQLMesh` in the VSCode "Extensions" tab. + +Learn more about installing VSCode extensions in the [official documentation](https://code.visualstudio.com/docs/configure/extensions/extension-marketplace#_install-an-extension). + +### Python setup + +While installing the extension is simple, setting up and configuring a Python environment in VSCode is a bit more involved. + +We recommend using a dedicated *Python virtual environment* to install SQLMesh. Visit the [Python documentation](https://docs.python.org/3/library/venv.html) for more information about virtual environments. + +We describe the steps to create and activate a virtual environment below, but additional information is available on the [SQLMesh installation page](../installation.md). + +We first install the SQLMesh library, which is required by the extension. + +Open a terminal instance in your SQLMesh project's directory and issue this command to create a virtual environment in the `.venv` directory: + +```bash +python -m venv .venv +``` + +Next, activate the virtual environment: + +```bash +source .venv/bin/activate +``` + +#### Open-source SQLMesh + +If you are using open-source SQLMesh, install SQLMesh with the `lsp` extra that enables the VSCode extension (learn more about SQLMesh extras [here](../installation.md#install-extras)): + +```bash +pip install 'sqlmesh[lsp]' +``` + +#### Tobiko Cloud + +If you are using Tobiko Cloud, the `tcloud` library will install SQLMesh for you. + +First, follow the [Python setup](#python-setup) steps above to create and activate a Python environment. Next, install `tcloud`: + +```bash +pip install tcloud # always make sure to install the latest version of tcloud +``` + +Finally, add the `lsp` extra to your `tcloud.yml` configuration file, as described [here](../cloud/tcloud_getting_started.md#connect-tobiko-cloud-to-data-warehouse). + +### VSCode Python interpreter + +A Python virtual environment contains its own copy of Python (the "Python interpreter"). + +We need to make sure VSCode is using your virtual environment's interpreter rather than a system-wide or other interpreter that does not have access to the SQLMesh library we just installed. + +Confirm that VSCode is using the correct interpreter by going to the [command palette](https://code.visualstudio.com/docs/getstarted/userinterface#_command-palette) and clicking `Python: Select Interpreter`. Select the Python executable that's in the virtual environment's directory `.venv`. + +![Select interpreter](./vscode/select_interpreter.png) + +Once that's done, validate that the everything is working correctly by checking the `sqlmesh` channel in the [output panel](https://code.visualstudio.com/docs/getstarted/userinterface#_output-panel). It displays the Python interpreter path and details of your SQLMesh installation: + +![Output panel](./vscode/interpreter_details.png) + +## Features + +SQLMesh's VSCode extension makes it easy to edit and understand your SQLMesh project with these features: + +- Lineage + - Interactive view of model lineage +- Editor + - Auto-completion for model names and SQLMesh keywords + - Model summaries when hovering over model references + - Links to open model files from model references + - Inline SQLMesh linter diagnostics +- VSCode commands + - Format SQLMesh project files + - Sign in/out of Tobiko Cloud (Tobiko Cloud users only) + +### Lineage + +The extension adds a lineage view to SQLMesh models. To view the lineage of a model, go to the `Lineage` tab in the panel: + +![Lineage view](./vscode/lineage.png) + +### Render + +The extension allows you to render a model with the macros resolved. You can invoke it either with the command palette `Render SQLMesh Model` or by clicking the preview button in the top right. + +### Editor + +The SQLMesh VSCode extension includes several features that make editing SQLMesh models easier and quicker: + +**Completion** + +See auto-completion suggestions when writing SQL models, keywords, or model names. + +![Completion](./vscode/autocomplete.png) + +**Go to definition and hover information** + +Hovering over a model name shows a tooltip with the model description. + +In addition to hover information, you can go to a definition of the following objects in a SQL file by either right-clicking and choosing "Go to definition" or by `Command/Control + Click` on the respective reference. This currently works for: + +- Model references in a SQL file like `FROM my_model` +- CTE reference in a SQL file like `WITH my_cte AS (...) ... FROM my_cte` +- Python macros in a SQL file like `SELECT @my_macro(...)` + +**Diagnostics** + +If you have the [SQLMesh linter](../guides/linter.md) enabled, issues are reported directly in your editor. This works for both SQLMesh's built-in linter rules and custom linter rules. + +![Diagnostics](./vscode/diagnostics.png) + +**Formatting** + +SQLMesh's model formatting tool is integrated directly into the editor, so it's easy to format models consistently. + +### Commands + +The SQLMesh VSCode extension provides the following commands in the VSCode command palette: + +- `Format SQLMesh project` +- `Sign in to Tobiko Cloud` (Tobiko Cloud users only) +- `Sign out of Tobiko Cloud` (Tobiko Cloud users only) + +## Troubleshooting + +### DuckDB concurrent access + +If your SQLMesh project uses DuckDB to store its state, you will likely encounter problems. + +SQLMesh can create multiple connections to the state database, but DuckDB's local database file does not support concurrent access. + +Because the VSCode extension establishes a long-running process connected to the database, access conflicts are more likely than with standard SQLMesh usage from the CLI. + +Therefore, we do not recommend using DuckDB as a state store with the VSCode extension. + +### Environment variables + +The VSCode extension is based on a [language server](https://en.wikipedia.org/wiki/Language_Server_Protocol) that runs in the background as a separate process. When the VSCode extension starts the background language server, the server inherits environment variables from the environment where you started VSCode. The server does *not* inherit environment variables from your terminal instance in VSCode, so it may not have access to variables you use when calling SQLMesh from the CLI. + +If you have environment variables that are needed by the context and the language server, you can use one of these approaches to pass variables to the language server: + +- Open VSCode from a terminal that has the variables set already. + - If you have `export ENV_VAR=value` in your shell configuration file (e.g. `.zshrc` or `.bashrc`) when initializing the terminal by default, the variables will be picked up by the language server if opened from that terminal. +- Use environment variables pulled from somewhere else dynamically in your `config.py` for example by connecting to a secret store +- By default, a `.env` file in your root project directory will automatically be picked up by the language server through the python environment that the extension uses. For exact details on how to set the environment variables in the Python environment that the extension uses, see [here](https://code.visualstudio.com/docs/python/environments#_environment-variables) + +You can verify that the environment variables are being passed to the language server by printing them in your terminal. + +1. `Cmd +Shift + P` (`Ctrl + Shift + P` in case of Windows) to start the VSCode command bar + ![print_env_vars](./vscode/print_env_vars.png) +2. Select the option: `SQLMesh: Print Environment Variables` +3. You should see the environment variables printed in the terminal + ![terminal_env_vars](./vscode/terminal_env_vars.png) + +If you change your setup during development (e.g., add variables to your shell config), you must restart the language server for the changes to take effect. You can do this by running the following command in the terminal: + +1. `Cmd +Shift + P` (`Ctrl + Shift + P` in case of Windows) to start the VSCode command bar +2. Select the option: `SQLMesh: Restart Servers` + ![restart_servers](./vscode/restart_servers.png) + ![loaded](./vscode/loaded.png) + + > This loaded message will appear in the lower left corner of the VSCode window. + +3. Print the environment variables based on the instructions above to verify the changes have taken effect. + +### Python environment issues + +The most common problem is the extension not using the correct Python interpreter. + +Follow the [setup process described above](#vscode-python-interpreter) to ensure that the extension is using the correct Python interpreter. + +If you have checked the VSCode `sqlmesh` output channel and the extension is still not using the correct Python interpreter, please raise an issue [here](https://github.com/tobikodata/sqlmesh/issues). + +### Missing Python dependencies + +When installing SQLMesh, some dependencies required by the VSCode extension are not installed unless you specify the `lsp` "extra". + +If you are using open-source SQLMesh, install the `lsp` extra by running this command in your terminal: + +```bash +pip install 'sqlmesh[lsp]' +``` + +If you are using Tobiko Cloud, make sure `lsp` is included in the list of extras specified in the [`tcloud.yaml` configuration file](../cloud/tcloud_getting_started.md#connect-tobiko-cloud-to-data-warehouse). + +### SQLMesh compatibility + +While the SQLMesh VSCode extension is in preview and the APIs to the underlying SQLMesh version are not stable, we do not guarantee compatibility between the extension and the SQLMesh version you are using. + +If you encounter a problem, please raise an issue [here](https://github.com/tobikodata/sqlmesh/issues). \ No newline at end of file diff --git a/docs/guides/vscode/autocomplete.png b/docs/guides/vscode/autocomplete.png new file mode 100644 index 0000000000..3c7c9fa08c Binary files /dev/null and b/docs/guides/vscode/autocomplete.png differ diff --git a/docs/guides/vscode/diagnostics.png b/docs/guides/vscode/diagnostics.png new file mode 100644 index 0000000000..1a8148cd66 Binary files /dev/null and b/docs/guides/vscode/diagnostics.png differ diff --git a/docs/guides/vscode/interpreter_details.png b/docs/guides/vscode/interpreter_details.png new file mode 100644 index 0000000000..09b0f0996a Binary files /dev/null and b/docs/guides/vscode/interpreter_details.png differ diff --git a/docs/guides/vscode/lineage.png b/docs/guides/vscode/lineage.png new file mode 100644 index 0000000000..c2435da845 Binary files /dev/null and b/docs/guides/vscode/lineage.png differ diff --git a/docs/guides/vscode/loaded.png b/docs/guides/vscode/loaded.png new file mode 100644 index 0000000000..efc38522be Binary files /dev/null and b/docs/guides/vscode/loaded.png differ diff --git a/docs/guides/vscode/print_env_vars.png b/docs/guides/vscode/print_env_vars.png new file mode 100644 index 0000000000..5ea4dca7f1 Binary files /dev/null and b/docs/guides/vscode/print_env_vars.png differ diff --git a/docs/guides/vscode/restart_servers.png b/docs/guides/vscode/restart_servers.png new file mode 100644 index 0000000000..c8052f1718 Binary files /dev/null and b/docs/guides/vscode/restart_servers.png differ diff --git a/docs/guides/vscode/select_interpreter.png b/docs/guides/vscode/select_interpreter.png new file mode 100644 index 0000000000..9224f73265 Binary files /dev/null and b/docs/guides/vscode/select_interpreter.png differ diff --git a/docs/guides/vscode/terminal_env_vars.png b/docs/guides/vscode/terminal_env_vars.png new file mode 100644 index 0000000000..f4a567634e Binary files /dev/null and b/docs/guides/vscode/terminal_env_vars.png differ diff --git a/docs/index.md b/docs/index.md index e5ecc7f8f3..83c1b0a431 100644 --- a/docs/index.md +++ b/docs/index.md @@ -1,76 +1,165 @@ -# SQLMesh - -[SQLMesh](https://sqlmesh.com) is an [open source](https://github.com/TobikoData/sqlmesh) data transformation framework that brings the best practices of DevOps to data teams. It enables data scientists, analysts, and engineers to efficiently run and deploy data transformations written in SQL or Python. It is created and maintained by [Tobiko Data](https://tobikodata.com/), a company founded by data leaders from Airbnb, Apple, and Netflix. - -## Why SQLMesh? - -The experience of developing and deploying data pipelines is more uncertain and manual when compared to developing applications. This is partially due to the lack of tooling revolving around the testing and deployment of data pipelines. With DevOps, software engineers are able to seamlessly confirm logic with unit tests, validate systems with containerized environments, and transition to prod with confidence. SQLMesh aims to give data teams the same confidence as their peers. - -Here are some challenges that data teams run into, especially when data sizes increase or the number of data users expands: - -1. Data pipelines are fragmented and fragile - * Data pipelines generally consist of Python or SQL scripts that implicitly depend upon each other through tables. Changes to upstream scripts that break downstream consumers are usually only detected at run time. - -1. Data quality checks are not sufficient - * The data community has settled on data quality checks as the "solution" for testing data pipelines. Although data quality checks are great for detecting large unexpected data changes, they are expensive to run, and they have trouble validating exact logic. - -1. It's too hard and too costly to build staging environments for data - * Validating changes to data pipelines before deploying to production is an uncertain and sometimes expensive process. Although branches can be deployed to environments, when merged to production, the code is re-run. This is wasteful and generates uncertainty because the data is regenerated. - -1. Silos transform data lakes to data swamps - * The difficulty and cost of making changes to core pipelines can lead to duplicate pipelines with minor customizations. The inability to easily make and validate changes causes contributors to follow the "path of least resistance". The proliferation of similar tables leads to additional costs, inconsistencies, and maintenance burden. - -## What is SQLMesh? -SQLMesh consists of a CLI, a Python API, and a Web UI to make data pipeline development and deployment easy, efficient, and safe. - -### Core principles -SQLMesh was built on three core principles: - -1. Correctness is non-negotiable - * Bad data is worse than no data. SQLMesh guarantees that your data will be consistent even in heavily collaborative environments. - -1. Change with confidence - * SQLMesh summarizes the impact of changes and provides automated guardrails empowering everyone to safely and quickly contribute. - -1. Efficiency without complexity - * SQLMesh automatically optimizes your workloads by reusing tables and minimizing computation saving you time and money. - -### Key features -* Efficient dev/staging environments - * SQLMesh builds a Virtual Data Environment using views, which allows you to seamlessly rollback or roll forward your changes. Any data computation you run for validation purposes is actually not wasted — with a cheap pointer swap, you re-use your “staging” data in production. This means you get unlimited copy-on-write environments that make data exploration and preview of changes fun and safe. - -* Automatic DAG generation by semantically parsing and understanding SQL or Python scripts - * No need to manually tag dependencies — SQLMesh was built with the ability to understand your entire data warehouse’s dependency graph. - -* Informative change summaries - * Before making changes, SQLMesh will determine what has changed and show the entire graph of affected jobs. - -* CI-Runnable Unit and Integration tests - * Can be easily defined in YAML and run in CI. SQLMesh can optionally transpile your queries to DuckDB so that your tests can be self-contained. - -* Smart change categorization - * Column-level lineage automatically determines whether changes are “breaking” or “non-breaking”, allowing you to correctly categorize changes and to skip expensive backfills. - -* Easy incremental loads - * Loading tables incrementally is as easy as a full refresh. SQLMesh transparently handles the complexity of tracking which intervals need loading, so all you have to do is specify a date filter. - -* Integrated with Airflow - * You can schedule jobs with our built-in scheduler or use your existing Airflow cluster. SQLMesh can dynamically generate and push Airflow DAGs. We aim to support other schedulers like Dagster and Prefect in the future. - -* Notebook / CLI - * Interact with SQLMesh with whatever tool you’re comfortable with. - -* Web based IDE - * Edit, run, and visualize queries in your browser. - -* Github CI/CD bot - * A bot to tie your code directly to your data. - -* Table/Column level lineage visualizations - * Quickly understand the full lineage and sequence of transformation of any column. - -## Next steps -* [Jump right in with the quickstart](quick_start.md) -* [Check out the FAQ](faq/faq.md) -* [Learn more about SQLMesh concepts](concepts/overview.md) -* [Join our Slack community](https://tobikodata.com/slack) +# + +

+ SQLMesh logo +

+ +SQLMesh is a next-generation data transformation framework designed to ship data quickly, efficiently, and without error. Data teams can efficiently run and deploy data transformations written in SQL or Python with visibility and control at any size. + +It is more than just a [dbt alternative](https://tobikodata.com/reduce_costs_with_cron_and_partitions.html). + +

+ Architecture Diagram +

+ +## Core Features +SQLMesh Plan Mode + +> Get instant SQL impact analysis of your changes, whether in the CLI or in [SQLMesh Plan Mode](https://sqlmesh.readthedocs.io/en/stable/guides/ui/?h=modes#working-with-an-ide) + +??? tip "Virtual Data Environments" + + - See a full diagram of how [Virtual Data Environments](https://whimsical.com/virtual-data-environments-MCT8ngSxFHict4wiL48ymz) work + - [Watch this video to learn more](https://www.youtube.com/watch?v=weJH3eM0rzc) + +* Create isolated development environments without data warehouse costs +* Plan / Apply workflow like [Terraform](https://www.terraform.io/) to understand potential impact of changes +* Easy to use [CI/CD bot](https://sqlmesh.readthedocs.io/en/stable/integrations/github/) for true blue-green deployments + +??? tip "Efficiency and Testing" + + Running this command will generate a unit test file in the `tests/` folder: `test_stg_payments.yaml` + + Runs a live query to generate the expected output of the model + + ```bash + sqlmesh create_test tcloud_demo.stg_payments --query tcloud_demo.seed_raw_payments "select * from tcloud_demo.seed_raw_payments limit 5" + + # run the unit test + sqlmesh test + ``` + + ```sql + MODEL ( + name tcloud_demo.stg_payments, + cron '@daily', + grain payment_id, + audits (UNIQUE_VALUES(columns = ( + payment_id + )), NOT_NULL(columns = ( + payment_id + ))) + ); + + SELECT + id AS payment_id, + order_id, + payment_method, + amount / 100 AS amount, /* `amount` is currently stored in cents, so we convert it to dollars */ + 'new_column' AS new_column, /* non-breaking change example */ + FROM tcloud_demo.seed_raw_payments + ``` + + ```yaml + test_stg_payments: + model: tcloud_demo.stg_payments + inputs: + tcloud_demo.seed_raw_payments: + - id: 66 + order_id: 58 + payment_method: coupon + amount: 1800 + - id: 27 + order_id: 24 + payment_method: coupon + amount: 2600 + - id: 30 + order_id: 25 + payment_method: coupon + amount: 1600 + - id: 109 + order_id: 95 + payment_method: coupon + amount: 2400 + - id: 3 + order_id: 3 + payment_method: coupon + amount: 100 + outputs: + query: + - payment_id: 66 + order_id: 58 + payment_method: coupon + amount: 18.0 + new_column: new_column + - payment_id: 27 + order_id: 24 + payment_method: coupon + amount: 26.0 + new_column: new_column + - payment_id: 30 + order_id: 25 + payment_method: coupon + amount: 16.0 + new_column: new_column + - payment_id: 109 + order_id: 95 + payment_method: coupon + amount: 24.0 + new_column: new_column + - payment_id: 3 + order_id: 3 + payment_method: coupon + amount: 1.0 + new_column: new_column + ``` + +* Never build a table [more than once](https://tobikodata.com/simplicity-or-efficiency-how-dbt-makes-you-choose.html) +* Track what data’s been modified and run only the necessary transformations for [incremental models](https://tobikodata.com/correctly-loading-incremental-data-at-scale.html) +* Run [unit tests](https://tobikodata.com/we-need-even-greater-expectations.html) for free and configure automated audits + +??? tip "Level Up Your SQL" + + Write SQL in any dialect and SQLMesh will transpile it to your target SQL dialect on the fly before sending it to the warehouse. + Transpile Example + +* Debug transformation errors *before* you run them in your warehouse in [10+ different SQL dialects](https://sqlmesh.readthedocs.io/en/stable/integrations/overview/#execution-engines) +* Definitions using [simply SQL](https://sqlmesh.readthedocs.io/en/stable/concepts/models/sql_models/#sql-based-definition) (no need for redundant and confusing `Jinja` + `YAML`) +* See impact of changes before you run them in your warehouse with column-level lineage + +For more information, check out the [website](https://sqlmesh.com) and [documentation](https://sqlmesh.readthedocs.io/en/stable/). + +## Getting Started +Install SQLMesh through [pypi](https://pypi.org/project/sqlmesh/) by running: + +```bash +mkdir sqlmesh-example +cd sqlmesh-example +python -m venv .venv +source .venv/bin/activate +pip install sqlmesh +source .venv/bin/activate # reactivate the venv to ensure you're using the right installation +sqlmesh init duckdb # get started right away with a local duckdb instance +sqlmesh plan # see the plan for the changes you're making +``` + +> Note: You may need to run `python3` or `pip3` instead of `python` or `pip`, depending on your python installation. + +Follow the [quickstart guide](https://sqlmesh.readthedocs.io/en/stable/quickstart/cli/#1-create-the-sqlmesh-project) to learn how to use SQLMesh. You already have a head start! + +Follow this [example](https://sqlmesh.readthedocs.io/en/stable/examples/incremental_time_full_walkthrough/) to learn how to use SQLMesh in a full walkthrough. + +## Join Our Community +Together, we want to build data transformation without the waste. Connect with us in the following ways: + +* Join the [Tobiko Slack Community](https://tobikodata.com/slack) to ask questions, or just to say hi! +* File an issue on our [GitHub](https://github.com/SQLMesh/sqlmesh/issues/new) +* Send us an email at [hello@tobikodata.com](mailto:hello@tobikodata.com) with your questions or feedback +* Read our [blog](https://tobikodata.com/blog) + +## Contribution +Contributions in the form of issues or pull requests are greatly appreciated. + +[Read more](https://sqlmesh.readthedocs.io/en/stable/development/) on how to contribute to SQLMesh open source. + +[Watch this video walkthrough](https://www.loom.com/share/2abd0d661c12459693fa155490633126?sid=b65c1c0f-8ef7-4036-ad19-3f85a3b87ff2) to see how our team contributes a feature to SQLMesh. diff --git a/docs/installation.md b/docs/installation.md index 250cd057f5..f12ec566e2 100644 --- a/docs/installation.md +++ b/docs/installation.md @@ -8,12 +8,12 @@ It is recommended, but not required, that you use a python virtual environment w First, create the virtual environment: ```bash -python -m venv .env +python -m venv .venv ``` Then activate it: ```bash -source .env/bin/activate +source .venv/bin/activate ``` ## Install SQLMesh core @@ -24,50 +24,46 @@ pip install sqlmesh ``` ## Install extras -Some SQLMesh functionality requires additional Python libraries. +Some SQLMesh functionality requires additional Python libraries, which are bundled with SQLMesh via "extras". -`pip` will automatically install them for you if you specify the relevant name in brackets. For example, you install the SQLMesh browser UI extras with `pip install "sqlmesh[web]"`. +In your `pip` command, specify the extra's name in brackets to automatically install the additional libraries. For example, you install the SQLMesh Github CI/CD bot extras with `pip install "sqlmesh[github]"`. -Some extras add features, like the SQLMesh browser UI or Github CI/CD bot: +There are two types of extras. + +Some extras add features, like the SQLMesh VSCode extension or Github CI/CD bot: ??? info "Feature extras commands" | Feature | `pip` command | | ------------------- | ------------------------------- | - | Browser UI | `pip install "sqlmesh[web]"` | - | dbt projects | `pip install "sqlmesh[dbt]"` | + | VSCode extension | `pip install "sqlmesh[lsp]"` | | Github CI/CD bot | `pip install "sqlmesh[github]"` | + | dbt projects | `pip install "sqlmesh[dbt]"` | + | dlt projects | `pip install "sqlmesh[dlt]"` | | Slack notifications | `pip install "sqlmesh[slack]"` | | Development setup | `pip install "sqlmesh[dev]"` | + | Browser UI | `pip install "sqlmesh[web]"` | | LLM SQL prompt | `pip install "sqlmesh[llm]"` | Other extras are required to use specific SQL engines, like Bigquery or Postgres: ??? info "SQL engine extras commands" | SQL engine | `pip` command | - |---------------|--------------------------------------| + | ------------- | ------------------------------------ | + | Athena | `pip install "sqlmesh[athena]"` | + | Azure SQL | `pip install "sqlmesh[azuresql]"` | | Bigquery | `pip install "sqlmesh[bigquery]"` | + | ClickHouse | `pip install "sqlmesh[clickhouse]"` | | Databricks | `pip install "sqlmesh[databricks]"` | | GCP Postgres | `pip install "sqlmesh[gcppostgres]"` | | MS SQL Server | `pip install "sqlmesh[mssql]"` | | MySQL | `pip install "sqlmesh[mysql]"` | | Postgres | `pip install "sqlmesh[postgres]"` | | Redshift | `pip install "sqlmesh[redshift]"` | + | RisingWave | `pip install "sqlmesh[risingwave]"` | | Snowflake | `pip install "sqlmesh[snowflake]"` | + | Trino | `pip install "sqlmesh[trino]"` | -Multiple extras can be installed at once, as in `pip install "sqlmesh[web,slack]"`. - -## Pydantic v2 -SQLMesh supports Pydantic v2, but since v2 is relatively new, v1 is the version installed by default. If you would like to use Pydantic v2, you can by installing it after installing SQLMesh. - -```bash -pip install --upgrade pydantic -``` - -Pip may issue a warning about dependency conflicts, but SQLMesh should still function fine. Furthermore, if you are using the SQLMesh UI, you will also need to install pydantic-settings. - -```bash -pip install --upgrade pydantic-settings -``` +Multiple extras can be installed at once, as in `pip install "sqlmesh[github,slack]"`. ## Next steps diff --git a/docs/integrations/airflow.md b/docs/integrations/airflow.md deleted file mode 100644 index 5275728a97..0000000000 --- a/docs/integrations/airflow.md +++ /dev/null @@ -1,156 +0,0 @@ -# Airflow - -SQLMesh provides first-class support for Airflow with the following capabilities: - -* A Directed Acyclic Graph (DAG) generated dynamically for each model version. Each DAG accounts for all its upstream dependencies defined within SQLMesh, and only runs after upstream DAGs succeed for the time period being processed. -* Each plan application leads to the creation of a dynamically-generated DAG dedicated specifically to that Plan. -* The Airflow [Database Backend](https://airflow.apache.org/docs/apache-airflow/stable/howto/set-up-database.html) is used for persistence of the SQLMesh state, meaning no external storage or additional configuration is required for SQLMesh to work. -* The janitor DAG runs periodically and automatically to clean up DAGs and other SQLMesh artifacts that are no longer needed. -* Support for any SQL engine can be added by providing a custom Airflow Operator. - -## Airflow cluster configuration -To enable SQLMesh support on a target Airflow cluster, the SQLMesh package should first be installed on that cluster. Ensure it is installed with the extras for your engine if needed; for example: `sqlmesh[databricks]` for Databricks. Check [setup.py](https://github.com/TobikoData/sqlmesh/blob/main/setup.py) for a list of extras. - -**Note:** The Airflow Webserver instance(s) must be restarted after **installation** and every time the SQLMesh package is **upgraded**. - -Once the package is installed, the following Python module must be created in the `dags/` folder of the target DAG repository with the following contents: - -```python linenums="1" -from sqlmesh.schedulers.airflow.integration import SQLMeshAirflow - -sqlmesh_airflow = SQLMeshAirflow("spark", default_catalog="spark_catalog") - -for dag in sqlmesh_airflow.dags: - globals()[dag.dag_id] = dag -``` -The name of the module file can be arbitrary, but we recommend something descriptive such as `sqlmesh.py` or `sqlmesh_integration.py`. - -`SQLMeshAirflow` has two required arguments (`engine_operator` and `default_catalog`). Details on these and additional optional arguments below: - -| Argument | Description | Type | Required | -|---------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:----------------------:|:--------:| -| `engine_operator` | Name or operator to use for creating models. See [Engine Support](#engine-support) for list of options | string or BaseOperator | Y | -| `default_catalog` | The default catalog (also called "database" in other engines) to use when models are defined that do not contain a catalog in their name. This should match the default catalog applied by the connection. | string | Y | -| `engine_operator_args` | The dictionary of arguments that will be passed into the evaluate engine operator during its construction. This can be used to customize parameters such as connection ID. | dict | N | -| `ddl_engine_operator` | The type of the Airflow operator that will be used for environment management. These operations are SQL only. `engine_operator` is used if not provided | string or BaseOperator | N | -| `ddl_engine_operator_args` | Args to be passed into just the environment management operator. This can be used to customize parameters such as connection ID. | dict | N | -| `janitor_interval` | Defines how often the janitor DAG runs. The janitor DAG removes platform-managed DAG instances that are pending deletion from Airflow. Default: 1 hour. | timedelta | N | -| `plan_application_dag_ttl` | Determines the time-to-live period for finished plan application DAGs. Once this period is exceeded, finished plan application DAGs are deleted by the janitor. Default: 2 days. | timedelta | N | -| `external_table_sensor_factory` | A factory function that creates a sensor operator for a given signal payload. See [External signals](#external-signals) for more info | function | N | -| `sensor_mode` | The mode to use for SQLMesh sensors. Supported values are "poke" and "reschedule". See https://airflow.apache.org/docs/apache-airflow/stable/core-concepts/sensors.html for more details. Default: "reschedule" | string | N | -| `high_water_mark_sensor_args` | The dictionary of arguments that will be passed into the high water mark sensor during its construction. | dict | N | -| `external_sensor_args` | The dictionary of arguments that will be passed into the external sensor during its construction. | dict | N | -| `generate_cadence_dags` | Whether to generate cadence DAGs for model versions that are currently deployed to production. | bool | N | - - -### State connection - -By default, SQLMesh uses the Airflow's database connection to read and write its state. - -To configure a different storage backend for the SQLMesh state you need to create a new [Airflow Connection](https://airflow.apache.org/docs/apache-airflow/stable/howto/connection.html) with ID `sqlmesh_state_db` and type `Generic`. The configuration should be provided in the connection's `extra` field in JSON format. - -![SQLMesh state connection](airflow/airflow_sqlmesh_state_connection.png) - -Refer to the [Connection Configuration](../reference/configuration.md#connection) for supported fields. - -## SQLMesh client configuration -In your SQLMesh repository, create the following configuration within config.yaml: -```yaml linenums="1" -default_scheduler: - type: airflow - airflow_url: https://:/ - username: - password: -``` - -## External signals - -Sometimes there is a need to postpone the model evaluation until certain external conditions are met. - -For example, a model might refer to an external table and should only be evaluated when the data actually lands upstream. This can be achieved using external signals. - -Signals are defined as part of the model's definition using arbitrary key-value pairs. Additionally, `@start_*` and `@end_*` [macros](../concepts/macros/macro_variables.md) can be used within these values. The macro values will be resolved accordingly at the time of evaluation. - -```sql linenums="1" -MODEL ( - name test_db.test_name, - signals [ - ( - table_name = 'upstream_table_a', - ds = @end_ds, - ), - ( - table_name = 'upstream_table_b', - ds = @end_ds, - hour = @end_hour, - ), - ], -) -``` - -Note that in the example above, `table_name`, `ds`, and `hour` are arbitrary keys defined by the user. - -Now, as part of the SQLMesh integration module, a function needs to be passed into the `SQLMeshAirflow` constructor. This function should accept signal payload and return an Airflow Sensor instance representing this signal. - -```python linenums="1" -import typing as t -from airflow.sensors.base import BaseSensorOperator -from sqlmesh.schedulers.airflow.integration import SQLMeshAirflow - - -def create_external_sensor(signal: t.Dict[str, t.Any]) -> BaseSensorOperator: - table_name = signal["table_name"] - ds = signal["ds"] - hour = signal["hour"] - return MyCustomSensor(partition=f"{table_name}/ds={ds}/hour={hour:02}") - - -sqlmesh_airflow = SQLMeshAirflow( - "spark", - default_catalog="spark_catalog", - external_table_sensor_factory=create_external_sensor, -) -``` - -The `create_external_sensor` function in the example above takes the `signal` dictionary as an argument and returns an instance of `BaseSensorOperator`. The keys in the signal dictionary match the keys provided in the model definition. - -## Engine support -SQLMesh supports a variety of engines in Airflow. Support for each engine is provided by a custom Airflow operator implementation. Below is a list of links to operators supported out of the box with information on how to configure them. - -* [BigQuery](engines/bigquery.md#airflow-scheduler) -* [Databricks](engines/databricks.md#airflow-scheduler) -* [MSSQL](engines/mssql.md#airflow-scheduler) -* [Postgres](engines/postgres.md#airflow-scheduler) -* [Redshift](engines/redshift.md#airflow-scheduler) -* [Snowflake](engines/snowflake.md#airflow-scheduler) -* [Spark](engines/spark.md#airflow-scheduler) -* [Trino](engines/trino.md#airflow-scheduler) -* [MySQL](engines/mysql.md#airflow-scheduler) - -## Managed Airflow instances - -Multiple companies offer managed Airflow instances that integrate with their products. This section describes SQLMesh support for some of the options. - -### Google Cloud Composer - -SQLMesh fully supports Airflow hosted on [Google Cloud Composer](https://cloud.google.com/composer/docs/composer-2/composer-overview) - see the [configuration reference page](../reference/configuration.md#cloud-composer) for more information. - -### Astronomer - -Astronomer provides [managed Airflow instances](https://www.astronomer.io/product/) running on AWS, GCP, and Azure. SQLMesh fully supports Airflow hosted by Astronomer. - -### AWS MWAA - -Due to MWAA not supporting the Airflow REST API, users are required to configure an external state connection for both the [client](../guides/connections.md#state-connection) and [Airflow cluster](#state-connection) to point to the same database. - -Additional dependencies need to be installed: -```bash -pip install "sqlmesh[mwaa]" -``` - -Additionally, the scheduler needs to be configured accordingly: -```yaml linenums="1" -default_scheduler: - type: mwaa - environment: -``` diff --git a/docs/integrations/dbt.md b/docs/integrations/dbt.md index 9b688f1a2b..5854236aa2 100644 --- a/docs/integrations/dbt.md +++ b/docs/integrations/dbt.md @@ -2,7 +2,40 @@ SQLMesh has native support for running dbt projects with its dbt adapter. +!!! tip + + If you've never used SQLMesh before, learn the basics of how it works in the [SQLMesh Quickstart](../quick_start.md)! + ## Getting started + +### Installing SQLMesh + +SQLMesh is a Python library you install with the `pip` command. We recommend running your SQLMesh projects in a [Python virtual environment](../installation.md#python-virtual-environment), which must be created and activated before running any `pip` commands. + +Most people do not use all of SQLMesh's functionality. For example, most projects only run on one [SQL execution engine](../integrations/overview.md#execution-engines). + +Therefore, SQLMesh is packaged with multiple "extras," which you may optionally install based on the functionality your project needs. You may specify all your project's extras in a single `pip` call. + +At minimum, using the SQLMesh dbt adapter requires installing the dbt extra: + +```bash +> pip install "sqlmesh[dbt]" +``` + +If your project uses any SQL execution engine other than DuckDB, you must install the extra for that engine. For example, if your project runs on the Postgres SQL engine: + +```bash +> pip install "sqlmesh[dbt,postgres]" +``` + +If you would like to use the [SQLMesh Browser UI](../guides/ui.md) to view column-level lineage, include the `web` extra: + +```bash +> pip install "sqlmesh[dbt,web]" +``` + +Learn more about [SQLMesh installation and extras here](../installation.md#install-extras). + ### Reading a dbt project Prepare an existing dbt project to be run by SQLMesh by executing the `sqlmesh init` command *within the dbt project root directory* and with the `dbt` template option: @@ -11,60 +44,128 @@ Prepare an existing dbt project to be run by SQLMesh by executing the `sqlmesh i $ sqlmesh init -t dbt ``` -SQLMesh will use the data warehouse connection target in your dbt project `profiles.yml` file. The target can be changed at any time. +This will create a file called `sqlmesh.yaml` containing the [default model start date](../reference/model_configuration.md#model-defaults). This configuration file is a minimum starting point for enabling SQLMesh to work with your DBT project. + +As you become more comfortable with running your project under SQLMesh, you may specify additional SQLMesh [configuration](../reference/configuration.md) as required to unlock more features. + +!!! note "profiles.yml" + + SQLMesh will use the existing data warehouse connection target from your dbt project's `profiles.yml` file so the connection configuration does not need to be duplicated in `sqlmesh.yaml`. You may change the target at any time in the dbt config and SQLMesh will pick up the new target. ### Setting model backfill start dates -Models **require** a start date for backfilling data through use of the `start` configuration parameter. `start` can be defined individually for each model in its `config` block or globally in the `dbt_project.yml` file as follows: +Models **require** a start date for backfilling data through use of the `start` configuration parameter. `start` can be defined individually for each model in its `config` block or globally in the `sqlmesh.yaml` file as follows: -``` -> models: -> +start: Jan 1 2000 -``` +=== "sqlmesh.yaml" + + ```yaml + model_defaults: + start: '2000-01-01' + ``` + +=== "dbt Model" + + ```jinja + {{ + config( + materialized='incremental', + start='2000-01-01', + ... + ) + }} + ``` ### Configuration -SQLMesh determines a project's configuration settings from its dbt configuration files. +SQLMesh derives a project's configuration from its dbt configuration files. This section outlines additional settings specific to SQLMesh that can be defined. -This section describes using runtime variables to create multiple configurations and how to disable SQLMesh's automatic model description and comment registration. +#### Selecting a different state connection -#### Runtime vars +[Certain engines](https://sqlmesh.readthedocs.io/en/stable/guides/configuration/?h=unsupported#state-connection), like Trino, cannot be used to store SQLMesh's state. -dbt supports passing variable values at runtime with its [CLI `vars` option](https://docs.getdbt.com/docs/build/project-variables#defining-variables-on-the-command-line). +In addition, even if your warehouse is supported for state, you may find that you get better performance by using a [traditional database](../concepts/state.md) to store state as these are a better fit for the state workload than a warehouse optimized for analytics workloads. -In SQLMesh, these variables are passed via configurations. When you initialize a dbt project with `sqlmesh init`, a file `config.py` is created in your project directory. +In these cases, we recommend specifying a [supported production state engine](../concepts/state.md#state) using the `state_connection` configuration. -The file creates a SQLMesh `config` object pointing to the project directory: +This involves updating `sqlmesh.yaml` to add a gateway configuration for the state connection: -```python -config = sqlmesh_config(Path(__file__).parent) +```yaml +gateways: + "": # "" (empty string) is the default gateway + state_connection: + type: postgres + ... + +model_defaults: + start: '2000-01-01' ``` -Specify runtime variables by adding a Python dictionary to the `sqlmesh_config()` `variables` argument. +Or, for a specific dbt profile defined in `profiles.yml`, eg `dev`: + +```yaml +gateways: + dev: # must match the target dbt profile name + state_connection: + type: postgres + ... + +model_defaults: + start: '2000-01-01' +``` + +Learn more about how to configure state connections [here](https://sqlmesh.readthedocs.io/en/stable/guides/configuration/#state-connection). + +#### Runtime vars + +dbt supports passing variable values at runtime with its [CLI `vars` option](https://docs.getdbt.com/docs/build/project-variables#defining-variables-on-the-command-line). + +In SQLMesh, these variables are passed via configurations. When you initialize a dbt project with `sqlmesh init`, a file `sqlmesh.yaml` is created in your project directory. + +You may define global variables in the same way as a native project by adding a `variables` section to the config. For example, we could specify the runtime variable `is_marketing` and its value `no` as: -```python -config = sqlmesh_config( - Path(__file__).parent, - variables={"is_marketing": "no"} - ) +```yaml +variables: + is_marketing: no + +model_defaults: + start: '2000-01-01' ``` +Variables can also be set at the gateway/profile level which override variables set at the project level. See the [variables documentation](../concepts/macros/sqlmesh_macros.md#gateway-variables) to learn more about how to specify them at different levels. + +#### Combinations + Some projects use combinations of runtime variables to control project behavior. Different combinations can be specified in different `sqlmesh_config` objects, with the relevant configuration passed to the SQLMesh CLI command. +!!! info "Python config" + + Switching between different config objects requires the use of [Python config](../guides/configuration.md#python) instead of the default YAML config. + + You will need to create a file called `config.py` in the root of your project with the following contents: + + ```py + from pathlib import Path + from sqlmesh.dbt.loader import sqlmesh_config + + config = sqlmesh_config(Path(__file__).parent) + ``` + + Note that any config from `sqlmesh.yaml` will be overlayed on top of the active Python config so you dont need to remove the `sqlmesh.yaml` file + For example, consider a project with a special configuration for the `marketing` department. We could create separate configurations to pass at runtime like this: ```python config = sqlmesh_config( - Path(__file__).parent, - variables={"is_marketing": "no", "include_pii": "no"} - ) + Path(__file__).parent, + variables={"is_marketing": "no", "include_pii": "no"} +) marketing_config = sqlmesh_config( - Path(__file__).parent, - variables={"is_marketing": "yes", "include_pii": "yes"} - ) + Path(__file__).parent, + variables={"is_marketing": "yes", "include_pii": "yes"} +) ``` By default, SQLMesh will use the configuration object named `config`. Use a different configuration by passing the object name to SQLMesh CLI commands with the `--config` option. For example, we could run a `plan` with the marketing configuration like this: @@ -118,7 +219,7 @@ This section describes how to adapt dbt's incremental models to run on sqlmesh a SQLMesh supports two approaches to implement [idempotent](../concepts/glossary.md#idempotency) incremental loads: * Using merge (with the sqlmesh [`INCREMENTAL_BY_UNIQUE_KEY` model kind](../concepts/models/model_kinds.md#incremental_by_unique_key)) -* Using insert-overwrite/delete+insert (with the sqlmesh [`INCREMENTAL_BY_TIME_RANGE` model kind](../concepts/models/model_kinds.md#incremental_by_time_range)) +* Using [`INCREMENTAL_BY_TIME_RANGE` model kind](../concepts/models/model_kinds.md#incremental_by_time_range) #### Incremental by unique key @@ -132,28 +233,22 @@ To enable incremental_by_unique_key incrementality, the model configuration shou #### Incremental by time range -To enable incremental_by_time_range incrementality, the model configuration should contain: +To enable incremental_by_time_range incrementality, the model configuration must contain: -* The `time_column` key with the model's time column field name as the value (see [`time column`](../concepts/models/model_kinds.md#time-column) for details) * The `materialized` key with value `'incremental'` -* Either: - * The `incremental_strategy` key with value `'insert_overwrite'` or - * The `incremental_strategy` key with value `'delete+insert'` - * Note: in this context, these two strategies are synonyms. Regardless of which one is specified SQLMesh will use the [`best incremental strategy`](../concepts/models/model_kinds.md#materialization-strategy) for the target engine. +* The `incremental_strategy` key with the value `incremental_by_time_range` +* The `time_column` key with the model's time column field name as the value (see [`time column`](../concepts/models/model_kinds.md#time-column) for details) ### Incremental logic -SQLMesh requires a new jinja block gated by `{% if sqlmesh_incremental is defined %}`. The new block should supersede the existing `{% if is_incremental() %}` block and contain the `WHERE` clause selecting the time interval. +Unlike dbt incremental strategies, SQLMesh does not require the use of `is_incremental` jinja blocks to implement incremental logic. +Instead, SQLMesh provides predefined time macro variables that can be used in the model's SQL to filter data based on the time column. For example, the SQL `WHERE` clause with the "ds" column goes in a new jinja block gated by `{% if sqlmesh_incremental is defined %}` as follows: ```bash -> {% if sqlmesh_incremental is defined %} > WHERE > ds BETWEEN '{{ start_ds }}' AND '{{ end_ds }}' -> {% elif is_incremental() %} -> ; < your existing is_incremental block > -> {% endif %} ``` `{{ start_ds }}` and `{{ end_ds }}` are the jinja equivalents of SQLMesh's `@start_ds` and `@end_ds` predefined time macro variables. See all [predefined time variables](../concepts/macros/macro_variables.md) available in jinja. @@ -162,26 +257,26 @@ For example, the SQL `WHERE` clause with the "ds" column goes in a new jinja blo SQLMesh provides configuration parameters that enable control over how incremental computations occur. These parameters are set in the model's `config` block. -The [`batch_size` parameter](../concepts/models/overview.md#batch_size) determines the maximum number of time intervals to run in a single job. - -The [`lookback` parameter](../concepts/models/overview.md#lookback) is used to capture late arriving data. It sets the number of units of late arriving data the model should expect and must be a positive integer. +See [Incremental Model Properties](../concepts/models/overview.md#incremental-model-properties) for the full list of incremental model configuration parameters. **Note:** By default, all incremental dbt models are configured to be [forward-only](../concepts/plans.md#forward-only-plans). However, you can change this behavior by setting the `forward_only: false` setting either in the configuration of an individual model or globally for all models in the `dbt_project.yaml` file. The [forward-only](../concepts/plans.md#forward-only-plans) mode aligns more closely with the typical operation of dbt and therefore better meets user's expectations. +Similarly, the [allow_partials](../concepts/models/overview.md#allow_partials) parameter is set to `true` by default unless the `allow_partials` parameter is explicitly set to `false` in the model configuration. + #### on_schema_change -SQLMesh automatically detects destructive schema changes to [forward-only incremental models](../guides/incremental_time.md#forward-only-models) and to all incremental models in [forward-only plans](../concepts/plans.md#destructive-changes). +SQLMesh automatically detects both destructive and additive schema changes to [forward-only incremental models](../guides/incremental_time.md#forward-only-models) and to all incremental models in [forward-only plans](../concepts/plans.md#destructive-changes). -A model's [`on_destructive_change` setting](../guides/incremental_time.md#destructive-changes) determines whether it errors (default), warns, or silently allows the changes. SQLMesh always allows non-destructive forward-only schema changes, such as adding or casting a column in place. +A model's [`on_destructive_change` and `on_additive_change` settings](../guides/incremental_time.md#schema-changes) determine whether it errors, warns, silently allows, or ignores the changes. SQLMesh provides fine-grained control over both destructive changes (like dropping columns) and additive changes (like adding new columns). -`on_schema_change` configuration values are mapped to these SQLMesh `on_destructive_change` values: +`on_schema_change` configuration values are mapped to these SQLMesh settings: -| `on_schema_change` | SQLMesh `on_destructive_change` | -| ------------------ | ------------------------------- | -| ignore | warn | -| append_new_columns | warn | -| sync_all_columns | allow | -| fail | error | +| `on_schema_change` | SQLMesh `on_destructive_change` | SQLMesh `on_additive_change` | +|--------------------|---------------------------------|------------------------------| +| ignore | ignore | ignore | +| fail | error | error | +| append_new_columns | ignore | allow | +| sync_all_columns | allow | allow | ## Snapshot support @@ -202,9 +297,26 @@ SQLMesh parses seed CSV files using [Panda's `read_csv` utility](https://pandas. dbt parses seed CSV files using [agate's csv reader](https://agate.readthedocs.io/en/latest/api/csv.html#csv-reader-and-writer) and [customizes agate's default type inference](https://github.com/dbt-labs/dbt-common/blob/ae8ffe082926fdb3ef2a15486588f40c7739aea9/dbt_common/clients/agate_helper.py#L59). -If SQLMesh and dbt infer different column types for a seed CSV file, you may specify your desired data types in a [seed properties configuration file](https://docs.getdbt.com/reference/seed-properties). +If SQLMesh and dbt infer different column types for a seed CSV file, you may specify a [column_types](https://docs.getdbt.com/reference/resource-configs/column_types) dictionary in your `dbt_project.yml` file, where the keys define the column names and the values the data types. + +``` yaml +seeds: + + +column_types: + : +``` + +Alternatively, you can define this dictionary in the seed [seed properties configuration file](https://docs.getdbt.com/reference/seed-properties). + +``` yaml +seeds: + - name: + config: + column_types: + : +``` -Specify a column's SQL data type in its `data_type` key, as shown below. The file must list all columns present in the CSV file; SQLMesh's default type inference will be used for columns that do not specify the `data_type` key. +You may also specify a column's SQL data type in its `data_type` key, as shown below. The file must list all columns present in the CSV file; SQLMesh's default type inference will be used for columns that do not specify the `data_type` key. ``` yaml seeds: @@ -220,49 +332,20 @@ SQLMesh does not have its own package manager; however, SQLMesh's dbt adapter is ## Documentation Model documentation is available in the [SQLMesh UI](../quickstart/ui.md#2-open-the-sqlmesh-web-ui). -## Using Airflow -To use SQLMesh and dbt projects with Airflow, first configure SQLMesh to use Airflow as described in the [Airflow integrations documentation](./airflow.md). - -Then, install dbt-core within airflow. - -Finally, replace the contents of `config.py` with: - -```bash -> from pathlib import Path -> -> from sqlmesh.core.config import AirflowSchedulerConfig -> from sqlmesh.dbt.loader import sqlmesh_config -> -> config = sqlmesh_config( -> Path(__file__).parent, -> default_scheduler=AirflowSchedulerConfig( -> airflow_url="https://:/", -> username="", -> password="", -> ) -> ) -``` - -See the [Airflow configuration documentation](https://airflow.apache.org/docs/apache-airflow/2.1.0/configurations-ref.html) for a list of all AirflowSchedulerConfig configuration options. Note: only the python config file format is supported for dbt at this time. - -The project is now configured to use airflow. Going forward, this also means that the engine configured in airflow will be used instead of the target engine specified in profiles.yml. - ## Supported dbt jinja methods SQLMesh supports running dbt projects using the majority of dbt jinja methods, including: -| Method | Method | Method | Method | -| ----------- | -------------- | ------------ | ------- | -| adapter (*) | env_var | project_name | target | -| as_bool | exceptions | ref | this | -| as_native | from_yaml | return | to_yaml | -| as_number | is_incremental | run_query | var | -| as_text | load_result | schema | zip | -| api | log | set | | -| builtins | modules | source | | -| config | print | statement | | - -\* `adapter.rename_relation` and `adapter.expand_target_column_types` are not currently supported. +| Method | Method | Method | Method | +| --------- | -------------- | ------------ | ------- | +| adapter | env_var | project_name | target | +| as_bool | exceptions | ref | this | +| as_native | from_yaml | return | to_yaml | +| as_number | is_incremental | run_query | var | +| as_text | load_result | schema | zip | +| api | log | set | | +| builtins | modules | source | | +| config | print | statement | | ## Unsupported dbt jinja methods @@ -270,13 +353,9 @@ The dbt jinja methods that are not currently supported are: * debug * selected_sources -* adapter.expand_target_column_types -* adapter.rename_relation -* schemas * graph.nodes.values * graph.metrics.values -* version - learn more about why SQLMesh doesn't support model versions at the [Tobiko Data blog](https://tobikodata.com/the-false-promise-of-dbt-contracts.html) ## Missing something you need? -Submit an [issue](https://github.com/TobikoData/sqlmesh/issues), and we'll look into it! +Submit an [issue](https://github.com/SQLMesh/sqlmesh/issues), and we'll look into it! diff --git a/docs/integrations/dlt.md b/docs/integrations/dlt.md new file mode 100644 index 0000000000..7125510de9 --- /dev/null +++ b/docs/integrations/dlt.md @@ -0,0 +1,118 @@ +# dlt + +SQLMesh enables efforless project generation using data ingested through [dlt](https://github.com/dlt-hub/dlt). This involves creating a baseline project scaffolding, generating incremental models to process the data from the pipeline's tables by inspecting its schema and configuring the gateway connection using the pipeline's credentials. + +## Getting started +### Reading from a dlt pipeline + +To load data from a dlt pipeline into SQLMesh, ensure the dlt pipeline has been run or restored locally. Then simply execute the sqlmesh `init` command *within the dlt project root directory* using the `dlt` template option and specifying the pipeline's name with the `dlt-pipeline` option: + +```bash +$ sqlmesh init -t dlt --dlt-pipeline dialect +``` + +This will create the configuration file and directories, which are found in all SQLMesh projects: + +- config.yaml + - The file for project configuration. Refer to [configuration](../reference/configuration.md). +- ./models + - SQL and Python models. Refer to [models](../concepts/models/overview.md). +- ./seeds + - Seed files. Refer to [seeds](../concepts/models/seed_models.md). +- ./audits + - Shared audit files. Refer to [auditing](../concepts/audits.md). +- ./tests + - Unit test files. Refer to [testing](../concepts/tests.md). +- ./macros + - Macro files. Refer to [macros](../concepts/macros/overview.md). + +SQLMesh will also automatically generate models to ingest data from the pipeline incrementally. Incremental loading is ideal for large datasets where recomputing entire tables is resource-intensive. In this case utilizing the [`INCREMENTAL_BY_TIME_RANGE` model kind](../concepts/models/model_kinds.md#incremental_by_time_range). However, these model definitions can be customized to meet your specific project needs. + +#### Specify the path to the pipelines directory + +The default location for dlt pipelines is `~/.dlt/pipelines/`. If your pipelines are in a [different directory](https://dlthub.com/docs/general-usage/pipeline#separate-working-environments-with-pipelines_dir), use the `--dlt-path` argument to specify the path explicitly: + +```bash +$ sqlmesh init -t dlt --dlt-pipeline --dlt-path dialect +``` + +### Generating models on demand + +To update the models in your SQLMesh project on demand, use the `dlt_refresh` command. This allows you to either specify individual tables to generate incremental models from or update all models at once. + +- **Generate all missing tables**: + +```bash +$ sqlmesh dlt_refresh +``` + +- **Generate all missing tables and overwrite existing ones** (use with `--force` or `-f`): + +```bash +$ sqlmesh dlt_refresh --force +``` + +- **Generate specific dlt tables** (using `--table` or `-t`): + +```bash +$ sqlmesh dlt_refresh --table +``` + +- **Provide the explicit path to the pipelines directory** (using `--dlt-path`): + +```bash +$ sqlmesh dlt_refresh --dlt-path +``` + +#### Configuration + +SQLMesh will retrieve the data warehouse connection credentials from your dlt project to configure the `config.yaml` file. This configuration can be modified or customized as needed. For more details, refer to the [configuration guide](../guides/configuration.md). + +### Example + +Generating a SQLMesh project dlt is quite simple. In this example, we'll use the example `sushi_pipeline.py` from the [sushi-dlt project](https://github.com/SQLMesh/sqlmesh/tree/main/examples/sushi_dlt). + +First, run the pipeline within the project directory: + +```bash +$ python sushi_pipeline.py +Pipeline sushi load step completed in 2.09 seconds +Load package 1728074157.660565 is LOADED and contains no failed jobs +``` + +After the pipeline has run, generate a SQLMesh project by executing: + +```bash +$ sqlmesh init -t dlt --dlt-pipeline sushi duckdb +``` + +Then the SQLMesh project is all set up. You can then proceed to run the SQLMesh `plan` command to ingest the dlt pipeline data and populate the SQLMesh tables: + +```bash +$ sqlmesh plan +`prod` environment will be initialized + +Models: +└── Added: + ├── sushi_dataset_sqlmesh.incremental__dlt_loads + ├── sushi_dataset_sqlmesh.incremental_sushi_types + └── sushi_dataset_sqlmesh.incremental_waiters +Models needing backfill (missing dates): +├── sushi_dataset_sqlmesh.incremental__dlt_loads: 2024-10-03 - 2024-10-03 +├── sushi_dataset_sqlmesh.incremental_sushi_types: 2024-10-03 - 2024-10-03 +└── sushi_dataset_sqlmesh.incremental_waiters: 2024-10-03 - 2024-10-03 +Apply - Backfill Tables [y/n]: y +[1/1] sushi_dataset_sqlmesh.incremental__dlt_loads evaluated in 0.01s +[1/1] sushi_dataset_sqlmesh.incremental_sushi_types evaluated in 0.00s +[1/1] sushi_dataset_sqlmesh.incremental_waiters evaluated in 0.01s +Evaluating models ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 3/3 • 0:00:00 + + +All model batches have been executed successfully + +Virtually Updating 'prod' ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 0:00:00 + +The target environment has been updated successfully +``` + +Once the models are planned and applied, you can continue as with any SQLMesh project, generating and applying [plans](../concepts/overview.md#make-a-plan), running [tests](../concepts/overview.md#tests) or [audits](../concepts/overview.md#audits), and executing models with a [scheduler](../guides/scheduling.md) if desired. diff --git a/docs/integrations/engines/athena.md b/docs/integrations/engines/athena.md new file mode 100644 index 0000000000..1c39ecbd94 --- /dev/null +++ b/docs/integrations/engines/athena.md @@ -0,0 +1,73 @@ +# Athena + +## Installation + +``` +pip install "sqlmesh[athena]" +``` + +## Connection options + +### PyAthena connection options + +SQLMesh leverages the [PyAthena](https://github.com/laughingman7743/PyAthena) DBAPI driver to connect to Athena. Therefore, the connection options relate to the PyAthena connection options. +Note that PyAthena uses [boto3](https://boto3.amazonaws.com/v1/documentation/api/latest/index.html) under the hood so you can also use [boto3 environment variables](https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html#using-environment-variables) for configuration. + +| Option | Description | Type | Required | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------|:------:|:--------:| +| `type` | Engine type name - must be `athena` | string | Y | +| `aws_access_key_id` | The access key for your AWS user | string | N | +| `aws_secret_access_key` | The secret key for your AWS user | string | N | +| `role_arn` | The ARN of a role to assume once authenticated | string | N | +| `role_session_name` | The session name to use when assuming `role_arn` | string | N | +| `region_name` | The AWS region to use | string | N | +| `work_group` | The Athena [workgroup](https://docs.aws.amazon.com/athena/latest/ug/workgroups-manage-queries-control-costs.html) to send queries to | string | N | +| `s3_staging_dir` | The S3 location for Athena to write query results. Only required if not using `work_group` OR the configured `work_group` doesnt have a results location set | string | N | +| `schema_name` | The default schema to place objects in if a schema isnt specified. Defaults to `default` | string | N | +| `catalog_name` | The default catalog to place schemas in. Defaults to `AwsDataCatalog` | string | N | + +### SQLMesh connection options + +These options are specific to SQLMesh itself and are not passed to PyAthena + +| Option | Description | Type | Required | +|-------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------|----------| +| `s3_warehouse_location` | Set the base path in S3 where SQLMesh will instruct Athena to place table data. Only required if you arent specifying the location in the model itself. See [S3 Locations](#s3-locations) below. | string | N | + +## Model properties + +The Athena adapter utilises the following model top-level [properties](../../concepts/models/overview.md#model-properties): + +| Name | Description | Type | Required | +|------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------|----------| +| `table_format` | Sets the [table_type](https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html#ctas-table-properties) Athena uses when creating the table. Valid values are `hive` or `iceberg`. | string | N | +| `storage_format` | Configures the file format to be used by the `table_format`. For Hive tables, this sets the [STORED AS](https://docs.aws.amazon.com/athena/latest/ug/create-table.html#parameters) option. For Iceberg tables, this sets [format](https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html#ctas-table-properties) property. | string | N | + +The Athena adapter recognises the following model [physical_properties](../../concepts/models/overview.md#physical_properties): + +| Name | Description | Type | Default | +|-------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|--------|---------| +| `s3_base_location`| `s3://` base URI of where the snapshot tables for this model should be written. Overrides `s3_warehouse_location` if one is configured. | string | | + + +## S3 Locations +When creating tables, Athena needs to know where in S3 the table data is located. You cannot issue a `CREATE TABLE` statement without specifying a `LOCATION` for the table data. + +In addition, unlike other engines such as Trino, Athena will not infer a table location if you set a _schema_ location via `CREATE SCHEMA LOCATION 's3://schema/location'`. + +Therefore, in order for SQLMesh to issue correct `CREATE TABLE` statements to Athena, you need to configure where the tables should be stored. There are two options for this: + +- **Project-wide:** set `s3_warehouse_location` in the connection config. SQLMesh will set the table `LOCATION` to be `//` when it creates a snapshot of your model. +- **Per-model:** set `s3_base_location` in the model `physical_properties`. SQLMesh will set the table `LOCATION` to be `/` every time it creates a snapshot of your model. This takes precedence over any `s3_warehouse_location` set in the connection config. + + +## Limitations +Athena was initially designed to read data stored in S3 and to do so without changing that data. This means that it does not have good support for mutating tables. In particular, it will not delete data from Hive tables. + +Consequently, [forward only changes](../../concepts/plans.md#forward-only-change) that mutate the schemas of existing tables have a high chance of failure because Athena supports very limited schema modifications on Hive tables. + +However, Athena does support [Apache Iceberg](https://docs.aws.amazon.com/athena/latest/ug/querying-iceberg.html) tables which allow a full range of operations. These can be used for more complex model types such as [`INCREMENTAL_BY_UNIQUE_KEY`](../../concepts/models/model_kinds.md#incremental_by_unique_key) and [`SCD_TYPE_2`](../../concepts/models/model_kinds.md#scd-type-2). + +To use an Iceberg table for a model, set `table_format iceberg` in the model [properties](../../concepts/models/overview.md#model-properties). + +In general, Iceberg tables offer the most flexibility and you'll run into the least SQLMesh limitations when using them. However, we create Hive tables by default because Athena creates Hive tables by default, so Iceberg tables are opt-in rather than opt-out. diff --git a/docs/integrations/engines/azuresql.md b/docs/integrations/engines/azuresql.md new file mode 100644 index 0000000000..5b54ffa9c6 --- /dev/null +++ b/docs/integrations/engines/azuresql.md @@ -0,0 +1,36 @@ +# Azure SQL + +[Azure SQL](https://azure.microsoft.com/en-us/products/azure-sql) is "a family of managed, secure, and intelligent products that use the SQL Server database engine in the Azure cloud." + +## Local/Built-in Scheduler +**Engine Adapter Type**: `azuresql` + +### Installation +#### User / Password Authentication: +``` +pip install "sqlmesh[azuresql]" +``` +#### Microsoft Entra ID / Azure Active Directory Authentication: +``` +pip install "sqlmesh[azuresql-odbc]" +``` + +### Connection options + +| Option | Description | Type | Required | +| ----------------- | ---------------------------------------------------------------- | :----------: | :------: | +| `type` | Engine type name - must be `azuresql` | string | Y | +| `host` | The hostname of the Azure SQL server | string | Y | +| `user` | The username / client ID to use for authentication with the Azure SQL server | string | N | +| `password` | The password / client secret to use for authentication with the Azure SQL server | string | N | +| `port` | The port number of the Azure SQL server | int | N | +| `database` | The target database | string | N | +| `charset` | The character set used for the connection | string | N | +| `timeout` | The query timeout in seconds. Default: no timeout | int | N | +| `login_timeout` | The timeout for connection and login in seconds. Default: 60 | int | N | +| `appname` | The application name to use for the connection | string | N | +| `conn_properties` | The list of connection properties | list[string] | N | +| `autocommit` | Is autocommit mode enabled. Default: false | bool | N | +| `driver` | The driver to use for the connection. Default: pymssql | string | N | +| `driver_name` | The driver name to use for the connection. E.g., *ODBC Driver 18 for SQL Server* | string | N | +| `odbc_properties` | The dict of ODBC connection properties. E.g., authentication: ActiveDirectoryServicePrincipal. See more [here](https://learn.microsoft.com/en-us/sql/connect/odbc/dsn-connection-string-attribute?view=sql-server-ver16). | dict | N | \ No newline at end of file diff --git a/docs/integrations/engines/bigquery.md b/docs/integrations/engines/bigquery.md index 9acf2bf925..b93d6837ed 100644 --- a/docs/integrations/engines/bigquery.md +++ b/docs/integrations/engines/bigquery.md @@ -1,81 +1,171 @@ # BigQuery -## Local/Built-in Scheduler +## Introduction -**Engine Adapter Type**: `bigquery` +This guide provides step-by-step instructions on how to connect SQLMesh to the BigQuery SQL engine. -### Installation -``` -pip install "sqlmesh[bigquery]" +It will walk you through the steps of installing SQLMesh and BigQuery connection libraries locally, configuring the connection in SQLMesh, and running the [quickstart project](../../quick_start.md). + +## Prerequisites + +This guide assumes the following about the BigQuery project being used with SQLMesh: + +- The project already exists +- Project [CLI/API access is enabled](https://cloud.google.com/endpoints/docs/openapi/enable-api) +- Project [billing is configured](https://cloud.google.com/billing/docs/how-to/manage-billing-account) (i.e. it's not a sandbox project) +- SQLMesh can authenticate using an account with permissions to execute commands against the project + +## Installation + +Follow the [quickstart installation guide](../../installation.md) up to the step that [installs SQLMesh](../../installation.md#install-sqlmesh-core), where we deviate to also install the necessary BigQuery libraries. + +Instead of installing just SQLMesh core, we will also include the BigQuery engine libraries: + +```bash +> pip install "sqlmesh[bigquery]" ``` -### Connection options +### Install Google Cloud SDK -| Option | Description | Type | Required | -|---------------------------------|--------------------------------------------------------------------------------------------------------------------------------------|:------:|:--------:| -| `type` | Engine type name - must be `bigquery` | string | Y | -| `method` | Connection methods - see [allowed values below](#connection-methods). Default: `oauth`. | string | N | -| `project` | The name of the GCP project | string | N | -| `location` | The location of for the datasets (can be regional or multi-regional) | string | N | -| `execution_project` | The name of the GCP project to bill for the execution of the models. If not set, the project associated with the model will be used. | string | N | -| `keyfile` | Path to the keyfile to be used with service-account method | string | N | -| `keyfile_json` | Keyfile information provided inline (not recommended) | dict | N | -| `token` | OAuth 2.0 access token | string | N | -| `refresh_token` | OAuth 2.0 refresh token | string | N | -| `client_id` | OAuth 2.0 client ID | string | N | -| `client_secret` | OAuth 2.0 client secret | string | N | -| `token_uri` | OAuth 2.0 authorization server's toke endpoint URI | string | N | -| `scopes` | The scopes used to obtain authorization | list | N | -| `job_creation_timeout_seconds` | The maximum amount of time, in seconds, to wait for the underlying job to be created. | int | N | -| `job_execution_timeout_seconds` | The maximum amount of time, in seconds, to wait for the underlying job to complete. | int | N | -| `job_retries` | The number of times to retry the underlying job if it fails. (Default: `1`) | int | N | -| `priority` | The priority of the underlying job. (Default: `INTERACTIVE`) | string | N | -| `maximum_bytes_billed` | The maximum number of bytes to be billed for the underlying job. | int | N | - -## Airflow Scheduler -**Engine Name:** `bigquery` - -In order to share a common implementation across local and Airflow, SQLMesh BigQuery implements its own hook and operator. +SQLMesh connects to BigQuery via the Python [`google-cloud-bigquery` library](https://pypi.org/project/google-cloud-bigquery/), which uses the [Google Cloud SDK `gcloud` tool](https://cloud.google.com/sdk/docs) for [authenticating with BigQuery](https://googleapis.dev/python/google-api-core/latest/auth.html). -### Installation +Follow these steps to install and configure the Google Cloud SDK on your computer: + +- Download the appropriate installer for your system from the [Google Cloud installation guide](https://cloud.google.com/sdk/docs/install) +- Unpack the downloaded file with the `tar` command: + + ```bash + > tar -xzvf google-cloud-cli-{SYSTEM_SPECIFIC_INFO}.tar.gz + ``` + +- Run the installation script: -To enable support for this operator, the Airflow BigQuery provider package should be installed on the target Airflow cluster along with SQLMesh with the BigQuery extra: + ```bash + > ./google-cloud-sdk/install.sh + ``` + +- Reload your shell profile (e.g., for zsh): + + ```bash + > source $HOME/.zshrc + ``` + +- Run [`gcloud init` to setup authentication](https://cloud.google.com/sdk/gcloud/reference/init) + +## Configuration + +### Configure SQLMesh for BigQuery + +Add the following gateway specification to your SQLMesh project's `config.yaml` file: + +```yaml +bigquery: + connection: + type: bigquery + project: + +default_gateway: bigquery ``` -pip install "apache-airflow-providers-google" -pip install "sqlmesh[bigquery]" + +This creates a gateway named `bigquery` and makes it your project's default gateway. + +It uses the [`oauth` authentication method](#authentication-methods), which does not specify a username or other information directly in the connection configuration. Other authentication methods are [described below](#authentication-methods). + +In BigQuery, navigate to the dashboard and select the BigQuery project your SQLMesh project will use. From the Google Cloud dashboard, use the arrow to open the pop-up menu: + +![BigQuery Dashboard](./bigquery/bigquery-1.png) + +Now we can identify the project ID needed in the `config.yaml` gateway specification above. Select the project that you want to work with, the project ID that you need to add to your yaml file is the ID label from the pop-up menu. + +![BigQuery Dashboard: selecting your project](./bigquery/bigquery-2.png) + +For this guide, the Docs-Demo is the one we will use, thus the project ID for this example is `healthy-life-440919-s0`. + +## Usage + +### Test the connection + +Run the following command to verify that SQLMesh can connect to BigQuery: + +```bash +> sqlmesh info ``` -### Connection info +The output will look something like this: + +![Terminal Output](./bigquery/bigquery-3.png) + +- **Set quota project (optional)** + + You may see warnings like this when you run `sqlmesh info`: + + ![Terminal Output with warnings](./bigquery/bigquery-4.png) + + You can avoid these warnings about quota projects by running: + + ```bash + > gcloud auth application-default set-quota-project + > gcloud config set project + ``` -The operator requires an [Airflow connection](https://airflow.apache.org/docs/apache-airflow/stable/howto/connection.html) to determine the target BigQuery account. Please see [GoogleBaseHook](https://airflow.apache.org/docs/apache-airflow-providers-google/stable/_api/airflow/providers/google/common/hooks/base_google/index.html#airflow.providers.google.common.hooks.base_google.GoogleBaseHook) and [GCP connection](https://airflow.apache.org/docs/apache-airflow-providers-google/stable/connections/gcp.html)for more details. Use the `sqlmesh_google_cloud_bigquery_default` (by default) connection ID instead of the `google_cloud_default` one in the Airflow guide. -By default, the connection ID is set to `sqlmesh_google_cloud_bigquery_default`, but it can be overridden using the `engine_operator_args` parameter to the `SQLMeshAirflow` instance as in the example below: -```python linenums="1" -sqlmesh_airflow = SQLMeshAirflow( - "bigquery", - default_catalog="", - engine_operator_args={ - "bigquery_conn_id": "" - }, -) +### Create and run a plan + +We've verified our connection, so we're ready to create and execute a plan in BigQuery: + +```bash +> sqlmesh plan ``` -#### Optional Arguments +### View results in BigQuery Console + +Let's confirm that our project models are as expected. + +First, navigate to the BigQuery Studio Console: + +![Steps to the Studio](./bigquery/bigquery-5.png) + +Then use the left sidebar to find your project and the newly created models: + +![New Models](./bigquery/bigquery-6.png) + +We have confirmed that our SQLMesh project is running properly in BigQuery! -* `location`: Sets the default location for datasets and tables. If not set, BigQuery defaults to US for new datasets. See `location` in [Connection options](#connection-options) for more details. +## Local/Built-in Scheduler + +**Engine Adapter Type**: `bigquery` -```python linenums="1" -sqlmesh_airflow = SQLMeshAirflow( - "bigquery", - default_catalog="", - engine_operator_args={ - "bigquery_conn_id": "", - "location": "" - }, -) +### Installation ``` +pip install "sqlmesh[bigquery]" +``` + +### Connection options -## Connection Methods +| Option | Description | Type | Required | +|---------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------:|:--------:| +| `type` | Engine type name - must be `bigquery` | string | Y | +| `method` | Connection methods - see [allowed values below](#authentication-methods). Default: `oauth`. | string | N | +| `project` | The ID of the GCP project | string | N | +| `location` | The location of for the datasets (can be regional or multi-regional) | string | N | +| `execution_project` | The name of the GCP project to bill for the execution of the models. If not set, the project associated with the model will be used. | string | N | +| `quota_project` | The name of the GCP project used for the quota. If not set, the `quota_project_id` set within the credentials of the account is used to authenticate to BigQuery. | string | N | +| `keyfile` | Path to the keyfile to be used with service-account method | string | N | +| `keyfile_json` | Keyfile information provided inline (not recommended) | dict | N | +| `token` | OAuth 2.0 access token | string | N | +| `refresh_token` | OAuth 2.0 refresh token | string | N | +| `client_id` | OAuth 2.0 client ID | string | N | +| `client_secret` | OAuth 2.0 client secret | string | N | +| `token_uri` | OAuth 2.0 authorization server's token endpoint URI | string | N | +| `scopes` | The scopes used to obtain authorization | list | N | +| `impersonated_service_account` | If set, SQLMesh will attempt to impersonate this service account | string | N | +| `job_creation_timeout_seconds` | The maximum amount of time, in seconds, to wait for the underlying job to be created. | int | N | +| `job_execution_timeout_seconds` | The maximum amount of time, in seconds, to wait for the underlying job to complete. | int | N | +| `job_retries` | The number of times to retry the underlying job if it fails. (Default: `1`) | int | N | +| `priority` | The priority of the underlying job. (Default: `INTERACTIVE`) | string | N | +| `maximum_bytes_billed` | The maximum number of bytes to be billed for the underlying job. | int | N | + +## Authentication Methods - [oauth](https://google-auth.readthedocs.io/en/master/reference/google.auth.html#google.auth.default) (default) - Related Credential Configuration: - `scopes` (Optional) @@ -96,7 +186,32 @@ sqlmesh_airflow = SQLMeshAirflow( - `keyfile_json` (Required) - `scopes` (Optional) +If the `impersonated_service_account` argument is set, SQLMesh will: + +1. Authenticate user account credentials with one of the methods above +2. Attempt to impersonate the service account with those credentials + +The user account must have [sufficient permissions to impersonate the service account](https://cloud.google.com/docs/authentication/use-service-account-impersonation). + +## Query Label + +BigQuery supports a `query_label` session variable which is attached to query jobs and can be used for auditing / attribution. + +SQLMesh supports setting it via `session_properties.query_label` on a model, as an array (or tuple) of key/value tuples. + +Example: +```sql +MODEL ( + name my_project.my_dataset.my_model, + dialect 'bigquery', + session_properties ( + query_label = [('team', 'data_platform'), ('env', 'prod')] + ) +); +``` + ## Permissions Required With any of the above connection methods, ensure these BigQuery permissions are enabled to allow SQLMesh to work correctly. -- [`BigQuery Data Editor`](https://cloud.google.com/bigquery/docs/access-control#bigquery.dataEditor) + +- [`BigQuery Data Owner`](https://cloud.google.com/bigquery/docs/access-control#bigquery.dataOwner) - [`BigQuery User`](https://cloud.google.com/bigquery/docs/access-control#bigquery.user) diff --git a/docs/integrations/engines/bigquery/bigquery-1.png b/docs/integrations/engines/bigquery/bigquery-1.png new file mode 100644 index 0000000000..8cf7e4933f Binary files /dev/null and b/docs/integrations/engines/bigquery/bigquery-1.png differ diff --git a/docs/integrations/engines/bigquery/bigquery-2.png b/docs/integrations/engines/bigquery/bigquery-2.png new file mode 100644 index 0000000000..d7e7b065dd Binary files /dev/null and b/docs/integrations/engines/bigquery/bigquery-2.png differ diff --git a/docs/integrations/engines/bigquery/bigquery-3.png b/docs/integrations/engines/bigquery/bigquery-3.png new file mode 100644 index 0000000000..ff121685b0 Binary files /dev/null and b/docs/integrations/engines/bigquery/bigquery-3.png differ diff --git a/docs/integrations/engines/bigquery/bigquery-4.png b/docs/integrations/engines/bigquery/bigquery-4.png new file mode 100644 index 0000000000..cfa14187fd Binary files /dev/null and b/docs/integrations/engines/bigquery/bigquery-4.png differ diff --git a/docs/integrations/engines/bigquery/bigquery-5.png b/docs/integrations/engines/bigquery/bigquery-5.png new file mode 100644 index 0000000000..0fb6851e41 Binary files /dev/null and b/docs/integrations/engines/bigquery/bigquery-5.png differ diff --git a/docs/integrations/engines/bigquery/bigquery-6.png b/docs/integrations/engines/bigquery/bigquery-6.png new file mode 100644 index 0000000000..6af27c461f Binary files /dev/null and b/docs/integrations/engines/bigquery/bigquery-6.png differ diff --git a/docs/integrations/engines/clickhouse.md b/docs/integrations/engines/clickhouse.md new file mode 100644 index 0000000000..14e931b046 --- /dev/null +++ b/docs/integrations/engines/clickhouse.md @@ -0,0 +1,449 @@ +# ClickHouse + +This page describes SQLMesh support for the ClickHouse engine, including configuration options specific to ClickHouse. + +!!! note + ClickHouse may not be used for the SQLMesh [state connection](../../reference/configuration.md#connections). + +## Background + +[ClickHouse](https://clickhouse.com/) is a distributed, column-oriented SQL engine designed to rapidly execute analytical workloads. + +It provides users fine-grained control of its behavior, but that control comes at the cost of complex configuration. + +This section provides background information about ClickHouse, providing context for how to use SQLMesh with the ClickHouse engine. + +### Object naming + +Most SQL engines use a three-level hierarchical naming scheme: tables/views are nested within _schemas_, and schemas are nested within _catalogs_. For example, the full name of a table might be `my_catalog.my_schema.my_table`. + +ClickHouse instead uses a two-level hierarchical naming scheme that has no counterpart to _catalog_. In addition, it calls the second level in the hierarchy "databases." SQLMesh and its documentation refer to this second level as "schemas." + +SQLMesh fully supports ClickHouse's two-level naming scheme without user action. + +### Table engines + +Every ClickHouse table is created with a ["table engine" that controls how the table's data is stored and queried](https://clickhouse.com/docs/en/engines/table-engines). ClickHouse's (and SQLMesh's) default table engine is `MergeTree`. + +The [`MergeTree` engine family](https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/mergetree) requires that every table be created with an `ORDER BY` clause. + +SQLMesh will automatically inject an empty `ORDER BY` clause into every `MergeTree` family table's `CREATE` statement, or you can specify the columns/expressions by which the table should be ordered. + +### ClickHouse modes of operation + +Conceptually, it may be helpful to view ClickHouse as having three modes of operation: single server, cluster, and ClickHouse Cloud. SQLMesh supports all three modes. + +#### Single server mode + +Single server mode is similar to other SQL engines: aside from choosing each table's engine, you do not need to worry about how computations are executed. You issue standard SQL commands/queries, and ClickHouse executes them. + +#### Cluster mode + +Cluster mode allows you to scale your ClickHouse engine to any number of networked servers. This enables massive workloads, but requires that you specify how computations are executed by the networked servers. + +ClickHouse coordinates the computations on the networked servers with [ClickHouse Keeper](https://clickhouse.com/docs/en/architecture/horizontal-scaling) (it also supports [Apache ZooKeeper](https://zookeeper.apache.org/)). + +You specify named virtual clusters of servers in the Keeper configuration, and those clusters provide namespaces for data objects and computations. For example, you might include all networked servers in the cluster you name `MyCluster`. + +In general, you must be connected to a ClickHouse server to execute commands. By default, each command you execute runs in single-server mode on the server you are connected to. + +To associate an object with a cluster, DDL commands that create or modify it must include the text `ON CLUSTER [your cluster name]`. + +If you provide a cluster name in your SQLMesh connection configuration, SQLMesh will automatically inject the `ON CLUSTER` statement into the DDL commands for all objects created while executing the project. We provide more information about clusters in SQLMesh [below](#cluster-specification). + +#### ClickHouse Cloud mode + +[ClickHouse Cloud](https://clickhouse.com/cloud) is a managed ClickHouse platform. It allows you to scale ClickHouse without administering a cluster yourself or modifying your SQL commands to run on the cluster. + +ClickHouse Cloud automates ClickHouse's cluster controls, which sometimes constrains ClickHouse's flexibility or how you execute SQL commands. For example, creating a table with a `SELECT` command must [occur in two steps on ClickHouse Cloud](https://clickhouse.com/docs/en/sql-reference/statements/create/table#from-select-query). SQLMesh handles this limitation for you. + +Aside from those constraints, ClickHouse Cloud mode is similar to single server mode - you run standard SQL commands/queries, and ClickHouse Cloud executes them. + +## Permissions + +In the default SQLMesh configuration, users must have sufficient permissions to create new ClickHouse databases. + +Alternatively, you can configure specific databases where SQLMesh should create table and view objects. + +### Environment views + +Use the [`environment_suffix_target` key in your project configuration](../../guides/configuration.md#disable-environment-specific-schemas) to specify that environment views should be created within the model's database instead of in a new database: + +``` yaml +environment_suffix_target: table +``` + +### Physical tables + +Use the [`physical_schema_mapping` key in your project configuration](../../guides/configuration.md#physical-table-schemas) to specify the databases where physical tables should be created. + +The key accepts a dictionary of regular expressions that map model database names to the corresponding databases where physical tables should be created. + +SQLMesh will compare a model's database name to each regular expression and use the first match to determine which database a physical table should be created in. + +For example, this configuration places every model's physical table in the `model_physical_tables` database because the regular expression `.*` matches any database name: + +``` yaml +physical_schema_mapping: + '.*': model_physical_tables +``` + +## Cluster specification + +A ClickHouse cluster allows multiple networked ClickHouse servers to operate on the same data object. Every cluster must be named in the ClickHouse configuration files, and that name is passed to a table's DDL statements in the `ON CLUSTER` clause. + +For example, we could create a table `my_schema.my_table` on cluster `TheCluster` like this: `CREATE TABLE my_schema.my_table ON CLUSTER TheCluster (col1 Int8)`. + +To create SQLMesh objects on a cluster, provide the cluster name to the `cluster` key in the SQLMesh connection definition (see all connection parameters [below](#localbuilt-in-scheduler)). + +SQLMesh will automatically inject the `ON CLUSTER` clause and cluster name you provide into all project DDL statements. + +## Model definition + +This section describes how you control a table's engine and other ClickHouse-specific functionality in SQLMesh models. + +### Table engine + +SQLMesh uses the `MergeTree` table engine with an empty `ORDER BY` clause by default. + +Specify a different table engine by passing the table engine definition to the model DDL's `storage_format` parameter. For example, you could specify the `Log` table engine like this: + +``` sql linenums="1" hl_lines="4" +MODEL ( + name my_schema.my_log_table, + kind full, + storage_format Log, +); + +select + * +from other_schema.other_table; +``` + +You may also specify more complex table engine definitions. For example: + +``` sql linenums="1" hl_lines="4" +MODEL ( + name my_schema.my_rep_table, + kind full, + storage_format ReplicatedMergeTree('/clickhouse/tables/{shard}/table_name', '{replica}', ver), +); + +select + * +from other_schema.other_table; +``` + +#### ORDER BY + +`MergeTree` family engines require that a table's `CREATE` statement include the `ORDER BY` clause. + +SQLMesh will automatically inject an empty `ORDER BY ()` when creating a table with an engine in the `MergeTree` family. This creates the table without any ordering. + +You may specify columns/expressions to `ORDER BY` by passing them to the model `physical_properties` dictionary's `order_by` key. + +For example, you could order by columns `col1` and `col2` like this: + +``` sql linenums="1" hl_lines="4-6" +MODEL ( + name my_schema.my_log_table, + kind full, + physical_properties ( + order_by = (col1, col2) + ) +); + +select + * +from other_schema.other_table; +``` + +Note that there is an `=` between the `order_by` key name and value `(col1, col2)`. + +Complex `ORDER BY` expressions may need to be passed in single quotes, with interior single quotes escaped by the `\` character. + +#### PRIMARY KEY + +Table engines may also accept a `PRIMARY KEY` specification. Similar to `ORDER BY`, specify a primary key in the model DDL's `physical_properties` dictionary. For example: + +``` sql linenums="1" hl_lines="6" +MODEL ( + name my_schema.my_log_table, + kind full, + physical_properties ( + order_by = (col1, col2), + primary_key = col1 + ) +); + +select + * +from other_schema.other_table; +``` + +Note that there is an `=` between the `primary_key` key name and value `col1`. + +### TTL + +ClickHouse tables accept a [TTL expression that triggers actions](https://clickhouse.com/docs/en/guides/developer/ttl) like deleting rows after a certain amount of time has passed. + +Similar to `ORDER_BY` and `PRIMARY_KEY`, specify a TTL key in the model DDL's `physical_properties` dictionary. For example: + +``` sql linenums="1" hl_lines="6" +MODEL ( + name my_schema.my_log_table, + kind full, + physical_properties ( + order_by = (col1, col2), + primary_key = col1, + ttl = timestamp + INTERVAL 1 WEEK + ) +); + +select + * +from other_schema.other_table; +``` + +Note that there is an `=` between the `ttl` key name and value `timestamp + INTERVAL 1 WEEK`. + +### Partitioning + +Some ClickHouse table engines support partitioning. Specify the partitioning columns/expressions in the model DDL's `partitioned_by` key. + +For example, you could partition by columns `col1` and `col2` like this: + +``` sql linenums="1" hl_lines="4" +MODEL ( + name my_schema.my_log_table, + kind full, + partitioned_by (col1, col2), +); + +select + * +from other_schema.other_table; +``` + +Learn more below about how SQLMesh uses [partitioned tables to improve performance](#performance-considerations). + +## Settings + +ClickHouse supports an [immense number of settings](https://clickhouse.com/docs/en/operations/settings), many of which can be altered in multiple places: ClickHouse configuration files, Python client connection arguments, DDL statements, SQL queries, and others. + +This section discusses how to control ClickHouse settings in SQLMesh. + +### Connection settings + +SQLMesh connects to Python with the [`clickhouse-connect` library](https://clickhouse.com/docs/en/integrations/python). Its connection method accepts a dictionary of arbitrary settings that are passed to ClickHouse. + +Specify these settings in the `connection_settings` key. This example demonstrates how to set the `distributed_ddl_task_timeout` setting to `300`: + +``` yaml linenums="1" hl_lines="8-9" +clickhouse_gateway: + connection: + type: clickhouse + host: localhost + port: 8123 + username: user + password: pw + connection_settings: + distributed_ddl_task_timeout: 300 + state_connection: + type: duckdb +``` + +### DDL settings + +ClickHouse settings may also be specified in DDL commands like `CREATE`. + +Specify these settings in a model DDL's [`physical_properties` key](https://sqlmesh.readthedocs.io/en/stable/concepts/models/overview/?h=physical#physical_properties) (where the [`order_by`](#order-by) and [`primary_key`](#primary-key) values are specified, if present). + +This example demonstrates how to set the `index_granularity` setting to `128`: + +``` sql linenums="1" hl_lines="4-6" +MODEL ( + name my_schema.my_log_table, + kind full, + physical_properties ( + index_granularity = 128 + ) +); + +select + * +from other_schema.other_table; +``` + +Note that there is an `=` between the `index_granularity` key name and value `128`. + +### Query settings + +ClickHouse settings may be specified directly in a model's query with the `SETTINGS` keyword. + +This example demonstrates setting the `join_use_nulls` setting to `1`: + +``` sql linenums="1" hl_lines="9" +MODEL ( + name my_schema.my_log_table, + kind full, +); + +select + * +from other_schema.other_table +SETTINGS join_use_nulls = 1; +``` + +Multiple settings may be specified in a query with repeated use of the `SETTINGS` keyword: `SELECT * FROM other_table SETTINGS first_setting = 1 SETTINGS second_setting = 2;`. + +#### Usage by SQLMesh + +The ClickHouse setting `join_use_nulls` affects the behavior of SQLMesh SCD models and table diffs. This section describes how SQLMesh uses query settings to control that behavior. + +^^Background^^ + +In general, table `JOIN`s can return empty cells for rows not present in both tables. + +For example, consider `LEFT JOIN`ing two tables `left` and `right`, where the column `right_column` is only present in the `right` table. Any rows only present in the `left` table will have no value for `right_column` in the joined table. + +In other SQL engines, those empty cells are filled with `NULL`s. + +In contrast, ClickHouse fills the empty cells with data type-specific default values (e.g., 0 for integer column types). It will instead fill the cells with `NULL`s if you set the `join_use_nulls` setting to `1`. + +^^SQLMesh^^ + +SQLMesh automatically generates SQL queries for both SCD Type 2 models and table diff comparisons. These queries include table `JOIN`s and calculations based on the presence of `NULL` values. + +Because those queries expect `NULL` values in empty cells, SQLMesh automatically adds `SETTINGS join_use_nulls = 1` to the generated SCD and table diff SQL code. + +The SCD model definition query is embedded as a CTE in the full SQLMesh-generated query. If run alone, the model definition query would use the ClickHouse server's current `join_use_nulls` value. + +If that value is not `1`, the SQLMesh setting on the outer query would override the server value and produce incorrect results. + +Therefore, SQLMesh uses the following procedure to ensure the model definition query runs with the correct `join_use_nulls` value: + +- If the model query sets `join_use_nulls` itself, do nothing +- If the model query does not set `join_use_nulls` and the current server `join_use_nulls` value is `1`, do nothing +- If the model query does not set `join_use_nulls` and the current server `join_use_nulls` value is `0`, add `SETTINGS join_use_nulls = 0` to the CTE model query + - All other CTEs and the outer query will still execute with a `join_use_nulls` value of `1` + +## Performance considerations + +ClickHouse is optimized for writing/reading records, so deleting/replacing records can be extremely slow. + +This section describes why SQLMesh needs to delete/replace records and how the ClickHouse engine adapter works around the limitations. + +### Why delete or replace? + +SQLMesh "materializes" model kinds in a number of ways, such as: + +- Replacing an entire table ([`FULL` models](../../concepts/models/model_kinds.md#full)) +- Replacing records in a specific time range ([`INCREMENTAL_BY_TIME_RANGE` models](../../concepts/models/model_kinds.md#incremental_by_time_range)) +- Replacing records with specific key values ([`INCREMENTAL_BY_UNIQUE_KEY` models](../../concepts/models/model_kinds.md#incremental_by_unique_key)) +- Replacing records in specific partitions ([`INCREMENTAL_BY_PARTITION` models](../../concepts/models/model_kinds.md#incremental_by_partition)) + +Different SQL engines provide different methods for performing record replacement. + +Some engines natively support updating or inserting ("upserting") records. For example, in some engines you can `merge` a new table into an existing table based on a key. Records in the new table whose keys are already in the existing table will update/replace the existing records. Records in the new table without keys in the existing table will be inserted into the existing table. + +Other engines do not natively support upserts, so SQLMesh replaces records in two steps: delete the records to update/replace from the existing table, then insert the new records. + +ClickHouse does not support upserts, and it performs the two step delete/insert operation so slowly as to be unusable. Therefore, SQLMesh uses a different method for replacing records. + +### Temp table swap + +SQLMesh uses what we call the "temp table swap" method of replacing records in ClickHouse. + +Because ClickHouse is optimized for writing and reading records, it is often faster to copy most of a table than to delete a small portion of its records. That is the approach used by the temp table swap method (with optional performance improvements [for partitioned tables](#partition-swap)). + +The temp table swap has four steps: + +1. Make an empty temp copy of the existing table that has the same structure (columns, data types, table engine, etc.) +2. Insert new records into the temp table +3. Insert the existing records that should be **kept** into the temp table +4. Swap the table names, such that the temp table now has the existing table's name + +Figure 1 illustrates these four steps: +

+ +![ClickHouse table swap steps](./clickhouse/clickhouse_table-swap-steps.png){ loading=lazy } +_Figure 1: steps to execute a temp table swap_ +

+ +The weakness of this method is that it requires copying all existing rows to keep (step three), which can be problematic for large tables. + +To address this weakness, SQLMesh instead uses *partition* swapping if a table is partitioned. + +### Partition swap + +ClickHouse supports *partitioned* tables, which store groups of records in separate files, or "partitions." + +A table is partitioned based on a table column or SQL expression - the "partitioning key." All records with the same value for the partitioning key are stored together in a partition. + +For example, consider a table containing each record's creation date in a datetime column. If we partition the table by month, all the records whose timestamp was in January will be stored in one partition, records from February in another partition, and so on. + +Table partitioning provides a major benefit for improving swap performance: records can be inserted, updated, or deleted in individual partitions. + +SQLMesh leverages this to avoid copying large numbers of existing records into a temp table. Instead, it only copies the records that are in partitions affected by a load's newly ingested records. + +SQLMesh automatically uses partition swapping for any incremental model that specifies the [`partitioned_by`](../../concepts/models/overview.md#partitioned_by) key. + +#### Choosing a partitioning key + +The first step of partitioning a table is choosing its partitioning key (columns or expression). The primary consideration for a key is the total number of partitions it will generate, which affects table performance. + +Too many partitions can drastically decrease performance because the overhead of handling partition files swamps the benefits of copying fewer records. Too few partitions decreases swap performance because many existing records must still be copied in each incremental load. + +!!! question "How many partitions is too many?" + + ClickHouse's documentation [specifically warns against tables having too many partitions](https://clickhouse.com/docs/en/engines/table-engines/mergetree-family/custom-partitioning-key), suggesting a maximum of 1000. + +The total number of partitions in a table is determined by the actual data in the table, not by the partition column/expression alone. + +For example, consider a table partitioned by date. If we insert records created on `2024-10-23`, the table will have one partition. If we then insert records from `2024-10-24`, the table will have two partitions. One partition is created for each unique value of the key. + +For each partitioned table in your project, carefully consider the number of partitions created by the combination of your partitioning expression and the characteristics of your data. + +#### Incremental by time models + +`INCREMENTAL_BY_TIME_RANGE` kind models must be partitioned by time. If the model's `time_column` is not present in any `partitioned_by` expression, SQLMesh will automatically add it as the first partitioning expression. + +By default, `INCREMENTAL_BY_TIME_RANGE` models partition by week, so the maximum recommended 1000 partitions corresponds to about 19 years of data. SQLMesh projects have widely varying time ranges and data sizes, so you should choose a model's partitioning key based on the data your system will process. + +If a model has many records in each partition, you may see additional performance benefits by including the time column in the model's [`ORDER_BY` expression](#order-by). + +!!! info "Partitioning by time" + `INCREMENTAL_BY_TIME_RANGE` models must be partitioned by time. + + SQLMesh will automatically partition them by **week** unless the `partitioned_by` configuration key includes the time column or an expression based on it. + + Choose a model's time partitioning granularity based on the characteristics of the data it will process, making sure the total number of partitions is 1000 or fewer. + +## Local/Built-in Scheduler + +**Engine Adapter Type**: `clickhouse` + +| Option | Description | Type | Required | +| ------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :----: | :------: | +| `type` | Engine type name - must be `clickhouse` | string | Y | +| `host` | ClickHouse server hostname or IP address | string | Y | +| `username` | ClickHouse user name | string | Y | +| `password` | ClickHouse user password | string | N | +| `port` | The ClickHouse HTTP or HTTPS port (Default: `8123`) | int | N | +| `cluster` | ClickHouse cluster name | string | N | +| `connect_timeout` | Connection timeout in seconds (Default: `10`) | int | N | +| `send_receive_timeout` | Send/receive timeout in seconds (Default: `300`) | int | N | +| `query_limit` | Query result limit (Default: `0` - no limit) | int | N | +| `use_compression` | Whether to use compression (Default: `True`) | bool | N | +| `compression_method` | Compression method to use | string | N | +| `http_proxy` | HTTP proxy address (equivalent to setting the HTTP_PROXY environment variable) | string | N | +| `verify` | Verify server TLS/SSL certificate (Default: `True`) | bool | N | +| `ca_cert` | Ignored if verify is `False`. If verify is `True`, the file path to Certificate Authority root to validate ClickHouse server certificate, in .pem format. Not necessary if the ClickHouse server certificate is a globally trusted root as verified by the operating system. | string | N | +| `client_cert` | File path to a TLS Client certificate in .pem format (for mutual TLS authentication). The file should contain a full certificate chain, including any intermediate certificates. | string | N | +| `client_cert_key` | File path to the private key for the Client Certificate. Required if the private key is not included the Client Certificate key file. | string | N | +| `https_proxy` | HTTPS proxy address (equivalent to setting the HTTPS_PROXY environment variable) | string | N | +| `server_host_name` | The ClickHouse server hostname as identified by the CN or SNI of its TLS certificate. Set this to avoid SSL errors when connecting through a proxy or tunnel with a different hostname. | string | N | +| `tls_mode` | Controls advanced TLS behavior. proxy and strict do not invoke ClickHouse mutual TLS connection, but do send client cert and key. mutual assumes ClickHouse mutual TLS auth with a client certificate. | string | N | +| `connection_settings` | Additional [connection settings](https://clickhouse.com/docs/integrations/python#settings-argument) | dict | N | +| `connection_pool_options` | Additional [options](https://clickhouse.com/docs/integrations/python#customizing-the-http-connection-pool) for the HTTP connection pool | dict | N | \ No newline at end of file diff --git a/docs/integrations/engines/clickhouse/clickhouse_table-swap-steps.png b/docs/integrations/engines/clickhouse/clickhouse_table-swap-steps.png new file mode 100644 index 0000000000..d010673acb Binary files /dev/null and b/docs/integrations/engines/clickhouse/clickhouse_table-swap-steps.png differ diff --git a/docs/integrations/engines/databricks.md b/docs/integrations/engines/databricks.md index 004456a8e7..b4206b22b5 100644 --- a/docs/integrations/engines/databricks.md +++ b/docs/integrations/engines/databricks.md @@ -1,108 +1,275 @@ # Databricks -## Local/Built-in Scheduler -**Engine Adapter Type**: `databricks` +This page provides information about how to use SQLMesh with the Databricks SQL engine. It begins with a description of the three methods for connecting SQLMesh to Databricks. -### Installation -``` -pip install "sqlmesh[databricks]" -``` +After that is a [Connection Quickstart](#connection-quickstart) that demonstrates how to connect to Databricks, or you can skip directly to information about using Databricks with the [built-in](#localbuilt-in-scheduler). -### Connection info +## Databricks connection methods -If you are always running SQLMesh commands directly on a Databricks Cluster (like in a Databricks Notebook using the [notebook magic commands](../../reference/notebook.md)) then the only relevant configuration is `catalog` and it is optional. -The SparkSession provided by Databricks will be used to execute all SQLMesh commands. +Databricks provides multiple computing options and connection methods. This section describes the three methods for connecting with SQLMesh. -Otherwise SQLMesh's Databricks implementation uses the [Databricks SQL Connector](https://docs.databricks.com/dev-tools/python-sql-connector.html) to connect to Databricks by default. -If your project contains PySpark DataFrames in Python models then it will use [Databricks Connect](https://docs.databricks.com/dev-tools/databricks-connect.html) to connect to Databricks. -SQLMesh's Databricks Connect implementation supports Databricks Runtime 13.0 or higher. If SQLMesh detects you have Databricks Connect installed then it will use it for all Python models (so both Pandas and PySpark DataFrames). +### Databricks SQL Connector -Databricks connect execution can be routed to a different cluster than the SQL Connector by setting the `databricks_connect_*` properties. -For example this allows SQLMesh to be configured to run SQL on a [Databricks SQL Warehouse](https://docs.databricks.com/sql/admin/create-sql-warehouse.html) while still routing DataFrame operations to a normal Databricks Cluster. +SQLMesh connects to Databricks with the [Databricks SQL Connector](https://docs.databricks.com/dev-tools/python-sql-connector.html) library by default. -Note: If using Databricks Connect please note the [requirements](https://docs.databricks.com/dev-tools/databricks-connect.html#requirements) and [limitations](https://docs.databricks.com/dev-tools/databricks-connect.html#limitations) +The SQL Connector is bundled with SQLMesh and automatically installed when you include the `databricks` extra in the command `pip install "sqlmesh[databricks]"`. -### Connection options +The SQL Connector has all the functionality needed for SQLMesh to execute SQL models on Databricks and Python models that do not return PySpark DataFrames. + +If you have Python models returning PySpark DataFrames, check out the [Databricks Connect](#databricks-connect-1) section. + +### Databricks Connect + +If you want Databricks to process PySpark DataFrames in SQLMesh Python models, then SQLMesh must use the [Databricks Connect](https://docs.databricks.com/dev-tools/databricks-connect.html) library to connect to Databricks (instead of the Databricks SQL Connector library). + +SQLMesh **DOES NOT** include/bundle the Databricks Connect library. You must [install the version of Databricks Connect](https://docs.databricks.com/en/dev-tools/databricks-connect/python/install.html) that matches the Databricks Runtime used in your Databricks cluster. + +Find [more configuration details below](#databricks-connect-1). + +### Databricks notebook interface + +If you are always running SQLMesh commands directly in a Databricks Cluster interface (like in a Databricks Notebook using the [notebook magic commands](../../reference/notebook.md)), the SparkSession provided by Databricks is used to execute all SQLMesh commands. + +Find [more configuration details below](#databricks-notebook-interface-1). + +## Connection quickstart + +Connecting to cloud warehouses involves a few steps, so this connection quickstart provides the info you need to get up and running with Databricks. + +It demonstrates connecting to a Databricks [All-Purpose Compute](https://docs.databricks.com/en/compute/index.html) instance with the `databricks-sql-connector` Python library bundled with SQLMesh. + +!!! tip + This quickstart assumes you are familiar with basic SQLMesh commands and functionality. + + If you're not, work through the [SQLMesh Quickstart](../../quick_start.md) before continuing! + +### Prerequisites + +Before working through this connection quickstart, ensure that: + +1. You have a Databricks account with access to an appropriate Databricks Workspace + - The Workspace must support authenticating with [personal access tokens](https://docs.databricks.com/en/dev-tools/auth/pat.html) (Databricks [Community Edition workspaces do not](https://docs.databricks.com/en/admin/access-control/tokens.html)) + - Your account must have Workspace Access and Create Compute permissions (these permissions are enabled by default) +2. Your Databricks compute resources have [Unity Catalog](https://docs.databricks.com/aws/en/data-governance/unity-catalog/) activated +3. Your computer has [SQLMesh installed](../../installation.md) with the [Databricks extra available](../../installation.md#install-extras) + - Install from the command line with the command `pip install "sqlmesh[databricks]"` +4. You have initialized a [SQLMesh example project](../../quickstart/cli#1-create-the-sqlmesh-project) on your computer + - Open a command line interface and navigate to the directory where the project files should go + - Initialize the project with the command `sqlmesh init duckdb` + +!!! important "Unity Catalog required" + + Databricks compute resources used by SQLMesh must have [Unity Catalog](https://docs.databricks.com/aws/en/data-governance/unity-catalog/) activated. + +### Get connection info + +The first step to configuring a Databricks connection is gathering the necessary information from your Databricks compute instance. + +#### Create Compute + +We must have something to connect to, so we first create and activate a Databricks compute instance. If you already have one running, skip to the [next section](#get-jdbcodbc-info). + +We begin in the default view for our Databricks Workspace. Access the Compute view by clicking the `Compute` entry in the left-hand menu: + +![Databricks Workspace default view](./databricks/db-guide_workspace.png){ loading=lazy } + +In the Compute view, click the `Create compute` button: + +![Databricks Compute default view](./databricks/db-guide_compute.png){ loading=lazy } + +Modify compute cluster options if desired and click the `Create compute` button: + +![Databricks Create Compute view](./databricks/db-guide_compute-create.png){ loading=lazy } + +#### Get JDBC/ODBC info + +Scroll to the bottom of the view and click the open the `Advanced Options` view: + +![Databricks Compute Advanced Options link](./databricks/db-guide_compute-advanced-options-link.png){ loading=lazy } + +Click the `JDBC/ODBC` tab: + +![Databricks Compute Advanced Options JDBC/ODBC tab](./databricks/db-guide_advanced-options.png){ loading=lazy } + +Open your project's `config.yaml` configuration file in a text editor and add a new gateway named `databricks` below the existing `local` gateway: + +![Project config.yaml databricks gateway](./databricks/db-guide_config-yaml.png){ loading=lazy } + +Copy the `server_hostname` and `http_path` connection values from the Databricks JDBC/ODBC tab to the `config.yaml` file: + +![Copy server_hostname and http_path to config.yaml](./databricks/db-guide_copy-server-http.png){ loading=lazy } + +#### Get personal access token + +The final piece of information we need for the `config.yaml` file is your personal access token. + +!!! warning + **Do not share your personal access token with anyone.** + + Best practice for storing secrets like access tokens is placing them in [environment variables that the configuration file loads dynamically](../../guides/configuration.md#environment-variables). For simplicity, this guide instead places the value directly in the configuration file. + + This code demonstrates how to use the environment variable `DATABRICKS_ACCESS_TOKEN` for the configuration's `access_token` parameter: + + ```yaml linenums="1" + gateways: + databricks: + connection: + type: databricks + access_token: {{ env_var('DATABRICKS_ACCESS_TOKEN') }} + ``` + +

+To create a personal access token, click on your profile logo and go to your profile's `Settings` page: -| Option | Description | Type | Required | -|--------------------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------:|:--------:| -| `type` | Engine type name - must be `databricks` | string | Y | -| `server_hostname` | Databricks instance host name | string | N | -| `http_path` | HTTP path, either to a DBSQL endpoint (such as `/sql/1.0/endpoints/1234567890abcdef`) or to an All-Purpose cluster (such as `/sql/protocolv1/o/1234567890123456/1234-123456-slid123`) | string | N | -| `access_token` | HTTP Bearer access token, such as Databricks Personal Access Token | string | N | -| `catalog` | Spark 3.4+ Only if not using SQL Connector. The name of the catalog to use for the connection. [Defaults to use Databricks cluster default](https://docs.databricks.com/en/data-governance/unity-catalog/create-catalogs.html#the-default-catalog-configuration-when-unity-catalog-is-enabled). | string | N | -| `http_headers` | SQL Connector Only: An optional dictionary of HTTP headers that will be set on every request | dict | N | -| `session_configuration` | SQL Connector Only: An optional dictionary of Spark session parameters. Execute the SQL command `SET -v` to get a full list of available commands. | dict | N | -| `databricks_connect_server_hostname` | Databricks Connect Only: Databricks Connect server hostname. Uses `server_hostname` if not set. | string | N | -| `databricks_connect_access_token` | Databricks Connect Only: Databricks Connect access token. Uses `access_token` if not set. | string | N | -| `databricks_connect_cluster_id` | Databricks Connect Only: Databricks Connect cluster ID. Uses `http_path` if not set. Cannot be a Databricks SQL Warehouse. | string | N | -| `force_databricks_connect` | When running locally, force the use of Databricks Connect for all model operations (so don't use SQL Connector for SQL models) | bool | N | -| `disable_databricks_connect` | When running locally, disable the use of Databricks Connect for all model operations (so use SQL Connector for all models) | bool | N | -| `disable_spark_session` | Do not use SparkSession if it is available (like when running in a notebook). | bool | N | +![Navigate to profile Settings page](./databricks/db-guide_profile-settings-link.png){ loading=lazy } -## Airflow Scheduler -**Engine Name:** `databricks` / `databricks-submit` / `databricks-sql`. +Go to the `Developer` view in the User menu. Depending on your account's role, your page may not display the Workspace Admin section of the page. -Databricks has multiple operators to help differentiate running a SQL query vs. running a Python script. +![Navigate to User Developer view](./databricks/db-guide_profile-settings-developer.png){ loading=lazy } -### Engine: `databricks` (Recommended) +Click the `Manage` button in the Access Tokens section: -When evaluating models, the SQLMesh Databricks integration implements the [DatabricksSubmitRunOperator](https://airflow.apache.org/docs/apache-airflow-providers-databricks/1.0.0/operators.html). This is needed to be able to run either SQL or Python scripts on the Databricks cluster. +![Navigate to Access Tokens management](./databricks/db-guide_access-tokens-link.png){ loading=lazy } -When performing environment management operations, the SQLMesh Databricks integration is similar to the [DatabricksSqlOperator](https://airflow.apache.org/docs/apache-airflow-providers-databricks/stable/operators/sql.html#databrickssqloperator), and relies on the same [DatabricksSqlHook](https://airflow.apache.org/docs/apache-airflow-providers-databricks/stable/_api/airflow/providers/databricks/hooks/databricks_sql/index.html#airflow.providers.databricks.hooks.databricks_sql.DatabricksSqlHook) implementation. -All environment management operations are SQL-based, and the overhead of submitting jobs can be avoided. +Click the `Generate new token` button: -### Engine: `databricks-submit` +![Open the token generation menu](./databricks/db-guide_access-tokens-generate-button.png){ loading=lazy } -Whether evaluating models or performing environment management operations, the SQLMesh Databricks integration implements the [DatabricksSubmitRunOperator](https://airflow.apache.org/docs/apache-airflow-providers-databricks/1.0.0/operators.html). +Name your token in the `Comment` field, and click the `Generate` button: -### Engine: `databricks-sql` +![Generate a new token](./databricks/db-guide_access-tokens-generate.png){ loading=lazy } -Forces the SQLMesh Databricks integration to use the operator based on the [DatabricksSqlOperator](https://airflow.apache.org/docs/apache-airflow-providers-databricks/stable/operators/sql.html#databrickssqloperator) for all operations. If your project is pure SQL operations, then this is an option. +Click the copy button and paste the token into the `access_token` key: -To enable support for this operator, the Airflow Databricks provider package should be installed on the target Airflow cluster along with the SQLMesh package with databricks extra as follows: +![Copy token to config.yaml access_token key](./databricks/db-guide_copy-token.png){ loading=lazy } + +!!! warning + **Do not share your personal access token with anyone.** + + Best practice for storing secrets like access tokens is placing them in [environment variables that the configuration file loads dynamically](../../guides/configuration.md#environment-variables). For simplicity, this guide instead places the value directly in the configuration file. + + This code demonstrates how to use the environment variable `DATABRICKS_ACCESS_TOKEN` for the configuration's `access_token` parameter: + + ```yaml linenums="1" + gateways: + databricks: + connection: + type: databricks + access_token: {{ env_var('DATABRICKS_ACCESS_TOKEN') }} + ``` + +### Check connection + +We have now specified the `databricks` gateway connection information, so we can confirm that SQLMesh is able to successfully connect to Databricks. We will test the connection with the `sqlmesh info` command. + +First, open a command line terminal. Now enter the command `sqlmesh --gateway databricks info`. + +We manually specify the `databricks` gateway because it is not our project's default gateway: + +![Run sqlmesh info command in CLI](./databricks/db-guide_sqlmesh-info.png){ loading=lazy } + +The output shows that our data warehouse connection succeeded: + +![Successful data warehouse connection](./databricks/db-guide_sqlmesh-info-succeeded.png){ loading=lazy } + +However, the output includes a `WARNING` about using the Databricks SQL engine for storing SQLMesh state: + +![Databricks state connection warning](./databricks/db-guide_sqlmesh-info-warning.png){ loading=lazy } + +!!! warning + Databricks is not designed for transactional workloads and should not be used to store SQLMesh state even in testing deployments. + + Learn more about storing SQLMesh state [here](../../guides/configuration.md#state-connection). + +### Specify state connection + +We can store SQLMesh state in a different SQL engine by specifying a `state_connection` in our `databricks` gateway. + +This example uses the DuckDB engine to store state in the local `databricks_state.db` file: + +![Specify DuckDB state connection](./databricks/db-guide_state-connection.png){ loading=lazy } + +Now we no longer see the warning when running `sqlmesh --gateway databricks info`, and we see a new entry `State backend connection succeeded`: + +![No state connection warning](./databricks/db-guide_sqlmesh-info-no-warning.png){ loading=lazy } + +### Run a `sqlmesh plan` + +For convenience, we can omit the `--gateway` option from our CLI commands by specifying `databricks` as our project's `default_gateway`: + +![Specify databricks as default gateway](./databricks/db-guide_default-gateway.png){ loading=lazy } + +And run a `sqlmesh plan` in Databricks: + +![Run sqlmesh plan in databricks](./databricks/db-guide_sqlmesh-plan.png){ loading=lazy } + +And confirm that our schemas and objects exist in the Databricks catalog: + +![Sqlmesh plan objects in databricks](./databricks/db-guide_sqlmesh-plan-objects.png){ loading=lazy } + +Congratulations - your SQLMesh project is up and running on Databricks! + +!!! tip + SQLMesh connects to your Databricks Cluster's default catalog by default. Connect to a different catalog by specifying its name in the connection configuration's [`catalog` parameter](#connection-options). + +## Local/Built-in Scheduler +**Engine Adapter Type**: `databricks` + +### Installation ``` -pip install apache-airflow-providers-databricks -sqlmesh[databricks] +pip install "sqlmesh[databricks]" ``` -The operator requires an [Airflow connection](https://airflow.apache.org/docs/apache-airflow/stable/howto/connection.html) to determine the target Databricks cluster. Refer to [Databricks connection](https://airflow.apache.org/docs/apache-airflow-providers-databricks/stable/connections/databricks.html) for more details. SQLMesh requires that `http_path` be defined in the connection since it uses this to determine the cluster for both SQL and submit operators. +### Connection method details -Example format: `databricks://?token=&http_path=` +Databricks provides multiple computing options and connection methods. The [section above](#databricks-connection-methods) explains how to use them with SQLMesh, and this section provides additional configuration details. -By default, the connection ID is set to `databricks_default`, but it can be overridden using both the `engine_operator_args` and the `ddl_engine_operator_args` parameters to the `SQLMeshAirflow` instance. -In addition, one special configuration that the SQLMesh Airflow evaluation operator requires is a dbfs path to store an application to load a given SQLMesh model. Also, a payload is stored that contains the information required for SQLMesh to do the loading. This must be defined in the `evaluate_engine_operator_args` parameter. Example of defining both: +#### Databricks SQL Connector -```python linenums="1" -from sqlmesh.schedulers.airflow.integration import SQLMeshAirflow +SQLMesh uses the [Databricks SQL Connector](https://docs.databricks.com/dev-tools/python-sql-connector.html) to connect to Databricks by default. Learn [more above](#databricks-sql-connector). -sqlmesh_airflow = SQLMeshAirflow( - "databricks", - default_catalog="", - engine_operator_args={ - "databricks_conn_id": "", - "dbfs_location": "dbfs:/FileStore/sqlmesh", - }, - ddl_engine_operator_args={ - "databricks_conn_id": "", - } -) +#### Databricks Connect -for dag in sqlmesh_airflow.dags: - globals()[dag.dag_id] = dag -``` +If you want Databricks to process PySpark DataFrames in SQLMesh Python models, then SQLMesh needs to use the [Databricks Connect](https://docs.databricks.com/dev-tools/databricks-connect.html) to connect to Databricks (instead of the Databricks SQL Connector). -**Note:** If your Databricks connection is configured to run on serverless [DBSQL](https://www.databricks.com/product/databricks-sql), then you need to define `existing_cluster_id` or `new_cluster` in your `engine_operator_args`. Example: -```python linenums="1" -sqlmesh_airflow = SQLMeshAirflow( - "databricks", - default_catalog="", - engine_operator_args={ - "dbfs_location": "dbfs:/FileStore/sqlmesh", - "existing_cluster_id": "1234-123456-slid123", - } -) -``` +SQLMesh **DOES NOT** include/bundle the Databricks Connect library. You must [install the version of Databricks Connect](https://docs.databricks.com/en/dev-tools/databricks-connect/python/install.html) that matches the Databricks Runtime used in your Databricks cluster. + +If SQLMesh detects that you have Databricks Connect installed, then it will automatically configure the connection and use it for all Python models that return a Pandas or PySpark DataFrame. + +To have databricks-connect installed but ignored by SQLMesh, set `disable_databricks_connect` to `true` in the connection configuration. + +Databricks Connect can execute SQL and DataFrame operations on different clusters by setting the SQLMesh `databricks_connect_*` connection options. For example, these options could configure SQLMesh to run SQL on a [Databricks SQL Warehouse](https://docs.databricks.com/sql/admin/create-sql-warehouse.html) while still routing DataFrame operations to a normal Databricks Cluster. + +!!! note + If using Databricks Connect, make sure to learn about the Databricks [requirements](https://docs.databricks.com/dev-tools/databricks-connect.html#requirements) and [limitations](https://docs.databricks.com/dev-tools/databricks-connect.html#limitations). + +#### Databricks notebook interface + +If you are always running SQLMesh commands directly on a Databricks Cluster (like in a Databricks Notebook using the [notebook magic commands](../../reference/notebook.md)), the SparkSession provided by Databricks is used to execute all SQLMesh commands. + +The only relevant SQLMesh configuration parameter is the optional `catalog` parameter. + +### Connection options + +| Option | Description | Type | Required | +|--------------------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------:|:--------:| +| `type` | Engine type name - must be `databricks` | string | Y | +| `server_hostname` | Databricks instance host name | string | N | +| `http_path` | HTTP path, either to a DBSQL endpoint (such as `/sql/1.0/endpoints/1234567890abcdef`) or to an All-Purpose cluster (such as `/sql/protocolv1/o/1234567890123456/1234-123456-slid123`) | string | N | +| `access_token` | HTTP Bearer access token, such as Databricks Personal Access Token | string | N | +| `catalog` | The name of the catalog to use for the connection. [Defaults to use Databricks cluster default](https://docs.databricks.com/en/data-governance/unity-catalog/create-catalogs.html#the-default-catalog-configuration-when-unity-catalog-is-enabled). | string | N | +| `auth_type` | SQL Connector Only: Set to 'databricks-oauth' or 'azure-oauth' to trigger OAuth (or dont set at all to use `access_token`) | string | N | +| `oauth_client_id` | SQL Connector Only: Optional [M2M](https://docs.databricks.com/en/dev-tools/python-sql-connector.html#oauth-machine-to-machine-m2m-authentication) OAuth Client ID to use when `auth_type` is set | string | N | +| `oauth_client_secret` | SQL Connector Only: Optional [M2M](https://docs.databricks.com/en/dev-tools/python-sql-connector.html#oauth-machine-to-machine-m2m-authentication) OAuth Client Secret to use when `auth_type` is set | string | N | +| `http_headers` | SQL Connector Only: An optional dictionary of HTTP headers that will be set on every request | dict | N | +| `session_configuration` | SQL Connector Only: An optional dictionary of Spark session parameters. Execute the SQL command `SET -v` to get a full list of available commands. | dict | N | +| `databricks_connect_server_hostname` | Databricks Connect Only: Databricks Connect server hostname. Uses `server_hostname` if not set. | string | N | +| `databricks_connect_access_token` | Databricks Connect Only: Databricks Connect access token. Uses `access_token` if not set. | string | N | +| `databricks_connect_cluster_id` | Databricks Connect Only: Databricks Connect cluster ID. Uses `http_path` if not set. Cannot be a Databricks SQL Warehouse. | string | N | +| `databricks_connect_use_serverless` | Databricks Connect Only: Use a serverless cluster for Databricks Connect instead of `databricks_connect_cluster_id`. | bool | N | +| `force_databricks_connect` | When running locally, force the use of Databricks Connect for all model operations (so don't use SQL Connector for SQL models) | bool | N | +| `disable_databricks_connect` | When running locally, disable the use of Databricks Connect for all model operations (so use SQL Connector for all models) | bool | N | +| `disable_spark_session` | Do not use SparkSession if it is available (like when running in a notebook). | bool | N | ## Model table properties to support altering tables diff --git a/docs/integrations/engines/databricks/db-guide_access-tokens-generate-button.png b/docs/integrations/engines/databricks/db-guide_access-tokens-generate-button.png new file mode 100644 index 0000000000..c9f76a2e7f Binary files /dev/null and b/docs/integrations/engines/databricks/db-guide_access-tokens-generate-button.png differ diff --git a/docs/integrations/engines/databricks/db-guide_access-tokens-generate.png b/docs/integrations/engines/databricks/db-guide_access-tokens-generate.png new file mode 100644 index 0000000000..5cc5ff93cf Binary files /dev/null and b/docs/integrations/engines/databricks/db-guide_access-tokens-generate.png differ diff --git a/docs/integrations/engines/databricks/db-guide_access-tokens-link.png b/docs/integrations/engines/databricks/db-guide_access-tokens-link.png new file mode 100644 index 0000000000..112823c056 Binary files /dev/null and b/docs/integrations/engines/databricks/db-guide_access-tokens-link.png differ diff --git a/docs/integrations/engines/databricks/db-guide_advanced-options.png b/docs/integrations/engines/databricks/db-guide_advanced-options.png new file mode 100644 index 0000000000..fb748eef99 Binary files /dev/null and b/docs/integrations/engines/databricks/db-guide_advanced-options.png differ diff --git a/docs/integrations/engines/databricks/db-guide_compute-advanced-options-link.png b/docs/integrations/engines/databricks/db-guide_compute-advanced-options-link.png new file mode 100644 index 0000000000..9f12db01c7 Binary files /dev/null and b/docs/integrations/engines/databricks/db-guide_compute-advanced-options-link.png differ diff --git a/docs/integrations/engines/databricks/db-guide_compute-create.png b/docs/integrations/engines/databricks/db-guide_compute-create.png new file mode 100644 index 0000000000..7d39900647 Binary files /dev/null and b/docs/integrations/engines/databricks/db-guide_compute-create.png differ diff --git a/docs/integrations/engines/databricks/db-guide_compute.png b/docs/integrations/engines/databricks/db-guide_compute.png new file mode 100644 index 0000000000..3f696126b5 Binary files /dev/null and b/docs/integrations/engines/databricks/db-guide_compute.png differ diff --git a/docs/integrations/engines/databricks/db-guide_config-yaml.png b/docs/integrations/engines/databricks/db-guide_config-yaml.png new file mode 100644 index 0000000000..177ad1441d Binary files /dev/null and b/docs/integrations/engines/databricks/db-guide_config-yaml.png differ diff --git a/docs/integrations/engines/databricks/db-guide_copy-server-http.png b/docs/integrations/engines/databricks/db-guide_copy-server-http.png new file mode 100644 index 0000000000..dfd4a8984f Binary files /dev/null and b/docs/integrations/engines/databricks/db-guide_copy-server-http.png differ diff --git a/docs/integrations/engines/databricks/db-guide_copy-token.png b/docs/integrations/engines/databricks/db-guide_copy-token.png new file mode 100644 index 0000000000..d302294033 Binary files /dev/null and b/docs/integrations/engines/databricks/db-guide_copy-token.png differ diff --git a/docs/integrations/engines/databricks/db-guide_default-gateway.png b/docs/integrations/engines/databricks/db-guide_default-gateway.png new file mode 100644 index 0000000000..cc400b285f Binary files /dev/null and b/docs/integrations/engines/databricks/db-guide_default-gateway.png differ diff --git a/docs/integrations/engines/databricks/db-guide_profile-settings-developer.png b/docs/integrations/engines/databricks/db-guide_profile-settings-developer.png new file mode 100644 index 0000000000..3feb727d08 Binary files /dev/null and b/docs/integrations/engines/databricks/db-guide_profile-settings-developer.png differ diff --git a/docs/integrations/engines/databricks/db-guide_profile-settings-link.png b/docs/integrations/engines/databricks/db-guide_profile-settings-link.png new file mode 100644 index 0000000000..dd0c66dda2 Binary files /dev/null and b/docs/integrations/engines/databricks/db-guide_profile-settings-link.png differ diff --git a/docs/integrations/engines/databricks/db-guide_sqlmesh-info-no-warning.png b/docs/integrations/engines/databricks/db-guide_sqlmesh-info-no-warning.png new file mode 100644 index 0000000000..3a72f60d6c Binary files /dev/null and b/docs/integrations/engines/databricks/db-guide_sqlmesh-info-no-warning.png differ diff --git a/docs/integrations/engines/databricks/db-guide_sqlmesh-info-succeeded.png b/docs/integrations/engines/databricks/db-guide_sqlmesh-info-succeeded.png new file mode 100644 index 0000000000..479c7e2a2d Binary files /dev/null and b/docs/integrations/engines/databricks/db-guide_sqlmesh-info-succeeded.png differ diff --git a/docs/integrations/engines/databricks/db-guide_sqlmesh-info-warning.png b/docs/integrations/engines/databricks/db-guide_sqlmesh-info-warning.png new file mode 100644 index 0000000000..82566cba8b Binary files /dev/null and b/docs/integrations/engines/databricks/db-guide_sqlmesh-info-warning.png differ diff --git a/docs/integrations/engines/databricks/db-guide_sqlmesh-info.png b/docs/integrations/engines/databricks/db-guide_sqlmesh-info.png new file mode 100644 index 0000000000..d257a569f2 Binary files /dev/null and b/docs/integrations/engines/databricks/db-guide_sqlmesh-info.png differ diff --git a/docs/integrations/engines/databricks/db-guide_sqlmesh-plan-objects.png b/docs/integrations/engines/databricks/db-guide_sqlmesh-plan-objects.png new file mode 100644 index 0000000000..2756c54ba7 Binary files /dev/null and b/docs/integrations/engines/databricks/db-guide_sqlmesh-plan-objects.png differ diff --git a/docs/integrations/engines/databricks/db-guide_sqlmesh-plan.png b/docs/integrations/engines/databricks/db-guide_sqlmesh-plan.png new file mode 100644 index 0000000000..7cd60cd816 Binary files /dev/null and b/docs/integrations/engines/databricks/db-guide_sqlmesh-plan.png differ diff --git a/docs/integrations/engines/databricks/db-guide_state-connection.png b/docs/integrations/engines/databricks/db-guide_state-connection.png new file mode 100644 index 0000000000..b5c60f735e Binary files /dev/null and b/docs/integrations/engines/databricks/db-guide_state-connection.png differ diff --git a/docs/integrations/engines/databricks/db-guide_workspace.png b/docs/integrations/engines/databricks/db-guide_workspace.png new file mode 100644 index 0000000000..70ad286dde Binary files /dev/null and b/docs/integrations/engines/databricks/db-guide_workspace.png differ diff --git a/docs/integrations/engines/duckdb.md b/docs/integrations/engines/duckdb.md index a9b6b74ef5..5f63a4688d 100644 --- a/docs/integrations/engines/duckdb.md +++ b/docs/integrations/engines/duckdb.md @@ -1,17 +1,24 @@ # DuckDB +!!! warning "DuckDB state connection limitations" + DuckDB is a [single user](https://duckdb.org/docs/connect/concurrency.html#writing-to-duckdb-from-multiple-processes) database. Using it for a state connection in your SQLMesh project limits you to a single workstation. This means your project cannot be shared amongst your team members or your CI/CD infrastructure. This is usually fine for proof of concept or test projects but it will not scale to production usage. + + For production projects, use [Tobiko Cloud](https://tobikodata.com/product.html) or a more robust state database such as [Postgres](./postgres.md). + ## Local/Built-in Scheduler **Engine Adapter Type**: `duckdb` ### Connection options -| Option | Description | Type | Required | -|--------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------:|:--------:| -| `type` | Engine type name - must be `duckdb` | string | Y | -| `database` | The optional database name. If not specified, the in-memory database is used. Cannot be defined if using `catalogs`. | string | N | -| `catalogs` | Mapping to define multiple catalogs. Can [attach DuckDB catalogs](#duckdb-catalogs-example) or [catalogs for other connections](#other-connection-catalogs-example). First entry is the default catalog. Cannot be defined if using `database`. | dict | N | -| `extensions` | Extension to load into duckdb. Only autoloadable extensions are supported. | list | N | -| `connector_config` | Configuration to pass into the duckdb connector. | dict | N | +| Option | Description | Type | Required | +|--------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:---------:|:--------:| +| `type` | Engine type name - must be `duckdb` | string | Y | +| `database` | The optional database name. If not specified, the in-memory database is used. Cannot be defined if using `catalogs`. | string | N | +| `catalogs` | Mapping to define multiple catalogs. Can [attach DuckDB catalogs](#duckdb-catalogs-example) or [catalogs for other connections](#other-connection-catalogs-example). First entry is the default catalog. Cannot be defined if using `database`. | dict | N | +| `extensions` | Extension to load into duckdb. Only autoloadable extensions are supported. | list | N | +| `connector_config` | Configuration to pass into the duckdb connector. | dict | N | +| `secrets` | Configuration for authenticating external sources (e.g., S3) using DuckDB secrets. Can be a list of secret configurations or a dictionary with custom secret names. | list/dict | N | +| `filesystems` | Configuration for registering `fsspec` filesystems to the DuckDB connection. | dict | N | #### DuckDB Catalogs Example @@ -58,11 +65,70 @@ SQLMesh will place models with the explicit catalog "ephemeral", such as `epheme ) ``` +#### DuckLake Catalog Example + +=== "YAML" + + ```yaml linenums="1" + gateways: + my_gateway: + connection: + type: duckdb + catalogs: + ducklake: + type: ducklake + path: 'catalog.ducklake' + data_path: data/ducklake + encrypted: True + data_inlining_row_limit: 10 + metadata_schema: main + ``` + +=== "Python" + + ```python linenums="1" + from sqlmesh.core.config import ( + Config, + ModelDefaultsConfig, + GatewayConfig, + DuckDBConnectionConfig + ) + from sqlmesh.core.config.connection import DuckDBAttachOptions + + config = Config( + model_defaults=ModelDefaultsConfig(dialect=), + gateways={ + "my_gateway": GatewayConfig( + connection=DuckDBConnectionConfig( + catalogs={ + "ducklake": DuckDBAttachOptions( + type="ducklake", + path="catalog.ducklake", + data_path="data/ducklake", + encrypted=True, + data_inlining_row_limit=10, + metadata_schema="main", + ), + } + ) + ), + } + ) + ``` + +**DuckLake Configuration Options:** + +- `path`: Path to the DuckLake catalog file +- `data_path`: Path where DuckLake data files are stored +- `encrypted`: Whether to enable encryption for the catalog (default: `False`) +- `data_inlining_row_limit`: Maximum number of rows to inline in the catalog (default: `0`) +- `metadata_schema`: The schema in the catalog server in which to store the DuckLake metadata tables (default: `main`) + #### Other Connection Catalogs Example Catalogs can also be defined to connect to anything that [DuckDB can be attached to](https://duckdb.org/docs/sql/statements/attach.html). -Below are examples of connecting to a SQLite database and a PostgreSQL database. +Below are examples of connecting to a SQLite database and a PostgreSQL database. The SQLite database is read-write, while the PostgreSQL database is read-only. === "YAML" @@ -102,12 +168,12 @@ The SQLite database is read-write, while the PostgreSQL database is read-only. catalogs={ "memory": ":memory:", "sqlite": DuckDBAttachOptions( - type="sqlite", + type="sqlite", path="test.db" ), "postgres": DuckDBAttachOptions( - type="postgres", - path="dbname=postgres user=postgres host=127.0.0.1", + type="postgres", + path="dbname=postgres user=postgres host=127.0.0.1", read_only=True ), } @@ -117,6 +183,10 @@ The SQLite database is read-write, while the PostgreSQL database is read-only. ) ``` +##### Catalogs for PostgreSQL + +In PostgreSQL, the catalog name must match the actual catalog name it is associated with, as shown in the example above, where the database name (`dbname` in the path) is the same as the catalog name. + ##### Connectors without schemas Some connections, like SQLite, do not support schema names and therefore objects will be attached under the default schema name of `main`. @@ -125,16 +195,186 @@ Example: mounting a SQLite database with the name `sqlite` that has a table `exa ##### Sensitive fields in paths -If a connector, like Postgres, requires sensitive information in the path, it might support defining environment variables instead. +If a connector, like Postgres, requires sensitive information in the path, it might support defining environment variables instead. [See DuckDB Documentation for more information](https://duckdb.org/docs/extensions/postgres#configuring-via-environment-variables). #### Cloud service authentication DuckDB can read data directly from cloud services via extensions (e.g., [httpfs](https://duckdb.org/docs/extensions/httpfs/s3api), [azure](https://duckdb.org/docs/extensions/azure)). -Loading credentials at runtime using `load_aws_credentials()` or similar functions may fail when using SQLMesh. +The `secrets` option allows you to configure DuckDB's [Secrets Manager](https://duckdb.org/docs/configuration/secrets_manager.html) to authenticate with external services like S3. This is the recommended approach for cloud storage authentication in DuckDB v0.10.0 and newer, replacing the [legacy authentication method](https://duckdb.org/docs/stable/extensions/httpfs/s3api_legacy_authentication.html) via variables. + +##### Secrets Configuration + +The `secrets` option supports two formats: + +1. **List format** (default secrets): A list of secret configurations where each secret uses DuckDB's default naming +2. **Dictionary format** (named secrets): A dictionary where keys are custom secret names and values are the secret configurations + +This flexibility allows you to organize multiple secrets of the same type or reference specific secrets by name in your SQL queries. + +##### List Format Example (Default Secrets) + +Using a list creates secrets with DuckDB's default naming: + +=== "YAML" + + ```yaml linenums="1" + gateways: + duckdb: + connection: + type: duckdb + catalogs: + local: local.db + remote: "s3://bucket/data/remote.duckdb" + extensions: + - name: httpfs + secrets: + - type: s3 + region: "YOUR_AWS_REGION" + key_id: "YOUR_AWS_ACCESS_KEY" + secret: "YOUR_AWS_SECRET_KEY" + ``` + +=== "Python" + + ```python linenums="1" + from sqlmesh.core.config import ( + Config, + ModelDefaultsConfig, + GatewayConfig, + DuckDBConnectionConfig + ) + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + gateways={ + "duckdb": GatewayConfig( + connection=DuckDBConnectionConfig( + catalogs={ + "local": "local.db", + "remote": "s3://bucket/data/remote.duckdb" + }, + extensions=[ + {"name": "httpfs"}, + ], + secrets=[ + { + "type": "s3", + "region": "YOUR_AWS_REGION", + "key_id": "YOUR_AWS_ACCESS_KEY", + "secret": "YOUR_AWS_SECRET_KEY" + } + ] + ) + ), + } + ) + ``` + +##### Dictionary Format Example (Named Secrets) + +Using a dictionary allows you to assign custom names to your secrets for better organization and reference: + +=== "YAML" + + ```yaml linenums="1" + gateways: + duckdb: + connection: + type: duckdb + catalogs: + local: local.db + remote: "s3://bucket/data/remote.duckdb" + extensions: + - name: httpfs + secrets: + my_s3_secret: + type: s3 + region: "YOUR_AWS_REGION" + key_id: "YOUR_AWS_ACCESS_KEY" + secret: "YOUR_AWS_SECRET_KEY" + my_azure_secret: + type: azure + account_name: "YOUR_AZURE_ACCOUNT" + account_key: "YOUR_AZURE_KEY" + ``` + +=== "Python" + + ```python linenums="1" + from sqlmesh.core.config import ( + Config, + ModelDefaultsConfig, + GatewayConfig, + DuckDBConnectionConfig + ) + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + gateways={ + "duckdb": GatewayConfig( + connection=DuckDBConnectionConfig( + catalogs={ + "local": "local.db", + "remote": "s3://bucket/data/remote.duckdb" + }, + extensions=[ + {"name": "httpfs"}, + ], + secrets={ + "my_s3_secret": { + "type": "s3", + "region": "YOUR_AWS_REGION", + "key_id": "YOUR_AWS_ACCESS_KEY", + "secret": "YOUR_AWS_SECRET_KEY" + }, + "my_azure_secret": { + "type": "azure", + "account_name": "YOUR_AZURE_ACCOUNT", + "account_key": "YOUR_AZURE_KEY" + } + } + ) + ), + } + ) + ``` + +After configuring the secrets, you can directly reference S3 paths in your catalogs or in SQL queries without additional authentication steps. + +Refer to the official DuckDB documentation for the full list of [supported S3 secret parameters](https://duckdb.org/docs/stable/extensions/httpfs/s3api.html#overview-of-s3-secret-parameters) and for more information on the [Secrets Manager configuration](https://duckdb.org/docs/configuration/secrets_manager.html). + +> Note: Loading credentials at runtime using `load_aws_credentials()` or similar deprecated functions may fail when using SQLMesh. + +##### File system configuration example for Microsoft Onelake + +The `filesystems` accepts a list of file systems to register in the DuckDB connection. This is especially useful for Azure Storage Accounts, as it adds write support for DuckDB which is not natively supported by DuckDB (yet). + + +=== "YAML" + + ```yaml linenums="1" + gateways: + ducklake: + connection: + type: duckdb + catalogs: + ducklake: + type: ducklake + path: myducklakecatalog.duckdb + data_path: abfs://MyFabricWorkspace/MyFabricLakehouse.Lakehouse/Files/DuckLake.Files + extensions: + - ducklake + filesystems: + - fs: abfs + account_name: onelake + account_host: onelake.blob.fabric.microsoft.com + client_id: {{ env_var('AZURE_CLIENT_ID') }} + client_secret: {{ env_var('AZURE_CLIENT_SECRET') }} + tenant_id: {{ env_var('AZURE_TENANT_ID') }} + # anon: False # To use azure.identity.DefaultAzureCredential authentication + ``` -Instead, create persistent and automatically used authentication credentials with the [DuckDB secrets manager](https://duckdb.org/docs/configuration/secrets_manager.html) (available in DuckDB v0.10.0 or greater). -## Airflow Scheduler -DuckDB only works when running locally; therefore it does not support Airflow. +Refer to the documentation for `fsspec` [fsspec.filesystem](https://filesystem-spec.readthedocs.io/en/latest/api.html#fsspec.filesystem) and `adlfs` [adlfs.AzureBlobFileSystem](https://fsspec.github.io/adlfs/api/#api-reference) for a full list of storage options. diff --git a/docs/integrations/engines/fabric.md b/docs/integrations/engines/fabric.md new file mode 100644 index 0000000000..90ac3234fc --- /dev/null +++ b/docs/integrations/engines/fabric.md @@ -0,0 +1,37 @@ +# Fabric + +!!! info + The Fabric engine adapter is a community contribution. Due to this, only limited community support is available. + +## Local/Built-in Scheduler +**Engine Adapter Type**: `fabric` + +NOTE: Fabric Warehouse is not recommended to be used for the SQLMesh [state connection](../../reference/configuration.md#connections). + +### Installation +#### Microsoft Entra ID / Azure Active Directory Authentication: +``` +pip install "sqlmesh[fabric]" +``` + +### Connection options + +| Option | Description | Type | Required | +| ----------------- | ------------------------------------------------------------ | :----------: | :------: | +| `type` | Engine type name - must be `fabric` | string | Y | +| `host` | The hostname of the Fabric Warehouse server | string | Y | +| `user` | The client id to use for authentication with the Fabric Warehouse server | string | N | +| `password` | The client secret to use for authentication with the Fabric Warehouse server | string | N | +| `port` | The port number of the Fabric Warehouse server | int | N | +| `database` | The target database | string | N | +| `charset` | The character set used for the connection | string | N | +| `timeout` | The query timeout in seconds. Default: no timeout | int | N | +| `login_timeout` | The timeout for connection and login in seconds. Default: 60 | int | N | +| `appname` | The application name to use for the connection | string | N | +| `conn_properties` | The list of connection properties | list[string] | N | +| `autocommit` | Is autocommit mode enabled. Default: false | bool | N | +| `driver` | The driver to use for the connection. Default: pyodbc | string | N | +| `driver_name` | The driver name to use for the connection. E.g., *ODBC Driver 18 for SQL Server* | string | N | +| `tenant_id` | The Azure / Entra tenant UUID | string | Y | +| `workspace_id` | The Fabric workspace UUID. The preferred way to retrieve it is by running `notebookutils.runtime.context.get("currentWorkspaceId")` in a python notebook. | string | Y | +| `odbc_properties` | The dict of ODBC connection properties. E.g., authentication: ActiveDirectoryServicePrincipal. See more [here](https://learn.microsoft.com/en-us/sql/connect/odbc/dsn-connection-string-attribute?view=sql-server-ver16). | dict | N | diff --git a/docs/integrations/engines/gcp-postgres.md b/docs/integrations/engines/gcp-postgres.md index 60701c8ac2..ca0bd9ded2 100644 --- a/docs/integrations/engines/gcp-postgres.md +++ b/docs/integrations/engines/gcp-postgres.md @@ -1,7 +1,7 @@ # GCP Postgres ## Local/Built-in Scheduler -**Engine Adapter Type**: `postgres` +**Engine Adapter Type**: `gcp_postgres` ### Installation ``` @@ -10,11 +10,17 @@ pip install "sqlmesh[gcppostgres]" ### Connection options -| Option | Description | Type | Required | -|---------------------------|-------------------------------------------------------------------------------------|:-------:|:--------:| -| `type` | Engine type name - must be `postgres` | string | Y | -| `instance_connection_str` | Connection name for the postgres instance | string | Y | -| `user` | The username (posgres or IAM) to use for authentication | string | Y | -| `password` | The password to use for authentication. Required when connecting as a Postgres user | string | N | -| `enable_iam_auth` | Enables IAM authentication. Required when connecting as an IAM user | boolean | N | -| `db` | The name of the database instance to connect to | string | Y | +| Option | Description | Type | Required | +|------------------------------|--------------------------------------------------------------------------------------------------------|:----------:|:--------:| +| `type` | Engine type name - must be `gcp_postgres` | string | Y | +| `instance_connection_string` | Connection name for the postgres instance | string | Y | +| `user` | The username (postgres or IAM) to use for authentication | string | Y | +| `password` | The password to use for authentication. Required when connecting as a Postgres user | string | N | +| `enable_iam_auth` | Enables IAM authentication. Required when connecting as an IAM user | boolean | N | +| `keyfile` | Path to the keyfile to be used with enable_iam_auth instead of ADC | string | N | +| `keyfile_json` | Keyfile information provided inline (not recommended) | dict | N | +| `db` | The name of the database instance to connect to | string | Y | +| `ip_type` | The IP type to use for the connection. Must be one of `public`, `private`, or `psc`. Default: `public` | string | N | +| `timeout` | The connection timeout in seconds. Default: `30` | integer | N | +| `scopes` | The scopes to use for the connection. Default: `(https://www.googleapis.com/auth/sqlservice.admin,)` | tuple[str] | N | +| `driver` | The driver to use for the connection. Default: `pg8000`. Note: only `pg8000` is tested | string | N | diff --git a/docs/integrations/engines/motherduck.md b/docs/integrations/engines/motherduck.md index 04759cc8b2..caa5541d3d 100644 --- a/docs/integrations/engines/motherduck.md +++ b/docs/integrations/engines/motherduck.md @@ -1,6 +1,98 @@ # MotherDuck +This page provides information about how to use SQLMesh with MotherDuck. + +It begins with a [Connection Quickstart](#connection-quickstart) that demonstrates how to connect to MotherDuck, or you can skip directly to information about using MotherDuck with the built-in scheduler. + +## Connection quickstart + +Connecting to cloud warehouses involves a few steps, so this connection quickstart provides the info you need to get up and running with MotherDuck. + +It demonstrates connecting to MotherDuck with the `duckdb` library bundled with SQLMesh. + +MotherDuck provides a single way to authorize a connection. This quickstart demonstrates authenticating with a token. + +!!! tip + This quick start assumes you are familiar with basic SQLMesh commands and functionality. + + If you’re not familiar, work through the [SQLMesh Quickstart](../../quick_start.md) before continuing. + +### Prerequisites + +Before working through this quickstart guide, ensure that: + +1. You have a motherduck account and an access token. +2. Your computer has SQLMesh installed with the DuckDB extra available. + 1. Install from command line with the command `pip install “sqlmesh[duckdb]”` +3. You have initialized a SQLMesh example project on your computer + 1. Open a command line interface and navigate to the directory where the project files should go. + 2. Initialize the project with the command `sqlmesh init duckdb`, since `duckdb` is the dialect. + +#### Access control permissions + +SQLMesh must have sufficient permissions to create and access your MotherDuck databases. Since permission is granted to specific databases for a specific user, you should create a service account for SQLMesh that will contain the credentials for writing to MotherDuck. + +### Configure the connection + +We now have what is required to configure SQLMesh’s connection to MotherDuck. + +We start the configuration by adding a gateway named `motherduck` to our example project’s config.yaml file and making it our `default gateway`, as well as adding our token, persistent, and ephemeral catalogs. + +```yaml +gateways: + motherduck: + connection: + type: motherduck + catalogs: + persistent: "md:" + ephemeral: ":memory:" + token: + +default_gateway: motherduck +``` + +Catalogs can be defined to connect to anything that [DuckDB can be attached to](./duckdb.md#other-connection-catalogs-example). + +!!! warning + Best practice for storing secrets like tokens is placing them in [environment variables that the configuration file loads dynamically](../../guides/configuration.md#environment-variables). For simplicity, this guide instead places the value directly in the configuration file. + + This code demonstrates how to use the environment variable `MOTHERDUCK_TOKEN` for the configuration's `token` parameter: + + ```yaml linenums="1" hl_lines="5" + gateways: + motherduck: + connection: + type: motherduck + token: {{ env_var('MOTHERDUCK_TOKEN') }} + ``` + +### Check connection + +We have now specified the `motherduck` gateway connection information, so we can confirm that SQLMesh is able to successfully connect to MotherDuck. We will test the connection with the `sqlmesh info` command. + +First, open a command line terminal. Now enter the command `sqlmesh info`: + +![](./motherduck/sqlmesh_info.png) + +The output shows that our data warehouse connection succeeded: + +![](./motherduck/info_output.png) + +### Run a `sqlmesh plan` + +Now we're ready to run a `sqlmesh plan` in MotherDuck: + +![](./motherduck/sqlmesh_plan.png) + +And confirm that our schemas and objects exist in the MotherDuck catalog: + +![](./motherduck/motherduck_ui.png) + +Congratulations \- your SQLMesh project is up and running on MotherDuck\! + + ## Local/Built-in Scheduler + **Engine Adapter Type**: `motherduck` ### Connection options @@ -12,3 +104,4 @@ | `token` | The optional MotherDuck token. If not specified, the user will be prompted to login with their web browser. | string | N | | `extensions` | Extension to load into duckdb. Only autoloadable extensions are supported. | list | N | | `connector_config` | Configuration to pass into the duckdb connector. | dict | N | +| `secrets` | Configuration for authenticating external sources (e.g. S3) using DuckDB secrets. | dict | N | \ No newline at end of file diff --git a/docs/integrations/engines/motherduck/info_output.png b/docs/integrations/engines/motherduck/info_output.png new file mode 100644 index 0000000000..1d37418a81 Binary files /dev/null and b/docs/integrations/engines/motherduck/info_output.png differ diff --git a/docs/integrations/engines/motherduck/motherduck_ui.png b/docs/integrations/engines/motherduck/motherduck_ui.png new file mode 100644 index 0000000000..3cc51da7fc Binary files /dev/null and b/docs/integrations/engines/motherduck/motherduck_ui.png differ diff --git a/docs/integrations/engines/motherduck/sqlmesh_info.png b/docs/integrations/engines/motherduck/sqlmesh_info.png new file mode 100644 index 0000000000..92800aaedc Binary files /dev/null and b/docs/integrations/engines/motherduck/sqlmesh_info.png differ diff --git a/docs/integrations/engines/motherduck/sqlmesh_plan.png b/docs/integrations/engines/motherduck/sqlmesh_plan.png new file mode 100644 index 0000000000..21a0045de8 Binary files /dev/null and b/docs/integrations/engines/motherduck/sqlmesh_plan.png differ diff --git a/docs/integrations/engines/mssql.md b/docs/integrations/engines/mssql.md index 32dbc0191d..4c68219dd2 100644 --- a/docs/integrations/engines/mssql.md +++ b/docs/integrations/engines/mssql.md @@ -1,50 +1,65 @@ # MSSQL -## Local/Built-in Scheduler -**Engine Adapter Type**: `mssql` +## Installation -### Installation +### User / Password Authentication: ``` pip install "sqlmesh[mssql]" ``` +### Microsoft Entra ID / Azure Active Directory Authentication: +``` +pip install "sqlmesh[mssql-odbc]" +``` -### Connection options +## Incremental by unique key `MERGE` -| Option | Description | Type | Required | -| ----------------- | ------------------------------------------------------------ | :----------: | :------: | -| `type` | Engine type name - must be `mssql` | string | Y | -| `host` | The hostname of the MSSQL server | string | Y | -| `user` | The username to use for authentication with the MSSQL server | string | N | -| `password` | The password to use for authentication with the MSSQL server | string | N | -| `port` | The port number of the MSSQL server | int | N | -| `database` | The target database | string | N | -| `charset` | The character set used for the connection | string | N | -| `timeout` | The query timeout in seconds. Default: no timeout | int | N | -| `login_timeout` | The timeout for connection and login in seconds. Default: 60 | int | N | -| `appname` | The application name to use for the connection | string | N | -| `conn_properties` | The list of connection properties | list[string] | N | -| `autocommit` | Is autocommit mode enabled. Default: false | bool | N | - -## Airflow Scheduler -**Engine Name:** `mssql` - -The SQLMesh MsSql Operator is similar to the [MsSqlOperator](https://airflow.apache.org/docs/apache-airflow-providers-microsoft-mssql/stable/_api/airflow/providers/microsoft/mssql/operators/mssql/index.html), and relies on the same [MsSqlHook](https://airflow.apache.org/docs/apache-airflow-providers-microsoft-mssql/stable/_api/airflow/providers/microsoft/mssql/hooks/mssql/index.html) implementation. - -To enable support for this operator, the Airflow Microsoft MSSQL provider package should be installed on the target Airflow cluster along with SQLMesh with the mssql extra: -``` -pip install "apache-airflow-providers-microsoft-mssql" -pip install "sqlmesh[mssql]" +SQLMesh executes a `MERGE` statement to insert rows for [incremental by unique key](../../concepts/models/model_kinds.md#incremental_by_unique_key) model kinds. + +By default, the `MERGE` statement updates all non-key columns of an existing row when a new row with the same key values is inserted. If all column values match between the two rows, those updates are unnecessary. + +SQLMesh provides an optional performance optimization that skips unnecessary updates by comparing column values with the `EXISTS` and `EXCEPT` operators. + +Enable the optimization by setting the `mssql_merge_exists` key to `true` in the [`physical_properties`](../../concepts/models/overview.md#physical_properties) section of the `MODEL` statement. + +For example: + +```sql linenums="1" hl_lines="7-9" +MODEL ( + name sqlmesh_example.unique_key, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key id + ), + cron '@daily', + physical_properties ( + mssql_merge_exists = true + ) +); ``` -The operator requires an [Airflow connection](https://airflow.apache.org/docs/apache-airflow/stable/howto/connection.html) to determine the target MSSQL account. Refer to [MSSQL connection](https://airflow.apache.org/docs/apache-airflow-providers-microsoft-mssql/stable/connections/mssql.html) for more details. - -By default, the connection ID is set to `mssql_default`, but can be overridden using the `engine_operator_args` parameter to the `SQLMeshAirflow` instance as in the example below: -```python linenums="1" -sqlmesh_airflow = SQLMeshAirflow( - "mssql", - default_catalog="", - engine_operator_args={ - "mssql_conn_id": "" - }, -) -``` \ No newline at end of file +!!! warning "Not all column types supported" + The `mssql_merge_exists` optimization is not supported for all column types, including `GEOMETRY`, `XML`, `TEXT`, `NTEXT`, `IMAGE`, and most user-defined types. + + Learn more in the [MSSQL `EXCEPT` statement documentation](https://learn.microsoft.com/en-us/sql/t-sql/language-elements/set-operators-except-and-intersect-transact-sql?view=sql-server-ver17#arguments). + +## Local/Built-in Scheduler +**Engine Adapter Type**: `mssql` + +### Connection options + +| Option | Description | Type | Required | +| ----------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :----------: | :------: | +| `type` | Engine type name - must be `mssql` | string | Y | +| `host` | The hostname of the MSSQL server | string | Y | +| `user` | The username / client id to use for authentication with the MSSQL server | string | N | +| `password` | The password / client secret to use for authentication with the MSSQL server | string | N | +| `port` | The port number of the MSSQL server | int | N | +| `database` | The target database | string | N | +| `charset` | The character set used for the connection | string | N | +| `timeout` | The query timeout in seconds. Default: no timeout | int | N | +| `login_timeout` | The timeout for connection and login in seconds. Default: 60 | int | N | +| `appname` | The application name to use for the connection | string | N | +| `conn_properties` | The list of connection properties | list[string] | N | +| `autocommit` | Is autocommit mode enabled. Default: false | bool | N | +| `driver` | The driver to use for the connection. Default: pymssql | string | N | +| `driver_name` | The driver name to use for the connection (e.g., *ODBC Driver 18 for SQL Server*). | string | N | +| `odbc_properties` | ODBC connection properties (e.g., *authentication: ActiveDirectoryServicePrincipal*). See more [here](https://learn.microsoft.com/en-us/sql/connect/odbc/dsn-connection-string-attribute?view=sql-server-ver16). | dict | N | \ No newline at end of file diff --git a/docs/integrations/engines/mysql.md b/docs/integrations/engines/mysql.md index 77bd96d42b..e8426a3f5a 100644 --- a/docs/integrations/engines/mysql.md +++ b/docs/integrations/engines/mysql.md @@ -19,31 +19,3 @@ pip install "sqlmesh[mysql]" | `port` | The port number of the MySQL server | int | N | | `charset` | The character set used for the connection | string | N | | `ssl_disabled` | Is SSL disabled | bool | N | - -## Airflow Scheduler -**Engine Name:** `mysql` - -The SQLMesh MySQL Operator is similar to the [MySQLOperator](https://airflow.apache.org/docs/apache-airflow-providers-mysql/stable/index.html), and relies on the same [MySqlHook](https://airflow.apache.org/docs/apache-airflow-providers-mysql/1.0.0/_api/airflow/providers/mysql/hooks/mysql/index.html) implementation. - -To enable support for this operator, the Airflow MySQL provider package should be installed on the target Airflow cluster along with SQLMesh with the mysql extra: -``` -pip install "apache-airflow-providers-mysql" -pip install "sqlmesh[mysql]" -``` - -The operator requires an [Airflow connection](https://airflow.apache.org/docs/apache-airflow/stable/howto/connection.html) to determine the target MySQL account. Refer to [MySQL connection](https://airflow.apache.org/docs/apache-airflow-providers-mysql/stable/connections/mysql.html) for more details. - -By default, the connection ID is set to `mysql_default`, but can be overridden using the `engine_operator_args` parameter to the `SQLMeshAirflow` instance as in the example below: -```python linenums="1" -from sqlmesh.schedulers.airflow import NO_DEFAULT_CATALOG - -sqlmesh_airflow = SQLMeshAirflow( - "mysql", - default_catalog=NO_DEFAULT_CATALOG, - engine_operator_args={ - "mysql_conn_id": "" - }, -) -``` - -Note: `NO_DEFAULT_CATALOG` is required for MySQL since MySQL doesn't support catalogs. \ No newline at end of file diff --git a/docs/integrations/engines/postgres.md b/docs/integrations/engines/postgres.md index 5867c26494..cf1d3e4ce8 100644 --- a/docs/integrations/engines/postgres.md +++ b/docs/integrations/engines/postgres.md @@ -10,39 +10,16 @@ pip install "sqlmesh[postgres]" ### Connection options -| Option | Description | Type | Required | -|-------------------|---------------------------------------------------------------------------------|:------:|:--------:| -| `type` | Engine type name - must be `postgres` | string | Y | -| `host` | The hostname of the Postgres server | string | Y | -| `user` | The username to use for authentication with the Postgres server | string | Y | -| `password` | The password to use for authentication with the Postgres server | string | Y | -| `port` | The port number of the Postgres server | int | Y | -| `database` | The name of the database instance to connect to | string | Y | -| `keepalives_idle` | The number of seconds between each keepalive packet sent to the server. | int | N | -| `connect_timeout` | The number of seconds to wait for the connection to the server. (Default: `10`) | int | N | -| `role` | The role to use for authentication with the Postgres server | string | N | -| `sslmode` | The security of the connection to the Postgres server | string | N | - -## Airflow Scheduler -**Engine Name:** `postgres` - -The SQLMesh Postgres Operator is similar to the [PostgresOperator](https://airflow.apache.org/docs/apache-airflow-providers-postgres/stable/_api/airflow/providers/postgres/operators/postgres/index.html), and relies on the same [PostgresHook](https://airflow.apache.org/docs/apache-airflow-providers-postgres/stable/_api/airflow/providers/postgres/hooks/postgres/index.html) implementation. - -To enable support for this operator, the Airflow Postgres provider package should be installed on the target Airflow cluster along with SQLMesh with the Postgres extra: -``` -pip install "apache-airflow-providers-postgres" -pip install "sqlmesh[postgres]" -``` - -The operator requires an [Airflow connection](https://airflow.apache.org/docs/apache-airflow/stable/howto/connection.html) to determine the target Postgres account. Refer to [Postgres connection](https://airflow.apache.org/docs/apache-airflow-providers-postgres/stable/connections/postgres.html) for more details. - -By default, the connection ID is set to `postgres_default`, but can be overridden using the `engine_operator_args` parameter to the `SQLMeshAirflow` instance as in the example below: -```python linenums="1" -sqlmesh_airflow = SQLMeshAirflow( - "postgres", - default_catalog="", - engine_operator_args={ - "postgres_conn_id": "" - }, -) -``` \ No newline at end of file +| Option | Description | Type | Required | +|--------------------|---------------------------------------------------------------------------------|:------:|:--------:| +| `type` | Engine type name - must be `postgres` | string | Y | +| `host` | The hostname of the Postgres server | string | Y | +| `user` | The username to use for authentication with the Postgres server | string | Y | +| `password` | The password to use for authentication with the Postgres server | string | Y | +| `port` | The port number of the Postgres server | int | Y | +| `database` | The name of the database instance to connect to | string | Y | +| `keepalives_idle` | The number of seconds between each keepalive packet sent to the server. | int | N | +| `connect_timeout` | The number of seconds to wait for the connection to the server. (Default: `10`) | int | N | +| `role` | The role to use for authentication with the Postgres server | string | N | +| `sslmode` | The security of the connection to the Postgres server | string | N | +| `application_name` | The name of the application to use for the connection | string | N | diff --git a/docs/integrations/engines/redshift.md b/docs/integrations/engines/redshift.md index b4c461aa16..0b853dfee1 100644 --- a/docs/integrations/engines/redshift.md +++ b/docs/integrations/engines/redshift.md @@ -32,27 +32,38 @@ pip install "sqlmesh[redshift]" | `is_serverless` | If the Amazon Redshift cluster is serverless (Default: `False`) | bool | N | | `serverless_acct_id` | The account ID of the serverless cluster | string | N | | `serverless_work_group` | The name of work group for serverless end point | string | N | +| `enable_merge` | Whether the incremental_by_unique_key model kind will use the native Redshift MERGE operation or SQLMesh's logical merge. (Default: `False`) | bool | N | -## Airflow Scheduler -**Engine Name:** `redshift` +## Performance Considerations -In order to share a common implementation across local and Airflow, SQLMesh's Redshift engine implements its own hook and operator. +### Timestamp Macro Variables and Sort Keys -To enable support for this operator, the Airflow Redshift provider package should be installed on the target Airflow cluster along with SQLMesh with the Redshift extra: -``` -pip install "apache-airflow-providers-amazon" -pip install "sqlmesh[redshift]" -``` +When working with Redshift tables that have a `TIMESTAMP` sort key, using the standard `@start_dt` and `@end_dt` macro variables may lead to performance issues. These macros render as `TIMESTAMP WITH TIME ZONE` values in SQL queries, which prevents Redshift from performing efficient pruning when filtering against `TIMESTAMP` (without timezone) sort keys. + +This can result in full table scans instead, causing significant performance degradation. + +**Solution**: Use the `_dtntz` (datetime no timezone) variants of macro variables: + +- `@start_dtntz` instead of `@start_dt` +- `@end_dtntz` instead of `@end_dt` -The operator requires an [Airflow connection](https://airflow.apache.org/docs/apache-airflow/stable/howto/connection.html) to determine the target Redshift account. Refer to [AmazonRedshiftConnection](https://airflow.apache.org/docs/apache-airflow-providers-amazon/stable/connections/redshift.html#authenticating-to-amazon-redshift) for details on how to define a connection string. - -By default, the connection ID is set to `sqlmesh_redshift_default`, but it can be overridden using the `engine_operator_args` parameter to the `SQLMeshAirflow` instance as in the example below: -```python linenums="1" -sqlmesh_airflow = SQLMeshAirflow( - "redshift", - default_catalog="", - engine_operator_args={ - "redshift_conn_id": "" - }, -) -``` \ No newline at end of file +These variants render as `TIMESTAMP WITHOUT TIME ZONE`, allowing Redshift to properly utilize sort key optimizations. + +**Example**: + +```sql linenums="1" +-- Inefficient: May cause full table scan +SELECT * FROM my_table +WHERE timestamp_column >= @start_dt + AND timestamp_column < @end_dt + +-- Efficient: Uses sort key optimization +SELECT * FROM my_table +WHERE timestamp_column >= @start_dtntz + AND timestamp_column < @end_dtntz + +-- Alternative: Cast to timestamp +SELECT * FROM my_table +WHERE timestamp_column >= @start_ts::timestamp + AND timestamp_column < @end_ts::timestamp +``` diff --git a/docs/integrations/engines/risingwave.md b/docs/integrations/engines/risingwave.md new file mode 100644 index 0000000000..029cf6b1a1 --- /dev/null +++ b/docs/integrations/engines/risingwave.md @@ -0,0 +1,74 @@ +# RisingWave + +This page provides information about how to use SQLMesh with the [RisingWave](https://risingwave.com/) streaming database engine. + +!!! info + The RisingWave engine adapter is a community contribution. Due to this, only limited community support is available. + +## Local/Built-in Scheduler + +**Engine Adapter Type**: `risingwave` + +### Installation + +``` +pip install "sqlmesh[risingwave]" +``` + +## Connection options + +RisingWave is based on Postgres and uses the same `psycopg2` connection library. Therefore, the connection parameters are very similar to [Postgres](./postgres.md). + +| Option | Description | Type | Required | +|----------------|-------------------------------------------------------------------|:------:|:--------:| +| `type` | Engine type name - must be `risingwave` | string | Y | +| `host` | The hostname of the RisingWave server | string | Y | +| `user` | The username to use for authentication with the RisingWave server | string | Y | +| `password` | The password to use for authentication with the RisingWave server | string | N | +| `port` | The port number of the RisingWave engine server | int | Y | +| `database` | The name of the database instance to connect to | string | Y | +| `role` | The role to use for authentication with the RisingWave server | string | N | +| `sslmode` | The security of the connection to the RisingWave server | string | N | + +## Extra Features + +As a streaming database engine, RisingWave contains some extra features tailored specifically to streaming usecases. + +Primarily, these are: + - [Sources](https://docs.risingwave.com/sql/commands/sql-create-source) which are used to stream records into RisingWave from streaming sources like Kafka + - [Sinks](https://docs.risingwave.com/sql/commands/sql-create-sink) which are used to write the results of data processed by RisingWave to an external target, such as an Apache Iceberg table in object storage. + +RisingWave exposes these features via normal SQL statements, namely `CREATE SOURCE` and `CREATE SINK`. To utilize these in SQLMesh, you can use them in [pre / post statements](../../concepts/models/sql_models.md#optional-prepost-statements). + +Here is an example of creating a Sink from a SQLMesh model using a post statement: + +```sql +MODEL ( + name sqlmesh_example.view_model, + kind VIEW ( + materialized true + ) +); + +SELECT + item_id, + COUNT(DISTINCT id) AS num_orders, +FROM + sqlmesh_example.incremental_model +GROUP BY item_id; + +CREATE + SINK IF NOT EXISTS kafka_sink +FROM + @this_model +WITH ( + connector='kafka', + "properties.bootstrap.server"='localhost:9092', + topic='test1', +) +FORMAT PLAIN +ENCODE JSON (force_append_only=true); +``` + +!!! info "@this_model" + The `@this_model` macro resolves to the physical table for the current version of the model. See [here](../../concepts/macros/macro_variables.md#runtime-variables) for more information. diff --git a/docs/integrations/engines/snowflake.md b/docs/integrations/engines/snowflake.md index a0a44655d0..fc2ccbd6bb 100644 --- a/docs/integrations/engines/snowflake.md +++ b/docs/integrations/engines/snowflake.md @@ -1,5 +1,263 @@ # Snowflake +This page provides information about how to use SQLMesh with the Snowflake SQL engine. + +It begins with a [Connection Quickstart](#connection-quickstart) that demonstrates how to connect to Snowflake, or you can skip directly to information about using Snowflake with the [built-in](#localbuilt-in-scheduler). + +## Connection quickstart + +Connecting to cloud warehouses involves a few steps, so this connection quickstart provides the info you need to get up and running with Snowflake. + +It demonstrates connecting to Snowflake with the `snowflake-connector-python` library bundled with SQLMesh. + +Snowflake provides multiple methods of authorizing a connection (e.g., password, SSO, etc.). This quickstart demonstrates authorizing with a password, but configurations for other methods are [described below](#snowflake-authorization-methods). + +!!! tip + This quickstart assumes you are familiar with basic SQLMesh commands and functionality. + + If you're not, work through the [SQLMesh Quickstart](../../quick_start.md) before continuing! + +### Prerequisites + +Before working through this connection quickstart, ensure that: + +1. You have a Snowflake account and know your username and password +2. Your Snowflake account has at least one [warehouse](https://docs.snowflake.com/en/user-guide/warehouses-overview) available for running computations +3. Your computer has [SQLMesh installed](../../installation.md) with the [Snowflake extra available](../../installation.md#install-extras) + - Install from the command line with the command `pip install "sqlmesh[snowflake]"` +4. You have initialized a [SQLMesh example project](../../quickstart/cli#1-create-the-sqlmesh-project) on your computer + - Open a command line interface and navigate to the directory where the project files should go + - Initialize the project with the command `sqlmesh init snowflake` + +### Access control permissions + +SQLMesh must have sufficient permissions to create and access different types of database objects. + +SQLMesh's core functionality requires relatively broad permissions, including: + +1. Ability to create and delete schemas in a database +2. Ability to create, modify, delete, and query tables and views in the schemas it creates + +If your project uses materialized views or dynamic tables, SQLMesh will also need permissions to create, modify, delete, and query those object types. + +We now describe how to grant SQLMesh appropriate permissions. + +#### Snowflake roles + +Snowflake allows you to grant permissions directly to a user, or you can create and assign permissions to a "role" that you then grant to the user. + +Roles provide a convenient way to bundle sets of permissions and provide them to multiple users. We create and use a role to grant our user permissions in this quickstart. + +The role must be granted `USAGE` on a warehouse so it can execute computations. We describe other permissions below. + +#### Database permissions +The top-level object container in Snowflake is a "database" (often called a "catalog" in other engines). SQLMesh does not need permission to create databases; it may use an existing one. + +The simplest way to grant SQLMesh sufficient permissions for a database is to give it `OWNERSHIP` of the database, which includes all the necessary permissions. + +Alternatively, you may grant SQLMesh granular permissions for all the actions and objects it will work with in the database. + +#### Granting the permissions + +This section provides example code for creating a `sqlmesh` role, granting it sufficient permissions, and granting it to a user. + +The code must be executed by a user with `USERADMIN` level permissions or higher. We provide two versions of the code, one that grants database `OWNERSHIP` to the role and another that does not. + +Both examples create a role named `sqlmesh`, grant it usage of the warehouse `compute_wh`, create a database named `demo_db`, and assign the role to the user `demo_user`. The step that creates the database can be omitted if the database already exists. + +=== "With database ownership" + + ```sql linenums="1" + USE ROLE useradmin; -- This code requires USERADMIN privileges or higher + + CREATE ROLE sqlmesh; -- Create role for permissions + GRANT USAGE ON WAREHOUSE compute_wh TO ROLE sqlmesh; -- Can use warehouse + + CREATE DATABASE demo_db; -- Create database for SQLMesh to use (omit if database already exists) + GRANT OWNERSHIP ON DATABASE demo_db TO ROLE sqlmesh; -- Role owns database + + GRANT ROLE sqlmesh TO USER demo_user; -- Grant role to user + ALTER USER demo_user SET DEFAULT ROLE = sqlmesh; -- Make role user's default role + ``` + +=== "Without database ownership" + + ```sql linenums="1" + USE ROLE useradmin; -- This code requires USERADMIN privileges or higher + + CREATE ROLE sqlmesh; -- Create role for permissions + CREATE DATABASE demo_db; -- Create database for SQLMesh to use (omit if database already exists) + + GRANT USAGE ON WAREHOUSE compute_wh TO ROLE sqlmesh; -- Can use warehouse + GRANT USAGE ON DATABASE demo_db TO ROLE sqlmesh; -- Can use database + + GRANT CREATE SCHEMA ON DATABASE demo_db TO ROLE sqlmesh; -- Can create SCHEMAs in database + GRANT USAGE ON FUTURE SCHEMAS IN DATABASE demo_db TO ROLE sqlmesh; -- Can use schemas it creates + GRANT CREATE TABLE ON FUTURE SCHEMAS IN DATABASE demo_db TO ROLE sqlmesh; -- Can create TABLEs in schemas + GRANT CREATE VIEW ON FUTURE SCHEMAS IN DATABASE demo_db TO ROLE sqlmesh; -- Can create VIEWs in schemas + GRANT SELECT, INSERT, TRUNCATE, UPDATE, DELETE ON FUTURE TABLES IN DATABASE demo_db TO ROLE sqlmesh; -- Can SELECT and modify TABLEs in schemas + GRANT REFERENCES, SELECT ON FUTURE VIEWS IN DATABASE demo_db TO ROLE sqlmesh; -- Can SELECT and modify VIEWs in schemas + + GRANT ROLE sqlmesh TO USER demo_user; -- Grant role to user + ALTER USER demo_user SET DEFAULT ROLE = sqlmesh; -- Make role user's default role + ``` + +### Get connection info + +Now that our user has sufficient access permissions, we're ready to gather the information needed to configure the SQLMesh connection. + +#### Account name + +Snowflake connection configurations require the `account` parameter that identifies the Snowflake account SQLMesh should connect to. + +Snowflake account identifiers have two components: your organization name and your account name. Both are embedded in your Snowflake web interface URL, separated by a `/`. + +This shows the default view when you log in to your Snowflake account, where we can see the two components of the account identifier: + +![Snowflake account info in web URL](./snowflake/snowflake_db-guide_account-url.png){ loading=lazy } + +In this example, our organization name is `idapznw`, and our account name is `wq29399`. + +We concatenate the two components, separated by a `-`, for the SQLMesh `account` parameter: `idapznw-wq29399`. + +#### Warehouse name + +Your Snowflake account may have more than one warehouse available - any will work for this quickstart, which runs very few computations. + +Some Snowflake user accounts may have a default warehouse they automatically use when connecting. + +The connection configuration's `warehouse` parameter is not required, but we recommend specifying the warehouse explicitly in the configuration to ensure SQLMesh's behavior doesn't change if the user's default warehouse changes. + +#### Database name + +Snowflake user accounts may have a "Default Namespace" that includes a default database they automatically use when connecting. + +The connection configuration's `database` parameter is not required, but we recommend specifying the database explicitly in the configuration to ensure SQLMesh's behavior doesn't change if the user's default namespace changes. + +### Configure the connection + +We now have the information we need to configure SQLMesh's connection to Snowflake. + +We start the configuration by adding a gateway named `snowflake` to our example project's config.yaml file and making it our `default_gateway`: + +```yaml linenums="1" hl_lines="2-6" +gateways: + snowflake: + connection: + type: snowflake + +default_gateway: snowflake + +model_defaults: + dialect: snowflake + start: 2024-07-24 +``` + +And we specify the `account`, `user`, `password`, `database`, and `warehouse` connection parameters using the information from above: + +```yaml linenums="1" hl_lines="5-9" +gateways: + snowflake: + connection: + type: snowflake + account: idapznw-wq29399 + user: DEMO_USER + password: << password here >> + database: DEMO_DB + warehouse: COMPUTE_WH + +default_gateway: snowflake + +model_defaults: + dialect: snowflake + start: 2024-07-24 +``` + +!!! warning + Best practice for storing secrets like passwords is placing them in [environment variables that the configuration file loads dynamically](../../guides/configuration.md#environment-variables). For simplicity, this guide instead places the value directly in the configuration file. + + This code demonstrates how to use the environment variable `SNOWFLAKE_PASSWORD` for the configuration's `password` parameter: + + ```yaml linenums="1" hl_lines="5" + gateways: + snowflake: + connection: + type: snowflake + password: {{ env_var('SNOWFLAKE_PASSWORD') }} + ``` + +### Check connection + +We have now specified the `snowflake` gateway connection information, so we can confirm that SQLMesh is able to successfully connect to Snowflake. We will test the connection with the `sqlmesh info` command. + +First, open a command line terminal. Now enter the command `sqlmesh info`: + +![Run sqlmesh info command in CLI](./snowflake/snowflake_db-guide_sqlmesh-info.png){ loading=lazy } + +The output shows that our data warehouse connection succeeded: + +![Successful data warehouse connection](./snowflake/snowflake_db-guide_sqlmesh-info-succeeded.png){ loading=lazy } + +However, the output includes a `WARNING` about using the Snowflake SQL engine for storing SQLMesh state: + +![Snowflake state connection warning](./snowflake/snowflake_db-guide_sqlmesh-info-warning.png){ loading=lazy } + +!!! warning + Snowflake is not designed for transactional workloads and should not be used to store SQLMesh state even in testing deployments. + + Learn more about storing SQLMesh state [here](../../guides/configuration.md#state-connection). + +### Specify state connection + +We can store SQLMesh state in a different SQL engine by specifying a `state_connection` in our `snowflake` gateway. + +This example uses the DuckDB engine to store state in the local `snowflake_state.db` file: + +```yaml linenums="1" hl_lines="10-12" +gateways: + snowflake: + connection: + type: snowflake + account: idapznw-wq29399 + user: DEMO_USER + password: << your password here >> + database: DEMO_DB + warehouse: COMPUTE_WH + state_connection: + type: duckdb + database: snowflake_state.db + +default_gateway: snowflake + +model_defaults: + dialect: snowflake + start: 2024-07-24 +``` + +Now we no longer see the warning when running `sqlmesh info`, and we see a new entry `State backend connection succeeded`: + +![No state connection warning](./snowflake/snowflake_db-guide_sqlmesh-info-no-warning.png){ loading=lazy } + +### Run a `sqlmesh plan` + +Now we're ready to run a `sqlmesh plan` in Snowflake: + +![Run sqlmesh plan in snowflake](./snowflake/snowflake_db-guide_sqlmesh-plan.png){ loading=lazy } + +And confirm that our schemas and objects exist in the Snowflake catalog: + +![Sqlmesh plan objects in snowflake](./snowflake/snowflake_db-guide_sqlmesh-plan-objects.png){ loading=lazy } + +Congratulations - your SQLMesh project is up and running on Snowflake! + +### Where are the row counts? + +SQLMesh reports the number of rows processed by each model in its `plan` and `run` terminal output. + +However, due to limitations in the Snowflake Python connector, row counts cannot be determined for `CREATE TABLE AS` statements. Therefore, SQLMesh does not report row counts for certain model kinds, such as `FULL` models. + +Learn more about the connector limitation [on Github](https://github.com/snowflakedb/snowflake-connector-python/issues/645). + ## Local/Built-in Scheduler **Engine Adapter Type**: `snowflake` @@ -13,10 +271,10 @@ pip install "sqlmesh[snowflake]" | Option | Description | Type | Required | |--------------------------|--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------:|:--------:| | `type` | Engine type name - must be `snowflake` | string | Y | +| `account` | The Snowflake account name | string | Y | | `user` | The Snowflake username | string | N | | `password` | The Snowflake password | string | N | | `authenticator` | The Snowflake authenticator method | string | N | -| `account` | The Snowflake account name | string | Y | | `warehouse` | The Snowflake warehouse name | string | N | | `database` | The Snowflake database name | string | N | | `role` | The Snowflake role name | string | N | @@ -27,9 +285,11 @@ pip install "sqlmesh[snowflake]" | `session_parameters` | The optional session parameters to set for the connection. | dict | N | -#### Lowercase object names +### Lowercase object names -Snowflake object names are case-insensitive by default. If you have intentionally created an object with a case-sensitive lowercase name, specify it with outer single and inner double quotes. +Snowflake object names are case-insensitive by default, and Snowflake automatically normalizes them to uppercase. For example, the command `CREATE SCHEMA sqlmesh` will generate a schema named `SQLMESH` in Snowflake. + +If you need to create an object with a case-sensitive lowercase name, the name must be double-quoted in SQL code. In the SQLMesh configuration file, it also requires outer single quotes. For example, a connection to the database `"my_db"` would include: @@ -37,10 +297,16 @@ For example, a connection to the database `"my_db"` would include: connection: type: snowflake - database: '"my_db"' + database: '"my_db"' # outer single and inner double quotes ``` -### Snowflake SSO Authorization +### Snowflake authorization methods + +The simplest (but arguably least secure) method of authorizing a connection with Snowflake is with a username and password. + +This section describes how to configure other authorization methods. + +#### Snowflake SSO Authorization SQLMesh supports Snowflake SSO authorization connections using the `externalbrowser` authenticator method. For example: @@ -57,7 +323,7 @@ gateways: role: ************ ``` -### Snowflake OAuth Authorization +#### Snowflake OAuth Authorization SQLMesh supports Snowflake OAuth authorization connections using the `oauth` authenticator method. For example: @@ -92,11 +358,13 @@ SQLMesh supports Snowflake OAuth authorization connections using the `oauth` aut ) ``` -### Snowflake Private Key Authorization +#### Snowflake Private Key Authorization + +SQLMesh supports Snowflake private key authorization connections by providing the private key as a path, Base64-encoded DER format (representing the key bytes), a plain-text PEM format, or as bytes (Python Only). -SQLMesh supports Snowflake private key authorization connections by providing the private key as a path, Base64-encoded DER format (representing the key bytes), a plain-text PEM format, or as bytes (Python Only). `account` and `user` are required. For example: +The `account` and `user` parameters are required for each of these methods. -#### Private Key Path +__Private Key Path__ Note: `private_key_passphrase` is only needed if the key was encrypted with a passphrase. @@ -132,7 +400,7 @@ Note: `private_key_passphrase` is only needed if the key was encrypted with a pa ``` -#### Private Key PEM +__Private Key PEM__ Note: `private_key_passphrase` is only needed if the key was encrypted with a passphrase. @@ -174,7 +442,7 @@ Note: `private_key_passphrase` is only needed if the key was encrypted with a pa ``` -#### Private Key Base64 +__Private Key Base64__ Note: This is base64 encoding of the bytes of the key itself and not the PEM file contents. @@ -207,7 +475,7 @@ Note: This is base64 encoding of the bytes of the key itself and not the PEM fil ) ``` -#### Private Key Bytes +__Private Key Bytes__ === "YAML" @@ -222,21 +490,21 @@ Note: This is base64 encoding of the bytes of the key itself and not the PEM fil ModelDefaultsConfig, SnowflakeConnectionConfig, ) - + from cryptography.hazmat.primitives import serialization - + key = """-----BEGIN PRIVATE KEY----- ... -----END PRIVATE KEY-----""".encode() - + p_key= serialization.load_pem_private_key(key, password=None) - + pkb = p_key.private_bytes( encoding=serialization.Encoding.DER, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption(), ) - + config = Config( model_defaults=ModelDefaultsConfig(dialect="snowflake"), gateways={ @@ -253,38 +521,106 @@ Note: This is base64 encoding of the bytes of the key itself and not the PEM fil The authenticator method is assumed to be `snowflake_jwt` when `private_key` is provided, but it can also be explicitly provided in the connection configuration. -## Airflow Scheduler -**Engine Name:** `snowflake` +## Configuring Virtual Warehouses -The SQLMesh Snowflake Operator is similar to the [SnowflakeOperator](https://airflow.apache.org/docs/apache-airflow-providers-snowflake/stable/operators/snowflake.html), and relies on the same [SnowflakeHook](https://airflow.apache.org/docs/apache-airflow-providers-snowflake/stable/_api/airflow/providers/snowflake/hooks/snowflake/index.html) implementation. +The Snowflake Virtual Warehouse a model should use can be specified in the `session_properties` attribute of the model definition: -To enable support for this operator, the Airflow Snowflake provider package should be installed on the target Airflow cluster along with SQLMesh with the Snowflake extra: +```sql linenums="1" +MODEL ( + name schema_name.model_name, + session_properties ( + 'warehouse' = TEST_WAREHOUSE, + ), +); ``` -pip install "apache-airflow-providers-snowflake[common.sql]" -pip install "sqlmesh[snowflake]" + +## Custom View and Table types + +SQLMesh supports custom view and table types for Snowflake models. You can apply these modifiers to either the physical layer or virtual layer of a model using the `physical_properties` and `virtual_properties` attributes respectively. For example: + +### Secure Views + +A table can be exposed through a `SECURE` view in the virtual layer by specifying the `creatable_type` property and setting it to `SECURE`: + +```sql linenums="1" +MODEL ( + name schema_name.model_name, + virtual_properties ( + creatable_type = SECURE + ) +); + +SELECT a FROM schema_name.model_b; ``` -The operator requires an [Airflow connection](https://airflow.apache.org/docs/apache-airflow/stable/howto/connection.html) to determine the target Snowflake account. Refer to [Snowflake connection](https://airflow.apache.org/docs/apache-airflow-providers-snowflake/stable/connections/snowflake.html) for more details. - -By default, the connection ID is set to `snowflake_default`, but can be overridden using the `engine_operator_args` parameter to the `SQLMeshAirflow` instance as in the example below: -```python linenums="1" -sqlmesh_airflow = SQLMeshAirflow( - "snowflake", - default_catalog="", - engine_operator_args={ - "snowflake_conn_id": "" - }, -) +### Transient Tables + +A model can use a `TRANSIENT` table in the physical layer by specifying the `creatable_type` property and setting it to `TRANSIENT`: + +```sql linenums="1" +MODEL ( + name schema_name.model_name, + physical_properties ( + creatable_type = TRANSIENT + ) +); + +SELECT a FROM schema_name.model_b; ``` -## Configuring Virtual Warehouses +### Iceberg Tables -The Snowflake Virtual Warehouse can be specified on a per-model basis using the `session_properties` attribute of the model definition: -```sql +In order for Snowflake to be able to create an Iceberg table, there must be an [External Volume](https://docs.snowflake.com/en/user-guide/tables-iceberg-configure-external-volume) configured to store the Iceberg table data on. + +Once that is configured, you can create a model backed by an Iceberg table by using `table_format iceberg` like so: + +```sql linenums="1" hl_lines="4 6-7" MODEL ( - name model_name, - session_properties ( - 'warehouse' = TEST_WAREHOUSE, - ), + name schema_name.model_name, + kind FULL, + table_format iceberg, + physical_properties ( + catalog = 'snowflake', + external_volume = '' + ) ); ``` + +To prevent having to specify `catalog = 'snowflake'` and `external_volume = ''` on every model, see the Snowflake documentation for: + + - [Configuring a default Catalog](https://docs.snowflake.com/en/user-guide/tables-iceberg-configure-catalog-integration#set-a-default-catalog-at-the-account-database-or-schema-level) + - [Configuring a default External Volume](https://docs.snowflake.com/en/user-guide/tables-iceberg-configure-external-volume#set-a-default-external-volume-at-the-account-database-or-schema-level) + +Alternatively you can also use [model defaults](../../guides/configuration.md#model-defaults) to set defaults at the SQLMesh level instead. + +To utilize the wide variety of [optional properties](https://docs.snowflake.com/en/sql-reference/sql/create-iceberg-table-snowflake#optional-parameters) that Snowflake makes available for Iceberg tables, simply specify them as `physical_properties`: + +```sql linenums="1" hl_lines="8" +MODEL ( + name schema_name.model_name, + kind FULL, + table_format iceberg, + physical_properties ( + catalog = 'snowflake', + external_volume = 'my_external_volume', + base_location = 'my/product_reviews/' + ) +); +``` + +!!! warning "External catalogs" + + Setting `catalog = 'snowflake'` to use Snowflake's internal catalog is a good default because SQLMesh needs to be able to write to the tables it's managing and Snowflake [does not support](https://docs.snowflake.com/en/user-guide/tables-iceberg#catalog-options) writing to Iceberg tables configured under external catalogs. + + You can however still reference a table from an external catalog in your model as a normal [external table](../../concepts/models/external_models.md). + +## Troubleshooting + +### Frequent Authentication Prompts + +When using Snowflake with security features like Multi-Factor Authentication (MFA), you may experience repeated prompts for authentication while running SQLMesh commands. This typically occurs when your Snowflake account isn't configured to issue short-lived tokens. + +To reduce authentication prompts, you can enable token caching in your Snowflake connection configuration: + +- For general authentication, see [Connection Caching Documentation](https://docs.snowflake.com/en/user-guide/admin-security-fed-auth-use#using-connection-caching-to-minimize-the-number-of-prompts-for-authentication-optional) +- For MFA specifically, see [MFA Token Caching Documentation](https://docs.snowflake.com/en/user-guide/security-mfa#using-mfa-token-caching-to-minimize-the-number-of-prompts-during-authentication-optional). diff --git a/docs/integrations/engines/snowflake/snowflake_db-guide_account-url.png b/docs/integrations/engines/snowflake/snowflake_db-guide_account-url.png new file mode 100644 index 0000000000..9ad93e4b6b Binary files /dev/null and b/docs/integrations/engines/snowflake/snowflake_db-guide_account-url.png differ diff --git a/docs/integrations/engines/snowflake/snowflake_db-guide_sqlmesh-info-no-warning.png b/docs/integrations/engines/snowflake/snowflake_db-guide_sqlmesh-info-no-warning.png new file mode 100644 index 0000000000..8cac812f05 Binary files /dev/null and b/docs/integrations/engines/snowflake/snowflake_db-guide_sqlmesh-info-no-warning.png differ diff --git a/docs/integrations/engines/snowflake/snowflake_db-guide_sqlmesh-info-succeeded.png b/docs/integrations/engines/snowflake/snowflake_db-guide_sqlmesh-info-succeeded.png new file mode 100644 index 0000000000..30f5e6b8ad Binary files /dev/null and b/docs/integrations/engines/snowflake/snowflake_db-guide_sqlmesh-info-succeeded.png differ diff --git a/docs/integrations/engines/snowflake/snowflake_db-guide_sqlmesh-info-warning.png b/docs/integrations/engines/snowflake/snowflake_db-guide_sqlmesh-info-warning.png new file mode 100644 index 0000000000..12d886a0ee Binary files /dev/null and b/docs/integrations/engines/snowflake/snowflake_db-guide_sqlmesh-info-warning.png differ diff --git a/docs/integrations/engines/snowflake/snowflake_db-guide_sqlmesh-info.png b/docs/integrations/engines/snowflake/snowflake_db-guide_sqlmesh-info.png new file mode 100644 index 0000000000..27fde273ab Binary files /dev/null and b/docs/integrations/engines/snowflake/snowflake_db-guide_sqlmesh-info.png differ diff --git a/docs/integrations/engines/snowflake/snowflake_db-guide_sqlmesh-plan-objects.png b/docs/integrations/engines/snowflake/snowflake_db-guide_sqlmesh-plan-objects.png new file mode 100644 index 0000000000..92cf290ece Binary files /dev/null and b/docs/integrations/engines/snowflake/snowflake_db-guide_sqlmesh-plan-objects.png differ diff --git a/docs/integrations/engines/snowflake/snowflake_db-guide_sqlmesh-plan.png b/docs/integrations/engines/snowflake/snowflake_db-guide_sqlmesh-plan.png new file mode 100644 index 0000000000..dedce422b9 Binary files /dev/null and b/docs/integrations/engines/snowflake/snowflake_db-guide_sqlmesh-plan.png differ diff --git a/docs/integrations/engines/spark.md b/docs/integrations/engines/spark.md index b9d22fdc4a..652d26a614 100644 --- a/docs/integrations/engines/spark.md +++ b/docs/integrations/engines/spark.md @@ -14,36 +14,6 @@ NOTE: Spark may not be used for the SQLMesh [state connection](../../reference/c | `catalog` | The catalog to use when issuing commands. See [Catalog Support](#catalog-support) for details | string | N | | `config` | Key/value pairs to set for the Spark Configuration. | dict | N | -## Airflow Scheduler -**Engine Name:** `spark` - -The SQLMesh Spark operator is very similar to the Airflow [SparkSubmitOperator](https://airflow.apache.org/docs/apache-airflow-providers-apache-spark/stable/operators.html#sparksubmitoperator), and relies on the same [SparkSubmitHook](https://airflow.apache.org/docs/apache-airflow-providers-apache-spark/stable/_api/airflow/providers/apache/spark/hooks/spark_submit/index.html#airflow.providers.apache.spark.hooks.spark_submit.SparkSubmitHook) implementation. - -To enable support for this operator, the Airflow Spark provider package should be installed on the target Airflow cluster as follows: -``` -pip install apache-airflow-providers-apache-spark -``` - -The operator requires an [Airflow connection](https://airflow.apache.org/docs/apache-airflow/stable/howto/connection.html) to determine the target cluster, queue, and deploy mode in which the Spark Job should be submitted. Refer to [Apache Spark connection](https://airflow.apache.org/docs/apache-airflow-providers-apache-spark/stable/connections/spark.html) for more details. - -By default, the connection ID is set to `spark_default`, but it can be overridden using the `engine_operator_args` parameter to the `SQLMeshAirflow` instance as in the example below: -```python linenums="1" -sqlmesh_airflow = SQLMeshAirflow( - "spark", - default_catalog="", - engine_operator_args={ - "connection_id": "" - }, -) -``` -Similarly, the `engine_operator_args` parameter can be used to override other job submission parameters, such as number of allocated cores, executors, and so forth. The full list of parameters that can be overridden can be found in `sqlmesh.schedulers.airflow.operators.spark_submit.SQLMeshSparkSubmitOperator`. - -**Cluster mode** -

- -Each Spark job submitted by SQLMesh is a PySpark application that depends on the SQLMesh library in its Driver process (but not in Executors). This means that if the Airflow connection is configured to submit jobs in `cluster` mode as opposed to `client` mode, the user must ensure that the SQLMesh Python library is installed on each node of a cluster where Spark jobs are submitted. This is because there is no way to know in advance which specific node to which a Driver process will be scheduled. No additional configuration is required if the deploy mode is set to `client`. - - ## Catalog Support SQLMesh's Spark integration is only designed/tested with a single catalog usage in mind. diff --git a/docs/integrations/engines/trino.md b/docs/integrations/engines/trino.md index 0f618734c6..db732f0cc1 100644 --- a/docs/integrations/engines/trino.md +++ b/docs/integrations/engines/trino.md @@ -47,7 +47,11 @@ iceberg.catalog.type=hive_metastore **Note**: The Trino Iceberg Connector must be configured with an `iceberg.catalog.type` that supports views. At the time of this writing, this is `hive_metastore`, `glue`, and `rest`. -The `jdbc` and `nessie` catalogs do not support views and are thus incompatible with SQLMesh. +The `jdbc` and `nessie` iceberg catalog types do not support views and are thus incompatible with SQLMesh. + +!!! info "Nessie" + Nessie is supported when used as an Iceberg REST Catalog (`iceberg.catalog.type=rest`). + For more information on how to configure the Trino Iceberg connector for this, see the [Nessie documentation](https://projectnessie.org/nessie-latest/trino/). #### Delta Lake Connector Configuration @@ -60,55 +64,169 @@ hive.metastore.uri=thrift://example.net:9083 delta.hive-catalog-name=datalake_delta # example catalog name, can be any valid string ``` +#### AWS Glue + +[AWS Glue](https://aws.amazon.com/glue/) provides an implementation of the Hive metastore catalog. + +Your Trino project's physical data objects are stored in a specific location, such as an [AWS S3](https://aws.amazon.com/s3/) bucket. Hive provides a default location, which you can override in its configuration file. + +Set the default location for your project's tables in the Hive catalog configuration's [`hive.metastore.glue.default-warehouse-dir` parameter](https://trino.io/docs/current/object-storage/metastores.html#aws-glue-catalog-configuration-properties). + +For example: + +```linenums="1" +hive.metastore=glue +hive.metastore.glue.default-warehouse-dir=s3://my-bucket/ +``` + ### Connection options -| Option | Description | Type | Required | -|----------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------:|:--------:| -| `type` | Engine type name - must be `trino` | string | Y | -| `user` | The username (of the account) to log in to your cluster. When connecting to Starburst Galaxy clusters, you must include the role of the user as a suffix to the username. | string | Y | -| `host` | The hostname of your cluster. Don't include the `http://` or `https://` prefix. | string | Y | -| `catalog` | The name of a catalog in your cluster. | string | Y | -| `http_scheme` | The HTTP scheme to use when connecting to your cluster. By default, it's `https` and can only be `http` for no-auth or basic auth. | string | N | -| `port` | The port to connect to your cluster. By default, it's `443` for `https` scheme and `80` for `http` | int | N | -| `roles` | Mapping of catalog name to a role | dict | N | -| `http_headers` | Additional HTTP headers to send with each request. | dict | N | -| `session_properties` | Trino session properties. Run `SHOW SESSION` to see all options. | dict | N | -| `retries` | Number of retries to attempt when a request fails. Default: `3` | int | N | -| `timezone` | Timezone to use for the connection. Default: client-side local timezone | string | N | - -## Airflow Scheduler -**Engine Name:** `trino` - -The SQLMesh Trino Operator is similar to the [TrinoOperator](https://airflow.apache.org/docs/apache-airflow-providers-trino/stable/operators/trino.html), and relies on the same [TrinoHook](https://airflow.apache.org/docs/apache-airflow-providers-trino/stable/_api/airflow/providers/trino/hooks/trino/index.html) implementation. - -To enable support for this operator, the Airflow Trino provider package should be installed on the target Airflow cluster along with SQLMesh with the Trino extra: +| Option | Description | Type | Required | +|---------------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:------:|:--------:| +| `type` | Engine type name - must be `trino` | string | Y | +| `user` | The username (of the account) to log in to your cluster. When connecting to Starburst Galaxy clusters, you must include the role of the user as a suffix to the username. | string | Y | +| `host` | The hostname of your cluster. Don't include the `http://` or `https://` prefix. | string | Y | +| `catalog` | The name of a catalog in your cluster. | string | Y | +| `http_scheme` | The HTTP scheme to use when connecting to your cluster. By default, it's `https` and can only be `http` for no-auth or basic auth. | string | N | +| `port` | The port to connect to your cluster. By default, it's `443` for `https` scheme and `80` for `http` | int | N | +| `roles` | Mapping of catalog name to a role | dict | N | +| `source` | Value to send as Trino's `source` field for query attribution / auditing. Default: `sqlmesh`. | string | N | +| `http_headers` | Additional HTTP headers to send with each request. | dict | N | +| `session_properties` | Trino session properties. Run `SHOW SESSION` to see all options. | dict | N | +| `retries` | Number of retries to attempt when a request fails. Default: `3` | int | N | +| `timezone` | Timezone to use for the connection. Default: client-side local timezone | string | N | +| `schema_location_mapping` | A mapping of regex patterns to S3 locations to use for the `LOCATION` property when creating schemas. See [Table and Schema locations](#table-and-schema-locations) for more details. | dict | N | +| `catalog_type_overrides` | A mapping of catalog names to their connector type. This is used to enable/disable connector specific behavior. See [Catalog Type Overrides](#catalog-type-overrides) for more details. | dict | N | + +## Table and Schema locations + +When using connectors that are decoupled from their storage (such as the Iceberg, Hive or Delta connectors), when creating new tables Trino needs to know the location in the physical storage it should write the table data to. + +This location gets stored against the table in the metastore so that any engine trying to read the data knows where to look. + +### Default behaviour + +Trino allows you to optionally configure a `default-warehouse-dir` property at the [Metastore](https://trino.io/docs/current/object-storage/metastores.html) level. When creating objects, Trino will infer schema locations to be `/` and table locations to be `//
`. + +However, if you dont set this property, Trino can still infer table locations if a *schema* location is explicitly set. + +For example, if you specify the `LOCATION` property when creating a schema like so: + +```sql +CREATE SCHEMA staging_data +WITH (LOCATION = 's3://warehouse/production/staging_data') ``` -pip install "apache-airflow-providers-trino" -pip install "sqlmesh[trino]" + +Then any tables created under that schema will have their location inferred as `/
`. + +If you specify neither a `default-warehouse-dir` in the metastore config nor a schema location when creating the schema, you must specify an explicit table location when creating the table or Trino will produce an error. + +Creating a table in a specific location is very similar to creating a schema in a specific location: + +```sql +CREATE TABLE staging_data.customers (customer_id INT) +WITH (LOCATION = 's3://warehouse/production/staging_data/customers') ``` -The operator requires an [Airflow connection](https://airflow.apache.org/docs/apache-airflow/stable/howto/connection.html) to determine the target Trino account. Refer to [Trino connection](https://airflow.apache.org/docs/apache-airflow-providers-trino/stable/connections.html) for more details. - -By default, the connection ID is set to `trino_default`, but can be overridden using the `engine_operator_args` parameter to the `SQLMeshAirflow` instance as in the example below: -```python linenums="1" -sqlmesh_airflow = SQLMeshAirflow( - "trino", - default_catalog="", - engine_operator_args={ - "trino_conn_id": "" - }, -) +### Configuring in SQLMesh + +Within SQLMesh, you can configure the value to use for the `LOCATION` property when SQLMesh creates tables and schemas. This overrides what Trino would have inferred based on the cluster configuration. + +#### Schemas + +To configure the `LOCATION` property that SQLMesh will specify when issuing `CREATE SCHEMA` statements, you can use the `schema_location_mapping` connection property. This applies to all schemas that SQLMesh creates, including its internal ones. + +The simplest example is to emulate a `default-warehouse-dir`: + +```yaml title="config.yaml" +gateways: + trino: + connection: + type: trino + ... + schema_location_mapping: + '.*': 's3://warehouse/production/@{schema_name}' +``` + +This will cause all schemas to get created with their location set to `s3://warehouse/production/`. The table locations will be inferred by Trino as `s3://warehouse/production//
` so all objects will effectively be created under `s3://warehouse/production/`. + +It's worth mentioning that if your models are using fully qualified three part names, eg `..` then string being matched against the `schema_location_mapping` regex will be `.` and not just the `` itself. This allows you to set different locations for the same schema name if that schema name is used across multiple catalogs. + +If your models are using two part names, eg `.
` then only the `` part will be matched against the regex. + +Here's an example: + +```yaml title="config.yaml" +gateways: + trino: + connection: + type: trino + ... + schema_location_mapping: + '^utils$': 's3://utils-bucket/@{schema_name}' + '^landing\..*$': 's3://raw-data/@{catalog_name}/@{schema_name}' + '^staging.*$': 's3://bucket/@{schema_name}_dev' + '^sqlmesh.*$': 's3://sqlmesh-internal/dev/@{schema_name}' +``` + +This would perform the following mappings: + +- a schema called `sales` would not be mapped to a location at all because it doesnt match any of the patterns. It would be created without a `LOCATION` property +- a schema called `utils` would be mapped to the location `s3://utils-bucket/utils` because it directly matches the `^utils$` pattern +- a schema called `transactions` in a catalog called `landing` would be mapped to the location `s3://raw-data/landing/transactions` because the string `landing.transactions` matches the `^landing\..*$` pattern +- schemas called `staging_customers` and `staging_accounts` would be mapped to the locations `s3://bucket/staging_customers_dev` and `s3://bucket/staging_accounts_dev` respectively because they match the `^staging.*$` pattern +- a schema called `accounts` in a catalog called `staging` would be mapped to the location `s3://bucket/accounts_dev` because the string `staging.accounts` matches the `^staging.*$` pattern +- schemas called `sqlmesh__staging_customers` and `sqlmesh__staging_utils` would be mapped to the locations `s3://sqlmesh-internal/dev/sqlmesh__staging_customers` and `s3://sqlmesh-internal/dev/sqlmesh__staging_utils` respectively because they match the `^sqlmesh.*$` pattern + +!!! info "Placeholders" + You may use the `@{catalog_name}` and `@{schema_name}` placeholders in the mapping value. + + If there is a match on one of the patterns then the catalog / schema that SQLMesh is about to use in the `CREATE SCHEMA` statement will be substituted into these placeholders. + + Note the use of curly brace syntax `@{}` when referencing these placeholders - learn more [here](../../concepts/macros/sqlmesh_macros.md#embedding-variables-in-strings). + +#### Tables + +Often, you don't need to configure an explicit table location because if you have configured explicit schema locations, table locations are automatically inferred by Trino to be a subdirectory under the schema location. + +However, if you need to, you can configure an explicit table location by adding a `location` property to the model `physical_properties`. + +Note that you need to use the [@resolve_template](../../concepts/macros/sqlmesh_macros.md#resolve_template) macro to generate a unique table location for each model version. Otherwise, all model versions will be written to the same location and clobber each other. + +```sql hl_lines="5" +MODEL ( + name staging.customers, + kind FULL, + physical_properties ( + location = @resolve_template('s3://warehouse/@{catalog_name}/@{schema_name}/@{table_name}') + ) +); + +SELECT ... ``` -```yaml linenums="1" -gateway_name: - connection: - type: trino - user: [user] - host: [host] - catalog: [catalog] + +This will cause SQLMesh to set the specified `LOCATION` when issuing a `CREATE TABLE` statement. + +## Catalog Type Overrides + +SQLMesh attempts to determine the connector type of a catalog by querying the `system.metadata.catalogs` table and checking the `connector_name` column. +It checks if the connector name is `hive` for Hive connector behavior or contains `iceberg` or `delta_lake` for Iceberg or Delta Lake connector behavior respectively. +However, the connector name may not always be a reliable way to determine the connector type, for example when using a custom connector or a fork of an existing connector. +To handle such cases, you can use the `catalog_type_overrides` connection property to explicitly specify the connector type for specific catalogs. +For example, to specify that the `datalake` catalog is using the Iceberg connector and the `analytics` catalog is using the Hive connector, you can configure the connection as follows: + +```yaml title="config.yaml" +gateways: + trino: + connection: + type: trino + ... + catalog_type_overrides: + datalake: iceberg + analytics: hive ``` -### Authentication +## Authentication === "No Auth" | Option | Description | Type | Required | diff --git a/docs/integrations/github.md b/docs/integrations/github.md index 7cf0c8a10e..07903fce56 100644 --- a/docs/integrations/github.md +++ b/docs/integrations/github.md @@ -5,6 +5,7 @@ The GitHub Actions CI/CD Bot enables teams to automate their SQLMesh projects using GitHub Actions. It can be configured to perform the following things: * Automatically run unit tests on PRs +* Automatically run the linter on PRs * Automatically create PR environments that represent the code changes in the PR * Automatically categorize and backfill data for models that have changed * Automatically deploy changes to production with automatic data gap prevention and merge the PR @@ -112,6 +113,7 @@ In this example we configured the merge method to be `squash`. See [Bot Configur One way to signal to SQLMesh that a PR is ready to go to production is through the use of "Required Approvers". In this approach users configure their SQLMesh project to list users that are designated as "Required Approver" and then when the bot detects an approval was received from one of these individuals then it determines that it is time to deploy to production. +The bot will only do the deploy to prod if the base branch is a production branch (as defined in the bot's configuration but defaults to either `main` or `master`). This pattern can be a great fit for teams that already have an approval process like this in place and therefore it actually removes an extra step from either the author or the approver since SQLMesh will automate the deployment and merge until of it having to be manually done. ##### Required Approval Configuration @@ -292,10 +294,14 @@ Below is an example of how to define the default config for the bot in either YA | `command_namespace` | The namespace to use for SQLMesh commands. For example if you provide `#SQLMesh` as a value then commands will be expected in the format of `#SQLMesh/`. Default: `None` meaning no namespace is used. | string | N | | `auto_categorize_changes` | Auto categorization behavior to use for the bot. If not provided then the project-wide categorization behavior is used. See [Auto-categorize model changes](https://sqlmesh.readthedocs.io/en/stable/guides/configuration/#auto-categorize-model-changes) for details. | dict | N | | `default_pr_start` | Default start when creating PR environment plans. If running in a mode where the bot automatically backfills models (based on `auto_categorize_changes` behavior) then this can be used to limit the amount of data backfilled. Defaults to `None` meaning the start date is set to the earliest model's start or to 1 day ago if [data previews](../concepts/plans.md#data-preview) need to be computed. | str | N | +| `pr_min_intervals` | Intended for use when `default_pr_start` is set to a relative time, eg `1 week ago`. This ensures that at least this many intervals across every model are included for backfill in the PR environment. Without this, models with an interval unit wider than `default_pr_start` (such as `@monthly` models if `default_pr_start` was set to `1 week ago`) will be excluded from backfill entirely. | int | N | | `skip_pr_backfill` | Indicates if the bot should skip backfilling models in the PR environment. Default: `True` | bool | N | | `pr_include_unmodified` | Indicates whether to include unmodified models in the PR environment. Default to the project's config value (which defaults to `False`) | bool | N | -| `run_on_deploy_to_prod` | Indicates whether to run latest intervals when deploying to prod. If set to false, the deployment will backfill only the changed models up to the existing latest interval in production, ignoring any missing intervals beyond this point. Default: `True` | bool | N | -| `pr_environment_name` | The name of the PR environment to create for which a PR number will be appended to. Defaults to the repo name if not provided. Note: The name will be normalized to alphanumeric + underscore and lowercase. | str | N | +| `run_on_deploy_to_prod` | Indicates whether to run latest intervals when deploying to prod. If set to false, the deployment will backfill only the changed models up to the existing latest interval in production, ignoring any missing intervals beyond this point. Default: `False` | bool | N | +| `pr_environment_name` | The name of the PR environment to create for which a PR number will be appended to. Defaults to the repo name if not provided. Note: The name will be normalized to alphanumeric + underscore and lowercase. | str | N | +| `prod_branch_name` | The name of the git branch associated with production. Ex: `prod`. Default: `main` or `master` is considered prod | str | N | +| `forward_only_branch_suffix` | If the git branch has this suffix, trigger a [forward-only](../concepts/plans.md#forward-only-plans) plan instead of a normal plan. Default: `-forward-only` | str | N | +| `check_if_blocked_on_deploy_to_prod` | The bot normally checks if a PR is blocked from merging before deploying to production. Setting this to `False` will skip that check. Default: `True` | bool | N | Example with all properties defined: @@ -315,7 +321,8 @@ Example with all properties defined: seed: full default_pr_start: "1 week ago" skip_pr_backfill: false - run_on_deploy_to_prod: true + run_on_deploy_to_prod: false + prod_branch_name: production ``` === "Python" @@ -338,7 +345,8 @@ Example with all properties defined: ), default_pr_start="1 week ago", skip_pr_backfill=False, - run_on_deploy_to_prod=True, + run_on_deploy_to_prod=False, + prod_branch_name="production", ) ) ``` @@ -350,12 +358,13 @@ These can be used to potentially trigger follow up steps in the workflow. These are the possible outputs (based on how the bot is configured) that are created by the bot: * `run_unit_tests` +* `linter` * `has_required_approval` * `pr_environment_synced` * `prod_plan_preview` * `prod_environment_synced` -[There are many possible conclusions](https://github.com/TobikoData/sqlmesh/blob/main/sqlmesh/integrations/github/cicd/controller.py#L96-L102) so the best use case for this is likely to check for `success` conclusion in order to potentially run follow up steps. +[There are many possible conclusions](https://github.com/SQLMesh/sqlmesh/blob/main/sqlmesh/integrations/github/cicd/controller.py#L96-L102) so the best use case for this is likely to check for `success` conclusion in order to potentially run follow up steps. Note that in error cases conclusions may not be set and therefore you will get an empty string. Example of running a step after pr environment has been synced: @@ -373,6 +382,8 @@ In addition, there are custom outputs listed below: * `created_pr_environment` - set to `"true"` (a string with a value of `true`) if a PR environment was created for the first time. It is absent, or considered empty string if you check for it, if it is not created for the first time * `pr_environment_name` - the name of the PR environment. It is output whenever PR environment synced check reaches a conclusion. Therefore make sure to check the status of `created_pr_environment` or `pr_environment_synced` before acting on this output +Note: The `linter` step will run only if it's enabled in the project's configuration (`config.yaml` / `config.py`). The step will fail if the linter finds errors, otherwise it'll output only the warnings. + ## Custom Workflow Configuration You can configure each individual action to run as a separate step. This can allow for more complex workflows or integrating specific steps with other actions you want to trigger. Run `sqlmesh_cicd github` to see a list of commands that can be supplied and their potential options. ```bash @@ -393,7 +404,7 @@ Commands: ``` ## Example Synchronized Full Workflow -This workflow involves configuring a SQLMesh connection to Databricks and configuring access to GCP to talk to Cloud Composer (Airflow). +This workflow involves configuring a SQLMesh connection to Databricks. ```yaml name: SQLMesh Bot @@ -447,11 +458,6 @@ jobs: - name: Install Dependencies run: pip install -r requirements.txt shell: bash - - id: auth - name: Authenticate to Google Cloud - uses: google-github-actions/auth@v1 - with: - credentials_json: '${{ secrets.GOOGLE_CREDENTIALS }}' - name: Run CI/CD Bot run: | sqlmesh_cicd -p ${{ github.workspace }} github --token ${{ secrets.GITHUB_TOKEN }} run-all @@ -460,6 +466,10 @@ jobs: ## Example Screenshots ### Automated Unit Tests with Error Summary ![Automated Unit Tests with Error Summary](github/github_test_summary.png) +### Automated Linting with Error Summary +![Automated Linting with Error Summary](github/linter_errors.png) +### Automated Linting with Warning Summary +![Automated Linting with Warning Summary](github/linter_warnings.png) ### Automatically create PR Environments that represent the code changes in the PR ![Environment Summary](github/github_env_summary.png) ### Enforce that certain reviewers have approved of the PR before it can be merged diff --git a/docs/integrations/github/linter_errors.png b/docs/integrations/github/linter_errors.png new file mode 100644 index 0000000000..d12b9f0ab2 Binary files /dev/null and b/docs/integrations/github/linter_errors.png differ diff --git a/docs/integrations/github/linter_warnings.png b/docs/integrations/github/linter_warnings.png new file mode 100644 index 0000000000..609ea1a524 Binary files /dev/null and b/docs/integrations/github/linter_warnings.png differ diff --git a/docs/integrations/overview.md b/docs/integrations/overview.md index 7200fc6cfc..94b9289d21 100644 --- a/docs/integrations/overview.md +++ b/docs/integrations/overview.md @@ -3,23 +3,27 @@ ## Tools SQLMesh supports integrations with the following tools: -* [Airflow](airflow.md) * [dbt](dbt.md) +* [dlt](dlt.md) * [GitHub Actions](github.md) * [Kestra](https://kestra.io/plugins/plugin-sqlmesh/tasks/cli/io.kestra.plugin.sqlmesh.cli.sqlmeshcli) ## Execution engines -SQLMesh supports the following execution engines for running SQLMesh projects: +SQLMesh supports the following execution engines for running SQLMesh projects (engine `type` in parentheses - example usage: `pip install "sqlmesh[databricks]"`): -* [BigQuery](./engines/bigquery.md) -* [Databricks](./engines/databricks.md) -* [DuckDB](./engines/duckdb.md) -* [MotherDuck](./engines/motherduck.md) -* [MySQL](./engines/mysql.md) -* [MSSQL](./engines/mssql.md) -* [Postgres](./engines/postgres.md) -* [GCP Postgres](./engines/gcp-postgres.md) -* [Redshift](./engines/redshift.md) -* [Snowflake](./engines/snowflake.md) -* [Spark](./engines/spark.md) -* [Trino](./engines/trino.md) +* [Athena](./engines/athena.md) (athena) +* [Azure SQL](./engines/azuresql.md) (azuresql) +* [BigQuery](./engines/bigquery.md) (bigquery) +* [ClickHouse](./engines/clickhouse.md) (clickhouse) +* [Databricks](./engines/databricks.md) (databricks) +* [DuckDB](./engines/duckdb.md) (duckdb) +* [Fabric](./engines/fabric.md) (fabric) +* [MotherDuck](./engines/motherduck.md) (motherduck) +* [MSSQL](./engines/mssql.md) (mssql) +* [MySQL](./engines/mysql.md) (mysql) +* [Postgres](./engines/postgres.md) (postgres) +* [GCP Postgres](./engines/gcp-postgres.md) (gcppostgres) +* [Redshift](./engines/redshift.md) (redshift) +* [Snowflake](./engines/snowflake.md) (snowflake) +* [Spark](./engines/spark.md) (spark) +* [Trino](./engines/trino.md) (trino) diff --git a/docs/prerequisites.md b/docs/prerequisites.md index 638132cdc4..11acfca64e 100644 --- a/docs/prerequisites.md +++ b/docs/prerequisites.md @@ -17,10 +17,6 @@ python --version **Note:** If `python --version` returns 2.x, replace all `python` commands with `python3`, and `pip` with `pip3`. -## Additional prerequisites for integrations - -If integrating with Airflow, you'll also need to install the SQLMesh Python package on all nodes of the Airflow cluster. For more information, refer to [Integrate with Airflow](./guides/scheduling.md#integrating-with-airflow). - ## Next steps Now that your machine meets the prerequisites, [install SQLMesh](installation.md). diff --git a/docs/quick_start.md b/docs/quick_start.md index d410829b1a..a80fe2c2d1 100644 --- a/docs/quick_start.md +++ b/docs/quick_start.md @@ -6,6 +6,7 @@ The example project runs locally on your machine with a DuckDB SQL engine, and S All you need to do is install SQLMesh on your machine - get started by ensuring your system meets the basic [prerequisites for using SQLMesh](./prerequisites.md). +Head over to the [CLI Quickstart](./quickstart/cli.md) or check out the video below. ## Video Quickstart @@ -13,4 +14,4 @@ This video walks through the quickstart installation, setup, and creating your f -If you're ready to keep going after your first plan, head over to the [CLI Quickstart](./quickstart/cli.md#3-update-a-model) to start modifying the project's models. \ No newline at end of file +If you're ready to keep going after your first plan, head over to the [CLI Quickstart Step 3](./quickstart/cli.md#3-update-a-model) to start modifying the project's models. \ No newline at end of file diff --git a/docs/quickstart/cli.md b/docs/quickstart/cli.md index aec94a2457..a592847470 100644 --- a/docs/quickstart/cli.md +++ b/docs/quickstart/cli.md @@ -1,6 +1,8 @@ # CLI -In this quick start guide, you'll use the SQLMesh command line interface (CLI) to get up and running with SQLMesh's scaffold generator. This example project will run locally on your computer using [DuckDB](https://duckdb.org/) as an embedded SQL engine. +In this quickstart, you'll use the SQLMesh command line interface (CLI) to get up and running with SQLMesh's scaffold generator. + +It will create an example project that runs locally on your computer using [DuckDB](https://duckdb.org/) as an embedded SQL engine. Before beginning, ensure that you meet all the [prerequisites](../prerequisites.md) for using SQLMesh. @@ -39,41 +41,180 @@ mkdir sqlmesh-example cd sqlmesh-example ``` -If using a python virtual environment, ensure it's activated first by running the `source .env/bin/activate` command from the folder used during [installation](../installation.md). +If using a Python virtual environment, ensure it's activated first by running the `source .venv/bin/activate` command from the folder used during [installation](../installation.md). + +### 1.1 Initialize the project + +SQLMesh includes a scaffold generator to initialize a new SQLMesh project. -Create a SQLMesh scaffold with the following command, specifying a default SQL dialect for your models. The dialect should correspond to the dialect most of your models are written in; it can be overridden for specific models in the model's `MODEL` specification. All SQL dialects [supported by the SQLGlot library](https://github.com/tobymao/sqlglot/blob/main/sqlglot/dialects/dialect.py) are allowed. +The scaffold generator will ask you some questions and create a SQLMesh configuration file based on your responses. -In this example, we specify the `duckdb` dialect: +Depending on your answers, it will also create multiple files for the SQLmesh example project used in this quickstart. + +Start the scaffold generator by executing the `sqlmesh init` command: ```bash -sqlmesh init duckdb +sqlmesh init +``` + +??? info "Skip the questions" + + If you don't want to use the interactive scaffold generator, you can initialize your project with arguments to the [`sqlmesh init` command](../reference/cli.md#init). + + The only required argument is `engine`, which specifies the SQL engine your project will use. Specify one of the engine `type`s in the [list of supported engines](../integrations/overview.md#execution-engines). + + In this example, we specify the `duckdb` engine: + + ```bash + sqlmesh init duckdb + ``` + + The scaffold will include a SQLMesh configuration file and example project directories and files. You're now ready to continue the quickstart [below](#2-create-a-prod-environment). + +#### Project type + +The first question asks about the type of project you want to create. Enter the number corresponding to the type of project you want to create and press `Enter`. + +``` bash +────────────────────────────── +Welcome to SQLMesh! +────────────────────────────── + +What type of project do you want to set up? + + [1] DEFAULT - Create SQLMesh example project models and files + [2] dbt - You have an existing dbt project and want to run it with SQLMesh + [3] EMPTY - Create a SQLMesh configuration file and project directories only + +Enter a number: 1 +``` + +For this quickstart, choose the `DEFAULT` option `1` so the example project files are included in the project directories. + +#### SQL engine + +The second question asks which SQL engine your project will use. SQLMesh will include that engine's connection settings in the configuration file, which you will fill in later to connect your project to the engine. + +For this quickstart, choose the `DuckDB` option `1` so we can run the example project with the built-in DuckDB engine that doesn't need additional configuration. + +``` bash +Choose your SQL engine: + + [1] DuckDB + [2] Snowflake + [3] Databricks + [4] BigQuery + [5] MotherDuck + [6] ClickHouse + [7] Redshift + [8] Spark + [9] Trino + [10] Azure SQL + [11] MSSQL + [12] Postgres + [13] GCP Postgres + [14] MySQL + [15] Athena + [16] RisingWave + +Enter a number: 1 +``` + +#### CLI mode + +SQLMesh's core commands have multiple options that alter their behavior. Some of those options streamline the SQLMesh `plan` workflow and CLI output. + +If you prefer a streamlined workflow (no prompts, no file diff previews, auto-apply changes), choose the `FLOW` CLI mode to automatically include those options in your project configuration file. + +If you prefer to see all the output SQLMesh provides, choose `DEFAULT` mode, which we will use in this quickstart: + +``` bash +Choose your SQLMesh CLI experience: + + [1] DEFAULT - See and control every detail + [2] FLOW - Automatically run changes and show summary output + +Enter a number: 1 ``` -The scaffold will include a SQLMesh configuration file for the example project. +#### Ready to go + +Your project is now ready to go, and SQLMesh displays a message with some good next steps. + +If you chose the DuckDB engine, you're ready to move forward and run the example project with DuckDB. -??? info "Learn more about the project's configuration" +If you chose a different engine, add your engine's connection information to the `config.yaml` file before you run any additional SQLMesh commands. + +``` bash +Your SQLMesh project is ready! + +Next steps: +- Update your gateway connection settings (e.g., username/password) in the project configuration file: + /sqlmesh-example/config.yaml +- Run command in CLI: sqlmesh plan +- (Optional) Explain a plan: sqlmesh plan --explain + +Quickstart guide: +https://sqlmesh.readthedocs.io/en/stable/quickstart/cli/ + +Need help? +- Docs: https://sqlmesh.readthedocs.io +- Slack: https://www.tobikodata.com/slack +- GitHub: https://github.com/SQLMesh/sqlmesh/issues +``` + +??? info "Learn more about the project's configuration: `config.yaml`" SQLMesh project-level configuration parameters are specified in the `config.yaml` file in the project directory. - This example project uses the embedded DuckDB SQL engine, so its configuration specifies `duckdb` as the local gateway's connection and the `local` gateway as the default. + This example project uses the embedded DuckDB SQL engine, so its configuration specifies `duckdb` as the gateway's connection type. All available configuration settings are included in the file, with optional settings set to their default value and commented out. - The command to run the scaffold generator **requires** a default SQL dialect for your models, which it places in the config `model_defaults` `dialect` key. In this example, we specified the `duckdb` SQL dialect as the default: + SQLMesh requires a default model SQL dialect. SQLMesh automatically specifies the SQL dialect for your project's SQL engine, which it places in the config `model_defaults` `dialect` key. In this example, we specified the DuckDB engine, so `duckdb` is the default SQL dialect: ```yaml linenums="1" + # --- Gateway Connection --- gateways: - local: + duckdb: connection: - type: duckdb - database: ./db.db - - default_gateway: local + # For more information on configuring the connection to your execution engine, visit: + # https://sqlmesh.readthedocs.io/en/stable/reference/configuration/#connection + # https://sqlmesh.readthedocs.io/en/stable/integrations/engines/duckdb/#connection-options + # + type: duckdb # <-- DuckDB engine + database: db.db + # concurrent_tasks: 1 + # register_comments: True # <-- Optional setting `register_comments` has a default value of True + # pre_ping: False + # pretty_sql: False + # catalogs: # <-- Optional setting `catalogs` has no default value + # extensions: + # connector_config: + # secrets: + # token: + + default_gateway: duckdb + + # --- Model Defaults --- + # https://sqlmesh.readthedocs.io/en/stable/reference/model_configuration/#model-defaults model_defaults: - dialect: duckdb + dialect: duckdb # <-- Models written in DuckDB SQL dialect by default + start: 2025-06-12 # Start date for backfill history + cron: '@daily' # Run models daily at 12am UTC (can override per model) + + # --- Linting Rules --- + # Enforce standards for your team + # https://sqlmesh.readthedocs.io/en/stable/guides/linter/ + + linter: + enabled: true + rules: + - ambiguousorinvalidcolumn + - invalidselectstarexpansion ``` Learn more about SQLMesh project configuration [here](../reference/configuration.md). -The scaffold will also include multiple directories where SQLMesh project files are stored and multiple files that constitute the example project (e.g., SQL models). +The scaffold generator creates multiple directories where SQLMesh project files are stored and multiple files that constitute the example project (e.g., SQL models). ??? info "Learn more about the project directories and files" SQLMesh uses a scaffold generator to initiate a new project. The generator will create multiple sub-directories and files for organizing your SQLMesh project code. @@ -106,7 +247,7 @@ The scaffold will also include multiple directories where SQLMesh project files - ./tests - test_full_model.yaml -Finally, the scaffold will include data for the example project to use. +Finally, the scaffold generator creates data for the example project to use. ??? info "Learn more about the project's data" The data used in this example project is contained in the `seed_data.csv` file in the `/seeds` project directory. The data reflects sales of 3 items over 7 days in January 2020. @@ -133,46 +274,121 @@ SQLMesh's key actions are creating and applying *plans* to *environments*. At th SQLMesh's key actions are creating and applying *plans* to *environments*. - A [SQLMesh environment](../concepts/environments.md) is an isolated namespace containing models and the data they generated. The most important environment is `prod` ("production"), which consists of the databases behind the applications your business uses to operate each day. Environments other than `prod` provide a place where you can test and preview changes to model code before they go live and affect business operations. + A [SQLMesh environment](../concepts/environments.md) is an isolated namespace containing models and the data they generated. + + The most important environment is `prod` ("production"), which consists of the databases behind the applications your business uses to operate each day. Environments other than `prod` provide a place where you can test and preview changes to model code before they go live and affect business operations. + + A [SQLMesh plan](../concepts/plans.md) contains a comparison of one environment to another and the set of changes needed to bring them into alignment. + + For example, if a new SQL model was added, tested, and run in the `dev` environment, it would need to be added and run in the `prod` environment to bring them into alignment. SQLMesh identifies all such changes and classifies them as either breaking or non-breaking. + + Breaking changes are those that invalidate data already existing in an environment. For example, if a `WHERE` clause was added to a model in the `dev` environment, existing data created by that model in the `prod` environment are now invalid because they may contain rows that would be filtered out by the new `WHERE` clause. + + Other changes, like adding a new column to a model in `dev`, are non-breaking because all the existing data in `prod` are still valid to use - only new data must be added to align the environments. - A [SQLMesh plan](../concepts/plans.md) contains a comparison of one environment to another and the set of changes needed to bring them into alignment. For example, if a new SQL model was added, tested, and run in the `dev` environment, it would need to be added and run in the `prod` environment to bring them into alignment. SQLMesh identifies all such changes and classifies them as either breaking or non-breaking. + After SQLMesh creates a plan, it summarizes the breaking and non-breaking changes so you can understand what will happen if you apply the plan. It will prompt you to "backfill" data to apply the plan. (In this context, backfill is a generic term for updating or adding to a table's data, including an initial load or full refresh.) - Breaking changes are those that invalidate data already existing in an environment. For example, if a `WHERE` clause was added to a model in the `dev` environment, existing data created by that model in the `prod` environment are now invalid because they may contain rows that would be filtered out by the new `WHERE` clause. Other changes, like adding a new column to a model in `dev`, are non-breaking because all the existing data in `prod` are still valid to use - only new data must be added to align the environments. +??? info "Learn more about a plan's actions: `sqlmesh plan --explain`" - After SQLMesh creates a plan, it summarizes the breaking and non-breaking changes so you can understand what will happen if you apply the plan. It will prompt you to "backfill" data to apply the plan - in this context, backfill is a generic term for updating or adding to a table's data (including an initial load or full refresh). + Before applying a plan, you can view a detailed description of the actions it will take by passing the explain flag in your `sqlmesh plan` command: + + ```bash + sqlmesh plan --explain + ``` + + Passing the explain flag for the quickstart example project above adds the following information to the output: + + ```bash + Explained plan + ├── Validate SQL and create physical layer tables and views if they do not exist + │ ├── sqlmesh_example.seed_model -> db.sqlmesh__sqlmesh_example.sqlmesh_example__seed_model__2185867172 + │ │ ├── Dry run model query without inserting results + │ │ └── Create table if it doesn't exist + │ ├── sqlmesh_example.full_model -> db.sqlmesh__sqlmesh_example.sqlmesh_example__full_model__2278521865 + │ │ ├── Dry run model query without inserting results + │ │ └── Create table if it doesn't exist + │ └── sqlmesh_example.incremental_model -> db.sqlmesh__sqlmesh_example.sqlmesh_example__incremental_model__1880815781 + │ ├── Dry run model query without inserting results + │ └── Create table if it doesn't exist + ├── Backfill models by running their queries and run standalone audits + │ ├── sqlmesh_example.seed_model -> db.sqlmesh__sqlmesh_example.sqlmesh_example__seed_model__2185867172 + │ │ └── Fully refresh table + │ ├── sqlmesh_example.full_model -> db.sqlmesh__sqlmesh_example.sqlmesh_example__full_model__2278521865 + │ │ ├── Fully refresh table + │ │ └── Run 'assert_positive_order_ids' audit + │ └── sqlmesh_example.incremental_model -> db.sqlmesh__sqlmesh_example.sqlmesh_example__incremental_model__1880815781 + │ └── Fully refresh table + └── Update the virtual layer for environment 'prod' + └── Create or update views in the virtual layer to point at new physical tables and views + ├── sqlmesh_example.full_model -> db.sqlmesh__sqlmesh_example.sqlmesh_example__full_model__2278521865 + ├── sqlmesh_example.seed_model -> db.sqlmesh__sqlmesh_example.sqlmesh_example__seed_model__2185867172 + └── sqlmesh_example.incremental_model -> db.sqlmesh__sqlmesh_example.sqlmesh_example__incremental_model__1880815781 + ``` + + The explanation has three top-level sections, corresponding to the three types of actions a plan takes: + + - Validate SQL and create physical layer tables and views if they do not exist + - Backfill models by running their queries and run standalone audits + - Update the virtual layer for environment 'prod' + + Each section lists the affected models and provides more information about what will occur. For example, the first model in the first section is: + + ```bash + ├── sqlmesh_example.seed_model -> db.sqlmesh__sqlmesh_example.sqlmesh_example__seed_model__2185867172 + │ ├── Dry run model query without inserting results + │ └── Create table if it doesn't exist + ``` + + The first line shows the model name `sqlmesh_example.seed_model` and the physical layer table SQLMesh will create to store its data: `db.sqlmesh__sqlmesh_example.sqlmesh_example__seed_model__2185867172`. The second and third lines tell us that in this step SQLMesh will dry-run the model query and create the physical layer table if it doesn't exist. + + The second section describes what will occur during the backfill step. The second model in this section is: + + ```bash + ├── sqlmesh_example.full_model -> db.sqlmesh__sqlmesh_example.sqlmesh_example__full_model__2278521865 + │ ├── Fully refresh table + │ └── Run 'assert_positive_order_ids' audit + ``` + + The first line shows the model name `sqlmesh_example.full_model` and the physical layer table SQLMesh will insert the model's data into: `db.sqlmesh__sqlmesh_example.sqlmesh_example__full_model__2278521865`. The second and third lines tell us that the backfill action will fully refresh the model's physical table and run the `assert_positive_order_ids` audit. + + The final section describes SQLMesh's action during the virtual layer update step. The first model in this section is: + + ```bash + └── Create or update views in the virtual layer to point at new physical tables and views + ├── sqlmesh_example.full_model -> db.sqlmesh__sqlmesh_example.sqlmesh_example__full_model__2278521865 + ``` + + The virtual layer step will update the `sqlmesh_example.full_model` virtual layer view to `SELECT * FROM` the physical table `db.sqlmesh__sqlmesh_example.sqlmesh_example__full_model__2278521865`. The first SQLMesh plan must execute every model to populate the production environment. Running `sqlmesh plan` will generate the plan and the following output: ```bash linenums="1" $ sqlmesh plan ====================================================================== -Successfully Ran 1 tests against duckdb +Successfully Ran 1 tests against duckdb in 0.1 seconds. ---------------------------------------------------------------------- -New environment `prod` will be created from `prod` -Summary of differences against `prod`: -└── Added Models: - ├── sqlmesh_example.seed_model + +`prod` environment will be initialized + +Models: +└── Added: + ├── sqlmesh_example.full_model ├── sqlmesh_example.incremental_model - └── sqlmesh_example.full_model -Models needing backfill (missing dates): -├── sqlmesh_example.full_model: 2020-01-01 - 2023-05-31 -├── sqlmesh_example.incremental_model: 2020-01-01 - 2023-05-31 -└── sqlmesh_example.seed_model: 2023-05-31 - 2023-05-31 + └── sqlmesh_example.seed_model +Models needing backfill: +├── sqlmesh_example.full_model: [full refresh] +├── sqlmesh_example.incremental_model: [2020-01-01 - 2025-06-22] +└── sqlmesh_example.seed_model: [full refresh] Apply - Backfill Tables [y/n]: ``` Line 3 of the output notes that `sqlmesh plan` successfully executed the project's test `tests/test_full_model.yaml` with duckdb. -Line 5 describes what environments the plan will affect when applied - a new `prod` environment in this case. - -Lines 7-10 of the output show that SQLMesh detected three new models relative to the current empty environment. +Line 6 describes what environments the plan will affect when applied - a new `prod` environment in this case. -Lines 11-14 list each model that will be executed by the plan, along with the date intervals that will be run. Note that `full_model` and `incremental_model` both show `2020-01-01` as their start date because: +Lines 8-12 of the output show that SQLMesh detected three new models relative to the current empty environment. -1. The incremental model specifies that date in the `start` property of its `MODEL` statement and -2. The full model depends on the incremental model. - -The `seed_model` date range begins on the same day the plan was made because `SEED` models have no temporality associated with them other than whether they have been modified since the previous SQLMesh plan. +Lines 13-16 list each model that will be executed by the plan, along with the date intervals or refresh types. For both `full_model` and `seed_model`, it shows `[full refresh]`, while for `incremental_model` it shows a specific date range `[2020-01-01 - 2025-06-22]`. The incremental model date range begins from 2020-01-01 because its definition specifies a model start date of `2020-01-01`. ??? info "Learn more about the project's models" @@ -249,24 +465,25 @@ The `seed_model` date range begins on the same day the plan was made because `SE GROUP BY item_id ``` -Line 15 asks you whether to proceed with executing the model backfills described in lines 11-14. Enter `y` and press `Enter`, and SQLMesh will execute the models and return this output: +Line 18 asks you whether to proceed with executing the model backfills described in lines 13-16. Enter `y` and press `Enter`, and SQLMesh will execute the models and return this output: ```bash linenums="1" Apply - Backfill Tables [y/n]: y -Creating physical tables ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 3/3 • 0:00:00 -All model versions have been created successfully +Updating physical layer ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 3/3 • 0:00:00 + +✔ Physical layer updated -[1/1] sqlmesh_example.seed_model evaluated in 0.01s -[1/1] sqlmesh_example.incremental_model evaluated in 0.01s -[1/1] sqlmesh_example.full_model evaluated in 0.02s -Evaluating models ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 3/3 • 0:00:00 +[1/1] sqlmesh_example.seed_model [insert seed file] 0.01s +[1/1] sqlmesh_example.incremental_model [insert 2020-01-01 - 2025-06-22] 0.01s +[1/1] sqlmesh_example.full_model [full refresh, audits ✔1] 0.01s +Executing model batches ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 3/3 • 0:00:00 -All model batches have been executed successfully +✔ Model batches executed -Virtually Updating 'prod' ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 0:00:00 +Updating virtual layer ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 3/3 • 0:00:00 -The target environment has been updated successfully +✔ Virtual layer updated ``` SQLMesh performs three actions when applying the plan: @@ -275,13 +492,25 @@ SQLMesh performs three actions when applying the plan: - Evaluating/running the models - Virtually updating the plan's target environment -Line 2 provides a progress bar and elapsed time for the first step of creating new model versions (very fast in this simple project). Line 4 reports that the first step has completed. +Lines 2-4 show the progress and completion of the first step - updating the physical layer (creating new model versions). + +Lines 6-11 show the execution of each model with their specific operations and timing. Line 6 shows the seed model being inserted, line 8 shows the incremental model being inserted for the specified date range, and line 10 shows the full model being processed with its audit check passing. + +Lines 12-14 show the progress and completion of the second step - executing model batches. + +Lines 16-18 show the progress and completion of the final step - virtually updating the plan's target environment, which makes the data available for querying. + +Let's take a quick look at the project's DuckDB database file to see the objects SQLMesh created. First, we open the built-in DuckDB CLI tool with the `duckdb db.db` command, then run our two queries. -Lines 6-8 show the run time for each model in the project. Line 9 provides a progress bar and total elapsed time for the second step of evaluating the models. Line 11 reports that the second step has completed. +Our first query shows the three physical tables SQLMesh created in the `sqlmesh__sqlmesh_example` schema (one table for each model): -Line 13 provides a progress bar and total elapsed time for the third step of virtually updating the plan's target environment. Line 15 reports that the third step has completed and the `prod` environment now points to the tables created during model execution. +![Example project physical layer tables in the DuckDB CLI](./cli/cli-quickstart_duckdb-tables.png) -You've now created a new production environment with all of history backfilled. +Our second query shows that in the `sqlmesh` schema SQLMesh created three virtual layer views that read from the three physical tables: + +![Example project virtual layer views in the DuckDB CLI](./cli/cli-quickstart_duckdb-views.png) + +You've now created a new production environment with all of history backfilled! ## 3. Update a model @@ -323,18 +552,23 @@ $ sqlmesh plan dev ====================================================================== Successfully Ran 1 tests against duckdb ---------------------------------------------------------------------- + New environment `dev` will be created from `prod` -Summary of differences against `dev`: + + +Differences from the `prod` environment: + Models: ├── Directly Modified: │ └── sqlmesh_example__dev.incremental_model └── Indirectly Modified: └── sqlmesh_example__dev.full_model + --- +++ -@@ -10,6 +10,7 @@ +@@ -14,6 +14,7 @@ SELECT id, @@ -343,42 +577,48 @@ Models: event_date FROM sqlmesh_example.seed_model WHERE -Directly Modified: sqlmesh_example__dev.incremental_model (Non-breaking) + +Directly Modified: sqlmesh_example__dev.incremental_model +(Non-breaking) └── Indirectly Modified Children: └── sqlmesh_example__dev.full_model (Indirect Non-breaking) -Models needing backfill (missing dates): -└── sqlmesh_example__dev.incremental_model: 2020-01-01 - 2023-05-31 -Enter the backfill start date (eg. '1 year', '2020-01-01') or blank to backfill from the beginning of history: +Models needing backfill: +└── sqlmesh_example__dev.incremental_model: [2020-01-01 - 2025-04-17] +Apply - Backfill Tables [y/n]: ``` -Line 5 of the output states that a new environment `dev` will be created from the existing `prod` environment. +Line 6 of the output states that a new environment `dev` will be created from the existing `prod` environment. -Lines 6-11 summarize the differences between the modified model and the `prod` environment, detecting that we directly modified `incremental_model` and that `full_model` was indirectly modified because it selects from the incremental model. Note that the model schemas are `sqlmesh_example__dev`, indicating that they are being created in the `dev` environment. +Lines 10-15 summarize the differences between the modified model and the `prod` environment, detecting that we directly modified `incremental_model` and that `full_model` was indirectly modified because it selects from the incremental model. Note that the model schemas are `sqlmesh_example__dev`, indicating that they are being created in the `dev` environment. -On line 25, we see that SQLMesh automatically classified the change as `Non-breaking` because it understood that the change was additive (added a column not used by `full_model`) and did not invalidate any data already in `prod`. +On line 31, we see that SQLMesh automatically classified the change as `Non-breaking` because it understood that the change was additive (added a column not used by `full_model`) and did not invalidate any data already in `prod`. -Hit `Enter` at the prompt to backfill data from our start date `2020-01-01`. Another prompt will appear asking for a backfill end date; hit `Enter` to backfill until now. Finally, enter `y` and press `Enter` to apply the plan and execute the backfill: +Enter `y` at the prompt and press `Enter` to apply the plan and execute the backfill: ```bash linenums="1" -Enter the backfill start date (eg. '1 year', '2020-01-01') or blank to backfill from the beginning of history: -Enter the backfill end date (eg. '1 month ago', '2020-01-01') or blank to backfill up until now: Apply - Backfill Tables [y/n]: y -Creating physical tables ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 2/2 • 0:00:00 -All model versions have been created successfully +Updating physical layer ━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 2/2 • 0:00:00 -[1/1] sqlmesh_example__dev.incremental_model evaluated in 0.01s -Evaluating models ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 1/1 • 0:00:00 +✔ Physical layer updated +[1/1] sqlmesh_example__dev.incremental_model [insert 2020-01-01 - 2025-04-17] 0.03s +Executing model batches ━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 1/1 • 0:00:00 -All model batches have been executed successfully +✔ Model batches executed -Virtually Updating 'dev' ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 0:00:00 +Updating virtual layer ━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 2/2 • 0:00:00 -The target environment has been updated successfully +✔ Virtual layer updated ``` -Line 8 of the output shows that SQLMesh applied the change and evaluated `sqlmesh_example__dev.incremental_model`. +Lines 3-5 show the progress and completion of updating the physical layer. + +Line 7 shows that SQLMesh applied the change and evaluated `sqlmesh_example__dev.incremental_model` for the date range from 2020-01-01 to 2025-04-17. + +Lines 9-11 show the progress and completion of executing model batches. + +Lines 13-15 show the progress and completion of updating the virtual layer. SQLMesh did not need to backfill anything for the `full_model` since the change was `Non-breaking`. @@ -431,17 +671,20 @@ $ sqlmesh plan ====================================================================== Successfully Ran 1 tests against duckdb ---------------------------------------------------------------------- -Summary of differences against `prod`: + +Differences from the `prod` environment: + Models: ├── Directly Modified: │ └── sqlmesh_example.incremental_model └── Indirectly Modified: └── sqlmesh_example.full_model + --- +++ -@@ -10,6 +10,7 @@ +@@ -14,6 +14,7 @@ SELECT id, @@ -450,18 +693,22 @@ Models: event_date FROM sqlmesh_example.seed_model WHERE + Directly Modified: sqlmesh_example.incremental_model (Non-breaking) └── Indirectly Modified Children: └── sqlmesh_example.full_model (Indirect Non-breaking) Apply - Virtual Update [y/n]: y -Virtually Updating 'prod' ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 0:00:00 -The target environment has been updated successfully +SKIP: No physical layer updates to perform + +SKIP: No model batches to execute + +Updating virtual layer ━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 2/2 • 0:00:00 -Virtual Update executed successfully +✔ Virtual layer updated ``` -Note that a backfill was not necessary and only a Virtual Update occurred. +Note that a backfill was not necessary and only a Virtual Update occurred, as indicated by the "SKIP: No physical layer updates to perform" and "SKIP: No model batches to execute" messages. This is because the changes were already calculated and executed in the `dev` environment, and SQLMesh is smart enough to recognize that it only needs to update the virtual references to the existing tables rather than recomputing everything. ### 5.2 Validate updates in prod Double-check that the data updated in `prod` by running `sqlmesh fetchdf "select * from sqlmesh_example.incremental_model"`: @@ -488,4 +735,4 @@ From here, you can: * [Learn more about SQLMesh CLI commands](../reference/cli.md) * [Set up a connection to a database or SQL engine](../guides/connections.md) * [Learn more about SQLMesh concepts](../concepts/overview.md) -* [Join our Slack community](https://tobikodata.com/slack) +* [Join our Slack community](https://tobikodata.com/slack) \ No newline at end of file diff --git a/docs/quickstart/cli/cli-quickstart_duckdb-tables.png b/docs/quickstart/cli/cli-quickstart_duckdb-tables.png new file mode 100644 index 0000000000..27d7180f8d Binary files /dev/null and b/docs/quickstart/cli/cli-quickstart_duckdb-tables.png differ diff --git a/docs/quickstart/cli/cli-quickstart_duckdb-views.png b/docs/quickstart/cli/cli-quickstart_duckdb-views.png new file mode 100644 index 0000000000..5d6af7fc87 Binary files /dev/null and b/docs/quickstart/cli/cli-quickstart_duckdb-views.png differ diff --git a/docs/quickstart/notebook.md b/docs/quickstart/notebook.md index bb0ce1d68d..a1dae6b822 100644 --- a/docs/quickstart/notebook.md +++ b/docs/quickstart/notebook.md @@ -1,6 +1,6 @@ # Notebook -In this quick start guide, you'll use the SQLMesh notebook interface to get up and running with SQLMesh's scaffold generator. This example project will run locally on your computer using [DuckDB](https://duckdb.org/) as an embedded SQL engine. +In this quickstart, you'll use the SQLMesh notebook interface to get up and running with SQLMesh's scaffold generator. This example project will run locally on your computer using [DuckDB](https://duckdb.org/) as an embedded SQL engine. Before beginning, ensure that you meet all the [prerequisites](../prerequisites.md) for using SQLMesh. @@ -34,17 +34,17 @@ The notebook interface works with both Jupyter and Databricks notebooks. Learn m ## 1. Create the SQLMesh project First, create a SQLMesh project directory with your operating system's graphical or command-line tools. Next, create a Jupyter or Databricks notebook file - it does not need to be in the SQLMesh project directory. -If using a python virtual environment, ensure it's activated first by running the `source .env/bin/activate` command from the folder used during [installation](../installation.md). +If using a python virtual environment, ensure it's activated first by running the `source .venv/bin/activate` command from the folder used during [installation](../installation.md). Import the SQLMesh library to load the notebook magic commands: -![Cell importing the SQLMesh library](./notebook/nb-quickstart_import.png) +![Cell importing the SQLMesh library](./notebook/nb-quickstart_import.png){ loading=lazy } Next, create a SQLMesh scaffold with the `%init` notebook magic, specifying a default SQL dialect for your models. The dialect should correspond to the dialect most of your models are written in; it can be overridden for specific models in the model's `MODEL` specification. All SQL dialects [supported by the SQLGlot library](https://github.com/tobymao/sqlglot/blob/main/sqlglot/dialects/dialect.py) are allowed. In this example, we specify the `duckdb` dialect: -![Notebook output after project initiation](./notebook/nb-quickstart_init.png) +![Notebook output after project initiation](./notebook/nb-quickstart_init.png){ loading=lazy } If the scaffold is successfully created, it will return `SQLMesh project scaffold created`. @@ -126,7 +126,7 @@ Finally, the scaffold will include data for the example project to use. Inform SQLMesh of the project location by setting a context with the `%context` notebook magic. If the context is set successfully, it will return a message including the repository or list of repositories: -![Notebook output after setting SQLMesh context](./notebook/nb-quickstart_context.png) +![Notebook output after setting SQLMesh context](./notebook/nb-quickstart_context.png){ loading=lazy } You can specify multiple directories in one call to `%context` if your SQLMesh project has [multiple repositories](../guides/multi_repo.md). @@ -240,7 +240,7 @@ The `seed_model` date range begins on the same day the plan was made because `SE Click the green button labeled `Apply - Backfill Tables` to apply the plan and initiate backfill. The following output will be displayed: -![Notebook output after plan application](./notebook/nb-quickstart_apply-plan.png) +![Notebook output after plan application](./notebook/nb-quickstart_apply-plan.png){ loading=lazy } The first output block shows the completion percentage and run time for each model (very fast in this simple example). The following line shows that the `prod` environment now points to the tables created during model execution. @@ -252,15 +252,15 @@ Now that we have have populated the `prod` environment, let's modify one of the We can modify the incremental SQL model using the `%model` *line* notebook magic (note the single `%`) and the model name: -![%model line magic for sqlmesh_example.incremental_model](./notebook/nb-quickstart_model-line.png) +![%model line magic for sqlmesh_example.incremental_model](./notebook/nb-quickstart_model-line.png){ loading=lazy } After we execute the cell, the contents will be replaced by the `%%model` *cell* notebook magic (note the double `%%`) and the model contents, along with a rendered version of the model SQL query. SQLMesh has automatically added explicit column aliases to the query (e.g., `id AS id`): -![%%model cell magic for sqlmesh_example.incremental_model](./notebook/nb-quickstart_model-cell.png) +![%%model cell magic for sqlmesh_example.incremental_model](./notebook/nb-quickstart_model-cell.png){ loading=lazy } We modify the incremental SQL model by adding a new column to the query. When we execute the cell it will write the updated model contents to the file and update the rendered version of the query: -![%%model cell magic for updated sqlmesh_example.incremental_model](./notebook/nb-quickstart_model-cell-updated.png) +![%%model cell magic for updated sqlmesh_example.incremental_model](./notebook/nb-quickstart_model-cell-updated.png){ loading=lazy } ## 4. Work with a development environment @@ -269,7 +269,7 @@ Now that you've modified a model, it's time to create a development environment Run `%plan dev` to create a development environment called `dev`. The following output will be displayed: -![Notebook output after dev plan creation](./notebook/nb-quickstart_plan-dev.png) +![Notebook output after dev plan creation](./notebook/nb-quickstart_plan-dev.png){ loading=lazy } The first block of output notes that `%plan` successfully executed the project's test `tests/test_full_model.yaml` with duckdb. @@ -283,7 +283,7 @@ The `Models needing backfill` section shows that only the directly modified `inc Click the green button to perform the backfill: -![Notebook output after dev plan application](./notebook/nb-quickstart_apply-plan-dev.png) +![Notebook output after dev plan application](./notebook/nb-quickstart_apply-plan-dev.png){ loading=lazy } The output shows that SQLMesh created a new model version in `dev`. The last line of the output shows that SQLMesh applied the change to `sqlmesh_example__dev.incremental_model`. In the model schema, the suffix "`__dev`" indicates that it is in the `dev` environment. @@ -294,7 +294,7 @@ You can now view this change by querying data from `incremental_model` with the Note that the environment name `__dev` is appended to the schema namespace `sqlmesh_example` in the query: -![Notebook output after executing %%fetchdf on `dev` incremental_model](./notebook/nb-quickstart_fetchdf-dev.png) +![Notebook output after executing %%fetchdf on `dev` incremental_model](./notebook/nb-quickstart_fetchdf-dev.png){ loading=lazy } You can see that `new_column` was added to the dataset. @@ -302,25 +302,25 @@ The production table was not modified; you can validate this by querying the pro Note that nothing has been appended to the schema namespace `sqlmesh_example` because `prod` is the default environment: -![Notebook output after executing %%fetchdf on prod incremental_model before model update applied](./notebook/nb-quickstart_fetchdf-prod.png) +![Notebook output after executing %%fetchdf on prod incremental_model before model update applied](./notebook/nb-quickstart_fetchdf-prod.png){ loading=lazy } The production table does not have `new_column` because the changes to `dev` have not yet been applied to `prod`. ## 5. Update the prod environment Now that we've tested the changes in dev, it's time to move them to production. Run `%plan` to plan and apply your changes to the `prod` environment: -![Notebook output after executing %plan on prod](./notebook/nb-quickstart_apply-plan-prod-modified.png) +![Notebook output after executing %plan on prod](./notebook/nb-quickstart_apply-plan-prod-modified.png){ loading=lazy } Click the green `Apply - Virtual Update` button to apply the plan and execute the backfill: -![Notebook output after executing applying virtual update on prod](./notebook/nb-quickstart_apply-plan-prod-modified-update.png) +![Notebook output after executing applying virtual update on prod](./notebook/nb-quickstart_apply-plan-prod-modified-update.png){ loading=lazy } Note that a backfill was not necessary and only a Virtual Update occurred. ### 5.2 Validate updates in prod Double-check that the data updated in `prod` by running `%%fetchdf` with the SQL query `select * from sqlmesh_example.incremental_model`: -![Notebook output after executing %%fetchdf on prod incremental_model after model update applied](./notebook/nb-quickstart_fetchdf-prod-modified.png) +![Notebook output after executing %%fetchdf on prod incremental_model after model update applied](./notebook/nb-quickstart_fetchdf-prod-modified.png){ loading=lazy } `new_column` is now present in the `prod` incremental model. diff --git a/docs/quickstart/ui.md b/docs/quickstart/ui.md index 42cf7c5a20..2891536876 100644 --- a/docs/quickstart/ui.md +++ b/docs/quickstart/ui.md @@ -1,6 +1,10 @@ # Browser UI -In this quick start guide, you'll use the SQLMesh browser user interface to get up and running with SQLMesh's scaffold generator. This example project will run locally on your computer using [DuckDB](https://duckdb.org/) as an embedded SQL engine. +!!! warning + + Browser UI is deprecated. Please use the [VSCode extension](../guides/vscode.md) instead. + +In this quickstart, you'll use the SQLMesh browser user interface to get up and running with SQLMesh's scaffold generator. This example project will run locally on your computer using [DuckDB](https://duckdb.org/) as an embedded SQL engine. ??? info "Learn more about the quickstart project structure" This project demonstrates key SQLMesh features by walking through the SQLMesh workflow on a simple data pipeline. This section describes the project structure and the SQLMesh concepts you will encounter as you work through it. @@ -31,7 +35,7 @@ In this quick start guide, you'll use the SQLMesh browser user interface to get Before beginning, ensure that you meet all the [prerequisites](../prerequisites.md) for using SQLMesh. The SQLMesh browser UI requires additional Python libraries not included in the base SQLMesh installation. -To use the UI, install SQLMesh with the `web` add-on. First, if using a python virtual environment, ensure it's activated by running `source .env/bin/activate` command from the folder used during [installation](../installation.md). +To use the UI, install SQLMesh with the `web` add-on. First, if using a python virtual environment, ensure it's activated by running `source .venv/bin/activate` command from the folder used during [installation](../installation.md). Next, install the UI with `pip`: @@ -52,7 +56,7 @@ Navigate to the directory on the command line: cd sqlmesh-example ``` -If using a python virtual environment, ensure it's activated by running `source .env/bin/activate` from the folder used during [installation](../installation.md). +If using a python virtual environment, ensure it's activated by running `source .venv/bin/activate` from the folder used during [installation](../installation.md). Create a SQLMesh scaffold with the following command, specifying a default SQL dialect for your models. The dialect should correspond to the dialect most of your models are written in; it can be overridden for specific models in the model's `MODEL` specification. All SQL dialects [supported by the SQLGlot library](https://github.com/tobymao/sqlglot/blob/main/sqlglot/dialects/dialect.py) are allowed. @@ -148,11 +152,11 @@ sqlmesh ui After starting up, the SQLMesh web UI is served at `http://127.0.0.1:8000` by default: -![SQLMesh web UI startup on CLI](./ui/ui-quickstart_cli.png) +![SQLMesh web UI startup on CLI](./ui/ui-quickstart_cli.png){ loading=lazy } Navigate to the URL by clicking the link in your terminal (if supported) or copy-pasting it into your web browser: -![SQLMesh web UI startup in browser](./ui/ui-quickstart_ui-startup.png) +![SQLMesh web UI startup in browser](./ui/ui-quickstart_ui-startup.png){ loading=lazy } The SQLMesh UI default view contains five panes: @@ -162,11 +166,11 @@ The SQLMesh UI default view contains five panes: 4. Inspector provides settings and information based on recent actions and the currently active pane. (Note: inspector pane is collapsed by default. Expand it by clicking the hamburger button at the top of the collapsed pane - see previous image.) 5. Details displays column-level lineage for models open in the editor and results of queries. (Note: details pane is collapsed by default. It will automatically expand upon opening a model in the editor or running a query.) -![SQLMesh web UI panes](./ui/ui-quickstart_ui-startup-panes.png) +![SQLMesh web UI panes](./ui/ui-quickstart_ui-startup-panes.png){ loading=lazy } It also contains nine buttons: -1. Toggle Editor/Docs/Errors toggles among the Code Editor (default), Docs, and Errors views. Errors view is only available if an error has occurred. +1. Toggle Editor/Data Catalog/Errors toggles among the Code Editor (default), Data Catalog, and Errors views. Errors view is only available if an error has occurred. 2. History navigation returns to previous views, similar to the back button in a web browser. 3. Add new tab opens a new code editor window. 4. Run plan command executes the [`sqlmesh plan` command](../reference/cli.md#plan). @@ -176,7 +180,7 @@ It also contains nine buttons: 8. Format SQL query reformats a SQL query using SQLGlot's pretty layout. 9. Change SQL dialect specifies the SQL dialect of the current tab for custom SQL queries. It does not affect the SQL dialect for the project. -![SQLMesh web UI buttons](./ui/ui-quickstart_ui-startup-buttons.png) +![SQLMesh web UI buttons](./ui/ui-quickstart_ui-startup-buttons.png){ loading=lazy } The default view contains four status indicators: @@ -185,7 +189,7 @@ The default view contains four status indicators: 3. Change indicator displays a summary of the changes in the project files relative to the most recently run SQLMesh plan in the selected environment. 4. Error indicator displays the count of errors in the project. -![SQLMesh web UI status indicators](./ui/ui-quickstart_ui-startup-status.png) +![SQLMesh web UI status indicators](./ui/ui-quickstart_ui-startup-status.png){ loading=lazy } ## 3. Plan and apply environments ### 3.1 Create a prod environment @@ -288,7 +292,7 @@ The pane contains multiple pieces of information about the plan: GROUP BY item_id ``` -![Run plan pane](./ui/ui-quickstart_run-plan.png) +![Run plan pane](./ui/ui-quickstart_run-plan.png){ loading=lazy } Click the blue button labeled `Apply Changes And Backfill` to apply the plan and initiate backfill. @@ -301,7 +305,7 @@ The `Snapshot Tables Created` indicates that [snapshots](../concepts/architectur The `Backfilled` section shows progress indicators for the backfill operations. The first progress indicator shows the total number of tasks and completion percentage for the entire backfill operation. The remaining progress bars show completion percentage and run time for each model (very fast in this simple example). -![Apply plan pane](./ui/ui-quickstart_apply-plan.png) +![Apply plan pane](./ui/ui-quickstart_apply-plan.png){ loading=lazy } Click the `Go Back` button to close the pane. @@ -312,21 +316,21 @@ Now that you've created the production environment, it's time to create a develo Open the environment menu by clicking the button labeled `prod \/` next to the green `Plan` button on the top right. Type `dev` into the Environment field and click the blue `Add` button. -![Open environment menu](./ui/ui-quickstart_create-dev.png) +![Open environment menu](./ui/ui-quickstart_create-dev.png){ loading=lazy } The button now shows that the SQLMesh UI is working in the `dev` environment: -![Working in dev environment](./ui/ui-quickstart_plan-dev.png) +![Working in dev environment](./ui/ui-quickstart_plan-dev.png){ loading=lazy } Click the green `Plan` button, and a new pane will open: -![Run plan on dev pane](./ui/ui-quickstart_run-plan-dev.png) +![Run plan on dev pane](./ui/ui-quickstart_run-plan-dev.png){ loading=lazy } The output section does not list any added/modified models or backfills because `dev` is being created from the existing `prod` environment without modification. Because the project has not been modified, no new computations need to run and a virtual update occurs. Click the blue `Apply Virtual Update` button to apply the new plan: -![Run plan on dev pane output](./ui/ui-quickstart_run-plan-dev-output.png) +![Run plan on dev pane output](./ui/ui-quickstart_run-plan-dev-output.png){ loading=lazy } The output confirms that the tests, virtual update, snapshot table creation, and environment promotion steps have completed. Click the `Go Back` button to close the pane. @@ -339,68 +343,68 @@ To modify the incremental SQL model, open it in the editor by clicking on it in The `Details` pane at the bottom displays the project's table and column lineage. -![Incremental model open in editor](./ui/ui-quickstart_incremental-model.png) +![Incremental model open in editor](./ui/ui-quickstart_incremental-model.png){ loading=lazy } Modify the incremental SQL model by adding a new column to the query. Press `Cmd + S` (`Ctrl + S` on Windows) to save the modified model file and display the updated lineage: -![Incremental model modified in editor](./ui/ui-quickstart_incremental-model-modified.png) +![Incremental model modified in editor](./ui/ui-quickstart_incremental-model-modified.png){ loading=lazy } ## 4. Plan and apply updates Preview the impact of the change by clicking the green `Plan` button in the top right. -![Plan pane after running plan with modified incremental model](./ui/ui-quickstart_run-plan-dev-modified.png) +![Plan pane after running plan with modified incremental model](./ui/ui-quickstart_run-plan-dev-modified.png){ loading=lazy } The `Changes` section detects that we directly modified `incremental_model` and that `full_model` was indirectly modified because it selects from the incremental model. SQLMesh understood that the change was additive (added a column not used by `full_model`) and was automatically classified as a non-breaking change. The `Backfill` section shows that only `incremental_model` requires backfill. Click the blue `Apply Changes And Backfill` button to apply the plan and execute the backfill: -![Plan after applying updated plan with modified incremental model](./ui/ui-quickstart_apply-plan-dev-modified.png) +![Plan after applying updated plan with modified incremental model](./ui/ui-quickstart_apply-plan-dev-modified.png){ loading=lazy } SQLMesh applies the change to `sqlmesh_example.incremental_model` and backfills the model. The `Backfilled` section shows that the backfill completed successfully. ### 4.1 Validate updates in dev You can now view this change by querying data from `incremental_model`. Add the SQL query `select * from sqlmesh_example__dev.incremental_model` to the Custom SQL 1 tab in the editor: -![Querying `dev` incremental model with SQL query in editor](./ui/ui-quickstart_fetchdf-dev.png) +![Querying `dev` incremental model with SQL query in editor](./ui/ui-quickstart_fetchdf-dev.png){ loading=lazy } Note that the environment name `__dev` is appended to the schema namespace `sqlmesh_example` in the query: `select * from sqlmesh_example__dev.incremental_model`. Click the `Run Query` button in the bottom right to execute the query: -![Results from querying dev incremental model with SQL query in editor](./ui/ui-quickstart_fetchdf-dev-results.png) +![Results from querying dev incremental model with SQL query in editor](./ui/ui-quickstart_fetchdf-dev-results.png){ loading=lazy } You can see that `new_column` was added to the dataset. The production table was not modified; you can validate this by modifying the query so it selects from the production table with `select * from sqlmesh_example.incremental_model`. Note that nothing has been appended to the schema namespace `sqlmesh_example` because `prod` is the default environment. -![Results from querying `prod` incremental model with SQL query in editor](./ui/ui-quickstart_fetchdf-prod.png) +![Results from querying `prod` incremental model with SQL query in editor](./ui/ui-quickstart_fetchdf-prod.png){ loading=lazy } The production table does not have `new_column` because the changes to `dev` have not yet been applied to `prod`. ### 4.2 Apply updates to prod Now that we've tested the changes in dev, it's time to move them to prod. Open the environment menu in top right and select the `prod` environment: -![`prod` environment selected in environment menu](./ui/ui-quickstart_plan-prod-modified.png) +![`prod` environment selected in environment menu](./ui/ui-quickstart_plan-prod-modified.png){ loading=lazy } Click the green `Plan` button to open the run plan interface: -![`prod` environment plan pane](./ui/ui-quickstart_plan-prod-modified-pane.png) +![`prod` environment plan pane](./ui/ui-quickstart_plan-prod-modified-pane.png){ loading=lazy } Click the blue `Apply Virtual Update` button, and a warning screen will appear: -![`prod` environment modification warning](./ui/ui-quickstart_plan-prod-modified-warning.png) +![`prod` environment modification warning](./ui/ui-quickstart_plan-prod-modified-warning.png){ loading=lazy } Click the `Yes, Run prod` button to proceed with applying the plan: -![`prod` environment after applying plan](./ui/ui-quickstart_apply-plan-prod-modified.png) +![`prod` environment after applying plan](./ui/ui-quickstart_apply-plan-prod-modified.png){ loading=lazy } Note that a backfill was not necessary and only a Virtual Update occurred - the computations have already occurred when backfilling the model in `dev`. Click the `Go Back` button to close the pane. ### 4.3. Validate updates in prod Double-check that the data updated in `prod` by re-running the SQL query from the editor. Click the `Run Query` button to execute the query: -![Results from querying updated `prod` incremental model with SQL query in editor](./ui/ui-quickstart_fetchdf-prod-modified.png) +![Results from querying updated `prod` incremental model with SQL query in editor](./ui/ui-quickstart_fetchdf-prod-modified.png){ loading=lazy } `new_column` is now present in the `prod` incremental model. diff --git a/docs/readme/architecture_diagram.png b/docs/readme/architecture_diagram.png new file mode 100644 index 0000000000..e3da16d1d6 Binary files /dev/null and b/docs/readme/architecture_diagram.png differ diff --git a/docs/readme/docs-site_2nd-level-nav_get-started.png b/docs/readme/docs-site_2nd-level-nav_get-started.png new file mode 100644 index 0000000000..43e4b01c0f Binary files /dev/null and b/docs/readme/docs-site_2nd-level-nav_get-started.png differ diff --git a/docs/readme/docs-site_2nd-level-nav_guides.png b/docs/readme/docs-site_2nd-level-nav_guides.png new file mode 100644 index 0000000000..da66d5a0dd Binary files /dev/null and b/docs/readme/docs-site_2nd-level-nav_guides.png differ diff --git a/docs/readme/docs-site_code-block-options.png b/docs/readme/docs-site_code-block-options.png new file mode 100644 index 0000000000..627fae56e1 Binary files /dev/null and b/docs/readme/docs-site_code-block-options.png differ diff --git a/docs/readme/docs-site_mkdocs-extra.png b/docs/readme/docs-site_mkdocs-extra.png new file mode 100644 index 0000000000..97192a4bc7 Binary files /dev/null and b/docs/readme/docs-site_mkdocs-extra.png differ diff --git a/docs/readme/docs-site_mkdocs-plugins.png b/docs/readme/docs-site_mkdocs-plugins.png new file mode 100644 index 0000000000..29a298e893 Binary files /dev/null and b/docs/readme/docs-site_mkdocs-plugins.png differ diff --git a/docs/readme/docs-site_mkdocs-theme.png b/docs/readme/docs-site_mkdocs-theme.png new file mode 100644 index 0000000000..552f4f3bb9 Binary files /dev/null and b/docs/readme/docs-site_mkdocs-theme.png differ diff --git a/docs/readme/docs-site_top-level-nav.png b/docs/readme/docs-site_top-level-nav.png new file mode 100644 index 0000000000..9a79c7b445 Binary files /dev/null and b/docs/readme/docs-site_top-level-nav.png differ diff --git a/docs/readme/docs-site_within-page-nav_config-guide.png b/docs/readme/docs-site_within-page-nav_config-guide.png new file mode 100644 index 0000000000..cd72cb8a75 Binary files /dev/null and b/docs/readme/docs-site_within-page-nav_config-guide.png differ diff --git a/docs/readme/mkdocs-file.png b/docs/readme/mkdocs-file.png new file mode 100644 index 0000000000..d1f91cd9b9 Binary files /dev/null and b/docs/readme/mkdocs-file.png differ diff --git a/docs/readme/sqlmesh.png b/docs/readme/sqlmesh.png new file mode 100644 index 0000000000..1adce67650 Binary files /dev/null and b/docs/readme/sqlmesh.png differ diff --git a/docs/readme/transpile_example.png b/docs/readme/transpile_example.png new file mode 100644 index 0000000000..5dbb323a44 Binary files /dev/null and b/docs/readme/transpile_example.png differ diff --git a/docs/reference/cli.md b/docs/reference/cli.md index e1957f0eeb..a9ce9366e1 100644 --- a/docs/reference/cli.md +++ b/docs/reference/cli.md @@ -23,7 +23,10 @@ Commands: create_external_models Create a schema file containing external model... create_test Generate a unit test fixture for a given model. dag Render the DAG as an html file. + destroy The destroy command removes all project resources. diff Show the diff between the local state and the... + dlt_refresh Attaches to a DLT pipeline with the option to... + environments Prints the list of SQLMesh environments with... evaluate Evaluate a model and return a dataframe with a... fetchdf Run a SQL query and display the results. format Format all SQL models and audits. @@ -38,10 +41,12 @@ Commands: rewrite Rewrite a SQL expression with semantic... rollback Rollback SQLMesh to the previous migration. run Evaluate missing intervals for the target... + state Commands for interacting with state table_diff Show the diff between two tables. table_name Prints the name of the physical table for the... test Run model unit tests. ui Start a browser-based SQLMesh UI. + lint Run the linter for the target model(s). ``` ## audit @@ -61,6 +66,24 @@ Options: --help Show this message and exit. ``` +## check_intervals + +``` +Usage: sqlmesh check_intervals [OPTIONS] [ENVIRONMENT] + + Show missing intervals in an environment, respecting signals. + +Options: + --no-signals Disable signal checks and only show missing intervals. + --select-model TEXT Select specific models to show missing intervals for. + -s, --start TEXT The start datetime of the interval for which this + command will be applied. + -e, --end TEXT The end datetime of the interval for which this command + will be applied. + --help Show this message and exit. +``` + + ## clean ``` @@ -92,7 +115,7 @@ Usage: sqlmesh create_test [OPTIONS] MODEL Options: -q, --query ... Queries that will be used to generate data for - the model's dependencies. [required] + the model's dependencies. -o, --overwrite When true, the fixture file will be overwritten in case it already exists. -v, --var ... Key-value pairs that will define variables @@ -121,6 +144,30 @@ Options: --help Show this message and exit. ``` +## destroy + +``` +Usage: sqlmesh destroy + + Removes all state tables, the SQLMesh cache and all project resources, including warehouse objects. This includes all tables, views and schemas managed by SQLMesh, as well as any external resources that may have been created by other tools within those schemas. + +Options: + --help Show this message and exit. +``` + +## dlt_refresh + +``` +Usage: dlt_refresh PIPELINE [OPTIONS] + + Attaches to a DLT pipeline with the option to update specific or all models of the SQLMesh project. + +Options: + -t, --table TEXT The DLT tables to generate SQLMesh models from. When none specified, all new missing tables will be generated. + -f, --force If set it will overwrite existing models with the new generated models from the DLT tables. + --help Show this message and exit. +``` + ## diff ``` @@ -132,6 +179,16 @@ Options: --help Show this message and exit. ``` +## environments +``` +Usage: sqlmesh environments [OPTIONS] + + Prints the list of SQLMesh environments with its expiry datetime. + +Options: + --help Show this message and exit. +``` + ## evaluate ``` @@ -172,6 +229,8 @@ Options: -t, --transpile TEXT Transpile project models to the specified dialect. --append-newline Include a newline at the end of each file. + --no-rewrite-casts Preserve the existing casts, without rewriting + them to use the :: syntax. --normalize Whether or not to normalize identifiers to lowercase. --pad INTEGER Determines the pad size in a formatted string. @@ -184,6 +243,8 @@ Options: trailing. --max-text-width INTEGER The max number of characters in a segment before creating new lines in pretty mode. + --check Whether or not to check formatting (but not + actually format anything). --help Show this message and exit. ``` @@ -198,19 +259,25 @@ Usage: sqlmesh info [OPTIONS] data warehouse. Options: + --skip-connection Skip the connection test. + -v, --verbose Verbose output. --help Show this message and exit. ``` ## init ``` -Usage: sqlmesh init [OPTIONS] [SQL_DIALECT] +Usage: sqlmesh init [OPTIONS] [ENGINE] Create a new SQLMesh repository. Options: - -t, --template TEXT Project template. Supported values: airflow, dbt, - default, empty. + -t, --template TEXT Project template. Supported values: dbt, dlt, default, + empty. + --dlt-pipeline TEXT DLT pipeline for which to generate a SQLMesh project. + Use alongside template: dlt + --dlt-path TEXT The directory where the DLT pipeline resides. Use + alongside template: dlt --help Show this message and exit. ``` @@ -256,7 +323,9 @@ Options: --help Show this message and exit. ``` -**Caution**: this command affects all SQLMesh users. Contact your SQLMesh administrator before running. +!!! danger "Caution" + + The `migrate` command affects all SQLMesh users. Contact your SQLMesh administrator before running. ## plan @@ -276,6 +345,8 @@ Options: Default: prod. --skip-tests Skip tests prior to generating the plan if they are defined. + --skip-linter Skip linting prior to generating the plan if + the linter is enabled. -r, --restate-model TEXT Restate data for specified models and models downstream from the one specified. For production environment, all related model @@ -287,10 +358,17 @@ Options: --no-gaps Ensure that new snapshots have no data gaps when comparing to existing snapshots for matching models in the target environment. - --skip-backfill Skip the backfill step. + --skip-backfill, --dry-run Skip the backfill step and only create a + virtual update for the plan. + --empty-backfill Produce empty backfill. Like --skip-backfill + no models will be backfilled, unlike --skip- + backfill missing intervals will be recorded + as if they were backfilled. --forward-only Create a plan for forward-only changes. --allow-destructive-model TEXT Allow destructive forward-only changes to models whose names match the expression. + --allow-additive-model TEXT Allow additive forward-only changes to + models whose names match the expression. --effective-from TEXT The effective date from which to apply forward-only changes on production. --no-prompts Disable interactive prompts for the backfill @@ -305,14 +383,18 @@ Options: --select-model TEXT Select specific model changes that should be included in the plan. --backfill-model TEXT Backfill only the models whose names match - the expression. This is supported only when - targeting a development environment. + the expression. --no-diff Hide text differences for changed models. --run Run latest intervals as part of the plan application (prod environment only). --enable-preview Enable preview for forward-only models when targeting a development environment. - -v, --verbose Verbose output. + --diff-rendered Output text differences for the rendered + versions of the models and standalone + audits. + --explain Explain the plan instead of applying it. + -v, --verbose Verbose output. Use -vv for very verbose + output. --help Show this message and exit. ``` @@ -341,19 +423,31 @@ Usage: sqlmesh render [OPTIONS] MODEL Render a model's query, optionally expanding referenced models. Options: - -s, --start TEXT The start datetime of the interval for which this - command will be applied. - -e, --end TEXT The end datetime of the interval for which this - command will be applied. - --execution-time TEXT The execution time (defaults to now). - --expand TEXT Whether or not to expand materialized models - (defaults to False). If True, all referenced models - are expanded as raw queries. Multiple model names can - also be specified, in which case only they will be - expanded as raw queries. - --dialect TEXT The SQL dialect to render the query as. - --no-format Disable fancy formatting of the query. - --help Show this message and exit. + -s, --start TEXT The start datetime of the interval for which + this command will be applied. + -e, --end TEXT The end datetime of the interval for which this + command will be applied. + --execution-time TEXT The execution time (defaults to now). + --expand TEXT Whether or not to expand materialized models + (defaults to False). If True, all referenced + models are expanded as raw queries. Multiple + model names can also be specified, in which case + only they will be expanded as raw queries. + --dialect TEXT The SQL dialect to render the query as. + --no-format Disable fancy formatting of the query. + --max-text-width INTEGER The max number of characters in a segment before + creating new lines in pretty mode. + --leading-comma Determines whether or not the comma is leading + or trailing in select expressions. Default is + trailing. + --normalize-functions TEXT Whether or not to normalize all function names. + Possible values are: 'upper', 'lower' + --indent INTEGER Determines the indentation size in a formatted + string. + --pad INTEGER Determines the pad size in a formatted string. + --normalize Whether or not to normalize identifiers to + lowercase. + --help Show this message and exit. ``` ## rewrite @@ -382,7 +476,9 @@ Options: --help Show this message and exit. ``` -**Caution**: this command affects all SQLMesh users. Contact your SQLMesh administrator before running. +!!! danger "Caution" + + The `rollback` command affects all SQLMesh users. Contact your SQLMesh administrator before running. ## run @@ -392,14 +488,73 @@ Usage: sqlmesh run [OPTIONS] [ENVIRONMENT] Evaluate missing intervals for the target environment. Options: - -s, --start TEXT The start datetime of the interval for which this command - will be applied. - -e, --end TEXT The end datetime of the interval for which this command - will be applied. - --skip-janitor Skip the janitor task. - --ignore-cron Run for all missing intervals, ignoring individual cron - schedules. - --help Show this message and exit. + -s, --start TEXT The start datetime of the interval for which + this command will be applied. + -e, --end TEXT The end datetime of the interval for which + this command will be applied. + --skip-janitor Skip the janitor task. + --ignore-cron Run for all missing intervals, ignoring + individual cron schedules. + --select-model TEXT Select specific models to run. Note: this + always includes upstream dependencies. + --exit-on-env-update INTEGER If set, the command will exit with the + specified code if the run is interrupted by an + update to the target environment. + --no-auto-upstream Do not automatically include upstream models. + Only applicable when --select-model is used. + Note: this may result in missing / invalid + data for the selected models. + --help Show this message and exit. +``` + +## state + +``` +Usage: sqlmesh state [OPTIONS] COMMAND [ARGS]... + + Commands for interacting with state + +Options: + --help Show this message and exit. + +Commands: + export Export the state database to a file + import Import a state export file back into the state database +``` + +### export + +``` +Usage: sqlmesh state export [OPTIONS] + + Export the state database to a file + +Options: + -o, --output-file FILE Path to write the state export to [required] + --environment TEXT Name of environment to export. Specify multiple + --environment arguments to export multiple + environments + --local Export local state only. Note that the resulting + file will not be importable + --no-confirm Do not prompt for confirmation before exporting + existing state + --help Show this message and exit. +``` + +### import + +``` +Usage: sqlmesh state import [OPTIONS] + + Import a state export file back into the state database + +Options: + -i, --input-file FILE Path to the state file [required] + --replace Clear the remote state before loading the file. If + omitted, a merge is performed instead + --no-confirm Do not prompt for confirmation before updating + existing state + --help Show this message and exit. ``` ## table_diff @@ -407,7 +562,7 @@ Options: ``` Usage: sqlmesh table_diff [OPTIONS] SOURCE:TARGET [MODEL] - Show the diff between two tables. + Show the diff between two tables or multiple models across two environments. Options: -o, --on TEXT The column to join on. Can be specified multiple @@ -423,6 +578,12 @@ Options: floating point columns. Default: 3 --skip-grain-check Disable the check for a primary key (grain) that is missing or is not unique. + --warn-grain-check Warn if any selected model is missing a grain, + and compute diffs for the remaining models. + --temp-schema TEXT Schema used for temporary tables. It can be + `CATALOG.SCHEMA` or `SCHEMA`. Default: + `sqlmesh_temp` + -m, --select-model TEXT Select specific models to table diff. --help Show this message and exit. ``` @@ -434,9 +595,11 @@ Usage: sqlmesh table_name [OPTIONS] MODEL_NAME Prints the name of the physical table for the given model. Options: - --dev Print the name of the snapshot table used for previews in - development environments. - --help Show this message and exit. + --environment, --env TEXT The environment to source the model version from. + --prod If set, return the name of the physical table + that will be used in production for the model + version promoted in the target environment. + --help Show this message and exit. ``` ## test @@ -464,6 +627,17 @@ Usage: sqlmesh ui [OPTIONS] Options: --host TEXT Bind socket to this host. Default: 127.0.0.1 --port INTEGER Bind socket to this port. Default: 8000 - --mode [ide|default|docs|plan] Mode to start the UI in. Default: default + --mode [ide|catalog|docs|plan] Mode to start the UI in. Default: ide --help Show this message and exit. ``` + +## lint +``` +Usage: sqlmesh lint [OPTIONS] + Run linter for the target model(s). + +Options: + --model TEXT A model to lint. Multiple models can be linted. If no models are specified, every model will be linted. + --help Show this message and exit. + +``` diff --git a/docs/reference/configuration.md b/docs/reference/configuration.md index 7c93d254f2..b13438ee2d 100644 --- a/docs/reference/configuration.md +++ b/docs/reference/configuration.md @@ -16,28 +16,45 @@ This section describes the other root level configuration parameters. Configuration options for SQLMesh project directories. -| Option | Description | Type | Required | -| ----------------- | ------------------------------------------------------------------------------------------------------------------ | :----------: | :------: | -| `ignore_patterns` | Files that match glob patterns specified in this list are ignored when scanning the project folder (Default: `[]`) | list[string] | N | -| `project` | The project name of this config. Used for [multi-repo setups](../guides/multi_repo.md). | string | N | +| Option | Description | Type | Required | +| ------------------ | --------------------------------------------------------------------------------------------------------------------------- | :----------: | :------: | +| `ignore_patterns` | Files that match glob patterns specified in this list are ignored when scanning the project folder (Default: `[]`) | list[string] | N | +| `project` | The project name of this config. Used for [multi-repo setups](../guides/multi_repo.md). | string | N | +| `cache_dir` | The directory to store the SQLMesh cache. Can be an absolute path or relative to the project directory. (Default: `.cache`) | string | N | +| `log_limit` | The default number of historical log files to keep (Default: `20`) | int | N | -### Environments +### Database (Physical Layer) -Configuration options for SQLMesh environment creation and promotion. +Configuration options for how SQLMesh manages database objects in the [physical layer](../concepts/glossary.md#physical-layer). | Option | Description | Type | Required | |-------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:--------------------:|:--------:| | `snapshot_ttl` | The period of time that a model snapshot not a part of any environment should exist before being deleted. This is defined as a string with the default `in 1 week`. Other [relative dates](https://dateparser.readthedocs.io/en/latest/) can be used, such as `in 30 days`. (Default: `in 1 week`) | string | N | +| `physical_schema_override` | (Deprecated) Use `physical_schema_mapping` instead. A mapping from model schema names to names of schemas in which physical tables for the corresponding models will be placed. | dict[string, string] | N | +| `physical_schema_mapping` | A mapping from regular expressions to names of schemas in which physical tables for the corresponding models [will be placed](../guides/configuration.md#physical-table-schemas). (Default physical schema name: `sqlmesh__[model schema]`) | dict[string, string] | N | +| `physical_table_naming_convention`| Sets which parts of the model name are included in the physical table names. Options are `schema_and_table`, `table_only` or `hash_md5` - [additional details](../guides/configuration.md#physical-table-naming-convention). (Default: `schema_and_table`) | string | N | + +### Environments (Virtual Layer) + +Configuration options for how SQLMesh manages environment creation and promotion in the [virtual layer](../concepts/glossary.md#virtual-layer). + +| Option | Description | Type | Required | +|-------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:--------------------:|:--------:| | `environment_ttl` | The period of time that a development environment should exist before being deleted. This is defined as a string with the default `in 1 week`. Other [relative dates](https://dateparser.readthedocs.io/en/latest/) can be used, such as `in 30 days`. (Default: `in 1 week`) | string | N | | `pinned_environments` | The list of development environments that are exempt from deletion due to expiration | list[string] | N | -| `time_column_format` | The default format to use for all model time columns. This time format uses [python format codes](https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes) (Default: `%Y-%m-%d`) | string | N | | `default_target_environment` | The name of the environment that will be the default target for the `sqlmesh plan` and `sqlmesh run` commands. (Default: `prod`) | string | N | -| `physical_schema_override` | A mapping from model schema names to names of schemas in which physical tables for the corresponding models will be placed - [addition details](../guides/configuration.md#physical-schema-override). (Default physical schema name: `sqlmesh__[model schema]`) | string | N | -| `environment_suffix_target` | Whether SQLMesh views should append their environment name to the `schema` or `table` - [additional details](../guides/configuration.md#view-schema-override). (Default: `schema`) | string | N | +| `environment_suffix_target` | Whether SQLMesh views should append their environment name to the `schema`, `table` or `catalog` - [additional details](../guides/configuration.md#view-schema-override). (Default: `schema`) | string | N | +| `gateway_managed_virtual_layer` | Whether SQLMesh views of the virtual layer will be created by the default gateway or model specified gateways - [additional details](../guides/multi_engine.md#gateway-managed-virtual-layer). (Default: False) | boolean | N | | `environment_catalog_mapping` | A mapping from regular expressions to catalog names. The catalog name is used to determine the target catalog for a given environment. | dict[string, string] | N | -| `log_limit` | The default number of logs to keep (Default: `20`) | int | N | +| `virtual_environment_mode` | Determines the Virtual Data Environment (VDE) mode. If set to `full`, VDE is used in both production and development environments. The `dev_only` option enables VDE only in development environments, while in production, no virtual layer is used and models are materialized directly using their original names (i.e., no versioned physical tables). (Default: `full`) | string | N | -### Model defaults +### Models + +| Option | Description | Type | Required | +|-------------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:--------------------:|:--------:| +| `time_column_format` | The default format to use for all model time columns. This time format uses [python format codes](https://docs.python.org/3/library/datetime.html#strftime-and-strptime-format-codes) (Default: `%Y-%m-%d`) | string | N | +| `infer_python_dependencies` | Whether SQLMesh will statically analyze Python code to automatically infer Python package requirements. (Default: True) | boolean | N | +| `model_defaults` | Default [properties](./model_configuration.md#model-defaults) to set on each model. At a minimum, `dialect` must be set. | dict[string, any] | Y | The `model_defaults` key is **required** and must contain a value for the `dialect` key. @@ -55,20 +72,29 @@ Global variable values may be any of the data types in the table below or lists |-------------|-------------------------------------|:------------------------------------------------------------:|:--------:| | `variables` | Mapping of variable names to values | dict[string, int \| float \| bool \| string \| list \| dict] | N | +### Before_all / after_all + +The `before_all` and `after_all` keys can be used to specify lists of SQL statements and/or SQLMesh macros that are executed at the start and end, respectively, of the `sqlmesh plan` and `sqlmesh run` commands. For more information and examples, see [the configuration guide](../guides/configuration.md#before_all-and-after_all-statements). + +| Option | Description | Type | Required | +|--------------|--------------------------------------------------------------------------------------|:------------:|:--------:| +| `before_all` | List of SQL statements to be executed at the start of the `plan` and `run` commands. | list[string] | N | +| `after_all` | List of SQL statements to be executed at the end of the `plan` and `run` commands. | list[string] | N | ## Plan Configuration for the `sqlmesh plan` command. -| Option | Description | Type | Required | -| ------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------: | :------: | -| `auto_categorize_changes` | Indicates whether SQLMesh should attempt to automatically [categorize](../concepts/plans.md#change-categories) model changes during plan creation per each model source type ([additional details](../guides/configuration.md#auto-categorize-changes)) | dict[string, string] | N | -| `include_unmodified` | Indicates whether to create views for all models in the target development environment or only for modified ones (Default: False) | boolean | N | -| `auto_apply` | Indicates whether to automatically apply a new plan after creation (Default: False) | boolean | N | -| `forward_only` | Indicates whether the plan should be [forward-only](../concepts/plans.md#forward-only-plans) (Default: False) | boolean | N | -| `enable_preview` | Indicates whether to enable [data preview](../concepts/plans.md#data-preview) for forward-only models when targeting a development environment (Default: False) | boolean | N | -| `no_diff` | Don't show diffs for changed models (Default: False) | boolean | N | -| `no_prompts` | Disables interactive prompts in CLI (Default: False) | boolean | N | +| Option | Description | Type | Required | +|---------------------------|---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:--------------------:|:--------:| +| `auto_categorize_changes` | Indicates whether SQLMesh should attempt to automatically [categorize](../concepts/plans.md#change-categories) model changes during plan creation per each model source type ([additional details](../guides/configuration.md#auto-categorize-changes)) | dict[string, string] | N | +| `include_unmodified` | Indicates whether to create views for all models in the target development environment or only for modified ones (Default: False) | boolean | N | +| `auto_apply` | Indicates whether to automatically apply a new plan after creation (Default: False) | boolean | N | +| `forward_only` | Indicates whether the plan should be [forward-only](../concepts/plans.md#forward-only-plans) (Default: False) | boolean | N | +| `enable_preview` | Indicates whether to enable [data preview](../concepts/plans.md#data-preview) for forward-only models when targeting a development environment (Default: True, except for dbt projects where the target engine does not support cloning) | Boolean | N | +| `no_diff` | Don't show diffs for changed models (Default: False) | boolean | N | +| `no_prompts` | Disables interactive prompts in CLI (Default: True) | boolean | N | +| `always_recreate_environment` | Always recreates the target environment from the environment specified in `create_from` (by default `prod`) (Default: False) | boolean | N | ## Run @@ -92,6 +118,18 @@ Formatting settings for the `sqlmesh format` command and UI. | `leading_comma` | Whether to use leading commas (Default: False) | boolean | N | | `max_text_width` | The maximum text width in a segment before creating new lines (Default: 80) | int | N | | `append_newline` | Whether to append a newline to the end of the file (Default: False) | boolean | N | +| `no_rewrite_casts` | Preserve the existing casts, without rewriting them to use the :: syntax. (Default: False) | boolean | N | + + +## Janitor + +Configuration for the `sqlmesh janitor` command. + +| Option | Description | Type | Required | +|---------------------------------|----------------------------------------------------------------------------------------------------------------------------|:-------:|:--------:| +| `warn_on_delete_failure` | Whether to warn instead of erroring if the janitor fails to delete the expired environment schema / views (Default: False) | boolean | N | +| `expired_snapshots_batch_size` | Maximum number of expired snapshots to clean in a single batch (Default: 200) | int | N | + ## UI @@ -105,7 +143,7 @@ SQLMesh UI settings. The `gateways` dictionary defines how SQLMesh should connect to the data warehouse, state backend, test backend, and scheduler. -It takes one or more named `gateway` configuration keys, each of which can define its own connections. A named gateway does not need to specify all four components and will use defaults if any are omitted - more information is provided about [gateway defaults](#gatewayconnection-defaults) below. +It takes one or more named `gateway` configuration keys, each of which can define its own connections. **Gateway names are case-insensitive** - SQLMesh normalizes all gateway names to lowercase during configuration validation, allowing you to use any case when referencing gateways. A named gateway does not need to specify all four components and will use defaults if any are omitted - more information is provided about [gateway defaults](#gatewayconnection-defaults) below. For example, a project might configure the `gate1` and `gate2` gateways: @@ -114,7 +152,7 @@ gateways: gate1: connection: ... - state_connection: # defaults to `connection` if omitted and not using airflow or google cloud composer scheduler + state_connection: # defaults to `connection` if omitted ... test_connection: # defaults to `connection` if omitted ... @@ -138,7 +176,7 @@ A named gateway key may define any or all of a data warehouse connection, state Some connections use default values if not specified: - The `connection` key may be omitted if a [`default_connection`](#default-connectionsscheduler) is specified. -- The state connection defaults to `connection` unless the configuration uses an Airflow or Google Cloud Composer scheduler. If using one of those schedulers, the state connection defaults to the scheduler's database. +- The state connection defaults to `connection` if omitted. - The test connection defaults to `connection` if omitted. NOTE: Spark and Trino engines may not be used for the state connection. @@ -162,16 +200,19 @@ Most parameters are specific to the connection engine `type` - see [below](#engi | Option | Description | Type | Required | |---------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:----:|:--------:| -| `type` | The engine type name, listed in engine-specific configuration pages below. | str | Y | -| `concurrent_tasks` | The maximum number of concurrent tasks that will be run by SQLMesh. (Default: 4 for engines that support concurrent tasks.) | int | N | -| `register_comments` | Whether SQLMesh should register model comments with the SQL engine (if the engine supports it). (Default: `true`.) | bool | N | -| `pre_ping` | Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive. This can only be enabled for engines with transaction support. | bool | N | +| `type` | The engine type name, listed in engine-specific configuration pages below. | str | Y | +| `concurrent_tasks` | The maximum number of concurrent tasks that will be run by SQLMesh. (Default: 4 for engines that support concurrent tasks.) | int | N | +| `register_comments` | Whether SQLMesh should register model comments with the SQL engine (if the engine supports it). (Default: `true`.) | bool | N | +| `pre_ping` | Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive. This can only be enabled for engines with transaction support. | bool | N | +| `pretty_sql` | If SQL should be formatted before being executed, not recommended in a production setting. (Default: `false`.) | bool | N | #### Engine-specific These pages describe the connection configuration options for each execution engine. +* [Athena](../integrations/engines/athena.md) * [BigQuery](../integrations/engines/bigquery.md) +* [ClickHouse](../integrations/engines/clickhouse.md) * [Databricks](../integrations/engines/databricks.md) * [DuckDB](../integrations/engines/duckdb.md) * [MotherDuck](../integrations/engines/motherduck.md) @@ -188,7 +229,7 @@ These pages describe the connection configuration options for each execution eng Identifies which scheduler backend to use. The scheduler backend is used both for storing metadata and for executing [plans](../concepts/plans.md). -By default, the scheduler type is set to `builtin` and uses the gateway's connection to store metadata. Use the `airflow` type to integrate with Airflow. +By default, the scheduler type is set to `builtin` and uses the gateway's connection to store metadata. Below is the list of configuration options specific to each corresponding scheduler type. Find additional details in the [configuration overview scheduler section](../guides/configuration.md#scheduler). @@ -198,33 +239,6 @@ Below is the list of configuration options specific to each corresponding schedu No configuration options are supported by this scheduler type. -#### Airflow - -**Type:** `airflow` - -See [Airflow Integration Guide](../integrations/airflow.md) for information about how to integrate Airflow with SQLMesh. - -| Option | Description | Type | Required | -| --------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :-----: | :------: | -| `airflow_url` | The URL of the Airflow Webserver | string | Y | -| `username` | The Airflow username | string | Y | -| `password` | The Airflow password | string | Y | -| `dag_run_poll_interval_secs` | Determines, in seconds, how often a running DAG can be polled (Default: `10`) | int | N | -| `dag_creation_poll_interval_secs` | Determines, in seconds, how often SQLMesh should check whether a DAG has been created (Default: `30`) | int | N | -| `dag_creation_max_retry_attempts` | Determines the maximum number of attempts that SQLMesh will make while checking for whether a DAG has been created (Default: `10`) | int | N | -| `backfill_concurrent_tasks` | The number of concurrent tasks used for model backfilling during plan application (Default: `4`) | int | N | -| `ddl_concurrent_tasks` | The number of concurrent tasks used for DDL operations like table/view creation, deletion, and so forth (Default: `4`) | int | N | -| `max_snapshot_ids_per_request` | The maximum number of snapshot IDs that can be sent in a single HTTP GET request to the Airflow Webserver (Default: `None`) | int | N | -| `use_state_connection` | Whether to use the `state_connection` configuration to bypass Airflow Webserver and access the SQLMesh state directly (Default: `false`) | boolean | N | -| `default_catalog_override` | Overrides the default catalog value for this project. If specified, this value takes precedence over the default catalog value set on the Airflow side. This only applies in the [multi-repo](../guides/multi_repo.md) setup when different projects require different default catalog values (Default: `None`) | string | N | - - -#### Cloud Composer - -**Type:** `cloud_composer` - -The Google Cloud Composer scheduler type shares the same configuration options as the `airflow` type, except for `username` and `password`. Cloud Composer relies on `gcloud` authentication, so the `username` and `password` options are not required. - ## Gateway/connection defaults The default gateway and connection keys specify what should happen when gateways or connections are not explicitly specified. Find additional details in the configuration overview page [gateway/connection defaults section](../guides/configuration.md#gatewayconnection-defaults). @@ -235,7 +249,7 @@ If a configuration contains multiple gateways, SQLMesh will use the first one in | Option | Description | Type | Required | | ----------------- | ---------------------------------------------------------------------------------------------------------------------------- | :----: | :------: | -| `default_gateway` | The name of a gateway to use if one is not provided explicitly (Default: the gateway defined first in the `gateways` option) | string | N | +| `default_gateway` | The name of a gateway to use if one is not provided explicitly (Default: the gateway defined first in the `gateways` option). Gateway names are case-insensitive. | string | N | ### Default connections/scheduler @@ -246,12 +260,15 @@ For example, you might have a specific connection where your tests should run re | Option | Description | Type | Required | | ------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------ | :---------: | :------: | | `default_connection` | The default connection to use if one is not specified in a gateway (Default: A DuckDB connection that creates an in-memory database) | connection | N | -| `default_test_connection` | The default connection to use when running tests if one is not specified in a gateway (Default: A DuckDB connection that creates an in-memory database | connection) | N | +| `default_test_connection` | The default connection to use when running tests if one is not specified in a gateway (Default: A DuckDB connection that creates an in-memory database) | connection | N | | `default_scheduler` | The default scheduler configuration to use if one is not specified in a gateway (Default: built-in scheduler) | scheduler | N | ## Debug mode -To enable debug mode set the `SQLMESH_DEBUG` environment variable to one of the following values: "1", "true", "t", "yes" or "y". +Enable debug mode in one of two ways: + +- Pass the `--debug` flag between the CLI command and the subcommand. For example, `sqlmesh --debug plan`. +- Set the `SQLMESH_DEBUG` environment variable to one of the following values: "1", "true", "t", "yes" or "y". Enabling this mode ensures that full backtraces are printed when using CLI. The default log level is set to `DEBUG` when this mode is enabled. @@ -259,12 +276,20 @@ Example enabling debug mode for the CLI command `sqlmesh plan`: === "Bash" + ```bash + $ sqlmesh --debug plan + ``` + ```bash $ SQLMESH_DEBUG=1 sqlmesh plan ``` === "MS Powershell" + ```powershell + PS> sqlmesh --debug plan + ``` + ```powershell PS> $env:SQLMESH_DEBUG=1 PS> sqlmesh plan @@ -272,11 +297,32 @@ Example enabling debug mode for the CLI command `sqlmesh plan`: === "MS CMD" + ```cmd + C:\> sqlmesh --debug plan + ``` + ```cmd C:\> set SQLMESH_DEBUG=1 C:\> sqlmesh plan ``` +## Runtime Environment + +SQLMesh can run in different runtime environments. For example, you might run it in a regular command-line terminal, in a Jupyter notebook, or in Github's CI/CD platform. + +When it starts up, SQLMesh automatically detects the runtime environment and adjusts its behavior accordingly. For example, it registers `%magic` commands if in a Jupyter notebook and adjusts logging behavior if in a CI/CD environment. + +If necessary, you may force SQLMesh to use a specific runtime environment by setting the `SQLMESH_RUNTIME_ENVIRONMENT` environment variable. + +It accepts the following values, which will cause SQLMesh to behave as if it were in the runtime environment in parentheses: + +- `terminal` (CLI console) +- `databricks` (Databricks notebook) +- `google_colab` (Google Colab notebook) +- `jupyter` (Jupyter notebook) +- `debugger` (Debugging output) +- `ci` (CI/CD or other non-interactive environment) + ## Anonymized usage information We strive to make SQLMesh the best data transformation tool on the market. Part of accomplishing that is continually fixing bugs, adding features, and improving SQLMesh's performance. @@ -291,3 +337,8 @@ You can disable collection of anonymized usage information with these methods: - Set the root `disable_anonymized_analytics: true` key in your SQLMesh project configuration file - Execute SQLMesh commands with an environment variable `SQLMESH__DISABLE_ANONYMIZED_ANALYTICS` set to `1`, `true`, `t`, `yes`, or `y` + +## Parallel loading +SQLMesh by default uses all of your cores when loading models and snapshots. It takes advantage of `fork` which is not available on Windows. The default is to use the same number of workers as cores on your machine if fork is available. + +You can override this setting by setting the environment variable `MAX_FORK_WORKERS`. A value of 1 will disable forking and load things sequentially. diff --git a/docs/reference/model_configuration.md b/docs/reference/model_configuration.md index 48f9e6246e..9d040fe6db 100644 --- a/docs/reference/model_configuration.md +++ b/docs/reference/model_configuration.md @@ -8,37 +8,171 @@ Learn more about specifying SQLMesh model properties in the [model concepts over Configuration options for SQLMesh model properties. Supported by all model kinds other than [`SEED` models](#seed-models). -| Option | Description | Type | Required | +| Option | Description | Type | Required | |-----------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:-----------------:|:--------:| -| `name` | The model name. Must include at least a qualifying schema (`.`) and may include a catalog (`..`). Can be omitted if [infer_names](#model-naming) is set to true. | str | N | -| `kind` | The model kind ([Additional Details](#model-kind-properties)). (Default: `VIEW`) | str \| dict | N | -| `audits` | SQLMesh [audits](../concepts/audits.md) that should run against the model's output | array[str] | N | -| `dialect` | The SQL dialect in which the model's query is written. All SQL dialects [supported by the SQLGlot library](https://github.com/tobymao/sqlglot/blob/main/sqlglot/dialects/dialect.py) are allowed. | str | N | -| `owner` | The owner of a model; may be used for notification purposes | str | N | -| `stamp` | Arbitrary string used to indicate a model's version without changing the model name | str | N | -| `tags` | Arbitrary strings used to organize or classify a model | array[str] | N | -| `cron` | The cron expression specifying how often the model should be refreshed. (Default: `@daily`) | str | N | -| `interval_unit` | The temporal granularity of the model's data intervals. Supported values: `year`, `month`, `day`, `hour`, `half_hour`, `quarter_hour`, `five_minute`. (Default: inferred from `cron`) | str | N | -| `start` | The date/time that determines the earliest date interval that should be processed by a model. Can be a datetime string, epoch time in milliseconds, or a relative datetime such as `1 year ago`. (Default: `yesterday`) | str \| int | N | -| `end` | The date/time that determines the latest date interval that should be processed by a model. Can be a datetime string, epoch time in milliseconds, or a relative datetime such as `1 year ago`. | str \| int | N | -| `grains` | The column(s) whose combination uniquely identifies each row in the model | str \| array[str] | N | -| `references` | The model column(s) used to join to other models' grains | str \| array[str] | N | -| `depends_on` | Models on which this model depends in addition to the ones inferred from the model's query. (Default: dependencies inferred from model code) | array[str] | N | -| `storage_format` | The storage format that should be used to store physical tables; only applicable to engines such as Spark | str | N | -| `partitioned_by` | The column(s) and/or column expressions used define a model's partitioning key. Required for the `INCREMENTAL_BY_PARTITION` model kind. Optional for all other model kinds; used to partition the model's physical table in engines that support partitioning. | str \| array[str] | N | -| `clustered_by` | The column(s) used to cluster the model's physical table; only applicable to engines that support clustering | str | N | -| `columns` | The column names and data types returned by the model. Disables [automatic inference of column names and types](../concepts/models/overview.md#conventions) from the SQL query. | array[str] | N | -| `physical_properties` | A key-value mapping of arbitrary properties specific to the target engine that are applied to the model table / view in the physical layer. Specified as key-value pairs (`key = value`). | dict | N | -| `virtual_properties` | A key-value mapping of arbitrary properties specific to the target engine that are applied to the model view in the virtual layer. Specified as key-value pairs (`key = value`). | dict | N | -| `allow_partials` | Whether this model can process partial (incomplete) data intervals | bool | N | -| `description` | Description of the model. Automatically registered in the SQL engine's table COMMENT field or equivalent (if supported by the engine). | str | N | -| `column_descriptions` | A key-value mapping of column names to column comments that will be registered in the SQL engine's table COMMENT field (if supported by the engine). Specified as key-value pairs (`column_name = 'column comment'`). If present, [inline column comments](../concepts/models/overview.md#inline-column-comments) will not be registered in the SQL engine. | dict | N | -| `enabled` | Whether the model is enabled. This attribute is `true` by default. Setting it to `false` causes SQLMesh to ignore this model when loading the project. | bool | N | +| `name` | The model name. Must include at least a qualifying schema (`.`) and may include a catalog (`..`). Can be omitted if [infer_names](#model-naming) is set to true. | str | N | +| `project` | The name of the project the model belongs to - used in multi-repo deployments | str | N | +| `kind` | The model kind ([Additional Details](#model-kind-properties)). (Default: `VIEW`) | str \| dict | N | +| `audits` | SQLMesh [audits](../concepts/audits.md) that should run against the model's output | array[str] | N | +| `dialect` | The SQL dialect in which the model's query is written. All SQL dialects [supported by the SQLGlot library](https://github.com/tobymao/sqlglot/blob/main/sqlglot/dialects/dialect.py) are allowed. | str | N | +| `owner` | The owner of a model; may be used for notification purposes | str | N | +| `stamp` | Arbitrary string used to indicate a model's version without changing the model name | str | N | +| `tags` | Arbitrary strings used to organize or classify a model | array[str] | N | +| `cron` | The cron expression specifying how often the model should be refreshed. (Default: `@daily`) | str | N | +| `interval_unit` | The temporal granularity of the model's data intervals. Supported values: `year`, `month`, `day`, `hour`, `half_hour`, `quarter_hour`, `five_minute`. (Default: inferred from `cron`) | str | N | +| `start` | The date/time that determines the earliest date interval that should be processed by a model. Can be a datetime string, epoch time in milliseconds, or a relative datetime such as `1 year ago`. (Default: `yesterday`) | str \| int | N | +| `end` | The date/time that determines the latest date interval that should be processed by a model. Can be a datetime string, epoch time in milliseconds, or a relative datetime such as `1 year ago`. | str \| int | N | +| `description` | Description of the model. Automatically registered in the SQL engine's table COMMENT field or equivalent (if supported by the engine). | str | N | +| `column_descriptions` | A key-value mapping of column names to column comments that will be registered in the SQL engine's table COMMENT field (if supported by the engine). Specified as key-value pairs (`column_name = 'column comment'`). If present, [inline column comments](../concepts/models/overview.md#inline-column-comments) will not be registered in the SQL engine. | dict | N | +| `grains` | The column(s) whose combination uniquely identifies each row in the model | str \| array[str] | N | +| `references` | The model column(s) used to join to other models' grains | str \| array[str] | N | +| `depends_on` | Models on which this model depends, in addition to the ones inferred from the model's query. (Default: dependencies inferred from model code) | array[str] | N | +| `table_format` | The table format that should be used to manage the physical files (eg `iceberg`, `hive`, `delta`); only applicable to engines such as Spark and Athena | str | N | +| `storage_format` | The storage format that should be used to store physical files (eg `parquet`, `orc`); only applicable to engines such as Spark and Athena | str | N | +| `partitioned_by` | The column(s) and/or column expressions used define a model's partitioning key. Required for the `INCREMENTAL_BY_PARTITION` model kind. Optional for all other model kinds; used to partition the model's physical table in engines that support partitioning. | str \| array[str] | N | +| `clustered_by` | The column(s) and/or column expressions used to cluster the model's physical table; only applicable to engines that support clustering | str | N | +| `columns` | The column names and data types returned by the model. Disables [automatic inference of column names and types](../concepts/models/overview.md#conventions) from the SQL query. | array[str] | N | +| `physical_properties` | A key-value mapping of arbitrary properties specific to the target engine that are applied to the model table / view in the physical layer. Specified as key-value pairs (`key = value`). The view/table type (e.g. `TEMPORARY`, `TRANSIENT`) can be added with the `creatable_type` key. | dict | N | +| `virtual_properties` | A key-value mapping of arbitrary properties specific to the target engine that are applied to the model view in the virtual layer. Specified as key-value pairs (`key = value`). The view type (e.g. `SECURE`) can be added with the `creatable_type` key. | dict | N | +| `session_properties` | A key-value mapping of arbitrary properties specific to the target engine that are applied to the engine session. Specified as key-value pairs (`key = value`). | dict | N | +| `allow_partials` | Whether this model can process partial (incomplete) data intervals | bool | N | +| `enabled` | Whether the model is enabled. This attribute is `true` by default. Setting it to `false` causes SQLMesh to ignore this model when loading the project. | bool | N | +| `gateway` | Specifies the gateway to use for the execution of this model. When not specified, the default gateway is used. | str | N | +| `optimize_query` | Whether the model's query should be optimized. This attribute is `true` by default. Setting it to `false` causes SQLMesh to disable query canonicalization & simplification. This should be turned off only if the optimized query leads to errors such as surpassing text limit. | bool | N | +| `ignored_rules` | A list of linter rule names (or "ALL") to be ignored/excluded for this model | str \| array[str] | N | +| `formatting` | Whether the model will be formatted. All models are formatted by default. Setting this to `false` causes SQLMesh to ignore this model during `sqlmesh format`. | bool | N | ### Model defaults The SQLMesh project-level configuration must contain the `model_defaults` key and must specify a value for its `dialect` key. Other values are set automatically unless explicitly overridden in the model definition. Learn more about project-level configuration in the [configuration guide](../guides/configuration.md). +In `physical_properties`, `virtual_properties`, and `session_properties`, when both project-level and model-specific properties are defined, they are merged, with model-level properties taking precedence. To unset a project-wide property for a specific model, set it to `None` in the `MODEL`'s DDL properties or within the `@model` decorator for Python models. + +For example, with the following `model_defaults` configuration: + +=== "YAML" + + ```yaml linenums="1" + model_defaults: + dialect: snowflake + start: 2022-01-01 + physical_properties: + partition_expiration_days: 7 + require_partition_filter: True + project_level_property: "value" + ``` + +=== "Python" + + ```python linenums="1" + from sqlmesh.core.config import Config, ModelDefaultsConfig + + config = Config( + model_defaults=ModelDefaultsConfig( + dialect="snowflake", + start="2022-01-01", + physical_properties={ + "partition_expiration_days": 7, + "require_partition_filter": True, + "project_level_property": "value" + }, + ), + ) + ``` + +To override `partition_expiration_days`, add a new `creatable_type` property and unset `project_level_property`, you can define the model as follows: + +=== "SQL" + + ```sql linenums="1" + MODEL ( + ..., + physical_properties ( + partition_expiration_days = 14, + creatable_type = TRANSIENT, + project_level_property = None, + ) + ); + ``` + +=== "Python" + + ```python linenums="1" + @model( + ..., + physical_properties={ + "partition_expiration_days": 14, + "creatable_type": "TRANSIENT", + "project_level_property": None + }, + ) + ``` + +You can also use the `@model_kind_name` variable to fine-tune control over `physical_properties` in `model_defaults`. This holds the current model's kind name and is useful for conditionally assigning a property. For example, to disable `creatable_type` for your project's `VIEW` kind models: + +=== "YAML" + + ```yaml linenums="1" + model_defaults: + dialect: snowflake + start: 2022-01-01 + physical_properties: + creatable_type: "@IF(@model_kind_name != 'VIEW', 'TRANSIENT', NULL)" + ``` + +=== "Python" + + ```python linenums="1" + from sqlmesh.core.config import Config, ModelDefaultsConfig + + config = Config( + model_defaults=ModelDefaultsConfig( + dialect="snowflake", + start="2022-01-01", + physical_properties={ + "creatable_type": "@IF(@model_kind_name != 'VIEW', 'TRANSIENT', NULL)", + }, + ), + ) + ``` + +You can aso define `pre_statements`, `post_statements` and `on_virtual_update` statements at the project level that will be applied to all models. These default statements are merged with any model-specific statements, with default statements executing first, followed by model-specific statements. + +=== "YAML" + + ```yaml linenums="1" + model_defaults: + dialect: duckdb + pre_statements: + - "SET timeout = 300000" + post_statements: + - "@IF(@runtime_stage = 'evaluating', ANALYZE @this_model)" + on_virtual_update: + - "GRANT SELECT ON @this_model TO ROLE analyst_role" + ``` + +=== "Python" + + ```python linenums="1" + from sqlmesh.core.config import Config, ModelDefaultsConfig + + config = Config( + model_defaults=ModelDefaultsConfig( + dialect="duckdb", + pre_statements=[ + "SET query_timeout = 300000", + ], + post_statements=[ + "@IF(@runtime_stage = 'evaluating', ANALYZE @this_model)", + ], + on_virtual_update=[ + "GRANT SELECT ON @this_model TO ROLE analyst_role", + ], + ), + ) + ``` + + The SQLMesh project-level `model_defaults` key supports the following options, described in the [general model properties](#general-model-properties) table above: - kind @@ -46,18 +180,30 @@ The SQLMesh project-level `model_defaults` key supports the following options, d - cron - owner - start +- table_format - storage_format +- physical_properties +- virtual_properties - session_properties (on per key basis) - on_destructive_change (described [below](#incremental-models)) +- on_additive_change (described [below](#incremental-models)) +- audits (described [here](../concepts/audits.md#generic-audits)) +- optimize_query +- allow_partials +- enabled +- interval_unit +- pre_statements (described [here](../concepts/models/sql_models.md#pre--and-post-statements)) +- post_statements (described [here](../concepts/models/sql_models.md#pre--and-post-statements)) +- on_virtual_update (described [here](../concepts/models/sql_models.md#on-virtual-update-statements)) ### Model Naming Configuration option for name inference. Learn more in the [model naming guide](../guides/configuration.md#model-naming). -| Option | Description | Type | Required | -| --------------- | --------------------------------------------------------------------------------------- | :-----: | :------: | -| `infer_names` | Whether to automatically infer model names based on the directory structure (Default: `False`) | bool | N | +| Option | Description | Type | Required | +|---------------|------------------------------------------------------------------------------------------------|:----:|:--------:| +| `infer_names` | Whether to automatically infer model names based on the directory structure (Default: `False`) | bool | N | ## Model kind properties @@ -71,7 +217,7 @@ Learn more about model kinds at the [model kind concepts page](../concepts/model Configuration options for models of the [`VIEW` kind](../concepts/models/model_kinds.md#view) (in addition to [general model properties](#general-model-properties)). | Option | Description | Type | Required | -| -------------- | ---------------------------------------------------------------------------------------------------- | :--: | :------: | +|----------------|------------------------------------------------------------------------------------------------------|:----:|:--------:| | `materialized` | Whether views should be materialized (for engines supporting materialized views). (Default: `False`) | bool | N | Python model kind `name` enum value: [ModelKindName.VIEW](https://sqlmesh.readthedocs.io/en/stable/_readthedocs/html/sqlmesh/core/model/kind.html#ModelKindName) @@ -86,34 +232,36 @@ Python model kind `name` enum value: [ModelKindName.FULL](https://sqlmesh.readth Configuration options for all incremental models (in addition to [general model properties](#general-model-properties)). -| Option | Description | Type | Required | -|-------------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:----:|:--------:| -| `on_destructive_change` | What should happen when a change to a [forward-only model](../guides/incremental_time.md#forward-only-models) or incremental model in a [forward-only plan](../concepts/plans.md#forward-only-plans) causes a destructive modification to the model schema. Valid values: `allow`, `warn`, `error`. (Default: `error`) | str | N | -| `forward_only` | Whether the model's changes should always be classified as [forward-only](../concepts/plans.md#forward-only-change). (Default: `False`) | bool | N | -| `disable_restatement` | Whether [restatements](../concepts/plans.md#restatement-plans) should be disabled for the model. (Default: `False`) | bool | N | +| Option | Description | Type | Required | +|-------------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:----:|:--------:| +| `forward_only` | Whether the model's changes should always be classified as [forward-only](../concepts/plans.md#forward-only-change). (Default: `False`) | bool | N | +| `on_destructive_change` | What should happen when a change to a [forward-only model](../guides/incremental_time.md#forward-only-models) or incremental model in a [forward-only plan](../concepts/plans.md#forward-only-plans) causes a destructive modification to the model schema. Valid values: `allow`, `warn`, `error`, `ignore`. (Default: `error`) | str | N | +| `on_additive_change` | What should happen when a change to a [forward-only model](../guides/incremental_time.md#forward-only-models) or incremental model in a [forward-only plan](../concepts/plans.md#forward-only-plans) causes an additive modification to the model schema (like adding new columns). Valid values: `allow`, `warn`, `error`, `ignore`. (Default: `allow`) | str | N | +| `disable_restatement` | Whether [restatements](../concepts/plans.md#restatement-plans) should be disabled for the model. (Default: `False`) | bool | N | #### Incremental by time range Configuration options for [`INCREMENTAL_BY_TIME_RANGE` models](../concepts/models/model_kinds.md#incremental_by_time_range) (in addition to [general model properties](#general-model-properties) and [incremental model properties](#incremental-models)). | Option | Description | Type | Required | -|---------------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|:----:|:--------:| -| `time_column` | The model column containing each row's timestamp. | str | Y | -| `format` | Argument to `time_column`. Format of the time column's data. (Default: `%Y-%m-%d`) | str | N | -| `batch_size` | The maximum number of intervals that can be evaluated in a single backfill task. If this is `None`, all intervals will be processed as part of a single task. If this is set, a model's backfill will be chunked such that each individual task only contains jobs with the maximum of `batch_size` intervals. (Default: `None`) | int | N | -| `batch_concurrency` | The maximum number of batches that can run concurrently for this model. (Default: the number of concurrent tasks set in the connection settings) | int | N | -| `lookback` | The number of time unit intervals prior to the current interval that should be processed. (Default: `0`) | int | N | +| ------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :--: | :------: | +| `time_column` | The model column containing each row's timestamp. Should be UTC time zone. | str | Y | +| `format` | Argument to `time_column`. Format of the time column's data. (Default: `%Y-%m-%d`) | str | N | +| `batch_size` | The maximum number of intervals that can be evaluated in a single backfill task. If this is `None`, all intervals will be processed as part of a single task. If this is set, a model's backfill will be chunked such that each individual task only contains jobs with the maximum of `batch_size` intervals. (Default: `None`) | int | N | +| `batch_concurrency` | The maximum number of batches that can run concurrently for this model. (Default: the number of concurrent tasks set in the connection settings) | int | N | +| `lookback` | The number of `interval_unit`s prior to the current interval that should be processed - [learn more](../concepts/models/overview.md#lookback). (Default: `0`) | int | N | Python model kind `name` enum value: [ModelKindName.INCREMENTAL_BY_TIME_RANGE](https://sqlmesh.readthedocs.io/en/stable/_readthedocs/html/sqlmesh/core/model/kind.html#ModelKindName) #### Incremental by unique key -Configuration options for [`INCREMENTAL_BY_UNIQUE_KEY` models](../concepts/models/model_kinds.md#incremental_by_unique_key) (in addition to [general model properties](#general-model-properties) and [incremental model properties](#incremental-models)). +Configuration options for [`INCREMENTAL_BY_UNIQUE_KEY` models](../concepts/models/model_kinds.md#incremental_by_unique_key) (in addition to [general model properties](#general-model-properties) and [incremental model properties](#incremental-models)). Batch concurrency cannot be set for incremental by unique key models because they cannot safely be run in parallel. | Option | Description | Type | Required | |----------------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-------------------|----------| | `unique_key` | The model column(s) containing each row's unique key | str \| array[str] | Y | | `when_matched` | SQL logic used to update columns when a match occurs - only available on engines that support `MERGE`. (Default: update all columns) | str | N | +| `merge_filter` | A single or a conjunction of predicates used to filter data in the ON clause of a MERGE operation - only available on engines that support `MERGE` | str | N | | `batch_size` | The maximum number of intervals that can be evaluated in a single backfill task. If this is `None`, all intervals will be processed as part of a single task. If this is set, a model's backfill will be chunked such that each individual task only contains jobs with the maximum of `batch_size` intervals. (Default: `None`) | int | N | | `lookback` | The number of time unit intervals prior to the current interval that should be processed. (Default: `0`) | int | N | @@ -134,7 +282,7 @@ Configuration options for [`SCD_TYPE_2` models](../concepts/models/model_kinds.m | `unique_key` | The model column(s) containing each row's unique key | array[str] | Y | | `valid_from_name` | The model column containing each row's valid from date. (Default: `valid_from`) | str | N | | `valid_to_name` | The model column containing each row's valid to date. (Default: `valid_to`) | str | N | -| `invalidate_hard_deletes` | If set to true, when a record is missing from the source table it will be marked as invalid - see [here](../concepts/models/model_kinds.md#deletes) for more information. (Default: `True`) | bool | N | +| `invalidate_hard_deletes` | If set to true, when a record is missing from the source table it will be marked as invalid - see [here](../concepts/models/model_kinds.md#deletes) for more information. (Default: `False`) | bool | N | ##### SCD Type 2 By Time @@ -195,5 +343,8 @@ Options specified within the `kind` property's `csv_settings` property (override | `skipinitialspace` | Skip spaces after delimiter. More information at the [Pandas documentation](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html). | bool | N | | `lineterminator` | Character used to denote a line break. More information at the [Pandas documentation](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html). | str | N | | `encoding` | Encoding to use for UTF when reading/writing (ex. 'utf-8'). More information at the [Pandas documentation](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html). | str | N | +| `na_values` | An array of values that should be recognized as NA/NaN. In order to specify such an array per column, a mapping in the form of `(col1 = (v1, v2, ...), col2 = ...)` can be passed instead. These values can be integers, strings, booleans or NULL, and they are converted to their corresponding Python values. More information at the [Pandas documentation](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html). | array[value] \| array[array[key = value]] | N | +| `keep_default_na` | Whether or not to include the default NaN values when parsing the data. More information at the [Pandas documentation](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.read_csv.html). | bool | N | + Python model kind `name` enum value: [ModelKindName.SEED](https://sqlmesh.readthedocs.io/en/stable/_readthedocs/html/sqlmesh/core/model/kind.html#ModelKindName) diff --git a/docs/reference/notebook.md b/docs/reference/notebook.md index 5b34da79e9..6cac4e1078 100644 --- a/docs/reference/notebook.md +++ b/docs/reference/notebook.md @@ -24,14 +24,14 @@ import sqlmesh %context path_to_sqlmesh_project ``` -### Quick start project +### Quickstart project -If desired, you can create the [quick start example project](../quick_start.md) with the Python `init_example_project` function. The function requires a default SQL dialect for the project's models; this example uses `snowflake`: +If desired, you can create the [quickstart example project](../quick_start.md) with the Python `init_example_project` function. The function requires a default SQL dialect for the project's models; this example uses `snowflake`: ```python -from sqlmesh.cli.example_project import init_example_project +from sqlmesh.cli.project_init import init_example_project -init_example_project("path_to_project_directory", dialect="snowflake") +init_example_project("path_to_project_directory", engine_type="snowflake") ``` Alternatively, create the project with a notebook magic: @@ -70,7 +70,7 @@ options: #### init ``` -%init [--template TEMPLATE] path sql_dialect +%init [--template TEMPLATE] [--dlt-pipeline PIPELINE] path sql_dialect Creates a SQLMesh project scaffold with a default SQL dialect. @@ -86,8 +86,11 @@ positional arguments: options: --template TEMPLATE, -t TEMPLATE - Project template. Supported values: airflow, dbt, - default, empty. + Project template. Supported values: dbt, + dlt, default, empty. + --dlt-pipeline PIPELINE + DLT pipeline for which to generate a SQLMesh project. + This option is supported if the template is dlt. ``` #### plan @@ -95,12 +98,12 @@ options: %plan [--start START] [--end END] [--execution-time EXECUTION_TIME] [--create-from CREATE_FROM] [--skip-tests] [--restate-model [RESTATE_MODEL ...]] [--no-gaps] - [--skip-backfill] [--forward-only] + [--skip-backfill, --dry-run] [--forward-only] [--effective-from EFFECTIVE_FROM] [--no-prompts] [--auto-apply] [--no-auto-categorization] [--include-unmodified] [--select-model [SELECT_MODEL ...]] [--backfill-model [BACKFILL_MODEL ...]] [--no-diff] [--run] - [environment] + [environment] [--diff-rendered] Goes through a set of prompts to both establish a plan and apply it @@ -127,7 +130,8 @@ options: --no-gaps, -g Ensure that new snapshots have no data gaps when comparing to existing snapshots for matching models in the target environment. - --skip-backfill Skip the backfill step. + --skip-backfill, --dry-run + Skip the backfill step and only create a virtual update for the plan. --forward-only Create a plan for forward-only changes. --effective-from EFFECTIVE_FROM The effective date from which to apply forward-only @@ -149,6 +153,8 @@ options: --no-diff Hide text differences for changed models. --run Run latest intervals as part of the plan application (prod environment only). + --diff-rendered Output text differences for the rendered versions of models and standalone audits + ``` #### run_dag @@ -195,6 +201,9 @@ options: ``` %render [--start START] [--end END] [--execution-time EXECUTION_TIME] [--expand EXPAND] [--dialect DIALECT] [--no-format] + [--normalize] [--pad PAD] [--indent INDENT] + [--normalize-functions NORMALIZE_FUNCTIONS] [--leading-comma] + [--max-text-width MAX_TEXT_WIDTH] model Renders a model's query, optionally expanding referenced models. @@ -214,6 +223,17 @@ options: models are expanded as raw queries. --dialect DIALECT SQL dialect to render. --no-format Disable fancy formatting of the query. + --normalize Whether or not to normalize identifiers to lowercase. + --pad PAD Determines the pad size in a formatted string. + --indent INDENT Determines the indentation size in a formatted string. + --normalize-functions NORMALIZE_FUNCTIONS + Whether or not to normalize all function names. + Possible values are: 'upper', 'lower' + --leading-comma Determines whether or not the comma is leading or + trailing in select expressions. Default is trailing. + --max-text-width MAX_TEXT_WIDTH + The max number of characters in a segment before + creating new lines in pretty mode. ``` #### dag @@ -226,6 +246,31 @@ options: --file FILE, -f FILE An optional file path to write the HTML output to. ``` +#### destroy +``` +%destroy + +Removes all state tables, the SQLMesh cache, and other project resources, including warehouse objects. This includes all tables, views, and schemas managed by SQLMesh, as well as any external resources that may have been created by other tools within those schemas. +``` + +#### dlt_refresh +``` +%dlt_refresh PIPELINE [--table] TABLE [--force] + +Attaches to a DLT pipeline with the option to update specific or all models of the SQLMesh project. + +options: + --table TABLE, -t TABLE The DLT tables to generate SQLMesh models from. When none specified, all new missing tables will be generated. + --force, -f If set it will overwrite existing models with the new generated models from the DLT tables. +``` + +#### environments +``` +%environments + +Prints the list of SQLMesh environments with its expiry datetime. +``` + #### fetchdf ``` %%fetchdf [df_var] @@ -269,7 +314,8 @@ Create a schema file containing external model schemas. %table_diff [--on [ON ...]] [--skip-columns [SKIP_COLUMNS ...]] [--model MODEL] [--where WHERE] [--limit LIMIT] [--show-sample] [--decimals DECIMALS] [--skip-grain-check] - SOURCE:TARGET + [--warn-grain-check] [--temp-schema SCHEMA] + [--select-model [SELECT_MODEL ...]] SOURCE:TARGET Show the diff between two tables. @@ -294,6 +340,11 @@ options: floating point columns. Default: 3 --skip-grain-check Disable the check for a primary key (grain) that is missing or is not unique. + --warn-grain-check Warn if any selected model is missing a grain, + and compute diffs for the remaining models. + --temp-schema SCHEMA The schema to use for temporary tables. + --select-model <[SELECT_MODEL ...]> + Select specific models to diff using a pattern. ``` #### model @@ -349,8 +400,9 @@ options: #### create_test ``` -%create_test --query QUERY [QUERY ...] [--overwrite] +%create_test [--query QUERY [QUERY ...]] [--overwrite] [--var VAR [VAR ...]] [--path PATH] [--name NAME] + [--include-ctes] model Generate a unit test fixture for a given model. @@ -411,6 +463,27 @@ options: Execution time. ``` +#### check_intervals +``` +%check_intervals [--no-signals] [--select-model [SELECT_MODEL ...]] + [--start START] [--end END] + [environment] + +Show missing intervals in an environment, respecting signals. + +positional arguments: + environment The environment to check intervals for. + +options: + --no-signals Disable signal checks and only show missing intervals. + --select-model <[SELECT_MODEL ...]> + Select specific model changes that should be included + in the plan. + --start START, -s START + Start date of intervals to check for. + --end END, -e END End date of intervals to check for. +``` + #### rollback ``` %rollback @@ -437,3 +510,45 @@ options: --read READ The input dialect of the sql string. --write WRITE The output dialect of the sql string. ``` + +#### format +``` +%format [--transpile TRANSPILE] [--append-newline] [--no-rewrite-casts] + [--normalize] [--pad PAD] [--indent INDENT] + [--normalize-functions NORMALIZE_FUNCTIONS] [--leading-comma] + [--max-text-width MAX_TEXT_WIDTH] [--check] + +Format all SQL models and audits. + +options: + --transpile TRANSPILE, -t TRANSPILE + Transpile project models to the specified dialect. + --append-newline Whether or not to append a newline to the end of the + file. + --no-rewrite-casts Preserve the existing casts, without rewriting them + to use the :: syntax. + --normalize Whether or not to normalize identifiers to lowercase. + --pad PAD Determines the pad size in a formatted string. + --indent INDENT Determines the indentation size in a formatted string. + --normalize-functions NORMALIZE_FUNCTIONS + Whether or not to normalize all function names. + Possible values are: 'upper', 'lower' + --leading-comma Determines whether or not the comma is leading or + trailing in select expressions. Default is trailing. + --max-text-width MAX_TEXT_WIDTH + The max number of characters in a segment before + creating new lines in pretty mode. + --check Whether or not to check formatting (but not actually + format anything). +``` + + +#### lint +``` +%lint [--models ...] + +Run the linter on the target models(s) + +positional arguments: + --models A model to lint. Multiple models can be linted. If no models are specified, every model will be linted. +``` diff --git a/docs/reference/overview.md b/docs/reference/overview.md index a10d10ca71..9d1aa64f7f 100644 --- a/docs/reference/overview.md +++ b/docs/reference/overview.md @@ -1,36 +1,3 @@ # Overview -SQLMesh can be used with a [CLI](cli.md), [Notebook](notebook.md), or directly through [Python](python.md). Each interface aims to have parity in both functionality and arguments. The following is a list of available commands. - -## plan -Plan is the main command of SQLMesh. It allows you to interactively create a migration plan, understand the downstream impact, and apply it. All changes to models and environments are materialized through `plan`. - -Read more about [plans](../concepts/plans.md). - -## evaluate -Evaluate a model or snapshot (running its query against a DB/Engine). This command is used to test or iterate on models without side effects. - -## render -Renders a model's SQL query with the provided arguments. - -## fetchdf -Given a SQL query, fetches a pandas dataframe. - -## test -Runs all tests. - -Read more about [testing](../concepts/tests.md). - -## audit -Runs all audits. - -Read more about [auditing](../concepts/audits.md). - -## format -Formats all SQL model and audit files in place. - -## diff -Shows the diff between the local model and a model in an environment. - -## dag -Shows the [DAG](../concepts/glossary.md#dag). +SQLMesh can be used with a [CLI](cli.md), [Notebook](notebook.md), or directly through [Python](python.md). Each interface aims to have parity in both functionality and arguments. \ No newline at end of file diff --git a/docs/reference/python.md b/docs/reference/python.md index 14e0da84c8..1c4c9191ff 100644 --- a/docs/reference/python.md +++ b/docs/reference/python.md @@ -4,6 +4,6 @@ SQLMesh is built in Python, and its complete Python API reference is located [he The Python API reference is comprehensive and includes the internal components of SQLMesh. Those components are likely only of interest if you want to modify SQLMesh itself. -If you want to use SQLMesh via its Python API, the best approach is to study how the SQLMesh [CLI](./cli.md) calls it behind the scenes. The CLI implementation code shows exactly which Python methods are called for each CLI command and can be viewed [on Github](https://github.com/TobikoData/sqlmesh/blob/main/sqlmesh/cli/main.py). For example, the Python code executed by the `plan` command is located [here](https://github.com/TobikoData/sqlmesh/blob/15c8788100fa1cfb8b0cc1879ccd1ad21dc3e679/sqlmesh/cli/main.py#L302). +If you want to use SQLMesh via its Python API, the best approach is to study how the SQLMesh [CLI](./cli.md) calls it behind the scenes. The CLI implementation code shows exactly which Python methods are called for each CLI command and can be viewed [on Github](https://github.com/SQLMesh/sqlmesh/blob/main/sqlmesh/cli/main.py). For example, the Python code executed by the `plan` command is located [here](https://github.com/SQLMesh/sqlmesh/blob/15c8788100fa1cfb8b0cc1879ccd1ad21dc3e679/sqlmesh/cli/main.py#L302). Almost all the relevant Python methods are in the [SQLMesh `Context` class](https://sqlmesh.readthedocs.io/en/stable/_readthedocs/html/sqlmesh/core/context.html#Context). diff --git a/docs/requirements.txt b/docs/requirements.txt index ba7a51feaf..1035ffe94a 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -3,4 +3,4 @@ mkdocs-include-markdown-plugin==6.0.6 mkdocs-material==9.0.5 mkdocs-material-extensions==1.1.1 mkdocs-glightbox==0.3.7 -pdoc==13.0.1 +pdoc==14.5.1 diff --git a/docs/stylesheets/extra.css b/docs/stylesheets/extra.css index 0304b1755b..44daab1f86 100644 --- a/docs/stylesheets/extra.css +++ b/docs/stylesheets/extra.css @@ -3,7 +3,7 @@ border-color: rgb(68, 138, 255); } -.md-typeset .question > .admonition-title, +.md-typeset .question > .admonition-title, .md-typeset .question > summary { background-color: rgba(68, 138, 255, 0.1); @@ -16,6 +16,38 @@ } } +.md-typeset .admonition.tip, +.md-typeset details.tip{ + border-color: rgb(61, 100, 226); +} + +.md-typeset .tip > .admonition-title, +.md-typeset .tip > summary { + background-color: rgba(61, 100, 226, 0.1); + + &::before { + background-color: #3D64E2; + } + + &::after { + color: #3D64E2; + } +} + .md-nav__link { word-break: break-word; } + +:root { + --md-primary-fg-color: #3C64E2; +} + +.md-tabs__item:last-child a::before { + content: ""; + background: transparent url("tobiko-logo.svg") center left no-repeat; + background-size: contain; + width: 18px; + height: 18px; + display: inline-block; + margin-bottom: -4px; +} \ No newline at end of file diff --git a/docs/stylesheets/tobiko-logo.svg b/docs/stylesheets/tobiko-logo.svg new file mode 100644 index 0000000000..3dbc9a6194 --- /dev/null +++ b/docs/stylesheets/tobiko-logo.svg @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/examples/airflow/Dockerfile.template b/examples/airflow/Dockerfile.template deleted file mode 100644 index 29fe3c6828..0000000000 --- a/examples/airflow/Dockerfile.template +++ /dev/null @@ -1,75 +0,0 @@ -FROM apache/spark:3.5.0-python3 AS spark - -FROM apache/airflow:$AIRFLOW_VERSION-python3.8 - -USER root - -# Fix the airflow user UID -ENV AIRFLOW_UID=$AIRFLOW_UID -RUN usermod -u $AIRFLOW_UID airflow - -# Workaround the expired MySQL GPG key. -RUN rm -f /etc/apt/sources.list.d/mysql.list - -RUN apt-get autoclean -RUN apt-get update - -# Install system packages -RUN apt install -y default-jdk gcc g++ make git - -ENV JAVA_HOME="/usr/lib/jvm/default-java/" - -# Install Spark -COPY --from=spark /opt/spark /opt/spark -RUN chown -R airflow /opt/spark -ENV SPARK_HOME="/opt/spark" -ENV PATH="$PATH:$SPARK_HOME/bin" - -# Install Postgres driver and Iceberg for Spark -RUN curl https://jdbc.postgresql.org/download/postgresql-42.5.0.jar -o /opt/spark/jars/postgresql-42.5.0.jar && \ - curl -L https://search.maven.org/remotecontent?filepath=org/apache/iceberg/iceberg-spark-runtime-3.5_2.12/1.5.1/iceberg-spark-runtime-3.5_2.12-1.5.1.jar -o /opt/spark/jars/iceberg-spark-runtime-3.5_2.12-1.5.1.jar - -# Install Hadoop -RUN curl https://dlcdn.apache.org/hadoop/common/hadoop-3.3.6/hadoop-3.3.6.tar.gz -o hadoop-3.3.6.tar.gz && \ - tar xf hadoop-3.3.6.tar.gz -C /opt/ && \ - mv /opt/hadoop-3.3.6 /opt/hadoop - -ENV HADOOP_HOME="/opt/hadoop" - -# Install Hive -RUN curl https://dlcdn.apache.org/hive/hive-3.1.3/apache-hive-3.1.3-bin.tar.gz -o apache-hive-3.1.3-bin.tar.gz && \ - tar xf apache-hive-3.1.3-bin.tar.gz -C /opt/ && \ - mv /opt/apache-hive-3.1.3-bin /opt/hive - -ENV HIVE_HOME="/opt/hive" - -# Airflow connections -ENV AIRFLOW_CONN_SPARK_DEFAULT="spark://local?deploy-mode=client" - -# Airflow configuration -ENV AIRFLOW__SCHEDULER__MIN_FILE_PROCESS_INTERVAL=3 - -# SQLMesh configuration -ENV SQLMESH__DISABLE_ANONYMIZED_ANALYTICS=1 - -USER airflow - -# Install Spark provider for Airflow -# skip install pyspark since it's part of the image -RUN pip install apache-airflow-providers-apache-spark==4.8.0 --no-deps -RUN pip install apache-airflow-providers-databricks==6.4.0 \ - apache-airflow-providers-github==2.6.0 \ - apache-airflow-providers-common-sql==1.13.0 \ - pandas==1.5.2 # python 3.8 spark 3.4 and pandas 2.0 have issues with casting timestamp - -# Install Deps -USER root -ADD setup.py /opt/sqlmesh/setup.py -RUN mkdir /opt/sqlmesh/sqlmesh && touch /opt/sqlmesh/sqlmesh/__init__.py && chown -R airflow /opt/sqlmesh - -ADD examples/custom_materializations /opt/custom_materializations -RUN chown -R airflow /opt/custom_materializations - -USER airflow -RUN cd /opt/sqlmesh && pip install -e .[dbt] -RUN cd /opt/custom_materializations && pip install -e . diff --git a/examples/airflow/Makefile b/examples/airflow/Makefile deleted file mode 100644 index 6a0b618776..0000000000 --- a/examples/airflow/Makefile +++ /dev/null @@ -1,56 +0,0 @@ -AIRFLOW_VERSION ?= 2.9.1 -AIRFLOW_IMAGE_NAME ?= airflow-sqlmesh -AIRFLOW_UID ?= $(shell id -u) - -install-requirements: - pip3 install -r requirements.txt - -download-docker-compose: - curl -LfO 'https://airflow.apache.org/docs/apache-airflow/$(AIRFLOW_VERSION)/docker-compose.yaml' - -decorate-docker-compose: install-requirements download-docker-compose - python3 ./docker_compose_decorator.py - -download-cli: - curl -LfO 'https://airflow.apache.org/docs/apache-airflow/$(AIRFLOW_VERSION)/airflow.sh' && chmod +x airflow.sh - -package-sqlmesh: - make -C ../../ package - -init-folders: - mkdir -p ./dags ./logs ./plugins ./warehouse - -init-airflow-dockerfile: - export AIRFLOW_VERSION=$(AIRFLOW_VERSION) AIRFLOW_UID=$(AIRFLOW_UID) && cat Dockerfile.template | envsubst '$$AIRFLOW_VERSION,$$AIRFLOW_UID' > Dockerfile - -build-airflow-image: init-airflow-dockerfile - cd ../../ && docker build -t $(AIRFLOW_IMAGE_NAME) -f ./examples/airflow/Dockerfile . - -create-metastore-db: build-airflow-image decorate-docker-compose - export AIRFLOW_IMAGE_NAME=$(AIRFLOW_IMAGE_NAME) AIRFLOW_UID=$(AIRFLOW_UID) && docker-compose up --force-recreate create-metastore-db - -provision-metastore-tables: build-airflow-image decorate-docker-compose create-metastore-db - export AIRFLOW_IMAGE_NAME=$(AIRFLOW_IMAGE_NAME) AIRFLOW_UID=$(AIRFLOW_UID) && docker-compose up --force-recreate provision-metastore-tables - -init-airflow: decorate-docker-compose - export AIRFLOW_IMAGE_NAME=$(AIRFLOW_IMAGE_NAME) AIRFLOW_UID=$(AIRFLOW_UID) && docker-compose up --force-recreate airflow-init - -init: decorate-docker-compose download-cli package-sqlmesh init-folders build-airflow-image init-airflow provision-metastore-tables - -run: build-airflow-image - export AIRFLOW_IMAGE_NAME=$(AIRFLOW_IMAGE_NAME) AIRFLOW_UID=$(AIRFLOW_UID) && docker-compose up --force-recreate -d - -stop: - docker-compose down - -clean: decorate-docker-compose - docker-compose down --volumes --remove-orphans && docker rmi -f $(AIRFLOW_IMAGE_NAME) && rm -rf ./logs/* && rm -rf ./warehouse/* - -psql: - docker-compose exec -it postgres psql -U airflow airflow - -spark-sql: - docker-compose exec -it airflow-worker spark-sql - -docker-test: decorate-docker-compose - docker-compose up --force-recreate --exit-code-from sqlmesh-tests sqlmesh-tests diff --git a/examples/airflow/README.md b/examples/airflow/README.md deleted file mode 100644 index fdc8942ae1..0000000000 --- a/examples/airflow/README.md +++ /dev/null @@ -1,45 +0,0 @@ -# SQLMesh Airflow Examples - -## Requirements -1. Docker -2. Docker Compose - -**Note:** the Docker instance must be configured to use 4GB of memory for all containers to run properly. - -## Install and Run -Initialize the Airflow environment first. This should only be done once: -```bash -make init -``` -Run the Airflow cluster in Docker: -```bash -make run -``` -The UI should now become available at [http://localhost:8080/](http://localhost:8080/). The account created has the login `airflow` and the password `airflow`. - -Terminate the Airflow cluster: -```bash -make stop -``` -Clean the environment: -```bash -make clean -``` -Re-create and re-launch the environment in one command: -```bash -make clean init run -``` -Access the Postgres instance with psql: -```bash -make psql -``` -Run the Spark SQL REPL on a running cluster: -```bash -make spark-sql -``` - -## CLI -After installation is complete the Airflow CLI script will become available: -```bash -./airflow.sh -``` diff --git a/examples/airflow/dags/sqlmesh_integration.py b/examples/airflow/dags/sqlmesh_integration.py deleted file mode 100644 index 2e1619703e..0000000000 --- a/examples/airflow/dags/sqlmesh_integration.py +++ /dev/null @@ -1,9 +0,0 @@ -import os - -from sqlmesh.schedulers.airflow.integration import SQLMeshAirflow - -engine_operator = os.environ.get("AIRFLOW_ENGINE_OPERATOR", "spark") -sqlmesh_airflow = SQLMeshAirflow(engine_operator, default_catalog="spark_catalog") - -for dag in sqlmesh_airflow.dags: - globals()[dag.dag_id] = dag diff --git a/examples/airflow/docker_compose_decorator.py b/examples/airflow/docker_compose_decorator.py deleted file mode 100644 index e61ea0a699..0000000000 --- a/examples/airflow/docker_compose_decorator.py +++ /dev/null @@ -1,125 +0,0 @@ -import os - -from ruamel.yaml import YAML - -DOCKER_COMPOSE_YAML = "docker-compose.yaml" - -yaml = YAML(typ="safe") -yaml.default_flow_style = False - - -with open(DOCKER_COMPOSE_YAML, "r", encoding="utf-8") as fd: - docker_compose = yaml.load(fd) - -docker_compose["x-airflow-common"]["volumes"].extend( - [ - "./spark_conf:/opt/spark/conf", - "./spark_conf:/opt/hive/conf", - "./warehouse:/opt/warehouse", - "../../:/opt/sqlmesh", - ] -) - -# Dont load Airflow example DAGs because they cause visual pollution -docker_compose["x-airflow-common"]["environment"]["AIRFLOW__CORE__LOAD_EXAMPLES"] = "false" - -docker_compose["services"]["postgres"]["ports"] = ["5432:5432"] - -docker_compose["services"]["create-metastore-db"] = { - "command": [ - "psql", - "-U", - "airflow", - "--host", - "postgres", - "-c", - "CREATE DATABASE metastore_db", - ], - "environment": { - "PGPASSWORD": "airflow", - }, - "image": docker_compose["services"]["postgres"]["image"], - "depends_on": { - "postgres": { - "condition": "service_healthy", - }, - }, - "profiles": [ - "sqlmesh-warehouse-init", - ], -} - -docker_compose["services"]["provision-metastore-tables"] = { - "entrypoint": "/bin/bash", - "command": [ - "-c", - "/opt/hive/bin/schematool -dbType postgres -initSchema", - ], - "image": "airflow-sqlmesh", - "user": "airflow", - "volumes": [ - "./spark_conf:/opt/spark/conf", - "./spark_conf:/opt/hive/conf", - "./warehouse:/opt/warehouse", - ], - "depends_on": { - "postgres": { - "condition": "service_healthy", - }, - }, - "profiles": [ - "sqlmesh-warehouse-init", - ], -} - -docker_compose["services"]["sqlmesh-tests"] = { - "entrypoint": "/bin/bash", - "command": [ - "-c", - "make install-dev && pytest -m 'airflow and docker'", - ], - "image": "airflow-sqlmesh", - "user": "airflow", - "volumes": [ - "./spark_conf:/opt/spark/conf", - "./spark_conf:/opt/hive/conf", - "./warehouse:/opt/warehouse", - "../../:/opt/sqlmesh", - ], - "environment": { - "AIRFLOW__DATABASE__SQL_ALCHEMY_CONN": "postgresql+psycopg2://airflow:airflow@postgres/airflow", - "IS_DOCKER": "true", - }, - "working_dir": "/opt/sqlmesh", - "profiles": [ - "sqlmesh-tests", - ], -} - -engine_operator = os.environ.get("AIRFLOW_ENGINE_OPERATOR", "spark").lower() -for airflow_component in ["airflow-scheduler", "airflow-worker"]: - environment_variables = {"AIRFLOW_ENGINE_OPERATOR": engine_operator} - if engine_operator == "databricks": - if not all( - variable in os.environ - for variable in [ - "DATABRICKS_SERVER_HOSTNAME", - "DATABRICKS_TOKEN", - "DATABRICKS_HTTP_PATH", - ] - ): - raise RuntimeError( - "Tried to use Databricks Airflow Engine operator but did not define `DATABRICKS_SERVER_HOSTNAME`, `DATABRICKS_TOKEN`, `DATABRICKS_HTTP_PATH`" - ) - environment_variables["AIRFLOW_CONN_DATABRICKS_DEFAULT"] = ( - "databricks://${DATABRICKS_SERVER_HOSTNAME}?token=${DATABRICKS_TOKEN}&http_path=${DATABRICKS_HTTP_PATH}" - ) - if os.getenv("DEMO_GITHUB_PAT"): - environment_variables["AIRFLOW_CONN_GITHUB_DEFAULT"] = ( - '{"conn_type": "github", "password": "${DEMO_GITHUB_PAT}"}' - ) - docker_compose["services"][airflow_component]["environment"].update(environment_variables) - - -with open(DOCKER_COMPOSE_YAML, "w", encoding="utf-8") as fd: - yaml.dump(docker_compose, fd) diff --git a/examples/airflow/requirements.txt b/examples/airflow/requirements.txt deleted file mode 100644 index 58d360844e..0000000000 --- a/examples/airflow/requirements.txt +++ /dev/null @@ -1 +0,0 @@ -ruamel.yaml diff --git a/examples/airflow/spark_conf/hive-site.xml b/examples/airflow/spark_conf/hive-site.xml deleted file mode 100644 index 62b857aab7..0000000000 --- a/examples/airflow/spark_conf/hive-site.xml +++ /dev/null @@ -1,25 +0,0 @@ - - - javax.jdo.option.ConnectionURL - jdbc:postgresql://postgres:5432/metastore_db - JDBC connect string for a JDBC metastore - - - javax.jdo.option.ConnectionDriverName - org.postgresql.Driver - Driver class name for a JDBC metastore - - - javax.jdo.option.ConnectionUserName - airflow - - - javax.jdo.option.ConnectionPassword - airflow - - - hive.metastore.warehouse.dir - /opt/warehouse - location of default database for the warehouse - - diff --git a/examples/airflow/spark_conf/spark-defaults.conf b/examples/airflow/spark_conf/spark-defaults.conf deleted file mode 100644 index 6904b823eb..0000000000 --- a/examples/airflow/spark_conf/spark-defaults.conf +++ /dev/null @@ -1,7 +0,0 @@ -spark.hadoop.hive.exec.dynamic.partition true -spark.hadoop.hive.exec.dynamic.partition.mode nonstrict -spark.sql.sources.partitionOverwriteMode dynamic - -spark.sql.extensions org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions -spark.sql.catalog.spark_catalog org.apache.iceberg.spark.SparkSessionCatalog -spark.sql.catalog.spark_catalog.type hive diff --git a/examples/custom_materializations/custom_materializations/custom_kind.py b/examples/custom_materializations/custom_materializations/custom_kind.py new file mode 100644 index 0000000000..8a0eabcfa7 --- /dev/null +++ b/examples/custom_materializations/custom_materializations/custom_kind.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import typing as t + +from sqlmesh import CustomMaterialization, CustomKind, Model +from sqlmesh.utils.pydantic import validate_string + +if t.TYPE_CHECKING: + from sqlmesh import QueryOrDF + + +class ExtendedCustomKind(CustomKind): + @property + def custom_property(self) -> str: + return validate_string(self.materialization_properties.get("custom_property")) + + +class CustomFullWithCustomKindMaterialization(CustomMaterialization[ExtendedCustomKind]): + NAME = "custom_full_with_custom_kind" + + def insert( + self, + table_name: str, + query_or_df: QueryOrDF, + model: Model, + is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], + **kwargs: t.Any, + ) -> None: + assert type(model.kind).__name__ == "ExtendedCustomKind" + + self._replace_query_for_model(model, table_name, query_or_df, render_kwargs) diff --git a/examples/custom_materializations/custom_materializations/full.py b/examples/custom_materializations/custom_materializations/full.py index 79aa50232a..d2a7c64993 100644 --- a/examples/custom_materializations/custom_materializations/full.py +++ b/examples/custom_materializations/custom_materializations/full.py @@ -17,6 +17,7 @@ def insert( query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: - self._replace_query_for_model(model, table_name, query_or_df) + self._replace_query_for_model(model, table_name, query_or_df, render_kwargs) diff --git a/examples/custom_materializations/pyproject.toml b/examples/custom_materializations/pyproject.toml new file mode 100644 index 0000000000..fd233a9986 --- /dev/null +++ b/examples/custom_materializations/pyproject.toml @@ -0,0 +1,19 @@ +[build-system] +requires = ["setuptools >= 61.0", "setuptools_scm"] +build-backend = "setuptools.build_meta" + +[project] +name = "custom_materializations" +requires-python = ">=3.9" +version = "0.1.0" +dependencies = ["sqlmesh"] + +[project.entry-points."sqlmesh.materializations"] +custom_full_materialization = "custom_materializations.full:CustomFullMaterialization" +custom_full_with_custom_kind = "custom_materializations.custom_kind:CustomFullWithCustomKindMaterialization" + +[tool.setuptools.packages.find] +include = ["custom_materializations"] + +[tool.setuptools_scm] +fallback_version = "0.0.0" diff --git a/examples/custom_materializations/setup.py b/examples/custom_materializations/setup.py deleted file mode 100644 index f2ce8176d5..0000000000 --- a/examples/custom_materializations/setup.py +++ /dev/null @@ -1,14 +0,0 @@ -from setuptools import find_packages, setup - -setup( - name="custom_materializations", - packages=find_packages(include=["custom_materializations"]), - entry_points={ - "sqlmesh.materializations": [ - "custom_full_materialization = custom_materializations.full:CustomFullMaterialization", - ], - }, - install_requires=[ - "sqlmesh", - ], -) diff --git a/examples/ibis/models/ibis_full_model_python.py b/examples/ibis/models/ibis_full_model_python.py index 5a50a021c5..1ded44775d 100644 --- a/examples/ibis/models/ibis_full_model_python.py +++ b/examples/ibis/models/ibis_full_model_python.py @@ -2,7 +2,7 @@ from datetime import datetime import ibis # type: ignore -import pandas as pd +import pandas as pd # noqa: TID253 from constants import DB_PATH # type: ignore from sqlglot import exp @@ -28,7 +28,7 @@ def execute( **kwargs: t.Any, ) -> pd.DataFrame: # get physical table name - upstream_model = exp.to_table(context.table("ibis.incremental_model")) + upstream_model = exp.to_table(context.resolve_table("ibis.incremental_model")) # connect ibis to database con = ibis.duckdb.connect(DB_PATH) diff --git a/examples/multi/.vscode/settings.json b/examples/multi/.vscode/settings.json new file mode 100644 index 0000000000..e08af7514c --- /dev/null +++ b/examples/multi/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "sqlmesh.projectPaths": ["./repo_1", "./repo_2"] +} \ No newline at end of file diff --git a/examples/multi/repo_1/config.yaml b/examples/multi/repo_1/config.yaml index f4e111275d..0f35441b86 100644 --- a/examples/multi/repo_1/config.yaml +++ b/examples/multi/repo_1/config.yaml @@ -4,7 +4,7 @@ gateways: local: connection: type: duckdb - database: db.db + database: db.duckdb memory: connection: @@ -12,5 +12,16 @@ gateways: default_gateway: local + +before_all: + - CREATE TABLE IF NOT EXISTS before_1 AS select @one() +after_all: + - CREATE TABLE IF NOT EXISTS after_1 AS select @dup() + model_defaults: dialect: 'duckdb' + +linter: + enabled: True + + warn_rules: "ALL" \ No newline at end of file diff --git a/examples/airflow/__init__.py b/examples/multi/repo_1/linter/__init__.py similarity index 100% rename from examples/airflow/__init__.py rename to examples/multi/repo_1/linter/__init__.py diff --git a/examples/multi/repo_1/linter/user.py b/examples/multi/repo_1/linter/user.py new file mode 100644 index 0000000000..1dfc7c8ae2 --- /dev/null +++ b/examples/multi/repo_1/linter/user.py @@ -0,0 +1,15 @@ +"""Contains all the standard rules included with SQLMesh""" + +from __future__ import annotations + +import typing as t + +from sqlmesh.core.linter.rule import Rule, RuleViolation +from sqlmesh.core.model import Model + + +class NoMissingDescription(Rule): + """All models should be documented.""" + + def check_model(self, model: Model) -> t.Optional[RuleViolation]: + return self.violation() if not model.description else None diff --git a/examples/multi/repo_1/macros/__init__.py b/examples/multi/repo_1/macros/__init__.py index e69de29bb2..e95e68fd17 100644 --- a/examples/multi/repo_1/macros/__init__.py +++ b/examples/multi/repo_1/macros/__init__.py @@ -0,0 +1,11 @@ +from sqlmesh import macro + + +@macro() +def one(context): + return 1 + + +@macro() +def dup(context): + return "'repo_1'" diff --git a/examples/multi/repo_1/models/a.sql b/examples/multi/repo_1/models/a.sql index 0eb1782054..31ef81b2d7 100644 --- a/examples/multi/repo_1/models/a.sql +++ b/examples/multi/repo_1/models/a.sql @@ -1,7 +1,10 @@ MODEL ( - name bronze.a + name bronze.a, + kind FULL ); SELECT 1 AS col_a, - 'b' AS col_b \ No newline at end of file + 'b' AS col_b, + @one() AS one, + @dup() AS dup diff --git a/examples/multi/repo_1/models/b.sql b/examples/multi/repo_1/models/b.sql index b32897705e..b80918d6d5 100644 --- a/examples/multi/repo_1/models/b.sql +++ b/examples/multi/repo_1/models/b.sql @@ -1,5 +1,6 @@ MODEL ( - name bronze.b + name bronze.b, + kind FULL ); SELECT diff --git a/examples/multi/repo_2/config.yaml b/examples/multi/repo_2/config.yaml index 6bd2063a84..23bec6d8fe 100644 --- a/examples/multi/repo_2/config.yaml +++ b/examples/multi/repo_2/config.yaml @@ -4,7 +4,7 @@ gateways: local: connection: type: duckdb - database: db.db + database: db.duckdb memory: connection: @@ -12,5 +12,16 @@ gateways: default_gateway: local + +before_all: + - CREATE TABLE IF NOT EXISTS before_2 AS select @two() +after_all: + - CREATE TABLE IF NOT EXISTS after_2 AS select @dup() + model_defaults: - dialect: 'duckdb' \ No newline at end of file + dialect: 'duckdb' + +linter: + enabled: True + + ignored_rules: "ALL" \ No newline at end of file diff --git a/examples/multi/repo_2/macros/__init__.py b/examples/multi/repo_2/macros/__init__.py index e69de29bb2..04b00b2c55 100644 --- a/examples/multi/repo_2/macros/__init__.py +++ b/examples/multi/repo_2/macros/__init__.py @@ -0,0 +1,11 @@ +from sqlmesh import macro + + +@macro() +def two(context): + return 2 + + +@macro() +def dup(context): + return "'repo_2'" diff --git a/examples/multi/repo_2/models/c.sql b/examples/multi/repo_2/models/c.sql index 6a5c42619c..08551704f4 100644 --- a/examples/multi/repo_2/models/c.sql +++ b/examples/multi/repo_2/models/c.sql @@ -1,5 +1,6 @@ MODEL ( - name silver.c + name silver.c, + kind FULL ); SELECT DISTINCT col_a diff --git a/examples/multi/repo_2/models/d.sql b/examples/multi/repo_2/models/d.sql index 897d38272e..6935763f59 100644 --- a/examples/multi/repo_2/models/d.sql +++ b/examples/multi/repo_2/models/d.sql @@ -1,7 +1,10 @@ MODEL ( - name silver.d + name silver.d, + kind FULL ); SELECT - * + *, + @two() as two, + @dup() AS dup FROM silver.c diff --git a/examples/multi/repo_2/models/e.sql b/examples/multi/repo_2/models/e.sql new file mode 100644 index 0000000000..168dbc143d --- /dev/null +++ b/examples/multi/repo_2/models/e.sql @@ -0,0 +1,8 @@ +MODEL ( + name silver.e, + kind FULL +); + +SELECT + * EXCEPT(dup) +FROM bronze.a diff --git a/examples/multi_dbt/bronze/dbt_project.yml b/examples/multi_dbt/bronze/dbt_project.yml index 14f841251c..1fadcdc1cd 100644 --- a/examples/multi_dbt/bronze/dbt_project.yml +++ b/examples/multi_dbt/bronze/dbt_project.yml @@ -19,3 +19,6 @@ require-dbt-version: [">=1.0.0", "<2.0.0"] models: start: "2024-01-01" +materialized: table + +on-run-start: + - 'CREATE TABLE IF NOT EXISTS analytic_stats (physical_table VARCHAR, evaluation_time VARCHAR);' \ No newline at end of file diff --git a/examples/multi_dbt/silver/dbt_project.yml b/examples/multi_dbt/silver/dbt_project.yml index e78f4643d3..57edd1f72c 100644 --- a/examples/multi_dbt/silver/dbt_project.yml +++ b/examples/multi_dbt/silver/dbt_project.yml @@ -19,3 +19,6 @@ require-dbt-version: [">=1.0.0", "<2.0.0"] models: start: "2024-01-01" +materialized: table + +on-run-end: + - '{{ store_schemas(schemas) }}' \ No newline at end of file diff --git a/examples/multi_dbt/silver/macros/store_schemas.sql b/examples/multi_dbt/silver/macros/store_schemas.sql new file mode 100644 index 0000000000..564d2b24bb --- /dev/null +++ b/examples/multi_dbt/silver/macros/store_schemas.sql @@ -0,0 +1,3 @@ +{% macro store_schemas(schemas) %} + create or replace table schema_table as select {{schemas}} as all_schemas; +{% endmacro %} \ No newline at end of file diff --git a/examples/multi_hybrid/dbt_repo/config.py b/examples/multi_hybrid/dbt_repo/config.py new file mode 100644 index 0000000000..e921536244 --- /dev/null +++ b/examples/multi_hybrid/dbt_repo/config.py @@ -0,0 +1,5 @@ +from pathlib import Path + +from sqlmesh.dbt.loader import sqlmesh_config + +config = sqlmesh_config(Path(__file__).parent, project="dbt_repo") diff --git a/examples/multi_hybrid/dbt_repo/dbt_project.yml b/examples/multi_hybrid/dbt_repo/dbt_project.yml new file mode 100644 index 0000000000..18a5d6b57c --- /dev/null +++ b/examples/multi_hybrid/dbt_repo/dbt_project.yml @@ -0,0 +1,18 @@ +name: 'dbt_repo' + +version: '1.0.0' + +profile: 'dbt_repo' + +model-paths: ["models"] +macro-paths: ["macros"] + +clean-targets: + - "target" + - "dbt_packages" + +models: + dbt_repo: + +materialized: view + + +start: Jan 1 2000 \ No newline at end of file diff --git a/examples/multi_hybrid/dbt_repo/macros/round_dollars.sql b/examples/multi_hybrid/dbt_repo/macros/round_dollars.sql new file mode 100644 index 0000000000..d184247e30 --- /dev/null +++ b/examples/multi_hybrid/dbt_repo/macros/round_dollars.sql @@ -0,0 +1,3 @@ +{% macro round_dollars(column, scale=2) %} + ROUND(({{ column }} / 100)::numeric(16, {{ scale }}), {{ scale }}) +{% endmacro %} \ No newline at end of file diff --git a/examples/multi_hybrid/dbt_repo/models/c.sql b/examples/multi_hybrid/dbt_repo/models/c.sql new file mode 100644 index 0000000000..e38030b07c --- /dev/null +++ b/examples/multi_hybrid/dbt_repo/models/c.sql @@ -0,0 +1,3 @@ +SELECT DISTINCT + {{ round_dollars('col_a') }} as rounded_col_a +FROM {{ source("sqlmesh_repo", "b") }} diff --git a/examples/multi_hybrid/dbt_repo/models/d.sql b/examples/multi_hybrid/dbt_repo/models/d.sql new file mode 100644 index 0000000000..d9c7de4a95 --- /dev/null +++ b/examples/multi_hybrid/dbt_repo/models/d.sql @@ -0,0 +1,3 @@ +SELECT + rounded_col_a +FROM {{ ref("c") }} diff --git a/examples/multi_hybrid/dbt_repo/models/e.sql b/examples/multi_hybrid/dbt_repo/models/e.sql new file mode 100644 index 0000000000..f9797074d0 --- /dev/null +++ b/examples/multi_hybrid/dbt_repo/models/e.sql @@ -0,0 +1,3 @@ +SELECT + 12.248 AS col_a, + 'b' AS col_b diff --git a/examples/multi_hybrid/dbt_repo/models/schema.yml b/examples/multi_hybrid/dbt_repo/models/schema.yml new file mode 100644 index 0000000000..8b76cac861 --- /dev/null +++ b/examples/multi_hybrid/dbt_repo/models/schema.yml @@ -0,0 +1,8 @@ +version: 2 + +sources: + - name: sqlmesh_repo + schema: sqlmesh_repo + tables: + - name: a + - name: b diff --git a/examples/multi_hybrid/dbt_repo/profiles.yml b/examples/multi_hybrid/dbt_repo/profiles.yml new file mode 100644 index 0000000000..bdeb799b0f --- /dev/null +++ b/examples/multi_hybrid/dbt_repo/profiles.yml @@ -0,0 +1,8 @@ +dbt_repo: + + target: prod + outputs: + prod: + type: duckdb + threads: 1 + schema: dbt_repo diff --git a/examples/multi_hybrid/sqlmesh_repo/config.yaml b/examples/multi_hybrid/sqlmesh_repo/config.yaml new file mode 100644 index 0000000000..9ef893eb8b --- /dev/null +++ b/examples/multi_hybrid/sqlmesh_repo/config.yaml @@ -0,0 +1,12 @@ +project: sqlmesh_repo + +gateways: + prod: + connection: + type: duckdb + database: prod.duckdb + +default_gateway: prod + +model_defaults: + dialect: duckdb diff --git a/examples/multi_hybrid/sqlmesh_repo/models/a.sql b/examples/multi_hybrid/sqlmesh_repo/models/a.sql new file mode 100644 index 0000000000..bf36e1a79d --- /dev/null +++ b/examples/multi_hybrid/sqlmesh_repo/models/a.sql @@ -0,0 +1,8 @@ +MODEL ( + name sqlmesh_repo.a +); + +SELECT + col_a, + col_b +FROM dbt_repo.e; diff --git a/examples/multi_hybrid/sqlmesh_repo/models/b.sql b/examples/multi_hybrid/sqlmesh_repo/models/b.sql new file mode 100644 index 0000000000..a40a6a5f97 --- /dev/null +++ b/examples/multi_hybrid/sqlmesh_repo/models/b.sql @@ -0,0 +1,7 @@ +MODEL ( + name sqlmesh_repo.b +); + +SELECT + col_a, col_b +FROM sqlmesh_repo.a diff --git a/examples/sushi/audits/raw_demographics.sql b/examples/sushi/audits/raw_demographics.sql new file mode 100644 index 0000000000..de33d52b84 --- /dev/null +++ b/examples/sushi/audits/raw_demographics.sql @@ -0,0 +1,6 @@ +AUDIT ( + name assert_raw_demographics +); +SELECT customer_id +FROM @this_model +WHERE customer_id <> 1 diff --git a/examples/sushi/config.py b/examples/sushi/config.py index afa8a83d35..b985e24ec5 100644 --- a/examples/sushi/config.py +++ b/examples/sushi/config.py @@ -1,7 +1,7 @@ import os +from sqlmesh.core.config.common import VirtualEnvironmentMode, TableNamingConvention from sqlmesh.core.config import ( - AirflowSchedulerConfig, AutoCategorizationMode, BigQueryConnectionConfig, CategorizerConfig, @@ -11,8 +11,8 @@ GatewayConfig, ModelDefaultsConfig, PlanConfig, - SparkConnectionConfig, ) +from sqlmesh.core.config.linter import LinterConfig from sqlmesh.core.notification_target import ( BasicSMTPNotificationTarget, SlackApiNotificationTarget, @@ -27,12 +27,37 @@ defaults = {"dialect": "duckdb"} model_defaults = ModelDefaultsConfig(**defaults) model_defaults_iceberg = ModelDefaultsConfig(**defaults, storage_format="iceberg") +before_all = [ + "CREATE SCHEMA IF NOT EXISTS raw", + "DROP VIEW IF EXISTS raw.demographics", + "CREATE VIEW raw.demographics AS (SELECT 1 AS customer_id, '00000' AS zip)", +] -# An in memory DuckDB config. +# A DuckDB config, in-memory by default. config = Config( - default_connection=DuckDBConnectionConfig(), + gateways={ + "duckdb": GatewayConfig( + connection=DuckDBConnectionConfig(), + ), + "duckdb_persistent": GatewayConfig( + connection=DuckDBConnectionConfig(database=f"{DATA_DIR}/duckdb.db"), + ), + }, + default_gateway="duckdb", model_defaults=model_defaults, + linter=LinterConfig( + enabled=False, + rules=[ + "ambiguousorinvalidcolumn", + "invalidselectstarexpansion", + "noselectstar", + "nomissingaudits", + "nomissingowner", + "nomissingexternalmodels", + ], + ), + before_all=before_all, ) bigquery_config = Config( @@ -44,47 +69,38 @@ }, default_gateway="bq", model_defaults=model_defaults, + before_all=before_all, ) # A configuration used for SQLMesh tests. test_config = Config( gateways={"in_memory": GatewayConfig(connection=DuckDBConnectionConfig())}, default_gateway="in_memory", - plan=PlanConfig(auto_categorize_changes=CategorizerConfig(sql=AutoCategorizationMode.SEMI)), - model_defaults=model_defaults, -) - -# A stateful DuckDB config. -local_config = Config( - default_connection=DuckDBConnectionConfig(database=f"{DATA_DIR}/local.duckdb"), - model_defaults=model_defaults, -) - -airflow_config = Config( - default_scheduler=AirflowSchedulerConfig(), - gateways=GatewayConfig( - connection=SparkConnectionConfig( - config_dir=os.path.join(CURRENT_FILE_PATH, "..", "airflow", "spark_conf"), - config={ - "spark.hadoop.javax.jdo.option.ConnectionURL": "jdbc:postgresql://localhost:5432/metastore_db" - }, + plan=PlanConfig( + auto_categorize_changes=CategorizerConfig( + sql=AutoCategorizationMode.SEMI, python=AutoCategorizationMode.OFF ) ), - model_defaults=model_defaults_iceberg, + model_defaults=model_defaults, + before_all=before_all, ) - -airflow_config_docker = Config( - default_scheduler=AirflowSchedulerConfig(airflow_url="http://airflow-webserver:8080/"), - gateways=GatewayConfig(connection=SparkConnectionConfig()), - model_defaults=model_defaults_iceberg, +# A configuration used for SQLMesh tests with virtual environment mode set to DEV_ONLY. +test_config_virtual_environment_mode_dev_only = test_config.copy( + update={ + "virtual_environment_mode": VirtualEnvironmentMode.DEV_ONLY, + "plan": PlanConfig( + auto_categorize_changes=CategorizerConfig.all_full(), + ), + }, ) # A DuckDB config with a physical schema map. map_config = Config( default_connection=DuckDBConnectionConfig(), - physical_schema_override={"sushi": "company_internal"}, + physical_schema_mapping={"^sushi$": "company_internal"}, model_defaults=model_defaults, + before_all=before_all, ) # A config representing isolated systems with a gateway per system @@ -96,6 +112,7 @@ }, default_gateway="dev", model_defaults=model_defaults, + before_all=before_all, ) required_approvers_config = Config( @@ -130,15 +147,22 @@ ), ], model_defaults=model_defaults, + before_all=before_all, ) -environment_suffix_config = Config( +environment_suffix_table_config = Config( default_connection=DuckDBConnectionConfig(), model_defaults=model_defaults, environment_suffix_target=EnvironmentSuffixTarget.TABLE, + before_all=before_all, ) +environment_suffix_catalog_config = environment_suffix_table_config.model_copy( + update={ + "environment_suffix_target": EnvironmentSuffixTarget.CATALOG, + }, +) CATALOGS = { "in_memory": ":memory:", @@ -149,6 +173,7 @@ default_connection=DuckDBConnectionConfig(catalogs=CATALOGS), default_test_connection=DuckDBConnectionConfig(catalogs=CATALOGS), model_defaults=model_defaults, + before_all=before_all, ) environment_catalog_mapping_config = Config( @@ -165,4 +190,13 @@ "^prod$": "prod_catalog", ".*": "dev_catalog", }, + before_all=before_all, +) + +hash_md5_naming_config = config.copy( + update={"physical_table_naming_convention": TableNamingConvention.HASH_MD5} +) + +table_only_naming_config = config.copy( + update={"physical_table_naming_convention": TableNamingConvention.TABLE_ONLY} ) diff --git a/examples/sushi/external_models.yaml b/examples/sushi/external_models.yaml index dab7c24d71..c83446a772 100644 --- a/examples/sushi/external_models.yaml +++ b/examples/sushi/external_models.yaml @@ -1,5 +1,6 @@ - name: raw.demographics description: Table containing demographics information + dialect: duckdb start: 1 week ago audits: - name: not_null @@ -8,6 +9,7 @@ column: zip min_v: "'00000'" max_v: "'99999'" + - name: assert_raw_demographics columns: customer_id: int zip: text diff --git a/examples/sushi/external_models/model1.yaml b/examples/sushi/external_models/model1/model1.yaml similarity index 100% rename from examples/sushi/external_models/model1.yaml rename to examples/sushi/external_models/model1/model1.yaml diff --git a/examples/sushi/external_models/model2.yaml b/examples/sushi/external_models/model2/model2.yaml similarity index 100% rename from examples/sushi/external_models/model2.yaml rename to examples/sushi/external_models/model2/model2.yaml diff --git a/examples/sushi/helper.py b/examples/sushi/helper.py index 4d5c1ad960..9a853b0903 100644 --- a/examples/sushi/helper.py +++ b/examples/sushi/helper.py @@ -2,7 +2,7 @@ import typing as t from datetime import datetime, timedelta -import numpy as np +import numpy as np # noqa: TID253 def set_seed(dt: datetime) -> None: diff --git a/examples/sushi/linter/__init__.py b/examples/sushi/linter/__init__.py new file mode 100644 index 0000000000..fc4d40b05e --- /dev/null +++ b/examples/sushi/linter/__init__.py @@ -0,0 +1 @@ +# this makes "linter" a package so "linter.user" is a valid module for importlib.import_module() to load diff --git a/examples/sushi/linter/user.py b/examples/sushi/linter/user.py new file mode 100644 index 0000000000..3f83c76e3a --- /dev/null +++ b/examples/sushi/linter/user.py @@ -0,0 +1,16 @@ +from __future__ import annotations + +import typing as t + +from sqlmesh.core.context import Context +from sqlmesh.core.linter.rule import Rule, RuleViolation +from sqlmesh.core.model import Model + + +class NoMissingOwner(Rule): + """All models should have an owner specified.""" + + def check_model(self, model: Model) -> t.Optional[RuleViolation]: + assert isinstance(self.context, Context) + assert len(self.context.models) > 10 + return self.violation() if not model.owner else None diff --git a/examples/sushi/macros/__init__.py b/examples/sushi/macros/__init__.py index e69de29bb2..749326c4d3 100644 --- a/examples/sushi/macros/__init__.py +++ b/examples/sushi/macros/__init__.py @@ -0,0 +1,6 @@ +from sqlmesh import macro + + +@macro() +def noop(context): + return "SELECT 1" diff --git a/examples/sushi/macros/macros.py b/examples/sushi/macros/macros.py index f4c2b6a829..763f6d62b2 100644 --- a/examples/sushi/macros/macros.py +++ b/examples/sushi/macros/macros.py @@ -7,3 +7,22 @@ @macro() def incremental_by_ds(evaluator, column: exp.Column): return between(evaluator, column, evaluator.locals["start_date"], evaluator.locals["end_date"]) + + +@macro() +def assert_has_columns(evaluator, model, columns_to_types): + if evaluator.runtime_stage == "creating": + expected_schema = { + column_type.name: exp.maybe_parse( + column_type.text("expression"), into=exp.DataType, dialect=evaluator.dialect + ) + for column_type in columns_to_types.expressions + } + assert expected_schema.items() <= evaluator.columns_to_types(model).items() + + return None + + +@macro() +def waiter_names_threshold(evaluator): + return 200 diff --git a/examples/sushi/macros/utils.py b/examples/sushi/macros/utils.py index fbc4c04e31..fb2ccc21b0 100644 --- a/examples/sushi/macros/utils.py +++ b/examples/sushi/macros/utils.py @@ -5,14 +5,14 @@ @macro() def add_one(evaluator, column: int): - # typed column will be cast to an int and return an integer back + """typed column will be cast to an int and return an integer back""" assert isinstance(column, int) return column + 1 @macro() def multiply(evaluator, column, num): - # untyped column will be a sqlglot column and return a sqlglot exp "column > 0" + """untyped column will be a sqlglot column and return a sqlglot exp "column > 0""" assert isinstance(column, exp.Column) return column * num @@ -26,6 +26,7 @@ def sql_literal( column_str: str, column_quoted: str, ): + """A macro that accepts various types of SQL literals and returns the column.""" assert isinstance(column, str) assert isinstance(str_lit, str) assert str_lit == "'x'" diff --git a/examples/sushi/models/blueprint.sql b/examples/sushi/models/blueprint.sql new file mode 100644 index 0000000000..54f797dba4 --- /dev/null +++ b/examples/sushi/models/blueprint.sql @@ -0,0 +1,21 @@ +MODEL ( + name @name, + kind FULL, + description "Count of customers by status, done with a fancy unnecessary blueprint", + grain status, + blueprints ( + ( + name := sushi.count_customers_active, + blueprint_status := 'active', + ), + ( + name := sushi.count_customers_inactive, + blueprint_status := 'inactive', + ) + ) +); + +SELECT + COUNT(customer_id) AS count_customers +FROM sushi.customers +WHERE status = @blueprint_status; \ No newline at end of file diff --git a/examples/sushi/models/customer_revenue_by_day.sql b/examples/sushi/models/customer_revenue_by_day.sql index 3b7f3724cb..248af2db8d 100644 --- a/examples/sushi/models/customer_revenue_by_day.sql +++ b/examples/sushi/models/customer_revenue_by_day.sql @@ -21,7 +21,7 @@ WITH order_total AS ( LEFT JOIN sushi.items AS i ON oi.item_id = i.id AND oi.event_date = i.event_date WHERE - oi.event_date BETWEEN CAST('{{ start_ds }}' as DATE) AND CAST('{{ end_ds }}' as DATE) + oi.event_date BETWEEN @start_date AND @end_date GROUP BY oi.order_id, oi.event_date @@ -35,7 +35,7 @@ FROM sushi.orders AS o LEFT JOIN order_total AS ot ON o.id = ot.order_id AND o.event_date = ot.event_date WHERE - o.event_date BETWEEN CAST('{{ start_ds }}' as DATE) AND CAST('{{ end_ds }}' as DATE) + o.event_date BETWEEN @start_date AND @end_date GROUP BY o.customer_id, o.event_date diff --git a/examples/sushi/models/customers.sql b/examples/sushi/models/customers.sql index fe8b3c5e57..d2bda09ed3 100644 --- a/examples/sushi/models/customers.sql +++ b/examples/sushi/models/customers.sql @@ -17,7 +17,7 @@ CREATE VIEW raw.demographics AS ( SELECT 1 AS customer_id, '00000' AS zip ); -WITH current_marketing AS ( +WITH current_marketing_outer AS ( SELECT customer_id, status @@ -29,7 +29,17 @@ SELECT DISTINCT m.status, d.zip FROM sushi.orders AS o -LEFT JOIN current_marketing AS m +LEFT JOIN ( + WITH current_marketing AS ( + SELECT + customer_id, + status, + @ADD_ONE(1) AS another_column, + FROM current_marketing_outer + ) + SELECT current_marketing.* FROM current_marketing WHERE current_marketing.customer_id != 100 +) AS m ON o.customer_id = m.customer_id LEFT JOIN raw.demographics AS d ON o.customer_id = d.customer_id +WHERE o.customer_id > 0 \ No newline at end of file diff --git a/examples/sushi/models/items.py b/examples/sushi/models/items.py index d20ea45d10..54c9442dc5 100644 --- a/examples/sushi/models/items.py +++ b/examples/sushi/models/items.py @@ -2,8 +2,8 @@ import typing as t from datetime import datetime -import numpy as np -import pandas as pd +import numpy as np # noqa: TID253 +import pandas as pd # noqa: TID253 from helper import iter_dates # type: ignore from sqlglot.expressions import to_column @@ -94,4 +94,4 @@ def execute( .rename(columns={"index": "id"}) ) - return pd.concat(dfs) + return pd.concat(dfs).reset_index(drop=True) diff --git a/examples/sushi/models/latest_order.sql b/examples/sushi/models/latest_order.sql new file mode 100644 index 0000000000..4523537505 --- /dev/null +++ b/examples/sushi/models/latest_order.sql @@ -0,0 +1,14 @@ +MODEL ( + name sushi.latest_order, + kind CUSTOM ( + materialization 'custom_full_with_custom_kind', + materialization_properties ( + custom_property = 'sushi!!!' + ) + ), + cron '@daily' +); + +SELECT id, customer_id, start_ts, end_ts, event_date +FROM sushi.orders +ORDER BY event_date DESC LIMIT 1 diff --git a/examples/sushi/models/marketing.sql b/examples/sushi/models/marketing.sql index 4c7bab851e..445fbf7787 100644 --- a/examples/sushi/models/marketing.sql +++ b/examples/sushi/models/marketing.sql @@ -12,4 +12,13 @@ SELECT status::TEXT AS status, updated_at::TIMESTAMP AS updated_at FROM - sushi.raw_marketing + sushi.raw_marketing; + +@assert_has_columns( + sushi.marketing, + { + customer_id: 'int', + status: 'text', + updated_at: 'timestamp', + } +) diff --git a/examples/sushi/models/order_items.py b/examples/sushi/models/order_items.py index 12d26c27ad..9d4dc551e3 100644 --- a/examples/sushi/models/order_items.py +++ b/examples/sushi/models/order_items.py @@ -2,8 +2,8 @@ import typing as t from datetime import datetime -import numpy as np -import pandas as pd +import numpy as np # noqa: TID253 +import pandas as pd # noqa: TID253 from helper import iter_dates # type: ignore from sqlglot import exp from sqlglot.expressions import to_column @@ -16,7 +16,7 @@ def get_items_table(context: ExecutionContext) -> str: - return context.table(ITEMS) + return context.resolve_table(ITEMS) @model( @@ -49,7 +49,7 @@ def execute( execution_time: datetime, **kwargs: t.Any, ) -> t.Generator[pd.DataFrame, None, None]: - orders_table = context.table("sushi.orders") + orders_table = context.resolve_table("sushi.orders") engine_dialect = context.engine_adapter.dialect items_table = get_items_table(context) @@ -99,4 +99,4 @@ def execute( .rename(columns={"index": "id"}) ) - yield pd.concat(dfs) + yield pd.concat(dfs).reset_index(drop=True) diff --git a/examples/sushi/models/orders.py b/examples/sushi/models/orders.py index 75b0b25b37..8d8718a3e3 100644 --- a/examples/sushi/models/orders.py +++ b/examples/sushi/models/orders.py @@ -2,7 +2,7 @@ import typing as t from datetime import datetime, timedelta -import pandas as pd +import pandas as pd # noqa: TID253 from helper import iter_dates # type: ignore from sqlmesh import ExecutionContext, model @@ -36,6 +36,7 @@ "end_ts": "int", "event_date": "date", }, + signals=[("test_signal", {"arg": 1})], ) def execute( context: ExecutionContext, @@ -69,4 +70,4 @@ def execute( .rename(columns={"index": "id"}) ) - return pd.concat(dfs) + return pd.concat(dfs).reset_index(drop=True) diff --git a/examples/sushi/models/raw_marketing.py b/examples/sushi/models/raw_marketing.py index 2c7f549f90..b17c471895 100644 --- a/examples/sushi/models/raw_marketing.py +++ b/examples/sushi/models/raw_marketing.py @@ -2,8 +2,8 @@ import typing as t from datetime import datetime -import numpy as np -import pandas as pd +import numpy as np # noqa: TID253 +import pandas as pd # noqa: TID253 from sqlglot import exp from sqlmesh import ExecutionContext, model @@ -35,7 +35,7 @@ def execute( **kwargs: t.Any, ) -> pd.DataFrame: # Generate query with sqlglot dialect/quoting - existing_table = context.table("sushi.raw_marketing") + existing_table = context.resolve_table("sushi.raw_marketing") engine_dialect = context.engine_adapter.dialect df_existing = context.fetchdf( @@ -59,7 +59,15 @@ def execute( "updated_at": [exec_time] * num_customers, } ) - df = df_new.merge(df_existing, on="customer_id", how="left", suffixes=(None, "_old")) + + # clickhouse returns a dataframe with no columns if the query is empty, so we can't merge + if not df_existing.empty: + df = df_new.merge(df_existing, on="customer_id", how="left", suffixes=(None, "_old")) + else: + df = df_new + df["status_old"] = pd.NA + df["updated_at_old"] = pd.NA + df["updated_at"] = pd.to_datetime( np.where( # type: ignore df["status_old"] != df["status"], execution_time, df["updated_at_old"] @@ -67,5 +75,5 @@ def execute( errors="coerce", utc=True, ) - df = df.drop(columns=["status_old", "updated_at_old"]) + df = df.drop(columns=["status_old", "updated_at_old"]).reset_index(drop=True) return df diff --git a/examples/sushi/models/waiter_as_customer_by_day.sql b/examples/sushi/models/waiter_as_customer_by_day.sql index 3f7e053c48..dd9f79b5a3 100644 --- a/examples/sushi/models/waiter_as_customer_by_day.sql +++ b/examples/sushi/models/waiter_as_customer_by_day.sql @@ -8,7 +8,11 @@ MODEL ( audits ( not_null(columns := (waiter_id)), forall(criteria := (LENGTH(waiter_name) > 0)) - ) + ), + signals ( + test_signal(arg := 1) + ), + ); JINJA_QUERY_BEGIN; @@ -23,6 +27,6 @@ SELECT FROM sushi.waiters AS w JOIN sushi.customers as c ON w.waiter_id = c.customer_id JOIN sushi.waiter_names as wn ON w.waiter_id = wn.id -WHERE w.event_date BETWEEN @start_date AND @end_date; +WHERE w.event_date BETWEEN CAST('{{ start_ds }}' as DATE) AND @end_date; JINJA_END; diff --git a/examples/sushi/models/waiter_names.sql b/examples/sushi/models/waiter_names.sql index 9f3ac04f72..9410364a29 100644 --- a/examples/sushi/models/waiter_names.sql +++ b/examples/sushi/models/waiter_names.sql @@ -9,7 +9,7 @@ MODEL ( description 'List of waiter names', audits ( assert_positive_id, - does_not_exceed_threshold(column := id, threshold := 200), + does_not_exceed_threshold(column := id, threshold := @waiter_names_threshold()), assert_valid_name, ) ); diff --git a/examples/sushi/models/waiters.py b/examples/sushi/models/waiters.py index e1de3786cc..f9e26eda82 100644 --- a/examples/sushi/models/waiters.py +++ b/examples/sushi/models/waiters.py @@ -12,6 +12,7 @@ kind=dict(name=ModelKindName.EMBEDDED), owner="jen", cron="@daily", + pre_statements=["@noop()"], ) def entrypoint(evaluator: MacroEvaluator) -> exp.Select: """ @@ -41,14 +42,27 @@ def entrypoint(evaluator: MacroEvaluator) -> exp.Select: # There are tests which force not having a default catalog so we check here if one is defined # and add it to the name if it is default_catalog = evaluator.default_catalog + parent_snapshot_name = parent_snapshots[0].name + parent_snapshot_catalog = parent_snapshot_name.split(".")[0].strip('"') + if default_catalog: # make sure we don't double quote the default catalog default_catalog = default_catalog.strip('"') + + # Snowflake normalizes unquoted names to uppercase, which can cause case mismatches with + # default_catalog due to sqlmesh's default normalization behavior. SQLMesh addresses this + # by rewriting the default catalog name on the fly in snowflake `_to_sql()`. This model + # code manually extracts default_catalog name, so we manually lowercase the default_catalog + # if the parent catalog name is an uppercase version of the default catalog name. + default_catalog = ( + default_catalog.lower() + if parent_snapshot_catalog == default_catalog.lower() + else default_catalog + ) + name = ".".join([f'"{default_catalog}"', name]) - assert ( - parent_snapshots[0].name == name - ), f"Snapshot Name: {parent_snapshots[0].name}, Name: {name}" + assert parent_snapshot_name == name, f"Snapshot Name: {parent_snapshot_name}, Name: {name}" excluded = {"id", "customer_id", "start_ts", "end_ts"} projections = [] diff --git a/examples/sushi/signals/__init__.py b/examples/sushi/signals/__init__.py new file mode 100644 index 0000000000..bd7c839fce --- /dev/null +++ b/examples/sushi/signals/__init__.py @@ -0,0 +1,9 @@ +import typing as t + +from sqlmesh import signal, DatetimeRanges + + +@signal() +def test_signal(batch: DatetimeRanges, arg: int = 0) -> t.Union[bool, DatetimeRanges]: + assert arg == 1 + return True diff --git a/examples/sushi/sqlmesh-requirements.lock b/examples/sushi/sqlmesh-requirements.lock new file mode 100644 index 0000000000..4b2e332e24 --- /dev/null +++ b/examples/sushi/sqlmesh-requirements.lock @@ -0,0 +1 @@ +pandas==2.2.2 diff --git a/examples/sushi_dbt/config.py b/examples/sushi_dbt/config.py index d5cdd7b874..2305cf79f2 100644 --- a/examples/sushi_dbt/config.py +++ b/examples/sushi_dbt/config.py @@ -1,14 +1,9 @@ from pathlib import Path -from sqlmesh.core.config import AirflowSchedulerConfig from sqlmesh.dbt.loader import sqlmesh_config config = sqlmesh_config(Path(__file__).parent) test_config = config - -airflow_config = sqlmesh_config( - Path(__file__).parent, - default_scheduler=AirflowSchedulerConfig(), -) +migration_test_config = sqlmesh_config(Path(__file__).parent, dbt_target_name="duckdb") diff --git a/examples/sushi_dbt/models/customer_revenue_by_day.sql b/examples/sushi_dbt/models/customer_revenue_by_day.sql index f3f49cfc14..9810481eff 100644 --- a/examples/sushi_dbt/models/customer_revenue_by_day.sql +++ b/examples/sushi_dbt/models/customer_revenue_by_day.sql @@ -1,7 +1,7 @@ {{ config( materialized='incremental', - incremental_strategy='delete+insert', + incremental_strategy='incremental_by_time_range', cluster_by=['ds'], time_column='ds', ) diff --git a/examples/sushi_dbt/models/schema.yml b/examples/sushi_dbt/models/schema.yml index eb5d288d5d..8fd62c4efe 100644 --- a/examples/sushi_dbt/models/schema.yml +++ b/examples/sushi_dbt/models/schema.yml @@ -21,6 +21,8 @@ models: tests: - less_than_amount: amount: 1000 + - greater_than_amount: + amount: 0 - name: ds description: Date - name: top_waiters @@ -34,12 +36,12 @@ models: field: waiter_id - name: revenue description: Revenue from orders served by this waiter + - name: unused_column + data_type: int - name: waiters columns: - name: waiter_id description: Waiter id - tests: - - not_null - name: ds description: Date - name: waiter_as_customer_by_day diff --git a/examples/sushi_dbt/models/top_waiters.sql b/examples/sushi_dbt/models/top_waiters.sql index f839b31dc2..e4a74fd8b3 100644 --- a/examples/sushi_dbt/models/top_waiters.sql +++ b/examples/sushi_dbt/models/top_waiters.sql @@ -6,7 +6,8 @@ SELECT waiter_id::INT AS waiter_id, - revenue::DOUBLE AS revenue + revenue::DOUBLE AS revenue, + 1 AS unused_column FROM {{ ref('waiter_revenue_by_day', version=1) }} WHERE ds = ( diff --git a/examples/sushi_dbt/models/waiter_as_customer_by_day.sql b/examples/sushi_dbt/models/waiter_as_customer_by_day.sql index 3d4967aec7..a1145c2b5c 100644 --- a/examples/sushi_dbt/models/waiter_as_customer_by_day.sql +++ b/examples/sushi_dbt/models/waiter_as_customer_by_day.sql @@ -1,7 +1,7 @@ {{ config( materialized='incremental', - incremental_strategy='delete+insert', + incremental_strategy='incremental_by_time_range', cluster_by=['ds'], time_column='ds', ) diff --git a/examples/sushi_dbt/models/waiter_revenue_by_day.sql b/examples/sushi_dbt/models/waiter_revenue_by_day.sql index d430c6125b..670e238962 100644 --- a/examples/sushi_dbt/models/waiter_revenue_by_day.sql +++ b/examples/sushi_dbt/models/waiter_revenue_by_day.sql @@ -1,7 +1,7 @@ {{ config( materialized='incremental', - incremental_strategy='delete+insert', + incremental_strategy='incremental_by_time_range', cluster_by=['ds'], time_column='ds', ) diff --git a/examples/sushi_dbt/models/waiter_revenue_by_day_v1.sql b/examples/sushi_dbt/models/waiter_revenue_by_day_v1.sql index d430c6125b..670e238962 100644 --- a/examples/sushi_dbt/models/waiter_revenue_by_day_v1.sql +++ b/examples/sushi_dbt/models/waiter_revenue_by_day_v1.sql @@ -1,7 +1,7 @@ {{ config( materialized='incremental', - incremental_strategy='delete+insert', + incremental_strategy='incremental_by_time_range', cluster_by=['ds'], time_column='ds', ) diff --git a/examples/sushi_dbt/profiles.yml b/examples/sushi_dbt/profiles.yml index e279803368..794b083793 100644 --- a/examples/sushi_dbt/profiles.yml +++ b/examples/sushi_dbt/profiles.yml @@ -3,6 +3,14 @@ sushi: in_memory: type: duckdb schema: sushi + postgres: + type: postgres + host: "host" + user: "user" + password: "password" + dbname: "dbname" + port: 5432 + schema: sushi duckdb: type: duckdb path: 'local.duckdb' @@ -53,4 +61,21 @@ sushi: database: "{{ env_var('TRINO_DATABASE') }}" schema: sushi threads: 1 + clickhouse: + type: clickhouse + host: "{{ env_var('CLICKHOUSE_HOST') }}" + port: 8123 + user: "{{ env_var('CLICKHOUSE_USER') }}" + password: "{{ env_var('CLICKHOUSE_PASSWORD') }}" + schema: sushi + threads: 1 + clickhouse_cluster: + type: clickhouse + host: "{{ env_var('CLICKHOUSE_HOST') }}" + port: 8123 + user: "{{ env_var('CLICKHOUSE_USER') }}" + password: "{{ env_var('CLICKHOUSE_PASSWORD') }}" + cluster: "{{ env_var('CLICKHOUSE_CLUSTER') }}" + schema: sushi + threads: 1 target: in_memory diff --git a/examples/sushi_dbt/tests/generic/greater_than_amount.sql b/examples/sushi_dbt/tests/generic/greater_than_amount.sql new file mode 100644 index 0000000000..0778e48a83 --- /dev/null +++ b/examples/sushi_dbt/tests/generic/greater_than_amount.sql @@ -0,0 +1,5 @@ +{%- test greater_than_amount(model, column_name, amount) -%} + select * + from {{ model }} + where {{ column_name }} < {{ amount }} +{%- endtest -%} diff --git a/examples/sushi_dlt/sushi_pipeline.py b/examples/sushi_dlt/sushi_pipeline.py new file mode 100644 index 0000000000..3a44a4897e --- /dev/null +++ b/examples/sushi_dlt/sushi_pipeline.py @@ -0,0 +1,92 @@ +import typing as t +import dlt + + +# Example sushi_types table +@dlt.resource(name="sushi_types", primary_key="id", write_disposition="merge") +def sushi_types() -> t.Iterator[t.Dict[str, t.Any]]: + yield from [ + {"id": 0, "name": "Tobiko"}, + {"id": 1, "name": "Sashimi"}, + {"id": 2, "name": "Maki"}, + {"id": 3, "name": "Temaki"}, + ] + + +# Example waiters table +@dlt.resource(name="waiters", primary_key="id", write_disposition="merge") +def waiters() -> t.Iterator[t.Dict[str, t.Any]]: + yield from [ + {"id": 0, "name": "Toby"}, + {"id": 1, "name": "Tyson"}, + {"id": 2, "name": "Ryan"}, + {"id": 3, "name": "George"}, + {"id": 4, "name": "Chris"}, + {"id": 5, "name": "Max"}, + {"id": 6, "name": "Vincent"}, + {"id": 7, "name": "Iaroslav"}, + {"id": 8, "name": "Emma"}, + {"id": 9, "name": "Maia"}, + ] + + +# Example sushi menu table with extra one and two levels of nesting tables +@dlt.resource(name="sushi_menu", primary_key="id", write_disposition="merge") +def sushi_menu() -> t.Iterator[t.Dict[str, t.Any]]: + yield from [ + { + "id": 0, + "name": "Tobiko", + "fillings": ["Red Tobiko", "Black Tobiko", "Wasabi Tobiko", "Green Tobiko"], + "details": { + "preparation": "Raw", + "ingredients": ["Seaweed", "Rice", "Tobiko"], + "price": 12.99, + "spicy": False, + }, + }, + { + "id": 1, + "name": "Sashimi", + "fillings": [ + "Tuna Sashimi", + "Salmon Sashimi", + "Yellowtail Sashimi", + "Octopus Sashimi", + "Scallop Sashimi", + ], + "details": { + "preparation": "Raw", + "ingredients": ["Fish", "Soy Sauce", "Wasabi"], + "price": 19.99, + "spicy": False, + }, + }, + { + "id": 2, + "name": "Maki", + "fillings": ["Cucumber", "Tuna", "Salmon", "Avocado", "Tempura Shrimp"], + "details": { + "preparation": "Rolled", + "ingredients": ["Seaweed", "Rice", "Fish", "Vegetables"], + "price": 14.99, + "spicy": True, + }, + }, + { + "id": 3, + "name": "Temaki", + "fillings": ["Tuna Temaki", "Salmon Temaki", "Vegetable Temaki", "Ebi Temaki"], + "details": { + "preparation": "Hand Roll", + "ingredients": ["Seaweed", "Rice", "Fish", "Vegetables"], + "price": 10.99, + "spicy": True, + }, + }, + ] + + +# Run the pipeline +p = dlt.pipeline(pipeline_name="sushi", destination="duckdb") +info = p.run([sushi_types(), waiters(), sushi_menu()]) diff --git a/examples/wursthall/config.yaml b/examples/wursthall/config.yaml index 299a2e51a9..d4a21061a4 100644 --- a/examples/wursthall/config.yaml +++ b/examples/wursthall/config.yaml @@ -3,6 +3,10 @@ gateways: duckdb: connection: type: duckdb + duckdb_persistent: + connection: + type: duckdb + database: duckdb.db model_defaults: dialect: 'duckdb' diff --git a/examples/wursthall/models/db/order_f.py b/examples/wursthall/models/db/order_f.py index c37ebda21c..d682d55f02 100644 --- a/examples/wursthall/models/db/order_f.py +++ b/examples/wursthall/models/db/order_f.py @@ -2,8 +2,8 @@ import typing as t from datetime import datetime -import numpy as np -import pandas as pd +import numpy as np # noqa: TID253 +import pandas as pd # noqa: TID253 from models.src.shared import DATA_START_DATE_STR, set_seed # type: ignore from sqlglot import parse_one @@ -41,8 +41,8 @@ def execute( end: datetime, **kwargs: t.Any, ) -> pd.DataFrame: - item_d_table_name = context.table("db.item_d") - order_item_f_table_name = context.table("db.order_item_f") + item_d_table_name = context.resolve_table("db.item_d") + order_item_f_table_name = context.resolve_table("db.order_item_f") # We use parse_one here instead of a raw string because this is a multi-dialect # project and we want to ensure that the resulting query is properly quoted in diff --git a/examples/wursthall/models/src/customer_details.py b/examples/wursthall/models/src/customer_details.py index 95914b1e53..44c56a60e5 100644 --- a/examples/wursthall/models/src/customer_details.py +++ b/examples/wursthall/models/src/customer_details.py @@ -3,7 +3,7 @@ from dataclasses import dataclass from datetime import datetime, timedelta -import pandas as pd +import pandas as pd # noqa: TID253 from faker import Faker from models.src.shared import DATA_START_DATE_STR, iter_dates, set_seed # type: ignore diff --git a/examples/wursthall/models/src/order_item_details.py b/examples/wursthall/models/src/order_item_details.py index 32ce003cb7..852250d2e2 100644 --- a/examples/wursthall/models/src/order_item_details.py +++ b/examples/wursthall/models/src/order_item_details.py @@ -3,8 +3,8 @@ from dataclasses import dataclass from datetime import datetime -import numpy as np -import pandas as pd +import numpy as np # noqa: TID253 +import pandas as pd # noqa: TID253 from faker import Faker from models.src.shared import DATA_START_DATE_STR, iter_dates, set_seed # type: ignore from sqlglot import parse_one diff --git a/examples/wursthall/models/src/shared.py b/examples/wursthall/models/src/shared.py index d220825ba4..d183a1c5db 100644 --- a/examples/wursthall/models/src/shared.py +++ b/examples/wursthall/models/src/shared.py @@ -4,7 +4,7 @@ import typing as t from datetime import date, timedelta -import numpy as np +import numpy as np # noqa: TID253 from faker import Faker SEED = 99999999 diff --git a/mkdocs.yml b/mkdocs.yml index d51ed6793b..86761de9d7 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -1,6 +1,6 @@ site_name: SQLMesh -repo_url: https://github.com/TobikoData/sqlmesh -repo_name: TobikoData/sqlmesh +repo_url: https://github.com/SQLMesh/sqlmesh +repo_name: SQLMesh/sqlmesh nav: - "Overview": index.md - Get started: @@ -15,6 +15,7 @@ nav: - guides/projects.md - guides/multi_repo.md - guides/isolated_systems.md + - guides/multi_engine.md - Project setup: - guides/configuration.md - guides/connections.md @@ -29,9 +30,13 @@ nav: - guides/testing.md - guides/model_selection.md - SQLMesh tools: - - guides/ui.md + - guides/vscode.md - guides/tablediff.md - - guides/observer.md + - guides/linter.md + - guides/ui.md + - Advanced usage: + - guides/customizing_sqlmesh.md + - guides/signals.md - Concepts: - concepts/overview.md - Development: @@ -39,6 +44,7 @@ nav: - concepts/environments.md - concepts/tests.md - concepts/audits.md + - concepts/state.md - Models: - concepts/models/overview.md - concepts/models/model_kinds.md @@ -59,22 +65,32 @@ nav: - concepts/architecture/snapshots.md - concepts/architecture/serialization.md - concepts/glossary.md + - Examples: + - examples/overview.md + - Walkthroughs: + - examples/sqlmesh_cli_crash_course.md + - examples/incremental_time_full_walkthrough.md - Integrations: - "Overview": integrations/overview.md - Tools: - - integrations/airflow.md - integrations/dbt.md + - integrations/dlt.md - integrations/github.md - Execution engines: + - integrations/engines/athena.md + - integrations/engines/azuresql.md - integrations/engines/bigquery.md + - integrations/engines/clickhouse.md - integrations/engines/databricks.md - integrations/engines/duckdb.md + - integrations/engines/fabric.md - integrations/engines/motherduck.md - integrations/engines/mssql.md - integrations/engines/mysql.md - integrations/engines/postgres.md - integrations/engines/gcp-postgres.md - integrations/engines/redshift.md + - integrations/engines/risingwave.md - integrations/engines/snowflake.md - integrations/engines/spark.md - integrations/engines/trino.md @@ -91,38 +107,77 @@ nav: - Configuration: - reference/configuration.md - reference/model_configuration.md + - Tobiko Cloud: # NOTE: if this item is no longer last, need to update extra.css to adjust logo positioning + - "Overview": cloud/cloud_index.md + - "Getting Started": cloud/tcloud_getting_started.md + - Cloud Features: + - "Alerts & Notifications": cloud/features/alerts_notifications.md + - cloud/features/data_catalog.md + - cloud/features/debugger_view.md + - Maintenance: + - cloud/features/incident_reporting.md + - cloud/features/upgrades.md + - Scheduler: + - "Cloud": cloud/features/scheduler/scheduler.md + - "Cloud Hybrid Deployments": + - "Overview": cloud/features/scheduler/hybrid_executors_overview.md + - "Helm Chart example": cloud/features/scheduler/hybrid_executors_helm.md + - "Docker Compose example": cloud/features/scheduler/hybrid_executors_docker_compose.md + - cloud/features/scheduler/airflow.md + - cloud/features/scheduler/dagster.md + - Security: + - cloud/features/security/security.md + - cloud/features/security/single_sign_on.md + - Tools: + - cloud/features/xdb_diffing.md +# - Observability: +# - cloud/features/observability/overview.md +# - cloud/features/observability/model_freshness.md +# - cloud/features/observability/prod_environment.md +# - cloud/features/observability/development_environment.md +# - cloud/features/observability/plan.md +# - cloud/features/observability/run.md +# - cloud/features/observability/model.md +# - "Measures & Dashboards": cloud/features/observability/measures_dashboards.md theme: name: material logo: _readthedocs/html/sqlmesh.png + favicon: _readthedocs/html/favicon.svg palette: - media: "(prefers-color-scheme: light)" scheme: default - primary: #091c79 - accent: #cce8fe + primary: #3C64E2 + accent: #3C64E2 toggle: icon: material/weather-sunny name: Switch to dark mode # Palette toggle for dark mode - media: "(prefers-color-scheme: dark)" scheme: slate - primary: #cce8fe - accent: #091c79 + primary: #3C64E2 + accent: #3C64E2 toggle: icon: material/weather-night name: Switch to light mode features: - content.tabs.link + - content.code.copy + - navigation.expand ## expands navigation bar by default - navigation.tracking - navigation.tabs + - navigation.tabs.sticky - navigation.sections - navigation.top - - toc.integrate - toc.follow + - search.suggest + - search.highlight plugins: - include-markdown - - search + - search: + separator: '[\s\-,:!=\[\]()"/_]+' - glightbox markdown_extensions: + - def_list - tables - pymdownx.highlight: anchor_linenums: true @@ -133,6 +188,10 @@ markdown_extensions: alternate_style: true - admonition - pymdownx.details + - attr_list + - md_in_html + - pymdownx.caret + - sane_lists extra_css: - stylesheets/extra.css copyright: Tobiko Data Inc. @@ -143,7 +202,7 @@ extra: - icon: fontawesome/solid/paper-plane link: mailto:hello@tobikodata.com - icon: fontawesome/brands/github - link: https://github.com/TobikoData/sqlmesh/issues/new + link: https://github.com/SQLMesh/sqlmesh/issues/new analytics: provider: google property: G-JXQ1R227VS diff --git a/package.json b/package.json new file mode 100644 index 0000000000..6d49a853ef --- /dev/null +++ b/package.json @@ -0,0 +1,16 @@ +{ + "engines": { + "node": ">=20.0.0", + "pnpm": ">=10.0.0" + }, + "scripts": { + "ci": "pnpm run lint && pnpm run -r ci", + "fmt": "prettier --write .", + "fmt:check": "prettier --check .", + "lint": "pnpm run fmt:check && pnpm run -r lint", + "lint:fix": "pnpm run fmt && pnpm run -r lint:fix" + }, + "devDependencies": { + "prettier": "^3.6.2" + } +} diff --git a/pdoc/cli.py b/pdoc/cli.py index e603302bba..9301ae0444 100755 --- a/pdoc/cli.py +++ b/pdoc/cli.py @@ -29,7 +29,7 @@ def mocked_import(*args, **kwargs): opts.logo_link = "https://tobikodata.com" opts.footer_text = "Copyright Tobiko Data Inc. 2022" opts.template_directory = Path(__file__).parent.joinpath("templates").absolute() - opts.edit_url = ["sqlmesh=https://github.com/TobikoData/sqlmesh/"] + opts.edit_url = ["sqlmesh=https://github.com/SQLMesh/sqlmesh/tree/main/sqlmesh/"] with mock.patch("pdoc.__main__.parser", **{"parse_args.return_value": opts}): cli() diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml new file mode 100644 index 0000000000..aeacb362d0 --- /dev/null +++ b/pnpm-lock.yaml @@ -0,0 +1,14475 @@ +lockfileVersion: '9.0' + +settings: + autoInstallPeers: true + excludeLinksFromLockfile: false + +importers: + + .: + devDependencies: + prettier: + specifier: ^3.6.2 + version: 3.6.2 + + vscode/bus: + devDependencies: + typescript: + specifier: ^5.8.3 + version: 5.8.3 + + vscode/extension: + dependencies: + '@duckdb/node-api': + specifier: 1.3.2-alpha.25 + version: 1.3.2-alpha.25 + '@types/fs-extra': + specifier: ^11.0.4 + version: 11.0.4 + '@types/shell-quote': + specifier: ^1.7.5 + version: 1.7.5 + '@vscode/python-extension': + specifier: ^1.0.5 + version: 1.0.5 + fs-extra: + specifier: ^11.3.0 + version: 11.3.0 + shell-quote: + specifier: ^1.8.3 + version: 1.8.3 + vscode-jsonrpc: + specifier: ^8.2.1 + version: 8.2.1 + vscode-languageclient: + specifier: ^9.0.1 + version: 9.0.1 + zod: + specifier: ^3.25.76 + version: 3.25.76 + devDependencies: + '@eslint/js': + specifier: ^9.31.0 + version: 9.31.0 + '@playwright/test': + specifier: ^1.54.1 + version: 1.54.1 + '@types/mocha': + specifier: ^10.0.10 + version: 10.0.10 + '@types/node': + specifier: 20.11.25 + version: 20.11.25 + '@types/vscode': + specifier: 1.96.0 + version: 1.96.0 + '@vitest/ui': + specifier: ^3.2.4 + version: 3.2.4(vitest@3.2.4) + '@vscode/test-cli': + specifier: ^0.0.10 + version: 0.0.10 + '@vscode/test-electron': + specifier: ^2.5.2 + version: 2.5.2 + '@vscode/vsce': + specifier: ^3.6.0 + version: 3.6.0 + esbuild: + specifier: ^0.25.8 + version: 0.25.8 + eslint: + specifier: ^9.31.0 + version: 9.31.0(jiti@2.4.2) + ts-loader: + specifier: ^9.5.2 + version: 9.5.2(typescript@5.8.3)(webpack@5.99.8(esbuild@0.25.8)) + typescript: + specifier: ^5.8.3 + version: 5.8.3 + typescript-eslint: + specifier: ^8.38.0 + version: 8.38.0(eslint@9.31.0(jiti@2.4.2))(typescript@5.8.3) + vitest: + specifier: ^3.2.4 + version: 3.2.4(@types/debug@4.1.12)(@types/node@20.11.25)(@vitest/browser@3.2.4)(@vitest/ui@3.2.4)(jiti@2.4.2)(jsdom@26.1.0)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + yaml: + specifier: ^2.8.0 + version: 2.8.0 + + vscode/react: + dependencies: + '@headlessui/react': + specifier: ^2.2.5 + version: 2.2.5(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@heroicons/react': + specifier: ^2.2.0 + version: 2.2.0(react@18.3.1) + '@radix-ui/react-select': + specifier: ^2.2.5 + version: 2.2.5(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@tailwindcss/postcss': + specifier: ^4.1.11 + version: 4.1.11 + '@tailwindcss/vite': + specifier: ^4.1.11 + version: 4.1.11(vite@6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)) + '@tanstack/react-query': + specifier: ^5.83.0 + version: 5.83.0(react@18.3.1) + '@tanstack/react-router': + specifier: ^1.129.8 + version: 1.129.8(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@tanstack/react-router-devtools': + specifier: ^1.131.26 + version: 1.131.26(@tanstack/react-router@1.129.8(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(@tanstack/router-core@1.129.8)(csstype@3.1.3)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(solid-js@1.9.7)(tiny-invariant@1.3.3) + '@tanstack/react-virtual': + specifier: ^3.13.12 + version: 3.13.12(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@tanstack/router-plugin': + specifier: ^1.129.8 + version: 1.129.8(@tanstack/react-router@1.129.8(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(vite@6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0))(webpack@5.99.8(esbuild@0.25.8)) + apache-arrow: + specifier: ^19.0.1 + version: 19.0.1 + clsx: + specifier: ^2.1.1 + version: 2.1.1 + elkjs: + specifier: ^0.8.2 + version: 0.8.2 + orval: + specifier: ^7.10.0 + version: 7.10.0(openapi-types@12.1.3) + react: + specifier: ^18.3.1 + version: 18.3.1 + react-dom: + specifier: ^18.3.1 + version: 18.3.1(react@18.3.1) + react-router: + specifier: ^7.7.0 + version: 7.7.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + reactflow: + specifier: ^11.11.4 + version: 11.11.4(@types/react@18.3.23)(immer@9.0.21)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + tailwindcss: + specifier: ^4.1.11 + version: 4.1.11 + vscode-uri: + specifier: ^3.1.0 + version: 3.1.0 + devDependencies: + '@chromatic-com/storybook': + specifier: ^4.0.1 + version: 4.0.1(storybook@9.0.18(@testing-library/dom@10.4.1)(prettier@3.6.2)) + '@storybook/addon-a11y': + specifier: ^9.0.18 + version: 9.0.18(storybook@9.0.18(@testing-library/dom@10.4.1)(prettier@3.6.2)) + '@storybook/addon-docs': + specifier: ^9.0.18 + version: 9.0.18(@types/react@18.3.23)(storybook@9.0.18(@testing-library/dom@10.4.1)(prettier@3.6.2)) + '@storybook/addon-onboarding': + specifier: ^9.0.18 + version: 9.0.18(storybook@9.0.18(@testing-library/dom@10.4.1)(prettier@3.6.2)) + '@storybook/addon-vitest': + specifier: ^9.0.18 + version: 9.0.18(@vitest/browser@3.2.3)(@vitest/runner@3.2.4)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@9.0.18(@testing-library/dom@10.4.1)(prettier@3.6.2))(vitest@3.2.4) + '@storybook/react-vite': + specifier: ^9.0.18 + version: 9.0.18(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(rollup@4.45.1)(storybook@9.0.18(@testing-library/dom@10.4.1)(prettier@3.6.2))(typescript@5.8.3)(vite@6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)) + '@testing-library/dom': + specifier: ^10.4.1 + version: 10.4.1 + '@testing-library/react': + specifier: ^16.3.0 + version: 16.3.0(@testing-library/dom@10.4.1)(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@types/react': + specifier: ^18.3.23 + version: 18.3.23 + '@types/react-dom': + specifier: ^18.3.7 + version: 18.3.7(@types/react@18.3.23) + '@vitejs/plugin-react': + specifier: ^4.7.0 + version: 4.7.0(vite@6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)) + '@vitest/browser': + specifier: 3.2.3 + version: 3.2.3(playwright@1.54.1)(vite@6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0))(vitest@3.2.4) + '@vitest/coverage-v8': + specifier: 3.2.3 + version: 3.2.3(@vitest/browser@3.2.3)(vitest@3.2.4) + jsdom: + specifier: ^26.1.0 + version: 26.1.0 + playwright: + specifier: ^1.54.1 + version: 1.54.1 + storybook: + specifier: ^9.0.18 + version: 9.0.18(@testing-library/dom@10.4.1)(prettier@3.6.2) + typescript: + specifier: ^5.8.3 + version: 5.8.3 + vite: + specifier: ^6.3.5 + version: 6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + vitest: + specifier: ^3.2.4 + version: 3.2.4(@types/debug@4.1.12)(@types/node@24.1.0)(@vitest/browser@3.2.3)(@vitest/ui@3.2.4)(jiti@2.4.2)(jsdom@26.1.0)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + web-vitals: + specifier: ^4.2.4 + version: 4.2.4 + + web/client: + dependencies: + '@codemirror/autocomplete': + specifier: ^6.18.6 + version: 6.18.6 + '@codemirror/commands': + specifier: ^6.8.1 + version: 6.8.1 + '@codemirror/lang-python': + specifier: ^6.2.1 + version: 6.2.1 + '@codemirror/lang-sql': + specifier: ^6.9.0 + version: 6.9.0 + '@codemirror/language': + specifier: ^6.11.2 + version: 6.11.2 + '@codemirror/legacy-modes': + specifier: ^6.5.1 + version: 6.5.1 + '@codemirror/state': + specifier: ^6.5.2 + version: 6.5.2 + '@codemirror/view': + specifier: ^6.38.1 + version: 6.38.1 + '@headlessui/react': + specifier: ^2.2.5 + version: 2.2.5(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@heroicons/react': + specifier: ^2.2.0 + version: 2.2.0(react@18.3.1) + '@lit/react': + specifier: ^1.0.8 + version: 1.0.8(@types/react@18.3.23) + '@radix-ui/react-context-menu': + specifier: ^2.2.15 + version: 2.2.15(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-select': + specifier: ^2.2.5 + version: 2.2.5(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@tailwindcss/container-queries': + specifier: ^0.1.1 + version: 0.1.1(tailwindcss@3.4.17) + '@tanstack/react-query': + specifier: ^5.83.0 + version: 5.83.0(react@18.3.1) + '@tanstack/react-table': + specifier: ^8.21.3 + version: 8.21.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@tanstack/react-virtual': + specifier: ^3.13.12 + version: 3.13.12(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@uidotdev/usehooks': + specifier: ^2.4.1 + version: 2.4.1(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@uiw/react-codemirror': + specifier: ^4.24.1 + version: 4.24.1(@babel/runtime@7.28.2)(@codemirror/autocomplete@6.18.6)(@codemirror/language@6.11.2)(@codemirror/lint@6.8.5)(@codemirror/search@6.5.10)(@codemirror/state@6.5.2)(@codemirror/theme-one-dark@6.1.2)(@codemirror/view@6.38.1)(codemirror@6.0.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + apache-arrow: + specifier: ^19.0.1 + version: 19.0.1 + clsx: + specifier: ^2.1.1 + version: 2.1.1 + diff: + specifier: ^8.0.2 + version: 8.0.2 + elkjs: + specifier: ^0.8.2 + version: 0.8.2 + pluralize: + specifier: ^8.0.0 + version: 8.0.0 + react: + specifier: ^18.3.1 + version: 18.3.1 + react-dnd: + specifier: ^16.0.1 + version: 16.0.1(@types/node@24.1.0)(@types/react@18.3.23)(react@18.3.1) + react-dnd-html5-backend: + specifier: ^16.0.1 + version: 16.0.1 + react-dom: + specifier: ^18.3.1 + version: 18.3.1(react@18.3.1) + react-markdown: + specifier: ^10.1.0 + version: 10.1.0(@types/react@18.3.23)(react@18.3.1) + react-router: + specifier: ^7.7.0 + version: 7.7.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + react-split: + specifier: ^2.0.14 + version: 2.0.14(react@18.3.1) + reactflow: + specifier: ^11.11.4 + version: 11.11.4(@types/react@18.3.23)(immer@9.0.21)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + thememirror: + specifier: ^2.0.1 + version: 2.0.1(@codemirror/language@6.11.2)(@codemirror/state@6.5.2)(@codemirror/view@6.38.1) + zustand: + specifier: ^5.0.6 + version: 5.0.6(@types/react@18.3.23)(immer@9.0.21)(react@18.3.1)(use-sync-external-store@1.5.0(react@18.3.1)) + devDependencies: + '@eslint/js': + specifier: ^9.31.0 + version: 9.31.0 + '@playwright/test': + specifier: ^1.54.1 + version: 1.54.1 + '@swc/core': + specifier: ^1.13.2 + version: 1.13.2(@swc/helpers@0.5.17) + '@testing-library/jest-dom': + specifier: ^6.6.3 + version: 6.6.3 + '@testing-library/react': + specifier: ^16.3.0 + version: 16.3.0(@testing-library/dom@10.4.1)(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@testing-library/user-event': + specifier: ^14.6.1 + version: 14.6.1(@testing-library/dom@10.4.1) + '@types/pluralize': + specifier: ^0.0.33 + version: 0.0.33 + '@types/react': + specifier: ^18.3.23 + version: 18.3.23 + '@types/react-dom': + specifier: ^18.3.7 + version: 18.3.7(@types/react@18.3.23) + '@vitejs/plugin-react-swc': + specifier: ^3.11.0 + version: 3.11.0(@swc/helpers@0.5.17)(vite@6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)) + ajv: + specifier: ^8.17.1 + version: 8.17.1 + autoprefixer: + specifier: ^10.4.21 + version: 10.4.21(postcss@8.5.6) + eslint: + specifier: ^9.31.0 + version: 9.31.0(jiti@2.4.2) + jsdom: + specifier: ^26.1.0 + version: 26.1.0 + orval: + specifier: ^7.10.0 + version: 7.10.0(openapi-types@12.1.3) + postcss: + specifier: ^8.5.6 + version: 8.5.6 + tailwindcss: + specifier: ^3.4.17 + version: 3.4.17 + typescript: + specifier: ^5.8.3 + version: 5.8.3 + typescript-eslint: + specifier: ^8.38.0 + version: 8.38.0(eslint@9.31.0(jiti@2.4.2))(typescript@5.8.3) + vite: + specifier: ^6.3.5 + version: 6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + vite-plugin-css-injected-by-js: + specifier: ^3.5.2 + version: 3.5.2(vite@6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)) + vitest: + specifier: ^3.2.4 + version: 3.2.4(@types/debug@4.1.12)(@types/node@24.1.0)(@vitest/browser@3.2.4)(@vitest/ui@3.2.4)(jiti@2.4.2)(jsdom@26.1.0)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + optionalDependencies: + '@swc/core-linux-x64-gnu': + specifier: ^1.13.2 + version: 1.13.2 + + web/common: + devDependencies: + '@eslint/js': + specifier: 9.31.0 + version: 9.31.0 + '@radix-ui/react-slot': + specifier: 1.2.3 + version: 1.2.3(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-tooltip': + specifier: 1.2.8 + version: 1.2.8(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@storybook/addon-docs': + specifier: 9.1.5 + version: 9.1.5(@types/react@18.3.23)(storybook@9.1.5(@testing-library/dom@10.4.1)(prettier@3.6.2)(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0))) + '@storybook/react-vite': + specifier: 9.1.5 + version: 9.1.5(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(rollup@4.45.1)(storybook@9.1.5(@testing-library/dom@10.4.1)(prettier@3.6.2)(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)))(typescript@5.8.3)(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)) + '@tailwindcss/typography': + specifier: 0.5.16 + version: 0.5.16(tailwindcss@3.4.17) + '@tanstack/react-virtual': + specifier: 3.13.12 + version: 3.13.12(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@testing-library/dom': + specifier: 10.4.1 + version: 10.4.1 + '@testing-library/jest-dom': + specifier: 6.6.3 + version: 6.6.3 + '@testing-library/react': + specifier: 16.3.0 + version: 16.3.0(@testing-library/dom@10.4.1)(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@testing-library/user-event': + specifier: 14.6.1 + version: 14.6.1(@testing-library/dom@10.4.1) + '@types/dagre': + specifier: 0.7.53 + version: 0.7.53 + '@types/node': + specifier: 20.11.25 + version: 20.11.25 + '@types/react': + specifier: 18.3.23 + version: 18.3.23 + '@types/react-dom': + specifier: 18.3.7 + version: 18.3.7(@types/react@18.3.23) + '@vitejs/plugin-react': + specifier: 4.7.0 + version: 4.7.0(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)) + '@vitest/browser': + specifier: 3.2.4 + version: 3.2.4(playwright@1.54.1)(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0))(vitest@3.2.4) + '@xyflow/react': + specifier: 12.8.4 + version: 12.8.4(@types/react@18.3.23)(immer@9.0.21)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + autoprefixer: + specifier: 10.4.21 + version: 10.4.21(postcss@8.5.6) + browserslist: + specifier: 4.26.2 + version: 4.26.2 + caniuse-lite: + specifier: 1.0.30001746 + version: 1.0.30001746 + class-variance-authority: + specifier: 0.7.1 + version: 0.7.1 + clsx: + specifier: 2.1.1 + version: 2.1.1 + cronstrue: + specifier: 3.3.0 + version: 3.3.0 + dagre: + specifier: 0.8.5 + version: 0.8.5 + deepmerge: + specifier: 4.3.1 + version: 4.3.1 + eslint: + specifier: 9.31.0 + version: 9.31.0(jiti@2.4.2) + eslint-plugin-react-hooks: + specifier: 5.2.0 + version: 5.2.0(eslint@9.31.0(jiti@2.4.2)) + eslint-plugin-storybook: + specifier: 9.1.5 + version: 9.1.5(eslint@9.31.0(jiti@2.4.2))(storybook@9.1.5(@testing-library/dom@10.4.1)(prettier@3.6.2)(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)))(typescript@5.8.3) + fuse.js: + specifier: 7.1.0 + version: 7.1.0 + globals: + specifier: 16.3.0 + version: 16.3.0 + lucide-react: + specifier: 0.542.0 + version: 0.542.0(react@18.3.1) + playwright: + specifier: 1.54.1 + version: 1.54.1 + postcss: + specifier: 8.5.6 + version: 8.5.6 + react: + specifier: 18.3.1 + version: 18.3.1 + react-dom: + specifier: 18.3.1 + version: 18.3.1(react@18.3.1) + storybook: + specifier: 9.1.5 + version: 9.1.5(@testing-library/dom@10.4.1)(prettier@3.6.2)(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)) + syncpack: + specifier: 13.0.4 + version: 13.0.4(typescript@5.8.3) + tailwind-merge: + specifier: 3.3.1 + version: 3.3.1 + tailwind-scrollbar: + specifier: 3.1.0 + version: 3.1.0(tailwindcss@3.4.17) + tailwindcss: + specifier: 3.4.17 + version: 3.4.17 + typescript: + specifier: 5.8.3 + version: 5.8.3 + typescript-eslint: + specifier: 8.38.0 + version: 8.38.0(eslint@9.31.0(jiti@2.4.2))(typescript@5.8.3) + vite: + specifier: 6.3.5 + version: 6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + vite-plugin-dts: + specifier: 4.5.4 + version: 4.5.4(@types/node@20.11.25)(rollup@4.45.1)(typescript@5.8.3)(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)) + vite-plugin-static-copy: + specifier: 3.1.1 + version: 3.1.1(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)) + vitest: + specifier: 3.2.4 + version: 3.2.4(@types/debug@4.1.12)(@types/node@20.11.25)(@vitest/browser@3.2.4)(@vitest/ui@3.2.4)(jiti@2.4.2)(jsdom@26.1.0)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + +packages: + + '@adobe/css-tools@4.4.3': + resolution: {integrity: sha512-VQKMkwriZbaOgVCby1UDY/LDk5fIjhQicCvVPFqfe+69fWaPWydbWJ3wRt59/YzIwda1I81loas3oCoHxnqvdA==} + + '@alloc/quick-lru@5.2.0': + resolution: {integrity: sha512-UrcABB+4bUrFABwbluTIBErXwvbsU/V7TZWfmbgJfbkwiBuziS9gxdODUyuiecfdGQ85jglMW6juS3+z5TsKLw==} + engines: {node: '>=10'} + + '@ampproject/remapping@2.3.0': + resolution: {integrity: sha512-30iZtAPgz+LTIYoeivqYo853f02jBYSd5uGnGpkFV0M3xOt9aN73erkgYAmZU43x4VfqcnLxW9Kpg3R5LC4YYw==} + engines: {node: '>=6.0.0'} + + '@apidevtools/json-schema-ref-parser@11.7.2': + resolution: {integrity: sha512-4gY54eEGEstClvEkGnwVkTkrx0sqwemEFG5OSRRn3tD91XH0+Q8XIkYIfo7IwEWPpJZwILb9GUXeShtplRc/eA==} + engines: {node: '>= 16'} + + '@apidevtools/openapi-schemas@2.1.0': + resolution: {integrity: sha512-Zc1AlqrJlX3SlpupFGpiLi2EbteyP7fXmUOGup6/DnkRgjP9bgMM/ag+n91rsv0U1Gpz0H3VILA/o3bW7Ua6BQ==} + engines: {node: '>=10'} + + '@apidevtools/swagger-methods@3.0.2': + resolution: {integrity: sha512-QAkD5kK2b1WfjDS/UQn/qQkbwF31uqRjPTrsCs5ZG9BQGAkjwvqGFjjPqAuzac/IYzpPtRzjCP1WrTuAIjMrXg==} + + '@apidevtools/swagger-parser@10.1.1': + resolution: {integrity: sha512-u/kozRnsPO/x8QtKYJOqoGtC4kH6yg1lfYkB9Au0WhYB0FNLpyFusttQtvhlwjtG3rOwiRz4D8DnnXa8iEpIKA==} + peerDependencies: + openapi-types: '>=7' + + '@asamuzakjp/css-color@3.2.0': + resolution: {integrity: sha512-K1A6z8tS3XsmCMM86xoWdn7Fkdn9m6RSVtocUrJYIwZnFVkng/PvkEoWtOWmP+Scc6saYWHWZYbndEEXxl24jw==} + + '@asyncapi/specs@6.8.1': + resolution: {integrity: sha512-czHoAk3PeXTLR+X8IUaD+IpT+g+zUvkcgMDJVothBsan+oHN3jfcFcFUNdOPAAFoUCQN1hXF1dWuphWy05THlA==} + + '@azu/format-text@1.0.2': + resolution: {integrity: sha512-Swi4N7Edy1Eqq82GxgEECXSSLyn6GOb5htRFPzBDdUkECGXtlf12ynO5oJSpWKPwCaUssOu7NfhDcCWpIC6Ywg==} + + '@azu/style-format@1.0.1': + resolution: {integrity: sha512-AHcTojlNBdD/3/KxIKlg8sxIWHfOtQszLvOpagLTO+bjC3u7SAszu1lf//u7JJC50aUSH+BVWDD/KvaA6Gfn5g==} + + '@azure/abort-controller@2.1.2': + resolution: {integrity: sha512-nBrLsEWm4J2u5LpAPjxADTlq3trDgVZZXHNKabeXZtpq3d3AbN/KGO82R87rdDz5/lYB024rtEf10/q0urNgsA==} + engines: {node: '>=18.0.0'} + + '@azure/core-auth@1.10.0': + resolution: {integrity: sha512-88Djs5vBvGbHQHf5ZZcaoNHo6Y8BKZkt3cw2iuJIQzLEgH4Ox6Tm4hjFhbqOxyYsgIG/eJbFEHpxRIfEEWv5Ow==} + engines: {node: '>=20.0.0'} + + '@azure/core-client@1.10.0': + resolution: {integrity: sha512-O4aP3CLFNodg8eTHXECaH3B3CjicfzkxVtnrfLkOq0XNP7TIECGfHpK/C6vADZkWP75wzmdBnsIA8ksuJMk18g==} + engines: {node: '>=20.0.0'} + + '@azure/core-rest-pipeline@1.22.0': + resolution: {integrity: sha512-OKHmb3/Kpm06HypvB3g6Q3zJuvyXcpxDpCS1PnU8OV6AJgSFaee/covXBcPbWc6XDDxtEPlbi3EMQ6nUiPaQtw==} + engines: {node: '>=20.0.0'} + + '@azure/core-tracing@1.3.0': + resolution: {integrity: sha512-+XvmZLLWPe67WXNZo9Oc9CrPj/Tm8QnHR92fFAFdnbzwNdCH1h+7UdpaQgRSBsMY+oW1kHXNUZQLdZ1gHX3ROw==} + engines: {node: '>=20.0.0'} + + '@azure/core-util@1.13.0': + resolution: {integrity: sha512-o0psW8QWQ58fq3i24Q1K2XfS/jYTxr7O1HRcyUE9bV9NttLU+kYOH82Ixj8DGlMTOWgxm1Sss2QAfKK5UkSPxw==} + engines: {node: '>=20.0.0'} + + '@azure/identity@4.10.2': + resolution: {integrity: sha512-Uth4vz0j+fkXCkbvutChUj03PDCokjbC6Wk9JT8hHEUtpy/EurNKAseb3+gO6Zi9VYBvwt61pgbzn1ovk942Qg==} + engines: {node: '>=20.0.0'} + + '@azure/logger@1.3.0': + resolution: {integrity: sha512-fCqPIfOcLE+CGqGPd66c8bZpwAji98tZ4JI9i/mlTNTlsIWslCfpg48s/ypyLxZTump5sypjrKn2/kY7q8oAbA==} + engines: {node: '>=20.0.0'} + + '@azure/msal-browser@4.15.0': + resolution: {integrity: sha512-+AIGTvpVz+FIx5CsM1y+nW0r/qOb/ChRdM8/Cbp+jKWC0Wdw4ldnwPdYOBi5NaALUQnYITirD9XMZX7LdklEzQ==} + engines: {node: '>=0.8.0'} + + '@azure/msal-common@15.8.1': + resolution: {integrity: sha512-ltIlFK5VxeJ5BurE25OsJIfcx1Q3H/IZg2LjV9d4vmH+5t4c1UCyRQ/HgKLgXuCZShs7qfc/TC95GYZfsUsJUQ==} + engines: {node: '>=0.8.0'} + + '@azure/msal-node@3.6.3': + resolution: {integrity: sha512-95wjsKGyUcAd5tFmQBo5Ug/kOj+hFh/8FsXuxluEvdfbgg6xCimhSP9qnyq6+xIg78/jREkBD1/BSqd7NIDDYQ==} + engines: {node: '>=16'} + + '@babel/code-frame@7.27.1': + resolution: {integrity: sha512-cjQ7ZlQ0Mv3b47hABuTevyTuYN4i+loJKGeV9flcCgIK37cCXRh+L1bd3iBHlynerhQ7BhCkn2BPbQUL+rGqFg==} + engines: {node: '>=6.9.0'} + + '@babel/compat-data@7.28.0': + resolution: {integrity: sha512-60X7qkglvrap8mn1lh2ebxXdZYtUcpd7gsmy9kLaBJ4i/WdY8PqTSdxyA8qraikqKQK5C1KRBKXqznrVapyNaw==} + engines: {node: '>=6.9.0'} + + '@babel/core@7.28.0': + resolution: {integrity: sha512-UlLAnTPrFdNGoFtbSXwcGFQBtQZJCNjaN6hQNP3UPvuNXT1i82N26KL3dZeIpNalWywr9IuQuncaAfUaS1g6sQ==} + engines: {node: '>=6.9.0'} + + '@babel/generator@7.28.0': + resolution: {integrity: sha512-lJjzvrbEeWrhB4P3QBsH7tey117PjLZnDbLiQEKjQ/fNJTjuq4HSqgFA+UNSwZT8D7dxxbnuSBMsa1lrWzKlQg==} + engines: {node: '>=6.9.0'} + + '@babel/helper-annotate-as-pure@7.27.3': + resolution: {integrity: sha512-fXSwMQqitTGeHLBC08Eq5yXz2m37E4pJX1qAU1+2cNedz/ifv/bVXft90VeSav5nFO61EcNgwr0aJxbyPaWBPg==} + engines: {node: '>=6.9.0'} + + '@babel/helper-compilation-targets@7.27.2': + resolution: {integrity: sha512-2+1thGUUWWjLTYTHZWK1n8Yga0ijBz1XAhUXcKy81rd5g6yh7hGqMp45v7cadSbEHc9G3OTv45SyneRN3ps4DQ==} + engines: {node: '>=6.9.0'} + + '@babel/helper-create-class-features-plugin@7.27.1': + resolution: {integrity: sha512-QwGAmuvM17btKU5VqXfb+Giw4JcN0hjuufz3DYnpeVDvZLAObloM77bhMXiqry3Iio+Ai4phVRDwl6WU10+r5A==} + engines: {node: '>=6.9.0'} + peerDependencies: + '@babel/core': ^7.0.0 + + '@babel/helper-globals@7.28.0': + resolution: {integrity: sha512-+W6cISkXFa1jXsDEdYA8HeevQT/FULhxzR99pxphltZcVaugps53THCeiWA8SguxxpSp3gKPiuYfSWopkLQ4hw==} + engines: {node: '>=6.9.0'} + + '@babel/helper-member-expression-to-functions@7.27.1': + resolution: {integrity: sha512-E5chM8eWjTp/aNoVpcbfM7mLxu9XGLWYise2eBKGQomAk/Mb4XoxyqXTZbuTohbsl8EKqdlMhnDI2CCLfcs9wA==} + engines: {node: '>=6.9.0'} + + '@babel/helper-module-imports@7.27.1': + resolution: {integrity: sha512-0gSFWUPNXNopqtIPQvlD5WgXYI5GY2kP2cCvoT8kczjbfcfuIljTbcWrulD1CIPIX2gt1wghbDy08yE1p+/r3w==} + engines: {node: '>=6.9.0'} + + '@babel/helper-module-transforms@7.27.3': + resolution: {integrity: sha512-dSOvYwvyLsWBeIRyOeHXp5vPj5l1I011r52FM1+r1jCERv+aFXYk4whgQccYEGYxK2H3ZAIA8nuPkQ0HaUo3qg==} + engines: {node: '>=6.9.0'} + peerDependencies: + '@babel/core': ^7.0.0 + + '@babel/helper-optimise-call-expression@7.27.1': + resolution: {integrity: sha512-URMGH08NzYFhubNSGJrpUEphGKQwMQYBySzat5cAByY1/YgIRkULnIy3tAMeszlL/so2HbeilYloUmSpd7GdVw==} + engines: {node: '>=6.9.0'} + + '@babel/helper-plugin-utils@7.27.1': + resolution: {integrity: sha512-1gn1Up5YXka3YYAHGKpbideQ5Yjf1tDa9qYcgysz+cNCXukyLl6DjPXhD3VRwSb8c0J9tA4b2+rHEZtc6R0tlw==} + engines: {node: '>=6.9.0'} + + '@babel/helper-replace-supers@7.27.1': + resolution: {integrity: sha512-7EHz6qDZc8RYS5ElPoShMheWvEgERonFCs7IAonWLLUTXW59DP14bCZt89/GKyreYn8g3S83m21FelHKbeDCKA==} + engines: {node: '>=6.9.0'} + peerDependencies: + '@babel/core': ^7.0.0 + + '@babel/helper-skip-transparent-expression-wrappers@7.27.1': + resolution: {integrity: sha512-Tub4ZKEXqbPjXgWLl2+3JpQAYBJ8+ikpQ2Ocj/q/r0LwE3UhENh7EUabyHjz2kCEsrRY83ew2DQdHluuiDQFzg==} + engines: {node: '>=6.9.0'} + + '@babel/helper-string-parser@7.27.1': + resolution: {integrity: sha512-qMlSxKbpRlAridDExk92nSobyDdpPijUq2DW6oDnUqd0iOGxmQjyqhMIihI9+zv4LPyZdRje2cavWPbCbWm3eA==} + engines: {node: '>=6.9.0'} + + '@babel/helper-validator-identifier@7.27.1': + resolution: {integrity: sha512-D2hP9eA+Sqx1kBZgzxZh0y1trbuU+JoDkiEwqhQ36nodYqJwyEIhPSdMNd7lOm/4io72luTPWH20Yda0xOuUow==} + engines: {node: '>=6.9.0'} + + '@babel/helper-validator-option@7.27.1': + resolution: {integrity: sha512-YvjJow9FxbhFFKDSuFnVCe2WxXk1zWc22fFePVNEaWJEu8IrZVlda6N0uHwzZrUM1il7NC9Mlp4MaJYbYd9JSg==} + engines: {node: '>=6.9.0'} + + '@babel/helpers@7.27.6': + resolution: {integrity: sha512-muE8Tt8M22638HU31A3CgfSUciwz1fhATfoVai05aPXGor//CdWDCbnlY1yvBPo07njuVOCNGCSp/GTt12lIug==} + engines: {node: '>=6.9.0'} + + '@babel/parser@7.28.0': + resolution: {integrity: sha512-jVZGvOxOuNSsuQuLRTh13nU0AogFlw32w/MT+LV6D3sP5WdbW61E77RnkbaO2dUvmPAYrBDJXGn5gGS6tH4j8g==} + engines: {node: '>=6.0.0'} + hasBin: true + + '@babel/plugin-syntax-jsx@7.27.1': + resolution: {integrity: sha512-y8YTNIeKoyhGd9O0Jiyzyyqk8gdjnumGTQPsz0xOZOQ2RmkVJeZ1vmmfIvFEKqucBG6axJGBZDE/7iI5suUI/w==} + engines: {node: '>=6.9.0'} + peerDependencies: + '@babel/core': ^7.0.0-0 + + '@babel/plugin-syntax-typescript@7.27.1': + resolution: {integrity: sha512-xfYCBMxveHrRMnAWl1ZlPXOZjzkN82THFvLhQhFXFt81Z5HnN+EtUkZhv/zcKpmT3fzmWZB0ywiBrbC3vogbwQ==} + engines: {node: '>=6.9.0'} + peerDependencies: + '@babel/core': ^7.0.0-0 + + '@babel/plugin-transform-modules-commonjs@7.27.1': + resolution: {integrity: sha512-OJguuwlTYlN0gBZFRPqwOGNWssZjfIUdS7HMYtN8c1KmwpwHFBwTeFZrg9XZa+DFTitWOW5iTAG7tyCUPsCCyw==} + engines: {node: '>=6.9.0'} + peerDependencies: + '@babel/core': ^7.0.0-0 + + '@babel/plugin-transform-react-jsx-self@7.27.1': + resolution: {integrity: sha512-6UzkCs+ejGdZ5mFFC/OCUrv028ab2fp1znZmCZjAOBKiBK2jXD1O+BPSfX8X2qjJ75fZBMSnQn3Rq2mrBJK2mw==} + engines: {node: '>=6.9.0'} + peerDependencies: + '@babel/core': ^7.0.0-0 + + '@babel/plugin-transform-react-jsx-source@7.27.1': + resolution: {integrity: sha512-zbwoTsBruTeKB9hSq73ha66iFeJHuaFkUbwvqElnygoNbj/jHRsSeokowZFN3CZ64IvEqcmmkVe89OPXc7ldAw==} + engines: {node: '>=6.9.0'} + peerDependencies: + '@babel/core': ^7.0.0-0 + + '@babel/plugin-transform-typescript@7.28.0': + resolution: {integrity: sha512-4AEiDEBPIZvLQaWlc9liCavE0xRM0dNca41WtBeM3jgFptfUOSG9z0uteLhq6+3rq+WB6jIvUwKDTpXEHPJ2Vg==} + engines: {node: '>=6.9.0'} + peerDependencies: + '@babel/core': ^7.0.0-0 + + '@babel/preset-typescript@7.27.1': + resolution: {integrity: sha512-l7WfQfX0WK4M0v2RudjuQK4u99BS6yLHYEmdtVPP7lKV013zr9DygFuWNlnbvQ9LR+LS0Egz/XAvGx5U9MX0fQ==} + engines: {node: '>=6.9.0'} + peerDependencies: + '@babel/core': ^7.0.0-0 + + '@babel/runtime@7.28.2': + resolution: {integrity: sha512-KHp2IflsnGywDjBWDkR9iEqiWSpc8GIi0lgTT3mOElT0PP1tG26P4tmFI2YvAdzgq9RGyoHZQEIEdZy6Ec5xCA==} + engines: {node: '>=6.9.0'} + + '@babel/template@7.27.2': + resolution: {integrity: sha512-LPDZ85aEJyYSd18/DkjNh4/y1ntkE5KwUHWTiqgRxruuZL2F1yuHligVHLvcHY2vMHXttKFpJn6LwfI7cw7ODw==} + engines: {node: '>=6.9.0'} + + '@babel/traverse@7.28.0': + resolution: {integrity: sha512-mGe7UK5wWyh0bKRfupsUchrQGqvDbZDbKJw+kcRGSmdHVYrv+ltd0pnpDTVpiTqnaBru9iEvA8pz8W46v0Amwg==} + engines: {node: '>=6.9.0'} + + '@babel/types@7.28.1': + resolution: {integrity: sha512-x0LvFTekgSX+83TI28Y9wYPUfzrnl2aT5+5QLnO6v7mSJYtEEevuDRN0F0uSHRk1G1IWZC43o00Y0xDDrpBGPQ==} + engines: {node: '>=6.9.0'} + + '@bcoe/v8-coverage@0.2.3': + resolution: {integrity: sha512-0hYQ8SB4Db5zvZB4axdMHGwEaQjkZzFjQiN9LVYvIFB2nSUHW9tYpxWriPrWDASIxiaXax83REcLxuSdnGPZtw==} + + '@bcoe/v8-coverage@1.0.2': + resolution: {integrity: sha512-6zABk/ECA/QYSCQ1NGiVwwbQerUCZ+TQbp64Q3AgmfNvurHH0j8TtXa1qbShXA6qqkpAj4V5W8pP6mLe1mcMqA==} + engines: {node: '>=18'} + + '@chromatic-com/storybook@4.0.1': + resolution: {integrity: sha512-GQXe5lyZl3yLewLJQyFXEpOp2h+mfN2bPrzYaOFNCJjO4Js9deKbRHTOSaiP2FRwZqDLdQwy2+SEGeXPZ94yYw==} + engines: {node: '>=20.0.0', yarn: '>=1.22.18'} + peerDependencies: + storybook: ^0.0.0-0 || ^9.0.0 || ^9.1.0-0 + + '@codemirror/autocomplete@6.18.6': + resolution: {integrity: sha512-PHHBXFomUs5DF+9tCOM/UoW6XQ4R44lLNNhRaW9PKPTU0D7lIjRg3ElxaJnTwsl/oHiR93WSXDBrekhoUGCPtg==} + + '@codemirror/autocomplete@6.19.0': + resolution: {integrity: sha512-61Hfv3cF07XvUxNeC3E7jhG8XNi1Yom1G0lRC936oLnlF+jrbrv8rc/J98XlYzcsAoTVupfsf5fLej1aI8kyIg==} + + '@codemirror/commands@6.8.1': + resolution: {integrity: sha512-KlGVYufHMQzxbdQONiLyGQDUW0itrLZwq3CcY7xpv9ZLRHqzkBSoteocBHtMCoY7/Ci4xhzSrToIeLg7FxHuaw==} + + '@codemirror/lang-python@6.2.1': + resolution: {integrity: sha512-IRjC8RUBhn9mGR9ywecNhB51yePWCGgvHfY1lWN/Mrp3cKuHr0isDKia+9HnvhiWNnMpbGhWrkhuWOc09exRyw==} + + '@codemirror/lang-sql@6.9.0': + resolution: {integrity: sha512-xmtpWqKSgum1B1J3Ro6rf7nuPqf2+kJQg5SjrofCAcyCThOe0ihSktSoXfXuhQBnwx1QbmreBbLJM5Jru6zitg==} + + '@codemirror/language@6.11.2': + resolution: {integrity: sha512-p44TsNArL4IVXDTbapUmEkAlvWs2CFQbcfc0ymDsis1kH2wh0gcY96AS29c/vp2d0y2Tquk1EDSaawpzilUiAw==} + + '@codemirror/language@6.11.3': + resolution: {integrity: sha512-9HBM2XnwDj7fnu0551HkGdrUrrqmYq/WC5iv6nbY2WdicXdGbhR/gfbZOH73Aqj4351alY1+aoG9rCNfiwS1RA==} + + '@codemirror/legacy-modes@6.5.1': + resolution: {integrity: sha512-DJYQQ00N1/KdESpZV7jg9hafof/iBNp9h7TYo1SLMk86TWl9uDsVdho2dzd81K+v4retmK6mdC7WpuOQDytQqw==} + + '@codemirror/lint@6.8.5': + resolution: {integrity: sha512-s3n3KisH7dx3vsoeGMxsbRAgKe4O1vbrnKBClm99PU0fWxmxsx5rR2PfqQgIt+2MMJBHbiJ5rfIdLYfB9NNvsA==} + + '@codemirror/search@6.5.10': + resolution: {integrity: sha512-RMdPdmsrUf53pb2VwflKGHEe1XVM07hI7vV2ntgw1dmqhimpatSJKva4VA9h4TLUDOD4EIF02201oZurpnEFsg==} + + '@codemirror/state@6.5.2': + resolution: {integrity: sha512-FVqsPqtPWKVVL3dPSxy8wEF/ymIEuVzF1PK3VbUgrxXpJUSHQWWZz4JMToquRxnkw+36LTamCZG2iua2Ptq0fA==} + + '@codemirror/theme-one-dark@6.1.2': + resolution: {integrity: sha512-F+sH0X16j/qFLMAfbciKTxVOwkdAS336b7AXTKOZhy8BR3eH/RelsnLgLFINrpST63mmN2OuwUt0W2ndUgYwUA==} + + '@codemirror/view@6.38.1': + resolution: {integrity: sha512-RmTOkE7hRU3OVREqFVITWHz6ocgBjv08GoePscAakgVQfciA3SGCEk7mb9IzwW61cKKmlTpHXG6DUE5Ubx+MGQ==} + + '@codemirror/view@6.38.4': + resolution: {integrity: sha512-hduz0suCcUSC/kM8Fq3A9iLwInJDl8fD1xLpTIk+5xkNm8z/FT7UsIa9sOXrkpChh+XXc18RzswE8QqELsVl+g==} + + '@csstools/color-helpers@5.0.2': + resolution: {integrity: sha512-JqWH1vsgdGcw2RR6VliXXdA0/59LttzlU8UlRT/iUUsEeWfYq8I+K0yhihEUTTHLRm1EXvpsCx3083EU15ecsA==} + engines: {node: '>=18'} + + '@csstools/css-calc@2.1.4': + resolution: {integrity: sha512-3N8oaj+0juUw/1H3YwmDDJXCgTB1gKU6Hc/bB502u9zR0q2vd786XJH9QfrKIEgFlZmhZiq6epXl4rHqhzsIgQ==} + engines: {node: '>=18'} + peerDependencies: + '@csstools/css-parser-algorithms': ^3.0.5 + '@csstools/css-tokenizer': ^3.0.4 + + '@csstools/css-color-parser@3.0.10': + resolution: {integrity: sha512-TiJ5Ajr6WRd1r8HSiwJvZBiJOqtH86aHpUjq5aEKWHiII2Qfjqd/HCWKPOW8EP4vcspXbHnXrwIDlu5savQipg==} + engines: {node: '>=18'} + peerDependencies: + '@csstools/css-parser-algorithms': ^3.0.5 + '@csstools/css-tokenizer': ^3.0.4 + + '@csstools/css-parser-algorithms@3.0.5': + resolution: {integrity: sha512-DaDeUkXZKjdGhgYaHNJTV9pV7Y9B3b644jCLs9Upc3VeNGg6LWARAT6O+Q+/COo+2gg/bM5rhpMAtf70WqfBdQ==} + engines: {node: '>=18'} + peerDependencies: + '@csstools/css-tokenizer': ^3.0.4 + + '@csstools/css-tokenizer@3.0.4': + resolution: {integrity: sha512-Vd/9EVDiu6PPJt9yAh6roZP6El1xHrdvIVGjyBsHR0RYwNHgL7FJPyIIW4fANJNG6FtyZfvlRPpFI4ZM/lubvw==} + engines: {node: '>=18'} + + '@duckdb/node-api@1.3.2-alpha.25': + resolution: {integrity: sha512-AzDyyjTtnYUxoy/MHDFRwfOggDOkS8RBgGA82OI6nla8B9NDNZeAYJ97T3PvCL8cx7y00EtGVN3g03aoW4fRmw==} + + '@duckdb/node-bindings-darwin-arm64@1.3.2-alpha.25': + resolution: {integrity: sha512-vRjzNgkz2TAYW5c2rzPwcHBctBWr0lxQ4blFASAv0DdeGPOeuCMXJUA3982X7iPNwAppH0VMII6cYzON0GA+RA==} + cpu: [arm64] + os: [darwin] + + '@duckdb/node-bindings-darwin-x64@1.3.2-alpha.25': + resolution: {integrity: sha512-BSg/DZjT25QZe87+pmdMfE1XlHdi2WxtAO+F2PEXN6VnPeLyTdl5bYlnhOGrDKquKDmUEqok5OwF7mR4QfU+Aw==} + cpu: [x64] + os: [darwin] + + '@duckdb/node-bindings-linux-arm64@1.3.2-alpha.25': + resolution: {integrity: sha512-VhjUH/AvolZWDX/URqiIh58JbAB1vYbDgSmQ0wvqhS9jzJ9Sj88urGDw+XWXw49Rr4BhIgDtX70SoARhO2i/Gg==} + cpu: [arm64] + os: [linux] + + '@duckdb/node-bindings-linux-x64@1.3.2-alpha.25': + resolution: {integrity: sha512-raav2ypBiV4TlpnKU9hocsuFDO4ipwIcQQmkMIh20/Qd9vkv35QcQYNqStiZVJh2LAaVoQffNvcKMlclblYqUQ==} + cpu: [x64] + os: [linux] + + '@duckdb/node-bindings-win32-x64@1.3.2-alpha.25': + resolution: {integrity: sha512-/fAKax+xYkdRhkUl3PkL3HfFd1ZsezG1yiOkL0StHBdD3xB80Njm1JGHxx1fO3WWE5XTbE1MTJ5I0xjEzPwsfQ==} + cpu: [x64] + os: [win32] + + '@duckdb/node-bindings@1.3.2-alpha.25': + resolution: {integrity: sha512-FkoSaoeRAi6Em0hs0qzr3SN04ykN99R+Qap5kLwhi6GNPnHzWMU1VrNpK9cE4eBj0n+RWlNK0TiO712dn44QzQ==} + + '@esbuild/aix-ppc64@0.25.8': + resolution: {integrity: sha512-urAvrUedIqEiFR3FYSLTWQgLu5tb+m0qZw0NBEasUeo6wuqatkMDaRT+1uABiGXEu5vqgPd7FGE1BhsAIy9QVA==} + engines: {node: '>=18'} + cpu: [ppc64] + os: [aix] + + '@esbuild/android-arm64@0.25.8': + resolution: {integrity: sha512-OD3p7LYzWpLhZEyATcTSJ67qB5D+20vbtr6vHlHWSQYhKtzUYrETuWThmzFpZtFsBIxRvhO07+UgVA9m0i/O1w==} + engines: {node: '>=18'} + cpu: [arm64] + os: [android] + + '@esbuild/android-arm@0.25.8': + resolution: {integrity: sha512-RONsAvGCz5oWyePVnLdZY/HHwA++nxYWIX1atInlaW6SEkwq6XkP3+cb825EUcRs5Vss/lGh/2YxAb5xqc07Uw==} + engines: {node: '>=18'} + cpu: [arm] + os: [android] + + '@esbuild/android-x64@0.25.8': + resolution: {integrity: sha512-yJAVPklM5+4+9dTeKwHOaA+LQkmrKFX96BM0A/2zQrbS6ENCmxc4OVoBs5dPkCCak2roAD+jKCdnmOqKszPkjA==} + engines: {node: '>=18'} + cpu: [x64] + os: [android] + + '@esbuild/darwin-arm64@0.25.8': + resolution: {integrity: sha512-Jw0mxgIaYX6R8ODrdkLLPwBqHTtYHJSmzzd+QeytSugzQ0Vg4c5rDky5VgkoowbZQahCbsv1rT1KW72MPIkevw==} + engines: {node: '>=18'} + cpu: [arm64] + os: [darwin] + + '@esbuild/darwin-x64@0.25.8': + resolution: {integrity: sha512-Vh2gLxxHnuoQ+GjPNvDSDRpoBCUzY4Pu0kBqMBDlK4fuWbKgGtmDIeEC081xi26PPjn+1tct+Bh8FjyLlw1Zlg==} + engines: {node: '>=18'} + cpu: [x64] + os: [darwin] + + '@esbuild/freebsd-arm64@0.25.8': + resolution: {integrity: sha512-YPJ7hDQ9DnNe5vxOm6jaie9QsTwcKedPvizTVlqWG9GBSq+BuyWEDazlGaDTC5NGU4QJd666V0yqCBL2oWKPfA==} + engines: {node: '>=18'} + cpu: [arm64] + os: [freebsd] + + '@esbuild/freebsd-x64@0.25.8': + resolution: {integrity: sha512-MmaEXxQRdXNFsRN/KcIimLnSJrk2r5H8v+WVafRWz5xdSVmWLoITZQXcgehI2ZE6gioE6HirAEToM/RvFBeuhw==} + engines: {node: '>=18'} + cpu: [x64] + os: [freebsd] + + '@esbuild/linux-arm64@0.25.8': + resolution: {integrity: sha512-WIgg00ARWv/uYLU7lsuDK00d/hHSfES5BzdWAdAig1ioV5kaFNrtK8EqGcUBJhYqotlUByUKz5Qo6u8tt7iD/w==} + engines: {node: '>=18'} + cpu: [arm64] + os: [linux] + + '@esbuild/linux-arm@0.25.8': + resolution: {integrity: sha512-FuzEP9BixzZohl1kLf76KEVOsxtIBFwCaLupVuk4eFVnOZfU+Wsn+x5Ryam7nILV2pkq2TqQM9EZPsOBuMC+kg==} + engines: {node: '>=18'} + cpu: [arm] + os: [linux] + + '@esbuild/linux-ia32@0.25.8': + resolution: {integrity: sha512-A1D9YzRX1i+1AJZuFFUMP1E9fMaYY+GnSQil9Tlw05utlE86EKTUA7RjwHDkEitmLYiFsRd9HwKBPEftNdBfjg==} + engines: {node: '>=18'} + cpu: [ia32] + os: [linux] + + '@esbuild/linux-loong64@0.25.8': + resolution: {integrity: sha512-O7k1J/dwHkY1RMVvglFHl1HzutGEFFZ3kNiDMSOyUrB7WcoHGf96Sh+64nTRT26l3GMbCW01Ekh/ThKM5iI7hQ==} + engines: {node: '>=18'} + cpu: [loong64] + os: [linux] + + '@esbuild/linux-mips64el@0.25.8': + resolution: {integrity: sha512-uv+dqfRazte3BzfMp8PAQXmdGHQt2oC/y2ovwpTteqrMx2lwaksiFZ/bdkXJC19ttTvNXBuWH53zy/aTj1FgGw==} + engines: {node: '>=18'} + cpu: [mips64el] + os: [linux] + + '@esbuild/linux-ppc64@0.25.8': + resolution: {integrity: sha512-GyG0KcMi1GBavP5JgAkkstMGyMholMDybAf8wF5A70CALlDM2p/f7YFE7H92eDeH/VBtFJA5MT4nRPDGg4JuzQ==} + engines: {node: '>=18'} + cpu: [ppc64] + os: [linux] + + '@esbuild/linux-riscv64@0.25.8': + resolution: {integrity: sha512-rAqDYFv3yzMrq7GIcen3XP7TUEG/4LK86LUPMIz6RT8A6pRIDn0sDcvjudVZBiiTcZCY9y2SgYX2lgK3AF+1eg==} + engines: {node: '>=18'} + cpu: [riscv64] + os: [linux] + + '@esbuild/linux-s390x@0.25.8': + resolution: {integrity: sha512-Xutvh6VjlbcHpsIIbwY8GVRbwoviWT19tFhgdA7DlenLGC/mbc3lBoVb7jxj9Z+eyGqvcnSyIltYUrkKzWqSvg==} + engines: {node: '>=18'} + cpu: [s390x] + os: [linux] + + '@esbuild/linux-x64@0.25.8': + resolution: {integrity: sha512-ASFQhgY4ElXh3nDcOMTkQero4b1lgubskNlhIfJrsH5OKZXDpUAKBlNS0Kx81jwOBp+HCeZqmoJuihTv57/jvQ==} + engines: {node: '>=18'} + cpu: [x64] + os: [linux] + + '@esbuild/netbsd-arm64@0.25.8': + resolution: {integrity: sha512-d1KfruIeohqAi6SA+gENMuObDbEjn22olAR7egqnkCD9DGBG0wsEARotkLgXDu6c4ncgWTZJtN5vcgxzWRMzcw==} + engines: {node: '>=18'} + cpu: [arm64] + os: [netbsd] + + '@esbuild/netbsd-x64@0.25.8': + resolution: {integrity: sha512-nVDCkrvx2ua+XQNyfrujIG38+YGyuy2Ru9kKVNyh5jAys6n+l44tTtToqHjino2My8VAY6Lw9H7RI73XFi66Cg==} + engines: {node: '>=18'} + cpu: [x64] + os: [netbsd] + + '@esbuild/openbsd-arm64@0.25.8': + resolution: {integrity: sha512-j8HgrDuSJFAujkivSMSfPQSAa5Fxbvk4rgNAS5i3K+r8s1X0p1uOO2Hl2xNsGFppOeHOLAVgYwDVlmxhq5h+SQ==} + engines: {node: '>=18'} + cpu: [arm64] + os: [openbsd] + + '@esbuild/openbsd-x64@0.25.8': + resolution: {integrity: sha512-1h8MUAwa0VhNCDp6Af0HToI2TJFAn1uqT9Al6DJVzdIBAd21m/G0Yfc77KDM3uF3T/YaOgQq3qTJHPbTOInaIQ==} + engines: {node: '>=18'} + cpu: [x64] + os: [openbsd] + + '@esbuild/openharmony-arm64@0.25.8': + resolution: {integrity: sha512-r2nVa5SIK9tSWd0kJd9HCffnDHKchTGikb//9c7HX+r+wHYCpQrSgxhlY6KWV1nFo1l4KFbsMlHk+L6fekLsUg==} + engines: {node: '>=18'} + cpu: [arm64] + os: [openharmony] + + '@esbuild/sunos-x64@0.25.8': + resolution: {integrity: sha512-zUlaP2S12YhQ2UzUfcCuMDHQFJyKABkAjvO5YSndMiIkMimPmxA+BYSBikWgsRpvyxuRnow4nS5NPnf9fpv41w==} + engines: {node: '>=18'} + cpu: [x64] + os: [sunos] + + '@esbuild/win32-arm64@0.25.8': + resolution: {integrity: sha512-YEGFFWESlPva8hGL+zvj2z/SaK+pH0SwOM0Nc/d+rVnW7GSTFlLBGzZkuSU9kFIGIo8q9X3ucpZhu8PDN5A2sQ==} + engines: {node: '>=18'} + cpu: [arm64] + os: [win32] + + '@esbuild/win32-ia32@0.25.8': + resolution: {integrity: sha512-hiGgGC6KZ5LZz58OL/+qVVoZiuZlUYlYHNAmczOm7bs2oE1XriPFi5ZHHrS8ACpV5EjySrnoCKmcbQMN+ojnHg==} + engines: {node: '>=18'} + cpu: [ia32] + os: [win32] + + '@esbuild/win32-x64@0.25.8': + resolution: {integrity: sha512-cn3Yr7+OaaZq1c+2pe+8yxC8E144SReCQjN6/2ynubzYjvyqZjTXfQJpAcQpsdJq3My7XADANiYGHoFC69pLQw==} + engines: {node: '>=18'} + cpu: [x64] + os: [win32] + + '@eslint-community/eslint-utils@4.7.0': + resolution: {integrity: sha512-dyybb3AcajC7uha6CvhdVRJqaKyn7w2YKqKyAN37NKYgZT36w+iRb0Dymmc5qEJ549c/S31cMMSFd75bteCpCw==} + engines: {node: ^12.22.0 || ^14.17.0 || >=16.0.0} + peerDependencies: + eslint: ^6.0.0 || ^7.0.0 || >=8.0.0 + + '@eslint-community/regexpp@4.12.1': + resolution: {integrity: sha512-CCZCDJuduB9OUkFkY2IgppNZMi2lBQgD2qzwXkEia16cge2pijY/aXi96CJMquDMn3nJdlPV1A5KrJEXwfLNzQ==} + engines: {node: ^12.0.0 || ^14.0.0 || >=16.0.0} + + '@eslint/config-array@0.21.0': + resolution: {integrity: sha512-ENIdc4iLu0d93HeYirvKmrzshzofPw6VkZRKQGe9Nv46ZnWUzcF1xV01dcvEg/1wXUR61OmmlSfyeyO7EvjLxQ==} + engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + + '@eslint/config-helpers@0.3.0': + resolution: {integrity: sha512-ViuymvFmcJi04qdZeDc2whTHryouGcDlaxPqarTD0ZE10ISpxGUVZGZDx4w01upyIynL3iu6IXH2bS1NhclQMw==} + engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + + '@eslint/core@0.15.1': + resolution: {integrity: sha512-bkOp+iumZCCbt1K1CmWf0R9pM5yKpDv+ZXtvSyQpudrI9kuFLp+bM2WOPXImuD/ceQuaa8f5pj93Y7zyECIGNA==} + engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + + '@eslint/eslintrc@3.3.1': + resolution: {integrity: sha512-gtF186CXhIl1p4pJNGZw8Yc6RlshoePRvE0X91oPGb3vZ8pM3qOS9W9NGPat9LziaBV7XrJWGylNQXkGcnM3IQ==} + engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + + '@eslint/js@9.31.0': + resolution: {integrity: sha512-LOm5OVt7D4qiKCqoiPbA7LWmI+tbw1VbTUowBcUMgQSuM6poJufkFkYDcQpo5KfgD39TnNySV26QjOh7VFpSyw==} + engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + + '@eslint/object-schema@2.1.6': + resolution: {integrity: sha512-RBMg5FRL0I0gs51M/guSAj5/e14VQ4tpZnQNWwuDT66P14I43ItmPfIZRhO9fUVIPOAQXU47atlywZ/czoqFPA==} + engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + + '@eslint/plugin-kit@0.3.4': + resolution: {integrity: sha512-Ul5l+lHEcw3L5+k8POx6r74mxEYKG5kOb6Xpy2gCRW6zweT6TEhAf8vhxGgjhqrd/VO/Dirhsb+1hNpD1ue9hw==} + engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + + '@exodus/schemasafe@1.3.0': + resolution: {integrity: sha512-5Aap/GaRupgNx/feGBwLLTVv8OQFfv3pq2lPRzPg9R+IOBnDgghTGW7l7EuVXOvg5cc/xSAlRW8rBrjIC3Nvqw==} + + '@floating-ui/core@1.7.2': + resolution: {integrity: sha512-wNB5ooIKHQc+Kui96jE/n69rHFWAVoxn5CAzL1Xdd8FG03cgY3MLO+GF9U3W737fYDSgPWA6MReKhBQBop6Pcw==} + + '@floating-ui/dom@1.7.2': + resolution: {integrity: sha512-7cfaOQuCS27HD7DX+6ib2OrnW+b4ZBwDNnCcT0uTyidcmyWb03FnQqJybDBoCnpdxwBSfA94UAYlRCt7mV+TbA==} + + '@floating-ui/react-dom@2.1.4': + resolution: {integrity: sha512-JbbpPhp38UmXDDAu60RJmbeme37Jbgsm7NrHGgzYYFKmblzRUh6Pa641dII6LsjwF4XlScDrde2UAzDo/b9KPw==} + peerDependencies: + react: '>=16.8.0' + react-dom: '>=16.8.0' + + '@floating-ui/react@0.26.28': + resolution: {integrity: sha512-yORQuuAtVpiRjpMhdc0wJj06b9JFjrYF4qp96j++v2NBpbi6SEGF7donUJ3TMieerQ6qVkAv1tgr7L4r5roTqw==} + peerDependencies: + react: '>=16.8.0' + react-dom: '>=16.8.0' + + '@floating-ui/utils@0.2.10': + resolution: {integrity: sha512-aGTxbpbg8/b5JfU1HXSrbH3wXZuLPJcNEcZQFMxLs3oSzgtVu6nFPkbbGGUvBcUjKV2YyB9Wxxabo+HEH9tcRQ==} + + '@gerrit0/mini-shiki@3.8.1': + resolution: {integrity: sha512-HVZW+8pxoOExr5ZMPK15U79jQAZTO/S6i5byQyyZGjtNj+qaYd82cizTncwFzTQgiLo8uUBym6vh+/1tfJklTw==} + + '@headlessui/react@2.2.5': + resolution: {integrity: sha512-h1+2Vu1yR5pp/fBcTnwVEW8Kb94Hbxp7MXZLORfDzvSrbmGgiTyaTZ4LI/tPNZnK8eDrYD9s9cMbjm5HS5otIQ==} + engines: {node: '>=10'} + peerDependencies: + react: ^18 || ^19 || ^19.0.0-rc + react-dom: ^18 || ^19 || ^19.0.0-rc + + '@heroicons/react@2.2.0': + resolution: {integrity: sha512-LMcepvRaS9LYHJGsF0zzmgKCUim/X3N/DQKc4jepAXJ7l8QxJ1PmxJzqplF2Z3FE4PqBAIGyJAQ/w4B5dsqbtQ==} + peerDependencies: + react: '>= 16 || ^19.0.0-rc' + + '@humanfs/core@0.19.1': + resolution: {integrity: sha512-5DyQ4+1JEUzejeK1JGICcideyfUbGixgS9jNgex5nqkW+cY7WZhxBigmieN5Qnw9ZosSNVC9KQKyb+GUaGyKUA==} + engines: {node: '>=18.18.0'} + + '@humanfs/node@0.16.6': + resolution: {integrity: sha512-YuI2ZHQL78Q5HbhDiBA1X4LmYdXCKCMQIfw0pw7piHJwyREFebJUvrQN4cMssyES6x+vfUbx1CIpaQUKYdQZOw==} + engines: {node: '>=18.18.0'} + + '@humanwhocodes/module-importer@1.0.1': + resolution: {integrity: sha512-bxveV4V8v5Yb4ncFTT3rPSgZBOpCkjfK0y4oVVVJwIuDVBRMDXrPyXRL988i5ap9m9bnyEEjWfm5WkBmtffLfA==} + engines: {node: '>=12.22'} + + '@humanwhocodes/retry@0.3.1': + resolution: {integrity: sha512-JBxkERygn7Bv/GbN5Rv8Ul6LVknS+5Bp6RgDC/O8gEBU/yeH5Ui5C/OlWrTb6qct7LjjfT6Re2NxB0ln0yYybA==} + engines: {node: '>=18.18'} + + '@humanwhocodes/retry@0.4.3': + resolution: {integrity: sha512-bV0Tgo9K4hfPCek+aMAn81RppFKv2ySDQeMoSZuvTASywNTnVJCArCZE2FWqpvIatKu7VMRLWlR1EazvVhDyhQ==} + engines: {node: '>=18.18'} + + '@ibm-cloud/openapi-ruleset-utilities@1.9.0': + resolution: {integrity: sha512-AoFbSarOqFBYH+1TZ9Ahkm2IWYSi5v0pBk88fpV+5b3qGJukypX8PwvCWADjuyIccKg48/F73a6hTTkBzDQ2UA==} + engines: {node: '>=16.0.0'} + + '@ibm-cloud/openapi-ruleset@1.31.1': + resolution: {integrity: sha512-3WK2FREmDA2aadCjD71PE7tx5evyvmhg80ts1kXp2IzXIA0ZJ7guGM66tj40kxaqwpMSGchwEnnfYswntav76g==} + engines: {node: '>=16.0.0'} + + '@isaacs/balanced-match@4.0.1': + resolution: {integrity: sha512-yzMTt9lEb8Gv7zRioUilSglI0c0smZ9k5D65677DLWLtWJaXIS3CqcGyUFByYKlnUj6TkjLVs54fBl6+TiGQDQ==} + engines: {node: 20 || >=22} + + '@isaacs/brace-expansion@5.0.0': + resolution: {integrity: sha512-ZT55BDLV0yv0RBm2czMiZ+SqCGO7AvmOM3G/w2xhVPH+te0aKgFjmBvGlL1dH+ql2tgGO3MVrbb3jCKyvpgnxA==} + engines: {node: 20 || >=22} + + '@isaacs/cliui@8.0.2': + resolution: {integrity: sha512-O8jcjabXaleOG9DQ0+ARXWZBTfnP4WNAqzuiJK7ll44AmxGKv/J2M4TPjxjY3znBCfvBXFzucm1twdyFybFqEA==} + engines: {node: '>=12'} + + '@isaacs/fs-minipass@4.0.1': + resolution: {integrity: sha512-wgm9Ehl2jpeqP3zw/7mo3kRHFp5MEDhqAdwy1fTGkHAwnkGOVsgpvQhL8B5n1qlb01jV3n/bI0ZfZp5lWA1k4w==} + engines: {node: '>=18.0.0'} + + '@istanbuljs/schema@0.1.3': + resolution: {integrity: sha512-ZXRY4jNvVgSVQ8DL3LTcakaAtXwTVUxE81hslsyD2AtoXW/wVob10HkOJ1X/pAlcI7D+2YoZKg5do8G/w6RYgA==} + engines: {node: '>=8'} + + '@joshwooding/vite-plugin-react-docgen-typescript@0.6.1': + resolution: {integrity: sha512-J4BaTocTOYFkMHIra1JDWrMWpNmBl4EkplIwHEsV8aeUOtdWjwSnln9U7twjMFTAEB7mptNtSKyVi1Y2W9sDJw==} + peerDependencies: + typescript: '>= 4.3.x' + vite: ^3.0.0 || ^4.0.0 || ^5.0.0 || ^6.0.0 || ^7.0.0 + peerDependenciesMeta: + typescript: + optional: true + + '@jridgewell/gen-mapping@0.3.12': + resolution: {integrity: sha512-OuLGC46TjB5BbN1dH8JULVVZY4WTdkF7tV9Ys6wLL1rubZnCMstOhNHueU5bLCrnRuDhKPDM4g6sw4Bel5Gzqg==} + + '@jridgewell/gen-mapping@0.3.13': + resolution: {integrity: sha512-2kkt/7niJ6MgEPxF0bYdQ6etZaA+fQvDcLKckhy1yIQOzaoKjBBjSj63/aLVjYE3qhRt5dvM+uUyfCg6UKCBbA==} + + '@jridgewell/resolve-uri@3.1.2': + resolution: {integrity: sha512-bRISgCIjP20/tbWSPWMEi54QVPRZExkuD9lJL+UIxUKtwVJA8wW1Trb1jMs1RFXo1CBTNZ/5hpC9QvmKWdopKw==} + engines: {node: '>=6.0.0'} + + '@jridgewell/source-map@0.3.11': + resolution: {integrity: sha512-ZMp1V8ZFcPG5dIWnQLr3NSI1MiCU7UETdS/A0G8V/XWHvJv3ZsFqutJn1Y5RPmAPX6F3BiE397OqveU/9NCuIA==} + + '@jridgewell/sourcemap-codec@1.5.4': + resolution: {integrity: sha512-VT2+G1VQs/9oz078bLrYbecdZKs912zQlkelYpuf+SXF+QvZDYJlbx/LSx+meSAwdDFnF8FVXW92AVjjkVmgFw==} + + '@jridgewell/sourcemap-codec@1.5.5': + resolution: {integrity: sha512-cYQ9310grqxueWbl+WuIUIaiUaDcj7WOq5fVhEljNVgRfOUhY9fy2zTvfoqWsnebh8Sl70VScFbICvJnLKB0Og==} + + '@jridgewell/trace-mapping@0.3.29': + resolution: {integrity: sha512-uw6guiW/gcAGPDhLmd77/6lW8QLeiV5RUTsAX46Db6oLhGaVj4lhnPwb184s1bkc8kdVg/+h988dro8GRDpmYQ==} + + '@jridgewell/trace-mapping@0.3.30': + resolution: {integrity: sha512-GQ7Nw5G2lTu/BtHTKfXhKHok2WGetd4XYcVKGx00SjAk8GMwgJM3zr6zORiPGuOE+/vkc90KtTosSSvaCjKb2Q==} + + '@jridgewell/trace-mapping@0.3.31': + resolution: {integrity: sha512-zzNR+SdQSDJzc8joaeP8QQoCQr8NuYx2dIIytl1QeBEZHJ9uW6hebsrYgbz8hJwUQao3TWCMtmfV8Nu1twOLAw==} + + '@jsdevtools/ono@7.1.3': + resolution: {integrity: sha512-4JQNk+3mVzK3xh2rqd6RB4J46qUR19azEHBneZyTZM+c456qOrbbM/5xcR8huNCCcbVt7+UmizG6GuUvPvKUYg==} + + '@jsep-plugin/assignment@1.3.0': + resolution: {integrity: sha512-VVgV+CXrhbMI3aSusQyclHkenWSAm95WaiKrMxRFam3JSUiIaQjoMIw2sEs/OX4XifnqeQUN4DYbJjlA8EfktQ==} + engines: {node: '>= 10.16.0'} + peerDependencies: + jsep: ^0.4.0||^1.0.0 + + '@jsep-plugin/regex@1.0.4': + resolution: {integrity: sha512-q7qL4Mgjs1vByCaTnDFcBnV9HS7GVPJX5vyVoCgZHNSC9rjwIlmbXG5sUuorR5ndfHAIlJ8pVStxvjXHbNvtUg==} + engines: {node: '>= 10.16.0'} + peerDependencies: + jsep: ^0.4.0||^1.0.0 + + '@jsep-plugin/ternary@1.1.4': + resolution: {integrity: sha512-ck5wiqIbqdMX6WRQztBL7ASDty9YLgJ3sSAK5ZpBzXeySvFGCzIvM6UiAI4hTZ22fEcYQVV/zhUbNscggW+Ukg==} + engines: {node: '>= 10.16.0'} + peerDependencies: + jsep: ^0.4.0||^1.0.0 + + '@lezer/common@1.2.3': + resolution: {integrity: sha512-w7ojc8ejBqr2REPsWxJjrMFsA/ysDCFICn8zEOR9mrqzOu2amhITYuLD8ag6XZf0CFXDrhKqw7+tW8cX66NaDA==} + + '@lezer/highlight@1.2.1': + resolution: {integrity: sha512-Z5duk4RN/3zuVO7Jq0pGLJ3qynpxUVsh7IbUbGj88+uV2ApSAn6kWg2au3iJb+0Zi7kKtqffIESgNcRXWZWmSA==} + + '@lezer/lr@1.4.2': + resolution: {integrity: sha512-pu0K1jCIdnQ12aWNaAVU5bzi7Bd1w54J3ECgANPmYLtQKP0HBj2cE/5coBD66MT10xbtIuUr7tg0Shbsvk0mDA==} + + '@lezer/python@1.1.18': + resolution: {integrity: sha512-31FiUrU7z9+d/ElGQLJFXl+dKOdx0jALlP3KEOsGTex8mvj+SoE1FgItcHWK/axkxCHGUSpqIHt6JAWfWu9Rhg==} + + '@lit/react@1.0.8': + resolution: {integrity: sha512-p2+YcF+JE67SRX3mMlJ1TKCSTsgyOVdAwd/nxp3NuV1+Cb6MWALbN6nT7Ld4tpmYofcE5kcaSY1YBB9erY+6fw==} + peerDependencies: + '@types/react': 17 || 18 || 19 + + '@marijn/find-cluster-break@1.0.2': + resolution: {integrity: sha512-l0h88YhZFyKdXIFNfSWpyjStDjGHwZ/U7iobcK1cQQD8sejsONdQtTVU+1wVN1PBw40PiiHB1vA5S7VTfQiP9g==} + + '@mdx-js/react@3.1.0': + resolution: {integrity: sha512-QjHtSaoameoalGnKDT3FoIl4+9RwyTmo9ZJGBdLOks/YOiWHoRDI3PUwEzOE7kEmGcV3AFcp9K6dYu9rEuKLAQ==} + peerDependencies: + '@types/react': '>=16' + react: '>=16' + + '@microsoft/api-extractor-model@7.30.7': + resolution: {integrity: sha512-TBbmSI2/BHpfR9YhQA7nH0nqVmGgJ0xH0Ex4D99/qBDAUpnhA2oikGmdXanbw9AWWY/ExBYIpkmY8dBHdla3YQ==} + + '@microsoft/api-extractor@7.52.10': + resolution: {integrity: sha512-LhKytJM5ZJkbHQVfW/3o747rZUNs/MGg6j/wt/9qwwqEOfvUDTYXXxIBuMgrRXhJ528p41iyz4zjBVHZU74Odg==} + hasBin: true + + '@microsoft/tsdoc-config@0.17.1': + resolution: {integrity: sha512-UtjIFe0C6oYgTnad4q1QP4qXwLhe6tIpNTRStJ2RZEPIkqQPREAwE5spzVxsdn9UaEMUqhh0AqSx3X4nWAKXWw==} + + '@microsoft/tsdoc@0.15.1': + resolution: {integrity: sha512-4aErSrCR/On/e5G2hDP0wjooqDdauzEbIq8hIkIe5pXV0rtWJZvdCEKL0ykZxex+IxIwBp0eGeV48hQN07dXtw==} + + '@neoconfetti/react@1.0.0': + resolution: {integrity: sha512-klcSooChXXOzIm+SE5IISIAn3bYzYfPjbX7D7HoqZL84oAfgREeSg5vSIaSFH+DaGzzvImTyWe1OyrJ67vik4A==} + + '@nodelib/fs.scandir@2.1.5': + resolution: {integrity: sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==} + engines: {node: '>= 8'} + + '@nodelib/fs.stat@2.0.5': + resolution: {integrity: sha512-RkhPPp2zrqDAQA/2jNhnztcPAlv64XdhIp7a7454A5ovI7Bukxgt7MX7udwAu3zg1DcpPU0rz3VV1SeaqvY4+A==} + engines: {node: '>= 8'} + + '@nodelib/fs.walk@1.2.8': + resolution: {integrity: sha512-oGB+UxlgWcgQkgwo8GcEGwemoTFt3FIO9ababBmaGwXIoBKZ+GTy0pP185beGg7Llih/NSHSV2XAs1lnznocSg==} + engines: {node: '>= 8'} + + '@orval/angular@7.10.0': + resolution: {integrity: sha512-M89GKo/PibxYXvOKp9+i6BLxhEW8YsO+evwuV2kMbDGNS3RiYDwzmMBcA9SVL7m8CumeZoxNEAXsupzq96ZAXA==} + + '@orval/axios@7.10.0': + resolution: {integrity: sha512-AB6BjEwyguIcH8olzOTFPvwUP8z63yP4Jfl3T2UoeFchK04KqWqxbUoxmDG9xVQ79uMs/uOrb0X+GFwdZ56gAg==} + + '@orval/core@7.10.0': + resolution: {integrity: sha512-Lm7HY4Kwzehe+2HNfi+Ov/IZ+m3nj3NskVGvOyJDAqaaHB7G/xydSCtgELG32ur4G+M/XmwChAjoP4TCNVh0VA==} + + '@orval/fetch@7.10.0': + resolution: {integrity: sha512-bWcXPmARcXhXRveBtUnkfPlkUcLEzfGaflAdqN4CtScS48LgNrXXtuyt2BV2wvEXAavCWIhnRyQvz2foTU4U8Q==} + + '@orval/hono@7.10.0': + resolution: {integrity: sha512-bOxTdZxx2BpGQf7fFuCeeUe//ZYDWc6Yz9WOhj3HrnsD06xTRKFWVBi/QZ29QcAPxqwunu/VWwbqoiHHuuX3bA==} + + '@orval/mcp@7.10.0': + resolution: {integrity: sha512-ztLXGOSxK7jFwPKAeYPR85BjKRh3KTClKEnM2MFmo2FHHojn72DPXRPCmy0Wbw5Ee+JOxK2kIpyx+HZi9XVxiA==} + + '@orval/mock@7.10.0': + resolution: {integrity: sha512-vkEWCaKEyMfWGJF5MtxVzl+blwc9vYzwdYxMoSdjA5yS2dNBrdNlt1aLtb4+aoI1jgBgpCg/OB7VtWaL5QYidA==} + + '@orval/query@7.10.0': + resolution: {integrity: sha512-DBVg8RyKWSQKhr5Zfvxx5XICUdDUkG4MJKSd4BQCrRjUWgN6vwGunMEKyfnjpS5mFUSCkwWD/I3rTkjW6aysJA==} + + '@orval/swr@7.10.0': + resolution: {integrity: sha512-ZdApomZQhJ5ZogjJgBK+haeCOP9gUaMaGKGjTVJr86jJaygDcKn54Ok1quiDUCbX42Eye+cgmQJeKeZvqnPohA==} + + '@orval/zod@7.10.0': + resolution: {integrity: sha512-AB/508IBMlVDBcGvlq+ASz7DvqU3nhoDnIeBCyjwNfQwhYzREU0qqiFBnH0XAW70c6SCMf9/bIcYbw8GAx/zxA==} + + '@pkgjs/parseargs@0.11.0': + resolution: {integrity: sha512-+1VkjdD0QBLPodGrJUeqarH8VAIvQODIbwh9XpP5Syisf7YoQgsJKPNFoqqLQlu+VQ/tVSshMR6loPMn8U+dPg==} + engines: {node: '>=14'} + + '@playwright/test@1.54.1': + resolution: {integrity: sha512-FS8hQ12acieG2dYSksmLOF7BNxnVf2afRJdCuM1eMSxj6QTSE6G4InGF7oApGgDb65MX7AwMVlIkpru0yZA4Xw==} + engines: {node: '>=18'} + hasBin: true + + '@polka/url@1.0.0-next.29': + resolution: {integrity: sha512-wwQAWhWSuHaag8c4q/KN/vCoeOJYshAIvMQwD4GpSb3OiZklFfvAgmj0VCBBImRpuF/aFgIRzllXlVX93Jevww==} + + '@radix-ui/number@1.1.1': + resolution: {integrity: sha512-MkKCwxlXTgz6CFoJx3pCwn07GKp36+aZyu/u2Ln2VrA5DcdyCZkASEDBTd8x5whTQQL5CiYf4prXKLcgQdv29g==} + + '@radix-ui/primitive@1.1.2': + resolution: {integrity: sha512-XnbHrrprsNqZKQhStrSwgRUQzoCI1glLzdw79xiZPoofhGICeZRSQ3dIxAKH1gb3OHfNf4d6f+vAv3kil2eggA==} + + '@radix-ui/primitive@1.1.3': + resolution: {integrity: sha512-JTF99U/6XIjCBo0wqkU5sK10glYe27MRRsfwoiq5zzOEZLHU3A3KCMa5X/azekYRCJ0HlwI0crAXS/5dEHTzDg==} + + '@radix-ui/react-arrow@1.1.7': + resolution: {integrity: sha512-F+M1tLhO+mlQaOWspE8Wstg+z6PwxwRd8oQ8IXceWz92kfAmalTRf0EjrouQeo7QssEPfCn05B4Ihs1K9WQ/7w==} + peerDependencies: + '@types/react': '*' + '@types/react-dom': '*' + react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + react-dom: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + '@types/react-dom': + optional: true + + '@radix-ui/react-collection@1.1.7': + resolution: {integrity: sha512-Fh9rGN0MoI4ZFUNyfFVNU4y9LUz93u9/0K+yLgA2bwRojxM8JU1DyvvMBabnZPBgMWREAJvU2jjVzq+LrFUglw==} + peerDependencies: + '@types/react': '*' + '@types/react-dom': '*' + react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + react-dom: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + '@types/react-dom': + optional: true + + '@radix-ui/react-compose-refs@1.1.2': + resolution: {integrity: sha512-z4eqJvfiNnFMHIIvXP3CY57y2WJs5g2v3X0zm9mEJkrkNv4rDxu+sg9Jh8EkXyeqBkB7SOcboo9dMVqhyrACIg==} + peerDependencies: + '@types/react': '*' + react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + + '@radix-ui/react-context-menu@2.2.15': + resolution: {integrity: sha512-UsQUMjcYTsBjTSXw0P3GO0werEQvUY2plgRQuKoCTtkNr45q1DiL51j4m7gxhABzZ0BadoXNsIbg7F3KwiUBbw==} + peerDependencies: + '@types/react': '*' + '@types/react-dom': '*' + react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + react-dom: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + '@types/react-dom': + optional: true + + '@radix-ui/react-context@1.1.2': + resolution: {integrity: sha512-jCi/QKUM2r1Ju5a3J64TH2A5SpKAgh0LpknyqdQ4m6DCV0xJ2HG1xARRwNGPQfi1SLdLWZ1OJz6F4OMBBNiGJA==} + peerDependencies: + '@types/react': '*' + react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + + '@radix-ui/react-direction@1.1.1': + resolution: {integrity: sha512-1UEWRX6jnOA2y4H5WczZ44gOOjTEmlqv1uNW4GAJEO5+bauCBhv8snY65Iw5/VOS/ghKN9gr2KjnLKxrsvoMVw==} + peerDependencies: + '@types/react': '*' + react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + + '@radix-ui/react-dismissable-layer@1.1.10': + resolution: {integrity: sha512-IM1zzRV4W3HtVgftdQiiOmA0AdJlCtMLe00FXaHwgt3rAnNsIyDqshvkIW3hj/iu5hu8ERP7KIYki6NkqDxAwQ==} + peerDependencies: + '@types/react': '*' + '@types/react-dom': '*' + react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + react-dom: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + '@types/react-dom': + optional: true + + '@radix-ui/react-dismissable-layer@1.1.11': + resolution: {integrity: sha512-Nqcp+t5cTB8BinFkZgXiMJniQH0PsUt2k51FUhbdfeKvc4ACcG2uQniY/8+h1Yv6Kza4Q7lD7PQV0z0oicE0Mg==} + peerDependencies: + '@types/react': '*' + '@types/react-dom': '*' + react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + react-dom: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + '@types/react-dom': + optional: true + + '@radix-ui/react-focus-guards@1.1.2': + resolution: {integrity: sha512-fyjAACV62oPV925xFCrH8DR5xWhg9KYtJT4s3u54jxp+L/hbpTY2kIeEFFbFe+a/HCE94zGQMZLIpVTPVZDhaA==} + peerDependencies: + '@types/react': '*' + react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + + '@radix-ui/react-focus-scope@1.1.7': + resolution: {integrity: sha512-t2ODlkXBQyn7jkl6TNaw/MtVEVvIGelJDCG41Okq/KwUsJBwQ4XVZsHAVUkK4mBv3ewiAS3PGuUWuY2BoK4ZUw==} + peerDependencies: + '@types/react': '*' + '@types/react-dom': '*' + react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + react-dom: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + '@types/react-dom': + optional: true + + '@radix-ui/react-id@1.1.1': + resolution: {integrity: sha512-kGkGegYIdQsOb4XjsfM97rXsiHaBwco+hFI66oO4s9LU+PLAC5oJ7khdOVFxkhsmlbpUqDAvXw11CluXP+jkHg==} + peerDependencies: + '@types/react': '*' + react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + + '@radix-ui/react-menu@2.1.15': + resolution: {integrity: sha512-tVlmA3Vb9n8SZSd+YSbuFR66l87Wiy4du+YE+0hzKQEANA+7cWKH1WgqcEX4pXqxUFQKrWQGHdvEfw00TjFiew==} + peerDependencies: + '@types/react': '*' + '@types/react-dom': '*' + react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + react-dom: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + '@types/react-dom': + optional: true + + '@radix-ui/react-popper@1.2.7': + resolution: {integrity: sha512-IUFAccz1JyKcf/RjB552PlWwxjeCJB8/4KxT7EhBHOJM+mN7LdW+B3kacJXILm32xawcMMjb2i0cIZpo+f9kiQ==} + peerDependencies: + '@types/react': '*' + '@types/react-dom': '*' + react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + react-dom: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + '@types/react-dom': + optional: true + + '@radix-ui/react-popper@1.2.8': + resolution: {integrity: sha512-0NJQ4LFFUuWkE7Oxf0htBKS6zLkkjBH+hM1uk7Ng705ReR8m/uelduy1DBo0PyBXPKVnBA6YBlU94MBGXrSBCw==} + peerDependencies: + '@types/react': '*' + '@types/react-dom': '*' + react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + react-dom: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + '@types/react-dom': + optional: true + + '@radix-ui/react-portal@1.1.9': + resolution: {integrity: sha512-bpIxvq03if6UNwXZ+HTK71JLh4APvnXntDc6XOX8UVq4XQOVl7lwok0AvIl+b8zgCw3fSaVTZMpAPPagXbKmHQ==} + peerDependencies: + '@types/react': '*' + '@types/react-dom': '*' + react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + react-dom: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + '@types/react-dom': + optional: true + + '@radix-ui/react-presence@1.1.4': + resolution: {integrity: sha512-ueDqRbdc4/bkaQT3GIpLQssRlFgWaL/U2z/S31qRwwLWoxHLgry3SIfCwhxeQNbirEUXFa+lq3RL3oBYXtcmIA==} + peerDependencies: + '@types/react': '*' + '@types/react-dom': '*' + react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + react-dom: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + '@types/react-dom': + optional: true + + '@radix-ui/react-presence@1.1.5': + resolution: {integrity: sha512-/jfEwNDdQVBCNvjkGit4h6pMOzq8bHkopq458dPt2lMjx+eBQUohZNG9A7DtO/O5ukSbxuaNGXMjHicgwy6rQQ==} + peerDependencies: + '@types/react': '*' + '@types/react-dom': '*' + react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + react-dom: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + '@types/react-dom': + optional: true + + '@radix-ui/react-primitive@2.1.3': + resolution: {integrity: sha512-m9gTwRkhy2lvCPe6QJp4d3G1TYEUHn/FzJUtq9MjH46an1wJU+GdoGC5VLof8RX8Ft/DlpshApkhswDLZzHIcQ==} + peerDependencies: + '@types/react': '*' + '@types/react-dom': '*' + react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + react-dom: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + '@types/react-dom': + optional: true + + '@radix-ui/react-roving-focus@1.1.10': + resolution: {integrity: sha512-dT9aOXUen9JSsxnMPv/0VqySQf5eDQ6LCk5Sw28kamz8wSOW2bJdlX2Bg5VUIIcV+6XlHpWTIuTPCf/UNIyq8Q==} + peerDependencies: + '@types/react': '*' + '@types/react-dom': '*' + react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + react-dom: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + '@types/react-dom': + optional: true + + '@radix-ui/react-select@2.2.5': + resolution: {integrity: sha512-HnMTdXEVuuyzx63ME0ut4+sEMYW6oouHWNGUZc7ddvUWIcfCva/AMoqEW/3wnEllriMWBa0RHspCYnfCWJQYmA==} + peerDependencies: + '@types/react': '*' + '@types/react-dom': '*' + react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + react-dom: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + '@types/react-dom': + optional: true + + '@radix-ui/react-slot@1.2.3': + resolution: {integrity: sha512-aeNmHnBxbi2St0au6VBVC7JXFlhLlOnvIIlePNniyUNAClzmtAUEY8/pBiK3iHjufOlwA+c20/8jngo7xcrg8A==} + peerDependencies: + '@types/react': '*' + react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + + '@radix-ui/react-tooltip@1.2.8': + resolution: {integrity: sha512-tY7sVt1yL9ozIxvmbtN5qtmH2krXcBCfjEiCgKGLqunJHvgvZG2Pcl2oQ3kbcZARb1BGEHdkLzcYGO8ynVlieg==} + peerDependencies: + '@types/react': '*' + '@types/react-dom': '*' + react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + react-dom: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + '@types/react-dom': + optional: true + + '@radix-ui/react-use-callback-ref@1.1.1': + resolution: {integrity: sha512-FkBMwD+qbGQeMu1cOHnuGB6x4yzPjho8ap5WtbEJ26umhgqVXbhekKUQO+hZEL1vU92a3wHwdp0HAcqAUF5iDg==} + peerDependencies: + '@types/react': '*' + react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + + '@radix-ui/react-use-controllable-state@1.2.2': + resolution: {integrity: sha512-BjasUjixPFdS+NKkypcyyN5Pmg83Olst0+c6vGov0diwTEo6mgdqVR6hxcEgFuh4QrAs7Rc+9KuGJ9TVCj0Zzg==} + peerDependencies: + '@types/react': '*' + react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + + '@radix-ui/react-use-effect-event@0.0.2': + resolution: {integrity: sha512-Qp8WbZOBe+blgpuUT+lw2xheLP8q0oatc9UpmiemEICxGvFLYmHm9QowVZGHtJlGbS6A6yJ3iViad/2cVjnOiA==} + peerDependencies: + '@types/react': '*' + react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + + '@radix-ui/react-use-escape-keydown@1.1.1': + resolution: {integrity: sha512-Il0+boE7w/XebUHyBjroE+DbByORGR9KKmITzbR7MyQ4akpORYP/ZmbhAr0DG7RmmBqoOnZdy2QlvajJ2QA59g==} + peerDependencies: + '@types/react': '*' + react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + + '@radix-ui/react-use-layout-effect@1.1.1': + resolution: {integrity: sha512-RbJRS4UWQFkzHTTwVymMTUv8EqYhOp8dOOviLj2ugtTiXRaRQS7GLGxZTLL1jWhMeoSCf5zmcZkqTl9IiYfXcQ==} + peerDependencies: + '@types/react': '*' + react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + + '@radix-ui/react-use-previous@1.1.1': + resolution: {integrity: sha512-2dHfToCj/pzca2Ck724OZ5L0EVrr3eHRNsG/b3xQJLA2hZpVCS99bLAX+hm1IHXDEnzU6by5z/5MIY794/a8NQ==} + peerDependencies: + '@types/react': '*' + react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + + '@radix-ui/react-use-rect@1.1.1': + resolution: {integrity: sha512-QTYuDesS0VtuHNNvMh+CjlKJ4LJickCMUAqjlE3+j8w+RlRpwyX3apEQKGFzbZGdo7XNG1tXa+bQqIE7HIXT2w==} + peerDependencies: + '@types/react': '*' + react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + + '@radix-ui/react-use-size@1.1.1': + resolution: {integrity: sha512-ewrXRDTAqAXlkl6t/fkXWNAhFX9I+CkKlw6zjEwk86RSPKwZr3xpBRso655aqYafwtnbpHLj6toFzmd6xdVptQ==} + peerDependencies: + '@types/react': '*' + react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + + '@radix-ui/react-visually-hidden@1.2.3': + resolution: {integrity: sha512-pzJq12tEaaIhqjbzpCuv/OypJY/BPavOofm+dbab+MHLajy277+1lLm6JFcGgF5eskJ6mquGirhXY2GD/8u8Ug==} + peerDependencies: + '@types/react': '*' + '@types/react-dom': '*' + react: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + react-dom: ^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + '@types/react-dom': + optional: true + + '@radix-ui/rect@1.1.1': + resolution: {integrity: sha512-HPwpGIzkl28mWyZqG52jiqDJ12waP11Pa1lGoiyUkIEuMLBP0oeK/C89esbXrxsky5we7dfd8U58nm0SgAWpVw==} + + '@react-aria/focus@3.21.0': + resolution: {integrity: sha512-7NEGtTPsBy52EZ/ToVKCu0HSelE3kq9qeis+2eEq90XSuJOMaDHUQrA7RC2Y89tlEwQB31bud/kKRi9Qme1dkA==} + peerDependencies: + react: ^16.8.0 || ^17.0.0-rc.1 || ^18.0.0 || ^19.0.0-rc.1 + react-dom: ^16.8.0 || ^17.0.0-rc.1 || ^18.0.0 || ^19.0.0-rc.1 + + '@react-aria/interactions@3.25.4': + resolution: {integrity: sha512-HBQMxgUPHrW8V63u9uGgBymkMfj6vdWbB0GgUJY49K9mBKMsypcHeWkWM6+bF7kxRO728/IK8bWDV6whDbqjHg==} + peerDependencies: + react: ^16.8.0 || ^17.0.0-rc.1 || ^18.0.0 || ^19.0.0-rc.1 + react-dom: ^16.8.0 || ^17.0.0-rc.1 || ^18.0.0 || ^19.0.0-rc.1 + + '@react-aria/ssr@3.9.10': + resolution: {integrity: sha512-hvTm77Pf+pMBhuBm760Li0BVIO38jv1IBws1xFm1NoL26PU+fe+FMW5+VZWyANR6nYL65joaJKZqOdTQMkO9IQ==} + engines: {node: '>= 12'} + peerDependencies: + react: ^16.8.0 || ^17.0.0-rc.1 || ^18.0.0 || ^19.0.0-rc.1 + + '@react-aria/utils@3.30.0': + resolution: {integrity: sha512-ydA6y5G1+gbem3Va2nczj/0G0W7/jUVo/cbN10WA5IizzWIwMP5qhFr7macgbKfHMkZ+YZC3oXnt2NNre5odKw==} + peerDependencies: + react: ^16.8.0 || ^17.0.0-rc.1 || ^18.0.0 || ^19.0.0-rc.1 + react-dom: ^16.8.0 || ^17.0.0-rc.1 || ^18.0.0 || ^19.0.0-rc.1 + + '@react-dnd/asap@5.0.2': + resolution: {integrity: sha512-WLyfoHvxhs0V9U+GTsGilGgf2QsPl6ZZ44fnv0/b8T3nQyvzxidxsg/ZltbWssbsRDlYW8UKSQMTGotuTotZ6A==} + + '@react-dnd/invariant@4.0.2': + resolution: {integrity: sha512-xKCTqAK/FFauOM9Ta2pswIyT3D8AQlfrYdOi/toTPEhqCuAs1v5tcJ3Y08Izh1cJ5Jchwy9SeAXmMg6zrKs2iw==} + + '@react-dnd/shallowequal@4.0.2': + resolution: {integrity: sha512-/RVXdLvJxLg4QKvMoM5WlwNR9ViO9z8B/qPcc+C0Sa/teJY7QG7kJ441DwzOjMYEY7GmU4dj5EcGHIkKZiQZCA==} + + '@react-stately/flags@3.1.2': + resolution: {integrity: sha512-2HjFcZx1MyQXoPqcBGALwWWmgFVUk2TuKVIQxCbRq7fPyWXIl6VHcakCLurdtYC2Iks7zizvz0Idv48MQ38DWg==} + + '@react-stately/utils@3.10.8': + resolution: {integrity: sha512-SN3/h7SzRsusVQjQ4v10LaVsDc81jyyR0DD5HnsQitm/I5WDpaSr2nRHtyloPFU48jlql1XX/S04T2DLQM7Y3g==} + peerDependencies: + react: ^16.8.0 || ^17.0.0-rc.1 || ^18.0.0 || ^19.0.0-rc.1 + + '@react-types/shared@3.31.0': + resolution: {integrity: sha512-ua5U6V66gDcbLZe4P2QeyNgPp4YWD1ymGA6j3n+s8CGExtrCPe64v+g4mvpT8Bnb985R96e4zFT61+m0YCwqMg==} + peerDependencies: + react: ^16.8.0 || ^17.0.0-rc.1 || ^18.0.0 || ^19.0.0-rc.1 + + '@reactflow/background@11.3.14': + resolution: {integrity: sha512-Gewd7blEVT5Lh6jqrvOgd4G6Qk17eGKQfsDXgyRSqM+CTwDqRldG2LsWN4sNeno6sbqVIC2fZ+rAUBFA9ZEUDA==} + peerDependencies: + react: '>=17' + react-dom: '>=17' + + '@reactflow/controls@11.2.14': + resolution: {integrity: sha512-MiJp5VldFD7FrqaBNIrQ85dxChrG6ivuZ+dcFhPQUwOK3HfYgX2RHdBua+gx+40p5Vw5It3dVNp/my4Z3jF0dw==} + peerDependencies: + react: '>=17' + react-dom: '>=17' + + '@reactflow/core@11.11.4': + resolution: {integrity: sha512-H4vODklsjAq3AMq6Np4LE12i1I4Ta9PrDHuBR9GmL8uzTt2l2jh4CiQbEMpvMDcp7xi4be0hgXj+Ysodde/i7Q==} + peerDependencies: + react: '>=17' + react-dom: '>=17' + + '@reactflow/minimap@11.7.14': + resolution: {integrity: sha512-mpwLKKrEAofgFJdkhwR5UQ1JYWlcAAL/ZU/bctBkuNTT1yqV+y0buoNVImsRehVYhJwffSWeSHaBR5/GJjlCSQ==} + peerDependencies: + react: '>=17' + react-dom: '>=17' + + '@reactflow/node-resizer@2.2.14': + resolution: {integrity: sha512-fwqnks83jUlYr6OHcdFEedumWKChTHRGw/kbCxj0oqBd+ekfs+SIp4ddyNU0pdx96JIm5iNFS0oNrmEiJbbSaA==} + peerDependencies: + react: '>=17' + react-dom: '>=17' + + '@reactflow/node-toolbar@1.3.14': + resolution: {integrity: sha512-rbynXQnH/xFNu4P9H+hVqlEUafDCkEoCy0Dg9mG22Sg+rY/0ck6KkrAQrYrTgXusd+cEJOMK0uOOFCK2/5rSGQ==} + peerDependencies: + react: '>=17' + react-dom: '>=17' + + '@rolldown/pluginutils@1.0.0-beta.27': + resolution: {integrity: sha512-+d0F4MKMCbeVUJwG96uQ4SgAznZNSq93I3V+9NHA4OpvqG8mRCpGdKmK8l/dl02h2CCDHwW2FqilnTyDcAnqjA==} + + '@rollup/pluginutils@5.2.0': + resolution: {integrity: sha512-qWJ2ZTbmumwiLFomfzTyt5Kng4hwPi9rwCYN4SHb6eaRU1KNO4ccxINHr/VhH4GgPlt1XfSTLX2LBTme8ne4Zw==} + engines: {node: '>=14.0.0'} + peerDependencies: + rollup: ^1.20.0||^2.0.0||^3.0.0||^4.0.0 + peerDependenciesMeta: + rollup: + optional: true + + '@rollup/rollup-android-arm-eabi@4.45.1': + resolution: {integrity: sha512-NEySIFvMY0ZQO+utJkgoMiCAjMrGvnbDLHvcmlA33UXJpYBCvlBEbMMtV837uCkS+plG2umfhn0T5mMAxGrlRA==} + cpu: [arm] + os: [android] + + '@rollup/rollup-android-arm64@4.45.1': + resolution: {integrity: sha512-ujQ+sMXJkg4LRJaYreaVx7Z/VMgBBd89wGS4qMrdtfUFZ+TSY5Rs9asgjitLwzeIbhwdEhyj29zhst3L1lKsRQ==} + cpu: [arm64] + os: [android] + + '@rollup/rollup-darwin-arm64@4.45.1': + resolution: {integrity: sha512-FSncqHvqTm3lC6Y13xncsdOYfxGSLnP+73k815EfNmpewPs+EyM49haPS105Rh4aF5mJKywk9X0ogzLXZzN9lA==} + cpu: [arm64] + os: [darwin] + + '@rollup/rollup-darwin-x64@4.45.1': + resolution: {integrity: sha512-2/vVn/husP5XI7Fsf/RlhDaQJ7x9zjvC81anIVbr4b/f0xtSmXQTFcGIQ/B1cXIYM6h2nAhJkdMHTnD7OtQ9Og==} + cpu: [x64] + os: [darwin] + + '@rollup/rollup-freebsd-arm64@4.45.1': + resolution: {integrity: sha512-4g1kaDxQItZsrkVTdYQ0bxu4ZIQ32cotoQbmsAnW1jAE4XCMbcBPDirX5fyUzdhVCKgPcrwWuucI8yrVRBw2+g==} + cpu: [arm64] + os: [freebsd] + + '@rollup/rollup-freebsd-x64@4.45.1': + resolution: {integrity: sha512-L/6JsfiL74i3uK1Ti2ZFSNsp5NMiM4/kbbGEcOCps99aZx3g8SJMO1/9Y0n/qKlWZfn6sScf98lEOUe2mBvW9A==} + cpu: [x64] + os: [freebsd] + + '@rollup/rollup-linux-arm-gnueabihf@4.45.1': + resolution: {integrity: sha512-RkdOTu2jK7brlu+ZwjMIZfdV2sSYHK2qR08FUWcIoqJC2eywHbXr0L8T/pONFwkGukQqERDheaGTeedG+rra6Q==} + cpu: [arm] + os: [linux] + + '@rollup/rollup-linux-arm-musleabihf@4.45.1': + resolution: {integrity: sha512-3kJ8pgfBt6CIIr1o+HQA7OZ9mp/zDk3ctekGl9qn/pRBgrRgfwiffaUmqioUGN9hv0OHv2gxmvdKOkARCtRb8Q==} + cpu: [arm] + os: [linux] + + '@rollup/rollup-linux-arm64-gnu@4.45.1': + resolution: {integrity: sha512-k3dOKCfIVixWjG7OXTCOmDfJj3vbdhN0QYEqB+OuGArOChek22hn7Uy5A/gTDNAcCy5v2YcXRJ/Qcnm4/ma1xw==} + cpu: [arm64] + os: [linux] + + '@rollup/rollup-linux-arm64-musl@4.45.1': + resolution: {integrity: sha512-PmI1vxQetnM58ZmDFl9/Uk2lpBBby6B6rF4muJc65uZbxCs0EA7hhKCk2PKlmZKuyVSHAyIw3+/SiuMLxKxWog==} + cpu: [arm64] + os: [linux] + + '@rollup/rollup-linux-loongarch64-gnu@4.45.1': + resolution: {integrity: sha512-9UmI0VzGmNJ28ibHW2GpE2nF0PBQqsyiS4kcJ5vK+wuwGnV5RlqdczVocDSUfGX/Na7/XINRVoUgJyFIgipoRg==} + cpu: [loong64] + os: [linux] + + '@rollup/rollup-linux-powerpc64le-gnu@4.45.1': + resolution: {integrity: sha512-7nR2KY8oEOUTD3pBAxIBBbZr0U7U+R9HDTPNy+5nVVHDXI4ikYniH1oxQz9VoB5PbBU1CZuDGHkLJkd3zLMWsg==} + cpu: [ppc64] + os: [linux] + + '@rollup/rollup-linux-riscv64-gnu@4.45.1': + resolution: {integrity: sha512-nlcl3jgUultKROfZijKjRQLUu9Ma0PeNv/VFHkZiKbXTBQXhpytS8CIj5/NfBeECZtY2FJQubm6ltIxm/ftxpw==} + cpu: [riscv64] + os: [linux] + + '@rollup/rollup-linux-riscv64-musl@4.45.1': + resolution: {integrity: sha512-HJV65KLS51rW0VY6rvZkiieiBnurSzpzore1bMKAhunQiECPuxsROvyeaot/tcK3A3aGnI+qTHqisrpSgQrpgA==} + cpu: [riscv64] + os: [linux] + + '@rollup/rollup-linux-s390x-gnu@4.45.1': + resolution: {integrity: sha512-NITBOCv3Qqc6hhwFt7jLV78VEO/il4YcBzoMGGNxznLgRQf43VQDae0aAzKiBeEPIxnDrACiMgbqjuihx08OOw==} + cpu: [s390x] + os: [linux] + + '@rollup/rollup-linux-x64-gnu@4.45.1': + resolution: {integrity: sha512-+E/lYl6qu1zqgPEnTrs4WysQtvc/Sh4fC2nByfFExqgYrqkKWp1tWIbe+ELhixnenSpBbLXNi6vbEEJ8M7fiHw==} + cpu: [x64] + os: [linux] + + '@rollup/rollup-linux-x64-musl@4.45.1': + resolution: {integrity: sha512-a6WIAp89p3kpNoYStITT9RbTbTnqarU7D8N8F2CV+4Cl9fwCOZraLVuVFvlpsW0SbIiYtEnhCZBPLoNdRkjQFw==} + cpu: [x64] + os: [linux] + + '@rollup/rollup-win32-arm64-msvc@4.45.1': + resolution: {integrity: sha512-T5Bi/NS3fQiJeYdGvRpTAP5P02kqSOpqiopwhj0uaXB6nzs5JVi2XMJb18JUSKhCOX8+UE1UKQufyD6Or48dJg==} + cpu: [arm64] + os: [win32] + + '@rollup/rollup-win32-ia32-msvc@4.45.1': + resolution: {integrity: sha512-lxV2Pako3ujjuUe9jiU3/s7KSrDfH6IgTSQOnDWr9aJ92YsFd7EurmClK0ly/t8dzMkDtd04g60WX6yl0sGfdw==} + cpu: [ia32] + os: [win32] + + '@rollup/rollup-win32-x64-msvc@4.45.1': + resolution: {integrity: sha512-M/fKi4sasCdM8i0aWJjCSFm2qEnYRR8AMLG2kxp6wD13+tMGA4Z1tVAuHkNRjud5SW2EM3naLuK35w9twvf6aA==} + cpu: [x64] + os: [win32] + + '@rushstack/node-core-library@5.14.0': + resolution: {integrity: sha512-eRong84/rwQUlATGFW3TMTYVyqL1vfW9Lf10PH+mVGfIb9HzU3h5AASNIw+axnBLjnD0n3rT5uQBwu9fvzATrg==} + peerDependencies: + '@types/node': '*' + peerDependenciesMeta: + '@types/node': + optional: true + + '@rushstack/rig-package@0.5.3': + resolution: {integrity: sha512-olzSSjYrvCNxUFZowevC3uz8gvKr3WTpHQ7BkpjtRpA3wK+T0ybep/SRUMfr195gBzJm5gaXw0ZMgjIyHqJUow==} + + '@rushstack/terminal@0.15.4': + resolution: {integrity: sha512-OQSThV0itlwVNHV6thoXiAYZlQh4Fgvie2CzxFABsbO2MWQsI4zOh3LRNigYSTrmS+ba2j0B3EObakPzf/x6Zg==} + peerDependencies: + '@types/node': '*' + peerDependenciesMeta: + '@types/node': + optional: true + + '@rushstack/ts-command-line@5.0.2': + resolution: {integrity: sha512-+AkJDbu1GFMPIU8Sb7TLVXDv/Q7Mkvx+wAjEl8XiXVVq+p1FmWW6M3LYpJMmoHNckSofeMecgWg5lfMwNAAsEQ==} + + '@secretlint/config-creator@10.2.1': + resolution: {integrity: sha512-nyuRy8uo2+mXPIRLJ93wizD1HbcdDIsVfgCT01p/zGVFrtvmiL7wqsl4KgZH0QFBM/KRLDLeog3/eaM5ASjtvw==} + engines: {node: '>=20.0.0'} + + '@secretlint/config-loader@10.2.1': + resolution: {integrity: sha512-ob1PwhuSw/Hc6Y4TA63NWj6o++rZTRJOwPZG82o6tgEURqkrAN44fXH9GIouLsOxKa8fbCRLMeGmSBtJLdSqtw==} + engines: {node: '>=20.0.0'} + + '@secretlint/core@10.2.1': + resolution: {integrity: sha512-2sPp5IE7pM5Q+f1/NK6nJ49FKuqh+e3fZq5MVbtVjegiD4NMhjcoML1Cg7atCBgXPufhXRHY1DWhIhkGzOx/cw==} + engines: {node: '>=20.0.0'} + + '@secretlint/formatter@10.2.1': + resolution: {integrity: sha512-0A7ho3j0Y4ysK0mREB3O6FKQtScD4rQgfzuI4Slv9Cut1ynQOI7JXAoIFm4XVzhNcgtmEPeD3pQB206VFphBgQ==} + engines: {node: '>=20.0.0'} + + '@secretlint/node@10.2.1': + resolution: {integrity: sha512-MQFte7C+5ZHINQGSo6+eUECcUCGvKR9PVgZcTsRj524xsbpeBqF1q1dHsUsdGb9r2jlvf40Q14MRZwMcpmLXWQ==} + engines: {node: '>=20.0.0'} + + '@secretlint/profiler@10.2.1': + resolution: {integrity: sha512-gOlfPZ1ASc5mP5cqsL809uMJGp85t+AJZg1ZPscWvB/m5UFFgeNTZcOawggb1S5ExDvR388sIJxagx5hyDZ34g==} + + '@secretlint/resolver@10.2.1': + resolution: {integrity: sha512-AuwehKwnE2uxKaJVv2Z5a8FzGezBmlNhtLKm70Cvsvtwd0oAtenxCSTKXkiPGYC0+S91fAw3lrX7CUkyr9cTCA==} + + '@secretlint/secretlint-formatter-sarif@10.2.1': + resolution: {integrity: sha512-qOZUYBesLkhCBP7YVMv0l1Pypt8e3V2rX2PT2Q5aJhJvKTcMiP9YTHG/3H9Zb7Gq3UIwZLEAGXRqJOu1XlE0Fg==} + + '@secretlint/secretlint-rule-no-dotenv@10.2.1': + resolution: {integrity: sha512-XwPjc9Wwe2QljerfvGlBmLJAJVATLvoXXw1fnKyCDNgvY33cu1Z561Kxg93xfRB5LSep0S5hQrAfZRJw6x7MBQ==} + engines: {node: '>=20.0.0'} + + '@secretlint/secretlint-rule-preset-recommend@10.2.1': + resolution: {integrity: sha512-/kj3UOpFbJt80dqoeEaUVv5nbeW1jPqPExA447FItthiybnaDse5C5HYcfNA2ywEInr399ELdcmpEMRe+ld1iQ==} + engines: {node: '>=20.0.0'} + + '@secretlint/source-creator@10.2.1': + resolution: {integrity: sha512-1CgO+hsRx8KdA5R/LEMNTJkujjomwSQQVV0BcuKynpOefV/rRlIDVQJOU0tJOZdqUMC15oAAwQXs9tMwWLu4JQ==} + engines: {node: '>=20.0.0'} + + '@secretlint/types@10.2.1': + resolution: {integrity: sha512-F5k1qpoMoUe7rrZossOBgJ3jWKv/FGDBZIwepqnefgPmNienBdInxhtZeXiGwjcxXHVhsdgp6I5Fi/M8PMgwcw==} + engines: {node: '>=20.0.0'} + + '@shikijs/engine-oniguruma@3.8.1': + resolution: {integrity: sha512-KGQJZHlNY7c656qPFEQpIoqOuC4LrxjyNndRdzk5WKB/Ie87+NJCF1xo9KkOUxwxylk7rT6nhlZyTGTC4fCe1g==} + + '@shikijs/langs@3.12.2': + resolution: {integrity: sha512-bVx5PfuZHDSHoBal+KzJZGheFuyH4qwwcwG/n+MsWno5cTlKmaNtTsGzJpHYQ8YPbB5BdEdKU1rga5/6JGY8ww==} + + '@shikijs/themes@3.12.2': + resolution: {integrity: sha512-fTR3QAgnwYpfGczpIbzPjlRnxyONJOerguQv1iwpyQZ9QXX4qy/XFQqXlf17XTsorxnHoJGbH/LXBvwtqDsF5A==} + + '@shikijs/types@3.12.2': + resolution: {integrity: sha512-K5UIBzxCyv0YoxN3LMrKB9zuhp1bV+LgewxuVwHdl4Gz5oePoUFrr9EfgJlGlDeXCU1b/yhdnXeuRvAnz8HN8Q==} + + '@shikijs/types@3.8.1': + resolution: {integrity: sha512-5C39Q8/8r1I26suLh+5TPk1DTrbY/kn3IdWA5HdizR0FhlhD05zx5nKCqhzSfDHH3p4S0ZefxWd77DLV+8FhGg==} + + '@shikijs/vscode-textmate@10.0.2': + resolution: {integrity: sha512-83yeghZ2xxin3Nj8z1NMd/NCuca+gsYXswywDy5bHvwlWL8tpTQmzGeUuHd9FC3E/SBEMvzJRwWEOz5gGes9Qg==} + + '@sindresorhus/merge-streams@2.3.0': + resolution: {integrity: sha512-LtoMMhxAlorcGhmFYI+LhPgbPZCkgP6ra1YL604EeF6U98pLlQ3iWIGMdWSC+vWmPBWBNgmDBAhnAobLROJmwg==} + engines: {node: '>=18'} + + '@standard-schema/spec@1.0.0': + resolution: {integrity: sha512-m2bOd0f2RT9k8QJx1JN85cZYyH1RqFBdlwtkSlf4tBDYLCiiZnv1fIIwacK6cqwXavOydf0NPToMQgpKq+dVlA==} + + '@stoplight/better-ajv-errors@1.0.3': + resolution: {integrity: sha512-0p9uXkuB22qGdNfy3VeEhxkU5uwvp/KrBTAbrLBURv6ilxIVwanKwjMc41lQfIVgPGcOkmLbTolfFrSsueu7zA==} + engines: {node: ^12.20 || >= 14.13} + peerDependencies: + ajv: '>=8' + + '@stoplight/json-ref-readers@1.2.2': + resolution: {integrity: sha512-nty0tHUq2f1IKuFYsLM4CXLZGHdMn+X/IwEUIpeSOXt0QjMUbL0Em57iJUDzz+2MkWG83smIigNZ3fauGjqgdQ==} + engines: {node: '>=8.3.0'} + + '@stoplight/json-ref-resolver@3.1.6': + resolution: {integrity: sha512-YNcWv3R3n3U6iQYBsFOiWSuRGE5su1tJSiX6pAPRVk7dP0L7lqCteXGzuVRQ0gMZqUl8v1P0+fAKxF6PLo9B5A==} + engines: {node: '>=8.3.0'} + + '@stoplight/json@3.21.7': + resolution: {integrity: sha512-xcJXgKFqv/uCEgtGlPxy3tPA+4I+ZI4vAuMJ885+ThkTHFVkC+0Fm58lA9NlsyjnkpxFh4YiQWpH+KefHdbA0A==} + engines: {node: '>=8.3.0'} + + '@stoplight/ordered-object-literal@1.0.5': + resolution: {integrity: sha512-COTiuCU5bgMUtbIFBuyyh2/yVVzlr5Om0v5utQDgBCuQUOPgU1DwoffkTfg4UBQOvByi5foF4w4T+H9CoRe5wg==} + engines: {node: '>=8'} + + '@stoplight/path@1.3.2': + resolution: {integrity: sha512-lyIc6JUlUA8Ve5ELywPC8I2Sdnh1zc1zmbYgVarhXIp9YeAB0ReeqmGEOWNtlHkbP2DAA1AL65Wfn2ncjK/jtQ==} + engines: {node: '>=8'} + + '@stoplight/spectral-core@1.20.0': + resolution: {integrity: sha512-5hBP81nCC1zn1hJXL/uxPNRKNcB+/pEIHgCjPRpl/w/qy9yC9ver04tw1W0l/PMiv0UeB5dYgozXVQ4j5a6QQQ==} + engines: {node: ^16.20 || ^18.18 || >= 20.17} + + '@stoplight/spectral-formats@1.8.2': + resolution: {integrity: sha512-c06HB+rOKfe7tuxg0IdKDEA5XnjL2vrn/m/OVIIxtINtBzphZrOgtRn7epQ5bQF5SWp84Ue7UJWaGgDwVngMFw==} + engines: {node: ^16.20 || ^18.18 || >= 20.17} + + '@stoplight/spectral-functions@1.10.1': + resolution: {integrity: sha512-obu8ZfoHxELOapfGsCJixKZXZcffjg+lSoNuttpmUFuDzVLT3VmH8QkPXfOGOL5Pz80BR35ClNAToDkdnYIURg==} + engines: {node: ^16.20 || ^18.18 || >= 20.17} + + '@stoplight/spectral-parsers@1.0.5': + resolution: {integrity: sha512-ANDTp2IHWGvsQDAY85/jQi9ZrF4mRrA5bciNHX+PUxPr4DwS6iv4h+FVWJMVwcEYdpyoIdyL+SRmHdJfQEPmwQ==} + engines: {node: ^16.20 || ^18.18 || >= 20.17} + + '@stoplight/spectral-ref-resolver@1.0.5': + resolution: {integrity: sha512-gj3TieX5a9zMW29z3mBlAtDOCgN3GEc1VgZnCVlr5irmR4Qi5LuECuFItAq4pTn5Zu+sW5bqutsCH7D4PkpyAA==} + engines: {node: ^16.20 || ^18.18 || >= 20.17} + + '@stoplight/spectral-rulesets@1.22.0': + resolution: {integrity: sha512-l2EY2jiKKLsvnPfGy+pXC0LeGsbJzcQP5G/AojHgf+cwN//VYxW1Wvv4WKFx/CLmLxc42mJYF2juwWofjWYNIQ==} + engines: {node: ^16.20 || ^18.18 || >= 20.17} + + '@stoplight/spectral-runtime@1.1.4': + resolution: {integrity: sha512-YHbhX3dqW0do6DhiPSgSGQzr6yQLlWybhKwWx0cqxjMwxej3TqLv3BXMfIUYFKKUqIwH4Q2mV8rrMM8qD2N0rQ==} + engines: {node: ^16.20 || ^18.18 || >= 20.17} + + '@stoplight/types@13.20.0': + resolution: {integrity: sha512-2FNTv05If7ib79VPDA/r9eUet76jewXFH2y2K5vuge6SXbRHtWBhcaRmu+6QpF4/WRNoJj5XYRSwLGXDxysBGA==} + engines: {node: ^12.20 || >=14.13} + + '@stoplight/types@13.6.0': + resolution: {integrity: sha512-dzyuzvUjv3m1wmhPfq82lCVYGcXG0xUYgqnWfCq3PCVR4BKFhjdkHrnJ+jIDoMKvXb05AZP/ObQF6+NpDo29IQ==} + engines: {node: ^12.20 || >=14.13} + + '@stoplight/types@14.1.1': + resolution: {integrity: sha512-/kjtr+0t0tjKr+heVfviO9FrU/uGLc+QNX3fHJc19xsCNYqU7lVhaXxDmEID9BZTjG+/r9pK9xP/xU02XGg65g==} + engines: {node: ^12.20 || >=14.13} + + '@stoplight/yaml-ast-parser@0.0.50': + resolution: {integrity: sha512-Pb6M8TDO9DtSVla9yXSTAxmo9GVEouq5P40DWXdOie69bXogZTkgvopCq+yEvTMA0F6PEvdJmbtTV3ccIp11VQ==} + + '@stoplight/yaml@4.3.0': + resolution: {integrity: sha512-JZlVFE6/dYpP9tQmV0/ADfn32L9uFarHWxfcRhReKUnljz1ZiUM5zpX+PH8h5CJs6lao3TuFqnPm9IJJCEkE2w==} + engines: {node: '>=10.8'} + + '@storybook/addon-a11y@9.0.18': + resolution: {integrity: sha512-msbsTI9TmePQ5ElVclLi7ns5WaAntouJFaj9ElNugFWME21k68RiyXnioDjDfEoi/+y8tthQNNqjsHoX/Ev0Og==} + peerDependencies: + storybook: ^9.0.18 + + '@storybook/addon-docs@9.0.18': + resolution: {integrity: sha512-1mLhaRDx8s1JAF51o56OmwMnIsg4BOQJ8cn+4wbMjh14pDFALrovlFl/BpAXnV1VaZqHjCB4ZWuP+y5CwXEpeQ==} + peerDependencies: + storybook: ^9.0.18 + + '@storybook/addon-docs@9.1.5': + resolution: {integrity: sha512-q1j5RRElxFSnHOh60eS3dS2TAyAHzcQeH/2B9UXo6MUHu7HmhNpw3qt2YibIw0zEogHCvZhLNx6TNzSy+7wRUw==} + peerDependencies: + storybook: ^9.1.5 + + '@storybook/addon-onboarding@9.0.18': + resolution: {integrity: sha512-A079BfJ3g3wYOtAuq9cPf2l6JHo+6UzEw1A2AbSNBBNP4hKfXpHcLadIVwuyOxuKjDUWzY5f4dJa3hCMurHXGQ==} + peerDependencies: + storybook: ^9.0.18 + + '@storybook/addon-vitest@9.0.18': + resolution: {integrity: sha512-uPLh9H7kRho+raxyIBCm8Ymd3j0VPuWIQ1HSAkdx8itmNafNqs4HE67Z8Cfl259YzdWU/j5BhZqoiT62BCbIDw==} + peerDependencies: + '@vitest/browser': ^3.0.0 + '@vitest/runner': ^3.0.0 + storybook: ^9.0.18 + vitest: ^3.0.0 + peerDependenciesMeta: + '@vitest/browser': + optional: true + '@vitest/runner': + optional: true + vitest: + optional: true + + '@storybook/builder-vite@9.0.18': + resolution: {integrity: sha512-lfbrozA6UPVizDrgbPEe04WMtxIraESwUkmwW3+Lxh8rKEUj5cXngcrJUW+meQNNaggdZZWEqeEtweuaLIR+Hg==} + peerDependencies: + storybook: ^9.0.18 + vite: ^5.0.0 || ^6.0.0 || ^7.0.0 + + '@storybook/builder-vite@9.1.5': + resolution: {integrity: sha512-sgt/9+Yl/5O7Bj5hdbHfadN8e/e4CNiDZKDcbLOMpOjKKoqF8vm19I1QocWIAiKjTOhF+4E9v9LddjtAGnfqHQ==} + peerDependencies: + storybook: ^9.1.5 + vite: ^5.0.0 || ^6.0.0 || ^7.0.0 + + '@storybook/csf-plugin@9.0.18': + resolution: {integrity: sha512-MQ3WwXnMua5sX0uYyuO7dC5WOWuJCLqf8CsOn3zQ2ptNoH6hD7DFx5ZOa1uD6VxIuJ3LkA+YqfSRBncomJoRnA==} + peerDependencies: + storybook: ^9.0.18 + + '@storybook/csf-plugin@9.1.5': + resolution: {integrity: sha512-PmHuF+j11Z7BxAI2/4wQYn0gH1d67gNvycyR+EWgp4P/AWam9wFbuI/T1R45CRQTV2/VrfGdts/tFrvo5kXWig==} + peerDependencies: + storybook: ^9.1.5 + + '@storybook/global@5.0.0': + resolution: {integrity: sha512-FcOqPAXACP0I3oJ/ws6/rrPT9WGhu915Cg8D02a9YxLo0DE9zI+a9A5gRGvmQ09fiWPukqI8ZAEoQEdWUKMQdQ==} + + '@storybook/icons@1.4.0': + resolution: {integrity: sha512-Td73IeJxOyalzvjQL+JXx72jlIYHgs+REaHiREOqfpo3A2AYYG71AUbcv+lg7mEDIweKVCxsMQ0UKo634c8XeA==} + engines: {node: '>=14.0.0'} + peerDependencies: + react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0-beta + react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0-beta + + '@storybook/react-dom-shim@9.0.18': + resolution: {integrity: sha512-qGR/d9x9qWRRxITaBVQkMnb73kwOm+N8fkbZRxc7U4lxupXRvkMIDh247nn71SYVBnvbh6//AL7P6ghiPWZYjA==} + peerDependencies: + react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0-beta + react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0-beta + storybook: ^9.0.18 + + '@storybook/react-dom-shim@9.1.5': + resolution: {integrity: sha512-blSq9uzSYnfgEYPHYKgM5O14n8hbXNiXx2GiVJyDSg8QPNicbsBg+lCb1TC7/USfV26pNZr/lGNNKGkcCEN6Gw==} + peerDependencies: + react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0-beta + react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0-beta + storybook: ^9.1.5 + + '@storybook/react-vite@9.0.18': + resolution: {integrity: sha512-dHzUoeY0/S35TvSYxCkPuBlNQZx4Zj9QDhAZ0qdv+nSll++uPgqSe2y2vF+2p+XVYhjDn+YX5LORv00YtuQezg==} + engines: {node: '>=20.0.0'} + peerDependencies: + react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0-beta + react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0-beta + storybook: ^9.0.18 + vite: ^5.0.0 || ^6.0.0 || ^7.0.0 + + '@storybook/react-vite@9.1.5': + resolution: {integrity: sha512-OYbkHHNCrn8MNPd+4KxMjcSR4M/YHa84h8sWDUHhKRTRtZFmj8i/QDW3E8tGx2BRLxXw3dTYe9J5UYBhJDDxFA==} + engines: {node: '>=20.0.0'} + peerDependencies: + react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0-beta + react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0-beta + storybook: ^9.1.5 + vite: ^5.0.0 || ^6.0.0 || ^7.0.0 + + '@storybook/react@9.0.18': + resolution: {integrity: sha512-CCH6Vj/O6I07PrhCHxc1pvCWYMfZhRzK7CVHAtrBP9xxnYA7OoXhM2wymuDogml5HW1BKtyVMeQ3oWZXFNgDXQ==} + engines: {node: '>=20.0.0'} + peerDependencies: + react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0-beta + react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0-beta + storybook: ^9.0.18 + typescript: '>= 4.9.x' + peerDependenciesMeta: + typescript: + optional: true + + '@storybook/react@9.1.5': + resolution: {integrity: sha512-fBVP7Go09gzpImtaMcZ2DipLEWdWeTmz7BrACr3Z8uCyKcoH8/d1Wv0JgIiBo1UKDh5ZgYx5pLafaPNqmVAepg==} + engines: {node: '>=20.0.0'} + peerDependencies: + react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0-beta + react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0-beta + storybook: ^9.1.5 + typescript: '>= 4.9.x' + peerDependenciesMeta: + typescript: + optional: true + + '@swc/core-darwin-arm64@1.13.2': + resolution: {integrity: sha512-44p7ivuLSGFJ15Vly4ivLJjg3ARo4879LtEBAabcHhSZygpmkP8eyjyWxrH3OxkY1eRZSIJe8yRZPFw4kPXFPw==} + engines: {node: '>=10'} + cpu: [arm64] + os: [darwin] + + '@swc/core-darwin-x64@1.13.2': + resolution: {integrity: sha512-Lb9EZi7X2XDAVmuUlBm2UvVAgSCbD3qKqDCxSI4jEOddzVOpNCnyZ/xEampdngUIyDDhhJLYU9duC+Mcsv5Y+A==} + engines: {node: '>=10'} + cpu: [x64] + os: [darwin] + + '@swc/core-linux-arm-gnueabihf@1.13.2': + resolution: {integrity: sha512-9TDe/92ee1x57x+0OqL1huG4BeljVx0nWW4QOOxp8CCK67Rpc/HHl2wciJ0Kl9Dxf2NvpNtkPvqj9+BUmM9WVA==} + engines: {node: '>=10'} + cpu: [arm] + os: [linux] + + '@swc/core-linux-arm64-gnu@1.13.2': + resolution: {integrity: sha512-KJUSl56DBk7AWMAIEcU83zl5mg3vlQYhLELhjwRFkGFMvghQvdqQ3zFOYa4TexKA7noBZa3C8fb24rI5sw9Exg==} + engines: {node: '>=10'} + cpu: [arm64] + os: [linux] + + '@swc/core-linux-arm64-musl@1.13.2': + resolution: {integrity: sha512-teU27iG1oyWpNh9CzcGQ48ClDRt/RCem7mYO7ehd2FY102UeTws2+OzLESS1TS1tEZipq/5xwx3FzbVgiolCiQ==} + engines: {node: '>=10'} + cpu: [arm64] + os: [linux] + + '@swc/core-linux-x64-gnu@1.13.2': + resolution: {integrity: sha512-dRPsyPyqpLD0HMRCRpYALIh4kdOir8pPg4AhNQZLehKowigRd30RcLXGNVZcc31Ua8CiPI4QSgjOIxK+EQe4LQ==} + engines: {node: '>=10'} + cpu: [x64] + os: [linux] + + '@swc/core-linux-x64-musl@1.13.2': + resolution: {integrity: sha512-CCxETW+KkYEQDqz1SYC15YIWYheqFC+PJVOW76Maa/8yu8Biw+HTAcblKf2isrlUtK8RvrQN94v3UXkC2NzCEw==} + engines: {node: '>=10'} + cpu: [x64] + os: [linux] + + '@swc/core-win32-arm64-msvc@1.13.2': + resolution: {integrity: sha512-Wv/QTA6PjyRLlmKcN6AmSI4jwSMRl0VTLGs57PHTqYRwwfwd7y4s2fIPJVBNbAlXd795dOEP6d/bGSQSyhOX3A==} + engines: {node: '>=10'} + cpu: [arm64] + os: [win32] + + '@swc/core-win32-ia32-msvc@1.13.2': + resolution: {integrity: sha512-PuCdtNynEkUNbUXX/wsyUC+t4mamIU5y00lT5vJcAvco3/r16Iaxl5UCzhXYaWZSNVZMzPp9qN8NlSL8M5pPxw==} + engines: {node: '>=10'} + cpu: [ia32] + os: [win32] + + '@swc/core-win32-x64-msvc@1.13.2': + resolution: {integrity: sha512-qlmMkFZJus8cYuBURx1a3YAG2G7IW44i+FEYV5/32ylKkzGNAr9tDJSA53XNnNXkAB5EXSPsOz7bn5C3JlEtdQ==} + engines: {node: '>=10'} + cpu: [x64] + os: [win32] + + '@swc/core@1.13.2': + resolution: {integrity: sha512-YWqn+0IKXDhqVLKoac4v2tV6hJqB/wOh8/Br8zjqeqBkKa77Qb0Kw2i7LOFzjFNZbZaPH6AlMGlBwNrxaauaAg==} + engines: {node: '>=10'} + peerDependencies: + '@swc/helpers': '>=0.5.17' + peerDependenciesMeta: + '@swc/helpers': + optional: true + + '@swc/counter@0.1.3': + resolution: {integrity: sha512-e2BR4lsJkkRlKZ/qCHPw9ZaSxc0MVUd7gtbtaB7aMvHeJVYe8sOB8DBZkP2DtISHGSku9sCK6T6cnY0CtXrOCQ==} + + '@swc/helpers@0.5.17': + resolution: {integrity: sha512-5IKx/Y13RsYd+sauPb2x+U/xZikHjolzfuDgTAl/Tdf3Q8rslRvC19NKDLgAJQ6wsqADk10ntlv08nPFw/gO/A==} + + '@swc/types@0.1.23': + resolution: {integrity: sha512-u1iIVZV9Q0jxY+yM2vw/hZGDNudsN85bBpTqzAQ9rzkxW9D+e3aEM4Han+ow518gSewkXgjmEK0BD79ZcNVgPw==} + + '@tailwindcss/container-queries@0.1.1': + resolution: {integrity: sha512-p18dswChx6WnTSaJCSGx6lTmrGzNNvm2FtXmiO6AuA1V4U5REyoqwmT6kgAsIMdjo07QdAfYXHJ4hnMtfHzWgA==} + peerDependencies: + tailwindcss: '>=3.2.0' + + '@tailwindcss/node@4.1.11': + resolution: {integrity: sha512-yzhzuGRmv5QyU9qLNg4GTlYI6STedBWRE7NjxP45CsFYYq9taI0zJXZBMqIC/c8fViNLhmrbpSFS57EoxUmD6Q==} + + '@tailwindcss/oxide-android-arm64@4.1.11': + resolution: {integrity: sha512-3IfFuATVRUMZZprEIx9OGDjG3Ou3jG4xQzNTvjDoKmU9JdmoCohQJ83MYd0GPnQIu89YoJqvMM0G3uqLRFtetg==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [android] + + '@tailwindcss/oxide-darwin-arm64@4.1.11': + resolution: {integrity: sha512-ESgStEOEsyg8J5YcMb1xl8WFOXfeBmrhAwGsFxxB2CxY9evy63+AtpbDLAyRkJnxLy2WsD1qF13E97uQyP1lfQ==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [darwin] + + '@tailwindcss/oxide-darwin-x64@4.1.11': + resolution: {integrity: sha512-EgnK8kRchgmgzG6jE10UQNaH9Mwi2n+yw1jWmof9Vyg2lpKNX2ioe7CJdf9M5f8V9uaQxInenZkOxnTVL3fhAw==} + engines: {node: '>= 10'} + cpu: [x64] + os: [darwin] + + '@tailwindcss/oxide-freebsd-x64@4.1.11': + resolution: {integrity: sha512-xdqKtbpHs7pQhIKmqVpxStnY1skuNh4CtbcyOHeX1YBE0hArj2romsFGb6yUmzkq/6M24nkxDqU8GYrKrz+UcA==} + engines: {node: '>= 10'} + cpu: [x64] + os: [freebsd] + + '@tailwindcss/oxide-linux-arm-gnueabihf@4.1.11': + resolution: {integrity: sha512-ryHQK2eyDYYMwB5wZL46uoxz2zzDZsFBwfjssgB7pzytAeCCa6glsiJGjhTEddq/4OsIjsLNMAiMlHNYnkEEeg==} + engines: {node: '>= 10'} + cpu: [arm] + os: [linux] + + '@tailwindcss/oxide-linux-arm64-gnu@4.1.11': + resolution: {integrity: sha512-mYwqheq4BXF83j/w75ewkPJmPZIqqP1nhoghS9D57CLjsh3Nfq0m4ftTotRYtGnZd3eCztgbSPJ9QhfC91gDZQ==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [linux] + + '@tailwindcss/oxide-linux-arm64-musl@4.1.11': + resolution: {integrity: sha512-m/NVRFNGlEHJrNVk3O6I9ggVuNjXHIPoD6bqay/pubtYC9QIdAMpS+cswZQPBLvVvEF6GtSNONbDkZrjWZXYNQ==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [linux] + + '@tailwindcss/oxide-linux-x64-gnu@4.1.11': + resolution: {integrity: sha512-YW6sblI7xukSD2TdbbaeQVDysIm/UPJtObHJHKxDEcW2exAtY47j52f8jZXkqE1krdnkhCMGqP3dbniu1Te2Fg==} + engines: {node: '>= 10'} + cpu: [x64] + os: [linux] + + '@tailwindcss/oxide-linux-x64-musl@4.1.11': + resolution: {integrity: sha512-e3C/RRhGunWYNC3aSF7exsQkdXzQ/M+aYuZHKnw4U7KQwTJotnWsGOIVih0s2qQzmEzOFIJ3+xt7iq67K/p56Q==} + engines: {node: '>= 10'} + cpu: [x64] + os: [linux] + + '@tailwindcss/oxide-wasm32-wasi@4.1.11': + resolution: {integrity: sha512-Xo1+/GU0JEN/C/dvcammKHzeM6NqKovG+6921MR6oadee5XPBaKOumrJCXvopJ/Qb5TH7LX/UAywbqrP4lax0g==} + engines: {node: '>=14.0.0'} + cpu: [wasm32] + bundledDependencies: + - '@napi-rs/wasm-runtime' + - '@emnapi/core' + - '@emnapi/runtime' + - '@tybys/wasm-util' + - '@emnapi/wasi-threads' + - tslib + + '@tailwindcss/oxide-win32-arm64-msvc@4.1.11': + resolution: {integrity: sha512-UgKYx5PwEKrac3GPNPf6HVMNhUIGuUh4wlDFR2jYYdkX6pL/rn73zTq/4pzUm8fOjAn5L8zDeHp9iXmUGOXZ+w==} + engines: {node: '>= 10'} + cpu: [arm64] + os: [win32] + + '@tailwindcss/oxide-win32-x64-msvc@4.1.11': + resolution: {integrity: sha512-YfHoggn1j0LK7wR82TOucWc5LDCguHnoS879idHekmmiR7g9HUtMw9MI0NHatS28u/Xlkfi9w5RJWgz2Dl+5Qg==} + engines: {node: '>= 10'} + cpu: [x64] + os: [win32] + + '@tailwindcss/oxide@4.1.11': + resolution: {integrity: sha512-Q69XzrtAhuyfHo+5/HMgr1lAiPP/G40OMFAnws7xcFEYqcypZmdW8eGXaOUIeOl1dzPJBPENXgbjsOyhg2nkrg==} + engines: {node: '>= 10'} + + '@tailwindcss/postcss@4.1.11': + resolution: {integrity: sha512-q/EAIIpF6WpLhKEuQSEVMZNMIY8KhWoAemZ9eylNAih9jxMGAYPPWBn3I9QL/2jZ+e7OEz/tZkX5HwbBR4HohA==} + + '@tailwindcss/typography@0.5.16': + resolution: {integrity: sha512-0wDLwCVF5V3x3b1SGXPCDcdsbDHMBe+lkFzBRaHeLvNi+nrrnZ1lA18u+OTWO8iSWU2GxUOCvlXtDuqftc1oiA==} + peerDependencies: + tailwindcss: '>=3.0.0 || insiders || >=4.0.0-alpha.20 || >=4.0.0-beta.1' + + '@tailwindcss/vite@4.1.11': + resolution: {integrity: sha512-RHYhrR3hku0MJFRV+fN2gNbDNEh3dwKvY8XJvTxCSXeMOsCRSr+uKvDWQcbizrHgjML6ZmTE5OwMrl5wKcujCw==} + peerDependencies: + vite: ^5.2.0 || ^6 || ^7 + + '@tanstack/history@1.129.7': + resolution: {integrity: sha512-I3YTkbe4RZQN54Qw4+IUhOjqG2DdbG2+EBWuQfew4MEk0eddLYAQVa50BZVww4/D2eh5I9vEk2Fd1Y0Wty7pug==} + engines: {node: '>=12'} + + '@tanstack/query-core@5.83.0': + resolution: {integrity: sha512-0M8dA+amXUkyz5cVUm/B+zSk3xkQAcuXuz5/Q/LveT4ots2rBpPTZOzd7yJa2Utsf8D2Upl5KyjhHRY+9lB/XA==} + + '@tanstack/react-query@5.83.0': + resolution: {integrity: sha512-/XGYhZ3foc5H0VM2jLSD/NyBRIOK4q9kfeml4+0x2DlL6xVuAcVEW+hTlTapAmejObg0i3eNqhkr2dT+eciwoQ==} + peerDependencies: + react: ^18 || ^19 + + '@tanstack/react-router-devtools@1.131.26': + resolution: {integrity: sha512-QdDF2t3ILZLqblBYDWQXpQ8QsHzo2ZJcWhaeQEdAkMZ0w0mlfKdZKOGigA21KvDbyTOgkfuQBj+DlkiQPqKYMA==} + engines: {node: '>=12'} + peerDependencies: + '@tanstack/react-router': ^1.131.26 + react: '>=18.0.0 || >=19.0.0' + react-dom: '>=18.0.0 || >=19.0.0' + + '@tanstack/react-router@1.129.8': + resolution: {integrity: sha512-d5mfM+67h3wq7aHkLjRKXD1ddbzx1YuxaEbNvW45jjZXMgaikZSVfJrZBiUWXE/nhV1sTdbMQ48JcPagvGPmYQ==} + engines: {node: '>=12'} + peerDependencies: + react: '>=18.0.0 || >=19.0.0' + react-dom: '>=18.0.0 || >=19.0.0' + + '@tanstack/react-store@0.7.3': + resolution: {integrity: sha512-3Dnqtbw9P2P0gw8uUM8WP2fFfg8XMDSZCTsywRPZe/XqqYW8PGkXKZTvP0AHkE4mpqP9Y43GpOg9vwO44azu6Q==} + peerDependencies: + react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 + react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 + + '@tanstack/react-table@8.21.3': + resolution: {integrity: sha512-5nNMTSETP4ykGegmVkhjcS8tTLW6Vl4axfEGQN3v0zdHYbK4UfoqfPChclTrJ4EoK9QynqAu9oUf8VEmrpZ5Ww==} + engines: {node: '>=12'} + peerDependencies: + react: '>=16.8' + react-dom: '>=16.8' + + '@tanstack/react-virtual@3.13.12': + resolution: {integrity: sha512-Gd13QdxPSukP8ZrkbgS2RwoZseTTbQPLnQEn7HY/rqtM+8Zt95f7xKC7N0EsKs7aoz0WzZ+fditZux+F8EzYxA==} + peerDependencies: + react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 + react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 + + '@tanstack/router-core@1.129.8': + resolution: {integrity: sha512-Izqf5q8TzJv0DJURynitJioPJT3dPAefrzHi2wlY/Q5+7nEG41SkjYMotTX2Q9i/Pjl91lW8gERCHpksszRdRw==} + engines: {node: '>=12'} + + '@tanstack/router-devtools-core@1.131.26': + resolution: {integrity: sha512-TGHmRDQpYphuRbDH+jJp418vQuIydzITaUx7MiPk5U1ZZ+2O/GxcF/ycXmyYR0IHTpSky35I83X3bKTiv+thyw==} + engines: {node: '>=12'} + peerDependencies: + '@tanstack/router-core': ^1.131.26 + csstype: ^3.0.10 + solid-js: '>=1.9.5' + tiny-invariant: ^1.3.3 + peerDependenciesMeta: + csstype: + optional: true + + '@tanstack/router-generator@1.129.8': + resolution: {integrity: sha512-i4QTtJeRq3jdRTuUXHKcmPNm6STS0jLJNTKEdeUCIzuVBiiP53oujMOd84e5ARP83k2IB2XcMHekTSzDlWD2fg==} + engines: {node: '>=12'} + + '@tanstack/router-plugin@1.129.8': + resolution: {integrity: sha512-DdO6el2slgBO2mIqIGdGyHCzsbQLsTNxsgbNz9ZY9y324iP4G+p3iEYopHWgzLKM2DKinMs9F7AxjLow4V3klQ==} + engines: {node: '>=12'} + peerDependencies: + '@rsbuild/core': '>=1.0.2' + '@tanstack/react-router': ^1.129.8 + vite: '>=5.0.0 || >=6.0.0' + vite-plugin-solid: ^2.11.2 + webpack: '>=5.92.0' + peerDependenciesMeta: + '@rsbuild/core': + optional: true + '@tanstack/react-router': + optional: true + vite: + optional: true + vite-plugin-solid: + optional: true + webpack: + optional: true + + '@tanstack/router-utils@1.129.7': + resolution: {integrity: sha512-I2OyQF5U6sxHJApXKCUmCncTHKcpj4681FwyxpYg5QYOatHcn/zVMl7Rj4h36fu8/Lo2ZRLxUMd5kmXgp5Pb/A==} + engines: {node: '>=12'} + + '@tanstack/store@0.7.2': + resolution: {integrity: sha512-RP80Z30BYiPX2Pyo0Nyw4s1SJFH2jyM6f9i3HfX4pA+gm5jsnYryscdq2aIQLnL4TaGuQMO+zXmN9nh1Qck+Pg==} + + '@tanstack/table-core@8.21.3': + resolution: {integrity: sha512-ldZXEhOBb8Is7xLs01fR3YEc3DERiz5silj8tnGkFZytt1abEvl/GhUmCE0PMLaMPTa3Jk4HbKmRlHmu+gCftg==} + engines: {node: '>=12'} + + '@tanstack/virtual-core@3.13.12': + resolution: {integrity: sha512-1YBOJfRHV4sXUmWsFSf5rQor4Ss82G8dQWLRbnk3GA4jeP8hQt1hxXh0tmflpC0dz3VgEv/1+qwPyLeWkQuPFA==} + + '@tanstack/virtual-file-routes@1.129.7': + resolution: {integrity: sha512-a+MxoAXG+Sq94Jp67OtveKOp2vQq75AWdVI8DRt6w19B0NEqpfm784FTLbVp/qdR1wmxCOmKAvElGSIiBOx5OQ==} + engines: {node: '>=12'} + + '@testing-library/dom@10.4.1': + resolution: {integrity: sha512-o4PXJQidqJl82ckFaXUeoAW+XysPLauYI43Abki5hABd853iMhitooc6znOnczgbTYmEP6U6/y1ZyKAIsvMKGg==} + engines: {node: '>=18'} + + '@testing-library/jest-dom@6.6.3': + resolution: {integrity: sha512-IteBhl4XqYNkM54f4ejhLRJiZNqcSCoXUOG2CPK7qbD322KjQozM4kHQOfkG2oln9b9HTYqs+Sae8vBATubxxA==} + engines: {node: '>=14', npm: '>=6', yarn: '>=1'} + + '@testing-library/react@16.3.0': + resolution: {integrity: sha512-kFSyxiEDwv1WLl2fgsq6pPBbw5aWKrsY2/noi1Id0TK0UParSF62oFQFGHXIyaG4pp2tEub/Zlel+fjjZILDsw==} + engines: {node: '>=18'} + peerDependencies: + '@testing-library/dom': ^10.0.0 + '@types/react': ^18.0.0 || ^19.0.0 + '@types/react-dom': ^18.0.0 || ^19.0.0 + react: ^18.0.0 || ^19.0.0 + react-dom: ^18.0.0 || ^19.0.0 + peerDependenciesMeta: + '@types/react': + optional: true + '@types/react-dom': + optional: true + + '@testing-library/user-event@14.6.1': + resolution: {integrity: sha512-vq7fv0rnt+QTXgPxr5Hjc210p6YKq2kmdziLgnsZGgLJ9e6VAShx1pACLuRjd/AS/sr7phAR58OIIpf0LlmQNw==} + engines: {node: '>=12', npm: '>=6'} + peerDependencies: + '@testing-library/dom': '>=7.21.4' + + '@textlint/ast-node-types@15.2.0': + resolution: {integrity: sha512-nr9wEiZCNYafGZ++uWFZgPlDX3Bi7u4T2d5swpaoMvc1G2toXsBfe7UNVwXZq5dvYDbQN7vDeb3ltlKQ8JnPNQ==} + + '@textlint/linter-formatter@15.2.0': + resolution: {integrity: sha512-L+fM2OTs17hRxPCLKUdPjHce7cJp81gV9ku53FCL+cXnq5bZx0XYYkqKdtC0jnXujkQmrTYU3SYFrb4DgXqbtA==} + + '@textlint/module-interop@15.2.0': + resolution: {integrity: sha512-M3y1s2dZZH8PSHo4RUlnPOdK3qN90wmYGaEdy+il9/BQfrrift7S9R8lOfhHoPS0m9FEsnwyj3dQLkCUugPd9Q==} + + '@textlint/resolver@15.2.0': + resolution: {integrity: sha512-1UC+5bEtuoht7uu0uGofb7sX7j17Mvyst9InrRtI4XgKhh1uMZz5YFiMYpNwry1GgCZvq7Wyq1fqtEIsvYWqFw==} + + '@textlint/types@15.2.0': + resolution: {integrity: sha512-wpF+xjGJgJK2JiwUdYjuNZrbuas3KfC9VDnHKac6aBLFyrI1iXuXtuxKXQDFi5/hebACactSJOuVVbuQbdJZ1Q==} + + '@types/argparse@1.0.38': + resolution: {integrity: sha512-ebDJ9b0e702Yr7pWgB0jzm+CX4Srzz8RcXtLJDJB+BSccqMa36uyH/zUsSYao5+BD1ytv3k3rPYCq4mAE1hsXA==} + + '@types/aria-query@5.0.4': + resolution: {integrity: sha512-rfT93uj5s0PRL7EzccGMs3brplhcrghnDoV26NqKhCAS1hVo+WdNsPvE/yb6ilfr5hi2MEk6d5EWJTKdxg8jVw==} + + '@types/babel__core@7.20.5': + resolution: {integrity: sha512-qoQprZvz5wQFJwMDqeseRXWv3rqMvhgpbXFfVyWhbx9X47POIA6i/+dXefEmZKoAgOaTdaIgNSMqMIU61yRyzA==} + + '@types/babel__generator@7.27.0': + resolution: {integrity: sha512-ufFd2Xi92OAVPYsy+P4n7/U7e68fex0+Ee8gSG9KX7eo084CWiQ4sdxktvdl0bOPupXtVJPY19zk6EwWqUQ8lg==} + + '@types/babel__template@7.4.4': + resolution: {integrity: sha512-h/NUaSyG5EyxBIp8YRxo4RMe2/qQgvyowRwVMzhYhBCONbW8PUsg4lkFMrhgZhUe5z3L3MiLDuvyJ/CaPa2A8A==} + + '@types/babel__traverse@7.20.7': + resolution: {integrity: sha512-dkO5fhS7+/oos4ciWxyEyjWe48zmG6wbCheo/G2ZnHx4fs3EU6YC6UM8rk56gAjNJ9P3MTH2jo5jb92/K6wbng==} + + '@types/chai@5.2.2': + resolution: {integrity: sha512-8kB30R7Hwqf40JPiKhVzodJs2Qc1ZJ5zuT3uzw5Hq/dhNCl3G3l83jfpdI1e20BP348+fV7VIL/+FxaXkqBmWg==} + + '@types/command-line-args@5.2.3': + resolution: {integrity: sha512-uv0aG6R0Y8WHZLTamZwtfsDLVRnOa+n+n5rEvFWL5Na5gZ8V2Teab/duDPFzIIIhs9qizDpcavCusCLJZu62Kw==} + + '@types/command-line-usage@5.0.4': + resolution: {integrity: sha512-BwR5KP3Es/CSht0xqBcUXS3qCAUVXwpRKsV2+arxeb65atasuXG9LykC9Ab10Cw3s2raH92ZqOeILaQbsB2ACg==} + + '@types/d3-array@3.2.1': + resolution: {integrity: sha512-Y2Jn2idRrLzUfAKV2LyRImR+y4oa2AntrgID95SHJxuMUrkNXmanDSed71sRNZysveJVt1hLLemQZIady0FpEg==} + + '@types/d3-axis@3.0.6': + resolution: {integrity: sha512-pYeijfZuBd87T0hGn0FO1vQ/cgLk6E1ALJjfkC0oJ8cbwkZl3TpgS8bVBLZN+2jjGgg38epgxb2zmoGtSfvgMw==} + + '@types/d3-brush@3.0.6': + resolution: {integrity: sha512-nH60IZNNxEcrh6L1ZSMNA28rj27ut/2ZmI3r96Zd+1jrZD++zD3LsMIjWlvg4AYrHn/Pqz4CF3veCxGjtbqt7A==} + + '@types/d3-chord@3.0.6': + resolution: {integrity: sha512-LFYWWd8nwfwEmTZG9PfQxd17HbNPksHBiJHaKuY1XeqscXacsS2tyoo6OdRsjf+NQYeB6XrNL3a25E3gH69lcg==} + + '@types/d3-color@3.1.3': + resolution: {integrity: sha512-iO90scth9WAbmgv7ogoq57O9YpKmFBbmoEoCHDB2xMBY0+/KVrqAaCDyCE16dUspeOvIxFFRI+0sEtqDqy2b4A==} + + '@types/d3-contour@3.0.6': + resolution: {integrity: sha512-BjzLgXGnCWjUSYGfH1cpdo41/hgdWETu4YxpezoztawmqsvCeep+8QGfiY6YbDvfgHz/DkjeIkkZVJavB4a3rg==} + + '@types/d3-delaunay@6.0.4': + resolution: {integrity: sha512-ZMaSKu4THYCU6sV64Lhg6qjf1orxBthaC161plr5KuPHo3CNm8DTHiLw/5Eq2b6TsNP0W0iJrUOFscY6Q450Hw==} + + '@types/d3-dispatch@3.0.6': + resolution: {integrity: sha512-4fvZhzMeeuBJYZXRXrRIQnvUYfyXwYmLsdiN7XXmVNQKKw1cM8a5WdID0g1hVFZDqT9ZqZEY5pD44p24VS7iZQ==} + + '@types/d3-drag@3.0.7': + resolution: {integrity: sha512-HE3jVKlzU9AaMazNufooRJ5ZpWmLIoc90A37WU2JMmeq28w1FQqCZswHZ3xR+SuxYftzHq6WU6KJHvqxKzTxxQ==} + + '@types/d3-dsv@3.0.7': + resolution: {integrity: sha512-n6QBF9/+XASqcKK6waudgL0pf/S5XHPPI8APyMLLUHd8NqouBGLsU8MgtO7NINGtPBtk9Kko/W4ea0oAspwh9g==} + + '@types/d3-ease@3.0.2': + resolution: {integrity: sha512-NcV1JjO5oDzoK26oMzbILE6HW7uVXOHLQvHshBUW4UMdZGfiY6v5BeQwh9a9tCzv+CeefZQHJt5SRgK154RtiA==} + + '@types/d3-fetch@3.0.7': + resolution: {integrity: sha512-fTAfNmxSb9SOWNB9IoG5c8Hg6R+AzUHDRlsXsDZsNp6sxAEOP0tkP3gKkNSO/qmHPoBFTxNrjDprVHDQDvo5aA==} + + '@types/d3-force@3.0.10': + resolution: {integrity: sha512-ZYeSaCF3p73RdOKcjj+swRlZfnYpK1EbaDiYICEEp5Q6sUiqFaFQ9qgoshp5CzIyyb/yD09kD9o2zEltCexlgw==} + + '@types/d3-format@3.0.4': + resolution: {integrity: sha512-fALi2aI6shfg7vM5KiR1wNJnZ7r6UuggVqtDA+xiEdPZQwy/trcQaHnwShLuLdta2rTymCNpxYTiMZX/e09F4g==} + + '@types/d3-geo@3.1.0': + resolution: {integrity: sha512-856sckF0oP/diXtS4jNsiQw/UuK5fQG8l/a9VVLeSouf1/PPbBE1i1W852zVwKwYCBkFJJB7nCFTbk6UMEXBOQ==} + + '@types/d3-hierarchy@3.1.7': + resolution: {integrity: sha512-tJFtNoYBtRtkNysX1Xq4sxtjK8YgoWUNpIiUee0/jHGRwqvzYxkq0hGVbbOGSz+JgFxxRu4K8nb3YpG3CMARtg==} + + '@types/d3-interpolate@3.0.4': + resolution: {integrity: sha512-mgLPETlrpVV1YRJIglr4Ez47g7Yxjl1lj7YKsiMCb27VJH9W8NVM6Bb9d8kkpG/uAQS5AmbA48q2IAolKKo1MA==} + + '@types/d3-path@3.1.1': + resolution: {integrity: sha512-VMZBYyQvbGmWyWVea0EHs/BwLgxc+MKi1zLDCONksozI4YJMcTt8ZEuIR4Sb1MMTE8MMW49v0IwI5+b7RmfWlg==} + + '@types/d3-polygon@3.0.2': + resolution: {integrity: sha512-ZuWOtMaHCkN9xoeEMr1ubW2nGWsp4nIql+OPQRstu4ypeZ+zk3YKqQT0CXVe/PYqrKpZAi+J9mTs05TKwjXSRA==} + + '@types/d3-quadtree@3.0.6': + resolution: {integrity: sha512-oUzyO1/Zm6rsxKRHA1vH0NEDG58HrT5icx/azi9MF1TWdtttWl0UIUsjEQBBh+SIkrpd21ZjEv7ptxWys1ncsg==} + + '@types/d3-random@3.0.3': + resolution: {integrity: sha512-Imagg1vJ3y76Y2ea0871wpabqp613+8/r0mCLEBfdtqC7xMSfj9idOnmBYyMoULfHePJyxMAw3nWhJxzc+LFwQ==} + + '@types/d3-scale-chromatic@3.1.0': + resolution: {integrity: sha512-iWMJgwkK7yTRmWqRB5plb1kadXyQ5Sj8V/zYlFGMUBbIPKQScw+Dku9cAAMgJG+z5GYDoMjWGLVOvjghDEFnKQ==} + + '@types/d3-scale@4.0.9': + resolution: {integrity: sha512-dLmtwB8zkAeO/juAMfnV+sItKjlsw2lKdZVVy6LRr0cBmegxSABiLEpGVmSJJ8O08i4+sGR6qQtb6WtuwJdvVw==} + + '@types/d3-selection@3.0.11': + resolution: {integrity: sha512-bhAXu23DJWsrI45xafYpkQ4NtcKMwWnAC/vKrd2l+nxMFuvOT3XMYTIj2opv8vq8AO5Yh7Qac/nSeP/3zjTK0w==} + + '@types/d3-shape@3.1.7': + resolution: {integrity: sha512-VLvUQ33C+3J+8p+Daf+nYSOsjB4GXp19/S/aGo60m9h1v6XaxjiT82lKVWJCfzhtuZ3yD7i/TPeC/fuKLLOSmg==} + + '@types/d3-time-format@4.0.3': + resolution: {integrity: sha512-5xg9rC+wWL8kdDj153qZcsJ0FWiFt0J5RB6LYUNZjwSnesfblqrI/bJ1wBdJ8OQfncgbJG5+2F+qfqnqyzYxyg==} + + '@types/d3-time@3.0.4': + resolution: {integrity: sha512-yuzZug1nkAAaBlBBikKZTgzCeA+k1uy4ZFwWANOfKw5z5LRhV0gNA7gNkKm7HoK+HRN0wX3EkxGk0fpbWhmB7g==} + + '@types/d3-timer@3.0.2': + resolution: {integrity: sha512-Ps3T8E8dZDam6fUyNiMkekK3XUsaUEik+idO9/YjPtfj2qruF8tFBXS7XhtE4iIXBLxhmLjP3SXpLhVf21I9Lw==} + + '@types/d3-transition@3.0.9': + resolution: {integrity: sha512-uZS5shfxzO3rGlu0cC3bjmMFKsXv+SmZZcgp0KD22ts4uGXp5EVYGzu/0YdwZeKmddhcAccYtREJKkPfXkZuCg==} + + '@types/d3-zoom@3.0.8': + resolution: {integrity: sha512-iqMC4/YlFCSlO8+2Ii1GGGliCAY4XdeG748w5vQUbevlbDu0zSjH/+jojorQVBK/se0j6DUFNPBGSqD3YWYnDw==} + + '@types/d3@7.4.3': + resolution: {integrity: sha512-lZXZ9ckh5R8uiFVt8ogUNf+pIrK4EsWrx2Np75WvF/eTpJ0FMHNhjXk8CKEx/+gpHbNQyJWehbFaTvqmHWB3ww==} + + '@types/dagre@0.7.53': + resolution: {integrity: sha512-f4gkWqzPZvYmKhOsDnhq/R8mO4UMcKdxZo+i5SCkOU1wvGeHJeUXGIHeE9pnwGyPMDof1Vx5ZQo4nxpeg2TTVQ==} + + '@types/debug@4.1.12': + resolution: {integrity: sha512-vIChWdVG3LG1SMxEvI/AK+FWJthlrqlTu7fbrlywTkkaONwk/UAGaULXRlf8vkzFBLVm0zkMdCquhL5aOjhXPQ==} + + '@types/deep-eql@4.0.2': + resolution: {integrity: sha512-c9h9dVVMigMPc4bwTvC5dxqtqJZwQPePsWjPlpSOnojbor6pGqdk541lfA7AqFQr5pB1BRdq0juY9db81BwyFw==} + + '@types/doctrine@0.0.9': + resolution: {integrity: sha512-eOIHzCUSH7SMfonMG1LsC2f8vxBFtho6NGBznK41R84YzPuvSBzrhEps33IsQiOW9+VL6NQ9DbjQJznk/S4uRA==} + + '@types/es-aggregate-error@1.0.6': + resolution: {integrity: sha512-qJ7LIFp06h1QE1aVxbVd+zJP2wdaugYXYfd6JxsyRMrYHaxb6itXPogW2tz+ylUJ1n1b+JF1PHyYCfYHm0dvUg==} + + '@types/eslint-scope@3.7.7': + resolution: {integrity: sha512-MzMFlSLBqNF2gcHWO0G1vP/YQyfvrxZ0bF+u7mzUdZ1/xK4A4sru+nraZz5i3iEIk1l1uyicaDVTB4QbbEkAYg==} + + '@types/eslint@9.6.1': + resolution: {integrity: sha512-FXx2pKgId/WyYo2jXw63kk7/+TY7u7AziEJxJAnSFzHlqTAS3Ync6SvgYAN/k4/PQpnnVuzoMuVnByKK2qp0ag==} + + '@types/estree-jsx@1.0.5': + resolution: {integrity: sha512-52CcUVNFyfb1A2ALocQw/Dd1BQFNmSdkuC3BkZ6iqhdMfQz7JWOFRuJFloOzjk+6WijU56m9oKXFAXc7o3Towg==} + + '@types/estree@1.0.8': + resolution: {integrity: sha512-dWHzHa2WqEXI/O1E9OjrocMTKJl2mSrEolh1Iomrv6U+JuNwaHXsXx9bLu5gG7BUWFIN0skIQJQ/L1rIex4X6w==} + + '@types/fs-extra@11.0.4': + resolution: {integrity: sha512-yTbItCNreRooED33qjunPthRcSjERP1r4MqCZc7wv0u2sUkzTFp45tgUfS5+r7FrZPdmCCNflLhVSP/o+SemsQ==} + + '@types/geojson@7946.0.16': + resolution: {integrity: sha512-6C8nqWur3j98U6+lXDfTUWIfgvZU+EumvpHKcYjujKH7woYyLj2sUmff0tRhrqM7BohUw7Pz3ZB1jj2gW9Fvmg==} + + '@types/hast@3.0.4': + resolution: {integrity: sha512-WPs+bbQw5aCj+x6laNGWLH3wviHtoCv/P3+otBhbOhJgG8qtpdAMlTCxLtsTWA7LH1Oh/bFCHsBn0TPS5m30EQ==} + + '@types/istanbul-lib-coverage@2.0.6': + resolution: {integrity: sha512-2QF/t/auWm0lsy8XtKVPG19v3sSOQlJe/YHZgfjb/KBBHOGSV+J2q/S671rcq9uTBrLAXmZpqJiaQbMT+zNU1w==} + + '@types/json-schema@7.0.15': + resolution: {integrity: sha512-5+fP8P8MFNC+AyZCDxrB2pkZFPGzqQWUzpSeuuVLvm8VMcorNYavBqoFcxK8bQz4Qsbn4oUEEem4wDLfcysGHA==} + + '@types/jsonfile@6.1.4': + resolution: {integrity: sha512-D5qGUYwjvnNNextdU59/+fI+spnwtTFmyQP0h+PfIOSkNfpU6AOICUOkm4i0OnSk+NyjdPJrxCDro0sJsWlRpQ==} + + '@types/mdast@4.0.4': + resolution: {integrity: sha512-kGaNbPh1k7AFzgpud/gMdvIm5xuECykRR+JnWKQno9TAXVa6WIVCGTPvYGekIDL4uwCZQSYbUxNBSb1aUo79oA==} + + '@types/mdx@2.0.13': + resolution: {integrity: sha512-+OWZQfAYyio6YkJb3HLxDrvnx6SWWDbC0zVPfBRzUk0/nqoDyf6dNxQi3eArPe8rJ473nobTMQ/8Zk+LxJ+Yuw==} + + '@types/mocha@10.0.10': + resolution: {integrity: sha512-xPyYSz1cMPnJQhl0CLMH68j3gprKZaTjG3s5Vi+fDgx+uhG9NOXwbVt52eFS8ECyXhyKcjDLCBEqBExKuiZb7Q==} + + '@types/ms@2.1.0': + resolution: {integrity: sha512-GsCCIZDE/p3i96vtEqx+7dBUGXrc7zeSK3wwPHIaRThS+9OhWIXRqzs4d6k1SVU8g91DrNRWxWUGhp5KXQb2VA==} + + '@types/node@20.11.25': + resolution: {integrity: sha512-TBHyJxk2b7HceLVGFcpAUjsa5zIdsPWlR6XHfyGzd0SFu+/NFgQgMAl96MSDZgQDvJAvV6BKsFOrt6zIL09JDw==} + + '@types/node@20.19.9': + resolution: {integrity: sha512-cuVNgarYWZqxRJDQHEB58GEONhOK79QVR/qYx4S7kcUObQvUwvFnYxJuuHUKm2aieN9X3yZB4LZsuYNU1Qphsw==} + + '@types/node@24.1.0': + resolution: {integrity: sha512-ut5FthK5moxFKH2T1CUOC6ctR67rQRvvHdFLCD2Ql6KXmMuCrjsSsRI9UsLCm9M18BMwClv4pn327UvB7eeO1w==} + + '@types/normalize-package-data@2.4.4': + resolution: {integrity: sha512-37i+OaWTh9qeK4LSHPsyRC7NahnGotNuZvjLSgcPzblpHB3rrCJxAOgI5gCdKm7coonsaX1Of0ILiTcnZjbfxA==} + + '@types/pluralize@0.0.33': + resolution: {integrity: sha512-JOqsl+ZoCpP4e8TDke9W79FDcSgPAR0l6pixx2JHkhnRjvShyYiAYw2LVsnA7K08Y6DeOnaU6ujmENO4os/cYg==} + + '@types/prop-types@15.7.15': + resolution: {integrity: sha512-F6bEyamV9jKGAFBEmlQnesRPGOQqS2+Uwi0Em15xenOxHaf2hv6L8YCVn3rPdPJOiJfPiCnLIRyvwVaqMY3MIw==} + + '@types/react-dom@18.3.7': + resolution: {integrity: sha512-MEe3UeoENYVFXzoXEWsvcpg6ZvlrFNlOQ7EOsvhI3CfAXwzPfO8Qwuxd40nepsYKqyyVQnTdEfv68q91yLcKrQ==} + peerDependencies: + '@types/react': ^18.0.0 + + '@types/react@18.3.23': + resolution: {integrity: sha512-/LDXMQh55EzZQ0uVAZmKKhfENivEvWz6E+EYzh+/MCjMhNsotd+ZHhBGIjFDTi6+fz0OhQQQLbTgdQIxxCsC0w==} + + '@types/resolve@1.20.6': + resolution: {integrity: sha512-A4STmOXPhMUtHH+S6ymgE2GiBSMqf4oTvcQZMcHzokuTLVYzXTB8ttjcgxOVaAp2lGwEdzZ0J+cRbbeevQj1UQ==} + + '@types/sarif@2.1.7': + resolution: {integrity: sha512-kRz0VEkJqWLf1LLVN4pT1cg1Z9wAuvI6L97V3m2f5B76Tg8d413ddvLBPTEHAZJlnn4XSvu0FkZtViCQGVyrXQ==} + + '@types/shell-quote@1.7.5': + resolution: {integrity: sha512-+UE8GAGRPbJVQDdxi16dgadcBfQ+KG2vgZhV1+3A1XmHbmwcdwhCUwIdy+d3pAGrbvgRoVSjeI9vOWyq376Yzw==} + + '@types/unist@2.0.11': + resolution: {integrity: sha512-CmBKiL6NNo/OqgmMn95Fk9Whlp2mtvIv+KNpQKN2F4SjvrEesubTRWGYSg+BnWZOnlCaSTU1sMpsBOzgbYhnsA==} + + '@types/unist@3.0.3': + resolution: {integrity: sha512-ko/gIFJRv177XgZsZcBwnqJN5x/Gien8qNOn0D5bQU/zAzVf9Zt3BlcUiLqhV9y4ARk0GbT3tnUiPNgnTXzc/Q==} + + '@types/urijs@1.19.25': + resolution: {integrity: sha512-XOfUup9r3Y06nFAZh3WvO0rBU4OtlfPB/vgxpjg+NRdGU6CN6djdc6OEiH+PcqHCY6eFLo9Ista73uarf4gnBg==} + + '@types/vscode@1.96.0': + resolution: {integrity: sha512-qvZbSZo+K4ZYmmDuaodMbAa67Pl6VDQzLKFka6rq+3WUTY4Kro7Bwoi0CuZLO/wema0ygcmpwow7zZfPJTs5jg==} + + '@typescript-eslint/eslint-plugin@8.38.0': + resolution: {integrity: sha512-CPoznzpuAnIOl4nhj4tRr4gIPj5AfKgkiJmGQDaq+fQnRJTYlcBjbX3wbciGmpoPf8DREufuPRe1tNMZnGdanA==} + engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + peerDependencies: + '@typescript-eslint/parser': ^8.38.0 + eslint: ^8.57.0 || ^9.0.0 + typescript: '>=4.8.4 <5.9.0' + + '@typescript-eslint/parser@8.38.0': + resolution: {integrity: sha512-Zhy8HCvBUEfBECzIl1PKqF4p11+d0aUJS1GeUiuqK9WmOug8YCmC4h4bjyBvMyAMI9sbRczmrYL5lKg/YMbrcQ==} + engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + peerDependencies: + eslint: ^8.57.0 || ^9.0.0 + typescript: '>=4.8.4 <5.9.0' + + '@typescript-eslint/project-service@8.38.0': + resolution: {integrity: sha512-dbK7Jvqcb8c9QfH01YB6pORpqX1mn5gDZc9n63Ak/+jD67oWXn3Gs0M6vddAN+eDXBCS5EmNWzbSxsn9SzFWWg==} + engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + peerDependencies: + typescript: '>=4.8.4 <5.9.0' + + '@typescript-eslint/scope-manager@8.38.0': + resolution: {integrity: sha512-WJw3AVlFFcdT9Ri1xs/lg8LwDqgekWXWhH3iAF+1ZM+QPd7oxQ6jvtW/JPwzAScxitILUIFs0/AnQ/UWHzbATQ==} + engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + + '@typescript-eslint/tsconfig-utils@8.38.0': + resolution: {integrity: sha512-Lum9RtSE3EroKk/bYns+sPOodqb2Fv50XOl/gMviMKNvanETUuUcC9ObRbzrJ4VSd2JalPqgSAavwrPiPvnAiQ==} + engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + peerDependencies: + typescript: '>=4.8.4 <5.9.0' + + '@typescript-eslint/type-utils@8.38.0': + resolution: {integrity: sha512-c7jAvGEZVf0ao2z+nnz8BUaHZD09Agbh+DY7qvBQqLiz8uJzRgVPj5YvOh8I8uEiH8oIUGIfHzMwUcGVco/SJg==} + engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + peerDependencies: + eslint: ^8.57.0 || ^9.0.0 + typescript: '>=4.8.4 <5.9.0' + + '@typescript-eslint/types@8.38.0': + resolution: {integrity: sha512-wzkUfX3plUqij4YwWaJyqhiPE5UCRVlFpKn1oCRn2O1bJ592XxWJj8ROQ3JD5MYXLORW84063z3tZTb/cs4Tyw==} + engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + + '@typescript-eslint/typescript-estree@8.38.0': + resolution: {integrity: sha512-fooELKcAKzxux6fA6pxOflpNS0jc+nOQEEOipXFNjSlBS6fqrJOVY/whSn70SScHrcJ2LDsxWrneFoWYSVfqhQ==} + engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + peerDependencies: + typescript: '>=4.8.4 <5.9.0' + + '@typescript-eslint/utils@8.38.0': + resolution: {integrity: sha512-hHcMA86Hgt+ijJlrD8fX0j1j8w4C92zue/8LOPAFioIno+W0+L7KqE8QZKCcPGc/92Vs9x36w/4MPTJhqXdyvg==} + engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + peerDependencies: + eslint: ^8.57.0 || ^9.0.0 + typescript: '>=4.8.4 <5.9.0' + + '@typescript-eslint/visitor-keys@8.38.0': + resolution: {integrity: sha512-pWrTcoFNWuwHlA9CvlfSsGWs14JxfN1TH25zM5L7o0pRLhsoZkDnTsXfQRJBEWJoV5DL0jf+Z+sxiud+K0mq1g==} + engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + + '@typespec/ts-http-runtime@0.3.0': + resolution: {integrity: sha512-sOx1PKSuFwnIl7z4RN0Ls7N9AQawmR9r66eI5rFCzLDIs8HTIYrIpH9QjYWoX0lkgGrkLxXhi4QnK7MizPRrIg==} + engines: {node: '>=20.0.0'} + + '@uidotdev/usehooks@2.4.1': + resolution: {integrity: sha512-1I+RwWyS+kdv3Mv0Vmc+p0dPYH0DTRAo04HLyXReYBL9AeseDWUJyi4THuksBJcu9F0Pih69Ak150VDnqbVnXg==} + engines: {node: '>=16'} + peerDependencies: + react: '>=18.0.0' + react-dom: '>=18.0.0' + + '@uiw/codemirror-extensions-basic-setup@4.24.1': + resolution: {integrity: sha512-o1m1a8eUS3fWERMbDFvN8t8sZUFPgDKNemmlQ5Ot2vKm+Ax84lKP1dhEFgkiOaZ1bDHk4T5h6SjHuTghrJHKww==} + peerDependencies: + '@codemirror/autocomplete': '>=6.0.0' + '@codemirror/commands': '>=6.0.0' + '@codemirror/language': '>=6.0.0' + '@codemirror/lint': '>=6.0.0' + '@codemirror/search': '>=6.0.0' + '@codemirror/state': '>=6.0.0' + '@codemirror/view': '>=6.0.0' + + '@uiw/react-codemirror@4.24.1': + resolution: {integrity: sha512-BivF4NLqbuBQK5gPVhSkOARi9nPXw8X5r25EnInPeY+I9l1dfEX8O9V6+0xHTlGHyUo0cNfGEF9t1KHEicUfJw==} + peerDependencies: + '@babel/runtime': '>=7.11.0' + '@codemirror/state': '>=6.0.0' + '@codemirror/theme-one-dark': '>=6.0.0' + '@codemirror/view': '>=6.0.0' + codemirror: '>=6.0.0' + react: '>=16.8.0' + react-dom: '>=16.8.0' + + '@ungap/structured-clone@1.3.0': + resolution: {integrity: sha512-WmoN8qaIAo7WTYWbAZuG8PYEhn5fkz7dZrqTBZ7dtt//lL2Gwms1IcnQ5yHqjDfX8Ft5j4YzDM23f87zBfDe9g==} + + '@vitejs/plugin-react-swc@3.11.0': + resolution: {integrity: sha512-YTJCGFdNMHCMfjODYtxRNVAYmTWQ1Lb8PulP/2/f/oEEtglw8oKxKIZmmRkyXrVrHfsKOaVkAc3NT9/dMutO5w==} + peerDependencies: + vite: ^4 || ^5 || ^6 || ^7 + + '@vitejs/plugin-react@4.7.0': + resolution: {integrity: sha512-gUu9hwfWvvEDBBmgtAowQCojwZmJ5mcLn3aufeCsitijs3+f2NsrPtlAWIR6OPiqljl96GVCUbLe0HyqIpVaoA==} + engines: {node: ^14.18.0 || >=16.0.0} + peerDependencies: + vite: ^4.2.0 || ^5.0.0 || ^6.0.0 || ^7.0.0 + + '@vitest/browser@3.2.3': + resolution: {integrity: sha512-5HpUb0ixGF8JWSAjb/P1x/VPuTYUkL4pL0+YO6DJiuvQgqJN3PREaUEcXwfXjU4nBc37EahfpRbAwdE9pHs9lQ==} + peerDependencies: + playwright: '*' + safaridriver: '*' + vitest: 3.2.3 + webdriverio: ^7.0.0 || ^8.0.0 || ^9.0.0 + peerDependenciesMeta: + playwright: + optional: true + safaridriver: + optional: true + webdriverio: + optional: true + + '@vitest/browser@3.2.4': + resolution: {integrity: sha512-tJxiPrWmzH8a+w9nLKlQMzAKX/7VjFs50MWgcAj7p9XQ7AQ9/35fByFYptgPELyLw+0aixTnC4pUWV+APcZ/kw==} + peerDependencies: + playwright: '*' + safaridriver: '*' + vitest: 3.2.4 + webdriverio: ^7.0.0 || ^8.0.0 || ^9.0.0 + peerDependenciesMeta: + playwright: + optional: true + safaridriver: + optional: true + webdriverio: + optional: true + + '@vitest/coverage-v8@3.2.3': + resolution: {integrity: sha512-D1QKzngg8PcDoCE8FHSZhREDuEy+zcKmMiMafYse41RZpBE5EDJyKOTdqK3RQfsV2S2nyKor5KCs8PyPRFqKPg==} + peerDependencies: + '@vitest/browser': 3.2.3 + vitest: 3.2.3 + peerDependenciesMeta: + '@vitest/browser': + optional: true + + '@vitest/expect@3.2.4': + resolution: {integrity: sha512-Io0yyORnB6sikFlt8QW5K7slY4OjqNX9jmJQ02QDda8lyM6B5oNgVWoSoKPac8/kgnCUzuHQKrSLtu/uOqqrig==} + + '@vitest/mocker@3.2.3': + resolution: {integrity: sha512-cP6fIun+Zx8he4rbWvi+Oya6goKQDZK+Yq4hhlggwQBbrlOQ4qtZ+G4nxB6ZnzI9lyIb+JnvyiJnPC2AGbKSPA==} + peerDependencies: + msw: ^2.4.9 + vite: ^5.0.0 || ^6.0.0 || ^7.0.0-0 + peerDependenciesMeta: + msw: + optional: true + vite: + optional: true + + '@vitest/mocker@3.2.4': + resolution: {integrity: sha512-46ryTE9RZO/rfDd7pEqFl7etuyzekzEhUbTW3BvmeO/BcCMEgq59BKhek3dXDWgAj4oMK6OZi+vRr1wPW6qjEQ==} + peerDependencies: + msw: ^2.4.9 + vite: ^5.0.0 || ^6.0.0 || ^7.0.0-0 + peerDependenciesMeta: + msw: + optional: true + vite: + optional: true + + '@vitest/pretty-format@3.2.3': + resolution: {integrity: sha512-yFglXGkr9hW/yEXngO+IKMhP0jxyFw2/qys/CK4fFUZnSltD+MU7dVYGrH8rvPcK/O6feXQA+EU33gjaBBbAng==} + + '@vitest/pretty-format@3.2.4': + resolution: {integrity: sha512-IVNZik8IVRJRTr9fxlitMKeJeXFFFN0JaB9PHPGQ8NKQbGpfjlTx9zO4RefN8gp7eqjNy8nyK3NZmBzOPeIxtA==} + + '@vitest/runner@3.2.4': + resolution: {integrity: sha512-oukfKT9Mk41LreEW09vt45f8wx7DordoWUZMYdY/cyAk7w5TWkTRCNZYF7sX7n2wB7jyGAl74OxgwhPgKaqDMQ==} + + '@vitest/snapshot@3.2.4': + resolution: {integrity: sha512-dEYtS7qQP2CjU27QBC5oUOxLE/v5eLkGqPE0ZKEIDGMs4vKWe7IjgLOeauHsR0D5YuuycGRO5oSRXnwnmA78fQ==} + + '@vitest/spy@3.2.3': + resolution: {integrity: sha512-JHu9Wl+7bf6FEejTCREy+DmgWe+rQKbK+y32C/k5f4TBIAlijhJbRBIRIOCEpVevgRsCQR2iHRUH2/qKVM/plw==} + + '@vitest/spy@3.2.4': + resolution: {integrity: sha512-vAfasCOe6AIK70iP5UD11Ac4siNUNJ9i/9PZ3NKx07sG6sUxeag1LWdNrMWeKKYBLlzuK+Gn65Yd5nyL6ds+nw==} + + '@vitest/ui@3.2.4': + resolution: {integrity: sha512-hGISOaP18plkzbWEcP/QvtRW1xDXF2+96HbEX6byqQhAUbiS5oH6/9JwW+QsQCIYON2bI6QZBF+2PvOmrRZ9wA==} + peerDependencies: + vitest: 3.2.4 + + '@vitest/utils@3.2.3': + resolution: {integrity: sha512-4zFBCU5Pf+4Z6v+rwnZ1HU1yzOKKvDkMXZrymE2PBlbjKJRlrOxbvpfPSvJTGRIwGoahaOGvp+kbCoxifhzJ1Q==} + + '@vitest/utils@3.2.4': + resolution: {integrity: sha512-fB2V0JFrQSMsCo9HiSq3Ezpdv4iYaXRG1Sx8edX3MwxfyNn83mKiGzOcH+Fkxt4MHxr3y42fQi1oeAInqgX2QA==} + + '@volar/language-core@2.4.23': + resolution: {integrity: sha512-hEEd5ET/oSmBC6pi1j6NaNYRWoAiDhINbT8rmwtINugR39loROSlufGdYMF9TaKGfz+ViGs1Idi3mAhnuPcoGQ==} + + '@volar/source-map@2.4.23': + resolution: {integrity: sha512-Z1Uc8IB57Lm6k7q6KIDu/p+JWtf3xsXJqAX/5r18hYOTpJyBn0KXUR8oTJ4WFYOcDzWC9n3IflGgHowx6U6z9Q==} + + '@volar/typescript@2.4.23': + resolution: {integrity: sha512-lAB5zJghWxVPqfcStmAP1ZqQacMpe90UrP5RJ3arDyrhy4aCUQqmxPPLB2PWDKugvylmO41ljK7vZ+t6INMTag==} + + '@vscode/python-extension@1.0.5': + resolution: {integrity: sha512-uYhXUrL/gn92mfqhjAwH2+yGOpjloBxj9ekoL4BhUsKcyJMpEg6WlNf3S3si+5x9zlbHHe7FYQNjZEbz1ymI9Q==} + engines: {node: '>=16.17.1', vscode: ^1.78.0} + + '@vscode/test-cli@0.0.10': + resolution: {integrity: sha512-B0mMH4ia+MOOtwNiLi79XhA+MLmUItIC8FckEuKrVAVriIuSWjt7vv4+bF8qVFiNFe4QRfzPaIZk39FZGWEwHA==} + engines: {node: '>=18'} + hasBin: true + + '@vscode/test-electron@2.5.2': + resolution: {integrity: sha512-8ukpxv4wYe0iWMRQU18jhzJOHkeGKbnw7xWRX3Zw1WJA4cEKbHcmmLPdPrPtL6rhDcrlCZN+xKRpv09n4gRHYg==} + engines: {node: '>=16'} + + '@vscode/vsce-sign-alpine-arm64@2.0.5': + resolution: {integrity: sha512-XVmnF40APwRPXSLYA28Ye+qWxB25KhSVpF2eZVtVOs6g7fkpOxsVnpRU1Bz2xG4ySI79IRuapDJoAQFkoOgfdQ==} + cpu: [arm64] + os: [alpine] + + '@vscode/vsce-sign-alpine-x64@2.0.5': + resolution: {integrity: sha512-JuxY3xcquRsOezKq6PEHwCgd1rh1GnhyH6urVEWUzWn1c1PC4EOoyffMD+zLZtFuZF5qR1I0+cqDRNKyPvpK7Q==} + cpu: [x64] + os: [alpine] + + '@vscode/vsce-sign-darwin-arm64@2.0.5': + resolution: {integrity: sha512-z2Q62bk0ptADFz8a0vtPvnm6vxpyP3hIEYMU+i1AWz263Pj8Mc38cm/4sjzxu+LIsAfhe9HzvYNS49lV+KsatQ==} + cpu: [arm64] + os: [darwin] + + '@vscode/vsce-sign-darwin-x64@2.0.5': + resolution: {integrity: sha512-ma9JDC7FJ16SuPXlLKkvOD2qLsmW/cKfqK4zzM2iJE1PbckF3BlR08lYqHV89gmuoTpYB55+z8Y5Fz4wEJBVDA==} + cpu: [x64] + os: [darwin] + + '@vscode/vsce-sign-linux-arm64@2.0.5': + resolution: {integrity: sha512-Hr1o0veBymg9SmkCqYnfaiUnes5YK6k/lKFA5MhNmiEN5fNqxyPUCdRZMFs3Ajtx2OFW4q3KuYVRwGA7jdLo7Q==} + cpu: [arm64] + os: [linux] + + '@vscode/vsce-sign-linux-arm@2.0.5': + resolution: {integrity: sha512-cdCwtLGmvC1QVrkIsyzv01+o9eR+wodMJUZ9Ak3owhcGxPRB53/WvrDHAFYA6i8Oy232nuen1YqWeEohqBuSzA==} + cpu: [arm] + os: [linux] + + '@vscode/vsce-sign-linux-x64@2.0.5': + resolution: {integrity: sha512-XLT0gfGMcxk6CMRLDkgqEPTyG8Oa0OFe1tPv2RVbphSOjFWJwZgK3TYWx39i/7gqpDHlax0AP6cgMygNJrA6zg==} + cpu: [x64] + os: [linux] + + '@vscode/vsce-sign-win32-arm64@2.0.5': + resolution: {integrity: sha512-hco8eaoTcvtmuPhavyCZhrk5QIcLiyAUhEso87ApAWDllG7djIrWiOCtqn48k4pHz+L8oCQlE0nwNHfcYcxOPw==} + cpu: [arm64] + os: [win32] + + '@vscode/vsce-sign-win32-x64@2.0.5': + resolution: {integrity: sha512-1ixKFGM2FwM+6kQS2ojfY3aAelICxjiCzeg4nTHpkeU1Tfs4RC+lVLrgq5NwcBC7ZLr6UfY3Ct3D6suPeOf7BQ==} + cpu: [x64] + os: [win32] + + '@vscode/vsce-sign@2.0.6': + resolution: {integrity: sha512-j9Ashk+uOWCDHYDxgGsqzKq5FXW9b9MW7QqOIYZ8IYpneJclWTBeHZz2DJCSKQgo+JAqNcaRRE1hzIx0dswqAw==} + + '@vscode/vsce@3.6.0': + resolution: {integrity: sha512-u2ZoMfymRNJb14aHNawnXJtXHLXDVKc1oKZaH4VELKT/9iWKRVgtQOdwxCgtwSxJoqYvuK4hGlBWQJ05wxADhg==} + engines: {node: '>= 20'} + hasBin: true + + '@vue/compiler-core@3.5.18': + resolution: {integrity: sha512-3slwjQrrV1TO8MoXgy3aynDQ7lslj5UqDxuHnrzHtpON5CBinhWjJETciPngpin/T3OuW3tXUf86tEurusnztw==} + + '@vue/compiler-dom@3.5.18': + resolution: {integrity: sha512-RMbU6NTU70++B1JyVJbNbeFkK+A+Q7y9XKE2EM4NLGm2WFR8x9MbAtWxPPLdm0wUkuZv9trpwfSlL6tjdIa1+A==} + + '@vue/compiler-vue2@2.7.16': + resolution: {integrity: sha512-qYC3Psj9S/mfu9uVi5WvNZIzq+xnXMhOwbTFKKDD7b1lhpnn71jXSFdTQ+WsIEk0ONCd7VV2IMm7ONl6tbQ86A==} + + '@vue/language-core@2.2.0': + resolution: {integrity: sha512-O1ZZFaaBGkKbsRfnVH1ifOK1/1BUkyK+3SQsfnh6PmMmD4qJcTU8godCeA96jjDRTL6zgnK7YzCHfaUlH2r0Mw==} + peerDependencies: + typescript: '*' + peerDependenciesMeta: + typescript: + optional: true + + '@vue/shared@3.5.18': + resolution: {integrity: sha512-cZy8Dq+uuIXbxCZpuLd2GJdeSO/lIzIspC2WtkqIpje5QyFbvLaI5wZtdUjLHjGZrlVX6GilejatWwVYYRc8tA==} + + '@webassemblyjs/ast@1.14.1': + resolution: {integrity: sha512-nuBEDgQfm1ccRp/8bCQrx1frohyufl4JlbMMZ4P1wpeOfDhF6FQkxZJ1b/e+PLwr6X1Nhw6OLme5usuBWYBvuQ==} + + '@webassemblyjs/floating-point-hex-parser@1.13.2': + resolution: {integrity: sha512-6oXyTOzbKxGH4steLbLNOu71Oj+C8Lg34n6CqRvqfS2O71BxY6ByfMDRhBytzknj9yGUPVJ1qIKhRlAwO1AovA==} + + '@webassemblyjs/helper-api-error@1.13.2': + resolution: {integrity: sha512-U56GMYxy4ZQCbDZd6JuvvNV/WFildOjsaWD3Tzzvmw/mas3cXzRJPMjP83JqEsgSbyrmaGjBfDtV7KDXV9UzFQ==} + + '@webassemblyjs/helper-buffer@1.14.1': + resolution: {integrity: sha512-jyH7wtcHiKssDtFPRB+iQdxlDf96m0E39yb0k5uJVhFGleZFoNw1c4aeIcVUPPbXUVJ94wwnMOAqUHyzoEPVMA==} + + '@webassemblyjs/helper-numbers@1.13.2': + resolution: {integrity: sha512-FE8aCmS5Q6eQYcV3gI35O4J789wlQA+7JrqTTpJqn5emA4U2hvwJmvFRC0HODS+3Ye6WioDklgd6scJ3+PLnEA==} + + '@webassemblyjs/helper-wasm-bytecode@1.13.2': + resolution: {integrity: sha512-3QbLKy93F0EAIXLh0ogEVR6rOubA9AoZ+WRYhNbFyuB70j3dRdwH9g+qXhLAO0kiYGlg3TxDV+I4rQTr/YNXkA==} + + '@webassemblyjs/helper-wasm-section@1.14.1': + resolution: {integrity: sha512-ds5mXEqTJ6oxRoqjhWDU83OgzAYjwsCV8Lo/N+oRsNDmx/ZDpqalmrtgOMkHwxsG0iI//3BwWAErYRHtgn0dZw==} + + '@webassemblyjs/ieee754@1.13.2': + resolution: {integrity: sha512-4LtOzh58S/5lX4ITKxnAK2USuNEvpdVV9AlgGQb8rJDHaLeHciwG4zlGr0j/SNWlr7x3vO1lDEsuePvtcDNCkw==} + + '@webassemblyjs/leb128@1.13.2': + resolution: {integrity: sha512-Lde1oNoIdzVzdkNEAWZ1dZ5orIbff80YPdHx20mrHwHrVNNTjNr8E3xz9BdpcGqRQbAEa+fkrCb+fRFTl/6sQw==} + + '@webassemblyjs/utf8@1.13.2': + resolution: {integrity: sha512-3NQWGjKTASY1xV5m7Hr0iPeXD9+RDobLll3T9d2AO+g3my8xy5peVyjSag4I50mR1bBSN/Ct12lo+R9tJk0NZQ==} + + '@webassemblyjs/wasm-edit@1.14.1': + resolution: {integrity: sha512-RNJUIQH/J8iA/1NzlE4N7KtyZNHi3w7at7hDjvRNm5rcUXa00z1vRz3glZoULfJ5mpvYhLybmVcwcjGrC1pRrQ==} + + '@webassemblyjs/wasm-gen@1.14.1': + resolution: {integrity: sha512-AmomSIjP8ZbfGQhumkNvgC33AY7qtMCXnN6bL2u2Js4gVCg8fp735aEiMSBbDR7UQIj90n4wKAFUSEd0QN2Ukg==} + + '@webassemblyjs/wasm-opt@1.14.1': + resolution: {integrity: sha512-PTcKLUNvBqnY2U6E5bdOQcSM+oVP/PmrDY9NzowJjislEjwP/C4an2303MCVS2Mg9d3AJpIGdUFIQQWbPds0Sw==} + + '@webassemblyjs/wasm-parser@1.14.1': + resolution: {integrity: sha512-JLBl+KZ0R5qB7mCnud/yyX08jWFw5MsoalJ1pQ4EdFlgj9VdXKGuENGsiCIjegI1W7p91rUlcB/LB5yRJKNTcQ==} + + '@webassemblyjs/wast-printer@1.14.1': + resolution: {integrity: sha512-kPSSXE6De1XOR820C90RIo2ogvZG+c3KiHzqUoO/F34Y2shGzesfqv7o57xrxovZJH/MetF5UjroJ/R/3isoiw==} + + '@xtuc/ieee754@1.2.0': + resolution: {integrity: sha512-DX8nKgqcGwsc0eJSqYt5lwP4DH5FlHnmuWWBRy7X0NcaGR0ZtuyeESgMwTYVEtxmsNGY+qit4QYT/MIYTOTPeA==} + + '@xtuc/long@4.2.2': + resolution: {integrity: sha512-NuHqBY1PB/D8xU6s/thBgOAiAP7HOYDQ32+BFZILJ8ivkUkAHQnWfn6WhL79Owj1qmUnoN/YPhktdIoucipkAQ==} + + '@xyflow/react@12.8.4': + resolution: {integrity: sha512-bqUu4T5QSHiCFPkoH+b+LROKwQJdLvcjhGbNW9c1dLafCBRjmH1IYz0zPE+lRDXCtQ9kRyFxz3tG19+8VORJ1w==} + peerDependencies: + react: '>=17' + react-dom: '>=17' + + '@xyflow/system@0.0.68': + resolution: {integrity: sha512-QDG2wxIG4qX+uF8yzm1ULVZrcXX3MxPBoxv7O52FWsX87qIImOqifUhfa/TwsvLdzn7ic2DDBH1uI8TKbdNTYA==} + + abort-controller@3.0.0: + resolution: {integrity: sha512-h8lQ8tacZYnR3vNQTgibj+tODHI5/+l06Au2Pcriv/Gmet0eaj4TwWH41sO9wnHDiQsEj19q0drzdWdeAHtweg==} + engines: {node: '>=6.5'} + + acorn-jsx@5.3.2: + resolution: {integrity: sha512-rq9s+JNhf0IChjtDXxllJ7g41oZk5SlXtp0LHwyA5cejwn7vKmKp4pPri6YEePv2PU65sAsegbXtIinmDFDXgQ==} + peerDependencies: + acorn: ^6.0.0 || ^7.0.0 || ^8.0.0 + + acorn@8.15.0: + resolution: {integrity: sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==} + engines: {node: '>=0.4.0'} + hasBin: true + + agent-base@7.1.4: + resolution: {integrity: sha512-MnA+YT8fwfJPgBx3m60MNqakm30XOkyIoH1y6huTQvC0PwZG7ki8NacLBcrPbNoo8vEZy7Jpuk7+jMO+CUovTQ==} + engines: {node: '>= 14'} + + ajv-draft-04@1.0.0: + resolution: {integrity: sha512-mv00Te6nmYbRp5DCwclxtt7yV/joXJPGS7nM+97GdxvuttCOfgI3K4U25zboyeX0O+myI8ERluxQe5wljMmVIw==} + peerDependencies: + ajv: ^8.5.0 + peerDependenciesMeta: + ajv: + optional: true + + ajv-errors@3.0.0: + resolution: {integrity: sha512-V3wD15YHfHz6y0KdhYFjyy9vWtEVALT9UrxfN3zqlI6dMioHnJrqOYfyPKol3oqrnCM9uwkcdCwkJ0WUcbLMTQ==} + peerDependencies: + ajv: ^8.0.1 + + ajv-formats@2.1.1: + resolution: {integrity: sha512-Wx0Kx52hxE7C18hkMEggYlEifqWZtYaRgouJor+WMdPnQyEK13vgEWyVNup7SoeeoLMsr4kf5h6dOW11I15MUA==} + peerDependencies: + ajv: ^8.0.0 + peerDependenciesMeta: + ajv: + optional: true + + ajv-formats@3.0.1: + resolution: {integrity: sha512-8iUql50EUR+uUcdRQ3HDqa6EVyo3docL8g5WJ3FNcWmu62IbkGUue/pEyLBW8VGKKucTPgqeks4fIU1DA4yowQ==} + peerDependencies: + ajv: ^8.0.0 + peerDependenciesMeta: + ajv: + optional: true + + ajv-keywords@5.1.0: + resolution: {integrity: sha512-YCS/JNFAUyr5vAuhk1DWm1CBxRHW9LbJ2ozWeemrIqpbsqKjHVxYPyi5GC0rjZIT5JxJ3virVTS8wk4i/Z+krw==} + peerDependencies: + ajv: ^8.8.2 + + ajv@6.12.6: + resolution: {integrity: sha512-j3fVLgvTo527anyYyJOGTYJbG+vnnQYvE0m5mmkc1TK+nxAppkCLMIL0aZ4dblVCNoGShhm+kzE4ZUykBoMg4g==} + + ajv@8.12.0: + resolution: {integrity: sha512-sRu1kpcO9yLtYxBKvqfTeh9KzZEwO3STyX1HT+4CaDzC6HpTGYhIhPIzj9XuKU7KYDwnaeh5hcOwjy1QuJzBPA==} + + ajv@8.13.0: + resolution: {integrity: sha512-PRA911Blj99jR5RMeTunVbNXMF6Lp4vZXnk5GQjcnUWUTsrXtekg/pnmFFI2u/I36Y/2bITGS30GZCXei6uNkA==} + + ajv@8.17.1: + resolution: {integrity: sha512-B/gBuNg5SiMTrPkC+A2+cW0RszwxYmn6VYxB/inlBStS5nx6xHIt/ehKRhIMhqusl7a8LjQoZnjCs5vhwxOQ1g==} + + alien-signals@0.4.14: + resolution: {integrity: sha512-itUAVzhczTmP2U5yX67xVpsbbOiquusbWVyA9N+sy6+r6YVbFkahXvNCeEPWEOMhwDYwbVbGHFkVL03N9I5g+Q==} + + ansi-colors@4.1.3: + resolution: {integrity: sha512-/6w/C21Pm1A7aZitlI5Ni/2J6FFQN8i1Cvz3kHABAAbw93v/NlvKdVOqz7CCWz/3iv/JplRSEEZ83XION15ovw==} + engines: {node: '>=6'} + + ansi-escapes@7.0.0: + resolution: {integrity: sha512-GdYO7a61mR0fOlAsvC9/rIHf7L96sBc6dEWzeOu+KAea5bZyQRPIpojrVoI4AXGJS/ycu/fBTdLrUkA4ODrvjw==} + engines: {node: '>=18'} + + ansi-regex@5.0.1: + resolution: {integrity: sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==} + engines: {node: '>=8'} + + ansi-regex@6.1.0: + resolution: {integrity: sha512-7HSX4QQb4CspciLpVFwyRe79O3xsIZDDLER21kERQ71oaPodF8jL725AgJMFAYbooIqolJoRLuM81SpeUkpkvA==} + engines: {node: '>=12'} + + ansi-styles@4.3.0: + resolution: {integrity: sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==} + engines: {node: '>=8'} + + ansi-styles@5.2.0: + resolution: {integrity: sha512-Cxwpt2SfTzTtXcfOlzGEee8O+c+MmUgGrNiBcXnuWxuFJHe6a5Hz7qwhwe5OgaSYI0IJvkLqWX1ASG+cJOkEiA==} + engines: {node: '>=10'} + + ansi-styles@6.2.1: + resolution: {integrity: sha512-bN798gFfQX+viw3R7yrGWRqnrN2oRkEkUjjl4JNn4E8GxxbjtG3FbrEIIY3l8/hrwUwIeCZvi4QuOTP4MErVug==} + engines: {node: '>=12'} + + ansis@4.1.0: + resolution: {integrity: sha512-BGcItUBWSMRgOCe+SVZJ+S7yTRG0eGt9cXAHev72yuGcY23hnLA7Bky5L/xLyPINoSN95geovfBkqoTlNZYa7w==} + engines: {node: '>=14'} + + any-promise@1.3.0: + resolution: {integrity: sha512-7UvmKalWRt1wgjL1RrGxoSJW/0QZFIegpeGvZG9kjp8vrRu55XTHbwnqq2GpXm9uLbcuhxm3IqX9OB4MZR1b2A==} + + anymatch@3.1.3: + resolution: {integrity: sha512-KMReFUr0B4t+D+OBkjR3KYqvocp2XaSzO55UcB6mgQMd3KbcE+mWTyvVV7D/zsdEbNnV6acZUutkiHQXvTr1Rw==} + engines: {node: '>= 8'} + + apache-arrow@19.0.1: + resolution: {integrity: sha512-APmMLzS4qbTivLrPdQXexGM4JRr+0g62QDaobzEvip/FdQIrv2qLy0mD5Qdmw4buydtVJgbFeKR8f59I6PPGDg==} + hasBin: true + + arg@5.0.2: + resolution: {integrity: sha512-PYjyFOLKQ9y57JvQ6QLo8dAgNqswh8M1RMJYdQduT6xbWSgK36P/Z/v+p888pM69jMMfS8Xd8F6I1kQ/I9HUGg==} + + argparse@1.0.10: + resolution: {integrity: sha512-o5Roy6tNG4SL/FOkCAN6RzjiakZS25RLYFrcMttJqbdd8BWrnA+fGz57iN5Pb06pvBGvl5gQ0B48dJlslXvoTg==} + + argparse@2.0.1: + resolution: {integrity: sha512-8+9WqebbFzpX9OR+Wa6O29asIogeRMzcGtAINdpMHHyAg10f05aSFVBbcEqGf/PXw1EjAZ+q2/bEBg3DvurK3Q==} + + aria-hidden@1.2.6: + resolution: {integrity: sha512-ik3ZgC9dY/lYVVM++OISsaYDeg1tb0VtP5uL3ouh1koGOaUMDPpbFIei4JkFimWUFPn90sbMNMXQAIVOlnYKJA==} + engines: {node: '>=10'} + + aria-query@5.3.0: + resolution: {integrity: sha512-b0P0sZPKtyu8HkeRAfCq0IfURZK+SuwMjY1UXGBU27wpAiTwQAIlq56IbIO+ytk/JjS1fMR14ee5WBBfKi5J6A==} + + aria-query@5.3.2: + resolution: {integrity: sha512-COROpnaoap1E2F000S62r6A60uHZnmlvomhfyT2DlTcrY1OrBKn2UhH7qn5wTC9zMvD0AY7csdPSNwKP+7WiQw==} + engines: {node: '>= 0.4'} + + array-back@6.2.2: + resolution: {integrity: sha512-gUAZ7HPyb4SJczXAMUXMGAvI976JoK3qEx9v1FTmeYuJj0IBiaKttG1ydtGKdkfqWkIkouke7nG8ufGy77+Cvw==} + engines: {node: '>=12.17'} + + array-buffer-byte-length@1.0.2: + resolution: {integrity: sha512-LHE+8BuR7RYGDKvnrmcuSq3tDcKv9OFEXQt/HpbZhY7V6h0zlUXutnAD82GiFx9rdieCMjkvtcsPqBwgUl1Iiw==} + engines: {node: '>= 0.4'} + + array-union@2.1.0: + resolution: {integrity: sha512-HGyxoOTYUyCM6stUe6EJgnd4EoewAI7zMdfqO+kGjnlZmBDz/cR5pf8r/cR4Wq60sL/p0IkcjUEEPwS3GFrIyw==} + engines: {node: '>=8'} + + arraybuffer.prototype.slice@1.0.4: + resolution: {integrity: sha512-BNoCY6SXXPQ7gF2opIP4GBE+Xw7U+pHMYKuzjgCN3GwiaIR09UUeKfheyIry77QtrCBlC0KK0q5/TER/tYh3PQ==} + engines: {node: '>= 0.4'} + + assertion-error@2.0.1: + resolution: {integrity: sha512-Izi8RQcffqCeNVgFigKli1ssklIbpHnCYc6AknXGYoB6grJqyeby7jv12JUQgmTAnIDnbck1uxksT4dzN3PWBA==} + engines: {node: '>=12'} + + ast-types@0.16.1: + resolution: {integrity: sha512-6t10qk83GOG8p0vKmaCr8eiilZwO171AvbROMtvvNiwrTly62t+7XkA8RdIIVbpMhCASAsxgAzdRSwh6nw/5Dg==} + engines: {node: '>=4'} + + ast-v8-to-istanbul@0.3.3: + resolution: {integrity: sha512-MuXMrSLVVoA6sYN/6Hke18vMzrT4TZNbZIj/hvh0fnYFpO+/kFXcLIaiPwXXWaQUPg4yJD8fj+lfJ7/1EBconw==} + + astral-regex@2.0.0: + resolution: {integrity: sha512-Z7tMw1ytTXt5jqMcOP+OQteU1VuNK9Y02uuJtKQ1Sv69jXQKKg5cibLwGJow8yzZP+eAc18EmLGPal0bp36rvQ==} + engines: {node: '>=8'} + + astring@1.9.0: + resolution: {integrity: sha512-LElXdjswlqjWrPpJFg1Fx4wpkOCxj1TDHlSV4PlaRxHGWko024xICaa97ZkMfs6DRKlCguiAI+rbXv5GWwXIkg==} + hasBin: true + + async-function@1.0.0: + resolution: {integrity: sha512-hsU18Ae8CDTR6Kgu9DYf0EbCr/a5iGL0rytQDobUcdpYOKokk8LEjVphnXkDkgpi0wYVsqrXuP0bZxJaTqdgoA==} + engines: {node: '>= 0.4'} + + asynckit@0.4.0: + resolution: {integrity: sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==} + + autoprefixer@10.4.21: + resolution: {integrity: sha512-O+A6LWV5LDHSJD3LjHYoNi4VLsj/Whi7k6zG12xTYaU4cQ8oxQGckXNX8cRHK5yOZ/ppVHe0ZBXGzSV9jXdVbQ==} + engines: {node: ^10 || ^12 || >=14} + hasBin: true + peerDependencies: + postcss: ^8.1.0 + + available-typed-arrays@1.0.7: + resolution: {integrity: sha512-wvUjBtSGN7+7SjNpq/9M2Tg350UZD3q62IFZLbRAR1bSMlCo1ZaeW+BJ+D090e4hIIZLBcTDWe4Mh4jvUDajzQ==} + engines: {node: '>= 0.4'} + + axe-core@4.10.3: + resolution: {integrity: sha512-Xm7bpRXnDSX2YE2YFfBk2FnF0ep6tmG7xPh8iHee8MIcrgq762Nkce856dYtJYLkuIoYZvGfTs/PbZhideTcEg==} + engines: {node: '>=4'} + + azure-devops-node-api@12.5.0: + resolution: {integrity: sha512-R5eFskGvOm3U/GzeAuxRkUsAl0hrAwGgWn6zAd2KrZmrEhWZVqLew4OOupbQlXUuojUzpGtq62SmdhJ06N88og==} + + babel-dead-code-elimination@1.0.10: + resolution: {integrity: sha512-DV5bdJZTzZ0zn0DC24v3jD7Mnidh6xhKa4GfKCbq3sfW8kaWhDdZjP3i81geA8T33tdYqWKw4D3fVv0CwEgKVA==} + + bail@2.0.2: + resolution: {integrity: sha512-0xO6mYd7JB2YesxDKplafRpsiOzPt9V02ddPCLbY1xYGPOX24NTyN50qnUxgCPcSoYMhKpAuBTjQoRZCAkUDRw==} + + balanced-match@1.0.2: + resolution: {integrity: sha512-3oSeUO0TMV67hN1AmbXsK4yaqU7tjiHlbxRDZOpH0KW9+CeX4bRAaX0Anxt0tx2MrpRpWwQaPwIlISEJhYU5Pw==} + + base64-js@1.5.1: + resolution: {integrity: sha512-AKpaYlHn8t4SVbOHCy+b5+KKgvR4vrsD8vbvrbiQJps7fKDTkjkDry6ji0rUJjC0kzbNePLwzxq8iypo41qeWA==} + + baseline-browser-mapping@2.8.9: + resolution: {integrity: sha512-hY/u2lxLrbecMEWSB0IpGzGyDyeoMFQhCvZd2jGFSE5I17Fh01sYUBPCJtkWERw7zrac9+cIghxm/ytJa2X8iA==} + hasBin: true + + better-opn@3.0.2: + resolution: {integrity: sha512-aVNobHnJqLiUelTaHat9DZ1qM2w0C0Eym4LPI/3JxOnSokGVdsl1T1kN7TFvsEAD8G47A6VKQ0TVHqbBnYMJlQ==} + engines: {node: '>=12.0.0'} + + binary-extensions@2.3.0: + resolution: {integrity: sha512-Ceh+7ox5qe7LJuLHoY0feh3pHuUDHAcRUeyL2VYghZwfpkNIy/+8Ocg0a3UuSoYzavmylwuLWQOf3hl0jjMMIw==} + engines: {node: '>=8'} + + binaryextensions@6.11.0: + resolution: {integrity: sha512-sXnYK/Ij80TO3lcqZVV2YgfKN5QjUWIRk/XSm2J/4bd/lPko3lvk0O4ZppH6m+6hB2/GTu+ptNwVFe1xh+QLQw==} + engines: {node: '>=4'} + + bl@4.1.0: + resolution: {integrity: sha512-1W07cM9gS6DcLperZfFSj+bWLtaPGSOHWhPiGzXmvVJbRLdG82sH/Kn8EtW1VqWVA54AKf2h5k5BbnIbwF3h6w==} + + boolbase@1.0.0: + resolution: {integrity: sha512-JZOSA7Mo9sNGB8+UjSgzdLtokWAky1zbztM3WRLCbZ70/3cTANmQmOdR7y2g+J0e2WXywy1yS468tY+IruqEww==} + + boundary@2.0.0: + resolution: {integrity: sha512-rJKn5ooC9u8q13IMCrW0RSp31pxBCHE3y9V/tp3TdWSLf8Em3p6Di4NBpfzbJge9YjjFEsD0RtFEjtvHL5VyEA==} + + brace-expansion@1.1.12: + resolution: {integrity: sha512-9T9UjW3r0UW5c1Q7GTwllptXwhvYmEzFhzMfZ9H7FQWt+uZePjZPjBP/W1ZEyZ1twGWom5/56TF4lPcqjnDHcg==} + + brace-expansion@2.0.2: + resolution: {integrity: sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==} + + braces@3.0.3: + resolution: {integrity: sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==} + engines: {node: '>=8'} + + browser-stdout@1.3.1: + resolution: {integrity: sha512-qhAVI1+Av2X7qelOfAIYwXONood6XlZE/fXaBSmW/T5SzLAmCgzi+eiWE7fUvbHaeNBQH13UftjpXxsfLkMpgw==} + + browserslist@4.26.2: + resolution: {integrity: sha512-ECFzp6uFOSB+dcZ5BK/IBaGWssbSYBHvuMeMt3MMFyhI0Z8SqGgEkBLARgpRH3hutIgPVsALcMwbDrJqPxQ65A==} + engines: {node: ^6 || ^7 || ^8 || ^9 || ^10 || ^11 || ^12 || >=13.7} + hasBin: true + + buffer-crc32@0.2.13: + resolution: {integrity: sha512-VO9Ht/+p3SN7SKWqcrgEzjGbRSJYTx+Q1pTQC0wrWqHx0vpJraQ6GtHx8tvcg1rlK1byhU5gccxgOgj7B0TDkQ==} + + buffer-equal-constant-time@1.0.1: + resolution: {integrity: sha512-zRpUiDwd/xk6ADqPMATG8vc9VPrkck7T07OIx0gnjmJAnHnTVXNQG3vfvWNuiZIkwu9KrKdA1iJKfsfTVxE6NA==} + + buffer-from@1.1.2: + resolution: {integrity: sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==} + + buffer@5.7.1: + resolution: {integrity: sha512-EHcyIPBQ4BSGlvjB16k5KgAJ27CIsHY/2JBmCRReo48y9rQ3MaUzWX3KVlBa4U7MyX02HdVj0K7C3WaB3ju7FQ==} + + bundle-name@4.1.0: + resolution: {integrity: sha512-tjwM5exMg6BGRI+kNmTntNsvdZS1X8BFYS6tnJ2hdH0kVxM6/eVZ2xy+FqStSWvYmtfFMDLIxurorHwDKfDz5Q==} + engines: {node: '>=18'} + + c8@9.1.0: + resolution: {integrity: sha512-mBWcT5iqNir1zIkzSPyI3NCR9EZCVI3WUD+AVO17MVWTSFNyUueXE82qTeampNtTr+ilN/5Ua3j24LgbCKjDVg==} + engines: {node: '>=14.14.0'} + hasBin: true + + cac@6.7.14: + resolution: {integrity: sha512-b6Ilus+c3RrdDk+JhLKUAQfzzgLEPy6wcXqS7f/xe1EETvsDP6GORG7SFuOs6cID5YkqchW/LXZbX5bc8j7ZcQ==} + engines: {node: '>=8'} + + call-bind-apply-helpers@1.0.2: + resolution: {integrity: sha512-Sp1ablJ0ivDkSzjcaJdxEunN5/XvksFJ2sMBFfq6x0ryhQV/2b/KwFe21cMpmHtPOSij8K99/wSfoEuTObmuMQ==} + engines: {node: '>= 0.4'} + + call-bind@1.0.8: + resolution: {integrity: sha512-oKlSFMcMwpUg2ednkhQ454wfWiU/ul3CkJe/PEHcTKuiX6RpbehUiFMXu13HalGZxfUwCQzZG747YXBn1im9ww==} + engines: {node: '>= 0.4'} + + call-bound@1.0.4: + resolution: {integrity: sha512-+ys997U96po4Kx/ABpBCqhA9EuxJaQWDQg7295H4hBphv3IZg0boBKuwYpt4YXp6MZ5AmZQnU/tyMTlRpaSejg==} + engines: {node: '>= 0.4'} + + call-me-maybe@1.0.2: + resolution: {integrity: sha512-HpX65o1Hnr9HH25ojC1YGs7HCQLq0GCOibSaWER0eNpgJ/Z1MZv2mTc7+xh6WOPxbRVcmgbv4hGU+uSQ/2xFZQ==} + + callsites@3.1.0: + resolution: {integrity: sha512-P8BjAsXvZS+VIDUI11hHCQEv74YT67YUi5JJFNWIqL235sBmjX4+qx9Muvls5ivyNENctx46xQLQ3aTuE7ssaQ==} + engines: {node: '>=6'} + + camelcase-css@2.0.1: + resolution: {integrity: sha512-QOSvevhslijgYwRx6Rv7zKdMF8lbRmx+uQGx2+vDc+KI/eBnsy9kit5aj23AgGu3pa4t9AgwbnXWqS+iOY+2aA==} + engines: {node: '>= 6'} + + camelcase@6.3.0: + resolution: {integrity: sha512-Gmy6FhYlCY7uOElZUSbxo2UCDH8owEk996gkbrpsgGtrJLM3J7jGxl9Ic7Qwwj4ivOE5AWZWRMecDdF7hqGjFA==} + engines: {node: '>=10'} + + caniuse-lite@1.0.30001746: + resolution: {integrity: sha512-eA7Ys/DGw+pnkWWSE/id29f2IcPHVoE8wxtvE5JdvD2V28VTDPy1yEeo11Guz0sJ4ZeGRcm3uaTcAqK1LXaphA==} + + ccount@2.0.1: + resolution: {integrity: sha512-eyrF0jiFpY+3drT6383f1qhkbGsLSifNAjA61IUjZjmLCWjItY6LB9ft9YhoDgwfmclB2zhu51Lc7+95b8NRAg==} + + chai@5.2.1: + resolution: {integrity: sha512-5nFxhUrX0PqtyogoYOA8IPswy5sZFTOsBFl/9bNsmDLgsxYTzSZQJDPppDnZPTQbzSEm0hqGjWPzRemQCYbD6A==} + engines: {node: '>=18'} + + chalk-template@0.4.0: + resolution: {integrity: sha512-/ghrgmhfY8RaSdeo43hNXxpoHAtxdbskUHjPpfqUWGttFgycUhYPGx3YZBCnUCvOa7Doivn1IZec3DEGFoMgLg==} + engines: {node: '>=12'} + + chalk-template@1.1.0: + resolution: {integrity: sha512-T2VJbcDuZQ0Tb2EWwSotMPJjgpy1/tGee1BTpUNsGZ/qgNjV2t7Mvu+d4600U564nbLesN1x2dPL+xii174Ekg==} + engines: {node: '>=14.16'} + + chalk@3.0.0: + resolution: {integrity: sha512-4D3B6Wf41KOYRFdszmDqMCGq5VV/uMAB273JILmO+3jAlh8X4qDtdtgCR3fxtbLEMzSx22QdhnDcJvu2u1fVwg==} + engines: {node: '>=8'} + + chalk@4.1.2: + resolution: {integrity: sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==} + engines: {node: '>=10'} + + chalk@5.4.1: + resolution: {integrity: sha512-zgVZuo2WcZgfUEmsn6eO3kINexW8RAE4maiQ8QNs8CtpPCSyMiYsULR3HQYkm3w8FIA3SberyMJMSldGsW+U3w==} + engines: {node: ^12.17.0 || ^14.13 || >=16.0.0} + + character-entities-html4@2.1.0: + resolution: {integrity: sha512-1v7fgQRj6hnSwFpq1Eu0ynr/CDEw0rXo2B61qXrLNdHZmPKgb7fqS1a2JwF0rISo9q77jDI8VMEHoApn8qDoZA==} + + character-entities-legacy@3.0.0: + resolution: {integrity: sha512-RpPp0asT/6ufRm//AJVwpViZbGM/MkjQFxJccQRHmISF/22NBtsHqAWmL+/pmkPWoIUJdWyeVleTl1wydHATVQ==} + + character-entities@2.0.2: + resolution: {integrity: sha512-shx7oQ0Awen/BRIdkjkvz54PnEEI/EjwXDSIZp86/KKdbafHh1Df/RYGBhn4hbe2+uKC9FnT5UCEdyPz3ai9hQ==} + + character-reference-invalid@2.0.1: + resolution: {integrity: sha512-iBZ4F4wRbyORVsu0jPV7gXkOsGYjGHPmAyv+HiHG8gi5PtC9KI2j1+v8/tlibRvjoWX027ypmG/n0HtO5t7unw==} + + check-error@2.1.1: + resolution: {integrity: sha512-OAlb+T7V4Op9OwdkjmguYRqncdlx5JiofwOAUkmTF+jNdHwzTaTs4sRAGpzLF3oOz5xAyDGrPgeIDFQmDOTiJw==} + engines: {node: '>= 16'} + + cheerio-select@2.1.0: + resolution: {integrity: sha512-9v9kG0LvzrlcungtnJtpGNxY+fzECQKhK4EGJX2vByejiMX84MFNQw4UxPJl3bFbTMw+Dfs37XaIkCwTZfLh4g==} + + cheerio@1.1.2: + resolution: {integrity: sha512-IkxPpb5rS/d1IiLbHMgfPuS0FgiWTtFIm/Nj+2woXDLTZ7fOT2eqzgYbdMlLweqlHbsZjxEChoVK+7iph7jyQg==} + engines: {node: '>=20.18.1'} + + chokidar@3.6.0: + resolution: {integrity: sha512-7VT13fmjotKpGipCW9JEQAusEPE+Ei8nl6/g4FBAmIm0GOOLMua9NDDo/DWp0ZAxCr3cPq5ZpBqmPAQgDda2Pw==} + engines: {node: '>= 8.10.0'} + + chokidar@4.0.3: + resolution: {integrity: sha512-Qgzu8kfBvo+cA4962jnP1KkS6Dop5NS6g7R5LFYJr4b8Ub94PPQXUksCw9PvXoeXPRRddRNC5C1JQUR2SMGtnA==} + engines: {node: '>= 14.16.0'} + + chownr@1.1.4: + resolution: {integrity: sha512-jJ0bqzaylmJtVnNgzTeSOs8DPavpbYgEr/b0YL8/2GO3xJEhInFmhKMUnEJQjZumK7KXGFhUy89PrsJWlakBVg==} + + chownr@3.0.0: + resolution: {integrity: sha512-+IxzY9BZOQd/XuYPRmrvEVjF/nqj5kgT4kEq7VofrDoM1MxoRjEWkrCC3EtLi59TVawxTAn+orJwFQcrqEN1+g==} + engines: {node: '>=18'} + + chromatic@12.2.0: + resolution: {integrity: sha512-GswmBW9ZptAoTns1BMyjbm55Z7EsIJnUvYKdQqXIBZIKbGErmpA+p4c0BYA+nzw5B0M+rb3Iqp1IaH8TFwIQew==} + hasBin: true + peerDependencies: + '@chromatic-com/cypress': ^0.*.* || ^1.0.0 + '@chromatic-com/playwright': ^0.*.* || ^1.0.0 + peerDependenciesMeta: + '@chromatic-com/cypress': + optional: true + '@chromatic-com/playwright': + optional: true + + chrome-trace-event@1.0.4: + resolution: {integrity: sha512-rNjApaLzuwaOTjCiT8lSDdGN1APCiqkChLMJxJPWLunPAt5fy8xgU9/jNOchV84wfIxrA0lRQB7oCT8jrn/wrQ==} + engines: {node: '>=6.0'} + + class-variance-authority@0.7.1: + resolution: {integrity: sha512-Ka+9Trutv7G8M6WT6SeiRWz792K5qEqIGEGzXKhAE6xOWAY6pPH8U+9IY3oCMv6kqTmLsv7Xh/2w2RigkePMsg==} + + classcat@5.0.5: + resolution: {integrity: sha512-JhZUT7JFcQy/EzW605k/ktHtncoo9vnyW/2GspNYwFlN1C/WmjuV/xtS04e9SOkL2sTdw0VAZ2UGCcQ9lR6p6w==} + + cli-cursor@5.0.0: + resolution: {integrity: sha512-aCj4O5wKyszjMmDT4tZj93kxyydN/K5zPWSCe6/0AV/AA1pqe5ZBIw0a2ZfPQV7lL5/yb5HsUreJ6UFAF1tEQw==} + engines: {node: '>=18'} + + cli-spinners@2.9.2: + resolution: {integrity: sha512-ywqV+5MmyL4E7ybXgKys4DugZbX0FC6LnwrhjuykIjnK9k8OQacQ7axGKnjDXWNhns0xot3bZI5h55H8yo9cJg==} + engines: {node: '>=6'} + + cliui@7.0.4: + resolution: {integrity: sha512-OcRE68cOsVMXp1Yvonl/fzkQOyjLSu/8bhPDfQt0e0/Eb283TKP20Fs2MqoPsr9SwA595rRCA+QMzYc9nBP+JQ==} + + cliui@8.0.1: + resolution: {integrity: sha512-BSeNnyus75C4//NQ9gQt1/csTXyo/8Sb+afLAkzAptFuMsod9HFokGNudZpi/oQV73hnVK+sR+5PVRMd+Dr7YQ==} + engines: {node: '>=12'} + + clsx@2.1.1: + resolution: {integrity: sha512-eYm0QWBtUrBWZWG0d386OGAw16Z995PiOVo2B7bjWSbHedGl5e0ZWaq65kOGgUSNesEIDkB9ISbTg/JK9dhCZA==} + engines: {node: '>=6'} + + cockatiel@3.2.1: + resolution: {integrity: sha512-gfrHV6ZPkquExvMh9IOkKsBzNDk6sDuZ6DdBGUBkvFnTCqCxzpuq48RySgP0AnaqQkw2zynOFj9yly6T1Q2G5Q==} + engines: {node: '>=16'} + + codemirror@6.0.1: + resolution: {integrity: sha512-J8j+nZ+CdWmIeFIGXEFbFPtpiYacFMDR8GlHK3IyHQJMCaVRfGx9NT+Hxivv1ckLWPvNdZqndbr/7lVhrf/Svg==} + + color-convert@2.0.1: + resolution: {integrity: sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==} + engines: {node: '>=7.0.0'} + + color-name@1.1.4: + resolution: {integrity: sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==} + + combined-stream@1.0.8: + resolution: {integrity: sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg==} + engines: {node: '>= 0.8'} + + comma-separated-tokens@2.0.3: + resolution: {integrity: sha512-Fu4hJdvzeylCfQPp9SGWidpzrMs7tTrlu6Vb8XGaRGck8QSNZJJp538Wrb60Lax4fPwR64ViY468OIUTbRlGZg==} + + command-line-args@6.0.1: + resolution: {integrity: sha512-Jr3eByUjqyK0qd8W0SGFW1nZwqCaNCtbXjRo2cRJC1OYxWl3MZ5t1US3jq+cO4sPavqgw4l9BMGX0CBe+trepg==} + engines: {node: '>=12.20'} + peerDependencies: + '@75lb/nature': latest + peerDependenciesMeta: + '@75lb/nature': + optional: true + + command-line-usage@7.0.3: + resolution: {integrity: sha512-PqMLy5+YGwhMh1wS04mVG44oqDsgyLRSKJBdOo1bnYhMKBW65gZF1dRp2OZRhiTjgUHljy99qkO7bsctLaw35Q==} + engines: {node: '>=12.20.0'} + + commander@12.1.0: + resolution: {integrity: sha512-Vw8qHK3bZM9y/P10u3Vib8o/DdkvA2OtPtZvD871QKjy74Wj1WSKFILMPRPSdUSx5RFK1arlJzEtA4PkFgnbuA==} + engines: {node: '>=18'} + + commander@13.1.0: + resolution: {integrity: sha512-/rFeCpNJQbhSZjGVwO9RFV3xPqbnERS8MmIQzCtD/zl6gpJuV/bMLuN92oG3F7d8oDEHHRrujSXNUr8fpjntKw==} + engines: {node: '>=18'} + + commander@2.20.3: + resolution: {integrity: sha512-GpVkmM8vF2vQUkj2LvZmD35JxeJOLCwJ9cUkugyk2nuhbv3+mJvpLYYt+0+USMxE+oj+ey/lJEnhZw75x/OMcQ==} + + commander@4.1.1: + resolution: {integrity: sha512-NOKm8xhkzAjzFx8B2v5OAHT+u5pRQc2UCa2Vq9jYL/31o2wi9mxBA7LIFs3sV5VSC49z6pEhfbMULvShKj26WA==} + engines: {node: '>= 6'} + + compare-versions@6.1.1: + resolution: {integrity: sha512-4hm4VPpIecmlg59CHXnRDnqGplJFrbLG4aFEl5vl6cK1u76ws3LLvX7ikFnTDl5vo39sjWD6AaDPYodJp/NNHg==} + + concat-map@0.0.1: + resolution: {integrity: sha512-/Srv4dswyQNBfohGpz9o6Yb3Gz3SrUDqBH5rTuhGR7ahtlbYKnVxw2bCFMRljaA7EXHaXZ8wsHdodFvbkhKmqg==} + + confbox@0.1.8: + resolution: {integrity: sha512-RMtmw0iFkeR4YV+fUOSucriAQNb9g8zFR52MWCtl+cCZOFRNL6zeB395vPzFhEjjn4fMxXudmELnl/KF/WrK6w==} + + confbox@0.2.2: + resolution: {integrity: sha512-1NB+BKqhtNipMsov4xI/NnhCKp9XG9NamYp5PVm9klAT0fsrNPjaFICsCFhNhwZJKNh7zB/3q8qXz0E9oaMNtQ==} + + convert-source-map@2.0.0: + resolution: {integrity: sha512-Kvp459HrV2FEJ1CAsi1Ku+MY3kasH19TFykTz2xWmMeq6bk2NU3XXvfJ+Q61m0xktWwt+1HSYf3JZsTms3aRJg==} + + cookie-es@1.2.2: + resolution: {integrity: sha512-+W7VmiVINB+ywl1HGXJXmrqkOhpKrIiVZV6tQuV54ZyQC7MMuBt81Vc336GMLoHBq5hV/F9eXgt5Mnx0Rha5Fg==} + + cookie@1.0.2: + resolution: {integrity: sha512-9Kr/j4O16ISv8zBBhJoi4bXOYNTkFLOqSL3UDB0njXxCXNezjeyVrJyGOWtgfs/q2km1gwBcfH8q1yEGoMYunA==} + engines: {node: '>=18'} + + core-util-is@1.0.3: + resolution: {integrity: sha512-ZQBvi1DcpJ4GDqanjucZ2Hj3wEO5pZDS89BWbkcrvdxksJorwUDDZamX9ldFkp9aw2lmBDLgkObEA4DWNJ9FYQ==} + + cosmiconfig@9.0.0: + resolution: {integrity: sha512-itvL5h8RETACmOTFc4UfIyB2RfEHi71Ax6E/PivVxq9NseKbOWpeyHEOIbmAw1rs8Ak0VursQNww7lf7YtUwzg==} + engines: {node: '>=14'} + peerDependencies: + typescript: '>=4.9.5' + peerDependenciesMeta: + typescript: + optional: true + + crelt@1.0.6: + resolution: {integrity: sha512-VQ2MBenTq1fWZUH9DJNGti7kKv6EeAuYr3cLwxUWhIu1baTaXh4Ib5W2CqHVqib4/MqbYGJqiL3Zb8GJZr3l4g==} + + cronstrue@3.3.0: + resolution: {integrity: sha512-iwJytzJph1hosXC09zY8F5ACDJKerr0h3/2mOxg9+5uuFObYlgK0m35uUPk4GCvhHc2abK7NfnR9oMqY0qZFAg==} + hasBin: true + + cross-spawn@7.0.6: + resolution: {integrity: sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==} + engines: {node: '>= 8'} + + css-select@5.2.2: + resolution: {integrity: sha512-TizTzUddG/xYLA3NXodFM0fSbNizXjOKhqiQQwvhlspadZokn1KDy0NZFS0wuEubIYAV5/c1/lAr0TaaFXEXzw==} + + css-what@6.2.2: + resolution: {integrity: sha512-u/O3vwbptzhMs3L1fQE82ZSLHQQfto5gyZzwteVIEyeaY5Fc7R4dapF/BvRoSYFeqfBk4m0V1Vafq5Pjv25wvA==} + engines: {node: '>= 6'} + + css.escape@1.5.1: + resolution: {integrity: sha512-YUifsXXuknHlUsmlgyY0PKzgPOr7/FjCePfHNt0jxm83wHZi44VDMQ7/fGNkjY3/jV1MC+1CmZbaHzugyeRtpg==} + + cssesc@3.0.0: + resolution: {integrity: sha512-/Tb/JcjK111nNScGob5MNtsntNM1aCNUDipB/TkwZFhyDrrE47SOx/18wF2bbjgc3ZzCSKW1T5nt5EbFoAz/Vg==} + engines: {node: '>=4'} + hasBin: true + + cssstyle@4.6.0: + resolution: {integrity: sha512-2z+rWdzbbSZv6/rhtvzvqeZQHrBaqgogqt85sqFNbabZOuFbCVFb8kPeEtZjiKkbrm395irpNKiYeFeLiQnFPg==} + engines: {node: '>=18'} + + csstype@3.1.3: + resolution: {integrity: sha512-M1uQkMl8rQK/szD0LNhtqxIPLpimGm8sOBwU7lLnCpSbTyY3yeU1Vc7l4KT5zT4s/yOxHH5O7tIuuLOCnLADRw==} + + d3-color@3.1.0: + resolution: {integrity: sha512-zg/chbXyeBtMQ1LbD/WSoW2DpC3I0mpmPdW+ynRTj/x2DAWYrIY7qeZIHidozwV24m4iavr15lNwIwLxRmOxhA==} + engines: {node: '>=12'} + + d3-dispatch@3.0.1: + resolution: {integrity: sha512-rzUyPU/S7rwUflMyLc1ETDeBj0NRuHKKAcvukozwhshr6g6c5d8zh4c2gQjY2bZ0dXeGLWc1PF174P2tVvKhfg==} + engines: {node: '>=12'} + + d3-drag@3.0.0: + resolution: {integrity: sha512-pWbUJLdETVA8lQNJecMxoXfH6x+mO2UQo8rSmZ+QqxcbyA3hfeprFgIT//HW2nlHChWeIIMwS2Fq+gEARkhTkg==} + engines: {node: '>=12'} + + d3-ease@3.0.1: + resolution: {integrity: sha512-wR/XK3D3XcLIZwpbvQwQ5fK+8Ykds1ip7A2Txe0yxncXSdq1L9skcG7blcedkOX+ZcgxGAmLX1FrRGbADwzi0w==} + engines: {node: '>=12'} + + d3-interpolate@3.0.1: + resolution: {integrity: sha512-3bYs1rOD33uo8aqJfKP3JWPAibgw8Zm2+L9vBKEHJ2Rg+viTR7o5Mmv5mZcieN+FRYaAOWX5SJATX6k1PWz72g==} + engines: {node: '>=12'} + + d3-selection@3.0.0: + resolution: {integrity: sha512-fmTRWbNMmsmWq6xJV8D19U/gw/bwrHfNXxrIN+HfZgnzqTHp9jOmKMhsTUjXOJnZOdZY9Q28y4yebKzqDKlxlQ==} + engines: {node: '>=12'} + + d3-timer@3.0.1: + resolution: {integrity: sha512-ndfJ/JxxMd3nw31uyKoY2naivF+r29V+Lc0svZxe1JvvIRmi8hUsrMvdOwgS1o6uBHmiz91geQ0ylPP0aj1VUA==} + engines: {node: '>=12'} + + d3-transition@3.0.1: + resolution: {integrity: sha512-ApKvfjsSR6tg06xrL434C0WydLr7JewBB3V+/39RMHsaXTOG0zmt/OAXeng5M5LBm0ojmxJrpomQVZ1aPvBL4w==} + engines: {node: '>=12'} + peerDependencies: + d3-selection: 2 - 3 + + d3-zoom@3.0.0: + resolution: {integrity: sha512-b8AmV3kfQaqWAuacbPuNbL6vahnOJflOhexLzMMNLga62+/nh0JzvJ0aO/5a5MVgUFGS7Hu1P9P03o3fJkDCyw==} + engines: {node: '>=12'} + + dagre@0.8.5: + resolution: {integrity: sha512-/aTqmnRta7x7MCCpExk7HQL2O4owCT2h8NT//9I1OQ9vt29Pa0BzSAkR5lwFUcQ7491yVi/3CXU9jQ5o0Mn2Sw==} + + data-urls@5.0.0: + resolution: {integrity: sha512-ZYP5VBHshaDAiVZxjbRVcFJpc+4xGgT0bK3vzy1HLN8jTO975HEbuYzZJcHoQEY5K1a0z8YayJkyVETa08eNTg==} + engines: {node: '>=18'} + + data-view-buffer@1.0.2: + resolution: {integrity: sha512-EmKO5V3OLXh1rtK2wgXRansaK1/mtVdTUEiEI0W8RkvgT05kfxaH29PliLnpLP73yYO6142Q72QNa8Wx/A5CqQ==} + engines: {node: '>= 0.4'} + + data-view-byte-length@1.0.2: + resolution: {integrity: sha512-tuhGbE6CfTM9+5ANGf+oQb72Ky/0+s3xKUpHvShfiz2RxMFgFPjsXuRLBVMtvMs15awe45SRb83D6wH4ew6wlQ==} + engines: {node: '>= 0.4'} + + data-view-byte-offset@1.0.1: + resolution: {integrity: sha512-BS8PfmtDGnrgYdOonGZQdLZslWIeCGFP9tpan0hi1Co2Zr2NKADsvGYA8XxuG/4UWgJ6Cjtv+YJnB6MM69QGlQ==} + engines: {node: '>= 0.4'} + + de-indent@1.0.2: + resolution: {integrity: sha512-e/1zu3xH5MQryN2zdVaF0OrdNLUbvWxzMbi+iNA6Bky7l1RoP8a2fIbRocyHclXt/arDrrR6lL3TqFD9pMQTsg==} + + debug@4.4.1: + resolution: {integrity: sha512-KcKCqiftBJcZr++7ykoDIEwSa3XWowTfNPo92BYxjXiyYEVrUQh2aLyhxBCwww+heortUFxEJYcRzosstTEBYQ==} + engines: {node: '>=6.0'} + peerDependencies: + supports-color: '*' + peerDependenciesMeta: + supports-color: + optional: true + + decamelize@4.0.0: + resolution: {integrity: sha512-9iE1PgSik9HeIIw2JO94IidnE3eBoQrFJ3w7sFuzSX4DpmZ3v5sZpUiV5Swcf6mQEF+Y0ru8Neo+p+nyh2J+hQ==} + engines: {node: '>=10'} + + decimal.js@10.6.0: + resolution: {integrity: sha512-YpgQiITW3JXGntzdUmyUR1V812Hn8T1YVXhCu+wO3OpS4eU9l4YdD3qjyiKdV6mvV29zapkMeD390UVEf2lkUg==} + + decode-named-character-reference@1.2.0: + resolution: {integrity: sha512-c6fcElNV6ShtZXmsgNgFFV5tVX2PaV4g+MOAkb8eXHvn6sryJBrZa9r0zV6+dtTyoCKxtDy5tyQ5ZwQuidtd+Q==} + + decompress-response@6.0.0: + resolution: {integrity: sha512-aW35yZM6Bb/4oJlZncMH2LCoZtJXTRxES17vE3hoRiowU2kWHaJKFkSBDnDR+cm9J+9QhXmREyIfv0pji9ejCQ==} + engines: {node: '>=10'} + + deep-eql@5.0.2: + resolution: {integrity: sha512-h5k/5U50IJJFpzfL6nO9jaaumfjO/f2NjK/oYB2Djzm4p9L+3T9qWpZqZ2hAbLPuuYq9wrU08WQyBTL5GbPk5Q==} + engines: {node: '>=6'} + + deep-extend@0.6.0: + resolution: {integrity: sha512-LOHxIOaPYdHlJRtCQfDIVZtfw/ufM8+rVj649RIHzcm/vGwQRXFt6OPqIFWsm2XEMrNIEtWR64sY1LEKD2vAOA==} + engines: {node: '>=4.0.0'} + + deep-is@0.1.4: + resolution: {integrity: sha512-oIPzksmTg4/MriiaYGO+okXDT7ztn/w3Eptv/+gSIdMdKsJo0u4CfYNFJPy+4SKMuCqGw2wxnA+URMg3t8a/bQ==} + + deepmerge@4.3.1: + resolution: {integrity: sha512-3sUqbMEc77XqpdNO7FRyRog+eW3ph+GYCbj+rK+uYyRMuwsVy0rMiVtPn+QJlKFvWP/1PYpapqYn0Me2knFn+A==} + engines: {node: '>=0.10.0'} + + default-browser-id@5.0.0: + resolution: {integrity: sha512-A6p/pu/6fyBcA1TRz/GqWYPViplrftcW2gZC9q79ngNCKAeR/X3gcEdXQHl4KNXV+3wgIJ1CPkJQ3IHM6lcsyA==} + engines: {node: '>=18'} + + default-browser@5.2.1: + resolution: {integrity: sha512-WY/3TUME0x3KPYdRRxEJJvXRHV4PyPoUsxtZa78lwItwRQRHhd2U9xOscaT/YTf8uCXIAjeJOFBVEh/7FtD8Xg==} + engines: {node: '>=18'} + + define-data-property@1.1.4: + resolution: {integrity: sha512-rBMvIzlpA8v6E+SJZoo++HAYqsLrkg7MSfIinMPFhmkorw7X+dOXVJQs+QT69zGkzMyfDnIMN2Wid1+NbL3T+A==} + engines: {node: '>= 0.4'} + + define-lazy-prop@2.0.0: + resolution: {integrity: sha512-Ds09qNh8yw3khSjiJjiUInaGX9xlqZDY7JVryGxdxV7NPeuqQfplOpQ66yJFZut3jLa5zOwkXw1g9EI2uKh4Og==} + engines: {node: '>=8'} + + define-lazy-prop@3.0.0: + resolution: {integrity: sha512-N+MeXYoqr3pOgn8xfyRPREN7gHakLYjhsHhWGT3fWAiL4IkAt0iDw14QiiEm2bE30c5XX5q0FtAA3CK5f9/BUg==} + engines: {node: '>=12'} + + define-properties@1.2.1: + resolution: {integrity: sha512-8QmQKqEASLd5nx0U1B1okLElbUuuttJ/AnYmRXbbbGDWh6uS208EjD4Xqq/I9wK7u0v6O08XhTWnt5XtEbR6Dg==} + engines: {node: '>= 0.4'} + + delayed-stream@1.0.0: + resolution: {integrity: sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ==} + engines: {node: '>=0.4.0'} + + dependency-graph@0.11.0: + resolution: {integrity: sha512-JeMq7fEshyepOWDfcfHK06N3MhyPhz++vtqWhMT5O9A3K42rdsEDpfdVqjaqaAhsw6a+ZqeDvQVtD0hFHQWrzg==} + engines: {node: '>= 0.6.0'} + + dequal@2.0.3: + resolution: {integrity: sha512-0je+qPKHEMohvfRTCEo3CrPG6cAzAYgmzKyxRiYSSDkS6eGJdyVJm7WaYA5ECaAD9wLB2T4EEeymA5aFVcYXCA==} + engines: {node: '>=6'} + + detect-libc@2.0.4: + resolution: {integrity: sha512-3UDv+G9CsCKO1WKMGw9fwq/SWJYbI0c5Y7LU1AXYoDdbhE2AHQ6N6Nb34sG8Fj7T5APy8qXDCKuuIHd1BR0tVA==} + engines: {node: '>=8'} + + detect-node-es@1.1.0: + resolution: {integrity: sha512-ypdmJU/TbBby2Dxibuv7ZLW3Bs1QEmM7nHjEANfohJLvE0XVujisn1qPJcZxg+qDucsr+bP6fLD1rPS3AhJ7EQ==} + + devlop@1.1.0: + resolution: {integrity: sha512-RWmIqhcFf1lRYBvNmr7qTNuyCt/7/ns2jbpp1+PalgE/rDQcBT0fioSMUpJ93irlUhC5hrg4cYqe6U+0ImW0rA==} + + didyoumean@1.2.2: + resolution: {integrity: sha512-gxtyfqMg7GKyhQmb056K7M3xszy/myH8w+B4RT+QXBQsvAOdc3XymqDDPHx1BgPgsdAA5SIifona89YtRATDzw==} + + diff@5.2.0: + resolution: {integrity: sha512-uIFDxqpRZGZ6ThOk84hEfqWoHx2devRFvpTZcTHur85vImfaxUbTW9Ryh4CpCuDnToOP1CEtXKIgytHBPVff5A==} + engines: {node: '>=0.3.1'} + + diff@8.0.2: + resolution: {integrity: sha512-sSuxWU5j5SR9QQji/o2qMvqRNYRDOcBTgsJ/DeCf4iSN4gW+gNMXM7wFIP+fdXZxoNiAnHUTGjCr+TSWXdRDKg==} + engines: {node: '>=0.3.1'} + + dir-glob@3.0.1: + resolution: {integrity: sha512-WkrWp9GR4KXfKGYzOLmTuGVi1UWFfws377n9cc55/tb6DuqyF6pcQ5AbiHEshaDpY9v6oaSr2XCDidGmMwdzIA==} + engines: {node: '>=8'} + + dlv@1.1.3: + resolution: {integrity: sha512-+HlytyjlPKnIG8XuRG8WvmBP8xs8P71y+SKKS6ZXWoEgLuePxtDoUEiH7WkdePWrQ5JBpE6aoVqfZfJUQkjXwA==} + + dnd-core@16.0.1: + resolution: {integrity: sha512-HK294sl7tbw6F6IeuK16YSBUoorvHpY8RHO+9yFfaJyCDVb6n7PRcezrOEOa2SBCqiYpemh5Jx20ZcjKdFAVng==} + + doctrine@3.0.0: + resolution: {integrity: sha512-yS+Q5i3hBf7GBkd4KG8a7eBNNWNGLTaEwwYWUijIYM7zrlYDM0BFXHjjPWlWZ1Rg7UaddZeIDmi9jF3HmqiQ2w==} + engines: {node: '>=6.0.0'} + + dom-accessibility-api@0.5.16: + resolution: {integrity: sha512-X7BJ2yElsnOJ30pZF4uIIDfBEVgF4XEBxL9Bxhy6dnrm5hkzqmsWHGTiHqRiITNhMyFLyAiWndIJP7Z1NTteDg==} + + dom-accessibility-api@0.6.3: + resolution: {integrity: sha512-7ZgogeTnjuHbo+ct10G9Ffp0mif17idi0IyWNVA/wcwcm7NPOD/WEHVP3n7n3MhXqxoIYm8d6MuZohYWIZ4T3w==} + + dom-serializer@2.0.0: + resolution: {integrity: sha512-wIkAryiqt/nV5EQKqQpo3SToSOV9J0DnbJqwK7Wv/Trc92zIAYZ4FlMu+JPFW1DfGFt81ZTCGgDEabffXeLyJg==} + + domelementtype@2.3.0: + resolution: {integrity: sha512-OLETBj6w0OsagBwdXnPdN0cnMfF9opN69co+7ZrbfPGrdpPVNBUj02spi6B1N7wChLQiPn4CSH/zJvXw56gmHw==} + + domhandler@5.0.3: + resolution: {integrity: sha512-cgwlv/1iFQiFnU96XXgROh8xTeetsnJiDsTc7TYCLFd9+/WNkIqPTxiM/8pSd8VIrhXGTf1Ny1q1hquVqDJB5w==} + engines: {node: '>= 4'} + + domutils@3.2.2: + resolution: {integrity: sha512-6kZKyUajlDuqlHKVX1w7gyslj9MPIXzIFiz/rGu35uC1wMi+kMhQwGhl4lt9unC9Vb9INnY9Z3/ZA3+FhASLaw==} + + dunder-proto@1.0.1: + resolution: {integrity: sha512-KIN/nDJBQRcXw0MLVhZE9iQHmG68qAVIBg9CqmUYjmQIhgij9U5MFvrqkUL5FbtyyzZuOeOt0zdeRe4UY7ct+A==} + engines: {node: '>= 0.4'} + + eastasianwidth@0.2.0: + resolution: {integrity: sha512-I88TYZWc9XiYHRQ4/3c5rjjfgkjhLyW2luGIheGERbNQ6OY7yTybanSpDXZa8y7VUP9YmDcYa+eyq4ca7iLqWA==} + + ecdsa-sig-formatter@1.0.11: + resolution: {integrity: sha512-nagl3RYrbNv6kQkeJIpt6NJZy8twLB/2vtz6yN9Z4vRKHN4/QZJIEbqohALSgwKdnksuY3k5Addp5lg8sVoVcQ==} + + editions@6.21.0: + resolution: {integrity: sha512-ofkXJtn7z0urokN62DI3SBo/5xAtF0rR7tn+S/bSYV79Ka8pTajIIl+fFQ1q88DQEImymmo97M4azY3WX/nUdg==} + engines: {node: '>=4'} + + effect@3.17.9: + resolution: {integrity: sha512-Nkkn9n1zhy30Dq0MpQatDCH7nfYnOIiebkOHNxmmvoVnEDKCto+2ZwDDWFGzcN/ojwfqjRXWGC9Lo91K5kwZCg==} + + electron-to-chromium@1.5.227: + resolution: {integrity: sha512-ITxuoPfJu3lsNWUi2lBM2PaBPYgH3uqmxut5vmBxgYvyI4AlJ6P3Cai1O76mOrkJCBzq0IxWg/NtqOrpu/0gKA==} + + elkjs@0.8.2: + resolution: {integrity: sha512-L6uRgvZTH+4OF5NE/MBbzQx/WYpru1xCBE9respNj6qznEewGUIfhzmm7horWWxbNO2M0WckQypGctR8lH79xQ==} + + emoji-regex@10.4.0: + resolution: {integrity: sha512-EC+0oUMY1Rqm4O6LLrgjtYDvcVYTy7chDnM4Q7030tP4Kwj3u/pR6gP9ygnp2CJMK5Gq+9Q2oqmrFJAz01DXjw==} + + emoji-regex@8.0.0: + resolution: {integrity: sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==} + + emoji-regex@9.2.2: + resolution: {integrity: sha512-L18DaJsXSUk2+42pv8mLs5jJT2hqFkFE4j21wOmgbUqsZ2hL72NsUU785g9RXgo3s0ZNgVl42TiHp3ZtOv/Vyg==} + + encoding-sniffer@0.2.1: + resolution: {integrity: sha512-5gvq20T6vfpekVtqrYQsSCFZ1wEg5+wW0/QaZMWkFr6BqD3NfKs0rLCx4rrVlSWJeZb5NBJgVLswK/w2MWU+Gw==} + + end-of-stream@1.4.5: + resolution: {integrity: sha512-ooEGc6HP26xXq/N+GCGOT0JKCLDGrq2bQUZrQ7gyrJiZANJ/8YDTxTpQBXGMn+WbIQXNVpyWymm7KYVICQnyOg==} + + enhanced-resolve@5.18.2: + resolution: {integrity: sha512-6Jw4sE1maoRJo3q8MsSIn2onJFbLTOjY9hlx4DZXmOKvLRd1Ok2kXmAGXaafL2+ijsJZ1ClYbl/pmqr9+k4iUQ==} + engines: {node: '>=10.13.0'} + + enhanced-resolve@5.18.3: + resolution: {integrity: sha512-d4lC8xfavMeBjzGr2vECC3fsGXziXZQyJxD868h2M/mBI3PwAuODxAkLkq5HYuvrPYcUtiLzsTo8U3PgX3Ocww==} + engines: {node: '>=10.13.0'} + + enquirer@2.4.1: + resolution: {integrity: sha512-rRqJg/6gd538VHvR3PSrdRBb/1Vy2YfzHqzvbhGIQpDRKIa4FgV/54b5Q1xYSxOOwKvjXweS26E0Q+nAMwp2pQ==} + engines: {node: '>=8.6'} + + entities@4.5.0: + resolution: {integrity: sha512-V0hjH4dGPh9Ao5p0MoRY6BVqtwCjhz6vI5LT8AJ55H+4g9/4vbHx1I54fS0XuclLhDHArPQCiMjDxjaL8fPxhw==} + engines: {node: '>=0.12'} + + entities@6.0.1: + resolution: {integrity: sha512-aN97NXWF6AWBTahfVOIrB/NShkzi5H7F9r1s9mD3cDj4Ko5f2qhhVoYMibXF7GlLveb/D2ioWay8lxI97Ven3g==} + engines: {node: '>=0.12'} + + env-paths@2.2.1: + resolution: {integrity: sha512-+h1lkLKhZMTYjog1VEpJNG7NZJWcuc2DDk/qsqSTRRCOXiLjeQ1d1/udrUGhqMxUgAlwKNZ0cf2uqan5GLuS2A==} + engines: {node: '>=6'} + + environment@1.1.0: + resolution: {integrity: sha512-xUtoPkMggbz0MPyPiIWr1Kp4aeWJjDZ6SMvURhimjdZgsRuDplF5/s9hcgGhyXMhs+6vpnuoiZ2kFiu3FMnS8Q==} + engines: {node: '>=18'} + + error-ex@1.3.2: + resolution: {integrity: sha512-7dFHNmqeFSEt2ZBsCriorKnn3Z2pj+fd9kmI6QoWw4//DL+icEBfc0U7qJCisqrTsKTjw4fNFy2pW9OqStD84g==} + + es-abstract@1.24.0: + resolution: {integrity: sha512-WSzPgsdLtTcQwm4CROfS5ju2Wa1QQcVeT37jFjYzdFz1r9ahadC8B8/a4qxJxM+09F18iumCdRmlr96ZYkQvEg==} + engines: {node: '>= 0.4'} + + es-aggregate-error@1.0.14: + resolution: {integrity: sha512-3YxX6rVb07B5TV11AV5wsL7nQCHXNwoHPsQC8S4AmBiqYhyNCJ5BRKXkXyDJvs8QzXN20NgRtxe3dEEQD9NLHA==} + engines: {node: '>= 0.4'} + + es-define-property@1.0.1: + resolution: {integrity: sha512-e3nRfgfUZ4rNGL232gUgX06QNyyez04KdjFrF+LTRoOXmrOgFKDg4BCdsjW8EnT69eqdYGmRpJwiPVYNrCaW3g==} + engines: {node: '>= 0.4'} + + es-errors@1.3.0: + resolution: {integrity: sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==} + engines: {node: '>= 0.4'} + + es-module-lexer@1.7.0: + resolution: {integrity: sha512-jEQoCwk8hyb2AZziIOLhDqpm5+2ww5uIE6lkO/6jcOCusfk6LhMHpXXfBLXTZ7Ydyt0j4VoUQv6uGNYbdW+kBA==} + + es-object-atoms@1.1.1: + resolution: {integrity: sha512-FGgH2h8zKNim9ljj7dankFPcICIK9Cp5bm+c2gQSYePhpaG5+esrLODihIorn+Pe6FGJzWhXQotPv73jTaldXA==} + engines: {node: '>= 0.4'} + + es-set-tostringtag@2.1.0: + resolution: {integrity: sha512-j6vWzfrGVfyXxge+O0x5sh6cvxAog0a/4Rdd2K36zCMV5eJ+/+tOAngRO8cODMNWbVRdVlmGZQL2YS3yR8bIUA==} + engines: {node: '>= 0.4'} + + es-to-primitive@1.3.0: + resolution: {integrity: sha512-w+5mJ3GuFL+NjVtJlvydShqE1eN3h3PbI7/5LAsYJP/2qtuMXjfL2LpHSRqo4b4eSF5K/DH1JXKUAHSB2UW50g==} + engines: {node: '>= 0.4'} + + es6-promise@3.3.1: + resolution: {integrity: sha512-SOp9Phqvqn7jtEUxPWdWfWoLmyt2VaJ6MpvP9Comy1MceMXqE6bxvaTu4iaxpYYPzhny28Lc+M87/c2cPK6lDg==} + + esbuild-register@3.6.0: + resolution: {integrity: sha512-H2/S7Pm8a9CL1uhp9OvjwrBh5Pvx0H8qVOxNu8Wed9Y7qv56MPtq+GGM8RJpq6glYJn9Wspr8uw7l55uyinNeg==} + peerDependencies: + esbuild: '>=0.12 <1' + + esbuild@0.25.8: + resolution: {integrity: sha512-vVC0USHGtMi8+R4Kz8rt6JhEWLxsv9Rnu/lGYbPR8u47B+DCBksq9JarW0zOO7bs37hyOK1l2/oqtbciutL5+Q==} + engines: {node: '>=18'} + hasBin: true + + escalade@3.2.0: + resolution: {integrity: sha512-WUj2qlxaQtO4g6Pq5c29GTcWGDyd8itL8zTlipgECz3JesAiiOKotd8JU6otB3PACgG6xkJUyVhboMS+bje/jA==} + engines: {node: '>=6'} + + escape-string-regexp@4.0.0: + resolution: {integrity: sha512-TtpcNJ3XAzx3Gq8sWRzJaVajRs0uVxA2YAkdb1jm2YkPz4G6egUFAyA3n5vtEIZefPk5Wa4UXbKuS5fKkJWdgA==} + engines: {node: '>=10'} + + eslint-plugin-react-hooks@5.2.0: + resolution: {integrity: sha512-+f15FfK64YQwZdJNELETdn5ibXEUQmW1DZL6KXhNnc2heoy/sg9VJJeT7n8TlMWouzWqSWavFkIhHyIbIAEapg==} + engines: {node: '>=10'} + peerDependencies: + eslint: ^3.0.0 || ^4.0.0 || ^5.0.0 || ^6.0.0 || ^7.0.0 || ^8.0.0-0 || ^9.0.0 + + eslint-plugin-storybook@9.1.5: + resolution: {integrity: sha512-vCfaZ2Wk1N1vvK4vmNZoA6y2CYxJwbgIs6BE8/toPf4Z6hCAipoobP6a/30Rs0g/B2TSxTSj41TfrJKJrowpjQ==} + engines: {node: '>=20.0.0'} + peerDependencies: + eslint: '>=8' + storybook: ^9.1.5 + + eslint-scope@5.1.1: + resolution: {integrity: sha512-2NxwbF/hZ0KpepYN0cNbo+FN6XoK7GaHlQhgx/hIZl6Va0bF45RQOOwhLIy8lQDbuCiadSLCBnH2CFYquit5bw==} + engines: {node: '>=8.0.0'} + + eslint-scope@8.4.0: + resolution: {integrity: sha512-sNXOfKCn74rt8RICKMvJS7XKV/Xk9kA7DyJr8mJik3S7Cwgy3qlkkmyS2uQB3jiJg6VNdZd/pDBJu0nvG2NlTg==} + engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + + eslint-visitor-keys@3.4.3: + resolution: {integrity: sha512-wpc+LXeiyiisxPlEkUzU6svyS1frIO3Mgxj1fdy7Pm8Ygzguax2N3Fa/D/ag1WqbOprdI+uY6wMUl8/a2G+iag==} + engines: {node: ^12.22.0 || ^14.17.0 || >=16.0.0} + + eslint-visitor-keys@4.2.1: + resolution: {integrity: sha512-Uhdk5sfqcee/9H/rCOJikYz67o0a2Tw2hGRPOG2Y1R2dg7brRe1uG0yaNQDHu+TO/uQPF/5eCapvYSmHUjt7JQ==} + engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + + eslint@9.31.0: + resolution: {integrity: sha512-QldCVh/ztyKJJZLr4jXNUByx3gR+TDYZCRXEktiZoUR3PGy4qCmSbkxcIle8GEwGpb5JBZazlaJ/CxLidXdEbQ==} + engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + hasBin: true + peerDependencies: + jiti: '*' + peerDependenciesMeta: + jiti: + optional: true + + espree@10.4.0: + resolution: {integrity: sha512-j6PAQ2uUr79PZhBjP5C5fhl8e39FmRnOjsD5lGnWrFU8i2G776tBK7+nP8KuQUTTyAZUwfQqXAgrVH5MbH9CYQ==} + engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + + esprima@4.0.1: + resolution: {integrity: sha512-eGuFFw7Upda+g4p+QHvnW0RyTX/SVeJBDM/gCtMARO0cLuT2HcEKnTPvhjV6aGeqrCB/sbNop0Kszm0jsaWU4A==} + engines: {node: '>=4'} + hasBin: true + + esquery@1.6.0: + resolution: {integrity: sha512-ca9pw9fomFcKPvFLXhBKUK90ZvGibiGOvRJNbjljY7s7uq/5YO4BOzcYtJqExdx99rF6aAcnRxHmcUHcz6sQsg==} + engines: {node: '>=0.10'} + + esrecurse@4.3.0: + resolution: {integrity: sha512-KmfKL3b6G+RXvP8N1vr3Tq1kL/oCFgn2NYXEtqP8/L3pKapUA4G8cFVaoF3SU323CD4XypR/ffioHmkti6/Tag==} + engines: {node: '>=4.0'} + + estraverse@4.3.0: + resolution: {integrity: sha512-39nnKffWz8xN1BU/2c79n9nB9HDzo0niYUqx6xyqUnyoAnQyyWpOTdZEeiCch8BBu515t4wp9ZmgVfVhn9EBpw==} + engines: {node: '>=4.0'} + + estraverse@5.3.0: + resolution: {integrity: sha512-MMdARuVEQziNTeJD8DgMqmhwR11BRQ/cBP+pLtYdSTnf3MIO8fFeiINEbX36ZdNlfU/7A9f3gUw49B3oQsvwBA==} + engines: {node: '>=4.0'} + + estree-util-is-identifier-name@3.0.0: + resolution: {integrity: sha512-hFtqIDZTIUZ9BXLb8y4pYGyk6+wekIivNVTcmvk8NoOh+VeRn5y6cEHzbURrWbfp1fIqdVipilzj+lfaadNZmg==} + + estree-walker@2.0.2: + resolution: {integrity: sha512-Rfkk/Mp/DL7JVje3u18FxFujQlTNR2q6QfMSMB7AvCBx91NGj/ba3kCfza0f6dVDbw7YlRf/nDrn7pQrCCyQ/w==} + + estree-walker@3.0.3: + resolution: {integrity: sha512-7RUKfXgSMMkzt6ZuXmqapOurLGPPfgj6l9uRZ7lRGolvk0y2yocc35LdcxKC5PQZdn2DMqioAQ2NoWcrTKmm6g==} + + esutils@2.0.3: + resolution: {integrity: sha512-kVscqXk4OCp68SZ0dkgEKVi6/8ij300KBWTJq32P/dYeWTSwK41WyTxalN1eRmA5Z9UU/LX9D7FWSmV9SAYx6g==} + engines: {node: '>=0.10.0'} + + event-target-shim@5.0.1: + resolution: {integrity: sha512-i/2XbnSz/uxRCU6+NdVJgKWDTM427+MqYbkQzD321DuCQJUqOuJKIA0IM2+W2xtYHdKOmZ4dR6fExsd4SXL+WQ==} + engines: {node: '>=6'} + + events@3.3.0: + resolution: {integrity: sha512-mQw+2fkQbALzQ7V0MY0IqdnXNOeTtP4r0lN9z7AAawCXgqea7bDii20AYrIBrFd/Hx0M2Ocz6S111CaFkUcb0Q==} + engines: {node: '>=0.8.x'} + + execa@5.1.1: + resolution: {integrity: sha512-8uSpZZocAZRBAPIEINJj3Lo9HyGitllczc27Eh5YYojjMFMn8yHMDMaUHE2Jqfq05D/wucwI4JGURyXt1vchyg==} + engines: {node: '>=10'} + + expand-template@2.0.3: + resolution: {integrity: sha512-XYfuKMvj4O35f/pOXLObndIRvyQ+/+6AhODh+OKWj9S9498pHHn/IMszH+gt0fBCRWMNfk1ZSp5x3AifmnI2vg==} + engines: {node: '>=6'} + + expect-type@1.2.2: + resolution: {integrity: sha512-JhFGDVJ7tmDJItKhYgJCGLOWjuK9vPxiXoUFLwLDc99NlmklilbiQJwoctZtt13+xMw91MCk/REan6MWHqDjyA==} + engines: {node: '>=12.0.0'} + + exsolve@1.0.7: + resolution: {integrity: sha512-VO5fQUzZtI6C+vx4w/4BWJpg3s/5l+6pRQEHzFRM8WFi4XffSP1Z+4qi7GbjWbvRQEbdIco5mIMq+zX4rPuLrw==} + + extend@3.0.2: + resolution: {integrity: sha512-fjquC59cD7CyW6urNXK0FBufkZcoiGG80wTuPujX590cB5Ttln20E2UB4S/WARVqhXffZl2LNgS+gQdPIIim/g==} + + fast-check@3.23.2: + resolution: {integrity: sha512-h5+1OzzfCC3Ef7VbtKdcv7zsstUQwUDlYpUTvjeUsJAssPgLn7QzbboPtL5ro04Mq0rPOsMzl7q5hIbRs2wD1A==} + engines: {node: '>=8.0.0'} + + fast-deep-equal@3.1.3: + resolution: {integrity: sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==} + + fast-glob@3.3.3: + resolution: {integrity: sha512-7MptL8U0cqcFdzIzwOTHoilX9x5BrNqye7Z/LuC7kCMRio1EMSyqRK3BEAUD7sXRq4iT4AzTVuZdhgQ2TCvYLg==} + engines: {node: '>=8.6.0'} + + fast-json-stable-stringify@2.1.0: + resolution: {integrity: sha512-lhd/wF+Lk98HZoTCtlVraHtfh5XYijIjalXck7saUtuanSDyLMxnHhSXEDJqHxD7msR8D0uCmqlkwjCV8xvwHw==} + + fast-levenshtein@2.0.6: + resolution: {integrity: sha512-DCXu6Ifhqcks7TZKY3Hxp3y6qphY5SJZmrWMDrKcERSOXWQdMhU9Ig/PYrzyw/ul9jOIyh0N4M0tbC5hodg8dw==} + + fast-memoize@2.5.2: + resolution: {integrity: sha512-Ue0LwpDYErFbmNnZSF0UH6eImUwDmogUO1jyE+JbN2gsQz/jICm1Ve7t9QT0rNSsfJt+Hs4/S3GnsDVjL4HVrw==} + + fast-safe-stringify@2.1.1: + resolution: {integrity: sha512-W+KJc2dmILlPplD/H4K9l9LcAHAfPtP6BY84uVLXQ6Evcz9Lcg33Y2z1IVblT6xdY54PXYVHEv+0Wpq8Io6zkA==} + + fast-uri@3.0.6: + resolution: {integrity: sha512-Atfo14OibSv5wAp4VWNsFYE1AchQRTv9cBGWET4pZWHzYshFSS9NQI6I57rdKn9croWVMbYFbLhJ+yJvmZIIHw==} + + fastq@1.19.1: + resolution: {integrity: sha512-GwLTyxkCXjXbxqIhTsMI2Nui8huMPtnxg7krajPJAjnEG/iiOS7i+zCtWGZR9G0NBKbXKh6X9m9UIsYX/N6vvQ==} + + fd-slicer@1.1.0: + resolution: {integrity: sha512-cE1qsB/VwyQozZ+q1dGxR8LBYNZeofhEdUNGSMbQD3Gw2lAzX9Zb3uIU6Ebc/Fmyjo9AWWfnn0AUCHqtevs/8g==} + + fdir@6.4.6: + resolution: {integrity: sha512-hiFoqpyZcfNm1yc4u8oWCf9A2c4D3QjCrks3zmoVKVxpQRzmPNar1hUJcBG2RQHvEVGDN+Jm81ZheVLAQMK6+w==} + peerDependencies: + picomatch: ^3 || ^4 + peerDependenciesMeta: + picomatch: + optional: true + + fflate@0.8.2: + resolution: {integrity: sha512-cPJU47OaAoCbg0pBvzsgpTPhmhqI5eJjh/JIu8tPj5q+T7iLvW/JAYUqmE7KOB4R1ZyEhzBaIQpQpardBF5z8A==} + + file-entry-cache@8.0.0: + resolution: {integrity: sha512-XXTUwCvisa5oacNGRP9SfNtYBNAMi+RPwBFmblZEF7N7swHYQS6/Zfk7SRwx4D5j3CH211YNRco1DEMNVfZCnQ==} + engines: {node: '>=16.0.0'} + + filesize@10.1.6: + resolution: {integrity: sha512-sJslQKU2uM33qH5nqewAwVB2QgR6w1aMNsYUp3aN5rMRyXEwJGmZvaWzeJFNTOXWlHQyBFCWrdj3fV/fsTOX8w==} + engines: {node: '>= 10.4.0'} + + fill-range@7.1.1: + resolution: {integrity: sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==} + engines: {node: '>=8'} + + find-replace@5.0.2: + resolution: {integrity: sha512-Y45BAiE3mz2QsrN2fb5QEtO4qb44NcS7en/0y9PEVsg351HsLeVclP8QPMH79Le9sH3rs5RSwJu99W0WPZO43Q==} + engines: {node: '>=14'} + peerDependencies: + '@75lb/nature': latest + peerDependenciesMeta: + '@75lb/nature': + optional: true + + find-up@5.0.0: + resolution: {integrity: sha512-78/PXT1wlLLDgTzDs7sjq9hzz0vXD+zn+7wypEe4fXQxCmdmqfGsEPQxmiCSQI3ajFV91bVSsvNtrJRiW6nGng==} + engines: {node: '>=10'} + + find-up@7.0.0: + resolution: {integrity: sha512-YyZM99iHrqLKjmt4LJDj58KI+fYyufRLBSYcqycxf//KpBk9FoewoGX0450m9nB44qrZnovzC2oeP5hUibxc/g==} + engines: {node: '>=18'} + + flat-cache@4.0.1: + resolution: {integrity: sha512-f7ccFPK3SXFHpx15UIGyRJ/FJQctuKZ0zVuN3frBo4HnK3cay9VEW0R6yPYFHC0AgqhukPzKjq22t5DmAyqGyw==} + engines: {node: '>=16'} + + flat@5.0.2: + resolution: {integrity: sha512-b6suED+5/3rTpUBdG1gupIl8MPFCAMA0QXwmljLhvCUKcUvdE4gWky9zpuGCcXHOsz4J9wPGNWq6OKpmIzz3hQ==} + hasBin: true + + flatbuffers@24.12.23: + resolution: {integrity: sha512-dLVCAISd5mhls514keQzmEG6QHmUUsNuWsb4tFafIUwvvgDjXhtfAYSKOzt5SWOy+qByV5pbsDZ+Vb7HUOBEdA==} + + flatted@3.3.3: + resolution: {integrity: sha512-GX+ysw4PBCz0PzosHDepZGANEuFCMLrnRTiEy9McGjmkCQYwRq4A/X786G/fjM/+OjsWSU1ZrY5qyARZmO/uwg==} + + for-each@0.3.5: + resolution: {integrity: sha512-dKx12eRCVIzqCxFGplyFKJMPvLEWgmNtUrpTiJIR5u97zEhRG8ySrtboPHZXx7daLxQVrl643cTzbab2tkQjxg==} + engines: {node: '>= 0.4'} + + foreground-child@3.3.1: + resolution: {integrity: sha512-gIXjKqtFuWEgzFRJA9WCQeSJLZDjgJUOMCMzxtvFq/37KojM1BFGufqsCy0r4qSQmYLsZYMeyRqzIWOMup03sw==} + engines: {node: '>=14'} + + form-data@4.0.4: + resolution: {integrity: sha512-KrGhL9Q4zjj0kiUt5OO4Mr/A/jlI2jDYs5eHBpYHPcBEVSiipAvn2Ko2HnPe20rmcuuvMHNdZFp+4IlGTMF0Ow==} + engines: {node: '>= 6'} + + fraction.js@4.3.7: + resolution: {integrity: sha512-ZsDfxO51wGAXREY55a7la9LScWpwv9RxIrYABrlvOFBlH/ShPnrtsXeuUIfXKKOVicNxQ+o8JTbJvjS4M89yew==} + + fs-constants@1.0.0: + resolution: {integrity: sha512-y6OAwoSIf7FyjMIv94u+b5rdheZEjzR63GTyZJm5qh4Bi+2YgwLCcI/fPFZkL5PSixOt6ZNKm+w+Hfp/Bciwow==} + + fs-extra@11.3.0: + resolution: {integrity: sha512-Z4XaCL6dUDHfP/jT25jJKMmtxvuwbkrD1vNSMFlo9lNLY2c5FHYSQgHPRZUjAB26TpDEoW9HCOgplrdbaPV/ew==} + engines: {node: '>=14.14'} + + fs.realpath@1.0.0: + resolution: {integrity: sha512-OO0pH2lK6a0hZnAdau5ItzHPI6pUlvI7jMVnxUQRtw4owF2wk8lOSabtGDCTP4Ggrg2MbGnWO9X8K1t4+fGMDw==} + + fsevents@2.3.2: + resolution: {integrity: sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA==} + engines: {node: ^8.16.0 || ^10.6.0 || >=11.0.0} + os: [darwin] + + fsevents@2.3.3: + resolution: {integrity: sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==} + engines: {node: ^8.16.0 || ^10.6.0 || >=11.0.0} + os: [darwin] + + function-bind@1.1.2: + resolution: {integrity: sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==} + + function.prototype.name@1.1.8: + resolution: {integrity: sha512-e5iwyodOHhbMr/yNrc7fDYG4qlbIvI5gajyzPnb5TCwyhjApznQh1BMFou9b30SevY43gCJKXycoCBjMbsuW0Q==} + engines: {node: '>= 0.4'} + + functions-have-names@1.2.3: + resolution: {integrity: sha512-xckBUXyTIqT97tq2x2AMb+g163b5JFysYk0x4qxNFwbfQkmNZoiRHb6sPzI9/QV33WeuvVYBUIiD4NzNIyqaRQ==} + + fuse.js@7.1.0: + resolution: {integrity: sha512-trLf4SzuuUxfusZADLINj+dE8clK1frKdmqiJNb1Es75fmI5oY6X2mxLVUciLLjxqw/xr72Dhy+lER6dGd02FQ==} + engines: {node: '>=10'} + + gensync@1.0.0-beta.2: + resolution: {integrity: sha512-3hN7NaskYvMDLQY55gnW3NQ+mesEAepTqlg+VEbj7zzqEMBVNhzcGYYeqFo/TlYz6eQiFcp1HcsCZO+nGgS8zg==} + engines: {node: '>=6.9.0'} + + get-caller-file@2.0.5: + resolution: {integrity: sha512-DyFP3BM/3YHTQOCUL/w0OZHR0lpKeGrxotcHWcqNEdnltqFwXVfhEBQ94eIo34AfQpo0rGki4cyIiftY06h2Fg==} + engines: {node: 6.* || 8.* || >= 10.*} + + get-east-asian-width@1.3.0: + resolution: {integrity: sha512-vpeMIQKxczTD/0s2CdEWHcb0eeJe6TFjxb+J5xgX7hScxqrGuyjmv4c1D4A/gelKfyox0gJJwIHF+fLjeaM8kQ==} + engines: {node: '>=18'} + + get-intrinsic@1.3.0: + resolution: {integrity: sha512-9fSjSaos/fRIVIp+xSJlE6lfwhES7LNtKaCBIamHsjr2na1BiABJPo0mOjjz8GJDURarmCPGqaiVg5mfjb98CQ==} + engines: {node: '>= 0.4'} + + get-nonce@1.0.1: + resolution: {integrity: sha512-FJhYRoDaiatfEkUK8HKlicmu/3SGFD51q3itKDGoSTysQJBnfOcxU5GxnhE1E6soB76MbT0MBtnKJuXyAx+96Q==} + engines: {node: '>=6'} + + get-proto@1.0.1: + resolution: {integrity: sha512-sTSfBjoXBp89JvIKIefqw7U2CCebsc74kiY6awiGogKtoSGbgjYE/G/+l9sF3MWFPNc9IcoOC4ODfKHfxFmp0g==} + engines: {node: '>= 0.4'} + + get-stream@6.0.1: + resolution: {integrity: sha512-ts6Wi+2j3jQjqi70w5AlN8DFnkSwC+MqmxEzdEALB2qXZYV3X/b1CTfgPLGJNMeAWxdPfU8FO1ms3NUfaHCPYg==} + engines: {node: '>=10'} + + get-symbol-description@1.1.0: + resolution: {integrity: sha512-w9UMqWwJxHNOvoNzSJ2oPF5wvYcvP7jUvYzhp67yEhTi17ZDBBC1z9pTdGuzjD+EFIqLSYRweZjqfiPzQ06Ebg==} + engines: {node: '>= 0.4'} + + get-tsconfig@4.10.1: + resolution: {integrity: sha512-auHyJ4AgMz7vgS8Hp3N6HXSmlMdUyhSUrfBF16w153rxtLIEOE+HGqaBppczZvnHLqQJfiHotCYpNhl0lUROFQ==} + + github-from-package@0.0.0: + resolution: {integrity: sha512-SyHy3T1v2NUXn29OsWdxmK6RwHD+vkj3v8en8AOBZ1wBQ/hCAQ5bAQTD02kW4W9tUp/3Qh6J8r9EvntiyCmOOw==} + + glob-parent@5.1.2: + resolution: {integrity: sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==} + engines: {node: '>= 6'} + + glob-parent@6.0.2: + resolution: {integrity: sha512-XxwI8EOhVQgWp6iDL+3b0r86f4d6AX6zSU55HfB4ydCEuXLXc5FcYeOu+nnGftS4TEju/11rt4KJPTMgbfmv4A==} + engines: {node: '>=10.13.0'} + + glob-to-regexp@0.4.1: + resolution: {integrity: sha512-lkX1HJXwyMcprw/5YUZc2s7DrpAiHB21/V+E1rHUrVNokkvB6bqMzT0VfV6/86ZNabt1k14YOIaT7nDvOX3Iiw==} + + glob@10.4.5: + resolution: {integrity: sha512-7Bv8RF0k6xjo7d4A/PxYLbUCfb6c+Vpd2/mB2yRDlew7Jb5hEXiCD9ibfO7wpk8i4sevK6DFny9h7EYbM3/sHg==} + hasBin: true + + glob@11.0.3: + resolution: {integrity: sha512-2Nim7dha1KVkaiF4q6Dj+ngPPMdfvLJEOpZk/jKiUAkqKebpGAWQXAq9z1xu9HKu5lWfqw/FASuccEjyznjPaA==} + engines: {node: 20 || >=22} + hasBin: true + + glob@7.2.3: + resolution: {integrity: sha512-nFR0zLpU2YCaRxwoCJvL6UvCH2JFyFVIvwTLsIf21AuHlMskA1hhTdk+LlYJtOlYt9v6dvszD2BGRqBL+iQK9Q==} + deprecated: Glob versions prior to v9 are no longer supported + + glob@8.1.0: + resolution: {integrity: sha512-r8hpEjiQEYlF2QU0df3dS+nxxSIreXQS1qRhMJM0Q5NDdR386C7jb7Hwwod8Fgiuex+k0GFjgft18yvxm5XoCQ==} + engines: {node: '>=12'} + deprecated: Glob versions prior to v9 are no longer supported + + globals@14.0.0: + resolution: {integrity: sha512-oahGvuMGQlPw/ivIYBjVSrWAfWLBeku5tpPE2fOPLi+WHffIWbuh2tCjhyQhTBPMf5E9jDEH4FOmTYgYwbKwtQ==} + engines: {node: '>=18'} + + globals@16.3.0: + resolution: {integrity: sha512-bqWEnJ1Nt3neqx2q5SFfGS8r/ahumIakg3HcwtNlrVlwXIeNumWn/c7Pn/wKzGhf6SaW6H6uWXLqC30STCMchQ==} + engines: {node: '>=18'} + + globalthis@1.0.4: + resolution: {integrity: sha512-DpLKbNU4WylpxJykQujfCcwYWiV/Jhm50Goo0wrVILAv5jOr9d+H+UR3PhSCD2rCCEIg0uc+G+muBTwD54JhDQ==} + engines: {node: '>= 0.4'} + + globby@11.1.0: + resolution: {integrity: sha512-jhIXaOzy1sb8IyocaruWSn1TjmnBVs8Ayhcy83rmxNJ8q2uWKCAj3CnJY+KpGSXCueAPc0i05kVvVKtP1t9S3g==} + engines: {node: '>=10'} + + globby@14.1.0: + resolution: {integrity: sha512-0Ia46fDOaT7k4og1PDW4YbodWWr3scS2vAr2lTbsplOt2WkKp0vQbkI9wKis/T5LV/dqPjO3bpS/z6GTJB82LA==} + engines: {node: '>=18'} + + goober@2.1.16: + resolution: {integrity: sha512-erjk19y1U33+XAMe1VTvIONHYoSqE4iS7BYUZfHaqeohLmnC0FdxEh7rQU+6MZ4OajItzjZFSRtVANrQwNq6/g==} + peerDependencies: + csstype: ^3.0.10 + + gopd@1.2.0: + resolution: {integrity: sha512-ZUKRh6/kUFoAiTAtTYPZJ3hw9wNxx+BIBOijnlG9PnrJsCcSjs1wyyD6vJpaYtgnzDrKYRSqf3OO6Rfa93xsRg==} + engines: {node: '>= 0.4'} + + graceful-fs@4.2.11: + resolution: {integrity: sha512-RbJ5/jmFcNNCcDV5o9eTnBLJ/HszWV0P73bc+Ff4nS/rJj+YaS6IGyiOL0VoBYX+l1Wrl3k63h/KrH+nhJ0XvQ==} + + graphemer@1.4.0: + resolution: {integrity: sha512-EtKwoO6kxCL9WO5xipiHTZlSzBm7WLT627TqC/uVRd0HKmq8NXyebnNYxDoBi7wt8eTWrUrKXCOVaFq9x1kgag==} + + graphlib@2.1.8: + resolution: {integrity: sha512-jcLLfkpoVGmH7/InMC/1hIvOPSUh38oJtGhvrOFGzioE1DZ+0YW16RgmOJhHiuWTvGiJQ9Z1Ik43JvkRPRvE+A==} + + has-bigints@1.1.0: + resolution: {integrity: sha512-R3pbpkcIqv2Pm3dUwgjclDRVmWpTJW2DcMzcIhEXEx1oh/CEMObMm3KLmRJOdvhM7o4uQBnwr8pzRK2sJWIqfg==} + engines: {node: '>= 0.4'} + + has-flag@4.0.0: + resolution: {integrity: sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==} + engines: {node: '>=8'} + + has-property-descriptors@1.0.2: + resolution: {integrity: sha512-55JNKuIW+vq4Ke1BjOTjM2YctQIvCT7GFzHwmfZPGo5wnrgkid0YQtnAleFSqumZm4az3n2BS+erby5ipJdgrg==} + + has-proto@1.2.0: + resolution: {integrity: sha512-KIL7eQPfHQRC8+XluaIw7BHUwwqL19bQn4hzNgdr+1wXoU0KKj6rufu47lhY7KbJR2C6T6+PfyN0Ea7wkSS+qQ==} + engines: {node: '>= 0.4'} + + has-symbols@1.1.0: + resolution: {integrity: sha512-1cDNdwJ2Jaohmb3sg4OmKaMBwuC48sYni5HUw2DvsC8LjGTLK9h+eb1X6RyuOHe4hT0ULCW68iomhjUoKUqlPQ==} + engines: {node: '>= 0.4'} + + has-tostringtag@1.0.2: + resolution: {integrity: sha512-NqADB8VjPFLM2V0VvHUewwwsw0ZWBaIdgo+ieHtK3hasLz4qeCRjYcqfB6AQrBggRKppKF8L52/VqdVsO47Dlw==} + engines: {node: '>= 0.4'} + + hasown@2.0.2: + resolution: {integrity: sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==} + engines: {node: '>= 0.4'} + + hast-util-to-jsx-runtime@2.3.6: + resolution: {integrity: sha512-zl6s8LwNyo1P9uw+XJGvZtdFF1GdAkOg8ujOw+4Pyb76874fLps4ueHXDhXWdk6YHQ6OgUtinliG7RsYvCbbBg==} + + hast-util-whitespace@3.0.0: + resolution: {integrity: sha512-88JUN06ipLwsnv+dVn+OIYOvAuvBMy/Qoi6O7mQHxdPXpjy+Cd6xRkWwux7DKO+4sYILtLBRIKgsdpS2gQc7qw==} + + he@1.2.0: + resolution: {integrity: sha512-F/1DnUGPopORZi0ni+CvrCgHQ5FyEAHRLSApuYWMmrbSwoN2Mn/7k+Gl38gJnR7yyDZk6WLXwiGod1JOWNDKGw==} + hasBin: true + + hoist-non-react-statics@3.3.2: + resolution: {integrity: sha512-/gGivxi8JPKWNm/W0jSmzcMPpfpPLc3dY/6GxhX2hQ9iGj3aDfklV4ET7NjKpSinLpJ5vafa9iiGIEZg10SfBw==} + + hosted-git-info@4.1.0: + resolution: {integrity: sha512-kyCuEOWjJqZuDbRHzL8V93NzQhwIB71oFWSyzVo+KPZI+pnQPPxucdkrOZvkLRnrf5URsQM+IJ09Dw29cRALIA==} + engines: {node: '>=10'} + + hosted-git-info@7.0.2: + resolution: {integrity: sha512-puUZAUKT5m8Zzvs72XWy3HtvVbTWljRE66cP60bxJzAqf2DgICo7lYTY2IHUmLnNpjYvw5bvmoHvPc0QO2a62w==} + engines: {node: ^16.14.0 || >=18.0.0} + + hosted-git-info@8.1.0: + resolution: {integrity: sha512-Rw/B2DNQaPBICNXEm8balFz9a6WpZrkCGpcWFpy7nCj+NyhSdqXipmfvtmWt9xGfp0wZnBxB+iVpLmQMYt47Tw==} + engines: {node: ^18.17.0 || >=20.5.0} + + html-encoding-sniffer@4.0.0: + resolution: {integrity: sha512-Y22oTqIU4uuPgEemfz7NDJz6OeKf12Lsu+QC+s3BVpda64lTiMYCyGwg5ki4vFxkMwQdeZDl2adZoqUgdFuTgQ==} + engines: {node: '>=18'} + + html-escaper@2.0.2: + resolution: {integrity: sha512-H2iMtd0I4Mt5eYiapRdIDjp+XzelXQ0tFE4JS7YFwFevXXMmOp9myNrUvCg0D6ws8iqkRPBfKHgbwig1SmlLfg==} + + html-url-attributes@3.0.1: + resolution: {integrity: sha512-ol6UPyBWqsrO6EJySPz2O7ZSr856WDrEzM5zMqp+FJJLGMW35cLYmmZnl0vztAZxRUoNZJFTCohfjuIJ8I4QBQ==} + + htmlparser2@10.0.0: + resolution: {integrity: sha512-TwAZM+zE5Tq3lrEHvOlvwgj1XLWQCtaaibSN11Q+gGBAS7Y1uZSWwXXRe4iF6OXnaq1riyQAPFOBtYc77Mxq0g==} + + http-proxy-agent@7.0.2: + resolution: {integrity: sha512-T1gkAiYYDWYx3V5Bmyu7HcfcvL7mUrTWiM6yOfa3PIphViJ/gFPbvidQ+veqSOHci/PxBcDabeUNCzpOODJZig==} + engines: {node: '>= 14'} + + http2-client@1.3.5: + resolution: {integrity: sha512-EC2utToWl4RKfs5zd36Mxq7nzHHBuomZboI0yYL6Y0RmBgT7Sgkq4rQ0ezFTYoIsSs7Tm9SJe+o2FcAg6GBhGA==} + + https-proxy-agent@7.0.6: + resolution: {integrity: sha512-vK9P5/iUfdl95AI+JVyUuIcVtd4ofvtrOr3HNtM2yxC9bnMbEdp3x01OhQNnjb8IJYi38VlTE3mBXwcfvywuSw==} + engines: {node: '>= 14'} + + human-signals@2.1.0: + resolution: {integrity: sha512-B4FFZ6q/T2jhhksgkbEW3HBvWIfDW85snkQgawt07S7J5QXTk6BkNV+0yAeZrM5QpMAdYlocGoljn0sJ/WQkFw==} + engines: {node: '>=10.17.0'} + + iconv-lite@0.6.3: + resolution: {integrity: sha512-4fCk79wshMdzMp2rH06qWrJE4iolqLhCUH+OiuIgU++RB0+94NlDL81atO7GX55uUKueo0txHNtvEyI6D7WdMw==} + engines: {node: '>=0.10.0'} + + ieee754@1.2.1: + resolution: {integrity: sha512-dcyqhDvX1C46lXZcVqCpK+FtMRQVdIMN6/Df5js2zouUsqG7I6sFxitIC+7KYK29KdXOLHdu9zL4sFnoVQnqaA==} + + ignore@5.3.2: + resolution: {integrity: sha512-hsBTNUqQTDwkWtcdYI2i06Y/nUBEsNEDJKjWdigLvegy8kDuJAS8uRlpkkcQpyEXL0Z/pjDy5HBmMjRCJ2gq+g==} + engines: {node: '>= 4'} + + ignore@7.0.5: + resolution: {integrity: sha512-Hs59xBNfUIunMFgWAbGX5cq6893IbWg4KnrjbYwX3tx0ztorVgTDA6B2sxf8ejHJ4wz8BqGUMYlnzNBer5NvGg==} + engines: {node: '>= 4'} + + immediate@3.0.6: + resolution: {integrity: sha512-XXOFtyqDjNDAQxVfYxuF7g9Il/IbWmmlQg2MYKOH8ExIT1qg6xc4zyS3HaEEATgs1btfzxq15ciUiY7gjSXRGQ==} + + immer@9.0.21: + resolution: {integrity: sha512-bc4NBHqOqSfRW7POMkHd51LvClaeMXpm8dx0e8oE2GORbq5aRK7Bxl4FyzVLdGtLmvLKL7BTDBG5ACQm4HWjTA==} + + import-fresh@3.3.1: + resolution: {integrity: sha512-TR3KfrTZTYLPB6jUjfx6MF9WcWrHL9su5TObK4ZkYgBdWKPOFoSoQIdEuTuR82pmtxH2spWG9h6etwfr1pLBqQ==} + engines: {node: '>=6'} + + import-lazy@4.0.0: + resolution: {integrity: sha512-rKtvo6a868b5Hu3heneU+L4yEQ4jYKLtjpnPeUdK7h0yzXGmyBTypknlkCvHFBqfX9YlorEiMM6Dnq/5atfHkw==} + engines: {node: '>=8'} + + imurmurhash@0.1.4: + resolution: {integrity: sha512-JmXMZ6wuvDmLiHEml9ykzqO6lwFbof0GG4IkcGaENdCRDDmMVnny7s5HsIgHCbaq0w2MyPhDqkhTUgS2LU2PHA==} + engines: {node: '>=0.8.19'} + + indent-string@4.0.0: + resolution: {integrity: sha512-EdDDZu4A2OyIK7Lr/2zG+w5jmbuk1DVBnEwREQvBzspBJkCEbRa8GxU1lghYcaGJCnRWibjDXlq779X1/y5xwg==} + engines: {node: '>=8'} + + index-to-position@1.1.0: + resolution: {integrity: sha512-XPdx9Dq4t9Qk1mTMbWONJqU7boCoumEH7fRET37HX5+khDUl3J2W6PdALxhILYlIYx2amlwYcRPp28p0tSiojg==} + engines: {node: '>=18'} + + inflight@1.0.6: + resolution: {integrity: sha512-k92I/b08q4wvFscXCLvqfsHCrjrF7yiXsQuIVvVE7N82W3+aqpzuUdBbfhWcy/FZR3/4IgflMgKLOsvPDrGCJA==} + deprecated: This module is not supported, and leaks memory. Do not use it. Check out lru-cache if you want a good and tested way to coalesce async requests by a key value, which is much more comprehensive and powerful. + + inherits@2.0.4: + resolution: {integrity: sha512-k/vGaX4/Yla3WzyMCvTQOXYeIHvqOKtnqBduzTHpzpQZzAskKMhZ2K+EnBiSM9zGSoIFeMpXKxa4dYeZIQqewQ==} + + ini@1.3.8: + resolution: {integrity: sha512-JV/yugV2uzW5iMRSiZAyDtQd+nxtUnjeLt0acNdw98kKLrvuRVyB80tsREOE7yvGVgalhZ6RNXCmEHkUKBKxew==} + + inline-style-parser@0.2.4: + resolution: {integrity: sha512-0aO8FkhNZlj/ZIbNi7Lxxr12obT7cL1moPfE4tg1LkX7LlLfC6DeX4l2ZEud1ukP9jNQyNnfzQVqwbwmAATY4Q==} + + internal-slot@1.1.0: + resolution: {integrity: sha512-4gd7VpWNQNB4UKKCFFVcp1AVv+FMOgs9NKzjHKusc8jTMhd5eL1NqQqOpE0KzMds804/yHlglp3uxgluOqAPLw==} + engines: {node: '>= 0.4'} + + is-alphabetical@2.0.1: + resolution: {integrity: sha512-FWyyY60MeTNyeSRpkM2Iry0G9hpr7/9kD40mD/cGQEuilcZYS4okz8SN2Q6rLCJ8gbCt6fN+rC+6tMGS99LaxQ==} + + is-alphanumerical@2.0.1: + resolution: {integrity: sha512-hmbYhX/9MUMF5uh7tOXyK/n0ZvWpad5caBA17GsC6vyuCqaWliRG5K1qS9inmUhEMaOBIW7/whAnSwveW/LtZw==} + + is-array-buffer@3.0.5: + resolution: {integrity: sha512-DDfANUiiG2wC1qawP66qlTugJeL5HyzMpfr8lLK+jMQirGzNod0B12cFB/9q838Ru27sBwfw78/rdoU7RERz6A==} + engines: {node: '>= 0.4'} + + is-arrayish@0.2.1: + resolution: {integrity: sha512-zz06S8t0ozoDXMG+ube26zeCTNXcKIPJZJi8hBrF4idCLms4CG9QtK7qBl1boi5ODzFpjswb5JPmHCbMpjaYzg==} + + is-async-function@2.1.1: + resolution: {integrity: sha512-9dgM/cZBnNvjzaMYHVoxxfPj2QXt22Ev7SuuPrs+xav0ukGB0S6d4ydZdEiM48kLx5kDV+QBPrpVnFyefL8kkQ==} + engines: {node: '>= 0.4'} + + is-bigint@1.1.0: + resolution: {integrity: sha512-n4ZT37wG78iz03xPRKJrHTdZbe3IicyucEtdRsV5yglwc3GyUfbAfpSeD0FJ41NbUNSt5wbhqfp1fS+BgnvDFQ==} + engines: {node: '>= 0.4'} + + is-binary-path@2.1.0: + resolution: {integrity: sha512-ZMERYes6pDydyuGidse7OsHxtbI7WVeUEozgR/g7rd0xUimYNlvZRE/K2MgZTjWy725IfelLeVcEM97mmtRGXw==} + engines: {node: '>=8'} + + is-boolean-object@1.2.2: + resolution: {integrity: sha512-wa56o2/ElJMYqjCjGkXri7it5FbebW5usLw/nPmCMs5DeZ7eziSYZhSmPRn0txqeW4LnAmQQU7FgqLpsEFKM4A==} + engines: {node: '>= 0.4'} + + is-callable@1.2.7: + resolution: {integrity: sha512-1BC0BVFhS/p0qtw6enp8e+8OD0UrK0oFLztSjNzhcKA3WDuJxxAPXzPuPtKkjEY9UUoEWlX/8fgKeu2S8i9JTA==} + engines: {node: '>= 0.4'} + + is-core-module@2.16.1: + resolution: {integrity: sha512-UfoeMA6fIJ8wTYFEUjelnaGI67v6+N7qXJEvQuIGa99l4xsCruSYOVSQ0uPANn4dAzm8lkYPaKLrrijLq7x23w==} + engines: {node: '>= 0.4'} + + is-data-view@1.0.2: + resolution: {integrity: sha512-RKtWF8pGmS87i2D6gqQu/l7EYRlVdfzemCJN/P3UOs//x1QE7mfhvzHIApBTRf7axvT6DMGwSwBXYCT0nfB9xw==} + engines: {node: '>= 0.4'} + + is-date-object@1.1.0: + resolution: {integrity: sha512-PwwhEakHVKTdRNVOw+/Gyh0+MzlCl4R6qKvkhuvLtPMggI1WAHt9sOwZxQLSGpUaDnrdyDsomoRgNnCfKNSXXg==} + engines: {node: '>= 0.4'} + + is-decimal@2.0.1: + resolution: {integrity: sha512-AAB9hiomQs5DXWcRB1rqsxGUstbRroFOPPVAomNk/3XHR5JyEZChOyTWe2oayKnsSsr/kcGqF+z6yuH6HHpN0A==} + + is-docker@2.2.1: + resolution: {integrity: sha512-F+i2BKsFrH66iaUFc0woD8sLy8getkwTwtOBjvs56Cx4CgJDeKQeqfz8wAYiSb8JOprWhHH5p77PbmYCvvUuXQ==} + engines: {node: '>=8'} + hasBin: true + + is-docker@3.0.0: + resolution: {integrity: sha512-eljcgEDlEns/7AXFosB5K/2nCM4P7FQPkGc/DWLy5rmFEWvZayGrik1d9/QIY5nJ4f9YsVvBkA6kJpHn9rISdQ==} + engines: {node: ^12.20.0 || ^14.13.1 || >=16.0.0} + hasBin: true + + is-extglob@2.1.1: + resolution: {integrity: sha512-SbKbANkN603Vi4jEZv49LeVJMn4yGwsbzZworEoyEiutsN3nJYdbO36zfhGJ6QEDpOZIFkDtnq5JRxmvl3jsoQ==} + engines: {node: '>=0.10.0'} + + is-finalizationregistry@1.1.1: + resolution: {integrity: sha512-1pC6N8qWJbWoPtEjgcL2xyhQOP491EQjeUo3qTKcmV8YSDDJrOepfG8pcC7h/QgnQHYSv0mJ3Z/ZWxmatVrysg==} + engines: {node: '>= 0.4'} + + is-fullwidth-code-point@3.0.0: + resolution: {integrity: sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg==} + engines: {node: '>=8'} + + is-generator-function@1.1.0: + resolution: {integrity: sha512-nPUB5km40q9e8UfN/Zc24eLlzdSf9OfKByBw9CIdw4H1giPMeA0OIJvbchsCu4npfI2QcMVBsGEBHKZ7wLTWmQ==} + engines: {node: '>= 0.4'} + + is-glob@4.0.3: + resolution: {integrity: sha512-xelSayHH36ZgE7ZWhli7pW34hNbNl8Ojv5KVmkJD4hBdD3th8Tfk9vYasLM+mXWOZhFkgZfxhLSnrwRr4elSSg==} + engines: {node: '>=0.10.0'} + + is-hexadecimal@2.0.1: + resolution: {integrity: sha512-DgZQp241c8oO6cA1SbTEWiXeoxV42vlcJxgH+B3hi1AiqqKruZR3ZGF8In3fj4+/y/7rHvlOZLZtgJ/4ttYGZg==} + + is-inside-container@1.0.0: + resolution: {integrity: sha512-KIYLCCJghfHZxqjYBE7rEy0OBuTd5xCHS7tHVgvCLkx7StIoaxwNW3hCALgEUjFfeRk+MG/Qxmp/vtETEF3tRA==} + engines: {node: '>=14.16'} + hasBin: true + + is-interactive@2.0.0: + resolution: {integrity: sha512-qP1vozQRI+BMOPcjFzrjXuQvdak2pHNUMZoeG2eRbiSqyvbEf/wQtEOTOX1guk6E3t36RkaqiSt8A/6YElNxLQ==} + engines: {node: '>=12'} + + is-map@2.0.3: + resolution: {integrity: sha512-1Qed0/Hr2m+YqxnM09CjA2d/i6YZNfF6R2oRAOj36eUdS6qIV/huPJNSEpKbupewFs+ZsJlxsjjPbc0/afW6Lw==} + engines: {node: '>= 0.4'} + + is-negative-zero@2.0.3: + resolution: {integrity: sha512-5KoIu2Ngpyek75jXodFvnafB6DJgr3u8uuK0LEZJjrU19DrMD3EVERaR8sjz8CCGgpZvxPl9SuE1GMVPFHx1mw==} + engines: {node: '>= 0.4'} + + is-number-object@1.1.1: + resolution: {integrity: sha512-lZhclumE1G6VYD8VHe35wFaIif+CTy5SJIi5+3y4psDgWu4wPDoBhF8NxUOinEc7pHgiTsT6MaBb92rKhhD+Xw==} + engines: {node: '>= 0.4'} + + is-number@7.0.0: + resolution: {integrity: sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng==} + engines: {node: '>=0.12.0'} + + is-plain-obj@2.1.0: + resolution: {integrity: sha512-YWnfyRwxL/+SsrWYfOpUtz5b3YD+nyfkHvjbcanzk8zgyO4ASD67uVMRt8k5bM4lLMDnXfriRhOpemw+NfT1eA==} + engines: {node: '>=8'} + + is-plain-obj@4.1.0: + resolution: {integrity: sha512-+Pgi+vMuUNkJyExiMBt5IlFoMyKnr5zhJ4Uspz58WOhBF5QoIZkFyNHIbBAtHwzVAgk5RtndVNsDRN61/mmDqg==} + engines: {node: '>=12'} + + is-potential-custom-element-name@1.0.1: + resolution: {integrity: sha512-bCYeRA2rVibKZd+s2625gGnGF/t7DSqDs4dP7CrLA1m7jKWz6pps0LpYLJN8Q64HtmPKJ1hrN3nzPNKFEKOUiQ==} + + is-regex@1.2.1: + resolution: {integrity: sha512-MjYsKHO5O7mCsmRGxWcLWheFqN9DJ/2TmngvjKXihe6efViPqc274+Fx/4fYj/r03+ESvBdTXK0V6tA3rgez1g==} + engines: {node: '>= 0.4'} + + is-set@2.0.3: + resolution: {integrity: sha512-iPAjerrse27/ygGLxw+EBR9agv9Y6uLeYVJMu+QNCoouJ1/1ri0mGrcWpfCqFZuzzx3WjtwxG098X+n4OuRkPg==} + engines: {node: '>= 0.4'} + + is-shared-array-buffer@1.0.4: + resolution: {integrity: sha512-ISWac8drv4ZGfwKl5slpHG9OwPNty4jOWPRIhBpxOoD+hqITiwuipOQ2bNthAzwA3B4fIjO4Nln74N0S9byq8A==} + engines: {node: '>= 0.4'} + + is-stream@2.0.1: + resolution: {integrity: sha512-hFoiJiTl63nn+kstHGBtewWSKnQLpyb155KHheA1l39uvtO9nWIop1p3udqPcUd/xbF1VLMO4n7OI6p7RbngDg==} + engines: {node: '>=8'} + + is-string@1.1.1: + resolution: {integrity: sha512-BtEeSsoaQjlSPBemMQIrY1MY0uM6vnS1g5fmufYOtnxLGUZM2178PKbhsk7Ffv58IX+ZtcvoGwccYsh0PglkAA==} + engines: {node: '>= 0.4'} + + is-symbol@1.1.1: + resolution: {integrity: sha512-9gGx6GTtCQM73BgmHQXfDmLtfjjTUDSyoxTCbp5WtoixAhfgsDirWIcVQ/IHpvI5Vgd5i/J5F7B9cN/WlVbC/w==} + engines: {node: '>= 0.4'} + + is-typed-array@1.1.15: + resolution: {integrity: sha512-p3EcsicXjit7SaskXHs1hA91QxgTw46Fv6EFKKGS5DRFLD8yKnohjF3hxoju94b/OcMZoQukzpPpBE9uLVKzgQ==} + engines: {node: '>= 0.4'} + + is-unicode-supported@0.1.0: + resolution: {integrity: sha512-knxG2q4UC3u8stRGyAVJCOdxFmv5DZiRcdlIaAQXAbSfJya+OhopNotLQrstBhququ4ZpuKbDc/8S6mgXgPFPw==} + engines: {node: '>=10'} + + is-unicode-supported@1.3.0: + resolution: {integrity: sha512-43r2mRvz+8JRIKnWJ+3j8JtjRKZ6GmjzfaE/qiBJnikNnYv/6bagRJ1kUhNk8R5EX/GkobD+r+sfxCPJsiKBLQ==} + engines: {node: '>=12'} + + is-unicode-supported@2.1.0: + resolution: {integrity: sha512-mE00Gnza5EEB3Ds0HfMyllZzbBrmLOX3vfWoj9A9PEnTfratQ/BcaJOuMhnkhjXvb2+FkY3VuHqtAGpTPmglFQ==} + engines: {node: '>=18'} + + is-weakmap@2.0.2: + resolution: {integrity: sha512-K5pXYOm9wqY1RgjpL3YTkF39tni1XajUIkawTLUo9EZEVUFga5gSQJF8nNS7ZwJQ02y+1YCNYcMh+HIf1ZqE+w==} + engines: {node: '>= 0.4'} + + is-weakref@1.1.1: + resolution: {integrity: sha512-6i9mGWSlqzNMEqpCp93KwRS1uUOodk2OJ6b+sq7ZPDSy2WuI5NFIxp/254TytR8ftefexkWn5xNiHUNpPOfSew==} + engines: {node: '>= 0.4'} + + is-weakset@2.0.4: + resolution: {integrity: sha512-mfcwb6IzQyOKTs84CQMrOwW4gQcaTOAWJ0zzJCl2WSPDrWk/OzDaImWFH3djXhb24g4eudZfLRozAvPGw4d9hQ==} + engines: {node: '>= 0.4'} + + is-wsl@2.2.0: + resolution: {integrity: sha512-fKzAra0rGJUUBwGBgNkHZuToZcn+TtXHpeCgmkMJMMYx1sQDYaCSyjJBSCa2nH1DGm7s3n1oBnohoVTBaN7Lww==} + engines: {node: '>=8'} + + is-wsl@3.1.0: + resolution: {integrity: sha512-UcVfVfaK4Sc4m7X3dUSoHoozQGBEFeDC+zVo06t98xe8CzHSZZBekNXH+tu0NalHolcJ/QAGqS46Hef7QXBIMw==} + engines: {node: '>=16'} + + isarray@1.0.0: + resolution: {integrity: sha512-VLghIWNM6ELQzo7zwmcg0NmTVyWKYjvIeM83yjp0wRDTmUnrM678fQbcKBo6n2CJEF0szoG//ytg+TKla89ALQ==} + + isarray@2.0.5: + resolution: {integrity: sha512-xHjhDr3cNBK0BzdUJSPXZntQUx/mwMS5Rw4A7lPJ90XGAO6ISP/ePDNuo0vhqOZU+UD5JoodwCAAoZQd3FeAKw==} + + isbot@5.1.28: + resolution: {integrity: sha512-qrOp4g3xj8YNse4biorv6O5ZShwsJM0trsoda4y7j/Su7ZtTTfVXFzbKkpgcSoDrHS8FcTuUwcU04YimZlZOxw==} + engines: {node: '>=18'} + + isexe@2.0.0: + resolution: {integrity: sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==} + + istanbul-lib-coverage@3.2.2: + resolution: {integrity: sha512-O8dpsF+r0WV/8MNRKfnmrtCWhuKjxrq2w+jpzBL5UZKTi2LeVWnWOmWRxFlesJONmc+wLAGvKQZEOanko0LFTg==} + engines: {node: '>=8'} + + istanbul-lib-report@3.0.1: + resolution: {integrity: sha512-GCfE1mtsHGOELCU8e/Z7YWzpmybrx/+dSTfLrvY8qRmaY6zXTKWn6WQIjaAFw069icm6GVMNkgu0NzI4iPZUNw==} + engines: {node: '>=10'} + + istanbul-lib-source-maps@5.0.6: + resolution: {integrity: sha512-yg2d+Em4KizZC5niWhQaIomgf5WlL4vOOjZ5xGCmF8SnPE/mDWWXgvRExdcpCgh9lLRRa1/fSYp2ymmbJ1pI+A==} + engines: {node: '>=10'} + + istanbul-reports@3.1.7: + resolution: {integrity: sha512-BewmUXImeuRk2YY0PVbxgKAysvhRPUQE0h5QRM++nVWyubKGV0l8qQ5op8+B2DOmwSe63Jivj0BjkPQVf8fP5g==} + engines: {node: '>=8'} + + istextorbinary@9.5.0: + resolution: {integrity: sha512-5mbUj3SiZXCuRf9fT3ibzbSSEWiy63gFfksmGfdOzujPjW3k+z8WvIBxcJHBoQNlaZaiyB25deviif2+osLmLw==} + engines: {node: '>=4'} + + jackspeak@3.4.3: + resolution: {integrity: sha512-OGlZQpz2yfahA/Rd1Y8Cd9SIEsqvXkLVoSw/cgwhnhFMDbsQFeZYoJJ7bIZBS9BcamUW96asq/npPWugM+RQBw==} + + jackspeak@4.1.1: + resolution: {integrity: sha512-zptv57P3GpL+O0I7VdMJNBZCu+BPHVQUk55Ft8/QCJjTVxrnJHuVuX/0Bl2A6/+2oyR/ZMEuFKwmzqqZ/U5nPQ==} + engines: {node: 20 || >=22} + + jest-worker@27.5.1: + resolution: {integrity: sha512-7vuh85V5cdDofPyxn58nrPjBktZo0u9x1g8WtjQol+jZDaE+fhN+cIvTj11GndBnMnyfrUOG1sZQxCdjKh+DKg==} + engines: {node: '>= 10.13.0'} + + jiti@1.21.7: + resolution: {integrity: sha512-/imKNG4EbWNrVjoNC/1H5/9GFy+tqjGBHCaSsN+P2RnPqjsLmv6UD3Ej+Kj8nBWaRAwyk7kK5ZUc+OEatnTR3A==} + hasBin: true + + jiti@2.4.2: + resolution: {integrity: sha512-rg9zJN+G4n2nfJl5MW3BMygZX56zKPNVEYYqq7adpmMh4Jn2QNEwhvQlFy6jPVdcod7txZtKHWnyZiA3a0zP7A==} + hasBin: true + + jju@1.4.0: + resolution: {integrity: sha512-8wb9Yw966OSxApiCt0K3yNJL8pnNeIv+OEq2YMidz4FKP6nonSRoOXc80iXY4JaN2FC11B9qsNmDsm+ZOfMROA==} + + js-tokens@4.0.0: + resolution: {integrity: sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==} + + js-tokens@9.0.1: + resolution: {integrity: sha512-mxa9E9ITFOt0ban3j6L5MpjwegGz6lBQmM1IJkWeBZGcMxto50+eWdjC/52xDbS2vy0k7vIMK0Fe2wfL9OQSpQ==} + + js-yaml@3.14.1: + resolution: {integrity: sha512-okMH7OXXJ7YrN9Ok3/SXrnu4iX9yOk+25nqX4imS2npuvTYDmo/QEZoqwZkYaIDk3jVvBOTOIEgEhaLOynBS9g==} + hasBin: true + + js-yaml@4.1.0: + resolution: {integrity: sha512-wpxZs9NoxZaJESJGIZTyDEaYpl0FKSA+FB9aJiyemKhMwkxQg63h4T1KJgUGHpTqPDNRcmmYLugrRjJlBtWvRA==} + hasBin: true + + jsdom@26.1.0: + resolution: {integrity: sha512-Cvc9WUhxSMEo4McES3P7oK3QaXldCfNWp7pl2NNeiIFlCoLr3kfq9kb1fxftiwk1FLV7CvpvDfonxtzUDeSOPg==} + engines: {node: '>=18'} + peerDependencies: + canvas: ^3.0.0 + peerDependenciesMeta: + canvas: + optional: true + + jsep@1.4.0: + resolution: {integrity: sha512-B7qPcEVE3NVkmSJbaYxvv4cHkVW7DQsZz13pUMrfS8z8Q/BuShN+gcTXrUlPiGqM2/t/EEaI030bpxMqY8gMlw==} + engines: {node: '>= 10.16.0'} + + jsesc@3.1.0: + resolution: {integrity: sha512-/sM3dO2FOzXjKQhJuo0Q173wf2KOo8t4I8vHy6lF9poUp7bKT0/NHE8fPX23PwfhnykfqnC2xRxOnVw5XuGIaA==} + engines: {node: '>=6'} + hasBin: true + + json-bignum@0.0.3: + resolution: {integrity: sha512-2WHyXj3OfHSgNyuzDbSxI1w2jgw5gkWSWhS7Qg4bWXx1nLk3jnbwfUeS0PSba3IzpTUWdHxBieELUzXRjQB2zg==} + engines: {node: '>=0.8'} + + json-buffer@3.0.1: + resolution: {integrity: sha512-4bV5BfR2mqfQTJm+V5tPPdf+ZpuhiIvTuAB5g8kcrXOZpTT/QwwVRWBywX1ozr6lEuPdbHxwaJlm9G6mI2sfSQ==} + + json-parse-even-better-errors@2.3.1: + resolution: {integrity: sha512-xyFwyhro/JEof6Ghe2iz2NcXoj2sloNsWr/XsERDK/oiPCfaNhl5ONfp+jQdAZRQQ0IJWNzH9zIZF7li91kh2w==} + + json-schema-traverse@0.4.1: + resolution: {integrity: sha512-xbbCH5dCYU5T8LcEhhuh7HJ88HXuW3qsI3Y0zOZFKfZEHcpWiHU/Jxzk629Brsab/mMiHQti9wMP+845RPe3Vg==} + + json-schema-traverse@1.0.0: + resolution: {integrity: sha512-NM8/P9n3XjXhIZn1lLhkFaACTOURQXjWhV4BA/RnOv8xvgqtqpAX9IO4mRQxSx1Rlo4tqzeqb0sOlruaOy3dug==} + + json-stable-stringify-without-jsonify@1.0.1: + resolution: {integrity: sha512-Bdboy+l7tA3OGW6FjyFHWkP5LuByj1Tk33Ljyq0axyzdk9//JSi2u3fP1QSmd1KNwq6VOKYGlAu87CisVir6Pw==} + + json5@2.2.3: + resolution: {integrity: sha512-XmOWe7eyHYH14cLdVPoyg+GOH3rYX++KpzrylJwSW98t3Nk+U8XOl8FWKOgwtzdb8lXGf6zYwDUzeHMWfxasyg==} + engines: {node: '>=6'} + hasBin: true + + jsonc-parser@2.2.1: + resolution: {integrity: sha512-o6/yDBYccGvTz1+QFevz6l6OBZ2+fMVu2JZ9CIhzsYRX4mjaK5IyX9eldUdCmga16zlgQxyrj5pt9kzuj2C02w==} + + jsonc-parser@3.3.1: + resolution: {integrity: sha512-HUgH65KyejrUFPvHFPbqOY0rsFip3Bo5wb4ngvdi1EpCYWUQDC5V+Y7mZws+DLkr4M//zQJoanu1SP+87Dv1oQ==} + + jsonfile@6.1.0: + resolution: {integrity: sha512-5dgndWOriYSm5cnYaJNhalLNDKOqFwyDB/rr1E9ZsGciGvKPs8R2xYGCacuf3z6K1YKDz182fd+fY3cn3pMqXQ==} + + jsonpath-plus@10.3.0: + resolution: {integrity: sha512-8TNmfeTCk2Le33A3vRRwtuworG/L5RrgMvdjhKZxvyShO+mBu2fP50OWUjRLNtvw344DdDarFh9buFAZs5ujeA==} + engines: {node: '>=18.0.0'} + hasBin: true + + jsonpointer@5.0.1: + resolution: {integrity: sha512-p/nXbhSEcu3pZRdkW1OfJhpsVtW1gd4Wa1fnQc9YLiTfAjn0312eMKimbdIQzuZl9aa9xUGaRlP9T/CJE/ditQ==} + engines: {node: '>=0.10.0'} + + jsonschema@1.5.0: + resolution: {integrity: sha512-K+A9hhqbn0f3pJX17Q/7H6yQfD/5OXgdrR5UE12gMXCiN9D5Xq2o5mddV2QEcX/bjla99ASsAAQUyMCCRWAEhw==} + + jsonwebtoken@9.0.2: + resolution: {integrity: sha512-PRp66vJ865SSqOlgqS8hujT5U4AOgMfhrwYIuIhfKaoSCZcirrmASQr8CX7cUg+RMih+hgznrjp99o+W4pJLHQ==} + engines: {node: '>=12', npm: '>=6'} + + jszip@3.10.1: + resolution: {integrity: sha512-xXDvecyTpGLrqFrvkrUSoxxfJI5AH7U8zxxtVclpsUtMCq4JQ290LY8AW5c7Ggnr/Y/oK+bQMbqK2qmtk3pN4g==} + + jwa@1.4.2: + resolution: {integrity: sha512-eeH5JO+21J78qMvTIDdBXidBd6nG2kZjg5Ohz/1fpa28Z4CcsWUzJ1ZZyFq/3z3N17aZy+ZuBoHljASbL1WfOw==} + + jws@3.2.2: + resolution: {integrity: sha512-YHlZCB6lMTllWDtSPHz/ZXTsi8S00usEV6v1tjq8tOUZzw7DpSDWVXjXDre6ed1w/pd495ODpHZYSdkRTsa0HA==} + + keytar@7.9.0: + resolution: {integrity: sha512-VPD8mtVtm5JNtA2AErl6Chp06JBfy7diFQ7TQQhdpWOl6MrCRB+eRbvAZUsbGQS9kiMq0coJsy0W0vHpDCkWsQ==} + + keyv@4.5.4: + resolution: {integrity: sha512-oxVHkHR/EJf2CNXnWxRLW6mg7JyCCUcG0DtEGmL2ctUo1PNTin1PUil+r/+4r5MpVgC/fn1kjsx7mjSujKqIpw==} + + kleur@3.0.3: + resolution: {integrity: sha512-eTIzlVOSUR+JxdDFepEYcBMtZ9Qqdef+rnzWdRZuMbOywu5tO2w2N7rqjoANZ5k9vywhL6Br1VRjUIgTQx4E8w==} + engines: {node: '>=6'} + + kolorist@1.8.0: + resolution: {integrity: sha512-Y+60/zizpJ3HRH8DCss+q95yr6145JXZo46OTpFvDZWLfRCE4qChOyk1b26nMaNpfHHgxagk9dXT5OP0Tfe+dQ==} + + leven@3.1.0: + resolution: {integrity: sha512-qsda+H8jTaUaN/x5vzW2rzc+8Rw4TAQ/4KjB46IwK5VH+IlVeeeje/EoZRpiXvIqjFgK84QffqPztGI3VBLG1A==} + engines: {node: '>=6'} + + levn@0.4.1: + resolution: {integrity: sha512-+bT2uH4E5LGE7h/n3evcS/sQlJXCpIp6ym8OWJ5eV6+67Dsql/LaaT7qJBAt2rzfoa/5QBGBhxDix1dMt2kQKQ==} + engines: {node: '>= 0.8.0'} + + lie@3.3.0: + resolution: {integrity: sha512-UaiMJzeWRlEujzAuw5LokY1L5ecNQYZKfmyZ9L7wDHb/p5etKaxXhohBcrw0EYby+G/NA52vRSN4N39dxHAIwQ==} + + lightningcss-darwin-arm64@1.30.1: + resolution: {integrity: sha512-c8JK7hyE65X1MHMN+Viq9n11RRC7hgin3HhYKhrMyaXflk5GVplZ60IxyoVtzILeKr+xAJwg6zK6sjTBJ0FKYQ==} + engines: {node: '>= 12.0.0'} + cpu: [arm64] + os: [darwin] + + lightningcss-darwin-x64@1.30.1: + resolution: {integrity: sha512-k1EvjakfumAQoTfcXUcHQZhSpLlkAuEkdMBsI/ivWw9hL+7FtilQc0Cy3hrx0AAQrVtQAbMI7YjCgYgvn37PzA==} + engines: {node: '>= 12.0.0'} + cpu: [x64] + os: [darwin] + + lightningcss-freebsd-x64@1.30.1: + resolution: {integrity: sha512-kmW6UGCGg2PcyUE59K5r0kWfKPAVy4SltVeut+umLCFoJ53RdCUWxcRDzO1eTaxf/7Q2H7LTquFHPL5R+Gjyig==} + engines: {node: '>= 12.0.0'} + cpu: [x64] + os: [freebsd] + + lightningcss-linux-arm-gnueabihf@1.30.1: + resolution: {integrity: sha512-MjxUShl1v8pit+6D/zSPq9S9dQ2NPFSQwGvxBCYaBYLPlCWuPh9/t1MRS8iUaR8i+a6w7aps+B4N0S1TYP/R+Q==} + engines: {node: '>= 12.0.0'} + cpu: [arm] + os: [linux] + + lightningcss-linux-arm64-gnu@1.30.1: + resolution: {integrity: sha512-gB72maP8rmrKsnKYy8XUuXi/4OctJiuQjcuqWNlJQ6jZiWqtPvqFziskH3hnajfvKB27ynbVCucKSm2rkQp4Bw==} + engines: {node: '>= 12.0.0'} + cpu: [arm64] + os: [linux] + + lightningcss-linux-arm64-musl@1.30.1: + resolution: {integrity: sha512-jmUQVx4331m6LIX+0wUhBbmMX7TCfjF5FoOH6SD1CttzuYlGNVpA7QnrmLxrsub43ClTINfGSYyHe2HWeLl5CQ==} + engines: {node: '>= 12.0.0'} + cpu: [arm64] + os: [linux] + + lightningcss-linux-x64-gnu@1.30.1: + resolution: {integrity: sha512-piWx3z4wN8J8z3+O5kO74+yr6ze/dKmPnI7vLqfSqI8bccaTGY5xiSGVIJBDd5K5BHlvVLpUB3S2YCfelyJ1bw==} + engines: {node: '>= 12.0.0'} + cpu: [x64] + os: [linux] + + lightningcss-linux-x64-musl@1.30.1: + resolution: {integrity: sha512-rRomAK7eIkL+tHY0YPxbc5Dra2gXlI63HL+v1Pdi1a3sC+tJTcFrHX+E86sulgAXeI7rSzDYhPSeHHjqFhqfeQ==} + engines: {node: '>= 12.0.0'} + cpu: [x64] + os: [linux] + + lightningcss-win32-arm64-msvc@1.30.1: + resolution: {integrity: sha512-mSL4rqPi4iXq5YVqzSsJgMVFENoa4nGTT/GjO2c0Yl9OuQfPsIfncvLrEW6RbbB24WtZ3xP/2CCmI3tNkNV4oA==} + engines: {node: '>= 12.0.0'} + cpu: [arm64] + os: [win32] + + lightningcss-win32-x64-msvc@1.30.1: + resolution: {integrity: sha512-PVqXh48wh4T53F/1CCu8PIPCxLzWyCnn/9T5W1Jpmdy5h9Cwd+0YQS6/LwhHXSafuc61/xg9Lv5OrCby6a++jg==} + engines: {node: '>= 12.0.0'} + cpu: [x64] + os: [win32] + + lightningcss@1.30.1: + resolution: {integrity: sha512-xi6IyHML+c9+Q3W0S4fCQJOym42pyurFiJUHEcEyHS0CeKzia4yZDEsLlqOFykxOdHpNy0NmvVO31vcSqAxJCg==} + engines: {node: '>= 12.0.0'} + + lilconfig@3.1.3: + resolution: {integrity: sha512-/vlFKAoH5Cgt3Ie+JLhRbwOsCQePABiU3tJ1egGvyQ+33R/vcwM2Zl2QR/LzjsBeItPt3oSVXapn+m4nQDvpzw==} + engines: {node: '>=14'} + + lines-and-columns@1.2.4: + resolution: {integrity: sha512-7ylylesZQ/PV29jhEDl3Ufjo6ZX7gCqJr5F7PKrqc93v7fzSymt1BpwEU8nAUXs8qzzvqhbjhK5QZg6Mt/HkBg==} + + linkify-it@5.0.0: + resolution: {integrity: sha512-5aHCbzQRADcdP+ATqnDuhhJ/MRIqDkZX5pyjFHRRysS8vZ5AbqGEoFIb6pYHPZ+L/OC2Lc+xT8uHVVR5CAK/wQ==} + + loader-runner@4.3.0: + resolution: {integrity: sha512-3R/1M+yS3j5ou80Me59j7F9IMs4PXs3VqRrm0TU3AbKPxlmpoY1TNscJV/oGJXo8qCatFGTfDbY6W6ipGOYXfg==} + engines: {node: '>=6.11.5'} + + local-pkg@1.1.1: + resolution: {integrity: sha512-WunYko2W1NcdfAFpuLUoucsgULmgDBRkdxHxWQ7mK0cQqwPiy8E1enjuRBrhLtZkB5iScJ1XIPdhVEFK8aOLSg==} + engines: {node: '>=14'} + + locate-path@6.0.0: + resolution: {integrity: sha512-iPZK6eYjbxRu3uB4/WZ3EsEIMJFMqAoopl3R+zuq0UjcAm/MO6KCweDgPfP3elTztoKP3KtnVHxTn2NHBSDVUw==} + engines: {node: '>=10'} + + locate-path@7.2.0: + resolution: {integrity: sha512-gvVijfZvn7R+2qyPX8mAuKcFGDf6Nc61GdvGafQsHL0sBIxfKzA+usWn4GFC/bk+QdwPUD4kWFJLhElipq+0VA==} + engines: {node: ^12.20.0 || ^14.13.1 || >=16.0.0} + + lodash.camelcase@4.3.0: + resolution: {integrity: sha512-TwuEnCnxbc3rAvhf/LbG7tJUDzhqXyFnv3dtzLOPgCG/hODL7WFnsbwktkD7yUV0RrreP/l1PALq/YSg6VvjlA==} + + lodash.castarray@4.4.0: + resolution: {integrity: sha512-aVx8ztPv7/2ULbArGJ2Y42bG1mEQ5mGjpdvrbJcJFU3TbYybe+QlLS4pst9zV52ymy2in1KpFPiZnAOATxD4+Q==} + + lodash.includes@4.3.0: + resolution: {integrity: sha512-W3Bx6mdkRTGtlJISOvVD/lbqjTlPPUDTMnlXZFnVwi9NKJ6tiAk6LVdlhZMm17VZisqhKcgzpO5Wz91PCt5b0w==} + + lodash.isboolean@3.0.3: + resolution: {integrity: sha512-Bz5mupy2SVbPHURB98VAcw+aHh4vRV5IPNhILUCsOzRmsTmSQ17jIuqopAentWoehktxGd9e/hbIXq980/1QJg==} + + lodash.isempty@4.4.0: + resolution: {integrity: sha512-oKMuF3xEeqDltrGMfDxAPGIVMSSRv8tbRSODbrs4KGsRRLEhrW8N8Rd4DRgB2+621hY8A8XwwrTVhXWpxFvMzg==} + + lodash.isinteger@4.0.4: + resolution: {integrity: sha512-DBwtEWN2caHQ9/imiNeEA5ys1JoRtRfY3d7V9wkqtbycnAmTvRRmbHKDV4a0EYc678/dia0jrte4tjYwVBaZUA==} + + lodash.isnumber@3.0.3: + resolution: {integrity: sha512-QYqzpfwO3/CWf3XP+Z+tkQsfaLL/EnUlXWVkIk5FUPc4sBdTehEqZONuyRt2P67PXAk+NXmTBcc97zw9t1FQrw==} + + lodash.isplainobject@4.0.6: + resolution: {integrity: sha512-oSXzaWypCMHkPC3NvBEaPHf0KsA5mvPrOPgQWDsbg8n7orZ290M0BmC/jgRZ4vcJ6DTAhjrsSYgdsW/F+MFOBA==} + + lodash.isstring@4.0.1: + resolution: {integrity: sha512-0wJxfxH1wgO3GrbuP+dTTk7op+6L41QCXbGINEmD+ny/G/eCqGzxyCsh7159S+mgDDcoarnBw6PC1PS5+wUGgw==} + + lodash.merge@4.6.2: + resolution: {integrity: sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ==} + + lodash.omitby@4.6.0: + resolution: {integrity: sha512-5OrRcIVR75M288p4nbI2WLAf3ndw2GD9fyNv3Bc15+WCxJDdZ4lYndSxGd7hnG6PVjiJTeJE2dHEGhIuKGicIQ==} + + lodash.once@4.1.1: + resolution: {integrity: sha512-Sb487aTOCr9drQVL8pIxOzVhafOjZN9UU54hiN8PU3uAiSV7lx1yYNpbNmex2PK6dSJoNTSJUUswT651yww3Mg==} + + lodash.topath@4.5.2: + resolution: {integrity: sha512-1/W4dM+35DwvE/iEd1M9ekewOSTlpFekhw9mhAtrwjVqUr83/ilQiyAvmg4tVX7Unkcfl1KC+i9WdaT4B6aQcg==} + + lodash.truncate@4.4.2: + resolution: {integrity: sha512-jttmRe7bRse52OsWIMDLaXxWqRAmtIUccAQ3garviCqJjafXOfNMO0yMfNpdD6zbGaTU0P5Nz7e7gAT6cKmJRw==} + + lodash.uniq@4.5.0: + resolution: {integrity: sha512-xfBaXQd9ryd9dlSDvnvI0lvxfLJlYAZzXomUYzLKtUeOQvOP5piqAWuGtrhWeqaXK9hhoM/iyJc5AV+XfsX3HQ==} + + lodash.uniqby@4.7.0: + resolution: {integrity: sha512-e/zcLx6CSbmaEgFHCA7BnoQKyCtKMxnuWrJygbwPs/AIn+IMKl66L8/s+wBUn5LRw2pZx3bUHibiV1b6aTWIww==} + + lodash.uniqwith@4.5.0: + resolution: {integrity: sha512-7lYL8bLopMoy4CTICbxygAUq6CdRJ36vFc80DucPueUee+d5NBRxz3FdT9Pes/HEx5mPoT9jwnsEJWz1N7uq7Q==} + + lodash@4.17.21: + resolution: {integrity: sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==} + + log-symbols@4.1.0: + resolution: {integrity: sha512-8XPvpAA8uyhfteu8pIvQxpJZ7SYYdpUivZpGy6sFsBuKRY/7rQGavedeB8aK+Zkyq6upMFVL/9AW6vOYzfRyLg==} + engines: {node: '>=10'} + + log-symbols@6.0.0: + resolution: {integrity: sha512-i24m8rpwhmPIS4zscNzK6MSEhk0DUWa/8iYQWxhffV8jkI4Phvs3F+quL5xvS0gdQR0FyTCMMH33Y78dDTzzIw==} + engines: {node: '>=18'} + + loglevel-plugin-prefix@0.8.4: + resolution: {integrity: sha512-WpG9CcFAOjz/FtNht+QJeGpvVl/cdR6P0z6OcXSkr8wFJOsV2GRj2j10JLfjuA4aYkcKCNIEqRGCyTife9R8/g==} + + loglevel@1.9.2: + resolution: {integrity: sha512-HgMmCqIJSAKqo68l0rS2AanEWfkxaZ5wNiEFb5ggm08lDs9Xl2KxBlX3PTcaD2chBM1gXAYf491/M2Rv8Jwayg==} + engines: {node: '>= 0.6.0'} + + longest-streak@3.1.0: + resolution: {integrity: sha512-9Ri+o0JYgehTaVBBDoMqIl8GXtbWg711O3srftcHhZ0dqnETqLaoIK0x17fUw9rFSlK/0NlsKe0Ahhyl5pXE2g==} + + loose-envify@1.4.0: + resolution: {integrity: sha512-lyuxPGr/Wfhrlem2CL/UcnUc1zcqKAImBDzukY7Y5F/yQiNdko6+fRLevlw1HgMySw7f611UIY408EtxRSoK3Q==} + hasBin: true + + loupe@3.1.4: + resolution: {integrity: sha512-wJzkKwJrheKtknCOKNEtDK4iqg/MxmZheEMtSTYvnzRdEYaZzmgH976nenp8WdJRdx5Vc1X/9MO0Oszl6ezeXg==} + + lru-cache@10.4.3: + resolution: {integrity: sha512-JNAzZcXrCt42VGLuYz0zfAzDfAvJWW6AfYlDBQyDV5DClI2m5sAmK+OIO7s59XfsRsWHp02jAJrRadPRGTt6SQ==} + + lru-cache@11.1.0: + resolution: {integrity: sha512-QIXZUBJUx+2zHUdQujWejBkcD9+cs94tLn0+YL8UrCh+D5sCXZ4c7LaEH48pNwRY3MLDgqUFyhlCyjJPf1WP0A==} + engines: {node: 20 || >=22} + + lru-cache@5.1.1: + resolution: {integrity: sha512-KpNARQA3Iwv+jTA0utUVVbrh+Jlrr1Fv0e56GGzAFOXN7dk/FviaDW8LHmK52DlcH4WP2n6gI8vN1aesBFgo9w==} + + lru-cache@6.0.0: + resolution: {integrity: sha512-Jo6dJ04CmSjuznwJSS3pUeWmd/H0ffTlkXXgwZi+eq1UCmqQwCh+eLsYOYCwY991i2Fah4h1BEMCx4qThGbsiA==} + engines: {node: '>=10'} + + lucide-react@0.542.0: + resolution: {integrity: sha512-w3hD8/SQB7+lzU2r4VdFyzzOzKnUjTZIF/MQJGSSvni7Llewni4vuViRppfRAa2guOsY5k4jZyxw/i9DQHv+dw==} + peerDependencies: + react: ^16.5.1 || ^17.0.0 || ^18.0.0 || ^19.0.0 + + lunr@2.3.9: + resolution: {integrity: sha512-zTU3DaZaF3Rt9rhN3uBMGQD3dD2/vFQqnvZCDv4dl5iOzq2IZQqTxu90r4E5J+nP70J3ilqVCrbho2eWaeW8Ow==} + + lz-string@1.5.0: + resolution: {integrity: sha512-h5bgJWpxJNswbU7qCrV0tIKQCaS3blPDrqKWx+QxzuzL1zGUzij9XCWLrSLsJPu5t+eWA/ycetzYAO5IOMcWAQ==} + hasBin: true + + magic-string@0.30.17: + resolution: {integrity: sha512-sNPKHvyjVf7gyjwS4xGTaW/mCnF8wnjtifKBEhxfZ7E/S8tQ0rssrwGNn6q8JH/ohItJfSQp9mBtQYuTlH5QnA==} + + magicast@0.3.5: + resolution: {integrity: sha512-L0WhttDl+2BOsybvEOLK7fW3UA0OQ0IQ2d6Zl2x/a6vVRs3bAY0ECOSHHeL5jD+SbOpOCUEi0y1DgHEn9Qn1AQ==} + + make-dir@4.0.0: + resolution: {integrity: sha512-hXdUTZYIVOt1Ex//jAQi+wTZZpUpwBj/0QsOzqegb3rGMMeJiSEu5xLHnYfBrRV4RH2+OCSOO95Is/7x1WJ4bw==} + engines: {node: '>=10'} + + markdown-it@14.1.0: + resolution: {integrity: sha512-a54IwgWPaeBCAAsv13YgmALOF1elABB08FxO9i+r4VFk5Vl4pKokRPeX8u5TCgSsPi6ec1otfLjdOpVcgbpshg==} + hasBin: true + + math-intrinsics@1.1.0: + resolution: {integrity: sha512-/IXtbwEk5HTPyEwyKX6hGkYXxM9nbj64B+ilVJnC/R6B0pH5G4V3b0pVbL7DBj4tkhBAppbQUlf6F6Xl9LHu1g==} + engines: {node: '>= 0.4'} + + mdast-util-from-markdown@2.0.2: + resolution: {integrity: sha512-uZhTV/8NBuw0WHkPTrCqDOl0zVe1BIng5ZtHoDk49ME1qqcjYmmLmOf0gELgcRMxN4w2iuIeVso5/6QymSrgmA==} + + mdast-util-mdx-expression@2.0.1: + resolution: {integrity: sha512-J6f+9hUp+ldTZqKRSg7Vw5V6MqjATc+3E4gf3CFNcuZNWD8XdyI6zQ8GqH7f8169MM6P7hMBRDVGnn7oHB9kXQ==} + + mdast-util-mdx-jsx@3.2.0: + resolution: {integrity: sha512-lj/z8v0r6ZtsN/cGNNtemmmfoLAFZnjMbNyLzBafjzikOM+glrjNHPlf6lQDOTccj9n5b0PPihEBbhneMyGs1Q==} + + mdast-util-mdxjs-esm@2.0.1: + resolution: {integrity: sha512-EcmOpxsZ96CvlP03NghtH1EsLtr0n9Tm4lPUJUBccV9RwUOneqSycg19n5HGzCf+10LozMRSObtVr3ee1WoHtg==} + + mdast-util-phrasing@4.1.0: + resolution: {integrity: sha512-TqICwyvJJpBwvGAMZjj4J2n0X8QWp21b9l0o7eXyVJ25YNWYbJDVIyD1bZXE6WtV6RmKJVYmQAKWa0zWOABz2w==} + + mdast-util-to-hast@13.2.0: + resolution: {integrity: sha512-QGYKEuUsYT9ykKBCMOEDLsU5JRObWQusAolFMeko/tYPufNkRffBAQjIE+99jbA87xv6FgmjLtwjh9wBWajwAA==} + + mdast-util-to-markdown@2.1.2: + resolution: {integrity: sha512-xj68wMTvGXVOKonmog6LwyJKrYXZPvlwabaryTjLh9LuvovB/KAH+kvi8Gjj+7rJjsFi23nkUxRQv1KqSroMqA==} + + mdast-util-to-string@4.0.0: + resolution: {integrity: sha512-0H44vDimn51F0YwvxSJSm0eCDOJTRlmN0R1yBh4HLj9wiV1Dn0QoXGbvFAWj2hSItVTlCmBF1hqKlIyUBVFLPg==} + + mdurl@2.0.0: + resolution: {integrity: sha512-Lf+9+2r+Tdp5wXDXC4PcIBjTDtq4UKjCPMQhKIuzpJNW0b96kVqSwW0bT7FhRSfmAiFYgP+SCRvdrDozfh0U5w==} + + merge-stream@2.0.0: + resolution: {integrity: sha512-abv/qOcuPfk3URPfDzmZU1LKmuw8kT+0nIHvKrKgFrwifol/doWcdA4ZqsWQ8ENrFKkd67Mfpo/LovbIUsbt3w==} + + merge2@1.4.1: + resolution: {integrity: sha512-8q7VEgMJW4J8tcfVPy8g09NcQwZdbwFEqhe/WZkoIzjn/3TGDwtOCYtXGxA3O8tPzpczCCDgv+P2P5y00ZJOOg==} + engines: {node: '>= 8'} + + micromark-core-commonmark@2.0.3: + resolution: {integrity: sha512-RDBrHEMSxVFLg6xvnXmb1Ayr2WzLAWjeSATAoxwKYJV94TeNavgoIdA0a9ytzDSVzBy2YKFK+emCPOEibLeCrg==} + + micromark-factory-destination@2.0.1: + resolution: {integrity: sha512-Xe6rDdJlkmbFRExpTOmRj9N3MaWmbAgdpSrBQvCFqhezUn4AHqJHbaEnfbVYYiexVSs//tqOdY/DxhjdCiJnIA==} + + micromark-factory-label@2.0.1: + resolution: {integrity: sha512-VFMekyQExqIW7xIChcXn4ok29YE3rnuyveW3wZQWWqF4Nv9Wk5rgJ99KzPvHjkmPXF93FXIbBp6YdW3t71/7Vg==} + + micromark-factory-space@2.0.1: + resolution: {integrity: sha512-zRkxjtBxxLd2Sc0d+fbnEunsTj46SWXgXciZmHq0kDYGnck/ZSGj9/wULTV95uoeYiK5hRXP2mJ98Uo4cq/LQg==} + + micromark-factory-title@2.0.1: + resolution: {integrity: sha512-5bZ+3CjhAd9eChYTHsjy6TGxpOFSKgKKJPJxr293jTbfry2KDoWkhBb6TcPVB4NmzaPhMs1Frm9AZH7OD4Cjzw==} + + micromark-factory-whitespace@2.0.1: + resolution: {integrity: sha512-Ob0nuZ3PKt/n0hORHyvoD9uZhr+Za8sFoP+OnMcnWK5lngSzALgQYKMr9RJVOWLqQYuyn6ulqGWSXdwf6F80lQ==} + + micromark-util-character@2.1.1: + resolution: {integrity: sha512-wv8tdUTJ3thSFFFJKtpYKOYiGP2+v96Hvk4Tu8KpCAsTMs6yi+nVmGh1syvSCsaxz45J6Jbw+9DD6g97+NV67Q==} + + micromark-util-chunked@2.0.1: + resolution: {integrity: sha512-QUNFEOPELfmvv+4xiNg2sRYeS/P84pTW0TCgP5zc9FpXetHY0ab7SxKyAQCNCc1eK0459uoLI1y5oO5Vc1dbhA==} + + micromark-util-classify-character@2.0.1: + resolution: {integrity: sha512-K0kHzM6afW/MbeWYWLjoHQv1sgg2Q9EccHEDzSkxiP/EaagNzCm7T/WMKZ3rjMbvIpvBiZgwR3dKMygtA4mG1Q==} + + micromark-util-combine-extensions@2.0.1: + resolution: {integrity: sha512-OnAnH8Ujmy59JcyZw8JSbK9cGpdVY44NKgSM7E9Eh7DiLS2E9RNQf0dONaGDzEG9yjEl5hcqeIsj4hfRkLH/Bg==} + + micromark-util-decode-numeric-character-reference@2.0.2: + resolution: {integrity: sha512-ccUbYk6CwVdkmCQMyr64dXz42EfHGkPQlBj5p7YVGzq8I7CtjXZJrubAYezf7Rp+bjPseiROqe7G6foFd+lEuw==} + + micromark-util-decode-string@2.0.1: + resolution: {integrity: sha512-nDV/77Fj6eH1ynwscYTOsbK7rR//Uj0bZXBwJZRfaLEJ1iGBR6kIfNmlNqaqJf649EP0F3NWNdeJi03elllNUQ==} + + micromark-util-encode@2.0.1: + resolution: {integrity: sha512-c3cVx2y4KqUnwopcO9b/SCdo2O67LwJJ/UyqGfbigahfegL9myoEFoDYZgkT7f36T0bLrM9hZTAaAyH+PCAXjw==} + + micromark-util-html-tag-name@2.0.1: + resolution: {integrity: sha512-2cNEiYDhCWKI+Gs9T0Tiysk136SnR13hhO8yW6BGNyhOC4qYFnwF1nKfD3HFAIXA5c45RrIG1ub11GiXeYd1xA==} + + micromark-util-normalize-identifier@2.0.1: + resolution: {integrity: sha512-sxPqmo70LyARJs0w2UclACPUUEqltCkJ6PhKdMIDuJ3gSf/Q+/GIe3WKl0Ijb/GyH9lOpUkRAO2wp0GVkLvS9Q==} + + micromark-util-resolve-all@2.0.1: + resolution: {integrity: sha512-VdQyxFWFT2/FGJgwQnJYbe1jjQoNTS4RjglmSjTUlpUMa95Htx9NHeYW4rGDJzbjvCsl9eLjMQwGeElsqmzcHg==} + + micromark-util-sanitize-uri@2.0.1: + resolution: {integrity: sha512-9N9IomZ/YuGGZZmQec1MbgxtlgougxTodVwDzzEouPKo3qFWvymFHWcnDi2vzV1ff6kas9ucW+o3yzJK9YB1AQ==} + + micromark-util-subtokenize@2.1.0: + resolution: {integrity: sha512-XQLu552iSctvnEcgXw6+Sx75GflAPNED1qx7eBJ+wydBb2KCbRZe+NwvIEEMM83uml1+2WSXpBAcp9IUCgCYWA==} + + micromark-util-symbol@2.0.1: + resolution: {integrity: sha512-vs5t8Apaud9N28kgCrRUdEed4UJ+wWNvicHLPxCa9ENlYuAY31M0ETy5y1vA33YoNPDFTghEbnh6efaE8h4x0Q==} + + micromark-util-types@2.0.2: + resolution: {integrity: sha512-Yw0ECSpJoViF1qTU4DC6NwtC4aWGt1EkzaQB8KPPyCRR8z9TWeV0HbEFGTO+ZY1wB22zmxnJqhPyTpOVCpeHTA==} + + micromark@4.0.2: + resolution: {integrity: sha512-zpe98Q6kvavpCr1NPVSCMebCKfD7CA2NqZ+rykeNhONIJBpc1tFKt9hucLGwha3jNTNI8lHpctWJWoimVF4PfA==} + + micromatch@4.0.8: + resolution: {integrity: sha512-PXwfBhYu0hBCPw8Dn0E+WDYb7af3dSLVWKi3HGv84IdF4TyFoC0ysxFd0Goxw7nSv4T/PzEJQxsYsEiFCKo2BA==} + engines: {node: '>=8.6'} + + mime-db@1.52.0: + resolution: {integrity: sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==} + engines: {node: '>= 0.6'} + + mime-types@2.1.35: + resolution: {integrity: sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==} + engines: {node: '>= 0.6'} + + mime@1.6.0: + resolution: {integrity: sha512-x0Vn8spI+wuJ1O6S7gnbaQg8Pxh4NNHb7KSINmEWKiPE4RKOplvijn+NkmYmmRgP68mc70j2EbeTFRsrswaQeg==} + engines: {node: '>=4'} + hasBin: true + + mimic-fn@2.1.0: + resolution: {integrity: sha512-OqbOk5oEQeAZ8WXWydlu9HJjz9WVdEIvamMCcXmuqUYjTknH/sqsWvhQ3vgwKFRR1HpjvNBKQ37nbJgYzGqGcg==} + engines: {node: '>=6'} + + mimic-function@5.0.1: + resolution: {integrity: sha512-VP79XUPxV2CigYP3jWwAUFSku2aKqBH7uTAapFWCBqutsbmDo96KY5o8uh6U+/YSIn5OxJnXp73beVkpqMIGhA==} + engines: {node: '>=18'} + + mimic-response@3.1.0: + resolution: {integrity: sha512-z0yWI+4FDrrweS8Zmt4Ej5HdJmky15+L2e6Wgn3+iK5fWzb6T3fhNFq2+MeTRb064c6Wr4N/wv0DzQTjNzHNGQ==} + engines: {node: '>=10'} + + min-indent@1.0.1: + resolution: {integrity: sha512-I9jwMn07Sy/IwOj3zVkVik2JTvgpaykDZEigL6Rx6N9LbMywwUSMtxET+7lVoDLLd3O3IXwJwvuuns8UB/HeAg==} + engines: {node: '>=4'} + + minimatch@10.0.3: + resolution: {integrity: sha512-IPZ167aShDZZUMdRk66cyQAW3qr0WzbHkPdMYa8bzZhlHhO3jALbKdxcaak7W9FfT2rZNpQuUu4Od7ILEpXSaw==} + engines: {node: 20 || >=22} + + minimatch@3.1.2: + resolution: {integrity: sha512-J7p63hRiAjw1NDEww1W7i37+ByIrOWO5XQQAzZ3VOcL0PNybwpfmV/N05zFAzwQ9USyEcX6t3UO+K5aqBQOIHw==} + + minimatch@5.1.6: + resolution: {integrity: sha512-lKwV/1brpG6mBUFHtb7NUmtABCb2WZZmm2wNiOA5hAb8VdCS4B3dtMWyvcoViccwAW/COERjXLt0zP1zXUN26g==} + engines: {node: '>=10'} + + minimatch@6.2.0: + resolution: {integrity: sha512-sauLxniAmvnhhRjFwPNnJKaPFYyddAgbYdeUpHULtCT/GhzdCx/MDNy+Y40lBxTQUrMzDE8e0S43Z5uqfO0REg==} + engines: {node: '>=10'} + + minimatch@9.0.5: + resolution: {integrity: sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==} + engines: {node: '>=16 || 14 >=14.17'} + + minimist@1.2.8: + resolution: {integrity: sha512-2yyAR8qBkN3YuheJanUpWC5U3bb5osDywNB8RzDVlDwDHbocAJveqqj1u8+SVD7jkWT4yvsHCpWqqWqAxb0zCA==} + + minipass@7.1.2: + resolution: {integrity: sha512-qOOzS1cBTWYF4BH8fVePDBOO9iptMnGUEZwNc/cMWnTV2nVLZ7VoNWEPHkYczZA0pdoA7dl6e7FL659nX9S2aw==} + engines: {node: '>=16 || 14 >=14.17'} + + minizlib@3.0.2: + resolution: {integrity: sha512-oG62iEk+CYt5Xj2YqI5Xi9xWUeZhDI8jjQmC5oThVH5JGCTgIjr7ciJDzC7MBzYd//WvR1OTmP5Q38Q8ShQtVA==} + engines: {node: '>= 18'} + + mkdirp-classic@0.5.3: + resolution: {integrity: sha512-gKLcREMhtuZRwRAfqP3RFW+TK4JqApVBtOIftVgjuABpAtpxhPGaDcfvbhNvD0B8iD1oUr/txX35NjcaY6Ns/A==} + + mkdirp@3.0.1: + resolution: {integrity: sha512-+NsyUUAZDmo6YVHzL/stxSu3t9YS1iljliy3BSDrXJ/dkn1KYdmtZODGGjLcc9XLgVVpH4KshHB8XmZgMhaBXg==} + engines: {node: '>=10'} + hasBin: true + + mlly@1.7.4: + resolution: {integrity: sha512-qmdSIPC4bDJXgZTCR7XosJiNKySV7O215tsPtDN9iEO/7q/76b/ijtgRu/+epFXSJhijtTCCGp3DWS549P3xKw==} + + mocha@10.8.2: + resolution: {integrity: sha512-VZlYo/WE8t1tstuRmqgeyBgCbJc/lEdopaa+axcKzTBJ+UIdlAB9XnmvTCAH4pwR4ElNInaedhEBmZD8iCSVEg==} + engines: {node: '>= 14.0.0'} + hasBin: true + + mrmime@2.0.1: + resolution: {integrity: sha512-Y3wQdFg2Va6etvQ5I82yUhGdsKrcYox6p7FfL1LbK2J4V01F9TGlepTIhnK24t7koZibmg82KGglhA1XK5IsLQ==} + engines: {node: '>=10'} + + ms@2.1.3: + resolution: {integrity: sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==} + + muggle-string@0.4.1: + resolution: {integrity: sha512-VNTrAak/KhO2i8dqqnqnAHOa3cYBwXEZe9h+D5h/1ZqFSTEFHdM65lR7RoIqq3tBBYavsOXV84NoHXZ0AkPyqQ==} + + mute-stream@0.0.8: + resolution: {integrity: sha512-nnbWWOkoWyUsTjKrhgD0dcz22mdkSnpYqbEjIm2nhwhuxlSkpywJmBo8h0ZqJdkp73mb90SssHkN4rsRaBAfAA==} + + mz@2.7.0: + resolution: {integrity: sha512-z81GNO7nnYMEhrGh9LeymoE4+Yr0Wn5McHIZMK5cfQCl+NDX08sCZgUc9/6MHni9IWuFLm1Z3HTCXu2z9fN62Q==} + + nanoid@3.3.11: + resolution: {integrity: sha512-N8SpfPUnUp1bK+PMYW8qSWdl9U+wwNWI4QKxOYDy9JAro3WMX7p2OeVRF9v+347pnakNevPmiHhNmZ2HbFA76w==} + engines: {node: ^10 || ^12 || ^13.7 || ^14 || >=15.0.1} + hasBin: true + + napi-build-utils@2.0.0: + resolution: {integrity: sha512-GEbrYkbfF7MoNaoh2iGG84Mnf/WZfB0GdGEsM8wz7Expx/LlWf5U8t9nvJKXSp3qr5IsEbK04cBGhol/KwOsWA==} + + natural-compare@1.4.0: + resolution: {integrity: sha512-OWND8ei3VtNC9h7V60qff3SVobHr996CTwgxubgyQYEpg290h9J0buyECNNJexkFm5sOajh5G116RYA1c8ZMSw==} + + neo-async@2.6.2: + resolution: {integrity: sha512-Yd3UES5mWCSqR+qNT93S3UoYUkqAZ9lLg8a7g9rimsWmYGK8cVToA4/sF3RrshdyV3sAGMXVUmpMYOw+dLpOuw==} + + nimma@0.2.3: + resolution: {integrity: sha512-1ZOI8J+1PKKGceo/5CT5GfQOG6H8I2BencSK06YarZ2wXwH37BSSUWldqJmMJYA5JfqDqffxDXynt6f11AyKcA==} + engines: {node: ^12.20 || >=14.13} + + node-abi@3.75.0: + resolution: {integrity: sha512-OhYaY5sDsIka7H7AtijtI9jwGYLyl29eQn/W623DiN/MIv5sUqc4g7BIDThX+gb7di9f6xK02nkp8sdfFWZLTg==} + engines: {node: '>=10'} + + node-addon-api@4.3.0: + resolution: {integrity: sha512-73sE9+3UaLYYFmDsFZnqCInzPyh3MqIwZO9cw58yIqAZhONrrabrYyYe3TuIqtIiOuTXVhsGau8hcrhhwSsDIQ==} + + node-fetch-h2@2.3.0: + resolution: {integrity: sha512-ofRW94Ab0T4AOh5Fk8t0h8OBWrmjb0SSB20xh1H8YnPV9EJ+f5AMoYSUQ2zgJ4Iq2HAK0I2l5/Nequ8YzFS3Hg==} + engines: {node: 4.x || >=6.0.0} + + node-fetch@2.7.0: + resolution: {integrity: sha512-c4FRfUm/dbcWZ7U+1Wq0AwCyFL+3nt2bEw05wfxSz+DWpWsitgmSgYmy2dQdWyKC1694ELPqMs/YzUSNozLt8A==} + engines: {node: 4.x || >=6.0.0} + peerDependencies: + encoding: ^0.1.0 + peerDependenciesMeta: + encoding: + optional: true + + node-readfiles@0.2.0: + resolution: {integrity: sha512-SU00ZarexNlE4Rjdm83vglt5Y9yiQ+XI1XpflWlb7q7UTN1JUItm69xMeiQCTxtTfnzt+83T8Cx+vI2ED++VDA==} + + node-releases@2.0.21: + resolution: {integrity: sha512-5b0pgg78U3hwXkCM8Z9b2FJdPZlr9Psr9V2gQPESdGHqbntyFJKFW4r5TeWGFzafGY3hzs1JC62VEQMbl1JFkw==} + + node-sarif-builder@3.2.0: + resolution: {integrity: sha512-kVIOdynrF2CRodHZeP/97Rh1syTUHBNiw17hUCIVhlhEsWlfJm19MuO56s4MdKbr22xWx6mzMnNAgXzVlIYM9Q==} + engines: {node: '>=18'} + + normalize-package-data@6.0.2: + resolution: {integrity: sha512-V6gygoYb/5EmNI+MEGrWkC+e6+Rr7mTmfHrxDbLzxQogBkgzo76rkok0Am6thgSF7Mv2nLOajAJj5vDJZEFn7g==} + engines: {node: ^16.14.0 || >=18.0.0} + + normalize-path@3.0.0: + resolution: {integrity: sha512-6eZs5Ls3WtCisHWp9S2GUy8dqkpGi4BVSz3GaqiE6ezub0512ESztXUwUB6C6IKbQkY2Pnb/mD4WYojCRwcwLA==} + engines: {node: '>=0.10.0'} + + normalize-range@0.1.2: + resolution: {integrity: sha512-bdok/XvKII3nUpklnV6P2hxtMNrCboOjAcyBuQnWEhO665FwrSNRxU+AqpsyvO6LgGYPspN+lu5CLtw4jPRKNA==} + engines: {node: '>=0.10.0'} + + npm-package-arg@12.0.2: + resolution: {integrity: sha512-f1NpFjNI9O4VbKMOlA5QoBq/vSQPORHcTZ2feJpFkTHJ9eQkdlmZEKSjcAhxTGInC7RlEyScT9ui67NaOsjFWA==} + engines: {node: ^18.17.0 || >=20.5.0} + + npm-run-path@4.0.1: + resolution: {integrity: sha512-S48WzZW777zhNIrn7gxOlISNAqi9ZC/uQFnRdbeIHhZhCA6UqpkOT8T1G7BvfdgP4Er8gF4sUbaS0i7QvIfCWw==} + engines: {node: '>=8'} + + nth-check@2.1.1: + resolution: {integrity: sha512-lqjrjmaOoAnWfMmBPL+XNnynZh2+swxiX3WUE0s4yEHI6m+AwrK2UZOimIRl3X/4QctVqS8AiZjFqyOGrMXb/w==} + + nwsapi@2.2.20: + resolution: {integrity: sha512-/ieB+mDe4MrrKMT8z+mQL8klXydZWGR5Dowt4RAGKbJ3kIGEx3X4ljUo+6V73IXtUPWgfOlU5B9MlGxFO5T+cA==} + + oas-kit-common@1.0.8: + resolution: {integrity: sha512-pJTS2+T0oGIwgjGpw7sIRU8RQMcUoKCDWFLdBqKB2BNmGpbBMH2sdqAaOXUg8OzonZHU0L7vfJu1mJFEiYDWOQ==} + + oas-linter@3.2.2: + resolution: {integrity: sha512-KEGjPDVoU5K6swgo9hJVA/qYGlwfbFx+Kg2QB/kd7rzV5N8N5Mg6PlsoCMohVnQmo+pzJap/F610qTodKzecGQ==} + + oas-resolver@2.5.6: + resolution: {integrity: sha512-Yx5PWQNZomfEhPPOphFbZKi9W93CocQj18NlD2Pa4GWZzdZpSJvYwoiuurRI7m3SpcChrnO08hkuQDL3FGsVFQ==} + hasBin: true + + oas-schema-walker@1.1.5: + resolution: {integrity: sha512-2yucenq1a9YPmeNExoUa9Qwrt9RFkjqaMAA1X+U7sbb0AqBeTIdMHky9SQQ6iN94bO5NW0W4TRYXerG+BdAvAQ==} + + oas-validator@5.0.8: + resolution: {integrity: sha512-cu20/HE5N5HKqVygs3dt94eYJfBi0TsZvPVXDhbXQHiEityDN+RROTleefoKRKKJ9dFAF2JBkDHgvWj0sjKGmw==} + + object-assign@4.1.1: + resolution: {integrity: sha512-rJgTQnkUnH1sFw8yT6VSU3zD3sWmu6sZhIseY8VX+GRu3P6F7Fu+JNDoXfklElbLJSnc3FUQHVe4cU5hj+BcUg==} + engines: {node: '>=0.10.0'} + + object-hash@3.0.0: + resolution: {integrity: sha512-RSn9F68PjH9HqtltsSnqYC1XXoWe9Bju5+213R98cNGttag9q9yAOTzdbsqvIa7aNm5WffBZFpWYr2aWrklWAw==} + engines: {node: '>= 6'} + + object-inspect@1.13.4: + resolution: {integrity: sha512-W67iLl4J2EXEGTbfeHCffrjDfitvLANg0UlX3wFUUSTx92KXRFegMHUVgSqE+wvhAbi4WqjGg9czysTV2Epbew==} + engines: {node: '>= 0.4'} + + object-keys@1.1.1: + resolution: {integrity: sha512-NuAESUOUMrlIXOfHKzD6bpPu3tYt3xvjNdRIQ+FeT0lNb4K8WR70CaDxhuNguS2XG+GjkyMwOzsN5ZktImfhLA==} + engines: {node: '>= 0.4'} + + object.assign@4.1.7: + resolution: {integrity: sha512-nK28WOo+QIjBkDduTINE4JkF/UJJKyf2EJxvJKfblDpyg0Q+pkOHNTL0Qwy6NP6FhE/EnzV73BxxqcJaXY9anw==} + engines: {node: '>= 0.4'} + + once@1.4.0: + resolution: {integrity: sha512-lNaJgI+2Q5URQBkccEKHTQOPaXdUxnZZElQTZY0MFUAuaEqe1E+Nyvgdz/aIyNi6Z9MzO5dv1H8n58/GELp3+w==} + + onetime@5.1.2: + resolution: {integrity: sha512-kbpaSSGJTWdAY5KPVeMOKXSrPtr8C8C7wodJbcsd51jRnmD+GZu8Y0VoU6Dm5Z4vWr0Ig/1NKuWRKf7j5aaYSg==} + engines: {node: '>=6'} + + onetime@7.0.0: + resolution: {integrity: sha512-VXJjc87FScF88uafS3JllDgvAm+c/Slfz06lorj2uAY34rlUu0Nt+v8wreiImcrgAjjIHp1rXpTDlLOGw29WwQ==} + engines: {node: '>=18'} + + open@10.2.0: + resolution: {integrity: sha512-YgBpdJHPyQ2UE5x+hlSXcnejzAvD0b22U2OuAP+8OnlJT+PjWPxtgmGqKKc+RgTM63U9gN0YzrYc71R2WT/hTA==} + engines: {node: '>=18'} + + open@8.4.2: + resolution: {integrity: sha512-7x81NCL719oNbsq/3mh+hVrAWmFuEYUqrq/Iw3kUzH8ReypT9QQ0BLoJS7/G9k6N81XjW4qHWtjWwe/9eLy1EQ==} + engines: {node: '>=12'} + + openapi-types@12.1.3: + resolution: {integrity: sha512-N4YtSYJqghVu4iek2ZUvcN/0aqH1kRDuNqzcycDxhOUpg7GdvLa2F3DgS6yBNhInhv2r/6I0Flkn7CqL8+nIcw==} + + openapi3-ts@4.2.2: + resolution: {integrity: sha512-+9g4actZKeb3czfi9gVQ4Br2Ju3KwhCAQJBNaKgye5KggqcBLIhFHH+nIkcm0BUX00TrAJl6dH4JWgM4G4JWrw==} + + openapi3-ts@4.4.0: + resolution: {integrity: sha512-9asTNB9IkKEzWMcHmVZE7Ts3kC9G7AFHfs8i7caD8HbI76gEjdkId4z/AkP83xdZsH7PLAnnbl47qZkXuxpArw==} + + optionator@0.9.4: + resolution: {integrity: sha512-6IpQ7mKUxRcZNLIObR0hz7lxsapSSIYNZJwXPGeF0mTVqGKFIXj1DQcMoT22S3ROcLyY/rz0PWaWZ9ayWmad9g==} + engines: {node: '>= 0.8.0'} + + ora@8.2.0: + resolution: {integrity: sha512-weP+BZ8MVNnlCm8c0Qdc1WSWq4Qn7I+9CJGm7Qali6g44e/PUzbjNqJX5NJ9ljlNMosfJvg1fKEGILklK9cwnw==} + engines: {node: '>=18'} + + orval@7.10.0: + resolution: {integrity: sha512-R1TlDDgK82dHfTXG0IuaIXHOrk6HQ1CuGejQQpQW9mBSCQA84AInp8U4Ovxw3upjMFNhghE8OlAQqD0ES8GgHQ==} + hasBin: true + + own-keys@1.0.1: + resolution: {integrity: sha512-qFOyK5PjiWZd+QQIh+1jhdb9LpxTF0qs7Pm8o5QHYZ0M3vKqSqzsZaEB6oWlxZ+q2sJBMI/Ktgd2N5ZwQoRHfg==} + engines: {node: '>= 0.4'} + + p-limit@3.1.0: + resolution: {integrity: sha512-TYOanM3wGwNGsZN2cVTYPArw454xnXj5qmWF1bEoAc4+cU/ol7GVh7odevjp1FNHduHc3KZMcFduxU5Xc6uJRQ==} + engines: {node: '>=10'} + + p-limit@4.0.0: + resolution: {integrity: sha512-5b0R4txpzjPWVw/cXXUResoD4hb6U/x9BH08L7nw+GN1sezDzPdxeRvpc9c433fZhBan/wusjbCsqwqm4EIBIQ==} + engines: {node: ^12.20.0 || ^14.13.1 || >=16.0.0} + + p-locate@5.0.0: + resolution: {integrity: sha512-LaNjtRWUBY++zB5nE/NwcaoMylSPk+S+ZHNB1TzdbMJMny6dynpAGt7X/tl/QYq3TIeE6nxHppbo2LGymrG5Pw==} + engines: {node: '>=10'} + + p-locate@6.0.0: + resolution: {integrity: sha512-wPrq66Llhl7/4AGC6I+cqxT07LhXvWL08LNXz1fENOw0Ap4sRZZ/gZpTTJ5jpurzzzfS2W/Ge9BY3LgLjCShcw==} + engines: {node: ^12.20.0 || ^14.13.1 || >=16.0.0} + + p-map@7.0.3: + resolution: {integrity: sha512-VkndIv2fIB99swvQoA65bm+fsmt6UNdGeIB0oxBs+WhAhdh08QA04JXpI7rbB9r08/nkbysKoya9rtDERYOYMA==} + engines: {node: '>=18'} + + package-json-from-dist@1.0.1: + resolution: {integrity: sha512-UEZIS3/by4OC8vL3P2dTXRETpebLI2NiI5vIrjaD/5UtrkFX/tNbwjTSRAGC/+7CAo2pIcBaRgWmcBBHcsaCIw==} + + pako@1.0.11: + resolution: {integrity: sha512-4hLB8Py4zZce5s4yd9XzopqwVv/yGNhV1Bl8NTmCq1763HeK2+EwVTv+leGeL13Dnh2wfbqowVPXCIO0z4taYw==} + + parent-module@1.0.1: + resolution: {integrity: sha512-GQ2EWRpQV8/o+Aw8YqtfZZPfNRWZYkbidE9k5rpl/hC3vtHHBfGm2Ifi6qWV+coDGkrUKZAxE3Lot5kcsRlh+g==} + engines: {node: '>=6'} + + parse-entities@4.0.2: + resolution: {integrity: sha512-GG2AQYWoLgL877gQIKeRPGO1xF9+eG1ujIb5soS5gPvLQ1y2o8FL90w2QWNdf9I361Mpp7726c+lj3U0qK1uGw==} + + parse-json@5.2.0: + resolution: {integrity: sha512-ayCKvm/phCGxOkYRSCM82iDwct8/EonSEgCSxWxD7ve6jHggsFl4fZVQBPRNgQoKiuV/odhFrGzQXZwbifC8Rg==} + engines: {node: '>=8'} + + parse-json@8.3.0: + resolution: {integrity: sha512-ybiGyvspI+fAoRQbIPRddCcSTV9/LsJbf0e/S85VLowVGzRmokfneg2kwVW/KU5rOXrPSbF1qAKPMgNTqqROQQ==} + engines: {node: '>=18'} + + parse-semver@1.1.1: + resolution: {integrity: sha512-Eg1OuNntBMH0ojvEKSrvDSnwLmvVuUOSdylH/pSCPNMIspLlweJyIWXCE+k/5hm3cj/EBUYwmWkjhBALNP4LXQ==} + + parse5-htmlparser2-tree-adapter@7.1.0: + resolution: {integrity: sha512-ruw5xyKs6lrpo9x9rCZqZZnIUntICjQAd0Wsmp396Ul9lN/h+ifgVV1x1gZHi8euej6wTfpqX8j+BFQxF0NS/g==} + + parse5-parser-stream@7.1.2: + resolution: {integrity: sha512-JyeQc9iwFLn5TbvvqACIF/VXG6abODeB3Fwmv/TGdLk2LfbWkaySGY72at4+Ty7EkPZj854u4CrICqNk2qIbow==} + + parse5@7.3.0: + resolution: {integrity: sha512-IInvU7fabl34qmi9gY8XOVxhYyMyuH2xUNpb2q8/Y+7552KlejkRvqvD19nMoUW/uQGGbqNpA6Tufu5FL5BZgw==} + + path-browserify@1.0.1: + resolution: {integrity: sha512-b7uo2UCUOYZcnF/3ID0lulOJi/bafxa1xPe7ZPsammBSpjSWQkjNxlt635YGS2MiR9GjvuXCtz2emr3jbsz98g==} + + path-exists@4.0.0: + resolution: {integrity: sha512-ak9Qy5Q7jYb2Wwcey5Fpvg2KoAc/ZIhLSLOSBmRmygPsGwkVVt0fZa0qrtMz+m6tJTAHfZQ8FnmB4MG4LWy7/w==} + engines: {node: '>=8'} + + path-exists@5.0.0: + resolution: {integrity: sha512-RjhtfwJOxzcFmNOi6ltcbcu4Iu+FL3zEj83dk4kAS+fVpTxXLO1b38RvJgT/0QwvV/L3aY9TAnyv0EOqW4GoMQ==} + engines: {node: ^12.20.0 || ^14.13.1 || >=16.0.0} + + path-is-absolute@1.0.1: + resolution: {integrity: sha512-AVbw3UJ2e9bq64vSaS9Am0fje1Pa8pbGqTTsmXfaIiMpnr5DlDhfJOuLj9Sf95ZPVDAUerDfEk88MPmPe7UCQg==} + engines: {node: '>=0.10.0'} + + path-key@3.1.1: + resolution: {integrity: sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==} + engines: {node: '>=8'} + + path-parse@1.0.7: + resolution: {integrity: sha512-LDJzPVEEEPR+y48z93A0Ed0yXb8pAByGWo/k5YYdYgpY2/2EsOsksJrq7lOHxryrVOn1ejG6oAp8ahvOIQD8sw==} + + path-scurry@1.11.1: + resolution: {integrity: sha512-Xa4Nw17FS9ApQFJ9umLiJS4orGjm7ZzwUrwamcGQuHSzDyth9boKDaycYdDcZDuqYATXw4HFXgaqWTctW/v1HA==} + engines: {node: '>=16 || 14 >=14.18'} + + path-scurry@2.0.0: + resolution: {integrity: sha512-ypGJsmGtdXUOeM5u93TyeIEfEhM6s+ljAhrk5vAvSx8uyY/02OvrZnA0YNGUrPXfpJMgI1ODd3nwz8Npx4O4cg==} + engines: {node: 20 || >=22} + + path-type@4.0.0: + resolution: {integrity: sha512-gDKb8aZMDeD/tZWs9P6+q0J9Mwkdl6xMV8TjnGP3qJVJ06bdMgkbBlLU8IdfOsIsFz2BW1rNVT3XuNEl8zPAvw==} + engines: {node: '>=8'} + + path-type@6.0.0: + resolution: {integrity: sha512-Vj7sf++t5pBD637NSfkxpHSMfWaeig5+DKWLhcqIYx6mWQz5hdJTGDVMQiJcw1ZYkhs7AazKDGpRVji1LJCZUQ==} + engines: {node: '>=18'} + + pathe@2.0.3: + resolution: {integrity: sha512-WUjGcAqP1gQacoQe+OBJsFA7Ld4DyXuUIjZ5cc75cLHvJ7dtNsTugphxIADwspS+AraAUePCKrSVtPLFj/F88w==} + + pathval@2.0.1: + resolution: {integrity: sha512-//nshmD55c46FuFw26xV/xFAaB5HF9Xdap7HJBBnrKdAd6/GxDBaNA1870O79+9ueg61cZLSVc+OaFlfmObYVQ==} + engines: {node: '>= 14.16'} + + pend@1.2.0: + resolution: {integrity: sha512-F3asv42UuXchdzt+xXqfW1OGlVBe+mxa2mqI0pg5yAHZPvFmY3Y6drSf/GQ1A86WgWEN9Kzh/WrgKa6iGcHXLg==} + + picocolors@1.1.1: + resolution: {integrity: sha512-xceH2snhtb5M9liqDsmEw56le376mTZkEX/jEb/RxNFyegNul7eNslCXP9FDj/Lcu0X8KEyMceP2ntpaHrDEVA==} + + picomatch@2.3.1: + resolution: {integrity: sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==} + engines: {node: '>=8.6'} + + picomatch@4.0.3: + resolution: {integrity: sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==} + engines: {node: '>=12'} + + pify@2.3.0: + resolution: {integrity: sha512-udgsAY+fTnvv7kI7aaxbqwWNb0AHiB0qBO89PZKPkoTmGOgdbrHDKD+0B2X4uTfJ/FT1R09r9gTsjUjNJotuog==} + engines: {node: '>=0.10.0'} + + pirates@4.0.7: + resolution: {integrity: sha512-TfySrs/5nm8fQJDcBDuUng3VOUKsd7S+zqvbOTiGXHfxX4wK31ard+hoNuvkicM/2YFzlpDgABOevKSsB4G/FA==} + engines: {node: '>= 6'} + + pkg-types@1.3.1: + resolution: {integrity: sha512-/Jm5M4RvtBFVkKWRu2BLUTNP8/M2a+UwuAX+ae4770q1qVGtfjG+WTCupoZixokjmHiry8uI+dlY8KXYV5HVVQ==} + + pkg-types@2.2.0: + resolution: {integrity: sha512-2SM/GZGAEkPp3KWORxQZns4M+WSeXbC2HEvmOIJe3Cmiv6ieAJvdVhDldtHqM5J1Y7MrR1XhkBT/rMlhh9FdqQ==} + + playwright-core@1.54.1: + resolution: {integrity: sha512-Nbjs2zjj0htNhzgiy5wu+3w09YetDx5pkrpI/kZotDlDUaYk0HVA5xrBVPdow4SAUIlhgKcJeJg4GRKW6xHusA==} + engines: {node: '>=18'} + hasBin: true + + playwright@1.54.1: + resolution: {integrity: sha512-peWpSwIBmSLi6aW2auvrUtf2DqY16YYcCMO8rTVx486jKmDTJg7UAhyrraP98GB8BoPURZP8+nxO7TSd4cPr5g==} + engines: {node: '>=18'} + hasBin: true + + pluralize@2.0.0: + resolution: {integrity: sha512-TqNZzQCD4S42De9IfnnBvILN7HAW7riLqsCyp8lgjXeysyPlX5HhqKAcJHHHb9XskE4/a+7VGC9zzx8Ls0jOAw==} + + pluralize@8.0.0: + resolution: {integrity: sha512-Nc3IT5yHzflTfbjgqWcCPpo7DaKy4FnpB0l/zCAW0Tc7jxAiuqSxHasntB3D7887LSrA93kDJ9IXovxJYxyLCA==} + engines: {node: '>=4'} + + pony-cause@1.1.1: + resolution: {integrity: sha512-PxkIc/2ZpLiEzQXu5YRDOUgBlfGYBY8156HY5ZcRAwwonMk5W/MrJP2LLkG/hF7GEQzaHo2aS7ho6ZLCOvf+6g==} + engines: {node: '>=12.0.0'} + + possible-typed-array-names@1.1.0: + resolution: {integrity: sha512-/+5VFTchJDoVj3bhoqi6UeymcD00DAwb1nJwamzPvHEszJ4FpF6SNNbUbOS8yI56qHzdV8eK0qEfOSiodkTdxg==} + engines: {node: '>= 0.4'} + + postcss-import@15.1.0: + resolution: {integrity: sha512-hpr+J05B2FVYUAXHeK1YyI267J/dDDhMU6B6civm8hSY1jYJnBXxzKDKDswzJmtLHryrjhnDjqqp/49t8FALew==} + engines: {node: '>=14.0.0'} + peerDependencies: + postcss: ^8.0.0 + + postcss-js@4.0.1: + resolution: {integrity: sha512-dDLF8pEO191hJMtlHFPRa8xsizHaM82MLfNkUHdUtVEV3tgTp5oj+8qbEqYM57SLfc74KSbw//4SeJma2LRVIw==} + engines: {node: ^12 || ^14 || >= 16} + peerDependencies: + postcss: ^8.4.21 + + postcss-load-config@4.0.2: + resolution: {integrity: sha512-bSVhyJGL00wMVoPUzAVAnbEoWyqRxkjv64tUl427SKnPrENtq6hJwUojroMz2VB+Q1edmi4IfrAPpami5VVgMQ==} + engines: {node: '>= 14'} + peerDependencies: + postcss: '>=8.0.9' + ts-node: '>=9.0.0' + peerDependenciesMeta: + postcss: + optional: true + ts-node: + optional: true + + postcss-nested@6.2.0: + resolution: {integrity: sha512-HQbt28KulC5AJzG+cZtj9kvKB93CFCdLvog1WFLf1D+xmMvPGlBstkpTEZfK5+AN9hfJocyBFCNiqyS48bpgzQ==} + engines: {node: '>=12.0'} + peerDependencies: + postcss: ^8.2.14 + + postcss-selector-parser@6.0.10: + resolution: {integrity: sha512-IQ7TZdoaqbT+LCpShg46jnZVlhWD2w6iQYAcYXfHARZ7X1t/UGhhceQDs5X0cGqKvYlHNOuv7Oa1xmb0oQuA3w==} + engines: {node: '>=4'} + + postcss-selector-parser@6.1.2: + resolution: {integrity: sha512-Q8qQfPiZ+THO/3ZrOrO0cJJKfpYCagtMUkXbnEfmgUjwXg6z/WBeOyS9APBBPCTSiDV+s4SwQGu8yFsiMRIudg==} + engines: {node: '>=4'} + + postcss-value-parser@4.2.0: + resolution: {integrity: sha512-1NNCs6uurfkVbeXG4S8JFT9t19m45ICnif8zWLd5oPSZ50QnwMfK+H3jv408d4jw/7Bttv5axS5IiHoLaVNHeQ==} + + postcss@8.5.6: + resolution: {integrity: sha512-3Ybi1tAuwAP9s0r1UQ2J4n5Y0G05bJkpUIO0/bI9MhwmD70S5aTWbXGBwxHrelT+XM1k6dM0pk+SwNkpTRN7Pg==} + engines: {node: ^10 || ^12 || >=14} + + prebuild-install@7.1.3: + resolution: {integrity: sha512-8Mf2cbV7x1cXPUILADGI3wuhfqWvtiLA1iclTDbFRZkgRQS0NqsPZphna9V+HyTEadheuPmjaJMsbzKQFOzLug==} + engines: {node: '>=10'} + hasBin: true + + prelude-ls@1.2.1: + resolution: {integrity: sha512-vkcDPrRZo1QZLbn5RLGPpg/WmIQ65qoWWhcGKf/b5eplkkarX0m9z8ppCat4mlOqUsWpyNuYgO3VRyrYHSzX5g==} + engines: {node: '>= 0.8.0'} + + prettier@3.6.2: + resolution: {integrity: sha512-I7AIg5boAr5R0FFtJ6rCfD+LFsWHp81dolrFD8S79U9tb8Az2nGrJncnMSnys+bpQJfRUzqs9hnA81OAA3hCuQ==} + engines: {node: '>=14'} + hasBin: true + + pretty-format@27.5.1: + resolution: {integrity: sha512-Qb1gy5OrP5+zDf2Bvnzdl3jsTf1qXVMazbvCoKhtKqVs4/YK4ozX4gKQJJVyNe+cajNPn0KoC0MC3FUmaHWEmQ==} + engines: {node: ^10.13.0 || ^12.13.0 || ^14.15.0 || >=15.0.0} + + proc-log@5.0.0: + resolution: {integrity: sha512-Azwzvl90HaF0aCz1JrDdXQykFakSSNPaPoiZ9fm5qJIMHioDZEi7OAdRwSm6rSoPtY3Qutnm3L7ogmg3dc+wbQ==} + engines: {node: ^18.17.0 || >=20.5.0} + + process-nextick-args@2.0.1: + resolution: {integrity: sha512-3ouUOpQhtgrbOa17J7+uxOTpITYWaGP7/AhoR3+A+/1e9skrzelGi/dXzEYyvbxubEF6Wn2ypscTKiKJFFn1ag==} + + prompts@2.4.2: + resolution: {integrity: sha512-NxNv/kLguCA7p3jE8oL2aEBsrJWgAakBpgmgK6lpPWV+WuOmY6r2/zbAVnP+T8bQlA0nzHXSJSJW0Hq7ylaD2Q==} + engines: {node: '>= 6'} + + prop-types@15.8.1: + resolution: {integrity: sha512-oj87CgZICdulUohogVAR7AjlC0327U4el4L6eAvOqCeudMDVU0NThNaV+b9Df4dXgSP1gXMTnPdhfe/2qDH5cg==} + + property-information@7.1.0: + resolution: {integrity: sha512-TwEZ+X+yCJmYfL7TPUOcvBZ4QfoT5YenQiJuX//0th53DE6w0xxLEtfK3iyryQFddXuvkIk51EEgrJQ0WJkOmQ==} + + pump@3.0.3: + resolution: {integrity: sha512-todwxLMY7/heScKmntwQG8CXVkWUOdYxIvY2s0VWAAMh/nd8SoYiRaKjlr7+iCs984f2P8zvrfWcDDYVb73NfA==} + + punycode.js@2.3.1: + resolution: {integrity: sha512-uxFIHU0YlHYhDQtV4R9J6a52SLx28BCjT+4ieh7IGbgwVJWO+km431c4yRlREUAsAmt/uMjQUyQHNEPf0M39CA==} + engines: {node: '>=6'} + + punycode@2.3.1: + resolution: {integrity: sha512-vYt7UD1U9Wg6138shLtLOvdAu+8DsC/ilFtEVHcH+wydcSpNE20AfSOduf6MkRFahL5FY7X1oU7nKVZFtfq8Fg==} + engines: {node: '>=6'} + + pure-rand@6.1.0: + resolution: {integrity: sha512-bVWawvoZoBYpp6yIoQtQXHZjmz35RSVHnUOTefl8Vcjr8snTPY1wnpSPMWekcFwbxI6gtmT7rSYPFvz71ldiOA==} + + qs@6.14.0: + resolution: {integrity: sha512-YWWTjgABSKcvs/nWBi9PycY/JiPJqOD4JA6o9Sej2AtvSGarXxKC3OQSk4pAarbdQlKAh5D4FCQkJNkW+GAn3w==} + engines: {node: '>=0.6'} + + quansync@0.2.10: + resolution: {integrity: sha512-t41VRkMYbkHyCYmOvx/6URnN80H7k4X0lLdBMGsz+maAwrJQYB1djpV6vHrQIBE0WBSGqhtEHrK9U3DWWH8v7A==} + + queue-microtask@1.2.3: + resolution: {integrity: sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A==} + + randombytes@2.1.0: + resolution: {integrity: sha512-vYl3iOX+4CKUWuxGi9Ukhie6fsqXqS9FE2Zaic4tNFD2N2QQaXOMFbuKK4QmDHC0JO6B1Zp41J0LpT0oR68amQ==} + + rc-config-loader@4.1.3: + resolution: {integrity: sha512-kD7FqML7l800i6pS6pvLyIE2ncbk9Du8Q0gp/4hMPhJU6ZxApkoLcGD8ZeqgiAlfwZ6BlETq6qqe+12DUL207w==} + + rc@1.2.8: + resolution: {integrity: sha512-y3bGgqKj3QBdxLbLkomlohkvsA8gdAiUQlSBJnBhfn+BPxg4bc62d8TcBW15wavDfgexCgccckhcZvywyQYPOw==} + hasBin: true + + react-dnd-html5-backend@16.0.1: + resolution: {integrity: sha512-Wu3dw5aDJmOGw8WjH1I1/yTH+vlXEL4vmjk5p+MHxP8HuHJS1lAGeIdG/hze1AvNeXWo/JgULV87LyQOr+r5jw==} + + react-dnd@16.0.1: + resolution: {integrity: sha512-QeoM/i73HHu2XF9aKksIUuamHPDvRglEwdHL4jsp784BgUuWcg6mzfxT0QDdQz8Wj0qyRKx2eMg8iZtWvU4E2Q==} + peerDependencies: + '@types/hoist-non-react-statics': '>= 3.3.1' + '@types/node': '>= 12' + '@types/react': '>= 16' + react: '>= 16.14' + peerDependenciesMeta: + '@types/hoist-non-react-statics': + optional: true + '@types/node': + optional: true + '@types/react': + optional: true + + react-docgen-typescript@2.4.0: + resolution: {integrity: sha512-ZtAp5XTO5HRzQctjPU0ybY0RRCQO19X/8fxn3w7y2VVTUbGHDKULPTL4ky3vB05euSgG5NpALhEhDPvQ56wvXg==} + peerDependencies: + typescript: '>= 4.3.x' + + react-docgen@8.0.0: + resolution: {integrity: sha512-kmob/FOTwep7DUWf9KjuenKX0vyvChr3oTdvvPt09V60Iz75FJp+T/0ZeHMbAfJj2WaVWqAPP5Hmm3PYzSPPKg==} + engines: {node: ^20.9.0 || >=22} + + react-dom@18.3.1: + resolution: {integrity: sha512-5m4nQKp+rZRb09LNH59GM4BxTh9251/ylbKIbpe7TpGxfJ+9kv6BLkLBXIjjspbgbnIBNqlI23tRnTWT0snUIw==} + peerDependencies: + react: ^18.3.1 + + react-is@16.13.1: + resolution: {integrity: sha512-24e6ynE2H+OKt4kqsOvNd8kBpV65zoxbA4BVsEOB3ARVWQki/DHzaUoC5KuON/BiccDaCCTZBuOcfZs70kR8bQ==} + + react-is@17.0.2: + resolution: {integrity: sha512-w2GsyukL62IJnlaff/nRegPQR94C/XXamvMWmSHRJ4y7Ts/4ocGRmTHvOs8PSE6pB3dWOrD/nueuU5sduBsQ4w==} + + react-markdown@10.1.0: + resolution: {integrity: sha512-qKxVopLT/TyA6BX3Ue5NwabOsAzm0Q7kAPwq6L+wWDwisYs7R8vZ0nRXqq6rkueboxpkjvLGU9fWifiX/ZZFxQ==} + peerDependencies: + '@types/react': '>=18' + react: '>=18' + + react-refresh@0.17.0: + resolution: {integrity: sha512-z6F7K9bV85EfseRCp2bzrpyQ0Gkw1uLoCel9XBVWPg/TjRj94SkJzUTGfOa4bs7iJvBWtQG0Wq7wnI0syw3EBQ==} + engines: {node: '>=0.10.0'} + + react-remove-scroll-bar@2.3.8: + resolution: {integrity: sha512-9r+yi9+mgU33AKcj6IbT9oRCO78WriSj6t/cF8DWBZJ9aOGPOTEDvdUDz1FwKim7QXWwmHqtdHnRJfhAxEG46Q==} + engines: {node: '>=10'} + peerDependencies: + '@types/react': '*' + react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 + peerDependenciesMeta: + '@types/react': + optional: true + + react-remove-scroll@2.7.1: + resolution: {integrity: sha512-HpMh8+oahmIdOuS5aFKKY6Pyog+FNaZV/XyJOq7b4YFwsFHe5yYfdbIalI4k3vU2nSDql7YskmUseHsRrJqIPA==} + engines: {node: '>=10'} + peerDependencies: + '@types/react': '*' + react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + + react-router@7.7.0: + resolution: {integrity: sha512-3FUYSwlvB/5wRJVTL/aavqHmfUKe0+Xm9MllkYgGo9eDwNdkvwlJGjpPxono1kCycLt6AnDTgjmXvK3/B4QGuw==} + engines: {node: '>=20.0.0'} + peerDependencies: + react: '>=18' + react-dom: '>=18' + peerDependenciesMeta: + react-dom: + optional: true + + react-split@2.0.14: + resolution: {integrity: sha512-bKWydgMgaKTg/2JGQnaJPg51T6dmumTWZppFgEbbY0Fbme0F5TuatAScCLaqommbGQQf/ZT1zaejuPDriscISA==} + peerDependencies: + react: '*' + + react-style-singleton@2.2.3: + resolution: {integrity: sha512-b6jSvxvVnyptAiLjbkWLE/lOnR4lfTtDAl+eUC7RZy+QQWc6wRzIV2CE6xBuMmDxc2qIihtDCZD5NPOFl7fRBQ==} + engines: {node: '>=10'} + peerDependencies: + '@types/react': '*' + react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + + react@18.3.1: + resolution: {integrity: sha512-wS+hAgJShR0KhEvPJArfuPVN1+Hz1t0Y6n5jLrGQbkb4urgPE/0Rve+1kMB1v/oWgHgm4WIcV+i7F2pTVj+2iQ==} + engines: {node: '>=0.10.0'} + + reactflow@11.11.4: + resolution: {integrity: sha512-70FOtJkUWH3BAOsN+LU9lCrKoKbtOPnz2uq0CV2PLdNSwxTXOhCbsZr50GmZ+Rtw3jx8Uv7/vBFtCGixLfd4Og==} + peerDependencies: + react: '>=17' + react-dom: '>=17' + + read-cache@1.0.0: + resolution: {integrity: sha512-Owdv/Ft7IjOgm/i0xvNDZ1LrRANRfew4b2prF3OWMQLxLfu3bS8FVhCsrSCMK4lR56Y9ya+AThoTpDCTxCmpRA==} + + read-pkg@9.0.1: + resolution: {integrity: sha512-9viLL4/n1BJUCT1NXVTdS1jtm80yDEgR5T4yCelII49Mbj0v1rZdKqj7zCiYdbB0CuCgdrvHcNogAKTFPBocFA==} + engines: {node: '>=18'} + + read-yaml-file@2.1.0: + resolution: {integrity: sha512-UkRNRIwnhG+y7hpqnycCL/xbTk7+ia9VuVTC0S+zVbwd65DI9eUpRMfsWIGrCWxTU/mi+JW8cHQCrv+zfCbEPQ==} + engines: {node: '>=10.13'} + + read@1.0.7: + resolution: {integrity: sha512-rSOKNYUmaxy0om1BNjMN4ezNT6VKK+2xF4GBhc81mkH7L60i6dp8qPYrkndNLT3QPphoII3maL9PVC9XmhHwVQ==} + engines: {node: '>=0.8'} + + readable-stream@2.3.8: + resolution: {integrity: sha512-8p0AUk4XODgIewSi0l8Epjs+EVnWiK7NoDIEGU0HhE7+ZyY8D1IMY7odu5lRrFXGg71L15KG8QrPmum45RTtdA==} + + readable-stream@3.6.2: + resolution: {integrity: sha512-9u/sniCrY3D5WdsERHzHE4G2YCXqoG5FTHUiCC4SIbr6XcLZBY05ya9EKjYek9O5xOAwjGq+1JdGBAS7Q9ScoA==} + engines: {node: '>= 6'} + + readdirp@3.6.0: + resolution: {integrity: sha512-hOS089on8RduqdbhvQ5Z37A0ESjsqz6qnRcffsMU3495FuTdqSm+7bhJ29JvIOsBDEEnan5DPu9t3To9VRlMzA==} + engines: {node: '>=8.10.0'} + + readdirp@4.1.2: + resolution: {integrity: sha512-GDhwkLfywWL2s6vEjyhri+eXmfH6j1L7JE27WhqLeYzoh/A3DBaYGEj2H/HFZCn/kMfim73FXxEJTw06WtxQwg==} + engines: {node: '>= 14.18.0'} + + recast@0.23.11: + resolution: {integrity: sha512-YTUo+Flmw4ZXiWfQKGcwwc11KnoRAYgzAE2E7mXKCjSviTKShtxBsN6YUUBB2gtaBzKzeKunxhUwNHQuRryhWA==} + engines: {node: '>= 4'} + + redent@3.0.0: + resolution: {integrity: sha512-6tDA8g98We0zd0GvVeMT9arEOnTw9qM03L9cJXaCjrip1OO764RDBLBfrB4cwzNGDj5OA5ioymC9GkizgWJDUg==} + engines: {node: '>=8'} + + redux@4.2.1: + resolution: {integrity: sha512-LAUYz4lc+Do8/g7aeRa8JkyDErK6ekstQaqWQrNRW//MY1TvCEpMtpTWvlQ+FPbWCx+Xixu/6SHt5N0HR+SB4w==} + + reflect.getprototypeof@1.0.10: + resolution: {integrity: sha512-00o4I+DVrefhv+nX0ulyi3biSHCPDe+yLv5o/p6d/UVlirijB8E16FtfwSAi4g3tcqrQ4lRAqQSoFEZJehYEcw==} + engines: {node: '>= 0.4'} + + reftools@1.1.9: + resolution: {integrity: sha512-OVede/NQE13xBQ+ob5CKd5KyeJYU2YInb1bmV4nRoOfquZPkAkxuOXicSe1PvqIuZZ4kD13sPKBbR7UFDmli6w==} + + regexp.prototype.flags@1.5.4: + resolution: {integrity: sha512-dYqgNSZbDwkaJ2ceRd9ojCGjBq+mOm9LmtXnAnEGyHhN/5R7iDW2TRw3h+o/jCFxus3P2LfWIIiwowAjANm7IA==} + engines: {node: '>= 0.4'} + + remark-parse@11.0.0: + resolution: {integrity: sha512-FCxlKLNGknS5ba/1lmpYijMUzX2esxW5xQqjWxw2eHFfS2MSdaHVINFmhjo+qN1WhZhNimq0dZATN9pH0IDrpA==} + + remark-rehype@11.1.2: + resolution: {integrity: sha512-Dh7l57ianaEoIpzbp0PC9UKAdCSVklD8E5Rpw7ETfbTl3FqcOOgq5q2LVDhgGCkaBv7p24JXikPdvhhmHvKMsw==} + + require-directory@2.1.1: + resolution: {integrity: sha512-fGxEI7+wsG9xrvdjsrlmL22OMTTiHRwAMroiEeMgq8gzoLC/PQr7RsRDSTLUg/bZAZtF+TVIkHc6/4RIKrui+Q==} + engines: {node: '>=0.10.0'} + + require-from-string@2.0.2: + resolution: {integrity: sha512-Xf0nWe6RseziFMu+Ap9biiUbmplq6S9/p+7w7YXP/JBHhrUDDUhwa+vANyubuqfZWTveU//DYVGsDG7RKL/vEw==} + engines: {node: '>=0.10.0'} + + resolve-from@4.0.0: + resolution: {integrity: sha512-pb/MYmXstAkysRFx8piNI1tGFNQIFA3vkE3Gq4EuA1dF6gHp/+vgZqsCGJapvy8N3Q+4o7FwvquPJcnZ7RYy4g==} + engines: {node: '>=4'} + + resolve-pkg-maps@1.0.0: + resolution: {integrity: sha512-seS2Tj26TBVOC2NIc2rOe2y2ZO7efxITtLZcGSOnHHNOQ7CkiUBfw0Iw2ck6xkIhPwLhKNLS8BO+hEpngQlqzw==} + + resolve@1.22.10: + resolution: {integrity: sha512-NPRy+/ncIMeDlTAsuqwKIiferiawhefFJtkNSW0qZJEqMEb+qBt/77B/jGeeek+F0uOeN05CDa6HXbbIgtVX4w==} + engines: {node: '>= 0.4'} + hasBin: true + + restore-cursor@5.1.0: + resolution: {integrity: sha512-oMA2dcrw6u0YfxJQXm342bFKX/E4sG9rbTzO9ptUcR/e8A33cHuvStiYOwH7fszkZlZ1z/ta9AAoPk2F4qIOHA==} + engines: {node: '>=18'} + + reusify@1.1.0: + resolution: {integrity: sha512-g6QUff04oZpHs0eG5p83rFLhHeV00ug/Yf9nZM6fLeUrPguBTkTQOdpAWWspMh55TZfVQDPaN3NQJfbVRAxdIw==} + engines: {iojs: '>=1.0.0', node: '>=0.10.0'} + + rollup@4.45.1: + resolution: {integrity: sha512-4iya7Jb76fVpQyLoiVpzUrsjQ12r3dM7fIVz+4NwoYvZOShknRmiv+iu9CClZml5ZLGb0XMcYLutK6w9tgxHDw==} + engines: {node: '>=18.0.0', npm: '>=8.0.0'} + hasBin: true + + rrweb-cssom@0.8.0: + resolution: {integrity: sha512-guoltQEx+9aMf2gDZ0s62EcV8lsXR+0w8915TC3ITdn2YueuNjdAYh/levpU9nFaoChh9RUS5ZdQMrKfVEN9tw==} + + run-applescript@7.0.0: + resolution: {integrity: sha512-9by4Ij99JUr/MCFBUkDKLWK3G9HVXmabKz9U5MlIAIuvuzkiOicRYs8XJLxX+xahD+mLiiCYDqF9dKAgtzKP1A==} + engines: {node: '>=18'} + + run-parallel@1.2.0: + resolution: {integrity: sha512-5l4VyZR86LZ/lDxZTR6jqL8AFE2S0IFLMP26AbjsLVADxHdhB/c0GUsH+y39UfCi3dzz8OlQuPmnaJOMoDHQBA==} + + safe-array-concat@1.1.3: + resolution: {integrity: sha512-AURm5f0jYEOydBj7VQlVvDrjeFgthDdEF5H1dP+6mNpoXOMo1quQqJ4wvJDyRZ9+pO3kGWoOdmV08cSv2aJV6Q==} + engines: {node: '>=0.4'} + + safe-buffer@5.1.2: + resolution: {integrity: sha512-Gd2UZBJDkXlY7GbJxfsE8/nvKkUEU1G38c1siN6QP6a9PT9MmHB8GnpscSmMJSoF8LOIrt8ud/wPtojys4G6+g==} + + safe-buffer@5.2.1: + resolution: {integrity: sha512-rp3So07KcdmmKbGvgaNxQSJr7bGVSVk5S9Eq1F+ppbRo70+YeaDxkw5Dd8NPN+GD6bjnYm2VuPuCXmpuYvmCXQ==} + + safe-push-apply@1.0.0: + resolution: {integrity: sha512-iKE9w/Z7xCzUMIZqdBsp6pEQvwuEebH4vdpjcDWnyzaI6yl6O9FHvVpmGelvEHNsoY6wGblkxR6Zty/h00WiSA==} + engines: {node: '>= 0.4'} + + safe-regex-test@1.1.0: + resolution: {integrity: sha512-x/+Cz4YrimQxQccJf5mKEbIa1NzeCRNI5Ecl/ekmlYaampdNLPalVyIcCZNNH3MvmqBugV5TMYZXv0ljslUlaw==} + engines: {node: '>= 0.4'} + + safe-stable-stringify@1.1.1: + resolution: {integrity: sha512-ERq4hUjKDbJfE4+XtZLFPCDi8Vb1JqaxAPTxWFLBx8XcAlf9Bda/ZJdVezs/NAfsMQScyIlUMx+Yeu7P7rx5jw==} + + safer-buffer@2.1.2: + resolution: {integrity: sha512-YZo3K82SD7Riyi0E1EQPojLz7kpepnSQI9IyPbHHg1XXXevb5dJI7tpyN2ADxGcQbHG7vcyRHk0cbwqcQriUtg==} + + sax@1.4.1: + resolution: {integrity: sha512-+aWOz7yVScEGoKNd4PA10LZ8sk0A/z5+nXQG5giUO5rprX9jgYsTdov9qCchZiPIZezbZH+jRut8nPodFAX4Jg==} + + saxes@6.0.0: + resolution: {integrity: sha512-xAg7SOnEhrm5zI3puOOKyy1OMcMlIJZYNJY7xLBwSze0UjhPLnWfj2GF2EpT0jmzaJKIWKHLsaSSajf35bcYnA==} + engines: {node: '>=v12.22.7'} + + scheduler@0.23.2: + resolution: {integrity: sha512-UOShsPwz7NrMUqhR6t0hWjFduvOzbtv7toDH1/hIrfRNIDBnnBWd0CwJTGvTpngVlmwGCdP9/Zl/tVrDqcuYzQ==} + + schema-utils@4.3.2: + resolution: {integrity: sha512-Gn/JaSk/Mt9gYubxTtSn/QCV4em9mpAPiR1rqy/Ocu19u/G9J5WWdNoUT4SiV6mFC3y6cxyFcFwdzPM3FgxGAQ==} + engines: {node: '>= 10.13.0'} + + secretlint@10.2.1: + resolution: {integrity: sha512-3BghQkIGrDz3xJklX/COxgKbxHz2CAsGkXH4oh8MxeYVLlhA3L/TLhAxZiTyqeril+CnDGg8MUEZdX1dZNsxVA==} + engines: {node: '>=20.0.0'} + hasBin: true + + semver@5.7.2: + resolution: {integrity: sha512-cBznnQ9KjJqU67B52RMC65CMarK2600WFnbkcaiwWq3xy/5haFJlshgnpjovMVJ+Hff49d8GEn0b87C5pDQ10g==} + hasBin: true + + semver@6.3.1: + resolution: {integrity: sha512-BR7VvDCVHO+q2xBEWskxS6DJE1qRnb7DxzUrogb71CWoSficBxYsiAGd+Kl0mmq/MprG9yArRkyrQxTO6XjMzA==} + hasBin: true + + semver@7.5.4: + resolution: {integrity: sha512-1bCSESV6Pv+i21Hvpxp3Dx+pSD8lIPt8uVjRrxAUt/nbswYc+tK6Y2btiULjd4+fnq15PX+nqQDC7Oft7WkwcA==} + engines: {node: '>=10'} + hasBin: true + + semver@7.7.2: + resolution: {integrity: sha512-RF0Fw+rO5AMf9MAyaRXI4AV0Ulj5lMHqVxxdSgiVbixSCXoEmmX/jk0CuJw4+3SqroYO9VoUh+HcuJivvtJemA==} + engines: {node: '>=10'} + hasBin: true + + serialize-javascript@6.0.2: + resolution: {integrity: sha512-Saa1xPByTTq2gdeFZYLLo+RFE35NHZkAbqZeWNd3BpzppeVisAqpDjcp8dyf6uIvEqJRd46jemmyA4iFIeVk8g==} + + seroval-plugins@1.3.2: + resolution: {integrity: sha512-0QvCV2lM3aj/U3YozDiVwx9zpH0q8A60CTWIv4Jszj/givcudPb48B+rkU5D51NJ0pTpweGMttHjboPa9/zoIQ==} + engines: {node: '>=10'} + peerDependencies: + seroval: ^1.0 + + seroval-plugins@1.3.3: + resolution: {integrity: sha512-16OL3NnUBw8JG1jBLUoZJsLnQq0n5Ua6aHalhJK4fMQkz1lqR7Osz1sA30trBtd9VUDc2NgkuRCn8+/pBwqZ+w==} + engines: {node: '>=10'} + peerDependencies: + seroval: ^1.0 + + seroval@1.3.2: + resolution: {integrity: sha512-RbcPH1n5cfwKrru7v7+zrZvjLurgHhGyso3HTyGtRivGWgYjbOmGuivCQaORNELjNONoK35nj28EoWul9sb1zQ==} + engines: {node: '>=10'} + + set-cookie-parser@2.7.1: + resolution: {integrity: sha512-IOc8uWeOZgnb3ptbCURJWNjWUPcO3ZnTTdzsurqERrP6nPyv+paC55vJM0LpOlT2ne+Ix+9+CRG1MNLlyZ4GjQ==} + + set-function-length@1.2.2: + resolution: {integrity: sha512-pgRc4hJ4/sNjWCSS9AmnS40x3bNMDTknHgL5UaMBTMyJnU90EgWh1Rz+MC9eFu4BuN/UwZjKQuY/1v3rM7HMfg==} + engines: {node: '>= 0.4'} + + set-function-name@2.0.2: + resolution: {integrity: sha512-7PGFlmtwsEADb0WYyvCMa1t+yke6daIG4Wirafur5kcf+MhUnPms1UeR0CKQdTZD81yESwMHbtn+TR+dMviakQ==} + engines: {node: '>= 0.4'} + + set-proto@1.0.0: + resolution: {integrity: sha512-RJRdvCo6IAnPdsvP/7m6bsQqNnn1FCBX5ZNtFL98MmFF/4xAIJTIg1YbHW5DC2W5SKZanrC6i4HsJqlajw/dZw==} + engines: {node: '>= 0.4'} + + setimmediate@1.0.5: + resolution: {integrity: sha512-MATJdZp8sLqDl/68LfQmbP8zKPLQNV6BIZoIgrscFDQ+RsvK/BxeDQOgyxKKoh0y/8h3BqVFnCqQ/gd+reiIXA==} + + shebang-command@2.0.0: + resolution: {integrity: sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==} + engines: {node: '>=8'} + + shebang-regex@3.0.0: + resolution: {integrity: sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==} + engines: {node: '>=8'} + + shell-quote@1.8.3: + resolution: {integrity: sha512-ObmnIF4hXNg1BqhnHmgbDETF8dLPCggZWBjkQfhZpbszZnYur5DUljTcCHii5LC3J5E0yeO/1LIMyH+UvHQgyw==} + engines: {node: '>= 0.4'} + + should-equal@2.0.0: + resolution: {integrity: sha512-ZP36TMrK9euEuWQYBig9W55WPC7uo37qzAEmbjHz4gfyuXrEUgF8cUvQVO+w+d3OMfPvSRQJ22lSm8MQJ43LTA==} + + should-format@3.0.3: + resolution: {integrity: sha512-hZ58adtulAk0gKtua7QxevgUaXTTXxIi8t41L3zo9AHvjXO1/7sdLECuHeIN2SRtYXpNkmhoUP2pdeWgricQ+Q==} + + should-type-adaptors@1.1.0: + resolution: {integrity: sha512-JA4hdoLnN+kebEp2Vs8eBe9g7uy0zbRo+RMcU0EsNy+R+k049Ki+N5tT5Jagst2g7EAja+euFuoXFCa8vIklfA==} + + should-type@1.4.0: + resolution: {integrity: sha512-MdAsTu3n25yDbIe1NeN69G4n6mUnJGtSJHygX3+oN0ZbO3DTiATnf7XnYJdGT42JCXurTb1JI0qOBR65shvhPQ==} + + should-util@1.0.1: + resolution: {integrity: sha512-oXF8tfxx5cDk8r2kYqlkUJzZpDBqVY/II2WhvU0n9Y3XYvAYRmeaf1PvvIvTgPnv4KJ+ES5M0PyDq5Jp+Ygy2g==} + + should@13.2.3: + resolution: {integrity: sha512-ggLesLtu2xp+ZxI+ysJTmNjh2U0TsC+rQ/pfED9bUZZ4DKefP27D+7YJVVTvKsmjLpIi9jAa7itwDGkDDmt1GQ==} + + side-channel-list@1.0.0: + resolution: {integrity: sha512-FCLHtRD/gnpCiCHEiJLOwdmFP+wzCmDEkc9y7NsYxeF4u7Btsn1ZuwgwJGxImImHicJArLP4R0yX4c2KCrMrTA==} + engines: {node: '>= 0.4'} + + side-channel-map@1.0.1: + resolution: {integrity: sha512-VCjCNfgMsby3tTdo02nbjtM/ewra6jPHmpThenkTYh8pG9ucZ/1P8So4u4FGBek/BjpOVsDCMoLA/iuBKIFXRA==} + engines: {node: '>= 0.4'} + + side-channel-weakmap@1.0.2: + resolution: {integrity: sha512-WPS/HvHQTYnHisLo9McqBHOJk2FkHO/tlpvldyrnem4aeQp4hai3gythswg6p01oSoTl58rcpiFAjF2br2Ak2A==} + engines: {node: '>= 0.4'} + + side-channel@1.1.0: + resolution: {integrity: sha512-ZX99e6tRweoUXqR+VBrslhda51Nh5MTQwou5tnUDgbtyM0dBgmhEDtWGP/xbKn6hqfPRHujUNwz5fy/wbbhnpw==} + engines: {node: '>= 0.4'} + + siginfo@2.0.0: + resolution: {integrity: sha512-ybx0WO1/8bSBLEWXZvEd7gMW3Sn3JFlW3TvX1nREbDLRNQNaeNN8WK0meBwPdAaOI7TtRRRJn/Es1zhrrCHu7g==} + + signal-exit@3.0.7: + resolution: {integrity: sha512-wnD2ZE+l+SPC/uoS0vXeE9L1+0wuaMqKlfz9AMUo38JsyLSBWSFcHR1Rri62LZc12vLr1gb3jl7iwQhgwpAbGQ==} + + signal-exit@4.1.0: + resolution: {integrity: sha512-bzyZ1e88w9O1iNJbKnOlvYTrWPDl46O1bG0D3XInv+9tkPrxrN8jUUTiFlDkkmKWgn1M6CfIA13SuGqOa9Korw==} + engines: {node: '>=14'} + + simple-concat@1.0.1: + resolution: {integrity: sha512-cSFtAPtRhljv69IK0hTVZQ+OfE9nePi/rtJmw5UjHeVyVroEqJXP1sFztKUy1qU+xvz3u/sfYJLa947b7nAN2Q==} + + simple-eval@1.0.1: + resolution: {integrity: sha512-LH7FpTAkeD+y5xQC4fzS+tFtaNlvt3Ib1zKzvhjv/Y+cioV4zIuw4IZr2yhRLu67CWL7FR9/6KXKnjRoZTvGGQ==} + engines: {node: '>=12'} + + simple-get@4.0.1: + resolution: {integrity: sha512-brv7p5WgH0jmQJr1ZDDfKDOSeWWg+OVypG99A/5vYGPqJ6pxiaHLy8nxtFjBA7oMa01ebA9gfh1uMCFqOuXxvA==} + + sirv@3.0.1: + resolution: {integrity: sha512-FoqMu0NCGBLCcAkS1qA+XJIQTR6/JHfQXl+uGteNCQ76T91DMUjPa9xfmeqMY3z80nLSg9yQmNjK0Px6RWsH/A==} + engines: {node: '>=18'} + + sisteransi@1.0.5: + resolution: {integrity: sha512-bLGGlR1QxBcynn2d5YmDX4MGjlZvy2MRBDRNHLJ8VI6l6+9FUiyTFNJ0IveOSP0bcXgVDPRcfGqA0pjaqUpfVg==} + + slash@3.0.0: + resolution: {integrity: sha512-g9Q1haeby36OSStwb4ntCGGGaKsaVSjQ68fBxoQcutl5fS1vuY18H3wSt3jFyFtrkx+Kz0V1G85A4MyAdDMi2Q==} + engines: {node: '>=8'} + + slash@5.1.0: + resolution: {integrity: sha512-ZA6oR3T/pEyuqwMgAKT0/hAv8oAXckzbkmR0UkUosQ+Mc4RxGoJkRmwHgHufaenlyAgE1Mxgpdcrf75y6XcnDg==} + engines: {node: '>=14.16'} + + slice-ansi@4.0.0: + resolution: {integrity: sha512-qMCMfhY040cVHT43K9BFygqYbUPFZKHOg7K73mtTWJRb8pyP3fzf4Ixd5SzdEJQ6MRUg/WBnOLxghZtKKurENQ==} + engines: {node: '>=10'} + + solid-js@1.9.7: + resolution: {integrity: sha512-/saTKi8iWEM233n5OSi1YHCCuh66ZIQ7aK2hsToPe4tqGm7qAejU1SwNuTPivbWAYq7SjuHVVYxxuZQNRbICiw==} + + source-map-js@1.2.1: + resolution: {integrity: sha512-UXWMKhLOwVKb728IUtQPXxfYU+usdybtUrK/8uGE8CQMvrhOpwvzDBwj0QhSL7MQc7vIsISBG8VQ8+IDQxpfQA==} + engines: {node: '>=0.10.0'} + + source-map-support@0.5.21: + resolution: {integrity: sha512-uBHU3L3czsIyYXKX88fdrGovxdSCoTGDRZ6SYXtSRxLZUzHg5P/66Ht6uoUlHu9EZod+inXhKo3qQgwXUT/y1w==} + + source-map@0.6.1: + resolution: {integrity: sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==} + engines: {node: '>=0.10.0'} + + source-map@0.7.4: + resolution: {integrity: sha512-l3BikUxvPOcn5E74dZiq5BGsTb5yEwhaTSzccU6t4sDOH8NWJCstKO5QT2CvtFoK6F0saL7p9xHAqHOlCPJygA==} + engines: {node: '>= 8'} + + space-separated-tokens@2.0.2: + resolution: {integrity: sha512-PEGlAwrG8yXGXRjW32fGbg66JAlOAwbObuqVoJpv/mRgoWDQfgH1wDPvtzWyUSNAXBGSk8h755YDbbcEy3SH2Q==} + + spdx-correct@3.2.0: + resolution: {integrity: sha512-kN9dJbvnySHULIluDHy32WHRUu3Og7B9sbY7tsFLctQkIqnMh3hErYgdMjTYuqmcXX+lK5T1lnUt3G7zNswmZA==} + + spdx-exceptions@2.5.0: + resolution: {integrity: sha512-PiU42r+xO4UbUS1buo3LPJkjlO7430Xn5SVAhdpzzsPHsjbYVflnnFdATgabnLude+Cqu25p6N+g2lw/PFsa4w==} + + spdx-expression-parse@3.0.1: + resolution: {integrity: sha512-cbqHunsQWnJNE6KhVSMsMeH5H/L9EpymbzqTQ3uLwNCLZ1Q481oWaofqH7nO6V07xlXwY6PhQdQ2IedWx/ZK4Q==} + + spdx-license-ids@3.0.21: + resolution: {integrity: sha512-Bvg/8F5XephndSK3JffaRqdT+gyhfqIPwDHpX80tJrF8QQRYMo8sNMeaZ2Dp5+jhwKnUmIOyFFQfHRkjJm5nXg==} + + split.js@1.6.5: + resolution: {integrity: sha512-mPTnGCiS/RiuTNsVhCm9De9cCAUsrNFFviRbADdKiiV+Kk8HKp/0fWu7Kr8pi3/yBmsqLFHuXGT9UUZ+CNLwFw==} + + sprintf-js@1.0.3: + resolution: {integrity: sha512-D9cPgkvLlV3t3IzL0D0YLvGA9Ahk4PcvVwUbN0dSGr1aP0Nrt4AEnTUbuGvquEC0mA64Gqt1fzirlRs5ibXx8g==} + + stackback@0.0.2: + resolution: {integrity: sha512-1XMJE5fQo1jGH6Y/7ebnwPOBEkIEnT4QF32d5R1+VXdXveM0IBMJt8zfaxX1P3QhVwrYe+576+jkANtSS2mBbw==} + + std-env@3.9.0: + resolution: {integrity: sha512-UGvjygr6F6tpH7o2qyqR6QYpwraIjKSdtzyBdyytFOHmPZY917kwdwLG0RbOjWOnKmnm3PeHjaoLLMie7kPLQw==} + + stdin-discarder@0.2.2: + resolution: {integrity: sha512-UhDfHmA92YAlNnCfhmq0VeNL5bDbiZGg7sZ2IvPsXubGkiNa9EC+tUTsjBRsYUAz87btI6/1wf4XoVvQ3uRnmQ==} + engines: {node: '>=18'} + + stop-iteration-iterator@1.1.0: + resolution: {integrity: sha512-eLoXW/DHyl62zxY4SCaIgnRhuMr6ri4juEYARS8E6sCEqzKpOiE521Ucofdx+KnDZl5xmvGYaaKCk5FEOxJCoQ==} + engines: {node: '>= 0.4'} + + storybook@9.0.18: + resolution: {integrity: sha512-ruxpEpizwoYQTt1hBOrWyp9trPYWD9Apt1TJ37rs1rzmNQWpSNGJDMg91JV4mUhBChzRvnid/oRBFFCWJz/dfw==} + hasBin: true + peerDependencies: + prettier: ^2 || ^3 + peerDependenciesMeta: + prettier: + optional: true + + storybook@9.1.5: + resolution: {integrity: sha512-cGwJ2AE6nxlwqQlOiI+HKX5qa7+FOV7Ha7Qa+GoASBIQSSnLfbY6UldgAxHCJGJOFtgW/wuqfDtNvni6sj1/OQ==} + hasBin: true + peerDependencies: + prettier: ^2 || ^3 + peerDependenciesMeta: + prettier: + optional: true + + string-argv@0.3.2: + resolution: {integrity: sha512-aqD2Q0144Z+/RqG52NeHEkZauTAUWJO8c6yTftGJKO3Tja5tUgIfmIl6kExvhtxSDP7fXB6DvzkfMpCd/F3G+Q==} + engines: {node: '>=0.6.19'} + + string-width@4.2.3: + resolution: {integrity: sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==} + engines: {node: '>=8'} + + string-width@5.1.2: + resolution: {integrity: sha512-HnLOCR3vjcY8beoNLtcjZ5/nxn2afmME6lhrDrebokqMap+XbeW8n9TXpPDOqdGK5qcI3oT0GKTW6wC7EMiVqA==} + engines: {node: '>=12'} + + string-width@7.2.0: + resolution: {integrity: sha512-tsaTIkKW9b4N+AEj+SVA+WhJzV7/zMhcSu78mLKWSk7cXMOSHsBKFWUs0fWwq8QyK3MgJBQRX6Gbi4kYbdvGkQ==} + engines: {node: '>=18'} + + string.prototype.trim@1.2.10: + resolution: {integrity: sha512-Rs66F0P/1kedk5lyYyH9uBzuiI/kNRmwJAR9quK6VOtIpZ2G+hMZd+HQbbv25MgCA6gEffoMZYxlTod4WcdrKA==} + engines: {node: '>= 0.4'} + + string.prototype.trimend@1.0.9: + resolution: {integrity: sha512-G7Ok5C6E/j4SGfyLCloXTrngQIQU3PWtXGst3yM7Bea9FRURf1S42ZHlZZtsNque2FN2PoUhfZXYLNWwEr4dLQ==} + engines: {node: '>= 0.4'} + + string.prototype.trimstart@1.0.8: + resolution: {integrity: sha512-UXSH262CSZY1tfu3G3Secr6uGLCFVPMhIqHjlgCUtCCcgihYc/xKs9djMTMUOb2j1mVSeU8EU6NWc/iQKU6Gfg==} + engines: {node: '>= 0.4'} + + string_decoder@1.1.1: + resolution: {integrity: sha512-n/ShnvDi6FHbbVfviro+WojiFzv+s8MPMHBczVePfUpDJLwoLT0ht1l4YwBCbi8pJAveEEdnkHyPyTP/mzRfwg==} + + string_decoder@1.3.0: + resolution: {integrity: sha512-hkRX8U1WjJFd8LsDJ2yQ/wWWxaopEsABU1XfkM8A+j0+85JAGppt16cr1Whg6KIbb4okU6Mql6BOj+uup/wKeA==} + + stringify-entities@4.0.4: + resolution: {integrity: sha512-IwfBptatlO+QCJUo19AqvrPNqlVMpW9YEL2LIVY+Rpv2qsjCGxaDLNRgeGsQWJhfItebuJhsGSLjaBbNSQ+ieg==} + + strip-ansi@6.0.1: + resolution: {integrity: sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==} + engines: {node: '>=8'} + + strip-ansi@7.1.0: + resolution: {integrity: sha512-iq6eVVI64nQQTRYq2KtEg2d2uU7LElhTJwsH4YzIHZshxlgZms/wIc4VoDQTlG/IvVIrBKG06CrZnp0qv7hkcQ==} + engines: {node: '>=12'} + + strip-bom@3.0.0: + resolution: {integrity: sha512-vavAMRXOgBVNF6nyEEmL3DBK19iRpDcoIwW+swQ+CbGiu7lju6t+JklA1MHweoWtadgt4ISVUsXLyDq34ddcwA==} + engines: {node: '>=4'} + + strip-bom@4.0.0: + resolution: {integrity: sha512-3xurFv5tEgii33Zi8Jtp55wEIILR9eh34FAW00PZf+JnSsTmV/ioewSgQl97JHvgjoRGwPShsWm+IdrxB35d0w==} + engines: {node: '>=8'} + + strip-final-newline@2.0.0: + resolution: {integrity: sha512-BrpvfNAE3dcvq7ll3xVumzjKjZQ5tI1sEUIKr3Uoks0XUl45St3FlatVqef9prk4jRDzhW6WZg+3bk93y6pLjA==} + engines: {node: '>=6'} + + strip-indent@3.0.0: + resolution: {integrity: sha512-laJTa3Jb+VQpaC6DseHhF7dXVqHTfJPCRDaEbid/drOhgitgYku/letMUqOXFoWV0zIIUbjpdH2t+tYj4bQMRQ==} + engines: {node: '>=8'} + + strip-indent@4.0.0: + resolution: {integrity: sha512-mnVSV2l+Zv6BLpSD/8V87CW/y9EmmbYzGCIavsnsI6/nwn26DwffM/yztm30Z/I2DY9wdS3vXVCMnHDgZaVNoA==} + engines: {node: '>=12'} + + strip-json-comments@2.0.1: + resolution: {integrity: sha512-4gB8na07fecVVkOI6Rs4e7T6NOTki5EmL7TUduTs6bu3EdnSycntVJ4re8kgZA+wx9IueI2Y11bfbgwtzuE0KQ==} + engines: {node: '>=0.10.0'} + + strip-json-comments@3.1.1: + resolution: {integrity: sha512-6fPc+R4ihwqP6N/aIv2f1gMH8lOVtWQHoqC4yK6oSDVVocumAsfCqjkXnqiYMhmMwS/mEHLp7Vehlt3ql6lEig==} + engines: {node: '>=8'} + + strip-literal@3.0.0: + resolution: {integrity: sha512-TcccoMhJOM3OebGhSBEmp3UZ2SfDMZUEBdRA/9ynfLi8yYajyWX3JiXArcJt4Umh4vISpspkQIY8ZZoCqjbviA==} + + structured-source@4.0.0: + resolution: {integrity: sha512-qGzRFNJDjFieQkl/sVOI2dUjHKRyL9dAJi2gCPGJLbJHBIkyOHxjuocpIEfbLioX+qSJpvbYdT49/YCdMznKxA==} + + style-mod@4.1.2: + resolution: {integrity: sha512-wnD1HyVqpJUI2+eKZ+eo1UwghftP6yuFheBqqe+bWCotBjC2K1YnteJILRMs3SM4V/0dLEW1SC27MWP5y+mwmw==} + + style-to-js@1.1.17: + resolution: {integrity: sha512-xQcBGDxJb6jjFCTzvQtfiPn6YvvP2O8U1MDIPNfJQlWMYfktPy+iGsHE7cssjs7y84d9fQaK4UF3RIJaAHSoYA==} + + style-to-object@1.0.9: + resolution: {integrity: sha512-G4qppLgKu/k6FwRpHiGiKPaPTFcG3g4wNVX/Qsfu+RqQM30E7Tyu/TEgxcL9PNLF5pdRLwQdE3YKKf+KF2Dzlw==} + + sucrase@3.35.0: + resolution: {integrity: sha512-8EbVDiu9iN/nESwxeSxDKe0dunta1GOlHufmSSXxMD2z2/tMZpDMpvXQGsc+ajGo8y2uYUmixaSRUc/QPoQ0GA==} + engines: {node: '>=16 || 14 >=14.17'} + hasBin: true + + supports-color@7.2.0: + resolution: {integrity: sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==} + engines: {node: '>=8'} + + supports-color@8.1.1: + resolution: {integrity: sha512-MpUEN2OodtUzxvKQl72cUF7RQ5EiHsGvSsVG0ia9c5RbWGL2CI4C7EpPS8UTBIplnlzZiNuV56w+FuNxy3ty2Q==} + engines: {node: '>=10'} + + supports-color@9.4.0: + resolution: {integrity: sha512-VL+lNrEoIXww1coLPOmiEmK/0sGigko5COxI09KzHc2VJXJsQ37UaQ+8quuxjDeA7+KnLGTWRyOXSLLR2Wb4jw==} + engines: {node: '>=12'} + + supports-hyperlinks@3.2.0: + resolution: {integrity: sha512-zFObLMyZeEwzAoKCyu1B91U79K2t7ApXuQfo8OuxwXLDgcKxuwM+YvcbIhm6QWqz7mHUH1TVytR1PwVVjEuMig==} + engines: {node: '>=14.18'} + + supports-preserve-symlinks-flag@1.0.0: + resolution: {integrity: sha512-ot0WnXS9fgdkgIcePe6RHNk1WA8+muPa6cSjeR3V8K27q9BB1rTE3R1p7Hv0z1ZyAc8s6Vvv8DIyWf681MAt0w==} + engines: {node: '>= 0.4'} + + swagger2openapi@7.0.8: + resolution: {integrity: sha512-upi/0ZGkYgEcLeGieoz8gT74oWHA0E7JivX7aN9mAf+Tc7BQoRBvnIGHoPDw+f9TXTW4s6kGYCZJtauP6OYp7g==} + hasBin: true + + symbol-tree@3.2.4: + resolution: {integrity: sha512-9QNk5KwDF+Bvz+PyObkmSYjI5ksVUYtjW7AU22r2NKcfLJcXp96hkDWU3+XndOsUb+AQ9QhfzfCT2O+CNWT5Tw==} + + syncpack@13.0.4: + resolution: {integrity: sha512-kJ9VlRxNCsBD5pJAE29oXeBYbPLhEySQmK4HdpsLv81I6fcDDW17xeJqMwiU3H7/woAVsbgq25DJNS8BeiN5+w==} + engines: {node: '>=18.18.0'} + hasBin: true + + tabbable@6.2.0: + resolution: {integrity: sha512-Cat63mxsVJlzYvN51JmVXIgNoUokrIaT2zLclCXjRd8boZ0004U4KCs/sToJ75C6sdlByWxpYnb5Boif1VSFew==} + + table-layout@4.1.1: + resolution: {integrity: sha512-iK5/YhZxq5GO5z8wb0bY1317uDF3Zjpha0QFFLA8/trAoiLbQD0HUbMesEaxyzUgDxi2QlcbM8IvqOlEjgoXBA==} + engines: {node: '>=12.17'} + + table@6.9.0: + resolution: {integrity: sha512-9kY+CygyYM6j02t5YFHbNz2FN5QmYGv9zAjVp4lCDjlCw7amdckXlEt/bjMhUIfj4ThGRE4gCUH5+yGnNuPo5A==} + engines: {node: '>=10.0.0'} + + tailwind-merge@3.3.1: + resolution: {integrity: sha512-gBXpgUm/3rp1lMZZrM/w7D8GKqshif0zAymAhbCyIt8KMe+0v9DQ7cdYLR4FHH/cKpdTXb+A/tKKU3eolfsI+g==} + + tailwind-scrollbar@3.1.0: + resolution: {integrity: sha512-pmrtDIZeHyu2idTejfV59SbaJyvp1VRjYxAjZBH0jnyrPRo6HL1kD5Glz8VPagasqr6oAx6M05+Tuw429Z8jxg==} + engines: {node: '>=12.13.0'} + peerDependencies: + tailwindcss: 3.x + + tailwindcss@3.4.17: + resolution: {integrity: sha512-w33E2aCvSDP0tW9RZuNXadXlkHXqFzSkQew/aIa2i/Sj8fThxwovwlXHSPXTbAHwEIhBFXAedUhP2tueAKP8Og==} + engines: {node: '>=14.0.0'} + hasBin: true + + tailwindcss@4.1.11: + resolution: {integrity: sha512-2E9TBm6MDD/xKYe+dvJZAmg3yxIEDNRc0jwlNyDg/4Fil2QcSLjFKGVff0lAf1jjeaArlG/M75Ey/EYr/OJtBA==} + + tapable@2.2.2: + resolution: {integrity: sha512-Re10+NauLTMCudc7T5WLFLAwDhQ0JWdrMK+9B2M8zR5hRExKmsRDCBA7/aV/pNJFltmBFO5BAMlQFi/vq3nKOg==} + engines: {node: '>=6'} + + tapable@2.2.3: + resolution: {integrity: sha512-ZL6DDuAlRlLGghwcfmSn9sK3Hr6ArtyudlSAiCqQ6IfE+b+HHbydbYDIG15IfS5do+7XQQBdBiubF/cV2dnDzg==} + engines: {node: '>=6'} + + tar-fs@2.1.3: + resolution: {integrity: sha512-090nwYJDmlhwFwEW3QQl+vaNnxsO2yVsd45eTKRBzSzu+hlb1w2K9inVq5b0ngXuLVqQ4ApvsUHHnu/zQNkWAg==} + + tar-stream@2.2.0: + resolution: {integrity: sha512-ujeqbceABgwMZxEJnk2HDY2DlnUZ+9oEcb1KzTVfYHio0UE6dG71n60d8D2I4qNvleWrrXpmjpt7vZeF1LnMZQ==} + engines: {node: '>=6'} + + tar@7.4.3: + resolution: {integrity: sha512-5S7Va8hKfV7W5U6g3aYxXmlPoZVAwUMy9AOKyF2fVuZa2UD3qZjg578OrLRt8PcNN1PleVaL/5/yYATNL0ICUw==} + engines: {node: '>=18'} + + terminal-link@4.0.0: + resolution: {integrity: sha512-lk+vH+MccxNqgVqSnkMVKx4VLJfnLjDBGzH16JVZjKE2DoxP57s6/vt6JmXV5I3jBcfGrxNrYtC+mPtU7WJztA==} + engines: {node: '>=18'} + + terser-webpack-plugin@5.3.14: + resolution: {integrity: sha512-vkZjpUjb6OMS7dhV+tILUW6BhpDR7P2L/aQSAv+Uwk+m8KATX9EccViHTJR2qDtACKPIYndLGCyl3FMo+r2LMw==} + engines: {node: '>= 10.13.0'} + peerDependencies: + '@swc/core': '*' + esbuild: '*' + uglify-js: '*' + webpack: ^5.1.0 + peerDependenciesMeta: + '@swc/core': + optional: true + esbuild: + optional: true + uglify-js: + optional: true + + terser@5.44.0: + resolution: {integrity: sha512-nIVck8DK+GM/0Frwd+nIhZ84pR/BX7rmXMfYwyg+Sri5oGVE99/E3KvXqpC2xHFxyqXyGHTKBSioxxplrO4I4w==} + engines: {node: '>=10'} + hasBin: true + + test-exclude@6.0.0: + resolution: {integrity: sha512-cAGWPIyOHU6zlmg88jwm7VRyXnMN7iV68OGAbYDk/Mh/xC/pzVPlQtY6ngoIH/5/tciuhGfvESU8GrHrcxD56w==} + engines: {node: '>=8'} + + test-exclude@7.0.1: + resolution: {integrity: sha512-pFYqmTw68LXVjeWJMST4+borgQP2AyMNbg1BpZh9LbyhUeNkeaPF9gzfPGUAnSMV3qPYdWUwDIjjCLiSDOl7vg==} + engines: {node: '>=18'} + + text-table@0.2.0: + resolution: {integrity: sha512-N+8UisAXDGk8PFXP4HAzVR9nbfmVJ3zYLAWiTIoqC5v5isinhr+r5uaO8+7r3BMfuNIufIsA7RdpVgacC2cSpw==} + + textextensions@6.11.0: + resolution: {integrity: sha512-tXJwSr9355kFJI3lbCkPpUH5cP8/M0GGy2xLO34aZCjMXBaK3SoPnZwr/oWmo1FdCnELcs4npdCIOFtq9W3ruQ==} + engines: {node: '>=4'} + + thememirror@2.0.1: + resolution: {integrity: sha512-d5i6FVvWWPkwrm4cHLI3t9AT1OrkAt7Ig8dtdYSofgF7C/eiyNuq6zQzSTusWTde3jpW9WLvA9J/fzNKMUsd0w==} + peerDependencies: + '@codemirror/language': ^6.0.0 + '@codemirror/state': ^6.0.0 + '@codemirror/view': ^6.0.0 + + thenify-all@1.6.0: + resolution: {integrity: sha512-RNxQH/qI8/t3thXJDwcstUO4zeqo64+Uy/+sNVRBx4Xn2OX+OZ9oP+iJnNFqplFra2ZUVeKCSa2oVWi3T4uVmA==} + engines: {node: '>=0.8'} + + thenify@3.3.1: + resolution: {integrity: sha512-RVZSIV5IG10Hk3enotrhvz0T9em6cyHBLkH/YAZuKqd8hRkKhSfCGIcP2KUY0EPxndzANBmNllzWPwak+bheSw==} + + tightrope@0.2.0: + resolution: {integrity: sha512-Kw36UHxJEELq2VUqdaSGR2/8cAsPgMtvX8uGVU6Jk26O66PhXec0A5ZnRYs47btbtwPDpXXF66+Fo3vimCM9aQ==} + engines: {node: '>=16'} + + tiny-invariant@1.3.3: + resolution: {integrity: sha512-+FbBPE1o9QAYvviau/qC5SE3caw21q3xkvWKBtja5vgqOWIHHJ3ioaq1VPfn/Szqctz2bU/oYeKd9/z5BL+PVg==} + + tiny-warning@1.0.3: + resolution: {integrity: sha512-lBN9zLN/oAf68o3zNXYrdCt1kP8WsiGW8Oo2ka41b2IM5JL/S1CTyX1rW0mb/zSuJun0ZUrDxx4sqvYS2FWzPA==} + + tinybench@2.9.0: + resolution: {integrity: sha512-0+DUvqWMValLmha6lr4kD8iAMK1HzV0/aKnCtWb9v9641TnP/MFb7Pc2bxoxQjTXAErryXVgUOfv2YqNllqGeg==} + + tinyexec@0.3.2: + resolution: {integrity: sha512-KQQR9yN7R5+OSwaK0XQoj22pwHoTlgYqmUscPYoknOoWCWfj/5/ABTMRi69FrKU5ffPVh5QcFikpWJI/P1ocHA==} + + tinyglobby@0.2.14: + resolution: {integrity: sha512-tX5e7OM1HnYr2+a2C/4V0htOcSQcoSTH9KgJnVvNm5zm/cyEWKJ7j7YutsH9CxMdtOkkLFy2AHrMci9IM8IPZQ==} + engines: {node: '>=12.0.0'} + + tinypool@1.1.1: + resolution: {integrity: sha512-Zba82s87IFq9A9XmjiX5uZA/ARWDrB03OHlq+Vw1fSdt0I+4/Kutwy8BP4Y/y/aORMo61FQ0vIb5j44vSo5Pkg==} + engines: {node: ^18.0.0 || >=20.0.0} + + tinyrainbow@2.0.0: + resolution: {integrity: sha512-op4nsTR47R6p0vMUUoYl/a+ljLFVtlfaXkLQmqfLR1qHma1h/ysYk4hEXZ880bf2CYgTskvTa/e196Vd5dDQXw==} + engines: {node: '>=14.0.0'} + + tinyspy@4.0.3: + resolution: {integrity: sha512-t2T/WLB2WRgZ9EpE4jgPJ9w+i66UZfDc8wHh0xrwiRNN+UwH98GIJkTeZqX9rg0i0ptwzqW+uYeIF0T4F8LR7A==} + engines: {node: '>=14.0.0'} + + tldts-core@6.1.86: + resolution: {integrity: sha512-Je6p7pkk+KMzMv2XXKmAE3McmolOQFdxkKw0R8EYNr7sELW46JqnNeTX8ybPiQgvg1ymCoF8LXs5fzFaZvJPTA==} + + tldts@6.1.86: + resolution: {integrity: sha512-WMi/OQ2axVTf/ykqCQgXiIct+mSQDFdH2fkwhPwgEwvJ1kSzZRiinb0zF2Xb8u4+OqPChmyI6MEu4EezNJz+FQ==} + hasBin: true + + tmp@0.2.3: + resolution: {integrity: sha512-nZD7m9iCPC5g0pYmcaxogYKggSfLsdxl8of3Q/oIbqCqLLIO9IAF0GWjX1z9NZRHPiXv8Wex4yDCaZsgEw0Y8w==} + engines: {node: '>=14.14'} + + to-regex-range@5.0.1: + resolution: {integrity: sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ==} + engines: {node: '>=8.0'} + + totalist@3.0.1: + resolution: {integrity: sha512-sf4i37nQ2LBx4m3wB74y+ubopq6W/dIzXg0FDGjsYnZHVa1Da8FH853wlL2gtUhg+xJXjfk3kUZS3BRoQeoQBQ==} + engines: {node: '>=6'} + + tough-cookie@5.1.2: + resolution: {integrity: sha512-FVDYdxtnj0G6Qm/DhNPSb8Ju59ULcup3tuJxkFb5K8Bv2pUXILbf0xZWU8PX8Ov19OXljbUyveOFwRMwkXzO+A==} + engines: {node: '>=16'} + + tr46@0.0.3: + resolution: {integrity: sha512-N3WMsuqV66lT30CrXNbEjx4GEwlow3v6rr4mCcv6prnfwhS01rkgyFdjPNBYd9br7LpXV1+Emh01fHnq2Gdgrw==} + + tr46@5.1.1: + resolution: {integrity: sha512-hdF5ZgjTqgAntKkklYw0R03MG2x/bSzTtkxmIRw/sTNV8YXsCJ1tfLAX23lhxhHJlEf3CRCOCGGWw3vI3GaSPw==} + engines: {node: '>=18'} + + trim-lines@3.0.1: + resolution: {integrity: sha512-kRj8B+YHZCc9kQYdWfJB2/oUl9rA99qbowYYBtr4ui4mZyAQ2JpvVBd/6U2YloATfqBhBTSMhTpgBHtU0Mf3Rg==} + + trough@2.2.0: + resolution: {integrity: sha512-tmMpK00BjZiUyVyvrBK7knerNgmgvcV/KLVyuma/SC+TQN167GrMRciANTz09+k3zW8L8t60jWO1GpfkZdjTaw==} + + ts-api-utils@2.1.0: + resolution: {integrity: sha512-CUgTZL1irw8u29bzrOD/nH85jqyc74D6SshFgujOIA7osm2Rz7dYH77agkx7H4FBNxDq7Cjf+IjaX/8zwFW+ZQ==} + engines: {node: '>=18.12'} + peerDependencies: + typescript: '>=4.8.4' + + ts-dedent@2.2.0: + resolution: {integrity: sha512-q5W7tVM71e2xjHZTlgfTDoPF/SmqKG5hddq9SzR49CH2hayqRKJtQ4mtRlSxKaJlR/+9rEM+mnBHf7I2/BQcpQ==} + engines: {node: '>=6.10'} + + ts-interface-checker@0.1.13: + resolution: {integrity: sha512-Y/arvbn+rrz3JCKl9C4kVNfTfSm2/mEp5FSz5EsZSANGPSlQrpRI5M4PKF+mJnE52jOO90PnPSc3Ur3bTQw0gA==} + + ts-loader@9.5.2: + resolution: {integrity: sha512-Qo4piXvOTWcMGIgRiuFa6nHNm+54HbYaZCKqc9eeZCLRy3XqafQgwX2F7mofrbJG3g7EEb+lkiR+z2Lic2s3Zw==} + engines: {node: '>=12.0.0'} + peerDependencies: + typescript: '*' + webpack: ^5.0.0 + + ts-toolbelt@9.6.0: + resolution: {integrity: sha512-nsZd8ZeNUzukXPlJmTBwUAuABDe/9qtVDelJeT/qW0ow3ZS3BsQJtNkan1802aM9Uf68/Y8ljw86Hu0h5IUW3w==} + + tsconfck@2.1.2: + resolution: {integrity: sha512-ghqN1b0puy3MhhviwO2kGF8SeMDNhEbnKxjK7h6+fvY9JAxqvXi8y5NAHSQv687OVboS2uZIByzGd45/YxrRHg==} + engines: {node: ^14.13.1 || ^16 || >=18} + hasBin: true + peerDependencies: + typescript: ^4.3.5 || ^5.0.0 + peerDependenciesMeta: + typescript: + optional: true + + tsconfig-paths@4.2.0: + resolution: {integrity: sha512-NoZ4roiN7LnbKn9QqE1amc9DJfzvZXxF4xDavcOWt1BPkdx+m+0gJuPM+S0vCe7zTJMYUP0R8pO2XMr+Y8oLIg==} + engines: {node: '>=6'} + + tslib@1.14.1: + resolution: {integrity: sha512-Xni35NKzjgMrwevysHTCArtLDpPvye8zV/0E4EyYn43P7/7qvQwPh9BGkHewbMulVntbigmcT7rdX3BNo9wRJg==} + + tslib@2.8.1: + resolution: {integrity: sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==} + + tsx@4.20.3: + resolution: {integrity: sha512-qjbnuR9Tr+FJOMBqJCW5ehvIo/buZq7vH7qD7JziU98h6l3qGy0a/yPFjwO+y0/T7GFpNgNAvEcPPVfyT8rrPQ==} + engines: {node: '>=18.0.0'} + hasBin: true + + tunnel-agent@0.6.0: + resolution: {integrity: sha512-McnNiV1l8RYeY8tBgEpuodCC1mLUdbSN+CYBL7kJsJNInOP8UjDDEwdk6Mw60vdLLrr5NHKZhMAOSrR2NZuQ+w==} + + tunnel@0.0.6: + resolution: {integrity: sha512-1h/Lnq9yajKY2PEbBadPXj3VxsDDu844OnaAo52UVmIzIvwwtBPIuNvkjuzBlTWpfJyUbG3ez0KSBibQkj4ojg==} + engines: {node: '>=0.6.11 <=0.7.0 || >=0.7.3'} + + type-check@0.4.0: + resolution: {integrity: sha512-XleUoc9uwGXqjWwXaUTZAmzMcFZ5858QA2vvx1Ur5xIcixXIP+8LnFDgRplU30us6teqdlskFfu+ae4K79Ooew==} + engines: {node: '>= 0.8.0'} + + type-fest@4.41.0: + resolution: {integrity: sha512-TeTSQ6H5YHvpqVwBRcnLDCBnDOHWYu7IvGbHT6N8AOymcr9PJGjc1GTtiWZTYg0NCgYwvnYWEkVChQAr9bjfwA==} + engines: {node: '>=16'} + + typed-array-buffer@1.0.3: + resolution: {integrity: sha512-nAYYwfY3qnzX30IkA6AQZjVbtK6duGontcQm1WSG1MD94YLqK0515GNApXkoxKOWMusVssAHWLh9SeaoefYFGw==} + engines: {node: '>= 0.4'} + + typed-array-byte-length@1.0.3: + resolution: {integrity: sha512-BaXgOuIxz8n8pIq3e7Atg/7s+DpiYrxn4vdot3w9KbnBhcRQq6o3xemQdIfynqSeXeDrF32x+WvfzmOjPiY9lg==} + engines: {node: '>= 0.4'} + + typed-array-byte-offset@1.0.4: + resolution: {integrity: sha512-bTlAFB/FBYMcuX81gbL4OcpH5PmlFHqlCCpAl8AlEzMz5k53oNDvN8p1PNOWLEmI2x4orp3raOFB51tv9X+MFQ==} + engines: {node: '>= 0.4'} + + typed-array-length@1.0.7: + resolution: {integrity: sha512-3KS2b+kL7fsuk/eJZ7EQdnEmQoaho/r6KUef7hxvltNA5DR8NAUM+8wJMbJyZ4G9/7i3v5zPBIMN5aybAh2/Jg==} + engines: {node: '>= 0.4'} + + typed-rest-client@1.8.11: + resolution: {integrity: sha512-5UvfMpd1oelmUPRbbaVnq+rHP7ng2cE4qoQkQeAqxRL6PklkxsM0g32/HL0yfvruK6ojQ5x8EE+HF4YV6DtuCA==} + + typedoc-plugin-markdown@4.7.1: + resolution: {integrity: sha512-HN/fHLm2S6MD4HX8txfB4eWvVBzX/mEYy5U5s1KTAdh3E5uX5/lilswqTzZlPTT6fNZInAboAdFGpbAuBKnE4A==} + engines: {node: '>= 18'} + peerDependencies: + typedoc: 0.28.x + + typedoc@0.28.7: + resolution: {integrity: sha512-lpz0Oxl6aidFkmS90VQDQjk/Qf2iw0IUvFqirdONBdj7jPSN9mGXhy66BcGNDxx5ZMyKKiBVAREvPEzT6Uxipw==} + engines: {node: '>= 18', pnpm: '>= 10'} + hasBin: true + peerDependencies: + typescript: 5.0.x || 5.1.x || 5.2.x || 5.3.x || 5.4.x || 5.5.x || 5.6.x || 5.7.x || 5.8.x + + typescript-eslint@8.38.0: + resolution: {integrity: sha512-FsZlrYK6bPDGoLeZRuvx2v6qrM03I0U0SnfCLPs/XCCPCFD80xU9Pg09H/K+XFa68uJuZo7l/Xhs+eDRg2l3hg==} + engines: {node: ^18.18.0 || ^20.9.0 || >=21.1.0} + peerDependencies: + eslint: ^8.57.0 || ^9.0.0 + typescript: '>=4.8.4 <5.9.0' + + typescript@5.8.2: + resolution: {integrity: sha512-aJn6wq13/afZp/jT9QZmwEjDqqvSGp1VT5GVg+f/t6/oVyrgXM6BY1h9BRh/O5p3PlUPAe+WuiEZOmb/49RqoQ==} + engines: {node: '>=14.17'} + hasBin: true + + typescript@5.8.3: + resolution: {integrity: sha512-p1diW6TqL9L07nNxvRMM7hMMw4c5XOo/1ibL4aAIGmSAt9slTE1Xgw5KWuof2uTOvCg9BY7ZRi+GaF+7sfgPeQ==} + engines: {node: '>=14.17'} + hasBin: true + + typical@7.3.0: + resolution: {integrity: sha512-ya4mg/30vm+DOWfBg4YK3j2WD6TWtRkCbasOJr40CseYENzCUby/7rIvXA99JGsQHeNxLbnXdyLLxKSv3tauFw==} + engines: {node: '>=12.17'} + + uc.micro@2.1.0: + resolution: {integrity: sha512-ARDJmphmdvUk6Glw7y9DQ2bFkKBHwQHLi2lsaH6PPmz/Ka9sFOBsBluozhDltWmnv9u/cF6Rt87znRTPV+yp/A==} + + ufo@1.6.1: + resolution: {integrity: sha512-9a4/uxlTWJ4+a5i0ooc1rU7C7YOw3wT+UGqdeNNHWnOF9qcMBgLRS+4IYUqbczewFx4mLEig6gawh7X6mFlEkA==} + + unbox-primitive@1.1.0: + resolution: {integrity: sha512-nWJ91DjeOkej/TA8pXQ3myruKpKEYgqvpw9lz4OPHj/NWFNluYrjbz9j01CJ8yKQd2g4jFoOkINCTW2I5LEEyw==} + engines: {node: '>= 0.4'} + + underscore@1.13.7: + resolution: {integrity: sha512-GMXzWtsc57XAtguZgaQViUOzs0KTkk8ojr3/xAxXLITqf/3EMwxC0inyETfDFjH/Krbhuep0HNbbjI9i/q3F3g==} + + undici-types@5.26.5: + resolution: {integrity: sha512-JlCMO+ehdEIKqlFxk6IfVoAUVmgz7cU7zD/h9XZ0qzeosSHmUJVOzSQvvYSYWXkFXC+IfLKSIffhv0sVZup6pA==} + + undici-types@6.21.0: + resolution: {integrity: sha512-iwDZqg0QAGrg9Rav5H4n0M64c3mkR59cJ6wQp+7C4nI0gsmExaedaYLNO44eT4AtBBwjbTiGPMlt2Md0T9H9JQ==} + + undici-types@7.8.0: + resolution: {integrity: sha512-9UJ2xGDvQ43tYyVMpuHlsgApydB8ZKfVYTsLDhXkFL/6gfkp+U8xTGdh8pMJv1SpZna0zxG1DwsKZsreLbXBxw==} + + undici@7.12.0: + resolution: {integrity: sha512-GrKEsc3ughskmGA9jevVlIOPMiiAHJ4OFUtaAH+NhfTUSiZ1wMPIQqQvAJUrJspFXJt3EBWgpAeoHEDVT1IBug==} + engines: {node: '>=20.18.1'} + + unicorn-magic@0.1.0: + resolution: {integrity: sha512-lRfVq8fE8gz6QMBuDM6a+LO3IAzTi05H6gCVaUpir2E1Rwpo4ZUog45KpNXKC/Mn3Yb9UDuHumeFTo9iV/D9FQ==} + engines: {node: '>=18'} + + unicorn-magic@0.3.0: + resolution: {integrity: sha512-+QBBXBCvifc56fsbuxZQ6Sic3wqqc3WWaqxs58gvJrcOuN83HGTCwz3oS5phzU9LthRNE9VrJCFCLUgHeeFnfA==} + engines: {node: '>=18'} + + unified@11.0.5: + resolution: {integrity: sha512-xKvGhPWw3k84Qjh8bI3ZeJjqnyadK+GEFtazSfZv/rKeTkTjOJho6mFqh2SM96iIcZokxiOpg78GazTSg8+KHA==} + + unist-util-is@6.0.0: + resolution: {integrity: sha512-2qCTHimwdxLfz+YzdGfkqNlH0tLi9xjTnHddPmJwtIG9MGsdbutfTc4P+haPD7l7Cjxf/WZj+we5qfVPvvxfYw==} + + unist-util-position@5.0.0: + resolution: {integrity: sha512-fucsC7HjXvkB5R3kTCO7kUjRdrS0BJt3M/FPxmHMBOm8JQi2BsHAHFsy27E0EolP8rp0NzXsJ+jNPyDWvOJZPA==} + + unist-util-stringify-position@4.0.0: + resolution: {integrity: sha512-0ASV06AAoKCDkS2+xw5RXJywruurpbC4JZSm7nr7MOt1ojAzvyyaO+UxZf18j8FCF6kmzCZKcAgN/yu2gm2XgQ==} + + unist-util-visit-parents@6.0.1: + resolution: {integrity: sha512-L/PqWzfTP9lzzEa6CKs0k2nARxTdZduw3zyh8d2NVBnsyvHjSX4TWse388YrrQKbvI8w20fGjGlhgT96WwKykw==} + + unist-util-visit@5.0.0: + resolution: {integrity: sha512-MR04uvD+07cwl/yhVuVWAtw+3GOR/knlL55Nd/wAdblk27GCVt3lqpTivy/tkJcZoNPzTwS1Y+KMojlLDhoTzg==} + + universalify@2.0.1: + resolution: {integrity: sha512-gptHNQghINnc/vTGIk0SOFGFNXw7JVrlRUtConJRlvaw6DuX0wO5Jeko9sWrMBhh+PsYAZ7oXAiOnf/UKogyiw==} + engines: {node: '>= 10.0.0'} + + unplugin@1.16.1: + resolution: {integrity: sha512-4/u/j4FrCKdi17jaxuJA0jClGxB1AvU2hw/IuayPc4ay1XGaJs/rbb4v5WKwAjNifjmXK9PIFyuPiaK8azyR9w==} + engines: {node: '>=14.0.0'} + + unplugin@2.3.5: + resolution: {integrity: sha512-RyWSb5AHmGtjjNQ6gIlA67sHOsWpsbWpwDokLwTcejVdOjEkJZh7QKu14J00gDDVSh8kGH4KYC/TNBceXFZhtw==} + engines: {node: '>=18.12.0'} + + update-browserslist-db@1.1.3: + resolution: {integrity: sha512-UxhIZQ+QInVdunkDAaiazvvT/+fXL5Osr0JZlJulepYu6Jd7qJtDZjlur0emRlT71EN3ScPoE7gvsuIKKNavKw==} + hasBin: true + peerDependencies: + browserslist: '>= 4.21.0' + + uri-js@4.4.1: + resolution: {integrity: sha512-7rKUyy33Q1yc98pQ1DAmLtwX109F7TIfWlW1Ydo8Wl1ii1SeHieeh0HHfPeL2fMXK6z0s8ecKs9frCuLJvndBg==} + + urijs@1.19.11: + resolution: {integrity: sha512-HXgFDgDommxn5/bIv0cnQZsPhHDA90NPHD6+c/v21U5+Sx5hoP8+dP9IZXBU1gIfvdRfhG8cel9QNPeionfcCQ==} + + url-join@4.0.1: + resolution: {integrity: sha512-jk1+QP6ZJqyOiuEI9AEWQfju/nB2Pw466kbA0LEZljHwKeMgd9WrAEgEGxjPDD2+TNbbb37rTyhEfrCXfuKXnA==} + + use-callback-ref@1.3.3: + resolution: {integrity: sha512-jQL3lRnocaFtu3V00JToYz/4QkNWswxijDaCVNZRiRTO3HQDLsdu1ZtmIUvV4yPp+rvWm5j0y0TG/S61cuijTg==} + engines: {node: '>=10'} + peerDependencies: + '@types/react': '*' + react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + + use-sidecar@1.1.3: + resolution: {integrity: sha512-Fedw0aZvkhynoPYlA5WXrMCAMm+nSWdZt6lzJQ7Ok8S6Q+VsHmHpRWndVRJ8Be0ZbkfPc5LRYH+5XrzXcEeLRQ==} + engines: {node: '>=10'} + peerDependencies: + '@types/react': '*' + react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 || ^19.0.0-rc + peerDependenciesMeta: + '@types/react': + optional: true + + use-sync-external-store@1.5.0: + resolution: {integrity: sha512-Rb46I4cGGVBmjamjphe8L/UnvJD+uPPtTkNvX5mZgqdbavhI4EbgIWJiIHXJ8bc/i9EQGPRh4DwEURJ552Do0A==} + peerDependencies: + react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 + + util-deprecate@1.0.2: + resolution: {integrity: sha512-EPD5q1uXyFxJpCrLnCc1nHnq3gOa6DZBocAIiI2TaSCA7VCJ1UJDMagCzIkXNsUYfD1daK//LTEQ8xiIbrHtcw==} + + utility-types@3.11.0: + resolution: {integrity: sha512-6Z7Ma2aVEWisaL6TvBCy7P8rm2LQoPv6dJ7ecIaIixHcwfbJ0x7mWdbcwlIM5IGQxPZSFYeqRCqlOOeKoJYMkw==} + engines: {node: '>= 4'} + + uuid@8.3.2: + resolution: {integrity: sha512-+NYs2QeMWy+GWFOEm9xnn6HCDp0l7QBD7ml8zLUmJ+93Q5NF0NocErnwkTkXVFNiX3/fpC6afS8Dhb/gz7R7eg==} + hasBin: true + + v8-to-istanbul@9.3.0: + resolution: {integrity: sha512-kiGUalWN+rgBJ/1OHZsBtU4rXZOfj/7rKQxULKlIzwzQSvMJUUNgPwJEEh7gU6xEVxC0ahoOBvN2YI8GH6FNgA==} + engines: {node: '>=10.12.0'} + + validate-npm-package-license@3.0.4: + resolution: {integrity: sha512-DpKm2Ui/xN7/HQKCtpZxoRWBhZ9Z0kqtygG8XCgNQ8ZlDnxuQmWhj566j8fN4Cu3/JmbhsDo7fcAJq4s9h27Ew==} + + validate-npm-package-name@6.0.2: + resolution: {integrity: sha512-IUoow1YUtvoBBC06dXs8bR8B9vuA3aJfmQNKMoaPG/OFsPmoQvw8xh+6Ye25Gx9DQhoEom3Pcu9MKHerm/NpUQ==} + engines: {node: ^18.17.0 || >=20.5.0} + + validator@13.15.15: + resolution: {integrity: sha512-BgWVbCI72aIQy937xbawcs+hrVaN/CZ2UwutgaJ36hGqRrLNM+f5LUT/YPRbo8IV/ASeFzXszezV+y2+rq3l8A==} + engines: {node: '>= 0.10'} + + version-range@4.14.0: + resolution: {integrity: sha512-gjb0ARm9qlcBAonU4zPwkl9ecKkas+tC2CGwFfptTCWWIVTWY1YUbT2zZKsOAF1jR/tNxxyLwwG0cb42XlYcTg==} + engines: {node: '>=4'} + + vfile-message@4.0.2: + resolution: {integrity: sha512-jRDZ1IMLttGj41KcZvlrYAaI3CfqpLpfpf+Mfig13viT6NKvRzWZ+lXz0Y5D60w6uJIBAOGq9mSHf0gktF0duw==} + + vfile@6.0.3: + resolution: {integrity: sha512-KzIbH/9tXat2u30jf+smMwFCsno4wHVdNmzFyL+T/L3UGqqk6JKfVqOFOZEpZSHADH1k40ab6NUIXZq422ov3Q==} + + vite-node@3.2.4: + resolution: {integrity: sha512-EbKSKh+bh1E1IFxeO0pg1n4dvoOTt0UDiXMd/qn++r98+jPO1xtJilvXldeuQ8giIB5IkpjCgMleHMNEsGH6pg==} + engines: {node: ^18.0.0 || ^20.0.0 || >=22.0.0} + hasBin: true + + vite-plugin-css-injected-by-js@3.5.2: + resolution: {integrity: sha512-2MpU/Y+SCZyWUB6ua3HbJCrgnF0KACAsmzOQt1UvRVJCGF6S8xdA3ZUhWcWdM9ivG4I5az8PnQmwwrkC2CAQrQ==} + peerDependencies: + vite: '>2.0.0-0' + + vite-plugin-dts@4.5.4: + resolution: {integrity: sha512-d4sOM8M/8z7vRXHHq/ebbblfaxENjogAAekcfcDCCwAyvGqnPrc7f4NZbvItS+g4WTgerW0xDwSz5qz11JT3vg==} + peerDependencies: + typescript: '*' + vite: '*' + peerDependenciesMeta: + vite: + optional: true + + vite-plugin-static-copy@3.1.1: + resolution: {integrity: sha512-oR53SkL5cX4KT1t18E/xU50vJDo0N8oaHza4EMk0Fm+2/u6nQivxavOfrDk3udWj+dizRizB/QnBvJOOQrTTAQ==} + engines: {node: ^18.0.0 || >=20.0.0} + peerDependencies: + vite: ^5.0.0 || ^6.0.0 || ^7.0.0 + + vite@6.3.5: + resolution: {integrity: sha512-cZn6NDFE7wdTpINgs++ZJ4N49W2vRp8LCKrn3Ob1kYNtOo21vfDoaV5GzBfLU4MovSAB8uNRm4jgzVQZ+mBzPQ==} + engines: {node: ^18.0.0 || ^20.0.0 || >=22.0.0} + hasBin: true + peerDependencies: + '@types/node': ^18.0.0 || ^20.0.0 || >=22.0.0 + jiti: '>=1.21.0' + less: '*' + lightningcss: ^1.21.0 + sass: '*' + sass-embedded: '*' + stylus: '*' + sugarss: '*' + terser: ^5.16.0 + tsx: ^4.8.1 + yaml: ^2.4.2 + peerDependenciesMeta: + '@types/node': + optional: true + jiti: + optional: true + less: + optional: true + lightningcss: + optional: true + sass: + optional: true + sass-embedded: + optional: true + stylus: + optional: true + sugarss: + optional: true + terser: + optional: true + tsx: + optional: true + yaml: + optional: true + + vitest@3.2.4: + resolution: {integrity: sha512-LUCP5ev3GURDysTWiP47wRRUpLKMOfPh+yKTx3kVIEiu5KOMeqzpnYNsKyOoVrULivR8tLcks4+lga33Whn90A==} + engines: {node: ^18.0.0 || ^20.0.0 || >=22.0.0} + hasBin: true + peerDependencies: + '@edge-runtime/vm': '*' + '@types/debug': ^4.1.12 + '@types/node': ^18.0.0 || ^20.0.0 || >=22.0.0 + '@vitest/browser': 3.2.4 + '@vitest/ui': 3.2.4 + happy-dom: '*' + jsdom: '*' + peerDependenciesMeta: + '@edge-runtime/vm': + optional: true + '@types/debug': + optional: true + '@types/node': + optional: true + '@vitest/browser': + optional: true + '@vitest/ui': + optional: true + happy-dom: + optional: true + jsdom: + optional: true + + vscode-jsonrpc@8.2.0: + resolution: {integrity: sha512-C+r0eKJUIfiDIfwJhria30+TYWPtuHJXHtI7J0YlOmKAo7ogxP20T0zxB7HZQIFhIyvoBPwWskjxrvAtfjyZfA==} + engines: {node: '>=14.0.0'} + + vscode-jsonrpc@8.2.1: + resolution: {integrity: sha512-kdjOSJ2lLIn7r1rtrMbbNCHjyMPfRnowdKjBQ+mGq6NAW5QY2bEZC/khaC5OR8svbbjvLEaIXkOq45e2X9BIbQ==} + engines: {node: '>=14.0.0'} + + vscode-languageclient@9.0.1: + resolution: {integrity: sha512-JZiimVdvimEuHh5olxhxkht09m3JzUGwggb5eRUkzzJhZ2KjCN0nh55VfiED9oez9DyF8/fz1g1iBV3h+0Z2EA==} + engines: {vscode: ^1.82.0} + + vscode-languageserver-protocol@3.17.5: + resolution: {integrity: sha512-mb1bvRJN8SVznADSGWM9u/b07H7Ecg0I3OgXDuLdn307rl/J3A9YD6/eYOssqhecL27hK1IPZAsaqh00i/Jljg==} + + vscode-languageserver-types@3.17.5: + resolution: {integrity: sha512-Ld1VelNuX9pdF39h2Hgaeb5hEZM2Z3jUrrMgWQAu82jMtZp7p3vJT3BzToKtZI7NgQssZje5o0zryOrhQvzQAg==} + + vscode-uri@3.1.0: + resolution: {integrity: sha512-/BpdSx+yCQGnCvecbyXdxHDkuk55/G3xwnC0GqY4gmQ3j+A+g8kzzgB4Nk/SINjqn6+waqw3EgbVF2QKExkRxQ==} + + w3c-keyname@2.2.8: + resolution: {integrity: sha512-dpojBhNsCNN7T82Tm7k26A6G9ML3NkhDsnw9n/eoxSRlVBB4CEtIQ/KTCLI2Fwf3ataSXRhYFkQi3SlnFwPvPQ==} + + w3c-xmlserializer@5.0.0: + resolution: {integrity: sha512-o8qghlI8NZHU1lLPrpi2+Uq7abh4GGPpYANlalzWxyWteJOCsr/P+oPBA49TOLu5FTZO4d3F9MnWJfiMo4BkmA==} + engines: {node: '>=18'} + + watchpack@2.4.4: + resolution: {integrity: sha512-c5EGNOiyxxV5qmTtAB7rbiXxi1ooX1pQKMLX/MIabJjRA0SJBQOjKF+KSVfHkr9U1cADPon0mRiVe/riyaiDUA==} + engines: {node: '>=10.13.0'} + + web-vitals@4.2.4: + resolution: {integrity: sha512-r4DIlprAGwJ7YM11VZp4R884m0Vmgr6EAKe3P+kO0PPj3Unqyvv59rczf6UiGcb9Z8QxZVcqKNwv/g0WNdWwsw==} + + webidl-conversions@3.0.1: + resolution: {integrity: sha512-2JAn3z8AR6rjK8Sm8orRC0h/bcl/DqL7tRPdGZ4I1CjdF+EaMLmYxBHyXuKL849eucPFhvBoxMsflfOb8kxaeQ==} + + webidl-conversions@7.0.0: + resolution: {integrity: sha512-VwddBukDzu71offAQR975unBIGqfKZpM+8ZX6ySk8nYhVoo5CYaZyzt3YBvYtRtO+aoGlqxPg/B87NGVZ/fu6g==} + engines: {node: '>=12'} + + webpack-sources@3.3.3: + resolution: {integrity: sha512-yd1RBzSGanHkitROoPFd6qsrxt+oFhg/129YzheDGqeustzX0vTZJZsSsQjVQC4yzBQ56K55XU8gaNCtIzOnTg==} + engines: {node: '>=10.13.0'} + + webpack-virtual-modules@0.6.2: + resolution: {integrity: sha512-66/V2i5hQanC51vBQKPH4aI8NMAcBW59FVBs+rC7eGHupMyfn34q7rZIE+ETlJ+XTevqfUhVVBgSUNSW2flEUQ==} + + webpack@5.99.8: + resolution: {integrity: sha512-lQ3CPiSTpfOnrEGeXDwoq5hIGzSjmwD72GdfVzF7CQAI7t47rJG9eDWvcEkEn3CUQymAElVvDg3YNTlCYj+qUQ==} + engines: {node: '>=10.13.0'} + hasBin: true + peerDependencies: + webpack-cli: '*' + peerDependenciesMeta: + webpack-cli: + optional: true + + whatwg-encoding@3.1.1: + resolution: {integrity: sha512-6qN4hJdMwfYBtE3YBTTHhoeuUrDBPZmbQaxWAqSALV/MeEnR5z1xd8UKud2RAkFoPkmB+hli1TZSnyi84xz1vQ==} + engines: {node: '>=18'} + + whatwg-mimetype@4.0.0: + resolution: {integrity: sha512-QaKxh0eNIi2mE9p2vEdzfagOKHCcj1pJ56EEHGQOVxp8r9/iszLUUV7v89x9O1p/T+NlTM5W7jW6+cz4Fq1YVg==} + engines: {node: '>=18'} + + whatwg-url@14.2.0: + resolution: {integrity: sha512-De72GdQZzNTUBBChsXueQUnPKDkg/5A5zp7pFDuQAj5UFoENpiACU0wlCvzpAGnTkj++ihpKwKyYewn/XNUbKw==} + engines: {node: '>=18'} + + whatwg-url@5.0.0: + resolution: {integrity: sha512-saE57nupxk6v3HY35+jzBwYa0rKSy0XR8JSxZPwgLr7ys0IBzhGviA1/TUGJLmSVqs8pb9AnvICXEuOHLprYTw==} + + which-boxed-primitive@1.1.1: + resolution: {integrity: sha512-TbX3mj8n0odCBFVlY8AxkqcHASw3L60jIuF8jFP78az3C2YhmGvqbHBpAjTRH2/xqYunrJ9g1jSyjCjpoWzIAA==} + engines: {node: '>= 0.4'} + + which-builtin-type@1.2.1: + resolution: {integrity: sha512-6iBczoX+kDQ7a3+YJBnh3T+KZRxM/iYNPXicqk66/Qfm1b93iu+yOImkg0zHbj5LNOcNv1TEADiZ0xa34B4q6Q==} + engines: {node: '>= 0.4'} + + which-collection@1.0.2: + resolution: {integrity: sha512-K4jVyjnBdgvc86Y6BkaLZEN933SwYOuBFkdmBu9ZfkcAbdVbpITnDmjvZ/aQjRXQrv5EPkTnD1s39GiiqbngCw==} + engines: {node: '>= 0.4'} + + which-typed-array@1.1.19: + resolution: {integrity: sha512-rEvr90Bck4WZt9HHFC4DJMsjvu7x+r6bImz0/BrbWb7A2djJ8hnZMrWnHo9F8ssv0OMErasDhftrfROTyqSDrw==} + engines: {node: '>= 0.4'} + + which@2.0.2: + resolution: {integrity: sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==} + engines: {node: '>= 8'} + hasBin: true + + why-is-node-running@2.3.0: + resolution: {integrity: sha512-hUrmaWBdVDcxvYqnyh09zunKzROWjbZTiNy8dBEjkS7ehEDQibXJ7XvlmtbwuTclUiIyN+CyXQD4Vmko8fNm8w==} + engines: {node: '>=8'} + hasBin: true + + word-wrap@1.2.5: + resolution: {integrity: sha512-BN22B5eaMMI9UMtjrGd5g5eCYPpCPDUy0FJXbYsaT5zYxjFOckS53SQDE3pWkVoWpHXVb3BrYcEN4Twa55B5cA==} + engines: {node: '>=0.10.0'} + + wordwrapjs@5.1.0: + resolution: {integrity: sha512-JNjcULU2e4KJwUNv6CHgI46UvDGitb6dGryHajXTDiLgg1/RiGoPSDw4kZfYnwGtEXf2ZMeIewDQgFGzkCB2Sg==} + engines: {node: '>=12.17'} + + workerpool@6.5.1: + resolution: {integrity: sha512-Fs4dNYcsdpYSAfVxhnl1L5zTksjvOJxtC5hzMNl+1t9B8hTJTdKDyZ5ju7ztgPy+ft9tBFXoOlDNiOT9WUXZlA==} + + wrap-ansi@7.0.0: + resolution: {integrity: sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q==} + engines: {node: '>=10'} + + wrap-ansi@8.1.0: + resolution: {integrity: sha512-si7QWI6zUMq56bESFvagtmzMdGOtoxfR+Sez11Mobfc7tm+VkUckk9bW2UeffTGVUbOksxmSw0AA2gs8g71NCQ==} + engines: {node: '>=12'} + + wrappy@1.0.2: + resolution: {integrity: sha512-l4Sp/DRseor9wL6EvV2+TuQn63dMkPjZ/sp9XkghTEbV9KlPS1xUsZ3u7/IQO4wxtcFB4bgpQPRcR3QCvezPcQ==} + + ws@8.18.3: + resolution: {integrity: sha512-PEIGCY5tSlUt50cqyMXfCzX+oOPqN0vuGqWzbcJ2xvnkzkq46oOpz7dQaTDBdfICb4N14+GARUDw2XV2N4tvzg==} + engines: {node: '>=10.0.0'} + peerDependencies: + bufferutil: ^4.0.1 + utf-8-validate: '>=5.0.2' + peerDependenciesMeta: + bufferutil: + optional: true + utf-8-validate: + optional: true + + wsl-utils@0.1.0: + resolution: {integrity: sha512-h3Fbisa2nKGPxCpm89Hk33lBLsnaGBvctQopaBSOW/uIs6FTe1ATyAnKFJrzVs9vpGdsTe73WF3V4lIsk4Gacw==} + engines: {node: '>=18'} + + xml-name-validator@5.0.0: + resolution: {integrity: sha512-EvGK8EJ3DhaHfbRlETOWAS5pO9MZITeauHKJyb8wyajUfQUenkIg2MvLDTZ4T/TgIcm3HU0TFBgWWboAZ30UHg==} + engines: {node: '>=18'} + + xml2js@0.5.0: + resolution: {integrity: sha512-drPFnkQJik/O+uPKpqSgr22mpuFHqKdbS835iAQrUC73L2F5WkboIRd63ai/2Yg6I1jzifPFKH2NTK+cfglkIA==} + engines: {node: '>=4.0.0'} + + xmlbuilder@11.0.1: + resolution: {integrity: sha512-fDlsI/kFEx7gLvbecc0/ohLG50fugQp8ryHzMTuW9vSa1GJ0XYWKnhsUx7oie3G98+r56aTQIUB4kht42R3JvA==} + engines: {node: '>=4.0'} + + xmlchars@2.2.0: + resolution: {integrity: sha512-JZnDKK8B0RCDw84FNdDAIpZK+JuJw+s7Lz8nksI7SIuU3UXJJslUthsi+uWBUYOwPFwW7W7PRLRfUKpxjtjFCw==} + + y18n@5.0.8: + resolution: {integrity: sha512-0pfFzegeDWJHJIAmTLRP2DwHjdF5s7jo9tuztdQxAhINCdvS+3nGINqPd00AphqJR/0LhANUS6/+7SCb98YOfA==} + engines: {node: '>=10'} + + yallist@3.1.1: + resolution: {integrity: sha512-a4UGQaWPH59mOXUYnAG2ewncQS4i4F43Tv3JoAM+s2VDAmS9NsK8GpDMLrCHPksFT7h3K6TOoUNn2pb7RoXx4g==} + + yallist@4.0.0: + resolution: {integrity: sha512-3wdGidZyq5PB084XLES5TpOSRA3wjXAlIWMhum2kRcv/41Sn2emQ0dycQW4uZXLejwKvg6EsvbdlVL+FYEct7A==} + + yallist@5.0.0: + resolution: {integrity: sha512-YgvUTfwqyc7UXVMrB+SImsVYSmTS8X/tSrtdNZMImM+n7+QTriRXyXim0mBrTXNeqzVF0KWGgHPeiyViFFrNDw==} + engines: {node: '>=18'} + + yaml@1.10.2: + resolution: {integrity: sha512-r3vXyErRCYJ7wg28yvBY5VSoAF8ZvlcW9/BwUzEtUsjvX/DKs24dIkuwjtuprwJJHsbyUbLApepYTR1BN4uHrg==} + engines: {node: '>= 6'} + + yaml@2.8.0: + resolution: {integrity: sha512-4lLa/EcQCB0cJkyts+FpIRx5G/llPxfP6VQU5KByHEhLxY3IJCH0f0Hy1MHI8sClTvsIb8qwRJ6R/ZdlDJ/leQ==} + engines: {node: '>= 14.6'} + hasBin: true + + yargs-parser@20.2.9: + resolution: {integrity: sha512-y11nGElTIV+CT3Zv9t7VKl+Q3hTQoT9a1Qzezhhl6Rp21gJ/IVTW7Z3y9EWXhuUBC2Shnf+DX0antecpAwSP8w==} + engines: {node: '>=10'} + + yargs-parser@21.1.1: + resolution: {integrity: sha512-tVpsJW7DdjecAiFpbIB1e3qxIQsE6NoPc5/eTdrbbIC4h0LVsWhnoa3g+m2HclBIujHzsxZ4VJVA+GUuc2/LBw==} + engines: {node: '>=12'} + + yargs-unparser@2.0.0: + resolution: {integrity: sha512-7pRTIA9Qc1caZ0bZ6RYRGbHJthJWuakf+WmHK0rVeLkNrrGhfoabBNdue6kdINI6r4if7ocq9aD/n7xwKOdzOA==} + engines: {node: '>=10'} + + yargs@16.2.0: + resolution: {integrity: sha512-D1mvvtDG0L5ft/jGWkLpG1+m0eQxOfaBvTNELraWj22wSVUMWxZUvYgJYcKh6jGGIkJFhH4IZPQhR4TKpc8mBw==} + engines: {node: '>=10'} + + yargs@17.7.2: + resolution: {integrity: sha512-7dSzzRQ++CKnNI/krKnYRV7JKKPUXMEh61soaHKg9mrWEhzFWhFnxPxGl+69cD1Ou63C13NUPCnmIcrvqCuM6w==} + engines: {node: '>=12'} + + yauzl@2.10.0: + resolution: {integrity: sha512-p4a9I6X6nu6IhoGmBqAcbJy1mlC4j27vEPZX9F4L4/vZT3Lyq1VkFHw/V/PUcB9Buo+DG3iHkT0x3Qya58zc3g==} + + yazl@2.5.1: + resolution: {integrity: sha512-phENi2PLiHnHb6QBVot+dJnaAZ0xosj7p3fWl+znIjBDlnMI2PsZCJZ306BPTFOaHf5qdDEI8x5qFrSOBN5vrw==} + + yocto-queue@0.1.0: + resolution: {integrity: sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q==} + engines: {node: '>=10'} + + yocto-queue@1.2.1: + resolution: {integrity: sha512-AyeEbWOu/TAXdxlV9wmGcR0+yh2j3vYPGOECcIj2S7MkrLyC7ne+oye2BKTItt0ii2PHk4cDy+95+LshzbXnGg==} + engines: {node: '>=12.20'} + + zod@3.25.76: + resolution: {integrity: sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==} + + zustand@4.5.7: + resolution: {integrity: sha512-CHOUy7mu3lbD6o6LJLfllpjkzhHXSBlX8B9+qPddUsIfeF5S/UZ5q0kmCsnRqT1UHFQZchNFDDzMbQsuesHWlw==} + engines: {node: '>=12.7.0'} + peerDependencies: + '@types/react': '>=16.8' + immer: '>=9.0.6' + react: '>=16.8' + peerDependenciesMeta: + '@types/react': + optional: true + immer: + optional: true + react: + optional: true + + zustand@5.0.6: + resolution: {integrity: sha512-ihAqNeUVhe0MAD+X8M5UzqyZ9k3FFZLBTtqo6JLPwV53cbRB/mJwBI0PxcIgqhBBHlEs8G45OTDTMq3gNcLq3A==} + engines: {node: '>=12.20.0'} + peerDependencies: + '@types/react': '>=18.0.0' + immer: '>=9.0.6' + react: '>=18.0.0' + use-sync-external-store: '>=1.2.0' + peerDependenciesMeta: + '@types/react': + optional: true + immer: + optional: true + react: + optional: true + use-sync-external-store: + optional: true + + zwitch@2.0.4: + resolution: {integrity: sha512-bXE4cR/kVZhKZX/RjPEflHaKVhUVl85noU3v6b8apfQEc1x4A+zBxjZ4lN8LqGd6WZ3dl98pY4o717VFmoPp+A==} + +snapshots: + + '@adobe/css-tools@4.4.3': {} + + '@alloc/quick-lru@5.2.0': {} + + '@ampproject/remapping@2.3.0': + dependencies: + '@jridgewell/gen-mapping': 0.3.12 + '@jridgewell/trace-mapping': 0.3.29 + + '@apidevtools/json-schema-ref-parser@11.7.2': + dependencies: + '@jsdevtools/ono': 7.1.3 + '@types/json-schema': 7.0.15 + js-yaml: 4.1.0 + + '@apidevtools/openapi-schemas@2.1.0': {} + + '@apidevtools/swagger-methods@3.0.2': {} + + '@apidevtools/swagger-parser@10.1.1(openapi-types@12.1.3)': + dependencies: + '@apidevtools/json-schema-ref-parser': 11.7.2 + '@apidevtools/openapi-schemas': 2.1.0 + '@apidevtools/swagger-methods': 3.0.2 + '@jsdevtools/ono': 7.1.3 + ajv: 8.17.1 + ajv-draft-04: 1.0.0(ajv@8.17.1) + call-me-maybe: 1.0.2 + openapi-types: 12.1.3 + + '@asamuzakjp/css-color@3.2.0': + dependencies: + '@csstools/css-calc': 2.1.4(@csstools/css-parser-algorithms@3.0.5(@csstools/css-tokenizer@3.0.4))(@csstools/css-tokenizer@3.0.4) + '@csstools/css-color-parser': 3.0.10(@csstools/css-parser-algorithms@3.0.5(@csstools/css-tokenizer@3.0.4))(@csstools/css-tokenizer@3.0.4) + '@csstools/css-parser-algorithms': 3.0.5(@csstools/css-tokenizer@3.0.4) + '@csstools/css-tokenizer': 3.0.4 + lru-cache: 10.4.3 + + '@asyncapi/specs@6.8.1': + dependencies: + '@types/json-schema': 7.0.15 + + '@azu/format-text@1.0.2': {} + + '@azu/style-format@1.0.1': + dependencies: + '@azu/format-text': 1.0.2 + + '@azure/abort-controller@2.1.2': + dependencies: + tslib: 2.8.1 + + '@azure/core-auth@1.10.0': + dependencies: + '@azure/abort-controller': 2.1.2 + '@azure/core-util': 1.13.0 + tslib: 2.8.1 + transitivePeerDependencies: + - supports-color + + '@azure/core-client@1.10.0': + dependencies: + '@azure/abort-controller': 2.1.2 + '@azure/core-auth': 1.10.0 + '@azure/core-rest-pipeline': 1.22.0 + '@azure/core-tracing': 1.3.0 + '@azure/core-util': 1.13.0 + '@azure/logger': 1.3.0 + tslib: 2.8.1 + transitivePeerDependencies: + - supports-color + + '@azure/core-rest-pipeline@1.22.0': + dependencies: + '@azure/abort-controller': 2.1.2 + '@azure/core-auth': 1.10.0 + '@azure/core-tracing': 1.3.0 + '@azure/core-util': 1.13.0 + '@azure/logger': 1.3.0 + '@typespec/ts-http-runtime': 0.3.0 + tslib: 2.8.1 + transitivePeerDependencies: + - supports-color + + '@azure/core-tracing@1.3.0': + dependencies: + tslib: 2.8.1 + + '@azure/core-util@1.13.0': + dependencies: + '@azure/abort-controller': 2.1.2 + '@typespec/ts-http-runtime': 0.3.0 + tslib: 2.8.1 + transitivePeerDependencies: + - supports-color + + '@azure/identity@4.10.2': + dependencies: + '@azure/abort-controller': 2.1.2 + '@azure/core-auth': 1.10.0 + '@azure/core-client': 1.10.0 + '@azure/core-rest-pipeline': 1.22.0 + '@azure/core-tracing': 1.3.0 + '@azure/core-util': 1.13.0 + '@azure/logger': 1.3.0 + '@azure/msal-browser': 4.15.0 + '@azure/msal-node': 3.6.3 + open: 10.2.0 + tslib: 2.8.1 + transitivePeerDependencies: + - supports-color + + '@azure/logger@1.3.0': + dependencies: + '@typespec/ts-http-runtime': 0.3.0 + tslib: 2.8.1 + transitivePeerDependencies: + - supports-color + + '@azure/msal-browser@4.15.0': + dependencies: + '@azure/msal-common': 15.8.1 + + '@azure/msal-common@15.8.1': {} + + '@azure/msal-node@3.6.3': + dependencies: + '@azure/msal-common': 15.8.1 + jsonwebtoken: 9.0.2 + uuid: 8.3.2 + + '@babel/code-frame@7.27.1': + dependencies: + '@babel/helper-validator-identifier': 7.27.1 + js-tokens: 4.0.0 + picocolors: 1.1.1 + + '@babel/compat-data@7.28.0': {} + + '@babel/core@7.28.0': + dependencies: + '@ampproject/remapping': 2.3.0 + '@babel/code-frame': 7.27.1 + '@babel/generator': 7.28.0 + '@babel/helper-compilation-targets': 7.27.2 + '@babel/helper-module-transforms': 7.27.3(@babel/core@7.28.0) + '@babel/helpers': 7.27.6 + '@babel/parser': 7.28.0 + '@babel/template': 7.27.2 + '@babel/traverse': 7.28.0 + '@babel/types': 7.28.1 + convert-source-map: 2.0.0 + debug: 4.4.1 + gensync: 1.0.0-beta.2 + json5: 2.2.3 + semver: 6.3.1 + transitivePeerDependencies: + - supports-color + + '@babel/generator@7.28.0': + dependencies: + '@babel/parser': 7.28.0 + '@babel/types': 7.28.1 + '@jridgewell/gen-mapping': 0.3.12 + '@jridgewell/trace-mapping': 0.3.29 + jsesc: 3.1.0 + + '@babel/helper-annotate-as-pure@7.27.3': + dependencies: + '@babel/types': 7.28.1 + + '@babel/helper-compilation-targets@7.27.2': + dependencies: + '@babel/compat-data': 7.28.0 + '@babel/helper-validator-option': 7.27.1 + browserslist: 4.26.2 + lru-cache: 5.1.1 + semver: 6.3.1 + + '@babel/helper-create-class-features-plugin@7.27.1(@babel/core@7.28.0)': + dependencies: + '@babel/core': 7.28.0 + '@babel/helper-annotate-as-pure': 7.27.3 + '@babel/helper-member-expression-to-functions': 7.27.1 + '@babel/helper-optimise-call-expression': 7.27.1 + '@babel/helper-replace-supers': 7.27.1(@babel/core@7.28.0) + '@babel/helper-skip-transparent-expression-wrappers': 7.27.1 + '@babel/traverse': 7.28.0 + semver: 6.3.1 + transitivePeerDependencies: + - supports-color + + '@babel/helper-globals@7.28.0': {} + + '@babel/helper-member-expression-to-functions@7.27.1': + dependencies: + '@babel/traverse': 7.28.0 + '@babel/types': 7.28.1 + transitivePeerDependencies: + - supports-color + + '@babel/helper-module-imports@7.27.1': + dependencies: + '@babel/traverse': 7.28.0 + '@babel/types': 7.28.1 + transitivePeerDependencies: + - supports-color + + '@babel/helper-module-transforms@7.27.3(@babel/core@7.28.0)': + dependencies: + '@babel/core': 7.28.0 + '@babel/helper-module-imports': 7.27.1 + '@babel/helper-validator-identifier': 7.27.1 + '@babel/traverse': 7.28.0 + transitivePeerDependencies: + - supports-color + + '@babel/helper-optimise-call-expression@7.27.1': + dependencies: + '@babel/types': 7.28.1 + + '@babel/helper-plugin-utils@7.27.1': {} + + '@babel/helper-replace-supers@7.27.1(@babel/core@7.28.0)': + dependencies: + '@babel/core': 7.28.0 + '@babel/helper-member-expression-to-functions': 7.27.1 + '@babel/helper-optimise-call-expression': 7.27.1 + '@babel/traverse': 7.28.0 + transitivePeerDependencies: + - supports-color + + '@babel/helper-skip-transparent-expression-wrappers@7.27.1': + dependencies: + '@babel/traverse': 7.28.0 + '@babel/types': 7.28.1 + transitivePeerDependencies: + - supports-color + + '@babel/helper-string-parser@7.27.1': {} + + '@babel/helper-validator-identifier@7.27.1': {} + + '@babel/helper-validator-option@7.27.1': {} + + '@babel/helpers@7.27.6': + dependencies: + '@babel/template': 7.27.2 + '@babel/types': 7.28.1 + + '@babel/parser@7.28.0': + dependencies: + '@babel/types': 7.28.1 + + '@babel/plugin-syntax-jsx@7.27.1(@babel/core@7.28.0)': + dependencies: + '@babel/core': 7.28.0 + '@babel/helper-plugin-utils': 7.27.1 + + '@babel/plugin-syntax-typescript@7.27.1(@babel/core@7.28.0)': + dependencies: + '@babel/core': 7.28.0 + '@babel/helper-plugin-utils': 7.27.1 + + '@babel/plugin-transform-modules-commonjs@7.27.1(@babel/core@7.28.0)': + dependencies: + '@babel/core': 7.28.0 + '@babel/helper-module-transforms': 7.27.3(@babel/core@7.28.0) + '@babel/helper-plugin-utils': 7.27.1 + transitivePeerDependencies: + - supports-color + + '@babel/plugin-transform-react-jsx-self@7.27.1(@babel/core@7.28.0)': + dependencies: + '@babel/core': 7.28.0 + '@babel/helper-plugin-utils': 7.27.1 + + '@babel/plugin-transform-react-jsx-source@7.27.1(@babel/core@7.28.0)': + dependencies: + '@babel/core': 7.28.0 + '@babel/helper-plugin-utils': 7.27.1 + + '@babel/plugin-transform-typescript@7.28.0(@babel/core@7.28.0)': + dependencies: + '@babel/core': 7.28.0 + '@babel/helper-annotate-as-pure': 7.27.3 + '@babel/helper-create-class-features-plugin': 7.27.1(@babel/core@7.28.0) + '@babel/helper-plugin-utils': 7.27.1 + '@babel/helper-skip-transparent-expression-wrappers': 7.27.1 + '@babel/plugin-syntax-typescript': 7.27.1(@babel/core@7.28.0) + transitivePeerDependencies: + - supports-color + + '@babel/preset-typescript@7.27.1(@babel/core@7.28.0)': + dependencies: + '@babel/core': 7.28.0 + '@babel/helper-plugin-utils': 7.27.1 + '@babel/helper-validator-option': 7.27.1 + '@babel/plugin-syntax-jsx': 7.27.1(@babel/core@7.28.0) + '@babel/plugin-transform-modules-commonjs': 7.27.1(@babel/core@7.28.0) + '@babel/plugin-transform-typescript': 7.28.0(@babel/core@7.28.0) + transitivePeerDependencies: + - supports-color + + '@babel/runtime@7.28.2': {} + + '@babel/template@7.27.2': + dependencies: + '@babel/code-frame': 7.27.1 + '@babel/parser': 7.28.0 + '@babel/types': 7.28.1 + + '@babel/traverse@7.28.0': + dependencies: + '@babel/code-frame': 7.27.1 + '@babel/generator': 7.28.0 + '@babel/helper-globals': 7.28.0 + '@babel/parser': 7.28.0 + '@babel/template': 7.27.2 + '@babel/types': 7.28.1 + debug: 4.4.1 + transitivePeerDependencies: + - supports-color + + '@babel/types@7.28.1': + dependencies: + '@babel/helper-string-parser': 7.27.1 + '@babel/helper-validator-identifier': 7.27.1 + + '@bcoe/v8-coverage@0.2.3': {} + + '@bcoe/v8-coverage@1.0.2': {} + + '@chromatic-com/storybook@4.0.1(storybook@9.0.18(@testing-library/dom@10.4.1)(prettier@3.6.2))': + dependencies: + '@neoconfetti/react': 1.0.0 + chromatic: 12.2.0 + filesize: 10.1.6 + jsonfile: 6.1.0 + storybook: 9.0.18(@testing-library/dom@10.4.1)(prettier@3.6.2) + strip-ansi: 7.1.0 + transitivePeerDependencies: + - '@chromatic-com/cypress' + - '@chromatic-com/playwright' + + '@codemirror/autocomplete@6.18.6': + dependencies: + '@codemirror/language': 6.11.2 + '@codemirror/state': 6.5.2 + '@codemirror/view': 6.38.1 + '@lezer/common': 1.2.3 + + '@codemirror/autocomplete@6.19.0': + dependencies: + '@codemirror/language': 6.11.3 + '@codemirror/state': 6.5.2 + '@codemirror/view': 6.38.4 + '@lezer/common': 1.2.3 + + '@codemirror/commands@6.8.1': + dependencies: + '@codemirror/language': 6.11.2 + '@codemirror/state': 6.5.2 + '@codemirror/view': 6.38.1 + '@lezer/common': 1.2.3 + + '@codemirror/lang-python@6.2.1': + dependencies: + '@codemirror/autocomplete': 6.18.6 + '@codemirror/language': 6.11.2 + '@codemirror/state': 6.5.2 + '@lezer/common': 1.2.3 + '@lezer/python': 1.1.18 + + '@codemirror/lang-sql@6.9.0': + dependencies: + '@codemirror/autocomplete': 6.18.6 + '@codemirror/language': 6.11.2 + '@codemirror/state': 6.5.2 + '@lezer/common': 1.2.3 + '@lezer/highlight': 1.2.1 + '@lezer/lr': 1.4.2 + + '@codemirror/language@6.11.2': + dependencies: + '@codemirror/state': 6.5.2 + '@codemirror/view': 6.38.1 + '@lezer/common': 1.2.3 + '@lezer/highlight': 1.2.1 + '@lezer/lr': 1.4.2 + style-mod: 4.1.2 + + '@codemirror/language@6.11.3': + dependencies: + '@codemirror/state': 6.5.2 + '@codemirror/view': 6.38.4 + '@lezer/common': 1.2.3 + '@lezer/highlight': 1.2.1 + '@lezer/lr': 1.4.2 + style-mod: 4.1.2 + + '@codemirror/legacy-modes@6.5.1': + dependencies: + '@codemirror/language': 6.11.2 + + '@codemirror/lint@6.8.5': + dependencies: + '@codemirror/state': 6.5.2 + '@codemirror/view': 6.38.4 + crelt: 1.0.6 + + '@codemirror/search@6.5.10': + dependencies: + '@codemirror/state': 6.5.2 + '@codemirror/view': 6.38.4 + crelt: 1.0.6 + + '@codemirror/state@6.5.2': + dependencies: + '@marijn/find-cluster-break': 1.0.2 + + '@codemirror/theme-one-dark@6.1.2': + dependencies: + '@codemirror/language': 6.11.3 + '@codemirror/state': 6.5.2 + '@codemirror/view': 6.38.4 + '@lezer/highlight': 1.2.1 + + '@codemirror/view@6.38.1': + dependencies: + '@codemirror/state': 6.5.2 + crelt: 1.0.6 + style-mod: 4.1.2 + w3c-keyname: 2.2.8 + + '@codemirror/view@6.38.4': + dependencies: + '@codemirror/state': 6.5.2 + crelt: 1.0.6 + style-mod: 4.1.2 + w3c-keyname: 2.2.8 + + '@csstools/color-helpers@5.0.2': {} + + '@csstools/css-calc@2.1.4(@csstools/css-parser-algorithms@3.0.5(@csstools/css-tokenizer@3.0.4))(@csstools/css-tokenizer@3.0.4)': + dependencies: + '@csstools/css-parser-algorithms': 3.0.5(@csstools/css-tokenizer@3.0.4) + '@csstools/css-tokenizer': 3.0.4 + + '@csstools/css-color-parser@3.0.10(@csstools/css-parser-algorithms@3.0.5(@csstools/css-tokenizer@3.0.4))(@csstools/css-tokenizer@3.0.4)': + dependencies: + '@csstools/color-helpers': 5.0.2 + '@csstools/css-calc': 2.1.4(@csstools/css-parser-algorithms@3.0.5(@csstools/css-tokenizer@3.0.4))(@csstools/css-tokenizer@3.0.4) + '@csstools/css-parser-algorithms': 3.0.5(@csstools/css-tokenizer@3.0.4) + '@csstools/css-tokenizer': 3.0.4 + + '@csstools/css-parser-algorithms@3.0.5(@csstools/css-tokenizer@3.0.4)': + dependencies: + '@csstools/css-tokenizer': 3.0.4 + + '@csstools/css-tokenizer@3.0.4': {} + + '@duckdb/node-api@1.3.2-alpha.25': + dependencies: + '@duckdb/node-bindings': 1.3.2-alpha.25 + + '@duckdb/node-bindings-darwin-arm64@1.3.2-alpha.25': + optional: true + + '@duckdb/node-bindings-darwin-x64@1.3.2-alpha.25': + optional: true + + '@duckdb/node-bindings-linux-arm64@1.3.2-alpha.25': + optional: true + + '@duckdb/node-bindings-linux-x64@1.3.2-alpha.25': + optional: true + + '@duckdb/node-bindings-win32-x64@1.3.2-alpha.25': + optional: true + + '@duckdb/node-bindings@1.3.2-alpha.25': + optionalDependencies: + '@duckdb/node-bindings-darwin-arm64': 1.3.2-alpha.25 + '@duckdb/node-bindings-darwin-x64': 1.3.2-alpha.25 + '@duckdb/node-bindings-linux-arm64': 1.3.2-alpha.25 + '@duckdb/node-bindings-linux-x64': 1.3.2-alpha.25 + '@duckdb/node-bindings-win32-x64': 1.3.2-alpha.25 + + '@esbuild/aix-ppc64@0.25.8': + optional: true + + '@esbuild/android-arm64@0.25.8': + optional: true + + '@esbuild/android-arm@0.25.8': + optional: true + + '@esbuild/android-x64@0.25.8': + optional: true + + '@esbuild/darwin-arm64@0.25.8': + optional: true + + '@esbuild/darwin-x64@0.25.8': + optional: true + + '@esbuild/freebsd-arm64@0.25.8': + optional: true + + '@esbuild/freebsd-x64@0.25.8': + optional: true + + '@esbuild/linux-arm64@0.25.8': + optional: true + + '@esbuild/linux-arm@0.25.8': + optional: true + + '@esbuild/linux-ia32@0.25.8': + optional: true + + '@esbuild/linux-loong64@0.25.8': + optional: true + + '@esbuild/linux-mips64el@0.25.8': + optional: true + + '@esbuild/linux-ppc64@0.25.8': + optional: true + + '@esbuild/linux-riscv64@0.25.8': + optional: true + + '@esbuild/linux-s390x@0.25.8': + optional: true + + '@esbuild/linux-x64@0.25.8': + optional: true + + '@esbuild/netbsd-arm64@0.25.8': + optional: true + + '@esbuild/netbsd-x64@0.25.8': + optional: true + + '@esbuild/openbsd-arm64@0.25.8': + optional: true + + '@esbuild/openbsd-x64@0.25.8': + optional: true + + '@esbuild/openharmony-arm64@0.25.8': + optional: true + + '@esbuild/sunos-x64@0.25.8': + optional: true + + '@esbuild/win32-arm64@0.25.8': + optional: true + + '@esbuild/win32-ia32@0.25.8': + optional: true + + '@esbuild/win32-x64@0.25.8': + optional: true + + '@eslint-community/eslint-utils@4.7.0(eslint@9.31.0(jiti@2.4.2))': + dependencies: + eslint: 9.31.0(jiti@2.4.2) + eslint-visitor-keys: 3.4.3 + + '@eslint-community/regexpp@4.12.1': {} + + '@eslint/config-array@0.21.0': + dependencies: + '@eslint/object-schema': 2.1.6 + debug: 4.4.1 + minimatch: 3.1.2 + transitivePeerDependencies: + - supports-color + + '@eslint/config-helpers@0.3.0': {} + + '@eslint/core@0.15.1': + dependencies: + '@types/json-schema': 7.0.15 + + '@eslint/eslintrc@3.3.1': + dependencies: + ajv: 6.12.6 + debug: 4.4.1 + espree: 10.4.0 + globals: 14.0.0 + ignore: 5.3.2 + import-fresh: 3.3.1 + js-yaml: 4.1.0 + minimatch: 3.1.2 + strip-json-comments: 3.1.1 + transitivePeerDependencies: + - supports-color + + '@eslint/js@9.31.0': {} + + '@eslint/object-schema@2.1.6': {} + + '@eslint/plugin-kit@0.3.4': + dependencies: + '@eslint/core': 0.15.1 + levn: 0.4.1 + + '@exodus/schemasafe@1.3.0': {} + + '@floating-ui/core@1.7.2': + dependencies: + '@floating-ui/utils': 0.2.10 + + '@floating-ui/dom@1.7.2': + dependencies: + '@floating-ui/core': 1.7.2 + '@floating-ui/utils': 0.2.10 + + '@floating-ui/react-dom@2.1.4(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@floating-ui/dom': 1.7.2 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + + '@floating-ui/react@0.26.28(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@floating-ui/react-dom': 2.1.4(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@floating-ui/utils': 0.2.10 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + tabbable: 6.2.0 + + '@floating-ui/utils@0.2.10': {} + + '@gerrit0/mini-shiki@3.8.1': + dependencies: + '@shikijs/engine-oniguruma': 3.8.1 + '@shikijs/langs': 3.12.2 + '@shikijs/themes': 3.12.2 + '@shikijs/types': 3.12.2 + '@shikijs/vscode-textmate': 10.0.2 + + '@headlessui/react@2.2.5(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@floating-ui/react': 0.26.28(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@react-aria/focus': 3.21.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@react-aria/interactions': 3.25.4(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@tanstack/react-virtual': 3.13.12(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + use-sync-external-store: 1.5.0(react@18.3.1) + + '@heroicons/react@2.2.0(react@18.3.1)': + dependencies: + react: 18.3.1 + + '@humanfs/core@0.19.1': {} + + '@humanfs/node@0.16.6': + dependencies: + '@humanfs/core': 0.19.1 + '@humanwhocodes/retry': 0.3.1 + + '@humanwhocodes/module-importer@1.0.1': {} + + '@humanwhocodes/retry@0.3.1': {} + + '@humanwhocodes/retry@0.4.3': {} + + '@ibm-cloud/openapi-ruleset-utilities@1.9.0': {} + + '@ibm-cloud/openapi-ruleset@1.31.1': + dependencies: + '@ibm-cloud/openapi-ruleset-utilities': 1.9.0 + '@stoplight/spectral-formats': 1.8.2 + '@stoplight/spectral-functions': 1.10.1 + '@stoplight/spectral-rulesets': 1.22.0 + chalk: 4.1.2 + jsonschema: 1.5.0 + lodash: 4.17.21 + loglevel: 1.9.2 + loglevel-plugin-prefix: 0.8.4 + minimatch: 6.2.0 + validator: 13.15.15 + transitivePeerDependencies: + - encoding + + '@isaacs/balanced-match@4.0.1': {} + + '@isaacs/brace-expansion@5.0.0': + dependencies: + '@isaacs/balanced-match': 4.0.1 + + '@isaacs/cliui@8.0.2': + dependencies: + string-width: 5.1.2 + string-width-cjs: string-width@4.2.3 + strip-ansi: 7.1.0 + strip-ansi-cjs: strip-ansi@6.0.1 + wrap-ansi: 8.1.0 + wrap-ansi-cjs: wrap-ansi@7.0.0 + + '@isaacs/fs-minipass@4.0.1': + dependencies: + minipass: 7.1.2 + + '@istanbuljs/schema@0.1.3': {} + + '@joshwooding/vite-plugin-react-docgen-typescript@0.6.1(typescript@5.8.3)(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0))': + dependencies: + glob: 10.4.5 + magic-string: 0.30.17 + react-docgen-typescript: 2.4.0(typescript@5.8.3) + vite: 6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + optionalDependencies: + typescript: 5.8.3 + + '@joshwooding/vite-plugin-react-docgen-typescript@0.6.1(typescript@5.8.3)(vite@6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0))': + dependencies: + glob: 10.4.5 + magic-string: 0.30.17 + react-docgen-typescript: 2.4.0(typescript@5.8.3) + vite: 6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + optionalDependencies: + typescript: 5.8.3 + + '@jridgewell/gen-mapping@0.3.12': + dependencies: + '@jridgewell/sourcemap-codec': 1.5.4 + '@jridgewell/trace-mapping': 0.3.29 + + '@jridgewell/gen-mapping@0.3.13': + dependencies: + '@jridgewell/sourcemap-codec': 1.5.5 + '@jridgewell/trace-mapping': 0.3.30 + + '@jridgewell/resolve-uri@3.1.2': {} + + '@jridgewell/source-map@0.3.11': + dependencies: + '@jridgewell/gen-mapping': 0.3.13 + '@jridgewell/trace-mapping': 0.3.31 + + '@jridgewell/sourcemap-codec@1.5.4': {} + + '@jridgewell/sourcemap-codec@1.5.5': {} + + '@jridgewell/trace-mapping@0.3.29': + dependencies: + '@jridgewell/resolve-uri': 3.1.2 + '@jridgewell/sourcemap-codec': 1.5.4 + + '@jridgewell/trace-mapping@0.3.30': + dependencies: + '@jridgewell/resolve-uri': 3.1.2 + '@jridgewell/sourcemap-codec': 1.5.5 + + '@jridgewell/trace-mapping@0.3.31': + dependencies: + '@jridgewell/resolve-uri': 3.1.2 + '@jridgewell/sourcemap-codec': 1.5.5 + + '@jsdevtools/ono@7.1.3': {} + + '@jsep-plugin/assignment@1.3.0(jsep@1.4.0)': + dependencies: + jsep: 1.4.0 + + '@jsep-plugin/regex@1.0.4(jsep@1.4.0)': + dependencies: + jsep: 1.4.0 + + '@jsep-plugin/ternary@1.1.4(jsep@1.4.0)': + dependencies: + jsep: 1.4.0 + + '@lezer/common@1.2.3': {} + + '@lezer/highlight@1.2.1': + dependencies: + '@lezer/common': 1.2.3 + + '@lezer/lr@1.4.2': + dependencies: + '@lezer/common': 1.2.3 + + '@lezer/python@1.1.18': + dependencies: + '@lezer/common': 1.2.3 + '@lezer/highlight': 1.2.1 + '@lezer/lr': 1.4.2 + + '@lit/react@1.0.8(@types/react@18.3.23)': + dependencies: + '@types/react': 18.3.23 + + '@marijn/find-cluster-break@1.0.2': {} + + '@mdx-js/react@3.1.0(@types/react@18.3.23)(react@18.3.1)': + dependencies: + '@types/mdx': 2.0.13 + '@types/react': 18.3.23 + react: 18.3.1 + + '@microsoft/api-extractor-model@7.30.7(@types/node@20.11.25)': + dependencies: + '@microsoft/tsdoc': 0.15.1 + '@microsoft/tsdoc-config': 0.17.1 + '@rushstack/node-core-library': 5.14.0(@types/node@20.11.25) + transitivePeerDependencies: + - '@types/node' + + '@microsoft/api-extractor@7.52.10(@types/node@20.11.25)': + dependencies: + '@microsoft/api-extractor-model': 7.30.7(@types/node@20.11.25) + '@microsoft/tsdoc': 0.15.1 + '@microsoft/tsdoc-config': 0.17.1 + '@rushstack/node-core-library': 5.14.0(@types/node@20.11.25) + '@rushstack/rig-package': 0.5.3 + '@rushstack/terminal': 0.15.4(@types/node@20.11.25) + '@rushstack/ts-command-line': 5.0.2(@types/node@20.11.25) + lodash: 4.17.21 + minimatch: 10.0.3 + resolve: 1.22.10 + semver: 7.5.4 + source-map: 0.6.1 + typescript: 5.8.2 + transitivePeerDependencies: + - '@types/node' + + '@microsoft/tsdoc-config@0.17.1': + dependencies: + '@microsoft/tsdoc': 0.15.1 + ajv: 8.12.0 + jju: 1.4.0 + resolve: 1.22.10 + + '@microsoft/tsdoc@0.15.1': {} + + '@neoconfetti/react@1.0.0': {} + + '@nodelib/fs.scandir@2.1.5': + dependencies: + '@nodelib/fs.stat': 2.0.5 + run-parallel: 1.2.0 + + '@nodelib/fs.stat@2.0.5': {} + + '@nodelib/fs.walk@1.2.8': + dependencies: + '@nodelib/fs.scandir': 2.1.5 + fastq: 1.19.1 + + '@orval/angular@7.10.0(openapi-types@12.1.3)': + dependencies: + '@orval/core': 7.10.0(openapi-types@12.1.3) + transitivePeerDependencies: + - encoding + - openapi-types + - supports-color + + '@orval/axios@7.10.0(openapi-types@12.1.3)': + dependencies: + '@orval/core': 7.10.0(openapi-types@12.1.3) + transitivePeerDependencies: + - encoding + - openapi-types + - supports-color + + '@orval/core@7.10.0(openapi-types@12.1.3)': + dependencies: + '@apidevtools/swagger-parser': 10.1.1(openapi-types@12.1.3) + '@ibm-cloud/openapi-ruleset': 1.31.1 + acorn: 8.15.0 + ajv: 8.17.1 + chalk: 4.1.2 + compare-versions: 6.1.1 + debug: 4.4.1(supports-color@8.1.1) + esbuild: 0.25.8 + esutils: 2.0.3 + fs-extra: 11.3.0 + globby: 11.1.0 + lodash.isempty: 4.4.0 + lodash.uniq: 4.5.0 + lodash.uniqby: 4.7.0 + lodash.uniqwith: 4.5.0 + micromatch: 4.0.8 + openapi3-ts: 4.4.0 + swagger2openapi: 7.0.8 + transitivePeerDependencies: + - encoding + - openapi-types + - supports-color + + '@orval/fetch@7.10.0(openapi-types@12.1.3)': + dependencies: + '@orval/core': 7.10.0(openapi-types@12.1.3) + transitivePeerDependencies: + - encoding + - openapi-types + - supports-color + + '@orval/hono@7.10.0(openapi-types@12.1.3)': + dependencies: + '@orval/core': 7.10.0(openapi-types@12.1.3) + '@orval/zod': 7.10.0(openapi-types@12.1.3) + lodash.uniq: 4.5.0 + transitivePeerDependencies: + - encoding + - openapi-types + - supports-color + + '@orval/mcp@7.10.0(openapi-types@12.1.3)': + dependencies: + '@orval/core': 7.10.0(openapi-types@12.1.3) + '@orval/zod': 7.10.0(openapi-types@12.1.3) + transitivePeerDependencies: + - encoding + - openapi-types + - supports-color + + '@orval/mock@7.10.0(openapi-types@12.1.3)': + dependencies: + '@orval/core': 7.10.0(openapi-types@12.1.3) + openapi3-ts: 4.4.0 + transitivePeerDependencies: + - encoding + - openapi-types + - supports-color + + '@orval/query@7.10.0(openapi-types@12.1.3)': + dependencies: + '@orval/core': 7.10.0(openapi-types@12.1.3) + '@orval/fetch': 7.10.0(openapi-types@12.1.3) + lodash.omitby: 4.6.0 + transitivePeerDependencies: + - encoding + - openapi-types + - supports-color + + '@orval/swr@7.10.0(openapi-types@12.1.3)': + dependencies: + '@orval/core': 7.10.0(openapi-types@12.1.3) + '@orval/fetch': 7.10.0(openapi-types@12.1.3) + transitivePeerDependencies: + - encoding + - openapi-types + - supports-color + + '@orval/zod@7.10.0(openapi-types@12.1.3)': + dependencies: + '@orval/core': 7.10.0(openapi-types@12.1.3) + lodash.uniq: 4.5.0 + transitivePeerDependencies: + - encoding + - openapi-types + - supports-color + + '@pkgjs/parseargs@0.11.0': + optional: true + + '@playwright/test@1.54.1': + dependencies: + playwright: 1.54.1 + + '@polka/url@1.0.0-next.29': {} + + '@radix-ui/number@1.1.1': {} + + '@radix-ui/primitive@1.1.2': {} + + '@radix-ui/primitive@1.1.3': {} + + '@radix-ui/react-arrow@1.1.7(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@radix-ui/react-primitive': 2.1.3(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + optionalDependencies: + '@types/react': 18.3.23 + '@types/react-dom': 18.3.7(@types/react@18.3.23) + + '@radix-ui/react-collection@1.1.7(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@radix-ui/react-compose-refs': 1.1.2(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-context': 1.1.2(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-primitive': 2.1.3(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-slot': 1.2.3(@types/react@18.3.23)(react@18.3.1) + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + optionalDependencies: + '@types/react': 18.3.23 + '@types/react-dom': 18.3.7(@types/react@18.3.23) + + '@radix-ui/react-compose-refs@1.1.2(@types/react@18.3.23)(react@18.3.1)': + dependencies: + react: 18.3.1 + optionalDependencies: + '@types/react': 18.3.23 + + '@radix-ui/react-context-menu@2.2.15(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@radix-ui/primitive': 1.1.2 + '@radix-ui/react-context': 1.1.2(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-menu': 2.1.15(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-primitive': 2.1.3(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-use-callback-ref': 1.1.1(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-use-controllable-state': 1.2.2(@types/react@18.3.23)(react@18.3.1) + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + optionalDependencies: + '@types/react': 18.3.23 + '@types/react-dom': 18.3.7(@types/react@18.3.23) + + '@radix-ui/react-context@1.1.2(@types/react@18.3.23)(react@18.3.1)': + dependencies: + react: 18.3.1 + optionalDependencies: + '@types/react': 18.3.23 + + '@radix-ui/react-direction@1.1.1(@types/react@18.3.23)(react@18.3.1)': + dependencies: + react: 18.3.1 + optionalDependencies: + '@types/react': 18.3.23 + + '@radix-ui/react-dismissable-layer@1.1.10(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@radix-ui/primitive': 1.1.2 + '@radix-ui/react-compose-refs': 1.1.2(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-primitive': 2.1.3(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-use-callback-ref': 1.1.1(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-use-escape-keydown': 1.1.1(@types/react@18.3.23)(react@18.3.1) + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + optionalDependencies: + '@types/react': 18.3.23 + '@types/react-dom': 18.3.7(@types/react@18.3.23) + + '@radix-ui/react-dismissable-layer@1.1.11(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@radix-ui/primitive': 1.1.3 + '@radix-ui/react-compose-refs': 1.1.2(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-primitive': 2.1.3(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-use-callback-ref': 1.1.1(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-use-escape-keydown': 1.1.1(@types/react@18.3.23)(react@18.3.1) + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + optionalDependencies: + '@types/react': 18.3.23 + '@types/react-dom': 18.3.7(@types/react@18.3.23) + + '@radix-ui/react-focus-guards@1.1.2(@types/react@18.3.23)(react@18.3.1)': + dependencies: + react: 18.3.1 + optionalDependencies: + '@types/react': 18.3.23 + + '@radix-ui/react-focus-scope@1.1.7(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@radix-ui/react-compose-refs': 1.1.2(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-primitive': 2.1.3(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-use-callback-ref': 1.1.1(@types/react@18.3.23)(react@18.3.1) + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + optionalDependencies: + '@types/react': 18.3.23 + '@types/react-dom': 18.3.7(@types/react@18.3.23) + + '@radix-ui/react-id@1.1.1(@types/react@18.3.23)(react@18.3.1)': + dependencies: + '@radix-ui/react-use-layout-effect': 1.1.1(@types/react@18.3.23)(react@18.3.1) + react: 18.3.1 + optionalDependencies: + '@types/react': 18.3.23 + + '@radix-ui/react-menu@2.1.15(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@radix-ui/primitive': 1.1.2 + '@radix-ui/react-collection': 1.1.7(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-compose-refs': 1.1.2(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-context': 1.1.2(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-direction': 1.1.1(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-dismissable-layer': 1.1.10(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-focus-guards': 1.1.2(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-focus-scope': 1.1.7(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-id': 1.1.1(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-popper': 1.2.7(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-portal': 1.1.9(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-presence': 1.1.4(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-primitive': 2.1.3(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-roving-focus': 1.1.10(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-slot': 1.2.3(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-use-callback-ref': 1.1.1(@types/react@18.3.23)(react@18.3.1) + aria-hidden: 1.2.6 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + react-remove-scroll: 2.7.1(@types/react@18.3.23)(react@18.3.1) + optionalDependencies: + '@types/react': 18.3.23 + '@types/react-dom': 18.3.7(@types/react@18.3.23) + + '@radix-ui/react-popper@1.2.7(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@floating-ui/react-dom': 2.1.4(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-arrow': 1.1.7(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-compose-refs': 1.1.2(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-context': 1.1.2(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-primitive': 2.1.3(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-use-callback-ref': 1.1.1(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-use-layout-effect': 1.1.1(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-use-rect': 1.1.1(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-use-size': 1.1.1(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/rect': 1.1.1 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + optionalDependencies: + '@types/react': 18.3.23 + '@types/react-dom': 18.3.7(@types/react@18.3.23) + + '@radix-ui/react-popper@1.2.8(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@floating-ui/react-dom': 2.1.4(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-arrow': 1.1.7(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-compose-refs': 1.1.2(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-context': 1.1.2(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-primitive': 2.1.3(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-use-callback-ref': 1.1.1(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-use-layout-effect': 1.1.1(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-use-rect': 1.1.1(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-use-size': 1.1.1(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/rect': 1.1.1 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + optionalDependencies: + '@types/react': 18.3.23 + '@types/react-dom': 18.3.7(@types/react@18.3.23) + + '@radix-ui/react-portal@1.1.9(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@radix-ui/react-primitive': 2.1.3(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-use-layout-effect': 1.1.1(@types/react@18.3.23)(react@18.3.1) + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + optionalDependencies: + '@types/react': 18.3.23 + '@types/react-dom': 18.3.7(@types/react@18.3.23) + + '@radix-ui/react-presence@1.1.4(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@radix-ui/react-compose-refs': 1.1.2(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-use-layout-effect': 1.1.1(@types/react@18.3.23)(react@18.3.1) + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + optionalDependencies: + '@types/react': 18.3.23 + '@types/react-dom': 18.3.7(@types/react@18.3.23) + + '@radix-ui/react-presence@1.1.5(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@radix-ui/react-compose-refs': 1.1.2(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-use-layout-effect': 1.1.1(@types/react@18.3.23)(react@18.3.1) + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + optionalDependencies: + '@types/react': 18.3.23 + '@types/react-dom': 18.3.7(@types/react@18.3.23) + + '@radix-ui/react-primitive@2.1.3(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@radix-ui/react-slot': 1.2.3(@types/react@18.3.23)(react@18.3.1) + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + optionalDependencies: + '@types/react': 18.3.23 + '@types/react-dom': 18.3.7(@types/react@18.3.23) + + '@radix-ui/react-roving-focus@1.1.10(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@radix-ui/primitive': 1.1.2 + '@radix-ui/react-collection': 1.1.7(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-compose-refs': 1.1.2(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-context': 1.1.2(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-direction': 1.1.1(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-id': 1.1.1(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-primitive': 2.1.3(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-use-callback-ref': 1.1.1(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-use-controllable-state': 1.2.2(@types/react@18.3.23)(react@18.3.1) + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + optionalDependencies: + '@types/react': 18.3.23 + '@types/react-dom': 18.3.7(@types/react@18.3.23) + + '@radix-ui/react-select@2.2.5(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@radix-ui/number': 1.1.1 + '@radix-ui/primitive': 1.1.2 + '@radix-ui/react-collection': 1.1.7(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-compose-refs': 1.1.2(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-context': 1.1.2(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-direction': 1.1.1(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-dismissable-layer': 1.1.10(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-focus-guards': 1.1.2(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-focus-scope': 1.1.7(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-id': 1.1.1(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-popper': 1.2.7(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-portal': 1.1.9(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-primitive': 2.1.3(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-slot': 1.2.3(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-use-callback-ref': 1.1.1(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-use-controllable-state': 1.2.2(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-use-layout-effect': 1.1.1(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-use-previous': 1.1.1(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-visually-hidden': 1.2.3(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + aria-hidden: 1.2.6 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + react-remove-scroll: 2.7.1(@types/react@18.3.23)(react@18.3.1) + optionalDependencies: + '@types/react': 18.3.23 + '@types/react-dom': 18.3.7(@types/react@18.3.23) + + '@radix-ui/react-slot@1.2.3(@types/react@18.3.23)(react@18.3.1)': + dependencies: + '@radix-ui/react-compose-refs': 1.1.2(@types/react@18.3.23)(react@18.3.1) + react: 18.3.1 + optionalDependencies: + '@types/react': 18.3.23 + + '@radix-ui/react-tooltip@1.2.8(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@radix-ui/primitive': 1.1.3 + '@radix-ui/react-compose-refs': 1.1.2(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-context': 1.1.2(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-dismissable-layer': 1.1.11(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-id': 1.1.1(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-popper': 1.2.8(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-portal': 1.1.9(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-presence': 1.1.5(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-primitive': 2.1.3(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@radix-ui/react-slot': 1.2.3(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-use-controllable-state': 1.2.2(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-visually-hidden': 1.2.3(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + optionalDependencies: + '@types/react': 18.3.23 + '@types/react-dom': 18.3.7(@types/react@18.3.23) + + '@radix-ui/react-use-callback-ref@1.1.1(@types/react@18.3.23)(react@18.3.1)': + dependencies: + react: 18.3.1 + optionalDependencies: + '@types/react': 18.3.23 + + '@radix-ui/react-use-controllable-state@1.2.2(@types/react@18.3.23)(react@18.3.1)': + dependencies: + '@radix-ui/react-use-effect-event': 0.0.2(@types/react@18.3.23)(react@18.3.1) + '@radix-ui/react-use-layout-effect': 1.1.1(@types/react@18.3.23)(react@18.3.1) + react: 18.3.1 + optionalDependencies: + '@types/react': 18.3.23 + + '@radix-ui/react-use-effect-event@0.0.2(@types/react@18.3.23)(react@18.3.1)': + dependencies: + '@radix-ui/react-use-layout-effect': 1.1.1(@types/react@18.3.23)(react@18.3.1) + react: 18.3.1 + optionalDependencies: + '@types/react': 18.3.23 + + '@radix-ui/react-use-escape-keydown@1.1.1(@types/react@18.3.23)(react@18.3.1)': + dependencies: + '@radix-ui/react-use-callback-ref': 1.1.1(@types/react@18.3.23)(react@18.3.1) + react: 18.3.1 + optionalDependencies: + '@types/react': 18.3.23 + + '@radix-ui/react-use-layout-effect@1.1.1(@types/react@18.3.23)(react@18.3.1)': + dependencies: + react: 18.3.1 + optionalDependencies: + '@types/react': 18.3.23 + + '@radix-ui/react-use-previous@1.1.1(@types/react@18.3.23)(react@18.3.1)': + dependencies: + react: 18.3.1 + optionalDependencies: + '@types/react': 18.3.23 + + '@radix-ui/react-use-rect@1.1.1(@types/react@18.3.23)(react@18.3.1)': + dependencies: + '@radix-ui/rect': 1.1.1 + react: 18.3.1 + optionalDependencies: + '@types/react': 18.3.23 + + '@radix-ui/react-use-size@1.1.1(@types/react@18.3.23)(react@18.3.1)': + dependencies: + '@radix-ui/react-use-layout-effect': 1.1.1(@types/react@18.3.23)(react@18.3.1) + react: 18.3.1 + optionalDependencies: + '@types/react': 18.3.23 + + '@radix-ui/react-visually-hidden@1.2.3(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@radix-ui/react-primitive': 2.1.3(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + optionalDependencies: + '@types/react': 18.3.23 + '@types/react-dom': 18.3.7(@types/react@18.3.23) + + '@radix-ui/rect@1.1.1': {} + + '@react-aria/focus@3.21.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@react-aria/interactions': 3.25.4(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@react-aria/utils': 3.30.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@react-types/shared': 3.31.0(react@18.3.1) + '@swc/helpers': 0.5.17 + clsx: 2.1.1 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + + '@react-aria/interactions@3.25.4(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@react-aria/ssr': 3.9.10(react@18.3.1) + '@react-aria/utils': 3.30.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@react-stately/flags': 3.1.2 + '@react-types/shared': 3.31.0(react@18.3.1) + '@swc/helpers': 0.5.17 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + + '@react-aria/ssr@3.9.10(react@18.3.1)': + dependencies: + '@swc/helpers': 0.5.17 + react: 18.3.1 + + '@react-aria/utils@3.30.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@react-aria/ssr': 3.9.10(react@18.3.1) + '@react-stately/flags': 3.1.2 + '@react-stately/utils': 3.10.8(react@18.3.1) + '@react-types/shared': 3.31.0(react@18.3.1) + '@swc/helpers': 0.5.17 + clsx: 2.1.1 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + + '@react-dnd/asap@5.0.2': {} + + '@react-dnd/invariant@4.0.2': {} + + '@react-dnd/shallowequal@4.0.2': {} + + '@react-stately/flags@3.1.2': + dependencies: + '@swc/helpers': 0.5.17 + + '@react-stately/utils@3.10.8(react@18.3.1)': + dependencies: + '@swc/helpers': 0.5.17 + react: 18.3.1 + + '@react-types/shared@3.31.0(react@18.3.1)': + dependencies: + react: 18.3.1 + + '@reactflow/background@11.3.14(@types/react@18.3.23)(immer@9.0.21)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@reactflow/core': 11.11.4(@types/react@18.3.23)(immer@9.0.21)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + classcat: 5.0.5 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + zustand: 4.5.7(@types/react@18.3.23)(immer@9.0.21)(react@18.3.1) + transitivePeerDependencies: + - '@types/react' + - immer + + '@reactflow/controls@11.2.14(@types/react@18.3.23)(immer@9.0.21)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@reactflow/core': 11.11.4(@types/react@18.3.23)(immer@9.0.21)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + classcat: 5.0.5 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + zustand: 4.5.7(@types/react@18.3.23)(immer@9.0.21)(react@18.3.1) + transitivePeerDependencies: + - '@types/react' + - immer + + '@reactflow/core@11.11.4(@types/react@18.3.23)(immer@9.0.21)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@types/d3': 7.4.3 + '@types/d3-drag': 3.0.7 + '@types/d3-selection': 3.0.11 + '@types/d3-zoom': 3.0.8 + classcat: 5.0.5 + d3-drag: 3.0.0 + d3-selection: 3.0.0 + d3-zoom: 3.0.0 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + zustand: 4.5.7(@types/react@18.3.23)(immer@9.0.21)(react@18.3.1) + transitivePeerDependencies: + - '@types/react' + - immer + + '@reactflow/minimap@11.7.14(@types/react@18.3.23)(immer@9.0.21)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@reactflow/core': 11.11.4(@types/react@18.3.23)(immer@9.0.21)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@types/d3-selection': 3.0.11 + '@types/d3-zoom': 3.0.8 + classcat: 5.0.5 + d3-selection: 3.0.0 + d3-zoom: 3.0.0 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + zustand: 4.5.7(@types/react@18.3.23)(immer@9.0.21)(react@18.3.1) + transitivePeerDependencies: + - '@types/react' + - immer + + '@reactflow/node-resizer@2.2.14(@types/react@18.3.23)(immer@9.0.21)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@reactflow/core': 11.11.4(@types/react@18.3.23)(immer@9.0.21)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + classcat: 5.0.5 + d3-drag: 3.0.0 + d3-selection: 3.0.0 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + zustand: 4.5.7(@types/react@18.3.23)(immer@9.0.21)(react@18.3.1) + transitivePeerDependencies: + - '@types/react' + - immer + + '@reactflow/node-toolbar@1.3.14(@types/react@18.3.23)(immer@9.0.21)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@reactflow/core': 11.11.4(@types/react@18.3.23)(immer@9.0.21)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + classcat: 5.0.5 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + zustand: 4.5.7(@types/react@18.3.23)(immer@9.0.21)(react@18.3.1) + transitivePeerDependencies: + - '@types/react' + - immer + + '@rolldown/pluginutils@1.0.0-beta.27': {} + + '@rollup/pluginutils@5.2.0(rollup@4.45.1)': + dependencies: + '@types/estree': 1.0.8 + estree-walker: 2.0.2 + picomatch: 4.0.3 + optionalDependencies: + rollup: 4.45.1 + + '@rollup/rollup-android-arm-eabi@4.45.1': + optional: true + + '@rollup/rollup-android-arm64@4.45.1': + optional: true + + '@rollup/rollup-darwin-arm64@4.45.1': + optional: true + + '@rollup/rollup-darwin-x64@4.45.1': + optional: true + + '@rollup/rollup-freebsd-arm64@4.45.1': + optional: true + + '@rollup/rollup-freebsd-x64@4.45.1': + optional: true + + '@rollup/rollup-linux-arm-gnueabihf@4.45.1': + optional: true + + '@rollup/rollup-linux-arm-musleabihf@4.45.1': + optional: true + + '@rollup/rollup-linux-arm64-gnu@4.45.1': + optional: true + + '@rollup/rollup-linux-arm64-musl@4.45.1': + optional: true + + '@rollup/rollup-linux-loongarch64-gnu@4.45.1': + optional: true + + '@rollup/rollup-linux-powerpc64le-gnu@4.45.1': + optional: true + + '@rollup/rollup-linux-riscv64-gnu@4.45.1': + optional: true + + '@rollup/rollup-linux-riscv64-musl@4.45.1': + optional: true + + '@rollup/rollup-linux-s390x-gnu@4.45.1': + optional: true + + '@rollup/rollup-linux-x64-gnu@4.45.1': + optional: true + + '@rollup/rollup-linux-x64-musl@4.45.1': + optional: true + + '@rollup/rollup-win32-arm64-msvc@4.45.1': + optional: true + + '@rollup/rollup-win32-ia32-msvc@4.45.1': + optional: true + + '@rollup/rollup-win32-x64-msvc@4.45.1': + optional: true + + '@rushstack/node-core-library@5.14.0(@types/node@20.11.25)': + dependencies: + ajv: 8.13.0 + ajv-draft-04: 1.0.0(ajv@8.13.0) + ajv-formats: 3.0.1(ajv@8.13.0) + fs-extra: 11.3.0 + import-lazy: 4.0.0 + jju: 1.4.0 + resolve: 1.22.10 + semver: 7.5.4 + optionalDependencies: + '@types/node': 20.11.25 + + '@rushstack/rig-package@0.5.3': + dependencies: + resolve: 1.22.10 + strip-json-comments: 3.1.1 + + '@rushstack/terminal@0.15.4(@types/node@20.11.25)': + dependencies: + '@rushstack/node-core-library': 5.14.0(@types/node@20.11.25) + supports-color: 8.1.1 + optionalDependencies: + '@types/node': 20.11.25 + + '@rushstack/ts-command-line@5.0.2(@types/node@20.11.25)': + dependencies: + '@rushstack/terminal': 0.15.4(@types/node@20.11.25) + '@types/argparse': 1.0.38 + argparse: 1.0.10 + string-argv: 0.3.2 + transitivePeerDependencies: + - '@types/node' + + '@secretlint/config-creator@10.2.1': + dependencies: + '@secretlint/types': 10.2.1 + + '@secretlint/config-loader@10.2.1': + dependencies: + '@secretlint/profiler': 10.2.1 + '@secretlint/resolver': 10.2.1 + '@secretlint/types': 10.2.1 + ajv: 8.17.1 + debug: 4.4.1(supports-color@8.1.1) + rc-config-loader: 4.1.3 + transitivePeerDependencies: + - supports-color + + '@secretlint/core@10.2.1': + dependencies: + '@secretlint/profiler': 10.2.1 + '@secretlint/types': 10.2.1 + debug: 4.4.1(supports-color@8.1.1) + structured-source: 4.0.0 + transitivePeerDependencies: + - supports-color + + '@secretlint/formatter@10.2.1': + dependencies: + '@secretlint/resolver': 10.2.1 + '@secretlint/types': 10.2.1 + '@textlint/linter-formatter': 15.2.0 + '@textlint/module-interop': 15.2.0 + '@textlint/types': 15.2.0 + chalk: 5.4.1 + debug: 4.4.1(supports-color@8.1.1) + pluralize: 8.0.0 + strip-ansi: 7.1.0 + table: 6.9.0 + terminal-link: 4.0.0 + transitivePeerDependencies: + - supports-color + + '@secretlint/node@10.2.1': + dependencies: + '@secretlint/config-loader': 10.2.1 + '@secretlint/core': 10.2.1 + '@secretlint/formatter': 10.2.1 + '@secretlint/profiler': 10.2.1 + '@secretlint/source-creator': 10.2.1 + '@secretlint/types': 10.2.1 + debug: 4.4.1(supports-color@8.1.1) + p-map: 7.0.3 + transitivePeerDependencies: + - supports-color + + '@secretlint/profiler@10.2.1': {} + + '@secretlint/resolver@10.2.1': {} + + '@secretlint/secretlint-formatter-sarif@10.2.1': + dependencies: + node-sarif-builder: 3.2.0 + + '@secretlint/secretlint-rule-no-dotenv@10.2.1': + dependencies: + '@secretlint/types': 10.2.1 + + '@secretlint/secretlint-rule-preset-recommend@10.2.1': {} + + '@secretlint/source-creator@10.2.1': + dependencies: + '@secretlint/types': 10.2.1 + istextorbinary: 9.5.0 + + '@secretlint/types@10.2.1': {} + + '@shikijs/engine-oniguruma@3.8.1': + dependencies: + '@shikijs/types': 3.8.1 + '@shikijs/vscode-textmate': 10.0.2 + + '@shikijs/langs@3.12.2': + dependencies: + '@shikijs/types': 3.12.2 + + '@shikijs/themes@3.12.2': + dependencies: + '@shikijs/types': 3.12.2 + + '@shikijs/types@3.12.2': + dependencies: + '@shikijs/vscode-textmate': 10.0.2 + '@types/hast': 3.0.4 + + '@shikijs/types@3.8.1': + dependencies: + '@shikijs/vscode-textmate': 10.0.2 + '@types/hast': 3.0.4 + + '@shikijs/vscode-textmate@10.0.2': {} + + '@sindresorhus/merge-streams@2.3.0': {} + + '@standard-schema/spec@1.0.0': {} + + '@stoplight/better-ajv-errors@1.0.3(ajv@8.17.1)': + dependencies: + ajv: 8.17.1 + jsonpointer: 5.0.1 + leven: 3.1.0 + + '@stoplight/json-ref-readers@1.2.2': + dependencies: + node-fetch: 2.7.0 + tslib: 1.14.1 + transitivePeerDependencies: + - encoding + + '@stoplight/json-ref-resolver@3.1.6': + dependencies: + '@stoplight/json': 3.21.7 + '@stoplight/path': 1.3.2 + '@stoplight/types': 13.20.0 + '@types/urijs': 1.19.25 + dependency-graph: 0.11.0 + fast-memoize: 2.5.2 + immer: 9.0.21 + lodash: 4.17.21 + tslib: 2.8.1 + urijs: 1.19.11 + + '@stoplight/json@3.21.7': + dependencies: + '@stoplight/ordered-object-literal': 1.0.5 + '@stoplight/path': 1.3.2 + '@stoplight/types': 13.20.0 + jsonc-parser: 2.2.1 + lodash: 4.17.21 + safe-stable-stringify: 1.1.1 + + '@stoplight/ordered-object-literal@1.0.5': {} + + '@stoplight/path@1.3.2': {} + + '@stoplight/spectral-core@1.20.0': + dependencies: + '@stoplight/better-ajv-errors': 1.0.3(ajv@8.17.1) + '@stoplight/json': 3.21.7 + '@stoplight/path': 1.3.2 + '@stoplight/spectral-parsers': 1.0.5 + '@stoplight/spectral-ref-resolver': 1.0.5 + '@stoplight/spectral-runtime': 1.1.4 + '@stoplight/types': 13.6.0 + '@types/es-aggregate-error': 1.0.6 + '@types/json-schema': 7.0.15 + ajv: 8.17.1 + ajv-errors: 3.0.0(ajv@8.17.1) + ajv-formats: 2.1.1(ajv@8.17.1) + es-aggregate-error: 1.0.14 + jsonpath-plus: 10.3.0 + lodash: 4.17.21 + lodash.topath: 4.5.2 + minimatch: 3.1.2 + nimma: 0.2.3 + pony-cause: 1.1.1 + simple-eval: 1.0.1 + tslib: 2.8.1 + transitivePeerDependencies: + - encoding + + '@stoplight/spectral-formats@1.8.2': + dependencies: + '@stoplight/json': 3.21.7 + '@stoplight/spectral-core': 1.20.0 + '@types/json-schema': 7.0.15 + tslib: 2.8.1 + transitivePeerDependencies: + - encoding + + '@stoplight/spectral-functions@1.10.1': + dependencies: + '@stoplight/better-ajv-errors': 1.0.3(ajv@8.17.1) + '@stoplight/json': 3.21.7 + '@stoplight/spectral-core': 1.20.0 + '@stoplight/spectral-formats': 1.8.2 + '@stoplight/spectral-runtime': 1.1.4 + ajv: 8.17.1 + ajv-draft-04: 1.0.0(ajv@8.17.1) + ajv-errors: 3.0.0(ajv@8.17.1) + ajv-formats: 2.1.1(ajv@8.17.1) + lodash: 4.17.21 + tslib: 2.8.1 + transitivePeerDependencies: + - encoding + + '@stoplight/spectral-parsers@1.0.5': + dependencies: + '@stoplight/json': 3.21.7 + '@stoplight/types': 14.1.1 + '@stoplight/yaml': 4.3.0 + tslib: 2.8.1 + + '@stoplight/spectral-ref-resolver@1.0.5': + dependencies: + '@stoplight/json-ref-readers': 1.2.2 + '@stoplight/json-ref-resolver': 3.1.6 + '@stoplight/spectral-runtime': 1.1.4 + dependency-graph: 0.11.0 + tslib: 2.8.1 + transitivePeerDependencies: + - encoding + + '@stoplight/spectral-rulesets@1.22.0': + dependencies: + '@asyncapi/specs': 6.8.1 + '@stoplight/better-ajv-errors': 1.0.3(ajv@8.17.1) + '@stoplight/json': 3.21.7 + '@stoplight/spectral-core': 1.20.0 + '@stoplight/spectral-formats': 1.8.2 + '@stoplight/spectral-functions': 1.10.1 + '@stoplight/spectral-runtime': 1.1.4 + '@stoplight/types': 13.20.0 + '@types/json-schema': 7.0.15 + ajv: 8.17.1 + ajv-formats: 2.1.1(ajv@8.17.1) + json-schema-traverse: 1.0.0 + leven: 3.1.0 + lodash: 4.17.21 + tslib: 2.8.1 + transitivePeerDependencies: + - encoding + + '@stoplight/spectral-runtime@1.1.4': + dependencies: + '@stoplight/json': 3.21.7 + '@stoplight/path': 1.3.2 + '@stoplight/types': 13.20.0 + abort-controller: 3.0.0 + lodash: 4.17.21 + node-fetch: 2.7.0 + tslib: 2.8.1 + transitivePeerDependencies: + - encoding + + '@stoplight/types@13.20.0': + dependencies: + '@types/json-schema': 7.0.15 + utility-types: 3.11.0 + + '@stoplight/types@13.6.0': + dependencies: + '@types/json-schema': 7.0.15 + utility-types: 3.11.0 + + '@stoplight/types@14.1.1': + dependencies: + '@types/json-schema': 7.0.15 + utility-types: 3.11.0 + + '@stoplight/yaml-ast-parser@0.0.50': {} + + '@stoplight/yaml@4.3.0': + dependencies: + '@stoplight/ordered-object-literal': 1.0.5 + '@stoplight/types': 14.1.1 + '@stoplight/yaml-ast-parser': 0.0.50 + tslib: 2.8.1 + + '@storybook/addon-a11y@9.0.18(storybook@9.0.18(@testing-library/dom@10.4.1)(prettier@3.6.2))': + dependencies: + '@storybook/global': 5.0.0 + axe-core: 4.10.3 + storybook: 9.0.18(@testing-library/dom@10.4.1)(prettier@3.6.2) + + '@storybook/addon-docs@9.0.18(@types/react@18.3.23)(storybook@9.0.18(@testing-library/dom@10.4.1)(prettier@3.6.2))': + dependencies: + '@mdx-js/react': 3.1.0(@types/react@18.3.23)(react@18.3.1) + '@storybook/csf-plugin': 9.0.18(storybook@9.0.18(@testing-library/dom@10.4.1)(prettier@3.6.2)) + '@storybook/icons': 1.4.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@storybook/react-dom-shim': 9.0.18(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@9.0.18(@testing-library/dom@10.4.1)(prettier@3.6.2)) + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + storybook: 9.0.18(@testing-library/dom@10.4.1)(prettier@3.6.2) + ts-dedent: 2.2.0 + transitivePeerDependencies: + - '@types/react' + + '@storybook/addon-docs@9.1.5(@types/react@18.3.23)(storybook@9.1.5(@testing-library/dom@10.4.1)(prettier@3.6.2)(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)))': + dependencies: + '@mdx-js/react': 3.1.0(@types/react@18.3.23)(react@18.3.1) + '@storybook/csf-plugin': 9.1.5(storybook@9.1.5(@testing-library/dom@10.4.1)(prettier@3.6.2)(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0))) + '@storybook/icons': 1.4.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@storybook/react-dom-shim': 9.1.5(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@9.1.5(@testing-library/dom@10.4.1)(prettier@3.6.2)(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0))) + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + storybook: 9.1.5(@testing-library/dom@10.4.1)(prettier@3.6.2)(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)) + ts-dedent: 2.2.0 + transitivePeerDependencies: + - '@types/react' + + '@storybook/addon-onboarding@9.0.18(storybook@9.0.18(@testing-library/dom@10.4.1)(prettier@3.6.2))': + dependencies: + storybook: 9.0.18(@testing-library/dom@10.4.1)(prettier@3.6.2) + + '@storybook/addon-vitest@9.0.18(@vitest/browser@3.2.3)(@vitest/runner@3.2.4)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@9.0.18(@testing-library/dom@10.4.1)(prettier@3.6.2))(vitest@3.2.4)': + dependencies: + '@storybook/global': 5.0.0 + '@storybook/icons': 1.4.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + prompts: 2.4.2 + storybook: 9.0.18(@testing-library/dom@10.4.1)(prettier@3.6.2) + ts-dedent: 2.2.0 + optionalDependencies: + '@vitest/browser': 3.2.3(playwright@1.54.1)(vite@6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0))(vitest@3.2.4) + '@vitest/runner': 3.2.4 + vitest: 3.2.4(@types/debug@4.1.12)(@types/node@24.1.0)(@vitest/browser@3.2.3)(@vitest/ui@3.2.4)(jiti@2.4.2)(jsdom@26.1.0)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + transitivePeerDependencies: + - react + - react-dom + + '@storybook/builder-vite@9.0.18(storybook@9.0.18(@testing-library/dom@10.4.1)(prettier@3.6.2))(vite@6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0))': + dependencies: + '@storybook/csf-plugin': 9.0.18(storybook@9.0.18(@testing-library/dom@10.4.1)(prettier@3.6.2)) + storybook: 9.0.18(@testing-library/dom@10.4.1)(prettier@3.6.2) + ts-dedent: 2.2.0 + vite: 6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + + '@storybook/builder-vite@9.1.5(storybook@9.1.5(@testing-library/dom@10.4.1)(prettier@3.6.2)(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)))(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0))': + dependencies: + '@storybook/csf-plugin': 9.1.5(storybook@9.1.5(@testing-library/dom@10.4.1)(prettier@3.6.2)(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0))) + storybook: 9.1.5(@testing-library/dom@10.4.1)(prettier@3.6.2)(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)) + ts-dedent: 2.2.0 + vite: 6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + + '@storybook/csf-plugin@9.0.18(storybook@9.0.18(@testing-library/dom@10.4.1)(prettier@3.6.2))': + dependencies: + storybook: 9.0.18(@testing-library/dom@10.4.1)(prettier@3.6.2) + unplugin: 1.16.1 + + '@storybook/csf-plugin@9.1.5(storybook@9.1.5(@testing-library/dom@10.4.1)(prettier@3.6.2)(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)))': + dependencies: + storybook: 9.1.5(@testing-library/dom@10.4.1)(prettier@3.6.2)(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)) + unplugin: 1.16.1 + + '@storybook/global@5.0.0': {} + + '@storybook/icons@1.4.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + + '@storybook/react-dom-shim@9.0.18(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@9.0.18(@testing-library/dom@10.4.1)(prettier@3.6.2))': + dependencies: + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + storybook: 9.0.18(@testing-library/dom@10.4.1)(prettier@3.6.2) + + '@storybook/react-dom-shim@9.1.5(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@9.1.5(@testing-library/dom@10.4.1)(prettier@3.6.2)(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)))': + dependencies: + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + storybook: 9.1.5(@testing-library/dom@10.4.1)(prettier@3.6.2)(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)) + + '@storybook/react-vite@9.0.18(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(rollup@4.45.1)(storybook@9.0.18(@testing-library/dom@10.4.1)(prettier@3.6.2))(typescript@5.8.3)(vite@6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0))': + dependencies: + '@joshwooding/vite-plugin-react-docgen-typescript': 0.6.1(typescript@5.8.3)(vite@6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)) + '@rollup/pluginutils': 5.2.0(rollup@4.45.1) + '@storybook/builder-vite': 9.0.18(storybook@9.0.18(@testing-library/dom@10.4.1)(prettier@3.6.2))(vite@6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)) + '@storybook/react': 9.0.18(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@9.0.18(@testing-library/dom@10.4.1)(prettier@3.6.2))(typescript@5.8.3) + find-up: 7.0.0 + magic-string: 0.30.17 + react: 18.3.1 + react-docgen: 8.0.0 + react-dom: 18.3.1(react@18.3.1) + resolve: 1.22.10 + storybook: 9.0.18(@testing-library/dom@10.4.1)(prettier@3.6.2) + tsconfig-paths: 4.2.0 + vite: 6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + transitivePeerDependencies: + - rollup + - supports-color + - typescript + + '@storybook/react-vite@9.1.5(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(rollup@4.45.1)(storybook@9.1.5(@testing-library/dom@10.4.1)(prettier@3.6.2)(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)))(typescript@5.8.3)(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0))': + dependencies: + '@joshwooding/vite-plugin-react-docgen-typescript': 0.6.1(typescript@5.8.3)(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)) + '@rollup/pluginutils': 5.2.0(rollup@4.45.1) + '@storybook/builder-vite': 9.1.5(storybook@9.1.5(@testing-library/dom@10.4.1)(prettier@3.6.2)(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)))(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)) + '@storybook/react': 9.1.5(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@9.1.5(@testing-library/dom@10.4.1)(prettier@3.6.2)(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)))(typescript@5.8.3) + find-up: 7.0.0 + magic-string: 0.30.17 + react: 18.3.1 + react-docgen: 8.0.0 + react-dom: 18.3.1(react@18.3.1) + resolve: 1.22.10 + storybook: 9.1.5(@testing-library/dom@10.4.1)(prettier@3.6.2)(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)) + tsconfig-paths: 4.2.0 + vite: 6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + transitivePeerDependencies: + - rollup + - supports-color + - typescript + + '@storybook/react@9.0.18(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@9.0.18(@testing-library/dom@10.4.1)(prettier@3.6.2))(typescript@5.8.3)': + dependencies: + '@storybook/global': 5.0.0 + '@storybook/react-dom-shim': 9.0.18(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@9.0.18(@testing-library/dom@10.4.1)(prettier@3.6.2)) + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + storybook: 9.0.18(@testing-library/dom@10.4.1)(prettier@3.6.2) + optionalDependencies: + typescript: 5.8.3 + + '@storybook/react@9.1.5(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@9.1.5(@testing-library/dom@10.4.1)(prettier@3.6.2)(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)))(typescript@5.8.3)': + dependencies: + '@storybook/global': 5.0.0 + '@storybook/react-dom-shim': 9.1.5(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(storybook@9.1.5(@testing-library/dom@10.4.1)(prettier@3.6.2)(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0))) + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + storybook: 9.1.5(@testing-library/dom@10.4.1)(prettier@3.6.2)(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)) + optionalDependencies: + typescript: 5.8.3 + + '@swc/core-darwin-arm64@1.13.2': + optional: true + + '@swc/core-darwin-x64@1.13.2': + optional: true + + '@swc/core-linux-arm-gnueabihf@1.13.2': + optional: true + + '@swc/core-linux-arm64-gnu@1.13.2': + optional: true + + '@swc/core-linux-arm64-musl@1.13.2': + optional: true + + '@swc/core-linux-x64-gnu@1.13.2': + optional: true + + '@swc/core-linux-x64-musl@1.13.2': + optional: true + + '@swc/core-win32-arm64-msvc@1.13.2': + optional: true + + '@swc/core-win32-ia32-msvc@1.13.2': + optional: true + + '@swc/core-win32-x64-msvc@1.13.2': + optional: true + + '@swc/core@1.13.2(@swc/helpers@0.5.17)': + dependencies: + '@swc/counter': 0.1.3 + '@swc/types': 0.1.23 + optionalDependencies: + '@swc/core-darwin-arm64': 1.13.2 + '@swc/core-darwin-x64': 1.13.2 + '@swc/core-linux-arm-gnueabihf': 1.13.2 + '@swc/core-linux-arm64-gnu': 1.13.2 + '@swc/core-linux-arm64-musl': 1.13.2 + '@swc/core-linux-x64-gnu': 1.13.2 + '@swc/core-linux-x64-musl': 1.13.2 + '@swc/core-win32-arm64-msvc': 1.13.2 + '@swc/core-win32-ia32-msvc': 1.13.2 + '@swc/core-win32-x64-msvc': 1.13.2 + '@swc/helpers': 0.5.17 + + '@swc/counter@0.1.3': {} + + '@swc/helpers@0.5.17': + dependencies: + tslib: 2.8.1 + + '@swc/types@0.1.23': + dependencies: + '@swc/counter': 0.1.3 + + '@tailwindcss/container-queries@0.1.1(tailwindcss@3.4.17)': + dependencies: + tailwindcss: 3.4.17 + + '@tailwindcss/node@4.1.11': + dependencies: + '@ampproject/remapping': 2.3.0 + enhanced-resolve: 5.18.2 + jiti: 2.4.2 + lightningcss: 1.30.1 + magic-string: 0.30.17 + source-map-js: 1.2.1 + tailwindcss: 4.1.11 + + '@tailwindcss/oxide-android-arm64@4.1.11': + optional: true + + '@tailwindcss/oxide-darwin-arm64@4.1.11': + optional: true + + '@tailwindcss/oxide-darwin-x64@4.1.11': + optional: true + + '@tailwindcss/oxide-freebsd-x64@4.1.11': + optional: true + + '@tailwindcss/oxide-linux-arm-gnueabihf@4.1.11': + optional: true + + '@tailwindcss/oxide-linux-arm64-gnu@4.1.11': + optional: true + + '@tailwindcss/oxide-linux-arm64-musl@4.1.11': + optional: true + + '@tailwindcss/oxide-linux-x64-gnu@4.1.11': + optional: true + + '@tailwindcss/oxide-linux-x64-musl@4.1.11': + optional: true + + '@tailwindcss/oxide-wasm32-wasi@4.1.11': + optional: true + + '@tailwindcss/oxide-win32-arm64-msvc@4.1.11': + optional: true + + '@tailwindcss/oxide-win32-x64-msvc@4.1.11': + optional: true + + '@tailwindcss/oxide@4.1.11': + dependencies: + detect-libc: 2.0.4 + tar: 7.4.3 + optionalDependencies: + '@tailwindcss/oxide-android-arm64': 4.1.11 + '@tailwindcss/oxide-darwin-arm64': 4.1.11 + '@tailwindcss/oxide-darwin-x64': 4.1.11 + '@tailwindcss/oxide-freebsd-x64': 4.1.11 + '@tailwindcss/oxide-linux-arm-gnueabihf': 4.1.11 + '@tailwindcss/oxide-linux-arm64-gnu': 4.1.11 + '@tailwindcss/oxide-linux-arm64-musl': 4.1.11 + '@tailwindcss/oxide-linux-x64-gnu': 4.1.11 + '@tailwindcss/oxide-linux-x64-musl': 4.1.11 + '@tailwindcss/oxide-wasm32-wasi': 4.1.11 + '@tailwindcss/oxide-win32-arm64-msvc': 4.1.11 + '@tailwindcss/oxide-win32-x64-msvc': 4.1.11 + + '@tailwindcss/postcss@4.1.11': + dependencies: + '@alloc/quick-lru': 5.2.0 + '@tailwindcss/node': 4.1.11 + '@tailwindcss/oxide': 4.1.11 + postcss: 8.5.6 + tailwindcss: 4.1.11 + + '@tailwindcss/typography@0.5.16(tailwindcss@3.4.17)': + dependencies: + lodash.castarray: 4.4.0 + lodash.isplainobject: 4.0.6 + lodash.merge: 4.6.2 + postcss-selector-parser: 6.0.10 + tailwindcss: 3.4.17 + + '@tailwindcss/vite@4.1.11(vite@6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0))': + dependencies: + '@tailwindcss/node': 4.1.11 + '@tailwindcss/oxide': 4.1.11 + tailwindcss: 4.1.11 + vite: 6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + + '@tanstack/history@1.129.7': {} + + '@tanstack/query-core@5.83.0': {} + + '@tanstack/react-query@5.83.0(react@18.3.1)': + dependencies: + '@tanstack/query-core': 5.83.0 + react: 18.3.1 + + '@tanstack/react-router-devtools@1.131.26(@tanstack/react-router@1.129.8(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(@tanstack/router-core@1.129.8)(csstype@3.1.3)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)(solid-js@1.9.7)(tiny-invariant@1.3.3)': + dependencies: + '@tanstack/react-router': 1.129.8(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@tanstack/router-devtools-core': 1.131.26(@tanstack/router-core@1.129.8)(csstype@3.1.3)(solid-js@1.9.7)(tiny-invariant@1.3.3) + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + transitivePeerDependencies: + - '@tanstack/router-core' + - csstype + - solid-js + - tiny-invariant + + '@tanstack/react-router@1.129.8(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@tanstack/history': 1.129.7 + '@tanstack/react-store': 0.7.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@tanstack/router-core': 1.129.8 + isbot: 5.1.28 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + tiny-invariant: 1.3.3 + tiny-warning: 1.0.3 + + '@tanstack/react-store@0.7.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@tanstack/store': 0.7.2 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + use-sync-external-store: 1.5.0(react@18.3.1) + + '@tanstack/react-table@8.21.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@tanstack/table-core': 8.21.3 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + + '@tanstack/react-virtual@3.13.12(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@tanstack/virtual-core': 3.13.12 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + + '@tanstack/router-core@1.129.8': + dependencies: + '@tanstack/history': 1.129.7 + '@tanstack/store': 0.7.2 + cookie-es: 1.2.2 + seroval: 1.3.2 + seroval-plugins: 1.3.2(seroval@1.3.2) + tiny-invariant: 1.3.3 + tiny-warning: 1.0.3 + + '@tanstack/router-devtools-core@1.131.26(@tanstack/router-core@1.129.8)(csstype@3.1.3)(solid-js@1.9.7)(tiny-invariant@1.3.3)': + dependencies: + '@tanstack/router-core': 1.129.8 + clsx: 2.1.1 + goober: 2.1.16(csstype@3.1.3) + solid-js: 1.9.7 + tiny-invariant: 1.3.3 + optionalDependencies: + csstype: 3.1.3 + + '@tanstack/router-generator@1.129.8': + dependencies: + '@tanstack/router-core': 1.129.8 + '@tanstack/router-utils': 1.129.7 + '@tanstack/virtual-file-routes': 1.129.7 + prettier: 3.6.2 + recast: 0.23.11 + source-map: 0.7.4 + tsx: 4.20.3 + zod: 3.25.76 + transitivePeerDependencies: + - supports-color + + '@tanstack/router-plugin@1.129.8(@tanstack/react-router@1.129.8(react-dom@18.3.1(react@18.3.1))(react@18.3.1))(vite@6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0))(webpack@5.99.8(esbuild@0.25.8))': + dependencies: + '@babel/core': 7.28.0 + '@babel/plugin-syntax-jsx': 7.27.1(@babel/core@7.28.0) + '@babel/plugin-syntax-typescript': 7.27.1(@babel/core@7.28.0) + '@babel/template': 7.27.2 + '@babel/traverse': 7.28.0 + '@babel/types': 7.28.1 + '@tanstack/router-core': 1.129.8 + '@tanstack/router-generator': 1.129.8 + '@tanstack/router-utils': 1.129.7 + '@tanstack/virtual-file-routes': 1.129.7 + babel-dead-code-elimination: 1.0.10 + chokidar: 3.6.0 + unplugin: 2.3.5 + zod: 3.25.76 + optionalDependencies: + '@tanstack/react-router': 1.129.8(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + vite: 6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + webpack: 5.99.8(esbuild@0.25.8) + transitivePeerDependencies: + - supports-color + + '@tanstack/router-utils@1.129.7': + dependencies: + '@babel/core': 7.28.0 + '@babel/generator': 7.28.0 + '@babel/parser': 7.28.0 + '@babel/preset-typescript': 7.27.1(@babel/core@7.28.0) + ansis: 4.1.0 + diff: 8.0.2 + transitivePeerDependencies: + - supports-color + + '@tanstack/store@0.7.2': {} + + '@tanstack/table-core@8.21.3': {} + + '@tanstack/virtual-core@3.13.12': {} + + '@tanstack/virtual-file-routes@1.129.7': {} + + '@testing-library/dom@10.4.1': + dependencies: + '@babel/code-frame': 7.27.1 + '@babel/runtime': 7.28.2 + '@types/aria-query': 5.0.4 + aria-query: 5.3.0 + dom-accessibility-api: 0.5.16 + lz-string: 1.5.0 + picocolors: 1.1.1 + pretty-format: 27.5.1 + + '@testing-library/jest-dom@6.6.3': + dependencies: + '@adobe/css-tools': 4.4.3 + aria-query: 5.3.2 + chalk: 3.0.0 + css.escape: 1.5.1 + dom-accessibility-api: 0.6.3 + lodash: 4.17.21 + redent: 3.0.0 + + '@testing-library/react@16.3.0(@testing-library/dom@10.4.1)(@types/react-dom@18.3.7(@types/react@18.3.23))(@types/react@18.3.23)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@babel/runtime': 7.28.2 + '@testing-library/dom': 10.4.1 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + optionalDependencies: + '@types/react': 18.3.23 + '@types/react-dom': 18.3.7(@types/react@18.3.23) + + '@testing-library/user-event@14.6.1(@testing-library/dom@10.4.1)': + dependencies: + '@testing-library/dom': 10.4.1 + + '@textlint/ast-node-types@15.2.0': {} + + '@textlint/linter-formatter@15.2.0': + dependencies: + '@azu/format-text': 1.0.2 + '@azu/style-format': 1.0.1 + '@textlint/module-interop': 15.2.0 + '@textlint/resolver': 15.2.0 + '@textlint/types': 15.2.0 + chalk: 4.1.2 + debug: 4.4.1(supports-color@8.1.1) + js-yaml: 3.14.1 + lodash: 4.17.21 + pluralize: 2.0.0 + string-width: 4.2.3 + strip-ansi: 6.0.1 + table: 6.9.0 + text-table: 0.2.0 + transitivePeerDependencies: + - supports-color + + '@textlint/module-interop@15.2.0': {} + + '@textlint/resolver@15.2.0': {} + + '@textlint/types@15.2.0': + dependencies: + '@textlint/ast-node-types': 15.2.0 + + '@types/argparse@1.0.38': {} + + '@types/aria-query@5.0.4': {} + + '@types/babel__core@7.20.5': + dependencies: + '@babel/parser': 7.28.0 + '@babel/types': 7.28.1 + '@types/babel__generator': 7.27.0 + '@types/babel__template': 7.4.4 + '@types/babel__traverse': 7.20.7 + + '@types/babel__generator@7.27.0': + dependencies: + '@babel/types': 7.28.1 + + '@types/babel__template@7.4.4': + dependencies: + '@babel/parser': 7.28.0 + '@babel/types': 7.28.1 + + '@types/babel__traverse@7.20.7': + dependencies: + '@babel/types': 7.28.1 + + '@types/chai@5.2.2': + dependencies: + '@types/deep-eql': 4.0.2 + + '@types/command-line-args@5.2.3': {} + + '@types/command-line-usage@5.0.4': {} + + '@types/d3-array@3.2.1': {} + + '@types/d3-axis@3.0.6': + dependencies: + '@types/d3-selection': 3.0.11 + + '@types/d3-brush@3.0.6': + dependencies: + '@types/d3-selection': 3.0.11 + + '@types/d3-chord@3.0.6': {} + + '@types/d3-color@3.1.3': {} + + '@types/d3-contour@3.0.6': + dependencies: + '@types/d3-array': 3.2.1 + '@types/geojson': 7946.0.16 + + '@types/d3-delaunay@6.0.4': {} + + '@types/d3-dispatch@3.0.6': {} + + '@types/d3-drag@3.0.7': + dependencies: + '@types/d3-selection': 3.0.11 + + '@types/d3-dsv@3.0.7': {} + + '@types/d3-ease@3.0.2': {} + + '@types/d3-fetch@3.0.7': + dependencies: + '@types/d3-dsv': 3.0.7 + + '@types/d3-force@3.0.10': {} + + '@types/d3-format@3.0.4': {} + + '@types/d3-geo@3.1.0': + dependencies: + '@types/geojson': 7946.0.16 + + '@types/d3-hierarchy@3.1.7': {} + + '@types/d3-interpolate@3.0.4': + dependencies: + '@types/d3-color': 3.1.3 + + '@types/d3-path@3.1.1': {} + + '@types/d3-polygon@3.0.2': {} + + '@types/d3-quadtree@3.0.6': {} + + '@types/d3-random@3.0.3': {} + + '@types/d3-scale-chromatic@3.1.0': {} + + '@types/d3-scale@4.0.9': + dependencies: + '@types/d3-time': 3.0.4 + + '@types/d3-selection@3.0.11': {} + + '@types/d3-shape@3.1.7': + dependencies: + '@types/d3-path': 3.1.1 + + '@types/d3-time-format@4.0.3': {} + + '@types/d3-time@3.0.4': {} + + '@types/d3-timer@3.0.2': {} + + '@types/d3-transition@3.0.9': + dependencies: + '@types/d3-selection': 3.0.11 + + '@types/d3-zoom@3.0.8': + dependencies: + '@types/d3-interpolate': 3.0.4 + '@types/d3-selection': 3.0.11 + + '@types/d3@7.4.3': + dependencies: + '@types/d3-array': 3.2.1 + '@types/d3-axis': 3.0.6 + '@types/d3-brush': 3.0.6 + '@types/d3-chord': 3.0.6 + '@types/d3-color': 3.1.3 + '@types/d3-contour': 3.0.6 + '@types/d3-delaunay': 6.0.4 + '@types/d3-dispatch': 3.0.6 + '@types/d3-drag': 3.0.7 + '@types/d3-dsv': 3.0.7 + '@types/d3-ease': 3.0.2 + '@types/d3-fetch': 3.0.7 + '@types/d3-force': 3.0.10 + '@types/d3-format': 3.0.4 + '@types/d3-geo': 3.1.0 + '@types/d3-hierarchy': 3.1.7 + '@types/d3-interpolate': 3.0.4 + '@types/d3-path': 3.1.1 + '@types/d3-polygon': 3.0.2 + '@types/d3-quadtree': 3.0.6 + '@types/d3-random': 3.0.3 + '@types/d3-scale': 4.0.9 + '@types/d3-scale-chromatic': 3.1.0 + '@types/d3-selection': 3.0.11 + '@types/d3-shape': 3.1.7 + '@types/d3-time': 3.0.4 + '@types/d3-time-format': 4.0.3 + '@types/d3-timer': 3.0.2 + '@types/d3-transition': 3.0.9 + '@types/d3-zoom': 3.0.8 + + '@types/dagre@0.7.53': {} + + '@types/debug@4.1.12': + dependencies: + '@types/ms': 2.1.0 + + '@types/deep-eql@4.0.2': {} + + '@types/doctrine@0.0.9': {} + + '@types/es-aggregate-error@1.0.6': + dependencies: + '@types/node': 20.11.25 + + '@types/eslint-scope@3.7.7': + dependencies: + '@types/eslint': 9.6.1 + '@types/estree': 1.0.8 + + '@types/eslint@9.6.1': + dependencies: + '@types/estree': 1.0.8 + '@types/json-schema': 7.0.15 + + '@types/estree-jsx@1.0.5': + dependencies: + '@types/estree': 1.0.8 + + '@types/estree@1.0.8': {} + + '@types/fs-extra@11.0.4': + dependencies: + '@types/jsonfile': 6.1.4 + '@types/node': 20.11.25 + + '@types/geojson@7946.0.16': {} + + '@types/hast@3.0.4': + dependencies: + '@types/unist': 3.0.3 + + '@types/istanbul-lib-coverage@2.0.6': {} + + '@types/json-schema@7.0.15': {} + + '@types/jsonfile@6.1.4': + dependencies: + '@types/node': 20.11.25 + + '@types/mdast@4.0.4': + dependencies: + '@types/unist': 3.0.3 + + '@types/mdx@2.0.13': {} + + '@types/mocha@10.0.10': {} + + '@types/ms@2.1.0': {} + + '@types/node@20.11.25': + dependencies: + undici-types: 5.26.5 + + '@types/node@20.19.9': + dependencies: + undici-types: 6.21.0 + + '@types/node@24.1.0': + dependencies: + undici-types: 7.8.0 + optional: true + + '@types/normalize-package-data@2.4.4': {} + + '@types/pluralize@0.0.33': {} + + '@types/prop-types@15.7.15': {} + + '@types/react-dom@18.3.7(@types/react@18.3.23)': + dependencies: + '@types/react': 18.3.23 + + '@types/react@18.3.23': + dependencies: + '@types/prop-types': 15.7.15 + csstype: 3.1.3 + + '@types/resolve@1.20.6': {} + + '@types/sarif@2.1.7': {} + + '@types/shell-quote@1.7.5': {} + + '@types/unist@2.0.11': {} + + '@types/unist@3.0.3': {} + + '@types/urijs@1.19.25': {} + + '@types/vscode@1.96.0': {} + + '@typescript-eslint/eslint-plugin@8.38.0(@typescript-eslint/parser@8.38.0(eslint@9.31.0(jiti@2.4.2))(typescript@5.8.3))(eslint@9.31.0(jiti@2.4.2))(typescript@5.8.3)': + dependencies: + '@eslint-community/regexpp': 4.12.1 + '@typescript-eslint/parser': 8.38.0(eslint@9.31.0(jiti@2.4.2))(typescript@5.8.3) + '@typescript-eslint/scope-manager': 8.38.0 + '@typescript-eslint/type-utils': 8.38.0(eslint@9.31.0(jiti@2.4.2))(typescript@5.8.3) + '@typescript-eslint/utils': 8.38.0(eslint@9.31.0(jiti@2.4.2))(typescript@5.8.3) + '@typescript-eslint/visitor-keys': 8.38.0 + eslint: 9.31.0(jiti@2.4.2) + graphemer: 1.4.0 + ignore: 7.0.5 + natural-compare: 1.4.0 + ts-api-utils: 2.1.0(typescript@5.8.3) + typescript: 5.8.3 + transitivePeerDependencies: + - supports-color + + '@typescript-eslint/parser@8.38.0(eslint@9.31.0(jiti@2.4.2))(typescript@5.8.3)': + dependencies: + '@typescript-eslint/scope-manager': 8.38.0 + '@typescript-eslint/types': 8.38.0 + '@typescript-eslint/typescript-estree': 8.38.0(typescript@5.8.3) + '@typescript-eslint/visitor-keys': 8.38.0 + debug: 4.4.1 + eslint: 9.31.0(jiti@2.4.2) + typescript: 5.8.3 + transitivePeerDependencies: + - supports-color + + '@typescript-eslint/project-service@8.38.0(typescript@5.8.3)': + dependencies: + '@typescript-eslint/tsconfig-utils': 8.38.0(typescript@5.8.3) + '@typescript-eslint/types': 8.38.0 + debug: 4.4.1 + typescript: 5.8.3 + transitivePeerDependencies: + - supports-color + + '@typescript-eslint/scope-manager@8.38.0': + dependencies: + '@typescript-eslint/types': 8.38.0 + '@typescript-eslint/visitor-keys': 8.38.0 + + '@typescript-eslint/tsconfig-utils@8.38.0(typescript@5.8.3)': + dependencies: + typescript: 5.8.3 + + '@typescript-eslint/type-utils@8.38.0(eslint@9.31.0(jiti@2.4.2))(typescript@5.8.3)': + dependencies: + '@typescript-eslint/types': 8.38.0 + '@typescript-eslint/typescript-estree': 8.38.0(typescript@5.8.3) + '@typescript-eslint/utils': 8.38.0(eslint@9.31.0(jiti@2.4.2))(typescript@5.8.3) + debug: 4.4.1 + eslint: 9.31.0(jiti@2.4.2) + ts-api-utils: 2.1.0(typescript@5.8.3) + typescript: 5.8.3 + transitivePeerDependencies: + - supports-color + + '@typescript-eslint/types@8.38.0': {} + + '@typescript-eslint/typescript-estree@8.38.0(typescript@5.8.3)': + dependencies: + '@typescript-eslint/project-service': 8.38.0(typescript@5.8.3) + '@typescript-eslint/tsconfig-utils': 8.38.0(typescript@5.8.3) + '@typescript-eslint/types': 8.38.0 + '@typescript-eslint/visitor-keys': 8.38.0 + debug: 4.4.1 + fast-glob: 3.3.3 + is-glob: 4.0.3 + minimatch: 9.0.5 + semver: 7.7.2 + ts-api-utils: 2.1.0(typescript@5.8.3) + typescript: 5.8.3 + transitivePeerDependencies: + - supports-color + + '@typescript-eslint/utils@8.38.0(eslint@9.31.0(jiti@2.4.2))(typescript@5.8.3)': + dependencies: + '@eslint-community/eslint-utils': 4.7.0(eslint@9.31.0(jiti@2.4.2)) + '@typescript-eslint/scope-manager': 8.38.0 + '@typescript-eslint/types': 8.38.0 + '@typescript-eslint/typescript-estree': 8.38.0(typescript@5.8.3) + eslint: 9.31.0(jiti@2.4.2) + typescript: 5.8.3 + transitivePeerDependencies: + - supports-color + + '@typescript-eslint/visitor-keys@8.38.0': + dependencies: + '@typescript-eslint/types': 8.38.0 + eslint-visitor-keys: 4.2.1 + + '@typespec/ts-http-runtime@0.3.0': + dependencies: + http-proxy-agent: 7.0.2 + https-proxy-agent: 7.0.6 + tslib: 2.8.1 + transitivePeerDependencies: + - supports-color + + '@uidotdev/usehooks@2.4.1(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + + '@uiw/codemirror-extensions-basic-setup@4.24.1(@codemirror/autocomplete@6.18.6)(@codemirror/commands@6.8.1)(@codemirror/language@6.11.2)(@codemirror/lint@6.8.5)(@codemirror/search@6.5.10)(@codemirror/state@6.5.2)(@codemirror/view@6.38.1)': + dependencies: + '@codemirror/autocomplete': 6.18.6 + '@codemirror/commands': 6.8.1 + '@codemirror/language': 6.11.2 + '@codemirror/lint': 6.8.5 + '@codemirror/search': 6.5.10 + '@codemirror/state': 6.5.2 + '@codemirror/view': 6.38.1 + + '@uiw/react-codemirror@4.24.1(@babel/runtime@7.28.2)(@codemirror/autocomplete@6.18.6)(@codemirror/language@6.11.2)(@codemirror/lint@6.8.5)(@codemirror/search@6.5.10)(@codemirror/state@6.5.2)(@codemirror/theme-one-dark@6.1.2)(@codemirror/view@6.38.1)(codemirror@6.0.1)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@babel/runtime': 7.28.2 + '@codemirror/commands': 6.8.1 + '@codemirror/state': 6.5.2 + '@codemirror/theme-one-dark': 6.1.2 + '@codemirror/view': 6.38.1 + '@uiw/codemirror-extensions-basic-setup': 4.24.1(@codemirror/autocomplete@6.18.6)(@codemirror/commands@6.8.1)(@codemirror/language@6.11.2)(@codemirror/lint@6.8.5)(@codemirror/search@6.5.10)(@codemirror/state@6.5.2)(@codemirror/view@6.38.1) + codemirror: 6.0.1 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + transitivePeerDependencies: + - '@codemirror/autocomplete' + - '@codemirror/language' + - '@codemirror/lint' + - '@codemirror/search' + + '@ungap/structured-clone@1.3.0': {} + + '@vitejs/plugin-react-swc@3.11.0(@swc/helpers@0.5.17)(vite@6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0))': + dependencies: + '@rolldown/pluginutils': 1.0.0-beta.27 + '@swc/core': 1.13.2(@swc/helpers@0.5.17) + vite: 6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + transitivePeerDependencies: + - '@swc/helpers' + + '@vitejs/plugin-react@4.7.0(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0))': + dependencies: + '@babel/core': 7.28.0 + '@babel/plugin-transform-react-jsx-self': 7.27.1(@babel/core@7.28.0) + '@babel/plugin-transform-react-jsx-source': 7.27.1(@babel/core@7.28.0) + '@rolldown/pluginutils': 1.0.0-beta.27 + '@types/babel__core': 7.20.5 + react-refresh: 0.17.0 + vite: 6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + transitivePeerDependencies: + - supports-color + + '@vitejs/plugin-react@4.7.0(vite@6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0))': + dependencies: + '@babel/core': 7.28.0 + '@babel/plugin-transform-react-jsx-self': 7.27.1(@babel/core@7.28.0) + '@babel/plugin-transform-react-jsx-source': 7.27.1(@babel/core@7.28.0) + '@rolldown/pluginutils': 1.0.0-beta.27 + '@types/babel__core': 7.20.5 + react-refresh: 0.17.0 + vite: 6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + transitivePeerDependencies: + - supports-color + + '@vitest/browser@3.2.3(playwright@1.54.1)(vite@6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0))(vitest@3.2.4)': + dependencies: + '@testing-library/dom': 10.4.1 + '@testing-library/user-event': 14.6.1(@testing-library/dom@10.4.1) + '@vitest/mocker': 3.2.3(vite@6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)) + '@vitest/utils': 3.2.3 + magic-string: 0.30.17 + sirv: 3.0.1 + tinyrainbow: 2.0.0 + vitest: 3.2.4(@types/debug@4.1.12)(@types/node@24.1.0)(@vitest/browser@3.2.3)(@vitest/ui@3.2.4)(jiti@2.4.2)(jsdom@26.1.0)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + ws: 8.18.3 + optionalDependencies: + playwright: 1.54.1 + transitivePeerDependencies: + - bufferutil + - msw + - utf-8-validate + - vite + + '@vitest/browser@3.2.4(playwright@1.54.1)(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0))(vitest@3.2.4)': + dependencies: + '@testing-library/dom': 10.4.1 + '@testing-library/user-event': 14.6.1(@testing-library/dom@10.4.1) + '@vitest/mocker': 3.2.4(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)) + '@vitest/utils': 3.2.4 + magic-string: 0.30.17 + sirv: 3.0.1 + tinyrainbow: 2.0.0 + vitest: 3.2.4(@types/debug@4.1.12)(@types/node@20.11.25)(@vitest/browser@3.2.4)(@vitest/ui@3.2.4)(jiti@2.4.2)(jsdom@26.1.0)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + ws: 8.18.3 + optionalDependencies: + playwright: 1.54.1 + transitivePeerDependencies: + - bufferutil + - msw + - utf-8-validate + - vite + + '@vitest/browser@3.2.4(playwright@1.54.1)(vite@6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0))(vitest@3.2.4)': + dependencies: + '@testing-library/dom': 10.4.1 + '@testing-library/user-event': 14.6.1(@testing-library/dom@10.4.1) + '@vitest/mocker': 3.2.4(vite@6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)) + '@vitest/utils': 3.2.4 + magic-string: 0.30.17 + sirv: 3.0.1 + tinyrainbow: 2.0.0 + vitest: 3.2.4(@types/debug@4.1.12)(@types/node@24.1.0)(@vitest/browser@3.2.4)(@vitest/ui@3.2.4)(jiti@2.4.2)(jsdom@26.1.0)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + ws: 8.18.3 + optionalDependencies: + playwright: 1.54.1 + transitivePeerDependencies: + - bufferutil + - msw + - utf-8-validate + - vite + optional: true + + '@vitest/coverage-v8@3.2.3(@vitest/browser@3.2.3)(vitest@3.2.4)': + dependencies: + '@ampproject/remapping': 2.3.0 + '@bcoe/v8-coverage': 1.0.2 + ast-v8-to-istanbul: 0.3.3 + debug: 4.4.1(supports-color@8.1.1) + istanbul-lib-coverage: 3.2.2 + istanbul-lib-report: 3.0.1 + istanbul-lib-source-maps: 5.0.6 + istanbul-reports: 3.1.7 + magic-string: 0.30.17 + magicast: 0.3.5 + std-env: 3.9.0 + test-exclude: 7.0.1 + tinyrainbow: 2.0.0 + vitest: 3.2.4(@types/debug@4.1.12)(@types/node@24.1.0)(@vitest/browser@3.2.3)(@vitest/ui@3.2.4)(jiti@2.4.2)(jsdom@26.1.0)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + optionalDependencies: + '@vitest/browser': 3.2.3(playwright@1.54.1)(vite@6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0))(vitest@3.2.4) + transitivePeerDependencies: + - supports-color + + '@vitest/expect@3.2.4': + dependencies: + '@types/chai': 5.2.2 + '@vitest/spy': 3.2.4 + '@vitest/utils': 3.2.4 + chai: 5.2.1 + tinyrainbow: 2.0.0 + + '@vitest/mocker@3.2.3(vite@6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0))': + dependencies: + '@vitest/spy': 3.2.3 + estree-walker: 3.0.3 + magic-string: 0.30.17 + optionalDependencies: + vite: 6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + + '@vitest/mocker@3.2.4(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0))': + dependencies: + '@vitest/spy': 3.2.4 + estree-walker: 3.0.3 + magic-string: 0.30.17 + optionalDependencies: + vite: 6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + + '@vitest/mocker@3.2.4(vite@6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0))': + dependencies: + '@vitest/spy': 3.2.4 + estree-walker: 3.0.3 + magic-string: 0.30.17 + optionalDependencies: + vite: 6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + + '@vitest/pretty-format@3.2.3': + dependencies: + tinyrainbow: 2.0.0 + + '@vitest/pretty-format@3.2.4': + dependencies: + tinyrainbow: 2.0.0 + + '@vitest/runner@3.2.4': + dependencies: + '@vitest/utils': 3.2.4 + pathe: 2.0.3 + strip-literal: 3.0.0 + + '@vitest/snapshot@3.2.4': + dependencies: + '@vitest/pretty-format': 3.2.4 + magic-string: 0.30.17 + pathe: 2.0.3 + + '@vitest/spy@3.2.3': + dependencies: + tinyspy: 4.0.3 + + '@vitest/spy@3.2.4': + dependencies: + tinyspy: 4.0.3 + + '@vitest/ui@3.2.4(vitest@3.2.4)': + dependencies: + '@vitest/utils': 3.2.4 + fflate: 0.8.2 + flatted: 3.3.3 + pathe: 2.0.3 + sirv: 3.0.1 + tinyglobby: 0.2.14 + tinyrainbow: 2.0.0 + vitest: 3.2.4(@types/debug@4.1.12)(@types/node@20.11.25)(@vitest/browser@3.2.4)(@vitest/ui@3.2.4)(jiti@2.4.2)(jsdom@26.1.0)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + + '@vitest/utils@3.2.3': + dependencies: + '@vitest/pretty-format': 3.2.3 + loupe: 3.1.4 + tinyrainbow: 2.0.0 + + '@vitest/utils@3.2.4': + dependencies: + '@vitest/pretty-format': 3.2.4 + loupe: 3.1.4 + tinyrainbow: 2.0.0 + + '@volar/language-core@2.4.23': + dependencies: + '@volar/source-map': 2.4.23 + + '@volar/source-map@2.4.23': {} + + '@volar/typescript@2.4.23': + dependencies: + '@volar/language-core': 2.4.23 + path-browserify: 1.0.1 + vscode-uri: 3.1.0 + + '@vscode/python-extension@1.0.5': {} + + '@vscode/test-cli@0.0.10': + dependencies: + '@types/mocha': 10.0.10 + c8: 9.1.0 + chokidar: 3.6.0 + enhanced-resolve: 5.18.2 + glob: 10.4.5 + minimatch: 9.0.5 + mocha: 10.8.2 + supports-color: 9.4.0 + yargs: 17.7.2 + + '@vscode/test-electron@2.5.2': + dependencies: + http-proxy-agent: 7.0.2 + https-proxy-agent: 7.0.6 + jszip: 3.10.1 + ora: 8.2.0 + semver: 7.7.2 + transitivePeerDependencies: + - supports-color + + '@vscode/vsce-sign-alpine-arm64@2.0.5': + optional: true + + '@vscode/vsce-sign-alpine-x64@2.0.5': + optional: true + + '@vscode/vsce-sign-darwin-arm64@2.0.5': + optional: true + + '@vscode/vsce-sign-darwin-x64@2.0.5': + optional: true + + '@vscode/vsce-sign-linux-arm64@2.0.5': + optional: true + + '@vscode/vsce-sign-linux-arm@2.0.5': + optional: true + + '@vscode/vsce-sign-linux-x64@2.0.5': + optional: true + + '@vscode/vsce-sign-win32-arm64@2.0.5': + optional: true + + '@vscode/vsce-sign-win32-x64@2.0.5': + optional: true + + '@vscode/vsce-sign@2.0.6': + optionalDependencies: + '@vscode/vsce-sign-alpine-arm64': 2.0.5 + '@vscode/vsce-sign-alpine-x64': 2.0.5 + '@vscode/vsce-sign-darwin-arm64': 2.0.5 + '@vscode/vsce-sign-darwin-x64': 2.0.5 + '@vscode/vsce-sign-linux-arm': 2.0.5 + '@vscode/vsce-sign-linux-arm64': 2.0.5 + '@vscode/vsce-sign-linux-x64': 2.0.5 + '@vscode/vsce-sign-win32-arm64': 2.0.5 + '@vscode/vsce-sign-win32-x64': 2.0.5 + + '@vscode/vsce@3.6.0': + dependencies: + '@azure/identity': 4.10.2 + '@secretlint/node': 10.2.1 + '@secretlint/secretlint-formatter-sarif': 10.2.1 + '@secretlint/secretlint-rule-no-dotenv': 10.2.1 + '@secretlint/secretlint-rule-preset-recommend': 10.2.1 + '@vscode/vsce-sign': 2.0.6 + azure-devops-node-api: 12.5.0 + chalk: 4.1.2 + cheerio: 1.1.2 + cockatiel: 3.2.1 + commander: 12.1.0 + form-data: 4.0.4 + glob: 11.0.3 + hosted-git-info: 4.1.0 + jsonc-parser: 3.3.1 + leven: 3.1.0 + markdown-it: 14.1.0 + mime: 1.6.0 + minimatch: 3.1.2 + parse-semver: 1.1.1 + read: 1.0.7 + secretlint: 10.2.1 + semver: 7.7.2 + tmp: 0.2.3 + typed-rest-client: 1.8.11 + url-join: 4.0.1 + xml2js: 0.5.0 + yauzl: 2.10.0 + yazl: 2.5.1 + optionalDependencies: + keytar: 7.9.0 + transitivePeerDependencies: + - supports-color + + '@vue/compiler-core@3.5.18': + dependencies: + '@babel/parser': 7.28.0 + '@vue/shared': 3.5.18 + entities: 4.5.0 + estree-walker: 2.0.2 + source-map-js: 1.2.1 + + '@vue/compiler-dom@3.5.18': + dependencies: + '@vue/compiler-core': 3.5.18 + '@vue/shared': 3.5.18 + + '@vue/compiler-vue2@2.7.16': + dependencies: + de-indent: 1.0.2 + he: 1.2.0 + + '@vue/language-core@2.2.0(typescript@5.8.3)': + dependencies: + '@volar/language-core': 2.4.23 + '@vue/compiler-dom': 3.5.18 + '@vue/compiler-vue2': 2.7.16 + '@vue/shared': 3.5.18 + alien-signals: 0.4.14 + minimatch: 9.0.5 + muggle-string: 0.4.1 + path-browserify: 1.0.1 + optionalDependencies: + typescript: 5.8.3 + + '@vue/shared@3.5.18': {} + + '@webassemblyjs/ast@1.14.1': + dependencies: + '@webassemblyjs/helper-numbers': 1.13.2 + '@webassemblyjs/helper-wasm-bytecode': 1.13.2 + + '@webassemblyjs/floating-point-hex-parser@1.13.2': {} + + '@webassemblyjs/helper-api-error@1.13.2': {} + + '@webassemblyjs/helper-buffer@1.14.1': {} + + '@webassemblyjs/helper-numbers@1.13.2': + dependencies: + '@webassemblyjs/floating-point-hex-parser': 1.13.2 + '@webassemblyjs/helper-api-error': 1.13.2 + '@xtuc/long': 4.2.2 + + '@webassemblyjs/helper-wasm-bytecode@1.13.2': {} + + '@webassemblyjs/helper-wasm-section@1.14.1': + dependencies: + '@webassemblyjs/ast': 1.14.1 + '@webassemblyjs/helper-buffer': 1.14.1 + '@webassemblyjs/helper-wasm-bytecode': 1.13.2 + '@webassemblyjs/wasm-gen': 1.14.1 + + '@webassemblyjs/ieee754@1.13.2': + dependencies: + '@xtuc/ieee754': 1.2.0 + + '@webassemblyjs/leb128@1.13.2': + dependencies: + '@xtuc/long': 4.2.2 + + '@webassemblyjs/utf8@1.13.2': {} + + '@webassemblyjs/wasm-edit@1.14.1': + dependencies: + '@webassemblyjs/ast': 1.14.1 + '@webassemblyjs/helper-buffer': 1.14.1 + '@webassemblyjs/helper-wasm-bytecode': 1.13.2 + '@webassemblyjs/helper-wasm-section': 1.14.1 + '@webassemblyjs/wasm-gen': 1.14.1 + '@webassemblyjs/wasm-opt': 1.14.1 + '@webassemblyjs/wasm-parser': 1.14.1 + '@webassemblyjs/wast-printer': 1.14.1 + + '@webassemblyjs/wasm-gen@1.14.1': + dependencies: + '@webassemblyjs/ast': 1.14.1 + '@webassemblyjs/helper-wasm-bytecode': 1.13.2 + '@webassemblyjs/ieee754': 1.13.2 + '@webassemblyjs/leb128': 1.13.2 + '@webassemblyjs/utf8': 1.13.2 + + '@webassemblyjs/wasm-opt@1.14.1': + dependencies: + '@webassemblyjs/ast': 1.14.1 + '@webassemblyjs/helper-buffer': 1.14.1 + '@webassemblyjs/wasm-gen': 1.14.1 + '@webassemblyjs/wasm-parser': 1.14.1 + + '@webassemblyjs/wasm-parser@1.14.1': + dependencies: + '@webassemblyjs/ast': 1.14.1 + '@webassemblyjs/helper-api-error': 1.13.2 + '@webassemblyjs/helper-wasm-bytecode': 1.13.2 + '@webassemblyjs/ieee754': 1.13.2 + '@webassemblyjs/leb128': 1.13.2 + '@webassemblyjs/utf8': 1.13.2 + + '@webassemblyjs/wast-printer@1.14.1': + dependencies: + '@webassemblyjs/ast': 1.14.1 + '@xtuc/long': 4.2.2 + + '@xtuc/ieee754@1.2.0': {} + + '@xtuc/long@4.2.2': {} + + '@xyflow/react@12.8.4(@types/react@18.3.23)(immer@9.0.21)(react-dom@18.3.1(react@18.3.1))(react@18.3.1)': + dependencies: + '@xyflow/system': 0.0.68 + classcat: 5.0.5 + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + zustand: 4.5.7(@types/react@18.3.23)(immer@9.0.21)(react@18.3.1) + transitivePeerDependencies: + - '@types/react' + - immer + + '@xyflow/system@0.0.68': + dependencies: + '@types/d3-drag': 3.0.7 + '@types/d3-interpolate': 3.0.4 + '@types/d3-selection': 3.0.11 + '@types/d3-transition': 3.0.9 + '@types/d3-zoom': 3.0.8 + d3-drag: 3.0.0 + d3-interpolate: 3.0.1 + d3-selection: 3.0.0 + d3-zoom: 3.0.0 + + abort-controller@3.0.0: + dependencies: + event-target-shim: 5.0.1 + + acorn-jsx@5.3.2(acorn@8.15.0): + dependencies: + acorn: 8.15.0 + + acorn@8.15.0: {} + + agent-base@7.1.4: {} + + ajv-draft-04@1.0.0(ajv@8.13.0): + optionalDependencies: + ajv: 8.13.0 + + ajv-draft-04@1.0.0(ajv@8.17.1): + optionalDependencies: + ajv: 8.17.1 + + ajv-errors@3.0.0(ajv@8.17.1): + dependencies: + ajv: 8.17.1 + + ajv-formats@2.1.1(ajv@8.17.1): + optionalDependencies: + ajv: 8.17.1 + + ajv-formats@3.0.1(ajv@8.13.0): + optionalDependencies: + ajv: 8.13.0 + + ajv-keywords@5.1.0(ajv@8.17.1): + dependencies: + ajv: 8.17.1 + fast-deep-equal: 3.1.3 + + ajv@6.12.6: + dependencies: + fast-deep-equal: 3.1.3 + fast-json-stable-stringify: 2.1.0 + json-schema-traverse: 0.4.1 + uri-js: 4.4.1 + + ajv@8.12.0: + dependencies: + fast-deep-equal: 3.1.3 + json-schema-traverse: 1.0.0 + require-from-string: 2.0.2 + uri-js: 4.4.1 + + ajv@8.13.0: + dependencies: + fast-deep-equal: 3.1.3 + json-schema-traverse: 1.0.0 + require-from-string: 2.0.2 + uri-js: 4.4.1 + + ajv@8.17.1: + dependencies: + fast-deep-equal: 3.1.3 + fast-uri: 3.0.6 + json-schema-traverse: 1.0.0 + require-from-string: 2.0.2 + + alien-signals@0.4.14: {} + + ansi-colors@4.1.3: {} + + ansi-escapes@7.0.0: + dependencies: + environment: 1.1.0 + + ansi-regex@5.0.1: {} + + ansi-regex@6.1.0: {} + + ansi-styles@4.3.0: + dependencies: + color-convert: 2.0.1 + + ansi-styles@5.2.0: {} + + ansi-styles@6.2.1: {} + + ansis@4.1.0: {} + + any-promise@1.3.0: {} + + anymatch@3.1.3: + dependencies: + normalize-path: 3.0.0 + picomatch: 2.3.1 + + apache-arrow@19.0.1: + dependencies: + '@swc/helpers': 0.5.17 + '@types/command-line-args': 5.2.3 + '@types/command-line-usage': 5.0.4 + '@types/node': 20.19.9 + command-line-args: 6.0.1 + command-line-usage: 7.0.3 + flatbuffers: 24.12.23 + json-bignum: 0.0.3 + tslib: 2.8.1 + transitivePeerDependencies: + - '@75lb/nature' + + arg@5.0.2: {} + + argparse@1.0.10: + dependencies: + sprintf-js: 1.0.3 + + argparse@2.0.1: {} + + aria-hidden@1.2.6: + dependencies: + tslib: 2.8.1 + + aria-query@5.3.0: + dependencies: + dequal: 2.0.3 + + aria-query@5.3.2: {} + + array-back@6.2.2: {} + + array-buffer-byte-length@1.0.2: + dependencies: + call-bound: 1.0.4 + is-array-buffer: 3.0.5 + + array-union@2.1.0: {} + + arraybuffer.prototype.slice@1.0.4: + dependencies: + array-buffer-byte-length: 1.0.2 + call-bind: 1.0.8 + define-properties: 1.2.1 + es-abstract: 1.24.0 + es-errors: 1.3.0 + get-intrinsic: 1.3.0 + is-array-buffer: 3.0.5 + + assertion-error@2.0.1: {} + + ast-types@0.16.1: + dependencies: + tslib: 2.8.1 + + ast-v8-to-istanbul@0.3.3: + dependencies: + '@jridgewell/trace-mapping': 0.3.29 + estree-walker: 3.0.3 + js-tokens: 9.0.1 + + astral-regex@2.0.0: {} + + astring@1.9.0: {} + + async-function@1.0.0: {} + + asynckit@0.4.0: {} + + autoprefixer@10.4.21(postcss@8.5.6): + dependencies: + browserslist: 4.26.2 + caniuse-lite: 1.0.30001746 + fraction.js: 4.3.7 + normalize-range: 0.1.2 + picocolors: 1.1.1 + postcss: 8.5.6 + postcss-value-parser: 4.2.0 + + available-typed-arrays@1.0.7: + dependencies: + possible-typed-array-names: 1.1.0 + + axe-core@4.10.3: {} + + azure-devops-node-api@12.5.0: + dependencies: + tunnel: 0.0.6 + typed-rest-client: 1.8.11 + + babel-dead-code-elimination@1.0.10: + dependencies: + '@babel/core': 7.28.0 + '@babel/parser': 7.28.0 + '@babel/traverse': 7.28.0 + '@babel/types': 7.28.1 + transitivePeerDependencies: + - supports-color + + bail@2.0.2: {} + + balanced-match@1.0.2: {} + + base64-js@1.5.1: + optional: true + + baseline-browser-mapping@2.8.9: {} + + better-opn@3.0.2: + dependencies: + open: 8.4.2 + + binary-extensions@2.3.0: {} + + binaryextensions@6.11.0: + dependencies: + editions: 6.21.0 + + bl@4.1.0: + dependencies: + buffer: 5.7.1 + inherits: 2.0.4 + readable-stream: 3.6.2 + optional: true + + boolbase@1.0.0: {} + + boundary@2.0.0: {} + + brace-expansion@1.1.12: + dependencies: + balanced-match: 1.0.2 + concat-map: 0.0.1 + + brace-expansion@2.0.2: + dependencies: + balanced-match: 1.0.2 + + braces@3.0.3: + dependencies: + fill-range: 7.1.1 + + browser-stdout@1.3.1: {} + + browserslist@4.26.2: + dependencies: + baseline-browser-mapping: 2.8.9 + caniuse-lite: 1.0.30001746 + electron-to-chromium: 1.5.227 + node-releases: 2.0.21 + update-browserslist-db: 1.1.3(browserslist@4.26.2) + + buffer-crc32@0.2.13: {} + + buffer-equal-constant-time@1.0.1: {} + + buffer-from@1.1.2: {} + + buffer@5.7.1: + dependencies: + base64-js: 1.5.1 + ieee754: 1.2.1 + optional: true + + bundle-name@4.1.0: + dependencies: + run-applescript: 7.0.0 + + c8@9.1.0: + dependencies: + '@bcoe/v8-coverage': 0.2.3 + '@istanbuljs/schema': 0.1.3 + find-up: 5.0.0 + foreground-child: 3.3.1 + istanbul-lib-coverage: 3.2.2 + istanbul-lib-report: 3.0.1 + istanbul-reports: 3.1.7 + test-exclude: 6.0.0 + v8-to-istanbul: 9.3.0 + yargs: 17.7.2 + yargs-parser: 21.1.1 + + cac@6.7.14: {} + + call-bind-apply-helpers@1.0.2: + dependencies: + es-errors: 1.3.0 + function-bind: 1.1.2 + + call-bind@1.0.8: + dependencies: + call-bind-apply-helpers: 1.0.2 + es-define-property: 1.0.1 + get-intrinsic: 1.3.0 + set-function-length: 1.2.2 + + call-bound@1.0.4: + dependencies: + call-bind-apply-helpers: 1.0.2 + get-intrinsic: 1.3.0 + + call-me-maybe@1.0.2: {} + + callsites@3.1.0: {} + + camelcase-css@2.0.1: {} + + camelcase@6.3.0: {} + + caniuse-lite@1.0.30001746: {} + + ccount@2.0.1: {} + + chai@5.2.1: + dependencies: + assertion-error: 2.0.1 + check-error: 2.1.1 + deep-eql: 5.0.2 + loupe: 3.1.4 + pathval: 2.0.1 + + chalk-template@0.4.0: + dependencies: + chalk: 4.1.2 + + chalk-template@1.1.0: + dependencies: + chalk: 5.4.1 + + chalk@3.0.0: + dependencies: + ansi-styles: 4.3.0 + supports-color: 7.2.0 + + chalk@4.1.2: + dependencies: + ansi-styles: 4.3.0 + supports-color: 7.2.0 + + chalk@5.4.1: {} + + character-entities-html4@2.1.0: {} + + character-entities-legacy@3.0.0: {} + + character-entities@2.0.2: {} + + character-reference-invalid@2.0.1: {} + + check-error@2.1.1: {} + + cheerio-select@2.1.0: + dependencies: + boolbase: 1.0.0 + css-select: 5.2.2 + css-what: 6.2.2 + domelementtype: 2.3.0 + domhandler: 5.0.3 + domutils: 3.2.2 + + cheerio@1.1.2: + dependencies: + cheerio-select: 2.1.0 + dom-serializer: 2.0.0 + domhandler: 5.0.3 + domutils: 3.2.2 + encoding-sniffer: 0.2.1 + htmlparser2: 10.0.0 + parse5: 7.3.0 + parse5-htmlparser2-tree-adapter: 7.1.0 + parse5-parser-stream: 7.1.2 + undici: 7.12.0 + whatwg-mimetype: 4.0.0 + + chokidar@3.6.0: + dependencies: + anymatch: 3.1.3 + braces: 3.0.3 + glob-parent: 5.1.2 + is-binary-path: 2.1.0 + is-glob: 4.0.3 + normalize-path: 3.0.0 + readdirp: 3.6.0 + optionalDependencies: + fsevents: 2.3.3 + + chokidar@4.0.3: + dependencies: + readdirp: 4.1.2 + + chownr@1.1.4: + optional: true + + chownr@3.0.0: {} + + chromatic@12.2.0: {} + + chrome-trace-event@1.0.4: {} + + class-variance-authority@0.7.1: + dependencies: + clsx: 2.1.1 + + classcat@5.0.5: {} + + cli-cursor@5.0.0: + dependencies: + restore-cursor: 5.1.0 + + cli-spinners@2.9.2: {} + + cliui@7.0.4: + dependencies: + string-width: 4.2.3 + strip-ansi: 6.0.1 + wrap-ansi: 7.0.0 + + cliui@8.0.1: + dependencies: + string-width: 4.2.3 + strip-ansi: 6.0.1 + wrap-ansi: 7.0.0 + + clsx@2.1.1: {} + + cockatiel@3.2.1: {} + + codemirror@6.0.1: + dependencies: + '@codemirror/autocomplete': 6.19.0 + '@codemirror/commands': 6.8.1 + '@codemirror/language': 6.11.3 + '@codemirror/lint': 6.8.5 + '@codemirror/search': 6.5.10 + '@codemirror/state': 6.5.2 + '@codemirror/view': 6.38.4 + + color-convert@2.0.1: + dependencies: + color-name: 1.1.4 + + color-name@1.1.4: {} + + combined-stream@1.0.8: + dependencies: + delayed-stream: 1.0.0 + + comma-separated-tokens@2.0.3: {} + + command-line-args@6.0.1: + dependencies: + array-back: 6.2.2 + find-replace: 5.0.2 + lodash.camelcase: 4.3.0 + typical: 7.3.0 + + command-line-usage@7.0.3: + dependencies: + array-back: 6.2.2 + chalk-template: 0.4.0 + table-layout: 4.1.1 + typical: 7.3.0 + + commander@12.1.0: {} + + commander@13.1.0: {} + + commander@2.20.3: {} + + commander@4.1.1: {} + + compare-versions@6.1.1: {} + + concat-map@0.0.1: {} + + confbox@0.1.8: {} + + confbox@0.2.2: {} + + convert-source-map@2.0.0: {} + + cookie-es@1.2.2: {} + + cookie@1.0.2: {} + + core-util-is@1.0.3: {} + + cosmiconfig@9.0.0(typescript@5.8.3): + dependencies: + env-paths: 2.2.1 + import-fresh: 3.3.1 + js-yaml: 4.1.0 + parse-json: 5.2.0 + optionalDependencies: + typescript: 5.8.3 + + crelt@1.0.6: {} + + cronstrue@3.3.0: {} + + cross-spawn@7.0.6: + dependencies: + path-key: 3.1.1 + shebang-command: 2.0.0 + which: 2.0.2 + + css-select@5.2.2: + dependencies: + boolbase: 1.0.0 + css-what: 6.2.2 + domhandler: 5.0.3 + domutils: 3.2.2 + nth-check: 2.1.1 + + css-what@6.2.2: {} + + css.escape@1.5.1: {} + + cssesc@3.0.0: {} + + cssstyle@4.6.0: + dependencies: + '@asamuzakjp/css-color': 3.2.0 + rrweb-cssom: 0.8.0 + + csstype@3.1.3: {} + + d3-color@3.1.0: {} + + d3-dispatch@3.0.1: {} + + d3-drag@3.0.0: + dependencies: + d3-dispatch: 3.0.1 + d3-selection: 3.0.0 + + d3-ease@3.0.1: {} + + d3-interpolate@3.0.1: + dependencies: + d3-color: 3.1.0 + + d3-selection@3.0.0: {} + + d3-timer@3.0.1: {} + + d3-transition@3.0.1(d3-selection@3.0.0): + dependencies: + d3-color: 3.1.0 + d3-dispatch: 3.0.1 + d3-ease: 3.0.1 + d3-interpolate: 3.0.1 + d3-selection: 3.0.0 + d3-timer: 3.0.1 + + d3-zoom@3.0.0: + dependencies: + d3-dispatch: 3.0.1 + d3-drag: 3.0.0 + d3-interpolate: 3.0.1 + d3-selection: 3.0.0 + d3-transition: 3.0.1(d3-selection@3.0.0) + + dagre@0.8.5: + dependencies: + graphlib: 2.1.8 + lodash: 4.17.21 + + data-urls@5.0.0: + dependencies: + whatwg-mimetype: 4.0.0 + whatwg-url: 14.2.0 + + data-view-buffer@1.0.2: + dependencies: + call-bound: 1.0.4 + es-errors: 1.3.0 + is-data-view: 1.0.2 + + data-view-byte-length@1.0.2: + dependencies: + call-bound: 1.0.4 + es-errors: 1.3.0 + is-data-view: 1.0.2 + + data-view-byte-offset@1.0.1: + dependencies: + call-bound: 1.0.4 + es-errors: 1.3.0 + is-data-view: 1.0.2 + + de-indent@1.0.2: {} + + debug@4.4.1: + dependencies: + ms: 2.1.3 + + debug@4.4.1(supports-color@8.1.1): + dependencies: + ms: 2.1.3 + optionalDependencies: + supports-color: 8.1.1 + + decamelize@4.0.0: {} + + decimal.js@10.6.0: {} + + decode-named-character-reference@1.2.0: + dependencies: + character-entities: 2.0.2 + + decompress-response@6.0.0: + dependencies: + mimic-response: 3.1.0 + optional: true + + deep-eql@5.0.2: {} + + deep-extend@0.6.0: + optional: true + + deep-is@0.1.4: {} + + deepmerge@4.3.1: {} + + default-browser-id@5.0.0: {} + + default-browser@5.2.1: + dependencies: + bundle-name: 4.1.0 + default-browser-id: 5.0.0 + + define-data-property@1.1.4: + dependencies: + es-define-property: 1.0.1 + es-errors: 1.3.0 + gopd: 1.2.0 + + define-lazy-prop@2.0.0: {} + + define-lazy-prop@3.0.0: {} + + define-properties@1.2.1: + dependencies: + define-data-property: 1.1.4 + has-property-descriptors: 1.0.2 + object-keys: 1.1.1 + + delayed-stream@1.0.0: {} + + dependency-graph@0.11.0: {} + + dequal@2.0.3: {} + + detect-libc@2.0.4: {} + + detect-node-es@1.1.0: {} + + devlop@1.1.0: + dependencies: + dequal: 2.0.3 + + didyoumean@1.2.2: {} + + diff@5.2.0: {} + + diff@8.0.2: {} + + dir-glob@3.0.1: + dependencies: + path-type: 4.0.0 + + dlv@1.1.3: {} + + dnd-core@16.0.1: + dependencies: + '@react-dnd/asap': 5.0.2 + '@react-dnd/invariant': 4.0.2 + redux: 4.2.1 + + doctrine@3.0.0: + dependencies: + esutils: 2.0.3 + + dom-accessibility-api@0.5.16: {} + + dom-accessibility-api@0.6.3: {} + + dom-serializer@2.0.0: + dependencies: + domelementtype: 2.3.0 + domhandler: 5.0.3 + entities: 4.5.0 + + domelementtype@2.3.0: {} + + domhandler@5.0.3: + dependencies: + domelementtype: 2.3.0 + + domutils@3.2.2: + dependencies: + dom-serializer: 2.0.0 + domelementtype: 2.3.0 + domhandler: 5.0.3 + + dunder-proto@1.0.1: + dependencies: + call-bind-apply-helpers: 1.0.2 + es-errors: 1.3.0 + gopd: 1.2.0 + + eastasianwidth@0.2.0: {} + + ecdsa-sig-formatter@1.0.11: + dependencies: + safe-buffer: 5.2.1 + + editions@6.21.0: + dependencies: + version-range: 4.14.0 + + effect@3.17.9: + dependencies: + '@standard-schema/spec': 1.0.0 + fast-check: 3.23.2 + + electron-to-chromium@1.5.227: {} + + elkjs@0.8.2: {} + + emoji-regex@10.4.0: {} + + emoji-regex@8.0.0: {} + + emoji-regex@9.2.2: {} + + encoding-sniffer@0.2.1: + dependencies: + iconv-lite: 0.6.3 + whatwg-encoding: 3.1.1 + + end-of-stream@1.4.5: + dependencies: + once: 1.4.0 + optional: true + + enhanced-resolve@5.18.2: + dependencies: + graceful-fs: 4.2.11 + tapable: 2.2.2 + + enhanced-resolve@5.18.3: + dependencies: + graceful-fs: 4.2.11 + tapable: 2.2.3 + + enquirer@2.4.1: + dependencies: + ansi-colors: 4.1.3 + strip-ansi: 6.0.1 + + entities@4.5.0: {} + + entities@6.0.1: {} + + env-paths@2.2.1: {} + + environment@1.1.0: {} + + error-ex@1.3.2: + dependencies: + is-arrayish: 0.2.1 + + es-abstract@1.24.0: + dependencies: + array-buffer-byte-length: 1.0.2 + arraybuffer.prototype.slice: 1.0.4 + available-typed-arrays: 1.0.7 + call-bind: 1.0.8 + call-bound: 1.0.4 + data-view-buffer: 1.0.2 + data-view-byte-length: 1.0.2 + data-view-byte-offset: 1.0.1 + es-define-property: 1.0.1 + es-errors: 1.3.0 + es-object-atoms: 1.1.1 + es-set-tostringtag: 2.1.0 + es-to-primitive: 1.3.0 + function.prototype.name: 1.1.8 + get-intrinsic: 1.3.0 + get-proto: 1.0.1 + get-symbol-description: 1.1.0 + globalthis: 1.0.4 + gopd: 1.2.0 + has-property-descriptors: 1.0.2 + has-proto: 1.2.0 + has-symbols: 1.1.0 + hasown: 2.0.2 + internal-slot: 1.1.0 + is-array-buffer: 3.0.5 + is-callable: 1.2.7 + is-data-view: 1.0.2 + is-negative-zero: 2.0.3 + is-regex: 1.2.1 + is-set: 2.0.3 + is-shared-array-buffer: 1.0.4 + is-string: 1.1.1 + is-typed-array: 1.1.15 + is-weakref: 1.1.1 + math-intrinsics: 1.1.0 + object-inspect: 1.13.4 + object-keys: 1.1.1 + object.assign: 4.1.7 + own-keys: 1.0.1 + regexp.prototype.flags: 1.5.4 + safe-array-concat: 1.1.3 + safe-push-apply: 1.0.0 + safe-regex-test: 1.1.0 + set-proto: 1.0.0 + stop-iteration-iterator: 1.1.0 + string.prototype.trim: 1.2.10 + string.prototype.trimend: 1.0.9 + string.prototype.trimstart: 1.0.8 + typed-array-buffer: 1.0.3 + typed-array-byte-length: 1.0.3 + typed-array-byte-offset: 1.0.4 + typed-array-length: 1.0.7 + unbox-primitive: 1.1.0 + which-typed-array: 1.1.19 + + es-aggregate-error@1.0.14: + dependencies: + define-data-property: 1.1.4 + define-properties: 1.2.1 + es-abstract: 1.24.0 + es-errors: 1.3.0 + function-bind: 1.1.2 + globalthis: 1.0.4 + has-property-descriptors: 1.0.2 + set-function-name: 2.0.2 + + es-define-property@1.0.1: {} + + es-errors@1.3.0: {} + + es-module-lexer@1.7.0: {} + + es-object-atoms@1.1.1: + dependencies: + es-errors: 1.3.0 + + es-set-tostringtag@2.1.0: + dependencies: + es-errors: 1.3.0 + get-intrinsic: 1.3.0 + has-tostringtag: 1.0.2 + hasown: 2.0.2 + + es-to-primitive@1.3.0: + dependencies: + is-callable: 1.2.7 + is-date-object: 1.1.0 + is-symbol: 1.1.1 + + es6-promise@3.3.1: {} + + esbuild-register@3.6.0(esbuild@0.25.8): + dependencies: + debug: 4.4.1 + esbuild: 0.25.8 + transitivePeerDependencies: + - supports-color + + esbuild@0.25.8: + optionalDependencies: + '@esbuild/aix-ppc64': 0.25.8 + '@esbuild/android-arm': 0.25.8 + '@esbuild/android-arm64': 0.25.8 + '@esbuild/android-x64': 0.25.8 + '@esbuild/darwin-arm64': 0.25.8 + '@esbuild/darwin-x64': 0.25.8 + '@esbuild/freebsd-arm64': 0.25.8 + '@esbuild/freebsd-x64': 0.25.8 + '@esbuild/linux-arm': 0.25.8 + '@esbuild/linux-arm64': 0.25.8 + '@esbuild/linux-ia32': 0.25.8 + '@esbuild/linux-loong64': 0.25.8 + '@esbuild/linux-mips64el': 0.25.8 + '@esbuild/linux-ppc64': 0.25.8 + '@esbuild/linux-riscv64': 0.25.8 + '@esbuild/linux-s390x': 0.25.8 + '@esbuild/linux-x64': 0.25.8 + '@esbuild/netbsd-arm64': 0.25.8 + '@esbuild/netbsd-x64': 0.25.8 + '@esbuild/openbsd-arm64': 0.25.8 + '@esbuild/openbsd-x64': 0.25.8 + '@esbuild/openharmony-arm64': 0.25.8 + '@esbuild/sunos-x64': 0.25.8 + '@esbuild/win32-arm64': 0.25.8 + '@esbuild/win32-ia32': 0.25.8 + '@esbuild/win32-x64': 0.25.8 + + escalade@3.2.0: {} + + escape-string-regexp@4.0.0: {} + + eslint-plugin-react-hooks@5.2.0(eslint@9.31.0(jiti@2.4.2)): + dependencies: + eslint: 9.31.0(jiti@2.4.2) + + eslint-plugin-storybook@9.1.5(eslint@9.31.0(jiti@2.4.2))(storybook@9.1.5(@testing-library/dom@10.4.1)(prettier@3.6.2)(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)))(typescript@5.8.3): + dependencies: + '@typescript-eslint/utils': 8.38.0(eslint@9.31.0(jiti@2.4.2))(typescript@5.8.3) + eslint: 9.31.0(jiti@2.4.2) + storybook: 9.1.5(@testing-library/dom@10.4.1)(prettier@3.6.2)(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)) + transitivePeerDependencies: + - supports-color + - typescript + + eslint-scope@5.1.1: + dependencies: + esrecurse: 4.3.0 + estraverse: 4.3.0 + + eslint-scope@8.4.0: + dependencies: + esrecurse: 4.3.0 + estraverse: 5.3.0 + + eslint-visitor-keys@3.4.3: {} + + eslint-visitor-keys@4.2.1: {} + + eslint@9.31.0(jiti@2.4.2): + dependencies: + '@eslint-community/eslint-utils': 4.7.0(eslint@9.31.0(jiti@2.4.2)) + '@eslint-community/regexpp': 4.12.1 + '@eslint/config-array': 0.21.0 + '@eslint/config-helpers': 0.3.0 + '@eslint/core': 0.15.1 + '@eslint/eslintrc': 3.3.1 + '@eslint/js': 9.31.0 + '@eslint/plugin-kit': 0.3.4 + '@humanfs/node': 0.16.6 + '@humanwhocodes/module-importer': 1.0.1 + '@humanwhocodes/retry': 0.4.3 + '@types/estree': 1.0.8 + '@types/json-schema': 7.0.15 + ajv: 6.12.6 + chalk: 4.1.2 + cross-spawn: 7.0.6 + debug: 4.4.1 + escape-string-regexp: 4.0.0 + eslint-scope: 8.4.0 + eslint-visitor-keys: 4.2.1 + espree: 10.4.0 + esquery: 1.6.0 + esutils: 2.0.3 + fast-deep-equal: 3.1.3 + file-entry-cache: 8.0.0 + find-up: 5.0.0 + glob-parent: 6.0.2 + ignore: 5.3.2 + imurmurhash: 0.1.4 + is-glob: 4.0.3 + json-stable-stringify-without-jsonify: 1.0.1 + lodash.merge: 4.6.2 + minimatch: 3.1.2 + natural-compare: 1.4.0 + optionator: 0.9.4 + optionalDependencies: + jiti: 2.4.2 + transitivePeerDependencies: + - supports-color + + espree@10.4.0: + dependencies: + acorn: 8.15.0 + acorn-jsx: 5.3.2(acorn@8.15.0) + eslint-visitor-keys: 4.2.1 + + esprima@4.0.1: {} + + esquery@1.6.0: + dependencies: + estraverse: 5.3.0 + + esrecurse@4.3.0: + dependencies: + estraverse: 5.3.0 + + estraverse@4.3.0: {} + + estraverse@5.3.0: {} + + estree-util-is-identifier-name@3.0.0: {} + + estree-walker@2.0.2: {} + + estree-walker@3.0.3: + dependencies: + '@types/estree': 1.0.8 + + esutils@2.0.3: {} + + event-target-shim@5.0.1: {} + + events@3.3.0: {} + + execa@5.1.1: + dependencies: + cross-spawn: 7.0.6 + get-stream: 6.0.1 + human-signals: 2.1.0 + is-stream: 2.0.1 + merge-stream: 2.0.0 + npm-run-path: 4.0.1 + onetime: 5.1.2 + signal-exit: 3.0.7 + strip-final-newline: 2.0.0 + + expand-template@2.0.3: + optional: true + + expect-type@1.2.2: {} + + exsolve@1.0.7: {} + + extend@3.0.2: {} + + fast-check@3.23.2: + dependencies: + pure-rand: 6.1.0 + + fast-deep-equal@3.1.3: {} + + fast-glob@3.3.3: + dependencies: + '@nodelib/fs.stat': 2.0.5 + '@nodelib/fs.walk': 1.2.8 + glob-parent: 5.1.2 + merge2: 1.4.1 + micromatch: 4.0.8 + + fast-json-stable-stringify@2.1.0: {} + + fast-levenshtein@2.0.6: {} + + fast-memoize@2.5.2: {} + + fast-safe-stringify@2.1.1: {} + + fast-uri@3.0.6: {} + + fastq@1.19.1: + dependencies: + reusify: 1.1.0 + + fd-slicer@1.1.0: + dependencies: + pend: 1.2.0 + + fdir@6.4.6(picomatch@4.0.3): + optionalDependencies: + picomatch: 4.0.3 + + fflate@0.8.2: {} + + file-entry-cache@8.0.0: + dependencies: + flat-cache: 4.0.1 + + filesize@10.1.6: {} + + fill-range@7.1.1: + dependencies: + to-regex-range: 5.0.1 + + find-replace@5.0.2: {} + + find-up@5.0.0: + dependencies: + locate-path: 6.0.0 + path-exists: 4.0.0 + + find-up@7.0.0: + dependencies: + locate-path: 7.2.0 + path-exists: 5.0.0 + unicorn-magic: 0.1.0 + + flat-cache@4.0.1: + dependencies: + flatted: 3.3.3 + keyv: 4.5.4 + + flat@5.0.2: {} + + flatbuffers@24.12.23: {} + + flatted@3.3.3: {} + + for-each@0.3.5: + dependencies: + is-callable: 1.2.7 + + foreground-child@3.3.1: + dependencies: + cross-spawn: 7.0.6 + signal-exit: 4.1.0 + + form-data@4.0.4: + dependencies: + asynckit: 0.4.0 + combined-stream: 1.0.8 + es-set-tostringtag: 2.1.0 + hasown: 2.0.2 + mime-types: 2.1.35 + + fraction.js@4.3.7: {} + + fs-constants@1.0.0: + optional: true + + fs-extra@11.3.0: + dependencies: + graceful-fs: 4.2.11 + jsonfile: 6.1.0 + universalify: 2.0.1 + + fs.realpath@1.0.0: {} + + fsevents@2.3.2: + optional: true + + fsevents@2.3.3: + optional: true + + function-bind@1.1.2: {} + + function.prototype.name@1.1.8: + dependencies: + call-bind: 1.0.8 + call-bound: 1.0.4 + define-properties: 1.2.1 + functions-have-names: 1.2.3 + hasown: 2.0.2 + is-callable: 1.2.7 + + functions-have-names@1.2.3: {} + + fuse.js@7.1.0: {} + + gensync@1.0.0-beta.2: {} + + get-caller-file@2.0.5: {} + + get-east-asian-width@1.3.0: {} + + get-intrinsic@1.3.0: + dependencies: + call-bind-apply-helpers: 1.0.2 + es-define-property: 1.0.1 + es-errors: 1.3.0 + es-object-atoms: 1.1.1 + function-bind: 1.1.2 + get-proto: 1.0.1 + gopd: 1.2.0 + has-symbols: 1.1.0 + hasown: 2.0.2 + math-intrinsics: 1.1.0 + + get-nonce@1.0.1: {} + + get-proto@1.0.1: + dependencies: + dunder-proto: 1.0.1 + es-object-atoms: 1.1.1 + + get-stream@6.0.1: {} + + get-symbol-description@1.1.0: + dependencies: + call-bound: 1.0.4 + es-errors: 1.3.0 + get-intrinsic: 1.3.0 + + get-tsconfig@4.10.1: + dependencies: + resolve-pkg-maps: 1.0.0 + + github-from-package@0.0.0: + optional: true + + glob-parent@5.1.2: + dependencies: + is-glob: 4.0.3 + + glob-parent@6.0.2: + dependencies: + is-glob: 4.0.3 + + glob-to-regexp@0.4.1: {} + + glob@10.4.5: + dependencies: + foreground-child: 3.3.1 + jackspeak: 3.4.3 + minimatch: 9.0.5 + minipass: 7.1.2 + package-json-from-dist: 1.0.1 + path-scurry: 1.11.1 + + glob@11.0.3: + dependencies: + foreground-child: 3.3.1 + jackspeak: 4.1.1 + minimatch: 10.0.3 + minipass: 7.1.2 + package-json-from-dist: 1.0.1 + path-scurry: 2.0.0 + + glob@7.2.3: + dependencies: + fs.realpath: 1.0.0 + inflight: 1.0.6 + inherits: 2.0.4 + minimatch: 3.1.2 + once: 1.4.0 + path-is-absolute: 1.0.1 + + glob@8.1.0: + dependencies: + fs.realpath: 1.0.0 + inflight: 1.0.6 + inherits: 2.0.4 + minimatch: 5.1.6 + once: 1.4.0 + + globals@14.0.0: {} + + globals@16.3.0: {} + + globalthis@1.0.4: + dependencies: + define-properties: 1.2.1 + gopd: 1.2.0 + + globby@11.1.0: + dependencies: + array-union: 2.1.0 + dir-glob: 3.0.1 + fast-glob: 3.3.3 + ignore: 5.3.2 + merge2: 1.4.1 + slash: 3.0.0 + + globby@14.1.0: + dependencies: + '@sindresorhus/merge-streams': 2.3.0 + fast-glob: 3.3.3 + ignore: 7.0.5 + path-type: 6.0.0 + slash: 5.1.0 + unicorn-magic: 0.3.0 + + goober@2.1.16(csstype@3.1.3): + dependencies: + csstype: 3.1.3 + + gopd@1.2.0: {} + + graceful-fs@4.2.11: {} + + graphemer@1.4.0: {} + + graphlib@2.1.8: + dependencies: + lodash: 4.17.21 + + has-bigints@1.1.0: {} + + has-flag@4.0.0: {} + + has-property-descriptors@1.0.2: + dependencies: + es-define-property: 1.0.1 + + has-proto@1.2.0: + dependencies: + dunder-proto: 1.0.1 + + has-symbols@1.1.0: {} + + has-tostringtag@1.0.2: + dependencies: + has-symbols: 1.1.0 + + hasown@2.0.2: + dependencies: + function-bind: 1.1.2 + + hast-util-to-jsx-runtime@2.3.6: + dependencies: + '@types/estree': 1.0.8 + '@types/hast': 3.0.4 + '@types/unist': 3.0.3 + comma-separated-tokens: 2.0.3 + devlop: 1.1.0 + estree-util-is-identifier-name: 3.0.0 + hast-util-whitespace: 3.0.0 + mdast-util-mdx-expression: 2.0.1 + mdast-util-mdx-jsx: 3.2.0 + mdast-util-mdxjs-esm: 2.0.1 + property-information: 7.1.0 + space-separated-tokens: 2.0.2 + style-to-js: 1.1.17 + unist-util-position: 5.0.0 + vfile-message: 4.0.2 + transitivePeerDependencies: + - supports-color + + hast-util-whitespace@3.0.0: + dependencies: + '@types/hast': 3.0.4 + + he@1.2.0: {} + + hoist-non-react-statics@3.3.2: + dependencies: + react-is: 16.13.1 + + hosted-git-info@4.1.0: + dependencies: + lru-cache: 6.0.0 + + hosted-git-info@7.0.2: + dependencies: + lru-cache: 10.4.3 + + hosted-git-info@8.1.0: + dependencies: + lru-cache: 10.4.3 + + html-encoding-sniffer@4.0.0: + dependencies: + whatwg-encoding: 3.1.1 + + html-escaper@2.0.2: {} + + html-url-attributes@3.0.1: {} + + htmlparser2@10.0.0: + dependencies: + domelementtype: 2.3.0 + domhandler: 5.0.3 + domutils: 3.2.2 + entities: 6.0.1 + + http-proxy-agent@7.0.2: + dependencies: + agent-base: 7.1.4 + debug: 4.4.1 + transitivePeerDependencies: + - supports-color + + http2-client@1.3.5: {} + + https-proxy-agent@7.0.6: + dependencies: + agent-base: 7.1.4 + debug: 4.4.1 + transitivePeerDependencies: + - supports-color + + human-signals@2.1.0: {} + + iconv-lite@0.6.3: + dependencies: + safer-buffer: 2.1.2 + + ieee754@1.2.1: + optional: true + + ignore@5.3.2: {} + + ignore@7.0.5: {} + + immediate@3.0.6: {} + + immer@9.0.21: {} + + import-fresh@3.3.1: + dependencies: + parent-module: 1.0.1 + resolve-from: 4.0.0 + + import-lazy@4.0.0: {} + + imurmurhash@0.1.4: {} + + indent-string@4.0.0: {} + + index-to-position@1.1.0: {} + + inflight@1.0.6: + dependencies: + once: 1.4.0 + wrappy: 1.0.2 + + inherits@2.0.4: {} + + ini@1.3.8: + optional: true + + inline-style-parser@0.2.4: {} + + internal-slot@1.1.0: + dependencies: + es-errors: 1.3.0 + hasown: 2.0.2 + side-channel: 1.1.0 + + is-alphabetical@2.0.1: {} + + is-alphanumerical@2.0.1: + dependencies: + is-alphabetical: 2.0.1 + is-decimal: 2.0.1 + + is-array-buffer@3.0.5: + dependencies: + call-bind: 1.0.8 + call-bound: 1.0.4 + get-intrinsic: 1.3.0 + + is-arrayish@0.2.1: {} + + is-async-function@2.1.1: + dependencies: + async-function: 1.0.0 + call-bound: 1.0.4 + get-proto: 1.0.1 + has-tostringtag: 1.0.2 + safe-regex-test: 1.1.0 + + is-bigint@1.1.0: + dependencies: + has-bigints: 1.1.0 + + is-binary-path@2.1.0: + dependencies: + binary-extensions: 2.3.0 + + is-boolean-object@1.2.2: + dependencies: + call-bound: 1.0.4 + has-tostringtag: 1.0.2 + + is-callable@1.2.7: {} + + is-core-module@2.16.1: + dependencies: + hasown: 2.0.2 + + is-data-view@1.0.2: + dependencies: + call-bound: 1.0.4 + get-intrinsic: 1.3.0 + is-typed-array: 1.1.15 + + is-date-object@1.1.0: + dependencies: + call-bound: 1.0.4 + has-tostringtag: 1.0.2 + + is-decimal@2.0.1: {} + + is-docker@2.2.1: {} + + is-docker@3.0.0: {} + + is-extglob@2.1.1: {} + + is-finalizationregistry@1.1.1: + dependencies: + call-bound: 1.0.4 + + is-fullwidth-code-point@3.0.0: {} + + is-generator-function@1.1.0: + dependencies: + call-bound: 1.0.4 + get-proto: 1.0.1 + has-tostringtag: 1.0.2 + safe-regex-test: 1.1.0 + + is-glob@4.0.3: + dependencies: + is-extglob: 2.1.1 + + is-hexadecimal@2.0.1: {} + + is-inside-container@1.0.0: + dependencies: + is-docker: 3.0.0 + + is-interactive@2.0.0: {} + + is-map@2.0.3: {} + + is-negative-zero@2.0.3: {} + + is-number-object@1.1.1: + dependencies: + call-bound: 1.0.4 + has-tostringtag: 1.0.2 + + is-number@7.0.0: {} + + is-plain-obj@2.1.0: {} + + is-plain-obj@4.1.0: {} + + is-potential-custom-element-name@1.0.1: {} + + is-regex@1.2.1: + dependencies: + call-bound: 1.0.4 + gopd: 1.2.0 + has-tostringtag: 1.0.2 + hasown: 2.0.2 + + is-set@2.0.3: {} + + is-shared-array-buffer@1.0.4: + dependencies: + call-bound: 1.0.4 + + is-stream@2.0.1: {} + + is-string@1.1.1: + dependencies: + call-bound: 1.0.4 + has-tostringtag: 1.0.2 + + is-symbol@1.1.1: + dependencies: + call-bound: 1.0.4 + has-symbols: 1.1.0 + safe-regex-test: 1.1.0 + + is-typed-array@1.1.15: + dependencies: + which-typed-array: 1.1.19 + + is-unicode-supported@0.1.0: {} + + is-unicode-supported@1.3.0: {} + + is-unicode-supported@2.1.0: {} + + is-weakmap@2.0.2: {} + + is-weakref@1.1.1: + dependencies: + call-bound: 1.0.4 + + is-weakset@2.0.4: + dependencies: + call-bound: 1.0.4 + get-intrinsic: 1.3.0 + + is-wsl@2.2.0: + dependencies: + is-docker: 2.2.1 + + is-wsl@3.1.0: + dependencies: + is-inside-container: 1.0.0 + + isarray@1.0.0: {} + + isarray@2.0.5: {} + + isbot@5.1.28: {} + + isexe@2.0.0: {} + + istanbul-lib-coverage@3.2.2: {} + + istanbul-lib-report@3.0.1: + dependencies: + istanbul-lib-coverage: 3.2.2 + make-dir: 4.0.0 + supports-color: 7.2.0 + + istanbul-lib-source-maps@5.0.6: + dependencies: + '@jridgewell/trace-mapping': 0.3.29 + debug: 4.4.1(supports-color@8.1.1) + istanbul-lib-coverage: 3.2.2 + transitivePeerDependencies: + - supports-color + + istanbul-reports@3.1.7: + dependencies: + html-escaper: 2.0.2 + istanbul-lib-report: 3.0.1 + + istextorbinary@9.5.0: + dependencies: + binaryextensions: 6.11.0 + editions: 6.21.0 + textextensions: 6.11.0 + + jackspeak@3.4.3: + dependencies: + '@isaacs/cliui': 8.0.2 + optionalDependencies: + '@pkgjs/parseargs': 0.11.0 + + jackspeak@4.1.1: + dependencies: + '@isaacs/cliui': 8.0.2 + + jest-worker@27.5.1: + dependencies: + '@types/node': 20.11.25 + merge-stream: 2.0.0 + supports-color: 8.1.1 + + jiti@1.21.7: {} + + jiti@2.4.2: {} + + jju@1.4.0: {} + + js-tokens@4.0.0: {} + + js-tokens@9.0.1: {} + + js-yaml@3.14.1: + dependencies: + argparse: 1.0.10 + esprima: 4.0.1 + + js-yaml@4.1.0: + dependencies: + argparse: 2.0.1 + + jsdom@26.1.0: + dependencies: + cssstyle: 4.6.0 + data-urls: 5.0.0 + decimal.js: 10.6.0 + html-encoding-sniffer: 4.0.0 + http-proxy-agent: 7.0.2 + https-proxy-agent: 7.0.6 + is-potential-custom-element-name: 1.0.1 + nwsapi: 2.2.20 + parse5: 7.3.0 + rrweb-cssom: 0.8.0 + saxes: 6.0.0 + symbol-tree: 3.2.4 + tough-cookie: 5.1.2 + w3c-xmlserializer: 5.0.0 + webidl-conversions: 7.0.0 + whatwg-encoding: 3.1.1 + whatwg-mimetype: 4.0.0 + whatwg-url: 14.2.0 + ws: 8.18.3 + xml-name-validator: 5.0.0 + transitivePeerDependencies: + - bufferutil + - supports-color + - utf-8-validate + + jsep@1.4.0: {} + + jsesc@3.1.0: {} + + json-bignum@0.0.3: {} + + json-buffer@3.0.1: {} + + json-parse-even-better-errors@2.3.1: {} + + json-schema-traverse@0.4.1: {} + + json-schema-traverse@1.0.0: {} + + json-stable-stringify-without-jsonify@1.0.1: {} + + json5@2.2.3: {} + + jsonc-parser@2.2.1: {} + + jsonc-parser@3.3.1: {} + + jsonfile@6.1.0: + dependencies: + universalify: 2.0.1 + optionalDependencies: + graceful-fs: 4.2.11 + + jsonpath-plus@10.3.0: + dependencies: + '@jsep-plugin/assignment': 1.3.0(jsep@1.4.0) + '@jsep-plugin/regex': 1.0.4(jsep@1.4.0) + jsep: 1.4.0 + + jsonpointer@5.0.1: {} + + jsonschema@1.5.0: {} + + jsonwebtoken@9.0.2: + dependencies: + jws: 3.2.2 + lodash.includes: 4.3.0 + lodash.isboolean: 3.0.3 + lodash.isinteger: 4.0.4 + lodash.isnumber: 3.0.3 + lodash.isplainobject: 4.0.6 + lodash.isstring: 4.0.1 + lodash.once: 4.1.1 + ms: 2.1.3 + semver: 7.7.2 + + jszip@3.10.1: + dependencies: + lie: 3.3.0 + pako: 1.0.11 + readable-stream: 2.3.8 + setimmediate: 1.0.5 + + jwa@1.4.2: + dependencies: + buffer-equal-constant-time: 1.0.1 + ecdsa-sig-formatter: 1.0.11 + safe-buffer: 5.2.1 + + jws@3.2.2: + dependencies: + jwa: 1.4.2 + safe-buffer: 5.2.1 + + keytar@7.9.0: + dependencies: + node-addon-api: 4.3.0 + prebuild-install: 7.1.3 + optional: true + + keyv@4.5.4: + dependencies: + json-buffer: 3.0.1 + + kleur@3.0.3: {} + + kolorist@1.8.0: {} + + leven@3.1.0: {} + + levn@0.4.1: + dependencies: + prelude-ls: 1.2.1 + type-check: 0.4.0 + + lie@3.3.0: + dependencies: + immediate: 3.0.6 + + lightningcss-darwin-arm64@1.30.1: + optional: true + + lightningcss-darwin-x64@1.30.1: + optional: true + + lightningcss-freebsd-x64@1.30.1: + optional: true + + lightningcss-linux-arm-gnueabihf@1.30.1: + optional: true + + lightningcss-linux-arm64-gnu@1.30.1: + optional: true + + lightningcss-linux-arm64-musl@1.30.1: + optional: true + + lightningcss-linux-x64-gnu@1.30.1: + optional: true + + lightningcss-linux-x64-musl@1.30.1: + optional: true + + lightningcss-win32-arm64-msvc@1.30.1: + optional: true + + lightningcss-win32-x64-msvc@1.30.1: + optional: true + + lightningcss@1.30.1: + dependencies: + detect-libc: 2.0.4 + optionalDependencies: + lightningcss-darwin-arm64: 1.30.1 + lightningcss-darwin-x64: 1.30.1 + lightningcss-freebsd-x64: 1.30.1 + lightningcss-linux-arm-gnueabihf: 1.30.1 + lightningcss-linux-arm64-gnu: 1.30.1 + lightningcss-linux-arm64-musl: 1.30.1 + lightningcss-linux-x64-gnu: 1.30.1 + lightningcss-linux-x64-musl: 1.30.1 + lightningcss-win32-arm64-msvc: 1.30.1 + lightningcss-win32-x64-msvc: 1.30.1 + + lilconfig@3.1.3: {} + + lines-and-columns@1.2.4: {} + + linkify-it@5.0.0: + dependencies: + uc.micro: 2.1.0 + + loader-runner@4.3.0: {} + + local-pkg@1.1.1: + dependencies: + mlly: 1.7.4 + pkg-types: 2.2.0 + quansync: 0.2.10 + + locate-path@6.0.0: + dependencies: + p-locate: 5.0.0 + + locate-path@7.2.0: + dependencies: + p-locate: 6.0.0 + + lodash.camelcase@4.3.0: {} + + lodash.castarray@4.4.0: {} + + lodash.includes@4.3.0: {} + + lodash.isboolean@3.0.3: {} + + lodash.isempty@4.4.0: {} + + lodash.isinteger@4.0.4: {} + + lodash.isnumber@3.0.3: {} + + lodash.isplainobject@4.0.6: {} + + lodash.isstring@4.0.1: {} + + lodash.merge@4.6.2: {} + + lodash.omitby@4.6.0: {} + + lodash.once@4.1.1: {} + + lodash.topath@4.5.2: {} + + lodash.truncate@4.4.2: {} + + lodash.uniq@4.5.0: {} + + lodash.uniqby@4.7.0: {} + + lodash.uniqwith@4.5.0: {} + + lodash@4.17.21: {} + + log-symbols@4.1.0: + dependencies: + chalk: 4.1.2 + is-unicode-supported: 0.1.0 + + log-symbols@6.0.0: + dependencies: + chalk: 5.4.1 + is-unicode-supported: 1.3.0 + + loglevel-plugin-prefix@0.8.4: {} + + loglevel@1.9.2: {} + + longest-streak@3.1.0: {} + + loose-envify@1.4.0: + dependencies: + js-tokens: 4.0.0 + + loupe@3.1.4: {} + + lru-cache@10.4.3: {} + + lru-cache@11.1.0: {} + + lru-cache@5.1.1: + dependencies: + yallist: 3.1.1 + + lru-cache@6.0.0: + dependencies: + yallist: 4.0.0 + + lucide-react@0.542.0(react@18.3.1): + dependencies: + react: 18.3.1 + + lunr@2.3.9: {} + + lz-string@1.5.0: {} + + magic-string@0.30.17: + dependencies: + '@jridgewell/sourcemap-codec': 1.5.4 + + magicast@0.3.5: + dependencies: + '@babel/parser': 7.28.0 + '@babel/types': 7.28.1 + source-map-js: 1.2.1 + + make-dir@4.0.0: + dependencies: + semver: 7.7.2 + + markdown-it@14.1.0: + dependencies: + argparse: 2.0.1 + entities: 4.5.0 + linkify-it: 5.0.0 + mdurl: 2.0.0 + punycode.js: 2.3.1 + uc.micro: 2.1.0 + + math-intrinsics@1.1.0: {} + + mdast-util-from-markdown@2.0.2: + dependencies: + '@types/mdast': 4.0.4 + '@types/unist': 3.0.3 + decode-named-character-reference: 1.2.0 + devlop: 1.1.0 + mdast-util-to-string: 4.0.0 + micromark: 4.0.2 + micromark-util-decode-numeric-character-reference: 2.0.2 + micromark-util-decode-string: 2.0.1 + micromark-util-normalize-identifier: 2.0.1 + micromark-util-symbol: 2.0.1 + micromark-util-types: 2.0.2 + unist-util-stringify-position: 4.0.0 + transitivePeerDependencies: + - supports-color + + mdast-util-mdx-expression@2.0.1: + dependencies: + '@types/estree-jsx': 1.0.5 + '@types/hast': 3.0.4 + '@types/mdast': 4.0.4 + devlop: 1.1.0 + mdast-util-from-markdown: 2.0.2 + mdast-util-to-markdown: 2.1.2 + transitivePeerDependencies: + - supports-color + + mdast-util-mdx-jsx@3.2.0: + dependencies: + '@types/estree-jsx': 1.0.5 + '@types/hast': 3.0.4 + '@types/mdast': 4.0.4 + '@types/unist': 3.0.3 + ccount: 2.0.1 + devlop: 1.1.0 + mdast-util-from-markdown: 2.0.2 + mdast-util-to-markdown: 2.1.2 + parse-entities: 4.0.2 + stringify-entities: 4.0.4 + unist-util-stringify-position: 4.0.0 + vfile-message: 4.0.2 + transitivePeerDependencies: + - supports-color + + mdast-util-mdxjs-esm@2.0.1: + dependencies: + '@types/estree-jsx': 1.0.5 + '@types/hast': 3.0.4 + '@types/mdast': 4.0.4 + devlop: 1.1.0 + mdast-util-from-markdown: 2.0.2 + mdast-util-to-markdown: 2.1.2 + transitivePeerDependencies: + - supports-color + + mdast-util-phrasing@4.1.0: + dependencies: + '@types/mdast': 4.0.4 + unist-util-is: 6.0.0 + + mdast-util-to-hast@13.2.0: + dependencies: + '@types/hast': 3.0.4 + '@types/mdast': 4.0.4 + '@ungap/structured-clone': 1.3.0 + devlop: 1.1.0 + micromark-util-sanitize-uri: 2.0.1 + trim-lines: 3.0.1 + unist-util-position: 5.0.0 + unist-util-visit: 5.0.0 + vfile: 6.0.3 + + mdast-util-to-markdown@2.1.2: + dependencies: + '@types/mdast': 4.0.4 + '@types/unist': 3.0.3 + longest-streak: 3.1.0 + mdast-util-phrasing: 4.1.0 + mdast-util-to-string: 4.0.0 + micromark-util-classify-character: 2.0.1 + micromark-util-decode-string: 2.0.1 + unist-util-visit: 5.0.0 + zwitch: 2.0.4 + + mdast-util-to-string@4.0.0: + dependencies: + '@types/mdast': 4.0.4 + + mdurl@2.0.0: {} + + merge-stream@2.0.0: {} + + merge2@1.4.1: {} + + micromark-core-commonmark@2.0.3: + dependencies: + decode-named-character-reference: 1.2.0 + devlop: 1.1.0 + micromark-factory-destination: 2.0.1 + micromark-factory-label: 2.0.1 + micromark-factory-space: 2.0.1 + micromark-factory-title: 2.0.1 + micromark-factory-whitespace: 2.0.1 + micromark-util-character: 2.1.1 + micromark-util-chunked: 2.0.1 + micromark-util-classify-character: 2.0.1 + micromark-util-html-tag-name: 2.0.1 + micromark-util-normalize-identifier: 2.0.1 + micromark-util-resolve-all: 2.0.1 + micromark-util-subtokenize: 2.1.0 + micromark-util-symbol: 2.0.1 + micromark-util-types: 2.0.2 + + micromark-factory-destination@2.0.1: + dependencies: + micromark-util-character: 2.1.1 + micromark-util-symbol: 2.0.1 + micromark-util-types: 2.0.2 + + micromark-factory-label@2.0.1: + dependencies: + devlop: 1.1.0 + micromark-util-character: 2.1.1 + micromark-util-symbol: 2.0.1 + micromark-util-types: 2.0.2 + + micromark-factory-space@2.0.1: + dependencies: + micromark-util-character: 2.1.1 + micromark-util-types: 2.0.2 + + micromark-factory-title@2.0.1: + dependencies: + micromark-factory-space: 2.0.1 + micromark-util-character: 2.1.1 + micromark-util-symbol: 2.0.1 + micromark-util-types: 2.0.2 + + micromark-factory-whitespace@2.0.1: + dependencies: + micromark-factory-space: 2.0.1 + micromark-util-character: 2.1.1 + micromark-util-symbol: 2.0.1 + micromark-util-types: 2.0.2 + + micromark-util-character@2.1.1: + dependencies: + micromark-util-symbol: 2.0.1 + micromark-util-types: 2.0.2 + + micromark-util-chunked@2.0.1: + dependencies: + micromark-util-symbol: 2.0.1 + + micromark-util-classify-character@2.0.1: + dependencies: + micromark-util-character: 2.1.1 + micromark-util-symbol: 2.0.1 + micromark-util-types: 2.0.2 + + micromark-util-combine-extensions@2.0.1: + dependencies: + micromark-util-chunked: 2.0.1 + micromark-util-types: 2.0.2 + + micromark-util-decode-numeric-character-reference@2.0.2: + dependencies: + micromark-util-symbol: 2.0.1 + + micromark-util-decode-string@2.0.1: + dependencies: + decode-named-character-reference: 1.2.0 + micromark-util-character: 2.1.1 + micromark-util-decode-numeric-character-reference: 2.0.2 + micromark-util-symbol: 2.0.1 + + micromark-util-encode@2.0.1: {} + + micromark-util-html-tag-name@2.0.1: {} + + micromark-util-normalize-identifier@2.0.1: + dependencies: + micromark-util-symbol: 2.0.1 + + micromark-util-resolve-all@2.0.1: + dependencies: + micromark-util-types: 2.0.2 + + micromark-util-sanitize-uri@2.0.1: + dependencies: + micromark-util-character: 2.1.1 + micromark-util-encode: 2.0.1 + micromark-util-symbol: 2.0.1 + + micromark-util-subtokenize@2.1.0: + dependencies: + devlop: 1.1.0 + micromark-util-chunked: 2.0.1 + micromark-util-symbol: 2.0.1 + micromark-util-types: 2.0.2 + + micromark-util-symbol@2.0.1: {} + + micromark-util-types@2.0.2: {} + + micromark@4.0.2: + dependencies: + '@types/debug': 4.1.12 + debug: 4.4.1(supports-color@8.1.1) + decode-named-character-reference: 1.2.0 + devlop: 1.1.0 + micromark-core-commonmark: 2.0.3 + micromark-factory-space: 2.0.1 + micromark-util-character: 2.1.1 + micromark-util-chunked: 2.0.1 + micromark-util-combine-extensions: 2.0.1 + micromark-util-decode-numeric-character-reference: 2.0.2 + micromark-util-encode: 2.0.1 + micromark-util-normalize-identifier: 2.0.1 + micromark-util-resolve-all: 2.0.1 + micromark-util-sanitize-uri: 2.0.1 + micromark-util-subtokenize: 2.1.0 + micromark-util-symbol: 2.0.1 + micromark-util-types: 2.0.2 + transitivePeerDependencies: + - supports-color + + micromatch@4.0.8: + dependencies: + braces: 3.0.3 + picomatch: 2.3.1 + + mime-db@1.52.0: {} + + mime-types@2.1.35: + dependencies: + mime-db: 1.52.0 + + mime@1.6.0: {} + + mimic-fn@2.1.0: {} + + mimic-function@5.0.1: {} + + mimic-response@3.1.0: + optional: true + + min-indent@1.0.1: {} + + minimatch@10.0.3: + dependencies: + '@isaacs/brace-expansion': 5.0.0 + + minimatch@3.1.2: + dependencies: + brace-expansion: 1.1.12 + + minimatch@5.1.6: + dependencies: + brace-expansion: 2.0.2 + + minimatch@6.2.0: + dependencies: + brace-expansion: 2.0.2 + + minimatch@9.0.5: + dependencies: + brace-expansion: 2.0.2 + + minimist@1.2.8: {} + + minipass@7.1.2: {} + + minizlib@3.0.2: + dependencies: + minipass: 7.1.2 + + mkdirp-classic@0.5.3: + optional: true + + mkdirp@3.0.1: {} + + mlly@1.7.4: + dependencies: + acorn: 8.15.0 + pathe: 2.0.3 + pkg-types: 1.3.1 + ufo: 1.6.1 + + mocha@10.8.2: + dependencies: + ansi-colors: 4.1.3 + browser-stdout: 1.3.1 + chokidar: 3.6.0 + debug: 4.4.1(supports-color@8.1.1) + diff: 5.2.0 + escape-string-regexp: 4.0.0 + find-up: 5.0.0 + glob: 8.1.0 + he: 1.2.0 + js-yaml: 4.1.0 + log-symbols: 4.1.0 + minimatch: 5.1.6 + ms: 2.1.3 + serialize-javascript: 6.0.2 + strip-json-comments: 3.1.1 + supports-color: 8.1.1 + workerpool: 6.5.1 + yargs: 16.2.0 + yargs-parser: 20.2.9 + yargs-unparser: 2.0.0 + + mrmime@2.0.1: {} + + ms@2.1.3: {} + + muggle-string@0.4.1: {} + + mute-stream@0.0.8: {} + + mz@2.7.0: + dependencies: + any-promise: 1.3.0 + object-assign: 4.1.1 + thenify-all: 1.6.0 + + nanoid@3.3.11: {} + + napi-build-utils@2.0.0: + optional: true + + natural-compare@1.4.0: {} + + neo-async@2.6.2: {} + + nimma@0.2.3: + dependencies: + '@jsep-plugin/regex': 1.0.4(jsep@1.4.0) + '@jsep-plugin/ternary': 1.1.4(jsep@1.4.0) + astring: 1.9.0 + jsep: 1.4.0 + optionalDependencies: + jsonpath-plus: 10.3.0 + lodash.topath: 4.5.2 + + node-abi@3.75.0: + dependencies: + semver: 7.7.2 + optional: true + + node-addon-api@4.3.0: + optional: true + + node-fetch-h2@2.3.0: + dependencies: + http2-client: 1.3.5 + + node-fetch@2.7.0: + dependencies: + whatwg-url: 5.0.0 + + node-readfiles@0.2.0: + dependencies: + es6-promise: 3.3.1 + + node-releases@2.0.21: {} + + node-sarif-builder@3.2.0: + dependencies: + '@types/sarif': 2.1.7 + fs-extra: 11.3.0 + + normalize-package-data@6.0.2: + dependencies: + hosted-git-info: 7.0.2 + semver: 7.7.2 + validate-npm-package-license: 3.0.4 + + normalize-path@3.0.0: {} + + normalize-range@0.1.2: {} + + npm-package-arg@12.0.2: + dependencies: + hosted-git-info: 8.1.0 + proc-log: 5.0.0 + semver: 7.7.2 + validate-npm-package-name: 6.0.2 + + npm-run-path@4.0.1: + dependencies: + path-key: 3.1.1 + + nth-check@2.1.1: + dependencies: + boolbase: 1.0.0 + + nwsapi@2.2.20: {} + + oas-kit-common@1.0.8: + dependencies: + fast-safe-stringify: 2.1.1 + + oas-linter@3.2.2: + dependencies: + '@exodus/schemasafe': 1.3.0 + should: 13.2.3 + yaml: 1.10.2 + + oas-resolver@2.5.6: + dependencies: + node-fetch-h2: 2.3.0 + oas-kit-common: 1.0.8 + reftools: 1.1.9 + yaml: 1.10.2 + yargs: 17.7.2 + + oas-schema-walker@1.1.5: {} + + oas-validator@5.0.8: + dependencies: + call-me-maybe: 1.0.2 + oas-kit-common: 1.0.8 + oas-linter: 3.2.2 + oas-resolver: 2.5.6 + oas-schema-walker: 1.1.5 + reftools: 1.1.9 + should: 13.2.3 + yaml: 1.10.2 + + object-assign@4.1.1: {} + + object-hash@3.0.0: {} + + object-inspect@1.13.4: {} + + object-keys@1.1.1: {} + + object.assign@4.1.7: + dependencies: + call-bind: 1.0.8 + call-bound: 1.0.4 + define-properties: 1.2.1 + es-object-atoms: 1.1.1 + has-symbols: 1.1.0 + object-keys: 1.1.1 + + once@1.4.0: + dependencies: + wrappy: 1.0.2 + + onetime@5.1.2: + dependencies: + mimic-fn: 2.1.0 + + onetime@7.0.0: + dependencies: + mimic-function: 5.0.1 + + open@10.2.0: + dependencies: + default-browser: 5.2.1 + define-lazy-prop: 3.0.0 + is-inside-container: 1.0.0 + wsl-utils: 0.1.0 + + open@8.4.2: + dependencies: + define-lazy-prop: 2.0.0 + is-docker: 2.2.1 + is-wsl: 2.2.0 + + openapi-types@12.1.3: {} + + openapi3-ts@4.2.2: + dependencies: + yaml: 2.8.0 + + openapi3-ts@4.4.0: + dependencies: + yaml: 2.8.0 + + optionator@0.9.4: + dependencies: + deep-is: 0.1.4 + fast-levenshtein: 2.0.6 + levn: 0.4.1 + prelude-ls: 1.2.1 + type-check: 0.4.0 + word-wrap: 1.2.5 + + ora@8.2.0: + dependencies: + chalk: 5.4.1 + cli-cursor: 5.0.0 + cli-spinners: 2.9.2 + is-interactive: 2.0.0 + is-unicode-supported: 2.1.0 + log-symbols: 6.0.0 + stdin-discarder: 0.2.2 + string-width: 7.2.0 + strip-ansi: 7.1.0 + + orval@7.10.0(openapi-types@12.1.3): + dependencies: + '@apidevtools/swagger-parser': 10.1.1(openapi-types@12.1.3) + '@orval/angular': 7.10.0(openapi-types@12.1.3) + '@orval/axios': 7.10.0(openapi-types@12.1.3) + '@orval/core': 7.10.0(openapi-types@12.1.3) + '@orval/fetch': 7.10.0(openapi-types@12.1.3) + '@orval/hono': 7.10.0(openapi-types@12.1.3) + '@orval/mcp': 7.10.0(openapi-types@12.1.3) + '@orval/mock': 7.10.0(openapi-types@12.1.3) + '@orval/query': 7.10.0(openapi-types@12.1.3) + '@orval/swr': 7.10.0(openapi-types@12.1.3) + '@orval/zod': 7.10.0(openapi-types@12.1.3) + ajv: 8.17.1 + cac: 6.7.14 + chalk: 4.1.2 + chokidar: 4.0.3 + enquirer: 2.4.1 + execa: 5.1.1 + find-up: 5.0.0 + fs-extra: 11.3.0 + lodash.uniq: 4.5.0 + openapi3-ts: 4.2.2 + string-argv: 0.3.2 + tsconfck: 2.1.2(typescript@5.8.3) + typedoc: 0.28.7(typescript@5.8.3) + typedoc-plugin-markdown: 4.7.1(typedoc@0.28.7(typescript@5.8.3)) + typescript: 5.8.3 + transitivePeerDependencies: + - encoding + - openapi-types + - supports-color + + own-keys@1.0.1: + dependencies: + get-intrinsic: 1.3.0 + object-keys: 1.1.1 + safe-push-apply: 1.0.0 + + p-limit@3.1.0: + dependencies: + yocto-queue: 0.1.0 + + p-limit@4.0.0: + dependencies: + yocto-queue: 1.2.1 + + p-locate@5.0.0: + dependencies: + p-limit: 3.1.0 + + p-locate@6.0.0: + dependencies: + p-limit: 4.0.0 + + p-map@7.0.3: {} + + package-json-from-dist@1.0.1: {} + + pako@1.0.11: {} + + parent-module@1.0.1: + dependencies: + callsites: 3.1.0 + + parse-entities@4.0.2: + dependencies: + '@types/unist': 2.0.11 + character-entities-legacy: 3.0.0 + character-reference-invalid: 2.0.1 + decode-named-character-reference: 1.2.0 + is-alphanumerical: 2.0.1 + is-decimal: 2.0.1 + is-hexadecimal: 2.0.1 + + parse-json@5.2.0: + dependencies: + '@babel/code-frame': 7.27.1 + error-ex: 1.3.2 + json-parse-even-better-errors: 2.3.1 + lines-and-columns: 1.2.4 + + parse-json@8.3.0: + dependencies: + '@babel/code-frame': 7.27.1 + index-to-position: 1.1.0 + type-fest: 4.41.0 + + parse-semver@1.1.1: + dependencies: + semver: 5.7.2 + + parse5-htmlparser2-tree-adapter@7.1.0: + dependencies: + domhandler: 5.0.3 + parse5: 7.3.0 + + parse5-parser-stream@7.1.2: + dependencies: + parse5: 7.3.0 + + parse5@7.3.0: + dependencies: + entities: 6.0.1 + + path-browserify@1.0.1: {} + + path-exists@4.0.0: {} + + path-exists@5.0.0: {} + + path-is-absolute@1.0.1: {} + + path-key@3.1.1: {} + + path-parse@1.0.7: {} + + path-scurry@1.11.1: + dependencies: + lru-cache: 10.4.3 + minipass: 7.1.2 + + path-scurry@2.0.0: + dependencies: + lru-cache: 11.1.0 + minipass: 7.1.2 + + path-type@4.0.0: {} + + path-type@6.0.0: {} + + pathe@2.0.3: {} + + pathval@2.0.1: {} + + pend@1.2.0: {} + + picocolors@1.1.1: {} + + picomatch@2.3.1: {} + + picomatch@4.0.3: {} + + pify@2.3.0: {} + + pirates@4.0.7: {} + + pkg-types@1.3.1: + dependencies: + confbox: 0.1.8 + mlly: 1.7.4 + pathe: 2.0.3 + + pkg-types@2.2.0: + dependencies: + confbox: 0.2.2 + exsolve: 1.0.7 + pathe: 2.0.3 + + playwright-core@1.54.1: {} + + playwright@1.54.1: + dependencies: + playwright-core: 1.54.1 + optionalDependencies: + fsevents: 2.3.2 + + pluralize@2.0.0: {} + + pluralize@8.0.0: {} + + pony-cause@1.1.1: {} + + possible-typed-array-names@1.1.0: {} + + postcss-import@15.1.0(postcss@8.5.6): + dependencies: + postcss: 8.5.6 + postcss-value-parser: 4.2.0 + read-cache: 1.0.0 + resolve: 1.22.10 + + postcss-js@4.0.1(postcss@8.5.6): + dependencies: + camelcase-css: 2.0.1 + postcss: 8.5.6 + + postcss-load-config@4.0.2(postcss@8.5.6): + dependencies: + lilconfig: 3.1.3 + yaml: 2.8.0 + optionalDependencies: + postcss: 8.5.6 + + postcss-nested@6.2.0(postcss@8.5.6): + dependencies: + postcss: 8.5.6 + postcss-selector-parser: 6.1.2 + + postcss-selector-parser@6.0.10: + dependencies: + cssesc: 3.0.0 + util-deprecate: 1.0.2 + + postcss-selector-parser@6.1.2: + dependencies: + cssesc: 3.0.0 + util-deprecate: 1.0.2 + + postcss-value-parser@4.2.0: {} + + postcss@8.5.6: + dependencies: + nanoid: 3.3.11 + picocolors: 1.1.1 + source-map-js: 1.2.1 + + prebuild-install@7.1.3: + dependencies: + detect-libc: 2.0.4 + expand-template: 2.0.3 + github-from-package: 0.0.0 + minimist: 1.2.8 + mkdirp-classic: 0.5.3 + napi-build-utils: 2.0.0 + node-abi: 3.75.0 + pump: 3.0.3 + rc: 1.2.8 + simple-get: 4.0.1 + tar-fs: 2.1.3 + tunnel-agent: 0.6.0 + optional: true + + prelude-ls@1.2.1: {} + + prettier@3.6.2: {} + + pretty-format@27.5.1: + dependencies: + ansi-regex: 5.0.1 + ansi-styles: 5.2.0 + react-is: 17.0.2 + + proc-log@5.0.0: {} + + process-nextick-args@2.0.1: {} + + prompts@2.4.2: + dependencies: + kleur: 3.0.3 + sisteransi: 1.0.5 + + prop-types@15.8.1: + dependencies: + loose-envify: 1.4.0 + object-assign: 4.1.1 + react-is: 16.13.1 + + property-information@7.1.0: {} + + pump@3.0.3: + dependencies: + end-of-stream: 1.4.5 + once: 1.4.0 + optional: true + + punycode.js@2.3.1: {} + + punycode@2.3.1: {} + + pure-rand@6.1.0: {} + + qs@6.14.0: + dependencies: + side-channel: 1.1.0 + + quansync@0.2.10: {} + + queue-microtask@1.2.3: {} + + randombytes@2.1.0: + dependencies: + safe-buffer: 5.2.1 + + rc-config-loader@4.1.3: + dependencies: + debug: 4.4.1(supports-color@8.1.1) + js-yaml: 4.1.0 + json5: 2.2.3 + require-from-string: 2.0.2 + transitivePeerDependencies: + - supports-color + + rc@1.2.8: + dependencies: + deep-extend: 0.6.0 + ini: 1.3.8 + minimist: 1.2.8 + strip-json-comments: 2.0.1 + optional: true + + react-dnd-html5-backend@16.0.1: + dependencies: + dnd-core: 16.0.1 + + react-dnd@16.0.1(@types/node@24.1.0)(@types/react@18.3.23)(react@18.3.1): + dependencies: + '@react-dnd/invariant': 4.0.2 + '@react-dnd/shallowequal': 4.0.2 + dnd-core: 16.0.1 + fast-deep-equal: 3.1.3 + hoist-non-react-statics: 3.3.2 + react: 18.3.1 + optionalDependencies: + '@types/node': 24.1.0 + '@types/react': 18.3.23 + + react-docgen-typescript@2.4.0(typescript@5.8.3): + dependencies: + typescript: 5.8.3 + + react-docgen@8.0.0: + dependencies: + '@babel/core': 7.28.0 + '@babel/traverse': 7.28.0 + '@babel/types': 7.28.1 + '@types/babel__core': 7.20.5 + '@types/babel__traverse': 7.20.7 + '@types/doctrine': 0.0.9 + '@types/resolve': 1.20.6 + doctrine: 3.0.0 + resolve: 1.22.10 + strip-indent: 4.0.0 + transitivePeerDependencies: + - supports-color + + react-dom@18.3.1(react@18.3.1): + dependencies: + loose-envify: 1.4.0 + react: 18.3.1 + scheduler: 0.23.2 + + react-is@16.13.1: {} + + react-is@17.0.2: {} + + react-markdown@10.1.0(@types/react@18.3.23)(react@18.3.1): + dependencies: + '@types/hast': 3.0.4 + '@types/mdast': 4.0.4 + '@types/react': 18.3.23 + devlop: 1.1.0 + hast-util-to-jsx-runtime: 2.3.6 + html-url-attributes: 3.0.1 + mdast-util-to-hast: 13.2.0 + react: 18.3.1 + remark-parse: 11.0.0 + remark-rehype: 11.1.2 + unified: 11.0.5 + unist-util-visit: 5.0.0 + vfile: 6.0.3 + transitivePeerDependencies: + - supports-color + + react-refresh@0.17.0: {} + + react-remove-scroll-bar@2.3.8(@types/react@18.3.23)(react@18.3.1): + dependencies: + react: 18.3.1 + react-style-singleton: 2.2.3(@types/react@18.3.23)(react@18.3.1) + tslib: 2.8.1 + optionalDependencies: + '@types/react': 18.3.23 + + react-remove-scroll@2.7.1(@types/react@18.3.23)(react@18.3.1): + dependencies: + react: 18.3.1 + react-remove-scroll-bar: 2.3.8(@types/react@18.3.23)(react@18.3.1) + react-style-singleton: 2.2.3(@types/react@18.3.23)(react@18.3.1) + tslib: 2.8.1 + use-callback-ref: 1.3.3(@types/react@18.3.23)(react@18.3.1) + use-sidecar: 1.1.3(@types/react@18.3.23)(react@18.3.1) + optionalDependencies: + '@types/react': 18.3.23 + + react-router@7.7.0(react-dom@18.3.1(react@18.3.1))(react@18.3.1): + dependencies: + cookie: 1.0.2 + react: 18.3.1 + set-cookie-parser: 2.7.1 + optionalDependencies: + react-dom: 18.3.1(react@18.3.1) + + react-split@2.0.14(react@18.3.1): + dependencies: + prop-types: 15.8.1 + react: 18.3.1 + split.js: 1.6.5 + + react-style-singleton@2.2.3(@types/react@18.3.23)(react@18.3.1): + dependencies: + get-nonce: 1.0.1 + react: 18.3.1 + tslib: 2.8.1 + optionalDependencies: + '@types/react': 18.3.23 + + react@18.3.1: + dependencies: + loose-envify: 1.4.0 + + reactflow@11.11.4(@types/react@18.3.23)(immer@9.0.21)(react-dom@18.3.1(react@18.3.1))(react@18.3.1): + dependencies: + '@reactflow/background': 11.3.14(@types/react@18.3.23)(immer@9.0.21)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@reactflow/controls': 11.2.14(@types/react@18.3.23)(immer@9.0.21)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@reactflow/core': 11.11.4(@types/react@18.3.23)(immer@9.0.21)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@reactflow/minimap': 11.7.14(@types/react@18.3.23)(immer@9.0.21)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@reactflow/node-resizer': 2.2.14(@types/react@18.3.23)(immer@9.0.21)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + '@reactflow/node-toolbar': 1.3.14(@types/react@18.3.23)(immer@9.0.21)(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + transitivePeerDependencies: + - '@types/react' + - immer + + read-cache@1.0.0: + dependencies: + pify: 2.3.0 + + read-pkg@9.0.1: + dependencies: + '@types/normalize-package-data': 2.4.4 + normalize-package-data: 6.0.2 + parse-json: 8.3.0 + type-fest: 4.41.0 + unicorn-magic: 0.1.0 + + read-yaml-file@2.1.0: + dependencies: + js-yaml: 4.1.0 + strip-bom: 4.0.0 + + read@1.0.7: + dependencies: + mute-stream: 0.0.8 + + readable-stream@2.3.8: + dependencies: + core-util-is: 1.0.3 + inherits: 2.0.4 + isarray: 1.0.0 + process-nextick-args: 2.0.1 + safe-buffer: 5.1.2 + string_decoder: 1.1.1 + util-deprecate: 1.0.2 + + readable-stream@3.6.2: + dependencies: + inherits: 2.0.4 + string_decoder: 1.3.0 + util-deprecate: 1.0.2 + optional: true + + readdirp@3.6.0: + dependencies: + picomatch: 2.3.1 + + readdirp@4.1.2: {} + + recast@0.23.11: + dependencies: + ast-types: 0.16.1 + esprima: 4.0.1 + source-map: 0.6.1 + tiny-invariant: 1.3.3 + tslib: 2.8.1 + + redent@3.0.0: + dependencies: + indent-string: 4.0.0 + strip-indent: 3.0.0 + + redux@4.2.1: + dependencies: + '@babel/runtime': 7.28.2 + + reflect.getprototypeof@1.0.10: + dependencies: + call-bind: 1.0.8 + define-properties: 1.2.1 + es-abstract: 1.24.0 + es-errors: 1.3.0 + es-object-atoms: 1.1.1 + get-intrinsic: 1.3.0 + get-proto: 1.0.1 + which-builtin-type: 1.2.1 + + reftools@1.1.9: {} + + regexp.prototype.flags@1.5.4: + dependencies: + call-bind: 1.0.8 + define-properties: 1.2.1 + es-errors: 1.3.0 + get-proto: 1.0.1 + gopd: 1.2.0 + set-function-name: 2.0.2 + + remark-parse@11.0.0: + dependencies: + '@types/mdast': 4.0.4 + mdast-util-from-markdown: 2.0.2 + micromark-util-types: 2.0.2 + unified: 11.0.5 + transitivePeerDependencies: + - supports-color + + remark-rehype@11.1.2: + dependencies: + '@types/hast': 3.0.4 + '@types/mdast': 4.0.4 + mdast-util-to-hast: 13.2.0 + unified: 11.0.5 + vfile: 6.0.3 + + require-directory@2.1.1: {} + + require-from-string@2.0.2: {} + + resolve-from@4.0.0: {} + + resolve-pkg-maps@1.0.0: {} + + resolve@1.22.10: + dependencies: + is-core-module: 2.16.1 + path-parse: 1.0.7 + supports-preserve-symlinks-flag: 1.0.0 + + restore-cursor@5.1.0: + dependencies: + onetime: 7.0.0 + signal-exit: 4.1.0 + + reusify@1.1.0: {} + + rollup@4.45.1: + dependencies: + '@types/estree': 1.0.8 + optionalDependencies: + '@rollup/rollup-android-arm-eabi': 4.45.1 + '@rollup/rollup-android-arm64': 4.45.1 + '@rollup/rollup-darwin-arm64': 4.45.1 + '@rollup/rollup-darwin-x64': 4.45.1 + '@rollup/rollup-freebsd-arm64': 4.45.1 + '@rollup/rollup-freebsd-x64': 4.45.1 + '@rollup/rollup-linux-arm-gnueabihf': 4.45.1 + '@rollup/rollup-linux-arm-musleabihf': 4.45.1 + '@rollup/rollup-linux-arm64-gnu': 4.45.1 + '@rollup/rollup-linux-arm64-musl': 4.45.1 + '@rollup/rollup-linux-loongarch64-gnu': 4.45.1 + '@rollup/rollup-linux-powerpc64le-gnu': 4.45.1 + '@rollup/rollup-linux-riscv64-gnu': 4.45.1 + '@rollup/rollup-linux-riscv64-musl': 4.45.1 + '@rollup/rollup-linux-s390x-gnu': 4.45.1 + '@rollup/rollup-linux-x64-gnu': 4.45.1 + '@rollup/rollup-linux-x64-musl': 4.45.1 + '@rollup/rollup-win32-arm64-msvc': 4.45.1 + '@rollup/rollup-win32-ia32-msvc': 4.45.1 + '@rollup/rollup-win32-x64-msvc': 4.45.1 + fsevents: 2.3.3 + + rrweb-cssom@0.8.0: {} + + run-applescript@7.0.0: {} + + run-parallel@1.2.0: + dependencies: + queue-microtask: 1.2.3 + + safe-array-concat@1.1.3: + dependencies: + call-bind: 1.0.8 + call-bound: 1.0.4 + get-intrinsic: 1.3.0 + has-symbols: 1.1.0 + isarray: 2.0.5 + + safe-buffer@5.1.2: {} + + safe-buffer@5.2.1: {} + + safe-push-apply@1.0.0: + dependencies: + es-errors: 1.3.0 + isarray: 2.0.5 + + safe-regex-test@1.1.0: + dependencies: + call-bound: 1.0.4 + es-errors: 1.3.0 + is-regex: 1.2.1 + + safe-stable-stringify@1.1.1: {} + + safer-buffer@2.1.2: {} + + sax@1.4.1: {} + + saxes@6.0.0: + dependencies: + xmlchars: 2.2.0 + + scheduler@0.23.2: + dependencies: + loose-envify: 1.4.0 + + schema-utils@4.3.2: + dependencies: + '@types/json-schema': 7.0.15 + ajv: 8.17.1 + ajv-formats: 2.1.1(ajv@8.17.1) + ajv-keywords: 5.1.0(ajv@8.17.1) + + secretlint@10.2.1: + dependencies: + '@secretlint/config-creator': 10.2.1 + '@secretlint/formatter': 10.2.1 + '@secretlint/node': 10.2.1 + '@secretlint/profiler': 10.2.1 + debug: 4.4.1(supports-color@8.1.1) + globby: 14.1.0 + read-pkg: 9.0.1 + transitivePeerDependencies: + - supports-color + + semver@5.7.2: {} + + semver@6.3.1: {} + + semver@7.5.4: + dependencies: + lru-cache: 6.0.0 + + semver@7.7.2: {} + + serialize-javascript@6.0.2: + dependencies: + randombytes: 2.1.0 + + seroval-plugins@1.3.2(seroval@1.3.2): + dependencies: + seroval: 1.3.2 + + seroval-plugins@1.3.3(seroval@1.3.2): + dependencies: + seroval: 1.3.2 + + seroval@1.3.2: {} + + set-cookie-parser@2.7.1: {} + + set-function-length@1.2.2: + dependencies: + define-data-property: 1.1.4 + es-errors: 1.3.0 + function-bind: 1.1.2 + get-intrinsic: 1.3.0 + gopd: 1.2.0 + has-property-descriptors: 1.0.2 + + set-function-name@2.0.2: + dependencies: + define-data-property: 1.1.4 + es-errors: 1.3.0 + functions-have-names: 1.2.3 + has-property-descriptors: 1.0.2 + + set-proto@1.0.0: + dependencies: + dunder-proto: 1.0.1 + es-errors: 1.3.0 + es-object-atoms: 1.1.1 + + setimmediate@1.0.5: {} + + shebang-command@2.0.0: + dependencies: + shebang-regex: 3.0.0 + + shebang-regex@3.0.0: {} + + shell-quote@1.8.3: {} + + should-equal@2.0.0: + dependencies: + should-type: 1.4.0 + + should-format@3.0.3: + dependencies: + should-type: 1.4.0 + should-type-adaptors: 1.1.0 + + should-type-adaptors@1.1.0: + dependencies: + should-type: 1.4.0 + should-util: 1.0.1 + + should-type@1.4.0: {} + + should-util@1.0.1: {} + + should@13.2.3: + dependencies: + should-equal: 2.0.0 + should-format: 3.0.3 + should-type: 1.4.0 + should-type-adaptors: 1.1.0 + should-util: 1.0.1 + + side-channel-list@1.0.0: + dependencies: + es-errors: 1.3.0 + object-inspect: 1.13.4 + + side-channel-map@1.0.1: + dependencies: + call-bound: 1.0.4 + es-errors: 1.3.0 + get-intrinsic: 1.3.0 + object-inspect: 1.13.4 + + side-channel-weakmap@1.0.2: + dependencies: + call-bound: 1.0.4 + es-errors: 1.3.0 + get-intrinsic: 1.3.0 + object-inspect: 1.13.4 + side-channel-map: 1.0.1 + + side-channel@1.1.0: + dependencies: + es-errors: 1.3.0 + object-inspect: 1.13.4 + side-channel-list: 1.0.0 + side-channel-map: 1.0.1 + side-channel-weakmap: 1.0.2 + + siginfo@2.0.0: {} + + signal-exit@3.0.7: {} + + signal-exit@4.1.0: {} + + simple-concat@1.0.1: + optional: true + + simple-eval@1.0.1: + dependencies: + jsep: 1.4.0 + + simple-get@4.0.1: + dependencies: + decompress-response: 6.0.0 + once: 1.4.0 + simple-concat: 1.0.1 + optional: true + + sirv@3.0.1: + dependencies: + '@polka/url': 1.0.0-next.29 + mrmime: 2.0.1 + totalist: 3.0.1 + + sisteransi@1.0.5: {} + + slash@3.0.0: {} + + slash@5.1.0: {} + + slice-ansi@4.0.0: + dependencies: + ansi-styles: 4.3.0 + astral-regex: 2.0.0 + is-fullwidth-code-point: 3.0.0 + + solid-js@1.9.7: + dependencies: + csstype: 3.1.3 + seroval: 1.3.2 + seroval-plugins: 1.3.3(seroval@1.3.2) + + source-map-js@1.2.1: {} + + source-map-support@0.5.21: + dependencies: + buffer-from: 1.1.2 + source-map: 0.6.1 + + source-map@0.6.1: {} + + source-map@0.7.4: {} + + space-separated-tokens@2.0.2: {} + + spdx-correct@3.2.0: + dependencies: + spdx-expression-parse: 3.0.1 + spdx-license-ids: 3.0.21 + + spdx-exceptions@2.5.0: {} + + spdx-expression-parse@3.0.1: + dependencies: + spdx-exceptions: 2.5.0 + spdx-license-ids: 3.0.21 + + spdx-license-ids@3.0.21: {} + + split.js@1.6.5: {} + + sprintf-js@1.0.3: {} + + stackback@0.0.2: {} + + std-env@3.9.0: {} + + stdin-discarder@0.2.2: {} + + stop-iteration-iterator@1.1.0: + dependencies: + es-errors: 1.3.0 + internal-slot: 1.1.0 + + storybook@9.0.18(@testing-library/dom@10.4.1)(prettier@3.6.2): + dependencies: + '@storybook/global': 5.0.0 + '@testing-library/jest-dom': 6.6.3 + '@testing-library/user-event': 14.6.1(@testing-library/dom@10.4.1) + '@vitest/expect': 3.2.4 + '@vitest/spy': 3.2.4 + better-opn: 3.0.2 + esbuild: 0.25.8 + esbuild-register: 3.6.0(esbuild@0.25.8) + recast: 0.23.11 + semver: 7.7.2 + ws: 8.18.3 + optionalDependencies: + prettier: 3.6.2 + transitivePeerDependencies: + - '@testing-library/dom' + - bufferutil + - supports-color + - utf-8-validate + + storybook@9.1.5(@testing-library/dom@10.4.1)(prettier@3.6.2)(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)): + dependencies: + '@storybook/global': 5.0.0 + '@testing-library/jest-dom': 6.6.3 + '@testing-library/user-event': 14.6.1(@testing-library/dom@10.4.1) + '@vitest/expect': 3.2.4 + '@vitest/mocker': 3.2.4(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)) + '@vitest/spy': 3.2.4 + better-opn: 3.0.2 + esbuild: 0.25.8 + esbuild-register: 3.6.0(esbuild@0.25.8) + recast: 0.23.11 + semver: 7.7.2 + ws: 8.18.3 + optionalDependencies: + prettier: 3.6.2 + transitivePeerDependencies: + - '@testing-library/dom' + - bufferutil + - msw + - supports-color + - utf-8-validate + - vite + + string-argv@0.3.2: {} + + string-width@4.2.3: + dependencies: + emoji-regex: 8.0.0 + is-fullwidth-code-point: 3.0.0 + strip-ansi: 6.0.1 + + string-width@5.1.2: + dependencies: + eastasianwidth: 0.2.0 + emoji-regex: 9.2.2 + strip-ansi: 7.1.0 + + string-width@7.2.0: + dependencies: + emoji-regex: 10.4.0 + get-east-asian-width: 1.3.0 + strip-ansi: 7.1.0 + + string.prototype.trim@1.2.10: + dependencies: + call-bind: 1.0.8 + call-bound: 1.0.4 + define-data-property: 1.1.4 + define-properties: 1.2.1 + es-abstract: 1.24.0 + es-object-atoms: 1.1.1 + has-property-descriptors: 1.0.2 + + string.prototype.trimend@1.0.9: + dependencies: + call-bind: 1.0.8 + call-bound: 1.0.4 + define-properties: 1.2.1 + es-object-atoms: 1.1.1 + + string.prototype.trimstart@1.0.8: + dependencies: + call-bind: 1.0.8 + define-properties: 1.2.1 + es-object-atoms: 1.1.1 + + string_decoder@1.1.1: + dependencies: + safe-buffer: 5.1.2 + + string_decoder@1.3.0: + dependencies: + safe-buffer: 5.2.1 + optional: true + + stringify-entities@4.0.4: + dependencies: + character-entities-html4: 2.1.0 + character-entities-legacy: 3.0.0 + + strip-ansi@6.0.1: + dependencies: + ansi-regex: 5.0.1 + + strip-ansi@7.1.0: + dependencies: + ansi-regex: 6.1.0 + + strip-bom@3.0.0: {} + + strip-bom@4.0.0: {} + + strip-final-newline@2.0.0: {} + + strip-indent@3.0.0: + dependencies: + min-indent: 1.0.1 + + strip-indent@4.0.0: + dependencies: + min-indent: 1.0.1 + + strip-json-comments@2.0.1: + optional: true + + strip-json-comments@3.1.1: {} + + strip-literal@3.0.0: + dependencies: + js-tokens: 9.0.1 + + structured-source@4.0.0: + dependencies: + boundary: 2.0.0 + + style-mod@4.1.2: {} + + style-to-js@1.1.17: + dependencies: + style-to-object: 1.0.9 + + style-to-object@1.0.9: + dependencies: + inline-style-parser: 0.2.4 + + sucrase@3.35.0: + dependencies: + '@jridgewell/gen-mapping': 0.3.13 + commander: 4.1.1 + glob: 10.4.5 + lines-and-columns: 1.2.4 + mz: 2.7.0 + pirates: 4.0.7 + ts-interface-checker: 0.1.13 + + supports-color@7.2.0: + dependencies: + has-flag: 4.0.0 + + supports-color@8.1.1: + dependencies: + has-flag: 4.0.0 + + supports-color@9.4.0: {} + + supports-hyperlinks@3.2.0: + dependencies: + has-flag: 4.0.0 + supports-color: 7.2.0 + + supports-preserve-symlinks-flag@1.0.0: {} + + swagger2openapi@7.0.8: + dependencies: + call-me-maybe: 1.0.2 + node-fetch: 2.7.0 + node-fetch-h2: 2.3.0 + node-readfiles: 0.2.0 + oas-kit-common: 1.0.8 + oas-resolver: 2.5.6 + oas-schema-walker: 1.1.5 + oas-validator: 5.0.8 + reftools: 1.1.9 + yaml: 1.10.2 + yargs: 17.7.2 + transitivePeerDependencies: + - encoding + + symbol-tree@3.2.4: {} + + syncpack@13.0.4(typescript@5.8.3): + dependencies: + chalk: 5.4.1 + chalk-template: 1.1.0 + commander: 13.1.0 + cosmiconfig: 9.0.0(typescript@5.8.3) + effect: 3.17.9 + enquirer: 2.4.1 + fast-check: 3.23.2 + globby: 14.1.0 + jsonc-parser: 3.3.1 + minimatch: 9.0.5 + npm-package-arg: 12.0.2 + ora: 8.2.0 + prompts: 2.4.2 + read-yaml-file: 2.1.0 + semver: 7.7.2 + tightrope: 0.2.0 + ts-toolbelt: 9.6.0 + transitivePeerDependencies: + - typescript + + tabbable@6.2.0: {} + + table-layout@4.1.1: + dependencies: + array-back: 6.2.2 + wordwrapjs: 5.1.0 + + table@6.9.0: + dependencies: + ajv: 8.17.1 + lodash.truncate: 4.4.2 + slice-ansi: 4.0.0 + string-width: 4.2.3 + strip-ansi: 6.0.1 + + tailwind-merge@3.3.1: {} + + tailwind-scrollbar@3.1.0(tailwindcss@3.4.17): + dependencies: + tailwindcss: 3.4.17 + + tailwindcss@3.4.17: + dependencies: + '@alloc/quick-lru': 5.2.0 + arg: 5.0.2 + chokidar: 3.6.0 + didyoumean: 1.2.2 + dlv: 1.1.3 + fast-glob: 3.3.3 + glob-parent: 6.0.2 + is-glob: 4.0.3 + jiti: 1.21.7 + lilconfig: 3.1.3 + micromatch: 4.0.8 + normalize-path: 3.0.0 + object-hash: 3.0.0 + picocolors: 1.1.1 + postcss: 8.5.6 + postcss-import: 15.1.0(postcss@8.5.6) + postcss-js: 4.0.1(postcss@8.5.6) + postcss-load-config: 4.0.2(postcss@8.5.6) + postcss-nested: 6.2.0(postcss@8.5.6) + postcss-selector-parser: 6.1.2 + resolve: 1.22.10 + sucrase: 3.35.0 + transitivePeerDependencies: + - ts-node + + tailwindcss@4.1.11: {} + + tapable@2.2.2: {} + + tapable@2.2.3: {} + + tar-fs@2.1.3: + dependencies: + chownr: 1.1.4 + mkdirp-classic: 0.5.3 + pump: 3.0.3 + tar-stream: 2.2.0 + optional: true + + tar-stream@2.2.0: + dependencies: + bl: 4.1.0 + end-of-stream: 1.4.5 + fs-constants: 1.0.0 + inherits: 2.0.4 + readable-stream: 3.6.2 + optional: true + + tar@7.4.3: + dependencies: + '@isaacs/fs-minipass': 4.0.1 + chownr: 3.0.0 + minipass: 7.1.2 + minizlib: 3.0.2 + mkdirp: 3.0.1 + yallist: 5.0.0 + + terminal-link@4.0.0: + dependencies: + ansi-escapes: 7.0.0 + supports-hyperlinks: 3.2.0 + + terser-webpack-plugin@5.3.14(esbuild@0.25.8)(webpack@5.99.8(esbuild@0.25.8)): + dependencies: + '@jridgewell/trace-mapping': 0.3.31 + jest-worker: 27.5.1 + schema-utils: 4.3.2 + serialize-javascript: 6.0.2 + terser: 5.44.0 + webpack: 5.99.8(esbuild@0.25.8) + optionalDependencies: + esbuild: 0.25.8 + + terser@5.44.0: + dependencies: + '@jridgewell/source-map': 0.3.11 + acorn: 8.15.0 + commander: 2.20.3 + source-map-support: 0.5.21 + + test-exclude@6.0.0: + dependencies: + '@istanbuljs/schema': 0.1.3 + glob: 7.2.3 + minimatch: 3.1.2 + + test-exclude@7.0.1: + dependencies: + '@istanbuljs/schema': 0.1.3 + glob: 10.4.5 + minimatch: 9.0.5 + + text-table@0.2.0: {} + + textextensions@6.11.0: + dependencies: + editions: 6.21.0 + + thememirror@2.0.1(@codemirror/language@6.11.2)(@codemirror/state@6.5.2)(@codemirror/view@6.38.1): + dependencies: + '@codemirror/language': 6.11.2 + '@codemirror/state': 6.5.2 + '@codemirror/view': 6.38.1 + + thenify-all@1.6.0: + dependencies: + thenify: 3.3.1 + + thenify@3.3.1: + dependencies: + any-promise: 1.3.0 + + tightrope@0.2.0: {} + + tiny-invariant@1.3.3: {} + + tiny-warning@1.0.3: {} + + tinybench@2.9.0: {} + + tinyexec@0.3.2: {} + + tinyglobby@0.2.14: + dependencies: + fdir: 6.4.6(picomatch@4.0.3) + picomatch: 4.0.3 + + tinypool@1.1.1: {} + + tinyrainbow@2.0.0: {} + + tinyspy@4.0.3: {} + + tldts-core@6.1.86: {} + + tldts@6.1.86: + dependencies: + tldts-core: 6.1.86 + + tmp@0.2.3: {} + + to-regex-range@5.0.1: + dependencies: + is-number: 7.0.0 + + totalist@3.0.1: {} + + tough-cookie@5.1.2: + dependencies: + tldts: 6.1.86 + + tr46@0.0.3: {} + + tr46@5.1.1: + dependencies: + punycode: 2.3.1 + + trim-lines@3.0.1: {} + + trough@2.2.0: {} + + ts-api-utils@2.1.0(typescript@5.8.3): + dependencies: + typescript: 5.8.3 + + ts-dedent@2.2.0: {} + + ts-interface-checker@0.1.13: {} + + ts-loader@9.5.2(typescript@5.8.3)(webpack@5.99.8(esbuild@0.25.8)): + dependencies: + chalk: 4.1.2 + enhanced-resolve: 5.18.2 + micromatch: 4.0.8 + semver: 7.7.2 + source-map: 0.7.4 + typescript: 5.8.3 + webpack: 5.99.8(esbuild@0.25.8) + + ts-toolbelt@9.6.0: {} + + tsconfck@2.1.2(typescript@5.8.3): + optionalDependencies: + typescript: 5.8.3 + + tsconfig-paths@4.2.0: + dependencies: + json5: 2.2.3 + minimist: 1.2.8 + strip-bom: 3.0.0 + + tslib@1.14.1: {} + + tslib@2.8.1: {} + + tsx@4.20.3: + dependencies: + esbuild: 0.25.8 + get-tsconfig: 4.10.1 + optionalDependencies: + fsevents: 2.3.3 + + tunnel-agent@0.6.0: + dependencies: + safe-buffer: 5.2.1 + optional: true + + tunnel@0.0.6: {} + + type-check@0.4.0: + dependencies: + prelude-ls: 1.2.1 + + type-fest@4.41.0: {} + + typed-array-buffer@1.0.3: + dependencies: + call-bound: 1.0.4 + es-errors: 1.3.0 + is-typed-array: 1.1.15 + + typed-array-byte-length@1.0.3: + dependencies: + call-bind: 1.0.8 + for-each: 0.3.5 + gopd: 1.2.0 + has-proto: 1.2.0 + is-typed-array: 1.1.15 + + typed-array-byte-offset@1.0.4: + dependencies: + available-typed-arrays: 1.0.7 + call-bind: 1.0.8 + for-each: 0.3.5 + gopd: 1.2.0 + has-proto: 1.2.0 + is-typed-array: 1.1.15 + reflect.getprototypeof: 1.0.10 + + typed-array-length@1.0.7: + dependencies: + call-bind: 1.0.8 + for-each: 0.3.5 + gopd: 1.2.0 + is-typed-array: 1.1.15 + possible-typed-array-names: 1.1.0 + reflect.getprototypeof: 1.0.10 + + typed-rest-client@1.8.11: + dependencies: + qs: 6.14.0 + tunnel: 0.0.6 + underscore: 1.13.7 + + typedoc-plugin-markdown@4.7.1(typedoc@0.28.7(typescript@5.8.3)): + dependencies: + typedoc: 0.28.7(typescript@5.8.3) + + typedoc@0.28.7(typescript@5.8.3): + dependencies: + '@gerrit0/mini-shiki': 3.8.1 + lunr: 2.3.9 + markdown-it: 14.1.0 + minimatch: 9.0.5 + typescript: 5.8.3 + yaml: 2.8.0 + + typescript-eslint@8.38.0(eslint@9.31.0(jiti@2.4.2))(typescript@5.8.3): + dependencies: + '@typescript-eslint/eslint-plugin': 8.38.0(@typescript-eslint/parser@8.38.0(eslint@9.31.0(jiti@2.4.2))(typescript@5.8.3))(eslint@9.31.0(jiti@2.4.2))(typescript@5.8.3) + '@typescript-eslint/parser': 8.38.0(eslint@9.31.0(jiti@2.4.2))(typescript@5.8.3) + '@typescript-eslint/typescript-estree': 8.38.0(typescript@5.8.3) + '@typescript-eslint/utils': 8.38.0(eslint@9.31.0(jiti@2.4.2))(typescript@5.8.3) + eslint: 9.31.0(jiti@2.4.2) + typescript: 5.8.3 + transitivePeerDependencies: + - supports-color + + typescript@5.8.2: {} + + typescript@5.8.3: {} + + typical@7.3.0: {} + + uc.micro@2.1.0: {} + + ufo@1.6.1: {} + + unbox-primitive@1.1.0: + dependencies: + call-bound: 1.0.4 + has-bigints: 1.1.0 + has-symbols: 1.1.0 + which-boxed-primitive: 1.1.1 + + underscore@1.13.7: {} + + undici-types@5.26.5: {} + + undici-types@6.21.0: {} + + undici-types@7.8.0: + optional: true + + undici@7.12.0: {} + + unicorn-magic@0.1.0: {} + + unicorn-magic@0.3.0: {} + + unified@11.0.5: + dependencies: + '@types/unist': 3.0.3 + bail: 2.0.2 + devlop: 1.1.0 + extend: 3.0.2 + is-plain-obj: 4.1.0 + trough: 2.2.0 + vfile: 6.0.3 + + unist-util-is@6.0.0: + dependencies: + '@types/unist': 3.0.3 + + unist-util-position@5.0.0: + dependencies: + '@types/unist': 3.0.3 + + unist-util-stringify-position@4.0.0: + dependencies: + '@types/unist': 3.0.3 + + unist-util-visit-parents@6.0.1: + dependencies: + '@types/unist': 3.0.3 + unist-util-is: 6.0.0 + + unist-util-visit@5.0.0: + dependencies: + '@types/unist': 3.0.3 + unist-util-is: 6.0.0 + unist-util-visit-parents: 6.0.1 + + universalify@2.0.1: {} + + unplugin@1.16.1: + dependencies: + acorn: 8.15.0 + webpack-virtual-modules: 0.6.2 + + unplugin@2.3.5: + dependencies: + acorn: 8.15.0 + picomatch: 4.0.3 + webpack-virtual-modules: 0.6.2 + + update-browserslist-db@1.1.3(browserslist@4.26.2): + dependencies: + browserslist: 4.26.2 + escalade: 3.2.0 + picocolors: 1.1.1 + + uri-js@4.4.1: + dependencies: + punycode: 2.3.1 + + urijs@1.19.11: {} + + url-join@4.0.1: {} + + use-callback-ref@1.3.3(@types/react@18.3.23)(react@18.3.1): + dependencies: + react: 18.3.1 + tslib: 2.8.1 + optionalDependencies: + '@types/react': 18.3.23 + + use-sidecar@1.1.3(@types/react@18.3.23)(react@18.3.1): + dependencies: + detect-node-es: 1.1.0 + react: 18.3.1 + tslib: 2.8.1 + optionalDependencies: + '@types/react': 18.3.23 + + use-sync-external-store@1.5.0(react@18.3.1): + dependencies: + react: 18.3.1 + + util-deprecate@1.0.2: {} + + utility-types@3.11.0: {} + + uuid@8.3.2: {} + + v8-to-istanbul@9.3.0: + dependencies: + '@jridgewell/trace-mapping': 0.3.29 + '@types/istanbul-lib-coverage': 2.0.6 + convert-source-map: 2.0.0 + + validate-npm-package-license@3.0.4: + dependencies: + spdx-correct: 3.2.0 + spdx-expression-parse: 3.0.1 + + validate-npm-package-name@6.0.2: {} + + validator@13.15.15: {} + + version-range@4.14.0: {} + + vfile-message@4.0.2: + dependencies: + '@types/unist': 3.0.3 + unist-util-stringify-position: 4.0.0 + + vfile@6.0.3: + dependencies: + '@types/unist': 3.0.3 + vfile-message: 4.0.2 + + vite-node@3.2.4(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0): + dependencies: + cac: 6.7.14 + debug: 4.4.1 + es-module-lexer: 1.7.0 + pathe: 2.0.3 + vite: 6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + transitivePeerDependencies: + - '@types/node' + - jiti + - less + - lightningcss + - sass + - sass-embedded + - stylus + - sugarss + - supports-color + - terser + - tsx + - yaml + + vite-node@3.2.4(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0): + dependencies: + cac: 6.7.14 + debug: 4.4.1(supports-color@8.1.1) + es-module-lexer: 1.7.0 + pathe: 2.0.3 + vite: 6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + transitivePeerDependencies: + - '@types/node' + - jiti + - less + - lightningcss + - sass + - sass-embedded + - stylus + - sugarss + - supports-color + - terser + - tsx + - yaml + + vite-plugin-css-injected-by-js@3.5.2(vite@6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)): + dependencies: + vite: 6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + + vite-plugin-dts@4.5.4(@types/node@20.11.25)(rollup@4.45.1)(typescript@5.8.3)(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)): + dependencies: + '@microsoft/api-extractor': 7.52.10(@types/node@20.11.25) + '@rollup/pluginutils': 5.2.0(rollup@4.45.1) + '@volar/typescript': 2.4.23 + '@vue/language-core': 2.2.0(typescript@5.8.3) + compare-versions: 6.1.1 + debug: 4.4.1 + kolorist: 1.8.0 + local-pkg: 1.1.1 + magic-string: 0.30.17 + typescript: 5.8.3 + optionalDependencies: + vite: 6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + transitivePeerDependencies: + - '@types/node' + - rollup + - supports-color + + vite-plugin-static-copy@3.1.1(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)): + dependencies: + chokidar: 3.6.0 + fs-extra: 11.3.0 + p-map: 7.0.3 + picocolors: 1.1.1 + tinyglobby: 0.2.14 + vite: 6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + + vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0): + dependencies: + esbuild: 0.25.8 + fdir: 6.4.6(picomatch@4.0.3) + picomatch: 4.0.3 + postcss: 8.5.6 + rollup: 4.45.1 + tinyglobby: 0.2.14 + optionalDependencies: + '@types/node': 20.11.25 + fsevents: 2.3.3 + jiti: 2.4.2 + lightningcss: 1.30.1 + terser: 5.44.0 + tsx: 4.20.3 + yaml: 2.8.0 + + vite@6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0): + dependencies: + esbuild: 0.25.8 + fdir: 6.4.6(picomatch@4.0.3) + picomatch: 4.0.3 + postcss: 8.5.6 + rollup: 4.45.1 + tinyglobby: 0.2.14 + optionalDependencies: + '@types/node': 24.1.0 + fsevents: 2.3.3 + jiti: 2.4.2 + lightningcss: 1.30.1 + terser: 5.44.0 + tsx: 4.20.3 + yaml: 2.8.0 + + vitest@3.2.4(@types/debug@4.1.12)(@types/node@20.11.25)(@vitest/browser@3.2.4)(@vitest/ui@3.2.4)(jiti@2.4.2)(jsdom@26.1.0)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0): + dependencies: + '@types/chai': 5.2.2 + '@vitest/expect': 3.2.4 + '@vitest/mocker': 3.2.4(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)) + '@vitest/pretty-format': 3.2.4 + '@vitest/runner': 3.2.4 + '@vitest/snapshot': 3.2.4 + '@vitest/spy': 3.2.4 + '@vitest/utils': 3.2.4 + chai: 5.2.1 + debug: 4.4.1 + expect-type: 1.2.2 + magic-string: 0.30.17 + pathe: 2.0.3 + picomatch: 4.0.3 + std-env: 3.9.0 + tinybench: 2.9.0 + tinyexec: 0.3.2 + tinyglobby: 0.2.14 + tinypool: 1.1.1 + tinyrainbow: 2.0.0 + vite: 6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + vite-node: 3.2.4(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + why-is-node-running: 2.3.0 + optionalDependencies: + '@types/debug': 4.1.12 + '@types/node': 20.11.25 + '@vitest/browser': 3.2.4(playwright@1.54.1)(vite@6.3.5(@types/node@20.11.25)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0))(vitest@3.2.4) + '@vitest/ui': 3.2.4(vitest@3.2.4) + jsdom: 26.1.0 + transitivePeerDependencies: + - jiti + - less + - lightningcss + - msw + - sass + - sass-embedded + - stylus + - sugarss + - supports-color + - terser + - tsx + - yaml + + vitest@3.2.4(@types/debug@4.1.12)(@types/node@24.1.0)(@vitest/browser@3.2.3)(@vitest/ui@3.2.4)(jiti@2.4.2)(jsdom@26.1.0)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0): + dependencies: + '@types/chai': 5.2.2 + '@vitest/expect': 3.2.4 + '@vitest/mocker': 3.2.4(vite@6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)) + '@vitest/pretty-format': 3.2.4 + '@vitest/runner': 3.2.4 + '@vitest/snapshot': 3.2.4 + '@vitest/spy': 3.2.4 + '@vitest/utils': 3.2.4 + chai: 5.2.1 + debug: 4.4.1(supports-color@8.1.1) + expect-type: 1.2.2 + magic-string: 0.30.17 + pathe: 2.0.3 + picomatch: 4.0.3 + std-env: 3.9.0 + tinybench: 2.9.0 + tinyexec: 0.3.2 + tinyglobby: 0.2.14 + tinypool: 1.1.1 + tinyrainbow: 2.0.0 + vite: 6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + vite-node: 3.2.4(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + why-is-node-running: 2.3.0 + optionalDependencies: + '@types/debug': 4.1.12 + '@types/node': 24.1.0 + '@vitest/browser': 3.2.3(playwright@1.54.1)(vite@6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0))(vitest@3.2.4) + '@vitest/ui': 3.2.4(vitest@3.2.4) + jsdom: 26.1.0 + transitivePeerDependencies: + - jiti + - less + - lightningcss + - msw + - sass + - sass-embedded + - stylus + - sugarss + - supports-color + - terser + - tsx + - yaml + + vitest@3.2.4(@types/debug@4.1.12)(@types/node@24.1.0)(@vitest/browser@3.2.4)(@vitest/ui@3.2.4)(jiti@2.4.2)(jsdom@26.1.0)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0): + dependencies: + '@types/chai': 5.2.2 + '@vitest/expect': 3.2.4 + '@vitest/mocker': 3.2.4(vite@6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0)) + '@vitest/pretty-format': 3.2.4 + '@vitest/runner': 3.2.4 + '@vitest/snapshot': 3.2.4 + '@vitest/spy': 3.2.4 + '@vitest/utils': 3.2.4 + chai: 5.2.1 + debug: 4.4.1(supports-color@8.1.1) + expect-type: 1.2.2 + magic-string: 0.30.17 + pathe: 2.0.3 + picomatch: 4.0.3 + std-env: 3.9.0 + tinybench: 2.9.0 + tinyexec: 0.3.2 + tinyglobby: 0.2.14 + tinypool: 1.1.1 + tinyrainbow: 2.0.0 + vite: 6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + vite-node: 3.2.4(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0) + why-is-node-running: 2.3.0 + optionalDependencies: + '@types/debug': 4.1.12 + '@types/node': 24.1.0 + '@vitest/browser': 3.2.4(playwright@1.54.1)(vite@6.3.5(@types/node@24.1.0)(jiti@2.4.2)(lightningcss@1.30.1)(terser@5.44.0)(tsx@4.20.3)(yaml@2.8.0))(vitest@3.2.4) + '@vitest/ui': 3.2.4(vitest@3.2.4) + jsdom: 26.1.0 + transitivePeerDependencies: + - jiti + - less + - lightningcss + - msw + - sass + - sass-embedded + - stylus + - sugarss + - supports-color + - terser + - tsx + - yaml + + vscode-jsonrpc@8.2.0: {} + + vscode-jsonrpc@8.2.1: {} + + vscode-languageclient@9.0.1: + dependencies: + minimatch: 5.1.6 + semver: 7.7.2 + vscode-languageserver-protocol: 3.17.5 + + vscode-languageserver-protocol@3.17.5: + dependencies: + vscode-jsonrpc: 8.2.0 + vscode-languageserver-types: 3.17.5 + + vscode-languageserver-types@3.17.5: {} + + vscode-uri@3.1.0: {} + + w3c-keyname@2.2.8: {} + + w3c-xmlserializer@5.0.0: + dependencies: + xml-name-validator: 5.0.0 + + watchpack@2.4.4: + dependencies: + glob-to-regexp: 0.4.1 + graceful-fs: 4.2.11 + + web-vitals@4.2.4: {} + + webidl-conversions@3.0.1: {} + + webidl-conversions@7.0.0: {} + + webpack-sources@3.3.3: {} + + webpack-virtual-modules@0.6.2: {} + + webpack@5.99.8(esbuild@0.25.8): + dependencies: + '@types/eslint-scope': 3.7.7 + '@types/estree': 1.0.8 + '@types/json-schema': 7.0.15 + '@webassemblyjs/ast': 1.14.1 + '@webassemblyjs/wasm-edit': 1.14.1 + '@webassemblyjs/wasm-parser': 1.14.1 + acorn: 8.15.0 + browserslist: 4.26.2 + chrome-trace-event: 1.0.4 + enhanced-resolve: 5.18.3 + es-module-lexer: 1.7.0 + eslint-scope: 5.1.1 + events: 3.3.0 + glob-to-regexp: 0.4.1 + graceful-fs: 4.2.11 + json-parse-even-better-errors: 2.3.1 + loader-runner: 4.3.0 + mime-types: 2.1.35 + neo-async: 2.6.2 + schema-utils: 4.3.2 + tapable: 2.2.3 + terser-webpack-plugin: 5.3.14(esbuild@0.25.8)(webpack@5.99.8(esbuild@0.25.8)) + watchpack: 2.4.4 + webpack-sources: 3.3.3 + transitivePeerDependencies: + - '@swc/core' + - esbuild + - uglify-js + + whatwg-encoding@3.1.1: + dependencies: + iconv-lite: 0.6.3 + + whatwg-mimetype@4.0.0: {} + + whatwg-url@14.2.0: + dependencies: + tr46: 5.1.1 + webidl-conversions: 7.0.0 + + whatwg-url@5.0.0: + dependencies: + tr46: 0.0.3 + webidl-conversions: 3.0.1 + + which-boxed-primitive@1.1.1: + dependencies: + is-bigint: 1.1.0 + is-boolean-object: 1.2.2 + is-number-object: 1.1.1 + is-string: 1.1.1 + is-symbol: 1.1.1 + + which-builtin-type@1.2.1: + dependencies: + call-bound: 1.0.4 + function.prototype.name: 1.1.8 + has-tostringtag: 1.0.2 + is-async-function: 2.1.1 + is-date-object: 1.1.0 + is-finalizationregistry: 1.1.1 + is-generator-function: 1.1.0 + is-regex: 1.2.1 + is-weakref: 1.1.1 + isarray: 2.0.5 + which-boxed-primitive: 1.1.1 + which-collection: 1.0.2 + which-typed-array: 1.1.19 + + which-collection@1.0.2: + dependencies: + is-map: 2.0.3 + is-set: 2.0.3 + is-weakmap: 2.0.2 + is-weakset: 2.0.4 + + which-typed-array@1.1.19: + dependencies: + available-typed-arrays: 1.0.7 + call-bind: 1.0.8 + call-bound: 1.0.4 + for-each: 0.3.5 + get-proto: 1.0.1 + gopd: 1.2.0 + has-tostringtag: 1.0.2 + + which@2.0.2: + dependencies: + isexe: 2.0.0 + + why-is-node-running@2.3.0: + dependencies: + siginfo: 2.0.0 + stackback: 0.0.2 + + word-wrap@1.2.5: {} + + wordwrapjs@5.1.0: {} + + workerpool@6.5.1: {} + + wrap-ansi@7.0.0: + dependencies: + ansi-styles: 4.3.0 + string-width: 4.2.3 + strip-ansi: 6.0.1 + + wrap-ansi@8.1.0: + dependencies: + ansi-styles: 6.2.1 + string-width: 5.1.2 + strip-ansi: 7.1.0 + + wrappy@1.0.2: {} + + ws@8.18.3: {} + + wsl-utils@0.1.0: + dependencies: + is-wsl: 3.1.0 + + xml-name-validator@5.0.0: {} + + xml2js@0.5.0: + dependencies: + sax: 1.4.1 + xmlbuilder: 11.0.1 + + xmlbuilder@11.0.1: {} + + xmlchars@2.2.0: {} + + y18n@5.0.8: {} + + yallist@3.1.1: {} + + yallist@4.0.0: {} + + yallist@5.0.0: {} + + yaml@1.10.2: {} + + yaml@2.8.0: {} + + yargs-parser@20.2.9: {} + + yargs-parser@21.1.1: {} + + yargs-unparser@2.0.0: + dependencies: + camelcase: 6.3.0 + decamelize: 4.0.0 + flat: 5.0.2 + is-plain-obj: 2.1.0 + + yargs@16.2.0: + dependencies: + cliui: 7.0.4 + escalade: 3.2.0 + get-caller-file: 2.0.5 + require-directory: 2.1.1 + string-width: 4.2.3 + y18n: 5.0.8 + yargs-parser: 20.2.9 + + yargs@17.7.2: + dependencies: + cliui: 8.0.1 + escalade: 3.2.0 + get-caller-file: 2.0.5 + require-directory: 2.1.1 + string-width: 4.2.3 + y18n: 5.0.8 + yargs-parser: 21.1.1 + + yauzl@2.10.0: + dependencies: + buffer-crc32: 0.2.13 + fd-slicer: 1.1.0 + + yazl@2.5.1: + dependencies: + buffer-crc32: 0.2.13 + + yocto-queue@0.1.0: {} + + yocto-queue@1.2.1: {} + + zod@3.25.76: {} + + zustand@4.5.7(@types/react@18.3.23)(immer@9.0.21)(react@18.3.1): + dependencies: + use-sync-external-store: 1.5.0(react@18.3.1) + optionalDependencies: + '@types/react': 18.3.23 + immer: 9.0.21 + react: 18.3.1 + + zustand@5.0.6(@types/react@18.3.23)(immer@9.0.21)(react@18.3.1)(use-sync-external-store@1.5.0(react@18.3.1)): + optionalDependencies: + '@types/react': 18.3.23 + immer: 9.0.21 + react: 18.3.1 + use-sync-external-store: 1.5.0(react@18.3.1) + + zwitch@2.0.4: {} diff --git a/pnpm-workspace.yaml b/pnpm-workspace.yaml new file mode 100644 index 0000000000..fe1b5be597 --- /dev/null +++ b/pnpm-workspace.yaml @@ -0,0 +1,6 @@ +packages: + - vscode/bus + - vscode/extension + - vscode/react + - web/client + - web/common diff --git a/posts/virtual_data_environments.md b/posts/virtual_data_environments.md index dc3b2cb46e..5cde9dba51 100644 --- a/posts/virtual_data_environments.md +++ b/posts/virtual_data_environments.md @@ -8,7 +8,7 @@ In this post, I'm going to explain why existing approaches to managing developme I'll introduce [Virtual Data Environments](#virtual-data-environments-1) - a novel approach that provides low-cost, efficient, scalable, and safe data environments that are easy to use and manage. They significantly boost the productivity of anyone who has to create or maintain data pipelines. -Finally, I’m going to explain how **Virtual Data Environments** are implemented in [SQLMesh](https://github.com/TobikoData/sqlmesh) and share details on each core component involved: +Finally, I’m going to explain how **Virtual Data Environments** are implemented in [SQLMesh](https://github.com/SQLMesh/sqlmesh) and share details on each core component involved: - Data [fingerprinting](#fingerprinting) - [Automatic change categorization](#automatic-change-categorization) - Decoupling of [physical](#physical-layer) and [virtual](#virtual-layer) layers @@ -156,6 +156,6 @@ With **Virtual Data Environments**, SQLMesh is able to provide fully **isolated* - Rolling back a change happens almost instantaneously since no data movement is involved and only views that are part of the **virtual layer** get updated. - Deploying changes to production is a **virtual layer** operation, which ensures that results observed during development are exactly the same in production and that data and code are always in sync. -To streamline deploying changes to production, our team is about to release the SQLMesh [CI/CD bot](https://github.com/TobikoData/sqlmesh/blob/main/docs/integrations/github.md), which will help automate this process. +To streamline deploying changes to production, our team is about to release the SQLMesh [CI/CD bot](https://github.com/SQLMesh/sqlmesh/blob/main/docs/integrations/github.md), which will help automate this process. Don't miss out - join our [Slack channel](https://tobikodata.com/slack) and stay tuned! diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000..ebfc112567 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,317 @@ +[project] +name = "sqlmesh" +dynamic = ["version"] +description = "Next-generation data transformation framework" +readme = "README.md" +authors = [{ name = "SQLMesh Contributors" }] +license = { file = "LICENSE" } +requires-python = ">= 3.9" +dependencies = [ + "astor", + "click", + "croniter", + "duckdb>=0.10.0,!=0.10.3", + "dateparser<=1.2.1", + "humanize", + "hyperscript>=0.1.0", + "importlib-metadata; python_version<'3.12'", + "ipywidgets", + "jinja2", + "packaging", + "pandas<3.0.0", + "pydantic>=2.0.0", + "python-dotenv", + "requests", + "rich[jupyter]", + "ruamel.yaml", + "sqlglot[rs]~=28.10.1", + "tenacity", + "time-machine", + "json-stream" +] +classifiers = [ + "Intended Audience :: Developers", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", + "Programming Language :: SQL", + "Programming Language :: Python :: 3 :: Only", +] + +[project.optional-dependencies] +athena = ["PyAthena[Pandas]"] +azuresql = ["pymssql"] +azuresql-odbc = ["pyodbc>=5.0.0"] +bigquery = [ + "google-cloud-bigquery[pandas]", + "google-cloud-bigquery-storage" +] +# bigframes has to be separate to support environments with an older google-cloud-bigquery pin +# this is because that pin pulls in an older bigframes and the bigframes team +# pinned an older SQLGlot which is incompatible with SQLMesh +bigframes = ["bigframes>=1.32.0"] +clickhouse = ["clickhouse-connect"] +databricks = ["databricks-sql-connector[pyarrow]"] +dev = [ + "agate", + "beautifulsoup4", + "clickhouse-connect", + "cryptography", + "databricks-sql-connector", + "dbt-bigquery", + "dbt-core", + "dbt-duckdb>=1.7.1", + # version 1.10.1 of dbt-snowflake declares that it's compatible with dbt-adapters>=1.16 but in reality + # it depends on the 'InvalidCatalogIntegrationConfigError' class that only exists as of dbt-adapters==1.16.6 + # so we exclude it to prevent failures and hope that upstream releases a new version with the correct constraint + "dbt-snowflake!=1.10.1", + "dbt-athena-community", + "dbt-clickhouse", + "dbt-databricks", + "dbt-redshift", + "dbt-trino", + "Faker", + "google-auth", + "google-cloud-bigquery", + "google-cloud-bigquery-storage", + "httpx", + "mypy~=1.13.0", + "numpy", + "pandas-stubs", + "pre-commit", + "psycopg2-binary", + "pydantic", + "PyAthena[Pandas]", + "PyGithub>=2.6.0", + "pyodbc>=5.0.0", + "pyperf", + "pyspark~=3.5.0", + "pytest", + "pytest-asyncio", + "pytest-mock", + "pytest-rerunfailures", + "pytest-xdist", + "pytz", + "redshift_connector", + "ruff~=0.11.0", + "snowflake-connector-python[pandas,secure-local-storage]>=3.0.2", + "sqlalchemy-stubs", + "trino", + "types-croniter", + "types-dateparser", + "types-PyMySQL", + "types-python-dateutil", + "types-pytz", + "types-requests==2.28.8", + "typing-extensions", +] +dbt = ["dbt-core<2"] +dlt = ["dlt"] +duckdb = [] +fabric = ["pyodbc>=5.0.0"] +gcppostgres = ["cloud-sql-python-connector[pg8000]>=1.8.0"] +github = ["PyGithub>=2.6.0"] +motherduck = ["duckdb>=1.3.2"] +mssql = ["pymssql"] +mssql-odbc = ["pyodbc>=5.0.0"] +mysql = ["pymysql"] +mwaa = ["boto3"] +postgres = ["psycopg2"] +redshift = ["redshift_connector"] +slack = ["slack_sdk"] +snowflake = [ + "cryptography", + "snowflake-connector-python[pandas,secure-local-storage]", + "snowflake-snowpark-python", +] +trino = ["trino"] +web = [ + "fastapi==0.120.1", + "watchfiles>=0.19.0", + "uvicorn[standard]==0.22.0", + "sse-starlette>=0.2.2", + "pyarrow", +] +lsp = [ + # Duplicate of web + "fastapi==0.120.1", + "watchfiles>=0.19.0", + # "uvicorn[standard]==0.22.0", + "sse-starlette>=0.2.2", + "pyarrow", + # For lsp + "pygls>=1.2.0,<2.0.0", + "lsprotocol", +] +risingwave = ["psycopg2"] + +[project.scripts] +sqlmesh = "sqlmesh.cli.main:cli" +sqlmesh_dbt = "sqlmesh_dbt.cli:dbt" +sqlmesh_cicd = "sqlmesh.cicd.bot:bot" +sqlmesh_lsp = "sqlmesh.lsp.main:main" + +[project.urls] +Homepage = "https://sqlmesh.com/" +Documentation = "https://sqlmesh.readthedocs.io/en/stable/" +Repository = "https://github.com/SQLMesh/sqlmesh" +Issues = "https://github.com/SQLMesh/sqlmesh/issues" + +[build-system] +requires = ["setuptools >= 61.0", "setuptools_scm"] +build-backend = "setuptools.build_meta" + +[tool.setuptools] +include-package-data = false + +[tool.setuptools_scm] +version_file = "sqlmesh/_version.py" +fallback_version = "0.0.0" +local_scheme = "no-local-version" + +[tool.setuptools.packages.find] +include = ["sqlmesh", "sqlmesh.*", "sqlmesh_dbt", "sqlmesh_dbt.*", "web*"] + +[tool.setuptools.package-data] +web = ["client/dist/**"] +"*" = ["py.typed"] + +# MyPy Rules +[tool.mypy] +plugins = "pydantic.mypy" +no_implicit_optional = true +disallow_untyped_defs = true + +[[tool.mypy.overrides]] +module = [ + "examples.*.macros.*", + "tests.*", + "sqlmesh.migrations.*" +] +disallow_untyped_defs = false +# Sometimes it's helpful to use types within an "untyped" function because it allows IDE assistance +# Unfortunately this causes MyPy to print an annoying 'By default the bodies of untyped functions are not checked' +# warning so we disable that warning here +disable_error_code = "annotation-unchecked" + +[[tool.mypy.overrides]] +module = [ + "api.*", + "astor.*", + "IPython.*", + "hyperscript.*", + "py.*", + "ruamel.*", + "setuptools.*", + "graphviz.*", + "ipywidgets.*", + "google.*", + "snowflake.*", + "redshift_connector", + "databricks.*", + "faker.*", + "agate.*", + "databricks_cli.*", + "mysql.*", + "pymssql.*", + "pyodbc.*", + "psycopg2.*", + "pytest_lazyfixture.*", + "dbt.adapters.*", + "slack_sdk.*", + "py4j.*", + "boto3.*", + "trino.*", + "bs4.*", + "pydantic_core.*", + "dlt.*", + "bigframes.*", + "json_stream.*", + "duckdb.*" +] +ignore_missing_imports = true + +[tool.pytest.ini_options] +markers = [ + # Test Type Markers + # Tests are ordered from fastest to slowest + "fast: fast tests (automatically applied if no type markers)", + "slow: slow tests that typically involve interacting with a local DB (like DuckDB)", + "docker: test that involves interacting with a Docker container", + "remote: test that involves interacting with a remote DB", + "cicdonly: test that only runs on CI/CD", + "isolated: tests that need to run sequentially usually because they use fork", + "dialect_isolated: tests that need to run separately due to global dialect overrides", + + # Test Domain Markers + # default: core functionality + "cli: test for CLI", + "dbt: test for dbt adapter", + "github: test for Github CI/CD bot", + "jupyter: tests for Jupyter integration", + "web: tests for web UI", + + # Engine Adapters + "engine: test all engine adapters", + "athena: test for Athena", + "bigquery: test for BigQuery", + "clickhouse: test for Clickhouse (standalone mode / cluster mode)", + "clickhouse_cloud: test for Clickhouse (cloud mode)", + "databricks: test for Databricks", + "duckdb: test for DuckDB", + "fabric: test for Fabric", + "motherduck: test for MotherDuck", + "mssql: test for MSSQL", + "mysql: test for MySQL", + "postgres: test for Postgres", + "gcp_postgres: test for Postgres on GCP", + "redshift: test for Redshift", + "snowflake: test for Snowflake", + "spark: test for Spark", + "pyspark: test for PySpark that need to run separately from the other spark tests", + "trino: test for Trino (all connectors)", + "risingwave: test for Risingwave", + + # Other + "set_default_connection", + "registry_isolation" +] +addopts = "-n 0 --dist=loadgroup" +asyncio_default_fixture_loop_scope = "session" +log_cli = false # Set this to true to enable logging during tests +log_cli_format = "%(asctime)s.%(msecs)03d %(filename)s:%(lineno)d %(levelname)s %(message)s" +log_cli_level = "INFO" +filterwarnings = [ + "ignore:The localize method is no longer necessary, as this time zone supports the fold attribute" +] +reruns_delay = 10 + +[tool.ruff] +line-length = 100 + +[tool.ruff.lint] +select = [ + "F401", + "RET505", + "T100", +] +extend-select = ["TID"] + + +[tool.ruff.lint.flake8-tidy-imports] +banned-module-level-imports = [ + "duckdb", + "numpy", + "pandas", +] + +# Bans imports from sqlmesh.lsp in files outside of sqlmesh/lsp +[tool.ruff.lint.flake8-tidy-imports.banned-api] +"sqlmesh.lsp".msg = "Only files within sqlmesh/lsp can import from sqlmesh.lsp" + +[tool.ruff.lint.per-file-ignores] +# TID251 is used to ignore the import of sqlmesh.lsp in files outside sqlmesh/lsp +"sqlmesh/lsp/**/*.py" = ["TID251"] +"tests/lsp/**/*.py" = ["TID251"] +"benchmarks/lsp*.py" = ["TID251"] +"sqlmesh/dbt/builtin.py" = ["T100"] diff --git a/pytest.ini b/pytest.ini deleted file mode 100644 index 915c8004e4..0000000000 --- a/pytest.ini +++ /dev/null @@ -1,42 +0,0 @@ -[pytest] -markers = - # Test Type Markers - # Tests are ordered from fastest to slowest - fast: fast tests (automatically applied if no type markers) - slow: slow tests that typically involve interacting with a local DB (like DuckDB) - docker: test that involves interacting with a Docker container - remote: test that involves interacting with a remote DB - cicdonly: test that only runs on CI/CD - - # Test Domain Markers - # default: core functionality - airflow: test for Airflow scheduler - cli: test for CLI - dbt: test for dbt adapter - github: test for Github CI/CD bot - jupyter: tests for Jupyter integration - web: tests for web UI - spark_pyspark: test for Spark with PySpark dependency - # Engine Adapters - engine: test all engine adapters - bigquery: test for BigQuery - databricks: test for Databricks - duckdb: test for DuckDB - motherduck: test for MotherDuck - mssql: test for MSSQL - mysql: test for MySQL - postgres: test for Postgres - redshift: test for Redshift - snowflake: test for Snowflake - spark: test for Spark - trino: test for Trino (Hive connector) - trino_iceberg: test for Trino (Iceberg connector) - trino_delta: test for Trino (Delta connector) -addopts = -n 0 --dist=loadgroup - -# Set this to True to enable logging during tests -log_cli = False -log_cli_format = %(asctime)s.%(msecs)03d %(filename)s:%(lineno)d %(levelname)s %(message)s -log_cli_level = INFO -filterwarnings = - ignore:The localize method is no longer necessary, as this time zone supports the fold attribute diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index f7e19479b8..0000000000 --- a/setup.cfg +++ /dev/null @@ -1,100 +0,0 @@ -[metadata] -version = attr: sqlmesh.__version__ - -[mypy] -plugins = pydantic.mypy -no_implicit_optional = True -disallow_untyped_defs = True - -[mypy-api.*] -ignore_missing_imports = True - -[mypy-examples.*.macros.*] -disallow_untyped_defs = False - -[mypy-airflow.*] -ignore_missing_imports = True - -[mypy-tests.*] -disallow_untyped_defs = False - -[mypy-astor.*] -ignore_missing_imports = True - -[mypy-IPython.*] -ignore_missing_imports = True - -[mypy-hyperscript.*] -ignore_missing_imports = True - -[mypy-py.*] -ignore_missing_imports = True - -[mypy-ruamel.*] -ignore_missing_imports = True - -[mypy-setuptools.*] -ignore_missing_imports = True - -[mypy-graphviz.*] -ignore_missing_imports = True - -[mypy-ipywidgets.*] -ignore_missing_imports = True - -[mypy-google.*] -ignore_missing_imports = True - -[mypy-snowflake.*] -ignore_missing_imports = True - -[mypy-redshift_connector] -ignore_missing_imports = True - -[mypy-databricks.*] -ignore_missing_imports = True - -[mypy-faker.*] -ignore_missing_imports = True - -[mypy-agate.*] -ignore_missing_imports = True - -[mypy-databricks_cli.*] -ignore_missing_imports = True - -[mypy-mysql.*] -ignore_missing_imports = True - -[mypy-pymssql.*] -ignore_missing_imports = True - -[mypy-psycopg2.*] -ignore_missing_imports = True - -[mypy-langchain.*] -ignore_missing_imports = True - -[mypy-pytest_lazyfixture.*] -ignore_missing_imports = True - -[mypy-dbt.adapters.*] -ignore_missing_imports = True - -[mypy-slack_sdk.*] -ignore_missing_imports = True - -[mypy-py4j.*] -ignore_missing_imports = True - -[mypy-boto3.*] -ignore_missing_imports = True - -[mypy-trino.*] -ignore_missing_imports = True - -[mypy-bs4.*] -ignore_missing_imports = True - -[mypy-pydantic_core.*] -ignore_missing_imports = True diff --git a/setup.py b/setup.py deleted file mode 100644 index 9cb888d610..0000000000 --- a/setup.py +++ /dev/null @@ -1,159 +0,0 @@ -import os -from os.path import exists - -from setuptools import find_packages, setup - -description = open("README.md").read() if exists("README.md") else "" - -setup( - name="sqlmesh", - description="", - long_description=description, - long_description_content_type="text/markdown", - url="https://github.com/TobikoData/sqlmesh", - author="TobikoData Inc.", - author_email="engineering@tobikodata.com", - license="Apache License 2.0", - packages=find_packages(include=["sqlmesh", "sqlmesh.*", "web*"]), - package_data={"web": ["client/dist/**"], "": ["py.typed"]}, - entry_points={ - "console_scripts": [ - "sqlmesh = sqlmesh.cli.main:cli", - "sqlmesh_cicd = sqlmesh.cicd.bot:bot", - ], - "airflow.plugins": [ - "sqlmesh_airflow = sqlmesh.schedulers.airflow.plugin:SqlmeshAirflowPlugin", - ], - }, - use_scm_version={ - "write_to": "sqlmesh/_version.py", - "fallback_version": "0.0.0", - "local_scheme": "no-local-version", - }, - setup_requires=["setuptools_scm"], - install_requires=[ - "astor", - "click", - "croniter", - "duckdb!=0.10.3", - "dateparser", - "freezegun", - "hyperscript>=0.1.0", - "importlib-metadata; python_version<'3.12'", - "ipywidgets", - "jinja2", - "pandas", - "pydantic", - "requests", - "rich[jupyter]", - "ruamel.yaml", - "sqlglot[rs]~=25.6.0", - ], - extras_require={ - "bigquery": [ - "google-cloud-bigquery[pandas]", - "google-cloud-bigquery-storage", - ], - "databricks": [ - "databricks-sql-connector", - "databricks-cli", - ], - "dev": [ - f"apache-airflow=={os.environ.get('AIRFLOW_VERSION', '2.9.1')}", - "agate==1.7.1", - "beautifulsoup4", - "ruff~=0.4.0", - "cryptography~=42.0.4", - "dbt-core", - "dbt-duckdb>=1.7.1", - "dbt-snowflake", - "dbt-bigquery", - "Faker", - "google-auth", - "google-cloud-bigquery", - "google-cloud-bigquery-storage", - "mypy~=1.10.0", - "pre-commit", - "pandas-stubs", - "psycopg2-binary", - "pydantic<2.6.0", - "PyGithub", - "pytest", - "pytest-asyncio<0.23.0", - "pytest-mock", - "pytest-xdist", - "pyspark~=3.5.0", - "pytz", - "snowflake-connector-python[pandas,secure-local-storage]>=3.0.2", - "sqlalchemy-stubs", - "tenacity==8.1.0", - "types-croniter", - "types-dateparser", - "types-python-dateutil", - "types-pytz", - "types-requests==2.28.8", - "typing-extensions", - "custom-materializations", - ], - "cicdtest": [ - "dbt-databricks", - "dbt-redshift", - "dbt-sqlserver>=1.7.0", - "dbt-trino", - ], - "dbt": [ - "dbt-core<2", - ], - "gcppostgres": [ - "cloud-sql-python-connector[pg8000]", - ], - "github": [ - "PyGithub", - ], - "llm": [ - "langchain", - "openai", - ], - "mssql": [ - "pymssql", - ], - "mysql": [ - "mysql-connector-python", - ], - "mwaa": [ - "boto3", - ], - "postgres": [ - "psycopg2", - ], - "redshift": [ - "redshift_connector", - ], - "slack": [ - "slack_sdk", - ], - "snowflake": [ - # https://github.com/dbt-labs/dbt-snowflake/blob/main/dev-requirements.txt#L12 - "cryptography~=42.0.4", - "snowflake-connector-python[pandas,secure-local-storage]", - ], - "trino": [ - "trino", - ], - "web": [ - "fastapi==0.110.2", - "watchfiles>=0.19.0", - "uvicorn[standard]==0.22.0", - "sse-starlette>=0.2.2", - "pyarrow", - ], - }, - classifiers=[ - "Intended Audience :: Developers", - "Intended Audience :: Science/Research", - "License :: OSI Approved :: Apache Software License", - "Operating System :: OS Independent", - "Programming Language :: SQL", - "Programming Language :: Python :: 3 :: Only", - ], -) diff --git a/sqlmesh-technical-charter.pdf b/sqlmesh-technical-charter.pdf new file mode 100644 index 0000000000..107f015050 Binary files /dev/null and b/sqlmesh-technical-charter.pdf differ diff --git a/sqlmesh.png b/sqlmesh.png deleted file mode 100644 index 3a786f7a16..0000000000 Binary files a/sqlmesh.png and /dev/null differ diff --git a/sqlmesh/__init__.py b/sqlmesh/__init__.py index bfa0094dc5..577a3aaf02 100644 --- a/sqlmesh/__init__.py +++ b/sqlmesh/__init__.py @@ -24,14 +24,18 @@ from sqlmesh.core.engine_adapter import EngineAdapter as EngineAdapter from sqlmesh.core.macros import SQL as SQL, macro as macro from sqlmesh.core.model import Model as Model, model as model +from sqlmesh.core.signal import signal as signal from sqlmesh.core.snapshot import Snapshot as Snapshot from sqlmesh.core.snapshot.evaluator import ( CustomMaterialization as CustomMaterialization, ) +from sqlmesh.core.model.kind import CustomKind as CustomKind from sqlmesh.utils import ( debug_mode_enabled as debug_mode_enabled, enable_debug_mode as enable_debug_mode, + str_to_bool, ) +from sqlmesh.utils.date import DatetimeRanges as DatetimeRanges try: from sqlmesh._version import __version__ as __version__, __version_tuple__ as __version_tuple__ @@ -51,6 +55,7 @@ class RuntimeEnv(str, Enum): GOOGLE_COLAB = "google_colab" # Not currently officially supported JUPYTER = "jupyter" DEBUGGER = "debugger" + CI = "ci" # CI or other envs that shouldn't use emojis @classmethod def get(cls) -> RuntimeEnv: @@ -59,6 +64,16 @@ def get(cls) -> RuntimeEnv: Unlike the rich implementation we try to split out by notebook type instead of treating it all as Jupyter. """ + runtime_env_var = os.getenv("SQLMESH_RUNTIME_ENVIRONMENT") + if runtime_env_var: + try: + return RuntimeEnv(runtime_env_var) + except ValueError: + valid_values = [f'"{member.value}"' for member in RuntimeEnv] + raise ValueError( + f"Invalid SQLMESH_RUNTIME_ENVIRONMENT value: {runtime_env_var}. Must be one of {', '.join(valid_values)}." + ) + try: shell = get_ipython() # type: ignore if os.getenv("DATABRICKS_RUNTIME_VERSION"): @@ -72,6 +87,10 @@ def get(cls) -> RuntimeEnv: if debug_mode_enabled(): return RuntimeEnv.DEBUGGER + + if is_cicd_environment() or not is_interactive_environment(): + return RuntimeEnv.CI + return RuntimeEnv.TERMINAL @property @@ -90,9 +109,26 @@ def is_jupyter(self) -> bool: def is_google_colab(self) -> bool: return self == RuntimeEnv.GOOGLE_COLAB + @property + def is_ci(self) -> bool: + return self == RuntimeEnv.CI + @property def is_notebook(self) -> bool: - return not self.is_terminal + return not self.is_terminal and not self.is_ci + + +def is_cicd_environment() -> bool: + for key in ("CI", "GITHUB_ACTIONS", "TRAVIS", "CIRCLECI", "GITLAB_CI", "BUILDKITE"): + if str_to_bool(os.environ.get(key, "false")): + return True + return False + + +def is_interactive_environment() -> bool: + if sys.stdin is None or sys.stdout is None: + return False + return sys.stdin.isatty() and sys.stdout.isatty() if RuntimeEnv.get().is_notebook: @@ -132,52 +168,79 @@ def format(self, record: logging.LogRecord) -> str: return formatter.format(record) +def remove_excess_logs( + log_file_dir: t.Optional[t.Union[str, Path]] = None, + log_limit: int = c.DEFAULT_LOG_LIMIT, +) -> None: + if log_limit <= 0: + return + + log_file_dir = log_file_dir or c.DEFAULT_LOG_FILE_DIR + log_path_prefix = Path(log_file_dir) / LOG_FILENAME_PREFIX + + for path in list(sorted(glob.glob(f"{log_path_prefix}*.log"), reverse=True))[log_limit:]: + os.remove(path) + + def configure_logging( force_debug: bool = False, - ignore_warnings: bool = False, write_to_stdout: bool = False, write_to_file: bool = True, - log_limit: int = c.DEFAULT_LOG_LIMIT, log_file_dir: t.Optional[t.Union[str, Path]] = None, + ignore_warnings: bool = False, + log_level: t.Optional[t.Union[str, int]] = None, ) -> None: + # Remove noisy grpc logs that are not useful for users + os.environ["GRPC_VERBOSITY"] = os.environ.get("GRPC_VERBOSITY", "NONE") + logger = logging.getLogger() debug = force_debug or debug_mode_enabled() - # base logger needs to be the lowest level that we plan to log - level = logging.DEBUG if debug else logging.INFO + if log_level is not None: + if isinstance(log_level, str): + level = logging._nameToLevel.get(log_level.upper()) or logging.INFO + else: + level = log_level + else: + # base logger needs to be the lowest level that we plan to log + level = logging.DEBUG if debug else logging.INFO + logger.setLevel(level) - stdout_handler = logging.StreamHandler(sys.stdout) - stdout_handler.setFormatter(CustomFormatter()) - stdout_handler.setLevel( - level if write_to_stdout else (logging.ERROR if ignore_warnings else logging.WARNING) - ) - logger.addHandler(stdout_handler) + if debug: + # Remove noisy snowflake connector logs that are not useful for users + logging.getLogger("snowflake.connector").setLevel(logging.INFO) + + if write_to_stdout: + stdout_handler = logging.StreamHandler(sys.stdout) + stdout_handler.setFormatter(CustomFormatter()) + stdout_handler.setLevel(logging.ERROR if ignore_warnings else level) + logger.addHandler(stdout_handler) log_file_dir = log_file_dir or c.DEFAULT_LOG_FILE_DIR log_path_prefix = Path(log_file_dir) / LOG_FILENAME_PREFIX + if write_to_file: os.makedirs(str(log_file_dir), exist_ok=True) filename = f"{log_path_prefix}{datetime.now().strftime('%Y_%m_%d_%H_%M_%S')}.log" file_handler = logging.FileHandler(filename, mode="w", encoding="utf-8") + # the log files should always log at least info so that users will always have # minimal info for debugging even if they specify "ignore_warnings" file_handler.setLevel(level) file_handler.setFormatter(logging.Formatter(LOG_FORMAT)) logger.addHandler(file_handler) - if log_limit > 0: - for path in list(sorted(glob.glob(f"{log_path_prefix}*.log"), reverse=True))[log_limit:]: - os.remove(path) - if debug: import faulthandler - import signal enable_debug_mode() # Enable threadumps. faulthandler.enable() + # Windows doesn't support register so we check for it here if hasattr(faulthandler, "register"): - faulthandler.register(signal.SIGUSR1.value) + from signal import SIGUSR1 + + faulthandler.register(SIGUSR1.value) diff --git a/sqlmesh/cli/__init__.py b/sqlmesh/cli/__init__.py index b0d59356f9..3b417eb478 100644 --- a/sqlmesh/cli/__init__.py +++ b/sqlmesh/cli/__init__.py @@ -4,10 +4,8 @@ import click from sqlglot.errors import SqlglotError - from sqlmesh.core.context import Context from sqlmesh.utils import debug_mode_enabled -from sqlmesh.utils.concurrency import NodeExecutionFailedError from sqlmesh.utils.errors import SQLMeshError DECORATOR_RETURN_TYPE = t.TypeVar("DECORATOR_RETURN_TYPE") @@ -38,11 +36,9 @@ def _default_exception_handler( ) -> DECORATOR_RETURN_TYPE: try: return func() - except NodeExecutionFailedError as ex: - cause = ex.__cause__ - raise click.ClickException(f"Failed processing {ex.node}. {cause}") except (SQLMeshError, SqlglotError, ValueError) as ex: - raise click.ClickException(str(ex)) + click.echo(click.style("Error: " + str(ex), fg="red")) + exit(1) finally: if context: context.close() diff --git a/sqlmesh/cli/example_project.py b/sqlmesh/cli/example_project.py deleted file mode 100644 index 990260aa1f..0000000000 --- a/sqlmesh/cli/example_project.py +++ /dev/null @@ -1,238 +0,0 @@ -import typing as t -from enum import Enum -from pathlib import Path - -import click -from sqlglot import Dialect -from sqlmesh.utils.date import yesterday_ds - - -class ProjectTemplate(Enum): - AIRFLOW = "airflow" - DBT = "dbt" - DEFAULT = "default" - EMPTY = "empty" - - -def _gen_config(dialect: t.Optional[str], template: ProjectTemplate) -> str: - default_configs = { - ProjectTemplate.DEFAULT: f"""gateways: - local: - connection: - type: duckdb - database: db.db - -default_gateway: local - -model_defaults: - dialect: {dialect} - start: {yesterday_ds()} -""", - ProjectTemplate.AIRFLOW: f"""gateways: - local: - connection: - type: duckdb - database: db.db - -default_gateway: local - -default_scheduler: - type: airflow - airflow_url: http://localhost:8080/ - username: airflow - password: airflow - -model_defaults: - dialect: {dialect} - start: {yesterday_ds()} -""", - ProjectTemplate.DBT: """from pathlib import Path - -from sqlmesh.dbt.loader import sqlmesh_config - -config = sqlmesh_config(Path(__file__).parent) -""", - } - - default_configs[ProjectTemplate.EMPTY] = default_configs[ProjectTemplate.DEFAULT] - return default_configs[template] - - -EXAMPLE_SCHEMA_NAME = "sqlmesh_example" -EXAMPLE_FULL_MODEL_NAME = f"{EXAMPLE_SCHEMA_NAME}.full_model" -EXAMPLE_INCREMENTAL_MODEL_NAME = f"{EXAMPLE_SCHEMA_NAME}.incremental_model" -EXAMPLE_SEED_MODEL_NAME = f"{EXAMPLE_SCHEMA_NAME}.seed_model" - -EXAMPLE_FULL_MODEL_DEF = f"""MODEL ( - name {EXAMPLE_FULL_MODEL_NAME}, - kind FULL, - cron '@daily', - grain item_id, - audits (assert_positive_order_ids), -); - -SELECT - item_id, - COUNT(DISTINCT id) AS num_orders, -FROM - {EXAMPLE_INCREMENTAL_MODEL_NAME} -GROUP BY item_id -""" - -EXAMPLE_INCREMENTAL_MODEL_DEF = f"""MODEL ( - name {EXAMPLE_INCREMENTAL_MODEL_NAME}, - kind INCREMENTAL_BY_TIME_RANGE ( - time_column event_date - ), - start '2020-01-01', - cron '@daily', - grain (id, event_date) -); - -SELECT - id, - item_id, - event_date, -FROM - {EXAMPLE_SEED_MODEL_NAME} -WHERE - event_date BETWEEN @start_date AND @end_date -""" - -EXAMPLE_SEED_MODEL_DEF = f"""MODEL ( - name {EXAMPLE_SEED_MODEL_NAME}, - kind SEED ( - path '../seeds/seed_data.csv' - ), - columns ( - id INTEGER, - item_id INTEGER, - event_date DATE - ), - grain (id, event_date) -); -""" - -EXAMPLE_AUDIT = """AUDIT ( - name assert_positive_order_ids, -); - -SELECT * -FROM @this_model -WHERE - item_id < 0 -""" - -EXAMPLE_SEED_DATA = """id,item_id,event_date -1,2,2020-01-01 -2,1,2020-01-01 -3,3,2020-01-03 -4,1,2020-01-04 -5,1,2020-01-05 -6,1,2020-01-06 -7,1,2020-01-07 -""" - -EXAMPLE_TEST = f"""test_example_full_model: - model: {EXAMPLE_FULL_MODEL_NAME} - inputs: - {EXAMPLE_INCREMENTAL_MODEL_NAME}: - rows: - - id: 1 - item_id: 1 - - id: 2 - item_id: 1 - - id: 3 - item_id: 2 - outputs: - query: - rows: - - item_id: 1 - num_orders: 2 - - item_id: 2 - num_orders: 1 -""" - - -def init_example_project( - path: t.Union[str, Path], - dialect: t.Optional[str], - template: ProjectTemplate = ProjectTemplate.DEFAULT, -) -> None: - root_path = Path(path) - config_extension = "py" if template == ProjectTemplate.DBT else "yaml" - config_path = root_path / f"config.{config_extension}" - audits_path = root_path / "audits" - macros_path = root_path / "macros" - models_path = root_path / "models" - seeds_path = root_path / "seeds" - tests_path = root_path / "tests" - - if config_path.exists(): - raise click.ClickException(f"Found an existing config in '{config_path}'") - - if not dialect and template != ProjectTemplate.DBT: - raise click.ClickException( - "Default SQL dialect is a required argument for SQLMesh projects" - ) - - _create_config(config_path, dialect, template) - if template == ProjectTemplate.DBT: - return - - _create_folders([audits_path, macros_path, models_path, seeds_path, tests_path]) - - if template != ProjectTemplate.EMPTY: - _create_macros(macros_path) - _create_audits(audits_path) - _create_models(models_path) - _create_seeds(seeds_path) - _create_tests(tests_path) - - -def _create_folders(target_folders: t.Sequence[Path]) -> None: - for folder_path in target_folders: - folder_path.mkdir(exist_ok=True) - (folder_path / ".gitkeep").touch() - - -def _create_config(config_path: Path, dialect: t.Optional[str], template: ProjectTemplate) -> None: - if dialect: - Dialect.get_or_raise(dialect) - - project_config = _gen_config(dialect, template) - - _write_file( - config_path, - project_config, - ) - - -def _create_macros(macros_path: Path) -> None: - (macros_path / "__init__.py").touch() - - -def _create_audits(audits_path: Path) -> None: - _write_file(audits_path / "assert_positive_order_ids.sql", EXAMPLE_AUDIT) - - -def _create_models(models_path: Path) -> None: - for model_name, model_def in [ - (EXAMPLE_FULL_MODEL_NAME, EXAMPLE_FULL_MODEL_DEF), - (EXAMPLE_INCREMENTAL_MODEL_NAME, EXAMPLE_INCREMENTAL_MODEL_DEF), - (EXAMPLE_SEED_MODEL_NAME, EXAMPLE_SEED_MODEL_DEF), - ]: - _write_file(models_path / f"{model_name.split('.')[-1]}.sql", model_def) - - -def _create_seeds(seeds_path: Path) -> None: - _write_file(seeds_path / "seed_data.csv", EXAMPLE_SEED_DATA) - - -def _create_tests(tests_path: Path) -> None: - _write_file(tests_path / "test_full_model.yaml", EXAMPLE_TEST) - - -def _write_file(path: Path, payload: str) -> None: - with open(path, "w", encoding="utf-8") as fd: - fd.write(payload) diff --git a/sqlmesh/cli/main.py b/sqlmesh/cli/main.py index 3d0ed10a4a..ec5acbea59 100644 --- a/sqlmesh/cli/main.py +++ b/sqlmesh/cli/main.py @@ -4,22 +4,43 @@ import os import sys import typing as t +from pathlib import Path import click -from sqlmesh import configure_logging +from sqlmesh import configure_logging, remove_excess_logs from sqlmesh.cli import error_handler from sqlmesh.cli import options as opt -from sqlmesh.cli.example_project import ProjectTemplate, init_example_project +from sqlmesh.cli.project_init import ( + InitCliMode, + ProjectTemplate, + init_example_project, + interactive_init, +) from sqlmesh.core.analytics import cli_analytics from sqlmesh.core.config import load_configs +from sqlmesh.core.console import configure_console, get_console from sqlmesh.core.context import Context +from sqlmesh.utils import Verbosity from sqlmesh.utils.date import TimeLike -from sqlmesh.utils.errors import MissingDependencyError +from sqlmesh.utils.errors import MissingDependencyError, SQLMeshError logger = logging.getLogger(__name__) -SKIP_LOAD_COMMANDS = ("create_external_models", "migrate", "rollback") + +SKIP_LOAD_COMMANDS = ( + "clean", + "create_external_models", + "destroy", + "environments", + "invalidate", + "janitor", + "migrate", + "rollback", + "run", + "table_name", +) +SKIP_CONTEXT_COMMANDS = ("init", "ui") def _sqlmesh_version() -> str: @@ -39,11 +60,13 @@ def _sqlmesh_version() -> str: "--gateway", type=str, help="The name of the gateway.", + envvar="SQLMESH_GATEWAY", ) @click.option( "--ignore-warnings", is_flag=True, help="Ignore warnings.", + envvar="SQLMESH_IGNORE_WARNINGS", ) @click.option( "--debug", @@ -60,6 +83,12 @@ def _sqlmesh_version() -> str: type=str, help="The directory to write log files to.", ) +@click.option( + "--dotenv", + type=click.Path(exists=True, path_type=Path), + help="Path to a custom .env file to load environment variables.", + envvar="SQLMESH_DOTENV_PATH", +) @click.pass_context @error_handler def cli( @@ -71,26 +100,34 @@ def cli( debug: bool = False, log_to_stdout: bool = False, log_file_dir: t.Optional[str] = None, + dotenv: t.Optional[Path] = None, ) -> None: """SQLMesh command line tool.""" if "--help" in sys.argv: return + configure_logging( + debug, + log_to_stdout, + log_file_dir=log_file_dir, + ignore_warnings=ignore_warnings, + ) + configure_console(ignore_warnings=ignore_warnings) + load = True if len(paths) == 1: path = os.path.abspath(paths[0]) - if ctx.invoked_subcommand in ("init", "ui"): + if ctx.invoked_subcommand in SKIP_CONTEXT_COMMANDS: ctx.obj = path return - elif ctx.invoked_subcommand in SKIP_LOAD_COMMANDS: + if ctx.invoked_subcommand in SKIP_LOAD_COMMANDS: load = False - configs = load_configs(config, Context.CONFIG_TYPE, paths) + configs = load_configs(config, Context.CONFIG_TYPE, paths, dotenv_path=dotenv) log_limit = list(configs.values())[0].log_limit - configure_logging( - debug, ignore_warnings, log_to_stdout, log_limit=log_limit, log_file_dir=log_file_dir - ) + + remove_excess_logs(log_file_dir, log_limit) try: context = Context( @@ -113,25 +150,104 @@ def cli( @cli.command("init") -@click.argument("sql_dialect", required=False) +@click.argument("engine", required=False) @click.option( "-t", "--template", type=str, - help="Project template. Supported values: airflow, dbt, default, empty.", + help="Project template. Supported values: dbt, dlt, default, empty.", +) +@click.option( + "--dlt-pipeline", + type=str, + help="DLT pipeline for which to generate a SQLMesh project. Use alongside template: dlt", +) +@click.option( + "--dlt-path", + type=str, + help="The directory where the DLT pipeline resides. Use alongside template: dlt", ) @click.pass_context @error_handler @cli_analytics def init( - ctx: click.Context, sql_dialect: t.Optional[str] = None, template: t.Optional[str] = None + ctx: click.Context, + engine: t.Optional[str] = None, + template: t.Optional[str] = None, + dlt_pipeline: t.Optional[str] = None, + dlt_path: t.Optional[str] = None, ) -> None: """Create a new SQLMesh repository.""" - try: - project_template = ProjectTemplate(template.lower() if template else "default") - except ValueError: - raise click.ClickException(f"Invalid project template '{template}'") - init_example_project(ctx.obj, dialect=sql_dialect, template=project_template) + project_template = None + if template: + try: + project_template = ProjectTemplate(template.lower()) + except ValueError: + template_strings = "', '".join([template.value for template in ProjectTemplate]) + raise click.ClickException( + f"Invalid project template '{template}'. Please specify one of '{template_strings}'." + ) + + if engine or project_template == ProjectTemplate.DBT: + init_example_project( + path=ctx.obj, + template=project_template or ProjectTemplate.DEFAULT, + engine_type=engine, + pipeline=dlt_pipeline, + dlt_path=dlt_path, + ) + return + + import sqlmesh.utils.rich as srich + + console = srich.console + + project_template, engine_type, cli_mode = interactive_init(ctx.obj, console, project_template) + + config_path = init_example_project( + path=ctx.obj, + template=project_template, + engine_type=engine_type, + cli_mode=cli_mode or InitCliMode.DEFAULT, + pipeline=dlt_pipeline, + dlt_path=dlt_path, + ) + + engine_install_text = "" + if engine_type and engine_type not in ("duckdb", "motherduck"): + install_text = ( + "pyspark" if engine_type == "spark" else f"sqlmesh\\[{engine_type.replace('_', '')}]" + ) + engine_install_text = f'• Run command in CLI to install your SQL engine\'s Python dependencies: pip install "{install_text}"\n' + # interactive init does not support DLT template + next_step_text = { + ProjectTemplate.DEFAULT: f"{engine_install_text}• Update your gateway connection settings (e.g., username/password) in the project configuration file:\n {config_path}", + ProjectTemplate.DBT: "", + } + next_step_text[ProjectTemplate.EMPTY] = next_step_text[ProjectTemplate.DEFAULT] + + quickstart_text = { + ProjectTemplate.DEFAULT: "Quickstart guide:\nhttps://sqlmesh.readthedocs.io/en/stable/quickstart/cli/", + ProjectTemplate.DBT: "dbt guide:\nhttps://sqlmesh.readthedocs.io/en/stable/integrations/dbt/", + } + quickstart_text[ProjectTemplate.EMPTY] = quickstart_text[ProjectTemplate.DEFAULT] + + console.print(f"""────────────────────────────── + +Your SQLMesh project is ready! + +Next steps: +{next_step_text[project_template]} +• Run command in CLI: sqlmesh plan +• (Optional) Explain a plan: sqlmesh plan --explain + +{quickstart_text[project_template]} + +Need help? +• Docs: https://sqlmesh.readthedocs.io +• Slack: https://www.tobikodata.com/slack +• GitHub: https://github.com/SQLMesh/sqlmesh/issues +""") @cli.command("render") @@ -146,6 +262,7 @@ def init( help="The SQL dialect to render the query as.", ) @click.option("--no-format", is_flag=True, help="Disable fancy formatting of the query.") +@opt.format_options @click.pass_context @error_handler @cli_analytics @@ -158,8 +275,11 @@ def render( expand: t.Optional[t.Union[bool, t.Iterable[str]]] = None, dialect: t.Optional[str] = None, no_format: bool = False, + **format_kwargs: t.Any, ) -> None: """Render a model's query, optionally expanding referenced models.""" + model = ctx.obj.get_model(model, raise_if_missing=True) + rendered = ctx.obj.render( model, start=start, @@ -168,7 +288,17 @@ def render( expand=expand, ) - sql = rendered.sql(pretty=True, dialect=ctx.obj.config.dialect if dialect is None else dialect) + format_config = ctx.obj.config_for_node(model).format + format_kwargs = { + **format_config.generator_options, + **{k: v for k, v in format_kwargs.items() if v is not None}, + } + + sql = rendered.sql( + pretty=True, + dialect=ctx.obj.config.dialect if dialect is None else dialect, + **format_kwargs, + ) if no_format: print(sql) else: @@ -211,6 +341,7 @@ def evaluate( @cli.command("format") +@click.argument("paths", nargs=-1) @click.option( "-t", "--transpile", @@ -218,49 +349,33 @@ def evaluate( help="Transpile project models to the specified dialect.", ) @click.option( - "--append-newline", + "--check", is_flag=True, - help="Include a newline at the end of each file.", + help="Whether or not to check formatting (but not actually format anything).", default=None, ) @click.option( - "--normalize", + "--rewrite-casts/--no-rewrite-casts", is_flag=True, - help="Whether or not to normalize identifiers to lowercase.", + help="Rewrite casts to use the :: syntax.", default=None, ) @click.option( - "--pad", - type=int, - help="Determines the pad size in a formatted string.", -) -@click.option( - "--indent", - type=int, - help="Determines the indentation size in a formatted string.", -) -@click.option( - "--normalize-functions", - type=str, - help="Whether or not to normalize all function names. Possible values are: 'upper', 'lower'", -) -@click.option( - "--leading-comma", + "--append-newline", is_flag=True, - help="Determines whether or not the comma is leading or trailing in select expressions. Default is trailing.", + help="Include a newline at the end of each file.", default=None, ) -@click.option( - "--max-text-width", - type=int, - help="The max number of characters in a segment before creating new lines in pretty mode.", -) +@opt.format_options @click.pass_context @error_handler @cli_analytics -def format(ctx: click.Context, **kwargs: t.Any) -> None: +def format( + ctx: click.Context, paths: t.Optional[t.Tuple[str, ...]] = None, **kwargs: t.Any +) -> None: """Format all SQL models and audits.""" - ctx.obj.format(**{k: v for k, v in kwargs.items() if v is not None}) + if not ctx.obj.format(**{k: v for k, v in kwargs.items() if v is not None}, paths=paths): + ctx.exit(1) @cli.command("diff") @@ -288,6 +403,13 @@ def diff(ctx: click.Context, environment: t.Optional[str] = None) -> None: "--skip-tests", is_flag=True, help="Skip tests prior to generating the plan if they are defined.", + default=None, +) +@click.option( + "--skip-linter", + is_flag=True, + help="Skip linting prior to generating the plan if the linter is enabled.", + default=None, ) @click.option( "--restate-model", @@ -300,11 +422,20 @@ def diff(ctx: click.Context, environment: t.Optional[str] = None) -> None: "--no-gaps", is_flag=True, help="Ensure that new snapshots have no data gaps when comparing to existing snapshots for matching models in the target environment.", + default=None, ) @click.option( "--skip-backfill", + "--dry-run", is_flag=True, - help="Skip the backfill step.", + help="Skip the backfill step and only create a virtual update for the plan.", + default=None, +) +@click.option( + "--empty-backfill", + is_flag=True, + help="Produce empty backfill. Like --skip-backfill no models will be backfilled, unlike --skip-backfill missing intervals will be recorded as if they were backfilled.", + default=None, ) @click.option( "--forward-only", @@ -318,6 +449,12 @@ def diff(ctx: click.Context, environment: t.Optional[str] = None) -> None: multiple=True, help="Allow destructive forward-only changes to models whose names match the expression.", ) +@click.option( + "--allow-additive-model", + type=str, + multiple=True, + help="Allow additive forward-only changes to models whose names match the expression.", +) @click.option( "--effective-from", type=str, @@ -358,7 +495,7 @@ def diff(ctx: click.Context, environment: t.Optional[str] = None) -> None: "--backfill-model", type=str, multiple=True, - help="Backfill only the models whose names match the expression. This is supported only when targeting a development environment.", + help="Backfill only the models whose names match the expression.", ) @click.option( "--no-diff", @@ -370,6 +507,7 @@ def diff(ctx: click.Context, environment: t.Optional[str] = None) -> None: "--run", is_flag=True, help="Run latest intervals as part of the plan application (prod environment only).", + default=None, ) @click.option( "--enable-preview", @@ -377,26 +515,57 @@ def diff(ctx: click.Context, environment: t.Optional[str] = None) -> None: help="Enable preview for forward-only models when targeting a development environment.", default=None, ) +@click.option( + "--diff-rendered", + is_flag=True, + help="Output text differences for the rendered versions of the models and standalone audits.", + default=None, +) +@click.option( + "--explain", + is_flag=True, + help="Explain the plan instead of applying it.", + default=None, +) +@click.option( + "--ignore-cron", + is_flag=True, + help="Run all missing intervals, ignoring individual cron schedules. Only applies if --run is set.", + default=None, +) +@click.option( + "--min-intervals", + default=None, + help="For every model, ensure at least this many intervals are covered by a missing intervals check regardless of the plan start date", +) @opt.verbose @click.pass_context @error_handler @cli_analytics def plan( - ctx: click.Context, verbose: bool, environment: t.Optional[str] = None, **kwargs: t.Any + ctx: click.Context, + verbose: int, + environment: t.Optional[str] = None, + **kwargs: t.Any, ) -> None: """Apply local changes to the target environment.""" context = ctx.obj restate_models = kwargs.pop("restate_model") or None select_models = kwargs.pop("select_model") or None allow_destructive_models = kwargs.pop("allow_destructive_model") or None + allow_additive_models = kwargs.pop("allow_additive_model") or None backfill_models = kwargs.pop("backfill_model") or None - context.console.verbose = verbose + ignore_cron = kwargs.pop("ignore_cron") or None + setattr(get_console(), "verbosity", Verbosity(verbose)) + context.plan( environment, restate_models=restate_models, select_models=select_models, allow_destructive_models=allow_destructive_models, + allow_additive_models=allow_additive_models, backfill_models=backfill_models, + ignore_cron=ignore_cron, **kwargs, ) @@ -411,15 +580,32 @@ def plan( is_flag=True, help="Run for all missing intervals, ignoring individual cron schedules.", ) +@click.option( + "--select-model", + type=str, + multiple=True, + help="Select specific models to run. Note: this always includes upstream dependencies.", +) +@click.option( + "--exit-on-env-update", + type=int, + help="If set, the command will exit with the specified code if the run is interrupted by an update to the target environment.", +) +@click.option( + "--no-auto-upstream", + is_flag=True, + help="Do not automatically include upstream models. Only applicable when --select-model is used. Note: this may result in missing / invalid data for the selected models.", +) @click.pass_context @error_handler @cli_analytics def run(ctx: click.Context, environment: t.Optional[str] = None, **kwargs: t.Any) -> None: """Evaluate missing intervals for the target environment.""" context = ctx.obj - success = context.run(environment, **kwargs) - if not success: - raise click.ClickException("Run DAG Failed. See output for details.") + select_models = kwargs.pop("select_model") or None + completion_status = context.run(environment, select_models=select_models, **kwargs) + if completion_status.is_failure: + raise click.ClickException("Run failed.") @cli.command("invalidate") @@ -457,6 +643,19 @@ def janitor(ctx: click.Context, ignore_ttl: bool, **kwargs: t.Any) -> None: ctx.obj.run_janitor(ignore_ttl, **kwargs) +@cli.command("destroy") +@click.pass_context +@error_handler +@cli_analytics +def destroy(ctx: click.Context, **kwargs: t.Any) -> None: + """ + The destroy command removes all project resources. + + This includes engine-managed objects, state tables, the SQLMesh cache and any build artifacts. + """ + ctx.obj.destroy(**kwargs) + + @cli.command("dag") @click.argument("file", required=True) @click.option( @@ -483,7 +682,7 @@ def dag(ctx: click.Context, file: str, select_model: t.List[str]) -> None: "queries", type=(str, str), multiple=True, - required=True, + default=[], help="Queries that will be used to generate data for the model's dependencies.", ) @click.option( @@ -566,7 +765,7 @@ def create_test( def test( obj: Context, k: t.List[str], - verbose: bool, + verbose: int, preserve_fixtures: bool, tests: t.List[str], ) -> None: @@ -574,7 +773,7 @@ def test( result = obj.test( match_patterns=k, tests=tests, - verbose=verbose, + verbosity=Verbosity(verbose), preserve_fixtures=preserve_fixtures, ) if not result.wasSuccessful(): @@ -602,7 +801,48 @@ def audit( execution_time: t.Optional[TimeLike] = None, ) -> None: """Run audits for the target model(s).""" - obj.audit(models=models, start=start, end=end, execution_time=execution_time) + if not obj.audit(models=models, start=start, end=end, execution_time=execution_time): + exit(1) + + +@cli.command("check_intervals") +@click.option( + "--no-signals", + is_flag=True, + help="Disable signal checks and only show missing intervals.", + default=False, +) +@click.argument("environment", required=False) +@click.option( + "--select-model", + type=str, + multiple=True, + help="Select specific models to show missing intervals for.", +) +@opt.start_time +@opt.end_time +@click.pass_context +@error_handler +@cli_analytics +def check_intervals( + ctx: click.Context, + environment: t.Optional[str], + no_signals: bool, + select_model: t.List[str], + start: TimeLike, + end: TimeLike, +) -> None: + """Show missing intervals in an environment, respecting signals.""" + context = ctx.obj + context.console.show_intervals( + context.check_intervals( + environment, + no_signals=no_signals, + select_models=select_model, + start=start, + end=end, + ) + ) @cli.command("fetchdf") @@ -622,16 +862,17 @@ def fetchdf(ctx: click.Context, sql: str) -> None: is_flag=True, help="Skip the connection test.", ) +@opt.verbose @click.pass_obj @error_handler @cli_analytics -def info(obj: Context, skip_connection: bool) -> None: +def info(obj: Context, skip_connection: bool, verbose: int) -> None: """ Print information about a SQLMesh project. Includes counts of project models and macros and connection tests for the data warehouse. """ - obj.print_info(skip_connection=skip_connection) + obj.print_info(skip_connection=skip_connection, verbosity=Verbosity(verbose)) @cli.command("ui") @@ -649,15 +890,22 @@ def info(obj: Context, skip_connection: bool) -> None: ) @click.option( "--mode", - type=click.Choice(["ide", "default", "docs", "plan"], case_sensitive=False), - default="default", - help="Mode to start the UI in. Default: default", + type=click.Choice(["ide", "catalog", "docs", "plan"], case_sensitive=False), + default="ide", + help="Mode to start the UI in. Default: ide", ) @click.pass_context @error_handler @cli_analytics def ui(ctx: click.Context, host: str, port: int, mode: str) -> None: """Start a browser-based SQLMesh UI.""" + from sqlmesh.core.console import get_console + + get_console().log_warning( + "The UI is deprecated and will be removed in a future version. Please use the SQLMesh VSCode extension instead. " + "Learn more at https://sqlmesh.readthedocs.io/en/stable/guides/vscode/" + ) + try: import uvicorn except ModuleNotFoundError as e: @@ -675,7 +923,7 @@ def ui(ctx: click.Context, host: str, port: int, mode: str) -> None: if gateway: os.environ["GATEWAY"] = gateway uvicorn.run( - "web.server.main:app", + "web.server.app:app", host=host, port=port, log_level="info", @@ -760,18 +1008,48 @@ def create_external_models(obj: Context, **kwargs: t.Any) -> None: is_flag=True, help="Disable the check for a primary key (grain) that is missing or is not unique.", ) +@click.option( + "--warn-grain-check", + is_flag=True, + help="Warn if any selected model is missing a grain, and compute diffs for the remaining models.", +) +@click.option( + "--temp-schema", + type=str, + help="Schema used for temporary tables. It can be `CATALOG.SCHEMA` or `SCHEMA`. Default: `sqlmesh_temp`", +) +@click.option( + "--select-model", + "-m", + type=str, + multiple=True, + help="Specify one or more models to data diff. Use wildcards to diff multiple models. Ex: '*' (all models with applied plan diffs), 'demo.model+' (this and downstream models), 'git:feature_branch' (models with direct modifications in this branch only)", +) +@click.option( + "--schema-diff-ignore-case", + is_flag=True, + help="If set, when performing a schema diff the case of column names is ignored when matching between the two schemas. For example, 'col_a' in the source schema and 'COL_A' in the target schema will be treated as the same column.", +) @click.pass_obj @error_handler @cli_analytics def table_diff( obj: Context, source_to_target: str, model: t.Optional[str], **kwargs: t.Any ) -> None: - """Show the diff between two tables.""" + """Show the diff between two tables or a selection of models when they are specified.""" source, target = source_to_target.split(":") + select_model = kwargs.pop("select_model", None) + + if model and select_model: + raise SQLMeshError( + "The --select-model option cannot be used together with a model argument. Please choose one of them." + ) + + select_models = {model} if model else select_model obj.table_diff( source=source, target=target, - model_or_snapshot=model, + select_models=select_models, **kwargs, ) @@ -801,66 +1079,186 @@ def rewrite(obj: Context, sql: str, read: str = "", write: str = "") -> None: ) -@cli.command("prompt") -@click.argument("prompt") +@cli.command("clean") +@click.pass_obj +@error_handler +@cli_analytics +def clean(obj: Context) -> None: + """Clears the SQLMesh cache and any build artifacts.""" + obj.clear_caches() + + +@cli.command("table_name") +@click.argument("model_name", required=True) +@click.option( + "--environment", + "--env", + help="The environment to source the model version from.", +) @click.option( - "-e", - "--evaluate", + "--prod", is_flag=True, - help="Evaluate the generated SQL query and display the results.", + default=False, + help="If set, return the name of the physical table that will be used in production for the model version promoted in the target environment.", ) +@click.pass_obj +@error_handler +@cli_analytics +def table_name( + obj: Context, + model_name: str, + environment: t.Optional[str] = None, + prod: bool = False, +) -> None: + """Prints the name of the physical table for the given model.""" + print(obj.table_name(model_name, environment, prod)) + + +@cli.command("dlt_refresh") +@click.argument("pipeline", required=True) @click.option( "-t", - "--temperature", - type=float, - help="Sampling temperature. 0.0 - precise and predictable, 0.5 - balanced, 1.0 - creative. Default: 0.7", - default=0.7, + "--table", + type=str, + multiple=True, + help="The specific dlt tables to refresh in the SQLMesh models.", +) +@click.option( + "-f", + "--force", + is_flag=True, + default=False, + help="If set, existing models are overwritten with the new DLT tables.", +) +@click.option( + "--dlt-path", + type=str, + help="The directory where the DLT pipeline resides.", ) -@opt.verbose @click.pass_context @error_handler @cli_analytics -def prompt( - ctx: click.Context, prompt: str, evaluate: bool, temperature: float, verbose: bool +def dlt_refresh( + ctx: click.Context, + pipeline: str, + force: bool, + table: t.List[str] = [], + dlt_path: t.Optional[str] = None, ) -> None: - """Uses LLM to generate a SQL query from a prompt.""" - from sqlmesh.integrations.llm import LLMIntegration + """Attaches to a DLT pipeline with the option to update specific or all missing tables in the SQLMesh project.""" + from sqlmesh.integrations.dlt import generate_dlt_models - context = ctx.obj + sqlmesh_models = generate_dlt_models(ctx.obj, pipeline, list(table or []), force, dlt_path) + if sqlmesh_models: + model_names = "\n".join([f"- {model_name}" for model_name in sqlmesh_models]) + ctx.obj.console.log_success(f"Updated SQLMesh project with models:\n{model_names}") + else: + ctx.obj.console.log_success("All SQLMesh models are up to date.") - llm_integration = LLMIntegration( - context.models.values(), - context.engine_adapter.dialect, - temperature=temperature, - verbose=verbose, - ) - query = llm_integration.query(prompt) - context.console.log_status_update(query) - if evaluate: - context.console.log_success(context.fetchdf(query)) +@cli.command("environments") +@click.pass_obj +@error_handler +@cli_analytics +def environments(obj: Context) -> None: + """Prints the list of SQLMesh environments with its expiry datetime.""" + obj.print_environment_names() -@cli.command("clean") +@cli.command("lint") +@click.option( + "--models", + "--model", + multiple=True, + help="A model to lint. Multiple models can be linted. If no models are specified, every model will be linted.", +) @click.pass_obj @error_handler @cli_analytics -def clean(obj: Context) -> None: - """Clears the SQLMesh cache and any build artifacts.""" - obj.clear_caches() +def lint( + obj: Context, + models: t.Iterator[str], +) -> None: + """Run the linter for the target model(s).""" + obj.lint_models(models) -@cli.command("table_name") -@click.argument("model_name", required=True) +@cli.group(no_args_is_help=True) +def state() -> None: + """Commands for interacting with state""" + pass + + +@state.command("export") @click.option( - "--dev", + "-o", + "--output-file", + required=True, + help="Path to write the state export to", + type=click.Path(dir_okay=False, writable=True, path_type=Path), +) +@click.option( + "--environment", + multiple=True, + help="Name of environment to export. Specify multiple --environment arguments to export multiple environments", +) +@click.option( + "--local", is_flag=True, - help="Print the name of the snapshot table used for previews in development environments.", - default=False, + help="Export local state only. Note that the resulting file will not be importable", +) +@click.option( + "--no-confirm", + is_flag=True, + help="Do not prompt for confirmation before exporting existing state", ) @click.pass_obj @error_handler @cli_analytics -def table_name(obj: Context, model_name: str, dev: bool) -> None: - """Prints the name of the physical table for the given model.""" - print(obj.table_name(model_name, dev)) +def state_export( + obj: Context, + output_file: Path, + environment: t.Optional[t.Tuple[str]], + local: bool, + no_confirm: bool, +) -> None: + """Export the state database to a file""" + confirm = not no_confirm + + if environment and local: + raise click.ClickException("Cannot specify both --environment and --local") + + environment_names = list(environment) if environment else None + obj.export_state( + output_file=output_file, + environment_names=environment_names, + local_only=local, + confirm=confirm, + ) + + +@state.command("import") +@click.option( + "-i", + "--input-file", + help="Path to the state file", + required=True, + type=click.Path(exists=True, dir_okay=False, readable=True, path_type=Path), +) +@click.option( + "--replace", + is_flag=True, + help="Clear the remote state before loading the file. If omitted, a merge is performed instead", +) +@click.option( + "--no-confirm", + is_flag=True, + help="Do not prompt for confirmation before updating existing state", +) +@click.pass_obj +@error_handler +@cli_analytics +def state_import(obj: Context, input_file: Path, replace: bool, no_confirm: bool) -> None: + """Import a state export file back into the state database""" + confirm = not no_confirm + obj.import_state(input_file=input_file, clear=replace, confirm=confirm) diff --git a/sqlmesh/cli/options.py b/sqlmesh/cli/options.py index 869cd46e19..2e4642eb0e 100644 --- a/sqlmesh/cli/options.py +++ b/sqlmesh/cli/options.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import typing as t import click @@ -51,6 +52,43 @@ verbose = click.option( "-v", "--verbose", - is_flag=True, - help="Verbose output.", + count=True, + help="Verbose output. Use -vv for very verbose output.", ) + + +def format_options(func: t.Callable) -> t.Callable: + """Decorator to add common format options to CLI commands.""" + func = click.option( + "--normalize", + is_flag=True, + help="Whether or not to normalize identifiers to lowercase.", + default=None, + )(func) + func = click.option( + "--pad", + type=int, + help="Determines the pad size in a formatted string.", + )(func) + func = click.option( + "--indent", + type=int, + help="Determines the indentation size in a formatted string.", + )(func) + func = click.option( + "--normalize-functions", + type=str, + help="Whether or not to normalize all function names. Possible values are: 'upper', 'lower'", + )(func) + func = click.option( + "--leading-comma", + is_flag=True, + default=None, + help="Determines whether or not the comma is leading or trailing in select expressions. Default is trailing.", + )(func) + func = click.option( + "--max-text-width", + type=int, + help="The max number of characters in a segment before creating new lines in pretty mode.", + )(func) + return func diff --git a/sqlmesh/cli/project_init.py b/sqlmesh/cli/project_init.py new file mode 100644 index 0000000000..e3132a6de3 --- /dev/null +++ b/sqlmesh/cli/project_init.py @@ -0,0 +1,525 @@ +import typing as t +from enum import Enum +from pathlib import Path +from dataclasses import dataclass +from rich.prompt import Prompt +from rich.console import Console +from sqlmesh.integrations.dlt import generate_dlt_models_and_settings +from sqlmesh.utils.date import yesterday_ds +from sqlmesh.utils.errors import SQLMeshError +from sqlmesh.core.config.common import VirtualEnvironmentMode + +from sqlmesh.core.config.common import DBT_PROJECT_FILENAME +from sqlmesh.core.config.connection import ( + CONNECTION_CONFIG_TO_TYPE, + DIALECT_TO_TYPE, + INIT_DISPLAY_INFO_TO_TYPE, +) + + +PRIMITIVES = (str, int, bool, float) + + +class ProjectTemplate(Enum): + DEFAULT = "default" + DBT = "dbt" + EMPTY = "empty" + DLT = "dlt" + + +class InitCliMode(Enum): + DEFAULT = "default" + FLOW = "flow" + + +def _gen_config( + engine_type: t.Optional[str], + settings: t.Optional[str], + start: t.Optional[str], + template: ProjectTemplate, + cli_mode: InitCliMode, + dialect: t.Optional[str] = None, +) -> str: + project_dialect = dialect or DIALECT_TO_TYPE.get(engine_type) + + connection_settings = ( + settings + or """ type: duckdb + database: db.db""" + ) + + if not settings and template != ProjectTemplate.DBT: + doc_link = "https://sqlmesh.readthedocs.io/en/stable/integrations/engines{engine_link}" + engine_link = "" + + if engine_type in CONNECTION_CONFIG_TO_TYPE: + required_fields = [] + non_required_fields = [] + + for name, field in CONNECTION_CONFIG_TO_TYPE[engine_type].model_fields.items(): + field_name = field.alias or name + + default_value = field.get_default() + + if isinstance(default_value, Enum): + default_value = default_value.value + elif not isinstance(default_value, PRIMITIVES): + default_value = "" + + required = field.is_required() or field_name == "type" + option_str = f" {'# ' if not required else ''}{field_name}: {default_value}\n" + + # specify the DuckDB database field so quickstart runs out of the box + if engine_type == "duckdb" and field_name == "database": + option_str = " database: db.db\n" + required = True + + if required: + required_fields.append(option_str) + else: + non_required_fields.append(option_str) + + connection_settings = "".join(required_fields + non_required_fields) + + engine_link = f"/{engine_type}/#connection-options" + + connection_settings = ( + " # For more information on configuring the connection to your execution engine, visit:\n" + " # https://sqlmesh.readthedocs.io/en/stable/reference/configuration/#connection\n" + f" # {doc_link.format(engine_link=engine_link)}\n{connection_settings}" + ) + + default_configs = { + ProjectTemplate.DEFAULT: f"""# --- Gateway Connection --- +gateways: + {engine_type}: + connection: +{connection_settings} +default_gateway: {engine_type} + +# --- Model Defaults --- +# https://sqlmesh.readthedocs.io/en/stable/reference/model_configuration/#model-defaults + +model_defaults: + dialect: {project_dialect} + start: {start or yesterday_ds()} # Start date for backfill history + cron: '@daily' # Run models daily at 12am UTC (can override per model) + +# --- Linting Rules --- +# Enforce standards for your team +# https://sqlmesh.readthedocs.io/en/stable/guides/linter/ + +linter: + enabled: true + rules: + - ambiguousorinvalidcolumn + - invalidselectstarexpansion + - noambiguousprojections +""", + ProjectTemplate.DBT: f"""# --- DBT-specific options --- +dbt: + # This configuration ensures that each dbt target gets its own isolated state. + # The inferred state schemas are named "sqlmesh_state__", eg "sqlmesh_state_jaffle_shop_dev" + # If this is undesirable, you may manually configure the gateway to use a specific state schema name + # https://sqlmesh.readthedocs.io/en/stable/integrations/dbt/#selecting-a-different-state-connection + infer_state_schema_name: True + +# --- Virtual Data Environment Mode --- +# Enable Virtual Data Environments (VDE) for *development* environments. +# Note that the production environment in dbt projects is not virtual by default to maintain compatibility with existing tooling. +# https://sqlmesh.readthedocs.io/en/stable/guides/configuration/#virtual-data-environment-modes +virtual_environment_mode: {VirtualEnvironmentMode.DEV_ONLY.lower()} + +# --- Plan Defaults --- +# https://sqlmesh.readthedocs.io/en/stable/reference/configuration/#plan +plan: + # For Virtual Data Environments, this ensures that any changes are always considered against prod, + # rather than the previous state of that environment + always_recreate_environment: True + +# --- Model Defaults --- +# https://sqlmesh.readthedocs.io/en/stable/reference/model_configuration/#model-defaults +model_defaults: + start: {start or yesterday_ds()} +""", + } + + default_configs[ProjectTemplate.EMPTY] = default_configs[ProjectTemplate.DEFAULT] + default_configs[ProjectTemplate.DLT] = default_configs[ProjectTemplate.DEFAULT] + + flow_cli_mode = """ +# FLOW: Minimal prompts, automatic changes, summary output +# https://sqlmesh.readthedocs.io/en/stable/reference/configuration/#plan + +plan: + no_diff: true # Hide detailed text differences for changed models + no_prompts: true # No interactive prompts + auto_apply: true # Apply changes automatically + +# --- Optional: Set a default target environment --- +# This is intended for local development to prevent users from accidentally applying plans to the prod environment. +# It is a development only config and should NOT be committed to your git repo. +# https://sqlmesh.readthedocs.io/en/stable/guides/configuration/#default-target-environment + +# Uncomment the following line to use a default target environment derived from the logged in user's name. +# default_target_environment: dev_{{ user() }} + +# Example usage: +# sqlmesh plan # Automatically resolves to: sqlmesh plan dev_yourname +# sqlmesh plan prod # Specify `prod` to apply changes to production +""" + + return default_configs[template] + (flow_cli_mode if cli_mode == InitCliMode.FLOW else "") + + +@dataclass +class ExampleObjects: + sql_models: t.Dict[str, str] + python_models: t.Dict[str, str] + seeds: t.Dict[str, str] + audits: t.Dict[str, str] + tests: t.Dict[str, str] + sql_macros: t.Dict[str, str] + python_macros: t.Dict[str, str] + + +def _gen_example_objects(schema_name: str) -> ExampleObjects: + sql_models: t.Dict[str, str] = {} + python_models: t.Dict[str, str] = {} + seeds: t.Dict[str, str] = {} + audits: t.Dict[str, str] = {} + tests: t.Dict[str, str] = {} + sql_macros: t.Dict[str, str] = {} + python_macros: t.Dict[str, str] = {"__init__": ""} + + full_model_name = f"{schema_name}.full_model" + incremental_model_name = f"{schema_name}.incremental_model" + seed_model_name = f"{schema_name}.seed_model" + + sql_models[full_model_name] = f"""MODEL ( + name {full_model_name}, + kind FULL, + cron '@daily', + grain item_id, + audits (assert_positive_order_ids), +); + +SELECT + item_id, + COUNT(DISTINCT id) AS num_orders, +FROM + {incremental_model_name} +GROUP BY item_id + """ + + sql_models[incremental_model_name] = f"""MODEL ( + name {incremental_model_name}, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column event_date + ), + start '2020-01-01', + cron '@daily', + grain (id, event_date) +); + +SELECT + id, + item_id, + event_date, +FROM + {seed_model_name} +WHERE + event_date BETWEEN @start_date AND @end_date + """ + + sql_models[seed_model_name] = f"""MODEL ( + name {seed_model_name}, + kind SEED ( + path '../seeds/seed_data.csv' + ), + columns ( + id INTEGER, + item_id INTEGER, + event_date DATE + ), + grain (id, event_date) +); + """ + + seeds["seed_data"] = """id,item_id,event_date +1,2,2020-01-01 +2,1,2020-01-01 +3,3,2020-01-03 +4,1,2020-01-04 +5,1,2020-01-05 +6,1,2020-01-06 +7,1,2020-01-07 +""" + + audits["assert_positive_order_ids"] = """AUDIT ( + name assert_positive_order_ids, +); + +SELECT * +FROM @this_model +WHERE + item_id < 0 + """ + + tests["test_full_model"] = f"""test_example_full_model: + model: {full_model_name} + inputs: + {incremental_model_name}: + rows: + - id: 1 + item_id: 1 + - id: 2 + item_id: 1 + - id: 3 + item_id: 2 + outputs: + query: + rows: + - item_id: 1 + num_orders: 2 + - item_id: 2 + num_orders: 1 + """ + + return ExampleObjects( + sql_models=sql_models, + python_models=python_models, + seeds=seeds, + audits=audits, + tests=tests, + python_macros=python_macros, + sql_macros=sql_macros, + ) + + +def init_example_project( + path: t.Union[str, Path], + engine_type: t.Optional[str], + dialect: t.Optional[str] = None, + template: ProjectTemplate = ProjectTemplate.DEFAULT, + pipeline: t.Optional[str] = None, + dlt_path: t.Optional[str] = None, + schema_name: str = "sqlmesh_example", + cli_mode: InitCliMode = InitCliMode.DEFAULT, + start: t.Optional[str] = None, +) -> Path: + root_path = Path(path) + + config_path = root_path / "config.yaml" + if template == ProjectTemplate.DBT: + # name the config file `sqlmesh.yaml` to make it clear that within the context of all + # the existing yaml files DBT project, this one specifically relates to configuring the sqlmesh engine + config_path = root_path / "sqlmesh.yaml" + + audits_path = root_path / "audits" + macros_path = root_path / "macros" + models_path = root_path / "models" + seeds_path = root_path / "seeds" + tests_path = root_path / "tests" + + if config_path.exists(): + raise SQLMeshError( + f"Found an existing config file '{config_path}'.\n\nPlease change to another directory or remove the existing file." + ) + + if template == ProjectTemplate.DBT and not Path(root_path, DBT_PROJECT_FILENAME).exists(): + raise SQLMeshError( + "Required dbt project file 'dbt_project.yml' not found in the current directory.\n\nPlease add it or change directories before running `sqlmesh init` to set up your project." + ) + + engine_types = "', '".join(CONNECTION_CONFIG_TO_TYPE) + if engine_type is None and template != ProjectTemplate.DBT: + raise SQLMeshError( + f"Missing `engine` argument to `sqlmesh init` - please specify a SQL engine for your project. Options: '{engine_types}'." + ) + + if engine_type and engine_type not in CONNECTION_CONFIG_TO_TYPE: + raise SQLMeshError( + f"Invalid engine '{engine_type}'. Please specify one of '{engine_types}'." + ) + + models: t.Set[t.Tuple[str, str]] = set() + settings = None + if engine_type and template == ProjectTemplate.DLT: + project_dialect = dialect or DIALECT_TO_TYPE.get(engine_type) + if pipeline and project_dialect: + dlt_models, settings, start = generate_dlt_models_and_settings( + pipeline_name=pipeline, dialect=project_dialect, dlt_path=dlt_path + ) + else: + raise SQLMeshError( + "Please provide a DLT pipeline with the `--dlt-pipeline` flag to generate a SQLMesh project from DLT." + ) + + _create_config(config_path, engine_type, dialect, settings, start, template, cli_mode) + if template == ProjectTemplate.DBT: + return config_path + + _create_folders([audits_path, macros_path, models_path, seeds_path, tests_path]) + + if template == ProjectTemplate.DLT: + _create_object_files( + models_path, {model[0].split(".")[-1]: model[1] for model in dlt_models}, "sql" + ) + return config_path + + example_objects = _gen_example_objects(schema_name=schema_name) + + if template != ProjectTemplate.EMPTY: + _create_object_files(models_path, example_objects.sql_models, "sql") + _create_object_files(models_path, example_objects.python_models, "py") + _create_object_files(seeds_path, example_objects.seeds, "csv") + _create_object_files(audits_path, example_objects.audits, "sql") + _create_object_files(tests_path, example_objects.tests, "yaml") + _create_object_files(macros_path, example_objects.python_macros, "py") + _create_object_files(macros_path, example_objects.sql_macros, "sql") + + return config_path + + +def _create_folders(target_folders: t.Sequence[Path]) -> None: + for folder_path in target_folders: + folder_path.mkdir(exist_ok=True) + (folder_path / ".gitkeep").touch() + + +def _create_config( + config_path: Path, + engine_type: t.Optional[str], + dialect: t.Optional[str], + settings: t.Optional[str], + start: t.Optional[str], + template: ProjectTemplate, + cli_mode: InitCliMode, +) -> None: + project_config = _gen_config(engine_type, settings, start, template, cli_mode, dialect) + + _write_file( + config_path, + project_config, + ) + + +def _create_object_files(path: Path, object_dict: t.Dict[str, str], file_extension: str) -> None: + for object_name, object_def in object_dict.items(): + # file name is table component of catalog.schema.table + _write_file(path / f"{object_name.split('.')[-1]}.{file_extension}", object_def) + + +def _write_file(path: Path, payload: str) -> None: + with open(path, "w", encoding="utf-8") as fd: + fd.write(payload) + + +def interactive_init( + path: Path, + console: Console, + project_template: t.Optional[ProjectTemplate] = None, +) -> t.Tuple[ProjectTemplate, t.Optional[str], t.Optional[InitCliMode]]: + console.print("──────────────────────────────") + console.print("Welcome to SQLMesh!") + + project_template = _init_template_prompt(console) if not project_template else project_template + + if project_template == ProjectTemplate.DBT: + return (project_template, None, None) + + engine_type = _init_engine_prompt(console) + cli_mode = _init_cli_mode_prompt(console) + + return (project_template, engine_type, cli_mode) + + +def _init_integer_prompt( + console: Console, err_msg_entity: str, num_options: int, retry_func: t.Callable[[t.Any], t.Any] +) -> int: + err_msg = "\nERROR: '{option_str}' is not a valid {err_msg_entity} number - please enter a number between 1 and {num_options} or exit with control+c\n" + while True: + option_str = Prompt.ask("Enter a number", console=console) + + value_error = False + try: + option_num = int(option_str) + except ValueError: + value_error = True + + if value_error or option_num < 1 or option_num > num_options: + console.print( + err_msg.format( + option_str=option_str, err_msg_entity=err_msg_entity, num_options=num_options + ), + style="red", + ) + continue + console.print("") + return option_num + + +def _init_display_choices(values_dict: t.Dict[str, str], console: Console) -> t.Dict[int, str]: + display_num_to_value = {} + for i, value_str in enumerate(values_dict.keys()): + console.print(f" \\[{i + 1}] {' ' if i < 9 else ''}{value_str} {values_dict[value_str]}") + display_num_to_value[i + 1] = value_str + console.print("") + return display_num_to_value + + +def _init_template_prompt(console: Console) -> ProjectTemplate: + console.print("──────────────────────────────\n") + console.print("What type of project do you want to set up?\n") + + # These are ordered for user display - do not reorder + template_descriptions = { + ProjectTemplate.DEFAULT.name: "- Create SQLMesh example project models and files", + ProjectTemplate.DBT.value: " - You have an existing dbt project and want to run it with SQLMesh", + ProjectTemplate.EMPTY.name: " - Create a SQLMesh configuration file and project directories only", + } + + display_num_to_template = _init_display_choices(template_descriptions, console) + + template_num = _init_integer_prompt( + console, "project type", len(template_descriptions), _init_template_prompt + ) + + return ProjectTemplate(display_num_to_template[template_num].lower()) + + +def _init_engine_prompt(console: Console) -> str: + console.print("──────────────────────────────\n") + console.print("Choose your SQL engine:\n") + + # INIT_DISPLAY_INFO_TO_TYPE is a dict of {engine_type: (display_order, display_name)} + DISPLAY_NAME_TO_TYPE = {v[1]: k for k, v in INIT_DISPLAY_INFO_TO_TYPE.items()} + ordered_engine_display_names = { + info[1]: "" for info in sorted(INIT_DISPLAY_INFO_TO_TYPE.values(), key=lambda x: x[0]) + } + display_num_to_display_name = _init_display_choices(ordered_engine_display_names, console) + + engine_num = _init_integer_prompt( + console, "engine", len(ordered_engine_display_names), _init_engine_prompt + ) + + return DISPLAY_NAME_TO_TYPE[display_num_to_display_name[engine_num]] + + +def _init_cli_mode_prompt(console: Console) -> InitCliMode: + console.print("──────────────────────────────\n") + console.print("Choose your SQLMesh CLI experience:\n") + + cli_mode_descriptions = { + InitCliMode.DEFAULT.name: "- See and control every detail", + InitCliMode.FLOW.name: " - Automatically run changes and show summary output", + } + + display_num_to_cli_mode = _init_display_choices(cli_mode_descriptions, console) + + cli_mode_num = _init_integer_prompt( + console, "config", len(cli_mode_descriptions), _init_cli_mode_prompt + ) + + return InitCliMode(display_num_to_cli_mode[cli_mode_num].lower()) diff --git a/sqlmesh/core/_typing.py b/sqlmesh/core/_typing.py index 197d07bf2d..8e28312c1a 100644 --- a/sqlmesh/core/_typing.py +++ b/sqlmesh/core/_typing.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sys import typing as t from sqlglot import exp @@ -9,3 +10,9 @@ SchemaName = t.Union[str, exp.Table] SessionProperties = t.Dict[str, t.Union[exp.Expression, str, int, float, bool]] CustomMaterializationProperties = t.Dict[str, t.Union[exp.Expression, str, int, float, bool]] + + +if sys.version_info >= (3, 11): + from typing import Self as Self +else: + from typing_extensions import Self as Self diff --git a/sqlmesh/core/analytics/__init__.py b/sqlmesh/core/analytics/__init__.py index 1b87aae829..fcf6d52064 100644 --- a/sqlmesh/core/analytics/__init__.py +++ b/sqlmesh/core/analytics/__init__.py @@ -9,6 +9,17 @@ from sqlmesh.core.analytics.dispatcher import AsyncEventDispatcher, NoopEventDispatcher from sqlmesh.utils import str_to_bool +if t.TYPE_CHECKING: + import sys + + if sys.version_info >= (3, 10): + from typing import ParamSpec + else: + from typing_extensions import ParamSpec + + _P = ParamSpec("_P") + _T = t.TypeVar("_T") + def init_collector() -> AnalyticsCollector: dispatcher = ( @@ -31,9 +42,9 @@ def disable_analytics() -> None: collector = AnalyticsCollector(dispatcher=NoopEventDispatcher()) -def cli_analytics(func: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]: +def cli_analytics(func: t.Callable[_P, _T]) -> t.Callable[_P, _T]: @wraps(func) - def wrapper(*args: t.List[t.Any], **kwargs: t.Any) -> t.Any: + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: import click from click.core import ParameterSource @@ -73,9 +84,9 @@ def wrapper(*args: t.List[t.Any], **kwargs: t.Any) -> t.Any: return wrapper -def python_api_analytics(func: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]: +def python_api_analytics(func: t.Callable[_P, _T]) -> t.Callable[_P, _T]: @wraps(func) - def wrapper(*args: t.List[t.Any], **kwargs: t.Any) -> t.Any: + def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T: import inspect from sqlmesh import magics diff --git a/sqlmesh/core/analytics/collector.py b/sqlmesh/core/analytics/collector.py index fdb2b667eb..cfdb60aadd 100644 --- a/sqlmesh/core/analytics/collector.py +++ b/sqlmesh/core/analytics/collector.py @@ -17,7 +17,7 @@ if t.TYPE_CHECKING: from sqlmesh.cicd.config import CICDBotConfig - from sqlmesh.core.plan import Plan + from sqlmesh.core.plan import EvaluatablePlan from sqlmesh.core.snapshot import Snapshot @@ -136,7 +136,7 @@ def on_project_loaded( "project_name_hash": _anonymize(project_name), } - if project_type == c.DBT: + if project_type in {c.DBT, c.HYBRID}: from dbt.version import __version__ as dbt_version event_data["dbt_version"] = dbt_version @@ -146,7 +146,7 @@ def on_project_loaded( def on_plan_apply_start( self, *, - plan: Plan, + plan: EvaluatablePlan, engine_type: t.Optional[str], state_sync_type: t.Optional[str], scheduler_type: str, @@ -157,7 +157,7 @@ def on_plan_apply_start( plan: The plan that is being applied. engine_type: The type of the target engine. state_sync_type: The type of the engine used to store the SQLMesh state. - scheduler_type: The type of the scheduler being used. Eg. "builtin" or "airflow". + scheduler_type: The type of the scheduler being used. Eg. "builtin". """ self._add_event( "PLAN_APPLY_START", @@ -172,11 +172,15 @@ def on_plan_apply_start( "forward_only": plan.forward_only, "ensure_finalized_snapshots": plan.ensure_finalized_snapshots, "has_restatements": bool(plan.restatements), - "directly_modified_count": len(plan.directly_modified), + "directly_modified_count": len(plan.directly_modified_snapshots), "indirectly_modified_count": len( - {s_id for s_ids in plan.indirectly_modified.values() for s_id in s_ids} + { + s_id + for s_ids in plan.indirectly_modified_snapshots.values() + for s_id in s_ids + } ), - "environment_name_hash": _anonymize(plan.environment_naming_info.name), + "environment_name_hash": _anonymize(plan.environment.name), }, ) @@ -248,11 +252,15 @@ def on_run_start(self, *, engine_type: str, state_sync_type: str) -> str: ) return run_id - def on_run_end(self, *, run_id: str, succeeded: bool, error: t.Optional[t.Any] = None) -> None: + def on_run_end( + self, *, run_id: str, succeeded: bool, interrupted: bool, error: t.Optional[t.Any] = None + ) -> None: """Called after a run ends. Args: run_id: The ID of the run. + succeeded: Whether the run succeeded. + interrupted: Whether the run was interrupted. error: The error that occurred during the run, if any. """ self._add_event( @@ -260,6 +268,7 @@ def on_run_end(self, *, run_id: str, succeeded: bool, error: t.Optional[t.Any] = { "run_id": run_id, "succeeded": succeeded, + "interrupted": interrupted, "error": type(error).__name__ if error else None, }, ) diff --git a/sqlmesh/core/audit/__init__.py b/sqlmesh/core/audit/__init__.py index 4fd74c2e49..65f77a8eca 100644 --- a/sqlmesh/core/audit/__init__.py +++ b/sqlmesh/core/audit/__init__.py @@ -1,29 +1,7 @@ -import inspect -import typing as t -from types import ModuleType - -from sqlmesh.core.audit import builtin from sqlmesh.core.audit.definition import ( Audit as Audit, - AuditResult as AuditResult, ModelAudit as ModelAudit, StandaloneAudit as StandaloneAudit, load_audit as load_audit, load_multiple_audits as load_multiple_audits, ) - - -def create_non_blocking_copy(audit: Audit) -> Audit: - return audit.copy(update={"name": f"{audit.name}_non_blocking", "blocking": False}) - - -def _discover_audits(modules: t.Iterable[ModuleType]) -> t.Dict[str, Audit]: - return { - audit.name: audit - for module in modules - for _, model_audit in inspect.getmembers(module, lambda v: isinstance(v, ModelAudit)) - for audit in (model_audit, create_non_blocking_copy(model_audit)) - } - - -BUILT_IN_AUDITS = _discover_audits([builtin]) diff --git a/sqlmesh/core/audit/builtin.py b/sqlmesh/core/audit/builtin.py index dbaf59f19b..b4eaab8d50 100644 --- a/sqlmesh/core/audit/builtin.py +++ b/sqlmesh/core/audit/builtin.py @@ -1,8 +1,16 @@ from __future__ import annotations +import inspect +import sys + from sqlglot import exp -from sqlmesh.core.audit.definition import ModelAudit +from sqlmesh.core.audit.definition import Audit, ModelAudit + + +def create_non_blocking_copy(audit: Audit) -> Audit: + return audit.copy(update={"name": f"{audit.name}_non_blocking", "blocking": False}) + # not_null(columns=(column_1, column_2)) not_null_audit = ModelAudit( @@ -341,7 +349,7 @@ @patterns, c -> NOT REGEXP_LIKE(@column, c) ), - (l, r) -> l OR r + (l, r) -> l AND r ), @condition, ) @@ -375,14 +383,15 @@ query=""" SELECT * FROM @this_model -WHERE @condition AND ( +WHERE @AND( @REDUCE( @EACH( @patterns, c -> NOT @column LIKE c ), - (l, r) -> l OR r - ) + (l, r) -> l AND r + ), + @condition, ) """, ) @@ -407,7 +416,7 @@ """, ) -# z_score_audit(column=column_name, threshold=3) +# z_score(column=column_name, threshold=3) z_score_audit = ModelAudit( name="z_score", defaults={"condition": exp.true()}, @@ -427,7 +436,7 @@ """, ) -# string_length_between_audit(column=column_name, max_v=22) +# string_length_between(column=column_name, max_v=22) string_length_between_audit = ModelAudit( name="string_length_between", defaults={ @@ -452,7 +461,7 @@ """, ) -# string_length_equal_audit(column=column_name, v=22) +# string_length_equal(column=column_name, v=22) string_length_equal_audit = ModelAudit( name="string_length_equal", defaults={"condition": exp.true()}, @@ -764,3 +773,12 @@ # ) AS tgt ON src.cnt >= tgt.cnt # """, # ) + + +BUILT_IN_AUDITS = { + audit.name: audit + for _, model_audit in inspect.getmembers( + sys.modules[__name__], lambda v: isinstance(v, ModelAudit) + ) + for audit in (model_audit, create_non_blocking_copy(model_audit)) +} diff --git a/sqlmesh/core/audit/definition.py b/sqlmesh/core/audit/definition.py index aa98f6e7b7..9f470872fe 100644 --- a/sqlmesh/core/audit/definition.py +++ b/sqlmesh/core/audit/definition.py @@ -1,27 +1,25 @@ from __future__ import annotations import pathlib -import sys import typing as t from functools import cached_property from pathlib import Path from pydantic import Field from sqlglot import exp -from sqlglot.optimizer.qualify_columns import quote_identifiers from sqlglot.optimizer.simplify import gen -from sqlmesh.core import constants as c from sqlmesh.core import dialect as d from sqlmesh.core.macros import MacroRegistry, macro from sqlmesh.core.model.common import ( bool_validator, default_catalog_validator, depends_on_validator, - expression_validator, + sort_python_env, + sorted_python_env_payloads, ) -from sqlmesh.core.model.definition import _Model, _python_env, _single_value_or_tuple -from sqlmesh.core.node import _Node +from sqlmesh.core.model.common import make_python_env, single_value_or_tuple, ParsableSql +from sqlmesh.core.node import _Node, DbtInfoMixin, DbtNodeInfo from sqlmesh.core.renderer import QueryRenderer from sqlmesh.utils.date import TimeLike from sqlmesh.utils.errors import AuditConfigError, SQLMeshError, raise_config_error @@ -31,21 +29,12 @@ extract_macro_references_and_variables, ) from sqlmesh.utils.metaprogramming import Executable -from sqlmesh.utils.pydantic import ( - PydanticModel, - field_validator, - model_validator, - model_validator_v1_args, -) +from sqlmesh.utils.pydantic import PydanticModel, field_validator, model_validator if t.TYPE_CHECKING: + from sqlmesh.core._typing import Self from sqlmesh.core.snapshot import DeployabilityIndex, Node, Snapshot -if sys.version_info >= (3, 9): - from typing import Literal -else: - from typing_extensions import Literal - class AuditCommonMetaMixin: """ @@ -77,68 +66,32 @@ class AuditMixin(AuditCommonMetaMixin): jinja_macros: A registry of jinja macros to use when rendering the audit query. """ - query: t.Union[exp.Query, d.JinjaQuery] + query_: ParsableSql defaults: t.Dict[str, exp.Expression] - expressions_: t.Optional[t.List[exp.Expression]] + expressions_: t.Optional[t.List[ParsableSql]] jinja_macros: JinjaMacroRegistry + formatting: t.Optional[bool] - def render_query( - self, - snapshot_or_node: t.Union[Snapshot, _Node], - *, - start: t.Optional[TimeLike] = None, - end: t.Optional[TimeLike] = None, - execution_time: t.Optional[TimeLike] = None, - snapshots: t.Optional[t.Dict[str, Snapshot]] = None, - deployability_index: t.Optional[DeployabilityIndex] = None, - **kwargs: t.Any, - ) -> exp.Query: - """Renders the audit's query. - - Args: - snapshot_or_node: The snapshot or node which is being audited. - start: The start datetime to render. Defaults to epoch start. - end: The end datetime to render. Defaults to epoch start. - execution_time: The date/time time reference to use for execution time. - snapshots: All snapshots (by name) to use for mapping of physical locations. - audit_name: The name of audit if the query to render is for an audit. - deployability_index: Determines snapshots that are deployable in the context of this render. - kwargs: Additional kwargs to pass to the renderer. - - Returns: - The rendered expression. - """ - node = snapshot_or_node if isinstance(snapshot_or_node, _Node) else snapshot_or_node.node - query_renderer = self._create_query_renderer(node) - - rendered_query = query_renderer.render( - start=start, - end=end, - execution_time=execution_time, - snapshots=snapshots, - deployability_index=deployability_index, - **{**self.defaults, **kwargs}, # type: ignore - ) - - if rendered_query is None: - raise SQLMeshError( - f"Failed to render query for audit '{self.name}', node '{node.name}'." - ) - - return rendered_query + @property + def query(self) -> t.Union[exp.Query, d.JinjaQuery]: + return t.cast(t.Union[exp.Query, d.JinjaQuery], self.query_.parse(self.dialect)) @property def expressions(self) -> t.List[exp.Expression]: - return self.expressions_ or [] + if not self.expressions_: + return [] + result = [] + for e in self.expressions_: + parsed = e.parse(self.dialect) + if not isinstance(parsed, exp.Semicolon): + result.append(parsed) + return result @property def macro_definitions(self) -> t.List[d.MacroDef]: """All macro definitions from the list of expressions.""" return [s for s in self.expressions if isinstance(s, d.MacroDef)] - def _create_query_renderer(self, node: _Node) -> QueryRenderer: - raise NotImplementedError - @field_validator("name", "dialect", mode="before", check_fields=False) def audit_string_validator(cls: t.Type, v: t.Any) -> t.Optional[str]: @@ -148,19 +101,26 @@ def audit_string_validator(cls: t.Type, v: t.Any) -> t.Optional[str]: @field_validator("defaults", mode="before", check_fields=False) -def audit_map_validator(cls: t.Type, v: t.Any) -> t.Dict[str, t.Any]: +def audit_map_validator(cls: t.Type, v: t.Any, values: t.Any) -> t.Dict[str, t.Any]: + from sqlmesh.utils.pydantic import get_dialect + + if isinstance(v, exp.Paren): + return dict([_maybe_parse_arg_pair(v.unnest())]) if isinstance(v, (exp.Tuple, exp.Array)): return dict(map(_maybe_parse_arg_pair, v.expressions)) - elif isinstance(v, dict): - return v - else: - raise_config_error( - "Defaults must be a tuple of exp.EQ or a dict", error_type=AuditConfigError - ) + if isinstance(v, dict): + dialect = get_dialect(values) + return { + key: value + if isinstance(value, exp.Expression) + else d.parse_one(str(value), dialect=dialect) + for key, value in v.items() + } + raise_config_error("Defaults must be a tuple of exp.EQ or a dict", error_type=AuditConfigError) return {} -class ModelAudit(PydanticModel, AuditMixin, frozen=True): +class ModelAudit(PydanticModel, AuditMixin, DbtInfoMixin, frozen=True): """ Audit is an assertion made about your tables. @@ -171,99 +131,30 @@ class ModelAudit(PydanticModel, AuditMixin, frozen=True): dialect: str = "" skip: bool = False blocking: bool = True - standalone: Literal[False] = False - query: t.Union[exp.Query, d.JinjaQuery] + standalone: t.Literal[False] = False + query_: ParsableSql = Field(alias="query") defaults: t.Dict[str, exp.Expression] = {} - expressions_: t.Optional[t.List[exp.Expression]] = Field(default=None, alias="expressions") + expressions_: t.Optional[t.List[ParsableSql]] = Field(default=None, alias="expressions") jinja_macros: JinjaMacroRegistry = JinjaMacroRegistry() + formatting: t.Optional[bool] = Field(default=None, exclude=True) + dbt_node_info_: t.Optional[DbtNodeInfo] = Field(alias="dbt_node_info", default=None) _path: t.Optional[Path] = None # Validators - _query_validator = expression_validator + _query_validator = ParsableSql.validator() _bool_validator = bool_validator _string_validator = audit_string_validator _map_validator = audit_map_validator - def render_query( - self, - snapshot_or_node: t.Union[Snapshot, _Node], - *, - start: t.Optional[TimeLike] = None, - end: t.Optional[TimeLike] = None, - execution_time: t.Optional[TimeLike] = None, - snapshots: t.Optional[t.Dict[str, Snapshot]] = None, - deployability_index: t.Optional[DeployabilityIndex] = None, - **kwargs: t.Any, - ) -> exp.Query: - from sqlmesh.core.snapshot import DeployabilityIndex, Snapshot - - deployability_index = deployability_index or DeployabilityIndex.all_deployable() - - extra_kwargs = {} - - node = snapshot_or_node if isinstance(snapshot_or_node, _Node) else snapshot_or_node.node - this_model = kwargs.pop("this_model", None) or ( - node.fqn - if isinstance(snapshot_or_node, _Node) - else t.cast(Snapshot, snapshot_or_node).table_name( - deployability_index.is_deployable(snapshot_or_node) - ) - ) - - columns_to_types: t.Optional[t.Dict[str, t.Any]] = None - if "engine_adapter" in kwargs: - try: - columns_to_types = kwargs["engine_adapter"].columns(this_model) - except Exception: - pass - - node = t.cast(_Model, node) - if node.time_column: - where = node.time_column.column.between( - node.convert_to_time_column(start or c.EPOCH, columns_to_types), - node.convert_to_time_column(end or c.EPOCH, columns_to_types), - ) - else: - where = None - - # The model's name is already normalized, but in case of snapshots we also prepend a - # case-sensitive physical schema name, so we quote here to ensure that we won't have - # a broken schema reference after the resulting query is normalized in `render`. - quoted_model_name = quote_identifiers( - exp.to_table(this_model, dialect=self.dialect), dialect=self.dialect - ) - extra_kwargs["this_model"] = ( - exp.select("*").from_(quoted_model_name).where(where).subquery() - ) - - return super().render_query( - snapshot_or_node, - start=start, - end=end, - execution_time=execution_time, - snapshots=snapshots, - deployability_index=deployability_index, - **{**extra_kwargs, **kwargs}, - ) - - def _create_query_renderer(self, node: _Node) -> QueryRenderer: - model = t.cast(_Model, node) - return QueryRenderer( - self.query, - self.dialect or model.dialect, - self.macro_definitions, - path=self._path or Path(), - jinja_macro_registry=self.jinja_macros, - python_env=model.python_env, - only_execution_time=model.kind.only_execution_time, - default_catalog=model.default_catalog, - ) - def __str__(self) -> str: path = f": {self._path.name}" if self._path else "" return f"{self.__class__.__name__}<{self.name}{path}>" + @property + def dbt_node_info(self) -> t.Optional[DbtNodeInfo]: + return self.dbt_node_info_ + class StandaloneAudit(_Node, AuditMixin): """ @@ -276,19 +167,20 @@ class StandaloneAudit(_Node, AuditMixin): dialect: str = "" skip: bool = False blocking: bool = False - standalone: Literal[True] = True - query: t.Union[exp.Query, d.JinjaQuery] + standalone: t.Literal[True] = True + query_: ParsableSql = Field(alias="query") defaults: t.Dict[str, exp.Expression] = {} - expressions_: t.Optional[t.List[exp.Expression]] = Field(default=None, alias="expressions") + expressions_: t.Optional[t.List[ParsableSql]] = Field(default=None, alias="expressions") jinja_macros: JinjaMacroRegistry = JinjaMacroRegistry() default_catalog: t.Optional[str] = None depends_on_: t.Optional[t.Set[str]] = Field(default=None, alias="depends_on") - python_env_: t.Optional[t.Dict[str, Executable]] = Field(default=None, alias="python_env") + python_env: t.Dict[str, Executable] = {} + formatting: t.Optional[bool] = Field(default=None, exclude=True) - source_type: Literal["audit"] = "audit" + source_type: t.Literal["audit"] = "audit" # Validators - _query_validator = expression_validator + _query_validator = ParsableSql.validator() _bool_validator = bool_validator _string_validator = audit_string_validator _map_validator = audit_map_validator @@ -296,18 +188,63 @@ class StandaloneAudit(_Node, AuditMixin): _depends_on_validator = depends_on_validator @model_validator(mode="after") - @model_validator_v1_args - def _node_root_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: - if values.get("blocking"): - name = values.get("name") - raise AuditConfigError(f"Standalone audits cannot be blocking: '{name}'.") - return values + def _node_root_validator(self) -> Self: + if self.blocking: + raise AuditConfigError(f"Standalone audits cannot be blocking: '{self.name}'.") + return self + + def render_audit_query( + self, + *, + start: t.Optional[TimeLike] = None, + end: t.Optional[TimeLike] = None, + execution_time: t.Optional[TimeLike] = None, + snapshots: t.Optional[t.Dict[str, Snapshot]] = None, + deployability_index: t.Optional[DeployabilityIndex] = None, + **kwargs: t.Any, + ) -> exp.Query: + """Renders the audit's query. + + Args: + start: The start datetime to render. Defaults to epoch start. + end: The end datetime to render. Defaults to epoch start. + execution_time: The date/time time reference to use for execution time. + snapshots: All snapshots (by name) to use for mapping of physical locations. + deployability_index: Determines snapshots that are deployable in the context of this render. + kwargs: Additional kwargs to pass to the renderer. + + Returns: + The rendered expression. + """ + query_renderer = QueryRenderer( + self.query, + self.dialect, + self.macro_definitions, + path=self._path or Path(), + jinja_macro_registry=self.jinja_macros, + python_env=self.python_env, + default_catalog=self.default_catalog, + ) + + rendered_query = query_renderer.render( + start=start, + end=end, + execution_time=execution_time, + snapshots=snapshots, + deployability_index=deployability_index, + **{**self.defaults, **kwargs}, # type: ignore + ) + + if rendered_query is None: + raise SQLMeshError(f"Failed to render query for audit '{self.name}'.") + + return rendered_query @cached_property def depends_on(self) -> t.Set[str]: depends_on = self.depends_on_ or set() - query = self.render_query(self) + query = self.render_audit_query() if query is not None: depends_on |= d.find_tables( query, default_catalog=self.default_catalog, dialect=self.dialect @@ -316,14 +253,10 @@ def depends_on(self) -> t.Set[str]: depends_on -= {self.name} return depends_on - @property - def python_env(self) -> t.Dict[str, Executable]: - return self.python_env_ or {} - @property def sorted_python_env(self) -> t.List[t.Tuple[str, Executable]]: """Returns the python env sorted by executable kind and then var name.""" - return sorted(self.python_env.items(), key=lambda x: (x[1].kind, x[0])) + return sort_python_env(self.python_env) @property def data_hash(self) -> str: @@ -336,7 +269,8 @@ def data_hash(self) -> str: # StandaloneAudits do not have a data hash return hash_data("") - def metadata_hash(self, audits: t.Dict[str, ModelAudit]) -> str: + @property + def metadata_hash(self) -> str: """ Computes the metadata hash for the node. @@ -346,24 +280,28 @@ def metadata_hash(self, audits: t.Dict[str, ModelAudit]) -> str: Returns: The metadata hash for the node. """ - data = [ - self.owner, - self.description, - *sorted(self.tags), - str(self.sorted_python_env), - self.stamp, - ] - - query = self.render_query(self) or self.query - data.append(gen(query)) - - return hash_data(data) - - def text_diff(self, other: Node) -> str: + if self._metadata_hash is None: + data = [ + self.owner, + self.description, + *sorted(self.tags), + str(self.sorted_python_env), + self.stamp, + self.cron, + self.cron_tz.key if self.cron_tz else None, + ] + + data.append(self.query_.sql) + data.extend([e.sql for e in self.expressions_ or []]) + self._metadata_hash = hash_data(data) + return self._metadata_hash + + def text_diff(self, other: Node, rendered: bool = False) -> str: """Produce a text diff against another node. Args: other: The node to diff against. + rendered: Whether the diff should be between raw vs rendered nodes Returns: A unified text diff showing additions and deletions. @@ -374,11 +312,17 @@ def text_diff(self, other: Node) -> str: ) return d.text_diff( - self.render_definition(), other.render_definition(), self.dialect, other.dialect + self.render_definition(render_query=rendered), + other.render_definition(render_query=rendered), + self.dialect, + other.dialect, ).strip() def render_definition( - self, include_python: bool = True, include_defaults: bool = False + self, + include_python: bool = True, + include_defaults: bool = False, + render_query: bool = False, ) -> t.List[exp.Expression]: """Returns the original list of sql expressions comprising the model definition. @@ -413,18 +357,19 @@ def render_definition( jinja_expressions = [] python_expressions = [] if include_python: - python_env = d.PythonCode( - expressions=[ - v.payload if v.is_import or v.is_definition else f"{k} = {v.payload}" - for k, v in self.sorted_python_env - ] - ) + python_env = d.PythonCode(expressions=sorted_python_env_payloads(self.python_env)) if python_env.expressions: python_expressions.append(python_env) jinja_expressions = self.jinja_macros.to_expressions() - return [audit, *python_expressions, *jinja_expressions, *self.expressions, self.query] + return [ + audit, + *python_expressions, + *jinja_expressions, + *self.expressions, + self.render_audit_query() if render_query else self.query, + ] @property def is_audit(self) -> bool: @@ -435,33 +380,14 @@ def is_audit(self) -> bool: def meta_fields(self) -> t.Iterable[str]: return set(AuditCommonMetaMixin.__annotations__) | set(_Node.all_field_infos()) - def _create_query_renderer(self, node: _Node) -> QueryRenderer: - return QueryRenderer( - self.query, - self.dialect, - self.macro_definitions, - path=self._path or Path(), - jinja_macro_registry=self.jinja_macros, - python_env=self.python_env, - default_catalog=self.default_catalog, - ) + @property + def audits_with_args(self) -> t.List[t.Tuple[Audit, t.Dict[str, exp.Expression]]]: + return [(self, {})] Audit = t.Union[ModelAudit, StandaloneAudit] -class AuditResult(PydanticModel): - audit: Audit - """The audit this result is for.""" - model: t.Optional[_Model] = None - """The model this audit is for.""" - count: t.Optional[int] = None - """The number of records returned by the audit query. This could be None if the audit was skipped.""" - query: t.Optional[exp.Expression] = None - """The rendered query used by the audit. This could be None if the audit was skipped.""" - skipped: bool = False - - def load_audit( expressions: t.List[exp.Expression], *, @@ -472,6 +398,7 @@ def load_audit( dialect: t.Optional[str] = None, default_catalog: t.Optional[str] = None, variables: t.Optional[t.Dict[str, t.Any]] = None, + project: t.Optional[str] = None, ) -> Audit: """Load an audit from a parsed SQLMesh audit file. @@ -526,30 +453,40 @@ def load_audit( extra_kwargs: t.Dict[str, t.Any] = {} if is_standalone: - jinja_macro_refrences, used_variables = extract_macro_references_and_variables( + jinja_macro_refrences, referenced_variables = extract_macro_references_and_variables( *(gen(s) for s in statements), gen(query), ) jinja_macros = (jinja_macros or JinjaMacroRegistry()).trim(jinja_macro_refrences) for jinja_macro in jinja_macros.root_macros.values(): - used_variables.update(extract_macro_references_and_variables(jinja_macro.definition)[1]) + referenced_variables.update( + extract_macro_references_and_variables(jinja_macro.definition)[1] + ) extra_kwargs["jinja_macros"] = jinja_macros - extra_kwargs["python_env"] = _python_env( + extra_kwargs["python_env"] = make_python_env( [*statements, query], jinja_macro_refrences, module_path, macros or macro.get_registry(), variables=variables, - used_variables=used_variables, + referenced_variables=referenced_variables, ) extra_kwargs["default_catalog"] = default_catalog + if project is not None: + extra_kwargs["project"] = project + + dialect = meta_fields.pop("dialect", dialect) or "" + + parsable_query = ParsableSql.from_parsed_expression(query, dialect, use_meta_sql=True) + parsable_statements = [ + ParsableSql.from_parsed_expression(s, dialect, use_meta_sql=True) for s in statements + ] - dialect = meta_fields.pop("dialect", dialect) try: audit = audit_class( - query=query, - expressions=statements, + query=parsable_query, + expressions=parsable_statements, dialect=dialect, **extra_kwargs, **meta_fields, @@ -571,6 +508,7 @@ def load_multiple_audits( dialect: t.Optional[str] = None, default_catalog: t.Optional[str] = None, variables: t.Optional[t.Dict[str, t.Any]] = None, + project: t.Optional[str] = None, ) -> t.Generator[Audit, None, None]: audit_block: t.List[exp.Expression] = [] for expression in expressions: @@ -585,6 +523,7 @@ def load_multiple_audits( dialect=dialect, default_catalog=default_catalog, variables=variables, + project=project, ) audit_block.clear() audit_block.append(expression) @@ -594,6 +533,7 @@ def load_multiple_audits( dialect=dialect, default_catalog=default_catalog, variables=variables, + project=project, ) @@ -615,6 +555,7 @@ def _maybe_parse_arg_pair(e: exp.Expression) -> t.Tuple[str, exp.Expression]: "blocking": exp.convert, "standalone": exp.convert, "depends_on_": lambda value: exp.Tuple(expressions=sorted(value)), - "tags": _single_value_or_tuple, + "tags": single_value_or_tuple, "default_catalog": exp.to_identifier, + "dbt_node_info_": lambda value: value.to_expression(), } diff --git a/sqlmesh/core/config/__init__.py b/sqlmesh/core/config/__init__.py index efdaa1c710..42ed82c6e6 100644 --- a/sqlmesh/core/config/__init__.py +++ b/sqlmesh/core/config/__init__.py @@ -2,12 +2,18 @@ AutoCategorizationMode as AutoCategorizationMode, CategorizerConfig as CategorizerConfig, ) -from sqlmesh.core.config.common import EnvironmentSuffixTarget as EnvironmentSuffixTarget +from sqlmesh.core.config.common import ( + EnvironmentSuffixTarget as EnvironmentSuffixTarget, + TableNamingConvention as TableNamingConvention, +) from sqlmesh.core.config.connection import ( + AthenaConnectionConfig as AthenaConnectionConfig, + BaseDuckDBConnectionConfig as BaseDuckDBConnectionConfig, BigQueryConnectionConfig as BigQueryConnectionConfig, ConnectionConfig as ConnectionConfig, DatabricksConnectionConfig as DatabricksConnectionConfig, DuckDBConnectionConfig as DuckDBConnectionConfig, + FabricConnectionConfig as FabricConnectionConfig, GCPPostgresConnectionConfig as GCPPostgresConnectionConfig, MotherDuckConnectionConfig as MotherDuckConnectionConfig, MSSQLConnectionConfig as MSSQLConnectionConfig, @@ -16,6 +22,7 @@ RedshiftConnectionConfig as RedshiftConnectionConfig, SnowflakeConnectionConfig as SnowflakeConnectionConfig, SparkConnectionConfig as SparkConnectionConfig, + TrinoConnectionConfig as TrinoConnectionConfig, parse_connection_config as parse_connection_config, ) from sqlmesh.core.config.gateway import GatewayConfig as GatewayConfig @@ -27,12 +34,8 @@ from sqlmesh.core.config.migration import MigrationConfig as MigrationConfig from sqlmesh.core.config.model import ModelDefaultsConfig as ModelDefaultsConfig from sqlmesh.core.config.naming import NameInferenceConfig as NameInferenceConfig +from sqlmesh.core.config.linter import LinterConfig as LinterConfig from sqlmesh.core.config.plan import PlanConfig as PlanConfig -from sqlmesh.core.config.root import Config as Config +from sqlmesh.core.config.root import Config as Config, DbtConfig as DbtConfig from sqlmesh.core.config.run import RunConfig as RunConfig -from sqlmesh.core.config.scheduler import ( - AirflowSchedulerConfig as AirflowSchedulerConfig, - BuiltInSchedulerConfig as BuiltInSchedulerConfig, - CloudComposerSchedulerConfig as CloudComposerSchedulerConfig, - MWAASchedulerConfig as MWAASchedulerConfig, -) +from sqlmesh.core.config.scheduler import BuiltInSchedulerConfig as BuiltInSchedulerConfig diff --git a/sqlmesh/core/config/base.py b/sqlmesh/core/config/base.py index 78d1ed70d7..0da36e4754 100644 --- a/sqlmesh/core/config/base.py +++ b/sqlmesh/core/config/base.py @@ -36,6 +36,15 @@ def update_field( The updated field """ + + def _update_pydantic_config(old: BaseConfig, new: BaseConfig) -> PydanticModel: + if type(new) != type(old): + raise ConfigError( + "NESTED_UPDATE behavior requires both values to have the same type. " + f"{type(old)} and {type(new)} were given instead." + ) + return old.update_with(new) + if not old: return new @@ -78,18 +87,20 @@ def update_field( return combined if update_strategy == UpdateStrategy.NESTED_UPDATE: - if not isinstance(old, BaseConfig): + if not isinstance(old, BaseConfig) and not isinstance(old, dict): raise ConfigError( - f"NESTED_UPDATE behavior requires a config object. {type(old)} was given instead." + f"NESTED_UPDATE behavior requires a config object and a dict of config objects as values. {type(old)} was given instead." ) - if type(new) != type(old): - raise ConfigError( - "NESTED_UPDATE behavior requires both values to have the same type. " - f"{type(old)} and {type(new)} were given instead." - ) + if isinstance(old, dict): + for k, pydantic_model in new.items(): + if k in old: + old[k] = _update_pydantic_config(old[k], pydantic_model) + else: + old[k] = pydantic_model - return old.update_with(new) + return old + return _update_pydantic_config(old, new) raise ConfigError(f"Unknown update strategy {update_strategy}.") diff --git a/sqlmesh/core/config/categorizer.py b/sqlmesh/core/config/categorizer.py index e269b96bb3..3cbedb922e 100644 --- a/sqlmesh/core/config/categorizer.py +++ b/sqlmesh/core/config/categorizer.py @@ -29,7 +29,7 @@ class CategorizerConfig(BaseConfig): """ external: AutoCategorizationMode = AutoCategorizationMode.FULL - python: AutoCategorizationMode = AutoCategorizationMode.OFF + python: AutoCategorizationMode = AutoCategorizationMode.FULL sql: AutoCategorizationMode = AutoCategorizationMode.FULL seed: AutoCategorizationMode = AutoCategorizationMode.FULL diff --git a/sqlmesh/core/config/common.py b/sqlmesh/core/config/common.py index 819740c4ea..dca472d7a9 100644 --- a/sqlmesh/core/config/common.py +++ b/sqlmesh/core/config/common.py @@ -2,16 +2,40 @@ import typing as t from enum import Enum +import re from sqlmesh.utils import classproperty from sqlmesh.utils.errors import ConfigError from sqlmesh.utils.pydantic import field_validator +# Config files that can be present in the project dir +ALL_CONFIG_FILENAMES = ("config.py", "config.yml", "config.yaml", "sqlmesh.yml", "sqlmesh.yaml") + +# For personal paths (~/.sqlmesh/) where python config is not supported +YAML_CONFIG_FILENAMES = tuple(n for n in ALL_CONFIG_FILENAMES if not n.endswith(".py")) + +# Note: is here to prevent having to import from sqlmesh.dbt.loader which introduces a dependency +# on dbt-core in a native project +DBT_PROJECT_FILENAME = "dbt_project.yml" + class EnvironmentSuffixTarget(str, Enum): + # Intended to create virtual environments in their own schemas, with names like "__". The view name is untouched. + # For example, a model named 'sqlmesh_example.full_model' created in an environment called 'dev' + # would have its virtual layer view created as 'sqlmesh_example__dev.full_model' SCHEMA = "schema" + + # Intended to create virtual environments in the same schema as their production counterparts by adjusting the table name. + # For example, a model named 'sqlmesh_example.full_model' created in an environment called 'dev' + # would have its virtual layer view created as "sqlmesh_example.full_model__dev" TABLE = "table" + # Intended to create virtual environments in their own catalogs to preserve the schema and view name of the models + # For example, a model named 'sqlmesh_example.full_model' created in an environment called 'dev' + # with a default catalog of "warehouse" would have its virtual layer view created as "warehouse__dev.sqlmesh_example.full_model" + # note: this only works for engines that can query across catalogs + CATALOG = "catalog" + @property def is_schema(self) -> bool: return self == EnvironmentSuffixTarget.SCHEMA @@ -20,6 +44,10 @@ def is_schema(self) -> bool: def is_table(self) -> bool: return self == EnvironmentSuffixTarget.TABLE + @property + def is_catalog(self) -> bool: + return self == EnvironmentSuffixTarget.CATALOG + @classproperty def default(cls) -> EnvironmentSuffixTarget: return EnvironmentSuffixTarget.SCHEMA @@ -31,6 +59,63 @@ def __repr__(self) -> str: return str(self) +class VirtualEnvironmentMode(str, Enum): + """Mode for virtual environment behavior. + + FULL: Use full virtual environment functionality with versioned table names and virtual layer updates. + DEV_ONLY: Bypass virtual environments in production, using original unversioned model names. + """ + + FULL = "full" + DEV_ONLY = "dev_only" + + @property + def is_full(self) -> bool: + return self == VirtualEnvironmentMode.FULL + + @property + def is_dev_only(self) -> bool: + return self == VirtualEnvironmentMode.DEV_ONLY + + @classproperty + def default(cls) -> VirtualEnvironmentMode: + return VirtualEnvironmentMode.FULL + + def __str__(self) -> str: + return self.name + + def __repr__(self) -> str: + return str(self) + + +class TableNamingConvention(str, Enum): + # Causes table names at the physical layer to follow the convention: + # ____ + SCHEMA_AND_TABLE = "schema_and_table" + + # Causes table names at the physical layer to follow the convention: + # __ + TABLE_ONLY = "table_only" + + # Takes the table name that would be returned from SCHEMA_AND_TABLE and wraps it in md5() + # to generate a hash and prefixes the has with `sqlmesh_md5__`, for the following reasons: + # - at a glance, you can still see it's managed by sqlmesh and that md5 was used to generate the hash + # - unquoted identifiers that start with numbers can trip up DB engine parsers, so having a text prefix prevents this + # This causes table names at the physical layer to follow the convention: + # sqlmesh_md5__3b07384d113edec49eaa6238ad5ff00d + HASH_MD5 = "hash_md5" + + @classproperty + def default(cls) -> TableNamingConvention: + return TableNamingConvention.SCHEMA_AND_TABLE + + def __str__(self) -> str: + return self.name + + def __repr__(self) -> str: + return str(self) + + def _concurrent_tasks_validator(v: t.Any) -> int: if isinstance(v, str): v = int(v) @@ -86,3 +171,16 @@ def _validate_type(v: t.Any) -> None: mode="before", check_fields=False, )(_variables_validator) + + +def compile_regex_mapping(value: t.Dict[str | re.Pattern, t.Any]) -> t.Dict[re.Pattern, t.Any]: + """ + Utility function to compile a dict of { "string regex pattern" : "string value" } into { "": "string value" } + """ + compiled_regexes = {} + for k, v in value.items(): + try: + compiled_regexes[re.compile(k)] = v + except re.error: + raise ConfigError(f"`{k}` is not a valid regular expression.") + return compiled_regexes diff --git a/sqlmesh/core/config/connection.py b/sqlmesh/core/config/connection.py index 38d20f5176..26bfa78730 100644 --- a/sqlmesh/core/config/connection.py +++ b/sqlmesh/core/config/connection.py @@ -4,47 +4,109 @@ import base64 import logging import os +import importlib import pathlib -import sys +import re import typing as t from enum import Enum from functools import partial +import pydantic from pydantic import Field +from pydantic_core import from_json +from packaging import version from sqlglot import exp from sqlglot.helper import subclasses +from sqlglot.errors import ParseError from sqlmesh.core import engine_adapter from sqlmesh.core.config.base import BaseConfig from sqlmesh.core.config.common import ( concurrent_tasks_validator, http_headers_validator, + compile_regex_mapping, ) +from sqlmesh.core.engine_adapter.shared import CatalogSupport from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.utils import debug_mode_enabled, str_to_bool from sqlmesh.utils.errors import ConfigError from sqlmesh.utils.pydantic import ( - PYDANTIC_MAJOR_VERSION, + ValidationInfo, field_validator, model_validator, - model_validator_v1_args, + validation_error_message, + get_concrete_types_from_typehint, ) +from sqlmesh.utils.aws import validate_s3_uri -if sys.version_info >= (3, 9): - from typing import Literal -else: - from typing_extensions import Literal - +if t.TYPE_CHECKING: + from sqlmesh.core._typing import Self logger = logging.getLogger(__name__) -RECOMMENDED_STATE_SYNC_ENGINES = {"postgres", "gcp_postgres", "mysql", "duckdb"} +RECOMMENDED_STATE_SYNC_ENGINES = { + "postgres", + "gcp_postgres", + "mysql", + "mssql", + "azuresql", +} +FORBIDDEN_STATE_SYNC_ENGINES = { + # Do not support row-level operations + "spark", + "trino", + # Nullable types are problematic + "clickhouse", +} +MOTHERDUCK_TOKEN_REGEX = re.compile(r"(\?|\&)(motherduck_token=)(\S*)") +PASSWORD_REGEX = re.compile(r"(password=)(\S+)") + + +def _get_engine_import_validator( + import_name: str, engine_type: str, extra_name: t.Optional[str] = None, decorate: bool = True +) -> t.Callable: + extra_name = extra_name or engine_type + + def validate(cls: t.Any, data: t.Any) -> t.Any: + check_import = ( + str_to_bool(str(data.pop("check_import", True))) if isinstance(data, dict) else True + ) + if not check_import: + return data + try: + importlib.import_module(import_name) + except ImportError: + if debug_mode_enabled(): + raise + + logger.exception("Failed to import the engine library") + + raise ConfigError( + f"Failed to import the '{engine_type}' engine library. This may be due to a missing " + "or incompatible installation. Please ensure the required dependency is installed by " + f'running: `pip install "sqlmesh[{extra_name}]"`. For more details, check the logs ' + "in the 'logs/' folder, or rerun the command with the '--debug' flag." + ) + + return data + + return model_validator(mode="before")(validate) if decorate else validate class ConnectionConfig(abc.ABC, BaseConfig): type_: str + DIALECT: t.ClassVar[str] + DISPLAY_NAME: t.ClassVar[str] + DISPLAY_ORDER: t.ClassVar[int] concurrent_tasks: int register_comments: bool pre_ping: bool + pretty_sql: bool = False + schema_differ_overrides: t.Optional[t.Dict[str, t.Any]] = None + catalog_type_overrides: t.Optional[t.Dict[str, str]] = None + + # Whether to share a single connection across threads or create a new connection per thread. + shared_connection: t.ClassVar[bool] = False @property @abc.abstractmethod @@ -71,11 +133,6 @@ def _extra_engine_config(self) -> t.Dict[str, t.Any]: """kwargs that are for execution config only""" return {} - @property - def _cursor_kwargs(self) -> t.Optional[t.Dict[str, t.Any]]: - """Key-value arguments that will be passed during cursor construction.""" - return None - @property def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]: """A function that is called to initialize the cursor""" @@ -83,9 +140,14 @@ def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]: @property def is_recommended_for_state_sync(self) -> bool: - """Whether this connection is recommended for being used as a state sync for production state syncs""" + """Whether this engine is recommended for being used as a state sync for production state syncs""" return self.type_ in RECOMMENDED_STATE_SYNC_ENGINES + @property + def is_forbidden_for_state_sync(self) -> bool: + """Whether this engine is forbidden from being used as a state sync""" + return self.type_ in FORBIDDEN_STATE_SYNC_ENGINES + @property def _connection_factory_with_kwargs(self) -> t.Callable[[], t.Any]: """A function that is called to return a connection object for the given Engine Adapter""" @@ -101,16 +163,23 @@ def connection_validator(self) -> t.Callable[[], None]: """A function that validates the connection configuration""" return self.create_engine_adapter().ping - def create_engine_adapter(self, register_comments_override: bool = False) -> EngineAdapter: + def create_engine_adapter( + self, register_comments_override: bool = False, concurrent_tasks: t.Optional[int] = None + ) -> EngineAdapter: """Returns a new instance of the Engine Adapter.""" + + concurrent_tasks = concurrent_tasks or self.concurrent_tasks return self._engine_adapter( self._connection_factory_with_kwargs, - multithreaded=self.concurrent_tasks > 1, - cursor_kwargs=self._cursor_kwargs, + multithreaded=concurrent_tasks > 1, default_catalog=self.get_catalog(), cursor_init=self._cursor_init, register_comments=register_comments_override or self.register_comments, pre_ping=self.pre_ping, + pretty_sql=self.pretty_sql, + shared_connection=self.shared_connection, + schema_differ_overrides=self.schema_differ_overrides, + catalog_type_overrides=self.catalog_type_overrides, **self._extra_engine_config, ) @@ -124,29 +193,146 @@ def get_catalog(self) -> t.Optional[str]: return self.db return None + @model_validator(mode="before") + @classmethod + def _expand_json_strings_to_concrete_types(cls, data: t.Any) -> t.Any: + """ + There are situations where a connection config class has a field that is some kind of complex type + (eg a list of strings or a dict) but the value is being supplied from a source such as an environment variable + + When this happens, the value is supplied as a string rather than a Python object. We need some way + of turning this string into the corresponding Python list or dict. + + Rather than doing this piecemeal on every config subclass, this provides a generic implementatation + to identify fields that may be be supplied as JSON strings and handle them transparently + """ + if data and isinstance(data, dict): + for maybe_json_field_name in cls._get_list_and_dict_field_names(): + if (value := data.get(maybe_json_field_name)) and isinstance(value, str): + # crude JSON check as we dont want to try and parse every string we get + value = value.strip() + if value.startswith("{") or value.startswith("["): + data[maybe_json_field_name] = from_json(value) + + return data + + @classmethod + def _get_list_and_dict_field_names(cls) -> t.Set[str]: + field_names = set() + for name, field in cls.model_fields.items(): + if field.annotation: + field_types = get_concrete_types_from_typehint(field.annotation) + + # check if the field type is something that could concievably be supplied as a json string + if any(ft is t for t in (list, tuple, set, dict) for ft in field_types): + field_names.add(name) + + return field_names + + +class DuckDBAttachOptions(BaseConfig): + type: str + path: str + read_only: bool = False + + # DuckLake specific options + data_path: t.Optional[str] = None + encrypted: bool = False + data_inlining_row_limit: t.Optional[int] = None + metadata_schema: t.Optional[str] = None + + def to_sql(self, alias: str) -> str: + options = [] + # 'duckdb' is actually not a supported type, but we'd like to allow it for + # fully qualified attach options or integration testing, similar to duckdb-dbt + if self.type not in ("duckdb", "ducklake", "motherduck"): + options.append(f"TYPE {self.type.upper()}") + if self.read_only: + options.append("READ_ONLY") + + # DuckLake specific options + path = self.path + if self.type == "ducklake": + if not path.startswith("ducklake:"): + path = f"ducklake:{path}" + if self.data_path is not None: + options.append(f"DATA_PATH '{self.data_path}'") + if self.encrypted: + options.append("ENCRYPTED") + if self.data_inlining_row_limit is not None: + options.append(f"DATA_INLINING_ROW_LIMIT {self.data_inlining_row_limit}") + if self.metadata_schema is not None: + options.append(f"METADATA_SCHEMA '{self.metadata_schema}'") + + options_sql = f" ({', '.join(options)})" if options else "" + alias_sql = "" + # TODO: Add support for Postgres schema. Currently adding it blocks access to the information_schema + + # MotherDuck does not support aliasing + alias_sql = ( + f" AS {alias}" if not (self.type == "motherduck" or self.path.startswith("md:")) else "" + ) + return f"ATTACH IF NOT EXISTS '{path}'{alias_sql}{options_sql}" + class BaseDuckDBConnectionConfig(ConnectionConfig): """Common configuration for the DuckDB-based connections. Args: + database: The optional database name. If not specified, the in-memory database will be used. + catalogs: Key is the name of the catalog and value is the path. extensions: A list of autoloadable extensions to load. connector_config: A dictionary of configuration to pass into the duckdb connector. + secrets: A list of dictionaries used to generate DuckDB secrets for authenticating with external services (e.g. S3). + filesystems: A list of dictionaries used to register `fsspec` filesystems to the DuckDB cursor. concurrent_tasks: The maximum number of tasks that can use this connection concurrently. register_comments: Whether or not to register model comments with the SQL engine. pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive. + token: The optional MotherDuck token. If not specified and a MotherDuck path is in the catalog, the user will be prompted to login with their web browser. """ - extensions: t.List[str] = [] + database: t.Optional[str] = None + catalogs: t.Optional[t.Dict[str, t.Union[str, DuckDBAttachOptions]]] = None + extensions: t.List[t.Union[str, t.Dict[str, t.Any]]] = [] connector_config: t.Dict[str, t.Any] = {} + secrets: t.Union[t.List[t.Dict[str, t.Any]], t.Dict[str, t.Dict[str, t.Any]]] = [] + filesystems: t.List[t.Dict[str, t.Any]] = [] - concurrent_tasks: Literal[1] = 1 + concurrent_tasks: int = 1 register_comments: bool = True - pre_ping: Literal[False] = False + pre_ping: t.Literal[False] = False + + token: t.Optional[str] = None + + shared_connection: t.ClassVar[bool] = True + + _data_file_to_adapter: t.ClassVar[t.Dict[str, EngineAdapter]] = {} + + @model_validator(mode="before") + def _validate_database_catalogs(cls, data: t.Any) -> t.Any: + if not isinstance(data, dict): + return data + + db_path = data.get("database") + if db_path and data.get("catalogs"): + raise ConfigError( + "Cannot specify both `database` and `catalogs`. Define all your catalogs in `catalogs` and have the first entry be the default catalog" + ) + if isinstance(db_path, str) and db_path.startswith("md:"): + raise ConfigError( + "Please use connection type 'motherduck' without the `md:` prefix if you want to use a MotherDuck database as the single `database`." + ) + + return data @property def _engine_adapter(self) -> t.Type[EngineAdapter]: return engine_adapter.DuckDBEngineAdapter + @property + def _connection_kwargs_keys(self) -> t.Set[str]: + return {"database"} + @property def _connection_factory(self) -> t.Callable: import duckdb @@ -161,17 +347,85 @@ def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]: def init(cursor: duckdb.DuckDBPyConnection) -> None: for extension in self.extensions: - try: - cursor.execute(f"INSTALL {extension}") - cursor.execute(f"LOAD {extension}") - except Exception as e: - raise ConfigError(f"Failed to load extension {extension}: {e}") + extension = extension if isinstance(extension, dict) else {"name": extension} + + install_command = f"INSTALL {extension['name']}" + + if extension.get("repository"): + install_command = f"{install_command} FROM {extension['repository']}" + + if extension.get("force_install"): + install_command = f"FORCE {install_command}" - for field, setting in self.connector_config.items(): try: - cursor.execute(f"SET {field} = '{setting}'") + cursor.execute(install_command) + cursor.execute(f"LOAD {extension['name']}") except Exception as e: - raise ConfigError(f"Failed to set connector config {field} to {setting}: {e}") + raise ConfigError(f"Failed to load extension {extension['name']}: {e}") + + if self.connector_config: + option_names = list(self.connector_config) + in_part = ",".join("?" for _ in range(len(option_names))) + + cursor.execute( + f"SELECT name, value FROM duckdb_settings() WHERE name IN ({in_part})", + option_names, + ) + + existing_values = {field: setting for field, setting in cursor.fetchall()} + + # only set connector_config items if the values differ from what is already set + # trying to set options like 'temp_directory' even to the same value can throw errors like: + # Not implemented Error: Cannot switch temporary directory after the current one has been used + for field, setting in self.connector_config.items(): + if existing_values.get(field) != setting: + try: + cursor.execute(f"SET {field} = '{setting}'") + except Exception as e: + raise ConfigError( + f"Failed to set connector config {field} to {setting}: {e}" + ) + + if self.secrets: + duckdb_version = duckdb.__version__ + if version.parse(duckdb_version) < version.parse("0.10.0"): + from sqlmesh.core.console import get_console + + get_console().log_warning( + f"DuckDB version {duckdb_version} does not support secrets-based authentication (requires 0.10.0 or later).\n" + "To use secrets, please upgrade DuckDB. For older versions, configure legacy authentication via `connector_config`.\n" + "More info: https://duckdb.org/docs/stable/extensions/httpfs/s3api_legacy_authentication.html" + ) + else: + if isinstance(self.secrets, list): + secrets_items = [(secret_dict, "") for secret_dict in self.secrets] + else: + secrets_items = [ + (secret_dict, secret_name) + for secret_name, secret_dict in self.secrets.items() + ] + + for secret_dict, secret_name in secrets_items: + secret_settings: t.List[str] = [] + for field, setting in secret_dict.items(): + secret_settings.append(f"{field} '{setting}'") + if secret_settings: + secret_clause = ", ".join(secret_settings) + try: + cursor.execute( + f"CREATE OR REPLACE SECRET {secret_name} ({secret_clause});" + ) + except Exception as e: + raise ConfigError(f"Failed to create secret: {e}") + + if self.filesystems: + from fsspec import filesystem # type: ignore + + for file_system in self.filesystems: + options = file_system.copy() + fs = options.pop("fs") + fs = filesystem(fs, **options) + cursor.register_filesystem(fs) for i, (alias, path_options) in enumerate( (getattr(self, "catalogs", None) or {}).items() @@ -182,19 +436,26 @@ def init(cursor: duckdb.DuckDBPyConnection) -> None: identify=True, dialect="duckdb" ) try: - query = ( - path_options.to_sql(alias) - if isinstance(path_options, DuckDBAttachOptions) - else f"ATTACH '{path_options}' AS {alias}" - ) + if isinstance(path_options, DuckDBAttachOptions): + query = path_options.to_sql(alias) + else: + query = f"ATTACH IF NOT EXISTS '{path_options}'" + if not path_options.startswith("md:"): + query += f" AS {alias}" cursor.execute(query) except BinderException as e: # If a user tries to create a catalog pointing at `:memory:` and with the name `memory` # then we don't want to raise since this happens by default. They are just doing this to # set it as the default catalog. - if not ( - 'database with name "memory" already exists' in str(e) - and path_options == ":memory:" + # If a user tried to attach a MotherDuck database/share which has already by attached via + # `ATTACH 'md:'`, then we don't want to raise since this is expected. + if ( + not ( + 'database with name "memory" already exists' in str(e) + and path_options == ":memory:" + ) + and f"""database with name "{path_options.path.replace("md:", "")}" already exists""" + not in str(e) ): raise e if i == 0 and not getattr(self, "database", None): @@ -202,117 +463,104 @@ def init(cursor: duckdb.DuckDBPyConnection) -> None: return init - -class MotherDuckConnectionConfig(BaseDuckDBConnectionConfig): - """Configuration for the MotherDuck connection. - - Args: - database: The database name. - token: The optional MotherDuck token. If not specified, the user will be prompted to login with their web browser. - """ - - database: str - token: t.Optional[str] = None - - type_: Literal["motherduck"] = Field(alias="type", default="motherduck") - - @property - def _connection_kwargs_keys(self) -> t.Set[str]: - return set() - - @property - def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: - """kwargs that are for execution config only""" - from sqlmesh import __version__ - - connection_str = f"md:{self.database}" - if self.token: - connection_str += f"?motherduck_token={self.token}" - return { - "database": connection_str, - "config": {"custom_user_agent": f"SQLMesh/{__version__}"}, - } - - -class DuckDBAttachOptions(BaseConfig): - type: str - path: str - read_only: bool = False - - def to_sql(self, alias: str) -> str: - options = [] - # 'duckdb' is actually not a supported type, but we'd like to allow it for - # fully qualified attach options or integration testing, similar to duckdb-dbt - if self.type != "duckdb": - options.append(f"TYPE {self.type.upper()}") - if self.read_only: - options.append("READ_ONLY") - options_sql = f" ({', '.join(options)})" if options else "" - return f"ATTACH '{self.path}' AS {alias}{options_sql}" - - -class DuckDBConnectionConfig(BaseDuckDBConnectionConfig): - """Configuration for the DuckDB connection. - - Args: - database: The optional database name. If not specified, the in-memory database will be used. - catalogs: Key is the name of the catalog and value is the path. - """ - - database: t.Optional[str] = None - catalogs: t.Optional[t.Dict[str, t.Union[str, DuckDBAttachOptions]]] = None - - type_: Literal["duckdb"] = Field(alias="type", default="duckdb") - - _data_file_to_adapter: t.ClassVar[t.Dict[str, EngineAdapter]] = {} - - @model_validator(mode="before") - @model_validator_v1_args - def _validate_database_catalogs( - cls, values: t.Dict[str, t.Optional[str]] - ) -> t.Dict[str, t.Optional[str]]: - if values.get("database") and values.get("catalogs"): - raise ConfigError( - "Cannot specify both `database` and `catalogs`. Define all your catalogs in `catalogs` and have the first entry be the default catalog" - ) - return values - - @property - def _connection_kwargs_keys(self) -> t.Set[str]: - return {"database"} - - def create_engine_adapter(self, register_comments_override: bool = False) -> EngineAdapter: + def create_engine_adapter( + self, register_comments_override: bool = False, concurrent_tasks: t.Optional[int] = None + ) -> EngineAdapter: """Checks if another engine adapter has already been created that shares a catalog that points to the same data file. If so, it uses that same adapter instead of creating a new one. As a result, any additional configuration associated with the new adapter will be ignored.""" data_files = set((self.catalogs or {}).values()) if self.database: - data_files.add(self.database) + if isinstance(self, MotherDuckConnectionConfig): + data_files.add( + f"md:{self.database}" + + (f"?motherduck_token={self.token}" if self.token else "") + ) + else: + data_files.add(self.database) data_files.discard(":memory:") for data_file in data_files: key = data_file if isinstance(data_file, str) else data_file.path - if adapter := DuckDBConnectionConfig._data_file_to_adapter.get(key): - logger.info(f"Using existing DuckDB adapter due to overlapping data file: {key}") + adapter = BaseDuckDBConnectionConfig._data_file_to_adapter.get(key) + if adapter is not None: + logger.info( + f"Using existing DuckDB adapter due to overlapping data file: {self._mask_sensitive_data(key)}" + ) return adapter if data_files: - logger.info(f"Creating new DuckDB adapter for data files: {data_files}") + masked_files = { + self._mask_sensitive_data(file if isinstance(file, str) else file.path) + for file in data_files + } + logger.info(f"Creating new DuckDB adapter for data files: {masked_files}") else: logger.info("Creating new DuckDB adapter for in-memory database") - adapter = super().create_engine_adapter(register_comments_override) + adapter = super().create_engine_adapter( + register_comments_override, concurrent_tasks=concurrent_tasks + ) for data_file in data_files: key = data_file if isinstance(data_file, str) else data_file.path - DuckDBConnectionConfig._data_file_to_adapter[key] = adapter + BaseDuckDBConnectionConfig._data_file_to_adapter[key] = adapter return adapter def get_catalog(self) -> t.Optional[str]: if self.database: # Remove `:` from the database name in order to handle if `:memory:` is passed in - return pathlib.Path(self.database.replace(":", "")).stem + return pathlib.Path(self.database.replace(":memory:", "memory")).stem if self.catalogs: return list(self.catalogs)[0] return None + def _mask_sensitive_data(self, string: str) -> str: + # Mask MotherDuck tokens with fixed number of asterisks + result = MOTHERDUCK_TOKEN_REGEX.sub( + lambda m: f"{m.group(1)}{m.group(2)}{'*' * 8 if m.group(3) else ''}", string + ) + # Mask PostgreSQL/MySQL passwords with fixed number of asterisks + result = PASSWORD_REGEX.sub(lambda m: f"{m.group(1)}{'*' * 8}", result) + return result + + +class MotherDuckConnectionConfig(BaseDuckDBConnectionConfig): + """Configuration for the MotherDuck connection.""" + + type_: t.Literal["motherduck"] = Field(alias="type", default="motherduck") + DIALECT: t.ClassVar[t.Literal["duckdb"]] = "duckdb" + DISPLAY_NAME: t.ClassVar[t.Literal["MotherDuck"]] = "MotherDuck" + DISPLAY_ORDER: t.ClassVar[t.Literal[5]] = 5 + + @property + def _connection_kwargs_keys(self) -> t.Set[str]: + return set() + + @property + def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: + """kwargs that are for execution config only""" + from sqlmesh import __version__ + + custom_user_agent_config = {"custom_user_agent": f"SQLMesh/{__version__}"} + connection_str = "md:" + if self.database: + # Attach single MD database instead of all databases on the account + connection_str += f"{self.database}?attach_mode=single" + if self.token: + connection_str += f"{'&' if self.database else '?'}motherduck_token={self.token}" + return {"database": connection_str, "config": custom_user_agent_config} + + @property + def _extra_engine_config(self) -> t.Dict[str, t.Any]: + return {"is_motherduck": True} + + +class DuckDBConnectionConfig(BaseDuckDBConnectionConfig): + """Configuration for the DuckDB connection.""" + + type_: t.Literal["duckdb"] = Field(alias="type", default="duckdb") + DIALECT: t.ClassVar[t.Literal["duckdb"]] = "duckdb" + DISPLAY_NAME: t.ClassVar[t.Literal["DuckDB"]] = "DuckDB" + DISPLAY_ORDER: t.ClassVar[t.Literal[1]] = 1 + class SnowflakeConnectionConfig(ConnectionConfig): """Configuration for the Snowflake connection. @@ -334,6 +582,8 @@ class SnowflakeConnectionConfig(ConnectionConfig): register_comments: Whether or not to register model comments with the SQL engine. pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive. session_parameters: The optional session parameters to set for the connection. + host: Host address for the connection. + port: Port for the connection. """ account: str @@ -344,6 +594,9 @@ class SnowflakeConnectionConfig(ConnectionConfig): role: t.Optional[str] = None authenticator: t.Optional[str] = None token: t.Optional[str] = None + host: t.Optional[str] = None + port: t.Optional[int] = None + application: t.Literal["Tobiko_SQLMesh"] = "Tobiko_SQLMesh" # Private Key Auth private_key: t.Optional[t.Union[str, bytes]] = None @@ -356,34 +609,41 @@ class SnowflakeConnectionConfig(ConnectionConfig): session_parameters: t.Optional[dict] = None - type_: Literal["snowflake"] = Field(alias="type", default="snowflake") + type_: t.Literal["snowflake"] = Field(alias="type", default="snowflake") + DIALECT: t.ClassVar[t.Literal["snowflake"]] = "snowflake" + DISPLAY_NAME: t.ClassVar[t.Literal["Snowflake"]] = "Snowflake" + DISPLAY_ORDER: t.ClassVar[t.Literal[2]] = 2 _concurrent_tasks_validator = concurrent_tasks_validator @model_validator(mode="before") - @model_validator_v1_args - def _validate_authenticator( - cls, values: t.Dict[str, t.Optional[str]] - ) -> t.Dict[str, t.Optional[str]]: - from snowflake.connector.network import ( - DEFAULT_AUTHENTICATOR, - OAUTH_AUTHENTICATOR, - ) + def _validate_authenticator(cls, data: t.Any) -> t.Any: + if not isinstance(data, dict): + return data - auth = values.get("authenticator") + from snowflake.connector.network import DEFAULT_AUTHENTICATOR, OAUTH_AUTHENTICATOR + + auth = data.get("authenticator") auth = auth.upper() if auth else DEFAULT_AUTHENTICATOR - user = values.get("user") - password = values.get("password") - values["private_key"] = cls._get_private_key(values, auth) # type: ignore + user = data.get("user") + password = data.get("password") + data["private_key"] = cls._get_private_key(data, auth) # type: ignore + if ( auth == DEFAULT_AUTHENTICATOR - and not values.get("private_key") + and not data.get("private_key") and (not user or not password) ): raise ConfigError("User and password must be provided if using default authentication") - if auth == OAUTH_AUTHENTICATOR and not values.get("token"): + + if auth == OAUTH_AUTHENTICATOR and not data.get("token"): raise ConfigError("Token must be provided if using oauth authentication") - return values + + return data + + _engine_import_validator = _get_engine_import_validator( + "snowflake.connector.network", "snowflake" + ) @classmethod def _get_private_key(cls, values: t.Dict[str, t.Optional[str]], auth: str) -> t.Optional[bytes]: @@ -475,6 +735,9 @@ def _connection_kwargs_keys(self) -> t.Set[str]: "token", "private_key", "session_parameters", + "application", + "host", + "port", } @property @@ -497,11 +760,16 @@ class DatabricksConnectionConfig(ConnectionConfig): Databricks connection that uses the SQL connector for SQL models and then Databricks Connect for Dataframe operations Arg Source: https://github.com/databricks/databricks-sql-python/blob/main/src/databricks/sql/client.py#L39 + OAuth ref: https://docs.databricks.com/en/dev-tools/python-sql-connector.html#oauth-machine-to-machine-m2m-authentication + Args: server_hostname: Databricks instance host name. http_path: Http path either to a DBSQL endpoint (e.g. /sql/1.0/endpoints/1234567890abcdef) or to a DBR interactive cluster (e.g. /sql/protocolv1/o/1234567890123456/1234-123456-slid123) access_token: Http Bearer access token, e.g. Databricks Personal Access Token. + auth_type: Set to 'databricks-oauth' or 'azure-oauth' to trigger OAuth (or dont set at all to use `access_token`) + oauth_client_id: Client ID to use when auth_type is set to one of the 'oauth' types + oauth_client_secret: Client Secret to use when auth_type is set to one of the 'oauth' types catalog: Default catalog to use for SQL models. Defaults to None which means it will use the default set in the Databricks cluster (most likely `hive_metastore`). http_headers: An optional list of (k, v) pairs that will be set as Http headers on every request @@ -522,53 +790,109 @@ class DatabricksConnectionConfig(ConnectionConfig): server_hostname: t.Optional[str] = None http_path: t.Optional[str] = None access_token: t.Optional[str] = None + auth_type: t.Optional[str] = None + oauth_client_id: t.Optional[str] = None + oauth_client_secret: t.Optional[str] = None catalog: t.Optional[str] = None http_headers: t.Optional[t.List[t.Tuple[str, str]]] = None session_configuration: t.Optional[t.Dict[str, t.Any]] = None databricks_connect_server_hostname: t.Optional[str] = None databricks_connect_access_token: t.Optional[str] = None databricks_connect_cluster_id: t.Optional[str] = None + databricks_connect_use_serverless: bool = False force_databricks_connect: bool = False disable_databricks_connect: bool = False disable_spark_session: bool = False concurrent_tasks: int = 1 register_comments: bool = True - pre_ping: Literal[False] = False + pre_ping: t.Literal[False] = False - type_: Literal["databricks"] = Field(alias="type", default="databricks") + type_: t.Literal["databricks"] = Field(alias="type", default="databricks") + DIALECT: t.ClassVar[t.Literal["databricks"]] = "databricks" + DISPLAY_NAME: t.ClassVar[t.Literal["Databricks"]] = "Databricks" + DISPLAY_ORDER: t.ClassVar[t.Literal[3]] = 3 _concurrent_tasks_validator = concurrent_tasks_validator _http_headers_validator = http_headers_validator @model_validator(mode="before") - @model_validator_v1_args - def _databricks_connect_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: + def _databricks_connect_validator(cls, data: t.Any) -> t.Any: + # SQLQueryContextLogger will output any error SQL queries even if they are in a try/except block. + # Disabling this allows SQLMesh to determine what should be shown to the user. + # Ex: We describe a table to see if it exists and therefore that execution can fail but we don't need to show + # the user since it is expected if the table doesn't exist. Without this change the user would see the error. + logging.getLogger("SQLQueryContextLogger").setLevel(logging.CRITICAL) + + if not isinstance(data, dict): + return data + from sqlmesh.core.engine_adapter.databricks import DatabricksEngineAdapter if DatabricksEngineAdapter.can_access_spark_session( - bool(values.get("disable_spark_session")) + bool(data.get("disable_spark_session")) ): - return values - server_hostname, http_path, access_token = ( - values.get("server_hostname"), - values.get("http_path"), - values.get("access_token"), + return data + + databricks_connect_use_serverless = data.get("databricks_connect_use_serverless") + server_hostname, http_path, access_token, auth_type = ( + data.get("server_hostname"), + data.get("http_path"), + data.get("access_token"), + data.get("auth_type"), ) - if not server_hostname or not http_path or not access_token: + + if (not server_hostname or not http_path or not access_token) and ( + not databricks_connect_use_serverless and not auth_type + ): raise ValueError( "`server_hostname`, `http_path`, and `access_token` are required for Databricks connections when not running in a notebook" ) + if ( + databricks_connect_use_serverless + and not server_hostname + and not data.get("databricks_connect_server_hostname") + ): + raise ValueError( + "`server_hostname` or `databricks_connect_server_hostname` is required when `databricks_connect_use_serverless` is set" + ) if DatabricksEngineAdapter.can_access_databricks_connect( - bool(values.get("disable_databricks_connect")) + bool(data.get("disable_databricks_connect")) ): - if not values.get("databricks_connect_server_hostname"): - values["databricks_connect_server_hostname"] = f"https://{server_hostname}" - if not values.get("databricks_connect_access_token"): - values["databricks_connect_access_token"] = access_token - if not values.get("databricks_connect_cluster_id"): - values["databricks_connect_cluster_id"] = http_path.split("/")[-1] - return values + if not data.get("databricks_connect_access_token"): + data["databricks_connect_access_token"] = access_token + if not data.get("databricks_connect_server_hostname"): + data["databricks_connect_server_hostname"] = f"https://{server_hostname}" + if not databricks_connect_use_serverless and not data.get( + "databricks_connect_cluster_id" + ): + if t.TYPE_CHECKING: + assert http_path is not None + data["databricks_connect_cluster_id"] = http_path.split("/")[-1] + + if auth_type: + from databricks.sql.auth.auth import AuthType + + all_data = [m.value for m in AuthType] + if auth_type not in all_data: + raise ValueError( + f"`auth_type` {auth_type} does not match a valid option: {all_data}" + ) + + client_id = data.get("oauth_client_id") + client_secret = data.get("oauth_client_secret") + + if client_secret and not client_id: + raise ValueError( + "`oauth_client_id` is required when `oauth_client_secret` is specified" + ) + + if not http_path: + raise ValueError("`http_path` is still required when using `auth_type`") + + return data + + _engine_import_validator = _get_engine_import_validator("databricks", "databricks") @property def _connection_kwargs_keys(self) -> t.Set[str]: @@ -612,7 +936,7 @@ def _connection_factory(self) -> t.Callable: return connection - from databricks import sql + from databricks import sql # type: ignore return sql.connect @@ -621,10 +945,31 @@ def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: from sqlmesh.core.engine_adapter.databricks import DatabricksEngineAdapter if not self.use_spark_session_only: - return { + conn_kwargs: t.Dict[str, t.Any] = { "_user_agent_entry": "sqlmesh", } + if self.auth_type and "oauth" in self.auth_type: + # there are two types of oauth: User-to-Machine (U2M) and Machine-to-Machine (M2M) + if self.oauth_client_secret: + # if a client_secret exists, then a client_id also exists and we are using M2M + # ref: https://docs.databricks.com/en/dev-tools/python-sql-connector.html#oauth-machine-to-machine-m2m-authentication + # ref: https://github.com/databricks/databricks-sql-python/blob/main/examples/m2m_oauth.py + from databricks.sdk.core import oauth_service_principal, Config + + config = Config( + host=f"https://{self.server_hostname}", + client_id=self.oauth_client_id, + client_secret=self.oauth_client_secret, + ) + conn_kwargs["credentials_provider"] = lambda: oauth_service_principal(config) + else: + # if auth_type is set to an 'oauth' type but no client_id/secret are set, then we are using U2M + # ref: https://docs.databricks.com/en/dev-tools/python-sql-connector.html#oauth-user-to-machine-u2m-authentication + conn_kwargs["auth_type"] = self.auth_type + + return conn_kwargs + if DatabricksEngineAdapter.can_access_spark_session(self.disable_spark_session): from pyspark.sql import SparkSession @@ -635,14 +980,27 @@ def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: from databricks.connect import DatabricksSession - return dict( - spark=DatabricksSession.builder.remote( + if t.TYPE_CHECKING: + assert self.databricks_connect_server_hostname is not None + assert self.databricks_connect_access_token is not None + + if self.databricks_connect_use_serverless: + builder = DatabricksSession.builder.remote( + host=self.databricks_connect_server_hostname, + token=self.databricks_connect_access_token, + serverless=True, + ) + else: + if t.TYPE_CHECKING: + assert self.databricks_connect_cluster_id is not None + builder = DatabricksSession.builder.remote( host=self.databricks_connect_server_hostname, token=self.databricks_connect_access_token, cluster_id=self.databricks_connect_cluster_id, ) - .userAgent("sqlmesh") - .getOrCreate(), + + return dict( + spark=builder.userAgent("sqlmesh").getOrCreate(), catalog=self.catalog, ) @@ -684,6 +1042,7 @@ class BigQueryConnectionConfig(ConnectionConfig): project: t.Optional[str] = None execution_project: t.Optional[str] = None + quota_project: t.Optional[str] = None location: t.Optional[str] = None # Keyfile Auth keyfile: t.Optional[str] = None @@ -695,8 +1054,9 @@ class BigQueryConnectionConfig(ConnectionConfig): client_secret: t.Optional[str] = None token_uri: t.Optional[str] = None scopes: t.Tuple[str, ...] = ("https://www.googleapis.com/auth/bigquery",) - job_creation_timeout_seconds: t.Optional[int] = None + impersonated_service_account: t.Optional[str] = None # Extra Engine Config + job_creation_timeout_seconds: t.Optional[int] = None job_execution_timeout_seconds: t.Optional[int] = None job_retries: t.Optional[int] = 1 job_retry_deadline_seconds: t.Optional[int] = None @@ -705,9 +1065,38 @@ class BigQueryConnectionConfig(ConnectionConfig): concurrent_tasks: int = 1 register_comments: bool = True - pre_ping: Literal[False] = False + pre_ping: t.Literal[False] = False + + type_: t.Literal["bigquery"] = Field(alias="type", default="bigquery") + DIALECT: t.ClassVar[t.Literal["bigquery"]] = "bigquery" + DISPLAY_NAME: t.ClassVar[t.Literal["BigQuery"]] = "BigQuery" + DISPLAY_ORDER: t.ClassVar[t.Literal[4]] = 4 + + _engine_import_validator = _get_engine_import_validator("google.cloud.bigquery", "bigquery") + + @field_validator("execution_project") + def validate_execution_project( + cls, + v: t.Optional[str], + info: ValidationInfo, + ) -> t.Optional[str]: + if v and not info.data.get("project"): + raise ConfigError( + "If the `execution_project` field is specified, you must also specify the `project` field to provide a default object location." + ) + return v - type_: Literal["bigquery"] = Field(alias="type", default="bigquery") + @field_validator("quota_project") + def validate_quota_project( + cls, + v: t.Optional[str], + info: ValidationInfo, + ) -> t.Optional[str]: + if v and not info.data.get("project"): + raise ConfigError( + "If the `quota_project` field is specified, you must also specify the `project` field to provide a default object location." + ) + return v @property def _connection_kwargs_keys(self) -> t.Set[str]: @@ -721,7 +1110,8 @@ def _engine_adapter(self) -> t.Type[EngineAdapter]: def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: """The static connection kwargs for this connection""" import google.auth - from google.api_core import client_info + from google.auth import impersonated_credentials + from google.api_core import client_info, client_options from google.oauth2 import credentials, service_account if self.method == BigQueryConnectionMethod.OAUTH: @@ -745,11 +1135,23 @@ def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: ) else: raise ConfigError("Invalid BigQuery Connection Method") + + if self.impersonated_service_account: + creds = impersonated_credentials.Credentials( + source_credentials=creds, + target_principal=self.impersonated_service_account, + target_scopes=self.scopes, + ) + + options = client_options.ClientOptions(quota_project_id=self.quota_project) + project = self.execution_project or self.project or None + client = google.cloud.bigquery.Client( - project=self.execution_project or self.project, + project=project and exp.parse_identifier(project, dialect="bigquery").name, credentials=creds, location=self.location, client_info=client_info.ClientInfo(user_agent="sqlmesh"), + client_options=options, ) return { @@ -792,6 +1194,8 @@ class GCPPostgresConnectionConfig(ConnectionConfig): password: The postgres user's password. Only needed when the user is a postgres user. enable_iam_auth: Set to True when user is an IAM user. db: Name of the db to connect to. + keyfile: string path to json service account credentials file + keyfile_json: dict service account credentials info pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive. """ @@ -800,34 +1204,43 @@ class GCPPostgresConnectionConfig(ConnectionConfig): password: t.Optional[str] = None enable_iam_auth: t.Optional[bool] = None db: str + ip_type: t.Union[t.Literal["public"], t.Literal["private"], t.Literal["psc"]] = "public" + # Keyfile Auth + keyfile: t.Optional[str] = None + keyfile_json: t.Optional[t.Dict[str, t.Any]] = None timeout: t.Optional[int] = None - + scopes: t.Tuple[str, ...] = ("https://www.googleapis.com/auth/sqlservice.admin",) driver: str = "pg8000" - type_: Literal["gcp_postgres"] = Field(alias="type", default="gcp_postgres") + + type_: t.Literal["gcp_postgres"] = Field(alias="type", default="gcp_postgres") + DIALECT: t.ClassVar[t.Literal["postgres"]] = "postgres" + DISPLAY_NAME: t.ClassVar[t.Literal["GCP Postgres"]] = "GCP Postgres" + DISPLAY_ORDER: t.ClassVar[t.Literal[13]] = 13 + concurrent_tasks: int = 4 register_comments: bool = True pre_ping: bool = True + _engine_import_validator = _get_engine_import_validator( + "google.cloud.sql", "gcp_postgres", "gcppostgres" + ) + @model_validator(mode="before") - @model_validator_v1_args - def _validate_auth_method( - cls, values: t.Dict[str, t.Optional[str]] - ) -> t.Dict[str, t.Optional[str]]: - password = values.get("password") - enable_iam_auth = values.get("enable_iam_auth") - if password and enable_iam_auth: - raise ConfigError( - "Invalid GCP Postgres connection configuration - both password and" - " enable_iam_auth set. Use password when connecting to a postgres" - " user and enable_iam_auth 'True' when connecting to an IAM user." - ) + def _validate_auth_method(cls, data: t.Any) -> t.Any: + if not isinstance(data, dict): + return data + + password = data.get("password") + enable_iam_auth = data.get("enable_iam_auth") + if not password and not enable_iam_auth: raise ConfigError( "GCP Postgres connection configuration requires either password set" " for a postgres user account or enable_iam_auth set to 'True'" " for an IAM user account." ) - return values + + return data @property def _connection_kwargs_keys(self) -> t.Set[str]: @@ -848,8 +1261,27 @@ def _engine_adapter(self) -> t.Type[EngineAdapter]: @property def _connection_factory(self) -> t.Callable: from google.cloud.sql.connector import Connector + from google.oauth2 import service_account + + creds = None + if self.keyfile: + creds = service_account.Credentials.from_service_account_file( + self.keyfile, scopes=self.scopes + ) + elif self.keyfile_json: + creds = service_account.Credentials.from_service_account_info( + self.keyfile_json, scopes=self.scopes + ) + + kwargs = { + "credentials": creds, + "ip_type": self.ip_type, + } - return Connector().connect + if self.timeout: + kwargs["timeout"] = self.timeout + + return Connector(**kwargs).connect # type: ignore class RedshiftConnectionConfig(ConnectionConfig): @@ -882,6 +1314,7 @@ class RedshiftConnectionConfig(ConnectionConfig): serverless_acct_id: The account ID of the serverless. Default value None serverless_work_group: The name of work group for serverless end point. Default value None. pre_ping: Whether or not to pre-ping the connection before starting a new transaction to ensure it is still alive. + enable_merge: Whether to use the Redshift merge operation instead of the SQLMesh logical merge. """ user: t.Optional[str] = None @@ -905,12 +1338,18 @@ class RedshiftConnectionConfig(ConnectionConfig): is_serverless: t.Optional[bool] = None serverless_acct_id: t.Optional[str] = None serverless_work_group: t.Optional[str] = None + enable_merge: t.Optional[bool] = None concurrent_tasks: int = 4 register_comments: bool = True pre_ping: bool = False - type_: Literal["redshift"] = Field(alias="type", default="redshift") + type_: t.Literal["redshift"] = Field(alias="type", default="redshift") + DIALECT: t.ClassVar[t.Literal["redshift"]] = "redshift" + DISPLAY_NAME: t.ClassVar[t.Literal["Redshift"]] = "Redshift" + DISPLAY_ORDER: t.ClassVar[t.Literal[7]] = 7 + + _engine_import_validator = _get_engine_import_validator("redshift_connector", "redshift") @property def _connection_kwargs_keys(self) -> t.Set[str]: @@ -948,6 +1387,10 @@ def _connection_factory(self) -> t.Callable: return connect + @property + def _extra_engine_config(self) -> t.Dict[str, t.Any]: + return {"enable_merge": self.enable_merge} + class PostgresConnectionConfig(ConnectionConfig): host: str @@ -959,12 +1402,18 @@ class PostgresConnectionConfig(ConnectionConfig): connect_timeout: int = 10 role: t.Optional[str] = None sslmode: t.Optional[str] = None + application_name: t.Optional[str] = None concurrent_tasks: int = 4 register_comments: bool = True pre_ping: bool = True - type_: Literal["postgres"] = Field(alias="type", default="postgres") + type_: t.Literal["postgres"] = Field(alias="type", default="postgres") + DIALECT: t.ClassVar[t.Literal["postgres"]] = "postgres" + DISPLAY_NAME: t.ClassVar[t.Literal["Postgres"]] = "Postgres" + DISPLAY_ORDER: t.ClassVar[t.Literal[12]] = 12 + + _engine_import_validator = _get_engine_import_validator("psycopg2", "postgres") @property def _connection_kwargs_keys(self) -> t.Set[str]: @@ -976,8 +1425,8 @@ def _connection_kwargs_keys(self) -> t.Set[str]: "database", "keepalives_idle", "connect_timeout", - "role", "sslmode", + "application_name", } @property @@ -990,25 +1439,37 @@ def _connection_factory(self) -> t.Callable: return connect + @property + def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]: + if not self.role: + return None + + def init(cursor: t.Any) -> None: + cursor.execute(f"SET ROLE {self.role}") + + return init + class MySQLConnectionConfig(ConnectionConfig): host: str user: str password: str port: t.Optional[int] = None + database: t.Optional[str] = None charset: t.Optional[str] = None + collation: t.Optional[str] = None ssl_disabled: t.Optional[bool] = None concurrent_tasks: int = 4 register_comments: bool = True pre_ping: bool = True - type_: Literal["mysql"] = Field(alias="type", default="mysql") + type_: t.Literal["mysql"] = Field(alias="type", default="mysql") + DIALECT: t.ClassVar[t.Literal["mysql"]] = "mysql" + DISPLAY_NAME: t.ClassVar[t.Literal["MySQL"]] = "MySQL" + DISPLAY_ORDER: t.ClassVar[t.Literal[14]] = 14 - @property - def _cursor_kwargs(self) -> t.Optional[t.Dict[str, t.Any]]: - """Key-value arguments that will be passed during cursor construction.""" - return {"buffered": True} + _engine_import_validator = _get_engine_import_validator("pymysql", "mysql") @property def _connection_kwargs_keys(self) -> t.Set[str]: @@ -1016,13 +1477,15 @@ def _connection_kwargs_keys(self) -> t.Set[str]: "host", "user", "password", - "port", - "database", } if self.port is not None: connection_keys.add("port") + if self.database is not None: + connection_keys.add("database") if self.charset is not None: connection_keys.add("charset") + if self.collation is not None: + connection_keys.add("collation") if self.ssl_disabled is not None: connection_keys.add("ssl_disabled") return connection_keys @@ -1033,7 +1496,7 @@ def _engine_adapter(self) -> t.Type[EngineAdapter]: @property def _connection_factory(self) -> t.Callable: - from mysql.connector import connect + from pymysql import connect return connect @@ -1048,19 +1511,57 @@ class MSSQLConnectionConfig(ConnectionConfig): charset: t.Optional[str] = "UTF-8" appname: t.Optional[str] = None port: t.Optional[int] = 1433 - conn_properties: t.Optional[t.Union[t.Iterable[str], str]] = None + conn_properties: t.Optional[t.Union[t.List[str], str]] = None autocommit: t.Optional[bool] = False tds_version: t.Optional[str] = None + # Driver options + driver: t.Literal["pymssql", "pyodbc"] = "pymssql" + # PyODBC specific options + driver_name: t.Optional[str] = None # e.g. "ODBC Driver 18 for SQL Server" + trust_server_certificate: t.Optional[bool] = None + encrypt: t.Optional[bool] = None + # Dictionary of arbitrary ODBC connection properties + # See: https://learn.microsoft.com/en-us/sql/connect/odbc/dsn-connection-string-attribute + odbc_properties: t.Optional[t.Dict[str, t.Any]] = None + concurrent_tasks: int = 4 register_comments: bool = True pre_ping: bool = True - type_: Literal["mssql"] = Field(alias="type", default="mssql") + type_: t.Literal["mssql"] = Field(alias="type", default="mssql") + DIALECT: t.ClassVar[t.Literal["tsql"]] = "tsql" + DISPLAY_NAME: t.ClassVar[t.Literal["MSSQL"]] = "MSSQL" + DISPLAY_ORDER: t.ClassVar[t.Literal[11]] = 11 + + @model_validator(mode="before") + @classmethod + def _mssql_engine_import_validator(cls, data: t.Any) -> t.Any: + if not isinstance(data, dict): + return data + + driver = data.get("driver", "pymssql") + + # Define the mapping of driver to import module and extra name + driver_configs = {"pymssql": ("pymssql", "mssql"), "pyodbc": ("pyodbc", "mssql-odbc")} + + if driver not in driver_configs: + raise ValueError(f"Unsupported driver: {driver}") + + import_module, extra_name = driver_configs[driver] + + # Use _get_engine_import_validator with decorate=False to get the raw validation function + # This avoids the __wrapped__ issue in Python 3.9 + validator_func = _get_engine_import_validator( + import_module, driver, extra_name, decorate=False + ) + + # Call the raw validation function directly + return validator_func(cls, data) @property def _connection_kwargs_keys(self) -> t.Set[str]: - return { + base_keys = { "host", "user", "password", @@ -1075,15 +1576,184 @@ def _connection_kwargs_keys(self) -> t.Set[str]: "tds_version", } + if self.driver == "pyodbc": + base_keys.update( + { + "driver_name", + "trust_server_certificate", + "encrypt", + "odbc_properties", + } + ) + # Remove pymssql-specific parameters + base_keys.discard("tds_version") + base_keys.discard("conn_properties") + + return base_keys + @property def _engine_adapter(self) -> t.Type[EngineAdapter]: return engine_adapter.MSSQLEngineAdapter @property def _connection_factory(self) -> t.Callable: - import pymssql + if self.driver == "pymssql": + import pymssql + + return pymssql.connect + + import pyodbc + + def connect(**kwargs: t.Any) -> t.Callable: + # Extract parameters for connection string + host = kwargs.pop("host") + port = kwargs.pop("port", 1433) + database = kwargs.pop("database", "") + user = kwargs.pop("user", None) + password = kwargs.pop("password", None) + driver_name = kwargs.pop("driver_name", "ODBC Driver 18 for SQL Server") + trust_server_certificate = kwargs.pop("trust_server_certificate", False) + encrypt = kwargs.pop("encrypt", True) + login_timeout = kwargs.pop("login_timeout", 60) + + # Build connection string + conn_str_parts = [ + f"DRIVER={{{driver_name}}}", + f"SERVER={host},{port}", + ] + + if database: + conn_str_parts.append(f"DATABASE={database}") + + # Add security options + conn_str_parts.append(f"Encrypt={'YES' if encrypt else 'NO'}") + if trust_server_certificate: + conn_str_parts.append("TrustServerCertificate=YES") + + conn_str_parts.append(f"Connection Timeout={login_timeout}") + + # Standard SQL Server authentication + if user: + conn_str_parts.append(f"UID={user}") + if password: + conn_str_parts.append(f"PWD={password}") + + # Add any additional ODBC properties from the odbc_properties dictionary + if self.odbc_properties: + for key, value in self.odbc_properties.items(): + # Skip properties that we've already set above + if key.lower() in ( + "driver", + "server", + "database", + "uid", + "pwd", + "encrypt", + "trustservercertificate", + "connection timeout", + ): + continue + + # Handle boolean values properly + if isinstance(value, bool): + conn_str_parts.append(f"{key}={'YES' if value else 'NO'}") + else: + conn_str_parts.append(f"{key}={value}") + + # Create the connection string + conn_str = ";".join(conn_str_parts) + + conn = pyodbc.connect(conn_str, autocommit=kwargs.get("autocommit", False)) + + # Set up output converters for MSSQL-specific data types + # Handle SQL type -155 (DATETIMEOFFSET) which is not yet supported by pyodbc + # ref: https://github.com/mkleehammer/pyodbc/issues/134#issuecomment-281739794 + def handle_datetimeoffset(dto_value: t.Any) -> t.Any: + from datetime import datetime, timedelta, timezone + import struct + + # Unpack the DATETIMEOFFSET binary format: + # Format: <6hI2h = (year, month, day, hour, minute, second, nanoseconds, tz_hour_offset, tz_minute_offset) + tup = struct.unpack("<6hI2h", dto_value) + return datetime( + tup[0], + tup[1], + tup[2], + tup[3], + tup[4], + tup[5], + tup[6] // 1000, + timezone(timedelta(hours=tup[7], minutes=tup[8])), + ) - return pymssql.connect + conn.add_output_converter(-155, handle_datetimeoffset) + + return conn + + return connect + + @property + def _extra_engine_config(self) -> t.Dict[str, t.Any]: + return {"catalog_support": CatalogSupport.REQUIRES_SET_CATALOG} + + +class AzureSQLConnectionConfig(MSSQLConnectionConfig): + type_: t.Literal["azuresql"] = Field(alias="type", default="azuresql") # type: ignore + DISPLAY_NAME: t.ClassVar[t.Literal["Azure SQL"]] = "Azure SQL" # type: ignore + DISPLAY_ORDER: t.ClassVar[t.Literal[10]] = 10 # type: ignore + + @property + def _extra_engine_config(self) -> t.Dict[str, t.Any]: + return {"catalog_support": CatalogSupport.SINGLE_CATALOG_ONLY} + + +class FabricConnectionConfig(MSSQLConnectionConfig): + """ + Fabric Connection Configuration. + Inherits most settings from MSSQLConnectionConfig and sets the type to 'fabric'. + It is recommended to use the 'pyodbc' driver for Fabric. + """ + + type_: t.Literal["fabric"] = Field(alias="type", default="fabric") # type: ignore + DIALECT: t.ClassVar[t.Literal["fabric"]] = "fabric" # type: ignore + DISPLAY_NAME: t.ClassVar[t.Literal["Fabric"]] = "Fabric" # type: ignore + DISPLAY_ORDER: t.ClassVar[t.Literal[17]] = 17 # type: ignore + driver: t.Literal["pyodbc"] = "pyodbc" + workspace_id: str + tenant_id: str + autocommit: t.Optional[bool] = True + + @property + def _engine_adapter(self) -> t.Type[EngineAdapter]: + from sqlmesh.core.engine_adapter.fabric import FabricEngineAdapter + + return FabricEngineAdapter + + @property + def _connection_factory(self) -> t.Callable: + # Override to support catalog switching for Fabric + base_factory = super()._connection_factory + + def create_fabric_connection( + target_catalog: t.Optional[str] = None, *args: t.Any, **kwargs: t.Any + ) -> t.Callable: + kwargs["database"] = target_catalog or self.database + return base_factory(*args, **kwargs) + + return create_fabric_connection + + @property + def _extra_engine_config(self) -> t.Dict[str, t.Any]: + return { + "database": self.database, + # more operations than not require a specific catalog to be already active + # in particular, create/drop view, create/drop schema and querying information_schema + "catalog_support": CatalogSupport.REQUIRES_SET_CATALOG, + "workspace_id": self.workspace_id, + "tenant_id": self.tenant_id, + "user": self.user, + "password": self.password, + } class SparkConnectionConfig(ConnectionConfig): @@ -1094,12 +1764,18 @@ class SparkConnectionConfig(ConnectionConfig): config_dir: t.Optional[str] = None catalog: t.Optional[str] = None config: t.Dict[str, t.Any] = {} + wap_enabled: bool = False concurrent_tasks: int = 4 register_comments: bool = True - pre_ping: Literal[False] = False + pre_ping: t.Literal[False] = False + + type_: t.Literal["spark"] = Field(alias="type", default="spark") + DIALECT: t.ClassVar[t.Literal["spark"]] = "spark" + DISPLAY_NAME: t.ClassVar[t.Literal["Spark"]] = "Spark" + DISPLAY_ORDER: t.ClassVar[t.Literal[8]] = 8 - type_: Literal["spark"] = Field(alias="type", default="spark") + _engine_import_validator = _get_engine_import_validator("pyspark", "spark") @property def _connection_kwargs_keys(self) -> t.Set[str]: @@ -1135,6 +1811,10 @@ def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: .getOrCreate(), } + @property + def _extra_engine_config(self) -> t.Dict[str, t.Any]: + return {"wap_enabled": self.wap_enabled} + class TrinoAuthenticationMethod(str, Enum): NO_AUTH = "no-auth" @@ -1180,7 +1860,7 @@ class TrinoConnectionConfig(ConnectionConfig): user: str catalog: str port: t.Optional[int] = None - http_scheme: Literal["http", "https"] = "https" + http_scheme: t.Literal["http", "https"] = "https" # General Optional roles: t.Optional[t.Dict[str, str]] = None http_headers: t.Optional[t.Dict[str, str]] = None @@ -1208,48 +1888,95 @@ class TrinoConnectionConfig(ConnectionConfig): client_certificate: t.Optional[str] = None client_private_key: t.Optional[str] = None cert: t.Optional[str] = None + source: str = "sqlmesh" + # SQLMesh options + schema_location_mapping: t.Optional[dict[re.Pattern, str]] = None + timestamp_mapping: t.Optional[dict[exp.DataType, exp.DataType]] = None concurrent_tasks: int = 4 register_comments: bool = True - pre_ping: Literal[False] = False + pre_ping: t.Literal[False] = False + + type_: t.Literal["trino"] = Field(alias="type", default="trino") + DIALECT: t.ClassVar[t.Literal["trino"]] = "trino" + DISPLAY_NAME: t.ClassVar[t.Literal["Trino"]] = "Trino" + DISPLAY_ORDER: t.ClassVar[t.Literal[9]] = 9 - type_: Literal["trino"] = Field(alias="type", default="trino") + _engine_import_validator = _get_engine_import_validator("trino", "trino") + + @field_validator("schema_location_mapping", mode="before") + @classmethod + def _validate_regex_keys( + cls, value: t.Dict[str | re.Pattern, str] + ) -> t.Dict[re.Pattern, t.Any]: + compiled = compile_regex_mapping(value) + for replacement in compiled.values(): + if "@{schema_name}" not in replacement: + raise ConfigError( + "schema_location_mapping needs to include the '@{schema_name}' placeholder in the value so SQLMesh knows where to substitute the schema name" + ) + return compiled + + @field_validator("timestamp_mapping", mode="before") + @classmethod + def _validate_timestamp_mapping( + cls, value: t.Optional[dict[str, str]] + ) -> t.Optional[dict[exp.DataType, exp.DataType]]: + if value is None: + return value + + result: dict[exp.DataType, exp.DataType] = {} + for source_type, target_type in value.items(): + try: + source_datatype = exp.DataType.build(source_type) + except ParseError: + raise ConfigError( + f"Invalid SQL type string in timestamp_mapping: " + f"'{source_type}' is not a valid SQL data type." + ) + try: + target_datatype = exp.DataType.build(target_type) + except ParseError: + raise ConfigError( + f"Invalid SQL type string in timestamp_mapping: " + f"'{target_type}' is not a valid SQL data type." + ) + result[source_datatype] = target_datatype + + return result @model_validator(mode="after") - @model_validator_v1_args - def _root_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: - port = values.get("port") - if ( - values["http_scheme"] == "http" - and not values["method"].is_no_auth - and not values["method"].is_basic - ): + def _root_validator(self) -> Self: + port = self.port + if self.http_scheme == "http" and not self.method.is_no_auth and not self.method.is_basic: raise ConfigError("HTTP scheme can only be used with no-auth or basic method") + if port is None: - values["port"] = 80 if values["http_scheme"] == "http" else 443 - if (values["method"].is_ldap or values["method"].is_basic) and ( - not values["password"] or not values["user"] - ): + self.port = 80 if self.http_scheme == "http" else 443 + + if (self.method.is_ldap or self.method.is_basic) and (not self.password or not self.user): raise ConfigError( - f"Username and Password must be provided if using {values['method'].value} authentication" + f"Username and Password must be provided if using {self.method.value} authentication" ) - if values["method"].is_kerberos and ( - not values["principal"] or not values["keytab"] or not values["krb5_config"] + + if self.method.is_kerberos and ( + not self.principal or not self.keytab or not self.krb5_config ): raise ConfigError( "Kerberos requires the following fields: principal, keytab, and krb5_config" ) - if values["method"].is_jwt and not values["jwt_token"]: + + if self.method.is_jwt and not self.jwt_token: raise ConfigError("JWT requires `jwt_token` to be set") - if values["method"].is_certificate and ( - not values["cert"] - or not values["client_certificate"] - or not values["client_private_key"] + + if self.method.is_certificate and ( + not self.cert or not self.client_certificate or not self.client_private_key ): raise ConfigError( "Certificate requires the following fields: cert, client_certificate, and client_private_key" ) - return values + + return self @property def _connection_kwargs_keys(self) -> t.Set[str]: @@ -1258,6 +1985,7 @@ def _connection_kwargs_keys(self) -> t.Set[str]: "port", "catalog", "roles", + "source", "http_scheme", "http_headers", "session_properties", @@ -1285,7 +2013,17 @@ def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: OAuth2Authentication, ) + auth: t.Optional[ + t.Union[ + BasicAuthentication, + KerberosAuthentication, + OAuth2Authentication, + JWTAuthentication, + CertificateAuthentication, + ] + ] = None if self.method.is_basic or self.method.is_ldap: + assert self.password is not None # for mypy since validator already checks this auth = BasicAuthentication(self.user, self.password) elif self.method.is_kerberos: if self.keytab: @@ -1304,26 +2042,331 @@ def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: elif self.method.is_oauth: auth = OAuth2Authentication() elif self.method.is_jwt: + assert self.jwt_token is not None auth = JWTAuthentication(self.jwt_token) elif self.method.is_certificate: + assert self.client_certificate is not None + assert self.client_private_key is not None auth = CertificateAuthentication(self.client_certificate, self.client_private_key) - else: - auth = None return { "auth": auth, "user": self.impersonation_user or self.user, "max_attempts": self.retries, "verify": self.cert if self.cert is not None else self.verify, - "source": "sqlmesh", + "source": self.source, + } + + @property + def _extra_engine_config(self) -> t.Dict[str, t.Any]: + return { + "schema_location_mapping": self.schema_location_mapping, + "timestamp_mapping": self.timestamp_mapping, } +class ClickhouseConnectionConfig(ConnectionConfig): + """ + Clickhouse Connection Configuration. + + Property reference: https://clickhouse.com/docs/en/integrations/python#client-initialization + """ + + host: str + username: str + password: t.Optional[str] = None + port: t.Optional[int] = None + cluster: t.Optional[str] = None + connect_timeout: int = 10 + send_receive_timeout: int = 300 + query_limit: int = 0 + use_compression: bool = True + compression_method: t.Optional[str] = None + connection_settings: t.Optional[t.Dict[str, t.Any]] = None + http_proxy: t.Optional[str] = None + # HTTPS/TLS settings + verify: bool = True + ca_cert: t.Optional[str] = None + client_cert: t.Optional[str] = None + client_cert_key: t.Optional[str] = None + https_proxy: t.Optional[str] = None + server_host_name: t.Optional[str] = None + tls_mode: t.Optional[str] = None + + concurrent_tasks: int = 1 + register_comments: bool = True + pre_ping: bool = False + + # This object expects options from urllib3 and also from clickhouse-connect + # See: + # * https://urllib3.readthedocs.io/en/stable/advanced-usage.html + # * https://clickhouse.com/docs/en/integrations/python#customizing-the-http-connection-pool + connection_pool_options: t.Optional[t.Dict[str, t.Any]] = None + + type_: t.Literal["clickhouse"] = Field(alias="type", default="clickhouse") + DIALECT: t.ClassVar[t.Literal["clickhouse"]] = "clickhouse" + DISPLAY_NAME: t.ClassVar[t.Literal["ClickHouse"]] = "ClickHouse" + DISPLAY_ORDER: t.ClassVar[t.Literal[6]] = 6 + + _engine_import_validator = _get_engine_import_validator("clickhouse_connect", "clickhouse") + + @property + def _connection_kwargs_keys(self) -> t.Set[str]: + kwargs = { + "host", + "username", + "port", + "password", + "connect_timeout", + "send_receive_timeout", + "query_limit", + "http_proxy", + "verify", + "ca_cert", + "client_cert", + "client_cert_key", + "https_proxy", + "server_host_name", + "tls_mode", + } + return kwargs + + @property + def _engine_adapter(self) -> t.Type[EngineAdapter]: + return engine_adapter.ClickhouseEngineAdapter + + @property + def _connection_factory(self) -> t.Callable: + from clickhouse_connect.dbapi import connect # type: ignore + from clickhouse_connect.driver import httputil # type: ignore + from functools import partial + + pool_manager_options: t.Dict[str, t.Any] = dict( + # Match the maxsize to the number of concurrent tasks + maxsize=self.concurrent_tasks, + # Block if there are no free connections + block=True, + verify=self.verify, + ca_cert=self.ca_cert, + client_cert=self.client_cert, + client_cert_key=self.client_cert_key, + https_proxy=self.https_proxy, + ) + # this doesn't happen automatically because we always supply our own pool manager to the connection + # https://github.com/ClickHouse/clickhouse-connect/blob/3a7f4b04cad29c7c2536661b831fb744248e2ec0/clickhouse_connect/driver/httpclient.py#L109 + if self.server_host_name: + pool_manager_options["server_hostname"] = self.server_host_name + if self.verify: + pool_manager_options["assert_hostname"] = self.server_host_name + if self.connection_pool_options: + pool_manager_options.update(self.connection_pool_options) + pool_mgr = httputil.get_pool_manager(**pool_manager_options) + + return partial(connect, pool_mgr=pool_mgr) + + @property + def cloud_mode(self) -> bool: + return "clickhouse.cloud" in self.host + + @property + def _extra_engine_config(self) -> t.Dict[str, t.Any]: + return {"cluster": self.cluster, "cloud_mode": self.cloud_mode} + + @property + def _static_connection_kwargs(self) -> t.Dict[str, t.Any]: + from sqlmesh import __version__ + + # False = no compression + # True = Clickhouse default compression method + # string = specific compression method + compress: bool | str = self.use_compression + if compress and self.compression_method: + compress = self.compression_method + + # Clickhouse system settings passed to connection + # https://clickhouse.com/docs/en/operations/settings/settings + # - below are set to align with dbt-clickhouse + # - https://github.com/ClickHouse/dbt-clickhouse/blob/44d26308ea6a3c8ead25c280164aa88191f05f47/dbt/adapters/clickhouse/dbclient.py#L77 + settings = self.connection_settings or {} + # mutations_sync = 2: "The query waits for all mutations [ALTER statements] to complete on all replicas (if they exist)" + settings["mutations_sync"] = "2" + # insert_distributed_sync = 1: "INSERT operation succeeds only after all the data is saved on all shards" + settings["insert_distributed_sync"] = "1" + if self.cluster or self.cloud_mode: + # database_replicated_enforce_synchronous_settings = 1: + # - "Enforces synchronous waiting for some queries" + # - https://github.com/ClickHouse/ClickHouse/blob/ccaa8d03a9351efc16625340268b9caffa8a22ba/src/Core/Settings.h#L709 + settings["database_replicated_enforce_synchronous_settings"] = "1" + # insert_quorum = auto: + # - "INSERT succeeds only when ClickHouse manages to correctly write data to the insert_quorum of replicas during + # the insert_quorum_timeout" + # - "use majority number (number_of_replicas / 2 + 1) as quorum number" + settings["insert_quorum"] = "auto" + + return { + "compress": compress, + "client_name": f"SQLMesh/{__version__}", + **settings, + } + + +class AthenaConnectionConfig(ConnectionConfig): + # PyAthena connection options + aws_access_key_id: t.Optional[str] = None + aws_secret_access_key: t.Optional[str] = None + role_arn: t.Optional[str] = None + role_session_name: t.Optional[str] = None + region_name: t.Optional[str] = None + work_group: t.Optional[str] = None + s3_staging_dir: t.Optional[str] = None + schema_name: t.Optional[str] = None + catalog_name: t.Optional[str] = None + + # SQLMesh options + s3_warehouse_location: t.Optional[str] = None + concurrent_tasks: int = 4 + register_comments: t.Literal[False] = ( + False # because Athena doesnt support comments in most cases + ) + pre_ping: t.Literal[False] = False + + type_: t.Literal["athena"] = Field(alias="type", default="athena") + DIALECT: t.ClassVar[t.Literal["athena"]] = "athena" + DISPLAY_NAME: t.ClassVar[t.Literal["Athena"]] = "Athena" + DISPLAY_ORDER: t.ClassVar[t.Literal[15]] = 15 + + _engine_import_validator = _get_engine_import_validator("pyathena", "athena") + + @model_validator(mode="after") + def _root_validator(self) -> Self: + work_group = self.work_group + s3_staging_dir = self.s3_staging_dir + s3_warehouse_location = self.s3_warehouse_location + + if not work_group and not s3_staging_dir: + raise ConfigError("At least one of work_group or s3_staging_dir must be set") + + if s3_staging_dir: + self.s3_staging_dir = validate_s3_uri(s3_staging_dir, base=True, error_type=ConfigError) + + if s3_warehouse_location: + self.s3_warehouse_location = validate_s3_uri( + s3_warehouse_location, base=True, error_type=ConfigError + ) + + return self + + @property + def _connection_kwargs_keys(self) -> t.Set[str]: + return { + "aws_access_key_id", + "aws_secret_access_key", + "role_arn", + "role_session_name", + "region_name", + "work_group", + "s3_staging_dir", + "schema_name", + "catalog_name", + } + + @property + def _engine_adapter(self) -> t.Type[EngineAdapter]: + return engine_adapter.AthenaEngineAdapter + + @property + def _extra_engine_config(self) -> t.Dict[str, t.Any]: + return {"s3_warehouse_location": self.s3_warehouse_location} + + @property + def _connection_factory(self) -> t.Callable: + from pyathena import connect # type: ignore + + return connect + + def get_catalog(self) -> t.Optional[str]: + return self.catalog_name + + +class RisingwaveConnectionConfig(ConnectionConfig): + host: str + user: str + password: t.Optional[str] = None + port: int + database: str + role: t.Optional[str] = None + sslmode: t.Optional[str] = None + + concurrent_tasks: int = 4 + register_comments: bool = True + pre_ping: bool = True + + type_: t.Literal["risingwave"] = Field(alias="type", default="risingwave") + DIALECT: t.ClassVar[t.Literal["risingwave"]] = "risingwave" + DISPLAY_NAME: t.ClassVar[t.Literal["RisingWave"]] = "RisingWave" + DISPLAY_ORDER: t.ClassVar[t.Literal[16]] = 16 + + _engine_import_validator = _get_engine_import_validator("psycopg2", "risingwave") + + @property + def _connection_kwargs_keys(self) -> t.Set[str]: + return { + "host", + "user", + "password", + "port", + "database", + "role", + "sslmode", + } + + @property + def _engine_adapter(self) -> t.Type[EngineAdapter]: + return engine_adapter.RisingwaveEngineAdapter + + @property + def _connection_factory(self) -> t.Callable: + from psycopg2 import connect + + return connect + + @property + def _cursor_init(self) -> t.Optional[t.Callable[[t.Any], None]]: + def init(cursor: t.Any) -> None: + sql = "SET RW_IMPLICIT_FLUSH TO true;" + cursor.execute(sql) + + return init + + CONNECTION_CONFIG_TO_TYPE = { # Map all subclasses of ConnectionConfig to the value of their `type_` field. tpe.all_field_infos()["type_"].default: tpe for tpe in subclasses( - __name__, ConnectionConfig, exclude=(ConnectionConfig, BaseDuckDBConnectionConfig) + __name__, + ConnectionConfig, + exclude={ConnectionConfig, BaseDuckDBConnectionConfig}, + ) +} + +DIALECT_TO_TYPE = { + tpe.all_field_infos()["type_"].default: tpe.DIALECT + for tpe in subclasses( + __name__, + ConnectionConfig, + exclude={ConnectionConfig, BaseDuckDBConnectionConfig}, + ) +} + +INIT_DISPLAY_INFO_TO_TYPE = { + tpe.all_field_infos()["type_"].default: ( + tpe.DISPLAY_ORDER, + tpe.DISPLAY_NAME, + ) + for tpe in subclasses( + __name__, + ConnectionConfig, + exclude={ConnectionConfig, BaseDuckDBConnectionConfig}, ) } @@ -1344,10 +2387,21 @@ def _connection_config_validator( ) -> ConnectionConfig | None: if v is None or isinstance(v, ConnectionConfig): return v - return parse_connection_config(v) + check_config_and_vars_msg = "\n\nVerify your config.yaml and environment variables." + + try: + return parse_connection_config(v) + except pydantic.ValidationError as e: + raise ConfigError( + validation_error_message(e, f"Invalid '{v['type']}' connection config:") + + check_config_and_vars_msg + ) + except ConfigError as e: + raise ConfigError(str(e) + check_config_and_vars_msg) -connection_config_validator = field_validator( + +connection_config_validator: t.Callable = field_validator( "connection", "state_connection", "test_connection", @@ -1362,10 +2416,8 @@ def _connection_config_validator( # TypeAlias hasn't been introduced until Python 3.10 which means that we can't use it # outside the TYPE_CHECKING guard. SerializableConnectionConfig: t.TypeAlias = ConnectionConfig # type: ignore -elif PYDANTIC_MAJOR_VERSION >= 2: +else: import pydantic # Workaround for https://docs.pydantic.dev/latest/concepts/serialization/#serializing-with-duck-typing SerializableConnectionConfig = pydantic.SerializeAsAny[ConnectionConfig] # type: ignore -else: - SerializableConnectionConfig = ConnectionConfig # type: ignore diff --git a/sqlmesh/core/config/dbt.py b/sqlmesh/core/config/dbt.py new file mode 100644 index 0000000000..e3132c40a4 --- /dev/null +++ b/sqlmesh/core/config/dbt.py @@ -0,0 +1,13 @@ +from sqlmesh.core.config.base import BaseConfig + + +class DbtConfig(BaseConfig): + """ + Represents dbt-specific options on the SQLMesh root config. + + These options are only taken into account for dbt projects and are ignored on native projects + """ + + infer_state_schema_name: bool = False + """If set, indicates to the dbt loader that the state schema should be inferred based on the profile/target + so that each target gets its own isolated state""" diff --git a/sqlmesh/core/config/feature_flag.py b/sqlmesh/core/config/feature_flag.py deleted file mode 100644 index b04dda2b08..0000000000 --- a/sqlmesh/core/config/feature_flag.py +++ /dev/null @@ -1,9 +0,0 @@ -from sqlmesh.utils.pydantic import PydanticModel - - -class DbtFeatureFlag(PydanticModel): - scd_type_2_support: bool = True - - -class FeatureFlag(PydanticModel): - dbt: DbtFeatureFlag = DbtFeatureFlag() diff --git a/sqlmesh/core/config/format.py b/sqlmesh/core/config/format.py index 2b8fd21524..8730425d2e 100644 --- a/sqlmesh/core/config/format.py +++ b/sqlmesh/core/config/format.py @@ -16,6 +16,7 @@ class FormatConfig(BaseConfig): leading_comma: Whether to use leading commas or not. max_text_width: The maximum text width in a segment before creating new lines. append_newline: Whether to append a newline to the end of the file or not. + no_rewrite_casts: Preserve the existing casts, without rewriting them to use the :: syntax. """ normalize: bool = False @@ -25,6 +26,7 @@ class FormatConfig(BaseConfig): leading_comma: bool = False max_text_width: int = 80 append_newline: bool = False + no_rewrite_casts: bool = False @property def generator_options(self) -> t.Dict[str, t.Any]: @@ -33,4 +35,4 @@ def generator_options(self) -> t.Dict[str, t.Any]: Returns: The generator options. """ - return self.dict(exclude={"append_newline"}) + return self.dict(exclude={"append_newline", "no_rewrite_casts"}) diff --git a/sqlmesh/core/config/gateway.py b/sqlmesh/core/config/gateway.py index 5c330a16e0..a51557c4d7 100644 --- a/sqlmesh/core/config/gateway.py +++ b/sqlmesh/core/config/gateway.py @@ -4,12 +4,13 @@ from sqlmesh.core import constants as c from sqlmesh.core.config.base import BaseConfig +from sqlmesh.core.config.model import ModelDefaultsConfig from sqlmesh.core.config.common import variables_validator from sqlmesh.core.config.connection import ( SerializableConnectionConfig, connection_config_validator, ) -from sqlmesh.core.config.scheduler import SchedulerConfig +from sqlmesh.core.config.scheduler import SchedulerConfig, scheduler_config_validator class GatewayConfig(BaseConfig): @@ -34,6 +35,8 @@ class GatewayConfig(BaseConfig): scheduler: t.Optional[SchedulerConfig] = None state_schema: t.Optional[str] = c.SQLMESH variables: t.Dict[str, t.Any] = {} + model_defaults: t.Optional[ModelDefaultsConfig] = None _connection_config_validator = connection_config_validator + _scheduler_config_validator = scheduler_config_validator _variables_validator = variables_validator diff --git a/sqlmesh/core/config/janitor.py b/sqlmesh/core/config/janitor.py new file mode 100644 index 0000000000..0f1c953bc0 --- /dev/null +++ b/sqlmesh/core/config/janitor.py @@ -0,0 +1,26 @@ +from __future__ import annotations + +import typing as t + +from sqlmesh.core.config.base import BaseConfig +from sqlmesh.utils.pydantic import field_validator + + +class JanitorConfig(BaseConfig): + """The configuration for the janitor. + + Args: + warn_on_delete_failure: Whether to warn instead of erroring if the janitor fails to delete the expired environment schema / views. + expired_snapshots_batch_size: Maximum number of expired snapshots to clean in a single batch. + """ + + warn_on_delete_failure: bool = False + expired_snapshots_batch_size: t.Optional[int] = None + + @field_validator("expired_snapshots_batch_size", mode="before") + @classmethod + def _validate_batch_size(cls, value: int) -> int: + batch_size = int(value) + if batch_size <= 0: + raise ValueError("expired_snapshots_batch_size must be greater than 0") + return batch_size diff --git a/sqlmesh/core/config/linter.py b/sqlmesh/core/config/linter.py new file mode 100644 index 0000000000..c2a40e09aa --- /dev/null +++ b/sqlmesh/core/config/linter.py @@ -0,0 +1,44 @@ +from __future__ import annotations + +import typing as t + +from sqlglot import exp +from sqlglot.helper import ensure_collection + +from sqlmesh.core.config.base import BaseConfig + +from sqlmesh.utils.pydantic import field_validator + + +class LinterConfig(BaseConfig): + """Configuration for model linting + + Args: + enabled: Flag indicating whether the linter should run + + rules: A list of error rules to be applied on model + warn_rules: A list of rules to be applied on models but produce warnings instead of raising errors. + ignored_rules: A list of rules to be excluded/ignored + + """ + + enabled: bool = False + + rules: t.Set[str] = set() + warn_rules: t.Set[str] = set() + ignored_rules: t.Set[str] = set() + + @classmethod + def _validate_rules(cls, v: t.Any) -> t.Set[str]: + if isinstance(v, exp.Paren): + v = v.unnest().name + elif isinstance(v, (exp.Tuple, exp.Array)): + v = [e.name for e in v.expressions] + elif isinstance(v, exp.Expression): + v = v.name + + return {name.lower() for name in ensure_collection(v)} + + @field_validator("rules", "warn_rules", "ignored_rules", mode="before") + def rules_validator(cls, vs: t.Any) -> t.Set[str]: + return cls._validate_rules(vs) diff --git a/sqlmesh/core/config/loader.py b/sqlmesh/core/config/loader.py index 78d0a9bb0e..e92c62960a 100644 --- a/sqlmesh/core/config/loader.py +++ b/sqlmesh/core/config/loader.py @@ -1,17 +1,26 @@ from __future__ import annotations +import glob import os import typing as t from pathlib import Path +from pydantic import ValidationError +from dotenv import load_dotenv from sqlglot.helper import ensure_list from sqlmesh.core import constants as c +from sqlmesh.core.config.common import ( + ALL_CONFIG_FILENAMES, + YAML_CONFIG_FILENAMES, + DBT_PROJECT_FILENAME, +) from sqlmesh.core.config.model import ModelDefaultsConfig from sqlmesh.core.config.root import Config from sqlmesh.utils import env_vars, merge_dicts, sys_path from sqlmesh.utils.errors import ConfigError from sqlmesh.utils.metaprogramming import import_python_file +from sqlmesh.utils.pydantic import validation_error_message from sqlmesh.utils.yaml import load as yaml_load C = t.TypeVar("C", bound=Config) @@ -22,24 +31,33 @@ def load_configs( config_type: t.Type[C], paths: t.Union[str | Path, t.Iterable[str | Path]], sqlmesh_path: t.Optional[Path] = None, + dotenv_path: t.Optional[Path] = None, + **kwargs: t.Any, ) -> t.Dict[Path, C]: sqlmesh_path = sqlmesh_path or c.SQLMESH_PATH config = config or "config" absolute_paths = [ - Path(t.cast(t.Union[str, Path], path)).absolute() for path in ensure_list(paths) + Path(t.cast(t.Union[str, Path], p)).absolute() + for path in ensure_list(paths) + for p in (glob.glob(str(path)) or [str(path)]) ] + if dotenv_path and dotenv_path.exists() and dotenv_path.is_file(): + load_dotenv(dotenv_path=dotenv_path, override=True) + else: + for path in absolute_paths: + env_file = path / ".env" + if env_file.exists() and env_file.is_file(): + load_dotenv(dotenv_path=env_file, override=True) + if not isinstance(config, str): if type(config) != config_type: config = convert_config_type(config, config_type) return {path: config for path in absolute_paths} config_env_vars = None - personal_paths = [ - sqlmesh_path / "config.yml", - sqlmesh_path / "config.yaml", - ] + personal_paths = [sqlmesh_path / name for name in YAML_CONFIG_FILENAMES] for path in personal_paths: if path.exists(): config_env_vars = load_config_from_yaml(path).get("env_vars") @@ -50,9 +68,10 @@ def load_configs( return { path: load_config_from_paths( config_type, - project_paths=[path / "config.py", path / "config.yml", path / "config.yaml"], + project_paths=[path / name for name in ALL_CONFIG_FILENAMES], personal_paths=personal_paths, config_name=config, + **kwargs, ) for path in absolute_paths } @@ -64,6 +83,8 @@ def load_config_from_paths( personal_paths: t.Optional[t.List[Path]] = None, config_name: str = "config", load_from_env: bool = True, + variables: t.Optional[t.Dict[str, t.Any]] = None, + **kwargs: t.Any, ) -> C: project_paths = project_paths or [] personal_paths = personal_paths or [] @@ -76,6 +97,7 @@ def load_config_from_paths( "SQLMesh project config could not be found. Point the cli to the project path with `sqlmesh -p`. If you haven't set up the SQLMesh project, run `sqlmesh init`." ) + yaml_config_path: t.Optional[Path] = None for path in [*project_paths, *personal_paths]: if not path.exists(): continue @@ -92,13 +114,20 @@ def load_config_from_paths( if extension in ("yml", "yaml"): if config_name != "config" and not python_config: raise ConfigError( - "YAML configs do not support multiple configs. Use Python instead." + "YAML configs do not support multiple configs. Use Python instead.", ) - non_python_configs.append(load_config_from_yaml(path)) + yaml_config_path = path.resolve() + non_python_configs.append(load_config_from_yaml(path, variables)) elif extension == "py": - python_config = load_config_from_python_module( - config_type, path, config_name=config_name - ) + try: + python_config = load_config_from_python_module( + config_type, path, config_name=config_name + ) + except ValidationError as e: + raise ConfigError( + validation_error_message(e, f"Invalid project config '{config_name}':") + + "\n\nVerify your config.py." + ) else: raise ConfigError( f"Unsupported config file extension '{extension}' in config file '{path}'." @@ -123,9 +152,44 @@ def load_config_from_paths( f"'{default}' is not a valid model default configuration key. Please remove it from the `model_defaults` specification in your config file." ) - non_python_config = config_type.parse_obj(non_python_config_dict) + try: + non_python_config = config_type.parse_obj(non_python_config_dict) + except ValidationError as e: + raise ConfigError( + validation_error_message(e, "Invalid project config:") + + "\n\nVerify your config.yaml and environment variables.", + location=yaml_config_path, + ) no_dialect_err_msg = "Default model SQL dialect is a required configuration parameter. Set it in the `model_defaults` `dialect` key in your config file." + + # if "dbt_project.yml" is present *and there was no python config already defined*, + # create a basic one to ensure we are using the DBT loader. + # any config within yaml files will get overlayed on top of it. + if not python_config: + potential_project_files = [f / DBT_PROJECT_FILENAME for f in visited_folders] + dbt_project_file = next((f for f in potential_project_files if f.exists()), None) + if dbt_project_file: + from sqlmesh.dbt.loader import sqlmesh_config + + infer_state_schema_name = False + if dbt := non_python_config.dbt: + infer_state_schema_name = dbt.infer_state_schema_name + + dbt_python_config = sqlmesh_config( + project_root=dbt_project_file.parent, + profiles_dir=kwargs.pop("profiles_dir", None), + dbt_profile_name=kwargs.pop("profile", None), + dbt_target_name=kwargs.pop("target", None), + variables=variables, + threads=kwargs.pop("threads", None), + infer_state_schema_name=infer_state_schema_name, + ) + if type(dbt_python_config) != config_type: + dbt_python_config = convert_config_type(dbt_python_config, config_type) + + python_config = dbt_python_config + if python_config: model_defaults = python_config.model_defaults if model_defaults.dialect is None: @@ -135,11 +199,21 @@ def load_config_from_paths( model_defaults = non_python_config.model_defaults if model_defaults.dialect is None: raise ConfigError(no_dialect_err_msg) + return non_python_config -def load_config_from_yaml(path: Path) -> t.Dict[str, t.Any]: - return yaml_load(path) +def load_config_from_yaml( + path: Path, variables: t.Optional[t.Dict[str, t.Any]] = None +) -> t.Dict[str, t.Any]: + content = yaml_load(path, variables=variables) + if not isinstance(content, dict): + raise ConfigError( + f"Invalid YAML configuration: expected a dictionary but got {type(content).__name__}. " + f"Please check the YAML syntax in your config file.", + location=path, + ) + return content def load_config_from_python_module( @@ -147,8 +221,14 @@ def load_config_from_python_module( module_path: Path, config_name: str = "config", ) -> C: - with sys_path(module_path.parent): - config_module = import_python_file(module_path, module_path.parent) + try: + with sys_path(module_path.parent): + config_module = import_python_file(module_path, module_path.parent) + except Exception as e: + raise ConfigError( + f"Failed to load config file: {e}", + location=module_path, + ) try: config_obj = getattr(config_module, config_name) @@ -157,7 +237,8 @@ def load_config_from_python_module( if config_obj is None or not isinstance(config_obj, Config): raise ConfigError( - f"Config needs to be a valid object of type sqlmesh.core.config.Config. Found `{config_obj}` instead at '{module_path}'." + f"Config needs to be a valid object of type sqlmesh.core.config.Config. Found `{config_obj}` instead at '{module_path}'.", + module_path, ) return ( @@ -172,7 +253,7 @@ def load_config_from_env() -> t.Dict[str, t.Any]: for key, value in os.environ.items(): key = key.lower() - if key.startswith(f"{c.SQLMESH}__"): + if key.startswith(f"{c.SQLMESH}__") and key != (c.DISABLE_SQLMESH_STATE_MIGRATION).lower(): segments = key.split("__")[1:] if not segments or not segments[-1]: raise ConfigError(f"Invalid SQLMesh configuration variable '{key}'.") diff --git a/sqlmesh/core/config/model.py b/sqlmesh/core/config/model.py index 796734281d..aeefdf2557 100644 --- a/sqlmesh/core/config/model.py +++ b/sqlmesh/core/config/model.py @@ -2,14 +2,21 @@ import typing as t +from sqlglot import exp +from sqlmesh.core.dialect import parse_one, extract_func_call from sqlmesh.core.config.base import BaseConfig from sqlmesh.core.model.kind import ( ModelKind, OnDestructiveChange, model_kind_validator, on_destructive_change_validator, + on_additive_change_validator, + OnAdditiveChange, ) +from sqlmesh.core.model.meta import FunctionCall +from sqlmesh.core.node import IntervalUnit from sqlmesh.utils.date import TimeLike +from sqlmesh.utils.pydantic import field_validator class ModelDefaultsConfig(BaseConfig): @@ -24,9 +31,25 @@ class ModelDefaultsConfig(BaseConfig): start: The earliest date that the model will be backfilled for. If this is None, then the date is inferred by taking the most recent start date of its ancestors. The start date can be a static datetime or a relative datetime like "1 year ago" + table_format: The table format used to manage the physical table files defined by `storage_format`, only applicable in certain engines. + (eg, 'iceberg', 'delta', 'hudi') storage_format: The storage format used to store the physical table, only applicable in certain engines. - (eg. 'parquet') + (eg. 'parquet', 'orc') on_destructive_change: What should happen when a forward-only model requires a destructive schema change. + on_additive_change: What should happen when a forward-only model requires an additive schema change. + physical_properties: A key-value mapping of arbitrary properties that are applied to the model table / view in the physical layer. + virtual_properties: A key-value mapping of arbitrary properties that are applied to the model view in the virtual layer. + session_properties: A key-value mapping of properties specific to the target engine that are applied to the engine session. + audits: The audits to be applied globally to all models in the project. + optimize_query: Whether the SQL models should be optimized. + allow_partials: Whether the models can process partial (incomplete) data intervals. + enabled: Whether the models are enabled. + interval_unit: The temporal granularity of the models data intervals. By default computed from cron. + batch_concurrency: The maximum number of batches that can run concurrently for an incremental model. + pre_statements: The list of SQL statements that get executed before a model runs. + post_statements: The list of SQL statements that get executed before a model runs. + on_virtual_update: The list of SQL statements to be executed after the virtual update. + """ kind: t.Optional[ModelKind] = None @@ -34,9 +57,31 @@ class ModelDefaultsConfig(BaseConfig): cron: t.Optional[str] = None owner: t.Optional[str] = None start: t.Optional[TimeLike] = None + table_format: t.Optional[str] = None storage_format: t.Optional[str] = None on_destructive_change: t.Optional[OnDestructiveChange] = None + on_additive_change: t.Optional[OnAdditiveChange] = None + physical_properties: t.Optional[t.Dict[str, t.Any]] = None + virtual_properties: t.Optional[t.Dict[str, t.Any]] = None session_properties: t.Optional[t.Dict[str, t.Any]] = None + audits: t.Optional[t.List[FunctionCall]] = None + optimize_query: t.Optional[t.Union[str, bool]] = None + allow_partials: t.Optional[t.Union[str, bool]] = None + interval_unit: t.Optional[t.Union[str, IntervalUnit]] = None + enabled: t.Optional[t.Union[str, bool]] = None + formatting: t.Optional[t.Union[str, bool]] = None + batch_concurrency: t.Optional[int] = None + pre_statements: t.Optional[t.List[t.Union[str, exp.Expression]]] = None + post_statements: t.Optional[t.List[t.Union[str, exp.Expression]]] = None + on_virtual_update: t.Optional[t.List[t.Union[str, exp.Expression]]] = None _model_kind_validator = model_kind_validator _on_destructive_change_validator = on_destructive_change_validator + _on_additive_change_validator = on_additive_change_validator + + @field_validator("audits", mode="before") + def _audits_validator(cls, v: t.Any) -> t.Any: + if isinstance(v, list): + return [extract_func_call(parse_one(audit)) for audit in v] + + return v diff --git a/sqlmesh/core/config/plan.py b/sqlmesh/core/config/plan.py index faf7a28aef..df1ca44873 100644 --- a/sqlmesh/core/config/plan.py +++ b/sqlmesh/core/config/plan.py @@ -1,5 +1,7 @@ from __future__ import annotations +import typing as t + from sqlmesh.core.config.base import BaseConfig from sqlmesh.core.config.categorizer import CategorizerConfig @@ -14,17 +16,19 @@ class PlanConfig(BaseConfig): include_unmodified: Whether to include unmodified models in the target development environment. enable_preview: Whether to enable preview for forward-only models in development environments. no_diff: Hide text differences for changed models. - no_prompts: Whether to disable interactive prompts for the backfill time range. Please note that + no_prompts: Whether to disable interactive prompts for the backfill time range. auto_apply: Whether to automatically apply the new plan after creation. use_finalized_state: Whether to compare against the latest finalized environment state, or to use whatever state the target environment is currently in. + always_recreate_environment: Whether to always recreate the target environment from the `create_from` environment. """ forward_only: bool = False auto_categorize_changes: CategorizerConfig = CategorizerConfig() include_unmodified: bool = False - enable_preview: bool = False + enable_preview: t.Optional[bool] = None no_diff: bool = False - no_prompts: bool = False + no_prompts: bool = True auto_apply: bool = False use_finalized_state: bool = False + always_recreate_environment: bool = False diff --git a/sqlmesh/core/config/root.py b/sqlmesh/core/config/root.py index a2e18e8461..211d271b01 100644 --- a/sqlmesh/core/config/root.py +++ b/sqlmesh/core/config/root.py @@ -6,40 +6,86 @@ import zlib from pydantic import Field +from pydantic.functional_validators import BeforeValidator from sqlglot import exp from sqlglot.helper import first from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlmesh.cicd.config import CICDBotConfig from sqlmesh.core import constants as c -from sqlmesh.core.config import EnvironmentSuffixTarget +from sqlmesh.core.console import get_console +from sqlmesh.core.config.common import ( + EnvironmentSuffixTarget, + TableNamingConvention, + VirtualEnvironmentMode, +) from sqlmesh.core.config.base import BaseConfig, UpdateStrategy -from sqlmesh.core.config.common import variables_validator +from sqlmesh.core.config.common import variables_validator, compile_regex_mapping from sqlmesh.core.config.connection import ( ConnectionConfig, DuckDBConnectionConfig, SerializableConnectionConfig, connection_config_validator, ) -from sqlmesh.core.config.feature_flag import FeatureFlag from sqlmesh.core.config.format import FormatConfig from sqlmesh.core.config.gateway import GatewayConfig +from sqlmesh.core.config.janitor import JanitorConfig from sqlmesh.core.config.migration import MigrationConfig from sqlmesh.core.config.model import ModelDefaultsConfig from sqlmesh.core.config.naming import NameInferenceConfig as NameInferenceConfig +from sqlmesh.core.config.linter import LinterConfig as LinterConfig from sqlmesh.core.config.plan import PlanConfig from sqlmesh.core.config.run import RunConfig -from sqlmesh.core.config.scheduler import BuiltInSchedulerConfig, SchedulerConfig +from sqlmesh.core.config.dbt import DbtConfig +from sqlmesh.core.config.scheduler import ( + BuiltInSchedulerConfig, + SchedulerConfig, + scheduler_config_validator, +) from sqlmesh.core.config.ui import UIConfig from sqlmesh.core.loader import Loader, SqlMeshLoader from sqlmesh.core.notification_target import NotificationTarget from sqlmesh.core.user import User +from sqlmesh.utils.date import to_timestamp, now from sqlmesh.utils.errors import ConfigError -from sqlmesh.utils.pydantic import ( - field_validator, - model_validator, - model_validator_v1_args, -) +from sqlmesh.utils.pydantic import model_validator + + +def validate_no_past_ttl(v: str) -> str: + current_time = now() + if to_timestamp(v, relative_base=current_time) < to_timestamp(current_time): + raise ValueError( + f"TTL '{v}' is in the past. Please specify a relative time in the future. Ex: `in 1 week` instead of `1 week`." + ) + return v + + +def gateways_ensure_dict(value: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: + try: + if not isinstance(value, GatewayConfig): + GatewayConfig.parse_obj(value) + return {"": value} + except Exception: + # Normalize all gateway keys to lowercase for case-insensitive matching + if isinstance(value, dict): + return {k.lower(): v for k, v in value.items()} + return value + + +def validate_regex_key_dict(value: t.Dict[str | re.Pattern, t.Any]) -> t.Dict[re.Pattern, t.Any]: + return compile_regex_mapping(value) + + +if t.TYPE_CHECKING: + from sqlmesh.core._typing import Self + + NoPastTTLString = str + GatewayDict = t.Dict[str, GatewayConfig] + RegexKeyDict = t.Dict[re.Pattern, str] +else: + NoPastTTLString = t.Annotated[str, BeforeValidator(validate_no_past_ttl)] + GatewayDict = t.Annotated[t.Dict[str, GatewayConfig], BeforeValidator(gateways_ensure_dict)] + RegexKeyDict = t.Annotated[t.Dict[re.Pattern, str], BeforeValidator(validate_regex_key_dict)] class Config(BaseConfig): @@ -65,22 +111,28 @@ class Config(BaseConfig): loader_kwargs: Key-value arguments to pass to the loader instance. env_vars: A dictionary of environmental variable names and values. model_defaults: Default values for model definitions. - physical_schema_override: A mapping from model schema names to names of schemas in which physical tables for corresponding models will be placed. + physical_schema_mapping: A mapping from regular expressions to names of schemas in which physical tables for corresponding models will be placed. environment_suffix_target: Indicates whether to append the environment name to the schema or table name. + physical_table_naming_convention: Indicates how tables should be named at the physical layer + virtual_environment_mode: Indicates how environments should be handled. + gateway_managed_virtual_layer: Whether the models' views in the virtual layer are created by the model-specific gateway rather than the default gateway. + infer_python_dependencies: Whether to statically analyze Python code to automatically infer Python package requirements. environment_catalog_mapping: A mapping from regular expressions to catalog names. The catalog name is used to determine the target catalog for a given environment. default_target_environment: The name of the environment that will be the default target for the `sqlmesh plan` and `sqlmesh run` commands. log_limit: The default number of logs to keep. format: The formatting options for SQL code. ui: The UI configuration for SQLMesh. - feature_flags: Feature flags to enable/disable certain features. plan: The plan configuration. migration: The migration configuration. variables: A dictionary of variables that can be used in models / macros. disable_anonymized_analytics: Whether to disable the anonymized analytics collection. + before_all: SQL statements or macros to be executed at the start of the `sqlmesh plan` and `sqlmesh run` commands. + after_all: SQL statements or macros to be executed at the end of the `sqlmesh plan` and `sqlmesh run` commands. + cache_dir: The directory to store the SQLMesh cache. Defaults to .cache in the project folder. """ - gateways: t.Dict[str, GatewayConfig] = {"": GatewayConfig()} - default_connection: SerializableConnectionConfig = DuckDBConnectionConfig() + gateways: GatewayDict = {"": GatewayConfig()} + default_connection: t.Optional[SerializableConnectionConfig] = None default_test_connection_: t.Optional[SerializableConnectionConfig] = Field( default=None, alias="default_test_connection" ) @@ -88,8 +140,8 @@ class Config(BaseConfig): default_gateway: str = "" notification_targets: t.List[NotificationTarget] = [] project: str = "" - snapshot_ttl: str = c.DEFAULT_SNAPSHOT_TTL - environment_ttl: t.Optional[str] = c.DEFAULT_ENVIRONMENT_TTL + snapshot_ttl: NoPastTTLString = c.DEFAULT_SNAPSHOT_TTL + environment_ttl: t.Optional[NoPastTTLString] = c.DEFAULT_ENVIRONMENT_TTL ignore_patterns: t.List[str] = c.IGNORE_PATTERNS time_column_format: str = c.DEFAULT_TIME_COLUMN_FORMAT users: t.List[User] = [] @@ -99,26 +151,33 @@ class Config(BaseConfig): loader_kwargs: t.Dict[str, t.Any] = {} env_vars: t.Dict[str, str] = {} username: str = "" - physical_schema_override: t.Dict[str, str] = {} - environment_suffix_target: EnvironmentSuffixTarget = Field( - default=EnvironmentSuffixTarget.default - ) - environment_catalog_mapping: t.Dict[re.Pattern, str] = {} + physical_schema_mapping: RegexKeyDict = {} + environment_suffix_target: EnvironmentSuffixTarget = EnvironmentSuffixTarget.default + physical_table_naming_convention: TableNamingConvention = TableNamingConvention.default + virtual_environment_mode: VirtualEnvironmentMode = VirtualEnvironmentMode.default + gateway_managed_virtual_layer: bool = False + infer_python_dependencies: bool = True + environment_catalog_mapping: RegexKeyDict = {} default_target_environment: str = c.PROD log_limit: int = c.DEFAULT_LOG_LIMIT cicd_bot: t.Optional[CICDBotConfig] = None run: RunConfig = RunConfig() format: FormatConfig = FormatConfig() ui: UIConfig = UIConfig() - feature_flags: FeatureFlag = FeatureFlag() plan: PlanConfig = PlanConfig() migration: MigrationConfig = MigrationConfig() model_naming: NameInferenceConfig = NameInferenceConfig() variables: t.Dict[str, t.Any] = {} disable_anonymized_analytics: bool = False + before_all: t.Optional[t.List[str]] = None + after_all: t.Optional[t.List[str]] = None + linter: LinterConfig = LinterConfig() + janitor: JanitorConfig = JanitorConfig() + cache_dir: t.Optional[str] = None + dbt: t.Optional[DbtConfig] = None _FIELD_UPDATE_STRATEGY: t.ClassVar[t.Dict[str, UpdateStrategy]] = { - "gateways": UpdateStrategy.KEY_UPDATE, + "gateways": UpdateStrategy.NESTED_UPDATE, "notification_targets": UpdateStrategy.EXTEND, "ignore_patterns": UpdateStrategy.EXTEND, "users": UpdateStrategy.EXTEND, @@ -131,57 +190,94 @@ class Config(BaseConfig): "ui": UpdateStrategy.NESTED_UPDATE, "loader_kwargs": UpdateStrategy.KEY_UPDATE, "plan": UpdateStrategy.NESTED_UPDATE, + "before_all": UpdateStrategy.EXTEND, + "after_all": UpdateStrategy.EXTEND, + "linter": UpdateStrategy.NESTED_UPDATE, + "dbt": UpdateStrategy.NESTED_UPDATE, } _connection_config_validator = connection_config_validator + _scheduler_config_validator = scheduler_config_validator # type: ignore _variables_validator = variables_validator - @field_validator("gateways", mode="before", always=True) - @classmethod - def _gateways_ensure_dict(cls, value: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: - try: - if not isinstance(value, GatewayConfig): - GatewayConfig.parse_obj(value) - return {"": value} - except Exception: - return value - - @field_validator("environment_catalog_mapping", mode="before") - @classmethod - def _validate_regex_keys( - cls, value: t.Dict[str | re.Pattern, t.Any] - ) -> t.Dict[re.Pattern, t.Any]: - compiled_regexes = {} - for k, v in value.items(): - try: - compiled_regexes[re.compile(k)] = v - except re.error: - raise ConfigError(f"`{k}` is not a valid regular expression.") - return compiled_regexes - @model_validator(mode="before") - @model_validator_v1_args - def _normalize_and_validate_fields(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: - if "gateways" not in values and "gateway" in values: - values["gateways"] = values.pop("gateway") + def _normalize_and_validate_fields(cls, data: t.Any) -> t.Any: + if not isinstance(data, dict): + return data + + if "gateways" not in data and "gateway" in data: + data["gateways"] = data.pop("gateway") for plan_deprecated in ("auto_categorize_changes", "include_unmodified"): - if plan_deprecated in values: + if plan_deprecated in data: raise ConfigError( f"The `{plan_deprecated}` config is deprecated. Please use the `plan.{plan_deprecated}` config instead." ) - return values + if "physical_schema_override" in data: + get_console().log_warning( + "`physical_schema_override` is deprecated. Please use `physical_schema_mapping` instead." + ) + + if "physical_schema_mapping" in data: + raise ConfigError( + "Only one of `physical_schema_override` and `physical_schema_mapping` can be specified." + ) + + physical_schema_override: t.Dict[str, str] = data.pop("physical_schema_override") + # translate physical_schema_override to physical_schema_mapping + data["physical_schema_mapping"] = { + f"^{k}$": v for k, v in physical_schema_override.items() + } + + return data + + @model_validator(mode="after") + def _normalize_fields_after(self) -> Self: + dialect = self.model_defaults.dialect + + def _normalize_identifiers(key: str) -> None: + setattr( + self, + key, + { + k: normalize_identifiers(v, dialect=dialect).name + for k, v in getattr(self, key, {}).items() + }, + ) + + if ( + self.environment_suffix_target == EnvironmentSuffixTarget.CATALOG + and self.environment_catalog_mapping + ): + raise ConfigError( + f"'environment_suffix_target: catalog' is mutually exclusive with 'environment_catalog_mapping'.\n" + "Please specify one or the other" + ) + + if self.plan.use_finalized_state and not self.virtual_environment_mode.is_full: + raise ConfigError( + "Using the finalized state is only supported when `virtual_environment_mode` is set to `full`." + ) + + if self.environment_catalog_mapping: + _normalize_identifiers("environment_catalog_mapping") + if self.physical_schema_mapping: + _normalize_identifiers("physical_schema_mapping") + + return self @model_validator(mode="after") - @model_validator_v1_args - def _normalize_fields_after(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: - dialect = values["model_defaults"].dialect - values["environment_catalog_mapping"] = { - k: normalize_identifiers(v, dialect=dialect).name - for k, v in values.get("environment_catalog_mapping", {}).items() - } - return values + def _inherit_project_config_in_cicd_bot(self) -> Self: + if self.cicd_bot: + # inherit the project-level settings into the CICD bot if they have not been explicitly overridden + if self.cicd_bot.auto_categorize_changes_ is None: + self.cicd_bot.auto_categorize_changes_ = self.plan.auto_categorize_changes + + if self.cicd_bot.pr_include_unmodified_ is None: + self.cicd_bot.pr_include_unmodified_ = self.plan.include_unmodified + + return self def get_default_test_connection( self, @@ -205,28 +301,33 @@ def get_gateway(self, name: t.Optional[str] = None) -> GatewayConfig: if isinstance(self.gateways, dict): if name is None: if self.default_gateway: - if self.default_gateway not in self.gateways: + # Normalize default_gateway name to lowercase for lookup + default_key = self.default_gateway.lower() + if default_key not in self.gateways: raise ConfigError(f"Missing gateway with name '{self.default_gateway}'") - return self.gateways[self.default_gateway] + return self.gateways[default_key] if "" in self.gateways: return self.gateways[""] return first(self.gateways.values()) - if name not in self.gateways: + # Normalize lookup name to lowercase since gateway keys are already lowercase + lookup_key = name.lower() + if lookup_key not in self.gateways: raise ConfigError(f"Missing gateway with name '{name}'.") - return self.gateways[name] - else: - if name is not None: - raise ConfigError( - "Gateway name is not supported when only one gateway is configured." - ) - return self.gateways + return self.gateways[lookup_key] + if name is not None: + raise ConfigError("Gateway name is not supported when only one gateway is configured.") + return self.gateways def get_connection(self, gateway_name: t.Optional[str] = None) -> ConnectionConfig: - return self.get_gateway(gateway_name).connection or self.default_connection + connection = self.get_gateway(gateway_name).connection or self.default_connection + if connection is None: + msg = f" for gateway '{gateway_name}'" if gateway_name else "" + raise ConfigError(f"No connection configured{msg}.") + return connection def get_state_connection( self, gateway_name: t.Optional[str] = None diff --git a/sqlmesh/core/config/scheduler.py b/sqlmesh/core/config/scheduler.py index b8bf1c1f1a..970defee62 100644 --- a/sqlmesh/core/config/scheduler.py +++ b/sqlmesh/core/config/scheduler.py @@ -1,44 +1,30 @@ from __future__ import annotations import abc -import logging -import sys import typing as t -from pydantic import Field -from requests import Session +from pydantic import Field, ValidationError +from sqlglot.helper import subclasses from sqlmesh.core.config.base import BaseConfig -from sqlmesh.core.config.common import concurrent_tasks_validator -from sqlmesh.core.console import Console +from sqlmesh.core.console import get_console from sqlmesh.core.plan import ( - AirflowPlanEvaluator, BuiltInPlanEvaluator, - MWAAPlanEvaluator, PlanEvaluator, ) +from sqlmesh.core.config import DuckDBConnectionConfig from sqlmesh.core.state_sync import EngineAdapterStateSync, StateSync -from sqlmesh.schedulers.airflow.client import AirflowClient -from sqlmesh.schedulers.airflow.mwaa_client import MWAAClient from sqlmesh.utils.errors import ConfigError from sqlmesh.utils.hashing import md5 -from sqlmesh.utils.pydantic import model_validator, model_validator_v1_args +from sqlmesh.utils.pydantic import field_validator, validation_error_message if t.TYPE_CHECKING: - from google.auth.transport.requests import AuthorizedSession - from sqlmesh.core.context import GenericContext -if sys.version_info >= (3, 9): - from typing import Annotated, Literal -else: - from typing_extensions import Annotated, Literal - +from sqlmesh.utils.config import sensitive_fields, excluded_fields -logger = logging.getLogger(__name__) - -class _SchedulerConfig(abc.ABC): +class SchedulerConfig(abc.ABC): """Abstract base class for Scheduler configurations.""" @abc.abstractmethod @@ -61,8 +47,8 @@ def create_state_sync(self, context: GenericContext) -> StateSync: """ @abc.abstractmethod - def get_default_catalog(self, context: GenericContext) -> t.Optional[str]: - """Returns the default catalog for the Scheduler. + def get_default_catalog_per_gateway(self, context: GenericContext) -> t.Dict[str, str]: + """Returns the default catalog for each gateway. Args: context: The SQLMesh Context. @@ -77,47 +63,60 @@ def state_sync_fingerprint(self, context: GenericContext) -> str: """ -class _EngineAdapterStateSyncSchedulerConfig(_SchedulerConfig): +class _EngineAdapterStateSyncSchedulerConfig(SchedulerConfig): def create_state_sync(self, context: GenericContext) -> StateSync: state_connection = ( - context.config.get_state_connection(context.gateway) or context._connection_config + context.config.get_state_connection(context.gateway) or context.connection_config ) + + warehouse_connection = context.config.get_connection(context.gateway) + + if ( + isinstance(state_connection, DuckDBConnectionConfig) + and state_connection.concurrent_tasks <= 1 + ): + # If we are using DuckDB, ensure that multithreaded mode gets enabled if necessary + if warehouse_connection.concurrent_tasks > 1: + get_console().log_warning( + "The duckdb state connection is configured for single threaded mode but the warehouse connection is configured for " + + f"multi threaded mode with {warehouse_connection.concurrent_tasks} concurrent tasks." + + " This can cause SQLMesh to hang. Overriding the duckdb state connection config to use multi threaded mode." + ) + # this triggers multithreaded mode and has to happen before the engine adapter is created below + state_connection.concurrent_tasks = warehouse_connection.concurrent_tasks + engine_adapter = state_connection.create_engine_adapter() - if not engine_adapter.SUPPORTS_ROW_LEVEL_OP: + if state_connection.is_forbidden_for_state_sync: raise ConfigError( f"The {engine_adapter.DIALECT.upper()} engine cannot be used to store SQLMesh state - please specify a different `state_connection` engine." + " See https://sqlmesh.readthedocs.io/en/stable/reference/configuration/#gateways for more information." ) + + # If the user is using DuckDB for both the state and the warehouse connection, they are most likely running an example project + # or POC. To reduce friction, we wont log a warning about DuckDB being used for state until they change to a proper warehouse + if not isinstance(state_connection, DuckDBConnectionConfig) or not isinstance( + warehouse_connection, DuckDBConnectionConfig + ): + if not state_connection.is_recommended_for_state_sync: + get_console().log_warning( + f"The {state_connection.type_} engine is not recommended for storing SQLMesh state in production deployments. Please see" + + " https://sqlmesh.readthedocs.io/en/stable/guides/configuration/#state-connection for a list of recommended engines and more information." + ) + schema = context.config.get_state_schema(context.gateway) return EngineAdapterStateSync( - engine_adapter, schema=schema, context_path=context.path, console=context.console + engine_adapter, schema=schema, cache_dir=context.cache_dir, console=context.console ) def state_sync_fingerprint(self, context: GenericContext) -> str: state_connection = ( - context.config.get_state_connection(context.gateway) or context._connection_config + context.config.get_state_connection(context.gateway) or context.connection_config ) return md5( [ state_connection.json( sort_keys=True, - exclude={ - "access_token", - "concurrent_tasks", - "user", - "password", - "keytab", - "keyfile", - "keyfile_json", - "pre_ping", - "principal", - "private_key", - "private_key_passphrase", - "private_key_path", - "refresh_token", - "register_comments", - "token", - }, + exclude=sensitive_fields.union(excluded_fields), ) ] ) @@ -126,257 +125,56 @@ def state_sync_fingerprint(self, context: GenericContext) -> str: class BuiltInSchedulerConfig(_EngineAdapterStateSyncSchedulerConfig, BaseConfig): """The Built-In Scheduler configuration.""" - type_: Literal["builtin"] = Field(alias="type", default="builtin") + type_: t.Literal["builtin"] = Field(alias="type", default="builtin") def create_plan_evaluator(self, context: GenericContext) -> PlanEvaluator: return BuiltInPlanEvaluator( state_sync=context.state_sync, snapshot_evaluator=context.snapshot_evaluator, - default_catalog=self.get_default_catalog(context), - backfill_concurrent_tasks=context.concurrent_tasks, + create_scheduler=context.create_scheduler, + default_catalog=context.default_catalog, console=context.console, - notification_target_manager=context.notification_target_manager, - signal_factory=context._signal_factory, ) - def get_default_catalog(self, context: GenericContext) -> t.Optional[str]: - return context.engine_adapter.default_catalog + def get_default_catalog_per_gateway(self, context: GenericContext) -> t.Dict[str, str]: + default_catalogs_per_gateway: t.Dict[str, str] = {} + for gateway, adapter in context.engine_adapters.items(): + if catalog := adapter.default_catalog: + default_catalogs_per_gateway[gateway] = catalog + return default_catalogs_per_gateway -class _BaseAirflowSchedulerConfig(_EngineAdapterStateSyncSchedulerConfig): - airflow_url: str - dag_run_poll_interval_secs: int - dag_creation_poll_interval_secs: int - dag_creation_max_retry_attempts: int +SCHEDULER_CONFIG_TO_TYPE = { + tpe.all_field_infos()["type_"].default: tpe + for tpe in subclasses(__name__, BaseConfig, exclude={BaseConfig}) +} - backfill_concurrent_tasks: int - ddl_concurrent_tasks: int - use_state_connection: bool +def _scheduler_config_validator( + cls: t.Type, v: SchedulerConfig | t.Dict[str, t.Any] | None +) -> SchedulerConfig | None: + if v is None or isinstance(v, SchedulerConfig): + return v - default_catalog_override: t.Optional[str] + if "type" not in v: + raise ConfigError("Missing scheduler type.") - @abc.abstractmethod - def get_client(self, console: t.Optional[Console] = None) -> AirflowClient: - """Constructs the Airflow Client instance.""" + scheduler_type = v["type"] + if scheduler_type not in SCHEDULER_CONFIG_TO_TYPE: + raise ConfigError(f"Unknown scheduler type '{scheduler_type}'.") - def create_state_sync(self, context: GenericContext) -> StateSync: - if self.use_state_connection: - return super().create_state_sync(context) - - from sqlmesh.schedulers.airflow.state_sync import HttpStateSync - - return HttpStateSync( - client=self.get_client(context.console), - dag_run_poll_interval_secs=self.dag_run_poll_interval_secs, - console=context.console, + try: + return SCHEDULER_CONFIG_TO_TYPE[scheduler_type](**v) + except ValidationError as e: + raise ConfigError( + validation_error_message(e, f"Invalid '{scheduler_type}' scheduler config:") + + "\n\nVerify your config.yaml and environment variables." ) - def state_sync_fingerprint(self, context: GenericContext) -> str: - if self.use_state_connection: - return super().state_sync_fingerprint(context) - return md5([self.airflow_url]) - - def create_plan_evaluator(self, context: GenericContext) -> PlanEvaluator: - return AirflowPlanEvaluator( - airflow_client=self.get_client(context.console), - dag_run_poll_interval_secs=self.dag_run_poll_interval_secs, - dag_creation_poll_interval_secs=self.dag_creation_poll_interval_secs, - dag_creation_max_retry_attempts=self.dag_creation_max_retry_attempts, - console=context.console, - notification_targets=context.notification_targets, - backfill_concurrent_tasks=self.backfill_concurrent_tasks, - ddl_concurrent_tasks=self.ddl_concurrent_tasks, - users=context.users, - state_sync=context.state_sync if self.use_state_connection else None, - ) - - def get_default_catalog(self, context: GenericContext) -> t.Optional[str]: - default_catalog = self.get_client(context.console).default_catalog - return self.default_catalog_override or default_catalog - - -class AirflowSchedulerConfig(_BaseAirflowSchedulerConfig, BaseConfig): - """The Airflow Scheduler configuration. - - Args: - airflow_url: The URL of the Airflow Webserver. - username: The Airflow username. - password: The Airflow password. - dag_run_poll_interval_secs: Determines how often a running DAG can be polled (in seconds). - dag_creation_poll_interval_secs: Determines how often SQLMesh should check whether a DAG has been created (in seconds). - dag_creation_max_retry_attempts: Determines the maximum number of attempts that SQLMesh will make while checking for - whether a DAG has been created. - backfill_concurrent_tasks: The number of concurrent tasks used for model backfilling during plan application. - ddl_concurrent_tasks: The number of concurrent tasks used for DDL operations (table / view creation, deletion, etc). - max_snapshot_ids_per_request: The maximum number of snapshot IDs that can be sent in a single HTTP GET request to the Airflow Webserver. - use_state_connection: Whether to use the `state_connection` configuration to access the SQLMesh state. - default_catalog_override: Overrides the default catalog value for this project. If specified, this value takes precedence - over the default catalog value set on the Airflow side. - """ - - airflow_url: str = "http://localhost:8080/" - username: str = "airflow" - password: str = "airflow" - token: t.Optional[str] = None - dag_run_poll_interval_secs: int = 10 - dag_creation_poll_interval_secs: int = 30 - dag_creation_max_retry_attempts: int = 10 - - backfill_concurrent_tasks: int = 4 - ddl_concurrent_tasks: int = 4 - - max_snapshot_ids_per_request: t.Optional[int] = None - use_state_connection: bool = False - - default_catalog_override: t.Optional[str] = None - - type_: Literal["airflow"] = Field(alias="type", default="airflow") - - _concurrent_tasks_validator = concurrent_tasks_validator - - def get_client(self, console: t.Optional[Console] = None) -> AirflowClient: - session = Session() - if self.token is None: - session.auth = (self.username, self.password) - else: - session.headers.update({"Authorization": f"Bearer {self.token}"}) - - return AirflowClient( - session=session, - airflow_url=self.airflow_url, - console=console, - snapshot_ids_batch_size=self.max_snapshot_ids_per_request, - ) - - -class CloudComposerSchedulerConfig(_BaseAirflowSchedulerConfig, BaseConfig, extra="allow"): - """The Google Cloud Composer configuration. - - Args: - airflow_url: The URL of the Airflow Webserver. - dag_run_poll_interval_secs: Determines how often a running DAG can be polled (in seconds). - dag_creation_poll_interval_secs: Determines how often SQLMesh should check whether a DAG has been created (in seconds). - dag_creation_max_retry_attempts: Determines the maximum number of attempts that SQLMesh will make while checking for - whether a DAG has been created. - backfill_concurrent_tasks: The number of concurrent tasks used for model backfilling during plan application. - ddl_concurrent_tasks: The number of concurrent tasks used for DDL operations (table / view creation, deletion, etc). - max_snapshot_ids_per_request: The maximum number of snapshot IDs that can be sent in a single HTTP GET request to the Airflow Webserver. - use_state_connection: Whether to use the `state_connection` configuration to access the SQLMesh state. - default_catalog_override: Overrides the default catalog value for this project. If specified, this value takes precedence - over the default catalog value set on the Airflow side. - """ - - airflow_url: str - dag_run_poll_interval_secs: int = 10 - dag_creation_poll_interval_secs: int = 30 - dag_creation_max_retry_attempts: int = 10 - - backfill_concurrent_tasks: int = 4 - ddl_concurrent_tasks: int = 4 - - max_snapshot_ids_per_request: t.Optional[int] = 20 - use_state_connection: bool = False - - default_catalog_override: t.Optional[str] = None - - type_: Literal["cloud_composer"] = Field(alias="type", default="cloud_composer") - - _concurrent_tasks_validator = concurrent_tasks_validator - - def __init__(self, **data: t.Any) -> None: - super().__init__(**data) - self._session: t.Optional[AuthorizedSession] = data.get("session") - - @property - def session(self) -> AuthorizedSession: - import google.auth - from google.auth.transport.requests import AuthorizedSession - - if self._session is None: - self._session = AuthorizedSession( - google.auth.default(scopes=["https://www.googleapis.com/auth/cloud-platform"])[0] - ) - self._session.headers.update({"Content-Type": "application/json"}) - return self._session - - def get_client(self, console: t.Optional[Console] = None) -> AirflowClient: - return AirflowClient( - airflow_url=self.airflow_url, - session=self.session, - console=console, - snapshot_ids_batch_size=self.max_snapshot_ids_per_request, - ) - - @model_validator(mode="before") - @model_validator_v1_args - def check_supported_fields(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: - allowed_field_names = {field.alias or name for name, field in cls.all_field_infos().items()} - allowed_field_names.add("session") - - for field_name in values: - if field_name not in allowed_field_names: - raise ValueError(f"Unsupported Field: {field_name}") - return values - - -class MWAASchedulerConfig(_EngineAdapterStateSyncSchedulerConfig, BaseConfig): - """The AWS MWAA Scheduler configuration. - - Args: - environment: The name of the MWAA environment. - dag_run_poll_interval_secs: Determines how often a running DAG can be polled (in seconds). - dag_creation_poll_interval_secs: Determines how often SQLMesh should check whether a DAG has been created (in seconds). - dag_creation_max_retry_attempts: Determines the maximum number of attempts that SQLMesh will make while checking for - whether a DAG has been created. - backfill_concurrent_tasks: The number of concurrent tasks used for model backfilling during plan application. - ddl_concurrent_tasks: The number of concurrent tasks used for DDL operations (table / view creation, deletion, etc). - default_catalog_override: Overrides the default catalog value for this project. If specified, this value takes precedence - over the default catalog value set on the Airflow side. - """ - - environment: str - dag_run_poll_interval_secs: int = 10 - dag_creation_poll_interval_secs: int = 30 - dag_creation_max_retry_attempts: int = 10 - - backfill_concurrent_tasks: int = 4 - ddl_concurrent_tasks: int = 4 - - default_catalog_override: t.Optional[str] = None - - type_: Literal["mwaa"] = Field(alias="type", default="mwaa") - - _concurrent_tasks_validator = concurrent_tasks_validator - - def get_client(self, console: t.Optional[Console] = None) -> MWAAClient: - return MWAAClient(self.environment, console=console) - - def create_plan_evaluator(self, context: GenericContext) -> PlanEvaluator: - return MWAAPlanEvaluator( - client=self.get_client(context.console), - state_sync=context.state_sync, - console=context.console, - dag_run_poll_interval_secs=self.dag_run_poll_interval_secs, - dag_creation_poll_interval_secs=self.dag_creation_poll_interval_secs, - dag_creation_max_retry_attempts=self.dag_creation_max_retry_attempts, - notification_targets=context.notification_targets, - backfill_concurrent_tasks=self.backfill_concurrent_tasks, - ddl_concurrent_tasks=self.ddl_concurrent_tasks, - users=context.users, - ) - - def get_default_catalog(self, context: GenericContext) -> t.Optional[str]: - default_catalog = self.get_client(context.console).default_catalog - return self.default_catalog_override or default_catalog - -SchedulerConfig = Annotated[ - t.Union[ - BuiltInSchedulerConfig, - AirflowSchedulerConfig, - CloudComposerSchedulerConfig, - MWAASchedulerConfig, - ], - Field(discriminator="type_"), -] +scheduler_config_validator = field_validator( + "scheduler", + "default_scheduler", + mode="before", + check_fields=False, +)(_scheduler_config_validator) diff --git a/sqlmesh/core/console.py b/sqlmesh/core/console.py index 59ed086177..8af837b08a 100644 --- a/sqlmesh/core/console.py +++ b/sqlmesh/core/console.py @@ -5,7 +5,11 @@ import typing as t import unittest import uuid - +import logging +import textwrap +from humanize import metric, naturalsize +from itertools import zip_longest +from pathlib import Path from hyperscript import h from rich.console import Console as RichConsole from rich.live import Live @@ -22,30 +26,51 @@ from rich.syntax import Syntax from rich.table import Table from rich.tree import Tree +from sqlglot import exp -from sqlmesh.core.environment import EnvironmentNamingInfo +from sqlmesh.core.schema_diff import TableAlterOperation +from sqlmesh.core.test.result import ModelTextTestResult +from sqlmesh.core.environment import EnvironmentNamingInfo, EnvironmentSummary +from sqlmesh.core.linter.rule import RuleViolation +from sqlmesh.core.model import Model from sqlmesh.core.snapshot import ( Snapshot, SnapshotChangeCategory, SnapshotId, SnapshotInfoLike, - start_date, ) +from sqlmesh.core.snapshot.definition import Interval, Intervals, SnapshotTableInfo +from sqlmesh.core.snapshot.execution_tracker import QueryExecutionStats from sqlmesh.core.test import ModelTest from sqlmesh.utils import rich as srich -from sqlmesh.utils.date import time_like_to_str, to_date, yesterday_ds +from sqlmesh.utils import Verbosity +from sqlmesh.utils.concurrency import NodeExecutionFailedError +from sqlmesh.utils.date import time_like_to_str, to_date, yesterday_ds, to_ds, make_inclusive +from sqlmesh.utils.errors import ( + PythonModelEvalError, + NodeAuditsErrors, + format_destructive_change_msg, + format_additive_change_msg, +) +from sqlmesh.utils.rich import strip_ansi_codes if t.TYPE_CHECKING: import ipywidgets as widgets + from sqlglot import exp from sqlglot.dialects.dialect import DialectType from sqlmesh.core.context_diff import ContextDiff - from sqlmesh.core.plan import Plan, PlanBuilder - from sqlmesh.core.table_diff import RowDiff, SchemaDiff + from sqlmesh.core.plan import Plan, EvaluatablePlan, PlanBuilder, SnapshotIntervals + from sqlmesh.core.table_diff import TableDiff, RowDiff, SchemaDiff + from sqlmesh.core.config.connection import ConnectionConfig + from sqlmesh.core.state_sync import Versions LayoutWidget = t.TypeVar("LayoutWidget", bound=t.Union[widgets.VBox, widgets.HBox]) +logger = logging.getLogger(__name__) + + SNAPSHOT_CHANGE_CATEGORY_STR = { None: "Unknown", SnapshotChangeCategory.BREAKING: "Breaking", @@ -56,15 +81,334 @@ SnapshotChangeCategory.METADATA: "Metadata", } +PROGRESS_BAR_WIDTH = 40 +LINE_WRAP_WIDTH = 100 + + +class LinterConsole(abc.ABC): + """Console for displaying linter violations""" + + @abc.abstractmethod + def show_linter_violations( + self, violations: t.List[RuleViolation], model: Model, is_error: bool = False + ) -> None: + """Prints all linter violations depending on their severity""" + + +class StateExporterConsole(abc.ABC): + """Console for describing a state export""" + + @abc.abstractmethod + def start_state_export( + self, + output_file: Path, + gateway: t.Optional[str] = None, + state_connection_config: t.Optional[ConnectionConfig] = None, + environment_names: t.Optional[t.List[str]] = None, + local_only: bool = False, + confirm: bool = True, + ) -> bool: + """State a state export""" + + @abc.abstractmethod + def update_state_export_progress( + self, + version_count: t.Optional[int] = None, + versions_complete: bool = False, + snapshot_count: t.Optional[int] = None, + snapshots_complete: bool = False, + environment_count: t.Optional[int] = None, + environments_complete: bool = False, + ) -> None: + """Update the state export progress""" + + @abc.abstractmethod + def stop_state_export(self, success: bool, output_file: Path) -> None: + """Finish a state export""" + + +class StateImporterConsole(abc.ABC): + """Console for describing a state import""" + + @abc.abstractmethod + def start_state_import( + self, + input_file: Path, + gateway: str, + state_connection_config: ConnectionConfig, + clear: bool = False, + confirm: bool = True, + ) -> bool: + """Start a state import""" + + @abc.abstractmethod + def update_state_import_progress( + self, + timestamp: t.Optional[str] = None, + state_file_version: t.Optional[int] = None, + versions: t.Optional[Versions] = None, + snapshot_count: t.Optional[int] = None, + snapshots_complete: bool = False, + environment_count: t.Optional[int] = None, + environments_complete: bool = False, + ) -> None: + """Update the state import process""" + + @abc.abstractmethod + def stop_state_import(self, success: bool, input_file: Path) -> None: + """Finish a state import""" + + +class JanitorConsole(abc.ABC): + """Console for describing a janitor / snapshot cleanup run""" + + @abc.abstractmethod + def start_cleanup(self, ignore_ttl: bool) -> bool: + """Start a janitor / snapshot cleanup run. + + Args: + ignore_ttl: Indicates that the user wants to ignore the snapshot TTL and clean up everything not promoted to an environment + + Returns: + Whether or not the cleanup run should proceed + """ + + @abc.abstractmethod + def update_cleanup_progress(self, object_name: str) -> None: + """Update the snapshot cleanup progress.""" + + @abc.abstractmethod + def stop_cleanup(self, success: bool = True) -> None: + """Indicates the janitor / snapshot cleanup run has ended + + Args: + success: Whether or not the cleanup completed successfully + """ + + +class DestroyConsole(abc.ABC): + """Console for describing a destroy operation""" + + @abc.abstractmethod + def start_destroy( + self, + schemas_to_delete: t.Optional[t.Set[str]] = None, + views_to_delete: t.Optional[t.Set[str]] = None, + tables_to_delete: t.Optional[t.Set[str]] = None, + ) -> bool: + """Start a destroy operation. + + Args: + schemas_to_delete: Set of schemas that will be deleted + views_to_delete: Set of views that will be deleted + tables_to_delete: Set of tables that will be deleted + + Returns: + Whether or not the destroy operation should proceed + """ + + @abc.abstractmethod + def stop_destroy(self, success: bool = True) -> None: + """Indicates the destroy operation has ended -class Console(abc.ABC): + Args: + success: Whether or not the cleanup completed successfully + """ + + +class EnvironmentsConsole(abc.ABC): + """Console for displaying environments""" + + @abc.abstractmethod + def print_environments(self, environments_summary: t.List[EnvironmentSummary]) -> None: + """Prints all environment names along with expiry datetime.""" + + @abc.abstractmethod + def show_intervals(self, snapshot_intervals: t.Dict[Snapshot, SnapshotIntervals]) -> None: + """Show ready intervals""" + + +class DifferenceConsole(abc.ABC): + """Console for displaying environment differences""" + + @abc.abstractmethod + def show_environment_difference_summary( + self, + context_diff: ContextDiff, + no_diff: bool = True, + ) -> None: + """Displays a summary of differences for the environment.""" + + @abc.abstractmethod + def show_model_difference_summary( + self, + context_diff: ContextDiff, + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str], + no_diff: bool = True, + ) -> None: + """Displays a summary of differences for the given models.""" + + +class TableDiffConsole(abc.ABC): + """Console for displaying table differences""" + + @abc.abstractmethod + def show_table_diff( + self, + table_diffs: t.List[TableDiff], + show_sample: bool = True, + skip_grain_check: bool = False, + temp_schema: t.Optional[str] = None, + ) -> None: + """Display the table diff between two or multiple tables.""" + + @abc.abstractmethod + def update_table_diff_progress(self, model: str) -> None: + """Update table diff progress bar""" + + @abc.abstractmethod + def start_table_diff_progress(self, models_to_diff: int) -> None: + """Start table diff progress bar""" + + @abc.abstractmethod + def start_table_diff_model_progress(self, model: str) -> None: + """Start table diff model progress""" + + @abc.abstractmethod + def stop_table_diff_progress(self, success: bool) -> None: + """Stop table diff progress bar""" + + @abc.abstractmethod + def show_table_diff_details( + self, + models_to_diff: t.List[str], + ) -> None: + """Display information about which tables are going to be diffed""" + + @abc.abstractmethod + def show_table_diff_summary(self, table_diff: TableDiff) -> None: + """Display information about the tables being diffed and how they are being joined""" + + @abc.abstractmethod + def show_schema_diff(self, schema_diff: SchemaDiff) -> None: + """Show table schema diff.""" + + @abc.abstractmethod + def show_row_diff( + self, row_diff: RowDiff, show_sample: bool = True, skip_grain_check: bool = False + ) -> None: + """Show table summary diff.""" + + +class BaseConsole(abc.ABC): + @abc.abstractmethod + def log_error(self, message: str, *args: t.Any, **kwargs: t.Any) -> None: + """Display error info to the user.""" + + @abc.abstractmethod + def log_warning( + self, + short_message: str, + long_message: t.Optional[str] = None, + *args: t.Any, + **kwargs: t.Any, + ) -> None: + """Display warning info to the user. + + Args: + short_message: The warning message to print to console. + long_message: The warning message to log to file. If not provided, `short_message` is used. + """ + + @abc.abstractmethod + def log_success(self, message: str) -> None: + """Display a general successful message to the user.""" + + +class PlanBuilderConsole(BaseConsole, abc.ABC): + @abc.abstractmethod + def log_destructive_change( + self, + snapshot_name: str, + alter_operations: t.List[TableAlterOperation], + dialect: str, + error: bool = True, + ) -> None: + """Display a destructive change error or warning to the user.""" + + @abc.abstractmethod + def log_additive_change( + self, + snapshot_name: str, + alter_operations: t.List[TableAlterOperation], + dialect: str, + error: bool = True, + ) -> None: + """Display an additive change error or warning to the user.""" + + +class UnitTestConsole(abc.ABC): + @abc.abstractmethod + def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None: + """Display the test result and output. + + Args: + result: The unittest test result that contains metrics like num success, fails, ect. + target_dialect: The dialect that tests were run against. Assumes all tests run against the same dialect. + """ + + +class SignalConsole(abc.ABC): + @abc.abstractmethod + def start_signal_progress( + self, + snapshot: Snapshot, + default_catalog: t.Optional[str], + environment_naming_info: EnvironmentNamingInfo, + ) -> None: + """Indicates that signal checking has begun for a snapshot.""" + + @abc.abstractmethod + def update_signal_progress( + self, + snapshot: Snapshot, + signal_name: str, + signal_idx: int, + total_signals: int, + ready_intervals: Intervals, + check_intervals: Intervals, + duration: float, + ) -> None: + """Updates the signal checking progress.""" + + @abc.abstractmethod + def stop_signal_progress(self) -> None: + """Indicates that signal checking has completed for a snapshot.""" + + +class Console( + SignalConsole, + PlanBuilderConsole, + LinterConsole, + StateExporterConsole, + StateImporterConsole, + JanitorConsole, + DestroyConsole, + EnvironmentsConsole, + DifferenceConsole, + TableDiffConsole, + BaseConsole, + UnitTestConsole, + abc.ABC, +): """Abstract base class for defining classes used for displaying information to the user and also interact with them when their input is needed.""" INDIRECTLY_MODIFIED_DISPLAY_THRESHOLD = 10 @abc.abstractmethod - def start_plan_evaluation(self, plan: Plan) -> None: + def start_plan_evaluation(self, plan: EvaluatablePlan) -> None: """Indicates that a new evaluation has begun.""" @abc.abstractmethod @@ -74,19 +418,31 @@ def stop_plan_evaluation(self) -> None: @abc.abstractmethod def start_evaluation_progress( self, - batches: t.Dict[Snapshot, int], + batched_intervals: t.Dict[Snapshot, Intervals], environment_naming_info: EnvironmentNamingInfo, default_catalog: t.Optional[str], + audit_only: bool = False, ) -> None: - """Indicates that a new snapshot evaluation progress has begun.""" + """Indicates that a new snapshot evaluation/auditing progress has begun.""" @abc.abstractmethod - def start_snapshot_evaluation_progress(self, snapshot: Snapshot) -> None: + def start_snapshot_evaluation_progress( + self, snapshot: Snapshot, audit_only: bool = False + ) -> None: """Starts the snapshot evaluation progress.""" @abc.abstractmethod def update_snapshot_evaluation_progress( - self, snapshot: Snapshot, batch_idx: int, duration_ms: t.Optional[int] + self, + snapshot: Snapshot, + interval: Interval, + batch_idx: int, + duration_ms: t.Optional[int], + num_audits_passed: int, + num_audits_failed: int, + audit_only: bool = False, + execution_stats: t.Optional[QueryExecutionStats] = None, + auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: """Updates the snapshot evaluation progress.""" @@ -97,7 +453,7 @@ def stop_evaluation_progress(self, success: bool = True) -> None: @abc.abstractmethod def start_creation_progress( self, - total_tasks: int, + snapshots: t.List[Snapshot], environment_naming_info: EnvironmentNamingInfo, default_catalog: t.Optional[str], ) -> None: @@ -111,33 +467,10 @@ def update_creation_progress(self, snapshot: SnapshotInfoLike) -> None: def stop_creation_progress(self, success: bool = True) -> None: """Stop the snapshot creation progress.""" - @abc.abstractmethod - def start_cleanup(self, ignore_ttl: bool) -> bool: - """Start a janitor / snapshot cleanup run. - - Args: - ignore_ttl: Indicates that the user wants to ignore the snapshot TTL and clean up everything not promoted to an environment - - Returns: - Whether or not the cleanup run should proceed - """ - - @abc.abstractmethod - def update_cleanup_progress(self, object_name: str) -> None: - """Update the snapshot cleanup progress.""" - - @abc.abstractmethod - def stop_cleanup(self, success: bool = True) -> None: - """Indicates the janitor / snapshot cleanup run has ended - - Args: - success: Whether or not the cleanup completed successfully - """ - @abc.abstractmethod def start_promotion_progress( self, - total_tasks: int, + snapshots: t.List[SnapshotTableInfo], environment_naming_info: EnvironmentNamingInfo, default_catalog: t.Optional[str], ) -> None: @@ -179,17 +512,6 @@ def update_env_migration_progress(self, num_tasks: int) -> None: def stop_env_migration_progress(self, success: bool = True) -> None: """Stop the environment migration progress.""" - @abc.abstractmethod - def show_model_difference_summary( - self, - context_diff: ContextDiff, - environment_naming_info: EnvironmentNamingInfo, - default_catalog: t.Optional[str], - no_diff: bool = True, - ignored_snapshot_ids: t.Optional[t.Set[SnapshotId]] = None, - ) -> None: - """Displays a summary of differences for the given models.""" - @abc.abstractmethod def plan( self, @@ -214,67 +536,379 @@ def plan( """ @abc.abstractmethod - def log_test_results( - self, result: unittest.result.TestResult, output: str, target_dialect: str + def show_sql(self, sql: str) -> None: + """Display to the user SQL.""" + + @abc.abstractmethod + def log_status_update(self, message: str) -> None: + """Display general status update to the user.""" + + @abc.abstractmethod + def log_skipped_models(self, snapshot_names: t.Set[str]) -> None: + """Display list of models skipped during evaluation to the user.""" + + @abc.abstractmethod + def log_failed_models(self, errors: t.List[NodeExecutionFailedError]) -> None: + """Display list of models that failed during evaluation to the user.""" + + @abc.abstractmethod + def log_models_updated_during_restatement( + self, + snapshots: t.List[t.Tuple[SnapshotTableInfo, SnapshotTableInfo]], + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str], ) -> None: - """Display the test result and output. + """Display a list of models where new versions got deployed to the specified :environment while we were restating data the old versions Args: - result: The unittest test result that contains metrics like num success, fails, ect. - output: The generated output from the unittest. - target_dialect: The dialect that tests were run against. Assumes all tests run against the same dialect. + snapshots: a list of (snapshot_we_restated, snapshot_it_got_replaced_with_during_restatement) tuples + environment: which environment got updated while we were restating models + environment_naming_info: how snapshots are named in that :environment (for display name purposes) + default_catalog: the configured default catalog (for display name purposes) """ - @abc.abstractmethod + @abc.abstractmethod + def loading_start(self, message: t.Optional[str] = None) -> uuid.UUID: + """Starts loading and returns a unique ID that can be used to stop the loading. Optionally can display a message.""" + + @abc.abstractmethod + def loading_stop(self, id: uuid.UUID) -> None: + """Stop loading for the given id.""" + + +class NoopConsole(Console): + def start_plan_evaluation(self, plan: EvaluatablePlan) -> None: + pass + + def stop_plan_evaluation(self) -> None: + pass + + def start_evaluation_progress( + self, + batched_intervals: t.Dict[Snapshot, Intervals], + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str], + audit_only: bool = False, + ) -> None: + pass + + def start_snapshot_evaluation_progress( + self, snapshot: Snapshot, audit_only: bool = False + ) -> None: + pass + + def update_snapshot_evaluation_progress( + self, + snapshot: Snapshot, + interval: Interval, + batch_idx: int, + duration_ms: t.Optional[int], + num_audits_passed: int, + num_audits_failed: int, + audit_only: bool = False, + execution_stats: t.Optional[QueryExecutionStats] = None, + auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, + ) -> None: + pass + + def stop_evaluation_progress(self, success: bool = True) -> None: + pass + + def start_signal_progress( + self, + snapshot: Snapshot, + default_catalog: t.Optional[str], + environment_naming_info: EnvironmentNamingInfo, + ) -> None: + pass + + def update_signal_progress( + self, + snapshot: Snapshot, + signal_name: str, + signal_idx: int, + total_signals: int, + ready_intervals: Intervals, + check_intervals: Intervals, + duration: float, + ) -> None: + pass + + def stop_signal_progress(self) -> None: + pass + + def start_creation_progress( + self, + snapshots: t.List[Snapshot], + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str], + ) -> None: + pass + + def update_creation_progress(self, snapshot: SnapshotInfoLike) -> None: + pass + + def stop_creation_progress(self, success: bool = True) -> None: + pass + + def start_cleanup(self, ignore_ttl: bool) -> bool: + return True + + def update_cleanup_progress(self, object_name: str) -> None: + pass + + def stop_cleanup(self, success: bool = True) -> None: + pass + + def start_promotion_progress( + self, + snapshots: t.List[SnapshotTableInfo], + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str], + ) -> None: + pass + + def update_promotion_progress(self, snapshot: SnapshotInfoLike, promoted: bool) -> None: + pass + + def stop_promotion_progress(self, success: bool = True) -> None: + pass + + def start_snapshot_migration_progress(self, total_tasks: int) -> None: + pass + + def update_snapshot_migration_progress(self, num_tasks: int) -> None: + pass + + def log_migration_status(self, success: bool = True) -> None: + pass + + def stop_snapshot_migration_progress(self, success: bool = True) -> None: + pass + + def start_env_migration_progress(self, total_tasks: int) -> None: + pass + + def update_env_migration_progress(self, num_tasks: int) -> None: + pass + + def stop_env_migration_progress(self, success: bool = True) -> None: + pass + + def start_state_export( + self, + output_file: Path, + gateway: t.Optional[str] = None, + state_connection_config: t.Optional[ConnectionConfig] = None, + environment_names: t.Optional[t.List[str]] = None, + local_only: bool = False, + confirm: bool = True, + ) -> bool: + return confirm + + def update_state_export_progress( + self, + version_count: t.Optional[int] = None, + versions_complete: bool = False, + snapshot_count: t.Optional[int] = None, + snapshots_complete: bool = False, + environment_count: t.Optional[int] = None, + environments_complete: bool = False, + ) -> None: + pass + + def stop_state_export(self, success: bool, output_file: Path) -> None: + pass + + def start_state_import( + self, + input_file: Path, + gateway: str, + state_connection_config: ConnectionConfig, + clear: bool = False, + confirm: bool = True, + ) -> bool: + return confirm + + def update_state_import_progress( + self, + timestamp: t.Optional[str] = None, + state_file_version: t.Optional[int] = None, + versions: t.Optional[Versions] = None, + snapshot_count: t.Optional[int] = None, + snapshots_complete: bool = False, + environment_count: t.Optional[int] = None, + environments_complete: bool = False, + ) -> None: + pass + + def stop_state_import(self, success: bool, input_file: Path) -> None: + pass + + def show_environment_difference_summary( + self, + context_diff: ContextDiff, + no_diff: bool = True, + ) -> None: + pass + + def show_model_difference_summary( + self, + context_diff: ContextDiff, + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str], + no_diff: bool = True, + ) -> None: + pass + + def plan( + self, + plan_builder: PlanBuilder, + auto_apply: bool, + default_catalog: t.Optional[str], + no_diff: bool = False, + no_prompts: bool = False, + ) -> None: + if auto_apply: + plan_builder.apply() + + def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None: + pass + def show_sql(self, sql: str) -> None: - """Display to the user SQL.""" + pass - @abc.abstractmethod def log_status_update(self, message: str) -> None: - """Display general status update to the user.""" + pass + + def log_skipped_models(self, snapshot_names: t.Set[str]) -> None: + pass + + def log_failed_models(self, errors: t.List[NodeExecutionFailedError]) -> None: + pass + + def log_models_updated_during_restatement( + self, + snapshots: t.List[t.Tuple[SnapshotTableInfo, SnapshotTableInfo]], + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str], + ) -> None: + pass + + def log_destructive_change( + self, + snapshot_name: str, + alter_operations: t.List[TableAlterOperation], + dialect: str, + error: bool = True, + ) -> None: + pass + + def log_additive_change( + self, + snapshot_name: str, + alter_operations: t.List[TableAlterOperation], + dialect: str, + error: bool = True, + ) -> None: + pass - @abc.abstractmethod def log_error(self, message: str) -> None: - """Display error info to the user.""" + pass + + def log_warning(self, short_message: str, long_message: t.Optional[str] = None) -> None: + logger.warning(long_message or short_message) - @abc.abstractmethod def log_success(self, message: str) -> None: - """Display a general successful message to the user.""" + pass - @abc.abstractmethod def loading_start(self, message: t.Optional[str] = None) -> uuid.UUID: - """Starts loading and returns a unique ID that can be used to stop the loading. Optionally can display a message.""" + return uuid.uuid4() - @abc.abstractmethod def loading_stop(self, id: uuid.UUID) -> None: - """Stop loading for the given id.""" + pass + + def show_table_diff( + self, + table_diffs: t.List[TableDiff], + show_sample: bool = True, + skip_grain_check: bool = False, + temp_schema: t.Optional[str] = None, + ) -> None: + for table_diff in table_diffs: + self.show_table_diff_summary(table_diff) + self.show_schema_diff(table_diff.schema_diff()) + self.show_row_diff( + table_diff.row_diff(temp_schema=temp_schema, skip_grain_check=skip_grain_check), + show_sample=show_sample, + skip_grain_check=skip_grain_check, + ) + + def update_table_diff_progress(self, model: str) -> None: + pass + + def start_table_diff_progress(self, models_to_diff: int) -> None: + pass + + def start_table_diff_model_progress(self, model: str) -> None: + pass + + def stop_table_diff_progress(self, success: bool) -> None: + pass + + def show_table_diff_details( + self, + models_to_diff: t.List[str], + ) -> None: + pass + + def show_table_diff_summary(self, table_diff: TableDiff) -> None: + pass - @abc.abstractmethod def show_schema_diff(self, schema_diff: SchemaDiff) -> None: - """Show table schema diff.""" + pass - @abc.abstractmethod def show_row_diff( self, row_diff: RowDiff, show_sample: bool = True, skip_grain_check: bool = False ) -> None: - """Show table summary diff.""" + pass - def _limit_model_names(self, tree: Tree, verbose: bool = False) -> Tree: - """Trim long indirectly modified model lists below threshold.""" - modified_length = len(tree.children) - if not verbose and modified_length > self.INDIRECTLY_MODIFIED_DISPLAY_THRESHOLD: - tree.children = [ - tree.children[0], - Tree(f".... {modified_length-2} more ...."), - tree.children[-1], - ] - return tree + def print_environments(self, environments_summary: t.List[EnvironmentSummary]) -> None: + pass + + def show_intervals(self, snapshot_intervals: t.Dict[Snapshot, SnapshotIntervals]) -> None: + pass + + def show_linter_violations( + self, violations: t.List[RuleViolation], model: Model, is_error: bool = False + ) -> None: + pass + + def print_connection_config( + self, config: ConnectionConfig, title: t.Optional[str] = "Connection" + ) -> None: + pass + + def start_destroy( + self, + schemas_to_delete: t.Optional[t.Set[str]] = None, + views_to_delete: t.Optional[t.Set[str]] = None, + tables_to_delete: t.Optional[t.Set[str]] = None, + ) -> bool: + return True + + def stop_destroy(self, success: bool = True) -> None: + pass -def make_progress_bar(message: str, console: t.Optional[RichConsole] = None) -> Progress: +def make_progress_bar( + message: str, + console: t.Optional[RichConsole] = None, + justify: t.Literal["default", "left", "center", "right", "full"] = "right", +) -> Progress: return Progress( - TextColumn(f"[bold blue]{message}", justify="right"), - BarColumn(bar_width=40), + TextColumn(f"[bold blue]{message}", justify=justify), + BarColumn(bar_width=PROGRESS_BAR_WIDTH), "[progress.percentage]{task.percentage:>3.1f}%", "•", srich.BatchColumn(), @@ -287,11 +921,20 @@ def make_progress_bar(message: str, console: t.Optional[RichConsole] = None) -> class TerminalConsole(Console): """A rich based implementation of the console.""" + TABLE_DIFF_SOURCE_BLUE = "#0248ff" + TABLE_DIFF_TARGET_GREEN = "green" + AUDIT_PASS_MARK = "\u2714" + GREEN_AUDIT_PASS_MARK = f"[green]{AUDIT_PASS_MARK}[/green]" + AUDIT_FAIL_MARK = "\u274c" + AUDIT_PADDING = 0 + CHECK_MARK = f"{AUDIT_PASS_MARK} " + def __init__( self, console: t.Optional[RichConsole] = None, - verbose: bool = False, + verbosity: Verbosity = Verbosity.DEFAULT, dialect: DialectType = None, + ignore_warnings: bool = False, **kwargs: t.Any, ) -> None: self.console: RichConsole = console or srich.console @@ -301,16 +944,19 @@ def __init__( self.evaluation_total_task: t.Optional[TaskID] = None self.evaluation_model_progress: t.Optional[Progress] = None self.evaluation_model_tasks: t.Dict[str, TaskID] = {} - self.evaluation_model_batches: t.Dict[Snapshot, int] = {} + self.evaluation_model_batch_sizes: t.Dict[Snapshot, int] = {} + self.evaluation_column_widths: t.Dict[str, int] = {} # Put in temporary values that are replaced when evaluating self.environment_naming_info = EnvironmentNamingInfo() self.default_catalog: t.Optional[str] = None self.creation_progress: t.Optional[Progress] = None + self.creation_column_widths: t.Dict[str, int] = {} self.creation_task: t.Optional[TaskID] = None self.promotion_progress: t.Optional[Progress] = None + self.promotion_column_widths: t.Dict[str, int] = {} self.promotion_task: t.Optional[TaskID] = None self.migration_progress: t.Optional[Progress] = None @@ -321,8 +967,41 @@ def __init__( self.loading_status: t.Dict[uuid.UUID, Status] = {} - self.verbose = verbose + self.state_export_progress: t.Optional[Progress] = None + self.state_export_version_task: t.Optional[TaskID] = None + self.state_export_snapshot_task: t.Optional[TaskID] = None + self.state_export_environment_task: t.Optional[TaskID] = None + + self.state_import_progress: t.Optional[Progress] = None + self.state_import_version_task: t.Optional[TaskID] = None + self.state_import_snapshot_task: t.Optional[TaskID] = None + self.state_import_environment_task: t.Optional[TaskID] = None + + self.table_diff_progress: t.Optional[Progress] = None + self.table_diff_model_progress: t.Optional[Progress] = None + self.table_diff_model_tasks: t.Dict[str, TaskID] = {} + self.table_diff_progress_live: t.Optional[Live] = None + + self.signal_progress_logged = False + self.signal_status_tree: t.Optional[Tree] = None + + self.verbosity = verbosity self.dialect = dialect + self.ignore_warnings = ignore_warnings + + def _limit_model_names(self, tree: Tree, verbosity: Verbosity = Verbosity.DEFAULT) -> Tree: + """Trim long indirectly modified model lists below threshold.""" + modified_length = len(tree.children) + if ( + verbosity < Verbosity.VERY_VERBOSE + and modified_length > self.INDIRECTLY_MODIFIED_DISPLAY_THRESHOLD + ): + tree.children = [ + tree.children[0], + Tree(f".... {modified_length - 2} more ...."), + tree.children[-1], + ] + return tree def _print(self, value: t.Any, **kwargs: t.Any) -> None: self.console.print(value, **kwargs) @@ -333,7 +1012,7 @@ def _prompt(self, message: str, **kwargs: t.Any) -> t.Any: def _confirm(self, message: str, **kwargs: t.Any) -> bool: return Confirm.ask(message, console=self.console, **kwargs) - def start_plan_evaluation(self, plan: Plan) -> None: + def start_plan_evaluation(self, plan: EvaluatablePlan) -> None: pass def stop_plan_evaluation(self) -> None: @@ -341,13 +1020,20 @@ def stop_plan_evaluation(self) -> None: def start_evaluation_progress( self, - batches: t.Dict[Snapshot, int], + batched_intervals: t.Dict[Snapshot, Intervals], environment_naming_info: EnvironmentNamingInfo, default_catalog: t.Optional[str], + audit_only: bool = False, ) -> None: - """Indicates that a new snapshot evaluation progress has begun.""" + """Indicates that a new snapshot evaluation/auditing progress has begun.""" + # Add a newline to separate signal checking from evaluation + if self.signal_progress_logged: + self._print("") + if not self.evaluation_progress_live: - self.evaluation_total_progress = make_progress_bar("Evaluating models", self.console) + self.evaluation_total_progress = make_progress_bar( + "Executing model batches" if not audit_only else "Auditing models", self.console + ) self.evaluation_model_progress = Progress( TextColumn("{task.fields[view_name]}", justify="right"), @@ -359,30 +1045,70 @@ def start_evaluation_progress( progress_table.add_row(self.evaluation_total_progress) progress_table.add_row(self.evaluation_model_progress) - self.evaluation_progress_live = Live(progress_table, refresh_per_second=10) + self.evaluation_progress_live = Live( + progress_table, console=self.console, refresh_per_second=10 + ) self.evaluation_progress_live.start() + batch_sizes = { + snapshot: len(intervals) for snapshot, intervals in batched_intervals.items() + } + message = "Executing" if not audit_only else "Auditing" self.evaluation_total_task = self.evaluation_total_progress.add_task( - "Evaluating models...", total=sum(batches.values()) + f"{message} models...", total=sum(batch_sizes.values()) + ) + + # determine column widths + self.evaluation_column_widths["annotation"] = ( + _calculate_annotation_str_len( + batched_intervals, self.AUDIT_PADDING, len(" (123.4m rows, 123.4 KiB)") + ) + + 3 # brackets and opening escape backslash + ) + self.evaluation_column_widths["name"] = max( + len( + snapshot.display_name( + environment_naming_info, + default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, + dialect=self.dialect, + ) + ) + for snapshot in batched_intervals ) + largest_batch_size = max(batch_sizes.values()) + self.evaluation_column_widths["batch"] = len(str(largest_batch_size)) * 2 + 3 # [X/X] + self.evaluation_column_widths["duration"] = 8 - self.evaluation_model_batches = batches + self.evaluation_model_batch_sizes = batch_sizes self.environment_naming_info = environment_naming_info self.default_catalog = default_catalog - def start_snapshot_evaluation_progress(self, snapshot: Snapshot) -> None: + def start_snapshot_evaluation_progress( + self, snapshot: Snapshot, audit_only: bool = False + ) -> None: if self.evaluation_model_progress and snapshot.name not in self.evaluation_model_tasks: display_name = snapshot.display_name( - self.environment_naming_info, self.default_catalog, dialect=self.dialect + self.environment_naming_info, + self.default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, + dialect=self.dialect, ) self.evaluation_model_tasks[snapshot.name] = self.evaluation_model_progress.add_task( - f"Evaluating {display_name}...", + f"{'Evaluating' if not audit_only else 'Auditing'} {display_name}...", view_name=display_name, - total=self.evaluation_model_batches[snapshot], + total=self.evaluation_model_batch_sizes[snapshot], ) def update_snapshot_evaluation_progress( - self, snapshot: Snapshot, batch_idx: int, duration_ms: t.Optional[int] + self, + snapshot: Snapshot, + interval: Interval, + batch_idx: int, + duration_ms: t.Optional[int], + num_audits_passed: int, + num_audits_failed: int, + audit_only: bool = False, + execution_stats: t.Optional[QueryExecutionStats] = None, + auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: """Update the snapshot evaluation progress.""" if ( @@ -390,20 +1116,54 @@ def update_snapshot_evaluation_progress( and self.evaluation_model_progress and self.evaluation_progress_live ): - total_batches = self.evaluation_model_batches[snapshot] + total_batches = self.evaluation_model_batch_sizes[snapshot] + batch_num = str(batch_idx + 1).rjust(len(str(total_batches))) + batch = f"[{batch_num}/{total_batches}]".ljust(self.evaluation_column_widths["batch"]) if duration_ms: - self.evaluation_progress_live.console.print( - f"[{batch_idx + 1}/{total_batches}] {snapshot.display_name(self.environment_naming_info, self.default_catalog, dialect=self.dialect)} [green]evaluated[/green] in {(duration_ms / 1000.0):.2f}s" + display_name = snapshot.display_name( + self.environment_naming_info, + self.default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, + dialect=self.dialect, + ).ljust(self.evaluation_column_widths["name"]) + + annotation = _create_evaluation_model_annotation( + snapshot, _format_evaluation_model_interval(snapshot, interval), execution_stats + ) + audits_str = "" + if num_audits_passed: + audits_str += f" {self.AUDIT_PASS_MARK}{num_audits_passed}" + if num_audits_failed: + audits_str += f" {self.AUDIT_FAIL_MARK}{num_audits_failed}" + audits_str = f", audits{audits_str}" if audits_str else "" + annotation_len = self.evaluation_column_widths["annotation"] + # don't adjust the annotation_len if we're using AUDIT_PADDING + annotation = f"\\[{annotation + audits_str}]".ljust( + annotation_len - 1 + if num_audits_failed and self.AUDIT_PADDING == 0 + else annotation_len ) + duration = f"{(duration_ms / 1000.0):.2f}s".ljust( + self.evaluation_column_widths["duration"] + ) + + msg = f"{f'{batch} ' if not audit_only else ''}{display_name} {annotation} {duration}".replace( + self.AUDIT_PASS_MARK, self.GREEN_AUDIT_PASS_MARK + ) + + self.evaluation_progress_live.console.print(msg) + self.evaluation_total_progress.update( self.evaluation_total_task or TaskID(0), refresh=True, advance=1 ) model_task_id = self.evaluation_model_tasks[snapshot.name] self.evaluation_model_progress.update(model_task_id, refresh=True, advance=1) - if self.evaluation_model_progress._tasks[model_task_id].completed >= total_batches: + if ( + self.evaluation_model_progress._tasks[model_task_id].completed >= total_batches + or audit_only + ): self.evaluation_model_progress.remove_task(model_task_id) def stop_evaluation_progress(self, success: bool = True) -> None: @@ -411,43 +1171,144 @@ def stop_evaluation_progress(self, success: bool = True) -> None: if self.evaluation_progress_live: self.evaluation_progress_live.stop() if success: - self.log_success("All model batches have been executed successfully") + self.log_success(f"{self.CHECK_MARK}Model batches executed") self.evaluation_progress_live = None self.evaluation_total_progress = None self.evaluation_total_task = None self.evaluation_model_progress = None self.evaluation_model_tasks = {} - self.evaluation_model_batches = {} + self.evaluation_model_batch_sizes = {} + self.evaluation_column_widths = {} self.environment_naming_info = EnvironmentNamingInfo() self.default_catalog = None + def start_signal_progress( + self, + snapshot: Snapshot, + default_catalog: t.Optional[str], + environment_naming_info: EnvironmentNamingInfo, + ) -> None: + """Indicates that signal checking has begun for a snapshot.""" + display_name = snapshot.display_name( + environment_naming_info, + default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, + dialect=self.dialect, + ) + self.signal_status_tree = Tree(f"Checking signals for {display_name}") + + def update_signal_progress( + self, + snapshot: Snapshot, + signal_name: str, + signal_idx: int, + total_signals: int, + ready_intervals: Intervals, + check_intervals: Intervals, + duration: float, + ) -> None: + """Updates the signal checking progress.""" + tree = Tree(f"[{signal_idx + 1}/{total_signals}] {signal_name} {duration:.2f}s") + + formatted_check_intervals = [_format_signal_interval(snapshot, i) for i in check_intervals] + formatted_ready_intervals = [_format_signal_interval(snapshot, i) for i in ready_intervals] + + if not formatted_check_intervals: + formatted_check_intervals = ["no intervals"] + if not formatted_ready_intervals: + formatted_ready_intervals = ["no intervals"] + + # Color coding to help detect partial interval ranges quickly + if ready_intervals == check_intervals: + msg = "All ready" + color = "green" + elif ready_intervals: + msg = "Some ready" + color = "yellow" + else: + msg = "None ready" + color = "red" + + if self.verbosity < Verbosity.VERY_VERBOSE: + num_check_intervals = len(formatted_check_intervals) + if num_check_intervals > 3: + formatted_check_intervals = formatted_check_intervals[:3] + formatted_check_intervals.append(f"... and {num_check_intervals - 3} more") + + num_ready_intervals = len(formatted_ready_intervals) + if num_ready_intervals > 3: + formatted_ready_intervals = formatted_ready_intervals[:3] + formatted_ready_intervals.append(f"... and {num_ready_intervals - 3} more") + + check = ", ".join(formatted_check_intervals) + tree.add(f"Check: {check}") + + ready = ", ".join(formatted_ready_intervals) + tree.add(f"[{color}]{msg}: {ready}[/{color}]") + else: + check_tree = Tree("Check") + tree.add(check_tree) + for interval in formatted_check_intervals: + check_tree.add(interval) + + ready_tree = Tree(f"[{color}]{msg}[/{color}]") + tree.add(ready_tree) + for interval in formatted_ready_intervals: + ready_tree.add(f"[{color}]{interval}[/{color}]") + + if self.signal_status_tree is not None: + self.signal_status_tree.add(tree) + + def stop_signal_progress(self) -> None: + """Indicates that signal checking has completed for a snapshot.""" + if self.signal_status_tree is not None: + self._print(self.signal_status_tree) + self.signal_status_tree = None + self.signal_progress_logged = True + def start_creation_progress( self, - total_tasks: int, + snapshots: t.List[Snapshot], environment_naming_info: EnvironmentNamingInfo, default_catalog: t.Optional[str], ) -> None: """Indicates that a new creation progress has begun.""" if self.creation_progress is None: - self.creation_progress = make_progress_bar("Creating physical table", self.console) + self.creation_progress = make_progress_bar("Updating physical layer", self.console) + self._print("") self.creation_progress.start() self.creation_task = self.creation_progress.add_task( - "Creating physical tables...", - total=total_tasks, + "Updating physical layer...", + total=len(snapshots), ) + # determine name column widths if we're printing name + if self.verbosity >= Verbosity.VERBOSE: + self.creation_column_widths["name"] = max( + len( + snapshot.display_name( + environment_naming_info, + default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, + dialect=self.dialect, + ) + ) + for snapshot in snapshots + ) + self.environment_naming_info = environment_naming_info self.default_catalog = default_catalog def update_creation_progress(self, snapshot: SnapshotInfoLike) -> None: """Update the snapshot creation progress.""" if self.creation_progress is not None and self.creation_task is not None: - if self.verbose: - self.creation_progress.live.console.print( - f"{snapshot.display_name(self.environment_naming_info, self.default_catalog, dialect=self.dialect)} [green]created[/green]" - ) + if self.verbosity >= Verbosity.VERBOSE: + msg = snapshot.display_name( + self.environment_naming_info, + self.default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, + dialect=self.dialect, + ).ljust(self.creation_column_widths["name"]) + self.creation_progress.live.console.print(msg + " [green]created[/green]") self.creation_progress.update(self.creation_task, refresh=True, advance=1) def stop_creation_progress(self, success: bool = True) -> None: @@ -457,10 +1318,11 @@ def stop_creation_progress(self, success: bool = True) -> None: self.creation_progress.stop() self.creation_progress = None if success: - self.log_success("All model versions have been created successfully") + self.log_success(f"\n{self.CHECK_MARK}Physical layer updated") self.environment_naming_info = EnvironmentNamingInfo() self.default_catalog = None + self.creation_column_widths = {} def start_cleanup(self, ignore_ttl: bool) -> bool: if ignore_ttl: @@ -488,43 +1350,109 @@ def stop_cleanup(self, success: bool = False) -> None: else: self.log_error("Cleanup failed!") + def start_destroy( + self, + schemas_to_delete: t.Optional[t.Set[str]] = None, + views_to_delete: t.Optional[t.Set[str]] = None, + tables_to_delete: t.Optional[t.Set[str]] = None, + ) -> bool: + self.log_warning( + "This will permanently delete all engine-managed objects, state tables and SQLMesh cache.\n" + "The operation may disrupt any currently running or scheduled plans.\n" + ) + + if schemas_to_delete or views_to_delete or tables_to_delete: + if schemas_to_delete: + self.log_error("Schemas to be deleted:") + for schema in sorted(schemas_to_delete): + self.log_error(f" • {schema}") + + if views_to_delete: + self.log_error("\nEnvironment views to be deleted:") + for view in sorted(views_to_delete): + self.log_error(f" • {view}") + + if tables_to_delete: + self.log_error("\nSnapshot tables to be deleted:") + for table in sorted(tables_to_delete): + self.log_error(f" • {table}") + + self.log_error( + "\nThis action will DELETE ALL the above resources managed by SQLMesh AND\n" + "potentially external resources created by other tools in these schemas.\n" + ) + + if not self._confirm("Are you ABSOLUTELY SURE you want to proceed with deletion?"): + self.log_error("Destroy operation cancelled.") + return False + return True + + def stop_destroy(self, success: bool = False) -> None: + if success: + self.log_success("Destroy completed successfully.") + else: + self.log_error("Destroy failed!") + def start_promotion_progress( self, - total_tasks: int, + snapshots: t.List[SnapshotTableInfo], environment_naming_info: EnvironmentNamingInfo, default_catalog: t.Optional[str], ) -> None: """Indicates that a new snapshot promotion progress has begun.""" - if self.promotion_progress is None: - self.promotion_progress = Progress( - TextColumn( - f"[bold blue]Virtually Updating '{environment_naming_info.name}'", - justify="right", - ), - BarColumn(bar_width=40), - "[progress.percentage]{task.percentage:>3.1f}%", - "•", - TimeElapsedColumn(), - console=self.console, + if snapshots and self.promotion_progress is None: + self.promotion_progress = make_progress_bar( + "Updating virtual layer ", self.console, justify="left" ) + snapshots_with_virtual_views = [ + s for s in snapshots if s.is_model and not s.is_symbolic + ] self.promotion_progress.start() self.promotion_task = self.promotion_progress.add_task( - f"Virtually Updating {environment_naming_info.name}...", - total=total_tasks, + f"Virtually updating {environment_naming_info.name}...", + total=len(snapshots_with_virtual_views), ) + # determine name column widths if we're printing names + if self.verbosity >= Verbosity.VERBOSE: + self.promotion_column_widths["name"] = max( + len( + snapshot.display_name( + environment_naming_info, + default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, + dialect=self.dialect, + ) + ) + for snapshot in snapshots_with_virtual_views + ) + self.environment_naming_info = environment_naming_info self.default_catalog = default_catalog def update_promotion_progress(self, snapshot: SnapshotInfoLike, promoted: bool) -> None: """Update the snapshot promotion progress.""" - if self.promotion_progress is not None and self.promotion_task is not None: - if self.verbose: - action_str = "[green]promoted[/green]" if promoted else "[yellow]demoted[/yellow]" - self.promotion_progress.live.console.print( - f"{snapshot.display_name(self.environment_naming_info, self.default_catalog, dialect=self.dialect)} {action_str}" - ) + if ( + self.promotion_progress is not None + and self.promotion_task is not None + and snapshot.is_model + and not snapshot.is_symbolic + ): + if self.verbosity >= Verbosity.VERBOSE: + display_name = snapshot.display_name( + self.environment_naming_info, + self.default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, + dialect=self.dialect, + ).ljust(self.promotion_column_widths["name"]) + action_str = "" + if promoted: + action_str = ( + "[yellow]updated[/yellow]" + if snapshot.previous_version + else "[green]created[/green]" + ) + action_str = action_str or "[red]dropped[/red]" + self.promotion_progress.live.console.print(f"{display_name} {action_str}") self.promotion_progress.update(self.promotion_task, refresh=True, advance=1) def stop_promotion_progress(self, success: bool = True) -> None: @@ -534,10 +1462,11 @@ def stop_promotion_progress(self, success: bool = True) -> None: self.promotion_progress.stop() self.promotion_progress = None if success: - self.log_success("The target environment has been updated successfully") + self.log_success(f"\n{self.CHECK_MARK}Virtual layer updated") self.environment_naming_info = EnvironmentNamingInfo() self.default_catalog = None + self.promotion_column_widths = {} def start_snapshot_migration_progress(self, total_tasks: int) -> None: """Indicates that a new snapshot migration progress has begun.""" @@ -560,7 +1489,7 @@ def log_migration_status(self, success: bool = True) -> None: if self.migration_progress is not None: self.migration_progress = None if success: - self.log_success("The migration has been completed successfully") + self.log_success("Migration completed successfully") def stop_snapshot_migration_progress(self, success: bool = True) -> None: """Stop the migration progress.""" @@ -568,33 +1497,320 @@ def stop_snapshot_migration_progress(self, success: bool = True) -> None: if self.migration_progress is not None: self.migration_progress.stop() if success: - self.log_success("All snapshots have been migrated successfully") + self.log_success("Snapshots migrated successfully") + + def start_env_migration_progress(self, total_tasks: int) -> None: + """Indicates that a new environment migration has begun.""" + if self.env_migration_progress is None: + self.env_migration_progress = make_progress_bar("Migrating environments", self.console) + self.env_migration_progress.start() + self.env_migration_task = self.env_migration_progress.add_task( + "Migrating environments...", + total=total_tasks, + ) + + def update_env_migration_progress(self, num_tasks: int) -> None: + """Update the environment migration progress.""" + if self.env_migration_progress is not None and self.env_migration_task is not None: + self.env_migration_progress.update( + self.env_migration_task, refresh=True, advance=num_tasks + ) + + def stop_env_migration_progress(self, success: bool = True) -> None: + """Stop the environment migration progress.""" + self.env_migration_task = None + if self.env_migration_progress is not None: + self.env_migration_progress.stop() + self.env_migration_progress = None + if success: + self.log_success("Environments migrated successfully") + + def start_state_export( + self, + output_file: Path, + gateway: t.Optional[str] = None, + state_connection_config: t.Optional[ConnectionConfig] = None, + environment_names: t.Optional[t.List[str]] = None, + local_only: bool = False, + confirm: bool = True, + ) -> bool: + self.state_export_progress = None + + if local_only: + self.log_status_update(f"Exporting [b]local[/b] state to '{output_file.as_posix()}'\n") + self.log_warning( + "Local state exports just contain the model versions in your local context. Therefore, the resulting file cannot be imported." + ) + else: + self.log_status_update( + f"Exporting state to '{output_file.as_posix()}' from the following connection:\n" + ) + if gateway: + self.log_status_update(f"[b]Gateway[/b]: [green]{gateway}[/green]") + if state_connection_config: + self.print_connection_config(state_connection_config, title="State Connection") + if environment_names: + heading = "Environments" if len(environment_names) > 1 else "Environment" + self.log_status_update( + f"[b]{heading}[/b]: [yellow]{', '.join(environment_names)}[/yellow]" + ) + + should_continue = True + if confirm: + should_continue = self._confirm("\nContinue?") + self.log_status_update("") + + if should_continue: + self.state_export_progress = make_progress_bar("{task.description}", self.console) + assert isinstance(self.state_export_progress, Progress) + + self.state_export_version_task = self.state_export_progress.add_task( + "Exporting versions", start=False + ) + self.state_export_snapshot_task = self.state_export_progress.add_task( + "Exporting snapshots", start=False + ) + self.state_export_environment_task = self.state_export_progress.add_task( + "Exporting environments", start=False + ) + + self.state_export_progress.start() + + return should_continue + + def update_state_export_progress( + self, + version_count: t.Optional[int] = None, + versions_complete: bool = False, + snapshot_count: t.Optional[int] = None, + snapshots_complete: bool = False, + environment_count: t.Optional[int] = None, + environments_complete: bool = False, + ) -> None: + if self.state_export_progress: + if self.state_export_version_task is not None: + if version_count is not None: + self.state_export_progress.start_task(self.state_export_version_task) + self.state_export_progress.update( + self.state_export_version_task, + total=version_count, + completed=version_count, + refresh=True, + ) + if versions_complete: + self.state_export_progress.stop_task(self.state_export_version_task) + + if self.state_export_snapshot_task is not None: + if snapshot_count is not None: + self.state_export_progress.start_task(self.state_export_snapshot_task) + self.state_export_progress.update( + self.state_export_snapshot_task, + total=snapshot_count, + completed=snapshot_count, + refresh=True, + ) + if snapshots_complete: + self.state_export_progress.stop_task(self.state_export_snapshot_task) + + if self.state_export_environment_task is not None: + if environment_count is not None: + self.state_export_progress.start_task(self.state_export_environment_task) + self.state_export_progress.update( + self.state_export_environment_task, + total=environment_count, + completed=environment_count, + refresh=True, + ) + if environments_complete: + self.state_export_progress.stop_task(self.state_export_environment_task) + + def stop_state_export(self, success: bool, output_file: Path) -> None: + if self.state_export_progress: + self.state_export_progress.stop() + self.state_export_progress = None + + self.log_status_update("") + + if success: + self.log_success(f"State exported successfully to '{output_file.as_posix()}'") + else: + self.log_error("State export failed!") + + def start_state_import( + self, + input_file: Path, + gateway: str, + state_connection_config: ConnectionConfig, + clear: bool = False, + confirm: bool = True, + ) -> bool: + self.log_status_update( + f"Loading state from '{input_file.as_posix()}' into the following connection:\n" + ) + self.log_status_update(f"[b]Gateway[/b]: [green]{gateway}[/green]") + self.print_connection_config(state_connection_config, title="State Connection") + self.log_status_update("") + + if clear: + self.log_warning( + f"This [b]destructive[/b] operation will delete all existing state against the '{gateway}' gateway \n" + f"and replace it with what's in the '{input_file.as_posix()}' file.\n" + ) + else: + self.log_warning( + f"This operation will [b]merge[/b] the contents of the state file to the state located at the '{gateway}' gateway.\n" + "Matching snapshots or environments will be replaced.\n" + "Non-matching snapshots or environments will be ignored.\n" + ) + + should_continue = True + if confirm: + should_continue = self._confirm("[red]Are you sure?[/red]") + self.log_status_update("") + + if should_continue: + self.state_import_progress = make_progress_bar("{task.description}", self.console) + + self.state_import_info = Tree("[bold]State File Information:") + + self.state_import_version_task = self.state_import_progress.add_task( + "Importing versions", start=False + ) + self.state_import_snapshot_task = self.state_import_progress.add_task( + "Importing snapshots", start=False + ) + self.state_import_environment_task = self.state_import_progress.add_task( + "Importing environments", start=False + ) + + self.state_import_progress.start() + + return should_continue + + def update_state_import_progress( + self, + timestamp: t.Optional[str] = None, + state_file_version: t.Optional[int] = None, + versions: t.Optional[Versions] = None, + snapshot_count: t.Optional[int] = None, + snapshots_complete: bool = False, + environment_count: t.Optional[int] = None, + environments_complete: bool = False, + ) -> None: + if self.state_import_progress: + if self.state_import_info: + if timestamp: + self.state_import_info.add(f"Creation Timestamp: {timestamp}") + if state_file_version: + self.state_import_info.add(f"File Version: {state_file_version}") + if versions: + self.state_import_info.add(f"SQLMesh version: {versions.sqlmesh_version}") + self.state_import_info.add( + f"SQLMesh migration version: {versions.schema_version}" + ) + self.state_import_info.add(f"SQLGlot version: {versions.sqlglot_version}\n") + + self._print(self.state_import_info) + + version_count = len(versions.model_dump()) + + if self.state_import_version_task is not None: + self.state_import_progress.start_task(self.state_import_version_task) + self.state_import_progress.update( + self.state_import_version_task, + total=version_count, + completed=version_count, + ) + self.state_import_progress.stop_task(self.state_import_version_task) + + if self.state_import_snapshot_task is not None: + if snapshot_count is not None: + self.state_import_progress.start_task(self.state_import_snapshot_task) + self.state_import_progress.update( + self.state_import_snapshot_task, + completed=snapshot_count, + total=snapshot_count, + refresh=True, + ) + + if snapshots_complete: + self.state_import_progress.stop_task(self.state_import_snapshot_task) + + if self.state_import_environment_task is not None: + if environment_count is not None: + self.state_import_progress.start_task(self.state_import_environment_task) + self.state_import_progress.update( + self.state_import_environment_task, + completed=environment_count, + total=environment_count, + refresh=True, + ) + + if environments_complete: + self.state_import_progress.stop_task(self.state_import_environment_task) + + def stop_state_import(self, success: bool, input_file: Path) -> None: + if self.state_import_progress: + self.state_import_progress.stop() + self.state_import_progress = None + + self.log_status_update("") + + if success: + self.log_success(f"State imported successfully from '{input_file.as_posix()}'") + else: + self.log_error("State import failed!") + + def show_environment_difference_summary( + self, + context_diff: ContextDiff, + no_diff: bool = True, + ) -> None: + """Shows a summary of the environment differences. + + Args: + context_diff: The context diff to use to print the summary + no_diff: Hide the actual environment statement differences. + """ + if context_diff.is_new_environment: + msg = ( + f"\n`{context_diff.environment}` environment will be initialized" + if not context_diff.create_from_env_exists + else f"\nNew environment `{context_diff.environment}` will be created from `{context_diff.create_from}`" + ) + self._print(Tree(f"[bold]{msg}\n")) + if not context_diff.has_snapshot_changes: + return - def start_env_migration_progress(self, total_tasks: int) -> None: - """Indicates that a new environment migration has begun.""" - if self.env_migration_progress is None: - self.env_migration_progress = make_progress_bar("Migrating environments", self.console) - self.env_migration_progress.start() - self.env_migration_task = self.env_migration_progress.add_task( - "Migrating environments...", - total=total_tasks, + if not context_diff.has_changes: + # This is only reached when the plan is against an existing environment, so we use the environment + # name instead of the create_from name. The equivalent message for new environments happens in + # the PlanBuilder. + self._print( + Tree( + f"\n[bold]No changes to plan: project files match the `{context_diff.environment}` environment\n" + ) ) + return - def update_env_migration_progress(self, num_tasks: int) -> None: - """Update the environment migration progress.""" - if self.env_migration_progress is not None and self.env_migration_task is not None: - self.env_migration_progress.update( - self.env_migration_task, refresh=True, advance=num_tasks + if not context_diff.is_new_environment or ( + context_diff.is_new_environment and context_diff.create_from_env_exists + ): + self._print( + Tree( + f"\n[bold]Differences from the `{context_diff.create_from if context_diff.is_new_environment else context_diff.environment}` environment:\n" + ) ) - def stop_env_migration_progress(self, success: bool = True) -> None: - """Stop the environment migration progress.""" - self.env_migration_task = None - if self.env_migration_progress is not None: - self.env_migration_progress.stop() - self.env_migration_progress = None - if success: - self.log_success("All environments have been migrated successfully") + if context_diff.has_requirement_changes: + self._print(f"[bold]Requirements:\n{context_diff.requirements_diff()}") + + if context_diff.has_environment_statements_changes and not no_diff: + self._print("[bold]Environment statements:\n") + for type, diff in context_diff.environment_statements_diff( + include_python_env=not context_diff.is_new_environment + ): + self._print(Syntax(diff, type, line_numbers=False)) def show_model_difference_summary( self, @@ -602,32 +1818,15 @@ def show_model_difference_summary( environment_naming_info: EnvironmentNamingInfo, default_catalog: t.Optional[str], no_diff: bool = True, - ignored_snapshot_ids: t.Optional[t.Set[SnapshotId]] = None, ) -> None: - """Shows a summary of the differences. + """Shows a summary of the model differences. Args: context_diff: The context diff to use to print the summary environment_naming_info: The environment naming info to reference when printing model names default_catalog: The default catalog to reference when deciding to remove catalog from display names no_diff: Hide the actual SQL differences. - ignored_snapshot_ids: A set of snapshot ids that are ignored """ - ignored_snapshot_ids = ignored_snapshot_ids or set() - if context_diff.is_new_environment: - self._print( - Tree( - f"[bold]New environment `{context_diff.environment}` will be created from `{context_diff.create_from}`" - ) - ) - if not context_diff.has_snapshot_changes: - return - - if not context_diff.has_changes: - self._print(Tree(f"[bold]No differences when compared to `{context_diff.environment}`")) - return - - self._print(Tree(f"[bold]Summary of differences against `{context_diff.environment}`:")) self._show_summary_tree_for( context_diff, "Models", @@ -635,7 +1834,6 @@ def show_model_difference_summary( environment_naming_info, default_catalog, no_diff=no_diff, - ignored_snapshot_ids=ignored_snapshot_ids, ) self._show_summary_tree_for( context_diff, @@ -644,7 +1842,6 @@ def show_model_difference_summary( environment_naming_info, default_catalog, no_diff=no_diff, - ignored_snapshot_ids=ignored_snapshot_ids, ) def plan( @@ -677,29 +1874,13 @@ def plan( default_catalog=default_catalog, ) - if not no_prompts: - self._show_options_after_categorization( - plan_builder, auto_apply, default_catalog=default_catalog - ) + self._show_options_after_categorization( + plan_builder, auto_apply, default_catalog=default_catalog, no_prompts=no_prompts + ) if auto_apply: plan_builder.apply() - def _get_ignored_tree( - self, - ignored_snapshot_ids: t.Set[SnapshotId], - snapshots: t.Dict[SnapshotId, Snapshot], - environment_naming_info: EnvironmentNamingInfo, - default_catalog: t.Optional[str], - ) -> Tree: - ignored = Tree("[bold][ignored]Ignored Models (Expected Plan Start):") - for s_id in ignored_snapshot_ids: - snapshot = snapshots[s_id] - ignored.add( - f"[ignored]{snapshot.display_name(environment_naming_info, default_catalog, dialect=self.dialect)} ({snapshot.get_latest(start_date(snapshot, snapshots.values()))})" - ) - return ignored - def _show_summary_tree_for( self, context_diff: ContextDiff, @@ -708,36 +1889,25 @@ def _show_summary_tree_for( environment_naming_info: EnvironmentNamingInfo, default_catalog: t.Optional[str], no_diff: bool = True, - ignored_snapshot_ids: t.Optional[t.Set[SnapshotId]] = None, ) -> None: - ignored_snapshot_ids = ignored_snapshot_ids or set() - selected_snapshots = { - s_id: snapshot - for s_id, snapshot in context_diff.snapshots.items() - if snapshot_selector(snapshot) - } - selected_ignored_snapshot_ids = { - s_id for s_id in selected_snapshots if s_id in ignored_snapshot_ids - } added_snapshot_ids = { s_id for s_id in context_diff.added if snapshot_selector(context_diff.snapshots[s_id]) - } - selected_ignored_snapshot_ids + } removed_snapshot_ids = { s_id for s_id, snapshot in context_diff.removed_snapshots.items() if snapshot_selector(snapshot) - } - selected_ignored_snapshot_ids + } modified_snapshot_ids = { current_snapshot.snapshot_id for _, (current_snapshot, _) in context_diff.modified_snapshots.items() if snapshot_selector(current_snapshot) - } - selected_ignored_snapshot_ids + } tree_sets = ( added_snapshot_ids, removed_snapshot_ids, modified_snapshot_ids, - selected_ignored_snapshot_ids, ) if all(not s_ids for s_ids in tree_sets): return @@ -748,67 +1918,92 @@ def _show_summary_tree_for( for s_id in sorted(added_snapshot_ids): snapshot = context_diff.snapshots[s_id] added_tree.add( - f"[added]{snapshot.display_name(environment_naming_info, default_catalog, dialect=self.dialect)}" + f"[added]{snapshot.display_name(environment_naming_info, default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, dialect=self.dialect)}" ) - tree.add(self._limit_model_names(added_tree, self.verbose)) + tree.add(self._limit_model_names(added_tree, self.verbosity)) if removed_snapshot_ids: removed_tree = Tree("[bold][removed]Removed:") for s_id in sorted(removed_snapshot_ids): snapshot_table_info = context_diff.removed_snapshots[s_id] removed_tree.add( - f"[removed]{snapshot_table_info.display_name(environment_naming_info, default_catalog, dialect=self.dialect)}" + f"[removed]{snapshot_table_info.display_name(environment_naming_info, default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, dialect=self.dialect)}" ) - tree.add(self._limit_model_names(removed_tree, self.verbose)) + tree.add(self._limit_model_names(removed_tree, self.verbosity)) if modified_snapshot_ids: - direct = Tree("[bold][direct]Directly Modified:") - indirect = Tree("[bold][indirect]Indirectly Modified:") - metadata = Tree("[bold][metadata]Metadata Updated:") - for s_id in modified_snapshot_ids: - name = s_id.name - display_name = context_diff.snapshots[s_id].display_name( - environment_naming_info, default_catalog, dialect=self.dialect - ) - if context_diff.directly_modified(name): - direct.add( - f"[direct]{display_name}" - if no_diff - else Syntax(f"{display_name}\n{context_diff.text_diff(name)}", "sql") - ) - elif context_diff.indirectly_modified(name): - indirect.add(f"[indirect]{display_name}") - elif context_diff.metadata_updated(name): - metadata.add( - f"[metadata]{display_name}" - if no_diff - else Syntax(f"{display_name}", "sql", word_wrap=True) - ) - if direct.children: - tree.add(direct) - if indirect.children: - tree.add(self._limit_model_names(indirect, self.verbose)) - if metadata.children: - tree.add(metadata) - if selected_ignored_snapshot_ids: - tree.add( - self._get_ignored_tree( - selected_ignored_snapshot_ids, - selected_snapshots, - environment_naming_info, - default_catalog, - ) + tree = self._add_modified_models( + context_diff, + modified_snapshot_ids, + tree, + environment_naming_info, + default_catalog, + no_diff, ) + self._print(tree) + def _add_modified_models( + self, + context_diff: ContextDiff, + modified_snapshot_ids: t.Set[SnapshotId], + tree: Tree, + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str] = None, + no_diff: bool = True, + ) -> Tree: + direct = Tree("[bold][direct]Directly Modified:") + indirect = Tree("[bold][indirect]Indirectly Modified:") + metadata = Tree("[bold][metadata]Metadata Updated:") + for s_id in modified_snapshot_ids: + name = s_id.name + display_name = context_diff.snapshots[s_id].display_name( + environment_naming_info, + default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, + dialect=self.dialect, + ) + if context_diff.directly_modified(name): + direct.add( + f"[direct]{display_name}" + if no_diff + else Syntax(f"{display_name}\n{context_diff.text_diff(name)}", "sql") + ) + elif context_diff.indirectly_modified(name): + indirect.add(f"[indirect]{display_name}") + elif context_diff.metadata_updated(name): + metadata.add( + f"[metadata]{display_name}" + if no_diff + else Syntax(f"{display_name}\n{context_diff.text_diff(name)}", "sql") + ) + if direct.children: + tree.add(direct) + if indirect.children: + tree.add(self._limit_model_names(indirect, self.verbosity)) + if metadata.children: + tree.add(metadata) + return tree + def _show_options_after_categorization( - self, plan_builder: PlanBuilder, auto_apply: bool, default_catalog: t.Optional[str] + self, + plan_builder: PlanBuilder, + auto_apply: bool, + default_catalog: t.Optional[str], + no_prompts: bool, ) -> None: plan = plan_builder.build() - if plan.forward_only and plan.new_snapshots: + if not no_prompts and plan.forward_only and plan.new_snapshots: self._prompt_effective_from(plan_builder, auto_apply, default_catalog) if plan.requires_backfill: self._show_missing_dates(plan_builder.build(), default_catalog) - self._prompt_backfill(plan_builder, auto_apply, default_catalog) + + if not no_prompts: + self._prompt_backfill(plan_builder, auto_apply, default_catalog) + + backfill_or_preview = "preview" if plan.is_dev and plan.forward_only else "backfill" + if not auto_apply and self._confirm( + f"Apply - {backfill_or_preview.capitalize()} Tables" + ): + plan_builder.apply() elif plan.has_changes and not auto_apply: self._prompt_promote(plan_builder) elif plan.has_unmodified_unpromoted and not auto_apply: @@ -826,21 +2021,58 @@ def _prompt_categorize( """Get the user's change category for the directly modified models.""" plan = plan_builder.build() - self.show_model_difference_summary( - plan.context_diff, - plan.environment_naming_info, - default_catalog=default_catalog, - ignored_snapshot_ids=plan.ignored, - ) + if plan.restatements: + # A plan can have restatements for the following reasons: + # - The user specifically called `sqlmesh plan` with --restate-model. + # This creates a "restatement plan" which disallows all other changes and simply force-backfills + # the selected models and their downstream dependencies using the versions of the models stored in state. + # - There are no specific restatements (so changes are allowed) AND dev previews need to be computed. + # The "restatements" feature is currently reused for dev previews. + if plan.selected_models_to_restate: + # There were legitimate restatements, no dev previews + tree = Tree( + "[bold]Models selected for restatement:[/bold]\n" + "This causes backfill of the model itself as well as affected downstream models" + ) + model_fqn_to_snapshot = {s.name: s for s in plan.snapshots.values()} + for model_fqn in plan.selected_models_to_restate: + snapshot = model_fqn_to_snapshot[model_fqn] + display_name = snapshot.display_name( + plan.environment_naming_info, + default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, + dialect=self.dialect, + ) + tree.add( + display_name + ) # note: we deliberately dont show any intervals here; they get shown in the backfill section + self._print(tree) + else: + # We are computing dev previews, do not confuse the user by printing out something to do + # with restatements. Dev previews are already highlighted in the backfill step + pass + else: + self.show_environment_difference_summary( + plan.context_diff, + no_diff=no_diff, + ) + + if plan.context_diff.has_changes: + self.show_model_difference_summary( + plan.context_diff, + plan.environment_naming_info, + default_catalog=default_catalog, + ) if not no_diff: self._show_categorized_snapshots(plan, default_catalog) for snapshot in plan.uncategorized: + if snapshot.is_model and snapshot.model.forward_only: + continue if not no_diff: self.show_sql(plan.context_diff.text_diff(snapshot.name)) tree = Tree( - f"[bold][direct]Directly Modified: {snapshot.display_name(plan.environment_naming_info, default_catalog, dialect=self.dialect)}" + f"[bold][direct]Directly Modified: {snapshot.display_name(plan.environment_naming_info, default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, dialect=self.dialect)}" ) indirect_tree = None @@ -850,10 +2082,10 @@ def _prompt_categorize( indirect_tree = Tree("[indirect]Indirectly Modified Children:") tree.add(indirect_tree) indirect_tree.add( - f"[indirect]{child_snapshot.display_name(plan.environment_naming_info, default_catalog, dialect=self.dialect)}" + f"[indirect]{child_snapshot.display_name(plan.environment_naming_info, default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, dialect=self.dialect)}" ) if indirect_tree: - indirect_tree = self._limit_model_names(indirect_tree, self.verbose) + indirect_tree = self._limit_model_names(indirect_tree, self.verbosity) self._print(tree) if not no_prompts: @@ -865,27 +2097,36 @@ def _show_categorized_snapshots(self, plan: Plan, default_catalog: t.Optional[st context_diff = plan.context_diff for snapshot in plan.categorized: - if not context_diff.directly_modified(snapshot.name): - continue - - category_str = SNAPSHOT_CHANGE_CATEGORY_STR[snapshot.change_category] - tree = Tree( - f"[bold][direct]Directly Modified: {snapshot.display_name(plan.environment_naming_info, default_catalog, dialect=self.dialect)} ({category_str})" - ) - indirect_tree = None - for child_sid in sorted(plan.indirectly_modified.get(snapshot.snapshot_id, set())): - child_snapshot = context_diff.snapshots[child_sid] - if not indirect_tree: - indirect_tree = Tree("[indirect]Indirectly Modified Children:") - tree.add(indirect_tree) - child_category_str = SNAPSHOT_CHANGE_CATEGORY_STR[child_snapshot.change_category] - indirect_tree.add( - f"[indirect]{child_snapshot.display_name(plan.environment_naming_info, default_catalog, dialect=self.dialect)} ({child_category_str})" + if context_diff.directly_modified(snapshot.name): + category_str = SNAPSHOT_CHANGE_CATEGORY_STR[snapshot.change_category] + tree = Tree( + f"\n[bold][direct]Directly Modified: {snapshot.display_name(plan.environment_naming_info, default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, dialect=self.dialect)} ({category_str})" ) - if indirect_tree: - indirect_tree = self._limit_model_names(indirect_tree, self.verbose) + indirect_tree = None + for child_sid in sorted(plan.indirectly_modified.get(snapshot.snapshot_id, set())): + child_snapshot = context_diff.snapshots[child_sid] + if not indirect_tree: + indirect_tree = Tree("[indirect]Indirectly Modified Children:") + tree.add(indirect_tree) + child_category_str = SNAPSHOT_CHANGE_CATEGORY_STR[ + child_snapshot.change_category + ] + indirect_tree.add( + f"[indirect]{child_snapshot.display_name(plan.environment_naming_info, default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, dialect=self.dialect)} ({child_category_str})" + ) + if indirect_tree: + indirect_tree = self._limit_model_names(indirect_tree, self.verbosity) + elif context_diff.metadata_updated(snapshot.name): + tree = Tree( + f"\n[bold][metadata]Metadata Updated: {snapshot.display_name(plan.environment_naming_info, default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, dialect=self.dialect)}" + ) + else: + continue - self._print(Syntax(context_diff.text_diff(snapshot.name), "sql", word_wrap=True)) + text_diff = context_diff.text_diff(snapshot.name) + if text_diff: + self._print("") + self._print(Syntax(text_diff, "sql", word_wrap=True)) self._print(tree) def _show_missing_dates(self, plan: Plan, default_catalog: t.Optional[str]) -> None: @@ -893,7 +2134,7 @@ def _show_missing_dates(self, plan: Plan, default_catalog: t.Optional[str]) -> N missing_intervals = plan.missing_intervals if not missing_intervals: return - backfill = Tree("[bold]Models needing backfill (missing dates):") + backfill = Tree("[bold]Models needing backfill:[/bold]") for missing in missing_intervals: snapshot = plan.context_diff.snapshots[missing.snapshot_id] if not snapshot.is_model: @@ -903,11 +2144,17 @@ def _show_missing_dates(self, plan: Plan, default_catalog: t.Optional[str]) -> N if not plan.deployability_index.is_deployable(snapshot): preview_modifier = " ([orange1]preview[/orange1])" + display_name = snapshot.display_name( + plan.environment_naming_info, + default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, + dialect=self.dialect, + ) backfill.add( - f"{snapshot.display_name(plan.environment_naming_info, default_catalog, dialect=self.dialect)}: {missing.format_intervals(snapshot.node.interval_unit)}{preview_modifier}" + f"{display_name}: \\[{_format_missing_intervals(snapshot, missing)}]{preview_modifier}" ) + if backfill: - backfill = self._limit_model_names(backfill, self.verbose) + backfill = self._limit_model_names(backfill, self.verbosity) self._print(backfill) def _prompt_effective_from( @@ -954,6 +2201,9 @@ def _prompt_backfill( if not plan_builder.override_end: if plan.provided_end: blank_meaning = f"'{time_like_to_str(plan.provided_end)}'" + elif plan.end_override_per_model: + max_end = max(plan.end_override_per_model.values()) + blank_meaning = f"'{time_like_to_str(max_end)}'" else: blank_meaning = "now" end = self._prompt( @@ -964,47 +2214,47 @@ def _prompt_backfill( plan = plan_builder.build() - if plan.ignored: - self._print( - self._get_ignored_tree( - plan.ignored, - plan.context_diff.snapshots, - plan.environment_naming_info, - default_catalog, - ) - ) - if not auto_apply and self._confirm(f"Apply - {backfill_or_preview.capitalize()} Tables"): - plan_builder.apply() - def _prompt_promote(self, plan_builder: PlanBuilder) -> None: if self._confirm( "Apply - Virtual Update", ): plan_builder.apply() - def log_test_results( - self, result: unittest.result.TestResult, output: str, target_dialect: str - ) -> None: + def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None: + # We don't log the test results if no tests were ran + if not result.testsRun: + return + divider_length = 70 + + self._log_test_details(result) + + message = ( + f"Ran {result.testsRun} tests against {target_dialect} in {result.duration} seconds." + ) if result.wasSuccessful(): self._print("=" * divider_length) self._print( - f"Successfully Ran {str(result.testsRun)} tests against {target_dialect}", + f"Successfully {message}", style="green", ) self._print("-" * divider_length) else: self._print("-" * divider_length) - self._print("Test Failure Summary") - self._print("=" * divider_length) - self._print( - f"Num Successful Tests: {result.testsRun - len(result.failures) - len(result.errors)}" - ) - for test, _ in result.failures + result.errors: - if isinstance(test, ModelTest): - self._print(f"Failure Test: {test.model.name} {test.test_name}") + self._print("Test Failure Summary", style="red") self._print("=" * divider_length) - self._print(output) + fail_and_error_tests = result.get_fail_and_error_tests() + self._print(f"{message} \n") + + self._print(f"Failed tests ({len(fail_and_error_tests)}):") + for test in fail_and_error_tests: + self._print(f" • {test.path}::{test.test_name}") + self._print("=" * divider_length, end="\n\n") + + def _captured_unit_test_results(self, result: ModelTextTestResult) -> str: + with self.console.capture() as capture: + self._log_test_details(result) + return strip_ansi_codes(capture.get()) def show_sql(self, sql: str) -> None: self._print(Syntax(sql, "sql", word_wrap=True), crop=False) @@ -1012,11 +2262,93 @@ def show_sql(self, sql: str) -> None: def log_status_update(self, message: str) -> None: self._print(message) + def log_skipped_models(self, snapshot_names: t.Set[str]) -> None: + if snapshot_names: + msg = " " + "\n ".join(snapshot_names) + self._print(f"[dark_orange3]Skipped models[/dark_orange3]\n\n{msg}") + + def log_failed_models(self, errors: t.List[NodeExecutionFailedError]) -> None: + if errors: + self._print("\n[red]Failed models[/red]\n") + + error_messages = _format_node_errors(errors) + + for node_name, msg in error_messages.items(): + self._print(f" [red]{node_name}[/red]\n\n{msg}") + + def log_models_updated_during_restatement( + self, + snapshots: t.List[t.Tuple[SnapshotTableInfo, SnapshotTableInfo]], + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str] = None, + ) -> None: + if snapshots: + tree = Tree( + f"[yellow]The following models had new versions deployed while data was being restated:[/yellow]" + ) + + for restated_snapshot, updated_snapshot in snapshots: + display_name = restated_snapshot.display_name( + environment_naming_info, + default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, + dialect=self.dialect, + ) + current_branch = tree.add(display_name) + current_branch.add(f"restated version: '{restated_snapshot.version}'") + current_branch.add(f"currently active version: '{updated_snapshot.version}'") + + self._print(tree) + self._print("") # newline spacer + + def log_destructive_change( + self, + snapshot_name: str, + alter_operations: t.List[TableAlterOperation], + dialect: str, + error: bool = True, + ) -> None: + if error: + self._print(format_destructive_change_msg(snapshot_name, alter_operations, dialect)) + else: + self.log_warning( + format_destructive_change_msg(snapshot_name, alter_operations, dialect, error) + ) + + def log_additive_change( + self, + snapshot_name: str, + alter_operations: t.List[TableAlterOperation], + dialect: str, + error: bool = True, + ) -> None: + if error: + self._print(format_additive_change_msg(snapshot_name, alter_operations, dialect)) + else: + self.log_warning( + format_additive_change_msg(snapshot_name, alter_operations, dialect, error) + ) + def log_error(self, message: str) -> None: self._print(f"[red]{message}[/red]") + def log_warning(self, short_message: str, long_message: t.Optional[str] = None) -> None: + logger.warning(long_message or short_message) + if not self.ignore_warnings: + if long_message: + file_path = None + for handler in logger.root.handlers: + if isinstance(handler, logging.FileHandler): + file_path = handler.baseFilename + break + file_path_msg = f" Learn more in logs: {file_path}\n" if file_path else "" + short_message = f"{short_message}{file_path_msg}" + message_lstrip = short_message.lstrip() + leading_ws = short_message[: -len(message_lstrip)] + message_formatted = f"{leading_ws}[yellow]\\[WARNING] {message_lstrip}[/yellow]" + self._print(message_formatted) + def log_success(self, message: str) -> None: - self._print(f"\n[green]{message}[/green]\n") + self._print(f"[green]{message}[/green]\n") def loading_start(self, message: t.Optional[str] = None) -> uuid.UUID: id = uuid.uuid4() @@ -1028,6 +2360,113 @@ def loading_stop(self, id: uuid.UUID) -> None: self.loading_status[id].stop() del self.loading_status[id] + def show_table_diff_details( + self, + models_to_diff: t.List[str], + ) -> None: + """Display information about which tables are going to be diffed""" + + if models_to_diff: + m_tree = Tree("\n[b]Models to compare:") + for m in models_to_diff: + m_tree.add(f"[{self.TABLE_DIFF_SOURCE_BLUE}]{m}[/{self.TABLE_DIFF_SOURCE_BLUE}]") + self._print(m_tree) + self._print("") + + def start_table_diff_progress(self, models_to_diff: int) -> None: + if not self.table_diff_progress: + self.table_diff_progress = make_progress_bar( + "Calculating model differences", self.console + ) + self.table_diff_model_progress = Progress( + TextColumn("{task.fields[view_name]}", justify="right"), + SpinnerColumn(spinner_name="simpleDots"), + console=self.console, + ) + + progress_table = Table.grid() + progress_table.add_row(self.table_diff_progress) + progress_table.add_row(self.table_diff_model_progress) + + self.table_diff_progress_live = Live(progress_table, refresh_per_second=10) + self.table_diff_progress_live.start() + + self.table_diff_model_task = self.table_diff_progress.add_task( + "Diffing", total=models_to_diff + ) + + def start_table_diff_model_progress(self, model: str) -> None: + if self.table_diff_model_progress and model not in self.table_diff_model_tasks: + self.table_diff_model_tasks[model] = self.table_diff_model_progress.add_task( + f"Diffing {model}...", + view_name=model, + total=1, + ) + + def update_table_diff_progress(self, model: str) -> None: + if self.table_diff_progress: + self.table_diff_progress.update(self.table_diff_model_task, refresh=True, advance=1) + if self.table_diff_model_progress and model in self.table_diff_model_tasks: + model_task_id = self.table_diff_model_tasks[model] + self.table_diff_model_progress.remove_task(model_task_id) + + def stop_table_diff_progress(self, success: bool) -> None: + if self.table_diff_progress_live: + self.table_diff_progress_live.stop() + self.table_diff_progress_live = None + self.log_status_update("") + + if success: + self.log_success(f"Table diff completed successfully!") + else: + self.log_error("Table diff failed!") + + self.table_diff_progress = None + self.table_diff_model_progress = None + self.table_diff_model_tasks = {} + + def show_table_diff_summary(self, table_diff: TableDiff) -> None: + tree = Tree("\n[b]Table Diff") + + if table_diff.model_name: + model = Tree("Model:") + model.add(f"[blue]{table_diff.model_name}[/blue]") + + tree.add(model) + + envs = Tree("Environment:") + source = Tree( + f"Source: [{self.TABLE_DIFF_SOURCE_BLUE}]{table_diff.source_alias}[/{self.TABLE_DIFF_SOURCE_BLUE}]" + ) + envs.add(source) + + target = Tree( + f"Target: [{self.TABLE_DIFF_TARGET_GREEN}]{table_diff.target_alias}[/{self.TABLE_DIFF_TARGET_GREEN}]" + ) + envs.add(target) + + tree.add(envs) + + tables = Tree("Tables:") + + tables.add( + f"Source: [{self.TABLE_DIFF_SOURCE_BLUE}]{table_diff.source}[/{self.TABLE_DIFF_SOURCE_BLUE}]" + ) + tables.add( + f"Target: [{self.TABLE_DIFF_TARGET_GREEN}]{table_diff.target}[/{self.TABLE_DIFF_TARGET_GREEN}]" + ) + + tree.add(tables) + + join = Tree("Join On:") + _, _, key_column_names = table_diff.key_columns + for col_name in key_column_names: + join.add(f"[yellow]{col_name}[/yellow]") + + tree.add(join) + + self._print(tree) + def show_schema_diff(self, schema_diff: SchemaDiff) -> None: source_name = schema_diff.source if schema_diff.source_alias: @@ -1036,7 +2475,7 @@ def show_schema_diff(self, schema_diff: SchemaDiff) -> None: if schema_diff.target_alias: target_name = schema_diff.target_alias.upper() - first_line = f"\n[b]Schema Diff Between '[yellow]{source_name}[/yellow]' and '[green]{target_name}[/green]'" + first_line = f"\n[b]Schema Diff Between '[{self.TABLE_DIFF_SOURCE_BLUE}]{source_name}[/{self.TABLE_DIFF_SOURCE_BLUE}]' and '[{self.TABLE_DIFF_TARGET_GREEN}]{target_name}[/{self.TABLE_DIFF_TARGET_GREEN}]'" if schema_diff.model_name: first_line = ( first_line + f" environments for model '[blue]{schema_diff.model_name}[/blue]'" @@ -1070,6 +2509,12 @@ def show_schema_diff(self, schema_diff: SchemaDiff) -> None: def show_row_diff( self, row_diff: RowDiff, show_sample: bool = True, skip_grain_check: bool = False ) -> None: + if row_diff.empty: + self.console.print( + "\n[b][red]Neither the source nor the target table contained any records[/red][/b]" + ) + return + source_name = row_diff.source if row_diff.source_alias: source_name = row_diff.source_alias.upper() @@ -1114,9 +2559,78 @@ def show_row_diff( self.console.print(" No columns with same name and data type in both tables") if show_sample: + sample = row_diff.joined_sample self.console.print("\n[b][blue]COMMON ROWS[/blue] sample data differences:[/b]") - if row_diff.joined_sample.shape[0] > 0: - self.console.print(row_diff.joined_sample.to_string(index=False), end="\n\n") + if sample.shape[0] > 0: + keys: list[str] = [] + columns: dict[str, list[str]] = {} + source_prefix, source_name = ( + (f"{source_name}__", source_name) + if source_name.lower() != row_diff.source.lower() + else ("s__", "SOURCE") + ) + target_prefix, target_name = ( + (f"{target_name}__", target_name) + if target_name.lower() != row_diff.target.lower() + else ("t__", "TARGET") + ) + + # Extract key and column names from the joined sample + for column in row_diff.joined_sample.columns: + if source_prefix in column: + column_name = "__".join(column.split(source_prefix)[1:]) + columns[column_name] = [column, target_prefix + column_name] + elif target_prefix not in column: + keys.append(column) + + column_styles = { + source_name: self.TABLE_DIFF_SOURCE_BLUE, + target_name: self.TABLE_DIFF_TARGET_GREEN, + } + + for column, [source_column, target_column] in columns.items(): + # Create a table with the joined keys and comparison columns + column_table = row_diff.joined_sample[keys + [source_column, target_column]] + + # Filter to retain non identical-valued rows + column_table = column_table[ + column_table.apply( + lambda row: not _cells_match(row[source_column], row[target_column]), + axis=1, + ) + ] + + # Rename the column headers for readability + column_table = column_table.rename( + columns={ + source_column: source_name, + target_column: target_name, + } + ) + + table = Table(show_header=True) + for column_name in column_table.columns: + style = column_styles.get(column_name, "") + table.add_column(column_name, style=style, header_style=style) + + for _, row in column_table.iterrows(): + table.add_row( + *[ + str( + round(cell, row_diff.decimals) + if isinstance(cell, float) + else cell + ) + for cell in row + ] + ) + + self.console.print( + f"Column: [underline][bold cyan]{column}[/bold cyan][/underline]", + table, + end="\n", + ) + else: self.console.print(" All joined rows match") @@ -1128,6 +2642,100 @@ def show_row_diff( self.console.print(f"\n[b][green]{target_name} ONLY[/green] sample rows:[/b]") self.console.print(row_diff.t_sample.to_string(index=False), end="\n\n") + def show_table_diff( + self, + table_diffs: t.List[TableDiff], + show_sample: bool = True, + skip_grain_check: bool = False, + temp_schema: t.Optional[str] = None, + ) -> None: + """ + Display the table diff between all mismatched tables. + """ + if len(table_diffs) > 1: + mismatched_tables = [] + fully_matched = [] + for table_diff in table_diffs: + if ( + table_diff.schema_diff().source_schema == table_diff.schema_diff().target_schema + ) and ( + table_diff.row_diff( + temp_schema=temp_schema, skip_grain_check=skip_grain_check + ).full_match_pct + == 100 + ): + fully_matched.append(table_diff) + else: + mismatched_tables.append(table_diff) + table_diffs = mismatched_tables if mismatched_tables else [] + if fully_matched: + m_tree = Tree("\n[b]Identical Tables") + for m in fully_matched: + m_tree.add( + f"[{self.TABLE_DIFF_SOURCE_BLUE}]{m.source}[/{self.TABLE_DIFF_SOURCE_BLUE}] - [{self.TABLE_DIFF_TARGET_GREEN}]{m.target}[/{self.TABLE_DIFF_TARGET_GREEN}]" + ) + self._print(m_tree) + + if mismatched_tables: + m_tree = Tree("\n[b]Mismatched Tables") + for m in mismatched_tables: + m_tree.add( + f"[{self.TABLE_DIFF_SOURCE_BLUE}]{m.source}[/{self.TABLE_DIFF_SOURCE_BLUE}] - [{self.TABLE_DIFF_TARGET_GREEN}]{m.target}[/{self.TABLE_DIFF_TARGET_GREEN}]" + ) + self._print(m_tree) + + for table_diff in table_diffs: + self.show_table_diff_summary(table_diff) + self.show_schema_diff(table_diff.schema_diff()) + self.show_row_diff( + table_diff.row_diff(temp_schema=temp_schema, skip_grain_check=skip_grain_check), + show_sample=show_sample, + skip_grain_check=skip_grain_check, + ) + + def print_environments(self, environments_summary: t.List[EnvironmentSummary]) -> None: + """Prints all environment names along with expiry datetime.""" + output = [ + f"{summary.name} - {time_like_to_str(summary.expiration_ts)}" + if summary.expiration_ts + else f"{summary.name} - No Expiry" + for summary in environments_summary + ] + output_str = "\n".join([str(len(output)), *output]) + self.log_status_update(f"Number of SQLMesh environments are: {output_str}") + + def show_intervals(self, snapshot_intervals: t.Dict[Snapshot, SnapshotIntervals]) -> None: + complete = Tree(f"[b]Complete Intervals[/b]") + incomplete = Tree(f"[b]Missing Intervals[/b]") + + for snapshot, intervals in sorted(snapshot_intervals.items(), key=lambda s: s[0].node.name): + if intervals.intervals: + incomplete.add( + f"{snapshot.node.name}: [{intervals.format_intervals(snapshot.node.interval_unit)}]" + ) + else: + complete.add(snapshot.node.name) + + if complete.children: + self._print(complete) + + if incomplete.children: + self._print(incomplete) + + def print_connection_config(self, config: ConnectionConfig, title: str = "Connection") -> None: + tree = Tree(f"[b]{title}:[/b]") + tree.add(f"Type: [bold cyan]{config.type_}[/bold cyan]") + tree.add(f"Catalog: [bold cyan]{config.get_catalog()}[/bold cyan]") + + try: + engine_adapter_type = config._engine_adapter + tree.add(f"Dialect: [bold cyan]{engine_adapter_type.DIALECT}[/bold cyan]") + except NotImplementedError: + # not all ConnectionConfig's have an engine adapter associated. The CloudConnectionConfig has a HTTP client instead + pass + + self._print(tree) + def _get_snapshot_change_category( self, snapshot: Snapshot, @@ -1139,9 +2747,9 @@ def _get_snapshot_change_category( snapshot, plan_builder.environment_naming_info, default_catalog ) response = self._prompt( - "\n".join([f"[{i+1}] {choice}" for i, choice in enumerate(choices.values())]), + "\n".join([f"[{i + 1}] {choice}" for i, choice in enumerate(choices.values())]), show_choices=False, - choices=[f"{i+1}" for i in range(len(choices))], + choices=[f"{i + 1}" for i in range(len(choices))], ) choice = list(choices)[int(response) - 1] plan_builder.set_choice(snapshot, choice) @@ -1154,7 +2762,9 @@ def _snapshot_change_choices( use_rich_formatting: bool = True, ) -> t.Dict[SnapshotChangeCategory, str]: direct = snapshot.display_name( - environment_naming_info, default_catalog, dialect=self.dialect + environment_naming_info, + default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, + dialect=self.dialect, ) if use_rich_formatting: direct = f"[direct]{direct}[/direct]" @@ -1181,6 +2791,100 @@ def _snapshot_change_choices( } return labeled_choices + def show_linter_violations( + self, violations: t.List[RuleViolation], model: Model, is_error: bool = False + ) -> None: + severity = "errors" if is_error else "warnings" + + # Sort violations by line, then alphabetically the name of the violation + # Violations with no range go first + sorted_violations = sorted( + violations, + key=lambda v: ( + v.violation_range.start.line if v.violation_range else -1, + v.rule.name.lower(), + ), + ) + violations_text = [ + ( + f" - Line {v.violation_range.start.line + 1}: {v.rule.name} - {v.violation_msg}" + if v.violation_range + else f" - {v.rule.name}: {v.violation_msg}" + ) + for v in sorted_violations + ] + violations_msg = "\n".join(violations_text) + msg = f"Linter {severity} for {model._path}:\n{violations_msg}" + + if is_error: + self.log_error(msg) + else: + self.log_warning(msg) + + def _log_test_details( + self, result: ModelTextTestResult, unittest_char_separator: bool = True + ) -> None: + """ + This is a helper method that encapsulates the logic for logging the relevant unittest for the result. + The top level method (`log_test_results`) reuses `_log_test_details` differently based on the console. + + Args: + result: The unittest test result that contains metrics like num success, fails, ect. + """ + if result.wasSuccessful(): + self._print("\n", end="") + return + + if unittest_char_separator: + self._print(f"\n{unittest.TextTestResult.separator1}\n\n", end="") + + for (test_case, failure), test_failure_tables in zip_longest( # type: ignore + result.failures, result.failure_tables + ): + self._print(unittest.TextTestResult.separator2) + self._print(f"FAIL: {test_case}") + + if test_description := test_case.shortDescription(): + self._print(test_description) + self._print(f"{unittest.TextTestResult.separator2}") + + if not test_failure_tables: + self._print(failure) + else: + for failure_table in test_failure_tables: + self._print(failure_table) + self._print("\n", end="") + + for test_case, error in result.errors: + self._print(unittest.TextTestResult.separator2) + self._print(f"ERROR: {test_case}") + self._print(f"{unittest.TextTestResult.separator2}") + self._print(error) + + +def _cells_match(x: t.Any, y: t.Any) -> bool: + """Helper function to compare two cells and returns true if they're equal, handling array objects.""" + import pandas as pd + import numpy as np + + # Convert array-like objects to list for consistent comparison + def _normalize(val: t.Any) -> t.Any: + # Convert Pandas null to Python null for the purposes of comparison to prevent errors like the following on boolean fields: + # - TypeError: boolean value of NA is ambiguous + # note pd.isnull() returns either a bool or a ndarray[bool] depending on if the input + # is scalar or an array + isnull = pd.isnull(val) + + if isinstance(isnull, bool): # scalar + if isnull: + val = None + elif all(isnull): # array + val = None + + return list(val) if isinstance(val, (pd.Series, np.ndarray)) else val + + return _normalize(x) == _normalize(y) + def add_to_layout_widget(target_widget: LayoutWidget, *widgets: widgets.Widget) -> LayoutWidget: """Helper function to add a widget to a layout widget. @@ -1261,13 +2965,18 @@ def _prompt_effective_from( def effective_from_change_callback(change: t.Dict[str, datetime.datetime]) -> None: plan_builder.set_effective_from(change["new"]) - self._show_options_after_categorization(plan_builder, auto_apply, default_catalog) + self._show_options_after_categorization( + plan_builder, auto_apply, default_catalog, no_prompts=False + ) def going_forward_change_callback(change: t.Dict[str, bool]) -> None: checked = change["new"] plan_builder.set_effective_from(None if checked else yesterday_ds()) self._show_options_after_categorization( - plan_builder, auto_apply=auto_apply, default_catalog=default_catalog + plan_builder, + auto_apply=auto_apply, + default_catalog=default_catalog, + no_prompts=False, ) date_picker = widgets.DatePicker( @@ -1325,11 +3034,15 @@ def _date_picker( def start_change_callback(change: t.Dict[str, datetime.datetime]) -> None: plan_builder.set_start(change["new"]) - self._show_options_after_categorization(plan_builder, auto_apply, default_catalog) + self._show_options_after_categorization( + plan_builder, auto_apply, default_catalog, no_prompts=False + ) def end_change_callback(change: t.Dict[str, datetime.datetime]) -> None: plan_builder.set_end(change["new"]) - self._show_options_after_categorization(plan_builder, auto_apply, default_catalog) + self._show_options_after_categorization( + plan_builder, auto_apply, default_catalog, no_prompts=False + ) if plan_builder.is_start_and_end_allowed: add_to_layout_widget( @@ -1377,11 +3090,17 @@ def end_change_callback(change: t.Dict[str, datetime.datetime]) -> None: button.output = output def _show_options_after_categorization( - self, plan_builder: PlanBuilder, auto_apply: bool, default_catalog: t.Optional[str] + self, + plan_builder: PlanBuilder, + auto_apply: bool, + default_catalog: t.Optional[str], + no_prompts: bool, ) -> None: self.dynamic_options_after_categorization_output.children = () self.display(self.dynamic_options_after_categorization_output) - super()._show_options_after_categorization(plan_builder, auto_apply, default_catalog) + super()._show_options_after_categorization( + plan_builder, auto_apply, default_catalog, no_prompts + ) def _add_to_dynamic_options(self, *widgets: widgets.Widget) -> None: add_to_layout_widget(self.dynamic_options_after_categorization_output, *widgets) @@ -1406,7 +3125,9 @@ def _get_snapshot_change_category( def radio_button_selected(change: t.Dict[str, t.Any]) -> None: plan_builder.set_choice(snapshot, choices[change["owner"].index]) - self._show_options_after_categorization(plan_builder, auto_apply, default_catalog) + self._show_options_after_categorization( + plan_builder, auto_apply, default_catalog, no_prompts=False + ) radio = widgets.RadioButtons( options=choice_mapping.values(), @@ -1419,9 +3140,11 @@ def radio_button_selected(change: t.Dict[str, t.Any]) -> None: ) self.display(radio) - def log_test_results( - self, result: unittest.result.TestResult, output: str, target_dialect: str - ) -> None: + def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None: + # We don't log the test results if no tests were ran + if not result.testsRun: + return + import ipywidgets as widgets divider_length = 70 @@ -1430,6 +3153,11 @@ def log_test_results( "font-weight": "bold", "font-family": "Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace", } + + message = ( + f"Ran {result.testsRun} tests against {target_dialect} in {result.duration} seconds." + ) + if result.wasSuccessful(): success_color = {"color": "#008000"} header = str(h("span", {"style": shared_style}, "-" * divider_length)) @@ -1437,41 +3165,43 @@ def log_test_results( h( "span", {"style": {**shared_style, **success_color}}, - f"Successfully Ran {str(result.testsRun)} Tests Against {target_dialect}", + f"Successfully {message}", ) ) footer = str(h("span", {"style": shared_style}, "=" * divider_length)) self.display(widgets.HTML("
".join([header, message, footer]))) else: + output = self._captured_unit_test_results(result) + fail_color = {"color": "#db3737"} fail_shared_style = {**shared_style, **fail_color} header = str(h("span", {"style": fail_shared_style}, "-" * divider_length)) message = str(h("span", {"style": fail_shared_style}, "Test Failure Summary")) - num_success = str( - h( - "span", - {"style": fail_shared_style}, - f"Num Successful Tests: {result.testsRun - len(result.failures) - len(result.errors)}", + fail_and_error_tests = result.get_fail_and_error_tests() + failed_tests = [ + str( + h( + "span", + {"style": fail_shared_style}, + f"Failed tests ({len(fail_and_error_tests)}):", + ) ) - ) - failure_tests = [] - for test, _ in result.failures + result.errors: - if isinstance(test, ModelTest): - failure_tests.append( - str( - h( - "span", - {"style": fail_shared_style}, - f"Failure Test: {test.model.name} {test.test_name}", - ) + ] + + for test in fail_and_error_tests: + failed_tests.append( + str( + h( + "span", + {"style": fail_shared_style}, + f" • {test.model.name}::{test.test_name}", ) ) - failures = "
".join(failure_tests) + ) + failures = "
".join(failed_tests) footer = str(h("span", {"style": fail_shared_style}, "=" * divider_length)) error_output = widgets.Textarea(output, layout={"height": "300px", "width": "100%"}) - test_info = widgets.HTML( - "
".join([header, message, footer, num_success, failures, footer]) - ) + test_info = widgets.HTML("
".join([header, message, footer, failures, footer])) self.display(widgets.VBox(children=[test_info, error_output], layout={"width": "100%"})) @@ -1487,35 +3217,67 @@ class CaptureTerminalConsole(TerminalConsole): def __init__(self, console: t.Optional[RichConsole] = None, **kwargs: t.Any) -> None: super().__init__(console=console, **kwargs) self._captured_outputs: t.List[str] = [] + self._warnings: t.List[str] = [] self._errors: t.List[str] = [] @property def captured_output(self) -> str: return "".join(self._captured_outputs) + @property + def captured_warnings(self) -> str: + return "".join(self._warnings) + @property def captured_errors(self) -> str: return "".join(self._errors) def consume_captured_output(self) -> str: - output = self.captured_output - self.clear_captured_outputs() - return output + try: + return self.captured_output + finally: + self._captured_outputs = [] - def consume_captured_errors(self) -> str: - errors = self.captured_errors - self.clear_captured_errors() - return errors + def consume_captured_warnings(self) -> str: + try: + return self.captured_warnings + finally: + self._warnings = [] - def clear_captured_outputs(self) -> None: - self._captured_outputs = [] + def consume_captured_errors(self) -> str: + try: + return self.captured_errors + finally: + self._errors = [] - def clear_captured_errors(self) -> None: - self._errors = [] + def log_warning( + self, + short_message: str, + long_message: t.Optional[str] = None, + *args: t.Any, + **kwargs: t.Any, + ) -> None: + if short_message not in self._warnings: + self._warnings.append(short_message) + if kwargs.pop("print", True): + super().log_warning(short_message, long_message) + + def log_error(self, message: str, *args: t.Any, **kwargs: t.Any) -> None: + if message not in self._errors: + self._errors.append(message) + if kwargs.pop("print", True): + super().log_error(message) + + def log_skipped_models(self, snapshot_names: t.Set[str]) -> None: + if snapshot_names: + self._captured_outputs.append( + "\n".join([f"SKIPPED snapshot {skipped}\n" for skipped in snapshot_names]) + ) + super().log_skipped_models(snapshot_names) - def log_error(self, message: str) -> None: - self._errors.append(message) - super().log_error(message) + def log_failed_models(self, errors: t.List[NodeExecutionFailedError]) -> None: + self._errors.extend([str(ex) for ex in errors if str(ex) not in self._errors]) + super().log_failed_models(errors) def _print(self, value: t.Any, **kwargs: t.Any) -> None: with self.console.capture() as capture: @@ -1529,147 +3291,212 @@ class MarkdownConsole(CaptureTerminalConsole): where you want to display a plan or test results in markdown. """ - def show_model_difference_summary( + CHECK_MARK = "" + AUDIT_PASS_MARK = "passed " + GREEN_AUDIT_PASS_MARK = AUDIT_PASS_MARK + AUDIT_FAIL_MARK = "failed " + AUDIT_PADDING = 7 + + def __init__(self, **kwargs: t.Any) -> None: + self.alert_block_max_content_length = int(kwargs.pop("alert_block_max_content_length", 500)) + self.alert_block_collapsible_threshold = int( + kwargs.pop("alert_block_collapsible_threshold", 200) + ) + + # capture_only = True: capture but dont print to console + # capture_only = False: capture and also print to console + self.warning_capture_only = kwargs.pop("warning_capture_only", False) + self.error_capture_only = kwargs.pop("error_capture_only", False) + + super().__init__( + **{**kwargs, "console": RichConsole(no_color=True, width=kwargs.pop("width", None))} + ) + + def show_environment_difference_summary( self, context_diff: ContextDiff, - environment_naming_info: EnvironmentNamingInfo, - default_catalog: t.Optional[str], no_diff: bool = True, - ignored_snapshot_ids: t.Optional[t.Set[SnapshotId]] = None, ) -> None: - """Shows a summary of the differences. + """Shows a summary of the environment differences. Args: context_diff: The context diff to use to print the summary. - environment_naming_info: The environment naming info to reference when printing model names - default_catalog: The default catalog to reference when deciding to remove catalog from display names - no_diff: Hide the actual SQL differences. - ignored_snapshot_ids: A set of snapshot names that are ignored + no_diff: Hide the actual environment statements differences. """ - ignored_snapshot_ids = ignored_snapshot_ids or set() if context_diff.is_new_environment: - self._print( - f"**New environment `{context_diff.environment}` will be created from `{context_diff.create_from}`**\n" + msg = ( + f"\n**`{context_diff.environment}` environment will be initialized**" + if not context_diff.create_from_env_exists + else f"\n**New environment `{context_diff.environment}` will be created from `{context_diff.create_from}`**" ) + self._print(msg) if not context_diff.has_snapshot_changes: return if not context_diff.has_changes: - self._print(f"**No differences when compared to `{context_diff.environment}`**\n") + self._print( + f"\n**No changes to plan: project files match the `{context_diff.environment}` environment**\n" + ) return - self._print(f"**Summary of differences against `{context_diff.environment}`:**\n") + self._print(f"\n**Summary of differences from `{context_diff.environment}`:**") - added_snapshots = { - context_diff.snapshots[s_id] - for s_id in context_diff.added - if s_id not in ignored_snapshot_ids - } - added_snapshot_models = {s for s in added_snapshots if s.is_model} - if added_snapshot_models: + if context_diff.has_requirement_changes: + self._print(f"\nRequirements:\n{context_diff.requirements_diff()}") + + if context_diff.has_environment_statements_changes and not no_diff: + self._print("\nEnvironment statements:\n") + for _, diff in context_diff.environment_statements_diff( + include_python_env=not context_diff.is_new_environment + ): + self._print(diff) + + def show_model_difference_summary( + self, + context_diff: ContextDiff, + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str], + no_diff: bool = True, + ) -> None: + """Shows a summary of the model differences. + + Args: + context_diff: The context diff to use to print the summary. + environment_naming_info: The environment naming info to reference when printing model names + default_catalog: The default catalog to reference when deciding to remove catalog from display names + no_diff: Hide the actual SQL differences. + """ + added_snapshots = {context_diff.snapshots[s_id] for s_id in context_diff.added} + if added_snapshots: self._print("\n**Added Models:**") - added_models = sorted(added_snapshot_models) - list_length = len(added_models) - if not self.verbose and list_length > self.INDIRECTLY_MODIFIED_DISPLAY_THRESHOLD: - self._print(added_models[0]) - self._print(f"- `.... {list_length-2} more ....`\n") - self._print(added_models[-1]) - else: - for snapshot in added_models: - self._print( - f"- `{snapshot.display_name(environment_naming_info, default_catalog, dialect=self.dialect)}`" - ) + self._print_models_with_threshold( + environment_naming_info, {s for s in added_snapshots if s.is_model}, default_catalog + ) added_snapshot_audits = {s for s in added_snapshots if s.is_audit} if added_snapshot_audits: self._print("\n**Added Standalone Audits:**") for snapshot in sorted(added_snapshot_audits): self._print( - f"- `{snapshot.display_name(environment_naming_info, default_catalog, dialect=self.dialect)}`" + f"- `{snapshot.display_name(environment_naming_info, default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, dialect=self.dialect)}`" ) - removed_snapshot_table_infos = { - snapshot_table_info - for s_id, snapshot_table_info in context_diff.removed_snapshots.items() - if s_id not in ignored_snapshot_ids - } - removed_model_snapshot_table_infos = {s for s in removed_snapshot_table_infos if s.is_model} - if removed_model_snapshot_table_infos: + removed_snapshot_table_infos = set(context_diff.removed_snapshots.values()) + if removed_snapshot_table_infos: self._print("\n**Removed Models:**") - removed_models = sorted(removed_model_snapshot_table_infos) - list_length = len(removed_models) - if not self.verbose and list_length > self.INDIRECTLY_MODIFIED_DISPLAY_THRESHOLD: - self._print(removed_models[0]) - self._print(f"- `.... {list_length-2} more ....`\n") - self._print(removed_models[-1]) - else: - for snapshot_table_info in removed_models: - self._print( - f"- `{snapshot_table_info.display_name(environment_naming_info, default_catalog, dialect=self.dialect)}`" - ) + self._print_models_with_threshold( + environment_naming_info, + {s for s in removed_snapshot_table_infos if s.is_model}, + default_catalog, + ) removed_audit_snapshot_table_infos = {s for s in removed_snapshot_table_infos if s.is_audit} if removed_audit_snapshot_table_infos: self._print("\n**Removed Standalone Audits:**") for snapshot_table_info in sorted(removed_audit_snapshot_table_infos): self._print( - f"- `{snapshot_table_info.display_name(environment_naming_info, default_catalog, dialect=self.dialect)}`" + f"- `{snapshot_table_info.display_name(environment_naming_info, default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, dialect=self.dialect)}`" ) modified_snapshots = { - current_snapshot - for current_snapshot, _ in context_diff.modified_snapshots.values() - if current_snapshot.snapshot_id not in ignored_snapshot_ids + current_snapshot for current_snapshot, _ in context_diff.modified_snapshots.values() } if modified_snapshots: - directly_modified = [] - indirectly_modified = [] - metadata_modified = [] - for snapshot in modified_snapshots: - if context_diff.directly_modified(snapshot.name): - directly_modified.append(snapshot) - elif context_diff.indirectly_modified(snapshot.name): - indirectly_modified.append(snapshot) - elif context_diff.metadata_updated(snapshot.name): - metadata_modified.append(snapshot) - if directly_modified: - self._print("\n**Directly Modified:**") - for snapshot in sorted(directly_modified): - self._print( - f"- `{snapshot.display_name(environment_naming_info, default_catalog, dialect=self.dialect)}`" - ) - if not no_diff: - self._print(f"```diff\n{context_diff.text_diff(snapshot.name)}\n```") - if indirectly_modified: - self._print("\n**Indirectly Modified:**") - indirectly_modified = sorted(indirectly_modified) - modified_length = len(indirectly_modified) - if ( - not self.verbose - and modified_length > self.INDIRECTLY_MODIFIED_DISPLAY_THRESHOLD - ): - self._print( - f"- `{indirectly_modified[0].display_name(environment_naming_info, default_catalog, dialect=self.dialect)}`\n" - f"- `.... {modified_length-2} more ....`\n" - f"- `{indirectly_modified[-1].display_name(environment_naming_info, default_catalog, dialect=self.dialect)}`" - ) - else: - for snapshot in indirectly_modified: + self._print_modified_models( + context_diff, modified_snapshots, environment_naming_info, default_catalog, no_diff + ) + + def _print_models_with_threshold( + self, + environment_naming_info: EnvironmentNamingInfo, + snapshot_table_infos: t.Set[SnapshotInfoLike], + default_catalog: t.Optional[str] = None, + ) -> None: + models = sorted(snapshot_table_infos) + list_length = len(models) + if ( + self.verbosity < Verbosity.VERY_VERBOSE + and list_length > self.INDIRECTLY_MODIFIED_DISPLAY_THRESHOLD + ): + self._print( + f"- `{models[0].display_name(environment_naming_info, default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, dialect=self.dialect)}`" + ) + self._print(f"- `.... {list_length - 2} more ....`\n") + self._print( + f"- `{models[-1].display_name(environment_naming_info, default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, dialect=self.dialect)}`" + ) + else: + for snapshot_table_info in models: + category_str = SNAPSHOT_CHANGE_CATEGORY_STR[snapshot_table_info.change_category] + self._print( + f"- `{snapshot_table_info.display_name(environment_naming_info, default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, dialect=self.dialect)}` ({category_str})" + ) + + def _print_modified_models( + self, + context_diff: ContextDiff, + modified_snapshots: t.Set[Snapshot], + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str] = None, + no_diff: bool = True, + ) -> None: + directly_modified = [] + indirectly_modified: t.List[Snapshot] = [] + metadata_modified = [] + for snapshot in modified_snapshots: + if context_diff.directly_modified(snapshot.name): + directly_modified.append(snapshot) + elif context_diff.indirectly_modified(snapshot.name): + indirectly_modified.append(snapshot) + elif context_diff.metadata_updated(snapshot.name): + metadata_modified.append(snapshot) + if directly_modified: + self._print("\n**Directly Modified:**") + for snapshot in sorted(directly_modified): + category_str = SNAPSHOT_CHANGE_CATEGORY_STR[snapshot.change_category] + self._print( + f"* `{snapshot.display_name(environment_naming_info, default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, dialect=self.dialect)}` ({category_str})" + ) + + indirectly_modified_children = sorted( + [s for s in indirectly_modified if snapshot.snapshot_id in s.parents] + ) + + if not no_diff: + diff_text = context_diff.text_diff(snapshot.name) + # sometimes there is no text_diff, like on a seed model where the data has been updated + if diff_text: + diff_text = f"\n```diff\n{diff_text}\n```" + # these are part of a Markdown list, so indent them by 2 spaces to relate them to the current list item + diff_text_indented = "\n".join( + [f" {line}" for line in diff_text.splitlines()] + ) + self._print(diff_text_indented) + else: + if indirectly_modified_children: + self._print("\n") + + if indirectly_modified_children: + self._print(" Indirectly Modified Children:") + for child_snapshot in indirectly_modified_children: + child_category_str = SNAPSHOT_CHANGE_CATEGORY_STR[ + child_snapshot.change_category + ] self._print( - f"- `{snapshot.display_name(environment_naming_info, default_catalog, dialect=self.dialect)}`" + f" - `{child_snapshot.display_name(environment_naming_info, default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, dialect=self.dialect)}` ({child_category_str})" ) - if metadata_modified: - self._print("\n**Metadata Updated:**") - for snapshot in sorted(metadata_modified): - self._print( - f"- `{snapshot.display_name(environment_naming_info, default_catalog, dialect=self.dialect)}`" - ) - if ignored_snapshot_ids: - self._print("\n**Ignored Models (Expected Plan Start):**") - for s_id in sorted(ignored_snapshot_ids): - snapshot = context_diff.snapshots[s_id] + self._print("\n") + + if indirectly_modified: + self._print("\n**Indirectly Modified:**") + self._print_models_with_threshold( + environment_naming_info, set(indirectly_modified), default_catalog + ) + if metadata_modified: + self._print("\n**Metadata Updated:**") + for snapshot in sorted(metadata_modified): self._print( - f"- `{snapshot.display_name(environment_naming_info, default_catalog, dialect=self.dialect)}` ({snapshot.get_latest(start_date(snapshot, context_diff.snapshots.values()))})" + f"- `{snapshot.display_name(environment_naming_info, default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, dialect=self.dialect)}`" ) def _show_missing_dates(self, plan: Plan, default_catalog: t.Optional[str]) -> None: @@ -1677,7 +3504,7 @@ def _show_missing_dates(self, plan: Plan, default_catalog: t.Optional[str]) -> N missing_intervals = plan.missing_intervals if not missing_intervals: return - self._print("\n**Models needing backfill (missing dates):**") + self._print("\n**Models needing backfill:**") snapshots = [] for missing in missing_intervals: snapshot = plan.context_diff.snapshots[missing.snapshot_id] @@ -1688,14 +3515,22 @@ def _show_missing_dates(self, plan: Plan, default_catalog: t.Optional[str]) -> N if not plan.deployability_index.is_deployable(snapshot): preview_modifier = " (**preview**)" + display_name = snapshot.display_name( + plan.environment_naming_info, + default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, + dialect=self.dialect, + ) snapshots.append( - f"* `{snapshot.display_name(plan.environment_naming_info, default_catalog, dialect=self.dialect)}`: {missing.format_intervals(snapshot.node.interval_unit)}{preview_modifier}" + f"* `{display_name}`: \\[{_format_missing_intervals(snapshot, missing)}]{preview_modifier}" ) length = len(snapshots) - if not self.verbose and length > self.INDIRECTLY_MODIFIED_DISPLAY_THRESHOLD: + if ( + self.verbosity < Verbosity.VERY_VERBOSE + and length > self.INDIRECTLY_MODIFIED_DISPLAY_THRESHOLD + ): self._print(snapshots[0]) - self._print(f"- `.... {length-2} more ....`\n") + self._print(f"- `.... {length - 2} more ....`\n") self._print(snapshots[-1]) else: for snap in snapshots: @@ -1704,49 +3539,153 @@ def _show_missing_dates(self, plan: Plan, default_catalog: t.Optional[str]) -> N def _show_categorized_snapshots(self, plan: Plan, default_catalog: t.Optional[str]) -> None: context_diff = plan.context_diff for snapshot in plan.categorized: - if not context_diff.directly_modified(snapshot.name): + if context_diff.directly_modified(snapshot.name): + category_str = SNAPSHOT_CHANGE_CATEGORY_STR[snapshot.change_category] + tree = Tree( + f"[bold][direct]Directly Modified: {snapshot.display_name(plan.environment_naming_info, default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, dialect=self.dialect)} ({category_str})" + ) + indirect_tree = None + for child_sid in sorted(plan.indirectly_modified.get(snapshot.snapshot_id, set())): + child_snapshot = context_diff.snapshots[child_sid] + if not indirect_tree: + indirect_tree = Tree("[indirect]Indirectly Modified Children:") + tree.add(indirect_tree) + child_category_str = SNAPSHOT_CHANGE_CATEGORY_STR[ + child_snapshot.change_category + ] + indirect_tree.add( + f"[indirect]{child_snapshot.display_name(plan.environment_naming_info, default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, dialect=self.dialect)} ({child_category_str})" + ) + if indirect_tree: + indirect_tree = self._limit_model_names(indirect_tree, self.verbosity) + elif context_diff.metadata_updated(snapshot.name): + tree = Tree( + f"[bold][metadata]Metadata Updated: {snapshot.display_name(plan.environment_naming_info, default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, dialect=self.dialect)}" + ) + else: continue - category_str = SNAPSHOT_CHANGE_CATEGORY_STR[snapshot.change_category] - tree = Tree( - f"[bold][direct]Directly Modified: {snapshot.display_name(plan.environment_naming_info, default_catalog, dialect=self.dialect)} ({category_str})" - ) - indirect_tree = None - for child_sid in sorted(plan.indirectly_modified.get(snapshot.snapshot_id, set())): - child_snapshot = context_diff.snapshots[child_sid] - if not indirect_tree: - indirect_tree = Tree("[indirect]Indirectly Modified Children:") - tree.add(indirect_tree) - child_category_str = SNAPSHOT_CHANGE_CATEGORY_STR[child_snapshot.change_category] - indirect_tree.add( - f"[indirect]{child_snapshot.display_name(plan.environment_naming_info, default_catalog, dialect=self.dialect)} ({child_category_str})" - ) - if indirect_tree: - indirect_tree = self._limit_model_names(indirect_tree, self.verbose) self._print(f"```diff\n{context_diff.text_diff(snapshot.name)}\n```\n") self._print("```\n") self._print(tree) self._print("\n```") - def log_test_results( - self, result: unittest.result.TestResult, output: str, target_dialect: str - ) -> None: - # import ipywidgets as widgets + def stop_evaluation_progress(self, success: bool = True) -> None: + super().stop_evaluation_progress(success) + self._print("\n") + + def stop_creation_progress(self, success: bool = True) -> None: + super().stop_creation_progress(success) + self._print("\n") + + def stop_promotion_progress(self, success: bool = True) -> None: + super().stop_promotion_progress(success) + self._print("\n") + + def log_warning(self, short_message: str, long_message: t.Optional[str] = None) -> None: + super().log_warning(short_message, long_message, print=not self.warning_capture_only) + + def log_error(self, message: str) -> None: + super().log_error(message, print=not self.error_capture_only) + + def log_success(self, message: str) -> None: + self._print(message) + + def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None: + # We don't log the test results if no tests were ran + if not result.testsRun: + return + + message = f"Ran `{result.testsRun}` Tests Against `{target_dialect}`" + if result.wasSuccessful(): - self._print( - f"**Successfully Ran `{str(result.testsRun)}` Tests Against `{target_dialect}`**\n\n" - ) + self._print(f"**Successfully {message}**\n\n") else: - self._print( - f"**Num Successful Tests: {result.testsRun - len(result.failures) - len(result.errors)}**\n\n" - ) - for test, _ in result.failures + result.errors: + self._print("```") + self._log_test_details(result, unittest_char_separator=False) + self._print("```\n\n") + + fail_and_error_tests = result.get_fail_and_error_tests() + self._print(f"**{message}**\n") + self._print(f"**Failed tests ({len(fail_and_error_tests)}):**") + for test in fail_and_error_tests: if isinstance(test, ModelTest): - self._print(f"* Failure Test: `{test.model.name}` - `{test.test_name}`\n\n") - self._print(f"```{output}```\n\n") + self._print(f" • `{test.model.name}`::`{test.test_name}`\n\n") - def log_error(self, message: str) -> None: - super().log_error(f"```\n{message}```\n\n") + def log_skipped_models(self, snapshot_names: t.Set[str]) -> None: + if snapshot_names: + self._print(f"**Skipped models**") + for snapshot_name in snapshot_names: + self._print(f"* `{snapshot_name}`") + self._print("") + + def log_failed_models(self, errors: t.List[NodeExecutionFailedError]) -> None: + if errors: + self._print("**Failed models**") + + error_messages = _format_node_errors(errors) + + for node_name, msg in error_messages.items(): + self._print(f"* `{node_name}`\n") + self._print(" ```") + self._print(msg) + self._print(" ```") + + self._print("") + + def show_linter_violations( + self, violations: t.List[RuleViolation], model: Model, is_error: bool = False + ) -> None: + severity = "**errors**" if is_error else "warnings" + violations_msg = "\n".join(f" - {violation}" for violation in violations) + msg = f"\nLinter {severity} for `{model._path}`:\n{violations_msg}\n" + + self._print(msg) + if is_error: + self._errors.append(msg) + else: + self._warnings.append(msg) + + @property + def captured_warnings(self) -> str: + return self._render_alert_block("WARNING", self._warnings) + + @property + def captured_errors(self) -> str: + return self._render_alert_block("CAUTION", self._errors) + + def _render_alert_block(self, block_type: str, items: t.List[str]) -> str: + # GitHub Markdown alert syntax, https://docs.github.com/en/get-started/writing-on-github/getting-started-with-writing-and-formatting-on-github/basic-writing-and-formatting-syntax#alerts + if items: + item_contents = "" + list_indicator = "- " if len(items) > 1 else "" + + for item in items: + item = item.replace("\n", "\n> ") + item_contents += f">\n> {list_indicator}{item}\n" + + if len(item_contents) > self.alert_block_max_content_length: + truncation_msg = ( + "...\n>\n> Truncated. Please check the console for full information.\n" + ) + item_contents = item_contents[ + 0 : self.alert_block_max_content_length - len(truncation_msg) + ] + item_contents += truncation_msg + break + + if len(item_contents) > self.alert_block_collapsible_threshold: + item_contents = f">
\n{item_contents}>
" + + return f"> [!{block_type}]\n{item_contents}\n" + + return "" + + def _print(self, value: t.Any, **kwargs: t.Any) -> None: + self.console.print(value, **kwargs) + with self.console.capture() as capture: + self.console.print(value, **kwargs) + self._captured_outputs.append(capture.get()) class DatabricksMagicConsole(CaptureTerminalConsole): @@ -1767,7 +3706,7 @@ def _print(self, value: t.Any, **kwargs: t.Any) -> None: super()._print(value, **kwargs) for captured_output in self._captured_outputs: print(captured_output) - self.clear_captured_outputs() + self.consume_captured_output() def _prompt(self, message: str, **kwargs: t.Any) -> t.Any: self._print(message) @@ -1780,27 +3719,50 @@ def _confirm(self, message: str, **kwargs: t.Any) -> bool: def start_evaluation_progress( self, - batches: t.Dict[Snapshot, int], + batched_intervals: t.Dict[Snapshot, Intervals], environment_naming_info: EnvironmentNamingInfo, default_catalog: t.Optional[str], + audit_only: bool = False, ) -> None: - self.evaluation_batches = batches + self.evaluation_model_batch_sizes = { + snapshot: len(intervals) for snapshot, intervals in batched_intervals.items() + } self.evaluation_environment_naming_info = environment_naming_info self.default_catalog = default_catalog - def start_snapshot_evaluation_progress(self, snapshot: Snapshot) -> None: + def start_snapshot_evaluation_progress( + self, snapshot: Snapshot, audit_only: bool = False + ) -> None: if not self.evaluation_batch_progress.get(snapshot.snapshot_id): display_name = snapshot.display_name( - self.evaluation_environment_naming_info, self.default_catalog, dialect=self.dialect + self.evaluation_environment_naming_info, + self.default_catalog if self.verbosity < Verbosity.VERY_VERBOSE else None, + dialect=self.dialect, ) self.evaluation_batch_progress[snapshot.snapshot_id] = (display_name, 0) - print(f"Starting '{display_name}', Total batches: {self.evaluation_batches[snapshot]}") + print( + f"Starting '{display_name}', Total batches: {self.evaluation_model_batch_sizes[snapshot]}" + ) def update_snapshot_evaluation_progress( - self, snapshot: Snapshot, batch_idx: int, duration_ms: t.Optional[int] + self, + snapshot: Snapshot, + interval: Interval, + batch_idx: int, + duration_ms: t.Optional[int], + num_audits_passed: int, + num_audits_failed: int, + audit_only: bool = False, + execution_stats: t.Optional[QueryExecutionStats] = None, + auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: view_name, loaded_batches = self.evaluation_batch_progress[snapshot.snapshot_id] - total_batches = self.evaluation_batches[snapshot] + + if audit_only: + print(f"Completed Auditing {view_name}") + return + + total_batches = self.evaluation_model_batch_sizes[snapshot] loaded_batches += 1 self.evaluation_batch_progress[snapshot.snapshot_id] = (view_name, loaded_batches) @@ -1812,7 +3774,7 @@ def update_snapshot_evaluation_progress( total_finished_loading = len( [ s - for s, total in self.evaluation_batches.items() + for s, total in self.evaluation_model_batch_sizes.items() if self.evaluation_batch_progress.get(s.snapshot_id, (None, -1))[1] == total ] ) @@ -1826,12 +3788,12 @@ def stop_evaluation_progress(self, success: bool = True) -> None: def start_creation_progress( self, - total_tasks: int, + snapshots: t.List[Snapshot], environment_naming_info: EnvironmentNamingInfo, default_catalog: t.Optional[str], ) -> None: """Indicates that a new creation progress has begun.""" - self.model_creation_status = (0, total_tasks) + self.model_creation_status = (0, len(snapshots)) print("Starting Creating New Model Versions") def update_creation_progress(self, snapshot: SnapshotInfoLike) -> None: @@ -1849,12 +3811,12 @@ def stop_creation_progress(self, success: bool = True) -> None: def start_promotion_progress( self, - total_tasks: int, + snapshots: t.List[SnapshotTableInfo], environment_naming_info: EnvironmentNamingInfo, default_catalog: t.Optional[str], ) -> None: """Indicates that a new snapshot promotion progress has begun.""" - self.promotion_status = (0, total_tasks) + self.promotion_status = (0, len(snapshots)) print(f"Virtually Updating '{environment_naming_info.name}'") def update_promotion_progress(self, snapshot: SnapshotInfoLike, promoted: bool) -> None: @@ -1919,15 +3881,18 @@ def __init__( console: t.Optional[RichConsole], *args: t.Any, dialect: DialectType = None, + ignore_warnings: bool = False, **kwargs: t.Any, ) -> None: self.console: RichConsole = console or srich.console self.dialect = dialect + self.verbosity = Verbosity.DEFAULT + self.ignore_warnings = ignore_warnings def _write(self, msg: t.Any, *args: t.Any, **kwargs: t.Any) -> None: self.console.log(msg, *args, **kwargs) - def start_plan_evaluation(self, plan: Plan) -> None: + def start_plan_evaluation(self, plan: EvaluatablePlan) -> None: self._write("Starting plan", plan.plan_id) def stop_plan_evaluation(self) -> None: @@ -1935,30 +3900,53 @@ def stop_plan_evaluation(self) -> None: def start_evaluation_progress( self, - batches: t.Dict[Snapshot, int], + batched_intervals: t.Dict[Snapshot, Intervals], environment_naming_info: EnvironmentNamingInfo, default_catalog: t.Optional[str], + audit_only: bool = False, ) -> None: - self._write(f"Starting evaluation for {len(batches)} snapshots") + message = "evaluation" if not audit_only else "auditing" + self._write( + f"Starting {message} for {sum(len(intervals) for intervals in batched_intervals.values())} snapshots" + ) - def start_snapshot_evaluation_progress(self, snapshot: Snapshot) -> None: - self._write(f"Evaluating {snapshot.name}") + def start_snapshot_evaluation_progress( + self, snapshot: Snapshot, audit_only: bool = False + ) -> None: + self._write(f"{'Evaluating' if not audit_only else 'Auditing'} {snapshot.name}") def update_snapshot_evaluation_progress( - self, snapshot: Snapshot, batch_idx: int, duration_ms: t.Optional[int] + self, + snapshot: Snapshot, + interval: Interval, + batch_idx: int, + duration_ms: t.Optional[int], + num_audits_passed: int, + num_audits_failed: int, + audit_only: bool = False, + execution_stats: t.Optional[QueryExecutionStats] = None, + auto_restatement_triggers: t.Optional[t.List[SnapshotId]] = None, ) -> None: - self._write(f"Evaluating {snapshot.name} | batch={batch_idx} | duration={duration_ms}ms") + message = f"Evaluated {snapshot.name} | batch={batch_idx} | duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}" + + if auto_restatement_triggers: + message += f" | auto_restatement_triggers=[{', '.join(trigger.name for trigger in auto_restatement_triggers)}]" + + if audit_only: + message = f"Audited {snapshot.name} | duration={duration_ms}ms | num_audits_passed={num_audits_passed} | num_audits_failed={num_audits_failed}" + + self._write(message) def stop_evaluation_progress(self, success: bool = True) -> None: self._write(f"Stopping evaluation with success={success}") def start_creation_progress( self, - total_tasks: int, + snapshots: t.List[Snapshot], environment_naming_info: EnvironmentNamingInfo, default_catalog: t.Optional[str], ) -> None: - self._write(f"Starting creation for {total_tasks} snapshots") + self._write(f"Starting creation for {len(snapshots)} snapshots") def update_creation_progress(self, snapshot: SnapshotInfoLike) -> None: self._write(f"Creating {snapshot.name}") @@ -1971,11 +3959,12 @@ def update_cleanup_progress(self, object_name: str) -> None: def start_promotion_progress( self, - total_tasks: int, + snapshots: t.List[SnapshotTableInfo], environment_naming_info: EnvironmentNamingInfo, default_catalog: t.Optional[str], ) -> None: - self._write(f"Starting promotion for {total_tasks} snapshots") + if snapshots: + self._write(f"Starting promotion for {len(snapshots)} snapshots") def update_promotion_progress(self, snapshot: SnapshotInfoLike, promoted: bool) -> None: self._write(f"Promoting {snapshot.name}") @@ -2004,15 +3993,32 @@ def update_env_migration_progress(self, num_tasks: int) -> None: def stop_env_migration_progress(self, success: bool = True) -> None: self._write(f"Stopping environment migration with success={success}") + def show_environment_difference_summary( + self, + context_diff: ContextDiff, + no_diff: bool = True, + ) -> None: + self._write("Environment Difference Summary:") + + if context_diff.has_requirement_changes: + self._write(f"Requirements:\n{context_diff.requirements_diff()}") + + if context_diff.has_environment_statements_changes and not no_diff: + self._write("Environment statements:\n") + for _, diff in context_diff.environment_statements_diff( + include_python_env=not context_diff.is_new_environment + ): + self._write(diff) + def show_model_difference_summary( self, context_diff: ContextDiff, environment_naming_info: EnvironmentNamingInfo, default_catalog: t.Optional[str], no_diff: bool = True, - ignored_snapshot_ids: t.Optional[t.Set[SnapshotId]] = None, ) -> None: self._write("Model Difference Summary:") + for added in context_diff.new_snapshots: self._write(f" Added: {added}") for removed in context_diff.removed_snapshots: @@ -2020,9 +4026,7 @@ def show_model_difference_summary( for modified in context_diff.modified_snapshots: self._write(f" Modified: {modified}") - def log_test_results( - self, result: unittest.result.TestResult, output: str, target_dialect: str - ) -> None: + def log_test_results(self, result: ModelTextTestResult, target_dialect: str) -> None: self._write("Test Results:", result) def show_sql(self, sql: str) -> None: @@ -2034,6 +4038,11 @@ def log_status_update(self, message: str) -> None: def log_error(self, message: str) -> None: self._write(message, style="bold red") + def log_warning(self, short_message: str, long_message: t.Optional[str] = None) -> None: + logger.warning(long_message or short_message) + if not self.ignore_warnings: + self._write(short_message, style="bold yellow") + def log_success(self, message: str) -> None: self._write(message, style="bold green") @@ -2052,10 +4061,79 @@ def show_row_diff( ) -> None: self._write(row_diff) + def show_table_diff( + self, + table_diffs: t.List[TableDiff], + show_sample: bool = True, + skip_grain_check: bool = False, + temp_schema: t.Optional[str] = None, + ) -> None: + for table_diff in table_diffs: + self.show_table_diff_summary(table_diff) + self.show_schema_diff(table_diff.schema_diff()) + self.show_row_diff( + table_diff.row_diff(temp_schema=temp_schema, skip_grain_check=skip_grain_check), + show_sample=show_sample, + skip_grain_check=skip_grain_check, + ) + + def update_table_diff_progress(self, model: str) -> None: + self._write(f"Finished table diff for: {model}") + + def start_table_diff_progress(self, models_to_diff: int) -> None: + self._write("Table diff started") + + def start_table_diff_model_progress(self, model: str) -> None: + self._write(f"Calculating differences for: {model}") + + def stop_table_diff_progress(self, success: bool) -> None: + self._write(f"Table diff finished with success={success}") + + def show_table_diff_details( + self, + models_to_diff: t.List[str], + ) -> None: + if models_to_diff: + models = "\n".join(models_to_diff) + self._write(f"Models to compare: {models}") + + def show_table_diff_summary(self, table_diff: TableDiff) -> None: + if table_diff.model_name: + self._write(f"Model: {table_diff.model_name}") + self._write(f"Source env: {table_diff.source_alias}") + self._write(f"Target env: {table_diff.target_alias}") + self._write(f"Source table: {table_diff.source}") + self._write(f"Target table: {table_diff.target}") + _, _, key_column_names = table_diff.key_columns + keys = ", ".join(key_column_names) + self._write(f"Join On: {keys}") + + +_CONSOLE: Console = NoopConsole() + + +def set_console(console: Console) -> None: + """Sets the console instance.""" + global _CONSOLE + _CONSOLE = console + -def get_console(**kwargs: t.Any) -> TerminalConsole | DatabricksMagicConsole | NotebookMagicConsole: +def configure_console(**kwargs: t.Any) -> None: + """Configures the console instance.""" + global _CONSOLE + _CONSOLE = create_console(**kwargs) + + +def get_console() -> Console: + """Returns the console instance or creates a new one if it hasn't been created yet.""" + return _CONSOLE + + +def create_console( + **kwargs: t.Any, +) -> TerminalConsole | DatabricksMagicConsole | NotebookMagicConsole: """ - Returns the console that is appropriate for the current runtime environment. + Creates a new console instance that is appropriate for the current runtime environment. Note: Google Colab environment is untested and currently assumes is compatible with the base NotebookMagicConsole. @@ -2070,6 +4148,7 @@ def get_console(**kwargs: t.Any) -> TerminalConsole | DatabricksMagicConsole | N RuntimeEnv.TERMINAL: TerminalConsole, RuntimeEnv.GOOGLE_COLAB: NotebookMagicConsole, RuntimeEnv.DEBUGGER: DebuggerTerminalConsole, + RuntimeEnv.CI: MarkdownConsole, } rich_console_kwargs: t.Dict[str, t.Any] = {"theme": srich.theme} if runtime_env.is_jupyter or runtime_env.is_google_colab: @@ -2077,3 +4156,224 @@ def get_console(**kwargs: t.Any) -> TerminalConsole | DatabricksMagicConsole | N return runtime_env_mapping[runtime_env]( **{**{"console": RichConsole(**rich_console_kwargs)}, **kwargs} ) + + +def _format_missing_intervals(snapshot: Snapshot, missing: SnapshotIntervals) -> str: + return ( + missing.format_intervals(snapshot.node.interval_unit) + if snapshot.is_incremental + else "recreate view" + if snapshot.is_view + else "full refresh" + ) + + +def _format_node_errors(errors: t.List[NodeExecutionFailedError]) -> t.Dict[str, str]: + """Formats a list of node execution errors for display.""" + + def _format_node_error(ex: NodeExecutionFailedError) -> str: + cause = ex.__cause__ if ex.__cause__ else ex + + error_msg = str(cause) + + if isinstance(cause, NodeAuditsErrors): + error_msg = _format_audits_errors(cause) + elif not isinstance(cause, (NodeExecutionFailedError, PythonModelEvalError)): + error_msg = " " + error_msg.replace("\n", "\n ") + error_msg = ( + f" {cause.__class__.__name__}:\n{error_msg}" # include error class name in msg + ) + error_msg = error_msg.replace("\n", "\n ") + error_msg = error_msg + "\n" if not error_msg.rstrip(" ").endswith("\n") else error_msg + + return error_msg + + error_messages = {} + + num_fails = len(errors) + for i, error in enumerate(errors): + node_name = "" + if isinstance(error.node, SnapshotId): + node_name = error.node.name + elif hasattr(error.node, "snapshot_name"): + node_name = error.node.snapshot_name + + msg = _format_node_error(error) + msg = " " + msg.replace("\n", "\n ") + if i == (num_fails - 1): + msg = msg if msg.rstrip(" ").endswith("\n") else msg + "\n" + + error_messages[node_name] = msg + + return error_messages + + +def _format_audits_errors(error: NodeAuditsErrors) -> str: + error_messages = [] + for err in error.errors: + audit_args_sql = [] + for arg_name, arg_value in err.audit_args.items(): + audit_args_sql.append(f"{arg_name} := {arg_value.sql(dialect=err.adapter_dialect)}") + audit_args_sql_msg = ("\n".join(audit_args_sql) + "\n\n") if audit_args_sql else "" + + err_msg = f"'{err.audit_name}' audit error: {err.count} {'row' if err.count == 1 else 'rows'} failed" + + query = "\n ".join(textwrap.wrap(err.sql(err.adapter_dialect), width=LINE_WRAP_WIDTH)) + msg = f"{err_msg}\n\nAudit arguments\n {audit_args_sql_msg}Audit query\n {query}\n\n" + msg = msg.replace("\n", "\n ") + error_messages.append(msg) + return " " + "\n".join(error_messages) + + +def _format_interval(snapshot: Snapshot, interval: Interval) -> str: + """Format an interval with an optional prefix.""" + inclusive_interval = make_inclusive(interval[0], interval[1]) + if snapshot.model.interval_unit.is_date_granularity: + return f"{to_ds(inclusive_interval[0])} - {to_ds(inclusive_interval[1])}" + + if inclusive_interval[0].date() == inclusive_interval[1].date(): + # omit end date if interval start/end on same day + return f"{to_ds(inclusive_interval[0])} {inclusive_interval[0].strftime('%H:%M:%S')}-{inclusive_interval[1].strftime('%H:%M:%S')}" + + return f"{inclusive_interval[0].strftime('%Y-%m-%d %H:%M:%S')} - {inclusive_interval[1].strftime('%Y-%m-%d %H:%M:%S')}" + + +def _format_signal_interval(snapshot: Snapshot, interval: Interval) -> str: + """Format an interval for signal output (without 'insert' prefix).""" + return _format_interval(snapshot, interval) + + +def _format_evaluation_model_interval(snapshot: Snapshot, interval: Interval) -> str: + """Format an interval for evaluation output (with 'insert' prefix).""" + if snapshot.is_model and ( + snapshot.model.kind.is_incremental + or snapshot.model.kind.is_managed + or snapshot.model.kind.is_custom + ): + formatted_interval = _format_interval(snapshot, interval) + return f"insert {formatted_interval}" + + return "" + + +def _create_evaluation_model_annotation( + snapshot: Snapshot, + interval_info: t.Optional[str], + execution_stats: t.Optional[QueryExecutionStats], +) -> str: + annotation = None + execution_stats_str = "" + if execution_stats: + rows_processed = execution_stats.total_rows_processed + if rows_processed: + # 1.00 and 1.0 to 1 + rows_processed_str = metric(rows_processed).replace(".00", "").replace(".0", "") + execution_stats_str += f"{rows_processed_str} row{'s' if rows_processed > 1 else ''}" + + bytes_processed = execution_stats.total_bytes_processed + execution_stats_str += ( + f"{', ' if execution_stats_str else ''}{naturalsize(bytes_processed, binary=True)}" + if bytes_processed + else "" + ) + execution_stats_str = f" ({execution_stats_str})" if execution_stats_str else "" + + if snapshot.is_audit: + annotation = "run standalone audit" + if snapshot.is_model: + if snapshot.model.kind.is_external: + annotation = "run external audits" + if snapshot.model.kind.is_view: + annotation = "recreate view" + if snapshot.model.kind.is_seed: + annotation = f"insert seed file{execution_stats_str}" + if snapshot.model.kind.is_full: + annotation = f"full refresh{execution_stats_str}" + if snapshot.model.kind.is_incremental_by_unique_key: + annotation = f"insert/update rows{execution_stats_str}" + if snapshot.model.kind.is_incremental_by_partition: + annotation = f"insert partitions{execution_stats_str}" + + if annotation: + return annotation + + return f"{interval_info}{execution_stats_str}" if interval_info else "" + + +def _calculate_interval_str_len( + snapshot: Snapshot, + intervals: t.List[Interval], + execution_stats: t.Optional[QueryExecutionStats] = None, +) -> int: + interval_str_len = 0 + for interval in intervals: + interval_str_len = max( + interval_str_len, + len( + _create_evaluation_model_annotation( + snapshot, _format_evaluation_model_interval(snapshot, interval), execution_stats + ) + ), + ) + return interval_str_len + + +def _calculate_audit_str_len(snapshot: Snapshot, audit_padding: int = 0) -> int: + # The annotation includes audit results. We cannot build the audits result string + # until after evaluation occurs, but we must determine the annotation column width here. + # Therefore, we add enough padding for the longest possible audits result string. + audit_str_len = 0 + audit_base_str_len = len(f", audits ") + 1 # +1 for check/X + if snapshot.is_audit: + # +1 for "1" audit count, +1 for red X + audit_str_len = max( + audit_str_len, audit_base_str_len + (2 if not snapshot.audit.blocking else 1) + ) + if snapshot.is_model and snapshot.model.audits: + num_audits = len(snapshot.model.audits_with_args) + num_nonblocking_audits = sum( + 1 + for audit in snapshot.model.audits_with_args + if not audit[0].blocking + or ("blocking" in audit[1] and audit[1]["blocking"] == exp.false()) + ) + if num_audits == 1: + # +1 for "1" audit count, +1 for red X + # if audit_padding is > 0 we're using "failed" instead of red X + audit_len = ( + audit_base_str_len + + (2 if num_nonblocking_audits else 1) + + ( + audit_padding - 1 + if num_nonblocking_audits and audit_padding > 0 + else audit_padding + ) + ) + else: + audit_len = audit_base_str_len + len(str(num_audits)) + audit_padding + if num_nonblocking_audits: + # +1 for space, +1 for red X + # if audit_padding is > 0 we're using "failed" instead of red X + audit_len += ( + len(str(num_nonblocking_audits)) + + 2 + + (audit_padding - 1 if audit_padding > 0 else audit_padding) + ) + audit_str_len = max(audit_str_len, audit_len) + return audit_str_len + + +def _calculate_annotation_str_len( + batched_intervals: t.Dict[Snapshot, t.List[Interval]], + audit_padding: int = 0, + execution_stats_len: int = 0, +) -> int: + annotation_str_len = 0 + for snapshot, intervals in batched_intervals.items(): + annotation_str_len = max( + annotation_str_len, + _calculate_interval_str_len(snapshot, intervals) + + _calculate_audit_str_len(snapshot, audit_padding) + + execution_stats_len, + ) + return annotation_str_len diff --git a/sqlmesh/core/constants.py b/sqlmesh/core/constants.py index 593075acb3..66dadb0b5d 100644 --- a/sqlmesh/core/constants.py +++ b/sqlmesh/core/constants.py @@ -1,10 +1,14 @@ from __future__ import annotations import datetime +import multiprocessing as mp +import os +import typing as t from pathlib import Path SQLMESH = "sqlmesh" -SQLMESH_PATH = Path.home() / ".sqlmesh" +SQLMESH_MANAGED = "sqlmesh_managed" +SQLMESH_PATH = Path(os.getenv("SQLMESH_HOME") or Path.home() / ".sqlmesh") PROD = "prod" """Prod""" @@ -28,6 +32,21 @@ MAX_MODEL_DEFINITION_SIZE = 10000 """Maximum number of characters in a model definition""" + +# The maximum number of fork processes, used for loading projects +# None means default to process pool, 1 means don't fork, :N is number of processes +# Factors in the number of available CPUs even if the process is bound to a subset of them +# (e.g. via taskset) to avoid oversubscribing the system and causing kill signals +if hasattr(os, "fork") and not mp.current_process().daemon: + try: + MAX_FORK_WORKERS: t.Optional[int] = int(os.getenv("MAX_FORK_WORKERS")) # type: ignore + except TypeError: + MAX_FORK_WORKERS = ( + len(os.sched_getaffinity(0)) if hasattr(os, "sched_getaffinity") else None # type: ignore + ) +else: + MAX_FORK_WORKERS = 1 + EPOCH = datetime.date(1970, 1, 1) DEFAULT_MAX_LIMIT = 1000 @@ -42,20 +61,28 @@ AUDITS = "audits" CACHE = ".cache" EXTERNAL_MODELS = "external_models" +LINTER = "linter" MACROS = "macros" MATERIALIZATIONS = "materializations" METRICS = "metrics" MODELS = "models" SEEDS = "seeds" +SIGNALS = "signals" TESTS = "tests" EXTERNAL_MODELS_YAML = "external_models.yaml" EXTERNAL_MODELS_DEPRECATED_YAML = "schema.yaml" +REQUIREMENTS = "sqlmesh-requirements.lock" DEFAULT_SCHEMA = "default" SQLMESH_VARS = "__sqlmesh__vars__" +SQLMESH_VARS_METADATA = "__sqlmesh__vars__metadata__" +SQLMESH_BLUEPRINT_VARS = "__sqlmesh__blueprint__vars__" +SQLMESH_BLUEPRINT_VARS_METADATA = "__sqlmesh__blueprint__vars__metadata__" + VAR = "var" +BLUEPRINT_VAR = "blueprint_var" GATEWAY = "gateway" SQLMESH_MACRO = "__sqlmesh__macro__" @@ -64,6 +91,8 @@ BUILTIN = "builtin" -AIRFLOW = "airflow" DBT = "dbt" NATIVE = "native" +HYBRID = "hybrid" + +DISABLE_SQLMESH_STATE_MIGRATION = "SQLMESH__AIRFLOW__DISABLE_STATE_MIGRATION" diff --git a/sqlmesh/core/context.py b/sqlmesh/core/context.py index 06b865067b..860194278b 100644 --- a/sqlmesh/core/context.py +++ b/sqlmesh/core/context.py @@ -11,7 +11,7 @@ Creating and applying a plan against the staging environment. ```python from sqlmesh.core.context import Context -context = Context(path="example", config="local_config") +context = Context(paths="example", config="local_config") plan = context.plan("staging") context.apply(plan) ``` @@ -19,14 +19,14 @@ Running audits on your data. ```python from sqlmesh.core.context import Context -context = Context(path="example", config="local_config") +context = Context(paths="example", config="local_config") context.audit("yesterday", "now") ``` Running tests on your models. ```python from sqlmesh.core.context import Context -context = Context(path="example") +context = Context(paths="example") context.test() ``` """ @@ -35,92 +35,119 @@ import abc import collections -import gc import logging +import sys import time import traceback import typing as t -import unittest.result -from datetime import timedelta from functools import cached_property from io import StringIO +from itertools import chain from pathlib import Path from shutil import rmtree from types import MappingProxyType +from datetime import datetime -import pandas as pd from sqlglot import Dialect, exp +from sqlglot.helper import first from sqlglot.lineage import GraphHTML from sqlmesh.core import analytics from sqlmesh.core import constants as c from sqlmesh.core.analytics import python_api_analytics -from sqlmesh.core.audit import Audit, StandaloneAudit -from sqlmesh.core.config import CategorizerConfig, Config, load_configs +from sqlmesh.core.audit import Audit, ModelAudit, StandaloneAudit +from sqlmesh.core.config import ( + CategorizerConfig, + Config, + load_configs, +) +from sqlmesh.core.config.connection import ConnectionConfig from sqlmesh.core.config.loader import C -from sqlmesh.core.console import Console, get_console +from sqlmesh.core.config.root import RegexKeyDict +from sqlmesh.core.console import get_console from sqlmesh.core.context_diff import ContextDiff from sqlmesh.core.dialect import ( format_model_expressions, + is_meta_expression, normalize_model_name, pandas_to_sql, parse, parse_one, ) from sqlmesh.core.engine_adapter import EngineAdapter -from sqlmesh.core.environment import Environment, EnvironmentNamingInfo -from sqlmesh.core.loader import Loader, update_model_schemas +from sqlmesh.core.environment import Environment, EnvironmentNamingInfo, EnvironmentStatements +from sqlmesh.core.loader import Loader +from sqlmesh.core.linter.definition import AnnotatedRuleViolation, Linter +from sqlmesh.core.linter.rules import BUILTIN_RULES from sqlmesh.core.macros import ExecutableOrMacro, macro from sqlmesh.core.metric import Metric, rewrite -from sqlmesh.core.model import Model +from sqlmesh.core.model import Model, update_model_schemas +from sqlmesh.core.config.model import ModelDefaultsConfig from sqlmesh.core.notification_target import ( NotificationEvent, NotificationTarget, NotificationTargetManager, ) -from sqlmesh.core.plan import Plan, PlanBuilder +from sqlmesh.core.plan import Plan, PlanBuilder, SnapshotIntervals, PlanExplainer +from sqlmesh.core.plan.definition import UserProvidedFlags from sqlmesh.core.reference import ReferenceGraph -from sqlmesh.core.scheduler import Scheduler, SignalFactory +from sqlmesh.core.scheduler import Scheduler, CompletionStatus from sqlmesh.core.schema_loader import create_external_models_file -from sqlmesh.core.selector import Selector +from sqlmesh.core.selector import Selector, NativeSelector from sqlmesh.core.snapshot import ( DeployabilityIndex, Snapshot, SnapshotEvaluator, SnapshotFingerprint, + missing_intervals, to_table_mapping, ) +from sqlmesh.core.snapshot.definition import get_next_model_interval_start from sqlmesh.core.state_sync import ( CachingStateSync, StateReader, StateSync, - cleanup_expired_views, ) +from sqlmesh.core.janitor import cleanup_expired_views, delete_expired_snapshots from sqlmesh.core.table_diff import TableDiff from sqlmesh.core.test import ( ModelTextTestResult, + ModelTestMetadata, generate_test, - get_all_model_tests, - run_model_tests, run_tests, + filter_tests_by_patterns, ) from sqlmesh.core.user import User -from sqlmesh.utils import UniqueKeyDict, sys_path +from sqlmesh.utils import UniqueKeyDict, Verbosity +from sqlmesh.utils.concurrency import concurrent_apply_to_values from sqlmesh.utils.dag import DAG -from sqlmesh.utils.date import TimeLike, now_ds, to_date +from sqlmesh.utils.date import ( + TimeLike, + to_timestamp, + format_tz_datetime, + now_timestamp, + now, + to_datetime, + make_exclusive, +) from sqlmesh.utils.errors import ( CircuitBreakerError, ConfigError, PlanError, SQLMeshError, UncategorizedPlanError, + LinterError, ) +from sqlmesh.utils.config import print_config from sqlmesh.utils.jinja import JinjaMacroRegistry +from sqlmesh.utils.windows import IS_WINDOWS, fix_windows_path if t.TYPE_CHECKING: + import pandas as pd from typing_extensions import Literal from sqlmesh.core.engine_adapter._typing import ( + BigframeSession, DF, PySparkDataFrame, PySparkSession, @@ -128,6 +155,8 @@ ) from sqlmesh.core.snapshot import Node + from sqlmesh.core.snapshot.definition import Intervals + ModelOrSnapshot = t.Union[str, Model, Snapshot] NodeOrSnapshot = t.Union[str, Model, StandaloneAudit, Snapshot] @@ -162,11 +191,23 @@ def snowpark(self) -> t.Optional[SnowparkSession]: """Returns the snowpark session if it exists.""" return self.engine_adapter.snowpark + @property + def bigframe(self) -> t.Optional[BigframeSession]: + """Returns the bigframe session if it exists.""" + return self.engine_adapter.bigframe + @property def default_catalog(self) -> t.Optional[str]: raise NotImplementedError def table(self, model_name: str) -> str: + get_console().log_warning( + "The SQLMesh context's `table` method is deprecated and will be removed " + "in a future release. Please use the `resolve_table` method instead." + ) + return self.resolve_table(model_name) + + def resolve_table(self, model_name: str) -> str: """Gets the physical table name for a given model. Args: @@ -177,9 +218,20 @@ def table(self, model_name: str) -> str: """ model_name = normalize_model_name(model_name, self.default_catalog, self.default_dialect) + if model_name not in self._model_tables: + model_name_list = "\n".join(list(self._model_tables)) + logger.debug( + f"'{model_name}' not found in model to table mapping. Available model names: \n{model_name_list}" + ) + raise SQLMeshError( + f"Unable to find a table mapping for model '{model_name}'. Has it been spelled correctly?" + ) + # We generate SQL for the default dialect because the table name may be used in a # fetchdf call and so the quotes need to be correct (eg. backticks for bigquery) - return parse_one(self._model_tables[model_name]).sql(dialect=self.default_dialect) + return parse_one(self._model_tables[model_name]).sql( + dialect=self.default_dialect, identify=True + ) def fetchdf( self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False @@ -226,7 +278,10 @@ def __init__( deployability_index: t.Optional[DeployabilityIndex] = None, default_dialect: t.Optional[str] = None, default_catalog: t.Optional[str] = None, + is_restatement: t.Optional[bool] = None, + parent_intervals: t.Optional[Intervals] = None, variables: t.Optional[t.Dict[str, t.Any]] = None, + blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, ): self.snapshots = snapshots self.deployability_index = deployability_index @@ -234,6 +289,9 @@ def __init__( self._default_catalog = default_catalog self._default_dialect = default_dialect self._variables = variables or {} + self._blueprint_variables = blueprint_variables or {} + self._is_restatement = is_restatement + self._parent_intervals = parent_intervals @property def default_dialect(self) -> t.Optional[str]: @@ -258,11 +316,27 @@ def gateway(self) -> t.Optional[str]: """Returns the gateway name.""" return self.var(c.GATEWAY) + @property + def is_restatement(self) -> t.Optional[bool]: + return self._is_restatement + + @property + def parent_intervals(self) -> t.Optional[Intervals]: + return self._parent_intervals + def var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]: """Returns a variable value.""" return self._variables.get(var_name.lower(), default) - def with_variables(self, variables: t.Dict[str, t.Any]) -> ExecutionContext: + def blueprint_var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]: + """Returns a blueprint variable value.""" + return self._blueprint_variables.get(var_name.lower(), default) + + def with_variables( + self, + variables: t.Dict[str, t.Any], + blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, + ) -> ExecutionContext: """Returns a new ExecutionContext with additional variables.""" return ExecutionContext( self._engine_adapter, @@ -270,7 +344,9 @@ def with_variables(self, variables: t.Dict[str, t.Any]) -> ExecutionContext: self.deployability_index, self._default_dialect, self._default_catalog, + self._is_restatement, variables=variables, + blueprint_variables=blueprint_variables, ) @@ -278,7 +354,6 @@ class GenericContext(BaseContext, t.Generic[C]): """Encapsulates a SQLMesh environment supplying convenient functions to perform various tasks. Args: - engine_adapter: The default engine adapter to use. notification_targets: The notification target to use. Defaults to what is defined in config. paths: The directories containing SQLMesh files. config: A Config object or the name of a Config object in config.py. @@ -290,14 +365,16 @@ class GenericContext(BaseContext, t.Generic[C]): load: Whether or not to automatically load all models and macros (default True). console: The rich instance used for printing out CLI command results. users: A list of users to make known to SQLMesh. - config_type: The type of config object to use (default Config). """ CONFIG_TYPE: t.Type[C] + """The type of config object to use (default: Config).""" + + PLAN_BUILDER_TYPE = PlanBuilder + """The type of plan builder object to use (default: PlanBuilder).""" def __init__( self, - engine_adapter: t.Optional[EngineAdapter] = None, notification_targets: t.Optional[t.List[NotificationTarget]] = None, state_sync: t.Optional[StateSync] = None, paths: t.Union[str | Path, t.Iterable[str | Path]] = "", @@ -306,35 +383,42 @@ def __init__( concurrent_tasks: t.Optional[int] = None, loader: t.Optional[t.Type[Loader]] = None, load: bool = True, - console: t.Optional[Console] = None, users: t.Optional[t.List[User]] = None, - signal_factory: t.Optional[SignalFactory] = None, + config_loader_kwargs: t.Optional[t.Dict[str, t.Any]] = None, + selector: t.Optional[t.Type[Selector]] = None, ): self.configs = ( - config if isinstance(config, dict) else load_configs(config, self.CONFIG_TYPE, paths) + config + if isinstance(config, dict) + else load_configs(config, self.CONFIG_TYPE, paths, **(config_loader_kwargs or {})) ) + self._projects = {config.project for config in self.configs.values()} self.dag: DAG[str] = DAG() self._models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") - self._audits: UniqueKeyDict[str, Audit] = UniqueKeyDict("audits") + self._audits: UniqueKeyDict[str, ModelAudit] = UniqueKeyDict("audits") self._standalone_audits: UniqueKeyDict[str, StandaloneAudit] = UniqueKeyDict( "standaloneaudits" ) + self._model_test_metadata: t.List[ModelTestMetadata] = [] + self._model_test_metadata_path_index: t.Dict[Path, t.List[ModelTestMetadata]] = {} + self._model_test_metadata_fully_qualified_name_index: t.Dict[str, ModelTestMetadata] = {} + self._models_with_tests: t.Set[str] = set() + self._macros: UniqueKeyDict[str, ExecutableOrMacro] = UniqueKeyDict("macros") self._metrics: UniqueKeyDict[str, Metric] = UniqueKeyDict("metrics") self._jinja_macros = JinjaMacroRegistry() - self._default_catalog: t.Optional[str] = None + self._requirements: t.Dict[str, str] = {} + self._environment_statements: t.List[EnvironmentStatements] = [] + self._excluded_requirements: t.Set[str] = set() + self._engine_adapter: t.Optional[EngineAdapter] = None + self._linters: t.Dict[str, Linter] = {} + self._loaded: bool = False + self._selector_cls = selector or NativeSelector self.path, self.config = t.cast(t.Tuple[Path, C], next(iter(self.configs.items()))) self._all_dialects: t.Set[str] = {self.config.dialect or ""} - # This allows overriding the default dialect's normalization strategy, so for example - # one can do `dialect="duckdb,normalization_strategy=lowercase"` and this will be - # applied to the DuckDB dialect globally - if "normalization_strategy" in str(self.config.dialect): - dialect = Dialect.get_or_raise(self.config.dialect) - type(dialect).NORMALIZATION_STRATEGY = dialect.normalization_strategy - if self.config.disable_anonymized_analytics: analytics.disable_analytics() @@ -343,39 +427,49 @@ def __init__( self.environment_ttl = self.config.environment_ttl self.pinned_environments = Environment.sanitize_names(self.config.pinned_environments) self.auto_categorize_changes = self.config.plan.auto_categorize_changes + self.selected_gateway = (gateway or self.config.default_gateway_name).lower() + + gw_model_defaults = self.config.get_gateway(self.selected_gateway).model_defaults + if gw_model_defaults: + # Merge global model defaults with the selected gateway's, if it's overriden + global_defaults = self.config.model_defaults.model_dump(exclude_unset=True) + gateway_defaults = gw_model_defaults.model_dump(exclude_unset=True) + + self.config.model_defaults = ModelDefaultsConfig( + **{**global_defaults, **gateway_defaults} + ) - self._connection_config = self.config.get_connection(self.gateway) - self.concurrent_tasks = concurrent_tasks or self._connection_config.concurrent_tasks - self._engine_adapter = engine_adapter or self._connection_config.create_engine_adapter() + # This allows overriding the default dialect's normalization strategy, so for example + # one can do `dialect="duckdb,normalization_strategy=lowercase"` and this will be + # applied to the DuckDB dialect globally + if "normalization_strategy" in str(self.config.dialect): + dialect = Dialect.get_or_raise(self.config.dialect) + type(dialect).NORMALIZATION_STRATEGY = dialect.normalization_strategy - self.console = console or get_console(dialect=self._engine_adapter.dialect) + self._loaders = [ + (loader or config.loader)(self, path, **config.loader_kwargs) + for path, config in self.configs.items() + ] - self._test_connection_config = self.config.get_test_connection( - self.gateway, self.default_catalog, default_catalog_dialect=self.engine_adapter.DIALECT + self._concurrent_tasks = concurrent_tasks + self._state_connection_config = ( + self.config.get_state_connection(self.gateway) or self.connection_config ) self._snapshot_evaluator: t.Optional[SnapshotEvaluator] = None - self._signal_factory = signal_factory + + self.console = get_console() + setattr(self.console, "dialect", self.config.dialect) self._provided_state_sync: t.Optional[StateSync] = state_sync self._state_sync: t.Optional[StateSync] = None - self._loader = (loader or self.config.loader)(**self.config.loader_kwargs) - # Should we dedupe notification_targets? If so how? self.notification_targets = (notification_targets or []) + self.config.notification_targets self.users = (users or []) + self.config.users self.users = list({user.username: user for user in self.users}.values()) self._register_notification_targets() - if ( - self.config.environment_catalog_mapping - and not self.engine_adapter.CATALOG_SUPPORT.is_multi_catalog_supported - ): - raise SQLMeshError( - "Environment catalog mapping is only supported for engine adapters that support multiple catalogs" - ) - if load: self.load() @@ -385,25 +479,34 @@ def default_dialect(self) -> t.Optional[str]: @property def engine_adapter(self) -> EngineAdapter: - """Returns an engine adapter.""" + """Returns the default engine adapter.""" + if self._engine_adapter is None: + self._engine_adapter = self.connection_config.create_engine_adapter() return self._engine_adapter @property def snapshot_evaluator(self) -> SnapshotEvaluator: if not self._snapshot_evaluator: self._snapshot_evaluator = SnapshotEvaluator( - self.engine_adapter.with_log_level(logging.INFO), + { + gateway: adapter.with_settings(execute_log_level=logging.INFO) + for gateway, adapter in self.engine_adapters.items() + }, ddl_concurrent_tasks=self.concurrent_tasks, + selected_gateway=self.selected_gateway, ) return self._snapshot_evaluator def execution_context( - self, deployability_index: t.Optional[DeployabilityIndex] = None + self, + deployability_index: t.Optional[DeployabilityIndex] = None, + engine_adapter: t.Optional[EngineAdapter] = None, + snapshots: t.Optional[t.Dict[str, Snapshot]] = None, ) -> ExecutionContext: """Returns an execution context.""" return ExecutionContext( - engine_adapter=self._engine_adapter, - snapshots=self.snapshots, + engine_adapter=engine_adapter or self.engine_adapter, + snapshots=snapshots or self.snapshots, deployability_index=deployability_index, default_dialect=self.default_dialect, default_catalog=self.default_catalog, @@ -427,16 +530,23 @@ def upsert_model(self, model: t.Union[str, Model], **kwargs: t.Any) -> Model: raise SQLMeshError(f"The disabled model '{model.name}' cannot be upserted") path = model._path - # model.copy() can't be used here due to a cached state that can be a part of a model instance. - model = t.cast(Model, type(model)(**{**t.cast(Model, model).dict(), **kwargs})) + model = model.copy(update=kwargs) model._path = path - self._models.update({model.fqn: model}) self.dag.add(model.fqn, model.depends_on) + + self._models.update( + { + model.fqn: model, + # bust the fingerprint cache for all downstream models + **{fqn: self._models[fqn].copy() for fqn in self.dag.downstream(model.fqn)}, + } + ) + update_model_schemas( self.dag, - self._models, - self.path, + models=self._models, + cache_dir=self.cache_dir, ) if model.dialect: @@ -446,7 +556,11 @@ def upsert_model(self, model: t.Union[str, Model], **kwargs: t.Any) -> Model: return model - def scheduler(self, environment: t.Optional[str] = None) -> Scheduler: + def scheduler( + self, + environment: t.Optional[str] = None, + snapshot_evaluator: t.Optional[SnapshotEvaluator] = None, + ) -> Scheduler: """Returns the built-in scheduler. Args: @@ -468,15 +582,27 @@ def scheduler(self, environment: t.Optional[str] = None) -> Scheduler: if not snapshots: raise ConfigError("No models were found") + return self.create_scheduler(snapshots, snapshot_evaluator or self.snapshot_evaluator) + + def create_scheduler( + self, snapshots: t.Iterable[Snapshot], snapshot_evaluator: SnapshotEvaluator + ) -> Scheduler: + """Creates the built-in scheduler. + + Args: + snapshots: The snapshots to schedule. + + Returns: + The built-in scheduler instance. + """ return Scheduler( snapshots, - self.snapshot_evaluator, + snapshot_evaluator, self.state_sync, default_catalog=self.default_catalog, max_workers=self.concurrent_tasks, console=self.console, notification_target_manager=self.notification_target_manager, - signal_factory=self._signal_factory, ) @property @@ -485,7 +611,8 @@ def state_sync(self) -> StateSync: self._state_sync = self._new_state_sync() if self._state_sync.get_versions(validate=False).schema_version == 0: - self._state_sync.migrate(default_catalog=self.default_catalog) + self.console.log_status_update("Initializing new project state...") + self._state_sync.migrate() self._state_sync.get_versions() self._state_sync = CachingStateSync(self._state_sync) # type: ignore return self._state_sync @@ -496,43 +623,122 @@ def state_reader(self) -> StateReader: def refresh(self) -> None: """Refresh all models that have been updated.""" - if self._loader.reload_needed(): + if any(loader.reload_needed() for loader in self._loaders): self.load() def load(self, update_schemas: bool = True) -> GenericContext[C]: """Load all files in the context's path.""" load_start_ts = time.perf_counter() - with sys_path(*self.configs): - gc.disable() - project = self._loader.load(self, update_schemas) - self._macros = project.macros - self._jinja_macros = project.jinja_macros - self._models = project.models - self._metrics = project.metrics - self._standalone_audits.clear() - self._audits.clear() - for name, audit in project.audits.items(): - if isinstance(audit, StandaloneAudit): - self._standalone_audits[name] = audit - else: - self._audits[name] = audit - self.dag = project.dag - gc.enable() - - duplicates = set(self._models) & set(self._standalone_audits) - if duplicates: - raise ConfigError( - f"Models and Standalone audits cannot have the same name: {duplicates}" - ) - self._all_dialects = {m.dialect for m in self._models.values() if m.dialect} | { - self.default_dialect or "" - } + loaded_projects = [loader.load() for loader in self._loaders] + + self.dag = DAG() + self._standalone_audits.clear() + self._audits.clear() + self._macros.clear() + self._models.clear() + self._metrics.clear() + self._requirements.clear() + self._excluded_requirements.clear() + self._linters.clear() + self._environment_statements = [] + self._model_test_metadata.clear() + self._model_test_metadata_path_index.clear() + self._model_test_metadata_fully_qualified_name_index.clear() + self._models_with_tests.clear() + + for loader, project in zip(self._loaders, loaded_projects): + self._jinja_macros = self._jinja_macros.merge(project.jinja_macros) + self._macros.update(project.macros) + self._models.update(project.models) + self._metrics.update(project.metrics) + self._audits.update(project.audits) + self._standalone_audits.update(project.standalone_audits) + self._requirements.update(project.requirements) + self._excluded_requirements.update(project.excluded_requirements) + self._environment_statements.extend(project.environment_statements) + + self._model_test_metadata.extend(project.model_test_metadata) + for metadata in project.model_test_metadata: + if metadata.path not in self._model_test_metadata_path_index: + self._model_test_metadata_path_index[metadata.path] = [] + self._model_test_metadata_path_index[metadata.path].append(metadata) + self._model_test_metadata_fully_qualified_name_index[ + metadata.fully_qualified_test_name + ] = metadata + self._models_with_tests.add(metadata.model_name) + + config = loader.config + self._linters[config.project] = Linter.from_rules( + BUILTIN_RULES.union(project.user_rules), config.linter + ) + + # Load environment statements from state for projects not in current load + if any(self._projects): + prod = self.state_reader.get_environment(c.PROD) + if prod: + existing_statements = self.state_reader.get_environment_statements(c.PROD) + for stmt in existing_statements: + if stmt.project and stmt.project not in self._projects: + self._environment_statements.append(stmt) + + uncached = set() + + if any(self._projects): + prod = self.state_reader.get_environment(c.PROD) + + if prod: + for snapshot in self.state_reader.get_snapshots(prod.snapshots).values(): + if snapshot.node.project in self._projects: + uncached.add(snapshot.name) + else: + local_store = self._standalone_audits if snapshot.is_audit else self._models + if snapshot.name in local_store: + uncached.add(snapshot.name) + else: + local_store[snapshot.name] = snapshot.node # type: ignore + + for model in self._models.values(): + self.dag.add(model.fqn, model.depends_on) + + if update_schemas: + for fqn in self.dag: + model = self._models.get(fqn) # type: ignore + + if not model or fqn in uncached: + continue + + # make a copy of remote models that depend on local models or in the downstream chain + # without this, a SELECT * FROM local will not propogate properly because the downstream + # model will get mutated (schema changes) but the object is the same as the remote cache + if any(dep in uncached for dep in model.depends_on): + uncached.add(fqn) + self._models.update({fqn: model.copy(update={"mapping_schema": {}})}) + continue + + update_model_schemas( + self.dag, + models=self._models, + cache_dir=self.cache_dir, + ) + + models = self.models.values() + for model in models: + # The model definition can be validated correctly only after the schema is set. + model.validate_definition() + + duplicates = set(self._models) & set(self._standalone_audits) + if duplicates: + raise ConfigError( + f"Models and Standalone audits cannot have the same name: {duplicates}" + ) + + self._all_dialects = {m.dialect for m in self._models.values() if m.dialect} | { + self.default_dialect or "" + } analytics.collector.on_project_loaded( - project_type=( - c.DBT if type(self._loader).__name__.lower().startswith(c.DBT) else c.NATIVE - ), + project_type=self._project_type, models_count=len(self._models), audits_count=len(self._audits), standalone_audits_count=len(self._standalone_audits), @@ -543,6 +749,7 @@ def load(self, update_schemas: bool = True) -> GenericContext[C]: project_name=self.config.project, ) + self._loaded = True return self @python_api_analytics @@ -555,7 +762,10 @@ def run( execution_time: t.Optional[TimeLike] = None, skip_janitor: bool = False, ignore_cron: bool = False, - ) -> bool: + select_models: t.Optional[t.Collection[str]] = None, + exit_on_env_update: t.Optional[int] = None, + no_auto_upstream: bool = False, + ) -> CompletionStatus: """Run the entire dag through the scheduler. Args: @@ -565,11 +775,19 @@ def run( execution_time: The date/time time reference to use for execution time. Defaults to now. skip_janitor: Whether to skip the janitor task. ignore_cron: Whether to ignore the model's cron schedule and run all available missing intervals. + select_models: A list of model selection expressions to filter models that should run. Note that + upstream dependencies of selected models will also be evaluated. + exit_on_env_update: If set, exits with the provided code if the run is interrupted by an update + to the target environment. + no_auto_upstream: Whether to not force upstream models to run. Only applicable when using `select_models`. Returns: True if the run was successful, False otherwise. """ environment = environment or self.config.default_target_environment + environment = Environment.sanitize_name(environment) + if not skip_janitor and environment.lower() == c.PROD: + self._run_janitor() self.notification_target_manager.notify( NotificationEvent.RUN_START, environment=environment @@ -578,38 +796,95 @@ def run( engine_type=self.snapshot_evaluator.adapter.dialect, state_sync_type=self.state_sync.state_type(), ) + self._load_materializations() - success = False - try: - success = self._run( - environment=environment, - start=start, - end=end, - execution_time=execution_time, - skip_janitor=skip_janitor, - ignore_cron=ignore_cron, - ) - except Exception as e: - self.notification_target_manager.notify( - NotificationEvent.RUN_FAILURE, traceback.format_exc() + env_check_attempts_num = max( + 1, + self.config.run.environment_check_max_wait + // self.config.run.environment_check_interval, + ) + + def _block_until_finalized() -> str: + for _ in range(env_check_attempts_num): + assert environment is not None # mypy + environment_state = self.state_sync.get_environment(environment) + if not environment_state: + raise SQLMeshError(f"Environment '{environment}' was not found.") + if environment_state.finalized_ts: + return environment_state.plan_id + self.console.log_warning( + f"Environment '{environment}' is being updated by plan '{environment_state.plan_id}'. " + f"Retrying in {self.config.run.environment_check_interval} seconds..." + ) + time.sleep(self.config.run.environment_check_interval) + raise SQLMeshError( + f"Exceeded the maximum wait time for environment '{environment}' to be ready. " + "This means that the environment either failed to update or the update is taking longer than expected. " + "See https://sqlmesh.readthedocs.io/en/stable/reference/configuration/#run to adjust the timeout settings." ) - logger.error(f"Run Failure: {traceback.format_exc()}") - analytics.collector.on_run_end(run_id=analytics_run_id, succeeded=False, error=e) - raise e - if success: + success = False + interrupted = False + done = False + while not done: + plan_id_at_start = _block_until_finalized() + + def _has_environment_changed() -> bool: + assert environment is not None # mypy + current_environment_state = self.state_sync.get_environment(environment) + return ( + not current_environment_state + or current_environment_state.plan_id != plan_id_at_start + or not current_environment_state.finalized_ts + ) + + try: + completion_status = self._run( + environment, + start=start, + end=end, + execution_time=execution_time, + ignore_cron=ignore_cron, + select_models=select_models, + circuit_breaker=_has_environment_changed, + no_auto_upstream=no_auto_upstream, + ) + done = True + except CircuitBreakerError: + self.console.log_warning( + f"Environment '{environment}' modified while running. Restarting the run..." + ) + if exit_on_env_update: + interrupted = True + done = True + except Exception as e: + self.notification_target_manager.notify( + NotificationEvent.RUN_FAILURE, traceback.format_exc() + ) + logger.info("Run failed.", exc_info=e) + analytics.collector.on_run_end( + run_id=analytics_run_id, succeeded=False, interrupted=False, error=e + ) + raise e + + if completion_status.is_success or interrupted: self.notification_target_manager.notify( NotificationEvent.RUN_END, environment=environment ) self.console.log_success(f"Run finished for environment '{environment}'") - else: + elif completion_status.is_failure: self.notification_target_manager.notify( NotificationEvent.RUN_FAILURE, "See console logs for details." ) - analytics.collector.on_run_end(run_id=analytics_run_id, succeeded=success) + analytics.collector.on_run_end( + run_id=analytics_run_id, succeeded=success, interrupted=interrupted + ) + + if interrupted and exit_on_env_update is not None: + sys.exit(exit_on_env_update) - return success + return completion_status @python_api_analytics def run_janitor(self, ignore_ttl: bool) -> bool: @@ -624,6 +899,61 @@ def run_janitor(self, ignore_ttl: bool) -> bool: return success + @python_api_analytics + def destroy(self) -> bool: + success = False + + # Collect resources to be deleted + environments = self.state_reader.get_environments() + schemas_to_delete = set() + tables_to_delete = set() + views_to_delete = set() + all_snapshot_infos = set() + + # For each environment find schemas and tables + for environment in environments: + all_snapshot_infos.update(environment.snapshots) + snapshots = self.state_reader.get_snapshots(environment.snapshots).values() + for snapshot in snapshots: + if snapshot.is_model and not snapshot.is_symbolic: + # Get the appropriate adapter + if environment.gateway_managed and snapshot.model_gateway: + adapter = self.engine_adapters.get( + snapshot.model_gateway, self.engine_adapter + ) + else: + adapter = self.engine_adapter + + if environment.suffix_target.is_schema or environment.suffix_target.is_catalog: + schema = snapshot.qualified_view_name.schema_for_environment( + environment.naming_info, dialect=adapter.dialect + ) + catalog = snapshot.qualified_view_name.catalog_for_environment( + environment.naming_info, dialect=adapter.dialect + ) + if catalog: + schemas_to_delete.add(f"{catalog}.{schema}") + else: + schemas_to_delete.add(schema) + + if environment.suffix_target.is_table: + view_name = snapshot.qualified_view_name.for_environment( + environment.naming_info, dialect=adapter.dialect + ) + views_to_delete.add(view_name) + + # Add snapshot tables + table_name = snapshot.table_name() + tables_to_delete.add(table_name) + + if self.console.start_destroy(schemas_to_delete, views_to_delete, tables_to_delete): + try: + success = self._destroy() + finally: + self.console.stop_destroy(success=success) + + return success + @t.overload def get_model( self, model_or_snapshot: ModelOrSnapshot, raise_if_missing: Literal[True] = True @@ -648,7 +978,12 @@ def get_model( Returns: The expected model. """ - if isinstance(model_or_snapshot, str): + if isinstance(model_or_snapshot, Snapshot): + return model_or_snapshot.model + if not isinstance(model_or_snapshot, str): + return model_or_snapshot + + try: # We should try all dialects referenced in the project for cases when models use mixed dialects. for dialect in self._all_dialects: normalized_name = normalize_model_name( @@ -658,13 +993,16 @@ def get_model( ) if normalized_name in self._models: return self._models[normalized_name] - elif isinstance(model_or_snapshot, Snapshot): - return model_or_snapshot.model - else: - return model_or_snapshot + except: + pass if raise_if_missing: - raise SQLMeshError(f"Cannot find model for '{model_or_snapshot}'") + if model_or_snapshot.endswith((".sql", ".py")): + msg = "Resolving models by path is not supported, please pass in the model name instead." + else: + msg = f"Cannot find model with name '{model_or_snapshot}'" + + raise SQLMeshError(msg) return None @@ -695,13 +1033,7 @@ def get_snapshot( """ if isinstance(node_or_snapshot, Snapshot): return node_or_snapshot - if isinstance(node_or_snapshot, str) and not self.standalone_audits.get(node_or_snapshot): - node_or_snapshot = normalize_model_name( - node_or_snapshot, - dialect=self.default_dialect, - default_catalog=self.default_catalog, - ) - fqn = node_or_snapshot if isinstance(node_or_snapshot, str) else node_or_snapshot.fqn + fqn = self._node_or_snapshot_to_fqn(node_or_snapshot) snapshot = self.snapshots.get(fqn) if raise_if_missing and not snapshot: @@ -709,19 +1041,21 @@ def get_snapshot( return snapshot - def config_for_path(self, path: Path) -> Config: + def config_for_path(self, path: Path) -> t.Tuple[Config, Path]: + """Returns the config and path of the said project for a given file path.""" for config_path, config in self.configs.items(): try: path.relative_to(config_path) - return config + return config, config_path except ValueError: pass - return self.config + return self.config, self.path - def config_for_node(self, node: str | Model | StandaloneAudit) -> Config: - if isinstance(node, str): - return self.config_for_path(self.get_snapshot(node, raise_if_missing=True).node._path) # type: ignore - return self.config_for_path(node._path) # type: ignore + def config_for_node(self, node: Model | Audit) -> Config: + path = node._path + if path is None: + return self.config + return self.config_for_path(path)[0] # type: ignore @property def models(self) -> MappingProxyType[str, Model]: @@ -738,6 +1072,11 @@ def standalone_audits(self) -> MappingProxyType[str, StandaloneAudit]: """Returns all registered standalone audits in this context.""" return MappingProxyType(self._standalone_audits) + @property + def models_with_tests(self) -> t.Set[str]: + """Returns all models with tests in this context.""" + return self._models_with_tests + @property def snapshots(self) -> t.Dict[str, Snapshot]: """Generates and returns snapshots based on models registered in this context. @@ -748,10 +1087,13 @@ def snapshots(self) -> t.Dict[str, Snapshot]: return self._snapshots() @property + def requirements(self) -> t.Dict[str, str]: + """Returns the Python dependencies of the project loaded in this context.""" + return self._requirements.copy() + + @cached_property def default_catalog(self) -> t.Optional[str]: - if self._default_catalog is None: - self._default_catalog = self._scheduler.get_default_catalog(self) - return self._default_catalog + return self.default_catalog_per_gateway.get(self.selected_gateway) @python_api_analytics def render( @@ -778,7 +1120,7 @@ def render( Returns: The rendered expression. """ - execution_time = execution_time or now_ds() + execution_time = execution_time or now() model = self.get_model(model_or_snapshot, raise_if_missing=True) @@ -793,9 +1135,13 @@ def render( expand = self.dag.upstream(model.fqn) if expand is True else expand or [] if model.is_seed: + import pandas as pd + df = next( model.render( - context=self.execution_context(), + context=self.execution_context( + engine_adapter=self._get_engine_adapter(model.gateway) + ), start=start, end=end, execution_time=execution_time, @@ -804,13 +1150,17 @@ def render( ) return next(pandas_to_sql(t.cast(pd.DataFrame, df), model.columns_to_types)) + snapshots = self.snapshots + deployability_index = DeployabilityIndex.create(snapshots.values(), start=start) + return model.render_query_or_raise( start=start, end=end, execution_time=execution_time, - snapshots=self.snapshots, + snapshots=snapshots, expand=expand, - engine_adapter=self.engine_adapter, + deployability_index=deployability_index, + engine_adapter=self._get_engine_adapter(model.gateway), **kwargs, ) @@ -835,7 +1185,21 @@ def evaluate( execution_time: The date/time time reference to use for execution time. limit: A limit applied to the model. """ - snapshot = self.get_snapshot(model_or_snapshot, raise_if_missing=True) + snapshots = self.snapshots + fqn = self._node_or_snapshot_to_fqn(model_or_snapshot) + if fqn not in snapshots: + raise SQLMeshError(f"Cannot find snapshot for '{fqn}'") + snapshot = snapshots[fqn] + + # Expand all uncategorized parents since physical tables don't exist for them yet + expand = [ + parent + for parent in self.dag.upstream(snapshot.model.fqn) + if (parent_snapshot := snapshots.get(parent)) + and parent_snapshot.is_model + and parent_snapshot.model.is_sql + and not parent_snapshot.categorized + ] df = self.snapshot_evaluator.evaluate_and_fetch( snapshot, @@ -844,6 +1208,7 @@ def evaluate( execution_time=execution_time, snapshots=self.snapshots, limit=limit or c.DEFAULT_MAX_LIMIT, + expand=expand, ) if df is None: @@ -855,38 +1220,95 @@ def evaluate( def format( self, transpile: t.Optional[str] = None, + rewrite_casts: t.Optional[bool] = None, append_newline: t.Optional[bool] = None, + *, + check: t.Optional[bool] = None, + paths: t.Optional[t.Tuple[t.Union[str, Path], ...]] = None, **kwargs: t.Any, - ) -> None: + ) -> bool: """Format all SQL models and audits.""" - format_targets = {**self._models, **self._audits} - for target in format_targets.values(): - if not target._path.suffix == ".sql": + filtered_targets = [ + target + for target in chain(self._models.values(), self._audits.values()) + if target._path is not None + and target._path.suffix == ".sql" + and (not paths or any(target._path.samefile(p) for p in paths)) + ] + unformatted_file_paths = [] + + for target in filtered_targets: + if ( + target._path is None or target.formatting is False + ): # introduced to satisfy type checker as still want to pull filter out as many targets as possible before loop continue + with open(target._path, "r+", encoding="utf-8") as file: - expressions = parse( - file.read(), default_dialect=self.config_for_node(target).dialect - ) - if transpile: - for prop in expressions[0].expressions: - if prop.name.lower() == "dialect": - prop.replace( - exp.Property( - this="dialect", - value=exp.Literal.string(transpile or target.dialect), - ) - ) - format = self.config_for_node(target).format - opts = {**format.generator_options, **kwargs} - file.seek(0) - file.write( - format_model_expressions(expressions, transpile or target.dialect, **opts) + before = file.read() + + after = self._format( + target, + before, + transpile=transpile, + rewrite_casts=rewrite_casts, + append_newline=append_newline, + **kwargs, ) - if append_newline is None: - append_newline = format.append_newline - if append_newline: - file.write("\n") - file.truncate() + + if not check: + file.seek(0) + file.write(after) + file.truncate() + elif before != after: + unformatted_file_paths.append(target._path) + + if unformatted_file_paths: + for path in unformatted_file_paths: + self.console.log_status_update(f"{path} needs reformatting.") + self.console.log_status_update( + f"\n{len(unformatted_file_paths)} file(s) need reformatting." + ) + return False + + return True + + def _format( + self, + target: Model | Audit, + before: str, + *, + transpile: t.Optional[str] = None, + rewrite_casts: t.Optional[bool] = None, + append_newline: t.Optional[bool] = None, + **kwargs: t.Any, + ) -> str: + expressions = parse(before, default_dialect=self.config_for_node(target).dialect) + if transpile and is_meta_expression(expressions[0]): + for prop in expressions[0].expressions: + if prop.name.lower() == "dialect": + prop.replace( + exp.Property( + this="dialect", + value=exp.Literal.string(transpile or target.dialect), + ) + ) + + format_config = self.config_for_node(target).format + after = format_model_expressions( + expressions, + transpile or target.dialect, + rewrite_casts=( + rewrite_casts if rewrite_casts is not None else not format_config.no_rewrite_casts + ), + **{**format_config.generator_options, **kwargs}, + ) + + if append_newline is None: + append_newline = format_config.append_newline + if append_newline: + after += "\n" + + return after @python_api_analytics def plan( @@ -897,12 +1319,14 @@ def plan( end: t.Optional[TimeLike] = None, execution_time: t.Optional[TimeLike] = None, create_from: t.Optional[str] = None, - skip_tests: bool = False, + skip_tests: t.Optional[bool] = None, restate_models: t.Optional[t.Iterable[str]] = None, - no_gaps: bool = False, - skip_backfill: bool = False, + no_gaps: t.Optional[bool] = None, + skip_backfill: t.Optional[bool] = None, + empty_backfill: t.Optional[bool] = None, forward_only: t.Optional[bool] = None, allow_destructive_models: t.Optional[t.Collection[str]] = None, + allow_additive_models: t.Optional[t.Collection[str]] = None, no_prompts: t.Optional[bool] = None, auto_apply: t.Optional[bool] = None, no_auto_categorization: t.Optional[bool] = None, @@ -913,7 +1337,12 @@ def plan( categorizer_config: t.Optional[CategorizerConfig] = None, enable_preview: t.Optional[bool] = None, no_diff: t.Optional[bool] = None, - run: bool = False, + run: t.Optional[bool] = None, + diff_rendered: t.Optional[bool] = None, + skip_linter: t.Optional[bool] = None, + explain: t.Optional[bool] = None, + ignore_cron: t.Optional[bool] = None, + min_intervals: t.Optional[int] = None, ) -> Plan: """Interactively creates a plan. @@ -938,8 +1367,10 @@ def plan( part of the target environment have no data gaps when compared against previous snapshots for same models. skip_backfill: Whether to skip the backfill step. Default: False. + empty_backfill: Like skip_backfill, but also records processed intervals. forward_only: Whether the purpose of the plan is to make forward only changes. allow_destructive_models: Models whose forward-only changes are allowed to be destructive. + allow_additive_models: Models whose forward-only changes are allowed to be additive. no_prompts: Whether to disable interactive prompts for the backfill time range. Please note that if this flag is set to true and there are uncategorized changes the plan creation will fail. Default: False. @@ -956,6 +1387,11 @@ def plan( enable_preview: Indicates whether to enable preview for forward-only models in development environments. no_diff: Hide text differences for changed models. run: Whether to run latest intervals as part of the plan application. + diff_rendered: Whether the diff should compare raw vs rendered models + skip_linter: Linter runs by default so this will skip it if enabled + explain: Whether to explain the plan instead of applying it. + min_intervals: Adjust the plan start date on a per-model basis in order to ensure at least this many intervals are covered + on every model when checking for missing intervals Returns: The populated Plan object. @@ -970,8 +1406,10 @@ def plan( restate_models=restate_models, no_gaps=no_gaps, skip_backfill=skip_backfill, + empty_backfill=empty_backfill, forward_only=forward_only, allow_destructive_models=allow_destructive_models, + allow_additive_models=allow_additive_models, no_auto_categorization=no_auto_categorization, effective_from=effective_from, include_unmodified=include_unmodified, @@ -980,8 +1418,23 @@ def plan( categorizer_config=categorizer_config, enable_preview=enable_preview, run=run, + diff_rendered=diff_rendered, + skip_linter=skip_linter, + explain=explain, + ignore_cron=ignore_cron, + min_intervals=min_intervals, ) + plan = plan_builder.build() + + if no_auto_categorization or plan.uncategorized: + # Prompts are required if the auto categorization is disabled + # or if there are any uncategorized snapshots in the plan + no_prompts = False + + if explain: + auto_apply = True + self.console.plan( plan_builder, auto_apply if auto_apply is not None else self.config.plan.auto_apply, @@ -990,7 +1443,7 @@ def plan( no_prompts=no_prompts if no_prompts is not None else self.config.plan.no_prompts, ) - return plan_builder.build() + return plan @python_api_analytics def plan_builder( @@ -1001,12 +1454,14 @@ def plan_builder( end: t.Optional[TimeLike] = None, execution_time: t.Optional[TimeLike] = None, create_from: t.Optional[str] = None, - skip_tests: bool = False, + skip_tests: t.Optional[bool] = None, restate_models: t.Optional[t.Iterable[str]] = None, - no_gaps: bool = False, - skip_backfill: bool = False, + no_gaps: t.Optional[bool] = None, + skip_backfill: t.Optional[bool] = None, + empty_backfill: t.Optional[bool] = None, forward_only: t.Optional[bool] = None, allow_destructive_models: t.Optional[t.Collection[str]] = None, + allow_additive_models: t.Optional[t.Collection[str]] = None, no_auto_categorization: t.Optional[bool] = None, effective_from: t.Optional[TimeLike] = None, include_unmodified: t.Optional[bool] = None, @@ -1014,7 +1469,13 @@ def plan_builder( backfill_models: t.Optional[t.Collection[str]] = None, categorizer_config: t.Optional[CategorizerConfig] = None, enable_preview: t.Optional[bool] = None, - run: bool = False, + run: t.Optional[bool] = None, + diff_rendered: t.Optional[bool] = None, + skip_linter: t.Optional[bool] = None, + explain: t.Optional[bool] = None, + ignore_cron: t.Optional[bool] = None, + min_intervals: t.Optional[int] = None, + always_include_local_changes: t.Optional[bool] = None, ) -> PlanBuilder: """Creates a plan builder. @@ -1036,6 +1497,7 @@ def plan_builder( part of the target environment have no data gaps when compared against previous snapshots for same models. skip_backfill: Whether to skip the backfill step. Default: False. + empty_backfill: Like skip_backfill, but also records processed intervals. forward_only: Whether the purpose of the plan is to make forward only changes. allow_destructive_models: Models whose forward-only changes are allowed to be destructive. no_auto_categorization: Indicates whether to disable automatic categorization of model @@ -1049,21 +1511,73 @@ def plan_builder( backfill_models: A list of model selection strings to filter the models for which the data should be backfilled. enable_preview: Indicates whether to enable preview for forward-only models in development environments. run: Whether to run latest intervals as part of the plan application. + diff_rendered: Whether the diff should compare raw vs rendered models + min_intervals: Adjust the plan start date on a per-model basis in order to ensure at least this many intervals are covered + on every model when checking for missing intervals + always_include_local_changes: Usually when restatements are present, local changes in the filesystem are ignored. + However, it can be desirable to deploy changes + restatements in the same plan, so this flag overrides the default behaviour. Returns: The plan builder. """ + kwargs: t.Dict[str, t.Optional[UserProvidedFlags]] = { + "start": start, + "end": end, + "execution_time": execution_time, + "create_from": create_from, + "skip_tests": skip_tests, + "restate_models": list(restate_models) if restate_models is not None else None, + "no_gaps": no_gaps, + "skip_backfill": skip_backfill, + "empty_backfill": empty_backfill, + "forward_only": forward_only, + "allow_destructive_models": list(allow_destructive_models) + if allow_destructive_models is not None + else None, + "allow_additive_models": list(allow_additive_models) + if allow_additive_models is not None + else None, + "no_auto_categorization": no_auto_categorization, + "effective_from": effective_from, + "include_unmodified": include_unmodified, + "select_models": list(select_models) if select_models is not None else None, + "backfill_models": list(backfill_models) if backfill_models is not None else None, + "enable_preview": enable_preview, + "run": run, + "diff_rendered": diff_rendered, + "skip_linter": skip_linter, + "min_intervals": min_intervals, + } + user_provided_flags: t.Dict[str, UserProvidedFlags] = { + k: v for k, v in kwargs.items() if v is not None + } + + skip_tests = explain or skip_tests or False + no_gaps = no_gaps or False + skip_backfill = skip_backfill or False + empty_backfill = empty_backfill or False + run = run or False + diff_rendered = diff_rendered or False + skip_linter = skip_linter or False + min_intervals = min_intervals or 0 + environment = environment or self.config.default_target_environment environment = Environment.sanitize_name(environment) is_dev = environment != c.PROD + if include_unmodified is None: + include_unmodified = self.config.plan.include_unmodified + if skip_backfill and not no_gaps and not is_dev: - raise ConfigError( - "When targeting the production environment either the backfill should not be skipped or the lack of data gaps should be enforced (--no-gaps flag)." + # note: we deliberately don't mention the --no-gaps flag in case the plan came from the sqlmesh_dbt command + # todo: perhaps we could have better error messages if we check sys.argv[0] for which cli is running? + self.console.log_warning( + "Skipping the backfill stage for production can lead to unexpected results, such as tables being empty or incremental data with non-contiguous time ranges being made available.\n" + "If you are doing this deliberately to create an empty version of a table to test a change, please consider using Virtual Data Environments instead." ) - if run and is_dev: - raise ConfigError("The '--run' flag is only supported for the production environment.") + if not skip_linter: + self.lint_models() self._run_plan_tests(skip_tests=skip_tests) @@ -1080,6 +1594,11 @@ def plan_builder( else: expanded_destructive_models = None + if allow_additive_models: + expanded_additive_models = model_selector.expand_model_selections(allow_additive_models) + else: + expanded_additive_models = None + if backfill_models: backfill_models = model_selector.expand_model_selections(backfill_models) else: @@ -1087,12 +1606,18 @@ def plan_builder( models_override: t.Optional[UniqueKeyDict[str, Model]] = None if select_models: - models_override = model_selector.select_models( - select_models, - environment, - fallback_env_name=create_from or c.PROD, - ensure_finalized_snapshots=self.config.plan.use_finalized_state, - ) + try: + models_override = model_selector.select_models( + select_models, + environment, + fallback_env_name=create_from or c.PROD, + ensure_finalized_snapshots=self.config.plan.use_finalized_state, + ) + except SQLMeshError as e: + logger.exception(e) # ensure the full stack trace is logged + raise PlanError( + f"{e}\nCheck the SQLMesh log file for the full stack trace.\nIf the model has been fixed locally, please ensure that the --select-model expression includes it." + ) if not backfill_models: # Only backfill selected models unless explicitly specified. backfill_models = model_selector.expand_model_selections(select_models) @@ -1101,83 +1626,130 @@ def plan_builder( if restate_models is not None: expanded_restate_models = model_selector.expand_model_selections(restate_models) + if (restate_models is not None and not expanded_restate_models) or ( + backfill_models is not None and not backfill_models + ): + raise PlanError( + "Selector did not return any models. Please check your model selection and try again." + ) + + if always_include_local_changes is None: + # default behaviour - if restatements are detected; we operate entirely out of state and ignore local changes + force_no_diff = restate_models is not None or ( + backfill_models is not None and not backfill_models + ) + else: + force_no_diff = not always_include_local_changes + snapshots = self._snapshots(models_override) context_diff = self._context_diff( environment or c.PROD, snapshots=snapshots, create_from=create_from, - force_no_diff=restate_models is not None - or (backfill_models is not None and not backfill_models), + force_no_diff=force_no_diff, ensure_finalized_snapshots=self.config.plan.use_finalized_state, + diff_rendered=diff_rendered, + always_recreate_environment=self.config.plan.always_recreate_environment, ) + modified_model_names = { + *context_diff.modified_snapshots, + *[s.name for s in context_diff.added], + } + + if ( + is_dev + and not include_unmodified + and backfill_models is None + and expanded_restate_models is None + ): + # Only backfill modified and added models. + # This ensures that no models outside the impacted sub-DAG(s) will be backfilled unexpectedly. + backfill_models = modified_model_names or None - # If no end date is specified, use the max interval end from prod - # to prevent unintended evaluation of the entire DAG. + max_interval_end_per_model = None + default_start, default_end = None, None if not run: - if backfill_models is not None: - # Only consider selected models for the default end value. - models_for_default_end = backfill_models.copy() - for name in backfill_models: - if name not in snapshots: - continue - snapshot = snapshots[name] - snapshot_id = snapshot.snapshot_id - if ( - snapshot_id in context_diff.added - and snapshot_id in context_diff.new_snapshots - ): - # If the selected model is a newly added model, then we should narrow down the intervals - # that should be considered for the default plan end value by including its parents. - models_for_default_end |= {s.name for s in snapshot.parents} - default_end = self.state_sync.greatest_common_interval_end( - c.PROD, - models_for_default_end, - ensure_finalized_snapshots=self.config.plan.use_finalized_state, - ) - else: - default_end = self.state_sync.max_interval_end_for_environment( - c.PROD, ensure_finalized_snapshots=self.config.plan.use_finalized_state - ) - else: - default_end = None + ignore_cron = False + max_interval_end_per_model = self._get_max_interval_end_per_model( + snapshots, backfill_models + ) + # If no end date is specified, use the max interval end from prod + # to prevent unintended evaluation of the entire DAG. + default_start, default_end = self._get_plan_default_start_end( + snapshots, + max_interval_end_per_model, + backfill_models, + modified_model_names, + execution_time or now(), + ) - default_start = to_date(default_end) - timedelta(days=1) if default_end and is_dev else None + # Refresh snapshot intervals to ensure that they are up to date with values reflected in the max_interval_end_per_model. + self.state_sync.refresh_snapshot_intervals(context_diff.snapshots.values()) - return PlanBuilder( - context_diff=context_diff, - start=start, + start_override_per_model = self._calculate_start_override_per_model( + min_intervals, + start or default_start, + end or default_end, + execution_time or now(), + backfill_models, + snapshots, + max_interval_end_per_model, + ) + + if not self.config.virtual_environment_mode.is_full: + forward_only = True + elif forward_only is None: + forward_only = self.config.plan.forward_only + + # When handling prod restatements, only clear intervals from other model versions if we are using full virtual environments + # If we are not, then there is no point, because none of the data in dev environments can be promoted by definition + restate_all_snapshots = ( + expanded_restate_models is not None + and not is_dev + and self.config.virtual_environment_mode.is_full + ) + + return self.PLAN_BUILDER_TYPE( + context_diff=context_diff, + start=start, end=end, execution_time=execution_time, apply=self.apply, restate_models=expanded_restate_models, + restate_all_snapshots=restate_all_snapshots, backfill_models=backfill_models, no_gaps=no_gaps, skip_backfill=skip_backfill, + empty_backfill=empty_backfill, is_dev=is_dev, - forward_only=( - forward_only if forward_only is not None else self.config.plan.forward_only - ), + forward_only=forward_only, allow_destructive_models=expanded_destructive_models, + allow_additive_models=expanded_additive_models, environment_ttl=environment_ttl, environment_suffix_target=self.config.environment_suffix_target, - environment_catalog_mapping=self.config.environment_catalog_mapping, + environment_catalog_mapping=self.environment_catalog_mapping, categorizer_config=categorizer_config or self.auto_categorize_changes, auto_categorization_enabled=not no_auto_categorization, effective_from=effective_from, - include_unmodified=( - include_unmodified - if include_unmodified is not None - else self.config.plan.include_unmodified - ), + include_unmodified=include_unmodified, default_start=default_start, default_end=default_end, enable_preview=( - enable_preview if enable_preview is not None else self.config.plan.enable_preview + enable_preview if enable_preview is not None else self._plan_preview_enabled ), end_bounded=not run, ensure_finalized_snapshots=self.config.plan.use_finalized_state, - engine_schema_differ=self.engine_adapter.SCHEMA_DIFFER, + start_override_per_model=start_override_per_model, + end_override_per_model=max_interval_end_per_model, console=self.console, + user_provided_flags=user_provided_flags, + selected_models={ + dbt_unique_id + for model in model_selector.expand_model_selections(select_models or "*") + if (dbt_unique_id := snapshots[model].node.dbt_unique_id) + }, + explain=explain or False, + ignore_cron=ignore_cron or False, ) def apply( @@ -1202,6 +1774,16 @@ def apply( return if plan.uncategorized: raise UncategorizedPlanError("Can't apply a plan with uncategorized changes.") + + if plan.explain: + explainer = PlanExplainer( + state_reader=self.state_reader, + default_catalog=self.default_catalog, + console=self.console, + ) + explainer.evaluate(plan.to_evaluatable()) + return + self.notification_target_manager.notify( NotificationEvent.APPLY_START, environment=plan.environment_naming_info.name, @@ -1216,7 +1798,7 @@ def apply( plan_id=plan.plan_id, exc=traceback.format_exc(), ) - logger.error(f"Apply Failure: {traceback.format_exc()}") + logger.info("Plan application failed.", exc_info=e) raise e self.notification_target_manager.notify( NotificationEvent.APPLY_END, @@ -1233,12 +1815,13 @@ def invalidate_environment(self, name: str, sync: bool = False) -> None: sync: If True, the call blocks until the environment is deleted. Otherwise, the environment will be deleted asynchronously by the janitor process. """ + name = Environment.sanitize_name(name) self.state_sync.invalidate_environment(name) if sync: self._cleanup_environments() - self.console.log_success(f"Environment '{name}' has been deleted.") + self.console.log_success(f"Environment '{name}' deleted.") else: - self.console.log_success(f"Environment '{name}' has been invalidated.") + self.console.log_success(f"Environment '{name}' invalidated.") @python_api_analytics def diff(self, environment: t.Optional[str] = None, detailed: bool = False) -> bool: @@ -1254,17 +1837,22 @@ def diff(self, environment: t.Optional[str] = None, detailed: bool = False) -> b environment = environment or self.config.default_target_environment environment = Environment.sanitize_name(environment) context_diff = self._context_diff(environment) - self.console.show_model_difference_summary( + self.console.show_environment_difference_summary( context_diff, - EnvironmentNamingInfo.from_environment_catalog_mapping( - self.config.environment_catalog_mapping, - name=environment, - suffix_target=self.config.environment_suffix_target, - normalize_name=context_diff.normalize_environment_name, - ), - self.default_catalog, no_diff=not detailed, ) + if context_diff.has_changes: + self.console.show_model_difference_summary( + context_diff, + EnvironmentNamingInfo.from_environment_catalog_mapping( + self.environment_catalog_mapping, + name=environment, + suffix_target=self.config.environment_suffix_target, + normalize_name=context_diff.normalize_environment_name, + ), + self.default_catalog, + no_diff=not detailed, + ) return context_diff.has_changes @python_api_analytics @@ -1272,16 +1860,20 @@ def table_diff( self, source: str, target: str, - on: t.List[str] | exp.Condition | None = None, - skip_columns: t.List[str] | None = None, - model_or_snapshot: t.Optional[ModelOrSnapshot] = None, + on: t.Optional[t.List[str] | exp.Condition] = None, + skip_columns: t.Optional[t.List[str]] = None, + select_models: t.Optional[t.Collection[str]] = None, where: t.Optional[str | exp.Condition] = None, limit: int = 20, show: bool = True, show_sample: bool = True, decimals: int = 3, skip_grain_check: bool = False, - ) -> TableDiff: + warn_grain_check: bool = False, + temp_schema: t.Optional[str] = None, + schema_diff_ignore_case: bool = False, + **kwargs: t.Any, # catch-all to prevent an 'unexpected keyword argument' error if an table_diff extension passes in some extra arguments + ) -> t.List[TableDiff]: """Show a diff between two tables. Args: @@ -1290,56 +1882,220 @@ def table_diff( on: The join condition, table aliases must be "s" and "t" for source and target. If omitted, the table's grain will be used. skip_columns: The columns to skip when computing the table diff. - model_or_snapshot: The model or snapshot to use when environments are passed in. + select_models: The models or snapshots to use when environments are passed in. where: An optional where statement to filter results. limit: The limit of the sample dataframe. show: Show the table diff output in the console. show_sample: Show the sample dataframe in the console. Requires show=True. decimals: The number of decimal places to keep when comparing floating point columns. skip_grain_check: Skip check for rows that contain null or duplicate grains. + temp_schema: The schema to use for temporary tables. Returns: - The TableDiff object containing schema and summary differences. + The list of TableDiff objects containing schema and summary differences. """ - source_alias, target_alias = source, target - if model_or_snapshot: - model = self.get_model(model_or_snapshot, raise_if_missing=True) + if "|" in source or "|" in target: + raise ConfigError( + "Cross-database table diffing is available in Tobiko Cloud. Read more here: " + "https://sqlmesh.readthedocs.io/en/stable/guides/tablediff/#diffing-tables-or-views-across-gateways" + ) + + table_diffs: t.List[TableDiff] = [] + + # Diffs multiple or a single model across two environments + if select_models: source_env = self.state_reader.get_environment(source) target_env = self.state_reader.get_environment(target) - if not source_env: raise SQLMeshError(f"Could not find environment '{source}'") if not target_env: - raise SQLMeshError(f"Could not find environment '{target}')") - - source = next( - snapshot for snapshot in source_env.snapshots if snapshot.name == model.fqn - ).table_name() - target = next( - snapshot for snapshot in target_env.snapshots if snapshot.name == model.fqn - ).table_name() - source_alias = source_env.name - target_alias = target_env.name - - if not on: - for ref in model.all_references: - if ref.unique: - expr = ref.expression - - if isinstance(expr, exp.Tuple): - on = [key.this.sql() for key in expr.expressions] + raise SQLMeshError(f"Could not find environment '{target}'") + criteria = ", ".join(f"'{c}'" for c in select_models) + try: + selected_models = self._new_selector().expand_model_selections(select_models) + if not selected_models: + self.console.log_status_update( + f"No models matched the selection criteria: {criteria}" + ) + except Exception as e: + raise SQLMeshError(e) + + models_to_diff: t.List[ + t.Tuple[Model, EngineAdapter, str, str, t.Optional[t.List[str] | exp.Condition]] + ] = [] + models_without_grain: t.List[Model] = [] + source_snapshots_to_name = { + snapshot.name: snapshot for snapshot in source_env.snapshots + } + target_snapshots_to_name = { + snapshot.name: snapshot for snapshot in target_env.snapshots + } + + for model_fqn in selected_models: + model = self._models[model_fqn] + adapter = self._get_engine_adapter(model.gateway) + source_snapshot = source_snapshots_to_name.get(model.fqn) + target_snapshot = target_snapshots_to_name.get(model.fqn) + + if target_snapshot and source_snapshot: + if (source_snapshot.fingerprint != target_snapshot.fingerprint) and ( + (source_snapshot.version != target_snapshot.version) + or source_snapshot.is_forward_only + ): + # Compare the virtual layer instead of the physical layer because the virtual layer is guaranteed to point + # to the correct/active snapshot for the model in the specified environment, taking into account things like dev previews + source = source_snapshot.qualified_view_name.for_environment( + source_env.naming_info, adapter.dialect + ) + target = target_snapshot.qualified_view_name.for_environment( + target_env.naming_info, adapter.dialect + ) + model_on = on or model.on + if not model_on: + models_without_grain.append(model) else: - # Handle a single Column or Paren expression - on = [expr.this.sql()] + models_to_diff.append((model, adapter, source, target, model_on)) + + if models_without_grain: + model_names = "\n".join( + f"─ {model.name} \n at '{model._path}'" for model in models_without_grain + ) + message = ( + "SQLMesh doesn't know how to join the tables for the following models:\n" + f"{model_names}\n\n" + "Please specify a `grain` in each model definition. It must be unique and not null." + ) + if warn_grain_check: + self.console.log_warning(message) + else: + raise SQLMeshError(message) + + if models_to_diff: + self.console.show_table_diff_details( + [model[0].name for model in models_to_diff], + ) + + self.console.start_table_diff_progress(len(models_to_diff)) + try: + tasks_num = min(len(models_to_diff), self.concurrent_tasks) + table_diffs = concurrent_apply_to_values( + list(models_to_diff), + lambda model_info: self._model_diff( + model=model_info[0], + adapter=model_info[1], + source=model_info[2], + target=model_info[3], + on=model_info[4], + source_alias=source_env.name, + target_alias=target_env.name, + limit=limit, + decimals=decimals, + skip_columns=skip_columns, + where=where, + show=show, + temp_schema=temp_schema, + skip_grain_check=skip_grain_check, + schema_diff_ignore_case=schema_diff_ignore_case, + ), + tasks_num=tasks_num, + ) + self.console.stop_table_diff_progress(success=True) + except: + self.console.stop_table_diff_progress(success=False) + raise + elif selected_models: + self.console.log_status_update( + f"No models contain differences with the selection criteria: {criteria}" + ) + + else: + table_diffs = [ + self._table_diff( + source=source, + target=target, + source_alias=source, + target_alias=target, + limit=limit, + decimals=decimals, + adapter=self.engine_adapter, + on=on, + skip_columns=skip_columns, + where=where, + schema_diff_ignore_case=schema_diff_ignore_case, + ) + ] + + if show: + self.console.show_table_diff(table_diffs, show_sample, skip_grain_check, temp_schema) + + return table_diffs + + def _model_diff( + self, + model: Model, + adapter: EngineAdapter, + source: str, + target: str, + source_alias: str, + target_alias: str, + limit: int, + decimals: int, + on: t.Optional[t.List[str] | exp.Condition] = None, + skip_columns: t.Optional[t.List[str]] = None, + where: t.Optional[str | exp.Condition] = None, + show: bool = True, + temp_schema: t.Optional[str] = None, + skip_grain_check: bool = False, + schema_diff_ignore_case: bool = False, + ) -> TableDiff: + self.console.start_table_diff_model_progress(model.name) + table_diff = self._table_diff( + on=on, + skip_columns=skip_columns, + where=where, + limit=limit, + decimals=decimals, + model=model, + adapter=adapter, + source=source, + target=target, + source_alias=source_alias, + target_alias=target_alias, + schema_diff_ignore_case=schema_diff_ignore_case, + ) + + if show: + # Trigger row_diff in parallel execution so it's available for ordered display later + table_diff.row_diff(temp_schema=temp_schema, skip_grain_check=skip_grain_check) + + self.console.update_table_diff_progress(model.name) + + return table_diff + + def _table_diff( + self, + source: str, + target: str, + source_alias: str, + target_alias: str, + limit: int, + decimals: int, + adapter: EngineAdapter, + on: t.Optional[t.List[str] | exp.Condition] = None, + model: t.Optional[Model] = None, + skip_columns: t.Optional[t.List[str]] = None, + where: t.Optional[str | exp.Condition] = None, + schema_diff_ignore_case: bool = False, + ) -> TableDiff: if not on: raise SQLMeshError( "SQLMesh doesn't know how to join the two tables. Specify the `grains` in each model definition or pass join column names in separate `-o` flags." ) - table_diff = TableDiff( - adapter=self._engine_adapter, + return TableDiff( + adapter=adapter.with_settings(execute_log_level=logger.getEffectiveLevel()), source=source, target=target, on=on, @@ -1347,19 +2103,12 @@ def table_diff( where=where, source_alias=source_alias, target_alias=target_alias, - model_name=model.name if model_or_snapshot else None, - model_dialect=model.dialect if model_or_snapshot else None, limit=limit, decimals=decimals, + model_name=model.name if model else None, + model_dialect=model.dialect if model else None, + schema_diff_ignore_case=schema_diff_ignore_case, ) - if show: - self.console.show_schema_diff(table_diff.schema_diff()) - self.console.show_row_diff( - table_diff.row_diff(skip_grain_check=skip_grain_check), - show_sample=show_sample, - skip_grain_check=skip_grain_check, - ) - return table_diff @python_api_analytics def get_dag( @@ -1422,7 +2171,7 @@ def render_dag(self, path: str, select_models: t.Optional[t.Collection[str]] = N suffix = file_path.suffix if suffix != ".html": if suffix: - logger.warning( + get_console().log_warning( f"The extension {suffix} does not designate an html file. A file with a `.html` extension will be created instead." ) path = str(file_path.with_suffix(".html")) @@ -1463,14 +2212,16 @@ def create_test( } try: - test_adapter = self._test_connection_config.create_engine_adapter( + model_to_test = self.get_model(model, raise_if_missing=True) + test_adapter = self.test_connection_config.create_engine_adapter( register_comments_override=False ) + generate_test( - model=self.get_model(model, raise_if_missing=True), + model=model_to_test, input_queries=input_queries, models=self._models, - engine_adapter=self._engine_adapter, + engine_adapter=self._get_engine_adapter(model_to_test.gateway), test_engine_adapter=test_adapter, project_path=self.path, overwrite=overwrite, @@ -1480,63 +2231,43 @@ def create_test( include_ctes=include_ctes, ) finally: - test_adapter.close() + if test_adapter: + test_adapter.close() @python_api_analytics def test( self, match_patterns: t.Optional[t.List[str]] = None, tests: t.Optional[t.List[str]] = None, - verbose: bool = False, + verbosity: Verbosity = Verbosity.DEFAULT, preserve_fixtures: bool = False, stream: t.Optional[t.TextIO] = None, ) -> ModelTextTestResult: """Discover and run model tests""" - if verbose: + if verbosity >= Verbosity.VERBOSE: + import pandas as pd + pd.set_option("display.max_columns", None) - verbosity = 2 - else: - verbosity = 1 - if tests: - result = run_model_tests( - tests=tests, - models=self._models, - config=self.config, - gateway=self.gateway, - dialect=self.default_dialect, - verbosity=verbosity, - patterns=match_patterns, - preserve_fixtures=preserve_fixtures, - stream=stream, - default_catalog=self.default_catalog, - default_catalog_dialect=self.engine_adapter.DIALECT, - ) - else: - test_meta = [] - - for path, config in self.configs.items(): - test_meta.extend( - get_all_model_tests( - path / c.TESTS, - patterns=match_patterns, - ignore_patterns=config.ignore_patterns, - variables=config.variables, - ) - ) + test_meta = self.select_tests(tests=tests, patterns=match_patterns) - result = run_tests( - model_test_metadata=test_meta, - models=self._models, - config=self.config, - gateway=self.gateway, - dialect=self.default_dialect, - verbosity=verbosity, - preserve_fixtures=preserve_fixtures, - stream=stream, - default_catalog=self.default_catalog, - default_catalog_dialect=self.engine_adapter.DIALECT, - ) + result = run_tests( + model_test_metadata=test_meta, + models=self._models, + config=self.config, + selected_gateway=self.selected_gateway, + dialect=self.default_dialect, + verbosity=verbosity, + preserve_fixtures=preserve_fixtures, + stream=stream, + default_catalog=self.default_catalog, + default_catalog_dialect=self.config.dialect or "", + ) + + self.console.log_test_results( + result, + self.test_connection_config._engine_adapter.DIALECT, + ) return result @@ -1548,7 +2279,7 @@ def audit( *, models: t.Optional[t.Iterator[str]] = None, execution_time: t.Optional[TimeLike] = None, - ) -> None: + ) -> bool: """Audit models. Args: @@ -1556,6 +2287,9 @@ def audit( end: The end of the interval to audit. models: The models to audit. All models will be audited if not specified. execution_time: The date/time time reference to use for execution time. Defaults to now. + + Returns: + False if any of the audits failed, True otherwise. """ snapshots = ( @@ -1564,8 +2298,9 @@ def audit( else self.snapshots.values() ) - num_audits = sum(len(snapshot.audits_with_args) for snapshot in snapshots) + num_audits = sum(len(snapshot.node.audits_with_args) for snapshot in snapshots) self.console.log_status_update(f"Found {num_audits} audit(s).") + errors = [] skipped_count = 0 for snapshot in snapshots: @@ -1573,8 +2308,8 @@ def audit( snapshot=snapshot, start=start, end=end, + execution_time=execution_time, snapshots=self.snapshots, - raise_exception=False, ): audit_id = f"{audit_result.audit.name}" if audit_result.model: @@ -1606,6 +2341,7 @@ def audit( ) self.console.log_status_update("Done.") + return not errors @python_api_analytics def rewrite(self, sql: str, dialect: str = "") -> exp.Expression: @@ -1627,6 +2363,61 @@ def rewrite(self, sql: str, dialect: str = "") -> exp.Expression: dialect=dialect or self.default_dialect, ) + @python_api_analytics + def check_intervals( + self, + environment: t.Optional[str], + no_signals: bool, + select_models: t.Collection[str], + start: t.Optional[TimeLike] = None, + end: t.Optional[TimeLike] = None, + ) -> t.Dict[Snapshot, SnapshotIntervals]: + """Check intervals for a given environment. + + Args: + environment: The environment or prod if None. + select_models: A list of model selection strings to show intervals for. + start: The start of the intervals to check. + end: The end of the intervals to check. + """ + + environment = environment or c.PROD + env = self.state_reader.get_environment(environment) + if not env: + raise SQLMeshError(f"Environment '{environment}' was not found.") + + snapshots = {k.name: v for k, v in self.state_sync.get_snapshots(env.snapshots).items()} + + missing = { + k.name: v + for k, v in missing_intervals( + snapshots.values(), start=start, end=end, execution_time=end + ).items() + } + + if select_models: + selected: t.Collection[str] = self._select_models_for_run( + select_models, True, snapshots.values() + ) + else: + selected = snapshots.keys() + + results = {} + execution_context = self.execution_context(snapshots=snapshots) + + for fqn in selected: + snapshot = snapshots[fqn] + intervals = missing.get(fqn) or [] + + results[snapshot] = SnapshotIntervals( + snapshot.snapshot_id, + intervals + if no_signals + else snapshot.check_ready_intervals(intervals, execution_context), + ) + + return results + @python_api_analytics def migrate(self) -> None: """Migrates SQLMesh to the current running version. @@ -1634,9 +2425,9 @@ def migrate(self) -> None: Please contact your SQLMesh administrator before doing this. """ self.notification_target_manager.notify(NotificationEvent.MIGRATION_START) + self._load_materializations() try: self._new_state_sync().migrate( - default_catalog=self.default_catalog, promoted_snapshots_only=self.config.migration.promoted_snapshots_only, ) except Exception as e: @@ -1673,6 +2464,14 @@ def create_external_models(self, strict: bool = False) -> None: external_models_yaml = ( path / c.EXTERNAL_MODELS_YAML if not deprecated_yaml.exists() else deprecated_yaml ) + + external_models_gateway: t.Optional[str] = self.gateway or self.config.default_gateway + if not external_models_gateway: + # can happen if there was no --gateway defined and the default_gateway is '' + # which means that the single gateway syntax is being used which means there is + # no named gateway which means we should not stamp `gateway:` on the external models + external_models_gateway = None + create_external_models_file( path=external_models_yaml, models=UniqueKeyDict( @@ -1683,16 +2482,18 @@ def create_external_models(self, strict: bool = False) -> None: if self.config_for_node(model) is config }, ), - adapter=self._engine_adapter, + adapter=self.engine_adapter, state_reader=self.state_reader, dialect=config.model_defaults.dialect, - gateway=self.gateway, + gateway=external_models_gateway, max_workers=self.concurrent_tasks, strict=strict, ) @python_api_analytics - def print_info(self, skip_connection: bool = False) -> None: + def print_info( + self, skip_connection: bool = False, verbosity: Verbosity = Verbosity.DEFAULT + ) -> None: """Prints information about connections, models, macros, etc. to the console.""" self.console.log_status_update(f"Models: {len(self.models)}") self.console.log_status_update(f"Macros: {len(self._macros) - len(macro.get_registry())}") @@ -1700,16 +2501,38 @@ def print_info(self, skip_connection: bool = False) -> None: if skip_connection: return - self._try_connection("data warehouse", self._engine_adapter.ping) + if verbosity >= Verbosity.VERBOSE: + self.console.log_status_update("") + print_config(self.config.get_connection(self.gateway), self.console, "Connection") + print_config( + self.config.get_test_connection(self.gateway), self.console, "Test Connection" + ) + print_config( + self.config.get_state_connection(self.gateway), self.console, "State Connection" + ) + self._try_connection("data warehouse", self.engine_adapter.ping) state_connection = self.config.get_state_connection(self.gateway) if state_connection: self._try_connection("state backend", state_connection.connection_validator()) + @python_api_analytics + def print_environment_names(self) -> None: + """Prints all environment names along with expiry datetime.""" + result = self._new_state_sync().get_environments_summary() + if not result: + raise SQLMeshError( + "This project has no environments. Create an environment using the `sqlmesh plan` command." + ) + self.console.print_environments(result) + def close(self) -> None: """Releases all resources allocated by this context.""" - self.snapshot_evaluator.close() - self.state_sync.close() + if self._snapshot_evaluator: + self._snapshot_evaluator.close() + + if self._state_sync: + self._state_sync.close() def _run( self, @@ -1718,118 +2541,181 @@ def _run( start: t.Optional[TimeLike], end: t.Optional[TimeLike], execution_time: t.Optional[TimeLike], - skip_janitor: bool, ignore_cron: bool, - ) -> bool: - if not skip_janitor and environment.lower() == c.PROD: - self._run_janitor() + select_models: t.Optional[t.Collection[str]], + circuit_breaker: t.Optional[t.Callable[[], bool]], + no_auto_upstream: bool, + ) -> CompletionStatus: + scheduler = self.scheduler(environment=environment) + snapshots = scheduler.snapshots + + if select_models is not None: + select_models = self._select_models_for_run( + select_models, no_auto_upstream, snapshots.values() + ) - env_check_attempts_num = max( - 1, - self.config.run.environment_check_max_wait - // self.config.run.environment_check_interval, + completion_status = scheduler.run( + environment, + start=start, + end=end, + execution_time=execution_time, + ignore_cron=ignore_cron, + circuit_breaker=circuit_breaker, + selected_snapshots=select_models, + auto_restatement_enabled=environment.lower() == c.PROD, + run_environment_statements=True, ) - def _block_until_finalized() -> str: - for _ in range(env_check_attempts_num): - assert environment is not None # mypy - environment_state = self.state_sync.get_environment(environment) - if not environment_state: - raise SQLMeshError(f"Environment '{environment}' was not found.") - if environment_state.finalized_ts: - return environment_state.plan_id - logger.warning( - "Environment '%s' is being updated by plan '%s'. Retrying in %s seconds...", - environment, - environment_state.plan_id, - self.config.run.environment_check_interval, - ) - time.sleep(self.config.run.environment_check_interval) - raise SQLMeshError( - f"Exceeded the maximum wait time for environment '{environment}' to be ready. " - "This means that the environment either failed to update or the update is taking longer than expected. " - "See https://sqlmesh.readthedocs.io/en/stable/reference/configuration/#run to adjust the timeout settings." - ) + if completion_status.is_nothing_to_do: + next_run_ready_msg = "" - done = False - while not done: - plan_id_at_start = _block_until_finalized() - - def _has_environment_changed() -> bool: - assert environment is not None # mypy - current_environment_state = self.state_sync.get_environment(environment) - return ( - not current_environment_state - or current_environment_state.plan_id != plan_id_at_start - or not current_environment_state.finalized_ts - ) + next_ready_interval_start = get_next_model_interval_start(snapshots.values()) + if next_ready_interval_start: + utc_time = format_tz_datetime(next_ready_interval_start) + local_time = format_tz_datetime(next_ready_interval_start, use_local_timezone=True) + time_msg = local_time if local_time == utc_time else f"{local_time} ({utc_time})" + next_run_ready_msg = f"\n\nNext run will be ready at {time_msg}." - try: - success = self.scheduler(environment=environment).run( - environment, - start=start, - end=end, - execution_time=execution_time, - ignore_cron=ignore_cron, - circuit_breaker=_has_environment_changed, - ) - done = True - except CircuitBreakerError: - logger.warning( - "Environment '%s' has been modified while running. Restarting the run...", - environment, - ) + self.console.log_status_update( + f"No models are ready to run. Please wait until a model `cron` interval has elapsed.{next_run_ready_msg}" + ) - return success + return completion_status def _apply(self, plan: Plan, circuit_breaker: t.Optional[t.Callable[[], bool]]) -> None: - self._scheduler.create_plan_evaluator(self).evaluate(plan, circuit_breaker=circuit_breaker) + self._scheduler.create_plan_evaluator(self).evaluate( + plan.to_evaluatable(), circuit_breaker=circuit_breaker + ) @python_api_analytics - def table_name(self, model_name: str, dev: bool) -> str: - """Returns the name of the pysical table for the given model name. + def table_name( + self, model_name: str, environment: t.Optional[str] = None, prod: bool = False + ) -> str: + """Returns the name of the pysical table for the given model name in the target environment. Args: model_name: The name of the model. - dev: Whether to use the deployability index for the table name. + environment: The environment to source the model version from. + prod: If True, return the name of the physical table that will be used in production for the model version + promoted in the target environment. Returns: The name of the physical table. """ - deployability_index = ( - DeployabilityIndex.create(self.snapshots.values()) - if dev - else DeployabilityIndex.all_deployable() + environment = environment or self.config.default_target_environment + fqn = self._node_or_snapshot_to_fqn(model_name) + target_env = self.state_reader.get_environment(environment) + if not target_env: + raise SQLMeshError(f"Environment '{environment}' was not found.") + + snapshot_info = None + for s in target_env.snapshots: + if s.name == fqn: + snapshot_info = s + break + if not snapshot_info: + raise SQLMeshError( + f"Model '{model_name}' was not found in environment '{environment}'." + ) + + if target_env.name == c.PROD or prod: + return snapshot_info.table_name() + + snapshots = self.state_reader.get_snapshots(target_env.snapshots) + deployability_index = DeployabilityIndex.create(snapshots) + + return snapshot_info.table_name( + is_deployable=deployability_index.is_deployable(snapshot_info.snapshot_id) ) - snapshot = self.get_snapshot(model_name) - if not snapshot: - raise SQLMeshError(f"Model '{model_name}' was not found.") - return snapshot.table_name(is_deployable=deployability_index.is_deployable(snapshot)) def clear_caches(self) -> None: - for path in self.configs: - rmtree(path / c.CACHE) + paths_to_remove = [path / c.CACHE for path in self.configs] + paths_to_remove.append(self.cache_dir) - def _run_tests(self, verbose: bool = False) -> t.Tuple[unittest.result.TestResult, str]: + if IS_WINDOWS: + paths_to_remove = [fix_windows_path(path) for path in paths_to_remove] + + for path in paths_to_remove: + if path.exists(): + rmtree(path) + + if isinstance(self._state_sync, CachingStateSync): + self._state_sync.clear_cache() + + def export_state( + self, + output_file: Path, + environment_names: t.Optional[t.List[str]] = None, + local_only: bool = False, + confirm: bool = True, + ) -> None: + from sqlmesh.core.state_sync.export_import import export_state + + # trigger a connection to the StateSync so we can fail early if there is a problem + # note we still need to do this even if we are doing a local export so we know what 'versions' to write + self.state_sync.get_versions(validate=True) + + local_snapshots = self.snapshots if local_only else None + + if self.console.start_state_export( + output_file=output_file, + gateway=self.selected_gateway, + state_connection_config=self._state_connection_config, + environment_names=environment_names, + local_only=local_only, + confirm=confirm, + ): + try: + export_state( + state_sync=self.state_sync, + output_file=output_file, + local_snapshots=local_snapshots, + environment_names=environment_names, + console=self.console, + ) + self.console.stop_state_export(success=True, output_file=output_file) + except: + self.console.stop_state_export(success=False, output_file=output_file) + raise + + def import_state(self, input_file: Path, clear: bool = False, confirm: bool = True) -> None: + from sqlmesh.core.state_sync.export_import import import_state + + if self.console.start_state_import( + input_file=input_file, + gateway=self.selected_gateway, + state_connection_config=self._state_connection_config, + clear=clear, + confirm=confirm, + ): + try: + import_state( + state_sync=self.state_sync, + input_file=input_file, + clear=clear, + console=self.console, + ) + self.console.stop_state_import(success=True, input_file=input_file) + except: + self.console.stop_state_import(success=False, input_file=input_file) + raise + + def _run_tests( + self, verbosity: Verbosity = Verbosity.DEFAULT + ) -> t.Tuple[ModelTextTestResult, str]: test_output_io = StringIO() - result = self.test(stream=test_output_io, verbose=verbose) + result = self.test(stream=test_output_io, verbosity=verbosity) return result, test_output_io.getvalue() - def _run_plan_tests( - self, skip_tests: bool = False - ) -> t.Tuple[t.Optional[unittest.result.TestResult], t.Optional[str]]: + def _run_plan_tests(self, skip_tests: bool = False) -> t.Optional[ModelTextTestResult]: if not skip_tests: - result, test_output = self._run_tests() - if result.testsRun > 0: - self.console.log_test_results( - result, test_output, self._test_connection_config._engine_adapter.DIALECT - ) + result = self.test() if not result.wasSuccessful(): raise PlanError( - "Cannot generate plan due to failing test(s). Fix test(s) and run again" + "Cannot generate plan due to failing test(s). Fix test(s) and run again." ) - return result, test_output - return None, None + return result + return None @property def _model_tables(self) -> t.Dict[str, str]: @@ -1843,7 +2729,7 @@ def _model_tables(self) -> t.Dict[str, str]: if snapshot.version else snapshot.qualified_view_name.for_environment( EnvironmentNamingInfo.from_environment_catalog_mapping( - self.config.environment_catalog_mapping, + self.environment_catalog_mapping, name=c.PROD, suffix_target=self.config.environment_suffix_target, ) @@ -1852,71 +2738,101 @@ def _model_tables(self) -> t.Dict[str, str]: for fqn, snapshot in self.snapshots.items() } - def _snapshots( - self, models_override: t.Optional[UniqueKeyDict[str, Model]] = None - ) -> t.Dict[str, Snapshot]: - prod = self.state_reader.get_environment(c.PROD) - remote_snapshots = ( - { - snapshot.name: snapshot - for snapshot in self.state_reader.get_snapshots(prod.snapshots).values() - } - if prod - else {} + @cached_property + def cache_dir(self) -> Path: + if self.config.cache_dir: + cache_path = Path(self.config.cache_dir) + if cache_path.is_absolute(): + return cache_path + return self.path / cache_path + + # Default to .cache directory in the project path + return self.path / c.CACHE + + @cached_property + def engine_adapters(self) -> t.Dict[str, EngineAdapter]: + """Returns all the engine adapters for the gateways defined in the configurations.""" + adapters: t.Dict[str, EngineAdapter] = {self.selected_gateway: self.engine_adapter} + for config in self.configs.values(): + for gateway_name in config.gateways: + if gateway_name not in adapters: + connection = config.get_connection(gateway_name) + adapter = connection.create_engine_adapter( + concurrent_tasks=self.concurrent_tasks, + ) + adapters[gateway_name] = adapter + return adapters + + @cached_property + def default_catalog_per_gateway(self) -> t.Dict[str, str]: + """Returns the default catalogs for each engine adapter.""" + return self._scheduler.get_default_catalog_per_gateway(self) + + @property + def concurrent_tasks(self) -> int: + if self._concurrent_tasks is None: + self._concurrent_tasks = self.connection_config.concurrent_tasks + return self._concurrent_tasks + + @cached_property + def connection_config(self) -> ConnectionConfig: + return self.config.get_connection(self.selected_gateway) + + @cached_property + def test_connection_config(self) -> ConnectionConfig: + return self.config.get_test_connection( + self.gateway, + self.default_catalog, + default_catalog_dialect=self.config.dialect, ) - local_nodes = {**(models_override or self._models), **self._standalone_audits} - nodes = local_nodes.copy() - audits = self._audits.copy() - projects = {config.project for config in self.configs.values()} - - for name, snapshot in remote_snapshots.items(): - if name not in nodes and snapshot.node.project not in projects: - nodes[name] = snapshot.node - if snapshot.is_model: - for audit in snapshot.audits: - if name not in audits: - audits[name] = audit - - def _nodes_to_snapshots(nodes: t.Dict[str, Node]) -> t.Dict[str, Snapshot]: - snapshots: t.Dict[str, Snapshot] = {} - fingerprint_cache: t.Dict[str, SnapshotFingerprint] = {} - - for node in nodes.values(): - if node.fqn not in local_nodes and node.fqn in remote_snapshots: - ttl = remote_snapshots[node.fqn].ttl - else: - config = self.config_for_node(node) - ttl = config.snapshot_ttl - - snapshot = Snapshot.from_node( - node, - nodes=nodes, - audits=audits, - cache=fingerprint_cache, - ttl=ttl, - ) - snapshots[snapshot.name] = snapshot - return snapshots + @cached_property + def environment_catalog_mapping(self) -> RegexKeyDict: + engine_adapter = None + try: + engine_adapter = self.engine_adapter + except Exception: + pass - snapshots = _nodes_to_snapshots(nodes) + if ( + self.config.environment_catalog_mapping + and engine_adapter + and not self.engine_adapter.catalog_support.is_multi_catalog_supported + ): + raise SQLMeshError( + "Environment catalog mapping is only supported for engine adapters that support multiple catalogs" + ) + return self.config.environment_catalog_mapping + + def _get_engine_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter: + if gateway: + if adapter := self.engine_adapters.get(gateway): + return adapter + raise SQLMeshError(f"Gateway '{gateway}' not found in the available engine adapters.") + return self.engine_adapter + + def _snapshots( + self, models_override: t.Optional[UniqueKeyDict[str, Model]] = None + ) -> t.Dict[str, Snapshot]: + nodes = {**(models_override or self._models), **self._standalone_audits} + snapshots = self._nodes_to_snapshots(nodes) stored_snapshots = self.state_reader.get_snapshots(snapshots.values()) unrestorable_snapshots = { snapshot for snapshot in stored_snapshots.values() - if snapshot.name in local_nodes and snapshot.unrestorable + if snapshot.name in nodes and snapshot.unrestorable } if unrestorable_snapshots: for snapshot in unrestorable_snapshots: logger.info( "Found a unrestorable snapshot %s. Restamping the model...", snapshot.name ) - node = local_nodes[snapshot.name] + node = nodes[snapshot.name] nodes[snapshot.name] = node.copy( update={"stamp": f"revert to {snapshot.identifier}"} ) - snapshots = _nodes_to_snapshots(nodes) + snapshots = self._nodes_to_snapshots(nodes) stored_snapshots = self.state_reader.get_snapshots(snapshots.values()) for snapshot in stored_snapshots.values(): @@ -1932,6 +2848,8 @@ def _context_diff( create_from: t.Optional[str] = None, force_no_diff: bool = False, ensure_finalized_snapshots: bool = False, + diff_rendered: bool = False, + always_recreate_environment: bool = False, ) -> ContextDiff: environment = Environment.sanitize_name(environment) if force_no_diff: @@ -1942,21 +2860,70 @@ def _context_diff( snapshots=snapshots or self.snapshots, create_from=create_from or c.PROD, state_reader=self.state_reader, + provided_requirements=self._requirements, + excluded_requirements=self._excluded_requirements, ensure_finalized_snapshots=ensure_finalized_snapshots, + diff_rendered=diff_rendered, + environment_statements=self._environment_statements, + gateway_managed_virtual_layer=self.config.gateway_managed_virtual_layer, + infer_python_dependencies=self.config.infer_python_dependencies, + always_recreate_environment=always_recreate_environment, ) + def _destroy(self) -> bool: + # Invalidate all environments, including prod + for environment in self.state_reader.get_environments(): + self.state_sync.invalidate_environment(name=environment.name, protect_prod=False) + self.console.log_success(f"Environment '{environment.name}' invalidated.") + + # Run janitor to clean up all objects + self._run_janitor(ignore_ttl=True) + + # Remove state tables, including backup tables + self.state_sync.remove_state(including_backup=True) + self.console.log_status_update("State tables removed.") + + # Finally clear caches + self.clear_caches() + + return True + def _run_janitor(self, ignore_ttl: bool = False) -> None: - self._cleanup_environments() - expired_snapshots = self.state_sync.delete_expired_snapshots(ignore_ttl=ignore_ttl) - self.snapshot_evaluator.cleanup( - expired_snapshots, on_complete=self.console.update_cleanup_progress - ) + current_ts = now_timestamp() + # Clean up expired environments by removing their views and schemas + self._cleanup_environments(current_ts=current_ts) + + delete_expired_snapshots( + self.state_sync, + self.snapshot_evaluator, + current_ts=current_ts, + ignore_ttl=ignore_ttl, + console=self.console, + batch_size=self.config.janitor.expired_snapshots_batch_size, + ) self.state_sync.compact_intervals() - def _cleanup_environments(self) -> None: - expired_environments = self.state_sync.delete_expired_environments() - cleanup_expired_views(self.engine_adapter, expired_environments, console=self.console) + def _cleanup_environments(self, current_ts: t.Optional[int] = None) -> None: + current_ts = current_ts or now_timestamp() + + expired_environments_summaries = self.state_sync.get_expired_environments( + current_ts=current_ts + ) + + for expired_env_summary in expired_environments_summaries: + expired_env = self.state_reader.get_environment(expired_env_summary.name) + + if expired_env: + cleanup_expired_views( + default_adapter=self.engine_adapter, + engine_adapters=self.engine_adapters, + environments=[expired_env], + warn_on_delete_failure=self.config.janitor.warn_on_delete_failure, + console=self.console, + ) + + self.state_sync.delete_expired_environments(current_ts=current_ts) def _try_connection(self, connection_name: str, validator: t.Callable[[], None]) -> None: connection_name = connection_name.capitalize() @@ -1969,13 +2936,17 @@ def _try_connection(self, connection_name: str, validator: t.Callable[[], None]) def _new_state_sync(self) -> StateSync: return self._provided_state_sync or self._scheduler.create_state_sync(self) - def _new_selector(self) -> Selector: - return Selector( + def _new_selector( + self, models: t.Optional[UniqueKeyDict[str, Model]] = None, dag: t.Optional[DAG[str]] = None + ) -> Selector: + return self._selector_cls( self.state_reader, - self._models, + models=models or self._models, context_path=self.path, + dag=dag, default_catalog=self.default_catalog, dialect=self.default_dialect, + cache_dir=self.cache_dir, ) def _register_notification_targets(self) -> None: @@ -1994,6 +2965,299 @@ def _register_notification_targets(self) -> None: event_notifications, user_notification_targets, username=self.config.username ) + def _load_materializations(self) -> None: + if not self._loaded: + for loader in self._loaders: + loader.load_materializations() + + def _select_models_for_run( + self, + select_models: t.Collection[str], + no_auto_upstream: bool, + snapshots: t.Collection[Snapshot], + ) -> t.Set[str]: + models: UniqueKeyDict[str, Model] = UniqueKeyDict( + "models", **{s.name: s.model for s in snapshots if s.is_model} + ) + dag: DAG[str] = DAG() + for fqn, model in models.items(): + dag.add(fqn, model.depends_on) + model_selector = self._new_selector(models=models, dag=dag) + result = set(model_selector.expand_model_selections(select_models)) + if not no_auto_upstream: + result = set(dag.subdag(*result)) + return result + + @cached_property + def _project_type(self) -> str: + project_types = { + c.DBT if loader.__class__.__name__.lower().startswith(c.DBT) else c.NATIVE + for loader in self._loaders + } + return c.HYBRID if len(project_types) > 1 else first(project_types) + + def _nodes_to_snapshots(self, nodes: t.Dict[str, Node]) -> t.Dict[str, Snapshot]: + snapshots: t.Dict[str, Snapshot] = {} + fingerprint_cache: t.Dict[str, SnapshotFingerprint] = {} + + for node in nodes.values(): + kwargs: t.Dict[str, t.Any] = {} + if node.project in self._projects: + config = self.config_for_node(node) + kwargs["ttl"] = config.snapshot_ttl + kwargs["table_naming_convention"] = config.physical_table_naming_convention + + snapshot = Snapshot.from_node( + node, + nodes=nodes, + cache=fingerprint_cache, + **kwargs, + ) + snapshots[snapshot.name] = snapshot + return snapshots + + def _node_or_snapshot_to_fqn(self, node_or_snapshot: NodeOrSnapshot) -> str: + if isinstance(node_or_snapshot, Snapshot): + return node_or_snapshot.name + if isinstance(node_or_snapshot, str) and not self.standalone_audits.get(node_or_snapshot): + return normalize_model_name( + node_or_snapshot, + dialect=self.default_dialect, + default_catalog=self.default_catalog, + ) + if not isinstance(node_or_snapshot, str): + return node_or_snapshot.fqn + return node_or_snapshot + + @property + def _plan_preview_enabled(self) -> bool: + if self.config.plan.enable_preview is not None: + return self.config.plan.enable_preview + # It is dangerous to enable preview by default for dbt projects that rely on engines that don't support cloning. + # Enabling previews in such cases can result in unintended full refreshes because dbt incremental models rely on + # the maximum timestamp value in the target table. + return self._project_type == c.NATIVE or self.engine_adapter.SUPPORTS_CLONING + + def _get_plan_default_start_end( + self, + snapshots: t.Dict[str, Snapshot], + max_interval_end_per_model: t.Dict[str, datetime], + backfill_models: t.Optional[t.Set[str]], + modified_model_names: t.Set[str], + execution_time: t.Optional[TimeLike] = None, + ) -> t.Tuple[t.Optional[int], t.Optional[int]]: + # exclude seeds so their stale interval ends does not become the default plan end date + # when they're the only ones that contain intervals in this plan + non_seed_interval_ends = { + model_fqn: end + for model_fqn, end in max_interval_end_per_model.items() + if model_fqn not in snapshots or not snapshots[model_fqn].is_seed + } + if not non_seed_interval_ends: + return None, None + + default_end = to_timestamp(max(non_seed_interval_ends.values())) + default_start: t.Optional[int] = None + # Infer the default start by finding the smallest interval start that corresponds to the default end. + for model_name in backfill_models or modified_model_names or max_interval_end_per_model: + if model_name not in snapshots: + continue + node = snapshots[model_name].node + interval_unit = node.interval_unit + default_start = min( + default_start or sys.maxsize, + to_timestamp( + interval_unit.cron_prev( + interval_unit.cron_floor( + max_interval_end_per_model.get( + model_name, node.cron_floor(default_end) + ), + ), + estimate=True, + ) + ), + ) + + if execution_time and to_timestamp(default_end) > to_timestamp(execution_time): + # the end date can't be in the future, which can happen if a specific `execution_time` is set and prod intervals + # are newer than it + default_end = to_timestamp(execution_time) + + return default_start, default_end + + def _calculate_start_override_per_model( + self, + min_intervals: t.Optional[int], + plan_start: t.Optional[TimeLike], + plan_end: t.Optional[TimeLike], + plan_execution_time: TimeLike, + backfill_model_fqns: t.Optional[t.Set[str]], + snapshots_by_model_fqn: t.Dict[str, Snapshot], + end_override_per_model: t.Optional[t.Dict[str, datetime]], + ) -> t.Dict[str, datetime]: + if not min_intervals or not backfill_model_fqns or not plan_start: + # If there are no models to backfill, there are no intervals to consider for backfill, so we dont need to consider a minimum number + # If the plan doesnt have a start date, all intervals are considered already so we dont need to consider a minimum number + # If we dont have a minimum number of intervals to consider, then we dont need to adjust the start date on a per-model basis + return {} + + start_overrides: t.Dict[str, datetime] = {} + end_override_per_model = end_override_per_model or {} + + plan_execution_time_dt = to_datetime(plan_execution_time) + plan_start_dt = to_datetime(plan_start, relative_base=plan_execution_time_dt) + plan_end_dt = to_datetime( + plan_end or plan_execution_time_dt, relative_base=plan_execution_time_dt + ) + + # we need to take the DAG into account so that parent models can be expanded to cover at least as much as their children + # for example, A(hourly) <- B(daily) + # if min_intervals=1, A would have 1 hour and B would have 1 day + # but B depends on A so in order for B to have 1 valid day, A needs to be expanded to 24 hours + backfill_dag: DAG[str] = DAG() + for fqn in backfill_model_fqns: + backfill_dag.add( + fqn, + [ + p.name + for p in snapshots_by_model_fqn[fqn].parents + if p.name in backfill_model_fqns + ], + ) + + # start from the leaf nodes and work back towards the root because the min_start at the root node is determined by the calculated starts in the leaf nodes + reversed_dag = backfill_dag.reversed + graph = reversed_dag.graph + + for model_fqn in reversed_dag: + # Get the earliest start from all immediate children of this snapshot + # this works because topological ordering guarantees that they've already been visited + # and we always set a start override + min_child_start = min( + [start_overrides[immediate_child_fqn] for immediate_child_fqn in graph[model_fqn]], + default=plan_start_dt, + ) + + snapshot = snapshots_by_model_fqn.get(model_fqn) + + if not snapshot: + continue + + starting_point = end_override_per_model.get(model_fqn, plan_end_dt) + if node_end := snapshot.node.end: + # if we dont do this, if the node end is a *date* (as opposed to a timestamp) + # we end up incorrectly winding back an extra day + node_end_dt = make_exclusive(node_end) + + if node_end_dt < plan_end_dt: + # if the model has an end date that has already elapsed, use that as a starting point for calculating min_intervals + # instead of the plan end. If we use the plan end, we will return intervals in the future which are invalid + starting_point = node_end_dt + + snapshot_start = snapshot.node.cron_floor(starting_point) + + for _ in range(min_intervals): + # wind back the starting point by :min_intervals intervals to arrive at the minimum snapshot start date + snapshot_start = snapshot.node.cron_prev(snapshot_start) + + start_overrides[model_fqn] = min(min_child_start, snapshot_start) + + return start_overrides + + def _get_max_interval_end_per_model( + self, snapshots: t.Dict[str, Snapshot], backfill_models: t.Optional[t.Set[str]] + ) -> t.Dict[str, datetime]: + models_for_interval_end = ( + self._get_models_for_interval_end(snapshots, backfill_models) + if backfill_models is not None + else None + ) + return { + model_fqn: to_datetime(ts) + for model_fqn, ts in self.state_sync.max_interval_end_per_model( + c.PROD, + models=models_for_interval_end, + ensure_finalized_snapshots=self.config.plan.use_finalized_state, + ).items() + } + + @staticmethod + def _get_models_for_interval_end( + snapshots: t.Dict[str, Snapshot], backfill_models: t.Set[str] + ) -> t.Set[str]: + models_for_interval_end = set() + models_stack = list(backfill_models) + while models_stack: + next_model = models_stack.pop() + if next_model not in snapshots: + continue + models_for_interval_end.add(next_model) + models_stack.extend( + s.name + for s in snapshots[next_model].parents + if s.name not in models_for_interval_end + ) + return models_for_interval_end + + def lint_models( + self, + models: t.Optional[t.Iterable[t.Union[str, Model]]] = None, + raise_on_error: bool = True, + ) -> t.List[AnnotatedRuleViolation]: + found_error = False + + model_list = ( + list(self.get_model(model, raise_if_missing=True) for model in models) + if models + else self.models.values() + ) + all_violations = [] + for model in model_list: + # Linter may be `None` if the context is not loaded yet + if linter := self._linters.get(model.project): + lint_violation, violations = ( + linter.lint_model(model, self, console=self.console) or found_error + ) + if lint_violation: + found_error = True + all_violations.extend(violations) + + if raise_on_error and found_error: + raise LinterError( + "Linter detected errors in the code. Please fix them before proceeding." + ) + + return all_violations + + def select_tests( + self, + tests: t.Optional[t.List[str]] = None, + patterns: t.Optional[t.List[str]] = None, + ) -> t.List[ModelTestMetadata]: + """Filter pre-loaded test metadata based on tests and patterns.""" + + test_meta = self._model_test_metadata + + if tests: + filtered_tests = [] + for test in tests: + if "::" in test: + if test in self._model_test_metadata_fully_qualified_name_index: + filtered_tests.append( + self._model_test_metadata_fully_qualified_name_index[test] + ) + else: + test_path = Path(test) + if test_path in self._model_test_metadata_path_index: + filtered_tests.extend(self._model_test_metadata_path_index[test_path]) + + test_meta = filtered_tests + + if patterns: + test_meta = filter_tests_by_patterns(test_meta, patterns) + + return test_meta + class Context(GenericContext[Config]): CONFIG_TYPE = Config diff --git a/sqlmesh/core/context_diff.py b/sqlmesh/core/context_diff.py index 11d4861c45..07d13b1c2f 100644 --- a/sqlmesh/core/context_diff.py +++ b/sqlmesh/core/context_diff.py @@ -12,18 +12,31 @@ from __future__ import annotations -import logging +import sys import typing as t +from difflib import ndiff, unified_diff from functools import cached_property - +from sqlmesh.core import constants as c +from sqlmesh.core.console import get_console +from sqlmesh.core.macros import RuntimeStage +from sqlmesh.core.model.common import sorted_python_env_payloads from sqlmesh.core.snapshot import Snapshot, SnapshotId, SnapshotTableInfo from sqlmesh.utils.errors import SQLMeshError from sqlmesh.utils.pydantic import PydanticModel +if sys.version_info >= (3, 12): + from importlib import metadata +else: + import importlib_metadata as metadata # type: ignore + + if t.TYPE_CHECKING: from sqlmesh.core.state_sync import StateReader -logger = logging.getLogger(__name__) +from sqlmesh.utils.metaprogramming import Executable # noqa +from sqlmesh.core.environment import EnvironmentStatements + +IGNORED_PACKAGES = {"sqlmesh", "sqlglot"} class ContextDiff(PydanticModel): @@ -41,8 +54,14 @@ class ContextDiff(PydanticModel): """Whether the currently stored environment record is in unfinalized state.""" normalize_environment_name: bool """Whether the environment name should be normalized.""" + previous_gateway_managed_virtual_layer: bool + """Whether the previous environment's virtual layer's views were created by the model specified gateways.""" + gateway_managed_virtual_layer: bool + """Whether the virtual layer's views will be created by the model specified gateways.""" create_from: str """The name of the environment the target environment will be created from if new.""" + create_from_env_exists: bool + """Whether the create_from environment already exists at plan time.""" added: t.Set[SnapshotId] """New nodes.""" removed_snapshots: t.Dict[SnapshotId, SnapshotTableInfo] @@ -59,6 +78,16 @@ class ContextDiff(PydanticModel): """Snapshot IDs that were promoted by the previous plan.""" previous_finalized_snapshots: t.Optional[t.List[SnapshotTableInfo]] """Snapshots from the previous finalized state.""" + previous_requirements: t.Dict[str, str] = {} + """Previous requirements.""" + requirements: t.Dict[str, str] = {} + """Python dependencies.""" + previous_environment_statements: t.List[EnvironmentStatements] = [] + """Previous environment statements.""" + environment_statements: t.List[EnvironmentStatements] + """Environment statements.""" + diff_rendered: bool = False + """Whether the diff should compare raw vs rendered models""" @classmethod def create( @@ -68,6 +97,13 @@ def create( create_from: str, state_reader: StateReader, ensure_finalized_snapshots: bool = False, + provided_requirements: t.Optional[t.Dict[str, str]] = None, + excluded_requirements: t.Optional[t.Set[str]] = None, + diff_rendered: bool = False, + environment_statements: t.Optional[t.List[EnvironmentStatements]] = [], + gateway_managed_virtual_layer: bool = False, + infer_python_dependencies: bool = True, + always_recreate_environment: bool = False, ) -> ContextDiff: """Create a ContextDiff object. @@ -80,18 +116,37 @@ def create( ensure_finalized_snapshots: Whether to compare against snapshots from the latest finalized environment state, or to use whatever snapshots are in the current environment state even if the environment is not finalized. + provided_requirements: Python dependencies sourced from the lock file. + excluded_requirements: Python dependencies to exclude. + diff_rendered: Whether to compute the diff of the rendered version of the compared expressions. + environment_statements: A list of `before_all` or `after_all` statements associated with the environment. + gateway_managed_virtual_layer: Whether the models' views in the virtual layer are created by the + model-specific gateway rather than the default gateway. + infer_python_dependencies: Whether to statically analyze Python code to automatically infer Python + package requirements. Returns: The ContextDiff object. """ environment = environment.lower() - env = state_reader.get_environment(environment) + existing_env = state_reader.get_environment(environment) + create_from_env_exists = False - if env is None: + recreate_environment = always_recreate_environment and not environment == create_from + + if existing_env is None or existing_env.expired or recreate_environment: env = state_reader.get_environment(create_from.lower()) + + if not env and create_from != c.PROD: + get_console().log_warning( + f"The environment name '{create_from}' was passed to the `plan` command's `--create-from` argument, but '{create_from}' does not exist. Initializing new environment '{environment}' from scratch." + ) + is_new_environment = True + create_from_env_exists = env is not None previously_promoted_snapshot_ids = set() else: + env = existing_env is_new_environment = False previously_promoted_snapshot_ids = {s.snapshot_id for s in env.promoted_snapshots} @@ -160,20 +215,44 @@ def create( stored[modified_snapshot_info.snapshot_id], ) + requirements = _build_requirements( + provided_requirements or {}, + excluded_requirements or set(), + snapshots.values(), + infer_python_dependencies=infer_python_dependencies, + ) + + previous_environment_statements = ( + state_reader.get_environment_statements(env.name) if env else [] + ) + + if existing_env and always_recreate_environment: + previous_plan_id: t.Optional[str] = existing_env.plan_id + else: + previous_plan_id = env.plan_id if env and not is_new_environment else None + return ContextDiff( environment=environment, is_new_environment=is_new_environment, is_unfinalized_environment=bool(env and not env.finalized_ts), normalize_environment_name=is_new_environment or bool(env and env.normalize_name), create_from=create_from, + create_from_env_exists=create_from_env_exists, added=added, removed_snapshots=removed, modified_snapshots=modified_snapshots, snapshots=merged_snapshots, new_snapshots=new_snapshots, - previous_plan_id=env.plan_id if env and not is_new_environment else None, + previous_plan_id=previous_plan_id, previously_promoted_snapshot_ids=previously_promoted_snapshot_ids, previous_finalized_snapshots=env.previous_finalized_snapshots if env else None, + previous_requirements=env.requirements if env else {}, + requirements=requirements, + diff_rendered=diff_rendered, + previous_environment_statements=previous_environment_statements, + environment_statements=environment_statements, + previous_gateway_managed_virtual_layer=env.gateway_managed if env else False, + gateway_managed_virtual_layer=gateway_managed_virtual_layer, ) @classmethod @@ -191,6 +270,7 @@ def create_no_diff(cls, environment: str, state_reader: StateReader) -> ContextD if not env: raise SQLMeshError(f"Environment '{environment}' must exist for this operation.") + environment_statements = state_reader.get_environment_statements(environment) snapshots = state_reader.get_snapshots(env.snapshots) return ContextDiff( @@ -199,6 +279,7 @@ def create_no_diff(cls, environment: str, state_reader: StateReader) -> ContextD is_unfinalized_environment=False, normalize_environment_name=env.normalize_name, create_from="", + create_from_env_exists=False, added=set(), removed_snapshots={}, modified_snapshots={}, @@ -207,12 +288,33 @@ def create_no_diff(cls, environment: str, state_reader: StateReader) -> ContextD previous_plan_id=env.plan_id, previously_promoted_snapshot_ids={s.snapshot_id for s in env.promoted_snapshots}, previous_finalized_snapshots=env.previous_finalized_snapshots, + previous_requirements=env.requirements, + requirements=env.requirements, + previous_environment_statements=environment_statements, + environment_statements=environment_statements, + previous_gateway_managed_virtual_layer=env.gateway_managed, + gateway_managed_virtual_layer=env.gateway_managed, ) @property def has_changes(self) -> bool: return ( - self.has_snapshot_changes or self.is_new_environment or self.is_unfinalized_environment + self.has_snapshot_changes + or self.is_new_environment + or self.is_unfinalized_environment + or self.has_requirement_changes + or self.has_environment_statements_changes + or self.previous_gateway_managed_virtual_layer != self.gateway_managed_virtual_layer + ) + + @property + def has_requirement_changes(self) -> bool: + return self.previous_requirements != self.requirements + + @property + def has_environment_statements_changes(self) -> bool: + return sorted(self.environment_statements, key=lambda s: s.project or "") != sorted( + self.previous_environment_statements, key=lambda s: s.project or "" ) @property @@ -251,6 +353,61 @@ def current_modified_snapshot_ids(self) -> t.Set[SnapshotId]: def snapshots_by_name(self) -> t.Dict[str, Snapshot]: return {x.name: x for x in self.snapshots.values()} + def requirements_diff(self) -> str: + return " " + "\n ".join( + ndiff( + [ + f"{k}=={self.previous_requirements[k]}" + for k in sorted(self.previous_requirements) + ], + [f"{k}=={self.requirements[k]}" for k in sorted(self.requirements)], + ) + ) + + def environment_statements_diff( + self, include_python_env: bool = False + ) -> t.List[t.Tuple[str, str]]: + def extract_statements(statements: t.List[EnvironmentStatements], attr: str) -> t.List[str]: + return [ + string + for statement in statements + for expr in ( + sorted_python_env_payloads(statement.python_env) + if attr == "python_env" + else getattr(statement, attr) + ) + for string in expr.split("\n") + ] + + def compute_diff(attribute: str) -> t.Optional[t.Tuple[str, str]]: + previous = extract_statements(self.previous_environment_statements, attribute) + current = extract_statements(self.environment_statements, attribute) + + if previous == current: + return None + + diff_text = attribute if not attribute == "python_env" else "dependencies" + diff_text += ":\n" + if attribute == "python_env": + diff = list(unified_diff(previous, current)) + diff_text += "\n".join(diff[2:] if len(diff) > 1 else diff) + return "python", diff_text + "\n" + + diff_lines = list(ndiff(previous, current)) + if any(line.startswith(("-", "+")) for line in diff_lines): + diff_text += " " + "\n ".join(diff_lines) + "\n" + return "sql", diff_text + + return [ + diff + for attribute in [ + RuntimeStage.BEFORE_ALL.value, + RuntimeStage.AFTER_ALL.value, + *(["python_env"] if include_python_env else []), + ] + if (diff := compute_diff(attribute)) is not None + ] + @property def environment_snapshots(self) -> t.List[SnapshotTableInfo]: """Returns current snapshots in the environment.""" @@ -278,7 +435,7 @@ def directly_modified(self, name: str) -> bool: return False current, previous = self.modified_snapshots[name] - return current.fingerprint.data_hash != previous.fingerprint.data_hash + return current.is_directly_modified(previous) def indirectly_modified(self, name: str) -> bool: """Returns whether or not a node was indirectly modified in this context. @@ -294,10 +451,7 @@ def indirectly_modified(self, name: str) -> bool: return False current, previous = self.modified_snapshots[name] - return ( - current.fingerprint.data_hash == previous.fingerprint.data_hash - and current.fingerprint.parent_data_hash != previous.fingerprint.parent_data_hash - ) + return current.is_indirectly_modified(previous) def metadata_updated(self, name: str) -> bool: """Returns whether or not the given node's metadata has been updated. @@ -313,7 +467,7 @@ def metadata_updated(self, name: str) -> bool: return False current, previous = self.modified_snapshots[name] - return current.fingerprint.metadata_hash != previous.fingerprint.metadata_hash + return current.is_metadata_updated(previous) def text_diff(self, name: str) -> str: """Finds the difference of a node between the current and remote environment. @@ -331,7 +485,51 @@ def text_diff(self, name: str) -> str: new, old = self.modified_snapshots[name] try: - return old.node.text_diff(new.node) + return old.node.text_diff(new.node, rendered=self.diff_rendered) except SQLMeshError as e: - logger.warning("Failed to diff model '%s': %s", name, str(e)) + get_console().log_warning(f"Failed to diff model '{name}': {str(e)}.") return "" + + +def _build_requirements( + provided_requirements: t.Dict[str, str], + excluded_requirements: t.Set[str], + snapshots: t.Collection[Snapshot], + infer_python_dependencies: bool = True, +) -> t.Dict[str, str]: + requirements = { + k: v for k, v in provided_requirements.items() if k not in excluded_requirements + } + + if not infer_python_dependencies: + return requirements + + distributions = metadata.packages_distributions() + + for snapshot in snapshots: + if not snapshot.is_model: + continue + + for executable in snapshot.model.python_env.values(): + if executable.kind != "import": + continue + + try: + start = "from " if executable.payload.startswith("from ") else "import " + lib = executable.payload.split(start)[1].split()[0].split(".")[0] + if lib not in distributions: + continue + + for dist in distributions[lib]: + if ( + dist not in requirements + and dist not in IGNORED_PACKAGES + and dist not in excluded_requirements + ): + requirements[dist] = metadata.version(dist) + except metadata.PackageNotFoundError: + from sqlmesh.core.console import get_console + + get_console().log_warning(f"Failed to find package for {lib}.") + + return requirements diff --git a/sqlmesh/core/dialect.py b/sqlmesh/core/dialect.py index 19407487c4..c0a48326f2 100644 --- a/sqlmesh/core/dialect.py +++ b/sqlmesh/core/dialect.py @@ -1,17 +1,19 @@ from __future__ import annotations import functools +import logging import re import sys import typing as t from contextlib import contextmanager from difflib import unified_diff from enum import Enum, auto +from functools import lru_cache -import pandas as pd from sqlglot import Dialect, Generator, ParseError, Parser, Tokenizer, TokenType, exp from sqlglot.dialects.dialect import DialectType -from sqlglot.dialects.snowflake import Snowflake +from sqlglot.dialects import DuckDB, Snowflake, TSQL +import sqlglot.dialects.athena as athena from sqlglot.helper import seq_get from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlglot.optimizer.qualify_columns import quote_identifiers @@ -21,10 +23,13 @@ from sqlglot.tokens import Token from sqlmesh.core.constants import MAX_MODEL_DEFINITION_SIZE -from sqlmesh.utils.errors import SQLMeshError +from sqlmesh.utils import get_source_columns_to_types +from sqlmesh.utils.errors import SQLMeshError, ConfigError from sqlmesh.utils.pandas import columns_to_types_from_df if t.TYPE_CHECKING: + import pandas as pd + from sqlglot._typing import E @@ -32,6 +37,8 @@ TABLES_META = "sqlmesh.tables" +logger = logging.getLogger(__name__) + class Model(exp.Expression): arg_types = {"expressions": True} @@ -57,6 +64,10 @@ class JinjaStatement(Jinja): pass +class VirtualUpdateStatement(exp.Expression): + arg_types = {"expressions": True} + + class ModelKind(exp.Expression): arg_types = {"this": True, "expressions": False} @@ -101,15 +112,18 @@ def output_name(self) -> str: return self.this.name -class StagedFilePath(exp.Table): +class StagedFilePath(exp.Expression): """Represents paths to "staged files" in Snowflake.""" + arg_types = exp.Table.arg_types.copy() + def _parse_statement(self: Parser) -> t.Optional[exp.Expression]: if self._curr is None: return None parser = PARSERS.get(self._curr.text.upper()) + error_msg = None if parser: # Capture any available description in the form of a comment @@ -119,7 +133,8 @@ def _parse_statement(self: Parser) -> t.Optional[exp.Expression]: try: self._advance() meta = self._parse_wrapped(lambda: t.cast(t.Callable, parser)(self)) - except ParseError: + except ParseError as parse_error: + error_msg = parse_error.args[0] self._retreat(index) # Only return the DDL expression if we actually managed to parse one. This is @@ -129,7 +144,12 @@ def _parse_statement(self: Parser) -> t.Optional[exp.Expression]: meta.comments = comments return meta - return self.__parse_statement() # type: ignore + try: + return self.__parse_statement() # type: ignore + except ParseError: + if error_msg: + raise ParseError(error_msg) + raise def _parse_lambda(self: Parser, alias: bool = False) -> t.Optional[exp.Expression]: @@ -154,6 +174,7 @@ def _parse_id_var( while ( identifier + and not identifier.args.get("quoted") and self._is_connected() and ( self._match_texts(("{", SQLMESH_MACRO_PREFIX)) @@ -311,6 +332,46 @@ def _parse_join( return macro +def _warn_unsupported(self: Parser) -> None: + from sqlmesh.core.console import get_console + + sql = self._find_sql(self._tokens[0], self._tokens[-1])[: self.error_message_context] + + get_console().log_warning( + f"'{sql}' could not be semantically understood as it contains unsupported syntax, SQLMesh will treat the command as is. Note that any references to the model's " + "underlying physical table can't be resolved in this case, consider using Jinja as explained here https://sqlmesh.readthedocs.io/en/stable/concepts/macros/macro_variables/#audit-only-variables" + ) + + +def _parse_select( + self: Parser, + nested: bool = False, + table: bool = False, + parse_subquery_alias: bool = True, + parse_set_operation: bool = True, + consume_pipe: bool = True, + from_: t.Optional[exp.From] = None, +) -> t.Optional[exp.Expression]: + select = self.__parse_select( # type: ignore + nested=nested, + table=table, + parse_subquery_alias=parse_subquery_alias, + parse_set_operation=parse_set_operation, + consume_pipe=consume_pipe, + from_=from_, + ) + + if ( + not select + and not parse_set_operation + and self._match_pair(TokenType.PARAMETER, TokenType.VAR, advance=False) + ): + self._advance() + return _parse_macro(self) + + return select + + def _parse_where(self: Parser, skip_where_token: bool = False) -> t.Optional[exp.Expression]: macro = _parse_matching_macro(self, "WHERE") if not macro: @@ -363,14 +424,41 @@ def _parse_limit( return macro +def _parse_value(self: Parser, values: bool = True) -> t.Optional[exp.Expression]: + wrapped = self._match(TokenType.L_PAREN, advance=False) + + # The base _parse_value method always constructs a Tuple instance. This is problematic when + # generating values with a macro function, because it's impossible to tell whether the user's + # intention was to construct a row or a column with the VALUES expression. To avoid this, we + # amend the AST such that the Tuple is replaced by the macro function call itself. + expr = self.__parse_value() # type: ignore + if expr and not wrapped and isinstance(seq_get(expr.expressions, 0), MacroFunc): + return expr.expressions[0] + + return expr + + +def _parse_macro_or_clause(self: Parser, parser: t.Callable) -> t.Optional[exp.Expression]: + return _parse_macro(self) if self._match(TokenType.PARAMETER) else parser() + + def _parse_props(self: Parser) -> t.Optional[exp.Expression]: key = self._parse_id_var(any_token=True) if not key: return None name = key.name.lower() - if name == "when_matched": - value: t.Optional[exp.Expression] = self._parse_when_matched()[0] + if name == "time_data_type": + # TODO: if we make *_data_type a convention to parse things into exp.DataType, we could make this more generic + value = self._parse_types(schema=True) + elif name == "when_matched": + # Parentheses around the WHEN clauses can be used to disambiguate them from other properties + value = self._parse_wrapped( + lambda: _parse_macro_or_clause(self, self._parse_when_matched), + optional=True, + ) + elif name == "merge_filter": + value = self._parse_conjunction() elif self._match(TokenType.L_PAREN): value = self.expression(exp.Tuple, expressions=self._parse_csv(self._parse_equality)) self._match_r_paren() @@ -379,7 +467,7 @@ def _parse_props(self: Parser) -> t.Optional[exp.Expression]: if name == "path" and value: # Make sure if we get a windows path that it is converted to posix - value = exp.Literal.string(value.this.replace("\\", "/")) + value = exp.Literal.string(value.this.replace("\\", "/")) # type: ignore return self.expression(exp.Property, this=name, value=value) @@ -408,17 +496,35 @@ def _parse_types( # # See: https://docs.snowflake.com/en/user-guide/querying-stage def _parse_table_parts( - self: Parser, schema: bool = False, is_db_reference: bool = False -) -> exp.Table: + self: Parser, schema: bool = False, is_db_reference: bool = False, wildcard: bool = False +) -> exp.Table | StagedFilePath: index = self._index - table = self.__parse_table_parts(schema=schema, is_db_reference=is_db_reference) # type: ignore + table = self.__parse_table_parts( # type: ignore + schema=schema, is_db_reference=is_db_reference, wildcard=wildcard + ) table_arg = table.this - name = table_arg.name - - if isinstance(table_arg, exp.Var) and name.startswith(SQLMESH_MACRO_PREFIX): - # Macro functions do not clash with the staged file syntax, so we can safely parse them - if self._prev.token_type == TokenType.STRING or any(ch in name for ch in ("(", "{")): + name = table_arg.name if isinstance(table_arg, exp.Var) else "" + + if name.startswith(SQLMESH_MACRO_PREFIX): + # In these cases, we don't want to produce a `StagedFilePath` node: + # + # - @'...' needs to parsed as a string template + # - @{foo}.bar needs to be parsed as a table with a macro var part + # - @name(arg1 [, arg2 ...]) needs to be parsed as a macro function call + # + # These cases can unambiguously be parsed using the base `_parse_table_parts`, as there + # is no overlap with staged files https://docs.snowflake.com/en/user-guide/querying-stage + if ( + self._prev.token_type == TokenType.STRING + or "{" in name + or ( + self._curr + and self._prev.token_type in (TokenType.L_PAREN, TokenType.R_PAREN) + and self._curr.text.upper() not in ("FILE_FORMAT", "PATTERN") + and not (table.args.get("format") or table.args.get("pattern")) + ) + ): self._retreat(index) return Parser._parse_table_parts(self, schema=schema, is_db_reference=is_db_reference) @@ -451,10 +557,16 @@ def _parse_if(self: Parser) -> t.Optional[exp.Expression]: else: self.raise_error("Expecting )") - return exp.Anonymous(this="IF", expressions=[cond, self._parse_statement()]) + index = self._index + stmt = self._parse_statement() + if self._curr: + self._retreat(index) + stmt = self._parse_as_command(self._tokens[index]) + + return exp.Anonymous(this="IF", expressions=[cond, stmt]) -def _create_parser(parser_type: t.Type[exp.Expression], table_keys: t.List[str]) -> t.Callable: +def _create_parser(expression_type: t.Type[exp.Expression], table_keys: t.List[str]) -> t.Callable: def parse(self: Parser) -> t.Optional[exp.Expression]: from sqlmesh.core.model.kind import ModelKindName @@ -481,19 +593,31 @@ def parse(self: Parser) -> t.Optional[exp.Expression]: if key in table_keys: value = self._parse_table_parts() + if value and self._prev.token_type == TokenType.STRING: + self.raise_error( + f"'{key}' property cannot be a string value: {value}. " + "Please use the identifier syntax instead, e.g. foo.bar instead of 'foo.bar'" + ) elif key == "columns": value = self._parse_schema() elif key == "kind": - id_var = self._parse_id_var(any_token=True) - if not id_var: - value = None + field = _parse_macro_or_clause(self, lambda: self._parse_id_var(any_token=True)) + + if not field or isinstance(field, (MacroVar, MacroFunc)): + value = field else: - kind = ModelKindName[id_var.name.upper()] + try: + kind = ModelKindName[field.name.upper()] + except KeyError: + raise SQLMeshError( + f"Model kind specified as '{field.name}', but that is not a valid model kind.\n\nPlease specify one of {', '.join(ModelKindName)}." + ) if kind in ( ModelKindName.INCREMENTAL_BY_TIME_RANGE, ModelKindName.INCREMENTAL_BY_UNIQUE_KEY, ModelKindName.INCREMENTAL_BY_PARTITION, + ModelKindName.INCREMENTAL_UNMANAGED, ModelKindName.SEED, ModelKindName.VIEW, ModelKindName.SCD_TYPE_2, @@ -505,13 +629,15 @@ def parse(self: Parser) -> t.Optional[exp.Expression]: else: props = None - value = self.expression( - ModelKind, - this=kind.value, - expressions=props, - ) + value = self.expression(ModelKind, this=kind.value, expressions=props) elif key == "expression": value = self._parse_conjunction() + elif key == "partitioned_by": + partitioned_by = self._parse_partitioned_by() + if isinstance(partitioned_by.this, exp.Schema): + value = exp.tuple_(*partitioned_by.this.expressions) + else: + value = partitioned_by.this else: value = self._parse_bracket(self._parse_field(any_token=True)) @@ -520,7 +646,7 @@ def parse(self: Parser) -> t.Optional[exp.Expression]: expressions.append(self.expression(exp.Property, this=key, value=value)) - return self.expression(parser_type, expressions=expressions) + return self.expression(expression_type, expressions=expressions) return parse @@ -537,15 +663,29 @@ def _props_sql(self: Generator, expressions: t.List[exp.Expression]) -> str: size = len(expressions) for i, prop in enumerate(expressions): - sql = self.indent(f"{prop.name} {self.sql(prop, 'value')}") + if isinstance(prop, MacroFunc): + sql = self.indent(self.sql(prop, comment=False)) + else: + sql = self.indent(f"{prop.name} {self.sql(prop, 'value')}") if i < size - 1: sql += "," + props.append(self.maybe_comment(sql, expression=prop)) return "\n".join(props) +def _on_virtual_update_sql(self: Generator, expressions: t.List[exp.Expression]) -> str: + statements = "\n".join( + self.sql(expression) + if isinstance(expression, JinjaStatement) + else f"{self.sql(expression)};" + for expression in expressions + ) + return f"{ON_VIRTUAL_UPDATE_BEGIN};\n{statements}\n{ON_VIRTUAL_UPDATE_END};" + + def _sqlmesh_ddl_sql(self: Generator, expression: Model | Audit | Metric, name: str) -> str: return "\n".join([f"{name} (", _props_sql(self, expression.expressions), ")"]) @@ -569,8 +709,19 @@ def _macro_func_sql(self: Generator, expression: MacroFunc) -> str: expression = expression.this name = expression.name if name in KEYWORD_MACROS: - return _macro_keyword_func_sql(self, expression) - return f"@{name}({self.format_args(*expression.expressions)})" + sql = _macro_keyword_func_sql(self, expression) + else: + sql = f"@{name}({self.format_args(*expression.expressions)})" + return self.maybe_comment(sql, expression) + + +def _whens_sql(self: Generator, expression: exp.Whens) -> str: + if isinstance(expression.parent, exp.Merge): + return self.whens_sql(expression) + + # If the `WHEN` clauses aren't part of a MERGE statement (e.g. they + # appear in the `MODEL` DDL), then we will wrap them with parentheses. + return self.wrap(self.expressions(expression, sep=" ", indent=False)) def _override(klass: t.Type[Tokenizer | Parser], func: t.Callable) -> None: @@ -580,49 +731,55 @@ def _override(klass: t.Type[Tokenizer | Parser], func: t.Callable) -> None: def format_model_expressions( - expressions: t.List[exp.Expression], dialect: t.Optional[str] = None, **kwargs: t.Any + expressions: t.List[exp.Expression], + dialect: t.Optional[str] = None, + rewrite_casts: bool = True, + **kwargs: t.Any, ) -> str: """Format a model's expressions into a standardized format. Args: expressions: The model's expressions, must be at least model def + query. dialect: The dialect to render the expressions as. + rewrite_casts: Whether to rewrite all casts to use the :: syntax. **kwargs: Additional keyword arguments to pass to the sql generator. Returns: A string representing the formatted model. """ - if len(expressions) == 1: + if len(expressions) == 1 and is_meta_expression(expressions[0]): return expressions[0].sql(pretty=True, dialect=dialect) - *statements, query = expressions + if rewrite_casts: - def cast_to_colon(node: exp.Expression) -> exp.Expression: - if isinstance(node, exp.Cast) and not any( - # Only convert CAST into :: if it doesn't have additional args set, otherwise this - # conversion could alter the semantics (eg. changing SAFE_CAST in BigQuery to CAST) - arg - for name, arg in node.args.items() - if name not in ("this", "to") - ): - this = node.this + def cast_to_colon(node: exp.Expression) -> exp.Expression: + if isinstance(node, exp.Cast) and not any( + # Only convert CAST into :: if it doesn't have additional args set, otherwise this + # conversion could alter the semantics (eg. changing SAFE_CAST in BigQuery to CAST) + arg + for name, arg in node.args.items() + if name not in ("this", "to") + ): + this = node.this - if not isinstance(this, (exp.Binary, exp.Unary)) or isinstance(this, exp.Paren): - cast = DColonCast(this=this, to=node.to) - cast.comments = node.comments - node = cast + if not isinstance(this, (exp.Binary, exp.Unary)) or isinstance(this, exp.Paren): + cast = DColonCast(this=this, to=node.to) + cast.comments = node.comments + node = cast - exp.replace_children(node, cast_to_colon) - return node + exp.replace_children(node, cast_to_colon) + return node + + new_expressions = [] + for expression in expressions: + expression = expression.copy() + exp.replace_children(expression, cast_to_colon) + new_expressions.append(expression) - query = query.copy() - exp.replace_children(query, cast_to_colon) + expressions = new_expressions return ";\n\n".join( - [ - *(statement.sql(pretty=True, dialect=dialect, **kwargs) for statement in statements), - query.sql(pretty=True, dialect=dialect, **kwargs), - ] + expression.sql(pretty=True, dialect=dialect, **kwargs) for expression in expressions ).strip() @@ -646,8 +803,15 @@ def text_diff( return "\n".join(unified_diff(a_sql, b_sql)) +WS_OR_COMMENT = r"(?:\s|--[^\n]*\n|/\*.*?\*/)" +HEADER = r"\b(?:model|audit)\b(?=\s*\()" +KEY_BOUNDARY = r"(?:\(|,)" # key is preceded by either '(' or ',' +DIALECT_VALUE = r"['\"]?(?P[a-z][a-z0-9]*)['\"]?" +VALUE_BOUNDARY = r"(?=,|\))" # value is followed by comma or closing paren + DIALECT_PATTERN = re.compile( - r"(model|audit).*?\(.*?dialect[^a-z,]+([a-z]*|,)", re.IGNORECASE | re.DOTALL + rf"{HEADER}.*?{KEY_BOUNDARY}{WS_OR_COMMENT}*dialect{WS_OR_COMMENT}+{DIALECT_VALUE}{WS_OR_COMMENT}*{VALUE_BOUNDARY}", + re.IGNORECASE | re.DOTALL, ) @@ -664,6 +828,8 @@ def _is_command_statement(command: str, tokens: t.List[Token], pos: int) -> bool JINJA_QUERY_BEGIN = "JINJA_QUERY_BEGIN" JINJA_STATEMENT_BEGIN = "JINJA_STATEMENT_BEGIN" JINJA_END = "JINJA_END" +ON_VIRTUAL_UPDATE_BEGIN = "ON_VIRTUAL_UPDATE_BEGIN" +ON_VIRTUAL_UPDATE_END = "ON_VIRTUAL_UPDATE_END" def _is_jinja_statement_begin(tokens: t.List[Token], pos: int) -> bool: @@ -686,10 +852,24 @@ def jinja_statement(statement: str) -> JinjaStatement: return JinjaStatement(this=exp.Literal.string(statement.strip())) +def _is_virtual_statement_begin(tokens: t.List[Token], pos: int) -> bool: + return _is_command_statement(ON_VIRTUAL_UPDATE_BEGIN, tokens, pos) + + +def _is_virtual_statement_end(tokens: t.List[Token], pos: int) -> bool: + return _is_command_statement(ON_VIRTUAL_UPDATE_END, tokens, pos) + + +def virtual_statement(statements: t.List[exp.Expression]) -> VirtualUpdateStatement: + return VirtualUpdateStatement(expressions=statements) + + class ChunkType(Enum): JINJA_QUERY = auto() JINJA_STATEMENT = auto() SQL = auto() + VIRTUAL_STATEMENT = auto() + VIRTUAL_JINJA_STATEMENT = auto() def parse_one( @@ -722,16 +902,23 @@ def parse( A list of the parsed expressions: [Model, *Statements, Query, *Statements] """ match = match_dialect and DIALECT_PATTERN.search(sql[:MAX_MODEL_DEFINITION_SIZE]) - dialect = Dialect.get_or_raise(match.group(2) if match else default_dialect) + dialect_str = match.group("dialect") if match else None + dialect = Dialect.get_or_raise(dialect_str or default_dialect) - tokens = dialect.tokenizer.tokenize(sql) + tokens = dialect.tokenize(sql) chunks: t.List[t.Tuple[t.List[Token], ChunkType]] = [([], ChunkType.SQL)] total = len(tokens) pos = 0 + virtual = False while pos < total: token = tokens[pos] - if _is_jinja_end(tokens, pos) or ( + if _is_virtual_statement_end(tokens, pos): + chunks[-1][0].append(token) + virtual = False + chunks.append(([], ChunkType.SQL)) + pos += 2 + elif _is_jinja_end(tokens, pos) or ( chunks[-1][1] == ChunkType.SQL and token.token_type == TokenType.SEMICOLON and pos < total - 1 @@ -742,13 +929,24 @@ def parse( # Jinja end statement chunks[-1][0].append(token) pos += 2 - chunks.append(([], ChunkType.SQL)) + chunks.append( + ( + [], + ChunkType.VIRTUAL_STATEMENT + if virtual and tokens[pos] != ON_VIRTUAL_UPDATE_END + else ChunkType.SQL, + ) + ) elif _is_jinja_query_begin(tokens, pos): chunks.append(([token], ChunkType.JINJA_QUERY)) pos += 2 elif _is_jinja_statement_begin(tokens, pos): chunks.append(([token], ChunkType.JINJA_STATEMENT)) pos += 2 + elif _is_virtual_statement_begin(tokens, pos): + chunks.append(([token], ChunkType.VIRTUAL_STATEMENT)) + pos += 2 + virtual = True else: chunks[-1][0].append(token) pos += 1 @@ -756,22 +954,68 @@ def parse( parser = dialect.parser() expressions: t.List[exp.Expression] = [] - for chunk, chunk_type in chunks: - if chunk_type == ChunkType.SQL: - parsed_expressions: t.List[t.Optional[exp.Expression]] = ( - parser.parse(chunk, sql) if into is None else parser.parse_into(into, chunk, sql) - ) - for expression in parsed_expressions: - if expression: + def parse_sql_chunk(chunk: t.List[Token], meta_sql: bool = True) -> t.List[exp.Expression]: + parsed_expressions: t.List[t.Optional[exp.Expression]] = ( + parser.parse(chunk, sql) if into is None else parser.parse_into(into, chunk, sql) + ) + expressions = [] + for expression in parsed_expressions: + if expression: + if meta_sql: expression.meta["sql"] = parser._find_sql(chunk[0], chunk[-1]) - expressions.append(expression) - else: - start, *_, end = chunk - segment = sql[start.end + 2 : end.start - 1] - factory = jinja_query if chunk_type == ChunkType.JINJA_QUERY else jinja_statement - expression = factory(segment.strip()) + expressions.append(expression) + return expressions + + def parse_jinja_chunk(chunk: t.List[Token], meta_sql: bool = True) -> exp.Expression: + start, *_, end = chunk + segment = sql[start.end + 2 : end.start - 1] + factory = jinja_query if chunk_type == ChunkType.JINJA_QUERY else jinja_statement + expression = factory(segment.strip()) + if meta_sql: expression.meta["sql"] = sql[start.start : end.end + 1] - expressions.append(expression) + return expression + + def parse_virtual_statement( + chunks: t.List[t.Tuple[t.List[Token], ChunkType]], pos: int + ) -> t.Tuple[t.List[exp.Expression], int]: + # For virtual statements we need to handle both SQL and Jinja nested blocks within the chunk + virtual_update_statements = [] + start = chunks[pos][0][0].start + + while ( + chunks[pos - 1][0] == [] or chunks[pos - 1][0][-1].text.upper() != ON_VIRTUAL_UPDATE_END + ): + chunk, chunk_type = chunks[pos] + if chunk_type == ChunkType.JINJA_STATEMENT: + virtual_update_statements.append(parse_jinja_chunk(chunk, False)) + else: + virtual_update_statements.extend( + parse_sql_chunk( + chunk[int(chunk[0].text.upper() == ON_VIRTUAL_UPDATE_BEGIN) : -1], False + ), + ) + pos += 1 + + if virtual_update_statements: + statements = virtual_statement(virtual_update_statements) + end = chunk[-1].end + 1 + statements.meta["sql"] = sql[start:end] + return [statements], pos + + return [], pos + + pos = 0 + total_chunks = len(chunks) + while pos < total_chunks: + chunk, chunk_type = chunks[pos] + if chunk_type == ChunkType.VIRTUAL_STATEMENT: + virtual_expression, pos = parse_virtual_statement(chunks, pos) + expressions.extend(virtual_expression) + elif chunk_type == ChunkType.SQL: + expressions.extend(parse_sql_chunk(chunk)) + else: + expressions.append(parse_jinja_chunk(chunk)) + pos += 1 return expressions @@ -783,6 +1027,14 @@ def extend_sqlglot() -> None: generators = {Generator} for dialect in Dialect.classes.values(): + # Athena picks a different Tokenizer / Parser / Generator depending on the query + # so this ensures that the extra ones it defines are also extended + if dialect == athena.Athena: + tokenizers.add(athena._TrinoTokenizer) + parsers.add(athena._TrinoParser) + generators.add(athena._TrinoGenerator) + generators.add(athena._HiveGenerator) + if hasattr(dialect, "Tokenizer"): tokenizers.add(dialect.Tokenizer) if hasattr(dialect, "Parser"): @@ -810,6 +1062,7 @@ def extend_sqlglot() -> None: JinjaQuery: lambda self, e: f"{JINJA_QUERY_BEGIN};\n{e.name}\n{JINJA_END};", JinjaStatement: lambda self, e: f"{JINJA_STATEMENT_BEGIN};\n{e.name}\n{JINJA_END};", + VirtualUpdateStatement: lambda self, e: _on_virtual_update_sql(self, e), MacroDef: lambda self, e: f"@DEF({self.sql(e.this)}, {self.sql(e.expression)})", MacroFunc: _macro_func_sql, MacroStrReplace: lambda self, e: f"@{self.sql(e.this)}", @@ -820,15 +1073,23 @@ def extend_sqlglot() -> None: ModelKind: _model_kind_sql, PythonCode: lambda self, e: self.expressions(e, sep="\n", indent=False), StagedFilePath: lambda self, e: self.table_sql(e), + exp.Whens: _whens_sql, } ) - + if MacroDef not in generator.WITH_SEPARATED_COMMENTS: generator.WITH_SEPARATED_COMMENTS = ( *generator.WITH_SEPARATED_COMMENTS, Model, MacroDef, ) + generator.UNWRAPPED_INTERVAL_VALUES = ( + *generator.UNWRAPPED_INTERVAL_VALUES, + MacroStrReplace, + MacroVar, + ) + + _override(Parser, _parse_select) _override(Parser, _parse_statement) _override(Parser, _parse_join) _override(Parser, _parse_order) @@ -837,12 +1098,18 @@ def extend_sqlglot() -> None: _override(Parser, _parse_with) _override(Parser, _parse_having) _override(Parser, _parse_limit) + _override(Parser, _parse_value) _override(Parser, _parse_lambda) _override(Parser, _parse_types) + _override(TSQL.Parser, Parser._parse_if) _override(Parser, _parse_if) _override(Parser, _parse_id_var) + _override(Parser, _warn_unsupported) _override(Snowflake.Parser, _parse_table_parts) + # DuckDB's prefix absolute power operator `@` clashes with the macro syntax + DuckDB.Parser.NO_PAREN_FUNCTION_PARSERS.pop("@", None) + def select_from_values( values: t.List[t.Tuple[t.Any, ...]], @@ -867,7 +1134,7 @@ def select_from_values( for i in range(0, num_rows, batch_size): yield select_from_values_for_batch_range( values=values, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, batch_start=i, batch_end=min(i + batch_size, num_rows), alias=alias, @@ -876,35 +1143,49 @@ def select_from_values( def select_from_values_for_batch_range( values: t.List[t.Tuple[t.Any, ...]], - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], batch_start: int, batch_end: int, alias: str = "t", + source_columns: t.Optional[t.List[str]] = None, ) -> exp.Select: - casted_columns = [ - exp.alias_(exp.cast(exp.column(column), to=kind), column, copy=False) - for column, kind in columns_to_types.items() - ] + source_columns = source_columns or list(target_columns_to_types) + source_columns_to_types = get_source_columns_to_types(target_columns_to_types, source_columns) if not values: # Ensures we don't generate an empty VALUES clause & forces a zero-row output where = exp.false() - expressions = [tuple(exp.cast(exp.null(), to=kind) for kind in columns_to_types.values())] + expressions = [ + tuple(exp.cast(exp.null(), to=kind) for kind in source_columns_to_types.values()) + ] else: where = None expressions = [ - tuple(transform_values(v, columns_to_types)) for v in values[batch_start:batch_end] + tuple(transform_values(v, source_columns_to_types)) + for v in values[batch_start:batch_end] ] - values_exp = exp.values(expressions, alias=alias, columns=columns_to_types) + values_exp = exp.values(expressions, alias=alias, columns=source_columns_to_types) if values: # BigQuery crashes on `SELECT CAST(x AS TIMESTAMP) FROM UNNEST([NULL]) AS x`, but not # on `SELECT CAST(x AS TIMESTAMP) FROM UNNEST([CAST(NULL AS TIMESTAMP)]) AS x`. This # ensures nulls under the `Values` expression are cast to avoid similar issues. - for value, kind in zip(values_exp.expressions[0].expressions, columns_to_types.values()): + for value, kind in zip( + values_exp.expressions[0].expressions, source_columns_to_types.values() + ): if isinstance(value, exp.Null): value.replace(exp.cast(value, to=kind)) + casted_columns = [ + exp.alias_( + exp.cast( + exp.column(column) if column in source_columns_to_types else exp.Null(), to=kind + ), + column, + copy=False, + ) + for column, kind in target_columns_to_types.items() + ] return exp.select(*casted_columns).from_(values_exp, copy=False).where(where, copy=False) @@ -945,6 +1226,7 @@ def set_default_catalog( return table +@lru_cache(maxsize=16384) def normalize_model_name( table: str | exp.Table | exp.Column, default_catalog: t.Optional[str], @@ -1009,19 +1291,52 @@ def transform_values( values: t.Tuple[t.Any, ...], columns_to_types: t.Dict[str, exp.DataType] ) -> t.Iterator[t.Any]: """Perform transformations on values given columns_to_types.""" - for value, col_type in zip(values, columns_to_types.values()): - if col_type.is_type(exp.DataType.Type.JSON): - yield exp.func("PARSE_JSON", f"'{value}'") - elif isinstance(value, dict) and col_type.is_type(*exp.DataType.STRUCT_TYPES): - yield _dict_to_struct(value) - else: - yield value + def _transform_value(value: t.Any, dtype: exp.DataType) -> t.Any: + if ( + isinstance(value, list) + and dtype.is_type(*exp.DataType.ARRAY_TYPES) + and len(dtype.expressions) == 1 + ): + element_type = dtype.expressions[0] + return exp.convert([_transform_value(v, element_type) for v in value]) -def to_schema(sql_path: str | exp.Table) -> exp.Table: + if ( + isinstance(value, dict) + and dtype.is_type(*exp.DataType.STRUCT_TYPES) + and len(value) == len(dtype.expressions) + ): + expressions = [] + for (field_name, field_value), field_type in zip(value.items(), dtype.expressions): + if isinstance(field_type, exp.ColumnDef): + field_type = field_type.kind + else: + field_type = exp.DataType.build(exp.DataType.Type.UNKNOWN) + + expressions.append( + exp.PropertyEQ( + this=exp.to_identifier(field_name), + expression=_transform_value(field_value, field_type), + ) + ) + + return exp.Struct(expressions=expressions) + + if dtype.is_type(exp.DataType.Type.JSON): + return exp.func("PARSE_JSON", f"'{value}'") + + return exp.convert(value) + + for col_value, col_type in zip(values, columns_to_types.values()): + yield _transform_value(col_value, col_type) + + +def to_schema(sql_path: str | exp.Table, dialect: DialectType = None) -> exp.Table: if isinstance(sql_path, exp.Table) and sql_path.this is None: return sql_path - table = exp.to_table(sql_path.copy() if isinstance(sql_path, exp.Table) else sql_path) + table = exp.to_table( + sql_path.copy() if isinstance(sql_path, exp.Table) else sql_path, dialect=dialect + ) table.set("catalog", table.args.get("db")) table.set("db", table.args.get("this")) table.set("this", None) @@ -1061,19 +1376,9 @@ def _unquote_schema(schema: t.Dict) -> t.Dict: } -def _dict_to_struct(values: t.Dict) -> exp.Struct: - expressions = [] - for key, value in values.items(): - key = exp.to_identifier(key) - value = _dict_to_struct(value) if isinstance(value, dict) else exp.convert(value) - expressions.append(exp.PropertyEQ(this=key, expression=value)) - - return exp.Struct(expressions=expressions) - - @contextmanager def normalize_and_quote( - query: E, dialect: str, default_catalog: t.Optional[str], quote: bool = True + query: E, dialect: DialectType, default_catalog: t.Optional[str], quote: bool = True ) -> t.Iterator[E]: qualify_tables(query, catalog=default_catalog, dialect=dialect) normalize_identifiers(query, dialect=dialect) @@ -1096,3 +1401,92 @@ def interpret_key_value_pairs( e: exp.Tuple, ) -> t.Dict[str, exp.Expression | str | int | float | bool]: return {i.this.name: interpret_expression(i.expression) for i in e.expressions} + + +def extract_func_call( + v: exp.Expression, allow_tuples: bool = False +) -> t.Tuple[str, t.Dict[str, exp.Expression]]: + kwargs = {} + + if isinstance(v, exp.Anonymous): + func = v.name + args = v.expressions + elif isinstance(v, exp.Func): + func = v.sql_name() + args = list(v.args.values()) + elif isinstance(v, exp.Paren): + func = "" + args = [v.this] + elif isinstance(v, exp.Tuple): # airflow only + if not allow_tuples: + raise ConfigError("Audit name is missing (eg. MY_AUDIT())") + + func = "" + args = v.expressions + else: + return v.name.lower(), {} + + for arg in args: + if not isinstance(arg, (exp.PropertyEQ, exp.EQ)): + raise ConfigError( + f"Function '{func}' must be called with key-value arguments like {func}(arg := value)." + ) + kwargs[arg.left.name.lower()] = arg.right + return func.lower(), kwargs + + +def extract_function_calls(func_calls: t.Any, allow_tuples: bool = False) -> t.Any: + """Used for extracting function calls for signals or audits.""" + + if isinstance(func_calls, (exp.Tuple, exp.Array)): + return [extract_func_call(i, allow_tuples=allow_tuples) for i in func_calls.expressions] + if isinstance(func_calls, exp.Paren): + return [extract_func_call(func_calls.this, allow_tuples=allow_tuples)] + if isinstance(func_calls, exp.Expression): + return [extract_func_call(func_calls, allow_tuples=allow_tuples)] + if isinstance(func_calls, list): + function_calls = [] + for entry in func_calls: + if isinstance(entry, dict): + args = entry + name = "" if allow_tuples else entry.pop("name") + elif isinstance(entry, (tuple, list)): + name, args = entry + else: + raise ConfigError(f"Audit must be a dictionary or named tuple. Got {entry}.") + + function_calls.append( + ( + name.lower(), + { + key: parse_one(value) if isinstance(value, str) else value + for key, value in args.items() + }, + ) + ) + + return function_calls + + return func_calls or [] + + +def is_meta_expression(v: t.Any) -> bool: + return isinstance(v, (Audit, Metric, Model)) + + +def replace_merge_table_aliases( + expression: exp.Expression, dialect: t.Optional[str] = None +) -> exp.Expression: + """ + Resolves references from the "source" and "target" tables (or their DBT equivalents) + with the corresponding SQLMesh merge aliases (MERGE_SOURCE_ALIAS and MERGE_TARGET_ALIAS) + """ + from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS + + if isinstance(expression, exp.Column) and (first_part := expression.parts[0]): + if first_part.this.lower() in ("target", "dbt_internal_dest", "__merge_target__"): + first_part.replace(exp.to_identifier(MERGE_TARGET_ALIAS, quoted=True)) + elif first_part.this.lower() in ("source", "dbt_internal_source", "__merge_source__"): + first_part.replace(exp.to_identifier(MERGE_SOURCE_ALIAS, quoted=True)) + + return expression diff --git a/sqlmesh/core/engine_adapter/__init__.py b/sqlmesh/core/engine_adapter/__init__.py index cec483f6f7..ab29885c7b 100644 --- a/sqlmesh/core/engine_adapter/__init__.py +++ b/sqlmesh/core/engine_adapter/__init__.py @@ -7,6 +7,7 @@ EngineAdapterWithIndexSupport, ) from sqlmesh.core.engine_adapter.bigquery import BigQueryEngineAdapter +from sqlmesh.core.engine_adapter.clickhouse import ClickhouseEngineAdapter from sqlmesh.core.engine_adapter.databricks import DatabricksEngineAdapter from sqlmesh.core.engine_adapter.duckdb import DuckDBEngineAdapter from sqlmesh.core.engine_adapter.mssql import MSSQLEngineAdapter @@ -16,11 +17,15 @@ from sqlmesh.core.engine_adapter.snowflake import SnowflakeEngineAdapter from sqlmesh.core.engine_adapter.spark import SparkEngineAdapter from sqlmesh.core.engine_adapter.trino import TrinoEngineAdapter +from sqlmesh.core.engine_adapter.athena import AthenaEngineAdapter +from sqlmesh.core.engine_adapter.risingwave import RisingwaveEngineAdapter +from sqlmesh.core.engine_adapter.fabric import FabricEngineAdapter DIALECT_TO_ENGINE_ADAPTER = { "hive": SparkEngineAdapter, "spark": SparkEngineAdapter, "bigquery": BigQueryEngineAdapter, + "clickhouse": ClickhouseEngineAdapter, "duckdb": DuckDBEngineAdapter, "snowflake": SnowflakeEngineAdapter, "databricks": DatabricksEngineAdapter, @@ -29,6 +34,9 @@ "mysql": MySQLEngineAdapter, "mssql": MSSQLEngineAdapter, "trino": TrinoEngineAdapter, + "athena": AthenaEngineAdapter, + "risingwave": RisingwaveEngineAdapter, + "fabric": FabricEngineAdapter, } DIALECT_ALIASES = { diff --git a/sqlmesh/core/engine_adapter/_typing.py b/sqlmesh/core/engine_adapter/_typing.py index 1ce4268929..77bcf2c015 100644 --- a/sqlmesh/core/engine_adapter/_typing.py +++ b/sqlmesh/core/engine_adapter/_typing.py @@ -1,17 +1,19 @@ import typing as t -import pandas as pd from sqlglot import exp from sqlmesh.utils import optional_import if t.TYPE_CHECKING: + import pandas as pd import pyspark import pyspark.sql.connect.dataframe + from bigframes.session import Session as BigframeSession # noqa + from bigframes.dataframe import DataFrame as BigframeDataFrame snowpark = optional_import("snowflake.snowpark") - Query = t.Union[exp.Query, exp.DerivedTable] + Query = exp.Query PySparkSession = t.Union[pyspark.sql.SparkSession, pyspark.sql.connect.dataframe.SparkSession] PySparkDataFrame = t.Union[pyspark.sql.DataFrame, pyspark.sql.connect.dataframe.DataFrame] @@ -23,7 +25,10 @@ pd.DataFrame, pyspark.sql.DataFrame, pyspark.sql.connect.dataframe.DataFrame, + BigframeDataFrame, SnowparkDataFrame, ] QueryOrDF = t.Union[Query, DF] + GrantsConfig = t.Dict[str, t.List[str]] + DCL = t.TypeVar("DCL", exp.Grant, exp.Revoke) diff --git a/sqlmesh/core/engine_adapter/athena.py b/sqlmesh/core/engine_adapter/athena.py new file mode 100644 index 0000000000..bd84ba5276 --- /dev/null +++ b/sqlmesh/core/engine_adapter/athena.py @@ -0,0 +1,622 @@ +from __future__ import annotations +from functools import lru_cache +import typing as t +import logging +from sqlglot import exp +from sqlmesh.core.dialect import to_schema +from sqlmesh.utils.aws import validate_s3_uri, parse_s3_uri +from sqlmesh.core.engine_adapter.mixins import PandasNativeFetchDFSupportMixin, RowDiffMixin +from sqlmesh.core.engine_adapter.trino import TrinoEngineAdapter +from sqlmesh.core.node import IntervalUnit +import posixpath +from sqlmesh.utils.errors import SQLMeshError +from sqlmesh.core.engine_adapter.shared import ( + CatalogSupport, + DataObject, + DataObjectType, + CommentCreationTable, + CommentCreationView, + SourceQuery, + InsertOverwriteStrategy, +) + +if t.TYPE_CHECKING: + from sqlmesh.core._typing import SchemaName, TableName + from sqlmesh.core.engine_adapter._typing import QueryOrDF + + TableType = t.Union[t.Literal["hive"], t.Literal["iceberg"]] + +logger = logging.getLogger(__name__) + + +class AthenaEngineAdapter(PandasNativeFetchDFSupportMixin, RowDiffMixin): + DIALECT = "athena" + SUPPORTS_TRANSACTIONS = False + SUPPORTS_REPLACE_TABLE = False + # Athena's support for table and column comments is too patchy to consider "supported" + # Hive tables: Table + Column comments are supported + # Iceberg tables: Column comments only + # CTAS, Views: No comment support at all + COMMENT_CREATION_TABLE = CommentCreationTable.UNSUPPORTED + COMMENT_CREATION_VIEW = CommentCreationView.UNSUPPORTED + SCHEMA_DIFFER_KWARGS = TrinoEngineAdapter.SCHEMA_DIFFER_KWARGS + MAX_TIMESTAMP_PRECISION = 3 # copied from Trino + # Athena does not deal with comments well, e.g: + # >>> self._execute('/* test */ DESCRIBE foo') + # pyathena.error.OperationalError: FAILED: ParseException line 1:0 cannot recognize input near '/' '*' 'test' + ATTACH_CORRELATION_ID = False + SUPPORTS_QUERY_EXECUTION_TRACKING = True + SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["DATABASE", "SCHEMA"] + + def __init__( + self, *args: t.Any, s3_warehouse_location: t.Optional[str] = None, **kwargs: t.Any + ): + # Need to pass s3_warehouse_location to the superclass so that it goes into _extra_config + # which means that EngineAdapter.with_settings() keeps this property when it makes a clone + super().__init__(*args, s3_warehouse_location=s3_warehouse_location, **kwargs) + self.s3_warehouse_location = s3_warehouse_location + + self._default_catalog = self._default_catalog or "awsdatacatalog" + + @property + def s3_warehouse_location(self) -> t.Optional[str]: + return self._s3_warehouse_location + + @s3_warehouse_location.setter + def s3_warehouse_location(self, value: t.Optional[str]) -> None: + if value: + value = validate_s3_uri(value, base=True) + self._s3_warehouse_location = value + + @property + def s3_warehouse_location_or_raise(self) -> str: + # this makes tests easier to write without extra null checks to keep mypy happy + if location := self.s3_warehouse_location: + return location + + raise SQLMeshError("s3_warehouse_location was expected to be populated; it isnt") + + @property + def catalog_support(self) -> CatalogSupport: + # Athena has the concept of catalogs but the current catalog is set in the connection parameters with no way to query or change it after that + # It also cant create new catalogs, you have to configure them in AWS. Typically, catalogs that are not "awsdatacatalog" + # are pointers to the "awsdatacatalog" of other AWS accounts + return CatalogSupport.SINGLE_CATALOG_ONLY + + def create_state_table( + self, + table_name: str, + target_columns_to_types: t.Dict[str, exp.DataType], + primary_key: t.Optional[t.Tuple[str, ...]] = None, + ) -> None: + self.create_table( + table_name, + target_columns_to_types, + primary_key=primary_key, + # it's painfully slow, but it works + table_format="iceberg", + ) + + def _get_data_objects( + self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None + ) -> t.List[DataObject]: + """ + Returns all the data objects that exist in the given schema and optionally catalog. + """ + schema_name = to_schema(schema_name) + schema = schema_name.db + query = ( + exp.select( + exp.column("table_catalog").as_("catalog"), + exp.column("table_schema", table="t").as_("schema"), + exp.column("table_name", table="t").as_("name"), + exp.case() + .when( + exp.column("table_type", table="t").eq("BASE TABLE"), + exp.Literal.string("table"), + ) + .else_(exp.column("table_type", table="t")) + .as_("type"), + ) + .from_(exp.to_table("information_schema.tables", alias="t")) + .where(exp.column("table_schema", table="t").eq(schema)) + ) + if object_names: + query = query.where(exp.column("table_name", table="t").isin(*object_names)) + + df = self.fetchdf(query) + + return [ + DataObject( + catalog=row.catalog, # type: ignore + schema=row.schema, # type: ignore + name=row.name, # type: ignore + type=DataObjectType.from_str(row.type), # type: ignore + ) + for row in df.itertuples() + ] + + def columns( + self, table_name: TableName, include_pseudo_columns: bool = False + ) -> t.Dict[str, exp.DataType]: + table = exp.to_table(table_name) + # note: the data_type column contains the full parameterized type, eg 'varchar(10)' + query = ( + exp.select("column_name", "data_type") + .from_("information_schema.columns") + .where(exp.column("table_schema").eq(table.db), exp.column("table_name").eq(table.name)) + .order_by("ordinal_position") + ) + result = self.fetchdf(query, quote_identifiers=True) + return { + str(r.column_name): exp.DataType.build(str(r.data_type)) + for r in result.itertuples(index=False) + } + + def _create_schema( + self, + schema_name: SchemaName, + ignore_if_exists: bool, + warn_on_error: bool, + properties: t.List[exp.Expression], + kind: str, + ) -> None: + if location := self._table_location(table_properties=None, table=exp.to_table(schema_name)): + # don't add extra LocationProperty's if one already exists + if not any(p for p in properties if isinstance(p, exp.LocationProperty)): + properties.append(location) + + return super()._create_schema( + schema_name=schema_name, + ignore_if_exists=ignore_if_exists, + warn_on_error=warn_on_error, + properties=properties, + kind=kind, + ) + + def _build_create_table_exp( + self, + table_name_or_schema: t.Union[exp.Schema, TableName], + expression: t.Optional[exp.Expression], + exists: bool = True, + replace: bool = False, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + table_description: t.Optional[str] = None, + table_kind: t.Optional[str] = None, + partitioned_by: t.Optional[t.List[exp.Expression]] = None, + table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, + **kwargs: t.Any, + ) -> exp.Create: + exists = False if replace else exists + + table: exp.Table + if isinstance(table_name_or_schema, str): + table = exp.to_table(table_name_or_schema) + elif isinstance(table_name_or_schema, exp.Schema): + table = table_name_or_schema.this + else: + table = table_name_or_schema + + properties = self._build_table_properties_exp( + table=table, + expression=expression, + target_columns_to_types=target_columns_to_types, + partitioned_by=partitioned_by, + table_properties=table_properties, + table_description=table_description, + table_kind=table_kind, + **kwargs, + ) + + is_hive = self._table_type(kwargs.get("table_format", None)) == "hive" + + # Filter any PARTITIONED BY properties from the main column list since they cant be specified in both places + # ref: https://docs.aws.amazon.com/athena/latest/ug/partitions.html + if is_hive and partitioned_by and isinstance(table_name_or_schema, exp.Schema): + partitioned_by_column_names = {e.name for e in partitioned_by} + filtered_expressions = [ + e + for e in table_name_or_schema.expressions + if isinstance(e, exp.ColumnDef) and e.this.name not in partitioned_by_column_names + ] + table_name_or_schema.args["expressions"] = filtered_expressions + + return exp.Create( + this=table_name_or_schema, + kind=table_kind or "TABLE", + replace=replace, + exists=exists, + expression=expression, + properties=properties, + ) + + def _build_table_properties_exp( + self, + catalog_name: t.Optional[str] = None, + table_format: t.Optional[str] = None, + storage_format: t.Optional[str] = None, + partitioned_by: t.Optional[t.List[exp.Expression]] = None, + partition_interval_unit: t.Optional[IntervalUnit] = None, + clustered_by: t.Optional[t.List[exp.Expression]] = None, + table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + table_description: t.Optional[str] = None, + table_kind: t.Optional[str] = None, + table: t.Optional[exp.Table] = None, + expression: t.Optional[exp.Expression] = None, + **kwargs: t.Any, + ) -> t.Optional[exp.Properties]: + properties: t.List[exp.Expression] = [] + table_properties = table_properties or {} + + is_hive = self._table_type(table_format) == "hive" + is_iceberg = not is_hive + + if is_hive and not expression: + # Hive tables are CREATE EXTERNAL TABLE, Iceberg tables are CREATE TABLE + # Unless it's a CTAS, those are always CREATE TABLE + properties.append(exp.ExternalProperty()) + + if table_format: + properties.append( + exp.Property(this=exp.var("table_type"), value=exp.Literal.string(table_format)) + ) + + if table_description: + properties.append(exp.SchemaCommentProperty(this=exp.Literal.string(table_description))) + + if partitioned_by: + schema_expressions: t.List[exp.Expression] = [] + if is_hive and target_columns_to_types: + # For Hive-style tables, you cannot include the partitioned by columns in the main set of columns + # In the PARTITIONED BY expression, you also cant just include the column names, you need to include the data type as well + # ref: https://docs.aws.amazon.com/athena/latest/ug/partitions.html + for match_name, match_dtype in self._find_matching_columns( + partitioned_by, target_columns_to_types + ): + column_def = exp.ColumnDef(this=exp.to_identifier(match_name), kind=match_dtype) + schema_expressions.append(column_def) + else: + schema_expressions = partitioned_by + + properties.append( + exp.PartitionedByProperty(this=exp.Schema(expressions=schema_expressions)) + ) + + if clustered_by: + # Athena itself supports CLUSTERED BY, via the syntax CLUSTERED BY (col) INTO BUCKETS + # However, SQLMesh is more closely aligned with BigQuery's notion of clustering and + # defines `clustered_by` as a List[str] with no way of indicating the number of buckets + # + # Athena's concept of CLUSTER BY is more like Iceberg's `bucket(, col)` partition transform + logging.warning("clustered_by is not supported in the Athena adapter at this time") + + if storage_format: + if is_iceberg: + # TBLPROPERTIES('format'='parquet') + table_properties["format"] = exp.Literal.string(storage_format) + else: + # STORED AS PARQUET + properties.append(exp.FileFormatProperty(this=storage_format)) + + if table and (location := self._table_location_or_raise(table_properties, table)): + properties.append(location) + + if is_iceberg and expression: + # To make a CTAS expression persist as iceberg, alongside setting `table_type=iceberg`, you also need to set is_external=false + # Note that SQLGlot does the right thing with LocationProperty and writes it as `location` (Iceberg) instead of `external_location` (Hive) + # ref: https://docs.aws.amazon.com/athena/latest/ug/create-table-as.html#ctas-table-properties + properties.append(exp.Property(this=exp.var("is_external"), value="false")) + + for name, value in table_properties.items(): + properties.append(exp.Property(this=exp.var(name), value=value)) + + if properties: + return exp.Properties(expressions=properties) + + return None + + def drop_table(self, table_name: TableName, exists: bool = True, **kwargs: t.Any) -> None: + table = exp.to_table(table_name) + + if self._query_table_type(table) == "hive": + self._truncate_table(table) + + return super().drop_table(table_name=table, exists=exists, **kwargs) + + def _truncate_table(self, table_name: TableName) -> None: + table = exp.to_table(table_name) + + # Truncating an Iceberg table is just DELETE FROM
+ if self._query_table_type(table) == "iceberg": + return self.delete_from(table, exp.true()) + + # Truncating a partitioned Hive table is dropping all partitions and deleting the data from S3 + if self._is_hive_partitioned_table(table): + self._clear_partition_data(table, exp.true()) + elif s3_location := self._query_table_s3_location(table): + # Truncating a non-partitioned Hive table is clearing out all data in its Location + self._clear_s3_location(s3_location) + + def _table_type(self, table_format: t.Optional[str] = None) -> TableType: + """ + Interpret the "table_format" property to check if this is a Hive or an Iceberg table + """ + if table_format and table_format.lower() == "iceberg": + return "iceberg" + + # if we cant detect any indication of Iceberg, this is a Hive table + return "hive" + + def _query_table_type(self, table: exp.Table) -> t.Optional[TableType]: + if self.table_exists(table): + return self._query_table_type_or_raise(table) + return None + + @lru_cache() + def _query_table_type_or_raise(self, table: exp.Table) -> TableType: + """ + Hit the DB to check if this is a Hive or an Iceberg table. + + Note that in order to @lru_cache() this method, we have the following assumptions: + - The table must exist (otherwise we would cache None if this method was called before table creation and always return None after creation) + - The table type will not change within the same SQLMesh session + """ + # Note: SHOW TBLPROPERTIES gets parsed by SQLGlot as an exp.Command anyway so we just use a string here + # This also means we need to use dialect="hive" instead of dialect="athena" so that the identifiers get the correct quoting (backticks) + for row in self.fetchall(f"SHOW TBLPROPERTIES {table.sql(dialect='hive', identify=True)}"): + # This query returns a single column with values like 'EXTERNAL\tTRUE' + row_lower = row[0].lower() + if "external" in row_lower and "true" in row_lower: + return "hive" + return "iceberg" + + def _is_hive_partitioned_table(self, table: exp.Table) -> bool: + try: + self._list_partitions(table=table, where=None, limit=1) + return True + except Exception as e: + if "TABLE_NOT_FOUND" in str(e): + return False + raise e + + def _table_location_or_raise( + self, table_properties: t.Optional[t.Dict[str, exp.Expression]], table: exp.Table + ) -> exp.LocationProperty: + location = self._table_location(table_properties, table) + if not location: + raise SQLMeshError( + f"Cannot figure out location for table {table}. Please either set `s3_base_location` in `physical_properties` or set `s3_warehouse_location` in the Athena connection config" + ) + return location + + def _table_location( + self, + table_properties: t.Optional[t.Dict[str, exp.Expression]], + table: exp.Table, + ) -> t.Optional[exp.LocationProperty]: + base_uri: str + + # If the user has manually specified a `s3_base_location`, use it + if table_properties and "s3_base_location" in table_properties: + s3_base_location_property = table_properties.pop( + "s3_base_location" + ) # pop because it's handled differently and we dont want it to end up in the TBLPROPERTIES clause + if isinstance(s3_base_location_property, exp.Expression): + base_uri = s3_base_location_property.name + else: + base_uri = s3_base_location_property + + elif self.s3_warehouse_location: + # If the user has set `s3_warehouse_location` in the connection config, the base URI is /// + base_uri = posixpath.join( + self.s3_warehouse_location, table.catalog or "", table.db or "" + ) + else: + return None + + full_uri = validate_s3_uri(posixpath.join(base_uri, table.text("this") or ""), base=True) + return exp.LocationProperty(this=exp.Literal.string(full_uri)) + + def _find_matching_columns( + self, partitioned_by: t.List[exp.Expression], columns_to_types: t.Dict[str, exp.DataType] + ) -> t.List[t.Tuple[str, exp.DataType]]: + matches = [] + for col in partitioned_by: + # TODO: do we care about normalization? + key = col.name + if isinstance(col, exp.Column) and (match_dtype := columns_to_types.get(key)): + matches.append((key, match_dtype)) + return matches + + def replace_query( + self, + table_name: TableName, + query_or_df: QueryOrDF, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + table_description: t.Optional[str] = None, + column_descriptions: t.Optional[t.Dict[str, str]] = None, + source_columns: t.Optional[t.List[str]] = None, + supports_replace_table_override: t.Optional[bool] = None, + **kwargs: t.Any, + ) -> None: + table = exp.to_table(table_name) + + if self._query_table_type(table=table) == "hive": + self.drop_table(table) + + return super().replace_query( + table_name=table, + query_or_df=query_or_df, + target_columns_to_types=target_columns_to_types, + table_description=table_description, + column_descriptions=column_descriptions, + source_columns=source_columns, + **kwargs, + ) + + def _insert_overwrite_by_time_partition( + self, + table_name: TableName, + source_queries: t.List[SourceQuery], + target_columns_to_types: t.Dict[str, exp.DataType], + where: exp.Condition, + **kwargs: t.Any, + ) -> None: + table = exp.to_table(table_name) + + table_type = self._query_table_type(table) + + if table_type == "iceberg": + # Iceberg tables work as expected, we can use the default behaviour + return super()._insert_overwrite_by_time_partition( + table, source_queries, target_columns_to_types, where, **kwargs + ) + + # For Hive tables, we need to drop all the partitions covered by the query and delete the data from S3 + self._clear_partition_data(table, where) + + # Now the data is physically gone, we can continue with inserting a new partition + return super()._insert_overwrite_by_time_partition( + table, + source_queries, + target_columns_to_types, + where, + insert_overwrite_strategy_override=InsertOverwriteStrategy.INTO_IS_OVERWRITE, # since we already cleared the data + **kwargs, + ) + + def _clear_partition_data(self, table: exp.Table, where: t.Optional[exp.Condition]) -> None: + if partitions_to_drop := self._list_partitions(table, where): + for _, s3_location in partitions_to_drop: + logger.debug( + f"Clearing S3 location for '{table.sql(dialect=self.dialect)}': {s3_location}" + ) + self._clear_s3_location(s3_location) + + partition_values = [k for k, _ in partitions_to_drop] + logger.debug( + f"Dropping partitions for '{table.sql(dialect=self.dialect)}' from metastore: {partition_values}" + ) + self._drop_partitions_from_metastore(table, partition_values) + + def _list_partitions( + self, + table: exp.Table, + where: t.Optional[exp.Condition] = None, + limit: t.Optional[int] = None, + ) -> t.List[t.Tuple[t.List[str], str]]: + # Use Athena's magic "$partitions" metadata table to identify the partitions to drop + # Doing it this way allows us to use SQL to filter the partition list + partition_table_name = table.copy() + partition_table_name.this.replace( + exp.to_identifier(f"{table.name}$partitions", quoted=True) + ) + + query = exp.select("*").from_(partition_table_name).where(where) + if limit: + query = query.limit(limit) + + partition_values = [list(r) for r in self.fetchall(query, quote_identifiers=True)] + + if partition_values: + response = self._glue_client.batch_get_partition( + DatabaseName=table.db, + TableName=table.name, + PartitionsToGet=[{"Values": [str(v) for v in lst]} for lst in partition_values], + ) + return sorted( + [(p["Values"], p["StorageDescriptor"]["Location"]) for p in response["Partitions"]] + ) + + return [] + + def _query_table_s3_location(self, table: exp.Table) -> str: + response = self._glue_client.get_table(DatabaseName=table.db, Name=table.name) + + # Athena wont let you create a table without a location, so *theoretically* this should never be empty + if location := response.get("Table", {}).get("StorageDescriptor", {}).get("Location", None): + return location + + raise SQLMeshError(f"Table {table} has no location set in the metastore!") + + def _drop_partitions_from_metastore( + self, table: exp.Table, partition_values: t.List[t.List[str]] + ) -> None: + # todo: switch to itertools.batched when our minimum supported Python is 3.12 + # 25 = maximum number of partitions that batch_delete_partition can process at once + # ref: https://docs.aws.amazon.com/glue/latest/webapi/API_BatchDeletePartition.html#API_BatchDeletePartition_RequestParameters + def _chunks() -> t.Iterable[t.List[t.List[str]]]: + for i in range(0, len(partition_values), 25): + yield partition_values[i : i + 25] + + for batch in _chunks(): + self._glue_client.batch_delete_partition( + DatabaseName=table.db, + TableName=table.name, + PartitionsToDelete=[{"Values": v} for v in batch], + ) + + def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expression]) -> None: + table = exp.to_table(table_name) + + table_type = self._query_table_type(table) + + # If Iceberg, DELETE operations work as expected + if table_type == "iceberg": + return super().delete_from(table, where) + + # If Hive, DELETE is an error + if table_type == "hive": + # However, if there are no actual records to delete, we can make DELETE a no-op + # This simplifies a bunch of calling code that just assumes DELETE works (which to be fair is a reasonable assumption since it does for every other engine) + empty_check = ( + exp.select("*").from_(table).where(where).limit(1) + ) # deliberately not count(*) because we want the engine to stop as soon as it finds a record + if len(self.fetchall(empty_check)) > 0: + raise SQLMeshError("Cannot delete individual records from a Hive table") + + return None + + def _clear_s3_location(self, s3_uri: str) -> None: + s3 = self._s3_client + + bucket, key = parse_s3_uri(s3_uri) + if not key.endswith("/"): + key = f"{key}/" + + keys_to_delete = [] + + # note: uses Delimiter=/ to prevent stepping into folders + # the assumption is that all the files in a partition live directly at the partition `Location` + for page in s3.get_paginator("list_objects_v2").paginate( + Bucket=bucket, Prefix=key, Delimiter="/" + ): + # list_objects_v2() returns 1000 keys per page so that lines up nicely with delete_objects() being able to delete 1000 keys at a time + keys = [item["Key"] for item in page.get("Contents", [])] + if keys: + keys_to_delete.append(keys) + + for chunk in keys_to_delete: + s3.delete_objects(Bucket=bucket, Delete={"Objects": [{"Key": k} for k in chunk]}) + + @property + def _glue_client(self) -> t.Any: + return self._boto3_client("glue") + + @property + def _s3_client(self) -> t.Any: + return self._boto3_client("s3") + + def _boto3_client(self, name: str) -> t.Any: + # use the client factory from PyAthena which is already configured with the correct AWS details + conn = self.connection + return conn.session.client( + name, + region_name=conn.region_name, + config=conn.config, + **conn._client_kwargs, + ) # type: ignore + + def get_current_catalog(self) -> t.Optional[str]: + return self.connection.catalog_name diff --git a/sqlmesh/core/engine_adapter/base.py b/sqlmesh/core/engine_adapter/base.py index 2077c3bc8a..e2dbb51459 100644 --- a/sqlmesh/core/engine_adapter/base.py +++ b/sqlmesh/core/engine_adapter/base.py @@ -14,12 +14,11 @@ import logging import sys import typing as t -from functools import partial +from functools import cached_property, partial -import pandas as pd from sqlglot import Dialect, exp from sqlglot.errors import ErrorLevel -from sqlglot.helper import ensure_list +from sqlglot.helper import ensure_list, seq_get from sqlglot.optimizer.qualify_columns import quote_identifiers from sqlmesh.core.dialect import ( @@ -33,22 +32,38 @@ CommentCreationTable, CommentCreationView, DataObject, + DataObjectType, + EngineRunMode, InsertOverwriteStrategy, SourceQuery, set_catalog, ) from sqlmesh.core.model.kind import TimeColumn -from sqlmesh.core.schema_diff import SchemaDiffer -from sqlmesh.utils import columns_to_types_all_known, random_id -from sqlmesh.utils.connection_pool import create_connection_pool +from sqlmesh.core.schema_diff import SchemaDiffer, TableAlterOperation +from sqlmesh.core.snapshot.execution_tracker import QueryExecutionTracker +from sqlmesh.utils import ( + CorrelationId, + columns_to_types_all_known, + random_id, + get_source_columns_to_types, +) +from sqlmesh.utils.connection_pool import ConnectionPool, create_connection_pool from sqlmesh.utils.date import TimeLike, make_inclusive, to_time_column -from sqlmesh.utils.errors import SQLMeshError, UnsupportedCatalogOperationError +from sqlmesh.utils.errors import ( + MissingDefaultCatalogError, + SQLMeshError, + UnsupportedCatalogOperationError, +) from sqlmesh.utils.pandas import columns_to_types_from_df if t.TYPE_CHECKING: + import pandas as pd + from sqlmesh.core._typing import SchemaName, SessionProperties, TableName from sqlmesh.core.engine_adapter._typing import ( DF, + BigframeSession, + GrantsConfig, PySparkDataFrame, PySparkSession, Query, @@ -62,6 +77,8 @@ MERGE_TARGET_ALIAS = "__MERGE_TARGET__" MERGE_SOURCE_ALIAS = "__MERGE_SOURCE__" +KEY_FOR_CREATABLE_TYPE = "CREATABLE_TYPE" + @set_catalog() class EngineAdapter: @@ -71,7 +88,7 @@ class EngineAdapter: with the underlying engine and data store. Args: - connection_factory: a callable which produces a new Database API-compliant + connection_factory_or_pool: a callable which produces a new Database API-compliant connection on every call. dialect: The dialect with which this adapter is associated. multithreaded: Indicates whether this adapter will be used by more than one thread. @@ -89,34 +106,51 @@ class EngineAdapter: INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.DELETE_INSERT SUPPORTS_MATERIALIZED_VIEWS = False SUPPORTS_MATERIALIZED_VIEW_SCHEMA = False + SUPPORTS_VIEW_SCHEMA = True SUPPORTS_CLONING = False SUPPORTS_MANAGED_MODELS = False - SCHEMA_DIFFER = SchemaDiffer() + SUPPORTS_CREATE_DROP_CATALOG = False + SUPPORTED_DROP_CASCADE_OBJECT_KINDS: t.List[str] = [] + SCHEMA_DIFFER_KWARGS: t.Dict[str, t.Any] = {} SUPPORTS_TUPLE_IN = True - CATALOG_SUPPORT = CatalogSupport.UNSUPPORTED - SUPPORTS_ROW_LEVEL_OP = True HAS_VIEW_BINDING = False SUPPORTS_REPLACE_TABLE = True + SUPPORTS_GRANTS = False DEFAULT_CATALOG_TYPE = DIALECT QUOTE_IDENTIFIERS_IN_VIEWS = True + MAX_IDENTIFIER_LENGTH: t.Optional[int] = None + ATTACH_CORRELATION_ID = True + SUPPORTS_QUERY_EXECUTION_TRACKING = False + SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS = False def __init__( self, - connection_factory: t.Callable[[], t.Any], + connection_factory_or_pool: t.Union[t.Callable[[], t.Any], ConnectionPool], dialect: str = "", sql_gen_kwargs: t.Optional[t.Dict[str, Dialect | bool | str]] = None, multithreaded: bool = False, - cursor_kwargs: t.Optional[t.Dict[str, t.Any]] = None, cursor_init: t.Optional[t.Callable[[t.Any], None]] = None, default_catalog: t.Optional[str] = None, execute_log_level: int = logging.DEBUG, register_comments: bool = True, pre_ping: bool = False, + pretty_sql: bool = False, + shared_connection: bool = False, + correlation_id: t.Optional[CorrelationId] = None, + schema_differ_overrides: t.Optional[t.Dict[str, t.Any]] = None, + query_execution_tracker: t.Optional[QueryExecutionTracker] = None, **kwargs: t.Any, ): self.dialect = dialect.lower() or self.DIALECT - self._connection_pool = create_connection_pool( - connection_factory, multithreaded, cursor_kwargs=cursor_kwargs, cursor_init=cursor_init + self._connection_pool = ( + connection_factory_or_pool + if isinstance(connection_factory_or_pool, ConnectionPool) + else create_connection_pool( + connection_factory_or_pool, + multithreaded, + shared_connection=shared_connection, + cursor_init=cursor_init, + ) ) self._sql_gen_kwargs = sql_gen_kwargs or {} self._default_catalog = default_catalog @@ -124,26 +158,46 @@ def __init__( self._extra_config = kwargs self._register_comments = register_comments self._pre_ping = pre_ping + self._pretty_sql = pretty_sql + self._multithreaded = multithreaded + self.correlation_id = correlation_id + self._schema_differ_overrides = schema_differ_overrides + self._query_execution_tracker = query_execution_tracker + self._data_object_cache: t.Dict[str, t.Optional[DataObject]] = {} + + def with_settings(self, **kwargs: t.Any) -> EngineAdapter: + extra_kwargs = { + "null_connection": True, + "execute_log_level": kwargs.pop("execute_log_level", self._execute_log_level), + "correlation_id": kwargs.pop("correlation_id", self.correlation_id), + "query_execution_tracker": kwargs.pop( + "query_execution_tracker", self._query_execution_tracker + ), + **self._extra_config, + **kwargs, + } - def with_log_level(self, level: int) -> EngineAdapter: adapter = self.__class__( - lambda: None, + self._connection_pool, dialect=self.dialect, sql_gen_kwargs=self._sql_gen_kwargs, default_catalog=self._default_catalog, - execute_log_level=level, register_comments=self._register_comments, - **self._extra_config, + multithreaded=self._multithreaded, + pretty_sql=self._pretty_sql, + **extra_kwargs, ) - adapter._connection_pool = self._connection_pool - return adapter @property def cursor(self) -> t.Any: return self._connection_pool.get_cursor() + @property + def connection(self) -> t.Any: + return self._connection_pool.get() + @property def spark(self) -> t.Optional[PySparkSession]: return None @@ -152,39 +206,103 @@ def spark(self) -> t.Optional[PySparkSession]: def snowpark(self) -> t.Optional[SnowparkSession]: return None + @property + def bigframe(self) -> t.Optional[BigframeSession]: + return None + @property def comments_enabled(self) -> bool: return self._register_comments and self.COMMENT_CREATION_TABLE.is_supported + @property + def catalog_support(self) -> CatalogSupport: + return CatalogSupport.UNSUPPORTED + + @cached_property + def schema_differ(self) -> SchemaDiffer: + return SchemaDiffer( + **{ + **self.SCHEMA_DIFFER_KWARGS, + **(self._schema_differ_overrides or {}), + } + ) + + @property + def _catalog_type_overrides(self) -> t.Dict[str, str]: + return self._extra_config.get("catalog_type_overrides") or {} + @classmethod - def _casted_columns(cls, columns_to_types: t.Dict[str, exp.DataType]) -> t.List[exp.Alias]: + def _casted_columns( + cls, + target_columns_to_types: t.Dict[str, exp.DataType], + source_columns: t.Optional[t.List[str]] = None, + ) -> t.List[exp.Alias]: + source_columns_lookup = set(source_columns or target_columns_to_types) return [ - exp.alias_(exp.cast(exp.column(column), to=kind), column, copy=False) - for column, kind in columns_to_types.items() + exp.alias_( + exp.cast( + exp.column(column, quoted=True) + if column in source_columns_lookup + else exp.Null(), + to=kind, + ), + column, + copy=False, + quoted=True, + ) + for column, kind in target_columns_to_types.items() ] @property def default_catalog(self) -> t.Optional[str]: - if self.CATALOG_SUPPORT.is_unsupported: + if self.catalog_support.is_unsupported: return None default_catalog = self._default_catalog or self.get_current_catalog() if not default_catalog: - raise SQLMeshError("Could not determine a default catalog despite it being supported.") + raise MissingDefaultCatalogError( + "Could not determine a default catalog despite it being supported." + ) return default_catalog + @property + def engine_run_mode(self) -> EngineRunMode: + return EngineRunMode.SINGLE_MODE_ENGINE + def _get_source_queries( self, query_or_df: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]], + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]], target_table: TableName, *, batch_size: t.Optional[int] = None, + source_columns: t.Optional[t.List[str]] = None, ) -> t.List[SourceQuery]: + import pandas as pd + batch_size = self.DEFAULT_BATCH_SIZE if batch_size is None else batch_size - if isinstance(query_or_df, (exp.Query, exp.DerivedTable)): - return [SourceQuery(query_factory=lambda: query_or_df)] # type: ignore + if isinstance(query_or_df, exp.Query): + query_factory = lambda: query_or_df + if source_columns: + source_columns_lookup = set(source_columns) + if not target_columns_to_types: + raise SQLMeshError("columns_to_types must be set if source_columns is set") + if not set(target_columns_to_types).issubset(source_columns_lookup): + select_columns = [ + exp.column(c, quoted=True) + if c in source_columns_lookup + else exp.cast(exp.Null(), target_columns_to_types[c], copy=False).as_( + c, copy=False, quoted=True + ) + for c in target_columns_to_types + ] + query_factory = ( + lambda: exp.Select() + .select(*select_columns) + .from_(query_or_df.subquery("select_source_columns")) + ) + return [SourceQuery(query_factory=query_factory)] # type: ignore - if not columns_to_types: + if not target_columns_to_types: raise SQLMeshError( "It is expected that if a DataFrame is passed in then columns_to_types is set" ) @@ -197,28 +315,41 @@ def _get_source_queries( ) return self._df_to_source_queries( - query_or_df, columns_to_types, batch_size, target_table=target_table + query_or_df, + target_columns_to_types, + batch_size, + target_table=target_table, + source_columns=source_columns, ) def _df_to_source_queries( self, df: DF, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], batch_size: int, target_table: TableName, + source_columns: t.Optional[t.List[str]] = None, ) -> t.List[SourceQuery]: + import pandas as pd + assert isinstance(df, pd.DataFrame) num_rows = len(df.index) batch_size = sys.maxsize if batch_size == 0 else batch_size + + # we need to ensure that the order of the columns in columns_to_types columns matches the order of the values + # they can differ if a user specifies columns() on a python model in a different order than what's in the DataFrame's emitted by that model + df = df[list(source_columns or target_columns_to_types)] values = list(df.itertuples(index=False, name=None)) + return [ SourceQuery( query_factory=partial( self._values_to_sql, - values=values, - columns_to_types=columns_to_types, + values=values, # type: ignore + target_columns_to_types=target_columns_to_types, batch_start=i, batch_end=min(i + batch_size, num_rows), + source_columns=source_columns, ), ) for i in range(0, num_rows, batch_size) @@ -227,37 +358,60 @@ def _df_to_source_queries( def _get_source_queries_and_columns_to_types( self, query_or_df: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]], + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]], target_table: TableName, *, batch_size: t.Optional[int] = None, + source_columns: t.Optional[t.List[str]] = None, ) -> t.Tuple[t.List[SourceQuery], t.Optional[t.Dict[str, exp.DataType]]]: - columns_to_types = self._columns_to_types(query_or_df, columns_to_types) - return ( - self._get_source_queries( - query_or_df, columns_to_types, target_table=target_table, batch_size=batch_size - ), - columns_to_types, + target_columns_to_types, source_columns = self._columns_to_types( + query_or_df, target_columns_to_types, source_columns ) + source_queries = self._get_source_queries( + query_or_df, + target_columns_to_types, + target_table=target_table, + batch_size=batch_size, + source_columns=source_columns, + ) + return source_queries, target_columns_to_types @t.overload def _columns_to_types( - self, query_or_df: DF, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None - ) -> t.Dict[str, exp.DataType]: ... + self, + query_or_df: DF, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, + ) -> t.Tuple[t.Dict[str, exp.DataType], t.List[str]]: ... @t.overload def _columns_to_types( - self, query_or_df: Query, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None - ) -> t.Optional[t.Dict[str, exp.DataType]]: ... + self, + query_or_df: Query, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, + ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: ... def _columns_to_types( - self, query_or_df: QueryOrDF, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None - ) -> t.Optional[t.Dict[str, exp.DataType]]: - if columns_to_types: - return columns_to_types - if isinstance(query_or_df, pd.DataFrame): - return columns_to_types_from_df(t.cast(pd.DataFrame, query_or_df)) - return columns_to_types + self, + query_or_df: QueryOrDF, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, + ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: + import pandas as pd + + if not target_columns_to_types and isinstance(query_or_df, pd.DataFrame): + target_columns_to_types = columns_to_types_from_df(t.cast(pd.DataFrame, query_or_df)) + if not source_columns and target_columns_to_types: + source_columns = list(target_columns_to_types) + # source columns should only contain columns that are defined in the target. If there are extras then + # that means they are intended to be ignored and will be excluded + source_columns = ( + [x for x in source_columns if x in target_columns_to_types] + if source_columns and target_columns_to_types + else None + ) + return target_columns_to_types, source_columns def recycle(self) -> None: """Closes all open connections and releases all allocated resources associated with any thread @@ -280,23 +434,37 @@ def get_catalog_type(self, catalog: t.Optional[str]) -> str: """Intended to be overridden for data virtualization systems like Trino that, depending on the target catalog, require slightly different properties to be set when creating / updating tables """ - if self.CATALOG_SUPPORT.is_unsupported: + if self.catalog_support.is_unsupported: raise UnsupportedCatalogOperationError( f"{self.dialect} does not support catalogs and a catalog was provided: {catalog}" ) - return self.DEFAULT_CATALOG_TYPE + return ( + self._catalog_type_overrides.get(catalog, self.DEFAULT_CATALOG_TYPE) + if catalog + else self.DEFAULT_CATALOG_TYPE + ) + + def get_catalog_type_from_table(self, table: TableName) -> str: + """Get the catalog type from a table name if it has a catalog specified, otherwise return the current catalog type""" + catalog = exp.to_table(table).catalog or self.get_current_catalog() + return self.get_catalog_type(catalog) @property def current_catalog_type(self) -> str: + # `get_catalog_type_from_table` should be used over this property. Reason is that the table that is the target + # of the operation is what matters and not the catalog type of the connection. + # This still remains for legacy reasons and should be refactored out. return self.get_catalog_type(self.get_current_catalog()) def replace_query( self, table_name: TableName, query_or_df: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, + source_columns: t.Optional[t.List[str]] = None, + supports_replace_table_override: t.Optional[bool] = None, **kwargs: t.Any, ) -> None: """Replaces an existing table with a query. @@ -306,15 +474,25 @@ def replace_query( Args: table_name: The name of the table (eg. prod.table) query_or_df: The SQL query to run or a dataframe. - columns_to_types: Only used if a dataframe is provided. A mapping between the column name and its data type. + target_columns_to_types: Only used if a dataframe is provided. A mapping between the column name and its data type. Expected to be ordered to match the order of values in the dataframe. kwargs: Optional create table properties. """ target_table = exp.to_table(table_name) - source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( - query_or_df, columns_to_types, target_table=target_table + + target_data_object = self.get_data_object(target_table) + table_exists = target_data_object is not None + if self.drop_data_object_on_type_mismatch(target_data_object, DataObjectType.TABLE): + table_exists = False + + source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types( + query_or_df, + target_columns_to_types, + target_table=target_table, + source_columns=source_columns, ) - columns_to_types = columns_to_types or self.columns(target_table) + if not target_columns_to_types and table_exists: + target_columns_to_types = self.columns(target_table) query = source_queries[0].query_factory() self_referencing = any( quote_identifiers(table) == quote_identifiers(target_table) @@ -322,52 +500,65 @@ def replace_query( ) # If a query references itself then it must have a table created regardless of approach used. if self_referencing: + if not target_columns_to_types: + raise SQLMeshError( + f"Cannot create a self-referencing table {target_table.sql(dialect=self.dialect)} without knowing the column types. " + "Try casting the columns to an expected type or defining the columns in the model metadata. " + ) self._create_table_from_columns( target_table, - columns_to_types, + target_columns_to_types, exists=True, table_description=table_description, column_descriptions=column_descriptions, + **kwargs, ) # All engines support `CREATE TABLE AS` so we use that if the table doesn't already exist and we # use `CREATE OR REPLACE TABLE AS` if the engine supports it - if self.SUPPORTS_REPLACE_TABLE or not self.table_exists(target_table): + supports_replace_table = ( + self.SUPPORTS_REPLACE_TABLE + if supports_replace_table_override is None + else supports_replace_table_override + ) + if supports_replace_table or not table_exists: return self._create_table_from_source_queries( target_table, source_queries, - columns_to_types, - replace=self.SUPPORTS_REPLACE_TABLE, + target_columns_to_types, + replace=supports_replace_table, table_description=table_description, column_descriptions=column_descriptions, **kwargs, ) - else: - if self_referencing: - with self.temp_table( - self._select_columns(columns_to_types).from_(target_table), - name=target_table, - columns_to_types=columns_to_types, - **kwargs, - ) as temp_table: - for source_query in source_queries: - source_query.add_transform( - lambda node: ( # type: ignore - temp_table # type: ignore - if isinstance(node, exp.Table) - and quote_identifiers(node) == quote_identifiers(target_table) - else node - ) + if self_referencing: + assert target_columns_to_types is not None + with self.temp_table( + self._select_columns(target_columns_to_types).from_(target_table), + name=target_table, + target_columns_to_types=target_columns_to_types, + **kwargs, + ) as temp_table: + for source_query in source_queries: + source_query.add_transform( + lambda node: ( # type: ignore + temp_table # type: ignore + if isinstance(node, exp.Table) + and quote_identifiers(node) == quote_identifiers(target_table) + else node ) - return self._insert_overwrite_by_condition( - target_table, - source_queries, - columns_to_types, ) - return self._insert_overwrite_by_condition( - target_table, - source_queries, - columns_to_types, - ) + return self._insert_overwrite_by_condition( + target_table, + source_queries, + target_columns_to_types, + **kwargs, + ) + return self._insert_overwrite_by_condition( + target_table, + source_queries, + target_columns_to_types, + **kwargs, + ) def create_index( self, @@ -398,10 +589,37 @@ def create_index( ) self.execute(expression) + def _pop_creatable_type_from_properties( + self, + properties: t.Dict[str, exp.Expression], + ) -> t.Optional[exp.Property]: + """Pop out the creatable_type from the properties dictionary (if exists (return it/remove it) else return none). + It also checks that none of the expressions are MATERIALIZE as that conflicts with the `materialize` parameter. + """ + for key in list(properties.keys()): + upper_key = key.upper() + if upper_key == KEY_FOR_CREATABLE_TYPE: + value = properties.pop(key).name + parsed_properties = exp.maybe_parse( + value, into=exp.Properties, dialect=self.dialect + ) + property, *others = parsed_properties.expressions + if others: + # Multiple properties are unsupported today, can look into it in the future if needed + raise SQLMeshError( + f"Invalid creatable_type value with multiple properties: {value}" + ) + if isinstance(property, exp.MaterializedProperty): + raise SQLMeshError( + f"Cannot use {value} as a creatable_type as it conflicts with the `materialize` parameter." + ) + return property + return None + def create_table( self, table_name: TableName, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], primary_key: t.Optional[t.Tuple[str, ...]] = None, exists: bool = True, table_description: t.Optional[str] = None, @@ -412,7 +630,7 @@ def create_table( Args: table_name: The name of the table to create. Can be fully qualified or just table name. - columns_to_types: A mapping between the column name and its data type. + target_columns_to_types: A mapping between the column name and its data type. primary_key: Determines the table primary key. exists: Indicates whether to include the IF NOT EXISTS check. table_description: Optional table description from MODEL DDL. @@ -421,7 +639,7 @@ def create_table( """ self._create_table_from_columns( table_name, - columns_to_types, + target_columns_to_types, primary_key, exists, table_description, @@ -433,12 +651,13 @@ def create_managed_table( self, table_name: TableName, query: Query, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, partitioned_by: t.Optional[t.List[exp.Expression]] = None, - clustered_by: t.Optional[t.List[str]] = None, + clustered_by: t.Optional[t.List[exp.Expression]] = None, table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, + source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> None: """Create a managed table using a query. @@ -448,9 +667,9 @@ def create_managed_table( Args: table_name: The name of the table to create. Can be fully qualified or just table name. query: The SQL query for the engine to base the managed table on - columns_to_types: A mapping between the column name and its data type. + target_columns_to_types: A mapping between the column name and its data type. partitioned_by: The partition columns or engine specific expressions, only applicable in certain engines. (eg. (ds, hour)) - clustered_by: The cluster columns, only applicable in certain engines. (eg. (ds, hour)) + clustered_by: The cluster columns or engine specific expressions, only applicable in certain engines. (eg. (ds, hour)) table_properties: Optional mapping of engine-specific properties to be set on the managed table table_description: Optional table description from MODEL DDL. column_descriptions: Optional column descriptions from model query. @@ -462,10 +681,11 @@ def ctas( self, table_name: TableName, query_or_df: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, exists: bool = True, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, + source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> None: """Create a table using a CTAS statement @@ -473,19 +693,22 @@ def ctas( Args: table_name: The name of the table to create. Can be fully qualified or just table name. query_or_df: The SQL query to run or a dataframe for the CTAS. - columns_to_types: A mapping between the column name and its data type. Required if using a DataFrame. + target_columns_to_types: A mapping between the column name and its data type. Required if using a DataFrame. exists: Indicates whether to include the IF NOT EXISTS check. table_description: Optional table description from MODEL DDL. column_descriptions: Optional column descriptions from model query. kwargs: Optional create table properties. """ - source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( - query_or_df, columns_to_types, target_table=table_name + source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types( + query_or_df, + target_columns_to_types, + target_table=table_name, + source_columns=source_columns, ) return self._create_table_from_source_queries( table_name, source_queries, - columns_to_types, + target_columns_to_types, exists, table_description=table_description, column_descriptions=column_descriptions, @@ -495,26 +718,26 @@ def ctas( def create_state_table( self, table_name: str, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], primary_key: t.Optional[t.Tuple[str, ...]] = None, ) -> None: """Create a table to store SQLMesh internal state. Args: table_name: The name of the table to create. Can be fully qualified or just table name. - columns_to_types: A mapping between the column name and its data type. + target_columns_to_types: A mapping between the column name and its data type. primary_key: Determines the table primary key. """ self.create_table( table_name, - columns_to_types, + target_columns_to_types, primary_key=primary_key, ) def _create_table_from_columns( self, table_name: TableName, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], primary_key: t.Optional[t.Tuple[str, ...]] = None, exists: bool = True, table_description: t.Optional[str] = None, @@ -526,7 +749,7 @@ def _create_table_from_columns( Args: table_name: The name of the table to create. Can be fully qualified or just table name. - columns_to_types: Mapping between the column name and its data type. + target_columns_to_types: Mapping between the column name and its data type. primary_key: Determines the table primary key. exists: Indicates whether to include the IF NOT EXISTS check. table_description: Optional table description from MODEL DDL. @@ -535,14 +758,14 @@ def _create_table_from_columns( """ table = exp.to_table(table_name) - if not columns_to_types_all_known(columns_to_types): + if not columns_to_types_all_known(target_columns_to_types): # It is ok if the columns types are not known if the table already exists and IF NOT EXISTS is set if exists and self.table_exists(table_name): return raise SQLMeshError( "Cannot create a table without knowing the column types. " "Try casting the columns to an expected type or defining the columns in the model metadata. " - f"Columns to types: {columns_to_types}" + f"Columns to types: {target_columns_to_types}" ) primary_key_expression = ( @@ -553,7 +776,7 @@ def _create_table_from_columns( schema = self._build_schema_exp( table, - columns_to_types, + target_columns_to_types, column_descriptions, primary_key_expression, ) @@ -562,7 +785,7 @@ def _create_table_from_columns( schema, None, exists=exists, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, table_description=table_description, **kwargs, ) @@ -584,37 +807,66 @@ def _create_table_from_columns( def _build_schema_exp( self, table: exp.Table, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], column_descriptions: t.Optional[t.Dict[str, str]] = None, expressions: t.Optional[t.List[exp.PrimaryKey]] = None, is_view: bool = False, + materialized: bool = False, ) -> exp.Schema: """ Build a schema expression for a table, columns, column comments, and additional schema properties. """ expressions = expressions or [] + + return exp.Schema( + this=table, + expressions=self._build_column_defs( + target_columns_to_types=target_columns_to_types, + column_descriptions=column_descriptions, + is_view=is_view, + materialized=materialized, + ) + + expressions, + ) + + def _build_column_defs( + self, + target_columns_to_types: t.Dict[str, exp.DataType], + column_descriptions: t.Optional[t.Dict[str, str]] = None, + is_view: bool = False, + materialized: bool = False, + ) -> t.List[exp.ColumnDef]: engine_supports_schema_comments = ( self.COMMENT_CREATION_VIEW.supports_schema_def if is_view else self.COMMENT_CREATION_TABLE.supports_schema_def ) - return exp.Schema( - this=table, - expressions=[ - exp.ColumnDef( - this=exp.to_identifier(column), - kind=None if is_view else kind, # don't include column data type for views - constraints=( - self._build_col_comment_exp(column, column_descriptions) - if column_descriptions - and engine_supports_schema_comments - and self.comments_enabled - else None - ), - ) - for column, kind in columns_to_types.items() - ] - + expressions, + return [ + self._build_column_def( + column, + column_descriptions=column_descriptions, + engine_supports_schema_comments=engine_supports_schema_comments, + col_type=None if is_view else kind, # don't include column data type for views + ) + for column, kind in target_columns_to_types.items() + ] + + def _build_column_def( + self, + col_name: str, + column_descriptions: t.Optional[t.Dict[str, str]] = None, + engine_supports_schema_comments: bool = False, + col_type: t.Optional[exp.DATA_TYPE] = None, + nested_names: t.List[str] = [], + ) -> exp.ColumnDef: + return exp.ColumnDef( + this=exp.to_identifier(col_name), + kind=col_type, + constraints=( + self._build_col_comment_exp(col_name, column_descriptions) + if engine_supports_schema_comments and self.comments_enabled and column_descriptions + else None + ), ) def _build_col_comment_exp( @@ -635,12 +887,13 @@ def _create_table_from_source_queries( self, table_name: TableName, source_queries: t.List[SourceQuery], - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, exists: bool = True, replace: bool = False, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, + track_rows_processed: bool = True, **kwargs: t.Any, ) -> None: table = exp.to_table(table_name) @@ -659,36 +912,42 @@ def _create_table_from_source_queries( # types, and for evaluation methods like `LogicalReplaceQueryMixin.replace_query()` # calls and SCD Type 2 model calls. schema = None - columns_to_types_known = columns_to_types and columns_to_types_all_known(columns_to_types) + target_columns_to_types_known = target_columns_to_types and columns_to_types_all_known( + target_columns_to_types + ) if ( column_descriptions - and columns_to_types_known + and target_columns_to_types_known and self.COMMENT_CREATION_TABLE.is_in_schema_def_ctas and self.comments_enabled ): - schema = self._build_schema_exp(table, columns_to_types, column_descriptions) # type: ignore + schema = self._build_schema_exp(table, target_columns_to_types, column_descriptions) # type: ignore with self.transaction(condition=len(source_queries) > 1): for i, source_query in enumerate(source_queries): with source_query as query: - if columns_to_types and columns_to_types_known: + if target_columns_to_types and target_columns_to_types_known: query = self._order_projections_and_filter( - query, columns_to_types, coerce_types=True + query, target_columns_to_types, coerce_types=True ) if i == 0: self._create_table( schema if schema else table, query, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, exists=exists, replace=replace, table_description=table_description, table_kind=table_kind, + track_rows_processed=track_rows_processed, **kwargs, ) else: self._insert_append_query( - table_name, query, columns_to_types or self.columns(table) + table_name, + query, + target_columns_to_types or self.columns(table), + track_rows_processed=track_rows_processed, ) # Register comments with commands if the engine supports comments and we weren't able to @@ -708,10 +967,11 @@ def _create_table( expression: t.Optional[exp.Expression], exists: bool = True, replace: bool = False, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, + track_rows_processed: bool = True, **kwargs: t.Any, ) -> None: self.execute( @@ -720,7 +980,7 @@ def _create_table( expression=expression, exists=exists, replace=replace, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, table_description=( table_description if self.COMMENT_CREATION_TABLE.supports_schema_def and self.comments_enabled @@ -728,8 +988,16 @@ def _create_table( ), table_kind=table_kind, **kwargs, - ) + ), + track_rows_processed=track_rows_processed, + ) + # Extract table name to clear cache + table_name = ( + table_name_or_schema.this + if isinstance(table_name_or_schema, exp.Schema) + else table_name_or_schema ) + self._clear_data_object_cache(table_name) def _build_create_table_exp( self, @@ -737,7 +1005,7 @@ def _build_create_table_exp( expression: t.Optional[exp.Expression], exists: bool = True, replace: bool = False, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, table_kind: t.Optional[str] = None, **kwargs: t.Any, @@ -755,7 +1023,7 @@ def _build_create_table_exp( self._build_table_properties_exp( **kwargs, catalog_name=catalog_name, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, table_description=table_description, table_kind=table_kind, ) @@ -776,29 +1044,24 @@ def create_table_like( target_table_name: TableName, source_table_name: TableName, exists: bool = True, + **kwargs: t.Any, ) -> None: + """Create a table to store SQLMesh internal state based on the definition of another table, including any + column attributes and indexes defined in the original table. + + Args: + target_table_name: The name of the table to create. Can be fully qualified or just table name. + source_table_name: The name of the table to base the new table on. """ - Create a table like another table or view. - """ - target_table = exp.to_table(target_table_name) - source_table = exp.to_table(source_table_name) - create_expression = exp.Create( - this=target_table, - kind="TABLE", - exists=exists, - properties=exp.Properties( - expressions=[ - exp.LikeProperty(this=source_table), - ] - ), - ) - self.execute(create_expression) + self._create_table_like(target_table_name, source_table_name, exists=exists, **kwargs) + self._clear_data_object_cache(target_table_name) def clone_table( self, target_table_name: TableName, source_table_name: TableName, replace: bool = False, + exists: bool = True, clone_kwargs: t.Optional[t.Dict[str, t.Any]] = None, **kwargs: t.Any, ) -> None: @@ -808,14 +1071,18 @@ def clone_table( target_table_name: The name of the table that should be created. source_table_name: The name of the source table that should be cloned. replace: Whether or not to replace an existing table. + exists: Indicates whether to include the IF NOT EXISTS check. """ if not self.SUPPORTS_CLONING: raise NotImplementedError(f"Engine does not support cloning: {type(self)}") + + kwargs.pop("rendered_physical_properties", None) self.execute( exp.Create( this=exp.to_table(target_table_name), kind="TABLE", replace=replace, + exists=exists, clone=exp.Clone( this=exp.to_table(source_table_name), **(clone_kwargs or {}), @@ -823,15 +1090,38 @@ def clone_table( **kwargs, ) ) + self._clear_data_object_cache(target_table_name) + + def drop_data_object(self, data_object: DataObject, ignore_if_not_exists: bool = True) -> None: + """Drops a data object of arbitrary type. + + Args: + data_object: The data object to drop. + ignore_if_not_exists: If True, no error will be raised if the data object does not exist. + """ + if data_object.type.is_view: + self.drop_view(data_object.to_table(), ignore_if_not_exists=ignore_if_not_exists) + elif data_object.type.is_materialized_view: + self.drop_view( + data_object.to_table(), ignore_if_not_exists=ignore_if_not_exists, materialized=True + ) + elif data_object.type.is_table: + self.drop_table(data_object.to_table(), exists=ignore_if_not_exists) + elif data_object.type.is_managed_table: + self.drop_managed_table(data_object.to_table(), exists=ignore_if_not_exists) + else: + raise SQLMeshError( + f"Can't drop data object '{data_object.to_table().sql(dialect=self.dialect)}' of type '{data_object.type.value}'" + ) - def drop_table(self, table_name: TableName, exists: bool = True) -> None: + def drop_table(self, table_name: TableName, exists: bool = True, **kwargs: t.Any) -> None: """Drops a table. Args: table_name: The name of the table to drop. exists: If exists, defaults to True. """ - self._drop_tablelike_object(table_name, exists) + self._drop_object(name=table_name, exists=exists, **kwargs) def drop_managed_table(self, table_name: TableName, exists: bool = True) -> None: """Drops a managed table. @@ -842,56 +1132,79 @@ def drop_managed_table(self, table_name: TableName, exists: bool = True) -> None """ raise NotImplementedError(f"Engine does not support managed tables: {type(self)}") - def _drop_tablelike_object( - self, table_name: TableName, exists: bool = True, kind: str = "TABLE" + def _drop_object( + self, + name: TableName | SchemaName, + exists: bool = True, + kind: str = "TABLE", + cascade: bool = False, + **drop_args: t.Any, ) -> None: - """Drops a "tablelike object". + """Drops an object. - A "tablelike object" could be a TABLE or a DYNAMIC TABLE or a TEMPORARY TABLE etc depending on the context. + An object could be a DATABASE, SCHEMA, VIEW, TABLE, DYNAMIC TABLE, TEMPORARY TABLE etc depending on the :kind. Args: - table_name: The name of the table to drop. + name: The name of the table to drop. exists: If exists, defaults to True. kind: What kind of object to drop. Defaults to TABLE + cascade: Whether or not to DROP ... CASCADE. + Note that this is ignored for :kind's that are not present in self.SUPPORTED_DROP_CASCADE_OBJECT_KINDS + **drop_args: Any extra arguments to set on the Drop expression """ - drop_expression = exp.Drop(this=exp.to_table(table_name), kind=kind, exists=exists) - self.execute(drop_expression) + if cascade and kind.upper() in self.SUPPORTED_DROP_CASCADE_OBJECT_KINDS: + drop_args["cascade"] = cascade - def get_alter_expressions( + self.execute(exp.Drop(this=exp.to_table(name), kind=kind, exists=exists, **drop_args)) + self._clear_data_object_cache(name) + + def get_alter_operations( self, current_table_name: TableName, target_table_name: TableName, - ) -> t.List[exp.AlterTable]: + *, + ignore_destructive: bool = False, + ignore_additive: bool = False, + ) -> t.List[TableAlterOperation]: """ Determines the alter statements needed to change the current table into the structure of the target table. """ - return self.SCHEMA_DIFFER.compare_columns( - current_table_name, - self.columns(current_table_name), - self.columns(target_table_name), + return t.cast( + t.List[TableAlterOperation], + self.schema_differ.compare_columns( + current_table_name, + self.columns(current_table_name), + self.columns(target_table_name), + ignore_destructive=ignore_destructive, + ignore_additive=ignore_additive, + ), ) def alter_table( self, - alter_expressions: t.List[exp.AlterTable], + alter_expressions: t.Union[t.List[exp.Alter], t.List[TableAlterOperation]], ) -> None: """ Performs the alter statements to change the current table into the structure of the target table. """ with self.transaction(): - for alter_expression in alter_expressions: + for alter_expression in [ + x.expression if isinstance(x, TableAlterOperation) else x for x in alter_expressions + ]: self.execute(alter_expression) def create_view( self, view_name: TableName, query_or_df: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, replace: bool = True, materialized: bool = False, + materialized_properties: t.Optional[t.Dict[str, t.Any]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, view_properties: t.Optional[t.Dict[str, exp.Expression]] = None, + source_columns: t.Optional[t.List[str]] = None, **create_kwargs: t.Any, ) -> None: """Create a view with a query or dataframe. @@ -902,50 +1215,104 @@ def create_view( Args: view_name: The view name. query_or_df: A query or dataframe. - columns_to_types: Columns to use in the view statement. + target_columns_to_types: Columns to use in the view statement. replace: Whether or not to replace an existing view defaults to True. materialized: Whether to create a a materialized view. Only used for engines that support this feature. + materialized_properties: Optional materialized view properties to add to the view. table_description: Optional table description from MODEL DDL. column_descriptions: Optional column descriptions from model query. view_properties: Optional view properties to add to the view. create_kwargs: Additional kwargs to pass into the Create expression """ + import pandas as pd + + if materialized_properties and not materialized: + raise SQLMeshError("Materialized properties are only supported for materialized views") + + query_or_df = self._native_df_to_pandas_df(query_or_df) + if isinstance(query_or_df, pd.DataFrame): values: t.List[t.Tuple[t.Any, ...]] = list( query_or_df.itertuples(index=False, name=None) ) - columns_to_types = columns_to_types or self._columns_to_types(query_or_df) - if not columns_to_types: + target_columns_to_types, source_columns = self._columns_to_types( + query_or_df, target_columns_to_types, source_columns + ) + if not target_columns_to_types: raise SQLMeshError("columns_to_types must be provided for dataframes") + source_columns_to_types = get_source_columns_to_types( + target_columns_to_types, source_columns + ) query_or_df = self._values_to_sql( values, - columns_to_types, + source_columns_to_types, batch_start=0, batch_end=len(values), ) - source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( - query_or_df, columns_to_types, batch_size=0, target_table=view_name + source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types( + query_or_df, + target_columns_to_types, + batch_size=0, + target_table=view_name, + source_columns=source_columns, ) if len(source_queries) != 1: raise SQLMeshError("Only one source query is supported for creating views") schema: t.Union[exp.Table, exp.Schema] = exp.to_table(view_name) - if columns_to_types: + if target_columns_to_types: schema = self._build_schema_exp( - exp.to_table(view_name), columns_to_types, column_descriptions, is_view=True + exp.to_table(view_name), + target_columns_to_types, + column_descriptions, + is_view=True, + materialized=materialized, ) properties = create_kwargs.pop("properties", None) if not properties: properties = exp.Properties(expressions=[]) + if view_properties: + table_type = self._pop_creatable_type_from_properties(view_properties) + if table_type: + properties.append("expressions", table_type) + if materialized and self.SUPPORTS_MATERIALIZED_VIEWS: properties.append("expressions", exp.MaterializedProperty()) if not self.SUPPORTS_MATERIALIZED_VIEW_SCHEMA and isinstance(schema, exp.Schema): schema = schema.this + if not self.SUPPORTS_VIEW_SCHEMA and isinstance(schema, exp.Schema): + schema = schema.this + + if materialized_properties: + partitioned_by = materialized_properties.pop("partitioned_by", None) + clustered_by = materialized_properties.pop("clustered_by", None) + if ( + partitioned_by + and ( + partitioned_by_prop := self._build_partitioned_by_exp( + partitioned_by, **materialized_properties + ) + ) + is not None + ): + materialized_properties["catalog_name"] = exp.to_table(view_name).catalog + properties.append("expressions", partitioned_by_prop) + if ( + clustered_by + and ( + clustered_by_prop := self._build_clustered_by_exp( + clustered_by, **materialized_properties + ) + ) + is not None + ): + properties.append("expressions", clustered_by_prop) + create_view_properties = self._build_view_properties_exp( view_properties, ( @@ -953,14 +1320,25 @@ def create_view( if self.COMMENT_CREATION_VIEW.supports_schema_def and self.comments_enabled else None ), + physical_cluster=create_kwargs.pop("physical_cluster", None), ) if create_view_properties: for view_property in create_view_properties.expressions: - properties.append("expressions", view_property) + # Small hack to make sure SECURE goes at the beginning before materialized as required by Snowflake + if isinstance(view_property, exp.SecureProperty): + properties.set("expressions", view_property, index=0, overwrite=False) + else: + properties.append("expressions", view_property) if properties.expressions: create_kwargs["properties"] = properties + if replace: + self.drop_data_object_on_type_mismatch( + self.get_data_object(view_name), + DataObjectType.VIEW if not materialized else DataObjectType.MATERIALIZED_VIEW, + ) + with source_queries[0] as query: self.execute( exp.Create( @@ -973,6 +1351,8 @@ def create_view( quote_identifiers=self.QUOTE_IDENTIFIERS_IN_VIEWS, ) + self._clear_data_object_cache(view_name) + # Register table comment with commands if the engine doesn't support doing it in CREATE if ( table_description @@ -989,12 +1369,12 @@ def create_view( self.COMMENT_CREATION_VIEW.is_comment_command_only or ( self.COMMENT_CREATION_VIEW.is_in_schema_def_and_commands - and not columns_to_types + and not target_columns_to_types ) ) and self.comments_enabled ): - self._create_column_comments(view_name, column_descriptions, "VIEW") + self._create_column_comments(view_name, column_descriptions, "VIEW", materialized) @set_catalog() def create_schema( @@ -1002,34 +1382,55 @@ def create_schema( schema_name: SchemaName, ignore_if_exists: bool = True, warn_on_error: bool = True, + properties: t.Optional[t.List[exp.Expression]] = None, + ) -> None: + properties = properties or [] + return self._create_schema( + schema_name=schema_name, + ignore_if_exists=ignore_if_exists, + warn_on_error=warn_on_error, + properties=properties, + kind="SCHEMA", + ) + + def _create_schema( + self, + schema_name: SchemaName, + ignore_if_exists: bool, + warn_on_error: bool, + properties: t.List[exp.Expression], + kind: str, ) -> None: """Create a schema from a name or qualified table name.""" try: self.execute( exp.Create( this=to_schema(schema_name), - kind="SCHEMA", + kind=kind, exists=ignore_if_exists, + properties=exp.Properties( # this renders as '' (empty string) if expressions is empty + expressions=properties + ), ) ) except Exception as e: if not warn_on_error: raise - logger.warning("Failed to create schema '%s': %s", schema_name, e) + logger.warning("Failed to create %s '%s': %s", kind.lower(), schema_name, e) def drop_schema( self, schema_name: SchemaName, ignore_if_not_exists: bool = True, cascade: bool = False, + **drop_args: t.Dict[str, exp.Expression], ) -> None: - self.execute( - exp.Drop( - this=to_schema(schema_name), - kind="SCHEMA", - exists=ignore_if_not_exists, - cascade=cascade, - ) + return self._drop_object( + name=schema_name, + exists=ignore_if_not_exists, + kind="SCHEMA", + cascade=cascade, + **drop_args, ) def drop_view( @@ -1040,14 +1441,28 @@ def drop_view( **kwargs: t.Any, ) -> None: """Drop a view.""" - self.execute( - exp.Drop( - this=exp.to_table(view_name), - exists=ignore_if_not_exists, - materialized=materialized and self.SUPPORTS_MATERIALIZED_VIEWS, - kind="VIEW", - **kwargs, - ) + self._drop_object( + name=view_name, + exists=ignore_if_not_exists, + kind="VIEW", + materialized=materialized and self.SUPPORTS_MATERIALIZED_VIEWS, + **kwargs, + ) + + def create_catalog(self, catalog_name: str | exp.Identifier) -> None: + return self._create_catalog(exp.parse_identifier(catalog_name, dialect=self.dialect)) + + def _create_catalog(self, catalog_name: exp.Identifier) -> None: + raise SQLMeshError( + f"Unable to create catalog '{catalog_name.sql(dialect=self.dialect)}' as automatic catalog management is not implemented in the {self.dialect} engine." + ) + + def drop_catalog(self, catalog_name: str | exp.Identifier) -> None: + return self._drop_catalog(exp.parse_identifier(catalog_name, dialect=self.dialect)) + + def _drop_catalog(self, catalog_name: exp.Identifier) -> None: + raise SQLMeshError( + f"Unable to drop catalog '{catalog_name.sql(dialect=self.dialect)}' as automatic catalog management is not implemented in the {self.dialect} engine." ) def columns( @@ -1067,8 +1482,14 @@ def columns( } def table_exists(self, table_name: TableName) -> bool: + table = exp.to_table(table_name) + data_object_cache_key = _get_data_object_cache_key(table.catalog, table.db, table.name) + if data_object_cache_key in self._data_object_cache: + logger.debug("Table existence cache hit: %s", data_object_cache_key) + return self._data_object_cache[data_object_cache_key] is not None + try: - self.execute(exp.Describe(this=exp.to_table(table_name), kind="TABLE")) + self.execute(exp.Describe(this=table, kind="TABLE")) return True except Exception: return False @@ -1080,54 +1501,80 @@ def insert_append( self, table_name: TableName, query_or_df: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + track_rows_processed: bool = True, + source_columns: t.Optional[t.List[str]] = None, ) -> None: - source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( - query_or_df, columns_to_types, target_table=table_name + source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types( + query_or_df, + target_columns_to_types, + target_table=table_name, + source_columns=source_columns, + ) + self._insert_append_source_queries( + table_name, source_queries, target_columns_to_types, track_rows_processed ) - self._insert_append_source_queries(table_name, source_queries, columns_to_types) def _insert_append_source_queries( self, table_name: TableName, source_queries: t.List[SourceQuery], - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + track_rows_processed: bool = True, ) -> None: with self.transaction(condition=len(source_queries) > 0): - columns_to_types = columns_to_types or self.columns(table_name) + target_columns_to_types = target_columns_to_types or self.columns(table_name) for source_query in source_queries: with source_query as query: - self._insert_append_query(table_name, query, columns_to_types) + self._insert_append_query( + table_name, + query, + target_columns_to_types, + track_rows_processed=track_rows_processed, + ) def _insert_append_query( self, table_name: TableName, query: Query, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], order_projections: bool = True, + track_rows_processed: bool = True, ) -> None: if order_projections: - query = self._order_projections_and_filter(query, columns_to_types) - self.execute(exp.insert(query, table_name, columns=list(columns_to_types))) + query = self._order_projections_and_filter(query, target_columns_to_types) + self.execute( + exp.insert(query, table_name, columns=list(target_columns_to_types)), + track_rows_processed=track_rows_processed, + ) def insert_overwrite_by_partition( self, table_name: TableName, query_or_df: QueryOrDF, partitioned_by: t.List[exp.Expression], - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, ) -> None: if self.INSERT_OVERWRITE_STRATEGY.is_insert_overwrite: target_table = exp.to_table(table_name) - source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( - query_or_df, columns_to_types, target_table=target_table + source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types( + query_or_df, + target_columns_to_types, + target_table=target_table, + source_columns=source_columns, ) self._insert_overwrite_by_condition( - table_name, source_queries, columns_to_types=columns_to_types + table_name, source_queries, target_columns_to_types=target_columns_to_types ) else: self._replace_by_key( - table_name, query_or_df, columns_to_types, partitioned_by, is_unique_key=False + table_name, + query_or_df, + target_columns_to_types, + partitioned_by, + is_unique_key=False, + source_columns=source_columns, ) def insert_overwrite_by_time_partition( @@ -1140,14 +1587,22 @@ def insert_overwrite_by_time_partition( [TimeLike, t.Optional[t.Dict[str, exp.DataType]]], exp.Expression ], time_column: TimeColumn | exp.Expression | str, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> None: - source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( - query_or_df, columns_to_types, target_table=table_name + source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types( + query_or_df, + target_columns_to_types, + target_table=table_name, + source_columns=source_columns, ) - columns_to_types = columns_to_types or self.columns(table_name) - low, high = [time_formatter(dt, columns_to_types) for dt in make_inclusive(start, end)] + if not target_columns_to_types or not columns_to_types_all_known(target_columns_to_types): + target_columns_to_types = self.columns(table_name) + low, high = [ + time_formatter(dt, target_columns_to_types) + for dt in make_inclusive(start, end, self.dialect) + ] if isinstance(time_column, TimeColumn): time_column = time_column.column where = exp.Between( @@ -1155,31 +1610,48 @@ def insert_overwrite_by_time_partition( low=low, high=high, ) - self._insert_overwrite_by_condition(table_name, source_queries, columns_to_types, where) + return self._insert_overwrite_by_time_partition( + table_name, source_queries, target_columns_to_types, where, **kwargs + ) + + def _insert_overwrite_by_time_partition( + self, + table_name: TableName, + source_queries: t.List[SourceQuery], + target_columns_to_types: t.Dict[str, exp.DataType], + where: exp.Condition, + **kwargs: t.Any, + ) -> None: + return self._insert_overwrite_by_condition( + table_name, source_queries, target_columns_to_types, where, **kwargs + ) def _values_to_sql( self, values: t.List[t.Tuple[t.Any, ...]], - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], batch_start: int, batch_end: int, alias: str = "t", + source_columns: t.Optional[t.List[str]] = None, ) -> Query: return select_from_values_for_batch_range( values=values, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, batch_start=batch_start, batch_end=batch_end, alias=alias, + source_columns=source_columns, ) def _insert_overwrite_by_condition( self, table_name: TableName, source_queries: t.List[SourceQuery], - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, where: t.Optional[exp.Condition] = None, insert_overwrite_strategy_override: t.Optional[InsertOverwriteStrategy] = None, + **kwargs: t.Any, ) -> None: table = exp.to_table(table_name) insert_overwrite_strategy = ( @@ -1188,25 +1660,51 @@ def _insert_overwrite_by_condition( with self.transaction( condition=len(source_queries) > 0 or insert_overwrite_strategy.is_delete_insert ): - columns_to_types = columns_to_types or self.columns(table_name) + target_columns_to_types = target_columns_to_types or self.columns(table_name) for i, source_query in enumerate(source_queries): with source_query as query: - query = self._order_projections_and_filter(query, columns_to_types, where=where) + query = self._order_projections_and_filter( + query, target_columns_to_types, where=where + ) if i > 0 or insert_overwrite_strategy.is_delete_insert: if i == 0: self.delete_from(table_name, where=where or exp.true()) self._insert_append_query( table_name, query, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, order_projections=False, ) + elif insert_overwrite_strategy.is_merge: + columns = [exp.column(col) for col in target_columns_to_types] + when_not_matched_by_source = exp.When( + matched=False, + source=True, + condition=where, + then=exp.Delete(), + ) + when_not_matched_by_target = exp.When( + matched=False, + source=False, + then=exp.Insert( + this=exp.Tuple(expressions=columns), + expression=exp.Tuple(expressions=columns), + ), + ) + self._merge( + target_table=table_name, + query=query, + on=exp.false(), + whens=exp.Whens( + expressions=[when_not_matched_by_source, when_not_matched_by_target] + ), + ) else: insert_exp = exp.insert( query, table, columns=( - list(columns_to_types) + list(target_columns_to_types) if not insert_overwrite_strategy.is_replace_where else None ), @@ -1214,7 +1712,7 @@ def _insert_overwrite_by_condition( ) if insert_overwrite_strategy.is_replace_where: insert_exp.set("where", where or exp.true()) - self.execute(insert_exp) + self.execute(insert_exp, track_rows_processed=True) def update_table( self, @@ -1229,19 +1727,14 @@ def _merge( target_table: TableName, query: Query, on: exp.Expression, - match_expressions: t.List[exp.When], + whens: exp.Whens, ) -> None: this = exp.alias_(exp.to_table(target_table), alias=MERGE_TARGET_ALIAS, table=True) using = exp.alias_( exp.Subquery(this=query), alias=MERGE_SOURCE_ALIAS, copy=False, table=True ) self.execute( - exp.Merge( - this=this, - using=using, - on=on, - expressions=match_expressions, - ) + exp.Merge(this=this, using=using, on=on, whens=whens), track_rows_processed=True ) def scd_type_2_by_time( @@ -1251,14 +1744,15 @@ def scd_type_2_by_time( unique_key: t.Sequence[exp.Expression], valid_from_col: exp.Column, valid_to_col: exp.Column, - execution_time: TimeLike, + execution_time: t.Union[TimeLike, exp.Column], updated_at_col: exp.Column, invalidate_hard_deletes: bool = True, updated_at_as_valid_from: bool = False, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, truncate: bool = False, + source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> None: self._scd_type_2( @@ -1271,10 +1765,12 @@ def scd_type_2_by_time( updated_at_col=updated_at_col, invalidate_hard_deletes=invalidate_hard_deletes, updated_at_as_valid_from=updated_at_as_valid_from, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, table_description=table_description, column_descriptions=column_descriptions, truncate=truncate, + source_columns=source_columns, + **kwargs, ) def scd_type_2_by_column( @@ -1284,14 +1780,15 @@ def scd_type_2_by_column( unique_key: t.Sequence[exp.Expression], valid_from_col: exp.Column, valid_to_col: exp.Column, - execution_time: TimeLike, - check_columns: t.Union[exp.Star, t.Sequence[exp.Column]], + execution_time: t.Union[TimeLike, exp.Column], + check_columns: t.Union[exp.Star, t.Sequence[exp.Expression]], invalidate_hard_deletes: bool = True, execution_time_as_valid_from: bool = False, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, truncate: bool = False, + source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> None: self._scd_type_2( @@ -1302,12 +1799,14 @@ def scd_type_2_by_column( valid_to_col=valid_to_col, execution_time=execution_time, check_columns=check_columns, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, invalidate_hard_deletes=invalidate_hard_deletes, execution_time_as_valid_from=execution_time_as_valid_from, table_description=table_description, column_descriptions=column_descriptions, truncate=truncate, + source_columns=source_columns, + **kwargs, ) def _scd_type_2( @@ -1317,32 +1816,51 @@ def _scd_type_2( unique_key: t.Sequence[exp.Expression], valid_from_col: exp.Column, valid_to_col: exp.Column, - execution_time: TimeLike, + execution_time: t.Union[TimeLike, exp.Column], invalidate_hard_deletes: bool = True, updated_at_col: t.Optional[exp.Column] = None, - check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Column]]] = None, + check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Expression]]] = None, updated_at_as_valid_from: bool = False, execution_time_as_valid_from: bool = False, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, truncate: bool = False, + source_columns: t.Optional[t.List[str]] = None, + **kwargs: t.Any, ) -> None: - source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( - source_table, columns_to_types, target_table=target_table, batch_size=0 - ) - columns_to_types = columns_to_types or self.columns(target_table) + def remove_managed_columns( + cols_to_types: t.Dict[str, exp.DataType], + ) -> t.Dict[str, exp.DataType]: + return { + k: v for k, v in cols_to_types.items() if k not in {valid_from_name, valid_to_name} + } + valid_from_name = valid_from_col.name valid_to_name = valid_to_col.name - updated_at_name = updated_at_col.name if updated_at_col else None + target_columns_to_types = target_columns_to_types or self.columns(target_table) if ( - valid_from_name not in columns_to_types - or valid_to_name not in columns_to_types - or not columns_to_types_all_known(columns_to_types) + valid_from_name not in target_columns_to_types + or valid_to_name not in target_columns_to_types + or not columns_to_types_all_known(target_columns_to_types) ): - columns_to_types = self.columns(target_table) - if not columns_to_types: + target_columns_to_types = self.columns(target_table) + unmanaged_columns_to_types = ( + remove_managed_columns(target_columns_to_types) if target_columns_to_types else None + ) + source_queries, unmanaged_columns_to_types = self._get_source_queries_and_columns_to_types( + source_table, + unmanaged_columns_to_types, + target_table=target_table, + batch_size=0, + source_columns=source_columns, + ) + updated_at_name = updated_at_col.name if updated_at_col else None + if not target_columns_to_types: raise SQLMeshError(f"Could not get columns_to_types. Does {target_table} exist?") + unmanaged_columns_to_types = unmanaged_columns_to_types or remove_managed_columns( + target_columns_to_types + ) if not unique_key: raise SQLMeshError("unique_key must be provided for SCD Type 2") if check_columns and updated_at_col: @@ -1357,19 +1875,15 @@ def _scd_type_2( raise SQLMeshError( "Cannot use `execution_time_as_valid_from` without `check_columns` for SCD Type 2" ) - if updated_at_name and updated_at_name not in columns_to_types: + if updated_at_name and updated_at_name not in target_columns_to_types: raise SQLMeshError( f"Column {updated_at_name} not found in {target_table}. Table must contain an `updated_at` timestamp for SCD Type 2" ) - - unmanaged_columns = [ - col for col in columns_to_types if col not in {valid_from_name, valid_to_name} - ] - time_data_type = columns_to_types[valid_from_name] + time_data_type = target_columns_to_types[valid_from_name] select_source_columns: t.List[t.Union[str, exp.Alias]] = [ - col for col in unmanaged_columns if col != updated_at_name + col for col in unmanaged_columns_to_types if col != updated_at_name ] - table_columns = [exp.column(c, quoted=True) for c in columns_to_types] + table_columns = [exp.column(c, quoted=True) for c in target_columns_to_types] if updated_at_name: select_source_columns.append( exp.cast(updated_at_col, time_data_type).as_(updated_at_col.this) # type: ignore @@ -1380,19 +1894,30 @@ def _scd_type_2( # they are equal or not, the extra check is not a problem and we gain simplified logic here. # If we want to change this, then we just need to check the expressions in unique_key and pull out the # column names and then remove them from the unmanaged_columns - if check_columns and check_columns == exp.Star(): - check_columns = [exp.column(col) for col in unmanaged_columns] - execution_ts = to_time_column(execution_time, time_data_type) + if check_columns: + # Handle both Star directly and [Star()] (which can happen during serialization/deserialization) + if isinstance(seq_get(ensure_list(check_columns), 0), exp.Star): + check_columns = [exp.column(col) for col in unmanaged_columns_to_types] + execution_ts = ( + exp.cast(execution_time, time_data_type, dialect=self.dialect) + if isinstance(execution_time, exp.Column) + else to_time_column(execution_time, time_data_type, self.dialect, nullable=True) + ) if updated_at_as_valid_from: if not updated_at_col: raise SQLMeshError( "Cannot use `updated_at_as_valid_from` without `updated_at_name` for SCD Type 2" ) update_valid_from_start: t.Union[str, exp.Expression] = updated_at_col - elif execution_time_as_valid_from: + # If using check_columns and the user doesn't always want execution_time for valid from + # then we only use epoch 0 if we are truncating the table and loading rows for the first time. + # All future new rows should have execution time. + elif check_columns and (execution_time_as_valid_from or not truncate): update_valid_from_start = execution_ts else: - update_valid_from_start = to_time_column("1970-01-01 00:00:00+00:00", time_data_type) + update_valid_from_start = to_time_column( + "1970-01-01 00:00:00+00:00", time_data_type, self.dialect, nullable=True + ) insert_valid_from_start = execution_ts if check_columns else updated_at_col # type: ignore # joined._exists IS NULL is saying "if the row is deleted" delete_check = ( @@ -1405,22 +1930,30 @@ def _scd_type_2( if check_columns: row_check_conditions = [] for col in check_columns: - t_col = col.copy() - t_col.this.set("this", f"t_{col.name}") + col_qualified = col.copy() + col_qualified.set("table", exp.to_identifier("joined")) + + t_col = col_qualified.copy() + for column in t_col.find_all(exp.Column): + column.this.set("this", f"t_{column.name}") + row_check_conditions.extend( [ - col.neq(t_col), - exp.and_(t_col.is_(exp.Null()), col.is_(exp.Null()).not_()), - exp.and_(t_col.is_(exp.Null()).not_(), col.is_(exp.Null())), + col_qualified.neq(t_col), + exp.and_(t_col.is_(exp.Null()), col_qualified.is_(exp.Null()).not_()), + exp.and_(t_col.is_(exp.Null()).not_(), col_qualified.is_(exp.Null())), ] ) row_value_check = exp.or_(*row_check_conditions) unique_key_conditions = [] - for col in unique_key: - t_col = col.copy() - t_col.this.set("this", f"t_{col.name}") + for key in unique_key: + key_qualified = key.copy() + key_qualified.set("table", exp.to_identifier("joined")) + t_key = key_qualified.copy() + for col in t_key.find_all(exp.Column): + col.this.set("this", f"t_{col.name}") unique_key_conditions.extend( - [t_col.is_(exp.Null()).not_(), col.is_(exp.Null()).not_()] + [t_key.is_(exp.Null()).not_(), key_qualified.is_(exp.Null()).not_()] ) unique_key_check = exp.and_(*unique_key_conditions) # unique_key_check is saying "if the row is updated" @@ -1447,11 +1980,15 @@ def _scd_type_2( ).as_(valid_from_col.this) else: assert updated_at_col is not None - prefixed_updated_at_col = updated_at_col.copy() - prefixed_updated_at_col.this.set("this", f"t_{updated_at_col.name}") - updated_row_filter = updated_at_col > prefixed_updated_at_col - - valid_to_case_stmt_builder = exp.Case().when(updated_row_filter, updated_at_col) + updated_at_col_qualified = updated_at_col.copy() + updated_at_col_qualified.set("table", exp.to_identifier("joined")) + prefixed_updated_at_col = updated_at_col_qualified.copy() + prefixed_updated_at_col.this.set("this", f"t_{updated_at_col_qualified.name}") + updated_row_filter = updated_at_col_qualified > prefixed_updated_at_col + + valid_to_case_stmt_builder = exp.Case().when( + updated_row_filter, updated_at_col_qualified + ) if delete_check: valid_to_case_stmt_builder = valid_to_case_stmt_builder.when( delete_check, execution_ts @@ -1486,12 +2023,12 @@ def _scd_type_2( with source_queries[0] as source_query: prefixed_columns_to_types = [] - for column in columns_to_types: + for column in target_columns_to_types: prefixed_col = exp.column(column).copy() prefixed_col.this.set("this", f"t_{prefixed_col.name}") prefixed_columns_to_types.append(prefixed_col) prefixed_unmanaged_columns = [] - for column in unmanaged_columns: + for column in unmanaged_columns_to_types: prefixed_col = exp.column(column).copy() prefixed_col.this.set("this", f"t_{prefixed_col.name}") prefixed_unmanaged_columns.append(prefixed_col) @@ -1511,7 +2048,11 @@ def _scd_type_2( "source", exp.select(exp.true().as_("_exists"), *select_source_columns) .distinct(*unique_key) - .from_(source_query.subquery("raw_source")), # type: ignore + .from_( + self.use_server_nulls_for_unmatched_after_join(source_query).subquery( # type: ignore + "raw_source" + ) + ), ) # Historical Records that Do Not Change .with_( @@ -1526,7 +2067,7 @@ def _scd_type_2( # Deleted records which can be used to determine `valid_from` for undeleted source records .with_( "deleted", - exp.select(*[exp.column(col, "static") for col in columns_to_types]) + exp.select(*[exp.column(col, "static") for col in target_columns_to_types]) .from_("static") .join( "latest", @@ -1552,18 +2093,21 @@ def _scd_type_2( .group_by(*unique_key), ) # Do a full join between latest records and source table in order to combine them together - # MySQL doesn't suport full join so going to do a left then right join and remove dups with union + # MySQL doesn't support full join so going to do a left then right join and remove dups with union # We do a left/right and filter right on only matching to remove the need to do union distinct # which allows scd type 2 to be compatible with unhashable data types .with_( "joined", exp.select( - exp.column("_exists", table="source"), + exp.column("_exists", table="source").as_("_exists"), *( exp.column(col, table="latest").as_(prefixed_columns_to_types[i].this) - for i, col in enumerate(columns_to_types) + for i, col in enumerate(target_columns_to_types) + ), + *( + exp.column(col, table="source").as_(col) + for col in unmanaged_columns_to_types ), - *(exp.column(col, table="source").as_(col) for col in unmanaged_columns), ) .from_("latest") .join( @@ -1578,16 +2122,16 @@ def _scd_type_2( ) .union( exp.select( - exp.column("_exists", table="source"), + exp.column("_exists", table="source").as_("_exists"), *( exp.column(col, table="latest").as_( prefixed_columns_to_types[i].this ) - for i, col in enumerate(columns_to_types) + for i, col in enumerate(target_columns_to_types) ), *( exp.column(col, table="source").as_(col) - for col in unmanaged_columns + for col in unmanaged_columns_to_types ), ) .from_("latest") @@ -1615,7 +2159,7 @@ def _scd_type_2( exp.column(prefixed_unmanaged_columns[i].this, table="joined"), exp.column(col, table="joined"), ).as_(col) - for i, col in enumerate(unmanaged_columns) + for i, col in enumerate(unmanaged_columns_to_types) ), valid_from_case_stmt, valid_to_case_stmt, @@ -1638,9 +2182,11 @@ def _scd_type_2( .with_( "inserted_rows", exp.select( - *unmanaged_columns, + *unmanaged_columns_to_types, insert_valid_from_start.as_(valid_from_col.this), # type: ignore - to_time_column(exp.null(), time_data_type).as_(valid_to_col.this), + to_time_column(exp.null(), time_data_type, self.dialect, nullable=True).as_( + valid_to_col.this + ), ) .from_("joined") .where(updated_row_filter), @@ -1649,50 +2195,73 @@ def _scd_type_2( self.replace_query( target_table, - query, - columns_to_types=columns_to_types, + self.ensure_nulls_for_unmatched_after_join(query), + target_columns_to_types=target_columns_to_types, table_description=table_description, column_descriptions=column_descriptions, + **kwargs, ) def merge( self, target_table: TableName, source_table: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]], + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]], unique_key: t.Sequence[exp.Expression], - when_matched: t.Optional[exp.When] = None, + when_matched: t.Optional[exp.Whens] = None, + merge_filter: t.Optional[exp.Expression] = None, + source_columns: t.Optional[t.List[str]] = None, + **kwargs: t.Any, ) -> None: - source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( - source_table, columns_to_types, target_table=target_table + source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types( + source_table, + target_columns_to_types, + target_table=target_table, + source_columns=source_columns, ) - columns_to_types = columns_to_types or self.columns(target_table) + target_columns_to_types = target_columns_to_types or self.columns(target_table) on = exp.and_( *( add_table(part, MERGE_TARGET_ALIAS).eq(add_table(part, MERGE_SOURCE_ALIAS)) for part in unique_key ) ) + if merge_filter: + on = exp.and_(merge_filter, on) + if not when_matched: - when_matched = exp.When( - matched=True, + match_expressions = [ + exp.When( + matched=True, + source=False, + then=exp.Update( + expressions=[ + exp.column(col, MERGE_TARGET_ALIAS).eq( + exp.column(col, MERGE_SOURCE_ALIAS) + ) + for col in target_columns_to_types + ], + ), + ) + ] + else: + match_expressions = when_matched.copy().expressions + + match_expressions.append( + exp.When( + matched=False, source=False, - then=exp.Update( - expressions=[ - exp.column(col, MERGE_TARGET_ALIAS).eq(exp.column(col, MERGE_SOURCE_ALIAS)) - for col in columns_to_types - ], + then=exp.Insert( + this=exp.Tuple( + expressions=[exp.column(col) for col in target_columns_to_types] + ), + expression=exp.Tuple( + expressions=[ + exp.column(col, MERGE_SOURCE_ALIAS) for col in target_columns_to_types + ] + ), ), ) - when_not_matched = exp.When( - matched=False, - source=False, - then=exp.Insert( - this=exp.Tuple(expressions=[exp.column(col) for col in columns_to_types]), - expression=exp.Tuple( - expressions=[exp.column(col, MERGE_SOURCE_ALIAS) for col in columns_to_types] - ), - ), ) for source_query in source_queries: with source_query as query: @@ -1700,7 +2269,7 @@ def merge( target_table=target_table, query=query, on=on, - match_expressions=[when_matched, when_not_matched], + whens=exp.Whens(expressions=match_expressions), ) def rename_table( @@ -1717,15 +2286,34 @@ def rename_table( "Tried to rename table across catalogs which is not supported" ) self._rename_table(old_table_name, new_table_name) + self._clear_data_object_cache(old_table_name) + self._clear_data_object_cache(new_table_name) + + def get_data_object( + self, target_name: TableName, safe_to_cache: bool = False + ) -> t.Optional[DataObject]: + target_table = exp.to_table(target_name) + existing_data_objects = self.get_data_objects( + schema_(target_table.db, target_table.catalog), + {target_table.name}, + safe_to_cache=safe_to_cache, + ) + if existing_data_objects: + return existing_data_objects[0] + return None def get_data_objects( - self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None + self, + schema_name: SchemaName, + object_names: t.Optional[t.Set[str]] = None, + safe_to_cache: bool = False, ) -> t.List[DataObject]: """Lists all data objects in the target schema. Args: schema_name: The name of the schema to list data objects from. object_names: If provided, only return data objects with these names. + safe_to_cache: Whether it is safe to cache the results of this call. Returns: A list of data objects in the target schema. @@ -1733,22 +2321,71 @@ def get_data_objects( if object_names is not None: if not object_names: return [] - object_names_list = list(object_names) - batches = [ - object_names_list[i : i + self.DATA_OBJECT_FILTER_BATCH_SIZE] - for i in range(0, len(object_names_list), self.DATA_OBJECT_FILTER_BATCH_SIZE) - ] - return [ - obj for batch in batches for obj in self._get_data_objects(schema_name, set(batch)) - ] - return self._get_data_objects(schema_name) + + # Check cache for each object name + target_schema = to_schema(schema_name) + cached_objects = [] + missing_names = set() + + for name in object_names: + cache_key = _get_data_object_cache_key( + target_schema.catalog, target_schema.db, name + ) + if cache_key in self._data_object_cache: + logger.debug("Data object cache hit: %s", cache_key) + data_object = self._data_object_cache[cache_key] + # If the object is none, then the table was previously looked for but not found + if data_object: + cached_objects.append(data_object) + else: + logger.debug("Data object cache miss: %s", cache_key) + missing_names.add(name) + + # Fetch missing objects from database + if missing_names: + object_names_list = list(missing_names) + batches = [ + object_names_list[i : i + self.DATA_OBJECT_FILTER_BATCH_SIZE] + for i in range(0, len(object_names_list), self.DATA_OBJECT_FILTER_BATCH_SIZE) + ] + + fetched_objects = [] + fetched_object_names = set() + for batch in batches: + objects = self._get_data_objects(schema_name, set(batch)) + for obj in objects: + if safe_to_cache: + cache_key = _get_data_object_cache_key( + obj.catalog, obj.schema_name, obj.name + ) + self._data_object_cache[cache_key] = obj + fetched_objects.append(obj) + fetched_object_names.add(obj.name) + + if safe_to_cache: + for missing_name in missing_names - fetched_object_names: + cache_key = _get_data_object_cache_key( + target_schema.catalog, target_schema.db, missing_name + ) + self._data_object_cache[cache_key] = None + + return cached_objects + fetched_objects + + return cached_objects + + fetched_objects = self._get_data_objects(schema_name) + if safe_to_cache: + for obj in fetched_objects: + cache_key = _get_data_object_cache_key(obj.catalog, obj.schema_name, obj.name) + self._data_object_cache[cache_key] = obj + return fetched_objects def fetchone( self, query: t.Union[exp.Expression, str], ignore_unsupported_errors: bool = False, quote_identifiers: bool = False, - ) -> t.Tuple: + ) -> t.Optional[t.Tuple]: with self.transaction(): self.execute( query, @@ -1779,10 +2416,27 @@ def _fetch_native_df( self.execute(query, quote_identifiers=quote_identifiers) return self.cursor.fetchdf() + def _native_df_to_pandas_df( + self, + query_or_df: QueryOrDF, + ) -> t.Union[Query, pd.DataFrame]: + """ + Take a "native" DataFrame (eg Pyspark, Bigframe, Snowpark etc) and convert it to Pandas + """ + import pandas as pd + + if isinstance(query_or_df, (exp.Query, pd.DataFrame)): + return query_or_df + + # EngineAdapter subclasses that have native DataFrame types should override this + raise NotImplementedError(f"Unable to convert {type(query_or_df)} to Pandas") + def fetchdf( self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False ) -> pd.DataFrame: """Fetches a Pandas DataFrame from the cursor""" + import pandas as pd + df = self._fetch_native_df(query, quote_identifiers=quote_identifiers) if not isinstance(df, pd.DataFrame): raise NotImplementedError( @@ -1796,6 +2450,11 @@ def fetch_pyspark_df( """Fetches a PySpark DataFrame from the cursor""" raise NotImplementedError(f"Engine does not support PySpark DataFrames: {type(self)}") + @property + def wap_enabled(self) -> bool: + """Returns whether WAP is enabled for this engine.""" + return self._extra_config.get("wap_enabled", False) + def wap_supported(self, table_name: TableName) -> bool: """Returns whether WAP for the target table is supported.""" return False @@ -1833,6 +2492,33 @@ def wap_publish(self, table_name: TableName, wap_id: str) -> None: """ raise NotImplementedError(f"Engine does not support WAP: {type(self)}") + def sync_grants_config( + self, + table: exp.Table, + grants_config: GrantsConfig, + table_type: DataObjectType = DataObjectType.TABLE, + ) -> None: + """Applies the grants_config to a table authoritatively. + It first compares the specified grants against the current grants, and then + applies the diffs to the table by revoking and granting privileges as needed. + + Args: + table: The table/view to apply grants to. + grants_config: Dictionary mapping privileges to lists of grantees. + table_type: The type of database object (TABLE, VIEW, MATERIALIZED_VIEW). + """ + if not self.SUPPORTS_GRANTS: + raise NotImplementedError(f"Engine does not support grants: {type(self)}") + + current_grants = self._get_current_grants_config(table) + new_grants, revoked_grants = self._diff_grants_configs(grants_config, current_grants) + revoke_exprs = self._revoke_grants_config_expr(table, revoked_grants, table_type) + grant_exprs = self._apply_grants_config_expr(table, new_grants, table_type) + dcl_exprs = revoke_exprs + grant_exprs + + if dcl_exprs: + self.execute(dcl_exprs) + @contextlib.contextmanager def transaction( self, @@ -1892,38 +2578,83 @@ def execute( expressions: t.Union[str, exp.Expression, t.Sequence[exp.Expression]], ignore_unsupported_errors: bool = False, quote_identifiers: bool = True, + track_rows_processed: bool = False, **kwargs: t.Any, ) -> None: """Execute a sql query.""" to_sql_kwargs = ( {"unsupported_level": ErrorLevel.IGNORE} if ignore_unsupported_errors else {} ) - with self.transaction(): for e in ensure_list(expressions): - sql = t.cast( - str, - ( - self._to_sql(e, quote=quote_identifiers, **to_sql_kwargs) - if isinstance(e, exp.Expression) - else e - ), + if isinstance(e, exp.Expression): + self._check_identifier_length(e) + sql = self._to_sql(e, quote=quote_identifiers, **to_sql_kwargs) + else: + sql = t.cast(str, e) + + sql = self._attach_correlation_id(sql) + + self._log_sql( + sql, + expression=e if isinstance(e, exp.Expression) else None, + quote_identifiers=quote_identifiers, ) - self._log_sql(sql) - self._execute(sql, **kwargs) + self._execute(sql, track_rows_processed, **kwargs) + + def _attach_correlation_id(self, sql: str) -> str: + if self.ATTACH_CORRELATION_ID and self.correlation_id: + return f"/* {self.correlation_id} */ {sql}" + return sql + + def _log_sql( + self, + sql: str, + expression: t.Optional[exp.Expression] = None, + quote_identifiers: bool = True, + ) -> None: + if not logger.isEnabledFor(self._execute_log_level): + return - def _log_sql(self, sql: str) -> None: - logger.log(self._execute_log_level, "Executing SQL: %s", sql) + sql_to_log = sql + if expression is not None and not isinstance(expression, exp.Query): + values = expression.find(exp.Values) + if values: + values.set("expressions", [exp.to_identifier("")]) + sql_to_log = self._to_sql(expression, quote=quote_identifiers) - def _execute(self, sql: str, **kwargs: t.Any) -> None: + logger.log(self._execute_log_level, "Executing SQL: %s", sql_to_log) + + def _record_execution_stats( + self, sql: str, rowcount: t.Optional[int] = None, bytes_processed: t.Optional[int] = None + ) -> None: + if self._query_execution_tracker: + self._query_execution_tracker.record_execution(sql, rowcount, bytes_processed) + + def _execute(self, sql: str, track_rows_processed: bool = False, **kwargs: t.Any) -> None: self.cursor.execute(sql, **kwargs) + if ( + self.SUPPORTS_QUERY_EXECUTION_TRACKING + and track_rows_processed + and self._query_execution_tracker + and self._query_execution_tracker.is_tracking() + ): + if ( + rowcount := getattr(self.cursor, "rowcount", None) + ) is not None and rowcount is not None: + try: + self._record_execution_stats(sql, int(rowcount)) + except (TypeError, ValueError): + return + @contextlib.contextmanager def temp_table( self, query_or_df: QueryOrDF, name: TableName = "diff", - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> t.Iterator[exp.Table]: """A context manager for working a temp table. @@ -1933,13 +2664,21 @@ def temp_table( Args: query_or_df: The query or df to create a temp table for. name: The base name of the temp table. - columns_to_types: A mapping between the column name and its data type. + target_columns_to_types: A mapping between the column name and its data type. Yields: The table expression """ - source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( - query_or_df, columns_to_types=columns_to_types, target_table=name + name = exp.to_table(name) + # ensure that we use default catalog if none is not specified + if isinstance(name, exp.Table) and not name.catalog and name.db and self.default_catalog: + name.set("catalog", exp.parse_identifier(self.default_catalog)) + + source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types( + query_or_df, + target_columns_to_types=target_columns_to_types, + target_table=name, + source_columns=source_columns, ) with self.transaction(): @@ -1949,10 +2688,11 @@ def temp_table( self._create_table_from_source_queries( table, source_queries, - columns_to_types, + target_columns_to_types, exists=True, table_description=None, column_descriptions=None, + track_rows_processed=False, **kwargs, ) @@ -1972,17 +2712,37 @@ def _table_or_view_properties_to_expressions( for key, value in table_or_view_properties.items() ] + def _build_partitioned_by_exp( + self, + partitioned_by: t.List[exp.Expression], + *, + partition_interval_unit: t.Optional[IntervalUnit] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + catalog_name: t.Optional[str] = None, + **kwargs: t.Any, + ) -> t.Optional[t.Union[exp.PartitionedByProperty, exp.Property]]: + return None + + def _build_clustered_by_exp( + self, + clustered_by: t.List[exp.Expression], + **kwargs: t.Any, + ) -> t.Optional[exp.Cluster]: + return None + def _build_table_properties_exp( self, catalog_name: t.Optional[str] = None, + table_format: t.Optional[str] = None, storage_format: t.Optional[str] = None, partitioned_by: t.Optional[t.List[exp.Expression]] = None, partition_interval_unit: t.Optional[IntervalUnit] = None, - clustered_by: t.Optional[t.List[str]] = None, + clustered_by: t.Optional[t.List[exp.Expression]] = None, table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, table_kind: t.Optional[str] = None, + **kwargs: t.Any, ) -> t.Optional[exp.Properties]: """Creates a SQLGlot table properties expression for ddl.""" properties: t.List[exp.Expression] = [] @@ -1994,6 +2754,10 @@ def _build_table_properties_exp( ) ) + if table_properties: + table_type = self._pop_creatable_type_from_properties(table_properties) + properties.extend(ensure_list(table_type)) + if properties: return exp.Properties(expressions=properties) return None @@ -2002,6 +2766,7 @@ def _build_view_properties_exp( self, view_properties: t.Optional[t.Dict[str, exp.Expression]] = None, table_description: t.Optional[str] = None, + **kwargs: t.Any, ) -> t.Optional[exp.Properties]: """Creates a SQLGlot table properties expression for view""" properties: t.List[exp.Expression] = [] @@ -2034,7 +2799,7 @@ def _to_sql(self, expression: exp.Expression, quote: bool = True, **kwargs: t.An """ sql_gen_kwargs = { "dialect": self.dialect, - "pretty": False, + "pretty": self._pretty_sql, "comments": False, **self._sql_gen_kwargs, **kwargs, @@ -2047,6 +2812,17 @@ def _to_sql(self, expression: exp.Expression, quote: bool = True, **kwargs: t.An return expression.sql(**sql_gen_kwargs, copy=False) # type: ignore + def _clear_data_object_cache(self, table_name: t.Optional[TableName] = None) -> None: + """Clears the cache entry for the given table name, or clears the entire cache if table_name is None.""" + if table_name is None: + logger.debug("Clearing entire data object cache") + self._data_object_cache.clear() + else: + table = exp.to_table(table_name) + cache_key = _get_data_object_cache_key(table.catalog, table.db, table.name) + logger.debug("Clearing data object cache key: %s", cache_key) + self._data_object_cache.pop(cache_key, None) + def _get_data_objects( self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None ) -> t.List[DataObject]: @@ -2075,25 +2851,25 @@ def _get_temp_table( def _order_projections_and_filter( self, query: Query, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], where: t.Optional[exp.Expression] = None, coerce_types: bool = False, ) -> Query: if not isinstance(query, exp.Query) or ( - not where and not coerce_types and query.named_selects == list(columns_to_types) + not where and not coerce_types and query.named_selects == list(target_columns_to_types) ): return query query = t.cast(exp.Query, query.copy()) - with_ = query.args.pop("with", None) + with_ = query.args.pop("with_", None) select_exprs: t.List[exp.Expression] = [ - exp.column(c, quoted=True) for c in columns_to_types + exp.column(c, quoted=True) for c in target_columns_to_types ] - if coerce_types and columns_to_types_all_known(columns_to_types): + if coerce_types and columns_to_types_all_known(target_columns_to_types): select_exprs = [ exp.cast(select_exprs[i], col_tpe).as_(col, quoted=True) - for i, (col, col_tpe) in enumerate(columns_to_types.items()) + for i, (col, col_tpe) in enumerate(target_columns_to_types.items()) ] query = exp.select(*select_exprs).from_(query.subquery("_subquery", copy=False), copy=False) @@ -2101,7 +2877,7 @@ def _order_projections_and_filter( query = query.where(where, copy=False) if with_: - query.set("with", with_) + query.set("with_", with_) return query @@ -2109,27 +2885,58 @@ def _truncate_table(self, table_name: TableName) -> None: table = exp.to_table(table_name) self.execute(f"TRUNCATE TABLE {table.sql(dialect=self.dialect, identify=True)}") + def drop_data_object_on_type_mismatch( + self, data_object: t.Optional[DataObject], expected_type: DataObjectType + ) -> bool: + """Drops a data object if it exists and is not of the expected type. + + Args: + data_object: The data object to check. + expected_type: The expected type of the data object. + + Returns: + True if the data object was dropped, False otherwise. + """ + if data_object is None or data_object.type == expected_type: + return False + + logger.warning( + "Target data object '%s' is a %s and not a %s, dropping it", + data_object.to_table().sql(dialect=self.dialect), + data_object.type.value, + expected_type.value, + ) + self.drop_data_object(data_object) + return True + def _replace_by_key( self, target_table: TableName, source_table: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]], + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]], key: t.Sequence[exp.Expression], is_unique_key: bool, + source_columns: t.Optional[t.List[str]] = None, ) -> None: - if columns_to_types is None: - columns_to_types = self.columns(target_table) + if target_columns_to_types is None: + target_columns_to_types = self.columns(target_table) temp_table = self._get_temp_table(target_table) key_exp = exp.func("CONCAT_WS", "'__SQLMESH_DELIM__'", *key) if len(key) > 1 else key[0] - column_names = list(columns_to_types or []) + column_names = list(target_columns_to_types or []) with self.transaction(): - self.ctas(temp_table, source_table, columns_to_types=columns_to_types, exists=False) + self.ctas( + temp_table, + source_table, + target_columns_to_types=target_columns_to_types, + exists=False, + source_columns=source_columns, + ) try: delete_query = exp.select(key_exp).from_(temp_table) - insert_query = self._select_columns(columns_to_types).from_(temp_table) + insert_query = self._select_columns(target_columns_to_types).from_(temp_table) if not is_unique_key: delete_query = delete_query.distinct() else: @@ -2143,12 +2950,12 @@ def _replace_by_key( delete_filter = key_exp.isin(query=delete_query) if not self.INSERT_OVERWRITE_STRATEGY.is_replace_where: - self.execute(exp.delete(target_table).where(delete_filter)) + self.delete_from(target_table, delete_filter) else: insert_statement.set("where", delete_filter) insert_statement.set("this", exp.to_table(target_table)) - self.execute(insert_statement) + self.execute(insert_statement, track_rows_processed=True) finally: self.drop_table(temp_table) @@ -2170,7 +2977,7 @@ def _create_table_comment( self.execute(self._build_create_comment_table_exp(table, table_comment, table_kind)) except Exception: logger.warning( - f"Table comment for '{table.alias_or_name}' not registered - this may be due to limited permissions.", + f"Table comment for '{table.alias_or_name}' not registered - this may be due to limited permissions", exc_info=True, ) @@ -2188,6 +2995,7 @@ def _create_column_comments( table_name: TableName, column_comments: t.Dict[str, str], table_kind: str = "TABLE", + materialized_view: bool = False, ) -> None: table = exp.to_table(table_name) @@ -2196,10 +3004,19 @@ def _create_column_comments( self.execute(self._build_create_comment_column_exp(table, col, comment, table_kind)) except Exception: logger.warning( - f"Column comments for table '{table.alias_or_name}' not registered - this may be due to limited permissions.", + f"Column comments for column '{col}' in table '{table.alias_or_name}' not registered - this may be due to limited permissions", exc_info=True, ) + def _create_table_like( + self, + target_table_name: TableName, + source_table_name: TableName, + exists: bool, + **kwargs: t.Any, + ) -> None: + self.create_table(target_table_name, self.columns(source_table_name), exists=exists) + def _rename_table( self, old_table_name: TableName, @@ -2207,6 +3024,18 @@ def _rename_table( ) -> None: self.execute(exp.rename_table(old_table_name, new_table_name)) + def ensure_nulls_for_unmatched_after_join( + self, + query: Query, + ) -> Query: + return query + + def use_server_nulls_for_unmatched_after_join( + self, + query: Query, + ) -> Query: + return query + def ping(self) -> None: try: self._execute(exp.select("1").sql(dialect=self.dialect)) @@ -2214,8 +3043,150 @@ def ping(self) -> None: self._connection_pool.close_cursor() @classmethod - def _select_columns(cls, columns: t.Iterable[str]) -> exp.Select: - return exp.select(*(exp.column(c, quoted=True) for c in columns)) + def _select_columns( + cls, columns: t.Iterable[str], source_columns: t.Optional[t.List[str]] = None + ) -> exp.Select: + return exp.select( + *( + exp.column(c, quoted=True) + if c in (source_columns or columns) + else exp.alias_(exp.Null(), c, quoted=True) + for c in columns + ) + ) + + def _check_identifier_length(self, expression: exp.Expression) -> None: + if self.MAX_IDENTIFIER_LENGTH is None or not isinstance(expression, exp.DDL): + return + + for identifier in expression.find_all(exp.Identifier): + name = identifier.name + name_length = len(name) + if name_length > self.MAX_IDENTIFIER_LENGTH: + raise SQLMeshError( + f"Identifier name '{name}' (length {name_length}) exceeds {self.dialect.capitalize()}'s max identifier limit of {self.MAX_IDENTIFIER_LENGTH} characters" + ) + + def get_table_last_modified_ts(self, table_names: t.List[TableName]) -> t.List[int]: + raise NotImplementedError() + + @classmethod + def _diff_grants_configs( + cls, new_config: GrantsConfig, old_config: GrantsConfig + ) -> t.Tuple[GrantsConfig, GrantsConfig]: + """Compute additions and removals between two grants configurations. + + This method compares new (desired) and old (current) GrantsConfigs case-insensitively + for both privilege keys and grantees, while preserving original casing + in the output GrantsConfigs. + + Args: + new_config: Desired grants configuration (specified by the user). + old_config: Current grants configuration (returned by the database). + + Returns: + A tuple of (additions, removals) GrantsConfig where: + - additions contains privileges/grantees present in new_config but not in old_config + - additions uses keys and grantee strings from new_config (user-specified casing) + - removals contains privileges/grantees present in old_config but not in new_config + - removals uses keys and grantee strings from old_config (database-returned casing) + + Notes: + - Comparison is case-insensitive using casefold(); original casing is preserved in results. + - Overlapping grantees (case-insensitive) are excluded from the results. + """ + + def _diffs(config1: GrantsConfig, config2: GrantsConfig) -> GrantsConfig: + diffs: GrantsConfig = {} + cf_config2 = {k.casefold(): {g.casefold() for g in v} for k, v in config2.items()} + for key, grantees in config1.items(): + cf_key = key.casefold() + + # Missing key (add all grantees) + if cf_key not in cf_config2: + diffs[key] = grantees.copy() + continue + + # Include only grantees not in config2 + cf_grantees2 = cf_config2[cf_key] + diff_grantees = [] + for grantee in grantees: + if grantee.casefold() not in cf_grantees2: + diff_grantees.append(grantee) + if diff_grantees: + diffs[key] = diff_grantees + return diffs + + return _diffs(new_config, old_config), _diffs(old_config, new_config) + + def _get_current_grants_config(self, table: exp.Table) -> GrantsConfig: + """Returns current grants for a table as a dictionary. + + This method queries the database and returns the current grants/permissions + for the given table, parsed into a dictionary format. The it handles + case-insensitive comparison between these current grants and the desired + grants from model configuration. + + Args: + table: The table/view to query grants for. + + Returns: + Dictionary mapping permissions to lists of grantees. Permission names + should be returned as the database provides them (typically uppercase + for standard SQL permissions, but engine-specific roles may vary). + + Raises: + NotImplementedError: If the engine does not support grants. + """ + if not self.SUPPORTS_GRANTS: + raise NotImplementedError(f"Engine does not support grants: {type(self)}") + raise NotImplementedError("Subclass must implement get_current_grants") + + def _apply_grants_config_expr( + self, + table: exp.Table, + grants_config: GrantsConfig, + table_type: DataObjectType = DataObjectType.TABLE, + ) -> t.List[exp.Expression]: + """Returns SQLGlot Grant expressions to apply grants to a table. + + Args: + table: The table/view to grant permissions on. + grants_config: Dictionary mapping permissions to lists of grantees. + table_type: The type of database object (TABLE, VIEW, MATERIALIZED_VIEW). + + Returns: + List of SQLGlot expressions for grant operations. + + Raises: + NotImplementedError: If the engine does not support grants. + """ + if not self.SUPPORTS_GRANTS: + raise NotImplementedError(f"Engine does not support grants: {type(self)}") + raise NotImplementedError("Subclass must implement _apply_grants_config_expr") + + def _revoke_grants_config_expr( + self, + table: exp.Table, + grants_config: GrantsConfig, + table_type: DataObjectType = DataObjectType.TABLE, + ) -> t.List[exp.Expression]: + """Returns SQLGlot expressions to revoke grants from a table. + + Args: + table: The table/view to revoke permissions from. + grants_config: Dictionary mapping permissions to lists of grantees. + table_type: The type of database object (TABLE, VIEW, MATERIALIZED_VIEW). + + Returns: + List of SQLGlot expressions for revoke operations. + + Raises: + NotImplementedError: If the engine does not support grants. + """ + if not self.SUPPORTS_GRANTS: + raise NotImplementedError(f"Engine does not support grants: {type(self)}") + raise NotImplementedError("Subclass must implement _revoke_grants_config_expr") class EngineAdapterWithIndexSupport(EngineAdapter): @@ -2226,3 +3197,9 @@ def _decoded_str(value: t.Union[str, bytes]) -> str: if isinstance(value, bytes): return value.decode("utf-8") return value + + +def _get_data_object_cache_key(catalog: t.Optional[str], schema_name: str, object_name: str) -> str: + """Returns a cache key for a data object based on its fully qualified name.""" + catalog = f"{catalog}." if catalog else "" + return f"{catalog}{schema_name}.{object_name}" diff --git a/sqlmesh/core/engine_adapter/base_postgres.py b/sqlmesh/core/engine_adapter/base_postgres.py index 7b49982dac..11f56da133 100644 --- a/sqlmesh/core/engine_adapter/base_postgres.py +++ b/sqlmesh/core/engine_adapter/base_postgres.py @@ -1,11 +1,12 @@ from __future__ import annotations import typing as t +import logging from sqlglot import exp from sqlmesh.core.dialect import to_schema -from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.core.engine_adapter.base import EngineAdapter, _get_data_object_cache_key from sqlmesh.core.engine_adapter.shared import ( CatalogSupport, CommentCreationTable, @@ -20,13 +21,22 @@ from sqlmesh.core.engine_adapter._typing import QueryOrDF +logger = logging.getLogger(__name__) + + class BasePostgresEngineAdapter(EngineAdapter): DEFAULT_BATCH_SIZE = 400 - CATALOG_SUPPORT = CatalogSupport.SINGLE_CATALOG_ONLY COMMENT_CREATION_TABLE = CommentCreationTable.COMMENT_COMMAND_ONLY COMMENT_CREATION_VIEW = CommentCreationView.COMMENT_COMMAND_ONLY + SUPPORTS_QUERY_EXECUTION_TRACKING = True + SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["SCHEMA", "TABLE", "VIEW"] + + def columns( + self, table_name: TableName, include_pseudo_columns: bool = False + ) -> t.Dict[str, exp.DataType]: + """Fetches column names and types for the target table.""" + table = exp.to_table(table_name) - def _columns_query(self, table: exp.Table) -> exp.Select: sql = ( exp.select( "attname AS column_name", @@ -45,22 +55,23 @@ def _columns_query(self, table: exp.Table) -> exp.Select: ) if table.args.get("db"): sql = sql.where(exp.column("nspname").eq(table.args["db"].name)) - return sql - def columns( - self, table_name: TableName, include_pseudo_columns: bool = False - ) -> t.Dict[str, exp.DataType]: - """Fetches column names and types for the target table.""" - table = exp.to_table(table_name) - self.execute(self._columns_query(table)) + self.execute(sql) resp = self.cursor.fetchall() if not resp: - raise SQLMeshError("Could not get columns for table '%s'. Table not found.", table_name) + raise SQLMeshError( + f"Could not get columns for table '{table.sql(dialect=self.dialect)}'. Table not found." + ) + return { column_name: exp.DataType.build(data_type, dialect=self.dialect, udt=True) for column_name, data_type in resp } + @property + def catalog_support(self) -> CatalogSupport: + return CatalogSupport.SINGLE_CATALOG_ONLY + def table_exists(self, table_name: TableName) -> bool: """ Postgres doesn't support describe so I'm using what the redshift cursor does to check if a table @@ -69,6 +80,10 @@ def table_exists(self, table_name: TableName) -> bool: Reference: https://github.com/aws/amazon-redshift-python-driver/blob/master/redshift_connector/cursor.py#L528-L553 """ table = exp.to_table(table_name) + data_object_cache_key = _get_data_object_cache_key(table.catalog, table.db, table.name) + if data_object_cache_key in self._data_object_cache: + logger.debug("Table existence cache hit: %s", data_object_cache_key) + return self._data_object_cache[data_object_cache_key] is not None sql = ( exp.select("1") @@ -89,12 +104,14 @@ def create_view( self, view_name: TableName, query_or_df: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, replace: bool = True, materialized: bool = False, + materialized_properties: t.Optional[t.Dict[str, t.Any]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, view_properties: t.Optional[t.Dict[str, exp.Expression]] = None, + source_columns: t.Optional[t.List[str]] = None, **create_kwargs: t.Any, ) -> None: """ @@ -110,12 +127,14 @@ def create_view( super().create_view( view_name, query_or_df, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, replace=False, materialized=materialized, + materialized_properties=materialized_properties, table_description=table_description, column_descriptions=column_descriptions, view_properties=view_properties, + source_columns=source_columns, **create_kwargs, ) @@ -178,3 +197,10 @@ def _get_data_objects( ) for row in df.itertuples() ] + + def _get_current_schema(self) -> str: + """Returns the current default schema for the connection.""" + result = self.fetchone(exp.select(exp.func("current_schema"))) + if result and result[0]: + return result[0] + return "public" diff --git a/sqlmesh/core/engine_adapter/bigquery.py b/sqlmesh/core/engine_adapter/bigquery.py index 9ef8227414..59a56b6ace 100644 --- a/sqlmesh/core/engine_adapter/bigquery.py +++ b/sqlmesh/core/engine_adapter/bigquery.py @@ -2,43 +2,61 @@ import logging import typing as t +from collections import defaultdict -import pandas as pd -from sqlglot import exp +from sqlglot import exp, parse_one from sqlglot.transforms import remove_precision_parameterized_types from sqlmesh.core.dialect import to_schema -from sqlmesh.core.engine_adapter.mixins import InsertOverwriteWithMergeMixin +from sqlmesh.core.engine_adapter.base import _get_data_object_cache_key +from sqlmesh.core.engine_adapter.mixins import ( + ClusteredByMixin, + GrantsFromInfoSchemaMixin, + RowDiffMixin, + TableAlterClusterByOperation, +) from sqlmesh.core.engine_adapter.shared import ( CatalogSupport, DataObject, DataObjectType, SourceQuery, set_catalog, + InsertOverwriteStrategy, ) from sqlmesh.core.node import IntervalUnit -from sqlmesh.core.schema_diff import SchemaDiffer +from sqlmesh.core.schema_diff import TableAlterOperation, NestedSupport +from sqlmesh.utils import optional_import, get_source_columns_to_types from sqlmesh.utils.date import to_datetime from sqlmesh.utils.errors import SQLMeshError +from sqlmesh.utils.pandas import columns_to_types_from_dtypes if t.TYPE_CHECKING: + import pandas as pd from google.api_core.retry import Retry from google.cloud import bigquery from google.cloud.bigquery import StandardSqlDataType from google.cloud.bigquery.client import Client as BigQueryClient - from google.cloud.bigquery.client import Connection as BigQueryConnection + from google.cloud.bigquery.job import QueryJob from google.cloud.bigquery.job.base import _AsyncJob as BigQueryQueryResult from google.cloud.bigquery.table import Table as BigQueryTable from sqlmesh.core._typing import SchemaName, SessionProperties, TableName - from sqlmesh.core.engine_adapter._typing import DF, Query + from sqlmesh.core.engine_adapter._typing import BigframeSession, DCL, DF, GrantsConfig, Query from sqlmesh.core.engine_adapter.base import QueryOrDF + logger = logging.getLogger(__name__) +bigframes = optional_import("bigframes") +bigframes_pd = optional_import("bigframes.pandas") + + +NestedField = t.Tuple[str, str, t.List[str]] +NestedFieldsDict = t.Dict[str, t.List[NestedField]] + @set_catalog() -class BigQueryEngineAdapter(InsertOverwriteWithMergeMixin): +class BigQueryEngineAdapter(ClusteredByMixin, RowDiffMixin, GrantsFromInfoSchemaMixin): """ BigQuery Engine Adapter using the `google-cloud-bigquery` library's DB API. """ @@ -48,14 +66,19 @@ class BigQueryEngineAdapter(InsertOverwriteWithMergeMixin): SUPPORTS_TRANSACTIONS = False SUPPORTS_MATERIALIZED_VIEWS = True SUPPORTS_CLONING = True - CATALOG_SUPPORT = CatalogSupport.FULL_SUPPORT + SUPPORTS_GRANTS = True + CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expression = exp.func("session_user") + SUPPORTS_MULTIPLE_GRANT_PRINCIPALS = True + USE_CATALOG_IN_GRANTS = True + GRANT_INFORMATION_SCHEMA_TABLE_NAME = "OBJECT_PRIVILEGES" MAX_TABLE_COMMENT_LENGTH = 1024 MAX_COLUMN_COMMENT_LENGTH = 1024 + SUPPORTS_QUERY_EXECUTION_TRACKING = True + SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["SCHEMA"] + INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.MERGE - # SQL is not supported for adding columns to structs: https://cloud.google.com/bigquery/docs/managing-table-schemas#api_1 - # Can explore doing this with the API in the future - SCHEMA_DIFFER = SchemaDiffer( - compatible_types={ + SCHEMA_DIFFER_KWARGS = { + "compatible_types": { exp.DataType.build("INT64", dialect=DIALECT): { exp.DataType.build("NUMERIC", dialect=DIALECT), exp.DataType.build("FLOAT64", dialect=DIALECT), @@ -69,12 +92,17 @@ class BigQueryEngineAdapter(InsertOverwriteWithMergeMixin): exp.DataType.build("DATETIME", dialect=DIALECT), }, }, - support_coercing_compatible_types=True, - parameterized_type_defaults={ + "coerceable_types": { + exp.DataType.build("FLOAT64", dialect=DIALECT): { + exp.DataType.build("BIGNUMERIC", dialect=DIALECT), + }, + }, + "support_coercing_compatible_types": True, + "parameterized_type_defaults": { exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(38, 9), (0,)], exp.DataType.build("BIGDECIMAL", dialect=DIALECT).this: [(76.76, 38), (0,)], }, - types_with_unlimited_length={ + "types_with_unlimited_length": { # parameterized `STRING(n)` can ALTER to unparameterized `STRING` exp.DataType.build("STRING", dialect=DIALECT).this: { exp.DataType.build("STRING", dialect=DIALECT).this, @@ -84,15 +112,23 @@ class BigQueryEngineAdapter(InsertOverwriteWithMergeMixin): exp.DataType.build("BYTES", dialect=DIALECT).this, }, }, - ) + "nested_support": NestedSupport.ALL_BUT_DROP, + } @property def client(self) -> BigQueryClient: return self.connection._client @property - def connection(self) -> BigQueryConnection: - return self.cursor.connection + def bigframe(self) -> t.Optional[BigframeSession]: + if bigframes: + options = bigframes.BigQueryOptions( + credentials=self.client._credentials, + project=self.client.project, + location=self.client.location, + ) + return bigframes.connect(context=options) + return None @property def _job_params(self) -> t.Dict[str, t.Any]: @@ -106,17 +142,32 @@ def _job_params(self) -> t.Dict[str, t.Any]: } if self._extra_config.get("maximum_bytes_billed"): params["maximum_bytes_billed"] = self._extra_config.get("maximum_bytes_billed") + if self.correlation_id: + # BigQuery label keys must be lowercase + key = self.correlation_id.job_type.value.lower() + params["labels"] = {key: self.correlation_id.job_id} return params + @property + def catalog_support(self) -> CatalogSupport: + return CatalogSupport.FULL_SUPPORT + def _df_to_source_queries( self, df: DF, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], batch_size: int, target_table: TableName, + source_columns: t.Optional[t.List[str]] = None, ) -> t.List[SourceQuery]: + import pandas as pd + + source_columns_to_types = get_source_columns_to_types( + target_columns_to_types, source_columns + ) + temp_bq_table = self.__get_temp_bq_table( - self._get_temp_table(target_table or "pandas"), columns_to_types + self._get_temp_table(target_table or "pandas"), source_columns_to_types ) temp_table = exp.table_( temp_bq_table.table_id, @@ -125,16 +176,24 @@ def _df_to_source_queries( ) def query_factory() -> Query: - if not self.table_exists(temp_table): + ordered_df = df[list(source_columns_to_types)] + if bigframes_pd and isinstance(ordered_df, bigframes_pd.DataFrame): + ordered_df.to_gbq( + f"{temp_bq_table.project}.{temp_bq_table.dataset_id}.{temp_bq_table.table_id}", + if_exists="replace", + ) + elif not self.table_exists(temp_table): # Make mypy happy - assert isinstance(df, pd.DataFrame) + assert isinstance(ordered_df, pd.DataFrame) self._db_call(self.client.create_table, table=temp_bq_table, exists_ok=False) result = self.__load_pandas_to_table( - temp_bq_table, df, columns_to_types, replace=False + temp_bq_table, ordered_df, source_columns_to_types, replace=False ) if result.errors: raise SQLMeshError(result.errors) - return self._select_columns(columns_to_types).from_(temp_table) + return exp.select( + *self._casted_columns(target_columns_to_types, source_columns=source_columns) + ).from_(temp_table) return [ SourceQuery( @@ -143,10 +202,68 @@ def query_factory() -> Query: ) ] + def close(self) -> t.Any: + # Cancel all pending query jobs across all threads + all_query_jobs = self._connection_pool.get_all_attributes("query_job") + for query_job in all_query_jobs: + if query_job: + try: + if not self._db_call(query_job.done): + self._db_call(query_job.cancel) + logger.debug( + "Cancelled BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s", + query_job.project, + query_job.location, + query_job.job_id, + ) + except Exception as ex: + logger.debug( + "Failed to cancel BigQuery job: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s. %s", + query_job.project, + query_job.location, + query_job.job_id, + str(ex), + ) + + return super().close() + def _begin_session(self, properties: SessionProperties) -> None: from google.cloud.bigquery import QueryJobConfig - job = self.client.query("SELECT 1;", job_config=QueryJobConfig(create_session=True)) + query_label_property = properties.get("query_label") + parsed_query_label: list[tuple[str, str]] = [] + if isinstance(query_label_property, (exp.Array, exp.Paren, exp.Tuple)): + label_tuples = ( + [query_label_property.unnest()] + if isinstance(query_label_property, exp.Paren) + else query_label_property.expressions + ) + + # query_label is a Paren, Array or Tuple of 2-tuples and validated at load time + parsed_query_label.extend( + (label_tuple.expressions[0].name, label_tuple.expressions[1].name) + for label_tuple in label_tuples + ) + elif query_label_property is not None: + raise SQLMeshError( + "Invalid value for `session_properties.query_label`. Must be an array or tuple." + ) + + if self.correlation_id: + parsed_query_label.append( + (self.correlation_id.job_type.value.lower(), self.correlation_id.job_id) + ) + + if parsed_query_label: + query_label_str = ",".join([":".join(label) for label in parsed_query_label]) + query = f'SET @@query_label = "{query_label_str}";SELECT 1;' + else: + query = "SELECT 1;" + + job = self.client.query( + query, + job_config=QueryJobConfig(create_session=True), + ) session_info = job.session_info session_id = session_info.session_id if session_info else None self._session_id = session_id @@ -171,6 +288,7 @@ def create_schema( schema_name: SchemaName, ignore_if_exists: bool = True, warn_on_error: bool = True, + properties: t.List[exp.Expression] = [], ) -> None: """Create a schema from a name or qualified table name.""" from google.api_core.exceptions import Conflict @@ -189,13 +307,25 @@ def create_schema( raise logger.warning("Failed to create schema '%s': %s", schema_name, e) + def get_bq_schema(self, table_name: TableName) -> t.List[bigquery.SchemaField]: + table = exp.to_table(table_name) + if len(table.parts) == 3 and "." in table.name: + self.execute(exp.select("*").from_(table).limit(0)) + query_job = self._query_job + assert query_job is not None + return query_job._query_results.schema + return self._get_table(table).schema + def columns( self, table_name: TableName, include_pseudo_columns: bool = False ) -> t.Dict[str, exp.DataType]: """Fetches column names and types for the target table.""" - def dtype_to_sql(dtype: t.Optional[StandardSqlDataType]) -> str: + def dtype_to_sql( + dtype: t.Optional[StandardSqlDataType], field: bigquery.SchemaField + ) -> str: assert dtype + assert field kind = dtype.type_kind assert kind @@ -203,37 +333,110 @@ def dtype_to_sql(dtype: t.Optional[StandardSqlDataType]) -> str: # Not using the enum value to preserve compatibility with older versions # of the BigQuery library. if kind.name == "ARRAY": - return f"ARRAY<{dtype_to_sql(dtype.array_element_type)}>" + return f"ARRAY<{dtype_to_sql(dtype.array_element_type, field)}>" if kind.name == "STRUCT": struct_type = dtype.struct_type assert struct_type fields = ", ".join( - f"{field.name} {dtype_to_sql(field.type)}" for field in struct_type.fields + f"{struct_field.name} {dtype_to_sql(struct_field.type, nested_field)}" + for struct_field, nested_field in zip(struct_type.fields, field.fields) ) return f"STRUCT<{fields}>" if kind.name == "TYPE_KIND_UNSPECIFIED": - return "JSON" + field_type = field.field_type + + if field_type == "RANGE": + # If the field is a RANGE then `range_element_type` should be set to + # one of `"DATE"`, `"DATETIME"` or `"TIMESTAMP"`. + return f"RANGE<{field.range_element_type.element_type}>" + + return field_type + return kind.name - table = self._get_table(table_name) - columns = { - field.name: exp.DataType.build( - dtype_to_sql(field.to_standard_sql().type), dialect=self.dialect - ) - for field in table.schema - } - if include_pseudo_columns and table.time_partitioning and not table.time_partitioning.field: - columns["_PARTITIONTIME"] = exp.DataType.build("TIMESTAMP") - if table.time_partitioning.type_ == "DAY": - columns["_PARTITIONDATE"] = exp.DataType.build("DATE") + def create_mapping_schema( + schema: t.Sequence[bigquery.SchemaField], + ) -> t.Dict[str, exp.DataType]: + return { + field.name: exp.DataType.build( + dtype_to_sql(field.to_standard_sql().type, field), dialect=self.dialect + ) + for field in schema + } + + table = exp.to_table(table_name) + if len(table.parts) == 3 and "." in table.name: + # The client's `get_table` method can't handle paths with >3 identifiers + self.execute(exp.select("*").from_(table).limit(0)) + query_job = self._query_job + assert query_job is not None + + query_results = query_job._query_results + columns = create_mapping_schema(query_results.schema) + else: + bq_table = self._get_table(table) + columns = create_mapping_schema(bq_table.schema) + + if include_pseudo_columns: + if bq_table.time_partitioning and not bq_table.time_partitioning.field: + columns["_PARTITIONTIME"] = exp.DataType.build("TIMESTAMP", dialect="bigquery") + if bq_table.time_partitioning.type_ == "DAY": + columns["_PARTITIONDATE"] = exp.DataType.build("DATE") + if bq_table.table_id.endswith("*"): + columns["_TABLE_SUFFIX"] = exp.DataType.build("STRING", dialect="bigquery") + if ( + bq_table.external_data_configuration is not None + and bq_table.external_data_configuration.source_format + in ( + "CSV", + "NEWLINE_DELIMITED_JSON", + "AVRO", + "PARQUET", + "ORC", + "DATASTORE_BACKUP", + ) + ): + columns["_FILE_NAME"] = exp.DataType.build("STRING", dialect="bigquery") + return columns + def alter_table( + self, + alter_expressions: t.Union[t.List[exp.Alter], t.List[TableAlterOperation]], + ) -> None: + """ + Performs the alter statements to change the current table into the structure of the target table, + and uses the API to add columns to structs, where SQL is not supported. + """ + if not alter_expressions: + return + + cluster_by_operations, alter_statements = [], [] + for e in alter_expressions: + if isinstance(e, TableAlterClusterByOperation): + cluster_by_operations.append(e) + elif isinstance(e, TableAlterOperation): + alter_statements.append(e.expression) + else: + alter_statements.append(e) + + for op in cluster_by_operations: + self._update_clustering_key(op) + + nested_fields, non_nested_expressions = self._split_alter_expressions(alter_statements) + + if nested_fields: + self._update_table_schema_nested_fields(nested_fields, alter_statements[0].this) + + if non_nested_expressions: + super().alter_table(non_nested_expressions) + def fetchone( self, query: t.Union[exp.Expression, str], ignore_unsupported_errors: bool = False, quote_identifiers: bool = False, - ) -> t.Tuple: + ) -> t.Optional[t.Tuple]: """ BigQuery's `fetchone` method doesn't call execute and therefore would not benefit from the execute configuration we have in place. Therefore this implementation calls execute instead. @@ -246,7 +449,7 @@ def fetchone( try: return next(self._query_data) except StopIteration: - return () + return None def fetchall( self, @@ -265,6 +468,104 @@ def fetchall( ) return list(self._query_data) + def _split_alter_expressions( + self, + alter_expressions: t.List[exp.Alter], + ) -> t.Tuple[NestedFieldsDict, t.List[exp.Alter]]: + """ + Returns a dictionary of the nested fields to add and a list of the non-nested alter expressions. + """ + nested_fields_to_add: NestedFieldsDict = defaultdict(list) + non_nested_expressions = [] + + for alter_expression in alter_expressions: + action = alter_expression.args["actions"][0] + if ( + isinstance(action, exp.ColumnDef) + and isinstance(action.this, exp.Dot) + and isinstance(action.kind, exp.DataType) + ): + root_field, *leaf_fields = action.this.this.sql(dialect=self.dialect).split(".") + new_field = action.this.expression.sql(dialect=self.dialect) + data_type = action.kind.sql(dialect=self.dialect) + nested_fields_to_add[root_field].append((new_field, data_type, leaf_fields)) + else: + non_nested_expressions.append(alter_expression) + + return nested_fields_to_add, non_nested_expressions + + def _build_nested_fields( + self, + current_fields: t.List[bigquery.SchemaField], + fields_to_add: t.List[NestedField], + ) -> t.List[bigquery.SchemaField]: + """ + Recursively builds and updates the schema fields with the new nested fields. + """ + from google.cloud import bigquery + + new_fields = [] + root: t.List[t.Tuple[str, str]] = [] + leaves: NestedFieldsDict = defaultdict(list) + for new_field, data_type, leaf_fields in fields_to_add: + if leaf_fields: + leaves[leaf_fields[0]].append((new_field, data_type, leaf_fields[1:])) + else: + root.append((new_field, data_type)) + + for field in current_fields: + # If the new fields are nested, we need to recursively build them + if field.name in leaves: + subfields = list(field.fields) + subfields = self._build_nested_fields(subfields, leaves[field.name]) + new_fields.append( + bigquery.SchemaField( + field.name, "RECORD", mode=field.mode, fields=tuple(subfields) + ) + ) + else: + new_fields.append(field) + + # Build and append the new root-level fields + new_fields.extend( + self.__get_bq_schemafield( + new_field[0], exp.DataType.build(new_field[1], dialect=self.dialect) + ) + for new_field in root + ) + return new_fields + + def _update_table_schema_nested_fields( + self, nested_fields_to_add: NestedFieldsDict, table_name: str + ) -> None: + """ + Updates a BigQuery table schema by adding the new nested fields provided. + """ + from google.cloud import bigquery + + table = self._get_table(table_name) + original_schema = table.schema + new_schema = [] + for field in original_schema: + if field.name in nested_fields_to_add: + fields = self._build_nested_fields( + list(field.fields), nested_fields_to_add[field.name] + ) + new_schema.append( + bigquery.SchemaField( + field.name, + "RECORD", + mode=field.mode, + fields=tuple(fields), + ) + ) + else: + new_schema.append(field) + + if new_schema != original_schema: + table.schema = new_schema + self.client.update_table(table, ["schema"]) + def __load_pandas_to_table( self, table: bigquery.Table, @@ -298,20 +599,50 @@ def __db_load_table_from_dataframe( ) return self._db_call(job.result) + def __get_bq_schemafield(self, name: str, tpe: exp.DataType) -> bigquery.SchemaField: + from google.cloud import bigquery + + mode = "NULLABLE" + if tpe.is_type(exp.DataType.Type.ARRAY): + mode = "REPEATED" + tpe = tpe.expressions[0] + + field_type = tpe.sql(dialect=self.dialect) + fields = [] + if tpe.is_type(*exp.DataType.NESTED_TYPES): + field_type = "RECORD" + for inner_field in tpe.expressions: + if isinstance(inner_field, exp.ColumnDef): + inner_name = inner_field.this.sql(dialect=self.dialect) + inner_type = inner_field.kind + if inner_type is None: + raise ValueError( + f"cannot convert unknown type to BQ schema field {inner_field}" + ) + fields.append(self.__get_bq_schemafield(name=inner_name, tpe=inner_type)) + else: + raise ValueError(f"unexpected nested expression {inner_field}") + + return bigquery.SchemaField( + name=name, + field_type=field_type, + mode=mode, + fields=fields, + ) + def __get_bq_schema( self, columns_to_types: t.Dict[str, exp.DataType] ) -> t.List[bigquery.SchemaField]: """ Returns a bigquery schema object from a dictionary of column names to types. """ - from google.cloud import bigquery precisionless_col_to_types = { col_name: remove_precision_parameterized_types(col_type) for col_name, col_type in columns_to_types.items() } return [ - bigquery.SchemaField(col_name, col_type.sql(dialect=self.dialect)) + self.__get_bq_schemafield(name=col_name, tpe=t.cast(exp.DataType, col_type)) for col_name, col_type in precisionless_col_to_types.items() ] @@ -336,7 +667,7 @@ def __get_bq_table( table_ = exp.to_table(table).copy() if not table_.catalog: - table_.set("catalog", exp.to_identifier(self.client.project)) + table_.set("catalog", exp.to_identifier(self.default_catalog)) return bigquery.Table( table_ref=self._table_name(table_), @@ -359,7 +690,8 @@ def insert_overwrite_by_partition( table_name: TableName, query_or_df: QueryOrDF, partitioned_by: t.List[exp.Expression], - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, ) -> None: if len(partitioned_by) != 1: raise SQLMeshError( @@ -378,15 +710,23 @@ def insert_overwrite_by_partition( raise SQLMeshError( f"The partition expression '{partition_sql}' doesn't contain a column." ) - with self.session({}), self.temp_table( - query_or_df, name=table_name, partitioned_by=partitioned_by - ) as temp_table_name: - if columns_to_types is None or columns_to_types[ + with ( + self.session({}), + self.temp_table( + query_or_df, + name=table_name, + partitioned_by=partitioned_by, + source_columns=source_columns, + ) as temp_table_name, + ): + if target_columns_to_types is None or target_columns_to_types[ partition_column.name ] == exp.DataType.build("unknown"): - columns_to_types = self.columns(temp_table_name) + target_columns_to_types = self.columns(table_name) - partition_type_sql = columns_to_types[partition_column.name].sql(dialect=self.dialect) + partition_type_sql = target_columns_to_types[partition_column.name].sql( + dialect=self.dialect + ) select_array_agg_partitions = select_partitions_expr( temp_table_name.db, @@ -394,6 +734,7 @@ def insert_overwrite_by_partition( partition_type_sql, granularity=granularity, agg_func="ARRAY_AGG", + catalog=temp_table_name.catalog or self.default_catalog, ) self.execute( @@ -405,11 +746,17 @@ def insert_overwrite_by_partition( self._insert_overwrite_by_condition( table_name, [SourceQuery(query_factory=lambda: exp.select("*").from_(temp_table_name))], - columns_to_types, + target_columns_to_types, where=where, ) def table_exists(self, table_name: TableName) -> bool: + table = exp.to_table(table_name) + data_object_cache_key = _get_data_object_cache_key(table.catalog, table.db, table.name) + if data_object_cache_key in self._data_object_cache: + logger.debug("Table existence cache hit: %s", data_object_cache_key) + return self._data_object_cache[data_object_cache_key] is not None + try: from google.cloud.exceptions import NotFound except ModuleNotFoundError: @@ -421,6 +768,28 @@ def table_exists(self, table_name: TableName) -> bool: except NotFound: return False + def get_table_last_modified_ts(self, table_names: t.List[TableName]) -> t.List[int]: + from sqlmesh.utils.date import to_timestamp + + datasets_to_tables: t.DefaultDict[str, t.List[str]] = defaultdict(list) + for table_name in table_names: + table = exp.to_table(table_name) + datasets_to_tables[table.db].append(table.name) + + results = [] + + for dataset, tables in datasets_to_tables.items(): + query = ( + f"SELECT TIMESTAMP_MILLIS(last_modified_time) FROM `{dataset}.__TABLES__` WHERE " + ) + for i, table_name in enumerate(tables): + query += f"TABLE_ID = '{table_name}'" + if i < len(tables) - 1: + query += " OR " + results.extend(self.fetchall(query)) + + return [to_timestamp(row[0]) for row in results] + def _get_table(self, table_name: TableName) -> BigQueryTable: """ Returns a BigQueryTable object for the given table name. @@ -437,38 +806,50 @@ def _fetch_native_df( self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False ) -> DF: self.execute(query, quote_identifiers=quote_identifiers) - return self._query_job.to_dataframe() + query_job = self._query_job + assert query_job is not None + return query_job.to_dataframe() def _create_column_comments( self, table_name: TableName, column_comments: t.Dict[str, str], table_kind: str = "TABLE", + materialized_view: bool = False, ) -> None: - table = self._get_table(table_name) - - # convert Table object to dict - table_def = table.to_api_repr() - - # set the column descriptions - for i in range(len(table_def["schema"]["fields"])): - comment = column_comments.get(table_def["schema"]["fields"][i]["name"], None) - if comment: - table_def["schema"]["fields"][i]["description"] = self._truncate_comment( - comment, self.MAX_COLUMN_COMMENT_LENGTH - ) - - # An "etag" is BQ versioning metadata that changes when an object is updated/modified. `update_table` - # compares the etags of the table object passed to it and the remote table, erroring if the etags - # don't match. We set the local etag to None to avoid this check. - table_def["etag"] = None - - # convert dict back to a Table object - table = table.from_api_repr(table_def) - - # update table schema - logger.info(f"Registering column comments for table {table_name}") - self._db_call(self.client.update_table, table=table, fields=["schema"]) + if not (table_kind == "VIEW" and materialized_view): + table = self._get_table(table_name) + + # convert Table object to dict + table_def = table.to_api_repr() + + # Set column descriptions, supporting nested fields (e.g. record.field.nested_field) + for column, comment in column_comments.items(): + fields = table_def["schema"]["fields"] + field_names = column.split(".") + last_index = len(field_names) - 1 + + # Traverse the fields with nested fields down to leaf level + for idx, name in enumerate(field_names): + if field := next((field for field in fields if field["name"] == name), None): + if idx == last_index: + field["description"] = self._truncate_comment( + comment, self.MAX_COLUMN_COMMENT_LENGTH + ) + else: + fields = field.get("fields") or [] + + # An "etag" is BQ versioning metadata that changes when an object is updated/modified. `update_table` + # compares the etags of the table object passed to it and the remote table, erroring if the etags + # don't match. We set the local etag to None to avoid this check. + table_def["etag"] = None + + # convert dict back to a Table object + table = table.from_api_repr(table_def) + + # update table schema + logger.info(f"Registering column comments for table {table_name}") + self._db_call(self.client.update_table, table=table, fields=["schema"]) def _build_description_property_exp( self, @@ -480,59 +861,76 @@ def _build_description_property_exp( value=exp.Literal.string(trunc_method(description)), ) + def _build_partitioned_by_exp( + self, + partitioned_by: t.List[exp.Expression], + *, + partition_interval_unit: t.Optional[IntervalUnit] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + **kwargs: t.Any, + ) -> t.Optional[exp.PartitionedByProperty]: + if len(partitioned_by) > 1: + raise SQLMeshError("BigQuery only supports partitioning by a single column") + + this = partitioned_by[0] + if ( + isinstance(this, exp.Column) + and partition_interval_unit is not None + and not partition_interval_unit.is_minute + ): + column_type: t.Optional[exp.DataType] = (target_columns_to_types or {}).get(this.name) + + if column_type == exp.DataType.build( + "date", dialect=self.dialect + ) and partition_interval_unit in ( + IntervalUnit.MONTH, + IntervalUnit.YEAR, + ): + trunc_func = "DATE_TRUNC" + elif column_type == exp.DataType.build("timestamp", dialect=self.dialect): + trunc_func = "TIMESTAMP_TRUNC" + elif column_type == exp.DataType.build("datetime", dialect=self.dialect): + trunc_func = "DATETIME_TRUNC" + else: + trunc_func = "" + + if trunc_func: + this = exp.func( + trunc_func, + this, + exp.var(partition_interval_unit.value.upper()), + dialect=self.dialect, + ) + + return exp.PartitionedByProperty(this=this) + def _build_table_properties_exp( self, catalog_name: t.Optional[str] = None, + table_format: t.Optional[str] = None, storage_format: t.Optional[str] = None, partitioned_by: t.Optional[t.List[exp.Expression]] = None, partition_interval_unit: t.Optional[IntervalUnit] = None, - clustered_by: t.Optional[t.List[str]] = None, + clustered_by: t.Optional[t.List[exp.Expression]] = None, table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, table_kind: t.Optional[str] = None, + **kwargs: t.Any, ) -> t.Optional[exp.Properties]: properties: t.List[exp.Expression] = [] - if partitioned_by: - if len(partitioned_by) > 1: - raise SQLMeshError("BigQuery only supports partitioning by a single column") - - this = partitioned_by[0] - - if ( - isinstance(this, exp.Column) - and partition_interval_unit is not None - and not partition_interval_unit.is_minute - ): - column_type: t.Optional[exp.DataType] = (columns_to_types or {}).get(this.name) - - if column_type == exp.DataType.build( - "date", dialect=self.dialect - ) and partition_interval_unit in ( - IntervalUnit.MONTH, - IntervalUnit.YEAR, - ): - trunc_func = "DATE_TRUNC" - elif column_type == exp.DataType.build("timestamp", dialect=self.dialect): - trunc_func = "TIMESTAMP_TRUNC" - elif column_type == exp.DataType.build("datetime", dialect=self.dialect): - trunc_func = "DATETIME_TRUNC" - else: - trunc_func = "" - - if trunc_func: - this = exp.func( - trunc_func, - this, - exp.var(partition_interval_unit.value.upper()), - dialect=self.dialect, - ) - - properties.append(exp.PartitionedByProperty(this=this)) + if partitioned_by and ( + partitioned_by_prop := self._build_partitioned_by_exp( + partitioned_by, + partition_interval_unit=partition_interval_unit, + target_columns_to_types=target_columns_to_types, + ) + ): + properties.append(partitioned_by_prop) - if clustered_by: - properties.append(exp.Cluster(expressions=[exp.column(col) for col in clustered_by])) + if clustered_by and (clustered_by_exp := self._build_clustered_by_exp(clustered_by)): + properties.append(clustered_by_exp) if table_description: properties.append( @@ -547,6 +945,66 @@ def _build_table_properties_exp( return exp.Properties(expressions=properties) return None + def _build_column_def( + self, + col_name: str, + column_descriptions: t.Optional[t.Dict[str, str]] = None, + engine_supports_schema_comments: bool = False, + col_type: t.Optional[exp.DATA_TYPE] = None, + nested_names: t.List[str] = [], + ) -> exp.ColumnDef: + # Helper function to build column definitions with column descriptions + def _build_struct_with_descriptions( + col_type: exp.DataType, + nested_names: t.List[str], + ) -> exp.DataType: + column_expressions = [] + for column_def in col_type.expressions: + # This is expected to be true, but this check is included as a + # precautionary measure in case of an unexpected edge case + if isinstance(column_def, exp.ColumnDef): + column = self._build_column_def( + col_name=column_def.name, + column_descriptions=column_descriptions, + engine_supports_schema_comments=engine_supports_schema_comments, + col_type=column_def.kind, + nested_names=nested_names, + ) + else: + column = column_def + column_expressions.append(column) + return exp.DataType(this=col_type.this, expressions=column_expressions, nested=True) + + # Recursively build column definitions for BigQuery's RECORDs (struct) and REPEATED RECORDs (array of struct) + if isinstance(col_type, exp.DataType) and col_type.expressions: + expressions = col_type.expressions + if col_type.is_type(exp.DataType.Type.STRUCT): + col_type = _build_struct_with_descriptions(col_type, nested_names + [col_name]) + elif col_type.is_type(exp.DataType.Type.ARRAY) and expressions[0].is_type( + exp.DataType.Type.STRUCT + ): + col_type = exp.DataType( + this=exp.DataType.Type.ARRAY, + expressions=[ + _build_struct_with_descriptions( + col_type.expressions[0], nested_names + [col_name] + ) + ], + nested=True, + ) + + return exp.ColumnDef( + this=exp.to_identifier(col_name), + kind=col_type, + constraints=( + self._build_col_comment_exp( + ".".join(nested_names + [col_name]), column_descriptions + ) + if engine_supports_schema_comments and self.comments_enabled and column_descriptions + else None + ), + ) + def _build_col_comment_exp( self, col_name: str, column_descriptions: t.Dict[str, str] ) -> t.List[exp.ColumnConstraint]: @@ -569,6 +1027,7 @@ def _build_view_properties_exp( self, view_properties: t.Optional[t.Dict[str, exp.Expression]] = None, table_description: t.Optional[str] = None, + **kwargs: t.Any, ) -> t.Optional[exp.Properties]: """Creates a SQLGlot table properties expression for view""" properties: t.List[exp.Expression] = [] @@ -610,12 +1069,12 @@ def _build_create_comment_column_exp( def create_state_table( self, table_name: str, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], primary_key: t.Optional[t.Tuple[str, ...]] = None, ) -> None: self.create_table( table_name, - columns_to_types, + target_columns_to_types, ) def _db_call(self, func: t.Callable[..., t.Any], *args: t.Any, **kwargs: t.Any) -> t.Any: @@ -628,6 +1087,7 @@ def _db_call(self, func: t.Callable[..., t.Any], *args: t.Any, **kwargs: t.Any) def _execute( self, sql: str, + track_rows_processed: bool = False, **kwargs: t.Any, ) -> None: """Execute a sql query.""" @@ -653,23 +1113,43 @@ def _execute( job_config=job_config, timeout=self._extra_config.get("job_creation_timeout_seconds"), ) + query_job = self._query_job + assert query_job is not None logger.debug( "BigQuery job created: https://console.cloud.google.com/bigquery?project=%s&j=bq:%s:%s", - self._query_job.project, - self._query_job.location, - self._query_job.job_id, + query_job.project, + query_job.location, + query_job.job_id, ) results = self._db_call( - self._query_job.result, + query_job.result, timeout=self._extra_config.get("job_execution_timeout_seconds"), # type: ignore ) + self._query_data = iter(results) if results.total_rows else iter([]) - query_results = self._query_job._query_results + query_results = query_job._query_results self.cursor._set_rowcount(query_results) self.cursor._set_description(query_results.schema) + if ( + track_rows_processed + and self._query_execution_tracker + and self._query_execution_tracker.is_tracking() + ): + num_rows = None + if query_job.statement_type == "CREATE_TABLE_AS_SELECT": + # since table was just created, number rows in table == number rows processed + query_table = self.client.get_table(query_job.destination) + num_rows = query_table.num_rows + elif query_job.statement_type in ["INSERT", "DELETE", "MERGE", "UPDATE"]: + num_rows = query_job.num_dml_affected_rows + + self._query_execution_tracker.record_execution( + sql, num_rows, query_job.total_bytes_processed + ) + def _get_data_objects( self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None ) -> t.List[DataObject]: @@ -680,26 +1160,56 @@ def _get_data_objects( # The BigQuery Client's list_tables method does not support filtering by table name, so we have to # resort to using SQL instead. schema = to_schema(schema_name) - catalog = schema.catalog or self.get_current_catalog() - query = exp.select( - exp.column("table_catalog").as_("catalog"), - exp.column("table_name").as_("name"), - exp.column("table_schema").as_("schema_name"), - exp.case() - .when(exp.column("table_type").eq("BASE TABLE"), exp.Literal.string("TABLE")) - .when(exp.column("table_type").eq("CLONE"), exp.Literal.string("TABLE")) - .when(exp.column("table_type").eq("EXTERNAL"), exp.Literal.string("TABLE")) - .when(exp.column("table_type").eq("SNAPSHOT"), exp.Literal.string("TABLE")) - .when(exp.column("table_type").eq("VIEW"), exp.Literal.string("VIEW")) - .when( - exp.column("table_type").eq("MATERIALIZED VIEW"), - exp.Literal.string("MATERIALIZED_VIEW"), + catalog = schema.catalog or self.default_catalog + query = ( + exp.select( + exp.column("table_catalog").as_("catalog"), + exp.column("table_name").as_("name"), + exp.column("table_schema").as_("schema_name"), + exp.case() + .when(exp.column("table_type").eq("BASE TABLE"), exp.Literal.string("TABLE")) + .when(exp.column("table_type").eq("CLONE"), exp.Literal.string("TABLE")) + .when(exp.column("table_type").eq("EXTERNAL"), exp.Literal.string("TABLE")) + .when(exp.column("table_type").eq("SNAPSHOT"), exp.Literal.string("TABLE")) + .when(exp.column("table_type").eq("VIEW"), exp.Literal.string("VIEW")) + .when( + exp.column("table_type").eq("MATERIALIZED VIEW"), + exp.Literal.string("MATERIALIZED_VIEW"), + ) + .else_(exp.column("table_type")) + .as_("type"), + exp.column("clustering_key", "ci").as_("clustering_key"), ) - .else_(exp.column("table_type")) - .as_("type"), - ).from_( - exp.to_table( - f"`{catalog}`.`{schema.db}`.INFORMATION_SCHEMA.TABLES", dialect=self.dialect + .with_( + "clustering_info", + as_=exp.select( + exp.column("table_catalog"), + exp.column("table_schema"), + exp.column("table_name"), + parse_one( + "string_agg(column_name order by clustering_ordinal_position)", + dialect=self.dialect, + ).as_("clustering_key"), + ) + .from_( + exp.to_table( + f"`{catalog}`.`{schema.db}`.INFORMATION_SCHEMA.COLUMNS", + dialect=self.dialect, + ) + ) + .where(exp.column("clustering_ordinal_position").is_(exp.not_(exp.null()))) + .group_by("1", "2", "3"), + ) + .from_( + exp.to_table( + f"`{catalog}`.`{schema.db}`.INFORMATION_SCHEMA.TABLES", dialect=self.dialect + ) + ) + .join( + "clustering_info", + using=["table_catalog", "table_schema", "table_name"], + join_type="left", + join_alias="ci", ) ) if object_names: @@ -720,25 +1230,99 @@ def _get_data_objects( schema=row.schema_name, # type: ignore name=row.name, # type: ignore type=DataObjectType.from_str(row.type), # type: ignore + clustering_key=f"({row.clustering_key})" if row.clustering_key else None, # type: ignore ) for row in df.itertuples() ] + def _update_clustering_key(self, operation: TableAlterClusterByOperation) -> None: + cluster_key_expressions = getattr(operation, "cluster_key_expressions", []) + bq_table = self._get_table(operation.target_table) + + rendered_columns = [c.sql(dialect=self.dialect) for c in cluster_key_expressions] + bq_table.clustering_fields = ( + rendered_columns or None + ) # causes a drop of the key if cluster_by is empty or None + + self._db_call(self.client.update_table, table=bq_table, fields=["clustering_fields"]) + + if cluster_key_expressions: + # BigQuery only applies new clustering going forward, so this rewrites the columns to apply the new clustering to historical data + # ref: https://cloud.google.com/bigquery/docs/creating-clustered-tables#modifying-cluster-spec + self.execute( + exp.update( + operation.target_table, + {c: c for c in cluster_key_expressions}, + where=exp.true(), + ) + ) + + def _normalize_decimal_value(self, col: exp.Expression, precision: int) -> exp.Expression: + return exp.func("FORMAT", exp.Literal.string(f"%.{precision}f"), col) + + def _normalize_nested_value(self, col: exp.Expression) -> exp.Expression: + return exp.func("TO_JSON_STRING", col, dialect=self.dialect) + + @t.overload + def _columns_to_types( + self, + query_or_df: DF, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, + ) -> t.Tuple[t.Dict[str, exp.DataType], t.List[str]]: ... + + @t.overload + def _columns_to_types( + self, + query_or_df: Query, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, + ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: ... + + def _columns_to_types( + self, + query_or_df: QueryOrDF, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, + ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: + if ( + not target_columns_to_types + and bigframes + and isinstance(query_or_df, bigframes.dataframe.DataFrame) + ): + # using dry_run=True attempts to prevent the DataFrame from being materialized just to read the column types from it + dtypes = query_or_df.to_pandas(dry_run=True).columnDtypes + target_columns_to_types = columns_to_types_from_dtypes(dtypes.items()) + return target_columns_to_types, list(source_columns or target_columns_to_types) + + return super()._columns_to_types( + query_or_df, target_columns_to_types, source_columns=source_columns + ) + + def _native_df_to_pandas_df( + self, + query_or_df: QueryOrDF, + ) -> t.Union[Query, pd.DataFrame]: + if bigframes and isinstance(query_or_df, bigframes.dataframe.DataFrame): + return query_or_df.to_pandas() + + return super()._native_df_to_pandas_df(query_or_df) + @property def _query_data(self) -> t.Any: return self._connection_pool.get_attribute("query_data") @_query_data.setter def _query_data(self, value: t.Any) -> None: - return self._connection_pool.set_attribute("query_data", value) + self._connection_pool.set_attribute("query_data", value) @property - def _query_job(self) -> t.Any: + def _query_job(self) -> t.Optional[QueryJob]: return self._connection_pool.get_attribute("query_job") @_query_job.setter def _query_job(self, value: t.Any) -> None: - return self._connection_pool.set_attribute("query_job", value) + self._connection_pool.set_attribute("query_job", value) @property def _session_id(self) -> t.Any: @@ -746,7 +1330,109 @@ def _session_id(self) -> t.Any: @_session_id.setter def _session_id(self, value: t.Any) -> None: - return self._connection_pool.set_attribute("session_id", value) + self._connection_pool.set_attribute("session_id", value) + + def _get_current_schema(self) -> str: + raise NotImplementedError("BigQuery does not support current schema") + + def _get_bq_dataset_location(self, project: str, dataset: str) -> str: + return self._db_call(self.client.get_dataset, dataset_ref=f"{project}.{dataset}").location + + def _get_grant_expression(self, table: exp.Table) -> exp.Expression: + if not table.db: + raise ValueError( + f"Table {table.sql(dialect=self.dialect)} does not have a schema (dataset)" + ) + project = table.catalog or self.get_current_catalog() + if not project: + raise ValueError( + f"Table {table.sql(dialect=self.dialect)} does not have a catalog (project)" + ) + + dataset = table.db + table_name = table.name + location = self._get_bq_dataset_location(project, dataset) + + # https://cloud.google.com/bigquery/docs/information-schema-object-privileges + # OBJECT_PRIVILEGES is a project-level INFORMATION_SCHEMA view with regional qualifier + object_privileges_table = exp.to_table( + f"`{project}`.`region-{location}`.INFORMATION_SCHEMA.{self.GRANT_INFORMATION_SCHEMA_TABLE_NAME}", + dialect=self.dialect, + ) + return ( + exp.select("privilege_type", "grantee") + .from_(object_privileges_table) + .where( + exp.and_( + exp.column("object_schema").eq(exp.Literal.string(dataset)), + exp.column("object_name").eq(exp.Literal.string(table_name)), + # Filter out current_user + # BigQuery grantees format: "user:email" or "group:name" + exp.func("split", exp.column("grantee"), exp.Literal.string(":"))[ + exp.func("OFFSET", exp.Literal.number("1")) + ].neq(self.CURRENT_USER_OR_ROLE_EXPRESSION), + ) + ) + ) + + @staticmethod + def _grant_object_kind(table_type: DataObjectType) -> str: + if table_type == DataObjectType.VIEW: + return "VIEW" + if table_type == DataObjectType.MATERIALIZED_VIEW: + # We actually need to use "MATERIALIZED VIEW" here even though it's not listed + # as a supported resource_type in the BigQuery DCL doc: + # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-control-language + return "MATERIALIZED VIEW" + return "TABLE" + + def _dcl_grants_config_expr( + self, + dcl_cmd: t.Type[DCL], + table: exp.Table, + grants_config: GrantsConfig, + table_type: DataObjectType = DataObjectType.TABLE, + ) -> t.List[exp.Expression]: + expressions: t.List[exp.Expression] = [] + if not grants_config: + return expressions + + # https://cloud.google.com/bigquery/docs/reference/standard-sql/data-control-language + + def normalize_principal(p: str) -> str: + if ":" not in p: + raise ValueError(f"Principal '{p}' missing a prefix label") + + # allUsers and allAuthenticatedUsers special groups that are cas-sensitive and must start with "specialGroup:" + if p.endswith("allUsers") or p.endswith("allAuthenticatedUsers"): + if not p.startswith("specialGroup:"): + raise ValueError( + f"Special group principal '{p}' must start with 'specialGroup:' prefix label" + ) + return p + + label, principal = p.split(":", 1) + # always lowercase principals + return f"{label}:{principal.lower()}" + + object_kind = self._grant_object_kind(table_type) + for privilege, principals in grants_config.items(): + if not principals: + continue + + noramlized_principals = [exp.Literal.string(normalize_principal(p)) for p in principals] + args: t.Dict[str, t.Any] = { + "privileges": [exp.GrantPrivilege(this=exp.to_identifier(privilege, quoted=True))], + "securable": table.copy(), + "principals": noramlized_principals, + } + + if object_kind: + args["kind"] = exp.Var(this=object_kind) + + expressions.append(dcl_cmd(**args)) # type: ignore[arg-type] + + return expressions class _ErrorCounter: @@ -778,7 +1464,7 @@ def _is_retryable(self, error: BaseException) -> bool: if isinstance(error, self.retryable_errors): return True - elif isinstance(error, Forbidden) and any( + if isinstance(error, Forbidden) and any( e["reason"] == "rateLimitExceeded" for e in error.errors ): return True @@ -800,24 +1486,24 @@ def select_partitions_expr( data_type: t.Union[str, exp.DataType], granularity: t.Optional[str] = None, agg_func: str = "MAX", - database: t.Optional[str] = None, + catalog: t.Optional[str] = None, ) -> str: """Generates a SQL expression that aggregates partition values for a table. Args: - schema: The schema (BigQueyr dataset) of the table. + schema: The schema (BigQuery dataset) of the table. table_name: The name of the table. data_type: The data type of the partition column. granularity: The granularity of the partition. Supported values are: 'day', 'month', 'year' and 'hour'. agg_func: The aggregation function to use. - database: The database (BigQuery project ID) of the table. + catalog: The catalog (BigQuery project ID) of the table. Returns: A SELECT statement that aggregates partition values for a table. """ partitions_table_name = f"`{schema}`.INFORMATION_SCHEMA.PARTITIONS" - if database: - partitions_table_name = f"`{database}`.{partitions_table_name}" + if catalog: + partitions_table_name = f"`{catalog}`.{partitions_table_name}" if isinstance(data_type, exp.DataType): data_type = data_type.sql(dialect="bigquery") diff --git a/sqlmesh/core/engine_adapter/clickhouse.py b/sqlmesh/core/engine_adapter/clickhouse.py new file mode 100644 index 0000000000..45c22a6e55 --- /dev/null +++ b/sqlmesh/core/engine_adapter/clickhouse.py @@ -0,0 +1,906 @@ +from __future__ import annotations + +import typing as t +import logging +import re +from sqlglot import exp, maybe_parse +from sqlmesh.core.dialect import to_schema +from sqlmesh.core.engine_adapter.mixins import LogicalMergeMixin +from sqlmesh.core.engine_adapter.base import EngineAdapterWithIndexSupport +from sqlmesh.core.engine_adapter.shared import ( + DataObject, + DataObjectType, + EngineRunMode, + SourceQuery, + CommentCreationView, + InsertOverwriteStrategy, +) +from sqlmesh.core.schema_diff import TableAlterOperation +from sqlmesh.utils import get_source_columns_to_types + +if t.TYPE_CHECKING: + import pandas as pd + + from sqlmesh.core._typing import SchemaName, TableName + from sqlmesh.core.engine_adapter._typing import DF, Query, QueryOrDF + + from sqlmesh.core.node import IntervalUnit + + +logger = logging.getLogger(__name__) + + +class ClickhouseEngineAdapter(EngineAdapterWithIndexSupport, LogicalMergeMixin): + DIALECT = "clickhouse" + SUPPORTS_TRANSACTIONS = False + SUPPORTS_VIEW_SCHEMA = False + SUPPORTS_REPLACE_TABLE = False + COMMENT_CREATION_VIEW = CommentCreationView.COMMENT_COMMAND_ONLY + + SCHEMA_DIFFER_KWARGS = {} + + DEFAULT_TABLE_ENGINE = "MergeTree" + ORDER_BY_TABLE_ENGINE_REGEX = "^.*?MergeTree.*$" + + @property + def engine_run_mode(self) -> EngineRunMode: + if self._extra_config.get("cloud_mode"): + return EngineRunMode.CLOUD + # we use the user's specification of a cluster in the connection config to determine if + # the engine is in cluster mode + if self._extra_config.get("cluster"): + return EngineRunMode.CLUSTER + return EngineRunMode.STANDALONE + + @property + def cluster(self) -> t.Optional[str]: + return self._extra_config.get("cluster") + + # Workaround for clickhouse-connect cursor bug + # - cursor does not reset row index correctly on `close()`, so `fetchone()` and `fetchmany()` + # return the wrong (or no) rows after the very first cursor query that returns rows + # in the connection + # - cursor does reset the data rows correctly on `close()`, so `fetchall()` works because it + # doesn't use the row index at all + def fetchone( + self, + query: t.Union[exp.Expression, str], + ignore_unsupported_errors: bool = False, + quote_identifiers: bool = False, + ) -> t.Tuple: + with self.transaction(): + self.execute( + query, + ignore_unsupported_errors=ignore_unsupported_errors, + quote_identifiers=quote_identifiers, + ) + return self.cursor.fetchall()[0] + + def _fetch_native_df( + self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False + ) -> pd.DataFrame: + """Fetches a Pandas DataFrame from the cursor""" + return self.cursor.client.query_df( + self._to_sql(query, quote=quote_identifiers) + if isinstance(query, exp.Expression) + else query, + use_extended_dtypes=True, + ) + + def _df_to_source_queries( + self, + df: DF, + target_columns_to_types: t.Dict[str, exp.DataType], + batch_size: int, + target_table: TableName, + source_columns: t.Optional[t.List[str]] = None, + **kwargs: t.Any, + ) -> t.List[SourceQuery]: + temp_table = self._get_temp_table(target_table, **kwargs) + source_columns_to_types = get_source_columns_to_types( + target_columns_to_types, source_columns + ) + + def query_factory() -> Query: + # It is possible for the factory to be called multiple times and if so then the temp table will already + # be created so we skip creating again. This means we are assuming the first call is the same result + # as later calls. + if not self.table_exists(temp_table): + self.create_table( + temp_table, + source_columns_to_types, + storage_format=exp.var("MergeTree"), + **kwargs, + ) + ordered_df = df[list(source_columns_to_types)] + + self.cursor.client.insert_df(temp_table.sql(dialect=self.dialect), df=ordered_df) + + return exp.select(*self._casted_columns(target_columns_to_types, source_columns)).from_( + temp_table + ) + + return [ + SourceQuery( + query_factory=query_factory, + cleanup_func=lambda: self.drop_table(temp_table, **kwargs), + ) + ] + + def _get_data_objects( + self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None + ) -> t.List[DataObject]: + """ + Returns all the data objects that exist in the given database. + """ + query = ( + exp.select( + exp.column("database").as_("schema_name"), + exp.column("name"), + exp.case(exp.column("engine")) + .when( + exp.Literal.string("View"), + exp.Literal.string("view"), + ) + .else_( + exp.Literal.string("table"), + ) + .as_("type"), + ) + .from_("system.tables") + .where(exp.column("database").eq(to_schema(schema_name).db)) + ) + if object_names: + query = query.where(exp.column("name").isin(*object_names)) + df = self.fetchdf(query) + return [ + DataObject( + catalog=None, + schema=row.schema_name, + name=row.name, + type=DataObjectType.from_str(row.type), # type: ignore + ) + for row in df.itertuples() + ] + + def create_schema( + self, + schema_name: SchemaName, + ignore_if_exists: bool = True, + warn_on_error: bool = True, + properties: t.List[exp.Expression] = [], + ) -> None: + """Create a Clickhouse database from a name or qualified table name. + + Clickhouse has a two-level naming scheme [database].[table]. + """ + properties_copy = properties.copy() + if self.engine_run_mode.is_cluster: + properties_copy.append(exp.OnCluster(this=exp.to_identifier(self.cluster))) + + # can't call super() because it will try to set a catalog + return self._create_schema( + schema_name=schema_name, + ignore_if_exists=ignore_if_exists, + warn_on_error=warn_on_error, + properties=properties_copy, + # sqlglot transpiles CREATE SCHEMA to CREATE DATABASE, but this text is used in an error message + kind="DATABASE", + ) + + def _insert_overwrite_by_condition( + self, + table_name: TableName, + source_queries: t.List[SourceQuery], + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + where: t.Optional[exp.Condition] = None, + insert_overwrite_strategy_override: t.Optional[InsertOverwriteStrategy] = None, + **kwargs: t.Any, + ) -> None: + """ + Implements the table or partition swap approach to insert-overwriting records. + + Because this method executes multiple variants (full table replace, replace by time + range, replace by key, replace by partition), some upstream caller info is needed and + passed via kwargs. + + Args: + table_name: Name of target table + source_queries: Source queries returning records to insert + target_columns_to_types: Column names and data types of target table + where: SQLGlot expression determining which target table rows should be overwritten + insert_overwrite_strategy_override: Not used by Clickhouse + kwargs: + dynamic_key: Key columns (replace by key only) + dynamic_key_exp: Expression to build key (replace by key only) + dynamic_key_unique: Whether more than one record can exist per key value (replace by key only) + + keep_existing_partition_rows: Whether to overwrite partitions with only new records (incremental by partition only) + + Returns: + Side effects only: execution of insert-overwrite operation. + """ + target_table = exp.to_table(table_name) + target_columns_to_types = target_columns_to_types or self.columns(target_table) + + temp_table = self._get_temp_table(target_table) + self.create_table_like(temp_table, target_table) + + # REPLACE BY KEY: extract kwargs if present + dynamic_key = kwargs.get("dynamic_key") + if dynamic_key: + dynamic_key_exp = t.cast(exp.Expression, kwargs.get("dynamic_key_exp")) + dynamic_key_unique = t.cast(bool, kwargs.get("dynamic_key_unique")) + + try: + # insert new records into temp table + for source_query in source_queries: + with source_query as query: + # REPLACE BY KEY: if unique key, DISTINCTify by key columns so only one row is present per key + if dynamic_key and dynamic_key_unique: + query = query.distinct(*dynamic_key) # type: ignore + + query = self._order_projections_and_filter( + query, target_columns_to_types, where=where + ) + self._insert_append_query( + temp_table, + query, + target_columns_to_types=target_columns_to_types, + order_projections=False, + ) + + # REPLACE BY KEY: build `where` expression as "key IN (new rows' key values)" + if dynamic_key: + key_query = exp.select(dynamic_key_exp).from_(temp_table) + if not dynamic_key_unique: + key_query = key_query.distinct() + where = dynamic_key_exp.isin(query=key_query) + + # get target table partition key to confirm it's actually partitioned + table_partition_exp = self.fetchone( + exp.select("partition_key") + .from_("system.tables") + .where( + exp.column("database").eq(target_table.db), + exp.column("name").eq(target_table.name), + ) + ) + + all_affected_partitions: t.Set[str] = set() + + if where: + # identify existing records to keep by inverting the delete `where` clause + existing_records_insert_exp = exp.insert( + self._select_columns(target_columns_to_types) + .from_(target_table) + .where(exp.paren(expression=where).not_()), + temp_table, + ) + + # if target table is partitioned, modify insert expression to only insert + # existing records that are in one of the affected partitions + if table_partition_exp: + partitions_temp_table_name = self._get_temp_table( + exp.to_table(f"{target_table.db}._affected_partitions") + ) + all_affected_partitions, existing_records_insert_exp = ( + self._get_affected_partitions_and_insert_exp( + target_table, + temp_table, + where, + existing_records_insert_exp, + partitions_temp_table_name, + ) + ) + + try: + self.execute(existing_records_insert_exp, track_rows_processed=True) + finally: + if table_partition_exp: + self.drop_table(partitions_temp_table_name) + + # process by partition if: + # 1. The table is partitioned AND + # (2a. There are existing records to keep (`where`) OR + # 2b. We're overwriting existing partition rows (incremental by partition model)) + if table_partition_exp and ( + where or kwargs.get("keep_existing_partition_rows") is False + ): + # only replace partitions that have records in temp_table + partitions_to_replace = self._get_partition_ids(temp_table) + + # drop affected partitions that have no records in temp_table + # - NOTE: `all_affected_partitions` will be empty when keep_existing_partition_rows=False + # because previous code block is skipped + partitions_to_drop = all_affected_partitions - partitions_to_replace + + if partitions_to_replace or partitions_to_drop: + self.alter_table( + [ + self._build_alter_partition_exp( + target_table, temp_table, partitions_to_replace, partitions_to_drop + ) + ] + ) + else: + self._exchange_tables(target_table, temp_table) + finally: + self.drop_table(temp_table) + + def _get_affected_partitions_and_insert_exp( + self, + target_table: exp.Table, + temp_table: exp.Table, + where: exp.Condition, + existing_records_insert_exp: exp.Insert, + partitions_temp_table_name: exp.Table, + ) -> tuple[t.Set[str], exp.Insert]: + # identify all affected partition IDs + # - store in temp table so we can reuse results + self.ctas( + partitions_temp_table_name, + exp.select("partition_id") + .distinct() + .from_( + exp.union( + # target table partitions with records in `where` + exp.select(exp.column("_partition_id").as_("partition_id")) + .from_(target_table) + .where(where), + # temp table partitions with new records to insert + exp.select( + exp.column("_partition_id").as_("partition_id"), + ).from_(temp_table), + ).subquery("_affected_partitions") + ), + ) + + # read all affected partition IDs into memory + all_affected_partitions = self._get_partition_ids( + partitions_temp_table_name, "partition_id" + ) + + # limit existing records insert expression WHERE to affected target table partitions + # by adding `AND _partition_id IN (SELECT partition_id FROM partitions_temp_table)` + existing_records_insert_exp.set( + "expression", + existing_records_insert_exp.expression.where( + exp.column("_partition_id").isin( + exp.select("partition_id").from_(partitions_temp_table_name) + ) + ), + ) + + return all_affected_partitions, existing_records_insert_exp + + def _build_alter_partition_exp( + self, + target_table: exp.Table, + temp_table: exp.Table, + partitions_to_replace: t.Set[str], + partitions_to_drop: t.Set[str], + ) -> exp.Alter: + alter_expr = exp.Alter(this=target_table, kind="TABLE") + + for partition in partitions_to_replace: + alter_expr.append( + "actions", + exp.ReplacePartition( + expression=exp.Partition( + expressions=[exp.PartitionId(this=exp.Literal.string(str(partition)))] + ), + source=temp_table, + ), + ) + + for partition in partitions_to_drop: + alter_expr.append( + "actions", + exp.DropPartition( + expressions=[ + exp.Partition( + expressions=[exp.PartitionId(this=exp.Literal.string(str(partition)))] + ) + ], + source=temp_table, + ), + ) + + return alter_expr + + def _replace_by_key( + self, + target_table: TableName, + source_table: QueryOrDF, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]], + key: t.Sequence[exp.Expression], + is_unique_key: bool, + source_columns: t.Optional[t.List[str]] = None, + ) -> None: + source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types( + source_table, + target_columns_to_types, + target_table=target_table, + source_columns=source_columns, + ) + + key_exp = exp.func("CONCAT_WS", "'__SQLMESH_DELIM__'", *key) if len(key) > 1 else key[0] + + self._insert_overwrite_by_condition( + target_table, + source_queries, + target_columns_to_types, + dynamic_key=key, + dynamic_key_exp=key_exp, + dynamic_key_unique=is_unique_key, + ) + + def insert_overwrite_by_partition( + self, + table_name: TableName, + query_or_df: QueryOrDF, + partitioned_by: t.List[exp.Expression], + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, + ) -> None: + source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types( + query_or_df, + target_columns_to_types, + target_table=table_name, + source_columns=source_columns, + ) + + self._insert_overwrite_by_condition( + table_name, source_queries, target_columns_to_types, keep_existing_partition_rows=False + ) + + def _create_table_like( + self, + target_table_name: TableName, + source_table_name: TableName, + exists: bool, + **kwargs: t.Any, + ) -> None: + """Create table with identical structure as source table""" + self.execute( + f"CREATE TABLE {target_table_name}{self._on_cluster_sql()} AS {source_table_name}" + ) + + def _get_partition_ids( + self, + table: exp.Table, + partition_col_name: str = "_partition_id", + where: t.Optional[exp.Condition] = None, + limit: t.Optional[int] = None, + ) -> t.Set[t.Any]: + """List partition IDs present in table""" + partitions_query = exp.select(partition_col_name).distinct().from_(table) + if where: + partitions_query = partitions_query.where(where) + if limit: + partitions_query = partitions_query.limit(limit) + partitions = self.fetchall(partitions_query) + + return set([part[0] for part in partitions] if partitions else []) + + def _create_table( + self, + table_name_or_schema: t.Union[exp.Schema, TableName], + expression: t.Optional[exp.Expression], + exists: bool = True, + replace: bool = False, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + table_description: t.Optional[str] = None, + column_descriptions: t.Optional[t.Dict[str, str]] = None, + table_kind: t.Optional[str] = None, + track_rows_processed: bool = True, + **kwargs: t.Any, + ) -> None: + """Creates a table in the database. + + Clickhouse Cloud requires doing CTAS in two steps. + + First, we add the `EMPTY` property to the CTAS call to create a table with the proper + schema, then insert the data with the CTAS query. + """ + # ensure columns used for partitioning are non-Nullable + # - normally user's responsibility, but we automatically partition by time column in + # incremental by time models + if kwargs.get("partitioned_by"): + partition_cols = [ + col.name + for part_expr in kwargs["partitioned_by"] + for col in part_expr.find_all(exp.Column) + ] + if isinstance(table_name_or_schema, exp.Schema): + for coldef in table_name_or_schema.expressions: + if coldef.name in partition_cols: + coldef.kind.set("nullable", False) + if target_columns_to_types: + for col in partition_cols: + target_columns_to_types[col].set("nullable", False) + + super()._create_table( + table_name_or_schema, + expression, + exists, + replace, + target_columns_to_types, + table_description, + column_descriptions, + table_kind, + empty_ctas=(self.engine_run_mode.is_cloud and expression is not None), + track_rows_processed=track_rows_processed, + **kwargs, + ) + + # execute the second INSERT step if on cloud and creating a table + # - Additional clause is to avoid clickhouse-connect HTTP client bug where CTAS LIMIT 0 + # returns a success code but malformed response + if ( + self.engine_run_mode.is_cloud + and table_kind != "VIEW" + and expression + and not ( + expression.args.get("limit") is not None + and expression.args["limit"].expression.this == "0" + ) + ): + table_name = ( + table_name_or_schema.this + if isinstance(table_name_or_schema, exp.Schema) + else table_name_or_schema + ) + self._insert_append_query( + table_name, + expression, # type: ignore + target_columns_to_types or self.columns(table_name), + ) + + def _exchange_tables( + self, + old_table_name: TableName, + new_table_name: TableName, + ) -> None: + from clickhouse_connect.driver.exceptions import DatabaseError # type: ignore + + old_table_sql = exp.to_table(old_table_name).sql(dialect=self.dialect, identify=True) + new_table_sql = exp.to_table(new_table_name).sql(dialect=self.dialect, identify=True) + + try: + self.execute( + f"EXCHANGE TABLES {old_table_sql} AND {new_table_sql}{self._on_cluster_sql()}" + ) + except DatabaseError as e: + if "NOT_IMPLEMENTED" in str(e): + # If someone is using an old Clickhouse version, an OS that doesn't support atomic exchanges, + # or a database engine that doesn't support atomic exchanges, we do a non-atomic rename instead. + # + # Executing multiple renames in one call like `RENAME TABLE a to b, c to a` is supported + # but not an atomic operation. Because it is not atomic, doing it in two calls is equivalent + # and does not require defining an additional method. + throwaway_table_name = self._get_temp_table(old_table_name) + self._rename_table(old_table_name, throwaway_table_name) + self._rename_table(new_table_name, old_table_name) + self.drop_table(throwaway_table_name) + + def _rename_table( + self, + old_table_name: TableName, + new_table_name: TableName, + ) -> None: + old_table_sql = exp.to_table(old_table_name).sql(dialect=self.dialect, identify=True) + new_table_sql = exp.to_table(new_table_name).sql(dialect=self.dialect, identify=True) + + self.execute(f"RENAME TABLE {old_table_sql} TO {new_table_sql}{self._on_cluster_sql()}") + + def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expression]) -> None: + delete_expr = exp.delete(table_name, where) + if self.engine_run_mode.is_cluster: + delete_expr.set("cluster", exp.OnCluster(this=exp.to_identifier(self.cluster))) + self.execute(delete_expr) + + def alter_table( + self, + alter_expressions: t.Union[t.List[exp.Alter], t.List[TableAlterOperation]], + ) -> None: + """ + Performs the alter statements to change the current table into the structure of the target table. + """ + with self.transaction(): + for alter_expression in [ + x.expression if isinstance(x, TableAlterOperation) else x for x in alter_expressions + ]: + if self.engine_run_mode.is_cluster: + alter_expression.set( + "cluster", exp.OnCluster(this=exp.to_identifier(self.cluster)) + ) + self.execute(alter_expression) + + def _drop_object( + self, + name: TableName | SchemaName, + exists: bool = True, + kind: str = "TABLE", + cascade: bool = False, + **drop_args: t.Any, + ) -> None: + """Drops an object. + + An object could be a DATABASE, SCHEMA, VIEW, TABLE, DYNAMIC TABLE, TEMPORARY TABLE etc depending on the :kind. + + Args: + name: The name of the table to drop. + exists: If exists, defaults to True. + kind: What kind of object to drop. Defaults to TABLE + **drop_args: Any extra arguments to set on the Drop expression + """ + super()._drop_object( + name=name, + exists=exists, + kind=kind, + cascade=cascade, + cluster=exp.OnCluster(this=exp.to_identifier(self.cluster)) + if self.engine_run_mode.is_cluster + else None, + **drop_args, + ) + + def _build_partitioned_by_exp( + self, + partitioned_by: t.List[exp.Expression], + **kwargs: t.Any, + ) -> t.Optional[t.Union[exp.PartitionedByProperty, exp.Property]]: + return exp.PartitionedByProperty( + this=exp.Schema(expressions=partitioned_by), + ) + + def ensure_nulls_for_unmatched_after_join( + self, + query: Query, + ) -> Query: + # Set `join_use_nulls = 1` in a query's SETTINGS clause + query.append("settings", exp.var("join_use_nulls").eq(exp.Literal.number("1"))) + return query + + def use_server_nulls_for_unmatched_after_join( + self, + query: Query, + ) -> Query: + # Set the `join_use_nulls` server value in a query's SETTINGS clause + # + # Use in SCD models: + # - The SCD query we build must include the setting `join_use_nulls = 1` to ensure that empty cells in a join + # are filled with NULL instead of the default data type value. The default join_use_nulls value is `0`. + # - The SCD embeds the user's original query in the `source` CTE + # - Settings are dynamically scoped, so our setting may override the server's default setting the user expects + # for their query. + # - To prevent this, we: + # - If the user query sets `join_use_nulls`, we do nothing + # - If the user query does not set `join_use_nulls`, we query the server for the current setting + # - If the server value is 1, we do nothing + # - If the server values is not 1, we inject its `join_use_nulls` value into the user query + # - We do not need to check user subqueries because our injected setting operates at the same scope the + # server value would normally operate at + setting_name = "join_use_nulls" + setting_value = "1" + + user_settings = query.args.get("settings") + # if user has not already set it explicitly + if not ( + user_settings + and any( + [ + isinstance(setting, exp.EQ) and setting.name == setting_name + for setting in user_settings + ] + ) + ): + server_value = self.fetchone( + exp.select("value") + .from_("system.settings") + .where(exp.column("name").eq(exp.Literal.string(setting_name))) + )[0] + # only inject the setting if the server value isn't 1 + inject_setting = setting_value != server_value + setting_value = server_value if inject_setting else setting_value + + if inject_setting: + query.append( + "settings", exp.var(setting_name).eq(exp.Literal.number(setting_value)) + ) + + return query + + def _build_settings_property( + self, key: str, value: exp.Expression | str | int | float + ) -> exp.SettingsProperty: + return exp.SettingsProperty( + expressions=[ + exp.EQ( + this=exp.var(key.lower()), + expression=value + if isinstance(value, exp.Expression) + else exp.Literal(this=value, is_string=isinstance(value, str)), + ) + ] + ) + + def _build_table_properties_exp( + self, + catalog_name: t.Optional[str] = None, + table_format: t.Optional[str] = None, + storage_format: t.Optional[str] = None, + partitioned_by: t.Optional[t.List[exp.Expression]] = None, + partition_interval_unit: t.Optional[IntervalUnit] = None, + clustered_by: t.Optional[t.List[exp.Expression]] = None, + table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + table_description: t.Optional[str] = None, + table_kind: t.Optional[str] = None, + empty_ctas: bool = False, + **kwargs: t.Any, + ) -> t.Optional[exp.Properties]: + properties: t.List[exp.Expression] = [] + + table_engine = self.DEFAULT_TABLE_ENGINE + if storage_format: + table_engine = ( + storage_format.this if isinstance(storage_format, exp.Var) else storage_format # type: ignore + ) + properties.append(exp.EngineProperty(this=table_engine)) + + # copy of table_properties so we can pop items off below then consume the rest later + table_properties_copy = { + k.upper(): v for k, v in (table_properties.copy() if table_properties else {}).items() + } + + mergetree_engine = bool(re.search(self.ORDER_BY_TABLE_ENGINE_REGEX, table_engine)) + ordered_by_raw = table_properties_copy.pop("ORDER_BY", None) + if mergetree_engine: + ordered_by_exprs = [] + if ordered_by_raw: + ordered_by_vals = [] + + if isinstance(ordered_by_raw, (exp.Tuple, exp.Array)): + ordered_by_vals = ordered_by_raw.expressions + if isinstance(ordered_by_raw, exp.Paren): + ordered_by_vals = [ordered_by_raw.this] + + if not ordered_by_vals: + ordered_by_vals = ( + ordered_by_raw if isinstance(ordered_by_raw, list) else [ordered_by_raw] + ) + + for col in ordered_by_vals: + ordered_by_exprs.append( + col + if isinstance(col, exp.Column) + else maybe_parse( + col.name if isinstance(col, exp.Literal) else col, + dialect=self.dialect, + into=exp.Ordered, + ) + ) + + properties.append(exp.Order(expressions=[exp.Tuple(expressions=ordered_by_exprs)])) + + primary_key = table_properties_copy.pop("PRIMARY_KEY", None) + if mergetree_engine and primary_key: + primary_key_vals = [] + if isinstance(primary_key, (exp.Tuple, exp.Array)): + primary_key_vals = primary_key.expressions + if isinstance(ordered_by_raw, exp.Paren): + primary_key_vals = [primary_key.this] + + if not primary_key_vals: + primary_key_vals = primary_key if isinstance(primary_key, list) else [primary_key] + + properties.append( + exp.PrimaryKey( + expressions=[ + exp.to_column(k.name if isinstance(k, exp.Literal) else k) + for k in primary_key_vals + ] + ) + ) + + ttl = table_properties_copy.pop("TTL", None) + if ttl: + properties.append( + exp.MergeTreeTTL( + expressions=[ttl if isinstance(ttl, exp.Expression) else exp.var(ttl)] + ) + ) + + if ( + partitioned_by + and (partitioned_by_prop := self._build_partitioned_by_exp(partitioned_by)) is not None + ): + properties.append(partitioned_by_prop) + + if self.engine_run_mode.is_cluster: + properties.append(exp.OnCluster(this=exp.to_identifier(self.cluster))) + + if empty_ctas: + properties.append(exp.EmptyProperty()) + + if table_properties_copy: + properties.extend( + [self._build_settings_property(k, v) for k, v in table_properties_copy.items()] + ) + + if table_description: + properties.append( + exp.SchemaCommentProperty( + this=exp.Literal.string(self._truncate_table_comment(table_description)) + ) + ) + + if properties: + return exp.Properties(expressions=properties) + + return None + + def _build_view_properties_exp( + self, + view_properties: t.Optional[t.Dict[str, exp.Expression]] = None, + table_description: t.Optional[str] = None, + **kwargs: t.Any, + ) -> t.Optional[exp.Properties]: + """Creates a SQLGlot table properties expression for view""" + properties: t.List[exp.Expression] = [] + + view_properties_copy = view_properties.copy() if view_properties else {} + + if self.engine_run_mode.is_cluster: + properties.append(exp.OnCluster(this=exp.to_identifier(self.cluster))) + + if view_properties_copy: + properties.extend( + [self._build_settings_property(k, v) for k, v in view_properties_copy.items()] + ) + + if table_description: + properties.append( + exp.SchemaCommentProperty( + this=exp.Literal.string(self._truncate_table_comment(table_description)) + ) + ) + + if properties: + return exp.Properties(expressions=properties) + return None + + def _build_create_comment_table_exp( + self, table: exp.Table, table_comment: str, table_kind: str, **kwargs: t.Any + ) -> exp.Comment | str: + table_sql = table.sql(dialect=self.dialect, identify=True) + + truncated_comment = self._truncate_table_comment(table_comment) + comment_sql = exp.Literal.string(truncated_comment).sql(dialect=self.dialect) + + return f"ALTER TABLE {table_sql}{self._on_cluster_sql()} MODIFY COMMENT {comment_sql}" + + def _build_create_comment_column_exp( + self, + table: exp.Table, + column_name: str, + column_comment: str, + table_kind: str = "TABLE", + **kwargs: t.Any, + ) -> exp.Comment | str: + table_sql = table.sql(dialect=self.dialect, identify=True) + column_sql = exp.to_column(column_name).sql(dialect=self.dialect, identify=True) + + truncated_comment = self._truncate_table_comment(column_comment) + comment_sql = exp.Literal.string(truncated_comment).sql(dialect=self.dialect) + + return f"ALTER TABLE {table_sql}{self._on_cluster_sql()} COMMENT COLUMN {column_sql} {comment_sql}" + + def _on_cluster_sql(self) -> str: + if self.engine_run_mode.is_cluster: + cluster_name = exp.to_identifier(self.cluster, quoted=True).sql(dialect=self.dialect) # type: ignore + return f" ON CLUSTER {cluster_name} " + return "" diff --git a/sqlmesh/core/engine_adapter/databricks.py b/sqlmesh/core/engine_adapter/databricks.py index 8e7b52c65f..870b946e7d 100644 --- a/sqlmesh/core/engine_adapter/databricks.py +++ b/sqlmesh/core/engine_adapter/databricks.py @@ -2,52 +2,56 @@ import logging import typing as t +from functools import partial -import pandas as pd from sqlglot import exp +from sqlmesh.core.dialect import to_schema +from sqlmesh.core.engine_adapter.mixins import GrantsFromInfoSchemaMixin from sqlmesh.core.engine_adapter.shared import ( CatalogSupport, DataObject, + DataObjectType, InsertOverwriteStrategy, - set_catalog, + SourceQuery, ) from sqlmesh.core.engine_adapter.spark import SparkEngineAdapter -from sqlmesh.core.schema_diff import SchemaDiffer -from sqlmesh.utils.errors import SQLMeshError +from sqlmesh.core.node import IntervalUnit +from sqlmesh.core.schema_diff import NestedSupport +from sqlmesh.engines.spark.db_api.spark_session import connection, SparkSessionConnection +from sqlmesh.utils.errors import SQLMeshError, MissingDefaultCatalogError if t.TYPE_CHECKING: - from sqlmesh.core._typing import SchemaName, TableName - from sqlmesh.core.engine_adapter._typing import DF, PySparkSession + import pandas as pd + + from sqlmesh.core._typing import SchemaName, TableName, SessionProperties + from sqlmesh.core.engine_adapter._typing import DF, PySparkSession, Query logger = logging.getLogger(__name__) -@set_catalog( - { - "_get_data_objects": CatalogSupport.REQUIRES_SET_CATALOG, - } -) -class DatabricksEngineAdapter(SparkEngineAdapter): +class DatabricksEngineAdapter(SparkEngineAdapter, GrantsFromInfoSchemaMixin): DIALECT = "databricks" INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.REPLACE_WHERE SUPPORTS_CLONING = True SUPPORTS_MATERIALIZED_VIEWS = True SUPPORTS_MATERIALIZED_VIEW_SCHEMA = True - SCHEMA_DIFFER = SchemaDiffer( - support_positional_add=True, - support_nested_operations=True, - array_element_selector="element", - parameterized_type_defaults={ + SUPPORTS_GRANTS = True + USE_CATALOG_IN_GRANTS = True + # Spark has this set to false for compatibility when mixing with Trino but that isn't a concern with Databricks + QUOTE_IDENTIFIERS_IN_VIEWS = True + SCHEMA_DIFFER_KWARGS = { + "support_positional_add": True, + "nested_support": NestedSupport.ALL, + "array_element_selector": "element", + "parameterized_type_defaults": { exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(10, 0), (0,)], }, - ) - CATALOG_SUPPORT = CatalogSupport.FULL_SUPPORT - SUPPORTS_ROW_LEVEL_OP = True + } - def __init__(self, *args: t.Any, **kwargs: t.Any): + def __init__(self, *args: t.Any, **kwargs: t.Any) -> None: super().__init__(*args, **kwargs) - self._spark: t.Optional[PySparkSession] = None + self._set_spark_engine_adapter_if_needed() @classmethod def can_access_spark_session(cls, disable_spark_session: bool) -> bool: @@ -74,19 +78,66 @@ def can_access_databricks_connect(cls, disable_databricks_connect: bool) -> bool def _use_spark_session(self) -> bool: if self.can_access_spark_session(bool(self._extra_config.get("disable_spark_session"))): return True - return self.can_access_databricks_connect( + + if self.can_access_databricks_connect( bool(self._extra_config.get("disable_databricks_connect")) - ) and { - "databricks_connect_server_hostname", - "databricks_connect_access_token", - "databricks_connect_cluster_id", - }.issubset(self._extra_config) + ): + if self._extra_config.get("databricks_connect_use_serverless"): + return True + + if { + "databricks_connect_cluster_id", + "databricks_connect_server_hostname", + "databricks_connect_access_token", + }.issubset(self._extra_config): + return True + + return False @property - def is_spark_session_cursor(self) -> bool: - from sqlmesh.engines.spark.db_api.spark_session import SparkSessionCursor + def is_spark_session_connection(self) -> bool: + return isinstance(self.connection, SparkSessionConnection) - return isinstance(self.cursor, SparkSessionCursor) + def _set_spark_engine_adapter_if_needed(self) -> None: + self._spark_engine_adapter = None + + if not self._use_spark_session or self.is_spark_session_connection: + return + + from databricks.connect import DatabricksSession + + connect_kwargs = dict( + host=self._extra_config["databricks_connect_server_hostname"], + token=self._extra_config.get("databricks_connect_access_token"), + ) + if "databricks_connect_use_serverless" in self._extra_config: + connect_kwargs["serverless"] = True + else: + connect_kwargs["cluster_id"] = self._extra_config["databricks_connect_cluster_id"] + + catalog = self._extra_config.get("catalog") + spark = ( + DatabricksSession.builder.remote(**connect_kwargs).userAgent("sqlmesh").getOrCreate() + ) + self._spark_engine_adapter = SparkEngineAdapter( + partial(connection, spark=spark, catalog=catalog), + default_catalog=catalog, + execute_log_level=self._execute_log_level, + multithreaded=self._multithreaded, + sql_gen_kwargs=self._sql_gen_kwargs, + register_comments=self._register_comments, + pre_ping=self._pre_ping, + pretty_sql=self._pretty_sql, + ) + + @property + def cursor(self) -> t.Any: + if ( + self._connection_pool.get_attribute("use_spark_engine_adapter") + and not self.is_spark_session_connection + ): + return self._spark_engine_adapter.cursor # type: ignore + return super().cursor @property def spark(self) -> PySparkSession: @@ -96,41 +147,78 @@ def spark(self) -> PySparkSession: "Either run from a Databricks Notebook or " "install `databricks-connect` and configure it to connect to your Databricks cluster." ) + if self.is_spark_session_connection: + return self.connection.spark + return self._spark_engine_adapter.spark # type: ignore - if self.is_spark_session_cursor: - return self._connection_pool.get().spark + @property + def catalog_support(self) -> CatalogSupport: + return CatalogSupport.FULL_SUPPORT - from databricks.connect import DatabricksSession + @staticmethod + def _grant_object_kind(table_type: DataObjectType) -> str: + if table_type == DataObjectType.VIEW: + return "VIEW" + if table_type == DataObjectType.MATERIALIZED_VIEW: + return "MATERIALIZED VIEW" + return "TABLE" - if self._spark is None: - self._spark = ( - DatabricksSession.builder.remote( - host=self._extra_config["databricks_connect_server_hostname"], - token=self._extra_config["databricks_connect_access_token"], - cluster_id=self._extra_config["databricks_connect_cluster_id"], - ) - .userAgent("sqlmesh") - .getOrCreate() + def _get_grant_expression(self, table: exp.Table) -> exp.Expression: + # We only care about explicitly granted privileges and not inherited ones + # if this is removed you would see grants inherited from the catalog get returned + expression = super()._get_grant_expression(table) + expression.args["where"].set( + "this", + exp.and_( + expression.args["where"].this, + exp.column("inherited_from").eq(exp.Literal.string("NONE")), + wrap=False, + ), + ) + return expression + + def _begin_session(self, properties: SessionProperties) -> t.Any: + """Begin a new session.""" + # Align the different possible connectors to a single catalog + self.set_current_catalog(self.default_catalog) # type: ignore + + def _end_session(self) -> None: + self._connection_pool.set_attribute("use_spark_engine_adapter", False) + + def _df_to_source_queries( + self, + df: DF, + target_columns_to_types: t.Dict[str, exp.DataType], + batch_size: int, + target_table: TableName, + source_columns: t.Optional[t.List[str]] = None, + ) -> t.List[SourceQuery]: + if not self._use_spark_session: + return super(SparkEngineAdapter, self)._df_to_source_queries( + df, target_columns_to_types, batch_size, target_table, source_columns=source_columns ) - catalog = self._extra_config.get("catalog") - if catalog: - self.set_current_catalog(catalog) - return self._spark + pyspark_df = self._ensure_pyspark_df( + df, target_columns_to_types, source_columns=source_columns + ) + + def query_factory() -> Query: + temp_table = self._get_temp_table(target_table or "spark", table_only=True) + pyspark_df.createOrReplaceTempView(temp_table.sql(dialect=self.dialect)) + self._connection_pool.set_attribute("use_spark_engine_adapter", True) + return exp.select(*self._select_columns(target_columns_to_types)).from_(temp_table) + + return [SourceQuery(query_factory=query_factory)] def _fetch_native_df( self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False ) -> DF: """Fetches a DataFrame that can be either Pandas or PySpark from the cursor""" - if self.is_spark_session_cursor: + if self.is_spark_session_connection: return super()._fetch_native_df(query, quote_identifiers=quote_identifiers) - if self._use_spark_session: - sql = ( - self._to_sql(query, quote=quote_identifiers) - if isinstance(query, exp.Expression) - else query + if self._spark_engine_adapter: + return self._spark_engine_adapter._fetch_native_df( # type: ignore + query, quote_identifiers=quote_identifiers ) - self._log_sql(sql) - return self.spark.sql(sql) self.execute(query) return self.cursor.fetchall_arrow().to_pandas() @@ -140,53 +228,106 @@ def fetchdf( """ Returns a Pandas DataFrame from a query or expression. """ + import pandas as pd + df = self._fetch_native_df(query, quote_identifiers=quote_identifiers) if not isinstance(df, pd.DataFrame): return df.toPandas() return df def get_current_catalog(self) -> t.Optional[str]: - # Update the Dataframe API if we have a spark session - if self._use_spark_session: + pyspark_catalog = None + sql_connector_catalog = None + if self._spark_engine_adapter: from py4j.protocol import Py4JError from pyspark.errors.exceptions.connect import SparkConnectGrpcException try: # Note: Spark 3.4+ Only API - return self.spark.catalog.currentCatalog() + pyspark_catalog = self._spark_engine_adapter.get_current_catalog() except (Py4JError, SparkConnectGrpcException): pass - result = self.fetchone(exp.select(self.CURRENT_CATALOG_EXPRESSION)) - if result: - return result[0] - return None + elif self.is_spark_session_connection: + pyspark_catalog = self.connection.spark.catalog.currentCatalog() + if not self.is_spark_session_connection: + result = self.fetchone(exp.select(self.CURRENT_CATALOG_EXPRESSION)) + sql_connector_catalog = result[0] if result else None + if self._spark_engine_adapter and pyspark_catalog != sql_connector_catalog: + logger.warning( + f"Current catalog mismatch between Databricks SQL Connector and Databricks-Connect: `{sql_connector_catalog}` != `{pyspark_catalog}`. Set `catalog` connection property to make them the same." + ) + return pyspark_catalog or sql_connector_catalog def set_current_catalog(self, catalog_name: str) -> None: - # Since Databricks splits commands across the Dataframe API and the SQL Connector - # (depending if databricks-connect is installed and a Dataframe is used) we need to ensure both - # are set to the same catalog since they maintain their default catalog seperately - self.execute(exp.Use(this=exp.to_identifier(catalog_name), kind="CATALOG")) - # Update the Dataframe API is we have a spark session - if self._use_spark_session: + def _set_spark_session_current_catalog(spark: PySparkSession) -> None: from py4j.protocol import Py4JError from pyspark.errors.exceptions.connect import SparkConnectGrpcException try: # Note: Spark 3.4+ Only API - self.spark.catalog.setCurrentCatalog(catalog_name) + spark.catalog.setCurrentCatalog(catalog_name) except (Py4JError, SparkConnectGrpcException): pass + # Since Databricks splits commands across the Dataframe API and the SQL Connector + # (depending if databricks-connect is installed and a Dataframe is used) we need to ensure both + # are set to the same catalog since they maintain their default catalog separately + self.execute(exp.Use(this=exp.to_identifier(catalog_name), kind="CATALOG")) + if self.is_spark_session_connection: + _set_spark_session_current_catalog(self.connection.spark) + + if self._spark_engine_adapter: + _set_spark_session_current_catalog(self._spark_engine_adapter.spark) + def _get_data_objects( self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None ) -> t.List[DataObject]: - return super()._get_data_objects(schema_name, object_names=object_names) + """ + Returns all the data objects that exist in the given schema and catalog. + """ + schema = to_schema(schema_name) + catalog_name = schema.catalog or self.get_current_catalog() + query = ( + exp.select( + exp.column("table_name").as_("name"), + exp.column("table_schema").as_("schema"), + exp.column("table_catalog").as_("catalog"), + exp.case(exp.column("table_type")) + .when(exp.Literal.string("VIEW"), exp.Literal.string("view")) + .when( + exp.Literal.string("MATERIALIZED_VIEW"), exp.Literal.string("materialized_view") + ) + .else_(exp.Literal.string("table")) + .as_("type"), + ) + .from_( + # always query `system` information_schema + exp.table_("tables", "information_schema", "system") + ) + .where(exp.column("table_catalog").eq(catalog_name)) + .where(exp.column("table_schema").eq(schema.db)) + ) + + if object_names: + query = query.where(exp.column("table_name").isin(*object_names)) + + df = self.fetchdf(query) + return [ + DataObject( + catalog=row.catalog, # type: ignore + schema=row.schema, # type: ignore + name=row.name, # type: ignore + type=DataObjectType.from_str(row.type), # type: ignore + ) + for row in df.itertuples() + ] def clone_table( self, target_table_name: TableName, source_table_name: TableName, replace: bool = False, + exists: bool = True, clone_kwargs: t.Optional[t.Dict[str, t.Any]] = None, **kwargs: t.Any, ) -> None: @@ -202,3 +343,71 @@ def clone_table( def wap_supported(self, table_name: TableName) -> bool: return False + + def close(self) -> t.Any: + """Closes all open connections and releases all allocated resources.""" + super().close() + if self._spark_engine_adapter: + self._spark_engine_adapter.close() + + @property + def default_catalog(self) -> t.Optional[str]: + try: + return super().default_catalog + except MissingDefaultCatalogError as e: + raise MissingDefaultCatalogError( + "Could not determine default catalog. Define the connection property `catalog` since it can't be inferred from your connection. See SQLMesh Databricks documentation for details" + ) from e + + def _build_table_properties_exp( + self, + catalog_name: t.Optional[str] = None, + table_format: t.Optional[str] = None, + storage_format: t.Optional[str] = None, + partitioned_by: t.Optional[t.List[exp.Expression]] = None, + partition_interval_unit: t.Optional[IntervalUnit] = None, + clustered_by: t.Optional[t.List[exp.Expression]] = None, + table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + table_description: t.Optional[str] = None, + table_kind: t.Optional[str] = None, + **kwargs: t.Any, + ) -> t.Optional[exp.Properties]: + properties = super()._build_table_properties_exp( + catalog_name=catalog_name, + table_format=table_format, + storage_format=storage_format, + partitioned_by=partitioned_by, + partition_interval_unit=partition_interval_unit, + clustered_by=clustered_by, + table_properties=table_properties, + target_columns_to_types=target_columns_to_types, + table_description=table_description, + table_kind=table_kind, + ) + if clustered_by: + # Databricks expects wrapped CLUSTER BY expressions + clustered_by_exp = exp.Cluster( + expressions=[exp.Tuple(expressions=[c.copy() for c in clustered_by])] + ) + expressions = properties.expressions if properties else [] + expressions.append(clustered_by_exp) + properties = exp.Properties(expressions=expressions) + return properties + + def _build_column_defs( + self, + target_columns_to_types: t.Dict[str, exp.DataType], + column_descriptions: t.Optional[t.Dict[str, str]] = None, + is_view: bool = False, + materialized: bool = False, + ) -> t.List[exp.ColumnDef]: + # Databricks requires column types to be specified when adding column comments + # in CREATE MATERIALIZED VIEW statements. Override is_view to False to force + # column types to be included when comments are present. + if is_view and materialized and column_descriptions: + is_view = False + + return super()._build_column_defs( + target_columns_to_types, column_descriptions, is_view, materialized + ) diff --git a/sqlmesh/core/engine_adapter/duckdb.py b/sqlmesh/core/engine_adapter/duckdb.py index aa7c0e0db6..3b057219e0 100644 --- a/sqlmesh/core/engine_adapter/duckdb.py +++ b/sqlmesh/core/engine_adapter/duckdb.py @@ -1,12 +1,13 @@ from __future__ import annotations import typing as t -from duckdb import __version__ as duckdb_version from sqlglot import exp +from pathlib import Path from sqlmesh.core.engine_adapter.mixins import ( GetCurrentCatalogFromFunctionMixin, LogicalMergeMixin, + RowDiffMixin, ) from sqlmesh.core.engine_adapter.shared import ( CatalogSupport, @@ -17,8 +18,6 @@ SourceQuery, set_catalog, ) -from sqlmesh.utils import major_minor -from sqlmesh.core.schema_diff import SchemaDiffer if t.TYPE_CHECKING: from sqlmesh.core._typing import SchemaName, TableName @@ -26,44 +25,73 @@ @set_catalog(override_mapping={"_get_data_objects": CatalogSupport.REQUIRES_SET_CATALOG}) -class DuckDBEngineAdapter(LogicalMergeMixin, GetCurrentCatalogFromFunctionMixin): +class DuckDBEngineAdapter(LogicalMergeMixin, GetCurrentCatalogFromFunctionMixin, RowDiffMixin): DIALECT = "duckdb" SUPPORTS_TRANSACTIONS = False - CATALOG_SUPPORT = CatalogSupport.FULL_SUPPORT - SCHEMA_DIFFER = SchemaDiffer( - parameterized_type_defaults={ + SCHEMA_DIFFER_KWARGS = { + "parameterized_type_defaults": { exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(18, 3), (0,)], }, - ) + } + COMMENT_CREATION_TABLE = CommentCreationTable.COMMENT_COMMAND_ONLY + COMMENT_CREATION_VIEW = CommentCreationView.COMMENT_COMMAND_ONLY + SUPPORTS_CREATE_DROP_CATALOG = True + SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["SCHEMA", "TABLE", "VIEW"] - # TODO: remove once we stop supporting DuckDB 0.9 - COMMENT_CREATION_TABLE, COMMENT_CREATION_VIEW = ( - (CommentCreationTable.UNSUPPORTED, CommentCreationView.UNSUPPORTED) - if major_minor(duckdb_version) < (0, 10) - else (CommentCreationTable.COMMENT_COMMAND_ONLY, CommentCreationView.COMMENT_COMMAND_ONLY) - ) + @property + def catalog_support(self) -> CatalogSupport: + return CatalogSupport.FULL_SUPPORT def set_current_catalog(self, catalog: str) -> None: """Sets the catalog name of the current connection.""" self.execute(exp.Use(this=exp.to_identifier(catalog))) + def _create_catalog(self, catalog_name: exp.Identifier) -> None: + if not self._is_motherduck: + db_filename = f"{catalog_name.output_name}.db" + self.execute( + exp.Attach( + this=exp.alias_(exp.Literal.string(db_filename), catalog_name), exists=True + ) + ) + else: + self.execute( + exp.Create(this=exp.Table(this=catalog_name), kind="DATABASE", exists=True) + ) + + def _drop_catalog(self, catalog_name: exp.Identifier) -> None: + if not self._is_motherduck: + db_file_path = Path(f"{catalog_name.output_name}.db") + self.execute(exp.Detach(this=catalog_name, exists=True)) + if db_file_path.exists(): + db_file_path.unlink() + else: + self.execute( + exp.Drop( + this=exp.Table(this=catalog_name), kind="DATABASE", cascade=True, exists=True + ) + ) + def _df_to_source_queries( self, df: DF, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], batch_size: int, target_table: TableName, + source_columns: t.Optional[t.List[str]] = None, ) -> t.List[SourceQuery]: temp_table = self._get_temp_table(target_table) temp_table_sql = ( - exp.select(*self._casted_columns(columns_to_types)) + exp.select(*self._casted_columns(target_columns_to_types, source_columns)) .from_("df") .sql(dialect=self.dialect) ) self.cursor.sql(f"CREATE TABLE {temp_table} AS {temp_table_sql}") return [ SourceQuery( - query_factory=lambda: self._select_columns(columns_to_types).from_(temp_table), # type: ignore + query_factory=lambda: self._select_columns(target_columns_to_types).from_( + temp_table + ), # type: ignore cleanup_func=lambda: self.drop_table(temp_table), ) ] @@ -99,7 +127,7 @@ def _get_data_objects( ) .as_("type"), ) - .from_(exp.to_table("information_schema.tables")) + .from_(exp.to_table("system.information_schema.tables")) .where( exp.column("table_catalog").eq(catalog), exp.column("table_schema").eq(schema_name) ) @@ -116,3 +144,77 @@ def _get_data_objects( ) for row in df.itertuples() ] + + def _normalize_decimal_value(self, col: exp.Expression, precision: int) -> exp.Expression: + """ + duckdb truncates instead of rounding when casting to decimal. + + other databases: select cast(3.14159 as decimal(38,3)) -> 3.142 + duckdb: select cast(3.14159 as decimal(38,3)) -> 3.141 + + however, we can get the behaviour of other databases by casting to double first: + select cast(cast(3.14159 as double) as decimal(38, 3)) -> 3.142 + """ + return exp.cast( + exp.cast(col, "DOUBLE"), + f"DECIMAL(38, {precision})", + ) + + def _create_table( + self, + table_name_or_schema: t.Union[exp.Schema, TableName], + expression: t.Optional[exp.Expression], + exists: bool = True, + replace: bool = False, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + table_description: t.Optional[str] = None, + column_descriptions: t.Optional[t.Dict[str, str]] = None, + table_kind: t.Optional[str] = None, + track_rows_processed: bool = True, + **kwargs: t.Any, + ) -> None: + catalog = self.get_current_catalog() + catalog_type_tuple = self.fetchone( + exp.select("type") + .from_("duckdb_databases()") + .where(exp.column("database_name").eq(catalog)) + ) + catalog_type = catalog_type_tuple[0] if catalog_type_tuple else None + + partitioned_by_exps = None + if catalog_type == "ducklake": + partitioned_by_exps = kwargs.pop("partitioned_by", None) + + super()._create_table( + table_name_or_schema, + expression, + exists, + replace, + target_columns_to_types, + table_description, + column_descriptions, + table_kind, + track_rows_processed=track_rows_processed, + **kwargs, + ) + + if partitioned_by_exps: + # Schema object contains column definitions, so we extract Table + table_name = ( + table_name_or_schema.this + if isinstance(table_name_or_schema, exp.Schema) + else table_name_or_schema + ) + table_name_str = ( + table_name.sql(dialect=self.dialect) + if isinstance(table_name, exp.Table) + else table_name + ) + partitioned_by_str = ", ".join( + expr.sql(dialect=self.dialect) for expr in partitioned_by_exps + ) + self.execute(f"ALTER TABLE {table_name_str} SET PARTITIONED BY ({partitioned_by_str});") + + @property + def _is_motherduck(self) -> bool: + return self._extra_config.get("is_motherduck", False) diff --git a/sqlmesh/core/engine_adapter/fabric.py b/sqlmesh/core/engine_adapter/fabric.py new file mode 100644 index 0000000000..e1dffe88f4 --- /dev/null +++ b/sqlmesh/core/engine_adapter/fabric.py @@ -0,0 +1,429 @@ +from __future__ import annotations + +import typing as t +import logging +import requests +import time +from functools import cached_property +from sqlglot import exp +from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_result +from sqlmesh.core.engine_adapter.mssql import MSSQLEngineAdapter +from sqlmesh.core.engine_adapter.shared import ( + InsertOverwriteStrategy, +) +from sqlmesh.utils.errors import SQLMeshError +from sqlmesh.utils.connection_pool import ConnectionPool +from sqlmesh.core.schema_diff import TableAlterOperation +from sqlmesh.utils import random_id + + +logger = logging.getLogger(__name__) + + +class FabricEngineAdapter(MSSQLEngineAdapter): + """ + Adapter for Microsoft Fabric. + """ + + DIALECT = "fabric" + SUPPORTS_INDEXES = False + SUPPORTS_TRANSACTIONS = False + SUPPORTS_CREATE_DROP_CATALOG = True + INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.DELETE_INSERT + + def __init__( + self, connection_factory_or_pool: t.Union[t.Callable, t.Any], *args: t.Any, **kwargs: t.Any + ) -> None: + # Wrap connection factory to support changing the catalog dynamically at runtime + if not isinstance(connection_factory_or_pool, ConnectionPool): + original_connection_factory = connection_factory_or_pool + + connection_factory_or_pool = lambda *args, **kwargs: original_connection_factory( + target_catalog=self._target_catalog, *args, **kwargs + ) + + super().__init__(connection_factory_or_pool, *args, **kwargs) + + @property + def _target_catalog(self) -> t.Optional[str]: + return self._connection_pool.get_attribute("target_catalog") + + @_target_catalog.setter + def _target_catalog(self, value: t.Optional[str]) -> None: + self._connection_pool.set_attribute("target_catalog", value) + + @property + def api_client(self) -> FabricHttpClient: + # the requests Session is not guaranteed to be threadsafe + # so we create a http client per thread on demand + if existing_client := self._connection_pool.get_attribute("api_client"): + return existing_client + + tenant_id: t.Optional[str] = self._extra_config.get("tenant_id") + workspace_id: t.Optional[str] = self._extra_config.get("workspace_id") + client_id: t.Optional[str] = self._extra_config.get("user") + client_secret: t.Optional[str] = self._extra_config.get("password") + + if not tenant_id or not client_id or not client_secret: + raise SQLMeshError( + "Service Principal authentication requires tenant_id, client_id, and client_secret " + "in the Fabric connection configuration" + ) + + if not workspace_id: + raise SQLMeshError( + "Fabric requires the workspace_id to be configured in the connection configuration to create / drop catalogs" + ) + + client = FabricHttpClient( + tenant_id=tenant_id, + workspace_id=workspace_id, + client_id=client_id, + client_secret=client_secret, + ) + + self._connection_pool.set_attribute("api_client", client) + return client + + def _create_catalog(self, catalog_name: exp.Identifier) -> None: + """Create a catalog (warehouse) in Microsoft Fabric via REST API.""" + warehouse_name = catalog_name.sql(dialect=self.dialect, identify=False) + logger.info(f"Creating Fabric warehouse: {warehouse_name}") + + self.api_client.create_warehouse(warehouse_name) + + def _drop_catalog(self, catalog_name: exp.Identifier) -> None: + """Drop a catalog (warehouse) in Microsoft Fabric via REST API.""" + warehouse_name = catalog_name.sql(dialect=self.dialect, identify=False) + current_catalog = self.get_current_catalog() + + logger.info(f"Deleting Fabric warehouse: {warehouse_name}") + self.api_client.delete_warehouse(warehouse_name) + + if warehouse_name == current_catalog: + # Somewhere around 2025-09-08, Fabric started validating the "Database=" connection argument and throwing 'Authentication failed' if the database doesnt exist + # In addition, set_current_catalog() is implemented using a threadlocal variable "target_catalog" + # So, when we drop a warehouse, and there are still threads with "target_catalog" set to reference it, any operations on those threads + # that use an either use an existing connection pointing to this warehouse or trigger a new connection + # will fail with an 'Authentication Failed' error unless we close all connections here, which also clears all the threadlocal data + self.close() + + def set_current_catalog(self, catalog_name: str) -> None: + """ + Set the current catalog for Microsoft Fabric connections. + + Override to handle Fabric's stateless session limitation where USE statements + don't persist across queries. Instead, we close existing connections and + recreate them with the new catalog in the connection configuration. + + Args: + catalog_name: The name of the catalog (warehouse) to switch to + + Note: + Fabric doesn't support catalog switching via USE statements because each + statement runs as an independent session. This method works around this + limitation by updating the connection pool with new catalog configuration. + + See: + https://learn.microsoft.com/en-us/fabric/data-warehouse/sql-query-editor#limitations + """ + current_catalog = self.get_current_catalog() + + # If already using the requested catalog, do nothing + if current_catalog and current_catalog == catalog_name: + logger.debug(f"Already using catalog '{catalog_name}', no action needed") + return + + logger.info(f"Switching from catalog '{current_catalog}' to '{catalog_name}'") + + # commit the transaction before closing the connection to help prevent errors like: + # > Snapshot isolation transaction failed in database because the object accessed by the statement has been modified by a + # > DDL statement in another concurrent transaction since the start of this transaction + # on subsequent queries in the new connection + self._connection_pool.commit() + + # note: we call close() on the connection pool instead of self.close() because self.close() calls close_all() + # on the connection pool but we just want to close the connection for this thread + self._connection_pool.close() + self._target_catalog = catalog_name # new connections will use this catalog + + catalog_after_switch = self.get_current_catalog() + + if catalog_after_switch != catalog_name: + # We need to raise an error if the catalog switch failed to prevent the operation that needed the catalog switch from being run against the wrong catalog + raise SQLMeshError( + f"Unable to switch catalog to {catalog_name}, catalog ended up as {catalog_after_switch}" + ) + + def alter_table( + self, alter_expressions: t.Union[t.List[exp.Alter], t.List[TableAlterOperation]] + ) -> None: + """ + Applies alter expressions to a table. Fabric has limited support for ALTER TABLE, + so this method implements a workaround for column type changes. + This method is self-contained and sets its own catalog context. + """ + if not alter_expressions: + return + + # Get the target table from the first expression to determine the correct catalog. + first_op = alter_expressions[0] + expression = first_op.expression if isinstance(first_op, TableAlterOperation) else first_op + if not isinstance(expression, exp.Alter) or not expression.this.catalog: + # Fallback for unexpected scenarios + logger.warning( + "Could not determine catalog from alter expression, executing with current context." + ) + super().alter_table(alter_expressions) + return + + target_catalog = expression.this.catalog + self.set_current_catalog(target_catalog) + + with self.transaction(): + for op in alter_expressions: + expression = op.expression if isinstance(op, TableAlterOperation) else op + + if not isinstance(expression, exp.Alter): + self.execute(expression) + continue + + for action in expression.actions: + table_name = expression.this + + table_name_without_catalog = table_name.copy() + table_name_without_catalog.set("catalog", None) + + is_type_change = isinstance(action, exp.AlterColumn) and action.args.get( + "dtype" + ) + + if is_type_change: + column_to_alter = action.this + new_type = action.args["dtype"] + temp_column_name_str = f"{column_to_alter.name}__{random_id(short=True)}" + temp_column_name = exp.to_identifier(temp_column_name_str) + + logger.info( + "Applying workaround for column '%s' on table '%s' to change type to '%s'.", + column_to_alter.sql(), + table_name.sql(), + new_type.sql(), + ) + + # Step 1: Add a temporary column. + add_column_expr = exp.Alter( + this=table_name_without_catalog.copy(), + kind="TABLE", + actions=[ + exp.ColumnDef(this=temp_column_name.copy(), kind=new_type.copy()) + ], + ) + add_sql = self._to_sql(add_column_expr) + self.execute(add_sql) + + # Step 2: Copy and cast data. + update_sql = self._to_sql( + exp.Update( + this=table_name_without_catalog.copy(), + expressions=[ + exp.EQ( + this=temp_column_name.copy(), + expression=exp.Cast( + this=column_to_alter.copy(), to=new_type.copy() + ), + ) + ], + ) + ) + self.execute(update_sql) + + # Step 3: Drop the original column. + drop_sql = self._to_sql( + exp.Alter( + this=table_name_without_catalog.copy(), + kind="TABLE", + actions=[exp.Drop(this=column_to_alter.copy(), kind="COLUMN")], + ) + ) + self.execute(drop_sql) + + # Step 4: Rename the temporary column. + old_name_qualified = f"{table_name_without_catalog.sql(dialect=self.dialect)}.{temp_column_name.sql(dialect=self.dialect)}" + new_name_unquoted = column_to_alter.sql( + dialect=self.dialect, identify=False + ) + rename_sql = f"EXEC sp_rename '{old_name_qualified}', '{new_name_unquoted}', 'COLUMN'" + self.execute(rename_sql) + else: + # For other alterations, execute directly. + direct_alter_expr = exp.Alter( + this=table_name_without_catalog.copy(), kind="TABLE", actions=[action] + ) + self.execute(direct_alter_expr) + + +class FabricHttpClient: + def __init__(self, tenant_id: str, workspace_id: str, client_id: str, client_secret: str): + self.tenant_id = tenant_id + self.client_id = client_id + self.client_secret = client_secret + self.workspace_id = workspace_id + + def create_warehouse( + self, warehouse_name: str, if_not_exists: bool = True, attempt: int = 0 + ) -> None: + """Create a catalog (warehouse) in Microsoft Fabric via REST API.""" + + # attempt count is arbitrary, it essentially equates to 5 minutes of 30 second waits + if attempt > 10: + raise SQLMeshError( + f"Gave up waiting for Fabric warehouse {warehouse_name} to become available" + ) + + logger.info(f"Creating Fabric warehouse: {warehouse_name}") + + request_data = { + "displayName": warehouse_name, + "description": f"Warehouse created by SQLMesh: {warehouse_name}", + } + + response = self.session.post(self._endpoint_url("warehouses"), json=request_data) + + if ( + if_not_exists + and response.status_code == 400 + and (errorCode := response.json().get("errorCode", None)) + ): + if errorCode == "ItemDisplayNameAlreadyInUse": + logger.warning(f"Fabric warehouse {warehouse_name} already exists") + return + if errorCode == "ItemDisplayNameNotAvailableYet": + logger.warning(f"Fabric warehouse {warehouse_name} is still spinning up; waiting") + # Fabric error message is something like: + # - "Requested 'circleci_51d7087e__dev' is not available yet and is expected to become available in the upcoming minutes." + # This seems to happen if a catalog is dropped and then a new one with the same name is immediately created. + # There appears to be some delayed async process on the Fabric side that actually drops the warehouses and frees up the names to be used again + time.sleep(30) + return self.create_warehouse( + warehouse_name=warehouse_name, if_not_exists=if_not_exists, attempt=attempt + 1 + ) + + try: + response.raise_for_status() + except: + # the important information to actually debug anything is in the response body which Requests never prints + logger.exception( + f"Failed to create warehouse {warehouse_name}. status: {response.status_code}, body: {response.text}" + ) + raise + + # Handle direct success (201) or async creation (202) + if response.status_code == 201: + logger.info(f"Successfully created Fabric warehouse: {warehouse_name}") + return + + if response.status_code == 202 and (location_header := response.headers.get("location")): + logger.info(f"Warehouse creation initiated for: {warehouse_name}") + self._wait_for_completion(location_header, warehouse_name) + logger.info(f"Successfully created Fabric warehouse: {warehouse_name}") + else: + logger.error(f"Unexpected response from Fabric API: {response}\n{response.text}") + raise SQLMeshError(f"Unable to create warehouse: {response}") + + def delete_warehouse(self, warehouse_name: str, if_exists: bool = True) -> None: + """Drop a catalog (warehouse) in Microsoft Fabric via REST API.""" + logger.info(f"Deleting Fabric warehouse: {warehouse_name}") + + # Get the warehouse ID by listing warehouses + # TODO: handle continuationUri for pagination, ref: https://learn.microsoft.com/en-us/rest/api/fabric/warehouse/items/list-warehouses?tabs=HTTP#warehouses + response = self.session.get(self._endpoint_url("warehouses")) + response.raise_for_status() + + warehouse_name_to_id = { + warehouse.get("displayName"): warehouse.get("id") + for warehouse in response.json().get("value", []) + } + + warehouse_id = warehouse_name_to_id.get(warehouse_name, None) + + if not warehouse_id: + logger.warning( + f"Fabric warehouse does not exist: {warehouse_name}\n(available warehouses: {', '.join(warehouse_name_to_id)})" + ) + if if_exists: + return + + raise SQLMeshError( + f"Unable to delete Fabric warehouse {warehouse_name} as it doesnt exist" + ) + + # Delete the warehouse by ID + response = self.session.delete(self._endpoint_url(f"warehouses/{warehouse_id}")) + response.raise_for_status() + + logger.info(f"Successfully deleted Fabric warehouse: {warehouse_name}") + + @cached_property + def session(self) -> requests.Session: + s = requests.Session() + + access_token = self._get_access_token() + s.headers.update({"Authorization": f"Bearer {access_token}"}) + + return s + + def _endpoint_url(self, endpoint: str) -> str: + if endpoint.startswith("/"): + endpoint = endpoint[1:] + + return f"https://api.fabric.microsoft.com/v1/workspaces/{self.workspace_id}/{endpoint}" + + def _get_access_token(self) -> str: + """Get access token using Service Principal authentication.""" + + # Use Azure AD OAuth2 token endpoint + token_url = f"https://login.microsoftonline.com/{self.tenant_id}/oauth2/v2.0/token" + + data = { + "grant_type": "client_credentials", + "client_id": self.client_id, + "client_secret": self.client_secret, + "scope": "https://api.fabric.microsoft.com/.default", + } + + response = requests.post(token_url, data=data) + response.raise_for_status() + token_data = response.json() + return token_data["access_token"] + + def _wait_for_completion(self, location_url: str, operation_name: str) -> None: + """Poll the operation status until completion.""" + + @retry( + wait=wait_exponential(multiplier=1, min=1, max=30), + stop=stop_after_attempt(20), + retry=retry_if_result(lambda result: result not in ["Succeeded", "Failed"]), + ) + def _poll() -> str: + response = self.session.get(location_url) + response.raise_for_status() + + result = response.json() + status = result.get("status", "Unknown") + + logger.debug(f"Operation {operation_name} status: {status}") + + if status == "Failed": + error_msg = result.get("error", {}).get("message", "Unknown error") + raise SQLMeshError(f"Operation {operation_name} failed: {error_msg}") + elif status in ["InProgress", "Running"]: + logger.debug(f"Operation {operation_name} still in progress...") + elif status not in ["Succeeded"]: + logger.warning(f"Unknown status '{status}' for operation {operation_name}") + + return status + + final_status = _poll() + if final_status != "Succeeded": + raise SQLMeshError(f"Operation {operation_name} completed with status: {final_status}") diff --git a/sqlmesh/core/engine_adapter/mixins.py b/sqlmesh/core/engine_adapter/mixins.py index 171d453c9d..c8ef32b9da 100644 --- a/sqlmesh/core/engine_adapter/mixins.py +++ b/sqlmesh/core/engine_adapter/mixins.py @@ -1,48 +1,58 @@ from __future__ import annotations +import abc import logging import typing as t +from dataclasses import dataclass -from sqlglot import exp +from sqlglot import exp, parse_one +from sqlglot.helper import seq_get +from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlmesh.core.engine_adapter.base import EngineAdapter -from sqlmesh.core.engine_adapter.shared import InsertOverwriteStrategy, SourceQuery +from sqlmesh.core.engine_adapter.shared import DataObjectType from sqlmesh.core.node import IntervalUnit +from sqlmesh.core.dialect import schema_ +from sqlmesh.core.schema_diff import TableAlterOperation from sqlmesh.utils.errors import SQLMeshError if t.TYPE_CHECKING: from sqlmesh.core._typing import TableName - from sqlmesh.core.engine_adapter._typing import DF + from sqlmesh.core.engine_adapter._typing import ( + DCL, + DF, + GrantsConfig, + QueryOrDF, + ) from sqlmesh.core.engine_adapter.base import QueryOrDF logger = logging.getLogger(__name__) +NORMALIZED_DATE_FORMAT = "%Y-%m-%d" +NORMALIZED_TIMESTAMP_FORMAT = "%Y-%m-%d %H:%M:%S.%f" + class LogicalMergeMixin(EngineAdapter): def merge( self, target_table: TableName, source_table: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]], + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]], unique_key: t.Sequence[exp.Expression], - when_matched: t.Optional[exp.When] = None, + when_matched: t.Optional[exp.Whens] = None, + merge_filter: t.Optional[exp.Expression] = None, + source_columns: t.Optional[t.List[str]] = None, + **kwargs: t.Any, ) -> None: - """ - Merge implementation for engine adapters that do not support merge natively. - - The merge is executed as follows: - 1. Create a temporary table containing the new data to merge. - 2. Delete rows from target table where unique_key cols match a row in the temporary table. - 3. Insert the temporary table contents into the target table. Any duplicate, non-unique rows - within the temporary table are ommitted. - 4. Drop the temporary table. - """ - if when_matched: - raise SQLMeshError( - "This engine does not support MERGE expressions and therefore `when_matched` is not supported." - ) - self._replace_by_key( - target_table, source_table, columns_to_types, unique_key, is_unique_key=True + logical_merge( + self, + target_table, + source_table, + target_columns_to_types, + unique_key, + when_matched=when_matched, + merge_filter=merge_filter, + source_columns=source_columns, ) @@ -71,100 +81,75 @@ def _fetch_native_df( return df -class InsertOverwriteWithMergeMixin(EngineAdapter): - def _insert_overwrite_by_condition( - self, - table_name: TableName, - source_queries: t.List[SourceQuery], - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, - where: t.Optional[exp.Condition] = None, - insert_overwrite_strategy_override: t.Optional[InsertOverwriteStrategy] = None, - ) -> None: - """ - Some engines do not support `INSERT OVERWRITE` but instead support - doing an "INSERT OVERWRITE" using a Merge expression but with the - predicate being `False`. - """ - columns_to_types = columns_to_types or self.columns(table_name) - for source_query in source_queries: - with source_query as query: - query = self._order_projections_and_filter(query, columns_to_types, where=where) - columns = [exp.to_column(col) for col in columns_to_types] - when_not_matched_by_source = exp.When( - matched=False, - source=True, - condition=where, - then=exp.Delete(), - ) - when_not_matched_by_target = exp.When( - matched=False, - source=False, - then=exp.Insert( - this=exp.Tuple(expressions=columns), - expression=exp.Tuple(expressions=columns), - ), - ) - self._merge( - target_table=table_name, - query=query, - on=exp.false(), - match_expressions=[when_not_matched_by_source, when_not_matched_by_target], - ) - - class HiveMetastoreTablePropertiesMixin(EngineAdapter): MAX_TABLE_COMMENT_LENGTH = 4000 MAX_COLUMN_COMMENT_LENGTH = 4000 + def _build_partitioned_by_exp( + self, + partitioned_by: t.List[exp.Expression], + *, + catalog_name: t.Optional[str] = None, + **kwargs: t.Any, + ) -> t.Union[exp.PartitionedByProperty, exp.Property]: + if ( + self.dialect == "trino" + and self.get_catalog_type(catalog_name or self.get_current_catalog()) == "iceberg" + ): + # On the Trino Iceberg catalog, the table property is called "partitioning" - not "partitioned_by" + # In addition, partition column transform expressions like `day(col)` or `bucket(col, 5)` are allowed + # Also, column names and transforms need to be strings and supplied as an ARRAY[varchar] + # ref: https://trino.io/docs/current/connector/iceberg.html#table-properties + return exp.Property( + this=exp.var("PARTITIONING"), + value=exp.array( + *(exp.Literal.string(e.sql(dialect=self.dialect)) for e in partitioned_by) + ), + ) + for expr in partitioned_by: + if not isinstance(expr, exp.Column): + raise SQLMeshError( + f"PARTITIONED BY contains non-column value '{expr.sql(dialect=self.dialect)}'." + ) + return exp.PartitionedByProperty( + this=exp.Schema(expressions=partitioned_by), + ) + def _build_table_properties_exp( self, catalog_name: t.Optional[str] = None, + table_format: t.Optional[str] = None, storage_format: t.Optional[str] = None, partitioned_by: t.Optional[t.List[exp.Expression]] = None, partition_interval_unit: t.Optional[IntervalUnit] = None, - clustered_by: t.Optional[t.List[str]] = None, + clustered_by: t.Optional[t.List[exp.Expression]] = None, table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, table_kind: t.Optional[str] = None, + **kwargs: t.Any, ) -> t.Optional[exp.Properties]: properties: t.List[exp.Expression] = [] - if storage_format: - properties.append(exp.FileFormatProperty(this=exp.Var(this=storage_format))) - - if partitioned_by: - if ( - self.dialect == "trino" - and self.get_catalog_type(catalog_name or self.get_current_catalog()) == "iceberg" - ): - # On the Trino Iceberg catalog, the table property is called "partitioning" - not "partitioned_by" - # In addition, partition column transform expressions like `day(col)` or `bucket(col, 5)` are allowed - # Also, column names and transforms need to be strings and supplied as an ARRAY[varchar] - # ref: https://trino.io/docs/current/connector/iceberg.html#table-properties + if table_format and self.dialect == "spark": + properties.append(exp.FileFormatProperty(this=exp.Var(this=table_format))) + if storage_format: properties.append( exp.Property( - this=exp.var("PARTITIONING"), - value=exp.array( - *( - exp.Literal.string(e.sql(dialect=self.dialect)) - for e in partitioned_by - ) - ), + this="write.format.default", value=exp.Literal.string(storage_format) ) ) - else: - for expr in partitioned_by: - if not isinstance(expr, exp.Column): - raise SQLMeshError( - f"PARTITIONED BY contains non-column value '{expr.sql(dialect=self.dialect)}'." - ) + elif storage_format: + properties.append(exp.FileFormatProperty(this=exp.Var(this=storage_format))) - properties.append( - exp.PartitionedByProperty( - this=exp.Schema(expressions=partitioned_by), - ) + if partitioned_by: + properties.append( + self._build_partitioned_by_exp( + partitioned_by, + partition_interval_unit=partition_interval_unit, + catalog_name=catalog_name, ) + ) if table_description: properties.append( @@ -183,6 +168,7 @@ def _build_view_properties_exp( self, view_properties: t.Optional[t.Dict[str, exp.Expression]] = None, table_description: t.Optional[str] = None, + **kwargs: t.Any, ) -> t.Optional[exp.Properties]: """Creates a SQLGlot table properties expression for view""" properties: t.List[exp.Expression] = [] @@ -202,7 +188,7 @@ def _build_view_properties_exp( def _truncate_comment(self, comment: str, length: t.Optional[int]) -> str: # iceberg and delta do not have a comment length limit - if self.current_catalog_type in ("iceberg", "delta"): + if self.current_catalog_type in ("iceberg", "delta_lake"): return comment return super()._truncate_comment(comment, length) @@ -233,9 +219,9 @@ def _default_precision_to_max( ) -> t.Dict[str, exp.DataType]: # get default lengths for types that support "max" length types_with_max_default_param = { - k: [self.SCHEMA_DIFFER.parameterized_type_defaults[k][0][0]] - for k in self.SCHEMA_DIFFER.max_parameter_length - if k in self.SCHEMA_DIFFER.parameterized_type_defaults + k: [self.schema_differ.parameterized_type_defaults[k][0][0]] + for k in self.schema_differ.max_parameter_length + if k in self.schema_differ.parameterized_type_defaults } # Redshift and MSSQL have a bug where CTAS statements have non-deterministic types. If a LIMIT @@ -244,7 +230,7 @@ def _default_precision_to_max( # and supports "max" length, we convert it to "max" length to prevent inadvertent data truncation. for col_name, col_type in columns_to_types.items(): if col_type.this in types_with_max_default_param and col_type.expressions: - parameter = self.SCHEMA_DIFFER.get_type_parameters(col_type) + parameter = self.schema_differ.get_type_parameters(col_type) type_default = types_with_max_default_param[col_type.this] if parameter == type_default: col_type.set("expressions", [exp.DataTypeParam(this=exp.var("max"))]) @@ -257,7 +243,7 @@ def _build_create_table_exp( expression: t.Optional[exp.Expression], exists: bool = True, replace: bool = False, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, table_kind: t.Optional[str] = None, **kwargs: t.Any, @@ -267,7 +253,7 @@ def _build_create_table_exp( expression=expression, exists=exists, replace=replace, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, table_description=table_description, table_kind=table_kind, **kwargs, @@ -279,17 +265,20 @@ def _build_create_table_exp( and statement.expression.args["limit"].expression.this == "0" ): assert not isinstance(table_name_or_schema, exp.Schema) + # redshift and mssql have a bug where CTAS statements have non determistic types. if a limit # is applied to a ctas statement, VARCHAR types default to 1 in some instances. select_statement = statement.expression.copy() for select_or_union in select_statement.find_all(exp.Select, exp.SetOperation): - select_or_union.set("limit", None) + limit = select_or_union.args.get("limit") + if limit is not None and limit.expression.this == "0": + limit.pop() + select_or_union.set("where", None) temp_view_name = self._get_temp_table("ctas") - self.create_view( - temp_view_name, select_statement, replace=False, no_schema_binding=False - ) + + self.create_view(temp_view_name, select_statement, replace=False) try: columns_to_types_from_view = self._default_precision_to_max( self.columns(temp_view_name) @@ -304,7 +293,7 @@ def _build_create_table_exp( None, exists=exists, replace=replace, - columns_to_types=columns_to_types_from_view, + target_columns_to_types=columns_to_types_from_view, table_description=table_description, **kwargs, ) @@ -312,3 +301,391 @@ def _build_create_table_exp( self.drop_view(temp_view_name) return statement + + +@dataclass(frozen=True) +class TableAlterClusterByOperation(TableAlterOperation, abc.ABC): + pass + + +@dataclass(frozen=True) +class TableAlterChangeClusterKeyOperation(TableAlterClusterByOperation): + clustering_key: str + dialect: str + + @property + def is_additive(self) -> bool: + return False + + @property + def is_destructive(self) -> bool: + return False + + @property + def _alter_actions(self) -> t.List[exp.Expression]: + return [exp.Cluster(expressions=self.cluster_key_expressions)] + + @property + def cluster_key_expressions(self) -> t.List[exp.Expression]: + # Note: Assumes `clustering_key` as a string like: + # - "(col_a)" + # - "(col_a, col_b)" + # - "func(col_a, transform(col_b))" + parsed_cluster_key = parse_one(self.clustering_key, dialect=self.dialect) + return parsed_cluster_key.expressions or [parsed_cluster_key.this] + + +@dataclass(frozen=True) +class TableAlterDropClusterKeyOperation(TableAlterClusterByOperation): + @property + def is_additive(self) -> bool: + return False + + @property + def is_destructive(self) -> bool: + return False + + @property + def _alter_actions(self) -> t.List[exp.Expression]: + return [exp.Command(this="DROP", expression="CLUSTERING KEY")] + + +class ClusteredByMixin(EngineAdapter): + def _build_clustered_by_exp( + self, + clustered_by: t.List[exp.Expression], + **kwargs: t.Any, + ) -> t.Optional[exp.Cluster]: + return exp.Cluster(expressions=[c.copy() for c in clustered_by]) + + def get_alter_operations( + self, + current_table_name: TableName, + target_table_name: TableName, + *, + ignore_destructive: bool = False, + ignore_additive: bool = False, + ) -> t.List[TableAlterOperation]: + operations = super().get_alter_operations( + current_table_name, + target_table_name, + ignore_destructive=ignore_destructive, + ignore_additive=ignore_additive, + ) + + # check for a change in clustering + current_table = exp.to_table(current_table_name) + target_table = exp.to_table(target_table_name) + + current_table_schema = schema_(current_table.db, catalog=current_table.catalog) + target_table_schema = schema_(target_table.db, catalog=target_table.catalog) + + current_table_info = seq_get( + self.get_data_objects(current_table_schema, {current_table.name}), 0 + ) + target_table_info = seq_get( + self.get_data_objects(target_table_schema, {target_table.name}), 0 + ) + + if current_table_info and target_table_info: + if target_table_info.is_clustered: + if target_table_info.clustering_key and ( + current_table_info.clustering_key != target_table_info.clustering_key + ): + operations.append( + TableAlterChangeClusterKeyOperation( + target_table=current_table, + clustering_key=target_table_info.clustering_key, + dialect=self.dialect, + ) + ) + elif current_table_info.is_clustered: + operations.append(TableAlterDropClusterKeyOperation(target_table=current_table)) + + return operations + + +def logical_merge( + engine_adapter: EngineAdapter, + target_table: TableName, + source_table: QueryOrDF, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]], + unique_key: t.Sequence[exp.Expression], + when_matched: t.Optional[exp.Whens] = None, + merge_filter: t.Optional[exp.Expression] = None, + source_columns: t.Optional[t.List[str]] = None, +) -> None: + """ + Merge implementation for engine adapters that do not support merge natively. + + The merge is executed as follows: + 1. Create a temporary table containing the new data to merge. + 2. Delete rows from target table where unique_key cols match a row in the temporary table. + 3. Insert the temporary table contents into the target table. Any duplicate, non-unique rows + within the temporary table are ommitted. + 4. Drop the temporary table. + """ + if when_matched or merge_filter: + prop = "when_matched" if when_matched else "merge_filter" + raise SQLMeshError( + f"This engine does not support MERGE expressions and therefore `{prop}` is not supported." + ) + + engine_adapter._replace_by_key( + target_table, + source_table, + target_columns_to_types, + unique_key, + is_unique_key=True, + source_columns=source_columns, + ) + + +class RowDiffMixin(EngineAdapter): + # The maximum supported value for n in timestamp(n). + # Most databases are microsecond (6) but some can only handle millisecond (3) while others go to nanosecond (9) + MAX_TIMESTAMP_PRECISION = 6 + + def concat_columns( + self, + columns_to_types: t.Dict[str, exp.DataType], + decimal_precision: int = 3, + timestamp_precision: int = MAX_TIMESTAMP_PRECISION, + delimiter: str = ",", + ) -> exp.Expression: + """ + Produce an expression that generates a string version of a record, that is: + - Every column converted to a string representation, joined together into a single string using the specified :delimiter + """ + expressions_to_concat: t.List[exp.Expression] = [] + for idx, (column, type) in enumerate(columns_to_types.items()): + expressions_to_concat.append( + exp.func( + "COALESCE", + self.normalize_value( + exp.to_column(column), type, decimal_precision, timestamp_precision + ), + exp.Literal.string(""), + ) + ) + if idx < len(columns_to_types) - 1: + expressions_to_concat.append(exp.Literal.string(delimiter)) + + return exp.func("CONCAT", *expressions_to_concat) + + def normalize_value( + self, + expr: exp.Expression, + type: exp.DataType, + decimal_precision: int = 3, + timestamp_precision: int = MAX_TIMESTAMP_PRECISION, + ) -> exp.Expression: + """ + Return an expression that converts the values inside the column `col` to a normalized string + + This string should be comparable across database engines, eg: + - `date` columns -> YYYY-MM-DD string + - `datetime`/`timestamp`/`timestamptz` columns -> ISO-8601 string to :timestamp_precision digits of subsecond precision + - `float` / `double` / `decimal` -> Value formatted to :decimal_precision decimal places + - `boolean` columns -> '1' or '0' + - NULLS -> "" (empty string) + """ + if type.is_type(exp.DataType.Type.BOOLEAN): + value = self._normalize_boolean_value(expr) + elif type.is_type(*exp.DataType.INTEGER_TYPES): + value = self._normalize_integer_value(expr) + elif type.is_type(*exp.DataType.REAL_TYPES): + # If there is no scale on the decimal type, treat it like an integer when comparing + # Some databases like Snowflake deliberately create all integer types as NUMERIC(, 0) + # and they should be treated as integers and not decimals + type_params = list(type.find_all(exp.DataTypeParam)) + if len(type_params) == 2 and type_params[-1].this.to_py() == 0: + value = self._normalize_integer_value(expr) + else: + value = self._normalize_decimal_value(expr, decimal_precision) + elif type.is_type(*exp.DataType.TEMPORAL_TYPES): + value = self._normalize_timestamp_value(expr, type, timestamp_precision) + elif type.is_type(*exp.DataType.NESTED_TYPES): + value = self._normalize_nested_value(expr) + else: + value = expr + + return exp.cast(value, to=exp.DataType.build("VARCHAR")) + + def _normalize_nested_value(self, expr: exp.Expression) -> exp.Expression: + return expr + + def _normalize_timestamp_value( + self, expr: exp.Expression, type: exp.DataType, precision: int + ) -> exp.Expression: + if precision > self.MAX_TIMESTAMP_PRECISION: + raise ValueError( + f"Requested timestamp precision '{precision}' exceeds maximum supported precision: {self.MAX_TIMESTAMP_PRECISION}" + ) + + is_date = type.is_type(exp.DataType.Type.DATE, exp.DataType.Type.DATE32) + + format = NORMALIZED_DATE_FORMAT if is_date else NORMALIZED_TIMESTAMP_FORMAT + + if type.is_type( + exp.DataType.Type.TIMESTAMPTZ, + exp.DataType.Type.TIMESTAMPLTZ, + exp.DataType.Type.TIMESTAMPNTZ, + ): + # Convert all timezone-aware values to UTC for comparison + expr = exp.AtTimeZone(this=expr, zone=exp.Literal.string("UTC")) + + digits_to_chop_off = ( + 6 - precision + ) # 6 = max precision across all adapters and also the max amount of digits TimeToStr will render since its based on `strftime` and `%f` only renders to microseconds + + expr = exp.TimeToStr(this=expr, format=exp.Literal.string(format)) + if digits_to_chop_off > 0: + expr = exp.func( + "SUBSTRING", expr, 1, len("2023-01-01 12:13:14.000000") - digits_to_chop_off + ) + + return expr + + def _normalize_integer_value(self, expr: exp.Expression) -> exp.Expression: + return exp.cast(expr, "BIGINT") + + def _normalize_decimal_value(self, expr: exp.Expression, precision: int) -> exp.Expression: + return exp.cast(expr, f"DECIMAL(38,{precision})") + + def _normalize_boolean_value(self, expr: exp.Expression) -> exp.Expression: + return exp.cast(expr, "INT") + + +class GrantsFromInfoSchemaMixin(EngineAdapter): + CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expression = exp.func("current_user") + SUPPORTS_MULTIPLE_GRANT_PRINCIPALS = False + USE_CATALOG_IN_GRANTS = False + GRANT_INFORMATION_SCHEMA_TABLE_NAME = "table_privileges" + + @staticmethod + @abc.abstractmethod + def _grant_object_kind(table_type: DataObjectType) -> t.Optional[str]: + pass + + @abc.abstractmethod + def _get_current_schema(self) -> str: + pass + + def _dcl_grants_config_expr( + self, + dcl_cmd: t.Type[DCL], + table: exp.Table, + grants_config: GrantsConfig, + table_type: DataObjectType = DataObjectType.TABLE, + ) -> t.List[exp.Expression]: + expressions: t.List[exp.Expression] = [] + if not grants_config: + return expressions + + object_kind = self._grant_object_kind(table_type) + for privilege, principals in grants_config.items(): + args: t.Dict[str, t.Any] = { + "privileges": [exp.GrantPrivilege(this=exp.Var(this=privilege))], + "securable": table.copy(), + } + if object_kind: + args["kind"] = exp.Var(this=object_kind) + if self.SUPPORTS_MULTIPLE_GRANT_PRINCIPALS: + args["principals"] = [ + normalize_identifiers( + parse_one(principal, into=exp.GrantPrincipal, dialect=self.dialect), + dialect=self.dialect, + ) + for principal in principals + ] + expressions.append(dcl_cmd(**args)) # type: ignore[arg-type] + else: + for principal in principals: + args["principals"] = [ + normalize_identifiers( + parse_one(principal, into=exp.GrantPrincipal, dialect=self.dialect), + dialect=self.dialect, + ) + ] + expressions.append(dcl_cmd(**args)) # type: ignore[arg-type] + + return expressions + + def _apply_grants_config_expr( + self, + table: exp.Table, + grants_config: GrantsConfig, + table_type: DataObjectType = DataObjectType.TABLE, + ) -> t.List[exp.Expression]: + return self._dcl_grants_config_expr(exp.Grant, table, grants_config, table_type) + + def _revoke_grants_config_expr( + self, + table: exp.Table, + grants_config: GrantsConfig, + table_type: DataObjectType = DataObjectType.TABLE, + ) -> t.List[exp.Expression]: + return self._dcl_grants_config_expr(exp.Revoke, table, grants_config, table_type) + + def _get_grant_expression(self, table: exp.Table) -> exp.Expression: + schema_identifier = table.args.get("db") or normalize_identifiers( + exp.to_identifier(self._get_current_schema(), quoted=True), dialect=self.dialect + ) + schema_name = schema_identifier.this + table_name = table.args.get("this").this # type: ignore + + grant_conditions = [ + exp.column("table_schema").eq(exp.Literal.string(schema_name)), + exp.column("table_name").eq(exp.Literal.string(table_name)), + exp.column("grantor").eq(self.CURRENT_USER_OR_ROLE_EXPRESSION), + exp.column("grantee").neq(self.CURRENT_USER_OR_ROLE_EXPRESSION), + ] + + info_schema_table = normalize_identifiers( + exp.table_(self.GRANT_INFORMATION_SCHEMA_TABLE_NAME, db="information_schema"), + dialect=self.dialect, + ) + if self.USE_CATALOG_IN_GRANTS: + catalog_identifier = table.args.get("catalog") + if not catalog_identifier: + catalog_name = self.get_current_catalog() + if not catalog_name: + raise SQLMeshError( + "Current catalog could not be determined for fetching grants. This is unexpected." + ) + catalog_identifier = normalize_identifiers( + exp.to_identifier(catalog_name, quoted=True), dialect=self.dialect + ) + catalog_name = catalog_identifier.this + info_schema_table.set("catalog", catalog_identifier.copy()) + grant_conditions.insert( + 0, exp.column("table_catalog").eq(exp.Literal.string(catalog_name)) + ) + + return ( + exp.select("privilege_type", "grantee") + .from_(info_schema_table) + .where(exp.and_(*grant_conditions)) + ) + + def _get_current_grants_config(self, table: exp.Table) -> GrantsConfig: + grant_expr = self._get_grant_expression(table) + + results = self.fetchall(grant_expr) + + grants_dict: GrantsConfig = {} + for privilege_raw, grantee_raw in results: + if privilege_raw is None or grantee_raw is None: + continue + + privilege = str(privilege_raw) + grantee = str(grantee_raw) + if not privilege or not grantee: + continue + + grantees = grants_dict.setdefault(privilege, []) + if grantee not in grantees: + grantees.append(grantee) + + return grants_dict diff --git a/sqlmesh/core/engine_adapter/mssql.py b/sqlmesh/core/engine_adapter/mssql.py index a632f01d72..359d1f0818 100644 --- a/sqlmesh/core/engine_adapter/mssql.py +++ b/sqlmesh/core/engine_adapter/mssql.py @@ -3,19 +3,24 @@ from __future__ import annotations import typing as t +import logging -import numpy as np -import pandas as pd -from pandas.api.types import is_datetime64_any_dtype # type: ignore from sqlglot import exp -from sqlmesh.core.dialect import to_schema -from sqlmesh.core.engine_adapter.base import EngineAdapterWithIndexSupport +from sqlmesh.core.dialect import to_schema, add_table +from sqlmesh.core.engine_adapter.base import ( + EngineAdapterWithIndexSupport, + EngineAdapter, + InsertOverwriteStrategy, + MERGE_SOURCE_ALIAS, + MERGE_TARGET_ALIAS, + _get_data_object_cache_key, +) from sqlmesh.core.engine_adapter.mixins import ( GetCurrentCatalogFromFunctionMixin, - InsertOverwriteWithMergeMixin, PandasNativeFetchDFSupportMixin, VarcharSizeWorkaroundMixin, + RowDiffMixin, ) from sqlmesh.core.engine_adapter.shared import ( CatalogSupport, @@ -26,31 +31,35 @@ SourceQuery, set_catalog, ) -from sqlmesh.core.schema_diff import SchemaDiffer +from sqlmesh.utils import get_source_columns_to_types if t.TYPE_CHECKING: from sqlmesh.core._typing import SchemaName, TableName - from sqlmesh.core.engine_adapter._typing import DF, Query + from sqlmesh.core.engine_adapter._typing import DF, Query, QueryOrDF + + +logger = logging.getLogger(__name__) @set_catalog() class MSSQLEngineAdapter( EngineAdapterWithIndexSupport, PandasNativeFetchDFSupportMixin, - InsertOverwriteWithMergeMixin, GetCurrentCatalogFromFunctionMixin, VarcharSizeWorkaroundMixin, + RowDiffMixin, ): DIALECT: str = "tsql" SUPPORTS_TUPLE_IN = False SUPPORTS_MATERIALIZED_VIEWS = False - CATALOG_SUPPORT = CatalogSupport.REQUIRES_SET_CATALOG CURRENT_CATALOG_EXPRESSION = exp.func("db_name") COMMENT_CREATION_TABLE = CommentCreationTable.UNSUPPORTED COMMENT_CREATION_VIEW = CommentCreationView.UNSUPPORTED SUPPORTS_REPLACE_TABLE = False - SCHEMA_DIFFER = SchemaDiffer( - parameterized_type_defaults={ + MAX_IDENTIFIER_LENGTH = 128 + SUPPORTS_QUERY_EXECUTION_TRACKING = True + SCHEMA_DIFFER_KWARGS = { + "parameterized_type_defaults": { exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(18, 0), (0,)], exp.DataType.build("BINARY", dialect=DIALECT).this: [(1,)], exp.DataType.build("VARBINARY", dialect=DIALECT).this: [(1,)], @@ -62,12 +71,21 @@ class MSSQLEngineAdapter( exp.DataType.build("DATETIME2", dialect=DIALECT).this: [(7,)], exp.DataType.build("DATETIMEOFFSET", dialect=DIALECT).this: [(7,)], }, - max_parameter_length={ + "max_parameter_length": { exp.DataType.build("VARBINARY", dialect=DIALECT).this: 2147483647, # 2 GB exp.DataType.build("VARCHAR", dialect=DIALECT).this: 2147483647, exp.DataType.build("NVARCHAR", dialect=DIALECT).this: 2147483647, }, - ) + } + VARIABLE_LENGTH_DATA_TYPES = {"binary", "varbinary", "char", "varchar", "nchar", "nvarchar"} + INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.MERGE + + @property + def catalog_support(self) -> CatalogSupport: + # MSSQL and AzureSQL both use this engine adapter, but they differ in catalog support. + # Therefore, we specify the catalog support in the connection config `_extra_engine_config` + # instead of in the adapter itself. + return self._extra_config["catalog_support"] def columns( self, @@ -80,38 +98,49 @@ def columns( sql = ( exp.select( - "column_name", - "data_type", - "character_maximum_length", - "numeric_precision", - "numeric_scale", + "COLUMN_NAME", + "DATA_TYPE", + "CHARACTER_MAXIMUM_LENGTH", + "NUMERIC_PRECISION", + "NUMERIC_SCALE", ) - .from_("information_schema.columns") - .where(f"table_name = '{table.name}'") + .from_("INFORMATION_SCHEMA.COLUMNS") + .where(f"TABLE_NAME = '{table.name}'") ) database_name = table.db if database_name: - sql = sql.where(f"table_schema = '{database_name}'") + sql = sql.where(f"TABLE_SCHEMA = '{database_name}'") columns_raw = self.fetchall(sql, quote_identifiers=True) - def build_var_length_col(row: tuple) -> tuple: - var_len_chars = ("binary", "varbinary", "char", "varchar", "nchar", "nvarchar") - if row[1] in var_len_chars and row[2] > 0: - return (row[0], f"{row[1]}({row[2]})") - if row[1] in ("varbinary", "varchar", "nvarchar") and row[2] == -1: - return (row[0], f"{row[1]}(max)") - if row[1] in ( - "decimal", - "numeric", + def build_var_length_col( + column_name: str, + data_type: str, + character_maximum_length: t.Optional[int] = None, + numeric_precision: t.Optional[int] = None, + numeric_scale: t.Optional[int] = None, + ) -> tuple: + data_type = data_type.lower() + if ( + data_type in self.VARIABLE_LENGTH_DATA_TYPES + and character_maximum_length is not None + and character_maximum_length > 0 ): - return (row[0], f"{row[1]}({row[3]}, {row[4]})") - if row[1] == "float": - return (row[0], f"{row[1]}({row[3]})") + return (column_name, f"{data_type}({character_maximum_length})") + if ( + data_type in ("varbinary", "varchar", "nvarchar") + and character_maximum_length is not None + and character_maximum_length == -1 + ): + return (column_name, f"{data_type}(max)") + if data_type in ("decimal", "numeric"): + return (column_name, f"{data_type}({numeric_precision}, {numeric_scale})") + if data_type == "float": + return (column_name, f"{data_type}({numeric_precision})") - return (row[0], row[1]) + return (column_name, data_type) - columns = [build_var_length_col(col) for col in columns_raw] + columns = [build_var_length_col(*row) for row in columns_raw] return { column_name: exp.DataType.build(data_type, dialect=self.dialect) @@ -121,15 +150,19 @@ def build_var_length_col(row: tuple) -> tuple: def table_exists(self, table_name: TableName) -> bool: """MsSql doesn't support describe so we query information_schema.""" table = exp.to_table(table_name) + data_object_cache_key = _get_data_object_cache_key(table.catalog, table.db, table.name) + if data_object_cache_key in self._data_object_cache: + logger.debug("Table existence cache hit: %s", data_object_cache_key) + return self._data_object_cache[data_object_cache_key] is not None sql = ( exp.select("1") - .from_("information_schema.tables") - .where(f"table_name = '{table.alias_or_name}'") + .from_("INFORMATION_SCHEMA.TABLES") + .where(f"TABLE_NAME = '{table.alias_or_name}'") ) database_name = table.db if database_name: - sql = sql.where(f"table_schema = '{database_name}'") + sql = sql.where(f"TABLE_SCHEMA = '{database_name}'") result = self.fetchone(sql, quote_identifiers=True) @@ -143,6 +176,7 @@ def drop_schema( schema_name: SchemaName, ignore_if_not_exists: bool = True, cascade: bool = False, + **drop_args: t.Dict[str, exp.Expression], ) -> None: """ MsSql doesn't support CASCADE clause and drops schemas unconditionally. @@ -150,20 +184,123 @@ def drop_schema( if cascade: objects = self._get_data_objects(schema_name) for obj in objects: + # Build properly quoted table for MSSQL using square brackets when needed + object_table = exp.table_(obj.name, obj.schema_name) + # _get_data_objects is catalog-specific, so these can't accidentally drop view/tables in another catalog if obj.type == DataObjectType.VIEW: self.drop_view( - ".".join([obj.schema_name, obj.name]), + object_table, ignore_if_not_exists=ignore_if_not_exists, ) else: self.drop_table( - ".".join([obj.schema_name, obj.name]), + object_table, exists=ignore_if_not_exists, ) super().drop_schema(schema_name, ignore_if_not_exists=ignore_if_not_exists, cascade=False) + def merge( + self, + target_table: TableName, + source_table: QueryOrDF, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]], + unique_key: t.Sequence[exp.Expression], + when_matched: t.Optional[exp.Whens] = None, + merge_filter: t.Optional[exp.Expression] = None, + source_columns: t.Optional[t.List[str]] = None, + **kwargs: t.Any, + ) -> None: + mssql_merge_exists = kwargs.get("physical_properties", {}).get("mssql_merge_exists") + + source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types( + source_table, + target_columns_to_types, + target_table=target_table, + source_columns=source_columns, + ) + target_columns_to_types = target_columns_to_types or self.columns(target_table) + on = exp.and_( + *( + add_table(part, MERGE_TARGET_ALIAS).eq(add_table(part, MERGE_SOURCE_ALIAS)) + for part in unique_key + ) + ) + if merge_filter: + on = exp.and_(merge_filter, on) + + match_expressions = [] + if not when_matched: + unique_key_names = [y.name for y in unique_key] + columns_to_types_no_keys = [ + c for c in target_columns_to_types if c not in unique_key_names + ] + + target_columns_no_keys = [ + exp.column(c, MERGE_TARGET_ALIAS) for c in columns_to_types_no_keys + ] + source_columns_no_keys = [ + exp.column(c, MERGE_SOURCE_ALIAS) for c in columns_to_types_no_keys + ] + + match_condition = ( + exp.Exists( + this=exp.select(*target_columns_no_keys).except_( + exp.select(*source_columns_no_keys) + ) + ) + if mssql_merge_exists + else None + ) + + if target_columns_no_keys: + match_expressions.append( + exp.When( + matched=True, + source=False, + condition=match_condition, + then=exp.Update( + expressions=[ + exp.column(col, MERGE_TARGET_ALIAS).eq( + exp.column(col, MERGE_SOURCE_ALIAS) + ) + for col in columns_to_types_no_keys + ], + ), + ) + ) + else: + match_expressions.extend(when_matched.copy().expressions) + + match_expressions.append( + exp.When( + matched=False, + source=False, + then=exp.Insert( + this=exp.Tuple( + expressions=[exp.column(col) for col in target_columns_to_types] + ), + expression=exp.Tuple( + expressions=[ + exp.column(col, MERGE_SOURCE_ALIAS) for col in target_columns_to_types + ] + ), + ), + ) + ) + for source_query in source_queries: + with source_query as query: + self._merge( + target_table=target_table, + query=query, + on=on, + whens=exp.Whens(expressions=match_expressions), + ) + def _convert_df_datetime(self, df: DF, columns_to_types: t.Dict[str, exp.DataType]) -> None: + import pandas as pd + from pandas.api.types import is_datetime64_any_dtype # type: ignore + # pymssql doesn't convert Pandas Timestamp (datetime64) types # - this code is based on snowflake adapter implementation for column, kind in columns_to_types.items(): @@ -184,27 +321,44 @@ def _convert_df_datetime(self, df: DF, columns_to_types: t.Dict[str, exp.DataTyp def _df_to_source_queries( self, df: DF, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], batch_size: int, target_table: TableName, + source_columns: t.Optional[t.List[str]] = None, ) -> t.List[SourceQuery]: + import pandas as pd + import numpy as np + assert isinstance(df, pd.DataFrame) temp_table = self._get_temp_table(target_table or "pandas") + # Return the superclass implementation if the connection pool doesn't support bulk_copy + if not hasattr(self._connection_pool.get(), "bulk_copy"): + return super()._df_to_source_queries( + df, target_columns_to_types, batch_size, target_table, source_columns=source_columns + ) + def query_factory() -> Query: # It is possible for the factory to be called multiple times and if so then the temp table will already # be created so we skip creating again. This means we are assuming the first call is the same result # as later calls. if not self.table_exists(temp_table): - columns_to_types_create = columns_to_types.copy() - self._convert_df_datetime(df, columns_to_types_create) - self.create_table(temp_table, columns_to_types_create) + source_columns_to_types = get_source_columns_to_types( + target_columns_to_types, source_columns + ) + ordered_df = df[ + list(source_columns_to_types) + ] # reorder DataFrame so it matches columns_to_types + self._convert_df_datetime(ordered_df, source_columns_to_types) + self.create_table(temp_table, source_columns_to_types) rows: t.List[t.Tuple[t.Any, ...]] = list( - df.replace({np.nan: None}).itertuples(index=False, name=None) - ) # type: ignore + ordered_df.replace({np.nan: None}).itertuples(index=False, name=None) # type: ignore + ) conn = self._connection_pool.get() conn.bulk_copy(temp_table.sql(dialect=self.dialect), rows) - return exp.select(*self._casted_columns(columns_to_types)).from_(temp_table) # type: ignore + return exp.select( + *self._casted_columns(target_columns_to_types, source_columns=source_columns) + ).from_(temp_table) # type: ignore return [ SourceQuery( @@ -219,6 +373,8 @@ def _get_data_objects( """ Returns all the data objects that exist in the given schema and catalog. """ + import pandas as pd + catalog = self.get_current_catalog() query = ( exp.select( @@ -257,3 +413,47 @@ def _rename_table( # The function that renames tables in MSSQL takes string literals as arguments instead of identifiers, # so we shouldn't quote the identifiers. self.execute(exp.rename_table(old_table_name, new_table_name), quote_identifiers=False) + + def _insert_overwrite_by_condition( + self, + table_name: TableName, + source_queries: t.List[SourceQuery], + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + where: t.Optional[exp.Condition] = None, + insert_overwrite_strategy_override: t.Optional[InsertOverwriteStrategy] = None, + **kwargs: t.Any, + ) -> None: + # note that this is passed as table_properties here rather than physical_properties + use_merge_strategy = kwargs.get("table_properties", {}).get("mssql_merge_exists") + if (not where or where == exp.true()) and not use_merge_strategy: + # this is a full table replacement, call the base strategy to do DELETE+INSERT + # which will result in TRUNCATE+INSERT due to how we have overridden self.delete_from() + return EngineAdapter._insert_overwrite_by_condition( + self, + table_name=table_name, + source_queries=source_queries, + target_columns_to_types=target_columns_to_types, + where=where, + insert_overwrite_strategy_override=InsertOverwriteStrategy.DELETE_INSERT, + **kwargs, + ) + + # For conditional overwrites or when mssql_merge_exists is set use MERGE + return super()._insert_overwrite_by_condition( + table_name=table_name, + source_queries=source_queries, + target_columns_to_types=target_columns_to_types, + where=where, + insert_overwrite_strategy_override=insert_overwrite_strategy_override, + **kwargs, + ) + + def delete_from(self, table_name: TableName, where: t.Union[str, exp.Expression]) -> None: + if where == exp.true(): + # "A TRUNCATE TABLE operation can be rolled back within a transaction." + # ref: https://learn.microsoft.com/en-us/sql/t-sql/statements/truncate-table-transact-sql?view=sql-server-ver15#remarks + return self.execute( + exp.TruncateTable(expressions=[exp.to_table(table_name, dialect=self.dialect)]) + ) + + return super().delete_from(table_name, where) diff --git a/sqlmesh/core/engine_adapter/mysql.py b/sqlmesh/core/engine_adapter/mysql.py index fd0538c801..31773d6c63 100644 --- a/sqlmesh/core/engine_adapter/mysql.py +++ b/sqlmesh/core/engine_adapter/mysql.py @@ -10,6 +10,7 @@ LogicalMergeMixin, NonTransactionalTruncateMixin, PandasNativeFetchDFSupportMixin, + RowDiffMixin, ) from sqlmesh.core.engine_adapter.shared import ( CommentCreationTable, @@ -18,7 +19,6 @@ DataObjectType, set_catalog, ) -from sqlmesh.core.schema_diff import SchemaDiffer if t.TYPE_CHECKING: from sqlmesh.core._typing import SchemaName, TableName @@ -28,9 +28,7 @@ @set_catalog() class MySQLEngineAdapter( - LogicalMergeMixin, - PandasNativeFetchDFSupportMixin, - NonTransactionalTruncateMixin, + LogicalMergeMixin, PandasNativeFetchDFSupportMixin, NonTransactionalTruncateMixin, RowDiffMixin ): DEFAULT_BATCH_SIZE = 200 DIALECT = "mysql" @@ -40,8 +38,10 @@ class MySQLEngineAdapter( MAX_TABLE_COMMENT_LENGTH = 2048 MAX_COLUMN_COMMENT_LENGTH = 1024 SUPPORTS_REPLACE_TABLE = False - SCHEMA_DIFFER = SchemaDiffer( - parameterized_type_defaults={ + MAX_IDENTIFIER_LENGTH = 64 + SUPPORTS_QUERY_EXECUTION_TRACKING = True + SCHEMA_DIFFER_KWARGS = { + "parameterized_type_defaults": { exp.DataType.build("BIT", dialect=DIALECT).this: [(1,)], exp.DataType.build("BINARY", dialect=DIALECT).this: [(1,)], exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(10, 0), (0,)], @@ -52,7 +52,7 @@ class MySQLEngineAdapter( exp.DataType.build("DATETIME", dialect=DIALECT).this: [(0,)], exp.DataType.build("TIMESTAMP", dialect=DIALECT).this: [(0,)], }, - ) + } def get_current_catalog(self) -> t.Optional[str]: """Returns the catalog name of the current connection.""" @@ -73,6 +73,7 @@ def drop_schema( schema_name: SchemaName, ignore_if_not_exists: bool = True, cascade: bool = False, + **drop_args: t.Dict[str, exp.Expression], ) -> None: # MySQL doesn't support CASCADE clause and drops schemas unconditionally. super().drop_schema(schema_name, ignore_if_not_exists=ignore_if_not_exists, cascade=False) @@ -129,35 +130,61 @@ def _create_column_comments( table_name: TableName, column_comments: t.Dict[str, str], table_kind: str = "TABLE", + materialized_view: bool = False, ) -> None: table = exp.to_table(table_name) table_sql = table.sql(dialect=self.dialect, identify=True) - try: - # MySQL ALTER TABLE MODIFY completely replaces the column (overwriting options and constraints). - # self.columns() only returns the column types so doesn't allow us to fully/correctly replace a column definition. - # To get the full column definition we retrieve and parse the table's CREATE TABLE statement. - create_table_exp = parse_one( - self.fetchone(f"SHOW CREATE TABLE {table_sql}")[1], dialect=self.dialect - ) - col_def_exps = { - col_def.name: col_def.copy() - for col_def in create_table_exp.find(exp.Schema).find_all(exp.ColumnDef) # type: ignore - } - - for col in column_comments: - col_def = col_def_exps.get(col) - col_def.args["constraints"].extend( # type: ignore - self._build_col_comment_exp(col_def.alias_or_name, column_comments) # type: ignore - ) - self.execute( - f"ALTER TABLE {table_sql} MODIFY {col_def.sql(dialect=self.dialect, identify=True)}", # type: ignore + # MySQL ALTER TABLE MODIFY completely replaces the column (overwriting options and constraints). + # self.columns() only returns the column types so doesn't allow us to fully/correctly replace a column definition. + # To get the full column definition we retrieve and parse the table's CREATE TABLE statement. + create_table_exp = parse_one( + self.fetchone(f"SHOW CREATE TABLE {table_sql}")[1], # type: ignore + dialect=self.dialect, + ) + col_def_exps = { + col_def.name: col_def.copy() + for col_def in create_table_exp.find(exp.Schema).find_all(exp.ColumnDef) # type: ignore + } + + for col in column_comments: + col_def = col_def_exps.get(col) + if col_def: + col_def.args["constraints"].extend( + self._build_col_comment_exp(col_def.alias_or_name, column_comments) ) - except Exception: - logger.warning( - f"Column comments for table '{table.alias_or_name}' not registered - this may be due to limited permissions.", - exc_info=True, + + try: + self.execute( + f"ALTER TABLE {table_sql} MODIFY {col_def.sql(dialect=self.dialect, identify=True)}", + ) + except Exception: + logger.warning( + f"Column comments for column '{col_def.alias_or_name}' in table '{table.alias_or_name}' not registered - this may be due to limited permissions.", + exc_info=True, + ) + + def _create_table_like( + self, + target_table_name: TableName, + source_table_name: TableName, + exists: bool, + **kwargs: t.Any, + ) -> None: + self.execute( + exp.Create( + this=exp.to_table(target_table_name), + kind="TABLE", + exists=exists, + properties=exp.Properties( + expressions=[ + exp.LikeProperty( + this=exp.to_table(source_table_name), + ), + ], + ), ) + ) def ping(self) -> None: self._connection_pool.get().ping(reconnect=False) diff --git a/sqlmesh/core/engine_adapter/postgres.py b/sqlmesh/core/engine_adapter/postgres.py index 178059946f..3dd108cf91 100644 --- a/sqlmesh/core/engine_adapter/postgres.py +++ b/sqlmesh/core/engine_adapter/postgres.py @@ -1,19 +1,24 @@ from __future__ import annotations import logging +import re import typing as t +from functools import cached_property, partial from sqlglot import exp from sqlmesh.core.engine_adapter.base_postgres import BasePostgresEngineAdapter from sqlmesh.core.engine_adapter.mixins import ( GetCurrentCatalogFromFunctionMixin, PandasNativeFetchDFSupportMixin, + RowDiffMixin, + logical_merge, + GrantsFromInfoSchemaMixin, ) from sqlmesh.core.engine_adapter.shared import set_catalog -from sqlmesh.core.schema_diff import SchemaDiffer if t.TYPE_CHECKING: - from sqlmesh.core.engine_adapter._typing import DF + from sqlmesh.core._typing import TableName + from sqlmesh.core.engine_adapter._typing import DF, QueryOrDF logger = logging.getLogger(__name__) @@ -23,21 +28,29 @@ class PostgresEngineAdapter( BasePostgresEngineAdapter, PandasNativeFetchDFSupportMixin, GetCurrentCatalogFromFunctionMixin, + RowDiffMixin, + GrantsFromInfoSchemaMixin, ): DIALECT = "postgres" + SUPPORTS_GRANTS = True SUPPORTS_INDEXES = True HAS_VIEW_BINDING = True CURRENT_CATALOG_EXPRESSION = exp.column("current_catalog") SUPPORTS_REPLACE_TABLE = False - SCHEMA_DIFFER = SchemaDiffer( - parameterized_type_defaults={ + MAX_IDENTIFIER_LENGTH: t.Optional[int] = 63 + SUPPORTS_QUERY_EXECUTION_TRACKING = True + GRANT_INFORMATION_SCHEMA_TABLE_NAME = "role_table_grants" + CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expression = exp.column("current_role") + SUPPORTS_MULTIPLE_GRANT_PRINCIPALS = True + SCHEMA_DIFFER_KWARGS = { + "parameterized_type_defaults": { # DECIMAL without precision is "up to 131072 digits before the decimal point; up to 16383 digits after the decimal point" exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(131072 + 16383, 16383), (0,)], exp.DataType.build("CHAR", dialect=DIALECT).this: [(1,)], exp.DataType.build("TIME", dialect=DIALECT).this: [(6,)], exp.DataType.build("TIMESTAMP", dialect=DIALECT).this: [(6,)], }, - types_with_unlimited_length={ + "types_with_unlimited_length": { # all can ALTER to `TEXT` exp.DataType.build("TEXT", dialect=DIALECT).this: { exp.DataType.build("VARCHAR", dialect=DIALECT).this, @@ -56,7 +69,8 @@ class PostgresEngineAdapter( exp.DataType.build("BPCHAR", dialect=DIALECT).this }, }, - ) + "drop_cascade": True, + } def _fetch_native_df( self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False @@ -70,3 +84,60 @@ def _fetch_native_df( if not self._connection_pool.is_transaction_active: self._connection_pool.commit() return df + + def _create_table_like( + self, + target_table_name: TableName, + source_table_name: TableName, + exists: bool, + **kwargs: t.Any, + ) -> None: + self.execute( + exp.Create( + this=exp.Schema( + this=exp.to_table(target_table_name), + expressions=[ + exp.LikeProperty( + this=exp.to_table(source_table_name), + expressions=[exp.Property(this="INCLUDING", value=exp.Var(this="ALL"))], + ) + ], + ), + kind="TABLE", + exists=exists, + ) + ) + + def merge( + self, + target_table: TableName, + source_table: QueryOrDF, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]], + unique_key: t.Sequence[exp.Expression], + when_matched: t.Optional[exp.Whens] = None, + merge_filter: t.Optional[exp.Expression] = None, + source_columns: t.Optional[t.List[str]] = None, + **kwargs: t.Any, + ) -> None: + # Merge isn't supported until Postgres 15 + major, minor = self.server_version + merge_impl = super().merge if major >= 15 else partial(logical_merge, self) + merge_impl( # type: ignore + target_table, + source_table, + target_columns_to_types, + unique_key, + when_matched=when_matched, + merge_filter=merge_filter, + source_columns=source_columns, + ) + + @cached_property + def server_version(self) -> t.Tuple[int, int]: + """Lazily fetch and cache major and minor server version""" + if result := self.fetchone("SHOW server_version"): + server_version, *_ = result + match = re.search(r"(\d+)\.(\d+)", server_version) + if match: + return int(match.group(1)), int(match.group(2)) + return 0, 0 diff --git a/sqlmesh/core/engine_adapter/redshift.py b/sqlmesh/core/engine_adapter/redshift.py index 77abb2bb47..03dc89053e 100644 --- a/sqlmesh/core/engine_adapter/redshift.py +++ b/sqlmesh/core/engine_adapter/redshift.py @@ -1,17 +1,20 @@ from __future__ import annotations +import logging import typing as t -import pandas as pd from sqlglot import exp from sqlmesh.core.dialect import to_schema +from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS from sqlmesh.core.engine_adapter.base_postgres import BasePostgresEngineAdapter from sqlmesh.core.engine_adapter.mixins import ( GetCurrentCatalogFromFunctionMixin, - LogicalMergeMixin, NonTransactionalTruncateMixin, VarcharSizeWorkaroundMixin, + RowDiffMixin, + logical_merge, + GrantsFromInfoSchemaMixin, ) from sqlmesh.core.engine_adapter.shared import ( CommentCreationView, @@ -20,28 +23,36 @@ SourceQuery, set_catalog, ) -from sqlmesh.core.schema_diff import SchemaDiffer +from sqlmesh.utils.errors import SQLMeshError if t.TYPE_CHECKING: + import pandas as pd + from sqlmesh.core._typing import SchemaName, TableName - from sqlmesh.core.engine_adapter.base import QueryOrDF + from sqlmesh.core.engine_adapter.base import QueryOrDF, Query + +logger = logging.getLogger(__name__) @set_catalog() class RedshiftEngineAdapter( BasePostgresEngineAdapter, - LogicalMergeMixin, GetCurrentCatalogFromFunctionMixin, NonTransactionalTruncateMixin, VarcharSizeWorkaroundMixin, + RowDiffMixin, + GrantsFromInfoSchemaMixin, ): DIALECT = "redshift" CURRENT_CATALOG_EXPRESSION = exp.func("current_database") # Redshift doesn't support comments for VIEWs WITH NO SCHEMA BINDING (which we always use) COMMENT_CREATION_VIEW = CommentCreationView.UNSUPPORTED SUPPORTS_REPLACE_TABLE = False - SCHEMA_DIFFER = SchemaDiffer( - parameterized_type_defaults={ + SUPPORTS_GRANTS = True + SUPPORTS_MULTIPLE_GRANT_PRINCIPALS = True + + SCHEMA_DIFFER_KWARGS = { + "parameterized_type_defaults": { exp.DataType.build("VARBYTE", dialect=DIALECT).this: [(64000,)], exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(18, 0), (0,)], exp.DataType.build("CHAR", dialect=DIALECT).this: [(1,)], @@ -49,21 +60,78 @@ class RedshiftEngineAdapter( exp.DataType.build("NCHAR", dialect=DIALECT).this: [(1,)], exp.DataType.build("NVARCHAR", dialect=DIALECT).this: [(256,)], }, - max_parameter_length={ + "max_parameter_length": { exp.DataType.build("CHAR", dialect=DIALECT).this: 4096, exp.DataType.build("VARCHAR", dialect=DIALECT).this: 65535, }, - ) + "precision_increase_allowed_types": {exp.DataType.build("VARCHAR", dialect=DIALECT).this}, + "drop_cascade": True, + } + VARIABLE_LENGTH_DATA_TYPES = { + "char", + "character", + "nchar", + "varchar", + "character varying", + "nvarchar", + "varbyte", + "varbinary", + "binary varying", + } + + def columns( + self, + table_name: TableName, + include_pseudo_columns: bool = True, + ) -> t.Dict[str, exp.DataType]: + table = exp.to_table(table_name) - def _columns_query(self, table: exp.Table) -> exp.Select: sql = ( - exp.select("column_name", "data_type") + exp.select( + "column_name", + "data_type", + "character_maximum_length", + "numeric_precision", + "numeric_scale", + ) .from_("svv_columns") # Includes late-binding views .where(exp.column("table_name").eq(table.alias_or_name)) ) if table.args.get("db"): sql = sql.where(exp.column("table_schema").eq(table.args["db"].name)) - return sql + + columns_raw = self.fetchall(sql, quote_identifiers=True) + + def build_var_length_col( + column_name: str, + data_type: str, + character_maximum_length: t.Optional[int] = None, + numeric_precision: t.Optional[int] = None, + numeric_scale: t.Optional[int] = None, + ) -> tuple: + data_type = data_type.lower() + if ( + data_type in self.VARIABLE_LENGTH_DATA_TYPES + and character_maximum_length is not None + ): + return (column_name, f"{data_type}({character_maximum_length})") + if data_type in ("decimal", "numeric"): + return (column_name, f"{data_type}({numeric_precision}, {numeric_scale})") + + return (column_name, data_type) + + columns = [build_var_length_col(*row) for row in columns_raw] + + return { + column_name: exp.DataType.build(data_type, dialect=self.dialect) + for column_name, data_type in columns + } + + @property + def enable_merge(self) -> bool: + # Redshift supports the MERGE operation but we use the logical merge + # unless the user has opted in by setting enable_merge in the connection. + return bool(self._extra_config.get("enable_merge")) @property def cursor(self) -> t.Any: @@ -78,19 +146,38 @@ def _fetch_native_df( self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False ) -> pd.DataFrame: """Fetches a Pandas DataFrame from the cursor""" + import pandas as pd + self.execute(query, quote_identifiers=quote_identifiers) - return self.cursor.fetch_dataframe() + + # We manually build the `DataFrame` here because the driver's `fetch_dataframe` + # method does not respect the active case-sensitivity configuration. + # + # Context: https://github.com/aws/amazon-redshift-python-driver/issues/238 + fetcheddata = self.cursor.fetchall() + + try: + columns = [column[0] for column in self.cursor.description] + except Exception: + columns = None + logging.warning( + "No row description was found, pandas dataframe will be missing column labels." + ) + + result = [tuple(row) for row in fetcheddata] + return pd.DataFrame(result, columns=columns) def _create_table_from_source_queries( self, table_name: TableName, source_queries: t.List[SourceQuery], - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, exists: bool = True, replace: bool = False, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, + track_rows_processed: bool = True, **kwargs: t.Any, ) -> None: """ @@ -103,7 +190,7 @@ def _create_table_from_source_queries( return super()._create_table_from_source_queries( table_name, source_queries, - columns_to_types, + target_columns_to_types, exists, table_description=table_description, column_descriptions=column_descriptions, @@ -124,12 +211,14 @@ def create_view( self, view_name: TableName, query_or_df: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, replace: bool = True, materialized: bool = False, + materialized_properties: t.Optional[t.Dict[str, t.Any]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, view_properties: t.Optional[t.Dict[str, exp.Expression]] = None, + source_columns: t.Optional[t.List[str]] = None, **create_kwargs: t.Any, ) -> None: """ @@ -137,16 +226,26 @@ def create_view( underlying table without dropping the view first. This is a problem for us since we want to be able to swap tables out from under views. Therefore, we create the view as non-binding. """ + no_schema_binding = True + if isinstance(query_or_df, exp.Expression): + # We can't include NO SCHEMA BINDING if the query has a recursive CTE + has_recursive_cte = any( + w.args.get("recursive", False) for w in query_or_df.find_all(exp.With) + ) + no_schema_binding = not has_recursive_cte + return super().create_view( view_name, query_or_df, - columns_to_types, + target_columns_to_types, replace, materialized, + materialized_properties, table_description=table_description, column_descriptions=column_descriptions, - no_schema_binding=create_kwargs.pop("no_schema_binding", True), + no_schema_binding=no_schema_binding, view_properties=view_properties, + source_columns=source_columns, **create_kwargs, ) @@ -154,9 +253,11 @@ def replace_query( self, table_name: TableName, query_or_df: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, + source_columns: t.Optional[t.List[str]] = None, + supports_replace_table_override: t.Optional[bool] = None, **kwargs: t.Any, ) -> None: """ @@ -167,32 +268,43 @@ def replace_query( If it does exist then we need to do the: `CREATE TABLE...`, `INSERT INTO...`, `RENAME TABLE...`, `RENAME TABLE...`, DROP TABLE...` dance. """ - if not isinstance(query_or_df, pd.DataFrame) or not self.table_exists(table_name): + import pandas as pd + + target_data_object = self.get_data_object(table_name) + table_exists = target_data_object is not None + if self.drop_data_object_on_type_mismatch(target_data_object, DataObjectType.TABLE): + table_exists = False + + if not isinstance(query_or_df, pd.DataFrame) or not table_exists: return super().replace_query( table_name, query_or_df, - columns_to_types, + target_columns_to_types, table_description, column_descriptions, + source_columns=source_columns, **kwargs, ) - source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( - query_or_df, columns_to_types, target_table=table_name + source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types( + query_or_df, + target_columns_to_types, + target_table=table_name, + source_columns=source_columns, ) - columns_to_types = columns_to_types or self.columns(table_name) + target_columns_to_types = target_columns_to_types or self.columns(table_name) target_table = exp.to_table(table_name) with self.transaction(): temp_table = self._get_temp_table(target_table) old_table = self._get_temp_table(target_table) self.create_table( temp_table, - columns_to_types, + target_columns_to_types, exists=False, table_description=table_description, column_descriptions=column_descriptions, **kwargs, ) - self._insert_append_source_queries(temp_table, source_queries, columns_to_types) + self._insert_append_source_queries(temp_table, source_queries, target_columns_to_types) self.rename_table(target_table, old_table) self.rename_table(temp_table, target_table) self.drop_table(old_table) @@ -249,3 +361,92 @@ def _get_data_objects( ) for row in df.itertuples() ] + + def merge( + self, + target_table: TableName, + source_table: QueryOrDF, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]], + unique_key: t.Sequence[exp.Expression], + when_matched: t.Optional[exp.Whens] = None, + merge_filter: t.Optional[exp.Expression] = None, + source_columns: t.Optional[t.List[str]] = None, + **kwargs: t.Any, + ) -> None: + if self.enable_merge: + # By default we use the logical merge unless the user has opted in + super().merge( + target_table=target_table, + source_table=source_table, + target_columns_to_types=target_columns_to_types, + unique_key=unique_key, + when_matched=when_matched, + merge_filter=merge_filter, + source_columns=source_columns, + ) + else: + logical_merge( + self, + target_table, + source_table, + target_columns_to_types, + unique_key, + when_matched=when_matched, + merge_filter=merge_filter, + source_columns=source_columns, + ) + + def _merge( + self, + target_table: TableName, + query: Query, + on: exp.Expression, + whens: exp.Whens, + ) -> None: + # Redshift does not support table aliases in the target table of a MERGE statement. + # So we must use the actual table name instead of an alias, as we do with the source table. + def resolve_target_table(expression: exp.Expression) -> exp.Expression: + if ( + isinstance(expression, exp.Column) + and expression.table.upper() == MERGE_TARGET_ALIAS + ): + expression.set("table", exp.to_table(target_table)) + return expression + + # Ensure that there is exactly one "WHEN MATCHED" and one "WHEN NOT MATCHED" clause. + # Since Redshift does not support multiple "WHEN MATCHED" clauses. + if ( + len(whens.expressions) != 2 + or whens.expressions[0].args["matched"] == whens.expressions[1].args["matched"] + ): + raise SQLMeshError( + "Redshift only supports a single WHEN MATCHED and WHEN NOT MATCHED clause" + ) + + using = exp.alias_( + exp.Subquery(this=query), alias=MERGE_SOURCE_ALIAS, copy=False, table=True + ) + self.execute( + exp.Merge( + this=target_table, + using=using, + on=on.transform(resolve_target_table), + whens=whens.transform(resolve_target_table), + ), + track_rows_processed=True, + ) + + def _normalize_decimal_value(self, expr: exp.Expression, precision: int) -> exp.Expression: + # Redshift is finicky. It truncates when the data is already in a table, but rounds when the data is generated as part of a SELECT. + # + # The following works: + # > select cast(cast(3.14159 as decimal(6, 5)) as decimal(6, 3)); --produces '3.142', the value we want / what every other database produces + # + # However, if you write that to a table, and then cast it to a less precise decimal, you get _truncation_. + # > create table foo (val decimal(6, 5)); insert into foo(val) values (3.14159); + # > select cast(val as decimal(6, 3)) from foo; --produces '3.141' + # + # So to make up for this, we force it to round by injecting a round() expression + rounded = exp.func("ROUND", expr, precision) + + return super()._normalize_decimal_value(rounded, precision) diff --git a/sqlmesh/core/engine_adapter/risingwave.py b/sqlmesh/core/engine_adapter/risingwave.py new file mode 100644 index 0000000000..61b44f5bbb --- /dev/null +++ b/sqlmesh/core/engine_adapter/risingwave.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import logging +import typing as t + + +from sqlglot import exp + +from sqlmesh.core.engine_adapter.postgres import PostgresEngineAdapter +from sqlmesh.core.engine_adapter.shared import ( + set_catalog, + CatalogSupport, + CommentCreationView, + CommentCreationTable, +) + +from sqlmesh.utils.errors import SQLMeshError + +if t.TYPE_CHECKING: + from sqlmesh.core._typing import TableName + +logger = logging.getLogger(__name__) + + +@set_catalog() +class RisingwaveEngineAdapter(PostgresEngineAdapter): + DIALECT = "risingwave" + DEFAULT_BATCH_SIZE = 400 + CATALOG_SUPPORT = CatalogSupport.SINGLE_CATALOG_ONLY + COMMENT_CREATION_TABLE = CommentCreationTable.COMMENT_COMMAND_ONLY + COMMENT_CREATION_VIEW = CommentCreationView.UNSUPPORTED + SUPPORTS_MATERIALIZED_VIEWS = True + SUPPORTS_TRANSACTIONS = False + MAX_IDENTIFIER_LENGTH = None + SUPPORTS_GRANTS = False + + def columns( + self, table_name: TableName, include_pseudo_columns: bool = False + ) -> t.Dict[str, exp.DataType]: + """Fetches column names and types for the target_table""" + table = exp.to_table(table_name) + + sql = ( + exp.select("rw_columns.name AS column_name", "rw_columns.data_type AS data_type") + .from_("rw_catalog.rw_columns") + .join("rw_catalog.rw_relations", on="rw_relations.id=rw_columns.relation_id") + .join("rw_catalog.rw_schemas", on="rw_schemas.id=rw_relations.schema_id") + .where( + exp.and_( + exp.column("name", table="rw_relations").eq(table.alias_or_name), + exp.column("name", table="rw_columns").neq("_row_id"), + exp.column("name", table="rw_columns").neq("_rw_timestamp"), + ) + ) + ) + + if table.db: + sql = sql.where(exp.column("name", table="rw_schemas").eq(table.db)) + + self.execute(sql) + resp = self.cursor.fetchall() + if not resp: + raise SQLMeshError(f"Could not get columns for table {table_name}. Table not found.") + return { + column_name: exp.DataType.build(data_type, dialect=self.dialect, udt=True) + for column_name, data_type in resp + } + + def _truncate_table(self, table_name: TableName) -> None: + return self.execute(exp.Delete(this=exp.to_table(table_name))) diff --git a/sqlmesh/core/engine_adapter/shared.py b/sqlmesh/core/engine_adapter/shared.py index 89fe7ed11d..ba0e1fa619 100644 --- a/sqlmesh/core/engine_adapter/shared.py +++ b/sqlmesh/core/engine_adapter/shared.py @@ -11,7 +11,7 @@ from sqlglot import exp from sqlmesh.core.dialect import to_schema -from sqlmesh.utils.errors import UnsupportedCatalogOperationError +from sqlmesh.utils.errors import UnsupportedCatalogOperationError, SQLMeshError from sqlmesh.utils.pydantic import PydanticModel if t.TYPE_CHECKING: @@ -164,11 +164,28 @@ class DataObject(PydanticModel): name: str type: DataObjectType + # for type=DataObjectType.Table, only if the DB supports it + clustering_key: t.Optional[str] = None + + @property + def is_clustered(self) -> bool: + return bool(self.clustering_key) + + def to_table(self) -> exp.Table: + return exp.table_(self.name, db=self.schema_name, catalog=self.catalog, quoted=True) + class CatalogSupport(Enum): + # The engine has no concept of catalogs UNSUPPORTED = 1 + + # The engine has a concept of catalogs, but they are isolated from each other and cannot reference each others tables SINGLE_CATALOG_ONLY = 2 + + # The engine supports multiple catalogs but some operations require a SET CATALOG query to set the active catalog before proceeding REQUIRES_SET_CATALOG = 3 + + # The engine supports multiple catalogs and can unambiguously target a specific catalog when performing operations (without running SET CATALOG first) FULL_SUPPORT = 4 @property @@ -192,12 +209,42 @@ def is_multi_catalog_supported(self) -> bool: return self.is_requires_set_catalog or self.is_full_support +class EngineRunMode(Enum): + SINGLE_MODE_ENGINE = 1 + STANDALONE = 2 + CLUSTER = 3 + CLOUD = 4 + + @property + def is_single_mode_engine(self) -> bool: + return self == EngineRunMode.SINGLE_MODE_ENGINE + + @property + def is_standalone(self) -> bool: + return self == EngineRunMode.STANDALONE + + @property + def is_cluster(self) -> bool: + return self == EngineRunMode.CLUSTER + + @property + def is_cloud(self) -> bool: + return self == EngineRunMode.CLOUD + + class InsertOverwriteStrategy(Enum): + # First, issue a DELETE to clear the data range. Then, issue an INSERT query to insert the new data DELETE_INSERT = 1 + # Issue a single INSERT OVERWRITE query to replace a data range. INSERT_OVERWRITE = 2 + # Issue a single INSERT INTO... REPLACE WHERE query # Note: Replace where on Databricks requires that `spark.sql.sources.partitionOverwriteMode` be set to `static` REPLACE_WHERE = 3 + # Issue a single INSERT query to replace a data range. The assumption is that the query engine will transparently match partition bounds + # and replace data rather than append to it. Trino is an example of this when `hive.insert-existing-partitions-behavior=OVERWRITE` is configured INTO_IS_OVERWRITE = 4 + # Do the INSERT OVERWRITE using merge since the engine doesn't support it natively + MERGE = 5 @property def is_delete_insert(self) -> bool: @@ -215,6 +262,10 @@ def is_replace_where(self) -> bool: def is_into_is_overwrite(self) -> bool: return self == InsertOverwriteStrategy.INTO_IS_OVERWRITE + @property + def is_merge(self) -> bool: + return self == InsertOverwriteStrategy.MERGE + class SourceQuery: def __init__( @@ -261,7 +312,7 @@ def internal_wrapper(*args: t.Any, **kwargs: t.Any) -> t.Any: # Need to convert args to list in order to later do assignment to the object list_args = list(args) engine_adapter = list_args[0] - catalog_support = override or engine_adapter.CATALOG_SUPPORT + catalog_support = override or engine_adapter.catalog_support # If there is full catalog support then we have nothing to do if catalog_support.is_full_support: return func(*list_args, **kwargs) @@ -282,6 +333,7 @@ def internal_wrapper(*args: t.Any, **kwargs: t.Any) -> t.Any: catalog_name = expression.catalog if not catalog_name: return func(*list_args, **kwargs) + # If we have a catalog and this engine doesn't support catalogs then we need to error if catalog_support.is_unsupported: raise UnsupportedCatalogOperationError( @@ -292,8 +344,8 @@ def internal_wrapper(*args: t.Any, **kwargs: t.Any) -> t.Any: container[key] = expression # type: ignore if catalog_support.is_single_catalog_only: if catalog_name != engine_adapter._default_catalog: - logger.warning( - f"{engine_adapter.dialect} requires that all catalog operations be against a single catalog: {engine_adapter._default_catalog}. Ignoring catalog: {catalog_name}" + raise SQLMeshError( + f"{engine_adapter.dialect} requires that all catalog operations be against a single catalog: {engine_adapter._default_catalog}. Provided catalog: {catalog_name}" ) return func(*list_args, **kwargs) # Set the catalog name on the engine adapter if needed diff --git a/sqlmesh/core/engine_adapter/snowflake.py b/sqlmesh/core/engine_adapter/snowflake.py index 38944d960f..a8eabe070d 100644 --- a/sqlmesh/core/engine_adapter/snowflake.py +++ b/sqlmesh/core/engine_adapter/snowflake.py @@ -1,16 +1,22 @@ from __future__ import annotations import contextlib +import logging import typing as t -import pandas as pd -from pandas.api.types import is_datetime64_any_dtype # type: ignore from sqlglot import exp +from sqlglot.helper import ensure_list from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlglot.optimizer.qualify_columns import quote_identifiers +import sqlmesh.core.constants as c from sqlmesh.core.dialect import to_schema -from sqlmesh.core.engine_adapter.mixins import GetCurrentCatalogFromFunctionMixin +from sqlmesh.core.engine_adapter.mixins import ( + GetCurrentCatalogFromFunctionMixin, + ClusteredByMixin, + RowDiffMixin, + GrantsFromInfoSchemaMixin, +) from sqlmesh.core.engine_adapter.shared import ( CatalogSupport, DataObject, @@ -18,15 +24,23 @@ SourceQuery, set_catalog, ) -from sqlmesh.core.schema_diff import SchemaDiffer -from sqlmesh.utils import optional_import +from sqlmesh.utils import optional_import, get_source_columns_to_types from sqlmesh.utils.errors import SQLMeshError +from sqlmesh.utils.pandas import columns_to_types_from_dtypes +logger = logging.getLogger(__name__) snowpark = optional_import("snowflake.snowpark") if t.TYPE_CHECKING: + import pandas as pd + from sqlmesh.core._typing import SchemaName, SessionProperties, TableName - from sqlmesh.core.engine_adapter._typing import DF, Query, SnowparkSession + from sqlmesh.core.engine_adapter._typing import ( + DF, + Query, + QueryOrDF, + SnowparkSession, + ) from sqlmesh.core.node import IntervalUnit @@ -35,18 +49,23 @@ "_get_data_objects": CatalogSupport.REQUIRES_SET_CATALOG, "create_schema": CatalogSupport.REQUIRES_SET_CATALOG, "drop_schema": CatalogSupport.REQUIRES_SET_CATALOG, + "drop_catalog": CatalogSupport.REQUIRES_SET_CATALOG, # needs a catalog to issue a query to information_schema.databases even though the result is global } ) -class SnowflakeEngineAdapter(GetCurrentCatalogFromFunctionMixin): +class SnowflakeEngineAdapter( + GetCurrentCatalogFromFunctionMixin, ClusteredByMixin, RowDiffMixin, GrantsFromInfoSchemaMixin +): DIALECT = "snowflake" SUPPORTS_MATERIALIZED_VIEWS = True SUPPORTS_MATERIALIZED_VIEW_SCHEMA = True SUPPORTS_CLONING = True SUPPORTS_MANAGED_MODELS = True - CATALOG_SUPPORT = CatalogSupport.FULL_SUPPORT CURRENT_CATALOG_EXPRESSION = exp.func("current_database") - SCHEMA_DIFFER = SchemaDiffer( - parameterized_type_defaults={ + SUPPORTS_CREATE_DROP_CATALOG = True + SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS = True + SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["DATABASE", "SCHEMA", "TABLE"] + SCHEMA_DIFFER_KWARGS = { + "parameterized_type_defaults": { exp.DataType.build("BINARY", dialect=DIALECT).this: [(8388608,)], exp.DataType.build("VARBINARY", dialect=DIALECT).this: [(8388608,)], exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(38, 0), (0,)], @@ -59,8 +78,13 @@ class SnowflakeEngineAdapter(GetCurrentCatalogFromFunctionMixin): exp.DataType.build("TIMESTAMP_NTZ", dialect=DIALECT).this: [(9,)], exp.DataType.build("TIMESTAMP_TZ", dialect=DIALECT).this: [(9,)], }, - ) + } MANAGED_TABLE_KIND = "DYNAMIC TABLE" + SNOWPARK = "snowpark" + SUPPORTS_QUERY_EXECUTION_TRACKING = True + SUPPORTS_GRANTS = True + CURRENT_USER_OR_ROLE_EXPRESSION: exp.Expression = exp.func("CURRENT_ROLE") + USE_CATALOG_IN_GRANTS = True @contextlib.contextmanager def session(self, properties: SessionProperties) -> t.Iterator[None]: @@ -85,33 +109,128 @@ def session(self, properties: SessionProperties) -> t.Iterator[None]: return self.execute(f"USE WAREHOUSE {warehouse_sql}") - yield - self.execute(f"USE WAREHOUSE {current_warehouse_sql}") + try: + yield + finally: + self.execute(f"USE WAREHOUSE {current_warehouse_sql}") @property def _current_warehouse(self) -> exp.Identifier: - current_warehouse_str = self.fetchone("SELECT CURRENT_WAREHOUSE()")[0] + current_warehouse_str = self.fetchone("SELECT CURRENT_WAREHOUSE()")[0] # type: ignore # The warehouse value returned by Snowflake is already normalized, so only quoting is needed. return quote_identifiers(exp.to_identifier(current_warehouse_str), dialect=self.dialect) @property def snowpark(self) -> t.Optional[SnowparkSession]: if snowpark: - return snowpark.Session.builder.configs( - {"connection": self._connection_pool.get()} - ).getOrCreate() + if not self._connection_pool.get_attribute(self.SNOWPARK): + # Snowpark sessions are not thread safe so we create a session per thread to prevent them from interfering with each other + # The sessions are cleaned up when close() is called + new_session = snowpark.Session.builder.configs( + {"connection": self._connection_pool.get()} + ).create() + self._connection_pool.set_attribute(self.SNOWPARK, new_session) + + return self._connection_pool.get_attribute(self.SNOWPARK) + return None + @property + def catalog_support(self) -> CatalogSupport: + return CatalogSupport.FULL_SUPPORT + + @staticmethod + def _grant_object_kind(table_type: DataObjectType) -> str: + if table_type == DataObjectType.VIEW: + return "VIEW" + if table_type == DataObjectType.MATERIALIZED_VIEW: + return "MATERIALIZED VIEW" + if table_type == DataObjectType.MANAGED_TABLE: + return "DYNAMIC TABLE" + return "TABLE" + + def _get_current_schema(self) -> str: + """Returns the current default schema for the connection.""" + result = self.fetchone("SELECT CURRENT_SCHEMA()") + if not result or not result[0]: + raise SQLMeshError("Unable to determine current schema") + return str(result[0]) + + def _create_catalog(self, catalog_name: exp.Identifier) -> None: + props = exp.Properties( + expressions=[exp.SchemaCommentProperty(this=exp.Literal.string(c.SQLMESH_MANAGED))] + ) + self.execute( + exp.Create( + this=exp.Table(this=catalog_name), kind="DATABASE", exists=True, properties=props + ) + ) + + def _drop_catalog(self, catalog_name: exp.Identifier) -> None: + # only drop the catalog if it was created by SQLMesh, which is indicated by its comment matching {c.SQLMESH_MANAGED} + exists_check = ( + exp.select(exp.Literal.number(1)) + .from_(exp.to_table("information_schema.databases")) + .where( + exp.and_( + exp.column("database_name").eq(exp.Literal.string(catalog_name)), + exp.column("comment").eq(exp.Literal.string(c.SQLMESH_MANAGED)), + ) + ) + ) + normalize_identifiers(exists_check, dialect=self.dialect) + if self.fetchone(exists_check, quote_identifiers=True) is not None: + self.execute(exp.Drop(this=exp.Table(this=catalog_name), kind="DATABASE", exists=True)) + else: + logger.warning( + f"Not dropping database {catalog_name.sql(dialect=self.dialect)} because there is no indication it is '{c.SQLMESH_MANAGED}'" + ) + + def _create_table( + self, + table_name_or_schema: t.Union[exp.Schema, TableName], + expression: t.Optional[exp.Expression], + exists: bool = True, + replace: bool = False, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + table_description: t.Optional[str] = None, + column_descriptions: t.Optional[t.Dict[str, str]] = None, + table_kind: t.Optional[str] = None, + track_rows_processed: bool = True, + **kwargs: t.Any, + ) -> None: + table_format = kwargs.get("table_format") + if table_format and isinstance(table_format, str): + table_format = table_format.upper() + if not table_kind: + table_kind = f"{table_format} TABLE" + elif table_kind == self.MANAGED_TABLE_KIND: + table_kind = f"DYNAMIC {table_format} TABLE" + + super()._create_table( + table_name_or_schema=table_name_or_schema, + expression=expression, + exists=exists, + replace=replace, + target_columns_to_types=target_columns_to_types, + table_description=table_description, + column_descriptions=column_descriptions, + table_kind=table_kind, + track_rows_processed=False, # snowflake tracks CTAS row counts incorrectly + **kwargs, + ) + def create_managed_table( self, table_name: TableName, query: Query, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, partitioned_by: t.Optional[t.List[exp.Expression]] = None, - clustered_by: t.Optional[t.List[str]] = None, + clustered_by: t.Optional[t.List[exp.Expression]] = None, table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, + source_columns: t.Optional[t.List[str]] = None, **kwargs: t.Any, ) -> None: target_table = exp.to_table(table_name) @@ -131,14 +250,14 @@ def create_managed_table( "`target_lag` must be specified in the model physical_properties for a Snowflake Dynamic Table" ) - source_queries, columns_to_types = self._get_source_queries_and_columns_to_types( - query, columns_to_types, target_table=target_table + source_queries, target_columns_to_types = self._get_source_queries_and_columns_to_types( + query, target_columns_to_types, target_table=target_table, source_columns=source_columns ) self._create_table_from_source_queries( target_table, source_queries, - columns_to_types, + target_columns_to_types, replace=self.SUPPORTS_REPLACE_TABLE, partitioned_by=partitioned_by, clustered_by=clustered_by, @@ -149,20 +268,57 @@ def create_managed_table( **kwargs, ) + def create_view( + self, + view_name: TableName, + query_or_df: QueryOrDF, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + replace: bool = True, + materialized: bool = False, + materialized_properties: t.Optional[t.Dict[str, t.Any]] = None, + table_description: t.Optional[str] = None, + column_descriptions: t.Optional[t.Dict[str, str]] = None, + view_properties: t.Optional[t.Dict[str, exp.Expression]] = None, + source_columns: t.Optional[t.List[str]] = None, + **create_kwargs: t.Any, + ) -> None: + properties = create_kwargs.pop("properties", None) + if not properties: + properties = exp.Properties(expressions=[]) + if replace: + properties.append("expressions", exp.CopyGrantsProperty()) + + super().create_view( + view_name=view_name, + query_or_df=query_or_df, + target_columns_to_types=target_columns_to_types, + replace=replace, + materialized=materialized, + materialized_properties=materialized_properties, + table_description=table_description, + column_descriptions=column_descriptions, + view_properties=view_properties, + properties=properties, + source_columns=source_columns, + **create_kwargs, + ) + def drop_managed_table(self, table_name: TableName, exists: bool = True) -> None: - self._drop_tablelike_object(table_name, exists, kind=self.MANAGED_TABLE_KIND) + self._drop_object(table_name, exists, kind=self.MANAGED_TABLE_KIND) def _build_table_properties_exp( self, catalog_name: t.Optional[str] = None, + table_format: t.Optional[str] = None, storage_format: t.Optional[str] = None, partitioned_by: t.Optional[t.List[exp.Expression]] = None, partition_interval_unit: t.Optional[IntervalUnit] = None, - clustered_by: t.Optional[t.List[str]] = None, + clustered_by: t.Optional[t.List[exp.Expression]] = None, table_properties: t.Optional[t.Dict[str, exp.Expression]] = None, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, table_kind: t.Optional[str] = None, + **kwargs: t.Any, ) -> t.Optional[exp.Properties]: properties: t.List[exp.Expression] = [] @@ -176,40 +332,83 @@ def _build_table_properties_exp( ) ) - if clustered_by: - properties.append(exp.Cluster(expressions=[exp.column(col) for col in clustered_by])) + if ( + clustered_by + and (clustered_by_prop := self._build_clustered_by_exp(clustered_by)) is not None + ): + properties.append(clustered_by_prop) if table_properties: table_properties = {k.upper(): v for k, v in table_properties.items()} # if we are creating a non-dynamic table; remove any properties that are only valid for dynamic tables - if table_kind != self.MANAGED_TABLE_KIND: + # this is necessary because we create "normal" tables from the same managed model definition for dev previews and the "normal" tables dont support these parameters + if "DYNAMIC" not in (table_kind or "").upper(): for prop in {"WAREHOUSE", "TARGET_LAG", "REFRESH_MODE", "INITIALIZE"}: table_properties.pop(prop, None) - properties.extend(self._table_or_view_properties_to_expressions(table_properties)) + table_type = self._pop_creatable_type_from_properties(table_properties) + properties.extend(ensure_list(table_type)) - if properties: - return exp.Properties(expressions=properties) + properties.extend(self._table_or_view_properties_to_expressions(table_properties)) - return None + return exp.Properties(expressions=properties) if properties else None def _df_to_source_queries( self, df: DF, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], batch_size: int, target_table: TableName, + source_columns: t.Optional[t.List[str]] = None, ) -> t.List[SourceQuery]: + import pandas as pd + from pandas.api.types import is_datetime64_any_dtype + + source_columns_to_types = get_source_columns_to_types( + target_columns_to_types, source_columns + ) + temp_table = self._get_temp_table( target_table or "pandas", quoted=False ) # write_pandas() re-quotes everything without checking if its already quoted + is_snowpark_dataframe = snowpark and isinstance(df, snowpark.dataframe.DataFrame) + def query_factory() -> Query: - if snowpark and isinstance(df, snowpark.dataframe.DataFrame): - df.createOrReplaceTempView(temp_table.sql(dialect=self.dialect, identify=True)) + # The catalog needs to be normalized before being passed to Snowflake's library functions because they + # just wrap whatever they are given in quotes without checking if its already quoted + database = ( + normalize_identifiers(temp_table.catalog, dialect=self.dialect) + if temp_table.catalog + else None + ) + + if is_snowpark_dataframe: + temp_table.set("catalog", database) + + # only quote columns if they arent already quoted + # if the Snowpark dataframe was created from a Pandas dataframe via snowpark.create_dataframe(pandas_df), + # then they will be quoted already. But if the Snowpark dataframe was created manually by the user, then the + # columns may not be quoted + columns_already_quoted = all( + col.startswith('"') and col.endswith('"') for col in df.columns + ) + local_df = df + if not columns_already_quoted: + local_df = df.rename( + { + col: exp.to_identifier(col).sql(dialect=self.dialect, identify=True) + for col in source_columns_to_types + } + ) # type: ignore + local_df.createOrReplaceTempView( + temp_table.sql(dialect=self.dialect, identify=True) + ) # type: ignore elif isinstance(df, pd.DataFrame): from snowflake.connector.pandas_tools import write_pandas + ordered_df = df[list(source_columns_to_types)] + # Workaround for https://github.com/snowflakedb/snowflake-connector-python/issues/1034 # The above issue has already been fixed upstream, but we keep the following # line anyway in order to support a wider range of Snowflake versions. @@ -219,34 +418,30 @@ def query_factory() -> Query: self.set_current_schema(schema) # See: https://stackoverflow.com/a/75627721 - for column, kind in columns_to_types.items(): - if is_datetime64_any_dtype(df.dtypes[column]): + for column, kind in source_columns_to_types.items(): + if is_datetime64_any_dtype(ordered_df.dtypes[column]): if kind.is_type("date"): # type: ignore - df[column] = pd.to_datetime(df[column]).dt.date # type: ignore - elif getattr(df.dtypes[column], "tz", None) is not None: # type: ignore - df[column] = pd.to_datetime(df[column]).dt.strftime( + ordered_df[column] = pd.to_datetime(ordered_df[column]).dt.date # type: ignore + elif getattr(ordered_df.dtypes[column], "tz", None) is not None: # type: ignore + ordered_df[column] = pd.to_datetime(ordered_df[column]).dt.strftime( "%Y-%m-%d %H:%M:%S.%f%z" ) # type: ignore # https://github.com/snowflakedb/snowflake-connector-python/issues/1677 else: # type: ignore - df[column] = pd.to_datetime(df[column]).dt.strftime( + ordered_df[column] = pd.to_datetime(ordered_df[column]).dt.strftime( "%Y-%m-%d %H:%M:%S.%f" ) # type: ignore # create the table first using our usual method ensure the column datatypes match what we parsed with sqlglot # otherwise we would be trusting `write_pandas()` from the snowflake lib to do this correctly - self.create_table(temp_table, columns_to_types, table_kind="TEMPORARY TABLE") + self.create_table(temp_table, source_columns_to_types, table_kind="TEMPORARY TABLE") write_pandas( self._connection_pool.get(), - df, + ordered_df, temp_table.name, schema=temp_table.db or None, - database=normalize_identifiers(temp_table.catalog, dialect=self.dialect).sql( - dialect=self.dialect - ) - if temp_table.catalog - else None, + database=database.sql(dialect=self.dialect) if database else None, chunk_size=self.DEFAULT_BATCH_SIZE, overwrite=True, table_type="temp", @@ -256,19 +451,29 @@ def query_factory() -> Query: f"Unknown dataframe type: {type(df)} for {target_table}. Expecting pandas or snowpark." ) - return exp.select(*self._casted_columns(columns_to_types)).from_(temp_table) + return exp.select( + *self._casted_columns(target_columns_to_types, source_columns=source_columns) + ).from_(temp_table) + + def cleanup() -> None: + if is_snowpark_dataframe: + if hasattr(df, "table_name"): + if isinstance(df.table_name, str): + # created by the Snowpark library if the Snowpark DataFrame was created from a Pandas DataFrame + # (if the Snowpark DataFrame was created via native means then there is no 'table_name' property and no temp table) + self.drop_table(df.table_name) + self.drop_view(temp_table) + else: + self.drop_table(temp_table) # the cleanup_func technically isnt needed because the temp table gets dropped when the session ends # but boy does it make our multi-adapter integration tests easier to write - return [ - SourceQuery( - query_factory=query_factory, cleanup_func=lambda: self.drop_table(temp_table) - ) - ] + return [SourceQuery(query_factory=query_factory, cleanup_func=cleanup)] def _fetch_native_df( self, query: t.Union[exp.Expression, str], quote_identifiers: bool = False ) -> DF: + import pandas as pd from snowflake.connector.errors import NotSupportedError self.execute(query, quote_identifiers=quote_identifiers) @@ -283,6 +488,15 @@ def _fetch_native_df( columns = self.cursor._result_set.batches[0].column_names return pd.DataFrame([dict(zip(columns, row)) for row in rows]) + def _native_df_to_pandas_df( + self, + query_or_df: QueryOrDF, + ) -> t.Union[Query, pd.DataFrame]: + if snowpark and isinstance(query_or_df, snowpark.DataFrame): + return query_or_df.to_pandas() + + return super()._native_df_to_pandas_df(query_or_df) + def _get_data_objects( self, schema_name: SchemaName, object_names: t.Optional[t.Set[str]] = None ) -> t.List[DataObject]: @@ -318,6 +532,7 @@ def _get_data_objects( ) .else_(exp.column("TABLE_TYPE")) .as_("type"), + exp.column("CLUSTERING_KEY").as_("clustering_key"), ) .from_(exp.table_("TABLES", db="INFORMATION_SCHEMA", catalog=catalog_name)) .where(exp.column("TABLE_SCHEMA").eq(schema.db)) @@ -327,6 +542,10 @@ def _get_data_objects( if object_names: query = query.where(exp.column("TABLE_NAME").isin(*object_names)) + # exclude SNOWPARK_TEMP_TABLE tables that are managed by the Snowpark library and are an implementation + # detail of dealing with DataFrame's + query = query.where(exp.column("TABLE_NAME").like("SNOWPARK_TEMP_TABLE%").not_()) + df = self.fetchdf(query, quote_identifiers=True) if df.empty: return [] @@ -336,17 +555,38 @@ def _get_data_objects( schema=row.schema_name, # type: ignore name=row.name, # type: ignore type=DataObjectType.from_str(row.type), # type: ignore + clustering_key=row.clustering_key, # type: ignore ) - for row in df.itertuples() + # lowercase the column names for cases where Snowflake might return uppercase column names for certain catalogs + for row in df.rename(columns={col: col.lower() for col in df.columns}).itertuples() ] + def _get_grant_expression(self, table: exp.Table) -> exp.Expression: + # Upon execute the catalog in table expressions are properly normalized to handle the case where a user provides + # the default catalog in their connection config. This doesn't though update catalogs in strings like when querying + # the information schema. So we need to manually replace those here. + expression = super()._get_grant_expression(table) + for col_exp in expression.find_all(exp.Column): + if col_exp.this.name == "table_catalog": + and_exp = col_exp.parent + assert and_exp is not None, "Expected column expression to have a parent" + assert and_exp.expression, "Expected AND expression to have an expression" + normalized_catalog = self._normalize_catalog( + exp.table_("placeholder", db="placeholder", catalog=and_exp.expression.this) + ) + and_exp.set( + "expression", + exp.Literal.string(normalized_catalog.args["catalog"].alias_or_name), + ) + return expression + def set_current_catalog(self, catalog: str) -> None: self.execute(exp.Use(this=exp.to_identifier(catalog))) def set_current_schema(self, schema: str) -> None: self.execute(exp.Use(kind="SCHEMA", this=to_schema(schema))) - def _to_sql(self, expression: exp.Expression, quote: bool = True, **kwargs: t.Any) -> str: + def _normalize_catalog(self, expression: exp.Expression) -> exp.Expression: # note: important to use self._default_catalog instead of the self.default_catalog property # otherwise we get RecursionError: maximum recursion depth exceeded # because it calls get_current_catalog(), which executes a query, which needs the default catalog, which calls get_current_catalog()... etc @@ -379,16 +619,120 @@ def catalog_rewriter(node: exp.Expression) -> exp.Expression: # Snowflake connection config. This is because the catalog present on the model gets normalized and quoted to match # the source dialect, which isnt always compatible with Snowflake expression = expression.transform(catalog_rewriter) + return expression + + def _to_sql(self, expression: exp.Expression, quote: bool = True, **kwargs: t.Any) -> str: + return super()._to_sql( + expression=self._normalize_catalog(expression), quote=quote, **kwargs + ) + + def _create_column_comments( + self, + table_name: TableName, + column_comments: t.Dict[str, str], + table_kind: str = "TABLE", + materialized_view: bool = False, + ) -> None: + """ + Reference: https://docs.snowflake.com/en/sql-reference/sql/alter-table-column#syntax + """ + if not column_comments: + return + + table = exp.to_table(table_name) + table_sql = self._to_sql(table) + + list_comment_sql = [] + for column_name, column_comment in column_comments.items(): + column_sql = exp.column(column_name).sql(dialect=self.dialect, identify=True) + + truncated_comment = self._truncate_column_comment(column_comment) + comment_sql = exp.Literal.string(truncated_comment).sql(dialect=self.dialect) + + list_comment_sql.append(f"COLUMN {column_sql} COMMENT {comment_sql}") + + combined_sql = f"ALTER {table_kind} {table_sql} ALTER {', '.join(list_comment_sql)}" + try: + self.execute(combined_sql) + except Exception: + logger.warning( + f"Column comments for table '{table.alias_or_name}' not registered - this may be due to limited permissions.", + exc_info=True, + ) + + def clone_table( + self, + target_table_name: TableName, + source_table_name: TableName, + replace: bool = False, + exists: bool = True, + clone_kwargs: t.Optional[t.Dict[str, t.Any]] = None, + **kwargs: t.Any, + ) -> None: + # The Snowflake adapter should use the transient property to clone transient tables + if physical_properties := kwargs.get("rendered_physical_properties"): + table_type = self._pop_creatable_type_from_properties(physical_properties) + if isinstance(table_type, exp.TransientProperty): + kwargs["properties"] = exp.Properties(expressions=[table_type]) + + super().clone_table( + target_table_name, + source_table_name, + replace=replace, + clone_kwargs=clone_kwargs, + **kwargs, + ) + + @t.overload + def _columns_to_types( + self, + query_or_df: DF, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, + ) -> t.Tuple[t.Dict[str, exp.DataType], t.List[str]]: ... + + @t.overload + def _columns_to_types( + self, + query_or_df: Query, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, + ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: ... + + def _columns_to_types( + self, + query_or_df: QueryOrDF, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, + ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: + if not target_columns_to_types and snowpark and isinstance(query_or_df, snowpark.DataFrame): + target_columns_to_types = columns_to_types_from_dtypes( + query_or_df.sample(n=1).to_pandas().dtypes.items() + ) + return target_columns_to_types, list(source_columns or target_columns_to_types) + + return super()._columns_to_types( + query_or_df, target_columns_to_types, source_columns=source_columns + ) + + def close(self) -> t.Any: + if snowpark_session := self._connection_pool.get_attribute(self.SNOWPARK): + snowpark_session.close() # type: ignore + self._connection_pool.set_attribute(self.SNOWPARK, None) + + return super().close() - return super()._to_sql(expression=expression, quote=quote, **kwargs) + def get_table_last_modified_ts(self, table_names: t.List[TableName]) -> t.List[int]: + from sqlmesh.utils.date import to_timestamp - def _build_create_comment_column_exp( - self, table: exp.Table, column_name: str, column_comment: str, table_kind: str = "TABLE" - ) -> exp.Comment | str: - table_sql = self._to_sql(table) # so that catalog replacement happens - column_sql = exp.column(column_name).sql(dialect=self.dialect, identify=True) + num_tables = len(table_names) - truncated_comment = self._truncate_column_comment(column_comment) - comment_sql = exp.Literal.string(truncated_comment).sql(dialect=self.dialect) + query = "SELECT LAST_ALTERED FROM INFORMATION_SCHEMA.TABLES WHERE" + for i, table_name in enumerate(table_names): + table = exp.to_table(table_name) + query += f"""(TABLE_NAME = '{table.name}' AND TABLE_SCHEMA = '{table.db}' AND TABLE_CATALOG = '{table.catalog}')""" + if i < num_tables - 1: + query += " OR " - return f"ALTER {table_kind} {table_sql} ALTER COLUMN {column_sql} COMMENT {comment_sql}" + result = self.fetchall(query) + return [to_timestamp(row[0]) for row in result] diff --git a/sqlmesh/core/engine_adapter/spark.py b/sqlmesh/core/engine_adapter/spark.py index 58837a1284..5216b0a329 100644 --- a/sqlmesh/core/engine_adapter/spark.py +++ b/sqlmesh/core/engine_adapter/spark.py @@ -4,13 +4,13 @@ import typing as t from functools import partial -import pandas as pd from sqlglot import exp from sqlmesh.core.dialect import to_schema from sqlmesh.core.engine_adapter.mixins import ( GetCurrentCatalogFromFunctionMixin, HiveMetastoreTablePropertiesMixin, + RowDiffMixin, ) from sqlmesh.core.engine_adapter.shared import ( CatalogSupport, @@ -22,11 +22,11 @@ SourceQuery, set_catalog, ) -from sqlmesh.core.schema_diff import SchemaDiffer -from sqlmesh.utils import classproperty +from sqlmesh.utils import classproperty, get_source_columns_to_types from sqlmesh.utils.errors import SQLMeshError if t.TYPE_CHECKING: + import pandas as pd from pyspark.sql import types as spark_types from sqlmesh.core._typing import SchemaName, TableName @@ -44,27 +44,28 @@ @set_catalog() -class SparkEngineAdapter(GetCurrentCatalogFromFunctionMixin, HiveMetastoreTablePropertiesMixin): +class SparkEngineAdapter( + GetCurrentCatalogFromFunctionMixin, HiveMetastoreTablePropertiesMixin, RowDiffMixin +): DIALECT = "spark" SUPPORTS_TRANSACTIONS = False INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.INSERT_OVERWRITE - CATALOG_SUPPORT = CatalogSupport.FULL_SUPPORT - SUPPORTS_ROW_LEVEL_OP = False COMMENT_CREATION_TABLE = CommentCreationTable.IN_SCHEMA_DEF_NO_CTAS COMMENT_CREATION_VIEW = CommentCreationView.IN_SCHEMA_DEF_NO_COMMANDS # Note: Some formats (like Delta and Iceberg) support REPLACE TABLE but since we don't # currently check for storage formats we say we don't support REPLACE TABLE SUPPORTS_REPLACE_TABLE = False QUOTE_IDENTIFIERS_IN_VIEWS = False + SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["DATABASE", "SCHEMA"] WAP_PREFIX = "wap_" BRANCH_PREFIX = "branch_" - SCHEMA_DIFFER = SchemaDiffer( - parameterized_type_defaults={ + SCHEMA_DIFFER_KWARGS = { + "parameterized_type_defaults": { # default decimal precision varies across backends exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(), (0,)], }, - ) + } @property def connection(self) -> SparkSessionConnection: @@ -78,6 +79,14 @@ def spark(self) -> PySparkSession: def _use_spark_session(self) -> bool: return True + @property + def use_serverless(self) -> bool: + return False + + @property + def catalog_support(self) -> CatalogSupport: + return CatalogSupport.FULL_SUPPORT + @classproperty def _sqlglot_to_spark_primitive_mapping(self) -> t.Dict[t.Any, t.Any]: from pyspark.sql import types as spark_types @@ -90,17 +99,17 @@ def _sqlglot_to_spark_primitive_mapping(self) -> t.Dict[t.Any, t.Any]: exp.DataType.Type.FLOAT: spark_types.FloatType, exp.DataType.Type.DOUBLE: spark_types.DoubleType, exp.DataType.Type.DECIMAL: spark_types.DecimalType, - # SQLGlot currently converts VARCHAR and CHAR to Strings exp.DataType.Type.VARCHAR: spark_types.StringType, exp.DataType.Type.CHAR: spark_types.StringType, exp.DataType.Type.TEXT: spark_types.StringType, exp.DataType.Type.BINARY: spark_types.BinaryType, exp.DataType.Type.BOOLEAN: spark_types.BooleanType, exp.DataType.Type.DATE: spark_types.DateType, + exp.DataType.Type.TIMESTAMPNTZ: spark_types.TimestampNTZType, exp.DataType.Type.DATETIME: spark_types.TimestampNTZType, exp.DataType.Type.TIMESTAMPLTZ: spark_types.TimestampType, - exp.DataType.Type.TIMESTAMPTZ: spark_types.TimestampType, exp.DataType.Type.TIMESTAMP: spark_types.TimestampType, + exp.DataType.Type.TIMESTAMPTZ: spark_types.TimestampType, } @classproperty @@ -190,9 +199,9 @@ def sqlglot_complex_to_spark_complex(complex_type: exp.DataType) -> spark_types. else partial(sqlglot_complex_to_spark_complex, data_type) ) if is_struct: - expressions.append(spark_types.StructField(col_name, type_func())) + expressions.append(spark_types.StructField(col_name, type_func())) # type: ignore else: - expressions.append(type_func()) + expressions.append(type_func()) # type: ignore klass = cls._sqlglot_to_spark_complex_mapping[complex_type.this] if is_struct: return klass(expressions) @@ -223,65 +232,99 @@ def try_get_pyspark_df(cls, value: t.Any) -> t.Optional[PySparkDataFrame]: @classmethod def try_get_pandas_df(cls, value: t.Any) -> t.Optional[pd.DataFrame]: + import pandas as pd + if isinstance(value, pd.DataFrame): return value return None @t.overload def _columns_to_types( - self, query_or_df: DF, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None - ) -> t.Dict[str, exp.DataType]: ... + self, + query_or_df: DF, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, + ) -> t.Tuple[t.Dict[str, exp.DataType], t.List[str]]: ... @t.overload def _columns_to_types( - self, query_or_df: Query, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None - ) -> t.Optional[t.Dict[str, exp.DataType]]: ... + self, + query_or_df: Query, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, + ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: ... def _columns_to_types( - self, query_or_df: QueryOrDF, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None - ) -> t.Optional[t.Dict[str, exp.DataType]]: - if columns_to_types: - return columns_to_types + self, + query_or_df: QueryOrDF, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, + ) -> t.Tuple[t.Optional[t.Dict[str, exp.DataType]], t.Optional[t.List[str]]]: + if target_columns_to_types: + return target_columns_to_types, list(source_columns or target_columns_to_types) if self.is_pyspark_df(query_or_df): from pyspark.sql import DataFrame - return self.spark_to_sqlglot_types(t.cast(DataFrame, query_or_df).schema) - return super()._columns_to_types(query_or_df, columns_to_types) + target_columns_to_types = self.spark_to_sqlglot_types( + t.cast(DataFrame, query_or_df).schema + ) + return target_columns_to_types, list(source_columns or target_columns_to_types) + return super()._columns_to_types( + query_or_df, target_columns_to_types, source_columns=source_columns + ) def _df_to_source_queries( self, df: DF, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], batch_size: int, target_table: TableName, + source_columns: t.Optional[t.List[str]] = None, ) -> t.List[SourceQuery]: - if not self._use_spark_session: - return super()._df_to_source_queries(df, columns_to_types, batch_size, target_table) - df = self._ensure_pyspark_df(df, columns_to_types) + df = self._ensure_pyspark_df(df, target_columns_to_types, source_columns=source_columns) def query_factory() -> Query: temp_table = self._get_temp_table(target_table or "spark", table_only=True) df.createOrReplaceGlobalTempView(temp_table.sql(dialect=self.dialect)) # type: ignore temp_table.set("db", "global_temp") - return exp.select(*self._casted_columns(columns_to_types)).from_(temp_table) + return exp.select(*self._select_columns(target_columns_to_types)).from_(temp_table) - if self._use_spark_session: - return [SourceQuery(query_factory=query_factory)] - return super()._df_to_source_queries(df, columns_to_types, batch_size, target_table) + return [SourceQuery(query_factory=query_factory)] def _ensure_pyspark_df( - self, generic_df: DF, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None + self, + generic_df: DF, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + source_columns: t.Optional[t.List[str]] = None, ) -> PySparkDataFrame: pyspark_df = self.try_get_pyspark_df(generic_df) - if pyspark_df: - return pyspark_df - df = self.try_get_pandas_df(generic_df) - if df is None: - raise SQLMeshError("Ensure PySpark DF can only be run on a PySpark or Pandas DataFrame") - kwargs = ( - dict(schema=self.sqlglot_to_spark_types(columns_to_types)) if columns_to_types else {} - ) - return self.spark.createDataFrame(df, **kwargs) # type: ignore + if not pyspark_df: + df = self.try_get_pandas_df(generic_df) + if df is None: + raise SQLMeshError( + "Ensure PySpark DF can only be run on a PySpark or Pandas DataFrame" + ) + + if target_columns_to_types: + source_columns_to_types = get_source_columns_to_types( + target_columns_to_types, source_columns + ) + # ensure Pandas dataframe column order matches columns_to_types + df = df[list(source_columns_to_types)] + else: + source_columns_to_types = None + kwargs = ( + dict(schema=self.sqlglot_to_spark_types(source_columns_to_types)) + if source_columns_to_types + else {} + ) + pyspark_df = self.spark.createDataFrame(df, **kwargs) # type: ignore + if target_columns_to_types: + select_columns = self._casted_columns( + target_columns_to_types, source_columns=source_columns + ) + pyspark_df = pyspark_df.selectExpr(*[x.sql(self.dialect) for x in select_columns]) # type: ignore + return pyspark_df def _get_temp_table( self, table: TableName, table_only: bool = False, quoted: bool = True @@ -325,21 +368,26 @@ def _get_data_objects( # Therefore just doing except Exception for now. except Exception: return [] - return [ - DataObject( - catalog=self.get_current_catalog(), - # This varies between Spark and Databricks - schema=(row.asDict() if not isinstance(row, dict) else row).get("namespace") - or row["database"], - name=row["tableName"], - type=( - DataObjectType.VIEW - if "Type: VIEW" in row["information"] - else DataObjectType.TABLE - ), + data_objects = [] + catalog = self.get_current_catalog() + for row in results: # type: ignore + row_dict = row.asDict() if not isinstance(row, dict) else row + if row_dict.get("isTemporary"): + continue + schema = row_dict.get("namespace") or row_dict.get("database") + data_objects.append( + DataObject( + catalog=catalog, + schema=schema, + name=row_dict["tableName"], + type=( + DataObjectType.VIEW + if "Type: VIEW" in row_dict["information"] + else DataObjectType.TABLE + ), + ) ) - for row in results # type: ignore - ] + return data_objects def get_current_catalog(self) -> t.Optional[str]: if self._use_spark_session: @@ -349,64 +397,42 @@ def get_current_catalog(self) -> t.Optional[str]: def set_current_catalog(self, catalog_name: str) -> None: self.connection.set_current_catalog(catalog_name) - def get_current_database(self) -> str: + def _get_current_schema(self) -> str: if self._use_spark_session: return self.spark.catalog.currentDatabase() - return self.fetchone(exp.select(exp.func("current_database")))[0] + return self.fetchone(exp.select(exp.func("current_database")))[0] # type: ignore + + def get_data_object( + self, target_name: TableName, safe_to_cache: bool = False + ) -> t.Optional[DataObject]: + target_table = exp.to_table(target_name) + if isinstance(target_table.this, exp.Dot) and target_table.this.expression.name.startswith( + f"{self.BRANCH_PREFIX}{self.WAP_PREFIX}" + ): + # Exclude the branch name + target_table.set("this", target_table.this.this) + return super().get_data_object(target_table, safe_to_cache=safe_to_cache) def create_state_table( self, table_name: str, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], primary_key: t.Optional[t.Tuple[str, ...]] = None, ) -> None: self.create_table( table_name, - columns_to_types, + target_columns_to_types, partitioned_by=[exp.column(x) for x in primary_key] if primary_key else None, ) - def create_view( + def _native_df_to_pandas_df( self, - view_name: TableName, query_or_df: QueryOrDF, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, - replace: bool = True, - materialized: bool = False, - table_description: t.Optional[str] = None, - column_descriptions: t.Optional[t.Dict[str, str]] = None, - view_properties: t.Optional[t.Dict[str, exp.Expression]] = None, - **create_kwargs: t.Any, - ) -> None: - """Create a view with a query or dataframe. - - If a dataframe is passed in, it will be converted into a literal values statement. - This should only be done if the dataframe is very small! - - Args: - view_name: The view name. - query_or_df: A query or dataframe. - columns_to_types: Columns to use in the view statement. - replace: Whether or not to replace an existing view - defaults to True. - materialized: Whether or not the view should be materialized - defaults to False. - table_description: Optional table description from MODEL DDL. - column_descriptions: Optional column descriptions from model query. - create_kwargs: Additional kwargs to pass into the Create expression - """ - pyspark_df = self.try_get_pyspark_df(query_or_df) - if pyspark_df: - query_or_df = pyspark_df.toPandas() - super().create_view( - view_name, - query_or_df, - columns_to_types, - replace, - materialized, - table_description, - column_descriptions, - view_properties=view_properties, - **create_kwargs, - ) + ) -> t.Union[Query, pd.DataFrame]: + if pyspark_df := self.try_get_pyspark_df(query_or_df): + return pyspark_df.toPandas() + + return super()._native_df_to_pandas_df(query_or_df) def _create_table( self, @@ -414,10 +440,11 @@ def _create_table( expression: t.Optional[exp.Expression], exists: bool = True, replace: bool = False, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, table_kind: t.Optional[str] = None, + track_rows_processed: bool = True, **kwargs: t.Any, ) -> None: table_name = ( @@ -432,20 +459,23 @@ def _create_table( if wap_id.startswith(f"{self.BRANCH_PREFIX}{self.WAP_PREFIX}"): table_name.set("this", table_name.this.this) - wap_supported = ( - kwargs.get("storage_format") or "" - ).lower() == "iceberg" or self.wap_supported(table_name) - do_dummy_insert = ( - False if not wap_supported or not exists else not self.table_exists(table_name) - ) + do_dummy_insert = False + if self.wap_enabled: + wap_supported = ( + kwargs.get("storage_format") or "" + ).lower() == "iceberg" or self.wap_supported(table_name) + do_dummy_insert = ( + False if not wap_supported or not exists else not self.table_exists(table_name) + ) super()._create_table( table_name_or_schema, expression, exists=exists, replace=replace, - columns_to_types=columns_to_types, + target_columns_to_types=target_columns_to_types, table_description=table_description, column_descriptions=column_descriptions, + track_rows_processed=track_rows_processed, **kwargs, ) table_name = ( @@ -509,7 +539,7 @@ def _ensure_fqn(self, table_name: TableName) -> exp.Table: if not table.catalog: table.set("catalog", self.get_current_catalog()) if not table.db: - table.set("db", self.get_current_database()) + table.set("db", self._get_current_schema()) return table def _build_create_comment_column_exp( diff --git a/sqlmesh/core/engine_adapter/trino.py b/sqlmesh/core/engine_adapter/trino.py index 47619a34e8..89470728f2 100644 --- a/sqlmesh/core/engine_adapter/trino.py +++ b/sqlmesh/core/engine_adapter/trino.py @@ -1,18 +1,20 @@ from __future__ import annotations +import contextlib +import re import typing as t from functools import lru_cache -import pandas as pd -from pandas.api.types import is_datetime64_any_dtype # type: ignore from sqlglot import exp from sqlglot.helper import seq_get +from tenacity import retry, wait_fixed, stop_after_attempt, retry_if_result from sqlmesh.core.dialect import schema_, to_schema from sqlmesh.core.engine_adapter.mixins import ( GetCurrentCatalogFromFunctionMixin, HiveMetastoreTablePropertiesMixin, PandasNativeFetchDFSupportMixin, + RowDiffMixin, ) from sqlmesh.core.engine_adapter.shared import ( CatalogSupport, @@ -24,48 +26,83 @@ SourceQuery, set_catalog, ) -from sqlmesh.core.schema_diff import SchemaDiffer +from sqlmesh.utils import get_source_columns_to_types +from sqlmesh.utils.errors import SQLMeshError from sqlmesh.utils.date import TimeLike if t.TYPE_CHECKING: - from trino.dbapi import Connection as TrinoConnection - - from sqlmesh.core._typing import SchemaName, TableName + from sqlmesh.core._typing import SchemaName, SessionProperties, TableName from sqlmesh.core.engine_adapter._typing import DF, QueryOrDF +CATALOG_TYPES_SUPPORTING_REPLACE_TABLE = {"iceberg", "delta_lake"} + @set_catalog() class TrinoEngineAdapter( PandasNativeFetchDFSupportMixin, HiveMetastoreTablePropertiesMixin, GetCurrentCatalogFromFunctionMixin, + RowDiffMixin, ): DIALECT = "trino" INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.INTO_IS_OVERWRITE - CATALOG_SUPPORT = CatalogSupport.FULL_SUPPORT # Trino does technically support transactions but it doesn't work correctly with partition overwrite so we # disable transactions. If we need to get them enabled again then we would need to disable auto commit on the # connector and then figure out how to get insert/overwrite to work correctly without it. SUPPORTS_TRANSACTIONS = False - SUPPORTS_ROW_LEVEL_OP = False CURRENT_CATALOG_EXPRESSION = exp.column("current_catalog") COMMENT_CREATION_TABLE = CommentCreationTable.IN_SCHEMA_DEF_NO_CTAS COMMENT_CREATION_VIEW = CommentCreationView.COMMENT_COMMAND_ONLY SUPPORTS_REPLACE_TABLE = False + SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["SCHEMA"] DEFAULT_CATALOG_TYPE = "hive" QUOTE_IDENTIFIERS_IN_VIEWS = False - SCHEMA_DIFFER = SchemaDiffer( - parameterized_type_defaults={ + SUPPORTS_QUERY_EXECUTION_TRACKING = True + SCHEMA_DIFFER_KWARGS = { + "parameterized_type_defaults": { # default decimal precision varies across backends exp.DataType.build("DECIMAL", dialect=DIALECT).this: [(), (0,)], exp.DataType.build("CHAR", dialect=DIALECT).this: [(1,)], exp.DataType.build("TIMESTAMP", dialect=DIALECT).this: [(3,)], }, - ) + } + # some catalogs support microsecond (precision 6) but it has to be specifically enabled (Hive) or just isnt available (Delta / TIMESTAMP WITH TIME ZONE) + # and even if you have a TIMESTAMP(6) the date formatting functions still only support millisecond precision + MAX_TIMESTAMP_PRECISION = 3 + + @property + def schema_location_mapping(self) -> t.Optional[t.Dict[re.Pattern, str]]: + return self._extra_config.get("schema_location_mapping") + + @property + def timestamp_mapping(self) -> t.Optional[t.Dict[exp.DataType, exp.DataType]]: + return self._extra_config.get("timestamp_mapping") + + def _apply_timestamp_mapping( + self, columns_to_types: t.Dict[str, exp.DataType] + ) -> t.Tuple[t.Dict[str, exp.DataType], t.Set[str]]: + """Apply custom timestamp mapping to column types. + + Returns: + A tuple of (mapped_columns_to_types, mapped_column_names) where mapped_column_names + contains the names of columns that were found in the mapping. + """ + if not self.timestamp_mapping: + return columns_to_types, set() + + result = {} + mapped_columns: t.Set[str] = set() + for column, column_type in columns_to_types.items(): + if column_type in self.timestamp_mapping: + result[column] = self.timestamp_mapping[column_type] + mapped_columns.add(column) + else: + result[column] = column_type + return result, mapped_columns @property - def connection(self) -> TrinoConnection: - return self.cursor.connection + def catalog_support(self) -> CatalogSupport: + return CatalogSupport.FULL_SUPPORT def set_current_catalog(self, catalog: str) -> None: """Sets the catalog name of the current connection.""" @@ -75,18 +112,78 @@ def set_current_catalog(self, catalog: str) -> None: def get_catalog_type(self, catalog: t.Optional[str]) -> str: row: t.Tuple = tuple() if catalog: - row = self.fetchone( - f"select connector_name from system.metadata.catalogs where catalog_name='{catalog}'" + if catalog_type_override := self._catalog_type_overrides.get(catalog): + return catalog_type_override + row = ( + self.fetchone( + f"select connector_name from system.metadata.catalogs where catalog_name='{catalog}'" + ) + or () ) return seq_get(row, 0) or self.DEFAULT_CATALOG_TYPE + @contextlib.contextmanager + def session(self, properties: SessionProperties) -> t.Iterator[None]: + authorization = properties.get("authorization") + if not authorization: + yield + return + + if not isinstance(authorization, exp.Expression): + authorization = exp.Literal.string(authorization) + + if not authorization.is_string: + raise SQLMeshError( + "Invalid value for `session_properties.authorization`. Must be a string literal." + ) + + authorization_sql = authorization.sql(dialect=self.dialect) + + self.execute(f"SET SESSION AUTHORIZATION {authorization_sql}") + try: + yield + finally: + self.execute("RESET SESSION AUTHORIZATION") + + def replace_query( + self, + table_name: TableName, + query_or_df: QueryOrDF, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + table_description: t.Optional[str] = None, + column_descriptions: t.Optional[t.Dict[str, str]] = None, + source_columns: t.Optional[t.List[str]] = None, + supports_replace_table_override: t.Optional[bool] = None, + **kwargs: t.Any, + ) -> None: + catalog_type = self.get_catalog_type_from_table(table_name) + # User may have a custom catalog type name so we are assuming they keep the catalog type still in the name + # Ex: `acme_iceberg` would be identified as an iceberg catalog and therefore supports replace table + supports_replace_table_override = None + for replace_table_catalog_type in CATALOG_TYPES_SUPPORTING_REPLACE_TABLE: + if replace_table_catalog_type in catalog_type: + supports_replace_table_override = True + break + + super().replace_query( + table_name=table_name, + query_or_df=query_or_df, + target_columns_to_types=target_columns_to_types, + table_description=table_description, + column_descriptions=column_descriptions, + source_columns=source_columns, + supports_replace_table_override=supports_replace_table_override, + **kwargs, + ) + def _insert_overwrite_by_condition( self, table_name: TableName, source_queries: t.List[SourceQuery], - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, where: t.Optional[exp.Condition] = None, insert_overwrite_strategy_override: t.Optional[InsertOverwriteStrategy] = None, + **kwargs: t.Any, ) -> None: catalog = exp.to_table(table_name).catalog or self.get_current_catalog() @@ -96,14 +193,14 @@ def _insert_overwrite_by_condition( # "Session property 'catalog.insert_existing_partitions_behavior' does not exist" self.execute(f"SET SESSION {catalog}.insert_existing_partitions_behavior='OVERWRITE'") super()._insert_overwrite_by_condition( - table_name, source_queries, columns_to_types, where + table_name, source_queries, target_columns_to_types, where ) self.execute(f"SET SESSION {catalog}.insert_existing_partitions_behavior='APPEND'") else: super()._insert_overwrite_by_condition( table_name, source_queries, - columns_to_types, + target_columns_to_types, where, insert_overwrite_strategy_override=InsertOverwriteStrategy.DELETE_INSERT, ) @@ -181,35 +278,48 @@ def _get_data_objects( def _df_to_source_queries( self, df: DF, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], batch_size: int, target_table: TableName, + source_columns: t.Optional[t.List[str]] = None, ) -> t.List[SourceQuery]: + import pandas as pd + from pandas.api.types import is_datetime64_any_dtype # type: ignore + assert isinstance(df, pd.DataFrame) + source_columns_to_types = get_source_columns_to_types( + target_columns_to_types, source_columns + ) # Trino does not accept timestamps in ISOFORMAT that include the "T". `execution_time` is stored in # Pandas with that format, so we convert the column to a string with the proper format and CAST to # timestamp in Trino. - for column, kind in (columns_to_types or {}).items(): + for column, kind in source_columns_to_types.items(): dtype = df.dtypes[column] if is_datetime64_any_dtype(dtype) and getattr(dtype, "tz", None) is not None: df[column] = pd.to_datetime(df[column]).map(lambda x: x.isoformat(" ")) - return super()._df_to_source_queries(df, columns_to_types, batch_size, target_table) + return super()._df_to_source_queries( + df, target_columns_to_types, batch_size, target_table, source_columns=source_columns + ) def _build_schema_exp( self, table: exp.Table, - columns_to_types: t.Dict[str, exp.DataType], + target_columns_to_types: t.Dict[str, exp.DataType], column_descriptions: t.Optional[t.Dict[str, str]] = None, expressions: t.Optional[t.List[exp.PrimaryKey]] = None, is_view: bool = False, + materialized: bool = False, ) -> exp.Schema: - if self.current_catalog_type == "delta_lake": - columns_to_types = self._to_delta_ts(columns_to_types) + target_columns_to_types, mapped_columns = self._apply_timestamp_mapping( + target_columns_to_types + ) + if "delta_lake" in self.get_catalog_type_from_table(table): + target_columns_to_types = self._to_delta_ts(target_columns_to_types, mapped_columns) return super()._build_schema_exp( - table, columns_to_types, column_descriptions, expressions, is_view + table, target_columns_to_types, column_descriptions, expressions, is_view ) def _scd_type_2( @@ -219,19 +329,28 @@ def _scd_type_2( unique_key: t.Sequence[exp.Expression], valid_from_col: exp.Column, valid_to_col: exp.Column, - execution_time: TimeLike, + execution_time: t.Union[TimeLike, exp.Column], invalidate_hard_deletes: bool = True, updated_at_col: t.Optional[exp.Column] = None, - check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Column]]] = None, + check_columns: t.Optional[t.Union[exp.Star, t.Sequence[exp.Expression]]] = None, updated_at_as_valid_from: bool = False, execution_time_as_valid_from: bool = False, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, table_description: t.Optional[str] = None, column_descriptions: t.Optional[t.Dict[str, str]] = None, truncate: bool = False, + source_columns: t.Optional[t.List[str]] = None, + **kwargs: t.Any, ) -> None: - if columns_to_types and self.current_catalog_type == "delta_lake": - columns_to_types = self._to_delta_ts(columns_to_types) + mapped_columns: t.Set[str] = set() + if target_columns_to_types: + target_columns_to_types, mapped_columns = self._apply_timestamp_mapping( + target_columns_to_types + ) + if target_columns_to_types and "delta_lake" in self.get_catalog_type_from_table( + target_table + ): + target_columns_to_types = self._to_delta_ts(target_columns_to_types, mapped_columns) return super()._scd_type_2( target_table, @@ -245,10 +364,12 @@ def _scd_type_2( check_columns, updated_at_as_valid_from, execution_time_as_valid_from, - columns_to_types, + target_columns_to_types, table_description, column_descriptions, truncate, + source_columns, + **kwargs, ) # delta_lake only supports two timestamp data types. This method converts other @@ -259,19 +380,100 @@ def _scd_type_2( # - `timestamp(3) with time zone` for timezone-aware # https://trino.io/docs/current/connector/delta-lake.html#delta-lake-to-trino-type-mapping def _to_delta_ts( - self, columns_to_types: t.Dict[str, exp.DataType] + self, + columns_to_types: t.Dict[str, exp.DataType], + skip_columns: t.Optional[t.Set[str]] = None, ) -> t.Dict[str, exp.DataType]: ts6 = exp.DataType.build("timestamp(6)") ts3_tz = exp.DataType.build("timestamp(3) with time zone") + skip = skip_columns or set() delta_columns_to_types = { - k: ts6 if v.is_type(exp.DataType.Type.TIMESTAMP) else v + k: ts6 if k not in skip and v.is_type(exp.DataType.Type.TIMESTAMP) else v for k, v in columns_to_types.items() } delta_columns_to_types = { - k: ts3_tz if v.is_type(exp.DataType.Type.TIMESTAMPTZ) else v + k: ts3_tz if k not in skip and v.is_type(exp.DataType.Type.TIMESTAMPTZ) else v for k, v in delta_columns_to_types.items() } return delta_columns_to_types + + @retry(wait=wait_fixed(1), stop=stop_after_attempt(10), retry=retry_if_result(lambda v: not v)) + def _block_until_table_exists(self, table_name: TableName) -> bool: + return self.table_exists(table_name) + + def _create_schema( + self, + schema_name: SchemaName, + ignore_if_exists: bool, + warn_on_error: bool, + properties: t.List[exp.Expression], + kind: str, + ) -> None: + if mapped_location := self._schema_location(schema_name): + properties.append(exp.LocationProperty(this=exp.Literal.string(mapped_location))) + + return super()._create_schema( + schema_name=schema_name, + ignore_if_exists=ignore_if_exists, + warn_on_error=warn_on_error, + properties=properties, + kind=kind, + ) + + def _create_table( + self, + table_name_or_schema: t.Union[exp.Schema, TableName], + expression: t.Optional[exp.Expression], + exists: bool = True, + replace: bool = False, + target_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + table_description: t.Optional[str] = None, + column_descriptions: t.Optional[t.Dict[str, str]] = None, + table_kind: t.Optional[str] = None, + track_rows_processed: bool = True, + **kwargs: t.Any, + ) -> None: + super()._create_table( + table_name_or_schema=table_name_or_schema, + expression=expression, + exists=exists, + replace=replace, + target_columns_to_types=target_columns_to_types, + table_description=table_description, + column_descriptions=column_descriptions, + table_kind=table_kind, + track_rows_processed=track_rows_processed, + **kwargs, + ) + + # extract the table name + if isinstance(table_name_or_schema, exp.Schema): + table_name = table_name_or_schema.this + assert isinstance(table_name, exp.Table) + else: + table_name = table_name_or_schema + + if "hive" in self.get_catalog_type_from_table(table_name): + # the Trino Hive connector can take a few seconds for metadata changes to propagate to all internal threads + # (even if metadata TTL is set to 0s) + # Blocking until the table shows up means that subsequent code expecting it to exist immediately will not fail + self._block_until_table_exists(table_name) + + def _schema_location(self, schema_name: SchemaName) -> t.Optional[str]: + if mapping := self.schema_location_mapping: + schema = to_schema(schema_name) + match_key = schema.db + + # only consider the catalog if it is present + if schema.catalog: + match_key = f"{schema.catalog}.{match_key}" + + for k, v in mapping.items(): + if re.match(k, match_key): + return v.replace("@{schema_name}", schema.db).replace( + "@{catalog_name}", schema.catalog + ) + return None diff --git a/sqlmesh/core/environment.py b/sqlmesh/core/environment.py index e28956bcad..4a1f417468 100644 --- a/sqlmesh/core/environment.py +++ b/sqlmesh/core/environment.py @@ -8,12 +8,19 @@ from sqlmesh.core import constants as c from sqlmesh.core.config import EnvironmentSuffixTarget -from sqlmesh.core.snapshot import SnapshotId, SnapshotTableInfo +from sqlmesh.core.engine_adapter.base import EngineAdapter +from sqlmesh.core.macros import RuntimeStage +from sqlmesh.core.renderer import render_statements +from sqlmesh.core.snapshot import SnapshotId, SnapshotTableInfo, Snapshot from sqlmesh.utils import word_characters_only -from sqlmesh.utils.date import TimeLike -from sqlmesh.utils.pydantic import PydanticModel, field_validator +from sqlmesh.utils.date import TimeLike, now_timestamp +from sqlmesh.utils.errors import SQLMeshError +from sqlmesh.utils.jinja import JinjaMacroRegistry +from sqlmesh.utils.metaprogramming import Executable +from sqlmesh.utils.pydantic import PydanticModel, field_validator, ValidationInfo T = t.TypeVar("T", bound="EnvironmentNamingInfo") +PydanticType = t.TypeVar("PydanticType", bound="PydanticModel") class EnvironmentNamingInfo(PydanticModel): @@ -26,22 +33,31 @@ class EnvironmentNamingInfo(PydanticModel): catalog_name_override: The name of the catalog to use for this environment if an override was provided normalize_name: Indicates whether the environment's name will be normalized. For example, if it's `dev`, then it will become `DEV` when targeting Snowflake. + gateway_managed: Determines whether the virtual layer's views are created by the model-specific + gateways, otherwise the default gateway is used. Default: False. """ name: str = c.PROD suffix_target: EnvironmentSuffixTarget = Field(default=EnvironmentSuffixTarget.SCHEMA) catalog_name_override: t.Optional[str] = None normalize_name: bool = True + gateway_managed: bool = False + + @property + def is_dev(self) -> bool: + return self.name.lower() != c.PROD @field_validator("name", mode="before") @classmethod def _sanitize_name(cls, v: str) -> str: return word_characters_only(v).lower() - @field_validator("normalize_name", mode="before") + @field_validator("normalize_name", "gateway_managed", mode="before") @classmethod - def _validate_normalize_name(cls, v: t.Any) -> bool: - return True if v is None else bool(v) + def _validate_boolean_field(cls, v: t.Any, info: ValidationInfo) -> bool: + if v is None: + return info.field_name == "normalize_name" + return bool(v) @t.overload @classmethod @@ -84,50 +100,92 @@ def from_environment_catalog_mapping( return cls(**construction_kwargs) -class Environment(EnvironmentNamingInfo): - """Represents an isolated environment. - - Environments are isolated workspaces that hold pointers to physical tables. +class EnvironmentSummary(PydanticModel): + """Represents summary information of an isolated environment. Args: - snapshots: The snapshots that are part of this environment. + name: The name of the environment. start_at: The start time of the environment. end_at: The end time of the environment. plan_id: The ID of the plan that last updated this environment. previous_plan_id: The ID of the previous plan that updated this environment. expiration_ts: The timestamp when this environment will expire. finalized_ts: The timestamp when this environment was finalized. - promoted_snapshot_ids: The IDs of the snapshots that are promoted in this environment - (i.e. for which the views are created). If not specified, all snapshots are promoted. - previous_finalized_snapshots: Snapshots that were part of this environment last time it was finalized. """ - snapshots: t.List[SnapshotTableInfo] + name: str start_at: TimeLike end_at: t.Optional[TimeLike] = None plan_id: str previous_plan_id: t.Optional[str] = None expiration_ts: t.Optional[int] = None finalized_ts: t.Optional[int] = None - promoted_snapshot_ids: t.Optional[t.List[SnapshotId]] = None - previous_finalized_snapshots: t.Optional[t.List[SnapshotTableInfo]] = None - @field_validator("snapshots", "previous_finalized_snapshots", mode="before") + @property + def expired(self) -> bool: + return self.expiration_ts is not None and self.expiration_ts <= now_timestamp() + + +class Environment(EnvironmentNamingInfo, EnvironmentSummary): + """Represents an isolated environment. + + Environments are isolated workspaces that hold pointers to physical tables. + + Args: + snapshots: The snapshots that are part of this environment. + promoted_snapshot_ids: The IDs of the snapshots that are promoted in this environment + (i.e. for which the views are created). If not specified, all snapshots are promoted. + previous_finalized_snapshots: Snapshots that were part of this environment last time it was finalized. + requirements: A mapping of library versions for all the snapshots in this environment. + """ + + snapshots_: t.List[t.Any] = Field(alias="snapshots") + promoted_snapshot_ids_: t.Optional[t.List[t.Any]] = Field( + default=None, alias="promoted_snapshot_ids" + ) + previous_finalized_snapshots_: t.Optional[t.List[t.Any]] = Field( + default=None, alias="previous_finalized_snapshots" + ) + requirements: t.Dict[str, str] = {} + + @field_validator("snapshots_", "previous_finalized_snapshots_", mode="before") @classmethod - def _convert_snapshots( - cls, v: str | t.List[SnapshotTableInfo] | None - ) -> t.List[SnapshotTableInfo] | None: + def _load_snapshots(cls, v: str | t.List[t.Any] | None) -> t.List[t.Any] | None: if isinstance(v, str): - return [SnapshotTableInfo.parse_obj(obj) for obj in json.loads(v)] + return json.loads(v) + if v and not isinstance(next(iter(v)), (dict, SnapshotTableInfo)): + raise ValueError("Must be a list of SnapshotTableInfo dicts or objects") return v - @field_validator("promoted_snapshot_ids", mode="before") + @field_validator("promoted_snapshot_ids_", mode="before") @classmethod - def _convert_snapshot_ids(cls, v: str | t.List[SnapshotId]) -> t.List[SnapshotId]: + def _load_snapshot_ids(cls, v: str | t.List[t.Any] | None) -> t.List[t.Any] | None: if isinstance(v, str): - return [SnapshotId.parse_obj(obj) for obj in json.loads(v)] + return json.loads(v) + if v and not isinstance(next(iter(v)), (dict, SnapshotId)): + raise ValueError("Must be a list of SnapshotId dicts or objects") return v + @field_validator("requirements", mode="before") + def _load_requirements(cls, v: t.Any) -> t.Any: + if isinstance(v, str): + v = json.loads(v) + return v or {} + + @property + def snapshots(self) -> t.List[SnapshotTableInfo]: + return self._convert_list_to_models_and_store("snapshots_", SnapshotTableInfo) or [] + + def snapshot_dicts(self) -> t.List[dict]: + return self._convert_list_to_dicts(self.snapshots_) + + @property + def promoted_snapshot_ids(self) -> t.Optional[t.List[SnapshotId]]: + return self._convert_list_to_models_and_store("promoted_snapshot_ids_", SnapshotId) + + def promoted_snapshot_id_dicts(self) -> t.List[dict]: + return self._convert_list_to_dicts(self.promoted_snapshot_ids_) + @property def promoted_snapshots(self) -> t.List[SnapshotTableInfo]: if self.promoted_snapshot_ids is None: @@ -136,6 +194,15 @@ def promoted_snapshots(self) -> t.List[SnapshotTableInfo]: promoted_snapshot_ids = set(self.promoted_snapshot_ids) return [s for s in self.snapshots if s.snapshot_id in promoted_snapshot_ids] + @property + def previous_finalized_snapshots(self) -> t.Optional[t.List[SnapshotTableInfo]]: + return self._convert_list_to_models_and_store( + "previous_finalized_snapshots_", SnapshotTableInfo + ) + + def previous_finalized_snapshot_dicts(self) -> t.List[dict]: + return self._convert_list_to_dicts(self.previous_finalized_snapshots_) + @property def finalized_or_current_snapshots(self) -> t.List[SnapshotTableInfo]: return ( @@ -151,4 +218,129 @@ def naming_info(self) -> EnvironmentNamingInfo: suffix_target=self.suffix_target, catalog_name_override=self.catalog_name_override, normalize_name=self.normalize_name, + gateway_managed=self.gateway_managed, + ) + + @property + def summary(self) -> EnvironmentSummary: + return EnvironmentSummary( + name=self.name, + start_at=self.start_at, + end_at=self.end_at, + plan_id=self.plan_id, + previous_plan_id=self.previous_plan_id, + expiration_ts=self.expiration_ts, + finalized_ts=self.finalized_ts, + ) + + def can_partially_promote(self, existing_environment: Environment) -> bool: + """Returns True if the existing environment can be partially promoted to the current environment. + + Partial promotion means that we don't need to re-create views for snapshots that are already promoted in the + target environment. + """ + return ( + bool(existing_environment.finalized_ts) + and not existing_environment.expired + and existing_environment.gateway_managed == self.gateway_managed + and existing_environment.name == c.PROD + ) + + def _convert_list_to_models_and_store( + self, field: str, type_: t.Type[PydanticType] + ) -> t.Optional[t.List[PydanticType]]: + value = getattr(self, field) + if value and not isinstance(value[0], type_): + value = [type_.parse_obj(obj) for obj in value] + setattr(self, field, value) + return value + + def _convert_list_to_dicts(self, value: t.Optional[t.List[t.Any]]) -> t.List[dict]: + if not value: + return [] + return value if isinstance(value[0], dict) else [v.dict() for v in value] + + +class EnvironmentStatements(PydanticModel): + before_all: t.List[str] + after_all: t.List[str] + python_env: t.Dict[str, Executable] + jinja_macros: t.Optional[JinjaMacroRegistry] = None + project: t.Optional[str] = None + + def render_before_all( + self, + dialect: str, + default_catalog: t.Optional[str] = None, + **render_kwargs: t.Any, + ) -> t.List[str]: + return self.render(RuntimeStage.BEFORE_ALL, dialect, default_catalog, **render_kwargs) + + def render_after_all( + self, + dialect: str, + default_catalog: t.Optional[str] = None, + **render_kwargs: t.Any, + ) -> t.List[str]: + return self.render(RuntimeStage.AFTER_ALL, dialect, default_catalog, **render_kwargs) + + def render( + self, + runtime_stage: RuntimeStage, + dialect: str, + default_catalog: t.Optional[str] = None, + **render_kwargs: t.Any, + ) -> t.List[str]: + return render_statements( + statements=getattr(self, runtime_stage.value), + dialect=dialect, + default_catalog=default_catalog, + python_env=self.python_env, + jinja_macros=self.jinja_macros, + runtime_stage=runtime_stage, + **render_kwargs, + ) + + +def execute_environment_statements( + adapter: EngineAdapter, + environment_statements: t.List[EnvironmentStatements], + runtime_stage: RuntimeStage, + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str] = None, + snapshots: t.Optional[t.Dict[str, Snapshot]] = None, + start: t.Optional[TimeLike] = None, + end: t.Optional[TimeLike] = None, + execution_time: t.Optional[TimeLike] = None, + selected_models: t.Optional[t.Set[str]] = None, +) -> None: + try: + rendered_expressions = [ + expr + for statements in environment_statements + for expr in statements.render( + runtime_stage=runtime_stage, + dialect=adapter.dialect, + default_catalog=default_catalog, + snapshots=snapshots, + start=start, + end=end, + execution_time=execution_time, + environment_naming_info=environment_naming_info, + engine_adapter=adapter, + selected_models=selected_models, + ) + ] + except Exception as e: + raise SQLMeshError( + f"An error occurred during rendering of the '{runtime_stage.value}' statements:\n\n{e}" ) + if rendered_expressions: + with adapter.transaction(): + for expr in rendered_expressions: + try: + adapter.execute(expr) + except Exception as e: + raise SQLMeshError( + f"An error occurred during execution of the following '{runtime_stage.value}' statement:\n\n{expr}\n\n{e}" + ) diff --git a/sqlmesh/core/janitor.py b/sqlmesh/core/janitor.py new file mode 100644 index 0000000000..e050d6ef6c --- /dev/null +++ b/sqlmesh/core/janitor.py @@ -0,0 +1,181 @@ +from __future__ import annotations + +import typing as t + +from sqlglot import exp + +from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.core.console import Console +from sqlmesh.core.dialect import schema_ +from sqlmesh.core.environment import Environment +from sqlmesh.core.snapshot import SnapshotEvaluator +from sqlmesh.core.state_sync import StateSync +from sqlmesh.core.state_sync.common import ( + logger, + iter_expired_snapshot_batches, + RowBoundary, + ExpiredBatchRange, +) +from sqlmesh.utils.errors import SQLMeshError + + +def cleanup_expired_views( + default_adapter: EngineAdapter, + engine_adapters: t.Dict[str, EngineAdapter], + environments: t.List[Environment], + warn_on_delete_failure: bool = False, + console: t.Optional[Console] = None, +) -> None: + expired_schema_or_catalog_environments = [ + environment + for environment in environments + if environment.suffix_target.is_schema or environment.suffix_target.is_catalog + ] + expired_table_environments = [ + environment for environment in environments if environment.suffix_target.is_table + ] + + # We have to use the corresponding adapter if the virtual layer is gateway managed + def get_adapter(gateway_managed: bool, gateway: t.Optional[str] = None) -> EngineAdapter: + if gateway_managed and gateway: + return engine_adapters.get(gateway, default_adapter) + return default_adapter + + catalogs_to_drop: t.Set[t.Tuple[EngineAdapter, str]] = set() + schemas_to_drop: t.Set[t.Tuple[EngineAdapter, exp.Table]] = set() + + # Collect schemas and catalogs to drop + for engine_adapter, expired_catalog, expired_schema, suffix_target in { + ( + (engine_adapter := get_adapter(environment.gateway_managed, snapshot.model_gateway)), + snapshot.qualified_view_name.catalog_for_environment( + environment.naming_info, dialect=engine_adapter.dialect + ), + snapshot.qualified_view_name.schema_for_environment( + environment.naming_info, dialect=engine_adapter.dialect + ), + environment.suffix_target, + ) + for environment in expired_schema_or_catalog_environments + for snapshot in environment.snapshots + if snapshot.is_model and not snapshot.is_symbolic + }: + if suffix_target.is_catalog: + if expired_catalog: + catalogs_to_drop.add((engine_adapter, expired_catalog)) + else: + schema = schema_(expired_schema, expired_catalog) + schemas_to_drop.add((engine_adapter, schema)) + + # Drop the views for the expired environments + for engine_adapter, expired_view in { + ( + (engine_adapter := get_adapter(environment.gateway_managed, snapshot.model_gateway)), + snapshot.qualified_view_name.for_environment( + environment.naming_info, dialect=engine_adapter.dialect + ), + ) + for environment in expired_table_environments + for snapshot in environment.snapshots + if snapshot.is_model and not snapshot.is_symbolic + }: + try: + engine_adapter.drop_view(expired_view, ignore_if_not_exists=True) + if console: + console.update_cleanup_progress(expired_view) + except Exception as e: + message = f"Failed to drop the expired environment view '{expired_view}': {e}" + if warn_on_delete_failure: + logger.warning(message) + else: + raise SQLMeshError(message) from e + + # Drop the schemas for the expired environments + for engine_adapter, schema in schemas_to_drop: + try: + engine_adapter.drop_schema( + schema, + ignore_if_not_exists=True, + cascade=True, + ) + if console: + console.update_cleanup_progress(schema.sql(dialect=engine_adapter.dialect)) + except Exception as e: + message = f"Failed to drop the expired environment schema '{schema}': {e}" + if warn_on_delete_failure: + logger.warning(message) + else: + raise SQLMeshError(message) from e + + # Drop any catalogs that were associated with a snapshot where the engine adapter supports dropping catalogs + # catalogs_to_drop is only populated when environment_suffix_target is set to 'catalog' + for engine_adapter, catalog in catalogs_to_drop: + if engine_adapter.SUPPORTS_CREATE_DROP_CATALOG: + try: + engine_adapter.drop_catalog(catalog) + if console: + console.update_cleanup_progress(catalog) + except Exception as e: + message = f"Failed to drop the expired environment catalog '{catalog}': {e}" + if warn_on_delete_failure: + logger.warning(message) + else: + raise SQLMeshError(message) from e + + +def delete_expired_snapshots( + state_sync: StateSync, + snapshot_evaluator: SnapshotEvaluator, + *, + current_ts: int, + ignore_ttl: bool = False, + batch_size: t.Optional[int] = None, + console: t.Optional[Console] = None, +) -> None: + """Delete all expired snapshots in batches. + + This helper function encapsulates the logic for deleting expired snapshots in batches, + eliminating code duplication across different use cases. + + Args: + state_sync: StateSync instance to query and delete expired snapshots from. + snapshot_evaluator: SnapshotEvaluator instance to clean up tables associated with snapshots. + current_ts: Timestamp used to evaluate expiration. + ignore_ttl: If True, include snapshots regardless of TTL (only checks if unreferenced). + batch_size: Maximum number of snapshots to fetch per batch. + console: Optional console for reporting progress. + + Returns: + The total number of deleted expired snapshots. + """ + num_expired_snapshots = 0 + for batch in iter_expired_snapshot_batches( + state_reader=state_sync, + current_ts=current_ts, + ignore_ttl=ignore_ttl, + batch_size=batch_size, + ): + end_info = ( + f"updated_ts={batch.batch_range.end.updated_ts}" + if isinstance(batch.batch_range.end, RowBoundary) + else f"limit={batch.batch_range.end.batch_size}" + ) + logger.info( + "Processing batch of size %s with end %s", + len(batch.expired_snapshot_ids), + end_info, + ) + snapshot_evaluator.cleanup( + target_snapshots=batch.cleanup_tasks, + on_complete=console.update_cleanup_progress if console else None, + ) + state_sync.delete_expired_snapshots( + batch_range=ExpiredBatchRange( + start=RowBoundary.lowest_boundary(), + end=batch.batch_range.end, + ), + ignore_ttl=ignore_ttl, + ) + logger.info("Cleaned up expired snapshots batch") + num_expired_snapshots += len(batch.expired_snapshot_ids) + logger.info("Cleaned up %s expired snapshots", num_expired_snapshots) diff --git a/sqlmesh/core/lineage.py b/sqlmesh/core/lineage.py index e133fca872..777a2a7d9a 100644 --- a/sqlmesh/core/lineage.py +++ b/sqlmesh/core/lineage.py @@ -66,10 +66,13 @@ def lineage( scope=scope, trim_selects=trim_selects, dialect=model.dialect, + copy=False, ) -def column_dependencies(context: Context, model_name: str, column: str) -> t.Dict[str, t.Set[str]]: +def column_dependencies( + context: Context, model_name: str, column: str | exp.Column +) -> t.Dict[str, t.Set[str]]: model = context.get_model(model_name) parents = defaultdict(set) @@ -86,7 +89,9 @@ def column_dependencies(context: Context, model_name: str, column: str) -> t.Dic return dict(parents) -def column_description(context: Context, model_name: str, column: str) -> t.Optional[str]: +def column_description( + context: Context, model_name: str, column: str, quote_column: bool = False +) -> t.Optional[str]: """Returns a column's description, inferring if needed.""" model = context.get_model(model_name) @@ -96,7 +101,7 @@ def column_description(context: Context, model_name: str, column: str) -> t.Opti if column in model.column_descriptions: return model.column_descriptions[column] - dependencies = column_dependencies(context, model_name, column) + dependencies = column_dependencies(context, model_name, exp.column(column, quoted=quote_column)) if len(dependencies) != 1: return None diff --git a/sqlmesh/schedulers/airflow/hooks/__init__.py b/sqlmesh/core/linter/__init__.py similarity index 100% rename from sqlmesh/schedulers/airflow/hooks/__init__.py rename to sqlmesh/core/linter/__init__.py diff --git a/sqlmesh/core/linter/definition.py b/sqlmesh/core/linter/definition.py new file mode 100644 index 0000000000..7dc64bbf95 --- /dev/null +++ b/sqlmesh/core/linter/definition.py @@ -0,0 +1,162 @@ +from __future__ import annotations + +import operator as op +import typing as t +from collections.abc import Iterator, Iterable, Set, Mapping, Callable +from functools import reduce + +from sqlmesh.core.config.linter import LinterConfig +from sqlmesh.core.console import LinterConsole, get_console +from sqlmesh.core.linter.rule import Rule, RuleViolation, Range, Fix +from sqlmesh.core.model import Model +from sqlmesh.utils.errors import raise_config_error + +if t.TYPE_CHECKING: + from sqlmesh.core.context import GenericContext + + +def select_rules(all_rules: RuleSet, rule_names: t.Set[str]) -> RuleSet: + if "all" in rule_names: + return all_rules + + rules = set() + for rule_name in rule_names: + if rule_name not in all_rules: + raise_config_error(f"Rule {rule_name} could not be found") + + rules.add(all_rules[rule_name]) + + return RuleSet(rules) + + +class Linter: + def __init__( + self, enabled: bool, all_rules: RuleSet, rules: RuleSet, warn_rules: RuleSet + ) -> None: + self.enabled = enabled + self.all_rules = all_rules + self.rules = rules + self.warn_rules = warn_rules + + if overlapping := rules.intersection(warn_rules): + overlapping_rules = ", ".join(rule for rule in overlapping) + raise_config_error( + f"Rules cannot simultaneously warn and raise an error: [{overlapping_rules}]" + ) + + @classmethod + def from_rules(cls, all_rules: RuleSet, config: LinterConfig) -> Linter: + ignored_rules = select_rules(all_rules, config.ignored_rules) + included_rules = all_rules.difference(ignored_rules) + + rules = select_rules(included_rules, config.rules) + warn_rules = select_rules(included_rules, config.warn_rules) + + return Linter(config.enabled, all_rules, rules, warn_rules) + + def lint_model( + self, model: Model, context: GenericContext, console: LinterConsole = get_console() + ) -> t.Tuple[bool, t.List[AnnotatedRuleViolation]]: + if not self.enabled: + return False, [] + + ignored_rules = select_rules(self.all_rules, model.ignored_rules) + + rules = self.rules.difference(ignored_rules) + warn_rules = self.warn_rules.difference(ignored_rules) + + error_violations = rules.check_model(model, context) + warn_violations = warn_rules.check_model(model, context) + + all_violations: t.List[AnnotatedRuleViolation] = [ + AnnotatedRuleViolation( + rule=violation.rule, + violation_msg=violation.violation_msg, + model=model, + violation_type="error", + violation_range=violation.violation_range, + fixes=violation.fixes, + ) + for violation in error_violations + ] + [ + AnnotatedRuleViolation( + rule=violation.rule, + violation_msg=violation.violation_msg, + model=model, + violation_type="warning", + violation_range=violation.violation_range, + fixes=violation.fixes, + ) + for violation in warn_violations + ] + + if warn_violations: + console.show_linter_violations(warn_violations, model) + if error_violations: + console.show_linter_violations(error_violations, model, is_error=True) + return True, all_violations + + return False, all_violations + + +class RuleSet(Mapping[str, type[Rule]]): + def __init__(self, rules: Iterable[type[Rule]] = ()) -> None: + self._underlying = {rule.name: rule for rule in rules} + + def check_model(self, model: Model, context: GenericContext) -> t.List[RuleViolation]: + violations = [] + + for rule in self._underlying.values(): + violation = rule(context).check_model(model) + if isinstance(violation, RuleViolation): + violation = [violation] + if violation: + violations.extend(violation) + + return violations + + def __iter__(self) -> Iterator[str]: + return iter(self._underlying) + + def __len__(self) -> int: + return len(self._underlying) + + def __getitem__(self, rule: str | type[Rule]) -> type[Rule]: + key = rule if isinstance(rule, str) else rule.name + return self._underlying[key] + + def __op( + self, + op: Callable[[Set[type[Rule]], Set[type[Rule]]], Set[type[Rule]]], + other: RuleSet, + /, + ) -> RuleSet: + rules = set() + for rule in op(set(self.values()), set(other.values())): + rules.add(other[rule] if rule in other else self[rule]) + + return RuleSet(rules) + + def union(self, *others: RuleSet) -> RuleSet: + return reduce(lambda lhs, rhs: lhs.__op(op.or_, rhs), (self, *others)) + + def intersection(self, *others: RuleSet) -> RuleSet: + return reduce(lambda lhs, rhs: lhs.__op(op.and_, rhs), (self, *others)) + + def difference(self, *others: RuleSet) -> RuleSet: + return reduce(lambda lhs, rhs: lhs.__op(op.sub, rhs), (self, *others)) + + +class AnnotatedRuleViolation(RuleViolation): + def __init__( + self, + rule: Rule, + violation_msg: str, + model: Model, + violation_type: t.Literal["error", "warning"], + violation_range: t.Optional[Range] = None, + fixes: t.Optional[t.List[Fix]] = None, + ) -> None: + super().__init__(rule, violation_msg, violation_range, fixes) + self.model = model + self.violation_type = violation_type diff --git a/sqlmesh/core/linter/helpers.py b/sqlmesh/core/linter/helpers.py new file mode 100644 index 0000000000..3c79f83a43 --- /dev/null +++ b/sqlmesh/core/linter/helpers.py @@ -0,0 +1,315 @@ +from pathlib import Path + +from sqlmesh.core.linter.rule import Range, Position +from sqlmesh.utils.pydantic import PydanticModel +from sqlglot import tokenize, TokenType, Token +import typing as t + + +class TokenPositionDetails(PydanticModel): + """ + Details about a token's position in the source code in the structure provided by SQLGlot. + + Attributes: + line (int): The line that the token ends on. + col (int): The column that the token ends on. + start (int): The start index of the token. + end (int): The ending index of the token. + """ + + line: int + col: int + start: int + end: int + + @staticmethod + def from_meta(meta: t.Dict[str, int]) -> "TokenPositionDetails": + return TokenPositionDetails( + line=meta["line"], + col=meta["col"], + start=meta["start"], + end=meta["end"], + ) + + def to_range(self, read_file: t.Optional[t.List[str]]) -> Range: + """ + Convert a TokenPositionDetails object to a Range object. + + In the circumstances where the token's start and end positions are the same, + there is no need for a read_file parameter, as the range can be derived from the token's + line and column. This is an optimization to avoid unnecessary file reads and should + only be used when the token represents a single character or position in the file. + + If the token's start and end positions are different, the read_file parameter is required. + + :param read_file: List of lines from the file. Optional + :return: A Range object representing the token's position + """ + if self.start == self.end: + # If the start and end positions are the same, we can create a range directly + return Range( + start=Position(line=self.line - 1, character=self.col - 1), + end=Position(line=self.line - 1, character=self.col), + ) + + if read_file is None: + raise ValueError("read_file must be provided when start and end positions differ.") + + # Convert from 1-indexed to 0-indexed for line only + end_line_0 = self.line - 1 + end_col_0 = self.col + + # Find the start line and column by counting backwards from the end position + start_pos = self.start + end_pos = self.end + + # Initialize with the end position + start_line_0 = end_line_0 + start_col_0 = end_col_0 - (end_pos - start_pos + 1) + + # If start_col_0 is negative, we need to go back to previous lines + while start_col_0 < 0 and start_line_0 > 0: + start_line_0 -= 1 + start_col_0 += len(read_file[start_line_0]) + # Account for newline character + if start_col_0 >= 0: + break + start_col_0 += 1 # For the newline character + + # Ensure we don't have negative values + start_col_0 = max(0, start_col_0) + return Range( + start=Position(line=start_line_0, character=start_col_0), + end=Position(line=end_line_0, character=end_col_0), + ) + + +def read_range_from_string(content: str, text_range: Range) -> str: + lines = content.splitlines(keepends=False) + + # Ensure the range is within bounds + start_line = max(0, text_range.start.line) + end_line = min(len(lines), text_range.end.line + 1) + + if start_line >= end_line: + return "" + + # Extract the relevant portions of each line + result = [] + for i in range(start_line, end_line): + line = lines[i] + start_char = text_range.start.character if i == text_range.start.line else 0 + end_char = text_range.end.character if i == text_range.end.line else len(line) + result.append(line[start_char:end_char]) + + return "".join(result) + + +def read_range_from_file(file: Path, text_range: Range) -> str: + """ + Read the file and return the content within the specified range. + + Args: + file: Path to the file to read + text_range: The range of text to extract + + Returns: + The content within the specified range + """ + with file.open("r", encoding="utf-8") as f: + lines = f.readlines() + + return read_range_from_string("".join(lines), text_range) + + +def get_start_and_end_of_model_block( + tokens: t.List[Token], +) -> t.Optional[t.Tuple[int, int]]: + """ + Returns the start and end tokens of the MODEL block in an SQL file. + The MODEL block is defined as the first occurrence of the keyword "MODEL" followed by + an opening parenthesis and a closing parenthesis that matches the opening one. + """ + # 1) Find the MODEL token + try: + model_idx = next( + i + for i, tok in enumerate(tokens) + if tok.token_type is TokenType.VAR and tok.text.upper() == "MODEL" + ) + except StopIteration: + return None + + # 2) Find the opening parenthesis for the MODEL properties list + try: + lparen_idx = next( + i + for i in range(model_idx + 1, len(tokens)) + if tokens[i].token_type is TokenType.L_PAREN + ) + except StopIteration: + return None + + # 3) Find the matching closing parenthesis by looking for the first semicolon after + # the opening parenthesis and assuming the MODEL block ends there. + try: + closing_semicolon = next( + i + for i in range(lparen_idx + 1, len(tokens)) + if tokens[i].token_type is TokenType.SEMICOLON + ) + # If we find a semicolon, we can assume the MODEL block ends there + rparen_idx = closing_semicolon - 1 + if tokens[rparen_idx].token_type is TokenType.R_PAREN: + return (lparen_idx, rparen_idx) + return None + except StopIteration: + return None + + +def get_range_of_model_block( + sql: str, + dialect: str, +) -> t.Optional[Range]: + """ + Get the range of the model block in an SQL file, + """ + tokens = tokenize(sql, dialect=dialect) + block = get_start_and_end_of_model_block(tokens) + if not block: + return None + (start_idx, end_idx) = block + start = tokens[start_idx - 1] + end = tokens[end_idx + 1] + start_position = TokenPositionDetails( + line=start.line, + col=start.col, + start=start.start, + end=start.end, + ) + end_position = TokenPositionDetails( + line=end.line, + col=end.col, + start=end.start, + end=end.end, + ) + splitlines = sql.splitlines() + return Range( + start=start_position.to_range(splitlines).start, + end=end_position.to_range(splitlines).end, + ) + + +def get_range_of_a_key_in_model_block( + sql: str, + dialect: str, + key: str, +) -> t.Optional[t.Tuple[Range, Range]]: + """ + Get the ranges of a specific key and its value in the MODEL block of an SQL file. + + Returns a tuple of (key_range, value_range) if found, otherwise None. + """ + tokens = tokenize(sql, dialect=dialect) + if not tokens: + return None + + block = get_start_and_end_of_model_block(tokens) + if not block: + return None + (lparen_idx, rparen_idx) = block + + # 4) Scan within the MODEL property list for the key at top-level (depth == 1) + # Initialize depth to 1 since we're inside the first parentheses + depth = 1 + for i in range(lparen_idx + 1, rparen_idx): + tok = tokens[i] + tt = tok.token_type + + if tt is TokenType.L_PAREN: + depth += 1 + continue + if tt is TokenType.R_PAREN: + depth -= 1 + # If we somehow exit before rparen_idx, stop early + if depth <= 0: + break + continue + + if depth == 1 and tt is TokenType.VAR and tok.text.upper() == key.upper(): + # Validate key position: it should immediately follow '(' or ',' at top level + prev_idx = i - 1 + prev_tt = tokens[prev_idx].token_type if prev_idx >= 0 else None + if prev_tt not in (TokenType.L_PAREN, TokenType.COMMA): + continue + + # Key range + lines = sql.splitlines() + key_start = TokenPositionDetails( + line=tok.line, col=tok.col, start=tok.start, end=tok.end + ) + key_range = key_start.to_range(lines) + + value_start_idx = i + 1 + if value_start_idx >= rparen_idx: + return None + + # Walk to the end of the value expression: until top-level comma or closing paren + # Track internal nesting for (), [], {} + nested = 0 + j = value_start_idx + value_end_idx = value_start_idx + + def is_open(t: TokenType) -> bool: + return t in (TokenType.L_PAREN, TokenType.L_BRACE, TokenType.L_BRACKET) + + def is_close(t: TokenType) -> bool: + return t in (TokenType.R_PAREN, TokenType.R_BRACE, TokenType.R_BRACKET) + + while j < rparen_idx: + ttype = tokens[j].token_type + if is_open(ttype): + nested += 1 + elif is_close(ttype): + nested -= 1 + + # End of value: at top-level (nested == 0) encountering a comma or the end paren + if nested == 0 and ( + ttype is TokenType.COMMA or (ttype is TokenType.R_PAREN and depth == 1) + ): + # For comma, don't include it in the value range + # For closing paren, include it only if it's part of the value structure + if ttype is TokenType.COMMA: + # Don't include the comma in the value range + break + else: + # Include the closing parenthesis in the value range + value_end_idx = j + break + + value_end_idx = j + j += 1 + + value_start_tok = tokens[value_start_idx] + value_end_tok = tokens[value_end_idx] + + value_start_pos = TokenPositionDetails( + line=value_start_tok.line, + col=value_start_tok.col, + start=value_start_tok.start, + end=value_start_tok.end, + ) + value_end_pos = TokenPositionDetails( + line=value_end_tok.line, + col=value_end_tok.col, + start=value_end_tok.start, + end=value_end_tok.end, + ) + value_range = Range( + start=value_start_pos.to_range(lines).start, + end=value_end_pos.to_range(lines).end, + ) + + return (key_range, value_range) + + return None diff --git a/sqlmesh/core/linter/rule.py b/sqlmesh/core/linter/rule.py new file mode 100644 index 0000000000..8dd1a2ebbd --- /dev/null +++ b/sqlmesh/core/linter/rule.py @@ -0,0 +1,151 @@ +from __future__ import annotations + +import abc +from dataclasses import dataclass, field +from pathlib import Path + +from sqlmesh.core.model import Model + +from typing import Type + +import typing as t + +from sqlmesh.utils.pydantic import PydanticModel + + +if t.TYPE_CHECKING: + from sqlmesh.core.context import GenericContext + + +class RuleLocation(PydanticModel): + """The location of a rule in a file.""" + + file_path: str + start_line: t.Optional[int] = None + + +@dataclass(frozen=True) +class Position: + """The position of a rule violation in a file, the position follows the LSP standard.""" + + line: int + character: int + + +@dataclass(frozen=True) +class Range: + """The range of a rule violation in a file. The range follows the LSP standard.""" + + start: Position + end: Position + + +@dataclass(frozen=True) +class TextEdit: + """A text edit to apply to a file.""" + + path: Path + range: Range + new_text: str + + +@dataclass(frozen=True) +class CreateFile: + """Create a new file with the provided text.""" + + path: Path + text: str + + +@dataclass(frozen=True) +class Fix: + """A fix that can be applied to resolve a rule violation.""" + + title: str + edits: t.List[TextEdit] = field(default_factory=list) + create_files: t.List[CreateFile] = field(default_factory=list) + + +class _Rule(abc.ABCMeta): + def __new__(cls: Type[_Rule], clsname: str, bases: t.Tuple, attrs: t.Dict) -> _Rule: + attrs["name"] = clsname.lower() + return super().__new__(cls, clsname, bases, attrs) + + +class Rule(abc.ABC, metaclass=_Rule): + """The base class for a rule.""" + + name = "rule" + + def __init__(self, context: GenericContext): + self.context = context + + @abc.abstractmethod + def check_model( + self, model: Model + ) -> t.Optional[t.Union[RuleViolation, t.List[RuleViolation]]]: + """The evaluation function that'll check for a violation of this rule.""" + + @property + def summary(self) -> str: + """A summary of what this rule checks for.""" + return self.__doc__ or "" + + def violation( + self, + violation_msg: t.Optional[str] = None, + violation_range: t.Optional[Range] = None, + fixes: t.Optional[t.List[Fix]] = None, + ) -> RuleViolation: + """Create a RuleViolation instance for this rule""" + return RuleViolation( + rule=self, + violation_msg=violation_msg or self.summary, + violation_range=violation_range, + fixes=fixes, + ) + + def get_definition_location(self) -> RuleLocation: + """Return the file path and position information for this rule. + + This method returns information about where this rule is defined, + which can be used in diagnostics to link to the rule's documentation. + + Returns: + A dictionary containing file path and position information. + """ + import inspect + + # Get the file where the rule class is defined + file_path = inspect.getfile(self.__class__) + + try: + # Get the source code and line number + source_lines, start_line = inspect.getsourcelines(self.__class__) + return RuleLocation( + file_path=file_path, + start_line=start_line, + ) + except (IOError, TypeError): + # Fall back to just returning the file path if we can't get source lines + return RuleLocation(file_path=file_path) + + def __repr__(self) -> str: + return self.name + + +class RuleViolation: + def __init__( + self, + rule: Rule, + violation_msg: str, + violation_range: t.Optional[Range] = None, + fixes: t.Optional[t.List[Fix]] = None, + ) -> None: + self.rule = rule + self.violation_msg = violation_msg + self.violation_range = violation_range + self.fixes = fixes or [] + + def __repr__(self) -> str: + return f"{self.rule.name}: {self.violation_msg}" diff --git a/sqlmesh/core/linter/rules/__init__.py b/sqlmesh/core/linter/rules/__init__.py new file mode 100644 index 0000000000..43812479a5 --- /dev/null +++ b/sqlmesh/core/linter/rules/__init__.py @@ -0,0 +1 @@ +from sqlmesh.core.linter.rules.builtin import BUILTIN_RULES as BUILTIN_RULES diff --git a/sqlmesh/core/linter/rules/builtin.py b/sqlmesh/core/linter/rules/builtin.py new file mode 100644 index 0000000000..4547ac0528 --- /dev/null +++ b/sqlmesh/core/linter/rules/builtin.py @@ -0,0 +1,321 @@ +"""Contains all the standard rules included with SQLMesh""" + +from __future__ import annotations + +import typing as t + +from sqlglot.expressions import Star +from sqlglot.helper import subclasses + +from sqlmesh.core.constants import EXTERNAL_MODELS_YAML +from sqlmesh.core.dialect import normalize_model_name +from sqlmesh.core.linter.helpers import ( + TokenPositionDetails, + get_range_of_model_block, + read_range_from_string, +) +from sqlmesh.core.linter.rule import ( + Rule, + RuleViolation, + Range, + Fix, + TextEdit, + Position, + CreateFile, +) +from sqlmesh.core.linter.definition import RuleSet +from sqlmesh.core.model import Model, SqlModel, ExternalModel +from sqlmesh.utils.lineage import extract_references_from_query, ExternalModelReference + + +class NoSelectStar(Rule): + """Query should not contain SELECT * on its outer most projections, even if it can be expanded.""" + + def check_model(self, model: Model) -> t.Optional[RuleViolation]: + # Only applies to SQL models, as other model types do not have a query. + if not isinstance(model, SqlModel): + return None + if model.query.is_star: + violation_range = self._get_range(model) + fixes = self._create_fixes(model, violation_range) + return self.violation(violation_range=violation_range, fixes=fixes) + return None + + def _get_range(self, model: SqlModel) -> t.Optional[Range]: + """Get the range of the violation if available.""" + try: + if len(model.query.expressions) == 1 and isinstance(model.query.expressions[0], Star): + return TokenPositionDetails.from_meta(model.query.expressions[0].meta).to_range( + None + ) + except Exception: + pass + + return None + + def _create_fixes( + self, model: SqlModel, violation_range: t.Optional[Range] + ) -> t.Optional[t.List[Fix]]: + """Create fixes for the SELECT * violation.""" + if not violation_range: + return None + columns = model.columns_to_types + if not columns: + return None + path = model._path + if path is None: + return None + new_text = ", ".join(columns.keys()) + return [ + Fix( + title="Replace SELECT * with explicit column list", + edits=[ + TextEdit( + path=path, + range=violation_range, + new_text=new_text, + ) + ], + ) + ] + + +class InvalidSelectStarExpansion(Rule): + def check_model(self, model: Model) -> t.Optional[RuleViolation]: + deps = model.violated_rules_for_query.get(InvalidSelectStarExpansion) + if not deps: + return None + + violation_msg = ( + f"SELECT * cannot be expanded due to missing schema(s) for model(s): {deps}. " + "Run `sqlmesh create_external_models` and / or make sure that the model " + f"'{model.fqn}' can be rendered at parse time." + ) + + return self.violation(violation_msg) + + +class AmbiguousOrInvalidColumn(Rule): + def check_model(self, model: Model) -> t.Optional[RuleViolation]: + sqlglot_err = model.violated_rules_for_query.get(AmbiguousOrInvalidColumn) + if not sqlglot_err: + return None + + violation_msg = ( + f"{sqlglot_err} for model '{model.fqn}', the column may not exist or is ambiguous." + ) + + return self.violation(violation_msg) + + +class NoMissingAudits(Rule): + """Model `audits` must be configured to test data quality.""" + + def check_model(self, model: Model) -> t.Optional[RuleViolation]: + if model.audits or model.kind.is_symbolic: + return None + if model._path is None or not str(model._path).endswith(".sql"): + return self.violation() + + try: + with open(model._path, "r", encoding="utf-8") as file: + content = file.read() + + range = get_range_of_model_block(content, model.dialect) + if range: + return self.violation(violation_range=range) + return self.violation() + except Exception: + return self.violation() + + +class NoMissingUnitTest(Rule): + """All models must have a unit test found in the tests/ directory yaml files""" + + def check_model(self, model: Model) -> t.Optional[RuleViolation]: + # External models cannot have unit tests + if isinstance(model, ExternalModel): + return None + + if model.name not in self.context.models_with_tests: + return self.violation( + violation_msg=f"Model {model.name} is missing unit test(s). Please add in the tests/ directory." + ) + return None + + +class NoMissingExternalModels(Rule): + """All external models must be registered in the external_models.yaml file""" + + def check_model( + self, model: Model + ) -> t.Optional[t.Union[RuleViolation, t.List[RuleViolation]]]: + # Ignore external models themselves, because either they are registered, + # and if they are not, they will be caught as referenced in another model. + if isinstance(model, ExternalModel): + return None + + # Handle other models that may refer to the external models. + not_registered_external_models: t.Set[str] = set() + for depends_on_model in model.depends_on: + existing_model = self.context.get_model(depends_on_model) + if existing_model is None: + not_registered_external_models.add(depends_on_model) + + if not not_registered_external_models: + return None + + # If the model is anything other than a sql model that and has a path + # that ends with .sql, we cannot extract the references from the query. + path = model._path + if not isinstance(model, SqlModel) or not path or not str(path).endswith(".sql"): + return self._standard_error_message( + model_name=model.fqn, + external_models=not_registered_external_models, + ) + + with open(path, "r", encoding="utf-8") as file: + read_file = file.read() + split_read_file = read_file.splitlines() + + # If there are any unregistered external models, return a violation find + # the ranges for them. + references = extract_references_from_query( + query=model.query, + context=self.context, + document_path=path, + read_file=split_read_file, + depends_on=model.depends_on, + dialect=model.dialect, + ) + external_references = { + normalize_model_name( + table=read_range_from_string(read_file, ref.range), + default_catalog=model.default_catalog, + dialect=model.dialect, + ): ref + for ref in references + if isinstance(ref, ExternalModelReference) and ref.path is None + } + + # Ensure that depends_on and external references match. + if not_registered_external_models != set(external_references.keys()): + return self._standard_error_message( + model_name=model.fqn, + external_models=not_registered_external_models, + ) + + # Return a violation for each unregistered external model with its range. + violations = [] + for ref_name, ref in external_references.items(): + if ref_name in not_registered_external_models: + fix = self.create_fix(ref_name) + violations.append( + RuleViolation( + rule=self, + violation_msg=f"Model '{model.fqn}' depends on unregistered external model '{ref_name}'. " + "Please register it in the external models file. This can be done by running 'sqlmesh create_external_models'.", + violation_range=ref.range, + fixes=[fix] if fix else [], + ) + ) + + if len(violations) < len(not_registered_external_models): + return self._standard_error_message( + model_name=model.fqn, + external_models=not_registered_external_models, + ) + + return violations + + def _standard_error_message( + self, model_name: str, external_models: t.Set[str] + ) -> RuleViolation: + return RuleViolation( + rule=self, + violation_msg=f"Model '{model_name}' depends on unregistered external models: " + f"{', '.join(m for m in external_models)}. " + "Please register them in the external models file. This can be done by running 'sqlmesh create_external_models'.", + ) + + def create_fix(self, model_name: str) -> t.Optional[Fix]: + """ + Add an external model to the external models file. + - If no external models file exists, it will create one with the model. + - If the model already exists, it will not add it again. + """ + root = self.context.path + if not root: + return None + + external_models_path = root / EXTERNAL_MODELS_YAML + if not external_models_path.exists(): + return Fix( + title="Add external model file", + edits=[], + create_files=[ + CreateFile( + path=external_models_path, + text=f"- name: '{model_name}'\n", + ) + ], + ) + + # Figure out the position to insert the new external model at the end of the file, whether + # needs new line or not. + with open(external_models_path, "r", encoding="utf-8") as file: + lines = file.read() + + # If a file ends in newline, we can add the new model directly. + split_lines = lines.splitlines() + if lines.endswith("\n"): + new_text = f"- name: '{model_name}'\n" + position = Position(line=len(split_lines), character=0) + else: + new_text = f"\n- name: '{model_name}'\n" + position = Position( + line=len(split_lines) - 1, character=len(split_lines[-1]) if split_lines else 0 + ) + + return Fix( + title="Add external model", + edits=[ + TextEdit( + path=external_models_path, + range=Range(start=position, end=position), + new_text=new_text, + ) + ], + ) + + +class NoAmbiguousProjections(Rule): + """All projections in a model must have unique & inferrable names or explicit aliases.""" + + def check_model(self, model: Model) -> t.Optional[RuleViolation]: + query = model.render_query() + if query is None: + return None + + name_counts: t.Dict[str, int] = {} + projection_list = query.selects + for expression in projection_list: + alias = expression.output_name + if alias == "*": + continue + + if not alias: + return self.violation( + f"Outer projection '{expression.sql(dialect=model.dialect)}' must have inferrable names or explicit aliases." + ) + + name_counts[alias] = name_counts.get(alias, 0) + 1 + + for name, count in name_counts.items(): + if count > 1: + return self.violation(f"Found duplicate outer select name '{name}'") + + return None + + +BUILTIN_RULES = RuleSet(subclasses(__name__, Rule, exclude={Rule})) diff --git a/sqlmesh/core/loader.py b/sqlmesh/core/loader.py index 43803d7ff2..4b7b1bac02 100644 --- a/sqlmesh/core/loader.py +++ b/sqlmesh/core/loader.py @@ -1,156 +1,265 @@ from __future__ import annotations import abc +import glob +import itertools import linecache -import logging import os +import re import typing as t -from collections import defaultdict +from collections import Counter, defaultdict from dataclasses import dataclass from pathlib import Path +from pydantic import ValidationError +import concurrent.futures -from sqlglot.errors import SchemaError, SqlglotError -from sqlglot.schema import MappingSchema +from sqlglot.errors import SqlglotError +from sqlglot import exp +from sqlglot.helper import subclasses from sqlmesh.core import constants as c -from sqlmesh.core.audit import Audit, load_multiple_audits +from sqlmesh.core.audit import Audit, ModelAudit, StandaloneAudit, load_multiple_audits +from sqlmesh.core.console import Console from sqlmesh.core.dialect import parse +from sqlmesh.core.environment import EnvironmentStatements +from sqlmesh.core.linter.rule import Rule +from sqlmesh.core.linter.definition import RuleSet from sqlmesh.core.macros import MacroRegistry, macro from sqlmesh.core.metric import Metric, MetricMeta, expand_metrics, load_metric_ddl from sqlmesh.core.model import ( Model, - ExternalModel, ModelCache, - OptimizedQueryCache, - SeedModel, create_external_model, - load_sql_based_model, + load_sql_based_models, ) from sqlmesh.core.model import model as model_registry -from sqlmesh.utils import UniqueKeyDict -from sqlmesh.utils.dag import DAG +from sqlmesh.core.model.common import make_python_env +from sqlmesh.core.signal import signal +from sqlmesh.core.test import ModelTestMetadata +from sqlmesh.utils import UniqueKeyDict, sys_path from sqlmesh.utils.errors import ConfigError from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroExtractor from sqlmesh.utils.metaprogramming import import_python_file -from sqlmesh.utils.yaml import YAML +from sqlmesh.utils.pydantic import validation_error_message +from sqlmesh.utils.process import create_process_pool_executor +from sqlmesh.utils.yaml import YAML, load as yaml_load + if t.TYPE_CHECKING: - from sqlmesh.core.config import Config from sqlmesh.core.context import GenericContext -logger = logging.getLogger(__name__) +GATEWAY_PATTERN = re.compile(r"gateway:\s*([^\s]+)") + + +@dataclass +class LoadedProject: + macros: MacroRegistry + jinja_macros: JinjaMacroRegistry + models: UniqueKeyDict[str, Model] + standalone_audits: UniqueKeyDict[str, StandaloneAudit] + audits: UniqueKeyDict[str, ModelAudit] + metrics: UniqueKeyDict[str, Metric] + requirements: t.Dict[str, str] + excluded_requirements: t.Set[str] + environment_statements: t.List[EnvironmentStatements] + user_rules: RuleSet + model_test_metadata: t.List[ModelTestMetadata] + + +class CacheBase(abc.ABC): + @abc.abstractmethod + def get_or_load_models( + self, target_path: Path, loader: t.Callable[[], t.List[Model]] + ) -> t.List[Model]: + """Get or load all models from cache.""" + pass + + @abc.abstractmethod + def put(self, models: t.List[Model], path: Path) -> bool: + """Store models in the cache associated with the given path. + + Args: + models: List of models to cache + path: File path to associate with the cached models + + Returns: + True if the models were successfully cached, + False otherwise (empty list, not a list, unsupported model types) + """ + pass + + @abc.abstractmethod + def get(self, path: Path) -> t.List[Model]: + """Retrieve models from the cache for a given path. + + Args: + path: File path to look up in the cache + + Returns: + List of cached models associated with the path, an empty list if no cache entry exists + """ + pass -# TODO: consider moving this to context -def update_model_schemas( - dag: DAG[str], - models: UniqueKeyDict[str, Model], - context_path: Path, +_defaults: t.Optional[t.Dict[str, t.Any]] = None +_cache: t.Optional[CacheBase] = None +_config_essentials: t.Optional[t.Dict[str, t.Any]] = None +_selected_gateway: t.Optional[str] = None + + +def _init_model_defaults( + config_essentials: t.Dict[str, t.Any], + selected_gateway: t.Optional[str], + model_loading_defaults: t.Optional[t.Dict[str, t.Any]] = None, + cache: t.Optional[CacheBase] = None, + console: t.Optional[Console] = None, ) -> None: - schema = MappingSchema(normalize=False) - optimized_query_cache: OptimizedQueryCache = OptimizedQueryCache(context_path / c.CACHE) + global _defaults, _cache, _config_essentials, _selected_gateway + _defaults = model_loading_defaults + _cache = cache + _config_essentials = config_essentials + _selected_gateway = selected_gateway - for name in dag.sorted: - model = models.get(name) + # Set the console passed from the parent process + if console is not None: + from sqlmesh.core.console import set_console - # External models don't exist in the context, so we need to skip them - if not model: - continue + set_console(console) - try: - model.update_schema(schema) - optimized_query_cache.with_optimized_query(model) - columns_to_types = model.columns_to_types - if columns_to_types is not None: - schema.add_table( - model.fqn, columns_to_types, dialect=model.dialect, normalize=False - ) - except SchemaError as e: - if "nesting level:" in str(e): - logger.error( - "SQLMesh requires all model names and references to have the same level of nesting." - ) - raise +def load_sql_models(path: Path) -> t.List[Model]: + assert _defaults + assert _cache + with open(path, "r", encoding="utf-8") as file: + expressions = parse(file.read(), default_dialect=_defaults["dialect"]) + models = load_sql_based_models(expressions, path=Path(path).absolute(), **_defaults) -@dataclass -class LoadedProject: - macros: MacroRegistry - jinja_macros: JinjaMacroRegistry - models: UniqueKeyDict[str, Model] - audits: UniqueKeyDict[str, Audit] - metrics: UniqueKeyDict[str, Metric] - dag: DAG[str] + return [] if _cache.put(models, path) else models + + +def get_variables(gateway_name: t.Optional[str] = None) -> t.Dict[str, t.Any]: + assert _config_essentials + + gateway_name = gateway_name or _selected_gateway + + try: + gateway = _config_essentials["gateways"].get(gateway_name) + except ConfigError: + from sqlmesh.core.console import get_console + + get_console().log_warning( + f"Gateway '{gateway_name}' not found in project '{_config_essentials['project']}'." + ) + gateway = None + + return { + **_config_essentials["variables"], + **(gateway.variables if gateway else {}), + c.GATEWAY: gateway_name, + } class Loader(abc.ABC): """Abstract base class to load macros and models for a context""" - def __init__(self) -> None: + def __init__(self, context: GenericContext, path: Path) -> None: + # This ensures pandas is imported before any model loading happens in the forked process + # to avoid macOS fork() safety issues, see https://stackoverflow.com/a/52230415. Without + # it, the following error was observerd in a macOS 15.5 system: + # + # "+[NSMutableString initialize] may have been in progress in another thread when fork() was called." + import pandas as pd # noqa + + from sqlmesh.core.console import get_console + self._path_mtimes: t.Dict[Path, float] = {} - self._dag: DAG[str] = DAG() + self.context = context + self.config_path = path + self.config = self.context.configs[self.config_path] + self._variables_by_gateway: t.Dict[str, t.Dict[str, t.Any]] = {} + self._console = get_console() + + self.config_essentials = { + "project": self.config.project, + "variables": self.config.variables, + "gateways": self.config.gateways, + } + _init_model_defaults(self.config_essentials, self.context.selected_gateway) - def load(self, context: GenericContext, update_schemas: bool = True) -> LoadedProject: + def load(self) -> LoadedProject: """ Loads all macros and models in the context's path. - Args: - context: The context to load macros and models for. - update_schemas: Convert star projections to explicit columns. + Returns: + A loaded project object. """ - # python files are cached by the system - # need to manually clear here so we can reload macros - linecache.clearcache() + with sys_path(self.config_path): + # python files are cached by the system + # need to manually clear here so we can reload macros + linecache.clearcache() + self._path_mtimes.clear() - self._context = context - self._path_mtimes.clear() - self._dag = DAG() + self._load_materializations() + signals = self._load_signals() - self._load_materializations() + config_mtimes: t.Dict[Path, t.List[float]] = defaultdict(list) - config_mtimes: t.Dict[Path, t.List[float]] = defaultdict(list) - for context_path, config in self._context.configs.items(): - for config_file in context_path.glob("config.*"): + for config_file in self.config_path.glob("config.*"): self._track_file(config_file) - config_mtimes[context_path].append(self._path_mtimes[config_file]) + config_mtimes[self.config_path].append(self._path_mtimes[config_file]) - for config_file in c.SQLMESH_PATH.glob("config.*"): - self._track_file(config_file) - config_mtimes[c.SQLMESH_PATH].append(self._path_mtimes[config_file]) + for config_file in c.SQLMESH_PATH.glob("config.*"): + self._track_file(config_file) + config_mtimes[c.SQLMESH_PATH].append(self._path_mtimes[config_file]) - self._config_mtimes = {path: max(mtimes) for path, mtimes in config_mtimes.items()} + self._config_mtimes = {path: max(mtimes) for path, mtimes in config_mtimes.items()} - macros, jinja_macros = self._load_scripts() - models = self._load_models( - macros, jinja_macros, context.gateway or context.config.default_gateway - ) + macros, jinja_macros = self._load_scripts() + audits: UniqueKeyDict[str, ModelAudit] = UniqueKeyDict("audits") + standalone_audits: UniqueKeyDict[str, StandaloneAudit] = UniqueKeyDict( + "standalone_audits" + ) + + for name, audit in self._load_audits(macros=macros, jinja_macros=jinja_macros).items(): + if isinstance(audit, ModelAudit): + audits[name] = audit + else: + standalone_audits[name] = audit + + models = self._load_models( + macros, + jinja_macros, + self.context.selected_gateway, + audits, + signals, + ) - for model in models.values(): - self._add_model_to_dag(model) + metrics = self._load_metrics() - if update_schemas: - update_model_schemas( - self._dag, - models, - self._context.path, + requirements, excluded_requirements = self._load_requirements() + + environment_statements = self._load_environment_statements(macros=macros) + + user_rules = self._load_linting_rules() + + model_test_metadata = self.load_model_tests() + + project = LoadedProject( + macros=macros, + jinja_macros=jinja_macros, + models=models, + audits=audits, + standalone_audits=standalone_audits, + metrics=expand_metrics(metrics), + requirements=requirements, + excluded_requirements=excluded_requirements, + environment_statements=environment_statements, + user_rules=user_rules, + model_test_metadata=model_test_metadata, ) - for model in models.values(): - # The model definition can be validated correctly only after the schema is set. - model.validate_definition() - - metrics = self._load_metrics() - - project = LoadedProject( - macros=macros, - jinja_macros=jinja_macros, - models=models, - audits=self._load_audits(macros=macros, jinja_macros=jinja_macros), - metrics=expand_metrics(metrics), - dag=self._dag, - ) - return project + return project def reload_needed(self) -> bool: """ @@ -162,7 +271,7 @@ def reload_needed(self) -> bool: """ return any( not path.exists() or path.stat().st_mtime > initial_mtime - for path, initial_mtime in self._path_mtimes.items() + for path, initial_mtime in self._path_mtimes.copy().items() ) @abc.abstractmethod @@ -171,7 +280,12 @@ def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]: @abc.abstractmethod def _load_models( - self, macros: MacroRegistry, jinja_macros: JinjaMacroRegistry, gateway: t.Optional[str] + self, + macros: MacroRegistry, + jinja_macros: JinjaMacroRegistry, + gateway: t.Optional[str], + audits: UniqueKeyDict[str, ModelAudit], + signals: UniqueKeyDict[str, signal], ) -> UniqueKeyDict[str, Model]: """Loads all models.""" @@ -181,62 +295,188 @@ def _load_audits( ) -> UniqueKeyDict[str, Audit]: """Loads all audits.""" - def _load_materializations(self) -> None: + def _load_environment_statements(self, macros: MacroRegistry) -> t.List[EnvironmentStatements]: + """Loads environment statements.""" + return [] + + def load_materializations(self) -> None: """Loads custom materializations.""" + def _load_materializations(self) -> None: + pass + + def _load_signals(self) -> UniqueKeyDict[str, signal]: + return UniqueKeyDict("signals") + def _load_metrics(self) -> UniqueKeyDict[str, MetricMeta]: return UniqueKeyDict("metrics") - def _load_external_models(self, gateway: t.Optional[str] = None) -> UniqueKeyDict[str, Model]: + def _load_external_models( + self, + audits: UniqueKeyDict[str, ModelAudit], + cache: CacheBase, + gateway: t.Optional[str] = None, + ) -> UniqueKeyDict[str, Model]: models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") - for context_path, config in self._context.configs.items(): - external_models_yaml = Path(context_path / c.EXTERNAL_MODELS_YAML) - deprecated_yaml = Path(context_path / c.EXTERNAL_MODELS_DEPRECATED_YAML) - external_models_path = context_path / c.EXTERNAL_MODELS - - paths_to_load = [] - if external_models_yaml.exists(): - paths_to_load.append(external_models_yaml) - elif deprecated_yaml.exists(): - paths_to_load.append(deprecated_yaml) + external_models_yaml = Path(self.config_path / c.EXTERNAL_MODELS_YAML) + deprecated_yaml = Path(self.config_path / c.EXTERNAL_MODELS_DEPRECATED_YAML) + external_models_path = self.config_path / c.EXTERNAL_MODELS - if external_models_path.exists() and external_models_path.is_dir(): - paths_to_load.extend(external_models_path.glob("*.yaml")) + paths_to_load = [] + if external_models_yaml.exists(): + paths_to_load.append(external_models_yaml) + elif deprecated_yaml.exists(): + paths_to_load.append(deprecated_yaml) - for path in paths_to_load: - self._track_file(path) + if external_models_path.exists() and external_models_path.is_dir(): + paths_to_load.extend(self._glob_paths(external_models_path, extension=".yaml")) + def _load(path: Path) -> t.List[Model]: + try: with open(path, "r", encoding="utf-8") as file: - external_models: t.List[ExternalModel] = [] - for row in YAML().load(file.read()): - model = create_external_model( - **row, - dialect=config.model_defaults.dialect, - defaults=config.model_defaults.dict(), + yaml = YAML().load(file) + # Allow empty YAML files to return an empty list + if yaml is None: + return [] + return [ + create_external_model( + defaults=self.config.model_defaults.dict(), path=path, - project=config.project, - default_catalog=self._context.default_catalog, + project=self.config.project, + audit_definitions=audits, + **{ + "dialect": self.config.model_defaults.dialect, + "default_catalog": self.context.default_catalog, + **row, + }, ) - external_models.append(model) - - # external models with no explicit gateway defined form the base set - for model in (e for e in external_models if e.gateway is None): - models[model.fqn] = model + for row in yaml + ] + except Exception as ex: + raise ConfigError(self._failed_to_load_model_error(path, ex), path) + + for path in paths_to_load: + self._track_file(path) + + external_models = cache.get_or_load_models(path, lambda: _load(path)) + + # external models with no explicit gateway defined form the base set + for model in external_models: + if model.gateway is None: + if model.fqn in models: + raise ConfigError( + self._failed_to_load_model_error( + path, f"Duplicate external model name: '{model.name}'." + ), + path, + ) + models[model.fqn] = model - # however, if there is a gateway defined, gateway-specific models take precedence - if gateway: - for model in (e for e in external_models if e.gateway == gateway): - models.update({model.fqn: model}) + # however, if there is a gateway defined, gateway-specific models take precedence + if gateway: + gateway = gateway.lower() + for model in external_models: + if model.gateway == gateway: + if model.fqn in models and models[model.fqn].gateway == gateway: + raise ConfigError( + self._failed_to_load_model_error( + path, f"Duplicate external model name: '{model.name}'." + ), + path, + ) + models.update({model.fqn: model}) return models - def _add_model_to_dag(self, model: Model) -> None: - self._dag.add(model.fqn, model.depends_on) + def _load_requirements(self) -> t.Tuple[t.Dict[str, str], t.Set[str]]: + """Loads Python dependencies from the lock file. + + Returns: + A tuple of requirements and excluded requirements. + """ + requirements: t.Dict[str, str] = {} + excluded_requirements: t.Set[str] = set() + + requirements_path = self.config_path / c.REQUIREMENTS + if requirements_path.is_file(): + with open(requirements_path, "r", encoding="utf-8") as file: + for line in file: + line = line.strip() + if line.startswith("^"): + excluded_requirements.add(line[1:]) + continue + + args = [k.strip() for k in line.split("==")] + if len(args) != 2: + raise ConfigError( + f"Invalid lock file entry '{line.strip()}'. Only 'dep==ver' is supported", + requirements_path, + ) + dep, ver = args + other_ver = requirements.get(dep, ver) + if ver != other_ver: + raise ConfigError( + f"Conflicting requirement {dep}: {ver} != {other_ver}. Fix your {c.REQUIREMENTS} file.", + requirements_path, + ) + requirements[dep] = ver + + return requirements, excluded_requirements + + def _load_linting_rules(self) -> RuleSet: + """Loads user linting rules""" + return RuleSet() + + def load_model_tests(self) -> t.List[ModelTestMetadata]: + """Loads YAML-based model tests""" + return [] + + def _glob_paths( + self, + path: Path, + ignore_patterns: t.Optional[t.List[str]] = None, + extension: t.Optional[str] = None, + ) -> t.Generator[Path, None, None]: + """ + Globs the provided path for the file extension but also removes any filepaths that match an ignore + pattern either set in constants or provided in config + + Args: + path: The filepath to glob + ignore_patterns: A list of patterns for glob to ignore + extension: The extension to check for in that path (checks recursively in zero or more subdirectories) + + Returns: + Matched paths that are not ignored + """ + ignore_patterns = ignore_patterns or [] + extension = extension or "" + + # We try to match both ignore_pattern itself and every file returned by glob, + # so that we will always ignore file names that do not appear in the latter. + ignored_filepaths = set(ignore_patterns) | { + ignored_path + for ignore_pattern in ignore_patterns + for ignored_path in glob.glob(str(self.config_path / ignore_pattern), recursive=True) + } + for filepath in path.glob(f"**/*{extension}"): + if any(filepath.match(ignored_filepath) for ignored_filepath in ignored_filepaths): + continue + + yield filepath def _track_file(self, path: Path) -> None: """Project file to track for modifications""" self._path_mtimes[path] = path.stat().st_mtime + def _failed_to_load_model_error(self, path: Path, error: t.Union[str, Exception]) -> str: + base_message = f"Failed to load model from file '{path}':" + if isinstance(error, ValidationError): + return validation_error_message(error, base_message) + # indent all lines of error message + error_message = str(error).replace("\n", "\n ") + return f"{base_message}\n\n {error_message}" + class SqlMeshLoader(Loader): """Loads macros and models for a context using the SQLMesh file formats""" @@ -250,18 +490,12 @@ def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]: macros_max_mtime: t.Optional[float] = None - for context_path, config in self._context.configs.items(): - for path in self._glob_paths(context_path / c.MACROS, config=config, extension=".py"): - if import_python_file(path, context_path): - self._track_file(path) - macro_file_mtime = self._path_mtimes[path] - macros_max_mtime = ( - max(macros_max_mtime, macro_file_mtime) - if macros_max_mtime - else macro_file_mtime - ) - - for path in self._glob_paths(context_path / c.MACROS, config=config, extension=".sql"): + for path in self._glob_paths( + self.config_path / c.MACROS, + ignore_patterns=self.config.ignore_patterns, + extension=".py", + ): + if import_python_file(path, self.config_path): self._track_file(path) macro_file_mtime = self._path_mtimes[path] macros_max_mtime = ( @@ -269,8 +503,21 @@ def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]: if macros_max_mtime else macro_file_mtime ) - with open(path, "r", encoding="utf-8") as file: - jinja_macros.add_macros(extractor.extract(file.read())) + + for path in self._glob_paths( + self.config_path / c.MACROS, + ignore_patterns=self.config.ignore_patterns, + extension=".sql", + ): + self._track_file(path) + macro_file_mtime = self._path_mtimes[path] + macros_max_mtime = ( + max(macros_max_mtime, macro_file_mtime) if macros_max_mtime else macro_file_mtime + ) + with open(path, "r", encoding="utf-8") as file: + jinja_macros.add_macros( + extractor.extract(file.read(), dialect=self.config.model_defaults.dialect) + ) self._macros_max_mtime = macros_max_mtime @@ -280,217 +527,405 @@ def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]: return macros, jinja_macros def _load_models( - self, macros: MacroRegistry, jinja_macros: JinjaMacroRegistry, gateway: t.Optional[str] + self, + macros: MacroRegistry, + jinja_macros: JinjaMacroRegistry, + gateway: t.Optional[str], + audits: UniqueKeyDict[str, ModelAudit], + signals: UniqueKeyDict[str, signal], ) -> UniqueKeyDict[str, Model]: """ Loads all of the models within the model directory with their associated audits into a Dict and creates the dag """ - models = self._load_sql_models(macros, jinja_macros) - models.update(self._load_external_models(gateway)) - models.update(self._load_python_models()) + cache = SqlMeshLoader._Cache(self, self.config_path) - return models + sql_models = self._load_sql_models(macros, jinja_macros, audits, signals, cache, gateway) + external_models = self._load_external_models(audits, cache, gateway) + python_models = self._load_python_models(macros, jinja_macros, audits, signals) + + all_model_names = list(sql_models) + list(external_models) + list(python_models) + duplicates = [name for name, count in Counter(all_model_names).items() if count > 1] + if duplicates: + raise ConfigError(f"Duplicate model name(s) found: {', '.join(duplicates)}.") + + return UniqueKeyDict("models", **sql_models, **external_models, **python_models) def _load_sql_models( - self, macros: MacroRegistry, jinja_macros: JinjaMacroRegistry + self, + macros: MacroRegistry, + jinja_macros: JinjaMacroRegistry, + audits: UniqueKeyDict[str, ModelAudit], + signals: UniqueKeyDict[str, signal], + cache: CacheBase, + gateway: t.Optional[str], ) -> UniqueKeyDict[str, Model]: """Loads the sql models into a Dict""" models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") - for context_path, config in self._context.configs.items(): - cache = SqlMeshLoader._Cache(self, context_path) - variables = self._variables(config) - - for path in self._glob_paths(context_path / c.MODELS, config=config, extension=".sql"): - if not os.path.getsize(path): - continue - - self._track_file(path) - - def _load() -> Model: - with open(path, "r", encoding="utf-8") as file: - try: - expressions = parse( - file.read(), default_dialect=config.model_defaults.dialect - ) - except SqlglotError as ex: - raise ConfigError( - f"Failed to parse a model definition at '{path}': {ex}." - ) - - return load_sql_based_model( - expressions, - defaults=config.model_defaults.dict(), - macros=macros, - jinja_macros=jinja_macros, - path=Path(path).absolute(), - module_path=context_path, - dialect=config.model_defaults.dialect, - time_column_format=config.time_column_format, - physical_schema_override=config.physical_schema_override, - project=config.project, - default_catalog=self._context.default_catalog, - variables=variables, - infer_names=config.model_naming.infer_names, - ) - - model = cache.get_or_load_model(path, _load) + paths: t.Set[Path] = set() + cached_paths: UniqueKeyDict[Path, t.List[Model]] = UniqueKeyDict("cached_paths") + + for path in self._glob_paths( + self.config_path / c.MODELS, + ignore_patterns=self.config.ignore_patterns, + extension=".sql", + ): + if not os.path.getsize(path): + continue + + self._track_file(path) + paths.add(path) + if cached_models := cache.get(path): + cached_paths[path] = cached_models + + for path, cached_models in cached_paths.items(): + paths.remove(path) + for model in cached_models: if model.enabled: models[model.fqn] = model - if isinstance(model, SeedModel): - seed_path = model.seed_path - self._track_file(seed_path) + if paths: + model_loading_defaults = dict( + get_variables=get_variables, + defaults=self.config.model_defaults.dict(), + macros=macros, + jinja_macros=jinja_macros, + audit_definitions=audits, + module_path=self.config_path, + dialect=self.config.model_defaults.dialect, + time_column_format=self.config.time_column_format, + physical_schema_mapping=self.config.physical_schema_mapping, + project=self.config.project, + default_catalog=self.context.default_catalog, + infer_names=self.config.model_naming.infer_names, + signal_definitions=signals, + default_catalog_per_gateway=self.context.default_catalog_per_gateway, + virtual_environment_mode=self.config.virtual_environment_mode, + ) + + with create_process_pool_executor( + initializer=_init_model_defaults, + initargs=( + self.config_essentials, + gateway, + model_loading_defaults, + cache, + self._console, + ), + max_workers=c.MAX_FORK_WORKERS, + ) as pool: + futures_to_paths = {pool.submit(load_sql_models, path): path for path in paths} + for future in concurrent.futures.as_completed(futures_to_paths): + path = futures_to_paths[future] + try: + loaded = future.result() + for model in loaded or cache.get(path): + if model.fqn in models: + raise ConfigError( + self._failed_to_load_model_error( + path, f"Duplicate SQL model name: '{model.name}'." + ), + path, + ) + elif model.enabled: + model._path = path + models[model.fqn] = model + except Exception as ex: + raise ConfigError(self._failed_to_load_model_error(path, ex), path) return models - def _load_python_models(self) -> UniqueKeyDict[str, Model]: + def _load_python_models( + self, + macros: MacroRegistry, + jinja_macros: JinjaMacroRegistry, + audits: UniqueKeyDict[str, ModelAudit], + signals: UniqueKeyDict[str, signal], + ) -> UniqueKeyDict[str, Model]: """Loads the python models into a Dict""" models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") registry = model_registry.registry() registry.clear() registered: t.Set[str] = set() - for context_path, config in self._context.configs.items(): - variables = self._variables(config) - model_registry._dialect = config.model_defaults.dialect - try: - for path in self._glob_paths( - context_path / c.MODELS, config=config, extension=".py" - ): - if not os.path.getsize(path): - continue + model_registry._dialect = self.config.model_defaults.dialect + try: + for path in self._glob_paths( + self.config_path / c.MODELS, + ignore_patterns=self.config.ignore_patterns, + extension=".py", + ): + if not os.path.getsize(path): + continue - self._track_file(path) - import_python_file(path, context_path) + self._track_file(path) + try: + import_python_file(path, self.config_path) new = registry.keys() - registered registered |= new for name in new: - model = registry[name].model( + for model in registry[name].models( + get_variables, path=path, - module_path=context_path, - defaults=config.model_defaults.dict(), - dialect=config.model_defaults.dialect, - time_column_format=config.time_column_format, - physical_schema_override=config.physical_schema_override, - project=config.project, - default_catalog=self._context.default_catalog, - variables=variables, - infer_names=config.model_naming.infer_names, - ) - if model.enabled: - models[model.fqn] = model - finally: - model_registry._dialect = None + module_path=self.config_path, + defaults=self.config.model_defaults.dict(), + macros=macros, + jinja_macros=jinja_macros, + dialect=self.config.model_defaults.dialect, + time_column_format=self.config.time_column_format, + physical_schema_mapping=self.config.physical_schema_mapping, + project=self.config.project, + default_catalog=self.context.default_catalog, + infer_names=self.config.model_naming.infer_names, + audit_definitions=audits, + signal_definitions=signals, + default_catalog_per_gateway=self.context.default_catalog_per_gateway, + virtual_environment_mode=self.config.virtual_environment_mode, + ): + if model.enabled: + models[model.fqn] = model + except Exception as ex: + raise ConfigError(self._failed_to_load_model_error(path, ex), path) + + finally: + model_registry._dialect = None return models + def load_materializations(self) -> None: + with sys_path(self.config_path): + self._load_materializations() + def _load_materializations(self) -> None: - """Loads custom materializations.""" - for context_path, config in self._context.configs.items(): - for path in self._glob_paths( - context_path / c.MATERIALIZATIONS, config=config, extension=".py" - ): - if os.path.getsize(path): - import_python_file(path, context_path) + for path in self._glob_paths( + self.config_path / c.MATERIALIZATIONS, + ignore_patterns=self.config.ignore_patterns, + extension=".py", + ): + if os.path.getsize(path): + import_python_file(path, self.config_path) + + def _load_signals(self) -> UniqueKeyDict[str, signal]: + """Loads signals for the built-in scheduler.""" + + base_signals = signal.get_registry() + + signals_max_mtime: t.Optional[float] = None + + for path in self._glob_paths( + self.config_path / c.SIGNALS, + ignore_patterns=self.config.ignore_patterns, + extension=".py", + ): + if os.path.getsize(path): + self._track_file(path) + signal_file_mtime = self._path_mtimes[path] + signals_max_mtime = ( + max(signals_max_mtime, signal_file_mtime) + if signals_max_mtime + else signal_file_mtime + ) + import_python_file(path, self.config_path) + + self._signals_max_mtime = signals_max_mtime + + signals = signal.get_registry() + signal.set_registry(base_signals) + + return signals def _load_audits( self, macros: MacroRegistry, jinja_macros: JinjaMacroRegistry ) -> UniqueKeyDict[str, Audit]: """Loads all the model audits.""" audits_by_name: UniqueKeyDict[str, Audit] = UniqueKeyDict("audits") - for context_path, config in self._context.configs.items(): - variables = self._variables(config) - for path in self._glob_paths(context_path / c.AUDITS, config=config, extension=".sql"): - self._track_file(path) - with open(path, "r", encoding="utf-8") as file: - expressions = parse(file.read(), default_dialect=config.model_defaults.dialect) - audits = load_multiple_audits( - expressions=expressions, - path=path, - module_path=context_path, - macros=macros, - jinja_macros=jinja_macros, - dialect=config.model_defaults.dialect, - default_catalog=self._context.default_catalog, - variables=variables, - ) - for audit in audits: - audits_by_name[audit.name] = audit + audits_max_mtime: t.Optional[float] = None + variables = get_variables() + + for path in self._glob_paths( + self.config_path / c.AUDITS, + ignore_patterns=self.config.ignore_patterns, + extension=".sql", + ): + self._track_file(path) + with open(path, "r", encoding="utf-8") as file: + audits_file_mtime = self._path_mtimes[path] + audits_max_mtime = ( + max(audits_max_mtime, audits_file_mtime) + if audits_max_mtime + else audits_file_mtime + ) + expressions = parse(file.read(), default_dialect=self.config.model_defaults.dialect) + audits = load_multiple_audits( + expressions=expressions, + path=path, + module_path=self.config_path, + macros=macros, + jinja_macros=jinja_macros, + dialect=self.config.model_defaults.dialect, + default_catalog=self.context.default_catalog, + variables=variables, + project=self.config.project, + ) + for audit in audits: + audits_by_name[audit.name] = audit + + self._audits_max_mtime = audits_max_mtime + return audits_by_name def _load_metrics(self) -> UniqueKeyDict[str, MetricMeta]: """Loads all metrics.""" metrics: UniqueKeyDict[str, MetricMeta] = UniqueKeyDict("metrics") - for context_path, config in self._context.configs.items(): - for path in self._glob_paths(context_path / c.METRICS, config=config, extension=".sql"): - if not os.path.getsize(path): - continue + for path in self._glob_paths( + self.config_path / c.METRICS, + ignore_patterns=self.config.ignore_patterns, + extension=".sql", + ): + if not os.path.getsize(path): + continue + self._track_file(path) + + with open(path, "r", encoding="utf-8") as file: + dialect = self.config.model_defaults.dialect + try: + for expression in parse(file.read(), default_dialect=dialect): + metric = load_metric_ddl(expression, path=path, dialect=dialect) + metrics[metric.name] = metric + except SqlglotError as ex: + raise ConfigError( + f"Failed to parse metric definitions at '{path}': {ex}.", path + ) + + return metrics + + def _load_environment_statements(self, macros: MacroRegistry) -> t.List[EnvironmentStatements]: + """Loads environment statements.""" + + if self.config.before_all or self.config.after_all: + statements = { + "before_all": self.config.before_all or [], + "after_all": self.config.after_all or [], + } + dialect = self.config.model_defaults.dialect + python_env = make_python_env( + [ + exp.maybe_parse(stmt, dialect=dialect) + for stmts in statements.values() + for stmt in stmts + ], + module_path=self.config_path, + jinja_macro_references=None, + macros=macros, + variables=get_variables(), + path=self.config_path, + ) + + return [ + EnvironmentStatements( + **statements, python_env=python_env, project=self.config.project or None + ) + ] + return [] + + def _load_linting_rules(self) -> RuleSet: + user_rules: UniqueKeyDict[str, type[Rule]] = UniqueKeyDict("rules") + + for path in self._glob_paths( + self.config_path / c.LINTER, + ignore_patterns=self.config.ignore_patterns, + extension=".py", + ): + if os.path.getsize(path): self._track_file(path) + module = import_python_file(path, self.config_path) + module_rules = subclasses(module.__name__, Rule, exclude={Rule}) + for user_rule in module_rules: + user_rules[user_rule.name] = user_rule - with open(path, "r", encoding="utf-8") as file: - dialect = config.model_defaults.dialect - try: - for expression in parse(file.read(), default_dialect=dialect): - metric = load_metric_ddl(expression, path=path, dialect=dialect) - metrics[metric.name] = metric - except SqlglotError as ex: - raise ConfigError(f"Failed to parse metric definitions at '{path}': {ex}.") + return RuleSet(user_rules.values()) - return metrics + def _load_model_test_file(self, path: Path) -> dict[str, ModelTestMetadata]: + """Load a single model test file.""" + model_test_metadata = {} - def _glob_paths( - self, path: Path, config: Config, extension: str - ) -> t.Generator[Path, None, None]: - """ - Globs the provided path for the file extension but also removes any filepaths that match an ignore - pattern either set in constants or provided in config + with open(path, "r", encoding="utf-8") as file: + source = file.read() + # If the user has specified a quoted/escaped gateway (e.g. "gateway: 'ma\tin'"), we need to + # parse it as YAML to match the gateway name stored in the config + gateway_line = GATEWAY_PATTERN.search(source) + gateway = YAML().load(gateway_line.group(0))["gateway"] if gateway_line else None - Args: - path: The filepath to glob - extension: The extension to check for in that path (checks recursively in zero or more subdirectories) + contents = yaml_load(source, variables=get_variables(gateway)) - Returns: - Matched paths that are not ignored - """ - for filepath in path.glob(f"**/*{extension}"): - for ignore_pattern in config.ignore_patterns: - if filepath.match(ignore_pattern): - break - else: - yield filepath - - def _variables(self, config: Config) -> t.Dict[str, t.Any]: - gateway_name = self._context.gateway or self._context.config.default_gateway_name - try: - gateway = config.get_gateway(gateway_name) - except ConfigError: - logger.warning("Gateway '%s' not found in project '%s'", gateway_name, config.project) - gateway = None - return { - **config.variables, - **(gateway.variables if gateway else {}), - c.GATEWAY: gateway_name, - } + for test_name, value in contents.items(): + model_test_metadata[test_name] = ModelTestMetadata( + path=path, test_name=test_name, body=value + ) + + return model_test_metadata - class _Cache: - def __init__(self, loader: SqlMeshLoader, context_path: Path): + def load_model_tests(self) -> t.List[ModelTestMetadata]: + """Loads YAML-based model tests""" + test_meta_list: t.List[ModelTestMetadata] = [] + + search_path = Path(self.config_path) / c.TESTS + + for yaml_file in itertools.chain( + search_path.glob("**/test*.yaml"), + search_path.glob("**/test*.yml"), + ): + if any( + yaml_file.match(ignore_pattern) + for ignore_pattern in self.config.ignore_patterns or [] + ): + continue + + test_meta_list.extend(self._load_model_test_file(yaml_file).values()) + + return test_meta_list + + class _Cache(CacheBase): + def __init__(self, loader: SqlMeshLoader, config_path: Path): self._loader = loader - self._context_path = context_path - self._model_cache = ModelCache(self._context_path / c.CACHE) + self.config_path = config_path + self._model_cache = ModelCache(self._loader.context.cache_dir) - def get_or_load_model(self, target_path: Path, loader: t.Callable[[], Model]) -> Model: - model = self._model_cache.get_or_load( + def get_or_load_models( + self, target_path: Path, loader: t.Callable[[], t.List[Model]] + ) -> t.List[Model]: + models = self._model_cache.get_or_load( self._cache_entry_name(target_path), self._model_cache_entry_id(target_path), loader=loader, ) - model._path = target_path - return model + + for model in models: + model._path = target_path + + return models + + def put(self, models: t.List[Model], path: Path) -> bool: + return self._model_cache.put( + models, + self._cache_entry_name(path), + self._model_cache_entry_id(path), + ) + + def get(self, path: Path) -> t.List[Model]: + models = self._model_cache.get( + self._cache_entry_name(path), + self._model_cache_entry_id(path), + ) + + for model in models: + model._path = path + + return models def _cache_entry_name(self, target_path: Path) -> str: - return "__".join(target_path.relative_to(self._context_path).parts).replace( + return "__".join(target_path.relative_to(self.config_path).parts).replace( target_path.suffix, "" ) @@ -498,16 +933,22 @@ def _model_cache_entry_id(self, model_path: Path) -> str: mtimes = [ self._loader._path_mtimes[model_path], self._loader._macros_max_mtime, - self._loader._config_mtimes.get(self._context_path), + self._loader._signals_max_mtime, + self._loader._audits_max_mtime, + self._loader._config_mtimes.get(self.config_path), self._loader._config_mtimes.get(c.SQLMESH_PATH), ] return "__".join( [ str(max(m for m in mtimes if m is not None)), - self._loader._context.config.fingerprint, - # We need to check default catalog since the provided config could not change but the - # gateway we are using could change, therefore potentially changing the default catalog - # which would then invalidate the cached model definition. - self._loader._context.default_catalog or "", + self._loader.config.fingerprint, + # default catalog can change outside sqlmesh (e.g., DB user's + # default catalog), and it is retained in cached model's fully + # qualified name + self._loader.context.default_catalog or "", + # gateway is configurable, and it is retained in a cached + # model's python environment if the @gateway macro variable is + # used in the model + self._loader.context.gateway or self._loader.config.default_gateway_name, ] ) diff --git a/sqlmesh/core/macros.py b/sqlmesh/core/macros.py index dda8bc1f47..af7c344081 100644 --- a/sqlmesh/core/macros.py +++ b/sqlmesh/core/macros.py @@ -1,18 +1,17 @@ from __future__ import annotations import inspect -import logging import sys import types import typing as t from enum import Enum -from functools import reduce +from functools import lru_cache, reduce from itertools import chain from pathlib import Path from string import Template +from datetime import datetime, date import sqlglot -from jinja2 import Environment from sqlglot import Generator, exp, parse_one from sqlglot.executor.env import ENV from sqlglot.executor.python import Python @@ -38,14 +37,21 @@ columns_to_types_all_known, registry_decorator, ) +from sqlmesh.utils.date import DatetimeRanges, to_datetime, to_date from sqlmesh.utils.errors import MacroEvalError, SQLMeshError -from sqlmesh.utils.jinja import JinjaMacroRegistry, has_jinja -from sqlmesh.utils.metaprogramming import Executable, prepare_env, print_exception +from sqlmesh.utils.metaprogramming import ( + Executable, + SqlValue, + format_evaluated_code_exception, + prepare_env, +) if t.TYPE_CHECKING: + from sqlglot.dialects.dialect import DialectType from sqlmesh.core._typing import TableName from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.snapshot import Snapshot + from sqlmesh.core.environment import EnvironmentNamingInfo if sys.version_info >= (3, 10): @@ -54,14 +60,16 @@ UNION_TYPES = (t.Union,) -logger = logging.getLogger(__name__) - - class RuntimeStage(Enum): LOADING = "loading" CREATING = "creating" EVALUATING = "evaluating" + PROMOTING = "promoting" + DEMOTING = "demoting" + AUDITING = "auditing" TESTING = "testing" + BEFORE_ALL = "before_all" + AFTER_ALL = "after_all" class MacroStrTemplate(Template): @@ -71,6 +79,25 @@ class MacroStrTemplate(Template): EXPRESSIONS_NAME_MAP = {} SQL = t.NewType("SQL", str) + +@lru_cache() +def get_supported_types() -> t.Dict[str, t.Any]: + from sqlmesh.core.context import ExecutionContext + + return { + "t": t, + "typing": t, + "List": t.List, + "Tuple": t.Tuple, + "Union": t.Union, + "DatetimeRanges": DatetimeRanges, + "exp": exp, + "SQL": SQL, + "MacroEvaluator": MacroEvaluator, + "ExecutionContext": ExecutionContext, + } + + for klass in sqlglot.Parser.EXPRESSION_PARSERS: name = klass if isinstance(klass, str) else klass.__name__ # type: ignore EXPRESSIONS_NAME_MAP[name.lower()] = name @@ -101,6 +128,17 @@ def _macro_str_replace(text: str) -> str: return f"self.template({text}, locals())" +class CaseInsensitiveMapping(t.Dict[str, t.Any]): + def __init__(self, data: t.Dict[str, t.Any]) -> None: + super().__init__(data) + + def __getitem__(self, key: str) -> t.Any: + return super().__getitem__(key.lower()) + + def get(self, key: str, default: t.Any = None, /) -> t.Any: + return super().get(key.lower(), default) + + class MacroDialect(Python): class Generator(Python.Generator): TRANSFORMS = { @@ -135,15 +173,17 @@ class MacroEvaluator: def __init__( self, - dialect: str = "", + dialect: DialectType = "", python_env: t.Optional[t.Dict[str, Executable]] = None, - jinja_env: t.Optional[Environment] = None, schema: t.Optional[MappingSchema] = None, runtime_stage: RuntimeStage = RuntimeStage.LOADING, + resolve_table: t.Optional[t.Callable[[str | exp.Table], str]] = None, resolve_tables: t.Optional[t.Callable[[exp.Expression], exp.Expression]] = None, snapshots: t.Optional[t.Dict[str, Snapshot]] = None, default_catalog: t.Optional[str] = None, - path: Path = Path(), + path: t.Optional[Path] = None, + environment_naming_info: t.Optional[EnvironmentNamingInfo] = None, + model_fqn: t.Optional[str] = None, ): self.dialect = dialect self.generator = MacroDialect().generator() @@ -158,14 +198,17 @@ def __init__( "MacroEvaluator": MacroEvaluator, } self.python_env = python_env or {} - self._jinja_env: t.Optional[Environment] = jinja_env self.macros = {normalize_macro_name(k): v.func for k, v in macro.get_registry().items()} + self.columns_to_types_called = False + self.default_catalog = default_catalog + self._schema = schema + self._resolve_table = resolve_table self._resolve_tables = resolve_tables - self.columns_to_types_called = False self._snapshots = snapshots if snapshots is not None else {} - self.default_catalog = default_catalog self._path = path + self._environment_naming_info = environment_naming_info + self._model_fqn = model_fqn prepare_env(self.python_env, self.env) for k, v in self.python_env.items(): @@ -174,7 +217,23 @@ def __init__( elif v.is_import and getattr(self.env.get(k), c.SQLMESH_MACRO, None): self.macros[normalize_macro_name(k)] = self.env[k] elif v.is_value: - self.locals[k] = self.env[k] + value = self.env[k] + if k in ( + c.SQLMESH_VARS, + c.SQLMESH_VARS_METADATA, + c.SQLMESH_BLUEPRINT_VARS, + c.SQLMESH_BLUEPRINT_VARS_METADATA, + ): + value = { + var_name: ( + self.parse_one(var_value.sql) + if isinstance(var_value, SqlValue) + else var_value + ) + for var_name, var_value in value.items() + } + + self.locals[k] = value def send( self, name: str, *args: t.Any, **kwargs: t.Any @@ -182,44 +241,17 @@ def send( func = self.macros.get(normalize_macro_name(name)) if not callable(func): - raise SQLMeshError(f"Macro '{name}' does not exist.") - - try: - # Bind the macro's actual parameters to its formal parameters - sig = inspect.signature(func) - bound = sig.bind(self, *args, **kwargs) - bound.apply_defaults() - except Exception as e: - print_exception(e, self.python_env) - raise MacroEvalError("Error trying to eval macro.") from e + raise MacroEvalError(f"Macro '{name}' does not exist.") try: - annotations = t.get_type_hints(func) - except NameError: # forward references aren't handled - annotations = {} - - # If the macro is annotated, we try coerce the actual parameters to the corresponding types - if annotations: - for arg, value in bound.arguments.items(): - typ = annotations.get(arg) - if not typ: - continue - - # Changes to bound.arguments will reflect in bound.args and bound.kwargs - # https://docs.python.org/3/library/inspect.html#inspect.BoundArguments.arguments - param = sig.parameters[arg] - if param.kind is inspect.Parameter.VAR_POSITIONAL: - bound.arguments[arg] = tuple(self._coerce(v, typ) for v in value) - elif param.kind is inspect.Parameter.VAR_KEYWORD: - bound.arguments[arg] = {k: self._coerce(v, typ) for k, v in value.items()} - else: - bound.arguments[arg] = self._coerce(value, typ) - - try: - return func(*bound.args, **bound.kwargs) + return call_macro( + func, self.dialect, self._path, provided_args=(self, *args), provided_kwargs=kwargs + ) # type: ignore except Exception as e: - print_exception(e, self.python_env) - raise MacroEvalError("Error trying to eval macro.") from e + raise MacroEvalError( + f"An error occurred during evaluation of '{name}'\n\n" + + format_evaluated_code_exception(e, self.python_env) + ) def transform( self, expression: exp.Expression @@ -233,35 +265,35 @@ def evaluate_macros( if isinstance(node, MacroVar): changed = True - variables = self.locals.get(c.SQLMESH_VARS, {}) - if node.name not in self.locals and node.name.lower() not in variables: + variables = self.variables + + # This makes all variables case-insensitive, e.g. @X is the same as @x. We do this + # for consistency, since `variables` and `blueprint_variables` are normalized. + var_name = node.name.lower() + + if var_name not in self.locals and var_name not in variables: if not isinstance(node.parent, StagedFilePath): raise SQLMeshError(f"Macro variable '{node.name}' is undefined.") return node - value = self.locals.get(node.name, variables.get(node.name.lower())) + # Precedence order is locals (e.g. @DEF) > blueprint variables > config variables + value = self.locals.get(var_name, variables.get(var_name)) if isinstance(value, list): return exp.convert( tuple( self.transform(v) if isinstance(v, exp.Expression) else v for v in value ) ) + return exp.convert( self.transform(value) if isinstance(value, exp.Expression) else value ) if isinstance(node, exp.Identifier) and "@" in node.this: - text = self.template(node.this, self.locals) + text = self.template(node.this, {}) if node.this != text: changed = True - node.args["this"] = text - return node - if node.is_string: - text = node.this - if has_jinja(text): - changed = True - node.set("this", self.jinja_env.from_string(node.this).render()) - return node + return exp.to_identifier(text, quoted=node.quoted or None) if isinstance(node, MacroFunc): changed = True return self.evaluate(node) @@ -279,7 +311,7 @@ def evaluate_macros( self.parse_one(node.sql(dialect=self.dialect, copy=False)) for node in transformed ] - elif isinstance(transformed, exp.Expression): + if isinstance(transformed, exp.Expression): return self.parse_one(transformed.sql(dialect=self.dialect, copy=False)) return transformed @@ -294,27 +326,18 @@ def template(self, text: t.Any, local_variables: t.Dict[str, t.Any]) -> str: Returns: The rendered string. """ - mapping = {} - - variables = self.locals.get(c.SQLMESH_VARS, {}) - - for k, v in chain(variables.items(), self.locals.items(), local_variables.items()): - # try to convert all variables into sqlglot expressions - # because they're going to be converted into strings in sql - # we use bare Exception instead of ValueError because there's - # a recursive error with MagicMock. - # we don't convert strings because that would result in adding quotes - if not isinstance(v, str): - try: - v = exp.convert(v) - except Exception: - pass - - if isinstance(v, exp.Expression): - v = v.sql(dialect=self.dialect) - mapping[k] = v - - return MacroStrTemplate(str(text)).safe_substitute(mapping) + # We try to convert all variables into sqlglot expressions because they're going to be converted + # into strings; in sql we don't convert strings because that would result in adding quotes + base_mapping = { + k.lower(): convert_sql(v, self.dialect) + for k, v in chain(self.variables.items(), self.locals.items(), local_variables.items()) + if k.lower() + not in ( + "engine_adapter", + "snapshot", + ) + } + return MacroStrTemplate(str(text)).safe_substitute(CaseInsensitiveMapping(base_mapping)) def evaluate(self, node: MacroFunc) -> exp.Expression | t.List[exp.Expression] | None: if isinstance(node, MacroDef): @@ -324,7 +347,9 @@ def evaluate(self, node: MacroFunc) -> exp.Expression | t.List[exp.Expression] | args[0] if len(args) == 1 else exp.Tuple(expressions=list(args)) ) else: - self.locals[node.name] = self.transform(node.expression) + # Make variables defined through `@DEF` case-insensitive + self.locals[node.name.lower()] = self.transform(node.expression) + return node if isinstance(node, (MacroSQL, MacroStrReplace)): @@ -354,8 +379,37 @@ def evaluate(self, node: MacroFunc) -> exp.Expression | t.List[exp.Expression] | return None if isinstance(result, (tuple, list)): - return [self.parse_one(item) for item in result if item is not None] - return self.parse_one(result) + result = [self.parse_one(item) for item in result if item is not None] + + if ( + len(result) == 1 + and isinstance(result[0], (exp.Array, exp.Tuple)) + and node.find_ancestor(MacroFunc) + ): + """ + if: + - the output of evaluating this node is being passed as an argument to another macro function + - and that output is something that _norm_var_arg_lambda() will unpack into varargs + > (a list containing a single item of type exp.Tuple/exp.Array) + then we will get inconsistent behaviour depending on if this node emits a list with a single item vs multiple items. + + In the first case, emitting a list containing a single array item will cause that array to get unpacked and its *members* passed to the calling macro + In the second case, emitting a list containing multiple array items will cause each item to get passed as-is to the calling macro + + To prevent this inconsistency, we wrap this node output in an exp.Array so that _norm_var_arg_lambda() can "unpack" that into the + actual argument we want to pass to the parent macro function + + Note we only do this for evaluation results that get passed as an argument to another macro, because when the final + result is given to something like SELECT, we still want that to be unpacked into a list of items like: + - SELECT ARRAY(1), ARRAY(2) + rather than a single item like: + - SELECT ARRAY(ARRAY(1), ARRAY(2)) + """ + result = [exp.Array(expressions=result)] + else: + result = self.parse_one(result) + + return result def eval_expression(self, node: t.Any) -> t.Any: """Converts a SQLGlot expression into executable Python code and evals it. @@ -374,10 +428,10 @@ def eval_expression(self, node: t.Any) -> t.Any: code = self.generator.generate(node) return eval(code, self.env, self.locals) except Exception as e: - print_exception(e, self.python_env) raise MacroEvalError( - f"Error trying to eval macro.\n\nGenerated code: {code}\n\nOriginal sql: {node}" - ) from e + f"Error trying to eval macro.\n\nGenerated code: {code}\n\nOriginal sql: {node}\n\n" + + format_evaluated_code_exception(e, self.python_env) + ) def parse_one( self, sql: str | exp.Expression, into: t.Optional[exp.IntoType] = None, **opts: t.Any @@ -395,17 +449,13 @@ def parse_one( """ return sqlglot.maybe_parse(sql, dialect=self.dialect, into=into, **opts) - @property - def jinja_env(self) -> Environment: - if not self._jinja_env: - jinja_env_methods = {**self.locals, **self.env} - del jinja_env_methods["self"] - self._jinja_env = JinjaMacroRegistry().build_environment(**jinja_env_methods) - return self._jinja_env - def columns_to_types(self, model_name: TableName | exp.Column) -> t.Dict[str, exp.DataType]: """Returns the columns-to-types mapping corresponding to the specified model.""" - if self._schema is None or self._schema.empty: + + # We only return this dummy schema at load time, because if we don't actually know the + # target model's schema at creation/evaluation time, returning a dummy schema could lead + # to unintelligible errors when the query is executed + if (self._schema is None or self._schema.empty) and self.runtime_stage == "loading": self.columns_to_types_called = True return {"__schema_unavailable_at_load__": exp.DataType.build("unknown")} @@ -414,9 +464,16 @@ def columns_to_types(self, model_name: TableName | exp.Column) -> t.Dict[str, ex default_catalog=self.default_catalog, dialect=self.dialect, ) - columns_to_types = self._schema.find( - exp.to_table(normalized_model_name), ensure_data_types=True + model_name = exp.to_table(normalized_model_name) + + columns_to_types = ( + self._schema.find(model_name, ensure_data_types=True) if self._schema else None ) + if columns_to_types is None: + snapshot = self.get_snapshot(model_name) + if snapshot and snapshot.node.is_model: + columns_to_types = snapshot.node.columns_to_types # type: ignore + if columns_to_types is None: raise SQLMeshError(f"Schema for model '{model_name}' can't be statically determined.") @@ -432,6 +489,14 @@ def get_snapshot(self, model_name: TableName | exp.Column) -> t.Optional[Snapsho ) ) + def resolve_table(self, table: str | exp.Table) -> str: + """Gets the physical table name for a given model.""" + if not self._resolve_table: + raise SQLMeshError( + "Macro evaluator not properly initialized with resolve_table lambda." + ) + return self._resolve_table(table) + def resolve_tables(self, query: exp.Expression) -> exp.Expression: """Resolves queries with references to SQLMesh model names to their physical tables.""" if not self._resolve_tables: @@ -445,6 +510,20 @@ def runtime_stage(self) -> RuntimeStage: """Returns the current runtime stage of the macro evaluation.""" return self.locals["runtime_stage"] + @property + def this_model(self) -> str: + """Returns the resolved name of the surrounding model.""" + this_model = self.locals.get("this_model") + if not this_model: + raise SQLMeshError("Model name is not available in the macro evaluator.") + return this_model.sql(dialect=self.dialect, identify=True, comments=False) + + @property + def this_model_fqn(self) -> str: + if self._model_fqn is None: + raise SQLMeshError("Model name is not available in the macro evaluator.") + return self._model_fqn + @property def engine_adapter(self) -> EngineAdapter: engine_adapter = self.locals.get("engine_adapter") @@ -460,84 +539,58 @@ def gateway(self) -> t.Optional[str]: """Returns the gateway name.""" return self.var(c.GATEWAY) + @property + def snapshots(self) -> t.Dict[str, Snapshot]: + """Returns the snapshots if available.""" + return self._snapshots + + @property + def this_env(self) -> str: + """Returns the name of the current environment in before after all.""" + if "this_env" not in self.locals: + raise SQLMeshError("Environment name is only available in before_all and after_all") + return self.locals["this_env"] + + @property + def schemas(self) -> t.List[str]: + """Returns the schemas of the current environment in before after all macros.""" + if "schemas" not in self.locals: + raise SQLMeshError("Schemas are only available in before_all and after_all") + return self.locals["schemas"] + + @property + def views(self) -> t.List[str]: + """Returns the views of the current environment in before after all macros.""" + if "views" not in self.locals: + raise SQLMeshError("Views are only available in before_all and after_all") + return self.locals["views"] + def var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]: """Returns the value of the specified variable, or the default value if it doesn't exist.""" - return (self.locals.get(c.SQLMESH_VARS) or {}).get(var_name.lower(), default) + return { + **(self.locals.get(c.SQLMESH_VARS) or {}), + **(self.locals.get(c.SQLMESH_VARS_METADATA) or {}), + }.get(var_name.lower(), default) + + def blueprint_var(self, var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]: + """Returns the value of the specified blueprint variable, or the default value if it doesn't exist.""" + return { + **(self.locals.get(c.SQLMESH_BLUEPRINT_VARS) or {}), + **(self.locals.get(c.SQLMESH_BLUEPRINT_VARS_METADATA) or {}), + }.get(var_name.lower(), default) + + @property + def variables(self) -> t.Dict[str, t.Any]: + return { + **self.locals.get(c.SQLMESH_VARS, {}), + **self.locals.get(c.SQLMESH_VARS_METADATA, {}), + **self.locals.get(c.SQLMESH_BLUEPRINT_VARS, {}), + **self.locals.get(c.SQLMESH_BLUEPRINT_VARS_METADATA, {}), + } def _coerce(self, expr: exp.Expression, typ: t.Any, strict: bool = False) -> t.Any: """Coerces the given expression to the specified type on a best-effort basis.""" - base_err_msg = f"Failed to coerce expression '{expr}' to type '{typ}'." - try: - if typ is None or typ is t.Any: - return expr - base = t.get_origin(typ) or typ - - # We need to handle Union and TypeVars first since we cannot use isinstance with it - if base in UNION_TYPES: - for branch in t.get_args(typ): - try: - return self._coerce(expr, branch, True) - except Exception: - pass - raise SQLMeshError(base_err_msg) - if base is SQL and isinstance(expr, exp.Expression): - return expr.sql(self.dialect) - - if isinstance(expr, base): - return expr - if issubclass(base, exp.Expression): - d = Dialect.get_or_raise(self.dialect) - into = base if base in d.parser_class.EXPRESSION_PARSERS else None - if into is None: - if isinstance(expr, exp.Literal): - coerced = parse_one(expr.this) - else: - raise SQLMeshError( - f"{base_err_msg} Coercion to {base} requires a literal expression." - ) - else: - coerced = parse_one( - expr.this if isinstance(expr, exp.Literal) else expr.sql(), into=into - ) - if isinstance(coerced, base): - return coerced - raise SQLMeshError(base_err_msg) - - if base in (int, float, str) and isinstance(expr, exp.Literal): - return base(expr.this) - if base is str and isinstance(expr, exp.Column) and not expr.table: - return expr.name - if base is bool and isinstance(expr, exp.Boolean): - return expr.this - # if base is str and isinstance(expr, exp.Expression): - # return expr.sql(self.dialect) - if base is tuple and isinstance(expr, (exp.Tuple, exp.Array)): - generic = t.get_args(typ) - if not generic: - return tuple(expr.expressions) - if generic[-1] is ...: - return tuple(self._coerce(expr, generic[0]) for expr in expr.expressions) - elif len(generic) == len(expr.expressions): - return tuple( - self._coerce(expr, generic[i]) for i, expr in enumerate(expr.expressions) - ) - raise SQLMeshError(f"{base_err_msg} Expected {len(generic)} items.") - if base is list and isinstance(expr, (exp.Array, exp.Tuple)): - generic = t.get_args(typ) - if not generic: - return expr.expressions - return [self._coerce(expr, generic[0]) for expr in expr.expressions] - raise SQLMeshError(base_err_msg) - except Exception: - if strict: - raise - logger.error( - "Coercion of expression '%s' to type '%s' failed. Using non coerced expression at '%s'", - expr, - typ, - self._path, - ) - return expr + return _coerce(expr, typ, self.dialect, self._path, strict) class macro(registry_decorator): @@ -599,7 +652,7 @@ def substitute( ) -> exp.Expression | t.List[exp.Expression] | None: if isinstance(node, (exp.Identifier, exp.Var)): if not isinstance(node.parent, exp.Column): - name = node.name + name = node.name.lower() if name in args: return args[name].copy() if name in evaluator.locals: @@ -618,7 +671,13 @@ def substitute( if len(items) == 1: item = items[0] - expressions = item.expressions if isinstance(item, (exp.Array, exp.Tuple)) else item + expressions = ( + item.expressions + if isinstance(item, (exp.Array, exp.Tuple)) + else [item.this] + if isinstance(item, exp.Paren) + else item + ) else: expressions = items @@ -626,7 +685,7 @@ def substitute( return expressions, lambda args: func.this.transform( substitute, { - expression.name: arg + expression.name.lower(): arg for expression, arg in zip( func.expressions, args.expressions if isinstance(args, exp.Tuple) else [args] ) @@ -834,7 +893,9 @@ def star( if exclude and not isinstance(exclude, (exp.Array, exp.Tuple)): raise SQLMeshError(f"Invalid exclude '{exclude}'. Expected an array.") if except_ != exp.tuple_(): - logger.warning( + from sqlmesh.core.console import get_console + + get_console().log_warning( "The 'except_' argument in @STAR will soon be deprecated. Use 'exclude' instead." ) if not isinstance(exclude, (exp.Array, exp.Tuple)): @@ -851,7 +912,9 @@ def star( for excluded in exclude.expressions or except_.expressions } quoted = quote_identifiers.this - table_identifier = alias.name or relation.name + table_identifier = normalize_identifiers( + alias if alias.name else relation, dialect=evaluator.dialect + ).name columns_to_types = { k: v for k, v in evaluator.columns_to_types(relation).items() if k not in excluded_names @@ -869,20 +932,29 @@ def star( exp.column(column, table=table_identifier, quoted=quoted).as_( f"{prefix.this}{column}{suffix.this}", quoted=quoted ) - for column, type_ in evaluator.columns_to_types(relation).items() + for column, type_ in columns_to_types.items() ] @macro() -def generate_surrogate_key(_: MacroEvaluator, *fields: exp.Expression) -> exp.Func: - """Generates a surrogate key for the given fields. +def generate_surrogate_key( + evaluator: MacroEvaluator, + *fields: exp.Expression, + hash_function: exp.Literal = exp.Literal.string("MD5"), +) -> exp.Func: + """Generates a surrogate key (string) for the given fields. Example: >>> from sqlglot import parse_one >>> from sqlmesh.core.macros import MacroEvaluator + >>> >>> sql = "SELECT @GENERATE_SURROGATE_KEY(a, b, c) FROM foo" - >>> MacroEvaluator().transform(parse_one(sql)).sql() - "SELECT MD5(CONCAT(COALESCE(CAST(a AS TEXT), '_sqlmesh_surrogate_key_null_'), '|', COALESCE(CAST(b AS TEXT), '_sqlmesh_surrogate_key_null_'), '|', COALESCE(CAST(c AS TEXT), '_sqlmesh_surrogate_key_null_'))) FROM foo" + >>> MacroEvaluator(dialect="bigquery").transform(parse_one(sql, dialect="bigquery")).sql("bigquery") + "SELECT TO_HEX(MD5(CONCAT(COALESCE(CAST(a AS STRING), '_sqlmesh_surrogate_key_null_'), '|', COALESCE(CAST(b AS STRING), '_sqlmesh_surrogate_key_null_'), '|', COALESCE(CAST(c AS STRING), '_sqlmesh_surrogate_key_null_')))) FROM foo" + >>> + >>> sql = "SELECT @GENERATE_SURROGATE_KEY(a, b, c, hash_function := 'SHA256') FROM foo" + >>> MacroEvaluator(dialect="bigquery").transform(parse_one(sql, dialect="bigquery")).sql("bigquery") + "SELECT SHA256(CONCAT(COALESCE(CAST(a AS STRING), '_sqlmesh_surrogate_key_null_'), '|', COALESCE(CAST(b AS STRING), '_sqlmesh_surrogate_key_null_'), '|', COALESCE(CAST(c AS STRING), '_sqlmesh_surrogate_key_null_'))) FROM foo" """ string_fields: t.List[exp.Expression] = [] for i, field in enumerate(fields): @@ -895,7 +967,16 @@ def generate_surrogate_key(_: MacroEvaluator, *fields: exp.Expression) -> exp.Fu exp.Literal.string("_sqlmesh_surrogate_key_null_"), ) ) - return exp.func("MD5", exp.func("CONCAT", *string_fields)) + + func = exp.func( + hash_function.name, + exp.func("CONCAT", *string_fields), + dialect=evaluator.dialect, + ) + if isinstance(func, exp.MD5Digest): + func = exp.MD5(this=func.this) + + return func @macro() @@ -951,11 +1032,17 @@ def safe_div(_: MacroEvaluator, numerator: exp.Expression, denominator: exp.Expr @macro() def union( evaluator: MacroEvaluator, - type_: exp.Literal = exp.Literal.string("ALL"), - *tables: exp.Table, + *args: exp.Expression, ) -> exp.Query: """Returns a UNION of the given tables. Only choosing columns that have the same name and type. + Args: + evaluator: MacroEvaluator that invoked the macro + args: Variable arguments that can be: + - First argument can be a condition (exp.Condition) + - A union type ('ALL' or 'DISTINCT') as exp.Literal + - Tables (exp.Table) + Example: >>> from sqlglot import parse_one >>> from sqlglot.schema import MappingSchema @@ -963,11 +1050,36 @@ def union( >>> sql = "@UNION('distinct', foo, bar)" >>> MacroEvaluator(schema=MappingSchema({"foo": {"a": "int", "b": "string", "c": "string"}, "bar": {"c": "string", "a": "int", "b": "int"}})).transform(parse_one(sql)).sql() 'SELECT CAST(a AS INT) AS a, CAST(c AS TEXT) AS c FROM foo UNION SELECT CAST(a AS INT) AS a, CAST(c AS TEXT) AS c FROM bar' + >>> sql = "@UNION(True, 'distinct', foo, bar)" + >>> MacroEvaluator(schema=MappingSchema({"foo": {"a": "int", "b": "string", "c": "string"}, "bar": {"c": "string", "a": "int", "b": "int"}})).transform(parse_one(sql)).sql() + 'SELECT CAST(a AS INT) AS a, CAST(c AS TEXT) AS c FROM foo UNION SELECT CAST(a AS INT) AS a, CAST(c AS TEXT) AS c FROM bar' """ + + if not args: + raise SQLMeshError("At least one table is required for the @UNION macro.") + + arg_idx = 0 + # Check for condition + condition = evaluator.eval_expression(args[arg_idx]) + if isinstance(condition, bool): + arg_idx += 1 + if arg_idx >= len(args): + raise SQLMeshError("Expected more arguments after the condition of the `@UNION` macro.") + + # Check for union type + type_ = exp.Literal.string("ALL") + if isinstance(args[arg_idx], exp.Literal): + type_ = args[arg_idx] # type: ignore + arg_idx += 1 kind = type_.name.upper() if kind not in ("ALL", "DISTINCT"): raise SQLMeshError(f"Invalid type '{type_}'. Expected 'ALL' or 'DISTINCT'.") + # Remaining args should be tables + tables = [ + exp.to_table(e.sql(evaluator.dialect), dialect=evaluator.dialect) for e in args[arg_idx:] + ] + columns = { column for column, _ in reduce( @@ -982,6 +1094,10 @@ def union( if column in columns ] + # Skip the union if condition is False + if condition == False: + return exp.select(*projections).from_(tables[0]) + return reduce( lambda a, b: a.union(b, distinct=kind == "DISTINCT"), # type: ignore [exp.select(*projections).from_(t) for t in tables], @@ -1033,17 +1149,17 @@ def haversine_distance( @macro() def pivot( evaluator: MacroEvaluator, - column: exp.Column, - values: t.Union[exp.Array, exp.Tuple], - alias: exp.Boolean = exp.true(), - agg: exp.Literal = exp.Literal.string("SUM"), - cmp: exp.Literal = exp.Literal.string("="), - prefix: exp.Literal = exp.Literal.string(""), - suffix: exp.Literal = exp.Literal.string(""), - then_value: exp.Literal = exp.Literal.number(1), - else_value: exp.Literal = exp.Literal.number(0), - quote: exp.Boolean = exp.true(), - distinct: exp.Boolean = exp.false(), + column: SQL, + values: t.List[exp.Expression], + alias: bool = True, + agg: exp.Expression = exp.Literal.string("SUM"), + cmp: exp.Expression = exp.Literal.string("="), + prefix: exp.Expression = exp.Literal.string(""), + suffix: exp.Expression = exp.Literal.string(""), + then_value: SQL = SQL("1"), + else_value: SQL = SQL("0"), + quote: bool = True, + distinct: bool = False, ) -> t.List[exp.Expression]: """Returns a list of projections as a result of pivoting the given column on the given values. @@ -1052,18 +1168,30 @@ def pivot( >>> from sqlmesh.core.macros import MacroEvaluator >>> sql = "SELECT date_day, @PIVOT(status, ['cancelled', 'completed']) FROM rides GROUP BY 1" >>> MacroEvaluator().transform(parse_one(sql)).sql() - 'SELECT date_day, SUM(CASE WHEN status = \\'cancelled\\' THEN 1 ELSE 0 END) AS "\\'cancelled\\'", SUM(CASE WHEN status = \\'completed\\' THEN 1 ELSE 0 END) AS "\\'completed\\'" FROM rides GROUP BY 1' + 'SELECT date_day, SUM(CASE WHEN status = \\'cancelled\\' THEN 1 ELSE 0 END) AS "cancelled", SUM(CASE WHEN status = \\'completed\\' THEN 1 ELSE 0 END) AS "completed" FROM rides GROUP BY 1' + >>> sql = "SELECT @PIVOT(a, ['v'], then_value := tv, suffix := '_sfx', quote := FALSE)" + >>> MacroEvaluator(dialect="bigquery").transform(parse_one(sql)).sql("bigquery") + "SELECT SUM(CASE WHEN a = 'v' THEN tv ELSE 0 END) AS v_sfx" """ aggregates: t.List[exp.Expression] = [] - for value in values.expressions: - proj = f"{agg.this}(" - if distinct.this: + for value in values: + proj = f"{agg.name}(" + if distinct: proj += "DISTINCT " - proj += f"CASE WHEN {column} {cmp.this} {value} THEN {then_value} ELSE {else_value} END) " + + proj += f"CASE WHEN {column} {cmp.name} {value.sql(evaluator.dialect)} THEN {then_value} ELSE {else_value} END) " node = evaluator.parse_one(proj) - if alias.this: - node = node.as_(f"{prefix.this}{value}{suffix.this}", quoted=quote.this, copy=False) + + if alias: + node = node.as_( + f"{prefix.name}{value.name}{suffix.name}", + quoted=quote, + copy=False, + dialect=evaluator.dialect, + ) + aggregates.append(node) + return aggregates @@ -1100,6 +1228,204 @@ def var( return exp.convert(evaluator.var(var_name.this, default)) +@macro("BLUEPRINT_VAR") +def blueprint_var( + evaluator: MacroEvaluator, var_name: exp.Expression, default: t.Optional[exp.Expression] = None +) -> exp.Expression: + """Returns the value of a blueprint variable or the default value if the variable is not set.""" + if not var_name.is_string: + raise SQLMeshError( + f"Invalid blueprint variable name '{var_name.sql()}'. Expected a string literal." + ) + + return exp.convert(evaluator.blueprint_var(var_name.this, default)) + + +@macro() +def deduplicate( + evaluator: MacroEvaluator, + relation: exp.Expression, + partition_by: t.List[exp.Expression], + order_by: t.List[str], +) -> exp.Query: + """Returns a QUERY to deduplicate rows within a table + + Args: + relation: table or CTE name to deduplicate + partition_by: column names, or expressions to use to identify a window of rows out of which to select one as the deduplicated row + order_by: A list of strings representing the ORDER BY clause + + Example: + >>> from sqlglot import parse_one + >>> from sqlglot.schema import MappingSchema + >>> from sqlmesh.core.macros import MacroEvaluator + >>> sql = "@deduplicate(demo.table, [user_id, cast(timestamp as date)], ['timestamp desc', 'status asc'])" + >>> MacroEvaluator().transform(parse_one(sql)).sql() + 'SELECT * FROM demo.table QUALIFY ROW_NUMBER() OVER (PARTITION BY user_id, CAST(timestamp AS DATE) ORDER BY timestamp DESC, status ASC) = 1' + """ + if not isinstance(partition_by, list): + raise SQLMeshError( + "partition_by must be a list of columns: [, cast( as )]" + ) + + if not isinstance(order_by, list): + raise SQLMeshError( + "order_by must be a list of strings, optional - nulls ordering: [' nulls ']" + ) + + partition_clause = exp.tuple_(*partition_by) + + order_expressions = [ + evaluator.transform(parse_one(order_item, into=exp.Ordered, dialect=evaluator.dialect)) + for order_item in order_by + ] + + if not order_expressions: + raise SQLMeshError( + "order_by must be a list of strings, optional - nulls ordering: [' nulls ']" + ) + + order_clause = exp.Order(expressions=order_expressions) + + window_function = exp.Window( + this=exp.RowNumber(), partition_by=partition_clause, order=order_clause + ) + + first_unique_row = window_function.eq(1) + + query = exp.select("*").from_(relation).qualify(first_unique_row) + + return query + + +@macro() +def date_spine( + evaluator: MacroEvaluator, + datepart: exp.Expression, + start_date: exp.Expression, + end_date: exp.Expression, +) -> exp.Select: + """Returns a query that produces a date spine with the given datepart, and range of start_date and end_date. Useful for joining as a date lookup table. + + Args: + datepart: The datepart to use for the date spine - day, week, month, quarter, year + start_date: The start date for the date spine in format YYYY-MM-DD + end_date: The end date for the date spine in format YYYY-MM-DD + + Example: + >>> from sqlglot import parse_one + >>> from sqlglot.schema import MappingSchema + >>> from sqlmesh.core.macros import MacroEvaluator + >>> sql = "@date_spine('week', '2022-01-20', '2024-12-16')" + >>> MacroEvaluator().transform(parse_one(sql)).sql() + "SELECT date_week FROM UNNEST(GENERATE_DATE_ARRAY(CAST(\'2022-01-20\' AS DATE), CAST(\'2024-12-16\' AS DATE), INTERVAL \'1\' WEEK)) AS _exploded(date_week)" + """ + datepart_name = datepart.name.lower() + if datepart_name not in ("day", "week", "month", "quarter", "year"): + raise SQLMeshError( + f"Invalid datepart '{datepart_name}'. Expected: 'day', 'week', 'month', 'quarter', or 'year'" + ) + + start_date_name = start_date.name + end_date_name = end_date.name + + try: + if start_date.is_string and end_date.is_string: + start_date_obj = datetime.strptime(start_date_name, "%Y-%m-%d").date() + end_date_obj = datetime.strptime(end_date_name, "%Y-%m-%d").date() + else: + start_date_obj = None + end_date_obj = None + except Exception as e: + raise SQLMeshError( + f"Invalid date format - start_date and end_date must be in format: YYYY-MM-DD. Error: {e}" + ) + + if start_date_obj and end_date_obj: + if start_date_obj > end_date_obj: + raise SQLMeshError( + f"Invalid date range - start_date '{start_date_name}' is after end_date '{end_date_name}'." + ) + + start_date = exp.cast(start_date, "DATE") + end_date = exp.cast(end_date, "DATE") + + if datepart_name == "quarter" and evaluator.dialect in ( + "spark", + "spark2", + "databricks", + "postgres", + ): + date_interval = exp.Interval(this=exp.Literal.number(3), unit=exp.var("month")) + else: + date_interval = exp.Interval(this=exp.Literal.number(1), unit=exp.var(datepart_name)) + + generate_date_array = exp.func( + "GENERATE_DATE_ARRAY", + start_date, + end_date, + date_interval, + ) + + alias_name = f"date_{datepart_name}" + exploded = exp.alias_(exp.func("unnest", generate_date_array), "_exploded", table=[alias_name]) + + return exp.select(alias_name).from_(exploded) + + +@macro() +def resolve_template( + evaluator: MacroEvaluator, + template: exp.Literal, + mode: str = "literal", +) -> t.Union[exp.Literal, exp.Table]: + """ + Generates either a String literal or an exp.Table representing a physical table location, based on rendering the provided template String literal. + + Note: It relies on the @this_model variable being available in the evaluation context (@this_model resolves to an exp.Table object + representing the current physical table). + Therefore, the @resolve_template macro must be used at creation or evaluation time and not at load time. + + Args: + template: Template string literal. Can contain the following placeholders: + @{catalog_name} -> replaced with the catalog of the exp.Table returned from @this_model + @{schema_name} -> replaced with the schema of the exp.Table returned from @this_model + @{table_name} -> replaced with the name of the exp.Table returned from @this_model + mode: What to return. + 'literal' -> return an exp.Literal string + 'table' -> return an exp.Table + + Example: + >>> from sqlglot import parse_one, exp + >>> from sqlmesh.core.macros import MacroEvaluator, RuntimeStage + >>> sql = "@resolve_template('s3://data-bucket/prod/@{catalog_name}/@{schema_name}/@{table_name}')" + >>> evaluator = MacroEvaluator(runtime_stage=RuntimeStage.CREATING) + >>> evaluator.locals.update({"this_model": exp.to_table("test_catalog.sqlmesh__test.test__test_model__2517971505")}) + >>> evaluator.transform(parse_one(sql)).sql() + "'s3://data-bucket/prod/test_catalog/sqlmesh__test/test__test_model__2517971505'" + """ + if "this_model" in evaluator.locals: + this_model = exp.to_table(evaluator.locals["this_model"], dialect=evaluator.dialect) + template_str: str = template.this + result = ( + template_str.replace("@{catalog_name}", this_model.catalog) + .replace("@{schema_name}", this_model.db) + .replace("@{table_name}", this_model.name) + ) + + if mode.lower() == "table": + return exp.to_table(result, dialect=evaluator.dialect) + return exp.Literal.string(result) + if evaluator.runtime_stage != RuntimeStage.LOADING.value: + # only error if we are CREATING, EVALUATING or TESTING and @this_model is not present; this could indicate a bug + # otherwise, for LOADING, it's a no-op + raise SQLMeshError( + "@this_model must be present in the macro evaluation context in order to use @resolve_template" + ) + + return template + + def normalize_macro_name(name: str) -> str: """Prefix macro name with @ and upcase""" return f"@{name.upper()}" @@ -1107,3 +1433,185 @@ def normalize_macro_name(name: str) -> str: for m in macro.get_registry().values(): setattr(m, c.SQLMESH_BUILTIN, True) + + +def call_macro( + func: t.Callable, + dialect: DialectType, + path: t.Optional[Path], + provided_args: t.Tuple[t.Any, ...], + provided_kwargs: t.Dict[str, t.Any], + **optional_kwargs: t.Any, +) -> t.Any: + # Bind the macro's actual parameters to its formal parameters + sig = inspect.signature(func) + + if optional_kwargs: + provided_kwargs = provided_kwargs.copy() + + for k, v in optional_kwargs.items(): + if k in sig.parameters: + provided_kwargs[k] = v + + bound = sig.bind(*provided_args, **provided_kwargs) + bound.apply_defaults() + + try: + annotations = t.get_type_hints(func, localns=get_supported_types()) + except (NameError, TypeError): # forward references aren't handled + annotations = {} + + # If the macro is annotated, we try coerce the actual parameters to the corresponding types + if annotations: + for arg, value in bound.arguments.items(): + typ = annotations.get(arg) + if not typ: + continue + + # Changes to bound.arguments will reflect in bound.args and bound.kwargs + # https://docs.python.org/3/library/inspect.html#inspect.BoundArguments.arguments + param = sig.parameters[arg] + if param.kind is inspect.Parameter.VAR_POSITIONAL: + bound.arguments[arg] = tuple(_coerce(v, typ, dialect, path) for v in value) + elif param.kind is inspect.Parameter.VAR_KEYWORD: + bound.arguments[arg] = {k: _coerce(v, typ, dialect, path) for k, v in value.items()} + else: + bound.arguments[arg] = _coerce(value, typ, dialect, path) + + return func(*bound.args, **bound.kwargs) + + +def _coerce( + expr: t.Any, + typ: t.Any, + dialect: DialectType, + path: t.Optional[Path] = None, + strict: bool = False, +) -> t.Any: + """Coerces the given expression to the specified type on a best-effort basis.""" + base_err_msg = f"Failed to coerce expression '{expr}' to type '{typ}'." + try: + if typ is None or typ is t.Any or not isinstance(expr, exp.Expression): + return expr + base = t.get_origin(typ) or typ + + # We need to handle Union and TypeVars first since we cannot use isinstance with it + if base in UNION_TYPES: + for branch in t.get_args(typ): + try: + return _coerce(expr, branch, dialect, path, strict=True) + except Exception: + pass + raise SQLMeshError(base_err_msg) + if base is SQL and isinstance(expr, exp.Expression): + return expr.sql(dialect) + + if base is t.Literal: + if not isinstance(expr, (exp.Literal, exp.Boolean)): + raise SQLMeshError( + f"{base_err_msg} Coercion to {base} requires a literal expression." + ) + literal_type_args = t.get_args(typ) + try: + for literal_type_arg in literal_type_args: + expr_is_bool = isinstance(expr.this, bool) + literal_is_bool = isinstance(literal_type_arg, bool) + if (expr_is_bool and literal_is_bool and literal_type_arg == expr.this) or ( + not expr_is_bool + and not literal_is_bool + and str(literal_type_arg) == str(expr.this) + ): + return type(literal_type_arg)(expr.this) + except Exception: + raise SQLMeshError(base_err_msg) + raise SQLMeshError(base_err_msg) + + if isinstance(expr, base): + return expr + if issubclass(base, exp.Expression): + d = Dialect.get_or_raise(dialect) + into = base if base in d.parser_class.EXPRESSION_PARSERS else None + if into is None: + if isinstance(expr, exp.Literal): + coerced = parse_one(expr.this) + else: + raise SQLMeshError( + f"{base_err_msg} Coercion to {base} requires a literal expression." + ) + else: + coerced = parse_one( + expr.this if isinstance(expr, exp.Literal) else expr.sql(), into=into + ) + if isinstance(coerced, base): + return coerced + raise SQLMeshError(base_err_msg) + + if base in (int, float, str) and isinstance(expr, exp.Literal): + return base(expr.this) + if base is str and isinstance(expr, exp.Column) and not expr.table: + return expr.name + if base is bool and isinstance(expr, exp.Boolean): + return expr.this + if base is datetime and isinstance(expr, exp.Literal): + return to_datetime(expr.this) + if base is date and isinstance(expr, exp.Literal): + return to_date(expr.this) + if base is tuple and isinstance(expr, (exp.Tuple, exp.Array)): + generic = t.get_args(typ) + if not generic: + return tuple(expr.expressions) + if generic[-1] is ...: + return tuple(_coerce(expr, generic[0], dialect, path) for expr in expr.expressions) + if len(generic) == len(expr.expressions): + return tuple( + _coerce(expr, generic[i], dialect, path) + for i, expr in enumerate(expr.expressions) + ) + raise SQLMeshError(f"{base_err_msg} Expected {len(generic)} items.") + if base is list and isinstance(expr, (exp.Array, exp.Tuple)): + generic = t.get_args(typ) + if not generic: + return expr.expressions + return [_coerce(expr, generic[0], dialect, path) for expr in expr.expressions] + raise SQLMeshError(base_err_msg) + except Exception: + if strict: + raise + + from sqlmesh.core.console import get_console + + get_console().log_error( + f"Coercion of expression '{expr}' to type '{typ}' failed. Using non coerced expression at '{path}'", + ) + return expr + + +def convert_sql(v: t.Any, dialect: DialectType) -> t.Any: + try: + return _cache_convert_sql(v, dialect, v.__class__) + # dicts aren't hashable but are convertable + except TypeError: + return _convert_sql(v, dialect) + + +def _convert_sql(v: t.Any, dialect: DialectType) -> t.Any: + if not isinstance(v, str): + try: + v = exp.convert(v) + # we use bare Exception instead of ValueError because there's + # a recursive error with MagicMock. + except Exception: + pass + + if isinstance(v, exp.Expression): + if (isinstance(v, exp.Column) and not v.table) or ( + isinstance(v, exp.Identifier) or v.is_string + ): + return v.name + v = v.sql(dialect=dialect) + return v + + +@lru_cache(maxsize=16384) +def _cache_convert_sql(v: t.Any, dialect: DialectType, t: type) -> t.Any: + return _convert_sql(v, dialect) diff --git a/sqlmesh/core/metric/definition.py b/sqlmesh/core/metric/definition.py index 10da247b86..dd11cfd38d 100644 --- a/sqlmesh/core/metric/definition.py +++ b/sqlmesh/core/metric/definition.py @@ -10,11 +10,7 @@ from sqlmesh.core.node import str_or_exp_to_str from sqlmesh.utils import UniqueKeyDict from sqlmesh.utils.errors import ConfigError -from sqlmesh.utils.pydantic import ( - PydanticModel, - field_validator, - field_validator_v1_args, -) +from sqlmesh.utils.pydantic import PydanticModel, ValidationInfo, field_validator MeasureAndDimTables = t.Tuple[str, t.Tuple[str, ...]] @@ -83,7 +79,7 @@ class MetricMeta(PydanticModel, frozen=True): @field_validator("name", mode="before") @classmethod def _name_validator(cls, v: t.Any) -> str: - return cls._string_validator(v).lower() + return (cls._string_validator(v) or "").lower() @field_validator("dialect", "owner", "description", mode="before") @classmethod @@ -91,14 +87,9 @@ def _string_validator(cls, v: t.Any) -> t.Optional[str]: return str_or_exp_to_str(v) @field_validator("expression", mode="before") - @field_validator_v1_args - def _validate_expression( - cls, - v: t.Any, - values: t.Dict[str, t.Any], - ) -> exp.Expression: + def _validate_expression(cls, v: t.Any, info: ValidationInfo) -> exp.Expression: if isinstance(v, str): - dialect = values.get("dialect") + dialect = info.data.get("dialect") return d.parse_one(v, dialect=dialect) if isinstance(v, exp.Expression): return v diff --git a/sqlmesh/core/metric/rewriter.py b/sqlmesh/core/metric/rewriter.py index 3519a77e68..bbdc6c6135 100644 --- a/sqlmesh/core/metric/rewriter.py +++ b/sqlmesh/core/metric/rewriter.py @@ -57,7 +57,7 @@ def _build_sources(self, projections: t.List[exp.Expression]) -> SourceAggsAndJo return sources def _expand(self, select: exp.Select) -> None: - base = select.args["from"].this.find(exp.Table) + base = select.args["from_"].this.find(exp.Table) base_alias = base.alias_or_name base_name = exp.table_name(base) diff --git a/sqlmesh/core/model/__init__.py b/sqlmesh/core/model/__init__.py index baac47aec8..c2ab47d9e7 100644 --- a/sqlmesh/core/model/__init__.py +++ b/sqlmesh/core/model/__init__.py @@ -4,6 +4,7 @@ ) from sqlmesh.core.model.decorator import model as model from sqlmesh.core.model.definition import ( + AuditResult as AuditResult, ExternalModel as ExternalModel, Model as Model, PythonModel as PythonModel, @@ -14,6 +15,7 @@ create_seed_model as create_seed_model, create_sql_model as create_sql_model, load_sql_based_model as load_sql_based_model, + load_sql_based_models as load_sql_based_models, ) from sqlmesh.core.model.kind import ( CustomKind as CustomKind, @@ -32,7 +34,9 @@ SeedKind as SeedKind, TimeColumn as TimeColumn, ViewKind as ViewKind, + ManagedKind as ManagedKind, model_kind_validator as model_kind_validator, ) from sqlmesh.core.model.meta import ModelMeta as ModelMeta +from sqlmesh.core.model.schema import update_model_schemas as update_model_schemas from sqlmesh.core.model.seed import Seed as Seed diff --git a/sqlmesh/core/model/cache.py b/sqlmesh/core/model/cache.py index 49b57efcec..774bfa402b 100644 --- a/sqlmesh/core/model/cache.py +++ b/sqlmesh/core/model/cache.py @@ -1,20 +1,29 @@ from __future__ import annotations +import logging import typing as t from pathlib import Path from sqlglot import exp +from sqlglot.helper import seq_get from sqlglot.optimizer.simplify import gen +from sqlglot.schema import MappingSchema -from sqlmesh.core.model.definition import Model, SqlModel +from sqlmesh.core import constants as c +from sqlmesh.core.model.definition import ExternalModel, Model, SqlModel, _Model from sqlmesh.utils.cache import FileCache from sqlmesh.utils.hashing import crc32 -from sqlmesh.utils.pydantic import PydanticModel +from sqlmesh.utils.process import PoolExecutor, create_process_pool_executor +from dataclasses import dataclass -class SqlModelCacheEntry(PydanticModel): - model: SqlModel - full_depends_on: t.Set[str] +logger = logging.getLogger(__name__) + +if t.TYPE_CHECKING: + from sqlmesh.core.snapshot import SnapshotId + from sqlmesh.core.linter.rule import Rule + + T = t.TypeVar("T") class ModelCache: @@ -26,41 +35,54 @@ class ModelCache: def __init__(self, path: Path): self.path = path - self._file_cache: FileCache[SqlModelCacheEntry] = FileCache( + self._file_cache: FileCache[t.List[Model]] = FileCache( path, - SqlModelCacheEntry, prefix="model_definition", ) - def get_or_load(self, name: str, entry_id: str = "", *, loader: t.Callable[[], Model]) -> Model: + def get_or_load( + self, name: str, entry_id: str = "", *, loader: t.Callable[[], t.List[Model]] + ) -> t.List[Model]: """Returns an existing cached model definition or loads and caches a new one. - Args: name: The name of the entry. entry_id: The unique entry identifier. Used for cache invalidation. loader: Used to load a new model definition when no cached instance was found. - Returns: The model definition. """ cache_entry = self._file_cache.get(name, entry_id) - if cache_entry: - model = cache_entry.model - model._full_depends_on = cache_entry.full_depends_on - return model + if isinstance(cache_entry, list) and isinstance(seq_get(cache_entry, 0), _Model): + return cache_entry + + models = loader() + if isinstance(models, list) and isinstance(seq_get(models, 0), (SqlModel, ExternalModel)): + # make sure we preload full_depends_on + for model in models: + model.full_depends_on - loaded_model = loader() - if isinstance(loaded_model, SqlModel): - new_entry = SqlModelCacheEntry( - model=loaded_model, full_depends_on=loaded_model.full_depends_on - ) - self._file_cache.put(name, entry_id, value=new_entry) + self._file_cache.put(name, entry_id, value=models) + return models - return loaded_model + def put(self, models: t.List[Model], name: str, entry_id: str = "") -> bool: + if models and isinstance(seq_get(models, 0), (SqlModel, ExternalModel)): + # make sure we preload full_depends_on + for model in models: + model.full_depends_on + self._file_cache.put(name, entry_id, value=models) + return True -class OptimizedQueryCacheEntry(PydanticModel): + return False + + def get(self, name: str, entry_id: str = "") -> t.List[Model]: + return self._file_cache.get(name, entry_id) or [] + + +@dataclass +class OptimizedQueryCacheEntry: optimized_rendered_query: t.Optional[exp.Expression] + renderer_violations: t.Optional[t.Dict[type[Rule], t.Any]] class OptimizedQueryCache: @@ -73,43 +95,129 @@ class OptimizedQueryCache: def __init__(self, path: Path): self.path = path self._file_cache: FileCache[OptimizedQueryCacheEntry] = FileCache( - path, OptimizedQueryCacheEntry, prefix="optimized_query" + path, prefix="optimized_query" ) - def with_optimized_query(self, model: Model) -> bool: + def with_optimized_query(self, model: Model, name: t.Optional[str] = None) -> bool: """Adds an optimized query to the model's in-memory cache. Args: model: The model to add the optimized query to. + name: The cache entry name of the model. """ if not isinstance(model, SqlModel): return False - hash_data = _mapping_schema_hash_data(model.mapping_schema) - hash_data.append(gen(model.query)) - hash_data.append(str([(k, v) for k, v in model.sorted_python_env])) - hash_data.extend(model.jinja_macros.data_hash_values) - - name = f"{model.name}_{crc32(hash_data)}" + name = self._entry_name(model) if name is None else name cache_entry = self._file_cache.get(name) - if cache_entry: - if cache_entry.optimized_rendered_query: - model._query_renderer.update_cache( - cache_entry.optimized_rendered_query, optimized=True - ) - else: + try: # If the optimized rendered query is None, then there are likely adapter calls in the query # that prevent us from rendering it at load time. This means that we can safely set the # unoptimized cache to None as well to prevent attempts to render it downstream. - model._query_renderer.update_cache(None, optimized=False) - return True + optimized = cache_entry.optimized_rendered_query is not None + model._query_renderer.update_cache( + cache_entry.optimized_rendered_query, + cache_entry.renderer_violations, + optimized=optimized, + ) + return True + except Exception as ex: + logger.warning("Failed to load a cache entry '%s': %s", name, ex) + + self._put(name, model) + return False + + def put(self, model: Model) -> t.Optional[str]: + if not isinstance(model, SqlModel): + return None + name = self._entry_name(model) + + if self._file_cache.exists(name): + return name + + self._put(name, model) + return name + + def _put(self, name: str, model: SqlModel) -> None: optimized_query = model.render_query() - new_entry = OptimizedQueryCacheEntry(optimized_rendered_query=optimized_query) + + new_entry = OptimizedQueryCacheEntry( + optimized_rendered_query=optimized_query, + renderer_violations=model.violated_rules_for_query, + ) self._file_cache.put(name, value=new_entry) - return False + @staticmethod + def _entry_name(model: SqlModel) -> str: + hash_data = _mapping_schema_hash_data(model.mapping_schema) + hash_data.append(gen(model.query, comments=True)) + hash_data.append(str([gen(d) for d in model.macro_definitions])) + hash_data.append(str([(k, v) for k, v in model.sorted_python_env])) + hash_data.extend(model.jinja_macros.data_hash_values) + return f"{model.name}_{crc32(hash_data)}" + + +def optimized_query_cache_pool(optimized_query_cache: OptimizedQueryCache) -> PoolExecutor: + return create_process_pool_executor( + initializer=_init_optimized_query_cache, + initargs=(optimized_query_cache,), + max_workers=c.MAX_FORK_WORKERS, + ) + + +_optimized_query_cache: t.Optional[OptimizedQueryCache] = None + + +def _init_optimized_query_cache(optimized_query_cache: OptimizedQueryCache) -> None: + global _optimized_query_cache + _optimized_query_cache = optimized_query_cache + + +def load_optimized_query( + model_snapshot_id: t.Tuple[Model, SnapshotId], +) -> t.Tuple[SnapshotId, t.Optional[str]]: + assert _optimized_query_cache + model, snapshot_id = model_snapshot_id + + entry_name = None + + if isinstance(model, SqlModel): + try: + entry_name = _optimized_query_cache.put(model) + except: + # this can happen if there is a query rendering error. + # for example, the model query references some python library or function that was available + # at the time the model was created but has since been removed locally + logger.exception(f"Failed to cache optimized query for model '{model.name}'") + + return snapshot_id, entry_name + + +def load_optimized_query_and_mapping( + model: Model, mapping: t.Dict +) -> t.Tuple[str, t.Optional[str], str, str, t.Dict]: + assert _optimized_query_cache + + schema = MappingSchema(normalize=False) + for parent, columns_to_types in mapping.items(): + schema.add_table(parent, columns_to_types, dialect=model.dialect) + model.update_schema(schema) + + if isinstance(model, SqlModel): + entry_name = _optimized_query_cache._entry_name(model) + _optimized_query_cache.with_optimized_query(model, entry_name) + else: + entry_name = None + + return ( + model.fqn, + entry_name, + model.data_hash, + model.metadata_hash, + model.mapping_schema, + ) def _mapping_schema_hash_data(schema: t.Dict[str, t.Any]) -> t.List[str]: diff --git a/sqlmesh/core/model/common.py b/sqlmesh/core/model/common.py index f3bcfd5b12..dc51b3379c 100644 --- a/sqlmesh/core/model/common.py +++ b/sqlmesh/core/model/common.py @@ -1,20 +1,476 @@ from __future__ import annotations +import ast import typing as t +from pathlib import Path +from astor import to_source +from difflib import get_close_matches from sqlglot import exp +from sqlglot.helper import ensure_list -from sqlmesh.core.dialect import normalize_model_name, parse_one +from sqlmesh.core import constants as c +from sqlmesh.core import dialect as d +from sqlmesh.core.macros import MacroRegistry, MacroStrTemplate from sqlmesh.utils import str_to_bool -from sqlmesh.utils.errors import ConfigError, SQLMeshError -from sqlmesh.utils.pydantic import field_validator, field_validator_v1_args +from sqlmesh.utils.errors import ConfigError, SQLMeshError, raise_config_error +from sqlmesh.utils.metaprogramming import ( + Executable, + SqlValue, + build_env, + prepare_env, + serialize_env, +) +from sqlmesh.utils.pydantic import PydanticModel, ValidationInfo, field_validator, get_dialect + +if t.TYPE_CHECKING: + from sqlglot.dialects.dialect import DialectType + from sqlmesh.utils import registry_decorator + from sqlmesh.utils.jinja import MacroReference + + MacroCallable = t.Union[Executable, registry_decorator] + + +def make_python_env( + expressions: t.Union[ + exp.Expression, + t.List[t.Union[exp.Expression, t.Tuple[exp.Expression, bool]]], + ], + jinja_macro_references: t.Optional[t.Set[MacroReference]], + module_path: Path, + macros: MacroRegistry, + variables: t.Optional[t.Dict[str, t.Any]] = None, + referenced_variables: t.Optional[t.Set[str]] = None, + path: t.Optional[Path] = None, + python_env: t.Optional[t.Dict[str, Executable]] = None, + strict_resolution: bool = True, + blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, + dialect: DialectType = None, +) -> t.Dict[str, Executable]: + python_env = {} if python_env is None else python_env + env: t.Dict[str, t.Tuple[t.Any, t.Optional[bool]]] = {} + + variables = variables or {} + blueprint_variables = blueprint_variables or {} + + used_macros: t.Dict[str, t.Tuple[MacroCallable, bool]] = {} + + # var -> True: var is metadata-only + # var -> False: var is not metadata-only + # var -> None: cannot determine whether var is metadata-only yet, need to walk macros first + used_variables: t.Dict[str, t.Optional[bool]] = dict.fromkeys( + referenced_variables or set(), False + ) + + # id(expr) -> true: expr appears under the AST of a metadata-only macro function + # id(expr) -> false: expr appears under the AST of a macro function whose metadata status we don't yet know + expr_under_metadata_macro_func: t.Dict[int, bool] = {} + + # For @m1(@m2(@x), @y), we'd get x -> m1 and y -> m1 + outermost_macro_func_ancestor_by_var: t.Dict[str, str] = {} + visited_macro_funcs: t.Set[int] = set() + + def _is_metadata_var( + name: str, expression: exp.Expression, appears_in_metadata_expression: bool + ) -> t.Optional[bool]: + is_metadata_so_far = used_variables.get(name, True) + if is_metadata_so_far is False: + # We've concluded this variable is definitely not metadata-only + return False + + appears_under_metadata_macro_func = expr_under_metadata_macro_func.get(id(expression)) + if is_metadata_so_far and ( + appears_in_metadata_expression or appears_under_metadata_macro_func + ): + # The variable appears in a metadata expression, e.g., audits (...), + # or in the AST of metadata-only macro call, e.g., @FOO(@x) + return True + + # The variable appears in the AST of a macro call, but we don't know if it's metadata-only + if appears_under_metadata_macro_func is False: + return None + + # The variable appears elsewhere, e.g., in the model's query: SELECT @x + return False + + def _is_metadata_macro(name: str, appears_in_metadata_expression: bool) -> bool: + if name in used_macros: + is_metadata_so_far = used_macros[name][1] + return is_metadata_so_far and appears_in_metadata_expression + + return appears_in_metadata_expression + + expressions = ensure_list(expressions) + for expression_metadata in expressions: + if isinstance(expression_metadata, tuple): + expression, is_metadata = expression_metadata + else: + expression, is_metadata = expression_metadata, False + + if isinstance(expression, d.Jinja): + continue + + for macro_func_or_var in expression.find_all(d.MacroFunc, d.MacroVar, exp.Identifier): + if macro_func_or_var.__class__ is d.MacroFunc: + name = macro_func_or_var.this.name.lower() + if name not in macros: + continue + + used_macros[name] = (macros[name], _is_metadata_macro(name, is_metadata)) + + if name in (c.VAR, c.BLUEPRINT_VAR): + args = macro_func_or_var.this.expressions + if len(args) < 1: + raise_config_error( + f"Macro {name.upper()} requires at least one argument", path + ) + + if not args[0].is_string: + raise_config_error( + f"The variable name must be a string literal, '{args[0].sql()}' was given instead", + path, + ) + + var_name = args[0].this.lower() + used_variables[var_name] = _is_metadata_var( + var_name, macro_func_or_var, is_metadata + ) + elif id(macro_func_or_var) not in visited_macro_funcs: + # We only care about the top-level macro function calls to determine the metadata + # status of the variables referenced in their ASTs. For example, in @m1(@m2(@x)), + # if m1 is metadata-only but m2 is not, we can still determine that @x only affects + # the metadata hash, since m2's result feeds into a metadata-only macro function. + # + # Generally, if the top-level call is known to be metadata-only or appear in a + # metadata expression, then we can avoid traversing nested macro function calls. + + var_refs, _expr_under_metadata_macro_func, _visited_macro_funcs = ( + _extract_macro_func_variable_references(macro_func_or_var, is_metadata) + ) + expr_under_metadata_macro_func.update(_expr_under_metadata_macro_func) + visited_macro_funcs.update(_visited_macro_funcs) + outermost_macro_func_ancestor_by_var |= {var_ref: name for var_ref in var_refs} + elif macro_func_or_var.__class__ is d.MacroVar: + var_name = macro_func_or_var.name.lower() + if var_name in macros: + used_macros[var_name] = ( + macros[var_name], + _is_metadata_macro(var_name, is_metadata), + ) + elif var_name in variables or var_name in blueprint_variables: + used_variables[var_name] = _is_metadata_var( + var_name, macro_func_or_var, is_metadata + ) + elif ( + isinstance(macro_func_or_var, (exp.Identifier, d.MacroStrReplace, d.MacroSQL)) + ) and "@" in macro_func_or_var.name: + for _, identifier, braced_identifier, _ in MacroStrTemplate.pattern.findall( + macro_func_or_var.name + ): + var_name = braced_identifier or identifier + if var_name in variables or var_name in blueprint_variables: + used_variables[var_name] = _is_metadata_var( + var_name, macro_func_or_var, is_metadata + ) + + for macro_ref in jinja_macro_references or set(): + if macro_ref.package is None and macro_ref.name in macros: + used_macros[macro_ref.name] = (macros[macro_ref.name], False) + + for name, (used_macro, is_metadata) in used_macros.items(): + if isinstance(used_macro, Executable): + python_env[name] = used_macro + elif not hasattr(used_macro, c.SQLMESH_BUILTIN) and name not in python_env: + build_env( + used_macro.func, + env=env, + name=name, + path=module_path, + is_metadata_obj=is_metadata, + ) + + python_env.update(serialize_env(env, path=module_path)) + return _add_variables_to_python_env( + python_env, + used_variables, + variables, + blueprint_variables=blueprint_variables, + dialect=dialect, + strict_resolution=strict_resolution, + outermost_macro_func_ancestor_by_var=outermost_macro_func_ancestor_by_var, + ) + + +def _extract_macro_func_variable_references( + macro_func: exp.Expression, + is_metadata: bool, +) -> t.Tuple[t.Set[str], t.Dict[int, bool], t.Set[int]]: + var_references = set() + visited_macro_funcs = set() + expr_under_metadata_macro_func = {} + + for n in macro_func.walk(): + if type(n) is d.MacroFunc: + visited_macro_funcs.add(id(n)) + + this = n.this + args = this.expressions + + if this.name.lower() in (c.VAR, c.BLUEPRINT_VAR) and args and args[0].is_string: + var_references.add(args[0].this.lower()) + expr_under_metadata_macro_func[id(n)] = is_metadata + elif isinstance(n, d.MacroVar): + var_references.add(n.name.lower()) + expr_under_metadata_macro_func[id(n)] = is_metadata + elif isinstance(n, (exp.Identifier, d.MacroStrReplace, d.MacroSQL)) and "@" in n.name: + var_references.update( + (braced_identifier or identifier).lower() + for _, identifier, braced_identifier, _ in MacroStrTemplate.pattern.findall(n.name) + ) + expr_under_metadata_macro_func[id(n)] = is_metadata + + return (var_references, expr_under_metadata_macro_func, visited_macro_funcs) + + +def _add_variables_to_python_env( + python_env: t.Dict[str, Executable], + used_variables: t.Dict[str, t.Optional[bool]], + variables: t.Optional[t.Dict[str, t.Any]], + strict_resolution: bool = True, + blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, + dialect: DialectType = None, + outermost_macro_func_ancestor_by_var: t.Optional[t.Dict[str, str]] = None, +) -> t.Dict[str, Executable]: + _, python_used_variables = parse_dependencies( + python_env, + None, + strict_resolution=strict_resolution, + variables=variables, + blueprint_variables=blueprint_variables, + ) + for var_name, is_metadata in python_used_variables.items(): + used_variables[var_name] = is_metadata and used_variables.get(var_name, True) + + # Variables are treated as metadata-only when all of their references either: + # - appear in metadata-only expressions, such as `audits (...)`, virtual statements, etc + # - appear in the ASTs or definitions of metadata-only macros + # + # See also: https://github.com/SQLMesh/sqlmesh/pull/4936#issuecomment-3136339936, + # specifically the "Terminology" and "Observations" section. + metadata_used_variables = { + var_name for var_name, is_metadata in used_variables.items() if is_metadata + } + for used_var, outermost_macro_func in (outermost_macro_func_ancestor_by_var or {}).items(): + used_var_is_metadata = used_variables.get(used_var) + if used_var_is_metadata is False: + continue + + # At this point we can decide whether a variable reference in a macro call's AST is + # metadata-only, because we've annotated the corresponding macro call in the python env. + if outermost_macro_func in python_env and python_env[outermost_macro_func].is_metadata: + metadata_used_variables.add(used_var) + + non_metadata_used_variables = set(used_variables) - metadata_used_variables + + if overlapping_variables := (non_metadata_used_variables & metadata_used_variables): + raise ConfigError( + f"Variables {', '.join(overlapping_variables)} are both metadata and non-metadata, " + "which is unexpected. Please file an issue at https://github.com/SQLMesh/sqlmesh/issues/new." + ) + + metadata_variables = { + k: v for k, v in (variables or {}).items() if k in metadata_used_variables + } + variables = {k: v for k, v in (variables or {}).items() if k in non_metadata_used_variables} + + if variables: + python_env[c.SQLMESH_VARS] = Executable.value(variables, sort_root_dict=True) + if metadata_variables: + python_env[c.SQLMESH_VARS_METADATA] = Executable.value( + metadata_variables, sort_root_dict=True, is_metadata=True + ) + + if blueprint_variables: + metadata_blueprint_variables = { + k: SqlValue(sql=v.sql(dialect=dialect)) if isinstance(v, exp.Expression) else v + for k, v in blueprint_variables.items() + if k in metadata_used_variables + } + blueprint_variables = { + k.lower(): SqlValue(sql=v.sql(dialect=dialect)) if isinstance(v, exp.Expression) else v + for k, v in blueprint_variables.items() + if k in non_metadata_used_variables + } + if blueprint_variables: + python_env[c.SQLMESH_BLUEPRINT_VARS] = Executable.value( + blueprint_variables, sort_root_dict=True + ) + if metadata_blueprint_variables: + python_env[c.SQLMESH_BLUEPRINT_VARS_METADATA] = Executable.value( + metadata_blueprint_variables, sort_root_dict=True, is_metadata=True + ) + + return python_env + + +def parse_dependencies( + python_env: t.Dict[str, Executable], + entrypoint: t.Optional[str], + strict_resolution: bool = True, + variables: t.Optional[t.Dict[str, t.Any]] = None, + blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, +) -> t.Tuple[t.Set[str], t.Dict[str, bool]]: + """ + Parses the source of a model function and finds upstream table dependencies + and referenced variables based on calls to context / evaluator. + + Args: + python_env: A dictionary of Python definitions. + entrypoint: The name of the function. + strict_resolution: If true, the arguments of `table` and `resolve_table` calls must + be resolvable at parse time, otherwise an exception will be raised. + variables: The variables available to the python environment. + blueprint_variables: The blueprint variables available to the python environment. + + Returns: + A tuple containing the set of upstream table dependencies and a mapping of + the referenced variables associated with their metadata status. + """ + + class VariableResolutionContext: + """This enables calls like `resolve_table` to reference `var()` and `blueprint_var()`.""" + + @staticmethod + def var(var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]: + return (variables or {}).get(var_name.lower(), default) + + @staticmethod + def blueprint_var(var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]: + return (blueprint_variables or {}).get(var_name.lower(), default) + + env = prepare_env(python_env) + local_env = dict.fromkeys(("context", "evaluator"), VariableResolutionContext) + + depends_on = set() + used_variables: t.Dict[str, bool] = {} + + for executable in python_env.values(): + if not executable.is_definition: + continue + + is_metadata = executable.is_metadata + for node in ast.walk(ast.parse(executable.payload)): + next_variables = set() + + if isinstance(node, ast.Call): + func = node.func + if not isinstance(func, ast.Attribute) or not isinstance(func.value, ast.Name): + continue + + def get_first_arg(keyword_arg_name: str) -> t.Any: + if node.args: + first_arg: t.Optional[ast.expr] = node.args[0] + else: + first_arg = next( + ( + keyword.value + for keyword in node.keywords + if keyword.arg == keyword_arg_name + ), + None, + ) + + try: + expression = to_source(first_arg) + return eval(expression, env, local_env) + except Exception: + if strict_resolution: + raise ConfigError( + f"Error resolving dependencies for '{executable.path}'. " + f"Argument '{expression.strip()}' must be resolvable at parse time." + ) + + if func.value.id == "context" and func.attr in ("table", "resolve_table"): + depends_on.add(get_first_arg("model_name")) + elif func.value.id in ("context", "evaluator") and func.attr in ( + c.VAR, + c.BLUEPRINT_VAR, + ): + next_variables.add(get_first_arg("var_name").lower()) + elif ( + isinstance(node, ast.Attribute) + and isinstance(node.value, ast.Name) + and node.value.id in ("context", "evaluator") + and node.attr == c.GATEWAY + ): + # Check whether the gateway attribute is referenced. + next_variables.add(c.GATEWAY) + elif isinstance(node, ast.FunctionDef) and node.name == entrypoint: + next_variables.update( + [ + arg.arg + for arg in [*node.args.args, *node.args.kwonlyargs] + if arg.arg != "context" + ] + ) + + for var_name in next_variables: + used_variables[var_name] = used_variables.get(var_name, True) and bool(is_metadata) + + return depends_on, used_variables + + +def validate_extra_and_required_fields( + klass: t.Type[PydanticModel], + provided_fields: t.Set[str], + entity_name: str, + path: t.Optional[Path] = None, +) -> None: + missing_required_fields = klass.missing_required_fields(provided_fields) + if missing_required_fields: + field_names = "'" + "', '".join(missing_required_fields) + "'" + raise_config_error( + f"Please add required field{'s' if len(missing_required_fields) > 1 else ''} {field_names} to the {entity_name}.", + path, + ) + + extra_fields = klass.extra_fields(provided_fields) + if extra_fields: + extra_field_names = "'" + "', '".join(extra_fields) + "'" + + all_fields = klass.all_fields() + close_matches = {} + for field in extra_fields: + matches = get_close_matches(field, all_fields, n=1) + if matches: + close_matches[field] = matches[0] + + if len(close_matches) == 1: + similar_msg = ". Did you mean " + "'" + "', '".join(close_matches.values()) + "'?" + else: + similar = [ + f"- {field}: Did you mean '{match}'?" for field, match in close_matches.items() + ] + similar_msg = "\n\n " + "\n ".join(similar) if similar else "" + + raise_config_error( + f"Invalid field name{'s' if len(extra_fields) > 1 else ''} present in the {entity_name}: {extra_field_names}{similar_msg}", + path, + ) + + +def single_value_or_tuple(values: t.Sequence) -> exp.Identifier | exp.Tuple: + return ( + exp.to_identifier(values[0]) + if len(values) == 1 + else exp.Tuple(expressions=[exp.to_identifier(v) for v in values]) + ) -@field_validator_v1_args def parse_expression( cls: t.Type, v: t.Union[t.List[str], t.List[exp.Expression], str, exp.Expression, t.Callable, None], - values: t.Dict[str, t.Any], + info: t.Optional[ValidationInfo], ) -> t.List[exp.Expression] | exp.Expression | t.Callable | None: """Helper method to deserialize SQLGlot expressions in Pydantic Models.""" if v is None: @@ -23,15 +479,17 @@ def parse_expression( if callable(v): return v - dialect = values.get("dialect") + dialect = info.data.get("dialect") if info else "" if isinstance(v, list): return [ - parse_one(e, dialect=dialect) if not isinstance(e, exp.Expression) else e for e in v + e if isinstance(e, exp.Expression) else d.parse_one(e, dialect=dialect) + for e in v + if not isinstance(e, exp.Semicolon) ] if isinstance(v, str): - return parse_one(v, dialect=dialect) + return d.parse_one(v, dialect=dialect) if not v: raise ConfigError(f"Could not parse {v}") @@ -40,21 +498,31 @@ def parse_expression( def parse_bool(v: t.Any) -> bool: - if isinstance(v, exp.Boolean): - return v.this if isinstance(v, exp.Expression): + if not isinstance(v, exp.Boolean): + from sqlglot.optimizer.simplify import simplify + + # Try to reduce expressions like (1 = 1) (see: T-SQL boolean generation) + v = simplify(v) + + if isinstance(v, exp.Boolean): + return v.this + return str_to_bool(v.name) + return str_to_bool(str(v or "")) -@field_validator_v1_args -def parse_properties(cls: t.Type, v: t.Any, values: t.Dict[str, t.Any]) -> t.Optional[exp.Tuple]: +def parse_properties( + cls: t.Type, v: t.Any, info: t.Optional[ValidationInfo] +) -> t.Optional[exp.Tuple]: if v is None: return v - dialect = values.get("dialect") + dialect = info.data.get("dialect") if info else "" + if isinstance(v, str): - v = parse_one(v, dialect=dialect) + v = d.parse_one(v, dialect=dialect) if isinstance(v, (exp.Array, exp.Paren, exp.Tuple)): eq_expressions: t.List[exp.Expression] = ( [v.unnest()] if isinstance(v, exp.Paren) else v.expressions @@ -88,14 +556,16 @@ def default_catalog(cls: t.Type, v: t.Any) -> t.Optional[str]: return str(v) -@field_validator_v1_args -def depends_on(cls: t.Type, v: t.Any, values: t.Dict[str, t.Any]) -> t.Optional[t.Set[str]]: - dialect = values.get("dialect") - default_catalog = values.get("default_catalog") +def depends_on(cls: t.Type, v: t.Any, info: ValidationInfo) -> t.Optional[t.Set[str]]: + dialect = info.data.get("dialect") + default_catalog = info.data.get("default_catalog") + + if isinstance(v, exp.Paren): + v = v.unnest() if isinstance(v, (exp.Array, exp.Tuple)): return { - normalize_model_name( + d.normalize_model_name( table.name if table.is_string else table, default_catalog=default_catalog, dialect=dialect, @@ -103,28 +573,56 @@ def depends_on(cls: t.Type, v: t.Any, values: t.Dict[str, t.Any]) -> t.Optional[ for table in v.expressions } if isinstance(v, (exp.Table, exp.Column)): - return {normalize_model_name(v, default_catalog=default_catalog, dialect=dialect)} + return {d.normalize_model_name(v, default_catalog=default_catalog, dialect=dialect)} if hasattr(v, "__iter__") and not isinstance(v, str): return { - normalize_model_name(name, default_catalog=default_catalog, dialect=dialect) + d.normalize_model_name(name, default_catalog=default_catalog, dialect=dialect) for name in v } return v -expression_validator = field_validator( - "query", - "expressions_", - "pre_statements_", - "post_statements_", +def sort_python_env(python_env: t.Dict[str, Executable]) -> t.List[t.Tuple[str, Executable]]: + """Returns the python env sorted.""" + return sorted(python_env.items(), key=lambda x: (x[1].kind, x[0])) + + +def sorted_python_env_payloads(python_env: t.Dict[str, Executable]) -> t.List[str]: + """Returns the payloads of the sorted python env.""" + + def _executable_to_str(k: str, v: Executable) -> str: + result = f"# {v.path}\n" if v.path is not None else "" + if v.is_import or v.is_definition: + result += v.payload + else: + result += f"{k} = {v.payload}" + return result + + return [_executable_to_str(k, v) for k, v in sort_python_env(python_env)] + + +def parse_strings_with_macro_refs(value: t.Any, dialect: DialectType) -> t.Any: + if isinstance(value, str) and "@" in value: + return exp.maybe_parse(value, dialect=dialect) + + if isinstance(value, dict): + for k, v in dict(value).items(): + value[k] = parse_strings_with_macro_refs(v, dialect) + elif isinstance(value, list): + value = [parse_strings_with_macro_refs(v, dialect) for v in value] + + return value + + +expression_validator: t.Callable = field_validator( "unique_key", mode="before", check_fields=False, )(parse_expression) -bool_validator = field_validator( +bool_validator: t.Callable = field_validator( "skip", "blocking", "forward_only", @@ -132,30 +630,95 @@ def depends_on(cls: t.Type, v: t.Any, values: t.Dict[str, t.Any]) -> t.Optional[ "insert_overwrite", "allow_partials", "enabled", + "optimize_query", + "formatting", mode="before", check_fields=False, )(parse_bool) -properties_validator = field_validator( +properties_validator: t.Callable = field_validator( "physical_properties_", "virtual_properties_", - "session_properties_", "materialization_properties_", + "grants_", mode="before", check_fields=False, )(parse_properties) -default_catalog_validator = field_validator( +default_catalog_validator: t.Callable = field_validator( "default_catalog", mode="before", check_fields=False, )(default_catalog) -depends_on_validator = field_validator( +depends_on_validator: t.Callable = field_validator( "depends_on_", mode="before", check_fields=False, )(depends_on) + + +class ParsableSql(PydanticModel): + sql: str + transaction: t.Optional[bool] = None + + _parsed: t.Optional[exp.Expression] = None + _parsed_dialect: t.Optional[str] = None + + def parse(self, dialect: str) -> exp.Expression: + if self._parsed is None or self._parsed_dialect != dialect: + self._parsed = d.parse_one(self.sql, dialect=dialect) + self._parsed_dialect = dialect + return self._parsed + + @classmethod + def from_parsed_expression( + cls, parsed_expression: exp.Expression, dialect: str, use_meta_sql: bool = False + ) -> ParsableSql: + sql = ( + parsed_expression.meta.get("sql") or parsed_expression.sql(dialect=dialect) + if use_meta_sql + else parsed_expression.sql(dialect=dialect) + ) + result = cls(sql=sql) + result._parsed = parsed_expression + result._parsed_dialect = dialect + return result + + @classmethod + def validator(cls) -> classmethod: + def _validate_parsable_sql( + v: t.Any, info: ValidationInfo + ) -> t.Optional[t.Union[ParsableSql, t.List[ParsableSql]]]: + if v is None: + return v + if isinstance(v, str): + return ParsableSql(sql=v) + if isinstance(v, exp.Expression): + return ParsableSql.from_parsed_expression( + v, get_dialect(info.data), use_meta_sql=False + ) + if isinstance(v, list): + dialect = get_dialect(info.data) + return [ + ParsableSql(sql=s) + if isinstance(s, str) + else ParsableSql.from_parsed_expression(s, dialect, use_meta_sql=False) + if isinstance(s, exp.Expression) + else ParsableSql.parse_obj(s) + for s in v + ] + return ParsableSql.parse_obj(v) + + return field_validator( + "query_", + "expressions_", + "pre_statements_", + "post_statements_", + "on_virtual_update_", + mode="before", + check_fields=False, + )(_validate_parsable_sql) diff --git a/sqlmesh/core/model/decorator.py b/sqlmesh/core/model/decorator.py index 6f4f0c9f18..73452cc165 100644 --- a/sqlmesh/core/model/decorator.py +++ b/sqlmesh/core/model/decorator.py @@ -1,27 +1,37 @@ from __future__ import annotations -import logging import typing as t from pathlib import Path import inspect +import re from sqlglot import exp from sqlglot.dialects.dialect import DialectType +from sqlmesh.core.config.common import VirtualEnvironmentMode +from sqlmesh.core.macros import MacroRegistry +from sqlmesh.core.signal import SignalRegistry +from sqlmesh.utils.jinja import JinjaMacroRegistry from sqlmesh.core import constants as c -from sqlmesh.core.dialect import MacroFunc +from sqlmesh.core.dialect import MacroFunc, parse_one from sqlmesh.core.model.definition import ( Model, create_python_model, create_sql_model, + create_models_from_blueprints, get_model_name, + parse_defaults_properties, + render_meta_fields, + render_model_defaults, ) from sqlmesh.core.model.kind import ModelKindName, _ModelKind -from sqlmesh.utils import registry_decorator -from sqlmesh.utils.errors import ConfigError +from sqlmesh.utils import registry_decorator, DECORATOR_RETURN_TYPE +from sqlmesh.utils.errors import ConfigError, raise_config_error from sqlmesh.utils.metaprogramming import build_env, serialize_env -logger = logging.getLogger(__name__) + +if t.TYPE_CHECKING: + from sqlmesh.core.audit import ModelAudit class model(registry_decorator): @@ -34,28 +44,30 @@ def __init__(self, name: t.Optional[str] = None, is_sql: bool = False, **kwargs: if not is_sql and "columns" not in kwargs: raise ConfigError("Python model must define column schema.") + self.name_provided = bool(name) self.name = name or "" self.is_sql = is_sql self.kwargs = kwargs # Make sure that argument values are expressions in order to pass validation in ModelMeta. - calls = self.kwargs.pop("audits", []) - self.kwargs["audits"] = [ - ( - (call, {}) - if isinstance(call, str) - else ( - call[0], - { - arg_key: exp.convert( - tuple(arg_value) if isinstance(arg_value, list) else arg_value - ) - for arg_key, arg_value in call[1].items() - }, + for function_call_attribute in ("audits", "signals"): + calls = self.kwargs.pop(function_call_attribute, []) + self.kwargs[function_call_attribute] = [ + ( + (call, {}) + if isinstance(call, str) + else ( + call[0], + { + arg_key: exp.convert( + tuple(arg_value) if isinstance(arg_value, list) else arg_value + ) + for arg_key, arg_value in call[1].items() + }, + ) ) - ) - for call in calls - ] + for call in calls + ] if "default_catalog" in kwargs: raise ConfigError("`default_catalog` cannot be set on a per-model basis.") @@ -64,39 +76,100 @@ def __init__(self, name: t.Optional[str] = None, is_sql: bool = False, **kwargs: column_name: ( column_type if isinstance(column_type, exp.DataType) - else exp.DataType.build(str(column_type)) + else exp.DataType.build( + str(column_type), dialect=self.kwargs.get("dialect", self._dialect) + ) ) for column_name, column_type in self.kwargs.pop("columns", {}).items() } + def __call__( + self, func: t.Callable[..., DECORATOR_RETURN_TYPE] + ) -> t.Callable[..., DECORATOR_RETURN_TYPE]: + if not self.name_provided: + self.name = get_model_name(Path(inspect.getfile(func))) + return super().__call__(func) + + def models( + self, + get_variables: t.Callable[[t.Optional[str]], t.Dict[str, str]], + path: Path, + module_path: Path, + dialect: t.Optional[str] = None, + default_catalog_per_gateway: t.Optional[t.Dict[str, str]] = None, + **loader_kwargs: t.Any, + ) -> t.List[Model]: + blueprints = self.kwargs.pop("blueprints", None) + + if isinstance(blueprints, str): + blueprints = parse_one(blueprints, dialect=dialect) + + if isinstance(blueprints, MacroFunc): + from sqlmesh.core.model.definition import render_expression + + blueprints = render_expression( + expression=blueprints, + module_path=module_path, + macros=loader_kwargs.get("macros"), + jinja_macros=loader_kwargs.get("jinja_macros"), + variables=get_variables(None), + path=path, + dialect=dialect, + default_catalog=loader_kwargs.get("default_catalog"), + ) + if not blueprints: + raise_config_error("Failed to render blueprints property", path) + + if len(blueprints) > 1: + blueprints = [exp.Tuple(expressions=blueprints)] + + blueprints = blueprints[0] + + return create_models_from_blueprints( + gateway=self.kwargs.get("gateway"), + blueprints=blueprints, + get_variables=get_variables, + loader=self.model, + path=path, + module_path=module_path, + dialect=dialect, + default_catalog_per_gateway=default_catalog_per_gateway, + **loader_kwargs, + ) + def model( self, *, module_path: Path, path: Path, defaults: t.Optional[t.Dict[str, t.Any]] = None, + macros: t.Optional[MacroRegistry] = None, + jinja_macros: t.Optional[JinjaMacroRegistry] = None, + signal_definitions: t.Optional[SignalRegistry] = None, + audit_definitions: t.Optional[t.Dict[str, ModelAudit]] = None, dialect: t.Optional[str] = None, time_column_format: str = c.DEFAULT_TIME_COLUMN_FORMAT, - physical_schema_override: t.Optional[t.Dict[str, str]] = None, + physical_schema_mapping: t.Optional[t.Dict[re.Pattern, str]] = None, project: str = "", default_catalog: t.Optional[str] = None, variables: t.Optional[t.Dict[str, t.Any]] = None, infer_names: t.Optional[bool] = False, + blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, + virtual_environment_mode: VirtualEnvironmentMode = VirtualEnvironmentMode.default, ) -> Model: """Get the model registered by this function.""" - env: t.Dict[str, t.Any] = {} + env: t.Dict[str, t.Tuple[t.Any, t.Optional[bool]]] = {} entrypoint = self.func.__name__ - if not self.name and infer_names: - self.name = get_model_name(Path(inspect.getfile(self.func))) - - if not self.name: + if not self.name_provided and not infer_names: raise ConfigError("Python model must have a name.") kind = self.kwargs.get("kind", None) if kind is not None: if isinstance(kind, _ModelKind): - logger.warning( + from sqlmesh.core.console import get_console + + get_console().log_warning( f"""Python model "{self.name}"'s `kind` argument was passed a SQLMesh `{type(kind).__name__}` object. This may result in unexpected behavior - provide a dictionary instead.""" ) elif isinstance(kind, dict): @@ -107,30 +180,69 @@ def model( build_env(self.func, env=env, name=entrypoint, path=module_path) - common_kwargs = dict( - defaults=defaults, + rendered_fields = render_meta_fields( + fields={"name": self.name, **self.kwargs}, + module_path=module_path, + macros=macros, + jinja_macros=jinja_macros, + variables=variables, path=path, - time_column_format=time_column_format, - python_env=serialize_env(env, path=module_path), - physical_schema_override=physical_schema_override, - project=project, + dialect=dialect, default_catalog=default_catalog, - variables=variables, - **self.kwargs, + blueprint_variables=blueprint_variables, + ) + + rendered_name = rendered_fields["name"] + if isinstance(rendered_name, exp.Expression): + rendered_fields["name"] = rendered_name.sql(dialect=dialect) + + rendered_defaults = ( + render_model_defaults( + defaults=defaults, + module_path=module_path, + macros=macros, + jinja_macros=jinja_macros, + variables=variables, + path=path, + dialect=dialect, + default_catalog=default_catalog, + ) + if defaults + else {} ) - dialect = common_kwargs.pop("dialect", dialect) - for key in ("pre_statements", "post_statements"): + rendered_defaults = parse_defaults_properties(rendered_defaults, dialect=dialect) + + common_kwargs = { + "defaults": rendered_defaults, + "path": path, + "time_column_format": time_column_format, + "python_env": serialize_env(env, path=module_path), + "physical_schema_mapping": physical_schema_mapping, + "project": project, + "default_catalog": default_catalog, + "variables": variables, + "dialect": dialect, + "columns": self.columns if self.columns else None, + "module_path": module_path, + "macros": macros, + "jinja_macros": jinja_macros, + "audit_definitions": audit_definitions, + "signal_definitions": signal_definitions, + "blueprint_variables": blueprint_variables, + "virtual_environment_mode": virtual_environment_mode, + **rendered_fields, + } + + for key in ("pre_statements", "post_statements", "on_virtual_update"): statements = common_kwargs.get(key) if statements: - common_kwargs[key] = [exp.maybe_parse(s, dialect=dialect) for s in statements] + common_kwargs[key] = [ + parse_one(s, dialect=common_kwargs.get("dialect")) if isinstance(s, str) else s + for s in statements + ] if self.is_sql: query = MacroFunc(this=exp.Anonymous(this=entrypoint)) - return create_sql_model( - self.name, query, module_path=module_path, dialect=dialect, **common_kwargs - ) - - return create_python_model( - self.name, entrypoint, columns=self.columns, dialect=dialect, **common_kwargs - ) + return create_sql_model(query=query, **common_kwargs) + return create_python_model(entrypoint=entrypoint, **common_kwargs) diff --git a/sqlmesh/core/model/definition.py b/sqlmesh/core/model/definition.py index 3ff3003624..831b52a44e 100644 --- a/sqlmesh/core/model/definition.py +++ b/sqlmesh/core/model/definition.py @@ -1,67 +1,101 @@ from __future__ import annotations -import ast import json import logging -import sys import types +import re import typing as t -from functools import cached_property +from functools import cached_property, partial from pathlib import Path -import pandas as pd -import numpy as np -from astor import to_source from pydantic import Field from sqlglot import diff, exp -from sqlglot.diff import Insert, Keep -from sqlglot.helper import ensure_list +from sqlglot.diff import Insert +from sqlglot.helper import seq_get +from sqlglot.optimizer.qualify_columns import quote_identifiers from sqlglot.optimizer.simplify import gen +from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlglot.schema import MappingSchema, nested_set from sqlglot.time import format_time from sqlmesh.core import constants as c from sqlmesh.core import dialect as d -from sqlmesh.core.macros import MacroRegistry, MacroStrTemplate, macro -from sqlmesh.core.model.common import expression_validator -from sqlmesh.core.model.kind import ModelKindName, SeedKind, create_model_kind +from sqlmesh.core.audit import Audit, ModelAudit +from sqlmesh.core.node import IntervalUnit +from sqlmesh.core.macros import MacroRegistry, macro +from sqlmesh.core.model.common import ( + ParsableSql, + make_python_env, + parse_dependencies, + parse_strings_with_macro_refs, + single_value_or_tuple, + sorted_python_env_payloads, + validate_extra_and_required_fields, +) from sqlmesh.core.model.meta import ModelMeta +from sqlmesh.core.model.kind import ( + ExternalKind, + ModelKindName, + SeedKind, + ModelKind, + FullKind, + create_model_kind, + CustomKind, +) from sqlmesh.core.model.seed import CsvSeedReader, Seed, create_seed from sqlmesh.core.renderer import ExpressionRenderer, QueryRenderer +from sqlmesh.core.signal import SignalRegistry from sqlmesh.utils import columns_to_types_all_known, str_to_bool, UniqueKeyDict +from sqlmesh.utils.cron import CroniterCache from sqlmesh.utils.date import TimeLike, make_inclusive, to_datetime, to_time_column -from sqlmesh.utils.errors import ConfigError, SQLMeshError, raise_config_error +from sqlmesh.utils.errors import ConfigError, SQLMeshError, raise_config_error, PythonModelEvalError from sqlmesh.utils.hashing import hash_data -from sqlmesh.utils.jinja import ( - JinjaMacroRegistry, - extract_macro_references_and_variables, -) -from sqlmesh.utils.pydantic import field_validator, field_validator_v1_args +from sqlmesh.utils.jinja import JinjaMacroRegistry, extract_macro_references_and_variables +from sqlmesh.utils.pydantic import PydanticModel, PRIVATE_FIELDS from sqlmesh.utils.metaprogramming import ( Executable, + SqlValue, build_env, prepare_env, - print_exception, serialize_env, + format_evaluated_code_exception, ) if t.TYPE_CHECKING: - from sqlmesh.core._typing import TableName - from sqlmesh.core.audit import ModelAudit + from sqlglot.dialects.dialect import DialectType + from sqlmesh.core.node import _Node + from sqlmesh.core._typing import Self, TableName, SessionProperties from sqlmesh.core.context import ExecutionContext from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.engine_adapter._typing import QueryOrDF + from sqlmesh.core.engine_adapter.shared import DataObjectType + from sqlmesh.core.linter.rule import Rule from sqlmesh.core.snapshot import DeployabilityIndex, Node, Snapshot from sqlmesh.utils.jinja import MacroReference -if sys.version_info >= (3, 9): - from typing import Literal -else: - from typing_extensions import Literal logger = logging.getLogger(__name__) +PROPERTIES = {"physical_properties", "session_properties", "virtual_properties"} + +RUNTIME_RENDERED_MODEL_FIELDS = { + "audits", + "signals", + "merge_filter", +} | PROPERTIES + +CRON_SHORTCUTS = { + "@midnight", + "@hourly", + "@daily", + "@weekly", + "@monthly", + "@yearly", + "@annually", +} + + class _Model(ModelMeta, frozen=True): """Model is the core abstraction for user defined datasets. @@ -102,22 +136,45 @@ class _Model(ModelMeta, frozen=True): end: The date that the model will be backfilled up until. Follows the same syntax as 'start', should be omitted if there is no end date. lookback: The number of previous incremental intervals in the lookback window. + table_format: The table format used to manage the physical table files defined by `storage_format`, only applicable in certain engines. + (eg, 'iceberg', 'delta', 'hudi') storage_format: The storage format used to store the physical table, only applicable in certain engines. - (eg. 'parquet') + (eg. 'parquet', 'orc') partitioned_by: The partition columns or engine specific expressions, only applicable in certain engines. (eg. (ds, hour)) - clustered_by: The cluster columns, only applicable in certain engines. (eg. (ds, hour)) + clustered_by: The cluster columns or engine specific expressions, only applicable in certain engines. (eg. (ds, hour)) python_env: Dictionary containing all global variables needed to render the model's macros. mapping_schema: The schema of table names to column and types. + extract_dependencies_from_query: Whether to extract additional dependencies from the rendered model's query. physical_schema_override: The desired physical schema name override. """ - python_env_: t.Optional[t.Dict[str, Executable]] = Field(default=None, alias="python_env") + python_env: t.Dict[str, Executable] = {} jinja_macros: JinjaMacroRegistry = JinjaMacroRegistry() + audit_definitions: t.Dict[str, ModelAudit] = {} mapping_schema: t.Dict[str, t.Any] = {} + extract_dependencies_from_query: bool = True + pre_statements_: t.Optional[t.List[ParsableSql]] = Field(default=None, alias="pre_statements") + post_statements_: t.Optional[t.List[ParsableSql]] = Field(default=None, alias="post_statements") + on_virtual_update_: t.Optional[t.List[ParsableSql]] = Field( + default=None, alias="on_virtual_update" + ) _full_depends_on: t.Optional[t.Set[str]] = None + _statement_renderer_cache: t.Dict[int, ExpressionRenderer] = {} + _is_metadata_only_change_cache: t.Dict[int, bool] = {} + + _expressions_validator = ParsableSql.validator() + + def __getstate__(self) -> t.Dict[t.Any, t.Any]: + state = super().__getstate__() + private = state[PRIVATE_FIELDS] + private["_statement_renderer_cache"] = {} + return state - _expressions_validator = expression_validator + def copy(self, **kwargs: t.Any) -> Self: + model = super().copy(**kwargs) + model._statement_renderer_cache = {} + return model def render( self, @@ -154,7 +211,10 @@ def render( ) def render_definition( - self, include_python: bool = True, include_defaults: bool = False + self, + include_python: bool = True, + include_defaults: bool = False, + render_query: bool = False, ) -> t.List[exp.Expression]: """Returns the original list of sql expressions comprising the model definition. @@ -183,12 +243,7 @@ def render_definition( value=exp.to_table(field_value, dialect=self.dialect), ) ) - elif field_name not in ( - "column_descriptions_", - "default_catalog", - "enabled", - "inline_audits", - ): + elif field_name not in ("default_catalog", "enabled", "ignored_rules_"): expressions.append( exp.Property( this=field_info.alias or field_name, @@ -204,12 +259,7 @@ def render_definition( jinja_expressions = [] python_expressions = [] if include_python: - python_env = d.PythonCode( - expressions=[ - v.payload if v.is_import or v.is_definition else f"{k} = {v.payload}" - for k, v in self.sorted_python_env - ] - ) + python_env = d.PythonCode(expressions=sorted_python_env_payloads(self.python_env)) if python_env.expressions: python_expressions.append(python_env) @@ -314,6 +364,7 @@ def render_pre_statements( expand: t.Iterable[str] = tuple(), deployability_index: t.Optional[DeployabilityIndex] = None, engine_adapter: t.Optional[EngineAdapter] = None, + inside_transaction: t.Optional[bool] = True, **kwargs: t.Any, ) -> t.List[exp.Expression]: """Renders pre-statements for a model. @@ -334,7 +385,21 @@ def render_pre_statements( Returns: The list of rendered expressions. """ - return [] + return self._render_statements( + [ + stmt + for stmt in self.pre_statements + if stmt.args.get("transaction", True) == inside_transaction + ], + start=start, + end=end, + execution_time=execution_time, + snapshots=snapshots, + expand=expand, + deployability_index=deployability_index, + engine_adapter=engine_adapter, + **kwargs, + ) def render_post_statements( self, @@ -342,10 +407,11 @@ def render_post_statements( start: t.Optional[TimeLike] = None, end: t.Optional[TimeLike] = None, execution_time: t.Optional[TimeLike] = None, - snapshots: t.Optional[t.Collection[Snapshot]] = None, + snapshots: t.Optional[t.Dict[str, Snapshot]] = None, expand: t.Iterable[str] = tuple(), deployability_index: t.Optional[DeployabilityIndex] = None, engine_adapter: t.Optional[EngineAdapter] = None, + inside_transaction: t.Optional[bool] = True, **kwargs: t.Any, ) -> t.List[exp.Expression]: """Renders post-statements for a model. @@ -361,12 +427,191 @@ def render_post_statements( that depend on materialized tables. Model definitions are inlined and can thus be run end to end on the fly. deployability_index: Determines snapshots that are deployable in the context of this render. + inside_transaction: Whether to render hooks with transaction=True (inside) or transaction=False (outside). kwargs: Additional kwargs to pass to the renderer. Returns: The list of rendered expressions. """ - return [] + return self._render_statements( + [ + stmt + for stmt in self.post_statements + if stmt.args.get("transaction", True) == inside_transaction + ], + start=start, + end=end, + execution_time=execution_time, + snapshots=snapshots, + expand=expand, + deployability_index=deployability_index, + engine_adapter=engine_adapter, + **kwargs, + ) + + def render_on_virtual_update( + self, + *, + start: t.Optional[TimeLike] = None, + end: t.Optional[TimeLike] = None, + execution_time: t.Optional[TimeLike] = None, + snapshots: t.Optional[t.Dict[str, Snapshot]] = None, + expand: t.Iterable[str] = tuple(), + deployability_index: t.Optional[DeployabilityIndex] = None, + engine_adapter: t.Optional[EngineAdapter] = None, + **kwargs: t.Any, + ) -> t.List[exp.Expression]: + return self._render_statements( + self.on_virtual_update, + start=start, + end=end, + execution_time=execution_time, + snapshots=snapshots, + expand=expand, + deployability_index=deployability_index, + engine_adapter=engine_adapter, + **kwargs, + ) + + def render_audit_query( + self, + audit: Audit, + *, + start: t.Optional[TimeLike] = None, + end: t.Optional[TimeLike] = None, + execution_time: t.Optional[TimeLike] = None, + snapshots: t.Optional[t.Dict[str, Snapshot]] = None, + deployability_index: t.Optional[DeployabilityIndex] = None, + **kwargs: t.Any, + ) -> exp.Query: + from sqlmesh.core.snapshot import DeployabilityIndex + + deployability_index = deployability_index or DeployabilityIndex.all_deployable() + snapshot = (snapshots or {}).get(self.fqn) + + this_model = kwargs.pop("this_model", None) or ( + snapshot.table_name(deployability_index.is_deployable(snapshot)) + if snapshot + else self.fqn + ) + + columns_to_types: t.Optional[t.Dict[str, t.Any]] = None + if "engine_adapter" in kwargs: + try: + columns_to_types = kwargs["engine_adapter"].columns(this_model) + except Exception: + pass + + if self.time_column: + low, high = [ + self.convert_to_time_column(dt, columns_to_types) + for dt in make_inclusive(start or c.EPOCH, end or c.EPOCH, self.dialect) + ] + where = self.time_column.column.between(low, high) + else: + where = None + + # The model's name is already normalized, but in case of snapshots we also prepend a + # case-sensitive physical schema name, so we quote here to ensure that we won't have + # a broken schema reference after the resulting query is normalized in `render`. + quoted_model_name = quote_identifiers( + exp.to_table(this_model, dialect=self.dialect), dialect=self.dialect + ) + + query_renderer = QueryRenderer( + audit.query, + audit.dialect or self.dialect, + audit.macro_definitions, + path=audit._path or Path(), + jinja_macro_registry=audit.jinja_macros, + python_env=self.python_env, + only_execution_time=self.kind.only_execution_time, + default_catalog=self.default_catalog, + ) + + rendered_query = query_renderer.render( + start=start, + end=end, + execution_time=execution_time, + snapshots=snapshots, + deployability_index=deployability_index, + **{ + **audit.defaults, + "this_model": exp.select("*").from_(quoted_model_name).where(where).subquery() + if where is not None + else quoted_model_name, + **kwargs, + }, # type: ignore + ) + + if rendered_query is None: + raise SQLMeshError( + f"Failed to render query for audit '{audit.name}', model '{self.name}'." + ) + + return rendered_query + + @property + def pre_statements(self) -> t.List[exp.Expression]: + return self._get_parsed_statements("pre_statements_") + + @property + def post_statements(self) -> t.List[exp.Expression]: + return self._get_parsed_statements("post_statements_") + + @property + def on_virtual_update(self) -> t.List[exp.Expression]: + return self._get_parsed_statements("on_virtual_update_") + + @property + def macro_definitions(self) -> t.List[d.MacroDef]: + """All macro definitions from the list of expressions.""" + return [ + s + for s in self.pre_statements + self.post_statements + self.on_virtual_update + if isinstance(s, d.MacroDef) + ] + + def _get_parsed_statements(self, attr_name: str) -> t.List[exp.Expression]: + value = getattr(self, attr_name) + if not value: + return [] + result = [] + for v in value: + parsed = v.parse(self.dialect) + if getattr(v, "transaction", None) is not None: + parsed.set("transaction", v.transaction) + if not isinstance(parsed, exp.Semicolon): + result.append(parsed) + return result + + def _render_statements( + self, + statements: t.Iterable[exp.Expression], + **kwargs: t.Any, + ) -> t.List[exp.Expression]: + rendered = ( + self._statement_renderer(statement).render(**kwargs) + for statement in statements + if not isinstance(statement, d.MacroDef) + ) + return [r for expressions in rendered if expressions for r in expressions] + + def _statement_renderer(self, expression: exp.Expression) -> ExpressionRenderer: + expression_key = id(expression) + if expression_key not in self._statement_renderer_cache: + self._statement_renderer_cache[expression_key] = ExpressionRenderer( + expression, + self.dialect, + self.macro_definitions, + path=self._path, + jinja_macro_registry=self.jinja_macros, + python_env=self.python_env, + only_execution_time=False, + default_catalog=self.default_catalog, + model=self, + ) + return self._statement_renderer_cache[expression_key] def render_signals( self, @@ -386,20 +631,9 @@ def render_signals( The list of rendered expressions. """ - def _create_renderer(expression: exp.Expression) -> ExpressionRenderer: - return ExpressionRenderer( - expression, - self.dialect, - [], - path=self._path, - jinja_macro_registry=self.jinja_macros, - python_env=self.python_env, - only_execution_time=False, - ) - def _render(e: exp.Expression) -> str | int | float | bool: rendered_exprs = ( - _create_renderer(e).render(start=start, end=end, execution_time=execution_time) + self._create_renderer(e).render(start=start, end=end, execution_time=execution_time) or [] ) if len(rendered_exprs) != 1: @@ -414,7 +648,95 @@ def _render(e: exp.Expression) -> str | int | float | bool: return rendered.this return rendered.sql(dialect=self.dialect) - return [{t.this.name: _render(t.expression) for t in signal} for signal in self.signals] + # airflow only + return [ + {k: _render(v) for k, v in signal.items()} for name, signal in self.signals if not name + ] + + def render_signal_calls(self) -> EvaluatableSignals: + python_env = self.python_env + env = prepare_env(python_env) + signals_to_kwargs = { + name: { + k: seq_get(self._create_renderer(v).render() or [], 0) for k, v in kwargs.items() + } + for name, kwargs in self.signals + if name + } + + return EvaluatableSignals( + signals_to_kwargs=signals_to_kwargs, + python_env=python_env, + prepared_python_env=env, + ) + + def render_merge_filter( + self, + *, + start: t.Optional[TimeLike] = None, + end: t.Optional[TimeLike] = None, + execution_time: t.Optional[TimeLike] = None, + ) -> t.Optional[exp.Expression]: + if self.merge_filter is None: + return None + rendered_exprs = ( + self._create_renderer(self.merge_filter).render( + start=start, end=end, execution_time=execution_time + ) + or [] + ) + if len(rendered_exprs) != 1: + raise SQLMeshError(f"Expected one expression but got {len(rendered_exprs)}") + return rendered_exprs[0].transform(d.replace_merge_table_aliases, dialect=self.dialect) + + def _render_properties( + self, properties: t.Dict[str, exp.Expression] | SessionProperties, **render_kwargs: t.Any + ) -> t.Dict[str, t.Any]: + def _render(expression: exp.Expression) -> exp.Expression | None: + # note: we use the _statement_renderer instead of _create_renderer because it sets model_fqn which + # in turn makes @this_model available in the evaluation context + rendered_exprs = self._statement_renderer(expression).render(**render_kwargs) + + # Inform instead of raising for cases where a property is conditionally assigned + if not rendered_exprs or rendered_exprs[0].sql().lower() in {"none", "null"}: + logger.info( + f"Rendering '{expression.sql(dialect=self.dialect)}' did not return an expression" + ) + return None + + if len(rendered_exprs) != 1: + raise SQLMeshError( + f"Expected one result when rendering '{expression.sql(dialect=self.dialect)}' but got {len(rendered_exprs)}" + ) + + return rendered_exprs[0] + + return { + k: rendered + for k, v in properties.items() + if (rendered := (_render(v) if isinstance(v, exp.Expression) else v)) + } + + def render_physical_properties(self, **render_kwargs: t.Any) -> t.Dict[str, t.Any]: + return self._render_properties(properties=self.physical_properties, **render_kwargs) + + def render_virtual_properties(self, **render_kwargs: t.Any) -> t.Dict[str, t.Any]: + return self._render_properties(properties=self.virtual_properties, **render_kwargs) + + def render_session_properties(self, **render_kwargs: t.Any) -> t.Dict[str, t.Any]: + return self._render_properties(properties=self.session_properties, **render_kwargs) + + def _create_renderer(self, expression: exp.Expression) -> ExpressionRenderer: + return ExpressionRenderer( + expression, + self.dialect, + [], + path=self._path, + jinja_macro_registry=self.jinja_macros, + python_env=self.python_env, + only_execution_time=False, + quote_identifiers=False, + ) def ctas_query(self, **render_kwarg: t.Any) -> exp.Query: """Return a dummy query to do a CTAS. @@ -428,26 +750,11 @@ def ctas_query(self, **render_kwarg: t.Any) -> exp.Query: Return: The mocked out ctas query. """ - query = self.render_query_or_raise(**render_kwarg).copy() + query = self.render_query_or_raise(**render_kwarg).limit(0) for select_or_set_op in query.find_all(exp.Select, exp.SetOperation): - skip_limit = False - ancestor = select_or_set_op.parent - while ancestor and not skip_limit: - if isinstance(ancestor, exp.With) and ancestor.recursive: - skip_limit = True - ancestor = ancestor.parent - - if isinstance(select_or_set_op, exp.Select) and select_or_set_op.args.get("from"): + if isinstance(select_or_set_op, exp.Select) and select_or_set_op.args.get("from_"): select_or_set_op.where(exp.false(), copy=False) - if not skip_limit and not isinstance(select_or_set_op.parent, exp.SetOperation): - select_or_set_op.limit(0, copy=False) - elif ( - not skip_limit - and isinstance(select_or_set_op, exp.SetOperation) - and not isinstance(select_or_set_op.parent, exp.SetOperation) - ): - select_or_set_op.set("limit", exp.Limit(expression=exp.Literal.number(0))) if self.managed_columns: query.select( @@ -461,33 +768,12 @@ def ctas_query(self, **render_kwarg: t.Any) -> exp.Query: ) return query - def referenced_audits(self, audits: t.Dict[str, ModelAudit]) -> t.List[ModelAudit]: - """Returns audits referenced in this model. - - Args: - audits: Available audits by name. - """ - from sqlmesh.core.audit import BUILT_IN_AUDITS - - referenced_audits = [] - - for audit_name, _ in self.audits: - if audit_name in self.inline_audits: - referenced_audits.append(self.inline_audits[audit_name]) - elif audit_name in audits: - referenced_audits.append(audits[audit_name]) - elif audit_name not in BUILT_IN_AUDITS: - raise_config_error( - f"Unknown audit '{audit_name}' referenced in model '{self.name}'", - self._path, - ) - return referenced_audits - - def text_diff(self, other: Node) -> str: + def text_diff(self, other: Node, rendered: bool = False) -> str: """Produce a text diff against another node. Args: other: The node to diff against. + rendered: Whether the diff should compare raw vs rendered models Returns: A unified text diff showing additions and deletions. @@ -497,10 +783,23 @@ def text_diff(self, other: Node) -> str: f"Cannot diff model '{self.name} against a non-model node '{other.name}'" ) - return d.text_diff( - self.render_definition(), other.render_definition(), self.dialect, other.dialect + text_diff = d.text_diff( + self.render_definition(render_query=rendered), + other.render_definition(render_query=rendered), + self.dialect, + other.dialect, ).strip() + if not text_diff and not rendered: + text_diff = d.text_diff( + self.render_definition(render_query=True), + other.render_definition(render_query=True), + self.dialect, + other.dialect, + ).strip() + + return text_diff + def set_time_format(self, default_time_format: str = c.DEFAULT_TIME_COLUMN_FORMAT) -> None: """Sets the default time format for a model. @@ -536,13 +835,19 @@ def convert_to_time_column( time_column_type = columns_to_types[self.time_column.column.name] - return to_time_column(time, time_column_type, self.time_column.format) + return to_time_column( + time, + time_column_type, + self.dialect, + self.time_column.format, + ) return exp.convert(time) - def update_schema( - self, - schema: MappingSchema, - ) -> None: + def set_mapping_schema(self, schema: t.Dict) -> None: + self.mapping_schema.clear() + self.mapping_schema.update(schema) + + def update_schema(self, schema: MappingSchema) -> None: """Updates the schema for this model's dependencies based on the given mapping schema.""" for dep in self.depends_on: table = exp.to_table(dep) @@ -555,7 +860,7 @@ def update_schema( {col: dtype.sql(dialect=self.dialect) for col, dtype in mapping_schema.items()}, ) - @cached_property + @property def depends_on(self) -> t.Set[str]: """All of the upstream dependencies referenced in the model's query, excluding self references. @@ -600,10 +905,6 @@ def sorted_python_env(self) -> t.List[t.Tuple[str, Executable]]: def view_name(self) -> str: return self.fully_qualified_table.name - @property - def python_env(self) -> t.Dict[str, Executable]: - return self.python_env_ or {} - @property def schema_name(self) -> str: return self.fully_qualified_table.db or c.DEFAULT_SCHEMA @@ -624,7 +925,7 @@ def is_python(self) -> bool: def is_seed(self) -> bool: return False - @cached_property + @property def depends_on_self(self) -> bool: return self.fqn in self.full_depends_on @@ -637,12 +938,22 @@ def disable_restatement(self) -> bool: return getattr(self.kind, "disable_restatement", False) @property - def wap_supported(self) -> bool: - return self.kind.is_materialized and (self.storage_format or "").lower() == "iceberg" + def auto_restatement_intervals(self) -> t.Optional[int]: + return getattr(self.kind, "auto_restatement_intervals", None) @property - def inline_audits(self) -> t.Dict[str, ModelAudit]: - return {} + def auto_restatement_cron(self) -> t.Optional[str]: + return getattr(self.kind, "auto_restatement_cron", None) + + def auto_restatement_croniter(self, value: TimeLike) -> CroniterCache: + cron = self.auto_restatement_cron + if cron is None: + raise SQLMeshError("Auto restatement cron is not set.") + return CroniterCache(cron, value) + + @property + def wap_supported(self) -> bool: + return self.kind.is_materialized and (self.storage_format or "").lower() == "iceberg" def validate_definition(self) -> None: """Validates the model's definition. @@ -667,7 +978,7 @@ def validate_definition(self) -> None: if len(values) != len(unique_keys): raise_config_error( - "All keys in '{field}' must be unique in the model definition", + f"All keys in '{field}' must be unique in the model definition", self._path, ) @@ -701,9 +1012,37 @@ def validate_definition(self) -> None: # TODO: would this sort of logic be better off moved into the Kind? if self.dialect == "snowflake" and "target_lag" not in self.physical_properties: raise_config_error( - "Snowflake managed tables must specify the 'target_lag' physical property" + "Snowflake managed tables must specify the 'target_lag' physical property", + self._path, + ) + + if self.physical_version is not None and not self.forward_only: + raise_config_error( + "Pinning a physical version is only supported for forward only models", + self._path, + ) + + # The following attributes should be set only for SQL models + if not self.is_sql: + if self.optimize_query: + raise_config_error( + "SQLMesh query optimizer can only be enabled for SQL models", + self._path, ) + if isinstance(self.kind, CustomKind): + from sqlmesh.core.snapshot.evaluator import get_custom_materialization_type_or_raise + + # Will raise if the custom materialization points to an invalid class + get_custom_materialization_type_or_raise(self.kind.materialization) + + # Embedded model kind shouldn't have audits + if self.kind.name == ModelKindName.EMBEDDED and self.audits: + raise_config_error( + "Audits are not supported for embedded models", + self._path, + ) + def is_breaking_change(self, previous: Model) -> t.Optional[bool]: """Determines whether this model is a breaking change in relation to the `previous` model. @@ -716,6 +1055,45 @@ def is_breaking_change(self, previous: Model) -> t.Optional[bool]: """ raise NotImplementedError + def is_metadata_only_change(self, other: _Node) -> bool: + if self._is_metadata_only_change_cache.get(id(other), None) is not None: + return self._is_metadata_only_change_cache[id(other)] + + is_metadata_change = True + if ( + not isinstance(other, _Model) + or self.metadata_hash == other.metadata_hash + or self._data_hash_values_no_sql != other._data_hash_values_no_sql + ): + is_metadata_change = False + else: + this_statements = [ + s + for s in [*self.pre_statements, *self.post_statements] + if not self._is_metadata_statement(s) + ] + other_statements = [ + s + for s in [*other.pre_statements, *other.post_statements] + if not other._is_metadata_statement(s) + ] + if len(this_statements) != len(other_statements): + is_metadata_change = False + else: + for this_statement, other_statement in zip(this_statements, other_statements): + this_rendered = ( + self._statement_renderer(this_statement).render() or this_statement + ) + other_rendered = ( + other._statement_renderer(other_statement).render() or other_statement + ) + if this_rendered != other_rendered: + is_metadata_change = False + break + + self._is_metadata_only_change_cache[id(other)] = is_metadata_change + return is_metadata_change + @property def data_hash(self) -> str: """ @@ -724,27 +1102,48 @@ def data_hash(self) -> str: Returns: The data hash for the node. """ - return hash_data(self._data_hash_values) + if self._data_hash is None: + self._data_hash = hash_data(self._data_hash_values) + return self._data_hash @property def _data_hash_values(self) -> t.List[str]: + return self._data_hash_values_no_sql + self._data_hash_values_sql + + @property + def _data_hash_values_sql(self) -> t.List[str]: + data = [] + + for statements in [self.pre_statements_, self.post_statements_]: + for statement in statements or []: + data.append(statement.sql) + + return data + + @property + def _data_hash_values_no_sql(self) -> t.List[str]: data = [ str( # Exclude metadata only macro funcs [(k, v) for k, v in self.sorted_python_env if not v.is_metadata] ), *self.kind.data_hash_values, + self.table_format, self.storage_format, str(self.lookback), *(gen(expr) for expr in (self.partitioned_by or [])), - *(self.clustered_by or []), + *(gen(expr) for expr in (self.clustered_by or [])), self.stamp, self.physical_schema, + self.physical_version, + self.gateway, self.interval_unit.value if self.interval_unit is not None else None, + str(self.optimize_query) if self.optimize_query is not None else None, + self.virtual_environment_mode.value, ] for column_name, column_type in (self.columns_to_types_ or {}).items(): data.append(column_name) - data.append(column_type.sql()) + data.append(column_type.sql(dialect=self.dialect)) for key, value in (self.physical_properties or {}).items(): data.append(key) @@ -752,80 +1151,107 @@ def _data_hash_values(self) -> t.List[str]: return data # type: ignore - def metadata_hash(self, audits: t.Dict[str, ModelAudit]) -> str: - """ - Computes the metadata hash for the node. - - Args: - audits: Available audits by name. - - Returns: - The metadata hash for the node. - """ - from sqlmesh.core.audit import BUILT_IN_AUDITS + def _audit_metadata_hash_values(self) -> t.List[str]: + from sqlmesh.core.audit.builtin import BUILT_IN_AUDITS - metadata = [ - self.dialect, - self.owner, - self.description, - json.dumps(self.column_descriptions, sort_keys=True), - self.cron, - str(self.start) if self.start else None, - str(self.end) if self.end else None, - str(self.retention) if self.retention else None, - str(self.batch_size) if self.batch_size is not None else None, - str(self.batch_concurrency) if self.batch_concurrency is not None else None, - json.dumps(self.mapping_schema, sort_keys=True), - *sorted(self.tags), - *sorted(ref.json(sort_keys=True) for ref in self.all_references), - *self.kind.metadata_hash_values, - self.project, - str(self.allow_partials), - gen(self.session_properties_) if self.session_properties_ else None, - ] + metadata = [] for audit_name, audit_args in sorted(self.audits, key=lambda a: a[0]): metadata.append(audit_name) - audit = None if audit_name in BUILT_IN_AUDITS: for arg_name, arg_value in audit_args.items(): metadata.append(arg_name) metadata.append(gen(arg_value)) - elif audit_name in self.inline_audits: - audit = self.inline_audits[audit_name] - elif audit_name in audits: - audit = audits[audit_name] else: - raise SQLMeshError(f"Unexpected audit name '{audit_name}'.") - - if audit: - query = ( - audit.render_query(self, **t.cast(t.Dict[str, t.Any], audit_args)) - or audit.query - ) + audit = self.audit_definitions[audit_name] metadata.extend( [ - gen(query), + audit.query_.sql, audit.dialect, str(audit.skip), str(audit.blocking), ] ) - for key, value in (self.virtual_properties or {}).items(): - metadata.append(key) - metadata.append(gen(value)) + return metadata - metadata.extend(gen(s) for s in self.signals) - metadata.extend(self._additional_metadata) + def audit_metadata_hash(self) -> str: + return hash_data(self._audit_metadata_hash_values()) - return hash_data(metadata) + @property + def metadata_hash(self) -> str: + """ + Computes the metadata hash for the node. + + Returns: + The metadata hash for the node. + """ + if self._metadata_hash is None: + metadata = [ + self.dialect, + self.owner, + self.description, + json.dumps(self.column_descriptions, sort_keys=True), + self.cron, + self.cron_tz.key if self.cron_tz else None, + str(self.start) if self.start else None, + str(self.end) if self.end else None, + str(self.retention) if self.retention else None, + str(self.batch_size) if self.batch_size is not None else None, + str(self.batch_concurrency) if self.batch_concurrency is not None else None, + json.dumps(self.mapping_schema, sort_keys=True), + *sorted(self.tags), + *sorted(ref.json(sort_keys=True) for ref in self.all_references), + *self.kind.metadata_hash_values, + self.project, + str(self.allow_partials), + gen(self.session_properties_) if self.session_properties_ else None, + *[gen(g) for g in self.grains], + *self._audit_metadata_hash_values(), + json.dumps(self.grants, sort_keys=True) if self.grants else None, + self.grants_target_layer, + ] + + for key, value in (self.virtual_properties or {}).items(): + metadata.append(key) + metadata.append(gen(value)) + + for signal_name, args in sorted(self.signals, key=lambda x: x[0]): + metadata.append(signal_name) + for k, v in sorted(args.items()): + metadata.append(f"{k}:{gen(v)}") + + if self.dbt_node_info: + metadata.append(self.dbt_node_info.json(sort_keys=True)) + + metadata.extend(self._additional_metadata) + + self._metadata_hash = hash_data(metadata) + return self._metadata_hash @property def is_model(self) -> bool: """Return True if this is a model node""" return True + @property + def grants_table_type(self) -> DataObjectType: + """Get the table type for grants application (TABLE, VIEW, MATERIALIZED_VIEW). + + Returns: + The DataObjectType that should be used when applying grants to this model. + """ + from sqlmesh.core.engine_adapter.shared import DataObjectType + + if self.kind.is_view: + if hasattr(self.kind, "materialized") and getattr(self.kind, "materialized", False): + return DataObjectType.MATERIALIZED_VIEW + return DataObjectType.VIEW + if self.kind.is_managed: + return DataObjectType.MANAGED_TABLE + # All other materialized models are tables + return DataObjectType.TABLE + @property def _additional_metadata(self) -> t.List[str]: additional_metadata = [] @@ -834,14 +1260,32 @@ def _additional_metadata(self) -> t.List[str]: if metadata_only_macros: additional_metadata.append(str(metadata_only_macros)) + for statements in [self.pre_statements_, self.post_statements_, self.on_virtual_update_]: + for statement in statements or []: + additional_metadata.append(statement.sql) + return additional_metadata + def _is_metadata_statement(self, statement: exp.Expression) -> bool: + if isinstance(statement, d.MacroDef): + return True + if isinstance(statement, d.MacroFunc): + target_macro = macro.get_registry().get(statement.name) + if target_macro: + return target_macro.metadata_only + target_macro = self.python_env.get(statement.name) + if target_macro: + return bool(target_macro.is_metadata) + return False + @property def full_depends_on(self) -> t.Set[str]: - if not self._full_depends_on: + if not self.extract_dependencies_from_query: + return self.depends_on_ or set() + if self._full_depends_on is None: depends_on = self.depends_on_ or set() - query = self.render_query(optimize=False) + query = self.render_query(needs_optimization=False) if query is not None: depends_on |= d.find_tables( query, default_catalog=self.default_catalog, dialect=self.dialect @@ -850,186 +1294,102 @@ def full_depends_on(self) -> t.Set[str]: return self._full_depends_on - -class _SqlBasedModel(_Model): - pre_statements_: t.Optional[t.List[exp.Expression]] = Field( - default=None, alias="pre_statements" - ) - post_statements_: t.Optional[t.List[exp.Expression]] = Field( - default=None, alias="post_statements" - ) - inline_audits_: t.Dict[str, t.Any] = Field(default={}, alias="inline_audits") - - __statement_renderers: t.Dict[int, ExpressionRenderer] = {} - - _expression_validator = expression_validator - - @field_validator("inline_audits_", mode="before") - @field_validator_v1_args - def _inline_audits_validator(cls, v: t.Any, values: t.Dict[str, t.Any]) -> t.Any: - if not isinstance(v, dict): - return {} - - from sqlmesh.core.audit import ModelAudit - - inline_audits = {} - - for name, audit in v.items(): - if isinstance(audit, ModelAudit): - inline_audits[name] = audit - elif isinstance(audit, dict): - inline_audits[name] = ModelAudit.parse_obj(audit) - - return inline_audits - - def render_pre_statements( - self, - *, - start: t.Optional[TimeLike] = None, - end: t.Optional[TimeLike] = None, - execution_time: t.Optional[TimeLike] = None, - snapshots: t.Optional[t.Collection[Snapshot]] = None, - expand: t.Iterable[str] = tuple(), - deployability_index: t.Optional[DeployabilityIndex] = None, - engine_adapter: t.Optional[EngineAdapter] = None, - **kwargs: t.Any, - ) -> t.List[exp.Expression]: - return self._render_statements( - self.pre_statements, - start=start, - end=end, - execution_time=execution_time, - snapshots=snapshots, - expand=expand, - deployability_index=deployability_index, - engine_adapter=engine_adapter, - **kwargs, - ) - - def render_post_statements( - self, - *, - start: t.Optional[TimeLike] = None, - end: t.Optional[TimeLike] = None, - execution_time: t.Optional[TimeLike] = None, - snapshots: t.Optional[t.Collection[Snapshot]] = None, - expand: t.Iterable[str] = tuple(), - deployability_index: t.Optional[DeployabilityIndex] = None, - engine_adapter: t.Optional[EngineAdapter] = None, - **kwargs: t.Any, - ) -> t.List[exp.Expression]: - return self._render_statements( - self.post_statements, - start=start, - end=end, - execution_time=execution_time, - snapshots=snapshots, - expand=expand, - deployability_index=deployability_index, - engine_adapter=engine_adapter, - **kwargs, - ) - @property - def pre_statements(self) -> t.List[exp.Expression]: - return self.pre_statements_ or [] + def partitioned_by(self) -> t.List[exp.Expression]: + """Columns to partition the model by, including the time column if it is not already included.""" + if self.time_column and not self._is_time_column_in_partitioned_by: + # This allows the user to opt out of automatic time_column injection + # by setting `partition_by_time_column false` on the model kind + if ( + hasattr(self.kind, "partition_by_time_column") + and self.kind.partition_by_time_column + ): + return [ + TIME_COL_PARTITION_FUNC.get(self.dialect, lambda x, y: x)( + self.time_column.column, self.columns_to_types + ), + *self.partitioned_by_, + ] + return self.partitioned_by_ @property - def post_statements(self) -> t.List[exp.Expression]: - return self.post_statements_ or [] + def partition_interval_unit(self) -> t.Optional[IntervalUnit]: + """The interval unit to use for partitioning if applicable.""" + # Only return the interval unit for partitioning if the partitioning + # wasn't explicitly set by the user. Otherwise, the user-provided + # value should always take precedence. + if self.time_column and not self._is_time_column_in_partitioned_by: + return self.interval_unit + return None @property - def macro_definitions(self) -> t.List[d.MacroDef]: - """All macro definitions from the list of expressions.""" - return [s for s in self.pre_statements + self.post_statements if isinstance(s, d.MacroDef)] + def audits_with_args(self) -> t.List[t.Tuple[Audit, t.Dict[str, exp.Expression]]]: + from sqlmesh.core.audit.builtin import BUILT_IN_AUDITS - @property - def inline_audits(self) -> t.Dict[str, ModelAudit]: - return self.inline_audits_ + audits_by_name = {**BUILT_IN_AUDITS, **self.audit_definitions} + audits_with_args = [] + added_audits = set() - def _render_statements( - self, - statements: t.Iterable[exp.Expression], - **kwargs: t.Any, - ) -> t.List[exp.Expression]: - rendered = ( - self._statement_renderer(statement).render(**kwargs) - for statement in statements - if not isinstance(statement, d.MacroDef) - ) - return [r for expressions in rendered if expressions for r in expressions] + for audit_name, audit_args in self.audits: + audits_with_args.append((audits_by_name[audit_name], audit_args.copy())) + added_audits.add(audit_name) - def _statement_renderer(self, expression: exp.Expression) -> ExpressionRenderer: - expression_key = id(expression) - if expression_key not in self.__statement_renderers: - self.__statement_renderers[expression_key] = ExpressionRenderer( - expression, - self.dialect, - self.macro_definitions, - path=self._path, - jinja_macro_registry=self.jinja_macros, - python_env=self.python_env, - only_execution_time=self.kind.only_execution_time, - default_catalog=self.default_catalog, - model_fqn=self.fqn, - ) - return self.__statement_renderers[expression_key] + for audit_name in self.audit_definitions: + if audit_name not in added_audits: + audits_with_args.append((audits_by_name[audit_name], {})) - @property - def _data_hash_values(self) -> t.List[str]: - data_hash_values = super()._data_hash_values - - for statement in (*self.pre_statements, *self.post_statements): - statement_exprs: t.List[exp.Expression] = [] - if not isinstance(statement, d.MacroDef): - rendered = self._statement_renderer(statement).render() - if self._is_metadata_statement(statement): - continue - if rendered: - statement_exprs = rendered - else: - statement_exprs = [statement] - data_hash_values.extend(gen(e) for e in statement_exprs) - - return data_hash_values + return audits_with_args @property - def _additional_metadata(self) -> t.List[str]: - additional_metadata = super()._additional_metadata - - for statement in (*self.pre_statements, *self.post_statements): - if self._is_metadata_statement(statement): - additional_metadata.append(gen(statement)) - - return additional_metadata + def _is_time_column_in_partitioned_by(self) -> bool: + return self.time_column is not None and self.time_column.column in { + col for expr in self.partitioned_by_ for col in expr.find_all(exp.Column) + } - def _is_metadata_statement(self, statement: exp.Expression) -> bool: - if isinstance(statement, d.MacroDef): - return True - if isinstance(statement, d.MacroFunc): - target_macro = macro.get_registry().get(statement.name) - if target_macro: - return target_macro.metadata_only - target_macro = self.python_env.get(statement.name) - if target_macro: - return bool(target_macro.is_metadata) - return False + @property + def violated_rules_for_query(self) -> t.Dict[type[Rule], t.Any]: + return {} -class SqlModel(_SqlBasedModel): +class SqlModel(_Model): """The model definition which relies on a SQL query to fetch the data. Args: query: The main query representing the model. pre_statements: The list of SQL statements that precede the model's query. post_statements: The list of SQL statements that follow after the model's query. + on_virtual_update: The list of SQL statements to be executed after the virtual update. """ - query: t.Union[exp.Query, d.JinjaQuery, d.MacroFunc] - source_type: Literal["sql"] = "sql" + query_: ParsableSql = Field(alias="query") + source_type: t.Literal["sql"] = "sql" _columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None + def __getstate__(self) -> t.Dict[t.Any, t.Any]: + state = super().__getstate__() + state["__dict__"] = state["__dict__"].copy() + # query renderer is very expensive to serialize + state["__dict__"].pop("_query_renderer", None) + state["__dict__"].pop("column_descriptions", None) + private = state[PRIVATE_FIELDS] + private["_columns_to_types"] = None + return state + + def copy(self, **kwargs: t.Any) -> Self: + model = super().copy(**kwargs) + model.__dict__.pop("_query_renderer", None) + model.__dict__.pop("column_descriptions", None) + model._columns_to_types = None + if kwargs.get("update", {}).keys() & {"depends_on_", "query"}: + model._full_depends_on = None + return model + + @property + def query(self) -> t.Union[exp.Query, d.JinjaQuery, d.MacroFunc]: + parsed_query = self.query_.parse(self.dialect) + return t.cast(t.Union[exp.Query, d.JinjaQuery, d.MacroFunc], parsed_query) + def render_query( self, *, @@ -1054,17 +1414,32 @@ def render_query( engine_adapter=engine_adapter, **kwargs, ) + return query def render_definition( - self, include_python: bool = True, include_defaults: bool = False + self, + include_python: bool = True, + include_defaults: bool = False, + render_query: bool = False, ) -> t.List[exp.Expression]: result = super().render_definition( include_python=include_python, include_defaults=include_defaults ) - result.extend(self.pre_statements) - result.append(self.query) - result.extend(self.post_statements) + + if render_query: + result.extend(self.render_pre_statements()) + result.append(self.render_query() or self.query) + result.extend(self.render_post_statements()) + if virtual_update := self.render_on_virtual_update(): + result.append(d.VirtualUpdateStatement(expressions=virtual_update)) + else: + result.extend(self.pre_statements) + result.append(self.query) + result.extend(self.post_statements) + if self.on_virtual_update: + result.append(d.VirtualUpdateStatement(expressions=self.on_virtual_update)) + return result @property @@ -1076,15 +1451,31 @@ def columns_to_types(self) -> t.Optional[t.Dict[str, exp.DataType]]: if self.columns_to_types_ is not None: self._columns_to_types = self.columns_to_types_ elif self._columns_to_types is None: - query = self._query_renderer.render() + try: + query = self._query_renderer.render() + except Exception: + logger.exception("Failed to render query for model %s", self.fqn) + return None if query is None: return None - self._columns_to_types = { - select.output_name: select.type or exp.DataType.build("unknown") - for select in query.selects - } + unknown = exp.DataType.build("unknown") + + columns_to_types = {} + for select in query.selects: + output_name = select.output_name + + # If model validation is disabled, we cannot assume that projections + # will have inferrable output names or even that they will be unique + if not output_name or output_name in columns_to_types: + return None + + # copy data type because it is used in the engine to build CTAS and other queries + # this can change the parent which will mess up the diffing algo + columns_to_types[output_name] = (select.type or unknown).copy() + + self._columns_to_types = columns_to_types if "*" in self._columns_to_types: return None @@ -1106,11 +1497,15 @@ def column_descriptions(self) -> t.Dict[str, str]: if select.comments } - def update_schema( - self, - schema: MappingSchema, - ) -> None: + def set_mapping_schema(self, schema: t.Dict) -> None: + super().set_mapping_schema(schema) + self._on_mapping_schema_set() + + def update_schema(self, schema: MappingSchema) -> None: super().update_schema(schema) + self._on_mapping_schema_set() + + def _on_mapping_schema_set(self) -> None: self._columns_to_types = None self._query_renderer.update_schema(self.mapping_schema) @@ -1131,22 +1526,6 @@ def validate_definition(self) -> None: if not projection_list: raise_config_error("Query missing select statements", self._path) - name_counts: t.Dict[str, int] = {} - for expression in projection_list: - alias = expression.output_name - if alias == "*": - continue - if not alias: - raise_config_error( - f"Outer projection '{expression}' must have inferrable names or explicit aliases.", - self._path, - ) - name_counts[alias] = name_counts.get(alias, 0) + 1 - - for name, count in name_counts.items(): - if count > 1: - raise_config_error(f"Found duplicate outer select name '{name}'", self._path) - if self.depends_on_self and not self.annotated: raise_config_error( "Self-referencing models require inferrable column types. There are three options available to mitigate this issue: add explicit types to all projections in the outermost SELECT statement, leverage external models (https://sqlmesh.readthedocs.io/en/stable/concepts/models/external_models/), or use the `columns` model attribute (https://sqlmesh.readthedocs.io/en/stable/concepts/models/overview/#columns).", @@ -1173,29 +1552,56 @@ def is_breaking_change(self, previous: Model) -> t.Optional[bool]: # Can't determine if there's a breaking change if we can't render the query. return None - edits = diff(previous_query, this_query, matchings=[(previous_query, this_query)]) + if previous_query is this_query: + edits = [] + else: + edits = diff( + previous_query, + this_query, + matchings=[(previous_query, this_query)], + delta_only=True, + dialect=self.dialect if self.dialect == previous.dialect else None, + ) inserted_expressions = {e.expression for e in edits if isinstance(e, Insert)} for edit in edits: - if isinstance(edit, Insert): - expr = edit.expression - if _is_udtf(expr): - # projection subqueries do not change cardinality, engines don't allow these to return - # more than one row of data - parent = expr.find_ancestor(exp.Subquery) - - if not parent: - return None + if not isinstance(edit, Insert): + return None - expr = parent + expr = edit.expression + if isinstance(expr, exp.UDTF): + # projection subqueries do not change cardinality, engines don't allow these to return + # more than one row of data + parent = expr.find_ancestor(exp.Subquery) - if not _is_projection(expr) and expr.parent not in inserted_expressions: + if not parent: return None - elif not isinstance(edit, Keep): + + expr = parent + + if not _is_projection(expr) and expr.parent not in inserted_expressions: return None return False + def is_metadata_only_change(self, previous: _Node) -> bool: + if self._is_metadata_only_change_cache.get(id(previous), None) is not None: + return self._is_metadata_only_change_cache[id(previous)] + + if not super().is_metadata_only_change(previous): + return False + + if not isinstance(previous, SqlModel): + self._is_metadata_only_change_cache[id(previous)] = False + return False + + this_rendered_query = self.render_query() or self.query + previous_rendered_query = previous.render_query() or previous.query + is_metadata_change = this_rendered_query == previous_rendered_query + + self._is_metadata_only_change_cache[id(previous)] = is_metadata_change + return is_metadata_change + @cached_property def _query_renderer(self) -> QueryRenderer: no_quote_identifiers = self.kind.is_view and self.dialect in ("trino", "spark") @@ -1204,26 +1610,41 @@ def _query_renderer(self) -> QueryRenderer: self.dialect, self.macro_definitions, schema=self.mapping_schema, - model_fqn=self.fqn, path=self._path, jinja_macro_registry=self.jinja_macros, python_env=self.python_env, only_execution_time=self.kind.only_execution_time, default_catalog=self.default_catalog, quote_identifiers=not no_quote_identifiers, + optimize_query=self.optimize_query, + model=self, ) @property - def _data_hash_values(self) -> t.List[str]: - data = super()._data_hash_values + def _data_hash_values_no_sql(self) -> t.List[str]: + return [ + *super()._data_hash_values_no_sql, + *self.jinja_macros.data_hash_values, + ] - query = self.render_query() or self.query - data.append(gen(query)) - data.extend(self.jinja_macros.data_hash_values) - return data + @property + def _data_hash_values_sql(self) -> t.List[str]: + return [ + *super()._data_hash_values_sql, + self.query_.sql, + ] + + @property + def _additional_metadata(self) -> t.List[str]: + return [*super()._additional_metadata, self.query_.sql] + + @property + def violated_rules_for_query(self) -> t.Dict[type[Rule], t.Any]: + self.render_query() + return self._query_renderer._violated_rules -class SeedModel(_SqlBasedModel): +class SeedModel(_Model): """The model definition which uses a pre-built static dataset to source the data from. Args: @@ -1235,7 +1656,18 @@ class SeedModel(_SqlBasedModel): column_hashes_: t.Optional[t.Dict[str, str]] = Field(default=None, alias="column_hashes") derived_columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None is_hydrated: bool = True - source_type: Literal["seed"] = "seed" + source_type: t.Literal["seed"] = "seed" + + def __getstate__(self) -> t.Dict[t.Any, t.Any]: + state = super().__getstate__() + state["__dict__"] = state["__dict__"].copy() + state["__dict__"].pop("_reader", None) + return state + + def copy(self, **kwargs: t.Any) -> Self: + model = super().copy(**kwargs) + model.__dict__.pop("_reader", None) + return model def render( self, @@ -1251,6 +1683,8 @@ def render( yield from self.render_seed() def render_seed(self) -> t.Iterator[QueryOrDF]: + import numpy as np + self._ensure_hydrated() date_columns = [] @@ -1258,7 +1692,9 @@ def render_seed(self) -> t.Iterator[QueryOrDF]: bool_columns = [] string_columns = [] - for name, tpe in (self.columns_to_types_ or {}).items(): + columns_to_types = self.columns_to_types_ or {} + column_names_to_check = set(columns_to_types) + for name, tpe in columns_to_types.items(): if tpe.this in (exp.DataType.Type.DATE, exp.DataType.Type.DATE32): date_columns.append(name) elif tpe.this in exp.DataType.TEMPORAL_TYPES: @@ -1269,15 +1705,44 @@ def render_seed(self) -> t.Iterator[QueryOrDF]: string_columns.append(name) for df in self._reader.read(batch_size=self.kind.batch_size): + rename_dict = {} + for column in columns_to_types: + if column not in df: + normalized_name = normalize_identifiers(column, dialect=self.dialect).name + if normalized_name in df: + rename_dict[normalized_name] = column + if rename_dict: + df.rename(columns=rename_dict, inplace=True) + # These names have already been checked + column_names_to_check -= set(rename_dict) + + missing_columns = column_names_to_check - set(df.columns) + if missing_columns: + raise_config_error( + f"Seed model '{self.name}' has missing columns: {missing_columns}", self._path + ) + # convert all date/time types to native pandas timestamp for column in [*date_columns, *datetime_columns]: - df[column] = pd.to_datetime(df[column]) + import pandas as pd + + df[column] = pd.to_datetime(df[column], infer_datetime_format=True, errors="ignore") # type: ignore # extract datetime.date from pandas timestamp for DATE columns for column in date_columns: - df[column] = df[column].dt.date + try: + df[column] = df[column].dt.date + except Exception as ex: + logger.error( + "Failed to convert column '%s' to date in seed model '%s': %s", + column, + self.name, + ex, + ) + + for column in bool_columns: + df[column] = df[column].apply(lambda i: str_to_bool(str(i))) - df[bool_columns] = df[bool_columns].apply(lambda i: str_to_bool(str(i))) df.loc[:, string_columns] = df[string_columns].mask( cond=lambda x: x.notna(), # type: ignore other=df[string_columns].astype(str), # type: ignore @@ -1309,10 +1774,12 @@ def is_seed(self) -> bool: def seed_path(self) -> Path: seed_path = Path(self.kind.path) if not seed_path.is_absolute(): + if self._path is None: + raise SQLMeshError(f"Seed model '{self.name}' has no path") return self._path.parent / seed_path return seed_path - @cached_property + @property def depends_on(self) -> t.Set[str]: return (self.depends_on_ or set()) - {self.fqn} @@ -1393,11 +1860,17 @@ def _reader(self) -> CsvSeedReader: return self.seed.reader(dialect=self.dialect, settings=self.kind.csv_settings) @property - def _data_hash_values(self) -> t.List[str]: - data = super()._data_hash_values + def _data_hash_values_no_sql(self) -> t.List[str]: + data = super()._data_hash_values_no_sql for column_name, column_hash in self.column_hashes.items(): data.append(column_name) data.append(column_hash) + + # Include grants in data hash for seed models to force recreation on grant changes + # since seed models don't support migration + data.append(json.dumps(self.grants, sort_keys=True) if self.grants else "") + data.append(self.grants_target_layer) + return data @@ -1408,15 +1881,17 @@ class PythonModel(_Model): entrypoint: The name of a Python function which contains the data fetching / transformation logic. """ + kind: ModelKind = FullKind() entrypoint: str - source_type: Literal["python"] = "python" + source_type: t.Literal["python"] = "python" def validate_definition(self) -> None: super().validate_definition() if self.kind and not self.kind.supports_python_models: - raise SQLMeshError( - f"Cannot create Python model '{self.name}' as the '{self.kind.name}' kind doesnt support Python models" + raise_config_error( + f"Cannot create Python model '{self.name}' as the '{self.kind.name}' kind doesn't support Python models", + self._path, ) def render( @@ -1429,12 +1904,21 @@ def render( **kwargs: t.Any, ) -> t.Iterator[QueryOrDF]: env = prepare_env(self.python_env) - start, end = make_inclusive(start or c.EPOCH, end or c.EPOCH) + start, end = make_inclusive(start or c.EPOCH, end or c.EPOCH, self.dialect) execution_time = to_datetime(execution_time or c.EPOCH) - variables = env.get(c.SQLMESH_VARS, {}) - variables.update(kwargs.pop("variables", {})) - + variables = { + **env.get(c.SQLMESH_VARS, {}), + **env.get(c.SQLMESH_VARS_METADATA, {}), + **kwargs.pop("variables", {}), + } + blueprint_variables = { + k: d.parse_one(v.sql, dialect=self.dialect) if isinstance(v, SqlValue) else v + for k, v in { + **env.get(c.SQLMESH_BLUEPRINT_VARS, {}), + **env.get(c.SQLMESH_BLUEPRINT_VARS_METADATA, {}), + }.items() + } try: kwargs = { **variables, @@ -1445,7 +1929,7 @@ def render( "latest": execution_time, # TODO: Preserved for backward compatibility. Remove in 1.0.0. } df_or_iter = env[self.entrypoint]( - context=context.with_variables(variables), + context=context.with_variables(variables, blueprint_variables=blueprint_variables), **kwargs, ) @@ -1455,15 +1939,19 @@ def render( for df in df_or_iter: yield df except Exception as e: - print_exception(e, self.python_env) - raise SQLMeshError(f"Error executing Python model '{self.name}'") + raise PythonModelEvalError(format_evaluated_code_exception(e, self.python_env)) def render_definition( - self, include_python: bool = True, include_defaults: bool = False + self, + include_python: bool = True, + include_defaults: bool = False, + render_query: bool = False, ) -> t.List[exp.Expression]: # Ignore the provided value for the include_python flag, since the Pyhon model's # definition without Python code is meaningless. - return super().render_definition(include_python=True, include_defaults=include_defaults) + return super().render_definition( + include_python=True, include_defaults=include_defaults, render_query=render_query + ) @property def is_python(self) -> bool: @@ -1473,8 +1961,8 @@ def is_breaking_change(self, previous: Model) -> t.Optional[bool]: return None @property - def _data_hash_values(self) -> t.List[str]: - data = super()._data_hash_values + def _data_hash_values_no_sql(self) -> t.List[str]: + data = super()._data_hash_values_no_sql data.append(self.entrypoint) return data @@ -1482,8 +1970,8 @@ def _data_hash_values(self) -> t.List[str]: class ExternalModel(_Model): """The model definition which represents an external source/table.""" - source_type: Literal["external"] = "external" - gateway: t.Optional[str] = None + kind: ModelKind = ExternalKind() + source_type: t.Literal["external"] = "external" def is_breaking_change(self, previous: Model) -> t.Optional[bool]: if not isinstance(previous, ExternalModel): @@ -1504,21 +1992,190 @@ def depends_on_self(self) -> bool: Model = t.Union[SqlModel, SeedModel, PythonModel, ExternalModel] +class AuditResult(PydanticModel): + audit: Audit + """The audit this result is for.""" + audit_args: t.Dict[t.Any, t.Any] + """Arguments passed to the audit.""" + model: t.Optional[_Model] = None + """The model this audit is for.""" + count: t.Optional[int] = None + """The number of records returned by the audit query. This could be None if the audit was skipped.""" + query: t.Optional[exp.Expression] = None + """The rendered query used by the audit. This could be None if the audit was skipped.""" + skipped: bool = False + """Whether or not the audit was blocking. This can be overriden by the user.""" + blocking: bool = True + + +class EvaluatableSignals(PydanticModel): + signals_to_kwargs: t.Dict[str, t.Dict[str, t.Optional[exp.Expression]]] + """A mapping of signal names to the kwargs passed to the signal.""" + python_env: t.Dict[str, Executable] + """The Python environment that should be used to evaluated the rendered signal calls.""" + prepared_python_env: t.Dict[str, t.Any] + """The prepared Python environment that should be used to evaluated the rendered signal calls.""" + + +def _extract_blueprints(blueprints: t.Any, path: Path) -> t.List[t.Any]: + if not blueprints: + return [None] + if isinstance(blueprints, exp.Paren): + return [blueprints.unnest()] + if isinstance(blueprints, (exp.Tuple, exp.Array)): + return blueprints.expressions + if isinstance(blueprints, list): + return blueprints + + raise_config_error( + "Expected a list or tuple consisting of key-value mappings for " + f"the 'blueprints' property, got '{blueprints}' instead", + path, + ) + return [] # This is unreachable, but is done to satisfy mypy + + +def _extract_blueprint_variables(blueprint: t.Any, path: Path) -> t.Dict[str, t.Any]: + if not blueprint: + return {} + if isinstance(blueprint, (exp.Paren, exp.PropertyEQ)): + blueprint = blueprint.unnest() + return {blueprint.left.name.lower(): blueprint.right} + if isinstance(blueprint, (exp.Tuple, exp.Array)): + return {e.left.name.lower(): e.right for e in blueprint.expressions} + if isinstance(blueprint, dict): + return {k.lower(): v for k, v in blueprint.items()} + + raise_config_error( + f"Expected a key-value mapping for the blueprint value, got '{blueprint}' instead", + path, + ) + return {} # This is unreachable, but is done to satisfy mypy + + +def create_models_from_blueprints( + gateway: t.Optional[str | exp.Expression], + blueprints: t.Any, + get_variables: t.Callable[[t.Optional[str]], t.Dict[str, str]], + loader: t.Callable[..., Model], + path: Path = Path(), + module_path: Path = Path(), + dialect: DialectType = None, + default_catalog_per_gateway: t.Optional[t.Dict[str, str]] = None, + **loader_kwargs: t.Any, +) -> t.List[Model]: + model_blueprints: t.List[Model] = [] + for blueprint in _extract_blueprints(blueprints, path): + blueprint_variables = _extract_blueprint_variables(blueprint, path) + + if gateway: + rendered_gateway = render_expression( + expression=exp.maybe_parse(gateway, dialect=dialect), + module_path=module_path, + macros=loader_kwargs.get("macros"), + jinja_macros=loader_kwargs.get("jinja_macros"), + path=path, + dialect=dialect, + default_catalog=loader_kwargs.get("default_catalog"), + blueprint_variables=blueprint_variables, + ) + gateway_name = rendered_gateway[0].name if rendered_gateway else None + else: + gateway_name = None + + if ( + default_catalog_per_gateway + and gateway_name + and (catalog := default_catalog_per_gateway.get(gateway_name)) is not None + ): + loader_kwargs["default_catalog"] = catalog + + model_blueprints.append( + loader( + path=path, + module_path=module_path, + dialect=dialect, + variables=get_variables(gateway_name), + blueprint_variables=blueprint_variables, + **loader_kwargs, + ) + ) + + return model_blueprints + + +def load_sql_based_models( + expressions: t.List[exp.Expression], + get_variables: t.Callable[[t.Optional[str]], t.Dict[str, str]], + path: Path = Path(), + module_path: Path = Path(), + dialect: DialectType = None, + default_catalog_per_gateway: t.Optional[t.Dict[str, str]] = None, + **loader_kwargs: t.Any, +) -> t.List[Model]: + gateway: t.Optional[exp.Expression] = None + blueprints: t.Optional[exp.Expression] = None + + model_meta = seq_get(expressions, 0) + for prop in (isinstance(model_meta, d.Model) and model_meta.expressions) or []: + if prop.name == "gateway": + gateway = prop.args["value"] + elif prop.name == "blueprints": + # We pop the `blueprints` here to avoid walking large lists when rendering the meta + blueprints = prop.pop().args["value"] + + if isinstance(blueprints, d.MacroFunc): + rendered_blueprints = render_expression( + expression=blueprints, + module_path=module_path, + macros=loader_kwargs.get("macros"), + jinja_macros=loader_kwargs.get("jinja_macros"), + variables=get_variables(None), + path=path, + dialect=dialect, + default_catalog=loader_kwargs.get("default_catalog"), + ) + if not rendered_blueprints: + raise_config_error("Failed to render blueprints property", path) + + # Help mypy see that rendered_blueprints can't be None + assert rendered_blueprints + + if len(rendered_blueprints) > 1: + rendered_blueprints = [exp.Tuple(expressions=rendered_blueprints)] + + blueprints = rendered_blueprints[0] + + return create_models_from_blueprints( + gateway=gateway, + blueprints=blueprints, + get_variables=get_variables, + loader=partial(load_sql_based_model, expressions), + path=path, + module_path=module_path, + dialect=dialect, + default_catalog_per_gateway=default_catalog_per_gateway, + **loader_kwargs, + ) + + def load_sql_based_model( expressions: t.List[exp.Expression], *, defaults: t.Optional[t.Dict[str, t.Any]] = None, - path: Path = Path(), + path: t.Optional[Path] = None, module_path: Path = Path(), time_column_format: str = c.DEFAULT_TIME_COLUMN_FORMAT, macros: t.Optional[MacroRegistry] = None, jinja_macros: t.Optional[JinjaMacroRegistry] = None, + audits: t.Optional[t.Dict[str, ModelAudit]] = None, python_env: t.Optional[t.Dict[str, Executable]] = None, dialect: t.Optional[str] = None, - physical_schema_override: t.Optional[t.Dict[str, str]] = None, + physical_schema_mapping: t.Optional[t.Dict[re.Pattern, str]] = None, default_catalog: t.Optional[str] = None, variables: t.Optional[t.Dict[str, t.Any]] = None, infer_names: t.Optional[bool] = False, + blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, **kwargs: t.Any, ) -> Model: """Load a model from a parsed SQLMesh model SQL file. @@ -1535,58 +2192,98 @@ def load_sql_based_model( from the macro registry. dialect: The default dialect if no model dialect is configured. The format must adhere to Python's strftime codes. - physical_schema_override: The physical schema override for the model. + physical_schema_mapping: A mapping of regular expressions to match against the model schema to produce the corresponding physical schema default_catalog: The default catalog if no model catalog is configured. variables: The variables to pass to the model. kwargs: Additional kwargs to pass to the loader. """ + missing_model_msg = f"""Please add a MODEL block at the top of the file. Example: + +MODEL ( + name sqlmesh_example.full_model, --model name + kind FULL, --materialization + cron '@daily', --schedule +); + +Learn more at https://sqlmesh.readthedocs.io/en/stable/concepts/models/overview +""" + if not expressions: - raise_config_error("Incomplete model definition, missing MODEL statement", path) + raise_config_error(missing_model_msg) dialect = dialect or "" meta = expressions[0] if not isinstance(meta, d.Model): - raise_config_error( - "MODEL statement is required as the first statement in the definition", - path, - ) + if not infer_names: + raise_config_error(missing_model_msg) + meta = d.Model(expressions=[]) # Dummy meta node + expressions.insert(0, meta) + + # We deliberately hold off rendering some properties at load time because there is not enough information available + # at load time to render them. They will get rendered later at evaluation time + unrendered_properties = {} + unrendered_merge_filter = None - unrendered_signals = None for prop in meta.expressions: - if prop.name.lower() == "signals": - unrendered_signals = prop.args.get("value") + # Macro functions that programmaticaly generate the key-value pair properties should be rendered + # This is needed in the odd case where a macro shares the name of one of the properties + # eg `@session_properties()` Test: `test_macros_in_model_statement` Reference PR: #2574 + if isinstance(prop, d.MacroFunc): + continue - meta_python_env = _python_env( - expressions=meta, - jinja_macro_references=None, + prop_name = prop.name.lower() + if prop_name in {"signals", "audits"} | PROPERTIES: + unrendered_properties[prop_name] = prop.args.get("value") + elif ( + prop.name.lower() == "kind" + and (value := prop.args.get("value")) + and value.name.lower() == "incremental_by_unique_key" + ): + for kind_prop in value.expressions: + if kind_prop.name.lower() == "merge_filter": + unrendered_merge_filter = kind_prop + + rendered_meta_exprs = render_expression( + expression=meta, module_path=module_path, - macros=macros or macro.get_registry(), + macros=macros, + jinja_macros=jinja_macros, variables=variables, path=path, - ) - meta_renderer = ExpressionRenderer( - meta, - dialect, - [], - path=path, - jinja_macro_registry=jinja_macros, - python_env=meta_python_env, + dialect=dialect, default_catalog=default_catalog, - quote_identifiers=False, - normalize_identifiers=False, + blueprint_variables=blueprint_variables, ) - rendered_meta_exprs = meta_renderer.render() + if rendered_meta_exprs is None or len(rendered_meta_exprs) != 1: raise_config_error( f"Invalid MODEL statement:\n{meta.sql(dialect=dialect, pretty=True)}", path, ) raise + rendered_meta = rendered_meta_exprs[0] + rendered_defaults = ( + render_model_defaults( + defaults=defaults, + module_path=module_path, + macros=macros, + jinja_macros=jinja_macros, + variables=variables, + path=path, + dialect=dialect, + default_catalog=default_catalog, + ) + if defaults + else {} + ) + + rendered_defaults = parse_defaults_properties(rendered_defaults, dialect=dialect) + # Extract the query and any pre/post statements - query_or_seed_insert, pre_statements, post_statements, inline_audits = ( - _split_sql_model_statements(expressions[1:], path, dialect) + query_or_seed_insert, pre_statements, post_statements, on_virtual_update, inline_audits = ( + _split_sql_model_statements(expressions[1:], path, dialect=dialect) ) meta_fields: t.Dict[str, t.Any] = { @@ -1599,9 +2296,15 @@ def load_sql_based_model( **{prop.name.lower(): prop.args.get("value") for prop in rendered_meta.expressions}, **kwargs, } - if unrendered_signals: - # Signals must remain unrendered, so that they can be rendered later at evaluation runtime. - meta_fields["signals"] = unrendered_signals + + # Discard the potentially half-rendered versions of these properties and replace them with the + # original unrendered versions. They will get rendered properly at evaluation time + meta_fields.update(unrendered_properties) + + if unrendered_merge_filter: + for idx, kind_prop in enumerate(meta_fields["kind"].expressions): + if kind_prop.name.lower() == "merge_filter": + meta_fields["kind"].expressions[idx] = unrendered_merge_filter if isinstance(meta_fields.get("dialect"), exp.Expression): meta_fields["dialect"] = meta_fields["dialect"].name @@ -1609,92 +2312,62 @@ def load_sql_based_model( # The name of the model will be inferred from its path relative to `models/`, if it's not explicitly specified name = meta_fields.pop("name", "") if not name and infer_names: + if path is None: + raise ValueError(f"Model {name} must have a name") name = get_model_name(path) if not name: - raise_config_error("Model must have a name", path) + raise_config_error( + "Please add the required 'name' field to the MODEL block at the top of the file.\n\n" + + "Learn more at https://sqlmesh.readthedocs.io/en/stable/concepts/models/overview" + ) if "default_catalog" in meta_fields: raise_config_error( - "`default_catalog` cannot be set on a per-model basis. It must be set at the connection level or in Airflow.", + "`default_catalog` cannot be set on a per-model basis. It must be set at the connection level.", path, ) - jinja_macro_references, used_variables = extract_macro_references_and_variables( - *(gen(e) for e in pre_statements), - *(gen(e) for e in post_statements), - *([gen(query_or_seed_insert)] if query_or_seed_insert is not None else []), - ) - - jinja_macros = (jinja_macros or JinjaMacroRegistry()).trim(jinja_macro_references) - for jinja_macro in jinja_macros.root_macros.values(): - used_variables.update(extract_macro_references_and_variables(jinja_macro.definition)[1]) - common_kwargs = dict( pre_statements=pre_statements, post_statements=post_statements, - defaults=defaults, + on_virtual_update=on_virtual_update, + defaults=rendered_defaults, path=path, module_path=module_path, macros=macros, python_env=python_env, jinja_macros=jinja_macros, - jinja_macro_references=jinja_macro_references, - physical_schema_override=physical_schema_override, + physical_schema_mapping=physical_schema_mapping, default_catalog=default_catalog, variables=variables, - used_variables=used_variables, inline_audits=inline_audits, + blueprint_variables=blueprint_variables, + use_original_sql=True, **meta_fields, ) - if query_or_seed_insert is not None and ( - isinstance(query_or_seed_insert, (exp.Query, d.JinjaQuery)) - or ( - isinstance(query_or_seed_insert, d.MacroFunc) - and query_or_seed_insert.this.name.lower() == "union" - ) - ): + kind = common_kwargs.pop("kind", ModelMeta.all_field_infos()["kind"].default) + + if kind.name != ModelKindName.SEED: return create_sql_model( name, query_or_seed_insert, + kind=kind, time_column_format=time_column_format, **common_kwargs, ) - else: - try: - seed_properties = { - p.name.lower(): p.args.get("value") for p in common_kwargs.pop("kind").expressions - } - return create_seed_model( - name, - SeedKind(**seed_properties), - **common_kwargs, - ) - except Exception as ex: - raise_config_error( - f"The model definition must either have a SELECT query, a JINJA_QUERY block, or a valid Seed kind. {ex}." - ) - raise + + seed_properties = {p.name.lower(): p.args.get("value") for p in kind.expressions} + return create_seed_model( + name, + SeedKind(**seed_properties), + **common_kwargs, + ) def create_sql_model( name: TableName, - query: exp.Expression, - *, - pre_statements: t.Optional[t.List[exp.Expression]] = None, - post_statements: t.Optional[t.List[exp.Expression]] = None, - defaults: t.Optional[t.Dict[str, t.Any]] = None, - path: Path = Path(), - module_path: Path = Path(), - time_column_format: str = c.DEFAULT_TIME_COLUMN_FORMAT, - macros: t.Optional[MacroRegistry] = None, - python_env: t.Optional[t.Dict[str, Executable]] = None, - jinja_macros: t.Optional[JinjaMacroRegistry] = None, - jinja_macro_references: t.Optional[t.Set[MacroReference]] = None, - dialect: t.Optional[str] = None, - physical_schema_override: t.Optional[t.Dict[str, str]] = None, - variables: t.Optional[t.Dict[str, t.Any]] = None, - used_variables: t.Optional[t.Set[str]] = None, + query: t.Optional[exp.Expression], **kwargs: t.Any, ) -> Model: """Creates a SQL model. @@ -1703,82 +2376,23 @@ def create_sql_model( name: The name of the model, which is of the form [catalog].[db].table. The catalog and db are optional. query: The model's logic in a form of a SELECT query. - pre_statements: The list of SQL statements that precede the model's query. - post_statements: The list of SQL statements that follow after the model's query. - defaults: Definition default values. - path: An optional path to the model definition file. - module_path: The python module path to serialize macros for. - time_column_format: The default time column format to use if no model time column is configured. - The format must adhere to Python's strftime codes. - macros: The custom registry of macros. If not provided the default registry will be used. - python_env: The custom Python environment for macros. If not provided the environment will be constructed - from the macro registry. - jinja_macros: The registry of Jinja macros. - jinja_macro_references: The set of Jinja macros referenced by this model. - dialect: The default dialect if no model dialect is configured. - physical_schema_override: The physical schema override. - variables: User-defined variables. - used_variables: The set of variable names used by this model. """ if not isinstance(query, (exp.Query, d.JinjaQuery, d.MacroFunc)): - # Users are not expected to pass in a single MacroFunc instance for a model's query; - # this is an implementation detail which allows us to create python models that return - # SQL, either in the form of SQLGlot expressions or just plain strings. raise_config_error( "A query is required and must be a SELECT statement, a UNION statement, or a JINJA_QUERY block", - path, + kwargs.get("path"), ) + assert isinstance(query, (exp.Query, d.JinjaQuery, d.MacroFunc)) - pre_statements = pre_statements or [] - post_statements = post_statements or [] - - if not python_env: - python_env = _python_env( - [*pre_statements, query, *post_statements], - jinja_macro_references, - module_path, - macros or macro.get_registry(), - variables=variables, - used_variables=used_variables, - path=path, - ) - else: - python_env = _add_variables_to_python_env(python_env, used_variables, variables) - - return _create_model( - SqlModel, - name, - defaults=defaults, - path=path, - time_column_format=time_column_format, - python_env=python_env, - jinja_macros=jinja_macros, - dialect=dialect, - query=query, - pre_statements=pre_statements, - post_statements=post_statements, - physical_schema_override=physical_schema_override, - **kwargs, - ) + return _create_model(SqlModel, name, query=query, **kwargs) def create_seed_model( name: TableName, seed_kind: SeedKind, *, - dialect: t.Optional[str] = None, - pre_statements: t.Optional[t.List[exp.Expression]] = None, - post_statements: t.Optional[t.List[exp.Expression]] = None, - defaults: t.Optional[t.Dict[str, t.Any]] = None, - path: Path = Path(), + path: t.Optional[Path] = None, module_path: Path = Path(), - macros: t.Optional[MacroRegistry] = None, - python_env: t.Optional[t.Dict[str, Executable]] = None, - jinja_macros: t.Optional[JinjaMacroRegistry] = None, - jinja_macro_references: t.Optional[t.Set[MacroReference]] = None, - physical_schema_override: t.Optional[t.Dict[str, str]] = None, - variables: t.Optional[t.Dict[str, t.Any]] = None, - used_variables: t.Optional[t.Set[str]] = None, **kwargs: t.Any, ) -> Model: """Creates a Seed model. @@ -1787,19 +2401,8 @@ def create_seed_model( name: The name of the model, which is of the form [catalog].[db].table. The catalog and db are optional. seed_kind: The information about the location of a seed and other related configuration. - dialect: The default dialect if no model dialect is configured. - pre_statements: The list of SQL statements that precede the insertion of the seed's content. - post_statements: The list of SQL statements that follow after the insertion of the seed's content. - defaults: Definition default values. path: An optional path to the model definition file. - macros: The custom registry of macros. If not provided the default registry will be used. - python_env: The custom Python environment for macros. If not provided the environment will be constructed from the macro registry. - jinja_macros: The registry of Jinja macros. - jinja_macro_references: The set of Jinja macros referenced by this model. - physical_schema_override: The physical schema override. - variables: User-defined variables. - used_variables: The set of variable names used by this model. """ seed_path = Path(seed_kind.path) marker, *subdirs = seed_path.parts @@ -1807,40 +2410,23 @@ def create_seed_model( seed_path = module_path.joinpath(*subdirs) seed_kind.path = str(seed_path) elif not seed_path.is_absolute(): - seed_path = path / seed_path if path.is_dir() else path.parent / seed_path + if path is None: + seed_path = seed_path + elif path.is_dir(): + seed_path = path / seed_path + else: + seed_path = path.parent / seed_path seed = create_seed(seed_path) - pre_statements = pre_statements or [] - post_statements = post_statements or [] - - if not python_env: - python_env = _python_env( - [*pre_statements, *post_statements], - jinja_macro_references, - module_path, - macros or macro.get_registry(), - variables=variables, - used_variables=used_variables, - path=path, - ) - else: - python_env = _add_variables_to_python_env(python_env, used_variables, variables) - return _create_model( SeedModel, name, - dialect=dialect, - defaults=defaults, path=path, seed=seed, kind=seed_kind, depends_on=kwargs.pop("depends_on", None), - python_env=python_env, - jinja_macros=jinja_macros, - pre_statements=pre_statements, - post_statements=post_statements, - physical_schema_override=physical_schema_override, + module_path=module_path, **kwargs, ) @@ -1850,12 +2436,13 @@ def create_python_model( entrypoint: str, python_env: t.Dict[str, Executable], *, - defaults: t.Optional[t.Dict[str, t.Any]] = None, + macros: t.Optional[MacroRegistry] = None, + jinja_macros: t.Optional[JinjaMacroRegistry] = None, path: Path = Path(), - time_column_format: str = c.DEFAULT_TIME_COLUMN_FORMAT, + module_path: Path = Path(), depends_on: t.Optional[t.Set[str]] = None, - physical_schema_override: t.Optional[t.Dict[str, str]] = None, variables: t.Optional[t.Dict[str, t.Any]] = None, + blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, **kwargs: t.Any, ) -> Model: """Creates a Python model. @@ -1865,33 +2452,65 @@ def create_python_model( The catalog and db are optional. entrypoint: The name of a Python function which contains the data fetching / transformation logic. python_env: The Python environment of all objects referenced by the model implementation. - defaults: Definition default values. path: An optional path to the model definition file. - time_column_format: The default time column format to use if no model time column is configured. depends_on: The custom set of model's upstream dependencies. + variables: The variables to pass to the model. + blueprint_variables: The blueprint's variables to pass to the model. """ # Find dependencies for python models by parsing code if they are not explicitly defined # Also remove self-references that are found + + dialect = kwargs.get("dialect") + + dependencies_unspecified = depends_on is None + parsed_depends_on, referenced_variables = ( - _parse_dependencies(python_env, entrypoint) if python_env is not None else (set(), set()) + parse_dependencies( + python_env, + entrypoint, + strict_resolution=dependencies_unspecified, + variables=variables, + blueprint_variables=blueprint_variables, + ) + if python_env is not None + else (set(), set()) ) - if depends_on is None: + if dependencies_unspecified: depends_on = parsed_depends_on - {name} + else: + depends_on_rendered = render_expression( + expression=exp.Array( + expressions=[exp.maybe_parse(dep, dialect=dialect) for dep in depends_on or []] + ), + module_path=module_path, + macros=macros, + jinja_macros=jinja_macros, + variables=variables, + path=path, + dialect=dialect, + default_catalog=kwargs.get("default_catalog"), + ) + depends_on = { + dep.sql(dialect=dialect) + for dep in t.cast(t.List[exp.Expression], depends_on_rendered)[0].expressions + } - variables = {k: v for k, v in (variables or {}).items() if k in referenced_variables} - if variables: - python_env[c.SQLMESH_VARS] = Executable.value(variables) + used_variables = {k: v for k, v in (variables or {}).items() if k in referenced_variables} + if used_variables: + python_env[c.SQLMESH_VARS] = Executable.value(used_variables, sort_root_dict=True) return _create_model( PythonModel, name, - defaults=defaults, path=path, - time_column_format=time_column_format, depends_on=depends_on, entrypoint=entrypoint, python_env=python_env, - physical_schema_override=physical_schema_override, + macros=macros, + jinja_macros=jinja_macros, + module_path=module_path, + variables=variables, + blueprint_variables=blueprint_variables, **kwargs, ) @@ -1931,48 +2550,189 @@ def _create_model( name: TableName, *, defaults: t.Optional[t.Dict[str, t.Any]] = None, - path: Path = Path(), + path: t.Optional[Path] = None, time_column_format: str = c.DEFAULT_TIME_COLUMN_FORMAT, jinja_macros: t.Optional[JinjaMacroRegistry] = None, jinja_macro_references: t.Optional[t.Set[MacroReference]] = None, depends_on: t.Optional[t.Set[str]] = None, dialect: t.Optional[str] = None, - physical_schema_override: t.Optional[t.Dict[str, str]] = None, + physical_schema_mapping: t.Optional[t.Dict[re.Pattern, str]] = None, + python_env: t.Optional[t.Dict[str, Executable]] = None, + audit_definitions: t.Optional[t.Dict[str, ModelAudit]] = None, + inline_audits: t.Optional[t.Dict[str, ModelAudit]] = None, + module_path: Path = Path(), + macros: t.Optional[MacroRegistry] = None, + signal_definitions: t.Optional[SignalRegistry] = None, + variables: t.Optional[t.Dict[str, t.Any]] = None, + blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, + use_original_sql: bool = False, **kwargs: t.Any, ) -> Model: - _validate_model_fields(klass, {"name", *kwargs} - {"grain", "table_properties"}, path) - - kwargs["session_properties"] = _resolve_session_properties( - (defaults or {}).get("session_properties"), kwargs.get("session_properties") + validate_extra_and_required_fields( + klass, + {"name", *kwargs} - {"grain", "table_properties"}, + "MODEL block", + path, ) + for prop in PROPERTIES: + kwargs[prop] = _resolve_properties((defaults or {}).get(prop), kwargs.get(prop)) + dialect = dialect or "" - physical_schema_override = physical_schema_override or {} + + physical_schema_mapping = physical_schema_mapping or {} + model_schema_name = exp.to_table(name, dialect=dialect).db + physical_schema_override: t.Optional[str] = None + + for re_pattern, override_schema in physical_schema_mapping.items(): + if re.match(re_pattern, model_schema_name): + physical_schema_override = override_schema + break raw_kind = kwargs.pop("kind", None) if raw_kind: kwargs["kind"] = create_model_kind(raw_kind, dialect, defaults or {}) defaults = {k: v for k, v in (defaults or {}).items() if k in klass.all_fields()} + if not issubclass(klass, SqlModel): + defaults.pop("optimize_query", None) - try: - model = klass( - name=name, - **{ - **(defaults or {}), - "jinja_macros": jinja_macros or JinjaMacroRegistry(), - "dialect": dialect, - "depends_on": depends_on, - "physical_schema_override": physical_schema_override.get( - exp.to_table(name, dialect=dialect).db - ), - **kwargs, - }, + statements: t.List[t.Union[exp.Expression, t.Tuple[exp.Expression, bool]]] = [] + + if "query" in kwargs: + statements.append(kwargs["query"]) + kwargs["query"] = ParsableSql.from_parsed_expression( + kwargs["query"], dialect, use_meta_sql=use_original_sql ) - except Exception as ex: - raise_config_error(str(ex), location=path) - raise + # Merge default statements with model-specific statements + for statement_field in ["pre_statements", "post_statements", "on_virtual_update"]: + if statement_field in defaults: + kwargs[statement_field] = [ + exp.maybe_parse(stmt, dialect=dialect) for stmt in defaults[statement_field] + ] + kwargs.get(statement_field, []) + if statement_field in kwargs: + # Macros extracted from these statements need to be treated as metadata only + is_metadata = statement_field == "on_virtual_update" + for stmt in kwargs[statement_field]: + # Extract the expression if it's ParsableSql already + expr = stmt.parse(dialect) if isinstance(stmt, ParsableSql) else stmt + statements.append((expr, is_metadata)) + kwargs[statement_field] = [ + # this to retain the transaction information + stmt + if isinstance(stmt, ParsableSql) + else ParsableSql.from_parsed_expression( + stmt, dialect, use_meta_sql=use_original_sql + ) + for stmt in kwargs[statement_field] + ] + + # This is done to allow variables like @gateway to be used in these properties + # since rendering shifted from load time to run time. + # Note: we check for Tuple since that's what we expect from _resolve_properties + for property_name in PROPERTIES: + property_values = kwargs.get(property_name) + if isinstance(property_values, exp.Tuple): + statements.extend(property_values.expressions) + + if isinstance(getattr(kwargs.get("kind"), "merge_filter", None), exp.Expression): + statements.append(kwargs["kind"].merge_filter) + + jinja_macro_references, referenced_variables = extract_macro_references_and_variables( + *(gen(e if isinstance(e, exp.Expression) else e[0]) for e in statements) + ) + + if jinja_macros: + jinja_macros = ( + jinja_macros if jinja_macros.trimmed else jinja_macros.trim(jinja_macro_references) + ) + else: + jinja_macros = JinjaMacroRegistry() + + for jinja_macro in jinja_macros.root_macros.values(): + referenced_variables.update( + extract_macro_references_and_variables(jinja_macro.definition)[1] + ) + + # Merge model-specific audits with default audits + if default_audits := defaults.pop("audits", None): + kwargs["audits"] = default_audits + d.extract_function_calls(kwargs.pop("audits", [])) + + model = klass( + name=name, + **{ + **(defaults or {}), + "jinja_macros": jinja_macros or JinjaMacroRegistry(), + "dialect": dialect, + "depends_on": depends_on, + "physical_schema_override": physical_schema_override, + **kwargs, + }, + ) + + audit_definitions = { + **(audit_definitions or {}), + **(inline_audits or {}), + } + + used_audits: t.Set[str] = {audit_name for audit_name, _ in model.audits} + + audit_definitions = { + audit_name: audit_definitions[audit_name] + for audit_name in used_audits + if audit_name in audit_definitions + } + + model.audit_definitions.update(audit_definitions) + + # Any macro referenced in audits or signals needs to be treated as metadata-only + statements.extend((audit.query, True) for audit in audit_definitions.values()) + + # Ensure that all audits referenced in the model are defined + from sqlmesh.core.audit.builtin import BUILT_IN_AUDITS + + available_audits = BUILT_IN_AUDITS.keys() | model.audit_definitions.keys() + for referenced_audit, audit_args in model.audits: + if referenced_audit not in available_audits: + raise_config_error(f"Audit '{referenced_audit}' is undefined", location=path) + + statements.extend( + (audit_arg_expression, True) for audit_arg_expression in audit_args.values() + ) + + signal_definitions = signal_definitions or UniqueKeyDict("signals") + + for referenced_signal, kwargs in model.signals: + if referenced_signal and referenced_signal not in signal_definitions: + raise_config_error(f"Signal '{referenced_signal}' is undefined", location=path) + + statements.extend((signal_kwarg, True) for signal_kwarg in kwargs.values()) + + python_env = make_python_env( + statements, + jinja_macro_references, + module_path, + macros or macro.get_registry(), + variables=variables, + referenced_variables=referenced_variables, + path=path, + python_env=python_env, + strict_resolution=depends_on is None, + blueprint_variables=blueprint_variables, + dialect=dialect, + ) + + env: t.Dict[str, t.Tuple[t.Any, t.Optional[bool]]] = {} + + for signal_name, _ in model.signals: + if signal_name and signal_name in signal_definitions: + func = signal_definitions[signal_name].func + setattr(func, c.SQLMESH_METADATA, True) + build_env(func, env=env, name=signal_name, path=module_path) + + model.python_env.update(python_env) + model.python_env.update(serialize_env(env, path=module_path)) model._path = path model.set_time_format(time_column_format) @@ -1983,11 +2743,14 @@ def _create_model( def _split_sql_model_statements( - expressions: t.List[exp.Expression], path: Path, dialect: t.Optional[str] = None + expressions: t.List[exp.Expression], + path: t.Optional[Path], + dialect: t.Optional[str] = None, ) -> t.Tuple[ t.Optional[exp.Expression], t.List[exp.Expression], t.List[exp.Expression], + t.List[exp.Expression], UniqueKeyDict[str, ModelAudit], ]: """Extracts the SELECT query from a sequence of expressions. @@ -2006,6 +2769,7 @@ def _split_sql_model_statements( query_positions = [] sql_statements = [] + on_virtual_update = [] inline_audits: UniqueKeyDict[str, ModelAudit] = UniqueKeyDict("inline_audits") idx = 0 @@ -2018,204 +2782,58 @@ def _split_sql_model_statements( assert isinstance(loaded_audit, ModelAudit) inline_audits[loaded_audit.name] = loaded_audit idx += 2 + elif isinstance(expr, d.VirtualUpdateStatement): + for statement in expr.expressions: + on_virtual_update.append(statement) + idx += 1 else: if ( isinstance(expr, (exp.Query, d.JinjaQuery)) or expr == INSERT_SEED_MACRO_CALL - or (isinstance(expr, d.MacroFunc) and expr.this.name.lower() == "union") + or ( + isinstance(expr, d.MacroFunc) + and (expr.this.name.lower() == "union" or length == 1) + ) ): query_positions.append((expr, idx)) sql_statements.append(expr) idx += 1 if not query_positions: - return None, sql_statements, [], inline_audits + return None, sql_statements, [], on_virtual_update, inline_audits - elif len(query_positions) > 1: + if len(query_positions) > 1: raise_config_error("Only one SELECT query is allowed per model", path) query, pos = query_positions[0] - return query, sql_statements[:pos], sql_statements[pos + 1 :], inline_audits + return query, sql_statements[:pos], sql_statements[pos + 1 :], on_virtual_update, inline_audits -def _resolve_session_properties( +def _resolve_properties( default: t.Optional[t.Dict[str, t.Any]], provided: t.Optional[exp.Expression | t.Dict[str, t.Any]], ) -> t.Optional[exp.Expression]: if isinstance(provided, dict): - session_properties = {k: exp.Literal.string(k).eq(v) for k, v in provided.items()} + properties = {k: exp.Literal.string(k).eq(v) for k, v in provided.items()} elif provided: if isinstance(provided, exp.Paren): provided = exp.Tuple(expressions=[provided.this]) - session_properties = {expr.this.name: expr for expr in provided} + properties = {expr.this.name: expr for expr in provided} else: - session_properties = {} + properties = {} for k, v in (default or {}).items(): - if k not in session_properties: - session_properties[k] = exp.Literal.string(k).eq(v) + if k not in properties: + properties[k] = exp.Literal.string(k).eq(v) + elif properties[k].expression.sql().lower() in {"none", "null"}: + del properties[k] - if session_properties: - return exp.Tuple(expressions=list(session_properties.values())) + if properties: + return exp.Tuple(expressions=list(properties.values())) return None -def _validate_model_fields(klass: t.Type[_Model], provided_fields: t.Set[str], path: Path) -> None: - missing_required_fields = klass.missing_required_fields(provided_fields) - if missing_required_fields: - raise_config_error( - f"Missing required fields {missing_required_fields} in the model definition", - path, - ) - - extra_fields = klass.extra_fields(provided_fields) - if extra_fields: - raise_config_error(f"Invalid extra fields {extra_fields} in the model definition", path) - - -def _python_env( - expressions: t.Union[exp.Expression, t.List[exp.Expression]], - jinja_macro_references: t.Optional[t.Set[MacroReference]], - module_path: Path, - macros: MacroRegistry, - variables: t.Optional[t.Dict[str, t.Any]] = None, - used_variables: t.Optional[t.Set[str]] = None, - path: t.Optional[str | Path] = None, -) -> t.Dict[str, Executable]: - python_env: t.Dict[str, Executable] = {} - variables = variables or {} - - used_macros = {} - used_variables = (used_variables or set()).copy() - serialized_env = {} - - expressions = ensure_list(expressions) - for expression in expressions: - if not isinstance(expression, d.Jinja): - for macro_func_or_var in expression.find_all(d.MacroFunc, d.MacroVar, exp.Identifier): - if macro_func_or_var.__class__ is d.MacroFunc: - name = macro_func_or_var.this.name.lower() - if name in macros: - used_macros[name] = macros[name] - if name == c.VAR: - args = macro_func_or_var.this.expressions - if len(args) < 1: - raise_config_error("Macro VAR requires at least one argument", path) - if not args[0].is_string: - raise_config_error( - f"The variable name must be a string literal, '{args[0].sql()}' was given instead", - path, - ) - used_variables.add(args[0].this.lower()) - elif macro_func_or_var.__class__ is d.MacroVar: - name = macro_func_or_var.name.lower() - if name in macros: - used_macros[name] = macros[name] - elif name in variables: - used_variables.add(name) - elif ( - isinstance(macro_func_or_var, (exp.Identifier, d.MacroStrReplace, d.MacroSQL)) - ) and "@" in macro_func_or_var.name: - for _, identifier, braced_identifier, _ in MacroStrTemplate.pattern.findall( - macro_func_or_var.name - ): - var_name = braced_identifier or identifier - if var_name in variables: - used_variables.add(var_name) - - for macro_ref in jinja_macro_references or set(): - if macro_ref.package is None and macro_ref.name in macros: - used_macros[macro_ref.name] = macros[macro_ref.name] - - for name, used_macro in used_macros.items(): - if isinstance(used_macro, Executable): - serialized_env[name] = used_macro - elif not hasattr(used_macro, c.SQLMESH_BUILTIN): - build_env(used_macro.func, env=python_env, name=name, path=module_path) - - serialized_env.update(serialize_env(python_env, path=module_path)) - return _add_variables_to_python_env(serialized_env, used_variables, variables) - - -def _add_variables_to_python_env( - python_env: t.Dict[str, Executable], - used_variables: t.Optional[t.Set[str]], - variables: t.Optional[t.Dict[str, t.Any]], -) -> t.Dict[str, Executable]: - _, python_used_variables = _parse_dependencies(python_env, None) - used_variables = (used_variables or set()) | python_used_variables - - variables = {k: v for k, v in (variables or {}).items() if k in used_variables} - if variables: - python_env[c.SQLMESH_VARS] = Executable.value(variables) - - return python_env - - -def _parse_dependencies( - python_env: t.Dict[str, Executable], entrypoint: t.Optional[str] -) -> t.Tuple[t.Set[str], t.Set[str]]: - """Parses the source of a model function and finds upstream table dependencies and referenced variables based on calls to context / evaluator. - - Args: - python_env: A dictionary of Python definitions. - - Returns: - A tuple containing the set of upstream table dependencies and the set of referenced variables. - """ - env = prepare_env(python_env) - depends_on = set() - variables = set() - - for executable in python_env.values(): - if not executable.is_definition: - continue - for node in ast.walk(ast.parse(executable.payload)): - if isinstance(node, ast.Call): - func = node.func - if not isinstance(func, ast.Attribute) or not isinstance(func.value, ast.Name): - continue - - def get_first_arg(keyword_arg_name: str) -> t.Any: - if node.args: - first_arg: t.Optional[ast.expr] = node.args[0] - else: - first_arg = next( - ( - keyword.value - for keyword in node.keywords - if keyword.arg == keyword_arg_name - ), - None, - ) - - try: - expression = to_source(first_arg) - return eval(expression, env) - except Exception: - raise ConfigError( - f"Error resolving dependencies for '{executable.path}'. Argument '{expression.strip()}' must be resolvable at parse time." - ) - - if func.value.id == "context" and func.attr == "table": - depends_on.add(get_first_arg("model_name")) - elif func.value.id in ("context", "evaluator") and func.attr == c.VAR: - variables.add(get_first_arg("var_name").lower()) - elif ( - isinstance(node, ast.Attribute) - and isinstance(node.value, ast.Name) - and node.value.id in ("context", "evaluator") - and node.attr == c.GATEWAY - ): - # Check whether the gateway attribute is referenced. - variables.add(c.GATEWAY) - elif isinstance(node, ast.FunctionDef) and node.name == entrypoint: - variables.update([arg.arg for arg in node.args.args if arg.arg != "context"]) - - return depends_on, variables - - def _list_of_calls_to_exp(value: t.List[t.Tuple[str, t.Dict[str, t.Any]]]) -> exp.Expression: return exp.Tuple( expressions=[ @@ -2236,34 +2854,201 @@ def _is_projection(expr: exp.Expression) -> bool: return isinstance(parent, exp.Select) and expr.arg_key == "expressions" -def _is_udtf(expr: exp.Expression) -> bool: - return isinstance(expr, (exp.Explode, exp.Posexplode, exp.Unnest)) or ( - isinstance(expr, exp.Anonymous) - and expr.this.upper() in ("EXPLODE_OUTER", "POSEXPLODE_OUTER", "UNNEST") - ) +def _single_expr_or_tuple(values: t.Sequence[exp.Expression]) -> exp.Expression | exp.Tuple: + return values[0] if len(values) == 1 else exp.Tuple(expressions=values) -def _single_value_or_tuple(values: t.Sequence) -> exp.Identifier | exp.Tuple: - return ( - exp.to_identifier(values[0]) - if len(values) == 1 - else exp.Tuple(expressions=[exp.to_identifier(v) for v in values]) +def _refs_to_sql(values: t.Any) -> exp.Expression: + return exp.Tuple(expressions=values) + + +def render_meta_fields( + fields: t.Dict[str, t.Any], + module_path: Path, + path: t.Optional[Path], + jinja_macros: t.Optional[JinjaMacroRegistry], + macros: t.Optional[MacroRegistry], + dialect: DialectType, + variables: t.Optional[t.Dict[str, t.Any]], + default_catalog: t.Optional[str], + blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, +) -> t.Dict[str, t.Any]: + def render_field_value(value: t.Any) -> t.Any: + if isinstance(value, exp.Expression) or (isinstance(value, str) and "@" in value): + expression = exp.maybe_parse(value, dialect=dialect) + rendered_expr = render_expression( + expression=expression, + module_path=module_path, + macros=macros, + jinja_macros=jinja_macros, + variables=variables, + path=path, + dialect=dialect, + default_catalog=default_catalog, + blueprint_variables=blueprint_variables, + ) + if not rendered_expr: + raise SQLMeshError( + f"Rendering `{expression.sql(dialect=dialect)}` did not return an expression" + ) + + if len(rendered_expr) != 1: + raise SQLMeshError( + f"Rendering `{expression.sql(dialect=dialect)}` must return one result, but got {len(rendered_expr)}" + ) + + # For cases where a property is conditionally assigned + if rendered_expr[0].sql().lower() in {"none", "null"}: + return None + + return rendered_expr[0] + + return value + + for field_name, field_info in ModelMeta.all_field_infos().items(): + field = field_info.alias or field_name + field_value = fields.get(field) + + # We don't want to parse python model cron="@..." kwargs (e.g. @daily) into MacroVar + if ( + field == "cron" + and isinstance(field_value, str) + and field_value.lower() in CRON_SHORTCUTS + ) or field_value is None: + continue + + if field in RUNTIME_RENDERED_MODEL_FIELDS: + fields[field] = parse_strings_with_macro_refs(field_value, dialect) + continue + + if isinstance(field_value, dict): + rendered_dict = {} + for key, value in field_value.items(): + if key in RUNTIME_RENDERED_MODEL_FIELDS: + rendered_dict[key] = parse_strings_with_macro_refs(value, dialect) + elif ( + # don't parse kind auto_restatement_cron="@..." kwargs (e.g. @daily) into MacroVar + key == "auto_restatement_cron" + and isinstance(value, str) + and value.lower() in CRON_SHORTCUTS + ): + rendered_dict[key] = value + elif (rendered := render_field_value(value)) is not None: + rendered_dict[key] = rendered + + if rendered_dict: + fields[field] = rendered_dict + else: + fields.pop(field) + elif isinstance(field_value, list): + rendered_list = [ + rendered + for value in field_value + if (rendered := render_field_value(value)) is not None + ] + if rendered_list: + fields[field] = rendered_list + else: + fields.pop(field) + else: + rendered_field = render_field_value(field_value) + if rendered_field is not None: + fields[field] = rendered_field + else: + fields.pop(field) + + return fields + + +def render_model_defaults( + defaults: t.Dict[str, t.Any], + module_path: Path, + path: t.Optional[Path], + jinja_macros: t.Optional[JinjaMacroRegistry], + macros: t.Optional[MacroRegistry], + dialect: DialectType, + variables: t.Optional[t.Dict[str, t.Any]], + default_catalog: t.Optional[str], +) -> t.Dict[str, t.Any]: + rendered_defaults = render_meta_fields( + fields=defaults, + module_path=module_path, + macros=macros, + jinja_macros=jinja_macros, + variables=variables, + path=path, + dialect=dialect, + default_catalog=default_catalog, ) + # Validate defaults that have macros are rendered to boolean + for boolean in {"optimize_query", "allow_partials", "enabled"}: + var = rendered_defaults.get(boolean) + if var is not None and not isinstance(var, (exp.Boolean, bool)): + raise ConfigError(f"Expected boolean for '{var}', got '{type(var)}' instead") -def _single_expr_or_tuple(values: t.Sequence[exp.Expression]) -> exp.Expression | exp.Tuple: - return values[0] if len(values) == 1 else exp.Tuple(expressions=values) + # Validate the 'interval_unit' if present is an Interval Unit + var = rendered_defaults.get("interval_unit") + if isinstance(var, str): + try: + rendered_defaults["interval_unit"] = IntervalUnit(var) + except ValueError as e: + raise ConfigError(f"Invalid interval unit: {var}") from e + return rendered_defaults -def _refs_to_sql(values: t.Any) -> exp.Expression: - return exp.Tuple(expressions=values) + +def parse_defaults_properties( + defaults: t.Dict[str, t.Any], dialect: DialectType +) -> t.Dict[str, t.Any]: + for prop in PROPERTIES: + default_properties = defaults.get(prop) + for key, value in (default_properties or {}).items(): + if isinstance(key, str) and d.SQLMESH_MACRO_PREFIX in str(value): + defaults[prop][key] = exp.maybe_parse(value, dialect=dialect) + + return defaults + + +def render_expression( + expression: exp.Expression, + module_path: Path, + path: t.Optional[Path], + jinja_macros: t.Optional[JinjaMacroRegistry] = None, + macros: t.Optional[MacroRegistry] = None, + dialect: DialectType = None, + variables: t.Optional[t.Dict[str, t.Any]] = None, + default_catalog: t.Optional[str] = None, + blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, +) -> t.Optional[t.List[exp.Expression]]: + meta_python_env = make_python_env( + expressions=expression, + jinja_macro_references=None, + module_path=module_path, + macros=macros or macro.get_registry(), + variables=variables, + path=path, + blueprint_variables=blueprint_variables, + ) + return ExpressionRenderer( + expression, + dialect, + [], + path=path, + jinja_macro_registry=jinja_macros, + python_env=meta_python_env, + default_catalog=default_catalog, + quote_identifiers=False, + normalize_identifiers=False, + ).render() META_FIELD_CONVERTER: t.Dict[str, t.Callable] = { "start": lambda value: exp.Literal.string(value), "cron": lambda value: exp.Literal.string(value), + "cron_tz": lambda value: exp.Literal.string(value), "partitioned_by_": _single_expr_or_tuple, - "clustered_by": _single_value_or_tuple, + "clustered_by": _single_expr_or_tuple, "depends_on_": lambda value: exp.Tuple(expressions=sorted(value)), "pre": _list_of_calls_to_exp, "post": _list_of_calls_to_exp, @@ -2271,17 +3056,77 @@ def _refs_to_sql(values: t.Any) -> exp.Expression: "columns_to_types_": lambda value: exp.Schema( expressions=[exp.ColumnDef(this=exp.to_column(c), kind=t) for c, t in value.items()] ), - "tags": _single_value_or_tuple, + "column_descriptions_": lambda value: exp.Schema( + expressions=[exp.to_column(c).eq(d) for c, d in value.items()] + ), + "tags": single_value_or_tuple, "grains": _refs_to_sql, "references": _refs_to_sql, "physical_properties_": lambda value: value, "virtual_properties_": lambda value: value, "session_properties_": lambda value: value, "allow_partials": exp.convert, - "signals": lambda values: exp.Tuple(expressions=values), + "signals": lambda values: exp.tuple_( + *( + exp.func( + name, *(exp.PropertyEQ(this=exp.var(k), expression=v) for k, v in args.items()) + ) + if name + else exp.Tuple(expressions=[exp.var(k).eq(v) for k, v in args.items()]) + for name, args in values + ) + ), + "formatting": str, + "optimize_query": str, + "virtual_environment_mode": lambda value: exp.Literal.string(value.value), + "dbt_node_info_": lambda value: value.to_expression(), + "grants_": lambda value: value, + "grants_target_layer": lambda value: exp.Literal.string(value.value), } def get_model_name(path: Path) -> str: path_parts = list(path.parts[path.parts.index("models") + 1 : -1]) + [path.stem] return ".".join(path_parts[-3:]) + + +# function applied to time column when automatically used for partitioning in INCREMENTAL_BY_TIME_RANGE models +def clickhouse_partition_func( + column: exp.Expression, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] +) -> exp.Expression: + # `toMonday()` function accepts a Date or DateTime type column + + col_type = (columns_to_types and columns_to_types.get(column.name)) or exp.DataType.build( + "UNKNOWN" + ) + col_type_is_conformable = col_type.is_type( + exp.DataType.Type.DATE, + exp.DataType.Type.DATE32, + exp.DataType.Type.DATETIME, + exp.DataType.Type.DATETIME64, + ) + + # if input column is already a conformable type, just pass the column + if col_type_is_conformable: + return exp.func("toMonday", column, dialect="clickhouse") + + # if input column type is not known, cast input to DateTime64 + if col_type.is_type(exp.DataType.Type.UNKNOWN): + return exp.func( + "toMonday", + exp.cast(column, exp.DataType.build("DateTime64(9, 'UTC')", dialect="clickhouse")), + dialect="clickhouse", + ) + + # if input column type is known but not conformable, cast input to DateTime64 and cast output back to original type + return exp.cast( + exp.func( + "toMonday", + exp.cast(column, exp.DataType.build("DateTime64(9, 'UTC')", dialect="clickhouse")), + dialect="clickhouse", + ), + col_type, + ) + + +TIME_COL_PARTITION_FUNC = {"clickhouse": clickhouse_partition_func} diff --git a/sqlmesh/core/model/kind.py b/sqlmesh/core/model/kind.py index a1ec505653..9abaa9c650 100644 --- a/sqlmesh/core/model/kind.py +++ b/sqlmesh/core/model/kind.py @@ -1,8 +1,8 @@ from __future__ import annotations -import sys import typing as t from enum import Enum +from typing_extensions import Self from pydantic import Field from sqlglot import exp @@ -12,29 +12,30 @@ from sqlglot.time import format_time from sqlmesh.core import dialect as d -from sqlmesh.core.model.common import parse_properties, properties_validator +from sqlmesh.core.model.common import ( + parse_properties, + properties_validator, + validate_extra_and_required_fields, +) from sqlmesh.core.model.seed import CsvSettings from sqlmesh.utils.errors import ConfigError from sqlmesh.utils.pydantic import ( PydanticModel, SQLGlotBool, SQLGlotColumn, - SQLGlotListOfColumnsOrStar, + SQLGlotListOfFieldsOrStar, SQLGlotListOfFields, SQLGlotPositiveInt, SQLGlotString, + SQLGlotCron, + ValidationInfo, column_validator, field_validator, - field_validator_v1_args, get_dialect, validate_string, + validate_expression, ) -if sys.version_info >= (3, 9): - from typing import Annotated, Literal -else: - from typing_extensions import Annotated, Literal - if t.TYPE_CHECKING: from sqlmesh.core._typing import CustomMaterializationProperties @@ -118,6 +119,10 @@ def is_custom(self) -> bool: def is_managed(self) -> bool: return self.model_kind_name == ModelKindName.MANAGED + @property + def is_dbt_custom(self) -> bool: + return self.model_kind_name == ModelKindName.DBT_CUSTOM + @property def is_symbolic(self) -> bool: """A symbolic model is one that doesn't execute at all.""" @@ -125,7 +130,7 @@ def is_symbolic(self) -> bool: @property def is_materialized(self) -> bool: - return not (self.is_symbolic or self.is_view) + return self.model_kind_name is not None and not (self.is_symbolic or self.is_view) @property def only_execution_time(self) -> bool: @@ -135,18 +140,25 @@ def only_execution_time(self) -> bool: @property def full_history_restatement_only(self) -> bool: """Whether or not this model only supports restatement of full history.""" - return self.model_kind_name in ( - ModelKindName.INCREMENTAL_UNMANAGED, - ModelKindName.INCREMENTAL_BY_UNIQUE_KEY, - ModelKindName.INCREMENTAL_BY_PARTITION, - ModelKindName.SCD_TYPE_2, - ModelKindName.MANAGED, + return ( + self.is_incremental_unmanaged + or self.is_incremental_by_unique_key + or self.is_incremental_by_partition + or self.is_scd_type_2 + or self.is_managed + or self.is_full + or self.is_view ) @property def supports_python_models(self) -> bool: return True + @property + def supports_grants(self) -> bool: + """Whether this model kind supports grants configuration.""" + return self.is_materialized or self.is_view + class ModelKindName(str, ModelKindMixin, Enum): """The kind of model, determining how this data is computed and stored in the warehouse.""" @@ -167,6 +179,7 @@ class ModelKindName(str, ModelKindMixin, Enum): EXTERNAL = "EXTERNAL" CUSTOM = "CUSTOM" MANAGED = "MANAGED" + DBT_CUSTOM = "DBT_CUSTOM" @property def model_kind_name(self) -> t.Optional[ModelKindName]: @@ -185,6 +198,7 @@ class OnDestructiveChange(str, Enum): ERROR = "ERROR" WARN = "WARN" ALLOW = "ALLOW" + IGNORE = "IGNORE" @property def is_error(self) -> bool: @@ -198,6 +212,35 @@ def is_warn(self) -> bool: def is_allow(self) -> bool: return self == OnDestructiveChange.ALLOW + @property + def is_ignore(self) -> bool: + return self == OnDestructiveChange.IGNORE + + +class OnAdditiveChange(str, Enum): + """What should happen when a forward-only model change requires an additive schema change.""" + + ERROR = "ERROR" + WARN = "WARN" + ALLOW = "ALLOW" + IGNORE = "IGNORE" + + @property + def is_error(self) -> bool: + return self == OnAdditiveChange.ERROR + + @property + def is_warn(self) -> bool: + return self == OnAdditiveChange.WARN + + @property + def is_allow(self) -> bool: + return self == OnAdditiveChange.ALLOW + + @property + def is_ignore(self) -> bool: + return self == OnAdditiveChange.IGNORE + def _on_destructive_change_validator( cls: t.Type, v: t.Union[OnDestructiveChange, str, exp.Identifier] @@ -209,6 +252,20 @@ def _on_destructive_change_validator( return v +def _on_additive_change_validator( + cls: t.Type, v: t.Union[OnAdditiveChange, str, exp.Identifier] +) -> t.Any: + if v and not isinstance(v, OnAdditiveChange): + return OnAdditiveChange( + v.this.upper() if isinstance(v, (exp.Identifier, exp.Literal)) else v.upper() + ) + return v + + +on_additive_change_validator = field_validator("on_additive_change", mode="before")( + _on_additive_change_validator +) + on_destructive_change_validator = field_validator("on_destructive_change", mode="before")( _on_destructive_change_validator ) @@ -242,44 +299,8 @@ class TimeColumn(PydanticModel): @classmethod def validator(cls) -> classmethod: - def _time_column_validator(v: t.Any, values: t.Any) -> TimeColumn: - dialect = get_dialect(values) - - if isinstance(v, exp.Tuple): - column_expr = v.expressions[0] - column = ( - exp.column(column_expr) - if isinstance(column_expr, exp.Identifier) - else column_expr - ) - format = v.expressions[1].name if len(v.expressions) > 1 else None - elif isinstance(v, exp.Expression): - column = exp.column(v) if isinstance(v, exp.Identifier) else v - format = None - elif isinstance(v, str): - column = d.parse_one(v, dialect=dialect) - column.meta.pop("sql") - format = None - elif isinstance(v, dict): - column_raw = v["column"] - column = ( - d.parse_one(column_raw, dialect=dialect) - if isinstance(column_raw, str) - else column_raw - ) - format = v.get("format") - elif isinstance(v, TimeColumn): - column = v.column - format = v.format - else: - raise ConfigError(f"Invalid time_column: '{v}'.") - - column = quote_identifiers( - normalize_identifiers(column, dialect=dialect), dialect=dialect - ) - column.meta["dialect"] = dialect - - return TimeColumn(column=column, format=format) + def _time_column_validator(v: t.Any, info: ValidationInfo) -> TimeColumn: + return TimeColumn.create(v, get_dialect(info.data)) return field_validator("time_column", mode="before")(_time_column_validator) @@ -317,6 +338,40 @@ def to_expression(self, dialect: str) -> exp.Expression: def to_property(self, dialect: str = "") -> exp.Property: return exp.Property(this="time_column", value=self.to_expression(dialect)) + @classmethod + def create(cls, v: t.Any, dialect: str) -> Self: + if isinstance(v, exp.Tuple): + column_expr = v.expressions[0] + column = ( + exp.column(column_expr) if isinstance(column_expr, exp.Identifier) else column_expr + ) + format = v.expressions[1].name if len(v.expressions) > 1 else None + elif isinstance(v, exp.Expression): + column = exp.column(v) if isinstance(v, exp.Identifier) else v + format = None + elif isinstance(v, str): + column = d.parse_one(v, dialect=dialect) + column.meta.pop("sql") + format = None + elif isinstance(v, dict): + column_raw = v["column"] + column = ( + d.parse_one(column_raw, dialect=dialect) + if isinstance(column_raw, str) + else column_raw + ) + format = v.get("format") + elif isinstance(v, TimeColumn): + column = v.column + format = v.format + else: + raise ConfigError(f"Invalid time_column: '{v}'.") + + column = quote_identifiers(normalize_identifiers(column, dialect=dialect), dialect=dialect) + column.meta["dialect"] = dialect + + return cls(column=column, format=format) + def _kind_dialect_validator(cls: t.Type, v: t.Optional[str]) -> str: if v is None: @@ -324,21 +379,24 @@ def _kind_dialect_validator(cls: t.Type, v: t.Optional[str]) -> str: return v -kind_dialect_validator = field_validator("dialect", mode="before", always=True)( - _kind_dialect_validator -) +kind_dialect_validator = field_validator("dialect", mode="before")(_kind_dialect_validator) class _Incremental(_ModelKind): on_destructive_change: OnDestructiveChange = OnDestructiveChange.ERROR + on_additive_change: OnAdditiveChange = OnAdditiveChange.ALLOW + auto_restatement_cron: t.Optional[SQLGlotCron] = None _on_destructive_change_validator = on_destructive_change_validator + _on_additive_change_validator = on_additive_change_validator @property def metadata_hash_values(self) -> t.List[t.Optional[str]]: return [ *super().metadata_hash_values, str(self.on_destructive_change), + str(self.on_additive_change), + self.auto_restatement_cron, ] def to_expression( @@ -347,8 +405,12 @@ def to_expression( return super().to_expression( expressions=[ *(expressions or []), - _property( - "on_destructive_change", exp.Literal.string(self.on_destructive_change.value) + *_properties( + { + "on_destructive_change": self.on_destructive_change.value, + "on_additive_change": self.on_additive_change.value, + "auto_restatement_cron": self.auto_restatement_cron, + } ), ], ) @@ -401,8 +463,12 @@ def to_expression( class IncrementalByTimeRangeKind(_IncrementalBy): - name: Literal[ModelKindName.INCREMENTAL_BY_TIME_RANGE] = ModelKindName.INCREMENTAL_BY_TIME_RANGE + name: t.Literal[ModelKindName.INCREMENTAL_BY_TIME_RANGE] = ( + ModelKindName.INCREMENTAL_BY_TIME_RANGE + ) time_column: TimeColumn + auto_restatement_intervals: t.Optional[SQLGlotPositiveInt] = None + partition_by_time_column: SQLGlotBool = True _time_column_validator = TimeColumn.validator() @@ -413,6 +479,16 @@ def to_expression( expressions=[ *(expressions or []), self.time_column.to_property(kwargs.get("dialect") or ""), + *_properties( + { + "partition_by_time_column": self.partition_by_time_column, + } + ), + *( + [_property("auto_restatement_intervals", self.auto_restatement_intervals)] + if self.auto_restatement_intervals is not None + else [] + ), ] ) @@ -420,44 +496,67 @@ def to_expression( def data_hash_values(self) -> t.List[t.Optional[str]]: return [*super().data_hash_values, gen(self.time_column.column), self.time_column.format] + @property + def metadata_hash_values(self) -> t.List[t.Optional[str]]: + return [ + *super().metadata_hash_values, + str(self.partition_by_time_column), + str(self.auto_restatement_intervals) + if self.auto_restatement_intervals is not None + else None, + ] + class IncrementalByUniqueKeyKind(_IncrementalBy): - name: Literal[ModelKindName.INCREMENTAL_BY_UNIQUE_KEY] = ModelKindName.INCREMENTAL_BY_UNIQUE_KEY + name: t.Literal[ModelKindName.INCREMENTAL_BY_UNIQUE_KEY] = ( + ModelKindName.INCREMENTAL_BY_UNIQUE_KEY + ) unique_key: SQLGlotListOfFields - when_matched: t.Optional[exp.When] = None - batch_concurrency: Literal[1] = 1 + when_matched: t.Optional[exp.Whens] = None + merge_filter: t.Optional[exp.Expression] = None + batch_concurrency: t.Literal[1] = 1 @field_validator("when_matched", mode="before") - @field_validator_v1_args def _when_matched_validator( - cls, v: t.Optional[t.Union[exp.When, str]], values: t.Dict[str, t.Any] - ) -> t.Optional[exp.When]: - def replace_table_references(expression: exp.Expression) -> exp.Expression: - from sqlmesh.core.engine_adapter.base import ( - MERGE_SOURCE_ALIAS, - MERGE_TARGET_ALIAS, - ) + cls, + v: t.Optional[t.Union[str, list, exp.Whens]], + info: ValidationInfo, + ) -> t.Optional[exp.Whens]: + if v is None: + return v + if isinstance(v, list): + v = " ".join(v) - if isinstance(expression, exp.Column): - if expression.table.lower() == "target": - expression.set( - "table", - exp.to_identifier(MERGE_TARGET_ALIAS), - ) - elif expression.table.lower() == "source": - expression.set( - "table", - exp.to_identifier(MERGE_SOURCE_ALIAS), - ) - return expression + dialect = get_dialect(info.data) if isinstance(v, str): - return t.cast(exp.When, d.parse_one(v, into=exp.When, dialect=get_dialect(values))) - - if not v: + # Whens wrap the WHEN clauses, but the parentheses aren't parsed by sqlglot + v = v.strip() + if v.startswith("("): + v = v[1:-1] + + v = t.cast(exp.Whens, d.parse_one(v, into=exp.Whens, dialect=dialect)) + + v = validate_expression(v, dialect=dialect) + return t.cast(exp.Whens, v.transform(d.replace_merge_table_aliases, dialect=dialect)) + + @field_validator("merge_filter", mode="before") + def _merge_filter_validator( + cls, + v: t.Optional[exp.Expression], + info: ValidationInfo, + ) -> t.Optional[exp.Expression]: + if v is None: return v - return t.cast(exp.When, v.transform(replace_table_references)) + dialect = get_dialect(info.data) + + if isinstance(v, str): + v = v.strip() + v = d.parse_one(v, dialect=dialect) + + v = validate_expression(v, dialect=dialect) + return v.transform(d.replace_merge_table_aliases, dialect=dialect) @property def data_hash_values(self) -> t.List[t.Optional[str]]: @@ -465,6 +564,7 @@ def data_hash_values(self) -> t.List[t.Optional[str]]: *super().data_hash_values, *(gen(k) for k in self.unique_key), gen(self.when_matched) if self.when_matched is not None else None, + gen(self.merge_filter) if self.merge_filter is not None else None, ] def to_expression( @@ -477,6 +577,7 @@ def to_expression( { "unique_key": exp.Tuple(expressions=self.unique_key), "when_matched": self.when_matched, + "merge_filter": self.merge_filter, } ), ], @@ -484,9 +585,17 @@ def to_expression( class IncrementalByPartitionKind(_Incremental): - name: Literal[ModelKindName.INCREMENTAL_BY_PARTITION] = ModelKindName.INCREMENTAL_BY_PARTITION - forward_only: Literal[True] = True - disable_restatement: SQLGlotBool = True + name: t.Literal[ModelKindName.INCREMENTAL_BY_PARTITION] = ModelKindName.INCREMENTAL_BY_PARTITION + forward_only: t.Literal[True] = True + disable_restatement: SQLGlotBool = False + + @field_validator("forward_only", mode="before") + def _forward_only_validator(cls, v: t.Union[bool, exp.Expression]) -> t.Literal[True]: + if v is not True: + raise ConfigError( + "Do not specify the `forward_only` configuration key - INCREMENTAL_BY_PARTITION models are always forward_only." + ) + return v @property def metadata_hash_values(self) -> t.List[t.Optional[str]]: @@ -513,7 +622,7 @@ def to_expression( class IncrementalUnmanagedKind(_Incremental): - name: Literal[ModelKindName.INCREMENTAL_UNMANAGED] = ModelKindName.INCREMENTAL_UNMANAGED + name: t.Literal[ModelKindName.INCREMENTAL_UNMANAGED] = ModelKindName.INCREMENTAL_UNMANAGED insert_overwrite: SQLGlotBool = False forward_only: SQLGlotBool = True disable_restatement: SQLGlotBool = True @@ -548,13 +657,17 @@ def to_expression( class ViewKind(_ModelKind): - name: Literal[ModelKindName.VIEW] = ModelKindName.VIEW + name: t.Literal[ModelKindName.VIEW] = ModelKindName.VIEW materialized: SQLGlotBool = False @property def data_hash_values(self) -> t.List[t.Optional[str]]: return [*super().data_hash_values, str(self.materialized)] + @property + def supports_python_models(self) -> bool: + return False + def to_expression( self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any ) -> d.ModelKind: @@ -567,7 +680,7 @@ def to_expression( class SeedKind(_ModelKind): - name: Literal[ModelKindName.SEED] = ModelKindName.SEED + name: t.Literal[ModelKindName.SEED] = ModelKindName.SEED path: SQLGlotString batch_size: SQLGlotPositiveInt = 1000 csv_settings: t.Optional[CsvSettings] = None @@ -578,7 +691,7 @@ def _parse_csv_settings(cls, v: t.Any) -> t.Optional[CsvSettings]: if v is None or isinstance(v, CsvSettings): return v if isinstance(v, exp.Expression): - tuple_exp = parse_properties(cls, v, {}) + tuple_exp = parse_properties(cls, v, None) if not tuple_exp: return None return CsvSettings(**{e.left.name: e.right for e in tuple_exp.expressions}) @@ -604,18 +717,23 @@ def to_expression( @property def data_hash_values(self) -> t.List[t.Optional[str]]: + csv_setting_values = (self.csv_settings or CsvSettings()).dict().values() return [ *super().data_hash_values, - *(self.csv_settings or CsvSettings()).dict().values(), + *(v if isinstance(v, (str, type(None))) else str(v) for v in csv_setting_values), ] @property def metadata_hash_values(self) -> t.List[t.Optional[str]]: return [*super().metadata_hash_values, str(self.batch_size)] + @property + def supports_python_models(self) -> bool: + return False + class FullKind(_ModelKind): - name: Literal[ModelKindName.FULL] = ModelKindName.FULL + name: t.Literal[ModelKindName.FULL] = ModelKindName.FULL class _SCDType2Kind(_Incremental): @@ -625,19 +743,18 @@ class _SCDType2Kind(_Incremental): valid_to_name: SQLGlotColumn = Field(exp.column("valid_to"), validate_default=True) invalidate_hard_deletes: SQLGlotBool = False time_data_type: exp.DataType = Field(exp.DataType.build("TIMESTAMP"), validate_default=True) + batch_size: t.Optional[SQLGlotPositiveInt] = None forward_only: SQLGlotBool = True disable_restatement: SQLGlotBool = True _dialect_validator = kind_dialect_validator - # Remove once Pydantic 1 is deprecated - _always_validate_column = field_validator( - "valid_from_name", "valid_to_name", mode="before", always=True - )(column_validator) + _always_validate_column = field_validator("valid_from_name", "valid_to_name", mode="before")( + column_validator + ) - # always=True can be removed once Pydantic 1 is deprecated - @field_validator("time_data_type", mode="before", always=True) + @field_validator("time_data_type", mode="before") @classmethod def _time_data_type_validator( cls, v: t.Union[str, exp.Expression], values: t.Any @@ -665,7 +782,8 @@ def data_hash_values(self) -> t.List[t.Optional[str]]: gen(self.valid_from_name), gen(self.valid_to_name), str(self.invalidate_hard_deletes), - gen(self.time_data_type), + self.time_data_type.sql(self.dialect), + gen(self.batch_size) if self.batch_size is not None else None, ] @property @@ -698,14 +816,13 @@ def to_expression( class SCDType2ByTimeKind(_SCDType2Kind): - name: Literal[ModelKindName.SCD_TYPE_2, ModelKindName.SCD_TYPE_2_BY_TIME] = ( + name: t.Literal[ModelKindName.SCD_TYPE_2, ModelKindName.SCD_TYPE_2_BY_TIME] = ( ModelKindName.SCD_TYPE_2_BY_TIME ) updated_at_name: SQLGlotColumn = Field(exp.column("updated_at"), validate_default=True) updated_at_as_valid_from: SQLGlotBool = False - # Remove once Pydantic 1 is deprecated - _always_validate_updated_at = field_validator("updated_at_name", mode="before", always=True)( + _always_validate_updated_at = field_validator("updated_at_name", mode="before")( column_validator ) @@ -734,9 +851,10 @@ def to_expression( class SCDType2ByColumnKind(_SCDType2Kind): - name: Literal[ModelKindName.SCD_TYPE_2_BY_COLUMN] = ModelKindName.SCD_TYPE_2_BY_COLUMN - columns: SQLGlotListOfColumnsOrStar + name: t.Literal[ModelKindName.SCD_TYPE_2_BY_COLUMN] = ModelKindName.SCD_TYPE_2_BY_COLUMN + columns: SQLGlotListOfFieldsOrStar execution_time_as_valid_from: SQLGlotBool = False + updated_at_name: t.Optional[SQLGlotColumn] = None @property def data_hash_values(self) -> t.List[t.Optional[str]]: @@ -745,7 +863,12 @@ def data_hash_values(self) -> t.List[t.Optional[str]]: if isinstance(self.columns, list) else [gen(self.columns)] ) - return [*super().data_hash_values, *columns_sql, str(self.execution_time_as_valid_from)] + return [ + *super().data_hash_values, + *columns_sql, + str(self.execution_time_as_valid_from), + gen(self.updated_at_name) if self.updated_at_name is not None else None, + ] def to_expression( self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any @@ -766,7 +889,7 @@ def to_expression( class ManagedKind(_ModelKind): - name: Literal[ModelKindName.MANAGED] = ModelKindName.MANAGED + name: t.Literal[ModelKindName.MANAGED] = ModelKindName.MANAGED disable_restatement: t.Literal[True] = True @property @@ -774,16 +897,60 @@ def supports_python_models(self) -> bool: return False +class DbtCustomKind(_ModelKind): + name: t.Literal[ModelKindName.DBT_CUSTOM] = ModelKindName.DBT_CUSTOM + materialization: str + adapter: str = "default" + definition: str + dialect: t.Optional[str] = Field(None, validate_default=True) + + _dialect_validator = kind_dialect_validator + + @field_validator("materialization", "adapter", "definition", mode="before") + @classmethod + def _validate_fields(cls, v: t.Any) -> str: + return validate_string(v) + + @property + def data_hash_values(self) -> t.List[t.Optional[str]]: + return [ + *super().data_hash_values, + self.materialization, + self.definition, + self.adapter, + self.dialect, + ] + + def to_expression( + self, expressions: t.Optional[t.List[exp.Expression]] = None, **kwargs: t.Any + ) -> d.ModelKind: + return super().to_expression( + expressions=[ + *(expressions or []), + *_properties( + { + "materialization": exp.Literal.string(self.materialization), + "adapter": exp.Literal.string(self.adapter), + } + ), + ], + ) + + class EmbeddedKind(_ModelKind): - name: Literal[ModelKindName.EMBEDDED] = ModelKindName.EMBEDDED + name: t.Literal[ModelKindName.EMBEDDED] = ModelKindName.EMBEDDED + + @property + def supports_python_models(self) -> bool: + return False class ExternalKind(_ModelKind): - name: Literal[ModelKindName.EXTERNAL] = ModelKindName.EXTERNAL + name: t.Literal[ModelKindName.EXTERNAL] = ModelKindName.EXTERNAL class CustomKind(_ModelKind): - name: Literal[ModelKindName.CUSTOM] = ModelKindName.CUSTOM + name: t.Literal[ModelKindName.CUSTOM] = ModelKindName.CUSTOM materialization: str materialization_properties_: t.Optional[exp.Tuple] = Field( default=None, alias="materialization_properties" @@ -793,18 +960,19 @@ class CustomKind(_ModelKind): batch_size: t.Optional[SQLGlotPositiveInt] = None batch_concurrency: t.Optional[SQLGlotPositiveInt] = None lookback: t.Optional[SQLGlotPositiveInt] = None + auto_restatement_cron: t.Optional[SQLGlotCron] = None + auto_restatement_intervals: t.Optional[SQLGlotPositiveInt] = None + + # so that CustomKind subclasses know the dialect when validating / normalizing / interpreting values in `materialization_properties` + dialect: str = Field(exclude=True) _properties_validator = properties_validator @field_validator("materialization", mode="before") @classmethod def _validate_materialization(cls, v: t.Any) -> str: - from sqlmesh.core.snapshot.evaluator import get_custom_materialization_type - - materialization = validate_string(v) - # The below call fails if a materialization with the given name doesn't exist. - get_custom_materialization_type(materialization) - return materialization + # note: create_model_kind() validates the custom materialization class + return validate_string(v) @property def materialization_properties(self) -> CustomMaterializationProperties: @@ -830,6 +998,10 @@ def metadata_hash_values(self) -> t.List[t.Optional[str]]: str(self.batch_concurrency) if self.batch_concurrency is not None else None, str(self.forward_only), str(self.disable_restatement), + self.auto_restatement_cron, + str(self.auto_restatement_intervals) + if self.auto_restatement_intervals is not None + else None, ] def to_expression( @@ -847,13 +1019,15 @@ def to_expression( "batch_size": self.batch_size, "batch_concurrency": self.batch_concurrency, "lookback": self.lookback, + "auto_restatement_cron": self.auto_restatement_cron, + "auto_restatement_intervals": self.auto_restatement_intervals, } ), ], ) -ModelKind = Annotated[ +ModelKind = t.Annotated[ t.Union[ EmbeddedKind, ExternalKind, @@ -868,6 +1042,7 @@ def to_expression( SCDType2ByColumnKind, CustomKind, ManagedKind, + DbtCustomKind, ], Field(discriminator="name"), ] @@ -887,6 +1062,7 @@ def to_expression( ModelKindName.SCD_TYPE_2_BY_COLUMN: SCDType2ByColumnKind, ModelKindName.CUSTOM: CustomKind, ModelKindName.MANAGED: ManagedKind, + ModelKindName.DBT_CUSTOM: DbtCustomKind, } @@ -919,28 +1095,63 @@ def create_model_kind(v: t.Any, dialect: str, defaults: t.Dict[str, t.Any]) -> M if "dialect" in kind_type.all_fields() and props.get("dialect") is None: props["dialect"] = dialect - # only pass the on_destructive_change user default to models inheriting from _Incremental + # only pass the on_destructive_change or on_additive_change user default to models inheriting from _Incremental # that don't explicitly set it in the model definition - if ( - issubclass(kind_type, _Incremental) - and props.get("on_destructive_change") is None - and defaults.get("on_destructive_change") is not None - ): - props["on_destructive_change"] = defaults.get("on_destructive_change") + if issubclass(kind_type, _Incremental): + for on_change_property in ("on_additive_change", "on_destructive_change"): + if ( + props.get(on_change_property) is None + and defaults.get(on_change_property) is not None + ): + props[on_change_property] = defaults.get(on_change_property) + + # only pass the batch_concurrency user default to models inheriting from _IncrementalBy + # that don't explicitly set it in the model definition, but ignore subclasses of _IncrementalBy + # that hardcode a specific batch_concurrency + if issubclass(kind_type, _IncrementalBy): + BATCH_CONCURRENCY: t.Final = "batch_concurrency" + if ( + props.get(BATCH_CONCURRENCY) is None + and defaults.get(BATCH_CONCURRENCY) is not None + and kind_type.all_field_infos()[BATCH_CONCURRENCY].default is None + ): + props[BATCH_CONCURRENCY] = defaults.get(BATCH_CONCURRENCY) + + if kind_type == CustomKind: + # load the custom materialization class and check if it uses a custom kind type + from sqlmesh.core.snapshot.evaluator import get_custom_materialization_type + + if "materialization" not in props: + raise ConfigError( + "The 'materialization' property is required for models of the CUSTOM kind" + ) + + # The below call will print a warning if a materialization with the given name doesn't exist + # we dont want to throw an error here because we still want Models with a CustomKind to be able + # to be serialized / deserialized in contexts where the custom materialization class may not be available, + # such as in HTTP request handlers + custom_materialization = get_custom_materialization_type( + validate_string(props.get("materialization")), raise_errors=False + ) + if custom_materialization is not None: + actual_kind_type, _ = custom_materialization + return actual_kind_type(**props) + validate_extra_and_required_fields( + kind_type, set(props), f"MODEL block 'kind {name}' field" + ) return kind_type(**props) name = (v.name if isinstance(v, exp.Expression) else str(v)).upper() return model_kind_type_from_name(name)(name=name) # type: ignore -@field_validator_v1_args -def _model_kind_validator(cls: t.Type, v: t.Any, values: t.Dict[str, t.Any]) -> ModelKind: - dialect = get_dialect(values) +def _model_kind_validator(cls: t.Type, v: t.Any, info: t.Optional[ValidationInfo]) -> ModelKind: + dialect = get_dialect(info.data) if info else "" return create_model_kind(v, dialect, {}) -model_kind_validator = field_validator("kind", mode="before")(_model_kind_validator) +model_kind_validator: t.Callable = field_validator("kind", mode="before")(_model_kind_validator) def _property(name: str, value: t.Any) -> exp.Property: diff --git a/sqlmesh/core/model/meta.py b/sqlmesh/core/model/meta.py index 86a2e6e927..c48b7d1524 100644 --- a/sqlmesh/core/model/meta.py +++ b/sqlmesh/core/model/meta.py @@ -1,22 +1,26 @@ from __future__ import annotations -import logging import typing as t +from enum import Enum from functools import cached_property +from typing_extensions import Self from pydantic import Field -from sqlglot import Dialect, exp +from sqlglot import Dialect, exp, parse_one from sqlglot.helper import ensure_collection, ensure_list from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlmesh.core import dialect as d +from sqlmesh.core.config.common import VirtualEnvironmentMode +from sqlmesh.core.config.linter import LinterConfig from sqlmesh.core.dialect import normalize_model_name +from sqlmesh.utils import classproperty from sqlmesh.core.model.common import ( bool_validator, default_catalog_validator, depends_on_validator, - parse_properties, properties_validator, + parse_properties, ) from sqlmesh.core.model.kind import ( CustomKind, @@ -27,27 +31,56 @@ SCDType2ByTimeKind, TimeColumn, ViewKind, - _IncrementalBy, model_kind_validator, + OnAdditiveChange, ) from sqlmesh.core.node import _Node, str_or_exp_to_str from sqlmesh.core.reference import Reference from sqlmesh.utils.date import TimeLike from sqlmesh.utils.errors import ConfigError from sqlmesh.utils.pydantic import ( + ValidationInfo, field_validator, - field_validator_v1_args, list_of_fields_validator, model_validator, - model_validator_v1_args, + get_dialect, ) if t.TYPE_CHECKING: from sqlmesh.core._typing import CustomMaterializationProperties, SessionProperties + from sqlmesh.core.engine_adapter._typing import GrantsConfig + +FunctionCall = t.Tuple[str, t.Dict[str, exp.Expression]] + + +class GrantsTargetLayer(str, Enum): + """Target layer(s) where grants should be applied.""" + + ALL = "all" + PHYSICAL = "physical" + VIRTUAL = "virtual" -AuditReference = t.Tuple[str, t.Dict[str, exp.Expression]] + @classproperty + def default(cls) -> "GrantsTargetLayer": + return GrantsTargetLayer.VIRTUAL -logger = logging.getLogger(__name__) + @property + def is_all(self) -> bool: + return self == GrantsTargetLayer.ALL + + @property + def is_physical(self) -> bool: + return self == GrantsTargetLayer.PHYSICAL + + @property + def is_virtual(self) -> bool: + return self == GrantsTargetLayer.VIRTUAL + + def __str__(self) -> str: + return self.name + + def __repr__(self) -> str: + return str(self) class ModelMeta(_Node): @@ -57,16 +90,17 @@ class ModelMeta(_Node): name: str kind: ModelKind = ViewKind() retention: t.Optional[int] = None # not implemented yet + table_format: t.Optional[str] = None storage_format: t.Optional[str] = None partitioned_by_: t.List[exp.Expression] = Field(default=[], alias="partitioned_by") - clustered_by: t.List[str] = [] + clustered_by: t.List[exp.Expression] = [] default_catalog: t.Optional[str] = None depends_on_: t.Optional[t.Set[str]] = Field(default=None, alias="depends_on") columns_to_types_: t.Optional[t.Dict[str, exp.DataType]] = Field(default=None, alias="columns") column_descriptions_: t.Optional[t.Dict[str, str]] = Field( default=None, alias="column_descriptions" ) - audits: t.List[AuditReference] = [] + audits: t.List[FunctionCall] = [] grains: t.List[exp.Expression] = [] references: t.List[exp.Expression] = [] physical_schema_override: t.Optional[str] = None @@ -74,8 +108,18 @@ class ModelMeta(_Node): virtual_properties_: t.Optional[exp.Tuple] = Field(default=None, alias="virtual_properties") session_properties_: t.Optional[exp.Tuple] = Field(default=None, alias="session_properties") allow_partials: bool = False - signals: t.List[exp.Tuple] = [] + signals: t.List[FunctionCall] = [] enabled: bool = True + physical_version: t.Optional[str] = None + gateway: t.Optional[str] = None + optimize_query: t.Optional[bool] = None + ignored_rules_: t.Optional[t.Set[str]] = Field( + default=None, exclude=True, alias="ignored_rules" + ) + formatting: t.Optional[bool] = Field(default=None, exclude=True) + virtual_environment_mode: VirtualEnvironmentMode = VirtualEnvironmentMode.default + grants_: t.Optional[exp.Tuple] = Field(default=None, alias="grants") + grants_target_layer: GrantsTargetLayer = GrantsTargetLayer.default _bool_validator = bool_validator _model_kind_validator = model_kind_validator @@ -83,75 +127,21 @@ class ModelMeta(_Node): _default_catalog_validator = default_catalog_validator _depends_on_validator = depends_on_validator - @field_validator("audits", mode="before") - def _audits_validator(cls, v: t.Any) -> t.Any: - def extract(v: exp.Expression) -> t.Tuple[str, t.Dict[str, exp.Expression]]: - kwargs = {} - - if isinstance(v, exp.Anonymous): - func = v.name - args = v.expressions - elif isinstance(v, exp.Func): - func = v.sql_name() - args = list(v.args.values()) - else: - return v.name.lower(), {} - - for arg in args: - if not isinstance(arg, (exp.PropertyEQ, exp.EQ)): - raise ConfigError( - f"Function '{func}' must be called with key-value arguments like {func}(arg := value)." - ) - kwargs[arg.left.name.lower()] = arg.right - return func.lower(), kwargs - - if isinstance(v, (exp.Tuple, exp.Array)): - return [extract(i) for i in v.expressions] - if isinstance(v, exp.Paren): - return [extract(v.this)] - if isinstance(v, exp.Expression): - return [extract(v)] - if isinstance(v, list): - audits = [] - - for entry in v: - if isinstance(entry, dict): - args = entry - name = entry.pop("name") - elif isinstance(entry, (tuple, list)): - name, args = entry - else: - raise ConfigError(f"Audit must be a dictionary or named tuple. Got {entry}.") - - audits.append( - ( - name.lower(), - { - key: d.parse_one(value) if isinstance(value, str) else value - for key, value in args.items() - }, - ) - ) - - return audits + @field_validator("audits", "signals", mode="before") + def _func_call_validator(cls, v: t.Any, field: t.Any) -> t.Any: + is_signal = getattr(field, "name" if hasattr(field, "name") else "field_name") == "signals" - return v + return d.extract_function_calls(v, allow_tuples=is_signal) @field_validator("tags", mode="before") - @field_validator_v1_args - def _value_or_tuple_validator(cls, v: t.Any, values: t.Dict[str, t.Any]) -> t.Any: - return ensure_list(cls._validate_value_or_tuple(v, values)) - - @field_validator("clustered_by", mode="before") - @field_validator_v1_args - def _normalized_value_or_tuple_validator(cls, v: t.Any, values: t.Dict[str, t.Any]) -> t.Any: - return ensure_list(cls._validate_value_or_tuple(v, values, normalize=True)) + def _value_or_tuple_validator(cls, v: t.Any, info: ValidationInfo) -> t.Any: + return ensure_list(cls._validate_value_or_tuple(v, info.data)) @classmethod def _validate_value_or_tuple( - cls, v: t.Dict[str, t.Any], values: t.Dict[str, t.Any], normalize: bool = False + cls, v: t.Dict[str, t.Any], data: t.Dict[str, t.Any], normalize: bool = False ) -> t.Any: - dialect = values.get("dialect") + dialect = data.get("dialect") def _normalize(value: t.Any) -> t.Any: return normalize_identifiers(value, dialect=dialect) if normalize else value @@ -167,12 +157,14 @@ def _normalize(value: t.Any) -> t.Any: value = _normalize(v) return value.name if isinstance(value, exp.Expression) else value if isinstance(v, (list, tuple)): - return [cls._validate_value_or_tuple(elm, values, normalize=normalize) for elm in v] + return [cls._validate_value_or_tuple(elm, data, normalize=normalize) for elm in v] return v - @field_validator("storage_format", mode="before") - def _storage_format_validator(cls, v: t.Any) -> t.Optional[str]: + @field_validator("table_format", "storage_format", mode="before") + def _format_validator(cls, v: t.Any, info: ValidationInfo) -> t.Optional[str]: + if isinstance(v, exp.Expression) and not (isinstance(v, (exp.Literal, exp.Identifier))): + return v.sql(info.data.get("dialect")) return str_or_exp_to_str(v) @field_validator("dialect", mode="before") @@ -182,15 +174,43 @@ def _dialect_validator(cls, v: t.Any) -> t.Optional[str]: dialect = str_or_exp_to_str(v) return dialect and dialect.lower() - @field_validator("partitioned_by_", mode="before") - @field_validator_v1_args - def _partition_by_validator( - cls, v: t.Any, values: t.Dict[str, t.Any] + @field_validator("physical_version", mode="before") + def _physical_version_validator(cls, v: t.Any) -> t.Optional[str]: + if v is None: + return v + return str_or_exp_to_str(v) + + @field_validator("gateway", mode="before") + def _gateway_validator(cls, v: t.Any) -> t.Optional[str]: + if v is None: + return None + gateway = str_or_exp_to_str(v) + return gateway and gateway.lower() + + @field_validator("partitioned_by_", "clustered_by", mode="before") + def _partition_and_cluster_validator( + cls, v: t.Any, info: ValidationInfo ) -> t.List[exp.Expression]: - partitions = list_of_fields_validator(v, values) + if ( + isinstance(v, list) + and all(isinstance(i, str) for i in v) + and info.field_name == "partitioned_by_" + ): + # this branch gets hit when we are deserializing from json because `partitioned_by` is stored as a List[str] + # however, we should only invoke this if the list contains strings because this validator is also + # called by Python models which might pass a List[exp.Expression] + string_to_parse = ( + f"({','.join(v)})" # recreate the (a, b, c) part of "partitioned_by (a, b, c)" + ) + parsed = parse_one( + string_to_parse, into=exp.PartitionedByProperty, dialect=get_dialect(info) + ) + v = parsed.this.expressions if isinstance(parsed.this, exp.Schema) else v + + expressions = list_of_fields_validator(v, info.data) - for partition in partitions: - num_cols = len(list(partition.find_all(exp.Column))) + for expression in expressions: + num_cols = len(list(expression.find_all(exp.Column))) error_msg: t.Optional[str] = None if num_cols == 0: @@ -199,23 +219,25 @@ def _partition_by_validator( error_msg = "contains multiple columns" if error_msg: - raise ConfigError(f"partitioned_by field '{partition}' {error_msg}") + raise ConfigError(f"Field '{expression}' {error_msg}") - return partitions + return expressions @field_validator( "columns_to_types_", "derived_columns_to_types", mode="before", check_fields=False ) - @field_validator_v1_args def _columns_validator( - cls, v: t.Any, values: t.Dict[str, t.Any] + cls, v: t.Any, info: ValidationInfo ) -> t.Optional[t.Dict[str, exp.DataType]]: columns_to_types = {} - dialect = values.get("dialect") + dialect = info.data.get("dialect") if isinstance(v, exp.Schema): for column in v.expressions: - expr = column.args["kind"] + expr = column.args.get("kind") + if not isinstance(expr, exp.DataType): + raise ConfigError(f"Missing data type for column '{column.name}'.") + expr.meta["dialect"] = dialect columns_to_types[normalize_identifiers(column, dialect=dialect).name] = expr @@ -233,11 +255,10 @@ def _columns_validator( return v @field_validator("column_descriptions_", mode="before") - @field_validator_v1_args def _column_descriptions_validator( - cls, vs: t.Any, values: t.Dict[str, t.Any] + cls, vs: t.Any, info: ValidationInfo ) -> t.Optional[t.Dict[str, str]]: - dialect = values.get("dialect") + dialect = info.data.get("dialect") if vs is None: return None @@ -249,7 +270,9 @@ def _column_descriptions_validator( vs = vs.expressions raw_col_descriptions = ( - vs if isinstance(vs, dict) else {v.this.name: v.expression.name for v in vs} + vs + if isinstance(vs, dict) + else {".".join([part.this for part in v.this.parts]): v.expression.name for v in vs} ) col_descriptions = { @@ -257,20 +280,23 @@ def _column_descriptions_validator( for k, v in raw_col_descriptions.items() } - columns_to_types = values.get("columns_to_types_") + columns_to_types = info.data.get("columns_to_types_") if columns_to_types: - for column_name in col_descriptions: + from sqlmesh.core.console import get_console + + console = get_console() + for column_name in list(col_descriptions): if column_name not in columns_to_types: - raise ConfigError( - f"In model '{values['name']}', a description is provided for column '{column_name}' but it is not a column in the model." + console.log_warning( + f"In model '{info.data['name']}', a description is provided for column '{column_name}' but it is not a column in the model." ) + del col_descriptions[column_name] return col_descriptions @field_validator("grains", "references", mode="before") - @field_validator_v1_args - def _refs_validator(cls, vs: t.Any, values: t.Dict[str, t.Any]) -> t.List[exp.Expression]: - dialect = values.get("dialect") + def _refs_validator(cls, vs: t.Any, info: ValidationInfo) -> t.List[exp.Expression]: + dialect = info.data.get("dialect") if isinstance(vs, exp.Paren): vs = vs.unnest() @@ -292,72 +318,130 @@ def _refs_validator(cls, vs: t.Any, values: t.Dict[str, t.Any]) -> t.List[exp.Ex return refs - @field_validator("signals", mode="before") - @field_validator_v1_args - def _signals_validator(cls, v: t.Any, values: t.Dict[str, t.Any]) -> t.Any: - if v is None: - return [] + @field_validator("ignored_rules_", mode="before") + def ignored_rules_validator(cls, vs: t.Any) -> t.Any: + return LinterConfig._validate_rules(vs) - if isinstance(v, str): - dialect = values.get("dialect") - v = d.parse_one(v, dialect=dialect) + @field_validator("grants_target_layer", mode="before") + def _grants_target_layer_validator(cls, v: t.Any) -> t.Any: + if isinstance(v, exp.Identifier): + return v.this + if isinstance(v, exp.Literal) and v.is_string: + return v.this + return v - if isinstance(v, (exp.Array, exp.Paren, exp.Tuple)): - tuples: t.List[exp.Expression] = ( - [v.unnest()] if isinstance(v, exp.Paren) else v.expressions - ) - signals = [parse_properties(cls, t, values) for t in tuples] - elif isinstance(v, list): - signals = [parse_properties(cls, t, values) for t in v] - else: - raise ConfigError(f"Unexpected signals '{v}'") + @field_validator("session_properties_", mode="before") + def session_properties_validator(cls, v: t.Any, info: ValidationInfo) -> t.Any: + # use the generic properties validator to parse the session properties + parsed_session_properties = parse_properties(type(cls), v, info) + if not parsed_session_properties: + return parsed_session_properties + + for eq in parsed_session_properties: + prop_name = eq.left.name + + if prop_name == "query_label": + query_label = eq.right + if not isinstance( + query_label, (exp.Array, exp.Tuple, exp.Paren, d.MacroFunc, d.MacroVar) + ): + raise ConfigError( + "Invalid value for `session_properties.query_label`. Must be an array or tuple." + ) - return signals + label_tuples: t.List[exp.Expression] = ( + [query_label.unnest()] + if isinstance(query_label, exp.Paren) + else query_label.expressions + ) + + for label_tuple in label_tuples: + if not ( + isinstance(label_tuple, exp.Tuple) + and len(label_tuple.expressions) == 2 + and all(isinstance(label, exp.Literal) for label in label_tuple.expressions) + ): + raise ConfigError( + "Invalid entry in `session_properties.query_label`. Must be tuples of string literals with length 2." + ) + elif prop_name == "authorization": + authorization = eq.right + if not ( + isinstance(authorization, exp.Literal) and authorization.is_string + ) and not isinstance(authorization, (d.MacroFunc, d.MacroVar)): + raise ConfigError( + "Invalid value for `session_properties.authorization`. Must be a string literal." + ) + + return parsed_session_properties @model_validator(mode="before") - @model_validator_v1_args - def _pre_root_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: - grain = values.pop("grain", None) + def _pre_root_validator(cls, data: t.Any) -> t.Any: + if not isinstance(data, dict): + return data + + grain = data.pop("grain", None) if grain: - grains = values.get("grains") + grains = data.get("grains") if grains: raise ConfigError( f"Cannot use argument 'grain' ({grain}) with 'grains' ({grains}), use only grains" ) - values["grains"] = ensure_list(grain) + data["grains"] = ensure_list(grain) - table_properties = values.pop("table_properties", None) + table_properties = data.pop("table_properties", None) if table_properties: if not isinstance(table_properties, str): # Do not warn when deserializing from the state. - model_name = values["name"] - logger.warning( + model_name = data["name"] + from sqlmesh.core.console import get_console + + get_console().log_warning( f"Model '{model_name}' is using the `table_properties` attribute which is deprecated. Please use `physical_properties` instead." ) - physical_properties = values.get("physical_properties") + physical_properties = data.get("physical_properties") if physical_properties: raise ConfigError( f"Cannot use argument 'table_properties' ({table_properties}) with 'physical_properties' ({physical_properties}), use only physical_properties." ) - values["physical_properties"] = table_properties - return values + + data["physical_properties"] = table_properties + + return data @model_validator(mode="after") - @model_validator_v1_args - def _root_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: - values = cls._kind_validator(values) - return values + def _root_validator(self) -> Self: + kind: t.Any = self.kind + + for field in ("partitioned_by_", "clustered_by"): + if ( + getattr(self, field, None) + and not kind.is_materialized + and not (kind.is_view and kind.materialized) + ): + name = field[:-1] if field.endswith("_") else field + raise ValueError(f"{name} field cannot be set for {kind.name} models") + if kind.is_incremental_by_partition and not getattr(self, "partitioned_by_", None): + raise ValueError(f"partitioned_by field is required for {kind.name} models") + + # needs to be in a mode=after model validator so that the field validators have run to convert from Expression -> str + if (storage_format := self.storage_format) and storage_format.lower() in { + "iceberg", + "hive", + "hudi", + "delta", + }: + from sqlmesh.core.console import get_console + + get_console().log_warning( + f"Model {self.name} has `storage_format` set to a table format '{storage_format}' which is deprecated. Please use the `table_format` property instead." + ) - @classmethod - def _kind_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: - kind = values.get("kind") - if kind: - for field in ("partitioned_by_", "clustered_by"): - if values.get(field) and not kind.is_materialized: - raise ValueError(f"{field} field cannot be set for {kind} models") - if kind.is_incremental_by_partition and not values.get("partitioned_by_"): - raise ValueError(f"partitioned_by field is required for {kind.name} models") - return values + # Validate grants configuration for model kind support + if self.grants is not None and not kind.supports_grants: + raise ValueError(f"grants cannot be set for {kind.name} models") + + return self @property def time_column(self) -> t.Optional[TimeColumn]: @@ -372,14 +456,6 @@ def unique_key(self) -> t.List[exp.Expression]: return self.kind.unique_key return [] - @property - def partitioned_by(self) -> t.List[exp.Expression]: - if self.time_column and self.time_column.column not in [ - col for col in self._partition_by_columns - ]: - return [self.time_column.column, *self.partitioned_by_] - return self.partitioned_by_ - @property def column_descriptions(self) -> t.Dict[str, str]: """A dictionary of column names to annotation comments.""" @@ -388,7 +464,7 @@ def column_descriptions(self) -> t.Dict[str, str]: @property def lookback(self) -> int: """The incremental lookback window.""" - return (self.kind.lookback if isinstance(self.kind, _IncrementalBy) else 0) or 0 + return getattr(self.kind, "lookback", 0) or 0 def lookback_start(self, start: TimeLike) -> TimeLike: if self.lookback == 0: @@ -436,6 +512,30 @@ def custom_materialization_properties(self) -> CustomMaterializationProperties: return self.kind.materialization_properties return {} + @cached_property + def grants(self) -> t.Optional[GrantsConfig]: + """A dictionary of grants mapping permission names to lists of grantees.""" + + if self.grants_ is None: + return None + + if not self.grants_.expressions: + return {} + + grants_dict = {} + for eq_expr in self.grants_.expressions: + try: + permission_name = self._validate_config_expression(eq_expr.left) + grantee_list = self._validate_nested_config_values(eq_expr.expression) + grants_dict[permission_name] = grantee_list + except ConfigError as e: + permission_name = ( + eq_expr.left.name if hasattr(eq_expr.left, "name") else str(eq_expr.left) + ) + raise ConfigError(f"Invalid grants configuration for '{permission_name}': {e}") + + return grants_dict if grants_dict else None + @property def all_references(self) -> t.List[Reference]: """All references including grains.""" @@ -444,19 +544,35 @@ def all_references(self) -> t.List[Reference]: ] @property - def _partition_by_columns(self) -> t.List[exp.Column]: - return [col for expr in self.partitioned_by_ for col in expr.find_all(exp.Column)] + def on(self) -> t.List[str]: + """The grains to be used as join condition in table_diff.""" + + on: t.List[str] = [] + for expr in [ref.expression for ref in self.all_references if ref.unique]: + if isinstance(expr, exp.Tuple): + on.extend([key.this.sql(dialect=self.dialect) for key in expr.expressions]) + else: + # Handle a single Column or Paren expression + on.append(expr.this.sql(dialect=self.dialect)) + + return on @property def managed_columns(self) -> t.Dict[str, exp.DataType]: return getattr(self.kind, "managed_columns", {}) @property - def when_matched(self) -> t.Optional[exp.When]: + def when_matched(self) -> t.Optional[exp.Whens]: if isinstance(self.kind, IncrementalByUniqueKeyKind): return self.kind.when_matched return None + @property + def merge_filter(self) -> t.Optional[exp.Expression]: + if isinstance(self.kind, IncrementalByUniqueKeyKind): + return self.kind.merge_filter + return None + @property def catalog(self) -> t.Optional[str]: """Returns the catalog of a model.""" @@ -475,3 +591,42 @@ def fqn(self) -> str: @property def on_destructive_change(self) -> OnDestructiveChange: return getattr(self.kind, "on_destructive_change", OnDestructiveChange.ALLOW) + + @property + def on_additive_change(self) -> OnAdditiveChange: + """Return the model's additive change setting if it has one.""" + return getattr(self.kind, "on_additive_change", OnAdditiveChange.ALLOW) + + @property + def ignored_rules(self) -> t.Set[str]: + return self.ignored_rules_ or set() + + def _validate_config_expression(self, expr: exp.Expression) -> str: + if isinstance(expr, (d.MacroFunc, d.MacroVar)): + raise ConfigError(f"Unresolved macro: {expr.sql(dialect=self.dialect)}") + + if isinstance(expr, exp.Null): + raise ConfigError("NULL value") + + if isinstance(expr, exp.Literal): + return str(expr.this).strip() + if isinstance(expr, (exp.Column, exp.Identifier)): + return expr.name + return expr.sql(dialect=self.dialect).strip() + + def _validate_nested_config_values(self, value_expr: exp.Expression) -> t.List[str]: + result = [] + + def flatten_expr(expr: exp.Expression) -> None: + if isinstance(expr, exp.Array): + for elem in expr.expressions: + flatten_expr(elem) + elif isinstance(expr, (exp.Tuple, exp.Paren)): + expressions = [expr.unnest()] if isinstance(expr, exp.Paren) else expr.expressions + for elem in expressions: + flatten_expr(elem) + else: + result.append(self._validate_config_expression(expr)) + + flatten_expr(value_expr) + return result diff --git a/sqlmesh/core/model/schema.py b/sqlmesh/core/model/schema.py new file mode 100644 index 0000000000..e29cacade0 --- /dev/null +++ b/sqlmesh/core/model/schema.py @@ -0,0 +1,100 @@ +from __future__ import annotations + +import typing as t +from concurrent.futures import as_completed +from pathlib import Path + +from sqlglot.errors import SchemaError +from sqlglot.schema import MappingSchema + +from sqlmesh.core.model.cache import ( + load_optimized_query_and_mapping, + optimized_query_cache_pool, + OptimizedQueryCache, +) + +if t.TYPE_CHECKING: + from sqlmesh.core.model.definition import Model + from sqlmesh.utils import UniqueKeyDict + from sqlmesh.utils.dag import DAG + + +def update_model_schemas( + dag: DAG[str], + models: UniqueKeyDict[str, Model], + cache_dir: Path, +) -> None: + schema = MappingSchema(normalize=False) + optimized_query_cache: OptimizedQueryCache = OptimizedQueryCache(cache_dir) + + _update_model_schemas(dag, models, schema, optimized_query_cache) + + +def _update_schema_with_model(schema: MappingSchema, model: Model) -> None: + columns_to_types = model.columns_to_types + if columns_to_types: + try: + schema.add_table(model.fqn, columns_to_types, dialect=model.dialect) + except SchemaError as e: + if "nesting level:" in str(e): + from sqlmesh.core.console import get_console + + get_console().log_error( + "SQLMesh requires all model names and references to have the same level of nesting." + ) + raise + + +def _update_model_schemas( + dag: DAG[str], + models: UniqueKeyDict[str, Model], + schema: MappingSchema, + optimized_query_cache: OptimizedQueryCache, +) -> None: + futures = set() + graph = { + model: {dep for dep in deps if dep in models} + for model, deps in dag._dag.items() + if model in models + } + + def process_models(completed_model: t.Optional[Model] = None) -> None: + for name in list(graph): + deps = graph[name] + + if completed_model: + deps.discard(completed_model.fqn) + + if not deps: + del graph[name] + model = models[name] + futures.add( + executor.submit( + load_optimized_query_and_mapping, + model, + mapping={ + parent: models[parent].columns_to_types + for parent in model.depends_on + if parent in models + }, + ) + ) + + with optimized_query_cache_pool(optimized_query_cache) as executor: + process_models() + + while futures: + for future in as_completed(futures): + try: + futures.remove(future) + fqn, entry_name, data_hash, metadata_hash, mapping_schema = future.result() + model = models[fqn] + model._data_hash = data_hash + model._metadata_hash = metadata_hash + if model.mapping_schema != mapping_schema: + model.set_mapping_schema(mapping_schema) + optimized_query_cache.with_optimized_query(model, entry_name) + _update_schema_with_model(schema, model) + process_models(completed_model=model) + except Exception as ex: + raise SchemaError(f"Failed to update model schemas\n\n{ex}") diff --git a/sqlmesh/core/model/seed.py b/sqlmesh/core/model/seed.py index c6dbb7245f..fe1aa85204 100644 --- a/sqlmesh/core/model/seed.py +++ b/sqlmesh/core/model/seed.py @@ -1,19 +1,28 @@ from __future__ import annotations +import logging import typing as t import zlib from io import StringIO from pathlib import Path -import pandas as pd from sqlglot import exp from sqlglot.dialects.dialect import UNESCAPED_SEQUENCES +from sqlglot.helper import seq_get from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlmesh.core.model.common import parse_bool from sqlmesh.utils.pandas import columns_to_types_from_df from sqlmesh.utils.pydantic import PydanticModel, field_validator +if t.TYPE_CHECKING: + import pandas as pd + +logger = logging.getLogger(__name__) + +NaHashables = t.List[t.Union[int, str, bool, t.Literal[None]]] +NaValues = t.Union[NaHashables, t.Dict[str, NaHashables]] + class CsvSettings(PydanticModel): """Settings for CSV seeds.""" @@ -25,8 +34,10 @@ class CsvSettings(PydanticModel): skipinitialspace: t.Optional[bool] = None lineterminator: t.Optional[str] = None encoding: t.Optional[str] = None + na_values: t.Optional[NaValues] = None + keep_default_na: t.Optional[bool] = None - @field_validator("doublequote", "skipinitialspace", mode="before") + @field_validator("doublequote", "skipinitialspace", "keep_default_na", mode="before") @classmethod def _bool_validator(cls, v: t.Any) -> t.Optional[bool]: if v is None: @@ -46,6 +57,36 @@ def _str_validator(cls, v: t.Any) -> t.Optional[str]: v = v.this return UNESCAPED_SEQUENCES.get(v, v) + @field_validator("na_values", mode="before") + @classmethod + def _na_values_validator(cls, v: t.Any) -> t.Optional[NaValues]: + if v is None or not isinstance(v, exp.Expression): + return v + + try: + if isinstance(v, exp.Paren) or not isinstance(v, (exp.Tuple, exp.Array)): + v = exp.Tuple(expressions=[v.unnest()]) + + expressions = v.expressions + if isinstance(seq_get(expressions, 0), (exp.PropertyEQ, exp.EQ)): + return { + e.left.name: [ + rhs_val.to_py() + for rhs_val in ( + [e.right.unnest()] + if isinstance(e.right, exp.Paren) + else e.right.expressions + ) + ] + for e in expressions + } + + return [e.to_py() for e in expressions] + except ValueError as e: + logger.warning(f"Failed to coerce na_values '{v}', proceeding with defaults. {str(e)}") + + return None + class CsvSeedReader: def __init__(self, content: str, dialect: str, settings: CsvSettings): @@ -76,6 +117,8 @@ def read(self, batch_size: t.Optional[int] = None) -> t.Generator[pd.DataFrame, batch_start += batch_size def _get_df(self) -> pd.DataFrame: + import pandas as pd + if self._df is None: self._df = pd.read_csv( StringIO(self.content), diff --git a/sqlmesh/core/node.py b/sqlmesh/core/node.py index a6a6bea8b2..4a3bf2564b 100644 --- a/sqlmesh/core/node.py +++ b/sqlmesh/core/node.py @@ -1,6 +1,7 @@ from __future__ import annotations import typing as t +import zoneinfo from datetime import datetime from enum import Enum from pathlib import Path @@ -13,13 +14,14 @@ from sqlmesh.utils.errors import ConfigError from sqlmesh.utils.pydantic import ( PydanticModel, + SQLGlotCron, field_validator, model_validator, - model_validator_v1_args, + PRIVATE_FIELDS, ) if t.TYPE_CHECKING: - from sqlmesh.core.audit import ModelAudit + from sqlmesh.core._typing import Self from sqlmesh.core.snapshot import Node @@ -29,6 +31,9 @@ class IntervalUnit(str, Enum): IntervalUnit can be one of 5 types, YEAR, MONTH, DAY, HOUR, MINUTE. The unit is inferred based on the cron schedule of a node. The minimum time delta between a sample set of dates is used to determine which unit a node's schedule is. + + It's designed to align with common partitioning schemes, hence why there is no WEEK unit + because generally tables are not partitioned by week """ YEAR = "year" @@ -148,6 +153,101 @@ def milliseconds(self) -> int: return self.seconds * 1000 +class DbtNodeInfo(PydanticModel): + """ + Represents dbt-specific model information set by the dbt loader and intended to be made available at the Snapshot level + (as opposed to hidden within the individual model jinja macro registries). + + This allows for things like injecting implementations of variables / functions into the Jinja context that are compatible with + their dbt equivalents but are backed by the sqlmesh snapshots in any given plan / environment + """ + + unique_id: str + """This is the node/resource name/unique_id that's used as the node key in the dbt manifest. + It's prefixed by the resource type and is exposed in context variables like {{ selected_resources }}. + + Examples: + - test.jaffle_shop.unique_stg_orders_order_id.e3b841c71a + - seed.jaffle_shop.raw_payments + - model.jaffle_shop.stg_orders + """ + + name: str + """Name of this object in the dbt global namespace, used by things like {{ ref() }} calls. + + Examples: + - unique_stg_orders_order_id + - raw_payments + - stg_orders + """ + + fqn: str + """Used for selectors in --select/--exclude. + Takes the filesystem into account so may be structured differently to :unique_id. + + Examples: + - jaffle_shop.staging.unique_stg_orders_order_id + - jaffle_shop.raw_payments + - jaffle_shop.staging.stg_orders + """ + + alias: t.Optional[str] = None + """This is dbt's way of overriding the _physical table_ a model is written to. + + It's used in the following situation: + - Say you have two models, "stg_customers" and "customers" + - You want "stg_customers" to be written to the "staging" schema as eg "staging.customers" - NOT "staging.stg_customers" + - But you cant rename the file to "customers" because it will conflict with your other model file "customers" + - Even if you put it in a different folder, eg "staging/customers.sql" - dbt still has a global namespace so it will conflict + when you try to do something like "{{ ref('customers') }}" + - So dbt's solution to this problem is to keep calling it "stg_customers" at the dbt project/model level, + but allow overriding the physical table to "customers" via something like "{{ config(alias='customers', schema='staging') }}" + + Note that if :alias is set, it does *not* replace :name at the model level and cannot be used interchangably with :name. + It also does not affect the :fqn or :unique_id. It's just used to override :name when it comes time to generate the physical table name. + """ + + @model_validator(mode="after") + def post_init(self) -> Self: + # by default, dbt sets alias to the same as :name + # however, we only want to include :alias if it is actually different / actually providing an override + if self.alias == self.name: + self.alias = None + return self + + def to_expression(self) -> exp.Expression: + """Produce a SQLGlot expression representing this object, for use in things like the model/audit definition renderers""" + return exp.tuple_( + *( + exp.PropertyEQ(this=exp.var(k), expression=exp.Literal.string(v)) + for k, v in sorted(self.model_dump(exclude_none=True).items()) + ) + ) + + +class DbtInfoMixin: + """This mixin encapsulates properties that only exist for dbt compatibility and are otherwise not required + for native projects""" + + @property + def dbt_node_info(self) -> t.Optional[DbtNodeInfo]: + raise NotImplementedError() + + @property + def dbt_unique_id(self) -> t.Optional[str]: + """Used for compatibility with jinja context variables such as {{ selected_resources }}""" + if self.dbt_node_info: + return self.dbt_node_info.unique_id + return None + + @property + def dbt_fqn(self) -> t.Optional[str]: + """Used in the selector engine for compatibility with selectors that select models by dbt fqn""" + if self.dbt_node_info: + return self.dbt_node_info.fqn + return None + + # this must be sorted in descending order INTERVAL_SECONDS = { IntervalUnit.YEAR: 60 * 60 * 24 * 365, @@ -160,13 +260,13 @@ def milliseconds(self) -> int: } -class _Node(PydanticModel): +class _Node(DbtInfoMixin, PydanticModel): """ Node is the core abstraction for entity that can be executed within the scheduler. Args: name: The name of the node. - description: The name of the project this node belongs to, used in multi-repo deployments. + project: The name of the project this node belongs to, used in multi-repo deployments. description: The optional node description. owner: The owner of the node. start: The earliest date that the node will be executed for. If this is None, @@ -176,6 +276,7 @@ class _Node(PydanticModel): the date from the scheduler will be used cron: A cron string specifying how often the node should be run, leveraging the [croniter](https://github.com/kiorky/croniter) library. + cron_tz: Time zone for the cron, defaults to utc, [IANA time zones](https://docs.python.org/3/library/zoneinfo.html). interval_unit: The duration of an interval for the node. By default, it is computed from the cron expression. tags: A list of tags that can be used to filter nodes. stamp: An optional arbitrary string sequence used to create new node versions without making @@ -188,11 +289,15 @@ class _Node(PydanticModel): owner: t.Optional[str] = None start: t.Optional[TimeLike] = None end: t.Optional[TimeLike] = None - cron: str = "@daily" + cron: SQLGlotCron = "@daily" + cron_tz: t.Optional[zoneinfo.ZoneInfo] = None interval_unit_: t.Optional[IntervalUnit] = Field(alias="interval_unit", default=None) tags: t.List[str] = [] stamp: t.Optional[str] = None - _path: Path = Path() + dbt_node_info_: t.Optional[DbtNodeInfo] = Field(alias="dbt_node_info", default=None) + _path: t.Optional[Path] = None + _data_hash: t.Optional[str] = None + _metadata_hash: t.Optional[str] = None _croniter: t.Optional[CroniterCache] = None __inferred_interval_unit: t.Optional[IntervalUnit] = None @@ -201,6 +306,19 @@ def __str__(self) -> str: path = f": {self._path.name}" if self._path else "" return f"{self.__class__.__name__}<{self.name}{path}>" + def __getstate__(self) -> t.Dict[t.Any, t.Any]: + state = super().__getstate__() + private = state[PRIVATE_FIELDS] + private["_data_hash"] = None + private["_metadata_hash"] = None + return state + + def copy(self, **kwargs: t.Any) -> Self: + node = super().copy(**kwargs) + node._data_hash = None + node._metadata_hash = None + return node + @field_validator("name", mode="before") @classmethod def _name_validator(cls, v: t.Any) -> t.Optional[str]: @@ -210,6 +328,27 @@ def _name_validator(cls, v: t.Any) -> t.Optional[str]: return v.meta["sql"] return str(v) + @field_validator("cron_tz", mode="before") + def _cron_tz_validator(cls, v: t.Any) -> t.Optional[zoneinfo.ZoneInfo]: + if not v or v == "UTC": + return None + + v = str_or_exp_to_str(v) + + try: + return zoneinfo.ZoneInfo(v) + except Exception as e: + available_timezones = zoneinfo.available_timezones() + + if available_timezones: + raise ConfigError(f"{e}. {v} must be in {available_timezones}.") + else: + raise ConfigError( + f"{e}. IANA time zone data is not available on your system. `pip install tzdata` to leverage cron time zones or remove this field which will default to UTC." + ) + + return None + @field_validator("start", "end", mode="before") @classmethod def _date_validator(cls, v: t.Any) -> t.Optional[TimeLike]: @@ -219,19 +358,6 @@ def _date_validator(cls, v: t.Any) -> t.Optional[TimeLike]: raise ConfigError(f"'{v}' needs to be time-like: https://pypi.org/project/dateparser") return v - @field_validator("cron", mode="before") - @classmethod - def _cron_validator(cls, v: t.Any) -> t.Optional[str]: - cron = str_or_exp_to_str(v) - if cron: - from croniter import CroniterBadCronError, croniter - - try: - croniter(cron) - except CroniterBadCronError: - raise ConfigError(f"Invalid cron expression '{cron}'") - return cron - @field_validator("owner", "description", "stamp", mode="before") @classmethod def _string_expr_validator(cls, v: t.Any) -> t.Optional[str]: @@ -248,22 +374,24 @@ def _interval_unit_validator(cls, v: t.Any) -> t.Optional[t.Union[IntervalUnit, return v @model_validator(mode="after") - @model_validator_v1_args - def _node_root_validator(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: - interval_unit = values.get("interval_unit_") - if interval_unit: - cron = values["cron"] + def _node_root_validator(self) -> Self: + interval_unit = self.interval_unit_ + if interval_unit and not getattr(self, "allow_partials", None): + cron = self.cron max_interval_unit = IntervalUnit.from_cron(cron) if interval_unit.seconds > max_interval_unit.seconds: raise ConfigError( - f"Interval unit of '{interval_unit}' is larger than cron period of '{cron}'" + f"Cron '{cron}' cannot be more frequent than interval unit '{interval_unit.value}'. " + "If this is intentional, set allow_partials to True." ) - start = values.get("start") - end = values.get("end") + + start = self.start + end = self.end + if end is not None and start is None: raise ConfigError("Must define a start date if an end date is defined.") validate_date_range(start, end) - return values + return self @property def batch_size(self) -> t.Optional[int]: @@ -275,16 +403,6 @@ def batch_concurrency(self) -> t.Optional[int]: """The maximal number of batches that can run concurrently for a backfill.""" return None - @property - def data_hash(self) -> str: - """ - Computes the data hash for the node. - - Returns: - The data hash for the node. - """ - raise NotImplementedError - @property def interval_unit(self) -> IntervalUnit: """Returns the interval unit using which data intervals are computed for this node.""" @@ -300,23 +418,55 @@ def depends_on(self) -> t.Set[str]: def fqn(self) -> str: return self.name - def metadata_hash(self, audits: t.Dict[str, ModelAudit]) -> str: + @property + def data_hash(self) -> str: """ - Computes the metadata hash for the node. + Computes the data hash for the node. - Args: - audits: Available audits by name. + Returns: + The data hash for the node. + """ + raise NotImplementedError + + @property + def metadata_hash(self) -> str: + """ + Computes the metadata hash for the node. Returns: The metadata hash for the node. """ raise NotImplementedError + def is_metadata_only_change(self, previous: _Node) -> bool: + """Determines if this node is a metadata only change in relation to the `previous` node. + + Args: + previous: The previous node to compare against. + + Returns: + True if this node is a metadata only change, False otherwise. + """ + return self.data_hash == previous.data_hash and self.metadata_hash != previous.metadata_hash + + def is_data_change(self, previous: _Node) -> bool: + """Determines if this node is a data change in relation to the `previous` node. + + Args: + previous: The previous node to compare against. + + Returns: + True if this node is a data change, False otherwise. + """ + return ( + self.data_hash != previous.data_hash or self.metadata_hash != previous.metadata_hash + ) and not self.is_metadata_only_change(previous) + def croniter(self, value: TimeLike) -> CroniterCache: if self._croniter is None: - self._croniter = CroniterCache(self.cron, value) + self._croniter = CroniterCache(self.cron, value, tz=self.cron_tz) else: - self._croniter.curr = to_datetime(value) + self._croniter.curr = to_datetime(value, tz=self.cron_tz) return self._croniter def cron_next(self, value: TimeLike, estimate: bool = False) -> datetime: @@ -358,7 +508,7 @@ def cron_floor(self, value: TimeLike, estimate: bool = False) -> datetime: """ return self.croniter(self.cron_next(value, estimate=estimate)).get_prev(estimate=True) - def text_diff(self, other: Node) -> str: + def text_diff(self, other: Node, rendered: bool = False) -> str: """Produce a text diff against another node. Args: @@ -391,6 +541,10 @@ def is_audit(self) -> bool: """Return True if this is an audit node""" return False + @property + def dbt_node_info(self) -> t.Optional[DbtNodeInfo]: + return self.dbt_node_info_ + class NodeType(str, Enum): MODEL = "model" diff --git a/sqlmesh/core/notification_target.py b/sqlmesh/core/notification_target.py index 7eaf1c3aaf..fba6e36f66 100644 --- a/sqlmesh/core/notification_target.py +++ b/sqlmesh/core/notification_target.py @@ -13,16 +13,6 @@ from sqlmesh.utils.errors import AuditError, ConfigError, MissingDependencyError from sqlmesh.utils.pydantic import PydanticModel -if sys.version_info >= (3, 8): - from typing import Literal -else: - from typing_extensions import Literal - -if sys.version_info >= (3, 9): - from typing import Annotated -else: - from typing_extensions import Annotated - if t.TYPE_CHECKING: from slack_sdk import WebClient, WebhookClient @@ -80,8 +70,7 @@ class NotificationEvent(str, Enum): class BaseNotificationTarget(PydanticModel, frozen=True): """ Base notification target model. Provides a command for sending notifications that is currently only used - by the built-in scheduler. Other schedulers like Airflow use the configuration of the target itself - to create the notification constructs appropriate for the scheduler. + by the built-in scheduler. Notification functions follow a naming convention of `notify_` + NotificationEvent value. """ @@ -97,7 +86,9 @@ def send(self, notification_status: NotificationStatus, msg: str, **kwargs: t.An msg: The message to send. """ - def notify_apply_start(self, environment: str, plan_id: str) -> None: + def notify_apply_start( + self, environment: str, plan_id: str, *args: t.Any, **kwargs: t.Any + ) -> None: """Notify when an apply starts. Args: @@ -109,7 +100,9 @@ def notify_apply_start(self, environment: str, plan_id: str) -> None: f"Plan `{plan_id}` apply started for environment `{environment}`.", ) - def notify_apply_end(self, environment: str, plan_id: str) -> None: + def notify_apply_end( + self, environment: str, plan_id: str, *args: t.Any, **kwargs: t.Any + ) -> None: """Notify when an apply ends. Args: @@ -121,7 +114,7 @@ def notify_apply_end(self, environment: str, plan_id: str) -> None: f"Plan `{plan_id}` apply finished for environment `{environment}`.", ) - def notify_run_start(self, environment: str) -> None: + def notify_run_start(self, environment: str, *args: t.Any, **kwargs: t.Any) -> None: """Notify when a SQLMesh run starts. Args: @@ -129,7 +122,7 @@ def notify_run_start(self, environment: str) -> None: """ self.send(NotificationStatus.INFO, f"SQLMesh run started for environment `{environment}`.") - def notify_run_end(self, environment: str) -> None: + def notify_run_end(self, environment: str, *args: t.Any, **kwargs: t.Any) -> None: """Notify when a SQLMesh run ends. Args: @@ -139,15 +132,17 @@ def notify_run_end(self, environment: str) -> None: NotificationStatus.SUCCESS, f"SQLMesh run finished for environment `{environment}`." ) - def notify_migration_start(self) -> None: + def notify_migration_start(self, *args: t.Any, **kwargs: t.Any) -> None: """Notify when a SQLMesh migration starts.""" self.send(NotificationStatus.INFO, "SQLMesh migration started.") - def notify_migration_end(self) -> None: + def notify_migration_end(self, *args: t.Any, **kwargs: t.Any) -> None: """Notify when a SQLMesh migration ends.""" self.send(NotificationStatus.SUCCESS, "SQLMesh migration finished.") - def notify_apply_failure(self, environment: str, plan_id: str, exc: str) -> None: + def notify_apply_failure( + self, environment: str, plan_id: str, exc: str, *args: t.Any, **kwargs: t.Any + ) -> None: """Notify in the case of an apply failure. Args: @@ -161,7 +156,7 @@ def notify_apply_failure(self, environment: str, plan_id: str, exc: str) -> None exc=exc, ) - def notify_run_failure(self, exc: str) -> None: + def notify_run_failure(self, exc: str, *args: t.Any, **kwargs: t.Any) -> None: """Notify in the case of a run failure. Args: @@ -169,7 +164,7 @@ def notify_run_failure(self, exc: str) -> None: """ self.send(NotificationStatus.FAILURE, "SQLMesh run failed.", exc=exc) - def notify_audit_failure(self, audit_error: AuditError) -> None: + def notify_audit_failure(self, audit_error: AuditError, *args: t.Any, **kwargs: t.Any) -> None: """Notify in the case of an audit failure. Args: @@ -177,7 +172,7 @@ def notify_audit_failure(self, audit_error: AuditError) -> None: """ self.send(NotificationStatus.FAILURE, "Audit failure.", audit_error=audit_error) - def notify_migration_failure(self, exc: str) -> None: + def notify_migration_failure(self, exc: str, *args: t.Any, **kwargs: t.Any) -> None: """Notify in the case of a migration failure. Args: @@ -220,7 +215,7 @@ class ConsoleNotificationTarget(BaseTextBasedNotificationTarget): Example console notification target. Keeping this around for testing purposes. """ - type_: Literal["console"] = Field(alias="type", default="console") + type_: t.Literal["console"] = Field(alias="type", default="console") _console: t.Optional[Console] = None @property @@ -283,6 +278,8 @@ def send( ), ) + composed.add_text(msg) + self._send_slack_message( composed=composed.slack_message, ) @@ -297,7 +294,7 @@ def _send_slack_message(self, composed: slack.TSlackMessage) -> None: class SlackWebhookNotificationTarget(BaseSlackNotificationTarget): url: t.Optional[str] = None - type_: Literal["slack_webhook"] = Field(alias="type", default="slack_webhook") + type_: t.Literal["slack_webhook"] = Field(alias="type", default="slack_webhook") _client: t.Optional[WebhookClient] = None @property @@ -318,6 +315,7 @@ def client(self) -> WebhookClient: def _send_slack_message(self, composed: slack.TSlackMessage) -> None: self.client.send( + text=composed["text"], blocks=composed["blocks"], attachments=composed["attachments"], # type: ignore ) @@ -330,7 +328,7 @@ def is_configured(self) -> bool: class SlackApiNotificationTarget(BaseSlackNotificationTarget): token: t.Optional[str] = None channel: t.Optional[str] = None - type_: Literal["slack_api"] = Field(alias="type", default="slack_api") + type_: t.Literal["slack_api"] = Field(alias="type", default="slack_api") _client: t.Optional[WebClient] = None @property @@ -352,6 +350,7 @@ def _send_slack_message(self, composed: slack.TSlackMessage) -> None: self.client.chat_postMessage( channel=self.channel, + text=composed["text"], blocks=composed["blocks"], attachments=composed["attachments"], # type: ignore ) @@ -369,7 +368,7 @@ class BasicSMTPNotificationTarget(BaseTextBasedNotificationTarget): sender: t.Optional[str] = None recipients: t.Optional[t.FrozenSet[str]] = None subject: t.Optional[str] = "SQLMesh Notification" - type_: Literal["smtp"] = Field(alias="type", default="smtp") + type_: t.Literal["smtp"] = Field(alias="type", default="smtp") def send_text_message( self, @@ -394,9 +393,36 @@ def is_configured(self) -> bool: return all((self.host, self.user, self.password, self.sender)) -NotificationTarget = Annotated[ +class GenericNotificationTarget(BaseNotificationTarget): + """A generic notification target that can be used to create custom notification targets. + + This target is not meant to be used directly, but rather as a base class for custom notification targets. + + The `send` method should be overridden to provide the actual notification functionality. + + Example: + ```python + class MyCustomNotificationTarget(GenericNotificationTarget): + def send(self, notification_status: NotificationStatus, msg: str, audit_error: t.Optional[AuditError] = None, exc: t.Optional[str] = None, **kwargs: t.Any) -> None: + error = None + if audit_error: + error = str(audit_error) + elif exc: + error = exc + + if error: + msg = f"{error} - {msg}" + print(f"Sending notification: {msg}") + ``` + """ + + type_: t.Literal["generic"] = Field(alias="type", default="generic") + + +NotificationTarget = t.Annotated[ t.Union[ BasicSMTPNotificationTarget, + GenericNotificationTarget, ConsoleNotificationTarget, SlackApiNotificationTarget, SlackWebhookNotificationTarget, diff --git a/sqlmesh/core/plan/__init__.py b/sqlmesh/core/plan/__init__.py index 139979e24e..8b3ba63e55 100644 --- a/sqlmesh/core/plan/__init__.py +++ b/sqlmesh/core/plan/__init__.py @@ -1,13 +1,12 @@ from sqlmesh.core.plan.builder import PlanBuilder as PlanBuilder from sqlmesh.core.plan.definition import ( Plan as Plan, + EvaluatablePlan as EvaluatablePlan, PlanStatus as PlanStatus, SnapshotIntervals as SnapshotIntervals, ) from sqlmesh.core.plan.evaluator import ( - AirflowPlanEvaluator as AirflowPlanEvaluator, BuiltInPlanEvaluator as BuiltInPlanEvaluator, - MWAAPlanEvaluator as MWAAPlanEvaluator, PlanEvaluator as PlanEvaluator, - update_intervals_for_new_snapshots as update_intervals_for_new_snapshots, ) +from sqlmesh.core.plan.explainer import PlanExplainer as PlanExplainer diff --git a/sqlmesh/core/plan/builder.py b/sqlmesh/core/plan/builder.py index 4fdd9cc6ad..01834594cd 100644 --- a/sqlmesh/core/plan/builder.py +++ b/sqlmesh/core/plan/builder.py @@ -2,14 +2,13 @@ import logging import re -import sys import typing as t from collections import defaultdict -from datetime import datetime from functools import cached_property +from datetime import datetime -from sqlmesh.core.console import Console, get_console +from sqlmesh.core.console import PlanBuilderConsole, get_console from sqlmesh.core.config import ( AutoCategorizationMode, CategorizerConfig, @@ -17,19 +16,38 @@ ) from sqlmesh.core.context_diff import ContextDiff from sqlmesh.core.environment import EnvironmentNamingInfo -from sqlmesh.core.plan.definition import Plan, SnapshotMapping, earliest_interval_start -from sqlmesh.core.schema_diff import SchemaDiffer, has_drop_alteration +from sqlmesh.core.plan.common import should_force_rebuild, is_breaking_kind_change +from sqlmesh.core.plan.definition import ( + Plan, + SnapshotMapping, + UserProvidedFlags, + earliest_interval_start, +) +from sqlmesh.core.schema_diff import ( + get_schema_differ, + has_drop_alteration, + has_additive_alteration, + TableAlterOperation, +) from sqlmesh.core.snapshot import ( DeployabilityIndex, Snapshot, SnapshotChangeCategory, ) from sqlmesh.core.snapshot.categorizer import categorize_change -from sqlmesh.core.snapshot.definition import Interval, SnapshotId, start_date +from sqlmesh.core.snapshot.definition import Interval, SnapshotId from sqlmesh.utils import columns_to_types_all_known, random_id from sqlmesh.utils.dag import DAG -from sqlmesh.utils.date import TimeLike, now, to_datetime, yesterday_ds -from sqlmesh.utils.errors import NoChangesPlanError, PlanError, SQLMeshError +from sqlmesh.utils.date import ( + TimeLike, + now, + to_datetime, + yesterday_ds, + to_timestamp, + time_like_to_str, + is_relative, +) +from sqlmesh.utils.errors import NoChangesPlanError, PlanError logger = logging.getLogger(__name__) @@ -42,18 +60,24 @@ class PlanBuilder: start: The start time to backfill data. end: The end time to backfill data. execution_time: The date/time time reference to use for execution time. Defaults to now. + If :start or :end are relative time expressions, they are interpreted as relative to the :execution_time apply: The callback to apply the plan. restate_models: A list of models for which the data should be restated for the time range specified in this plan. Note: models defined outside SQLMesh (external) won't be a part of the restatement. + restate_all_snapshots: If restatements are present, this flag indicates whether or not the intervals + being restated should be cleared from state for other versions of this model (typically, versions that are present in other environments). + If set to None, the default behaviour is to not clear anything unless the target environment is prod. backfill_models: A list of fully qualified model names for which the data should be backfilled as part of this plan. no_gaps: Whether to ensure that new snapshots for nodes that are already a part of the target environment have no data gaps when compared against previous snapshots for same nodes. skip_backfill: Whether to skip the backfill step. + empty_backfill: Like skip_backfill, but also records processed intervals. is_dev: Whether this plan is for development purposes. forward_only: Whether the purpose of the plan is to make forward only changes. allow_destructive_models: A list of fully qualified model names whose forward-only changes are allowed to be destructive. + allow_additive_models: A list of fully qualified model names whose forward-only changes are allowed to be additive. environment_ttl: The period of time that a development environment should exist before being deleted. categorizer_config: Auto categorization settings. auto_categorization_enabled: Whether to apply auto categorization. @@ -68,24 +92,29 @@ class PlanBuilder: ensure_finalized_snapshots: Whether to compare against snapshots from the latest finalized environment state, or to use whatever snapshots are in the current environment state even if the environment is not finalized. - engine_schema_differ: Schema differ from the context engine adapter. + start_override_per_model: A mapping of model FQNs to target start dates. + end_override_per_model: A mapping of model FQNs to target end dates. + ignore_cron: Whether to ignore the node's cron schedule when computing missing intervals. + explain: Whether to explain the plan instead of applying it. """ def __init__( self, context_diff: ContextDiff, - engine_schema_differ: SchemaDiffer, start: t.Optional[TimeLike] = None, end: t.Optional[TimeLike] = None, execution_time: t.Optional[TimeLike] = None, apply: t.Optional[t.Callable[[Plan], None]] = None, restate_models: t.Optional[t.Iterable[str]] = None, + restate_all_snapshots: bool = False, backfill_models: t.Optional[t.Iterable[str]] = None, no_gaps: bool = False, skip_backfill: bool = False, + empty_backfill: bool = False, is_dev: bool = False, forward_only: bool = False, allow_destructive_models: t.Optional[t.Iterable[str]] = None, + allow_additive_models: t.Optional[t.Iterable[str]] = None, environment_ttl: t.Optional[str] = None, environment_suffix_target: EnvironmentSuffixTarget = EnvironmentSuffixTarget.default, environment_catalog_mapping: t.Optional[t.Dict[re.Pattern, str]] = None, @@ -98,34 +127,61 @@ def __init__( enable_preview: bool = False, end_bounded: bool = False, ensure_finalized_snapshots: bool = False, - console: t.Optional[Console] = None, + explain: bool = False, + ignore_cron: bool = False, + start_override_per_model: t.Optional[t.Dict[str, datetime]] = None, + end_override_per_model: t.Optional[t.Dict[str, datetime]] = None, + console: t.Optional[PlanBuilderConsole] = None, + user_provided_flags: t.Optional[t.Dict[str, UserProvidedFlags]] = None, + selected_models: t.Optional[t.Set[str]] = None, ): self._context_diff = context_diff self._no_gaps = no_gaps self._skip_backfill = skip_backfill + self._empty_backfill = empty_backfill self._is_dev = is_dev self._forward_only = forward_only self._allow_destructive_models = set( allow_destructive_models if allow_destructive_models is not None else [] ) + self._allow_additive_models = set( + allow_additive_models if allow_additive_models is not None else [] + ) self._enable_preview = enable_preview self._end_bounded = end_bounded self._ensure_finalized_snapshots = ensure_finalized_snapshots + self._ignore_cron = ignore_cron + self._start_override_per_model = start_override_per_model + self._end_override_per_model = end_override_per_model self._environment_ttl = environment_ttl self._categorizer_config = categorizer_config or CategorizerConfig() self._auto_categorization_enabled = auto_categorization_enabled self._include_unmodified = include_unmodified self._restate_models = set(restate_models) if restate_models is not None else None + self._restate_all_snapshots = restate_all_snapshots self._effective_from = effective_from + + # note: this deliberately doesnt default to now() here. + # There may be an significant delay between the PlanBuilder producing a Plan and the Plan actually being run + # so if execution_time=None is passed to the PlanBuilder, then the resulting Plan should also have execution_time=None + # in order to prevent the Plan that was intended to run "as at now" from having "now" fixed to some time in the past + # ref: https://github.com/SQLMesh/sqlmesh/pull/4702#discussion_r2140696156 self._execution_time = execution_time + self._backfill_models = backfill_models self._end = end or default_end + self._default_start = default_start self._apply = apply - self._engine_schema_differ = engine_schema_differ self._console = console or get_console() + self._choices: t.Dict[SnapshotId, SnapshotChangeCategory] = {} + self._user_provided_flags = user_provided_flags + self._selected_models = selected_models + self._explain = explain self._start = start - if not self._start and self._forward_only_preview_needed: + if not self._start and ( + self._forward_only_preview_needed or self._non_forward_only_preview_needed + ): self._start = default_start or yesterday_ds() self._plan_id: str = random_id() @@ -138,6 +194,7 @@ def __init__( name=self._context_diff.environment, suffix_target=environment_suffix_target, normalize_name=self._context_diff.normalize_environment_name, + gateway_managed=self._context_diff.gateway_managed_virtual_layer, ) self._latest_plan: t.Optional[Plan] = None @@ -147,6 +204,26 @@ def is_start_and_end_allowed(self) -> bool: """Indicates whether this plan allows to set the start and end dates.""" return self._is_dev or bool(self._restate_models) + @property + def start(self) -> t.Optional[TimeLike]: + if self._start and is_relative(self._start): + # only do this for relative expressions otherwise inclusive date strings like '2020-01-01' can be turned into exclusive timestamps eg '2020-01-01 00:00:00' + return to_datetime(self._start, relative_base=to_datetime(self.execution_time)) + return self._start + + @property + def end(self) -> t.Optional[TimeLike]: + if self._end and is_relative(self._end): + # only do this for relative expressions otherwise inclusive date strings like '2020-01-01' can be turned into exclusive timestamps eg '2020-01-01 00:00:00' + return to_datetime(self._end, relative_base=to_datetime(self.execution_time)) + return self._end + + @cached_property + def execution_time(self) -> TimeLike: + # this is cached to return a stable value from now() in the places where the execution time matters for resolving relative date strings + # during the plan building process + return self._execution_time or now() + def set_start(self, new_start: TimeLike) -> PlanBuilder: self._start = new_start self.override_start = True @@ -180,15 +257,24 @@ def set_choice(self, snapshot: Snapshot, choice: SnapshotChangeCategory) -> Plan snapshot: The target snapshot. choice: The user decision on how to version the target snapshot and its children. """ - plan = self.build() - self._set_choice(snapshot, choice, plan.directly_modified, plan.indirectly_modified) - self._adjust_new_snapshot_intervals() + if not self._is_new_snapshot(snapshot): + raise PlanError( + f"A choice can't be changed for the existing version of {snapshot.name}." + ) + if ( + not self._context_diff.directly_modified(snapshot.name) + and snapshot.snapshot_id not in self._context_diff.added + ): + raise PlanError(f"Only directly modified models can be categorized ({snapshot.name}).") + + self._choices[snapshot.snapshot_id] = choice + self._latest_plan = None return self def apply(self) -> None: """Builds and applies the plan.""" if not self._apply: - raise SQLMeshError("Plan was not initialized with an applier.") + raise PlanError("Plan was not initialized with an applier.") self._apply(self.build()) def build(self) -> Plan: @@ -196,10 +282,8 @@ def build(self) -> Plan: if self._latest_plan: return self._latest_plan - self._ensure_no_new_snapshots_with_restatements() self._ensure_new_env_with_changes() self._ensure_valid_date_range() - self._ensure_no_forward_only_revert() self._ensure_no_broken_references() self._apply_effective_from() @@ -207,62 +291,71 @@ def build(self) -> Plan: dag = self._build_dag() directly_modified, indirectly_modified = self._build_directly_and_indirectly_modified(dag) - self._check_destructive_changes(directly_modified) - self._categorize_snapshots(dag, directly_modified, indirectly_modified) - self._adjust_new_snapshot_intervals() + self._check_destructive_additive_changes(directly_modified) + self._categorize_snapshots(dag, indirectly_modified) + self._adjust_snapshot_intervals() deployability_index = ( - DeployabilityIndex.create(self._context_diff.snapshots.values()) + DeployabilityIndex.create( + self._context_diff.snapshots.values(), + start=self._start, + start_override_per_model=self._start_override_per_model, + ) if self._is_dev else DeployabilityIndex.all_deployable() ) - filtered_dag, ignored = self._build_filtered_dag(dag, deployability_index) - - # Exclude ignored snapshots from the modified sets. - directly_modified = {s_id for s_id in directly_modified if s_id not in ignored} - for s_id in list(indirectly_modified): - if s_id in ignored: - indirectly_modified.pop(s_id, None) - else: - indirectly_modified[s_id] = { - s_id for s_id in indirectly_modified[s_id] if s_id not in ignored - } - - filtered_snapshots = { - s.snapshot_id: s - for s in self._context_diff.snapshots.values() - if s.snapshot_id not in ignored - } - - models_to_backfill = self._build_models_to_backfill(filtered_dag) restatements = self._build_restatements( - dag, earliest_interval_start(filtered_snapshots.values()) + dag, + earliest_interval_start(self._context_diff.snapshots.values(), self.execution_time), ) + models_to_backfill = self._build_models_to_backfill(dag, restatements) + + end_override_per_model = self._end_override_per_model + if end_override_per_model and self.override_end: + # If the end date was provided explicitly by a user, then interval end for each individual + # model should be ignored. + end_override_per_model = None + + # this deliberately uses the passed in self._execution_time and not self.execution_time cached property + # the reason is because that there can be a delay between the Plan being built and the Plan being actually run, + # so this ensures that an _execution_time of None can be propagated to the Plan and thus be re-resolved to + # the current timestamp of when the Plan is eventually run + plan_execution_time = self._execution_time plan = Plan( context_diff=self._context_diff, plan_id=self._plan_id, - provided_start=self._start, - provided_end=self._end, + provided_start=self.start, + provided_end=self.end, is_dev=self._is_dev, skip_backfill=self._skip_backfill, + empty_backfill=self._empty_backfill, no_gaps=self._no_gaps, forward_only=self._forward_only, + explain=self._explain, allow_destructive_models=t.cast(t.Set, self._allow_destructive_models), + allow_additive_models=t.cast(t.Set, self._allow_additive_models), include_unmodified=self._include_unmodified, environment_ttl=self._environment_ttl, environment_naming_info=self.environment_naming_info, directly_modified=directly_modified, indirectly_modified=indirectly_modified, - ignored=ignored, deployability_index=deployability_index, + selected_models_to_restate=self._restate_models, restatements=restatements, + restate_all_snapshots=self._restate_all_snapshots, + start_override_per_model=self._start_override_per_model, + end_override_per_model=end_override_per_model, + selected_models_to_backfill=self._backfill_models, models_to_backfill=models_to_backfill, effective_from=self._effective_from, - execution_time=self._execution_time, + execution_time=plan_execution_time, end_bounded=self._end_bounded, ensure_finalized_snapshots=self._ensure_finalized_snapshots, + ignore_cron=self._ignore_cron, + user_provided_flags=self._user_provided_flags, + selected_models=self._selected_models, ) self._latest_plan = plan return plan @@ -273,38 +366,9 @@ def _build_dag(self) -> DAG[SnapshotId]: dag.add(s_id, context_snapshot.parents) return dag - def _build_filtered_dag( - self, full_dag: DAG[SnapshotId], deployability_index: DeployabilityIndex - ) -> t.Tuple[DAG[SnapshotId], t.Set[SnapshotId]]: - ignored_snapshot_ids: t.Set[SnapshotId] = set() - filtered_dag: DAG[SnapshotId] = DAG() - cache: t.Optional[t.Dict[str, datetime]] = {} - for s_id in full_dag: - snapshot = self._context_diff.snapshots.get(s_id) - # If the snapshot doesn't exist then it must be an external model - if not snapshot: - continue - - is_deployable = deployability_index.is_deployable(s_id) - is_valid_start = snapshot.is_valid_start( - self._start, start_date(snapshot, self._context_diff.snapshots.values(), cache) - ) - if set(snapshot.parents).isdisjoint(ignored_snapshot_ids) and ( - not is_deployable or is_valid_start - ): - filtered_dag.add(s_id, snapshot.parents) - else: - ignored_snapshot_ids.add(s_id) - return filtered_dag, ignored_snapshot_ids - def _build_restatements( self, dag: DAG[SnapshotId], earliest_interval_start: TimeLike ) -> t.Dict[SnapshotId, Interval]: - def is_restateable_snapshot(snapshot: Snapshot) -> bool: - if not self._is_dev and snapshot.disable_restatement: - return False - return not snapshot.is_symbolic and not snapshot.is_seed - restate_models = self._restate_models if restate_models == set(): # This is a warning but we print this as error since the Console is lacking API for warnings. @@ -322,8 +386,10 @@ def is_restateable_snapshot(snapshot: Snapshot) -> bool: restate_models = { s.name for s in self._context_diff.new_snapshots.values() - if s.is_materialized - and (self._forward_only or s.model.forward_only) + if s.is_model + and not s.is_symbolic + and (s.is_forward_only or s.model.forward_only) + and not s.is_no_preview and ( # Metadata changes should not be previewed. self._context_diff.directly_modified(s.name) @@ -335,49 +401,77 @@ def is_restateable_snapshot(snapshot: Snapshot) -> bool: if not restate_models: return {} + start = self._start or earliest_interval_start + end = self._end or now() + # Add restate snapshots and their downstream snapshots - dummy_interval = (sys.maxsize, -sys.maxsize) for model_fqn in restate_models: - snapshot = self._model_fqn_to_snapshot.get(model_fqn) - if not snapshot: + if model_fqn not in self._model_fqn_to_snapshot: raise PlanError(f"Cannot restate model '{model_fqn}'. Model does not exist.") - if not forward_only_preview_needed: - if not self._is_dev and snapshot.disable_restatement: - # This is a warning but we print this as error since the Console is lacking API for warnings. - self._console.log_error( - f"Cannot restate model '{model_fqn}'. Restatement is disabled for this model." - ) - continue - elif snapshot.is_symbolic or snapshot.is_seed: - logger.info("Skipping restatement for model '%s'", model_fqn) - continue - restatements[snapshot.snapshot_id] = dummy_interval - for downstream_s_id in dag.downstream(snapshot.snapshot_id): - if is_restateable_snapshot(self._context_diff.snapshots[downstream_s_id]): - restatements[downstream_s_id] = dummy_interval - # Get restatement intervals for all restated snapshots and make sure that if a snapshot expands it's + # Get restatement intervals for all restated snapshots and make sure that if an incremental snapshot expands it's # restatement range that it's downstream dependencies all expand their restatement ranges as well. for s_id in dag: - if s_id not in restatements: - continue snapshot = self._context_diff.snapshots[s_id] - interval = snapshot.get_removal_interval( - self._start or earliest_interval_start, - self._end or now(), - self._execution_time, - strict=False, - is_preview=is_preview, - ) + + if is_preview and snapshot.is_no_preview: + continue + # Since we are traversing the graph in topological order and the largest interval range is pushed down # the graph we just have to check our immediate parents in the graph and not the whole upstream graph. - snapshot_dependencies = snapshot.parents - possible_intervals = [ - restatements.get(s, dummy_interval) for s in snapshot_dependencies - ] + [interval] + restating_parents = [ + self._context_diff.snapshots[s] for s in snapshot.parents if s in restatements + ] + + if not restating_parents and snapshot.name not in restate_models: + continue + + if not forward_only_preview_needed: + if self._is_dev and not snapshot.is_paused: + self._console.log_warning( + f"Cannot restate model '{snapshot.name}' because the current version is used in production. " + "Run the restatement against the production environment instead to restate this model." + ) + continue + elif (not self._is_dev or not snapshot.is_paused) and snapshot.disable_restatement: + self._console.log_warning( + f"Cannot restate model '{snapshot.name}'. " + "Restatement is disabled for this model to prevent possible data loss. " + "If you want to restate this model, change the model's `disable_restatement` setting to `false`." + ) + continue + elif snapshot.is_seed: + logger.info("Skipping restatement for model '%s'", snapshot.name) + continue + + possible_intervals = { + restatements[p.snapshot_id] for p in restating_parents if p.is_incremental + } + possible_intervals.add( + snapshot.get_removal_interval( + start, + end, + self._execution_time, + strict=False, + is_preview=is_preview, + ) + ) snapshot_start = min(i[0] for i in possible_intervals) snapshot_end = max(i[1] for i in possible_intervals) + + # We may be tasked with restating a time range smaller than the target snapshot interval unit + # For example, restating an hour of Hourly Model A, which has a downstream dependency of Daily Model B + # we need to ensure the whole affected day in Model B is restated + floored_snapshot_start = snapshot.node.interval_unit.cron_floor(snapshot_start) + floored_snapshot_end = snapshot.node.interval_unit.cron_floor(snapshot_end) + if to_timestamp(floored_snapshot_end) < snapshot_end: + snapshot_start = to_timestamp(floored_snapshot_start) + snapshot_end = to_timestamp( + snapshot.node.interval_unit.cron_next(floored_snapshot_end) + ) + restatements[s_id] = (snapshot_start, snapshot_end) + return restatements def _build_directly_and_indirectly_modified( @@ -412,47 +506,61 @@ def _build_directly_and_indirectly_modified( indirectly_modified, ) - def _build_models_to_backfill(self, dag: DAG[SnapshotId]) -> t.Optional[t.Set[str]]: - if self._backfill_models is None: + def _build_models_to_backfill( + self, dag: DAG[SnapshotId], restatements: t.Collection[SnapshotId] + ) -> t.Optional[t.Set[str]]: + backfill_models = ( + self._backfill_models + if self._backfill_models is not None + else [r.name for r in restatements] + # Only backfill models explicitly marked for restatement. + if self._restate_models + else None + ) + if backfill_models is None: return None - if not self._is_dev: - raise PlanError( - "Selecting models to backfill is only supported for development environments." - ) return { self._context_diff.snapshots[s_id].name for s_id in dag.subdag( *[ self._model_fqn_to_snapshot[m].snapshot_id - for m in self._backfill_models + for m in backfill_models if m in self._model_fqn_to_snapshot ] ).sorted } - def _adjust_new_snapshot_intervals(self) -> None: - old_snapshots = { - (old.name, old.version_get_or_generate()): old - for _, old in self._context_diff.modified_snapshots.values() - } - - for new in self._context_diff.new_snapshots.values(): - new.intervals = [] - new.dev_intervals = [] - old = old_snapshots.get((new.name, new.version_get_or_generate())) - if not old: + def _adjust_snapshot_intervals(self) -> None: + for new, old in self._context_diff.modified_snapshots.values(): + if not new.is_model or not old.is_model: continue - new.merge_intervals(old) - if new.is_forward_only: - new.dev_intervals = new.intervals.copy() - - def _check_destructive_changes(self, directly_modified: t.Set[SnapshotId]) -> None: + is_same_version = old.version_get_or_generate() == new.version_get_or_generate() + if is_same_version and should_force_rebuild(old, new): + # If the difference between 2 snapshots requires a full rebuild, + # then clear the intervals for the new snapshot. + self._context_diff.snapshots[new.snapshot_id].intervals = [] + elif new.snapshot_id in self._context_diff.new_snapshots: + new.intervals = [] + new.dev_intervals = [] + if is_same_version: + new.merge_intervals(old) + if new.is_forward_only: + new.dev_intervals = new.intervals.copy() + + def _check_destructive_additive_changes(self, directly_modified: t.Set[SnapshotId]) -> None: for s_id in sorted(directly_modified): + if s_id.name not in self._context_diff.modified_snapshots: + continue + snapshot = self._context_diff.snapshots[s_id] + needs_destructive_check = snapshot.needs_destructive_check( + self._allow_destructive_models + ) + needs_additive_check = snapshot.needs_additive_check(self._allow_additive_models) # should we raise/warn if this snapshot has/inherits a destructive change? - should_raise_or_warn = ( - self._is_forward_only_change(s_id) or self._forward_only - ) and snapshot.needs_destructive_check(self._allow_destructive_models) + should_raise_or_warn = (self._is_forward_only_change(s_id) or self._forward_only) and ( + needs_destructive_check or needs_additive_check + ) if not should_raise_or_warn or not snapshot.is_model: continue @@ -466,26 +574,44 @@ def _check_destructive_changes(self, directly_modified: t.Set[SnapshotId]) -> No if columns_to_types_all_known(old_columns_to_types) and columns_to_types_all_known( new_columns_to_types ): - schema_diff = self._engine_schema_differ.compare_columns( - new.name, - old_columns_to_types, - new_columns_to_types, + alter_operations = t.cast( + t.List[TableAlterOperation], + get_schema_differ(snapshot.model.dialect).compare_columns( + new.name, + old_columns_to_types, + new_columns_to_types, + ignore_destructive=new.model.on_destructive_change.is_ignore, + ignore_additive=new.model.on_additive_change.is_ignore, + ), ) - if has_drop_alteration(schema_diff): - warning_msg = f"Plan results in a destructive change to forward-only model '{snapshot.name}'s schema" - if snapshot.model.on_destructive_change.is_warn: - logger.warning(warning_msg) - else: + snapshot_name = snapshot.name + model_dialect = snapshot.model.dialect + + if needs_destructive_check and has_drop_alteration(alter_operations): + self._console.log_destructive_change( + snapshot_name, + alter_operations, + model_dialect, + error=not snapshot.model.on_destructive_change.is_warn, + ) + if snapshot.model.on_destructive_change.is_error: raise PlanError( - f"{warning_msg}. To allow this, change the model's `on_destructive_change` setting to `warn` or `allow` or include it in the plan's `--allow-destructive-model` option." + "Plan requires a destructive change to a forward-only model." ) + if needs_additive_check and has_additive_alteration(alter_operations): + self._console.log_additive_change( + snapshot_name, + alter_operations, + model_dialect, + error=not snapshot.model.on_additive_change.is_warn, + ) + if snapshot.model.on_additive_change.is_error: + raise PlanError("Plan requires an additive change to a forward-only model.") + def _categorize_snapshots( - self, - dag: DAG[SnapshotId], - directly_modified: t.Set[SnapshotId], - indirectly_modified: SnapshotMapping, + self, dag: DAG[SnapshotId], indirectly_modified: SnapshotMapping ) -> None: """Automatically categorizes snapshots that can be automatically categorized and returns a list of added and directly modified snapshots as well as the mapping of @@ -496,152 +622,173 @@ def _categorize_snapshots( # assigned to its upstream dependencies. for s_id in dag: snapshot = self._context_diff.snapshots.get(s_id) - if not snapshot or snapshot.change_category: + + if not snapshot or not self._is_new_snapshot(snapshot): continue - if s_id.name in self._context_diff.modified_snapshots: - is_directly_modified = self._context_diff.directly_modified(s_id.name) - if self._is_new_snapshot(snapshot): - if self._forward_only: - # In case of the forward only plan any modifications result in reuse of the - # previous version for non-seed models. - # New snapshots of seed models are considered non-breaking ones. - if not snapshot.is_seed: - snapshot.categorize_as(SnapshotChangeCategory.FORWARD_ONLY) - else: - snapshot.categorize_as(SnapshotChangeCategory.NON_BREAKING) - elif self._is_forward_only_change(s_id) and is_directly_modified: - self._set_choice( - snapshot, - SnapshotChangeCategory.FORWARD_ONLY, - directly_modified, - indirectly_modified, - ) - elif self._auto_categorization_enabled and is_directly_modified: - s_id_with_missing_columns: t.Optional[SnapshotId] = None - this_sid_with_downstream = indirectly_modified.get(s_id, set()) | {s_id} - for downstream_s_id in this_sid_with_downstream: - downstream_snapshot = self._context_diff.snapshots[downstream_s_id] - if ( - downstream_snapshot.is_model - and downstream_snapshot.model.columns_to_types is None - ): - s_id_with_missing_columns = downstream_s_id - break - - new, old = self._context_diff.modified_snapshots[s_id.name] - if s_id_with_missing_columns is None: - change_category = categorize_change( - new, old, config=self._categorizer_config - ) - if change_category is not None: - self._set_choice( - new, change_category, directly_modified, indirectly_modified - ) - else: - mode = self._categorizer_config.dict().get( - new.model.source_type, AutoCategorizationMode.OFF - ) - if mode == AutoCategorizationMode.FULL: - self._set_choice( - new, - SnapshotChangeCategory.BREAKING, - directly_modified, - indirectly_modified, - ) - - if ( - not is_directly_modified - and not snapshot.version - and not any( - self._context_diff.directly_modified(upstream.name) - and not self._context_diff.snapshots[upstream].version - for upstream in dag.upstream(s_id) - ) - ): - if self._context_diff.indirectly_modified(snapshot.name): - # Set to breaking if an indirect child has no directly modified parents - # that need a decision. this can happen when a revert to a parent causes - # an indirectly modified snapshot to be created because of a new parent - snapshot.categorize_as( - SnapshotChangeCategory.FORWARD_ONLY - if self._is_forward_only_change(s_id) - else SnapshotChangeCategory.INDIRECT_BREAKING - ) - else: - # Metadata updated. - snapshot.categorize_as(SnapshotChangeCategory.METADATA) + forward_only = self._forward_only or self._is_forward_only_change(s_id) + if forward_only and s_id.name in self._context_diff.modified_snapshots: + new, old = self._context_diff.modified_snapshots[s_id.name] + if is_breaking_kind_change(old, new) or snapshot.is_seed: + # Breaking kind changes and seed changes can't be forward-only. + forward_only = False - elif s_id in self._context_diff.added and self._is_new_snapshot(snapshot): - snapshot.categorize_as( - SnapshotChangeCategory.FORWARD_ONLY - if self._is_forward_only_change(s_id) - else SnapshotChangeCategory.BREAKING - ) + if s_id in self._choices: + snapshot.categorize_as(self._choices[s_id], forward_only) + continue + + if s_id in self._context_diff.added: + snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only) + elif s_id.name in self._context_diff.modified_snapshots: + self._categorize_snapshot(snapshot, forward_only, dag, indirectly_modified) - def _set_choice( + def _categorize_snapshot( self, snapshot: Snapshot, - choice: SnapshotChangeCategory, - directly_modified: t.Set[SnapshotId], + forward_only: bool, + dag: DAG[SnapshotId], indirectly_modified: SnapshotMapping, ) -> None: - if self._forward_only: - raise PlanError("Choice setting is not supported by a forward-only plan.") - if not self._is_new_snapshot(snapshot): - raise SQLMeshError( - f"A choice can't be changed for the existing version of '{snapshot.name}'." - ) - - snapshot.categorize_as(choice) - - is_breaking_choice = choice in ( - SnapshotChangeCategory.BREAKING, - SnapshotChangeCategory.INDIRECT_BREAKING, - ) + s_id = snapshot.snapshot_id + + if self._context_diff.directly_modified(s_id.name): + if self._auto_categorization_enabled: + new, old = self._context_diff.modified_snapshots[s_id.name] + if is_breaking_kind_change(old, new): + snapshot.categorize_as(SnapshotChangeCategory.BREAKING, False) + return + + s_id_with_missing_columns: t.Optional[SnapshotId] = None + this_sid_with_downstream = indirectly_modified.get(s_id, set()) | {s_id} + for downstream_s_id in this_sid_with_downstream: + downstream_snapshot = self._context_diff.snapshots[downstream_s_id] + if ( + downstream_snapshot.is_model + and downstream_snapshot.model.columns_to_types is None + ): + s_id_with_missing_columns = downstream_s_id + break + + if s_id_with_missing_columns is None: + change_category = categorize_change(new, old, config=self._categorizer_config) + if change_category is not None: + snapshot.categorize_as(change_category, forward_only) + else: + mode = self._categorizer_config.dict().get( + new.model.source_type, AutoCategorizationMode.OFF + ) + if mode == AutoCategorizationMode.FULL: + snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only) + elif self._context_diff.indirectly_modified(snapshot.name): + if snapshot.is_materialized_view and not forward_only: + # We categorize changes as breaking to allow for instantaneous switches in a virtual layer. + # Otherwise, there might be a potentially long downtime during MVs recreation. + # In the case of forward-only changes this optimization is not applicable because we want to continue + # using the same (existing) table version. + snapshot.categorize_as(SnapshotChangeCategory.INDIRECT_BREAKING, forward_only) + return + + all_upstream_forward_only = set() + all_upstream_categories = set() + direct_parent_categories = set() + + for p_id in dag.upstream(s_id): + parent = self._context_diff.snapshots.get(p_id) + + if parent and self._is_new_snapshot(parent): + all_upstream_categories.add(parent.change_category) + all_upstream_forward_only.add(parent.is_forward_only) + if p_id in snapshot.parents: + direct_parent_categories.add(parent.change_category) + + if all_upstream_forward_only == {True} or ( + snapshot.is_model and snapshot.model.forward_only + ): + forward_only = True - for child_s_id in indirectly_modified.get(snapshot.snapshot_id, set()): - child_snapshot = self._context_diff.snapshots[child_s_id] - # If the snapshot isn't new then we are reverting to a previously existing snapshot - # and therefore we don't want to recategorize it. - if not self._is_new_snapshot(child_snapshot): + if direct_parent_categories.intersection( + {SnapshotChangeCategory.BREAKING, SnapshotChangeCategory.INDIRECT_BREAKING} + ): + snapshot.categorize_as(SnapshotChangeCategory.INDIRECT_BREAKING, forward_only) + elif not direct_parent_categories: + snapshot.categorize_as( + self._get_orphaned_indirect_change_category(snapshot), forward_only + ) + elif all_upstream_categories == {SnapshotChangeCategory.METADATA}: + snapshot.categorize_as(SnapshotChangeCategory.METADATA, forward_only) + else: + snapshot.categorize_as(SnapshotChangeCategory.INDIRECT_NON_BREAKING, forward_only) + else: + # Metadata updated. + snapshot.categorize_as(SnapshotChangeCategory.METADATA, forward_only) + + def _get_orphaned_indirect_change_category( + self, indirect_snapshot: Snapshot + ) -> SnapshotChangeCategory: + """Sometimes an indirectly changed downstream snapshot ends up with no directly changed parents introduced in the same plan. + This may happen when 2 or more parent models were changed independently in different plans and then the changes were + merged together and applied in a single plan. As a result, a combination of 2 or more previously changed parents produces + a new downstream snapshot not previously seen. + + This function is used to infer the correct change category for such downstream snapshots based on change categories of their parents. + """ + previous_snapshot = self._context_diff.modified_snapshots[indirect_snapshot.name][1] + previous_parent_snapshot_ids = {p.name: p for p in previous_snapshot.parents} + + current_parent_snapshots = [ + self._context_diff.snapshots[p_id] + for p_id in indirect_snapshot.parents + if p_id in self._context_diff.snapshots + ] + + indirect_category: t.Optional[SnapshotChangeCategory] = None + for current_parent_snapshot in current_parent_snapshots: + if current_parent_snapshot.name not in previous_parent_snapshot_ids: + # This is a new parent so falling back to INDIRECT_BREAKING + return SnapshotChangeCategory.INDIRECT_BREAKING + pevious_parent_snapshot_id = previous_parent_snapshot_ids[current_parent_snapshot.name] + + if current_parent_snapshot.snapshot_id == pevious_parent_snapshot_id: + # There were no new versions of this parent since the previous version of this snapshot, + # so we can skip it continue - is_forward_only_child = self._is_forward_only_change(child_s_id) + # Find the previous snapshot ID of the same parent in the historical chain + previous_parent_found = False + previous_parent_categories = set() + for pv in reversed(current_parent_snapshot.all_versions): + pv_snapshot_id = pv.snapshot_id(current_parent_snapshot.name) + if pv_snapshot_id == pevious_parent_snapshot_id: + previous_parent_found = True + break + previous_parent_categories.add(pv.change_category) - if is_breaking_choice: - child_snapshot.categorize_as( - SnapshotChangeCategory.FORWARD_ONLY - if is_forward_only_child - else SnapshotChangeCategory.INDIRECT_BREAKING - ) - elif choice == SnapshotChangeCategory.FORWARD_ONLY: - child_snapshot.categorize_as(SnapshotChangeCategory.FORWARD_ONLY) - else: - child_snapshot.categorize_as(SnapshotChangeCategory.INDIRECT_NON_BREAKING) + if not previous_parent_found: + # The previous parent is not in the historical chain so falling back to INDIRECT_BREAKING + return SnapshotChangeCategory.INDIRECT_BREAKING - for upstream_id in directly_modified: - if upstream_id == snapshot.snapshot_id or child_s_id not in indirectly_modified.get( - upstream_id, set() - ): - continue + if previous_parent_categories.intersection( + {SnapshotChangeCategory.BREAKING, SnapshotChangeCategory.INDIRECT_BREAKING} + ): + # One of the new parents in the chain was breaking so this indirect snapshot is breaking + return SnapshotChangeCategory.INDIRECT_BREAKING - upstream = self._context_diff.snapshots[upstream_id] - if upstream.change_category == SnapshotChangeCategory.BREAKING: - # If any other snapshot specified breaking this child, then that child - # needs to be backfilled as a part of the plan. - child_snapshot.categorize_as( - SnapshotChangeCategory.FORWARD_ONLY - if is_forward_only_child - else SnapshotChangeCategory.INDIRECT_BREAKING - ) - break - elif ( - upstream.change_category == SnapshotChangeCategory.FORWARD_ONLY - and child_snapshot.is_indirect_non_breaking - ): - # FORWARD_ONLY takes precedence over INDIRECT_NON_BREAKING. - child_snapshot.categorize_as(SnapshotChangeCategory.FORWARD_ONLY) + if previous_parent_categories.intersection( + { + SnapshotChangeCategory.NON_BREAKING, + SnapshotChangeCategory.INDIRECT_NON_BREAKING, + } + ): + # All changes in the chain were non-breaking so this indirect snapshot can be non-breaking too + indirect_category = SnapshotChangeCategory.INDIRECT_NON_BREAKING + elif ( + previous_parent_categories == {SnapshotChangeCategory.METADATA} + and indirect_category is None + ): + # All changes in the chain were metadata so this indirect snapshot can be metadata too + indirect_category = SnapshotChangeCategory.METADATA + + return indirect_category or SnapshotChangeCategory.INDIRECT_BREAKING def _apply_effective_from(self) -> None: if self._effective_from: @@ -651,21 +798,26 @@ def _apply_effective_from(self) -> None: raise PlanError("Effective date cannot be in the future.") for snapshot in self._context_diff.new_snapshots.values(): - if not snapshot.disable_restatement and not snapshot.full_history_restatement_only: + if ( + snapshot.evaluatable + and not snapshot.disable_restatement + and (not snapshot.full_history_restatement_only or not snapshot.is_incremental) + ): snapshot.effective_from = self._effective_from def _is_forward_only_change(self, s_id: SnapshotId) -> bool: + if not self._context_diff.directly_modified( + s_id.name + ) and not self._context_diff.indirectly_modified(s_id.name): + return False snapshot = self._context_diff.snapshots[s_id] if snapshot.name in self._context_diff.modified_snapshots: _, old = self._context_diff.modified_snapshots[snapshot.name] - # If the model kind has changed, then we should not consider this to be a forward-only change. - if snapshot.is_model and old.model.kind.name != snapshot.model.kind.name: + # If the model kind has changed in a breaking way, then we can't consider this to be a forward-only change. + if snapshot.is_model and is_breaking_kind_change(old, snapshot): return False return ( - snapshot.is_model - and snapshot.model.forward_only - and not snapshot.change_category - and bool(snapshot.previous_versions) + snapshot.is_model and snapshot.model.forward_only and bool(snapshot.previous_versions) ) def _is_new_snapshot(self, snapshot: Snapshot) -> bool: @@ -678,25 +830,37 @@ def _ensure_valid_date_range(self) -> None: "The start and end dates can't be set for a production plan without restatements." ) - def _ensure_no_forward_only_revert(self) -> None: - """Ensures that a previously superseded breaking / non-breaking snapshot is not being - used again to replace an existing forward-only snapshot with the same version. + if (start := self.start) and (end := self.end): + if to_datetime(start) > to_datetime(end): + raise PlanError( + f"Plan end date: '{time_like_to_str(end)}' must be after the plan start date: '{time_like_to_str(start)}'" + ) - In other words there is no going back to the original non-forward-only snapshot with - the same version once a forward-only change for that version has been introduced. - """ - for name, (candidate, promoted) in self._context_diff.modified_snapshots.items(): - if ( - candidate.snapshot_id not in self._context_diff.new_snapshots - and promoted.is_forward_only - and not promoted.is_paused - and not candidate.reuses_previous_version - and promoted.version == candidate.version - ): + if end := self.end: + if to_datetime(end) > to_datetime(self.execution_time): raise PlanError( - f"Attempted to revert to an unrevertable version of model '{name}'. Run `sqlmesh plan` again to mitigate the issue." + f"Plan end date: '{time_like_to_str(end)}' cannot be in the future (execution time: '{time_like_to_str(self.execution_time)}')" ) + # Validate model-specific start/end dates + if (start := self.start or self._default_start) and (end := self.end): + start_ts = to_datetime(start) + end_ts = to_datetime(end) + if start_ts > end_ts: + models_to_check: t.Set[str] = ( + set(self._backfill_models or []) + | set(self._context_diff.modified_snapshots.keys()) + | {s.name for s in self._context_diff.added} + | set((self._end_override_per_model or {}).keys()) + ) + for model_name in models_to_check: + if snapshot := self._model_fqn_to_snapshot.get(model_name): + if snapshot.node.start is None or to_datetime(snapshot.node.start) > end_ts: + raise PlanError( + f"Model '{model_name}': Start date / time '({time_like_to_str(start_ts)})' can't be greater than end date / time '({time_like_to_str(end_ts)})'.\n" + f"Set the `start` attribute in your project config model defaults to avoid this issue." + ) + def _ensure_no_broken_references(self) -> None: for snapshot in self._context_diff.snapshots.values(): broken_references = { @@ -708,24 +872,17 @@ def _ensure_no_broken_references(self) -> None: f"""Removed {broken_references_msg} are referenced in '{snapshot.name}'. Please remove broken references before proceeding.""" ) - def _ensure_no_new_snapshots_with_restatements(self) -> None: - if self._restate_models is not None and ( - self._context_diff.new_snapshots or self._context_diff.modified_snapshots - ): - raise PlanError( - "Model changes and restatements can't be a part of the same plan. " - "Revert or apply changes before proceeding with restatements." - ) - def _ensure_new_env_with_changes(self) -> None: if ( self._is_dev and not self._include_unmodified and self._context_diff.is_new_environment and not self._context_diff.has_snapshot_changes + and not self._context_diff.has_environment_statements_changes + and not self._backfill_models ): raise NoChangesPlanError( - "No changes were detected. Make a change or run with --include-unmodified to create a new environment without changes." + f"Creating a new environment requires a change, but project files match the `{self._context_diff.create_from}` environment. Make a change or use the --include-unmodified flag to create a new environment without changes." ) @cached_property @@ -737,8 +894,31 @@ def _forward_only_preview_needed(self) -> bool: self._enable_preview and any( snapshot.model.forward_only - for snapshot, _ in self._context_diff.modified_snapshots.values() + for snapshot in self._modified_and_added_snapshots if snapshot.is_model ) ) ) + + @cached_property + def _non_forward_only_preview_needed(self) -> bool: + if not self._is_dev: + return False + for snapshot in self._modified_and_added_snapshots: + if not snapshot.is_model: + continue + if ( + not snapshot.virtual_environment_mode.is_full + or snapshot.model.auto_restatement_cron is not None + ): + return True + return False + + @cached_property + def _modified_and_added_snapshots(self) -> t.List[Snapshot]: + return [ + snapshot + for snapshot in self._context_diff.snapshots.values() + if snapshot.name in self._context_diff.modified_snapshots + or snapshot.snapshot_id in self._context_diff.added + ] diff --git a/sqlmesh/core/plan/common.py b/sqlmesh/core/plan/common.py new file mode 100644 index 0000000000..bece17639c --- /dev/null +++ b/sqlmesh/core/plan/common.py @@ -0,0 +1,223 @@ +from __future__ import annotations +import typing as t +import logging +from dataclasses import dataclass, field + +from sqlmesh.core.state_sync import StateReader +from sqlmesh.core.snapshot import Snapshot, SnapshotId, SnapshotIdAndVersion, SnapshotNameVersion +from sqlmesh.core.snapshot.definition import Interval +from sqlmesh.utils.dag import DAG +from sqlmesh.utils.date import now_timestamp + +logger = logging.getLogger(__name__) + + +def should_force_rebuild(old: Snapshot, new: Snapshot) -> bool: + if new.is_view and new.is_indirect_non_breaking and not new.is_forward_only: + # View models always need to be rebuilt to reflect updated upstream dependencies + return True + if new.is_seed and not ( + new.is_metadata + and new.previous_version + and new.previous_version.snapshot_id(new.name) == old.snapshot_id + ): + # Seed models always need to be rebuilt to reflect changes in the seed file + # Unless only their metadata has been updated (eg description added) and the seed file has not been touched + return True + return is_breaking_kind_change(old, new) + + +def is_breaking_kind_change(old: Snapshot, new: Snapshot) -> bool: + if new.is_model != old.is_model: + # If one is a model and the other isn't, then we need to rebuild + return True + if not new.is_model or not old.is_model: + # If neither are models, then we don't need to rebuild + # Note that the remaining checks only apply to model snapshots + return False + if old.virtual_environment_mode != new.virtual_environment_mode: + # If the virtual environment mode has changed, then we need to rebuild + return True + if old.model.kind.name == new.model.kind.name: + # If the kind hasn't changed, then we don't need to rebuild + return False + if not old.is_incremental or not new.is_incremental: + # If either is not incremental, then we need to rebuild + return True + if old.model.partitioned_by == new.model.partitioned_by: + # If the partitioning hasn't changed, then we don't need to rebuild + return False + return True + + +@dataclass +class SnapshotIntervalClearRequest: + # affected snapshot + snapshot: SnapshotIdAndVersion + + # which interval to clear + interval: Interval + + # which environments this snapshot is currently promoted + # note that this can be empty if the snapshot exists because its ttl has not expired + # but it is not part of any particular environment + environment_names: t.Set[str] = field(default_factory=set) + + @property + def snapshot_id(self) -> SnapshotId: + return self.snapshot.snapshot_id + + @property + def sorted_environment_names(self) -> t.List[str]: + return list(sorted(self.environment_names)) + + +def identify_restatement_intervals_across_snapshot_versions( + state_reader: StateReader, + prod_restatements: t.Dict[str, Interval], + disable_restatement_models: t.Set[str], + loaded_snapshots: t.Dict[SnapshotId, Snapshot], + current_ts: t.Optional[int] = None, +) -> t.Dict[SnapshotId, SnapshotIntervalClearRequest]: + """ + Given a map of snapshot names + intervals to restate in prod: + - Look up matching snapshots (match based on name - regardless of version, to get all versions) + - For each match, also match downstream snapshots in each dev environment while filtering out models that have restatement disabled + - Return a list of all snapshots that are affected + the interval that needs to be cleared for each + + The goal here is to produce a list of intervals to invalidate across all dev snapshots so that a subsequent plan or + cadence run in those environments causes the intervals to be repopulated. + """ + if not prod_restatements: + return {} + + # Although :loaded_snapshots is sourced from RestatementStage.all_snapshots, since the only time we ever need + # to clear intervals across all environments is for prod, the :loaded_snapshots here are always from prod + prod_name_versions: t.Set[SnapshotNameVersion] = { + s.name_version for s in loaded_snapshots.values() + } + + snapshot_intervals_to_clear: t.Dict[SnapshotId, SnapshotIntervalClearRequest] = {} + + for env_summary in state_reader.get_environments_summary(): + # Fetch the full environment object one at a time to avoid loading all environments into memory at once + env = state_reader.get_environment(env_summary.name) + if not env: + logger.warning("Environment %s not found", env_summary.name) + continue + + snapshots_by_name = {s.name: s.table_info for s in env.snapshots} + + # We dont just restate matching snapshots, we also have to restate anything downstream of them + # so that if A gets restated in prod and dev has A <- B <- C, B and C get restated in dev + env_dag = DAG({s.name: {p.name for p in s.parents} for s in env.snapshots}) + + for restate_snapshot_name, interval in prod_restatements.items(): + if restate_snapshot_name not in snapshots_by_name: + # snapshot is not promoted in this environment + continue + + affected_snapshot_names = [ + x + for x in ([restate_snapshot_name] + env_dag.downstream(restate_snapshot_name)) + if x not in disable_restatement_models + ] + + for affected_snapshot_name in affected_snapshot_names: + affected_snapshot = snapshots_by_name[affected_snapshot_name] + + # Don't clear intervals for a dev snapshot if it shares the same physical version with prod. + # Otherwise, prod will be affected by what should be a dev operation + if affected_snapshot.name_version in prod_name_versions: + continue + + clear_request = snapshot_intervals_to_clear.get(affected_snapshot.snapshot_id) + if not clear_request: + clear_request = SnapshotIntervalClearRequest( + snapshot=affected_snapshot.id_and_version, interval=interval + ) + snapshot_intervals_to_clear[affected_snapshot.snapshot_id] = clear_request + + clear_request.environment_names |= set([env.name]) + + # snapshot_intervals_to_clear now contains the entire hierarchy of affected snapshots based + # on building the DAG for each environment and including downstream snapshots + # but, what if there are affected snapshots that arent part of any environment? + unique_snapshot_names = set(snapshot_id.name for snapshot_id in snapshot_intervals_to_clear) + + current_ts = current_ts or now_timestamp() + all_matching_non_prod_snapshots = { + s.snapshot_id: s + for s in state_reader.get_snapshots_by_names( + snapshot_names=unique_snapshot_names, current_ts=current_ts, exclude_expired=True + ) + # Don't clear intervals for a snapshot if it shares the same physical version with prod. + # Otherwise, prod will be affected by what should be a dev operation + if s.name_version not in prod_name_versions + } + + # identify the ones that we havent picked up yet, which are the ones that dont exist in any environment + if remaining_snapshot_ids := set(all_matching_non_prod_snapshots).difference( + snapshot_intervals_to_clear + ): + # these snapshot id's exist in isolation and may be related to a downstream dependency of the :prod_restatements, + # rather than directly related, so we can't simply look up the interval to clear based on :prod_restatements. + # To figure out the interval that should be cleared, we can match to the existing list based on name + # and conservatively take the widest interval that shows up + snapshot_name_to_widest_interval: t.Dict[str, Interval] = {} + for s_id, clear_request in snapshot_intervals_to_clear.items(): + current_start, current_end = snapshot_name_to_widest_interval.get( + s_id.name, clear_request.interval + ) + next_start, next_end = clear_request.interval + + next_start = min(current_start, next_start) + next_end = max(current_end, next_end) + + snapshot_name_to_widest_interval[s_id.name] = (next_start, next_end) + + for remaining_snapshot_id in remaining_snapshot_ids: + remaining_snapshot = all_matching_non_prod_snapshots[remaining_snapshot_id] + snapshot_intervals_to_clear[remaining_snapshot_id] = SnapshotIntervalClearRequest( + snapshot=remaining_snapshot, + interval=snapshot_name_to_widest_interval[remaining_snapshot_id.name], + ) + + # for any affected full_history_restatement_only snapshots, we need to widen the intervals being restated to + # include the whole time range for that snapshot. This requires a call to state to load the full snapshot record, + # so we only do it if necessary + full_history_restatement_snapshot_ids = [ + # FIXME: full_history_restatement_only is just one indicator that the snapshot can only be fully refreshed, the other one is Model.depends_on_self + # however, to figure out depends_on_self, we have to render all the model queries which, alongside having to fetch full snapshots from state, + # is problematic in secure environments that are deliberately isolated from arbitrary user code (since rendering a query may require user macros to be present) + # So for now, these are not considered + s_id + for s_id, s in snapshot_intervals_to_clear.items() + if s.snapshot.full_history_restatement_only + ] + if full_history_restatement_snapshot_ids: + # only load full snapshot records that we havent already loaded + additional_snapshots = state_reader.get_snapshots( + [ + s.snapshot_id + for s in full_history_restatement_snapshot_ids + if s.snapshot_id not in loaded_snapshots + ] + ) + + all_snapshots = loaded_snapshots | additional_snapshots + + for full_snapshot_id in full_history_restatement_snapshot_ids: + full_snapshot = all_snapshots[full_snapshot_id] + intervals_to_clear = snapshot_intervals_to_clear[full_snapshot_id] + + original_start, original_end = intervals_to_clear.interval + + # get_removal_interval() widens intervals if necessary + new_interval = full_snapshot.get_removal_interval( + start=original_start, end=original_end + ) + + intervals_to_clear.interval = new_interval + + return snapshot_intervals_to_clear diff --git a/sqlmesh/core/plan/definition.py b/sqlmesh/core/plan/definition.py index ccef658e5a..866299eff8 100644 --- a/sqlmesh/core/plan/definition.py +++ b/sqlmesh/core/plan/definition.py @@ -5,9 +5,11 @@ from datetime import datetime from enum import Enum from functools import cached_property +from pydantic import Field from sqlmesh.core.context_diff import ContextDiff -from sqlmesh.core.environment import Environment, EnvironmentNamingInfo +from sqlmesh.core.environment import Environment, EnvironmentNamingInfo, EnvironmentStatements +from sqlmesh.utils.metaprogramming import Executable # noqa from sqlmesh.core.node import IntervalUnit from sqlmesh.core.snapshot import ( DeployabilityIndex, @@ -27,6 +29,7 @@ from sqlmesh.utils.pydantic import PydanticModel SnapshotMapping = t.Dict[SnapshotId, t.Set[SnapshotId]] +UserProvidedFlags = t.Union[TimeLike, str, bool, t.List[str]] class Plan(PydanticModel, frozen=True): @@ -37,26 +40,49 @@ class Plan(PydanticModel, frozen=True): is_dev: bool skip_backfill: bool + empty_backfill: bool no_gaps: bool forward_only: bool allow_destructive_models: t.Set[str] + allow_additive_models: t.Set[str] include_unmodified: bool end_bounded: bool ensure_finalized_snapshots: bool + explain: bool + ignore_cron: bool = False environment_ttl: t.Optional[str] = None environment_naming_info: EnvironmentNamingInfo directly_modified: t.Set[SnapshotId] indirectly_modified: t.Dict[SnapshotId, t.Set[SnapshotId]] - ignored: t.Set[SnapshotId] deployability_index: DeployabilityIndex + selected_models_to_restate: t.Optional[t.Set[str]] = None + """Models that have been explicitly selected for restatement by a user""" restatements: t.Dict[SnapshotId, Interval] + """ + All models being restated, which are typically the explicitly selected ones + their downstream dependencies. + Note that dev previews are also considered restatements, so :selected_models_to_restate can be empty + while :restatements is still populated with dev previews + """ + restate_all_snapshots: bool + """Whether or not to clear intervals from state for other versions of the models listed in :restatements""" + + start_override_per_model: t.Optional[t.Dict[str, datetime]] + end_override_per_model: t.Optional[t.Dict[str, datetime]] + + selected_models_to_backfill: t.Optional[t.Set[str]] = None + """Models that have been explicitly selected for backfill by a user.""" models_to_backfill: t.Optional[t.Set[str]] = None + """All models that should be backfilled as part of this plan.""" effective_from: t.Optional[TimeLike] = None - execution_time: t.Optional[TimeLike] = None + execution_time_: t.Optional[TimeLike] = Field(default=None, alias="execution_time") + + user_provided_flags: t.Optional[t.Dict[str, UserProvidedFlags]] = None + selected_models: t.Optional[t.Set[str]] = None + """Models that have been selected for this plan (used for dbt selected_resources)""" @cached_property def start(self) -> TimeLike: @@ -71,7 +97,12 @@ def start(self) -> TimeLike: @cached_property def end(self) -> TimeLike: - return self.provided_end or now() + return self.provided_end or self.execution_time + + @cached_property + def execution_time(self) -> TimeLike: + # note: property is cached so that it returns a consistent timestamp for now() + return self.execution_time_ or now() @property def previous_plan_id(self) -> t.Optional[str]: @@ -79,20 +110,15 @@ def previous_plan_id(self) -> t.Optional[str]: @property def requires_backfill(self) -> bool: - return not self.skip_backfill and (bool(self.restatements) or bool(self.missing_intervals)) + return ( + not self.skip_backfill + and not self.empty_backfill + and (bool(self.restatements) or bool(self.missing_intervals)) + ) @property def has_changes(self) -> bool: - modified_snapshot_ids = { - *self.context_diff.added, - *self.context_diff.removed_snapshots, - *self.context_diff.current_modified_snapshot_ids, - } - self.ignored - return ( - self.context_diff.is_new_environment - or self.context_diff.is_unfinalized_environment - or bool(modified_snapshot_ids) - ) + return self.context_diff.has_changes @property def has_unmodified_unpromoted(self) -> bool: @@ -109,7 +135,7 @@ def categorized(self) -> t.List[Snapshot]: """Returns the already categorized snapshots.""" return [ self.context_diff.snapshots[s_id] - for s_id in sorted(self.directly_modified) + for s_id in sorted({*self.directly_modified, *self.metadata_updated}) if self.context_diff.snapshots[s_id].version ] @@ -122,11 +148,9 @@ def uncategorized(self) -> t.List[Snapshot]: if not self.context_diff.snapshots[s_id].version ] - @cached_property + @property def snapshots(self) -> t.Dict[SnapshotId, Snapshot]: - return { - s_id: s for s_id, s in self.context_diff.snapshots.items() if s_id not in self.ignored - } + return self.context_diff.snapshots @cached_property def modified_snapshots(self) -> t.Dict[SnapshotId, t.Union[Snapshot, SnapshotTableInfo]]: @@ -139,14 +163,21 @@ def modified_snapshots(self) -> t.Dict[SnapshotId, t.Union[Snapshot, SnapshotTab for s_id in sorted(downstream_s_ids) }, **self.context_diff.removed_snapshots, + **{s_id: self.context_diff.snapshots[s_id] for s_id in sorted(self.metadata_updated)}, + } + + @cached_property + def metadata_updated(self) -> t.Set[SnapshotId]: + return { + snapshot.snapshot_id + for snapshot, _ in self.context_diff.modified_snapshots.values() + if self.context_diff.metadata_updated(snapshot.name) } @property def new_snapshots(self) -> t.List[Snapshot]: """Gets only new snapshots in the plan/environment.""" - return [ - s for s in self.context_diff.new_snapshots.values() if s.snapshot_id not in self.ignored - ] + return list(self.context_diff.new_snapshots.values()) @property def missing_intervals(self) -> t.List[SnapshotIntervals]: @@ -162,7 +193,10 @@ def missing_intervals(self) -> t.List[SnapshotIntervals]: execution_time=self.execution_time, restatements=self.restatements, deployability_index=self.deployability_index, + start_override_per_model=self.start_override_per_model, + end_override_per_model=self.end_override_per_model, end_bounded=self.end_bounded, + ignore_cron=self.ignore_cron, ).items() if snapshot.is_model and missing ] @@ -179,22 +213,26 @@ def environment(self) -> Environment: snapshots_by_name = self.context_diff.snapshots_by_name snapshots = [s.table_info for s in self.snapshots.values()] - promoted_snapshot_ids = None - if self.is_dev and not self.include_unmodified: - promotable_snapshot_ids = self.context_diff.promotable_snapshot_ids.copy() - if self.models_to_backfill is not None: + promotable_snapshot_ids = None + if self.is_dev: + if self.selected_models_to_backfill is not None: # Only promote models that have been explicitly selected for backfill. - promotable_snapshot_ids &= { + promotable_snapshot_ids = { *self.context_diff.previously_promoted_snapshot_ids, *[ snapshots_by_name[m].snapshot_id - for m in self.models_to_backfill + for m in self.selected_models_to_backfill if m in snapshots_by_name ], } - promoted_snapshot_ids = [ - s.snapshot_id for s in snapshots if s.snapshot_id in promotable_snapshot_ids - ] + elif not self.include_unmodified: + promotable_snapshot_ids = self.context_diff.promotable_snapshot_ids.copy() + + promoted_snapshot_ids = ( + [s.snapshot_id for s in snapshots if s.snapshot_id in promotable_snapshot_ids] + if promotable_snapshot_ids is not None + else None + ) previous_finalized_snapshots = ( self.context_diff.environment_snapshots @@ -211,21 +249,106 @@ def environment(self) -> Environment: expiration_ts=expiration_ts, promoted_snapshot_ids=promoted_snapshot_ids, previous_finalized_snapshots=previous_finalized_snapshots, + requirements=self.context_diff.requirements, **self.environment_naming_info.dict(), ) def is_new_snapshot(self, snapshot: Snapshot) -> bool: """Returns True if the given snapshot is a new snapshot in this plan.""" snapshot_id = snapshot.snapshot_id - return snapshot_id in self.context_diff.new_snapshots and snapshot_id not in self.ignored + return snapshot_id in self.context_diff.new_snapshots def is_selected_for_backfill(self, model_fqn: str) -> bool: """Returns True if a model with the given FQN should be backfilled as part of this plan.""" return self.models_to_backfill is None or model_fqn in self.models_to_backfill + def to_evaluatable(self) -> EvaluatablePlan: + return EvaluatablePlan( + start=self.start, + end=self.end, + new_snapshots=self.new_snapshots, + environment=self.environment, + no_gaps=self.no_gaps, + skip_backfill=self.skip_backfill, + empty_backfill=self.empty_backfill, + restatements={s.name: i for s, i in self.restatements.items()}, + restate_all_snapshots=self.restate_all_snapshots, + is_dev=self.is_dev, + allow_destructive_models=self.allow_destructive_models, + allow_additive_models=self.allow_additive_models, + forward_only=self.forward_only, + end_bounded=self.end_bounded, + ensure_finalized_snapshots=self.ensure_finalized_snapshots, + ignore_cron=self.ignore_cron, + directly_modified_snapshots=sorted(self.directly_modified), + indirectly_modified_snapshots={ + s.name: sorted(snapshot_ids) for s, snapshot_ids in self.indirectly_modified.items() + }, + metadata_updated_snapshots=sorted(self.metadata_updated), + removed_snapshots=sorted(self.context_diff.removed_snapshots), + requires_backfill=self.requires_backfill, + models_to_backfill=self.models_to_backfill, + start_override_per_model=self.start_override_per_model, + end_override_per_model=self.end_override_per_model, + execution_time=self.execution_time, + disabled_restatement_models={ + s.name + for s in self.snapshots.values() + if s.is_model and s.model.disable_restatement + }, + environment_statements=self.context_diff.environment_statements, + user_provided_flags=self.user_provided_flags, + selected_models=self.selected_models, + ) + @cached_property def _earliest_interval_start(self) -> datetime: - return earliest_interval_start(self.snapshots.values()) + return earliest_interval_start(self.snapshots.values(), self.execution_time) + + +class EvaluatablePlan(PydanticModel): + """A serializable version of a plan that can be evaluated.""" + + start: TimeLike + end: TimeLike + new_snapshots: t.List[Snapshot] + environment: Environment + no_gaps: bool + skip_backfill: bool + empty_backfill: bool + restatements: t.Dict[str, Interval] + restate_all_snapshots: bool + is_dev: bool + allow_destructive_models: t.Set[str] + allow_additive_models: t.Set[str] + forward_only: bool + end_bounded: bool + ensure_finalized_snapshots: bool + ignore_cron: bool = False + directly_modified_snapshots: t.List[SnapshotId] + indirectly_modified_snapshots: t.Dict[str, t.List[SnapshotId]] + metadata_updated_snapshots: t.List[SnapshotId] + removed_snapshots: t.List[SnapshotId] + requires_backfill: bool + models_to_backfill: t.Optional[t.Set[str]] = None + start_override_per_model: t.Optional[t.Dict[str, datetime]] = None + end_override_per_model: t.Optional[t.Dict[str, datetime]] = None + execution_time: t.Optional[TimeLike] = None + disabled_restatement_models: t.Set[str] + environment_statements: t.Optional[t.List[EnvironmentStatements]] = None + user_provided_flags: t.Optional[t.Dict[str, UserProvidedFlags]] = None + selected_models: t.Optional[t.Set[str]] = None + + def is_selected_for_backfill(self, model_fqn: str) -> bool: + return self.models_to_backfill is None or model_fqn in self.models_to_backfill + + @property + def plan_id(self) -> str: + return self.environment.plan_id + + @property + def is_prod(self) -> bool: + return not self.is_dev class PlanStatus(str, Enum): @@ -260,8 +383,10 @@ def format_intervals(self, unit: t.Optional[IntervalUnit] = None) -> str: return format_intervals(self.merged_intervals, unit) -def earliest_interval_start(snapshots: t.Collection[Snapshot]) -> datetime: - earliest_start = earliest_start_date(snapshots) +def earliest_interval_start( + snapshots: t.Collection[Snapshot], execution_time: t.Optional[TimeLike] = None +) -> datetime: + earliest_start = earliest_start_date(snapshots, relative_to=execution_time) earliest_interval_starts = [s.intervals[0][0] for s in snapshots if s.intervals] return ( min(earliest_start, to_datetime(min(earliest_interval_starts))) diff --git a/sqlmesh/core/plan/evaluator.py b/sqlmesh/core/plan/evaluator.py index 5d408154a6..f2f432a97e 100644 --- a/sqlmesh/core/plan/evaluator.py +++ b/sqlmesh/core/plan/evaluator.py @@ -17,24 +17,31 @@ import abc import logging import typing as t - from sqlmesh.core import analytics from sqlmesh.core import constants as c from sqlmesh.core.console import Console, get_console -from sqlmesh.core.notification_target import ( - NotificationTarget, - NotificationTargetManager, +from sqlmesh.core.environment import EnvironmentNamingInfo, execute_environment_statements +from sqlmesh.core.macros import RuntimeStage +from sqlmesh.core.snapshot.definition import to_view_mapping, SnapshotTableInfo +from sqlmesh.core.plan import stages +from sqlmesh.core.plan.definition import EvaluatablePlan +from sqlmesh.core.scheduler import Scheduler +from sqlmesh.core.snapshot import ( + DeployabilityIndex, + Snapshot, + SnapshotEvaluator, + SnapshotIntervals, + SnapshotId, + SnapshotInfoLike, + SnapshotCreationFailedError, ) -from sqlmesh.core.plan.definition import Plan -from sqlmesh.core.scheduler import Scheduler, SignalFactory -from sqlmesh.core.snapshot import DeployabilityIndex, Snapshot, SnapshotEvaluator +from sqlmesh.utils import to_snake_case from sqlmesh.core.state_sync import StateSync -from sqlmesh.core.state_sync.base import PromotionResult -from sqlmesh.core.user import User -from sqlmesh.schedulers.airflow import common as airflow_common -from sqlmesh.schedulers.airflow.client import AirflowClient, BaseAirflowClient -from sqlmesh.schedulers.airflow.mwaa_client import MWAAClient -from sqlmesh.utils.errors import SQLMeshError +from sqlmesh.core.plan.common import identify_restatement_intervals_across_snapshot_versions +from sqlmesh.utils import CorrelationId +from sqlmesh.utils.concurrency import NodeExecutionFailedError +from sqlmesh.utils.errors import PlanError, ConflictingPlanError, SQLMeshError +from sqlmesh.utils.date import now, to_timestamp logger = logging.getLogger(__name__) @@ -42,7 +49,9 @@ class PlanEvaluator(abc.ABC): @abc.abstractmethod def evaluate( - self, plan: Plan, circuit_breaker: t.Optional[t.Callable[[], bool]] = None + self, + plan: EvaluatablePlan, + circuit_breaker: t.Optional[t.Callable[[], bool]] = None, ) -> None: """Evaluates a plan by pushing snapshots and backfilling data. @@ -53,6 +62,7 @@ def evaluate( Args: plan: The plan to evaluate. + circuit_breaker: The circuit breaker to use. """ @@ -61,25 +71,27 @@ def __init__( self, state_sync: StateSync, snapshot_evaluator: SnapshotEvaluator, + create_scheduler: t.Callable[[t.Iterable[Snapshot], SnapshotEvaluator], Scheduler], default_catalog: t.Optional[str], - backfill_concurrent_tasks: int = 1, console: t.Optional[Console] = None, - notification_target_manager: t.Optional[NotificationTargetManager] = None, - signal_factory: t.Optional[SignalFactory] = None, ): self.state_sync = state_sync self.snapshot_evaluator = snapshot_evaluator + self.create_scheduler = create_scheduler self.default_catalog = default_catalog - self.backfill_concurrent_tasks = backfill_concurrent_tasks self.console = console or get_console() - self.notification_target_manager = notification_target_manager - self.signal_factory = signal_factory + self._circuit_breaker: t.Optional[t.Callable[[], bool]] = None def evaluate( self, - plan: Plan, + plan: EvaluatablePlan, circuit_breaker: t.Optional[t.Callable[[], bool]] = None, ) -> None: + self._circuit_breaker = circuit_breaker + self.snapshot_evaluator = self.snapshot_evaluator.set_correlation_id( + CorrelationId.from_plan_id(plan.plan_id) + ) + self.console.start_plan_evaluation(plan) analytics.collector.on_plan_apply_start( plan=plan, @@ -89,445 +101,390 @@ def evaluate( ) try: - snapshots = plan.snapshots - all_names = { - s.name for s in snapshots.values() if plan.is_selected_for_backfill(s.name) - } - deployability_index_for_evaluation = DeployabilityIndex.create(snapshots) - deployability_index_for_creation = deployability_index_for_evaluation - if plan.is_dev: - before_promote_snapshots = all_names - after_promote_snapshots = set() - else: - before_promote_snapshots = { - s.name - for s in snapshots.values() - if deployability_index_for_evaluation.is_representative(s) - and plan.is_selected_for_backfill(s.name) - } - after_promote_snapshots = all_names - before_promote_snapshots - deployability_index_for_evaluation = DeployabilityIndex.all_deployable() - - self._push(plan, deployability_index_for_creation) - update_intervals_for_new_snapshots(plan.new_snapshots, self.state_sync) - self._restate(plan) - self._backfill( - plan, - before_promote_snapshots, - deployability_index_for_evaluation, - circuit_breaker=circuit_breaker, - ) - promotion_result = self._promote(plan, before_promote_snapshots) - self._backfill( - plan, - after_promote_snapshots, - deployability_index_for_evaluation, - circuit_breaker=circuit_breaker, - ) - self._update_views(plan, promotion_result, deployability_index_for_evaluation) - - if not plan.requires_backfill: - self.console.log_success("Virtual Update executed successfully") + plan_stages = stages.build_plan_stages(plan, self.state_sync, self.default_catalog) + self._evaluate_stages(plan_stages, plan) except Exception as e: analytics.collector.on_plan_apply_end(plan_id=plan.plan_id, error=e) raise else: analytics.collector.on_plan_apply_end(plan_id=plan.plan_id) finally: + self.snapshot_evaluator.recycle() self.console.stop_plan_evaluation() - def _backfill( - self, - plan: Plan, - selected_snapshots: t.Set[str], - deployability_index: DeployabilityIndex, - circuit_breaker: t.Optional[t.Callable[[], bool]] = None, + def _evaluate_stages( + self, plan_stages: t.List[stages.PlanStage], plan: EvaluatablePlan ) -> None: - """Backfill missing intervals for snapshots that are part of the given plan. - - Args: - plan: The plan to source snapshots from. - selected_snapshots: The snapshots to backfill. - """ - if not plan.requires_backfill or not selected_snapshots: - return - - snapshots = plan.snapshots - scheduler = Scheduler( - snapshots.values(), - self.snapshot_evaluator, - self.state_sync, + for stage in plan_stages: + stage_name = stage.__class__.__name__ + handler_name = f"visit_{to_snake_case(stage_name)}" + if not hasattr(self, handler_name): + raise SQLMeshError(f"Unexpected plan stage: {stage_name}") + logger.info("Evaluating plan stage %s", stage_name) + handler = getattr(self, handler_name) + handler(stage, plan) + + def visit_before_all_stage(self, stage: stages.BeforeAllStage, plan: EvaluatablePlan) -> None: + execute_environment_statements( + adapter=self.snapshot_evaluator.adapter, + environment_statements=stage.statements, + runtime_stage=RuntimeStage.BEFORE_ALL, + environment_naming_info=plan.environment.naming_info, default_catalog=self.default_catalog, - max_workers=self.backfill_concurrent_tasks, - console=self.console, - notification_target_manager=self.notification_target_manager, - signal_factory=self.signal_factory, - ) - is_run_successful = scheduler.run( - plan.environment_naming_info, - plan.start, - plan.end, + snapshots=stage.all_snapshots, + start=plan.start, + end=plan.end, execution_time=plan.execution_time, - restatements=plan.restatements, - selected_snapshots=selected_snapshots, - deployability_index=deployability_index, - circuit_breaker=circuit_breaker, - end_bounded=plan.end_bounded, + selected_models=plan.selected_models, ) - if not is_run_successful: - raise SQLMeshError("Plan application failed.") - def _push(self, plan: Plan, deployability_index: t.Optional[DeployabilityIndex] = None) -> None: - """Push the snapshots to the state sync. + def visit_after_all_stage(self, stage: stages.AfterAllStage, plan: EvaluatablePlan) -> None: + execute_environment_statements( + adapter=self.snapshot_evaluator.adapter, + environment_statements=stage.statements, + runtime_stage=RuntimeStage.AFTER_ALL, + environment_naming_info=plan.environment.naming_info, + default_catalog=self.default_catalog, + snapshots=stage.all_snapshots, + start=plan.start, + end=plan.end, + execution_time=plan.execution_time, + selected_models=plan.selected_models, + ) - As a part of plan pushing, snapshot tables are created. + def visit_create_snapshot_records_stage( + self, stage: stages.CreateSnapshotRecordsStage, plan: EvaluatablePlan + ) -> None: + self.state_sync.push_snapshots(stage.snapshots) + analytics.collector.on_snapshots_created( + new_snapshots=stage.snapshots, plan_id=plan.plan_id + ) + # Update the intervals for the new forward-only snapshots + self._update_intervals_for_new_snapshots(stage.snapshots) - Args: - plan: The plan to source snapshots from. - deployability_index: Indicates which snapshots are deployable in the context of this creation. - """ - snapshots_to_create = [ - s - for s in plan.snapshots.values() - if s.is_model and not s.is_symbolic and plan.is_selected_for_backfill(s.name) - ] - snapshots_to_create_count = len(snapshots_to_create) + def visit_physical_layer_update_stage( + self, stage: stages.PhysicalLayerUpdateStage, plan: EvaluatablePlan + ) -> None: + skip_message = "" if plan.restatements else "\nSKIP: No physical layer updates to perform" - if snapshots_to_create_count > 0: - self.console.start_creation_progress( - snapshots_to_create_count, plan.environment_naming_info, self.default_catalog - ) + snapshots_to_create = stage.snapshots + if not snapshots_to_create: + self.console.log_success(skip_message) + return - completed = False + completion_status = None + progress_stopped = False try: - self.snapshot_evaluator.create( + completion_status = self.snapshot_evaluator.create( snapshots_to_create, - plan.snapshots, + stage.all_snapshots, allow_destructive_snapshots=plan.allow_destructive_models, - deployability_index=deployability_index, + allow_additive_snapshots=plan.allow_additive_models, + deployability_index=stage.deployability_index, + on_start=lambda x: self.console.start_creation_progress( + x, plan.environment, self.default_catalog + ), on_complete=self.console.update_creation_progress, ) - completed = True + if completion_status.is_nothing_to_do: + self.console.log_success(skip_message) + return + except SnapshotCreationFailedError as ex: + self.console.stop_creation_progress(success=False) + progress_stopped = True + + for error in ex.errors: + logger.info(str(error), exc_info=error) + + self.console.log_skipped_models({s.name for s in ex.skipped}) + self.console.log_failed_models(ex.errors) + + raise PlanError("Plan application failed.") finally: - self.console.stop_creation_progress(success=completed) + if not progress_stopped: + self.console.stop_creation_progress( + success=completion_status is not None and completion_status.is_success + ) - self.state_sync.push_snapshots(plan.new_snapshots) + def visit_physical_layer_schema_creation_stage( + self, stage: stages.PhysicalLayerSchemaCreationStage, plan: EvaluatablePlan + ) -> None: + try: + self.snapshot_evaluator.create_physical_schemas( + stage.snapshots, stage.deployability_index + ) + except Exception as ex: + raise PlanError("Plan application failed.") from ex + + def visit_backfill_stage(self, stage: stages.BackfillStage, plan: EvaluatablePlan) -> None: + if plan.empty_backfill: + intervals_to_add = [] + for snapshot in stage.all_snapshots.values(): + if not snapshot.evaluatable or not plan.is_selected_for_backfill(snapshot.name): + # Skip snapshots that are not evaluatable or not selected for backfill. + continue + intervals = [ + snapshot.inclusive_exclusive(plan.start, plan.end, strict=False, expand=False) + ] + is_deployable = stage.deployability_index.is_deployable(snapshot) + intervals_to_add.append( + SnapshotIntervals( + name=snapshot.name, + identifier=snapshot.identifier, + version=snapshot.version, + dev_version=snapshot.dev_version, + intervals=intervals if is_deployable else [], + dev_intervals=intervals if not is_deployable else [], + ) + ) + self.state_sync.add_snapshots_intervals(intervals_to_add) + self.console.log_success("SKIP: No model batches to execute") + return - analytics.collector.on_snapshots_created( - new_snapshots=plan.new_snapshots, plan_id=plan.plan_id + if not stage.snapshot_to_intervals: + self.console.log_success("SKIP: No model batches to execute") + return + + scheduler = self.create_scheduler(stage.all_snapshots.values(), self.snapshot_evaluator) + errors, _ = scheduler.run_merged_intervals( + merged_intervals=stage.snapshot_to_intervals, + deployability_index=stage.deployability_index, + environment_naming_info=plan.environment.naming_info, + execution_time=plan.execution_time, + circuit_breaker=self._circuit_breaker, + start=plan.start, + end=plan.end, + allow_destructive_snapshots=plan.allow_destructive_models, + allow_additive_snapshots=plan.allow_additive_models, + selected_snapshot_ids=stage.selected_snapshot_ids, + selected_models=plan.selected_models, + is_restatement=bool(plan.restatements), ) + if errors: + raise PlanError("Plan application failed.") - def _promote( - self, plan: Plan, no_gaps_snapshot_names: t.Optional[t.Set[str]] = None - ) -> PromotionResult: - """Promote a plan. + def visit_audit_only_run_stage( + self, stage: stages.AuditOnlyRunStage, plan: EvaluatablePlan + ) -> None: + audit_snapshots = stage.snapshots + if not audit_snapshots: + return - Args: - plan: The plan to promote. - no_gaps_snapshot_names: The names of snapshots to check for gaps if the no gaps check is enabled in the plan. - If not provided, all snapshots are checked. - """ - promotion_result = self.state_sync.promote( + # If there are any snapshots to be audited, we'll reuse the scheduler's internals to audit them + scheduler = self.create_scheduler(audit_snapshots, self.snapshot_evaluator) + completion_status = scheduler.audit( plan.environment, - no_gaps_snapshot_names=no_gaps_snapshot_names if plan.no_gaps else set(), + plan.start, + plan.end, + execution_time=plan.execution_time, + end_bounded=plan.end_bounded, + start_override_per_model=plan.start_override_per_model, + end_override_per_model=plan.end_override_per_model, ) - if not plan.is_dev: - self.snapshot_evaluator.migrate( - [s for s in plan.snapshots.values() if s.is_paused], - plan.snapshots, - plan.allow_destructive_models, + if completion_status.is_failure: + raise PlanError("Plan application failed.") + + def visit_restatement_stage( + self, stage: stages.RestatementStage, plan: EvaluatablePlan + ) -> None: + # Restating intervals on prod plans means that once the data for the intervals being restated has been backfilled + # (which happens in the backfill stage) then we need to clear those intervals *from state* across all other environments. + # + # This ensures that work done in dev environments can still be promoted to prod by forcing dev environments to + # re-run intervals that changed in prod (because after this stage runs they are cleared from state and thus show as missing) + # + # It also means that any new dev environments created while this restatement plan was running also get the + # correct intervals cleared because we look up matching snapshots as at right now and not as at the time the plan + # was created, which could have been several hours ago if there was a lot of data to restate. + # + # Without this rule, its possible that promoting a dev table to prod will introduce old data to prod + + intervals_to_clear = identify_restatement_intervals_across_snapshot_versions( + state_reader=self.state_sync, + prod_restatements=plan.restatements, + disable_restatement_models=plan.disabled_restatement_models, + loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()}, + current_ts=to_timestamp(plan.execution_time or now()), + ) + + if not intervals_to_clear: + # Nothing to do + return + + # While the restatements were being processed, did any of the snapshots being restated get new versions deployed? + # If they did, they will not reflect the data that just got restated, so we need to notify the user + deployed_during_restatement: t.Dict[ + str, t.Tuple[SnapshotTableInfo, SnapshotTableInfo] + ] = {} # tuple of (restated_snapshot, current_prod_snapshot) + + if deployed_env := self.state_sync.get_environment(plan.environment.name): + promoted_snapshots_by_name = {s.name: s for s in deployed_env.snapshots} + + for name in plan.restatements: + snapshot = stage.all_snapshots[name] + version = snapshot.table_info.version + if ( + prod_snapshot := promoted_snapshots_by_name.get(name) + ) and prod_snapshot.version != version: + deployed_during_restatement[name] = ( + snapshot.table_info, + prod_snapshot.table_info, + ) + + # we need to *not* clear the intervals on the snapshots where new versions were deployed while the restatement was running in order to prevent + # subsequent plans from having unexpected intervals to backfill. + # we instead list the affected models and abort the plan with an error so the user can decide what to do + # (either re-attempt the restatement plan or leave things as they are) + filtered_intervals_to_clear = [ + (s.snapshot, s.interval) + for s in intervals_to_clear.values() + if s.snapshot.name not in deployed_during_restatement + ] + + if filtered_intervals_to_clear: + # We still clear intervals in other envs for models that were successfully restated without having new versions promoted during restatement + self.state_sync.remove_intervals( + snapshot_intervals=filtered_intervals_to_clear, + remove_shared_versions=plan.is_prod, ) - if not plan.ensure_finalized_snapshots: - # Only unpause at this point if we don't have to use the finalized snapshots - # for subsequent plan applications. Otherwise, unpause right before finalizing - # the environment. - self.state_sync.unpause_snapshots(promotion_result.added, plan.end) - return promotion_result + if deployed_env and deployed_during_restatement: + self.console.log_models_updated_during_restatement( + list(deployed_during_restatement.values()), + plan.environment.naming_info, + self.default_catalog, + ) + raise ConflictingPlanError( + f"Another plan ({deployed_env.summary.plan_id}) deployed new versions of {len(deployed_during_restatement)} models in the target environment '{plan.environment.name}' while they were being restated by this plan.\n" + "Please re-apply your plan if these new versions should be restated." + ) - def _update_views( - self, - plan: Plan, - promotion_result: PromotionResult, - deployability_index: t.Optional[DeployabilityIndex] = None, + def visit_environment_record_update_stage( + self, stage: stages.EnvironmentRecordUpdateStage, plan: EvaluatablePlan ) -> None: - """Update environment views. + self.state_sync.promote( + plan.environment, + no_gaps_snapshot_names=stage.no_gaps_snapshot_names if plan.no_gaps else set(), + environment_statements=plan.environment_statements, + ) - Args: - plan: The plan to promote. - promotion_result: The result of the promotion. - deployability_index: Indicates which snapshots are deployable in the context of this promotion. - """ - if not plan.is_dev and plan.ensure_finalized_snapshots: - # Unpause right before finalizing the environment in case when - # we need to use the finalized snapshots for subsequent plan applications. - # Otherwise, unpause right after updatig the environment record. - self.state_sync.unpause_snapshots(promotion_result.added, plan.end) + def visit_migrate_schemas_stage( + self, stage: stages.MigrateSchemasStage, plan: EvaluatablePlan + ) -> None: + try: + self.snapshot_evaluator.migrate( + stage.snapshots, + stage.all_snapshots, + allow_destructive_snapshots=plan.allow_destructive_models, + allow_additive_snapshots=plan.allow_additive_models, + deployability_index=stage.deployability_index, + ) + except NodeExecutionFailedError as ex: + raise PlanError(str(ex.__cause__) if ex.__cause__ else str(ex)) + def visit_unpause_stage(self, stage: stages.UnpauseStage, plan: EvaluatablePlan) -> None: + self.state_sync.unpause_snapshots(stage.promoted_snapshots, plan.end) + + def visit_virtual_layer_update_stage( + self, stage: stages.VirtualLayerUpdateStage, plan: EvaluatablePlan + ) -> None: environment = plan.environment self.console.start_promotion_progress( - len(promotion_result.added) + len(promotion_result.removed), + list(stage.promoted_snapshots) + list(stage.demoted_snapshots), environment.naming_info, self.default_catalog, ) completed = False try: - self.snapshot_evaluator.promote( - [plan.context_diff.snapshots[s.snapshot_id] for s in promotion_result.added], + self._promote_snapshots( + plan, + [stage.all_snapshots[s.snapshot_id] for s in stage.promoted_snapshots], environment.naming_info, - deployability_index=deployability_index, + deployability_index=stage.deployability_index, on_complete=lambda s: self.console.update_promotion_progress(s, True), + snapshots=stage.all_snapshots, ) - if promotion_result.removed_environment_naming_info: - self.snapshot_evaluator.demote( - promotion_result.removed, - promotion_result.removed_environment_naming_info, + if stage.demoted_environment_naming_info: + self._demote_snapshots( + [stage.all_snapshots[s.snapshot_id] for s in stage.demoted_snapshots], + stage.demoted_environment_naming_info, + deployability_index=stage.deployability_index, on_complete=lambda s: self.console.update_promotion_progress(s, False), + snapshots=stage.all_snapshots, ) - self.state_sync.finalize(environment) + completed = True finally: self.console.stop_promotion_progress(success=completed) - def _restate(self, plan: Plan) -> None: - if not plan.restatements: - return - - self.state_sync.remove_interval( - [ - (plan.context_diff.snapshots[s_id], interval) - for s_id, interval in plan.restatements.items() - ], - remove_shared_versions=not plan.is_dev, - ) - + def visit_finalize_environment_stage( + self, stage: stages.FinalizeEnvironmentStage, plan: EvaluatablePlan + ) -> None: + self.state_sync.finalize(plan.environment) -class BaseAirflowPlanEvaluator(PlanEvaluator): - def __init__( + def _promote_snapshots( self, - console: t.Optional[Console], - blocking: bool, - dag_run_poll_interval_secs: int, - dag_creation_poll_interval_secs: int, - dag_creation_max_retry_attempts: int, - ): - self.blocking = blocking - self.dag_run_poll_interval_secs = dag_run_poll_interval_secs - self.dag_creation_poll_interval_secs = dag_creation_poll_interval_secs - self.dag_creation_max_retry_attempts = dag_creation_max_retry_attempts - self.console = console or get_console() - - def evaluate( - self, plan: Plan, circuit_breaker: t.Optional[t.Callable[[], bool]] = None + plan: EvaluatablePlan, + target_snapshots: t.Iterable[Snapshot], + environment_naming_info: EnvironmentNamingInfo, + snapshots: t.Dict[SnapshotId, Snapshot], + deployability_index: t.Optional[DeployabilityIndex] = None, + on_complete: t.Optional[t.Callable[[SnapshotInfoLike], None]] = None, ) -> None: - plan_request_id = plan.plan_id - self._apply_plan(plan, plan_request_id) - - analytics.collector.on_plan_apply_start( - plan=plan, - engine_type=None, - state_sync_type=None, - scheduler_type=c.AIRFLOW, - ) - - if self.blocking: - plan_application_dag_id = airflow_common.plan_application_dag_id( - plan.environment_naming_info.name, plan_request_id - ) - - self.console.log_status_update( - f"Waiting for the plan application DAG '{plan_application_dag_id}' to be provisioned on Airflow" - ) - - plan_application_dag_run_id = self.client.wait_for_first_dag_run( - plan_application_dag_id, - self.dag_creation_poll_interval_secs, - self.dag_creation_max_retry_attempts, - ) - - self.client.print_tracking_url( - plan_application_dag_id, - plan_application_dag_run_id, - "plan application", - ) - plan_application_succeeded = self.client.wait_for_dag_run_completion( - plan_application_dag_id, - plan_application_dag_run_id, - self.dag_run_poll_interval_secs, - ) - if not plan_application_succeeded: - raise SQLMeshError("Plan application failed.") - - self.console.log_success("The plan has been applied successfully") - - @property - def client(self) -> BaseAirflowClient: - raise NotImplementedError - - def _apply_plan(self, plan: Plan, plan_request_id: str) -> None: - raise NotImplementedError - - -class StateBasedAirflowPlanEvaluator(BaseAirflowPlanEvaluator): - backfill_concurrent_tasks: int - ddl_concurrent_tasks: int - notification_targets: t.Optional[t.List[NotificationTarget]] - users: t.Optional[t.List[User]] - - def _apply_plan(self, plan: Plan, plan_request_id: str) -> None: - from sqlmesh.schedulers.airflow.plan import PlanDagState, create_plan_dag_spec - - plan_application_request = airflow_common.PlanApplicationRequest( - new_snapshots=plan.new_snapshots, - environment=plan.environment, - no_gaps=plan.no_gaps, - skip_backfill=plan.skip_backfill, - request_id=plan_request_id, - restatements={s.name: i for s, i in (plan.restatements or {}).items()}, - notification_targets=self.notification_targets or [], - backfill_concurrent_tasks=self.backfill_concurrent_tasks, - ddl_concurrent_tasks=self.ddl_concurrent_tasks, - users=self.users or [], - is_dev=plan.is_dev, - forward_only=plan.forward_only, - models_to_backfill=plan.models_to_backfill, - end_bounded=plan.end_bounded, - ensure_finalized_snapshots=plan.ensure_finalized_snapshots, - directly_modified_snapshots=list(plan.directly_modified), - indirectly_modified_snapshots={ - change_source.name: list(snapshots) - for change_source, snapshots in plan.indirectly_modified.items() - }, - removed_snapshots=list(plan.context_diff.removed_snapshots), - execution_time=plan.execution_time, - allow_destructive_snapshots=plan.allow_destructive_models, + self.snapshot_evaluator.promote( + target_snapshots, + start=plan.start, + end=plan.end, + execution_time=plan.execution_time or now(), + snapshots=snapshots, + table_mapping=to_view_mapping( + snapshots.values(), + environment_naming_info, + default_catalog=self.default_catalog, + dialect=self.snapshot_evaluator.adapter.dialect, + ), + environment_naming_info=environment_naming_info, + deployability_index=deployability_index, + on_complete=on_complete, ) - plan_dag_spec = create_plan_dag_spec(plan_application_request, self.state_sync) - PlanDagState.from_state_sync(self.state_sync).add_dag_spec(plan_dag_spec) - - @property - def state_sync(self) -> StateSync: - raise NotImplementedError - -class AirflowPlanEvaluator(StateBasedAirflowPlanEvaluator): - def __init__( + def _demote_snapshots( self, - airflow_client: AirflowClient, - console: t.Optional[Console] = None, - blocking: bool = True, - dag_run_poll_interval_secs: int = 10, - dag_creation_poll_interval_secs: int = 30, - dag_creation_max_retry_attempts: int = 10, - notification_targets: t.Optional[t.List[NotificationTarget]] = None, - backfill_concurrent_tasks: int = 1, - ddl_concurrent_tasks: int = 1, - users: t.Optional[t.List[User]] = None, - state_sync: t.Optional[StateSync] = None, - ): - super().__init__( - console, - blocking, - dag_run_poll_interval_secs, - dag_creation_poll_interval_secs, - dag_creation_max_retry_attempts, - ) - self._airflow_client = airflow_client - self.notification_targets = notification_targets or [] - self.backfill_concurrent_tasks = backfill_concurrent_tasks - self.ddl_concurrent_tasks = ddl_concurrent_tasks - self.users = users or [] - - self._state_sync = state_sync - - @property - def client(self) -> BaseAirflowClient: - return self._airflow_client - - @property - def state_sync(self) -> StateSync: - if self._state_sync is None: - raise SQLMeshError("State Sync is not configured") - return self._state_sync - - def _apply_plan(self, plan: Plan, plan_request_id: str) -> None: - if self._state_sync is not None: - super()._apply_plan(plan, plan_request_id) - return - - self._airflow_client.apply_plan( - plan.new_snapshots, - plan.environment, - plan_request_id, - no_gaps=plan.no_gaps, - skip_backfill=plan.skip_backfill, - restatements=plan.restatements, - notification_targets=self.notification_targets, - backfill_concurrent_tasks=self.backfill_concurrent_tasks, - ddl_concurrent_tasks=self.ddl_concurrent_tasks, - users=self.users, - is_dev=plan.is_dev, - forward_only=plan.forward_only, - models_to_backfill=plan.models_to_backfill, - end_bounded=plan.end_bounded, - ensure_finalized_snapshots=plan.ensure_finalized_snapshots, - directly_modified_snapshots=list(plan.directly_modified), - indirectly_modified_snapshots={ - change_source.name: list(snapshots) - for change_source, snapshots in plan.indirectly_modified.items() - }, - removed_snapshots=list(plan.context_diff.removed_snapshots), - execution_time=plan.execution_time, - allow_destructive_snapshots=plan.allow_destructive_models, + target_snapshots: t.Iterable[Snapshot], + environment_naming_info: EnvironmentNamingInfo, + snapshots: t.Dict[SnapshotId, Snapshot], + deployability_index: t.Optional[DeployabilityIndex] = None, + on_complete: t.Optional[t.Callable[[SnapshotInfoLike], None]] = None, + ) -> None: + self.snapshot_evaluator.demote( + target_snapshots, + environment_naming_info, + table_mapping=to_view_mapping( + snapshots.values(), + environment_naming_info, + default_catalog=self.default_catalog, + dialect=self.snapshot_evaluator.adapter.dialect, + ), + deployability_index=deployability_index, + on_complete=on_complete, ) + def _update_intervals_for_new_snapshots(self, snapshots: t.Collection[Snapshot]) -> None: + snapshots_intervals: t.List[SnapshotIntervals] = [] + for snapshot in snapshots: + if snapshot.is_forward_only: + snapshots_intervals.append( + SnapshotIntervals( + name=snapshot.name, + identifier=snapshot.identifier, + version=snapshot.version, + dev_version=snapshot.dev_version, + dev_intervals=snapshot.dev_intervals, + ) + ) -class MWAAPlanEvaluator(StateBasedAirflowPlanEvaluator): - def __init__( - self, - client: MWAAClient, - state_sync: StateSync, - console: t.Optional[Console] = None, - blocking: bool = True, - dag_run_poll_interval_secs: int = 10, - dag_creation_poll_interval_secs: int = 30, - dag_creation_max_retry_attempts: int = 10, - notification_targets: t.Optional[t.List[NotificationTarget]] = None, - backfill_concurrent_tasks: int = 1, - ddl_concurrent_tasks: int = 1, - users: t.Optional[t.List[User]] = None, - ): - super().__init__( - console, - blocking, - dag_run_poll_interval_secs, - dag_creation_poll_interval_secs, - dag_creation_max_retry_attempts, - ) - self._mwaa_client = client - self._state_sync = state_sync - self.notification_targets = notification_targets or [] - self.backfill_concurrent_tasks = backfill_concurrent_tasks - self.ddl_concurrent_tasks = ddl_concurrent_tasks - self.users = users or [] - - @property - def client(self) -> BaseAirflowClient: - return self._mwaa_client - - @property - def state_sync(self) -> StateSync: - return self._state_sync - - -def update_intervals_for_new_snapshots( - snapshots: t.Collection[Snapshot], state_sync: StateSync -) -> None: - for snapshot in state_sync.refresh_snapshot_intervals(snapshots): - if snapshot.is_forward_only: - snapshot.dev_intervals = snapshot.intervals.copy() - for start, end in snapshot.dev_intervals: - state_sync.add_interval(snapshot, start, end, is_dev=True) + if snapshots_intervals: + self.state_sync.add_snapshots_intervals(snapshots_intervals) diff --git a/sqlmesh/core/plan/explainer.py b/sqlmesh/core/plan/explainer.py new file mode 100644 index 0000000000..f0a1e44aff --- /dev/null +++ b/sqlmesh/core/plan/explainer.py @@ -0,0 +1,386 @@ +from __future__ import annotations + +import abc +import typing as t +import logging +from dataclasses import dataclass +from collections import defaultdict + +from rich.console import Console as RichConsole +from rich.tree import Tree +from sqlglot.dialects.dialect import DialectType +from sqlmesh.core import constants as c +from sqlmesh.core.console import Console, TerminalConsole, get_console +from sqlmesh.core.environment import EnvironmentNamingInfo +from sqlmesh.core.plan.common import ( + SnapshotIntervalClearRequest, + identify_restatement_intervals_across_snapshot_versions, +) +from sqlmesh.core.plan.definition import EvaluatablePlan, SnapshotIntervals +from sqlmesh.core.plan import stages +from sqlmesh.core.plan.evaluator import ( + PlanEvaluator, +) +from sqlmesh.core.state_sync import StateReader +from sqlmesh.core.snapshot.definition import ( + SnapshotInfoMixin, + SnapshotIdAndVersion, + model_display_name, +) +from sqlmesh.utils import Verbosity, rich as srich, to_snake_case +from sqlmesh.utils.date import to_ts +from sqlmesh.utils.errors import SQLMeshError + + +logger = logging.getLogger(__name__) + + +class PlanExplainer(PlanEvaluator): + def __init__( + self, + state_reader: StateReader, + default_catalog: t.Optional[str], + console: t.Optional[Console] = None, + ): + self.state_reader = state_reader + self.default_catalog = default_catalog + self.console = console or get_console() + + def evaluate( + self, + plan: EvaluatablePlan, + circuit_breaker: t.Optional[t.Callable[[], bool]] = None, + ) -> None: + plan_stages = stages.build_plan_stages(plan, self.state_reader, self.default_catalog) + explainer_console = _get_explainer_console( + self.console, plan.environment, self.default_catalog + ) + + # add extra metadata that's only needed at this point for better --explain output + plan_stages = [ + ExplainableRestatementStage.from_restatement_stage(stage, self.state_reader, plan) + if isinstance(stage, stages.RestatementStage) + else stage + for stage in plan_stages + ] + + explainer_console.explain(plan_stages) + + +class ExplainerConsole(abc.ABC): + @abc.abstractmethod + def explain(self, stages: t.List[stages.PlanStage]) -> None: + pass + + +@dataclass +class ExplainableRestatementStage(stages.RestatementStage): + """ + This brings forward some calculations that would usually be done in the evaluator so the user can be given a better indication + of what might happen when they ask for the plan to be explained + """ + + snapshot_intervals_to_clear: t.Dict[str, t.List[SnapshotIntervalClearRequest]] + """Which snapshots from other environments would have intervals cleared as part of restatement, grouped by name.""" + + @classmethod + def from_restatement_stage( + cls: t.Type[ExplainableRestatementStage], + stage: stages.RestatementStage, + state_reader: StateReader, + plan: EvaluatablePlan, + ) -> ExplainableRestatementStage: + all_restatement_intervals = identify_restatement_intervals_across_snapshot_versions( + state_reader=state_reader, + prod_restatements=plan.restatements, + disable_restatement_models=plan.disabled_restatement_models, + loaded_snapshots={s.snapshot_id: s for s in stage.all_snapshots.values()}, + ) + + # Group the interval clear requests by snapshot name to make them easier to write to the console + snapshot_intervals_to_clear = defaultdict(list) + for clear_request in all_restatement_intervals.values(): + snapshot_intervals_to_clear[clear_request.snapshot.name].append(clear_request) + + return cls( + snapshot_intervals_to_clear=snapshot_intervals_to_clear, + all_snapshots=stage.all_snapshots, + ) + + +MAX_TREE_LENGTH = 10 + + +class RichExplainerConsole(ExplainerConsole): + def __init__( + self, + environment_naming_info: EnvironmentNamingInfo, + dialect: DialectType, + default_catalog: t.Optional[str], + verbosity: Verbosity = Verbosity.DEFAULT, + console: t.Optional[RichConsole] = None, + ): + self.environment_naming_info = environment_naming_info + self.dialect = dialect + self.default_catalog = default_catalog + self.verbosity = verbosity + self.console: RichConsole = console or srich.console + + def explain(self, stages: t.List[stages.PlanStage]) -> None: + tree = Tree("[bold]Explained plan[/bold]") + for stage in stages: + handler_name = f"visit_{to_snake_case(stage.__class__.__name__)}" + if not hasattr(self, handler_name): + logger.error("Unexpected stage: %s", stage.__class__.__name__) + continue + handler = getattr(self, handler_name) + result = handler(stage) + if result: + tree.add(self._limit_tree(result)) + self.console.print(tree) + + def visit_before_all_stage(self, stage: stages.BeforeAllStage) -> Tree: + return Tree("[bold]Execute before all statements[/bold]") + + def visit_after_all_stage(self, stage: stages.AfterAllStage) -> Tree: + return Tree("[bold]Execute after all statements[/bold]") + + def visit_physical_layer_update_stage(self, stage: stages.PhysicalLayerUpdateStage) -> Tree: + snapshots = [ + s for s in stage.snapshots if s.snapshot_id in stage.snapshots_with_missing_intervals + ] + if not snapshots: + return Tree("[bold]SKIP: No physical layer updates to perform[/bold]") + + tree = Tree( + "[bold]Validate SQL and create physical layer tables and views if they do not exist[/bold]" + ) + for snapshot in snapshots: + is_deployable = ( + stage.deployability_index.is_deployable(snapshot) + if self.environment_naming_info.name != c.PROD + else True + ) + display_name = self._display_name(snapshot) + table_name = snapshot.table_name(is_deployable) + model_tree = Tree(f"{display_name} -> {table_name}") + + if snapshot.is_model: + if snapshot.model.pre_statements: + model_tree.add("Run pre-statements") + if snapshot.model.annotated: + model_tree.add("Dry run model query without inserting results") + + if snapshot.is_view: + create_tree = Tree("Create view if it doesn't exist") + elif ( + snapshot.is_forward_only and snapshot.previous_versions and not snapshot.is_managed + ): + prod_table = snapshot.table_name(True) + create_tree = Tree( + f"Clone {prod_table} into {table_name} and then update its schema if it doesn't exist" + ) + else: + create_tree = Tree("Create table if it doesn't exist") + + if not is_deployable: + create_tree.add("[orange1]preview[/orange1]: data will NOT be reused in production") + model_tree.add(create_tree) + + if snapshot.is_model and snapshot.model.post_statements: + model_tree.add("Run post-statements") + + tree.add(model_tree) + return tree + + def visit_audit_only_run_stage(self, stage: stages.AuditOnlyRunStage) -> Tree: + tree = Tree("[bold]Audit-only execution[/bold]") + for snapshot in stage.snapshots: + display_name = self._display_name(snapshot) + tree.add(display_name) + return tree + + def visit_explainable_restatement_stage(self, stage: ExplainableRestatementStage) -> Tree: + return self.visit_restatement_stage(stage) + + def visit_restatement_stage( + self, stage: t.Union[ExplainableRestatementStage, stages.RestatementStage] + ) -> Tree: + tree = Tree( + "[bold]Invalidate data intervals in state for development environments to prevent old data from being promoted[/bold]\n" + "This only affects state and will not clear physical data from the tables until the next plan for each environment" + ) + + if isinstance(stage, ExplainableRestatementStage) and ( + snapshot_intervals := stage.snapshot_intervals_to_clear + ): + for name, clear_requests in snapshot_intervals.items(): + display_name = model_display_name( + name, self.environment_naming_info, self.default_catalog, self.dialect + ) + interval_start = min(cr.interval[0] for cr in clear_requests) + interval_end = max(cr.interval[1] for cr in clear_requests) + + if not interval_start or not interval_end: + continue + + node = tree.add(f"{display_name} [{to_ts(interval_start)} - {to_ts(interval_end)}]") + + all_environment_names = sorted( + set(env_name for cr in clear_requests for env_name in cr.environment_names) + ) + node.add("in environments: " + ", ".join(all_environment_names)) + + return tree + + def visit_backfill_stage(self, stage: stages.BackfillStage) -> Tree: + if not stage.snapshot_to_intervals: + return Tree("[bold]SKIP: No model batches to execute[/bold]") + + tree = Tree( + "[bold]Backfill models by running their queries and run standalone audits[/bold]" + ) + for snapshot, intervals in stage.snapshot_to_intervals.items(): + display_name = self._display_name(snapshot) + if snapshot.is_model: + is_deployable = stage.deployability_index.is_deployable(snapshot) + table_name = snapshot.table_name(is_deployable) + model_tree = Tree(f"{display_name} -> {table_name}") + + for signal_name, _ in snapshot.model.signals: + model_tree.add(f"Check '{signal_name}' signal") + + if snapshot.model.pre_statements: + model_tree.add("Run pre-statements") + + backfill_tree = Tree("Fully refresh table") + if snapshot.is_incremental: + current_intervals = ( + snapshot.intervals + if stage.deployability_index.is_deployable(snapshot) + else snapshot.dev_intervals + ) + # If there are no intervals, the table will be fully refreshed + if current_intervals: + formatted_range = SnapshotIntervals( + snapshot_id=snapshot.snapshot_id, intervals=intervals + ).format_intervals(snapshot.node.interval_unit) + backfill_tree = Tree( + f"Incrementally insert records within the range [{formatted_range}]" + ) + elif snapshot.is_view: + backfill_tree = Tree("Recreate view") + + if not is_deployable: + backfill_tree.add( + "[orange1]preview[/orange1]: data will NOT be reused in production" + ) + + model_tree.add(backfill_tree) + + if snapshot.model.post_statements: + model_tree.add("Run post-statements") + + if snapshot.model.audits: + for audit_name, _ in snapshot.model.audits: + model_tree.add(f"Run '{audit_name}' audit") + + tree.add(model_tree) + else: + tree.add(f"{display_name} \\[standalone audit]") + return tree + + def visit_migrate_schemas_stage(self, stage: stages.MigrateSchemasStage) -> Tree: + tree = Tree( + "[bold]Update schemas (add, drop, alter columns) of production physical tables to reflect forward-only changes[/bold]" + ) + for snapshot in stage.snapshots: + display_name = self._display_name(snapshot) + table_name = snapshot.table_name(True) + tree.add(f"{display_name} -> {table_name}") + return tree + + def visit_virtual_layer_update_stage(self, stage: stages.VirtualLayerUpdateStage) -> Tree: + tree = Tree( + f"[bold]Update the virtual layer for environment '{self.environment_naming_info.name}'[/bold]" + ) + promote_tree = Tree( + "[bold]Create or update views in the virtual layer to point at new physical tables and views[/bold]" + ) + for snapshot in stage.promoted_snapshots: + display_name = self._display_name(snapshot) + table_name = snapshot.table_name(stage.deployability_index.is_representative(snapshot)) + promote_tree.add(f"{display_name} -> {table_name}") + + demote_tree = Tree( + "[bold]Delete views in the virtual layer for models that were removed[/bold]" + ) + for snapshot in stage.demoted_snapshots: + display_name = self._display_name(snapshot, stage.demoted_environment_naming_info) + demote_tree.add(display_name) + + if stage.promoted_snapshots: + tree.add(self._limit_tree(promote_tree)) + if stage.demoted_snapshots: + tree.add(self._limit_tree(demote_tree)) + return tree + + def visit_create_snapshot_records_stage( + self, stage: stages.CreateSnapshotRecordsStage + ) -> t.Optional[Tree]: + return None + + def visit_environment_record_update_stage( + self, stage: stages.EnvironmentRecordUpdateStage + ) -> t.Optional[Tree]: + return None + + def visit_unpause_stage(self, stage: stages.UnpauseStage) -> t.Optional[Tree]: + return None + + def visit_finalize_environment_stage( + self, stage: stages.FinalizeEnvironmentStage + ) -> t.Optional[Tree]: + return None + + def _display_name( + self, + snapshot: t.Union[SnapshotInfoMixin, SnapshotIdAndVersion], + environment_naming_info: t.Optional[EnvironmentNamingInfo] = None, + ) -> str: + return snapshot.display_name( + environment_naming_info=environment_naming_info or self.environment_naming_info, + default_catalog=self.default_catalog + if self.verbosity < Verbosity.VERY_VERBOSE + else None, + dialect=self.dialect, + ) + + def _limit_tree(self, tree: Tree) -> Tree: + tree_length = len(tree.children) + if tree_length <= MAX_TREE_LENGTH: + return tree + if self.verbosity < Verbosity.VERY_VERBOSE: + tree.children = [ + tree.children[0], + Tree(f".... {tree_length - 2} more ...."), + tree.children[-1], + ] + return tree + + +def _get_explainer_console( + console: t.Optional[Console], + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str], +) -> ExplainerConsole: + console = console or get_console() + if not isinstance(console, TerminalConsole): + raise SQLMeshError("Plain explaination is only supported in the terminal.") + return RichExplainerConsole( + environment_naming_info=environment_naming_info, + dialect=console.dialect, + default_catalog=default_catalog, + verbosity=console.verbosity, + console=console.console, + ) diff --git a/sqlmesh/core/plan/stages.py b/sqlmesh/core/plan/stages.py new file mode 100644 index 0000000000..729e1705b4 --- /dev/null +++ b/sqlmesh/core/plan/stages.py @@ -0,0 +1,703 @@ +import typing as t + +from dataclasses import dataclass +from sqlmesh.core import constants as c +from sqlmesh.core.environment import EnvironmentStatements, EnvironmentNamingInfo, Environment +from sqlmesh.core.plan.common import should_force_rebuild +from sqlmesh.core.plan.definition import EvaluatablePlan +from sqlmesh.core.state_sync import StateReader +from sqlmesh.core.scheduler import merged_missing_intervals, SnapshotToIntervals +from sqlmesh.core.snapshot.definition import ( + DeployabilityIndex, + Snapshot, + SnapshotTableInfo, + SnapshotId, + snapshots_to_dag, +) +from sqlmesh.utils.errors import PlanError + + +@dataclass +class BeforeAllStage: + """Run environment statements before every other stage. + + Args: + statements: Environment statements to run before every other stage. + all_snapshots: All snapshots in the plan by name. + """ + + statements: t.List[EnvironmentStatements] + all_snapshots: t.Dict[str, Snapshot] + + +@dataclass +class AfterAllStage: + """Run environment statements after all other stages. + + Args: + statements: Environment statements to run after all other stages. + all_snapshots: All snapshots in the plan by name. + """ + + statements: t.List[EnvironmentStatements] + all_snapshots: t.Dict[str, Snapshot] + + +@dataclass +class CreateSnapshotRecordsStage: + """Create new snapshot reecords in the state. + + Args: + snapshots: New snapshots to create records for. + """ + + snapshots: t.List[Snapshot] + + +@dataclass +class PhysicalLayerUpdateStage: + """Update the physical layer by creating physical tables and views for given snapshots. + + Args: + snapshots: Snapshots to create physical tables and views for. This collection can be empty in which case + no physical layer update is needed. This can be useful to report the lack of physical layer updates + back to the user. + all_snapshots: All snapshots in the plan by snapshot ID. + snapshots_with_missing_intervals: Snapshots that have missing intervals. + deployability_index: Deployability index for this stage. + """ + + snapshots: t.List[Snapshot] + all_snapshots: t.Dict[SnapshotId, Snapshot] + snapshots_with_missing_intervals: t.Set[SnapshotId] + deployability_index: DeployabilityIndex + + +@dataclass +class PhysicalLayerSchemaCreationStage: + """Create the physical schemas for the given snapshots. + + Args: + snapshots: Snapshots to create physical schemas for. + deployability_index: Deployability index for this stage. + """ + + snapshots: t.List[Snapshot] + deployability_index: DeployabilityIndex + + +@dataclass +class AuditOnlyRunStage: + """Run audits only for given snapshots. + + Args: + snapshots: Snapshots to run audits for. + """ + + snapshots: t.List[Snapshot] + + +@dataclass +class RestatementStage: + """Clear intervals from state for snapshots in *other* environments, when restatements are requested in prod. + + This stage is effectively a "marker" stage to trigger the plan evaluator to perform the "clear intervals" logic after the BackfillStage has completed. + The "clear intervals" logic is executed just-in-time using the latest state available in order to pick up new snapshots that may have + been created while the BackfillStage was running, which is why we do not build a list of snapshots to clear at plan time and defer to evaluation time. + + Note that this stage is only present on `prod` plans because dev plans do not need to worry about clearing intervals in other environments. + + Args: + all_snapshots: All snapshots in the plan by name. Note that this does not include the snapshots from other environments that will get their + intervals cleared, it's included here as an optimization to prevent having to re-fetch the current plan's snapshots + """ + + all_snapshots: t.Dict[str, Snapshot] + + +@dataclass +class BackfillStage: + """Backfill given missing intervals. + + Args: + snapshot_to_intervals: Intervals to backfill. This collection can be empty in which case no backfill is needed. + This can be useful to report the lack of backfills back to the user. + selected_snapshot_ids: The snapshots to include in the run DAG. + all_snapshots: All snapshots in the plan by name. + deployability_index: Deployability index for this stage. + before_promote: Whether this stage is before the promotion stage. + """ + + snapshot_to_intervals: SnapshotToIntervals + selected_snapshot_ids: t.Set[SnapshotId] + all_snapshots: t.Dict[str, Snapshot] + deployability_index: DeployabilityIndex + before_promote: bool = True + + +@dataclass +class EnvironmentRecordUpdateStage: + """Update the environment record in the state. + + Args: + no_gaps_snapshot_names: Names of snapshots for which there should be no interval gaps. + """ + + no_gaps_snapshot_names: t.Set[str] + + +@dataclass +class MigrateSchemasStage: + """Migrate schemas of physical tables for given snapshots. + + Args: + snapshots: Snapshots to migrate schemas for. + all_snapshots: All snapshots in the plan by snapshot ID. + deployability_index: Deployability index for this stage. + """ + + snapshots: t.List[Snapshot] + all_snapshots: t.Dict[SnapshotId, Snapshot] + deployability_index: DeployabilityIndex + + +@dataclass +class VirtualLayerUpdateStage: + """Update the virtual layer by creating and deleting views for given snapshots. + + Args: + promoted_snapshots: Snapshots to create views for. + demoted_snapshots: Snapshots to delete views for. + demoted_environment_naming_info: Environment naming info of the previous environment record. + all_snapshots: All snapshots in the plan by snapshot ID. + deployability_index: Deployability index for this stage. + """ + + promoted_snapshots: t.Set[SnapshotTableInfo] + demoted_snapshots: t.Set[SnapshotTableInfo] + demoted_environment_naming_info: t.Optional[EnvironmentNamingInfo] + all_snapshots: t.Dict[SnapshotId, Snapshot] + deployability_index: DeployabilityIndex + + +@dataclass +class UnpauseStage: + """Unpause given snapshots that are being deployed to prod. + + Args: + promoted_snapshots: Snapshots to unpause. + """ + + promoted_snapshots: t.Set[SnapshotTableInfo] + + +@dataclass +class FinalizeEnvironmentStage: + """Finalize the enviornment record in the state. + + Finalization means that all stages have been applied and that the environment has been transitioned + to the new state successfully. This should be the last stage in the plan application process. + """ + + pass + + +PlanStage = t.Union[ + BeforeAllStage, + AfterAllStage, + CreateSnapshotRecordsStage, + PhysicalLayerUpdateStage, + PhysicalLayerSchemaCreationStage, + AuditOnlyRunStage, + RestatementStage, + BackfillStage, + EnvironmentRecordUpdateStage, + MigrateSchemasStage, + VirtualLayerUpdateStage, + UnpauseStage, + FinalizeEnvironmentStage, +] + + +class PlanStagesBuilder: + """The builder for the plan stages. + + Args: + state_reader: The state reader to use to read the snapshots and environment. + default_catalog: The default catalog to use for the snapshots. + """ + + def __init__( + self, + state_reader: StateReader, + default_catalog: t.Optional[str], + ): + self.state_reader = state_reader + self.default_catalog = default_catalog + + def build(self, plan: EvaluatablePlan) -> t.List[PlanStage]: + """Builds the plan stages for the given plan. + + NOTE: Building the plan stages should NOT produce any side effects in the state or the data warehouse. + + Args: + plan: The plan to build the stages for. + + Returns: + A list of plan stages. + """ + new_snapshots = {s.snapshot_id: s for s in plan.new_snapshots} + stored_snapshots = self.state_reader.get_snapshots(plan.environment.snapshots) + snapshots = {**new_snapshots, **stored_snapshots} + snapshots_by_name = {s.name: s for s in snapshots.values()} + dag = snapshots_to_dag(snapshots.values()) + + all_selected_for_backfill_snapshots = { + s.snapshot_id for s in snapshots.values() if plan.is_selected_for_backfill(s.name) + } + existing_environment = self.state_reader.get_environment(plan.environment.name) + + self._adjust_intervals(snapshots_by_name, plan, existing_environment) + + deployability_index = DeployabilityIndex.create(snapshots, start=plan.start) + if plan.is_dev: + before_promote_snapshots = all_selected_for_backfill_snapshots + after_promote_snapshots = set() + snapshots_with_schema_migration = [] + else: + before_promote_snapshots = { + s.snapshot_id + for s in snapshots.values() + if (deployability_index.is_representative(s) or s.is_seed) + and plan.is_selected_for_backfill(s.name) + } + after_promote_snapshots = all_selected_for_backfill_snapshots - before_promote_snapshots + deployability_index = DeployabilityIndex.all_deployable() + + snapshot_ids_with_schema_migration = [ + s.snapshot_id for s in snapshots.values() if s.requires_schema_migration_in_prod + ] + # Include all upstream dependencies of snapshots that require schema migration to make sure + # the upstream tables are created before the schema updates are applied + snapshots_with_schema_migration = [ + snapshots[s_id] + for s_id in dag.subdag(*snapshot_ids_with_schema_migration) + if snapshots[s_id].supports_schema_migration_in_prod + ] + + snapshots_to_intervals = self._missing_intervals( + plan, snapshots_by_name, deployability_index + ) + needs_backfill = ( + not plan.empty_backfill and not plan.skip_backfill and bool(snapshots_to_intervals) + ) + missing_intervals_before_promote: SnapshotToIntervals = {} + missing_intervals_after_promote: SnapshotToIntervals = {} + if needs_backfill: + for snapshot, intervals in snapshots_to_intervals.items(): + if snapshot.snapshot_id in before_promote_snapshots: + missing_intervals_before_promote[snapshot] = intervals + elif snapshot.snapshot_id in after_promote_snapshots: + missing_intervals_after_promote[snapshot] = intervals + + promoted_snapshots, demoted_snapshots, demoted_environment_naming_info = ( + self._get_promoted_demoted_snapshots(plan, existing_environment) + ) + + stages: t.List[PlanStage] = [] + + before_all_stage = self._get_before_all_stage(plan, snapshots_by_name) + if before_all_stage: + stages.append(before_all_stage) + + if plan.new_snapshots: + stages.append(CreateSnapshotRecordsStage(snapshots=plan.new_snapshots)) + + snapshots_to_create = self._get_snapshots_to_create(plan, snapshots) + if snapshots_to_create: + stages.append( + PhysicalLayerSchemaCreationStage( + snapshots=snapshots_to_create, deployability_index=deployability_index + ) + ) + if not needs_backfill: + stages.append( + self._get_physical_layer_update_stage( + plan, + snapshots_to_create, + snapshots, + snapshots_to_intervals, + deployability_index, + ) + ) + + audit_only_snapshots = self._get_audit_only_snapshots(new_snapshots) + if audit_only_snapshots: + stages.append(AuditOnlyRunStage(snapshots=list(audit_only_snapshots.values()))) + + if missing_intervals_before_promote: + stages.append( + BackfillStage( + snapshot_to_intervals=missing_intervals_before_promote, + selected_snapshot_ids={ + s_id + for s_id in before_promote_snapshots + if plan.is_selected_for_backfill(s_id.name) + }, + all_snapshots=snapshots_by_name, + deployability_index=deployability_index, + ) + ) + elif not needs_backfill: + # Append an empty backfill stage so that explainer can show that the stage is skipped + stages.append( + BackfillStage( + snapshot_to_intervals={}, + selected_snapshot_ids=set(), + all_snapshots=snapshots_by_name, + deployability_index=deployability_index, + ) + ) + + # note: "restatement stage" (which is clearing intervals in state - not actually performing the restatements, that's the backfill stage) + # needs to come *after* the backfill stage so that at no time do other plans / runs see empty prod intervals and compete with this plan to try to fill them. + # in addition, when we update intervals in state, we only clear intervals from dev snapshots to force dev models to be backfilled based on the new prod data. + # we can leave prod intervals alone because by the time this plan finishes, the intervals in state have not actually changed, since restatement replaces + # data for existing intervals and does not produce new ones + restatement_stage = self._get_restatement_stage(plan, snapshots_by_name) + if restatement_stage: + stages.append(restatement_stage) + + stages.append( + EnvironmentRecordUpdateStage( + no_gaps_snapshot_names={s.name for s in before_promote_snapshots} + ) + ) + + if snapshots_with_schema_migration: + stages.append( + MigrateSchemasStage( + snapshots=snapshots_with_schema_migration, + all_snapshots=snapshots, + deployability_index=deployability_index, + ) + ) + + if not plan.is_dev and not plan.ensure_finalized_snapshots and promoted_snapshots: + # Only unpause at this point if we don't have to use the finalized snapshots + # for subsequent plan applications. Otherwise, unpause right before updating + # the virtual layer. + stages.append(UnpauseStage(promoted_snapshots=promoted_snapshots)) + + if missing_intervals_after_promote: + stages.append( + BackfillStage( + snapshot_to_intervals=missing_intervals_after_promote, + selected_snapshot_ids={ + s_id + for s_id in after_promote_snapshots + if plan.is_selected_for_backfill(s_id.name) + }, + all_snapshots=snapshots_by_name, + deployability_index=deployability_index, + ) + ) + + if not plan.is_dev and plan.ensure_finalized_snapshots and promoted_snapshots: + # Unpause right before updating the virtual layer and finalizing the environment in case when + # we need to use the finalized snapshots for subsequent plan applications. + # Otherwise, unpause right after updatig the environment record. + stages.append(UnpauseStage(promoted_snapshots=promoted_snapshots)) + + full_demoted_snapshots = self.state_reader.get_snapshots( + s.snapshot_id for s in demoted_snapshots if s.snapshot_id not in snapshots + ) + virtual_layer_update_stage = self._get_virtual_layer_update_stage( + promoted_snapshots, + demoted_snapshots, + demoted_environment_naming_info, + snapshots | full_demoted_snapshots, + deployability_index, + plan.is_dev, + ) + if virtual_layer_update_stage: + stages.append(virtual_layer_update_stage) + + stages.append(FinalizeEnvironmentStage()) + + after_all_stage = self._get_after_all_stage(plan, snapshots_by_name) + if after_all_stage: + stages.append(after_all_stage) + + return stages + + def _get_before_all_stage( + self, plan: EvaluatablePlan, snapshots_by_name: t.Dict[str, Snapshot] + ) -> t.Optional[BeforeAllStage]: + before_all = [ + environment_statements + for environment_statements in plan.environment_statements or [] + if environment_statements.before_all + ] + return ( + BeforeAllStage(statements=before_all, all_snapshots=snapshots_by_name) + if before_all + else None + ) + + def _get_after_all_stage( + self, plan: EvaluatablePlan, snapshots_by_name: t.Dict[str, Snapshot] + ) -> t.Optional[AfterAllStage]: + after_all = [ + environment_statements + for environment_statements in plan.environment_statements or [] + if environment_statements.after_all + ] + return ( + AfterAllStage(statements=after_all, all_snapshots=snapshots_by_name) + if after_all + else None + ) + + def _get_restatement_stage( + self, plan: EvaluatablePlan, snapshots_by_name: t.Dict[str, Snapshot] + ) -> t.Optional[RestatementStage]: + if plan.restate_all_snapshots: + if plan.is_dev: + raise PlanError( + "Clearing intervals from state across dev model versions is only valid for prod plans" + ) + + if plan.restatements: + return RestatementStage( + all_snapshots=snapshots_by_name, + ) + + return None + + def _get_physical_layer_update_stage( + self, + plan: EvaluatablePlan, + snapshots_to_create: t.List[Snapshot], + all_snapshots: t.Dict[SnapshotId, Snapshot], + snapshots_to_intervals: SnapshotToIntervals, + deployability_index: DeployabilityIndex, + ) -> PhysicalLayerUpdateStage: + return PhysicalLayerUpdateStage( + snapshots=snapshots_to_create, + all_snapshots=all_snapshots, + snapshots_with_missing_intervals={ + s.snapshot_id + for s in snapshots_to_intervals + if plan.is_selected_for_backfill(s.name) + }, + deployability_index=deployability_index, + ) + + def _get_virtual_layer_update_stage( + self, + promoted_snapshots: t.Set[SnapshotTableInfo], + demoted_snapshots: t.Set[SnapshotTableInfo], + demoted_environment_naming_info: t.Optional[EnvironmentNamingInfo], + all_snapshots: t.Dict[SnapshotId, Snapshot], + deployability_index: DeployabilityIndex, + is_dev: bool, + ) -> t.Optional[VirtualLayerUpdateStage]: + def _should_update_virtual_layer(snapshot: SnapshotTableInfo) -> bool: + # Skip virtual layer update for snapshots with virtual environment support disabled + virtual_environment_enabled = is_dev or snapshot.virtual_environment_mode.is_full + return snapshot.is_model and not snapshot.is_symbolic and virtual_environment_enabled + + promoted_snapshots = {s for s in promoted_snapshots if _should_update_virtual_layer(s)} + demoted_snapshots = {s for s in demoted_snapshots if _should_update_virtual_layer(s)} + if not promoted_snapshots and not demoted_snapshots: + return None + + return VirtualLayerUpdateStage( + promoted_snapshots=promoted_snapshots, + demoted_snapshots=demoted_snapshots, + demoted_environment_naming_info=demoted_environment_naming_info, + all_snapshots=all_snapshots, + deployability_index=deployability_index, + ) + + def _get_promoted_demoted_snapshots( + self, plan: EvaluatablePlan, existing_environment: t.Optional[Environment] + ) -> t.Tuple[ + t.Set[SnapshotTableInfo], t.Set[SnapshotTableInfo], t.Optional[EnvironmentNamingInfo] + ]: + if existing_environment: + new_table_infos = { + table_info.name: table_info for table_info in plan.environment.promoted_snapshots + } + existing_table_infos = { + table_info.name: table_info + for table_info in existing_environment.promoted_snapshots + } + views_that_changed_location = { + existing_table_info + for existing_table_info in existing_environment.promoted_snapshots + if existing_table_info.name in new_table_infos + and existing_table_info.qualified_view_name.for_environment( + existing_environment.naming_info + ) + != new_table_infos[existing_table_info.name].qualified_view_name.for_environment( + plan.environment.naming_info + ) + } + missing_model_names = set(existing_table_infos) - { + s.name for s in plan.environment.promoted_snapshots + } + demoted_snapshots = { + existing_table_infos[name] for name in missing_model_names + } | views_that_changed_location + else: + demoted_snapshots = set() + + promoted_snapshots = set(plan.environment.promoted_snapshots) + if existing_environment and plan.environment.can_partially_promote(existing_environment): + promoted_snapshots -= set(existing_environment.promoted_snapshots) + + demoted_environment_naming_info = ( + existing_environment.naming_info if demoted_snapshots and existing_environment else None + ) + + return ( + promoted_snapshots, + demoted_snapshots, + demoted_environment_naming_info, + ) + + def _missing_intervals( + self, + plan: EvaluatablePlan, + snapshots_by_name: t.Dict[str, Snapshot], + deployability_index: DeployabilityIndex, + ) -> SnapshotToIntervals: + return merged_missing_intervals( + snapshots=snapshots_by_name.values(), + start=plan.start, + end=plan.end, + execution_time=plan.execution_time, + restatements={ + snapshots_by_name[name].snapshot_id: interval + for name, interval in plan.restatements.items() + }, + deployability_index=deployability_index, + end_bounded=plan.end_bounded, + ignore_cron=plan.ignore_cron, + start_override_per_model=plan.start_override_per_model, + end_override_per_model=plan.end_override_per_model, + ) + + def _get_audit_only_snapshots( + self, new_snapshots: t.Dict[SnapshotId, Snapshot] + ) -> t.Dict[SnapshotId, Snapshot]: + metadata_snapshots = [] + for snapshot in new_snapshots.values(): + if ( + not snapshot.is_metadata + or not snapshot.is_model + or not snapshot.evaluatable + or not snapshot.previous_version + ): + continue + + metadata_snapshots.append(snapshot) + + # Bulk load all the previous snapshots + previous_snapshot_ids = [ + s.previous_version.snapshot_id(s.name) for s in metadata_snapshots if s.previous_version + ] + previous_snapshots = { + s.name: s for s in self.state_reader.get_snapshots(previous_snapshot_ids).values() + } + + # Check if any of the snapshots have modifications to the audits field by comparing the hashes + audit_snapshots = {} + for snapshot in metadata_snapshots: + if snapshot.name not in previous_snapshots: + continue + + previous_snapshot = previous_snapshots[snapshot.name] + new_audits_hash = snapshot.model.audit_metadata_hash() + previous_audit_hash = previous_snapshot.model.audit_metadata_hash() + + if snapshot.model.audits and previous_audit_hash != new_audits_hash: + audit_snapshots[snapshot.snapshot_id] = snapshot + + return audit_snapshots + + def _get_snapshots_to_create( + self, plan: EvaluatablePlan, snapshots: t.Dict[SnapshotId, Snapshot] + ) -> t.List[Snapshot]: + promoted_snapshot_ids = ( + set(plan.environment.promoted_snapshot_ids) + if plan.environment.promoted_snapshot_ids is not None + else None + ) + + def _should_create(s: Snapshot) -> bool: + if not s.is_model or s.is_symbolic: + return False + # Only create tables for snapshots that we're planning to promote or that were selected for backfill + return ( + plan.is_selected_for_backfill(s.name) + or promoted_snapshot_ids is None + or s.snapshot_id in promoted_snapshot_ids + ) + + return [s for s in snapshots.values() if _should_create(s)] + + def _adjust_intervals( + self, + snapshots_by_name: t.Dict[str, Snapshot], + plan: EvaluatablePlan, + existing_environment: t.Optional[Environment], + ) -> None: + # Make sure the intervals are up to date and restatements are reflected + self.state_reader.refresh_snapshot_intervals(snapshots_by_name.values()) + + if not existing_environment: + existing_environment = self.state_reader.get_environment(c.PROD) + + if existing_environment: + new_snapshot_ids = set() + new_snapshot_versions = set() + for s in snapshots_by_name.values(): + if s.is_model: + new_snapshot_ids.add(s.snapshot_id) + new_snapshot_versions.add(s.name_version) + # Only compare to old snapshots that share the same version as the new snapshots + old_snapshot_ids = { + s.snapshot_id + for s in existing_environment.snapshots + if s.is_model + and s.name_version in new_snapshot_versions + and s.snapshot_id not in new_snapshot_ids + } + if old_snapshot_ids: + old_snapshots = self.state_reader.get_snapshots(old_snapshot_ids) + for old in old_snapshots.values(): + new = snapshots_by_name.get(old.name) + if not new or old.version != new.version: + continue + if should_force_rebuild(old, new): + # If the difference between 2 snapshots requires a full rebuild, + # then clear the intervals for the new snapshot. + new.intervals = [] + + for new_snapshot in plan.new_snapshots: + if new_snapshot.is_forward_only: + # Forward-only snapshots inherit intervals in dev because of cloning + new_snapshot.dev_intervals = new_snapshot.intervals.copy() + for s_name, interval in plan.restatements.items(): + snapshots_by_name[s_name].remove_interval(interval) + + +def build_plan_stages( + plan: EvaluatablePlan, + state_reader: StateReader, + default_catalog: t.Optional[str], +) -> t.List[PlanStage]: + return PlanStagesBuilder(state_reader, default_catalog).build(plan) diff --git a/sqlmesh/core/renderer.py b/sqlmesh/core/renderer.py index 95d3528846..50c1faeb63 100644 --- a/sqlmesh/core/renderer.py +++ b/sqlmesh/core/renderer.py @@ -3,10 +3,12 @@ import logging import typing as t from contextlib import contextmanager +from functools import partial from pathlib import Path -from sqlglot import exp, parse +from sqlglot import exp, Dialect from sqlglot.errors import SqlglotError +from sqlglot.helper import ensure_list from sqlglot.optimizer.annotate_types import annotate_types from sqlglot.optimizer.qualify import qualify from sqlglot.optimizer.simplify import simplify @@ -14,20 +16,29 @@ from sqlmesh.core import constants as c from sqlmesh.core import dialect as d from sqlmesh.core.macros import MacroEvaluator, RuntimeStage -from sqlmesh.utils.date import TimeLike, date_dict, make_inclusive_end, to_datetime +from sqlmesh.utils.date import ( + TimeLike, + date_dict, + make_inclusive, + to_datetime, + make_ts_exclusive, + to_tstz, +) from sqlmesh.utils.errors import ( ConfigError, - MacroEvalError, ParsetimeAdapterCallError, SQLMeshError, raise_config_error, ) -from sqlmesh.utils.jinja import JinjaMacroRegistry +from sqlmesh.utils.jinja import JinjaMacroRegistry, extract_error_details from sqlmesh.utils.metaprogramming import Executable, prepare_env if t.TYPE_CHECKING: from sqlglot._typing import E + from sqlglot.dialects.dialect import DialectType + from sqlmesh.core.linter.rule import Rule + from sqlmesh.core.model.definition import _Model from sqlmesh.core.snapshot import DeployabilityIndex, Snapshot @@ -38,17 +49,18 @@ class BaseExpressionRenderer: def __init__( self, expression: exp.Expression, - dialect: str, + dialect: DialectType, macro_definitions: t.List[d.MacroDef], - path: Path = Path(), + path: t.Optional[Path] = None, jinja_macro_registry: t.Optional[JinjaMacroRegistry] = None, python_env: t.Optional[t.Dict[str, Executable]] = None, only_execution_time: bool = False, schema: t.Optional[t.Dict[str, t.Any]] = None, default_catalog: t.Optional[str] = None, quote_identifiers: bool = True, - model_fqn: t.Optional[str] = None, normalize_identifiers: bool = True, + optimize_query: t.Optional[bool] = True, + model: t.Optional[_Model] = None, ): self._expression = expression self._dialect = dialect @@ -62,7 +74,9 @@ def __init__( self._quote_identifiers = quote_identifiers self.update_schema({} if schema is None else schema) self._cache: t.List[t.Optional[exp.Expression]] = [] - self._model_fqn = model_fqn + self._model_fqn = model.fqn if model else None + self._optimize_query_flag = optimize_query is not False + self._model = model def update_schema(self, schema: t.Dict[str, t.Any]) -> None: self.schema = d.normalize_mapping_schema(schema, dialect=self._dialect) @@ -101,62 +115,55 @@ def _render( if should_cache and self._cache: return self._cache - if self._model_fqn and "this_model" not in kwargs: - kwargs["this_model"] = exp.to_table( - self._to_table_mapping( - ( - [snapshots[self._model_fqn]] - if snapshots and self._model_fqn in snapshots - else [] - ), - deployability_index, - ).get(self._model_fqn, self._model_fqn), - dialect=self._dialect, - ).sql(dialect=self._dialect, identify=True) - - expressions = [self._expression] - - render_kwargs = { - **date_dict( - to_datetime(execution_time or c.EPOCH), - to_datetime(start or c.EPOCH) if not self._only_execution_time else None, - make_inclusive_end(end or c.EPOCH) if not self._only_execution_time else None, - ), - **kwargs, - } - - variables = kwargs.pop("variables", {}) - jinja_env = self._jinja_macro_registry.build_environment( - **{**render_kwargs, **prepare_env(self._python_env), **variables}, - snapshots=(snapshots or {}), - table_mapping=table_mapping, - deployability_index=deployability_index, - default_catalog=self._default_catalog, - runtime_stage=runtime_stage.value, - ) - - if isinstance(self._expression, d.Jinja): - try: - expressions = [] - rendered_expression = jinja_env.from_string(self._expression.name).render() - if rendered_expression.strip(): - expressions = [e for e in parse(rendered_expression, read=self._dialect) if e] + environment_naming_info = kwargs.get("environment_naming_info") + if environment_naming_info is not None: + kwargs["this_env"] = getattr(environment_naming_info, "name") + if snapshots: + schemas, views = set(), [] + for snapshot in snapshots.values(): + if snapshot.is_model and not snapshot.is_symbolic: + schemas.add( + snapshot.qualified_view_name.schema_for_environment( + environment_naming_info, dialect=self._dialect + ) + ) + views.append( + snapshot.display_name( + environment_naming_info, self._default_catalog, self._dialect + ) + ) + if schemas: + kwargs["schemas"] = list(schemas) + if views: + kwargs["views"] = views + + this_model = kwargs.pop("this_model", None) + + this_snapshot = (snapshots or {}).get(self._model_fqn) if self._model_fqn else None + if not this_model and self._model_fqn: + this_model = self._resolve_table( + self._model_fqn, + snapshots={self._model_fqn: this_snapshot} if this_snapshot else None, + deployability_index=deployability_index, + table_mapping=table_mapping, + ) + if this_snapshot and (kind := this_snapshot.model_kind_name): + kwargs["model_kind_name"] = kind.name - if not expressions: - raise ConfigError(f"Failed to parse an expression:\n{self._expression}") - except ParsetimeAdapterCallError: - raise - except Exception as ex: - raise ConfigError( - f"Could not render or parse jinja at '{self._path}'.\n{ex}" - ) from ex + def _resolve_table(table: str | exp.Table) -> str: + return self._resolve_table( + d.normalize_model_name(table, self._default_catalog, self._dialect), + snapshots=snapshots, + table_mapping=table_mapping, + deployability_index=deployability_index, + ).sql(dialect=self._dialect, identify=True, comments=False) macro_evaluator = MacroEvaluator( self._dialect, python_env=self._python_env, - jinja_env=jinja_env, schema=self.schema, runtime_stage=runtime_stage, + resolve_table=_resolve_table, resolve_tables=lambda e: self._resolve_tables( e, snapshots=snapshots, @@ -170,28 +177,124 @@ def _render( snapshots=snapshots, default_catalog=self._default_catalog, path=self._path, + environment_naming_info=environment_naming_info, + model_fqn=self._model_fqn, ) - for definition in self._macro_definitions: - try: - macro_evaluator.evaluate(definition) - except MacroEvalError as ex: - raise_config_error(f"Failed to evaluate macro '{definition}'. {ex}", self._path) + start_time, end_time = ( + make_inclusive(start or c.EPOCH, end or c.EPOCH, self._dialect) + if not self._only_execution_time + else (None, None) + ) + + render_kwargs = { + **date_dict( + to_datetime(execution_time or c.EPOCH), + start_time, + end_time, + ), + **kwargs, + } + + if this_model: + render_kwargs["this_model"] = this_model macro_evaluator.locals.update(render_kwargs) + variables = kwargs.pop("variables", {}) if variables: macro_evaluator.locals.setdefault(c.SQLMESH_VARS, {}).update(variables) + expressions = [self._expression] + if isinstance(self._expression, d.Jinja): + try: + jinja_env_kwargs = { + **{ + **render_kwargs, + **_prepare_python_env_for_jinja(macro_evaluator, self._python_env), + **variables, + }, + "snapshots": snapshots or {}, + "table_mapping": table_mapping, + "deployability_index": deployability_index, + "default_catalog": self._default_catalog, + "runtime_stage": runtime_stage.value, + "resolve_table": _resolve_table, + "model_instance": self._model, + } + + if this_model: + jinja_env_kwargs["this_model"] = this_model.sql( + dialect=self._dialect, identify=True, comments=False + ) + + if self._model and self._model.kind.is_incremental_by_time_range: + all_refs = list( + self._jinja_macro_registry.global_objs.get("sources", {}).values() # type: ignore + ) + list( + self._jinja_macro_registry.global_objs.get("refs", {}).values() # type: ignore + ) + for ref in all_refs: + if ref.event_time_filter: + ref.event_time_filter["start"] = render_kwargs["start_tstz"] + ref.event_time_filter["end"] = to_tstz( + make_ts_exclusive(render_kwargs["end_tstz"], dialect=self._dialect) + ) + + jinja_env = self._jinja_macro_registry.build_environment(**jinja_env_kwargs) + + expressions = [] + rendered_expression = jinja_env.from_string(self._expression.name).render() + logger.debug( + f"Rendered Jinja expression for model '{self._model_fqn}' at '{self._path}': '{rendered_expression}'" + ) + except ParsetimeAdapterCallError: + raise + except Exception as ex: + raise ConfigError( + f"Could not render jinja for '{self._path}'.\n" + extract_error_details(ex) + ) from ex + + if rendered_expression.strip(): + # ensure there is actual SQL and not just comments and non-SQL jinja + dialect = Dialect.get_or_raise(self._dialect) + tokens = dialect.tokenize(rendered_expression) + + if tokens: + try: + expressions = [ + e for e in dialect.parser().parse(tokens, rendered_expression) if e + ] + + if not expressions: + raise ConfigError( + f"Failed to parse an expression:\n{rendered_expression}" + ) + except Exception as ex: + raise ConfigError( + f"Could not parse the rendered jinja at '{self._path}'.\n{ex}" + ) from ex + + for definition in self._macro_definitions: + try: + macro_evaluator.evaluate(definition) + except Exception as ex: + raise_config_error( + f"Failed to evaluate macro '{definition}'.\n\n{ex}\n", self._path + ) + resolved_expressions: t.List[t.Optional[exp.Expression]] = [] for expression in expressions: try: - expression = macro_evaluator.transform(expression) # type: ignore - except MacroEvalError as ex: - raise_config_error(f"Failed to resolve macro for expression. {ex}", self._path) + transformed_expressions = ensure_list(macro_evaluator.transform(expression)) + except Exception as ex: + raise_config_error( + f"Failed to resolve macros for\n\n{expression.sql(dialect=self._dialect, pretty=True)}\n\n{ex}\n", + self._path, + ) - if expression: + for expression in t.cast(t.List[exp.Expression], transformed_expressions): with self._normalize_and_quote(expression) as expression: if hasattr(expression, "selects"): for select in expression.selects: @@ -220,6 +323,28 @@ def _render( def update_cache(self, expression: t.Optional[exp.Expression]) -> None: self._cache = [expression] + def _resolve_table( + self, + table_name: str | exp.Expression, + snapshots: t.Optional[t.Dict[str, Snapshot]] = None, + table_mapping: t.Optional[t.Dict[str, str]] = None, + deployability_index: t.Optional[DeployabilityIndex] = None, + ) -> exp.Table: + table = exp.replace_tables( + exp.maybe_parse(table_name, into=exp.Table, dialect=self._dialect), + { + **self._to_table_mapping((snapshots or {}).values(), deployability_index), + **(table_mapping or {}), + }, + dialect=self._dialect, + copy=False, + ) + # We quote the table here to mimic the behavior of _resolve_tables, otherwise we may end + # up normalizing twice, because _to_table_mapping returns the mapped names unquoted. + return ( + d.quote_identifiers(table, dialect=self._dialect) if self._quote_identifiers else table + ) + def _resolve_tables( self, expression: E, @@ -280,8 +405,7 @@ def _expand(node: exp.Expression) -> exp.Expression: alias=node.alias or model.view_name, copy=False, ) - else: - logger.warning("Failed to expand the nested model '%s'", name) + logger.warning("Failed to expand the nested model '%s'", name) return node expression = expression.transform(_expand, copy=False) # type: ignore @@ -333,6 +457,7 @@ def render( execution_time=execution_time, snapshots=snapshots, deployability_index=deployability_index, + table_mapping=table_mapping, **kwargs, ) except ParsetimeAdapterCallError: @@ -355,10 +480,45 @@ def render( ] +def render_statements( + statements: t.List[str], + dialect: str, + default_catalog: t.Optional[str] = None, + python_env: t.Optional[t.Dict[str, Executable]] = None, + jinja_macros: t.Optional[JinjaMacroRegistry] = None, + **render_kwargs: t.Any, +) -> t.List[str]: + rendered_statements: t.List[str] = [] + for statement in statements: + for expression in d.parse(statement, default_dialect=dialect): + if expression: + rendered = ExpressionRenderer( + expression, + dialect, + [], + jinja_macro_registry=jinja_macros, + python_env=python_env, + default_catalog=default_catalog, + quote_identifiers=False, + normalize_identifiers=False, + ).render(**render_kwargs) + + if not rendered: + # Warning instead of raising for cases where a statement is conditionally executed + logger.warning( + f"Rendering `{expression.sql(dialect=dialect)}` did not return an expression" + ) + else: + rendered_statements.extend(expr.sql(dialect=dialect) for expr in rendered) + + return rendered_statements + + class QueryRenderer(BaseExpressionRenderer): def __init__(self, *args: t.Any, **kwargs: t.Any): super().__init__(*args, **kwargs) self._optimized_cache: t.Optional[exp.Query] = None + self._violated_rules: t.Dict[type[Rule], t.Any] = {} def update_schema(self, schema: t.Dict[str, t.Any]) -> None: super().update_schema(schema) @@ -373,7 +533,7 @@ def render( table_mapping: t.Optional[t.Dict[str, str]] = None, deployability_index: t.Optional[DeployabilityIndex] = None, expand: t.Iterable[str] = tuple(), - optimize: bool = True, + needs_optimization: bool = True, runtime_stage: RuntimeStage = RuntimeStage.LOADING, **kwargs: t.Any, ) -> t.Optional[exp.Query]: @@ -390,7 +550,8 @@ def render( expand: Expand referenced models as subqueries. This is used to bypass backfills when running queries that depend on materialized tables. Model definitions are inlined and can thus be run end to end on the fly. - optimize: Whether to optimize the query. + needs_optimization: Whether or not an optimization should be attempted + (if passing False, it still may return a cached optimized query). runtime_stage: Indicates the current runtime stage, for example if we're still loading the project, etc. kwargs: Additional kwargs to pass to the renderer. @@ -402,7 +563,7 @@ def render( runtime_stage, start, end, execution_time, *kwargs.values() ) - if should_cache and self._optimized_cache and optimize: + if should_cache and self._optimized_cache: query = self._optimized_cache else: try: @@ -419,7 +580,14 @@ def render( except ParsetimeAdapterCallError: return None + expressions = [e for e in expressions if not isinstance(e, exp.Semicolon)] + if not expressions: + # We assume that if there are no expressions, then the model contains dynamic Jinja SQL + # and we thus treat it similar to models with adapter calls to match dbt's behavior. + if isinstance(self._expression, d.JinjaQuery): + return None + raise ConfigError(f"Failed to render query at '{self._path}':\n{self._expression}") if len(expressions) > 1: @@ -430,10 +598,12 @@ def render( if not query: return None if not isinstance(query, exp.Query): - raise_config_error(f"Query needs to be a SELECT or a UNION {query}.", self._path) + raise_config_error( + f"Model query needs to be a SELECT or a UNION, got {query}.", self._path + ) raise - if optimize: + if needs_optimization and self._optimize_query_flag: deps = d.find_tables( query, default_catalog=self._default_catalog, dialect=self._dialect ) @@ -443,7 +613,7 @@ def render( if should_cache: self._optimized_cache = query - if optimize: + if needs_optimization: query = self._resolve_tables( query, snapshots=snapshots, @@ -459,7 +629,12 @@ def render( return query - def update_cache(self, expression: t.Optional[exp.Expression], optimized: bool = False) -> None: + def update_cache( + self, + expression: t.Optional[exp.Expression], + violated_rules: t.Optional[t.Dict[type[Rule], t.Any]] = None, + optimized: bool = False, + ) -> None: if optimized: if not isinstance(expression, exp.Query): raise SQLMeshError(f"Expected a Query but got: {expression}") @@ -467,7 +642,14 @@ def update_cache(self, expression: t.Optional[exp.Expression], optimized: bool = else: super().update_cache(expression) + self._violated_rules = violated_rules or {} + def _optimize_query(self, query: exp.Query, all_deps: t.Set[str]) -> exp.Query: + from sqlmesh.core.linter.rules.builtin import ( + AmbiguousOrInvalidColumn, + InvalidSelectStarExpansion, + ) + # We don't want to normalize names in the schema because that's handled by the optimizer original = query missing_deps = set() @@ -481,12 +663,7 @@ def _optimize_query(self, query: exp.Query, all_deps: t.Set[str]) -> exp.Query: if self._model_fqn and not should_optimize and any(s.is_star for s in query.selects): deps = ", ".join(f"'{dep}'" for dep in sorted(missing_deps)) - - logger.warning( - f"SELECT * cannot be expanded due to missing schema(s) for model(s): {deps}. " - "Run `sqlmesh create_external_models` and / or make sure that the model " - f"'{self._model_fqn}' can be rendered at parse time.", - ) + self._violated_rules[InvalidSelectStarExpansion] = deps try: if should_optimize: @@ -502,13 +679,19 @@ def _optimize_query(self, query: exp.Query, all_deps: t.Set[str]) -> exp.Query: quote_identifiers=self._quote_identifiers, ), schema=self.schema, - ) + dialect=self._dialect, + ), + dialect=self._dialect, ) except SqlglotError as ex: + self._violated_rules[AmbiguousOrInvalidColumn] = ex + query = original - logger.warning( - "%s for model '%s', the column may not exist or is ambiguous", ex, self._model_fqn + except Exception as ex: + raise_config_error( + f"Failed to optimize query, please file an issue at https://github.com/SQLMesh/sqlmesh/issues/new. {ex}", + self._path, ) if not query.type: @@ -516,3 +699,15 @@ def _optimize_query(self, query: exp.Query, all_deps: t.Set[str]) -> exp.Query: annotate_types(select) return query + + +def _prepare_python_env_for_jinja( + evaluator: MacroEvaluator, + python_env: t.Dict[str, Executable], +) -> t.Dict[str, t.Any]: + prepared_env = prepare_env(python_env) + # Pass the evaluator to all macro functions + return { + key: partial(value, evaluator) if callable(value) else value + for key, value in prepared_env.items() + } diff --git a/sqlmesh/core/scheduler.py b/sqlmesh/core/scheduler.py index 6364b8425c..5eb0ff40ff 100644 --- a/sqlmesh/core/scheduler.py +++ b/sqlmesh/core/scheduler.py @@ -1,14 +1,17 @@ from __future__ import annotations - +from dataclasses import dataclass import abc import logging -import traceback import typing as t +import time from datetime import datetime - +from sqlglot import exp from sqlmesh.core import constants as c from sqlmesh.core.console import Console, get_console -from sqlmesh.core.environment import EnvironmentNamingInfo +from sqlmesh.core.environment import EnvironmentNamingInfo, execute_environment_statements +from sqlmesh.core.macros import RuntimeStage +from sqlmesh.core.model.definition import AuditResult +from sqlmesh.core.node import IntervalUnit from sqlmesh.core.notification_target import ( NotificationEvent, NotificationTargetManager, @@ -16,66 +19,81 @@ from sqlmesh.core.snapshot import ( DeployabilityIndex, Snapshot, + SnapshotId, + SnapshotIdBatch, SnapshotEvaluator, + apply_auto_restatements, earliest_start_date, missing_intervals, + merge_intervals, + snapshots_to_dag, + Intervals, +) +from sqlmesh.core.snapshot.definition import check_ready_intervals +from sqlmesh.core.snapshot.definition import ( + Interval, + expand_range, + parent_snapshots_by_name, ) -from sqlmesh.core.snapshot.definition import Interval as SnapshotInterval -from sqlmesh.core.snapshot.definition import SnapshotId from sqlmesh.core.state_sync import StateSync -from sqlmesh.utils import format_exception -from sqlmesh.utils.concurrency import concurrent_apply_to_dag +from sqlmesh.utils import CompletionStatus +from sqlmesh.utils.concurrency import concurrent_apply_to_dag, NodeExecutionFailedError from sqlmesh.utils.dag import DAG from sqlmesh.utils.date import ( TimeLike, - now, now_timestamp, - to_datetime, - to_timestamp, validate_date_range, ) -from sqlmesh.utils.errors import AuditError, CircuitBreakerError, SQLMeshError +from sqlmesh.utils.errors import ( + AuditError, + NodeAuditsErrors, + CircuitBreakerError, + SQLMeshError, + SignalEvalError, +) -logger = logging.getLogger(__name__) -Interval = t.Tuple[datetime, datetime] -Batch = t.List[Interval] -SnapshotToBatches = t.Dict[Snapshot, Batch] -# we store snapshot name instead of snapshots/snapshotids because pydantic -# is extremely slow to hash. snapshot names should be unique within a dag run -SchedulingUnit = t.Tuple[str, t.Tuple[Interval, int]] +if t.TYPE_CHECKING: + from sqlmesh.core.context import ExecutionContext +logger = logging.getLogger(__name__) +SnapshotToIntervals = t.Dict[Snapshot, Intervals] -class Signal(abc.ABC): - @abc.abstractmethod - def check_intervals(self, batch: Batch) -> t.Union[bool, Batch]: - """Returns which intervals are ready from a list of scheduled intervals. - When SQLMesh wishes to execute a batch of intervals, say between `a` and `d`, then - the `batch` parameter will contain each individual interval within this batch, - i.e.: `[a,b),[b,c),[c,d)`. +class SchedulingUnit(abc.ABC): + snapshot_name: str - This function may return `True` to indicate that the whole batch is ready, - `False` to indicate none of the batch's intervals are ready, or a list of - intervals (a batch) to indicate exactly which ones are ready. + def __lt__(self, other: SchedulingUnit) -> bool: + return (self.__class__.__name__, self.snapshot_name) < ( + other.__class__.__name__, + other.snapshot_name, + ) - When returning a batch, the function is expected to return a subset of - the `batch` parameter, e.g.: `[a,b),[b,c)`. Note that it may return - gaps, e.g.: `[a,b),[c,d)`, but it may not alter the bounds of any of the - intervals. - The interface allows an implementation to check batches of intervals without - having to actually compute individual intervals itself. +@dataclass(frozen=True) +class EvaluateNode(SchedulingUnit): + snapshot_name: str + interval: Interval + batch_index: int + + def __lt__(self, other: SchedulingUnit) -> bool: + if not isinstance(other, EvaluateNode): + return super().__lt__(other) + return (self.__class__.__name__, self.snapshot_name, self.interval, self.batch_index) < ( + other.__class__.__name__, + other.snapshot_name, + other.interval, + other.batch_index, + ) - Args: - batch: the list of intervals that are missing and scheduled to run. - Returns: - Either `True` to indicate all intervals are ready, `False` to indicate none are - ready or a list of intervals to indicate exactly which ones are ready. - """ +@dataclass(frozen=True) +class CreateNode(SchedulingUnit): + snapshot_name: str -SignalFactory = t.Callable[[t.Dict[str, t.Union[str, int, float, bool]]], Signal] +@dataclass(frozen=True) +class DummyNode(SchedulingUnit): + snapshot_name: str class Scheduler: @@ -93,7 +111,6 @@ class Scheduler: state_sync: The state sync to pull saved snapshots. max_workers: The maximum number of parallel queries to run. console: The rich instance used for printing scheduling information. - signal_factory: A factory method for building Signal instances from model signal configuration. """ def __init__( @@ -105,10 +122,10 @@ def __init__( max_workers: int = 1, console: t.Optional[Console] = None, notification_target_manager: t.Optional[NotificationTargetManager] = None, - signal_factory: t.Optional[SignalFactory] = None, ): self.state_sync = state_sync self.snapshots = {s.snapshot_id: s for s in snapshots} + self.snapshots_by_name = {snapshot.name: snapshot for snapshot in self.snapshots.values()} self.snapshot_per_version = _resolve_one_snapshot_per_version(self.snapshots.values()) self.default_catalog = default_catalog self.snapshot_evaluator = snapshot_evaluator @@ -117,60 +134,59 @@ def __init__( self.notification_target_manager = ( notification_target_manager or NotificationTargetManager() ) - self.signal_factory = signal_factory - def batches( + def merged_missing_intervals( self, start: t.Optional[TimeLike] = None, end: t.Optional[TimeLike] = None, execution_time: t.Optional[TimeLike] = None, deployability_index: t.Optional[DeployabilityIndex] = None, - restatements: t.Optional[t.Dict[SnapshotId, SnapshotInterval]] = None, + restatements: t.Optional[t.Dict[SnapshotId, Interval]] = None, + start_override_per_model: t.Optional[t.Dict[str, datetime]] = None, + end_override_per_model: t.Optional[t.Dict[str, datetime]] = None, ignore_cron: bool = False, end_bounded: bool = False, selected_snapshots: t.Optional[t.Set[str]] = None, - ) -> SnapshotToBatches: - """Find the optimal date interval paramaters based on what needs processing and maximal batch size. + ) -> SnapshotToIntervals: + """Find the largest contiguous date interval parameters based only on what is missing. For each node name, find all dependencies and look for a stored snapshot from the metastore. If a snapshot is found, calculate the missing intervals that need to be processed given the passed in start and end intervals. - If a snapshot's node specifies a batch size, consecutive intervals are merged into batches of a size that is less than - or equal to the configured one. If no batch size is specified, then it uses the intervals that correspond to the node's cron expression. - For example, if a node is supposed to run daily and has 70 days to backfill with a batch size set to 30, there would be 2 jobs - with 30 days and 1 job with 10. + This is a superset of what may actually get processed at runtime based on things like batch size, signal readiness, etc. Args: start: The start of the run. Defaults to the min node start date. end: The end of the run. Defaults to now. - execution_time: The date/time time reference to use for execution time. Defaults to now. + execution_time: The date/time reference to use for execution time. Defaults to now. deployability_index: Determines snapshots that are deployable in the context of this evaluation. restatements: A set of snapshot names being restated. + start_override_per_model: A mapping of model FQNs to target start dates. + end_override_per_model: A mapping of model FQNs to target end dates. ignore_cron: Whether to ignore the node's cron schedule. end_bounded: If set to true, the returned intervals will be bounded by the target end date, disregarding lookback, allow_partials, and other attributes that could cause the intervals to exceed the target end date. selected_snapshots: A set of snapshot names to run. If not provided, all snapshots will be run. """ - restatements = restatements or {} - validate_date_range(start, end) - - snapshots: t.Collection[Snapshot] = self.snapshot_per_version.values() - if selected_snapshots is not None: - snapshots = [s for s in snapshots if s.name in selected_snapshots] - - self.state_sync.refresh_snapshot_intervals(snapshots) - - return compute_interval_params( - snapshots, - start=start or earliest_start_date(snapshots), - end=end or now(), + snapshots_to_intervals = merged_missing_intervals( + snapshots=self.snapshot_per_version.values(), + start=start, + end=end, + execution_time=execution_time, deployability_index=deployability_index, - execution_time=execution_time or now(), restatements=restatements, + start_override_per_model=start_override_per_model, + end_override_per_model=end_override_per_model, ignore_cron=ignore_cron, end_bounded=end_bounded, - signal_factory=self.signal_factory, ) + # Filtering snapshots after computing missing intervals because we need all snapshots in order + # to correctly infer start dates. + if selected_snapshots is not None: + snapshots_to_intervals = { + s: i for s, i in snapshots_to_intervals.items() if s.name in selected_snapshots + } + return snapshots_to_intervals def evaluate( self, @@ -180,8 +196,12 @@ def evaluate( execution_time: TimeLike, deployability_index: DeployabilityIndex, batch_index: int, + environment_naming_info: t.Optional[EnvironmentNamingInfo] = None, + allow_destructive_snapshots: t.Optional[t.Set[str]] = None, + allow_additive_snapshots: t.Optional[t.Set[str]] = None, + target_table_exists: t.Optional[bool] = None, **kwargs: t.Any, - ) -> None: + ) -> t.List[AuditResult]: """Evaluate a snapshot and add the processed interval to the state sync. Args: @@ -189,16 +209,20 @@ def evaluate( start: The start datetime to render. end: The end datetime to render. execution_time: The date/time time reference to use for execution time. Defaults to now. + allow_destructive_snapshots: Snapshots for which destructive schema changes are allowed. + allow_additive_snapshots: Snapshots for which additive schema changes are allowed. deployability_index: Determines snapshots that are deployable in the context of this evaluation. batch_index: If the snapshot is part of a batch of related snapshots; which index in the batch is it + auto_restatement_enabled: Whether to enable auto restatements. + target_table_exists: Whether the target table exists. If None, the table will be checked for existence. kwargs: Additional kwargs to pass to the renderer. + + Returns: + Tuple of list of all audit results from the evaluation and list of non-blocking audit errors to warn. """ validate_date_range(start, end) - snapshots = { - self.snapshots[p_sid].name: self.snapshots[p_sid] for p_sid in snapshot.parents - } - snapshots[snapshot.name] = snapshot + snapshots = parent_snapshots_by_name(snapshot, self.snapshots) is_deployable = deployability_index.is_deployable(snapshot) @@ -208,46 +232,561 @@ def evaluate( end=end, execution_time=execution_time, snapshots=snapshots, + allow_destructive_snapshots=allow_destructive_snapshots, + allow_additive_snapshots=allow_additive_snapshots, deployability_index=deployability_index, batch_index=batch_index, + target_table_exists=target_table_exists, **kwargs, ) - try: - self.snapshot_evaluator.audit( - snapshot=snapshot, + audit_results = self._audit_snapshot( + snapshot=snapshot, + environment_naming_info=environment_naming_info, + start=start, + end=end, + execution_time=execution_time, + snapshots=snapshots, + deployability_index=deployability_index, + wap_id=wap_id, + **kwargs, + ) + + self.state_sync.add_interval( + snapshot, start, end, is_dev=not is_deployable, last_altered_ts=now_timestamp() + ) + return audit_results + + def run( + self, + environment: str | EnvironmentNamingInfo, + start: t.Optional[TimeLike] = None, + end: t.Optional[TimeLike] = None, + execution_time: t.Optional[TimeLike] = None, + restatements: t.Optional[t.Dict[SnapshotId, Interval]] = None, + start_override_per_model: t.Optional[t.Dict[str, datetime]] = None, + end_override_per_model: t.Optional[t.Dict[str, datetime]] = None, + ignore_cron: bool = False, + end_bounded: bool = False, + selected_snapshots: t.Optional[t.Set[str]] = None, + circuit_breaker: t.Optional[t.Callable[[], bool]] = None, + deployability_index: t.Optional[DeployabilityIndex] = None, + auto_restatement_enabled: bool = False, + run_environment_statements: bool = False, + ) -> CompletionStatus: + return self._run_or_audit( + environment=environment, + start=start, + end=end, + execution_time=execution_time, + remove_intervals=restatements, + start_override_per_model=start_override_per_model, + end_override_per_model=end_override_per_model, + ignore_cron=ignore_cron, + end_bounded=end_bounded, + selected_snapshots=selected_snapshots, + circuit_breaker=circuit_breaker, + deployability_index=deployability_index, + auto_restatement_enabled=auto_restatement_enabled, + run_environment_statements=run_environment_statements, + ) + + def audit( + self, + environment: str | EnvironmentNamingInfo, + start: TimeLike, + end: TimeLike, + execution_time: t.Optional[TimeLike] = None, + start_override_per_model: t.Optional[t.Dict[str, datetime]] = None, + end_override_per_model: t.Optional[t.Dict[str, datetime]] = None, + ignore_cron: bool = False, + end_bounded: bool = False, + selected_snapshots: t.Optional[t.Set[str]] = None, + circuit_breaker: t.Optional[t.Callable[[], bool]] = None, + deployability_index: t.Optional[DeployabilityIndex] = None, + run_environment_statements: bool = False, + ) -> CompletionStatus: + # Remove the intervals from the snapshots that will be audited so that they can be recomputed + # by _run_or_audit as "missing intervals" to reuse the rest of it's logic + remove_intervals = {} + for snapshot in self.snapshots.values(): + removal_intervals = snapshot.get_removal_interval( + start, end, execution_time, is_preview=True + ) + remove_intervals[snapshot.snapshot_id] = removal_intervals + + return self._run_or_audit( + environment=environment, + start=start, + end=end, + execution_time=execution_time, + remove_intervals=remove_intervals, + start_override_per_model=start_override_per_model, + end_override_per_model=end_override_per_model, + ignore_cron=ignore_cron, + end_bounded=end_bounded, + selected_snapshots=selected_snapshots, + circuit_breaker=circuit_breaker, + deployability_index=deployability_index, + run_environment_statements=run_environment_statements, + audit_only=True, + ) + + def batch_intervals( + self, + merged_intervals: SnapshotToIntervals, + deployability_index: t.Optional[DeployabilityIndex], + environment_naming_info: EnvironmentNamingInfo, + dag: t.Optional[DAG[SnapshotId]] = None, + is_restatement: bool = False, + ) -> t.Dict[Snapshot, Intervals]: + dag = dag or snapshots_to_dag(merged_intervals) + + snapshot_intervals: t.Dict[SnapshotId, t.Tuple[Snapshot, t.List[Interval]]] = { + snapshot.snapshot_id: ( + snapshot, + [ + i + for interval in intervals + for i in _expand_range_as_interval(*interval, snapshot.node.interval_unit) + ], + ) + for snapshot, intervals in merged_intervals.items() + } + snapshot_batches: t.Dict[Snapshot, Intervals] = {} + all_unready_intervals: t.Dict[str, set[Interval]] = {} + for snapshot_id in dag: + if snapshot_id not in snapshot_intervals: + continue + snapshot, intervals = snapshot_intervals[snapshot_id] + unready = set(intervals) + + from sqlmesh.core.context import ExecutionContext + + adapter = self.snapshot_evaluator.get_adapter(snapshot.model_gateway) + + parent_intervals: Intervals = [] + for parent_id in snapshot.parents: + parent_snapshot, _ = snapshot_intervals.get(parent_id, (None, [])) + if not parent_snapshot or parent_snapshot.is_external: + continue + + parent_intervals.extend(snapshot_batches[parent_snapshot]) + + context = ExecutionContext( + adapter, + self.snapshots_by_name, + deployability_index, + default_dialect=adapter.dialect, + default_catalog=self.default_catalog, + is_restatement=is_restatement, + parent_intervals=parent_intervals, + ) + + intervals = self._check_ready_intervals( + snapshot, + intervals, + context, + environment_naming_info, + ) + unready -= set(intervals) + + for parent in snapshot.parents: + if parent.name in all_unready_intervals: + unready.update(all_unready_intervals[parent.name]) + + all_unready_intervals[snapshot.name] = unready + + batches = [] + batch_size = snapshot.node.batch_size + next_batch: t.List[t.Tuple[int, int]] = [] + + for interval in interval_diff( + intervals, merge_intervals(unready), uninterrupted=snapshot.depends_on_past + ): + if (batch_size and len(next_batch) >= batch_size) or ( + next_batch and interval[0] != next_batch[-1][-1] + ): + batches.append((next_batch[0][0], next_batch[-1][-1])) + next_batch = [] + + next_batch.append(interval) + + if next_batch: + batches.append((next_batch[0][0], next_batch[-1][-1])) + + snapshot_batches[snapshot] = batches + + return snapshot_batches + + def run_merged_intervals( + self, + *, + merged_intervals: SnapshotToIntervals, + deployability_index: DeployabilityIndex, + environment_naming_info: EnvironmentNamingInfo, + execution_time: t.Optional[TimeLike] = None, + circuit_breaker: t.Optional[t.Callable[[], bool]] = None, + start: t.Optional[TimeLike] = None, + end: t.Optional[TimeLike] = None, + allow_destructive_snapshots: t.Optional[t.Set[str]] = None, + selected_models: t.Optional[t.Set[str]] = None, + allow_additive_snapshots: t.Optional[t.Set[str]] = None, + selected_snapshot_ids: t.Optional[t.Set[SnapshotId]] = None, + run_environment_statements: bool = False, + audit_only: bool = False, + auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {}, + is_restatement: bool = False, + ) -> t.Tuple[t.List[NodeExecutionFailedError[SchedulingUnit]], t.List[SchedulingUnit]]: + """Runs precomputed batches of missing intervals. + + Args: + merged_intervals: The snapshots and contiguous interval ranges to evaluate. + deployability_index: Determines snapshots that are deployable in the context of this evaluation. + environment_naming_info: The environment naming info the user is targeting when applying their change. + execution_time: The date/time reference to use for execution time. + circuit_breaker: An optional handler which checks if the run should be aborted. + start: The start of the run. + end: The end of the run. + allow_destructive_snapshots: Snapshots for which destructive schema changes are allowed. + allow_additive_snapshots: Snapshots for which additive schema changes are allowed. + selected_snapshot_ids: The snapshots to include in the run DAG. If None, all snapshots with missing intervals will be included. + + Returns: + A tuple of errors and skipped intervals. + """ + execution_time = execution_time or now_timestamp() + + selected_snapshots = [self.snapshots[sid] for sid in (selected_snapshot_ids or set())] + if not selected_snapshots: + selected_snapshots = list(merged_intervals) + + # Build the full DAG from all snapshots to preserve transitive dependencies + full_dag = snapshots_to_dag(self.snapshots.values()) + + # Create a subdag that includes the selected snapshots and all their upstream dependencies + # This ensures that transitive dependencies are preserved even when intermediate nodes are not selected + selected_snapshot_ids_set = {s.snapshot_id for s in selected_snapshots} + snapshot_dag = full_dag.subdag(*selected_snapshot_ids_set) + + batched_intervals = self.batch_intervals( + merged_intervals, + deployability_index, + environment_naming_info, + dag=snapshot_dag, + is_restatement=is_restatement, + ) + self.console.start_evaluation_progress( + batched_intervals, + environment_naming_info, + self.default_catalog, + audit_only=audit_only, + ) + + if run_environment_statements: + environment_statements = self.state_sync.get_environment_statements( + environment_naming_info.name + ) + execute_environment_statements( + adapter=self.snapshot_evaluator.adapter, + environment_statements=environment_statements, + runtime_stage=RuntimeStage.BEFORE_ALL, + environment_naming_info=environment_naming_info, + default_catalog=self.default_catalog, + snapshots=self.snapshots_by_name, start=start, end=end, execution_time=execution_time, - snapshots=snapshots, - deployability_index=deployability_index, - wap_id=wap_id, - **kwargs, + selected_models=selected_models, ) - except AuditError as e: - self.notification_target_manager.notify(NotificationEvent.AUDIT_FAILURE, e) - if is_deployable and snapshot.node.owner: - self.notification_target_manager.notify_user( - NotificationEvent.AUDIT_FAILURE, snapshot.node.owner, e + + # We only need to create physical tables if the snapshot is not representative or if it + # needs backfill + snapshots_to_create_candidates = [ + s + for s in selected_snapshots + if not deployability_index.is_representative(s) or s in batched_intervals + ] + snapshots_to_create = { + s.snapshot_id + for s in self.snapshot_evaluator.get_snapshots_to_create( + snapshots_to_create_candidates, deployability_index + ) + } + + dag = self._dag( + batched_intervals, snapshot_dag=snapshot_dag, snapshots_to_create=snapshots_to_create + ) + + def run_node(node: SchedulingUnit) -> None: + if circuit_breaker and circuit_breaker(): + raise CircuitBreakerError() + if isinstance(node, DummyNode): + return + + snapshot = self.snapshots_by_name[node.snapshot_name] + + if isinstance(node, EvaluateNode): + self.console.start_snapshot_evaluation_progress(snapshot) + execution_start_ts = now_timestamp() + evaluation_duration_ms: t.Optional[int] = None + start, end = node.interval + + audit_results: t.List[AuditResult] = [] + try: + assert execution_time # mypy + assert deployability_index # mypy + + if audit_only: + audit_results = self._audit_snapshot( + snapshot=snapshot, + environment_naming_info=environment_naming_info, + deployability_index=deployability_index, + snapshots=self.snapshots_by_name, + start=start, + end=end, + execution_time=execution_time, + ) + else: + # If batch_index > 0, then the target table must exist since the first batch would have created it + target_table_exists = ( + snapshot.snapshot_id not in snapshots_to_create or node.batch_index > 0 + ) + audit_results = self.evaluate( + snapshot=snapshot, + environment_naming_info=environment_naming_info, + start=start, + end=end, + execution_time=execution_time, + deployability_index=deployability_index, + batch_index=node.batch_index, + allow_destructive_snapshots=allow_destructive_snapshots, + allow_additive_snapshots=allow_additive_snapshots, + target_table_exists=target_table_exists, + selected_models=selected_models, + ) + + evaluation_duration_ms = now_timestamp() - execution_start_ts + finally: + num_audits = len(audit_results) + num_audits_failed = sum(1 for result in audit_results if result.count) + + execution_stats = self.snapshot_evaluator.execution_tracker.get_execution_stats( + SnapshotIdBatch(snapshot_id=snapshot.snapshot_id, batch_id=node.batch_index) + ) + + self.console.update_snapshot_evaluation_progress( + snapshot, + batched_intervals[snapshot][node.batch_index], + node.batch_index, + evaluation_duration_ms, + num_audits - num_audits_failed, + num_audits_failed, + execution_stats=execution_stats, + auto_restatement_triggers=auto_restatement_triggers.get( + snapshot.snapshot_id + ), + ) + elif isinstance(node, CreateNode): + self.snapshot_evaluator.create_snapshot( + snapshot=snapshot, + snapshots=self.snapshots_by_name, + deployability_index=deployability_index, + allow_destructive_snapshots=allow_destructive_snapshots or set(), + allow_additive_snapshots=allow_additive_snapshots or set(), ) - logger.error(f"Audit Failure: {traceback.format_exc()}") - raise e - self.state_sync.add_interval(snapshot, start, end, is_dev=not is_deployable) + try: + with self.snapshot_evaluator.concurrent_context(): + errors, skipped_intervals = concurrent_apply_to_dag( + dag, + run_node, + self.max_workers, + raise_on_error=False, + ) + self.console.stop_evaluation_progress(success=not errors) - def run( + skipped_snapshots = { + i.snapshot_name for i in skipped_intervals if isinstance(i, EvaluateNode) + } + self.console.log_skipped_models(skipped_snapshots) + for skipped in skipped_snapshots: + logger.info(f"SKIPPED snapshot {skipped}\n") + + for error in errors: + if isinstance(error.__cause__, CircuitBreakerError): + raise error.__cause__ + logger.info(str(error), exc_info=error) + + self.console.log_failed_models(errors) + + return errors, skipped_intervals + finally: + if run_environment_statements: + execute_environment_statements( + adapter=self.snapshot_evaluator.adapter, + environment_statements=environment_statements, + runtime_stage=RuntimeStage.AFTER_ALL, + environment_naming_info=environment_naming_info, + default_catalog=self.default_catalog, + snapshots=self.snapshots_by_name, + start=start, + end=end, + execution_time=execution_time, + selected_models=selected_models, + ) + + self.state_sync.recycle() + + def _dag( + self, + batches: SnapshotToIntervals, + snapshot_dag: t.Optional[DAG[SnapshotId]] = None, + snapshots_to_create: t.Optional[t.Set[SnapshotId]] = None, + ) -> DAG[SchedulingUnit]: + """Builds a DAG of snapshot intervals to be evaluated. + + Args: + batches: The batches of snapshots and intervals to evaluate. + snapshot_dag: The DAG of all snapshots. + snapshots_to_create: The snapshots with missing physical tables. + + Returns: + A DAG of snapshot intervals to be evaluated. + """ + + intervals_per_snapshot = { + snapshot.name: intervals for snapshot, intervals in batches.items() + } + snapshots_to_create = snapshots_to_create or set() + original_snapshots_to_create = snapshots_to_create.copy() + upstream_dependencies_cache: t.Dict[SnapshotId, t.Set[SchedulingUnit]] = {} + + snapshot_dag = snapshot_dag or snapshots_to_dag(batches) + dag = DAG[SchedulingUnit]() + + for snapshot_id in snapshot_dag: + if snapshot_id.name not in self.snapshots_by_name: + continue + + snapshot = self.snapshots_by_name[snapshot_id.name] + intervals = intervals_per_snapshot.get(snapshot.name, []) + + upstream_dependencies: t.Set[SchedulingUnit] = set() + + for p_sid in snapshot.parents: + upstream_dependencies.update( + self._find_upstream_dependencies( + p_sid, + intervals_per_snapshot, + original_snapshots_to_create, + upstream_dependencies_cache, + ) + ) + + batch_concurrency = snapshot.node.batch_concurrency + batch_size = snapshot.node.batch_size + if snapshot.depends_on_past: + batch_concurrency = 1 + + create_node: t.Optional[CreateNode] = None + if snapshot.snapshot_id in original_snapshots_to_create and ( + snapshot.is_incremental_by_time_range + or ((not batch_concurrency or batch_concurrency > 1) and batch_size) + or not intervals + ): + # Add a separate node for table creation in case when there multiple concurrent + # evaluation nodes or when there are no intervals to evaluate. + create_node = CreateNode(snapshot_name=snapshot.name) + dag.add(create_node, upstream_dependencies) + snapshots_to_create.remove(snapshot.snapshot_id) + + for i, interval in enumerate(intervals): + node = EvaluateNode(snapshot_name=snapshot.name, interval=interval, batch_index=i) + + if create_node: + dag.add(node, [create_node]) + else: + dag.add(node, upstream_dependencies) + + if len(intervals) > 1: + dag.add(DummyNode(snapshot_name=snapshot.name), [node]) + + if batch_concurrency and i >= batch_concurrency: + batch_idx_to_wait_for = i - batch_concurrency + dag.add( + node, + [ + EvaluateNode( + snapshot_name=snapshot.name, + interval=intervals[batch_idx_to_wait_for], + batch_index=batch_idx_to_wait_for, + ), + ], + ) + return dag + + def _find_upstream_dependencies( + self, + parent_sid: SnapshotId, + intervals_per_snapshot: t.Dict[str, Intervals], + snapshots_to_create: t.Set[SnapshotId], + cache: t.Dict[SnapshotId, t.Set[SchedulingUnit]], + ) -> t.Set[SchedulingUnit]: + if parent_sid not in self.snapshots: + return set() + if parent_sid in cache: + return cache[parent_sid] + + p_intervals = intervals_per_snapshot.get(parent_sid.name, []) + + parent_node: t.Optional[SchedulingUnit] = None + if p_intervals: + if len(p_intervals) > 1: + parent_node = DummyNode(snapshot_name=parent_sid.name) + else: + interval = p_intervals[0] + parent_node = EvaluateNode( + snapshot_name=parent_sid.name, interval=interval, batch_index=0 + ) + elif parent_sid in snapshots_to_create: + parent_node = CreateNode(snapshot_name=parent_sid.name) + + if parent_node is not None: + cache[parent_sid] = {parent_node} + return {parent_node} + + # This snapshot has no intervals and doesn't need creation which means + # that it can be a transitive dependency + transitive_deps: t.Set[SchedulingUnit] = set() + parent_snapshot = self.snapshots[parent_sid] + for grandparent_sid in parent_snapshot.parents: + transitive_deps.update( + self._find_upstream_dependencies( + grandparent_sid, intervals_per_snapshot, snapshots_to_create, cache + ) + ) + cache[parent_sid] = transitive_deps + return transitive_deps + + def _run_or_audit( self, environment: str | EnvironmentNamingInfo, start: t.Optional[TimeLike] = None, end: t.Optional[TimeLike] = None, execution_time: t.Optional[TimeLike] = None, - restatements: t.Optional[t.Dict[SnapshotId, SnapshotInterval]] = None, + remove_intervals: t.Optional[t.Dict[SnapshotId, Interval]] = None, + start_override_per_model: t.Optional[t.Dict[str, datetime]] = None, + end_override_per_model: t.Optional[t.Dict[str, datetime]] = None, ignore_cron: bool = False, end_bounded: bool = False, selected_snapshots: t.Optional[t.Set[str]] = None, circuit_breaker: t.Optional[t.Callable[[], bool]] = None, deployability_index: t.Optional[DeployabilityIndex] = None, - ) -> bool: - """Concurrently runs all snapshots in topological order. + auto_restatement_enabled: bool = False, + run_environment_statements: bool = False, + audit_only: bool = False, + ) -> CompletionStatus: + """Concurrently runs or audits all snapshots in topological order. Args: environment: The environment naming info the user is targeting when applying their change. @@ -256,18 +795,21 @@ def run( start: The start of the run. Defaults to the min node start date. end: The end of the run. Defaults to now. execution_time: The date/time time reference to use for execution time. Defaults to now. - restatements: A dict of snapshots to restate and their intervals. + remove_intervals: A dict of snapshots to their intervals. For evaluation, these are the intervals that will be restated. For audits, + these are the intervals that will be reaudited + start_override_per_model: A mapping of model FQNs to target start dates. + end_override_per_model: A mapping of model FQNs to target end dates. ignore_cron: Whether to ignore the node's cron schedule. end_bounded: If set to true, the evaluated intervals will be bounded by the target end date, disregarding lookback, allow_partials, and other attributes that could cause the intervals to exceed the target end date. selected_snapshots: A set of snapshot names to run. If not provided, all snapshots will be run. circuit_breaker: An optional handler which checks if the run should be aborted. deployability_index: Determines snapshots that are deployable in the context of this render. + auto_restatement_enabled: Whether to enable auto restatements. Returns: True if the execution was successful and False otherwise. """ - restatements = restatements or {} validate_date_range(start, end) if isinstance(environment, str): env = self.state_sync.get_environment(environment) @@ -281,145 +823,249 @@ def run( environment_naming_info = environment deployability_index = deployability_index or ( - DeployabilityIndex.create(self.snapshots.values()) + DeployabilityIndex.create(self.snapshots.values(), start=start) if environment_naming_info.name != c.PROD else DeployabilityIndex.all_deployable() ) - execution_time = execution_time or now() - batches = self.batches( + execution_time = execution_time or now_timestamp() + + self.state_sync.refresh_snapshot_intervals(self.snapshots.values()) + for s_id, interval in (remove_intervals or {}).items(): + self.snapshots[s_id].remove_interval(interval) + + all_auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {} + if auto_restatement_enabled: + auto_restated_intervals, all_auto_restatement_triggers = apply_auto_restatements( + self.snapshots, execution_time + ) + self.state_sync.add_snapshots_intervals(auto_restated_intervals) + self.state_sync.update_auto_restatements( + {s.name_version: s.next_auto_restatement_ts for s in self.snapshots.values()} + ) + + merged_intervals = self.merged_missing_intervals( start, end, execution_time, deployability_index=deployability_index, - restatements=restatements, + restatements=remove_intervals, + start_override_per_model=start_override_per_model, + end_override_per_model=end_override_per_model, ignore_cron=ignore_cron, end_bounded=end_bounded, selected_snapshots=selected_snapshots, ) - if not batches: - return True - - dag = self._dag(batches) - - self.console.start_evaluation_progress( - {snapshot: len(intervals) for snapshot, intervals in batches.items()}, - environment_naming_info, - self.default_catalog, + if not merged_intervals: + return CompletionStatus.NOTHING_TO_DO + + auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {} + if all_auto_restatement_triggers: + merged_intervals_snapshots = {snapshot.snapshot_id for snapshot in merged_intervals} + auto_restatement_triggers = { + s_id: all_auto_restatement_triggers.get(s_id, []) + for s_id in merged_intervals_snapshots + } + + errors, _ = self.run_merged_intervals( + merged_intervals=merged_intervals, + deployability_index=deployability_index, + environment_naming_info=environment_naming_info, + execution_time=execution_time, + circuit_breaker=circuit_breaker, + start=start, + end=end, + run_environment_statements=run_environment_statements, + audit_only=audit_only, + auto_restatement_triggers=auto_restatement_triggers, + selected_models={ + s.node.dbt_unique_id for s in merged_intervals if s.node.dbt_unique_id + }, ) - snapshots_by_name = {snapshot.name: snapshot for snapshot in self.snapshots.values()} - - def evaluate_node(node: SchedulingUnit) -> None: - if circuit_breaker and circuit_breaker(): - raise CircuitBreakerError() + return CompletionStatus.FAILURE if errors else CompletionStatus.SUCCESS - snapshot_name, ((start, end), batch_idx) = node - if batch_idx == -1: - return - snapshot = snapshots_by_name[snapshot_name] - - self.console.start_snapshot_evaluation_progress(snapshot) + def _audit_snapshot( + self, + snapshot: Snapshot, + deployability_index: DeployabilityIndex, + snapshots: t.Dict[str, Snapshot], + start: t.Optional[TimeLike] = None, + end: t.Optional[TimeLike] = None, + execution_time: t.Optional[TimeLike] = None, + wap_id: t.Optional[str] = None, + environment_naming_info: t.Optional[EnvironmentNamingInfo] = None, + **kwargs: t.Any, + ) -> t.List[AuditResult]: + is_deployable = deployability_index.is_deployable(snapshot) - execution_start_ts = now_timestamp() - evaluation_duration_ms: t.Optional[int] = None + audit_results = self.snapshot_evaluator.audit( + snapshot=snapshot, + start=start, + end=end, + execution_time=execution_time, + snapshots=snapshots, + deployability_index=deployability_index, + wap_id=wap_id, + **kwargs, + ) - try: - assert execution_time # mypy - assert deployability_index # mypy - self.evaluate(snapshot, start, end, execution_time, deployability_index, batch_idx) - evaluation_duration_ms = now_timestamp() - execution_start_ts - finally: - self.console.update_snapshot_evaluation_progress( - snapshot, batch_idx, evaluation_duration_ms + audit_errors_to_raise: t.List[AuditError] = [] + audit_errors_to_warn: t.List[AuditError] = [] + for audit_result in (result for result in audit_results if result.count): + error = AuditError( + audit_name=audit_result.audit.name, + audit_args=audit_result.audit_args, + model=snapshot.model_or_none, + count=t.cast(int, audit_result.count), + query=t.cast(exp.Query, audit_result.query), + adapter_dialect=self.snapshot_evaluator.adapter.dialect, + ) + self.notification_target_manager.notify(NotificationEvent.AUDIT_FAILURE, error) + if is_deployable and snapshot.node.owner: + self.notification_target_manager.notify_user( + NotificationEvent.AUDIT_FAILURE, snapshot.node.owner, error ) - - try: - with self.snapshot_evaluator.concurrent_context(): - errors, skipped_intervals = concurrent_apply_to_dag( - dag, - evaluate_node, - self.max_workers, - raise_on_error=False, + if audit_result.blocking: + audit_errors_to_raise.append(error) + else: + audit_errors_to_warn.append(error) + + if audit_errors_to_raise: + raise NodeAuditsErrors(audit_errors_to_raise) + + if environment_naming_info: + for audit_error in audit_errors_to_warn: + display_name = snapshot.display_name( + environment_naming_info, + self.default_catalog, + self.snapshot_evaluator.adapter.dialect, + ) + self.console.log_warning( + f"\n{display_name}: {audit_error}.", + f"{audit_error}. Audit query:\n{audit_error.query.sql(audit_error.adapter_dialect)}", ) - finally: - self.state_sync.recycle() - - self.console.stop_evaluation_progress(success=not errors) - - skipped_snapshots = {i[0] for i in skipped_intervals} - for skipped in skipped_snapshots: - log_message = f"SKIPPED snapshot {skipped}\n" - self.console.log_status_update(log_message) - logger.info(log_message) - for error in errors: - if isinstance(error.__cause__, CircuitBreakerError): - raise error.__cause__ - sid = error.node[0] - formatted_exception = "".join(format_exception(error.__cause__ or error)) - log_message = f"FAILED processing snapshot {sid}\n{formatted_exception}" - self.console.log_error(log_message) - # Log with INFO level to prevent duplicate messages in the console. - logger.info(log_message) + return audit_results - return not errors + def _check_ready_intervals( + self, + snapshot: Snapshot, + intervals: Intervals, + context: ExecutionContext, + environment_naming_info: EnvironmentNamingInfo, + ) -> Intervals: + """Checks if the intervals are ready for evaluation for the given snapshot. - def _dag(self, batches: SnapshotToBatches) -> DAG[SchedulingUnit]: - """Builds a DAG of snapshot intervals to be evaluated. + This implementation also includes the signal progress tracking. + Note that this will handle gaps in the provided intervals. The returned intervals + may introduce new gaps. Args: - batches: The batches of snapshots and intervals to evaluate. + snapshot: The snapshot to check. + intervals: The intervals to check. + context: The context to use. + environment_naming_info: The environment naming info to use. Returns: - A DAG of snapshot intervals to be evaluated. + The intervals that are ready for evaluation. """ + signals = snapshot.is_model and snapshot.model.render_signal_calls() - intervals_per_snapshot = { - snapshot.name: intervals for snapshot, intervals in batches.items() - } + if not (signals and signals.signals_to_kwargs): + return intervals - dag = DAG[SchedulingUnit]() - terminal_node = ((to_datetime(0), to_datetime(0)), -1) + self.console.start_signal_progress( + snapshot, + self.default_catalog, + environment_naming_info or EnvironmentNamingInfo(), + ) - for snapshot, intervals in batches.items(): - if not intervals: - continue + for signal_idx, (signal_name, kwargs) in enumerate(signals.signals_to_kwargs.items()): + # Capture intervals before signal check for display + intervals_to_check = merge_intervals(intervals) - upstream_dependencies = [] + signal_start_ts = time.perf_counter() - for p_sid in snapshot.parents: - if p_sid in self.snapshots: - p_intervals = intervals_per_snapshot.get(p_sid.name, []) + try: + intervals = check_ready_intervals( + signals.prepared_python_env[signal_name], + intervals, + context, + python_env=signals.python_env, + dialect=snapshot.model.dialect, + path=snapshot.model._path, + snapshot=snapshot, + kwargs=kwargs, + ) + except SQLMeshError as e: + raise SignalEvalError( + f"{e} '{signal_name}' for '{snapshot.model.name}' at {snapshot.model._path}" + ) - if len(p_intervals) > 1: - upstream_dependencies.append((p_sid.name, terminal_node)) - else: - for i, interval in enumerate(p_intervals): - upstream_dependencies.append((p_sid.name, (interval, i))) + duration = time.perf_counter() - signal_start_ts - batch_concurrency = snapshot.node.batch_concurrency - if snapshot.depends_on_past: - batch_concurrency = 1 + self.console.update_signal_progress( + snapshot=snapshot, + signal_name=signal_name, + signal_idx=signal_idx, + total_signals=len(signals.signals_to_kwargs), + ready_intervals=merge_intervals(intervals), + check_intervals=intervals_to_check, + duration=duration, + ) - for i, interval in enumerate(intervals): - node = (snapshot.name, (interval, i)) - dag.add(node, upstream_dependencies) + self.console.stop_signal_progress() - if len(intervals) > 1: - dag.add((snapshot.name, terminal_node), [node]) + return intervals - if batch_concurrency and i >= batch_concurrency: - batch_idx_to_wait_for = i - batch_concurrency - dag.add( - node, - [ - ( - snapshot.name, - (intervals[batch_idx_to_wait_for], batch_idx_to_wait_for), - ) - ], - ) - return dag + +def merged_missing_intervals( + snapshots: t.Collection[Snapshot], + start: t.Optional[TimeLike] = None, + end: t.Optional[TimeLike] = None, + execution_time: t.Optional[TimeLike] = None, + deployability_index: t.Optional[DeployabilityIndex] = None, + restatements: t.Optional[t.Dict[SnapshotId, Interval]] = None, + start_override_per_model: t.Optional[t.Dict[str, datetime]] = None, + end_override_per_model: t.Optional[t.Dict[str, datetime]] = None, + ignore_cron: bool = False, + end_bounded: bool = False, +) -> SnapshotToIntervals: + """Find the largest contiguous date interval parameters based only on what is missing. + + For each node name, find all dependencies and look for a stored snapshot from the metastore. If a snapshot is found, + calculate the missing intervals that need to be processed given the passed in start and end intervals. + + This is a superset of what may actually get processed at runtime based on things like batch size, signal readiness, etc. + + Args: + snapshots: A set of target snapshots for which intervals should be computed. + start: The start of the run. Defaults to the min node start date. + end: The end of the run. Defaults to now. + execution_time: The date/time reference to use for execution time. Defaults to now. + deployability_index: Determines snapshots that are deployable in the context of this evaluation. + restatements: A set of snapshot names being restated. + start_override_per_model: A mapping of model FQNs to target start dates. + end_override_per_model: A mapping of model FQNs to target end dates. + ignore_cron: Whether to ignore the node's cron schedule. + end_bounded: If set to true, the returned intervals will be bounded by the target end date, disregarding lookback, + allow_partials, and other attributes that could cause the intervals to exceed the target end date. + """ + restatements = restatements or {} + validate_date_range(start, end) + + return compute_interval_params( + snapshots, + start=start or earliest_start_date(snapshots), + end=end or now_timestamp(), + deployability_index=deployability_index, + execution_time=execution_time or now_timestamp(), + restatements=restatements, + start_override_per_model=start_override_per_model, + end_override_per_model=end_override_per_model, + ignore_cron=ignore_cron, + end_bounded=end_bounded, + ) def compute_interval_params( @@ -429,29 +1075,28 @@ def compute_interval_params( end: TimeLike, deployability_index: t.Optional[DeployabilityIndex] = None, execution_time: t.Optional[TimeLike] = None, - restatements: t.Optional[t.Dict[SnapshotId, SnapshotInterval]] = None, + restatements: t.Optional[t.Dict[SnapshotId, Interval]] = None, + start_override_per_model: t.Optional[t.Dict[str, datetime]] = None, + end_override_per_model: t.Optional[t.Dict[str, datetime]] = None, ignore_cron: bool = False, end_bounded: bool = False, - signal_factory: t.Optional[SignalFactory] = None, -) -> SnapshotToBatches: - """Find the optimal date interval paramaters based on what needs processing and maximal batch size. +) -> SnapshotToIntervals: + """Find the largest contiguous date interval parameters based only on what is missing. For each node name, find all dependencies and look for a stored snapshot from the metastore. If a snapshot is found, calculate the missing intervals that need to be processed given the passed in start and end intervals. - If a snapshot's node specifies a batch size, consecutive intervals are merged into batches of a size that is less than - or equal to the configured one. If no batch size is specified, then it uses the intervals that correspond to the node's cron expression. - For example, if a node is supposed to run daily and has 70 days to backfill with a batch size set to 30, there would be 2 jobs - with 30 days and 1 job with 10. + This is a superset of what may actually get processed at runtime based on things like batch size, signal readiness, etc. Args: snapshots: A set of target snapshots for which intervals should be computed. - intervals: A list of all snapshot intervals that should be considered. start: Start of the interval. end: End of the interval. deployability_index: Determines snapshots that are deployable in the context of this evaluation. - execution_time: The date/time time reference to use for execution time. + execution_time: The date/time reference to use for execution time. restatements: A dict of snapshot names being restated and their intervals. + start_override_per_model: A mapping of model FQNs to target start dates. + end_override_per_model: A mapping of model FQNs to target end dates. ignore_cron: Whether to ignore the node's cron schedule. end_bounded: If set to true, the returned intervals will be bounded by the target end date, disregarding lookback, allow_partials, and other attributes that could cause the intervals to exceed the target end date. @@ -459,7 +1104,7 @@ def compute_interval_params( Returns: A dict containing all snapshots needing to be run with their associated interval params. """ - snapshot_batches = {} + snapshot_merged_intervals = {} for snapshot, intervals in missing_intervals( snapshots, @@ -468,34 +1113,62 @@ def compute_interval_params( execution_time=execution_time, restatements=restatements, deployability_index=deployability_index, + start_override_per_model=start_override_per_model, + end_override_per_model=end_override_per_model, ignore_cron=ignore_cron, end_bounded=end_bounded, ).items(): - if signal_factory: - for signal in snapshot.model.render_signals( - start=start, end=end, execution_time=execution_time - ): - intervals = _check_ready_intervals( - signal=signal_factory(signal), - intervals=intervals, - ) - - batches = [] - batch_size = snapshot.node.batch_size - next_batch: t.List[t.Tuple[int, int]] = [] + contiguous_batch = [] + next_batch: Intervals = [] for interval in intervals: - if (batch_size and len(next_batch) >= batch_size) or ( - next_batch and interval[0] != next_batch[-1][-1] - ): - batches.append((next_batch[0][0], next_batch[-1][-1])) + if next_batch and interval[0] != next_batch[-1][-1]: + contiguous_batch.append((next_batch[0][0], next_batch[-1][-1])) next_batch = [] next_batch.append(interval) if next_batch: - batches.append((next_batch[0][0], next_batch[-1][-1])) - snapshot_batches[snapshot] = [(to_datetime(s), to_datetime(e)) for s, e in batches] + contiguous_batch.append((next_batch[0][0], next_batch[-1][-1])) + snapshot_merged_intervals[snapshot] = contiguous_batch + + return snapshot_merged_intervals + + +def interval_diff( + intervals_a: Intervals, intervals_b: Intervals, uninterrupted: bool = False +) -> Intervals: + if not intervals_a or not intervals_b: + return intervals_a + + index_a, index_b = 0, 0 + len_a = len(intervals_a) + len_b = len(intervals_b) + + results = [] + + while index_a < len_a and index_b < len_b: + interval_a = intervals_a[index_a] + interval_b = intervals_b[index_b] + + if interval_a[0] >= interval_b[1]: + index_b += 1 + elif interval_b[0] >= interval_a[1]: + results.append(interval_a) + index_a += 1 + else: + if uninterrupted: + return results + + if interval_a[0] >= interval_b[0]: + index_a += 1 + else: + index_b += 1 + + if index_a < len_a: + interval_a = intervals_a[index_a] + if interval_a[0] >= interval_b[1] or interval_b[0] >= interval_a[1]: + results.extend(intervals_a[index_a:]) - return snapshot_batches + return results def _resolve_one_snapshot_per_version( @@ -516,51 +1189,8 @@ def _resolve_one_snapshot_per_version( return snapshot_per_version -def _contiguous_intervals( - intervals: t.List[SnapshotInterval], -) -> t.List[t.List[SnapshotInterval]]: - """Given a list of intervals with gaps, returns a list of sequences of contiguous intervals.""" - contiguous_intervals = [] - current_batch: t.List[SnapshotInterval] = [] - for interval in intervals: - if len(current_batch) == 0 or interval[0] == current_batch[-1][-1]: - current_batch.append(interval) - else: - contiguous_intervals.append(current_batch) - current_batch = [interval] - - if len(current_batch) > 0: - contiguous_intervals.append(current_batch) - - return contiguous_intervals - - -def _check_ready_intervals( - signal: Signal, - intervals: t.List[SnapshotInterval], -) -> t.List[SnapshotInterval]: - """Returns a list of intervals that are considered ready by the provided signal. - - Note that this will handle gaps in the provided intervals. The returned intervals - may introduce new gaps. - """ - checked_intervals = [] - for interval_batch in _contiguous_intervals(intervals): - batch = [(to_datetime(start), to_datetime(end)) for start, end in interval_batch] - - ready_intervals = signal.check_intervals(batch=batch) - if isinstance(ready_intervals, bool): - if not ready_intervals: - batch = [] - elif isinstance(ready_intervals, list): - for i in ready_intervals: - if i not in batch: - raise RuntimeError(f"Signal returned unknown interval {i}") - batch = ready_intervals - else: - raise ValueError( - f"unexpected return value from signal, expected bool | list, got {type(ready_intervals)}" - ) - - checked_intervals.extend([(to_timestamp(start), to_timestamp(end)) for start, end in batch]) - return checked_intervals +def _expand_range_as_interval( + start_ts: int, end_ts: int, interval_unit: IntervalUnit +) -> t.List[Interval]: + values = expand_range(start_ts, end_ts, interval_unit) + return [(values[i], values[i + 1]) for i in range(len(values) - 1)] diff --git a/sqlmesh/core/schema_diff.py b/sqlmesh/core/schema_diff.py index a049c25053..e1f9d72a6c 100644 --- a/sqlmesh/core/schema_diff.py +++ b/sqlmesh/core/schema_diff.py @@ -1,14 +1,19 @@ from __future__ import annotations +import abc import logging import typing as t +from dataclasses import dataclass from collections import defaultdict -from enum import Enum, auto +from enum import Enum + +from pydantic import Field from sqlglot import exp from sqlglot.helper import ensure_list, seq_get from sqlmesh.utils import columns_to_types_to_struct from sqlmesh.utils.pydantic import PydanticModel +from sqlmesh.utils.errors import SQLMeshError if t.TYPE_CHECKING: from sqlmesh.core._typing import TableName @@ -16,25 +21,141 @@ logger = logging.getLogger(__name__) -class TableAlterOperationType(Enum): - ADD = auto() - DROP = auto() - ALTER_TYPE = auto() +@dataclass(frozen=True) +class TableAlterOperation(abc.ABC): + target_table: exp.Table + + @property + @abc.abstractmethod + def is_destructive(self) -> bool: + pass + + @property + @abc.abstractmethod + def is_additive(self) -> bool: + pass + + @property + @abc.abstractmethod + def _alter_actions(self) -> t.List[exp.Expression]: + pass + + @property + def expression(self) -> exp.Alter: + return exp.Alter( + this=self.target_table, + kind="TABLE", + actions=self._alter_actions, + ) + + +@dataclass(frozen=True) +class TableAlterColumnOperation(TableAlterOperation, abc.ABC): + column_parts: t.List[TableAlterColumn] + expected_table_struct: exp.DataType + array_element_selector: str + + @property + def column_identifiers(self) -> t.List[exp.Identifier]: + results = [] + for column in self.column_parts: + results.append(column.identifier) + if ( + column.is_array_of_struct + and len(self.column_parts) > 1 + and self.array_element_selector + ): + results.append(exp.to_identifier(self.array_element_selector)) + return results + + @property + def column(self) -> t.Union[exp.Dot, exp.Identifier]: + columns = self.column_identifiers + if len(columns) == 1: + return columns[0] + return exp.Dot.build(columns) + + +@dataclass(frozen=True) +class TableAlterTypedColumnOperation(TableAlterColumnOperation, abc.ABC): + column_type: exp.DataType + + @property + def column_def(self) -> exp.ColumnDef: + if not self.column_type: + raise SQLMeshError("Tried to access column type when it shouldn't be needed") + return exp.ColumnDef( + this=self.column, + kind=self.column_type, + ) + + +@dataclass(frozen=True) +class TableAlterAddColumnOperation(TableAlterTypedColumnOperation): + position: t.Optional[TableAlterColumnPosition] = None + is_part_of_destructive_change: bool = False + + @property + def is_additive(self) -> bool: + return not self.is_part_of_destructive_change + + @property + def is_destructive(self) -> bool: + return self.is_part_of_destructive_change + + @property + def _alter_actions(self) -> t.List[exp.Expression]: + column_def = exp.ColumnDef( + this=self.column, + kind=self.column_type, + ) + if self.position: + column_def.set("position", self.position.column_position_node) + return [column_def] + + +@dataclass(frozen=True) +class TableAlterDropColumnOperation(TableAlterColumnOperation): + cascade: bool = False + + @property + def is_additive(self) -> bool: + return False + + @property + def is_destructive(self) -> bool: + return True + + @property + def _alter_actions(self) -> t.List[exp.Expression]: + return [exp.Drop(this=self.column, kind="COLUMN", cascade=self.cascade)] + + +@dataclass(frozen=True) +class TableAlterChangeColumnTypeOperation(TableAlterTypedColumnOperation): + current_type: exp.DataType + is_part_of_destructive_change: bool = False @property - def is_add(self) -> bool: - return self == TableAlterOperationType.ADD + def is_additive(self) -> bool: + return not self.is_part_of_destructive_change @property - def is_drop(self) -> bool: - return self == TableAlterOperationType.DROP + def is_destructive(self) -> bool: + return self.is_part_of_destructive_change @property - def is_alter_type(self) -> bool: - return self == TableAlterOperationType.ALTER_TYPE + def _alter_actions(self) -> t.List[exp.Expression]: + return [ + exp.AlterColumn( + this=self.column, + dtype=self.column_type, + ) + ] -class TableAlterColumn(PydanticModel): +@dataclass(frozen=True) +class TableAlterColumn: name: str is_struct: bool is_array_of_struct: bool @@ -89,15 +210,13 @@ def from_struct_kwarg(cls, struct: exp.ColumnDef) -> TableAlterColumn: if kwarg_type.is_type(exp.DataType.Type.STRUCT): return cls.struct(name, quoted=quoted) - elif kwarg_type.is_type(exp.DataType.Type.ARRAY): + if kwarg_type.is_type(exp.DataType.Type.ARRAY): if kwarg_type.expressions and kwarg_type.expressions[0].is_type( exp.DataType.Type.STRUCT ): return cls.array_of_struct(name, quoted=quoted) - else: - return cls.array_of_primitive(name, quoted=quoted) - else: - return cls.primitive(name, quoted=quoted) + return cls.array_of_primitive(name, quoted=quoted) + return cls.primitive(name, quoted=quoted) @property def is_array(self) -> bool: @@ -116,7 +235,8 @@ def identifier(self) -> exp.Identifier: return exp.to_identifier(self.name, quoted=self.quoted) -class TableAlterColumnPosition(PydanticModel): +@dataclass(frozen=True) +class TableAlterColumnPosition: is_first: bool is_last: bool after: t.Optional[exp.Identifier] = None @@ -161,122 +281,31 @@ def column_position_node(self) -> t.Optional[exp.ColumnPosition]: return exp.ColumnPosition(this=column, position=position) -class TableAlterOperation(PydanticModel): - op: TableAlterOperationType - columns: t.List[TableAlterColumn] - column_type: exp.DataType - expected_table_struct: exp.DataType - add_position: t.Optional[TableAlterColumnPosition] = None - current_type: t.Optional[exp.DataType] = None - - @classmethod - def add( - cls, - columns: t.Union[TableAlterColumn, t.List[TableAlterColumn]], - column_type: t.Union[str, exp.DataType], - expected_table_struct: t.Union[str, exp.DataType], - position: t.Optional[TableAlterColumnPosition] = None, - ) -> TableAlterOperation: - return cls( - op=TableAlterOperationType.ADD, - columns=ensure_list(columns), - column_type=exp.DataType.build(column_type), - add_position=position, - expected_table_struct=exp.DataType.build(expected_table_struct), - ) - - @classmethod - def drop( - cls, - columns: t.Union[TableAlterColumn, t.List[TableAlterColumn]], - expected_table_struct: t.Union[str, exp.DataType], - column_type: t.Optional[t.Union[str, exp.DataType]] = None, - ) -> TableAlterOperation: - column_type = exp.DataType.build(column_type) if column_type else exp.DataType.build("INT") - return cls( - op=TableAlterOperationType.DROP, - columns=ensure_list(columns), - column_type=column_type, - expected_table_struct=exp.DataType.build(expected_table_struct), - ) - - @classmethod - def alter_type( - cls, - columns: t.Union[TableAlterColumn, t.List[TableAlterColumn]], - column_type: t.Union[str, exp.DataType], - current_type: t.Union[str, exp.DataType], - expected_table_struct: t.Union[str, exp.DataType], - position: t.Optional[TableAlterColumnPosition] = None, - ) -> TableAlterOperation: - return cls( - op=TableAlterOperationType.ALTER_TYPE, - columns=ensure_list(columns), - column_type=exp.DataType.build(column_type), - add_position=position, - current_type=exp.DataType.build(current_type), - expected_table_struct=exp.DataType.build(expected_table_struct), - ) +class NestedSupport(str, Enum): + # Supports all nested data type operations + ALL = "ALL" + # Does not support any nested data type operations + NONE = "NONE" + # Supports nested data type operations except for those that require dropping a nested field + ALL_BUT_DROP = "ALL_BUT_DROP" + # Ignores all nested data type operations + IGNORE = "IGNORE" @property - def is_add(self) -> bool: - return self.op.is_add + def is_all(self) -> bool: + return self == NestedSupport.ALL @property - def is_drop(self) -> bool: - return self.op.is_drop + def is_none(self) -> bool: + return self == NestedSupport.NONE @property - def is_alter_type(self) -> bool: - return self.op.is_alter_type - - def column_identifiers(self, array_element_selector: str) -> t.List[exp.Identifier]: - results = [] - for column in self.columns: - results.append(column.identifier) - if column.is_array_of_struct and len(self.columns) > 1 and array_element_selector: - results.append(exp.to_identifier(array_element_selector)) - return results - - def column(self, array_element_selector: str) -> t.Union[exp.Dot, exp.Identifier]: - columns = self.column_identifiers(array_element_selector) - if len(columns) == 1: - return columns[0] - return exp.Dot.build(columns) - - def column_def(self, array_element_selector: str) -> exp.ColumnDef: - return exp.ColumnDef( - this=self.column(array_element_selector), - kind=self.column_type, - ) + def is_all_but_drop(self) -> bool: + return self == NestedSupport.ALL_BUT_DROP - def expression( - self, table_name: t.Union[str, exp.Table], array_element_selector: str - ) -> exp.AlterTable: - if self.is_alter_type: - return exp.AlterTable( - this=exp.to_table(table_name), - actions=[ - exp.AlterColumn( - this=self.column(array_element_selector), - dtype=self.column_type, - ) - ], - ) - elif self.is_add: - alter_table = exp.AlterTable(this=exp.to_table(table_name)) - column = self.column_def(array_element_selector) - alter_table.set("actions", [column]) - if self.add_position: - column.set("position", self.add_position.column_position_node) - return alter_table - elif self.is_drop: - alter_table = exp.AlterTable(this=exp.to_table(table_name)) - drop_column = exp.Drop(this=self.column(array_element_selector), kind="COLUMN") - alter_table.set("actions", [drop_column]) - return alter_table - else: - raise ValueError(f"Unknown operation {self.op}") + @property + def is_ignore(self) -> bool: + return self == NestedSupport.IGNORE class SchemaDiffer(PydanticModel): @@ -298,12 +327,17 @@ class SchemaDiffer(PydanticModel): Args: support_positional_add: Whether the engine for which the diff is being computed supports adding columns in a specific position in the set of existing columns. - support_nested_operations: Whether the engine for which the diff is being computed supports modifications to - nested data types like STRUCTs and ARRAYs. + nested_support: How the engine for which the diff is being computed supports nested types. compatible_types: Types that are compatible and automatically coerced in actions like UNION ALL. Dict key is data type, and value is the set of types that are compatible with it. + coerceable_types: The mapping from a current type to all types that can be safely coerced to the current one without + altering the column type. NOTE: usually callers should not specify this attribute manually and set the + `support_coercing_compatible_types` flag instead. Some engines are inconsistent about their type coercion rules. + For example, in BigQuery a BIGNUMERIC column can't be altered to be FLOAT64, while BIGNUMERIC values can be inserted + into a FLOAT64 column just fine. support_coercing_compatible_types: Whether or not the engine for which the diff is being computed supports direct coercion of compatible types. + drop_cascade: Whether to add CASCADE modifier when dropping a column. parameterized_type_defaults: Default values for parameterized data types. Dict key is a sqlglot exp.DataType.Type, but in the engine adapter specification we build it from the dialect string instead of specifying it directly. Example: `exp.DataType.build("STRING", dialect=DIALECT).this` instead of the underlying `exp.DataType.Type.TEXT` @@ -316,18 +350,28 @@ class SchemaDiffer(PydanticModel): max_parameter_length: Numeric parameter values corresponding to "max". Example: `VARCHAR(max)` -> `VARCHAR(65535)`. types_with_unlimited_length: Data types that accept values of any length up to system limits. Any explicitly parameterized type can ALTER to its unlimited length version, along with different types in some engines. + treat_alter_data_type_as_destructive: The SchemaDiffer will only output change data type operations if it + concludes the change is compatible and won't result in data loss. If this flag is set to True, it will + flag these data type changes as destructive. This was added for dbt adapter support and likely shouldn't + be set outside of that context. """ support_positional_add: bool = False - support_nested_operations: bool = False + nested_support: NestedSupport = NestedSupport.NONE array_element_selector: str = "" compatible_types: t.Dict[exp.DataType, t.Set[exp.DataType]] = {} + coerceable_types_: t.Dict[exp.DataType, t.Set[exp.DataType]] = Field( + default_factory=dict, alias="coerceable_types" + ) + precision_increase_allowed_types: t.Optional[t.Set[exp.DataType.Type]] = None support_coercing_compatible_types: bool = False + drop_cascade: bool = False parameterized_type_defaults: t.Dict[ exp.DataType.Type, t.List[t.Tuple[t.Union[int, float], ...]] ] = {} max_parameter_length: t.Dict[exp.DataType.Type, t.Union[int, float]] = {} types_with_unlimited_length: t.Dict[exp.DataType.Type, t.Set[exp.DataType.Type]] = {} + treat_alter_data_type_as_destructive: bool = False _coerceable_types: t.Dict[exp.DataType, t.Set[exp.DataType]] = {} @@ -335,8 +379,9 @@ class SchemaDiffer(PydanticModel): def coerceable_types(self) -> t.Dict[exp.DataType, t.Set[exp.DataType]]: if not self._coerceable_types: if not self.support_coercing_compatible_types or not self.compatible_types: - return {} - coerceable_types = defaultdict(set) + return self.coerceable_types_ + coerceable_types: t.Dict[exp.DataType, t.Set[exp.DataType]] = defaultdict(set) + coerceable_types.update(self.coerceable_types_) for source_type, target_types in self.compatible_types.items(): for target_type in target_types: coerceable_types[target_type].add(source_type) @@ -346,7 +391,10 @@ def coerceable_types(self) -> t.Dict[exp.DataType, t.Set[exp.DataType]]: def _is_compatible_type(self, current_type: exp.DataType, new_type: exp.DataType) -> bool: # types are identical or both types are parameterized and new has higher precision # - default parameter values are automatically provided if not present - if current_type == new_type or self._is_precision_increase(current_type, new_type): + if current_type == new_type or ( + self._is_precision_increase_allowed(current_type) + and self._is_precision_increase(current_type, new_type) + ): return True # types are un-parameterized and compatible if current_type in self.compatible_types: @@ -357,17 +405,24 @@ def _is_compatible_type(self, current_type: exp.DataType, new_type: exp.DataType return False def _is_coerceable_type(self, current_type: exp.DataType, new_type: exp.DataType) -> bool: - if not self.support_coercing_compatible_types: - return False if current_type in self.coerceable_types: is_coerceable = new_type in self.coerceable_types[current_type] if is_coerceable: - logger.warning( - f"Coercing type {current_type} to {new_type} which means an alter will not be performed and therefore the resulting table structure will not match what is in the query.\nUpdate your model to cast the value to {current_type} type in order to remove this warning.", + from sqlmesh.core.console import get_console + + get_console().log_warning( + f"Coercing type {current_type} to {new_type} which means an alter will not be performed and therefore the resulting table structure will not match what is in the query.\nUpdate your model to cast the value to {current_type} type in order to remove this warning." ) + return is_coerceable return False + def _is_precision_increase_allowed(self, current_type: exp.DataType) -> bool: + return ( + self.precision_increase_allowed_types is None + or current_type.this in self.precision_increase_allowed_types + ) + def _is_precision_increase(self, current_type: exp.DataType, new_type: exp.DataType) -> bool: if current_type.this == new_type.this and not current_type.is_type( *exp.DataType.NESTED_TYPES @@ -438,32 +493,54 @@ def _drop_operation( struct: exp.DataType, pos: int, root_struct: exp.DataType, - ) -> t.List[TableAlterOperation]: + table_name: TableName, + ) -> t.List[TableAlterColumnOperation]: columns = ensure_list(columns) - operations = [] + operations: t.List[TableAlterColumnOperation] = [] column_pos, column_kwarg = self._get_matching_kwarg(columns[-1].name, struct, pos) - assert column_pos is not None - assert column_kwarg + if column_pos is None or not column_kwarg: + raise SQLMeshError( + f"Cannot drop column '{columns[-1].name}' from table '{table_name}' - column not found. " + f"This may indicate a mismatch between the expected and actual table schemas." + ) struct.expressions.pop(column_pos) operations.append( - TableAlterOperation.drop(columns, root_struct.copy(), column_kwarg.args["kind"]) + TableAlterDropColumnOperation( + target_table=exp.to_table(table_name), + column_parts=columns, + expected_table_struct=root_struct.copy(), + cascade=self.drop_cascade, + array_element_selector=self.array_element_selector, + ) ) return operations + def _requires_drop_alteration( + self, current_struct: exp.DataType, new_struct: exp.DataType + ) -> bool: + for current_pos, current_kwarg in enumerate(current_struct.expressions.copy()): + new_pos, _ = self._get_matching_kwarg(current_kwarg, new_struct, current_pos) + if new_pos is None: + return True + return False + def _resolve_drop_operation( self, parent_columns: t.List[TableAlterColumn], current_struct: exp.DataType, new_struct: exp.DataType, root_struct: exp.DataType, - ) -> t.List[TableAlterOperation]: + table_name: TableName, + ) -> t.List[TableAlterColumnOperation]: operations = [] for current_pos, current_kwarg in enumerate(current_struct.expressions.copy()): new_pos, _ = self._get_matching_kwarg(current_kwarg, new_struct, current_pos) columns = parent_columns + [TableAlterColumn.from_struct_kwarg(current_kwarg)] if new_pos is None: operations.extend( - self._drop_operation(columns, current_struct, current_pos, root_struct) + self._drop_operation( + columns, current_struct, current_pos, root_struct, table_name + ) ) return operations @@ -474,7 +551,9 @@ def _add_operation( new_kwarg: exp.ColumnDef, current_struct: exp.DataType, root_struct: exp.DataType, - ) -> t.List[TableAlterOperation]: + table_name: TableName, + is_part_of_destructive_change: bool = False, + ) -> t.List[TableAlterColumnOperation]: if self.support_positional_add: col_pos = TableAlterColumnPosition.create(new_pos, current_struct.expressions) current_struct.expressions.insert(new_pos, new_kwarg) @@ -482,11 +561,14 @@ def _add_operation( col_pos = None current_struct.expressions.append(new_kwarg) return [ - TableAlterOperation.add( - columns, - new_kwarg.args["kind"], - root_struct.copy(), - col_pos, + TableAlterAddColumnOperation( + target_table=exp.to_table(table_name), + column_parts=columns, + column_type=new_kwarg.args["kind"], + expected_table_struct=root_struct.copy(), + position=col_pos, + is_part_of_destructive_change=is_part_of_destructive_change, + array_element_selector=self.array_element_selector, ) ] @@ -496,14 +578,17 @@ def _resolve_add_operations( current_struct: exp.DataType, new_struct: exp.DataType, root_struct: exp.DataType, - ) -> t.List[TableAlterOperation]: + table_name: TableName, + ) -> t.List[TableAlterColumnOperation]: operations = [] for new_pos, new_kwarg in enumerate(new_struct.expressions): possible_current_pos, _ = self._get_matching_kwarg(new_kwarg, current_struct, new_pos) if possible_current_pos is None: columns = parent_columns + [TableAlterColumn.from_struct_kwarg(new_kwarg)] operations.extend( - self._add_operation(columns, new_pos, new_kwarg, current_struct, root_struct) + self._add_operation( + columns, new_pos, new_kwarg, current_struct, root_struct, table_name + ) ) return operations @@ -516,18 +601,31 @@ def _alter_operation( current_type: t.Union[str, exp.DataType], root_struct: exp.DataType, new_kwarg: exp.ColumnDef, - ) -> t.List[TableAlterOperation]: + table_name: TableName, + *, + ignore_destructive: bool = False, + ignore_additive: bool = False, + ) -> t.List[TableAlterColumnOperation]: # We don't copy on purpose here because current_type may need to be mutated inside # _get_operations (struct.expressions.pop and struct.expressions.insert) current_type = exp.DataType.build(current_type, copy=False) - if self.support_nested_operations: + if not self.nested_support.is_none: if new_type.this == current_type.this == exp.DataType.Type.STRUCT: - return self._get_operations( - columns, - current_type, - new_type, - root_struct, - ) + if self.nested_support.is_ignore: + return [] + if self.nested_support.is_all or not self._requires_drop_alteration( + current_type, new_type + ): + return self._get_operations( + columns, + current_type, + new_type, + root_struct, + table_name, + ignore_destructive=ignore_destructive, + ignore_additive=ignore_additive, + ) + if new_type.this == current_type.this == exp.DataType.Type.ARRAY: # Some engines (i.e. Snowflake) don't support defining types on arrays if not new_type.expressions or not current_type.expressions: @@ -535,35 +633,55 @@ def _alter_operation( new_array_type = new_type.expressions[0] current_array_type = current_type.expressions[0] if new_array_type.this == current_array_type.this == exp.DataType.Type.STRUCT: - return self._get_operations( - columns, - current_array_type, - new_array_type, - root_struct, - ) + if self.nested_support.is_ignore: + return [] + if self.nested_support.is_all or not self._requires_drop_alteration( + current_array_type, new_array_type + ): + return self._get_operations( + columns, + current_array_type, + new_array_type, + root_struct, + table_name, + ignore_destructive=ignore_destructive, + ignore_additive=ignore_additive, + ) if self._is_coerceable_type(current_type, new_type): return [] - elif self._is_compatible_type(current_type, new_type): + if self._is_compatible_type(current_type, new_type): + if ignore_additive: + return [] struct.expressions.pop(pos) struct.expressions.insert(pos, new_kwarg) - col_pos = ( - TableAlterColumnPosition.create(pos, struct.expressions, replacing_col=True) - if self.support_positional_add - else None - ) return [ - TableAlterOperation.alter_type( - columns, - new_type, - current_type, - root_struct.copy(), - col_pos, + TableAlterChangeColumnTypeOperation( + target_table=exp.to_table(table_name), + column_parts=columns, + column_type=new_type, + current_type=current_type, + expected_table_struct=root_struct.copy(), + array_element_selector=self.array_element_selector, + is_part_of_destructive_change=self.treat_alter_data_type_as_destructive, ) ] - else: - return self._drop_operation( - columns, root_struct, pos, root_struct - ) + self._add_operation(columns, pos, new_kwarg, struct, root_struct) + if ignore_destructive: + return [] + return self._drop_operation( + columns, + root_struct, + pos, + root_struct, + table_name, + ) + self._add_operation( + columns, + pos, + new_kwarg, + struct, + root_struct, + table_name, + is_part_of_destructive_change=True, + ) def _resolve_alter_operations( self, @@ -571,11 +689,18 @@ def _resolve_alter_operations( current_struct: exp.DataType, new_struct: exp.DataType, root_struct: exp.DataType, - ) -> t.List[TableAlterOperation]: + table_name: TableName, + *, + ignore_destructive: bool = False, + ignore_additive: bool = False, + ) -> t.List[TableAlterColumnOperation]: operations = [] for current_pos, current_kwarg in enumerate(current_struct.expressions.copy()): _, new_kwarg = self._get_matching_kwarg(current_kwarg, new_struct, current_pos) - assert new_kwarg + if new_kwarg is None: + if ignore_destructive: + continue + raise ValueError("Cannot alter a column that is being dropped") _, new_type = _get_name_and_type(new_kwarg) _, current_type = _get_name_and_type(current_kwarg) columns = parent_columns + [TableAlterColumn.from_struct_kwarg(current_kwarg)] @@ -590,6 +715,9 @@ def _resolve_alter_operations( current_type, root_struct, new_kwarg, + table_name, + ignore_destructive=ignore_destructive, + ignore_additive=ignore_additive, ) ) return operations @@ -600,70 +728,150 @@ def _get_operations( current_struct: exp.DataType, new_struct: exp.DataType, root_struct: exp.DataType, - ) -> t.List[TableAlterOperation]: + table_name: TableName, + *, + ignore_destructive: bool = False, + ignore_additive: bool = False, + ) -> t.List[TableAlterColumnOperation]: root_struct = root_struct or current_struct parent_columns = parent_columns or [] operations = [] + if not ignore_destructive: + operations.extend( + self._resolve_drop_operation( + parent_columns, current_struct, new_struct, root_struct, table_name + ) + ) + if not ignore_additive: + operations.extend( + self._resolve_add_operations( + parent_columns, current_struct, new_struct, root_struct, table_name + ) + ) operations.extend( - self._resolve_drop_operation(parent_columns, current_struct, new_struct, root_struct) - ) - operations.extend( - self._resolve_add_operations(parent_columns, current_struct, new_struct, root_struct) - ) - operations.extend( - self._resolve_alter_operations(parent_columns, current_struct, new_struct, root_struct) + self._resolve_alter_operations( + parent_columns, + current_struct, + new_struct, + root_struct, + ignore_destructive=ignore_destructive, + ignore_additive=ignore_additive, + table_name=table_name, + ) ) return operations def _from_structs( - self, current_struct: exp.DataType, new_struct: exp.DataType - ) -> t.List[TableAlterOperation]: - return self._get_operations([], current_struct, new_struct, current_struct) - - def compare_structs( - self, table_name: t.Union[str, exp.Table], current: exp.DataType, new: exp.DataType - ) -> t.List[exp.AlterTable]: - """ - Compares two schemas represented as structs. - - Args: - current: The current schema. - new: The new schema. - - Returns: - The list of table alter operations. - """ - return [ - op.expression(table_name, self.array_element_selector) - for op in self._from_structs(current, new) - ] + self, + current_struct: exp.DataType, + new_struct: exp.DataType, + table_name: TableName, + *, + ignore_destructive: bool = False, + ignore_additive: bool = False, + ) -> t.List[TableAlterColumnOperation]: + return self._get_operations( + [], + current_struct, + new_struct, + current_struct, + table_name=table_name, + ignore_destructive=ignore_destructive, + ignore_additive=ignore_additive, + ) + + def _compare_structs( + self, + table_name: t.Union[str, exp.Table], + current: exp.DataType, + new: exp.DataType, + *, + ignore_destructive: bool = False, + ignore_additive: bool = False, + ) -> t.List[TableAlterColumnOperation]: + return self._from_structs( + current, + new, + table_name=table_name, + ignore_destructive=ignore_destructive, + ignore_additive=ignore_additive, + ) def compare_columns( self, table_name: TableName, current: t.Dict[str, exp.DataType], new: t.Dict[str, exp.DataType], - ) -> t.List[exp.AlterTable]: - """ - Compares two schemas represented as dictionaries of column names and types. - - Args: - current: The current schema. - new: The new schema. - - Returns: - The list of schema deltas. - """ - return self.compare_structs( - table_name, columns_to_types_to_struct(current), columns_to_types_to_struct(new) + *, + ignore_destructive: bool = False, + ignore_additive: bool = False, + ) -> t.List[TableAlterColumnOperation]: + return self._compare_structs( + table_name, + columns_to_types_to_struct(current), + columns_to_types_to_struct(new), + ignore_destructive=ignore_destructive, + ignore_additive=ignore_additive, ) -def has_drop_alteration(alter_expressions: t.List[exp.AlterTable]) -> bool: - return any( - isinstance(action, exp.Drop) - for actions in alter_expressions - for action in actions.args.get("actions", []) +def has_drop_alteration(alter_operations: t.List[TableAlterOperation]) -> bool: + return any(op.is_destructive for op in alter_operations) + + +def has_additive_alteration(alter_operations: t.List[TableAlterOperation]) -> bool: + return any(op.is_additive for op in alter_operations) + + +def get_additive_changes( + alter_operations: t.List[TableAlterOperation], +) -> t.List[TableAlterOperation]: + return [x for x in alter_operations if x.is_additive] + + +def get_dropped_column_names(alter_expressions: t.List[TableAlterOperation]) -> t.List[str]: + return [ + op.column.alias_or_name + for op in alter_expressions + if isinstance(op, TableAlterDropColumnOperation) + ] + + +def get_additive_column_names(alter_expressions: t.List[TableAlterOperation]) -> t.List[str]: + return [ + op.column.alias_or_name + for op in alter_expressions + if op.is_additive and isinstance(op, TableAlterColumnOperation) + ] + + +def get_schema_differ( + dialect: str, overrides: t.Optional[t.Dict[str, t.Any]] = None +) -> SchemaDiffer: + """ + Returns the appropriate SchemaDiffer for a given dialect without initializing the engine adapter. + + Args: + dialect: The dialect for which to get the schema differ. + overrides: Optional dictionary of overrides to apply to the SchemaDiffer instance. + + Returns: + The SchemaDiffer instance configured for the given dialect. + """ + from sqlmesh.core.engine_adapter import ( + DIALECT_TO_ENGINE_ADAPTER, + DIALECT_ALIASES, + EngineAdapter, + ) + + dialect = dialect.lower() + dialect = DIALECT_ALIASES.get(dialect, dialect) + engine_adapter_class = DIALECT_TO_ENGINE_ADAPTER.get(dialect, EngineAdapter) + return SchemaDiffer( + **{ + **getattr(engine_adapter_class, "SCHEMA_DIFFER_KWARGS"), + **(overrides or {}), + } ) diff --git a/sqlmesh/core/schema_loader.py b/sqlmesh/core/schema_loader.py index 1b00c7bc18..52ab807c78 100644 --- a/sqlmesh/core/schema_loader.py +++ b/sqlmesh/core/schema_loader.py @@ -1,6 +1,5 @@ from __future__ import annotations -import logging import typing as t from concurrent.futures import ThreadPoolExecutor from pathlib import Path @@ -8,14 +7,13 @@ from sqlglot import exp from sqlglot.dialects.dialect import DialectType +from sqlmesh.core.console import get_console from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.model.definition import Model from sqlmesh.core.state_sync import StateReader from sqlmesh.utils import UniqueKeyDict, yaml from sqlmesh.utils.errors import SQLMeshError -logger = logging.getLogger(__name__) - def create_external_models_file( path: Path, @@ -51,36 +49,25 @@ def create_external_models_file( # Make sure we don't convert internal models into external ones. existing_model_fqns = state_reader.nodes_exist(external_model_fqns, exclude_external=True) if existing_model_fqns: - logger.warning( - "The following models already exist and can't be converted to external: %s." - "Perhaps these models have been removed, while downstream models that reference them weren't updated accordingly", - ", ".join(existing_model_fqns), + existing_model_fqns_str = ", ".join(existing_model_fqns) + get_console().log_warning( + f"The following models already exist and can't be converted to external: {existing_model_fqns_str}. " + "Perhaps these models have been removed, while downstream models that reference them weren't updated accordingly." ) external_model_fqns -= existing_model_fqns with ThreadPoolExecutor(max_workers=max_workers) as pool: - - def _get_columns(table: str) -> t.Optional[t.Dict[str, t.Any]]: - try: - return adapter.columns(table, include_pseudo_columns=True) - except Exception as e: - msg = f"Unable to get schema for '{table}': '{e}'." - if strict: - raise SQLMeshError(msg) from e - logger.warning(msg) - return None - gateway_part = {"gateway": gateway} if gateway else {} schemas = [ { "name": exp.to_table(table).sql(dialect=dialect), - "columns": {c: dtype.sql(dialect=dialect) for c, dtype in columns.items()}, + "columns": columns, **gateway_part, } for table, columns in sorted( pool.map( - lambda table: (table, _get_columns(table)), + lambda table: (table, get_columns(adapter, dialect, table, strict)), external_model_fqns, ) ) @@ -96,3 +83,20 @@ def _get_columns(table: str) -> t.Optional[t.Dict[str, t.Any]]: with open(path, "w", encoding="utf-8") as file: yaml.dump(entries_to_keep + schemas, file) + + +def get_columns( + adapter: EngineAdapter, dialect: DialectType, table: str, strict: bool +) -> t.Optional[t.Dict[str, t.Any]]: + """ + Return the column and their types in a dictionary + """ + try: + columns = adapter.columns(table, include_pseudo_columns=True) + return {c: dtype.sql(dialect=dialect) for c, dtype in columns.items()} + except Exception as e: + msg = f"Unable to get schema for '{table}': '{e}'." + if strict: + raise SQLMeshError(msg) from e + get_console().log_warning(msg) + return None diff --git a/sqlmesh/core/selector.py b/sqlmesh/core/selector.py index 3c4a844d46..3865327acd 100644 --- a/sqlmesh/core/selector.py +++ b/sqlmesh/core/selector.py @@ -1,24 +1,36 @@ from __future__ import annotations import fnmatch -import logging import typing as t -from collections import defaultdict from pathlib import Path +from itertools import zip_longest +import abc +from sqlglot import exp +from sqlglot.errors import ParseError +from sqlglot.tokens import Token, TokenType, Tokenizer as BaseTokenizer +from sqlglot.dialects.dialect import Dialect, DialectType +from sqlglot.helper import seq_get + +from sqlmesh.core import constants as c from sqlmesh.core.dialect import normalize_model_name from sqlmesh.core.environment import Environment -from sqlmesh.core.loader import update_model_schemas -from sqlmesh.core.model import Model -from sqlmesh.core.state_sync import StateReader +from sqlmesh.core.model import update_model_schemas +from sqlmesh.core.audit import StandaloneAudit from sqlmesh.utils import UniqueKeyDict from sqlmesh.utils.dag import DAG from sqlmesh.utils.git import GitClient +from sqlmesh.utils.errors import SQLMeshError + -logger = logging.getLogger(__name__) +if t.TYPE_CHECKING: + from typing_extensions import Literal as Lit # noqa + from sqlmesh.core.model import Model + from sqlmesh.core.node import Node + from sqlmesh.core.state_sync import StateReader -class Selector: +class Selector(abc.ABC): def __init__( self, state_reader: StateReader, @@ -27,14 +39,15 @@ def __init__( dag: t.Optional[DAG[str]] = None, default_catalog: t.Optional[str] = None, dialect: t.Optional[str] = None, + cache_dir: t.Optional[Path] = None, ): self._state_reader = state_reader self._models = models self._context_path = context_path + self._cache_dir = cache_dir if cache_dir else context_path / c.CACHE self._default_catalog = default_catalog self._dialect = dialect self._git_client = GitClient(context_path) - self.__models_by_tag: t.Optional[t.Dict[str, t.Set[str]]] = None if dag is None: self._dag: DAG[str] = DAG() @@ -66,6 +79,9 @@ def select_models( A dictionary of models. """ target_env = self._state_reader.get_environment(Environment.sanitize_name(target_env_name)) + if target_env and target_env.expired: + target_env = None + if not target_env and fallback_env_name: target_env = self._state_reader.get_environment( Environment.sanitize_name(fallback_env_name) @@ -85,7 +101,7 @@ def select_models( } all_selected_models = self.expand_model_selections( - model_selections, models={**self._models, **env_models} + model_selections, models={**env_models, **self._models} ) dag: DAG[str] = DAG() @@ -97,196 +113,389 @@ def select_models( subdag.update(self._dag.downstream(fqn)) models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") - all_model_fqns = set(self._models) | set(env_models) - for fqn in all_model_fqns: - model: t.Optional[Model] = None + needs_update = False + + def get_model(fqn: str) -> t.Optional[Model]: if fqn not in all_selected_models and fqn in env_models: # Unselected modified or added model. - model = env_models[fqn] - elif fqn in all_selected_models and fqn in self._models: + model_from_env = env_models[fqn] + try: + # this triggers a render_query() which can throw an exception + model_from_env.depends_on + return model_from_env + except Exception as e: + raise SQLMeshError( + f"Model '{model_from_env.name}' sourced from state cannot be rendered " + f"in the local environment due to:\n> {str(e)}" + ) from e + if fqn in all_selected_models and fqn in self._models: # Selected modified or removed model. - model = self._models[fqn] + return self._models[fqn] + return None + + for fqn in all_model_fqns: + model = get_model(fqn) + + if not model: + continue + + if model.fqn in subdag: + dag.add(model.fqn, model.depends_on) - if model: - # model.copy() can't be used here due to a cached state that can be a part of a model instance. - if model.fqn in subdag: - model = type(model).parse_obj(model.dict(exclude={"mapping_schema"})) - dag.add(model.fqn, model.depends_on) - models[model.fqn] = model + for dep in model.depends_on: + schema = model.mapping_schema - update_model_schemas(dag, models, self._context_path) + for part in exp.to_table(dep).parts: + schema = schema.get(part.sql()) or {} + + parent = get_model(dep) + + parent_schema = { + c: t.sql(dialect=model.dialect) + for c, t in ((parent and parent.columns_to_types) or {}).items() + } + + if schema != parent_schema: + model = model.copy(update={"mapping_schema": {}}) + needs_update = True + break + + models[model.fqn] = model + + if needs_update: + update_model_schemas(dag, models=models, cache_dir=self._cache_dir) return models def expand_model_selections( - self, model_selections: t.Iterable[str], models: t.Optional[t.Dict[str, Model]] = None + self, model_selections: t.Iterable[str], models: t.Optional[t.Dict[str, Node]] = None ) -> t.Set[str]: - """Expands a set of model selections into a set of model names. + """Expands a set of model selections into a set of model fqns that can be looked up in the Context. Args: model_selections: A set of model selections. Returns: - A set of model names. + A set of model fqns. """ - results: t.Set[str] = set() - models = models or self._models - models_by_tags: t.Optional[t.Dict[str, t.Set[str]]] = None - - for selection in model_selections: - sub_results: t.Optional[t.Set[str]] = None - - def add_sub_results(sr: t.Set[str]) -> None: - nonlocal sub_results - if sub_results is None: - sub_results = sr - else: - sub_results &= sr - - sub_selections = [s.strip() for s in selection.split("&")] - for sub_selection in sub_selections: - if not sub_selection: - continue - - if sub_selection.startswith("tag:"): - if models_by_tags is None: - models_by_tag = defaultdict(set) - for model in models.values(): - for tag in model.tags: - models_by_tag[tag.lower()].add(model.fqn) - add_sub_results( - self._expand_model_tag(sub_selection[4:], models, models_by_tag) - ) - elif sub_selection.startswith(("git:", "+git:")): - sub_selection = sub_selection.replace("git:", "") - add_sub_results(self._expand_git(sub_selection, models)) - else: - add_sub_results(self._expand_model_name(sub_selection, models)) - - if sub_results: - results.update(sub_results) + + node = parse(" | ".join(f"({s})" for s in model_selections)) + + all_models: t.Dict[str, Node] = models or dict(self._models) + models_by_tags: t.Dict[str, t.Set[str]] = {} + + for fqn, model in all_models.items(): + for tag in model.tags: + tag = tag.lower() + models_by_tags.setdefault(tag, set()) + models_by_tags[tag].add(model.fqn) + + def evaluate(node: exp.Expression) -> t.Set[str]: + if isinstance(node, exp.Var): + pattern = node.this + if "*" in pattern: + return { + fqn + for fqn, model in all_models.items() + if fnmatch.fnmatchcase(self._model_name(model), node.this) + } + return self._pattern_to_model_fqns(pattern, all_models) + if isinstance(node, exp.And): + return evaluate(node.left) & evaluate(node.right) + if isinstance(node, exp.Or): + return evaluate(node.left) | evaluate(node.right) + if isinstance(node, exp.Paren): + return evaluate(node.this) + if isinstance(node, exp.Not): + return set(all_models) - evaluate(node.this) + if isinstance(node, Git): + target_branch = node.name + git_modified_files = { + *self._git_client.list_untracked_files(), + *self._git_client.list_uncommitted_changed_files(), + *self._git_client.list_committed_changed_files(target_branch=target_branch), + } + return {m.fqn for m in all_models.values() if m._path in git_modified_files} + if isinstance(node, Tag): + pattern = node.name.lower() + + if "*" in pattern: + return { + model + for tag, models in models_by_tags.items() + for model in models + if fnmatch.fnmatchcase(tag, pattern) + } + return models_by_tags.get(pattern, set()) + if isinstance(node, ResourceType): + resource_type = node.name.lower() + return { + fqn + for fqn, model in all_models.items() + if self._matches_resource_type(resource_type, model) + } + if isinstance(node, Direction): + selected = set() + + for model_name in evaluate(node.this): + selected.add(model_name) + if node.args.get("up"): + for u in self._dag.upstream(model_name): + if u in all_models: + selected.add(u) + if node.args.get("down"): + selected.update(self._dag.downstream(model_name)) + return selected + raise ParseError(f"Unexpected node {node}") + + return evaluate(node) + + @abc.abstractmethod + def _model_name(self, model: Node) -> str: + """Given a model, return the name that a selector pattern contining wildcards should be fnmatch'd on""" + pass + + @abc.abstractmethod + def _pattern_to_model_fqns(self, pattern: str, all_models: t.Dict[str, Node]) -> t.Set[str]: + """Given a pattern, return the keys of the matching models from :all_models""" + pass + + @abc.abstractmethod + def _matches_resource_type(self, resource_type: str, model: Node) -> bool: + """Indicate whether or not the supplied model matches the supplied resource type""" + pass + + +class NativeSelector(Selector): + """Implementation of selectors that matches objects based on SQLMesh native names""" + + def _model_name(self, model: Node) -> str: + return model.name + + def _pattern_to_model_fqns(self, pattern: str, all_models: t.Dict[str, Node]) -> t.Set[str]: + fqn = normalize_model_name(pattern, self._default_catalog, self._dialect) + return {fqn} if fqn in all_models else set() + + def _matches_resource_type(self, resource_type: str, model: Node) -> bool: + if resource_type == "model": + return model.is_model + if resource_type == "audit": + return isinstance(model, StandaloneAudit) + + raise SQLMeshError(f"Unsupported resource type: {resource_type}") + + +class DbtSelector(Selector): + """Implementation of selectors that matches objects based on the DBT names instead of the SQLMesh native names""" + + def _model_name(self, model: Node) -> str: + if dbt_fqn := model.dbt_fqn: + return dbt_fqn + raise SQLMeshError("dbt node information must be populated to use dbt selectors") + + def _pattern_to_model_fqns(self, pattern: str, all_models: t.Dict[str, Node]) -> t.Set[str]: + # a pattern like "staging.customers" should match a model called "jaffle_shop.staging.customers" + # but not a model called "jaffle_shop.customers.staging" + # also a pattern like "aging" should not match "staging" so we need to consider components; not substrings + pattern_components = pattern.split(".") + first_pattern_component = pattern_components[0] + matches = set() + for fqn, model in all_models.items(): + if not model.dbt_fqn: + continue + + dbt_fqn_components = model.dbt_fqn.split(".") + try: + starting_idx = dbt_fqn_components.index(first_pattern_component) + except ValueError: + continue + for pattern_component, fqn_component in zip_longest( + pattern_components, dbt_fqn_components[starting_idx:] + ): + if pattern_component and not fqn_component: + # the pattern still goes but we have run out of fqn components to match; no match + break + if fqn_component and not pattern_component: + # all elements of the pattern have matched elements of the fqn; match + matches.add(fqn) + break + if pattern_component != fqn_component: + # the pattern explicitly doesnt match a component; no match + break else: - logger.warning(f"Expression '{selection}' doesn't match any models.") + # called if no explicit break, indicating all components of the pattern matched all components of the fqn + matches.add(fqn) + return matches + + def _matches_resource_type(self, resource_type: str, model: Node) -> bool: + """ + ref: https://docs.getdbt.com/reference/node-selection/methods#resource_type + + # supported by SQLMesh + "model" + "seed" + "source" # external model + "test" # standalone audit + + # not supported by SQLMesh yet, commented out to throw an error if someone tries to use them + "analysis" + "exposure" + "metric" + "saved_query" + "semantic_model" + "snapshot" + "unit_test" + """ + if resource_type not in ("model", "seed", "source", "test"): + raise SQLMeshError(f"Unsupported resource type: {resource_type}") + + if isinstance(model, StandaloneAudit): + return resource_type == "test" + + if resource_type == "model": + return model.is_model and not model.kind.is_external and not model.kind.is_seed + if resource_type == "source": + return model.kind.is_external + if resource_type == "seed": + return model.kind.is_seed + + return False + + +class SelectorDialect(Dialect): + IDENTIFIERS_CAN_START_WITH_DIGIT = True + + class Tokenizer(BaseTokenizer): + SINGLE_TOKENS = { + "(": TokenType.L_PAREN, + ")": TokenType.R_PAREN, + "&": TokenType.AMP, + "|": TokenType.PIPE, + "^": TokenType.CARET, + "+": TokenType.PLUS, + "*": TokenType.STAR, + ":": TokenType.COLON, + } - return results + KEYWORDS = {} + IDENTIFIERS = ["\\"] # there are no identifiers but need to put something here + IDENTIFIER_START = "" + IDENTIFIER_END = "" - def _expand_git(self, target_branch: str, models: t.Dict[str, Model]) -> t.Set[str]: - results: t.Set[str] = set() - ( - target_branch, - include_upstream, - include_downstream, - ) = self._get_value_and_dependency_inclusion(target_branch) +class Git(exp.Expression): + pass - git_modified_files = { - *self._git_client.list_untracked_files(), - *self._git_client.list_uncommitted_changed_files(), - *self._git_client.list_committed_changed_files(target_branch=target_branch), - } - matched_models = {m.fqn for m in self._models.values() if m._path in git_modified_files} - if not matched_models: - logger.warning(f"Expression 'git:{target_branch}' doesn't match any models.") - return matched_models +class Tag(exp.Expression): + pass - for model_fqn in matched_models: - results.update( - self._get_models(model_fqn, include_upstream, include_downstream, models) - ) - return results +class ResourceType(exp.Expression): + pass - def _expand_model_name(self, selection: str, models: t.Dict[str, Model]) -> t.Set[str]: - results = set() - ( - selection, - include_upstream, - include_downstream, - ) = self._get_value_and_dependency_inclusion(selection) +class Direction(exp.Expression): + pass - matched_models = set() - if "*" in selection: - for model in models.values(): - if fnmatch.fnmatchcase(model.name, selection): - matched_models.add(model.fqn) - else: - model_fqn = normalize_model_name(selection, self._default_catalog, self._dialect) - if model_fqn in models: - matched_models.add(model_fqn) +def parse(selector: str, dialect: DialectType = None) -> exp.Expression: + tokens = SelectorDialect().tokenize(selector) + i = 0 - if not matched_models: - logger.warning(f"Expression '{selection}' doesn't match any models.") + def _curr() -> t.Optional[Token]: + return seq_get(tokens, i) - for model_fqn in matched_models: - results.update( - self._get_models(model_fqn, include_upstream, include_downstream, models) - ) - return results + def _prev() -> Token: + return tokens[i - 1] - def _expand_model_tag( - self, tag_selection: str, models: t.Dict[str, Model], models_by_tag: t.Dict[str, t.Set[str]] - ) -> t.Set[str]: - """ - Expands a set of model tags into a set of model names. - The tag matching is case-insensitive and supports wildcards and + prefix and suffix to - include upstream and downstream models. + def _advance(num: int = 1) -> Token: + nonlocal i + i += num + return _prev() - Args: - tag_selection: A tag to match models against. + def _next() -> t.Optional[Token]: + return seq_get(tokens, i + 1) - Returns: - A set of model names. - """ - result = set() - matched_tags = set() - ( - selection, - include_upstream, - include_downstream, - ) = self._get_value_and_dependency_inclusion(tag_selection.lower()) - - if "*" in selection: - for model_tag in models_by_tag: - if fnmatch.fnmatchcase(model_tag, selection): - matched_tags.add(model_tag) - elif selection in models_by_tag: - matched_tags.add(selection) - - if not matched_tags: - logger.warning(f"Expression 'tag:{tag_selection}' doesn't match any models.") - - for tag in matched_tags: - for model in models_by_tag[tag]: - result.update(self._get_models(model, include_upstream, include_downstream, models)) - - return result - - def _get_models( - self, - model_name: str, - include_upstream: bool, - include_downstream: bool, - models: t.Dict[str, Model], - ) -> t.Set[str]: - result = {model_name} - if include_upstream: - result.update([u for u in self._dag.upstream(model_name) if u in models]) - if include_downstream: - result.update(self._dag.downstream(model_name)) - return result - - @staticmethod - def _get_value_and_dependency_inclusion(value: str) -> t.Tuple[str, bool, bool]: - include_upstream = False - include_downstream = False - if value[0] == "+": - value = value[1:] - include_upstream = True - if value[-1] == "+": - value = value[:-1] - include_downstream = True - return value, include_upstream, include_downstream + def _error(msg: str) -> str: + return f"{msg} at index {i}: {selector}" + + def _match(token_type: TokenType, raise_unmatched: bool = False) -> t.Optional[Token]: + token = _curr() + if token and token.token_type == token_type: + return _advance() + if raise_unmatched: + raise ParseError(_error(f"Expected {token_type}")) + return None + + def _parse_kind(kind: str) -> bool: + token = _curr() + next_token = _next() + + if ( + token + and token.token_type == TokenType.VAR + and token.text.lower() == kind + and next_token + and next_token.token_type == TokenType.COLON + ): + _advance(2) + return True + return False + + def _parse_var() -> exp.Expression: + upstream = _match(TokenType.PLUS) + downstream = None + tag = _parse_kind("tag") + resource_type = False if tag else _parse_kind("resource_type") + git = False if resource_type else _parse_kind("git") + lstar = "*" if _match(TokenType.STAR) else "" + directions = {} + + if _match(TokenType.VAR) or _match(TokenType.NUMBER): + name = _prev().text + rstar = "*" if _match(TokenType.STAR) else "" + downstream = _match(TokenType.PLUS) + this: exp.Expression = exp.Var(this=f"{lstar}{name}{rstar}") + + elif _match(TokenType.L_PAREN): + this = exp.Paren(this=_parse_conjunction()) + downstream = _match(TokenType.PLUS) + _match(TokenType.R_PAREN, True) + elif lstar: + this = exp.var("*") + else: + raise ParseError(_error("Expected model name.")) + + if upstream: + directions["up"] = True + if downstream: + directions["down"] = True + + if tag: + this = Tag(this=this) + if resource_type: + this = ResourceType(this=this) + if git: + this = Git(this=this) + if directions: + this = Direction(this=this, **directions) + return this + + def _parse_unary() -> exp.Expression: + if _match(TokenType.CARET): + return exp.Not(this=_parse_unary()) + return _parse_var() + + def _parse_conjunction() -> exp.Expression: + this = _parse_unary() + + if _match(TokenType.AMP): + this = exp.And(this=this, expression=_parse_unary()) + if _match(TokenType.PIPE): + this = exp.Or(this=this, expression=_parse_conjunction()) + + return this + + return _parse_conjunction() diff --git a/sqlmesh/core/signal.py b/sqlmesh/core/signal.py new file mode 100644 index 0000000000..554dd60a39 --- /dev/null +++ b/sqlmesh/core/signal.py @@ -0,0 +1,98 @@ +from __future__ import annotations + +import typing as t +from sqlmesh.utils import UniqueKeyDict, registry_decorator +from sqlmesh.utils.errors import MissingSourceError + +if t.TYPE_CHECKING: + from sqlmesh.core.context import ExecutionContext + from sqlmesh.core.snapshot.definition import Snapshot + from sqlmesh.utils.date import DatetimeRanges + from sqlmesh.core.snapshot.definition import DeployabilityIndex + + +class signal(registry_decorator): + """Specifies a function which intervals are ready from a list of scheduled intervals. + + When SQLMesh wishes to execute a batch of intervals, say between `a` and `d`, then + the `batch` parameter will contain each individual interval within this batch, + i.e.: `[a,b),[b,c),[c,d)`. + + This function may return `True` to indicate that the whole batch is ready, + `False` to indicate none of the batch's intervals are ready, or a list of + intervals (a batch) to indicate exactly which ones are ready. + + When returning a batch, the function is expected to return a subset of + the `batch` parameter, e.g.: `[a,b),[b,c)`. Note that it may return + gaps, e.g.: `[a,b),[c,d)`, but it may not alter the bounds of any of the + intervals. + + The interface allows an implementation to check batches of intervals without + having to actually compute individual intervals itself. + + Args: + batch: the list of intervals that are missing and scheduled to run. + + Returns: + Either `True` to indicate all intervals are ready, `False` to indicate none are + ready or a list of intervals to indicate exactly which ones are ready. + """ + + +SignalRegistry = UniqueKeyDict[str, signal] + + +@signal() +def freshness( + batch: DatetimeRanges, + snapshot: Snapshot, + context: ExecutionContext, +) -> bool: + """ + Implements model freshness as a signal, i.e it considers this model to be fresh if: + - Any upstream SQLMesh model has available intervals to compute i.e is fresh + - Any upstream external model has been altered since the last time the model was evaluated + """ + adapter = context.engine_adapter + if context.is_restatement or not adapter.SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS: + return True + + deployability_index = context.deployability_index or DeployabilityIndex.all_deployable() + + last_altered_ts = ( + snapshot.last_altered_ts + if deployability_index.is_deployable(snapshot) + else snapshot.dev_last_altered_ts + ) + + if not last_altered_ts: + return True + + parent_snapshots = {context.snapshots[p.name] for p in snapshot.parents} + + upstream_parent_snapshots = {p for p in parent_snapshots if not p.is_external} + external_parents = snapshot.node.depends_on - {p.name for p in upstream_parent_snapshots} + + if context.parent_intervals: + # At least one upstream sqlmesh model has intervals to compute (i.e is fresh), + # so the current model is considered fresh too + return True + + if external_parents: + external_last_altered_timestamps = adapter.get_table_last_modified_ts( + list(external_parents) + ) + + if len(external_last_altered_timestamps) != len(external_parents): + raise MissingSourceError( + f"Expected {len(external_parents)} sources to be present, but got {len(external_last_altered_timestamps)}." + ) + + # Finding new data means that the upstream depedencies have been altered + # since the last time the model was evaluated + return any( + external_last_altered_ts > last_altered_ts + for external_last_altered_ts in external_last_altered_timestamps + ) + + return False diff --git a/sqlmesh/core/snapshot/__init__.py b/sqlmesh/core/snapshot/__init__.py index a474bcf13a..65e5c2a822 100644 --- a/sqlmesh/core/snapshot/__init__.py +++ b/sqlmesh/core/snapshot/__init__.py @@ -4,17 +4,21 @@ Node as Node, QualifiedViewName as QualifiedViewName, Snapshot as Snapshot, + SnapshotIdAndVersion as SnapshotIdAndVersion, SnapshotChangeCategory as SnapshotChangeCategory, SnapshotDataVersion as SnapshotDataVersion, SnapshotFingerprint as SnapshotFingerprint, SnapshotId as SnapshotId, + SnapshotIdBatch as SnapshotIdBatch, SnapshotIdLike as SnapshotIdLike, + SnapshotIdAndVersionLike as SnapshotIdAndVersionLike, SnapshotInfoLike as SnapshotInfoLike, SnapshotIntervals as SnapshotIntervals, SnapshotNameVersion as SnapshotNameVersion, SnapshotNameVersionLike as SnapshotNameVersionLike, SnapshotTableCleanupTask as SnapshotTableCleanupTask, SnapshotTableInfo as SnapshotTableInfo, + apply_auto_restatements as apply_auto_restatements, earliest_start_date as earliest_start_date, fingerprint_from_node as fingerprint_from_node, has_paused_forward_only as has_paused_forward_only, @@ -25,4 +29,7 @@ table_name as table_name, to_table_mapping as to_table_mapping, ) -from sqlmesh.core.snapshot.evaluator import SnapshotEvaluator as SnapshotEvaluator +from sqlmesh.core.snapshot.evaluator import ( + SnapshotEvaluator as SnapshotEvaluator, + SnapshotCreationFailedError as SnapshotCreationFailedError, +) diff --git a/sqlmesh/core/snapshot/cache.py b/sqlmesh/core/snapshot/cache.py new file mode 100644 index 0000000000..d46b5f0620 --- /dev/null +++ b/sqlmesh/core/snapshot/cache.py @@ -0,0 +1,111 @@ +from __future__ import annotations + +import logging +import typing as t + +from pathlib import Path +from sqlmesh.core.model.cache import ( + OptimizedQueryCache, + optimized_query_cache_pool, + load_optimized_query, +) +from sqlmesh.core import constants as c +from sqlmesh.core.snapshot.definition import Snapshot, SnapshotId +from sqlmesh.utils.cache import FileCache + + +logger = logging.getLogger(__name__) + + +class SnapshotCache: + def __init__(self, path: Path): + self._snapshot_cache: FileCache[Snapshot] = FileCache(path, prefix="snapshot") + self._optimized_query_cache = OptimizedQueryCache(path) + + def get_or_load( + self, + snapshot_ids: t.Set[SnapshotId], + loader: t.Callable[[t.Set[SnapshotId]], t.Collection[Snapshot]], + ) -> t.Tuple[t.Dict[SnapshotId, Snapshot], t.Set[SnapshotId]]: + """Fetches the target snapshots from cache or loads them using the provided loader on cache miss. + + Args: + snapshot_ids: Target snapshot IDs to fetch. + loader: The loader to load snapshot records that are missing in the cache. + + Returns: + A tuple where the first value represents the fetched snapshots, and the second value is a set of + snapshot IDs for which records were retrieved from the cache. + + """ + snapshots = {} + cache_hits: t.Set[SnapshotId] = set() + + for s_id in snapshot_ids: + snapshot = self._snapshot_cache.get(self._entry_name(s_id)) + if snapshot: + snapshot.intervals = [] + snapshot.dev_intervals = [] + snapshots[s_id] = snapshot + cache_hits.add(s_id) + + snapshot_ids_to_load = snapshot_ids - snapshots.keys() + if snapshot_ids_to_load: + loaded_snapshots = loader(snapshot_ids_to_load) + for snapshot in loaded_snapshots: + snapshots[snapshot.snapshot_id] = snapshot + + with optimized_query_cache_pool(self._optimized_query_cache) as executor: + for key, entry_name in executor.map( + load_optimized_query, + ( + (snapshot.model, s_id) + for s_id, snapshot in snapshots.items() + if snapshot.is_model + ), + ): + if entry_name: + self._optimized_query_cache.with_optimized_query( + snapshots[key].model, entry_name + ) + + for snapshot in snapshots.values(): + self._update_node_hash_cache(snapshot) + + if snapshot.is_model and c.MAX_FORK_WORKERS == 1: + try: + self._optimized_query_cache.with_optimized_query(snapshot.model) + except Exception: + logger.exception( + "Failed to cache optimized query for snapshot %s", snapshot.snapshot_id + ) + + self.put(snapshot) + + return snapshots, cache_hits + + def put(self, snapshot: Snapshot) -> None: + entry_name = self._entry_name(snapshot.snapshot_id) + + if self._snapshot_cache.exists(entry_name): + return + + try: + if snapshot.is_model: + # make sure we preload full_depends_on + snapshot.model.full_depends_on + self._snapshot_cache.put(entry_name, value=snapshot) + except Exception: + logger.exception("Failed to cache snapshot %s", snapshot.snapshot_id) + + def clear(self) -> None: + self._snapshot_cache.clear() + + @staticmethod + def _entry_name(snapshot_id: SnapshotId) -> str: + return f"{snapshot_id.name}_{snapshot_id.identifier}" + + @staticmethod + def _update_node_hash_cache(snapshot: Snapshot) -> None: + snapshot.node._data_hash = snapshot.fingerprint.data_hash + snapshot.node._metadata_hash = snapshot.fingerprint.metadata_hash diff --git a/sqlmesh/core/snapshot/categorizer.py b/sqlmesh/core/snapshot/categorizer.py index 09d83784b9..78ea7466ed 100644 --- a/sqlmesh/core/snapshot/categorizer.py +++ b/sqlmesh/core/snapshot/categorizer.py @@ -8,7 +8,11 @@ def categorize_change( - new: Snapshot, old: Snapshot, config: t.Optional[CategorizerConfig] = None + new: Snapshot, + old: Snapshot, + config: t.Optional[CategorizerConfig] = None, + is_breaking_change: t.Optional[t.Callable[..., t.Optional[bool]]] = None, + **kwargs: t.Any, ) -> t.Optional[SnapshotChangeCategory]: """Attempts to automatically categorize a change between two snapshots. @@ -19,6 +23,10 @@ def categorize_change( Args: new: The new snapshot. old: The old snapshot. + config: Configuration for the automatic categorizer of snapshot changes. + is_breaking_change: Callable that compares two models (new, old) and determines + whether there is a breaking change between them. + kwargs: Additional arguments to pass to is_breaking_change. Returns: The change category or None if the category can't be determined automatically. @@ -39,20 +47,24 @@ def categorize_change( if type(new_model) != type(old_model): return default_category - if new.fingerprint.data_hash == old.fingerprint.data_hash: - if new.fingerprint.metadata_hash == old.fingerprint.metadata_hash: - raise SQLMeshError( - f"{new} is unmodified or indirectly modified and should not be categorized" - ) + if new.fingerprint == old.fingerprint: + raise SQLMeshError( + f"{new} is unmodified or indirectly modified and should not be categorized" + ) + + if not new.is_directly_modified(old): if new.fingerprint.parent_data_hash == old.fingerprint.parent_data_hash: return SnapshotChangeCategory.NON_BREAKING return None - is_breaking_change = new_model.is_breaking_change(old_model) - if is_breaking_change is None: + breaking_change = ( + is_breaking_change(new_model, old_model, **kwargs) + if is_breaking_change + else new_model.is_breaking_change(old_model) + ) + if breaking_change is None: return default_category + return ( - SnapshotChangeCategory.BREAKING - if is_breaking_change - else SnapshotChangeCategory.NON_BREAKING + SnapshotChangeCategory.BREAKING if breaking_change else SnapshotChangeCategory.NON_BREAKING ) diff --git a/sqlmesh/core/snapshot/definition.py b/sqlmesh/core/snapshot/definition.py index 57143f1d7e..0c9635a7c2 100644 --- a/sqlmesh/core/snapshot/definition.py +++ b/sqlmesh/core/snapshot/definition.py @@ -5,27 +5,36 @@ from collections import defaultdict from datetime import datetime, timedelta from enum import IntEnum +import logging from functools import cached_property, lru_cache +from pathlib import Path from pydantic import Field from sqlglot import exp -from sqlglot.helper import seq_get from sqlglot.optimizer.normalize_identifiers import normalize_identifiers +from sqlmesh.core.config.common import ( + TableNamingConvention, + VirtualEnvironmentMode, + EnvironmentSuffixTarget, +) from sqlmesh.core import constants as c -from sqlmesh.core.audit import BUILT_IN_AUDITS, Audit, ModelAudit, StandaloneAudit +from sqlmesh.core.audit import StandaloneAudit +from sqlmesh.core.macros import call_macro from sqlmesh.core.model import Model, ModelKindMixin, ModelKindName, ViewKind, CustomKind from sqlmesh.core.model.definition import _Model from sqlmesh.core.node import IntervalUnit, NodeType -from sqlmesh.utils import sanitize_name +from sqlmesh.utils import sanitize_name, unique from sqlmesh.utils.dag import DAG from sqlmesh.utils.date import ( TimeLike, is_date, make_inclusive, + make_exclusive, make_inclusive_end, now, now_timestamp, + time_like_to_str, to_date, to_datetime, to_ds, @@ -34,23 +43,26 @@ validate_date_range, yesterday, ) -from sqlmesh.utils.errors import SQLMeshError -from sqlmesh.utils.hashing import hash_data +from sqlmesh.utils.errors import SQLMeshError, SignalEvalError +from sqlmesh.utils.metaprogramming import ( + format_evaluated_code_exception, + Executable, +) +from sqlmesh.utils.hashing import hash_data, md5 from sqlmesh.utils.pydantic import PydanticModel, field_validator -if sys.version_info >= (3, 9): - from typing import Annotated -else: - from typing_extensions import Annotated - if t.TYPE_CHECKING: from sqlglot.dialects.dialect import DialectType from sqlmesh.core.environment import EnvironmentNamingInfo + from sqlmesh.core.context import ExecutionContext Interval = t.Tuple[int, int] Intervals = t.List[Interval] -Node = Annotated[t.Union[Model, StandaloneAudit], Field(descriminator="source_type")] +Node = t.Annotated[t.Union[Model, StandaloneAudit], Field(discriminator="source_type")] + + +logger = logging.getLogger(__name__) class SnapshotChangeCategory(IntEnum): @@ -67,6 +79,7 @@ class SnapshotChangeCategory(IntEnum): BREAKING = 1 NON_BREAKING = 2 + # FORWARD_ONLY category is deprecated and is kept for backwards compatibility. FORWARD_ONLY = 3 INDIRECT_BREAKING = 4 INDIRECT_NON_BREAKING = 5 @@ -149,6 +162,11 @@ def __str__(self) -> str: return f"SnapshotId<{self.name}: {self.identifier}>" +class SnapshotIdBatch(PydanticModel, frozen=True): + snapshot_id: SnapshotId + batch_id: int + + class SnapshotNameVersion(PydanticModel, frozen=True): name: str version: str @@ -159,32 +177,91 @@ def name_version(self) -> SnapshotNameVersion: return self -class SnapshotIntervals(PydanticModel, frozen=True): +class SnapshotIntervals(PydanticModel): name: str - identifier: str + identifier: t.Optional[str] version: str - intervals: Intervals - dev_intervals: Intervals + dev_version: t.Optional[str] + intervals: Intervals = [] + dev_intervals: Intervals = [] + pending_restatement_intervals: Intervals = [] + last_altered_ts: t.Optional[int] = None + dev_last_altered_ts: t.Optional[int] = None @property - def snapshot_id(self) -> SnapshotId: + def snapshot_id(self) -> t.Optional[SnapshotId]: + if not self.identifier: + return None return SnapshotId(name=self.name, identifier=self.identifier) @property def name_version(self) -> SnapshotNameVersion: return SnapshotNameVersion(name=self.name, version=self.version) + def add_interval(self, start: int, end: int) -> None: + self._add_interval(start, end, "intervals") + + def add_dev_interval(self, start: int, end: int) -> None: + self._add_interval(start, end, "dev_intervals") + + def add_pending_restatement_interval(self, start: int, end: int) -> None: + self._add_interval(start, end, "pending_restatement_intervals") + + def update_last_altered_ts(self, last_altered_ts: t.Optional[int]) -> None: + self._update_last_altered_ts(last_altered_ts, "last_altered_ts") + + def update_dev_last_altered_ts(self, last_altered_ts: t.Optional[int]) -> None: + self._update_last_altered_ts(last_altered_ts, "dev_last_altered_ts") + + def remove_interval(self, start: int, end: int) -> None: + self._remove_interval(start, end, "intervals") + + def remove_dev_interval(self, start: int, end: int) -> None: + self._remove_interval(start, end, "dev_intervals") + + def remove_pending_restatement_interval(self, start: int, end: int) -> None: + self._remove_interval(start, end, "pending_restatement_intervals") + + def is_empty(self) -> bool: + return ( + not self.intervals and not self.dev_intervals and not self.pending_restatement_intervals + ) + + def _add_interval(self, start: int, end: int, interval_attr: str) -> None: + target_intervals = getattr(self, interval_attr) + target_intervals = merge_intervals([*target_intervals, (start, end)]) + setattr(self, interval_attr, target_intervals) + + def _update_last_altered_ts( + self, last_altered_ts: t.Optional[int], last_altered_attr: str + ) -> None: + if last_altered_ts: + existing_last_altered_ts = getattr(self, last_altered_attr) + setattr(self, last_altered_attr, max(existing_last_altered_ts or 0, last_altered_ts)) + + def _remove_interval(self, start: int, end: int, interval_attr: str) -> None: + target_intervals = getattr(self, interval_attr) + target_intervals = remove_interval(target_intervals, start, end) + setattr(self, interval_attr, target_intervals) + class SnapshotDataVersion(PydanticModel, frozen=True): fingerprint: SnapshotFingerprint version: str - temp_version: t.Optional[str] = None + dev_version_: t.Optional[str] = Field(default=None, alias="dev_version") change_category: t.Optional[SnapshotChangeCategory] = None physical_schema_: t.Optional[str] = Field(default=None, alias="physical_schema") + dev_table_suffix: str + table_naming_convention: TableNamingConvention = Field(default=TableNamingConvention.default) + virtual_environment_mode: VirtualEnvironmentMode = Field(default=VirtualEnvironmentMode.default) def snapshot_id(self, name: str) -> SnapshotId: return SnapshotId(name=name, identifier=self.fingerprint.to_identifier()) + @property + def dev_version(self) -> str: + return self.dev_version_ or self.fingerprint.to_version() + @property def physical_schema(self) -> str: # The physical schema here is optional to maintain backwards compatibility with @@ -217,13 +294,26 @@ def table_for_environment( return exp.table_( self.table_name_for_environment(environment_naming_info, dialect=dialect), db=self.schema_for_environment(environment_naming_info, dialect=dialect), - catalog=self.catalog_for_environment(environment_naming_info), + catalog=self.catalog_for_environment(environment_naming_info, dialect=dialect), ) def catalog_for_environment( - self, environment_naming_info: EnvironmentNamingInfo + self, environment_naming_info: EnvironmentNamingInfo, dialect: DialectType = None ) -> t.Optional[str]: - return environment_naming_info.catalog_name_override or self.catalog + catalog_name: t.Optional[str] = None + if environment_naming_info.is_dev and environment_naming_info.suffix_target.is_catalog: + catalog_name = f"{self.catalog}__{environment_naming_info.name}" + elif environment_naming_info.catalog_name_override: + catalog_name = environment_naming_info.catalog_name_override + + if catalog_name: + return ( + normalize_identifiers(catalog_name, dialect=dialect).name + if environment_naming_info.normalize_name + else catalog_name + ) + + return self.catalog def schema_for_environment( self, environment_naming_info: EnvironmentNamingInfo, dialect: DialectType = None @@ -237,10 +327,7 @@ def schema_for_environment( if normalize: schema = normalize_identifiers(schema, dialect=dialect).name - if ( - environment_naming_info.name.lower() != c.PROD - and environment_naming_info.suffix_target.is_schema - ): + if environment_naming_info.is_dev and environment_naming_info.suffix_target.is_schema: env_name = environment_naming_info.name if normalize: env_name = normalize_identifiers(env_name, dialect=dialect).name @@ -253,10 +340,7 @@ def table_name_for_environment( self, environment_naming_info: EnvironmentNamingInfo, dialect: DialectType = None ) -> str: table = self.table - if ( - environment_naming_info.name.lower() != c.PROD - and environment_naming_info.suffix_target.is_table - ): + if environment_naming_info.is_dev and environment_naming_info.suffix_target.is_table: env_name = environment_naming_info.name if environment_naming_info.normalize_name: env_name = normalize_identifiers(env_name, dialect=dialect).name @@ -268,19 +352,22 @@ def table_name_for_environment( class SnapshotInfoMixin(ModelKindMixin): name: str - temp_version: t.Optional[str] + dev_version_: t.Optional[str] change_category: t.Optional[SnapshotChangeCategory] fingerprint: SnapshotFingerprint previous_versions: t.Tuple[SnapshotDataVersion, ...] # Added to support Migration # 34 (default catalog) # This can be removed from this model once Pydantic 1 support is dropped (must remain in `Snapshot` though) base_table_name_override: t.Optional[str] + dev_table_suffix: str + table_naming_convention: TableNamingConvention + forward_only: bool - @property + @cached_property def identifier(self) -> str: return self.fingerprint.to_identifier() - @property + @cached_property def snapshot_id(self) -> SnapshotId: return SnapshotId(name=self.name, identifier=self.identifier) @@ -300,6 +387,10 @@ def previous_version(self) -> t.Optional[SnapshotDataVersion]: return self.previous_versions[-1] return None + @property + def dev_version(self) -> str: + return self.dev_version_ or self.fingerprint.to_version() + @property def physical_schema(self) -> str: raise NotImplementedError @@ -316,9 +407,13 @@ def is_new_version(self) -> bool: def fully_qualified_table(self) -> t.Optional[exp.Table]: raise NotImplementedError + @property + def virtual_environment_mode(self) -> VirtualEnvironmentMode: + raise NotImplementedError + @property def is_forward_only(self) -> bool: - return self.change_category == SnapshotChangeCategory.FORWARD_ONLY + return self.forward_only or self.change_category == SnapshotChangeCategory.FORWARD_ONLY @property def is_metadata(self) -> bool: @@ -329,9 +424,18 @@ def is_indirect_non_breaking(self) -> bool: return self.change_category == SnapshotChangeCategory.INDIRECT_NON_BREAKING @property - def reuses_previous_version(self) -> bool: - return self.change_category in ( - SnapshotChangeCategory.FORWARD_ONLY, + def is_no_rebuild(self) -> bool: + """Returns true if this snapshot doesn't require a rebuild in production.""" + return self.forward_only or self.change_category in ( + SnapshotChangeCategory.FORWARD_ONLY, # Backwards compatibility + SnapshotChangeCategory.METADATA, + SnapshotChangeCategory.INDIRECT_NON_BREAKING, + ) + + @property + def is_no_preview(self) -> bool: + """Returns true if this snapshot doesn't require a preview in development.""" + return self.forward_only and self.change_category in ( SnapshotChangeCategory.METADATA, SnapshotChangeCategory.INDIRECT_NON_BREAKING, ) @@ -357,10 +461,6 @@ def display_name( def data_hash_matches(self, other: t.Optional[SnapshotInfoMixin | SnapshotDataVersion]) -> bool: return other is not None and self.fingerprint.data_hash == other.fingerprint.data_hash - def temp_version_get_or_generate(self) -> str: - """Helper method to get the temp version or generate it from the fingerprint.""" - return self.temp_version or self.fingerprint.to_version() - def _table_name(self, version: str, is_deployable: bool) -> str: """Full table name pointing to the materialized location of the snapshot. @@ -371,9 +471,13 @@ def _table_name(self, version: str, is_deployable: bool) -> str: if self.is_external: return self.name + if is_deployable and self.virtual_environment_mode.is_dev_only: + # Use the model name as is if the target is deployable and the virtual environment mode is set to dev-only + return self.name + is_dev_table = not is_deployable if is_dev_table: - version = self.temp_version_get_or_generate() + version = self.dev_version if self.fully_qualified_table is None: raise SQLMeshError( @@ -387,12 +491,14 @@ def _table_name(self, version: str, is_deployable: bool) -> str: fqt = self.fully_qualified_table.copy() fqt.set("catalog", None) base_table_name = fqt.sql() + return table_name( self.physical_schema, base_table_name, version, - is_dev_table=is_dev_table, catalog=self.fully_qualified_table.catalog, + suffix=self.dev_table_suffix if is_dev_table else None, + naming_convention=self.table_naming_convention, ) @property @@ -412,7 +518,7 @@ class SnapshotTableInfo(PydanticModel, SnapshotInfoMixin, frozen=True): name: str fingerprint: SnapshotFingerprint version: str - temp_version: t.Optional[str] = None + dev_version_: t.Optional[str] = Field(default=None, alias="dev_version") physical_schema_: str = Field(alias="physical_schema") parents: t.Tuple[SnapshotId, ...] previous_versions: t.Tuple[SnapshotDataVersion, ...] = () @@ -422,12 +528,24 @@ class SnapshotTableInfo(PydanticModel, SnapshotInfoMixin, frozen=True): # Added to support Migration # 34 (default catalog) # This can be removed from this model once Pydantic 1 support is dropped (must remain in `Snapshot` though) base_table_name_override: t.Optional[str] = None - custom_materialization: t.Optional[str] = None + dev_table_suffix: str + model_gateway: t.Optional[str] = None + forward_only: bool = False + table_naming_convention: TableNamingConvention = TableNamingConvention.default + virtual_environment_mode_: VirtualEnvironmentMode = Field( + default=VirtualEnvironmentMode.default, alias="virtual_environment_mode" + ) def __lt__(self, other: SnapshotTableInfo) -> bool: return self.name < other.name + def __eq__(self, other: t.Any) -> bool: + return isinstance(other, SnapshotTableInfo) and self.fingerprint == other.fingerprint + + def __hash__(self) -> int: + return hash((self.__class__, self.name, self.fingerprint)) + def table_name(self, is_deployable: bool = True) -> str: """Full table name pointing to the materialized location of the snapshot. @@ -449,14 +567,21 @@ def table_info(self) -> SnapshotTableInfo: """Helper method to return self.""" return self + @property + def virtual_environment_mode(self) -> VirtualEnvironmentMode: + return self.virtual_environment_mode_ + @property def data_version(self) -> SnapshotDataVersion: return SnapshotDataVersion( fingerprint=self.fingerprint, version=self.version, - temp_version=self.temp_version, + dev_version=self.dev_version, change_category=self.change_category, physical_schema=self.physical_schema, + dev_table_suffix=self.dev_table_suffix, + table_naming_convention=self.table_naming_convention, + virtual_environment_mode=self.virtual_environment_mode, ) @property @@ -477,6 +602,67 @@ def name_version(self) -> SnapshotNameVersion: """Returns the name and version of the snapshot.""" return SnapshotNameVersion(name=self.name, version=self.version) + @property + def id_and_version(self) -> SnapshotIdAndVersion: + return SnapshotIdAndVersion( + name=self.name, + kind_name=self.kind_name, + identifier=self.identifier, + version=self.version, + dev_version=self.dev_version, + fingerprint=self.fingerprint, + ) + + +class SnapshotIdAndVersion(PydanticModel, ModelKindMixin): + """A stripped down version of a snapshot that is used in situations where we want to fetch the main fields of the snapshots table + without the overhead of parsing the full snapshot payload and fetching intervals. + """ + + name: str + version: str + kind_name_: t.Optional[ModelKindName] = Field(default=None, alias="kind_name") + dev_version_: t.Optional[str] = Field(alias="dev_version") + identifier: str + fingerprint_: t.Union[str, SnapshotFingerprint] = Field(alias="fingerprint") + + @property + def snapshot_id(self) -> SnapshotId: + return SnapshotId(name=self.name, identifier=self.identifier) + + @property + def id_and_version(self) -> SnapshotIdAndVersion: + return self + + @property + def name_version(self) -> SnapshotNameVersion: + return SnapshotNameVersion(name=self.name, version=self.version) + + @property + def fingerprint(self) -> SnapshotFingerprint: + value = self.fingerprint_ + if isinstance(value, str): + self.fingerprint_ = value = SnapshotFingerprint.parse_raw(value) + return value + + @property + def dev_version(self) -> str: + return self.dev_version_ or self.fingerprint.to_version() + + @property + def model_kind_name(self) -> t.Optional[ModelKindName]: + return self.kind_name_ + + def display_name( + self, + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str], + dialect: DialectType = None, + ) -> str: + return model_display_name( + self.name, environment_naming_info, default_catalog, dialect=dialect + ) + class Snapshot(PydanticModel, SnapshotInfoMixin): """A snapshot represents a node at a certain point in time. @@ -496,7 +682,6 @@ class Snapshot(PydanticModel, SnapshotInfoMixin): fingerprint: A unique hash of the node definition so that nodes can be reused across environments. node: Node object that the snapshot encapsulates. parents: The list of parent snapshots (upstream dependencies). - audits: The list of generic audits used by the node. intervals: List of [start, end) intervals showing which time ranges a snapshot has data for. dev_intervals: List of [start, end) intervals showing development intervals (forward-only). created_ts: Epoch millis timestamp when a snapshot was first created. @@ -514,6 +699,8 @@ class Snapshot(PydanticModel, SnapshotInfoMixin): Applicable for forward-only snapshots only. migrated: Whether or not this snapshot has been created as a result of migration. unrestorable: Whether or not this snapshot can be used to revert its model to a previous version. + next_auto_restatement_ts: The timestamp which indicates when is the next time this snapshot should be restated. + table_naming_convention: Convention to follow when generating the physical table name """ name: str @@ -521,15 +708,15 @@ class Snapshot(PydanticModel, SnapshotInfoMixin): physical_schema_: t.Optional[str] = Field(default=None, alias="physical_schema") node: Node parents: t.Tuple[SnapshotId, ...] - audits: t.Tuple[ModelAudit, ...] = tuple() intervals: Intervals = [] dev_intervals: Intervals = [] + pending_restatement_intervals: Intervals = [] created_ts: int updated_ts: int ttl: str previous_versions: t.Tuple[SnapshotDataVersion, ...] = () version: t.Optional[str] = None - temp_version: t.Optional[str] = None + dev_version_: t.Optional[str] = Field(default=None, alias="dev_version") change_category: t.Optional[SnapshotChangeCategory] = None unpaused_ts: t.Optional[int] = None effective_from: t.Optional[TimeLike] = None @@ -537,6 +724,14 @@ class Snapshot(PydanticModel, SnapshotInfoMixin): unrestorable: bool = False # Added to support Migration # 34 (default catalog) base_table_name_override: t.Optional[str] = None + next_auto_restatement_ts: t.Optional[int] = None + dev_table_suffix: str = "dev" + table_naming_convention: TableNamingConvention = TableNamingConvention.default + forward_only: bool = False + # Physical table last modified timestamp, not to be confused with the "updated_ts" field + # which is for the snapshot record itself + last_altered_ts: t.Optional[int] = None + dev_last_altered_ts: t.Optional[int] = None @field_validator("ttl") @classmethod @@ -552,7 +747,6 @@ def _time_delta_must_be_positive(cls, v: str) -> str: def hydrate_with_intervals_by_version( snapshots: t.Iterable[Snapshot], intervals: t.Iterable[SnapshotIntervals], - is_dev: bool = False, ) -> t.List[Snapshot]: """Hydrates target snapshots with given intervals. @@ -561,7 +755,6 @@ def hydrate_with_intervals_by_version( Args: snapshots: Target snapshots. intervals: Target snapshot intervals. - is_dev: If in development mode ignores same version intervals for paused forward-only snapshots. Returns: List of target snapshots with hydrated intervals. @@ -577,32 +770,7 @@ def hydrate_with_intervals_by_version( ) for interval in snapshot_intervals: snapshot.merge_intervals(interval) - result.append(snapshot) - return result - - @staticmethod - def hydrate_with_intervals_by_identifier( - snapshots: t.Iterable[Snapshot], - intervals: t.Iterable[SnapshotIntervals], - ) -> t.List[Snapshot]: - """Hydrates target snapshots with given intervals. - - This will match snapshots with intervals by name and identifier rather than versions. - - Args: - snapshots: Target snapshots. - intervals: Target snapshot intervals. - - Returns: - List of target snapshots with hydrated intervals. - """ - intervals_by_snapshot_id = {i.snapshot_id: i for i in intervals} - - result = [] - for snapshot in snapshots: - if snapshot.snapshot_id in intervals_by_snapshot_id: - snapshot.merge_intervals(intervals_by_snapshot_id[snapshot.snapshot_id]) result.append(snapshot) return result @@ -615,8 +783,8 @@ def from_node( nodes: t.Dict[str, Node], ttl: str = c.DEFAULT_SNAPSHOT_TTL, version: t.Optional[str] = None, - audits: t.Optional[t.Dict[str, ModelAudit]] = None, cache: t.Optional[t.Dict[str, SnapshotFingerprint]] = None, + table_naming_convention: TableNamingConvention = TableNamingConvention.default, ) -> Snapshot: """Creates a new snapshot for a node. @@ -626,23 +794,19 @@ def from_node( If no dictionary is passed in the fingerprint will not be dependent on a node's parents. ttl: A TTL to determine how long orphaned (snapshots that are not promoted anywhere) should live. version: The version that a snapshot is associated with. Usually set during the planning phase. - audits: Available audits by name. cache: Cache of node name to fingerprints. + table_naming_convention: Convention to follow when generating the physical table name Returns: The newly created snapshot. """ created_ts = now_timestamp() - kwargs = {} - if node.is_model: - kwargs["audits"] = tuple(t.cast(_Model, node).referenced_audits(audits or {})) return cls( name=node.fqn, fingerprint=fingerprint_from_node( node, nodes=nodes, - audits=audits, cache=cache, ), node=node, @@ -652,7 +816,6 @@ def from_node( identifier=fingerprint_from_node( parent_node, nodes=nodes, - audits=audits, cache=cache, ).to_identifier(), ) @@ -664,7 +827,7 @@ def from_node( updated_ts=created_ts, ttl=ttl, version=version, - **kwargs, + table_naming_convention=table_naming_convention, ) def __eq__(self, other: t.Any) -> bool: @@ -693,7 +856,8 @@ def add_interval(self, start: TimeLike, end: TimeLike, is_dev: bool = False) -> f"Attempted to add an Invalid interval ({start}, {end}) to snapshot {self.snapshot_id}" ) - start_ts, end_ts = self.inclusive_exclusive(start, end, strict=False) + start_ts, end_ts = self.inclusive_exclusive(start, end, strict=False, expand=False) + if start_ts >= end_ts: # Skipping partial interval. return @@ -739,17 +903,42 @@ def get_removal_interval( When previewing, we are not actually restating a model, but removing an interval to trigger a run. """ - end = execution_time or now() if self.depends_on_past else end + end = execution_time or now_timestamp() if self.depends_on_past else end + removal_interval = self.inclusive_exclusive(start, end, strict) + if not is_preview and self.full_history_restatement_only and self.intervals: - start = self.intervals[0][0] - return self.inclusive_exclusive(start, end, strict) + expanded_removal_interval = self.inclusive_exclusive(self.intervals[0][0], end, strict) + requested_start, requested_end = removal_interval + expanded_start, expanded_end = expanded_removal_interval + + # only warn if the requested removal interval was a subset of the actual model intervals and was automatically expanded + # if the requested interval was the same or wider than the actual model intervals, no need to warn + if ( + requested_start > expanded_start or requested_end < expanded_end + ) and self.is_incremental: + from sqlmesh.core.console import get_console + + get_console().log_warning( + f"Model '{self.model.name}' is '{self.model_kind_name}' which does not support partial restatement.\n" + f"Expanding the requested restatement intervals from [{to_ts(requested_start)} - {to_ts(requested_end)}] " + f"to [{to_ts(expanded_start)} - {to_ts(expanded_end)}] in order to fully restate the model." + ) + + removal_interval = expanded_removal_interval + + return removal_interval + + @property + def allow_partials(self) -> bool: + return self.is_model and self.model.allow_partials def inclusive_exclusive( self, start: TimeLike, end: TimeLike, strict: bool = True, - allow_partial: bool = False, + allow_partial: t.Optional[bool] = None, + expand: bool = True, ) -> Interval: """Transform the inclusive start and end into a [start, end) pair. @@ -758,29 +947,19 @@ def inclusive_exclusive( end: The end date/time of the interval (inclusive) strict: Whether to fail when the inclusive start is the same as the exclusive end. allow_partial: Whether the interval can be partial or not. + expand: Whether or not partial intervals are expanded outwards. Returns: A [start, end) pair. """ - interval_unit = self.node.interval_unit - start_ts = to_timestamp(interval_unit.cron_floor(start)) - if start_ts < to_timestamp(start) and not self.model.allow_partials: - start_ts = to_timestamp(interval_unit.cron_next(start_ts)) - - if is_date(end): - end = to_datetime(end) + timedelta(days=1) - end_ts = to_timestamp(interval_unit.cron_floor(end) if not allow_partial else end) - if end_ts < start_ts and to_timestamp(end) > to_timestamp(start) and not strict: - # This can happen when the interval unit is coarser than the size of the input interval. - # For example, if the interval unit is monthly, but the input interval is only 1 hour long. - return (start_ts, end_ts) - - if (strict and start_ts >= end_ts) or (start_ts > end_ts): - raise ValueError( - f"`end` ({to_datetime(end_ts)}) must be greater than `start` ({to_datetime(start_ts)})" - ) - - return (start_ts, end_ts) + return inclusive_exclusive( + start, + end, + self.node.interval_unit, + strict=strict, + allow_partial=self.allow_partials if allow_partial is None else allow_partial, + expand=expand, + ) def merge_intervals(self, other: t.Union[Snapshot, SnapshotIntervals]) -> None: """Inherits intervals from the target snapshot. @@ -790,25 +969,32 @@ def merge_intervals(self, other: t.Union[Snapshot, SnapshotIntervals]) -> None: """ effective_from_ts = self.normalized_effective_from_ts or 0 apply_effective_from = effective_from_ts > 0 and self.identifier != other.identifier - for start, end in other.intervals: # If the effective_from is set, then intervals that come after it must come from - # the current snapshost. + # the current snapshots. if apply_effective_from and start < effective_from_ts: end = min(end, effective_from_ts) if not apply_effective_from or end <= effective_from_ts: self.add_interval(start, end) - previous_ids = {s.snapshot_id(self.name) for s in self.previous_versions} - if self.identifier == other.identifier or ( - # Indirect Non-Breaking snapshots share the dev table with its previous version. - # The same applies to migrated snapshots. - (self.is_indirect_non_breaking or self.is_metadata or self.migrated) - and other.snapshot_id in previous_ids - ): + if other.last_altered_ts: + self.last_altered_ts = max(self.last_altered_ts or 0, other.last_altered_ts) + + if self.dev_version == other.dev_version: + # Merge dev intervals if the dev versions match which would mean + # that this and the other snapshot are pointing to the same dev table. for start, end in other.dev_intervals: self.add_interval(start, end, is_dev=True) + if other.dev_last_altered_ts: + self.dev_last_altered_ts = max( + self.dev_last_altered_ts or 0, other.dev_last_altered_ts + ) + + self.pending_restatement_intervals = merge_intervals( + [*self.pending_restatement_intervals, *other.pending_restatement_intervals] + ) + @property def evaluatable(self) -> bool: """Whether or not a snapshot should be evaluated and have intervals.""" @@ -849,14 +1035,13 @@ def missing_intervals( return [] if self.node.start and to_datetime(start) < to_datetime(self.node.start): start = self.node.start - if self.node.end and make_inclusive_end(end) > make_inclusive_end(self.node.end): - end = self.node.end # If the amount of time being checked is less than the size of a single interval then we # know that there can't being missing intervals within that range and return validate_date_range(start, end) + if ( not is_date(end) - and not (self.is_model and self.model.allow_partials) + and not self.allow_partials and to_timestamp(end) - to_timestamp(start) < self.node.interval_unit.milliseconds ): return [] @@ -869,74 +1054,119 @@ def missing_intervals( if not self.evaluatable or (self.is_seed and intervals): return [] - allow_partials = not end_bounded and self.is_model and self.model.allow_partials - start_ts, end_ts = ( - to_timestamp(ts) - for ts in self.inclusive_exclusive( - start, - end, - strict=False, - allow_partial=allow_partials, - ) - ) + start_ts, end_ts = (to_timestamp(ts) for ts in self.inclusive_exclusive(start, end)) interval_unit = self.node.interval_unit - - execution_time = execution_time or now() + execution_time_ts = to_timestamp(execution_time) if execution_time else now_timestamp() + upper_bound_ts = ( + execution_time_ts + if ignore_cron + else to_timestamp(self.node.cron_floor(execution_time_ts)) + ) if end_bounded: - execution_time = min(to_timestamp(execution_time), end_ts) + upper_bound_ts = min(upper_bound_ts, end_ts) + if not self.allow_partials: + upper_bound_ts = to_timestamp(interval_unit.cron_floor(upper_bound_ts)) - if not allow_partials: - upper_bound_ts = to_timestamp( - self.node.cron_floor(execution_time) if not ignore_cron else execution_time - ) - end_ts = min(end_ts, to_timestamp(interval_unit.cron_floor(upper_bound_ts))) - else: - upper_bound_ts = to_timestamp(execution_time) - end_ts = min(end_ts, upper_bound_ts) + end_ts = min(end_ts, upper_bound_ts) + + lookback = 0 + model_end_ts: t.Optional[int] = None - lookback = self.model.lookback if self.is_model else 0 + if self.is_model: + lookback = self.model.lookback + model_end_ts = to_timestamp(make_exclusive(self.model.end)) if self.model.end else None return compute_missing_intervals( - interval_unit, tuple(intervals), start_ts, end_ts, upper_bound_ts, lookback + interval_unit, + tuple(intervals), + start_ts, + end_ts, + lookback, + model_end_ts, ) - def categorize_as(self, category: SnapshotChangeCategory) -> None: + def check_ready_intervals( + self, + intervals: Intervals, + context: ExecutionContext, + ) -> Intervals: + """Returns a list of intervals that are considered ready by the provided signal. + + Note that this will handle gaps in the provided intervals. The returned intervals + may introduce new gaps. + """ + signals = self.is_model and self.model.render_signal_calls() + if not signals: + return intervals + + for signal_name, kwargs in signals.signals_to_kwargs.items(): + try: + intervals = check_ready_intervals( + signals.prepared_python_env[signal_name], + intervals, + context, + python_env=signals.python_env, + dialect=self.model.dialect, + path=self.model._path, + snapshot=self, + kwargs=kwargs, + ) + except SQLMeshError as e: + raise SignalEvalError( + f"{e} '{signal_name}' for '{self.model.name}' at {self.model._path}" + ) + return intervals + + def categorize_as(self, category: SnapshotChangeCategory, forward_only: bool = False) -> None: """Assigns the given category to this snapshot. Args: category: The change category to assign to this snapshot. + forward_only: Whether or not this snapshot is applied going forward in production. """ - self.temp_version = None - reuse_previous_version = category in ( - SnapshotChangeCategory.FORWARD_ONLY, + assert category != SnapshotChangeCategory.FORWARD_ONLY, ( + "FORWARD_ONLY change category is deprecated" + ) + + self.dev_version_ = self.fingerprint.to_version() + is_no_rebuild = forward_only or category in ( SnapshotChangeCategory.INDIRECT_NON_BREAKING, SnapshotChangeCategory.METADATA, ) - if reuse_previous_version and self.previous_version: + if self.is_model and not self.virtual_environment_mode.is_full: + # Hardcode the version if the virtual environment is not fully enabled. + self.version = "novde" + elif self.is_model and self.model.physical_version: + # If the model has a pinned version then use that. + self.version = self.model.physical_version + elif is_no_rebuild and self.previous_version: + self.version = self.previous_version.data_version.version + elif self.is_model and self.model.forward_only and not self.previous_version: + # If this is a new model then use a deterministic version, independent of the fingerprint. + self.version = hash_data([self.name, *self.model.kind.data_hash_values]) + else: + self.version = self.fingerprint.to_version() + + if is_no_rebuild and self.previous_version: previous_version = self.previous_version - self.version = previous_version.data_version.version self.physical_schema_ = previous_version.physical_schema - if category.is_indirect_non_breaking or category.is_metadata: + self.table_naming_convention = previous_version.table_naming_convention + if self.is_materialized and (category.is_indirect_non_breaking or category.is_metadata): # Reuse the dev table for indirect non-breaking changes. - self.temp_version = ( - previous_version.data_version.temp_version + self.dev_version_ = ( + previous_version.data_version.dev_version or previous_version.fingerprint.to_version() ) - else: - self.version = self.fingerprint.to_version() + self.dev_table_suffix = previous_version.data_version.dev_table_suffix self.change_category = category + self.forward_only = forward_only - def set_unpaused_ts(self, unpaused_dt: t.Optional[TimeLike]) -> None: - """Sets the timestamp for when this snapshot was unpaused. - - Args: - unpaused_dt: The datetime object of when this snapshot was unpaused. - """ - self.unpaused_ts = ( - to_timestamp(self.node.interval_unit.cron_floor(unpaused_dt)) if unpaused_dt else None - ) + @property + def categorized(self) -> bool: + """Whether the snapshot has been categorized.""" + return self.change_category is not None and self.version is not None def table_name(self, is_deployable: bool = True) -> str: """Full table name pointing to the materialized location of the snapshot. @@ -1009,6 +1239,102 @@ def needs_destructive_check( and self.name not in allow_destructive_snapshots ) + def needs_additive_check( + self, + allow_additive_snapshots: t.Set[str], + ) -> bool: + return ( + self.is_model + and not self.model.on_additive_change.is_allow + and self.name not in allow_additive_snapshots + ) + + def get_next_auto_restatement_interval(self, execution_time: TimeLike) -> t.Optional[Interval]: + """Returns the next auto restatement interval for the snapshot. + + Args: + execution_time: The execution time to use for the restatement. + + Returns: + The interval that needs to be restated or None if no restatement is needed. + """ + if ( + not self.is_model + or not self.intervals + or not self.model.auto_restatement_cron + or self.model.disable_restatement + ): + return None + + execution_time_ts = to_timestamp(execution_time) + next_auto_restatement_ts = self.next_auto_restatement_ts or to_timestamp( + self.model.auto_restatement_croniter(self.created_ts).get_next(estimate=False) + ) + if execution_time_ts < next_auto_restatement_ts: + return None + + num_intervals_to_restate = self.model.auto_restatement_intervals + if num_intervals_to_restate is None: + return (self.intervals[0][0], self.intervals[-1][1]) + + auto_restatement_end_ts = to_timestamp( + self.node.interval_unit.cron_floor(execution_time_ts) + ) + auto_restatement_start_ts = ( + auto_restatement_end_ts + - num_intervals_to_restate * self.node.interval_unit.milliseconds + ) + return (auto_restatement_start_ts, auto_restatement_end_ts) + + def update_next_auto_restatement_ts(self, execution_time: TimeLike) -> t.Optional[int]: + """Updates the next auto restatement timestamp. + + Args: + execution_time: The execution time to use for the restatement. + + Returns: + The next auto restatement timestamp or None if not applicable. + """ + if ( + not self.is_model + or not self.model.auto_restatement_cron + or self.model.disable_restatement + ): + self.next_auto_restatement_ts = None + else: + self.next_auto_restatement_ts = to_timestamp( + self.model.auto_restatement_croniter(execution_time).get_next(estimate=False) + ) + return self.next_auto_restatement_ts + + def apply_pending_restatement_intervals(self) -> None: + """Applies the pending restatement intervals to the snapshot's intervals.""" + if not self.is_model or self.model.disable_restatement: + return + for pending_restatement_interval in self.pending_restatement_intervals: + logger.info( + "Applying the auto restated interval (%s, %s) to snapshot %s", + time_like_to_str(pending_restatement_interval[0]), + time_like_to_str(pending_restatement_interval[1]), + self.snapshot_id, + ) + self.intervals = remove_interval(self.intervals, *pending_restatement_interval) + + def is_directly_modified(self, other: Snapshot) -> bool: + """Returns whether or not this snapshot is directly modified in relation to the other snapshot.""" + return self.node.is_data_change(other.node) + + def is_indirectly_modified(self, other: Snapshot) -> bool: + """Returns whether or not this snapshot is indirectly modified in relation to the other snapshot.""" + return ( + self.fingerprint.parent_data_hash != other.fingerprint.parent_data_hash + and not self.node.is_data_change(other.node) + ) + + def is_metadata_updated(self, other: Snapshot) -> bool: + """Returns whether or not this snapshot contains metadata changes in relation to the other snapshot.""" + return self.fingerprint.metadata_hash != other.fingerprint.metadata_hash + @property def physical_schema(self) -> str: if self.physical_schema_ is not None: @@ -1031,13 +1357,18 @@ def table_info(self) -> SnapshotTableInfo: name=self.name, fingerprint=self.fingerprint, version=self.version, - temp_version=self.temp_version, + dev_version=self.dev_version, parents=self.parents, previous_versions=self.previous_versions, change_category=self.change_category, kind_name=self.model_kind_name, node_type=self.node_type, custom_materialization=custom_materialization, + dev_table_suffix=self.dev_table_suffix, + model_gateway=self.model_gateway, + table_naming_convention=self.table_naming_convention, # type: ignore + forward_only=self.forward_only, + virtual_environment_mode=self.virtual_environment_mode, ) @property @@ -1046,9 +1377,12 @@ def data_version(self) -> SnapshotDataVersion: return SnapshotDataVersion( fingerprint=self.fingerprint, version=self.version, - temp_version=self.temp_version, + dev_version=self.dev_version, change_category=self.change_category, physical_schema=self.physical_schema, + dev_table_suffix=self.dev_table_suffix, + table_naming_convention=self.table_naming_convention, + virtual_environment_mode=self.virtual_environment_mode, ) @property @@ -1058,8 +1392,10 @@ def snapshot_intervals(self) -> SnapshotIntervals: name=self.name, identifier=self.identifier, version=self.version, + dev_version=self.dev_version, intervals=self.intervals.copy(), dev_intervals=self.dev_intervals.copy(), + pending_restatement_intervals=self.pending_restatement_intervals.copy(), ) @property @@ -1099,7 +1435,7 @@ def model_kind_name(self) -> t.Optional[ModelKindName]: def node_type(self) -> NodeType: if self.node.is_model: return NodeType.MODEL - elif self.node.is_audit: + if self.node.is_audit: return NodeType.AUDIT raise SQLMeshError(f"Snapshot {self.snapshot_id} has an unknown node type.") @@ -1116,6 +1452,10 @@ def model_or_none(self) -> t.Optional[Model]: return t.cast(Model, self.node) return None + @property + def model_gateway(self) -> t.Optional[str]: + return self.model.gateway if self.is_model else None + @property def audit(self) -> StandaloneAudit: if self.is_audit: @@ -1138,24 +1478,15 @@ def depends_on_self(self) -> bool: """Whether or not this models depends on self.""" return self.is_model and self.model.depends_on_self - @property - def audits_with_args(self) -> t.List[t.Tuple[Audit, t.Dict[str, exp.Expression]]]: - if self.is_model: - audits_by_name = {**BUILT_IN_AUDITS, **{a.name: a for a in self.audits}} - return [ - (audits_by_name[audit_name], audit_args) - for audit_name, audit_args in self.model.audits - ] - elif self.is_audit: - return [(self.audit, {})] - - return [] - @property def name_version(self) -> SnapshotNameVersion: """Returns the name and version of the snapshot.""" return SnapshotNameVersion(name=self.name, version=self.version) + @property + def id_and_version(self) -> SnapshotIdAndVersion: + return self.table_info.id_and_version + @property def disable_restatement(self) -> bool: """Is restatement disabled for the node""" @@ -1169,7 +1500,30 @@ def fully_qualified_table(self) -> t.Optional[exp.Table]: @property def expiration_ts(self) -> int: - return to_timestamp(self.ttl, relative_base=to_datetime(self.updated_ts)) + return to_timestamp( + self.ttl, + relative_base=to_datetime(self.updated_ts), + check_categorical_relative_expression=False, + ) + + @property + def supports_schema_migration_in_prod(self) -> bool: + """Returns whether or not this snapshot supports schema migration when deployed to production.""" + return self.is_paused and self.is_model and not self.is_symbolic and not self.is_seed + + @property + def requires_schema_migration_in_prod(self) -> bool: + """Returns whether or not this snapshot requires a schema migration when deployed to production.""" + return self.supports_schema_migration_in_prod and ( + (self.previous_version and self.previous_version.version == self.version) + or self.model.forward_only + or bool(self.model.physical_version) + or not self.virtual_environment_mode.is_full + ) + + @property + def ttl_ms(self) -> int: + return self.expiration_ts - self.updated_ts @property def custom_materialization(self) -> t.Optional[str]: @@ -1177,21 +1531,38 @@ def custom_materialization(self) -> t.Optional[str]: return t.cast(CustomKind, self.model.kind).materialization return None + @property + def virtual_environment_mode(self) -> VirtualEnvironmentMode: + return ( + self.model.virtual_environment_mode if self.is_model else VirtualEnvironmentMode.default + ) + def _ensure_categorized(self) -> None: if not self.change_category: raise SQLMeshError(f"Snapshot {self.snapshot_id} has not been categorized yet.") if not self.version: raise SQLMeshError(f"Snapshot {self.snapshot_id} has not been versioned yet.") + def __getstate__(self) -> t.Dict[t.Any, t.Any]: + state = super().__getstate__() + state["__dict__"] = state["__dict__"].copy() + # Don't store intervals. + state["__dict__"]["intervals"] = [] + state["__dict__"]["dev_intervals"] = [] + return state + class SnapshotTableCleanupTask(PydanticModel): snapshot: SnapshotTableInfo dev_table_only: bool -SnapshotIdLike = t.Union[SnapshotId, SnapshotTableInfo, Snapshot] +SnapshotIdLike = t.Union[SnapshotId, SnapshotIdAndVersion, SnapshotTableInfo, Snapshot] +SnapshotIdAndVersionLike = t.Union[SnapshotIdAndVersion, SnapshotTableInfo, Snapshot] SnapshotInfoLike = t.Union[SnapshotTableInfo, Snapshot] -SnapshotNameVersionLike = t.Union[SnapshotNameVersion, SnapshotTableInfo, Snapshot] +SnapshotNameVersionLike = t.Union[ + SnapshotNameVersion, SnapshotTableInfo, SnapshotIdAndVersion, Snapshot +] class DeployabilityIndex(PydanticModel, frozen=True): @@ -1214,7 +1585,7 @@ def _snapshot_ids_set_validator(cls, v: t.Any) -> t.Optional[t.FrozenSet[t.Tuple return frozenset( { ( - cls._snapshot_id_key(snapshot_id) + cls._snapshot_id_key(snapshot_id) # type: ignore if isinstance(snapshot_id, SnapshotId) else snapshot_id ) @@ -1238,8 +1609,10 @@ def is_deployable(self, snapshot: SnapshotIdLike) -> bool: ) def is_representative(self, snapshot: SnapshotIdLike) -> bool: - """Returns true if the output produced by the given snapshot in a development environment can be reused - in (deployed to) production, or if this snapshot already represents what is currently in production. + """Returns true if the deployable (non-dev) table of the given snapshot should be used for reading, table mapping, and + computing missing intervals. + + Note, that deployable snapshots are also representative, but the reverse is not always true. Unlike `is_deployable`, this variant also captures FORWARD_ONLY and INDIRECT_NON_BREAKING snapshots that are not deployable by their nature but are currently promoted in production. Therefore, it's safe to consider @@ -1287,57 +1660,81 @@ def none_deployable(cls) -> DeployabilityIndex: @classmethod def create( - cls, snapshots: t.Dict[SnapshotId, Snapshot] | t.Collection[Snapshot] + cls, + snapshots: t.Dict[SnapshotId, Snapshot] | t.Collection[Snapshot], + start: t.Optional[TimeLike] = None, # plan start + start_override_per_model: t.Optional[t.Dict[str, datetime]] = None, ) -> DeployabilityIndex: if not isinstance(snapshots, dict): snapshots = {s.snapshot_id: s for s in snapshots} - dag = snapshots_to_dag(snapshots.values()) - reversed_dag = dag.reversed.graph deployability_mapping: t.Dict[SnapshotId, bool] = {} + children_deployability_mapping: t.Dict[SnapshotId, bool] = {} representative_shared_version_ids: t.Set[SnapshotId] = set() + start_override_per_model = start_override_per_model or {} + + start_date_cache: t.Optional[t.Dict[str, datetime]] = {} - def _visit(node: SnapshotId, deployable: bool = True) -> None: - if deployability_mapping.get(node) in (False, deployable) and ( - deployable or node not in representative_shared_version_ids - ): - return - - if deployable and node in snapshots: - snapshot = snapshots[node] - # Capture uncategorized snapshot which represents a forward-only model. - is_uncategorized_forward_only_model = ( - snapshot.change_category is None - and snapshot.previous_versions - and snapshot.is_model - and snapshot.model.forward_only + dag = snapshots_to_dag(snapshots.values()) + for node in dag: + if node not in snapshots: + continue + snapshot = snapshots[node] + + if not snapshot.virtual_environment_mode.is_full: + # If the virtual environment is not fully enabled, then the snapshot can never be deployable + this_deployable = False + else: + # Make sure that the node is deployable according to all its parents + this_deployable = all( + children_deployability_mapping[p_id] + for p_id in snapshots[node].parents + if p_id in children_deployability_mapping + ) + + if this_deployable: + is_forward_only_model = ( + snapshot.is_model and snapshot.model.forward_only and not snapshot.is_metadata + ) + has_auto_restatement = ( + snapshot.is_model and snapshot.model.auto_restatement_cron is not None + ) + + snapshot_start = start_override_per_model.get( + node.name, start_date(snapshot, snapshots.values(), cache=start_date_cache) + ) + + is_valid_start = ( + snapshot.is_valid_start(start, snapshot_start) if start is not None else True ) + + children_deployable = is_valid_start and not has_auto_restatement if ( snapshot.is_forward_only or snapshot.is_indirect_non_breaking - or is_uncategorized_forward_only_model + or is_forward_only_model + or has_auto_restatement + or not is_valid_start ): # FORWARD_ONLY and INDIRECT_NON_BREAKING snapshots are not deployable by nature. + # Similarly, if the model depends on past and the start date is not aligned with the + # model's start, we should consider this snapshot non-deployable. this_deployable = False - if not snapshot.is_paused or snapshot.is_indirect_non_breaking: + if not snapshot.is_paused or ( + snapshot.is_indirect_non_breaking and snapshot.intervals + ): # This snapshot represents what's currently deployed in prod. representative_shared_version_ids.add(node) - else: - this_deployable = True - children_deployable = not ( - snapshot.is_paused - and (snapshot.is_forward_only or is_uncategorized_forward_only_model) - ) + else: + # If the parent is not representative then its children can't be deployable. + children_deployable = False else: - this_deployable, children_deployable = False, False - representative_shared_version_ids.discard(node) - - deployability_mapping[node] = deployability_mapping.get(node, True) and this_deployable - for child in reversed_dag[node]: - _visit(child, children_deployable) + children_deployable = False + if not snapshots[node].is_paused: + representative_shared_version_ids.add(node) - for node in dag.roots: - _visit(node) + deployability_mapping[node] = this_deployable + children_deployability_mapping[node] = children_deployable deployable_ids = { snapshot_id for snapshot_id, deployable in deployability_mapping.items() if deployable @@ -1365,19 +1762,46 @@ def table_name( physical_schema: str, name: str, version: str, - is_dev_table: bool = False, catalog: t.Optional[str] = None, + suffix: t.Optional[str] = None, + naming_convention: t.Optional[TableNamingConvention] = None, ) -> str: table = exp.to_table(name) - # bigquery projects usually have "-" in them which is illegal in the table name, so we aggressively prune - name = "__".join(sanitize_name(part.name) for part in table.parts) - temp_suffix = "__temp" if is_dev_table else "" + naming_convention = naming_convention or TableNamingConvention.default - table.set("this", exp.to_identifier(f"{name}__{version}{temp_suffix}")) + if naming_convention == TableNamingConvention.HASH_MD5: + # just take a MD5 hash of what we would have generated anyway using SCHEMA_AND_TABLE + value_to_hash = table_name( + physical_schema=physical_schema, + name=name, + version=version, + catalog=catalog, + suffix=suffix, + naming_convention=TableNamingConvention.SCHEMA_AND_TABLE, + ) + full_name = f"{c.SQLMESH}_md5__{md5(value_to_hash)}" + else: + # note: Snapshot._table_name() already strips the catalog from the model name before calling this function + # Therefore, a model with 3-part naming like "foo.bar.baz" gets passed as (name="bar.baz", catalog="foo") to this function + # This is why there is no TableNamingConvention.CATALOG_AND_SCHEMA_AND_TABLE + table_parts = table.parts + parts_to_consider = 2 if naming_convention == TableNamingConvention.SCHEMA_AND_TABLE else 1 + + # in case the parsed table name has less parts than what the naming convention says we should be considering + parts_to_consider = min(len(table_parts), parts_to_consider) + + # bigquery projects usually have "-" in them which is illegal in the table name, so we aggressively prune + name = "__".join(sanitize_name(part.name) for part in table_parts[-parts_to_consider:]) + + full_name = f"{name}__{version}" + + suffix = f"__{suffix}" if suffix else "" + + table.set("this", exp.to_identifier(f"{full_name}{suffix}")) table.set("db", exp.to_identifier(physical_schema)) if not table.catalog and catalog: - table.set("catalog", exp.parse_identifier(catalog)) + table.set("catalog", exp.to_identifier(catalog)) return exp.table_name(table) @@ -1391,16 +1815,43 @@ def display_name( Returns the model name as a qualified view name. This is just used for presenting information back to the user and `qualified_view_name` should be used when wanting a view name in all other cases. + + Args: + snapshot_info_like: The snapshot info object to get the display name for + environment_naming_info: Environment naming info to use for display name formatting + default_catalog: Optional default catalog name to use. If None, the default catalog will always be included in the display name. + dialect: Optional dialect type to use for name formatting + + Returns: + The formatted display name as a string """ if snapshot_info_like.is_audit: return snapshot_info_like.name - view_name = exp.to_table(snapshot_info_like.name) + + return model_display_name( + snapshot_info_like.name, environment_naming_info, default_catalog, dialect + ) + + +def model_display_name( + node_name: str, + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str], + dialect: DialectType = None, +) -> str: + view_name = exp.to_table(node_name) + + catalog = ( + None + if ( + environment_naming_info.suffix_target != EnvironmentSuffixTarget.CATALOG + and view_name.catalog == default_catalog + ) + else view_name.catalog + ) + qvn = QualifiedViewName( - catalog=( - view_name.catalog - if view_name.catalog and view_name.catalog != default_catalog - else None - ), + catalog=catalog, schema_name=view_name.db or None, table=view_name.name, ) @@ -1411,7 +1862,6 @@ def fingerprint_from_node( node: Node, *, nodes: t.Dict[str, Node], - audits: t.Optional[t.Dict[str, ModelAudit]] = None, cache: t.Optional[t.Dict[str, SnapshotFingerprint]] = None, ) -> SnapshotFingerprint: """Helper function to generate a fingerprint based on the data and metadata of the node and its parents. @@ -1424,7 +1874,6 @@ def fingerprint_from_node( node: Node to fingerprint. nodes: Dictionary of all nodes in the graph to make the fingerprint dependent on parent changes. If no dictionary is passed in the fingerprint will not be dependent on a node's parents. - audits: Available audits by name. cache: Cache of node name to fingerprints. Returns: @@ -1434,12 +1883,7 @@ def fingerprint_from_node( if node.fqn not in cache: parents = [ - fingerprint_from_node( - nodes[table], - nodes=nodes, - audits=audits, - cache=cache, - ) + fingerprint_from_node(nodes[table], nodes=nodes, cache=cache) for table in node.depends_on if table in nodes ] @@ -1452,7 +1896,7 @@ def fingerprint_from_node( cache[node.fqn] = SnapshotFingerprint( data_hash=node.data_hash, - metadata_hash=node.metadata_hash(audits or {}), + metadata_hash=node.metadata_hash, parent_data_hash=parent_data_hash, parent_metadata_hash=parent_metadata_hash, ) @@ -1475,7 +1919,7 @@ def _parents_from_node( return parent_nodes -def merge_intervals(intervals: Intervals) -> Intervals: +def merge_intervals(intervals: t.Collection[Interval]) -> Intervals: """Merge a list of intervals. Args: @@ -1484,6 +1928,8 @@ def merge_intervals(intervals: Intervals) -> Intervals: Returns: A new list of sorted and merged intervals. """ + if not intervals: + return [] intervals = sorted(intervals) merged = [intervals[0]] @@ -1556,6 +2002,21 @@ def to_table_mapping( } +def to_view_mapping( + snapshots: t.Iterable[Snapshot], + environment_naming_info: EnvironmentNamingInfo, + default_catalog: t.Optional[str] = None, + dialect: t.Optional[str] = None, +) -> t.Dict[str, str]: + return { + snapshot.name: snapshot.display_name( + environment_naming_info, default_catalog=default_catalog, dialect=dialect + ) + for snapshot in snapshots + if snapshot.is_model + } + + def has_paused_forward_only( targets: t.Iterable[SnapshotIdLike], snapshots: t.Union[t.List[Snapshot], t.Dict[SnapshotId, Snapshot]], @@ -1570,39 +2031,62 @@ def has_paused_forward_only( def missing_intervals( - snapshots: t.Collection[Snapshot], + snapshots: t.Union[t.Collection[Snapshot], t.Dict[SnapshotId, Snapshot]], start: t.Optional[TimeLike] = None, end: t.Optional[TimeLike] = None, execution_time: t.Optional[TimeLike] = None, restatements: t.Optional[t.Dict[SnapshotId, Interval]] = None, deployability_index: t.Optional[DeployabilityIndex] = None, + start_override_per_model: t.Optional[t.Dict[str, datetime]] = None, + end_override_per_model: t.Optional[t.Dict[str, datetime]] = None, ignore_cron: bool = False, end_bounded: bool = False, ) -> t.Dict[Snapshot, Intervals]: """Returns all missing intervals given a collection of snapshots.""" + if not isinstance(snapshots, dict): + # Make sure that the mapping is only constructed once + snapshots = {snapshot.snapshot_id: snapshot for snapshot in snapshots} missing = {} cache: t.Dict[str, datetime] = {} - end_date = end or now() + end_date = end or now_timestamp() start_dt = ( to_datetime(start) if start else earliest_start_date(snapshots, cache=cache, relative_to=end_date) ) restatements = restatements or {} - + start_override_per_model = start_override_per_model or {} + end_override_per_model = end_override_per_model or {} deployability_index = deployability_index or DeployabilityIndex.all_deployable() - for snapshot in snapshots: + for snapshot in snapshots.values(): if not snapshot.evaluatable: continue - interval = restatements.get(snapshot.snapshot_id) - snapshot_start_date = start_dt - snapshot_end_date = end_date - if interval: - snapshot_start_date, snapshot_end_date = (to_datetime(i) for i in interval) + + snapshot_start_date = start_override_per_model.get(snapshot.name, start_dt) + snapshot_end_date: TimeLike = end_date + + restated_interval = restatements.get(snapshot.snapshot_id) + if restated_interval: + snapshot_start_date, snapshot_end_date = (to_datetime(i) for i in restated_interval) snapshot = snapshot.copy() snapshot.intervals = snapshot.intervals.copy() - snapshot.remove_interval(interval) + snapshot.remove_interval(restated_interval) + + existing_interval_end = end_override_per_model.get(snapshot.name) + if existing_interval_end: + if snapshot_start_date >= existing_interval_end: + # The start exceeds the provided interval end, so we can skip this snapshot + # since it doesn't have missing intervals by definition + continue + snapshot_end_date = existing_interval_end + + snapshot_start_date = max( + to_datetime(snapshot_start_date), + to_datetime(start_date(snapshot, snapshots, cache, relative_to=snapshot_end_date)), + ) + if snapshot_start_date > to_datetime(snapshot_end_date): + continue missing_interval_end_date = snapshot_end_date node_end_date = snapshot.node.end @@ -1610,10 +2094,7 @@ def missing_intervals( missing_interval_end_date = node_end_date intervals = snapshot.missing_intervals( - max( - to_datetime(snapshot_start_date), - to_datetime(start_date(snapshot, snapshots, cache, relative_to=snapshot_end_date)), - ), + snapshot_start_date, missing_interval_end_date, execution_time=execution_time, deployability_index=deployability_index, @@ -1626,14 +2107,31 @@ def missing_intervals( return missing -@lru_cache(maxsize=None) +@lru_cache(maxsize=16384) +def expand_range(start_ts: int, end_ts: int, interval_unit: IntervalUnit) -> t.List[int]: + croniter = interval_unit.croniter(start_ts) + timestamps = [start_ts] + + while True: + ts = to_timestamp(croniter.get_next(estimate=True)) + + if ts > end_ts: + if timestamps and timestamps[-1] != end_ts: + timestamps.append(end_ts) + break + + timestamps.append(ts) + return timestamps + + +@lru_cache(maxsize=16384) def compute_missing_intervals( interval_unit: IntervalUnit, intervals: t.Tuple[Interval, ...], start_ts: int, end_ts: int, - upper_bound_ts: int, lookback: int, + model_end_ts: t.Optional[int], ) -> Intervals: """Computes all missing intervals between start and end given intervals. @@ -1642,61 +2140,104 @@ def compute_missing_intervals( intervals: The intervals to check what's missing. start_ts: Inclusive timestamp start. end_ts: Exclusive timestamp end. - upper_bound_ts: The exclusive upper bound timestamp for lookback. lookback: A lookback window. + model_end_ts: The exclusive end timestamp set on the model (if one is set) Returns: A list of all timestamps in this range. """ - croniter = interval_unit.croniter(start_ts) - timestamps = [start_ts] - - # get all individual timestamps with the addition of extra lookback timestamps up to the execution date - # when a model has lookback, we need to check all the intervals between itself and its lookback exist. - while True: - ts = to_timestamp(croniter.get_next(estimate=True)) - - if ts < end_ts: - timestamps.append(ts) - else: - croniter.get_prev(estimate=True) - break - - for _ in range(lookback): - ts = to_timestamp(croniter.get_next(estimate=True)) - if ts < upper_bound_ts: - timestamps.append(ts) - else: - break + if start_ts == end_ts: + return [] - missing = [] - for i in range(len(timestamps)): - if timestamps[i] >= end_ts: - break - current_ts = timestamps[i] - next_ts = ( - timestamps[i + 1] - if i + 1 < len(timestamps) - else min( - to_timestamp(interval_unit.cron_next(current_ts, estimate=True)), upper_bound_ts - ) - ) - compare_ts = seq_get(timestamps, i + lookback) or timestamps[-1] + timestamps = expand_range(start_ts, end_ts, interval_unit) + missing = set() + for current_ts, next_ts in zip(timestamps, timestamps[1:]): for low, high in intervals: - if compare_ts < low: - missing.append((current_ts, next_ts)) + if current_ts < low: + missing.add((current_ts, next_ts)) break - elif current_ts >= low and compare_ts < high: + elif current_ts >= low and next_ts <= high: break else: - missing.append((current_ts, next_ts)) + missing.add((current_ts, next_ts)) - return missing + if missing: + if lookback: + if model_end_ts: + croniter = interval_unit.croniter(end_ts) + end_ts = to_timestamp(croniter.get_prev(estimate=True)) + + while model_end_ts < end_ts: + end_ts = to_timestamp(croniter.get_prev(estimate=True)) + lookback -= 1 + + lookback = max(lookback, 0) + + for i, (current_ts, next_ts) in enumerate(zip(timestamps, timestamps[1:])): + parent = timestamps[i + lookback : i + lookback + 2] + + if len(parent) < 2 or tuple(parent) in missing: + missing.add((current_ts, next_ts)) + + if model_end_ts: + missing = {interval for interval in missing if interval[0] < model_end_ts} + + return sorted(missing) + + +@lru_cache(maxsize=16384) +def inclusive_exclusive( + start: TimeLike, + end: TimeLike, + interval_unit: IntervalUnit, + strict: bool = True, + allow_partial: bool = False, + expand: bool = True, +) -> Interval: + """Transform the inclusive start and end into a [start, end) pair. + + Args: + start: The start date/time of the interval (inclusive) + end: The end date/time of the interval (inclusive) + interval_unit: The interval unit. + strict: Whether to fail when the inclusive start is the same as the exclusive end. + allow_partial: Whether the interval can be partial or not. + expand: Whether or not partial intervals are expanded outwards. + + Returns: + A [start, end) pair. + """ + start_dt = interval_unit.cron_floor(start) + + if not expand and not allow_partial and start_dt < to_datetime(start): + start_dt = interval_unit.cron_next(start_dt) + + start_ts = to_timestamp(start_dt) + + if is_date(end): + end = to_datetime(end) + timedelta(days=1) + + if allow_partial: + end_dt = end + else: + end_dt = interval_unit.cron_floor(end) + + if expand and end_dt != to_datetime(end): + end_dt = interval_unit.cron_next(end_dt) + + end_ts = to_timestamp(end_dt) + + if strict and start_ts >= end_ts: + raise ValueError( + f"`end` ({to_datetime(end_ts)}) must be greater than `start` ({to_datetime(start_ts)})" + ) + + return (start_ts, end_ts) def earliest_start_date( - snapshots: t.Collection[Snapshot], + snapshots: t.Union[t.Collection[Snapshot], t.Dict[SnapshotId, Snapshot]], cache: t.Optional[t.Dict[str, datetime]] = None, relative_to: t.Optional[TimeLike] = None, ) -> datetime: @@ -1711,11 +2252,19 @@ def earliest_start_date( """ cache = {} if cache is None else cache if snapshots: + if not isinstance(snapshots, dict): + # Make sure that the mapping is only constructed once + snapshots = {snapshot.snapshot_id: snapshot for snapshot in snapshots} return min( start_date(snapshot, snapshots, cache=cache, relative_to=relative_to) - for snapshot in snapshots + for snapshot in snapshots.values() ) - return yesterday() + + relative_base = None + if relative_to is not None: + relative_base = to_datetime(relative_to) + + return yesterday(relative_base=relative_base) def start_date( @@ -1750,14 +2299,16 @@ def start_date( if not isinstance(snapshots, dict): snapshots = {snapshot.snapshot_id: snapshot for snapshot in snapshots} - earliest = snapshot.node.cron_prev(snapshot.node.cron_floor(relative_to or now())) - - for parent in snapshot.parents: - if parent in snapshots: - earliest = min( - earliest, - start_date(snapshots[parent], snapshots, cache=cache, relative_to=relative_to), - ) + parent_starts = [ + start_date(snapshots[parent], snapshots, cache=cache, relative_to=relative_to) + for parent in snapshot.parents + if parent in snapshots + ] + earliest = ( + min(parent_starts) + if parent_starts + else snapshot.node.cron_prev(snapshot.node.cron_floor(relative_to or now())) + ) cache[key] = earliest return earliest @@ -1768,3 +2319,187 @@ def snapshots_to_dag(snapshots: t.Collection[Snapshot]) -> DAG[SnapshotId]: for snapshot in snapshots: dag.add(snapshot.snapshot_id, snapshot.parents) return dag + + +def apply_auto_restatements( + snapshots: t.Dict[SnapshotId, Snapshot], execution_time: TimeLike +) -> t.Tuple[t.List[SnapshotIntervals], t.Dict[SnapshotId, t.List[SnapshotId]]]: + """Applies auto restatements to the snapshots. + + This operation results in the removal of intervals for snapshots that are ready to be restated based + on the provided execution time and configured auto restatement settings. For each affected snapshot, + it also updates the next auto restatement timestamp. + + Args: + snapshots: A dictionary of snapshots to apply auto restatements to. + execution_time: The execution time. + + Returns: + A list of SnapshotIntervals with **new** intervals that need to be restated. + """ + dag = snapshots_to_dag(snapshots.values()) + auto_restatement_triggers: t.Dict[SnapshotId, t.List[SnapshotId]] = {} + auto_restated_intervals_per_snapshot: t.Dict[SnapshotId, Interval] = {} + for s_id in dag: + if s_id not in snapshots: + continue + snapshot = snapshots[s_id] + if not snapshot.is_model or snapshot.model.disable_restatement: + continue + + next_auto_restated_interval = snapshot.get_next_auto_restatement_interval(execution_time) + auto_restated_intervals = [ + auto_restated_intervals_per_snapshot[parent_s_id] + for parent_s_id in snapshot.parents + if parent_s_id in auto_restated_intervals_per_snapshot + ] + upstream_triggers = [] + if next_auto_restated_interval: + logger.info( + "Calculated the next auto restated interval (%s, %s) for snapshot %s", + time_like_to_str(next_auto_restated_interval[0]), + time_like_to_str(next_auto_restated_interval[1]), + snapshot.snapshot_id, + ) + auto_restated_intervals.append(next_auto_restated_interval) + + # auto-restated snapshot is its own trigger + upstream_triggers = [s_id] + else: + # inherit each parent's auto-restatement triggers (if any) + for parent_s_id in snapshot.parents: + if parent_s_id in auto_restatement_triggers: + upstream_triggers.extend(auto_restatement_triggers[parent_s_id]) + + # remove duplicate triggers, retaining order and keeping first seen of duplicates + if upstream_triggers: + auto_restatement_triggers[s_id] = unique(upstream_triggers) + + if auto_restated_intervals: + auto_restated_interval_start = sys.maxsize + auto_restated_interval_end = -sys.maxsize + for interval in auto_restated_intervals: + auto_restated_interval_start = min(auto_restated_interval_start, interval[0]) + auto_restated_interval_end = max(auto_restated_interval_end, interval[1]) + + interval_to_remove_start = snapshot.node.interval_unit.cron_floor( + auto_restated_interval_start + ) + interval_to_remove_end = snapshot.node.interval_unit.cron_floor( + auto_restated_interval_end + ) + if auto_restated_interval_end > to_timestamp(interval_to_remove_end): + interval_to_remove_end = snapshot.node.interval_unit.cron_next( + interval_to_remove_end + ) + + removal_interval = snapshot.get_removal_interval( + interval_to_remove_start, interval_to_remove_end, execution_time=execution_time + ) + + auto_restated_intervals_per_snapshot[s_id] = removal_interval + snapshot.pending_restatement_intervals = merge_intervals( + [*snapshot.pending_restatement_intervals, removal_interval] + ) + + snapshot.apply_pending_restatement_intervals() + snapshot.update_next_auto_restatement_ts(execution_time) + return ( + [ + SnapshotIntervals( + name=snapshots[s_id].name, + identifier=None, + version=snapshots[s_id].version, + dev_version=None, + intervals=[], + dev_intervals=[], + pending_restatement_intervals=[interval], + ) + for s_id, interval in auto_restated_intervals_per_snapshot.items() + if s_id in snapshots + ], + auto_restatement_triggers, + ) + + +def parent_snapshots_by_name( + snapshot: Snapshot, snapshots: t.Dict[SnapshotId, Snapshot] +) -> t.Dict[str, Snapshot]: + parent_snapshots_by_name = { + snapshots[p_sid].name: snapshots[p_sid] for p_sid in snapshot.parents + } + parent_snapshots_by_name[snapshot.name] = snapshot + return parent_snapshots_by_name + + +def _contiguous_intervals(intervals: Intervals) -> t.List[Intervals]: + """Given a list of intervals with gaps, returns a list of sequences of contiguous intervals.""" + contiguous_intervals = [] + current_batch: t.List[Interval] = [] + for interval in intervals: + if len(current_batch) == 0 or interval[0] == current_batch[-1][-1]: + current_batch.append(interval) + else: + contiguous_intervals.append(current_batch) + current_batch = [interval] + + if len(current_batch) > 0: + contiguous_intervals.append(current_batch) + + return contiguous_intervals + + +def check_ready_intervals( + check: t.Callable, + intervals: Intervals, + context: ExecutionContext, + python_env: t.Dict[str, Executable], + dialect: DialectType = None, + path: t.Optional[Path] = None, + snapshot: t.Optional[Snapshot] = None, + kwargs: t.Optional[t.Dict] = None, +) -> Intervals: + checked_intervals: Intervals = [] + + for interval_batch in _contiguous_intervals(intervals): + batch = [(to_datetime(start), to_datetime(end)) for start, end in interval_batch] + + try: + ready_intervals = call_macro( + check, + dialect, + path, + provided_args=(batch,), + provided_kwargs=(kwargs or {}), + context=context, + snapshot=snapshot, + ) + except Exception as ex: + raise SignalEvalError(format_evaluated_code_exception(ex, python_env)) + + if isinstance(ready_intervals, bool): + if not ready_intervals: + batch = [] + elif isinstance(ready_intervals, list): + for i in ready_intervals: + if i not in batch: + raise SignalEvalError(f"Unknown interval {i} for signal") + batch = ready_intervals + else: + raise SignalEvalError(f"Expected bool | list, got {type(ready_intervals)} for signal") + + checked_intervals.extend((to_timestamp(start), to_timestamp(end)) for start, end in batch) + + return checked_intervals + + +def get_next_model_interval_start(snapshots: t.Iterable[Snapshot]) -> t.Optional[datetime]: + now_dt = now() + + starts = [ + snap.node.cron_next(now_dt) + for snap in snapshots + if snap.is_model and not snap.is_symbolic and not snap.is_seed + ] + + return min(starts) if starts else None diff --git a/sqlmesh/core/snapshot/evaluator.py b/sqlmesh/core/snapshot/evaluator.py index 386809ff84..4f5102cbef 100644 --- a/sqlmesh/core/snapshot/evaluator.py +++ b/sqlmesh/core/snapshot/evaluator.py @@ -30,42 +30,61 @@ from contextlib import contextmanager from functools import reduce -import pandas as pd from sqlglot import exp, select from sqlglot.executor import execute +from tenacity import retry, stop_after_attempt, wait_exponential, retry_if_not_exception_type from sqlmesh.core import constants as c from sqlmesh.core import dialect as d -from sqlmesh.core.audit import Audit, AuditResult +from sqlmesh.core.audit import Audit, StandaloneAudit from sqlmesh.core.dialect import schema_ -from sqlmesh.core.engine_adapter import EngineAdapter -from sqlmesh.core.engine_adapter.shared import InsertOverwriteStrategy +from sqlmesh.core.engine_adapter.shared import InsertOverwriteStrategy, DataObjectType, DataObject +from sqlmesh.core.model.meta import GrantsTargetLayer from sqlmesh.core.macros import RuntimeStage from sqlmesh.core.model import ( + AuditResult, IncrementalUnmanagedKind, Model, SeedModel, SCDType2ByColumnKind, SCDType2ByTimeKind, ViewKind, + CustomKind, +) +from sqlmesh.core.model.kind import _Incremental, DbtCustomKind +from sqlmesh.utils import CompletionStatus, columns_to_types_all_known +from sqlmesh.core.schema_diff import ( + has_drop_alteration, + TableAlterOperation, + has_additive_alteration, ) -from sqlmesh.core.schema_diff import has_drop_alteration from sqlmesh.core.snapshot import ( DeployabilityIndex, Intervals, Snapshot, - SnapshotChangeCategory, SnapshotId, + SnapshotIdBatch, SnapshotInfoLike, SnapshotTableCleanupTask, ) -from sqlmesh.utils import random_id +from sqlmesh.core.snapshot.execution_tracker import QueryExecutionTracker +from sqlmesh.utils import random_id, CorrelationId, AttributeDict from sqlmesh.utils.concurrency import ( concurrent_apply_to_snapshots, concurrent_apply_to_values, + NodeExecutionFailedError, +) +from sqlmesh.utils.date import TimeLike, now, time_like_to_str +from sqlmesh.utils.errors import ( + ConfigError, + DestructiveChangeError, + MigrationNotSupportedError, + SQLMeshError, + format_destructive_change_msg, + format_additive_change_msg, + AdditiveChangeError, ) -from sqlmesh.utils.date import TimeLike, now -from sqlmesh.utils.errors import AuditError, ConfigError, SQLMeshError +from sqlmesh.utils.jinja import MacroReturnVal if sys.version_info >= (3, 12): from importlib import metadata @@ -74,11 +93,22 @@ if t.TYPE_CHECKING: from sqlmesh.core.engine_adapter._typing import DF, QueryOrDF + from sqlmesh.core.engine_adapter.base import EngineAdapter from sqlmesh.core.environment import EnvironmentNamingInfo logger = logging.getLogger(__name__) +class SnapshotCreationFailedError(SQLMeshError): + def __init__( + self, errors: t.List[NodeExecutionFailedError[SnapshotId]], skipped: t.List[SnapshotId] + ): + messages = "\n\n".join(f"{error}\n {error.__cause__}" for error in errors) + super().__init__(f"Physical table creation failed:\n\n{messages}") + self.errors = errors + self.skipped = skipped + + class SnapshotEvaluator: """Evaluates a snapshot given runtime arguments through an arbitrary EngineAdapter. @@ -87,13 +117,34 @@ class SnapshotEvaluator: does not directly communicate with the underlying execution engine. Args: - adapter: The adapter that interfaces with the execution engine. + adapters: A single EngineAdapter or a dictionary of EngineAdapters where + the key is the gateway name. When a dictionary is provided, and not an + explicit default gateway its first item is treated as the default + adapter and used for the virtual layer. ddl_concurrent_tasks: The number of concurrent tasks used for DDL operations (table / view creation, deletion, etc). Default: 1. """ - def __init__(self, adapter: EngineAdapter, ddl_concurrent_tasks: int = 1): - self.adapter = adapter + def __init__( + self, + adapters: EngineAdapter | t.Dict[str, EngineAdapter], + ddl_concurrent_tasks: int = 1, + selected_gateway: t.Optional[str] = None, + ): + self.adapters = ( + adapters if isinstance(adapters, t.Dict) else {selected_gateway or "": adapters} + ) + self.execution_tracker = QueryExecutionTracker() + self.adapters = { + gateway: adapter.with_settings(query_execution_tracker=self.execution_tracker) + for gateway, adapter in self.adapters.items() + } + self.adapter = ( + next(iter(self.adapters.values())) + if not selected_gateway + else self.adapters[selected_gateway] + ) + self.selected_gateway = selected_gateway self.ddl_concurrent_tasks = ddl_concurrent_tasks def evaluate( @@ -104,8 +155,11 @@ def evaluate( end: TimeLike, execution_time: TimeLike, snapshots: t.Dict[str, Snapshot], + allow_destructive_snapshots: t.Optional[t.Set[str]] = None, + allow_additive_snapshots: t.Optional[t.Set[str]] = None, deployability_index: t.Optional[DeployabilityIndex] = None, batch_index: int = 0, + target_table_exists: t.Optional[bool] = None, **kwargs: t.Any, ) -> t.Optional[str]: """Renders the snapshot's model, executes it and stores the result in the snapshot's physical table. @@ -116,23 +170,32 @@ def evaluate( end: The end datetime to render. execution_time: The date/time time reference to use for execution time. snapshots: All upstream snapshots (by name) to use for expansion and mapping of physical locations. + allow_destructive_snapshots: Snapshots for which destructive schema changes are allowed. + allow_additive_snapshots: Snapshots for which additive schema changes are allowed. deployability_index: Determines snapshots that are deployable in the context of this evaluation. batch_index: If the snapshot is part of a batch of related snapshots; which index in the batch is it + target_table_exists: Whether the target table exists. If None, the table will be checked for existence. kwargs: Additional kwargs to pass to the renderer. Returns: The WAP ID of this evaluation if supported, None otherwise. """ - result = self._evaluate_snapshot( - snapshot, - start, - end, - execution_time, - snapshots, - deployability_index=deployability_index, - batch_index=batch_index, - **kwargs, - ) + with self.execution_tracker.track_execution( + SnapshotIdBatch(snapshot_id=snapshot.snapshot_id, batch_id=batch_index) + ): + result = self._evaluate_snapshot( + start=start, + end=end, + execution_time=execution_time, + snapshot=snapshot, + snapshots=snapshots, + allow_destructive_snapshots=allow_destructive_snapshots or set(), + allow_additive_snapshots=allow_additive_snapshots or set(), + deployability_index=deployability_index, + batch_index=batch_index, + target_table_exists=target_table_exists, + **kwargs, + ) if result is None or isinstance(result, str): return result raise SQLMeshError( @@ -166,27 +229,51 @@ def evaluate_and_fetch( Returns: The result of the evaluation as a dataframe. """ - result = self._evaluate_snapshot( + import pandas as pd + + adapter = self.get_adapter(snapshot.model.gateway) + render_kwargs = dict( + start=start, + end=end, + execution_time=execution_time, + snapshot=snapshot, + runtime_stage=RuntimeStage.EVALUATING, + **kwargs, + ) + queries_or_dfs = self._render_snapshot_for_evaluation( snapshot, - start, - end, - execution_time, snapshots, - limit=limit, - deployability_index=deployability_index, - **kwargs, + deployability_index or DeployabilityIndex.all_deployable(), + render_kwargs, ) - if result is None or isinstance(result, str): - raise SQLMeshError( - f"Unexpected result {result} when evaluating snapshot {snapshot.snapshot_id}." - ) - return result + query_or_df = next(queries_or_dfs) + if isinstance(query_or_df, pd.DataFrame): + return query_or_df.head(limit) + if not isinstance(query_or_df, exp.Expression): + # We assume that if this branch is reached, `query_or_df` is a pyspark / snowpark / bigframe dataframe, + # so we use `limit` instead of `head` to get back a dataframe instead of List[Row] + # https://spark.apache.org/docs/3.1.1/api/python/reference/api/pyspark.sql.DataFrame.head.html#pyspark.sql.DataFrame.head + return query_or_df.limit(limit) + + assert isinstance(query_or_df, exp.Query) + + existing_limit = query_or_df.args.get("limit") + if existing_limit: + limit = min(limit, execute(exp.select(existing_limit.expression)).rows[0][0]) + assert limit is not None + + return adapter._fetch_native_df(query_or_df.limit(limit)) def promote( self, target_snapshots: t.Iterable[Snapshot], environment_naming_info: EnvironmentNamingInfo, deployability_index: t.Optional[DeployabilityIndex] = None, + start: t.Optional[TimeLike] = None, + end: t.Optional[TimeLike] = None, + execution_time: t.Optional[TimeLike] = None, + snapshots: t.Optional[t.Dict[SnapshotId, Snapshot]] = None, + table_mapping: t.Optional[t.Dict[str, str]] = None, on_complete: t.Optional[t.Callable[[SnapshotInfoLike], None]] = None, ) -> None: """Promotes the given collection of snapshots in the target environment by replacing a corresponding @@ -198,32 +285,56 @@ def promote( deployability_index: Determines snapshots that are deployable in the context of this promotion. on_complete: A callback to call on each successfully promoted snapshot. """ - self._create_schemas( - [ - s.qualified_view_name.table_for_environment( - environment_naming_info, dialect=self.adapter.dialect + + tables_by_gateway: t.Dict[t.Union[str, None], t.List[exp.Table]] = defaultdict(list) + for snapshot in target_snapshots: + if snapshot.is_model and not snapshot.is_symbolic: + gateway = ( + snapshot.model_gateway if environment_naming_info.gateway_managed else None ) - for s in target_snapshots - if s.is_model and not s.is_symbolic - ] - ) + adapter = self.get_adapter(gateway) + table = snapshot.qualified_view_name.table_for_environment( + environment_naming_info, dialect=adapter.dialect + ) + tables_by_gateway[gateway].append(table) + + # A schema can be shared across multiple engines, so we need to group by gateway + for gateway, tables in tables_by_gateway.items(): + if environment_naming_info.suffix_target.is_catalog: + self._create_catalogs(tables=tables, gateway=gateway) + + gateway_table_pairs = [ + (gateway, table) for gateway, tables in tables_by_gateway.items() for table in tables + ] + self._create_schemas(gateway_table_pairs=gateway_table_pairs) + + # Fetch the view data objects for the promoted snapshots to get them cached + self._get_virtual_data_objects(target_snapshots, environment_naming_info) + deployability_index = deployability_index or DeployabilityIndex.all_deployable() with self.concurrent_context(): concurrent_apply_to_snapshots( target_snapshots, lambda s: self._promote_snapshot( s, - environment_naming_info, - deployability_index, # type: ignore - on_complete, + start=start, + end=end, + execution_time=execution_time, + snapshots=snapshots, + table_mapping=table_mapping, + environment_naming_info=environment_naming_info, + deployability_index=deployability_index, # type: ignore + on_complete=on_complete, ), self.ddl_concurrent_tasks, ) def demote( self, - target_snapshots: t.Iterable[SnapshotInfoLike], + target_snapshots: t.Iterable[Snapshot], environment_naming_info: EnvironmentNamingInfo, + table_mapping: t.Optional[t.Dict[str, str]] = None, + deployability_index: t.Optional[DeployabilityIndex] = None, on_complete: t.Optional[t.Callable[[SnapshotInfoLike], None]] = None, ) -> None: """Demotes the given collection of snapshots in the target environment by removing its view. @@ -236,7 +347,13 @@ def demote( with self.concurrent_context(): concurrent_apply_to_snapshots( target_snapshots, - lambda s: self._demote_snapshot(s, environment_naming_info, on_complete), + lambda s: self._demote_snapshot( + s, + environment_naming_info, + deployability_index=deployability_index, + on_complete=on_complete, + table_mapping=table_mapping, + ), self.ddl_concurrent_tasks, ) @@ -245,69 +362,121 @@ def create( target_snapshots: t.Iterable[Snapshot], snapshots: t.Dict[SnapshotId, Snapshot], deployability_index: t.Optional[DeployabilityIndex] = None, + on_start: t.Optional[t.Callable] = None, on_complete: t.Optional[t.Callable[[SnapshotInfoLike], None]] = None, - allow_destructive_snapshots: t.Set[str] = set(), - ) -> None: + allow_destructive_snapshots: t.Optional[t.Set[str]] = None, + allow_additive_snapshots: t.Optional[t.Set[str]] = None, + ) -> CompletionStatus: """Creates a physical snapshot schema and table for the given collection of snapshots. Args: target_snapshots: Target snapshots. snapshots: Mapping of snapshot ID to snapshot. deployability_index: Determines snapshots that are deployable in the context of this creation. + on_start: A callback to initialize the snapshot creation progress bar. on_complete: A callback to call on each successfully created snapshot. allow_destructive_snapshots: Set of snapshots that are allowed to have destructive schema changes. + allow_additive_snapshots: Set of snapshots that are allowed to have additive schema changes. + + Returns: + CompletionStatus: The status of the creation operation (success, failure, nothing to do). """ - snapshots_with_table_names = defaultdict(set) - tables_by_schema = defaultdict(set) - for snapshot in target_snapshots: - if not snapshot.is_model or snapshot.is_symbolic: - continue - for is_deployable in (True, False): - table = exp.to_table( - snapshot.table_name(is_deployable), dialect=snapshot.model.dialect - ) - snapshots_with_table_names[snapshot].add(table.name) - tables_by_schema[d.schema_(table.db, catalog=table.catalog)].add(table.name) + deployability_index = deployability_index or DeployabilityIndex.all_deployable() - def _get_data_objects(schema: exp.Table) -> t.Set[str]: - logger.info("Listing data objects in schema %s", schema.sql()) - objs = self.adapter.get_data_objects(schema, tables_by_schema[schema]) - return {obj.name for obj in objs} + snapshots_to_create = self.get_snapshots_to_create(target_snapshots, deployability_index) + if not snapshots_to_create: + return CompletionStatus.NOTHING_TO_DO + if on_start: + on_start(snapshots_to_create) - with self.concurrent_context(): - existing_objects = { - obj - for objs in concurrent_apply_to_values( - list(tables_by_schema), _get_data_objects, self.ddl_concurrent_tasks + self._create_snapshots( + snapshots_to_create=snapshots_to_create, + snapshots={s.name: s for s in snapshots.values()}, + deployability_index=deployability_index, + on_complete=on_complete, + allow_destructive_snapshots=allow_destructive_snapshots or set(), + allow_additive_snapshots=allow_additive_snapshots or set(), + ) + return CompletionStatus.SUCCESS + + def create_physical_schemas( + self, snapshots: t.Iterable[Snapshot], deployability_index: DeployabilityIndex + ) -> None: + """Creates the physical schemas for the given snapshots. + + Args: + snapshots: Snapshots to create physical schemas for. + deployability_index: Determines snapshots that are deployable in the context of this creation. + """ + tables_by_gateway: t.Dict[t.Optional[str], t.List[str]] = defaultdict(list) + for snapshot in snapshots: + if snapshot.is_model and not snapshot.is_symbolic: + tables_by_gateway[snapshot.model_gateway].append( + snapshot.table_name(is_deployable=deployability_index.is_deployable(snapshot)) ) - for obj in objs - } + gateway_table_pairs = [ + (gateway, table) for gateway, tables in tables_by_gateway.items() for table in tables + ] + self._create_schemas(gateway_table_pairs=gateway_table_pairs) + + def get_snapshots_to_create( + self, target_snapshots: t.Iterable[Snapshot], deployability_index: DeployabilityIndex + ) -> t.List[Snapshot]: + """Returns a list of snapshots that need to have their physical tables created. + + Args: + target_snapshots: Target snapshots. + deployability_index: Determines snapshots that are deployable / representative in the context of this creation. + """ + existing_data_objects = self._get_physical_data_objects( + target_snapshots, deployability_index + ) snapshots_to_create = [] - for snapshot, table_names in snapshots_with_table_names.items(): - if table_names - existing_objects or (snapshot.is_seed and not snapshot.intervals): + for snapshot in target_snapshots: + if not snapshot.is_model or snapshot.is_symbolic: + continue + if snapshot.snapshot_id not in existing_data_objects or ( + snapshot.is_seed and not snapshot.intervals + ): snapshots_to_create.append(snapshot) - elif on_complete: - on_complete(snapshot) - if not snapshots_to_create: - return + return snapshots_to_create - self._create_schemas(tables_by_schema) + def _create_snapshots( + self, + snapshots_to_create: t.Iterable[Snapshot], + snapshots: t.Dict[str, Snapshot], + deployability_index: DeployabilityIndex, + on_complete: t.Optional[t.Callable[[SnapshotInfoLike], None]], + allow_destructive_snapshots: t.Set[str], + allow_additive_snapshots: t.Set[str], + ) -> None: + """Internal method to create tables in parallel.""" with self.concurrent_context(): - concurrent_apply_to_snapshots( + errors, skipped = concurrent_apply_to_snapshots( snapshots_to_create, - lambda s: self._create_snapshot( - s, snapshots, deployability_index, on_complete, allow_destructive_snapshots + lambda s: self.create_snapshot( + s, + snapshots=snapshots, + deployability_index=deployability_index, + allow_destructive_snapshots=allow_destructive_snapshots, + allow_additive_snapshots=allow_additive_snapshots, + on_complete=on_complete, ), self.ddl_concurrent_tasks, + raise_on_error=False, ) + if errors: + raise SnapshotCreationFailedError(errors, skipped) def migrate( self, target_snapshots: t.Iterable[Snapshot], snapshots: t.Dict[SnapshotId, Snapshot], - allow_destructive_snapshots: t.Set[str] = set(), + allow_destructive_snapshots: t.Optional[t.Set[str]] = None, + allow_additive_snapshots: t.Optional[t.Set[str]] = None, + deployability_index: t.Optional[DeployabilityIndex] = None, ) -> None: """Alters a physical snapshot table to match its snapshot's schema for the given collection of snapshots. @@ -315,11 +484,33 @@ def migrate( target_snapshots: Target snapshots. snapshots: Mapping of snapshot ID to snapshot. allow_destructive_snapshots: Set of snapshots that are allowed to have destructive schema changes. + allow_additive_snapshots: Set of snapshots that are allowed to have additive schema changes. + deployability_index: Determines snapshots that are deployable in the context of this evaluation. """ + deployability_index = deployability_index or DeployabilityIndex.all_deployable() + target_data_objects = self._get_physical_data_objects(target_snapshots, deployability_index) + if not target_data_objects: + return + + if not snapshots: + snapshots = {s.snapshot_id: s for s in target_snapshots} + + allow_destructive_snapshots = allow_destructive_snapshots or set() + allow_additive_snapshots = allow_additive_snapshots or set() + snapshots_by_name = {s.name: s for s in snapshots.values()} with self.concurrent_context(): + # Only migrate snapshots for which there's an existing data object concurrent_apply_to_snapshots( target_snapshots, - lambda s: self._migrate_snapshot(s, snapshots, allow_destructive_snapshots), + lambda s: self._migrate_snapshot( + s, + snapshots_by_name, + target_data_objects.get(s.snapshot_id), + allow_destructive_snapshots, + allow_additive_snapshots, + self.get_adapter(s.model_gateway), + deployability_index, + ), self.ddl_concurrent_tasks, ) @@ -334,6 +525,9 @@ def cleanup( target_snapshots: Snapshots to cleanup. on_complete: A callback to call on each successfully deleted database object. """ + target_snapshots = [ + t for t in target_snapshots if t.snapshot.is_model and not t.snapshot.is_symbolic + ] snapshots_to_dev_table_only = { t.snapshot.snapshot_id: t.dev_table_only for t in target_snapshots } @@ -341,7 +535,10 @@ def cleanup( concurrent_apply_to_snapshots( [t.snapshot for t in target_snapshots], lambda s: self._cleanup_snapshot( - s, snapshots_to_dev_table_only[s.snapshot_id], on_complete + s, + snapshots_to_dev_table_only[s.snapshot_id], + self.get_adapter(s.model_gateway), + on_complete, ), self.ddl_concurrent_tasks, reverse_order=True, @@ -355,7 +552,6 @@ def audit( start: t.Optional[TimeLike] = None, end: t.Optional[TimeLike] = None, execution_time: t.Optional[TimeLike] = None, - raise_exception: bool = True, deployability_index: t.Optional[DeployabilityIndex] = None, wap_id: t.Optional[str] = None, **kwargs: t.Any, @@ -368,16 +564,12 @@ def audit( start: The start datetime to audit. Defaults to epoch start. end: The end datetime to audit. Defaults to epoch start. execution_time: The date/time time reference to use for execution time. - raise_exception: Whether to raise an exception if the audit fails. Blocking rules determine if an - AuditError is thrown or if we just warn with logger deployability_index: Determines snapshots that are deployable in the context of this evaluation. wap_id: The WAP ID if applicable, None otherwise. kwargs: Additional kwargs to pass to the renderer. """ deployability_index = deployability_index or DeployabilityIndex.all_deployable() - if not deployability_index.is_deployable(snapshot): - # We can't audit a temporary table. - return [] + adapter = self.get_adapter(snapshot.model_gateway) if not snapshot.version: raise ConfigError( @@ -389,24 +581,40 @@ def audit( original_table_name = snapshot.table_name( is_deployable=deployability_index.is_deployable(snapshot) ) - wap_table_name = self.adapter.wap_table_name(original_table_name, wap_id) + wap_table_name = adapter.wap_table_name(original_table_name, wap_id) logger.info( - "Auditing WAP table '%s', snapshot %s", wap_table_name, snapshot.snapshot_id + "Auditing WAP table '%s', snapshot %s", + wap_table_name, + snapshot.snapshot_id, ) table_mapping = kwargs.get("table_mapping") or {} table_mapping[snapshot.name] = wap_table_name kwargs["table_mapping"] = table_mapping - kwargs["this_model"] = exp.to_table(wap_table_name, dialect=self.adapter.dialect) + kwargs["this_model"] = exp.to_table(wap_table_name, dialect=adapter.dialect) results = [] - audits_with_args = snapshot.audits_with_args + audits_with_args = snapshot.node.audits_with_args + + force_non_blocking = False if audits_with_args: logger.info("Auditing snapshot %s", snapshot.snapshot_id) - for audit, audit_args in snapshot.audits_with_args: + if not deployability_index.is_deployable(snapshot) and not adapter.SUPPORTS_CLONING: + # For dev preview tables that aren't based on clones of the production table, only a subset of the data is typically available + # However, users still expect audits to run anwyay. Some audits (such as row count) are practically guaranteed to fail + # when run on only a subset of data, so we switch all audits to non blocking and the user can decide if they still want to proceed + force_non_blocking = True + + for audit, audit_args in audits_with_args: + if force_non_blocking: + # remove any blocking indicator on the model itself + audit_args.pop("blocking", None) + # so that we can fall back to the audit's setting, which we override to blocking: False + audit = audit.model_copy(update={"blocking": False}) + results.append( self._audit( audit=audit, @@ -416,7 +624,6 @@ def audit( start=start, end=end, execution_time=execution_time, - raise_exception=raise_exception, deployability_index=deployability_index, **kwargs, ) @@ -424,11 +631,11 @@ def audit( if wap_id is not None: logger.info( - "Publishing evalaution results for snapshot %s, WAP ID '%s'", + "Publishing evaluation results for snapshot %s, WAP ID '%s'", snapshot.snapshot_id, wap_id, ) - self._wap_publish_snapshot(snapshot, wap_id, deployability_index) + self.wap_publish_snapshot(snapshot, wap_id, deployability_index) return results @@ -443,29 +650,44 @@ def recycle(self) -> None: """Closes all open connections and releases all allocated resources associated with any thread except the calling one.""" try: - self.adapter.recycle() + for adapter in self.adapters.values(): + adapter.recycle() + except Exception: logger.exception("Failed to recycle Snapshot Evaluator") def close(self) -> None: """Closes all open connections and releases all allocated resources.""" try: - self.adapter.close() + for adapter in self.adapters.values(): + adapter.close() except Exception: logger.exception("Failed to close Snapshot Evaluator") + def set_correlation_id(self, correlation_id: CorrelationId) -> SnapshotEvaluator: + return SnapshotEvaluator( + { + gateway: adapter.with_settings(correlation_id=correlation_id) + for gateway, adapter in self.adapters.items() + }, + self.ddl_concurrent_tasks, + self.selected_gateway, + ) + def _evaluate_snapshot( self, - snapshot: Snapshot, start: TimeLike, end: TimeLike, execution_time: TimeLike, + snapshot: Snapshot, snapshots: t.Dict[str, Snapshot], - limit: t.Optional[int] = None, - deployability_index: t.Optional[DeployabilityIndex] = None, - batch_index: int = 0, + allow_destructive_snapshots: t.Set[str], + allow_additive_snapshots: t.Set[str], + deployability_index: t.Optional[DeployabilityIndex], + batch_index: int, + target_table_exists: t.Optional[bool], **kwargs: t.Any, - ) -> DF | str | None: + ) -> t.Optional[str]: """Renders the snapshot's model and executes it. The return value depends on whether the limit was specified. Args: @@ -474,279 +696,540 @@ def _evaluate_snapshot( end: The end datetime to render. execution_time: The date/time time reference to use for execution time. snapshots: All upstream snapshots to use for expansion and mapping of physical locations. - limit: If limit is not None, the query will not be persisted but evaluated and returned as a dataframe. + allow_destructive_snapshots: Snapshots for which destructive schema changes are allowed. + allow_additive_snapshots: Snapshots for which additive schema changes are allowed. deployability_index: Determines snapshots that are deployable in the context of this evaluation. batch_index: If the snapshot is part of a batch of related snapshots; which index in the batch is it + target_table_exists: Whether the target table exists. If None, the table will be checked for existence. kwargs: Additional kwargs to pass to the renderer. """ - if not snapshot.is_model or snapshot.is_seed: + if not snapshot.is_model: return None model = snapshot.model logger.info("Evaluating snapshot %s", snapshot.snapshot_id) + adapter = self.get_adapter(model.gateway) deployability_index = deployability_index or DeployabilityIndex.all_deployable() - table_name = ( - "" - if limit is not None - else snapshot.table_name(is_deployable=deployability_index.is_deployable(snapshot)) - ) - - evaluation_strategy = _evaluation_strategy(snapshot, self.adapter) - - # https://github.com/TobikoData/sqlmesh/issues/2609 + is_snapshot_deployable = deployability_index.is_deployable(snapshot) + target_table_name = snapshot.table_name(is_deployable=is_snapshot_deployable) + # https://github.com/SQLMesh/sqlmesh/issues/2609 # If there are no existing intervals yet; only consider this a first insert for the first snapshot in the batch - is_first_insert = not _intervals(snapshot, deployability_index) and batch_index == 0 - - def apply(query_or_df: QueryOrDF, index: int = 0) -> None: - if index > 0: - evaluation_strategy.append( - table_name=table_name, - query_or_df=query_or_df, - model=snapshot.model, - snapshot=snapshot, - snapshots=snapshots, - deployability_index=deployability_index, - batch_index=batch_index, - start=start, - end=end, - execution_time=execution_time, - ) - else: - logger.info("Inserting batch (%s, %s) into %s'", start, end, table_name) - evaluation_strategy.insert( - table_name=table_name, - query_or_df=query_or_df, - is_first_insert=is_first_insert, - model=snapshot.model, - snapshot=snapshot, - snapshots=snapshots, - deployability_index=deployability_index, - batch_index=batch_index, - start=start, - end=end, - execution_time=execution_time, - ) - - from sqlmesh.core.context import ExecutionContext - + if target_table_exists is None: + target_table_exists = adapter.table_exists(target_table_name) + is_first_insert = ( + not _intervals(snapshot, deployability_index) or not target_table_exists + ) and batch_index == 0 + + # Use the 'creating' stage if the table doesn't exist yet to preserve backwards compatibility with existing projects + # that depend on a separate physical table creation stage. + runtime_stage = RuntimeStage.EVALUATING if target_table_exists else RuntimeStage.CREATING common_render_kwargs = dict( start=start, end=end, execution_time=execution_time, snapshot=snapshot, - runtime_stage=RuntimeStage.EVALUATING, + runtime_stage=runtime_stage, **kwargs, ) - + create_render_kwargs = dict( + engine_adapter=adapter, + snapshots=snapshots, + deployability_index=deployability_index, + **common_render_kwargs, + ) + create_render_kwargs["runtime_stage"] = RuntimeStage.CREATING render_statements_kwargs = dict( - engine_adapter=self.adapter, + engine_adapter=adapter, snapshots=snapshots, deployability_index=deployability_index, **common_render_kwargs, ) + rendered_physical_properties = snapshot.model.render_physical_properties( + **render_statements_kwargs + ) + + evaluation_strategy = _evaluation_strategy(snapshot, adapter) + evaluation_strategy.run_pre_statements( + snapshot=snapshot, + render_kwargs={**render_statements_kwargs, "inside_transaction": False}, + ) + + with ( + adapter.transaction(), + adapter.session(snapshot.model.render_session_properties(**render_statements_kwargs)), + ): + evaluation_strategy.run_pre_statements( + snapshot=snapshot, + render_kwargs={**render_statements_kwargs, "inside_transaction": True}, + ) + + if not target_table_exists or (model.is_seed and not snapshot.intervals): + # Only create the empty table if the columns were provided explicitly by the user + should_create_empty_table = ( + model.kind.is_materialized + and model.columns_to_types_ + and columns_to_types_all_known(model.columns_to_types_) + ) + if not should_create_empty_table: + # Or if the model is self-referential and its query is fully annotated with types + should_create_empty_table = model.depends_on_self and model.annotated + if self._can_clone(snapshot, deployability_index): + self._clone_snapshot_in_dev( + snapshot=snapshot, + snapshots=snapshots, + deployability_index=deployability_index, + render_kwargs=create_render_kwargs, + rendered_physical_properties=rendered_physical_properties.copy(), + allow_destructive_snapshots=allow_destructive_snapshots, + allow_additive_snapshots=allow_additive_snapshots, + ) + runtime_stage = RuntimeStage.EVALUATING + target_table_exists = True + elif should_create_empty_table or model.is_seed or model.kind.is_scd_type_2: + self._execute_create( + snapshot=snapshot, + table_name=target_table_name, + is_table_deployable=is_snapshot_deployable, + deployability_index=deployability_index, + create_render_kwargs=create_render_kwargs, + rendered_physical_properties=rendered_physical_properties.copy(), + dry_run=False, + run_pre_post_statements=False, + ) + runtime_stage = RuntimeStage.EVALUATING + target_table_exists = True + + evaluate_render_kwargs = { + **common_render_kwargs, + "runtime_stage": runtime_stage, + "snapshot_table_exists": target_table_exists, + } - with self.adapter.transaction(), self.adapter.session(snapshot.model.session_properties): wap_id: t.Optional[str] = None if ( - table_name - and snapshot.is_materialized - and (model.wap_supported or self.adapter.wap_supported(table_name)) + snapshot.is_materialized + and target_table_exists + and adapter.wap_enabled + and (model.wap_supported or adapter.wap_supported(target_table_name)) ): wap_id = random_id()[0:8] logger.info("Using WAP ID '%s' for snapshot %s", wap_id, snapshot.snapshot_id) - table_name = self.adapter.wap_prepare(table_name, wap_id) - - if limit is None: - self.adapter.execute(model.render_pre_statements(**render_statements_kwargs)) - - queries_or_dfs = model.render( - context=ExecutionContext( - self.adapter, - snapshots, - deployability_index, - default_dialect=model.dialect, - default_catalog=model.default_catalog, - ), - **common_render_kwargs, + target_table_name = adapter.wap_prepare(target_table_name, wap_id) + + self._render_and_insert_snapshot( + start=start, + end=end, + execution_time=execution_time, + snapshot=snapshot, + snapshots=snapshots, + render_kwargs=evaluate_render_kwargs, + create_render_kwargs=create_render_kwargs, + rendered_physical_properties=rendered_physical_properties, + deployability_index=deployability_index, + target_table_name=target_table_name, + is_first_insert=is_first_insert, + batch_index=batch_index, ) - if limit is not None: - query_or_df = next(queries_or_dfs) - if isinstance(query_or_df, pd.DataFrame): - return query_or_df.head(limit) - if not isinstance(query_or_df, exp.Expression): - # We assume that if this branch is reached, `query_or_df` is a pyspark / snowpark dataframe, - # so we use `limit` instead of `head` to get back a dataframe instead of List[Row] - # https://spark.apache.org/docs/3.1.1/api/python/reference/api/pyspark.sql.DataFrame.head.html#pyspark.sql.DataFrame.head - return query_or_df.limit(limit) - - assert isinstance(query_or_df, exp.Query) - - existing_limit = query_or_df.args.get("limit") - if existing_limit: - limit = min(limit, execute(exp.select(existing_limit.expression)).rows[0][0]) - assert limit is not None - - return self.adapter._fetch_native_df(query_or_df.limit(limit)) - - # DataFrames, unlike SQL expressions, can provide partial results by yielding dataframes. As a result, - # if the engine supports INSERT OVERWRITE or REPLACE WHERE and the snapshot is incremental by time range, we risk - # having a partial result since each dataframe write can re-truncate partitions. To avoid this, we - # union all the dataframes together before writing. For pandas this could result in OOM and a potential - # workaround for that would be to serialize pandas to disk and then read it back with Spark. - # Note: We assume that if multiple things are yielded from `queries_or_dfs` that they are dataframes - # and not SQL expressions. - elif ( - self.adapter.INSERT_OVERWRITE_STRATEGY - in (InsertOverwriteStrategy.INSERT_OVERWRITE, InsertOverwriteStrategy.REPLACE_WHERE) - and snapshot.is_incremental_by_time_range - ): - query_or_df = reduce( - lambda a, b: ( - pd.concat([a, b], ignore_index=True) # type: ignore - if isinstance(a, pd.DataFrame) - else a.union_all(b) # type: ignore - ), # type: ignore - queries_or_dfs, - ) - apply(query_or_df, index=0) - else: - for index, query_or_df in enumerate(queries_or_dfs): - apply(query_or_df, index) + evaluation_strategy.run_post_statements( + snapshot=snapshot, + render_kwargs={**render_statements_kwargs, "inside_transaction": True}, + ) - if limit is None: - self.adapter.execute(model.render_post_statements(**render_statements_kwargs)) + evaluation_strategy.run_post_statements( + snapshot=snapshot, + render_kwargs={**render_statements_kwargs, "inside_transaction": False}, + ) - return wap_id + return wap_id - def _create_snapshot( + def create_snapshot( self, snapshot: Snapshot, - snapshots: t.Dict[SnapshotId, Snapshot], - deployability_index: t.Optional[DeployabilityIndex], - on_complete: t.Optional[t.Callable[[SnapshotInfoLike], None]], + snapshots: t.Dict[str, Snapshot], + deployability_index: DeployabilityIndex, allow_destructive_snapshots: t.Set[str], + allow_additive_snapshots: t.Set[str], + on_complete: t.Optional[t.Callable[[SnapshotInfoLike], None]] = None, ) -> None: + """Creates a physical table for the given snapshot. + + Args: + snapshot: Snapshot to create. + snapshots: All upstream snapshots to use for expansion and mapping of physical locations. + deployability_index: Determines snapshots that are deployable in the context of this creation. + on_complete: A callback to call on each successfully created database object. + allow_destructive_snapshots: Snapshots for which destructive schema changes are allowed. + allow_additive_snapshots: Snapshots for which additive schema changes are allowed. + """ if not snapshot.is_model: return - parent_snapshots_by_name = { - snapshots[p_sid].name: snapshots[p_sid] for p_sid in snapshot.parents - } - parent_snapshots_by_name[snapshot.name] = snapshot - - deployability_index = deployability_index or DeployabilityIndex.all_deployable() + logger.info("Creating a physical table for snapshot %s", snapshot.snapshot_id) - common_render_kwargs: t.Dict[str, t.Any] = dict( - engine_adapter=self.adapter, - snapshots=parent_snapshots_by_name, + adapter = self.get_adapter(snapshot.model.gateway) + create_render_kwargs: t.Dict[str, t.Any] = dict( + engine_adapter=adapter, + snapshots=snapshots, runtime_stage=RuntimeStage.CREATING, + deployability_index=deployability_index, ) - pre_post_render_kwargs = dict( - **common_render_kwargs, - deployability_index=deployability_index.with_deployable(snapshot), - ) - create_render_kwargs = dict( - **common_render_kwargs, - # Refers to self as non-deployable to successfully create self-referential tables / views. - deployability_index=deployability_index.with_non_deployable(snapshot), - ) - - # It can still be useful for some strategies to know if the snapshot was actually deployable - is_snapshot_deployable = deployability_index.is_deployable(snapshot) - - evaluation_strategy = _evaluation_strategy(snapshot, self.adapter) - - with self.adapter.transaction(), self.adapter.session(snapshot.model.session_properties): - self.adapter.execute(snapshot.model.render_pre_statements(**pre_post_render_kwargs)) - if ( - snapshot.is_forward_only - and snapshot.is_materialized - and snapshot.previous_versions - and self.adapter.SUPPORTS_CLONING - ): - target_table_name = snapshot.table_name(is_deployable=False) - tmp_table_name = f"{target_table_name}__schema_migration_source" - source_table_name = snapshot.table_name() + evaluation_strategy = _evaluation_strategy(snapshot, adapter) + evaluation_strategy.run_pre_statements( + snapshot=snapshot, render_kwargs={**create_render_kwargs, "inside_transaction": False} + ) - logger.info(f"Cloning table '{source_table_name}' into '{target_table_name}'") + with ( + adapter.transaction(), + adapter.session(snapshot.model.render_session_properties(**create_render_kwargs)), + ): + rendered_physical_properties = snapshot.model.render_physical_properties( + **create_render_kwargs + ) - evaluation_strategy.create( - table_name=tmp_table_name, - model=snapshot.model, - is_table_deployable=False, - render_kwargs=dict( - table_mapping={snapshot.name: tmp_table_name}, - **create_render_kwargs, - ), - is_snapshot_deployable=is_snapshot_deployable, + if self._can_clone(snapshot, deployability_index): + self._clone_snapshot_in_dev( + snapshot=snapshot, + snapshots=snapshots, + deployability_index=deployability_index, + render_kwargs=create_render_kwargs, + rendered_physical_properties=rendered_physical_properties, + allow_destructive_snapshots=allow_destructive_snapshots, + allow_additive_snapshots=allow_additive_snapshots, + run_pre_post_statements=True, ) - try: - self.adapter.clone_table(target_table_name, snapshot.table_name(), replace=True) - alter_expressions = self.adapter.get_alter_expressions( - target_table_name, tmp_table_name - ) - _check_destructive_schema_change( - snapshot, alter_expressions, allow_destructive_snapshots - ) - self.adapter.alter_table(alter_expressions) - except Exception: - self.adapter.drop_table(target_table_name) - raise - finally: - self.adapter.drop_table(tmp_table_name) else: - table_deployability_flags = [False] - if not snapshot.reuses_previous_version: - table_deployability_flags.append(True) - for is_table_deployable in table_deployability_flags: - evaluation_strategy.create( - table_name=snapshot.table_name(is_deployable=is_table_deployable), - model=snapshot.model, - is_table_deployable=is_table_deployable, - render_kwargs=create_render_kwargs, - is_snapshot_deployable=is_snapshot_deployable, - ) + is_table_deployable = deployability_index.is_deployable(snapshot) + self._execute_create( + snapshot=snapshot, + table_name=snapshot.table_name(is_deployable=is_table_deployable), + is_table_deployable=is_table_deployable, + deployability_index=deployability_index, + create_render_kwargs=create_render_kwargs, + rendered_physical_properties=rendered_physical_properties, + dry_run=True, + ) - self.adapter.execute(snapshot.model.render_post_statements(**pre_post_render_kwargs)) + evaluation_strategy.run_post_statements( + snapshot=snapshot, render_kwargs={**create_render_kwargs, "inside_transaction": False} + ) if on_complete is not None: on_complete(snapshot) - def _migrate_snapshot( + def wap_publish_snapshot( self, snapshot: Snapshot, - snapshots: t.Dict[SnapshotId, Snapshot], - allow_destructive_snapshots: t.Set[str], + wap_id: str, + deployability_index: t.Optional[DeployabilityIndex], ) -> None: - if ( - not snapshot.is_paused - or snapshot.change_category - not in ( - SnapshotChangeCategory.FORWARD_ONLY, - SnapshotChangeCategory.INDIRECT_NON_BREAKING, - ) - or not snapshot.is_model - ): - return - - parent_snapshots_by_name = { - snapshots[p_sid].name: snapshots[p_sid] for p_sid in snapshot.parents - } - parent_snapshots_by_name[snapshot.name] = snapshot + deployability_index = deployability_index or DeployabilityIndex.all_deployable() + table_name = snapshot.table_name(is_deployable=deployability_index.is_deployable(snapshot)) + adapter = self.get_adapter(snapshot.model_gateway) + adapter.wap_publish(table_name, wap_id) - tmp_table_name = snapshot.table_name(is_deployable=False) - target_table_name = snapshot.table_name() - _evaluation_strategy(snapshot, self.adapter).migrate( - target_table_name=target_table_name, - source_table_name=tmp_table_name, - snapshot=snapshot, - snapshots=parent_snapshots_by_name, - allow_destructive_snapshots=allow_destructive_snapshots, - ) + def _render_and_insert_snapshot( + self, + start: TimeLike, + end: TimeLike, + execution_time: TimeLike, + snapshot: Snapshot, + snapshots: t.Dict[str, Snapshot], + render_kwargs: t.Dict[str, t.Any], + create_render_kwargs: t.Dict[str, t.Any], + rendered_physical_properties: t.Dict[str, exp.Expression], + deployability_index: DeployabilityIndex, + target_table_name: str, + is_first_insert: bool, + batch_index: int, + ) -> None: + if not snapshot.is_model or snapshot.is_seed: + return + + logger.info("Inserting data for snapshot %s", snapshot.snapshot_id) + + model = snapshot.model + adapter = self.get_adapter(model.gateway) + evaluation_strategy = _evaluation_strategy(snapshot, adapter) + is_snapshot_deployable = deployability_index.is_deployable(snapshot) + + queries_or_dfs = self._render_snapshot_for_evaluation( + snapshot, + snapshots, + deployability_index, + render_kwargs, + ) + + def apply(query_or_df: QueryOrDF, index: int = 0) -> None: + if index > 0: + evaluation_strategy.append( + table_name=target_table_name, + query_or_df=query_or_df, + model=snapshot.model, + snapshot=snapshot, + snapshots=snapshots, + deployability_index=deployability_index, + batch_index=batch_index, + start=start, + end=end, + execution_time=execution_time, + physical_properties=rendered_physical_properties, + render_kwargs=create_render_kwargs, + is_snapshot_deployable=is_snapshot_deployable, + ) + else: + logger.info( + "Inserting batch (%s, %s) into %s'", + time_like_to_str(start), + time_like_to_str(end), + target_table_name, + ) + evaluation_strategy.insert( + table_name=target_table_name, + query_or_df=query_or_df, + is_first_insert=is_first_insert, + model=snapshot.model, + snapshot=snapshot, + snapshots=snapshots, + deployability_index=deployability_index, + batch_index=batch_index, + start=start, + end=end, + execution_time=execution_time, + physical_properties=rendered_physical_properties, + render_kwargs=create_render_kwargs, + is_snapshot_deployable=is_snapshot_deployable, + ) + + # DataFrames, unlike SQL expressions, can provide partial results by yielding dataframes. As a result, + # if the engine supports INSERT OVERWRITE or REPLACE WHERE and the snapshot is incremental by time range, we risk + # having a partial result since each dataframe write can re-truncate partitions. To avoid this, we + # union all the dataframes together before writing. For pandas this could result in OOM and a potential + # workaround for that would be to serialize pandas to disk and then read it back with Spark. + # Note: We assume that if multiple things are yielded from `queries_or_dfs` that they are dataframes + # and not SQL expressions. + if ( + adapter.INSERT_OVERWRITE_STRATEGY + in ( + InsertOverwriteStrategy.INSERT_OVERWRITE, + InsertOverwriteStrategy.REPLACE_WHERE, + ) + and snapshot.is_incremental_by_time_range + ): + import pandas as pd + + try: + first_query_or_df = next(queries_or_dfs) + except StopIteration: + return + + query_or_df = reduce( + lambda a, b: ( + pd.concat([a, b], ignore_index=True) # type: ignore + if isinstance(a, pd.DataFrame) + else a.union_all(b) # type: ignore + ), # type: ignore + queries_or_dfs, + first_query_or_df, + ) + apply(query_or_df, index=0) + else: + for index, query_or_df in enumerate(queries_or_dfs): + apply(query_or_df, index) + + def _render_snapshot_for_evaluation( + self, + snapshot: Snapshot, + snapshots: t.Dict[str, Snapshot], + deployability_index: DeployabilityIndex, + render_kwargs: t.Dict[str, t.Any], + ) -> t.Iterator[QueryOrDF]: + from sqlmesh.core.context import ExecutionContext + + model = snapshot.model + adapter = self.get_adapter(model.gateway) + + return model.render( + context=ExecutionContext( + adapter, + snapshots, + deployability_index, + default_dialect=model.dialect, + default_catalog=model.default_catalog, + ), + **render_kwargs, + ) + + def _clone_snapshot_in_dev( + self, + snapshot: Snapshot, + snapshots: t.Dict[str, Snapshot], + deployability_index: DeployabilityIndex, + render_kwargs: t.Dict[str, t.Any], + rendered_physical_properties: t.Dict[str, exp.Expression], + allow_destructive_snapshots: t.Set[str], + allow_additive_snapshots: t.Set[str], + run_pre_post_statements: bool = False, + ) -> None: + adapter = self.get_adapter(snapshot.model.gateway) + + target_table_name = snapshot.table_name(is_deployable=False) + source_table_name = snapshot.table_name() + + try: + logger.info(f"Cloning table '{source_table_name}' into '{target_table_name}'") + adapter.clone_table( + target_table_name, + snapshot.table_name(), + rendered_physical_properties=rendered_physical_properties, + ) + self._migrate_target_table( + target_table_name=target_table_name, + snapshot=snapshot, + snapshots=snapshots, + deployability_index=deployability_index, + render_kwargs=render_kwargs, + rendered_physical_properties=rendered_physical_properties, + allow_destructive_snapshots=allow_destructive_snapshots, + allow_additive_snapshots=allow_additive_snapshots, + run_pre_post_statements=run_pre_post_statements, + ) + + except Exception: + adapter.drop_table(target_table_name) + raise + + def _migrate_snapshot( + self, + snapshot: Snapshot, + snapshots: t.Dict[str, Snapshot], + target_data_object: t.Optional[DataObject], + allow_destructive_snapshots: t.Set[str], + allow_additive_snapshots: t.Set[str], + adapter: EngineAdapter, + deployability_index: DeployabilityIndex, + ) -> None: + if not snapshot.is_model or snapshot.is_symbolic: + return + + deployability_index = DeployabilityIndex.all_deployable() + render_kwargs: t.Dict[str, t.Any] = dict( + engine_adapter=adapter, + snapshots=snapshots, + runtime_stage=RuntimeStage.CREATING, + deployability_index=deployability_index, + ) + target_table_name = snapshot.table_name() + + evaluation_strategy = _evaluation_strategy(snapshot, adapter) + evaluation_strategy.run_pre_statements( + snapshot=snapshot, render_kwargs={**render_kwargs, "inside_transaction": False} + ) + + with ( + adapter.transaction(), + adapter.session(snapshot.model.render_session_properties(**render_kwargs)), + ): + table_exists = target_data_object is not None + if adapter.drop_data_object_on_type_mismatch( + target_data_object, _snapshot_to_data_object_type(snapshot) + ): + table_exists = False + + rendered_physical_properties = snapshot.model.render_physical_properties( + **render_kwargs + ) + + if table_exists: + self._migrate_target_table( + target_table_name=target_table_name, + snapshot=snapshot, + snapshots=snapshots, + deployability_index=deployability_index, + render_kwargs=render_kwargs, + rendered_physical_properties=rendered_physical_properties, + allow_destructive_snapshots=allow_destructive_snapshots, + allow_additive_snapshots=allow_additive_snapshots, + run_pre_post_statements=True, + ) + else: + self._execute_create( + snapshot=snapshot, + table_name=snapshot.table_name(is_deployable=True), + is_table_deployable=True, + deployability_index=deployability_index, + create_render_kwargs=render_kwargs, + rendered_physical_properties=rendered_physical_properties, + dry_run=True, + ) + + evaluation_strategy.run_post_statements( + snapshot=snapshot, render_kwargs={**render_kwargs, "inside_transaction": False} + ) + + # Retry in case when the table is migrated concurrently from another plan application + @retry( + reraise=True, + stop=stop_after_attempt(5), + wait=wait_exponential(min=1, max=16), + retry=retry_if_not_exception_type( + (DestructiveChangeError, AdditiveChangeError, MigrationNotSupportedError) + ), + ) + def _migrate_target_table( + self, + target_table_name: str, + snapshot: Snapshot, + snapshots: t.Dict[str, Snapshot], + deployability_index: DeployabilityIndex, + render_kwargs: t.Dict[str, t.Any], + rendered_physical_properties: t.Dict[str, exp.Expression], + allow_destructive_snapshots: t.Set[str], + allow_additive_snapshots: t.Set[str], + run_pre_post_statements: bool = False, + ) -> None: + adapter = self.get_adapter(snapshot.model.gateway) + + tmp_table = exp.to_table(target_table_name) + tmp_table.this.set("this", f"{tmp_table.name}_schema_tmp") + tmp_table_name = tmp_table.sql() + + if snapshot.is_materialized: + self._execute_create( + snapshot=snapshot, + table_name=tmp_table_name, + is_table_deployable=False, + deployability_index=deployability_index, + create_render_kwargs=render_kwargs, + rendered_physical_properties=rendered_physical_properties, + dry_run=False, + run_pre_post_statements=run_pre_post_statements, + skip_grants=True, # skip grants for tmp table + ) + try: + evaluation_strategy = _evaluation_strategy(snapshot, adapter) + logger.info( + "Migrating table schema from '%s' to '%s'", + tmp_table_name, + target_table_name, + ) + evaluation_strategy.migrate( + target_table_name=target_table_name, + source_table_name=tmp_table_name, + snapshot=snapshot, + snapshots=snapshots, + allow_destructive_snapshots=allow_destructive_snapshots, + allow_additive_snapshots=allow_additive_snapshots, + ignore_destructive=snapshot.model.on_destructive_change.is_ignore, + ignore_additive=snapshot.model.on_additive_change.is_ignore, + deployability_index=deployability_index, + ) + finally: + if snapshot.is_materialized: + adapter.drop_table(tmp_table_name) def _promote_snapshot( self, @@ -754,32 +1237,86 @@ def _promote_snapshot( environment_naming_info: EnvironmentNamingInfo, deployability_index: DeployabilityIndex, on_complete: t.Optional[t.Callable[[SnapshotInfoLike], None]], + start: t.Optional[TimeLike] = None, + end: t.Optional[TimeLike] = None, + execution_time: t.Optional[TimeLike] = None, + snapshots: t.Optional[t.Dict[SnapshotId, Snapshot]] = None, + table_mapping: t.Optional[t.Dict[str, str]] = None, ) -> None: - if snapshot.is_model: - table_name = snapshot.table_name(deployability_index.is_representative(snapshot)) - view_name = snapshot.qualified_view_name.for_environment( - environment_naming_info, dialect=self.adapter.dialect - ) - _evaluation_strategy(snapshot, self.adapter).promote( + if not snapshot.is_model: + return + + adapter = ( + self.get_adapter(snapshot.model_gateway) + if environment_naming_info.gateway_managed + else self.adapter + ) + table_name = snapshot.table_name(deployability_index.is_representative(snapshot)) + view_name = snapshot.qualified_view_name.for_environment( + environment_naming_info, dialect=adapter.dialect + ) + render_kwargs: t.Dict[str, t.Any] = dict( + start=start, + end=end, + execution_time=execution_time, + engine_adapter=adapter, + deployability_index=deployability_index, + table_mapping=table_mapping, + runtime_stage=RuntimeStage.PROMOTING, + ) + + with ( + adapter.transaction(), + adapter.session(snapshot.model.render_session_properties(**render_kwargs)), + ): + _evaluation_strategy(snapshot, adapter).promote( table_name=table_name, view_name=view_name, model=snapshot.model, environment=environment_naming_info.name, + snapshots=snapshots, + snapshot=snapshot, + **render_kwargs, ) + snapshot_by_name = {s.name: s for s in (snapshots or {}).values()} + render_kwargs["snapshots"] = snapshot_by_name + adapter.execute(snapshot.model.render_on_virtual_update(**render_kwargs)) + if on_complete is not None: on_complete(snapshot) def _demote_snapshot( self, - snapshot: SnapshotInfoLike, + snapshot: Snapshot, environment_naming_info: EnvironmentNamingInfo, + deployability_index: t.Optional[DeployabilityIndex], on_complete: t.Optional[t.Callable[[SnapshotInfoLike], None]], + table_mapping: t.Optional[t.Dict[str, str]] = None, ) -> None: + if not snapshot.is_model: + return + + adapter = ( + self.get_adapter(snapshot.model_gateway) + if environment_naming_info.gateway_managed + else self.adapter + ) view_name = snapshot.qualified_view_name.for_environment( - environment_naming_info, dialect=self.adapter.dialect + environment_naming_info, dialect=adapter.dialect ) - _evaluation_strategy(snapshot, self.adapter).demote(view_name) + with ( + adapter.transaction(), + adapter.session( + snapshot.model.render_session_properties( + engine_adapter=adapter, + deployability_index=deployability_index, + table_mapping=table_mapping, + runtime_stage=RuntimeStage.DEMOTING, + ) + ), + ): + _evaluation_strategy(snapshot, adapter).demote(view_name) if on_complete is not None: on_complete(snapshot) @@ -788,6 +1325,7 @@ def _cleanup_snapshot( self, snapshot: SnapshotInfoLike, dev_table_only: bool, + adapter: EngineAdapter, on_complete: t.Optional[t.Callable[[str], None]], ) -> None: snapshot = snapshot.table_info @@ -796,29 +1334,32 @@ def _cleanup_snapshot( if not dev_table_only: table_names.append((True, snapshot.table_name(is_deployable=True))) - evaluation_strategy = _evaluation_strategy(snapshot, self.adapter) - + evaluation_strategy = _evaluation_strategy(snapshot, adapter) for is_table_deployable, table_name in table_names: - table = exp.to_table(table_name) - if table.db != snapshot.physical_schema: - raise SQLMeshError( - f"Table '{table_name}' is not a part of the physical schema '{snapshot.physical_schema}' and so can't be dropped." + try: + evaluation_strategy.delete( + table_name, + is_table_deployable=is_table_deployable, + physical_schema=snapshot.physical_schema, + # we need to set cascade=true or we will get a 'cant drop because other objects depend on it'-style + # error on engines that enforce referential integrity, such as Postgres + # this situation can happen when a snapshot expires but downstream view snapshots that reference it have not yet expired + cascade=True, + ) + except Exception: + # Use `get_data_object` to check if the table exists instead of `table_exists` since the former + # is based on `INFORMATION_SCHEMA` and avoids touching the table directly. + # This is important when the table name is malformed for some reason and running any statement + # that touches the table would result in an error. + if adapter.get_data_object(table_name) is not None: + raise + logger.warning( + "Skipping cleanup of table '%s' because it does not exist", table_name ) - evaluation_strategy.delete(table_name, is_table_deployable=is_table_deployable) if on_complete is not None: on_complete(table_name) - def _wap_publish_snapshot( - self, - snapshot: Snapshot, - wap_id: str, - deployability_index: t.Optional[DeployabilityIndex], - ) -> None: - deployability_index = deployability_index or DeployabilityIndex.all_deployable() - table_name = snapshot.table_name(is_deployable=deployability_index.is_deployable(snapshot)) - self.adapter.wap_publish(table_name, wap_id) - def _audit( self, audit: Audit, @@ -828,13 +1369,13 @@ def _audit( start: t.Optional[TimeLike], end: t.Optional[TimeLike], execution_time: t.Optional[TimeLike], - raise_exception: bool, deployability_index: t.Optional[DeployabilityIndex], **kwargs: t.Any, ) -> AuditResult: if audit.skip: return AuditResult( audit=audit, + audit_args=audit_args, model=snapshot.model_or_none, skipped=True, ) @@ -843,50 +1384,260 @@ def _audit( blocking = audit_args.pop("blocking", None) blocking = blocking == exp.true() if blocking else audit.blocking - query = audit.render_query( - snapshot, - start=start, - end=end, - execution_time=execution_time, - snapshots=snapshots, - deployability_index=deployability_index, - engine_adapter=self.adapter, + adapter = self.get_adapter(snapshot.model_gateway) + + kwargs = { + "start": start, + "end": end, + "execution_time": execution_time, + "snapshots": snapshots, + "deployability_index": deployability_index, + "engine_adapter": adapter, + "runtime_stage": RuntimeStage.AUDITING, **audit_args, **kwargs, - ) - count, *_ = self.adapter.fetchone( + } + + if snapshot.is_model: + query = snapshot.model.render_audit_query(audit, **kwargs) + elif isinstance(audit, StandaloneAudit): + query = audit.render_audit_query(**kwargs) + else: + raise SQLMeshError("Expected model or standalone audit. {snapshot}: {audit}") + + count, *_ = adapter.fetchone( select("COUNT(*)").from_(query.subquery("audit")), quote_identifiers=True, - ) - if count and raise_exception: - audit_error = AuditError( - audit_name=audit.name, - model=snapshot.model_or_none, - count=count, - query=query, - adapter_dialect=self.adapter.dialect, - ) - if blocking: - raise audit_error - else: - logger.warning(f"{audit_error}\nAudit is warn only so proceeding with execution.") + ) # type: ignore return AuditResult( audit=audit, + audit_args=audit_args, model=snapshot.model_or_none, count=count, query=query, + blocking=blocking, ) - def _create_schemas(self, tables: t.Iterable[t.Union[exp.Table, str]]) -> None: - table_exprs = [exp.to_table(t) for t in tables] - unique_schemas = {(t.args["db"], t.args.get("catalog")) for t in table_exprs if t and t.db} - # Create schemas sequentially, since some engines (eg. Postgres) may not support concurrent creation - # of schemas with the same name. - for schema_name, catalog in unique_schemas: + def _create_catalogs( + self, + tables: t.Iterable[t.Union[exp.Table, str]], + gateway: t.Optional[str] = None, + ) -> None: + # attempt to create catalogs for the virtual layer if possible + adapter = self.get_adapter(gateway) + if adapter.SUPPORTS_CREATE_DROP_CATALOG: + unique_catalogs = {t.catalog for t in [exp.to_table(maybe_t) for maybe_t in tables]} + for catalog_name in unique_catalogs: + adapter.create_catalog(catalog_name) + + def _create_schemas( + self, + gateway_table_pairs: t.Iterable[t.Tuple[t.Optional[str], t.Union[exp.Table, str]]], + ) -> None: + table_exprs = [(gateway, exp.to_table(t)) for gateway, t in gateway_table_pairs] + unique_schemas = { + (gateway, t.args["db"], t.args.get("catalog")) + for gateway, t in table_exprs + if t and t.db + } + + def _create_schema( + gateway: t.Optional[str], schema_name: str, catalog: t.Optional[str] + ) -> None: schema = schema_(schema_name, catalog) logger.info("Creating schema '%s'", schema) - self.adapter.create_schema(schema) + adapter = self.get_adapter(gateway) + adapter.create_schema(schema) + + with self.concurrent_context(): + concurrent_apply_to_values( + list(unique_schemas), + lambda item: _create_schema(item[0], item[1], item[2]), + self.ddl_concurrent_tasks, + ) + + def get_adapter(self, gateway: t.Optional[str] = None) -> EngineAdapter: + """Returns the adapter for the specified gateway or the default adapter if none is provided.""" + if gateway: + if adapter := self.adapters.get(gateway): + return adapter + raise SQLMeshError(f"Gateway '{gateway}' not found in the available engine adapters.") + return self.adapter + + def _execute_create( + self, + snapshot: Snapshot, + table_name: str, + is_table_deployable: bool, + deployability_index: DeployabilityIndex, + create_render_kwargs: t.Dict[str, t.Any], + rendered_physical_properties: t.Dict[str, exp.Expression], + dry_run: bool, + run_pre_post_statements: bool = True, + skip_grants: bool = False, + ) -> None: + adapter = self.get_adapter(snapshot.model.gateway) + evaluation_strategy = _evaluation_strategy(snapshot, adapter) + + # It can still be useful for some strategies to know if the snapshot was actually deployable + is_snapshot_deployable = deployability_index.is_deployable(snapshot) + is_snapshot_representative = deployability_index.is_representative(snapshot) + + create_render_kwargs = { + **create_render_kwargs, + "table_mapping": {snapshot.name: table_name}, + } + if run_pre_post_statements: + evaluation_strategy.run_pre_statements( + snapshot=snapshot, + render_kwargs={**create_render_kwargs, "inside_transaction": True}, + ) + evaluation_strategy.create( + table_name=table_name, + model=snapshot.model, + is_table_deployable=is_table_deployable, + skip_grants=skip_grants, + render_kwargs=create_render_kwargs, + is_snapshot_deployable=is_snapshot_deployable, + is_snapshot_representative=is_snapshot_representative, + dry_run=dry_run, + physical_properties=rendered_physical_properties, + snapshot=snapshot, + deployability_index=deployability_index, + ) + if run_pre_post_statements: + evaluation_strategy.run_post_statements( + snapshot=snapshot, + render_kwargs={**create_render_kwargs, "inside_transaction": True}, + ) + + def _can_clone(self, snapshot: Snapshot, deployability_index: DeployabilityIndex) -> bool: + adapter = self.get_adapter(snapshot.model.gateway) + return ( + snapshot.is_forward_only + and snapshot.is_materialized + and bool(snapshot.previous_versions) + and adapter.SUPPORTS_CLONING + # managed models cannot have their schema mutated because they're based on queries, so clone + alter won't work + and not snapshot.is_managed + and not snapshot.is_dbt_custom + and not deployability_index.is_deployable(snapshot) + # If the deployable table is missing we can't clone it + and adapter.table_exists(snapshot.table_name()) + ) + + def _get_physical_data_objects( + self, + target_snapshots: t.Iterable[Snapshot], + deployability_index: DeployabilityIndex, + ) -> t.Dict[SnapshotId, DataObject]: + """Returns a dictionary of snapshot IDs to existing data objects of their physical tables. + + Args: + target_snapshots: Target snapshots. + deployability_index: The deployability index to determine whether to look for a deployable or + a non-deployable physical table. + + Returns: + A dictionary of snapshot IDs to existing data objects of their physical tables. If the data object + for a snapshot is not found, it will not be included in the dictionary. + """ + return self._get_data_objects( + target_snapshots, + lambda s: exp.to_table( + s.table_name(deployability_index.is_deployable(s)), dialect=s.model.dialect + ), + ) + + def _get_virtual_data_objects( + self, + target_snapshots: t.Iterable[Snapshot], + environment_naming_info: EnvironmentNamingInfo, + ) -> t.Dict[SnapshotId, DataObject]: + """Returns a dictionary of snapshot IDs to existing data objects of their virtual views. + + Args: + target_snapshots: Target snapshots. + environment_naming_info: The environment naming info of the target virtual environment. + + Returns: + A dictionary of snapshot IDs to existing data objects of their virtual views. If the data object + for a snapshot is not found, it will not be included in the dictionary. + """ + + def _get_view_name(s: Snapshot) -> exp.Table: + adapter = ( + self.get_adapter(s.model_gateway) + if environment_naming_info.gateway_managed + else self.adapter + ) + return exp.to_table( + s.qualified_view_name.for_environment( + environment_naming_info, dialect=adapter.dialect + ), + dialect=adapter.dialect, + ) + + return self._get_data_objects(target_snapshots, _get_view_name) + + def _get_data_objects( + self, + target_snapshots: t.Iterable[Snapshot], + table_name_callable: t.Callable[[Snapshot], exp.Table], + ) -> t.Dict[SnapshotId, DataObject]: + """Returns a dictionary of snapshot IDs to existing data objects. + + Args: + target_snapshots: Target snapshots. + table_name_callable: A function that takes a snapshot and returns the table to look for. + + Returns: + A dictionary of snapshot IDs to existing data objects. If the data object for a snapshot is not found, + it will not be included in the dictionary. + """ + tables_by_gateway_and_schema: t.Dict[t.Union[str, None], t.Dict[exp.Table, set[str]]] = ( + defaultdict(lambda: defaultdict(set)) + ) + snapshots_by_table_name: t.Dict[exp.Table, t.Dict[str, Snapshot]] = defaultdict(dict) + for snapshot in target_snapshots: + if not snapshot.is_model or snapshot.is_symbolic: + continue + table = table_name_callable(snapshot) + table_schema = d.schema_(table.db, catalog=table.catalog) + tables_by_gateway_and_schema[snapshot.model_gateway][table_schema].add(table.name) + snapshots_by_table_name[table_schema][table.name] = snapshot + + def _get_data_objects_in_schema( + schema: exp.Table, + object_names: t.Optional[t.Set[str]] = None, + gateway: t.Optional[str] = None, + ) -> t.List[DataObject]: + logger.info("Listing data objects in schema %s", schema.sql()) + return self.get_adapter(gateway).get_data_objects( + schema, object_names, safe_to_cache=True + ) + + with self.concurrent_context(): + snapshot_id_to_obj: t.Dict[SnapshotId, DataObject] = {} + # A schema can be shared across multiple engines, so we need to group tables by both gateway and schema + for gateway, tables_by_schema in tables_by_gateway_and_schema.items(): + schema_list = list(tables_by_schema.keys()) + results = concurrent_apply_to_values( + schema_list, + lambda s: _get_data_objects_in_schema( + schema=s, object_names=tables_by_schema.get(s), gateway=gateway + ), + self.ddl_concurrent_tasks, + ) + + for schema, objs in zip(schema_list, results): + snapshots_by_name = snapshots_by_table_name.get(schema, {}) + for obj in objs: + if obj.name in snapshots_by_name: + snapshot_id_to_obj[snapshots_by_name[obj.name].snapshot_id] = obj + + return snapshot_id_to_obj def _evaluation_strategy(snapshot: SnapshotInfoLike, adapter: EngineAdapter) -> EvaluationStrategy: @@ -911,12 +1662,25 @@ def _evaluation_strategy(snapshot: SnapshotInfoLike, adapter: EngineAdapter) -> klass = ViewStrategy elif snapshot.is_scd_type_2: klass = SCDType2Strategy + elif snapshot.is_dbt_custom: + if hasattr(snapshot, "model") and isinstance( + (model_kind := snapshot.model.kind), DbtCustomKind + ): + return DbtCustomMaterializationStrategy( + adapter=adapter, + materialization_name=model_kind.materialization, + materialization_template=model_kind.definition, + ) + + raise SQLMeshError( + f"Expected DbtCustomKind for dbt custom materialization in model '{snapshot.name}'" + ) elif snapshot.is_custom: if snapshot.custom_materialization is None: raise SQLMeshError( f"Missing the name of a custom evaluation strategy in model '{snapshot.name}'." ) - klass = get_custom_materialization_type(snapshot.custom_materialization) + _, klass = get_custom_materialization_type_or_raise(snapshot.custom_materialization) return klass(adapter) elif snapshot.is_managed: klass = EngineManagedStrategy @@ -937,6 +1701,7 @@ def insert( query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: """Inserts the given query or a DataFrame into the target table or a view. @@ -949,6 +1714,7 @@ def insert( if no data has been previously inserted into the target table, or when the entire history of the target model has been restated. Note that in the latter case, the table might contain data from previous executions, and it is the responsibility of a specific evaluation strategy to handle the truncation of the table if necessary. + render_kwargs: Additional key-value arguments to pass when rendering the model's query. """ @abc.abstractmethod @@ -957,6 +1723,7 @@ def append( table_name: str, query_or_df: QueryOrDF, model: Model, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: """Appends the given query or a DataFrame to the existing table. @@ -965,6 +1732,7 @@ def append( table_name: The target table name. query_or_df: A query or a DataFrame to insert. model: The target model. + render_kwargs: Additional key-value arguments to pass when rendering the model's query. """ @abc.abstractmethod @@ -974,6 +1742,7 @@ def create( model: Model, is_table_deployable: bool, render_kwargs: t.Dict[str, t.Any], + skip_grants: bool, **kwargs: t.Any, ) -> None: """Creates the target table or view. @@ -995,6 +1764,9 @@ def migrate( target_table_name: str, source_table_name: str, snapshot: Snapshot, + *, + ignore_destructive: bool, + ignore_additive: bool, **kwargs: t.Any, ) -> None: """Migrates the target table schema so that it corresponds to the source table schema. @@ -1003,6 +1775,10 @@ def migrate( target_table_name: The target table name. source_table_name: The source table name. snapshot: The target snapshot. + ignore_destructive: If True, destructive changes are not created when migrating. + This is used for forward-only models that are being migrated to a new version. + ignore_additive: If True, additive changes are not created when migrating. + This is used for forward-only models that are being migrated to a new version. """ @abc.abstractmethod @@ -1039,27 +1815,84 @@ def demote(self, view_name: str, **kwargs: t.Any) -> None: view_name: The name of the target view in the virtual layer. """ - def _replace_query_for_model(self, model: Model, name: str, query_or_df: QueryOrDF) -> None: - """Replaces the table for the given model. + @abc.abstractmethod + def run_pre_statements(self, snapshot: Snapshot, render_kwargs: t.Any) -> None: + """Executes the snapshot's pre statements. Args: - model: The target model. - name: The name of the target table. - query_or_df: The query or DataFrame to replace the target table with. + snapshot: The target snapshot. + render_kwargs: Additional key-value arguments to pass when rendering the statements. """ - self.adapter.replace_query( - name, - query_or_df, - columns_to_types=model.columns_to_types if model.annotated else None, - storage_format=model.storage_format, - partitioned_by=model.partitioned_by, - partition_interval_unit=model.interval_unit, - clustered_by=model.clustered_by, - table_properties=model.physical_properties, - table_description=model.description, - column_descriptions=model.column_descriptions, + + @abc.abstractmethod + def run_post_statements(self, snapshot: Snapshot, render_kwargs: t.Any) -> None: + """Executes the snapshot's post statements. + + Args: + snapshot: The target snapshot. + render_kwargs: Additional key-value arguments to pass when rendering the statements. + """ + + def _apply_grants( + self, + model: Model, + table_name: str, + target_layer: GrantsTargetLayer, + is_snapshot_deployable: bool = False, + ) -> None: + """Apply grants for a model if grants are configured. + + This method provides consistent grants application across all evaluation strategies. + It ensures that whenever a physical database object (table, view, materialized view) + is created or modified, the appropriate grants are applied. + + Args: + model: The SQLMesh model containing grants configuration + table_name: The target table/view name to apply grants to + target_layer: The grants application layer (physical or virtual) + is_snapshot_deployable: Whether the snapshot is deployable (targeting production) + """ + grants_config = model.grants + if grants_config is None: + return + + if not self.adapter.SUPPORTS_GRANTS: + logger.warning( + f"Engine {self.adapter.__class__.__name__} does not support grants. " + f"Skipping grants application for model {model.name}" + ) + return + + model_grants_target_layer = model.grants_target_layer + deployable_vde_dev_only = ( + is_snapshot_deployable and model.virtual_environment_mode.is_dev_only ) + # table_type is always a VIEW in the virtual layer unless model is deployable and VDE is dev_only + # in which case we fall back to the model's model_grants_table_type + if target_layer == GrantsTargetLayer.VIRTUAL and not deployable_vde_dev_only: + model_grants_table_type = DataObjectType.VIEW + else: + model_grants_table_type = model.grants_table_type + + if ( + model_grants_target_layer.is_all + or model_grants_target_layer == target_layer + # Always apply grants in production when VDE is dev_only regardless of target_layer + # since only physical tables are created in production + or deployable_vde_dev_only + ): + logger.info(f"Applying grants for model {model.name} to table {table_name}") + self.adapter.sync_grants_config( + exp.to_table(table_name, dialect=self.adapter.dialect), + grants_config, + model_grants_table_type, + ) + else: + logger.debug( + f"Skipping grants application for model {model.name} in {target_layer} layer" + ) + class SymbolicStrategy(EvaluationStrategy): def insert( @@ -1068,6 +1901,7 @@ def insert( query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: pass @@ -1077,6 +1911,7 @@ def append( table_name: str, query_or_df: QueryOrDF, model: Model, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: pass @@ -1087,6 +1922,7 @@ def create( model: Model, is_table_deployable: bool, render_kwargs: t.Dict[str, t.Any], + skip_grants: bool, **kwargs: t.Any, ) -> None: pass @@ -1096,6 +1932,9 @@ def migrate( target_table_name: str, source_table_name: str, snapshot: Snapshot, + *, + ignore_destructive: bool, + ignore_additive: bool, **kwarg: t.Any, ) -> None: pass @@ -1116,6 +1955,12 @@ def promote( def demote(self, view_name: str, **kwargs: t.Any) -> None: pass + def run_pre_statements(self, snapshot: Snapshot, render_kwargs: t.Dict[str, t.Any]) -> None: + pass + + def run_post_statements(self, snapshot: Snapshot, render_kwargs: t.Dict[str, t.Any]) -> None: + pass + class EmbeddedStrategy(SymbolicStrategy): def promote( @@ -1130,7 +1975,7 @@ def promote( self.adapter.drop_view(view_name, cascade=False) -class PromotableStrategy(EvaluationStrategy): +class PromotableStrategy(EvaluationStrategy, abc.ABC): def promote( self, table_name: str, @@ -1141,59 +1986,81 @@ def promote( ) -> None: is_prod = environment == c.PROD logger.info("Updating view '%s' to point at table '%s'", view_name, table_name) + render_kwargs: t.Dict[str, t.Any] = dict( + start=kwargs.get("start"), + end=kwargs.get("end"), + execution_time=kwargs.get("execution_time"), + engine_adapter=kwargs.get("engine_adapter"), + snapshots=kwargs.get("snapshots"), + deployability_index=kwargs.get("deployability_index"), + table_mapping=kwargs.get("table_mapping"), + runtime_stage=kwargs.get("runtime_stage"), + ) self.adapter.create_view( view_name, exp.select("*").from_(table_name, dialect=self.adapter.dialect), table_description=model.description if is_prod else None, column_descriptions=model.column_descriptions if is_prod else None, - view_properties=model.virtual_properties, + view_properties=model.render_virtual_properties(**render_kwargs), + ) + + snapshot = kwargs.get("snapshot") + deployability_index = kwargs.get("deployability_index") + is_snapshot_deployable = ( + deployability_index.is_deployable(snapshot) + if snapshot and deployability_index + else False ) + # Apply grants to the virtual layer (view) after promotion + self._apply_grants(model, view_name, GrantsTargetLayer.VIRTUAL, is_snapshot_deployable) + def demote(self, view_name: str, **kwargs: t.Any) -> None: logger.info("Dropping view '%s'", view_name) self.adapter.drop_view(view_name, cascade=False) + def run_pre_statements(self, snapshot: Snapshot, render_kwargs: t.Any) -> None: + self.adapter.execute(snapshot.model.render_pre_statements(**render_kwargs)) -class MaterializableStrategy(PromotableStrategy): - def append( - self, - table_name: str, - query_or_df: QueryOrDF, - model: Model, - **kwargs: t.Any, - ) -> None: - self.adapter.insert_append(table_name, query_or_df, columns_to_types=model.columns_to_types) + def run_post_statements(self, snapshot: Snapshot, render_kwargs: t.Any) -> None: + self.adapter.execute(snapshot.model.render_post_statements(**render_kwargs)) + +class MaterializableStrategy(PromotableStrategy, abc.ABC): def create( self, table_name: str, model: Model, is_table_deployable: bool, render_kwargs: t.Dict[str, t.Any], + skip_grants: bool, **kwargs: t.Any, ) -> None: ctas_query = model.ctas_query(**render_kwargs) + physical_properties = kwargs.get("physical_properties", model.physical_properties) logger.info("Creating table '%s'", table_name) if model.annotated: self.adapter.create_table( table_name, - columns_to_types=model.columns_to_types_or_raise, + target_columns_to_types=model.columns_to_types_or_raise, + table_format=model.table_format, storage_format=model.storage_format, partitioned_by=model.partitioned_by, - partition_interval_unit=model.interval_unit, + partition_interval_unit=model.partition_interval_unit, clustered_by=model.clustered_by, - table_properties=model.physical_properties, + table_properties=physical_properties, table_description=model.description if is_table_deployable else None, column_descriptions=model.column_descriptions if is_table_deployable else None, ) + # If we create both temp and prod tables, we need to make sure that we dry run once. + dry_run = kwargs.get("dry_run", True) or not is_table_deployable + # Only sql models have queries that can be tested. - # Additionally, we always create temp tables and sometimes we additionally created prod tables, - # we need to make sure that we only dry run once. # We also need to make sure that we don't dry run on Redshift because its planner / optimizer sometimes # breaks on our CTAS queries due to us relying on the WHERE FALSE LIMIT 0 combo. - if model.is_sql and not is_table_deployable and self.adapter.dialect != "redshift": + if model.is_sql and dry_run and self.adapter.dialect != "redshift": logger.info("Dry running model '%s'", model.name) self.adapter.fetchall(ctas_query) else: @@ -1201,89 +2068,245 @@ def create( table_name, ctas_query, model.columns_to_types, + table_format=model.table_format, storage_format=model.storage_format, partitioned_by=model.partitioned_by, - partition_interval_unit=model.interval_unit, + partition_interval_unit=model.partition_interval_unit, clustered_by=model.clustered_by, - table_properties=model.physical_properties, + table_properties=physical_properties, table_description=model.description if is_table_deployable else None, column_descriptions=model.column_descriptions if is_table_deployable else None, ) + # Apply grants after table creation (unless explicitly skipped by caller) + if not skip_grants: + is_snapshot_deployable = kwargs.get("is_snapshot_deployable", False) + self._apply_grants( + model, table_name, GrantsTargetLayer.PHYSICAL, is_snapshot_deployable + ) + def migrate( self, target_table_name: str, source_table_name: str, snapshot: Snapshot, + *, + ignore_destructive: bool, + ignore_additive: bool, **kwargs: t.Any, ) -> None: logger.info(f"Altering table '{target_table_name}'") - alter_expressions = self.adapter.get_alter_expressions(target_table_name, source_table_name) + alter_operations = self.adapter.get_alter_operations( + target_table_name, + source_table_name, + ignore_destructive=ignore_destructive, + ignore_additive=ignore_additive, + ) _check_destructive_schema_change( - snapshot, alter_expressions, kwargs["allow_destructive_snapshots"] + snapshot, alter_operations, kwargs["allow_destructive_snapshots"] + ) + _check_additive_schema_change( + snapshot, alter_operations, kwargs["allow_additive_snapshots"] + ) + self.adapter.alter_table(alter_operations) + + # Apply grants after schema migration + deployability_index = kwargs.get("deployability_index") + is_snapshot_deployable = ( + deployability_index.is_deployable(snapshot) if deployability_index else False + ) + self._apply_grants( + snapshot.model, target_table_name, GrantsTargetLayer.PHYSICAL, is_snapshot_deployable ) - self.adapter.alter_table(alter_expressions) def delete(self, name: str, **kwargs: t.Any) -> None: - self.adapter.drop_table(name) + _check_table_db_is_physical_schema(name, kwargs["physical_schema"]) + self.adapter.drop_table(name, cascade=kwargs.pop("cascade", False)) logger.info("Dropped table '%s'", name) + def _replace_query_for_model( + self, + model: Model, + name: str, + query_or_df: QueryOrDF, + render_kwargs: t.Dict[str, t.Any], + skip_grants: bool = False, + **kwargs: t.Any, + ) -> None: + """Replaces the table for the given model. -class IncrementalByPartitionStrategy(MaterializableStrategy): - def insert( + Args: + model: The target model. + name: The name of the target table. + query_or_df: The query or DataFrame to replace the target table with. + """ + if (model.is_seed or model.kind.is_full) and model.annotated: + columns_to_types = model.columns_to_types_or_raise + source_columns: t.Optional[t.List[str]] = list(columns_to_types) + else: + try: + # Source columns from the underlying table to prevent unintentional table schema changes during restatement of incremental models. + columns_to_types, source_columns = self._get_target_and_source_columns( + model, name, render_kwargs, force_get_columns_from_target=True + ) + except Exception: + columns_to_types, source_columns = None, None + + self.adapter.replace_query( + name, + query_or_df, + table_format=model.table_format, + storage_format=model.storage_format, + partitioned_by=model.partitioned_by, + partition_interval_unit=model.partition_interval_unit, + clustered_by=model.clustered_by, + table_properties=kwargs.get("physical_properties", model.physical_properties), + table_description=model.description, + column_descriptions=model.column_descriptions, + target_columns_to_types=columns_to_types, + source_columns=source_columns, + ) + + # Apply grants after table replacement (unless explicitly skipped by caller) + if not skip_grants: + is_snapshot_deployable = kwargs.get("is_snapshot_deployable", False) + self._apply_grants(model, name, GrantsTargetLayer.PHYSICAL, is_snapshot_deployable) + + def _get_target_and_source_columns( + self, + model: Model, + table_name: str, + render_kwargs: t.Dict[str, t.Any], + force_get_columns_from_target: bool = False, + ) -> t.Tuple[t.Dict[str, exp.DataType], t.Optional[t.List[str]]]: + if force_get_columns_from_target: + target_column_to_types = self.adapter.columns(table_name) + else: + target_column_to_types = ( + model.columns_to_types # type: ignore + if model.annotated + and not model.on_destructive_change.is_ignore + and not model.on_additive_change.is_ignore + else self.adapter.columns(table_name) + ) + assert target_column_to_types is not None + if model.on_destructive_change.is_ignore or model.on_additive_change.is_ignore: + # We need to identify the columns that are only in the source so we create an empty table with + # the user query to determine that + temp_table_name = exp.table_( + "diff", + db=model.physical_schema, + ) + with self.adapter.temp_table( + model.ctas_query(**render_kwargs), name=temp_table_name + ) as temp_table: + source_columns = list(self.adapter.columns(temp_table)) + else: + source_columns = None + return target_column_to_types, source_columns + + +class IncrementalStrategy(MaterializableStrategy, abc.ABC): + def append( self, table_name: str, query_or_df: QueryOrDF, model: Model, - is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: - self.adapter.insert_overwrite_by_partition( + columns_to_types, source_columns = self._get_target_and_source_columns( + model, table_name, render_kwargs=render_kwargs + ) + self.adapter.insert_append( table_name, query_or_df, - partitioned_by=model.partitioned_by, - columns_to_types=model.columns_to_types, + target_columns_to_types=columns_to_types, + source_columns=source_columns, ) -class IncrementalByTimeRangeStrategy(MaterializableStrategy): +class IncrementalByPartitionStrategy(IncrementalStrategy): + def insert( + self, + table_name: str, + query_or_df: QueryOrDF, + model: Model, + is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], + **kwargs: t.Any, + ) -> None: + if is_first_insert: + self._replace_query_for_model(model, table_name, query_or_df, render_kwargs, **kwargs) + else: + columns_to_types, source_columns = self._get_target_and_source_columns( + model, table_name, render_kwargs=render_kwargs + ) + self.adapter.insert_overwrite_by_partition( + table_name, + query_or_df, + partitioned_by=model.partitioned_by, + target_columns_to_types=columns_to_types, + source_columns=source_columns, + ) + + +class IncrementalByTimeRangeStrategy(IncrementalStrategy): def insert( self, table_name: str, query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: assert model.time_column + columns_to_types, source_columns = self._get_target_and_source_columns( + model, table_name, render_kwargs=render_kwargs + ) self.adapter.insert_overwrite_by_time_partition( table_name, query_or_df, time_formatter=model.convert_to_time_column, time_column=model.time_column, - columns_to_types=model.columns_to_types, + target_columns_to_types=columns_to_types, + source_columns=source_columns, **kwargs, ) -class IncrementalByUniqueKeyStrategy(MaterializableStrategy): +class IncrementalByUniqueKeyStrategy(IncrementalStrategy): def insert( self, table_name: str, query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: if is_first_insert: - self._replace_query_for_model(model, table_name, query_or_df) + self._replace_query_for_model(model, table_name, query_or_df, render_kwargs, **kwargs) else: + columns_to_types, source_columns = self._get_target_and_source_columns( + model, + table_name, + render_kwargs=render_kwargs, + ) self.adapter.merge( table_name, query_or_df, - columns_to_types=model.columns_to_types, + target_columns_to_types=columns_to_types, unique_key=model.unique_key, when_matched=model.when_matched, + merge_filter=model.render_merge_filter( + start=kwargs.get("start"), + end=kwargs.get("end"), + execution_time=kwargs.get("execution_time"), + ), + physical_properties=kwargs.get("physical_properties", model.physical_properties), + source_columns=source_columns, ) def append( @@ -1291,54 +2314,108 @@ def append( table_name: str, query_or_df: QueryOrDF, model: Model, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: + columns_to_types, source_columns = self._get_target_and_source_columns( + model, table_name, render_kwargs=render_kwargs + ) self.adapter.merge( table_name, query_or_df, - columns_to_types=model.columns_to_types, + target_columns_to_types=columns_to_types, unique_key=model.unique_key, when_matched=model.when_matched, + merge_filter=model.render_merge_filter( + start=kwargs.get("start"), + end=kwargs.get("end"), + execution_time=kwargs.get("execution_time"), + ), + physical_properties=kwargs.get("physical_properties", model.physical_properties), + source_columns=source_columns, ) -class IncrementalUnmanagedStrategy(MaterializableStrategy): +class IncrementalUnmanagedStrategy(IncrementalStrategy): + def append( + self, + table_name: str, + query_or_df: QueryOrDF, + model: Model, + render_kwargs: t.Dict[str, t.Any], + **kwargs: t.Any, + ) -> None: + columns_to_types, source_columns = self._get_target_and_source_columns( + model, table_name, render_kwargs=render_kwargs + ) + self.adapter.insert_append( + table_name, + query_or_df, + target_columns_to_types=columns_to_types, + source_columns=source_columns, + ) + def insert( self, table_name: str, query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: if is_first_insert: - self._replace_query_for_model(model, table_name, query_or_df) - elif isinstance(model.kind, IncrementalUnmanagedKind) and model.kind.insert_overwrite: - self.adapter.insert_overwrite_by_partition( + return self._replace_query_for_model( + model, table_name, query_or_df, render_kwargs, **kwargs + ) + if isinstance(model.kind, IncrementalUnmanagedKind) and model.kind.insert_overwrite: + columns_to_types, source_columns = self._get_target_and_source_columns( + model, table_name, - query_or_df, - model.partitioned_by, - columns_to_types=model.columns_to_types, + render_kwargs=render_kwargs, ) - else: - self.append( + + return self.adapter.insert_overwrite_by_partition( table_name, query_or_df, - model, - **kwargs, + model.partitioned_by, + target_columns_to_types=columns_to_types, + source_columns=source_columns, ) + return self.append( + table_name, + query_or_df, + model, + render_kwargs=render_kwargs, + **kwargs, + ) + +class FullRefreshStrategy(MaterializableStrategy): + def append( + self, + table_name: str, + query_or_df: QueryOrDF, + model: Model, + render_kwargs: t.Dict[str, t.Any], + **kwargs: t.Any, + ) -> None: + self.adapter.insert_append( + table_name, + query_or_df, + target_columns_to_types=model.columns_to_types, + ) -class FullRefreshStrategy(MaterializableStrategy): def insert( self, table_name: str, query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: - self._replace_query_for_model(model, table_name, query_or_df) + self._replace_query_for_model(model, table_name, query_or_df, render_kwargs, **kwargs) class SeedStrategy(MaterializableStrategy): @@ -1348,6 +2425,7 @@ def create( model: Model, is_table_deployable: bool, render_kwargs: t.Dict[str, t.Any], + skip_grants: bool, **kwargs: t.Any, ) -> None: model = t.cast(SeedModel, model) @@ -1361,20 +2439,52 @@ def create( ) return - super().create(table_name, model, is_table_deployable, render_kwargs, **kwargs) - if is_table_deployable: - # For seeds we insert data at the time of table creation. - try: - for index, df in enumerate(model.render_seed()): - if index == 0: - self._replace_query_for_model(model, table_name, df) - else: - self.adapter.insert_append( - table_name, df, columns_to_types=model.columns_to_types - ) - except Exception: - self.adapter.drop_table(table_name) - raise + super().create( + table_name, + model, + is_table_deployable, + render_kwargs, + skip_grants=True, # Skip grants; they're applied after data insertion + **kwargs, + ) + # For seeds we insert data at the time of table creation. + try: + for index, df in enumerate(model.render_seed()): + if index == 0: + self._replace_query_for_model( + model, + table_name, + df, + render_kwargs, + skip_grants=True, # Skip grants; they're applied after data insertion + **kwargs, + ) + else: + self.adapter.insert_append( + table_name, df, target_columns_to_types=model.columns_to_types + ) + + if not skip_grants: + # Apply grants after seed table creation and data insertion + is_snapshot_deployable = kwargs.get("is_snapshot_deployable", False) + self._apply_grants( + model, table_name, GrantsTargetLayer.PHYSICAL, is_snapshot_deployable + ) + except Exception: + self.adapter.drop_table(table_name) + raise + + def migrate( + self, + target_table_name: str, + source_table_name: str, + snapshot: Snapshot, + *, + ignore_destructive: bool, + ignore_additive: bool, + **kwargs: t.Any, + ) -> None: + raise NotImplementedError("Seeds do not support migrations.") def insert( self, @@ -1382,19 +2492,32 @@ def insert( query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], + **kwargs: t.Any, + ) -> None: + # Data has already been inserted at the time of table creation. + pass + + def append( + self, + table_name: str, + query_or_df: QueryOrDF, + model: Model, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: # Data has already been inserted at the time of table creation. pass -class SCDType2Strategy(MaterializableStrategy): +class SCDType2Strategy(IncrementalStrategy): def create( self, table_name: str, model: Model, is_table_deployable: bool, render_kwargs: t.Dict[str, t.Any], + skip_grants: bool, **kwargs: t.Any, ) -> None: assert isinstance(model.kind, (SCDType2ByTimeKind, SCDType2ByColumnKind)) @@ -1405,12 +2528,13 @@ def create( columns_to_types[model.kind.updated_at_name.name] = model.kind.time_data_type self.adapter.create_table( table_name, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, + table_format=model.table_format, storage_format=model.storage_format, partitioned_by=model.partitioned_by, - partition_interval_unit=model.interval_unit, + partition_interval_unit=model.partition_interval_unit, clustered_by=model.clustered_by, - table_properties=model.physical_properties, + table_properties=kwargs.get("physical_properties", model.physical_properties), table_description=model.description if is_table_deployable else None, column_descriptions=model.column_descriptions if is_table_deployable else None, ) @@ -1423,17 +2547,33 @@ def create( model, is_table_deployable, render_kwargs, + skip_grants, **kwargs, ) + if not skip_grants: + # Apply grants after SCD Type 2 table creation + is_snapshot_deployable = kwargs.get("is_snapshot_deployable", False) + self._apply_grants( + model, table_name, GrantsTargetLayer.PHYSICAL, is_snapshot_deployable + ) + def insert( self, table_name: str, query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: + # Source columns from the underlying table to prevent unintentional table schema changes during restatement of incremental models. + columns_to_types, source_columns = self._get_target_and_source_columns( + model, + table_name, + render_kwargs=render_kwargs, + force_get_columns_from_target=True, + ) if isinstance(model.kind, SCDType2ByTimeKind): self.adapter.scd_type_2_by_time( target_table=table_name, @@ -1445,10 +2585,17 @@ def insert( updated_at_col=model.kind.updated_at_name, invalidate_hard_deletes=model.kind.invalidate_hard_deletes, updated_at_as_valid_from=model.kind.updated_at_as_valid_from, - columns_to_types=model.columns_to_types, + target_columns_to_types=columns_to_types, + table_format=model.table_format, table_description=model.description, column_descriptions=model.column_descriptions, truncate=is_first_insert, + source_columns=source_columns, + storage_format=model.storage_format, + partitioned_by=model.partitioned_by, + partition_interval_unit=model.partition_interval_unit, + clustered_by=model.clustered_by, + table_properties=kwargs.get("physical_properties", model.physical_properties), ) elif isinstance(model.kind, SCDType2ByColumnKind): self.adapter.scd_type_2_by_column( @@ -1457,61 +2604,47 @@ def insert( unique_key=model.unique_key, valid_from_col=model.kind.valid_from_name, valid_to_col=model.kind.valid_to_name, - execution_time=kwargs["execution_time"], + execution_time=model.kind.updated_at_name or kwargs["execution_time"], check_columns=model.kind.columns, invalidate_hard_deletes=model.kind.invalidate_hard_deletes, execution_time_as_valid_from=model.kind.execution_time_as_valid_from, - columns_to_types=model.columns_to_types, + target_columns_to_types=columns_to_types, + table_format=model.table_format, table_description=model.description, column_descriptions=model.column_descriptions, truncate=is_first_insert, + source_columns=source_columns, + storage_format=model.storage_format, + partitioned_by=model.partitioned_by, + partition_interval_unit=model.partition_interval_unit, + clustered_by=model.clustered_by, + table_properties=kwargs.get("physical_properties", model.physical_properties), ) else: raise SQLMeshError( f"Unexpected SCD Type 2 kind: {model.kind}. This is not expected and please report this as a bug." ) + # Apply grants after SCD Type 2 table recreation + is_snapshot_deployable = kwargs.get("is_snapshot_deployable", False) + self._apply_grants(model, table_name, GrantsTargetLayer.PHYSICAL, is_snapshot_deployable) + def append( self, table_name: str, query_or_df: QueryOrDF, model: Model, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: - if isinstance(model.kind, SCDType2ByTimeKind): - self.adapter.scd_type_2_by_time( - target_table=table_name, - source_table=query_or_df, - unique_key=model.unique_key, - valid_from_col=model.kind.valid_from_name, - valid_to_col=model.kind.valid_to_name, - updated_at_col=model.kind.updated_at_name, - invalidate_hard_deletes=model.kind.invalidate_hard_deletes, - updated_at_as_valid_from=model.kind.updated_at_as_valid_from, - columns_to_types=model.columns_to_types, - table_description=model.description, - column_descriptions=model.column_descriptions, - **kwargs, - ) - elif isinstance(model.kind, SCDType2ByColumnKind): - self.adapter.scd_type_2_by_column( - target_table=table_name, - source_table=query_or_df, - unique_key=model.unique_key, - valid_from_col=model.kind.valid_from_name, - valid_to_col=model.kind.valid_to_name, - check_columns=model.kind.columns, - columns_to_types=model.columns_to_types, - invalidate_hard_deletes=model.kind.invalidate_hard_deletes, - execution_time_as_valid_from=model.kind.execution_time_as_valid_from, - table_description=model.description, - column_descriptions=model.column_descriptions, - **kwargs, - ) - else: - raise SQLMeshError( - f"Unexpected SCD Type 2 kind: {model.kind}. This is not expected and please report this as a bug." - ) + return self.insert( + table_name, + query_or_df, + model, + is_first_insert=False, + render_kwargs=render_kwargs, + **kwargs, + ) class ViewStrategy(PromotableStrategy): @@ -1521,27 +2654,22 @@ def insert( query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: - deployability_index = ( - kwargs.get("deployability_index") or DeployabilityIndex.all_deployable() + # We should recreate MVs across supported engines (Snowflake, BigQuery etc) because + # if upstream tables were recreated (e.g FULL models), the MVs would be silently invalidated. + # The only exception to that rule is RisingWave which doesn't support CREATE OR REPLACE, so upstream + # models don't recreate their physical tables for the MVs to be invalidated. + # However, even for RW we still want to recreate MVs to avoid stale references, as is the case with normal views. + # The flag is_first_insert is used for that matter as a signal to recreate the MV if the snapshot's intervals + # have been cleared by `should_force_rebuild` + is_materialized_view = self._is_materialized_view(model) + must_recreate_view = not self.adapter.HAS_VIEW_BINDING or ( + is_materialized_view and is_first_insert ) - snapshot = kwargs["snapshot"] - snapshots = kwargs["snapshots"] - if ( - ( - isinstance(query_or_df, exp.Expression) - and snapshot.is_materialized_view - and deployability_index.is_deployable(snapshot) - and model.render_query( - snapshots=snapshots, - deployability_index=deployability_index, - engine_adapter=self.adapter, - ) - == query_or_df - ) - or self.adapter.HAS_VIEW_BINDING - ) and self.adapter.table_exists(table_name): + + if self.adapter.table_exists(table_name) and not must_recreate_view: logger.info("Skipping creation of the view '%s'", table_name) return @@ -1550,18 +2678,23 @@ def insert( table_name, query_or_df, model.columns_to_types, - replace=not self.adapter.HAS_VIEW_BINDING, - materialized=self._is_materialized_view(model), - view_properties=model.physical_properties, + replace=must_recreate_view, + materialized=is_materialized_view, + view_properties=kwargs.get("physical_properties", model.physical_properties), table_description=model.description, column_descriptions=model.column_descriptions, ) + # Apply grants after view creation / replacement + is_snapshot_deployable = kwargs.get("is_snapshot_deployable", False) + self._apply_grants(model, table_name, GrantsTargetLayer.PHYSICAL, is_snapshot_deployable) + def append( self, table_name: str, query_or_df: QueryOrDF, model: Model, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: raise ConfigError(f"Cannot append to a view '{table_name}'.") @@ -1572,74 +2705,111 @@ def create( model: Model, is_table_deployable: bool, render_kwargs: t.Dict[str, t.Any], + skip_grants: bool, **kwargs: t.Any, ) -> None: - is_snapshot_deployable: bool = kwargs["is_snapshot_deployable"] - if not is_snapshot_deployable and is_table_deployable: - # If the snapshot is not deployable, the query may contain references to non-deployable tables or views. - # Therefore, we postpone the creation of the deployable view until the snapshot is deployed to production. - logger.info( - "Skipping creation of the deployable view '%s' for the non-deployable snapshot", - table_name, - ) - return + is_snapshot_deployable = kwargs.get("is_snapshot_deployable", False) if self.adapter.table_exists(table_name): # Make sure we don't recreate the view to prevent deletion of downstream views in engines with no late # binding support (because of DROP CASCADE). logger.info("View '%s' already exists", table_name) + + if not skip_grants: + # Always apply grants when present, even if view exists, to handle grants updates + self._apply_grants( + model, table_name, GrantsTargetLayer.PHYSICAL, is_snapshot_deployable + ) return logger.info("Creating view '%s'", table_name) + materialized = self._is_materialized_view(model) + materialized_properties = None + if materialized: + materialized_properties = { + "partitioned_by": model.partitioned_by, + "clustered_by": model.clustered_by, + "partition_interval_unit": model.partition_interval_unit, + } self.adapter.create_view( table_name, model.render_query_or_raise(**render_kwargs), # Make sure we never replace the view during creation to avoid race conditions in engines with no late binding support. replace=False, materialized=self._is_materialized_view(model), - view_properties=model.physical_properties, + materialized_properties=materialized_properties, + view_properties=kwargs.get("physical_properties", model.physical_properties), table_description=model.description if is_table_deployable else None, column_descriptions=model.column_descriptions if is_table_deployable else None, ) + if not skip_grants: + # Apply grants after view creation + self._apply_grants( + model, table_name, GrantsTargetLayer.PHYSICAL, is_snapshot_deployable + ) + def migrate( self, target_table_name: str, source_table_name: str, snapshot: Snapshot, + *, + ignore_destructive: bool, + ignore_additive: bool, **kwargs: t.Any, ) -> None: logger.info("Migrating view '%s'", target_table_name) model = snapshot.model + render_kwargs = dict( + execution_time=now(), snapshots=kwargs["snapshots"], engine_adapter=self.adapter + ) + self.adapter.create_view( target_table_name, - model.render_query_or_raise( - execution_time=now(), snapshots=kwargs["snapshots"], engine_adapter=self.adapter - ), + model.render_query_or_raise(**render_kwargs), model.columns_to_types, materialized=self._is_materialized_view(model), - view_properties=model.physical_properties, + view_properties=model.render_physical_properties(**render_kwargs), table_description=model.description, column_descriptions=model.column_descriptions, ) + # Apply grants after view migration + deployability_index = kwargs.get("deployability_index") + is_snapshot_deployable = ( + deployability_index.is_deployable(snapshot) if deployability_index else False + ) + self._apply_grants( + snapshot.model, target_table_name, GrantsTargetLayer.PHYSICAL, is_snapshot_deployable + ) + def delete(self, name: str, **kwargs: t.Any) -> None: + cascade = kwargs.pop("cascade", False) try: - self.adapter.drop_view(name) + # Some engines (e.g., RisingWave) don’t fail when dropping a materialized view with a DROP VIEW statement, + # because views and materialized views don’t share the same namespace. Therefore, we should not ignore if the + # view doesn't exist and let the exception handler attempt to drop the materialized view. + self.adapter.drop_view(name, cascade=cascade, ignore_if_not_exists=False) except Exception: logger.debug( "Failed to drop view '%s'. Trying to drop the materialized view instead", name, exc_info=True, ) - self.adapter.drop_view(name, materialized=True) + self.adapter.drop_view( + name, materialized=True, cascade=cascade, ignore_if_not_exists=True + ) logger.info("Dropped view '%s'", name) def _is_materialized_view(self, model: Model) -> bool: return isinstance(model.kind, ViewKind) and model.kind.materialized -class CustomMaterialization(MaterializableStrategy): +C = t.TypeVar("C", bound=CustomKind) + + +class CustomMaterialization(IncrementalStrategy, t.Generic[C]): """Base class for custom materializations.""" def insert( @@ -1648,6 +2818,7 @@ def insert( query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: """Inserts the given query or a DataFrame into the target table or a view. @@ -1660,46 +2831,259 @@ def insert( if no data has been previously inserted into the target table, or when the entire history of the target model has been restated. Note that in the latter case, the table might contain data from previous executions, and it is the responsibility of a specific evaluation strategy to handle the truncation of the table if necessary. + render_kwargs: Additional key-value arguments to pass when rendering the model's query. """ raise NotImplementedError( "Custom materialization strategies must implement the 'insert' method." ) -_custom_materialization_type_cache: t.Optional[t.Dict[str, t.Type[CustomMaterialization]]] = None +_custom_materialization_type_cache: t.Optional[ + t.Dict[str, t.Tuple[t.Type[CustomKind], t.Type[CustomMaterialization]]] +] = None + + +def get_custom_materialization_kind_type(st: t.Type[CustomMaterialization]) -> t.Type[CustomKind]: + # try to read if there is a custom 'kind' type in use by inspecting the type signature + # eg try to read 'MyCustomKind' from: + # >>>> class MyCustomMaterialization(CustomMaterialization[MyCustomKind]) + # and fall back to base CustomKind if there is no generic type declared + if hasattr(st, "__orig_bases__"): + for base in st.__orig_bases__: + if hasattr(base, "__origin__") and base.__origin__ == CustomMaterialization: + for generic_arg in t.get_args(base): + if not issubclass(generic_arg, CustomKind): + raise SQLMeshError( + f"Custom materialization kind '{generic_arg.__name__}' must be a subclass of CustomKind" + ) + + return generic_arg + + return CustomKind -def get_custom_materialization_type(name: str) -> t.Type[CustomMaterialization]: +def get_custom_materialization_type( + name: str, raise_errors: bool = True +) -> t.Optional[t.Tuple[t.Type[CustomKind], t.Type[CustomMaterialization]]]: global _custom_materialization_type_cache strategy_key = name.lower() - if ( - _custom_materialization_type_cache is None - or strategy_key not in _custom_materialization_type_cache - ): - strategy_types = list(CustomMaterialization.__subclasses__()) - - entry_points = metadata.entry_points(group="sqlmesh.materializations") - for entry_point in entry_points: - strategy_type = entry_point.load() - if not issubclass(strategy_type, CustomMaterialization): - raise SQLMeshError( - f"Custom materialization entry point '{entry_point.name}' must be a subclass of CustomMaterialization." + try: + if ( + _custom_materialization_type_cache is None + or strategy_key not in _custom_materialization_type_cache + ): + strategy_types = list(CustomMaterialization.__subclasses__()) + + entry_points = metadata.entry_points(group="sqlmesh.materializations") + for entry_point in entry_points: + strategy_type = entry_point.load() + if not issubclass(strategy_type, CustomMaterialization): + raise SQLMeshError( + f"Custom materialization entry point '{entry_point.name}' must be a subclass of CustomMaterialization." + ) + strategy_types.append(strategy_type) + + _custom_materialization_type_cache = { + getattr(strategy_type, "NAME", strategy_type.__name__).lower(): ( + get_custom_materialization_kind_type(strategy_type), + strategy_type, ) - strategy_types.append(strategy_type) + for strategy_type in strategy_types + } + + if strategy_key not in _custom_materialization_type_cache: + raise ConfigError(f"Materialization strategy with name '{name}' was not found.") + except (SQLMeshError, ConfigError) as e: + if raise_errors: + raise e + + from sqlmesh.core.console import get_console + + get_console().log_warning(str(e)) + return None + + strategy_kind_type, strategy_type = _custom_materialization_type_cache[strategy_key] + logger.debug( + "Resolved custom materialization '%s' to '%s' (%s)", name, strategy_type, strategy_kind_type + ) + + return strategy_kind_type, strategy_type + + +def get_custom_materialization_type_or_raise( + name: str, +) -> t.Tuple[t.Type[CustomKind], t.Type[CustomMaterialization]]: + types = get_custom_materialization_type(name, raise_errors=True) + if types is not None: + return types[0], types[1] + + # Shouldnt get here as get_custom_materialization_type() has raise_errors=True, but just in case... + raise SQLMeshError(f"Custom materialization '{name}' not present in the Python environment") + + +class DbtCustomMaterializationStrategy(MaterializableStrategy): + def __init__( + self, + adapter: EngineAdapter, + materialization_name: str, + materialization_template: str, + ): + super().__init__(adapter) + self.materialization_name = materialization_name + self.materialization_template = materialization_template + + def create( + self, + table_name: str, + model: Model, + is_table_deployable: bool, + render_kwargs: t.Dict[str, t.Any], + skip_grants: bool, + **kwargs: t.Any, + ) -> None: + original_query = model.render_query_or_raise(**render_kwargs) + self._execute_materialization( + table_name=table_name, + query_or_df=original_query.limit(0), + model=model, + is_first_insert=True, + render_kwargs=render_kwargs, + create_only=True, + **kwargs, + ) + + # Apply grants after dbt custom materialization table creation + if not skip_grants: + is_snapshot_deployable = kwargs.get("is_snapshot_deployable", False) + self._apply_grants( + model, table_name, GrantsTargetLayer.PHYSICAL, is_snapshot_deployable + ) + + def insert( + self, + table_name: str, + query_or_df: QueryOrDF, + model: Model, + is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], + **kwargs: t.Any, + ) -> None: + self._execute_materialization( + table_name=table_name, + query_or_df=query_or_df, + model=model, + is_first_insert=is_first_insert, + render_kwargs=render_kwargs, + **kwargs, + ) + + # Apply grants after custom materialization insert (only on first insert) + if is_first_insert: + is_snapshot_deployable = kwargs.get("is_snapshot_deployable", False) + self._apply_grants( + model, table_name, GrantsTargetLayer.PHYSICAL, is_snapshot_deployable + ) + + def append( + self, + table_name: str, + query_or_df: QueryOrDF, + model: Model, + render_kwargs: t.Dict[str, t.Any], + **kwargs: t.Any, + ) -> None: + return self.insert( + table_name, + query_or_df, + model, + is_first_insert=False, + render_kwargs=render_kwargs, + **kwargs, + ) + + def run_pre_statements(self, snapshot: Snapshot, render_kwargs: t.Any) -> None: + # in dbt custom materialisations it's up to the user to run the pre hooks inside the transaction + if not render_kwargs.get("inside_transaction", True): + super().run_pre_statements( + snapshot=snapshot, + render_kwargs=render_kwargs, + ) - _custom_materialization_type_cache = { - getattr(strategy_type, "NAME", strategy_type.__name__).lower(): strategy_type - for strategy_type in strategy_types + def run_post_statements(self, snapshot: Snapshot, render_kwargs: t.Any) -> None: + # in dbt custom materialisations it's up to the user to run the post hooks inside the transaction + if not render_kwargs.get("inside_transaction", True): + super().run_post_statements( + snapshot=snapshot, + render_kwargs=render_kwargs, + ) + + def _execute_materialization( + self, + table_name: str, + query_or_df: QueryOrDF, + model: Model, + is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], + create_only: bool = False, + **kwargs: t.Any, + ) -> None: + jinja_macros = model.jinja_macros + + # For vdes we need to use the table, since we don't know the schema/table at parse time + parts = exp.to_table(table_name, dialect=self.adapter.dialect) + + existing_globals = jinja_macros.global_objs + relation_info = existing_globals.get("this") + if isinstance(relation_info, dict): + relation_info["database"] = parts.catalog + relation_info["identifier"] = parts.name + relation_info["name"] = parts.name + + jinja_globals = { + **existing_globals, + "this": relation_info, + "database": parts.catalog, + "schema": parts.db, + "identifier": parts.name, + "target": existing_globals.get("target", {"type": self.adapter.dialect}), + "execution_dt": kwargs.get("execution_time"), + "engine_adapter": self.adapter, + "sql": str(query_or_df), + "is_first_insert": is_first_insert, + "create_only": create_only, + "pre_hooks": [ + AttributeDict({"sql": s.this.this, "transaction": transaction}) + for s in model.pre_statements + if (transaction := s.args.get("transaction", True)) + ], + "post_hooks": [ + AttributeDict({"sql": s.this.this, "transaction": transaction}) + for s in model.post_statements + if (transaction := s.args.get("transaction", True)) + ], + "model_instance": model, + **kwargs, } - if strategy_key not in _custom_materialization_type_cache: - raise ConfigError(f"Materialization strategy with name '{name}' was not found.") + try: + jinja_env = jinja_macros.build_environment(**jinja_globals) + template = jinja_env.from_string(self.materialization_template) + + try: + template.render() + except MacroReturnVal as ret: + # this is a successful return from a macro call (dbt uses this list of Relations to update their relation cache) + returned_relations = ret.value.get("relations", []) + logger.info( + f"Materialization {self.materialization_name} returned relations: {returned_relations}" + ) - strategy_type = _custom_materialization_type_cache[strategy_key] - logger.debug("Resolved custom materialization '%s' to '%s'", name, strategy_type) - return strategy_type + except Exception as e: + raise SQLMeshError( + f"Failed to execute dbt materialization '{self.materialization_name}': {e}" + ) from e class EngineManagedStrategy(MaterializableStrategy): @@ -1709,6 +3093,7 @@ def create( model: Model, is_table_deployable: bool, render_kwargs: t.Dict[str, t.Any], + skip_grants: bool, **kwargs: t.Any, ) -> None: is_snapshot_deployable: bool = kwargs["is_snapshot_deployable"] @@ -1719,13 +3104,21 @@ def create( self.adapter.create_managed_table( table_name=table_name, query=model.render_query_or_raise(**render_kwargs), - columns_to_types=model.columns_to_types, + target_columns_to_types=model.columns_to_types, partitioned_by=model.partitioned_by, clustered_by=model.clustered_by, - table_properties=model.physical_properties, + table_properties=kwargs.get("physical_properties", model.physical_properties), table_description=model.description, column_descriptions=model.column_descriptions, + table_format=model.table_format, ) + + # Apply grants after managed table creation + if not skip_grants: + self._apply_grants( + model, table_name, GrantsTargetLayer.PHYSICAL, is_snapshot_deployable + ) + elif not is_table_deployable: # Only create the dev preview table as a normal table. # For the main table, if the snapshot is cant be deployed to prod (eg upstream is forward-only) do nothing. @@ -1736,6 +3129,7 @@ def create( model=model, is_table_deployable=is_table_deployable, render_kwargs=render_kwargs, + skip_grants=skip_grants, **kwargs, ) @@ -1745,36 +3139,49 @@ def insert( query_or_df: QueryOrDF, model: Model, is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: deployability_index: DeployabilityIndex = kwargs["deployability_index"] snapshot: Snapshot = kwargs["snapshot"] is_snapshot_deployable = deployability_index.is_deployable(snapshot) - if is_first_insert and is_snapshot_deployable and not self.adapter.table_exists(table_name): self.adapter.create_managed_table( table_name=table_name, query=query_or_df, # type: ignore - columns_to_types=model.columns_to_types, + target_columns_to_types=model.columns_to_types, partitioned_by=model.partitioned_by, clustered_by=model.clustered_by, - table_properties=model.physical_properties, + table_properties=kwargs.get("physical_properties", model.physical_properties), table_description=model.description, column_descriptions=model.column_descriptions, + table_format=model.table_format, + ) + self._apply_grants( + model, table_name, GrantsTargetLayer.PHYSICAL, is_snapshot_deployable ) elif not is_snapshot_deployable: # Snapshot isnt deployable; update the preview table instead # If the snapshot was deployable, then data would have already been loaded in create() because a managed table would have been created logger.info( - "Updating preview table: %s (for managed model: %s)", table_name, model.name + "Updating preview table: %s (for managed model: %s)", + table_name, + model.name, + ) + self._replace_query_for_model( + model=model, + name=table_name, + query_or_df=query_or_df, + render_kwargs=render_kwargs, + **kwargs, ) - self._replace_query_for_model(model=model, name=table_name, query_or_df=query_or_df) def append( self, table_name: str, query_or_df: QueryOrDF, model: Model, + render_kwargs: t.Dict[str, t.Any], **kwargs: t.Any, ) -> None: raise ConfigError(f"Cannot append to a managed table '{table_name}'.") @@ -1784,14 +3191,35 @@ def migrate( target_table_name: str, source_table_name: str, snapshot: Snapshot, + *, + ignore_destructive: bool, + ignore_additive: bool, **kwargs: t.Any, ) -> None: - # Not entirely true, many engines support modifying some of the metadata fields on a managed table - # eg Snowflake allows you to ALTER DYNAMIC TABLE foo SET WAREHOUSE=my_other_wh; - raise ConfigError(f"Cannot mutate managed table: {target_table_name}") + potential_alter_operations = self.adapter.get_alter_operations( + target_table_name, + source_table_name, + ignore_destructive=ignore_destructive, + ignore_additive=ignore_additive, + ) + if len(potential_alter_operations) > 0: + # this can happen if a user changes a managed model and deliberately overrides a plan to be forward only, eg `sqlmesh plan --forward-only` + raise MigrationNotSupportedError( + f"The schema of the managed model '{target_table_name}' cannot be updated in a forward-only fashion." + ) + + # Apply grants after verifying no schema changes + deployability_index = kwargs.get("deployability_index") + is_snapshot_deployable = ( + deployability_index.is_deployable(snapshot) if deployability_index else False + ) + self._apply_grants( + snapshot.model, target_table_name, GrantsTargetLayer.PHYSICAL, is_snapshot_deployable + ) def delete(self, name: str, **kwargs: t.Any) -> None: # a dev preview table is created as a normal table, so it needs to be dropped as a normal table + _check_table_db_is_physical_schema(name, kwargs["physical_schema"]) if kwargs["is_table_deployable"]: self.adapter.drop_managed_table(name) logger.info("Dropped managed table '%s'", name) @@ -1810,18 +3238,80 @@ def _intervals(snapshot: Snapshot, deployability_index: DeployabilityIndex) -> I def _check_destructive_schema_change( snapshot: Snapshot, - alter_expressions: t.List[exp.AlterTable], + alter_operations: t.List[TableAlterOperation], allow_destructive_snapshots: t.Set[str], ) -> None: - if snapshot.needs_destructive_check(allow_destructive_snapshots) and has_drop_alteration( - alter_expressions + if ( + snapshot.is_no_rebuild + and snapshot.needs_destructive_check(allow_destructive_snapshots) + and has_drop_alteration(alter_operations) ): - warning_msg = ( - f"Plan results in a destructive change to forward-only table '{snapshot.name}'s schema." - ) + snapshot_name = snapshot.name + model_dialect = snapshot.model.dialect + if snapshot.model.on_destructive_change.is_warn: - logger.warning(warning_msg) + logger.warning( + format_destructive_change_msg( + snapshot_name, + alter_operations, + model_dialect, + error=False, + ) + ) + return + raise DestructiveChangeError( + format_destructive_change_msg(snapshot_name, alter_operations, model_dialect) + ) + + +def _check_additive_schema_change( + snapshot: Snapshot, + alter_operations: t.List[TableAlterOperation], + allow_additive_snapshots: t.Set[str], +) -> None: + # Only check additive changes for incremental models that have the on_additive_change property + if not isinstance(snapshot.model.kind, _Incremental): + return + + if snapshot.needs_additive_check(allow_additive_snapshots) and has_additive_alteration( + alter_operations + ): + # Note: IGNORE filtering is applied before this function is called + # so if we reach here, additive changes are not being ignored + snapshot_name = snapshot.name + model_dialect = snapshot.model.dialect + + if snapshot.model.on_additive_change.is_warn: + logger.warning( + format_additive_change_msg( + snapshot_name, + alter_operations, + model_dialect, + error=False, + ) + ) return + if snapshot.model.on_additive_change.is_error: + raise AdditiveChangeError( + format_additive_change_msg(snapshot_name, alter_operations, model_dialect) + ) + + +def _check_table_db_is_physical_schema(table_name: str, physical_schema: str) -> None: + table = exp.to_table(table_name) + if table.db != physical_schema: raise SQLMeshError( - f"{warning_msg} To allow this, change the model's `on_destructive_change` setting to `warn` or `allow` or include it in the plan's `--allow-destructive-model` option." + f"Table '{table_name}' is not a part of the physical schema '{physical_schema}' and so can't be dropped." ) + + +def _snapshot_to_data_object_type(snapshot: Snapshot) -> DataObjectType: + if snapshot.is_managed: + return DataObjectType.MANAGED_TABLE + if snapshot.is_materialized_view: + return DataObjectType.MATERIALIZED_VIEW + if snapshot.is_view: + return DataObjectType.VIEW + if snapshot.is_materialized: + return DataObjectType.TABLE + return DataObjectType.UNKNOWN diff --git a/sqlmesh/core/snapshot/execution_tracker.py b/sqlmesh/core/snapshot/execution_tracker.py new file mode 100644 index 0000000000..bcafec8d28 --- /dev/null +++ b/sqlmesh/core/snapshot/execution_tracker.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import typing as t +from contextlib import contextmanager +from threading import local +from dataclasses import dataclass, field +from sqlmesh.core.snapshot import SnapshotIdBatch + + +@dataclass +class QueryExecutionStats: + snapshot_id_batch: SnapshotIdBatch + total_rows_processed: t.Optional[int] = None + total_bytes_processed: t.Optional[int] = None + + +@dataclass +class QueryExecutionContext: + """ + Container for tracking rows processed or other execution information during snapshot evaluation. + + It accumulates statistics from multiple cursor.execute() calls during a single snapshot evaluation. + + Attributes: + snapshot_id_batch: Identifier linking this context to a specific snapshot evaluation + stats: Running sum of cursor.rowcount and possibly bytes processed from all executed queries during evaluation + """ + + snapshot_id_batch: SnapshotIdBatch + stats: QueryExecutionStats = field(init=False) + + def __post_init__(self) -> None: + self.stats = QueryExecutionStats(snapshot_id_batch=self.snapshot_id_batch) + + def add_execution( + self, sql: str, row_count: t.Optional[int], bytes_processed: t.Optional[int] + ) -> None: + if row_count is not None and row_count >= 0: + if self.stats.total_rows_processed is None: + self.stats.total_rows_processed = row_count + else: + self.stats.total_rows_processed += row_count + + # conditional on row_count because we should only count bytes corresponding to + # DML actions whose rows were captured + if bytes_processed is not None: + if self.stats.total_bytes_processed is None: + self.stats.total_bytes_processed = bytes_processed + else: + self.stats.total_bytes_processed += bytes_processed + + def get_execution_stats(self) -> QueryExecutionStats: + return self.stats + + +class QueryExecutionTracker: + """Thread-local context manager for snapshot execution statistics, such as rows processed.""" + + def __init__(self) -> None: + self._thread_local = local() + self._contexts: t.Dict[SnapshotIdBatch, QueryExecutionContext] = {} + + def get_execution_context( + self, snapshot_id_batch: SnapshotIdBatch + ) -> t.Optional[QueryExecutionContext]: + return self._contexts.get(snapshot_id_batch) + + def is_tracking(self) -> bool: + return getattr(self._thread_local, "context", None) is not None + + @contextmanager + def track_execution( + self, snapshot_id_batch: SnapshotIdBatch + ) -> t.Iterator[t.Optional[QueryExecutionContext]]: + """Context manager for tracking snapshot execution statistics such as row counts and bytes processed.""" + context = QueryExecutionContext(snapshot_id_batch=snapshot_id_batch) + self._thread_local.context = context + self._contexts[snapshot_id_batch] = context + + try: + yield context + finally: + self._thread_local.context = None + + def record_execution( + self, sql: str, row_count: t.Optional[int], bytes_processed: t.Optional[int] + ) -> None: + context = getattr(self._thread_local, "context", None) + if context is not None: + context.add_execution(sql, row_count, bytes_processed) + + def get_execution_stats( + self, snapshot_id_batch: SnapshotIdBatch + ) -> t.Optional[QueryExecutionStats]: + context = self._contexts.get(snapshot_id_batch) + self._contexts.pop(snapshot_id_batch, None) + return context.get_execution_stats() if context else None diff --git a/sqlmesh/core/state_sync/__init__.py b/sqlmesh/core/state_sync/__init__.py index 8b78e5749b..12ea77ac8f 100644 --- a/sqlmesh/core/state_sync/__init__.py +++ b/sqlmesh/core/state_sync/__init__.py @@ -20,8 +20,4 @@ Versions as Versions, ) from sqlmesh.core.state_sync.cache import CachingStateSync as CachingStateSync -from sqlmesh.core.state_sync.common import ( - CommonStateSyncMixin as CommonStateSyncMixin, - cleanup_expired_views as cleanup_expired_views, -) -from sqlmesh.core.state_sync.engine_adapter import EngineAdapterStateSync as EngineAdapterStateSync +from sqlmesh.core.state_sync.db import EngineAdapterStateSync as EngineAdapterStateSync diff --git a/sqlmesh/core/state_sync/base.py b/sqlmesh/core/state_sync/base.py index 08724d273d..3c8c72845d 100644 --- a/sqlmesh/core/state_sync/base.py +++ b/sqlmesh/core/state_sync/base.py @@ -9,23 +9,30 @@ from sqlglot import __version__ as SQLGLOT_VERSION from sqlmesh import migrations -from sqlmesh.core.environment import Environment, EnvironmentNamingInfo +from sqlmesh.core.environment import ( + Environment, + EnvironmentStatements, + EnvironmentSummary, +) from sqlmesh.core.snapshot import ( Snapshot, SnapshotId, SnapshotIdLike, + SnapshotIdAndVersionLike, SnapshotInfoLike, - SnapshotTableCleanupTask, - SnapshotTableInfo, + SnapshotNameVersion, + SnapshotIdAndVersion, ) from sqlmesh.core.snapshot.definition import Interval, SnapshotIntervals from sqlmesh.utils import major_minor from sqlmesh.utils.date import TimeLike from sqlmesh.utils.errors import SQLMeshError -from sqlmesh.utils.pydantic import ( - PydanticModel, - field_validator, - field_validator_v1_args, +from sqlmesh.utils.pydantic import PydanticModel, field_validator +from sqlmesh.core.state_sync.common import ( + StateStream, + ExpiredSnapshotBatch, + PromotionResult, + ExpiredBatchRange, ) logger = logging.getLogger(__name__) @@ -57,26 +64,14 @@ def _schema_version_validator(cls, v: t.Any) -> int: return 0 if v is None else int(v) +MIN_SCHEMA_VERSION = 60 +MIN_SQLMESH_VERSION = "0.134.0" MIGRATIONS = [ importlib.import_module(f"sqlmesh.migrations.{migration}") for migration in sorted(info.name for info in pkgutil.iter_modules(migrations.__path__)) ] -SCHEMA_VERSION: int = len(MIGRATIONS) - - -class PromotionResult(PydanticModel): - added: t.List[SnapshotTableInfo] - removed: t.List[SnapshotTableInfo] - removed_environment_naming_info: t.Optional[EnvironmentNamingInfo] - - @field_validator("removed_environment_naming_info") - @field_validator_v1_args - def _validate_removed_environment_naming_info( - cls, v: t.Optional[EnvironmentNamingInfo], values: t.Any - ) -> t.Optional[EnvironmentNamingInfo]: - if v and not values["removed"]: - raise ValueError("removed_environment_naming_info must be None if removed is empty") - return v +# -1 to account for the baseline script +SCHEMA_VERSION: int = MIN_SCHEMA_VERSION + len(MIGRATIONS) - 1 class StateReader(abc.ABC): @@ -84,19 +79,35 @@ class StateReader(abc.ABC): @abc.abstractmethod def get_snapshots( - self, - snapshot_ids: t.Optional[t.Iterable[SnapshotIdLike]], + self, snapshot_ids: t.Iterable[SnapshotIdLike] ) -> t.Dict[SnapshotId, Snapshot]: """Bulk fetch snapshots given the corresponding snapshot ids. Args: - snapshot_ids: Iterable of snapshot ids to get. If not provided all - available snapshots will be returned. + snapshot_ids: Iterable of snapshot ids to get. Returns: A dictionary of snapshot ids to snapshots for ones that could be found. """ + @abc.abstractmethod + def get_snapshots_by_names( + self, + snapshot_names: t.Iterable[str], + current_ts: t.Optional[int] = None, + exclude_expired: bool = True, + ) -> t.Set[SnapshotIdAndVersion]: + """Return the snapshot records for all versions of the specified snapshot names. + + Args: + snapshot_names: Iterable of snapshot names to fetch all snapshot records for + current_ts: Sets the current time for identifying which snapshots have expired so they can be excluded (only relevant if :exclude_expired=True) + exclude_expired: Whether or not to return the snapshot id's of expired snapshots in the result + + Returns: + A set containing all the matched snapshot records. To fetch full snapshots, pass it into StateSync.get_snapshots() + """ + @abc.abstractmethod def snapshots_exist(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> t.Set[SnapshotId]: """Checks if multiple snapshots exist in the state sync. @@ -108,6 +119,17 @@ def snapshots_exist(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> t.Set[Sna A set of all the existing snapshot ids. """ + @abc.abstractmethod + def refresh_snapshot_intervals(self, snapshots: t.Collection[Snapshot]) -> t.List[Snapshot]: + """Updates given snapshots with latest intervals from the state. + + Args: + snapshots: The snapshots to refresh. + + Returns: + The updated snapshots. + """ + @abc.abstractmethod def nodes_exist(self, names: t.Iterable[str], exclude_external: bool = False) -> t.Set[str]: """Returns the node names that exist in the state sync. @@ -140,34 +162,30 @@ def get_environments(self) -> t.List[Environment]: """ @abc.abstractmethod - def max_interval_end_for_environment( - self, environment: str, ensure_finalized_snapshots: bool = False - ) -> t.Optional[int]: - """Returns the max interval end for the given environment. - - Args: - environment: The environment. - ensure_finalized_snapshots: Whether to use snapshots from the latest finalized environment state, - or to use whatever snapshots are in the current environment state even if the environment is not finalized. + def get_environments_summary(self) -> t.List[EnvironmentSummary]: + """Fetches all environment names along with expiry datetime. Returns: - A timestamp or None if no interval or environment exists. + A list of all environment summaries. """ @abc.abstractmethod - def greatest_common_interval_end( - self, environment: str, models: t.Set[str], ensure_finalized_snapshots: bool = False - ) -> t.Optional[int]: - """Returns the greatest common interval end for given models in the target environment. + def max_interval_end_per_model( + self, + environment: str, + models: t.Optional[t.Set[str]] = None, + ensure_finalized_snapshots: bool = False, + ) -> t.Dict[str, int]: + """Returns the max interval end per model for the given environment. Args: - environment: The environment. - models: The model FQNs to select intervals from. + environment: The target environment. + models: The models to get the max interval end for. If None, all models are considered. ensure_finalized_snapshots: Whether to use snapshots from the latest finalized environment state, or to use whatever snapshots are in the current environment state even if the environment is not finalized. Returns: - A timestamp or None if no interval or environment exists. + A dictionary of model FQNs to their respective interval ends in milliseconds since epoch. """ @abc.abstractmethod @@ -183,6 +201,24 @@ def close(self) -> None: def state_type(self) -> str: """Returns the type of state sync.""" + @abc.abstractmethod + def update_auto_restatements( + self, next_auto_restatement_ts: t.Dict[SnapshotNameVersion, t.Optional[int]] + ) -> None: + """Updates the next auto restatement timestamp for the snapshots. + + Args: + next_auto_restatement_ts: A dictionary of snapshot name / version pairs to the next auto restatement timestamp. + """ + + @abc.abstractmethod + def get_environment_statements(self, environment: str) -> t.List[EnvironmentStatements]: + """Fetches environment statements from the environment_statements table. + + Returns: + A list of the Environment Statements. + """ + def get_versions(self, validate: bool = True) -> Versions: """Get the current versions of the SQLMesh schema and libraries. @@ -220,6 +256,15 @@ def raise_error( f"{lib} (local) is using version '{local}' which is behind '{remote}' (remote).{upgrade_suggestion}" ) + if major_minor(SQLMESH_VERSION) != major_minor(versions.sqlmesh_version): + raise_error( + "SQLMesh", + SQLMESH_VERSION, + versions.sqlmesh_version, + remote_package_version=versions.sqlmesh_version, + ahead=major_minor(SQLMESH_VERSION) > major_minor(versions.sqlmesh_version), + ) + if SCHEMA_VERSION != versions.schema_version: raise_error( "SQLMesh", @@ -238,26 +283,50 @@ def raise_error( ahead=major_minor(SQLGLOT_VERSION) > major_minor(versions.sqlglot_version), ) - if major_minor(SQLMESH_VERSION) != major_minor(versions.sqlmesh_version): - raise_error( - "SQLMesh", - SQLMESH_VERSION, - versions.sqlmesh_version, - remote_package_version=versions.sqlmesh_version, - ahead=major_minor(SQLMESH_VERSION) > major_minor(versions.sqlmesh_version), - ) - return versions @abc.abstractmethod - def _get_versions(self, lock_for_update: bool = False) -> Versions: + def _get_versions(self) -> Versions: """Queries the store to get the current versions of SQLMesh and deps. + Returns: + The versions object. + """ + + @abc.abstractmethod + def export(self, environment_names: t.Optional[t.List[str]] = None) -> StateStream: + """Export the contents of this StateSync as a StateStream + Args: - lock_for_update: Whether or not the usage of this method plans to update the row. + environment_names: An optional list of environment names to export. If not specified, all environments will be exported. + """ + + @abc.abstractmethod + def get_expired_snapshots( + self, + *, + batch_range: ExpiredBatchRange, + current_ts: t.Optional[int] = None, + ignore_ttl: bool = False, + ) -> t.Optional[ExpiredSnapshotBatch]: + """Returns a single batch of expired snapshots ordered by (updated_ts, name, identifier). + + Args: + current_ts: Timestamp used to evaluate expiration. + ignore_ttl: If True, include snapshots regardless of TTL (only checks if unreferenced). + batch_range: The range of the batch to fetch. Returns: - The versions object. + A batch describing expired snapshots or None if no snapshots are pending cleanup. + """ + + @abc.abstractmethod + def get_expired_environments(self, current_ts: int) -> t.List[EnvironmentSummary]: + """Returns the expired environments. + + Expired environments are environments that have exceeded their time-to-live value. + Returns: + The list of environment summaries to remove. """ @@ -288,33 +357,40 @@ def delete_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None: @abc.abstractmethod def delete_expired_snapshots( - self, ignore_ttl: bool = False - ) -> t.List[SnapshotTableCleanupTask]: + self, + batch_range: ExpiredBatchRange, + ignore_ttl: bool = False, + current_ts: t.Optional[int] = None, + ) -> None: """Removes expired snapshots. Expired snapshots are snapshots that have exceeded their time-to-live and are no longer in use within an environment. Args: + batch_range: The range of snapshots to delete in this batch. ignore_ttl: Ignore the TTL on the snapshot when considering it expired. This has the effect of deleting all snapshots that are not referenced in any environment - - Returns: - The list of table cleanup tasks. + current_ts: Timestamp used to evaluate expiration. """ @abc.abstractmethod - def invalidate_environment(self, name: str) -> None: + def invalidate_environment(self, name: str, protect_prod: bool = True) -> None: """Invalidates the target environment by setting its expiration timestamp to now. Args: name: The name of the environment to invalidate. + protect_prod: If True, prevents invalidation of the production environment. """ @abc.abstractmethod - def remove_interval( + def remove_state(self, including_backup: bool = False) -> None: + """Removes the state store objects.""" + + @abc.abstractmethod + def remove_intervals( self, - snapshot_intervals: t.Sequence[t.Tuple[SnapshotInfoLike, Interval]], + snapshot_intervals: t.Sequence[t.Tuple[SnapshotIdAndVersionLike, Interval]], remove_shared_versions: bool = False, ) -> None: """Remove an interval from a list of snapshots and sync it to the store. @@ -323,21 +399,8 @@ def remove_interval( can also grab all snapshots tied to the passed in version. Args: - snapshots: The snapshot info like object to remove intervals from. - start: The start of the interval to add. - end: The end of the interval to add. - all_snapshots: All snapshots can be passed in to skip fetching matching snapshot versions. - """ - - @abc.abstractmethod - def refresh_snapshot_intervals(self, snapshots: t.Collection[Snapshot]) -> t.List[Snapshot]: - """Updates given snapshots with latest intervals from the state. - - Args: - snapshots: The snapshots to refresh. - - Returns: - The updated snapshots. + snapshot_intervals: The snapshot intervals to remove. + remove_shared_versions: Whether to remove intervals for snapshots that share the same version with the target snapshots. """ @abc.abstractmethod @@ -345,6 +408,7 @@ def promote( self, environment: Environment, no_gaps_snapshot_names: t.Optional[t.Set[str]] = None, + environment_statements: t.Optional[t.List[EnvironmentStatements]] = None, ) -> PromotionResult: """Update the environment to reflect the current state. @@ -371,7 +435,9 @@ def finalize(self, environment: Environment) -> None: """ @abc.abstractmethod - def delete_expired_environments(self) -> t.List[Environment]: + def delete_expired_environments( + self, current_ts: t.Optional[int] = None + ) -> t.List[EnvironmentSummary]: """Removes expired environments. Expired environments are environments that have exceeded their time-to-live value. @@ -406,7 +472,6 @@ def compact_intervals(self) -> None: @abc.abstractmethod def migrate( self, - default_catalog: t.Optional[str], skip_backup: bool = False, promoted_snapshots_only: bool = True, ) -> None: @@ -417,11 +482,11 @@ def rollback(self) -> None: """Rollback to previous backed up state.""" @abc.abstractmethod - def _add_snapshot_intervals(self, snapshot_intervals: SnapshotIntervals) -> None: + def add_snapshots_intervals(self, snapshots_intervals: t.Sequence[SnapshotIntervals]) -> None: """Add snapshot intervals to state Args: - snapshot_intervals: The snapshot intervals to add. + snapshots_intervals: The intervals to add. """ def add_interval( @@ -430,6 +495,7 @@ def add_interval( start: TimeLike, end: TimeLike, is_dev: bool = False, + last_altered_ts: t.Optional[int] = None, ) -> None: """Add an interval to a snapshot and sync it to the store. @@ -438,8 +504,9 @@ def add_interval( start: The start of the interval to add. end: The end of the interval to add. is_dev: Indicates whether the given interval is being added while in development mode + last_altered_ts: The timestamp of the last modification of the physical table """ - start_ts, end_ts = snapshot.inclusive_exclusive(start, end, strict=False) + start_ts, end_ts = snapshot.inclusive_exclusive(start, end, strict=False, expand=False) if not snapshot.version: raise SQLMeshError("Snapshot version must be set to add an interval.") intervals = [(start_ts, end_ts)] @@ -447,10 +514,23 @@ def add_interval( name=snapshot.name, identifier=snapshot.identifier, version=snapshot.version, + dev_version=snapshot.dev_version, intervals=intervals if not is_dev else [], dev_intervals=intervals if is_dev else [], + last_altered_ts=last_altered_ts if not is_dev else None, + dev_last_altered_ts=last_altered_ts if is_dev else None, ) - self._add_snapshot_intervals(snapshot_intervals) + self.add_snapshots_intervals([snapshot_intervals]) + + @abc.abstractmethod + def import_(self, stream: StateStream, clear: bool = True) -> None: + """ + Replace the existing state with the state contained in the StateStream + + Args: + stream: The stream of new state + clear: Whether or not to clear existing state before inserting state from the stream + """ class DelegatingStateSync(StateSync): diff --git a/sqlmesh/core/state_sync/cache.py b/sqlmesh/core/state_sync/cache.py index bf59eeed69..77f3fc6ba5 100644 --- a/sqlmesh/core/state_sync/cache.py +++ b/sqlmesh/core/state_sync/cache.py @@ -1,6 +1,5 @@ from __future__ import annotations -import sys import typing as t from sqlmesh.core.model import SeedModel @@ -8,18 +7,14 @@ Snapshot, SnapshotId, SnapshotIdLike, + SnapshotIdAndVersionLike, SnapshotInfoLike, - SnapshotTableCleanupTask, ) from sqlmesh.core.snapshot.definition import Interval, SnapshotIntervals from sqlmesh.core.state_sync.base import DelegatingStateSync, StateSync +from sqlmesh.core.state_sync.common import ExpiredBatchRange from sqlmesh.utils.date import TimeLike, now_timestamp -if sys.version_info >= (3, 8): - from typing import Literal -else: - from typing_extensions import Literal - class CachingStateSync(DelegatingStateSync): """In memory cache for snapshots that implements the state sync api. @@ -35,15 +30,15 @@ def __init__(self, state_sync: StateSync, ttl: int = 120): # False means that the snapshot does not exist in the state sync but has been requested before # None means that the snapshot has not been requested. self.snapshot_cache: t.Dict[ - SnapshotId, t.Tuple[t.Optional[Snapshot | Literal[False]], int] + SnapshotId, t.Tuple[t.Optional[Snapshot | t.Literal[False]], int] ] = {} self.ttl = ttl def _from_cache( self, snapshot_id: SnapshotId, now: int - ) -> t.Optional[Snapshot | Literal[False]]: - snapshot: t.Optional[Snapshot | Literal[False]] = None + ) -> t.Optional[Snapshot | t.Literal[False]]: + snapshot: t.Optional[Snapshot | t.Literal[False]] = None snapshot_expiration = self.snapshot_cache.get(snapshot_id) if snapshot_expiration and snapshot_expiration[1] >= now: @@ -52,11 +47,8 @@ def _from_cache( return snapshot def get_snapshots( - self, snapshot_ids: t.Optional[t.Iterable[SnapshotIdLike]] + self, snapshot_ids: t.Iterable[SnapshotIdLike] ) -> t.Dict[SnapshotId, Snapshot]: - if snapshot_ids is None: - return self.state_sync.get_snapshots(snapshot_ids) - existing = {} missing = set() now = now_timestamp() @@ -117,26 +109,45 @@ def delete_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None: self.state_sync.delete_snapshots(snapshot_ids) def delete_expired_snapshots( - self, ignore_ttl: bool = False - ) -> t.List[SnapshotTableCleanupTask]: + self, + batch_range: ExpiredBatchRange, + ignore_ttl: bool = False, + current_ts: t.Optional[int] = None, + ) -> None: self.snapshot_cache.clear() - return self.state_sync.delete_expired_snapshots(ignore_ttl=ignore_ttl) - - def _add_snapshot_intervals(self, snapshot_intervals: SnapshotIntervals) -> None: - self.snapshot_cache.pop(snapshot_intervals.snapshot_id, None) - self.state_sync._add_snapshot_intervals(snapshot_intervals) - - def remove_interval( + self.state_sync.delete_expired_snapshots( + batch_range=batch_range, + ignore_ttl=ignore_ttl, + current_ts=current_ts, + ) + + def add_snapshots_intervals(self, snapshots_intervals: t.Sequence[SnapshotIntervals]) -> None: + for snapshot_intervals in snapshots_intervals: + if snapshot_intervals.snapshot_id: + self.snapshot_cache.pop(snapshot_intervals.snapshot_id, None) + else: + # Evict all snapshots that share the same name + self.snapshot_cache = { + snapshot_id: value + for snapshot_id, value in self.snapshot_cache.items() + if snapshot_id.name != snapshot_intervals.name + } + self.state_sync.add_snapshots_intervals(snapshots_intervals) + + def remove_intervals( self, - snapshot_intervals: t.Sequence[t.Tuple[SnapshotInfoLike, Interval]], + snapshot_intervals: t.Sequence[t.Tuple[SnapshotIdAndVersionLike, Interval]], remove_shared_versions: bool = False, ) -> None: for s, _ in snapshot_intervals: self.snapshot_cache.pop(s.snapshot_id, None) - self.state_sync.remove_interval(snapshot_intervals, remove_shared_versions) + self.state_sync.remove_intervals(snapshot_intervals, remove_shared_versions) def unpause_snapshots( self, snapshots: t.Collection[SnapshotInfoLike], unpaused_dt: TimeLike ) -> None: self.snapshot_cache.clear() self.state_sync.unpause_snapshots(snapshots, unpaused_dt) + + def clear_cache(self) -> None: + self.snapshot_cache.clear() diff --git a/sqlmesh/core/state_sync/common.py b/sqlmesh/core/state_sync/common.py index 12da604dfd..2e8c67ac29 100644 --- a/sqlmesh/core/state_sync/common.py +++ b/sqlmesh/core/state_sync/common.py @@ -1,78 +1,31 @@ from __future__ import annotations -import abc import logging import typing as t -from datetime import datetime from functools import wraps +import itertools +import abc + +from dataclasses import dataclass + +from pydantic_core.core_schema import ValidationInfo +from sqlglot import exp -from sqlmesh.core.console import Console -from sqlmesh.core.dialect import schema_ -from sqlmesh.core.environment import Environment +from sqlmesh.utils.pydantic import PydanticModel, field_validator +from sqlmesh.core.environment import Environment, EnvironmentStatements, EnvironmentNamingInfo from sqlmesh.core.snapshot import ( Snapshot, SnapshotId, - SnapshotIdLike, - SnapshotInfoLike, - SnapshotNameVersionLike, + SnapshotTableCleanupTask, SnapshotTableInfo, - start_date, ) -from sqlmesh.core.state_sync.base import PromotionResult, StateSync -from sqlmesh.utils.date import TimeLike, now, now_timestamp, to_timestamp -from sqlmesh.utils.errors import SQLMeshError if t.TYPE_CHECKING: - from sqlmesh.core.engine_adapter.base import EngineAdapter + from sqlmesh.core.state_sync.base import Versions, StateReader logger = logging.getLogger(__name__) - -def cleanup_expired_views( - adapter: EngineAdapter, environments: t.List[Environment], console: t.Optional[Console] = None -) -> None: - expired_schema_environments = [ - environment for environment in environments if environment.suffix_target.is_schema - ] - expired_table_environments = [ - environment for environment in environments if environment.suffix_target.is_table - ] - for expired_catalog, expired_schema in { - ( - snapshot.qualified_view_name.catalog_for_environment(environment.naming_info), - snapshot.qualified_view_name.schema_for_environment( - environment.naming_info, dialect=adapter.dialect - ), - ) - for environment in expired_schema_environments - for snapshot in environment.snapshots - if snapshot.is_model and not snapshot.is_symbolic - }: - schema = schema_(expired_schema, expired_catalog) - try: - adapter.drop_schema( - schema, - ignore_if_not_exists=True, - cascade=True, - ) - if console: - console.update_cleanup_progress(schema.sql(dialect=adapter.dialect)) - except Exception as e: - logger.warning("Falied to drop the expired environment schema '%s': %s", schema, e) - for expired_view in { - snapshot.qualified_view_name.for_environment( - environment.naming_info, dialect=adapter.dialect - ) - for environment in expired_table_environments - for snapshot in environment.snapshots - if snapshot.is_model and not snapshot.is_symbolic - }: - try: - adapter.drop_view(expired_view, ignore_if_not_exists=True) - if console: - console.update_cleanup_progress(expired_view) - except Exception as e: - logger.warning("Falied to drop the expired environment view '%s': %s", expired_view, e) +EXPIRED_SNAPSHOT_DEFAULT_BATCH_SIZE = 200 def transactional() -> t.Callable[[t.Callable], t.Callable]: @@ -90,283 +43,280 @@ def wrapper(self: t.Any, *args: t.Any, **kwargs: t.Any) -> t.Any: return decorator -class CommonStateSyncMixin(StateSync): - def get_snapshots( - self, - snapshot_ids: t.Optional[t.Iterable[SnapshotIdLike]], - ) -> t.Dict[SnapshotId, Snapshot]: - return self._get_snapshots(snapshot_ids) +T = t.TypeVar("T") - def get_environment(self, environment: str) -> t.Optional[Environment]: - return self._get_environment(environment) - @transactional() - def promote( - self, - environment: Environment, - no_gaps_snapshot_names: t.Optional[t.Set[str]] = None, - ) -> PromotionResult: - """Update the environment to reflect the current state. +def chunk_iterable(iterable: t.Iterable[T], size: int = 10) -> t.Iterable[t.Iterable[T]]: + iterator = iter(iterable) + for first in iterator: + yield itertools.chain([first], itertools.islice(iterator, size - 1)) - This method verifies that snapshots have been pushed. - Args: - environment: The environment to promote. - no_gaps_snapshot_names: A set of snapshot names to check for data gaps. If None, - all snapshots will be checked. The data gap check ensures that models that are already a - part of the target environment have no data gaps when compared against previous - snapshots for same models. +class EnvironmentWithStatements(PydanticModel): + environment: Environment + statements: t.List[EnvironmentStatements] = [] - Returns: - A tuple of (added snapshot table infos, removed snapshot table infos, and environment target suffix for the removed table infos) - """ - logger.info("Promoting environment '%s'", environment.name) - missing = {s.snapshot_id for s in environment.snapshots} - self.snapshots_exist( - environment.snapshots - ) - if missing: - raise SQLMeshError( - f"Missing snapshots {missing}. Make sure to push and backfill your snapshots." - ) +@dataclass +class VersionsChunk: + versions: Versions - existing_environment = self._get_environment(environment.name, lock_for_update=True) - existing_table_infos = ( - {table_info.name: table_info for table_info in existing_environment.promoted_snapshots} - if existing_environment - else {} - ) - table_infos = {table_info.name: table_info for table_info in environment.promoted_snapshots} - views_that_changed_location: t.Set[SnapshotTableInfo] = set() - if existing_environment: - views_that_changed_location = { - existing_table_info - for name, existing_table_info in existing_table_infos.items() - if name in table_infos - and existing_table_info.qualified_view_name.for_environment( - existing_environment.naming_info - ) - != table_infos[name].qualified_view_name.for_environment(environment.naming_info) - } - if environment.previous_plan_id != existing_environment.plan_id: - raise SQLMeshError( - f"Plan '{environment.plan_id}' is no longer valid for the target environment '{environment.name}'. " - f"Expected previous plan ID: '{environment.previous_plan_id}', actual previous plan ID: '{existing_environment.plan_id}'. " - "Please recreate the plan and try again" - ) - if no_gaps_snapshot_names != set(): - snapshots = self._get_snapshots(environment.snapshots).values() - self._ensure_no_gaps( - snapshots, - existing_environment, - no_gaps_snapshot_names, - ) - demoted_snapshots = set(existing_environment.snapshots) - set(environment.snapshots) - for demoted_snapshot in self._get_snapshots(demoted_snapshots).values(): - # Update the updated_at attribute. - self._update_snapshot(demoted_snapshot) - - missing_models = set(existing_table_infos) - { - snapshot.name for snapshot in environment.promoted_snapshots - } - - added_table_infos = set(table_infos.values()) - if existing_environment and existing_environment.finalized_ts: - # Only promote new snapshots. - added_table_infos -= set(existing_environment.promoted_snapshots) - - self._update_environment(environment) - - removed = {existing_table_infos[name] for name in missing_models}.union( - views_that_changed_location - ) +class SnapshotsChunk: + def __init__(self, items: t.Iterator[Snapshot]): + self.items = items - return PromotionResult( - added=sorted(added_table_infos), - removed=list(removed), - removed_environment_naming_info=( - existing_environment.naming_info if removed and existing_environment else None - ), - ) + def __iter__(self) -> t.Iterator[Snapshot]: + return self.items - @transactional() - def finalize(self, environment: Environment) -> None: - """Finalize the target environment, indicating that this environment has been - fully promoted and is ready for use. - Args: - environment: The target environment to finalize. - """ - logger.info("Finalizing environment '%s'", environment.name) - - stored_environment = self._get_environment(environment.name, lock_for_update=True) - if stored_environment and stored_environment.plan_id != environment.plan_id: - raise SQLMeshError( - f"Plan '{environment.plan_id}' is no longer valid for the target environment '{environment.name}'. " - f"Stored plan ID: '{stored_environment.plan_id}'. Please recreate the plan and try again" - ) - - environment.finalized_ts = now_timestamp() - self._update_environment(environment) - - @transactional() - def unpause_snapshots( - self, snapshots: t.Collection[SnapshotInfoLike], unpaused_dt: TimeLike - ) -> None: - current_ts = now() - - target_snapshot_ids = {s.snapshot_id for s in snapshots} - snapshots = self._get_snapshots_with_same_version(snapshots, lock_for_update=True) - target_snapshots_by_version = { - (s.name, s.version): s for s in snapshots if s.snapshot_id in target_snapshot_ids - } - - for snapshot in snapshots: - is_target_snapshot = snapshot.snapshot_id in target_snapshot_ids - if is_target_snapshot and not snapshot.unpaused_ts: - logger.info("Unpausing snapshot %s", snapshot.snapshot_id) - snapshot.set_unpaused_ts(unpaused_dt) - self._update_snapshot(snapshot) - elif not is_target_snapshot: - target_snapshot = target_snapshots_by_version[(snapshot.name, snapshot.version)] - if target_snapshot.normalized_effective_from_ts: - # Making sure that there are no overlapping intervals. - effective_from_ts = target_snapshot.normalized_effective_from_ts - logger.info( - "Removing all intervals after '%s' for snapshot %s, superseded by snapshot %s", - target_snapshot.effective_from, - snapshot.snapshot_id, - target_snapshot.snapshot_id, - ) - self.remove_interval( - [(snapshot, snapshot.get_removal_interval(effective_from_ts, current_ts))] - ) - - update_required = False - - if snapshot.unpaused_ts: - logger.info("Pausing snapshot %s", snapshot.snapshot_id) - snapshot.set_unpaused_ts(None) - update_required = True - - if ( - not snapshot.is_forward_only - and target_snapshot.is_forward_only - and not snapshot.unrestorable - ): - logger.info("Marking snapshot %s as unrestorable", snapshot.snapshot_id) - snapshot.unrestorable = True - update_required = True - - if update_required: - self._update_snapshot(snapshot) - - def _ensure_no_gaps( - self, - target_snapshots: t.Iterable[Snapshot], - target_environment: Environment, - snapshot_names: t.Optional[t.Set[str]], - ) -> None: - target_snapshots_by_name = {s.name: s for s in target_snapshots} - - changed_version_prev_snapshots_by_name = { - s.name: s - for s in target_environment.snapshots - if s.name in target_snapshots_by_name - and target_snapshots_by_name[s.name].version != s.version - } - - prev_snapshots = self._get_snapshots( - changed_version_prev_snapshots_by_name.values() - ).values() - cache: t.Dict[str, datetime] = {} - - for prev_snapshot in prev_snapshots: - target_snapshot = target_snapshots_by_name[prev_snapshot.name] - if ( - (snapshot_names is None or prev_snapshot.name in snapshot_names) - and target_snapshot.is_incremental - and prev_snapshot.is_incremental - and prev_snapshot.intervals - ): - start = to_timestamp( - start_date(target_snapshot, target_snapshots_by_name.values(), cache) - ) - end = prev_snapshot.intervals[-1][1] - - if start < end: - missing_intervals = target_snapshot.missing_intervals( - start, end, end_bounded=True - ) - - if missing_intervals: - raise SQLMeshError( - f"Detected gaps in snapshot {target_snapshot.snapshot_id}: {missing_intervals}" - ) +class EnvironmentsChunk: + def __init__(self, items: t.Iterator[EnvironmentWithStatements]): + self.items = items - @abc.abstractmethod - def _update_environment(self, environment: Environment) -> None: - """Overwrites the target environment with a given environment. + def __iter__(self) -> t.Iterator[EnvironmentWithStatements]: + return self.items - Args: - environment: The new environment. - """ - @abc.abstractmethod - def _update_snapshot(self, snapshot: Snapshot) -> None: - """Updates the target snapshot. +StateStreamContents = t.Union[VersionsChunk, SnapshotsChunk, EnvironmentsChunk] - Args: - snapshot: The target snapshot. - """ - @abc.abstractmethod - def _get_snapshots( - self, - snapshot_ids: t.Optional[t.Iterable[SnapshotIdLike]] = None, - lock_for_update: bool = False, - hydrate_intervals: bool = True, - ) -> t.Dict[SnapshotId, Snapshot]: - """Fetches specified snapshots. +class StateStream(abc.ABC): + """ + Represents a stream of state either going into the StateSync (perhaps loaded from a file) + or out of the StateSync (perhaps being dumped to a file) - Args: - snapshot_ids: The collection of IDs of snapshots to fetch - lock_for_update: Lock the snapshot rows for future update - hydrate_intervals: Whether to hydrate result snapshots with intervals. + Iterating over the stream produces the following chunks: - Returns: - A dictionary of snapshot ids to snapshots for ones that could be found. - """ + VersionsChunk: The versions of the objects contained in this StateStream + SnapshotsChunk: Is itself an iterator that streams Snapshot objects. Note that they should be fully populated with any relevant Intervals + EnvironmentsChunk: Is itself an iterator emitting a stream of Environments with any EnvironmentStatements attached + + The idea here is to give some structure to the stream and ensure that callers have the opportunity to process all its components while not + needing to worry about the order they are emitted in + """ @abc.abstractmethod - def _get_snapshots_with_same_version( - self, - snapshots: t.Collection[SnapshotNameVersionLike], - lock_for_update: bool = False, - ) -> t.List[Snapshot]: - """Fetches all snapshots that share the same version as the snapshots. + def __iter__(self) -> t.Iterator[StateStreamContents]: + pass - The output includes the snapshots with the specified version. + @classmethod + def from_iterators( + cls: t.Type["StateStream"], + versions: Versions, + snapshots: t.Iterator[Snapshot], + environments: t.Iterator[EnvironmentWithStatements], + ) -> "StateStream": + class _StateStream(cls): # type: ignore + def __iter__(self) -> t.Iterator[StateStreamContents]: + yield VersionsChunk(versions) - Args: - snapshots: The collection of target name / version pairs. - lock_for_update: Lock the snapshot rows for future update + yield SnapshotsChunk(snapshots) - Returns: - The list of Snapshot objects. - """ + yield EnvironmentsChunk(environments) - @abc.abstractmethod - def _get_environment( - self, environment: str, lock_for_update: bool = False - ) -> t.Optional[Environment]: - """Fetches the environment if it exists. + return _StateStream() + + +class ExpiredBatchRange(PydanticModel): + start: RowBoundary + end: t.Union[RowBoundary, LimitBoundary] + + @classmethod + def init_batch_range(cls, batch_size: int) -> ExpiredBatchRange: + return ExpiredBatchRange( + start=RowBoundary.lowest_boundary(), + end=LimitBoundary(batch_size=batch_size), + ) + + @classmethod + def all_batch_range(cls) -> ExpiredBatchRange: + return ExpiredBatchRange( + start=RowBoundary.lowest_boundary(), + end=RowBoundary.highest_boundary(), + ) + + @classmethod + def _expanded_tuple_comparison( + cls, + columns: t.List[exp.Column], + values: t.List[t.Union[exp.Literal, exp.Neg]], + operator: t.Type[exp.Expression], + ) -> exp.Expression: + """Generate expanded tuple comparison that works across all SQL engines. + + Converts tuple comparisons like (a, b, c) OP (x, y, z) into an expanded form + that's compatible with all SQL engines, since native tuple comparisons have + inconsistent support across engines (especially DuckDB, MySQL, SQLite). + + Repro of problem with DuckDB: + "SELECT * FROM VALUES(1,'2') as test(a,b) WHERE ((a, b) > (1, 'foo')) AND ((a, b) <= (10, 'baz'))" Args: - environment: The target environment name. - lock_for_update: Lock the snapshot rows for future update + columns: List of column expressions to compare + values: List of value expressions to compare against + operator: The comparison operator class (exp.GT, exp.GTE, exp.LT, exp.LTE) + + Examples: + (a, b, c) > (x, y, z) expands to: + a > x OR (a = x AND b > y) OR (a = x AND b = y AND c > z) + + (a, b, c) <= (x, y, z) expands to: + a < x OR (a = x AND b < y) OR (a = x AND b = y AND c <= z) + + (a, b, c) >= (x, y, z) expands to: + a > x OR (a = x AND b > y) OR (a = x AND b = y AND c >= z) Returns: - The target environment. + An expanded OR expression representing the tuple comparison """ + if operator not in (exp.GT, exp.GTE, exp.LT, exp.LTE): + raise ValueError(f"Unsupported operator: {operator}. Use GT, GTE, LT, or LTE.") + + # For <= and >=, we use the strict operator for all but the last column + # e.g., (a, b) <= (x, y) becomes: a < x OR (a = x AND b <= y) + # For < and >, we use the strict operator throughout + # e.g., (a, b) > (x, y) becomes: a > x OR (a = x AND b > x) + strict_operator: t.Type[exp.Expression] + final_operator: t.Type[exp.Expression] + + if operator in (exp.LTE, exp.GTE): + # For inclusive operators (<=, >=), use strict form for intermediate columns + # but keep inclusive form for the last column + strict_operator = exp.LT if operator == exp.LTE else exp.GT + final_operator = operator # Keep LTE/GTE for last column + else: + # For strict operators (<, >), use them throughout + strict_operator = operator + final_operator = operator + + conditions: t.List[exp.Expression] = [] + for i in range(len(columns)): + # Build equality conditions for all columns before current + equality_conditions = [exp.EQ(this=columns[j], expression=values[j]) for j in range(i)] + + # Use the final operator for the last column, strict for others + comparison_op = final_operator if i == len(columns) - 1 else strict_operator + comparison_condition = comparison_op(this=columns[i], expression=values[i]) + + if equality_conditions: + conditions.append(exp.and_(*equality_conditions, comparison_condition)) + else: + conditions.append(comparison_condition) + + return exp.or_(*conditions) if len(conditions) > 1 else conditions[0] + + @property + def where_filter(self) -> exp.Expression: + # Use expanded tuple comparisons for cross-engine compatibility + # Native tuple comparisons like (a, b) > (x, y) don't work reliably across all SQL engines + columns = [ + exp.column("updated_ts"), + exp.column("name"), + exp.column("identifier"), + ] + start_values = [ + exp.Literal.number(self.start.updated_ts), + exp.Literal.string(self.start.name), + exp.Literal.string(self.start.identifier), + ] + + start_condition = self._expanded_tuple_comparison(columns, start_values, exp.GT) + + range_filter: exp.Expression + if isinstance(self.end, RowBoundary): + end_values = [ + exp.Literal.number(self.end.updated_ts), + exp.Literal.string(self.end.name), + exp.Literal.string(self.end.identifier), + ] + end_condition = self._expanded_tuple_comparison(columns, end_values, exp.LTE) + range_filter = exp.and_(start_condition, end_condition) + else: + range_filter = start_condition + return range_filter + + +class RowBoundary(PydanticModel): + updated_ts: int + name: str + identifier: str + + @classmethod + def lowest_boundary(cls) -> RowBoundary: + return RowBoundary(updated_ts=0, name="", identifier="") + + @classmethod + def highest_boundary(cls) -> RowBoundary: + # 9999-12-31T23:59:59.999Z in epoch milliseconds + return RowBoundary(updated_ts=253_402_300_799_999, name="", identifier="") + + +class LimitBoundary(PydanticModel): + batch_size: int + + @classmethod + def init_batch_boundary(cls, batch_size: int) -> LimitBoundary: + return LimitBoundary(batch_size=batch_size) + + +class PromotionResult(PydanticModel): + added: t.List[SnapshotTableInfo] + removed: t.List[SnapshotTableInfo] + removed_environment_naming_info: t.Optional[EnvironmentNamingInfo] + + @field_validator("removed_environment_naming_info") + def _validate_removed_environment_naming_info( + cls, v: t.Optional[EnvironmentNamingInfo], info: ValidationInfo + ) -> t.Optional[EnvironmentNamingInfo]: + if v and not info.data.get("removed"): + raise ValueError("removed_environment_naming_info must be None if removed is empty") + return v + + +class ExpiredSnapshotBatch(PydanticModel): + """A batch of expired snapshots to be cleaned up.""" + + expired_snapshot_ids: t.Set[SnapshotId] + cleanup_tasks: t.List[SnapshotTableCleanupTask] + batch_range: ExpiredBatchRange + + +def iter_expired_snapshot_batches( + state_reader: StateReader, + *, + current_ts: int, + ignore_ttl: bool = False, + batch_size: t.Optional[int] = None, +) -> t.Iterator[ExpiredSnapshotBatch]: + """Yields expired snapshot batches. + + Args: + state_reader: StateReader instance to query expired snapshots from. + current_ts: Timestamp used to evaluate expiration. + ignore_ttl: If True, include snapshots regardless of TTL (only checks if unreferenced). + batch_size: Maximum number of snapshots to fetch per batch. + """ + + batch_size = batch_size if batch_size is not None else EXPIRED_SNAPSHOT_DEFAULT_BATCH_SIZE + batch_range = ExpiredBatchRange.init_batch_range(batch_size=batch_size) + + while True: + batch = state_reader.get_expired_snapshots( + current_ts=current_ts, + ignore_ttl=ignore_ttl, + batch_range=batch_range, + ) + + if batch is None: + return + + yield batch + + assert isinstance(batch.batch_range.end, RowBoundary), ( + "Only RowBoundary is supported for pagination currently" + ) + batch_range = ExpiredBatchRange( + start=batch.batch_range.end, + end=LimitBoundary(batch_size=batch_size), + ) diff --git a/sqlmesh/core/state_sync/db/__init__.py b/sqlmesh/core/state_sync/db/__init__.py new file mode 100644 index 0000000000..3292449359 --- /dev/null +++ b/sqlmesh/core/state_sync/db/__init__.py @@ -0,0 +1,3 @@ +from sqlmesh.core.state_sync.db.facade import EngineAdapterStateSync + +__all__ = ["EngineAdapterStateSync"] diff --git a/sqlmesh/core/state_sync/db/environment.py b/sqlmesh/core/state_sync/db/environment.py new file mode 100644 index 0000000000..e3f1d1ec9e --- /dev/null +++ b/sqlmesh/core/state_sync/db/environment.py @@ -0,0 +1,386 @@ +from __future__ import annotations + +import typing as t +import json +import logging +from sqlglot import exp + +from sqlmesh.core import constants as c +from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.core.state_sync.db.utils import ( + fetchall, + fetchone, +) +from sqlmesh.core.environment import Environment, EnvironmentStatements, EnvironmentSummary +from sqlmesh.utils.migration import index_text_type, blob_text_type +from sqlmesh.utils.date import now_timestamp, time_like_to_str +from sqlmesh.utils.errors import SQLMeshError + +if t.TYPE_CHECKING: + import pandas as pd + + +logger = logging.getLogger(__name__) + + +class EnvironmentState: + def __init__( + self, + engine_adapter: EngineAdapter, + schema: t.Optional[str] = None, + ): + self.engine_adapter = engine_adapter + self.environments_table = exp.table_("_environments", db=schema) + self.environment_statements_table = exp.table_("_environment_statements", db=schema) + + index_type = index_text_type(engine_adapter.dialect) + blob_type = blob_text_type(engine_adapter.dialect) + + self._environment_columns_to_types = { + "name": exp.DataType.build(index_type), + "snapshots": exp.DataType.build(blob_type), + "start_at": exp.DataType.build("text"), + "end_at": exp.DataType.build("text"), + "plan_id": exp.DataType.build("text"), + "previous_plan_id": exp.DataType.build("text"), + "expiration_ts": exp.DataType.build("bigint"), + "finalized_ts": exp.DataType.build("bigint"), + "promoted_snapshot_ids": exp.DataType.build(blob_type), + "suffix_target": exp.DataType.build("text"), + "catalog_name_override": exp.DataType.build("text"), + "previous_finalized_snapshots": exp.DataType.build(blob_type), + "normalize_name": exp.DataType.build("boolean"), + "gateway_managed": exp.DataType.build("boolean"), + "requirements": exp.DataType.build(blob_type), + } + + self._environment_statements_columns_to_types = { + "environment_name": exp.DataType.build(index_type), + "plan_id": exp.DataType.build("text"), + "environment_statements": exp.DataType.build(blob_type), + } + + def update_environment(self, environment: Environment) -> None: + """Updates the environment. + + Args: + environment: The environment + """ + self.engine_adapter.delete_from( + self.environments_table, + where=exp.EQ( + this=exp.column("name"), + expression=exp.Literal.string(environment.name), + ), + ) + + self.engine_adapter.insert_append( + self.environments_table, + _environment_to_df(environment), + target_columns_to_types=self._environment_columns_to_types, + track_rows_processed=False, + ) + + def update_environment_statements( + self, + environment_name: str, + plan_id: str, + environment_statements: t.List[EnvironmentStatements], + ) -> None: + """Updates the environment's statements. + + Args: + environment_name: The environment name + plan_id: The environment's plan ID + environment_statements: The environment statements + + """ + + self.engine_adapter.delete_from( + self.environment_statements_table, + where=exp.EQ( + this=exp.column("environment_name"), + expression=exp.Literal.string(environment_name), + ), + ) + + if environment_statements: + self.engine_adapter.insert_append( + self.environment_statements_table, + _environment_statements_to_df(environment_name, plan_id, environment_statements), + target_columns_to_types=self._environment_statements_columns_to_types, + track_rows_processed=False, + ) + + def invalidate_environment(self, name: str, protect_prod: bool = True) -> None: + """Invalidates the environment. + + Args: + name: The name of the environment + protect_prod: If True, prevents invalidation of the production environment. + """ + name = name.lower() + if protect_prod and name == c.PROD: + raise SQLMeshError("Cannot invalidate the production environment.") + + filter_expr = exp.column("name").eq(name) + + self.engine_adapter.update_table( + self.environments_table, + {"expiration_ts": now_timestamp()}, + where=filter_expr, + ) + + def finalize(self, environment: Environment) -> None: + """Finalize the target environment, indicating that this environment has been + fully promoted and is ready for use. + + Args: + environment: The target environment to finalize. + """ + logger.info("Finalizing environment '%s'", environment.name) + + environment_filter = exp.column("name").eq(exp.Literal.string(environment.name)) + + stored_plan_id_query = ( + exp.select("plan_id") + .from_(self.environments_table) + .where(environment_filter, copy=False) + .lock(copy=False) + ) + stored_plan_id_row = fetchone(self.engine_adapter, stored_plan_id_query) + + if not stored_plan_id_row: + raise SQLMeshError(f"Missing environment '{environment.name}' can't be finalized") + + stored_plan_id = stored_plan_id_row[0] + if stored_plan_id != environment.plan_id: + raise SQLMeshError( + f"Another plan ({stored_plan_id}) was applied to the target environment '{environment.name}' while your current plan " + f"({environment.plan_id}) was still in progress, interrupting it. Please re-apply your plan to resolve this error." + ) + + environment.finalized_ts = now_timestamp() + self.engine_adapter.update_table( + self.environments_table, + {"finalized_ts": environment.finalized_ts}, + where=environment_filter, + ) + + def get_expired_environments(self, current_ts: int) -> t.List[EnvironmentSummary]: + """Returns the expired environments. + + Expired environments are environments that have exceeded their time-to-live value. + Returns: + The list of environment summaries to remove. + """ + return self._fetch_environment_summaries( + where=self._create_expiration_filter_expr(current_ts) + ) + + def delete_expired_environments( + self, current_ts: t.Optional[int] = None + ) -> t.List[EnvironmentSummary]: + """Deletes expired environments. + + Returns: + A list of deleted environments. + """ + current_ts = current_ts or now_timestamp() + expired_environments = self.get_expired_environments(current_ts=current_ts) + + self.engine_adapter.delete_from( + self.environments_table, + where=self._create_expiration_filter_expr(current_ts), + ) + + # Delete the expired environments' corresponding environment statements + if expired_environments_exprs := [ + exp.EQ(this=exp.column("environment_name"), expression=exp.Literal.string(env.name)) + for env in expired_environments + ]: + self.engine_adapter.delete_from( + self.environment_statements_table, + where=exp.or_(*expired_environments_exprs), + ) + + return expired_environments + + def get_environments(self) -> t.List[Environment]: + """Fetches all environments. + + Returns: + A list of all environments. + """ + return [ + self._environment_from_row(row) + for row in fetchall(self.engine_adapter, self._environments_query()) + ] + + def get_environments_summary(self) -> t.List[EnvironmentSummary]: + """Fetches summaries for all environments. + + Returns: + A list of all environment summaries. + """ + return self._fetch_environment_summaries() + + def get_environment( + self, environment: str, lock_for_update: bool = False + ) -> t.Optional[Environment]: + """Fetches the environment if it exists. + + Args: + environment: The environment + lock_for_update: Lock the snapshot rows for future update + + Returns: + The environment object. + """ + row = fetchone( + self.engine_adapter, + self._environments_query( + where=exp.EQ( + this=exp.column("name"), + expression=exp.Literal.string(environment), + ), + lock_for_update=lock_for_update, + ), + ) + + if not row: + return None + + env = self._environment_from_row(row) + return env + + def get_environment_statements(self, environment: str) -> t.List[EnvironmentStatements]: + """Fetches the environment's statements from the environment_statements table. + Args: + environment: The environment name + + Returns: + A list of the environment statements. + + """ + query = ( + exp.select( + exp.to_identifier("environment_statements"), + ) + .from_(self.environment_statements_table) + .where( + exp.EQ( + this=exp.column("environment_name"), + expression=exp.Literal.string(environment), + ) + ) + ) + result = fetchone(engine_adapter=self.engine_adapter, query=query) + if result and (statements := json.loads(result[0])): + return [ + EnvironmentStatements.parse_obj(environment_statements) + for environment_statements in statements + ] + + return [] + + def _environment_from_row(self, row: t.Tuple[str, ...]) -> Environment: + return Environment( + **{field: row[i] for i, field in enumerate(sorted(Environment.all_fields()))} + ) + + def _environment_summmary_from_row(self, row: t.Tuple[str, ...]) -> EnvironmentSummary: + return EnvironmentSummary( + **{field: row[i] for i, field in enumerate(sorted(EnvironmentSummary.all_fields()))} + ) + + def _environments_query( + self, + where: t.Optional[str | exp.Expression] = None, + lock_for_update: bool = False, + required_fields: t.Optional[t.List[str]] = None, + ) -> exp.Select: + query_fields = required_fields if required_fields else sorted(Environment.all_fields()) + query = ( + exp.select(*(exp.to_identifier(field) for field in query_fields)) + .from_(self.environments_table) + .where(where) + ) + if lock_for_update: + return query.lock(copy=False) + return query + + def _create_expiration_filter_expr(self, current_ts: int) -> exp.Expression: + """Creates a SQLGlot filter expression to find expired environments. + + Args: + current_ts: The current timestamp. + """ + return exp.LTE( + this=exp.column("expiration_ts"), + expression=exp.Literal.number(current_ts), + ) + + def _fetch_environment_summaries( + self, where: t.Optional[str | exp.Expression] = None + ) -> t.List[EnvironmentSummary]: + return [ + self._environment_summmary_from_row(row) + for row in fetchall( + self.engine_adapter, + self._environments_query( + where=where, + required_fields=sorted(EnvironmentSummary.all_fields()), + ), + ) + ] + + +def _environment_to_df(environment: Environment) -> pd.DataFrame: + import pandas as pd + + return pd.DataFrame( + [ + { + "name": environment.name, + "snapshots": json.dumps(environment.snapshot_dicts()), + "start_at": time_like_to_str(environment.start_at), + "end_at": time_like_to_str(environment.end_at) if environment.end_at else None, + "plan_id": environment.plan_id, + "previous_plan_id": environment.previous_plan_id, + "expiration_ts": environment.expiration_ts, + "finalized_ts": environment.finalized_ts, + "promoted_snapshot_ids": ( + json.dumps(environment.promoted_snapshot_id_dicts()) + if environment.promoted_snapshot_ids is not None + else None + ), + "suffix_target": environment.suffix_target.value, + "catalog_name_override": environment.catalog_name_override, + "previous_finalized_snapshots": ( + json.dumps(environment.previous_finalized_snapshot_dicts()) + if environment.previous_finalized_snapshots is not None + else None + ), + "normalize_name": environment.normalize_name, + "gateway_managed": environment.gateway_managed, + "requirements": json.dumps(environment.requirements), + } + ] + ) + + +def _environment_statements_to_df( + environment_name: str, plan_id: str, environment_statements: t.List[EnvironmentStatements] +) -> pd.DataFrame: + import pandas as pd + + return pd.DataFrame( + [ + { + "environment_name": environment_name, + "plan_id": plan_id, + "environment_statements": json.dumps([e.dict() for e in environment_statements]), + } + ] + ) diff --git a/sqlmesh/core/state_sync/db/facade.py b/sqlmesh/core/state_sync/db/facade.py new file mode 100644 index 0000000000..64042624f3 --- /dev/null +++ b/sqlmesh/core/state_sync/db/facade.py @@ -0,0 +1,652 @@ +""" +# StateSync + +State sync is how SQLMesh keeps track of environments and their states, e.g. snapshots. + +# StateReader + +StateReader provides a subset of the functionalities of the StateSync class. As its name +implies, it only allows for read-only operations on snapshots and environment states. + +# EngineAdapterStateSync + +The provided `sqlmesh.core.state_sync.EngineAdapterStateSync` leverages an existing engine +adapter to read and write state to the underlying data store. +""" + +from __future__ import annotations + +import contextlib +import logging +import typing as t +from pathlib import Path +from datetime import datetime + + +from sqlmesh.core.console import Console, get_console +from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.core.environment import Environment, EnvironmentStatements, EnvironmentSummary +from sqlmesh.core.snapshot import ( + Snapshot, + SnapshotIdAndVersion, + SnapshotId, + SnapshotIdLike, + SnapshotIdAndVersionLike, + SnapshotInfoLike, + SnapshotIntervals, + SnapshotNameVersion, + SnapshotTableInfo, + start_date, +) +from sqlmesh.core.snapshot.definition import ( + Interval, +) +from sqlmesh.core.state_sync.base import ( + StateSync, + Versions, +) +from sqlmesh.core.state_sync.common import ( + EnvironmentsChunk, + SnapshotsChunk, + VersionsChunk, + transactional, + StateStream, + chunk_iterable, + EnvironmentWithStatements, + ExpiredSnapshotBatch, + PromotionResult, + ExpiredBatchRange, +) +from sqlmesh.core.state_sync.db.interval import IntervalState +from sqlmesh.core.state_sync.db.environment import EnvironmentState +from sqlmesh.core.state_sync.db.snapshot import SnapshotState +from sqlmesh.core.state_sync.db.version import VersionState +from sqlmesh.core.state_sync.db.migrator import StateMigrator, _backup_table_name +from sqlmesh.utils.date import TimeLike, to_timestamp, time_like_to_str, now_timestamp +from sqlmesh.utils.errors import ConflictingPlanError, SQLMeshError + +logger = logging.getLogger(__name__) + + +T = t.TypeVar("T") + + +class EngineAdapterStateSync(StateSync): + """Manages state of nodes and snapshot with an existing engine adapter. + + This state sync is convenient to use because it requires no additional setup. + You can reuse the same engine/warehouse that your data is stored in. + + Args: + engine_adapter: The EngineAdapter to use to store and fetch snapshots. + schema: The schema to store state metadata in. If None or empty string then no schema is defined + console: The console to log information to. + cache_dir: The cache path, used for caching snapshot models. + """ + + def __init__( + self, + engine_adapter: EngineAdapter, + schema: t.Optional[str], + console: t.Optional[Console] = None, + cache_dir: Path = Path(), + ): + self.interval_state = IntervalState(engine_adapter, schema=schema) + self.environment_state = EnvironmentState(engine_adapter, schema=schema) + self.snapshot_state = SnapshotState(engine_adapter, schema=schema, cache_dir=cache_dir) + self.version_state = VersionState(engine_adapter, schema=schema) + self.migrator = StateMigrator( + engine_adapter, + version_state=self.version_state, + snapshot_state=self.snapshot_state, + environment_state=self.environment_state, + interval_state=self.interval_state, + console=console, + ) + # Make sure that if an empty string is provided that we treat it as None + self.schema = schema or None + self.engine_adapter = engine_adapter + self.console = console or get_console() + + @transactional() + def push_snapshots(self, snapshots: t.Iterable[Snapshot]) -> None: + """Pushes snapshots to the state store, merging them with existing ones. + + This method first finds all existing snapshots in the store and merges them with + the local snapshots. It will then delete all existing snapshots and then + insert all the local snapshots. This can be made safer with locks or merge/upsert. + + Args: + snapshots: The snapshots to push. + """ + snapshots_by_id = {} + for snapshot in snapshots: + if not snapshot.version: + raise SQLMeshError( + f"Snapshot {snapshot} has not been versioned yet. Create a plan before pushing a snapshot." + ) + snapshots_by_id[snapshot.snapshot_id] = snapshot + + existing = self.snapshots_exist(snapshots_by_id) + + if existing: + logger.error( + "Snapshots %s already exists. This could be due to a concurrent plan or a hash collision. If this is a hash collision, add a stamp to your model.", + str(existing), + ) + + for sid in tuple(snapshots_by_id): + if sid in existing: + snapshots_by_id.pop(sid) + + snapshots = snapshots_by_id.values() + if snapshots: + self.snapshot_state.push_snapshots(snapshots) + + @transactional() + def promote( + self, + environment: Environment, + no_gaps_snapshot_names: t.Optional[t.Set[str]] = None, + environment_statements: t.Optional[t.List[EnvironmentStatements]] = None, + ) -> PromotionResult: + """Update the environment to reflect the current state. + + This method verifies that snapshots have been pushed. + + Args: + environment: The environment to promote. + no_gaps_snapshot_names: A set of snapshot names to check for data gaps. If None, + all snapshots will be checked. The data gap check ensures that models that are already a + part of the target environment have no data gaps when compared against previous + snapshots for same models. + + Returns: + A tuple of (added snapshot table infos, removed snapshot table infos, and environment target suffix for the removed table infos) + """ + logger.info("Promoting environment '%s'", environment.name) + + missing = {s.snapshot_id for s in environment.snapshots} - self.snapshots_exist( + environment.snapshots + ) + if missing: + raise SQLMeshError( + f"Missing snapshots {missing}. Make sure to push and backfill your snapshots." + ) + + existing_environment = self.environment_state.get_environment( + environment.name, lock_for_update=True + ) + + existing_table_infos = ( + {table_info.name: table_info for table_info in existing_environment.promoted_snapshots} + if existing_environment + else {} + ) + table_infos = {table_info.name: table_info for table_info in environment.promoted_snapshots} + views_that_changed_location: t.Set[SnapshotTableInfo] = set() + if existing_environment: + views_that_changed_location = { + existing_table_info + for name, existing_table_info in existing_table_infos.items() + if name in table_infos + and existing_table_info.qualified_view_name.for_environment( + existing_environment.naming_info + ) + != table_infos[name].qualified_view_name.for_environment(environment.naming_info) + } + if not existing_environment.expired: + if environment.previous_plan_id != existing_environment.plan_id: + raise ConflictingPlanError( + f"Another plan ({existing_environment.plan_id}) was applied to the target environment '{environment.name}' while your current plan " + f"({environment.plan_id}) was still in progress, interrupting it. Please re-apply your plan to resolve this error." + ) + if no_gaps_snapshot_names != set(): + snapshots = self.get_snapshots(environment.snapshots).values() + self._ensure_no_gaps( + snapshots, + existing_environment, + no_gaps_snapshot_names, + ) + demoted_snapshots = set(existing_environment.snapshots) - set(environment.snapshots) + # Update the updated_at attribute. + self.snapshot_state.touch_snapshots(demoted_snapshots) + + missing_models = set(existing_table_infos) - { + snapshot.name for snapshot in environment.promoted_snapshots + } + + added_table_infos = set(table_infos.values()) + if existing_environment and environment.can_partially_promote(existing_environment): + # Only promote new snapshots. + added_table_infos -= set(existing_environment.promoted_snapshots) + + self.environment_state.update_environment(environment) + + # If it is an empty list, we want to update the environment statements + # To reflect there are no statements anymore in this environment + if environment_statements is not None: + self.environment_state.update_environment_statements( + environment.name, environment.plan_id, environment_statements + ) + + removed = {existing_table_infos[name] for name in missing_models}.union( + views_that_changed_location + ) + + return PromotionResult( + added=sorted(added_table_infos), + removed=list(removed), + removed_environment_naming_info=( + existing_environment.naming_info if removed and existing_environment else None + ), + ) + + @transactional() + def finalize(self, environment: Environment) -> None: + """Finalize the target environment, indicating that this environment has been + fully promoted and is ready for use. + + Args: + environment: The target environment to finalize. + """ + self.environment_state.finalize(environment) + + @transactional() + def unpause_snapshots( + self, snapshots: t.Collection[SnapshotInfoLike], unpaused_dt: TimeLike + ) -> None: + self.snapshot_state.unpause_snapshots(snapshots, unpaused_dt) + + def invalidate_environment(self, name: str, protect_prod: bool = True) -> None: + self.environment_state.invalidate_environment(name, protect_prod) + + def get_expired_snapshots( + self, + *, + batch_range: ExpiredBatchRange, + current_ts: t.Optional[int] = None, + ignore_ttl: bool = False, + ) -> t.Optional[ExpiredSnapshotBatch]: + current_ts = current_ts or now_timestamp() + return self.snapshot_state.get_expired_snapshots( + environments=self.environment_state.get_environments(), + current_ts=current_ts, + ignore_ttl=ignore_ttl, + batch_range=batch_range, + ) + + def get_expired_environments(self, current_ts: int) -> t.List[EnvironmentSummary]: + return self.environment_state.get_expired_environments(current_ts=current_ts) + + @transactional() + def delete_expired_snapshots( + self, + batch_range: ExpiredBatchRange, + ignore_ttl: bool = False, + current_ts: t.Optional[int] = None, + ) -> None: + batch = self.get_expired_snapshots( + ignore_ttl=ignore_ttl, + current_ts=current_ts, + batch_range=batch_range, + ) + if batch and batch.expired_snapshot_ids: + self.snapshot_state.delete_snapshots(batch.expired_snapshot_ids) + self.interval_state.cleanup_intervals(batch.cleanup_tasks, batch.expired_snapshot_ids) + + @transactional() + def delete_expired_environments( + self, current_ts: t.Optional[int] = None + ) -> t.List[EnvironmentSummary]: + current_ts = current_ts or now_timestamp() + return self.environment_state.delete_expired_environments(current_ts=current_ts) + + def delete_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None: + self.snapshot_state.delete_snapshots(snapshot_ids) + + def snapshots_exist(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> t.Set[SnapshotId]: + return self.snapshot_state.snapshots_exist(snapshot_ids) + + def nodes_exist(self, names: t.Iterable[str], exclude_external: bool = False) -> t.Set[str]: + return self.snapshot_state.nodes_exist(names, exclude_external) + + def remove_state(self, including_backup: bool = False) -> None: + """Removes the state store objects.""" + for table in ( + self.snapshot_state.snapshots_table, + self.snapshot_state.auto_restatements_table, + self.environment_state.environments_table, + self.environment_state.environment_statements_table, + self.interval_state.intervals_table, + self.version_state.versions_table, + ): + self.engine_adapter.drop_table(table) + if including_backup: + self.engine_adapter.drop_table(_backup_table_name(table)) + + self.snapshot_state.clear_cache() + + def reset(self, default_catalog: t.Optional[str]) -> None: + """Resets the state store to the state when it was first initialized.""" + self.remove_state() + self.migrate(default_catalog) + + @transactional() + def update_auto_restatements( + self, next_auto_restatement_ts: t.Dict[SnapshotNameVersion, t.Optional[int]] + ) -> None: + self.snapshot_state.update_auto_restatements(next_auto_restatement_ts) + + def get_environment(self, environment: str) -> t.Optional[Environment]: + return self.environment_state.get_environment(environment) + + def get_environment_statements(self, environment: str) -> t.List[EnvironmentStatements]: + return self.environment_state.get_environment_statements(environment) + + def get_environments(self) -> t.List[Environment]: + """Fetches all environments. + + Returns: + A list of all environments. + """ + return self.environment_state.get_environments() + + def get_environments_summary(self) -> t.List[EnvironmentSummary]: + """Fetches all environment names along with expiry datetime. + + Returns: + A list of all environment summaries. + """ + return self.environment_state.get_environments_summary() + + def get_snapshots( + self, + snapshot_ids: t.Iterable[SnapshotIdLike], + ) -> t.Dict[SnapshotId, Snapshot]: + """Fetches snapshots from the state. + + Args: + snapshot_ids: The snapshot IDs to fetch. + + Returns: + A dict of snapshots. + """ + snapshots = self.snapshot_state.get_snapshots(snapshot_ids) + intervals = self.interval_state.get_snapshot_intervals(snapshots.values()) + Snapshot.hydrate_with_intervals_by_version(snapshots.values(), intervals) + return snapshots + + def get_snapshots_by_names( + self, + snapshot_names: t.Iterable[str], + current_ts: t.Optional[int] = None, + exclude_expired: bool = True, + ) -> t.Set[SnapshotIdAndVersion]: + return self.snapshot_state.get_snapshots_by_names( + snapshot_names=snapshot_names, current_ts=current_ts, exclude_expired=exclude_expired + ) + + @transactional() + def add_interval( + self, + snapshot: Snapshot, + start: TimeLike, + end: TimeLike, + is_dev: bool = False, + last_altered_ts: t.Optional[int] = None, + ) -> None: + super().add_interval(snapshot, start, end, is_dev, last_altered_ts) + + @transactional() + def add_snapshots_intervals(self, snapshots_intervals: t.Sequence[SnapshotIntervals]) -> None: + intervals_to_insert = [] + for snapshot_intervals in snapshots_intervals: + snapshot_intervals = snapshot_intervals.copy( + update={ + "intervals": _remove_partial_intervals( + snapshot_intervals.intervals, snapshot_intervals.snapshot_id, is_dev=False + ), + "dev_intervals": _remove_partial_intervals( + snapshot_intervals.dev_intervals, + snapshot_intervals.snapshot_id, + is_dev=True, + ), + } + ) + if not snapshot_intervals.is_empty(): + intervals_to_insert.append(snapshot_intervals) + if intervals_to_insert: + self.interval_state.add_snapshots_intervals(intervals_to_insert) + + @transactional() + def remove_intervals( + self, + snapshot_intervals: t.Sequence[t.Tuple[SnapshotIdAndVersionLike, Interval]], + remove_shared_versions: bool = False, + ) -> None: + self.interval_state.remove_intervals(snapshot_intervals, remove_shared_versions) + + @transactional() + def compact_intervals(self) -> None: + self.interval_state.compact_intervals() + + def refresh_snapshot_intervals(self, snapshots: t.Collection[Snapshot]) -> t.List[Snapshot]: + return self.interval_state.refresh_snapshot_intervals(snapshots) + + def max_interval_end_per_model( + self, + environment: str, + models: t.Optional[t.Set[str]] = None, + ensure_finalized_snapshots: bool = False, + ) -> t.Dict[str, int]: + env = self.get_environment(environment) + if not env: + return {} + + snapshots = ( + env.snapshots if not ensure_finalized_snapshots else env.finalized_or_current_snapshots + ) + if models is not None: + snapshots = [s for s in snapshots if s.name in models] + + if not snapshots: + return {} + + return self.interval_state.max_interval_end_per_model(snapshots) + + def recycle(self) -> None: + self.engine_adapter.recycle() + + def close(self) -> None: + self.engine_adapter.close() + + @transactional() + def migrate( + self, + skip_backup: bool = False, + promoted_snapshots_only: bool = True, + ) -> None: + """Migrate the state sync to the latest SQLMesh / SQLGlot version.""" + self.migrator.migrate( + self.schema, + skip_backup=skip_backup, + promoted_snapshots_only=promoted_snapshots_only, + ) + + @transactional() + def rollback(self) -> None: + """Rollback to the previous migration.""" + self.migrator.rollback() + + @transactional() + def export(self, environment_names: t.Optional[t.List[str]] = None) -> StateStream: + versions = self.get_versions( + validate=True + ) # will throw if the state db hasnt been created or there is a version mismatch + + snapshot_ids_to_export: t.Set[SnapshotId] = set() + selected_environments: t.List[Environment] = [] + if environment_names: + for env_name in environment_names: + environment = self.get_environment(env_name) + if not environment: + raise SQLMeshError(f"No such environment: {env_name}") + selected_environments.append(environment) + else: + selected_environments = self.get_environments() + + for env in selected_environments: + snapshot_ids_to_export |= set([s.snapshot_id for s in env.snapshots or []]) + + def _export_snapshots() -> t.Iterator[Snapshot]: + for chunk in chunk_iterable(snapshot_ids_to_export, SnapshotState.SNAPSHOT_BATCH_SIZE): + yield from self.get_snapshots(chunk).values() + + def _export_environments() -> t.Iterator[EnvironmentWithStatements]: + for env in selected_environments: + yield EnvironmentWithStatements( + environment=env, statements=self.get_environment_statements(env.name) + ) + + return StateStream.from_iterators( + versions=versions, + snapshots=_export_snapshots(), + environments=_export_environments(), + ) + + @transactional() + def import_(self, stream: StateStream, clear: bool = True) -> None: + existing_versions = self.get_versions() + + for state_chunk in stream: + if isinstance(state_chunk, VersionsChunk): + # SQLMesh major/minor version must match so that we can be sure the JSON contained in the state file + # is compatible with our Pydantic model definitions. Patch versions dont need to match because the assumption + # is that they dont contain any breaking changes + incoming_versions = state_chunk.versions + if ( + incoming_versions.minor_sqlmesh_version + != existing_versions.minor_sqlmesh_version + ): + raise SQLMeshError( + f"SQLMesh version mismatch. You are running '{existing_versions.sqlmesh_version}' but the state file was created with '{incoming_versions.sqlmesh_version}'.\n" + "Please upgrade/downgrade your SQLMesh version to match the state file before performing the import." + ) + + if clear: + self.reset(default_catalog=None) + + if isinstance(state_chunk, SnapshotsChunk): + auto_restatements: t.Dict[SnapshotNameVersion, t.Optional[int]] = {} + + for snapshot_chunk in chunk_iterable( + state_chunk, SnapshotState.SNAPSHOT_BATCH_SIZE + ): + snapshot_chunk = list(snapshot_chunk) + overwrite_existing_snapshots = ( + not clear + ) # if clear=True, all existing snapshots were dropped anyway + self.snapshot_state.push_snapshots( + snapshot_chunk, overwrite=overwrite_existing_snapshots + ) + self.add_snapshots_intervals((s.snapshot_intervals for s in snapshot_chunk)) + + auto_restatements.update( + { + s.name_version: s.next_auto_restatement_ts + for s in snapshot_chunk + if s.next_auto_restatement_ts + } + ) + + self.update_auto_restatements(auto_restatements) + + if isinstance(state_chunk, EnvironmentsChunk): + for environment_with_statements in state_chunk: + environment = environment_with_statements.environment + self.environment_state.update_environment(environment) + self.environment_state.update_environment_statements( + environment.name, + environment.plan_id, + environment_with_statements.statements, + ) + + def state_type(self) -> str: + return self.engine_adapter.dialect + + def _get_versions(self) -> Versions: + return self.version_state.get_versions() + + def _ensure_no_gaps( + self, + target_snapshots: t.Iterable[Snapshot], + target_environment: Environment, + snapshot_names: t.Optional[t.Set[str]], + ) -> None: + target_snapshots_by_name = {s.name: s for s in target_snapshots} + + changed_version_prev_snapshots_by_name = { + s.name: s + for s in target_environment.snapshots + if s.name in target_snapshots_by_name + and target_snapshots_by_name[s.name].version != s.version + } + + prev_snapshots = self.get_snapshots( + changed_version_prev_snapshots_by_name.values() + ).values() + cache: t.Dict[str, datetime] = {} + + for prev_snapshot in prev_snapshots: + target_snapshot = target_snapshots_by_name[prev_snapshot.name] + if ( + (snapshot_names is None or prev_snapshot.name in snapshot_names) + and target_snapshot.is_incremental + and prev_snapshot.is_incremental + and prev_snapshot.intervals + ): + start = to_timestamp( + start_date(target_snapshot, target_snapshots_by_name.values(), cache) + ) + end = prev_snapshot.intervals[-1][1] + + if start < end: + missing_intervals = target_snapshot.missing_intervals( + start, end, end_bounded=True + ) + + if missing_intervals: + raise SQLMeshError( + f"Detected missing intervals for model {target_snapshot.name}, interrupting your current plan. " + "Please re-apply your plan to resolve this error." + ) + + @contextlib.contextmanager + def _transaction(self) -> t.Iterator[None]: + with self.engine_adapter.transaction(): + yield + + +def _remove_partial_intervals( + intervals: t.List[Interval], snapshot_id: t.Optional[SnapshotId], *, is_dev: bool +) -> t.List[Interval]: + results = [] + for start_ts, end_ts in intervals: + if start_ts < end_ts: + logger.info( + "Adding %s (%s, %s) for snapshot %s", + "dev interval" if is_dev else "interval", + time_like_to_str(start_ts), + time_like_to_str(end_ts), + snapshot_id, + ) + results.append((start_ts, end_ts)) + else: + logger.info( + "Skipping partial interval (%s, %s) for snapshot %s", + start_ts, + end_ts, + snapshot_id, + ) + return results diff --git a/sqlmesh/core/state_sync/db/interval.py b/sqlmesh/core/state_sync/db/interval.py new file mode 100644 index 0000000000..8ccdc58fa0 --- /dev/null +++ b/sqlmesh/core/state_sync/db/interval.py @@ -0,0 +1,495 @@ +from __future__ import annotations + +import typing as t +import logging + +from sqlglot import exp + +from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.core.state_sync.db.utils import ( + snapshot_name_version_filter, + snapshot_id_filter, + create_batches, + fetchall, +) +from sqlmesh.core.snapshot import ( + SnapshotIntervals, + SnapshotIdLike, + SnapshotIdAndVersionLike, + SnapshotNameVersionLike, + SnapshotTableCleanupTask, + SnapshotNameVersion, + Snapshot, +) +from sqlmesh.core.snapshot.definition import Interval +from sqlmesh.utils.migration import index_text_type +from sqlmesh.utils import random_id +from sqlmesh.utils.date import now_timestamp + +if t.TYPE_CHECKING: + import pandas as pd + + +logger = logging.getLogger(__name__) + + +class IntervalState: + INTERVAL_BATCH_SIZE = 1000 + SNAPSHOT_BATCH_SIZE = 1000 + + def __init__( + self, + engine_adapter: EngineAdapter, + schema: t.Optional[str] = None, + table_name: t.Optional[str] = None, + ): + self.engine_adapter = engine_adapter + self.intervals_table = exp.table_(table_name or "_intervals", db=schema) + + index_type = index_text_type(engine_adapter.dialect) + self._interval_columns_to_types = { + "id": exp.DataType.build(index_type), + "created_ts": exp.DataType.build("bigint"), + "name": exp.DataType.build(index_type), + "identifier": exp.DataType.build("text"), + "version": exp.DataType.build(index_type), + "dev_version": exp.DataType.build(index_type), + "start_ts": exp.DataType.build("bigint"), + "end_ts": exp.DataType.build("bigint"), + "is_dev": exp.DataType.build("boolean"), + "is_removed": exp.DataType.build("boolean"), + "is_compacted": exp.DataType.build("boolean"), + "is_pending_restatement": exp.DataType.build("boolean"), + "last_altered_ts": exp.DataType.build("bigint"), + } + + def add_snapshots_intervals(self, snapshots_intervals: t.Sequence[SnapshotIntervals]) -> None: + if snapshots_intervals: + self._push_snapshot_intervals(snapshots_intervals) + + def remove_intervals( + self, + snapshot_intervals: t.Sequence[t.Tuple[SnapshotIdAndVersionLike, Interval]], + remove_shared_versions: bool = False, + ) -> None: + intervals_to_remove: t.Sequence[ + t.Tuple[t.Union[SnapshotIdAndVersionLike, SnapshotIntervals], Interval] + ] = snapshot_intervals + if remove_shared_versions: + name_version_mapping = {s.name_version: interval for s, interval in snapshot_intervals} + all_snapshots = [] + for where in snapshot_name_version_filter( + self.engine_adapter, + name_version_mapping, + alias=None, + batch_size=self.SNAPSHOT_BATCH_SIZE, + ): + all_snapshots.extend( + [ + SnapshotIntervals( + name=r[0], + identifier=r[1], + version=r[2], + dev_version=r[3], + intervals=[], + dev_intervals=[], + ) + for r in fetchall( + self.engine_adapter, + exp.select("name", "identifier", "version", "dev_version") + .from_(self.intervals_table) + .where(where) + .distinct(), + ) + ] + ) + intervals_to_remove = [ + (snapshot, name_version_mapping[snapshot.name_version]) + for snapshot in all_snapshots + ] + + if logger.isEnabledFor(logging.INFO): + snapshot_ids = ", ".join(str(s.snapshot_id) for s, _ in intervals_to_remove) + logger.info("Removing interval for snapshots: %s", snapshot_ids) + + self.engine_adapter.insert_append( + self.intervals_table, + _intervals_to_df(intervals_to_remove, is_dev=False, is_removed=True), + target_columns_to_types=self._interval_columns_to_types, + track_rows_processed=False, + ) + + def get_snapshot_intervals( + self, snapshots: t.Collection[SnapshotNameVersionLike] + ) -> t.List[SnapshotIntervals]: + return self._get_snapshot_intervals(snapshots)[1] + + def compact_intervals(self) -> None: + interval_ids, snapshot_intervals = self._get_snapshot_intervals(uncompacted_only=True) + + logger.info( + "Compacting %s intervals for %s snapshots", len(interval_ids), len(snapshot_intervals) + ) + + self._push_snapshot_intervals(snapshot_intervals, is_compacted=True) + + if interval_ids: + for interval_id_batch in create_batches( + list(interval_ids), batch_size=self.INTERVAL_BATCH_SIZE + ): + self.engine_adapter.delete_from( + self.intervals_table, exp.column("id").isin(*interval_id_batch) + ) + + def refresh_snapshot_intervals(self, snapshots: t.Collection[Snapshot]) -> t.List[Snapshot]: + if not snapshots: + return [] + + _, intervals = self._get_snapshot_intervals([s for s in snapshots if s.version]) + for s in snapshots: + s.intervals = [] + s.dev_intervals = [] + return Snapshot.hydrate_with_intervals_by_version(snapshots, intervals) + + def max_interval_end_per_model( + self, snapshots: t.Collection[SnapshotNameVersionLike] + ) -> t.Dict[str, int]: + if not snapshots: + return {} + + table_alias = "intervals" + name_col = exp.column("name", table=table_alias) + version_col = exp.column("version", table=table_alias) + + result: t.Dict[str, int] = {} + + for where in snapshot_name_version_filter( + self.engine_adapter, snapshots, alias=table_alias, batch_size=self.SNAPSHOT_BATCH_SIZE + ): + query = ( + exp.select( + name_col, + exp.func("MAX", exp.column("end_ts", table=table_alias)).as_("max_end_ts"), + ) + .from_(exp.to_table(self.intervals_table).as_(table_alias)) + .where(where, copy=False) + .where( + exp.and_( + exp.to_column("is_dev").not_(), + exp.to_column("is_removed").not_(), + exp.to_column("is_pending_restatement").not_(), + ), + copy=False, + ) + .group_by(name_col, version_col, copy=False) + ) + + for name, max_end in fetchall(self.engine_adapter, query): + result[name] = max_end + + return result + + def cleanup_intervals( + self, + cleanup_targets: t.List[SnapshotTableCleanupTask], + expired_snapshot_ids: t.Collection[SnapshotIdLike], + ) -> None: + # Cleanup can only happen for compacted intervals + self.compact_intervals() + # Delete intervals for non-dev tables that are no longer used + self._delete_intervals_by_version(cleanup_targets) + # Delete dev intervals for dev tables that are no longer used + self._delete_intervals_by_dev_version(cleanup_targets) + # Nullify the snapshot identifiers of interval records for snapshots that have been deleted + self._update_intervals_for_deleted_snapshots(expired_snapshot_ids) + + def _push_snapshot_intervals( + self, + snapshots: t.Iterable[t.Union[Snapshot, SnapshotIntervals]], + is_compacted: bool = False, + ) -> None: + import pandas as pd + + new_intervals = [] + for snapshot in snapshots: + logger.info("Pushing intervals for snapshot %s", snapshot.snapshot_id) + for start_ts, end_ts in snapshot.intervals: + new_intervals.append( + _interval_to_df( + snapshot, + start_ts, + end_ts, + is_dev=False, + is_compacted=is_compacted, + last_altered_ts=snapshot.last_altered_ts, + ) + ) + for start_ts, end_ts in snapshot.dev_intervals: + new_intervals.append( + _interval_to_df( + snapshot, + start_ts, + end_ts, + is_dev=True, + is_compacted=is_compacted, + last_altered_ts=snapshot.dev_last_altered_ts, + ) + ) + + # Make sure that all pending restatement intervals are recorded last + for snapshot in snapshots: + for start_ts, end_ts in snapshot.pending_restatement_intervals: + new_intervals.append( + _interval_to_df( + snapshot, + start_ts, + end_ts, + is_dev=False, + is_compacted=is_compacted, + is_pending_restatement=True, + last_altered_ts=snapshot.last_altered_ts, + ) + ) + + if new_intervals: + self.engine_adapter.insert_append( + self.intervals_table, + pd.DataFrame(new_intervals), + target_columns_to_types=self._interval_columns_to_types, + track_rows_processed=False, + ) + + def _get_snapshot_intervals( + self, + snapshots: t.Optional[t.Collection[SnapshotNameVersionLike]] = None, + uncompacted_only: bool = False, + ) -> t.Tuple[t.Set[str], t.List[SnapshotIntervals]]: + if not snapshots and snapshots is not None: + return (set(), []) + + query = self._get_snapshot_intervals_query(uncompacted_only) + + interval_ids: t.Set[str] = set() + intervals: t.Dict[ + t.Tuple[str, str, t.Optional[str], t.Optional[str]], SnapshotIntervals + ] = {} + + for where in ( + snapshot_name_version_filter( + self.engine_adapter, + snapshots, + alias="intervals", + batch_size=self.SNAPSHOT_BATCH_SIZE, + ) + if snapshots + else [None] + ): + rows = fetchall(self.engine_adapter, query.where(where)) + for ( + interval_id, + name, + identifier, + version, + dev_version, + start, + end, + is_dev, + is_removed, + is_pending_restatement, + last_altered_ts, + ) in rows: + interval_ids.add(interval_id) + merge_key = (name, version, dev_version, identifier) + # Pending restatement intervals are merged by name and version + pending_restatement_interval_merge_key = (name, version, None, None) + + if merge_key not in intervals: + intervals[merge_key] = SnapshotIntervals( + name=name, + identifier=identifier, + version=version, + dev_version=dev_version, + ) + + if pending_restatement_interval_merge_key not in intervals: + intervals[pending_restatement_interval_merge_key] = SnapshotIntervals( + name=name, + identifier=None, + version=version, + dev_version=None, + ) + + if is_removed: + if is_dev: + intervals[merge_key].remove_dev_interval(start, end) + else: + intervals[merge_key].remove_interval(start, end) + elif is_pending_restatement: + intervals[ + pending_restatement_interval_merge_key + ].add_pending_restatement_interval(start, end) + else: + if is_dev: + intervals[merge_key].add_dev_interval(start, end) + intervals[merge_key].update_dev_last_altered_ts(last_altered_ts) + else: + intervals[merge_key].add_interval(start, end) + intervals[merge_key].update_last_altered_ts(last_altered_ts) + # Remove all pending restatement intervals recorded before the current interval has been added + intervals[ + pending_restatement_interval_merge_key + ].remove_pending_restatement_interval(start, end) + + return interval_ids, [i for i in intervals.values() if not i.is_empty()] + + def _get_snapshot_intervals_query(self, uncompacted_only: bool) -> exp.Select: + query = ( + exp.select( + "id", + exp.column("name", table="intervals"), + exp.column("identifier", table="intervals"), + exp.column("version", table="intervals"), + exp.column("dev_version", table="intervals"), + "start_ts", + "end_ts", + "is_dev", + "is_removed", + "is_pending_restatement", + "last_altered_ts", + ) + .from_(exp.to_table(self.intervals_table).as_("intervals")) + .order_by( + exp.column("name", table="intervals"), + exp.column("version", table="intervals"), + "created_ts", + "is_removed", + "is_pending_restatement", + ) + ) + if uncompacted_only: + query.join( + exp.select("name", "version") + .from_(exp.to_table(self.intervals_table).as_("intervals")) + .where(exp.column("is_compacted").not_()) + .distinct() + .subquery(alias="uncompacted"), + on=exp.and_( + exp.column("name", table="intervals").eq( + exp.column("name", table="uncompacted") + ), + exp.column("version", table="intervals").eq( + exp.column("version", table="uncompacted") + ), + ), + copy=False, + ) + return query + + def _update_intervals_for_deleted_snapshots( + self, snapshot_ids: t.Collection[SnapshotIdLike] + ) -> None: + """Nullifies the snapshot identifiers of dev interval records and snapshot identifiers and dev versions of + non-dev interval records for snapshots that have been deleted so that they can be compacted efficiently. + """ + if not snapshot_ids: + return + + for where in snapshot_id_filter( + self.engine_adapter, snapshot_ids, alias=None, batch_size=self.SNAPSHOT_BATCH_SIZE + ): + # Nullify the identifier for dev intervals + # Set is_compacted to False so that it's compacted during the next compaction + self.engine_adapter.update_table( + self.intervals_table, + {"identifier": None, "is_compacted": False}, + where=where.and_(exp.column("is_dev")), + ) + # Nullify both identifier and dev version for non-dev intervals + # Set is_compacted to False so that it's compacted during the next compaction + self.engine_adapter.update_table( + self.intervals_table, + {"identifier": None, "dev_version": None, "is_compacted": False}, + where=where.and_(exp.column("is_dev").not_()), + ) + + def _delete_intervals_by_dev_version(self, targets: t.List[SnapshotTableCleanupTask]) -> None: + """Deletes dev intervals for snapshot dev versions that are no longer used.""" + dev_keys_to_delete = [ + SnapshotNameVersion(name=t.snapshot.name, version=t.snapshot.dev_version) + for t in targets + if t.dev_table_only + ] + if not dev_keys_to_delete: + return + + for where in snapshot_name_version_filter( + self.engine_adapter, + dev_keys_to_delete, + version_column_name="dev_version", + alias=None, + batch_size=self.SNAPSHOT_BATCH_SIZE, + ): + self.engine_adapter.delete_from(self.intervals_table, where.and_(exp.column("is_dev"))) + + def _delete_intervals_by_version(self, targets: t.List[SnapshotTableCleanupTask]) -> None: + """Deletes intervals for snapshot versions that are no longer used.""" + non_dev_keys_to_delete = [t.snapshot for t in targets if not t.dev_table_only] + if not non_dev_keys_to_delete: + return + + for where in snapshot_name_version_filter( + self.engine_adapter, + non_dev_keys_to_delete, + alias=None, + batch_size=self.SNAPSHOT_BATCH_SIZE, + ): + self.engine_adapter.delete_from(self.intervals_table, where) + + +def _intervals_to_df( + snapshot_intervals: t.Sequence[ + t.Tuple[t.Union[SnapshotIdAndVersionLike, SnapshotIntervals], Interval] + ], + is_dev: bool, + is_removed: bool, +) -> pd.DataFrame: + import pandas as pd + + return pd.DataFrame( + [ + _interval_to_df( + s, + *interval, + is_dev=is_dev, + is_removed=is_removed, + ) + for s, interval in snapshot_intervals + ] + ) + + +def _interval_to_df( + snapshot: t.Union[SnapshotIdAndVersionLike, SnapshotIntervals], + start_ts: int, + end_ts: int, + is_dev: bool = False, + is_removed: bool = False, + is_compacted: bool = False, + is_pending_restatement: bool = False, + last_altered_ts: t.Optional[int] = None, +) -> t.Dict[str, t.Any]: + return { + "id": random_id(), + "created_ts": now_timestamp(), + "name": snapshot.name, + "identifier": snapshot.identifier if not is_pending_restatement else None, + "version": snapshot.version, + "dev_version": snapshot.dev_version if not is_pending_restatement else None, + "start_ts": start_ts, + "end_ts": end_ts, + "is_dev": is_dev, + "is_removed": is_removed, + "is_compacted": is_compacted, + "is_pending_restatement": is_pending_restatement, + "last_altered_ts": last_altered_ts, + } diff --git a/sqlmesh/core/state_sync/db/migrator.py b/sqlmesh/core/state_sync/db/migrator.py new file mode 100644 index 0000000000..8d73e1d395 --- /dev/null +++ b/sqlmesh/core/state_sync/db/migrator.py @@ -0,0 +1,483 @@ +from __future__ import annotations + +import json +import logging +import time +import typing as t +from copy import deepcopy + +from sqlglot import __version__ as SQLGLOT_VERSION +from sqlglot import exp + +from sqlmesh.core import analytics +from sqlmesh.core import constants as c +from sqlmesh.core.console import Console, get_console +from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.core.environment import Environment +from sqlmesh.core.snapshot import ( + Node, + Snapshot, + SnapshotFingerprint, + SnapshotId, + SnapshotTableInfo, + fingerprint_from_node, +) +from sqlmesh.core.snapshot.definition import ( + _parents_from_node, +) +from sqlmesh.core.state_sync.base import ( + MIGRATIONS, + MIN_SCHEMA_VERSION, + MIN_SQLMESH_VERSION, +) +from sqlmesh.core.state_sync.db.environment import EnvironmentState +from sqlmesh.core.state_sync.db.interval import IntervalState +from sqlmesh.core.state_sync.db.snapshot import SnapshotState +from sqlmesh.core.state_sync.db.version import VersionState +from sqlmesh.core.state_sync.db.utils import ( + SQLMESH_VERSION, + snapshot_id_filter, + fetchall, +) +from sqlmesh.utils import major_minor +from sqlmesh.utils.dag import DAG +from sqlmesh.utils.date import now_timestamp +from sqlmesh.utils.errors import SQLMeshError, StateMigrationError + +logger = logging.getLogger(__name__) + + +if t.TYPE_CHECKING: + from sqlmesh.core._typing import TableName + + +class StateMigrator: + SNAPSHOT_BATCH_SIZE = 1000 + SNAPSHOT_MIGRATION_BATCH_SIZE = 500 + + def __init__( + self, + engine_adapter: EngineAdapter, + version_state: VersionState, + snapshot_state: SnapshotState, + environment_state: EnvironmentState, + interval_state: IntervalState, + console: t.Optional[Console] = None, + ): + self.engine_adapter = engine_adapter + self.console = console or get_console() + self.version_state = version_state + self.snapshot_state = snapshot_state + self.environment_state = environment_state + self.interval_state = interval_state + + self._state_tables = [ + self.snapshot_state.snapshots_table, + self.environment_state.environments_table, + self.version_state.versions_table, + ] + self._optional_state_tables = [ + self.interval_state.intervals_table, + self.snapshot_state.auto_restatements_table, + self.environment_state.environment_statements_table, + ] + + def migrate( + self, + schema: t.Optional[str], + skip_backup: bool = False, + promoted_snapshots_only: bool = True, + ) -> None: + """Migrate the state sync to the latest SQLMesh / SQLGlot version.""" + versions = self.version_state.get_versions() + migration_start_ts = time.perf_counter() + + try: + migrate_rows = self._apply_migrations(schema, skip_backup) + + if not migrate_rows and major_minor(SQLMESH_VERSION) == versions.minor_sqlmesh_version: + return + + if migrate_rows: + self._migrate_rows(promoted_snapshots_only) + self.version_state.update_versions() + + analytics.collector.on_migration_end( + from_sqlmesh_version=versions.sqlmesh_version, + state_sync_type=self.engine_adapter.dialect, + migration_time_sec=time.perf_counter() - migration_start_ts, + ) + except Exception as e: + if skip_backup: + logger.error("Backup was skipped so no rollback was attempted.") + else: + self.rollback() + + analytics.collector.on_migration_end( + from_sqlmesh_version=versions.sqlmesh_version, + state_sync_type=self.engine_adapter.dialect, + migration_time_sec=time.perf_counter() - migration_start_ts, + error=e, + ) + + self.console.log_migration_status(success=False) + if isinstance(e, StateMigrationError): + raise + raise SQLMeshError("SQLMesh migration failed.") from e + + self.console.log_migration_status() + + def rollback(self) -> None: + """Rollback to the previous migration.""" + logger.info("Starting migration rollback.") + versions = self.version_state.get_versions() + if versions.schema_version == 0: + # Clean up state tables + for table in self._state_tables + self._optional_state_tables: + self.engine_adapter.drop_table(table) + else: + if not all( + self.engine_adapter.table_exists(_backup_table_name(table)) + for table in self._state_tables + ): + raise SQLMeshError("There are no prior migrations to roll back to.") + for table in self._state_tables: + self._restore_table(table, _backup_table_name(table)) + + for optional_table in self._optional_state_tables: + if self.engine_adapter.table_exists(_backup_table_name(optional_table)): + self._restore_table(optional_table, _backup_table_name(optional_table)) + + logger.info("Migration rollback successful.") + + def _apply_migrations( + self, + schema: t.Optional[str], + skip_backup: bool, + ) -> bool: + versions = self.version_state.get_versions() + first_script_index = 0 + if versions.schema_version and versions.schema_version < MIN_SCHEMA_VERSION: + raise StateMigrationError( + "The current state belongs to an old version of SQLMesh that is no longer supported. " + f"Please upgrade to {MIN_SQLMESH_VERSION} first before upgrading to {SQLMESH_VERSION}." + ) + elif versions.schema_version > 0: + # -1 to skip the baseline migration script + first_script_index = versions.schema_version - (MIN_SCHEMA_VERSION - 1) + + migrations = MIGRATIONS[first_script_index:] + should_backup = any( + [ + migrations, + major_minor(SQLGLOT_VERSION) != versions.minor_sqlglot_version, + major_minor(SQLMESH_VERSION) != versions.minor_sqlmesh_version, + ] + ) + if not skip_backup and should_backup: + self._backup_state() + + snapshot_count_before = self.snapshot_state.count() if versions.schema_version else None + + state_table_exist = any(self.engine_adapter.table_exists(t) for t in self._state_tables) + + for migration in migrations: + logger.info(f"Applying migration {migration}") + migration.migrate_schemas(engine_adapter=self.engine_adapter, schema=schema) + if state_table_exist: + # No need to run DML for the initial migration since all tables are empty + migration.migrate_rows(engine_adapter=self.engine_adapter, schema=schema) + + snapshot_count_after = self.snapshot_state.count() + + if snapshot_count_before is not None and snapshot_count_before != snapshot_count_after: + scripts = f"{versions.schema_version} - {versions.schema_version + len(migrations)}" + raise SQLMeshError( + f"Number of snapshots before ({snapshot_count_before}) and after " + f"({snapshot_count_after}) applying migration scripts {scripts} does not match. " + "Please file an issue issue at https://github.com/SQLMesh/sqlmesh/issues/new." + ) + + migrate_snapshots_and_environments = ( + bool(migrations) or major_minor(SQLGLOT_VERSION) != versions.minor_sqlglot_version + ) + return migrate_snapshots_and_environments + + def _migrate_rows(self, promoted_snapshots_only: bool) -> None: + logger.info("Fetching environments") + environments = self.environment_state.get_environments() + # Only migrate snapshots that are part of at least one environment. + snapshots_to_migrate = ( + {s.snapshot_id for e in environments for s in e.snapshots} + if promoted_snapshots_only + else None + ) + snapshot_mapping = self._migrate_snapshot_rows(snapshots_to_migrate) + if not snapshot_mapping: + logger.info("No changes to snapshots detected") + return + self._migrate_environment_rows(environments, snapshot_mapping) + + def _migrate_snapshot_rows( + self, snapshots: t.Optional[t.Set[SnapshotId]] + ) -> t.Dict[SnapshotId, SnapshotTableInfo]: + logger.info("Migrating snapshot rows...") + raw_snapshots = { + SnapshotId(name=name, identifier=identifier): { + **json.loads(raw_snapshot), + "updated_ts": updated_ts, + "unpaused_ts": unpaused_ts, + "unrestorable": unrestorable, + "forward_only": forward_only, + } + for where in ( + snapshot_id_filter( + self.engine_adapter, snapshots, batch_size=self.SNAPSHOT_BATCH_SIZE + ) + if snapshots is not None + else [None] + ) + for name, identifier, raw_snapshot, updated_ts, unpaused_ts, unrestorable, forward_only in fetchall( + self.engine_adapter, + exp.select( + "name", + "identifier", + "snapshot", + "updated_ts", + "unpaused_ts", + "unrestorable", + "forward_only", + ) + .from_(self.snapshot_state.snapshots_table) + .where(where) + .lock(), + ) + } + if not raw_snapshots: + return {} + + dag: DAG[SnapshotId] = DAG() + for snapshot_id, raw_snapshot in raw_snapshots.items(): + parent_ids = [SnapshotId.parse_obj(p_id) for p_id in raw_snapshot.get("parents", [])] + dag.add(snapshot_id, [p_id for p_id in parent_ids if p_id in raw_snapshots]) + + reversed_dag_raw = dag.reversed.graph + + self.console.start_snapshot_migration_progress(len(raw_snapshots)) + + parsed_snapshots = LazilyParsedSnapshots(raw_snapshots) + all_snapshot_mapping: t.Dict[SnapshotId, SnapshotTableInfo] = {} + snapshot_id_mapping: t.Dict[SnapshotId, SnapshotId] = {} + new_snapshots: t.Dict[SnapshotId, Snapshot] = {} + visited: t.Set[SnapshotId] = set() + + def _push_new_snapshots() -> None: + all_snapshot_mapping.update( + { + from_id: new_snapshots[to_id].table_info + for from_id, to_id in snapshot_id_mapping.items() + } + ) + + existing_new_snapshots = self.snapshot_state.snapshots_exist(new_snapshots) + new_snapshots_to_push = [ + s for s in new_snapshots.values() if s.snapshot_id not in existing_new_snapshots + ] + if new_snapshots_to_push: + logger.info("Pushing %s migrated snapshots", len(new_snapshots_to_push)) + self._push_snapshots(new_snapshots_to_push) + new_snapshots.clear() + snapshot_id_mapping.clear() + + def _visit( + snapshot_id: SnapshotId, fingerprint_cache: t.Dict[str, SnapshotFingerprint] + ) -> None: + if snapshot_id in visited or snapshot_id not in raw_snapshots: + return + visited.add(snapshot_id) + + snapshot = parsed_snapshots[snapshot_id] + node = snapshot.node + + node_seen = set() + node_queue = {snapshot_id} + nodes: t.Dict[str, Node] = {} + while node_queue: + next_snapshot_id = node_queue.pop() + next_snapshot = parsed_snapshots.get(next_snapshot_id) + + if next_snapshot_id in node_seen or not next_snapshot: + continue + + node_seen.add(next_snapshot_id) + node_queue.update(next_snapshot.parents) + nodes[next_snapshot.name] = next_snapshot.node + + new_snapshot = deepcopy(snapshot) + try: + new_snapshot.fingerprint = fingerprint_from_node( + node, + nodes=nodes, + cache=fingerprint_cache, + ) + new_snapshot.parents = tuple( + SnapshotId( + name=parent_node.fqn, + identifier=fingerprint_from_node( + parent_node, + nodes=nodes, + cache=fingerprint_cache, + ).to_identifier(), + ) + for parent_node in _parents_from_node(node, nodes).values() + ) + except Exception: + logger.exception("Could not compute fingerprint for %s", snapshot.snapshot_id) + return + + # Reset the effective_from date for the new snapshot to avoid unexpected backfills. + new_snapshot.effective_from = None + new_snapshot.previous_versions = snapshot.all_versions + new_snapshot.migrated = True + if not new_snapshot.dev_version_: + new_snapshot.dev_version_ = snapshot.dev_version + + self.console.update_snapshot_migration_progress(1) + + # Visit children and evict them from the parsed_snapshots cache after. + for child in reversed_dag_raw.get(snapshot_id, []): + # Make sure to copy the fingerprint cache to avoid sharing it between different child snapshots with the same name. + _visit(child, fingerprint_cache.copy()) + parsed_snapshots.evict(child) + + if new_snapshot.fingerprint == snapshot.fingerprint: + logger.debug(f"{new_snapshot.snapshot_id} is unchanged.") + return + + new_snapshot_id = new_snapshot.snapshot_id + + if new_snapshot_id in raw_snapshots: + # Mapped to an existing snapshot. + new_snapshots[new_snapshot_id] = parsed_snapshots[new_snapshot_id] + logger.debug("Migrated snapshot %s already exists", new_snapshot_id) + elif ( + new_snapshot_id not in new_snapshots + or new_snapshot.updated_ts > new_snapshots[new_snapshot_id].updated_ts + ): + new_snapshots[new_snapshot_id] = new_snapshot + + snapshot_id_mapping[snapshot.snapshot_id] = new_snapshot_id + logger.debug("%s mapped to %s", snapshot.snapshot_id, new_snapshot_id) + + if len(new_snapshots) >= self.SNAPSHOT_MIGRATION_BATCH_SIZE: + _push_new_snapshots() + + for root_snapshot_id in dag.roots: + _visit(root_snapshot_id, {}) + + if new_snapshots: + _push_new_snapshots() + + self.console.stop_snapshot_migration_progress() + return all_snapshot_mapping + + def _migrate_environment_rows( + self, + environments: t.List[Environment], + snapshot_mapping: t.Dict[SnapshotId, SnapshotTableInfo], + ) -> None: + logger.info("Migrating environment rows...") + + updated_prod_environment: t.Optional[Environment] = None + updated_environments = [] + for environment in environments: + snapshots = [ + ( + snapshot_mapping[info.snapshot_id] + if info.snapshot_id in snapshot_mapping + else info + ) + for info in environment.snapshots + ] + + if snapshots != environment.snapshots: + environment.snapshots_ = snapshots + updated_environments.append(environment) + if environment.name == c.PROD: + updated_prod_environment = environment + self.console.start_env_migration_progress(len(updated_environments)) + + for environment in updated_environments: + self._update_environment(environment) + self.console.update_env_migration_progress(1) + + if updated_prod_environment: + try: + self.snapshot_state.unpause_snapshots( + updated_prod_environment.snapshots, now_timestamp() + ) + except Exception: + logger.warning("Failed to unpause migrated snapshots", exc_info=True) + + self.console.stop_env_migration_progress() + + def _backup_state(self) -> None: + for table in [ + *self._state_tables, + *self._optional_state_tables, + ]: + if self.engine_adapter.table_exists(table): + with self.engine_adapter.transaction(): + backup_name = _backup_table_name(table) + self.engine_adapter.drop_table(backup_name) + self.engine_adapter.create_table_like(backup_name, table) + self.engine_adapter.insert_append( + backup_name, exp.select("*").from_(table), track_rows_processed=False + ) + + def _restore_table( + self, + table_name: TableName, + backup_table_name: TableName, + ) -> None: + self.engine_adapter.drop_table(table_name) + self.engine_adapter.rename_table( + old_table_name=backup_table_name, + new_table_name=table_name, + ) + + def _update_environment(self, environment: Environment) -> None: + self.environment_state.update_environment(environment) + + def _push_snapshots(self, snapshots: t.Iterable[Snapshot]) -> None: + self.snapshot_state.push_snapshots(snapshots) + + +def _backup_table_name(table_name: TableName) -> exp.Table: + table = exp.to_table(table_name).copy() + table.set("this", exp.to_identifier(table.name + "_backup")) + return table + + +class LazilyParsedSnapshots: + def __init__(self, raw_snapshots: t.Dict[SnapshotId, t.Dict[str, t.Any]]): + self._raw_snapshots = raw_snapshots + self._parsed_snapshots: t.Dict[SnapshotId, t.Optional[Snapshot]] = {} + + def get(self, snapshot_id: SnapshotId) -> t.Optional[Snapshot]: + if snapshot_id not in self._parsed_snapshots: + raw_snapshot = self._raw_snapshots.get(snapshot_id) + if raw_snapshot: + self._parsed_snapshots[snapshot_id] = Snapshot.parse_obj(raw_snapshot) + else: + self._parsed_snapshots[snapshot_id] = None + return self._parsed_snapshots[snapshot_id] + + def evict(self, snapshot_id: SnapshotId) -> None: + self._parsed_snapshots.pop(snapshot_id, None) + + def __getitem__(self, snapshot_id: SnapshotId) -> Snapshot: + snapshot = self.get(snapshot_id) + if snapshot is None: + raise KeyError(snapshot_id) + return snapshot diff --git a/sqlmesh/core/state_sync/db/snapshot.py b/sqlmesh/core/state_sync/db/snapshot.py new file mode 100644 index 0000000000..d584c69d65 --- /dev/null +++ b/sqlmesh/core/state_sync/db/snapshot.py @@ -0,0 +1,790 @@ +from __future__ import annotations + +import typing as t +import json +import logging +from pathlib import Path +from collections import defaultdict +from sqlglot import exp + +from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.core.state_sync.db.utils import ( + snapshot_name_filter, + snapshot_name_version_filter, + snapshot_id_filter, + fetchone, + fetchall, +) +from sqlmesh.core.environment import Environment +from sqlmesh.core.model import SeedModel, ModelKindName +from sqlmesh.core.snapshot.cache import SnapshotCache +from sqlmesh.core.snapshot import ( + SnapshotIdLike, + SnapshotNameVersionLike, + SnapshotTableCleanupTask, + SnapshotNameVersion, + SnapshotInfoLike, + Snapshot, + SnapshotIdAndVersion, + SnapshotId, + SnapshotFingerprint, +) +from sqlmesh.core.state_sync.common import ( + RowBoundary, + ExpiredSnapshotBatch, + ExpiredBatchRange, + LimitBoundary, +) +from sqlmesh.utils.migration import index_text_type, blob_text_type +from sqlmesh.utils.date import now_timestamp, TimeLike, to_timestamp +from sqlmesh.utils import unique + +if t.TYPE_CHECKING: + import pandas as pd + + +logger = logging.getLogger(__name__) + + +class SnapshotState: + SNAPSHOT_BATCH_SIZE = 1000 + + def __init__( + self, + engine_adapter: EngineAdapter, + schema: t.Optional[str] = None, + cache_dir: Path = Path(), + ): + self.engine_adapter = engine_adapter + self.snapshots_table = exp.table_("_snapshots", db=schema) + self.auto_restatements_table = exp.table_("_auto_restatements", db=schema) + + index_type = index_text_type(engine_adapter.dialect) + blob_type = blob_text_type(engine_adapter.dialect) + self._snapshot_columns_to_types = { + "name": exp.DataType.build(index_type), + "identifier": exp.DataType.build(index_type), + "version": exp.DataType.build(index_type), + "dev_version": exp.DataType.build(index_type), + "snapshot": exp.DataType.build(blob_type), + "kind_name": exp.DataType.build("text"), + "updated_ts": exp.DataType.build("bigint"), + "unpaused_ts": exp.DataType.build("bigint"), + "ttl_ms": exp.DataType.build("bigint"), + "unrestorable": exp.DataType.build("boolean"), + "forward_only": exp.DataType.build("boolean"), + "fingerprint": exp.DataType.build(blob_type), + } + + self._auto_restatement_columns_to_types = { + "snapshot_name": exp.DataType.build(index_type), + "snapshot_version": exp.DataType.build(index_type), + "next_auto_restatement_ts": exp.DataType.build("bigint"), + } + + self._snapshot_cache = SnapshotCache(cache_dir) + + def push_snapshots(self, snapshots: t.Iterable[Snapshot], overwrite: bool = False) -> None: + """Pushes snapshots to the state store. + + Args: + snapshots: The snapshots to push. + overwrite: Whether to overwrite existing snapshots. + """ + if overwrite: + snapshots = tuple(snapshots) + self.delete_snapshots(snapshots) + + snapshots_to_store = [] + + for snapshot in snapshots: + if isinstance(snapshot.node, SeedModel): + seed_model = t.cast(SeedModel, snapshot.node) + snapshot = snapshot.copy(update={"node": seed_model.to_dehydrated()}) + snapshots_to_store.append(snapshot) + + self.engine_adapter.insert_append( + self.snapshots_table, + _snapshots_to_df(snapshots_to_store), + target_columns_to_types=self._snapshot_columns_to_types, + track_rows_processed=False, + ) + + for snapshot in snapshots: + self._snapshot_cache.put(snapshot) + + def unpause_snapshots( + self, + snapshots: t.Collection[SnapshotInfoLike], + unpaused_dt: TimeLike, + ) -> None: + unrestorable_snapshots_by_forward_only: t.Dict[bool, t.List[SnapshotNameVersion]] = ( + defaultdict(list) + ) + + for snapshot in snapshots: + # We need to mark all other snapshots that have forward-only opposite to the target snapshot as unrestorable + unrestorable_snapshots_by_forward_only[not snapshot.is_forward_only].append( + snapshot.name_version + ) + + updated_ts = now_timestamp() + unpaused_ts = to_timestamp(unpaused_dt) + + # Pause all snapshots with target names first + for where in snapshot_name_filter( + [s.name for s in snapshots], + batch_size=self.SNAPSHOT_BATCH_SIZE, + ): + self.engine_adapter.update_table( + self.snapshots_table, + {"unpaused_ts": None, "updated_ts": updated_ts}, + where=where, + ) + + # Now unpause the target snapshots + self._update_snapshots( + [s.snapshot_id for s in snapshots], + unpaused_ts=unpaused_ts, + updated_ts=updated_ts, + ) + + # Mark unrestorable snapshots + for forward_only, snapshot_name_versions in unrestorable_snapshots_by_forward_only.items(): + forward_only_exp = exp.column("forward_only").is_(exp.convert(forward_only)) + for where in snapshot_name_version_filter( + self.engine_adapter, + snapshot_name_versions, + batch_size=self.SNAPSHOT_BATCH_SIZE, + alias=None, + ): + self.engine_adapter.update_table( + self.snapshots_table, + {"unrestorable": True, "updated_ts": updated_ts}, + where=forward_only_exp.and_(where), + ) + + def get_expired_snapshots( + self, + environments: t.Iterable[Environment], + current_ts: int, + ignore_ttl: bool, + batch_range: ExpiredBatchRange, + ) -> t.Optional[ExpiredSnapshotBatch]: + expired_query = exp.select("name", "identifier", "version", "updated_ts").from_( + self.snapshots_table + ) + + if not ignore_ttl: + expired_query = expired_query.where( + (exp.column("updated_ts") + exp.column("ttl_ms")) <= current_ts + ) + + expired_query = expired_query.where(batch_range.where_filter) + + promoted_snapshot_ids = { + snapshot.snapshot_id + for environment in environments + for snapshot in ( + environment.snapshots + if environment.finalized_ts is not None + # If the environment is not finalized, check both the current snapshots and the previous finalized snapshots + else [*environment.snapshots, *(environment.previous_finalized_snapshots or [])] + ) + } + + if promoted_snapshot_ids: + not_in_conditions = [ + exp.not_(condition) + for condition in snapshot_id_filter( + self.engine_adapter, + promoted_snapshot_ids, + batch_size=self.SNAPSHOT_BATCH_SIZE, + ) + ] + expired_query = expired_query.where(exp.and_(*not_in_conditions)) + + expired_query = expired_query.order_by( + exp.column("updated_ts"), exp.column("name"), exp.column("identifier") + ) + + if isinstance(batch_range.end, LimitBoundary): + expired_query = expired_query.limit(batch_range.end.batch_size) + + rows = fetchall(self.engine_adapter, expired_query) + + if not rows: + return None + + expired_candidates = { + SnapshotId(name=name, identifier=identifier): SnapshotNameVersion( + name=name, version=version + ) + for name, identifier, version, _ in rows + } + if not expired_candidates: + return None + + def _is_snapshot_used(snapshot: SnapshotIdAndVersion) -> bool: + return ( + snapshot.snapshot_id in promoted_snapshot_ids + or snapshot.snapshot_id not in expired_candidates + ) + + # Extract cursor values from last row for pagination + last_row = rows[-1] + last_row_boundary = RowBoundary( + updated_ts=last_row[3], + name=last_row[0], + identifier=last_row[1], + ) + # The returned batch_range represents the actual range of rows in this batch + result_batch_range = ExpiredBatchRange( + start=batch_range.start, + end=last_row_boundary, + ) + + unique_expired_versions = unique(expired_candidates.values()) + expired_snapshot_ids: t.Set[SnapshotId] = set() + cleanup_tasks: t.List[SnapshotTableCleanupTask] = [] + + snapshots = self._get_snapshots_with_same_version(unique_expired_versions) + + snapshots_by_version = defaultdict(set) + snapshots_by_dev_version = defaultdict(set) + for s in snapshots: + snapshots_by_version[(s.name, s.version)].add(s.snapshot_id) + snapshots_by_dev_version[(s.name, s.dev_version)].add(s.snapshot_id) + + expired_snapshots = [s for s in snapshots if not _is_snapshot_used(s)] + all_expired_snapshot_ids = {s.snapshot_id for s in expired_snapshots} + + cleanup_targets: t.List[t.Tuple[SnapshotId, bool]] = [] + for snapshot in expired_snapshots: + shared_version_snapshots = snapshots_by_version[(snapshot.name, snapshot.version)] + shared_version_snapshots.discard(snapshot.snapshot_id) + + shared_dev_version_snapshots = snapshots_by_dev_version[ + (snapshot.name, snapshot.dev_version) + ] + shared_dev_version_snapshots.discard(snapshot.snapshot_id) + + if not shared_dev_version_snapshots: + dev_table_only = bool(shared_version_snapshots) + cleanup_targets.append((snapshot.snapshot_id, dev_table_only)) + + snapshot_ids_to_cleanup = [snapshot_id for snapshot_id, _ in cleanup_targets] + full_snapshots = self._get_snapshots(snapshot_ids_to_cleanup) + for snapshot_id, dev_table_only in cleanup_targets: + if snapshot_id in full_snapshots: + cleanup_tasks.append( + SnapshotTableCleanupTask( + snapshot=full_snapshots[snapshot_id].table_info, + dev_table_only=dev_table_only, + ) + ) + expired_snapshot_ids.add(snapshot_id) + all_expired_snapshot_ids.discard(snapshot_id) + + # Add any remaining expired snapshots that don't require cleanup + if all_expired_snapshot_ids: + expired_snapshot_ids.update(all_expired_snapshot_ids) + + if expired_snapshot_ids or cleanup_tasks: + return ExpiredSnapshotBatch( + expired_snapshot_ids=expired_snapshot_ids, + cleanup_tasks=cleanup_tasks, + batch_range=result_batch_range, + ) + + return None + + def delete_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None: + """Deletes snapshots. + + Args: + snapshot_ids: The snapshot IDs to delete. + """ + if not snapshot_ids: + return + for where in snapshot_id_filter( + self.engine_adapter, snapshot_ids, batch_size=self.SNAPSHOT_BATCH_SIZE + ): + self.engine_adapter.delete_from(self.snapshots_table, where=where) + + def touch_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None: + """Touch snapshots to set their updated_ts to the current timestamp. + + Args: + snapshot_ids: The snapshot IDs to touch. + """ + self._update_snapshots(snapshot_ids) + + def get_snapshots( + self, + snapshot_ids: t.Iterable[SnapshotIdLike], + ) -> t.Dict[SnapshotId, Snapshot]: + """Fetches snapshots. + + Args: + snapshot_ids: The snapshot IDs to fetch. + + Returns: + A dictionary of snapshot IDs to snapshots. + """ + return self._get_snapshots(snapshot_ids) + + def get_snapshots_by_names( + self, + snapshot_names: t.Iterable[str], + current_ts: t.Optional[int] = None, + exclude_expired: bool = True, + ) -> t.Set[SnapshotIdAndVersion]: + """Return the snapshot records for all versions of the specified snapshot names. + + Args: + snapshot_names: Iterable of snapshot names to fetch all snapshot records for + current_ts: Sets the current time for identifying which snapshots have expired so they can be excluded (only relevant if :exclude_expired=True) + exclude_expired: Whether or not to return the snapshot id's of expired snapshots in the result + + Returns: + A set containing all the matched snapshot records. To fetch full snapshots, pass it into StateSync.get_snapshots() + """ + if not snapshot_names: + return set() + + if exclude_expired: + current_ts = current_ts or now_timestamp() + unexpired_expr = (exp.column("updated_ts") + exp.column("ttl_ms")) > current_ts + else: + unexpired_expr = None + + return { + SnapshotIdAndVersion( + name=name, + identifier=identifier, + version=version, + kind_name=kind_name or None, + dev_version=dev_version, + fingerprint=fingerprint, + ) + for where in snapshot_name_filter( + snapshot_names=snapshot_names, + batch_size=self.SNAPSHOT_BATCH_SIZE, + ) + for name, identifier, version, kind_name, dev_version, fingerprint in fetchall( + self.engine_adapter, + exp.select( + "name", "identifier", "version", "kind_name", "dev_version", "fingerprint" + ) + .from_(self.snapshots_table) + .where(where) + .and_(unexpired_expr), + ) + } + + def snapshots_exist(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> t.Set[SnapshotId]: + """Checks if snapshots exist. + + Args: + snapshot_ids: The snapshot IDs to check. + + Returns: + A set of snapshot IDs to check for existence. + """ + return { + SnapshotId(name=name, identifier=identifier) + for where in snapshot_id_filter( + self.engine_adapter, snapshot_ids, batch_size=self.SNAPSHOT_BATCH_SIZE + ) + for name, identifier in fetchall( + self.engine_adapter, + exp.select("name", "identifier").from_(self.snapshots_table).where(where), + ) + } + + def nodes_exist(self, names: t.Iterable[str], exclude_external: bool = False) -> t.Set[str]: + """Checks if nodes with given names exist. + + Args: + names: The node names to check. + exclude_external: Whether to exclude external nodes. + + Returns: + A set of node names that exist. + """ + names = set(names) + + if not names: + return names + + query = ( + exp.select("name") + .from_(self.snapshots_table) + .where(exp.column("name").isin(*names)) + .distinct() + ) + if exclude_external: + query = query.where(exp.column("kind_name").neq(ModelKindName.EXTERNAL.value)) + return {name for (name,) in fetchall(self.engine_adapter, query)} + + def update_auto_restatements( + self, next_auto_restatement_ts: t.Dict[SnapshotNameVersion, t.Optional[int]] + ) -> None: + """Updates the auto restatement timestamps. + + Args: + next_auto_restatement_ts: A dictionary of snapshot name version to the next auto restatement timestamp. + """ + next_auto_restatement_ts_deleted = [] + next_auto_restatement_ts_filtered = {} + for k, v in next_auto_restatement_ts.items(): + if v is None: + next_auto_restatement_ts_deleted.append(k) + else: + next_auto_restatement_ts_filtered[k] = v + + for where in snapshot_name_version_filter( + self.engine_adapter, + next_auto_restatement_ts_deleted, + column_prefix="snapshot", + alias=None, + batch_size=self.SNAPSHOT_BATCH_SIZE, + ): + self.engine_adapter.delete_from(self.auto_restatements_table, where=where) + + if not next_auto_restatement_ts_filtered: + return + + self.engine_adapter.merge( + self.auto_restatements_table, + _auto_restatements_to_df(next_auto_restatement_ts_filtered), + target_columns_to_types=self._auto_restatement_columns_to_types, + unique_key=(exp.column("snapshot_name"), exp.column("snapshot_version")), + ) + + def count(self) -> int: + """Counts the number of snapshots in the state.""" + result = fetchone(self.engine_adapter, exp.select("COUNT(*)").from_(self.snapshots_table)) + return result[0] if result else 0 + + def clear_cache(self) -> None: + """Clears the snapshot cache.""" + self._snapshot_cache.clear() + + def _update_snapshots( + self, + snapshots: t.Iterable[SnapshotIdLike], + **kwargs: t.Any, + ) -> None: + properties = kwargs + if "updated_ts" not in properties: + properties["updated_ts"] = now_timestamp() + + for where in snapshot_id_filter( + self.engine_adapter, snapshots, batch_size=self.SNAPSHOT_BATCH_SIZE + ): + self.engine_adapter.update_table( + self.snapshots_table, + properties, + where=where, + ) + + def _push_snapshots(self, snapshots: t.Iterable[Snapshot]) -> None: + snapshots_to_store = [] + for snapshot in snapshots: + if isinstance(snapshot.node, SeedModel): + seed_model = t.cast(SeedModel, snapshot.node) + snapshot = snapshot.copy(update={"node": seed_model.to_dehydrated()}) + snapshots_to_store.append(snapshot) + + self.engine_adapter.insert_append( + self.snapshots_table, + _snapshots_to_df(snapshots_to_store), + target_columns_to_types=self._snapshot_columns_to_types, + track_rows_processed=False, + ) + + def _get_snapshots( + self, + snapshot_ids: t.Iterable[SnapshotIdLike], + lock_for_update: bool = False, + ) -> t.Dict[SnapshotId, Snapshot]: + """Fetches specified snapshots or all snapshots. + + Args: + snapshot_ids: The collection of snapshot like objects to fetch. + lock_for_update: Lock the snapshot rows for future update + + Returns: + A dictionary of snapshot ids to snapshots for ones that could be found. + """ + duplicates: t.Dict[SnapshotId, Snapshot] = {} + + def _loader(snapshot_ids_to_load: t.Set[SnapshotId]) -> t.Collection[Snapshot]: + fetched_snapshots: t.Dict[SnapshotId, Snapshot] = {} + for query in self._get_snapshots_expressions(snapshot_ids_to_load, lock_for_update): + for ( + serialized_snapshot, + _, + _, + _, + updated_ts, + unpaused_ts, + unrestorable, + forward_only, + next_auto_restatement_ts, + ) in fetchall(self.engine_adapter, query): + snapshot = parse_snapshot( + serialized_snapshot=serialized_snapshot, + updated_ts=updated_ts, + unpaused_ts=unpaused_ts, + unrestorable=unrestorable, + forward_only=forward_only, + next_auto_restatement_ts=next_auto_restatement_ts, + ) + snapshot_id = snapshot.snapshot_id + if snapshot_id in fetched_snapshots: + other = duplicates.get(snapshot_id, fetched_snapshots[snapshot_id]) + duplicates[snapshot_id] = ( + snapshot if snapshot.updated_ts > other.updated_ts else other + ) + fetched_snapshots[snapshot_id] = duplicates[snapshot_id] + else: + fetched_snapshots[snapshot_id] = snapshot + return fetched_snapshots.values() + + snapshots, cached_snapshots = self._snapshot_cache.get_or_load( + {s.snapshot_id for s in snapshot_ids}, _loader + ) + + if cached_snapshots: + cached_snapshots_in_state: t.Set[SnapshotId] = set() + for where in snapshot_id_filter( + self.engine_adapter, cached_snapshots, batch_size=self.SNAPSHOT_BATCH_SIZE + ): + query = ( + exp.select( + "name", + "identifier", + "updated_ts", + "unpaused_ts", + "unrestorable", + "forward_only", + "next_auto_restatement_ts", + ) + .from_(exp.to_table(self.snapshots_table).as_("snapshots")) + .join( + exp.to_table(self.auto_restatements_table).as_("auto_restatements"), + on=exp.and_( + exp.column("name", table="snapshots").eq( + exp.column("snapshot_name", table="auto_restatements") + ), + exp.column("version", table="snapshots").eq( + exp.column("snapshot_version", table="auto_restatements") + ), + ), + join_type="left", + copy=False, + ) + .where(where) + ) + if lock_for_update: + query = query.lock(copy=False) + for ( + name, + identifier, + updated_ts, + unpaused_ts, + unrestorable, + forward_only, + next_auto_restatement_ts, + ) in fetchall(self.engine_adapter, query): + snapshot_id = SnapshotId(name=name, identifier=identifier) + snapshot = snapshots[snapshot_id] + snapshot.updated_ts = updated_ts + snapshot.unpaused_ts = unpaused_ts + snapshot.unrestorable = unrestorable + snapshot.forward_only = forward_only + snapshot.next_auto_restatement_ts = next_auto_restatement_ts + cached_snapshots_in_state.add(snapshot_id) + + missing_cached_snapshots = cached_snapshots - cached_snapshots_in_state + for missing_cached_snapshot_id in missing_cached_snapshots: + snapshots.pop(missing_cached_snapshot_id, None) + + if duplicates: + self.push_snapshots(duplicates.values(), overwrite=True) + logger.error("Found duplicate snapshots in the state store.") + + return snapshots + + def _get_snapshots_expressions( + self, + snapshot_ids: t.Iterable[SnapshotIdLike], + lock_for_update: bool = False, + ) -> t.Iterator[exp.Expression]: + for where in snapshot_id_filter( + self.engine_adapter, + snapshot_ids, + alias="snapshots", + batch_size=self.SNAPSHOT_BATCH_SIZE, + ): + query = ( + exp.select( + "snapshots.snapshot", + "snapshots.name", + "snapshots.identifier", + "snapshots.version", + "snapshots.updated_ts", + "snapshots.unpaused_ts", + "snapshots.unrestorable", + "snapshots.forward_only", + "auto_restatements.next_auto_restatement_ts", + ) + .from_(exp.to_table(self.snapshots_table).as_("snapshots")) + .join( + exp.to_table(self.auto_restatements_table).as_("auto_restatements"), + on=exp.and_( + exp.column("name", table="snapshots").eq( + exp.column("snapshot_name", table="auto_restatements") + ), + exp.column("version", table="snapshots").eq( + exp.column("snapshot_version", table="auto_restatements") + ), + ), + join_type="left", + copy=False, + ) + .where(where) + ) + if lock_for_update: + query = query.lock(copy=False) + yield query + + def _get_snapshots_with_same_version( + self, + snapshots: t.Collection[SnapshotNameVersionLike], + lock_for_update: bool = False, + ) -> t.List[SnapshotIdAndVersion]: + """Fetches all snapshots that share the same version as the snapshots. + + The output includes the snapshots with the specified identifiers. + + Args: + snapshots: The collection of target name / version pairs. + lock_for_update: Lock the snapshot rows for future update + + Returns: + The list of Snapshot objects. + """ + if not snapshots: + return [] + + snapshot_rows = [] + + for where in snapshot_name_version_filter( + self.engine_adapter, snapshots, batch_size=self.SNAPSHOT_BATCH_SIZE + ): + query = ( + exp.select( + "name", + "identifier", + "version", + "kind_name", + "dev_version", + "fingerprint", + ) + .from_(exp.to_table(self.snapshots_table).as_("snapshots")) + .where(where) + ) + if lock_for_update: + query = query.lock(copy=False) + + snapshot_rows.extend(fetchall(self.engine_adapter, query)) + + return [ + SnapshotIdAndVersion( + name=name, + identifier=identifier, + version=version, + kind_name=kind_name or None, + dev_version=dev_version, + fingerprint=SnapshotFingerprint.parse_raw(fingerprint), + ) + for name, identifier, version, kind_name, dev_version, fingerprint in snapshot_rows + ] + + +def parse_snapshot( + serialized_snapshot: str, + updated_ts: int, + unpaused_ts: t.Optional[int], + unrestorable: bool, + forward_only: bool, + next_auto_restatement_ts: t.Optional[int], +) -> Snapshot: + return Snapshot( + **{ + **json.loads(serialized_snapshot), + "updated_ts": updated_ts, + "unpaused_ts": unpaused_ts, + "unrestorable": unrestorable, + "forward_only": forward_only, + "next_auto_restatement_ts": next_auto_restatement_ts, + } + ) + + +def _snapshot_to_json(snapshot: Snapshot) -> str: + return snapshot.json( + exclude={ + "intervals", + "dev_intervals", + "pending_restatement_intervals", + "updated_ts", + "unpaused_ts", + "unrestorable", + "forward_only", + "next_auto_restatement_ts", + } + ) + + +def _snapshots_to_df(snapshots: t.Iterable[Snapshot]) -> pd.DataFrame: + import pandas as pd + + return pd.DataFrame( + [ + { + "name": snapshot.name, + "identifier": snapshot.identifier, + "version": snapshot.version, + "snapshot": _snapshot_to_json(snapshot), + "kind_name": snapshot.model_kind_name.value if snapshot.model_kind_name else None, + "updated_ts": snapshot.updated_ts, + "unpaused_ts": snapshot.unpaused_ts, + "ttl_ms": snapshot.ttl_ms, + "unrestorable": snapshot.unrestorable, + "forward_only": snapshot.forward_only, + "dev_version": snapshot.dev_version, + "fingerprint": snapshot.fingerprint.json(), + } + for snapshot in snapshots + ] + ) + + +def _auto_restatements_to_df(auto_restatements: t.Dict[SnapshotNameVersion, int]) -> pd.DataFrame: + import pandas as pd + + return pd.DataFrame( + [ + { + "snapshot_name": name_version.name, + "snapshot_version": name_version.version, + "next_auto_restatement_ts": ts, + } + for name_version, ts in auto_restatements.items() + ] + ) diff --git a/sqlmesh/core/state_sync/db/utils.py b/sqlmesh/core/state_sync/db/utils.py new file mode 100644 index 0000000000..87c259f5d6 --- /dev/null +++ b/sqlmesh/core/state_sync/db/utils.py @@ -0,0 +1,133 @@ +from __future__ import annotations + +import typing as t +import logging + +from sqlglot import exp +from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.core.snapshot import SnapshotIdLike, SnapshotNameVersionLike + + +logger = logging.getLogger(__name__) + +try: + # We can't import directly from the root package due to circular dependency + from sqlmesh._version import __version__ as SQLMESH_VERSION # noqa +except ImportError: + logger.error( + 'Unable to set __version__, run "pip install -e ." or "python setup.py develop" first.' + ) + + +T = t.TypeVar("T") + + +def snapshot_name_filter( + snapshot_names: t.Iterable[str], + batch_size: int, + alias: t.Optional[str] = None, +) -> t.Iterator[exp.Condition]: + names = sorted(snapshot_names) + + if not names: + yield exp.false() + else: + batches = create_batches(names, batch_size=batch_size) + for names in batches: + yield exp.column("name", table=alias).isin(*names) + + +def snapshot_id_filter( + engine_adapter: EngineAdapter, + snapshot_ids: t.Iterable[SnapshotIdLike], + batch_size: int, + alias: t.Optional[str] = None, +) -> t.Iterator[exp.Condition]: + name_identifiers = sorted( + {(snapshot_id.name, snapshot_id.identifier) for snapshot_id in snapshot_ids} + ) + batches = create_batches(name_identifiers, batch_size=batch_size) + + if not name_identifiers: + yield exp.false() + elif engine_adapter.SUPPORTS_TUPLE_IN: + for identifiers in batches: + yield t.cast( + exp.Tuple, + exp.convert( + ( + exp.column("name", table=alias), + exp.column("identifier", table=alias), + ) + ), + ).isin(*identifiers) + else: + for identifiers in batches: + yield exp.or_( + *[ + exp.and_( + exp.column("name", table=alias).eq(name), + exp.column("identifier", table=alias).eq(identifier), + ) + for name, identifier in identifiers + ] + ) + + +def snapshot_name_version_filter( + engine_adapter: EngineAdapter, + snapshot_name_versions: t.Iterable[SnapshotNameVersionLike], + batch_size: int, + version_column_name: str = "version", + alias: t.Optional[str] = "snapshots", + column_prefix: t.Optional[str] = None, +) -> t.Iterator[exp.Condition]: + name_versions = sorted({(s.name, s.version) for s in snapshot_name_versions}) + batches = create_batches(name_versions, batch_size=batch_size) + + name_column_name = "name" + if column_prefix: + name_column_name = f"{column_prefix}_{name_column_name}" + version_column_name = f"{column_prefix}_{version_column_name}" + + name_column = exp.column(name_column_name, table=alias) + version_column = exp.column(version_column_name, table=alias) + + if not name_versions: + yield exp.false() + elif engine_adapter.SUPPORTS_TUPLE_IN: + for versions in batches: + yield t.cast( + exp.Tuple, + exp.convert( + ( + name_column, + version_column, + ) + ), + ).isin(*versions) + else: + for versions in batches: + yield exp.or_( + *[ + exp.and_( + name_column.eq(name), + version_column.eq(version), + ) + for name, version in versions + ] + ) + + +def create_batches(l: t.List[T], batch_size: int) -> t.List[t.List[T]]: + return [l[i : i + batch_size] for i in range(0, len(l), batch_size)] + + +def fetchone( + engine_adapter: EngineAdapter, query: t.Union[exp.Expression, str] +) -> t.Optional[t.Tuple]: + return engine_adapter.fetchone(query, ignore_unsupported_errors=True, quote_identifiers=True) + + +def fetchall(engine_adapter: EngineAdapter, query: t.Union[exp.Expression, str]) -> t.List[t.Tuple]: + return engine_adapter.fetchall(query, ignore_unsupported_errors=True, quote_identifiers=True) diff --git a/sqlmesh/core/state_sync/db/version.py b/sqlmesh/core/state_sync/db/version.py new file mode 100644 index 0000000000..c95592bc31 --- /dev/null +++ b/sqlmesh/core/state_sync/db/version.py @@ -0,0 +1,74 @@ +from __future__ import annotations + +import logging +import typing as t + +from sqlglot import __version__ as SQLGLOT_VERSION +from sqlglot import exp +from sqlglot.helper import seq_get + +from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.core.state_sync.db.utils import ( + fetchone, + SQLMESH_VERSION, +) +from sqlmesh.core.state_sync.base import ( + SCHEMA_VERSION, + Versions, +) +from sqlmesh.utils.migration import index_text_type + +logger = logging.getLogger(__name__) + + +class VersionState: + def __init__(self, engine_adapter: EngineAdapter, schema: t.Optional[str] = None): + self.engine_adapter = engine_adapter + self.versions_table = exp.table_("_versions", db=schema) + + index_type = index_text_type(engine_adapter.dialect) + self._version_columns_to_types = { + "schema_version": exp.DataType.build("int"), + "sqlglot_version": exp.DataType.build(index_type), + "sqlmesh_version": exp.DataType.build(index_type), + } + + def update_versions( + self, + schema_version: int = SCHEMA_VERSION, + sqlglot_version: str = SQLGLOT_VERSION, + sqlmesh_version: str = SQLMESH_VERSION, + ) -> None: + import pandas as pd + + self.engine_adapter.delete_from(self.versions_table, "TRUE") + + self.engine_adapter.insert_append( + self.versions_table, + pd.DataFrame( + [ + { + "schema_version": schema_version, + "sqlglot_version": sqlglot_version, + "sqlmesh_version": sqlmesh_version, + } + ] + ), + target_columns_to_types=self._version_columns_to_types, + track_rows_processed=False, + ) + + def get_versions(self) -> Versions: + no_version = Versions() + + if not self.engine_adapter.table_exists(self.versions_table): + return no_version + + query = exp.select("*").from_(self.versions_table) + row = fetchone(self.engine_adapter, query) + if not row: + return no_version + + return Versions( + schema_version=row[0], sqlglot_version=row[1], sqlmesh_version=seq_get(row, 2) + ) diff --git a/sqlmesh/core/state_sync/engine_adapter.py b/sqlmesh/core/state_sync/engine_adapter.py deleted file mode 100644 index fc04efc379..0000000000 --- a/sqlmesh/core/state_sync/engine_adapter.py +++ /dev/null @@ -1,1464 +0,0 @@ -""" -# StateSync - -State sync is how SQLMesh keeps track of environments and their states, e.g. snapshots. - -# StateReader - -StateReader provides a subset of the functionalities of the StateSync class. As its name -implies, it only allows for read-only operations on snapshots and environment states. - -# EngineAdapterStateSync - -The provided `sqlmesh.core.state_sync.EngineAdapterStateSync` leverages an existing engine -adapter to read and write state to the underlying data store. -""" - -from __future__ import annotations - -import contextlib -import json -import logging -import time -import typing as t -from collections import defaultdict -from copy import deepcopy -from pathlib import Path - -import pandas as pd -from sqlglot import __version__ as SQLGLOT_VERSION -from sqlglot import exp -from sqlglot.helper import seq_get - -from sqlmesh.core import analytics -from sqlmesh.core import constants as c -from sqlmesh.core.audit import ModelAudit -from sqlmesh.core.console import Console, get_console -from sqlmesh.core.engine_adapter import EngineAdapter -from sqlmesh.core.environment import Environment -from sqlmesh.core.model import ModelCache, ModelKindName, SeedModel -from sqlmesh.core.snapshot import ( - Intervals, - Node, - Snapshot, - SnapshotFingerprint, - SnapshotId, - SnapshotIdLike, - SnapshotInfoLike, - SnapshotIntervals, - SnapshotNameVersion, - SnapshotNameVersionLike, - SnapshotTableCleanupTask, - SnapshotTableInfo, - fingerprint_from_node, -) -from sqlmesh.core.snapshot.definition import ( - Interval, - _parents_from_node, - merge_intervals, - remove_interval, -) -from sqlmesh.core.state_sync.base import MIGRATIONS, SCHEMA_VERSION, StateSync, Versions -from sqlmesh.core.state_sync.common import CommonStateSyncMixin, transactional -from sqlmesh.utils import major_minor, random_id, unique -from sqlmesh.utils.dag import DAG -from sqlmesh.utils.date import TimeLike, now_timestamp, time_like_to_str -from sqlmesh.utils.errors import SQLMeshError -from sqlmesh.utils.pydantic import parse_obj_as - -logger = logging.getLogger(__name__) - - -T = t.TypeVar("T") - - -if t.TYPE_CHECKING: - from sqlmesh.core._typing import TableName - -try: - # We can't import directly from the root package due to circular dependency - from sqlmesh._version import __version__ as SQLMESH_VERSION # type: ignore -except ImportError: - logger.error( - 'Unable to set __version__, run "pip install -e ." or "python setup.py develop" first.' - ) - - -class EngineAdapterStateSync(CommonStateSyncMixin, StateSync): - """Manages state of nodes and snapshot with an existing engine adapter. - - This state sync is convenient to use because it requires no additional setup. - You can reuse the same engine/warehouse that your data is stored in. - - Args: - engine_adapter: The EngineAdapter to use to store and fetch snapshots. - schema: The schema to store state metadata in. If None or empty string then no schema is defined - console: The console to log information to. - context_path: The context path, used for caching snapshot models. - """ - - INTERVAL_BATCH_SIZE = 1000 - SNAPSHOT_BATCH_SIZE = 1000 - SNAPSHOT_MIGRATION_BATCH_SIZE = 500 - - def __init__( - self, - engine_adapter: EngineAdapter, - schema: t.Optional[str], - console: t.Optional[Console] = None, - context_path: Path = Path(), - ): - # Make sure that if an empty string is provided that we treat it as None - self.schema = schema or None - self.engine_adapter = engine_adapter - self._context_path = context_path - self.console = console or get_console() - self.snapshots_table = exp.table_("_snapshots", db=self.schema) - self.environments_table = exp.table_("_environments", db=self.schema) - self.intervals_table = exp.table_("_intervals", db=self.schema) - self.plan_dags_table = exp.table_("_plan_dags", db=self.schema) - self.versions_table = exp.table_("_versions", db=self.schema) - - self._snapshot_columns_to_types = { - "name": exp.DataType.build("text"), - "identifier": exp.DataType.build("text"), - "version": exp.DataType.build("text"), - "snapshot": exp.DataType.build("text"), - "kind_name": exp.DataType.build("text"), - "expiration_ts": exp.DataType.build("bigint"), - } - - self._environment_columns_to_types = { - "name": exp.DataType.build("text"), - "snapshots": exp.DataType.build("text"), - "start_at": exp.DataType.build("text"), - "end_at": exp.DataType.build("text"), - "plan_id": exp.DataType.build("text"), - "previous_plan_id": exp.DataType.build("text"), - "expiration_ts": exp.DataType.build("bigint"), - "finalized_ts": exp.DataType.build("bigint"), - "promoted_snapshot_ids": exp.DataType.build("text"), - "suffix_target": exp.DataType.build("text"), - "catalog_name_override": exp.DataType.build("text"), - "previous_finalized_snapshots": exp.DataType.build("text"), - "normalize_name": exp.DataType.build("boolean"), - } - - self._interval_columns_to_types = { - "id": exp.DataType.build("text"), - "created_ts": exp.DataType.build("bigint"), - "name": exp.DataType.build("text"), - "identifier": exp.DataType.build("text"), - "version": exp.DataType.build("text"), - "start_ts": exp.DataType.build("bigint"), - "end_ts": exp.DataType.build("bigint"), - "is_dev": exp.DataType.build("boolean"), - "is_removed": exp.DataType.build("boolean"), - "is_compacted": exp.DataType.build("boolean"), - } - - self._version_columns_to_types = { - "schema_version": exp.DataType.build("int"), - "sqlglot_version": exp.DataType.build("text"), - "sqlmesh_version": exp.DataType.build("text"), - } - - def _fetchone(self, query: t.Union[exp.Expression, str]) -> t.Tuple: - return self.engine_adapter.fetchone( - query, ignore_unsupported_errors=True, quote_identifiers=True - ) - - def _fetchall(self, query: t.Union[exp.Expression, str]) -> t.List[t.Tuple]: - return self.engine_adapter.fetchall( - query, ignore_unsupported_errors=True, quote_identifiers=True - ) - - @transactional() - def push_snapshots(self, snapshots: t.Iterable[Snapshot]) -> None: - """Pushes snapshots to the state store, merging them with existing ones. - - This method first finds all existing snapshots in the store and merges them with - the local snapshots. It will then delete all existing snapshots and then - insert all the local snapshots. This can be made safer with locks or merge/upsert. - - Args: - snapshot_ids: Iterable of snapshot ids to bulk push. - """ - snapshots_by_id = {} - for snapshot in snapshots: - if not snapshot.version: - raise SQLMeshError( - f"Snapshot {snapshot} has not been versioned yet. Create a plan before pushing a snapshot." - ) - snapshots_by_id[snapshot.snapshot_id] = snapshot - - existing = self.snapshots_exist(snapshots_by_id) - - if existing: - raise SQLMeshError(f"Snapshots {existing} already exists.") - - snapshots = snapshots_by_id.values() - - if snapshots: - self._push_snapshots(snapshots) - - def _push_snapshots(self, snapshots: t.Iterable[Snapshot], overwrite: bool = False) -> None: - if overwrite: - snapshots = tuple(snapshots) - self.delete_snapshots(snapshots) - - snapshots_to_store = [] - - for snapshot in snapshots: - if isinstance(snapshot.node, SeedModel): - seed_model = t.cast(SeedModel, snapshot.node) - snapshot = snapshot.copy(update={"node": seed_model.to_dehydrated()}) - snapshots_to_store.append(snapshot) - - self.engine_adapter.insert_append( - self.snapshots_table, - _snapshots_to_df(snapshots_to_store), - columns_to_types=self._snapshot_columns_to_types, - ) - - def _update_versions( - self, - schema_version: int = SCHEMA_VERSION, - sqlglot_version: str = SQLGLOT_VERSION, - sqlmesh_version: str = SQLMESH_VERSION, - ) -> None: - self.engine_adapter.delete_from(self.versions_table, "TRUE") - - self.engine_adapter.insert_append( - self.versions_table, - pd.DataFrame( - [ - { - "schema_version": schema_version, - "sqlglot_version": sqlglot_version, - "sqlmesh_version": sqlmesh_version, - } - ] - ), - columns_to_types=self._version_columns_to_types, - ) - - def invalidate_environment(self, name: str) -> None: - name = name.lower() - if name == c.PROD: - raise SQLMeshError("Cannot invalidate the production environment.") - - filter_expr = exp.column("name").eq(name) - - self.engine_adapter.update_table( - self.environments_table, - {"expiration_ts": now_timestamp()}, - where=filter_expr, - ) - - @transactional() - def delete_expired_snapshots( - self, ignore_ttl: bool = False - ) -> t.List[SnapshotTableCleanupTask]: - current_ts = now_timestamp(minute_floor=False) - - expired_query = exp.select("name", "identifier", "version").from_(self.snapshots_table) - - if not ignore_ttl: - expired_query = expired_query.where(exp.column("expiration_ts") <= current_ts) - - expired_candidates = { - SnapshotId(name=name, identifier=identifier): SnapshotNameVersion( - name=name, version=version - ) - for name, identifier, version in self._fetchall(expired_query) - } - if not expired_candidates: - return [] - - promoted_snapshot_ids = { - snapshot.snapshot_id - for environment in self.get_environments() - for snapshot in environment.snapshots - } - - def _is_snapshot_used(snapshot: Snapshot) -> bool: - return ( - snapshot.snapshot_id in promoted_snapshot_ids - or snapshot.snapshot_id not in expired_candidates - ) - - unique_expired_versions = unique(expired_candidates.values()) - version_batches = self._batches(unique_expired_versions) - cleanup_targets = [] - for versions_batch in version_batches: - snapshots = self._get_snapshots_with_same_version(versions_batch) - - snapshots_by_version = defaultdict(set) - snapshots_by_temp_version = defaultdict(set) - for s in snapshots: - snapshots_by_version[(s.name, s.version)].add(s.snapshot_id) - snapshots_by_temp_version[(s.name, s.temp_version_get_or_generate())].add( - s.snapshot_id - ) - - expired_snapshots = [s for s in snapshots if not _is_snapshot_used(s)] - - if expired_snapshots: - self.delete_snapshots(expired_snapshots) - - for snapshot in expired_snapshots: - shared_version_snapshots = snapshots_by_version[(snapshot.name, snapshot.version)] - shared_version_snapshots.discard(snapshot.snapshot_id) - - shared_temp_version_snapshots = snapshots_by_temp_version[ - (snapshot.name, snapshot.temp_version_get_or_generate()) - ] - shared_temp_version_snapshots.discard(snapshot.snapshot_id) - - if not shared_temp_version_snapshots: - cleanup_targets.append( - SnapshotTableCleanupTask( - snapshot=snapshot.table_info, - dev_table_only=bool(shared_version_snapshots), - ) - ) - - return cleanup_targets - - def delete_expired_environments(self) -> t.List[Environment]: - now_ts = now_timestamp() - filter_expr = exp.LTE( - this=exp.column("expiration_ts"), - expression=exp.Literal.number(now_ts), - ) - - rows = self._fetchall( - self._environments_query( - where=filter_expr, - lock_for_update=True, - ) - ) - environments = [self._environment_from_row(r) for r in rows] - - self.engine_adapter.delete_from( - self.environments_table, - where=filter_expr, - ) - - return environments - - def delete_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None: - for where in self._snapshot_id_filter(snapshot_ids): - self.engine_adapter.delete_from(self.snapshots_table, where=where) - - def snapshots_exist(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> t.Set[SnapshotId]: - return self._snapshot_ids_exist(snapshot_ids, self.snapshots_table) - - def nodes_exist(self, names: t.Iterable[str], exclude_external: bool = False) -> t.Set[str]: - names = set(names) - - if not names: - return names - - query = ( - exp.select("name") - .from_(self.snapshots_table) - .where(exp.column("name").isin(*names)) - .distinct() - ) - if exclude_external: - query = query.where(exp.column("kind_name").neq(ModelKindName.EXTERNAL.value)) - return {name for (name,) in self._fetchall(query)} - - def reset(self, default_catalog: t.Optional[str]) -> None: - """Resets the state store to the state when it was first initialized.""" - self.engine_adapter.drop_table(self.snapshots_table) - self.engine_adapter.drop_table(self.environments_table) - self.engine_adapter.drop_table(self.versions_table) - self.migrate(default_catalog) - - def _update_environment(self, environment: Environment) -> None: - self.engine_adapter.delete_from( - self.environments_table, - where=exp.EQ( - this=exp.column("name"), - expression=exp.Literal.string(environment.name), - ), - ) - - self.engine_adapter.insert_append( - self.environments_table, - _environment_to_df(environment), - columns_to_types=self._environment_columns_to_types, - ) - - def _update_snapshot(self, snapshot: Snapshot) -> None: - snapshot.updated_ts = now_timestamp() - for where in self._snapshot_id_filter([snapshot.snapshot_id]): - self.engine_adapter.update_table( - self.snapshots_table, - {"snapshot": _snapshot_to_json(snapshot), "expiration_ts": snapshot.expiration_ts}, - where=where, - ) - - def get_environments(self) -> t.List[Environment]: - """Fetches all environments. - - Returns: - A list of all environments. - """ - return [ - self._environment_from_row(row) for row in self._fetchall(self._environments_query()) - ] - - def _environment_from_row(self, row: t.Tuple[str, ...]) -> Environment: - return Environment(**{field: row[i] for i, field in enumerate(Environment.all_fields())}) - - def _environments_query( - self, - where: t.Optional[str | exp.Expression] = None, - lock_for_update: bool = False, - ) -> exp.Select: - query = ( - exp.select(*(exp.to_identifier(field) for field in Environment.all_fields())) - .from_(self.environments_table) - .where(where) - ) - if lock_for_update: - return query.lock(copy=False) - return query - - def _get_snapshots_expressions( - self, - snapshot_ids: t.Optional[t.Iterable[SnapshotIdLike]] = None, - lock_for_update: bool = False, - batch_size: t.Optional[int] = None, - ) -> t.Iterator[exp.Expression]: - for where in ( - [None] - if snapshot_ids is None - else self._snapshot_id_filter(snapshot_ids, alias="snapshots", batch_size=batch_size) - ): - query = ( - exp.select( - "snapshots.snapshot", - "snapshots.name", - "snapshots.identifier", - "snapshots.version", - ) - .from_(exp.to_table(self.snapshots_table).as_("snapshots")) - .where(where) - ) - if lock_for_update: - query = query.lock(copy=False) - yield query - - def _get_snapshots( - self, - snapshot_ids: t.Optional[t.Iterable[SnapshotIdLike]] = None, - lock_for_update: bool = False, - hydrate_intervals: bool = True, - ) -> t.Dict[SnapshotId, Snapshot]: - """Fetches specified snapshots or all snapshots. - - Args: - snapshot_ids: The collection of snapshot like objects to fetch. - lock_for_update: Lock the snapshot rows for future update - hydrate_intervals: Whether to hydrate result snapshots with intervals. - - Returns: - A dictionary of snapshot ids to snapshots for ones that could be found. - """ - snapshots: t.Dict[SnapshotId, Snapshot] = {} - duplicates: t.Dict[SnapshotId, Snapshot] = {} - model_cache = ModelCache(self._context_path / c.CACHE) - - for query in self._get_snapshots_expressions(snapshot_ids, lock_for_update): - for serialized_snapshot, name, identifier, _ in self._fetchall(query): - snapshot = parse_snapshot( - model_cache, - serialized_snapshot=serialized_snapshot, - name=name, - identifier=identifier, - ) - snapshot_id = snapshot.snapshot_id - if snapshot_id in snapshots: - other = duplicates.get(snapshot_id, snapshots[snapshot_id]) - duplicates[snapshot_id] = ( - snapshot if snapshot.updated_ts > other.updated_ts else other - ) - snapshots[snapshot_id] = duplicates[snapshot_id] - else: - snapshots[snapshot_id] = snapshot - - if snapshots and hydrate_intervals: - _, intervals = self._get_snapshot_intervals(snapshots.values()) - Snapshot.hydrate_with_intervals_by_version(snapshots.values(), intervals) - - if duplicates: - self._push_snapshots(duplicates.values(), overwrite=True) - logger.error("Found duplicate snapshots in the state store.") - - return snapshots - - def _get_snapshots_with_same_version( - self, - snapshots: t.Collection[SnapshotNameVersionLike], - lock_for_update: bool = False, - ) -> t.List[Snapshot]: - """Fetches all snapshots that share the same version as the snapshots. - - The output includes the snapshots with the specified identifiers. - - Args: - snapshots: The collection of target name / version pairs. - lock_for_update: Lock the snapshot rows for future update - - Returns: - The list of Snapshot objects. - """ - if not snapshots: - return [] - - snapshot_rows = [] - - for where in self._snapshot_name_version_filter(snapshots): - query = ( - exp.select("snapshot") - .from_(exp.to_table(self.snapshots_table).as_("snapshots")) - .where(where) - ) - if lock_for_update: - query = query.lock(copy=False) - - snapshot_rows.extend(self._fetchall(query)) - - return [Snapshot(**json.loads(row[0])) for row in snapshot_rows] - - def _get_versions(self, lock_for_update: bool = False) -> Versions: - no_version = Versions() - - if not self.engine_adapter.table_exists(self.versions_table): - return no_version - - query = exp.select("*").from_(self.versions_table) - if lock_for_update: - query.lock(copy=False) - - row = self._fetchone(query) - if not row: - return no_version - - return Versions( - schema_version=row[0], sqlglot_version=row[1], sqlmesh_version=seq_get(row, 2) - ) - - def _get_environment( - self, environment: str, lock_for_update: bool = False - ) -> t.Optional[Environment]: - """Fetches the environment if it exists. - - Args: - environment: The environment - lock_for_update: Lock the snapshot rows for future update - - Returns: - The environment object. - """ - row = self._fetchone( - self._environments_query( - where=exp.EQ( - this=exp.column("name"), - expression=exp.Literal.string(environment), - ), - lock_for_update=lock_for_update, - ) - ) - - if not row: - return None - - env = self._environment_from_row(row) - return env - - @transactional() - def add_interval( - self, - snapshot: Snapshot, - start: TimeLike, - end: TimeLike, - is_dev: bool = False, - ) -> None: - super().add_interval(snapshot, start, end, is_dev) - - @transactional() - def _add_snapshot_intervals(self, snapshot_intervals: SnapshotIntervals) -> None: - def remove_partial_intervals( - intervals: t.List[Interval], snapshot_id: SnapshotId, *, is_dev: bool - ) -> t.List[Interval]: - results = [] - for start_ts, end_ts in intervals: - if start_ts < end_ts: - logger.info( - "Adding %s (%s, %s) for snapshot %s", - "dev interval" if is_dev else "interval", - start_ts, - end_ts, - snapshot_id, - ) - results.append((start_ts, end_ts)) - else: - logger.info( - "Skipping partial interval (%s, %s) for snapshot %s", - start_ts, - end_ts, - snapshot_intervals.snapshot_id, - ) - return results - - snapshot_intervals = snapshot_intervals.copy( - update={ - "intervals": remove_partial_intervals( - snapshot_intervals.intervals, snapshot_intervals.snapshot_id, is_dev=False - ), - "dev_intervals": remove_partial_intervals( - snapshot_intervals.dev_intervals, snapshot_intervals.snapshot_id, is_dev=True - ), - } - ) - if not snapshot_intervals.intervals and not snapshot_intervals.dev_intervals: - return - self.engine_adapter.insert_append( - self.intervals_table, - _snapshot_interval_to_df(snapshot_intervals, is_removed=False), - columns_to_types=self._interval_columns_to_types, - ) - - @transactional() - def remove_interval( - self, - snapshot_intervals: t.Sequence[t.Tuple[SnapshotInfoLike, Interval]], - remove_shared_versions: bool = False, - ) -> None: - intervals_to_remove: t.Sequence[ - t.Tuple[t.Union[SnapshotInfoLike, SnapshotIntervals], Interval] - ] = snapshot_intervals - if remove_shared_versions: - name_version_mapping = {s.name_version: interval for s, interval in snapshot_intervals} - all_snapshots = [] - for where in self._snapshot_name_version_filter(name_version_mapping, alias=None): - all_snapshots.extend( - [ - SnapshotIntervals( - name=r[0], identifier=r[1], version=r[2], intervals=[], dev_intervals=[] - ) - for r in self._fetchall( - exp.select("name", "identifier", "version") - .from_(self.intervals_table) - .where(where) - ) - ] - ) - intervals_to_remove = [ - (snapshot, name_version_mapping[snapshot.name_version]) - for snapshot in all_snapshots - ] - - if logger.isEnabledFor(logging.INFO): - snapshot_ids = ", ".join(str(s.snapshot_id) for s, _ in intervals_to_remove) - logger.info("Removing interval for snapshots: %s", snapshot_ids) - - for is_dev in (True, False): - self.engine_adapter.insert_append( - self.intervals_table, - _intervals_to_df(intervals_to_remove, is_dev=is_dev, is_removed=True), - columns_to_types=self._interval_columns_to_types, - ) - - @transactional() - def compact_intervals(self) -> None: - interval_ids, snapshot_intervals = self._get_snapshot_intervals(uncompacted_only=True) - - logger.info( - "Compacting %s intervals for %s snapshots", len(interval_ids), len(snapshot_intervals) - ) - - self._push_snapshot_intervals(snapshot_intervals) - - if interval_ids: - for interval_id_batch in self._batches( - list(interval_ids), batch_size=self.INTERVAL_BATCH_SIZE - ): - self.engine_adapter.delete_from( - self.intervals_table, exp.column("id").isin(*interval_id_batch) - ) - - def refresh_snapshot_intervals(self, snapshots: t.Collection[Snapshot]) -> t.List[Snapshot]: - if not snapshots: - return [] - - _, intervals = self._get_snapshot_intervals(snapshots) - for s in snapshots: - s.intervals = [] - s.dev_intervals = [] - return Snapshot.hydrate_with_intervals_by_version(snapshots, intervals) - - def max_interval_end_for_environment( - self, environment: str, ensure_finalized_snapshots: bool = False - ) -> t.Optional[int]: - env = self._get_environment(environment) - if not env: - return None - - max_end = None - snapshots = ( - env.snapshots if not ensure_finalized_snapshots else env.finalized_or_current_snapshots - ) - for where in self._snapshot_name_version_filter(snapshots, "intervals"): - end = self._fetchone( - exp.select(exp.func("MAX", exp.to_column("end_ts"))) - .from_(exp.to_table(self.intervals_table).as_("intervals")) - .where(where, copy=False) - .where(exp.to_column("is_dev").not_(), copy=False), - )[0] - - if max_end is None: - max_end = end - elif end is not None: - max_end = max(max_end, end) - - return max_end - - def greatest_common_interval_end( - self, environment: str, models: t.Set[str], ensure_finalized_snapshots: bool = False - ) -> t.Optional[int]: - if not models: - return None - - env = self._get_environment(environment) - if not env: - return None - - snapshots = ( - env.snapshots if not ensure_finalized_snapshots else env.finalized_or_current_snapshots - ) - snapshots = [s for s in snapshots if s.name in models] - if not snapshots: - snapshots = env.snapshots - - greatest_common_end = None - - table_alias = "intervals" - name_col = exp.column("name", table=table_alias) - version_col = exp.column("version", table=table_alias) - - for where in self._snapshot_name_version_filter(snapshots, table_alias): - max_end_subquery = ( - exp.select( - name_col, - version_col, - exp.func("MAX", exp.column("end_ts", table=table_alias)).as_("max_end_ts"), - ) - .from_(exp.to_table(self.intervals_table).as_(table_alias)) - .where(where, copy=False) - .where(exp.to_column("is_dev").not_(), copy=False) - .group_by(name_col, version_col, copy=False) - ) - query = exp.select(exp.func("MIN", exp.column("max_end_ts"))).from_( - max_end_subquery.subquery(alias="max_ends") - ) - - end = self._fetchone(query)[0] - - if greatest_common_end is None: - greatest_common_end = end - elif end is not None: - greatest_common_end = min(greatest_common_end, end) - - return greatest_common_end - - def recycle(self) -> None: - self.engine_adapter.recycle() - - def close(self) -> None: - self.engine_adapter.close() - - def _get_snapshot_intervals( - self, - snapshots: t.Optional[t.Collection[SnapshotNameVersionLike]] = None, - uncompacted_only: bool = False, - ) -> t.Tuple[t.Set[str], t.List[SnapshotIntervals]]: - query = ( - exp.select( - "id", - exp.column("name", table="intervals"), - exp.column("identifier", table="intervals"), - "version", - "start_ts", - "end_ts", - "is_dev", - "is_removed", - ) - .from_(exp.to_table(self.intervals_table).as_("intervals")) - .order_by( - exp.column("name", table="intervals"), - exp.column("identifier", table="intervals"), - "created_ts", - "is_removed", - ) - ) - - if uncompacted_only: - query.join( - exp.select("name", "identifier") - .from_(exp.to_table(self.intervals_table).as_("intervals")) - .where(exp.column("is_compacted").not_()) - .distinct() - .subquery(alias="uncompacted"), - on=exp.and_( - exp.column("name", table="intervals").eq( - exp.column("name", table="uncompacted") - ), - exp.column("identifier", table="intervals").eq( - exp.column("identifier", table="uncompacted") - ), - ), - copy=False, - ) - - if not snapshots and snapshots is not None: - return (set(), []) - - interval_ids: t.Set[str] = set() - snapshot_intervals = [] - - for where in ( - self._snapshot_name_version_filter(snapshots, "intervals") if snapshots else [None] - ): - rows = self._fetchall(query.where(where)) - interval_ids.update(row[0] for row in rows) - - intervals: t.Dict[t.Tuple[str, str, str], Intervals] = defaultdict(list) - dev_intervals: t.Dict[t.Tuple[str, str, str], Intervals] = defaultdict(list) - for row in rows: - _, name, identifier, version, start, end, is_dev, is_removed = row - intervals_key = (name, identifier, version) - target_intervals = intervals if not is_dev else dev_intervals - if is_removed: - target_intervals[intervals_key] = remove_interval( - target_intervals[intervals_key], start, end - ) - else: - target_intervals[intervals_key] = merge_intervals( - [*target_intervals[intervals_key], (start, end)] - ) - - for name, identifier, version in {**intervals, **dev_intervals}: - snapshot_intervals.append( - SnapshotIntervals( - name=name, - identifier=identifier, - version=version, - intervals=intervals.get((name, identifier, version), []), - dev_intervals=dev_intervals.get((name, identifier, version), []), - ) - ) - - return interval_ids, snapshot_intervals - - def _push_snapshot_intervals( - self, snapshots: t.Iterable[t.Union[Snapshot, SnapshotIntervals]] - ) -> None: - new_intervals = [] - for snapshot in snapshots: - logger.info("Pushing intervals for snapshot %s", snapshot.snapshot_id) - for start_ts, end_ts in snapshot.intervals: - new_intervals.append( - _interval_to_df(snapshot, start_ts, end_ts, is_dev=False, is_compacted=True) - ) - for start_ts, end_ts in snapshot.dev_intervals: - new_intervals.append( - _interval_to_df(snapshot, start_ts, end_ts, is_dev=True, is_compacted=True) - ) - - if new_intervals: - self.engine_adapter.insert_append( - self.intervals_table, - pd.DataFrame(new_intervals), - columns_to_types=self._interval_columns_to_types, - ) - - def _restore_table( - self, - table_name: TableName, - backup_table_name: TableName, - ) -> None: - self.engine_adapter.drop_table(table_name) - self.engine_adapter.rename_table( - old_table_name=backup_table_name, - new_table_name=table_name, - ) - - @transactional() - def migrate( - self, - default_catalog: t.Optional[str], - skip_backup: bool = False, - promoted_snapshots_only: bool = True, - ) -> None: - """Migrate the state sync to the latest SQLMesh / SQLGlot version.""" - versions = self.get_versions(validate=False) - - migration_start_ts = time.perf_counter() - - try: - migrate_rows = self._apply_migrations(default_catalog, skip_backup) - - if not migrate_rows and major_minor(SQLMESH_VERSION) == versions.minor_sqlmesh_version: - return - - if migrate_rows: - self._migrate_rows(promoted_snapshots_only) - # Cleanup plan DAGs since we currently don't migrate snapshot records that are in there. - self.engine_adapter.delete_from(self.plan_dags_table, "TRUE") - self._update_versions() - - analytics.collector.on_migration_end( - from_sqlmesh_version=versions.sqlmesh_version, - state_sync_type=self.state_type(), - migration_time_sec=time.perf_counter() - migration_start_ts, - ) - except Exception as e: - if skip_backup: - logger.error("Backup was skipped so no rollback was attempted.") - else: - self.rollback() - - analytics.collector.on_migration_end( - from_sqlmesh_version=versions.sqlmesh_version, - state_sync_type=self.state_type(), - migration_time_sec=time.perf_counter() - migration_start_ts, - error=e, - ) - - self.console.log_migration_status(success=False) - raise SQLMeshError("SQLMesh migration failed.") from e - - self.console.log_migration_status() - - @transactional() - def rollback(self) -> None: - """Rollback to the previous migration.""" - logger.info("Starting migration rollback.") - tables = (self.snapshots_table, self.environments_table, self.versions_table) - optional_tables = (self.intervals_table, self.plan_dags_table) - versions = self.get_versions(validate=False) - if versions.schema_version == 0: - # Clean up state tables - for table in tables + optional_tables: - self.engine_adapter.drop_table(table) - else: - if not all( - self.engine_adapter.table_exists(_backup_table_name(table)) for table in tables - ): - raise SQLMeshError("There are no prior migrations to roll back to.") - for table in tables: - self._restore_table(table, _backup_table_name(table)) - - for optional_table in optional_tables: - if self.engine_adapter.table_exists(_backup_table_name(optional_table)): - self._restore_table(optional_table, _backup_table_name(optional_table)) - - logger.info("Migration rollback successful.") - - def state_type(self) -> str: - return self.engine_adapter.dialect - - def _backup_state(self) -> None: - for table in ( - self.snapshots_table, - self.environments_table, - self.versions_table, - self.intervals_table, - self.plan_dags_table, - ): - if self.engine_adapter.table_exists(table): - with self.engine_adapter.transaction(): - backup_name = _backup_table_name(table) - self.engine_adapter.drop_table(backup_name) - self.engine_adapter.ctas( - backup_name, exp.select("*").from_(table), exists=False - ) - - def _apply_migrations( - self, - default_catalog: t.Optional[str], - skip_backup: bool, - ) -> bool: - versions = self.get_versions(validate=False) - migrations = MIGRATIONS[versions.schema_version :] - - migrate_rows = migrations or major_minor(SQLGLOT_VERSION) != versions.minor_sqlglot_version - if not skip_backup and migrate_rows: - self._backup_state() - - for migration in migrations: - logger.info(f"Applying migration {migration}") - migration.migrate(self, default_catalog=default_catalog) - - return bool(migrate_rows) - - def _migrate_rows(self, promoted_snapshots_only: bool) -> None: - logger.info("Fetching environments") - environments = self.get_environments() - # Only migrate snapshots that are part of at least one environment. - snapshots_to_migrate = ( - {s.snapshot_id for e in environments for s in e.snapshots} - if promoted_snapshots_only - else None - ) - snapshot_mapping = self._migrate_snapshot_rows(snapshots_to_migrate) - if not snapshot_mapping: - logger.info("No changes to snapshots detected") - return - self._migrate_environment_rows(environments, snapshot_mapping) - - def _migrate_snapshot_rows( - self, snapshots: t.Optional[t.Set[SnapshotId]] - ) -> t.Dict[SnapshotId, SnapshotTableInfo]: - logger.info("Migrating snapshot rows...") - raw_snapshots = { - SnapshotId(name=name, identifier=identifier): json.loads(raw_snapshot) - for where in (self._snapshot_id_filter(snapshots) if snapshots is not None else [None]) - for name, identifier, raw_snapshot in self._fetchall( - exp.select("name", "identifier", "snapshot") - .from_(self.snapshots_table) - .where(where) - .lock() - ) - } - if not raw_snapshots: - return {} - - dag: DAG[SnapshotId] = DAG() - for snapshot_id, raw_snapshot in raw_snapshots.items(): - parent_ids = [SnapshotId.parse_obj(p_id) for p_id in raw_snapshot.get("parents", [])] - dag.add(snapshot_id, [p_id for p_id in parent_ids if p_id in raw_snapshots]) - - reversed_dag_raw = dag.reversed.graph - - self.console.start_snapshot_migration_progress(len(raw_snapshots)) - - parsed_snapshots = LazilyParsedSnapshots(raw_snapshots) - all_snapshot_mapping: t.Dict[SnapshotId, SnapshotTableInfo] = {} - snapshot_id_mapping: t.Dict[SnapshotId, SnapshotId] = {} - new_snapshots: t.Dict[SnapshotId, Snapshot] = {} - visited: t.Set[SnapshotId] = set() - - def _push_new_snapshots() -> None: - all_snapshot_mapping.update( - { - from_id: new_snapshots[to_id].table_info - for from_id, to_id in snapshot_id_mapping.items() - } - ) - - existing_new_snapshots = self.snapshots_exist(new_snapshots) - new_snapshots_to_push = [ - s for s in new_snapshots.values() if s.snapshot_id not in existing_new_snapshots - ] - if new_snapshots_to_push: - logger.info("Pushing %s migrated snapshots", len(new_snapshots_to_push)) - self._push_snapshots(new_snapshots_to_push) - new_snapshots.clear() - snapshot_id_mapping.clear() - - def _visit( - snapshot_id: SnapshotId, fingerprint_cache: t.Dict[str, SnapshotFingerprint] - ) -> None: - if snapshot_id in visited or snapshot_id not in raw_snapshots: - return - visited.add(snapshot_id) - - snapshot = parsed_snapshots[snapshot_id] - node = snapshot.node - - node_seen = set() - node_queue = {snapshot_id} - nodes: t.Dict[str, Node] = {} - audits: t.Dict[str, ModelAudit] = {} - while node_queue: - next_snapshot_id = node_queue.pop() - next_snapshot = parsed_snapshots.get(next_snapshot_id) - - if next_snapshot_id in node_seen or not next_snapshot: - continue - - node_seen.add(next_snapshot_id) - node_queue.update(next_snapshot.parents) - - nodes[next_snapshot.name] = next_snapshot.node - audits.update({a.name: a for a in next_snapshot.audits}) - - new_snapshot = deepcopy(snapshot) - try: - new_snapshot.fingerprint = fingerprint_from_node( - node, - nodes=nodes, - audits=audits, - cache=fingerprint_cache, - ) - new_snapshot.parents = tuple( - SnapshotId( - name=parent_node.fqn, - identifier=fingerprint_from_node( - parent_node, - nodes=nodes, - audits=audits, - cache=fingerprint_cache, - ).to_identifier(), - ) - for parent_node in _parents_from_node(node, nodes).values() - ) - except Exception: - logger.exception("Could not compute fingerprint for %s", snapshot.snapshot_id) - return - - # Reset the effective_from date for the new snapshot to avoid unexpected backfills. - new_snapshot.effective_from = None - new_snapshot.previous_versions = snapshot.all_versions - new_snapshot.migrated = True - if not new_snapshot.temp_version: - new_snapshot.temp_version = snapshot.fingerprint.to_version() - - self.console.update_snapshot_migration_progress(1) - - # Visit children and evict them from the parsed_snapshots cache after. - for child in reversed_dag_raw.get(snapshot_id, []): - # Make sure to copy the fingerprint cache to avoid sharing it between different child snapshots with the same name. - _visit(child, fingerprint_cache.copy()) - parsed_snapshots.evict(child) - - if new_snapshot.fingerprint == snapshot.fingerprint: - logger.debug(f"{new_snapshot.snapshot_id} is unchanged.") - return - - new_snapshot_id = new_snapshot.snapshot_id - - if new_snapshot_id in raw_snapshots: - # Mapped to an existing snapshot. - new_snapshots[new_snapshot_id] = parsed_snapshots[new_snapshot_id] - logger.debug("Migrated snapshot %s already exists", new_snapshot_id) - elif ( - new_snapshot_id not in new_snapshots - or new_snapshot.updated_ts > new_snapshots[new_snapshot_id].updated_ts - ): - new_snapshots[new_snapshot_id] = new_snapshot - - snapshot_id_mapping[snapshot.snapshot_id] = new_snapshot_id - logger.debug("%s mapped to %s", snapshot.snapshot_id, new_snapshot_id) - - if len(new_snapshots) >= self.SNAPSHOT_MIGRATION_BATCH_SIZE: - _push_new_snapshots() - - for root_snapshot_id in dag.roots: - _visit(root_snapshot_id, {}) - - if new_snapshots: - _push_new_snapshots() - - self.console.stop_snapshot_migration_progress() - return all_snapshot_mapping - - def _migrate_environment_rows( - self, - environments: t.List[Environment], - snapshot_mapping: t.Dict[SnapshotId, SnapshotTableInfo], - ) -> None: - logger.info("Migrating environment rows...") - - updated_prod_environment: t.Optional[Environment] = None - updated_environments = [] - for environment in environments: - snapshots = [ - ( - snapshot_mapping[info.snapshot_id] - if info.snapshot_id in snapshot_mapping - else info - ) - for info in environment.snapshots - ] - - if snapshots != environment.snapshots: - environment.snapshots = snapshots - updated_environments.append(environment) - if environment.name == c.PROD: - updated_prod_environment = environment - self.console.start_env_migration_progress(len(updated_environments)) - - for environment in updated_environments: - self._update_environment(environment) - self.console.update_env_migration_progress(1) - - if updated_prod_environment: - try: - self.unpause_snapshots(updated_prod_environment.snapshots, now_timestamp()) - except Exception: - logger.warning("Failed to unpause migrated snapshots", exc_info=True) - - self.console.stop_env_migration_progress() - - def _snapshot_ids_exist( - self, snapshot_ids: t.Iterable[SnapshotIdLike], table_name: exp.Table - ) -> t.Set[SnapshotId]: - return { - SnapshotId(name=name, identifier=identifier) - for where in self._snapshot_id_filter(snapshot_ids) - for name, identifier in self._fetchall( - exp.select("name", "identifier").from_(table_name).where(where) - ) - } - - def _snapshot_id_filter( - self, - snapshot_ids: t.Iterable[SnapshotIdLike], - alias: t.Optional[str] = None, - batch_size: t.Optional[int] = None, - ) -> t.Iterator[exp.Condition]: - name_identifiers = sorted( - {(snapshot_id.name, snapshot_id.identifier) for snapshot_id in snapshot_ids} - ) - batches = self._batches(name_identifiers, batch_size=batch_size) - - if not name_identifiers: - yield exp.false() - elif self.engine_adapter.SUPPORTS_TUPLE_IN: - for identifiers in batches: - yield t.cast( - exp.Tuple, - exp.convert( - ( - exp.column("name", table=alias), - exp.column("identifier", table=alias), - ) - ), - ).isin(*identifiers) - else: - for identifiers in batches: - yield exp.or_( - *[ - exp.and_( - exp.column("name", table=alias).eq(name), - exp.column("identifier", table=alias).eq(identifier), - ) - for name, identifier in identifiers - ] - ) - - def _snapshot_name_version_filter( - self, - snapshot_name_versions: t.Iterable[SnapshotNameVersionLike], - alias: t.Optional[str] = "snapshots", - ) -> t.Iterator[exp.Condition]: - name_versions = sorted({(s.name, s.version) for s in snapshot_name_versions}) - batches = self._batches(name_versions) - - if not name_versions: - yield exp.false() - elif self.engine_adapter.SUPPORTS_TUPLE_IN: - for versions in batches: - yield t.cast( - exp.Tuple, - exp.convert( - ( - exp.column("name", table=alias), - exp.column("version", table=alias), - ) - ), - ).isin(*versions) - else: - for versions in batches: - yield exp.or_( - *[ - exp.and_( - exp.column("name", table=alias).eq(name), - exp.column("version", table=alias).eq(version), - ) - for name, version in versions - ] - ) - - def _batches(self, l: t.List[T], batch_size: t.Optional[int] = None) -> t.List[t.List[T]]: - batch_size = batch_size or self.SNAPSHOT_BATCH_SIZE - return [l[i : i + batch_size] for i in range(0, len(l), batch_size)] - - @contextlib.contextmanager - def _transaction(self) -> t.Iterator[None]: - with self.engine_adapter.transaction(): - yield - - -def _intervals_to_df( - snapshot_intervals: t.Sequence[t.Tuple[t.Union[SnapshotInfoLike, SnapshotIntervals], Interval]], - is_dev: bool, - is_removed: bool, -) -> pd.DataFrame: - return pd.DataFrame( - [ - _interval_to_df( - s, - *interval, - is_dev=is_dev, - is_removed=is_removed, - ) - for s, interval in snapshot_intervals - ] - ) - - -def _snapshot_interval_to_df( - snapshot_intervals: SnapshotIntervals, - is_removed: bool = False, -) -> pd.DataFrame: - return pd.DataFrame( - [ - _interval_to_df( - snapshot_intervals, - start_ts, - end_ts, - is_dev=is_dev, - is_removed=is_removed, - ) - for is_dev in (False, True) - for start_ts, end_ts in getattr( - snapshot_intervals, "dev_intervals" if is_dev else "intervals" - ) - ] - ) - - -def _interval_to_df( - snapshot: t.Union[SnapshotInfoLike, SnapshotIntervals], - start_ts: int, - end_ts: int, - is_dev: bool = False, - is_removed: bool = False, - is_compacted: bool = False, -) -> t.Dict[str, t.Any]: - return { - "id": random_id(), - "created_ts": now_timestamp(), - "name": snapshot.name, - "identifier": snapshot.identifier, - "version": snapshot.version, - "start_ts": start_ts, - "end_ts": end_ts, - "is_dev": is_dev, - "is_removed": is_removed, - "is_compacted": is_compacted, - } - - -def _snapshots_to_df(snapshots: t.Iterable[Snapshot]) -> pd.DataFrame: - return pd.DataFrame( - [ - { - "name": snapshot.name, - "identifier": snapshot.identifier, - "version": snapshot.version, - "snapshot": _snapshot_to_json(snapshot), - "kind_name": snapshot.model_kind_name.value if snapshot.model_kind_name else None, - "expiration_ts": snapshot.expiration_ts, - } - for snapshot in snapshots - ] - ) - - -def _environment_to_df(environment: Environment) -> pd.DataFrame: - return pd.DataFrame( - [ - { - "name": environment.name, - "snapshots": json.dumps( - [snapshot.dict(mode="json") for snapshot in environment.snapshots] - ), - "start_at": time_like_to_str(environment.start_at), - "end_at": time_like_to_str(environment.end_at) if environment.end_at else None, - "plan_id": environment.plan_id, - "previous_plan_id": environment.previous_plan_id, - "expiration_ts": environment.expiration_ts, - "finalized_ts": environment.finalized_ts, - "promoted_snapshot_ids": ( - json.dumps([s.dict() for s in environment.promoted_snapshot_ids]) - if environment.promoted_snapshot_ids is not None - else None - ), - "suffix_target": environment.suffix_target.value, - "catalog_name_override": environment.catalog_name_override, - "previous_finalized_snapshots": ( - json.dumps( - [ - snapshot.dict(mode="json") - for snapshot in environment.previous_finalized_snapshots - ] - ) - if environment.previous_finalized_snapshots is not None - else None - ), - "normalize_name": environment.normalize_name, - } - ] - ) - - -def _backup_table_name(table_name: TableName) -> exp.Table: - table = exp.to_table(table_name).copy() - table.set("this", exp.to_identifier(table.name + "_backup")) - return table - - -def _snapshot_to_json(snapshot: Snapshot) -> str: - return snapshot.json(exclude={"intervals", "dev_intervals"}) - - -def parse_snapshot( - model_cache: ModelCache, - serialized_snapshot: str, - name: str, - identifier: str, -) -> Snapshot: - payload = json.loads(serialized_snapshot) - - def loader() -> Node: - return parse_obj_as(Node, payload["node"]) # type: ignore - - payload["node"] = model_cache.get_or_load(f"{name}_{identifier}", loader=loader) # type: ignore - snapshot = Snapshot(**payload) - - return snapshot - - -class LazilyParsedSnapshots: - def __init__(self, raw_snapshots: t.Dict[SnapshotId, t.Dict[str, t.Any]]): - self._raw_snapshots = raw_snapshots - self._parsed_snapshots: t.Dict[SnapshotId, t.Optional[Snapshot]] = {} - - def get(self, snapshot_id: SnapshotId) -> t.Optional[Snapshot]: - if snapshot_id not in self._parsed_snapshots: - raw_snapshot = self._raw_snapshots.get(snapshot_id) - if raw_snapshot: - self._parsed_snapshots[snapshot_id] = Snapshot.parse_obj(raw_snapshot) - else: - self._parsed_snapshots[snapshot_id] = None - return self._parsed_snapshots[snapshot_id] - - def evict(self, snapshot_id: SnapshotId) -> None: - self._parsed_snapshots.pop(snapshot_id, None) - - def __getitem__(self, snapshot_id: SnapshotId) -> Snapshot: - snapshot = self.get(snapshot_id) - if snapshot is None: - raise KeyError(snapshot_id) - return snapshot diff --git a/sqlmesh/core/state_sync/export_import.py b/sqlmesh/core/state_sync/export_import.py new file mode 100644 index 0000000000..3a63351ddb --- /dev/null +++ b/sqlmesh/core/state_sync/export_import.py @@ -0,0 +1,226 @@ +import json +import typing as t + +from sqlmesh.core.state_sync import StateSync +from sqlmesh.core.snapshot import Snapshot +from sqlmesh.utils.date import now, to_tstz +from sqlmesh.utils.pydantic import _expression_encoder +from sqlmesh.core.state_sync import Versions +from sqlmesh.core.state_sync.common import ( + EnvironmentsChunk, + SnapshotsChunk, + VersionsChunk, + EnvironmentWithStatements, + StateStream, +) +from sqlmesh.core.console import Console +from pathlib import Path +from sqlmesh.core.console import NoopConsole + +import json_stream +from json_stream import streamable_dict, to_standard_types, streamable_list +from json_stream.writer import StreamableDict +from json_stream.base import StreamingJSONObject +from json_stream.dump import JSONStreamEncoder +from sqlmesh.utils.errors import SQLMeshError +from sqlglot import exp +from sqlmesh.utils.pydantic import DEFAULT_ARGS as PYDANTIC_DEFAULT_ARGS, PydanticModel + + +class SQLMeshJSONStreamEncoder(JSONStreamEncoder): + def default(self, obj: t.Any) -> t.Any: + if isinstance(obj, exp.Expression): + return _expression_encoder(obj) + + return super().default(obj) + + +def _dump_pydantic_model(model: PydanticModel) -> t.Dict[str, t.Any]: + dump_args: t.Dict[str, t.Any] = PYDANTIC_DEFAULT_ARGS + return model.model_dump(mode="json", **dump_args) + + +def _export(state_stream: StateStream, importable: bool, console: Console) -> StreamableDict: + """ + Return the state in a format 'json_stream' can stream to a file + + Args: + state_stream: A stream of state to export + console: A Console instance to print progress to + """ + + @streamable_list + def _dump_snapshots( + snapshot_stream: t.Iterable[Snapshot], + ) -> t.Iterator[t.Dict[str, t.Any]]: + console.update_state_export_progress(snapshot_count=0) + for idx, snapshot in enumerate(snapshot_stream): + yield _dump_pydantic_model(snapshot) + console.update_state_export_progress(snapshot_count=idx + 1) + + @streamable_dict + def _dump_environments( + environment_stream: t.Iterable[EnvironmentWithStatements], + ) -> t.Iterator[t.Tuple[str, t.Any]]: + console.update_state_export_progress(environment_count=0) + for idx, env in enumerate(environment_stream): + yield env.environment.name, _dump_pydantic_model(env) + console.update_state_export_progress(environment_count=idx + 1) + + @streamable_dict + def _do_export() -> t.Iterator[t.Tuple[str, t.Any]]: + yield "metadata", {"timestamp": to_tstz(now()), "file_version": 1, "importable": importable} + + for state_chunk in state_stream: + if isinstance(state_chunk, VersionsChunk): + versions = _dump_pydantic_model(state_chunk.versions) + yield "versions", versions + console.update_state_export_progress( + version_count=len(versions), versions_complete=True + ) + + if isinstance(state_chunk, SnapshotsChunk): + yield "snapshots", _dump_snapshots(state_chunk) + console.update_state_export_progress(snapshots_complete=True) + + if isinstance(state_chunk, EnvironmentsChunk): + yield "environments", _dump_environments(state_chunk) + console.update_state_export_progress(environments_complete=True) + + return _do_export() + + +def _import( + state_sync: StateSync, data: t.Callable[[], StreamingJSONObject], clear: bool, console: Console +) -> None: + """ + Load the state defined by the :data into the supplied :state_sync. The data is in the same format as written by dump() + + Args: + state_sync: The StateSync that the user has requested to dump state from + data: A factory function that produces new streaming JSON reader attached to the file we are loading state from. + This is so each section of the file can have its own reader which allows it to be read in isolation / out-of-order + This puts less reliance on downstream consumers performing operations in a certain order + clear: Whether or not to clear the existing state before writing the new state + console: A Console instance to print progress to + """ + + def _load_snapshots() -> t.Iterator[Snapshot]: + stream = data()["snapshots"] + + console.update_state_import_progress(snapshot_count=0) + for idx, raw_snapshot in enumerate(stream): + snapshot = Snapshot.model_validate(to_standard_types(raw_snapshot)) + yield snapshot + console.update_state_import_progress(snapshot_count=idx + 1) + + console.update_state_import_progress(snapshots_complete=True) + + def _load_environments() -> t.Iterator[EnvironmentWithStatements]: + stream = data()["environments"] + + console.update_state_import_progress(environment_count=0) + for idx, (_, raw_environment) in enumerate(stream.items()): + environment = EnvironmentWithStatements.model_validate( + to_standard_types(raw_environment) + ) + yield environment + console.update_state_import_progress(environment_count=idx + 1) + + console.update_state_import_progress(environments_complete=True) + + metadata = to_standard_types(data()["metadata"]) + + timestamp = metadata["timestamp"] + if not isinstance(timestamp, str): + raise ValueError(f"'timestamp' contains an invalid value. Expecting str, got: {timestamp}") + console.update_state_import_progress( + timestamp=timestamp, state_file_version=metadata["file_version"] + ) + + versions = Versions.model_validate(to_standard_types(data()["versions"])) + + stream = StateStream.from_iterators( + versions=versions, snapshots=_load_snapshots(), environments=_load_environments() + ) + + console.update_state_import_progress(versions=versions) + + state_sync.import_(stream, clear=clear) + + +def export_state( + state_sync: StateSync, + output_file: Path, + local_snapshots: t.Optional[t.Dict[str, Snapshot]] = None, + environment_names: t.Optional[t.List[str]] = None, + console: t.Optional[Console] = None, +) -> None: + console = console or NoopConsole() + + state_stream = ( + StateStream.from_iterators( + versions=state_sync.get_versions(), + snapshots=iter(local_snapshots.values()), + environments=iter([]), + ) + if local_snapshots + else state_sync.export(environment_names=environment_names) + ) + + importable = False if local_snapshots else True + + json_stream = _export(state_stream=state_stream, importable=importable, console=console) + with output_file.open(mode="w", encoding="utf8") as fh: + json.dump(json_stream, fh, indent=2, cls=SQLMeshJSONStreamEncoder) + + +def import_state( + state_sync: StateSync, + input_file: Path, + clear: bool = False, + console: t.Optional[Console] = None, +) -> None: + console = console or NoopConsole() + + # we need to peek into the file to figure out what state version we are dealing with + with input_file.open("r", encoding="utf8") as fh: + stream = json_stream.load(fh) + if not isinstance(stream, StreamingJSONObject): + raise SQLMeshError(f"Expected JSON object, got: {type(stream)}") + + try: + metadata = stream["metadata"].persistent() + except KeyError: + raise SQLMeshError("Expecting a 'metadata' key to be present") + + if not isinstance(metadata, StreamingJSONObject): + raise SQLMeshError("Expecting the 'metadata' key to contain an object") + + file_version = metadata.get("file_version") + if file_version is None: + raise SQLMeshError("Unable to determine state file format version from the input file") + + try: + int(file_version) + except ValueError: + raise SQLMeshError(f"Unable to parse state file format version: {file_version}") + + if not metadata.get("importable", False): + # this can happen if the state file was created from local unversioned snapshots that were not sourced from the project state database + raise SQLMeshError("State file is marked as not importable. Aborting") + + handles: t.List[t.TextIO] = [] + + def _new_handle() -> StreamingJSONObject: + handle = input_file.open("r", encoding="utf8") + handles.append(handle) + stream = json_stream.load(handle) + assert isinstance(stream, StreamingJSONObject) + return stream + + try: + _import(state_sync=state_sync, data=_new_handle, clear=clear, console=console) + finally: + for handle in handles: + handle.close() diff --git a/sqlmesh/core/table_diff.py b/sqlmesh/core/table_diff.py index 2a686fbb73..bd32cc170f 100644 --- a/sqlmesh/core/table_diff.py +++ b/sqlmesh/core/table_diff.py @@ -2,19 +2,30 @@ import math import typing as t +from functools import cached_property -import pandas as pd +from sqlmesh.core.dialect import to_schema +from sqlmesh.core.engine_adapter.mixins import RowDiffMixin +from sqlmesh.core.engine_adapter.athena import AthenaEngineAdapter from sqlglot import exp, parse_one from sqlglot.helper import ensure_list from sqlglot.optimizer.normalize_identifiers import normalize_identifiers from sqlglot.optimizer.qualify_columns import quote_identifiers +from sqlglot.optimizer.scope import find_all_in_scope from sqlmesh.utils.pydantic import PydanticModel +from sqlmesh.utils.errors import SQLMeshError + if t.TYPE_CHECKING: + import pandas as pd + from sqlmesh.core._typing import TableName from sqlmesh.core.engine_adapter import EngineAdapter +SQLMESH_JOIN_KEY_COL = "__sqlmesh_join_key" +SQLMESH_SAMPLE_TYPE_COL = "__sqlmesh_sample_type" + class SchemaDiff(PydanticModel, frozen=True): """An object containing the schema difference between a source and target table.""" @@ -26,29 +37,78 @@ class SchemaDiff(PydanticModel, frozen=True): source_alias: t.Optional[str] = None target_alias: t.Optional[str] = None model_name: t.Optional[str] = None + ignore_case: bool = False + + @property + def _comparable_source_schema(self) -> t.Dict[str, exp.DataType]: + return ( + self._lowercase_schema_names(self.source_schema) + if self.ignore_case + else self.source_schema + ) + + @property + def _comparable_target_schema(self) -> t.Dict[str, exp.DataType]: + return ( + self._lowercase_schema_names(self.target_schema) + if self.ignore_case + else self.target_schema + ) + + def _lowercase_schema_names( + self, schema: t.Dict[str, exp.DataType] + ) -> t.Dict[str, exp.DataType]: + return {c.lower(): t for c, t in schema.items()} + + def _original_column_name( + self, maybe_lowercased_column_name: str, schema: t.Dict[str, exp.DataType] + ) -> str: + if not self.ignore_case: + return maybe_lowercased_column_name + + return next(c for c in schema if c.lower() == maybe_lowercased_column_name) @property def added(self) -> t.List[t.Tuple[str, exp.DataType]]: """Added columns.""" - return [(c, t) for c, t in self.target_schema.items() if c not in self.source_schema] + return [ + (self._original_column_name(c, self.target_schema), t) + for c, t in self._comparable_target_schema.items() + if c not in self._comparable_source_schema + ] @property def removed(self) -> t.List[t.Tuple[str, exp.DataType]]: """Removed columns.""" - return [(c, t) for c, t in self.source_schema.items() if c not in self.target_schema] + return [ + (self._original_column_name(c, self.source_schema), t) + for c, t in self._comparable_source_schema.items() + if c not in self._comparable_target_schema + ] @property def modified(self) -> t.Dict[str, t.Tuple[exp.DataType, exp.DataType]]: """Columns with modified types.""" modified = {} - for column in self.source_schema.keys() & self.target_schema.keys(): - source_type = self.source_schema[column] - target_type = self.target_schema[column] + for column in self._comparable_source_schema.keys() & self._comparable_target_schema.keys(): + source_type = self._comparable_source_schema[column] + target_type = self._comparable_target_schema[column] if source_type != target_type: modified[column] = (source_type, target_type) + + if self.ignore_case: + modified = { + self._original_column_name(c, self.source_schema): dt for c, dt in modified.items() + } + return modified + @property + def has_changes(self) -> bool: + """Does the schema contain any changes at all between source and target""" + return bool(self.added or self.removed or self.modified) + class RowDiff(PydanticModel, frozen=True): """Summary statistics and a sample dataframe.""" @@ -64,6 +124,22 @@ class RowDiff(PydanticModel, frozen=True): source_alias: t.Optional[str] = None target_alias: t.Optional[str] = None model_name: t.Optional[str] = None + decimals: int = 3 + + _types_resolved: t.ClassVar[bool] = False + + def __new__(cls, *args: t.Any, **kwargs: t.Any) -> RowDiff: + if not cls._types_resolved: + cls._resolve_types() + return super().__new__(cls) + + @classmethod + def _resolve_types(cls) -> None: + # Pandas is imported by type checking so we need to resolve the types with the real import before instantiating + import pandas as pd # noqa + + cls.model_rebuild() + cls._types_resolved = True @property def source_count(self) -> int: @@ -75,6 +151,15 @@ def target_count(self) -> int: """Count of the target.""" return int(self.stats["t_count"]) + @property + def empty(self) -> bool: + return ( + self.source_count == 0 + and self.target_count == 0 + and self.s_only_count == 0 + and self.t_only_count == 0 + ) + @property def count_pct_change(self) -> float: """The percentage change of the counts.""" @@ -148,39 +233,28 @@ def __init__( model_name: t.Optional[str] = None, model_dialect: t.Optional[str] = None, decimals: int = 3, + schema_diff_ignore_case: bool = False, ): + if not isinstance(adapter, RowDiffMixin): + raise ValueError(f"Engine {adapter} doesnt support RowDiff") + self.adapter = adapter self.source = source self.target = target self.dialect = adapter.dialect + self.source_table = exp.to_table(self.source, dialect=self.dialect) + self.target_table = exp.to_table(self.target, dialect=self.dialect) self.where = exp.condition(where, dialect=self.dialect) if where else None self.limit = limit self.model_name = model_name self.model_dialect = model_dialect self.decimals = decimals + self.schema_diff_ignore_case = schema_diff_ignore_case # Support environment aliases for diff output improvement in certain cases self.source_alias = source_alias self.target_alias = target_alias - if isinstance(on, (list, tuple)): - join_condition = [exp.parse_identifier(key) for key in on] - s_table = exp.to_identifier("s", quoted=True) - t_table = exp.to_identifier("t", quoted=True) - - self.on: exp.Condition = exp.and_( - *( - exp.column(c, s_table).eq(exp.column(c, t_table)) - | ( - exp.column(c, s_table).is_(exp.null()) - & exp.column(c, t_table).is_(exp.null()) - ) - for c in join_condition - ) - ) - else: - self.on = on - self.skip_columns = { normalize_identifiers( exp.parse_identifier(t.cast(str, col)), @@ -189,23 +263,67 @@ def __init__( for col in ensure_list(skip_columns) } - normalize_identifiers(self.on, dialect=self.model_dialect or self.dialect) - - self._source_schema: t.Optional[t.Dict[str, exp.DataType]] = None - self._target_schema: t.Optional[t.Dict[str, exp.DataType]] = None + self._on = on self._row_diff: t.Optional[RowDiff] = None - @property + @cached_property def source_schema(self) -> t.Dict[str, exp.DataType]: - if self._source_schema is None: - self._source_schema = self.adapter.columns(self.source) - return self._source_schema + return self.adapter.columns(self.source_table) - @property + @cached_property def target_schema(self) -> t.Dict[str, exp.DataType]: - if self._target_schema is None: - self._target_schema = self.adapter.columns(self.target) - return self._target_schema + return self.adapter.columns(self.target_table) + + @cached_property + def key_columns(self) -> t.Tuple[t.List[exp.Column], t.List[exp.Column], t.List[str]]: + dialect = self.model_dialect or self.dialect + + # If the columns to join on are explicitly specified, then just return them + if isinstance(self._on, (list, tuple)): + identifiers = [normalize_identifiers(c, dialect=dialect) for c in self._on] + s_index = [exp.column(c, "s") for c in identifiers] + t_index = [exp.column(c, "t") for c in identifiers] + return s_index, t_index, [i.name for i in identifiers] + + # Otherwise, we need to parse them out of the supplied "on" condition + index_cols = [] + s_index = [] + t_index = [] + + normalize_identifiers(self._on, dialect=dialect) + for col in self._on.find_all(exp.Column): + index_cols.append(col.name) + if col.table.lower() == "s": + s_index.append(col) + elif col.table.lower() == "t": + t_index.append(col) + + index_cols = list(dict.fromkeys(index_cols)) + s_index = list(dict.fromkeys(s_index)) + t_index = list(dict.fromkeys(t_index)) + + return s_index, t_index, index_cols + + @property + def source_key_expression(self) -> exp.Expression: + s_index, _, _ = self.key_columns + return self._key_expression(s_index, self.source_schema) + + @property + def target_key_expression(self) -> exp.Expression: + _, t_index, _ = self.key_columns + return self._key_expression(t_index, self.target_schema) + + def _key_expression( + self, cols: t.List[exp.Column], schema: t.Dict[str, exp.DataType] + ) -> exp.Expression: + # if there is a single column, dont do anything fancy to it in order to allow existing indexes to be hit + if len(cols) == 1: + return exp.to_column(cols[0].name) + + # if there are multiple columns, turn them into a single column by stringify-ing/concatenating them together + key_columns_to_types = {key.name: schema[key.name] for key in cols} + return self.adapter.concat_columns(key_columns_to_types, self.decimals) def schema_diff(self) -> SchemaDiff: return SchemaDiff( @@ -216,9 +334,12 @@ def schema_diff(self) -> SchemaDiff: source_alias=self.source_alias, target_alias=self.target_alias, model_name=self.model_name, + ignore_case=self.schema_diff_ignore_case, ) - def row_diff(self, skip_grain_check: bool = False) -> RowDiff: + def row_diff( + self, temp_schema: t.Optional[str] = None, skip_grain_check: bool = False + ) -> RowDiff: if self._row_diff is None: source_schema = { c: t for c, t in self.source_schema.items() if c not in self.skip_columns @@ -229,29 +350,29 @@ def row_diff(self, skip_grain_check: bool = False) -> RowDiff: s_selects = {c: exp.column(c, "s").as_(f"s__{c}") for c in source_schema} t_selects = {c: exp.column(c, "t").as_(f"t__{c}") for c in target_schema} - - index_cols = [] - s_index = [] - t_index = [] - - for col in self.on.find_all(exp.Column): - index_cols.append(col.name) - if col.table == "s": - s_index.append(col) - elif col.table == "t": - t_index.append(col) - index_cols = list(dict.fromkeys(index_cols)) - s_index = list(dict.fromkeys(s_index)) - t_index = list(dict.fromkeys(t_index)) + s_selects[SQLMESH_JOIN_KEY_COL] = exp.column(SQLMESH_JOIN_KEY_COL, "s").as_( + f"s__{SQLMESH_JOIN_KEY_COL}" + ) + t_selects[SQLMESH_JOIN_KEY_COL] = exp.column(SQLMESH_JOIN_KEY_COL, "t").as_( + f"t__{SQLMESH_JOIN_KEY_COL}" + ) matched_columns = {c: t for c, t in source_schema.items() if t == target_schema.get(c)} + s_index, t_index, index_cols = self.key_columns + s_index_names = [c.name for c in s_index] + t_index_names = [t.name for t in t_index] + def _column_expr(name: str, table: str) -> exp.Expression: - if matched_columns[name].this in exp.DataType.FLOAT_TYPES: - return exp.func( - "ROUND", exp.column(name, table), exp.Literal.number(self.decimals) - ) - return exp.column(name, table) + column_type = matched_columns[name] + qualified_column = exp.column(name, table) + + if column_type.is_type(*exp.DataType.REAL_TYPES): + return self.adapter._normalize_decimal_value(qualified_column, self.decimals) + if column_type.is_type(*exp.DataType.NESTED_TYPES): + return self.adapter._normalize_nested_value(qualified_column) + + return qualified_column comparisons = [ exp.Case() @@ -269,30 +390,47 @@ def _column_expr(name: str, table: str) -> exp.Expression: for c, t in matched_columns.items() ] - def name(e: exp.Expression) -> str: - return e.args["alias"].sql(identify=True) + source_query = ( + exp.select( + *(exp.column(c) for c in source_schema), + self.source_key_expression.as_(SQLMESH_JOIN_KEY_COL), + ) + .from_(self.source_table.as_("s")) + .where(self.where) + ) + target_query = ( + exp.select( + *(exp.column(c) for c in target_schema), + self.target_key_expression.as_(SQLMESH_JOIN_KEY_COL), + ) + .from_(self.target_table.as_("t")) + .where(self.where) + ) + + # Ensure every column is qualified with the alias in the source and target queries + for col in find_all_in_scope(source_query, exp.Column): + col.set("table", exp.to_identifier("s")) + for col in find_all_in_scope(target_query, exp.Column): + col.set("table", exp.to_identifier("t")) + + source_table = exp.table_("__source") + target_table = exp.table_("__target") + stats_table = exp.table_("__stats") - query = ( + stats_query = ( exp.select( *s_selects.values(), *t_selects.values(), - exp.func("IF", exp.or_(*(c.is_(exp.Null()).not_() for c in s_index)), 1, 0).as_( - "s_exists" - ), - exp.func("IF", exp.or_(*(c.is_(exp.Null()).not_() for c in t_index)), 1, 0).as_( - "t_exists" - ), + exp.func( + "IF", exp.column(SQLMESH_JOIN_KEY_COL, "s").is_(exp.Null()).not_(), 1, 0 + ).as_("s_exists"), + exp.func( + "IF", exp.column(SQLMESH_JOIN_KEY_COL, "t").is_(exp.Null()).not_(), 1, 0 + ).as_("t_exists"), exp.func( "IF", - exp.and_( - *( - exp.and_( - exp.column(c, "s").eq(exp.column(c, "t")), - exp.column(c, "s").is_(exp.Null()).not_(), - exp.column(c, "t").is_(exp.Null()).not_(), - ) - for c in index_cols - ), + exp.column(SQLMESH_JOIN_KEY_COL, "s").eq( + exp.column(SQLMESH_JOIN_KEY_COL, "t") ), 1, 0, @@ -302,10 +440,10 @@ def name(e: exp.Expression) -> str: exp.or_( *( exp.and_( - exp.column(c, "s").is_(exp.Null()), - exp.column(c, "t").is_(exp.Null()), + s.is_(exp.Null()), + t.is_(exp.Null()), ) - for c in index_cols + for s, t in zip(s_index, t_index) ), ), 1, @@ -313,36 +451,69 @@ def name(e: exp.Expression) -> str: ).as_("null_grain"), *comparisons, ) - .from_(exp.alias_(self.source, "s")) + .from_(source_table.as_("s")) .join( - self.target, - on=self.on, + target_table.as_("t"), + on=exp.column(SQLMESH_JOIN_KEY_COL, "s").eq( + exp.column(SQLMESH_JOIN_KEY_COL, "t") + ), join_type="FULL", - join_alias="t", ) - .where(self.where) ) - query = exp.select( - "*", - exp.Case() - .when( - exp.and_( - *[ - exp.column(f"{c}_matches").eq(exp.Literal.number(1)) - for c in matched_columns - ] - ), - exp.Literal.number(1), + base_query = ( + exp.Select() + .with_(source_table, source_query) + .with_(target_table, target_query) + .with_(stats_table, stats_query) + .select( + "*", + exp.Case() + .when( + exp.and_( + *[ + exp.column(f"{c}_matches").eq(exp.Literal.number(1)) + for c in matched_columns + ] + ), + exp.Literal.number(1), + ) + .else_(exp.Literal.number(0)) + .as_("row_full_match"), ) - .else_(exp.Literal.number(0)) - .as_("row_full_match"), - ).from_(query.subquery("stats")) + .from_(stats_table) + ) + + query = self.adapter.ensure_nulls_for_unmatched_after_join( + quote_identifiers(base_query.copy(), dialect=self.model_dialect or self.dialect) + ) - query = quote_identifiers(query, dialect=self.model_dialect or self.dialect) - temp_table = exp.table_("diff", db="sqlmesh_temp", quoted=True) + if not temp_schema: + temp_schema = "sqlmesh_temp" + + schema = to_schema(temp_schema, dialect=self.dialect) + temp_table = exp.table_("diff", db=schema.db, catalog=schema.catalog, quoted=True) + + temp_table_kwargs: t.Dict[str, t.Any] = {} + if isinstance(self.adapter, AthenaEngineAdapter): + # Athena has two table formats: Hive (the default) and Iceberg. TableDiff requires that + # the formats be the same for the source, target, and temp tables. + source_table_type = self.adapter._query_table_type(self.source_table) + target_table_type = self.adapter._query_table_type(self.target_table) + + if source_table_type == "iceberg" and target_table_type == "iceberg": + temp_table_kwargs["table_format"] = "iceberg" + # Sets the temp table's format to Iceberg. + # If neither source nor target table is Iceberg, it defaults to Hive (Athena's default). + elif source_table_type == "iceberg" or target_table_type == "iceberg": + raise SQLMeshError( + f"Source table '{self.source}' format '{source_table_type}' and target table '{self.target}' format '{target_table_type}' " + f"do not match for Athena. Diffing between different table formats is not supported." + ) - with self.adapter.temp_table(query, name=temp_table) as table: + with self.adapter.temp_table( + query, name=temp_table, target_columns_to_types=None, **temp_table_kwargs + ) as table: summary_sums = [ exp.func("SUM", "s_exists").as_("s_count"), exp.func("SUM", "t_exists").as_("t_count"), @@ -353,18 +524,20 @@ def name(e: exp.Expression) -> str: ] if not skip_grain_check: - s_grains = ", ".join((f"s__{c}" for c in index_cols)) - t_grains = ", ".join((f"t__{c}" for c in index_cols)) summary_sums.extend( [ - parse_one(f"COUNT(DISTINCT({s_grains}))").as_("distinct_count_s"), - parse_one(f"COUNT(DISTINCT({t_grains}))").as_("distinct_count_t"), + parse_one(f"COUNT(DISTINCT(s__{SQLMESH_JOIN_KEY_COL}))").as_( + "distinct_count_s" + ), + parse_one(f"COUNT(DISTINCT(t__{SQLMESH_JOIN_KEY_COL}))").as_( + "distinct_count_t" + ), ] ) summary_query = exp.select(*summary_sums).from_(table) - stats_df = self.adapter.fetchdf(summary_query, quote_identifiers=True) + stats_df = self.adapter.fetchdf(summary_query, quote_identifiers=True).fillna(0) stats_df["s_only_count"] = stats_df["s_count"] - stats_df["join_count"] stats_df["t_only_count"] = stats_df["t_count"] - stats_df["join_count"] stats = stats_df.iloc[0].to_dict() @@ -374,8 +547,14 @@ def name(e: exp.Expression) -> str: *( exp.func( "ROUND", - 100 * (exp.func("SUM", name(c)) / exp.func("COUNT", name(c))), - 1, + 100 + * ( + exp.cast( + exp.func("SUM", name(c)), exp.DataType.build("NUMERIC") + ) + / exp.func("COUNT", name(c)) + ), + 9, ).as_(c.alias) for c in comparisons ) @@ -383,48 +562,47 @@ def name(e: exp.Expression) -> str: .from_(table) .where(exp.column("row_joined").eq(exp.Literal.number(1))) ) + column_stats = ( self.adapter.fetchdf(column_stats_query, quote_identifiers=True) .T.rename( columns={0: "pct_match"}, index=lambda x: str(x).replace("_matches", "") if x else "", ) - .drop(index=index_cols) + # errors=ignore because all the index_cols might not be present in the DF if the `on` condition was something like "s.id == t.item_id" + # because these would not be present in the matching_cols (since they have different names) and thus no summary would be generated + .drop(index=index_cols, errors="ignore") ) - sample_filter_cols = ["s_exists", "t_exists", "row_joined", "row_full_match"] - sample_query = ( - exp.select( - *(sample_filter_cols), - *(name(c) for c in s_selects.values()), - *(name(c) for c in t_selects.values()), - ) - .from_(table) - .where(exp.or_(*(exp.column(c.alias).eq(0) for c in comparisons))) - .order_by( - *(name(s_selects[c.name]) for c in s_index), - *(name(t_selects[c.name]) for c in t_index), - ) - .limit(self.limit) + sample = self._fetch_sample( + table, s_selects, s_index, t_selects, t_index, self.limit ) - sample = self.adapter.fetchdf(sample_query, quote_identifiers=True) - joined_sample_cols = [f"s__{c}" for c in index_cols] + joined_sample_cols = [f"s__{c}" for c in s_index_names] comparison_cols = [ (f"s__{c}", f"t__{c}") for c in column_stats[column_stats["pct_match"] < 100].index ] + for cols in comparison_cols: joined_sample_cols.extend(cols) + joined_renamed_cols = { c: c.split("__")[1] if c.split("__")[1] in index_cols else c for c in joined_sample_cols } - if self.source != self.source_alias and self.target != self.target_alias: + + if ( + self.source_alias + and self.target_alias + and self.source != self.source_alias + and self.target != self.target_alias + ): joined_renamed_cols = { c: ( n.replace( - "s__", f"{self.source_alias.upper() if self.source_alias else ''}__" + "s__", + f"{self.source_alias.upper()}__", ) if n.startswith("s__") else n @@ -434,40 +612,51 @@ def name(e: exp.Expression) -> str: joined_renamed_cols = { c: ( n.replace( - "t__", f"{self.target_alias.upper() if self.target_alias else ''}__" + "t__", + f"{self.target_alias.upper()}__", ) if n.startswith("t__") else n ) for c, n in joined_renamed_cols.items() } - joined_sample = sample[sample["row_joined"] == 1][joined_sample_cols] + + joined_sample = sample[sample[SQLMESH_SAMPLE_TYPE_COL] == "common_rows"][ + joined_sample_cols + ] joined_sample.rename( columns=joined_renamed_cols, inplace=True, ) - s_sample = sample[(sample["s_exists"] == 1) & (sample["row_joined"] == 0)][ + s_sample = sample[sample[SQLMESH_SAMPLE_TYPE_COL] == "source_only"][ [ - *[f"s__{c}" for c in index_cols], - *[f"s__{c}" for c in source_schema if c not in index_cols], + *[f"s__{c}" for c in s_index_names], + *[f"s__{c}" for c in source_schema if c not in s_index_names], ] ] s_sample.rename( columns={c: c.replace("s__", "") for c in s_sample.columns}, inplace=True ) - t_sample = sample[(sample["t_exists"] == 1) & (sample["row_joined"] == 0)][ + t_sample = sample[sample[SQLMESH_SAMPLE_TYPE_COL] == "target_only"][ [ - *[f"t__{c}" for c in index_cols], - *[f"t__{c}" for c in target_schema if c not in index_cols], + *[f"t__{c}" for c in t_index_names], + *[f"t__{c}" for c in target_schema if c not in t_index_names], ] ] t_sample.rename( columns={c: c.replace("t__", "") for c in t_sample.columns}, inplace=True ) - sample.drop(columns=sample_filter_cols, inplace=True) + sample.drop( + columns=[ + f"s__{SQLMESH_JOIN_KEY_COL}", + f"t__{SQLMESH_JOIN_KEY_COL}", + SQLMESH_SAMPLE_TYPE_COL, + ], + inplace=True, + ) self._row_diff = RowDiff( source=self.source, @@ -481,5 +670,77 @@ def name(e: exp.Expression) -> str: source_alias=self.source_alias, target_alias=self.target_alias, model_name=self.model_name, + decimals=self.decimals, ) + return self._row_diff + + def _fetch_sample( + self, + sample_table: exp.Table, + s_selects: t.Dict[str, exp.Alias], + s_index: t.List[exp.Column], + t_selects: t.Dict[str, exp.Alias], + t_index: t.List[exp.Column], + limit: int, + ) -> pd.DataFrame: + rendered_data_column_names = [ + name(s) for s in list(s_selects.values()) + list(t_selects.values()) + ] + sample_type = exp.to_identifier(SQLMESH_SAMPLE_TYPE_COL) + + source_only_sample = ( + exp.select( + exp.Literal.string("source_only").as_(sample_type), *rendered_data_column_names + ) + .from_(sample_table) + .where(exp.and_(exp.column("s_exists").eq(1), exp.column("row_joined").eq(0))) + .order_by(*(name(s_selects[c.name]) for c in s_index)) + .limit(limit) + ) + + target_only_sample = ( + exp.select( + exp.Literal.string("target_only").as_(sample_type), *rendered_data_column_names + ) + .from_(sample_table) + .where(exp.and_(exp.column("t_exists").eq(1), exp.column("row_joined").eq(0))) + .order_by(*(name(t_selects[c.name]) for c in t_index)) + .limit(limit) + ) + + common_rows_sample = ( + exp.select( + exp.Literal.string("common_rows").as_(sample_type), *rendered_data_column_names + ) + .from_(sample_table) + .where(exp.and_(exp.column("row_joined").eq(1), exp.column("row_full_match").eq(0))) + .order_by( + *(name(s_selects[c.name]) for c in s_index), + *(name(t_selects[c.name]) for c in t_index), + ) + .limit(limit) + ) + + query = ( + exp.Select() + .with_("source_only", source_only_sample) + .with_("target_only", target_only_sample) + .with_("common_rows", common_rows_sample) + .select(sample_type, *rendered_data_column_names) + .from_("source_only") + .union( + exp.select(sample_type, *rendered_data_column_names).from_("target_only"), + distinct=False, + ) + .union( + exp.select(sample_type, *rendered_data_column_names).from_("common_rows"), + distinct=False, + ) + ) + + return self.adapter.fetchdf(query, quote_identifiers=True) + + +def name(e: exp.Expression) -> str: + return e.args["alias"].sql(identify=True) diff --git a/sqlmesh/core/test/__init__.py b/sqlmesh/core/test/__init__.py index 4b1d00de1c..6353370f45 100644 --- a/sqlmesh/core/test/__init__.py +++ b/sqlmesh/core/test/__init__.py @@ -1,132 +1,9 @@ from __future__ import annotations -import pathlib -import typing as t -import unittest - -from sqlmesh.core.engine_adapter import EngineAdapter -from sqlmesh.core.model import Model from sqlmesh.core.test.definition import ModelTest as ModelTest, generate_test as generate_test from sqlmesh.core.test.discovery import ( ModelTestMetadata as ModelTestMetadata, filter_tests_by_patterns as filter_tests_by_patterns, - get_all_model_tests as get_all_model_tests, - load_model_test_file as load_model_test_file, ) from sqlmesh.core.test.result import ModelTextTestResult as ModelTextTestResult -from sqlmesh.utils import UniqueKeyDict - -if t.TYPE_CHECKING: - from sqlmesh.core.config.loader import C - - -def run_tests( - model_test_metadata: list[ModelTestMetadata], - models: UniqueKeyDict[str, Model], - config: C, - gateway: t.Optional[str] = None, - dialect: str | None = None, - verbosity: int = 1, - preserve_fixtures: bool = False, - stream: t.TextIO | None = None, - default_catalog: str | None = None, - default_catalog_dialect: str = "", -) -> ModelTextTestResult: - """Create a test suite of ModelTest objects and run it. - - Args: - model_test_metadata: A list of ModelTestMetadata named tuples. - models: All models to use for expansion and mapping of physical locations. - verbosity: The verbosity level. - preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging. - """ - testing_adapter_by_gateway: t.Dict[str, EngineAdapter] = {} - default_gateway = gateway or config.default_gateway_name - - try: - tests = [] - for metadata in model_test_metadata: - body = metadata.body - gateway = body.get("gateway") or default_gateway - testing_engine_adapter = testing_adapter_by_gateway.get(gateway) - if not testing_engine_adapter: - testing_engine_adapter = config.get_test_connection( - gateway, - default_catalog, - default_catalog_dialect, - ).create_engine_adapter(register_comments_override=False) - testing_adapter_by_gateway[gateway] = testing_engine_adapter - - tests.append( - ModelTest.create_test( - body=body, - test_name=metadata.test_name, - models=models, - engine_adapter=testing_engine_adapter, - dialect=dialect, - path=metadata.path, - default_catalog=default_catalog, - preserve_fixtures=preserve_fixtures, - ) - ) - - result = t.cast( - ModelTextTestResult, - unittest.TextTestRunner( - stream=stream, verbosity=verbosity, resultclass=ModelTextTestResult - ).run(unittest.TestSuite(tests)), - ) - finally: - for testing_engine_adapter in testing_adapter_by_gateway.values(): - testing_engine_adapter.close() - - return result - - -def run_model_tests( - tests: list[str], - models: UniqueKeyDict[str, Model], - config: C, - gateway: t.Optional[str] = None, - dialect: str | None = None, - verbosity: int = 1, - patterns: list[str] | None = None, - preserve_fixtures: bool = False, - stream: t.TextIO | None = None, - default_catalog: t.Optional[str] = None, - default_catalog_dialect: str = "", -) -> ModelTextTestResult: - """Load and run tests. - - Args: - tests: A list of tests to run, e.g. [tests/test_orders.yaml::test_single_order] - models: All models to use for expansion and mapping of physical locations. - verbosity: The verbosity level. - patterns: A list of patterns to match against. - preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging. - """ - loaded_tests = [] - for test in tests: - filename, test_name = test.split("::", maxsplit=1) if "::" in test else (test, "") - path = pathlib.Path(filename) - - if test_name: - loaded_tests.append(load_model_test_file(path, variables=config.variables)[test_name]) - else: - loaded_tests.extend(load_model_test_file(path, variables=config.variables).values()) - - if patterns: - loaded_tests = filter_tests_by_patterns(loaded_tests, patterns) - - return run_tests( - loaded_tests, - models, - config, - gateway=gateway, - dialect=dialect, - verbosity=verbosity, - preserve_fixtures=preserve_fixtures, - stream=stream, - default_catalog=default_catalog, - default_catalog_dialect=default_catalog_dialect, - ) +from sqlmesh.core.test.runner import run_tests as run_tests diff --git a/sqlmesh/core/test/context.py b/sqlmesh/core/test/context.py index 6f4563cf51..a326c3c1b3 100644 --- a/sqlmesh/core/test/context.py +++ b/sqlmesh/core/test/context.py @@ -18,6 +18,8 @@ class TestExecutionContext(ExecutionContext): models: All upstream models to use for expansion and mapping of physical locations. """ + __test__ = False # prevent pytest trying to collect this as a test class + def __init__( self, engine_adapter: EngineAdapter, @@ -26,6 +28,7 @@ def __init__( default_dialect: t.Optional[str] = None, default_catalog: t.Optional[str] = None, variables: t.Optional[t.Dict[str, t.Any]] = None, + blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, ): self._engine_adapter = engine_adapter self._models = models @@ -33,15 +36,24 @@ def __init__( self._default_catalog = default_catalog self._default_dialect = default_dialect self._variables = variables or {} + self._blueprint_variables = variables or {} @cached_property def _model_tables(self) -> t.Dict[str, str]: """Returns a mapping of model names to tables.""" + + # Include upstream dependencies to ensure they can be resolved during test execution return { - name: self._test._test_fixture_table(name).sql() for name, model in self._models.items() + name: self._test._test_fixture_table(name).sql() + for normalized_model_name, model in self._models.items() + for name in [normalized_model_name, *model.depends_on] } - def with_variables(self, variables: t.Dict[str, t.Any]) -> TestExecutionContext: + def with_variables( + self, + variables: t.Dict[str, t.Any], + blueprint_variables: t.Optional[t.Dict[str, t.Any]] = None, + ) -> TestExecutionContext: """Returns a new TestExecutionContext with additional variables.""" return TestExecutionContext( self._engine_adapter, @@ -50,4 +62,5 @@ def with_variables(self, variables: t.Dict[str, t.Any]) -> TestExecutionContext: self._default_dialect, self._default_catalog, variables=variables, + blueprint_variables=blueprint_variables, ) diff --git a/sqlmesh/core/test/definition.py b/sqlmesh/core/test/definition.py index 0871f5ced1..2a838753de 100644 --- a/sqlmesh/core/test/definition.py +++ b/sqlmesh/core/test/definition.py @@ -1,18 +1,19 @@ from __future__ import annotations +import sys + import datetime +import threading import typing as t import unittest from collections import Counter -from contextlib import AbstractContextManager, nullcontext +from contextlib import nullcontext, contextmanager, AbstractContextManager +from itertools import chain from pathlib import Path from unittest.mock import patch -import numpy as np -import pandas as pd + from io import StringIO -from freezegun import freeze_time -from pandas.api.types import is_object_dtype from sqlglot import Dialect, exp from sqlglot.optimizer.annotate_types import annotate_types from sqlglot.optimizer.normalize_identifiers import normalize_identifiers @@ -23,21 +24,35 @@ from sqlmesh.core.macros import RuntimeStage from sqlmesh.core.model import Model, PythonModel, SqlModel from sqlmesh.utils import UniqueKeyDict, random_id, type_is_known, yaml -from sqlmesh.utils.date import pandas_timestamp_to_pydatetime +from sqlmesh.utils.date import date_dict, pandas_timestamp_to_pydatetime, to_datetime from sqlmesh.utils.errors import ConfigError, TestError from sqlmesh.utils.yaml import load as yaml_load +from sqlmesh.utils import Verbosity +from sqlmesh.utils.rich import df_to_table if t.TYPE_CHECKING: + import pandas as pd + from sqlglot.dialects.dialect import DialectType Row = t.Dict[str, t.Any] -TIME_KWARG_KEYS = {"start", "end", "execution_time", "latest"} + +TIME_KWARG_KEYS = { + "start", + "end", + "execution_time", + "latest", + # all built-in datetime macro var names + *date_dict(execution_time="1970-01-01", start="1970-01-01", end="1970-01-01").keys(), +} class ModelTest(unittest.TestCase): __test__ = False + CONCURRENT_RENDER_LOCK = threading.Lock() + def __init__( self, body: t.Dict[str, t.Any], @@ -49,6 +64,8 @@ def __init__( path: Path | None = None, preserve_fixtures: bool = False, default_catalog: str | None = None, + concurrency: bool = False, + verbosity: Verbosity = Verbosity.DEFAULT, ) -> None: """ModelTest encapsulates a unit test for a model. @@ -71,6 +88,8 @@ def __init__( self.preserve_fixtures = preserve_fixtures self.default_catalog = default_catalog self.dialect = dialect + self.concurrency = concurrency + self.verbosity = verbosity self._fixture_table_cache: t.Dict[str, exp.Table] = {} self._normalized_column_name_cache: t.Dict[str, str] = {} @@ -81,60 +100,86 @@ def __init__( self._validate_and_normalize_test() if self.engine_adapter.default_catalog: - self._fixture_catalog: t.Optional[exp.Identifier] = exp.parse_identifier( - self.engine_adapter.default_catalog, dialect=self._test_adapter_dialect + self._fixture_catalog: t.Optional[exp.Identifier] = normalize_identifiers( + exp.parse_identifier( + self.engine_adapter.default_catalog, dialect=self._test_adapter_dialect + ), + dialect=self._test_adapter_dialect, ) else: self._fixture_catalog = None - # The test schema name is randomized to avoid concurrency issues - self._fixture_schema = exp.to_identifier(f"sqlmesh_test_{random_id(short=True)}") + # The test schema name is randomized to avoid concurrency issues, + # unless a schema is provided in the unit tests's body + self._fixture_schema = exp.parse_identifier( + self.body.get("schema") or f"sqlmesh_test_{random_id(short=True)}" + ) self._qualified_fixture_schema = schema_(self._fixture_schema, self._fixture_catalog) self._transforms = self._test_adapter_dialect.generator_class.TRANSFORMS self._execution_time = str(self.body.get("vars", {}).get("execution_time") or "") + if self._execution_time: + # Normalizes the execution time by converting it into UTC timezone + self._execution_time = str(to_datetime(self._execution_time)) + # When execution_time is set, we mock the CURRENT_* SQL expressions so they always return it if self._execution_time: exec_time = exp.Literal.string(self._execution_time) self._transforms = { **self._transforms, - exp.CurrentDate: lambda self, _: self.sql(exp.cast(exec_time, "date")), - exp.CurrentDatetime: lambda self, _: self.sql(exp.cast(exec_time, "datetime")), - exp.CurrentTime: lambda self, _: self.sql(exp.cast(exec_time, "time")), - exp.CurrentTimestamp: lambda self, _: self.sql(exp.cast(exec_time, "timestamp")), + exp.CurrentDate: lambda self, _: self.sql( + exp.cast(exec_time, "date", dialect=dialect) + ), + exp.CurrentDatetime: lambda self, _: self.sql( + exp.cast(exec_time, "datetime", dialect=dialect) + ), + exp.CurrentTime: lambda self, _: self.sql( + exp.cast(exec_time, "time", dialect=dialect) + ), + exp.CurrentTimestamp: lambda self, _: self.sql( + exp.cast(exec_time, "timestamp", dialect=dialect) + ), } super().__init__() + def defaultTestResult(self) -> unittest.TestResult: + from sqlmesh.core.test.result import ModelTextTestResult + + return ModelTextTestResult(stream=sys.stdout, descriptions=True, verbosity=self.verbosity) + def shortDescription(self) -> t.Optional[str]: return self.body.get("description") def setUp(self) -> None: """Load all input tables""" + import pandas as pd + import numpy as np + self.engine_adapter.create_schema(self._qualified_fixture_schema) for name, values in self.body.get("inputs", {}).items(): all_types_are_known = False - known_columns_to_types: t.Dict[str, exp.DataType] = {} + columns_to_known_types: t.Dict[str, exp.DataType] = {} model = self.models.get(name) if model: inferred_columns_to_types = model.columns_to_types or {} - known_columns_to_types = { + columns_to_known_types = { c: t for c, t in inferred_columns_to_types.items() if type_is_known(t) } all_types_are_known = bool(inferred_columns_to_types) and ( - len(known_columns_to_types) == len(inferred_columns_to_types) + len(columns_to_known_types) == len(inferred_columns_to_types) ) # Types specified in the test will override the corresponding inferred ones - known_columns_to_types.update(values.get("columns", {})) + columns_to_known_types.update(values.get("columns", {})) rows = values.get("rows") if not all_types_are_known and rows: for col, value in rows[0].items(): - if col not in known_columns_to_types: + if col not in columns_to_known_types: v_type = annotate_types(exp.convert(value)).type or type(value).__name__ v_type = exp.maybe_parse( v_type, into=exp.DataType, dialect=self._test_adapter_dialect @@ -149,17 +194,25 @@ def setUp(self) -> None: self.path, ) - known_columns_to_types[col] = v_type + columns_to_known_types[col] = v_type if rows is None: query_or_df: exp.Query | pd.DataFrame = self._add_missing_columns( - values["query"], known_columns_to_types + values["query"], columns_to_known_types ) + if columns_to_known_types: + columns_to_known_types = { + col: columns_to_known_types[col] for col in query_or_df.named_selects + } else: - query_or_df = self._create_df(values, columns=known_columns_to_types) + query_or_df = self._create_df(values, columns=columns_to_known_types) + + # Convert NaN/NaT values to None if DataFrame + if isinstance(query_or_df, pd.DataFrame): + query_or_df = query_or_df.replace({np.nan: None}) self.engine_adapter.create_view( - self._test_fixture_table(name), query_or_df, known_columns_to_types + self._test_fixture_table(name), query_or_df, columns_to_known_types ) def tearDown(self) -> None: @@ -175,9 +228,13 @@ def assert_equal( partial: t.Optional[bool] = False, ) -> None: """Compare two DataFrames""" + import numpy as np + import pandas as pd + from pandas.api.types import is_object_dtype + if partial: intersection = actual[actual.columns.intersection(expected.columns)] - if not intersection.empty: + if len(intersection.columns) > 0: actual = intersection # Two astypes are necessary, pandas converts strings to times as NS, @@ -204,26 +261,46 @@ def assert_equal( if is_object_dtype(actual_types[col]) and len(actual[col]) != 0 } for col, value in object_sentinel_values.items(): - # can't use `isinstance()` here - https://stackoverflow.com/a/68743663/1707525 - if type(value) is datetime.date: - expected[col] = pd.to_datetime(expected[col], errors="ignore").dt.date # type: ignore - elif type(value) is datetime.time: - expected[col] = pd.to_datetime(expected[col], errors="ignore").dt.time # type: ignore - elif type(value) is datetime.datetime: - expected[col] = pd.to_datetime(expected[col], errors="ignore").dt.to_pydatetime() # type: ignore + try: + # can't use `isinstance()` here - https://stackoverflow.com/a/68743663/1707525 + if type(value) is datetime.date: + expected[col] = pd.to_datetime(expected[col]).dt.date + elif type(value) is datetime.time: + expected[col] = pd.to_datetime(expected[col]).dt.time + elif type(value) is datetime.datetime: + expected[col] = pd.to_datetime(expected[col]).dt.to_pydatetime() + except Exception as e: + from sqlmesh.core.console import get_console + + get_console().log_warning( + f"Failed to convert expected value for {col} into `datetime` " + f"for unit test '{str(self)}'. {str(e)}." + ) actual = actual.replace({np.nan: None}) expected = expected.replace({np.nan: None}) + # We define this here to avoid a top-level import of numpy and pandas + DATETIME_TYPES = ( + datetime.datetime, + datetime.date, + datetime.time, + np.datetime64, + pd.Timestamp, + ) + def _to_hashable(x: t.Any) -> t.Any: if isinstance(x, (list, np.ndarray)): - return tuple(x) - return str(x) if not isinstance(x, t.Hashable) else x + return tuple(_to_hashable(v) for v in x) + if isinstance(x, dict): + return tuple((k, _to_hashable(v)) for k, v in x.items()) + return str(x) if isinstance(x, DATETIME_TYPES) or not isinstance(x, t.Hashable) else x + + actual = actual.apply(lambda col: col.map(_to_hashable)) + expected = expected.apply(lambda col: col.map(_to_hashable)) if sort: - actual = actual.apply(lambda col: col.map(_to_hashable)) actual = actual.sort_values(by=actual.columns.to_list()).reset_index(drop=True) - expected = expected.apply(lambda col: col.map(_to_hashable)) expected = expected.sort_values(by=expected.columns.to_list()).reset_index(drop=True) try: @@ -235,23 +312,63 @@ def _to_hashable(x: t.Any) -> t.Any: check_like=True, # Ignore column order ) except AssertionError as e: + # There are 2 concepts at play here: + # 1. The Exception args will contain the error message plus the diff dataframe table stringified + # (backwards compatibility with existing tests, possible to serialize/send over network etc) + # 2. Each test will also transform these diff dataframes into Rich tables, which will be the ones that'll + # be surfaced to the user through Console for better UX (versus stringified dataframes) + # + # This is a bit of a hack, but it's a way to get the best of both worlds. + args: t.List[t.Any] = [] + + failed_subtest = "" + + if subtest := getattr(self, "_subtest", None): + if cte := subtest.params.get("cte"): + failed_subtest = f" (CTE {cte})" + if expected.shape != actual.shape: _raise_if_unexpected_columns(expected.columns, actual.columns) - error_msg = "Data mismatch (rows are different)" + args.append("Data mismatch (rows are different)") missing_rows = _row_difference(expected, actual) if not missing_rows.empty: - error_msg += f"\n\nMissing rows:\n\n{missing_rows}" + args[0] += f"\n\nMissing rows:\n\n{missing_rows}" + args.append(df_to_table(f"Missing rows{failed_subtest}", missing_rows)) unexpected_rows = _row_difference(actual, expected) + if not unexpected_rows.empty: - error_msg += f"\n\nUnexpected rows:\n\n{unexpected_rows}" + args[0] += f"\n\nUnexpected rows:\n\n{unexpected_rows}" + args.append(df_to_table(f"Unexpected rows{failed_subtest}", unexpected_rows)) - e.args = (error_msg,) else: diff = expected.compare(actual).rename(columns={"self": "exp", "other": "act"}) - e.args = (f"Data mismatch (exp: expected, act: actual)\n\n{diff}",) + + args.append(f"Data mismatch (exp: expected, act: actual)\n\n{diff}") + + diff.rename(columns={"exp": "Expected", "act": "Actual"}, inplace=True) + if self.verbosity == Verbosity.DEFAULT: + args.extend( + df_to_table(f"Data mismatch{failed_subtest}", df) + for df in _split_df_by_column_pairs(diff) + ) + else: + from pandas import DataFrame, MultiIndex + + levels = t.cast(MultiIndex, diff.columns).levels[0] + for col in levels: + # diff[col] returns a DataFrame when columns is a MultiIndex + col_diff = t.cast(DataFrame, diff[col]) + if not col_diff.empty: + table = df_to_table( + f"[bold red]Column '{col}' mismatch{failed_subtest}[/bold red]", + col_diff, + ) + args.append(table) + + e.args = (*args,) raise e @@ -272,7 +389,9 @@ def create_test( path: Path | None, preserve_fixtures: bool = False, default_catalog: str | None = None, - ) -> ModelTest: + concurrency: bool = False, + verbosity: Verbosity = Verbosity.DEFAULT, + ) -> t.Optional[ModelTest]: """Create a SqlModelTest or a PythonModelTest. Args: @@ -284,10 +403,19 @@ def create_test( path: An optional path to the test definition yaml file. preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging. """ - name = normalize_model_name(body["model"], default_catalog=default_catalog, dialect=dialect) + name = body.get("model") + if name is None: + _raise_error("Missing required 'model' field", path) + + name = normalize_model_name(name, default_catalog=default_catalog, dialect=dialect) model = models.get(name) if not model: - _raise_error(f"Model '{name}' was not found", path) + from sqlmesh.core.console import get_console + + get_console().log_warning( + f"Model '{name}' was not found{' at ' + str(path) if path else ''}" + ) + return None if isinstance(model, SqlModel): test_type: t.Type[ModelTest] = SqlModelTest @@ -296,17 +424,22 @@ def create_test( else: _raise_error(f"Model '{name}' is an unsupported model type for testing", path) - return test_type( - body, - test_name, - t.cast(Model, model), - models, - engine_adapter, - dialect, - path, - preserve_fixtures, - default_catalog, - ) + try: + return test_type( + body, + test_name, + t.cast(Model, model), + models, + engine_adapter, + dialect, + path, + preserve_fixtures, + default_catalog, + concurrency, + verbosity, + ) + except Exception as e: + raise TestError(f"Failed to create test {test_name} ({path})\n{str(e)}") def __str__(self) -> str: return f"{self.test_name} ({self.path})" @@ -322,12 +455,17 @@ def _validate_and_normalize_test(self) -> None: query = outputs.get("query") partial = outputs.pop("partial", None) + if ctes is None and query is None: + _raise_error("Incomplete test, outputs must contain 'query' or 'ctes'", self.path) + def _normalize_rows( values: t.List[Row] | t.Dict, name: str, partial: bool = False, dialect: DialectType = None, ) -> t.Dict: + import pandas as pd + if not isinstance(values, dict): values = {"rows": values} @@ -465,10 +603,34 @@ def _normalize_column_name(self, name: str) -> str: return normalized_name - def _execute(self, query: exp.Query) -> pd.DataFrame: + @contextmanager + def _concurrent_render_context(self) -> t.Iterator[None]: + """ + Context manager that ensures that the tests are executed safely in a concurrent environment. + This is needed in case `execution_time` is set, as we'd then have to: + - Freeze time through `time_machine` (not thread safe) + - Globally patch the SQLGlot dialect so that any date/time nodes are evaluated at the `execution_time` during generation + """ + import time_machine + + lock_ctx: AbstractContextManager = ( + self.CONCURRENT_RENDER_LOCK if self.concurrency else nullcontext() + ) + time_ctx: AbstractContextManager = nullcontext() + dialect_patch_ctx: AbstractContextManager = nullcontext() + + if self._execution_time: + time_ctx = time_machine.travel(self._execution_time, tick=False) + dialect_patch_ctx = patch.dict( + self._test_adapter_dialect.generator_class.TRANSFORMS, self._transforms + ) + + with lock_ctx, time_ctx, dialect_patch_ctx: + yield + + def _execute(self, query: exp.Query | str) -> pd.DataFrame: """Executes the given query using the testing engine adapter and returns a DataFrame.""" - with patch.dict(self._test_adapter_dialect.generator_class.TRANSFORMS, self._transforms): - return self.engine_adapter.fetchdf(query) + return self.engine_adapter.fetchdf(query) def _create_df( self, @@ -476,19 +638,26 @@ def _create_df( columns: t.Optional[t.Collection] = None, partial: t.Optional[bool] = False, ) -> pd.DataFrame: + import pandas as pd + query = values.get("query") if query: - return self._execute(self._add_missing_columns(query, columns)) + if not partial: + query = self._add_missing_columns(query, columns) + + return self._execute(query) rows = values["rows"] + columns_str: t.Optional[t.List[str]] = None if columns: + columns_str = [str(c) for c in columns] referenced_columns = list(dict.fromkeys(col for row in rows for col in row)) _raise_if_unexpected_columns(columns, referenced_columns) if partial: - columns = referenced_columns + columns_str = [c for c in columns_str if c in referenced_columns] - return pd.DataFrame.from_records(rows, columns=columns) + return pd.DataFrame.from_records(rows, columns=columns_str) def _add_missing_columns( self, query: exp.Query, all_columns: t.Optional[t.Collection[str]] = None @@ -505,7 +674,7 @@ def _add_missing_columns( class SqlModelTest(ModelTest): - def test_ctes(self, ctes: t.Dict[str, exp.Expression]) -> None: + def test_ctes(self, ctes: t.Dict[str, exp.Expression], recursive: bool = False) -> None: """Run CTE queries and compare output to expected output""" for cte_name, values in self.body["outputs"].get("ctes", {}).items(): with self.subTest(cte=cte_name): @@ -515,40 +684,57 @@ def test_ctes(self, ctes: t.Dict[str, exp.Expression]) -> None: ) cte_query = ctes[cte_name].this - for alias, cte in ctes.items(): - cte_query = cte_query.with_(alias, cte.this) - partial = values.get("partial") sort = cte_query.args.get("order") is None + partial = values.get("partial") + + cte_query = exp.select(*_projection_identifiers(cte_query)).from_(cte_name) + for alias, cte in ctes.items(): + cte_query = cte_query.with_(alias, cte.this, recursive=recursive) - actual = self._execute(cte_query) + with self._concurrent_render_context(): + # Similar to the model's query, we render the CTE query under the locked context + # so that the execution (fetchdf) can continue concurrently between the threads + sql = cte_query.sql( + self._test_adapter_dialect, pretty=self.engine_adapter._pretty_sql + ) + + actual = self._execute(sql) expected = self._create_df(values, columns=cte_query.named_selects, partial=partial) self.assert_equal(expected, actual, sort=sort, partial=partial) def runTest(self) -> None: - query = self._render_model_query() - - self.test_ctes( - { - self._normalize_model_name(cte.alias, with_default_catalog=False): cte - for cte in query.ctes - } - ) + with self._concurrent_render_context(): + # Render the model's query and generate the SQL under the locked context so that + # execution (fetchdf) can continue concurrently between the threads + query = self._render_model_query() + sql = query.sql(self._test_adapter_dialect, pretty=self.engine_adapter._pretty_sql) + + with_clause = query.args.get("with_") + + if with_clause: + self.test_ctes( + { + self._normalize_model_name(cte.alias, with_default_catalog=False): cte + for cte in query.ctes + }, + recursive=with_clause.recursive, + ) values = self.body["outputs"].get("query") if values is not None: partial = values.get("partial") sort = query.args.get("order") is None - actual = self._execute(query) + actual = self._execute(sql) expected = self._create_df(values, columns=self.model.columns_to_types, partial=partial) self.assert_equal(expected, actual, sort=sort, partial=partial) def _render_model_query(self) -> exp.Query: variables = self.body.get("vars", {}).copy() - time_kwargs = {key: variables.pop(key, None) for key in TIME_KWARG_KEYS} + time_kwargs = {key: variables.pop(key) for key in TIME_KWARG_KEYS if key in variables} query = self.model.render_query_or_raise( **time_kwargs, @@ -574,6 +760,8 @@ def __init__( path: Path | None = None, preserve_fixtures: bool = False, default_catalog: str | None = None, + concurrency: bool = False, + verbosity: Verbosity = Verbosity.DEFAULT, ) -> None: """PythonModelTest encapsulates a unit test for a Python model. @@ -599,6 +787,8 @@ def __init__( path, preserve_fixtures, default_catalog, + concurrency, + verbosity, ) self.context = TestExecutionContext( @@ -618,18 +808,19 @@ def runTest(self) -> None: actual_df.reset_index(drop=True, inplace=True) expected = self._create_df(values, columns=self.model.columns_to_types, partial=partial) - self.assert_equal(expected, actual_df, sort=False, partial=partial) + self.assert_equal(expected, actual_df, sort=True, partial=partial) def _execute_model(self) -> pd.DataFrame: """Executes the python model and returns a DataFrame.""" - time_ctx = freeze_time(self._execution_time) if self._execution_time else nullcontext() - with patch.dict(self._test_adapter_dialect.generator_class.TRANSFORMS, self._transforms): - with t.cast(AbstractContextManager, time_ctx): - variables = self.body.get("vars", {}).copy() - time_kwargs = {key: variables.pop(key, None) for key in TIME_KWARG_KEYS} - df = next(self.model.render(context=self.context, **time_kwargs, **variables)) - assert not isinstance(df, exp.Expression) - return df if isinstance(df, pd.DataFrame) else df.toPandas() + import pandas as pd + + with self._concurrent_render_context(): + variables = self.body.get("vars", {}).copy() + time_kwargs = {key: variables.pop(key) for key in TIME_KWARG_KEYS if key in variables} + df = next(self.model.render(context=self.context, variables=variables, **time_kwargs)) + + assert not isinstance(df, exp.Expression) + return df if isinstance(df, pd.DataFrame) else df.toPandas() def generate_test( @@ -664,6 +855,8 @@ def generate_test( name: The name of the test. This is inferred from the model name by default. include_ctes: When true, CTE fixtures will also be generated. """ + import numpy as np + test_name = name or f"test_{model.view_name}" path = path or f"{test_name}.yaml" @@ -680,7 +873,7 @@ def generate_test( # ruamel.yaml does not support pandas Timestamps, so we must convert them to python # datetime or datetime.date objects based on column type inputs = { - models[dep].name: pandas_timestamp_to_pydatetime( + dep: pandas_timestamp_to_pydatetime( engine_adapter.fetchdf(query).apply(lambda col: col.map(_normalize_df_value)), models[dep].columns_to_types, ) @@ -690,7 +883,7 @@ def generate_test( } outputs: t.Dict[str, t.Any] = {"query": {}} variables = variables or {} - test_body = {"model": model.name, "inputs": inputs, "outputs": outputs} + test_body = {"model": model.fqn, "inputs": inputs, "outputs": outputs} if variables: test_body["vars"] = variables @@ -704,26 +897,36 @@ def generate_test( path=fixture_path, default_catalog=model.default_catalog, ) + if not test: + return test.setUp() if isinstance(model, SqlModel): assert isinstance(test, SqlModelTest) model_query = test._render_model_query() + with_clause = model_query.args.get("with_") - if include_ctes: + if with_clause and include_ctes: ctes = {} + recursive = with_clause.recursive previous_ctes: t.List[exp.CTE] = [] + for cte in model_query.ctes: cte_query = cte.this - for prev in previous_ctes: - cte_query = cte_query.with_(prev.alias, prev.this) + cte_identifier = cte.args["alias"].this + + cte_query = exp.select(*_projection_identifiers(cte_query)).from_(cte_identifier) + + for prev in chain(previous_ctes, [cte]): + cte_query = cte_query.with_( + prev.args["alias"].this, prev.this, recursive=recursive + ) cte_output = test._execute(cte_query) ctes[cte.alias] = ( pandas_timestamp_to_pydatetime( - cte_output.apply(lambda col: col.map(_normalize_df_value)), - cte_query.named_selects, + df=cte_output.apply(lambda col: col.map(_normalize_df_value)), ) .replace({np.nan: None}) .to_dict(orient="records") @@ -753,6 +956,19 @@ def generate_test( yaml.dump({test_name: test_body}, file) +def _projection_identifiers(query: exp.Query) -> t.List[str | exp.Identifier]: + identifiers: t.List[str | exp.Identifier] = [] + for select in query.selects: + if isinstance(select, exp.Alias): + identifiers.append(select.args["alias"]) + elif isinstance(select, exp.Column): + identifiers.append(select.this) + else: + identifiers.append(select.output_name) + + return identifiers + + def _raise_if_unexpected_columns( expected_cols: t.Collection[str], actual_cols: t.Collection[str] ) -> None: @@ -767,6 +983,9 @@ def _raise_if_unexpected_columns( def _row_difference(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame: """Returns all rows in `left` that don't appear in `right`.""" + import numpy as np + import pandas as pd + rows_missing_from_right = [] # `None` replaces `np.nan` because `np.nan != np.nan` and this would affect the mapping lookup @@ -785,12 +1004,14 @@ def _row_difference(left: pd.DataFrame, right: pd.DataFrame) -> pd.DataFrame: def _raise_error(msg: str, path: Path | None = None) -> None: if path: - raise TestError(f"{msg} at {path}") - raise TestError(msg) + raise TestError(f"Failed to run test at {path}:\n{msg}") + raise TestError(f"Failed to run test:\n{msg}") def _normalize_df_value(value: t.Any) -> t.Any: """Normalize data in a pandas dataframe so ruamel and sqlglot can deal with it.""" + import numpy as np + if isinstance(value, (list, np.ndarray)): return [_normalize_df_value(v) for v in value] if isinstance(value, dict): @@ -800,3 +1021,43 @@ def _normalize_df_value(value: t.Any) -> t.Any: return {k: _normalize_df_value(v) for k, v in zip(value["key"], value["value"])} return {k: _normalize_df_value(v) for k, v in value.items()} return value + + +def _split_df_by_column_pairs(df: pd.DataFrame, pairs_per_chunk: int = 4) -> t.List[pd.DataFrame]: + """Split a dataframe into chunks of column pairs. + + Args: + df: The dataframe to split + pairs_per_chunk: Number of column pairs per chunk (default: 4) + + Returns: + List of dataframes, each containing an even number of columns + """ + total_columns = len(df.columns) + + # If we have fewer columns than pairs_per_chunk * 2, return the original df + if total_columns <= pairs_per_chunk * 2: + return [df] + + # Calculate number of chunks needed to split columns evenly + num_chunks = (total_columns + (pairs_per_chunk * 2 - 1)) // (pairs_per_chunk * 2) + + # Calculate columns per chunk to ensure equal distribution + # We round down to nearest even number to ensure each chunk has even columns + columns_per_chunk = (total_columns // num_chunks) & ~1 # Round down to nearest even number + remainder = total_columns - (columns_per_chunk * num_chunks) + + chunks = [] + start_idx = 0 + + # Distribute columns evenly across chunks + for i in range(num_chunks): + # Add 2 columns to early chunks if there's a remainder + # This ensures we always add pairs of columns + extra = 2 if i < remainder // 2 else 0 + end_idx = start_idx + columns_per_chunk + extra + chunk = df.iloc[:, start_idx:end_idx] + chunks.append(chunk) + start_idx = end_idx + + return chunks diff --git a/sqlmesh/core/test/discovery.py b/sqlmesh/core/test/discovery.py index a7977d2036..9afe3dd7fc 100644 --- a/sqlmesh/core/test/discovery.py +++ b/sqlmesh/core/test/discovery.py @@ -4,13 +4,11 @@ import itertools import pathlib import typing as t -from collections.abc import Iterator import ruamel from sqlmesh.utils import unique from sqlmesh.utils.pydantic import PydanticModel -from sqlmesh.utils.yaml import load as yaml_load class ModelTestMetadata(PydanticModel): @@ -22,63 +20,14 @@ class ModelTestMetadata(PydanticModel): def fully_qualified_test_name(self) -> str: return f"{self.path}::{self.test_name}" + @property + def model_name(self) -> str: + return self.body.get("model", "") + def __hash__(self) -> int: return self.fully_qualified_test_name.__hash__() -def load_model_test_file( - path: pathlib.Path, variables: dict[str, t.Any] | None = None -) -> dict[str, ModelTestMetadata]: - """Load a single model test file. - - Args: - path: The path to the test file - - returns: - A list of ModelTestMetadata named tuples. - """ - model_test_metadata = {} - contents = yaml_load(path, variables=variables) - - for test_name, value in contents.items(): - model_test_metadata[test_name] = ModelTestMetadata( - path=path, test_name=test_name, body=value - ) - return model_test_metadata - - -def discover_model_tests( - path: pathlib.Path, - ignore_patterns: list[str] | None = None, - variables: dict[str, t.Any] | None = None, -) -> Iterator[ModelTestMetadata]: - """Discover model tests. - - Model tests are defined in YAML files and contain the inputs and outputs used to test model queries. - - Args: - path: A path to search for tests. - ignore_patterns: An optional list of patterns to ignore. - - Returns: - A list of ModelTestMetadata named tuples. - """ - search_path = pathlib.Path(path) - - for yaml_file in itertools.chain( - search_path.glob("**/test*.yaml"), - search_path.glob("**/test*.yml"), - ): - for ignore_pattern in ignore_patterns or []: - if yaml_file.match(ignore_pattern): - break - else: - for model_test_metadata in load_model_test_file( - yaml_file, variables=variables - ).values(): - yield model_test_metadata - - def filter_tests_by_patterns( tests: list[ModelTestMetadata], patterns: list[str] ) -> list[ModelTestMetadata]: @@ -97,19 +46,3 @@ def filter_tests_by_patterns( if ("*" in pattern and fnmatch.fnmatchcase(test.fully_qualified_test_name, pattern)) or pattern in test.fully_qualified_test_name ) - - -def get_all_model_tests( - *paths: pathlib.Path, - patterns: list[str] | None = None, - ignore_patterns: list[str] | None = None, - variables: dict[str, t.Any] | None = None, -) -> list[ModelTestMetadata]: - model_test_metadatas = [ - meta - for path in paths - for meta in discover_model_tests(pathlib.Path(path), ignore_patterns, variables=variables) - ] - if patterns: - model_test_metadatas = filter_tests_by_patterns(model_test_metadatas, patterns) - return model_test_metadatas diff --git a/sqlmesh/core/test/result.py b/sqlmesh/core/test/result.py index 752b896bb3..eefa0be513 100644 --- a/sqlmesh/core/test/result.py +++ b/sqlmesh/core/test/result.py @@ -4,21 +4,58 @@ import typing as t import unittest +from sqlmesh.core.test.definition import ModelTest + +if t.TYPE_CHECKING: + ErrorType = t.Union[ + t.Tuple[type[BaseException], BaseException, types.TracebackType], + t.Tuple[None, None, None], + ] + class ModelTextTestResult(unittest.TextTestResult): successes: t.List[unittest.TestCase] def __init__(self, *args: t.Any, **kwargs: t.Any): + self.console = kwargs.pop("console", None) super().__init__(*args, **kwargs) self.successes = [] + self.original_failures: t.List[t.Tuple[unittest.TestCase, ErrorType]] = [] + self.failure_tables: t.List[t.Tuple[t.Any, ...]] = [] + self.original_errors: t.List[t.Tuple[unittest.TestCase, ErrorType]] = [] + self.duration: t.Optional[float] = None - def addFailure( + def addSubTest( self, test: unittest.TestCase, - err: ( - tuple[type[BaseException], BaseException, types.TracebackType] | tuple[None, None, None] - ), + subtest: unittest.TestCase, + err: t.Optional[ErrorType], ) -> None: + """Called at the end of a subtest. + + The traceback is suppressed because it is redundant and not useful. + + Args: + test: The test case. + subtest: The subtest instance. + err: A tuple of the form returned by sys.exc_info(), i.e., (type, value, traceback). + """ + if err: + exctype, value, tb = err + err = (exctype, value, None) # type: ignore + + if err[0] and issubclass(err[0], test.failureException): + self.addFailure(test, err) + else: + self.addError(test, err) + + def _print_char(self, char: str) -> None: + from sqlmesh.core.console import TerminalConsole + + if isinstance(self.console, TerminalConsole): + self.console._print(char, end="") + + def addFailure(self, test: unittest.TestCase, err: ErrorType) -> None: """Called when the test case test signals a failure. The traceback is suppressed because it is redundant and not useful. @@ -27,9 +64,37 @@ def addFailure( test: The test case. err: A tuple of the form returned by sys.exc_info(), i.e., (type, value, traceback). """ - exctype, value, tb = err + exctype, value, _ = err + + if value and value.args: + exception_msg, rich_tables = value.args[:1], value.args[1:] + value.args = exception_msg + + if rich_tables: + self.failure_tables.append(rich_tables) + + self._print_char("F") + + self.original_failures.append((test, err)) + + # Intentionally ignore the traceback to hide it from the user return super().addFailure(test, (exctype, value, None)) # type: ignore + def addError(self, test: unittest.TestCase, err: ErrorType) -> None: + """Called when the test case test signals an error. + + Args: + test: The test case. + err: A tuple of the form returned by sys.exc_info(), i.e., (type, value, traceback). + """ + exctype, value, _ = err + self.original_errors.append((test, err)) + + self._print_char("E") + + # Intentionally ignore the traceback to hide it from the user + return super().addError(test, (exctype, value, None)) # type: ignore + def addSuccess(self, test: unittest.TestCase) -> None: """Called when the test case test succeeds. @@ -37,4 +102,34 @@ def addSuccess(self, test: unittest.TestCase) -> None: test: The test case """ super().addSuccess(test) + + self._print_char(".") + self.successes.append(test) + + def merge(self, other: ModelTextTestResult) -> None: + if other.successes: + self.addSuccess(other.successes[0]) + elif other.errors: + for error_test, error in other.original_errors: + self.addError(error_test, error) + elif other.failures: + for failure_test, failure in other.original_failures: + self.addFailure(failure_test, failure) + + self.failure_tables.extend(other.failure_tables) + elif other.skipped: + skipped_args = other.skipped[0] + self.addSkip(skipped_args[0], skipped_args[1]) + + self.testsRun += other.testsRun + + def get_fail_and_error_tests(self) -> t.List[ModelTest]: + # If tests contain failed subtests (e.g testing CTE outputs) we don't want + # to report it as different test failures + test_name_to_test = { + test.test_name: test + for test, _ in self.failures + self.errors + if isinstance(test, ModelTest) + } + return list(test_name_to_test.values()) diff --git a/sqlmesh/core/test/runner.py b/sqlmesh/core/test/runner.py new file mode 100644 index 0000000000..284558e1c8 --- /dev/null +++ b/sqlmesh/core/test/runner.py @@ -0,0 +1,179 @@ +from __future__ import annotations + +import time +import threading +import typing as t +import unittest +from io import StringIO + +import concurrent +from concurrent.futures import ThreadPoolExecutor + +from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.core.model import Model +from sqlmesh.core.test.definition import ModelTest as ModelTest, generate_test as generate_test +from sqlmesh.core.test.discovery import ( + ModelTestMetadata as ModelTestMetadata, +) +from sqlmesh.core.config.connection import BaseDuckDBConnectionConfig +from sqlmesh.core.test.result import ModelTextTestResult as ModelTextTestResult +from sqlmesh.utils import UniqueKeyDict, Verbosity + + +if t.TYPE_CHECKING: + from sqlmesh.core.config.loader import C + + +class ModelTextTestRunner(unittest.TextTestRunner): + def __init__( + self, + **kwargs: t.Any, + ) -> None: + # StringIO is used to capture the output of the tests since we'll + # run them in parallel and we don't want to mix the output streams + from io import StringIO + + super().__init__( + stream=StringIO(), + resultclass=ModelTextTestResult, + **kwargs, + ) + + +def create_testing_engine_adapters( + model_test_metadata: list[ModelTestMetadata], + config: C, + selected_gateway: str, + default_catalog: str | None = None, + default_catalog_dialect: str = "", +) -> t.Dict[ModelTestMetadata, EngineAdapter]: + testing_adapter_by_gateway: t.Dict[str, EngineAdapter] = {} + metadata_to_adapter = {} + + for metadata in model_test_metadata: + gateway = metadata.body.get("gateway") or selected_gateway + test_connection = config.get_test_connection( + gateway, default_catalog, default_catalog_dialect + ) + + concurrent_tasks = test_connection.concurrent_tasks + + is_duckdb_connection = isinstance(test_connection, BaseDuckDBConnectionConfig) + adapter = None + if is_duckdb_connection: + # Ensure DuckDB connections are fully isolated from each other + # by forcing the creation of a new adapter with SingletonConnectionPool + test_connection.concurrent_tasks = 1 + adapter = test_connection.create_engine_adapter(register_comments_override=False) + test_connection.concurrent_tasks = concurrent_tasks + elif gateway not in testing_adapter_by_gateway: + # All other engines can share connections between threads + testing_adapter_by_gateway[gateway] = test_connection.create_engine_adapter( + register_comments_override=False + ) + + metadata_to_adapter[metadata] = adapter or testing_adapter_by_gateway[gateway] + + return metadata_to_adapter + + +def run_tests( + model_test_metadata: list[ModelTestMetadata], + models: UniqueKeyDict[str, Model], + config: C, + selected_gateway: str, + dialect: str | None = None, + verbosity: Verbosity = Verbosity.DEFAULT, + preserve_fixtures: bool = False, + stream: t.TextIO | None = None, + default_catalog: str | None = None, + default_catalog_dialect: str = "", +) -> ModelTextTestResult: + """Create a test suite of ModelTest objects and run it. + + Args: + model_test_metadata: A list of ModelTestMetadata named tuples. + models: All models to use for expansion and mapping of physical locations. + verbosity: The verbosity level. + preserve_fixtures: Preserve the fixture tables in the testing database, useful for debugging. + """ + default_test_connection = config.get_test_connection( + gateway_name=selected_gateway, + default_catalog=default_catalog, + default_catalog_dialect=default_catalog_dialect, + ) + + lock = threading.Lock() + + from sqlmesh.core.console import get_console + + combined_results = ModelTextTestResult( + stream=unittest.runner._WritelnDecorator(stream or StringIO()), # type: ignore + verbosity=2 if verbosity >= Verbosity.VERBOSE else 1, + descriptions=True, + console=get_console(), + ) + + metadata_to_adapter = create_testing_engine_adapters( + model_test_metadata=model_test_metadata, + config=config, + selected_gateway=selected_gateway, + default_catalog=default_catalog, + default_catalog_dialect=default_catalog_dialect, + ) + + # Ensure workers are not greater than the number of tests + num_workers = min(len(model_test_metadata) or 1, default_test_connection.concurrent_tasks) + + def _run_single_test( + metadata: ModelTestMetadata, engine_adapter: EngineAdapter + ) -> t.Optional[ModelTextTestResult]: + test = ModelTest.create_test( + body=metadata.body, + test_name=metadata.test_name, + models=models, + engine_adapter=engine_adapter, + dialect=dialect, + path=metadata.path, + default_catalog=default_catalog, + preserve_fixtures=preserve_fixtures, + concurrency=num_workers > 1, + verbosity=verbosity, + ) + + if not test: + return None + + result = t.cast( + ModelTextTestResult, + ModelTextTestRunner().run(t.cast(unittest.TestCase, test)), + ) + + with lock: + combined_results.merge(result) + + return result + + test_results = [] + + start_time = time.perf_counter() + try: + with ThreadPoolExecutor(max_workers=num_workers) as pool: + futures = [ + pool.submit(_run_single_test, metadata=metadata, engine_adapter=engine_adapter) + for metadata, engine_adapter in metadata_to_adapter.items() + ] + + for future in concurrent.futures.as_completed(futures): + test_results.append(future.result()) + finally: + for engine_adapter in set(metadata_to_adapter.values()): + # The engine adapters list might have duplicates, so we ensure that we close each adapter once + if engine_adapter: + engine_adapter.close() + + end_time = time.perf_counter() + + combined_results.duration = round(end_time - start_time, 2) + + return combined_results diff --git a/sqlmesh/core/user.py b/sqlmesh/core/user.py index ad6a3221c8..fabc06516f 100644 --- a/sqlmesh/core/user.py +++ b/sqlmesh/core/user.py @@ -1,15 +1,8 @@ import typing as t from enum import Enum -from sqlmesh.core.notification_target import ( - BasicSMTPNotificationTarget, - NotificationTarget, -) -from sqlmesh.utils.pydantic import ( - PydanticModel, - field_validator, - field_validator_v1_args, -) +from sqlmesh.core.notification_target import BasicSMTPNotificationTarget, NotificationTarget +from sqlmesh.utils.pydantic import PydanticModel, ValidationInfo, field_validator class UserRole(str, Enum): @@ -44,13 +37,12 @@ def is_required_approver(self) -> bool: return UserRole.REQUIRED_APPROVER in self.roles @field_validator("notification_targets") - @field_validator_v1_args def validate_notification_targets( cls, v: t.List[NotificationTarget], - values: t.Dict[str, t.Any], + info: ValidationInfo, ) -> t.List[NotificationTarget]: - email = values["email"] + email = info.data["email"] for target in v: if isinstance(target, BasicSMTPNotificationTarget) and target.recipients != {email}: raise ValueError("Recipient emails do not match user email") diff --git a/sqlmesh/dbt/adapter.py b/sqlmesh/dbt/adapter.py index 1820e72b25..7f7c7eb4fb 100644 --- a/sqlmesh/dbt/adapter.py +++ b/sqlmesh/dbt/adapter.py @@ -4,20 +4,22 @@ import logging import typing as t -import pandas as pd from sqlglot import exp, parse_one -from sqlmesh.core.dialect import normalize_and_quote, normalize_model_name +from sqlmesh.core.dialect import normalize_and_quote, normalize_model_name, schema_ from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.snapshot import DeployabilityIndex, Snapshot, to_table_mapping from sqlmesh.utils.errors import ConfigError, ParsetimeAdapterCallError from sqlmesh.utils.jinja import JinjaMacroRegistry +from sqlmesh.utils import AttributeDict +from sqlmesh.core.schema_diff import TableAlterOperation if t.TYPE_CHECKING: import agate from dbt.adapters.base import BaseRelation from dbt.adapters.base.column import Column from dbt.adapters.base.impl import AdapterResponse + from sqlmesh.core.engine_adapter.base import DataObject from sqlmesh.dbt.relation import Policy @@ -30,11 +32,15 @@ def __init__( jinja_macros: JinjaMacroRegistry, jinja_globals: t.Optional[t.Dict[str, t.Any]] = None, project_dialect: t.Optional[str] = None, + quote_policy: t.Optional[Policy] = None, ): + from dbt.adapters.base.relation import Policy + self.jinja_macros = jinja_macros self.jinja_globals = jinja_globals.copy() if jinja_globals else {} self.jinja_globals["adapter"] = self self.project_dialect = project_dialect + self.quote_policy = quote_policy or Policy() @abc.abstractmethod def get_relation(self, database: str, schema: str, identifier: str) -> t.Optional[BaseRelation]: @@ -77,6 +83,16 @@ def drop_schema(self, relation: BaseRelation) -> None: def drop_relation(self, relation: BaseRelation) -> None: """Drops a relation (table) in the target database.""" + @abc.abstractmethod + def expand_target_column_types( + self, from_relation: BaseRelation, to_relation: BaseRelation + ) -> None: + """Expand to_relation's column types to match those of from_relation.""" + + @abc.abstractmethod + def rename_relation(self, from_relation: BaseRelation, to_relation: BaseRelation) -> None: + """Renames a relation (table) in the target database.""" + @abc.abstractmethod def execute( self, sql: str, auto_begin: bool = False, fetch: bool = False @@ -95,30 +111,43 @@ def quote(self, identifier: str) -> str: """Returns a quoted identifier.""" return exp.to_column(identifier).sql(dialect=self.project_dialect, identify=True) - def dispatch(self, name: str, package: t.Optional[str] = None) -> t.Callable: + def quote_as_configured(self, value: str, component_type: str) -> str: + """Returns the value quoted according to the quote policy.""" + return self.quote(value) if getattr(self.quote_policy, component_type, False) else value + + def dispatch( + self, + macro_name: str, + macro_namespace: t.Optional[str] = None, + ) -> t.Callable: """Returns a dialect-specific version of a macro with the given name.""" target_type = self.jinja_globals["target"]["type"] - macro_suffix = f"__{name}" + macro_suffix = f"__{macro_name}" def _relevance(package_name_pair: t.Tuple[t.Optional[str], str]) -> t.Tuple[int, int]: """Lower scores more relevant.""" - macro_package, macro_name = package_name_pair + macro_package, name = package_name_pair - package_score = 0 if macro_package == package else 1 + package_score = 0 if macro_package == macro_namespace else 1 name_score = 1 - if macro_name.startswith("default"): + if name.startswith("default"): name_score = 2 - elif macro_name.startswith(target_type): + elif name.startswith(target_type): name_score = 0 return name_score, package_score jinja_env = self.jinja_macros.build_environment(**self.jinja_globals).globals - packages_to_check: t.List[t.Optional[str]] = [ - package, - *(k for k in jinja_env if k.startswith("dbt")), - ] + + packages_to_check: t.List[t.Optional[str]] = [None] + if macro_namespace is not None: + if macro_namespace in jinja_env: + packages_to_check = [self.jinja_macros.root_package_name, macro_namespace] + + # Add dbt packages as fallback + packages_to_check.extend(k for k in jinja_env if k.startswith("dbt")) + candidates = {} for macro_package in packages_to_check: macros = jinja_env.get(macro_package, {}) if macro_package else jinja_env @@ -136,7 +165,7 @@ def _relevance(package_name_pair: t.Tuple[t.Optional[str], str]) -> t.Tuple[int, sorted_candidates = sorted(candidates, key=_relevance) return candidates[sorted_candidates[0]] - raise ConfigError(f"Macro '{name}', package '{package}' was not found.") + raise ConfigError(f"Macro '{macro_name}', package '{macro_namespace}' was not found.") def type(self) -> str: return self.project_dialect or "" @@ -146,6 +175,21 @@ def compare_dbr_version(self, major: int, minor: int) -> int: # Always return -1 to fallback to Spark macro implementations. return -1 + @property + def graph(self) -> t.Any: + flat_graph = self.jinja_globals.get("flat_graph", None) + return flat_graph or AttributeDict( + { + "exposures": {}, + "groups": {}, + "metrics": {}, + "nodes": {}, + "sources": {}, + "semantic_models": {}, + "saved_queries": {}, + } + ) + class ParsetimeAdapter(BaseAdapter): def get_relation(self, database: str, schema: str, identifier: str) -> t.Optional[BaseRelation]: @@ -183,6 +227,14 @@ def drop_schema(self, relation: BaseRelation) -> None: def drop_relation(self, relation: BaseRelation) -> None: self._raise_parsetime_adapter_call_error("drop relation") + def expand_target_column_types( + self, from_relation: BaseRelation, to_relation: BaseRelation + ) -> None: + self._raise_parsetime_adapter_call_error("expand target column types") + + def rename_relation(self, from_relation: BaseRelation, to_relation: BaseRelation) -> None: + self._raise_parsetime_adapter_call_error("rename relation") + def execute( self, sql: str, auto_begin: bool = False, fetch: bool = False ) -> t.Tuple[AdapterResponse, agate.Table]: @@ -216,12 +268,12 @@ def __init__( ): from dbt.adapters.base import BaseRelation from dbt.adapters.base.column import Column - from dbt.adapters.base.relation import Policy super().__init__( jinja_macros, jinja_globals=jinja_globals, project_dialect=project_dialect or engine_adapter.dialect, + quote_policy=quote_policy, ) table_mapping = table_mapping or {} @@ -229,7 +281,6 @@ def __init__( self.engine_adapter = engine_adapter self.relation_type = relation_type or BaseRelation self.column_type = column_type or Column - self.quote_policy = quote_policy or Policy() self.table_mapping = { **to_table_mapping((snapshots or {}).values(), deployability_index), **table_mapping, @@ -238,56 +289,59 @@ def __init__( def get_relation( self, database: t.Optional[str], schema: str, identifier: str ) -> t.Optional[BaseRelation]: - return self.load_relation( - self.relation_type.create( - database=database, - schema=schema, - identifier=identifier, - quote_policy=self.quote_policy, - ) - ) + target_table = exp.table_(identifier, db=schema, catalog=database) + # Normalize before converting to a relation; otherwise, it will be too late, + # as quotes will have already been applied. + target_table = self._normalize(target_table) + return self.load_relation(self._table_to_relation(target_table)) def load_relation(self, relation: BaseRelation) -> t.Optional[BaseRelation]: mapped_table = self._map_table_name(self._normalize(self._relation_to_table(relation))) - if not self.engine_adapter.table_exists(mapped_table): - return None - return self._table_to_relation(mapped_table) + data_object = self.engine_adapter.get_data_object(mapped_table) + return self._data_object_to_relation(data_object) if data_object is not None else None def list_relations(self, database: t.Optional[str], schema: str) -> t.List[BaseRelation]: - reference_relation = self.relation_type.create( - database=database, schema=schema, quote_policy=self.quote_policy - ) - return self.list_relations_without_caching(reference_relation) + target_schema = schema_(schema, catalog=database) + # Normalize before converting to a relation; otherwise, it will be too late, + # as quotes will have already been applied. + target_schema = self._normalize(target_schema) + return self.list_relations_without_caching(self._table_to_relation(target_schema)) def list_relations_without_caching(self, schema_relation: BaseRelation) -> t.List[BaseRelation]: - from sqlmesh.dbt.relation import RelationType - schema = self._normalize(self._schema(schema_relation)) relations = [ - self.relation_type.create( - database=do.catalog, - schema=do.schema_name, - identifier=do.name, - quote_policy=self.quote_policy, - # DBT relation types aren't snake case and instead just one word without spaces so we remove underscores - type=( - RelationType.External - if do.type.is_unknown - else RelationType(do.type.lower().replace("_", "")) - ), - ) - for do in self.engine_adapter.get_data_objects(schema) + self._data_object_to_relation(do) for do in self.engine_adapter.get_data_objects(schema) ] return relations def get_columns_in_relation(self, relation: BaseRelation) -> t.List[Column]: - from dbt.adapters.base.column import Column - mapped_table = self._map_table_name(self._normalize(self._relation_to_table(relation))) + + if self.project_dialect == "bigquery": + # dbt.adapters.bigquery.column.BigQueryColumn has a different constructor signature + # We need to use BigQueryColumn.create_from_field() to create the column instead + if ( + hasattr(self.column_type, "create_from_field") + and callable(getattr(self.column_type, "create_from_field")) + and hasattr(self.engine_adapter, "get_bq_schema") + and callable(getattr(self.engine_adapter, "get_bq_schema")) + ): + return [ + self.column_type.create_from_field(field) # type: ignore + for field in self.engine_adapter.get_bq_schema(mapped_table) # type: ignore + ] + from dbt.adapters.base.column import Column + + return [ + Column.from_description( + name=name, raw_data_type=dtype.sql(dialect=self.project_dialect) + ) + for name, dtype in self.engine_adapter.columns(table_name=mapped_table).items() + ] return [ - Column.from_description( + self.column_type.from_description( name=name, raw_data_type=dtype.sql(dialect=self.project_dialect) ) for name, dtype in self.engine_adapter.columns(table_name=mapped_table).items() @@ -316,9 +370,49 @@ def drop_relation(self, relation: BaseRelation) -> None: if relation.schema is not None and relation.identifier is not None: self.engine_adapter.drop_table(self._normalize(self._relation_to_table(relation))) + def expand_target_column_types( + self, from_relation: BaseRelation, to_relation: BaseRelation + ) -> None: + from_dbt_columns = {c.name: c for c in self.get_columns_in_relation(from_relation)} + to_dbt_columns = {c.name: c for c in self.get_columns_in_relation(to_relation)} + + from_table_name = self._normalize(self._relation_to_table(from_relation)) + to_table_name = self._normalize(self._relation_to_table(to_relation)) + + from_columns = self.engine_adapter.columns(from_table_name) + to_columns = self.engine_adapter.columns(to_table_name) + + current_columns = {} + new_columns = {} + for column_name, from_column in from_dbt_columns.items(): + target_column = to_dbt_columns.get(column_name) + if target_column is not None and target_column.can_expand_to(from_column): + current_columns[column_name] = to_columns[column_name] + new_columns[column_name] = from_columns[column_name] + + alter_expressions = t.cast( + t.List[TableAlterOperation], + self.engine_adapter.schema_differ.compare_columns( + to_table_name, + current_columns, + new_columns, + ignore_destructive=True, + ), + ) + + if alter_expressions: + self.engine_adapter.alter_table(alter_expressions) + + def rename_relation(self, from_relation: BaseRelation, to_relation: BaseRelation) -> None: + old_table_name = self._normalize(self._relation_to_table(from_relation)) + new_table_name = self._normalize(self._relation_to_table(to_relation)) + + self.engine_adapter.rename_table(old_table_name, new_table_name) + def execute( self, sql: str, auto_begin: bool = False, fetch: bool = False ) -> t.Tuple[AdapterResponse, agate.Table]: + import pandas as pd from dbt.adapters.base.impl import AdapterResponse from sqlmesh.dbt.util import pandas_to_agate, empty_table @@ -372,6 +466,24 @@ def _map_table_name(self, table: exp.Table) -> exp.Table: def _relation_to_table(self, relation: BaseRelation) -> exp.Table: return exp.to_table(relation.render(), dialect=self.project_dialect) + def _data_object_to_relation(self, data_object: DataObject) -> BaseRelation: + from sqlmesh.dbt.relation import RelationType + + if data_object.type.is_unknown: + dbt_relation_type = RelationType.External + elif data_object.type.is_managed_table: + dbt_relation_type = RelationType.Table + else: + dbt_relation_type = RelationType(data_object.type.lower()) + + return self.relation_type.create( + database=data_object.catalog, + schema=data_object.schema_name, + identifier=data_object.name, + quote_policy=self.quote_policy, + type=dbt_relation_type, + ) + def _table_to_relation(self, table: exp.Table) -> BaseRelation: return self.relation_type.create( database=table.catalog or None, diff --git a/sqlmesh/dbt/basemodel.py b/sqlmesh/dbt/basemodel.py index 40a07ea954..32a76aba13 100644 --- a/sqlmesh/dbt/basemodel.py +++ b/sqlmesh/dbt/basemodel.py @@ -4,13 +4,17 @@ from abc import abstractmethod from enum import Enum from pathlib import Path +import logging from pydantic import Field from sqlglot.helper import ensure_list from sqlmesh.core import dialect as d from sqlmesh.core.config.base import UpdateStrategy +from sqlmesh.core.config.common import VirtualEnvironmentMode from sqlmesh.core.model import Model +from sqlmesh.core.model.common import ParsableSql +from sqlmesh.core.node import DbtNodeInfo from sqlmesh.dbt.column import ( ColumnConfig, column_descriptions_to_sqlmesh, @@ -20,23 +24,28 @@ DbtConfig, Dependencies, GeneralConfig, + RAW_CODE_KEY, SqlStr, sql_str_validator, ) from sqlmesh.dbt.relation import Policy, RelationType from sqlmesh.dbt.test import TestConfig +from sqlmesh.dbt.util import DBT_VERSION from sqlmesh.utils import AttributeDict -from sqlmesh.utils.conversions import ensure_bool from sqlmesh.utils.errors import ConfigError from sqlmesh.utils.pydantic import field_validator if t.TYPE_CHECKING: + from sqlmesh.core.audit.definition import ModelAudit from sqlmesh.dbt.context import DbtContext BMC = t.TypeVar("BMC", bound="BaseModelConfig") +logger = logging.getLogger(__name__) + + class Materialization(str, Enum): """DBT model materializations""" @@ -46,6 +55,15 @@ class Materialization(str, Enum): EPHEMERAL = "ephemeral" SNAPSHOT = "snapshot" + # Snowflake, https://docs.getdbt.com/reference/resource-configs/snowflake-configs#dynamic-tables + DYNAMIC_TABLE = "dynamic_table" + + CUSTOM = "custom" + + @classmethod + def _missing_(cls, value): # type: ignore + return cls.CUSTOM + class SnapshotStrategy(str, Enum): """DBT snapshot strategies""" @@ -70,7 +88,7 @@ class Hook(DbtConfig): """ sql: SqlStr - transaction: bool = True # TODO not yet supported + transaction: bool = True _sql_validator = sql_str_validator @@ -101,15 +119,19 @@ class BaseModelConfig(GeneralConfig): # sqlmesh fields owner: t.Optional[str] = None stamp: t.Optional[str] = None + table_format: t.Optional[str] = None storage_format: t.Optional[str] = None path: Path = Path() dependencies: Dependencies = Dependencies() tests: t.List[TestConfig] = [] dialect_: t.Optional[str] = Field(None, alias="dialect") + grain: t.Union[str, t.List[str]] = [] # DBT configuration fields + unique_id: str = "" name: str = "" package_name: str = "" + fqn_: t.List[str] = Field(default_factory=list, alias="fqn") schema_: str = Field("", alias="schema") database: t.Optional[str] = None alias: t.Optional[str] = None @@ -119,6 +141,7 @@ class BaseModelConfig(GeneralConfig): grants: t.Dict[str, t.List[str]] = {} columns: t.Dict[str, ColumnConfig] = {} quoting: t.Dict[str, t.Optional[bool]] = {} + event_time: t.Optional[str] = None version: t.Optional[int] = None latest_version: t.Optional[int] = None @@ -141,14 +164,13 @@ def _validate_hooks(cls, v: t.Union[str, t.List[t.Union[SqlStr, str]]]) -> t.Lis return hooks - @field_validator("full_refresh", mode="before") - @classmethod - def _validate_bool(cls, v: str) -> bool: - return ensure_bool(v) - @field_validator("grants", mode="before") @classmethod - def _validate_grants(cls, v: t.Dict[str, str]) -> t.Dict[str, t.List[str]]: + def _validate_grants( + cls, v: t.Optional[t.Dict[str, str]] + ) -> t.Optional[t.Dict[str, t.List[str]]]: + if v is None: + return None return {key: ensure_list(value) for key, value in v.items()} _FIELD_UPDATE_STRATEGY: t.ClassVar[t.Dict[str, UpdateStrategy]] = { @@ -162,14 +184,6 @@ def _validate_grants(cls, v: t.Dict[str, str]) -> t.Dict[str, t.List[str]]: }, } - @property - def sql_no_config(self) -> SqlStr: - return SqlStr("") - - @property - def sql_embedded_config(self) -> SqlStr: - return SqlStr("") - @property def table_schema(self) -> str: """ @@ -187,9 +201,9 @@ def table_name(self) -> str: @property def config_name(self) -> str: """ - Get the model's config name (package_name.alias) + Get the model's config name (package_name.name) """ - return f"{self.package_name}.{self.alias}" + return f"{self.package_name}.{self.name}" def dialect(self, context: DbtContext) -> str: return self.dialect_ or context.default_dialect @@ -224,6 +238,12 @@ def relation_info(self) -> AttributeDict[str, t.Any]: else: relation_type = RelationType.Table + extras = {} + if DBT_VERSION >= (1, 9, 0) and self.event_time: + extras["event_time_filter"] = { + "field_name": self.event_time, + } + return AttributeDict( { "database": self.database, @@ -231,12 +251,10 @@ def relation_info(self) -> AttributeDict[str, t.Any]: "identifier": self.table_name, "type": relation_type.value, "quote_policy": AttributeDict(self.quoting), + **extras, } ) - def model_function(self) -> AttributeDict[str, t.Any]: - return AttributeDict({"config": self.config_attribute_dict}) - @property def tests_ref_source_dependencies(self) -> Dependencies: dependencies = Dependencies() @@ -247,12 +265,9 @@ def tests_ref_source_dependencies(self) -> Dependencies: dependencies.macros = [] return dependencies - def check_for_circular_test_refs(self, context: DbtContext) -> None: + def remove_tests_with_invalid_refs(self, context: DbtContext) -> None: """ - Checks for direct circular references between two models and raises an exception if found. - This addresses the most common circular reference seen when importing a dbt project - - relationship tests in both directions. In the future, we may want to increase coverage by - checking for indirect circular references. + Removes tests that reference models or sources that do not exist in the context in order to match dbt behavior. Args: context: The dbt context this model resides within. @@ -260,64 +275,138 @@ def check_for_circular_test_refs(self, context: DbtContext) -> None: Returns: None """ - for test in self.tests: - for ref in test.dependencies.refs: - model = context.refs[ref] - if ref == self.name or ref in self.dependencies.refs: - continue - elif self.name in model.dependencies.refs: - raise ConfigError( - f"Test '{test.name}' for model '{self.name}' depends on downstream model '{model.name}'." - " Move the test to the downstream model to avoid circular references." - ) - elif self.name in model.tests_ref_source_dependencies.refs: - circular_test = next( - test.name for test in model.tests if ref in test.dependencies.refs - ) - raise ConfigError( - f"Circular reference detected between tests for models '{self.name}' and '{model.name}':" - f" '{test.name}' ({self.name}), '{circular_test}' ({model.name})." - ) + self.tests = [ + test + for test in self.tests + if all(ref in context.refs for ref in test.dependencies.refs) + and all(source in context.sources for source in test.dependencies.sources) + ] + + @property + def fqn(self) -> str: + return ".".join(self.fqn_) @property def sqlmesh_config_fields(self) -> t.Set[str]: return {"description", "owner", "stamp", "storage_format"} - def sqlmesh_model_kwargs(self, context: DbtContext) -> t.Dict[str, t.Any]: + @property + def node_info(self) -> DbtNodeInfo: + return DbtNodeInfo(unique_id=self.unique_id, name=self.name, fqn=self.fqn, alias=self.alias) + + def sqlmesh_model_kwargs( + self, + context: DbtContext, + column_types_override: t.Optional[t.Dict[str, ColumnConfig]] = None, + ) -> t.Dict[str, t.Any]: """Get common sqlmesh model parameters""" - self.check_for_circular_test_refs(context) + self.remove_tests_with_invalid_refs(context) + + dependencies = self.dependencies.copy() + if dependencies.has_dynamic_var_names: + # Include ALL variables as dependencies since we couldn't determine + # precisely which variables are referenced in the model + dependencies.variables |= set(context.variables) + + if ( + getattr(self, "model_materialization", None) == Materialization.CUSTOM + and hasattr(self, "_get_custom_materialization") + and (custom_mat := self._get_custom_materialization(context)) + ): + # include custom materialization dependencies as they might use macros + dependencies = dependencies.union(custom_mat.dependencies) + + model_dialect = self.dialect(context) + + # Only keep refs and sources that exist in the context to match dbt behavior + dependencies.refs.intersection_update(context.refs) + dependencies.sources.intersection_update(context.sources) model_context = context.context_for_dependencies( - self.dependencies.union(self.tests_ref_source_dependencies) + dependencies.union(self.tests_ref_source_dependencies) ) jinja_macros = model_context.jinja_macros.trim( - self.dependencies.macros, package=self.package_name - ) - jinja_macros.add_globals( - { - "this": self.relation_info, - "model": self.model_function(), - "schema": self.table_schema, - "config": self.config_attribute_dict, - **model_context.jinja_globals, # type: ignore - } + dependencies.macros, package=self.package_name ) - return { - "audits": [(test.name, {}) for test in self.tests], - "columns": column_types_to_sqlmesh(self.columns, self.dialect(context)) or None, + jinja_macros.add_globals(self._model_jinja_context(model_context, dependencies)) + + model_kwargs = { + "audits": [(test.canonical_name, {}) for test in self.tests], "column_descriptions": column_descriptions_to_sqlmesh(self.columns) or None, "depends_on": { model.canonical_name(context) for model in model_context.refs.values() - }.union({source.canonical_name(context) for source in model_context.sources.values()}), + }.union( + { + source.canonical_name(context) + for source in model_context.sources.values() + if source.fqn not in context.model_fqns + # Allow dbt projects to reference a model as a source without causing a cycle + }, + ), "jinja_macros": jinja_macros, "path": self.path, - "pre_statements": [d.jinja_statement(hook.sql) for hook in self.pre_hook], - "post_statements": [d.jinja_statement(hook.sql) for hook in self.post_hook], + "pre_statements": [ + ParsableSql(sql=d.jinja_statement(hook.sql).sql(), transaction=hook.transaction) + for hook in self.pre_hook + ], + "post_statements": [ + ParsableSql(sql=d.jinja_statement(hook.sql).sql(), transaction=hook.transaction) + for hook in self.post_hook + ], "tags": self.tags, - "physical_schema_override": context.sqlmesh_config.physical_schema_override, + "physical_schema_mapping": context.sqlmesh_config.physical_schema_mapping, "default_catalog": context.target.database, + "grain": [d.parse_one(g, dialect=model_dialect) for g in ensure_list(self.grain)], **self.sqlmesh_config_kwargs, } + # dbt doesn't respect the data_type field for DDL statements– instead, it optionally uses + # it to validate the actual data types at runtime through contracts or external plugins. + # Only the `columns_types` config of seed models is actually respected. We don't set the + # columns attribute to self.columns intentionally in all other cases, as that could result + # in unfaithful types when models are materialized. + # + # See: + # - https://docs.getdbt.com/reference/resource-properties/columns + # - https://docs.getdbt.com/reference/resource-configs/contract + # - https://docs.getdbt.com/reference/resource-configs/column_types + if column_types_override: + model_kwargs["columns"] = ( + column_types_to_sqlmesh(column_types_override, self.dialect(context)) or None + ) + + return model_kwargs + @abstractmethod - def to_sqlmesh(self, context: DbtContext) -> Model: + def to_sqlmesh( + self, + context: DbtContext, + audit_definitions: t.Optional[t.Dict[str, ModelAudit]] = None, + virtual_environment_mode: VirtualEnvironmentMode = VirtualEnvironmentMode.default, + ) -> Model: """Convert DBT model into sqlmesh Model""" + + def _model_jinja_context( + self, context: DbtContext, dependencies: Dependencies + ) -> t.Dict[str, t.Any]: + if context._manifest and self.unique_id in context._manifest._manifest.nodes: + attributes = context._manifest._manifest.nodes[self.unique_id].to_dict() + if dependencies.model_attrs.all_attrs: + model_node: AttributeDict[str, t.Any] = AttributeDict(attributes) + else: + model_node = AttributeDict( + filter(lambda kv: kv[0] in dependencies.model_attrs.attrs, attributes.items()) + ) + + # We exclude the raw SQL code to reduce the payload size. It's still accessible through + # the JinjaQuery instance stored in the resulting SQLMesh model's `query` field. + model_node.pop(RAW_CODE_KEY, None) + else: + model_node = AttributeDict({}) + + return { + "this": self.relation_info, + "model": model_node, + "schema": self.table_schema, + "config": self.config_attribute_dict, + **context.jinja_globals, + } diff --git a/sqlmesh/dbt/builtin.py b/sqlmesh/dbt/builtin.py index 00cd897a1e..fa05e3d7f9 100644 --- a/sqlmesh/dbt/builtin.py +++ b/sqlmesh/dbt/builtin.py @@ -12,15 +12,20 @@ from dbt import version from dbt.adapters.base import BaseRelation, Column from ruamel.yaml import YAMLError +from sqlglot import Dialect +from sqlmesh.core.console import get_console from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.core.model.definition import SqlModel +from sqlmesh.core.snapshot.definition import DeployabilityIndex from sqlmesh.dbt.adapter import BaseAdapter, ParsetimeAdapter, RuntimeAdapter +from sqlmesh.dbt.common import RAW_CODE_KEY from sqlmesh.dbt.relation import Policy from sqlmesh.dbt.target import TARGET_TYPE_TO_CONFIG_CLASS from sqlmesh.dbt.util import DBT_VERSION -from sqlmesh.utils import AttributeDict, yaml +from sqlmesh.utils import AttributeDict, debug_mode_enabled, yaml from sqlmesh.utils.date import now -from sqlmesh.utils.errors import ConfigError, MacroEvalError +from sqlmesh.utils.errors import ConfigError from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroReference, MacroReturnVal logger = logging.getLogger(__name__) @@ -28,7 +33,7 @@ class Exceptions: def raise_compiler_error(self, msg: str) -> None: - if DBT_VERSION >= (1, 4): + if DBT_VERSION >= (1, 4, 0): from dbt.exceptions import CompilationError raise CompilationError(msg) @@ -45,10 +50,28 @@ def warn(self, msg: str) -> str: return "" +def try_or_compiler_error( + message_if_exception: str, func: t.Callable, *args: t.Any, **kwargs: t.Any +) -> t.Any: + try: + return func(*args, **kwargs) + except Exception: + if DBT_VERSION >= (1, 4, 0): + from dbt.exceptions import CompilationError + + raise CompilationError(message_if_exception) + else: + from dbt.exceptions import CompilationException # type: ignore + + raise CompilationException(message_if_exception) + + class Api: def __init__(self, dialect: t.Optional[str]) -> None: if dialect: - config_class = TARGET_TYPE_TO_CONFIG_CLASS[dialect] + config_class = TARGET_TYPE_TO_CONFIG_CLASS[ + Dialect.get_or_raise(dialect).__class__.__name__.lower() + ] self.Relation = config_class.relation_class self.Column = config_class.column_class self.quote_policy = config_class.quote_policy @@ -159,6 +182,74 @@ def has_var(self, name: str) -> bool: return name in self.variables +class Config: + def __init__(self, config_dict: t.Dict[str, t.Any]) -> None: + self._config = config_dict + + def __call__(self, *args: t.Any, **kwargs: t.Any) -> str: + if args and kwargs: + raise ConfigError( + "Invalid inline model config: cannot mix positional and keyword arguments" + ) + + if args: + if len(args) == 1 and isinstance(args[0], dict): + # Single dict argument: config({"materialized": "table"}) + self._config.update(args[0]) + else: + raise ConfigError( + f"Invalid inline model config: expected a single dictionary, got {len(args)} arguments" + ) + elif kwargs: + # Keyword arguments: config(materialized="table") + self._config.update(kwargs) + + return "" + + def set(self, name: str, value: t.Any) -> str: + self._config.update({name: value}) + return "" + + def _validate(self, name: str, validator: t.Callable, value: t.Optional[t.Any] = None) -> None: + try: + validator(value) + except Exception as e: + raise ConfigError(f"Config validation failed for '{name}': {e}") + + def require(self, name: str, validator: t.Optional[t.Callable] = None) -> t.Any: + if name not in self._config: + raise ConfigError(f"Missing required config: {name}") + + value = self._config[name] + + if validator is not None: + self._validate(name, validator, value) + + return value + + def get( + self, name: str, default: t.Any = None, validator: t.Optional[t.Callable] = None + ) -> t.Any: + value = self._config.get(name, default) + + if validator is not None and value is not None: + self._validate(name, validator, value) + + return value + + def persist_relation_docs(self) -> bool: + persist_docs = self.get("persist_docs", default={}) + if not isinstance(persist_docs, dict): + return False + return persist_docs.get("relation", False) + + def persist_column_docs(self) -> bool: + persist_docs = self.get("persist_docs", default={}) + if not isinstance(persist_docs, dict): + return False + return persist_docs.get("columns", False) + + def env_var(name: str, default: t.Optional[str] = None) -> t.Optional[str]: if name not in os.environ and default is None: raise ConfigError(f"Missing environment variable '{name}'") @@ -166,7 +257,13 @@ def env_var(name: str, default: t.Optional[str] = None) -> t.Optional[str]: def log(msg: str, info: bool = False) -> str: - logger.debug(msg) + if info: + # Write to both log file and stdout + logger.info(msg) + get_console().log_status_update(msg) + else: + logger.debug(msg) + return "" @@ -212,6 +309,27 @@ def source(package: str, name: str) -> t.Optional[BaseRelation]: logger.debug("Could not resolve source package='%s' name='%s'", package, name) return None + # Clickhouse uses a 2-level schema.table naming scheme, where the second level is called + # a "database" (instead of "schema" as one would reasonably assume). This can lead to confusion + # because it is not clear how Clickhouse identifiers map onto dbt's "database" and "schema" fields. + # + # This confusion can occur in source resolution. If a source's `schema` is not explicitly specified, + # the source name is used as the schema by default. + # + # If a source specified the `database` field and the schema has defaulted to the source name, + # we follow dbt-clickhouse in assuming that the user intended for the `database` field to be the + # second level identifier. + # https://github.com/ClickHouse/dbt-clickhouse/blob/065f3a724fa09205446ecadac7a00d92b2d8c646/dbt/adapters/clickhouse/relation.py#L112 + # + # NOTE: determining relation class based on name so we don't introduce a dependency on dbt-clickhouse + if ( + api.Relation.__name__ == "ClickHouseRelation" + and relation_info.schema == package + and relation_info.database + ): + relation_info["schema"] = relation_info["database"] + relation_info["database"] = "" + return _relation_info_to_relation(relation_info, api.Relation, api.quote_policy) return source @@ -263,18 +381,16 @@ def do_zip(*args: t.Any, default: t.Optional[t.Any] = None) -> t.Optional[t.Any] return default -def as_bool(value: str) -> bool: - result = _try_literal_eval(value) - if isinstance(result, bool): - return result - raise MacroEvalError(f"Failed to convert '{value}' into boolean.") +def as_bool(value: t.Any) -> t.Any: + # dbt's jinja TEXT_FILTERS just return the input value as is + # https://github.com/dbt-labs/dbt-common/blob/main/dbt_common/clients/jinja.py#L559 + return value def as_number(value: str) -> t.Any: - result = _try_literal_eval(value) - if isinstance(value, (int, float)) and not isinstance(result, bool): - return result - raise MacroEvalError(f"Failed to convert '{value}' into number.") + # dbt's jinja TEXT_FILTERS just return the input value as is + # https://github.com/dbt-labs/dbt-common/blob/main/dbt_common/clients/jinja.py#L559 + return value def _try_literal_eval(value: str) -> t.Any: @@ -284,6 +400,15 @@ def _try_literal_eval(value: str) -> t.Any: return value +def debug() -> str: + import sys + import ipdb # type: ignore + + frame = sys._getframe(3) + ipdb.set_trace(frame) + return "" + + BUILTIN_GLOBALS = { "dbt_version": version.__version__, "env_var": env_var, @@ -300,10 +425,15 @@ def _try_literal_eval(value: str) -> t.Any: "sqlmesh_incremental": True, "tojson": to_json, "toyaml": to_yaml, + "try_or_compiler_error": try_or_compiler_error, "zip": do_zip, "zip_strict": lambda *args: list(zip(*args)), } +# Add debug function conditionally both with dbt or sqlmesh equivalent flag +if os.environ.get("DBT_MACRO_DEBUGGING") or debug_mode_enabled(): + BUILTIN_GLOBALS["debug"] = debug + BUILTIN_FILTERS = { "as_bool": as_bool, "as_native": _try_literal_eval, @@ -326,9 +456,7 @@ def create_builtin_globals( jinja_globals = jinja_globals.copy() target: t.Optional[AttributeDict] = jinja_globals.get("target", None) - project_dialect = jinja_globals.pop("dialect", None) or ( - target.get("dialect") if target else None - ) + project_dialect = jinja_globals.pop("dialect", None) or (target.get("type") if target else None) api = Api(project_dialect) builtin_globals["api"] = api @@ -352,16 +480,51 @@ def create_builtin_globals( if variables is not None: builtin_globals["var"] = Var(variables) + builtin_globals["config"] = Config(jinja_globals.pop("config", {"tags": []})) + + deployability_index = ( + jinja_globals.get("deployability_index") or DeployabilityIndex.all_deployable() + ) snapshot = jinja_globals.pop("snapshot", None) - is_incremental = bool(snapshot.intervals) if snapshot else False + + if snapshot and snapshot.is_incremental: + intervals = ( + snapshot.intervals + if deployability_index.is_deployable(snapshot) + else snapshot.dev_intervals + ) + is_incremental = bool(intervals) + + snapshot_table_exists = jinja_globals.get("snapshot_table_exists") + if is_incremental and snapshot_table_exists is not None: + # If we know the information about table existence, we can use it to correctly + # set the flag + is_incremental &= snapshot_table_exists + else: + is_incremental = False + builtin_globals["is_incremental"] = lambda: is_incremental builtin_globals["builtins"] = AttributeDict( {k: builtin_globals.get(k) for k in ("ref", "source", "config", "var")} ) + if (model := jinja_globals.pop("model", None)) is not None: + if isinstance(model_instance := jinja_globals.pop("model_instance", None), SqlModel): + builtin_globals["model"] = AttributeDict( + {**model, RAW_CODE_KEY: model_instance.query.name} + ) + else: + builtin_globals["model"] = AttributeDict(model.copy()) + + builtin_globals["flags"] = ( + Flags(which="run") if engine_adapter is not None else Flags(which="parse") + ) + builtin_globals["invocation_args_dict"] = { + k.lower(): v for k, v in builtin_globals["flags"].__dict__.items() + } + if engine_adapter is not None: - builtin_globals["flags"] = Flags(which="run") adapter: BaseAdapter = RuntimeAdapter( engine_adapter, jinja_macros, @@ -375,15 +538,15 @@ def create_builtin_globals( quote_policy=api.quote_policy, snapshots=jinja_globals.get("snapshots", {}), table_mapping=jinja_globals.get("table_mapping", {}), - deployability_index=jinja_globals.get("deployability_index"), + deployability_index=deployability_index, project_dialect=project_dialect, ) else: - builtin_globals["flags"] = Flags(which="parse") adapter = ParsetimeAdapter( jinja_macros, jinja_globals={**builtin_globals, **jinja_globals}, project_dialect=project_dialect, + quote_policy=api.quote_policy, ) sql_execution = SQLExecution(adapter) @@ -396,11 +559,15 @@ def create_builtin_globals( "load_result": sql_execution.load_result, "run_query": sql_execution.run_query, "statement": sql_execution.statement, + "graph": adapter.graph, + "selected_resources": list(jinja_globals.get("selected_models") or []), + "write": lambda input: None, # We don't support writing yet } ) builtin_globals["run_started_at"] = jinja_globals.get("execution_dt") or now() builtin_globals["dbt"] = AttributeDict(builtin_globals) + builtin_globals["context"] = builtin_globals["dbt"] return {**builtin_globals, **jinja_globals} diff --git a/sqlmesh/dbt/column.py b/sqlmesh/dbt/column.py index 327f7cd539..80a6ad9325 100644 --- a/sqlmesh/dbt/column.py +++ b/sqlmesh/dbt/column.py @@ -1,6 +1,7 @@ from __future__ import annotations import typing as t +import logging from sqlglot import exp, parse_one from sqlglot.helper import ensure_list @@ -9,6 +10,8 @@ from sqlmesh.utils.conversions import ensure_bool from sqlmesh.utils.pydantic import field_validator +logger = logging.getLogger(__name__) + def yaml_to_columns( yaml: t.Dict[str, ColumnConfig] | t.List[t.Dict[str, ColumnConfig]], @@ -31,11 +34,20 @@ def column_types_to_sqlmesh( Returns: A dict of column name to exp.DataType """ - return { - name: parse_one(column.data_type, into=exp.DataType, dialect=dialect or "") - for name, column in columns.items() - if column.enabled and column.data_type - } + col_types_to_sqlmesh: t.Dict[str, exp.DataType] = {} + for name, column in columns.items(): + if column.enabled and column.data_type: + column_def = parse_one( + f"{name} {column.data_type}", into=exp.ColumnDef, dialect=dialect or "" + ) + if column_def.args.get("constraints"): + logger.warning( + f"Ignoring unsupported constraints for column '{name}' with definition '{column.data_type}'. Please refer to github.com/SQLMesh/sqlmesh/issues/4717 for more information." + ) + kind = column_def.kind + if kind: + col_types_to_sqlmesh[name] = kind + return col_types_to_sqlmesh def column_descriptions_to_sqlmesh(columns: t.Dict[str, ColumnConfig]) -> t.Dict[str, str]: diff --git a/sqlmesh/dbt/common.py b/sqlmesh/dbt/common.py index 681e1cc8a7..67e1a788cf 100644 --- a/sqlmesh/dbt/common.py +++ b/sqlmesh/dbt/common.py @@ -2,12 +2,15 @@ import re import typing as t +from dataclasses import dataclass from pathlib import Path from ruamel.yaml.constructor import DuplicateKeyError from sqlglot.helper import ensure_list +from sqlmesh.dbt.util import DBT_VERSION from sqlmesh.core.config.base import BaseConfig, UpdateStrategy +from sqlmesh.core.config.common import DBT_PROJECT_FILENAME from sqlmesh.utils import AttributeDict from sqlmesh.utils.conversions import ensure_bool, try_str_to_bool from sqlmesh.utils.errors import ConfigError @@ -17,7 +20,8 @@ T = t.TypeVar("T", bound="GeneralConfig") -PROJECT_FILENAME = "dbt_project.yml" +PROJECT_FILENAME = DBT_PROJECT_FILENAME +RAW_CODE_KEY = "raw_code" if DBT_VERSION >= (1, 3, 0) else "raw_sql" # type: ignore JINJA_ONLY = { "adapter", @@ -35,12 +39,16 @@ def load_yaml(source: str | Path) -> t.Dict: try: - return load(source, render_jinja=False) + return load( + source, render_jinja=False, allow_duplicate_keys=True, keep_last_duplicate_key=True + ) except DuplicateKeyError as ex: raise ConfigError(f"{source}: {ex}" if isinstance(source, Path) else f"{ex}") -def parse_meta(v: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: +def parse_meta(v: t.Optional[t.Dict[str, t.Any]]) -> t.Dict[str, t.Any]: + if v is None: + return {} for key, value in v.items(): if isinstance(value, str): v[key] = try_str_to_bool(value) @@ -109,7 +117,7 @@ def _validate_list(cls, v: t.Union[str, t.List[str]]) -> t.List[str]: @field_validator("meta", mode="before") @classmethod - def _validate_meta(cls, v: t.Dict[str, t.Union[str, t.Any]]) -> t.Dict[str, t.Any]: + def _validate_meta(cls, v: t.Optional[t.Dict[str, t.Union[str, t.Any]]]) -> t.Dict[str, t.Any]: return parse_meta(v) _FIELD_UPDATE_STRATEGY: t.ClassVar[t.Dict[str, UpdateStrategy]] = { @@ -129,6 +137,10 @@ def _validate_meta(cls, v: t.Dict[str, t.Union[str, t.Any]]) -> t.Dict[str, t.An def config_attribute_dict(self) -> AttributeDict[str, t.Any]: return AttributeDict(self.dict(exclude=EXCLUDED_CONFIG_ATTRIBUTE_KEYS)) + def _get_field_value(self, field: str) -> t.Optional[t.Any]: + field_val = getattr(self, field, None) + return field_val if field_val is not None else self.meta.get(field, None) + def replace(self, other: T) -> None: """ Replace the contents of this instance with the passed in instance. @@ -149,8 +161,8 @@ def sqlmesh_config_kwargs(self) -> t.Dict[str, t.Any]: """ kwargs = {} for field in self.sqlmesh_config_fields: - field_val = getattr(self, field, None) or self.meta.get(field, None) - if field_val: + field_val = self._get_field_value(field) + if field_val is not None: kwargs[field] = field_val return kwargs @@ -165,6 +177,12 @@ def sqlmesh_config_fields(self) -> t.Set[str]: return set() +@dataclass +class ModelAttrs: + attrs: t.Set[str] + all_attrs: bool = False + + class Dependencies(PydanticModel): """ DBT dependencies for a model, macro, etc. @@ -179,6 +197,9 @@ class Dependencies(PydanticModel): sources: t.Set[str] = set() refs: t.Set[str] = set() variables: t.Set[str] = set() + model_attrs: ModelAttrs = ModelAttrs(attrs=set()) + + has_dynamic_var_names: bool = False def union(self, other: Dependencies) -> Dependencies: return Dependencies( @@ -186,6 +207,11 @@ def union(self, other: Dependencies) -> Dependencies: sources=self.sources | other.sources, refs=self.refs | other.refs, variables=self.variables | other.variables, + model_attrs=ModelAttrs( + attrs=self.model_attrs.attrs | other.model_attrs.attrs, + all_attrs=self.model_attrs.all_attrs or other.model_attrs.all_attrs, + ), + has_dynamic_var_names=self.has_dynamic_var_names or other.has_dynamic_var_names, ) @field_validator("macros", mode="after") diff --git a/sqlmesh/dbt/context.py b/sqlmesh/dbt/context.py index 14e14efa2e..29eb03700d 100644 --- a/sqlmesh/dbt/context.py +++ b/sqlmesh/dbt/context.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import typing as t from dataclasses import dataclass, field, replace from pathlib import Path @@ -8,10 +9,11 @@ from sqlmesh.core.config import Config as SQLMeshConfig from sqlmesh.dbt.builtin import _relation_info_to_relation +from sqlmesh.dbt.common import Dependencies from sqlmesh.dbt.manifest import ManifestHelper from sqlmesh.dbt.target import TargetConfig from sqlmesh.utils import AttributeDict -from sqlmesh.utils.errors import ConfigError, SQLMeshError +from sqlmesh.utils.errors import ConfigError, SQLMeshError, MissingModelError, MissingSourceError from sqlmesh.utils.jinja import ( JinjaGlobalAttribute, JinjaMacroRegistry, @@ -22,18 +24,21 @@ if t.TYPE_CHECKING: from jinja2 import Environment - from sqlmesh.dbt.basemodel import Dependencies from sqlmesh.dbt.model import ModelConfig from sqlmesh.dbt.relation import Policy from sqlmesh.dbt.seed import SeedConfig from sqlmesh.dbt.source import SourceConfig +logger = logging.getLogger(__name__) + @dataclass class DbtContext: """Context for DBT environment""" project_root: Path = Path() + profiles_dir: t.Optional[Path] = None + """Optional override to specify the directory where profiles.yml is located, if not at the :project_root""" target_name: t.Optional[str] = None profile_name: t.Optional[str] = None project_schema: t.Optional[str] = None @@ -48,6 +53,7 @@ class DbtContext: _project_name: t.Optional[str] = None _variables: t.Dict[str, t.Any] = field(default_factory=dict) _models: t.Dict[str, ModelConfig] = field(default_factory=dict) + _model_fqns: t.Set[str] = field(default_factory=set) _seeds: t.Dict[str, SeedConfig] = field(default_factory=dict) _sources: t.Dict[str, SourceConfig] = field(default_factory=dict) _refs: t.Dict[str, t.Union[ModelConfig, SeedConfig]] = field(default_factory=dict) @@ -101,9 +107,10 @@ def add_variables(self, variables: t.Dict[str, t.Any]) -> None: self._jinja_environment = None def set_and_render_variables(self, variables: t.Dict[str, t.Any], package: str) -> None: - self.variables = variables - - jinja_environment = self.jinja_macros.build_environment(**self.jinja_globals) + package_macros = self.jinja_macros.copy( + update={"top_level_packages": [*self.jinja_macros.top_level_packages, package]} + ) + jinja_environment = package_macros.build_environment(**self.jinja_globals) def _render_var(value: t.Any) -> t.Any: if isinstance(value, str): @@ -124,7 +131,7 @@ def _var(name: str, default: t.Optional[t.Any] = None) -> t.Any: try: rendered_variables[k] = _render_var(v) except Exception as ex: - raise ConfigError(f"Failed to render variable '{k}', value '{v}': {ex}") from ex + logger.warning(f"Failed to render variable '{k}', value '{v}': {ex}") self.variables = rendered_variables @@ -140,6 +147,7 @@ def models(self) -> t.Dict[str, ModelConfig]: def models(self, models: t.Dict[str, ModelConfig]) -> None: self._models = {} self._refs = {} + self._model_fqns = set() self.add_models(models) def add_models(self, models: t.Dict[str, ModelConfig]) -> None: @@ -147,6 +155,12 @@ def add_models(self, models: t.Dict[str, ModelConfig]) -> None: self._models.update(models) self._jinja_environment = None + @property + def model_fqns(self) -> t.Set[str]: + if not self._model_fqns: + self._model_fqns = {model.fqn for model in self._models.values()} + return self._model_fqns + @property def seeds(self) -> t.Dict[str, SeedConfig]: return self._seeds @@ -185,11 +199,14 @@ def refs(self) -> t.Dict[str, t.Union[ModelConfig, SeedConfig]]: for model in t.cast( t.Dict[str, t.Union[ModelConfig, SeedConfig]], {**self._seeds, **self._models} ).values(): - self._refs[model.alias or model.name] = model - self._refs[model.config_name] = model - if model.version is not None and model.version == model.latest_version: - self._refs[model.name] = model - self._refs[f"{model.package_name}.{model.name}"] = model + name = model.name + config_name = model.config_name + if model.version == model.latest_version: + self._refs[name] = model + self._refs[config_name] = model + if model.version: + self._refs[f"{name}_v{model.version}"] = model + self._refs[f"{config_name}_v{model.version}"] = model return self._refs @property @@ -239,6 +256,9 @@ def jinja_globals(self) -> t.Dict[str, JinjaGlobalAttribute]: # pass user-specified default dialect if we have already loaded the config if self.sqlmesh_config.dialect: output["dialect"] = self.sqlmesh_config.dialect + # Pass flat graph structure like dbt + if self._manifest is not None: + output["flat_graph"] = AttributeDict(self.manifest.flat_graph) return output def context_for_dependencies(self, dependencies: Dependencies) -> DbtContext: @@ -259,13 +279,13 @@ def context_for_dependencies(self, dependencies: Dependencies) -> DbtContext: else: models[ref] = t.cast(ModelConfig, model) else: - raise ConfigError(f"Model '{ref}' was not found.") + raise MissingModelError(ref) for source in dependencies.sources: if source in self.sources: sources[source] = self.sources[source] else: - raise ConfigError(f"Source '{source}' was not found.") + raise MissingSourceError(source) variables = {k: v for k, v in self.variables.items() if k in dependencies.variables} @@ -273,6 +293,7 @@ def context_for_dependencies(self, dependencies: Dependencies) -> DbtContext: dependency_context.seeds = seeds dependency_context.models = models dependency_context.variables = variables + dependency_context._refs = {**dependency_context._seeds, **dependency_context._models} # type: ignore return dependency_context diff --git a/sqlmesh/dbt/loader.py b/sqlmesh/dbt/loader.py index 90db2b7dc2..fb3ecb2c77 100644 --- a/sqlmesh/dbt/loader.py +++ b/sqlmesh/dbt/loader.py @@ -1,64 +1,122 @@ from __future__ import annotations import logging +import sys import typing as t +import sqlmesh.core.dialect as d from pathlib import Path -from sqlmesh.core import constants as c -from sqlmesh.core.audit import Audit +from collections import defaultdict from sqlmesh.core.config import ( Config, ConnectionConfig, GatewayConfig, ModelDefaultsConfig, + DbtConfig as RootDbtConfig, ) -from sqlmesh.core.loader import LoadedProject, Loader +from sqlmesh.core.environment import EnvironmentStatements +from sqlmesh.core.loader import CacheBase, LoadedProject, Loader from sqlmesh.core.macros import MacroRegistry, macro from sqlmesh.core.model import Model, ModelCache +from sqlmesh.core.signal import signal from sqlmesh.dbt.basemodel import BMC, BaseModelConfig +from sqlmesh.dbt.common import Dependencies from sqlmesh.dbt.context import DbtContext from sqlmesh.dbt.model import ModelConfig from sqlmesh.dbt.profile import Profile from sqlmesh.dbt.project import Project from sqlmesh.dbt.target import TargetConfig from sqlmesh.utils import UniqueKeyDict -from sqlmesh.utils.errors import ConfigError -from sqlmesh.utils.jinja import JinjaMacroRegistry +from sqlmesh.utils.errors import ConfigError, MissingModelError, BaseMissingReferenceError +from sqlmesh.utils.jinja import ( + JinjaMacroRegistry, + make_jinja_registry, +) -logger = logging.getLogger(__name__) +if sys.version_info >= (3, 12): + from importlib import metadata +else: + import importlib_metadata as metadata # type: ignore if t.TYPE_CHECKING: + from sqlmesh.core.audit import Audit, ModelAudit from sqlmesh.core.context import GenericContext +logger = logging.getLogger(__name__) + def sqlmesh_config( project_root: t.Optional[Path] = None, state_connection: t.Optional[ConnectionConfig] = None, + dbt_profile_name: t.Optional[str] = None, dbt_target_name: t.Optional[str] = None, variables: t.Optional[t.Dict[str, t.Any]] = None, + threads: t.Optional[int] = None, register_comments: t.Optional[bool] = None, + infer_state_schema_name: bool = False, + profiles_dir: t.Optional[Path] = None, **kwargs: t.Any, ) -> Config: project_root = project_root or Path() - context = DbtContext(project_root=project_root) + context = DbtContext( + project_root=project_root, profiles_dir=profiles_dir, profile_name=dbt_profile_name + ) + + # note: Profile.load() is called twice with different DbtContext's: + # - once here with the above DbtContext (to determine connnection / gateway config which has to be set up before everything else) + # - again on the SQLMesh side via GenericContext.load() -> DbtLoader._load_projects() -> Project.load() which constructs a fresh DbtContext and ignores the above one + # it's important to ensure that the DbtContext created within the DbtLoader uses the same project root / profiles dir that we use here profile = Profile.load(context, target_name=dbt_target_name) model_defaults = kwargs.pop("model_defaults", ModelDefaultsConfig()) if model_defaults.dialect is None: model_defaults.dialect = profile.target.dialect - target_to_sqlmesh_args = {} - if register_comments is not None: - target_to_sqlmesh_args["register_comments"] = register_comments + target_to_sqlmesh_args = { + "register_comments": register_comments or False, + } + + loader = kwargs.pop("loader", DbtLoader) + if not issubclass(loader, DbtLoader): + raise ConfigError("The loader must be a DbtLoader.") + + if threads is not None: + # the to_sqlmesh() function on TargetConfig maps self.threads -> concurrent_tasks + profile.target.threads = threads + + gateway_kwargs = {} + if infer_state_schema_name: + profile_name = context.profile_name + + # Note: we deliberately isolate state based on the target *schema* and not the target name. + # It is assumed that the project will define a target, eg 'dev', and then in each users own ~/.dbt/profiles.yml the schema + # for the 'dev' target is overriden to something user-specific, rather than making the target name itself user-specific. + # This means that the schema name is the indicator of isolated state, not the target name which may be re-used across multiple schemas. + target_schema = profile.target.schema_ + + # dbt-core doesnt allow schema to be undefined, but it does allow an empty string, and then just + # fails at runtime when `CREATE SCHEMA ""` doesnt work + if not target_schema: + raise ConfigError( + f"Target '{profile.target_name}' does not specify a schema.\n" + "A schema is required in order to infer where to store SQLMesh state" + ) + + inferred_state_schema_name = f"sqlmesh_state_{profile_name}_{target_schema}" + logger.info("Inferring state schema: %s", inferred_state_schema_name) + gateway_kwargs["state_schema"] = inferred_state_schema_name return Config( - loader=DbtLoader, + loader=loader, + loader_kwargs=dict(profiles_dir=profiles_dir), model_defaults=model_defaults, variables=variables or {}, + dbt=RootDbtConfig(infer_state_schema_name=infer_state_schema_name), **{ "default_gateway": profile.target_name if "gateways" not in kwargs else "", "gateways": { profile.target_name: GatewayConfig( connection=profile.target.to_sqlmesh(**target_to_sqlmesh_args), state_connection=state_connection, + **gateway_kwargs, ) }, # type: ignore **kwargs, @@ -67,17 +125,20 @@ def sqlmesh_config( class DbtLoader(Loader): - def __init__(self) -> None: + def __init__( + self, context: GenericContext, path: Path, profiles_dir: t.Optional[Path] = None + ) -> None: self._projects: t.List[Project] = [] self._macros_max_mtime: t.Optional[float] = None - super().__init__() + self._profiles_dir = profiles_dir + super().__init__(context, path) - def load(self, context: GenericContext, update_schemas: bool = True) -> LoadedProject: + def load(self) -> LoadedProject: self._projects = [] - return super().load(context, update_schemas) + return super().load() def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]: - macro_files = list(Path(self._context.path, "macros").glob("**/*.sql")) + macro_files = list(Path(self.config_path, "macros").glob("**/*.sql")) for file in macro_files: self._track_file(file) @@ -89,13 +150,24 @@ def _load_scripts(self) -> t.Tuple[MacroRegistry, JinjaMacroRegistry]: ) def _load_models( - self, macros: MacroRegistry, jinja_macros: JinjaMacroRegistry, gateway: t.Optional[str] + self, + macros: MacroRegistry, + jinja_macros: JinjaMacroRegistry, + gateway: t.Optional[str], + audits: UniqueKeyDict[str, ModelAudit], + signals: UniqueKeyDict[str, signal], ) -> UniqueKeyDict[str, Model]: models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") - for project in self._load_projects(): - context = project.context.copy() + def _to_sqlmesh(config: BMC, context: DbtContext) -> Model: + logger.debug("Converting '%s' to sqlmesh format", config.canonical_name(context)) + return config.to_sqlmesh( + context, + audit_definitions=audits, + virtual_environment_mode=self.config.virtual_environment_mode, + ) + for project in self._load_projects(): macros_max_mtime = self._macros_max_mtime yaml_max_mtimes = self._compute_yaml_max_mtime_per_subfolder( project.context.project_root @@ -105,26 +177,28 @@ def _load_models( logger.debug("Converting models to sqlmesh") # Now that config is rendered, create the sqlmesh models for package in project.packages.values(): - context.set_and_render_variables(package.variables, package.name) + package_context = project.context.copy() + package_context.set_and_render_variables(package.variables, package.name) package_models: t.Dict[str, BaseModelConfig] = {**package.models, **package.seeds} + package_models_by_path: t.Dict[Path, t.List[BaseModelConfig]] = defaultdict(list) for model in package_models.values(): - if ( - not context.sqlmesh_config.feature_flags.dbt.scd_type_2_support - and isinstance(model, ModelConfig) - and model.model_kind(context).is_scd_type_2 - ): - logger.info( - "Skipping loading Snapshot (SCD Type 2) models due to the feature flag disabling this feature" - ) + if isinstance(model, ModelConfig) and not model.sql.strip(): + logger.info(f"Skipping empty model '{model.name}' at path '{model.path}'.") continue - sqlmesh_model = cache.get_or_load_model( - model.path, loader=lambda: self._to_sqlmesh(model, context) + package_models_by_path[model.path].append(model) + + for path, path_models in package_models_by_path.items(): + sqlmesh_models = cache.get_or_load_models( + path, + loader=lambda: [ + _to_sqlmesh(model, package_context) for model in path_models + ], ) + for sqlmesh_model in sqlmesh_models: + models[sqlmesh_model.fqn] = sqlmesh_model - models[sqlmesh_model.fqn] = sqlmesh_model - - models.update(self._load_external_models()) + models.update(self._load_external_models(audits, cache)) return models @@ -134,81 +208,166 @@ def _load_audits( audits: UniqueKeyDict = UniqueKeyDict("audits") for project in self._load_projects(): - context = project.context - logger.debug("Converting audits to sqlmesh") for package in project.packages.values(): - context.set_and_render_variables(package.variables, package.name) + package_context = project.context.copy() + package_context.set_and_render_variables(package.variables, package.name) for test in package.tests.values(): logger.debug("Converting '%s' to sqlmesh format", test.name) - audits[test.name] = test.to_sqlmesh(context) + try: + audits[test.canonical_name] = test.to_sqlmesh(package_context) + + except BaseMissingReferenceError as e: + ref_type = "model" if isinstance(e, MissingModelError) else "source" + logger.warning( + "Skipping audit '%s' because %s '%s' is not a valid ref", + test.name, + ref_type, + e.ref, + ) return audits def _load_projects(self) -> t.List[Project]: if not self._projects: - target_name = self._context.gateway or self._context.config.default_gateway + target_name = self.context.selected_gateway self._projects = [] - for path, config in self._context.configs.items(): - project = Project.load( - DbtContext( - project_root=path, - target_name=target_name, - sqlmesh_config=config, - ), - variables=self._context.config.variables, + project = Project.load( + DbtContext( + project_root=self.config_path, + profiles_dir=self._profiles_dir, + target_name=target_name, + sqlmesh_config=self.config, + ), + variables=self.config.variables, + ) + + self._projects.append(project) + + context_default_catalog = self.context.default_catalog or "" + if project.context.target.database != context_default_catalog: + raise ConfigError( + f"Project default catalog ('{project.context.target.database}') does not match context default catalog ('{context_default_catalog}')." ) + for path in project.project_files: + self._track_file(path) - self._projects.append(project) + context = project.context - if project.context.target.database != self._context.default_catalog: - raise ConfigError( - "Project default catalog does not match context default catalog" - ) - for path in project.project_files: - self._track_file(path) - - context = project.context - - macros_mtimes: t.List[float] = [] - - for package_name, package in project.packages.items(): - context.add_sources(package.sources) - context.add_seeds(package.seeds) - context.add_models(package.models) - macros_mtimes.extend( - [ - self._path_mtimes[m.path] - for m in package.macros.values() - if m.path in self._path_mtimes - ] - ) + macros_mtimes: t.List[float] = [] + + for package_name, package in project.packages.items(): + context.add_sources(package.sources) + context.add_seeds(package.seeds) + context.add_models(package.models) + macros_mtimes.extend( + [ + self._path_mtimes[m.path] + for m in package.macros.values() + if m.path in self._path_mtimes + ] + ) - for package_name, macro_infos in context.manifest.all_macros.items(): - context.add_macros(macro_infos, package=package_name) + for package_name, macro_infos in context.manifest.all_macros.items(): + context.add_macros(macro_infos, package=package_name) - self._macros_max_mtime = max(macros_mtimes) if macros_mtimes else None + self._macros_max_mtime = max(macros_mtimes) if macros_mtimes else None return self._projects - @classmethod - def _to_sqlmesh(cls, config: BMC, context: DbtContext) -> Model: - logger.debug("Converting '%s' to sqlmesh format", config.canonical_name(context)) - return config.to_sqlmesh(context) + def _load_requirements(self) -> t.Tuple[t.Dict[str, str], t.Set[str]]: + requirements, excluded_requirements = super()._load_requirements() - def _compute_yaml_max_mtime_per_subfolder(self, root: Path) -> t.Dict[Path, float]: - if not root.is_dir(): + target_packages = ["dbt-core"] + for project in self._load_projects(): + target_packages.append(f"dbt-{project.context.target.type}") + + for target_package in target_packages: + if target_package in requirements or target_package in excluded_requirements: + continue + try: + requirements[target_package] = metadata.version(target_package) + except metadata.PackageNotFoundError: + from sqlmesh.core.console import get_console + + get_console().log_warning(f"dbt package {target_package} is not installed.") + + return requirements, excluded_requirements + + def _load_environment_statements(self, macros: MacroRegistry) -> t.List[EnvironmentStatements]: + """Loads dbt's on_run_start, on_run_end hooks into sqlmesh's before_all, after_all statements respectively.""" + + hooks_by_package_name: t.Dict[str, EnvironmentStatements] = {} + project_names: t.Set[str] = set() + dialect = self.config.dialect + for project in self._load_projects(): + for package_name, package in project.packages.items(): + package_context = project.context.copy() + package_context.set_and_render_variables(package.variables, package_name) + on_run_start: t.List[str] = [ + on_run_hook.sql + for on_run_hook in sorted(package.on_run_start.values(), key=lambda h: h.index) + ] + on_run_end: t.List[str] = [ + on_run_hook.sql + for on_run_hook in sorted(package.on_run_end.values(), key=lambda h: h.index) + ] + + if on_run_start or on_run_end: + dependencies = Dependencies() + for hook in [*package.on_run_start.values(), *package.on_run_end.values()]: + dependencies = dependencies.union(hook.dependencies) + + statements_context = package_context.context_for_dependencies(dependencies) + jinja_registry = make_jinja_registry( + statements_context.jinja_macros, package_name, set(dependencies.macros) + ) + jinja_registry.add_globals(statements_context.jinja_globals) + + hooks_by_package_name[package_name] = EnvironmentStatements( + before_all=[ + d.jinja_statement(stmt).sql(dialect=dialect) + for stmt in on_run_start or [] + ], + after_all=[ + d.jinja_statement(stmt).sql(dialect=dialect) + for stmt in on_run_end or [] + ], + python_env={}, + jinja_macros=jinja_registry, + project=package_name, + ) + project_names.add(package_name) + + return [ + statements + for _, statements in sorted( + hooks_by_package_name.items(), + key=lambda item: 0 if item[0] in project_names else 1, + ) + ] + + def _compute_yaml_max_mtime_per_subfolder( + self, root: Path, visited: t.Optional[t.Set[Path]] = None + ) -> t.Dict[Path, float]: + root = root.resolve() + visited = visited or set() + if not root.is_dir() or root in visited: return {} + visited.add(root) + result = {} max_mtime: t.Optional[float] = None for nested in root.iterdir(): try: if nested.is_dir(): - result.update(self._compute_yaml_max_mtime_per_subfolder(nested)) + result.update( + self._compute_yaml_max_mtime_per_subfolder(nested, visited=visited) + ) elif nested.suffix.lower() in (".yaml", ".yml"): yaml_mtime = self._path_mtimes.get(nested) if yaml_mtime: @@ -223,7 +382,7 @@ def _compute_yaml_max_mtime_per_subfolder(self, root: Path) -> t.Dict[Path, floa return result - class _Cache: + class _Cache(CacheBase): MAX_ENTRY_NAME_LENGTH = 200 def __init__( @@ -239,17 +398,34 @@ def __init__( self._yaml_max_mtimes = yaml_max_mtimes target = t.cast(TargetConfig, project.context.target) - cache_path = loader._context.path / c.CACHE / target.name - self._model_cache = ModelCache(cache_path) + cache_dir = loader.context.cache_dir / target.name + self._model_cache = ModelCache(cache_dir) - def get_or_load_model(self, target_path: Path, loader: t.Callable[[], Model]) -> Model: - model = self._model_cache.get_or_load( + def get_or_load_models( + self, target_path: Path, loader: t.Callable[[], t.List[Model]] + ) -> t.List[Model]: + models = self._model_cache.get_or_load( self._cache_entry_name(target_path), self._cache_entry_id(target_path), loader=loader, ) - model._path = target_path - return model + for model in models: + model._path = target_path + + return models + + def put(self, models: t.List[Model], path: Path) -> bool: + return self._model_cache.put( + models, + self._cache_entry_name(path), + self._cache_entry_id(path), + ) + + def get(self, path: Path) -> t.List[Model]: + return self._model_cache.get( + self._cache_entry_name(path), + self._cache_entry_id(path), + ) def _cache_entry_name(self, target_path: Path) -> str: try: @@ -268,7 +444,7 @@ def _cache_entry_id(self, target_path: Path) -> str: return "__".join( [ str(int(max_mtime)) if max_mtime is not None else "na", - self._loader._context.config.fingerprint, + self._loader.config.fingerprint, ] ) diff --git a/sqlmesh/dbt/manifest.py b/sqlmesh/dbt/manifest.py index f62dbc3899..fce561a24d 100644 --- a/sqlmesh/dbt/manifest.py +++ b/sqlmesh/dbt/manifest.py @@ -1,3 +1,4 @@ +# ruff: noqa: E402 from __future__ import annotations import json @@ -7,25 +8,54 @@ import typing as t from argparse import Namespace from collections import defaultdict +from functools import cached_property from pathlib import Path from dbt import flags + +from sqlmesh.dbt.util import DBT_VERSION +from sqlmesh.utils.conversions import make_serializable + +# Override the file name to prevent dbt commands from invalidating the cache. + +if DBT_VERSION >= (1, 6, 0): + from dbt import constants as dbt_constants + + dbt_constants.PARTIAL_PARSE_FILE_NAME = "sqlmesh_partial_parse.msgpack" # type: ignore +else: + from dbt.parser import manifest as dbt_manifest # type: ignore + + dbt_manifest.PARTIAL_PARSE_FILE_NAME = "sqlmesh_partial_parse.msgpack" # type: ignore + +import jinja2 from dbt.adapters.factory import register_adapter, reset_adapters from dbt.config import Profile, Project, RuntimeConfig from dbt.config.profile import read_profile from dbt.config.renderer import DbtProjectYamlRenderer, ProfileRenderer from dbt.parser.manifest import ManifestLoader + +try: + from dbt.parser.sources import merge_freshness # type: ignore[attr-defined] +except ImportError: + # merge_freshness was renamed to merge_source_freshness in dbt 1.10 + # ref: https://github.com/dbt-labs/dbt-core/commit/14fc39a76ff4830cdf2fcbe73f57ca27db500018#diff-1f09db95588f46879a83378c2a86d6b16b7cdfcaddbfe46afc5d919ee5e9a4d9R430 + from dbt.parser.sources import merge_source_freshness as merge_freshness # type: ignore[no-redef,attr-defined] + from dbt.tracking import do_not_track -from sqlmesh.dbt.basemodel import Dependencies +from sqlmesh.core import constants as c +from sqlmesh.utils.errors import SQLMeshError +from sqlmesh.core.config import ModelDefaultsConfig from sqlmesh.dbt.builtin import BUILTIN_FILTERS, BUILTIN_GLOBALS, OVERRIDDEN_MACROS +from sqlmesh.dbt.common import Dependencies from sqlmesh.dbt.model import ModelConfig -from sqlmesh.dbt.package import MacroConfig +from sqlmesh.dbt.package import HookConfig, MacroConfig, MaterializationConfig from sqlmesh.dbt.seed import SeedConfig from sqlmesh.dbt.source import SourceConfig from sqlmesh.dbt.target import TargetConfig from sqlmesh.dbt.test import TestConfig from sqlmesh.dbt.util import DBT_VERSION +from sqlmesh.utils.cache import FileCache from sqlmesh.utils.errors import ConfigError from sqlmesh.utils.jinja import ( MacroInfo, @@ -33,10 +63,12 @@ extract_call_names, jinja_call_arg_name, ) +from sqlglot.helper import ensure_list if t.TYPE_CHECKING: from dbt.contracts.graph.manifest import Macro, Manifest from dbt.contracts.graph.nodes import ManifestNode, SourceDefinition + from sqlmesh.utils.jinja import CallNames logger = logging.getLogger(__name__) @@ -45,11 +77,20 @@ SeedConfigs = t.Dict[str, SeedConfig] SourceConfigs = t.Dict[str, SourceConfig] MacroConfigs = t.Dict[str, MacroConfig] +HookConfigs = t.Dict[str, HookConfig] +MaterializationConfigs = t.Dict[str, MaterializationConfig] IGNORED_PACKAGES = {"elementary"} BUILTIN_CALLS = {*BUILTIN_GLOBALS, *BUILTIN_FILTERS} +# Patch Semantic Manifest to skip validation and avoid Pydantic v1 errors on DBT 1.6 +# We patch for 1.7+ since we don't care about semantic models +if DBT_VERSION >= (1, 6, 0): + from dbt.contracts.graph.semantic_manifest import SemanticManifest # type: ignore + + SemanticManifest.validate = lambda _: True # type: ignore + class ManifestHelper: def __init__( @@ -59,12 +100,15 @@ def __init__( profile_name: str, target: TargetConfig, variable_overrides: t.Optional[t.Dict[str, t.Any]] = None, + cache_dir: t.Optional[str] = None, + model_defaults: t.Optional[ModelDefaultsConfig] = None, ): self.project_path = project_path self.profiles_path = profiles_path self.profile_name = profile_name self.target = target self.variable_overrides = variable_overrides or {} + self.model_defaults = model_defaults or ModelDefaultsConfig() self.__manifest: t.Optional[Manifest] = None self._project_name: str = "" @@ -82,6 +126,21 @@ def __init__( self._disabled_refs: t.Optional[t.Set[str]] = None self._disabled_sources: t.Optional[t.Set[str]] = None + if cache_dir is not None: + cache_path = Path(cache_dir) + if not cache_path.is_absolute(): + cache_path = self.project_path / cache_path + else: + cache_path = self.project_path / c.CACHE + + self._call_cache: FileCache[t.Dict[str, t.List[CallNames]]] = FileCache( + cache_path, "jinja_calls" + ) + + self._on_run_start_per_package: t.Dict[str, HookConfigs] = defaultdict(dict) + self._on_run_end_per_package: t.Dict[str, HookConfigs] = defaultdict(dict) + self._materializations: MaterializationConfigs = {} + def tests(self, package_name: t.Optional[str] = None) -> TestConfigs: self._load_all() return self._tests_per_package[package_name or self._project_name] @@ -102,6 +161,18 @@ def macros(self, package_name: t.Optional[str] = None) -> MacroConfigs: self._load_all() return self._macros_per_package[package_name or self._project_name] + def on_run_start(self, package_name: t.Optional[str] = None) -> HookConfigs: + self._load_all() + return self._on_run_start_per_package[package_name or self._project_name] + + def on_run_end(self, package_name: t.Optional[str] = None) -> HookConfigs: + self._load_all() + return self._on_run_end_per_package[package_name or self._project_name] + + def materializations(self) -> MaterializationConfigs: + self._load_all() + return self._materializations + @property def all_macros(self) -> t.Dict[str, t.Dict[str, MacroInfo]]: self._load_all() @@ -111,20 +182,77 @@ def all_macros(self) -> t.Dict[str, t.Dict[str, MacroInfo]]: result[package_name][macro_name] = macro_config.info return result + @cached_property + def flat_graph(self) -> t.Dict[str, t.Any]: + return { + "exposures": { + k: make_serializable(v.to_dict(omit_none=False)) + for k, v in getattr(self._manifest, "exposures", {}).items() + }, + "groups": { + k: make_serializable(v.to_dict(omit_none=False)) + for k, v in getattr(self._manifest, "groups", {}).items() + }, + "metrics": { + k: make_serializable(v.to_dict(omit_none=False)) + for k, v in getattr(self._manifest, "metrics", {}).items() + }, + "nodes": { + k: make_serializable(v.to_dict(omit_none=False)) + for k, v in self._manifest.nodes.items() + }, + "sources": { + k: make_serializable(v.to_dict(omit_none=False)) + for k, v in self._manifest.sources.items() + }, + "semantic_models": { + k: make_serializable(v.to_dict(omit_none=False)) + for k, v in getattr(self._manifest, "semantic_models", {}).items() + }, + "saved_queries": { + k: make_serializable(v.to_dict(omit_none=False)) + for k, v in getattr(self._manifest, "saved_queries", {}).items() + }, + } + def _load_all(self) -> None: if self._is_loaded: return + + self._calls = {k: (v, False) for k, v in (self._call_cache.get("") or {}).items()} + self._load_macros() + self._load_materializations() self._load_sources() self._load_tests() self._load_models_and_seeds() + self._load_on_run_start_end() self._is_loaded = True + self._call_cache.put("", value={k: v for k, (v, used) in self._calls.items() if used}) + def _load_sources(self) -> None: for source in self._manifest.sources.values(): + # starting in dbt-core 1.9.5, freshness can be set in both source and source config + source_dict = source.to_dict() + source_dict.pop("freshness", None) + + source_config_dict = _config(source) + source_config_dict.pop("freshness", None) + + source_config_freshness = getattr(source.config, "freshness", None) + freshness = ( + merge_freshness(source.freshness, source_config_freshness) + if source_config_freshness + else source.freshness + ) + source_config = SourceConfig( - **_config(source), - **source.to_dict(), + **{ + **source_dict, + **source_config_dict, + "freshness": freshness.to_dict() if freshness else None, + } ) self._sources_per_package[source.package_name][source_config.config_name] = ( source_config @@ -132,11 +260,14 @@ def _load_sources(self) -> None: def _load_macros(self) -> None: for macro in self._manifest.macros.values(): + if macro.name.startswith("materialization_"): + continue + if macro.name.startswith("test_"): macro.macro_sql = _convert_jinja_test_to_macro(macro.macro_sql) dependencies = Dependencies(macros=_macro_references(self._manifest, macro)) - if not macro.name.startswith("materialization_") and not macro.name.startswith("test_"): + if not macro.name.startswith("test_"): dependencies = dependencies.union( self._extra_dependencies(macro.macro_sql, macro.package_name) ) @@ -150,6 +281,45 @@ def _load_macros(self) -> None: path=Path(macro.original_file_path), ) + # This is a workaround for dbt adapter macros (eg. "spark__dateadd") whcih are expected to be + # available in the global scope regardless of the package they came from. + adapter_macro_names = { + name[name.find("__") + 2 :] + for name in self._macros_per_package.get("dbt", {}) + if "__" in name + } + for macros in self._macros_per_package.values(): + for name, macro_config in macros.items(): + pos = name.find("__") + if pos > 0 and name[pos + 2 :] in adapter_macro_names: + macro_config.info.is_top_level = True + + def _load_materializations(self) -> None: + for macro in self._manifest.macros.values(): + if macro.name.startswith("materialization_"): + # Extract name and adapter ( "materialization_{name}_{adapter}" or "materialization_{name}_default") + name_parts = macro.name.split("_") + if len(name_parts) >= 3: + mat_name = "_".join(name_parts[1:-1]) + adapter = name_parts[-1] + + dependencies = Dependencies(macros=_macro_references(self._manifest, macro)) + macro.macro_sql = _strip_jinja_materialization_tags(macro.macro_sql) + dependencies = dependencies.union( + self._extra_dependencies(macro.macro_sql, macro.package_name) + ) + + materialization_config = MaterializationConfig( + name=mat_name, + adapter=adapter, + definition=macro.macro_sql, + dependencies=dependencies, + path=Path(macro.original_file_path), + ) + + key = f"{mat_name}_{adapter}" + self._materializations[key] = materialization_config + def _load_tests(self) -> None: for node in self._manifest.nodes.values(): if node.resource_type != "test": @@ -179,22 +349,24 @@ def _load_tests(self) -> None: dependencies.macros.append(MacroReference(package="dbt", name="get_where_subquery")) dependencies.macros.append(MacroReference(package="dbt", name="should_store_failures")) - sql = node.raw_code if DBT_VERSION >= (1, 3) else node.raw_sql # type: ignore + sql = node.raw_code if DBT_VERSION >= (1, 3, 0) else node.raw_sql # type: ignore dependencies = dependencies.union(self._extra_dependencies(sql, node.package_name)) dependencies = dependencies.union( self._flatten_dependencies_from_macros(dependencies.macros, node.package_name) ) test_model = _test_model(node) + node_config = _node_base_config(node) + node_config["name"] = _build_test_name(node, dependencies) test = TestConfig( sql=sql, model_name=test_model, test_kwargs=node.test_metadata.kwargs if hasattr(node, "test_metadata") else {}, dependencies=dependencies, - **_node_base_config(node), + **node_config, ) - self._tests_per_package[node.package_name][node.name.lower()] = test + self._tests_per_package[node.package_name][node.unique_id] = test if test_model: self._tests_by_owner[test_model].append(test) @@ -207,40 +379,95 @@ def _load_models_and_seeds(self) -> None: continue macro_references = _macro_references(self._manifest, node) - tests = ( + all_tests = ( self._tests_by_owner[node.name] + self._tests_by_owner[f"{node.package_name}.{node.name}"] ) + # Only include non-standalone tests (tests that don't reference other models) + tests = [test for test in all_tests if not test.is_standalone] node_config = _node_base_config(node) + node_name = node.name + node_version = getattr(node, "version", None) + if node_version: + node_name = f"{node_name}_v{node_version}" + if node.resource_type in {"model", "snapshot"}: - sql = node.raw_code if DBT_VERSION >= (1, 3) else node.raw_sql # type: ignore + sql = node.raw_code if DBT_VERSION >= (1, 3, 0) else node.raw_sql # type: ignore dependencies = Dependencies( macros=macro_references, refs=_refs(node), sources=_sources(node) ) - dependencies = dependencies.union(self._extra_dependencies(sql, node.package_name)) + dependencies = dependencies.union( + self._extra_dependencies(sql, node.package_name, track_all_model_attrs=True) + ) + for hook in [*node_config.get("pre-hook", []), *node_config.get("post-hook", [])]: + dependencies = dependencies.union( + self._extra_dependencies( + hook["sql"], node.package_name, track_all_model_attrs=True + ) + ) dependencies = dependencies.union( self._flatten_dependencies_from_macros(dependencies.macros, node.package_name) ) - # Using the alias instead of the name because the alias captures the version of the model. - self._models_per_package[node.package_name][node.alias] = ModelConfig( - sql=sql, - dependencies=dependencies, - tests=tests, - **node_config, + self._models_per_package[node.package_name][node_name] = ModelConfig( + **dict( + node_config, + sql=sql, + dependencies=dependencies, + tests=tests, + ) ) else: - self._seeds_per_package[node.package_name][node.alias] = SeedConfig( - dependencies=Dependencies(macros=macro_references), - tests=tests, - **node_config, + self._seeds_per_package[node.package_name][node_name] = SeedConfig( + **dict( + node_config, + dependencies=Dependencies(macros=macro_references), + tests=tests, + ) ) + def _load_on_run_start_end(self) -> None: + for node in self._manifest.nodes.values(): + if node.resource_type == "operation" and ( + set(node.tags) & {"on-run-start", "on-run-end"} + ): + sql = node.raw_code if DBT_VERSION >= (1, 3, 0) else node.raw_sql # type: ignore + node_name = node.name + node_path = Path(node.original_file_path) + + dependencies = Dependencies( + macros=_macro_references(self._manifest, node), + refs=_refs(node), + sources=_sources(node), + ) + dependencies = dependencies.union(self._extra_dependencies(sql, node.package_name)) + dependencies = dependencies.union( + self._flatten_dependencies_from_macros(dependencies.macros, node.package_name) + ) + + if "on-run-start" in node.tags: + self._on_run_start_per_package[node.package_name][node_name] = HookConfig( + sql=sql, + index=getattr(node, "index", None) or 0, + path=node_path, + dependencies=dependencies, + ) + else: + self._on_run_end_per_package[node.package_name][node_name] = HookConfig( + sql=sql, + index=getattr(node, "index", None) or 0, + path=node_path, + dependencies=dependencies, + ) + @property def _manifest(self) -> Manifest: if not self.__manifest: - self.__manifest = self._load_manifest() + try: + self.__manifest = self._load_manifest() + except Exception as ex: + raise SQLMeshError(f"Failed to load dbt manifest: {ex}") from ex return self.__manifest def _load_manifest(self) -> Manifest: @@ -248,13 +475,14 @@ def _load_manifest(self) -> Manifest: variables = ( self.variable_overrides - if DBT_VERSION >= (1, 5) + if DBT_VERSION >= (1, 5, 0) else json.dumps(self.variable_overrides) ) args: Namespace = Namespace( vars=variables, profile=self.profile_name, + project_dir=str(self.project_path), profiles_dir=str(self.profiles_path), target=self.target.name, macro_debugging=False, @@ -262,7 +490,7 @@ def _load_manifest(self) -> Manifest: ) flags.set_from_args(args, None) - if DBT_VERSION >= (1, 8): + if DBT_VERSION >= (1, 8, 0): from dbt_common.context import set_invocation_context # type: ignore set_invocation_context(os.environ) @@ -270,16 +498,19 @@ def _load_manifest(self) -> Manifest: profile = self._load_profile() project = self._load_project(profile) - if not any(k in project.models for k in ("start", "+start")): + if ( + not any(k in project.models for k in ("start", "+start")) + and not self.model_defaults.start + ): raise ConfigError( - "SQLMesh's requires a start date in order to have a finite range of backfilling data. Add start to the 'models:' block in dbt_project.yml. https://sqlmesh.readthedocs.io/en/stable/integrations/dbt/#setting-model-backfill-start-dates" + "SQLMesh requires a start date in order to have a finite range of backfilling data. Add start to the 'models:' block in dbt_project.yml. https://sqlmesh.readthedocs.io/en/stable/integrations/dbt/#setting-model-backfill-start-dates" ) runtime_config = RuntimeConfig.from_parts(project, profile, args) self._project_name = project.project_name - if DBT_VERSION >= (1, 8): + if DBT_VERSION >= (1, 8, 0): from dbt.mp_context import get_mp_context # type: ignore register_adapter(runtime_config, get_mp_context()) # type: ignore @@ -287,6 +518,8 @@ def _load_manifest(self) -> Manifest: register_adapter(runtime_config) # type: ignore manifest = ManifestLoader.get_full_manifest(runtime_config) + # This adapter doesn't care about semantic models so we clear them out to avoid issues + manifest.semantic_models = {} reset_adapters() return manifest @@ -370,14 +603,37 @@ def _flatten_dependencies_from_macros( dependencies = dependencies.union(macro_dependencies) return dependencies - def _extra_dependencies(self, target: str, package: str) -> Dependencies: - # We sometimes observe that the manifest doesn't capture all macros, refs, and sources within a macro. - # This behavior has been observed with macros like dbt.current_timestamp(), dbt_utils.slugify(), and source(). - # Here we apply our custom extractor to make a best effort to supplement references captured in the manifest. + def _extra_dependencies( + self, + target: str, + package: str, + track_all_model_attrs: bool = False, + ) -> Dependencies: + """ + We sometimes observe that the manifest doesn't capture all macros, refs, and sources within a macro. + This behavior has been observed with macros like dbt.current_timestamp(), dbt_utils.slugify(), and source(). + Here we apply our custom extractor to make a best effort to supplement references captured in the manifest. + """ dependencies = Dependencies() - for call_name, node in extract_call_names(target): + + # Whether all `model` attributes (e.g., `model.config`) should be included in the dependencies + all_model_attrs = False + + for call_name, node in extract_call_names(target, cache=self._calls): if call_name[0] == "config": continue + + if ( + track_all_model_attrs + and not all_model_attrs + and isinstance(node, jinja2.nodes.Call) + and any(isinstance(a, jinja2.nodes.Name) and a.name == "model" for a in node.args) + ): + all_model_attrs = True + + if isinstance(node, jinja2.nodes.Getattr): + if call_name[0] == "model": + dependencies.model_attrs.attrs.add(call_name[1]) elif call_name[0] == "source": args = [jinja_call_arg_name(arg) for arg in node.args] if args and all(arg for arg in args): @@ -396,6 +652,9 @@ def _extra_dependencies(self, target: str, package: str) -> Dependencies: args = [jinja_call_arg_name(arg) for arg in node.args] if args and args[0]: dependencies.variables.add(args[0]) + else: + # We couldn't determine the var name statically + dependencies.has_dynamic_var_names = True dependencies.macros.append(MacroReference(name="var")) elif len(call_name) == 1: macro_name = call_name[0] @@ -418,6 +677,14 @@ def _extra_dependencies(self, target: str, package: str) -> Dependencies: call_name[0], call_name[1], dependencies.macros.append ) + # When `model` is referenced as-is, e.g. it's passed as an argument to a macro call like + # `{{ foo(model) }}`, we can't easily track the attributes that are actually used, because + # it may be aliased and hence tracking actual uses of `model` requires a proper data flow + # analysis. We conservatively deal with this by including all of its supported attributes + # if a standalone reference is found. + if all_model_attrs: + dependencies.model_attrs.all_attrs = True + return dependencies @@ -437,8 +704,11 @@ def _macro_references( manifest: Manifest, node: t.Union[ManifestNode, Macro] ) -> t.Set[MacroReference]: result: t.Set[MacroReference] = set() + if not hasattr(node, "depends_on"): + return result + for macro_node_id in node.depends_on.macros: - if not macro_node_id: + if not macro_node_id or macro_node_id == "None": continue macro_node = manifest.macros[macro_node_id] @@ -451,20 +721,21 @@ def _macro_references( def _refs(node: ManifestNode) -> t.Set[str]: - if DBT_VERSION >= (1, 5): - result = set() + if DBT_VERSION >= (1, 5, 0): + result: t.Set[str] = set() + if not hasattr(node, "refs"): + return result for r in node.refs: - ref_name = f"{r.package}.{r.name}" if r.package else r.name + ref_name = f"{r.package}.{r.name}" if r.package else r.name # type: ignore if getattr(r, "version", None): - ref_name = f"{ref_name}_v{r.version}" + ref_name = f"{ref_name}_v{r.version}" # type: ignore result.add(ref_name) return result - else: - return {".".join(r) for r in node.refs} # type: ignore + return {".".join(r) for r in node.refs} # type: ignore def _sources(node: ManifestNode) -> t.Set[str]: - return {".".join(s) for s in node.sources} + return {".".join(s) for s in getattr(node, "sources", [])} def _model_node_id(model_name: str, package: str) -> str: @@ -475,7 +746,12 @@ def _test_model(node: ManifestNode) -> t.Optional[str]: attached_node = getattr(node, "attached_node", None) if attached_node: pieces = attached_node.split(".") - return pieces[-1] if pieces[0] in ["model", "seed"] else None + if pieces[0] in ["model", "seed"]: + # versioned models have format "model.package.model_name.v1" (4 parts) + if len(pieces) == 4: + return f"{pieces[2]}_{pieces[3]}" + return pieces[-1] + return None key_name = getattr(node, "file_key_name", None) if key_name: @@ -494,12 +770,91 @@ def _node_base_config(node: ManifestNode) -> t.Dict[str, t.Any]: def _convert_jinja_test_to_macro(test_jinja: str) -> str: - TEST_TAG_REGEX = r"\s*{%\s*test\s+" - ENDTEST_REGEX = r"{%\s*endtest\s*%}" + TEST_TAG_REGEX = r"\s*{%-?\s*test\s+" + ENDTEST_REGEX = r"{%-?\s*endtest\s*-?%}" + match = re.match(TEST_TAG_REGEX, test_jinja) if not match: # already a macro return test_jinja - macro = "{% macro test_" + test_jinja[match.span()[-1] :] - return re.sub(ENDTEST_REGEX, "{% endmacro %}", macro) + test_tag = test_jinja[: match.span()[-1]] + + macro_tag = re.sub(r"({%-?\s*)test\s+", r"\1macro test_", test_tag) + macro = macro_tag + test_jinja[match.span()[-1] :] + + return re.sub(ENDTEST_REGEX, lambda m: m.group(0).replace("endtest", "endmacro"), macro) + + +def _strip_jinja_materialization_tags(materialization_jinja: str) -> str: + MATERIALIZATION_TAG_REGEX = r"\s*{%-?\s*materialization\s+[^%]*%}\s*\n?" + ENDMATERIALIZATION_REGEX = r"{%-?\s*endmaterialization\s*-?%}\s*\n?" + + if not re.match(MATERIALIZATION_TAG_REGEX, materialization_jinja): + return materialization_jinja + + materialization_jinja = re.sub( + MATERIALIZATION_TAG_REGEX, + "", + materialization_jinja, + flags=re.IGNORECASE, + ) + + materialization_jinja = re.sub( + ENDMATERIALIZATION_REGEX, + "", + materialization_jinja, + flags=re.IGNORECASE, + ) + + return materialization_jinja.strip() + + +def _build_test_name(node: ManifestNode, dependencies: Dependencies) -> str: + """ + Build a user-friendly test name that includes the test's model/source, column, + and args for tests with custom user names. Needed because dbt only generates these + names for tests that do not specify the "name" field in their YAML definition. + + Name structure + - Model test: [namespace]_[test name]_[model name]_[column name]__[arg values] + - Source test: [namespace]_source_[test name]_[source name]_[table name]_[column name]__[arg values] + """ + # standalone test + if not hasattr(node, "test_metadata"): + return node.name + + model_name = _test_model(node) + source_name = None + if not model_name and dependencies.sources: + # extract source and table names + source_parts = list(dependencies.sources)[0].split(".") + source_name = "_".join(source_parts) if len(source_parts) == 2 else source_parts[-1] + entity_name = model_name or source_name or "" + entity_name = f"_{entity_name}" if entity_name else "" + + name_prefix = "" + if namespace := getattr(node.test_metadata, "namespace", None): + name_prefix += f"{namespace}_" + if source_name and not model_name: + name_prefix += "source_" + + metadata_kwargs = node.test_metadata.kwargs + arg_val_parts = [] + for arg, val in sorted(metadata_kwargs.items()): + if arg == "model": + continue + if isinstance(val, dict): + val = list(val.values()) + val = [re.sub("[^0-9a-zA-Z_]+", "_", str(v)) for v in ensure_list(val)] + arg_val_parts.extend(val) + unique_args = "__".join(arg_val_parts) if arg_val_parts else "" + unique_args = f"_{unique_args}" if unique_args else "" + + auto_name = f"{name_prefix}{node.test_metadata.name}{entity_name}{unique_args}" + + if node.name == auto_name: + return node.name + + custom_prefix = name_prefix if source_name and not model_name else "" + return f"{custom_prefix}{node.name}{entity_name}{unique_args}" diff --git a/sqlmesh/dbt/model.py b/sqlmesh/dbt/model.py index 80fffa5889..41cea9b9ae 100644 --- a/sqlmesh/dbt/model.py +++ b/sqlmesh/dbt/model.py @@ -1,7 +1,8 @@ from __future__ import annotations -import logging +import datetime import typing as t +import logging from sqlglot import exp from sqlglot.errors import SqlglotError @@ -9,6 +10,8 @@ from sqlmesh.core import dialect as d from sqlmesh.core.config.base import UpdateStrategy +from sqlmesh.core.config.common import VirtualEnvironmentMode +from sqlmesh.core.console import get_console from sqlmesh.core.model import ( EmbeddedKind, FullKind, @@ -19,26 +22,41 @@ ModelKind, SCDType2ByColumnKind, ViewKind, + ManagedKind, create_sql_model, ) -from sqlmesh.core.model.kind import SCDType2ByTimeKind, OnDestructiveChange +from sqlmesh.core.model.kind import ( + SCDType2ByTimeKind, + OnDestructiveChange, + OnAdditiveChange, + on_destructive_change_validator, + on_additive_change_validator, + DbtCustomKind, +) from sqlmesh.dbt.basemodel import BaseModelConfig, Materialization, SnapshotStrategy -from sqlmesh.dbt.common import SqlStr, extract_jinja_config, sql_str_validator +from sqlmesh.dbt.common import SqlStr, sql_str_validator from sqlmesh.utils.errors import ConfigError from sqlmesh.utils.pydantic import field_validator if t.TYPE_CHECKING: + from sqlmesh.core.audit.definition import ModelAudit from sqlmesh.dbt.context import DbtContext + from sqlmesh.dbt.package import MaterializationConfig + +logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) -INCREMENTAL_BY_TIME_STRATEGIES = set(["delete+insert", "insert_overwrite"]) + +INCREMENTAL_BY_TIME_RANGE_STRATEGIES = set( + ["delete+insert", "insert_overwrite", "microbatch", "incremental_by_time_range"] +) INCREMENTAL_BY_UNIQUE_KEY_STRATEGIES = set(["merge"]) def collection_to_str(collection: t.Iterable) -> str: - return ", ".join(f"'{item}'" for item in collection) + return ", ".join(f"'{item}'" for item in sorted(collection)) class ModelConfig(BaseModelConfig): @@ -69,14 +87,19 @@ class ModelConfig(BaseModelConfig): # sqlmesh fields sql: SqlStr = SqlStr("") - time_column: t.Optional[str] = None + time_column: t.Optional[t.Union[str, t.Dict[str, str]]] = None cron: t.Optional[str] = None interval_unit: t.Optional[str] = None - batch_size: t.Optional[int] = None - lookback: t.Optional[int] = None + batch_concurrency: t.Optional[int] = None forward_only: bool = True disable_restatement: t.Optional[bool] = None - allow_partials: t.Optional[bool] = None + allow_partials: bool = True + physical_version: t.Optional[str] = None + auto_restatement_cron: t.Optional[str] = None + auto_restatement_intervals: t.Optional[int] = None + partition_by_time_column: t.Optional[bool] = None + on_destructive_change: t.Optional[OnDestructiveChange] = None + on_additive_change: t.Optional[OnAdditiveChange] = None # DBT configuration fields cluster_by: t.Optional[t.List[str]] = None @@ -86,7 +109,7 @@ class ModelConfig(BaseModelConfig): unique_key: t.Optional[t.List[str]] = None partition_by: t.Optional[t.Union[t.List[str], t.Dict[str, t.Any]]] = None full_refresh: t.Optional[bool] = None - on_schema_change: t.Optional[str] = None + on_schema_change: str = "ignore" # Snapshot (SCD Type 2) Fields updated_at: t.Optional[str] = None @@ -95,6 +118,15 @@ class ModelConfig(BaseModelConfig): target_schema: t.Optional[str] = None check_cols: t.Optional[t.Union[t.List[str], str]] = None + # Microbatch Fields + event_time: t.Optional[str] = None + begin: t.Optional[datetime.datetime] = None + concurrent_batches: t.Optional[bool] = None + + # Shared SQLMesh and DBT configuration fields + batch_size: t.Optional[t.Union[int, str]] = None + lookback: t.Optional[int] = None + # redshift bind: t.Optional[bool] = None @@ -104,12 +136,24 @@ class ModelConfig(BaseModelConfig): # snowflake snowflake_warehouse: t.Optional[str] = None - - # Private fields - _sql_embedded_config: t.Optional[SqlStr] = None - _sql_no_config: t.Optional[SqlStr] = None + # note: for Snowflake dynamic tables, in the DBT adapter we only support properties that DBT supports + # which are defined here: https://docs.getdbt.com/reference/resource-configs/snowflake-configs#dynamic-tables + target_lag: t.Optional[str] = None + + # clickhouse + engine: t.Optional[str] = None + order_by: t.Optional[t.Union[t.List[str], str]] = None + primary_key: t.Optional[t.Union[t.List[str], str]] = None + sharding_key: t.Optional[t.Union[t.List[str], str]] = None + ttl: t.Optional[t.Union[t.List[str], str]] = None + settings: t.Optional[t.Dict[str, t.Any]] = None + query_settings: t.Optional[t.Dict[str, t.Any]] = None + inserts_only: t.Optional[bool] = None + incremental_predicates: t.Optional[t.List[str]] = None _sql_validator = sql_str_validator + _on_destructive_change_validator = on_destructive_change_validator + _on_additive_change_validator = on_additive_change_validator @field_validator( "unique_key", @@ -128,6 +172,22 @@ def _validate_check_cols(cls, v: t.Union[str, t.List[str]]) -> t.Union[str, t.Li return "*" return ensure_list(v) + @field_validator("updated_at", mode="before") + @classmethod + def _validate_updated_at(cls, v: t.Optional[str]) -> t.Optional[str]: + """ + Extract column name if updated_at contains a cast. + + SCDType2ByTimeKind and SCDType2ByColumnKind expect a column, and the casting is done later. + """ + if v is None: + return None + parsed = d.parse_one(v) + if isinstance(parsed, exp.Cast) and isinstance(parsed.this, exp.Column): + return parsed.this.name + + return v + @field_validator("sql", mode="before") @classmethod def _validate_sql(cls, v: t.Union[str, SqlStr]) -> SqlStr: @@ -135,7 +195,11 @@ def _validate_sql(cls, v: t.Union[str, SqlStr]) -> SqlStr: @field_validator("partition_by", mode="before") @classmethod - def _validate_partition_by(cls, v: t.Any) -> t.Union[t.List[str], t.Dict[str, t.Any]]: + def _validate_partition_by( + cls, v: t.Any + ) -> t.Optional[t.Union[t.List[str], t.Dict[str, t.Any]]]: + if v is None: + return None if isinstance(v, str): return [v] if isinstance(v, list): @@ -151,9 +215,40 @@ def _validate_partition_by(cls, v: t.Any) -> t.Union[t.List[str], t.Dict[str, t. ): granularity = v["granularity"] raise ConfigError(f"Unexpected granularity '{granularity}' in partition_by '{v}'.") + if "data_type" in v and v["data_type"].lower() not in ( + "timestamp", + "date", + "datetime", + "int64", + ): + data_type = v["data_type"] + raise ConfigError(f"Unexpected data_type '{data_type}' in partition_by '{v}'.") return {"data_type": "date", "granularity": "day", **v} raise ConfigError(f"Invalid format for partition_by '{v}'") + @field_validator("materialized", mode="before") + @classmethod + def _validate_materialized(cls, v: str) -> str: + unsupported_materializations = [ + "materialized_view", # multiple engines + "dictionary", # clickhouse only + "distributed_table", # clickhouse only + "distributed_incremental", # clickhouse only + ] + if v in unsupported_materializations: + fallback = v.split("_") + msg = f"SQLMesh does not support the '{v}' model materialization." + if len(fallback) == 1: + # dictionary materialization + raise ConfigError(msg) + else: + get_console().log_warning( + f"{msg} Falling back to the '{fallback[1]}' materialization." + ) + + return fallback[1] + return v + _FIELD_UPDATE_STRATEGY: t.ClassVar[t.Dict[str, UpdateStrategy]] = { **BaseModelConfig._FIELD_UPDATE_STRATEGY, **{ @@ -185,56 +280,124 @@ def model_kind(self, context: DbtContext) -> ModelKind: # args common to all sqlmesh incremental kinds, regardless of materialization incremental_kind_kwargs: t.Dict[str, t.Any] = {} - if self.on_schema_change: - on_schema_change = self.on_schema_change.lower() - - on_destructive_change = OnDestructiveChange.WARN - if on_schema_change == "sync_all_columns": - on_destructive_change = OnDestructiveChange.ALLOW - elif on_schema_change == "fail": - on_destructive_change = OnDestructiveChange.ERROR + on_schema_change = self.on_schema_change.lower() + if materialization == Materialization.SNAPSHOT: + # dbt snapshots default to `append_new_columns` behavior and can't be changed + on_schema_change = "append_new_columns" + + if on_schema_change == "ignore": + on_destructive_change = OnDestructiveChange.IGNORE + on_additive_change = OnAdditiveChange.IGNORE + elif on_schema_change == "fail": + on_destructive_change = OnDestructiveChange.ERROR + on_additive_change = OnAdditiveChange.ERROR + elif on_schema_change == "append_new_columns": + on_destructive_change = OnDestructiveChange.IGNORE + on_additive_change = OnAdditiveChange.ALLOW + elif on_schema_change == "sync_all_columns": + on_destructive_change = OnDestructiveChange.ALLOW + on_additive_change = OnAdditiveChange.ALLOW + else: + raise ConfigError( + f"{self.canonical_name(context)}: Invalid on_schema_change value '{on_schema_change}'. " + "Valid values are 'ignore', 'fail', 'append_new_columns', 'sync_all_columns'." + ) - incremental_kind_kwargs["on_destructive_change"] = on_destructive_change + incremental_kind_kwargs["on_destructive_change"] = ( + self._get_field_value("on_destructive_change") or on_destructive_change + ) + incremental_kind_kwargs["on_additive_change"] = ( + self._get_field_value("on_additive_change") or on_additive_change + ) + auto_restatement_cron_value = self._get_field_value("auto_restatement_cron") + if auto_restatement_cron_value is not None: + incremental_kind_kwargs["auto_restatement_cron"] = auto_restatement_cron_value if materialization == Materialization.TABLE: return FullKind() if materialization == Materialization.VIEW: return ViewKind() if materialization == Materialization.INCREMENTAL: - incremental_materialization_kwargs: t.Dict[str, t.Any] = { - "dialect": self.dialect(context) - } - for field in ("batch_size", "lookback", "forward_only"): - field_val = getattr(self, field, None) or self.meta.get(field, None) - if field_val: - incremental_materialization_kwargs[field] = field_val + incremental_by_kind_kwargs: t.Dict[str, t.Any] = {"dialect": self.dialect(context)} + forward_only_value = self._get_field_value("forward_only") + if forward_only_value is not None: + incremental_kind_kwargs["forward_only"] = forward_only_value + + is_incremental_by_time_range = self.time_column or ( + self.incremental_strategy + and self.incremental_strategy in {"microbatch", "incremental_by_time_range"} + ) + # Get shared incremental by kwargs + for field in ("batch_size", "batch_concurrency", "lookback"): + field_val = self._get_field_value(field) + if field_val is not None: + # Check if `batch_size` is representing an interval unit and if so that will be handled at the model level + if field == "batch_size" and isinstance(field_val, str): + continue + incremental_by_kind_kwargs[field] = field_val + + disable_restatement = self.disable_restatement + if disable_restatement is None: + if is_incremental_by_time_range: + disable_restatement = False + else: + disable_restatement = ( + not self.full_refresh if self.full_refresh is not None else False + ) + incremental_by_kind_kwargs["disable_restatement"] = disable_restatement - if self.time_column: + if is_incremental_by_time_range: strategy = self.incremental_strategy or target.default_incremental_strategy( IncrementalByTimeRangeKind ) - if strategy not in INCREMENTAL_BY_TIME_STRATEGIES: - logger.warning( - "SQLMesh incremental by time strategy is not compatible with '%s' incremental strategy in model '%s'. Supported strategies include %s.", - strategy, - self.canonical_name(context), - collection_to_str(INCREMENTAL_BY_TIME_STRATEGIES), + if strategy not in INCREMENTAL_BY_TIME_RANGE_STRATEGIES: + get_console().log_warning( + f"SQLMesh incremental by time strategy is not compatible with '{strategy}' incremental strategy in model '{self.canonical_name(context)}'. " + f"Supported strategies include {collection_to_str(INCREMENTAL_BY_TIME_RANGE_STRATEGIES)}." + ) + + if self.time_column and strategy != "incremental_by_time_range": + get_console().log_warning( + f"Using `time_column` on a model with incremental_strategy '{strategy}' has been deprecated. " + f"Please use `incremental_by_time_range` instead in model '{self.canonical_name(context)}'." + ) + + if strategy == "microbatch": + if self.time_column: + raise ConfigError( + f"{self.canonical_name(context)}: 'time_column' cannot be used with 'microbatch' incremental strategy. Use 'event_time' instead." + ) + time_column = self._get_field_value("event_time") + if not time_column: + raise ConfigError( + f"{self.canonical_name(context)}: 'event_time' is required for microbatch incremental strategy." + ) + # dbt microbatch always processes batches in a size of 1 + incremental_by_kind_kwargs["batch_size"] = 1 + else: + if not self.time_column: + raise ConfigError( + f"{self.canonical_name(context)}: 'time_column' is required for incremental by time range models not defined using microbatch." + ) + time_column = self.time_column + + incremental_by_time_range_kwargs = { + "time_column": time_column, + } + if self.auto_restatement_intervals: + incremental_by_time_range_kwargs["auto_restatement_intervals"] = ( + self.auto_restatement_intervals + ) + if self.partition_by_time_column is not None: + incremental_by_time_range_kwargs["partition_by_time_column"] = ( + self.partition_by_time_column ) return IncrementalByTimeRangeKind( - time_column=self.time_column, - disable_restatement=( - self.disable_restatement if self.disable_restatement is not None else False - ), **incremental_kind_kwargs, - **incremental_materialization_kwargs, - ) - - disable_restatement = self.disable_restatement - if disable_restatement is None: - disable_restatement = ( - not self.full_refresh if self.full_refresh is not None else False + **incremental_by_kind_kwargs, + **incremental_by_time_range_kwargs, ) if self.unique_key: @@ -245,30 +408,35 @@ def model_kind(self, context: DbtContext) -> ModelKind: self.incremental_strategy and strategy not in INCREMENTAL_BY_UNIQUE_KEY_STRATEGIES ): - raise ConfigError( - f"{self.canonical_name(context)}: SQLMesh incremental by unique key strategy is not compatible with '{strategy}'" - f" incremental strategy. Supported strategies include {collection_to_str(INCREMENTAL_BY_UNIQUE_KEY_STRATEGIES)}." + get_console().log_warning( + f"Unique key is not compatible with '{strategy}' incremental strategy in model '{self.canonical_name(context)}'. " + f"Supported strategies include {collection_to_str(INCREMENTAL_BY_UNIQUE_KEY_STRATEGIES)}. Falling back to 'merge' strategy." ) + + merge_filter = None + if self.incremental_predicates: + dialect = self.dialect(context) + merge_filter = exp.and_( + *[ + d.parse_one(predicate, dialect=dialect) + for predicate in self.incremental_predicates + ], + dialect=dialect, + ).transform(d.replace_merge_table_aliases) + return IncrementalByUniqueKeyKind( unique_key=self.unique_key, - disable_restatement=disable_restatement, + merge_filter=merge_filter, **incremental_kind_kwargs, - **incremental_materialization_kwargs, + **incremental_by_kind_kwargs, ) - logger.warning( - "Using unmanaged incremental materialization for model '%s'. Some features might not be available. Consider adding either a time_column (%s) or a unique_key (%s) configuration to mitigate this", - self.canonical_name(context), - collection_to_str(INCREMENTAL_BY_TIME_STRATEGIES), - collection_to_str(INCREMENTAL_BY_UNIQUE_KEY_STRATEGIES.union(["none"])), - ) strategy = self.incremental_strategy or target.default_incremental_strategy( IncrementalUnmanagedKind ) return IncrementalUnmanagedKind( - insert_overwrite=strategy in INCREMENTAL_BY_TIME_STRATEGIES, - forward_only=incremental_materialization_kwargs.get("forward_only", True), - disable_restatement=disable_restatement, + insert_overwrite=strategy in INCREMENTAL_BY_TIME_RANGE_STRATEGIES, + disable_restatement=incremental_by_kind_kwargs["disable_restatement"], **incremental_kind_kwargs, ) if materialization == Materialization.EPHEMERAL: @@ -298,26 +466,24 @@ def model_kind(self, context: DbtContext) -> ModelKind: return SCDType2ByTimeKind( updated_at_name=self.updated_at, updated_at_as_valid_from=True, **shared_kwargs ) - raise ConfigError(f"{materialization.value} materialization not supported.") - @property - def sql_no_config(self) -> SqlStr: - if self._sql_no_config is None: - self._sql_no_config = SqlStr("") - self._extract_sql_config() - return self._sql_no_config + if materialization == Materialization.DYNAMIC_TABLE: + return ManagedKind() - @property - def sql_embedded_config(self) -> SqlStr: - if self._sql_embedded_config is None: - self._sql_embedded_config = SqlStr("") - self._extract_sql_config() - return self._sql_embedded_config + if materialization == Materialization.CUSTOM: + if custom_materialization := self._get_custom_materialization(context): + return DbtCustomKind( + materialization=self.materialized, + adapter=custom_materialization.adapter, + dialect=self.dialect(context), + definition=custom_materialization.definition, + ) - def _extract_sql_config(self) -> None: - no_config, embedded_config = extract_jinja_config(self.sql) - self._sql_no_config = SqlStr(no_config) - self._sql_embedded_config = SqlStr(embedded_config) + raise ConfigError( + f"Unknown materialization '{self.materialized}'. Custom materializations must be defined in your dbt project." + ) + + raise ConfigError(f"{materialization.value} materialization not supported.") def _big_query_partition_by_expr(self, context: DbtContext) -> exp.Expression: assert isinstance(self.partition_by, dict) @@ -326,7 +492,9 @@ def _big_query_partition_by_expr(self, context: DbtContext) -> exp.Expression: try: field = d.parse_one(raw_field, dialect="bigquery") except SqlglotError as e: - raise ConfigError(f"Failed to parse partition_by field '{raw_field}': {e}") from e + raise ConfigError( + f"Failed to parse model '{self.canonical_name(context)}' partition_by field '{raw_field}' in '{self.path}': {e}" + ) from e if data_type == "date" and self.partition_by["granularity"].lower() == "day": return field @@ -354,32 +522,103 @@ def _big_query_partition_by_expr(self, context: DbtContext) -> exp.Expression: dialect="bigquery", ) + def _get_custom_materialization(self, context: DbtContext) -> t.Optional[MaterializationConfig]: + materializations = context.manifest.materializations() + name, target_adapter = self.materialized, context.target.dialect + + adapter_specific_key = f"{name}_{target_adapter}" + default_key = f"{name}_default" + if adapter_specific_key in materializations: + return materializations[adapter_specific_key] + if default_key in materializations: + return materializations[default_key] + return None + @property def sqlmesh_config_fields(self) -> t.Set[str]: - return super().sqlmesh_config_fields | {"cron", "interval_unit", "allow_partials"} - - def to_sqlmesh(self, context: DbtContext) -> Model: + return super().sqlmesh_config_fields | { + "cron", + "interval_unit", + "allow_partials", + "physical_version", + "start", + # In microbatch models `begin` is the same as `start` + "begin", + } + + def to_sqlmesh( + self, + context: DbtContext, + audit_definitions: t.Optional[t.Dict[str, ModelAudit]] = None, + virtual_environment_mode: VirtualEnvironmentMode = VirtualEnvironmentMode.default, + ) -> Model: """Converts the dbt model into a SQLMesh model.""" model_dialect = self.dialect(context) - query = d.jinja_query(self.sql_no_config) + query = d.jinja_query(self.sql) + kind = self.model_kind(context) optional_kwargs: t.Dict[str, t.Any] = {} + physical_properties: t.Dict[str, t.Any] = {} if self.partition_by: - optional_kwargs["partitioned_by"] = ( - [exp.to_column(val, dialect=model_dialect) for val in self.partition_by] - if isinstance(self.partition_by, list) - else self._big_query_partition_by_expr(context) - ) + if isinstance(kind, (ViewKind, EmbeddedKind)): + logger.warning( + "Ignoring partition_by config for model '%s'; partition_by is not supported for %s.", + self.name, + "views" if isinstance(kind, ViewKind) else "ephemeral models", + ) + elif context.target.dialect == "snowflake": + logger.warning( + "Ignoring partition_by config for model '%s' targeting %s. The partition_by config is not supported for Snowflake.", + self.name, + context.target.dialect, + ) + else: + partitioned_by = [] + if isinstance(self.partition_by, list): + for p in self.partition_by: + try: + partitioned_by.append(d.parse_one(p, dialect=model_dialect)) + except SqlglotError as e: + raise ConfigError( + f"Failed to parse model '{self.canonical_name(context)}' partition_by field '{p}' in '{self.path}': {e}" + ) from e + elif isinstance(self.partition_by, dict): + if context.target.dialect == "bigquery": + partitioned_by.append(self._big_query_partition_by_expr(context)) + else: + logger.warning( + "Ignoring partition_by config for model '%s' targeting %s. The format of the config field is only supported for BigQuery.", + self.name, + context.target.dialect, + ) + + if partitioned_by: + optional_kwargs["partitioned_by"] = partitioned_by if self.cluster_by: - clustered_by = [] - for c in self.cluster_by: - try: - clustered_by.append(d.parse_one(c, dialect=model_dialect).name) - except SqlglotError as e: - raise ConfigError(f"Failed to parse cluster_by field '{c}': {e}") from e - optional_kwargs["clustered_by"] = clustered_by + if isinstance(kind, (ViewKind, EmbeddedKind)): + logger.warning( + "Ignoring cluster_by config for model '%s'; cluster_by is not supported for %s.", + self.name, + "views" if isinstance(kind, ViewKind) else "ephemeral models", + ) + else: + clustered_by = [] + for c in self.cluster_by: + try: + cluster_expr = exp.maybe_parse( + c, into=exp.Cluster, prefix="CLUSTER BY", dialect=model_dialect + ) + for expr in cluster_expr.expressions: + clustered_by.append( + expr.this if isinstance(expr, exp.Ordered) else expr + ) + except SqlglotError as e: + raise ConfigError( + f"Failed to parse model '{self.canonical_name(context)}' cluster_by field '{c}' in '{self.path}': {e}" + ) from e + optional_kwargs["clustered_by"] = clustered_by model_kwargs = self.sqlmesh_model_kwargs(context) if self.sql_header: @@ -390,7 +629,6 @@ def to_sqlmesh(self, context: DbtContext) -> Model: if dbt_max_partition_blob: model_kwargs["pre_statements"].append(d.jinja_statement(dbt_max_partition_blob)) - physical_properties = {} if self.partition_expiration_days is not None: physical_properties["partition_expiration_days"] = self.partition_expiration_days if self.require_partition_filter is not None: @@ -399,18 +637,141 @@ def to_sqlmesh(self, context: DbtContext) -> Model: if physical_properties: model_kwargs["physical_properties"] = physical_properties - if context.target.dialect == "snowflake" and self.snowflake_warehouse is not None: - model_kwargs["session_properties"] = {"warehouse": self.snowflake_warehouse} + if context.target.dialect == "snowflake": + if self.snowflake_warehouse is not None: + model_kwargs["session_properties"] = {"warehouse": self.snowflake_warehouse} + + if self.model_materialization == Materialization.DYNAMIC_TABLE: + if not self.snowflake_warehouse: + raise ConfigError("`snowflake_warehouse` must be set for dynamic tables") + if not self.target_lag: + raise ConfigError("`target_lag` must be set for dynamic tables") + + model_kwargs["physical_properties"] = { + "warehouse": self.snowflake_warehouse, + "target_lag": self.target_lag, + } + + if context.target.dialect == "clickhouse": + if self.model_materialization == Materialization.INCREMENTAL: + # `inserts_only` overrides incremental_strategy setting (if present) + # https://github.com/ClickHouse/dbt-clickhouse/blob/065f3a724fa09205446ecadac7a00d92b2d8c646/README.md?plain=1#L108 + if self.inserts_only: + self.incremental_strategy = "append" + + if self.incremental_strategy == "delete+insert": + get_console().log_warning( + f"The '{self.incremental_strategy}' incremental strategy is not supported - SQLMesh will use the temp table/partition swap strategy." + ) + + if self.incremental_predicates: + get_console().log_warning( + "SQLMesh does not support 'incremental_predicates' - they will not be applied." + ) + + if self.query_settings: + get_console().log_warning( + "SQLMesh does not support the 'query_settings' model configuration parameter. Specify the query settings directly in the model query." + ) + + if self.engine: + optional_kwargs["storage_format"] = self.engine + + if self.order_by: + order_by = [] + for o in self.order_by if isinstance(self.order_by, list) else [self.order_by]: + try: + order_by.append(d.parse_one(o, dialect=model_dialect)) + except SqlglotError as e: + raise ConfigError( + f"Failed to parse model '{self.canonical_name(context)}' 'order_by' field '{o}' in '{self.path}': {e}" + ) from e + physical_properties["order_by"] = order_by + + if self.primary_key: + primary_key = [] + for p in self.primary_key: + try: + primary_key.append(d.parse_one(p, dialect=model_dialect)) + except SqlglotError as e: + raise ConfigError( + f"Failed to parse model '{self.canonical_name(context)}' 'primary_key' field '{p}' in '{self.path}': {e}" + ) from e + physical_properties["primary_key"] = primary_key + + if self.sharding_key: + get_console().log_warning( + "SQLMesh does not support the 'sharding_key' model configuration parameter or distributed materializations." + ) + + if self.ttl: + physical_properties["ttl"] = exp.var( + self.ttl[0] if isinstance(self.ttl, list) else self.ttl + ) + + if self.settings: + physical_properties.update({k: exp.var(v) for k, v in self.settings.items()}) + + if physical_properties: + model_kwargs["physical_properties"] = physical_properties + + kind = self.model_kind(context) + + # A falsy grants config (None or {}) is considered as unmanaged per dbt semantics + if self.grants and kind.supports_grants: + model_kwargs["grants"] = self.grants + + allow_partials = model_kwargs.pop("allow_partials", None) + if allow_partials is None: + # Set allow_partials to True for dbt models to preserve the original semantics. + allow_partials = True + + # pop begin for all models so we don't pass it through for non-incremental materializations + # (happens if model config is microbatch but project config overrides) + begin = model_kwargs.pop("begin", None) + if kind.is_incremental: + if self.batch_size and isinstance(self.batch_size, str): + if "interval_unit" in model_kwargs: + get_console().log_warning( + f"Both 'interval_unit' and 'batch_size' are set for model '{self.canonical_name(context)}'. 'interval_unit' will be used." + ) + else: + model_kwargs["interval_unit"] = self.batch_size + self.batch_size = None + if begin: + if "start" in model_kwargs: + get_console().log_warning( + f"Both 'begin' and 'start' are set for model '{self.canonical_name(context)}'. 'start' will be used." + ) + else: + model_kwargs["start"] = begin + # If user explicitly disables concurrent batches then we want to set depends on past to true which we + # will do by including the model in the depends_on + if self.concurrent_batches is not None and self.concurrent_batches is False: + depends_on = model_kwargs.get("depends_on", set()) + depends_on.add(self.canonical_name(context)) + + model_kwargs["start"] = model_kwargs.get( + "start", context.sqlmesh_config.model_defaults.start + ) - return create_sql_model( + model = create_sql_model( self.canonical_name(context), query, dialect=model_dialect, - kind=self.model_kind(context), - start=self.start, + kind=kind, + audit_definitions=audit_definitions, + # This ensures that we bypass query rendering that would otherwise be required to extract additional + # dependencies from the model's SQL. + # Note: any table dependencies that are not referenced using the `ref` macro will not be included. + extract_dependencies_from_query=False, + allow_partials=allow_partials, + virtual_environment_mode=virtual_environment_mode, + dbt_node_info=self.node_info, **optional_kwargs, **model_kwargs, ) + return model def _dbt_max_partition_blob(self) -> t.Optional[str]: """Returns a SQL blob which declares the _dbt_max_partition variable. Only applicable to BigQuery.""" @@ -428,7 +789,7 @@ def _dbt_max_partition_blob(self) -> t.Optional[str]: "{{ adapter.resolve_identifier(this) }}", data_type, granularity=self.partition_by.get("granularity"), - database="{{ target.database }}", + catalog="{{ target.database }}", ) data_type = data_type.upper() diff --git a/sqlmesh/dbt/package.py b/sqlmesh/dbt/package.py index d67bc4a508..dbaa832c22 100644 --- a/sqlmesh/dbt/package.py +++ b/sqlmesh/dbt/package.py @@ -28,6 +28,25 @@ class MacroConfig(PydanticModel): path: Path +class HookConfig(PydanticModel): + """Class to contain on run start / on run end hooks.""" + + sql: str + index: int + path: Path + dependencies: Dependencies + + +class MaterializationConfig(PydanticModel): + """Class to contain custom materialization configuration.""" + + name: str + adapter: str + definition: str + dependencies: Dependencies + path: Path + + class Package(PydanticModel): """Class to contain package configuration""" @@ -38,6 +57,9 @@ class Package(PydanticModel): models: t.Dict[str, ModelConfig] variables: t.Dict[str, t.Any] macros: t.Dict[str, MacroConfig] + materializations: t.Dict[str, MaterializationConfig] + on_run_start: t.Dict[str, HookConfig] + on_run_end: t.Dict[str, HookConfig] files: t.Set[Path] @property @@ -83,6 +105,9 @@ def load(self, package_root: Path) -> Package: models = _fix_paths(self._context.manifest.models(package_name), package_root) seeds = _fix_paths(self._context.manifest.seeds(package_name), package_root) macros = _fix_paths(self._context.manifest.macros(package_name), package_root) + materializations = _fix_paths(self._context.manifest.materializations(), package_root) + on_run_start = _fix_paths(self._context.manifest.on_run_start(package_name), package_root) + on_run_end = _fix_paths(self._context.manifest.on_run_end(package_name), package_root) sources = self._context.manifest.sources(package_name) config_paths = { @@ -101,11 +126,16 @@ def load(self, package_root: Path) -> Package: seeds=seeds, variables=package_variables, macros=macros, + materializations=materializations, files=config_paths, + on_run_start=on_run_start, + on_run_end=on_run_end, ) -T = t.TypeVar("T", TestConfig, ModelConfig, MacroConfig, SeedConfig) +T = t.TypeVar( + "T", TestConfig, ModelConfig, MacroConfig, MaterializationConfig, SeedConfig, HookConfig +) def _fix_paths(configs: t.Dict[str, T], package_root: Path) -> t.Dict[str, T]: diff --git a/sqlmesh/dbt/profile.py b/sqlmesh/dbt/profile.py index 1c2ffa8726..a95c81501c 100644 --- a/sqlmesh/dbt/profile.py +++ b/sqlmesh/dbt/profile.py @@ -60,7 +60,7 @@ def load(cls, context: DbtContext, target_name: t.Optional[str] = None) -> Profi if not context.profile_name: raise ConfigError(f"{project_file.stem} must include project name.") - profile_filepath = cls._find_profile(context.project_root) + profile_filepath = cls._find_profile(context.project_root, context.profiles_dir) if not profile_filepath: raise ConfigError(f"{cls.PROFILE_FILE} not found.") @@ -68,12 +68,12 @@ def load(cls, context: DbtContext, target_name: t.Optional[str] = None) -> Profi return Profile(profile_filepath, target_name, target) @classmethod - def _find_profile(cls, project_root: Path) -> t.Optional[Path]: - dir = os.environ.get("DBT_PROFILES_DIR", "") + def _find_profile(cls, project_root: Path, profiles_dir: t.Optional[Path]) -> t.Optional[Path]: + dir = os.environ.get("DBT_PROFILES_DIR", profiles_dir or "") path = Path(project_root, dir, cls.PROFILE_FILE) if path.exists(): return path - elif dir: + if dir: return None path = Path(Path.home(), ".dbt", cls.PROFILE_FILE) @@ -101,8 +101,10 @@ def _read_profile( target_name = context.render(project_data.get("target")) if target_name not in outputs: + target_names = "\n".join(f"- {name}" for name in outputs) raise ConfigError( - f"Target '{target_name}' not specified in profiles for '{context.profile_name}'." + f"Target '{target_name}' not specified in profiles for '{context.profile_name}'. " + f"The valid target names for this profile are:\n{target_names}" ) target_fields = load_yaml(context.render(yaml.dump(outputs[target_name]))) diff --git a/sqlmesh/dbt/project.py b/sqlmesh/dbt/project.py index 8e3c255d05..2b0a2e0c3f 100644 --- a/sqlmesh/dbt/project.py +++ b/sqlmesh/dbt/project.py @@ -1,9 +1,10 @@ from __future__ import annotations -import logging import typing as t +import logging from pathlib import Path +from sqlmesh.core.console import get_console from sqlmesh.dbt.common import PROJECT_FILENAME, load_yaml from sqlmesh.dbt.context import DbtContext from sqlmesh.dbt.manifest import ManifestHelper @@ -54,12 +55,6 @@ def load(cls, context: DbtContext, variables: t.Optional[t.Dict[str, t.Any]] = N raise ConfigError(f"Could not find {PROJECT_FILENAME} in {context.project_root}") project_yaml = load_yaml(project_file_path) - variable_overrides = variables - variables = {**project_yaml.get("vars", {}), **(variables or {})} - global_variables = { - name: var for name, var in variables.items() if not isinstance(var, dict) - } - project_name = context.render(project_yaml.get("name", "")) context.project_name = project_name if not context.project_name: @@ -71,19 +66,22 @@ def load(cls, context: DbtContext, variables: t.Optional[t.Dict[str, t.Any]] = N profile = Profile.load(context, context.target_name) context.target = profile.target + variable_overrides = variables or {} context.manifest = ManifestHelper( project_file_path.parent, profile.path.parent, profile_name, target=profile.target, variable_overrides=variable_overrides, + cache_dir=context.sqlmesh_config.cache_dir, + model_defaults=context.sqlmesh_config.model_defaults, ) extra_fields = profile.target.extra if extra_fields: extra_str = ",".join(f"'{field}'" for field in extra_fields) - logger.warning( - "%s adapter does not currently support %s", profile.target.type, extra_str + get_console().log_warning( + f"{profile.target.type} adapter does not currently support {extra_str}." ) packages = {} @@ -101,13 +99,22 @@ def load(cls, context: DbtContext, variables: t.Optional[t.Dict[str, t.Any]] = N package = package_loader.load(path.parent) packages[package.name] = package + # Variable resolution precedence: + # 1. Variable overrides + # 2. Package-scoped variables in the root project's dbt_project.yml + # 3. Global project variables in the root project's dbt_project.yml + # 4. Variables in the package's dbt_project.yml + all_project_variables = {**(project_yaml.get("vars") or {}), **(variable_overrides or {})} for name, package in packages.items(): - package_vars = variables.get(name) - - if isinstance(package_vars, dict): - package.variables.update(package_vars) - - package.variables.update(global_variables) + if isinstance(all_project_variables.get(name), dict): + project_vars_copy = all_project_variables.copy() + package_scoped_vars = project_vars_copy.pop(name) + package.variables.update(project_vars_copy) + package.variables.update(package_scoped_vars) + else: + package.variables.update(all_project_variables) + if variable_overrides: + package.variables.update(variable_overrides) return Project(context, profile, packages) diff --git a/sqlmesh/dbt/relation.py b/sqlmesh/dbt/relation.py index 9d07db8bc6..fff9f75593 100644 --- a/sqlmesh/dbt/relation.py +++ b/sqlmesh/dbt/relation.py @@ -1,7 +1,7 @@ from sqlmesh.dbt.util import DBT_VERSION -if DBT_VERSION < (1, 8): - from dbt.contracts.relation import * # type: ignore # noqa: F403 -else: +if DBT_VERSION >= (1, 8, 0): from dbt.adapters.contracts.relation import * # type: ignore # noqa: F403 +else: + from dbt.contracts.relation import * # type: ignore # noqa: F403 diff --git a/sqlmesh/dbt/seed.py b/sqlmesh/dbt/seed.py index 50a1ec9d7c..c0c8186f29 100644 --- a/sqlmesh/dbt/seed.py +++ b/sqlmesh/dbt/seed.py @@ -4,20 +4,26 @@ import agate -try: +from sqlmesh.dbt.util import DBT_VERSION + +if DBT_VERSION >= (1, 8, 0): from dbt_common.clients import agate_helper # type: ignore SUPPORTS_DELIMITER = True -except ImportError: +else: from dbt.clients import agate_helper # type: ignore SUPPORTS_DELIMITER = False from sqlglot import exp +from sqlmesh.core.config.common import VirtualEnvironmentMode from sqlmesh.core.model import Model, SeedKind, create_seed_model +from sqlmesh.core.model.seed import CsvSettings from sqlmesh.dbt.basemodel import BaseModelConfig +from sqlmesh.dbt.column import ColumnConfig if t.TYPE_CHECKING: + from sqlmesh.core.audit.definition import ModelAudit from sqlmesh.dbt.context import DbtContext @@ -31,50 +37,67 @@ class SeedConfig(BaseModelConfig): """ delimiter: str = "," - - def to_sqlmesh(self, context: DbtContext) -> Model: + column_types: t.Optional[t.Dict[str, str]] = None + quote_columns: t.Optional[bool] = False + + def to_sqlmesh( + self, + context: DbtContext, + audit_definitions: t.Optional[t.Dict[str, ModelAudit]] = None, + virtual_environment_mode: VirtualEnvironmentMode = VirtualEnvironmentMode.default, + ) -> Model: """Converts the dbt seed into a SQLMesh model.""" seed_path = self.path.absolute().as_posix() - kwargs = self.sqlmesh_model_kwargs(context) - if kwargs.get("columns") is None: - agate_table = ( - agate_helper.from_csv(seed_path, [], delimiter=self.delimiter) - if SUPPORTS_DELIMITER - else agate_helper.from_csv(seed_path, []) - ) - kwargs["columns"] = { - name: AGATE_TYPE_MAPPING[tpe.__class__] - for name, tpe in zip(agate_table.column_names, agate_table.column_types) - } + + column_types_override = { + name: ColumnConfig(name=name, data_type=data_type, quote=self.quote_columns) + for name, data_type in (self.column_types or {}).items() + } + kwargs = self.sqlmesh_model_kwargs(context, column_types_override) + + columns = kwargs.get("columns") or {} + + agate_table = ( + agate_helper.from_csv(seed_path, [], delimiter=self.delimiter) + if SUPPORTS_DELIMITER + else agate_helper.from_csv(seed_path, []) + ) + inferred_types = { + name: AGATE_TYPE_MAPPING[tpe.__class__] + for name, tpe in zip(agate_table.column_names, agate_table.column_types) + } + + # The columns list built from the mixture of supplied and inferred types needs to + # be in the same order as the data for assumptions elsewhere in the codebase to hold true + new_columns = {} + for column_name in agate_table.column_names: + if column_name not in columns: + new_columns[column_name] = inferred_types[column_name] + else: + new_columns[column_name] = columns[column_name] + + kwargs["columns"] = new_columns + + # dbt treats single whitespace as a null value + csv_settings = CsvSettings( + delimiter=self.delimiter, + na_values=[" "], + keep_default_na=True, + ) return create_seed_model( self.canonical_name(context), - SeedKind(path=seed_path), + SeedKind(path=seed_path, csv_settings=csv_settings), dialect=self.dialect(context), + audit_definitions=audit_definitions, + virtual_environment_mode=virtual_environment_mode, + start=self.start or context.sqlmesh_config.model_defaults.start, + dbt_node_info=self.node_info, **kwargs, ) -class Integer(agate.data_types.DataType): - def cast(self, d: str) -> t.Optional[int]: - if d is None: - return d - try: - return int(d) - except ValueError: - raise agate.exceptions.CastError('Can not parse value "%s" as Integer.' % d) - - def jsonify(self, d: str) -> str: - return d - - -# The dbt version has a bug in which they check whether the type of the input value -# is int, while the input value is actually always a string. -agate_helper.Integer = Integer # type: ignore - - AGATE_TYPE_MAPPING = { - agate_helper.Integer: exp.DataType.build("int"), agate_helper.Number: exp.DataType.build("double"), agate_helper.ISODateTime: exp.DataType.build("datetime"), agate.Date: exp.DataType.build("date"), @@ -82,3 +105,25 @@ def jsonify(self, d: str) -> str: agate.Boolean: exp.DataType.build("boolean"), agate.Text: exp.DataType.build("text"), } + + +if DBT_VERSION >= (1, 7, 0): + + class Integer(agate_helper.Integer): + def cast(self, d: t.Any) -> t.Optional[int]: + if isinstance(d, str): + # The dbt's implementation doesn't support coercion of strings to integers. + if d.strip().lower() in self.null_values: + return None + try: + return int(d) + except ValueError: + raise agate.exceptions.CastError('Can not parse value "%s" as Integer.' % d) + return super().cast(d) + + def jsonify(self, d: t.Any) -> str: + return d + + agate_helper.Integer = Integer # type: ignore + + AGATE_TYPE_MAPPING[agate_helper.Integer] = exp.DataType.build("int") diff --git a/sqlmesh/dbt/source.py b/sqlmesh/dbt/source.py index 39651ce7f4..832ed0e156 100644 --- a/sqlmesh/dbt/source.py +++ b/sqlmesh/dbt/source.py @@ -8,6 +8,7 @@ from sqlmesh.dbt.column import ColumnConfig from sqlmesh.dbt.common import GeneralConfig from sqlmesh.dbt.relation import RelationType +from sqlmesh.dbt.util import DBT_VERSION from sqlmesh.utils import AttributeDict from sqlmesh.utils.errors import ConfigError @@ -35,6 +36,7 @@ class SourceConfig(GeneralConfig): # DBT configuration fields name: str = "" source_name_: str = Field("", alias="source_name") + fqn_: t.List[str] = Field(default_factory=list, alias="fqn") database: t.Optional[str] = None schema_: t.Optional[str] = Field(None, alias="schema") identifier: t.Optional[str] = None @@ -44,7 +46,9 @@ class SourceConfig(GeneralConfig): loaded_at_field: t.Optional[str] = None quoting: t.Dict[str, t.Optional[bool]] = {} external: t.Optional[t.Dict[str, t.Any]] = {} + source_meta: t.Optional[t.Dict[str, t.Any]] = {} columns: t.Dict[str, ColumnConfig] = {} + event_time: t.Optional[str] = None _canonical_name: t.Optional[str] = None @@ -61,6 +65,10 @@ def table_name(self) -> t.Optional[str]: def config_name(self) -> str: return f"{self.source_name_}.{self.name}" + @property + def fqn(self) -> str: + return ".".join(self.fqn_) + def canonical_name(self, context: DbtContext) -> str: if self._canonical_name is None: source = context.get_callable_macro("source") @@ -71,7 +79,7 @@ def canonical_name(self, context: DbtContext) -> str: relation = source(self.source_name_, self.name) except Exception as e: raise ConfigError( - f"'source' macro failed for '{self.config_name}' with exeception '{e}'." + f"'source' macro failed for '{self.config_name}' with exception '{e}'." ) relation = relation.quote( @@ -86,6 +94,18 @@ def canonical_name(self, context: DbtContext) -> str: @property def relation_info(self) -> AttributeDict: + extras = {} + external_location = ( + self.source_meta.get("external_location", None) if self.source_meta else None + ) + if external_location: + extras["external"] = external_location.replace("{name}", self.table_name) + + if DBT_VERSION >= (1, 9, 0) and self.event_time: + extras["event_time_filter"] = { + "field_name": self.event_time, + } + return AttributeDict( { "database": self.database, @@ -93,5 +113,6 @@ def relation_info(self) -> AttributeDict: "identifier": self.table_name, "type": RelationType.External.value, "quote_policy": AttributeDict(self.quoting), + **extras, } ) diff --git a/sqlmesh/dbt/target.py b/sqlmesh/dbt/target.py index 69bc5eb33f..62683ecfac 100644 --- a/sqlmesh/dbt/target.py +++ b/sqlmesh/dbt/target.py @@ -1,18 +1,19 @@ from __future__ import annotations import abc -import logging -import sys import typing as t from pathlib import Path from dbt.adapters.base import BaseRelation, Column -from pydantic import Field +from pydantic import Field, AliasChoices +from sqlmesh.core.console import get_console from sqlmesh.core.config.connection import ( + AthenaConnectionConfig, BigQueryConnectionConfig, BigQueryConnectionMethod, BigQueryPriority, + ClickhouseConnectionConfig, ConnectionConfig, DatabricksConnectionConfig, DuckDBConnectionConfig, @@ -28,23 +29,13 @@ IncrementalByUniqueKeyKind, IncrementalUnmanagedKind, ) +from sqlmesh.core.schema_diff import NestedSupport from sqlmesh.dbt.common import DbtConfig from sqlmesh.dbt.relation import Policy from sqlmesh.dbt.util import DBT_VERSION from sqlmesh.utils import AttributeDict, classproperty from sqlmesh.utils.errors import ConfigError -from sqlmesh.utils.pydantic import ( - field_validator, - model_validator, - model_validator_v1_args, -) - -if sys.version_info >= (3, 9): - from typing import Literal -else: - from typing_extensions import Literal - -logger = logging.getLogger(__name__) +from sqlmesh.utils.pydantic import field_validator, model_validator IncrementalKind = t.Union[ t.Type[IncrementalByUniqueKeyKind], @@ -54,12 +45,45 @@ # We only serialize a subset of fields in order to avoid persisting sensitive information SERIALIZABLE_FIELDS = { - "type", + # core "name", - "database", "schema_", + "type", + "threads", + # snowflake + "database", + "warehouse", + "user", + "role", + "account", + # postgres/redshift + "dbname", + "host", + "port", + # bigquery + "project", + "dataset", } +SCHEMA_DIFFER_OVERRIDES = { + "schema_differ_overrides": { + "treat_alter_data_type_as_destructive": True, + "nested_support": NestedSupport.IGNORE, + } +} + + +def with_schema_differ_overrides( + func: t.Callable[..., ConnectionConfig], +) -> t.Callable[..., ConnectionConfig]: + """Decorator that merges default config with kwargs.""" + + def wrapper(self: TargetConfig, **kwargs: t.Any) -> ConnectionConfig: + merged_kwargs = {**SCHEMA_DIFFER_OVERRIDES, **kwargs} + return func(self, **merged_kwargs) + + return wrapper + class TargetConfig(abc.ABC, DbtConfig): """ @@ -95,20 +119,24 @@ def load(cls, data: t.Dict[str, t.Any]) -> TargetConfig: db_type = data["type"] if db_type == "databricks": return DatabricksConfig(**data) - elif db_type == "duckdb": + if db_type == "duckdb": return DuckDbConfig(**data) - elif db_type == "postgres": + if db_type == "postgres": return PostgresConfig(**data) - elif db_type == "redshift": + if db_type == "redshift": return RedshiftConfig(**data) - elif db_type == "snowflake": + if db_type == "snowflake": return SnowflakeConfig(**data) - elif db_type == "bigquery": + if db_type == "bigquery": return BigQueryConfig(**data) - elif db_type == "sqlserver": + if db_type == "sqlserver": return MSSQLConfig(**data) - elif db_type == "trino": + if db_type == "trino": return TrinoConfig(**data) + if db_type == "clickhouse": + return ClickhouseConfig(**data) + if db_type == "athena": + return AthenaConfig(**data) raise ConfigError(f"{db_type} not supported.") @@ -116,6 +144,7 @@ def default_incremental_strategy(self, kind: IncrementalKind) -> str: """The default incremental strategy for the db""" raise NotImplementedError + @with_schema_differ_overrides def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: """Converts target config to SQLMesh connection config""" raise NotImplementedError @@ -157,30 +186,36 @@ class DuckDbConfig(TargetConfig): path: Location of the database file. If not specified, an in memory database is used. extensions: A list of autoloadable extensions to load. settings: A dictionary of settings to pass into the duckdb connector. + secrets: A list of secrets to pass to the secret manager in the duckdb connector. + filesystems: A list of `fsspec` filesystems to register in the duckdb connection. """ - type: Literal["duckdb"] = "duckdb" + type: t.Literal["duckdb"] = "duckdb" database: str = "main" schema_: str = Field(default="main", alias="schema") path: str = DUCKDB_IN_MEMORY extensions: t.Optional[t.List[str]] = None settings: t.Optional[t.Dict[str, t.Any]] = None + secrets: t.Optional[t.List[t.Dict[str, t.Any]]] = None + filesystems: t.Optional[t.List[t.Dict[str, t.Any]]] = None @model_validator(mode="before") - @model_validator_v1_args - def validate_authentication( - cls, values: t.Dict[str, t.Union[t.Tuple[str, ...], t.Optional[str], t.Dict[str, t.Any]]] - ) -> t.Dict[str, t.Union[t.Tuple[str, ...], t.Optional[str], t.Dict[str, t.Any]]]: - if "database" not in values and DBT_VERSION >= (1, 5): - path = values.get("path") - values["database"] = ( + def validate_authentication(cls, data: t.Any) -> t.Any: + if not isinstance(data, dict): + return data + + if "database" not in data and DBT_VERSION >= (1, 5, 0): + path = data.get("path") + data["database"] = ( "memory" if path is None or path == DUCKDB_IN_MEMORY else Path(t.cast(str, path)).stem ) - if "threads" in values and t.cast(int, values["threads"]) > 1: - logger.warning("DuckDB does not support concurrency - setting threads to 1.") - return values + + if "threads" in data and t.cast(int, data["threads"]) > 1: + get_console().log_warning("DuckDB does not support concurrency - setting threads to 1.") + + return data def default_incremental_strategy(self, kind: IncrementalKind) -> str: return "delete+insert" @@ -191,11 +226,16 @@ def relation_class(cls) -> t.Type[BaseRelation]: return DuckDBRelation + @with_schema_differ_overrides def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: if self.extensions is not None: kwargs["extensions"] = self.extensions if self.settings is not None: kwargs["connector_config"] = self.settings + if self.secrets is not None: + kwargs["secrets"] = self.secrets + if self.filesystems is not None: + kwargs["filesystems"] = self.filesystems return DuckDBConnectionConfig( database=self.path, concurrent_tasks=1, @@ -226,7 +266,7 @@ class SnowflakeConfig(TargetConfig): token: OAuth authentication: The Snowflake OAuth 2.0 access token """ - type: Literal["snowflake"] = "snowflake" + type: t.Literal["snowflake"] = "snowflake" account: str user: str @@ -257,17 +297,16 @@ class SnowflakeConfig(TargetConfig): retry_all: bool = False @model_validator(mode="before") - @model_validator_v1_args - def validate_authentication( - cls, values: t.Dict[str, t.Union[t.Tuple[str, ...], t.Optional[str], t.Dict[str, t.Any]]] - ) -> t.Dict[str, t.Union[t.Tuple[str, ...], t.Optional[str], t.Dict[str, t.Any]]]: - if ( - values.get("password") - or values.get("authenticator") - or values.get("private_key") - or values.get("private_key_path") + @classmethod + def validate_authentication(cls, data: t.Any) -> t.Any: + if not isinstance(data, dict) or ( + data.get("password") + or data.get("authenticator") + or data.get("private_key") + or data.get("private_key_path") ): - return values + return data + raise ConfigError("No supported Snowflake authentication method found in target profile.") def default_incremental_strategy(self, kind: IncrementalKind) -> str: @@ -285,6 +324,7 @@ def column_class(cls) -> t.Type[Column]: return SnowflakeColumn + @with_schema_differ_overrides def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: return SnowflakeConnectionConfig( user=self.user, @@ -325,10 +365,10 @@ class PostgresConfig(TargetConfig): sslmode: SSL Mode used to connect to the database """ - type: Literal["postgres"] = "postgres" + type: t.Literal["postgres"] = "postgres" host: str user: str - password: str + password: str = Field(validation_alias=AliasChoices("pass", "password")) port: int dbname: str keepalives_idle: t.Optional[int] = None @@ -339,14 +379,16 @@ class PostgresConfig(TargetConfig): sslmode: t.Optional[str] = None @model_validator(mode="before") - @model_validator_v1_args - def validate_database( - cls, values: t.Dict[str, t.Union[t.Tuple[str, ...], t.Optional[str], t.Dict[str, t.Any]]] - ) -> t.Dict[str, t.Union[t.Tuple[str, ...], t.Optional[str], t.Dict[str, t.Any]]]: - values["database"] = values.get("database") or values.get("dbname") - if not values["database"]: + @classmethod + def validate_database(cls, data: t.Any) -> t.Any: + if not isinstance(data, dict): + return data + + data["database"] = data.get("database") or data.get("dbname") + if not data["database"]: raise ConfigError("Either database or dbname must be set") - return values + + return data @field_validator("port") @classmethod @@ -356,6 +398,7 @@ def _validate_port(cls, v: t.Union[int, str]) -> int: def default_incremental_strategy(self, kind: IncrementalKind) -> str: return "delete+insert" if kind is IncrementalByUniqueKeyKind else "append" + @with_schema_differ_overrides def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: return PostgresConnectionConfig( host=self.host, @@ -389,10 +432,10 @@ class RedshiftConfig(TargetConfig): """ # TODO add other forms of authentication - type: Literal["redshift"] = "redshift" + type: t.Literal["redshift"] = "redshift" host: str user: str - password: str + password: str = Field(validation_alias=AliasChoices("pass", "password")) port: int dbname: str connect_timeout: t.Optional[int] = None @@ -401,14 +444,16 @@ class RedshiftConfig(TargetConfig): sslmode: t.Optional[str] = None @model_validator(mode="before") - @model_validator_v1_args - def validate_database( - cls, values: t.Dict[str, t.Union[t.Tuple[str, ...], t.Optional[str], t.Dict[str, t.Any]]] - ) -> t.Dict[str, t.Union[t.Tuple[str, ...], t.Optional[str], t.Dict[str, t.Any]]]: - values["database"] = values.get("database") or values.get("dbname") - if not values["database"]: + @classmethod + def validate_database(cls, data: t.Any) -> t.Any: + if not isinstance(data, dict): + return data + + data["database"] = data.get("database") or data.get("dbname") + if not data["database"]: raise ConfigError("Either database or dbname must be set") - return values + + return data def default_incremental_strategy(self, kind: IncrementalKind) -> str: return "append" @@ -421,13 +466,13 @@ def relation_class(cls) -> t.Type[BaseRelation]: @classproperty def column_class(cls) -> t.Type[Column]: - if DBT_VERSION < (1, 6): + if DBT_VERSION < (1, 6, 0): from dbt.adapters.redshift import RedshiftColumn # type: ignore return RedshiftColumn - else: - return super(RedshiftConfig, cls).column_class + return super(RedshiftConfig, cls).column_class + @with_schema_differ_overrides def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: return RedshiftConnectionConfig( user=self.user, @@ -454,11 +499,14 @@ class DatabricksConfig(TargetConfig): database: Name of the database. Not applicable for Databricks and ignored """ - type: Literal["databricks"] = "databricks" + type: t.Literal["databricks"] = "databricks" host: str http_path: str - token: str + token: t.Optional[str] = None # only required if auth_type is not set to 'oauth' database: t.Optional[str] = Field(alias="catalog") # type: ignore + auth_type: t.Optional[str] = None + client_id: t.Optional[str] = None + client_secret: t.Optional[str] = None def default_incremental_strategy(self, kind: IncrementalKind) -> str: return "merge" @@ -475,6 +523,7 @@ def column_class(cls) -> t.Type[Column]: return DatabricksColumn + @with_schema_differ_overrides def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: return DatabricksConnectionConfig( server_hostname=self.host, @@ -482,6 +531,9 @@ def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: access_token=self.token, concurrent_tasks=self.threads, catalog=self.database, + auth_type="databricks-oauth" if self.auth_type == "oauth" else self.auth_type, + oauth_client_id=self.client_id, + oauth_client_secret=self.client_secret, **kwargs, ) @@ -503,6 +555,8 @@ class BigQueryConfig(TargetConfig): client_secret: The BigQuery client secret token_uri: The BigQuery token URI scopes: The BigQuery scopes + impersonated_service_account: The service account to impersonate + job_creation_timeout_seconds: The maximum amount of time, in seconds, to wait for the underlying job to be created job_execution_timeout_seconds: The maximum amount of time, in seconds, to wait for the underlying job to complete timeout_seconds: Alias for job_execution_timeout_seconds job_retries: The number of times to retry the underlying job if it fails @@ -512,11 +566,12 @@ class BigQueryConfig(TargetConfig): maximum_bytes_billed: The maximum number of bytes to be billed for the underlying job """ - type: Literal["bigquery"] = "bigquery" + type: t.Literal["bigquery"] = "bigquery" method: t.Optional[str] = BigQueryConnectionMethod.OAUTH dataset: t.Optional[str] = None project: t.Optional[str] = None execution_project: t.Optional[str] = None + quota_project: t.Optional[str] = None location: t.Optional[str] = None keyfile: t.Optional[str] = None keyfile_json: t.Optional[t.Dict[str, t.Any]] = None @@ -530,6 +585,8 @@ class BigQueryConfig(TargetConfig): "https://www.googleapis.com/auth/cloud-platform", "https://www.googleapis.com/auth/drive", ) + impersonated_service_account: t.Optional[str] = None + job_creation_timeout_seconds: t.Optional[int] = None job_execution_timeout_seconds: t.Optional[int] = None timeout_seconds: t.Optional[int] = None # To support legacy config job_retries: t.Optional[int] = None @@ -539,17 +596,24 @@ class BigQueryConfig(TargetConfig): maximum_bytes_billed: t.Optional[int] = None @model_validator(mode="before") - @model_validator_v1_args - def validate_fields( - cls, values: t.Dict[str, t.Union[t.Tuple[str, ...], t.Optional[str], t.Dict[str, t.Any]]] - ) -> t.Dict[str, t.Union[t.Tuple[str, ...], t.Optional[str], t.Dict[str, t.Any]]]: - values["schema"] = values.get("schema") or values.get("dataset") - if not values["schema"]: + @classmethod + def validate_fields(cls, data: t.Any) -> t.Any: + if not isinstance(data, dict): + return data + + # dbt treats schema and dataset interchangeably + schema = data.get("schema") or data.get("dataset") + if not schema: raise ConfigError("Either schema or dataset must be set") - values["database"] = values.get("database") or values.get("project") - if not values["database"]: + data["dataset"] = data["schema"] = schema + + # dbt treats database and project interchangeably + database = data.get("database") or data.get("project") + if not database: raise ConfigError("Either database or project must be set") - return values + data["database"] = data["project"] = database + + return data def default_incremental_strategy(self, kind: IncrementalKind) -> str: return "merge" @@ -566,6 +630,7 @@ def column_class(cls) -> t.Type[Column]: return BigQueryColumn + @with_schema_differ_overrides def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: job_retries = self.job_retries if self.job_retries is not None else self.retries job_execution_timeout_seconds = ( @@ -577,6 +642,7 @@ def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: method=self.method, project=self.database, execution_project=self.execution_project, + quota_project=self.quota_project, location=self.location, concurrent_tasks=self.threads, keyfile=self.keyfile, @@ -587,6 +653,8 @@ def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: client_secret=self.client_secret, token_uri=self.token_uri, scopes=self.scopes, + impersonated_service_account=self.impersonated_service_account, + job_creation_timeout_seconds=self.job_creation_timeout_seconds, job_execution_timeout_seconds=job_execution_timeout_seconds, job_retries=job_retries, job_retry_deadline_seconds=self.job_retry_deadline_seconds, @@ -623,7 +691,7 @@ class MSSQLConfig(TargetConfig): client_secret: The client secret of the Azure Active Directory service principal, not used by SQLMesh """ - type: Literal["sqlserver"] = "sqlserver" + type: t.Literal["sqlserver"] = "sqlserver" host: t.Optional[str] = None server: t.Optional[str] = None port: int = 1433 @@ -653,23 +721,24 @@ class MSSQLConfig(TargetConfig): client_secret: t.Optional[str] = None # Azure Active Directory auth @model_validator(mode="before") - @model_validator_v1_args - def validate_alias_fields( - cls, values: t.Dict[str, t.Union[t.Tuple[str, ...], t.Optional[str], t.Dict[str, t.Any]]] - ) -> t.Dict[str, t.Union[t.Tuple[str, ...], t.Optional[str], t.Dict[str, t.Any]]]: - values["host"] = values.get("host") or values.get("server") - if not values["host"]: + @classmethod + def validate_alias_fields(cls, data: t.Any) -> t.Any: + if not isinstance(data, dict): + return data + + data["host"] = data.get("host") or data.get("server") + if not data["host"]: raise ConfigError("Either host or server must be set") - values["user"] = values.get("user") or values.get("username") or values.get("UID") - if not values["user"]: + data["user"] = data.get("user") or data.get("username") or data.get("UID") + if not data["user"]: raise ConfigError("One of user, username, or UID must be set") - values["password"] = values.get("password") or values.get("PWD") - if not values["password"]: + data["password"] = data.get("password") or data.get("PWD") + if not data["password"]: raise ConfigError("Either password or PWD must be set") - return values + return data @field_validator("authentication") @classmethod @@ -689,7 +758,12 @@ def default_incremental_strategy(self, kind: IncrementalKind) -> str: @classproperty def column_class(cls) -> t.Type[Column]: - from dbt.adapters.sqlserver.sql_server_column import SQLServerColumn + try: + # 1.8.0+ + from dbt.adapters.sqlserver.sqlserver_column import SQLServerColumn + except ImportError: + # <1.8.0 + from dbt.adapters.sqlserver.sql_server_column import SQLServerColumn # type: ignore return SQLServerColumn @@ -697,6 +771,7 @@ def column_class(cls) -> t.Type[Column]: def dialect(self) -> str: return "tsql" + @with_schema_differ_overrides def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: return MSSQLConnectionConfig( host=self.host, @@ -756,7 +831,7 @@ class TrinoConfig(TargetConfig): "oauth_console": TrinoAuthenticationMethod.OAUTH, } - type: Literal["trino"] = "trino" + type: t.Literal["trino"] = "trino" host: str database: str schema_: str = Field(alias="schema") @@ -811,6 +886,7 @@ def column_class(cls) -> t.Type[Column]: return TrinoColumn + @with_schema_differ_overrides def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: return TrinoConnectionConfig( method=self._method_to_auth_enum[self.method], @@ -844,6 +920,184 @@ def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: ) +class ClickhouseConfig(TargetConfig): + """ + Project connection and operational configuration for the Clickhouse target + + Args: + host: [localhost] + user: [default] # User for all database operations + password: [] # Password for the user + secure: [False] # Use TLS (native protocol) or HTTPS (http protocol) + port: [8123] # If not set, defaults to 8123, 8443 depending on the secure and driver settings + connect_timeout: [10] # Timeout in seconds to establish a connection to ClickHouse + send_receive_timeout: [300] # Timeout in seconds to receive data from the ClickHouse server + verify: [True] # Validate TLS certificate if using TLS/SSL + cluster: [] # If set, certain DDL/table operations will be executed with the `ON CLUSTER` clause using this cluster. + custom_settings: [{}] # A dictionary/mapping of custom ClickHouse settings for the connection - default is empty. + schema: [default] # ClickHouse database for dbt models, not used by SQLMesh + driver: [http] # http or native. If not set this will be autodetermined based on port setting, not used by SQLMesh + retries: [1] # Number of times to retry a "retriable" database exception (such as a 503 'Service Unavailable' error), not used by SQLMesh + compression: [] # Use gzip compression if truthy (http), or compression type for a native connection, not used by SQLMesh + cluster_mode: [False] # Use specific settings designed to improve operation on Replicated databases (recommended for ClickHouse Cloud), not used by SQLMesh + use_lw_deletes: [False] # Use the strategy `delete+insert` as the default incremental strategy, not used by SQLMesh + check_exchange: [True] # Validate that clickhouse support the atomic EXCHANGE TABLES command. Not used by SQLMesh. + local_suffix: [_local] # Table suffix of local tables on shards for distributed materializations, not used by SQLMesh + local_db_prefix: [] # Database prefix of local tables on shards for distributed materializations, not used by SQLMesh + allow_automatic_deduplication: [False] # Enable ClickHouse automatic deduplication for Replicated tables, not used by SQLMesh + tcp_keepalive: [False] # Native client only, specify TCP keepalive configuration. Specify custom keepalive settings as [idle_time_sec, interval_sec, probes], not used by SQLMesh + sync_request_timeout: [5] # Timeout for server ping, not used by SQLMesh + compress_block_size: [1048576] # Compression block size if compression is enabled, not used by SQLMesh + """ + + host: str = "localhost" + user: str = Field(default="default", alias="username") + password: str = "" + port: t.Optional[int] = None + cluster: t.Optional[str] = None + schema_: str = Field(default="default", alias="schema") + connect_timeout: int = 10 + send_receive_timeout: int = 300 + verify: bool = True + compression: str = "" + custom_settings: t.Optional[t.Dict[str, t.Any]] = None + + # Not used by SQLMesh + driver: t.Optional[str] = None + secure: bool = False + retries: int = 1 + database_engine: t.Optional[str] = None + cluster_mode: bool = False + sync_request_timeout: int = 5 + compress_block_size: int = 1048576 + check_exchange: bool = True + use_lw_deletes: bool = False + allow_automatic_deduplication: bool = False + tcp_keepalive: t.Union[bool, t.Tuple[int, ...], t.List[int]] = False + database: str = "" + local_suffix: str = "local" + local_db_prefix: str = "" + + type: t.Literal["clickhouse"] = "clickhouse" + + def default_incremental_strategy(self, kind: IncrementalKind) -> str: + # dbt-clickhouse name for temp table swap. That is sqlmesh's default + # strategy so doesn't require special handling during conversion. + return "legacy" + + @classproperty + def relation_class(cls) -> t.Type[BaseRelation]: + from dbt.adapters.clickhouse.relation import ClickHouseRelation + + return ClickHouseRelation + + @classproperty + def column_class(cls) -> t.Type[Column]: + from dbt.adapters.clickhouse.column import ClickHouseColumn + + return ClickHouseColumn + + @with_schema_differ_overrides + def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: + return ClickhouseConnectionConfig( + host=self.host, + username=self.user, + password=self.password, + port=self.port, + cluster=self.cluster, + connect_timeout=self.connect_timeout, + send_receive_timeout=self.send_receive_timeout, + verify=self.verify, + compression_method=self.compression, + connection_settings=self.custom_settings, + **kwargs, + ) + + +class AthenaConfig(TargetConfig): + """ + Project connection and operational configuration for the Athena target. + + Args: + s3_staging_dir: S3 location to store Athena query results and metadata + s3_data_dir: Prefix for storing tables, if different from the connection's s3_staging_dir + s3_data_naming: How to generate table paths in s3_data_dir + s3_tmp_table_dir: Prefix for storing temporary tables, if different from the connection's s3_data_dir + region_name: AWS region of your Athena instance + schema: Specify the schema (Athena database) to build models into (lowercase only) + database: Specify the database (Data catalog) to build models into (lowercase only) + poll_interval: Interval in seconds to use for polling the status of query results in Athena + debug_query_state: Flag if debug message with Athena query state is needed + aws_access_key_id: Access key ID of the user performing requests + aws_secret_access_key: Secret access key of the user performing requests + aws_profile_name: Profile to use from your AWS shared credentials file + work_group: Identifier of Athena workgroup + skip_workgroup_check: Indicates if the WorkGroup check (additional AWS call) can be skipped + num_retries: Number of times to retry a failing query + num_boto3_retries: Number of times to retry boto3 requests (e.g. deleting S3 files for materialized tables) + num_iceberg_retries: Number of times to retry iceberg commit queries to fix ICEBERG_COMMIT_ERROR + spark_work_group: Identifier of Athena Spark workgroup for running Python models + seed_s3_upload_args: Dictionary containing boto3 ExtraArgs when uploading to S3 + lf_tags_database: Default LF tags for new database if it's created by dbt + """ + + type: t.Literal["athena"] = "athena" + threads: int = 4 + + s3_staging_dir: t.Optional[str] = None + s3_data_dir: t.Optional[str] = None + s3_data_naming: t.Optional[str] = None + s3_tmp_table_dir: t.Optional[str] = None + poll_interval: t.Optional[int] = None + debug_query_state: bool = False + work_group: t.Optional[str] = None + skip_workgroup_check: t.Optional[bool] = None + spark_work_group: t.Optional[str] = None + + aws_access_key_id: t.Optional[str] = None + aws_secret_access_key: t.Optional[str] = None + aws_profile_name: t.Optional[str] = None + region_name: t.Optional[str] = None + + num_retries: t.Optional[int] = None + num_boto3_retries: t.Optional[int] = None + num_iceberg_retries: t.Optional[int] = None + + seed_s3_upload_args: t.Dict[str, str] = {} + lf_tags_database: t.Dict[str, str] = {} + + @classproperty + def relation_class(cls) -> t.Type[BaseRelation]: + from dbt.adapters.athena.relation import AthenaRelation + + return AthenaRelation + + @classproperty + def column_class(cls) -> t.Type[Column]: + from dbt.adapters.athena.column import AthenaColumn + + return AthenaColumn + + def default_incremental_strategy(self, kind: IncrementalKind) -> str: + return "insert_overwrite" + + @with_schema_differ_overrides + def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: + return AthenaConnectionConfig( + type="athena", + aws_access_key_id=self.aws_access_key_id, + aws_secret_access_key=self.aws_secret_access_key, + region_name=self.region_name, + work_group=self.work_group, + s3_staging_dir=self.s3_staging_dir, + s3_warehouse_location=self.s3_data_dir, + schema_name=self.schema_, + catalog_name=self.database, + concurrent_tasks=self.threads, + **kwargs, + ) + + TARGET_TYPE_TO_CONFIG_CLASS: t.Dict[str, t.Type[TargetConfig]] = { "databricks": DatabricksConfig, "duckdb": DuckDbConfig, @@ -854,4 +1108,6 @@ def to_sqlmesh(self, **kwargs: t.Any) -> ConnectionConfig: "sqlserver": MSSQLConfig, "tsql": MSSQLConfig, "trino": TrinoConfig, + "athena": AthenaConfig, + "clickhouse": ClickhouseConfig, } diff --git a/sqlmesh/dbt/test.py b/sqlmesh/dbt/test.py index 5f005ebc36..c4a32b2189 100644 --- a/sqlmesh/dbt/test.py +++ b/sqlmesh/dbt/test.py @@ -8,11 +8,11 @@ from pydantic import Field import sqlmesh.core.dialect as d from sqlmesh.core.audit import Audit, ModelAudit, StandaloneAudit +from sqlmesh.core.node import DbtNodeInfo from sqlmesh.dbt.common import ( Dependencies, GeneralConfig, SqlStr, - extract_jinja_config, sql_str_validator, ) from sqlmesh.utils import AttributeDict @@ -61,6 +61,10 @@ class TestConfig(GeneralConfig): error_if: Conditional expression (default "!=0") to detect if error condition met (Not supported). """ + __test__ = ( + False # prevent pytest trying to collect this as a test class when it's imported in a test + ) + # SQLMesh fields path: Path = Path() name: str @@ -76,14 +80,16 @@ class TestConfig(GeneralConfig): dialect_: t.Optional[str] = Field(None, alias="dialect") # dbt fields + unique_id: str = "" package_name: str = "" alias: t.Optional[str] = None + fqn: t.List[str] = [] schema_: t.Optional[str] = Field("", alias="schema") database: t.Optional[str] = None severity: Severity = Severity.ERROR store_failures: t.Optional[bool] = None where: t.Optional[str] = None - limit: t.Optional[str] = None + limit: t.Optional[int] = None fail_calc: str = "count(*)" warn_if: str = "!=0" error_if: str = "!=0" @@ -103,9 +109,28 @@ def _validate_severity(cls, v: t.Union[Severity, str]) -> Severity: def _lowercase_name(cls, v: str) -> str: return v.lower() + @property + def canonical_name(self) -> str: + return f"{self.package_name}.{self.name}".lower() if self.package_name else self.name + @property def is_standalone(self) -> bool: - return not self.model_name + # A test is standalone if: + # 1. It has no model_name (already standalone), OR + # 2. It references other models besides its own model + if not self.model_name: + return True + + # Check if test has references to other models + # For versioned models, refs include version (e.g., "model_name_v1") but model_name may not + self_refs = {self.model_name} + for ref in self.dependencies.refs: + # versioned models end in _vX + if ref.startswith(f"{self.model_name}_v"): + self_refs.add(ref) + + other_refs = {ref for ref in self.dependencies.refs if ref not in self_refs} + return bool(other_refs) @property def sqlmesh_config_fields(self) -> t.Set[str]: @@ -134,9 +159,7 @@ def to_sqlmesh(self, context: DbtContext) -> Audit: } ) - sql_no_config, _sql_config_only = extract_jinja_config(self.sql) - sql_no_config = sql_no_config.replace("**_dbt_generic_test_kwargs", self._kwargs()) - query = d.jinja_query(sql_no_config) + query = d.jinja_query(self.sql.replace("**_dbt_generic_test_kwargs", self._kwargs())) skip = not self.enabled blocking = self.severity == Severity.ERROR @@ -146,6 +169,7 @@ def to_sqlmesh(self, context: DbtContext) -> Audit: jinja_macros.add_globals({"this": self.relation_info}) audit = StandaloneAudit( name=self.name, + dbt_node_info=self.node_info, dialect=self.dialect(context), skip=skip, query=query, @@ -162,6 +186,7 @@ def to_sqlmesh(self, context: DbtContext) -> Audit: else: audit = ModelAudit( name=self.name, + dbt_node_info=self.node_info, dialect=self.dialect(context), skip=skip, blocking=blocking, @@ -205,6 +230,12 @@ def relation_info(self) -> AttributeDict: } ) + @property + def node_info(self) -> DbtNodeInfo: + return DbtNodeInfo( + unique_id=self.unique_id, name=self.name, fqn=".".join(self.fqn), alias=self.alias + ) + def _remove_jinja_braces(jinja_str: str) -> str: no_braces = jinja_str diff --git a/sqlmesh/dbt/util.py b/sqlmesh/dbt/util.py index 8fc6c6ecd2..0de16e3b3e 100644 --- a/sqlmesh/dbt/util.py +++ b/sqlmesh/dbt/util.py @@ -1,21 +1,29 @@ +from __future__ import annotations + import typing as t import agate -import pandas as pd from dbt.version import get_installed_version +if t.TYPE_CHECKING: + import pandas as pd + -def _get_dbt_version() -> t.Tuple[int, int]: +def _get_dbt_version() -> t.Tuple[int, int, int]: dbt_version = get_installed_version() - return (int(dbt_version.major or "0"), int(dbt_version.minor or "0")) + return ( + int(dbt_version.major or "0"), + int(dbt_version.minor or "0"), + int(dbt_version.patch or "0"), + ) DBT_VERSION = _get_dbt_version() -if DBT_VERSION < (1, 8): - from dbt.clients.agate_helper import table_from_data_flat, empty_table, as_matrix # type: ignore # noqa: F401 -else: +if DBT_VERSION >= (1, 8, 0): from dbt_common.clients.agate_helper import table_from_data_flat, empty_table, as_matrix # type: ignore # noqa: F401 +else: + from dbt.clients.agate_helper import table_from_data_flat, empty_table, as_matrix # type: ignore # noqa: F401 def pandas_to_agate(df: pd.DataFrame) -> agate.Table: diff --git a/sqlmesh/engines/commands.py b/sqlmesh/engines/commands.py deleted file mode 100644 index 1eaa5c4522..0000000000 --- a/sqlmesh/engines/commands.py +++ /dev/null @@ -1,171 +0,0 @@ -import typing as t -from enum import Enum - -from sqlmesh.core.environment import Environment, EnvironmentNamingInfo -from sqlmesh.core.snapshot import ( - DeployabilityIndex, - Snapshot, - SnapshotEvaluator, - SnapshotId, - SnapshotTableCleanupTask, - SnapshotTableInfo, -) -from sqlmesh.core.state_sync import cleanup_expired_views -from sqlmesh.utils.date import TimeLike -from sqlmesh.utils.pydantic import PydanticModel - -COMMAND_PAYLOAD_FILE_NAME = "payload.json" - - -class CommandType(str, Enum): - EVALUATE = "evaluate" - PROMOTE = "promote" - DEMOTE = "demote" - CLEANUP = "cleanup" - CREATE_TABLES = "create_tables" - MIGRATE_TABLES = "migrate_tables" - - # This makes it easy to integrate with argparse - def __str__(self) -> str: - return self.value - - -class EvaluateCommandPayload(PydanticModel): - snapshot: Snapshot - parent_snapshots: t.Dict[str, Snapshot] - start: TimeLike - end: TimeLike - execution_time: TimeLike - deployability_index: DeployabilityIndex - batch_index: int - - -class PromoteCommandPayload(PydanticModel): - snapshots: t.List[Snapshot] - environment_naming_info: EnvironmentNamingInfo - deployability_index: DeployabilityIndex - - -class DemoteCommandPayload(PydanticModel): - snapshots: t.List[SnapshotTableInfo] - environment_naming_info: EnvironmentNamingInfo - - -class CleanupCommandPayload(PydanticModel): - environments: t.List[Environment] - tasks: t.List[SnapshotTableCleanupTask] - - -class CreateTablesCommandPayload(PydanticModel): - target_snapshot_ids: t.List[SnapshotId] - snapshots: t.List[Snapshot] - deployability_index: DeployabilityIndex - allow_destructive_snapshots: t.Set[str] - - -class MigrateTablesCommandPayload(PydanticModel): - target_snapshot_ids: t.List[SnapshotId] - snapshots: t.List[Snapshot] - allow_destructive_snapshots: t.Set[str] - - -def evaluate( - evaluator: SnapshotEvaluator, command_payload: t.Union[str, EvaluateCommandPayload] -) -> None: - if isinstance(command_payload, str): - command_payload = EvaluateCommandPayload.parse_raw(command_payload) - - parent_snapshots = command_payload.parent_snapshots - parent_snapshots[command_payload.snapshot.name] = command_payload.snapshot - - wap_id = evaluator.evaluate( - command_payload.snapshot, - start=command_payload.start, - end=command_payload.end, - execution_time=command_payload.execution_time, - snapshots=parent_snapshots, - deployability_index=command_payload.deployability_index, - batch_index=command_payload.batch_index, - ) - evaluator.audit( - snapshot=command_payload.snapshot, - start=command_payload.start, - end=command_payload.end, - execution_time=command_payload.execution_time, - snapshots=parent_snapshots, - deployability_index=command_payload.deployability_index, - wap_id=wap_id, - ) - - -def promote( - evaluator: SnapshotEvaluator, command_payload: t.Union[str, PromoteCommandPayload] -) -> None: - if isinstance(command_payload, str): - command_payload = PromoteCommandPayload.parse_raw(command_payload) - evaluator.promote( - command_payload.snapshots, - command_payload.environment_naming_info, - deployability_index=command_payload.deployability_index, - ) - - -def demote( - evaluator: SnapshotEvaluator, command_payload: t.Union[str, DemoteCommandPayload] -) -> None: - if isinstance(command_payload, str): - command_payload = DemoteCommandPayload.parse_raw(command_payload) - evaluator.demote( - command_payload.snapshots, - command_payload.environment_naming_info, - ) - - -def cleanup( - evaluator: SnapshotEvaluator, command_payload: t.Union[str, CleanupCommandPayload] -) -> None: - if isinstance(command_payload, str): - command_payload = CleanupCommandPayload.parse_raw(command_payload) - - cleanup_expired_views(evaluator.adapter, command_payload.environments) - evaluator.cleanup(command_payload.tasks) - - -def create_tables( - evaluator: SnapshotEvaluator, - command_payload: t.Union[str, CreateTablesCommandPayload], -) -> None: - if isinstance(command_payload, str): - command_payload = CreateTablesCommandPayload.parse_raw(command_payload) - - snapshots_by_id = {s.snapshot_id: s for s in command_payload.snapshots} - target_snapshots = [snapshots_by_id[sid] for sid in command_payload.target_snapshot_ids] - evaluator.create( - target_snapshots, - snapshots_by_id, - deployability_index=command_payload.deployability_index, - allow_destructive_snapshots=command_payload.allow_destructive_snapshots, - ) - - -def migrate_tables( - evaluator: SnapshotEvaluator, - command_payload: t.Union[str, MigrateTablesCommandPayload], -) -> None: - if isinstance(command_payload, str): - command_payload = MigrateTablesCommandPayload.parse_raw(command_payload) - snapshots_by_id = {s.snapshot_id: s for s in command_payload.snapshots} - target_snapshots = [snapshots_by_id[sid] for sid in command_payload.target_snapshot_ids] - evaluator.migrate( - target_snapshots, snapshots_by_id, command_payload.allow_destructive_snapshots - ) - - -COMMAND_HANDLERS: t.Dict[CommandType, t.Callable[[SnapshotEvaluator, str], None]] = { - CommandType.EVALUATE: evaluate, - CommandType.PROMOTE: promote, - CommandType.DEMOTE: demote, - CommandType.CLEANUP: cleanup, - CommandType.CREATE_TABLES: create_tables, - CommandType.MIGRATE_TABLES: migrate_tables, -} diff --git a/sqlmesh/engines/spark/app.py b/sqlmesh/engines/spark/app.py deleted file mode 100644 index a8709361fa..0000000000 --- a/sqlmesh/engines/spark/app.py +++ /dev/null @@ -1,114 +0,0 @@ -import argparse -import logging -import os -import tempfile - -from pyspark import SparkFiles -from pyspark.sql import SparkSession - -from sqlmesh.core.engine_adapter import create_engine_adapter -from sqlmesh.core.snapshot import SnapshotEvaluator -from sqlmesh.engines import commands -from sqlmesh.engines.spark.db_api import spark_session as spark_session_db -from sqlmesh.engines.spark.db_api.errors import NotSupportedError -from sqlmesh.utils.errors import SQLMeshError - -logger = logging.getLogger(__name__) - - -def get_or_create_spark_session(dialect: str) -> SparkSession: - if dialect == "databricks": - spark = SparkSession.getActiveSession() - if not spark: - raise SQLMeshError("Could not find an active SparkSession.") - return spark - return ( - SparkSession.builder.config("spark.scheduler.mode", "FAIR") - .enableHiveSupport() - .getOrCreate() - ) - - -def main( - dialect: str, - default_catalog: str, - command_type: commands.CommandType, - ddl_concurrent_tasks: int, - payload_path: str, -) -> None: - if dialect not in ("databricks", "spark"): - raise NotSupportedError( - f"Dialect '{dialect}' not supported. Must be either 'databricks' or 'spark'" - ) - logging.basicConfig( - format="%(name)s - %(levelname)s - %(message)s (%(filename)s:%(lineno)d)", - level=logging.INFO, - ) - command_handler = commands.COMMAND_HANDLERS.get(command_type) - if not command_handler: - raise NotSupportedError(f"Command '{command_type.value}' not supported") - - spark = get_or_create_spark_session(dialect) - - evaluator = SnapshotEvaluator( - create_engine_adapter( - lambda: spark_session_db.connection(spark), - dialect, - default_catalog=default_catalog, - multithreaded=ddl_concurrent_tasks > 1, - execute_log_level=logging.INFO, - ), - ddl_concurrent_tasks=ddl_concurrent_tasks, - ) - if dialect == "spark": - with open(SparkFiles.get(payload_path), "r", encoding="utf-8") as payload_fd: - command_payload = payload_fd.read() - else: - from pyspark.dbutils import DBUtils # type: ignore - - dbutils = DBUtils(spark) - with tempfile.TemporaryDirectory() as tmp: - local_payload_path = os.path.join(tmp, commands.COMMAND_PAYLOAD_FILE_NAME) - dbutils.fs.cp(payload_path, f"file://{local_payload_path}") - with open(local_payload_path, "r", encoding="utf-8") as payload_fd: - command_payload = payload_fd.read() - logger.info("Command payload:\n %s", command_payload) - command_handler(evaluator, command_payload) - - evaluator.close() - - -if __name__ == "__main__": - parser = argparse.ArgumentParser(description="SQLMesh Spark Submit App") - parser.add_argument( - "--dialect", - help="The dialect to use when creating the engine adapter.", - ) - parser.add_argument( - "--default_catalog", - help="The default catalog to use when creating the engine adapter.", - ) - parser.add_argument( - "--command_type", - type=commands.CommandType, - choices=list(commands.CommandType), - help="The type of command that is being run", - ) - parser.add_argument( - "--ddl_concurrent_tasks", - type=int, - default=1, - help="The number of ddl concurrent tasks to use. Default to 1.", - ) - parser.add_argument( - "--payload_path", - help="Path to the payload object. Can be a local or remote path.", - ) - args = parser.parse_args() - main( - args.dialect, - args.default_catalog, - args.command_type, - args.ddl_concurrent_tasks, - args.payload_path, - ) diff --git a/sqlmesh/engines/spark/db_api/spark_session.py b/sqlmesh/engines/spark/db_api/spark_session.py index c9f3a7e099..04229f2a44 100644 --- a/sqlmesh/engines/spark/db_api/spark_session.py +++ b/sqlmesh/engines/spark/db_api/spark_session.py @@ -1,12 +1,15 @@ +from __future__ import annotations + import logging import typing as t from threading import get_ident -from pyspark.sql import DataFrame, SparkSession -from pyspark.sql.types import Row - from sqlmesh.engines.spark.db_api.errors import NotSupportedError, ProgrammingError +if t.TYPE_CHECKING: + from pyspark.sql import DataFrame, SparkSession + from pyspark.sql.types import Row + logger = logging.getLogger(__name__) @@ -87,14 +90,17 @@ def set_current_catalog(self, catalog_name: str) -> None: ) def cursor(self) -> SparkSessionCursor: + from pyspark.errors import PySparkAttributeError + try: self.spark.sparkContext.setLocalProperty("spark.scheduler.pool", f"pool_{get_ident()}") self.spark.conf.set("spark.sql.sources.partitionOverwriteMode", "dynamic") self.spark.conf.set("hive.exec.dynamic.partition", "true") self.spark.conf.set("hive.exec.dynamic.partition.mode", "nonstrict") - except NotImplementedError: + except (NotImplementedError, PySparkAttributeError): # Databricks Connect does not support accessing the SparkContext nor does it support # setting dynamic partition overwrite since it uses replace where + # Also Serverless jobs don't support access to spark context so we pass for that too pass if self.catalog: from py4j.protocol import Py4JError diff --git a/sqlmesh/integrations/dlt.py b/sqlmesh/integrations/dlt.py new file mode 100644 index 0000000000..2d601a0e22 --- /dev/null +++ b/sqlmesh/integrations/dlt.py @@ -0,0 +1,235 @@ +import typing as t +import click +from datetime import datetime, timedelta, timezone +from pydantic import ValidationError +from sqlglot import exp, parse_one +from sqlmesh.core.config.connection import parse_connection_config +from sqlmesh.core.context import Context +from sqlmesh.utils.date import yesterday_ds + + +def generate_dlt_models_and_settings( + pipeline_name: str, + dialect: str, + tables: t.Optional[t.List[str]] = None, + dlt_path: t.Optional[str] = None, +) -> t.Tuple[t.Set[t.Tuple[str, str]], t.Optional[str], str]: + """ + This function attaches to a DLT pipeline and retrieves the connection configs and + SQLMesh models based on the tables present in the pipeline's default schema. + + Args: + pipeline_name: The name of the DLT pipeline to attach to. + dialect: The SQL dialect to use for generating SQLMesh models. + tables: A list of table names to include. + dlt_path: The path to the directory containing the DLT pipelines. + + Returns: + A tuple containing a set of the SQLMesh model definitions, the connection config and the start date. + """ + + import dlt + from dlt.common.schema.utils import has_table_seen_data, is_complete_column + from dlt.pipeline.exceptions import CannotRestorePipelineException + + try: + pipeline = dlt.attach(pipeline_name=pipeline_name, pipelines_dir=dlt_path or "") + except CannotRestorePipelineException: + raise click.ClickException(f"Could not attach to pipeline {pipeline_name}") + + schema = pipeline.default_schema + dataset = pipeline.dataset_name + + # Get the start date from the load_ids + storage_ids = list(pipeline._get_load_storage().list_loaded_packages()) + start_date = get_start_date(storage_ids) + + # Get the connection credentials + db_type = pipeline.destination.to_name(pipeline.destination) + if db_type == "filesystem": + connection_config = None + else: + if dlt.__version__ >= "1.10.0": + client = pipeline.destination_client() + else: + client = pipeline._sql_job_client(schema) # type: ignore + config = client.config + credentials = config.credentials + configs = { + key: value + for key in dir(credentials) + if not key.startswith("_") + and not callable(value := getattr(credentials, key)) + and value is not None + } + connection_config = format_config(configs, db_type) + + dlt_tables = { + name: table + for name, table in schema.tables.items() + if ( + (has_table_seen_data(table) and not name.startswith(schema._dlt_tables_prefix)) + or name == schema.loads_table_name + ) + and (name in tables if tables else True) + } + + sqlmesh_models = set() + for table_name, table in dlt_tables.items(): + dlt_columns = {} + primary_key = [] + + # is_complete_column returns true if column contains a name and a data type + for col in filter(is_complete_column, table["columns"].values()): + dlt_columns[col["name"]] = exp.DataType.build(str(col["data_type"]), dialect=dialect) + if col.get("primary_key"): + primary_key.append(str(col["name"])) + + load_id = next( + (col for col in ["_dlt_load_id", "load_id"] if col in dlt_columns), + None, + ) + load_key = "c." + load_id if load_id else "" + parent_table = None + + # Handling for nested tables: https://dlthub.com/docs/general-usage/destination-tables#nested-tables + if not load_id: + if ( + "_dlt_parent_id" in dlt_columns + and (parent_table := table["parent"]) + and parent_table in dlt_tables + ): + load_key = "p._dlt_load_id" + parent_table = dataset + "." + parent_table + else: + break + + column_types = [ + exp.cast(exp.column(column, table="c"), data_type, dialect=dialect) + .as_(column) + .sql(dialect=dialect) + for column, data_type in dlt_columns.items() + if isinstance(column, str) + ] + select_columns = ( + ",\n".join(f" {column_name}" for column_name in column_types) if column_types else "" + ) + + grain = f"\n grain ({', '.join(primary_key)})," if primary_key else "" + incremental_model_name = f"{dataset}_sqlmesh.incremental_{table_name}" + incremental_model_sql = generate_incremental_model( + incremental_model_name, + select_columns, + grain, + dataset + "." + table_name, + dialect, + load_key, + parent_table, + ) + sqlmesh_models.add((incremental_model_name, incremental_model_sql)) + + return sqlmesh_models, connection_config, start_date + + +def generate_dlt_models( + context: Context, + pipeline_name: str, + tables: t.List[str], + force: bool, + dlt_path: t.Optional[str] = None, +) -> t.List[str]: + from sqlmesh.cli.project_init import _create_object_files + + sqlmesh_models, _, _ = generate_dlt_models_and_settings( + pipeline_name=pipeline_name, + dialect=context.config.dialect or "", + tables=tables if tables else None, + dlt_path=dlt_path, + ) + + if not tables and not force: + existing_models = [m.name for m in context.models.values()] + sqlmesh_models = {model for model in sqlmesh_models if model[0] not in existing_models} + + if sqlmesh_models: + _create_object_files( + context.path / "models", + {model[0].split(".")[-1]: model[1] for model in sqlmesh_models}, + "sql", + ) + return [model[0] for model in sqlmesh_models] + return [] + + +def generate_incremental_model( + model_name: str, + select_columns: str, + grain: str, + from_table: str, + dialect: str, + load_id: str, + parent_table: t.Optional[str] = None, +) -> str: + """Generate the SQL definition for an incremental model.""" + + time_column = parse_one(f"to_timestamp(CAST({load_id} AS DOUBLE))").sql(dialect=dialect) + + from_clause = f"{from_table} as c" + if parent_table: + from_clause += f"""\nJOIN + {parent_table} as p +ON + c._dlt_parent_id = p._dlt_id""" + + return f"""MODEL ( + name {model_name}, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column _dlt_load_time, + ),{grain} +); + +SELECT +{select_columns}, + {time_column} as _dlt_load_time +FROM + {from_clause} +WHERE + {time_column} BETWEEN @start_ds AND @end_ds +""" + + +def format_config(configs: t.Dict[str, str], db_type: str) -> str: + """Generate a string for the gateway connection config.""" + config = { + "type": db_type, + } + + for key, value in configs.items(): + if key == "password": + config[key] = f'"{value}"' + elif key == "username": + config["user"] = value + else: + config[key] = value + + # Validate the connection config fields + invalid_fields = [] + try: + parse_connection_config(config) + except ValidationError as e: + for error in e.errors(): + invalid_fields.append(error.get("loc", [])[0]) + + return "\n".join( + [f" {key}: {value}" for key, value in config.items() if key not in invalid_fields] + ) + + +def get_start_date(load_ids: t.List[str]) -> str: + """Convert the earliest load_id to UTC timestamp, subtract a day and format as 'YYYY-MM-DD'.""" + + timestamps = [datetime.fromtimestamp(float(id), tz=timezone.utc) for id in load_ids] + if timestamps: + start_timestamp = min(timestamps) - timedelta(days=1) + return start_timestamp.strftime("%Y-%m-%d") + return yesterday_ds() diff --git a/sqlmesh/integrations/github/cicd/command.py b/sqlmesh/integrations/github/cicd/command.py index e5fc207b2b..5506d4917b 100644 --- a/sqlmesh/integrations/github/cicd/command.py +++ b/sqlmesh/integrations/github/cicd/command.py @@ -6,13 +6,14 @@ import click from sqlmesh.core.analytics import cli_analytics +from sqlmesh.core.console import set_console, MarkdownConsole from sqlmesh.integrations.github.cicd.controller import ( GithubCheckConclusion, GithubCheckStatus, GithubController, TestFailure, ) -from sqlmesh.utils.errors import CICDBotError, PlanError +from sqlmesh.utils.errors import CICDBotError, ConflictingPlanError, PlanError, LinterError logger = logging.getLogger(__name__) @@ -21,11 +22,24 @@ @click.option( "--token", type=str, + envvar="GITHUB_TOKEN", help="The Github Token to be used. Pass in `${{ secrets.GITHUB_TOKEN }}` if you want to use the one created by Github actions", ) +@click.option( + "--full-logs", + is_flag=True, + help="Whether to print all logs in the Github Actions output or only in their relevant GA check", +) @click.pass_context -def github(ctx: click.Context, token: str) -> None: +def github(ctx: click.Context, token: str, full_logs: bool = False) -> None: """Github Action CI/CD Bot. See https://sqlmesh.readthedocs.io/en/stable/integrations/github/ for details""" + # set a larger width because if none is specified, it auto-detects 80 characters when running in GitHub Actions + # which can result in surprise newlines when outputting dates to backfill + set_console( + MarkdownConsole( + width=1000, warning_capture_only=not full_logs, error_capture_only=not full_logs + ) + ) ctx.obj["github"] = GithubController( paths=ctx.obj["paths"], token=token, @@ -41,7 +55,7 @@ def _check_required_approvers(controller: GithubController) -> bool: ) return True controller.update_required_approval_check( - status=GithubCheckStatus.COMPLETED, conclusion=GithubCheckConclusion.NEUTRAL + status=GithubCheckStatus.COMPLETED, conclusion=GithubCheckConclusion.FAILURE ) return False @@ -51,36 +65,58 @@ def _check_required_approvers(controller: GithubController) -> bool: @cli_analytics def check_required_approvers(ctx: click.Context) -> None: """Checks if a required approver has provided approval on the PR.""" - _check_required_approvers(ctx.obj["github"]) + if not _check_required_approvers(ctx.obj["github"]): + raise CICDBotError( + "Required approver has not approved the PR. See Pull Requests Checks for more information." + ) def _run_tests(controller: GithubController) -> bool: controller.update_test_check(status=GithubCheckStatus.IN_PROGRESS) try: - result, output = controller.run_tests() + result, _ = controller.run_tests() controller.update_test_check( status=GithubCheckStatus.COMPLETED, # Conclusion will be updated with final status based on test results conclusion=GithubCheckConclusion.NEUTRAL, result=result, - output=output, ) return result.wasSuccessful() except Exception: controller.update_test_check( status=GithubCheckStatus.COMPLETED, conclusion=GithubCheckConclusion.FAILURE, - output=traceback.format_exc(), + traceback=traceback.format_exc(), ) return False +def _run_linter(controller: GithubController) -> bool: + controller.update_linter_check(status=GithubCheckStatus.IN_PROGRESS) + try: + controller.run_linter() + except LinterError: + controller.update_linter_check( + status=GithubCheckStatus.COMPLETED, + conclusion=GithubCheckConclusion.FAILURE, + ) + return False + + controller.update_linter_check( + status=GithubCheckStatus.COMPLETED, + conclusion=GithubCheckConclusion.SUCCESS, + ) + + return True + + @github.command() @click.pass_context @cli_analytics def run_tests(ctx: click.Context) -> None: """Runs the unit tests""" - _run_tests(ctx.obj["github"]) + if not _run_tests(ctx.obj["github"]): + raise CICDBotError("Failed to run tests. See Pull Requests Checks for more information.") def _update_pr_environment(controller: GithubController) -> bool: @@ -90,6 +126,7 @@ def _update_pr_environment(controller: GithubController) -> bool: conclusion = controller.update_pr_environment_check(status=GithubCheckStatus.COMPLETED) return conclusion is not None and conclusion.is_success except Exception as e: + logger.exception("Error occurred when updating PR environment") conclusion = controller.update_pr_environment_check( status=GithubCheckStatus.COMPLETED, exception=e ) @@ -105,7 +142,10 @@ def _update_pr_environment(controller: GithubController) -> bool: @cli_analytics def update_pr_environment(ctx: click.Context) -> None: """Creates or updates the PR environments""" - _update_pr_environment(ctx.obj["github"]) + if not _update_pr_environment(ctx.obj["github"]): + raise CICDBotError( + "Failed to update PR environment. See Pull Requests Checks for more information." + ) def _gen_prod_plan(controller: GithubController) -> bool: @@ -119,6 +159,7 @@ def _gen_prod_plan(controller: GithubController) -> bool: ) return bool(plan_summary) except Exception as e: + logger.exception("Error occurred generating prod plan") controller.update_prod_plan_preview_check( status=GithubCheckStatus.COMPLETED, conclusion=GithubCheckConclusion.FAILURE, @@ -134,7 +175,10 @@ def gen_prod_plan(ctx: click.Context) -> None: """Generates the production plan""" controller = ctx.obj["github"] controller.update_prod_plan_preview_check(status=GithubCheckStatus.IN_PROGRESS) - _gen_prod_plan(controller) + if not _gen_prod_plan(controller): + raise CICDBotError( + "Failed to generate production plan. See Pull Requests Checks for more information." + ) def _deploy_production(controller: GithubController) -> bool: @@ -147,9 +191,23 @@ def _deploy_production(controller: GithubController) -> bool: controller.try_merge_pr() controller.try_invalidate_pr_environment() return True - except PlanError: + except ConflictingPlanError as e: controller.update_prod_environment_check( - status=GithubCheckStatus.COMPLETED, conclusion=GithubCheckConclusion.ACTION_REQUIRED + status=GithubCheckStatus.COMPLETED, + conclusion=GithubCheckConclusion.SKIPPED, + skip_reason=str(e), + ) + return False + except PlanError as e: + controller.update_prod_environment_check( + status=GithubCheckStatus.COMPLETED, + conclusion=GithubCheckConclusion.ACTION_REQUIRED, + plan_error=e, + ) + return False + except Exception: + controller.update_prod_environment_check( + status=GithubCheckStatus.COMPLETED, conclusion=GithubCheckConclusion.FAILURE ) return False @@ -159,14 +217,19 @@ def _deploy_production(controller: GithubController) -> bool: @cli_analytics def deploy_production(ctx: click.Context) -> None: """Deploys the production environment""" - _deploy_production(ctx.obj["github"]) + if not _deploy_production(ctx.obj["github"]): + raise CICDBotError( + "Failed to deploy to production. See Pull Requests Checks for more information." + ) def _run_all(controller: GithubController) -> None: + click.echo(f"SQLMesh Version: {controller.version_info}") + has_required_approval = False is_auto_deploying_prod = ( controller.deploy_command_enabled or controller.do_required_approval_check - ) + ) and controller.pr_targets_prod_branch if controller.is_comment_added: if not controller.deploy_command_enabled: # We aren't using commands so we can just return @@ -175,15 +238,17 @@ def _run_all(controller: GithubController) -> None: if command.is_invalid: # Probably a comment unrelated to SQLMesh so we do nothing return - elif command.is_deploy_prod: + if command.is_deploy_prod: has_required_approval = True else: raise CICDBotError(f"Unsupported command: {command}") + controller.update_linter_check(status=GithubCheckStatus.QUEUED) controller.update_pr_environment_check(status=GithubCheckStatus.QUEUED) controller.update_prod_plan_preview_check(status=GithubCheckStatus.QUEUED) controller.update_test_check(status=GithubCheckStatus.QUEUED) if is_auto_deploying_prod: controller.update_prod_environment_check(status=GithubCheckStatus.QUEUED) + linter_passed = _run_linter(controller) tests_passed = _run_tests(controller) if controller.do_required_approval_check: if has_required_approval: @@ -193,23 +258,27 @@ def _run_all(controller: GithubController) -> None: else: controller.update_required_approval_check(status=GithubCheckStatus.QUEUED) has_required_approval = _check_required_approvers(controller) - if not tests_passed: + if not tests_passed or not linter_passed: controller.update_pr_environment_check( status=GithubCheckStatus.COMPLETED, - exception=TestFailure(), + exception=LinterError("") if not linter_passed else TestFailure(), ) controller.update_prod_plan_preview_check( status=GithubCheckStatus.COMPLETED, conclusion=GithubCheckConclusion.SKIPPED, - summary="Unit Test(s) Failed so skipping creating prod plan", + summary="Linter or Unit Test(s) failed so skipping creating prod plan", ) if is_auto_deploying_prod: controller.update_prod_environment_check( status=GithubCheckStatus.COMPLETED, conclusion=GithubCheckConclusion.SKIPPED, - skip_reason="Unit Test(s) Failed so skipping deploying to production", + skip_reason="Linter or Unit Test(s) failed so skipping deploying to production", ) - return + + raise CICDBotError( + "Linter or Unit Test(s) failed. See Pull Requests Checks for more information." + ) + pr_environment_updated = _update_pr_environment(controller) prod_plan_generated = False if pr_environment_updated: @@ -218,10 +287,13 @@ def _run_all(controller: GithubController) -> None: controller.update_prod_plan_preview_check( status=GithubCheckStatus.COMPLETED, conclusion=GithubCheckConclusion.SKIPPED ) - if tests_passed and has_required_approval and pr_environment_updated and prod_plan_generated: - _deploy_production(controller) + deployed_to_prod = False + if has_required_approval and prod_plan_generated and controller.pr_targets_prod_branch: + deployed_to_prod = _deploy_production(controller) elif is_auto_deploying_prod: - if not has_required_approval: + if controller.deploy_command_enabled and not has_required_approval: + skip_reason = "Skipped Deploying to Production because a `/deploy` command has not been detected yet" + elif controller.do_required_approval_check and not has_required_approval: skip_reason = ( "Skipped Deploying to Production because a required approver has not approved" ) @@ -240,6 +312,14 @@ def _run_all(controller: GithubController) -> None: conclusion=GithubCheckConclusion.SKIPPED, skip_reason=skip_reason, ) + if ( + not pr_environment_updated + or not prod_plan_generated + or (has_required_approval and controller.pr_targets_prod_branch and not deployed_to_prod) + ): + raise CICDBotError( + "A step of the run-all check failed. See Pull Requests Checks for more information." + ) @github.command() diff --git a/sqlmesh/integrations/github/cicd/config.py b/sqlmesh/integrations/github/cicd/config.py index 95b090af8d..7fb3a0f5b6 100644 --- a/sqlmesh/integrations/github/cicd/config.py +++ b/sqlmesh/integrations/github/cicd/config.py @@ -1,4 +1,3 @@ -import sys import typing as t from enum import Enum @@ -7,12 +6,8 @@ from sqlmesh.core.config import CategorizerConfig from sqlmesh.core.config.base import BaseConfig from sqlmesh.utils.date import TimeLike -from sqlmesh.utils.pydantic import model_validator, model_validator_v1_args - -if sys.version_info >= (3, 9): - from typing import Literal -else: - from typing_extensions import Literal +from sqlmesh.utils.pydantic import model_validator +from sqlmesh.core.console import get_console class MergeMethod(str, Enum): @@ -22,27 +17,69 @@ class MergeMethod(str, Enum): class GithubCICDBotConfig(BaseConfig): - type_: Literal["github"] = Field(alias="type", default="github") + type_: t.Literal["github"] = Field(alias="type", default="github") invalidate_environment_after_deploy: bool = True enable_deploy_command: bool = False merge_method: t.Optional[MergeMethod] = None command_namespace: t.Optional[str] = None - auto_categorize_changes: CategorizerConfig = CategorizerConfig.all_off() + auto_categorize_changes_: t.Optional[CategorizerConfig] = Field( + default=None, alias="auto_categorize_changes" + ) default_pr_start: t.Optional[TimeLike] = None - skip_pr_backfill: bool = True - pr_include_unmodified: t.Optional[bool] = None - run_on_deploy_to_prod: bool = True + skip_pr_backfill_: t.Optional[bool] = Field(default=None, alias="skip_pr_backfill") + pr_include_unmodified_: t.Optional[bool] = Field(default=None, alias="pr_include_unmodified") + run_on_deploy_to_prod: bool = False pr_environment_name: t.Optional[str] = None + pr_min_intervals: t.Optional[int] = None + prod_branch_names_: t.Optional[str] = Field(default=None, alias="prod_branch_name") + forward_only_branch_suffix_: t.Optional[str] = Field( + default=None, alias="forward_only_branch_suffix" + ) + check_if_blocked_on_deploy_to_prod: bool = True @model_validator(mode="before") - @model_validator_v1_args - def _validate(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: - if values.get("enable_deploy_command") and not values.get("merge_method"): + @classmethod + def _validate(cls, data: t.Any) -> t.Any: + if not isinstance(data, dict): + return data + + if data.get("enable_deploy_command") and not data.get("merge_method"): raise ValueError("merge_method must be set if enable_deploy_command is True") - if values.get("command_namespace") and not values.get("enable_deploy_command"): + if data.get("command_namespace") and not data.get("enable_deploy_command"): raise ValueError("enable_deploy_command must be set if command_namespace is set") - return values + + return data + + @property + def prod_branch_names(self) -> t.List[str]: + if self.prod_branch_names_: + return [self.prod_branch_names_] + return ["main", "master"] + + @property + def auto_categorize_changes(self) -> CategorizerConfig: + return self.auto_categorize_changes_ or CategorizerConfig.all_off() + + @property + def pr_include_unmodified(self) -> bool: + return self.pr_include_unmodified_ or False + + @property + def skip_pr_backfill(self) -> bool: + if self.skip_pr_backfill_ is None: + get_console().log_warning( + "`skip_pr_backfill` is unset, defaulting it to `true` (no data will be backfilled).\n" + "Future versions of SQLMesh will default to `skip_pr_backfill: false` to align with the CLI default behaviour.\n" + "If you would like to preserve the current behaviour and remove this warning, please explicitly set `skip_pr_backfill: true` in the bot config.\n\n" + "For more information on configuring the bot, see: https://sqlmesh.readthedocs.io/en/stable/integrations/github/" + ) + return True + return self.skip_pr_backfill_ + + @property + def forward_only_branch_suffix(self) -> str: + return self.forward_only_branch_suffix_ or "-forward-only" FIELDS_FOR_ANALYTICS: t.ClassVar[t.Set[str]] = { "invalidate_environment_after_deploy", @@ -54,4 +91,6 @@ def _validate(cls, values: t.Dict[str, t.Any]) -> t.Dict[str, t.Any]: "skip_pr_backfill", "pr_include_unmodified", "run_on_deploy_to_prod", + "pr_min_intervals", + "forward_only_branch_suffix", } diff --git a/sqlmesh/integrations/github/cicd/controller.py b/sqlmesh/integrations/github/cicd/controller.py index 3a9e680882..40102b97e8 100644 --- a/sqlmesh/integrations/github/cicd/controller.py +++ b/sqlmesh/integrations/github/cicd/controller.py @@ -8,37 +8,40 @@ import re import traceback import typing as t -import unittest from enum import Enum -from typing import List +from pathlib import Path +from dataclasses import dataclass +from functools import cached_property import requests -from hyperscript import Element, h -from rich.console import Console from sqlglot.helper import seq_get from sqlmesh.core import constants as c -from sqlmesh.core.console import SNAPSHOT_CHANGE_CATEGORY_STR, MarkdownConsole +from sqlmesh.core.console import SNAPSHOT_CHANGE_CATEGORY_STR, get_console, MarkdownConsole from sqlmesh.core.context import Context +from sqlmesh.core.test.result import ModelTextTestResult from sqlmesh.core.environment import Environment -from sqlmesh.core.plan import Plan, PlanBuilder +from sqlmesh.core.plan import Plan, PlanBuilder, SnapshotIntervals +from sqlmesh.core.plan.definition import UserProvidedFlags from sqlmesh.core.snapshot.definition import ( Snapshot, SnapshotChangeCategory, SnapshotId, SnapshotTableInfo, - format_intervals, ) +from sqlglot.errors import SqlglotError from sqlmesh.core.user import User +from sqlmesh.core.config import Config from sqlmesh.integrations.github.cicd.config import GithubCICDBotConfig -from sqlmesh.utils import word_characters_only -from sqlmesh.utils.concurrency import NodeExecutionFailedError +from sqlmesh.utils import word_characters_only, Verbosity from sqlmesh.utils.date import now from sqlmesh.utils.errors import ( CICDBotError, NoChangesPlanError, PlanError, UncategorizedPlanError, + LinterError, + SQLMeshError, ) from sqlmesh.utils.pydantic import PydanticModel @@ -51,8 +54,6 @@ from github.PullRequestReview import PullRequestReview from github.Repository import Repository - from sqlmesh.core.config import Config - logger = logging.getLogger(__name__) @@ -77,7 +78,7 @@ def create_from_pull_request_url(cls, pull_request_url: str) -> PullRequestInfo: return cls( owner=owner, repo=repo, - pr_number=pr_number, + pr_number=int(pr_number), ) @@ -284,11 +285,12 @@ class GithubController: def __init__( self, - paths: t.Union[str, t.Iterable[str]], + paths: t.Union[Path, t.Iterable[Path]], token: str, config: t.Optional[t.Union[Config, str]] = None, event: t.Optional[GithubEvent] = None, client: t.Optional[Github] = None, + context: t.Optional[Context] = None, ) -> None: from github import Github @@ -303,11 +305,18 @@ def __init__( self._prod_plan_builder: t.Optional[PlanBuilder] = None self._prod_plan_with_gaps_builder: t.Optional[PlanBuilder] = None self._check_run_mapping: t.Dict[str, CheckRun] = {} - self._console = MarkdownConsole(console=Console(no_color=True)) + + if not isinstance(get_console(), MarkdownConsole): + raise CICDBotError("Console must be a markdown console.") + self._console = t.cast(MarkdownConsole, get_console()) + + from github.Consts import DEFAULT_BASE_URL + from github.Auth import Token + self._client: Github = client or Github( - base_url=os.environ["GITHUB_API_URL"], - login_or_token=self._token, + base_url=os.environ.get("GITHUB_API_URL", DEFAULT_BASE_URL), auth=Token(self._token) ) + self._repo: Repository = self._client.get_repo( self._event.pull_request_info.full_repo_path, lazy=True ) @@ -323,11 +332,10 @@ def __init__( if review.state.lower() == "approved" } logger.debug(f"Approvers: {', '.join(self._approvers)}") - self._context: Context = Context( - paths=self._paths, - config=self.config, - console=self._console, - ) + self._context: Context = context or Context(paths=self._paths, config=self.config) + + # Bot config needs the context to be initialized + logger.debug(f"Bot config: {self.bot_config.json(indent=2)}") @property def deploy_command_enabled(self) -> bool: @@ -392,14 +400,32 @@ def pr_plan(self) -> Plan: self._pr_plan_builder = self._context.plan_builder( environment=self.pr_environment_name, skip_tests=True, + skip_linter=True, categorizer_config=self.bot_config.auto_categorize_changes, start=self.bot_config.default_pr_start, + min_intervals=self.bot_config.pr_min_intervals, skip_backfill=self.bot_config.skip_pr_backfill, include_unmodified=self.bot_config.pr_include_unmodified, + forward_only=self.forward_only_plan, ) assert self._pr_plan_builder return self._pr_plan_builder.build() + @property + def pr_plan_or_none(self) -> t.Optional[Plan]: + try: + return self.pr_plan + except: + return None + + @property + def pr_plan_flags(self) -> t.Optional[t.Dict[str, UserProvidedFlags]]: + if pr_plan := self.pr_plan_or_none: + return pr_plan.user_provided_flags + if pr_plan_builder := self._pr_plan_builder: + return pr_plan_builder._user_provided_flags + return None + @property def prod_plan(self) -> Plan: if not self._prod_plan_builder: @@ -407,8 +433,10 @@ def prod_plan(self) -> Plan: c.PROD, no_gaps=True, skip_tests=True, + skip_linter=True, categorizer_config=self.bot_config.auto_categorize_changes, run=self.bot_config.run_on_deploy_to_prod, + forward_only=self.forward_only_plan, ) assert self._prod_plan_builder return self._prod_plan_builder.build() @@ -418,10 +446,13 @@ def prod_plan_with_gaps(self) -> Plan: if not self._prod_plan_with_gaps_builder: self._prod_plan_with_gaps_builder = self._context.plan_builder( c.PROD, + # this is required to highlight any data gaps between this PR environment and prod (since PR environments may only contain a subset of data) no_gaps=False, - no_auto_categorization=True, skip_tests=True, + skip_linter=True, + categorizer_config=self.bot_config.auto_categorize_changes, run=self.bot_config.run_on_deploy_to_prod, + forward_only=self.forward_only_plan, ) assert self._prod_plan_with_gaps_builder return self._prod_plan_with_gaps_builder.build() @@ -431,7 +462,6 @@ def bot_config(self) -> GithubCICDBotConfig: bot_config = self._context.config.cicd_bot or GithubCICDBotConfig( auto_categorize_changes=self._context.auto_categorize_changes ) - logger.debug(f"Bot config: {bot_config.json(indent=2)}") return bot_config @property @@ -442,40 +472,208 @@ def modified_snapshots(self) -> t.Dict[SnapshotId, t.Union[Snapshot, SnapshotTab def removed_snapshots(self) -> t.Set[SnapshotId]: return set(self.prod_plan_with_gaps.context_diff.removed_snapshots) + @property + def pr_targets_prod_branch(self) -> bool: + return self._pull_request.base.ref in self.bot_config.prod_branch_names + + @property + def forward_only_plan(self) -> bool: + default = self._context.config.plan.forward_only + head_ref = self._pull_request.head.ref + if isinstance(head_ref, str): + return head_ref.endswith(self.bot_config.forward_only_branch_suffix) or default + return default + @classmethod def _append_output(cls, key: str, value: str) -> None: """ Appends the given key/value to output so they can be read by following steps """ logger.debug(f"Setting output. Key: {key}, Value: {value}") - with open(os.environ["GITHUB_OUTPUT"], "a", encoding="utf-8") as fh: - print(f"{key}={value}", file=fh) + + # GitHub Actions sets this environment variable + if output_file := os.environ.get("GITHUB_OUTPUT"): + with open(output_file, "a", encoding="utf-8") as fh: + print(f"{key}={value}", file=fh) + + def get_forward_only_plan_post_deployment_tip(self, plan: Plan) -> str: + if not plan.forward_only: + return "" + + example_model_name = "" + for snapshot_id in sorted(plan.snapshots): + snapshot = plan.snapshots[snapshot_id] + if snapshot.is_incremental: + example_model_name = snapshot.node.name + break + + return ( + "> [!TIP]\n" + "> In order to see this forward-only plan retroactively apply to historical intervals on the production model, run the below for date ranges in scope:\n" + "> \n" + f"> `$ sqlmesh plan --restate-model {example_model_name} --start YYYY-MM-DD --end YYYY-MM-DD`\n" + ">\n" + "> Learn more: https://sqlmesh.readthedocs.io/en/stable/concepts/plans/?h=restate#restatement-plans" + ) def get_plan_summary(self, plan: Plan) -> str: + # use Verbosity.VERY_VERBOSE to prevent the list of models from being truncated + # this is particularly important for the "Models needing backfill" list because + # there is no easy way to tell this otherwise + orig_verbosity = self._console.verbosity + self._console.verbosity = Verbosity.VERY_VERBOSE + try: # Clear out any output that might exist from prior steps - self._console.clear_captured_outputs() - self._console.show_model_difference_summary( - context_diff=plan.context_diff, - environment_naming_info=plan.environment_naming_info, - default_catalog=self._context.default_catalog, - no_diff=False, - ignored_snapshot_ids=plan.ignored, - ) + self._console.consume_captured_output() + if plan.restatements: + self._console._print("\n**Restating models**\n") + else: + self._console.show_environment_difference_summary( + context_diff=plan.context_diff, + no_diff=False, + ) + if plan.context_diff.has_changes: + self._console.show_model_difference_summary( + context_diff=plan.context_diff, + environment_naming_info=plan.environment_naming_info, + default_catalog=self._context.default_catalog, + no_diff=False, + ) difference_summary = self._console.consume_captured_output() self._console._show_missing_dates(plan, self._context.default_catalog) missing_dates = self._console.consume_captured_output() + + plan_flags_section = ( + f"\n\n{self._generate_plan_flags_section(plan.user_provided_flags)}" + if plan.user_provided_flags + else "" + ) + if not difference_summary and not missing_dates: - return "No changes to apply." - return f"{difference_summary}\n{missing_dates}" + return f"No changes to apply.{plan_flags_section}" + + warnings_block = self._console.consume_captured_warnings() + errors_block = self._console.consume_captured_errors() + + return f"{warnings_block}{errors_block}{difference_summary}\n{missing_dates}{plan_flags_section}" except PlanError as e: + logger.exception("Plan failed to generate") return f"Plan failed to generate. Check for pending or unresolved changes. Error: {e}" + finally: + self._console.verbosity = orig_verbosity + + def get_pr_environment_summary( + self, conclusion: GithubCheckConclusion, exception: t.Optional[Exception] = None + ) -> str: + heading = "" + summary = "" + + if conclusion.is_success: + summary = self._get_pr_environment_summary_success() + elif conclusion.is_action_required: + heading = f":warning: Action Required to create or update PR Environment `{self.pr_environment_name}` :warning:" + summary = self._get_pr_environment_summary_action_required(exception) + elif conclusion.is_failure: + heading = ( + f":x: Failed to create or update PR Environment `{self.pr_environment_name}` :x:" + ) + summary = self._get_pr_environment_summary_failure(exception) + elif conclusion.is_skipped: + heading = f":next_track_button: Skipped creating or updating PR Environment `{self.pr_environment_name}` :next_track_button:" + summary = self._get_pr_environment_summary_skipped(exception) + else: + heading = f":interrobang: Got an unexpected conclusion: {conclusion.value}" + + # note: we just add warnings here, errors will be covered by the "failure" conclusion + if warnings := self._console.consume_captured_warnings(): + summary = f"{warnings}\n{summary}" - def run_tests(self) -> t.Tuple[unittest.result.TestResult, str]: + return f"{heading}\n\n{summary}".strip() + + def _get_pr_environment_summary_success(self) -> str: + prod_plan = self.prod_plan_with_gaps + + if not prod_plan.has_changes: + summary = "No models were modified in this PR.\n" + else: + intro = self._generate_pr_environment_summary_intro() + summary = intro + self._generate_pr_environment_summary_list(prod_plan) + + if prod_plan.user_provided_flags: + summary += self._generate_plan_flags_section(prod_plan.user_provided_flags) + + return summary + + def _get_pr_environment_summary_skipped(self, exception: t.Optional[Exception] = None) -> str: + if isinstance(exception, NoChangesPlanError): + skip_reason = "No changes were detected compared to the prod environment." + elif isinstance(exception, TestFailure): + skip_reason = "Unit Test(s) Failed so skipping PR creation" + else: + skip_reason = "A prior stage failed resulting in skipping PR creation." + + return skip_reason + + def _get_pr_environment_summary_action_required( + self, exception: t.Optional[Exception] = None + ) -> str: + plan = self.pr_plan_or_none + if isinstance(exception, UncategorizedPlanError) and plan: + failure_msg = f"The following models could not be categorized automatically:\n" + for snapshot in plan.uncategorized: + failure_msg += f"- {snapshot.name}\n" + failure_msg += ( + f"\nRun `sqlmesh plan {self.pr_environment_name}` locally to apply these changes.\n\n" + "If you would like the bot to automatically categorize changes, check the [documentation](https://sqlmesh.readthedocs.io/en/stable/integrations/github/) for more information." + ) + else: + failure_msg = "Please check the Actions Workflow logs for more information." + + return failure_msg + + def _get_pr_environment_summary_failure(self, exception: t.Optional[Exception] = None) -> str: + console_output = self._console.consume_captured_output() + failure_msg = "" + + if isinstance(exception, PlanError): + if exception.args and (msg := exception.args[0]) and isinstance(msg, str): + failure_msg += f"*{msg}*\n" + if console_output: + failure_msg += f"\n{console_output}" + elif isinstance(exception, (SQLMeshError, SqlglotError, ValueError)): + # this logic is taken from the global error handler attached to the CLI, which uses `click.echo()` to output the message + # so cant be re-used here because it bypasses the Console + failure_msg = f"**Error:** {str(exception)}" + elif exception: + logger.debug( + "Got unexpected error. Error Type: " + + str(type(exception)) + + " Stack trace: " + + traceback.format_exc() + ) + failure_msg = f"This is an unexpected error.\n\n**Exception:**\n```\n{traceback.format_exc()}\n```" + + if captured_errors := self._console.consume_captured_errors(): + failure_msg = f"{captured_errors}\n{failure_msg}" + + if plan_flags := self.pr_plan_flags: + failure_msg += f"\n\n{self._generate_plan_flags_section(plan_flags)}" + + return failure_msg + + def run_tests(self) -> t.Tuple[ModelTextTestResult, str]: """ Run tests for the PR """ - return self._context._run_tests(verbose=True) + return self._context._run_tests(verbosity=Verbosity.VERBOSE) + + def run_linter(self) -> None: + """ + Run linter for the PR + """ + self._console.consume_captured_output() + self._context.lint_models() def _get_or_create_comment(self, header: str = BOT_HEADER_MSG) -> IssueComment: comment = seq_get( @@ -546,7 +744,22 @@ def update_pr_environment(self) -> None: Creates a PR environment from the logic present in the PR. If the PR contains changes that are uncategorized, then an error will be raised. """ - self._context.apply(self.pr_plan) + self._console.consume_captured_output() # clear output buffer + self._context.apply(self.pr_plan) # will raise if PR environment creation fails + + # update PR info comment + vde_title = "- :eyes: To **review** this PR's changes, use virtual data environment:" + comment_value = f"{vde_title}\n - `{self.pr_environment_name}`" + if self.bot_config.enable_deploy_command: + full_command = f"{self.bot_config.command_namespace or ''}/deploy" + comment_value += f"\n- :arrow_forward: To **apply** this PR's plan to prod, comment:\n - `{full_command}`" + dedup_regex = vde_title.replace("*", r"\*") + r".*" + updated_comment, _ = self.update_sqlmesh_comment_info( + value=comment_value, + dedup_regex=dedup_regex, + ) + if updated_comment: + self._append_output("created_pr_environment", "true") def deploy_to_prod(self) -> None: """ @@ -559,6 +772,11 @@ def deploy_to_prod(self) -> None: "PR is already merged and this event was triggered prior to the merge." ) merge_status = self._get_merge_state_status() + if self.bot_config.check_if_blocked_on_deploy_to_prod and merge_status.is_blocked: + raise CICDBotError( + "Branch protection or ruleset requirement is likely not satisfied, e.g. missing CODEOWNERS approval. " + "Please check PR and resolve any issues. To disable this check, set `check_if_blocked_on_deploy_to_prod` to false in the bot configuration." + ) if merge_status.is_dirty: raise CICDBotError( "Merge commit cannot be cleanly created. Likely from a merge conflict. " @@ -571,6 +789,11 @@ def deploy_to_prod(self) -> None: """ + if self.forward_only_plan: + plan_summary = ( + f"{self.get_forward_only_plan_post_deployment_tip(self.prod_plan)}\n{plan_summary}" + ) + self.update_sqlmesh_comment_info( value=plan_summary, dedup_regex=None, @@ -617,15 +840,31 @@ def _update_check( if text: kwargs["output"]["text"] = text logger.debug(f"Updating check with kwargs: {kwargs}") - if name in self._check_run_mapping: - logger.debug(f"Found check run in mapping so updating it. Name: {name}") - check_run = self._check_run_mapping[name] - check_run.edit( - **{k: v for k, v in kwargs.items() if k not in ("name", "head_sha", "started_at")} - ) + + if self.running_in_github_actions: + # Only make the API call to update the checks if we are running within GitHub Actions + # One very annoying limitation of the Pull Request Checks API is that its only available to GitHub Apps + # and not personal access tokens, which makes it unable to be utilized during local development + if name in self._check_run_mapping: + logger.debug(f"Found check run in mapping so updating it. Name: {name}") + check_run = self._check_run_mapping[name] + check_run.edit( + **{ + k: v + for k, v in kwargs.items() + if k not in ("name", "head_sha", "started_at") + } + ) + else: + logger.debug(f"Did not find check run in mapping so creating it. Name: {name}") + self._check_run_mapping[name] = self._repo.create_check_run(**kwargs) else: - logger.debug(f"Did not find check run in mapping so creating it. Name: {name}") - self._check_run_mapping[name] = self._repo.create_check_run(**kwargs) + # Output the summary using print() so the newlines are resolved and the result can easily + # be disambiguated from the rest of the console output and copy+pasted into a Markdown renderer + print( + f"---CHECK OUTPUT START: {kwargs['output']['title']} ---\n{kwargs['output']['summary']}\n---CHECK OUTPUT END---\n" + ) + if conclusion: self._append_output( word_characters_only(name.replace("SQLMesh - ", "").lower()), conclusion.value @@ -653,25 +892,58 @@ def _update_check_handler( full_summary=summary, ) + def update_linter_check( + self, + status: GithubCheckStatus, + conclusion: t.Optional[GithubCheckConclusion] = None, + ) -> None: + if not self._context.config.linter.enabled: + return + + def conclusion_handler( + conclusion: GithubCheckConclusion, + ) -> t.Tuple[GithubCheckConclusion, str, t.Optional[str]]: + linter_summary = self._console.consume_captured_output() or "Linter Success" + + title = "Linter results" + + return conclusion, title, linter_summary + + self._update_check_handler( + check_name="SQLMesh - Linter", + status=status, + conclusion=conclusion, + status_handler=lambda status: ( + { + GithubCheckStatus.IN_PROGRESS: "Running linter", + GithubCheckStatus.QUEUED: "Waiting to Run linter", + }[status], + None, + ), + conclusion_handler=conclusion_handler, + ) + def update_test_check( self, status: GithubCheckStatus, conclusion: t.Optional[GithubCheckConclusion] = None, - result: t.Optional[unittest.result.TestResult] = None, - output: t.Optional[str] = None, + result: t.Optional[ModelTextTestResult] = None, + traceback: t.Optional[str] = None, ) -> None: """ Updates the status of tests for code in the PR """ def conclusion_handler( - conclusion: GithubCheckConclusion, result: unittest.result.TestResult, output: str + conclusion: GithubCheckConclusion, + result: t.Optional[ModelTextTestResult], ) -> t.Tuple[GithubCheckConclusion, str, t.Optional[str]]: if result: # Clear out console self._console.consume_captured_output() self._console.log_test_results( - result, output, self._context._test_connection_config._engine_adapter.DIALECT + result, + self._context.test_connection_config._engine_adapter.DIALECT, ) test_summary = self._console.consume_captured_output() test_title = "Tests Passed" if result.wasSuccessful() else "Tests Failed" @@ -681,8 +953,11 @@ def conclusion_handler( else GithubCheckConclusion.FAILURE ) return test_conclusion, test_title, test_summary + if traceback: + self._console._print(traceback) + test_title = "Skipped Tests" if conclusion.is_skipped else "Tests Failed" - return conclusion, test_title, output + return conclusion, test_title, traceback self._update_check_handler( check_name="SQLMesh - Run Unit Tests", @@ -695,7 +970,7 @@ def conclusion_handler( }[status], None, ), - conclusion_handler=functools.partial(conclusion_handler, result=result, output=output), + conclusion_handler=functools.partial(conclusion_handler, result=result), ) def update_required_approval_check( @@ -738,15 +1013,13 @@ def conclusion_handler( ) def update_pr_environment_check( - self, - status: GithubCheckStatus, - exception: t.Optional[Exception] = None, + self, status: GithubCheckStatus, exception: t.Optional[Exception] = None ) -> t.Optional[GithubCheckConclusion]: """ Updates the status of the merge commit for the PR environment. """ conclusion: t.Optional[GithubCheckConclusion] = None - if isinstance(exception, (NoChangesPlanError, TestFailure)): + if isinstance(exception, (NoChangesPlanError, TestFailure, LinterError)): conclusion = GithubCheckConclusion.SKIPPED elif isinstance(exception, UncategorizedPlanError): conclusion = GithubCheckConclusion.ACTION_REQUIRED @@ -761,106 +1034,7 @@ def update_pr_environment_check( def conclusion_handler( conclusion: GithubCheckConclusion, exception: t.Optional[Exception] ) -> t.Tuple[GithubCheckConclusion, str, t.Optional[str]]: - if conclusion.is_success: - if not self.modified_snapshots: - summary = "No models were modified in this PR.\n" - else: - header_rows = [ - h("th", {"colspan": "3"}, "PR Environment Summary"), - [ - h("th", "Model"), - h("th", "Change Type"), - h("th", "Dates Loaded"), - ], - ] - body_rows: List[Element | List[Element]] = [] - for modified_snapshot in self.modified_snapshots.values(): - # We don't want to display indirect non-breaking since to users these are effectively no-op changes - if modified_snapshot.is_indirect_non_breaking: - continue - if modified_snapshot.snapshot_id in self.removed_snapshots: - # This will be an FQN since we don't have access to node name from a snapshot table info - # which is what a removed snapshot is - model_name = modified_snapshot.name - change_category = SNAPSHOT_CHANGE_CATEGORY_STR[ - SnapshotChangeCategory.BREAKING - ] - interval_output = "REMOVED" - else: - assert isinstance(modified_snapshot, Snapshot) - model_name = modified_snapshot.node.name - change_category = ( - "Uncategorized" - if not modified_snapshot.change_category - else SNAPSHOT_CHANGE_CATEGORY_STR[modified_snapshot.change_category] - ) - intervals = ( - modified_snapshot.dev_intervals - if modified_snapshot.is_forward_only - else modified_snapshot.intervals - ) - interval_output = ( - format_intervals(intervals, modified_snapshot.node.interval_unit) - if intervals - else "N/A" - ) - body_rows.append( - [ - h("td", model_name, autoescape=False), - h("td", change_category), - h("td", interval_output), - ] - ) - table_header = h("thead", [h("tr", row) for row in header_rows]) - table_body = h("tbody", [h("tr", row) for row in body_rows]) - summary = str(h("table", [table_header, table_body])) - vde_title = ( - "- :eyes: To **review** this PR's changes, use virtual data environment:" - ) - comment_value = f"{vde_title}\n - `{self.pr_environment_name}`" - if self.bot_config.enable_deploy_command: - comment_value += "\n- :arrow_forward: To **apply** this PR's plan to prod, comment:\n - `/deploy`" - dedup_regex = vde_title.replace("*", r"\*") + r".*" - updated_comment, _ = self.update_sqlmesh_comment_info( - value=comment_value, - dedup_regex=dedup_regex, - ) - if updated_comment: - self._append_output("created_pr_environment", "true") - else: - if isinstance(exception, NoChangesPlanError): - skip_reason = "No changes were detected compared to the prod environment." - elif isinstance(exception, TestFailure): - skip_reason = "Unit Test(s) Failed so skipping PR creation" - else: - skip_reason = "A prior stage failed resulting in skipping PR creation." - - captured_errors = self._console.consume_captured_errors() - if captured_errors: - logger.debug(f"Captured errors: {captured_errors}") - failure_msg = f"**Errors:**\n{captured_errors}\n" - elif isinstance(exception, NodeExecutionFailedError): - logger.debug( - "Got Node Execution Failed Error. Stack trace: " + traceback.format_exc() - ) - failure_msg = f"Node `{exception.node.name}` failed to apply.\n\n**Stack Trace:**\n```\n{traceback.format_exc()}\n```" - else: - logger.debug( - "Got unexpected error. Error Type: " - + str(type(exception)) - + " Stack trace: " - + traceback.format_exc() - ) - failure_msg = f"This is an unexpected error.\n\n**Exception:**\n```\n{traceback.format_exc()}\n```" - conclusion_to_summary = { - GithubCheckConclusion.SKIPPED: f":next_track_button: Skipped creating or updating PR Environment `{self.pr_environment_name}`. {skip_reason}", - GithubCheckConclusion.FAILURE: f":x: Failed to create or update PR Environment `{self.pr_environment_name}`.\n{failure_msg}", - GithubCheckConclusion.CANCELLED: f":stop_sign: Cancelled creating or updating PR Environment `{self.pr_environment_name}`", - GithubCheckConclusion.ACTION_REQUIRED: f":warning: Action Required to create or update PR Environment `{self.pr_environment_name}`. There are likely uncateogrized changes. Run `plan` locally to apply these changes. If you want the bot to automatically categorize changes, then check documentation (https://sqlmesh.readthedocs.io/en/stable/integrations/github/) for more information.", - } - summary = conclusion_to_summary.get( - conclusion, f":interrobang: Got an unexpected conclusion: {conclusion.value}" - ) + summary = self.get_pr_environment_summary(conclusion, exception) self._append_output("pr_environment_name", self.pr_environment_name) return conclusion, check_title, summary @@ -901,6 +1075,12 @@ def conclusion_handler( title = conclusion_to_title.get( conclusion, f"Got an unexpected conclusion: {conclusion.value}" ) + if conclusion == GithubCheckConclusion.SUCCESS and summary: + summary = ( + f"This is a preview that shows the differences between this PR environment `{self.pr_environment_name}` and `prod`.\n\n" + "These are the changes that would be deployed.\n\n" + ) + summary + return conclusion, title, summary self._update_check_handler( @@ -922,6 +1102,7 @@ def update_prod_environment_check( status: GithubCheckStatus, conclusion: t.Optional[GithubCheckConclusion] = None, skip_reason: t.Optional[str] = None, + plan_error: t.Optional[PlanError] = None, ) -> None: """ Updates the status of the merge commit for the prod environment. @@ -933,22 +1114,29 @@ def conclusion_handler( conclusion_to_title = { GithubCheckConclusion.SUCCESS: "Deployed to Prod", GithubCheckConclusion.CANCELLED: "Cancelled deploying to prod", - GithubCheckConclusion.SKIPPED: skip_reason, + GithubCheckConclusion.SKIPPED: "Skipped deployment", GithubCheckConclusion.FAILURE: "Failed to deploy to prod", + GithubCheckConclusion.ACTION_REQUIRED: "Failed due to error applying plan", } title = ( conclusion_to_title.get(conclusion) or f"Got an unexpected conclusion: {conclusion.value}" ) if conclusion.is_skipped: - summary = title + summary = skip_reason elif conclusion.is_failure: captured_errors = self._console.consume_captured_errors() summary = ( captured_errors or f"{title}\n\n**Error:**\n```\n{traceback.format_exc()}\n```" ) + elif conclusion.is_action_required: + if plan_error: + summary = f"**Plan error:**\n```\n{plan_error}\n```" + else: + summary = "Got an action required conclusion but no plan error was provided. This is unexpected." else: summary = "**Generated Prod Plan**\n" + self.get_plan_summary(self.prod_plan) + return conclusion, title, summary self._update_check_handler( @@ -999,3 +1187,236 @@ def _chunk_up_api_message(self, message: str) -> t.List[str]: message_encoded[i : i + self.MAX_BYTE_LENGTH].decode("utf-8", "ignore") for i in range(0, len(message_encoded), self.MAX_BYTE_LENGTH) ] + + @property + def running_in_github_actions(self) -> bool: + return os.environ.get("GITHUB_ACTIONS", None) == "true" + + @property + def version_info(self) -> str: + from sqlmesh.cli.main import _sqlmesh_version + + return _sqlmesh_version() + + def _generate_plan_flags_section( + self, user_provided_flags: t.Dict[str, UserProvidedFlags] + ) -> str: + # collapsed section syntax: + # https://docs.github.com/en/get-started/writing-on-github/working-with-advanced-formatting/organizing-information-with-collapsed-sections#creating-a-collapsed-section + section = "
\n\nPlan flags\n\n" + for flag_name, flag_value in user_provided_flags.items(): + section += f"- `{flag_name}` = `{flag_value}`\n" + section += "\n
" + + return section + + def _generate_pr_environment_summary_intro(self) -> str: + note = "" + subset_reasons = [] + + if self.bot_config.skip_pr_backfill: + subset_reasons.append("`skip_pr_backfill` is enabled") + + if default_pr_start := self.bot_config.default_pr_start: + subset_reasons.append(f"`default_pr_start` is set to `{default_pr_start}`") + + if subset_reasons: + note = ( + "> [!IMPORTANT]\n" + f"> This PR environment may only contain a subset of data because:\n" + + "\n".join(f"> - {r}" for r in subset_reasons) + + "\n" + "> \n" + "> This means that deploying to `prod` may not be a simple virtual update if there is still some data to load.\n" + "> See `Dates not loaded in PR` below or the `Prod Plan Preview` check for more information.\n\n" + ) + + return ( + f"Here is a summary of data that has been loaded into the PR environment `{self.pr_environment_name}` and could be deployed to `prod`.\n\n" + + note + ) + + def _generate_pr_environment_summary_list(self, plan: Plan) -> str: + added_snapshot_ids = set(plan.context_diff.added) + modified_snapshot_ids = set( + s.snapshot_id for s, _ in plan.context_diff.modified_snapshots.values() + ) + removed_snapshot_ids = set(plan.context_diff.removed_snapshots.keys()) + + # note: we sort these to get a deterministic order for the output tests + table_records = sorted( + [ + SnapshotSummaryRecord(snapshot_id=snapshot_id, plan=plan) + for snapshot_id in ( + added_snapshot_ids | modified_snapshot_ids | removed_snapshot_ids + ) + ], + key=lambda r: r.display_name, + ) + + sections = [ + ("### Added", [r for r in table_records if r.is_added]), + ("### Removed", [r for r in table_records if r.is_removed]), + ("### Directly Modified", [r for r in table_records if r.is_directly_modified]), + ("### Indirectly Modified", [r for r in table_records if r.is_indirectly_modified]), + ( + "### Metadata Updated", + [r for r in table_records if r.is_metadata_updated and not r.is_modified], + ), + ] + + summary = "" + for title, records in sections: + if records: + summary += f"\n{title}\n" + + for record in records: + summary += f"{record.as_markdown_list_item}\n" + + return summary + + +@dataclass +class SnapshotSummaryRecord: + snapshot_id: SnapshotId + plan: Plan + + @property + def snapshot(self) -> Snapshot: + if self.is_removed: + raise ValueError("Removed snapshots only have SnapshotTableInfo available") + return self.plan.snapshots[self.snapshot_id] + + @cached_property + def snapshot_table_info(self) -> SnapshotTableInfo: + if self.is_removed: + return self.plan.modified_snapshots[self.snapshot_id].table_info + return self.plan.snapshots[self.snapshot_id].table_info + + @property + def display_name(self) -> str: + dialect = None if self.is_removed else self.snapshot.node.dialect + return self.snapshot_table_info.display_name( + self.plan.environment_naming_info, default_catalog=None, dialect=dialect + ) + + @property + def change_category(self) -> str: + if self.is_removed: + return SNAPSHOT_CHANGE_CATEGORY_STR[SnapshotChangeCategory.BREAKING] + + if change_category := self.snapshot.change_category: + return SNAPSHOT_CHANGE_CATEGORY_STR[change_category] + + return "Uncategorized" + + @property + def is_added(self) -> bool: + return self.snapshot_id in self.plan.context_diff.added + + @property + def is_removed(self) -> bool: + return self.snapshot_id in self.plan.context_diff.removed_snapshots + + @property + def is_dev_preview(self) -> bool: + return not self.plan.deployability_index.is_deployable(self.snapshot_id) + + @property + def is_directly_modified(self) -> bool: + return self.plan.context_diff.directly_modified(self.snapshot_table_info.name) + + @property + def is_indirectly_modified(self) -> bool: + return self.plan.context_diff.indirectly_modified(self.snapshot_table_info.name) + + @property + def is_modified(self) -> bool: + return self.is_directly_modified or self.is_indirectly_modified + + @property + def is_metadata_updated(self) -> bool: + return self.plan.context_diff.metadata_updated(self.snapshot_table_info.name) + + @property + def is_incremental(self) -> bool: + return self.snapshot_table_info.is_incremental + + @property + def modification_type(self) -> str: + if self.is_directly_modified: + return "Directly modified" + if self.is_indirectly_modified: + return "Indirectly modified" + if self.is_metadata_updated: + return "Metadata updated" + + return "Unknown" + + @property + def loaded_intervals(self) -> SnapshotIntervals: + if self.is_removed: + raise ValueError("Removed snapshots dont have loaded intervals available") + + return SnapshotIntervals( + snapshot_id=self.snapshot_id, + intervals=( + self.snapshot.dev_intervals + if self.snapshot.is_forward_only + else self.snapshot.intervals + ), + ) + + @property + def loaded_intervals_rendered(self) -> str: + if self.is_removed: + return "REMOVED" + + return self._format_intervals(self.loaded_intervals) + + @property + def missing_intervals(self) -> t.Optional[SnapshotIntervals]: + return next( + (si for si in self.plan.missing_intervals if si.snapshot_id == self.snapshot_id), + None, + ) + + @property + def missing_intervals_formatted(self) -> str: + if not self.is_removed and (intervals := self.missing_intervals): + return self._format_intervals(intervals) + + return "N/A" + + @property + def as_markdown_list_item(self) -> str: + if self.is_removed: + return f"- `{self.display_name}` ({self.change_category})" + + how_applied = "" + + if not self.is_incremental: + from sqlmesh.core.console import _format_missing_intervals + + # note: this is to re-use the '[recreate view]' and '[full refresh]' text and keep it in sync with updates to the CLI + # it doesnt actually use the passed intervals, those are handled differently + how_applied = _format_missing_intervals(self.snapshot, self.loaded_intervals) + + how_applied_str = f" [{how_applied}]" if how_applied else "" + + item = f"- `{self.display_name}` ({self.change_category})\n" + + if self.snapshot_table_info.model_kind_name: + item += f" **Kind:** {self.snapshot_table_info.model_kind_name}{how_applied_str}\n" + + if self.is_incremental: + # in-depth interval info is only relevant for incremental models + item += f" **Dates loaded in PR:** [{self.loaded_intervals_rendered}]\n" + if self.missing_intervals: + item += f" **Dates *not* loaded in PR:** [{self.missing_intervals_formatted}]\n" + + return item + + def _format_intervals(self, intervals: SnapshotIntervals) -> str: + preview_modifier = " (**preview**)" if self.is_dev_preview else "" + return f"{intervals.format_intervals(self.snapshot.node.interval_unit)}{preview_modifier}" diff --git a/sqlmesh/integrations/llm.py b/sqlmesh/integrations/llm.py deleted file mode 100644 index eb1c7148d5..0000000000 --- a/sqlmesh/integrations/llm.py +++ /dev/null @@ -1,53 +0,0 @@ -from __future__ import annotations - -import typing as t - -from langchain import LLMChain, PromptTemplate -from langchain.chat_models import ChatOpenAI - -from sqlmesh.core.model import Model - -_QUERY_PROMPT_TEMPLATE = """Given an input request, create a syntactically correct {dialect} SQL query. -Use full table names. -Convert string operands to lowercase in the WHERE clause. -Reply with a SQL query and nothing else. - -Use the following tables and columns: - -{table_info} - -Request: {input}""" - - -class LLMIntegration: - def __init__( - self, - models: t.Iterable[Model], - dialect: str, - temperature: float = 0.7, - verbose: bool = False, - ): - query_prompt_template = PromptTemplate.from_template(_QUERY_PROMPT_TEMPLATE).partial( - dialect=dialect, table_info=_to_table_info(models) - ) - llm = ChatOpenAI(temperature=temperature) # type: ignore - self._query_chain = LLMChain(llm=llm, prompt=query_prompt_template, verbose=verbose) - - def query(self, prompt: str) -> str: - result = self._query_chain.predict(input=prompt).strip() - select_pos = result.find("SELECT") - if select_pos >= 0: - return result[select_pos:] - return result - - -def _to_table_info(models: t.Iterable[Model]) -> str: - infos = [] - for model in models: - if not model.kind.is_materialized: - continue - - columns_csv = ", ".join(model.columns_to_types_or_raise) - infos.append(f"Table: {model.name}\nColumns: {columns_csv}\n") - - return "\n".join(infos) diff --git a/sqlmesh/integrations/slack.py b/sqlmesh/integrations/slack.py index 026cfe6912..495978b984 100644 --- a/sqlmesh/integrations/slack.py +++ b/sqlmesh/integrations/slack.py @@ -1,16 +1,10 @@ """Helpers for building robust Slack messages""" import json -import sys import typing as t from enum import Enum from textwrap import dedent -if sys.version_info >= (3, 8): - from typing import TypedDict -else: - from typing_extensions import TypedDict - SLACK_MAX_TEXT_LENGTH = 3000 SLACK_MAX_ALERT_PREVIEW_BLOCKS = 5 @@ -21,12 +15,13 @@ TSlackBlock = t.Dict[str, t.Any] -class TSlackBlocks(TypedDict): +class TSlackBlocks(t.TypedDict): blocks: t.List[TSlackBlock] class TSlackMessage(TSlackBlocks): attachments: t.List[TSlackBlocks] + text: str class SlackMessageComposer: @@ -34,7 +29,11 @@ class SlackMessageComposer: def __init__(self, initial_message: t.Optional[TSlackMessage] = None) -> None: """Initialize the Slack message builder""" - self.slack_message = initial_message or {"blocks": [], "attachments": [{"blocks": []}]} + self.slack_message: TSlackMessage = initial_message or { + "text": "", + "blocks": [], + "attachments": [{"blocks": []}], + } def add_primary_blocks(self, *blocks: TSlackBlock) -> "SlackMessageComposer": """Add blocks to the message. Blocks are always displayed""" @@ -52,6 +51,14 @@ def add_secondary_blocks(self, *blocks: TSlackBlock) -> "SlackMessageComposer": raise ValueError("Too many attachments") return self + def add_text(self, text: str) -> "SlackMessageComposer": + """Add text to the message + + This text is used in places where content cannot be rendered such as: system push notifications, assistive technology such as screen readers, etc. + """ + self.slack_message["text"] = normalize_message(text) + return self + def _introspect(self) -> "SlackMessageComposer": """Print the message to stdout diff --git a/sqlmesh/lsp/api.py b/sqlmesh/lsp/api.py new file mode 100644 index 0000000000..882ca9825b --- /dev/null +++ b/sqlmesh/lsp/api.py @@ -0,0 +1,84 @@ +""" +This module maps the LSP custom API calls to the SQLMesh web api. + +Allowing the LSP to call the web api without having to know the details of the web api +and thus passing through the details of the web api to the LSP, so that both the LSP +and the web api can communicate with the same process, avoiding the need to have a +separate process for the web api. +""" + +import typing as t +from pydantic import field_validator +from sqlmesh.lsp.custom import ( + CustomMethodRequestBaseClass, + CustomMethodResponseBaseClass, +) +from web.server.models import LineageColumn, Model, TableDiff + +API_FEATURE = "sqlmesh/api" + + +class ApiRequest(CustomMethodRequestBaseClass): + """ + Request to call the SQLMesh API. + This is a generic request that can be used to call any API endpoint. + """ + + requestId: str + url: str + method: t.Optional[str] = "GET" + params: t.Optional[t.Dict[str, t.Any]] = None + body: t.Optional[t.Dict[str, t.Any]] = None + + +class BaseAPIResponse(CustomMethodResponseBaseClass): + error: t.Optional[str] = None + + +class ApiResponseGetModels(BaseAPIResponse): + """ + Response from the SQLMesh API for the get_models endpoint. + """ + + data: t.List[Model] + + @field_validator("data", mode="before") + def sanitize_datetime_fields(cls, data: t.List[Model]) -> t.List[Model]: + """ + Convert datetime objects to None to avoid serialization issues. + """ + if isinstance(data, list): + for model in data: + if hasattr(model, "details") and model.details: + # Convert datetime fields to None to avoid serialization issues + for field in ["stamp", "start", "cron_prev", "cron_next"]: + if ( + hasattr(model.details, field) + and getattr(model.details, field) is not None + ): + setattr(model.details, field, None) + return data + + +class ApiResponseGetLineage(BaseAPIResponse): + """ + Response from the SQLMesh API for the get_lineage endpoint. + """ + + data: t.Dict[str, t.List[str]] + + +class ApiResponseGetColumnLineage(BaseAPIResponse): + """ + Response from the SQLMesh API for the get_column_lineage endpoint. + """ + + data: t.Dict[str, t.Dict[str, LineageColumn]] + + +class ApiResponseGetTableDiff(BaseAPIResponse): + """ + Response from the SQLMesh API for the get_table_diff endpoint. + """ + + data: t.Optional[TableDiff] diff --git a/sqlmesh/lsp/commands.py b/sqlmesh/lsp/commands.py new file mode 100644 index 0000000000..bea81f898a --- /dev/null +++ b/sqlmesh/lsp/commands.py @@ -0,0 +1 @@ +EXTERNAL_MODEL_UPDATE_COLUMNS = "sqlmesh.external_model_update_columns" diff --git a/sqlmesh/lsp/completions.py b/sqlmesh/lsp/completions.py new file mode 100644 index 0000000000..93162b15a4 --- /dev/null +++ b/sqlmesh/lsp/completions.py @@ -0,0 +1,194 @@ +from functools import lru_cache +from sqlglot import Dialect, Tokenizer +from sqlmesh.lsp.custom import ( + AllModelsResponse, + MacroCompletion, + ModelCompletion, +) +from sqlmesh import macro +import typing as t +from sqlmesh.lsp.context import AuditTarget, LSPContext, ModelTarget +from sqlmesh.lsp.uri import URI +from sqlmesh.utils.lineage import generate_markdown_description + + +def get_sql_completions( + context: t.Optional[LSPContext] = None, + file_uri: t.Optional[URI] = None, + content: t.Optional[str] = None, +) -> AllModelsResponse: + """ + Return a list of completions for a given file. + """ + # Get SQL keywords for the dialect + sql_keywords = get_keywords(context, file_uri) + + # Get keywords from file content if provided + file_keywords = set() + if content: + file_keywords = extract_keywords_from_content(content, get_dialect(context, file_uri)) + + # Combine keywords - SQL keywords first, then file keywords + all_keywords = list(sql_keywords) + list(file_keywords - sql_keywords) + + models = list(get_models(context, file_uri)) + return AllModelsResponse( + models=[m.name for m in models], + model_completions=models, + keywords=all_keywords, + macros=list(get_macros(context, file_uri)), + ) + + +def get_models( + context: t.Optional[LSPContext], file_uri: t.Optional[URI] +) -> t.List[ModelCompletion]: + """ + Return a list of models for a given file. + + If there is no context, return an empty list. + If there is a context, return a list of all models bar the ones the file itself defines. + """ + if context is None: + return [] + + current_path = file_uri.to_path() if file_uri is not None else None + + completions: t.List[ModelCompletion] = [] + for model in context.context.models.values(): + if current_path is not None and model._path == current_path: + continue + description = None + try: + description = generate_markdown_description(model) + except Exception: + description = getattr(model, "description", None) + + completions.append(ModelCompletion(name=model.name, description=description)) + + return completions + + +def get_macros( + context: t.Optional[LSPContext], file_uri: t.Optional[URI] +) -> t.List[MacroCompletion]: + """Return a list of macros with optional descriptions.""" + macros: t.Dict[str, t.Optional[str]] = {} + + for name, m in macro.get_registry().items(): + macros[name] = getattr(m.func, "__doc__", None) + + try: + if context is not None: + for name, m in context.context._macros.items(): + macros[name] = getattr(m.func, "__doc__", None) + except Exception: + pass + + return [MacroCompletion(name=name, description=doc) for name, doc in macros.items()] + + +def get_keywords(context: t.Optional[LSPContext], file_uri: t.Optional[URI]) -> t.Set[str]: + """ + Return a list of sql keywords for a given file. + If no context is provided, return ANSI SQL keywords. + + If a context is provided but no file_uri is provided, returns the keywords + for the default dialect of the context. + + If both a context and a file_uri are provided, returns the keywords + for the dialect of the model that the file belongs to. + """ + if file_uri is not None and context is not None and file_uri.to_path() in context.map: + file_info = context.map[file_uri.to_path()] + + # Handle ModelInfo objects + if isinstance(file_info, ModelTarget) and file_info.names: + model_name = file_info.names[0] + model_from_context = context.context.get_model(model_name) + if model_from_context is not None and model_from_context.dialect: + return get_keywords_from_tokenizer(model_from_context.dialect) + + # Handle AuditInfo objects + elif isinstance(file_info, AuditTarget) and file_info.name: + audit = context.context.standalone_audits.get(file_info.name) + if audit is not None and audit.dialect: + return get_keywords_from_tokenizer(audit.dialect) + + if context is not None: + return get_keywords_from_tokenizer(context.context.default_dialect) + + return get_keywords_from_tokenizer(None) + + +@lru_cache() +def get_keywords_from_tokenizer(dialect: t.Optional[str] = None) -> t.Set[str]: + """ + Return a list of sql keywords for a given dialect. This is separate from + the direct use of Tokenizer.KEYWORDS.keys() because that returns a set of + keywords that are expanded, e.g. "ORDER BY" -> ["ORDER", "BY"]. + """ + tokenizer = Tokenizer + if dialect is not None: + try: + tokenizer = Dialect.get_or_raise(dialect).tokenizer_class + except Exception: + pass + + expanded_keywords = set() + for keyword in tokenizer.KEYWORDS.keys(): + parts = keyword.split(" ") + expanded_keywords.update(parts) + return expanded_keywords + + +def get_dialect(context: t.Optional[LSPContext], file_uri: t.Optional[URI]) -> t.Optional[str]: + """ + Get the dialect for a given file. + """ + if file_uri is not None and context is not None and file_uri.to_path() in context.map: + file_info = context.map[file_uri.to_path()] + + # Handle ModelInfo objects + if isinstance(file_info, ModelTarget) and file_info.names: + model_name = file_info.names[0] + model_from_context = context.context.get_model(model_name) + return model_from_context.dialect + + # Handle AuditInfo objects + if isinstance(file_info, AuditTarget) and file_info.name: + audit = context.context.standalone_audits.get(file_info.name) + if audit is not None and audit.dialect: + return audit.dialect + + if context is not None: + return context.context.default_dialect + + return None + + +def extract_keywords_from_content(content: str, dialect: t.Optional[str] = None) -> t.Set[str]: + """ + Extract identifiers from SQL content using the tokenizer. + + Only extracts identifiers (variable names, table names, column names, etc.) + that are not SQL keywords. + """ + if not content: + return set() + + tokenizer_class = Dialect.get_or_raise(dialect).tokenizer_class + keywords = set() + try: + tokenizer = tokenizer_class() + tokens = tokenizer.tokenize(content) + for token in tokens: + # Don't include keywords in the set + if token.text.upper() not in tokenizer_class.KEYWORDS: + keywords.add(token.text) + + except Exception: + # If tokenization fails, return an empty set + pass + + return keywords diff --git a/sqlmesh/lsp/context.py b/sqlmesh/lsp/context.py new file mode 100644 index 0000000000..a94db7c421 --- /dev/null +++ b/sqlmesh/lsp/context.py @@ -0,0 +1,544 @@ +from dataclasses import dataclass +from pathlib import Path +from pygls.server import LanguageServer +from sqlmesh.core.context import Context +import typing as t +from sqlmesh.core.linter.rule import Range +from sqlmesh.core.model.definition import SqlModel, ExternalModel +from sqlmesh.core.linter.definition import AnnotatedRuleViolation +from sqlmesh.core.schema_loader import get_columns +from sqlmesh.lsp.commands import EXTERNAL_MODEL_UPDATE_COLUMNS +from sqlmesh.lsp.custom import ModelForRendering, TestEntry, RunTestResponse +from sqlmesh.lsp.custom import AllModelsResponse, RenderModelEntry +from sqlmesh.lsp.tests_ranges import get_test_ranges +from sqlmesh.lsp.helpers import to_lsp_range +from sqlmesh.lsp.uri import URI +from lsprotocol import types +from sqlmesh.utils import yaml +from sqlmesh.utils.lineage import get_yaml_model_name_ranges + + +@dataclass +class ModelTarget: + """Information about models in a file.""" + + names: t.List[str] + + +@dataclass +class AuditTarget: + """Information about standalone audits in a file.""" + + name: str + + +class LSPContext: + """ + A context that is used for linting. It contains the context and a reverse map of file uri to + model names and standalone audit names. + """ + + map: t.Dict[Path, t.Union[ModelTarget, AuditTarget]] + _render_cache: t.Dict[Path, t.List[RenderModelEntry]] + _lint_cache: t.Dict[Path, t.List[AnnotatedRuleViolation]] + + def __init__(self, context: Context) -> None: + self.context = context + self._render_cache = {} + self._lint_cache = {} + + # Add models to the map + model_map: t.Dict[Path, ModelTarget] = {} + for model in context.models.values(): + if model._path is not None: + uri = model._path + if uri in model_map: + model_map[uri].names.append(model.name) + else: + model_map[uri] = ModelTarget(names=[model.name]) + + # Add standalone audits to the map + audit_map: t.Dict[Path, AuditTarget] = {} + for audit in context.standalone_audits.values(): + if audit._path is not None: + uri = audit._path + if uri not in audit_map: + audit_map[uri] = AuditTarget(name=audit.name) + + self.map: t.Dict[Path, t.Union[ModelTarget, AuditTarget]] = { + **model_map, + **audit_map, + } + + def list_workspace_tests(self) -> t.List[TestEntry]: + """List all tests in the workspace.""" + tests = self.context.select_tests() + + # Use a set to ensure unique URIs + unique_test_uris = {URI.from_path(test.path).value for test in tests} + test_uris: t.Dict[str, t.Dict[str, Range]] = {} + for uri in unique_test_uris: + test_ranges = get_test_ranges(URI(uri).to_path()) + if uri not in test_uris: + test_uris[uri] = {} + + test_uris[uri].update(test_ranges) + + return [ + TestEntry( + name=test.test_name, + uri=URI.from_path(test.path).value, + range=test_uris.get(URI.from_path(test.path).value, {}).get(test.test_name), + ) + for test in tests + ] + + def get_document_tests(self, uri: URI) -> t.List[TestEntry]: + """Get tests for a specific document. + + Args: + uri: The URI of the file to get tests for. + + Returns: + List of TestEntry objects for the specified document. + """ + tests = self.context.select_tests(tests=[str(uri.to_path())]) + test_ranges = get_test_ranges(uri.to_path()) + return [ + TestEntry( + name=test.test_name, + uri=URI.from_path(test.path).value, + range=test_ranges.get(test.test_name), + ) + for test in tests + ] + + def run_test(self, uri: URI, test_name: str) -> RunTestResponse: + """Run a specific test for a model. + + Args: + uri: The URI of the file containing the test. + test_name: The name of the test to run. + + Returns: + List of annotated rule violations from the test run. + """ + path = uri.to_path() + results = self.context.test( + tests=[str(path)], + match_patterns=[test_name], + ) + if results.testsRun != 1: + raise ValueError(f"Expected to run 1 test, but ran {results.testsRun} tests.") + if len(results.successes) == 1: + return RunTestResponse(success=True) + return RunTestResponse( + success=False, + error_message=str(results.failures[0][1]), + ) + + def render_model(self, uri: URI) -> t.List[RenderModelEntry]: + """Get rendered models for a file, using cache when available. + + Args: + uri: The URI of the file to render. + + Returns: + List of rendered model entries. + """ + path = uri.to_path() + + # Check cache first + if path in self._render_cache: + return self._render_cache[path] + + # If not cached, render and cache + entries: t.List[RenderModelEntry] = [] + target = self.map.get(path) + + if isinstance(target, AuditTarget): + audit = self.context.standalone_audits[target.name] + definition = audit.render_definition( + include_python=False, + render_query=True, + ) + rendered_query = [ + render.sql(dialect=audit.dialect, pretty=True) for render in definition + ] + entry = RenderModelEntry( + name=audit.name, + fqn=audit.fqn, + description=audit.description, + rendered_query="\n\n".join(rendered_query), + ) + entries.append(entry) + + elif isinstance(target, ModelTarget): + for name in target.names: + model = self.context.get_model(name) + if isinstance(model, SqlModel): + rendered_query = [ + render.sql(dialect=model.dialect, pretty=True) + for render in model.render_definition( + include_python=False, + render_query=True, + ) + ] + entry = RenderModelEntry( + name=model.name, + fqn=model.fqn, + description=model.description, + rendered_query="\n\n".join(rendered_query), + ) + entries.append(entry) + + # Store in cache + self._render_cache[path] = entries + return entries + + def lint_model(self, uri: URI) -> t.List[AnnotatedRuleViolation]: + """Get lint diagnostics for a model, using cache when available. + + Args: + uri: The URI of the file to lint. + + Returns: + List of annotated rule violations. + """ + path = uri.to_path() + + # Check cache first + if path in self._lint_cache: + return self._lint_cache[path] + + # If not cached, lint and cache + target = self.map.get(path) + if target is None or not isinstance(target, ModelTarget): + return [] + + diagnostics = self.context.lint_models( + target.names, + raise_on_error=False, + ) + + # Store in cache + self._lint_cache[path] = diagnostics + return diagnostics + + def get_code_actions( + self, uri: URI, params: types.CodeActionParams + ) -> t.Optional[t.List[t.Union[types.Command, types.CodeAction]]]: + """Get code actions for a file.""" + + # Get the violations (which contain the fixes) + violations = self.lint_model(uri) + + # Convert violations to a map for quick lookup + # Use a hashable representation of Range as the key + violation_map: t.Dict[ + t.Tuple[str, t.Tuple[int, int, int, int]], AnnotatedRuleViolation + ] = {} + for violation in violations: + if violation.violation_range: + lsp_diagnostic = self.diagnostic_to_lsp_diagnostic(violation) + if lsp_diagnostic: + # Create a hashable key from the diagnostic message and range + key = ( + lsp_diagnostic.message, + ( + lsp_diagnostic.range.start.line, + lsp_diagnostic.range.start.character, + lsp_diagnostic.range.end.line, + lsp_diagnostic.range.end.character, + ), + ) + violation_map[key] = violation + + # Get diagnostics in the requested range + diagnostics = params.context.diagnostics if params.context else [] + + code_actions: t.List[t.Union[types.Command, types.CodeAction]] = [] + + for diagnostic in diagnostics: + # Find the corresponding violation + key = ( + diagnostic.message, + ( + diagnostic.range.start.line, + diagnostic.range.start.character, + diagnostic.range.end.line, + diagnostic.range.end.character, + ), + ) + found_violation = violation_map.get(key) + + if found_violation is not None and found_violation.fixes: + # Create code actions for each fix + for fix in found_violation.fixes: + changes: t.Dict[str, t.List[types.TextEdit]] = {} + document_changes: t.List[ + t.Union[ + types.TextDocumentEdit, + types.CreateFile, + types.RenameFile, + types.DeleteFile, + ] + ] = [] + + for create in fix.create_files: + create_uri = URI.from_path(create.path).value + document_changes.append(types.CreateFile(uri=create_uri)) + document_changes.append( + types.TextDocumentEdit( + text_document=types.OptionalVersionedTextDocumentIdentifier( + uri=create_uri, + version=None, + ), + edits=[ + types.TextEdit( + range=types.Range( + start=types.Position(line=0, character=0), + end=types.Position(line=0, character=0), + ), + new_text=create.text, + ) + ], + ) + ) + + for edit in fix.edits: + uri_key = URI.from_path(edit.path).value + if uri_key not in changes: + changes[uri_key] = [] + changes[uri_key].append( + types.TextEdit( + range=types.Range( + start=types.Position( + line=edit.range.start.line, + character=edit.range.start.character, + ), + end=types.Position( + line=edit.range.end.line, + character=edit.range.end.character, + ), + ), + new_text=edit.new_text, + ) + ) + + workspace_edit = types.WorkspaceEdit( + changes=changes if changes else None, + document_changes=document_changes if document_changes else None, + ) + code_action = types.CodeAction( + title=fix.title, + kind=types.CodeActionKind.QuickFix, + diagnostics=[diagnostic], + edit=workspace_edit, + ) + code_actions.append(code_action) + + return code_actions if code_actions else None + + def get_code_lenses(self, uri: URI) -> t.Optional[t.List[types.CodeLens]]: + models_in_file = self.map.get(uri.to_path()) + if isinstance(models_in_file, ModelTarget): + models = [self.context.get_model(model) for model in models_in_file.names] + if any(isinstance(model, ExternalModel) for model in models): + code_lenses = self._get_external_model_code_lenses(uri) + if code_lenses: + return code_lenses + + return None + + def _get_external_model_code_lenses(self, uri: URI) -> t.List[types.CodeLens]: + """Get code lenses for external models YAML files.""" + ranges = get_yaml_model_name_ranges(uri.to_path()) + if ranges is None: + return [] + return [ + types.CodeLens( + range=to_lsp_range(range), + command=types.Command( + title="Update Columns", + command=EXTERNAL_MODEL_UPDATE_COLUMNS, + arguments=[ + name, + ], + ), + ) + for name, range in ranges.items() + ] + + def list_of_models_for_rendering(self) -> t.List[ModelForRendering]: + """Get a list of models for rendering. + + Returns: + List of ModelForRendering objects. + """ + return [ + ModelForRendering( + name=model.name, + fqn=model.fqn, + description=model.description, + uri=URI.from_path(model._path).value, + ) + for model in self.context.models.values() + if isinstance(model, SqlModel) and model._path is not None + ] + [ + ModelForRendering( + name=audit.name, + fqn=audit.fqn, + description=audit.description, + uri=URI.from_path(audit._path).value, + ) + for audit in self.context.standalone_audits.values() + if audit._path is not None + ] + + @staticmethod + def get_completions( + self: t.Optional["LSPContext"] = None, + uri: t.Optional[URI] = None, + file_content: t.Optional[str] = None, + ) -> AllModelsResponse: + """Get completion suggestions for a file""" + from sqlmesh.lsp.completions import get_sql_completions + + return get_sql_completions(self, uri, file_content) + + @staticmethod + def diagnostics_to_lsp_diagnostics( + diagnostics: t.List[AnnotatedRuleViolation], + ) -> t.List[types.Diagnostic]: + """ + Converts a list of AnnotatedRuleViolations to a list of LSP diagnostics. It will remove duplicates based on the message and range. + """ + lsp_diagnostics = {} + for diagnostic in diagnostics: + lsp_diagnostic = LSPContext.diagnostic_to_lsp_diagnostic(diagnostic) + if lsp_diagnostic is not None: + # Create a unique key combining message and range + diagnostic_key = ( + lsp_diagnostic.message, + lsp_diagnostic.range.start.line, + lsp_diagnostic.range.start.character, + lsp_diagnostic.range.end.line, + lsp_diagnostic.range.end.character, + ) + if diagnostic_key not in lsp_diagnostics: + lsp_diagnostics[diagnostic_key] = lsp_diagnostic + return list(lsp_diagnostics.values()) + + @staticmethod + def diagnostic_to_lsp_diagnostic( + diagnostic: AnnotatedRuleViolation, + ) -> t.Optional[types.Diagnostic]: + if diagnostic.model._path is None: + return None + if not diagnostic.violation_range: + with open(diagnostic.model._path, "r", encoding="utf-8") as file: + lines = file.readlines() + diagnostic_range = types.Range( + start=types.Position(line=0, character=0), + end=types.Position(line=len(lines) - 1, character=len(lines[-1])), + ) + else: + diagnostic_range = types.Range( + start=types.Position( + line=diagnostic.violation_range.start.line, + character=diagnostic.violation_range.start.character, + ), + end=types.Position( + line=diagnostic.violation_range.end.line, + character=diagnostic.violation_range.end.character, + ), + ) + + # Get rule definition location for diagnostics link + rule_location = diagnostic.rule.get_definition_location() + rule_uri_wihout_extension = URI.from_path(rule_location.file_path) + rule_uri = f"{rule_uri_wihout_extension.value}#L{rule_location.start_line}" + + # Use URI format to create a link for "related information" + return types.Diagnostic( + range=diagnostic_range, + message=diagnostic.violation_msg, + severity=types.DiagnosticSeverity.Error + if diagnostic.violation_type == "error" + else types.DiagnosticSeverity.Warning, + source="sqlmesh", + code=diagnostic.rule.name, + code_description=types.CodeDescription(href=rule_uri), + ) + + def update_external_model_columns(self, ls: LanguageServer, uri: URI, model_name: str) -> bool: + """ + Update the columns for an external model in the YAML file. Returns True if changed, False if didn't because + of the columns already being up to date. + + In this case, the model name is the name of the external model as is defined in the YAML file, not any other version of it. + + Errors still throw exceptions to be handled by the caller. + """ + models = yaml.load(uri.to_path()) + if not isinstance(models, list): + raise ValueError( + f"Expected a list of models in {uri.to_path()}, but got {type(models).__name__}" + ) + + existing_model = next((model for model in models if model.get("name") == model_name), None) + if existing_model is None: + raise ValueError(f"Could not find model {model_name} in {uri.to_path()}") + + existing_model_columns = existing_model.get("columns") + + # Get the adapter and fetch columns + adapter = self.context.engine_adapter + # Get columns for the model + new_columns = get_columns( + adapter=adapter, + dialect=self.context.config.model_defaults.dialect, + table=model_name, + strict=True, + ) + # Compare existing columns and matching types and if they are the same, do not update + if existing_model_columns is not None: + if existing_model_columns == new_columns: + return False + + # Model index to update + model_index = next( + (i for i, model in enumerate(models) if model.get("name") == model_name), None + ) + if model_index is None: + raise ValueError(f"Could not find model {model_name} in {uri.to_path()}") + + # Get end of the file to set the edit range + with open(uri.to_path(), "r", encoding="utf-8") as file: + read_file = file.read() + + end_line = read_file.count("\n") + end_character = len(read_file.splitlines()[-1]) if end_line > 0 else 0 + + models[model_index]["columns"] = new_columns + edit = types.TextDocumentEdit( + text_document=types.OptionalVersionedTextDocumentIdentifier( + uri=uri.value, + version=None, + ), + edits=[ + types.TextEdit( + range=types.Range( + start=types.Position(line=0, character=0), + end=types.Position( + line=end_line, + character=end_character, + ), + ), + new_text=yaml.dump(models), + ) + ], + ) + ls.apply_edit(types.WorkspaceEdit(document_changes=[edit])) + return True diff --git a/sqlmesh/lsp/custom.py b/sqlmesh/lsp/custom.py new file mode 100644 index 0000000000..84be43ee0e --- /dev/null +++ b/sqlmesh/lsp/custom.py @@ -0,0 +1,257 @@ +from lsprotocol import types +import typing as t + +from sqlmesh.core.linter.rule import Range +from sqlmesh.utils.pydantic import PydanticModel + + +class CustomMethodRequestBaseClass(PydanticModel): + pass + + +class CustomMethodResponseBaseClass(PydanticModel): + # Prefixing, so guaranteed not to collide + response_error: t.Optional[str] = None + + +ALL_MODELS_FEATURE = "sqlmesh/all_models" + + +class AllModelsRequest(CustomMethodRequestBaseClass): + """ + Request to get all the models that are in the current project. + """ + + textDocument: types.TextDocumentIdentifier + + +class MacroCompletion(PydanticModel): + """Information about a macro for autocompletion.""" + + name: str + description: t.Optional[str] = None + + +class ModelCompletion(PydanticModel): + """Information about a model for autocompletion.""" + + name: str + description: t.Optional[str] = None + + +class AllModelsResponse(CustomMethodResponseBaseClass): + """Response to get all models that are in the current project.""" + + #: Deprecated: use ``model_completions`` instead + models: t.List[str] + model_completions: t.List[ModelCompletion] + keywords: t.List[str] + macros: t.List[MacroCompletion] + + +RENDER_MODEL_FEATURE = "sqlmesh/render_model" + + +class RenderModelRequest(CustomMethodRequestBaseClass): + textDocumentUri: str + + +class RenderModelEntry(PydanticModel): + """ + An entry in the rendered model. + """ + + name: str + fqn: str + description: t.Optional[str] = None + rendered_query: str + + +class RenderModelResponse(CustomMethodResponseBaseClass): + """ + Response to render a model. + """ + + models: t.List[RenderModelEntry] + + +ALL_MODELS_FOR_RENDER_FEATURE = "sqlmesh/all_models_for_render" + + +class ModelForRendering(PydanticModel): + """ + A model that is available for rendering. + """ + + name: str + fqn: str + description: t.Optional[str] = None + uri: str + + +class AllModelsForRenderRequest(CustomMethodRequestBaseClass): + pass + + +class AllModelsForRenderResponse(CustomMethodResponseBaseClass): + """ + Response to get all the models that are in the current project for rendering purposes. + """ + + models: t.List[ModelForRendering] + + +SUPPORTED_METHODS_FEATURE = "sqlmesh/supported_methods" + + +class SupportedMethodsRequest(PydanticModel): + """ + Request to get all supported custom LSP methods. + """ + + pass + + +class CustomMethod(PydanticModel): + """ + Information about a custom LSP method. + """ + + name: str + + +class SupportedMethodsResponse(CustomMethodResponseBaseClass): + """ + Response containing all supported custom LSP methods. + """ + + methods: t.List[CustomMethod] + + +FORMAT_PROJECT_FEATURE = "sqlmesh/format_project" + + +class FormatProjectRequest(CustomMethodRequestBaseClass): + """ + Request to format all models in the current project. + """ + + pass + + +class FormatProjectResponse(CustomMethodResponseBaseClass): + """ + Response to format project request. + """ + + pass + + +LIST_WORKSPACE_TESTS_FEATURE = "sqlmesh/list_workspace_tests" + + +class ListWorkspaceTestsRequest(CustomMethodRequestBaseClass): + """ + Request to list all tests in the current project. + """ + + pass + + +GET_ENVIRONMENTS_FEATURE = "sqlmesh/get_environments" + + +class GetEnvironmentsRequest(CustomMethodRequestBaseClass): + """ + Request to get all environments in the current project. + """ + + pass + + +class TestEntry(PydanticModel): + """ + An entry representing a test in the workspace. + """ + + name: str + uri: str + range: Range + + +class ListWorkspaceTestsResponse(CustomMethodResponseBaseClass): + tests: t.List[TestEntry] + + +LIST_DOCUMENT_TESTS_FEATURE = "sqlmesh/list_document_tests" + + +class ListDocumentTestsRequest(CustomMethodRequestBaseClass): + textDocument: types.TextDocumentIdentifier + + +class ListDocumentTestsResponse(CustomMethodResponseBaseClass): + tests: t.List[TestEntry] + + +RUN_TEST_FEATURE = "sqlmesh/run_test" + + +class RunTestRequest(CustomMethodRequestBaseClass): + textDocument: types.TextDocumentIdentifier + testName: str + + +class RunTestResponse(CustomMethodResponseBaseClass): + success: bool + error_message: t.Optional[str] = None + + +class EnvironmentInfo(PydanticModel): + """ + Information about an environment. + """ + + name: str + snapshots: t.List[str] + start_at: str + plan_id: str + + +class GetEnvironmentsResponse(CustomMethodResponseBaseClass): + """ + Response containing all environments in the current project. + """ + + environments: t.Dict[str, EnvironmentInfo] + pinned_environments: t.Set[str] + default_target_environment: str + + +GET_MODELS_FEATURE = "sqlmesh/get_models" + + +class GetModelsRequest(CustomMethodRequestBaseClass): + """ + Request to get all models available for table diff. + """ + + pass + + +class ModelInfo(PydanticModel): + """ + Information about a model for table diff. + """ + + name: str + fqn: str + description: t.Optional[str] = None + + +class GetModelsResponse(CustomMethodResponseBaseClass): + """ + Response containing all models available for table diff. + """ + + models: t.List[ModelInfo] diff --git a/sqlmesh/lsp/errors.py b/sqlmesh/lsp/errors.py new file mode 100644 index 0000000000..a9e778a555 --- /dev/null +++ b/sqlmesh/lsp/errors.py @@ -0,0 +1,51 @@ +from lsprotocol.types import Diagnostic, DiagnosticSeverity, Range, Position + +from sqlmesh.lsp.uri import URI +from sqlmesh.utils.errors import ( + ConfigError, +) +import typing as t + +ContextFailedError = t.Union[str, ConfigError, Exception] + + +def context_error_to_diagnostic( + error: t.Union[Exception, ContextFailedError], + uri_filter: t.Optional[URI] = None, +) -> t.Tuple[t.Optional[t.Tuple[str, Diagnostic]], ContextFailedError]: + """ + Convert an error to a diagnostic message. + If the error is a ConfigError, it will be converted to a diagnostic message. + + uri_filter is used to filter diagnostics by URI. If present, only diagnostics + with a matching URI will be returned. + """ + if isinstance(error, ConfigError): + return config_error_to_diagnostic(error), error + return None, str(error) + + +def config_error_to_diagnostic( + error: ConfigError, + uri_filter: t.Optional[URI] = None, +) -> t.Optional[t.Tuple[str, Diagnostic]]: + if error.location is None: + return None + uri = URI.from_path(error.location).value + if uri_filter and uri != uri_filter.value: + return None + return uri, Diagnostic( + range=Range( + start=Position( + line=0, + character=0, + ), + end=Position( + line=0, + character=0, + ), + ), + message=str(error), + severity=DiagnosticSeverity.Error, + source="SQLMesh", + ) diff --git a/sqlmesh/lsp/helpers.py b/sqlmesh/lsp/helpers.py new file mode 100644 index 0000000000..920a93f5c7 --- /dev/null +++ b/sqlmesh/lsp/helpers.py @@ -0,0 +1,40 @@ +from lsprotocol.types import Range, Position + +from sqlmesh.core.linter.helpers import ( + Range as SQLMeshRange, + Position as SQLMeshPosition, +) + + +def to_sqlmesh_position(position: Position) -> SQLMeshPosition: + """ + Converts an LSP Position to a SQLMesh Position. + """ + return SQLMeshPosition(line=position.line, character=position.character) + + +def to_lsp_position(position: SQLMeshPosition) -> Position: + """ + Converts a SQLMesh Position to an LSP Position. + """ + return Position(line=position.line, character=position.character) + + +def to_sqlmesh_range(range: Range) -> SQLMeshRange: + """ + Converts an LSP Range to a SQLMesh Range. + """ + return SQLMeshRange( + start=to_sqlmesh_position(range.start), + end=to_sqlmesh_position(range.end), + ) + + +def to_lsp_range(range: SQLMeshRange) -> Range: + """ + Converts a SQLMesh Range to an LSP Range. + """ + return Range( + start=to_lsp_position(range.start), + end=to_lsp_position(range.end), + ) diff --git a/sqlmesh/lsp/hints.py b/sqlmesh/lsp/hints.py new file mode 100644 index 0000000000..a8d56e2f31 --- /dev/null +++ b/sqlmesh/lsp/hints.py @@ -0,0 +1,136 @@ +"""Type hinting on SQLMesh models""" + +import typing as t + +from lsprotocol import types + +from sqlglot import exp +from sqlglot.expressions import Expression +from sqlglot.optimizer.normalize_identifiers import normalize_identifiers +from sqlmesh.core.model.definition import SqlModel +from sqlmesh.lsp.context import LSPContext, ModelTarget +from sqlmesh.lsp.uri import URI + + +def get_hints( + lsp_context: LSPContext, + document_uri: URI, + start_line: int, + end_line: int, +) -> t.List[types.InlayHint]: + """ + Get type hints for certain lines in a document + + Args: + lint_context: The LSP context + document_uri: The URI of the document + start_line: the starting line to get hints for + end_line: the ending line to get hints for + + Returns: + A list of hints to apply to the document + """ + path = document_uri.to_path() + if path.suffix != ".sql": + return [] + + if path not in lsp_context.map: + return [] + + file_info = lsp_context.map[path] + + # Process based on whether it's a model or standalone audit + if not isinstance(file_info, ModelTarget): + return [] + + # It's a model + model = lsp_context.context.get_model( + model_or_snapshot=file_info.names[0], raise_if_missing=False + ) + if not isinstance(model, SqlModel): + return [] + + query = model.query + dialect = model.dialect + columns_to_types = model.columns_to_types or {} + + return _get_type_hints_for_model_from_query( + query, dialect, columns_to_types, start_line, end_line + ) + + +def _get_type_hints_for_select( + expression: exp.Expression, + dialect: str, + columns_to_types: t.Dict[str, exp.DataType], + start_line: int, + end_line: int, +) -> t.List[types.InlayHint]: + hints: t.List[types.InlayHint] = [] + + for select_exp in expression.expressions: + if isinstance(select_exp, exp.Alias): + if isinstance(select_exp.this, exp.Cast): + continue + + meta = select_exp.args["alias"].meta + + elif isinstance(select_exp, exp.Column): + meta = select_exp.parts[-1].meta + else: + continue + + if "line" not in meta or "col" not in meta: + continue + + line = meta["line"] + col = meta["col"] + + # Lines from sqlglot are 1 based + line -= 1 + + if line < start_line or line > end_line: + continue + + name = select_exp.alias_or_name + data_type = columns_to_types.get(name) + + if not data_type or data_type.is_type(exp.DataType.Type.UNKNOWN): + continue + + type_label = data_type.sql(dialect) + hints.append( + types.InlayHint( + label=f"::{type_label}", + kind=types.InlayHintKind.Type, + padding_left=False, + padding_right=True, + position=types.Position(line=line, character=col), + ) + ) + + return hints + + +def _get_type_hints_for_model_from_query( + query: Expression, + dialect: str, + columns_to_types: t.Dict[str, exp.DataType], + start_line: int, + end_line: int, +) -> t.List[types.InlayHint]: + hints: t.List[types.InlayHint] = [] + try: + query = normalize_identifiers(query.copy(), dialect=dialect) + + # Return the hints for top level selects (model definition columns only) + return [ + hint + for q in query.walk(prune=lambda n: not isinstance(n, exp.SetOperation)) + if isinstance(select := q.unnest(), exp.Select) + for hint in _get_type_hints_for_select( + q, dialect, columns_to_types, start_line, end_line + ) + ] + except Exception: + return [] diff --git a/sqlmesh/lsp/main.py b/sqlmesh/lsp/main.py new file mode 100755 index 0000000000..71dc5e1e2b --- /dev/null +++ b/sqlmesh/lsp/main.py @@ -0,0 +1,1204 @@ +#!/usr/bin/env python +"""A Language Server Protocol (LSP) server for SQL with SQLMesh integration, refactored without globals.""" + +from itertools import chain +import logging +import typing as t +from pathlib import Path +import urllib.parse +import uuid + +from lsprotocol import types +from lsprotocol.types import ( + WorkspaceDiagnosticRefreshRequest, + WorkspaceInlayHintRefreshRequest, +) +from pygls.server import LanguageServer +from sqlglot import exp +from sqlmesh._version import __version__ +from sqlmesh.core.context import Context +from sqlmesh.utils.date import to_timestamp +from sqlmesh.lsp.api import ( + API_FEATURE, + ApiRequest, + ApiResponseGetColumnLineage, + ApiResponseGetLineage, + ApiResponseGetModels, + ApiResponseGetTableDiff, +) + +from sqlmesh.lsp.commands import EXTERNAL_MODEL_UPDATE_COLUMNS +from sqlmesh.lsp.completions import get_sql_completions +from sqlmesh.lsp.context import ( + LSPContext, + ModelTarget, +) +from sqlmesh.lsp.custom import ( + ALL_MODELS_FEATURE, + ALL_MODELS_FOR_RENDER_FEATURE, + RENDER_MODEL_FEATURE, + SUPPORTED_METHODS_FEATURE, + FORMAT_PROJECT_FEATURE, + GET_ENVIRONMENTS_FEATURE, + GET_MODELS_FEATURE, + AllModelsRequest, + AllModelsResponse, + AllModelsForRenderRequest, + AllModelsForRenderResponse, + CustomMethodResponseBaseClass, + RenderModelRequest, + RenderModelResponse, + SupportedMethodsRequest, + SupportedMethodsResponse, + FormatProjectRequest, + FormatProjectResponse, + CustomMethod, + LIST_WORKSPACE_TESTS_FEATURE, + ListWorkspaceTestsRequest, + ListWorkspaceTestsResponse, + LIST_DOCUMENT_TESTS_FEATURE, + ListDocumentTestsRequest, + ListDocumentTestsResponse, + RUN_TEST_FEATURE, + RunTestRequest, + RunTestResponse, + GetEnvironmentsRequest, + GetEnvironmentsResponse, + EnvironmentInfo, + GetModelsRequest, + GetModelsResponse, + ModelInfo, +) +from sqlmesh.lsp.errors import ContextFailedError, context_error_to_diagnostic +from sqlmesh.lsp.helpers import to_lsp_range, to_sqlmesh_position +from sqlmesh.lsp.hints import get_hints +from sqlmesh.lsp.reference import ( + CTEReference, + ModelReference, + get_references, + get_all_references, +) +from sqlmesh.lsp.rename import prepare_rename, rename_symbol, get_document_highlights +from sqlmesh.lsp.uri import URI +from sqlmesh.utils.errors import ConfigError +from sqlmesh.utils.lineage import ExternalModelReference +from sqlmesh.utils.pydantic import PydanticModel +from web.server.api.endpoints.lineage import column_lineage, model_lineage +from web.server.api.endpoints.models import get_models +from web.server.api.endpoints.table_diff import _process_sample_data +from typing import Union +from dataclasses import dataclass, field + +from web.server.models import RowDiff, SchemaDiff, TableDiff + + +class InitializationOptions(PydanticModel): + """Initialization options for the SQLMesh Language Server, that + are passed from the client to the server.""" + + project_paths: t.Optional[t.List[str]] = None + + +@dataclass +class NoContext: + """State when no context has been attempted to load.""" + + version_id: str = field(default_factory=lambda: str(uuid.uuid4())) + + +@dataclass +class ContextLoaded: + """State when context has been successfully loaded.""" + + lsp_context: LSPContext + version_id: str = field(default_factory=lambda: str(uuid.uuid4())) + + +@dataclass +class ContextFailed: + """State when context failed to load with an error message.""" + + error: ContextFailedError + context: t.Optional[Context] = None + version_id: str = field(default_factory=lambda: str(uuid.uuid4())) + + +ContextState = Union[NoContext, ContextLoaded, ContextFailed] + + +class SQLMeshLanguageServer: + # Specified folders take precedence over workspace folders or looking + # for a config files. They are explicitly set by the user and optionally + # pass in at init + specified_paths: t.Optional[t.List[Path]] = None + + def __init__( + self, + context_class: t.Type[Context], + server_name: str = "sqlmesh_lsp", + version: str = __version__, + ): + """ + :param context_class: A class that inherits from `Context`. + :param server_name: Name for the language server. + :param version: Version string. + """ + self.server = LanguageServer(server_name, version, max_workers=1) + self.context_class = context_class + self.context_state: ContextState = NoContext() + self.workspace_folders: t.List[Path] = [] + + self.has_raised_loading_error: bool = False + + self.client_supports_pull_diagnostics = False + self._supported_custom_methods: t.Dict[ + str, + t.Callable[ + # mypy unable to recognize the base class + [LanguageServer, t.Any], + t.Any, + ], + ] = { + ALL_MODELS_FEATURE: self._custom_all_models, + RENDER_MODEL_FEATURE: self._custom_render_model, + ALL_MODELS_FOR_RENDER_FEATURE: self._custom_all_models_for_render, + API_FEATURE: self._custom_api, + SUPPORTED_METHODS_FEATURE: self._custom_supported_methods, + FORMAT_PROJECT_FEATURE: self._custom_format_project, + LIST_WORKSPACE_TESTS_FEATURE: self._list_workspace_tests, + LIST_DOCUMENT_TESTS_FEATURE: self._list_document_tests, + RUN_TEST_FEATURE: self._run_test, + GET_ENVIRONMENTS_FEATURE: self._custom_get_environments, + GET_MODELS_FEATURE: self._custom_get_models, + } + + # Register LSP features (e.g., formatting, hover, etc.) + self._register_features() + + def _list_workspace_tests( + self, + ls: LanguageServer, + params: ListWorkspaceTestsRequest, + ) -> ListWorkspaceTestsResponse: + """List all tests in the current workspace.""" + try: + context = self._context_get_or_load() + tests = context.list_workspace_tests() + return ListWorkspaceTestsResponse(tests=tests) + except Exception as e: + ls.log_trace(f"Error listing workspace tests: {e}") + return ListWorkspaceTestsResponse(tests=[]) + + def _list_document_tests( + self, + ls: LanguageServer, + params: ListDocumentTestsRequest, + ) -> ListDocumentTestsResponse: + """List tests for a specific document.""" + try: + uri = URI(params.textDocument.uri) + context = self._context_get_or_load(uri) + tests = context.get_document_tests(uri) + return ListDocumentTestsResponse(tests=tests) + except Exception as e: + ls.log_trace(f"Error listing document tests: {e}") + return ListDocumentTestsResponse(tests=[]) + + def _run_test( + self, + ls: LanguageServer, + params: RunTestRequest, + ) -> RunTestResponse: + """Run a specific test.""" + try: + uri = URI(params.textDocument.uri) + context = self._context_get_or_load(uri) + result = context.run_test(uri, params.testName) + return result + except Exception as e: + ls.log_trace(f"Error running test: {e}") + return RunTestResponse(success=False, response_error=str(e)) + + # All the custom LSP methods are registered here and prefixed with _custom + def _custom_all_models(self, ls: LanguageServer, params: AllModelsRequest) -> AllModelsResponse: + uri = URI(params.textDocument.uri) + # Get the document content + content = None + try: + document = ls.workspace.get_text_document(params.textDocument.uri) + content = document.source + except Exception: + pass + try: + context = self._context_get_or_load(uri) + return LSPContext.get_completions(context, uri, content) + except Exception as e: + from sqlmesh.lsp.completions import get_sql_completions + + return get_sql_completions(None, URI(params.textDocument.uri), content) + + def _custom_render_model( + self, ls: LanguageServer, params: RenderModelRequest + ) -> RenderModelResponse: + uri = URI(params.textDocumentUri) + context = self._context_get_or_load(uri) + return RenderModelResponse(models=context.render_model(uri)) + + def _custom_all_models_for_render( + self, ls: LanguageServer, params: AllModelsForRenderRequest + ) -> AllModelsForRenderResponse: + context = self._context_get_or_load() + return AllModelsForRenderResponse(models=context.list_of_models_for_rendering()) + + def _custom_format_project( + self, ls: LanguageServer, params: FormatProjectRequest + ) -> FormatProjectResponse: + """Format all models in the current project.""" + try: + context = self._context_get_or_load() + context.context.format() + return FormatProjectResponse() + except Exception as e: + ls.log_trace(f"Error formatting project: {e}") + return FormatProjectResponse() + + def _custom_get_environments( + self, ls: LanguageServer, params: GetEnvironmentsRequest + ) -> GetEnvironmentsResponse: + """Get all environments in the current project.""" + try: + context = self._context_get_or_load() + environments = {} + + # Get environments from state + for env in context.context.state_reader.get_environments(): + environments[env.name] = EnvironmentInfo( + name=env.name, + snapshots=[s.fingerprint.to_identifier() for s in env.snapshots], + start_at=str(to_timestamp(env.start_at)), + plan_id=env.plan_id or "", + ) + + return GetEnvironmentsResponse( + environments=environments, + pinned_environments=context.context.config.pinned_environments, + default_target_environment=context.context.config.default_target_environment, + ) + except Exception as e: + ls.log_trace(f"Error getting environments: {e}") + return GetEnvironmentsResponse( + response_error=str(e), + environments={}, + pinned_environments=set(), + default_target_environment="", + ) + + def _custom_get_models(self, ls: LanguageServer, params: GetModelsRequest) -> GetModelsResponse: + """Get all models available for table diff.""" + try: + context = self._context_get_or_load() + models = [ + ModelInfo( + name=model.name, + fqn=model.fqn, + description=model.description, + ) + for model in context.context.models.values() + # Filter for models that are suitable for table diff + if model._path is not None # Has a file path + ] + return GetModelsResponse(models=models) + except Exception as e: + ls.log_trace(f"Error getting table diff models: {e}") + return GetModelsResponse( + response_error=str(e), + models=[], + ) + + def _custom_api( + self, ls: LanguageServer, request: ApiRequest + ) -> t.Union[ + ApiResponseGetModels, + ApiResponseGetColumnLineage, + ApiResponseGetLineage, + ApiResponseGetTableDiff, + ]: + ls.log_trace(f"API request: {request}") + context = self._context_get_or_load() + + parsed_url = urllib.parse.urlparse(request.url) + path_parts = parsed_url.path.strip("/").split("/") + + if request.method == "GET": + if path_parts == ["api", "models"]: + # /api/models + return ApiResponseGetModels(data=get_models(context.context)) + + if path_parts[:2] == ["api", "lineage"]: + if len(path_parts) == 3: + # /api/lineage/{model} + model_name = urllib.parse.unquote(path_parts[2]) + lineage = model_lineage(model_name, context.context) + non_set_lineage = {k: v for k, v in lineage.items() if v is not None} + return ApiResponseGetLineage(data=non_set_lineage) + + if len(path_parts) == 4: + # /api/lineage/{model}/{column} + model_name = urllib.parse.unquote(path_parts[2]) + column = urllib.parse.unquote(path_parts[3]) + models_only = False + if hasattr(request, "params"): + models_only = bool(getattr(request.params, "models_only", False)) + column_lineage_response = column_lineage( + model_name, column, models_only, context.context + ) + return ApiResponseGetColumnLineage(data=column_lineage_response) + + if path_parts[:2] == ["api", "table_diff"]: + import numpy as np + + # /api/table_diff + params = request.params + table_diff_result: t.Optional[TableDiff] = None + if params := request.params: + source = getattr(params, "source", "") if params else "" + target = getattr(params, "target", "") if params else "" + on = getattr(params, "on", None) if params else None + model_or_snapshot = ( + getattr(params, "model_or_snapshot", None) if params else None + ) + where = getattr(params, "where", None) if params else None + temp_schema = getattr(params, "temp_schema", None) if params else None + limit = getattr(params, "limit", 20) if params else 20 + + table_diffs = context.context.table_diff( + source=source, + target=target, + on=exp.condition(on) if on else None, + select_models={model_or_snapshot} if model_or_snapshot else None, + where=where, + limit=limit, + show=False, + ) + + if table_diffs: + diff = table_diffs[0] if isinstance(table_diffs, list) else table_diffs + + _schema_diff = diff.schema_diff() + _row_diff = diff.row_diff(temp_schema=temp_schema) + schema_diff = SchemaDiff( + source=_schema_diff.source, + target=_schema_diff.target, + source_schema=_schema_diff.source_schema, + target_schema=_schema_diff.target_schema, + added=_schema_diff.added, + removed=_schema_diff.removed, + modified=_schema_diff.modified, + ) + + # create a readable column-centric sample data structure + processed_sample_data = _process_sample_data(_row_diff, source, target) + + row_diff = RowDiff( + source=_row_diff.source, + target=_row_diff.target, + stats=_row_diff.stats, + sample=_row_diff.sample.replace({np.nan: None}).to_dict(), + joined_sample=_row_diff.joined_sample.replace({np.nan: None}).to_dict(), + s_sample=_row_diff.s_sample.replace({np.nan: None}).to_dict(), + t_sample=_row_diff.t_sample.replace({np.nan: None}).to_dict(), + column_stats=_row_diff.column_stats.replace({np.nan: None}).to_dict(), + source_count=_row_diff.source_count, + target_count=_row_diff.target_count, + count_pct_change=_row_diff.count_pct_change, + decimals=getattr(_row_diff, "decimals", 3), + processed_sample_data=processed_sample_data, + ) + + s_index, t_index, _ = diff.key_columns + table_diff_result = TableDiff( + schema_diff=schema_diff, + row_diff=row_diff, + on=[(s.name, t.name) for s, t in zip(s_index, t_index)], + ) + return ApiResponseGetTableDiff(data=table_diff_result) + + raise NotImplementedError(f"API request not implemented: {request.url}") + + def _custom_supported_methods( + self, ls: LanguageServer, params: SupportedMethodsRequest + ) -> SupportedMethodsResponse: + """Return all supported custom LSP methods.""" + return SupportedMethodsResponse( + methods=[ + CustomMethod( + name=name, + ) + for name in self._supported_custom_methods + ] + ) + + def _reload_context_and_publish_diagnostics( + self, ls: LanguageServer, uri: URI, document_uri: str + ) -> None: + """Helper method to reload context and publish diagnostics.""" + if isinstance(self.context_state, NoContext): + pass + elif isinstance(self.context_state, ContextFailed): + if self.context_state.context: + try: + self.context_state.context.load() + # Creating a new LSPContext will naturally create fresh caches + self.context_state = ContextLoaded( + lsp_context=LSPContext(self.context_state.context) + ) + except Exception as e: + ls.log_trace(f"Error loading context: {e}") + context = ( + self.context_state.context + if hasattr(self.context_state, "context") + else None + ) + self.context_state = ContextFailed(error=e, context=context) + else: + # If there's no context, reset to NoContext and try to create one from scratch + ls.log_trace("No partial context available, attempting fresh creation") + self.context_state = NoContext() + self.has_raised_loading_error = False # Reset error flag to show new errors + try: + self._ensure_context_for_document(uri) + # If successful, context_state will be ContextLoaded + if isinstance(self.context_state, ContextLoaded): + loaded_sqlmesh_message(ls) + except Exception as e: + ls.log_trace(f"Still cannot load context: {e}") + # The error will be stored in context_state by _ensure_context_for_document + else: + # Reload the context if it was successfully loaded + try: + context = self.context_state.lsp_context.context + context.load() + # Create new LSPContext which will have fresh, empty caches + self.context_state = ContextLoaded(lsp_context=LSPContext(context)) + except Exception as e: + ls.log_trace(f"Error loading context: {e}") + self.context_state = ContextFailed( + error=e, context=self.context_state.lsp_context.context + ) + + # Send a workspace diagnostic refresh request to the client. This is used to notify the client that the diagnostics have changed. + ls.lsp.send_request( + types.WORKSPACE_DIAGNOSTIC_REFRESH, + WorkspaceDiagnosticRefreshRequest( + id=self.context_state.version_id, + ), + ) + ls.lsp.send_request( + types.WORKSPACE_INLAY_HINT_REFRESH, + WorkspaceInlayHintRefreshRequest( + id=self.context_state.version_id, + ), + ) + + # Only publish diagnostics if client doesn't support pull diagnostics + if not self.client_supports_pull_diagnostics: + if hasattr(self.context_state, "lsp_context"): + diagnostics = self.context_state.lsp_context.lint_model(uri) + ls.publish_diagnostics( + document_uri, + LSPContext.diagnostics_to_lsp_diagnostics(diagnostics), + ) + + def _register_features(self) -> None: + """Register LSP features on the internal LanguageServer instance.""" + for name, method in self._supported_custom_methods.items(): + + def create_function_call(method_func: t.Callable) -> t.Callable: + def function_call(ls: LanguageServer, params: t.Any) -> t.Dict[str, t.Any]: + try: + response = method_func(ls, params) + except Exception as e: + response = CustomMethodResponseBaseClass(response_error=str(e)) + return response.model_dump(mode="json") + + return function_call + + self.server.feature(name)(create_function_call(method)) + + @self.server.command(EXTERNAL_MODEL_UPDATE_COLUMNS) + def command_external_models_update_columns(ls: LanguageServer, raw: t.Any) -> None: + try: + if not isinstance(raw, list): + raise ValueError("Invalid command parameters") + if len(raw) != 1: + raise ValueError("Command expects exactly one parameter") + model_name = raw[0] + if not isinstance(model_name, str): + raise ValueError("Command parameter must be a string") + + context = self._context_get_or_load() + if not isinstance(context, LSPContext): + raise ValueError("Context is not loaded or invalid") + model = context.context.get_model(model_name) + if model is None: + raise ValueError(f"External model '{model_name}' not found") + if model._path is None: + raise ValueError(f"External model '{model_name}' does not have a file path") + uri = URI.from_path(model._path) + updated = context.update_external_model_columns( + ls=ls, + uri=uri, + model_name=model_name, + ) + if updated: + ls.show_message( + f"Updated columns for '{model_name}'", + types.MessageType.Info, + ) + else: + ls.show_message( + f"Columns for '{model_name}' are already up to date", + ) + except Exception as e: + ls.show_message(f"Error executing command: {e}", types.MessageType.Error) + return None + + @self.server.feature(types.INITIALIZE) + def initialize(ls: LanguageServer, params: types.InitializeParams) -> None: + """Initialize the server when the client connects.""" + try: + # Check the custom options + if params.initialization_options: + options = InitializationOptions.model_validate(params.initialization_options) + if options.project_paths is not None: + self.specified_paths = [Path(path) for path in options.project_paths] + + # Check if the client supports pull diagnostics + if params.capabilities and params.capabilities.text_document: + diagnostics = getattr(params.capabilities.text_document, "diagnostic", None) + if diagnostics: + self.client_supports_pull_diagnostics = True + ls.log_trace("Client supports pull diagnostics") + else: + self.client_supports_pull_diagnostics = False + ls.log_trace("Client does not support pull diagnostics") + else: + self.client_supports_pull_diagnostics = False + + if params.workspace_folders: + # Store all workspace folders for later use + self.workspace_folders = [ + Path(self._uri_to_path(folder.uri)) for folder in params.workspace_folders + ] + + # Try to find a SQLMesh config file in any workspace folder (only at the root level) + for folder_path in self.workspace_folders: + # Only check for config files directly in the workspace directory + for ext in ("py", "yml", "yaml"): + config_path = folder_path / f"config.{ext}" + if config_path.exists(): + if self._create_lsp_context([folder_path]): + loaded_sqlmesh_message(ls) + return # Exit after successfully loading any config + except Exception as e: + ls.log_trace( + f"Error initializing SQLMesh context: {e}", + ) + + @self.server.feature(types.TEXT_DOCUMENT_DID_OPEN) + def did_open(ls: LanguageServer, params: types.DidOpenTextDocumentParams) -> None: + uri = URI(params.text_document.uri) + context = self._context_get_or_load(uri) + + # Only publish diagnostics if client doesn't support pull diagnostics + if not self.client_supports_pull_diagnostics: + diagnostics = context.lint_model(uri) + ls.publish_diagnostics( + params.text_document.uri, + LSPContext.diagnostics_to_lsp_diagnostics(diagnostics), + ) + + @self.server.feature(types.TEXT_DOCUMENT_DID_SAVE) + def did_save(ls: LanguageServer, params: types.DidSaveTextDocumentParams) -> None: + uri = URI(params.text_document.uri) + self._reload_context_and_publish_diagnostics(ls, uri, params.text_document.uri) + + @self.server.feature(types.TEXT_DOCUMENT_FORMATTING) + def formatting( + ls: LanguageServer, params: types.DocumentFormattingParams + ) -> t.List[types.TextEdit]: + """Format the document using SQLMesh `format_model_expressions`.""" + try: + uri = URI(params.text_document.uri) + context = self._context_get_or_load(uri) + document = ls.workspace.get_text_document(params.text_document.uri) + before = document.source + + target = next( + ( + target + for target in chain( + context.context._models.values(), + context.context._audits.values(), + ) + if target._path is not None + and target._path.suffix == ".sql" + and (target._path.samefile(uri.to_path())) + ), + None, + ) + if target is None: + return [] + after = context.context._format( + target=target, + before=before, + ) + return [ + types.TextEdit( + range=types.Range( + start=types.Position(line=0, character=0), + end=types.Position( + line=len(document.lines), + character=len(document.lines[-1]) if document.lines else 0, + ), + ), + new_text=after, + ) + ] + except Exception as e: + ls.show_message(f"Error formatting SQL: {e}", types.MessageType.Error) + return [] + + @self.server.feature(types.TEXT_DOCUMENT_HOVER) + def hover(ls: LanguageServer, params: types.HoverParams) -> t.Optional[types.Hover]: + """Provide hover information for an object.""" + try: + uri = URI(params.text_document.uri) + context = self._context_get_or_load(uri) + document = ls.workspace.get_text_document(params.text_document.uri) + + references = get_references(context, uri, to_sqlmesh_position(params.position)) + if not references: + return None + reference = references[0] + if isinstance(reference, CTEReference) or not reference.markdown_description: + return None + return types.Hover( + contents=types.MarkupContent( + kind=types.MarkupKind.Markdown, + value=reference.markdown_description, + ), + range=to_lsp_range(reference.range), + ) + + except Exception as e: + ls.log_trace( + f"Error getting hover information: {e}", + ) + return None + + @self.server.feature(types.TEXT_DOCUMENT_INLAY_HINT) + def inlay_hint( + ls: LanguageServer, params: types.InlayHintParams + ) -> t.List[types.InlayHint]: + """Implement type hints for sql columns as inlay hints""" + try: + uri = URI(params.text_document.uri) + context = self._context_get_or_load(uri) + + start_line = params.range.start.line + end_line = params.range.end.line + hints = get_hints(context, uri, start_line, end_line) + return hints + + except Exception as e: + return [] + + @self.server.feature(types.TEXT_DOCUMENT_DEFINITION) + def goto_definition( + ls: LanguageServer, params: types.DefinitionParams + ) -> t.List[types.LocationLink]: + """Jump to an object's definition.""" + try: + uri = URI(params.text_document.uri) + context = self._context_get_or_load(uri) + + references = get_references(context, uri, to_sqlmesh_position(params.position)) + location_links = [] + for reference in references: + # Use target_range if available (CTEs, Macros, and external models in YAML) + if isinstance(reference, ModelReference): + # Regular SQL models - default to start of file + target_range = types.Range( + start=types.Position(line=0, character=0), + end=types.Position(line=0, character=0), + ) + target_selection_range = types.Range( + start=types.Position(line=0, character=0), + end=types.Position(line=0, character=0), + ) + elif isinstance(reference, ExternalModelReference): + # External models may have target_range set for YAML files + target_range = types.Range( + start=types.Position(line=0, character=0), + end=types.Position(line=0, character=0), + ) + target_selection_range = types.Range( + start=types.Position(line=0, character=0), + end=types.Position(line=0, character=0), + ) + if reference.target_range is not None: + target_range = to_lsp_range(reference.target_range) + target_selection_range = to_lsp_range(reference.target_range) + else: + # CTEs and Macros always have target_range + target_range = to_lsp_range(reference.target_range) + target_selection_range = to_lsp_range(reference.target_range) + + if reference.path is not None: + location_links.append( + types.LocationLink( + target_uri=URI.from_path(reference.path).value, + target_selection_range=target_selection_range, + target_range=target_range, + origin_selection_range=to_lsp_range(reference.range), + ) + ) + return location_links + except Exception as e: + ls.show_message(f"Error getting references: {e}", types.MessageType.Error) + return [] + + @self.server.feature(types.TEXT_DOCUMENT_REFERENCES) + def find_references( + ls: LanguageServer, params: types.ReferenceParams + ) -> t.Optional[t.List[types.Location]]: + """Find all references of a symbol (supporting CTEs, models for now)""" + try: + uri = URI(params.text_document.uri) + context = self._context_get_or_load(uri) + + all_references = get_all_references( + context, uri, to_sqlmesh_position(params.position) + ) + + # Convert references to Location objects + locations = [ + types.Location(uri=URI.from_path(ref.path).value, range=to_lsp_range(ref.range)) + for ref in all_references + if ref.path is not None + ] + + return locations if locations else None + except Exception as e: + ls.show_message(f"Error getting locations: {e}", types.MessageType.Error) + return None + + @self.server.feature(types.TEXT_DOCUMENT_PREPARE_RENAME) + def prepare_rename_handler( + ls: LanguageServer, params: types.PrepareRenameParams + ) -> t.Optional[types.PrepareRenameResult]: + """Prepare for rename operation by checking if the symbol can be renamed.""" + try: + uri = URI(params.text_document.uri) + context = self._context_get_or_load(uri) + result = prepare_rename(context, uri, params.position) + return result + except Exception as e: + ls.log_trace(f"Error preparing rename: {e}") + return None + + @self.server.feature(types.TEXT_DOCUMENT_RENAME) + def rename_handler( + ls: LanguageServer, params: types.RenameParams + ) -> t.Optional[types.WorkspaceEdit]: + """Perform rename operation on the symbol at the given position.""" + try: + uri = URI(params.text_document.uri) + context = self._context_get_or_load(uri) + workspace_edit = rename_symbol(context, uri, params.position, params.new_name) + return workspace_edit + except Exception as e: + ls.show_message(f"Error performing rename: {e}", types.MessageType.Error) + return None + + @self.server.feature(types.TEXT_DOCUMENT_DOCUMENT_HIGHLIGHT) + def document_highlight_handler( + ls: LanguageServer, params: types.DocumentHighlightParams + ) -> t.Optional[t.List[types.DocumentHighlight]]: + """Highlight all occurrences of the symbol at the given position.""" + try: + uri = URI(params.text_document.uri) + context = self._context_get_or_load(uri) + highlights = get_document_highlights(context, uri, params.position) + return highlights + except Exception as e: + ls.log_trace(f"Error getting document highlights: {e}") + return None + + @self.server.feature(types.TEXT_DOCUMENT_DIAGNOSTIC) + def diagnostic( + ls: LanguageServer, params: types.DocumentDiagnosticParams + ) -> types.DocumentDiagnosticReport: + """Handle diagnostic pull requests from the client.""" + try: + uri = URI(params.text_document.uri) + diagnostics, result_id = self._get_diagnostics_for_uri(uri) + + # Check if client provided a previous result ID + if hasattr(params, "previous_result_id") and params.previous_result_id == result_id: + # Return unchanged report if diagnostics haven't changed + return types.RelatedUnchangedDocumentDiagnosticReport( + kind=types.DocumentDiagnosticReportKind.Unchanged, + result_id=str(result_id), + ) + + return types.RelatedFullDocumentDiagnosticReport( + kind=types.DocumentDiagnosticReportKind.Full, + items=diagnostics, + result_id=str(result_id), + ) + except Exception as e: + ls.log_trace( + f"Error getting diagnostics: {e}", + ) + return types.RelatedFullDocumentDiagnosticReport( + kind=types.DocumentDiagnosticReportKind.Full, + items=[], + ) + + @self.server.feature(types.WORKSPACE_DIAGNOSTIC) + def workspace_diagnostic( + ls: LanguageServer, params: types.WorkspaceDiagnosticParams + ) -> types.WorkspaceDiagnosticReport: + """Handle workspace-wide diagnostic pull requests from the client.""" + try: + context = self._context_get_or_load() + + items: t.List[ + t.Union[ + types.WorkspaceFullDocumentDiagnosticReport, + types.WorkspaceUnchangedDocumentDiagnosticReport, + ] + ] = [] + + # Get all SQL and Python model files from the context + for path, target in context.map.items(): + if isinstance(target, ModelTarget): + uri = URI.from_path(path) + diagnostics, result_id = self._get_diagnostics_for_uri(uri) + + # Check if we have a previous result ID for this file + previous_result_id = None + if hasattr(params, "previous_result_ids") and params.previous_result_ids: + for prev in params.previous_result_ids: + if prev.uri == uri.value: + previous_result_id = prev.value + break + + if previous_result_id and previous_result_id == result_id: + # File hasn't changed + items.append( + types.WorkspaceUnchangedDocumentDiagnosticReport( + kind=types.DocumentDiagnosticReportKind.Unchanged, + result_id=str(result_id), + uri=uri.value, + ) + ) + else: + # File has changed or is new + items.append( + types.WorkspaceFullDocumentDiagnosticReport( + kind=types.DocumentDiagnosticReportKind.Full, + result_id=str(result_id), + uri=uri.value, + items=diagnostics, + ) + ) + + return types.WorkspaceDiagnosticReport(items=items) + + except Exception as e: + ls.log_trace(f"Error getting workspace diagnostics: {e}") + error_diagnostic, error = context_error_to_diagnostic(e) + if error_diagnostic: + uri_value, unpacked_diagnostic = error_diagnostic + return types.WorkspaceDiagnosticReport( + items=[ + types.WorkspaceFullDocumentDiagnosticReport( + kind=types.DocumentDiagnosticReportKind.Full, + result_id=self.context_state.version_id, # No versioning, always fresh + uri=uri_value, + items=[unpacked_diagnostic], + ) + ] + ) + + return types.WorkspaceDiagnosticReport(items=[]) + + @self.server.feature(types.TEXT_DOCUMENT_CODE_ACTION) + def code_action( + ls: LanguageServer, params: types.CodeActionParams + ) -> t.Optional[t.List[t.Union[types.Command, types.CodeAction]]]: + try: + ls.log_trace(f"Codeactionrequest: {params}") + uri = URI(params.text_document.uri) + context = self._context_get_or_load(uri) + code_actions = context.get_code_actions(uri, params) + return code_actions + + except Exception as e: + ls.log_trace(f"Error getting code actions: {e}") + return None + + @self.server.feature(types.TEXT_DOCUMENT_CODE_LENS) + def code_lens(ls: LanguageServer, params: types.CodeLensParams) -> t.List[types.CodeLens]: + try: + uri = URI(params.text_document.uri) + context = self._context_get_or_load(uri) + code_lenses = context.get_code_lenses(uri) + return code_lenses if code_lenses else [] + except Exception as e: + ls.log_trace(f"Error getting code lenses: {e}") + return [] + + @self.server.feature( + types.TEXT_DOCUMENT_COMPLETION, + types.CompletionOptions(trigger_characters=["@"]), # advertise "@" for macros + ) + def completion( + ls: LanguageServer, params: types.CompletionParams + ) -> t.Optional[types.CompletionList]: + """Handle completion requests from the client.""" + try: + uri = URI(params.text_document.uri) + context = self._context_get_or_load(uri) + + # Get the document content + content = None + try: + document = ls.workspace.get_text_document(params.text_document.uri) + content = document.source + except Exception: + pass + + # Get completions using the existing completions module + completion_response = LSPContext.get_completions(context, uri, content) + + completion_items = [] + # Add model completions + for model in completion_response.model_completions: + completion_items.append( + types.CompletionItem( + label=model.name, + kind=types.CompletionItemKind.Reference, + detail="SQLMesh Model", + documentation=types.MarkupContent( + kind=types.MarkupKind.Markdown, + value=model.description or "No description available", + ) + if model.description + else None, + ) + ) + # Add macro completions + triggered_by_at = ( + params.context is not None + and getattr(params.context, "trigger_character", None) == "@" + ) + + for macro in completion_response.macros: + macro_name = macro.name + insert_text = macro_name if triggered_by_at else f"@{macro_name}" + + completion_items.append( + types.CompletionItem( + label=f"@{macro_name}", + insert_text=insert_text, + insert_text_format=types.InsertTextFormat.PlainText, + filter_text=macro_name, + kind=types.CompletionItemKind.Function, + detail="SQLMesh Macro", + documentation=macro.description, + ) + ) + + for keyword in completion_response.keywords: + completion_items.append( + types.CompletionItem( + label=keyword, + kind=types.CompletionItemKind.Keyword, + detail="SQL Keyword", + ) + ) + + return types.CompletionList( + is_incomplete=False, + items=completion_items, + ) + + except Exception as e: + get_sql_completions(None, URI(params.text_document.uri)) + return None + + def _get_diagnostics_for_uri(self, uri: URI) -> t.Tuple[t.List[types.Diagnostic], str]: + """Get diagnostics for a specific URI, returning (diagnostics, result_id). + + Since we no longer track version numbers, we always return 0 as the result_id. + This means pull diagnostics will always fetch fresh results. + """ + try: + context = self._context_get_or_load(uri) + diagnostics = context.lint_model(uri) + return LSPContext.diagnostics_to_lsp_diagnostics( + diagnostics + ), self.context_state.version_id + except ConfigError as config_error: + diagnostic, error = context_error_to_diagnostic(config_error, uri_filter=uri) + if diagnostic: + location, diag = diagnostic + if location == uri.value: + return [diag], self.context_state.version_id + return [], self.context_state.version_id + + def _context_get_or_load(self, document_uri: t.Optional[URI] = None) -> LSPContext: + state = self.context_state + if isinstance(state, ContextFailed): + if isinstance(state.error, str): + raise Exception(state.error) + raise state.error + if isinstance(state, NoContext): + if self.specified_paths is not None: + # If specified paths are provided, create context from them + if self._create_lsp_context(self.specified_paths): + loaded_sqlmesh_message(self.server) + else: + self._ensure_context_for_document(document_uri) + if isinstance(state, ContextLoaded): + return state.lsp_context + raise RuntimeError("Context failed to load") + + def _ensure_context_for_document( + self, + document_uri: t.Optional[URI] = None, + ) -> None: + """ + Ensure that a context exists for the given document if applicable by searching + for a config.py or config.yml file in the parent directories. + """ + if document_uri is not None: + document_path = document_uri.to_path() + if document_path.is_file() and document_path.suffix in (".sql", ".py"): + document_folder = document_path.parent + if document_folder.is_dir(): + self._ensure_context_in_folder(document_folder) + return + + self._ensure_context_in_folder() + + def _ensure_context_in_folder(self, folder_path: t.Optional[Path] = None) -> None: + if not isinstance(self.context_state, NoContext): + return + + # If not found in the provided folder, search through all workspace folders + for workspace_folder in self.workspace_folders: + for ext in ("py", "yml", "yaml"): + config_path = workspace_folder / f"config.{ext}" + if config_path.exists(): + if self._create_lsp_context([workspace_folder]): + loaded_sqlmesh_message(self.server) + return + + # Then , check the provided folder recursively + path = folder_path + if path is None: + path = Path.cwd() + while path.is_dir(): + for ext in ("py", "yml", "yaml"): + config_path = path / f"config.{ext}" + if config_path.exists(): + if self._create_lsp_context([path]): + loaded_sqlmesh_message(self.server) + return + + path = path.parent + if path == path.parent: + break + + raise RuntimeError( + f"No context found in workspaces folders {self.workspace_folders}" + + (f" or in {folder_path}" if folder_path else "") + ) + + def _create_lsp_context(self, paths: t.List[Path]) -> t.Optional[LSPContext]: + """Create a new LSPContext instance using the configured context class. + + On success, sets self.context_state to ContextLoaded and returns the created context. + + Args: + paths: List of paths to pass to the context constructor + + Returns: + A new LSPContext instance wrapping the created context, or None if creation fails + """ + try: + if isinstance(self.context_state, NoContext): + context = self.context_class(paths=paths) + elif isinstance(self.context_state, ContextFailed): + if self.context_state.context: + context = self.context_state.context + context.load() + else: + # If there's no context (initial creation failed), try creating again + context = self.context_class(paths=paths) + else: + context = self.context_state.lsp_context.context + context.load() + self.context_state = ContextLoaded(lsp_context=LSPContext(context)) + return self.context_state.lsp_context + except Exception as e: + # Only show the error message once + if not self.has_raised_loading_error: + self.server.show_message( + f"Error creating context: {e}", + types.MessageType.Error, + ) + self.has_raised_loading_error = True + + self.server.log_trace(f"Error creating context: {e}") + # Store the error in context state so subsequent requests show the actual error + # Try to preserve any partially loaded context if it exists + context = None + if isinstance(self.context_state, ContextLoaded): + context = self.context_state.lsp_context.context + elif isinstance(self.context_state, ContextFailed) and self.context_state.context: + context = self.context_state.context + self.context_state = ContextFailed(error=e, context=context) + return None + + @staticmethod + def _uri_to_path(uri: str) -> Path: + """Convert a URI to a path.""" + return URI(uri).to_path() + + def start(self) -> None: + """Start the server with I/O transport.""" + logging.basicConfig(level=logging.DEBUG) + self.server.start_io() + + +def loaded_sqlmesh_message(ls: LanguageServer) -> None: + ls.show_message( + f"Loaded SQLMesh Context", + types.MessageType.Info, + ) + + +def main() -> None: + # Example instantiator that just uses the same signature as your original `Context` usage. + sqlmesh_server = SQLMeshLanguageServer(context_class=Context) + sqlmesh_server.start() + + +if __name__ == "__main__": + main() diff --git a/sqlmesh/lsp/reference.py b/sqlmesh/lsp/reference.py new file mode 100644 index 0000000000..80d401f79c --- /dev/null +++ b/sqlmesh/lsp/reference.py @@ -0,0 +1,580 @@ +import typing as t +from pathlib import Path + +from sqlmesh.core.audit import StandaloneAudit +from sqlmesh.core.linter.helpers import ( + TokenPositionDetails, +) +from sqlmesh.core.linter.rule import Range, Position +from sqlmesh.core.model.definition import SqlModel +from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget +from sqlglot import exp + +from sqlmesh.lsp.uri import URI +from sqlmesh.utils.lineage import ( + MacroReference, + CTEReference, + Reference, + ModelReference, + extract_references_from_query, +) +import ast +from sqlmesh.core.model import Model +from sqlmesh import macro +import inspect + + +def by_position(position: Position) -> t.Callable[[Reference], bool]: + """ + Filter reference to only filter references that contain the given position. + + Args: + position: The cursor position to check + + Returns: + A function that returns True if the reference contains the position, False otherwise + """ + + def contains_position(r: Reference) -> bool: + return _position_within_range(position, r.range) + + return contains_position + + +def get_references( + lint_context: LSPContext, document_uri: URI, position: Position +) -> t.List[Reference]: + """ + Get references at a specific position in a document. + + Used for hover information. + + Args: + lint_context: The LSP context + document_uri: The URI of the document + position: The position to check for references + + Returns: + A list of references at the given position + """ + references = get_model_definitions_for_a_path(lint_context, document_uri) + + # Get macro references before filtering by position + macro_references = get_macro_definitions_for_a_path(lint_context, document_uri) + references.extend(macro_references) + + filtered_references = list(filter(by_position(position), references)) + return filtered_references + + +def get_model_definitions_for_a_path( + lint_context: LSPContext, document_uri: URI +) -> t.List[Reference]: + """ + Get the model references for a given path. + + Works for models and standalone audits. + Works for targeting sql and python models. + + Steps: + - Get the parsed query + - Find all table objects using find_all exp.Table + - Match the string against all model names + - Need to normalize it before matching + - Try get_model before normalization + - Match to models that the model refers to + - Also find CTE references within the query + """ + path = document_uri.to_path() + if path.suffix != ".sql": + return [] + # Get the file info from the context map + if path not in lint_context.map: + return [] + + file_info = lint_context.map[path] + # Process based on whether it's a model or standalone audit + if isinstance(file_info, ModelTarget): + # It's a model + model = lint_context.context.get_model( + model_or_snapshot=file_info.names[0], raise_if_missing=False + ) + if model is None or not isinstance(model, SqlModel): + return [] + + query = model.query + dialect = model.dialect + depends_on = model.depends_on + file_path = model._path + elif isinstance(file_info, AuditTarget): + # It's a standalone audit + audit = lint_context.context.standalone_audits.get(file_info.name) + if audit is None: + return [] + query = audit.query + dialect = audit.dialect + depends_on = audit.depends_on + file_path = audit._path + else: + return [] + + if file_path is None: + return [] + + with open(file_path, "r", encoding="utf-8") as file: + read_file = file.readlines() + + return extract_references_from_query( + query=query, + context=lint_context.context, + document_path=document_uri.to_path(), + read_file=read_file, + depends_on=depends_on, + dialect=dialect, + ) + + +def get_macro_definitions_for_a_path( + lsp_context: LSPContext, document_uri: URI +) -> t.List[Reference]: + """ + Get macro references for a given path. + + This function finds all macro invocations (e.g., @ADD_ONE, @MULTIPLY) in a SQL file + and creates references to their definitions in the Python macro files. + + Args: + lsp_context: The LSP context containing macro definitions + document_uri: The URI of the document to search for macro invocations + + Returns: + A list of Reference objects for each macro invocation found + """ + path = document_uri.to_path() + if path.suffix != ".sql": + return [] + + # Get the file info from the context map + if path not in lsp_context.map: + return [] + + file_info = lsp_context.map[path] + # Process based on whether it's a model or standalone audit + if isinstance(file_info, ModelTarget): + # It's a model + target: t.Optional[t.Union[Model, StandaloneAudit]] = lsp_context.context.get_model( + model_or_snapshot=file_info.names[0], raise_if_missing=False + ) + if target is None or not isinstance(target, SqlModel): + return [] + query = target.query + file_path = target._path + elif isinstance(file_info, AuditTarget): + # It's a standalone audit + target = lsp_context.context.standalone_audits.get(file_info.name) + if target is None: + return [] + query = target.query + file_path = target._path + else: + return [] + + if file_path is None: + return [] + + references = [] + _, config_path = lsp_context.context.config_for_path( + file_path, + ) + + with open(file_path, "r", encoding="utf-8") as file: + read_file = file.readlines() + + for node in query.find_all(exp.Anonymous): + macro_name = node.name.lower() + reference = get_macro_reference( + node=node, + target=target, + read_file=read_file, + config_path=config_path, + macro_name=macro_name, + ) + if reference is not None: + references.append(reference) + + return references + + +def get_macro_reference( + target: t.Union[Model, StandaloneAudit], + read_file: t.List[str], + config_path: t.Optional[Path], + node: exp.Expression, + macro_name: str, +) -> t.Optional[Reference]: + # Get the file path where the macro is defined + try: + # Get the position of the macro invocation in the source file first + if hasattr(node, "meta") and node.meta: + macro_range = TokenPositionDetails.from_meta(node.meta).to_range(read_file) + + # Check if it's a built-in method + if builtin := get_built_in_macro_reference(macro_name, macro_range): + return builtin + else: + # Skip if we can't get the position + return None + + # Find the macro definition information + macro_def = target.python_env.get(macro_name) + if macro_def is None: + return None + + function_name = macro_def.name + if not function_name: + return None + if not macro_def.path: + return None + if not config_path: + return None + path = Path(config_path).joinpath(macro_def.path) + + # Parse the Python file to find the function definition + with open(path, "r") as f: + tree = ast.parse(f.read()) + with open(path, "r") as f: + output_read_line = f.readlines() + + # Find the function definition by name + start_line = None + end_line = None + get_length_of_end_line = None + docstring = None + for ast_node in ast.walk(tree): + if isinstance(ast_node, ast.FunctionDef) and ast_node.name == function_name: + start_line = ast_node.lineno + end_line = ast_node.end_lineno + get_length_of_end_line = ( + len(output_read_line[end_line - 1]) + if end_line is not None and end_line - 1 < len(read_file) + else 0 + ) + # Extract docstring if present + docstring = ast.get_docstring(ast_node) + break + + if start_line is None or end_line is None or get_length_of_end_line is None: + return None + + # Create a reference to the macro definition + + return MacroReference( + path=path, + range=macro_range, + target_range=Range( + start=Position(line=start_line - 1, character=0), + end=Position(line=end_line - 1, character=get_length_of_end_line), + ), + markdown_description=docstring, + ) + except Exception: + return None + + +def get_built_in_macro_reference(macro_name: str, macro_range: Range) -> t.Optional[Reference]: + """ + Get a reference to a built-in macro by its name. + + Args: + macro_name: The name of the built-in macro (e.g., 'each', 'sql_literal') + macro_range: The range of the macro invocation in the source file + """ + built_in_macros = macro.get_registry() + built_in_macro = built_in_macros.get(macro_name) + if built_in_macro is None: + return None + + func = built_in_macro.func + filename = inspect.getfile(func) + source_lines, line_number = inspect.getsourcelines(func) + + # Calculate the end line number by counting the number of source lines + end_line_number = line_number + len(source_lines) - 1 + + return MacroReference( + path=Path(filename), + range=macro_range, + target_range=Range( + start=Position(line=line_number - 1, character=0), + end=Position(line=end_line_number - 1, character=0), + ), + markdown_description=func.__doc__ if func.__doc__ else None, + ) + + +def get_model_find_all_references( + lint_context: LSPContext, document_uri: URI, position: Position +) -> t.List[ModelReference]: + """ + Get all references to a model across the entire project. + + This function finds all usages of a model in other files by searching through + all models in the project and checking their dependencies. + + Args: + lint_context: The LSP context + document_uri: The URI of the document + position: The position to check for model references + + Returns: + A list of references to the model across all files + """ + # Find the model reference at the cursor position + model_at_position = next( + filter( + lambda ref: isinstance(ref, ModelReference) + and _position_within_range(position, ref.range), + get_model_definitions_for_a_path(lint_context, document_uri), + ), + None, + ) + + if not model_at_position: + return [] + + assert isinstance(model_at_position, ModelReference) # for mypy + + target_model_path = model_at_position.path + + # Start with the model definition + all_references: t.List[ModelReference] = [ + ModelReference( + path=model_at_position.path, + range=Range( + start=Position(line=0, character=0), + end=Position(line=0, character=0), + ), + markdown_description=model_at_position.markdown_description, + ) + ] + + # Then add references from the current file + current_file_refs = filter( + lambda ref: isinstance(ref, ModelReference) and ref.path == target_model_path, + get_model_definitions_for_a_path(lint_context, document_uri), + ) + + for ref in current_file_refs: + assert isinstance(ref, ModelReference) # for mypy + + all_references.append( + ModelReference( + path=document_uri.to_path(), + range=ref.range, + markdown_description=ref.markdown_description, + ) + ) + + # Search through the models in the project + for path, _ in lint_context.map.items(): + file_uri = URI.from_path(path) + + # Skip current file, already processed + if file_uri.value == document_uri.value: + continue + + # Get model references that point to the target model + matching_refs = filter( + lambda ref: isinstance(ref, ModelReference) and ref.path == target_model_path, + get_model_definitions_for_a_path(lint_context, file_uri), + ) + + for ref in matching_refs: + assert isinstance(ref, ModelReference) # for mypy + + all_references.append( + ModelReference( + path=path, + range=ref.range, + markdown_description=ref.markdown_description, + ) + ) + + return all_references + + +def get_cte_references( + lint_context: LSPContext, document_uri: URI, position: Position +) -> t.List[CTEReference]: + """ + Get all references to a CTE at a specific position in a document. + + This function finds both the definition and all usages of a CTE within the same file. + + Args: + lint_context: The LSP context + document_uri: The URI of the document + position: The position to check for CTE references + + Returns: + A list of references to the CTE (including its definition and all usages) + """ + + # Filter to get the CTE references + cte_references: t.List[CTEReference] = [ + ref + for ref in get_model_definitions_for_a_path(lint_context, document_uri) + if isinstance(ref, CTEReference) + ] + + if not cte_references: + return [] + + target_cte_definition_range = None + for ref in cte_references: + # Check if cursor is on a CTE usage + if _position_within_range(position, ref.range): + target_cte_definition_range = ref.target_range + break + # Check if cursor is on the CTE definition + elif _position_within_range(position, ref.target_range): + target_cte_definition_range = ref.target_range + break + + if target_cte_definition_range is None: + return [] + + # Add the CTE definition + matching_references = [ + CTEReference( + path=document_uri.to_path(), + range=target_cte_definition_range, + target_range=target_cte_definition_range, + ) + ] + + # Add all usages + for ref in cte_references: + if ref.target_range == target_cte_definition_range: + matching_references.append( + CTEReference( + path=document_uri.to_path(), + range=ref.range, + target_range=ref.target_range, + ) + ) + + return matching_references + + +def get_macro_find_all_references( + lsp_context: LSPContext, document_uri: URI, position: Position +) -> t.List[MacroReference]: + """ + Get all references to a macro at a specific position in a document. + + This function finds all usages of a macro across the entire project. + + Args: + lsp_context: The LSP context + document_uri: The URI of the document + position: The position to check for macro references + + Returns: + A list of references to the macro across all files + """ + # Find the macro reference at the cursor position + macro_at_position = next( + filter( + lambda ref: isinstance(ref, MacroReference) + and _position_within_range(position, ref.range), + get_macro_definitions_for_a_path(lsp_context, document_uri), + ), + None, + ) + + if not macro_at_position: + return [] + + assert isinstance(macro_at_position, MacroReference) # for mypy + + target_macro_path = macro_at_position.path + target_macro_target_range = macro_at_position.target_range + + # Start with the macro definition + all_references: t.List[MacroReference] = [ + MacroReference( + path=target_macro_path, + range=target_macro_target_range, + target_range=target_macro_target_range, + markdown_description=None, + ) + ] + + # Search through all SQL and audit files in the project + for path, _ in lsp_context.map.items(): + file_uri = URI.from_path(path) + + # Get macro references that point to the same macro definition + matching_refs = filter( + lambda ref: isinstance(ref, MacroReference) + and ref.path == target_macro_path + and ref.target_range == target_macro_target_range, + get_macro_definitions_for_a_path(lsp_context, file_uri), + ) + + for ref in matching_refs: + assert isinstance(ref, MacroReference) # for mypy + all_references.append( + MacroReference( + path=path, + range=ref.range, + target_range=ref.target_range, + markdown_description=ref.markdown_description, + ) + ) + + return all_references + + +def get_all_references( + lint_context: LSPContext, document_uri: URI, position: Position +) -> t.Sequence[Reference]: + """ + Get all references of a symbol at a specific position in a document. + + This function determines the type of reference (CTE, model or macro) at the cursor + position and returns all references to that symbol across the project. + + Args: + lint_context: The LSP context + document_uri: The URI of the document + position: The position to check for references + + Returns: + A list of references to the symbol at the given position + """ + # First try CTE references (within same file) + if cte_references := get_cte_references(lint_context, document_uri, position): + return cte_references + + # Then try model references (across files) + if model_references := get_model_find_all_references(lint_context, document_uri, position): + return model_references + + # Finally try macro references (across files) + if macro_references := get_macro_find_all_references(lint_context, document_uri, position): + return macro_references + + return [] + + +def _position_within_range(position: Position, range: Range) -> bool: + """Check if a position is within a given range.""" + return ( + range.start.line < position.line + or (range.start.line == position.line and range.start.character <= position.character) + ) and ( + range.end.line > position.line + or (range.end.line == position.line and range.end.character >= position.character) + ) diff --git a/sqlmesh/lsp/rename.py b/sqlmesh/lsp/rename.py new file mode 100644 index 0000000000..5675c4efca --- /dev/null +++ b/sqlmesh/lsp/rename.py @@ -0,0 +1,143 @@ +import typing as t +from lsprotocol.types import ( + Position, + TextEdit, + WorkspaceEdit, + PrepareRenameResult_Type1, + DocumentHighlight, + DocumentHighlightKind, +) + +from sqlmesh.lsp.context import LSPContext +from sqlmesh.lsp.helpers import to_sqlmesh_position, to_lsp_range +from sqlmesh.lsp.reference import ( + _position_within_range, + get_cte_references, + CTEReference, +) +from sqlmesh.lsp.uri import URI + + +def prepare_rename( + lsp_context: LSPContext, document_uri: URI, lsp_position: Position +) -> t.Optional[PrepareRenameResult_Type1]: + """ + Prepare for rename operation by checking if the symbol at the position can be renamed. + + Args: + lsp_context: The LSP context + document_uri: The URI of the document + position: The position in the document + + Returns: + PrepareRenameResult if the symbol can be renamed, None otherwise + """ + # Check if there's a CTE at this position + position = to_sqlmesh_position(lsp_position) + cte_references = get_cte_references(lsp_context, document_uri, position) + if cte_references: + # Find the target CTE definition to get its range + target_range = None + for ref in cte_references: + # Check if cursor is on a CTE usage + if _position_within_range(position, ref.range): + target_range = ref.target_range + break + # Check if cursor is on the CTE definition + elif _position_within_range(position, ref.target_range): + target_range = ref.target_range + break + if target_range: + return PrepareRenameResult_Type1( + range=to_lsp_range(target_range), placeholder="cte_name" + ) + + # For now, only CTEs are supported + return None + + +def rename_symbol( + lsp_context: LSPContext, document_uri: URI, lsp_position: Position, new_name: str +) -> t.Optional[WorkspaceEdit]: + """ + Perform rename operation on the symbol at the given position. + + Args: + lsp_context: The LSP context + document_uri: The URI of the document + position: The position in the document + new_name: The new name for the symbol + + Returns: + WorkspaceEdit with the changes, or None if no symbol to rename + """ + # Check if there's a CTE at this position + cte_references = get_cte_references( + lsp_context, document_uri, to_sqlmesh_position(lsp_position) + ) + if cte_references: + return _rename_cte(cte_references, new_name) + + # For now, only CTEs are supported + return None + + +def _rename_cte(cte_references: t.List[CTEReference], new_name: str) -> WorkspaceEdit: + """ + Create a WorkspaceEdit for renaming a CTE. + + Args: + cte_references: List of CTE references (definition and usages) + new_name: The new name for the CTE + + Returns: + WorkspaceEdit with the text edits for renaming the CTE + """ + changes: t.Dict[str, t.List[TextEdit]] = {} + + for ref in cte_references: + uri = URI.from_path(ref.path).value + if uri not in changes: + changes[uri] = [] + + # Create a text edit for this reference + text_edit = TextEdit(range=to_lsp_range(ref.range), new_text=new_name) + changes[uri].append(text_edit) + + return WorkspaceEdit(changes=changes) + + +def get_document_highlights( + lsp_context: LSPContext, document_uri: URI, position: Position +) -> t.Optional[t.List[DocumentHighlight]]: + """ + Get document highlights for all occurrences of the symbol at the given position. + + This function finds all occurrences of a symbol (CTE) within the current document + and returns them as DocumentHighlight objects for "Change All Occurrences" feature. + + Args: + lsp_context: The LSP context + document_uri: The URI of the document + position: The position in the document to find highlights for + + Returns: + List of DocumentHighlight objects or None if no symbol found + """ + # Check if there's a CTE at this position + cte_references = get_cte_references(lsp_context, document_uri, to_sqlmesh_position(position)) + if cte_references: + highlights = [] + for ref in cte_references: + # Determine the highlight kind based on whether it's a definition or usage + kind = ( + DocumentHighlightKind.Write + if ref.range == ref.target_range + else DocumentHighlightKind.Read + ) + + highlights.append(DocumentHighlight(range=to_lsp_range(ref.range), kind=kind)) + return highlights + + # For now, only CTEs are supported + return None diff --git a/sqlmesh/lsp/tests_ranges.py b/sqlmesh/lsp/tests_ranges.py new file mode 100644 index 0000000000..cbcb33d8b6 --- /dev/null +++ b/sqlmesh/lsp/tests_ranges.py @@ -0,0 +1,65 @@ +""" +Provides helper functions to get ranges of tests in SQLMesh LSP. +""" + +from pathlib import Path + +from sqlmesh.core.linter.rule import Range, Position +from ruamel import yaml +from ruamel.yaml.comments import CommentedMap +import typing as t + + +def get_test_ranges( + path: Path, +) -> t.Dict[str, Range]: + """ + Test files are yaml files with a stucture of dict to test information. This returns a dictionary + with the test name as the key and the range of the test in the file as the value. + """ + test_ranges: t.Dict[str, Range] = {} + + with open(path, "r", encoding="utf-8") as file: + content = file.read() + + # Parse YAML to get line numbers + yaml_obj = yaml.YAML() + yaml_obj.preserve_quotes = True + data = yaml_obj.load(content) + + if not isinstance(data, dict): + raise ValueError("Invalid test file format: expected a dictionary at the top level.") + + # For each top-level key (test name), find its range + for test_name in data: + if isinstance(data, CommentedMap) and test_name in data.lc.data: + # Get line and column info from ruamel yaml + line_info = data.lc.data[test_name] + start_line = line_info[0] # 0-based line number + start_col = line_info[1] # 0-based column number + + # Find the end of this test by looking for the next test or end of file + lines = content.splitlines() + end_line = start_line + + # Find where this test ends by looking for the next top-level key + # or the end of the file + for i in range(start_line + 1, len(lines)): + line = lines[i] + # Check if this line starts a new top-level key (no leading spaces) + if line and not line[0].isspace() and ":" in line: + end_line = i - 1 + break + else: + # This test goes to the end of the file + end_line = len(lines) - 1 + + # Create the range + test_ranges[test_name] = Range( + start=Position(line=start_line, character=start_col), + end=Position( + line=end_line, character=len(lines[end_line]) if end_line < len(lines) else 0 + ), + ) + + return test_ranges diff --git a/sqlmesh/lsp/uri.py b/sqlmesh/lsp/uri.py new file mode 100644 index 0000000000..f8f0a495db --- /dev/null +++ b/sqlmesh/lsp/uri.py @@ -0,0 +1,33 @@ +from pathlib import Path +from pygls.uris import from_fs_path, to_fs_path +import typing as t + + +class URI: + """ + A URI is a unique identifier for a file used in the LSP. + """ + + def __init__(self, uri: str): + self.value: str = uri + + def __hash__(self) -> int: + return hash(self.value) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, URI): + return False + return self.value == other.value + + def __repr__(self) -> str: + return f"URI({self.value})" + + def to_path(self) -> Path: + p = to_fs_path(self.value) + return Path(p) + + @staticmethod + def from_path(path: t.Union[str, Path]) -> "URI": + if isinstance(path, Path): + path = path.as_posix() + return URI(from_fs_path(path)) diff --git a/sqlmesh/magics.py b/sqlmesh/magics.py index 51652fc00d..0a433360df 100644 --- a/sqlmesh/magics.py +++ b/sqlmesh/magics.py @@ -1,13 +1,22 @@ from __future__ import annotations +from io import StringIO + import functools import logging import typing as t -from argparse import Namespace +from argparse import Namespace, SUPPRESS from collections import defaultdict +from copy import deepcopy +from pathlib import Path from hyperscript import h -from IPython.core.display import display + +try: + from IPython.core.display import display # type: ignore +except ImportError: + from IPython.display import display + from IPython.core.magic import ( Magics, cell_magic, @@ -18,17 +27,16 @@ from IPython.core.magic_arguments import argument, magic_arguments, parse_argstring from IPython.utils.process import arg_split from rich.jupyter import JupyterRenderable - -from sqlmesh.cli.example_project import ProjectTemplate, init_example_project +from sqlmesh.cli.project_init import ProjectTemplate, init_example_project from sqlmesh.core import analytics -from sqlmesh.core import constants as c from sqlmesh.core.config import load_configs -from sqlmesh.core.console import get_console +from sqlmesh.core.config.connection import INIT_DISPLAY_INFO_TO_TYPE +from sqlmesh.core.console import create_console, set_console, configure_console from sqlmesh.core.context import Context from sqlmesh.core.dialect import format_model_expressions, parse from sqlmesh.core.model import load_sql_based_model -from sqlmesh.core.test import ModelTestMetadata, get_all_model_tests -from sqlmesh.utils import sqlglot_dialects, yaml +from sqlmesh.core.test import ModelTestMetadata +from sqlmesh.utils import yaml, Verbosity, optional_import from sqlmesh.utils.errors import MagicError, MissingContextException, SQLMeshError logger = logging.getLogger(__name__) @@ -52,7 +60,9 @@ def wrapper(self: SQLMeshMagics, *args: t.Any, **kwargs: t.Any) -> None: f"Context must be defined and initialized with one of these names: {', '.join(CONTEXT_VARIABLE_NAMES)}" ) old_console = context.console - context.console = get_console(display=self.display) + new_console = create_console(display=self.display) + context.console = new_console + set_console(new_console) context.refresh() magic_name = func.__name__ @@ -60,8 +70,19 @@ def wrapper(self: SQLMeshMagics, *args: t.Any, **kwargs: t.Any) -> None: if bound_method: args_split = arg_split(args[0]) parser = bound_method.parser - # Calling the private method to bypass setting of defaults. - parsed_args, _ = parser._parse_known_args(args_split, Namespace()) + + original_parser_actions = deepcopy(parser._actions) + original_parser_defaults = parser._defaults + + # Temporarily supress default values, otherwise any missing arg would be set and affect analytics + parser._defaults = {} + for action in parser._actions: + action.default = SUPPRESS + + parsed_args, _ = parser.parse_known_args(args_split, Namespace()) + + parser._actions = original_parser_actions + parser._defaults = original_parser_defaults command_args = {k for k, v in parsed_args.__dict__.items() if v is not None} analytics.collector.on_magic_command(command_name=magic_name, command_args=command_args) @@ -69,10 +90,48 @@ def wrapper(self: SQLMeshMagics, *args: t.Any, **kwargs: t.Any) -> None: func(self, context, *args, **kwargs) context.console = old_console + set_console(old_console) return wrapper +def format_arguments(func: t.Callable) -> t.Callable: + """Decorator to add common format arguments to magic commands.""" + func = argument( + "--normalize", + action="store_true", + help="Whether or not to normalize identifiers to lowercase.", + default=None, + )(func) + func = argument( + "--pad", + type=int, + help="Determines the pad size in a formatted string.", + )(func) + func = argument( + "--indent", + type=int, + help="Determines the indentation size in a formatted string.", + )(func) + func = argument( + "--normalize-functions", + type=str, + help="Whether or not to normalize all function names. Possible values are: 'upper', 'lower'", + )(func) + func = argument( + "--leading-comma", + action="store_true", + help="Determines whether or not the comma is leading or trailing in select expressions. Default is trailing.", + default=None, + )(func) + func = argument( + "--max-text-width", + type=int, + help="The max number of characters in a segment before creating new lines in pretty mode.", + )(func) + return func + + @magics_class class SQLMeshMagics(Magics): @property @@ -108,17 +167,32 @@ def _shell(self) -> t.Any: @argument("--ignore-warnings", action="store_true", help="Ignore warnings.") @argument("--debug", action="store_true", help="Enable debug mode.") @argument("--log-file-dir", type=str, help="The directory to write the log file to.") + @argument( + "--dotenv", type=str, help="Path to a custom .env file to load environment variables from." + ) @line_magic def context(self, line: str) -> None: """Sets the context in the user namespace.""" - from sqlmesh import configure_logging + from sqlmesh import configure_logging, remove_excess_logs args = parse_argstring(self.context, line) - configs = load_configs(args.config, Context.CONFIG_TYPE, args.paths) - log_limit = list(configs.values())[0].log_limit + log_file_dir = args.log_file_dir + configure_logging( - args.debug, args.ignore_warnings, log_limit=log_limit, log_file_dir=args.log_file_dir + args.debug, + log_file_dir=log_file_dir, + ignore_warnings=args.ignore_warnings, + ) + configure_console(ignore_warnings=args.ignore_warnings) + + dotenv_path = Path(args.dotenv) if args.dotenv else None + configs = load_configs( + args.config, Context.CONFIG_TYPE, args.paths, dotenv_path=dotenv_path ) + log_limit = list(configs.values())[0].log_limit + + remove_excess_logs(log_file_dir, log_limit) + try: context = Context(paths=args.paths, config=configs, gateway=args.gateway) self._shell.user_ns["context"] = context @@ -126,20 +200,31 @@ def context(self, line: str) -> None: if args.debug: logger.exception("Failed to initialize SQLMesh context") raise + context.console.log_success(f"SQLMesh project context set to: {', '.join(args.paths)}") @magic_arguments() @argument("path", type=str, help="The path where the new SQLMesh project should be created.") @argument( - "sql_dialect", + "engine", type=str, - help=f"Default model SQL dialect. Supported values: {sqlglot_dialects()}.", + help=f"Project SQL engine. Supported values: '{', '.join([info[1] for info in sorted(INIT_DISPLAY_INFO_TO_TYPE.values(), key=lambda x: x[0])])}'.", # type: ignore ) @argument( "--template", "-t", type=str, - help="Project template. Supported values: airflow, dbt, default, empty.", + help="Project template. Supported values: dbt, default, empty.", + ) + @argument( + "--dlt-pipeline", + type=str, + help="DLT pipeline for which to generate a SQLMesh project. Use alongside template: dlt", + ) + @argument( + "--dlt-path", + type=str, + help="The directory where the DLT pipeline resides. Use alongside template: dlt", ) @line_magic def init(self, line: str) -> None: @@ -151,7 +236,14 @@ def init(self, line: str) -> None: ) except ValueError: raise MagicError(f"Invalid project template '{args.template}'") - init_example_project(args.path, args.sql_dialect, project_template) + init_example_project( + path=args.path, + engine_type=args.engine, + dialect=None, + template=project_template, + pipeline=args.dlt_pipeline, + dlt_path=args.dlt_path, + ) html = str( h( "div", @@ -188,18 +280,22 @@ def model(self, context: Context, line: str, sql: t.Optional[str] = None) -> Non path=model._path, dialect=config.dialect, time_column_format=config.time_column_format, - physical_schema_override=context.config.physical_schema_override, + physical_schema_mapping=context.config.physical_schema_mapping, default_catalog=context.default_catalog, ) if loaded.name == args.model: model = loaded else: - with open(model._path, "r", encoding="utf-8") as file: - expressions = parse(file.read(), default_dialect=config.dialect) + if model._path: + with open(model._path, "r", encoding="utf-8") as file: + expressions = parse(file.read(), default_dialect=config.dialect) formatted = format_model_expressions( - expressions, model.dialect, **config.format.generator_options + expressions, + model.dialect, + rewrite_casts=not config.format.no_rewrite_casts, + **config.format.generator_options, ) self._shell.set_next_input( @@ -212,8 +308,9 @@ def model(self, context: Context, line: str, sql: t.Optional[str] = None) -> Non replace=True, ) - with open(model._path, "w", encoding="utf-8") as file: - file.write(formatted) + if model._path: + with open(model._path, "w", encoding="utf-8") as file: + file.write(formatted) if sql: context.console.log_success(f"Model `{args.model}` updated") @@ -240,15 +337,7 @@ def test(self, context: Context, line: str, test_def_raw: t.Optional[str] = None if not args.test_name and not args.ls: raise MagicError("Must provide either test name or `--ls` to list tests") - test_meta = [] - - for path, config in context.configs.items(): - test_meta.extend( - get_all_model_tests( - path / c.TESTS, - ignore_patterns=config.ignore_patterns, - ) - ) + test_meta = context.select_tests() tests: t.Dict[str, t.Dict[str, ModelTestMetadata]] = defaultdict(dict) for model_test_metadata in test_meta: @@ -310,6 +399,11 @@ def test(self, context: Context, line: str, test_def_raw: t.Optional[str] = None action="store_true", help="Skip the unit tests defined for the model.", ) + @argument( + "--skip-linter", + action="store_true", + help="Skip the linter for the model.", + ) @argument( "--restate-model", "-r", @@ -325,8 +419,14 @@ def test(self, context: Context, line: str, test_def_raw: t.Optional[str] = None ) @argument( "--skip-backfill", + "--dry-run", + action="store_true", + help="Skip the backfill step and only create a virtual update for the plan.", + ) + @argument( + "--empty-backfill", action="store_true", - help="Skip the backfill step.", + help="Produce empty backfill. Like --skip-backfill no models will be backfilled, unlike --skip-backfill missing intervals will be recorded as if they were backfilled.", ) @argument( "--forward-only", @@ -373,7 +473,7 @@ def test(self, context: Context, line: str, test_def_raw: t.Optional[str] = None "--backfill-model", type=str, nargs="*", - help="Backfill only the models whose names match the expression. This is supported only when targeting a development environment.", + help="Backfill only the models whose names match the expression.", ) @argument( "--no-diff", @@ -386,18 +486,38 @@ def test(self, context: Context, line: str, test_def_raw: t.Optional[str] = None action="store_true", help="Run latest intervals as part of the plan application (prod environment only).", ) + @argument( + "--ignore-cron", + action="store_true", + help="Run for all missing intervals, ignoring individual cron schedules. Only applies if --run is set.", + default=None, + ) @argument( "--enable-preview", action="store_true", help="Enable preview for forward-only models when targeting a development environment.", default=None, ) + @argument( + "--diff-rendered", + action="store_true", + help="Output text differences for the rendered versions of the models and standalone audits", + ) + @argument( + "--verbose", + "-v", + action="count", + default=0, + help="Verbose output. Use -vv for very verbose.", + ) @line_magic @pass_sqlmesh_context def plan(self, context: Context, line: str) -> None: """Goes through a set of prompts to both establish a plan and apply it""" args = parse_argstring(self.plan, line) + setattr(context.console, "verbosity", Verbosity(args.verbose)) + context.plan( args.environment, start=args.start, @@ -409,6 +529,7 @@ def plan(self, context: Context, line: str) -> None: backfill_models=args.backfill_model, no_gaps=args.no_gaps, skip_backfill=args.skip_backfill, + empty_backfill=args.empty_backfill, forward_only=args.forward_only, no_prompts=args.no_prompts, auto_apply=args.auto_apply, @@ -418,7 +539,9 @@ def plan(self, context: Context, line: str) -> None: select_models=args.select_model, no_diff=args.no_diff, run=args.run, + ignore_cron=args.run, enable_preview=args.enable_preview, + diff_rendered=args.diff_rendered, ) @magic_arguments() @@ -436,20 +559,39 @@ def plan(self, context: Context, line: str) -> None: action="store_true", help="Run for all missing intervals, ignoring individual cron schedules.", ) + @argument( + "--select-model", + type=str, + nargs="*", + help="Select specific models to run. Note: this always includes upstream dependencies.", + ) + @argument( + "--exit-on-env-update", + type=int, + help="If set, the command will exit with the specified code if the run is interrupted by an update to the target environment.", + ) + @argument( + "--no-auto-upstream", + action="store_true", + help="Do not automatically include upstream models. Only applicable when --select-model is used. Note: this may result in missing / invalid data for the selected models.", + ) @line_magic @pass_sqlmesh_context def run_dag(self, context: Context, line: str) -> None: """Evaluate the DAG of models using the built-in scheduler.""" args = parse_argstring(self.run_dag, line) - success = context.run( + completion_status = context.run( args.environment, start=args.start, end=args.end, skip_janitor=args.skip_janitor, ignore_cron=args.ignore_cron, + select_models=args.select_model, + exit_on_env_update=args.exit_on_env_update, + no_auto_upstream=args.no_auto_upstream, ) - if not success: + if completion_status.is_failure: raise SQLMeshError("Error Running DAG. Check logs for details.") @magic_arguments() @@ -467,6 +609,8 @@ def run_dag(self, context: Context, line: str) -> None: def evaluate(self, context: Context, line: str) -> None: """Evaluate a model query and fetches a dataframe.""" context.refresh() + + snowpark = optional_import("snowflake.snowpark") args = parse_argstring(self.evaluate, line) df = context.evaluate( @@ -476,6 +620,10 @@ def evaluate(self, context: Context, line: str) -> None: execution_time=args.execution_time, limit=args.limit, ) + + if snowpark and isinstance(df, snowpark.DataFrame): + df = df.limit(args.limit or 100).to_pandas() + self.display(df) @magic_arguments() @@ -490,23 +638,41 @@ def evaluate(self, context: Context, line: str) -> None: ) @argument("--dialect", type=str, help="SQL dialect to render.") @argument("--no-format", action="store_true", help="Disable fancy formatting of the query.") + @format_arguments @line_magic @pass_sqlmesh_context def render(self, context: Context, line: str) -> None: """Renders a model's query, optionally expanding referenced models.""" context.refresh() - args = parse_argstring(self.render, line) + render_opts = vars(parse_argstring(self.render, line)) + model = render_opts.pop("model") + dialect = render_opts.pop("dialect", None) + + model = context.get_model(model, raise_if_missing=True) query = context.render( - args.model, - start=args.start, - end=args.end, - execution_time=args.execution_time, - expand=args.expand, + model, + start=render_opts.pop("start", None), + end=render_opts.pop("end", None), + execution_time=render_opts.pop("execution_time", None), + expand=render_opts.pop("expand", False), + ) + + no_format = render_opts.pop("no_format", False) + + format_config = context.config_for_node(model).format + format_options = { + **format_config.generator_options, + **{k: v for k, v in render_opts.items() if v is not None}, + } + + sql = query.sql( + pretty=True, + dialect=context.config.dialect if dialect is None else dialect, + **format_options, ) - sql = query.sql(pretty=True, dialect=args.dialect or context.config.dialect) - if args.no_format: + if no_format: context.console.log_status_update(sql) else: context.console.show_sql(sql) @@ -616,11 +782,27 @@ def create_external_models(self, context: Context, line: str) -> None: default=3, help="The number of decimal places to keep when comparing floating point columns. Default: 3", ) + @argument( + "--select-model", + type=str, + nargs="*", + help="Specify one or more models to data diff. Use wildcards to diff multiple models. Ex: '*' (all models with applied plan diffs), 'demo.model+' (this and downstream models), 'git:feature_branch' (models with direct modifications in this branch only)", + ) @argument( "--skip-grain-check", action="store_true", help="Disable the check for a primary key (grain) that is missing or is not unique.", ) + @argument( + "--warn-grain-check", + action="store_true", + help="Warn if any selected model is missing a grain, and compute diffs for the remaining models.", + ) + @argument( + "--schema-diff-ignore-case", + action="store_true", + help="If set, when performing a schema diff the case of column names is ignored when matching between the two schemas. For example, 'col_a' in the source schema and 'COL_A' in the target schema will be treated as the same column.", + ) @line_magic @pass_sqlmesh_context def table_diff(self, context: Context, line: str) -> None: @@ -630,17 +812,20 @@ def table_diff(self, context: Context, line: str) -> None: """ args = parse_argstring(self.table_diff, line) source, target = args.source_to_target.split(":") + select_models = {args.model} if args.model else args.select_model or None context.table_diff( source=source, target=target, on=args.on, skip_columns=args.skip_columns, - model_or_snapshot=args.model, + select_models=select_models, where=args.where, limit=args.limit, show_sample=args.show_sample, decimals=args.decimals, skip_grain_check=args.skip_grain_check, + warn_grain_check=args.warn_grain_check, + schema_diff_ignore_case=args.schema_diff_ignore_case, ) @magic_arguments() @@ -651,16 +836,64 @@ def table_diff(self, context: Context, line: str) -> None: help="The name of the model to get the table name for.", ) @argument( - "--dev", + "--environment", + type=str, + help="The environment to source the model version from.", + ) + @argument( + "--prod", action="store_true", - help="Print the name of the snapshot table used for previews in development environments.", + help="If set, return the name of the physical table that will be used in production for the model version promoted in the target environment.", ) @line_magic @pass_sqlmesh_context def table_name(self, context: Context, line: str) -> None: """Prints the name of the physical table for the given model.""" args = parse_argstring(self.table_name, line) - context.console.log_status_update(context.table_name(args.model_name, args.dev)) + context.console.log_status_update( + context.table_name(args.model_name, args.environment, args.prod) + ) + + @magic_arguments() + @argument( + "pipeline", + nargs="?", + type=str, + help="The dlt pipeline to attach for this SQLMesh project.", + ) + @argument( + "--table", + "-t", + type=str, + nargs="*", + help="The specific dlt tables to refresh in the SQLMesh models.", + ) + @argument( + "--force", + "-f", + action="store_true", + help="If set, existing models are overwritten with the new DLT tables.", + ) + @argument( + "--dlt-path", + type=str, + help="The directory where the DLT pipeline resides.", + ) + @line_magic + @pass_sqlmesh_context + def dlt_refresh(self, context: Context, line: str) -> None: + """Attaches to a DLT pipeline with the option to update specific or all missing tables in the SQLMesh project.""" + from sqlmesh.integrations.dlt import generate_dlt_models + + args = parse_argstring(self.dlt_refresh, line) + sqlmesh_models = generate_dlt_models( + context, args.pipeline, list(args.table or []), args.force, args.dlt_path + ) + if sqlmesh_models: + model_names = "\n".join([f"- {model_name}" for model_name in sqlmesh_models]) + context.console.log_success(f"Updated SQLMesh project with models:\n{model_names}") + else: + context.console.log_success("All SQLMesh models are up to date.") @magic_arguments() @argument( @@ -697,49 +930,33 @@ def rewrite(self, context: Context, line: str, sql: str) -> None: help="Transpile project models to the specified dialect.", ) @argument( - "--append-newline", + "--check", action="store_true", - help="Whether or not to append a newline to the end of the file.", + help="Whether or not to check formatting (but not actually format anything).", default=None, ) @argument( - "--normalize", + "--append-newline", action="store_true", - help="Whether or not to normalize identifiers to lowercase.", + help="Include a newline at the end of the output.", default=None, ) @argument( - "--pad", - type=int, - help="Determines the pad size in a formatted string.", - ) - @argument( - "--indent", - type=int, - help="Determines the indentation size in a formatted string.", - ) - @argument( - "--normalize-functions", - type=str, - help="Whether or not to normalize all function names. Possible values are: 'upper', 'lower'", - ) - @argument( - "--leading-comma", + "--no-rewrite-casts", action="store_true", - help="Determines whether or not the comma is leading or trailing in select expressions. Default is trailing.", + help="Preserve the existing casts, without rewriting them to use the :: syntax.", default=None, ) - @argument( - "--max-text-width", - type=int, - help="The max number of characters in a segment before creating new lines in pretty mode.", - ) + @format_arguments @line_magic @pass_sqlmesh_context - def format(self, context: Context, line: str) -> None: + def format(self, context: Context, line: str) -> bool: """Format all SQL models and audits.""" - args = parse_argstring(self.format, line) - context.format(**{k: v for k, v in vars(args).items() if v is not None}) + format_opts = vars(parse_argstring(self.format, line)) + if format_opts.pop("no_rewrite_casts", None): + format_opts["rewrite_casts"] = False + + return context.format(**{k: v for k, v in format_opts.items() if v is not None}) @magic_arguments() @argument("environment", type=str, help="The environment to diff local state against.") @@ -779,7 +996,7 @@ def janitor(self, context: Context, line: str) -> None: "-q", type=str, nargs="+", - required=True, + default=[], help="Queries that will be used to generate data for the model's dependencies.", ) @argument( @@ -840,7 +1057,13 @@ def create_test(self, context: Context, line: str) -> None: type=str, help="Only run tests that match the pattern of substring.", ) - @argument("--verbose", "-v", action="store_true", help="Verbose output.") + @argument( + "--verbose", + "-v", + action="count", + default=0, + help="Verbose output. Use -vv for very verbose.", + ) @argument( "--preserve-fixtures", action="store_true", @@ -851,11 +1074,13 @@ def create_test(self, context: Context, line: str) -> None: def run_test(self, context: Context, line: str) -> None: """Run unit test(s).""" args = parse_argstring(self.run_test, line) + context.test( match_patterns=args.pattern, tests=args.tests, - verbose=args.verbose, + verbosity=Verbosity(args.verbose), preserve_fixtures=args.preserve_fixtures, + stream=StringIO(), # consume the output instead of redirecting to stdout ) @magic_arguments() @@ -867,13 +1092,45 @@ def run_test(self, context: Context, line: str) -> None: @argument("--execution-time", type=str, help="Execution time.") @line_magic @pass_sqlmesh_context - def audit(self, context: Context, line: str) -> None: + def audit(self, context: Context, line: str) -> bool: """Run audit(s)""" args = parse_argstring(self.audit, line) - context.audit( + return context.audit( models=args.models, start=args.start, end=args.end, execution_time=args.execution_time ) + @magic_arguments() + @argument("environment", nargs="?", type=str, help="The environment to check intervals for.") + @argument( + "--no-signals", + action="store_true", + help="Disable signal checks and only show missing intervals.", + default=False, + ) + @argument( + "--select-model", + type=str, + nargs="*", + help="Select specific model changes that should be included in the plan.", + ) + @argument("--start", "-s", type=str, help="Start date of intervals to check for.") + @argument("--end", "-e", type=str, help="End date of intervals to check for.") + @line_magic + @pass_sqlmesh_context + def check_intervals(self, context: Context, line: str) -> None: + """Show missing intervals in an environment, respecting signals.""" + args = parse_argstring(self.check_intervals, line) + + context.console.show_intervals( + context.check_intervals( + environment=args.environment, + no_signals=args.no_signals, + select_models=args.select_model, + start=args.start, + end=args.end, + ) + ) + @magic_arguments() @argument( "--skip-connection", @@ -881,12 +1138,19 @@ def audit(self, context: Context, line: str) -> None: help="Skip the connection test.", default=False, ) + @argument( + "--verbose", + "-v", + action="count", + default=0, + help="Verbose output. Use -vv for very verbose.", + ) @line_magic @pass_sqlmesh_context def info(self, context: Context, line: str) -> None: """Display SQLMesh project information.""" args = parse_argstring(self.info, line) - context.print_info(skip_connection=args.skip_connection) + context.print_info(skip_connection=args.skip_connection, verbosity=Verbosity(args.verbose)) @magic_arguments() @line_magic @@ -903,6 +1167,35 @@ def clean(self, context: Context, line: str) -> None: context.clear_caches() context.console.log_success("SQLMesh cache and build artifacts cleared") + @magic_arguments() + @line_magic + @pass_sqlmesh_context + def environments(self, context: Context, line: str) -> None: + """Prints the list of SQLMesh environments with its expiry datetime.""" + context.print_environment_names() + + @magic_arguments() + @argument( + "--models", + "--model", + type=str, + nargs="*", + help="A model to lint. Multiple models can be linted. If no models are specified, every model will be linted.", + ) + @line_magic + @pass_sqlmesh_context + def lint(self, context: Context, line: str) -> None: + """Run linter for target model(s)""" + args = parse_argstring(self.lint, line) + context.lint_models(args.models) + + @magic_arguments() + @line_magic + @pass_sqlmesh_context + def destroy(self, context: Context, line: str) -> None: + """Removes all project resources, engine-managed objects, state tables and clears the SQLMesh cache.""" + context.destroy() + def register_magics() -> None: try: diff --git a/sqlmesh/migrations/v0000_baseline.py b/sqlmesh/migrations/v0000_baseline.py new file mode 100644 index 0000000000..abd316fcfe --- /dev/null +++ b/sqlmesh/migrations/v0000_baseline.py @@ -0,0 +1,95 @@ +"""The baseline migration script that sets up the initial state tables.""" + +from sqlglot import exp +from sqlmesh.utils.migration import blob_text_type, index_text_type + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + intervals_table = "_intervals" + snapshots_table = "_snapshots" + environments_table = "_environments" + versions_table = "_versions" + if schema: + engine_adapter.create_schema(schema) + intervals_table = f"{schema}.{intervals_table}" + snapshots_table = f"{schema}.{snapshots_table}" + environments_table = f"{schema}.{environments_table}" + versions_table = f"{schema}.{versions_table}" + + index_type = index_text_type(engine_adapter.dialect) + blob_type = blob_text_type(engine_adapter.dialect) + + snapshots_columns_to_types = { + "name": exp.DataType.build(index_type), + "identifier": exp.DataType.build(index_type), + "version": exp.DataType.build(index_type), + "snapshot": exp.DataType.build(blob_type), + "kind_name": exp.DataType.build(index_type), + "updated_ts": exp.DataType.build("bigint"), + "unpaused_ts": exp.DataType.build("bigint"), + "ttl_ms": exp.DataType.build("bigint"), + "unrestorable": exp.DataType.build("boolean"), + } + + environments_columns_to_types = { + "name": exp.DataType.build(index_type), + "snapshots": exp.DataType.build(blob_type), + "start_at": exp.DataType.build("text"), + "end_at": exp.DataType.build("text"), + "plan_id": exp.DataType.build("text"), + "previous_plan_id": exp.DataType.build("text"), + "expiration_ts": exp.DataType.build("bigint"), + "finalized_ts": exp.DataType.build("bigint"), + "promoted_snapshot_ids": exp.DataType.build(blob_type), + "suffix_target": exp.DataType.build("text"), + "catalog_name_override": exp.DataType.build("text"), + "previous_finalized_snapshots": exp.DataType.build(blob_type), + "normalize_name": exp.DataType.build("boolean"), + "requirements": exp.DataType.build(blob_type), + } + + intervals_columns_to_types = { + "id": exp.DataType.build(index_type), + "created_ts": exp.DataType.build("bigint"), + "name": exp.DataType.build(index_type), + "identifier": exp.DataType.build(index_type), + "version": exp.DataType.build(index_type), + "start_ts": exp.DataType.build("bigint"), + "end_ts": exp.DataType.build("bigint"), + "is_dev": exp.DataType.build("boolean"), + "is_removed": exp.DataType.build("boolean"), + "is_compacted": exp.DataType.build("boolean"), + } + + versions_columns_to_types = { + "schema_version": exp.DataType.build("int"), + "sqlglot_version": exp.DataType.build(index_type), + "sqlmesh_version": exp.DataType.build(index_type), + } + + # Create the versions table. + engine_adapter.create_state_table(versions_table, versions_columns_to_types) + + # Create the snapshots table and its indexes. + engine_adapter.create_state_table( + snapshots_table, snapshots_columns_to_types, primary_key=("name", "identifier") + ) + engine_adapter.create_index(snapshots_table, "_snapshots_name_version_idx", ("name", "version")) + + # Create the environments table and its indexes. + engine_adapter.create_state_table( + environments_table, environments_columns_to_types, primary_key=("name",) + ) + + # Create the intervals table and its indexes. + engine_adapter.create_state_table( + intervals_table, intervals_columns_to_types, primary_key=("id",) + ) + engine_adapter.create_index( + intervals_table, "_intervals_name_identifier_idx", ("name", "identifier") + ) + engine_adapter.create_index(intervals_table, "_intervals_name_version_idx", ("name", "version")) + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + pass diff --git a/sqlmesh/migrations/v0001_init.py b/sqlmesh/migrations/v0001_init.py deleted file mode 100644 index 778c36bc23..0000000000 --- a/sqlmesh/migrations/v0001_init.py +++ /dev/null @@ -1,60 +0,0 @@ -"""All migrations should be named _XXXX.py, they will be executed sequentially. - -If a migration alters the payload of any pydantic models, you should not actually use them because -the running model may not be able to load them. Make sure that these migration files are standalone. -""" - -from sqlglot import exp - -from sqlmesh.utils.migration import index_text_type - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - schema = state_sync.schema - snapshots_table = "_snapshots" - environments_table = "_environments" - versions_table = "_versions" - - if schema: - engine_adapter.create_schema(schema) - snapshots_table = f"{schema}.{snapshots_table}" - environments_table = f"{schema}.{environments_table}" - versions_table = f"{schema}.{versions_table}" - - index_type = index_text_type(engine_adapter.dialect) - - engine_adapter.create_state_table( - snapshots_table, - { - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build(index_type), - "version": exp.DataType.build(index_type), - "snapshot": exp.DataType.build("text"), - }, - primary_key=("name", "identifier"), - ) - - engine_adapter.create_index(snapshots_table, "name_version_idx", ("name", "version")) - - engine_adapter.create_state_table( - environments_table, - { - "name": exp.DataType.build(index_type), - "snapshots": exp.DataType.build("text"), - "start_at": exp.DataType.build("text"), - "end_at": exp.DataType.build("text"), - "plan_id": exp.DataType.build("text"), - "previous_plan_id": exp.DataType.build("text"), - "expiration_ts": exp.DataType.build("bigint"), - }, - primary_key=("name",), - ) - - engine_adapter.create_state_table( - versions_table, - { - "schema_version": exp.DataType.build("int"), - "sqlglot_version": exp.DataType.build("text"), - }, - ) diff --git a/sqlmesh/migrations/v0002_remove_identify.py b/sqlmesh/migrations/v0002_remove_identify.py deleted file mode 100644 index 0152e719f7..0000000000 --- a/sqlmesh/migrations/v0002_remove_identify.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Remove identify=True kwarg for rendering sql""" - - -def migrate(state_sync, **kwargs): # type: ignore - pass diff --git a/sqlmesh/migrations/v0003_move_batch_size.py b/sqlmesh/migrations/v0003_move_batch_size.py deleted file mode 100644 index 8148325750..0000000000 --- a/sqlmesh/migrations/v0003_move_batch_size.py +++ /dev/null @@ -1,34 +0,0 @@ -"""Move batch_size from the model and into the kind.""" - -import json - -from sqlglot import exp - - -def migrate(state_sync, **kwargs): # type: ignore - snapshots_table = "_snapshots" - if state_sync.schema: - snapshots_table = f"{state_sync.schema}.{snapshots_table}" - - for row in state_sync.engine_adapter.fetchall( - exp.select("*").from_(snapshots_table), quote_identifiers=True - ): - name, identifier, _, snapshot = row - snapshot = json.loads(snapshot) - model = snapshot["model"] - if "batch_size" in model: - batch_size = model.pop("batch_size") - kind = model.get("kind") - - if kind: - if kind["name"] in ("INCREMENTAL_BY_TIME_RANGE", "INCREMENTAL_BY_UNIQUE_KEY"): - kind["batch_size"] = batch_size - - # this is not efficient, i'm doing this because i'm lazy and no one has snapshots at the time of writing this migration - # do not copy this code in future migrations - - state_sync.engine_adapter.update_table( - snapshots_table, - {"snapshot": json.dumps(snapshot)}, - where=f"name = '{name}' and identifier = '{identifier}'", - ) diff --git a/sqlmesh/migrations/v0004_environmnent_add_finalized_at.py b/sqlmesh/migrations/v0004_environmnent_add_finalized_at.py deleted file mode 100644 index ad228abbc9..0000000000 --- a/sqlmesh/migrations/v0004_environmnent_add_finalized_at.py +++ /dev/null @@ -1,22 +0,0 @@ -"""Add support for environment finalization.""" - -from sqlglot import exp - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - environments_table = "_environments" - if state_sync.schema: - environments_table = f"{state_sync.schema}.{environments_table}" - - alter_table_exp = exp.AlterTable( - this=exp.to_table(environments_table), - actions=[ - exp.ColumnDef( - this=exp.to_column("finalized_ts"), - kind=exp.DataType.build("bigint"), - ) - ], - ) - - engine_adapter.execute(alter_table_exp) diff --git a/sqlmesh/migrations/v0005_create_seed_table.py b/sqlmesh/migrations/v0005_create_seed_table.py deleted file mode 100644 index 1e1e7dc34e..0000000000 --- a/sqlmesh/migrations/v0005_create_seed_table.py +++ /dev/null @@ -1,24 +0,0 @@ -"""Create a dedicated table to store the content of seeds.""" - -from sqlglot import exp - -from sqlmesh.utils.migration import index_text_type - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - seeds_table = "_seeds" - if state_sync.schema: - seeds_table = f"{state_sync.schema}.{seeds_table}" - - index_type = index_text_type(engine_adapter.dialect) - - engine_adapter.create_state_table( - seeds_table, - { - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build(index_type), - "content": exp.DataType.build("text"), - }, - primary_key=("name", "identifier"), - ) diff --git a/sqlmesh/migrations/v0006_change_seed_hash.py b/sqlmesh/migrations/v0006_change_seed_hash.py deleted file mode 100644 index d6d4e1bf9c..0000000000 --- a/sqlmesh/migrations/v0006_change_seed_hash.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Seed hashes moved from to_string to to_json for performance.""" - - -def migrate(state_sync, **kwargs): # type: ignore - pass diff --git a/sqlmesh/migrations/v0007_env_table_info_to_kind.py b/sqlmesh/migrations/v0007_env_table_info_to_kind.py deleted file mode 100644 index 1afffa1ca5..0000000000 --- a/sqlmesh/migrations/v0007_env_table_info_to_kind.py +++ /dev/null @@ -1,98 +0,0 @@ -"""Change environments because snapshot table info now stores model kind name.""" - -import json -import zlib - -import pandas as pd -from sqlglot import exp - -from sqlmesh.utils.migration import index_text_type - - -def _hash(data): # type: ignore - return str(zlib.crc32(";".join("" if d is None else d for d in data).encode("utf-8"))) - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - schema = state_sync.schema - environments_table = "_environments" - snapshots_table = "_snapshots" - if schema: - environments_table = f"{schema}.{environments_table}" - snapshots_table = f"{schema}.{snapshots_table}" - snapshots_to_kind = {} - - for name, identifier, snapshot in engine_adapter.fetchall( - exp.select("name", "identifier", "snapshot").from_(snapshots_table), - quote_identifiers=True, - ): - snapshot = json.loads(snapshot) - snapshots_to_kind[(name, identifier)] = snapshot["model"]["kind"]["name"] - - environments = engine_adapter.fetchall( - exp.select("*").from_(environments_table), quote_identifiers=True - ) - new_environments = [] - - for ( - name, - snapshots, - start_at, - end_at, - plan_id, - previous_plan_id, - expiration_ts, - finalized_ts, - ) in environments: - new_snapshots = [] - - for snapshot in json.loads(snapshots): - snapshot.pop("is_materialized", None) - snapshot.pop("is_embedded_kind", None) - - fingerprint = snapshot["fingerprint"] - identifier = _hash( - [ - fingerprint["data_hash"], - fingerprint["metadata_hash"], - fingerprint["parent_data_hash"], - fingerprint["parent_metadata_hash"], - ] - ) - - snapshot["kind_name"] = snapshots_to_kind.get((snapshot["name"], identifier), "VIEW") - new_snapshots.append(snapshot) - - new_environments.append( - { - "name": name, - "snapshots": json.dumps(new_snapshots), - "start_at": start_at, - "end_at": end_at, - "plan_id": plan_id, - "previous_plan_id": previous_plan_id, - "expiration_ts": expiration_ts, - "finalized_ts": finalized_ts, - } - ) - - if new_environments: - engine_adapter.delete_from(environments_table, "TRUE") - - index_type = index_text_type(engine_adapter.dialect) - - engine_adapter.insert_append( - environments_table, - pd.DataFrame(new_environments), - columns_to_types={ - "name": exp.DataType.build(index_type), - "snapshots": exp.DataType.build("text"), - "start_at": exp.DataType.build("text"), - "end_at": exp.DataType.build("text"), - "plan_id": exp.DataType.build("text"), - "previous_plan_id": exp.DataType.build("text"), - "expiration_ts": exp.DataType.build("bigint"), - "finalized_ts": exp.DataType.build("bigint"), - }, - ) diff --git a/sqlmesh/migrations/v0008_create_intervals_table.py b/sqlmesh/migrations/v0008_create_intervals_table.py deleted file mode 100644 index 0746febcaa..0000000000 --- a/sqlmesh/migrations/v0008_create_intervals_table.py +++ /dev/null @@ -1,38 +0,0 @@ -"""Create a dedicated table to store snapshot intervals.""" - -from sqlglot import exp - -from sqlmesh.utils.migration import index_text_type - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - intervals_table = "_intervals" - if state_sync.schema: - intervals_table = f"{state_sync.schema}.{intervals_table}" - - index_type = index_text_type(engine_adapter.dialect) - - engine_adapter.create_state_table( - intervals_table, - { - "id": exp.DataType.build(index_type), - "created_ts": exp.DataType.build("bigint"), - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build(index_type), - "version": exp.DataType.build(index_type), - "start_ts": exp.DataType.build("bigint"), - "end_ts": exp.DataType.build("bigint"), - "is_dev": exp.DataType.build("boolean"), - "is_removed": exp.DataType.build("boolean"), - "is_compacted": exp.DataType.build("boolean"), - }, - primary_key=("id",), - ) - - engine_adapter.create_index( - intervals_table, "name_version_idx", ("name", "version", "created_ts") - ) - engine_adapter.create_index( - intervals_table, "name_identifier_idx", ("name", "identifier", "created_ts") - ) diff --git a/sqlmesh/migrations/v0009_remove_pre_post_hooks.py b/sqlmesh/migrations/v0009_remove_pre_post_hooks.py deleted file mode 100644 index 90b39bcf72..0000000000 --- a/sqlmesh/migrations/v0009_remove_pre_post_hooks.py +++ /dev/null @@ -1,61 +0,0 @@ -"""Remove pre- / post- hooks from existing snapshots.""" - -import json - -import pandas as pd -from sqlglot import exp - -from sqlmesh.utils.migration import index_text_type - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - schema = state_sync.schema - snapshots_table = "_snapshots" - if schema: - snapshots_table = f"{schema}.{snapshots_table}" - - new_snapshots = [] - - for name, identifier, version, snapshopt in engine_adapter.fetchall( - exp.select("name", "identifier", "version", "snapshot").from_(snapshots_table), - quote_identifiers=True, - ): - snapshot = json.loads(snapshopt) - pre_hooks = snapshot["model"].pop("pre", []) - post_hooks = snapshot["model"].pop("post", []) - - expressions = snapshot["model"].pop("expressions", None) - if expressions and snapshot["model"]["source_type"] == "sql": - snapshot["model"]["pre_statements"] = expressions - - if pre_hooks or post_hooks: - print( - "WARNING: Hooks are no longer supported by SQLMesh, use pre and post SQL statements instead. " - f"Removing 'pre' and 'post' attributes from snapshot name='{name}', identifier='{identifier}'" - ) - - new_snapshots.append( - { - "name": name, - "identifier": identifier, - "version": version, - "snapshot": json.dumps(snapshot), - } - ) - - if new_snapshots: - engine_adapter.delete_from(snapshots_table, "TRUE") - - index_type = index_text_type(engine_adapter.dialect) - - engine_adapter.insert_append( - snapshots_table, - pd.DataFrame(new_snapshots), - columns_to_types={ - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build(index_type), - "version": exp.DataType.build(index_type), - "snapshot": exp.DataType.build("text"), - }, - ) diff --git a/sqlmesh/migrations/v0010_seed_hash_batch_size.py b/sqlmesh/migrations/v0010_seed_hash_batch_size.py deleted file mode 100644 index 2f73e73161..0000000000 --- a/sqlmesh/migrations/v0010_seed_hash_batch_size.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Seed metadata hashes now correctly include the batch_size.""" - - -def migrate(state_sync, **kwargs): # type: ignore - pass diff --git a/sqlmesh/migrations/v0011_add_model_kind_name.py b/sqlmesh/migrations/v0011_add_model_kind_name.py deleted file mode 100644 index 2d600dae4f..0000000000 --- a/sqlmesh/migrations/v0011_add_model_kind_name.py +++ /dev/null @@ -1,61 +0,0 @@ -"""Add the kind_name column to the snapshots table.""" - -import json - -import pandas as pd -from sqlglot import exp - -from sqlmesh.utils.migration import index_text_type - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - schema = state_sync.schema - snapshots_table = "_snapshots" - if schema: - snapshots_table = f"{schema}.{snapshots_table}" - - index_type = index_text_type(engine_adapter.dialect) - - alter_table_exp = exp.AlterTable( - this=exp.to_table(snapshots_table), - actions=[ - exp.ColumnDef( - this=exp.to_column("kind_name"), - kind=exp.DataType.build(index_type), - ) - ], - ) - engine_adapter.execute(alter_table_exp) - - new_snapshots = [] - - for name, identifier, version, snapshot in engine_adapter.fetchall( - exp.select("name", "identifier", "version", "snapshot").from_(snapshots_table), - quote_identifiers=True, - ): - parsed_snapshot = json.loads(snapshot) - new_snapshots.append( - { - "name": name, - "identifier": identifier, - "version": version, - "snapshot": snapshot, - "kind_name": parsed_snapshot["model"]["kind"]["name"], - } - ) - - if new_snapshots: - engine_adapter.delete_from(snapshots_table, "TRUE") - - engine_adapter.insert_append( - snapshots_table, - pd.DataFrame(new_snapshots), - columns_to_types={ - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build(index_type), - "version": exp.DataType.build(index_type), - "snapshot": exp.DataType.build("text"), - "kind_name": exp.DataType.build(index_type), - }, - ) diff --git a/sqlmesh/migrations/v0012_update_jinja_expressions.py b/sqlmesh/migrations/v0012_update_jinja_expressions.py deleted file mode 100644 index aa7bcd375c..0000000000 --- a/sqlmesh/migrations/v0012_update_jinja_expressions.py +++ /dev/null @@ -1,85 +0,0 @@ -"""Fix expressions that contain jinja.""" - -import json -import typing as t - -import pandas as pd -from sqlglot import exp - -from sqlmesh.utils.jinja import has_jinja -from sqlmesh.utils.migration import index_text_type - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - schema = state_sync.schema - snapshots_table = "_snapshots" - if schema: - snapshots_table = f"{schema}.{snapshots_table}" - - new_snapshots = [] - - for name, identifier, version, snapshot, kind_name in engine_adapter.fetchall( - exp.select("name", "identifier", "version", "snapshot", "kind_name").from_(snapshots_table), - quote_identifiers=True, - ): - parsed_snapshot = json.loads(snapshot) - audits = parsed_snapshot.get("audits", []) - model = parsed_snapshot["model"] - - if "query" in model and has_jinja(model["query"]): - model["query"] = _wrap_query(model["query"]) - - _wrap_statements(model, "pre_statements") - _wrap_statements(model, "post_statements") - - for audit in audits: - if has_jinja(audit["query"]): - audit["query"] = _wrap_query(audit["query"]) - _wrap_statements(audit, "expressions") - - new_snapshots.append( - { - "name": name, - "identifier": identifier, - "version": version, - "snapshot": json.dumps(parsed_snapshot), - "kind_name": kind_name, - } - ) - - if new_snapshots: - engine_adapter.delete_from(snapshots_table, "TRUE") - - index_type = index_text_type(engine_adapter.dialect) - - engine_adapter.insert_append( - snapshots_table, - pd.DataFrame(new_snapshots), - columns_to_types={ - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build(index_type), - "version": exp.DataType.build(index_type), - "snapshot": exp.DataType.build("text"), - "kind_name": exp.DataType.build(index_type), - }, - ) - - -def _wrap_statements(obj: t.Dict, key: str) -> None: - updated_statements = [] - for statement in obj.get(key, []): - if has_jinja(statement): - statement = _wrap_statement(statement) - updated_statements.append(statement) - - if updated_statements: - obj[key] = updated_statements - - -def _wrap_query(sql: str) -> str: - return f"JINJA_QUERY_BEGIN;\n{sql}\nJINJA_END;" - - -def _wrap_statement(sql: str) -> str: - return f"JINJA_STATEMENT_BEGIN;\n{sql}\nJINJA_END;" diff --git a/sqlmesh/migrations/v0013_serde_using_model_dialects.py b/sqlmesh/migrations/v0013_serde_using_model_dialects.py deleted file mode 100644 index 284c8026dd..0000000000 --- a/sqlmesh/migrations/v0013_serde_using_model_dialects.py +++ /dev/null @@ -1,86 +0,0 @@ -"""Serialize SQL using the dialect of each model.""" - -import json -import typing as t - -import pandas as pd -from sqlglot import exp, parse_one - -from sqlmesh.utils.jinja import has_jinja -from sqlmesh.utils.migration import index_text_type - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - schema = state_sync.schema - snapshots_table = "_snapshots" - if schema: - snapshots_table = f"{schema}.{snapshots_table}" - - new_snapshots = [] - - for name, identifier, version, snapshot, kind_name in engine_adapter.fetchall( - exp.select("name", "identifier", "version", "snapshot", "kind_name").from_(snapshots_table), - quote_identifiers=True, - ): - parsed_snapshot = json.loads(snapshot) - model = parsed_snapshot["model"] - dialect = model["dialect"] - - _update_expression(model, "query", dialect) - _update_expression_list(model, "pre_statements", dialect) - _update_expression_list(model, "post_statements", dialect) - - for audit in parsed_snapshot.get("audits", []): - dialect = audit["dialect"] - _update_expression(audit, "query", dialect) - _update_expression_list(audit, "expressions", dialect) - - new_snapshots.append( - { - "name": name, - "identifier": identifier, - "version": version, - "snapshot": json.dumps(parsed_snapshot), - "kind_name": kind_name, - } - ) - - if new_snapshots: - engine_adapter.delete_from(snapshots_table, "TRUE") - - index_type = index_text_type(engine_adapter.dialect) - - engine_adapter.insert_append( - snapshots_table, - pd.DataFrame(new_snapshots), - columns_to_types={ - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build(index_type), - "version": exp.DataType.build(index_type), - "snapshot": exp.DataType.build("text"), - "kind_name": exp.DataType.build(index_type), - }, - ) - - -# Note: previously we used to do serde using the SQLGlot dialect, so we need to parse the -# stored queries using that dialect and then write them back using the correct dialect. - - -def _update_expression(obj: t.Dict, key: str, dialect: str) -> None: - if key in obj and not has_jinja(obj[key]): - obj[key] = parse_one(obj[key]).sql(dialect=dialect) - - -def _update_expression_list(obj: t.Dict, key: str, dialect: str) -> None: - if key in obj: - obj[key] = [ - ( - parse_one(expression).sql(dialect=dialect) - if not has_jinja(expression) - else expression - ) - for expression in obj[key] - if expression - ] diff --git a/sqlmesh/migrations/v0014_fix_dev_intervals.py b/sqlmesh/migrations/v0014_fix_dev_intervals.py deleted file mode 100644 index f0e922783c..0000000000 --- a/sqlmesh/migrations/v0014_fix_dev_intervals.py +++ /dev/null @@ -1,14 +0,0 @@ -"""Fix snapshot intervals that have been erroneously marked as dev.""" - - -def migrate(state_sync, **kwargs): # type: ignore - schema = state_sync.schema - intervals_table = "_intervals" - if schema: - intervals_table = f"{schema}.{intervals_table}" - - state_sync.engine_adapter.update_table( - intervals_table, - {"is_dev": False}, - where="1=1", - ) diff --git a/sqlmesh/migrations/v0015_environment_add_promoted_snapshot_ids.py b/sqlmesh/migrations/v0015_environment_add_promoted_snapshot_ids.py deleted file mode 100644 index d8d963cb5b..0000000000 --- a/sqlmesh/migrations/v0015_environment_add_promoted_snapshot_ids.py +++ /dev/null @@ -1,22 +0,0 @@ -"""Include a set of snapshot IDs filtered for promotion.""" - -from sqlglot import exp - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - environments_table = "_environments" - if state_sync.schema: - environments_table = f"{state_sync.schema}.{environments_table}" - - alter_table_exp = exp.AlterTable( - this=exp.to_table(environments_table), - actions=[ - exp.ColumnDef( - this=exp.to_column("promoted_snapshot_ids"), - kind=exp.DataType.build("text"), - ) - ], - ) - - engine_adapter.execute(alter_table_exp) diff --git a/sqlmesh/migrations/v0016_fix_windows_path.py b/sqlmesh/migrations/v0016_fix_windows_path.py deleted file mode 100644 index 46c85a0d5d..0000000000 --- a/sqlmesh/migrations/v0016_fix_windows_path.py +++ /dev/null @@ -1,58 +0,0 @@ -"""Fix paths that have a Windows forward slash in them.""" - -import json - -import pandas as pd -from sqlglot import exp - -from sqlmesh.utils.migration import index_text_type - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - schema = state_sync.schema - snapshots_table = "_snapshots" - if schema: - snapshots_table = f"{schema}.{snapshots_table}" - - new_snapshots = [] - - for name, identifier, version, snapshot, kind_name in engine_adapter.fetchall( - exp.select("name", "identifier", "version", "snapshot", "kind_name").from_(snapshots_table), - quote_identifiers=True, - ): - parsed_snapshot = json.loads(snapshot) - model = parsed_snapshot["model"] - python_env = model.get("python_env") - if python_env: - for py_definition in python_env.values(): - path = py_definition.get("path") - if path: - py_definition["path"] = path.replace("\\", "/") - - new_snapshots.append( - { - "name": name, - "identifier": identifier, - "version": version, - "snapshot": json.dumps(parsed_snapshot), - "kind_name": kind_name, - } - ) - - if new_snapshots: - engine_adapter.delete_from(snapshots_table, "TRUE") - - index_type = index_text_type(engine_adapter.dialect) - - engine_adapter.insert_append( - snapshots_table, - pd.DataFrame(new_snapshots), - columns_to_types={ - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build(index_type), - "version": exp.DataType.build(index_type), - "snapshot": exp.DataType.build("text"), - "kind_name": exp.DataType.build(index_type), - }, - ) diff --git a/sqlmesh/migrations/v0017_fix_windows_seed_path.py b/sqlmesh/migrations/v0017_fix_windows_seed_path.py deleted file mode 100644 index f780b216de..0000000000 --- a/sqlmesh/migrations/v0017_fix_windows_seed_path.py +++ /dev/null @@ -1,54 +0,0 @@ -"""Fix seed paths that have a Windows forward slash in them.""" - -import json - -import pandas as pd -from sqlglot import exp - -from sqlmesh.utils.migration import index_text_type - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - schema = state_sync.schema - snapshots_table = "_snapshots" - if schema: - snapshots_table = f"{schema}.{snapshots_table}" - - new_snapshots = [] - - for name, identifier, version, snapshot, kind_name in engine_adapter.fetchall( - exp.select("name", "identifier", "version", "snapshot", "kind_name").from_(snapshots_table), - quote_identifiers=True, - ): - parsed_snapshot = json.loads(snapshot) - model_kind = parsed_snapshot["model"]["kind"] - if "path" in model_kind: - model_kind["path"] = model_kind["path"].replace("\\", "/") - - new_snapshots.append( - { - "name": name, - "identifier": identifier, - "version": version, - "snapshot": json.dumps(parsed_snapshot), - "kind_name": kind_name, - } - ) - - if new_snapshots: - engine_adapter.delete_from(snapshots_table, "TRUE") - - index_type = index_text_type(engine_adapter.dialect) - - engine_adapter.insert_append( - snapshots_table, - pd.DataFrame(new_snapshots), - columns_to_types={ - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build(index_type), - "version": exp.DataType.build(index_type), - "snapshot": exp.DataType.build("text"), - "kind_name": exp.DataType.build(index_type), - }, - ) diff --git a/sqlmesh/migrations/v0018_rename_snapshot_model_to_node.py b/sqlmesh/migrations/v0018_rename_snapshot_model_to_node.py deleted file mode 100644 index 9b342962cc..0000000000 --- a/sqlmesh/migrations/v0018_rename_snapshot_model_to_node.py +++ /dev/null @@ -1,52 +0,0 @@ -"""Replace snapshot model field with node.""" - -import json - -import pandas as pd -from sqlglot import exp - -from sqlmesh.utils.migration import index_text_type - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - schema = state_sync.schema - snapshots_table = "_snapshots" - if schema: - snapshots_table = f"{schema}.{snapshots_table}" - - new_snapshots = [] - - for name, identifier, version, snapshot, kind_name in engine_adapter.fetchall( - exp.select("name", "identifier", "version", "snapshot", "kind_name").from_(snapshots_table), - quote_identifiers=True, - ): - parsed_snapshot = json.loads(snapshot) - parsed_snapshot["node"] = parsed_snapshot.pop("model") - - new_snapshots.append( - { - "name": name, - "identifier": identifier, - "version": version, - "snapshot": json.dumps(parsed_snapshot), - "kind_name": kind_name, - } - ) - - if new_snapshots: - engine_adapter.delete_from(snapshots_table, "TRUE") - - index_type = index_text_type(engine_adapter.dialect) - - engine_adapter.insert_append( - snapshots_table, - pd.DataFrame(new_snapshots), - columns_to_types={ - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build(index_type), - "version": exp.DataType.build(index_type), - "snapshot": exp.DataType.build("text"), - "kind_name": exp.DataType.build(index_type), - }, - ) diff --git a/sqlmesh/migrations/v0019_add_env_suffix_target.py b/sqlmesh/migrations/v0019_add_env_suffix_target.py deleted file mode 100644 index b3007b45e1..0000000000 --- a/sqlmesh/migrations/v0019_add_env_suffix_target.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Add support for environment suffix target.""" - -from sqlglot import exp - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - environments_table = "_environments" - if state_sync.schema: - environments_table = f"{state_sync.schema}.{environments_table}" - - alter_table_exp = exp.AlterTable( - this=exp.to_table(environments_table), - actions=[ - exp.ColumnDef( - this=exp.to_column("suffix_target"), - kind=exp.DataType.build("text"), - ) - ], - ) - engine_adapter.execute(alter_table_exp) - - state_sync.engine_adapter.update_table( - environments_table, - {"suffix_target": "schema"}, - where="1=1", - ) diff --git a/sqlmesh/migrations/v0020_remove_redundant_attributes_from_dbt_models.py b/sqlmesh/migrations/v0020_remove_redundant_attributes_from_dbt_models.py deleted file mode 100644 index cf9d7a145f..0000000000 --- a/sqlmesh/migrations/v0020_remove_redundant_attributes_from_dbt_models.py +++ /dev/null @@ -1,79 +0,0 @@ -"""Remove redundant attributes from dbt models.""" - -import json - -import pandas as pd -from sqlglot import exp - -from sqlmesh.utils.migration import index_text_type - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - schema = state_sync.schema - snapshots_table = "_snapshots" - if schema: - snapshots_table = f"{schema}.{snapshots_table}" - - new_snapshots = [] - - for name, identifier, version, snapshot, kind_name in engine_adapter.fetchall( - exp.select("name", "identifier", "version", "snapshot", "kind_name").from_(snapshots_table), - quote_identifiers=True, - ): - parsed_snapshot = json.loads(snapshot) - jinja_macros_global_objs = parsed_snapshot["node"]["jinja_macros"]["global_objs"] - if "config" in jinja_macros_global_objs and isinstance( - jinja_macros_global_objs["config"], dict - ): - for key in CONFIG_ATTRIBUTE_KEYS_TO_REMOVE: - jinja_macros_global_objs["config"].pop(key, None) - - new_snapshots.append( - { - "name": name, - "identifier": identifier, - "version": version, - "snapshot": json.dumps(parsed_snapshot), - "kind_name": kind_name, - } - ) - - if new_snapshots: - engine_adapter.delete_from(snapshots_table, "TRUE") - - index_type = index_text_type(engine_adapter.dialect) - - engine_adapter.insert_append( - snapshots_table, - pd.DataFrame(new_snapshots), - columns_to_types={ - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build(index_type), - "version": exp.DataType.build(index_type), - "snapshot": exp.DataType.build("text"), - "kind_name": exp.DataType.build(index_type), - }, - ) - - -CONFIG_ATTRIBUTE_KEYS_TO_REMOVE = [ - "config", - "config_call_dict", - "depends_on", - "dependencies", - "metrics", - "original_file_path", - "packages", - "patch_path", - "path", - "post-hook", - "pre-hook", - "raw_code", - "refs", - "resource_type", - "sources", - "sql", - "tests", - "unrendered_config", -] diff --git a/sqlmesh/migrations/v0021_fix_table_properties.py b/sqlmesh/migrations/v0021_fix_table_properties.py deleted file mode 100644 index 7889b59875..0000000000 --- a/sqlmesh/migrations/v0021_fix_table_properties.py +++ /dev/null @@ -1,61 +0,0 @@ -"""Fix table properties that have extra quoting due to a bug.""" - -import json - -import pandas as pd -from sqlglot import exp - -from sqlmesh.core import dialect as d -from sqlmesh.utils.migration import index_text_type - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - schema = state_sync.schema - snapshots_table = "_snapshots" - if schema: - snapshots_table = f"{schema}.{snapshots_table}" - - new_snapshots = [] - found_table_properties = False - for name, identifier, version, snapshot, kind_name in engine_adapter.fetchall( - exp.select("name", "identifier", "version", "snapshot", "kind_name").from_(snapshots_table), - quote_identifiers=True, - ): - parsed_snapshot = json.loads(snapshot) - table_properties = parsed_snapshot["node"].get("table_properties") - if table_properties: - found_table_properties = True - dialect = parsed_snapshot["node"].get("dialect") - parsed_snapshot["node"]["table_properties"] = exp.Tuple( - expressions=[ - exp.Literal.string(k).eq(d.parse_one(v)) for k, v in table_properties.items() - ] - ).sql(dialect=dialect) - - new_snapshots.append( - { - "name": name, - "identifier": identifier, - "version": version, - "snapshot": json.dumps(parsed_snapshot), - "kind_name": kind_name, - } - ) - - if found_table_properties: - engine_adapter.delete_from(snapshots_table, "TRUE") - - index_type = index_text_type(engine_adapter.dialect) - - engine_adapter.insert_append( - snapshots_table, - pd.DataFrame(new_snapshots), - columns_to_types={ - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build(index_type), - "version": exp.DataType.build(index_type), - "snapshot": exp.DataType.build("text"), - "kind_name": exp.DataType.build(index_type), - }, - ) diff --git a/sqlmesh/migrations/v0022_move_project_to_model.py b/sqlmesh/migrations/v0022_move_project_to_model.py deleted file mode 100644 index ec8cba5762..0000000000 --- a/sqlmesh/migrations/v0022_move_project_to_model.py +++ /dev/null @@ -1,53 +0,0 @@ -"""Move project attr from snapshot to model.""" - -import json - -import pandas as pd -from sqlglot import exp - -from sqlmesh.utils.migration import index_text_type - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - schema = state_sync.schema - snapshots_table = "_snapshots" - if schema: - snapshots_table = f"{schema}.{snapshots_table}" - - new_snapshots = [] - - for name, identifier, version, snapshot, kind_name in engine_adapter.fetchall( - exp.select("name", "identifier", "version", "snapshot", "kind_name").from_(snapshots_table), - quote_identifiers=True, - ): - parsed_snapshot = json.loads(snapshot) - - parsed_snapshot["node"]["project"] = parsed_snapshot.pop("project", "") - - new_snapshots.append( - { - "name": name, - "identifier": identifier, - "version": version, - "snapshot": json.dumps(parsed_snapshot), - "kind_name": kind_name, - } - ) - - engine_adapter.delete_from(snapshots_table, "TRUE") - - index_type = index_text_type(engine_adapter.dialect) - - if new_snapshots: - engine_adapter.insert_append( - snapshots_table, - pd.DataFrame(new_snapshots), - columns_to_types={ - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build(index_type), - "version": exp.DataType.build(index_type), - "snapshot": exp.DataType.build("text"), - "kind_name": exp.DataType.build(index_type), - }, - ) diff --git a/sqlmesh/migrations/v0023_fix_added_models_with_forward_only_parents.py b/sqlmesh/migrations/v0023_fix_added_models_with_forward_only_parents.py deleted file mode 100644 index 6ae64955b8..0000000000 --- a/sqlmesh/migrations/v0023_fix_added_models_with_forward_only_parents.py +++ /dev/null @@ -1,65 +0,0 @@ -"""Fix snapshots of added models with forward only parents.""" - -import json -import typing as t - -from sqlglot import exp - -from sqlmesh.utils.dag import DAG - - -def migrate(state_sync: t.Any, **kwargs) -> None: # type: ignore - engine_adapter = state_sync.engine_adapter - schema = state_sync.schema - snapshots_table = "_snapshots" - environments_table = "_environments" - if schema: - snapshots_table = f"{schema}.{snapshots_table}" - environments_table = f"{schema}.{environments_table}" - - dag: DAG[t.Tuple[str, str]] = DAG() - snapshot_mapping: t.Dict[t.Tuple[str, str], t.Dict[str, t.Any]] = {} - - for identifier, snapshot in engine_adapter.fetchall( - exp.select("identifier", "snapshot").from_(snapshots_table), - quote_identifiers=True, - ): - parsed_snapshot = json.loads(snapshot) - - snapshot_id = (parsed_snapshot["name"], identifier) - snapshot_mapping[snapshot_id] = parsed_snapshot - - parent_ids = [ - (parent["name"], parent["identifier"]) for parent in parsed_snapshot["parents"] - ] - dag.add(snapshot_id, parent_ids) - - snapshots_to_delete = set() - - for snapshot_id in dag: - if snapshot_id not in snapshot_mapping: - continue - parsed_snapshot = snapshot_mapping[snapshot_id] - is_breaking = parsed_snapshot.get("change_category") == 1 - has_previous_versions = bool(parsed_snapshot.get("previous_versions", [])) - - has_paused_forward_only_parent = False - if is_breaking and not has_previous_versions: - for upstream_id in dag.upstream(snapshot_id): - if upstream_id not in snapshot_mapping: - continue - upstream_snapshot = snapshot_mapping[upstream_id] - upstream_change_category = upstream_snapshot.get("change_category") - is_forward_only_upstream = upstream_change_category == 3 - if is_forward_only_upstream and not upstream_snapshot.get("unpaused_ts"): - has_paused_forward_only_parent = True - break - - if has_paused_forward_only_parent: - snapshots_to_delete.add(snapshot_id) - - if snapshots_to_delete: - where = t.cast(exp.Tuple, exp.convert((exp.column("name"), exp.column("identifier")))).isin( - *snapshots_to_delete - ) - engine_adapter.delete_from(snapshots_table, where) diff --git a/sqlmesh/migrations/v0024_replace_model_kind_name_enum_with_value.py b/sqlmesh/migrations/v0024_replace_model_kind_name_enum_with_value.py deleted file mode 100644 index 1c55b93f7c..0000000000 --- a/sqlmesh/migrations/v0024_replace_model_kind_name_enum_with_value.py +++ /dev/null @@ -1,54 +0,0 @@ -"""Replace snapshot model_kind_name enum with value.""" - -import json - -import pandas as pd -from sqlglot import exp - -from sqlmesh.utils.migration import index_text_type - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - schema = state_sync.schema - snapshots_table = "_snapshots" - if schema: - snapshots_table = f"{schema}.{snapshots_table}" - - new_snapshots = [] - - for name, identifier, version, snapshot, kind_name in engine_adapter.fetchall( - exp.select("name", "identifier", "version", "snapshot", "kind_name").from_(snapshots_table), - quote_identifiers=True, - ): - corrected_kind_name = None - parsed_snapshot = json.loads(snapshot) - if "kind" in parsed_snapshot["node"]: - corrected_kind_name = parsed_snapshot["node"]["kind"].get("name") - - new_snapshots.append( - { - "name": name, - "identifier": identifier, - "version": version, - "snapshot": snapshot, - "kind_name": corrected_kind_name, - } - ) - - if new_snapshots: - engine_adapter.delete_from(snapshots_table, "TRUE") - - index_type = index_text_type(engine_adapter.dialect) - - engine_adapter.insert_append( - snapshots_table, - pd.DataFrame(new_snapshots), - columns_to_types={ - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build(index_type), - "version": exp.DataType.build(index_type), - "snapshot": exp.DataType.build("text"), - "kind_name": exp.DataType.build(index_type), - }, - ) diff --git a/sqlmesh/migrations/v0025_fix_intervals_and_missing_change_category.py b/sqlmesh/migrations/v0025_fix_intervals_and_missing_change_category.py deleted file mode 100644 index 884b8c4067..0000000000 --- a/sqlmesh/migrations/v0025_fix_intervals_and_missing_change_category.py +++ /dev/null @@ -1,116 +0,0 @@ -"""Normalize intervals and fix missing change category.""" - -import json -import zlib - -import pandas as pd -from sqlglot import exp - -from sqlmesh.utils import random_id -from sqlmesh.utils.date import now_timestamp -from sqlmesh.utils.migration import index_text_type - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - schema = state_sync.schema - snapshots_table = "_snapshots" - intervals_table = "_intervals" - if schema: - snapshots_table = f"{schema}.{snapshots_table}" - intervals_table = f"{schema}.{intervals_table}" - - migration_required = False - new_snapshots = [] - new_intervals = [] - - for name, identifier, version, snapshot, kind_name in engine_adapter.fetchall( - exp.select("name", "identifier", "version", "snapshot", "kind_name").from_(snapshots_table), - quote_identifiers=True, - ): - parsed_snapshot = json.loads(snapshot) - - if not parsed_snapshot.get("change_category"): - fingerprint = parsed_snapshot.get("fingerprint") - version = _hash( - [ - fingerprint["data_hash"], - fingerprint["parent_data_hash"], - ] - ) - parsed_snapshot["change_category"] = ( - 4 if version == parsed_snapshot.get("version") else 5 - ) - migration_required = True - - def _add_interval(start_ts: int, end_ts: int, is_dev: bool) -> None: - new_intervals.append( - { - "id": random_id(), - "created_ts": now_timestamp(), - "name": name, - "identifier": identifier, - "version": version, - "start_ts": start_ts, - "end_ts": end_ts, - "is_dev": is_dev, - "is_removed": False, - "is_compacted": True, - } - ) - - for interval in parsed_snapshot.pop("intervals", []): - _add_interval(interval[0], interval[1], False) - migration_required = True - - for interval in parsed_snapshot.pop("dev_intervals", []): - _add_interval(interval[0], interval[1], True) - migration_required = True - - new_snapshots.append( - { - "name": name, - "identifier": identifier, - "version": version, - "snapshot": json.dumps(parsed_snapshot), - "kind_name": kind_name, - } - ) - - if migration_required: - index_type = index_text_type(engine_adapter.dialect) - - engine_adapter.delete_from(snapshots_table, "TRUE") - engine_adapter.insert_append( - snapshots_table, - pd.DataFrame(new_snapshots), - columns_to_types={ - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build(index_type), - "version": exp.DataType.build(index_type), - "snapshot": exp.DataType.build("text"), - "kind_name": exp.DataType.build(index_type), - }, - ) - - if new_intervals: - engine_adapter.insert_append( - intervals_table, - pd.DataFrame(new_intervals), - columns_to_types={ - "id": exp.DataType.build(index_type), - "created_ts": exp.DataType.build("bigint"), - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build(index_type), - "version": exp.DataType.build(index_type), - "start_ts": exp.DataType.build("bigint"), - "end_ts": exp.DataType.build("bigint"), - "is_dev": exp.DataType.build("boolean"), - "is_removed": exp.DataType.build("boolean"), - "is_compacted": exp.DataType.build("boolean"), - }, - ) - - -def _hash(data): # type: ignore - return str(zlib.crc32(";".join("" if d is None else d for d in data).encode("utf-8"))) diff --git a/sqlmesh/migrations/v0026_remove_dialect_from_seed.py b/sqlmesh/migrations/v0026_remove_dialect_from_seed.py deleted file mode 100644 index 509c87947c..0000000000 --- a/sqlmesh/migrations/v0026_remove_dialect_from_seed.py +++ /dev/null @@ -1,54 +0,0 @@ -"""Remove dialect from seeds.""" - -import json - -import pandas as pd -from sqlglot import exp - -from sqlmesh.utils.migration import index_text_type - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - schema = state_sync.schema - snapshots_table = "_snapshots" - if schema: - snapshots_table = f"{schema}.{snapshots_table}" - - new_snapshots = [] - - for name, identifier, version, snapshot, kind_name in engine_adapter.fetchall( - exp.select("name", "identifier", "version", "snapshot", "kind_name").from_(snapshots_table), - quote_identifiers=True, - ): - parsed_snapshot = json.loads(snapshot) - node = parsed_snapshot["node"] - if "seed" in node: - node["seed"].pop("dialect", None) - - new_snapshots.append( - { - "name": name, - "identifier": identifier, - "version": version, - "snapshot": json.dumps(parsed_snapshot), - "kind_name": kind_name, - } - ) - - if new_snapshots: - engine_adapter.delete_from(snapshots_table, "TRUE") - - index_type = index_text_type(engine_adapter.dialect) - - engine_adapter.insert_append( - snapshots_table, - pd.DataFrame(new_snapshots), - columns_to_types={ - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build(index_type), - "version": exp.DataType.build(index_type), - "snapshot": exp.DataType.build("text"), - "kind_name": exp.DataType.build(index_type), - }, - ) diff --git a/sqlmesh/migrations/v0027_minute_interval_to_five.py b/sqlmesh/migrations/v0027_minute_interval_to_five.py deleted file mode 100644 index 10d58fbeb1..0000000000 --- a/sqlmesh/migrations/v0027_minute_interval_to_five.py +++ /dev/null @@ -1,56 +0,0 @@ -"""Change any interval unit of minute to five_minute.""" - -import json - -import pandas as pd -from sqlglot import exp - -from sqlmesh.utils.migration import index_text_type - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - schema = state_sync.schema - snapshots_table = "_snapshots" - if schema: - snapshots_table = f"{schema}.{snapshots_table}" - - new_snapshots = [] - - for name, identifier, version, snapshot, kind_name in engine_adapter.fetchall( - exp.select("name", "identifier", "version", "snapshot", "kind_name").from_(snapshots_table), - quote_identifiers=True, - ): - parsed_snapshot = json.loads(snapshot) - - node = parsed_snapshot["node"] - - if node.get("interval_unit") == "minute": - node["interval_unit"] = "five_minute" - - new_snapshots.append( - { - "name": name, - "identifier": identifier, - "version": version, - "snapshot": json.dumps(parsed_snapshot), - "kind_name": kind_name, - } - ) - - if new_snapshots: - engine_adapter.delete_from(snapshots_table, "TRUE") - - index_type = index_text_type(engine_adapter.dialect) - - engine_adapter.insert_append( - snapshots_table, - pd.DataFrame(new_snapshots), - columns_to_types={ - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build(index_type), - "version": exp.DataType.build(index_type), - "snapshot": exp.DataType.build("text"), - "kind_name": exp.DataType.build(index_type), - }, - ) diff --git a/sqlmesh/migrations/v0028_add_plan_dags_table.py b/sqlmesh/migrations/v0028_add_plan_dags_table.py deleted file mode 100644 index d8e67f6045..0000000000 --- a/sqlmesh/migrations/v0028_add_plan_dags_table.py +++ /dev/null @@ -1,29 +0,0 @@ -"""Creates the '_plan_dags' table if Airflow is used.""" - -from sqlglot import exp - -from sqlmesh.utils.migration import index_text_type - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - schema = state_sync.schema - plan_dags_table = "_plan_dags" - - if schema: - engine_adapter.create_schema(schema) - plan_dags_table = f"{schema}.{plan_dags_table}" - - index_type = index_text_type(engine_adapter.dialect) - - engine_adapter.create_state_table( - plan_dags_table, - { - "request_id": exp.DataType.build(index_type), - "dag_id": exp.DataType.build(index_type), - "dag_spec": exp.DataType.build("text"), - }, - primary_key=("request_id",), - ) - - engine_adapter.create_index(plan_dags_table, "dag_id_idx", ("dag_id",)) diff --git a/sqlmesh/migrations/v0029_generate_schema_types_using_dialect.py b/sqlmesh/migrations/v0029_generate_schema_types_using_dialect.py deleted file mode 100644 index aab3ec7426..0000000000 --- a/sqlmesh/migrations/v0029_generate_schema_types_using_dialect.py +++ /dev/null @@ -1,68 +0,0 @@ -"""Generate mapping schema data types using the corresponding model's dialect.""" - -import json - -import pandas as pd -from sqlglot import exp, parse_one - -from sqlmesh.utils.migration import index_text_type - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - schema = state_sync.schema - snapshots_table = "_snapshots" - if schema: - snapshots_table = f"{schema}.{snapshots_table}" - - new_snapshots = [] - for name, identifier, version, snapshot, kind_name in engine_adapter.fetchall( - exp.select("name", "identifier", "version", "snapshot", "kind_name").from_(snapshots_table), - quote_identifiers=True, - ): - parsed_snapshot = json.loads(snapshot) - node = parsed_snapshot["node"] - - mapping_schema = node.get("mapping_schema") - if mapping_schema: - node["mapping_schema"] = _convert_schema_types(mapping_schema, node["dialect"]) - - new_snapshots.append( - { - "name": name, - "identifier": identifier, - "version": version, - "snapshot": json.dumps(parsed_snapshot), - "kind_name": kind_name, - } - ) - - if new_snapshots: - engine_adapter.delete_from(snapshots_table, "TRUE") - - index_type = index_text_type(engine_adapter.dialect) - - engine_adapter.insert_append( - snapshots_table, - pd.DataFrame(new_snapshots), - columns_to_types={ - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build(index_type), - "version": exp.DataType.build(index_type), - "snapshot": exp.DataType.build("text"), - "kind_name": exp.DataType.build(index_type), - }, - ) - - -def _convert_schema_types(schema, dialect): # type: ignore - if not schema: - return schema - - for k, v in schema.items(): - if isinstance(v, dict): - _convert_schema_types(v, dialect) - else: - schema[k] = parse_one(v).sql(dialect=dialect) - - return schema diff --git a/sqlmesh/migrations/v0030_update_unrestorable_snapshots.py b/sqlmesh/migrations/v0030_update_unrestorable_snapshots.py deleted file mode 100644 index 95e6b36704..0000000000 --- a/sqlmesh/migrations/v0030_update_unrestorable_snapshots.py +++ /dev/null @@ -1,64 +0,0 @@ -"""Update unrestorable snapshots.""" - -import json -import typing as t -from collections import defaultdict - -import pandas as pd -from sqlglot import exp - -from sqlmesh.utils.migration import index_text_type - - -def migrate(state_sync: t.Any, **kwargs: t.Any) -> None: # type: ignore - engine_adapter = state_sync.engine_adapter - schema = state_sync.schema - snapshots_table = "_snapshots" - if schema: - snapshots_table = f"{schema}.{snapshots_table}" - - new_snapshots = [] - snapshots_by_version = defaultdict(list) - - for name, identifier, version, snapshot, kind_name in engine_adapter.fetchall( - exp.select("name", "identifier", "version", "snapshot", "kind_name").from_(snapshots_table), - quote_identifiers=True, - ): - parsed_snapshot = json.loads(snapshot) - snapshots_by_version[(name, version)].append((identifier, kind_name, parsed_snapshot)) - - for (name, version), snapshots in snapshots_by_version.items(): - has_forward_only = any(s["change_category"] == 3 for _, _, s in snapshots) - for identifier, kind_name, snapshot in snapshots: - if ( - has_forward_only - and snapshot["change_category"] != 3 - and not snapshot.get("unpaused_ts") - ): - snapshot["unrestorable"] = True - new_snapshots.append( - { - "name": name, - "identifier": identifier, - "version": version, - "snapshot": json.dumps(snapshot), - "kind_name": kind_name, - } - ) - - if new_snapshots: - engine_adapter.delete_from(snapshots_table, "TRUE") - - index_type = index_text_type(engine_adapter.dialect) - - engine_adapter.insert_append( - snapshots_table, - pd.DataFrame(new_snapshots), - columns_to_types={ - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build(index_type), - "version": exp.DataType.build(index_type), - "snapshot": exp.DataType.build("text"), - "kind_name": exp.DataType.build(index_type), - }, - ) diff --git a/sqlmesh/migrations/v0031_remove_dbt_target_fields.py b/sqlmesh/migrations/v0031_remove_dbt_target_fields.py deleted file mode 100644 index 7a1953707c..0000000000 --- a/sqlmesh/migrations/v0031_remove_dbt_target_fields.py +++ /dev/null @@ -1,64 +0,0 @@ -"""Remove dbt target fields from snapshots outside of limited list of approved fields""" - -import json - -import pandas as pd -from sqlglot import exp - -from sqlmesh.utils.migration import index_text_type - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - schema = state_sync.schema - snapshots_table = "_snapshots" - if schema: - snapshots_table = f"{schema}.{snapshots_table}" - - new_snapshots = [] - found_dbt_target = False - for name, identifier, version, snapshot, kind_name in engine_adapter.fetchall( - exp.select("name", "identifier", "version", "snapshot", "kind_name").from_(snapshots_table), - quote_identifiers=True, - ): - parsed_snapshot = json.loads(snapshot) - node = parsed_snapshot["node"] - dbt_target = node.get("jinja_macros", {}).get("global_objs", {}).get("target", {}) - # Double check that `target_name` exists as a field since we know that all dbt targets have `target_name` - # We do this in case someone has a target macro defined that is not related to dbt - if dbt_target and dbt_target.get("target_name"): - found_dbt_target = True - node["jinja_macros"]["global_objs"]["target"] = { - "type": dbt_target.get("type", "None"), - "name": dbt_target.get("name", "None"), - "schema": dbt_target.get("schema", "None"), - "database": dbt_target.get("database", "None"), - "target_name": dbt_target["target_name"], - } - - new_snapshots.append( - { - "name": name, - "identifier": identifier, - "version": version, - "snapshot": json.dumps(parsed_snapshot), - "kind_name": kind_name, - } - ) - - if found_dbt_target: - engine_adapter.delete_from(snapshots_table, "TRUE") - - index_type = index_text_type(engine_adapter.dialect) - - engine_adapter.insert_append( - snapshots_table, - pd.DataFrame(new_snapshots), - columns_to_types={ - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build(index_type), - "version": exp.DataType.build(index_type), - "snapshot": exp.DataType.build("text"), - "kind_name": exp.DataType.build(index_type), - }, - ) diff --git a/sqlmesh/migrations/v0032_add_sqlmesh_version.py b/sqlmesh/migrations/v0032_add_sqlmesh_version.py deleted file mode 100644 index 0b17f4a0f4..0000000000 --- a/sqlmesh/migrations/v0032_add_sqlmesh_version.py +++ /dev/null @@ -1,22 +0,0 @@ -"""Add new 'sqlmesh_version' column to the version state table.""" - -from sqlglot import exp - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - versions_table = "_versions" - if state_sync.schema: - versions_table = f"{state_sync.schema}.{versions_table}" - - alter_table_exp = exp.AlterTable( - this=exp.to_table(versions_table), - actions=[ - exp.ColumnDef( - this=exp.to_column("sqlmesh_version"), - kind=exp.DataType.build("text"), - ) - ], - ) - - engine_adapter.execute(alter_table_exp) diff --git a/sqlmesh/migrations/v0033_mysql_fix_blob_text_type.py b/sqlmesh/migrations/v0033_mysql_fix_blob_text_type.py deleted file mode 100644 index 5660240d4b..0000000000 --- a/sqlmesh/migrations/v0033_mysql_fix_blob_text_type.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Use LONGTEXT type for blob fields in MySQL.""" - -from sqlglot import exp - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - if engine_adapter.dialect != "mysql": - return - - schema = state_sync.schema - environments_table = "_environments" - snapshots_table = "_snapshots" - seeds_table = "_seeds" - plan_dags_table = "_plan_dags" - - if schema: - environments_table = f"{schema}.{environments_table}" - snapshots_table = f"{schema}.{snapshots_table}" - seeds_table = f"{state_sync.schema}.{seeds_table}" - plan_dags_table = f"{schema}.{plan_dags_table}" - - targets = [ - (environments_table, "snapshots"), - (snapshots_table, "snapshot"), - (seeds_table, "content"), - (plan_dags_table, "dag_spec"), - ] - - for table_name, column_name in targets: - alter_table_exp = exp.AlterTable( - this=exp.to_table(table_name), - actions=[ - exp.AlterColumn( - this=exp.to_column(column_name), - dtype=exp.DataType.build("longtext"), - ) - ], - ) - - engine_adapter.execute(alter_table_exp) diff --git a/sqlmesh/migrations/v0034_add_default_catalog.py b/sqlmesh/migrations/v0034_add_default_catalog.py deleted file mode 100644 index 6d13ae96bf..0000000000 --- a/sqlmesh/migrations/v0034_add_default_catalog.py +++ /dev/null @@ -1,364 +0,0 @@ -"""Add default catalog to snapshots and update names to match new normalization rules.""" - -from __future__ import annotations - -import json -import typing as t - -import pandas as pd -from sqlglot import exp -from sqlglot.dialects.dialect import DialectType -from sqlglot.helper import dict_depth, seq_get -from sqlglot.optimizer.normalize_identifiers import normalize_identifiers - -from sqlmesh.utils.migration import index_text_type - - -def set_default_catalog( - table: exp.Table, - default_catalog: t.Optional[str], -) -> exp.Table: - if default_catalog and not table.catalog and table.db: - table.set("catalog", exp.parse_identifier(default_catalog)) - - return table - - -def normalize_model_name( - table: str | exp.Table, - default_catalog: t.Optional[str], - dialect: DialectType = None, -) -> str: - table = exp.to_table(table, dialect=dialect) - - table = set_default_catalog(table, default_catalog) - return exp.table_name(normalize_identifiers(table, dialect=dialect), identify=True) - - -def normalize_mapping_schema(mapping_schema: t.Dict, dialect: str) -> t.Dict: - # Example input: {'"catalog"': {'schema': {'table': {'column': 'INT'}}}} - # Example output: {'"catalog"': {'"schema"': {'"table"': {'column': 'INT'}}}} - normalized_mapping_schema = {} - for key, value in mapping_schema.items(): - if isinstance(value, dict): - normalized_mapping_schema[normalize_model_name(key, None, dialect)] = ( - normalize_mapping_schema(value, dialect) - ) - else: - normalized_mapping_schema[key] = value - return normalized_mapping_schema - - -def update_dbt_relations( - source: t.Optional[t.Dict], keys: t.List[str], default_catalog: t.Optional[str] -) -> None: - if not default_catalog or not source: - return - for key in keys: - relations = source.get(key) - if relations: - relations = [relations] if "database" in relations else relations.values() - for relation in relations: - if not relation["database"]: - relation["database"] = default_catalog - - -def migrate(state_sync, default_catalog: t.Optional[str], **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - schema = state_sync.schema - snapshots_table = "_snapshots" - environments_table = "_environments" - intervals_table = "_intervals" - seeds_table = "_seeds" - - if schema: - snapshots_table = f"{schema}.{snapshots_table}" - environments_table = f"{schema}.{environments_table}" - intervals_table = f"{schema}.{intervals_table}" - seeds_table = f"{schema}.{seeds_table}" - - new_snapshots = [] - snapshot_to_dialect = {} - index_type = index_text_type(engine_adapter.dialect) - - for name, identifier, version, snapshot, kind_name in engine_adapter.fetchall( - exp.select("name", "identifier", "version", "snapshot", "kind_name").from_(snapshots_table), - quote_identifiers=True, - ): - parsed_snapshot = json.loads(snapshot) - # This is here in the case where the user originally had catalog in this model name, and therefore - # we would have before created the table with the catalog in the name. New logic removes the catalog, - # and therefore we need to make sure the table name is the same as the original table name, so we include - # this override - parsed_snapshot["base_table_name_override"] = parsed_snapshot["name"] - node = parsed_snapshot["node"] - dialect = node.get("dialect") - normalized_name = ( - normalize_model_name(name, default_catalog=default_catalog, dialect=dialect) - if node["source_type"] != "audit" - else name - ) - parsed_snapshot["name"] = normalized_name - # At the time of migration all nodes had default catalog, so we don't have to check type - node["default_catalog"] = default_catalog - snapshot_to_dialect[name] = dialect - mapping_schema = node.get("mapping_schema", {}) - if mapping_schema: - normalized_default_catalog = ( - normalize_model_name(default_catalog, default_catalog=None, dialect=dialect) - if default_catalog - else None - ) - mapping_schema_depth = dict_depth(mapping_schema) - if mapping_schema_depth == 3 and normalized_default_catalog: - mapping_schema = {normalized_default_catalog: mapping_schema} - node["mapping_schema"] = normalize_mapping_schema(mapping_schema, dialect) - depends_on = node.get("depends_on", []) - if depends_on: - node["depends_on"] = [ - normalize_model_name(dep, default_catalog, dialect) for dep in depends_on - ] - if parsed_snapshot["parents"]: - parsed_snapshot["parents"] = [ - { - "name": normalize_model_name(parent["name"], default_catalog, dialect), - "identifier": parent["identifier"], - } - for parent in parsed_snapshot["parents"] - ] - if parsed_snapshot["indirect_versions"]: - parsed_snapshot["indirect_versions"] = { - normalize_model_name(name, default_catalog, dialect): snapshot_data_versions - for name, snapshot_data_versions in parsed_snapshot["indirect_versions"].items() - } - # dbt specific migration - jinja_macros = node.get("jinja_macros") - if ( - default_catalog - and jinja_macros - and jinja_macros.get("create_builtins_module") == "sqlmesh.dbt" - ): - update_dbt_relations( - jinja_macros.get("global_objs"), ["refs", "sources", "this"], default_catalog - ) - - new_snapshots.append( - { - "name": normalized_name, - "identifier": identifier, - "version": version, - "snapshot": json.dumps(parsed_snapshot), - "kind_name": kind_name, - } - ) - - if new_snapshots: - engine_adapter.delete_from(snapshots_table, "TRUE") - - engine_adapter.insert_append( - snapshots_table, - pd.DataFrame(new_snapshots), - columns_to_types={ - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build(index_type), - "version": exp.DataType.build(index_type), - "snapshot": exp.DataType.build("text"), - "kind_name": exp.DataType.build(index_type), - }, - ) - - new_environments = [] - default_dialect = seq_get(list(snapshot_to_dialect.values()), 0) - for ( - name, - snapshots, - start_at, - end_at, - plan_id, - previous_plan_id, - expiration_ts, - finalized_ts, - promoted_snapshot_ids, - suffix_target, - ) in engine_adapter.fetchall( - exp.select( - "name", - "snapshots", - "start_at", - "end_at", - "plan_id", - "previous_plan_id", - "expiration_ts", - "finalized_ts", - "promoted_snapshot_ids", - "suffix_target", - ).from_(environments_table), - quote_identifiers=True, - ): - new_snapshots = [] - for snapshot in json.loads(snapshots): - snapshot_name = snapshot["name"] - snapshot["base_table_name_override"] = snapshot_name - dialect = snapshot_to_dialect.get(snapshot_name, default_dialect) - node_type = snapshot.get("node_type") - normalized_name = ( - normalize_model_name(snapshot_name, default_catalog, dialect) - if node_type is None or node_type == "model" - else snapshot_name - ) - snapshot["name"] = normalized_name - if snapshot["parents"]: - snapshot["parents"] = [ - { - "name": normalize_model_name(parent["name"], default_catalog, dialect), - "identifier": parent["identifier"], - } - for parent in snapshot["parents"] - ] - new_snapshots.append(snapshot) - - new_environments.append( - { - "name": name, - "snapshots": json.dumps(new_snapshots), - "start_at": start_at, - "end_at": end_at, - "plan_id": plan_id, - "previous_plan_id": previous_plan_id, - "expiration_ts": expiration_ts, - "finalized_ts": finalized_ts, - "promoted_snapshot_ids": promoted_snapshot_ids, - "suffix_target": suffix_target, - } - ) - - if new_environments: - engine_adapter.delete_from(environments_table, "TRUE") - - engine_adapter.insert_append( - environments_table, - pd.DataFrame(new_environments), - columns_to_types={ - "name": exp.DataType.build(index_type), - "snapshots": exp.DataType.build("text"), - "start_at": exp.DataType.build("text"), - "end_at": exp.DataType.build("text"), - "plan_id": exp.DataType.build("text"), - "previous_plan_id": exp.DataType.build("text"), - "expiration_ts": exp.DataType.build("bigint"), - "finalized_ts": exp.DataType.build("bigint"), - "promoted_snapshot_ids": exp.DataType.build("text"), - "suffix_target": exp.DataType.build("text"), - }, - ) - - # We update environment to not be finalized in order to force them to update their views - # in order to make sure the views now have the fully qualified names - # We only do this if a default catalog was applied otherwise the current views are fine - # We do this post creating the new environments in order to avoid having to find a way to - # expression a null timestamp value in pandas that works across all engines - if default_catalog: - engine_adapter.execute( - exp.update(environments_table, {"finalized_ts": None}, where="1=1"), - quote_identifiers=True, - ) - - new_intervals = [] - for ( - id, - created_ts, - name, - identifier, - version, - start_ts, - end_ts, - is_dev, - is_removed, - is_compacted, - ) in engine_adapter.fetchall( - exp.select( - "id", - "created_ts", - "name", - "identifier", - "version", - "start_ts", - "end_ts", - "is_dev", - "is_removed", - "is_compacted", - ).from_(intervals_table), - quote_identifiers=True, - ): - dialect = snapshot_to_dialect.get(name, default_dialect) - normalized_name = normalize_model_name(name, default_catalog, dialect) - new_intervals.append( - { - "id": id, - "created_ts": created_ts, - "name": normalized_name, - "identifier": identifier, - "version": version, - "start_ts": start_ts, - "end_ts": end_ts, - "is_dev": is_dev, - "is_removed": is_removed, - "is_compacted": is_compacted, - } - ) - - if new_intervals: - engine_adapter.delete_from(intervals_table, "TRUE") - - engine_adapter.insert_append( - intervals_table, - pd.DataFrame(new_intervals), - columns_to_types={ - "id": exp.DataType.build(index_type), - "created_ts": exp.DataType.build("bigint"), - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build(index_type), - "version": exp.DataType.build(index_type), - "start_ts": exp.DataType.build("bigint"), - "end_ts": exp.DataType.build("bigint"), - "is_dev": exp.DataType.build("boolean"), - "is_removed": exp.DataType.build("boolean"), - "is_compacted": exp.DataType.build("boolean"), - }, - ) - - new_seeds = [] - for ( - name, - identifier, - content, - ) in engine_adapter.fetchall( - exp.select( - "name", - "identifier", - "content", - ).from_(seeds_table), - quote_identifiers=True, - ): - dialect = snapshot_to_dialect.get(name, default_dialect) - normalized_name = normalize_model_name(name, default_catalog, dialect) - new_seeds.append( - { - "name": normalized_name, - "identifier": identifier, - "content": content, - } - ) - - if new_seeds: - engine_adapter.delete_from(seeds_table, "TRUE") - - engine_adapter.insert_append( - seeds_table, - pd.DataFrame(new_seeds), - columns_to_types={ - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build(index_type), - "content": exp.DataType.build("text"), - }, - ) diff --git a/sqlmesh/migrations/v0035_add_catalog_name_override.py b/sqlmesh/migrations/v0035_add_catalog_name_override.py deleted file mode 100644 index 34cfc8f34e..0000000000 --- a/sqlmesh/migrations/v0035_add_catalog_name_override.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Add support for environment catalog name override.""" - -from sqlglot import exp - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - environments_table = "_environments" - if state_sync.schema: - environments_table = f"{state_sync.schema}.{environments_table}" - - alter_table_exp = exp.AlterTable( - this=exp.to_table(environments_table), - actions=[ - exp.ColumnDef( - this=exp.to_column("catalog_name_override"), - kind=exp.DataType.build("text"), - ) - ], - ) - engine_adapter.execute(alter_table_exp) diff --git a/sqlmesh/migrations/v0036_delete_plan_dags_bug_fix.py b/sqlmesh/migrations/v0036_delete_plan_dags_bug_fix.py deleted file mode 100644 index 7f9f49d61d..0000000000 --- a/sqlmesh/migrations/v0036_delete_plan_dags_bug_fix.py +++ /dev/null @@ -1,14 +0,0 @@ -"""Add missing delete from migration #34.""" - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - schema = state_sync.schema - plan_dags_table = "_plan_dags" - if state_sync.schema: - plan_dags_table = f"{schema}.{plan_dags_table}" - - # At the time of migration plan_dags table is only needed for in-flight DAGs and therefore we can safely - # just delete it instead of migrating it - # If reusing this code verify that this is still the case - engine_adapter.delete_from(plan_dags_table, "TRUE") diff --git a/sqlmesh/migrations/v0037_remove_dbt_is_incremental_macro.py b/sqlmesh/migrations/v0037_remove_dbt_is_incremental_macro.py deleted file mode 100644 index 0302f8b575..0000000000 --- a/sqlmesh/migrations/v0037_remove_dbt_is_incremental_macro.py +++ /dev/null @@ -1,58 +0,0 @@ -"""Remove dbt is_incremental macro""" - -import json - -import pandas as pd -from sqlglot import exp - -from sqlmesh.utils.migration import index_text_type - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - schema = state_sync.schema - snapshots_table = "_snapshots" - if schema: - snapshots_table = f"{schema}.{snapshots_table}" - - new_snapshots = [] - found_dbt_package = False - for name, identifier, version, snapshot, kind_name in engine_adapter.fetchall( - exp.select("name", "identifier", "version", "snapshot", "kind_name").from_(snapshots_table), - quote_identifiers=True, - ): - parsed_snapshot = json.loads(snapshot) - node = parsed_snapshot["node"] - dbt_package = node.get("jinja_macros", {}).get("packages", {}).get("dbt", {}) - - if dbt_package: - found_dbt_package = True - dbt_package.pop("is_incremental", None) - dbt_package.pop("should_full_refresh", None) - - new_snapshots.append( - { - "name": name, - "identifier": identifier, - "version": version, - "snapshot": json.dumps(parsed_snapshot), - "kind_name": kind_name, - } - ) - - if found_dbt_package: - engine_adapter.delete_from(snapshots_table, "TRUE") - - index_type = index_text_type(engine_adapter.dialect) - - engine_adapter.insert_append( - snapshots_table, - pd.DataFrame(new_snapshots), - columns_to_types={ - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build(index_type), - "version": exp.DataType.build(index_type), - "snapshot": exp.DataType.build("text"), - "kind_name": exp.DataType.build(index_type), - }, - ) diff --git a/sqlmesh/migrations/v0038_add_expiration_ts_to_snapshot.py b/sqlmesh/migrations/v0038_add_expiration_ts_to_snapshot.py deleted file mode 100644 index c96efab6be..0000000000 --- a/sqlmesh/migrations/v0038_add_expiration_ts_to_snapshot.py +++ /dev/null @@ -1,69 +0,0 @@ -"""Add the expiration_ts column to the snapshots table.""" - -import json - -import pandas as pd -from sqlglot import exp - -from sqlmesh.utils.date import to_datetime, to_timestamp -from sqlmesh.utils.migration import index_text_type - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - schema = state_sync.schema - snapshots_table = "_snapshots" - if schema: - snapshots_table = f"{schema}.{snapshots_table}" - - index_type = index_text_type(engine_adapter.dialect) - - alter_table_exp = exp.AlterTable( - this=exp.to_table(snapshots_table), - actions=[ - exp.ColumnDef( - this=exp.to_column("expiration_ts"), - kind=exp.DataType.build("bigint"), - ) - ], - ) - engine_adapter.execute(alter_table_exp) - - new_snapshots = [] - - for name, identifier, version, snapshot, kind_name in engine_adapter.fetchall( - exp.select("name", "identifier", "version", "snapshot", "kind_name").from_(snapshots_table), - quote_identifiers=True, - ): - parsed_snapshot = json.loads(snapshot) - - updated_ts = parsed_snapshot["updated_ts"] - ttl = parsed_snapshot["ttl"] - expiration_ts = to_timestamp(ttl, relative_base=to_datetime(updated_ts)) - - new_snapshots.append( - { - "name": name, - "identifier": identifier, - "version": version, - "snapshot": snapshot, - "kind_name": kind_name, - "expiration_ts": expiration_ts, - } - ) - - if new_snapshots: - engine_adapter.delete_from(snapshots_table, "TRUE") - - engine_adapter.insert_append( - snapshots_table, - pd.DataFrame(new_snapshots), - columns_to_types={ - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build(index_type), - "version": exp.DataType.build(index_type), - "snapshot": exp.DataType.build("text"), - "kind_name": exp.DataType.build(index_type), - "expiration_ts": exp.DataType.build("bigint"), - }, - ) diff --git a/sqlmesh/migrations/v0039_include_environment_in_plan_dag_spec.py b/sqlmesh/migrations/v0039_include_environment_in_plan_dag_spec.py deleted file mode 100644 index 597280fee8..0000000000 --- a/sqlmesh/migrations/v0039_include_environment_in_plan_dag_spec.py +++ /dev/null @@ -1,65 +0,0 @@ -"""Include environment in plan dag spec.""" - -import json - -import pandas as pd -from sqlglot import exp - -from sqlmesh.utils.migration import index_text_type - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - schema = state_sync.schema - plan_dags_table = "_plan_dags" - if state_sync.schema: - plan_dags_table = f"{schema}.{plan_dags_table}" - - new_specs = [] - - for request_id, dag_id, dag_spec in engine_adapter.fetchall( - exp.select("request_id", "dag_id", "dag_spec").from_(plan_dags_table), - quote_identifiers=True, - ): - parsed_dag_spec = json.loads(dag_spec) - - environment_naming_info = parsed_dag_spec.pop("environment_naming_info") - promoted_snapshots = parsed_dag_spec.pop("promoted_snapshots", []) - start = parsed_dag_spec.pop("start") - parsed_dag_spec.pop("end", None) - plan_id = parsed_dag_spec.pop("plan_id") - previous_plan_id = parsed_dag_spec.pop("previous_plan_id", None) - expiration_ts = parsed_dag_spec.pop("environment_expiration_ts", None) - - parsed_dag_spec["environment"] = { - **environment_naming_info, - "snapshots": promoted_snapshots, - "start_at": start, - "end_at": start, - "plan_id": plan_id, - "previous_plan_id": previous_plan_id, - "expiration_ts": expiration_ts, - } - - new_specs.append( - { - "request_id": request_id, - "dag_id": dag_id, - "dag_spec": json.dumps(parsed_dag_spec), - } - ) - - if new_specs: - engine_adapter.delete_from(plan_dags_table, "TRUE") - - index_type = index_text_type(engine_adapter.dialect) - - engine_adapter.insert_append( - plan_dags_table, - pd.DataFrame(new_specs), - columns_to_types={ - "request_id": exp.DataType.build(index_type), - "dag_id": exp.DataType.build(index_type), - "dag_spec": exp.DataType.build("text"), - }, - ) diff --git a/sqlmesh/migrations/v0040_add_previous_finalized_snapshots.py b/sqlmesh/migrations/v0040_add_previous_finalized_snapshots.py deleted file mode 100644 index 0ac1417535..0000000000 --- a/sqlmesh/migrations/v0040_add_previous_finalized_snapshots.py +++ /dev/null @@ -1,21 +0,0 @@ -"""Add support for environment previous finalized snapshots.""" - -from sqlglot import exp - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - environments_table = "_environments" - if state_sync.schema: - environments_table = f"{state_sync.schema}.{environments_table}" - - alter_table_exp = exp.AlterTable( - this=exp.to_table(environments_table), - actions=[ - exp.ColumnDef( - this=exp.to_column("previous_finalized_snapshots"), - kind=exp.DataType.build("text"), - ) - ], - ) - engine_adapter.execute(alter_table_exp) diff --git a/sqlmesh/migrations/v0041_remove_hash_raw_query_attribute.py b/sqlmesh/migrations/v0041_remove_hash_raw_query_attribute.py deleted file mode 100644 index 953060554c..0000000000 --- a/sqlmesh/migrations/v0041_remove_hash_raw_query_attribute.py +++ /dev/null @@ -1,56 +0,0 @@ -"""Remove hash_raw_query from existing snapshots.""" - -import json - -import pandas as pd -from sqlglot import exp - -from sqlmesh.utils.migration import index_text_type - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - schema = state_sync.schema - snapshots_table = "_snapshots" - if schema: - snapshots_table = f"{schema}.{snapshots_table}" - - new_snapshots = [] - - for name, identifier, version, snapshot, kind_name, expiration_ts in engine_adapter.fetchall( - exp.select("name", "identifier", "version", "snapshot", "kind_name", "expiration_ts").from_( - snapshots_table - ), - quote_identifiers=True, - ): - parsed_snapshot = json.loads(snapshot) - parsed_snapshot["node"].pop("hash_raw_query", None) - - new_snapshots.append( - { - "name": name, - "identifier": identifier, - "version": version, - "snapshot": json.dumps(parsed_snapshot), - "kind_name": kind_name, - "expiration_ts": expiration_ts, - } - ) - - if new_snapshots: - engine_adapter.delete_from(snapshots_table, "TRUE") - - index_type = index_text_type(engine_adapter.dialect) - - engine_adapter.insert_append( - snapshots_table, - pd.DataFrame(new_snapshots), - columns_to_types={ - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build(index_type), - "version": exp.DataType.build(index_type), - "snapshot": exp.DataType.build("text"), - "kind_name": exp.DataType.build(index_type), - "expiration_ts": exp.DataType.build("bigint"), - }, - ) diff --git a/sqlmesh/migrations/v0042_trim_indirect_versions.py b/sqlmesh/migrations/v0042_trim_indirect_versions.py deleted file mode 100644 index 9233036ffd..0000000000 --- a/sqlmesh/migrations/v0042_trim_indirect_versions.py +++ /dev/null @@ -1,63 +0,0 @@ -"""Trim irrelevant attributes from indirect versions.""" - -import json - -import pandas as pd -from sqlglot import exp - -from sqlmesh.utils.migration import index_text_type - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - schema = state_sync.schema - snapshots_table = "_snapshots" - if schema: - snapshots_table = f"{schema}.{snapshots_table}" - - new_snapshots = [] - - for name, identifier, version, snapshot, kind_name, expiration_ts in engine_adapter.fetchall( - exp.select("name", "identifier", "version", "snapshot", "kind_name", "expiration_ts").from_( - snapshots_table - ), - quote_identifiers=True, - ): - parsed_snapshot = json.loads(snapshot) - for indirect_versions in parsed_snapshot["indirect_versions"].values(): - for indirect_version in indirect_versions: - # Only keep version and change_category. - version = indirect_version.get("version") - change_category = indirect_version.get("change_category") - indirect_version.clear() - indirect_version["version"] = version - indirect_version["change_category"] = change_category - - new_snapshots.append( - { - "name": name, - "identifier": identifier, - "version": version, - "snapshot": json.dumps(parsed_snapshot), - "kind_name": kind_name, - "expiration_ts": expiration_ts, - } - ) - - if new_snapshots: - engine_adapter.delete_from(snapshots_table, "TRUE") - - index_type = index_text_type(engine_adapter.dialect) - - engine_adapter.insert_append( - snapshots_table, - pd.DataFrame(new_snapshots), - columns_to_types={ - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build(index_type), - "version": exp.DataType.build(index_type), - "snapshot": exp.DataType.build("text"), - "kind_name": exp.DataType.build(index_type), - "expiration_ts": exp.DataType.build("bigint"), - }, - ) diff --git a/sqlmesh/migrations/v0043_fix_remove_obsolete_attributes_in_plan_dags.py b/sqlmesh/migrations/v0043_fix_remove_obsolete_attributes_in_plan_dags.py deleted file mode 100644 index 73aa5b24ac..0000000000 --- a/sqlmesh/migrations/v0043_fix_remove_obsolete_attributes_in_plan_dags.py +++ /dev/null @@ -1,58 +0,0 @@ -"""Trim irrelevant attributes from the plan DAGs state.""" - -import json - -import pandas as pd -from sqlglot import exp - -from sqlmesh.utils.migration import index_text_type - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - schema = state_sync.schema - plan_dags_table = "_plan_dags" - if schema: - plan_dags_table = f"{schema}.{plan_dags_table}" - - new_dag_specs = [] - - for request_id, dag_id, dag_spec in engine_adapter.fetchall( - exp.select("request_id", "dag_id", "dag_spec").from_(plan_dags_table), - quote_identifiers=True, - ): - parsed_dag_spec = json.loads(dag_spec) - for snapshot in parsed_dag_spec.get("new_snapshots", []): - snapshot["node"].pop("hash_raw_query", None) - - for indirect_versions in snapshot.get("indirect_versions", {}).values(): - for indirect_version in indirect_versions: - # Only keep version and change_category. - version = indirect_version.get("version") - change_category = indirect_version.get("change_category") - indirect_version.clear() - indirect_version["version"] = version - indirect_version["change_category"] = change_category - - new_dag_specs.append( - { - "request_id": request_id, - "dag_id": dag_id, - "dag_spec": json.dumps(parsed_dag_spec), - } - ) - - if new_dag_specs: - engine_adapter.delete_from(plan_dags_table, "TRUE") - - index_type = index_text_type(engine_adapter.dialect) - - engine_adapter.insert_append( - plan_dags_table, - pd.DataFrame(new_dag_specs), - columns_to_types={ - "request_id": exp.DataType.build(index_type), - "dag_id": exp.DataType.build(index_type), - "dag_spec": exp.DataType.build("text"), - }, - ) diff --git a/sqlmesh/migrations/v0044_quote_identifiers_in_model_attributes.py b/sqlmesh/migrations/v0044_quote_identifiers_in_model_attributes.py deleted file mode 100644 index 82eae3db3b..0000000000 --- a/sqlmesh/migrations/v0044_quote_identifiers_in_model_attributes.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Quoted identifiers in model SQL attributes.""" - - -def migrate(state_sync, **kwargs): # type: ignore - pass diff --git a/sqlmesh/migrations/v0045_move_gateway_variable.py b/sqlmesh/migrations/v0045_move_gateway_variable.py deleted file mode 100644 index f14151a6da..0000000000 --- a/sqlmesh/migrations/v0045_move_gateway_variable.py +++ /dev/null @@ -1,67 +0,0 @@ -"""Move the gateway variable.""" - -import ast -import json - -import pandas as pd -from sqlglot import exp - -from sqlmesh.utils.migration import index_text_type - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - schema = state_sync.schema - snapshots_table = "_snapshots" - if schema: - snapshots_table = f"{schema}.{snapshots_table}" - - migration_needed = False - new_snapshots = [] - - for name, identifier, version, snapshot, kind_name, expiration_ts in engine_adapter.fetchall( - exp.select("name", "identifier", "version", "snapshot", "kind_name", "expiration_ts").from_( - snapshots_table - ), - quote_identifiers=True, - ): - parsed_snapshot = json.loads(snapshot) - python_env = parsed_snapshot["node"].get("python_env") - if python_env: - gateway = python_env.pop("gateway", None) - if gateway is not None: - migration_needed = True - sqlmesh_vars = {"gateway": ast.literal_eval(gateway["payload"])} - python_env["__sqlmesh__vars__"] = { - "payload": repr(sqlmesh_vars), - "kind": "value", - } - - new_snapshots.append( - { - "name": name, - "identifier": identifier, - "version": version, - "snapshot": json.dumps(parsed_snapshot), - "kind_name": kind_name, - "expiration_ts": expiration_ts, - } - ) - - if migration_needed and new_snapshots: - engine_adapter.delete_from(snapshots_table, "TRUE") - - index_type = index_text_type(engine_adapter.dialect) - - engine_adapter.insert_append( - snapshots_table, - pd.DataFrame(new_snapshots), - columns_to_types={ - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build(index_type), - "version": exp.DataType.build(index_type), - "snapshot": exp.DataType.build("text"), - "kind_name": exp.DataType.build(index_type), - "expiration_ts": exp.DataType.build("bigint"), - }, - ) diff --git a/sqlmesh/migrations/v0046_add_batch_concurrency.py b/sqlmesh/migrations/v0046_add_batch_concurrency.py deleted file mode 100644 index a76dc358b5..0000000000 --- a/sqlmesh/migrations/v0046_add_batch_concurrency.py +++ /dev/null @@ -1,8 +0,0 @@ -"""Add the batch_concurrency attribute to the incremental model kinds. - -This results in a change to the metadata hash. -""" - - -def migrate(state_sync, **kwargs): # type: ignore - pass diff --git a/sqlmesh/migrations/v0047_change_scd_string_to_column.py b/sqlmesh/migrations/v0047_change_scd_string_to_column.py deleted file mode 100644 index 72ebbf2654..0000000000 --- a/sqlmesh/migrations/v0047_change_scd_string_to_column.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Changes the SCD Type 2 columns from strings to columns.""" - - -def migrate(state_sync, **kwargs): # type: ignore - pass diff --git a/sqlmesh/migrations/v0048_drop_indirect_versions.py b/sqlmesh/migrations/v0048_drop_indirect_versions.py deleted file mode 100644 index dde1f4ebb1..0000000000 --- a/sqlmesh/migrations/v0048_drop_indirect_versions.py +++ /dev/null @@ -1,56 +0,0 @@ -"""Drop the indirect_versions attribute in snapshots.""" - -import json - -import pandas as pd -from sqlglot import exp - -from sqlmesh.utils.migration import index_text_type - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - schema = state_sync.schema - snapshots_table = "_snapshots" - if schema: - snapshots_table = f"{schema}.{snapshots_table}" - - new_snapshots = [] - - for name, identifier, version, snapshot, kind_name, expiration_ts in engine_adapter.fetchall( - exp.select("name", "identifier", "version", "snapshot", "kind_name", "expiration_ts").from_( - snapshots_table - ), - quote_identifiers=True, - ): - parsed_snapshot = json.loads(snapshot) - parsed_snapshot.pop("indirect_versions", None) - - new_snapshots.append( - { - "name": name, - "identifier": identifier, - "version": version, - "snapshot": json.dumps(parsed_snapshot), - "kind_name": kind_name, - "expiration_ts": expiration_ts, - } - ) - - if new_snapshots: - engine_adapter.delete_from(snapshots_table, "TRUE") - - index_type = index_text_type(engine_adapter.dialect) - - engine_adapter.insert_append( - snapshots_table, - pd.DataFrame(new_snapshots), - columns_to_types={ - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build(index_type), - "version": exp.DataType.build(index_type), - "snapshot": exp.DataType.build("text"), - "kind_name": exp.DataType.build(index_type), - "expiration_ts": exp.DataType.build("bigint"), - }, - ) diff --git a/sqlmesh/migrations/v0049_replace_identifier_with_version_in_seeds_table.py b/sqlmesh/migrations/v0049_replace_identifier_with_version_in_seeds_table.py deleted file mode 100644 index 186b5f7856..0000000000 --- a/sqlmesh/migrations/v0049_replace_identifier_with_version_in_seeds_table.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Use version instead of identifier in the seeds table.""" - -from sqlglot import exp - -from sqlmesh.utils.migration import index_text_type - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - - snapshots_table = "_snapshots" - seeds_table = "_seeds" - new_seeds_table = f"{seeds_table}_v49" - - if state_sync.schema: - snapshots_table = f"{state_sync.schema}.{snapshots_table}" - seeds_table = f"{state_sync.schema}.{seeds_table}" - new_seeds_table = f"{state_sync.schema}.{new_seeds_table}" - - index_type = index_text_type(engine_adapter.dialect) - - engine_adapter.drop_table(new_seeds_table) - engine_adapter.create_state_table( - new_seeds_table, - { - "name": exp.DataType.build(index_type), - "version": exp.DataType.build(index_type), - "content": exp.DataType.build("text"), - }, - primary_key=("name", "version"), - ) - - name_col = exp.column("name", table="seeds") - version_col = exp.column("version", table="snapshots") - query = ( - exp.select( - name_col, - version_col, - exp.func("MAX", exp.column("content", table="seeds")).as_("content"), - ) - .from_(exp.to_table(seeds_table).as_("seeds")) - .join( - exp.to_table(snapshots_table).as_("snapshots"), - on=exp.and_( - exp.column("name", table="seeds").eq(exp.column("name", table="snapshots")), - exp.column("identifier", table="seeds").eq( - exp.column("identifier", table="snapshots") - ), - ), - ) - .where(exp.column("version", table="snapshots").is_(exp.null()).not_()) - .group_by(name_col, version_col) - ) - - engine_adapter.insert_append(new_seeds_table, query) - engine_adapter.drop_table(seeds_table) - engine_adapter.rename_table(new_seeds_table, seeds_table) diff --git a/sqlmesh/migrations/v0050_drop_seeds_table.py b/sqlmesh/migrations/v0050_drop_seeds_table.py deleted file mode 100644 index 706fae63ed..0000000000 --- a/sqlmesh/migrations/v0050_drop_seeds_table.py +++ /dev/null @@ -1,11 +0,0 @@ -"""Drop the seeds table.""" - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - - seeds_table = "_seeds" - if state_sync.schema: - seeds_table = f"{state_sync.schema}.{seeds_table}" - - engine_adapter.drop_table(seeds_table) diff --git a/sqlmesh/migrations/v0051_rename_column_descriptions.py b/sqlmesh/migrations/v0051_rename_column_descriptions.py deleted file mode 100644 index c624240bdc..0000000000 --- a/sqlmesh/migrations/v0051_rename_column_descriptions.py +++ /dev/null @@ -1,62 +0,0 @@ -"""Rename the node attribute `column_descriptions_` to `column_descriptions` in snapshots.""" - -import json - -import pandas as pd -from sqlglot import exp - -from sqlmesh.utils.migration import index_text_type - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - schema = state_sync.schema - snapshots_table = "_snapshots" - if schema: - snapshots_table = f"{schema}.{snapshots_table}" - - new_snapshots = [] - found_col_descriptions = False - - for name, identifier, version, snapshot, kind_name, expiration_ts in engine_adapter.fetchall( - exp.select("name", "identifier", "version", "snapshot", "kind_name", "expiration_ts").from_( - snapshots_table - ), - quote_identifiers=True, - ): - parsed_snapshot = json.loads(snapshot) - - if "column_descriptions_" in parsed_snapshot["node"]: - found_col_descriptions = True - parsed_snapshot["node"]["column_descriptions"] = parsed_snapshot["node"].pop( - "column_descriptions_" - ) - - new_snapshots.append( - { - "name": name, - "identifier": identifier, - "version": version, - "snapshot": json.dumps(parsed_snapshot), - "kind_name": kind_name, - "expiration_ts": expiration_ts, - } - ) - - if found_col_descriptions: - engine_adapter.delete_from(snapshots_table, "TRUE") - - index_type = index_text_type(engine_adapter.dialect) - - engine_adapter.insert_append( - snapshots_table, - pd.DataFrame(new_snapshots), - columns_to_types={ - "name": exp.DataType.build(index_type), - "identifier": exp.DataType.build(index_type), - "version": exp.DataType.build(index_type), - "snapshot": exp.DataType.build("text"), - "kind_name": exp.DataType.build(index_type), - "expiration_ts": exp.DataType.build("bigint"), - }, - ) diff --git a/sqlmesh/migrations/v0052_add_normalize_name_in_environment_naming_info.py b/sqlmesh/migrations/v0052_add_normalize_name_in_environment_naming_info.py deleted file mode 100644 index 617c298544..0000000000 --- a/sqlmesh/migrations/v0052_add_normalize_name_in_environment_naming_info.py +++ /dev/null @@ -1,27 +0,0 @@ -"""Add flag that controls whether environment names will be normalized.""" - -from sqlglot import exp - - -def migrate(state_sync, **kwargs): # type: ignore - engine_adapter = state_sync.engine_adapter - environments_table = "_environments" - if state_sync.schema: - environments_table = f"{state_sync.schema}.{environments_table}" - - alter_table_exp = exp.AlterTable( - this=exp.to_table(environments_table), - actions=[ - exp.ColumnDef( - this=exp.to_column("normalize_name"), - kind=exp.DataType.build("boolean"), - ) - ], - ) - engine_adapter.execute(alter_table_exp) - - state_sync.engine_adapter.update_table( - environments_table, - {"normalize_name": False}, - where=exp.true(), - ) diff --git a/sqlmesh/migrations/v0053_custom_model_kind_extra_attributes.py b/sqlmesh/migrations/v0053_custom_model_kind_extra_attributes.py deleted file mode 100644 index bc242964a5..0000000000 --- a/sqlmesh/migrations/v0053_custom_model_kind_extra_attributes.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Add batch_size, batch_concurrency, and batch_interval to the CUSTOM model kind.""" - - -def migrate(state_sync, **kwargs): # type: ignore - pass diff --git a/sqlmesh/migrations/v0054_fix_trailing_comments.py b/sqlmesh/migrations/v0054_fix_trailing_comments.py deleted file mode 100644 index 0084626e3d..0000000000 --- a/sqlmesh/migrations/v0054_fix_trailing_comments.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Fix support for trailing comments in SQL model definitions.""" - - -def migrate(state_sync, **kwargs): # type: ignore - pass diff --git a/sqlmesh/migrations/v0061_mysql_fix_blob_text_type.py b/sqlmesh/migrations/v0061_mysql_fix_blob_text_type.py new file mode 100644 index 0000000000..897974f09a --- /dev/null +++ b/sqlmesh/migrations/v0061_mysql_fix_blob_text_type.py @@ -0,0 +1,47 @@ +"""Duplicate of v0033, Use LONGTEXT type for blob fields in MySQL. + +Seeds table has since been dropped. +Environments table now has a requirements column. +""" + +from sqlglot import exp + +from sqlmesh.utils.migration import blob_text_type + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + if engine_adapter.dialect != "mysql": + return + environments_table = "_environments" + snapshots_table = "_snapshots" + + if schema: + environments_table = f"{schema}.{environments_table}" + snapshots_table = f"{schema}.{snapshots_table}" + + targets = [ + (environments_table, "snapshots"), + (environments_table, "promoted_snapshot_ids"), + (environments_table, "previous_finalized_snapshots"), + (environments_table, "requirements"), + (snapshots_table, "snapshot"), + ] + + for table_name, column_name in targets: + blob_type = blob_text_type(engine_adapter.dialect) + alter_table_exp = exp.Alter( + this=exp.to_table(table_name), + kind="TABLE", + actions=[ + exp.AlterColumn( + this=exp.to_column(column_name), + dtype=exp.DataType.build(blob_type), + ) + ], + ) + + engine_adapter.execute(alter_table_exp) + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + pass diff --git a/sqlmesh/migrations/v0062_add_model_gateway.py b/sqlmesh/migrations/v0062_add_model_gateway.py new file mode 100644 index 0000000000..f65d8224ec --- /dev/null +++ b/sqlmesh/migrations/v0062_add_model_gateway.py @@ -0,0 +1,9 @@ +"""Add the gateway model attribute.""" + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + pass + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + pass diff --git a/sqlmesh/migrations/v0063_change_signals.py b/sqlmesh/migrations/v0063_change_signals.py new file mode 100644 index 0000000000..bbced547fd --- /dev/null +++ b/sqlmesh/migrations/v0063_change_signals.py @@ -0,0 +1,100 @@ +"""Change serialization of signals to allow for function calls.""" + +import json + +from sqlglot import exp, parse_one + +from sqlmesh.utils.migration import index_text_type, blob_text_type + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + pass + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + import pandas as pd + + snapshots_table = "_snapshots" + index_type = index_text_type(engine_adapter.dialect) + if schema: + snapshots_table = f"{schema}.{snapshots_table}" + + new_snapshots = [] + + signal_change = False + for ( + name, + identifier, + version, + snapshot, + kind_name, + updated_ts, + unpaused_ts, + ttl_ms, + unrestorable, + ) in engine_adapter.fetchall( + exp.select( + "name", + "identifier", + "version", + "snapshot", + "kind_name", + "updated_ts", + "unpaused_ts", + "ttl_ms", + "unrestorable", + ).from_(snapshots_table), + quote_identifiers=True, + ): + parsed_snapshot = json.loads(snapshot) + node = parsed_snapshot["node"] + signals = node.get("signals") + + if signals: + signal_change = True + node["signals"] = [] + + for signal in signals: + node["signals"].append( + ( + "", + { + eq.left.name: eq.right.sql() + for eq in parse_one(signal, into=exp.Tuple).expressions + }, + ) + ) + + new_snapshots.append( + { + "name": name, + "identifier": identifier, + "version": version, + "snapshot": json.dumps(parsed_snapshot), + "kind_name": kind_name, + "updated_ts": updated_ts, + "unpaused_ts": unpaused_ts, + "ttl_ms": ttl_ms, + "unrestorable": unrestorable, + } + ) + + if signal_change and new_snapshots: + engine_adapter.delete_from(snapshots_table, "TRUE") + blob_type = blob_text_type(engine_adapter.dialect) + + engine_adapter.insert_append( + snapshots_table, + pd.DataFrame(new_snapshots), + target_columns_to_types={ + "name": exp.DataType.build(index_type), + "identifier": exp.DataType.build(index_type), + "version": exp.DataType.build(index_type), + "snapshot": exp.DataType.build(blob_type), + "kind_name": exp.DataType.build(index_type), + "updated_ts": exp.DataType.build("bigint"), + "unpaused_ts": exp.DataType.build("bigint"), + "ttl_ms": exp.DataType.build("bigint"), + "unrestorable": exp.DataType.build("boolean"), + }, + ) diff --git a/sqlmesh/migrations/v0064_join_when_matched_strings.py b/sqlmesh/migrations/v0064_join_when_matched_strings.py new file mode 100644 index 0000000000..ffd4c94913 --- /dev/null +++ b/sqlmesh/migrations/v0064_join_when_matched_strings.py @@ -0,0 +1,87 @@ +"""Join list of `WHEN [NOT] MATCHED` strings into a single string.""" + +import json + +from sqlglot import exp + +from sqlmesh.utils.migration import index_text_type, blob_text_type + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + pass + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + import pandas as pd + + snapshots_table = "_snapshots" + index_type = index_text_type(engine_adapter.dialect) + if schema: + snapshots_table = f"{schema}.{snapshots_table}" + + new_snapshots = [] + + for ( + name, + identifier, + version, + snapshot, + kind_name, + updated_ts, + unpaused_ts, + ttl_ms, + unrestorable, + ) in engine_adapter.fetchall( + exp.select( + "name", + "identifier", + "version", + "snapshot", + "kind_name", + "updated_ts", + "unpaused_ts", + "ttl_ms", + "unrestorable", + ).from_(snapshots_table), + quote_identifiers=True, + ): + parsed_snapshot = json.loads(snapshot) + node = parsed_snapshot["node"] + + kind = node.get("kind") + if kind and isinstance(when_matched := kind.get("when_matched"), list): + kind["when_matched"] = " ".join(when_matched) + + new_snapshots.append( + { + "name": name, + "identifier": identifier, + "version": version, + "snapshot": json.dumps(parsed_snapshot), + "kind_name": kind_name, + "updated_ts": updated_ts, + "unpaused_ts": unpaused_ts, + "ttl_ms": ttl_ms, + "unrestorable": unrestorable, + } + ) + + if new_snapshots: + engine_adapter.delete_from(snapshots_table, "TRUE") + blob_type = blob_text_type(engine_adapter.dialect) + + engine_adapter.insert_append( + snapshots_table, + pd.DataFrame(new_snapshots), + target_columns_to_types={ + "name": exp.DataType.build(index_type), + "identifier": exp.DataType.build(index_type), + "version": exp.DataType.build(index_type), + "snapshot": exp.DataType.build(blob_type), + "kind_name": exp.DataType.build(index_type), + "updated_ts": exp.DataType.build("bigint"), + "unpaused_ts": exp.DataType.build("bigint"), + "ttl_ms": exp.DataType.build("bigint"), + "unrestorable": exp.DataType.build("boolean"), + }, + ) diff --git a/sqlmesh/migrations/v0065_add_model_optimize.py b/sqlmesh/migrations/v0065_add_model_optimize.py new file mode 100644 index 0000000000..e9bc646666 --- /dev/null +++ b/sqlmesh/migrations/v0065_add_model_optimize.py @@ -0,0 +1,9 @@ +"""Add the optimize_query model attribute.""" + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + pass + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + pass diff --git a/sqlmesh/migrations/v0066_add_auto_restatements.py b/sqlmesh/migrations/v0066_add_auto_restatements.py new file mode 100644 index 0000000000..9eea773573 --- /dev/null +++ b/sqlmesh/migrations/v0066_add_auto_restatements.py @@ -0,0 +1,51 @@ +"""Add the auto restatements table.""" + +from sqlglot import exp + +from sqlmesh.utils.migration import index_text_type + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + auto_restatements_table = "_auto_restatements" + intervals_table = "_intervals" + + if schema: + auto_restatements_table = f"{schema}.{auto_restatements_table}" + intervals_table = f"{schema}.{intervals_table}" + + index_type = index_text_type(engine_adapter.dialect) + + engine_adapter.create_state_table( + auto_restatements_table, + { + "snapshot_name": exp.DataType.build(index_type), + "snapshot_version": exp.DataType.build(index_type), + "next_auto_restatement_ts": exp.DataType.build("bigint"), + }, + primary_key=("snapshot_name", "snapshot_version"), + ) + + alter_table_exp = exp.Alter( + this=exp.to_table(intervals_table), + kind="TABLE", + actions=[ + exp.ColumnDef( + this=exp.to_column("is_pending_restatement"), + kind=exp.DataType.build("boolean"), + ) + ], + ) + engine_adapter.execute(alter_table_exp) + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + intervals_table = "_intervals" + + if schema: + intervals_table = f"{schema}.{intervals_table}" + + engine_adapter.update_table( + intervals_table, + {"is_pending_restatement": False}, + where=exp.true(), + ) diff --git a/sqlmesh/migrations/v0067_add_tsql_date_full_precision.py b/sqlmesh/migrations/v0067_add_tsql_date_full_precision.py new file mode 100644 index 0000000000..1243118df0 --- /dev/null +++ b/sqlmesh/migrations/v0067_add_tsql_date_full_precision.py @@ -0,0 +1,9 @@ +"""Add full precision for tsql to support nanoseconds.""" + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + pass + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + pass diff --git a/sqlmesh/migrations/v0068_include_unrendered_query_in_metadata_hash.py b/sqlmesh/migrations/v0068_include_unrendered_query_in_metadata_hash.py new file mode 100644 index 0000000000..35142e9aeb --- /dev/null +++ b/sqlmesh/migrations/v0068_include_unrendered_query_in_metadata_hash.py @@ -0,0 +1,9 @@ +"""Include the unrendered query in the metadata hash.""" + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + pass + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + pass diff --git a/sqlmesh/migrations/v0069_update_dev_table_suffix.py b/sqlmesh/migrations/v0069_update_dev_table_suffix.py new file mode 100644 index 0000000000..f69aac434e --- /dev/null +++ b/sqlmesh/migrations/v0069_update_dev_table_suffix.py @@ -0,0 +1,168 @@ +"""Update the dev table suffix to be 'dev'. Rename temp_version to dev_version.""" + +import json + +from sqlglot import exp + +from sqlmesh.utils.migration import index_text_type, blob_text_type + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + pass + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + import pandas as pd + + snapshots_table = "_snapshots" + environments_table = "_environments" + if schema: + snapshots_table = f"{schema}.{snapshots_table}" + environments_table = f"{schema}.{environments_table}" + + index_type = index_text_type(engine_adapter.dialect) + blob_type = blob_text_type(engine_adapter.dialect) + snapshots_columns_to_types = { + "name": exp.DataType.build(index_type), + "identifier": exp.DataType.build(index_type), + "version": exp.DataType.build(index_type), + "snapshot": exp.DataType.build(blob_type), + "kind_name": exp.DataType.build(index_type), + "updated_ts": exp.DataType.build("bigint"), + "unpaused_ts": exp.DataType.build("bigint"), + "ttl_ms": exp.DataType.build("bigint"), + "unrestorable": exp.DataType.build("boolean"), + } + environments_columns_to_types = { + "name": exp.DataType.build(index_type), + "snapshots": exp.DataType.build(blob_type), + "start_at": exp.DataType.build("text"), + "end_at": exp.DataType.build("text"), + "plan_id": exp.DataType.build("text"), + "previous_plan_id": exp.DataType.build("text"), + "expiration_ts": exp.DataType.build("bigint"), + "finalized_ts": exp.DataType.build("bigint"), + "promoted_snapshot_ids": exp.DataType.build(blob_type), + "suffix_target": exp.DataType.build("text"), + "catalog_name_override": exp.DataType.build("text"), + "previous_finalized_snapshots": exp.DataType.build(blob_type), + "normalize_name": exp.DataType.build("boolean"), + "requirements": exp.DataType.build(blob_type), + } + + new_snapshots = [] + for ( + name, + identifier, + version, + snapshot, + kind_name, + updated_ts, + unpaused_ts, + ttl_ms, + unrestorable, + ) in engine_adapter.fetchall( + exp.select(*snapshots_columns_to_types).from_(snapshots_table), + quote_identifiers=True, + ): + parsed_snapshot = json.loads(snapshot) + parsed_snapshot = _update_snapshot(parsed_snapshot) + + new_snapshots.append( + { + "name": name, + "identifier": identifier, + "version": version, + "snapshot": json.dumps(parsed_snapshot), + "kind_name": kind_name, + "updated_ts": updated_ts, + "unpaused_ts": unpaused_ts, + "ttl_ms": ttl_ms, + "unrestorable": unrestorable, + } + ) + + if new_snapshots: + engine_adapter.delete_from(snapshots_table, "TRUE") + engine_adapter.insert_append( + snapshots_table, + pd.DataFrame(new_snapshots), + target_columns_to_types=snapshots_columns_to_types, + ) + + new_environments = [] + for ( + name, + snapshots, + start_at, + end_at, + plan_id, + previous_plan_id, + expiration_ts, + finalized_ts, + promoted_snapshot_ids, + suffix_target, + catalog_name_override, + previous_finalized_snapshots, + normalize_name, + requirements, + ) in engine_adapter.fetchall( + exp.select(*environments_columns_to_types).from_(environments_table), + quote_identifiers=True, + ): + if snapshots: + parsed_snapshots = json.loads(snapshots) + for s in parsed_snapshots: + _update_snapshot(s) + + if previous_finalized_snapshots: + parsed_previous_finalized_snapshots = json.loads(previous_finalized_snapshots) + for s in parsed_previous_finalized_snapshots: + _update_snapshot(s) + + new_environments.append( + { + "name": name, + "snapshots": json.dumps(parsed_snapshots) if snapshots else None, + "start_at": start_at, + "end_at": end_at, + "plan_id": plan_id, + "previous_plan_id": previous_plan_id, + "expiration_ts": expiration_ts, + "finalized_ts": finalized_ts, + "promoted_snapshot_ids": promoted_snapshot_ids, + "suffix_target": suffix_target, + "catalog_name_override": catalog_name_override, + "previous_finalized_snapshots": json.dumps(parsed_previous_finalized_snapshots) + if previous_finalized_snapshots + else None, + "normalize_name": normalize_name, + "requirements": requirements, + } + ) + + if new_environments: + engine_adapter.delete_from(environments_table, "TRUE") + engine_adapter.insert_append( + environments_table, + pd.DataFrame(new_environments), + target_columns_to_types=environments_columns_to_types, + ) + + +def _update_snapshot(snapshot: dict) -> dict: + snapshot = _update_fields(snapshot) + + if "previous_versions" in snapshot: + for previous_version in snapshot["previous_versions"]: + _update_fields(previous_version) + + return snapshot + + +def _update_fields(target: dict) -> dict: + # Setting the old suffix to match the names of existing tables. + target["dev_table_suffix"] = "temp" + if "temp_version" in target: + target["dev_version"] = target.pop("temp_version") + return target diff --git a/sqlmesh/migrations/v0070_include_grains_in_metadata_hash.py b/sqlmesh/migrations/v0070_include_grains_in_metadata_hash.py new file mode 100644 index 0000000000..d0dbdd5563 --- /dev/null +++ b/sqlmesh/migrations/v0070_include_grains_in_metadata_hash.py @@ -0,0 +1,9 @@ +"""Include grains in the metadata hash.""" + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + pass + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + pass diff --git a/sqlmesh/migrations/v0071_add_dev_version_to_intervals.py b/sqlmesh/migrations/v0071_add_dev_version_to_intervals.py new file mode 100644 index 0000000000..61a49dc0b9 --- /dev/null +++ b/sqlmesh/migrations/v0071_add_dev_version_to_intervals.py @@ -0,0 +1,251 @@ +"""Add dev version to the intervals table.""" + +import typing as t +import json +import zlib + +from sqlglot import exp +from sqlmesh.utils.migration import index_text_type, blob_text_type + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + intervals_table = "_intervals" + if schema: + intervals_table = f"{schema}.{intervals_table}" + + index_type = index_text_type(engine_adapter.dialect) + alter_table_exp = exp.Alter( + this=exp.to_table(intervals_table), + kind="TABLE", + actions=[ + exp.ColumnDef( + this=exp.to_column("dev_version"), + kind=exp.DataType.build(index_type), + ) + ], + ) + engine_adapter.execute(alter_table_exp) + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + intervals_table = "_intervals" + snapshots_table = "_snapshots" + if schema: + intervals_table = f"{schema}.{intervals_table}" + snapshots_table = f"{schema}.{snapshots_table}" + + used_dev_versions: t.Set[t.Tuple[str, str]] = set() + used_versions: t.Set[t.Tuple[str, str]] = set() + used_snapshot_ids: t.Set[t.Tuple[str, str]] = set() + snapshot_ids_to_dev_versions: t.Dict[t.Tuple[str, str], str] = {} + + _migrate_snapshots( + engine_adapter, + snapshots_table, + used_dev_versions, + used_versions, + used_snapshot_ids, + snapshot_ids_to_dev_versions, + ) + _migrate_intervals( + engine_adapter, + intervals_table, + used_dev_versions, + used_versions, + used_snapshot_ids, + snapshot_ids_to_dev_versions, + ) + + +def _migrate_intervals( + engine_adapter: t.Any, + intervals_table: str, + used_dev_versions: t.Set[t.Tuple[str, str]], + used_versions: t.Set[t.Tuple[str, str]], + used_snapshot_ids: t.Set[t.Tuple[str, str]], + snapshot_ids_to_dev_versions: t.Dict[t.Tuple[str, str], str], +) -> None: + import pandas as pd + + index_type = index_text_type(engine_adapter.dialect) + intervals_columns_to_types = { + "id": exp.DataType.build(index_type), + "created_ts": exp.DataType.build("bigint"), + "name": exp.DataType.build(index_type), + "identifier": exp.DataType.build("text"), + "version": exp.DataType.build(index_type), + "dev_version": exp.DataType.build(index_type), + "start_ts": exp.DataType.build("bigint"), + "end_ts": exp.DataType.build("bigint"), + "is_dev": exp.DataType.build("boolean"), + "is_removed": exp.DataType.build("boolean"), + "is_compacted": exp.DataType.build("boolean"), + "is_pending_restatement": exp.DataType.build("boolean"), + } + + new_intervals = [] + for ( + interval_id, + created_ts, + name, + identifier, + version, + _, + start_ts, + end_ts, + is_dev, + is_removed, + is_compacted, + is_pending_restatement, + ) in engine_adapter.fetchall( + exp.select(*intervals_columns_to_types).from_(intervals_table), + quote_identifiers=True, + ): + if (name, version) not in used_versions: + # If the interval's version is no longer used, we can safely delete it + continue + + dev_version = snapshot_ids_to_dev_versions.get((name, identifier)) + if dev_version not in used_dev_versions and is_dev: + # If the interval's dev version is no longer used and this is a dev interval, we can safely delete it + continue + + if (name, identifier) not in used_snapshot_ids: + # If the snapshot associated with this interval no longer exists, we can nullify the interval's identifier + # to improve compaction + is_compacted = False + identifier = None + if not is_dev: + # If the interval is not dev, we can safely nullify the dev version as well + dev_version = None + + new_intervals.append( + { + "id": interval_id, + "created_ts": created_ts, + "name": name, + "identifier": identifier, + "version": version, + "dev_version": dev_version, + "start_ts": start_ts, + "end_ts": end_ts, + "is_dev": is_dev, + "is_removed": is_removed, + "is_compacted": is_compacted, + "is_pending_restatement": is_pending_restatement, + } + ) + + if new_intervals: + engine_adapter.delete_from(intervals_table, "TRUE") + engine_adapter.insert_append( + intervals_table, + pd.DataFrame(new_intervals), + target_columns_to_types=intervals_columns_to_types, + ) + + +def _migrate_snapshots( + engine_adapter: t.Any, + snapshots_table: str, + used_dev_versions: t.Set[t.Tuple[str, str]], + used_versions: t.Set[t.Tuple[str, str]], + used_snapshot_ids: t.Set[t.Tuple[str, str]], + snapshot_ids_to_dev_versions: t.Dict[t.Tuple[str, str], str], +) -> None: + import pandas as pd + + index_type = index_text_type(engine_adapter.dialect) + blob_type = blob_text_type(engine_adapter.dialect) + snapshots_columns_to_types = { + "name": exp.DataType.build(index_type), + "identifier": exp.DataType.build(index_type), + "version": exp.DataType.build(index_type), + "snapshot": exp.DataType.build(blob_type), + "kind_name": exp.DataType.build(index_type), + "updated_ts": exp.DataType.build("bigint"), + "unpaused_ts": exp.DataType.build("bigint"), + "ttl_ms": exp.DataType.build("bigint"), + "unrestorable": exp.DataType.build("boolean"), + } + + new_snapshots = [] + for ( + name, + identifier, + version, + snapshot, + kind_name, + updated_ts, + unpaused_ts, + ttl_ms, + unrestorable, + ) in engine_adapter.fetchall( + exp.select(*snapshots_columns_to_types).from_(snapshots_table), + quote_identifiers=True, + ): + parsed_snapshot = json.loads(snapshot) + version = parsed_snapshot.get("version") or version + dev_version = get_dev_version(parsed_snapshot) + parsed_snapshot["dev_version"] = dev_version + parsed_snapshot["version"] = version + + used_dev_versions.add((name, dev_version)) + used_versions.add((name, version)) + used_snapshot_ids.add((name, identifier)) + snapshot_ids_to_dev_versions[(name, identifier)] = dev_version + + for previous_version in parsed_snapshot.get("previous_versions", []): + previous_identifier = get_identifier(previous_version) + previous_dev_version = get_dev_version(previous_version) + snapshot_ids_to_dev_versions[(name, previous_identifier)] = previous_dev_version + + new_snapshots.append( + { + "name": name, + "identifier": identifier, + "version": version, + "snapshot": json.dumps(parsed_snapshot), + "kind_name": kind_name, + "updated_ts": updated_ts, + "unpaused_ts": unpaused_ts, + "ttl_ms": ttl_ms, + "unrestorable": unrestorable, + } + ) + + if new_snapshots: + engine_adapter.delete_from(snapshots_table, "TRUE") + engine_adapter.insert_append( + snapshots_table, + pd.DataFrame(new_snapshots), + target_columns_to_types=snapshots_columns_to_types, + ) + + +def get_identifier(snapshot: t.Dict[str, t.Any]) -> str: + fingerprint = snapshot["fingerprint"] + return crc32( + [ + fingerprint["data_hash"], + fingerprint["metadata_hash"], + fingerprint["parent_data_hash"], + fingerprint["parent_metadata_hash"], + ] + ) + + +def get_dev_version(snapshot: t.Dict[str, t.Any]) -> str: + dev_version = snapshot.get("dev_version") + if dev_version: + return dev_version + fingerprint = snapshot["fingerprint"] + return crc32([fingerprint["data_hash"], fingerprint["parent_data_hash"]]) + + +def crc32(data: t.Iterable[t.Optional[str]]) -> str: + return str(zlib.crc32(safe_concat(data))) + + +def safe_concat(data: t.Iterable[t.Optional[str]]) -> bytes: + return ";".join("" if d is None else d for d in data).encode("utf-8") diff --git a/sqlmesh/migrations/v0072_add_environment_statements.py b/sqlmesh/migrations/v0072_add_environment_statements.py new file mode 100644 index 0000000000..4ed52b5c47 --- /dev/null +++ b/sqlmesh/migrations/v0072_add_environment_statements.py @@ -0,0 +1,29 @@ +"""Add the environment statements table.""" + +from sqlglot import exp + +from sqlmesh.utils.migration import blob_text_type, index_text_type + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + environment_statements_table = "_environment_statements" + + if schema: + environment_statements_table = f"{schema}.{environment_statements_table}" + + index_type = index_text_type(engine_adapter.dialect) + blob_type = blob_text_type(engine_adapter.dialect) + + engine_adapter.create_state_table( + environment_statements_table, + { + "environment_name": exp.DataType.build(index_type), + "plan_id": exp.DataType.build("text"), + "environment_statements": exp.DataType.build(blob_type), + }, + primary_key=("environment_name",), + ) + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + pass diff --git a/sqlmesh/migrations/v0073_remove_symbolic_disable_restatement.py b/sqlmesh/migrations/v0073_remove_symbolic_disable_restatement.py new file mode 100644 index 0000000000..708693ed61 --- /dev/null +++ b/sqlmesh/migrations/v0073_remove_symbolic_disable_restatement.py @@ -0,0 +1,75 @@ +"""Remove disable restatement from external and embedded models.""" + +import json + +from sqlglot import exp +from sqlmesh.utils.migration import index_text_type, blob_text_type + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + pass + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + import pandas as pd + + snapshots_table = "_snapshots" + if schema: + snapshots_table = f"{schema}.{snapshots_table}" + + index_type = index_text_type(engine_adapter.dialect) + blob_type = blob_text_type(engine_adapter.dialect) + snapshots_columns_to_types = { + "name": exp.DataType.build(index_type), + "identifier": exp.DataType.build(index_type), + "version": exp.DataType.build(index_type), + "snapshot": exp.DataType.build(blob_type), + "kind_name": exp.DataType.build(index_type), + "updated_ts": exp.DataType.build("bigint"), + "unpaused_ts": exp.DataType.build("bigint"), + "ttl_ms": exp.DataType.build("bigint"), + "unrestorable": exp.DataType.build("boolean"), + } + + new_snapshots = [] + for ( + name, + identifier, + version, + snapshot, + kind_name, + updated_ts, + unpaused_ts, + ttl_ms, + unrestorable, + ) in engine_adapter.fetchall( + exp.select(*snapshots_columns_to_types).from_(snapshots_table), + quote_identifiers=True, + ): + parsed_snapshot = json.loads(snapshot) + kind = parsed_snapshot["node"].get("kind") + + if kind and kind_name in ("EMBEDDED", "EXTERNAL"): + kind.pop("disable_restatement", None) + + new_snapshots.append( + { + "name": name, + "identifier": identifier, + "version": version, + "snapshot": json.dumps(parsed_snapshot), + "kind_name": kind_name, + "updated_ts": updated_ts, + "unpaused_ts": unpaused_ts, + "ttl_ms": ttl_ms, + "unrestorable": unrestorable, + } + ) + + if new_snapshots: + engine_adapter.delete_from(snapshots_table, "TRUE") + engine_adapter.insert_append( + snapshots_table, + pd.DataFrame(new_snapshots), + target_columns_to_types=snapshots_columns_to_types, + ) diff --git a/sqlmesh/migrations/v0074_add_partition_by_time_column_property.py b/sqlmesh/migrations/v0074_add_partition_by_time_column_property.py new file mode 100644 index 0000000000..acd349c888 --- /dev/null +++ b/sqlmesh/migrations/v0074_add_partition_by_time_column_property.py @@ -0,0 +1,10 @@ +"""Add 'partition_by_time_column' property to the IncrementalByTimeRange model kind +(default: True to keep the original behaviour)""" + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + pass + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + pass diff --git a/sqlmesh/migrations/v0075_remove_validate_query.py b/sqlmesh/migrations/v0075_remove_validate_query.py new file mode 100644 index 0000000000..9fdcca7ea6 --- /dev/null +++ b/sqlmesh/migrations/v0075_remove_validate_query.py @@ -0,0 +1,85 @@ +"""Remove validate_query from existing snapshots.""" + +import json + +from sqlglot import exp + +from sqlmesh.utils.migration import index_text_type +from sqlmesh.utils.migration import blob_text_type + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + pass + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + import pandas as pd + + snapshots_table = "_snapshots" + index_type = index_text_type(engine_adapter.dialect) + if schema: + snapshots_table = f"{schema}.{snapshots_table}" + + new_snapshots = [] + + for ( + name, + identifier, + version, + snapshot, + kind_name, + updated_ts, + unpaused_ts, + ttl_ms, + unrestorable, + ) in engine_adapter.fetchall( + exp.select( + "name", + "identifier", + "version", + "snapshot", + "kind_name", + "updated_ts", + "unpaused_ts", + "ttl_ms", + "unrestorable", + ).from_(snapshots_table), + quote_identifiers=True, + ): + parsed_snapshot = json.loads(snapshot) + + parsed_snapshot["node"].pop("validate_query", None) + + new_snapshots.append( + { + "name": name, + "identifier": identifier, + "version": version, + "snapshot": json.dumps(parsed_snapshot), + "kind_name": kind_name, + "updated_ts": updated_ts, + "unpaused_ts": unpaused_ts, + "ttl_ms": ttl_ms, + "unrestorable": unrestorable, + } + ) + + if new_snapshots: + engine_adapter.delete_from(snapshots_table, "TRUE") + blob_type = blob_text_type(engine_adapter.dialect) + + engine_adapter.insert_append( + snapshots_table, + pd.DataFrame(new_snapshots), + target_columns_to_types={ + "name": exp.DataType.build(index_type), + "identifier": exp.DataType.build(index_type), + "version": exp.DataType.build(index_type), + "snapshot": exp.DataType.build(blob_type), + "kind_name": exp.DataType.build(index_type), + "updated_ts": exp.DataType.build("bigint"), + "unpaused_ts": exp.DataType.build("bigint"), + "ttl_ms": exp.DataType.build("bigint"), + "unrestorable": exp.DataType.build("boolean"), + }, + ) diff --git a/sqlmesh/migrations/v0076_add_cron_tz.py b/sqlmesh/migrations/v0076_add_cron_tz.py new file mode 100644 index 0000000000..909017c8cd --- /dev/null +++ b/sqlmesh/migrations/v0076_add_cron_tz.py @@ -0,0 +1,9 @@ +"""Add 'cron_tz' property to node definition.""" + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + pass + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + pass diff --git a/sqlmesh/migrations/v0077_fix_column_type_hash_calculation.py b/sqlmesh/migrations/v0077_fix_column_type_hash_calculation.py new file mode 100644 index 0000000000..68953836bd --- /dev/null +++ b/sqlmesh/migrations/v0077_fix_column_type_hash_calculation.py @@ -0,0 +1,9 @@ +"""Use the model's dialect when calculating the hash for the column types.""" + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + pass + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + pass diff --git a/sqlmesh/migrations/v0078_warn_if_non_migratable_python_env.py b/sqlmesh/migrations/v0078_warn_if_non_migratable_python_env.py new file mode 100644 index 0000000000..adf1e96dd0 --- /dev/null +++ b/sqlmesh/migrations/v0078_warn_if_non_migratable_python_env.py @@ -0,0 +1,144 @@ +""" +This script's goal is to warn users if there is both a metadata and non-metadata reference in +the python environment of a model. Additionally, it warns them if there's a macro referenced +in a used audit's query, in the argument list of the audits and signals properties, or in an +on_virtual_update statement. + +Context: + +The metadata status for macros and signals is now transitive, i.e. every dependency of a +metadata macro or signal is also metadata, unless it is referenced by a non-metadata object. + +This means that global references of metadata objects may now be excluded from the data hash +calculation because of their new metadata status, which would lead to a diff. + +Additionally, we now implicitly treat macro refs in the aforementioned statements as "metadata-only", +even though they may not be marked as such by a user. This may also lead to a diff. +""" + +import json + +from sqlglot import exp + +import sqlmesh.core.dialect as d +from sqlmesh.core.console import get_console + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + pass + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + snapshots_table = "_snapshots" + if schema: + snapshots_table = f"{schema}.{snapshots_table}" + + warning = ( + "SQLMesh detected that it may not be able to fully migrate the state database. This should not impact " + "the migration process, but may result in unexpected changes being reported by the next `sqlmesh plan` " + "command. Please run `sqlmesh diff prod` after the migration has completed, before making any new " + "changes. If any unexpected changes are reported, consider running a forward-only plan to apply these " + "changes and avoid unnecessary backfills: sqlmesh plan prod --forward-only. " + "See https://sqlmesh.readthedocs.io/en/stable/concepts/plans/#forward-only-plans for more details.\n" + ) + + for (snapshot,) in engine_adapter.fetchall( + exp.select("snapshot").from_(snapshots_table), quote_identifiers=True + ): + parsed_snapshot = json.loads(snapshot) + node = parsed_snapshot["node"] + + # Standalone audits don't have a data hash, so they're unaffected + if node.get("source_type") == "audit": + continue + + python_env = node.get("python_env") or {} + + has_metadata = False + has_non_metadata = False + + for k, v in python_env.items(): + if v.get("is_metadata"): + has_metadata = True + else: + has_non_metadata = True + + if has_metadata and has_non_metadata: + get_console().log_warning(warning) + return + + dialect = node.get("dialect") + metadata_hash_statements = [] + + # We use try-except here as a conservative measure to avoid any unexpected exceptions + try: + if on_virtual_update := node.get("on_virtual_update"): + metadata_hash_statements.extend(parse_expression(on_virtual_update, dialect)) + + for _, audit_args in func_call_validator(node.get("audits") or []): + metadata_hash_statements.extend(audit_args.values()) + + for signal_name, signal_args in func_call_validator( + node.get("signals") or [], is_signal=True + ): + metadata_hash_statements.extend(signal_args.values()) + + if audit_definitions := node.get("audit_definitions"): + audit_queries = [ + parse_expression(audit["query"], audit["dialect"]) + for audit in audit_definitions.values() + ] + metadata_hash_statements.extend(audit_queries) + + for macro_name in extract_used_macros(metadata_hash_statements): + serialized_macro = python_env.get(macro_name) + if isinstance(serialized_macro, dict) and not serialized_macro.get("is_metadata"): + get_console().log_warning(warning) + return + except Exception: + pass + + +def extract_used_macros(expressions): + used_macros = set() + for expression in expressions: + if isinstance(expression, d.Jinja): + continue + + for macro_func in expression.find_all(d.MacroFunc): + if macro_func.__class__ is d.MacroFunc: + used_macros.add(macro_func.this.name.lower()) + + return used_macros + + +def func_call_validator(v, is_signal=False): + assert isinstance(v, list) + + audits = [] + for entry in v: + if isinstance(entry, dict): + args = entry + name = "" if is_signal else entry.pop("name") + else: + assert isinstance(entry, (tuple, list)) + name, args = entry + + parsed_audit = { + key: d.parse_one(value) if isinstance(value, str) else value + for key, value in args.items() + } + audits.append((name.lower(), parsed_audit)) + + return audits + + +def parse_expression(v, dialect): + if v is None: + return None + + if isinstance(v, list): + return [d.parse_one(e, dialect=dialect) for e in v] + + assert isinstance(v, str) + return d.parse_one(v, dialect=dialect) diff --git a/sqlmesh/migrations/v0079_add_gateway_managed_property.py b/sqlmesh/migrations/v0079_add_gateway_managed_property.py new file mode 100644 index 0000000000..7650d6d765 --- /dev/null +++ b/sqlmesh/migrations/v0079_add_gateway_managed_property.py @@ -0,0 +1,33 @@ +"""Add flag that controls whether the virtual layer's views will be created by the model specified gateway rather than the default gateway.""" + +from sqlglot import exp + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + environments_table = "_environments" + if schema: + environments_table = f"{schema}.{environments_table}" + + alter_table_exp = exp.Alter( + this=exp.to_table(environments_table), + kind="TABLE", + actions=[ + exp.ColumnDef( + this=exp.to_column("gateway_managed"), + kind=exp.DataType.build("boolean"), + ) + ], + ) + engine_adapter.execute(alter_table_exp) + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + environments_table = "_environments" + if schema: + environments_table = f"{schema}.{environments_table}" + + engine_adapter.update_table( + environments_table, + {"gateway_managed": False}, + where=exp.true(), + ) diff --git a/sqlmesh/migrations/v0080_add_batch_size_to_scd_type_2_models.py b/sqlmesh/migrations/v0080_add_batch_size_to_scd_type_2_models.py new file mode 100644 index 0000000000..35cb3977cc --- /dev/null +++ b/sqlmesh/migrations/v0080_add_batch_size_to_scd_type_2_models.py @@ -0,0 +1,9 @@ +"""Add batch_size to SCD Type 2 models and add updated_at_name to by time which changes their data hash.""" + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + pass + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + pass diff --git a/sqlmesh/migrations/v0081_update_partitioned_by.py b/sqlmesh/migrations/v0081_update_partitioned_by.py new file mode 100644 index 0000000000..8740285bf0 --- /dev/null +++ b/sqlmesh/migrations/v0081_update_partitioned_by.py @@ -0,0 +1,94 @@ +"""Remove superfluous exp.Paren references from partitioned_by""" + +import json + +from sqlglot import exp + +from sqlmesh.utils.migration import index_text_type +from sqlmesh.utils.migration import blob_text_type + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + pass + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + import pandas as pd + + snapshots_table = "_snapshots" + index_type = index_text_type(engine_adapter.dialect) + if schema: + snapshots_table = f"{schema}.{snapshots_table}" + + new_snapshots = [] + updated = False + + for ( + name, + identifier, + version, + snapshot, + kind_name, + updated_ts, + unpaused_ts, + ttl_ms, + unrestorable, + ) in engine_adapter.fetchall( + exp.select( + "name", + "identifier", + "version", + "snapshot", + "kind_name", + "updated_ts", + "unpaused_ts", + "ttl_ms", + "unrestorable", + ).from_(snapshots_table), + quote_identifiers=True, + ): + parsed_snapshot = json.loads(snapshot) + + if partitioned_by := parsed_snapshot["node"].get("partitioned_by"): + new_partitioned_by = [] + for item in partitioned_by: + # rewrite '(foo)' to 'foo' + if item.startswith("(") and item.endswith(")"): + item = item[1:-1] + updated = True + new_partitioned_by.append(item) + parsed_snapshot["node"]["partitioned_by"] = new_partitioned_by + + new_snapshots.append( + { + "name": name, + "identifier": identifier, + "version": version, + "snapshot": json.dumps(parsed_snapshot), + "kind_name": kind_name, + "updated_ts": updated_ts, + "unpaused_ts": unpaused_ts, + "ttl_ms": ttl_ms, + "unrestorable": unrestorable, + } + ) + + if new_snapshots and updated: + engine_adapter.delete_from(snapshots_table, "TRUE") + blob_type = blob_text_type(engine_adapter.dialect) + + engine_adapter.insert_append( + snapshots_table, + pd.DataFrame(new_snapshots), + target_columns_to_types={ + "name": exp.DataType.build(index_type), + "identifier": exp.DataType.build(index_type), + "version": exp.DataType.build(index_type), + "snapshot": exp.DataType.build(blob_type), + "kind_name": exp.DataType.build(index_type), + "updated_ts": exp.DataType.build("bigint"), + "unpaused_ts": exp.DataType.build("bigint"), + "ttl_ms": exp.DataType.build("bigint"), + "unrestorable": exp.DataType.build("boolean"), + }, + ) diff --git a/sqlmesh/migrations/v0082_warn_if_incorrectly_duplicated_statements.py b/sqlmesh/migrations/v0082_warn_if_incorrectly_duplicated_statements.py new file mode 100644 index 0000000000..5565b099cd --- /dev/null +++ b/sqlmesh/migrations/v0082_warn_if_incorrectly_duplicated_statements.py @@ -0,0 +1,70 @@ +""" +This script's goal is to warn users if there are two adjacent expressions in a SQL +model that are equivalent. + +Context: + +We used to include `Semicolon` expressions in the model's state, which led to a bug +where the expression preceding the semicolon would be duplicated in pre_statements +or post_statements. For example, the query in the model below would be incorrectly +included in its post_statements list: + +``` +MODEL ( + name test +); + +SELECT 1 AS c; + +-- foo +``` + +We now don't include `Semicolon` expressions in the model's state, which fixes this +issue, but unfortunately migrating existing snapshots is not possible because we do +not have a signal in state to detect whether an expression was incorrectly duplicated. + +If a SQL model suffered from this issue, then there would be two adjacent equivalent +expressions in it, so we use that as a heuristic to warn the user accordingly. +""" + +import json + +from sqlglot import exp + +from sqlmesh.core.console import get_console + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + pass + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + snapshots_table = "_snapshots" + if schema: + snapshots_table = f"{schema}.{snapshots_table}" + + warning = ( + "SQLMesh detected that it may not be able to fully migrate the state database. This should not impact " + "the migration process, but may result in unexpected changes being reported by the next `sqlmesh plan` " + "command. Please run `sqlmesh diff prod` after the migration has completed, before making any new " + "changes. If any unexpected changes are reported, consider running a forward-only plan to apply these " + "changes and avoid unnecessary backfills: sqlmesh plan prod --forward-only. " + "See https://sqlmesh.readthedocs.io/en/stable/concepts/plans/#forward-only-plans for more details.\n" + ) + + for (snapshot,) in engine_adapter.fetchall( + exp.select("snapshot").from_(snapshots_table), quote_identifiers=True + ): + parsed_snapshot = json.loads(snapshot) + node = parsed_snapshot["node"] + + if node.get("source_type") == "sql": + expressions = [ + *node.get("pre_statements", []), + node["query"], + *node.get("post_statements", []), + ] + for e1, e2 in zip(expressions, expressions[1:]): + if e1 == e2: + get_console().log_warning(warning) + return diff --git a/sqlmesh/migrations/v0083_use_sql_for_scd_time_data_type_data_hash.py b/sqlmesh/migrations/v0083_use_sql_for_scd_time_data_type_data_hash.py new file mode 100644 index 0000000000..5dbe0847f9 --- /dev/null +++ b/sqlmesh/migrations/v0083_use_sql_for_scd_time_data_type_data_hash.py @@ -0,0 +1,9 @@ +"""Use sql(...) instead of gen when computing the data hash of the time data type.""" + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + pass + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + pass diff --git a/sqlmesh/migrations/v0084_normalize_quote_when_matched_and_merge_filter.py b/sqlmesh/migrations/v0084_normalize_quote_when_matched_and_merge_filter.py new file mode 100644 index 0000000000..9edb0051ba --- /dev/null +++ b/sqlmesh/migrations/v0084_normalize_quote_when_matched_and_merge_filter.py @@ -0,0 +1,13 @@ +""" +Normalize and quote the when_matched and merge_filter properties of IncrementalByUniqueKeyKind +to match how other properties (such as time_column and partitioned_by) are handled and to +prevent un-normalized identifiers being quoted at the EngineAdapter level +""" + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + pass + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + pass diff --git a/sqlmesh/migrations/v0085_deterministic_repr.py b/sqlmesh/migrations/v0085_deterministic_repr.py new file mode 100644 index 0000000000..81cb0f194e --- /dev/null +++ b/sqlmesh/migrations/v0085_deterministic_repr.py @@ -0,0 +1,133 @@ +""" +When serializing some objects, like `__sqlmesh__vars__`, the order of keys in the dictionary were not deterministic +and therefore this migration applies deterministic sorting to the keys of the dictionary. +""" + +import json +import logging +import typing as t +from dataclasses import dataclass + +from sqlglot import exp + +from sqlmesh.utils.migration import index_text_type, blob_text_type + + +logger = logging.getLogger(__name__) + + +KEYS_TO_MAKE_DETERMINISTIC = ["__sqlmesh__vars__", "__sqlmesh__blueprint__vars__"] + + +# Make sure `SqlValue` is defined so it can be used by `eval` call in the migration +@dataclass +class SqlValue: + """A SQL string representing a generated SQLGlot AST.""" + + sql: str + + +def _dict_sort(obj: t.Any) -> str: + try: + if isinstance(obj, dict): + obj = dict(sorted(obj.items(), key=lambda x: str(x[0]))) + except Exception: + logger.warning("Failed to sort non-recursive dict", exc_info=True) + return repr(obj) + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + pass + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + import pandas as pd + + snapshots_table = "_snapshots" + if schema: + snapshots_table = f"{schema}.{snapshots_table}" + + migration_needed = False + new_snapshots = [] + + for ( + name, + identifier, + version, + snapshot, + kind_name, + updated_ts, + unpaused_ts, + ttl_ms, + unrestorable, + ) in engine_adapter.fetchall( + exp.select( + "name", + "identifier", + "version", + "snapshot", + "kind_name", + "updated_ts", + "unpaused_ts", + "ttl_ms", + "unrestorable", + ).from_(snapshots_table), + quote_identifiers=True, + ): + parsed_snapshot = json.loads(snapshot) + python_env = parsed_snapshot["node"].get("python_env") + + if python_env: + for key, executable in python_env.items(): + if key not in KEYS_TO_MAKE_DETERMINISTIC: + continue + if isinstance(executable, dict) and executable.get("kind") == "value": + old_payload = executable["payload"] + try: + # Try to parse the old payload and re-serialize it deterministically + parsed_value = eval(old_payload) + new_payload = _dict_sort(parsed_value) + + # Only update if the representation changed + if old_payload != new_payload: + executable["payload"] = new_payload + migration_needed = True + except Exception: + # If we still can't eval it, leave it as-is + logger.warning("Exception trying to eval payload", exc_info=True) + + new_snapshots.append( + { + "name": name, + "identifier": identifier, + "version": version, + "snapshot": json.dumps(parsed_snapshot), + "kind_name": kind_name, + "updated_ts": updated_ts, + "unpaused_ts": unpaused_ts, + "ttl_ms": ttl_ms, + "unrestorable": unrestorable, + } + ) + + if migration_needed and new_snapshots: + engine_adapter.delete_from(snapshots_table, "TRUE") + + index_type = index_text_type(engine_adapter.dialect) + blob_type = blob_text_type(engine_adapter.dialect) + + engine_adapter.insert_append( + snapshots_table, + pd.DataFrame(new_snapshots), + target_columns_to_types={ + "name": exp.DataType.build(index_type), + "identifier": exp.DataType.build(index_type), + "version": exp.DataType.build(index_type), + "snapshot": exp.DataType.build(blob_type), + "kind_name": exp.DataType.build("text"), + "updated_ts": exp.DataType.build("bigint"), + "unpaused_ts": exp.DataType.build("bigint"), + "ttl_ms": exp.DataType.build("bigint"), + "unrestorable": exp.DataType.build("boolean"), + }, + ) diff --git a/sqlmesh/migrations/v0086_check_deterministic_bug.py b/sqlmesh/migrations/v0086_check_deterministic_bug.py new file mode 100644 index 0000000000..f44e5b8e33 --- /dev/null +++ b/sqlmesh/migrations/v0086_check_deterministic_bug.py @@ -0,0 +1,84 @@ +import json +import logging + +from sqlglot import exp + +from sqlmesh.core.console import get_console + + +logger = logging.getLogger(__name__) +KEYS_TO_MAKE_DETERMINISTIC = ["__sqlmesh__vars__", "__sqlmesh__blueprint__vars__"] + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + pass + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + snapshots_table = "_snapshots" + versions_table = "_versions" + if schema: + snapshots_table = f"{schema}.{snapshots_table}" + versions_table = f"{schema}.{versions_table}" + + result = engine_adapter.fetchone( + exp.select("schema_version").from_(versions_table), quote_identifiers=True + ) + if not result: + # This must be the first migration, so we can skip the check since the project was not exposed to 85 migration bug + return + schema_version = result[0] + if schema_version < 85: + # The project was not exposed to the bugged 85 migration, so we can skip it. + return + + warning = ( + "SQLMesh detected that it may not be able to fully migrate the state database. This should not impact " + "the migration process, but may result in unexpected changes being reported by the next `sqlmesh plan` " + "command. Please run `sqlmesh diff prod` after the migration has completed, before making any new " + "changes. If any unexpected changes are reported, consider running a forward-only plan to apply these " + "changes and avoid unnecessary backfills: sqlmesh plan prod --forward-only. " + "See https://sqlmesh.readthedocs.io/en/stable/concepts/plans/#forward-only-plans for more details.\n" + ) + + for ( + name, + identifier, + version, + snapshot, + kind_name, + updated_ts, + unpaused_ts, + ttl_ms, + unrestorable, + ) in engine_adapter.fetchall( + exp.select( + "name", + "identifier", + "version", + "snapshot", + "kind_name", + "updated_ts", + "unpaused_ts", + "ttl_ms", + "unrestorable", + ).from_(snapshots_table), + quote_identifiers=True, + ): + parsed_snapshot = json.loads(snapshot) + python_env = parsed_snapshot["node"].get("python_env") + + if python_env: + for key, executable in python_env.items(): + if ( + key not in KEYS_TO_MAKE_DETERMINISTIC + and isinstance(executable, dict) + and executable.get("kind") == "value" + ): + try: + parsed_value = eval(executable["payload"]) + if isinstance(parsed_value, dict): + get_console().log_warning(warning) + return + except Exception: + logger.warning("Exception trying to eval payload", exc_info=True) diff --git a/sqlmesh/migrations/v0087_normalize_blueprint_variables.py b/sqlmesh/migrations/v0087_normalize_blueprint_variables.py new file mode 100644 index 0000000000..fe737861c2 --- /dev/null +++ b/sqlmesh/migrations/v0087_normalize_blueprint_variables.py @@ -0,0 +1,140 @@ +""" +Normalizes blueprint variables, so Customer_Field is stored as customer_field in the `python_env`: + +MODEL ( + ... + blueprints ( + Customer_Field := 1 + ) +); + +SELECT + @customer_field AS col +""" + +import json +import logging +from dataclasses import dataclass + +from sqlglot import exp +from sqlmesh.core.console import get_console +from sqlmesh.utils.migration import index_text_type, blob_text_type + + +logger = logging.getLogger(__name__) + + +SQLMESH_BLUEPRINT_VARS = "__sqlmesh__blueprint__vars__" + + +# Make sure `SqlValue` is defined so it can be used by `eval` call in the migration +@dataclass +class SqlValue: + """A SQL string representing a generated SQLGlot AST.""" + + sql: str + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + pass + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + import pandas as pd + + snapshots_table = "_snapshots" + if schema: + snapshots_table = f"{schema}.{snapshots_table}" + + migration_needed = False + new_snapshots = [] + + for ( + name, + identifier, + version, + snapshot, + kind_name, + updated_ts, + unpaused_ts, + ttl_ms, + unrestorable, + ) in engine_adapter.fetchall( + exp.select( + "name", + "identifier", + "version", + "snapshot", + "kind_name", + "updated_ts", + "unpaused_ts", + "ttl_ms", + "unrestorable", + ).from_(snapshots_table), + quote_identifiers=True, + ): + parsed_snapshot = json.loads(snapshot) + node = parsed_snapshot["node"] + python_env = node.get("python_env") or {} + + migrate_snapshot = False + + if blueprint_vars_executable := python_env.get(SQLMESH_BLUEPRINT_VARS): + blueprint_vars = eval(blueprint_vars_executable["payload"]) + + for var, value in dict(blueprint_vars).items(): + lowercase_var = var.lower() + if var != lowercase_var: + if lowercase_var in blueprint_vars: + get_console().log_warning( + "SQLMesh is unable to fully migrate the state database, because the " + f"model '{node['name']}' contains two blueprint variables ('{var}' and " + f"'{lowercase_var}') that resolve to the same value ('{lowercase_var}'). " + "This may result in unexpected changes being reported by the next " + "`sqlmesh plan` command. If this happens, consider renaming either variable, " + "so that the lowercase version of their names are different." + ) + else: + del blueprint_vars[var] + blueprint_vars[lowercase_var] = value + migrate_snapshot = True + + if migrate_snapshot: + migration_needed = True + blueprint_vars_executable["payload"] = repr(blueprint_vars) + + new_snapshots.append( + { + "name": name, + "identifier": identifier, + "version": version, + "snapshot": json.dumps(parsed_snapshot), + "kind_name": kind_name, + "updated_ts": updated_ts, + "unpaused_ts": unpaused_ts, + "ttl_ms": ttl_ms, + "unrestorable": unrestorable, + } + ) + + if migration_needed and new_snapshots: + engine_adapter.delete_from(snapshots_table, "TRUE") + + index_type = index_text_type(engine_adapter.dialect) + blob_type = blob_text_type(engine_adapter.dialect) + + engine_adapter.insert_append( + snapshots_table, + pd.DataFrame(new_snapshots), + target_columns_to_types={ + "name": exp.DataType.build(index_type), + "identifier": exp.DataType.build(index_type), + "version": exp.DataType.build(index_type), + "snapshot": exp.DataType.build(blob_type), + "kind_name": exp.DataType.build("text"), + "updated_ts": exp.DataType.build("bigint"), + "unpaused_ts": exp.DataType.build("bigint"), + "ttl_ms": exp.DataType.build("bigint"), + "unrestorable": exp.DataType.build("boolean"), + }, + ) diff --git a/sqlmesh/migrations/v0088_warn_about_variable_python_env_diffs.py b/sqlmesh/migrations/v0088_warn_about_variable_python_env_diffs.py new file mode 100644 index 0000000000..0aa7171821 --- /dev/null +++ b/sqlmesh/migrations/v0088_warn_about_variable_python_env_diffs.py @@ -0,0 +1,76 @@ +""" +This script's goal is to warn users about two situations that could lead to a diff: + +- They have blueprint models and some of their variables may be trimmed from `python_env` +- Variables are used in metadata-only contexts, e.g., within metadata-only macros + +Context: + +We used to store *all* blueprint variables in `python_env`, even though some of them were +redundant. For example, if a blueprint variable is only used in the model's `name` property, +then it is rendered once, at load time, and after that point it's not needed elsewhere. + +This behavior is now different: we only store the blueprint variables that are required to render +expressions at runtime, such as model query or runtime-rendered properties, like `merge_filter`. + +Additionally, variables were previously treated as non-metadata, regardless of how they were used. +This behavior changed as well: SQLMesh now analyzes variable references and tracks the data flow, +in order to detect whether changing them will result in a metadata diff for a given model. + +Some examples where variables can be treated as metadata-only `python_env` executables are: + +- A variable is referenced in metadata-only macros +- A variable is referenced in metadata-only expressions, such as virtual update statements +- A variable is passed as argument to metadata-only macros +""" + +import json + +from sqlglot import exp + +from sqlmesh.core.console import get_console + +SQLMESH_VARS = "__sqlmesh__vars__" +SQLMESH_BLUEPRINT_VARS = "__sqlmesh__blueprint__vars__" +METADATA_HASH_EXPRESSIONS = {"on_virtual_update", "audits", "signals", "audit_definitions"} + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + pass + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + snapshots_table = "_snapshots" + if schema: + snapshots_table = f"{schema}.{snapshots_table}" + + warning = ( + "SQLMesh detected that it may not be able to fully migrate the state database. This should not impact " + "the migration process, but may result in unexpected changes being reported by the next `sqlmesh plan` " + "command. Please run `sqlmesh diff prod` after the migration has completed, before making any new " + "changes. If any unexpected changes are reported, consider running a forward-only plan to apply these " + "changes and avoid unnecessary backfills: sqlmesh plan prod --forward-only. " + "See https://sqlmesh.readthedocs.io/en/stable/concepts/plans/#forward-only-plans for more details.\n" + ) + + for (snapshot,) in engine_adapter.fetchall( + exp.select("snapshot").from_(snapshots_table), quote_identifiers=True + ): + parsed_snapshot = json.loads(snapshot) + node = parsed_snapshot["node"] + + # Standalone audits don't have a data hash, so they're unaffected + if node.get("source_type") == "audit": + continue + + python_env = node.get("python_env") or {} + + if SQLMESH_BLUEPRINT_VARS in python_env or ( + SQLMESH_VARS in python_env + and ( + any(v.get("is_metadata") for v in python_env.values()) + or any(node.get(k) for k in METADATA_HASH_EXPRESSIONS) + ) + ): + get_console().log_warning(warning) + return diff --git a/sqlmesh/migrations/v0089_add_virtual_environment_mode.py b/sqlmesh/migrations/v0089_add_virtual_environment_mode.py new file mode 100644 index 0000000000..88126c76d7 --- /dev/null +++ b/sqlmesh/migrations/v0089_add_virtual_environment_mode.py @@ -0,0 +1,9 @@ +"""Add virtual_environment_mode to the model definition.""" + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + pass + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + pass diff --git a/sqlmesh/migrations/v0090_add_forward_only_column.py b/sqlmesh/migrations/v0090_add_forward_only_column.py new file mode 100644 index 0000000000..48253691ec --- /dev/null +++ b/sqlmesh/migrations/v0090_add_forward_only_column.py @@ -0,0 +1,104 @@ +"""Add forward_only column to the snapshots table.""" + +import json + +from sqlglot import exp + +from sqlmesh.utils.migration import index_text_type, blob_text_type + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + snapshots_table = "_snapshots" + if schema: + snapshots_table = f"{schema}.{snapshots_table}" + + alter_table_exp = exp.Alter( + this=exp.to_table(snapshots_table), + kind="TABLE", + actions=[ + exp.ColumnDef( + this=exp.to_column("forward_only"), + kind=exp.DataType.build("boolean"), + ) + ], + ) + engine_adapter.execute(alter_table_exp) + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + import pandas as pd + + snapshots_table = "_snapshots" + if schema: + snapshots_table = f"{schema}.{snapshots_table}" + + new_snapshots = [] + + for ( + name, + identifier, + version, + snapshot, + kind_name, + updated_ts, + unpaused_ts, + ttl_ms, + unrestorable, + forward_only, + ) in engine_adapter.fetchall( + exp.select( + "name", + "identifier", + "version", + "snapshot", + "kind_name", + "updated_ts", + "unpaused_ts", + "ttl_ms", + "unrestorable", + "forward_only", + ).from_(snapshots_table), + quote_identifiers=True, + ): + parsed_snapshot = json.loads(snapshot) + + forward_only = parsed_snapshot.get("forward_only") + if forward_only is None: + forward_only = parsed_snapshot.get("change_category") == 3 + + new_snapshots.append( + { + "name": name, + "identifier": identifier, + "version": version, + "snapshot": json.dumps(parsed_snapshot), + "kind_name": kind_name, + "updated_ts": updated_ts, + "unpaused_ts": unpaused_ts, + "ttl_ms": ttl_ms, + "unrestorable": unrestorable, + "forward_only": forward_only, + } + ) + + if new_snapshots: + engine_adapter.delete_from(snapshots_table, "TRUE") + index_type = index_text_type(engine_adapter.dialect) + blob_type = blob_text_type(engine_adapter.dialect) + + engine_adapter.insert_append( + snapshots_table, + pd.DataFrame(new_snapshots), + target_columns_to_types={ + "name": exp.DataType.build(index_type), + "identifier": exp.DataType.build(index_type), + "version": exp.DataType.build(index_type), + "snapshot": exp.DataType.build(blob_type), + "kind_name": exp.DataType.build(index_type), + "updated_ts": exp.DataType.build("bigint"), + "unpaused_ts": exp.DataType.build("bigint"), + "ttl_ms": exp.DataType.build("bigint"), + "unrestorable": exp.DataType.build("boolean"), + "forward_only": exp.DataType.build("boolean"), + }, + ) diff --git a/sqlmesh/migrations/v0091_on_additive_change.py b/sqlmesh/migrations/v0091_on_additive_change.py new file mode 100644 index 0000000000..e24b9b4122 --- /dev/null +++ b/sqlmesh/migrations/v0091_on_additive_change.py @@ -0,0 +1,9 @@ +"""Add on_additive_change to incremental model metadata hash.""" + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + pass + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + pass diff --git a/sqlmesh/migrations/v0092_warn_about_dbt_data_type_diff.py b/sqlmesh/migrations/v0092_warn_about_dbt_data_type_diff.py new file mode 100644 index 0000000000..5407e5a99a --- /dev/null +++ b/sqlmesh/migrations/v0092_warn_about_dbt_data_type_diff.py @@ -0,0 +1,50 @@ +""" +Warns dbt users about potential diffs due to corrected data_type handling. + +SQLMesh previously treated dbt's schema.yml data_type field as columns_to_types, which +doesn't match dbt's behavior. dbt only uses data_type for contracts/validation, not DDL. +This fix may cause diffs if tables were created with incorrect types. + +More context: https://github.com/SQLMesh/sqlmesh/pull/5231 +""" + +import json + +from sqlglot import exp + +from sqlmesh.core.console import get_console + +SQLMESH_DBT_PACKAGE = "sqlmesh.dbt" + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + pass + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + snapshots_table = "_snapshots" + if schema: + snapshots_table = f"{schema}.{snapshots_table}" + + warning = ( + "SQLMesh previously misinterpreted dbt's schema.yml 'data_type' field as actual " + "column types, but dbt only uses these for contracts/validation, not in actual " + "DDL statements. This has been fixed to match dbt's actual behavior. Your existing " + "tables may have been created with incorrect column types. After this migration, run " + "'sqlmesh diff prod' to check for column type differences, and if any are found, " + "apply a plan to correct the table schemas. For more details, see: " + "https://github.com/SQLMesh/sqlmesh/pull/5231." + ) + + for (snapshot,) in engine_adapter.fetchall( + exp.select("snapshot").from_(snapshots_table), quote_identifiers=True + ): + parsed_snapshot = json.loads(snapshot) + node = parsed_snapshot["node"] + + jinja_macros = node.get("jinja_macros") or {} + create_builtins_module = jinja_macros.get("create_builtins_module") or "" + + if create_builtins_module == SQLMESH_DBT_PACKAGE and node.get("columns"): + get_console().log_warning(warning) + return diff --git a/sqlmesh/migrations/v0093_use_raw_sql_in_fingerprint.py b/sqlmesh/migrations/v0093_use_raw_sql_in_fingerprint.py new file mode 100644 index 0000000000..aaaacf3a91 --- /dev/null +++ b/sqlmesh/migrations/v0093_use_raw_sql_in_fingerprint.py @@ -0,0 +1,9 @@ +"""Use the raw SQL when computing the model fingerprint.""" + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + pass + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + pass diff --git a/sqlmesh/migrations/v0094_add_dev_version_and_fingerprint_columns.py b/sqlmesh/migrations/v0094_add_dev_version_and_fingerprint_columns.py new file mode 100644 index 0000000000..9d7adf21a3 --- /dev/null +++ b/sqlmesh/migrations/v0094_add_dev_version_and_fingerprint_columns.py @@ -0,0 +1,123 @@ +"""Add dev_version and fingerprint columns to the snapshots table.""" + +import json + +from sqlglot import exp + +from sqlmesh.utils.migration import index_text_type, blob_text_type + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + snapshots_table = "_snapshots" + if schema: + snapshots_table = f"{schema}.{snapshots_table}" + + index_type = index_text_type(engine_adapter.dialect) + blob_type = blob_text_type(engine_adapter.dialect) + + add_dev_version_exp = exp.Alter( + this=exp.to_table(snapshots_table), + kind="TABLE", + actions=[ + exp.ColumnDef( + this=exp.to_column("dev_version"), + kind=exp.DataType.build(index_type), + ) + ], + ) + engine_adapter.execute(add_dev_version_exp) + + add_fingerprint_exp = exp.Alter( + this=exp.to_table(snapshots_table), + kind="TABLE", + actions=[ + exp.ColumnDef( + this=exp.to_column("fingerprint"), + kind=exp.DataType.build(blob_type), + ) + ], + ) + engine_adapter.execute(add_fingerprint_exp) + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + import pandas as pd + + snapshots_table = "_snapshots" + if schema: + snapshots_table = f"{schema}.{snapshots_table}" + + index_type = index_text_type(engine_adapter.dialect) + blob_type = blob_text_type(engine_adapter.dialect) + + new_snapshots = [] + + for ( + name, + identifier, + version, + snapshot, + kind_name, + updated_ts, + unpaused_ts, + ttl_ms, + unrestorable, + forward_only, + _, + _, + ) in engine_adapter.fetchall( + exp.select( + "name", + "identifier", + "version", + "snapshot", + "kind_name", + "updated_ts", + "unpaused_ts", + "ttl_ms", + "unrestorable", + "forward_only", + "dev_version", + "fingerprint", + ).from_(snapshots_table), + quote_identifiers=True, + ): + parsed_snapshot = json.loads(snapshot) + new_snapshots.append( + { + "name": name, + "identifier": identifier, + "version": version, + "snapshot": snapshot, + "kind_name": kind_name, + "updated_ts": updated_ts, + "unpaused_ts": unpaused_ts, + "ttl_ms": ttl_ms, + "unrestorable": unrestorable, + "forward_only": forward_only, + "dev_version": parsed_snapshot.get("dev_version"), + "fingerprint": json.dumps(parsed_snapshot.get("fingerprint")), + } + ) + + if new_snapshots: + engine_adapter.delete_from(snapshots_table, "TRUE") + + engine_adapter.insert_append( + snapshots_table, + pd.DataFrame(new_snapshots), + target_columns_to_types={ + "name": exp.DataType.build(index_type), + "identifier": exp.DataType.build(index_type), + "version": exp.DataType.build(index_type), + "snapshot": exp.DataType.build(blob_type), + "kind_name": exp.DataType.build(index_type), + "updated_ts": exp.DataType.build("bigint"), + "unpaused_ts": exp.DataType.build("bigint"), + "ttl_ms": exp.DataType.build("bigint"), + "unrestorable": exp.DataType.build("boolean"), + "forward_only": exp.DataType.build("boolean"), + "dev_version": exp.DataType.build(index_type), + "fingerprint": exp.DataType.build(blob_type), + }, + ) diff --git a/sqlmesh/migrations/v0095_warn_about_dbt_raw_sql_diff.py b/sqlmesh/migrations/v0095_warn_about_dbt_raw_sql_diff.py new file mode 100644 index 0000000000..0fa9fd51b8 --- /dev/null +++ b/sqlmesh/migrations/v0095_warn_about_dbt_raw_sql_diff.py @@ -0,0 +1,49 @@ +""" +Warns dbt users about potential diffs due to inclusion of {{ config(...) }} blocks in model SQL. + +Prior to this fix, SQLMesh wasn't including the {{ config(...) }} block in the model's SQL payload +when processing dbt models. Now these config blocks are properly included in the raw SQL, which +may cause diffs to appear for existing dbt models even though the actual SQL logic hasn't changed. + +This is a one-time diff that will appear after upgrading, and applying a plan will resolve it. +""" + +import json + +from sqlglot import exp + +from sqlmesh.core.console import get_console + +SQLMESH_DBT_PACKAGE = "sqlmesh.dbt" + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + pass + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + snapshots_table = "_snapshots" + if schema: + snapshots_table = f"{schema}.{snapshots_table}" + + warning = ( + "SQLMesh detected that it may not be able to fully migrate the state database. This should not impact " + "the migration process, but may result in unexpected changes being reported by the next `sqlmesh plan` " + "command. Please run `sqlmesh diff prod` after the migration has completed, before making any new " + "changes. If any unexpected changes are reported, consider running a forward-only plan to apply these " + "changes and avoid unnecessary backfills: sqlmesh plan prod --forward-only. " + "See https://sqlmesh.readthedocs.io/en/stable/concepts/plans/#forward-only-plans for more details.\n" + ) + + for (snapshot,) in engine_adapter.fetchall( + exp.select("snapshot").from_(snapshots_table), quote_identifiers=True + ): + parsed_snapshot = json.loads(snapshot) + node = parsed_snapshot["node"] + + jinja_macros = node.get("jinja_macros") or {} + create_builtins_module = jinja_macros.get("create_builtins_module") or "" + + if create_builtins_module == SQLMESH_DBT_PACKAGE: + get_console().log_warning(warning) + return diff --git a/sqlmesh/migrations/v0096_remove_plan_dags_table.py b/sqlmesh/migrations/v0096_remove_plan_dags_table.py new file mode 100644 index 0000000000..8eb674ead0 --- /dev/null +++ b/sqlmesh/migrations/v0096_remove_plan_dags_table.py @@ -0,0 +1,13 @@ +"""Remove the obsolete _plan_dags table.""" + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + plan_dags_table = "_plan_dags" + if schema: + plan_dags_table = f"{schema}.{plan_dags_table}" + + engine_adapter.drop_table(plan_dags_table) + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + pass diff --git a/sqlmesh/migrations/v0097_add_dbt_name_in_node.py b/sqlmesh/migrations/v0097_add_dbt_name_in_node.py new file mode 100644 index 0000000000..cd548977ef --- /dev/null +++ b/sqlmesh/migrations/v0097_add_dbt_name_in_node.py @@ -0,0 +1,9 @@ +"""Add 'dbt_name' property to node definition.""" + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + pass + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + pass diff --git a/sqlmesh/migrations/v0098_add_dbt_node_info_in_node.py b/sqlmesh/migrations/v0098_add_dbt_node_info_in_node.py new file mode 100644 index 0000000000..b69ba8fa6f --- /dev/null +++ b/sqlmesh/migrations/v0098_add_dbt_node_info_in_node.py @@ -0,0 +1,103 @@ +"""Replace 'dbt_name' with 'dbt_node_info' in the snapshot definition""" + +import json +from sqlglot import exp +from sqlmesh.utils.migration import index_text_type, blob_text_type + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + pass + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + import pandas as pd + + snapshots_table = "_snapshots" + if schema: + snapshots_table = f"{schema}.{snapshots_table}" + + index_type = index_text_type(engine_adapter.dialect) + blob_type = blob_text_type(engine_adapter.dialect) + + new_snapshots = [] + migration_needed = False + + for ( + name, + identifier, + version, + snapshot, + kind_name, + updated_ts, + unpaused_ts, + ttl_ms, + unrestorable, + forward_only, + dev_version, + fingerprint, + ) in engine_adapter.fetchall( + exp.select( + "name", + "identifier", + "version", + "snapshot", + "kind_name", + "updated_ts", + "unpaused_ts", + "ttl_ms", + "unrestorable", + "forward_only", + "dev_version", + "fingerprint", + ).from_(snapshots_table), + quote_identifiers=True, + ): + parsed_snapshot = json.loads(snapshot) + if dbt_name := parsed_snapshot["node"].get("dbt_name"): + parsed_snapshot["node"].pop("dbt_name") + parsed_snapshot["node"]["dbt_node_info"] = { + "unique_id": dbt_name, + # these will get populated as metadata-only changes on the next plan + "name": "", + "fqn": "", + } + migration_needed = True + + new_snapshots.append( + { + "name": name, + "identifier": identifier, + "version": version, + "snapshot": json.dumps(parsed_snapshot), + "kind_name": kind_name, + "updated_ts": updated_ts, + "unpaused_ts": unpaused_ts, + "ttl_ms": ttl_ms, + "unrestorable": unrestorable, + "forward_only": forward_only, + "dev_version": dev_version, + "fingerprint": fingerprint, + } + ) + + if migration_needed and new_snapshots: + engine_adapter.delete_from(snapshots_table, "TRUE") + + engine_adapter.insert_append( + snapshots_table, + pd.DataFrame(new_snapshots), + target_columns_to_types={ + "name": exp.DataType.build(index_type), + "identifier": exp.DataType.build(index_type), + "version": exp.DataType.build(index_type), + "snapshot": exp.DataType.build(blob_type), + "kind_name": exp.DataType.build(index_type), + "updated_ts": exp.DataType.build("bigint"), + "unpaused_ts": exp.DataType.build("bigint"), + "ttl_ms": exp.DataType.build("bigint"), + "unrestorable": exp.DataType.build("boolean"), + "forward_only": exp.DataType.build("boolean"), + "dev_version": exp.DataType.build(index_type), + "fingerprint": exp.DataType.build(blob_type), + }, + ) diff --git a/sqlmesh/migrations/v0099_add_last_altered_to_intervals.py b/sqlmesh/migrations/v0099_add_last_altered_to_intervals.py new file mode 100644 index 0000000000..b80ed35a35 --- /dev/null +++ b/sqlmesh/migrations/v0099_add_last_altered_to_intervals.py @@ -0,0 +1,25 @@ +"""Add dev version to the intervals table.""" + +from sqlglot import exp + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + intervals_table = "_intervals" + if schema: + intervals_table = f"{schema}.{intervals_table}" + + alter_table_exp = exp.Alter( + this=exp.to_table(intervals_table), + kind="TABLE", + actions=[ + exp.ColumnDef( + this=exp.to_column("last_altered_ts"), + kind=exp.DataType.build("BIGINT", dialect=engine_adapter.dialect), + ) + ], + ) + engine_adapter.execute(alter_table_exp) + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + pass diff --git a/sqlmesh/migrations/v0100_add_grants_and_grants_target_layer.py b/sqlmesh/migrations/v0100_add_grants_and_grants_target_layer.py new file mode 100644 index 0000000000..9ff64c5e57 --- /dev/null +++ b/sqlmesh/migrations/v0100_add_grants_and_grants_target_layer.py @@ -0,0 +1,9 @@ +"""Add grants and grants_target_layer to incremental model metadata hash.""" + + +def migrate_schemas(engine_adapter, schema, **kwargs): # type: ignore + pass + + +def migrate_rows(engine_adapter, schema, **kwargs): # type: ignore + pass diff --git a/sqlmesh/schedulers/airflow/__init__.py b/sqlmesh/schedulers/airflow/__init__.py deleted file mode 100644 index f5c62217b8..0000000000 --- a/sqlmesh/schedulers/airflow/__init__.py +++ /dev/null @@ -1 +0,0 @@ -NO_DEFAULT_CATALOG = "NO_DEFAULT_CATALOG" diff --git a/sqlmesh/schedulers/airflow/api.py b/sqlmesh/schedulers/airflow/api.py deleted file mode 100644 index 9744efbb0c..0000000000 --- a/sqlmesh/schedulers/airflow/api.py +++ /dev/null @@ -1,183 +0,0 @@ -from __future__ import annotations - -import json -import logging -import typing as t -from functools import wraps - -from airflow.api_connexion import security -from airflow.www.app import csrf -from flask import Blueprint, Response, jsonify, make_response, request - -from sqlmesh.core import constants as c -from sqlmesh.core.snapshot import SnapshotId, SnapshotNameVersion -from sqlmesh.schedulers.airflow import common, util -from sqlmesh.schedulers.airflow.plan import PlanDagState, create_plan_dag_spec -from sqlmesh.utils.errors import SQLMeshError -from sqlmesh.utils.pydantic import PydanticModel - -logger = logging.getLogger(__name__) - - -sqlmesh_api_v1 = Blueprint( - c.SQLMESH, - __name__, - url_prefix=f"/{common.SQLMESH_API_BASE_PATH}", -) - - -def check_authentication(func: t.Callable) -> t.Callable: - @wraps(func) - def wrapper(*args: t.Any, **kwargs: t.Any) -> t.Any: - security.check_authentication() - return func(*args, **kwargs) - - return wrapper - - -@sqlmesh_api_v1.route("/plans", methods=["POST"]) -@csrf.exempt -@check_authentication -def apply_plan() -> Response: - try: - plan = common.PlanApplicationRequest.parse_obj(request.json or {}) - with util.scoped_state_sync() as state_sync: - spec = create_plan_dag_spec(plan, state_sync) - PlanDagState.from_state_sync(state_sync).add_dag_spec(spec) - return make_response(jsonify(request_id=spec.request_id), 201) - except Exception as ex: - logger.exception("Failed to create a plan DAG spec from request:\n%s", request.json) - return _error(str(ex)) - - -@sqlmesh_api_v1.route("/environments/") -@csrf.exempt -@check_authentication -def get_environment(name: str) -> Response: - with util.scoped_state_sync() as state_sync: - environment = state_sync.get_environment(name) - if environment is None: - return _error(f"Environment '{name}' was not found", 404) - return _success(environment) - - -@sqlmesh_api_v1.route("/environments") -@csrf.exempt -@check_authentication -def get_environments() -> Response: - with util.scoped_state_sync() as state_sync: - environments = state_sync.get_environments() - return _success(common.EnvironmentsResponse(environments=environments)) - - -@sqlmesh_api_v1.route("/environments//max_interval_end") -@csrf.exempt -@check_authentication -def get_max_interval_end(name: str) -> Response: - with util.scoped_state_sync() as state_sync: - ensure_finalized_snapshots = "ensure_finalized_snapshots" in request.args - max_interval_end = state_sync.max_interval_end_for_environment( - name, ensure_finalized_snapshots=ensure_finalized_snapshots - ) - response = common.IntervalEndResponse(environment=name, max_interval_end=max_interval_end) - return _success(response) - - -@sqlmesh_api_v1.route("/environments//greatest_common_interval_end") -@csrf.exempt -@check_authentication -def get_greatest_common_interval_end(name: str) -> Response: - with util.scoped_state_sync() as state_sync: - models = json.loads(request.args["models"]) if "models" in request.args else [] - ensure_finalized_snapshots = "ensure_finalized_snapshots" in request.args - max_interval_end = state_sync.greatest_common_interval_end( - name, set(models), ensure_finalized_snapshots=ensure_finalized_snapshots - ) - response = common.IntervalEndResponse(environment=name, max_interval_end=max_interval_end) - return _success(response) - - -@sqlmesh_api_v1.route("/environments/", methods=["DELETE"]) -@csrf.exempt -@check_authentication -def invalidate_environment(name: str) -> Response: - with util.scoped_state_sync() as state_sync: - try: - state_sync.invalidate_environment(name) - except SQLMeshError as ex: - return _error(str(ex), 400) - - return _success(common.InvalidateEnvironmentResponse(name=name)) - - -@sqlmesh_api_v1.route("/snapshots") -@csrf.exempt -@check_authentication -def get_snapshots() -> Response: - with util.scoped_state_sync() as state_sync: - snapshot_ids = _snapshot_ids_from_request() - - if "check_existence" in request.args: - existing_snapshot_ids = ( - state_sync.snapshots_exist(snapshot_ids) if snapshot_ids is not None else set() - ) - return _success(common.SnapshotIdsResponse(snapshot_ids=existing_snapshot_ids)) - - snapshots = list(state_sync.get_snapshots(snapshot_ids).values()) - return _success(common.SnapshotsResponse(snapshots=snapshots)) - - -@sqlmesh_api_v1.route("/models") -@csrf.exempt -@check_authentication -def nodes_exist() -> Response: - with util.scoped_state_sync() as state_sync: - names = _csv_arg("names") - exclude_external = "exclude_external" in request.args - existing_models = state_sync.nodes_exist(names, exclude_external=exclude_external) - return _success(common.ExistingModelsResponse(names=list(existing_models))) - - -@sqlmesh_api_v1.route("/versions") -@csrf.exempt -@check_authentication -def get_versions() -> Response: - with util.scoped_state_sync() as state_sync: - versions = state_sync.get_versions() - assert versions - return _success(versions) - - -T = t.TypeVar("T", bound=PydanticModel) - - -def _success(data: T, status_code: int = 200) -> Response: - response = make_response(data.json(), status_code) - response.mimetype = "application/json" - return response - - -def _error(message: str, status_code: int = 400) -> Response: - return make_response(jsonify(message=message), status_code) - - -def _snapshot_ids_from_request() -> t.Optional[t.List[SnapshotId]]: - if "ids" not in request.args: - return None - - raw_ids = json.loads(request.args["ids"]) - return [SnapshotId.parse_obj(i) for i in raw_ids] - - -def _snapshot_name_versions_from_request() -> t.Optional[t.List[SnapshotNameVersion]]: - if "versions" not in request.args: - return None - - raw_versions = json.loads(request.args["versions"]) - return [SnapshotNameVersion.parse_obj(v) for v in raw_versions] - - -def _csv_arg(arg: str) -> t.List[str]: - if arg not in request.args: - return [] - return [v.strip() for v in request.args[arg].split(",")] diff --git a/sqlmesh/schedulers/airflow/client.py b/sqlmesh/schedulers/airflow/client.py deleted file mode 100644 index 68ff64d0a1..0000000000 --- a/sqlmesh/schedulers/airflow/client.py +++ /dev/null @@ -1,372 +0,0 @@ -import abc -import json -import time -import typing as t -import uuid -from urllib.parse import urlencode, urljoin - -import requests - -from sqlmesh.core.console import Console -from sqlmesh.core.environment import Environment -from sqlmesh.core.notification_target import NotificationTarget -from sqlmesh.core.snapshot import Snapshot, SnapshotId -from sqlmesh.core.snapshot.definition import Interval -from sqlmesh.core.state_sync import Versions -from sqlmesh.core.user import User -from sqlmesh.schedulers.airflow import common, NO_DEFAULT_CATALOG -from sqlmesh.utils import unique -from sqlmesh.utils.date import TimeLike -from sqlmesh.utils.errors import ( - ApiServerError, - NotFoundError, - SQLMeshError, - raise_for_status, -) -from sqlmesh.utils.pydantic import PydanticModel - -DAG_RUN_PATH_TEMPLATE = "api/v1/dags/{}/dagRuns" - - -PLANS_PATH = f"{common.SQLMESH_API_BASE_PATH}/plans" -ENVIRONMENTS_PATH = f"{common.SQLMESH_API_BASE_PATH}/environments" -SNAPSHOTS_PATH = f"{common.SQLMESH_API_BASE_PATH}/snapshots" -SEEDS_PATH = f"{common.SQLMESH_API_BASE_PATH}/seeds" -INTERVALS_PATH = f"{common.SQLMESH_API_BASE_PATH}/intervals" -MODELS_PATH = f"{common.SQLMESH_API_BASE_PATH}/models" -VERSIONS_PATH = f"{common.SQLMESH_API_BASE_PATH}/versions" - - -class BaseAirflowClient(abc.ABC): - def __init__(self, airflow_url: str, console: t.Optional[Console]): - self._airflow_url = airflow_url - if not self._airflow_url.endswith("/"): - self._airflow_url += "/" - self._console = console - - @property - def default_catalog(self) -> t.Optional[str]: - default_catalog = self.get_variable(common.DEFAULT_CATALOG_VARIABLE_NAME) - if not default_catalog: - raise SQLMeshError( - "Must define `default_catalog` when creating `SQLMeshAirflow` object. See docs for more info: https://sqlmesh.readthedocs.io/en/stable/integrations/airflow/#airflow-cluster-configuration" - ) - if default_catalog == NO_DEFAULT_CATALOG: - return None - return default_catalog - - def print_tracking_url(self, dag_id: str, dag_run_id: str, op_name: str) -> None: - if not self._console: - return - - tracking_url = self.dag_run_tracking_url(dag_id, dag_run_id) - # TODO: Figure out generalized solution for links - self._console.log_status_update( - f"Track [green]{op_name}[/green] progress using [link={tracking_url}]link[/link]" - ) - - def dag_run_tracking_url(self, dag_id: str, dag_run_id: str) -> str: - url_params = urlencode( - dict( - dag_id=dag_id, - run_id=dag_run_id, - ) - ) - return urljoin(self._airflow_url, f"dagrun_details?{url_params}") - - def wait_for_dag_run_completion( - self, dag_id: str, dag_run_id: str, poll_interval_secs: int - ) -> bool: - """Blocks until the given DAG Run completes. - - Args: - dag_id: The DAG ID. - dag_run_id: The DAG Run ID. - poll_interval_secs: The number of seconds to wait between polling for the DAG Run state. - - Returns: - True if the DAG Run completed successfully, False otherwise. - """ - loading_id = self._console_loading_start() - - while True: - state = self.get_dag_run_state(dag_id, dag_run_id) - if state in ("failed", "success"): - if self._console and loading_id: - self._console.loading_stop(loading_id) - return state == "success" - - time.sleep(poll_interval_secs) - - def wait_for_first_dag_run(self, dag_id: str, poll_interval_secs: int, max_retries: int) -> str: - """Blocks until the first DAG Run for the given DAG ID is created. - - Args: - dag_id: The DAG ID. - poll_interval_secs: The number of seconds to wait between polling for the DAG Run. - max_retries: The maximum number of retries. - - Returns: - The ID of the first DAG Run for the given DAG ID. - """ - - loading_id = self._console_loading_start() - - attempt_num = 1 - - try: - while True: - try: - first_dag_run_id = self.get_first_dag_run_id(dag_id) - if first_dag_run_id is None: - raise SQLMeshError(f"Missing a DAG Run for DAG '{dag_id}'") - return first_dag_run_id - except ApiServerError: - raise - except SQLMeshError: - if attempt_num > max_retries: - raise - - attempt_num += 1 - time.sleep(poll_interval_secs) - finally: - if self._console and loading_id: - self._console.loading_stop(loading_id) - - @abc.abstractmethod - def get_first_dag_run_id(self, dag_id: str) -> t.Optional[str]: - """Returns the ID of the first DAG Run for the given DAG ID, or None if no DAG Runs exist. - - Args: - dag_id: The DAG ID. - - Returns: - The ID of the first DAG Run for the given DAG ID, or None if no DAG Runs exist. - """ - - @abc.abstractmethod - def get_dag_run_state(self, dag_id: str, dag_run_id: str) -> str: - """Returns the state of the given DAG Run. - - Args: - dag_id: The DAG ID. - dag_run_id: The DAG Run ID. - - Returns: - The state of the given DAG Run. - """ - - @abc.abstractmethod - def get_variable(self, key: str) -> t.Optional[str]: - """Returns the value of an Airflow variable with the given key. - - Args: - key: The variable key. - - Returns: - The variable value or None if no variable with the given key exists. - """ - - def _console_loading_start(self) -> t.Optional[uuid.UUID]: - if self._console: - return self._console.loading_start() - return None - - -class AirflowClient(BaseAirflowClient): - def __init__( - self, - session: requests.Session, - airflow_url: str, - console: t.Optional[Console] = None, - snapshot_ids_batch_size: t.Optional[int] = None, - ): - super().__init__(airflow_url, console) - self._session = session - self._snapshot_ids_batch_size = snapshot_ids_batch_size - - def apply_plan( - self, - new_snapshots: t.Iterable[Snapshot], - environment: Environment, - request_id: str, - no_gaps: bool = False, - skip_backfill: bool = False, - restatements: t.Optional[t.Dict[SnapshotId, Interval]] = None, - notification_targets: t.Optional[t.List[NotificationTarget]] = None, - backfill_concurrent_tasks: int = 1, - ddl_concurrent_tasks: int = 1, - users: t.Optional[t.List[User]] = None, - is_dev: bool = False, - allow_destructive_snapshots: t.Set[str] = set(), - forward_only: bool = False, - models_to_backfill: t.Optional[t.Set[str]] = None, - end_bounded: bool = False, - ensure_finalized_snapshots: bool = False, - directly_modified_snapshots: t.Optional[t.List[SnapshotId]] = None, - indirectly_modified_snapshots: t.Optional[t.Dict[str, t.List[SnapshotId]]] = None, - removed_snapshots: t.Optional[t.List[SnapshotId]] = None, - execution_time: t.Optional[TimeLike] = None, - ) -> None: - request = common.PlanApplicationRequest( - new_snapshots=list(new_snapshots), - environment=environment, - no_gaps=no_gaps, - skip_backfill=skip_backfill, - request_id=request_id, - restatements={s.name: i for s, i in (restatements or {}).items()}, - notification_targets=notification_targets or [], - backfill_concurrent_tasks=backfill_concurrent_tasks, - ddl_concurrent_tasks=ddl_concurrent_tasks, - users=users or [], - is_dev=is_dev, - allow_destructive_snapshots=allow_destructive_snapshots, - forward_only=forward_only, - models_to_backfill=models_to_backfill, - end_bounded=end_bounded, - ensure_finalized_snapshots=ensure_finalized_snapshots, - directly_modified_snapshots=directly_modified_snapshots or [], - indirectly_modified_snapshots=indirectly_modified_snapshots or {}, - removed_snapshots=removed_snapshots or [], - execution_time=execution_time, - ) - - response = self._session.post( - urljoin(self._airflow_url, PLANS_PATH), - data=request.json(), - headers={"Content-Type": "application/json"}, - ) - raise_for_status(response) - - def get_snapshots(self, snapshot_ids: t.Optional[t.List[SnapshotId]]) -> t.List[Snapshot]: - output = [] - - if snapshot_ids is not None: - for ids_batch in _list_to_json( - unique(snapshot_ids), batch_size=self._snapshot_ids_batch_size - ): - output.extend( - common.SnapshotsResponse.parse_obj( - self._get(SNAPSHOTS_PATH, ids=ids_batch) - ).snapshots - ) - return output - - return common.SnapshotsResponse.parse_obj(self._get(SNAPSHOTS_PATH)).snapshots - - def snapshots_exist(self, snapshot_ids: t.List[SnapshotId]) -> t.Set[SnapshotId]: - output = set() - for ids_batch in _list_to_json( - unique(snapshot_ids), batch_size=self._snapshot_ids_batch_size - ): - output |= set( - common.SnapshotIdsResponse.parse_obj( - self._get(SNAPSHOTS_PATH, "check_existence", ids=ids_batch) - ).snapshot_ids - ) - - return output - - def nodes_exist(self, names: t.Iterable[str], exclude_external: bool = False) -> t.Set[str]: - flags = ["exclude_external"] if exclude_external else [] - return set( - common.ExistingModelsResponse.parse_obj( - self._get(MODELS_PATH, *flags, names=",".join(names)) - ).names - ) - - def get_environment(self, environment: str) -> t.Optional[Environment]: - try: - response = self._get(f"{ENVIRONMENTS_PATH}/{environment}") - return Environment.parse_obj(response) - except NotFoundError: - return None - - def get_environments(self) -> t.List[Environment]: - response = self._get(ENVIRONMENTS_PATH) - return common.EnvironmentsResponse.parse_obj(response).environments - - def max_interval_end_for_environment( - self, environment: str, ensure_finalized_snapshots: bool - ) -> t.Optional[int]: - flags = ["ensure_finalized_snapshots"] if ensure_finalized_snapshots else [] - response = self._get(f"{ENVIRONMENTS_PATH}/{environment}/max_interval_end", *flags) - return common.IntervalEndResponse.parse_obj(response).max_interval_end - - def greatest_common_interval_end( - self, environment: str, models: t.Collection[str], ensure_finalized_snapshots: bool - ) -> t.Optional[int]: - flags = ["ensure_finalized_snapshots"] if ensure_finalized_snapshots else [] - response = self._get( - f"{ENVIRONMENTS_PATH}/{environment}/greatest_common_interval_end", - *flags, - models=_json_query_param(list(models)), - ) - return common.IntervalEndResponse.parse_obj(response).max_interval_end - - def invalidate_environment(self, environment: str) -> None: - response = self._session.delete( - urljoin(self._airflow_url, f"{ENVIRONMENTS_PATH}/{environment}") - ) - raise_for_status(response) - - def get_versions(self) -> Versions: - return Versions.parse_obj(self._get(VERSIONS_PATH)) - - def get_dag_run_state(self, dag_id: str, dag_run_id: str) -> str: - url = f"{DAG_RUN_PATH_TEMPLATE.format(dag_id)}/{dag_run_id}" - return self._get(url)["state"].lower() - - def get_janitor_dag(self) -> t.Dict[str, t.Any]: - return self._get_dag(common.JANITOR_DAG_ID) - - def get_snapshot_dag(self, name: str, version: str) -> t.Dict[str, t.Any]: - return self._get_dag(common.dag_id_for_name_version(name, version)) - - def get_all_dags(self) -> t.Dict[str, t.Any]: - return self._get("api/v1/dags") - - def get_first_dag_run_id(self, dag_id: str) -> t.Optional[str]: - dag_runs_response = self._get(f"{DAG_RUN_PATH_TEMPLATE.format(dag_id)}", limit="1") - dag_runs = dag_runs_response["dag_runs"] - if not dag_runs: - return None - return dag_runs[0]["dag_run_id"] - - def get_variable(self, key: str) -> t.Optional[str]: - try: - variables_response = self._get(f"api/v1/variables/{key}") - return variables_response["value"] - except NotFoundError: - return None - - def close(self) -> None: - self._session.close() - - def _get_dag(self, dag_id: str) -> t.Dict[str, t.Any]: - return self._get(f"api/v1/dags/{dag_id}") - - def _get(self, path: str, *flags: str, **params: str) -> t.Dict[str, t.Any]: - all_params = [*flags, *([urlencode(params)] if params else [])] - query_string = "&".join(all_params) - if query_string: - path = f"{path}?{query_string}" - response = self._session.get(urljoin(self._airflow_url, path)) - raise_for_status(response) - return response.json() - - -T = t.TypeVar("T", bound=PydanticModel) - - -def _list_to_json(models: t.Collection[T], batch_size: t.Optional[int] = None) -> t.List[str]: - serialized = [m.dict() for m in models] - if batch_size is not None: - batches = [serialized[i : i + batch_size] for i in range(0, len(serialized), batch_size)] - else: - batches = [serialized] - return [_json_query_param(batch) for batch in batches] - - -def _json_query_param(value: t.Any) -> str: - return json.dumps(value, separators=(",", ":")) diff --git a/sqlmesh/schedulers/airflow/common.py b/sqlmesh/schedulers/airflow/common.py deleted file mode 100644 index 89c227913f..0000000000 --- a/sqlmesh/schedulers/airflow/common.py +++ /dev/null @@ -1,137 +0,0 @@ -from __future__ import annotations - -import typing as t - -from sqlmesh.core import constants as c -from sqlmesh.core.environment import Environment -from sqlmesh.core.notification_target import NotificationTarget -from sqlmesh.core.scheduler import Interval -from sqlmesh.core.snapshot import ( - DeployabilityIndex, - Snapshot, - SnapshotId, - SnapshotInfoLike, - SnapshotIntervals, - SnapshotTableInfo, -) -from sqlmesh.core.snapshot.definition import Interval as SnapshotInterval -from sqlmesh.core.user import User -from sqlmesh.utils import sanitize_name -from sqlmesh.utils.date import TimeLike -from sqlmesh.utils.pydantic import PydanticModel - -JANITOR_DAG_ID = "sqlmesh_janitor_dag" -JANITOR_TASK_ID = "janitor_task" - -SQLMESH_AIRFLOW_TAG = "sqlmesh" -SNAPSHOT_AIRFLOW_TAG = "sqlmesh_snapshot" -PLAN_AIRFLOW_TAG = "sqlmesh_plan" - -SNAPSHOT_CLEANUP_COMMAND_XCOM_KEY = "snapshot_cleanup_command" - -DEFAULT_CATALOG_VARIABLE_NAME = "sqlmesh_default_catalog" - -AIRFLOW_LOCAL_URL = "http://localhost:8080/" - -SQLMESH_API_BASE_PATH: str = f"{c.SQLMESH}/api/v1" - - -class PlanApplicationRequest(PydanticModel): - request_id: str - new_snapshots: t.List[Snapshot] - environment: Environment - no_gaps: bool - skip_backfill: bool - restatements: t.Dict[str, SnapshotInterval] - notification_targets: t.List[NotificationTarget] - backfill_concurrent_tasks: int - ddl_concurrent_tasks: int - users: t.List[User] - is_dev: bool - allow_destructive_snapshots: t.Set[str] = set() - forward_only: bool - models_to_backfill: t.Optional[t.Set[str]] = None - end_bounded: bool - ensure_finalized_snapshots: bool - directly_modified_snapshots: t.List[SnapshotId] - indirectly_modified_snapshots: t.Dict[str, t.List[SnapshotId]] - removed_snapshots: t.List[SnapshotId] - execution_time: t.Optional[TimeLike] = None - - def is_selected_for_backfill(self, model_fqn: str) -> bool: - return self.models_to_backfill is None or model_fqn in self.models_to_backfill - - -class BackfillIntervalsPerSnapshot(PydanticModel): - snapshot_id: SnapshotId - intervals: t.List[Interval] - before_promote: bool = True - - -class PlanDagSpec(PydanticModel): - request_id: str - environment: Environment - new_snapshots: t.List[Snapshot] - backfill_intervals_per_snapshot: t.List[BackfillIntervalsPerSnapshot] - demoted_snapshots: t.List[SnapshotTableInfo] - unpaused_dt: t.Optional[TimeLike] = None - no_gaps: bool - notification_targets: t.List[NotificationTarget] - backfill_concurrent_tasks: int - ddl_concurrent_tasks: int - users: t.List[User] - is_dev: bool - allow_destructive_snapshots: t.Set[str] - forward_only: t.Optional[bool] = None - dag_start_ts: t.Optional[int] = None - deployability_index: DeployabilityIndex = DeployabilityIndex.all_deployable() - deployability_index_for_creation: DeployabilityIndex = DeployabilityIndex.all_deployable() - no_gaps_snapshot_names: t.Optional[t.Set[str]] = None - models_to_backfill: t.Optional[t.Set[str]] = None - ensure_finalized_snapshots: bool = False - directly_modified_snapshots: t.Optional[t.List[SnapshotId]] = None - indirectly_modified_snapshots: t.Optional[t.Dict[str, t.List[SnapshotId]]] = None - removed_snapshots: t.Optional[t.List[SnapshotId]] = None - execution_time: t.Optional[TimeLike] = None - - -class EnvironmentsResponse(PydanticModel): - environments: t.List[Environment] - - -class SnapshotsResponse(PydanticModel): - snapshots: t.List[Snapshot] - - -class SnapshotIntervalsResponse(PydanticModel): - snapshot_intervals: t.List[SnapshotIntervals] - - -class SnapshotIdsResponse(PydanticModel): - snapshot_ids: t.List[SnapshotId] - - -class ExistingModelsResponse(PydanticModel): - names: t.List[str] - - -class InvalidateEnvironmentResponse(PydanticModel): - name: str - - -class IntervalEndResponse(PydanticModel): - environment: str - max_interval_end: t.Optional[int] = None - - -def dag_id_for_snapshot_info(info: SnapshotInfoLike) -> str: - assert info.version - return dag_id_for_name_version(info.name, info.version) - - -def dag_id_for_name_version(name: str, version: str) -> str: - return f"sqlmesh_snapshot_{sanitize_name(name)}_{version}_dag" - - -def plan_application_dag_id(environment: str, request_id: str) -> str: - return f"sqlmesh_plan_application__{environment}__{request_id}" diff --git a/sqlmesh/schedulers/airflow/dag_generator.py b/sqlmesh/schedulers/airflow/dag_generator.py deleted file mode 100644 index 04b6c02494..0000000000 --- a/sqlmesh/schedulers/airflow/dag_generator.py +++ /dev/null @@ -1,715 +0,0 @@ -from __future__ import annotations - -import logging -import os -import typing as t - -import pendulum -from airflow import DAG -from airflow.models import BaseOperator -from airflow.operators.python import PythonOperator -from airflow.sensors.base import BaseSensorOperator - -from sqlmesh.core.environment import Environment, EnvironmentNamingInfo -from sqlmesh.core.notification_target import NotificationTarget -from sqlmesh.core.plan import PlanStatus -from sqlmesh.core.snapshot import ( - DeployabilityIndex, - Snapshot, - SnapshotId, - SnapshotIdLike, - SnapshotTableInfo, -) -from sqlmesh.core.state_sync import StateReader -from sqlmesh.schedulers.airflow import common, util -from sqlmesh.schedulers.airflow.operators import targets -from sqlmesh.schedulers.airflow.operators.sensor import ( - ExternalSensor, - HighWaterMarkSensor, -) -from sqlmesh.schedulers.airflow.operators.notification import ( - BaseNotificationOperatorProvider, -) -from sqlmesh.utils import sanitize_name -from sqlmesh.utils.date import TimeLike, to_datetime, yesterday_timestamp -from sqlmesh.utils.errors import SQLMeshError - -try: - from airflow.operators.empty import EmptyOperator -except ImportError: - from airflow.operators.dummy import DummyOperator as EmptyOperator # type: ignore - -logger = logging.getLogger(__name__) - - -TASK_ID_DATE_FORMAT = "%Y-%m-%d_%H-%M-%S" - -NOTIFICATION_TARGET_TO_OPERATOR_PROVIDER: t.Dict[ - t.Type[NotificationTarget], BaseNotificationOperatorProvider -] = {} - -DAG_DEFAULT_ARGS = { - # `AIRFLOW__CORE__DEFAULT_TASK_RETRY_DELAY` support added in 2.4.0 - # We can't use `AIRFLOW__CORE__DEFAULT_TASK_RETRY_DELAY` because cloud composer doesn't allow you to set config - # from an environment variable - "retry_delay": int( - os.getenv( - "SQLMESH_AIRFLOW_DEFAULT_TASK_RETRY_DELAY", - os.getenv("AIRFLOW__CORE__DEFAULT_TASK_RETRY_DELAY", "300"), - ) - ), -} - -AIRFLOW_TAG_CHARACTER_LIMIT = 100 - - -class SnapshotDagGenerator: - def __init__( - self, - engine_operator: t.Type[BaseOperator], - engine_operator_args: t.Optional[t.Dict[str, t.Any]], - ddl_engine_operator: t.Type[BaseOperator], - ddl_engine_operator_args: t.Optional[t.Dict[str, t.Any]], - external_table_sensor_factory: t.Optional[ - t.Callable[[t.Dict[str, t.Any]], BaseSensorOperator] - ], - sensor_mode: str, - high_water_mark_sensor_args: t.Optional[t.Dict[str, t.Any]], - external_sensor_args: t.Optional[t.Dict[str, t.Any]], - state_reader: StateReader, - ): - self._engine_operator = engine_operator - self._engine_operator_args = engine_operator_args or {} - self._ddl_engine_operator = ddl_engine_operator - self._ddl_engine_operator_args = ddl_engine_operator_args or {} - self._external_table_sensor_factory = external_table_sensor_factory - self._state_reader = state_reader - self._sensor_mode = sensor_mode - self._high_water_mark_sensor_args = high_water_mark_sensor_args or {} - self._external_sensor_args = external_sensor_args or {} - - def generate_cadence_dags(self, snapshots: t.Iterable[SnapshotIdLike]) -> t.List[DAG]: - dags = [] - snapshots = self._state_reader.get_snapshots(snapshots) - for snapshot in snapshots.values(): - if snapshot.unpaused_ts and not snapshot.is_symbolic and not snapshot.is_seed: - dags.append(self._create_cadence_dag_for_snapshot(snapshot, snapshots)) - return dags - - def generate_plan_application_dag(self, spec: common.PlanDagSpec) -> t.Optional[DAG]: - try: - return self._create_plan_application_dag(spec) - except Exception: - logger.exception("Failed to generate the plan application DAG '%s'", spec.request_id) - return None - - def _create_cadence_dag_for_snapshot( - self, snapshot: Snapshot, snapshots: t.Dict[SnapshotId, Snapshot] - ) -> DAG: - dag_id = common.dag_id_for_snapshot_info(snapshot.table_info) - logger.info( - "Generating the cadence DAG '%s' for snapshot %s", - dag_id, - snapshot.snapshot_id, - ) - - if not snapshot.unpaused_ts: - raise SQLMeshError( - f"Can't create a cadence DAG for the paused snapshot {snapshot.snapshot_id}" - ) - - end_date = None - if snapshot.node.end: - end_date = pendulum.instance(to_datetime(snapshot.node.end)) - - with DAG( - dag_id=dag_id, - schedule_interval=snapshot.node.cron, - start_date=pendulum.instance(to_datetime(snapshot.unpaused_ts)), - end_date=end_date, - max_active_runs=1, - catchup=True, - is_paused_upon_creation=False, - tags=[ - common.SQLMESH_AIRFLOW_TAG, - common.SNAPSHOT_AIRFLOW_TAG, - snapshot.node.name[-AIRFLOW_TAG_CHARACTER_LIMIT:], - ], - default_args={ - **DAG_DEFAULT_ARGS, - "email": snapshot.node.owner, - "email_on_failure": True, - }, - ) as dag: - hwm_sensor_tasks = self._create_hwm_sensors(snapshot, snapshots) - - evaluator_task = self._create_snapshot_evaluation_operator( - snapshots=snapshots, - snapshot=snapshot, - task_id="snapshot_evaluator", - ) - - hwm_sensor_tasks >> evaluator_task - - return dag - - def _create_plan_application_dag(self, plan_dag_spec: common.PlanDagSpec) -> DAG: - dag_id = common.plan_application_dag_id( - plan_dag_spec.environment.name, plan_dag_spec.request_id - ) - logger.info( - "Generating the plan application DAG '%s' for environment '%s'", - dag_id, - plan_dag_spec.environment.name, - ) - - all_snapshots = { - **{s.snapshot_id: s for s in plan_dag_spec.new_snapshots}, - **self._state_reader.get_snapshots(plan_dag_spec.environment.snapshots), - } - - snapshots_to_create = [ - all_snapshots[snapshot.snapshot_id] - for snapshot in plan_dag_spec.environment.snapshots - if snapshot.snapshot_id in all_snapshots - and ( - plan_dag_spec.models_to_backfill is None - or snapshot.name in plan_dag_spec.models_to_backfill - ) - ] - - with DAG( - dag_id=dag_id, - schedule_interval="@once", - start_date=pendulum.instance( - to_datetime(plan_dag_spec.dag_start_ts or yesterday_timestamp()) - ), - max_active_tasks=plan_dag_spec.backfill_concurrent_tasks, - catchup=False, - is_paused_upon_creation=False, - default_args=DAG_DEFAULT_ARGS, - tags=[ - common.SQLMESH_AIRFLOW_TAG, - common.PLAN_AIRFLOW_TAG, - plan_dag_spec.environment.name, - ], - ) as dag: - start_task = EmptyOperator(task_id="plan_application_start") - end_task = EmptyOperator(task_id="plan_application_end") - - (create_start_task, create_end_task) = self._create_creation_tasks( - snapshots_to_create, - plan_dag_spec.new_snapshots, - plan_dag_spec.ddl_concurrent_tasks, - plan_dag_spec.deployability_index_for_creation, - plan_dag_spec.allow_destructive_snapshots, - plan_dag_spec.request_id, - ) - - ( - backfill_before_promote_start_task, - backfill_before_promote_end_task, - ) = self._create_backfill_tasks( - [i for i in plan_dag_spec.backfill_intervals_per_snapshot if i.before_promote], - all_snapshots, - plan_dag_spec.deployability_index, - plan_dag_spec.environment.plan_id, - "before_promote", - plan_dag_spec.execution_time, - ) - - ( - backfill_after_promote_start_task, - backfill_after_promote_end_task, - ) = self._create_backfill_tasks( - [i for i in plan_dag_spec.backfill_intervals_per_snapshot if not i.before_promote], - all_snapshots, - plan_dag_spec.deployability_index, - plan_dag_spec.environment.plan_id, - "after_promote", - plan_dag_spec.execution_time, - ) - - ( - promote_start_task, - promote_end_task, - ) = self._create_promotion_demotion_tasks(plan_dag_spec, all_snapshots) - - start_task >> create_start_task - create_end_task >> backfill_before_promote_start_task - backfill_before_promote_end_task >> promote_start_task - - update_views_task_pair = self._create_update_views_tasks(plan_dag_spec, all_snapshots) - if update_views_task_pair: - backfill_after_promote_end_task >> update_views_task_pair[0] - before_finalize_task = update_views_task_pair[1] - else: - before_finalize_task = backfill_after_promote_end_task - - unpause_snapshots_task = self._create_unpause_snapshots_task(plan_dag_spec) - if unpause_snapshots_task: - if not plan_dag_spec.ensure_finalized_snapshots: - # Only unpause right after updatign the environment record if we don't - # have to use the finalized snapshots for subsequent plan applications. - promote_end_task >> unpause_snapshots_task - unpause_snapshots_task >> backfill_after_promote_start_task - else: - # Otherwise, unpause right before finalizing the environment. - promote_end_task >> backfill_after_promote_start_task - before_finalize_task >> unpause_snapshots_task - before_finalize_task = unpause_snapshots_task - else: - promote_end_task >> backfill_after_promote_start_task - - finalize_task = self._create_finalize_task(plan_dag_spec.environment) - before_finalize_task >> finalize_task - finalize_task >> end_task - - on_plan_apply_end_task = PythonOperator( - task_id="on_plan_apply_end", - python_callable=on_plan_apply_end, - op_kwargs={"plan_id": plan_dag_spec.environment.plan_id}, - trigger_rule="all_done", - ) - finalize_task >> on_plan_apply_end_task - - self._add_notification_target_tasks(plan_dag_spec, start_task, finalize_task) - return dag - - def _add_notification_target_tasks( - self, - request: common.PlanDagSpec, - start_task: BaseOperator, - end_task: BaseOperator, - ) -> None: - for notification_target in request.notification_targets: - notification_operator_provider = NOTIFICATION_TARGET_TO_OPERATOR_PROVIDER.get( - type(notification_target) - ) - if not notification_operator_provider: - continue - plan_start_notification_task = notification_operator_provider.operator( - notification_target, PlanStatus.STARTED, request - ) - plan_success_notification_task = notification_operator_provider.operator( - notification_target, PlanStatus.FINISHED, request - ) - plan_failed_notification_task = notification_operator_provider.operator( - notification_target, PlanStatus.FAILED, request - ) - if plan_start_notification_task: - start_task >> plan_start_notification_task - if plan_success_notification_task: - end_task >> plan_success_notification_task - if plan_failed_notification_task: - end_task >> plan_failed_notification_task - - def _create_creation_tasks( - self, - snapshots_to_create: t.List[Snapshot], - new_snapshots: t.List[Snapshot], - ddl_concurrent_tasks: int, - deployability_index: DeployabilityIndex, - allow_destructive_snapshots: t.Set[str], - request_id: str, - ) -> t.Tuple[BaseOperator, BaseOperator]: - start_task = EmptyOperator(task_id="snapshot_creation_start") - end_task = EmptyOperator(task_id="snapshot_creation_end", trigger_rule="none_failed") - - current_task: BaseOperator = start_task - - if snapshots_to_create: - creation_task = self._create_snapshot_create_tables_operator( - snapshots_to_create, - ddl_concurrent_tasks, - deployability_index, - allow_destructive_snapshots, - "snapshot_creation__create_tables", - ) - current_task >> creation_task - current_task = creation_task - - if new_snapshots: - update_state_task = PythonOperator( - task_id="snapshot_creation__update_state", - python_callable=creation_update_state_task, - op_kwargs={"new_snapshots": new_snapshots, "request_id": request_id}, - ) - current_task >> update_state_task - current_task = update_state_task - - current_task >> end_task - - return (start_task, end_task) - - def _create_promotion_demotion_tasks( - self, - request: common.PlanDagSpec, - snapshots: t.Dict[SnapshotId, Snapshot], - ) -> t.Tuple[BaseOperator, BaseOperator]: - update_state_task = PythonOperator( - task_id="snapshot_promotion_update_state", - python_callable=promotion_update_state_task, - op_kwargs={ - "environment": request.environment, - "no_gaps_snapshot_names": ( - request.no_gaps_snapshot_names if request.no_gaps else set() - ), - }, - ) - - start_task = update_state_task - end_task: BaseOperator = update_state_task - - if request.environment.promoted_snapshots and not request.is_dev and request.unpaused_dt: - migrate_tables_task = self._create_snapshot_migrate_tables_operator( - [ - snapshots[s.snapshot_id] - for s in request.environment.promoted_snapshots - if snapshots[s.snapshot_id].is_paused - ], - request.ddl_concurrent_tasks, - request.allow_destructive_snapshots, - "snapshot_promotion_migrate_tables", - ) - update_state_task >> migrate_tables_task - end_task = migrate_tables_task - - return (start_task, end_task) - - def _create_unpause_snapshots_task( - self, request: common.PlanDagSpec - ) -> t.Optional[BaseOperator]: - if request.is_dev or not request.unpaused_dt: - return None - return PythonOperator( - task_id="snapshot_promotion_unpause_snapshots", - python_callable=promotion_unpause_snapshots_task, - op_kwargs={ - "environment": request.environment, - "unpaused_dt": request.unpaused_dt, - }, - trigger_rule="none_failed", - ) - - def _create_update_views_tasks( - self, request: common.PlanDagSpec, snapshots: t.Dict[SnapshotId, Snapshot] - ) -> t.Optional[t.Tuple[BaseOperator, BaseOperator]]: - create_views_task = None - delete_views_task = None - - environment_naming_info = request.environment.naming_info - - if request.environment.promoted_snapshots: - create_views_task = self._create_snapshot_promotion_operator( - [snapshots[x.snapshot_id] for x in request.environment.promoted_snapshots], - environment_naming_info, - request.ddl_concurrent_tasks, - request.deployability_index, - "snapshot_promotion_create_views", - ) - - if request.demoted_snapshots: - delete_views_task = self._create_snapshot_demotion_operator( - request.demoted_snapshots, - environment_naming_info, - request.ddl_concurrent_tasks, - "snapshot_promotion_delete_views", - ) - - if create_views_task and delete_views_task: - create_views_task >> delete_views_task - return create_views_task, delete_views_task - if create_views_task: - return create_views_task, create_views_task - if delete_views_task: - return delete_views_task, delete_views_task - return None - - def _create_finalize_task(self, environment: Environment) -> BaseOperator: - return PythonOperator( - task_id="snapshot_promotion_finalize", - python_callable=promotion_finalize_task, - op_kwargs={"environment": environment}, - ) - - def _create_backfill_tasks( - self, - backfill_intervals: t.List[common.BackfillIntervalsPerSnapshot], - snapshots: t.Dict[SnapshotId, Snapshot], - deployability_index: DeployabilityIndex, - plan_id: str, - task_id_suffix: str, - execution_time: t.Optional[TimeLike], - ) -> t.Tuple[BaseOperator, BaseOperator]: - snapshot_to_tasks = {} - for intervals_per_snapshot in backfill_intervals: - sid = intervals_per_snapshot.snapshot_id - - if not intervals_per_snapshot.intervals: - logger.info("Skipping backfill for snapshot %s", sid) - continue - - snapshot = snapshots[sid] - sanitized_model_name = sanitize_name(snapshot.node.name) - - snapshot_task_pairs: t.List[t.Tuple[BaseOperator, BaseOperator]] = [] - - snapshot_start_task = EmptyOperator( - task_id=f"snapshot_backfill__{sanitized_model_name}__{snapshot.identifier}__start" - ) - snapshot_end_task = EmptyOperator( - task_id=f"snapshot_backfill__{sanitized_model_name}__{snapshot.identifier}__end" - ) - - task_id_prefix = f"snapshot_backfill__{sanitized_model_name}__{snapshot.identifier}" - for batch_idx, (start, end) in enumerate(intervals_per_snapshot.intervals): - evaluation_task = self._create_snapshot_evaluation_operator( - snapshots=snapshots, - snapshot=snapshot, - task_id=f"{task_id_prefix}__{start.strftime(TASK_ID_DATE_FORMAT)}__{end.strftime(TASK_ID_DATE_FORMAT)}", - start=start, - end=end, - deployability_index=deployability_index, - plan_id=plan_id, - execution_time=execution_time, - batch_index=batch_idx, - ) - external_sensor_task = self._create_external_sensor(snapshot, start=start, end=end) - if external_sensor_task: - ( - snapshot_start_task - >> external_sensor_task - >> evaluation_task - >> snapshot_end_task - ) - snapshot_task_pairs.append((external_sensor_task, evaluation_task)) - else: - snapshot_start_task >> evaluation_task >> snapshot_end_task - snapshot_task_pairs.append((evaluation_task, evaluation_task)) - - batch_concurrency = snapshot.node.batch_concurrency - if snapshot.depends_on_past: - batch_concurrency = 1 - - if not intervals_per_snapshot.intervals: - snapshot_start_task >> snapshot_end_task - elif batch_concurrency: - for i in range(batch_concurrency, len(snapshot_task_pairs)): - snapshot_task_pairs[i - batch_concurrency][1] >> snapshot_task_pairs[i][0] - - snapshot_to_tasks[snapshot.snapshot_id] = ( - snapshot_start_task, - snapshot_end_task, - ) - - backfill_start_task = EmptyOperator(task_id=f"snapshot_backfill_{task_id_suffix}_start") - backfill_end_task = EmptyOperator(task_id=f"snapshot_backfill_{task_id_suffix}_end") - - if not snapshot_to_tasks: - backfill_start_task >> backfill_end_task - return (backfill_start_task, backfill_end_task) - - entry_tasks = [] - parent_ids_to_backfill = set() - for sid, (start_task, _) in snapshot_to_tasks.items(): - has_parents_to_backfill = False - for p_sid in snapshots[sid].parents: - if p_sid in snapshot_to_tasks: - snapshot_to_tasks[p_sid][1] >> start_task - parent_ids_to_backfill.add(p_sid) - has_parents_to_backfill = True - - if not has_parents_to_backfill: - entry_tasks.append(start_task) - - backfill_start_task >> entry_tasks - - exit_tasks = [ - end_task - for sid, (_, end_task) in snapshot_to_tasks.items() - if sid not in parent_ids_to_backfill - ] - for task in exit_tasks: - task >> backfill_end_task - - return (backfill_start_task, backfill_end_task) - - def _create_snapshot_promotion_operator( - self, - snapshots: t.List[Snapshot], - environment_naming_info: EnvironmentNamingInfo, - ddl_concurrent_tasks: int, - deployability_index: DeployabilityIndex, - task_id: str, - ) -> BaseOperator: - return self._ddl_engine_operator( - **self._ddl_engine_operator_args, - target=targets.SnapshotPromotionTarget( - snapshots=snapshots, - environment_naming_info=environment_naming_info, - ddl_concurrent_tasks=ddl_concurrent_tasks, - deployability_index=deployability_index, - ), - task_id=task_id, - ) - - def _create_snapshot_demotion_operator( - self, - snapshots: t.List[SnapshotTableInfo], - environment_naming_info: EnvironmentNamingInfo, - ddl_concurrent_tasks: int, - task_id: str, - ) -> BaseOperator: - return self._ddl_engine_operator( - **self._ddl_engine_operator_args, - target=targets.SnapshotDemotionTarget( - snapshots=snapshots, - environment_naming_info=environment_naming_info, - ddl_concurrent_tasks=ddl_concurrent_tasks, - ), - task_id=task_id, - ) - - def _create_snapshot_create_tables_operator( - self, - new_snapshots: t.List[Snapshot], - ddl_concurrent_tasks: int, - deployability_index: DeployabilityIndex, - allow_destructive_snapshots: t.Set[str], - task_id: str, - ) -> BaseOperator: - return self._ddl_engine_operator( - **self._ddl_engine_operator_args, - target=targets.SnapshotCreateTablesTarget( - new_snapshots=new_snapshots, - ddl_concurrent_tasks=ddl_concurrent_tasks, - deployability_index=deployability_index, - allow_destructive_snapshots=allow_destructive_snapshots, - ), - task_id=task_id, - ) - - def _create_snapshot_migrate_tables_operator( - self, - snapshots: t.List[Snapshot], - ddl_concurrent_tasks: int, - allow_destructive_snapshots: t.Set[str], - task_id: str, - ) -> BaseOperator: - return self._ddl_engine_operator( - **self._ddl_engine_operator_args, - target=targets.SnapshotMigrateTablesTarget( - snapshots=snapshots, - ddl_concurrent_tasks=ddl_concurrent_tasks, - allow_destructive_snapshots=allow_destructive_snapshots, - ), - task_id=task_id, - ) - - def _create_snapshot_evaluation_operator( - self, - snapshots: t.Dict[SnapshotId, Snapshot], - snapshot: Snapshot, - task_id: str, - start: t.Optional[TimeLike] = None, - end: t.Optional[TimeLike] = None, - execution_time: t.Optional[TimeLike] = None, - deployability_index: t.Optional[DeployabilityIndex] = None, - plan_id: t.Optional[str] = None, - batch_index: int = 0, - ) -> BaseOperator: - parent_snapshots = {snapshots[sid].name: snapshots[sid] for sid in snapshot.parents} - - return self._engine_operator( - **self._engine_operator_args, - target=targets.SnapshotEvaluationTarget( - snapshot=snapshot, - parent_snapshots=parent_snapshots, - start=start, - end=end, - deployability_index=deployability_index or DeployabilityIndex.all_deployable(), - plan_id=plan_id, - execution_time=execution_time, - batch_index=batch_index, - ), - task_id=task_id, - ) - - def _create_hwm_sensors( - self, snapshot: Snapshot, snapshots: t.Dict[SnapshotId, Snapshot] - ) -> t.List[BaseSensorOperator]: - output: t.List[BaseSensorOperator] = [] - for upstream_snapshot_id in snapshot.parents: - upstream_snapshot = snapshots[upstream_snapshot_id] - if not upstream_snapshot.is_symbolic and not upstream_snapshot.is_seed: - output.append( - HighWaterMarkSensor( - target_snapshot_info=upstream_snapshot.table_info, - this_snapshot=snapshot, - task_id=f"{sanitize_name(upstream_snapshot.node.name)}_{upstream_snapshot.version}_high_water_mark_sensor", - mode=self._sensor_mode, - **self._high_water_mark_sensor_args, - ) - ) - - external_sensor = self._create_external_sensor(snapshot) - if external_sensor: - output.append(external_sensor) - - return output - - def _create_external_sensor( - self, - snapshot: Snapshot, - start: t.Optional[TimeLike] = None, - end: t.Optional[TimeLike] = None, - ) -> t.Optional[BaseSensorOperator]: - if self._external_table_sensor_factory and snapshot.model.signals: - return ExternalSensor( - snapshot=snapshot, - external_table_sensor_factory=self._external_table_sensor_factory, - task_id="external_high_water_mark_sensor", - mode=self._sensor_mode, - start=start, - end=end, - **self._external_sensor_args, - ) - return None - - -def creation_update_state_task(new_snapshots: t.Collection[Snapshot], request_id: str) -> None: - with util.scoped_state_sync() as state_sync: - state_sync.push_snapshots(new_snapshots) - - from sqlmesh.core.analytics import collector - - collector.on_snapshots_created(new_snapshots=new_snapshots, plan_id=request_id) - - -def promotion_update_state_task( - environment: Environment, - no_gaps_snapshot_names: t.Optional[t.Set[str]], -) -> None: - with util.scoped_state_sync() as state_sync: - state_sync.promote(environment, no_gaps_snapshot_names=no_gaps_snapshot_names) - - -def promotion_unpause_snapshots_task( - environment: Environment, - unpaused_dt: t.Optional[TimeLike], -) -> None: - if environment.snapshots and unpaused_dt: - with util.scoped_state_sync() as state_sync: - state_sync.unpause_snapshots(environment.snapshots, unpaused_dt) - - -def promotion_finalize_task(environment: Environment) -> None: - with util.scoped_state_sync() as state_sync: - state_sync.finalize(environment) - - -def on_plan_apply_end(plan_id: str) -> None: - from sqlmesh.core.analytics import collector - - collector.on_plan_apply_end(plan_id=plan_id) diff --git a/sqlmesh/schedulers/airflow/hooks/bigquery.py b/sqlmesh/schedulers/airflow/hooks/bigquery.py deleted file mode 100644 index a12a7d8b2e..0000000000 --- a/sqlmesh/schedulers/airflow/hooks/bigquery.py +++ /dev/null @@ -1,61 +0,0 @@ -from __future__ import annotations - -import typing as t - -from airflow.providers.common.sql.hooks.sql import DbApiHook -from airflow.providers.google.common.hooks.base_google import GoogleBaseHook - -if t.TYPE_CHECKING: - from google.cloud.bigquery.dbapi import Connection - - -class SQLMeshBigQueryHook(GoogleBaseHook, DbApiHook): - """ - Interact with BigQuery. This hook uses the Google Cloud connection. We didn't use the Airflow BigQueryHook - because it implements an Airflow specific version of the BigQuery DB API that is different then the DB API - provided from Google's python package. - - :param gcp_conn_id: The Airflow connection used for GCP credentials. - :param delegate_to: This performs a task on one host with reference to other hosts. - :param impersonation_chain: This is the optional service account to impersonate using short term - credentials. - """ - - conn_name_attr = "sqlmesh_gcp_conn_id" - default_conn_name = "sqlmesh_google_cloud_bigquery_default" - conn_type = "sqlmeshgcpbigquery" - hook_name = "SQLMesh Google Bigquery" - - def __init__( - self, - gcp_conn_id: str = default_conn_name, - delegate_to: t.Optional[str] = None, - impersonation_chain: t.Optional[t.Union[str, t.Sequence[str]]] = None, - location: t.Optional[str] = None, - ) -> None: - GoogleBaseHook.__init__( - self, - gcp_conn_id=gcp_conn_id, - delegate_to=delegate_to, - impersonation_chain=impersonation_chain, - ) - self.location = location - - def get_conn(self) -> Connection: - """Returns a BigQuery DBAPI connection object.""" - from google.api_core.client_info import ClientInfo - from google.cloud.bigquery import Client - from google.cloud.bigquery.dbapi import Connection - - # This method is private in older versions of the BigQuery library and public later. So we check for both - try: - creds, project_id = self._get_credentials_and_project_id() # type: ignore - except AttributeError: - creds, project_id = self.get_credentials_and_project_id() # type: ignore - client = Client( - project=project_id, - credentials=creds, - location=self.location, - client_info=ClientInfo(user_agent="sqlmesh"), - ) - return Connection(client=client) diff --git a/sqlmesh/schedulers/airflow/hooks/redshift.py b/sqlmesh/schedulers/airflow/hooks/redshift.py deleted file mode 100644 index ba52071ba8..0000000000 --- a/sqlmesh/schedulers/airflow/hooks/redshift.py +++ /dev/null @@ -1,31 +0,0 @@ -from __future__ import annotations - -import typing as t - -import redshift_connector -from airflow.providers.common.sql.hooks.sql import DbApiHook - - -class SQLMeshRedshiftHook(DbApiHook): - """ - Uses the Redshift Python DB API connector. - """ - - conn_name_attr = "sqlmesh_redshift_conn_id" - default_conn_name = "sqlmesh_redshift_default" - conn_type = "sqlmesh_redshift" - hook_name = "SQLMesh Redshift" - connector = redshift_connector - - def get_conn(self) -> redshift_connector.Connection: - """Returns a Redshift connection object""" - db = self.get_connection(getattr(self, t.cast(str, self.conn_name_attr))) - - return self.connector.connect( - host=db.host, - port=db.port, - user=db.login, - password=db.password, - database=db.schema, - **db.extra_dejson, - ) diff --git a/sqlmesh/schedulers/airflow/integration.py b/sqlmesh/schedulers/airflow/integration.py deleted file mode 100644 index 06f5db2bbd..0000000000 --- a/sqlmesh/schedulers/airflow/integration.py +++ /dev/null @@ -1,250 +0,0 @@ -from __future__ import annotations - -import logging -import typing as t -from datetime import datetime, timedelta - -from airflow import DAG -from airflow.models import BaseOperator, TaskInstance, Variable -from airflow.operators.python import PythonOperator -from airflow.sensors.base import BaseSensorOperator -from airflow.utils.session import provide_session -from sqlalchemy.orm import Session - -from sqlmesh.core import constants as c -from sqlmesh.core.state_sync import StateReader -from sqlmesh.engines import commands -from sqlmesh.schedulers.airflow import common, util -from sqlmesh.schedulers.airflow.dag_generator import SnapshotDagGenerator -from sqlmesh.schedulers.airflow.operators import targets -from sqlmesh.schedulers.airflow.plan import PlanDagState - -if t.TYPE_CHECKING: - pass - -logger = logging.getLogger(__name__) - - -class SQLMeshAirflow: - """The entry point for the SQLMesh integration with Airflow. - - The instance of this class should be created in a module that is part of the - Airflow DAGs folder. Its primary purpose is to create DAG objects for the operational - needs of the platform, as well as for model evaluation and backfills. - - Please note that the user must pass created DAGs into the - Airflow scheduler. See the example below: - - Example: - Create a new python module in the Airflow DAGs folder called "sqlmesh_integration.py" - with the following content: - - from sqlmesh.schedulers.airflow.integration import SQLMeshAirflow - - for dag in SQLMeshAirflow("spark").dags: - globals()[dag.dag_id] = dag - - Args: - engine_operator: The type of the Airflow operator that will be used for model evaluation. - If a string value is passed, an automatic operator discovery is attempted based - on the engine name specified in the string. - default_catalog: The default catalog to use when models are defined that do not contain a catalog in their name. This should match the default catalog applied by the connection. - engine_operator_args: The dictionary of arguments that will be passed into the evaluate engine - operator during its construction. - This can be used to customize parameters such as connection ID. - ddl_engine_operator: The type of the Airflow operator that will be used for environment management. - These operations are SQL only. - If a string value is passed, an automatic operator discovery is attempted based - on the engine name specified in the string. - ddl_engine_operator_args: Args to be passed into just the environment management operator. - This can be used to customize parameters such as connection ID. - If not specified, and the operator is the same as `engine_operator`, falls back to using `engine_operator_args`. - janitor_interval: Defines how often the janitor DAG runs. - The janitor DAG removes platform-managed DAG instances that are pending - deletion from Airflow. Default: 1 hour. - plan_application_dag_ttl: Determines the time-to-live period for finished plan application DAGs. - Once this period is exceeded, finished plan application DAGs are deleted by the janitor. Default: 2 days. - external_table_sensor_factory: A factory function that creates a sensor operator for a given signal payload. - sensor_mode: The mode to use for SQLMesh sensors. Supported values are "poke" and "reschedule". - See https://airflow.apache.org/docs/apache-airflow/stable/core-concepts/sensors.html for more details. Default: "reschedule". - high_water_mark_sensor_args: The dictionary of arguments that will be passed into the high water mark sensor operator during its construction. - external_sensor_args: The dictionary of arguments that will be passed into the external sensor operator during its construction. - generate_cadence_dags: Whether to generate cadence DAGs for model versions that are currently deployed to production. - """ - - def __init__( - self, - engine_operator: t.Union[str, t.Type[BaseOperator]], - default_catalog: str, - engine_operator_args: t.Optional[t.Dict[str, t.Any]] = None, - ddl_engine_operator: t.Optional[t.Union[str, t.Type[BaseOperator]]] = None, - ddl_engine_operator_args: t.Optional[t.Dict[str, t.Any]] = None, - janitor_interval: timedelta = timedelta(hours=1), - plan_application_dag_ttl: timedelta = timedelta(days=2), - external_table_sensor_factory: t.Optional[ - t.Callable[[t.Dict[str, t.Any]], BaseSensorOperator] - ] = None, - sensor_mode: str = "reschedule", - high_water_mark_sensor_args: t.Optional[t.Dict[str, t.Any]] = None, - external_sensor_args: t.Optional[t.Dict[str, t.Any]] = None, - generate_cadence_dags: bool = True, - ): - if isinstance(engine_operator, str): - if not ddl_engine_operator: - ddl_engine_operator = util.discover_engine_operator(engine_operator, sql_only=True) - engine_operator = util.discover_engine_operator(engine_operator, sql_only=False) - - if isinstance(ddl_engine_operator, str): - ddl_engine_operator = util.discover_engine_operator(ddl_engine_operator, sql_only=True) - - engine_operator_args = engine_operator_args or {} - ddl_engine_operator_args = ddl_engine_operator_args or {} - - self._engine_operator = engine_operator - self._engine_operator_args = engine_operator_args - self._ddl_engine_operator = ddl_engine_operator or engine_operator - if self._engine_operator == self._ddl_engine_operator: - self._ddl_engine_operator_args = {**engine_operator_args, **ddl_engine_operator_args} - else: - self._ddl_engine_operator_args = ddl_engine_operator_args or {} - self._janitor_interval = janitor_interval - self._plan_application_dag_ttl = plan_application_dag_ttl - self._external_table_sensor_factory = external_table_sensor_factory - self._generate_cadence_dags = generate_cadence_dags - self._default_catalog = default_catalog - self._sensor_mode = sensor_mode - self._high_water_mark_sensor_args = high_water_mark_sensor_args or {} - self._external_sensor_args = external_sensor_args or {} - - @classmethod - def set_default_catalog(cls, default_catalog: str) -> None: - current_value = Variable.get(common.DEFAULT_CATALOG_VARIABLE_NAME, default_var=None) - if not current_value: - Variable.set(common.DEFAULT_CATALOG_VARIABLE_NAME, default_catalog) - if current_value != default_catalog: - Variable.update(common.DEFAULT_CATALOG_VARIABLE_NAME, default_catalog) - - @property - def dags(self) -> t.List[DAG]: - """Returns all DAG instances that must be registered with the Airflow scheduler - for the integration to work. - - Returns: - The list of DAG instances managed by the platform. - """ - self.set_default_catalog(self._default_catalog) - with util.scoped_state_sync() as state_sync: - dag_generator = self._create_dag_generator(state_sync) - - if self._generate_cadence_dags: - prod_env = state_sync.get_environment(c.PROD) - cadence_dags = ( - dag_generator.generate_cadence_dags(prod_env.snapshots) if prod_env else [] - ) - _delete_orphaned_snapshot_dags({d.dag_id for d in cadence_dags}) - else: - cadence_dags = [] - - plan_dag_specs = PlanDagState.from_state_sync(state_sync).get_dag_specs() - plan_application_dags = [ - dag_generator.generate_plan_application_dag(s) for s in plan_dag_specs - ] - - system_dags = [ - self._create_janitor_dag(), - ] - - return system_dags + cadence_dags + [d for d in plan_application_dags if d] - - def _create_janitor_dag(self) -> DAG: - dag = self._create_system_dag(common.JANITOR_DAG_ID, self._janitor_interval) - janitor_task_op = PythonOperator( - task_id=common.JANITOR_TASK_ID, - python_callable=_janitor_task, - op_kwargs={"plan_application_dag_ttl": self._plan_application_dag_ttl}, - dag=dag, - ) - - table_cleanup_task_op = self._ddl_engine_operator( - **self._ddl_engine_operator_args, - target=targets.SnapshotCleanupTarget(), - task_id="snapshot_table_cleanup_task", - dag=dag, - ) - - janitor_task_op >> table_cleanup_task_op - - return dag - - def _create_system_dag(self, dag_id: str, schedule_interval: t.Optional[timedelta]) -> DAG: - return DAG( - dag_id=dag_id, - default_args=dict( - execution_timeout=timedelta(minutes=10), - retries=0, - ), - schedule_interval=schedule_interval, - start_date=datetime(2023, 1, 1), - max_active_runs=1, - catchup=False, - is_paused_upon_creation=False, - tags=[common.SQLMESH_AIRFLOW_TAG], - ) - - def _create_dag_generator(self, state_reader: StateReader) -> SnapshotDagGenerator: - return SnapshotDagGenerator( - self._engine_operator, - self._engine_operator_args, - self._ddl_engine_operator, - self._ddl_engine_operator_args, - self._external_table_sensor_factory, - self._sensor_mode, - self._high_water_mark_sensor_args, - self._external_sensor_args, - state_reader, - ) - - -@provide_session -def _janitor_task( - plan_application_dag_ttl: timedelta, - ti: TaskInstance, - session: Session = util.PROVIDED_SESSION, -) -> None: - with util.scoped_state_sync() as state_sync: - expired_environments = state_sync.delete_expired_environments() - expired_snapshots = state_sync.delete_expired_snapshots() - ti.xcom_push( - key=common.SNAPSHOT_CLEANUP_COMMAND_XCOM_KEY, - value=commands.CleanupCommandPayload( - environments=expired_environments, - tasks=expired_snapshots, - ).json(), - session=session, - ) - - prod_env = state_sync.get_environment(c.PROD) - if prod_env: - active_snapshot_dag_ids = { - common.dag_id_for_snapshot_info(s) for s in prod_env.snapshots - } - _delete_orphaned_snapshot_dags(active_snapshot_dag_ids, session=session) - - plan_application_dag_ids = util.get_finished_plan_application_dag_ids( - ttl=plan_application_dag_ttl, session=session - ) - logger.info("Deleting expired Plan Application DAGs: %s", plan_application_dag_ids) - PlanDagState.from_state_sync(state_sync).delete_dag_specs(plan_application_dag_ids) - util.delete_dags(plan_application_dag_ids, session=session) - - state_sync.compact_intervals() - - -@provide_session -def _delete_orphaned_snapshot_dags( - active_snapshot_dag_ids: t.Set[str], session: Session = util.PROVIDED_SESSION -) -> None: - all_snapshot_dag_ids = set(util.get_snapshot_dag_ids(session=session)) - orphaned_snapshot_dag_ids = all_snapshot_dag_ids - active_snapshot_dag_ids - logger.info("Deleting orphaned Snapshot DAGs: %s", orphaned_snapshot_dag_ids) - util.delete_dags(orphaned_snapshot_dag_ids, session=session) diff --git a/sqlmesh/schedulers/airflow/mwaa_client.py b/sqlmesh/schedulers/airflow/mwaa_client.py deleted file mode 100644 index 4a4028fec0..0000000000 --- a/sqlmesh/schedulers/airflow/mwaa_client.py +++ /dev/null @@ -1,91 +0,0 @@ -from __future__ import annotations - -import base64 -import json -import logging -import typing as t -from urllib.parse import urljoin - -from requests import Session - -from sqlmesh.core.console import Console -from sqlmesh.schedulers.airflow.client import BaseAirflowClient, raise_for_status -from sqlmesh.utils.date import now_timestamp -from sqlmesh.utils.errors import NotFoundError - -logger = logging.getLogger(__name__) - - -TOKEN_TTL_MS = 30 * 1000 - - -class MWAAClient(BaseAirflowClient): - def __init__(self, environment: str, console: t.Optional[Console] = None): - airflow_url, auth_token = url_and_auth_token_for_environment(environment) - super().__init__(airflow_url, console) - - self._environment = environment - self._last_token_refresh_ts = now_timestamp() - self.__session: Session = _create_session(auth_token) - - def get_first_dag_run_id(self, dag_id: str) -> t.Optional[str]: - dag_runs = self._list_dag_runs(dag_id) - if dag_runs: - return dag_runs[-1]["run_id"] - return None - - def get_dag_run_state(self, dag_id: str, dag_run_id: str) -> str: - dag_runs = self._list_dag_runs(dag_id) or [] - for dag_run in dag_runs: - if dag_run["run_id"] == dag_run_id: - return dag_run["state"].lower() - raise NotFoundError(f"DAG run '{dag_run_id}' was not found for DAG '{dag_id}'") - - def get_variable(self, key: str) -> t.Optional[str]: - stdout, stderr = self._post(f"variables get {key}") - if "does not exist" in stderr: - return None - return stdout - - def _list_dag_runs(self, dag_id: str) -> t.Optional[t.List[t.Dict[str, t.Any]]]: - stdout, stderr = self._post(f"dags list-runs -o json -d {dag_id}") - if stdout: - return json.loads(stdout) - return None - - def _post(self, data: str) -> t.Tuple[str, str]: - response = self._session.post(urljoin(self._airflow_url, "aws_mwaa/cli"), data=data) - raise_for_status(response) - response_body = response.json() - - cli_stdout = base64.b64decode(response_body["stdout"]).decode("utf8").strip() - cli_stderr = base64.b64decode(response_body["stderr"]).decode("utf8").strip() - return cli_stdout, cli_stderr - - @property - def _session(self) -> Session: - current_ts = now_timestamp() - if current_ts - self._last_token_refresh_ts > TOKEN_TTL_MS: - _, auth_token = url_and_auth_token_for_environment(self._environment) - self.__session = _create_session(auth_token) - self._last_token_refresh_ts = current_ts - return self.__session - - -def _create_session(auth_token: str) -> Session: - session = Session() - session.headers.update({"Authorization": f"Bearer {auth_token}", "Content-Type": "text/plain"}) - return session - - -def url_and_auth_token_for_environment(environment_name: str) -> t.Tuple[str, str]: - import boto3 - - logger.info("Fetching the MWAA CLI token") - - client = boto3.client("mwaa") - cli_token = client.create_cli_token(Name=environment_name) - - url = f"https://{cli_token['WebServerHostname']}/" - auth_token = cli_token["CliToken"] - return url, auth_token diff --git a/sqlmesh/schedulers/airflow/operators/base.py b/sqlmesh/schedulers/airflow/operators/base.py deleted file mode 100644 index b4d22b47ec..0000000000 --- a/sqlmesh/schedulers/airflow/operators/base.py +++ /dev/null @@ -1,45 +0,0 @@ -from __future__ import annotations - -import typing as t - -from airflow.models import BaseOperator -from airflow.utils.context import Context -from airflow.providers.common.sql.hooks.sql import DbApiHook - - -from sqlmesh.schedulers.airflow.operators.targets import BaseTarget - - -class BaseDbApiOperator(BaseOperator): - """The base class for DB API operators. - - Args: - target: The target that will be executed by this operator instance. - conn_id: The Airflow connection id. - dialect: The target SQL dialect. - hook_type: The type of the DB API hook. - """ - - def __init__( - self, - *, - target: BaseTarget, - conn_id: str, - dialect: str, - hook_type: t.Type[DbApiHook], - **kwargs: t.Any, - ) -> None: - super().__init__(**kwargs) - self._hook_type = hook_type - self._target = target - self._conn_id = conn_id - self._dialect = dialect - self._hook_params = kwargs - - def get_db_hook(self) -> DbApiHook: - """Gets the Hook which contains the DB API connection object.""" - return self._hook_type(self._conn_id, **self._hook_params) - - def execute(self, context: Context) -> None: - """Executes the desired target against the configured connection.""" - self._target.execute(context, lambda: self.get_db_hook().get_conn(), self._dialect) diff --git a/sqlmesh/schedulers/airflow/operators/bigquery.py b/sqlmesh/schedulers/airflow/operators/bigquery.py deleted file mode 100644 index ab261717ee..0000000000 --- a/sqlmesh/schedulers/airflow/operators/bigquery.py +++ /dev/null @@ -1,53 +0,0 @@ -from __future__ import annotations - -import typing as t - -from airflow.models import BaseOperator -from airflow.utils.context import Context - -from sqlmesh.schedulers.airflow.hooks.bigquery import SQLMeshBigQueryHook -from sqlmesh.schedulers.airflow.operators.targets import BaseTarget - - -class SQLMeshBigQueryOperator(BaseOperator): - """The operator that evaluates a SQLMesh model snapshot on Bigquery - - Args: - target: The target that will be executed by this operator instance. - bigquery_conn_id: The Airflow connection id for the bigquery target. - """ - - def __init__( - self, - *, - target: BaseTarget, - bigquery_conn_id: str = SQLMeshBigQueryHook.default_conn_name, - delegate_to: str | None = None, - impersonation_chain: str | t.Sequence[str] | None = None, - location: t.Optional[str] = None, - **kwargs: t.Any, - ) -> None: - super().__init__(**kwargs) - self._target = target - self._bigquery_conn_id = bigquery_conn_id - self._delegate_to = delegate_to - self._impersonation_chain = impersonation_chain - self._location = location - - def get_db_hook(self) -> SQLMeshBigQueryHook: - """Gets the BigQuery Hook which contains the DB API connection object""" - return SQLMeshBigQueryHook( - self._bigquery_conn_id, - delegate_to=self._delegate_to, - impersonation_chain=self._impersonation_chain, - location=self._location, - ) - - def execute(self, context: Context) -> None: - """Executes the desired target against the configured BigQuery connection""" - self._target.execute( - context, - lambda: self.get_db_hook().get_conn(), - "bigquery", - job_retries=self.get_db_hook().num_retries, - ) diff --git a/sqlmesh/schedulers/airflow/operators/databricks.py b/sqlmesh/schedulers/airflow/operators/databricks.py deleted file mode 100644 index 1762562aba..0000000000 --- a/sqlmesh/schedulers/airflow/operators/databricks.py +++ /dev/null @@ -1,139 +0,0 @@ -from __future__ import annotations - -import os -import tempfile -import typing as t - -from airflow.providers.databricks.hooks.databricks_base import BaseDatabricksHook -from airflow.providers.databricks.hooks.databricks_sql import DatabricksSqlHook -from airflow.providers.databricks.operators.databricks import ( - DatabricksSubmitRunOperator, -) -from airflow.utils.context import Context - -import sqlmesh -from sqlmesh import utils -from sqlmesh.engines import commands -from sqlmesh.schedulers.airflow.operators.base import BaseDbApiOperator -from sqlmesh.schedulers.airflow.operators.targets import BaseTarget - - -class SQLMeshDatabricksSubmitOperator(DatabricksSubmitRunOperator): - """Operator for submitting Databricks jobs to a Databricks cluster using the submit run API. - - Args: - target: The target that will be executed by this operator instance. - dbfs_location: The dbfs location where the app.py file and payload will be copied to. - existing_cluster_id: The id of the cluster to run the job on. Either this or new_cluster must be specified. - new_cluster: The specification for a new cluster to run the job on. Either this or existing_cluster_id must be specified. - """ - - def __init__( - self, - target: BaseTarget, - dbfs_location: t.Optional[str] = None, - **kwargs: t.Any, - ) -> None: - if not dbfs_location: - raise ValueError( - "dbfs_location is required for Databricks connections. See documentation for more details." - ) - if not dbfs_location.startswith("dbfs:/"): - raise ValueError( - "dbfs_location must start with 'dbfs:/'. See documentation for more details." - ) - super().__init__(**kwargs) - self._target = target - self._dbfs_location = os.path.join(dbfs_location, utils.random_id()) - - def execute(self, context: Context) -> None: - """Executes the target against the configured databricks cluster using the submit run API. - - SQLMesh copies the app.py file to the dbfs location specified in the operator. It also copies a file containing - the target's payload to the dbfs location. The app.py file is then executed with the target's payload as an - argument. - - TODO: Add support for `idempotency token`. This would allow this operator to reattach to an existing run if it - exists instead of creating a new one. We would need to make sure this is done correctly by have the dbfs - path use that token for the path instead of a random ID. Consider using a plan ID but also consider how - cadence runs and restatements will also work. - """ - from databricks_cli.dbfs.api import DbfsApi - from databricks_cli.sdk.api_client import ApiClient - - if "new_cluster" not in self.json and "existing_cluster_id" not in self.json: - http_path = self._hook.databricks_conn.extra_dejson.get("http_path") - if not http_path: - raise ValueError( - "Must provide a cluster to run on or new cluster specification. See documentation for more details." - ) - cluster_id = http_path.split("/")[-1] - if "-" not in cluster_id: - raise ValueError( - "Must provide a non-DBSQL cluster to execute against. See documentation for more details." - ) - self.json["existing_cluster_id"] = cluster_id - - api_client = ApiClient( - host=f"https://{self._hook.host}", - token=self._hook._get_token(raise_error=False), - user=self._hook.databricks_conn.login, - password=self._hook.databricks_conn.password, - ) - dbfs_api = DbfsApi(api_client) - - local_app_path = os.path.join( - os.path.dirname(os.path.abspath(sqlmesh.__file__)), "engines/spark/app.py" - ) - remote_app_path = os.path.join(self._dbfs_location, "app.py") - dbfs_api.cp(recursive=False, overwrite=True, src=local_app_path, dst=remote_app_path) - - command_payload = self._target.serialized_command_payload(context) - with tempfile.TemporaryDirectory() as tmp: - local_payload_path = os.path.join(tmp, commands.COMMAND_PAYLOAD_FILE_NAME) - with open(local_payload_path, "w", encoding="utf-8") as payload_fd: - payload_fd.write(command_payload) - remote_payload_path = os.path.join( - self._dbfs_location, commands.COMMAND_PAYLOAD_FILE_NAME - ) - dbfs_api.cp( - recursive=False, overwrite=True, src=local_payload_path, dst=remote_payload_path - ) - task_arguments = { - "dialect": "databricks", - "default_catalog": self._target.default_catalog, - "command_type": self._target.command_type.value if self._target.command_type else None, - "ddl_concurrent_tasks": self._target.ddl_concurrent_tasks, - "payload_path": remote_payload_path, - } - python_task = { - "python_file": remote_app_path, - "parameters": [f"--{k}={v}" for k, v in task_arguments.items() if v is not None], - } - self.json["spark_python_task"] = python_task - super().execute(context) - self._target.post_hook(context) - - -class SQLMeshDatabricksSQLOperator(BaseDbApiOperator): - """Operator for running just SQL operations against Databricks. - - Args: - target: The target that will be executed by this operator instance. - databricks_conn_id: The Airflow connection id for the databricks target. - """ - - def __init__( - self, - *, - target: BaseTarget, - databricks_conn_id: str = BaseDatabricksHook.default_conn_name, - **kwargs: t.Any, - ) -> None: - super().__init__( - target=target, - conn_id=databricks_conn_id, - dialect="databricks", - hook_type=DatabricksSqlHook, - **kwargs, - ) diff --git a/sqlmesh/schedulers/airflow/operators/mssql.py b/sqlmesh/schedulers/airflow/operators/mssql.py deleted file mode 100644 index d9d84eceba..0000000000 --- a/sqlmesh/schedulers/airflow/operators/mssql.py +++ /dev/null @@ -1,28 +0,0 @@ -from __future__ import annotations - -import typing as t - -from airflow.providers.microsoft.mssql.hooks.mssql import MsSqlHook - -from sqlmesh.schedulers.airflow.operators.base import BaseDbApiOperator -from sqlmesh.schedulers.airflow.operators.targets import BaseTarget - - -class SQLMeshMsSqlOperator(BaseDbApiOperator): - """The operator that evaluates a SQLMesh model snapshot on a mssql target - - Args: - target: The target that will be executed by this operator instance. - mssql_conn_id: The Airflow connection id for the mssql target. - """ - - def __init__( - self, - *, - target: BaseTarget, - mssql_conn_id: str = MsSqlHook.default_conn_name, - **kwargs: t.Any, - ) -> None: - super().__init__( - target=target, conn_id=mssql_conn_id, dialect="mssql", hook_type=MsSqlHook, **kwargs - ) diff --git a/sqlmesh/schedulers/airflow/operators/mysql.py b/sqlmesh/schedulers/airflow/operators/mysql.py deleted file mode 100644 index 7dd5ae855e..0000000000 --- a/sqlmesh/schedulers/airflow/operators/mysql.py +++ /dev/null @@ -1,28 +0,0 @@ -from __future__ import annotations - -import typing as t - -from airflow.providers.mysql.hooks.mysql import MySqlHook - -from sqlmesh.schedulers.airflow.operators.base import BaseDbApiOperator -from sqlmesh.schedulers.airflow.operators.targets import BaseTarget - - -class SQLMeshMySqlOperator(BaseDbApiOperator): - """The operator that evaluates a SQLMesh model snapshot on a mysql target - - Args: - target: The target that will be executed by this operator instance. - mysql_conn_id: The Airflow connection id for the mysql target. - """ - - def __init__( - self, - *, - target: BaseTarget, - mysql_conn_id: str = MySqlHook.default_conn_name, - **kwargs: t.Any, - ) -> None: - super().__init__( - target=target, conn_id=mysql_conn_id, dialect="mysql", hook_type=MySqlHook, **kwargs - ) diff --git a/sqlmesh/schedulers/airflow/operators/notification.py b/sqlmesh/schedulers/airflow/operators/notification.py deleted file mode 100644 index 5e6eedab88..0000000000 --- a/sqlmesh/schedulers/airflow/operators/notification.py +++ /dev/null @@ -1,30 +0,0 @@ -import abc -import typing as t - -from airflow.models import BaseOperator - -from sqlmesh.core.notification_target import BaseNotificationTarget -from sqlmesh.core.plan import PlanStatus -from sqlmesh.schedulers.airflow import common - -NT = t.TypeVar("NT", bound=BaseNotificationTarget) - - -class BaseNotificationOperatorProvider(abc.ABC, t.Generic[NT]): - @abc.abstractmethod - def operator( - self, - target: NT, - plan_status: PlanStatus, - plan_dag_spec: common.PlanDagSpec, - **dag_kwargs: t.Any, - ) -> t.Optional[BaseOperator]: - pass - - def get_trigger_rule(self, plan_status: PlanStatus) -> str: - if plan_status.is_failed: - return "one_failed" - return "all_success" - - def get_task_id(self, target: NT, plan_status: PlanStatus) -> str: - return f"plan_{plan_status.value}_{target.type_}_notification" diff --git a/sqlmesh/schedulers/airflow/operators/postgres.py b/sqlmesh/schedulers/airflow/operators/postgres.py deleted file mode 100644 index 49dbb1686e..0000000000 --- a/sqlmesh/schedulers/airflow/operators/postgres.py +++ /dev/null @@ -1,32 +0,0 @@ -from __future__ import annotations - -import typing as t - -from airflow.providers.postgres.hooks.postgres import PostgresHook - -from sqlmesh.schedulers.airflow.operators.base import BaseDbApiOperator -from sqlmesh.schedulers.airflow.operators.targets import BaseTarget - - -class SQLMeshPostgresOperator(BaseDbApiOperator): - """The operator that evaluates a SQLMesh model snapshot on a Postgres target - - Args: - target: The target that will be executed by this operator instance. - postgres_conn_id: The Airflow connection id for the postgres target. - """ - - def __init__( - self, - *, - target: BaseTarget, - postgres_conn_id: str = PostgresHook.default_conn_name, - **kwargs: t.Any, - ) -> None: - super().__init__( - target=target, - conn_id=postgres_conn_id, - dialect="postgres", - hook_type=PostgresHook, - **kwargs, - ) diff --git a/sqlmesh/schedulers/airflow/operators/redshift.py b/sqlmesh/schedulers/airflow/operators/redshift.py deleted file mode 100644 index 0d1f5c9240..0000000000 --- a/sqlmesh/schedulers/airflow/operators/redshift.py +++ /dev/null @@ -1,31 +0,0 @@ -from __future__ import annotations - -import typing as t - - -from sqlmesh.schedulers.airflow.hooks.redshift import SQLMeshRedshiftHook -from sqlmesh.schedulers.airflow.operators.base import BaseDbApiOperator -from sqlmesh.schedulers.airflow.operators.targets import BaseTarget - - -class SQLMeshRedshiftOperator(BaseDbApiOperator): - """The operator that evaluates a SQLMesh model snapshot on Redshift cluster - - Args: - target: The target that will be executed by this operator instance. - redshift_conn_id: The Airflow connection id for the Redshift target. - """ - - def __init__( - self, - target: BaseTarget, - redshift_conn_id: str = SQLMeshRedshiftHook.default_conn_name, - **kwargs: t.Any, - ) -> None: - super().__init__( - target=target, - conn_id=redshift_conn_id, - dialect="redshift", - hook_type=SQLMeshRedshiftHook, - **kwargs, - ) diff --git a/sqlmesh/schedulers/airflow/operators/sensor.py b/sqlmesh/schedulers/airflow/operators/sensor.py deleted file mode 100644 index 3a7e72d8da..0000000000 --- a/sqlmesh/schedulers/airflow/operators/sensor.py +++ /dev/null @@ -1,100 +0,0 @@ -from __future__ import annotations - -import logging -import typing as t -from datetime import datetime - -from airflow.models import DagRun -from airflow.sensors.base import BaseSensorOperator -from airflow.utils.context import Context - -from sqlmesh.core.snapshot import Snapshot, SnapshotTableInfo -from sqlmesh.schedulers.airflow import util -from sqlmesh.utils.date import TimeLike, now, to_datetime - -if t.TYPE_CHECKING: - from airflow.sensors.base import PokeReturnValue - -logger = logging.getLogger(__name__) - - -class HighWaterMarkSensor(BaseSensorOperator): - def __init__( - self, - target_snapshot_info: SnapshotTableInfo, - this_snapshot: Snapshot, - mode: str = "reschedule", - **kwargs: t.Any, - ) -> None: - super().__init__( - mode=mode, - **kwargs, - ) - self.target_snapshot_info = target_snapshot_info - self.this_snapshot = this_snapshot - - def poke(self, context: Context) -> bool: - dag_run = context["dag_run"] - - with util.scoped_state_sync() as state_sync: - target_snapshot = state_sync.get_snapshots([self.target_snapshot_info])[ - self.target_snapshot_info.snapshot_id - ] - if target_snapshot.intervals: - current_high_water_mark = to_datetime(target_snapshot.intervals[-1][1]) - else: - current_high_water_mark = None - - target_high_water_mark = self._compute_target_high_water_mark( - dag_run, # type: ignore - target_snapshot, - ) - - logger.info( - "The current high water mark for snapshot %s is '%s' (target is '%s')", - self.target_snapshot_info.snapshot_id, - current_high_water_mark, - target_high_water_mark, - ) - if current_high_water_mark is not None: - return current_high_water_mark >= target_high_water_mark - return False - - def _compute_target_high_water_mark( - self, dag_run: DagRun, target_snapshot: Snapshot - ) -> datetime: - target_date = to_datetime(dag_run.data_interval_end) - target_prev = to_datetime(target_snapshot.node.interval_unit.cron_floor(target_date)) - this_prev = to_datetime(self.this_snapshot.node.interval_unit.cron_floor(target_date)) - return min(target_prev, this_prev) - - -class ExternalSensor(BaseSensorOperator): - def __init__( - self, - snapshot: Snapshot, - external_table_sensor_factory: t.Callable[[t.Dict[str, t.Any]], BaseSensorOperator], - mode: str = "reschedule", - start: t.Optional[TimeLike] = None, - end: t.Optional[TimeLike] = None, - **kwargs: t.Any, - ): - super().__init__( - mode=mode, - **kwargs, - ) - self.snapshot = snapshot - self.external_table_sensor_factory = external_table_sensor_factory - self.start = start - self.end = end - - def poke(self, context: Context) -> t.Union[bool, PokeReturnValue]: - interval_unit = self.snapshot.node.interval_unit - dag_run = context["dag_run"] - signals = self.snapshot.model.render_signals( - start=interval_unit.cron_floor(self.start or dag_run.data_interval_start), # type: ignore - end=interval_unit.cron_floor(self.end or dag_run.data_interval_end), # type: ignore - execution_time=now(minute_floor=False), - ) - delegates = [self.external_table_sensor_factory(signal) for signal in signals] - return all(d.poke(context) for d in delegates) diff --git a/sqlmesh/schedulers/airflow/operators/snowflake.py b/sqlmesh/schedulers/airflow/operators/snowflake.py deleted file mode 100644 index b015d3f1d2..0000000000 --- a/sqlmesh/schedulers/airflow/operators/snowflake.py +++ /dev/null @@ -1,32 +0,0 @@ -from __future__ import annotations - -import typing as t - -from airflow.providers.snowflake.hooks.snowflake import SnowflakeHook - -from sqlmesh.schedulers.airflow.operators.base import BaseDbApiOperator -from sqlmesh.schedulers.airflow.operators.targets import BaseTarget - - -class SQLMeshSnowflakeOperator(BaseDbApiOperator): - """The operator that evaluates a SQLMesh model snapshot on a Snowflake target - - Args: - target: The target that will be executed by this operator instance. - databricks_conn_id: The Airflow connection id for the snowflake target. - """ - - def __init__( - self, - *, - target: BaseTarget, - snowflake_conn_id: str = SnowflakeHook.default_conn_name, - **kwargs: t.Any, - ) -> None: - super().__init__( - target=target, - conn_id=snowflake_conn_id, - dialect="snowflake", - hook_type=SnowflakeHook, - **kwargs, - ) diff --git a/sqlmesh/schedulers/airflow/operators/spark_submit.py b/sqlmesh/schedulers/airflow/operators/spark_submit.py deleted file mode 100644 index 720f448381..0000000000 --- a/sqlmesh/schedulers/airflow/operators/spark_submit.py +++ /dev/null @@ -1,160 +0,0 @@ -import os -import tempfile -import typing as t - -from airflow.models import BaseOperator -from airflow.providers.apache.spark.hooks.spark_submit import SparkSubmitHook -from airflow.utils.context import Context - -import sqlmesh -from sqlmesh.engines import commands -from sqlmesh.schedulers.airflow.operators.targets import ( - BaseTarget, - SnapshotEvaluationTarget, -) - - -class SQLMeshSparkSubmitOperator(BaseOperator): - """The operator which evaluates a SQLMesh model snapshot using a dedicated Spark job instance. - - It requires the "spark-submit" binary to be available in the PATH or the spark_home - attribute to be set in the connection extras. - - Args: - target: The target that will be executed by this operator instance. - application_name: The name of the submitted application (default sqlmesh-spark). - spark_conf: Spark configuration properties. - connection_id: The Airflow connection ID as described in - https://airflow.apache.org/docs/apache-airflow-providers-apache-spark/stable/connections/spark.html - (default spark_default). - total_executor_cores: (Srandalone & Mesos only) The total number of cores for all executors. - executor_cores: (Standalone, YARN and Kubernetes only) The number of cores per executor. - executor_memory: The amount of memory allocated to each executor (e.g. 1024M, 2G). - driver_memory: The amount of memory allocated to the driver (e.g. 1024M, 2G). - keytab: The full path to the file that contains the keytab. - principal: The name of the Kerberos principal used for the keytab. - proxy_user: The name of a user which should be impersonated when submitting the application. - num_executors: The number of executors that will be allocateed to the application. - """ - - def __init__( - self, - *, - target: BaseTarget, - application_name: str = "sqlmesh-spark", - spark_conf: t.Optional[t.Dict[str, t.Any]] = None, - connection_id: str = "spark_default", - total_executor_cores: t.Optional[int] = None, - executor_cores: t.Optional[int] = None, - executor_memory: t.Optional[str] = None, - driver_memory: t.Optional[str] = None, - keytab: t.Optional[str] = None, - principal: t.Optional[str] = None, - proxy_user: t.Optional[str] = None, - num_executors: t.Optional[int] = None, - **kwargs: t.Any, - ) -> None: - super().__init__(**kwargs) - self._target = target - self._application_name = application_name - self._spark_conf = spark_conf or {} - self._total_executor_cores = total_executor_cores - self._executor_cores = executor_cores - self._executor_memory = executor_memory - self._driver_memory = driver_memory - self._keytab = keytab - self._principal = principal - self._proxy_user = proxy_user - self._num_executors = num_executors - self._connection_id = connection_id - self._application = os.path.join( - os.path.dirname(os.path.abspath(sqlmesh.__file__)), "engines/spark/app.py" - ) - self._hook: t.Optional[SparkSubmitHook] = None - - def execute(self, context: Context) -> None: - command_payload = self._target.serialized_command_payload(context) - with tempfile.TemporaryDirectory() as tmp: - payload_file_path = os.path.join(tmp, commands.COMMAND_PAYLOAD_FILE_NAME) - with open(payload_file_path, "w", encoding="utf-8") as payload_fd: - payload_fd.write(command_payload) - - if self._hook is None: - if ( - isinstance(self._target, SnapshotEvaluationTarget) - and self._target.snapshot.is_model - ): - session_properties = self._target.snapshot.model.session_properties - executor_cores: t.Optional[int] = session_properties.pop( # type: ignore - "spark.executor.cores", self._executor_cores - ) - executor_memory: t.Optional[str] = session_properties.pop( # type: ignore - "spark.executor.memory", self._executor_memory - ) - driver_memory: t.Optional[str] = session_properties.pop( # type: ignore - "spark.driver.memory", self._driver_memory - ) - num_executors: t.Optional[int] = session_properties.pop( # type: ignore - "spark.executor.instances", self._num_executors - ) - spark_conf: t.Dict[str, t.Any] = {**self._spark_conf, **session_properties} - else: - executor_cores = self._executor_cores - executor_memory = self._executor_memory - driver_memory = self._driver_memory - num_executors = self._num_executors - spark_conf = self._spark_conf - - self._hook = self._get_hook( - self._target.command_type, - payload_file_path, - self._target.ddl_concurrent_tasks, - spark_conf, - executor_cores, - executor_memory, - driver_memory, - num_executors, - ) - self._hook.submit(self._application) - self._target.post_hook(context) - - def on_kill(self) -> None: - if self._hook is None: - self._hook = self._get_hook(None, None, None, None, None, None, None, None) - self._hook.on_kill() - - def _get_hook( - self, - command_type: t.Optional[commands.CommandType], - command_payload_file_path: t.Optional[str], - ddl_concurrent_tasks: t.Optional[int], - spark_conf: t.Optional[t.Dict[str, t.Any]], - executor_cores: t.Optional[int], - executor_memory: t.Optional[str], - driver_memory: t.Optional[str], - num_executors: t.Optional[int], - ) -> SparkSubmitHook: - application_args = { - "dialect": "spark", - "default_catalog": self._target.default_catalog, - "command_type": command_type.value if command_type else None, - "ddl_concurrent_tasks": ddl_concurrent_tasks, - "payload_path": ( - command_payload_file_path.split("/")[-1] if command_payload_file_path else None - ), - } - return SparkSubmitHook( - conf=spark_conf, - conn_id=self._connection_id, - total_executor_cores=self._total_executor_cores, - executor_cores=executor_cores, - executor_memory=executor_memory, - driver_memory=driver_memory, - keytab=self._keytab, - principal=self._principal, - proxy_user=self._proxy_user, - name=self._application_name, - num_executors=num_executors, - application_args=[f"--{k}={v}" for k, v in application_args.items() if v is not None], - files=command_payload_file_path, - ) diff --git a/sqlmesh/schedulers/airflow/operators/targets.py b/sqlmesh/schedulers/airflow/operators/targets.py deleted file mode 100644 index b9e924b109..0000000000 --- a/sqlmesh/schedulers/airflow/operators/targets.py +++ /dev/null @@ -1,349 +0,0 @@ -import abc -import logging -import typing as t - -from airflow.exceptions import AirflowSkipException -from airflow.models import Variable -from airflow.utils.context import Context -from airflow.utils.session import provide_session -from sqlalchemy.orm import Session - -from sqlmesh.core.engine_adapter import create_engine_adapter -from sqlmesh.core.environment import EnvironmentNamingInfo -from sqlmesh.core.snapshot import ( - DeployabilityIndex, - Snapshot, - SnapshotEvaluator, - SnapshotTableInfo, -) -from sqlmesh.engines import commands -from sqlmesh.schedulers.airflow import common, util, NO_DEFAULT_CATALOG -from sqlmesh.utils.date import TimeLike -from sqlmesh.utils.pydantic import PydanticModel - -CP = t.TypeVar("CP", bound=PydanticModel) - - -class BaseTarget(abc.ABC, t.Generic[CP]): - command_type: commands.CommandType - command_handler: t.Callable[[SnapshotEvaluator, CP], None] - ddl_concurrent_tasks: int - - @property - def default_catalog(self) -> t.Optional[str]: - default_catalog = Variable.get(common.DEFAULT_CATALOG_VARIABLE_NAME) - if default_catalog == NO_DEFAULT_CATALOG: - return None - return default_catalog - - def serialized_command_payload(self, context: Context) -> str: - """Returns the serialized command payload for the Spark application. - - Args: - context: Airflow task context. - - Returns: - The serialized command payload. - """ - return self._get_command_payload_or_skip(context).json() - - def execute( - self, - context: Context, - connection_factory: t.Callable[[], t.Any], - dialect: str, - **kwargs: t.Any, - ) -> None: - """Executes this target. - - Args: - context: Airflow task context. - connection_factory: a callable which produces a new Database API compliant - connection on every call. - dialect: The dialect with which this adapter is associated. - """ - payload = self._get_command_payload_or_skip(context) - snapshot_evaluator = SnapshotEvaluator( - create_engine_adapter( - connection_factory, - dialect, - multithreaded=self.ddl_concurrent_tasks > 1, - execute_log_level=logging.INFO, - default_catalog=self.default_catalog, - **kwargs, - ), - ddl_concurrent_tasks=self.ddl_concurrent_tasks, - ) - try: - self.command_handler(snapshot_evaluator, payload) - self.post_hook(context) - finally: - snapshot_evaluator.close() - - def post_hook(self, context: Context, **kwargs: t.Any) -> None: - """The hook that should be invoked once the processing of this target - is complete. - - Args: - context: Airflow task context. - """ - - @abc.abstractmethod - def _get_command_payload(self, context: Context) -> t.Optional[CP]: - """Constructs the command payload. - - Args: - context: Airflow task context. - - Returns: - The command payload or None if there is no command to execute - and the target must be skipped. - """ - - def _get_command_payload_or_skip(self, context: Context) -> CP: - payload = self._get_command_payload(context) - if not payload: - self.post_hook(context) - raise AirflowSkipException - return payload - - -class SnapshotEvaluationTarget(PydanticModel, BaseTarget[commands.EvaluateCommandPayload]): - """The target which contains attributes necessary to evaluate a given snapshot. - - Args: - snapshot: The snapshot which should be evaluated. - parent_snapshots: All upstream snapshots to use for expansion and mapping of physical locations. - start: The start of the interval to evaluate. - end: The end of the interval to evaluate. - execution_time: The date/time time reference to use for execution time. Defaults to now. - deployability_index: Determines snapshots that are deployable in the context of this evaluation. - batch_index: For snapshots that are part of a batch, this is their position in the batch - """ - - command_type: commands.CommandType = commands.CommandType.EVALUATE - command_handler: t.Callable[[SnapshotEvaluator, commands.EvaluateCommandPayload], None] = ( - commands.evaluate - ) - ddl_concurrent_tasks: int = 1 - - snapshot: Snapshot - parent_snapshots: t.Dict[str, Snapshot] - start: t.Optional[TimeLike] = None - end: t.Optional[TimeLike] = None - execution_time: t.Optional[TimeLike] = None - deployability_index: DeployabilityIndex - plan_id: t.Optional[str] = None - batch_index: int = 0 - - def post_hook( - self, - context: Context, - **kwargs: t.Any, - ) -> None: - with util.scoped_state_sync() as state_sync: - state_sync.add_interval( - self.snapshot, - self._get_start(context), - self._get_end(context), - is_dev=not self.deployability_index.is_deployable(self.snapshot), - ) - - def _get_command_payload(self, context: Context) -> t.Optional[commands.EvaluateCommandPayload]: - return commands.EvaluateCommandPayload( - snapshot=self.snapshot, - parent_snapshots=self.parent_snapshots, - start=self._get_start(context), - end=self._get_end(context), - execution_time=self._get_execution_time(context), - deployability_index=self.deployability_index, - batch_index=self.batch_index, - ) - - def _get_start(self, context: Context) -> TimeLike: - if self.start: - return self.start - - start = self.snapshot.node.interval_unit.cron_floor(context["dag_run"].data_interval_start) # type: ignore - if not self.snapshot.is_model: - return start - - return self.snapshot.model.lookback_start(start) - - def _get_end(self, context: Context) -> TimeLike: - return self.end or self.snapshot.node.interval_unit.cron_floor( - context["dag_run"].data_interval_end # type: ignore - ) - - def _get_execution_time(self, context: Context) -> TimeLike: - return self.execution_time or context["dag_run"].logical_date - - -class SnapshotPromotionTarget(PydanticModel, BaseTarget[commands.PromoteCommandPayload]): - """The target which contains attributes necessary to perform snapshot promotion in a given environment. - - The promotion means creation of views associated with the environment which target physical tables - associated with the given list of snapshots. - - Args: - snapshots: The list of snapshots that should be promoted in the target environment. - environment_naming_info: Naming information for the target environment. - ddl_concurrent_tasks: The number of concurrent tasks used for DDL - operations (table / view creation, deletion, etc). Default: 1. - deployability_index: Determines snapshots that are deployable in the context of this promotion. - """ - - command_type: commands.CommandType = commands.CommandType.PROMOTE - command_handler: t.Callable[[SnapshotEvaluator, commands.PromoteCommandPayload], None] = ( - commands.promote - ) - - snapshots: t.List[Snapshot] - environment_naming_info: EnvironmentNamingInfo - ddl_concurrent_tasks: int - deployability_index: DeployabilityIndex - - def _get_command_payload(self, context: Context) -> t.Optional[commands.PromoteCommandPayload]: - return commands.PromoteCommandPayload( - snapshots=self.snapshots, - environment_naming_info=self.environment_naming_info, - deployability_index=self.deployability_index, - ) - - -class SnapshotDemotionTarget(PydanticModel, BaseTarget[commands.DemoteCommandPayload]): - """The target which contains attributes necessary to perform snapshot demotion in a given environment. - - The demotion means deletion of views that match names of provided snapshots in the target environment. - - Args: - snapshots: The list of snapshots that should be demoted in the target environment. - environment_naming_info: Naming information for the target environment. - """ - - command_type: commands.CommandType = commands.CommandType.DEMOTE - command_handler: t.Callable[[SnapshotEvaluator, commands.DemoteCommandPayload], None] = ( - commands.demote - ) - - snapshots: t.List[SnapshotTableInfo] - environment_naming_info: EnvironmentNamingInfo - ddl_concurrent_tasks: int - - def _get_command_payload(self, context: Context) -> t.Optional[commands.DemoteCommandPayload]: - return commands.DemoteCommandPayload( - snapshots=self.snapshots, - environment_naming_info=self.environment_naming_info, - ) - - -class SnapshotCleanupTarget(PydanticModel, BaseTarget[commands.CleanupCommandPayload]): - """The target which contains attributes necessary to perform table cleanup of expired snapshots""" - - command_type: commands.CommandType = commands.CommandType.CLEANUP - command_handler: t.Callable[[SnapshotEvaluator, commands.CleanupCommandPayload], None] = ( - commands.cleanup - ) - ddl_concurrent_tasks: int = 1 - - @provide_session - def post_hook( - self, - context: Context, - session: Session = util.PROVIDED_SESSION, - **kwargs: t.Any, - ) -> None: - _delete_xcom( - common.SNAPSHOT_CLEANUP_COMMAND_XCOM_KEY, - common.JANITOR_TASK_ID, - context, - session, - ) - - def _get_command_payload(self, context: Context) -> t.Optional[commands.CleanupCommandPayload]: - command = commands.CleanupCommandPayload.parse_raw( - context["ti"].xcom_pull(key=common.SNAPSHOT_CLEANUP_COMMAND_XCOM_KEY) - ) - if not command.tasks and not command.environments: - return None - return command - - -class SnapshotCreateTablesTarget(PydanticModel, BaseTarget[commands.CreateTablesCommandPayload]): - """The target which creates physical tables for the given set of new snapshots.""" - - command_type: commands.CommandType = commands.CommandType.CREATE_TABLES - command_handler: t.Callable[[SnapshotEvaluator, commands.CreateTablesCommandPayload], None] = ( - commands.create_tables - ) - - new_snapshots: t.List[Snapshot] - ddl_concurrent_tasks: int - deployability_index: DeployabilityIndex - allow_destructive_snapshots: t.Set[str] - - def _get_command_payload( - self, context: Context - ) -> t.Optional[commands.CreateTablesCommandPayload]: - if not self.new_snapshots: - return None - - return commands.CreateTablesCommandPayload( - target_snapshot_ids=[s.snapshot_id for s in self.new_snapshots], - snapshots=_get_snapshots_with_parents(self.new_snapshots), - deployability_index=self.deployability_index, - allow_destructive_snapshots=self.allow_destructive_snapshots, - ) - - -class SnapshotMigrateTablesTarget(PydanticModel, BaseTarget[commands.MigrateTablesCommandPayload]): - """The target which updates schemas of existing physical tables to bring them in correspondance - with schemas of target snapshots. - """ - - command_type: commands.CommandType = commands.CommandType.MIGRATE_TABLES - command_handler: t.Callable[[SnapshotEvaluator, commands.MigrateTablesCommandPayload], None] = ( - commands.migrate_tables - ) - - snapshots: t.List[Snapshot] - ddl_concurrent_tasks: int - allow_destructive_snapshots: t.Set[str] - - def _get_command_payload( - self, context: Context - ) -> t.Optional[commands.MigrateTablesCommandPayload]: - if not self.snapshots: - return None - - return commands.MigrateTablesCommandPayload( - target_snapshot_ids=[s.snapshot_id for s in self.snapshots], - snapshots=_get_snapshots_with_parents(self.snapshots), - allow_destructive_snapshots=self.allow_destructive_snapshots, - ) - - -def _get_snapshots_with_parents(snapshots: t.Iterable[Snapshot]) -> t.List[Snapshot]: - snapshots_by_id = {s.snapshot_id: s for s in snapshots} - - parent_snapshot_ids = {p_sid for snapshot in snapshots for p_sid in snapshot.parents} - missing_parent_ids = parent_snapshot_ids - set(snapshots_by_id.keys()) - - existing_snapshots = list(snapshots_by_id.values()) - - if not missing_parent_ids: - return existing_snapshots - - with util.scoped_state_sync() as state_sync: - return existing_snapshots + list(state_sync.get_snapshots(missing_parent_ids).values()) - - -def _delete_xcom(key: str, task_id: str, context: Context, session: Session) -> None: - ti = context["ti"] - util.delete_xcoms( - ti.dag_id, - {key}, - task_id=task_id, - run_id=ti.run_id, - session=session, - ) diff --git a/sqlmesh/schedulers/airflow/operators/trino.py b/sqlmesh/schedulers/airflow/operators/trino.py deleted file mode 100644 index a6267fe117..0000000000 --- a/sqlmesh/schedulers/airflow/operators/trino.py +++ /dev/null @@ -1,28 +0,0 @@ -from __future__ import annotations - -import typing as t - -from airflow.providers.trino.hooks.trino import TrinoHook - -from sqlmesh.schedulers.airflow.operators.base import BaseDbApiOperator -from sqlmesh.schedulers.airflow.operators.targets import BaseTarget - - -class SQLMeshTrinoOperator(BaseDbApiOperator): - """The operator that evaluates a SQLMesh model snapshot on a Trino target - - Args: - target: The target that will be executed by this operator instance. - trino_conn_id: The Airflow connection id for the trino target. - """ - - def __init__( - self, - *, - target: BaseTarget, - trino_conn_id: str = TrinoHook.default_conn_name, - **kwargs: t.Any, - ) -> None: - super().__init__( - target=target, conn_id=trino_conn_id, dialect="trino", hook_type=TrinoHook, **kwargs - ) diff --git a/sqlmesh/schedulers/airflow/plan.py b/sqlmesh/schedulers/airflow/plan.py deleted file mode 100644 index e5941f5418..0000000000 --- a/sqlmesh/schedulers/airflow/plan.py +++ /dev/null @@ -1,198 +0,0 @@ -from __future__ import annotations - -import typing as t - -import pandas as pd -from sqlglot import exp - -from sqlmesh.core import scheduler -from sqlmesh.core.engine_adapter import EngineAdapter -from sqlmesh.core.environment import Environment -from sqlmesh.core.plan import update_intervals_for_new_snapshots -from sqlmesh.core.snapshot import DeployabilityIndex, SnapshotTableInfo -from sqlmesh.core.state_sync import EngineAdapterStateSync, StateSync -from sqlmesh.core.state_sync.base import DelegatingStateSync -from sqlmesh.schedulers.airflow import common -from sqlmesh.utils.date import now, to_timestamp -from sqlmesh.utils.errors import SQLMeshError - - -class PlanDagState: - def __init__(self, engine_adapter: EngineAdapter, plan_dags_table: exp.Table): - self.engine_adapter = engine_adapter - - self._plan_dags_table = plan_dags_table - - self._plan_dag_columns_to_types = { - "request_id": exp.DataType.build("text"), - "dag_id": exp.DataType.build("text"), - "dag_spec": exp.DataType.build("text"), - } - - @classmethod - def from_state_sync(cls, state_sync: StateSync) -> PlanDagState: - while isinstance(state_sync, DelegatingStateSync): - state_sync = state_sync.state_sync - if not isinstance(state_sync, EngineAdapterStateSync): - raise SQLMeshError(f"Unsupported state sync {state_sync.__class__.__name__}") - return cls(state_sync.engine_adapter, state_sync.plan_dags_table) - - def add_dag_spec(self, spec: common.PlanDagSpec) -> None: - """Adds a new DAG spec to the state. - - Args: - spec: the plan DAG spec to add. - """ - df = pd.DataFrame( - [ - { - "request_id": spec.request_id, - "dag_id": common.plan_application_dag_id( - spec.environment.name, spec.request_id - ), - "dag_spec": spec.json(), - } - ] - ) - self.engine_adapter.insert_append( - self._plan_dags_table, - df, - columns_to_types=self._plan_dag_columns_to_types, - ) - - def get_dag_specs(self) -> t.List[common.PlanDagSpec]: - """Returns all DAG specs in the state.""" - query = exp.select("dag_spec").from_(self._plan_dags_table) - return [ - common.PlanDagSpec.parse_raw(row[0]) - for row in self.engine_adapter.fetchall( - query, ignore_unsupported_errors=True, quote_identifiers=True - ) - ] - - def delete_dag_specs(self, dag_ids: t.Collection[str]) -> None: - """Deletes the DAG specs with the given DAG IDs.""" - if not dag_ids: - return - self.engine_adapter.delete_from( - self._plan_dags_table, - where=exp.column("dag_id").isin(*dag_ids), - ) - - -def create_plan_dag_spec( - request: common.PlanApplicationRequest, state_sync: StateSync -) -> common.PlanDagSpec: - new_snapshots = {s.snapshot_id: s for s in request.new_snapshots} - stored_snapshots = state_sync.get_snapshots([*new_snapshots, *request.environment.snapshots]) - all_snapshots = {**new_snapshots, **stored_snapshots} - - all_snaphots_by_name = {s.name: s for s in all_snapshots.values()} - restatements = { - all_snaphots_by_name[n].snapshot_id: i - for n, i in request.restatements.items() - if n in all_snaphots_by_name - } - - duplicated_snapshots = set(stored_snapshots).intersection(new_snapshots) - if duplicated_snapshots: - raise SQLMeshError( - f"Snapshots {duplicated_snapshots} already exist. " - "Make sure your code base is up to date and try re-creating the plan" - ) - - update_intervals_for_new_snapshots(new_snapshots.values(), state_sync) - - now_dt = now() - end = request.environment.end_at or now_dt - unpaused_dt = end if not request.is_dev and not request.restatements else None - - if request.restatements: - intervals_to_remove = [ - (s, restatements[s.snapshot_id]) - for s in all_snapshots.values() - if s.snapshot_id in restatements and s.snapshot_id not in new_snapshots - ] - state_sync.remove_interval( - intervals_to_remove, - remove_shared_versions=not request.is_dev, - ) - for s, interval in intervals_to_remove: - all_snapshots[s.snapshot_id].remove_interval(interval) - - deployability_index_for_creation = DeployabilityIndex.create(all_snapshots) - deployability_index_for_evaluation = ( - deployability_index_for_creation if request.is_dev else DeployabilityIndex.all_deployable() - ) - - if not request.skip_backfill: - backfill_batches = scheduler.compute_interval_params( - [s for s in all_snapshots.values() if request.is_selected_for_backfill(s.name)], - start=request.environment.start_at, - end=end, - execution_time=request.execution_time or now(), - deployability_index=deployability_index_for_evaluation, - restatements=restatements, - end_bounded=request.end_bounded, - signal_factory=None, - ) - else: - backfill_batches = {} - - backfill_intervals_per_snapshot = [ - common.BackfillIntervalsPerSnapshot( - snapshot_id=s.snapshot_id, - intervals=intervals, - before_promote=request.is_dev or deployability_index_for_creation.is_representative(s), - ) - for s, intervals in backfill_batches.items() - ] - - no_gaps_snapshot_names = ( - { - s.name - for s in all_snapshots.values() - if deployability_index_for_creation.is_representative(s) - } - if request.no_gaps and not request.is_dev - else None - if request.no_gaps - else set() - ) - - return common.PlanDagSpec( - request_id=request.request_id, - environment=request.environment, - new_snapshots=request.new_snapshots, - backfill_intervals_per_snapshot=backfill_intervals_per_snapshot, - demoted_snapshots=_get_demoted_snapshots(request.environment, state_sync), - unpaused_dt=unpaused_dt, - no_gaps=request.no_gaps, - notification_targets=request.notification_targets, - backfill_concurrent_tasks=request.backfill_concurrent_tasks, - ddl_concurrent_tasks=request.ddl_concurrent_tasks, - users=request.users, - is_dev=request.is_dev, - allow_destructive_snapshots=request.allow_destructive_snapshots, - forward_only=request.forward_only, - dag_start_ts=to_timestamp(now_dt), - deployability_index=deployability_index_for_evaluation, - deployability_index_for_creation=deployability_index_for_creation, - no_gaps_snapshot_names=no_gaps_snapshot_names, - models_to_backfill=request.models_to_backfill, - ensure_finalized_snapshots=request.ensure_finalized_snapshots, - directly_modified_snapshots=request.directly_modified_snapshots, - indirectly_modified_snapshots=request.indirectly_modified_snapshots, - removed_snapshots=request.removed_snapshots, - execution_time=request.execution_time, - ) - - -def _get_demoted_snapshots( - new_environment: Environment, state_sync: StateSync -) -> t.List[SnapshotTableInfo]: - current_environment = state_sync.get_environment(new_environment.name) - if current_environment: - preserved_snapshot_names = {s.name for s in new_environment.snapshots} - return [s for s in current_environment.snapshots if s.name not in preserved_snapshot_names] - return [] diff --git a/sqlmesh/schedulers/airflow/plugin.py b/sqlmesh/schedulers/airflow/plugin.py deleted file mode 100644 index 3f7e303ed6..0000000000 --- a/sqlmesh/schedulers/airflow/plugin.py +++ /dev/null @@ -1,63 +0,0 @@ -from __future__ import annotations - -import logging -import os -import time -import typing as t - -from airflow.models import Variable -from airflow.plugins_manager import AirflowPlugin - -from sqlmesh.core import constants as c -from sqlmesh.schedulers.airflow import util, NO_DEFAULT_CATALOG -from sqlmesh.schedulers.airflow.api import sqlmesh_api_v1 -from sqlmesh.schedulers.airflow.common import ( - DEFAULT_CATALOG_VARIABLE_NAME, -) -from sqlmesh.utils.errors import SQLMeshError - -logger = logging.getLogger(__name__) - - -class SqlmeshAirflowPlugin(AirflowPlugin): - name = c.SQLMESH - flask_blueprints = [sqlmesh_api_v1] - - @classmethod - def on_load(cls, *args: t.Any, **kwargs: t.Any) -> None: - if os.environ.get("MWAA_AIRFLOW_COMPONENT", "").lower() == "webserver": - # When using MWAA, the Webserver instance might not have access to the external state database. - logger.info("MWAA Webserver instance detected. Skipping SQLMesh state migration...") - return - - # We want to different an expected None default catalog (where the user set `NO_DEFAULT_CATALOG`) - # and where the default catalog is not set at all. - default_catalog = Variable.get( - DEFAULT_CATALOG_VARIABLE_NAME, default_var="MISSING_REQUIRED_CATALOG" - ) - if default_catalog == NO_DEFAULT_CATALOG: - # If the user explicitly set `NO_DEFAULT_CATALOG` we want to set the default catalog to None. - default_catalog = None - - with util.scoped_state_sync() as state_sync: - try: - # If default catalog is required but missing (and not explicitly set to None) we want to raise unless - # this is a fresh install since we know nothing needs to be migrated and - # the client will prevent making any changes until the default catalog is set. - if default_catalog == "MISSING_REQUIRED_CATALOG": - versions = state_sync.get_versions(validate=False) - if versions.schema_version != 0: - raise SQLMeshError( - "Must define `default_catalog` when creating `SQLMeshAirflow` object. See docs for more info: https://sqlmesh.readthedocs.io/en/stable/integrations/airflow/#airflow-cluster-configuration" - ) - logger.info("Migrating SQLMesh state ...") - state_sync.migrate(default_catalog=default_catalog) - except Exception as ex: - # This method is called once for each Gunicorn worker spawned by the Airflow Webserver, - # which leads to SQLMesh schema being initialized concurrently from multiple processes. - # There is a known issue in Postgres (https://stackoverflow.com/a/29908840) which occurs - # due to a race condition when a new schema is being created concurrently. Here we retry - # the schema initialization once as a workaround. - logger.warning("Failed to initialize the SQLMesh State Sync: %s. Retrying...", ex) - time.sleep(1) - state_sync.migrate(default_catalog=default_catalog) diff --git a/sqlmesh/schedulers/airflow/state_sync.py b/sqlmesh/schedulers/airflow/state_sync.py deleted file mode 100644 index 1cb89cbd9f..0000000000 --- a/sqlmesh/schedulers/airflow/state_sync.py +++ /dev/null @@ -1,349 +0,0 @@ -from __future__ import annotations - -import logging -import typing as t - -from sqlmesh.core.console import Console -from sqlmesh.core.environment import Environment -from sqlmesh.core.snapshot import ( - Snapshot, - SnapshotId, - SnapshotIdLike, - SnapshotInfoLike, - SnapshotTableCleanupTask, -) -from sqlmesh.core.snapshot.definition import Interval, SnapshotIntervals -from sqlmesh.core.state_sync import StateSync, Versions -from sqlmesh.core.state_sync.base import PromotionResult -from sqlmesh.schedulers.airflow.client import AirflowClient - -if t.TYPE_CHECKING: - from sqlmesh.utils.date import TimeLike - -logger = logging.getLogger(__name__) - - -class HttpStateSync(StateSync): - """Reads state of models and snapshot through the Airflow REST API. - - Args: - airflow_url: URL pointing to the airflow rest api. - username: Username for Airflow. - password: Password for Airflow. - blocking_updates: Indicates whether calls that cause state updates should be blocking. - dag_run_poll_interval_secs: Determines how frequently the state of a DAG run should be checked. - Used to block on calls that update the state. - console: Used to print out tracking URLs. - """ - - def __init__( - self, - client: AirflowClient, - blocking_updates: bool = True, - dag_run_poll_interval_secs: int = 2, - console: t.Optional[Console] = None, - ): - self._client = client - self.blocking_updates = blocking_updates - self.dag_run_poll_interval_secs = dag_run_poll_interval_secs - self.console = console - - def get_environment(self, environment: str) -> t.Optional[Environment]: - """Fetches the environment if it exists. - - Args: - environment: The environment - - Returns: - The environment object. - """ - return self._client.get_environment(environment) - - def get_environments(self) -> t.List[Environment]: - """Fetches all environments. - - Returns: - A list of all environments. - """ - return self._client.get_environments() - - def max_interval_end_for_environment( - self, environment: str, ensure_finalized_snapshots: bool = False - ) -> t.Optional[int]: - """Returns the max interval end for the given environment. - - Args: - environment: The environment. - ensure_finalized_snapshots: Whether to use snapshots from the latest finalized environment state, - or to use whatever snapshots are in the current environment state even if the environment is not finalized. - - Returns: - A timestamp or None if no interval or environment exists. - """ - return self._client.max_interval_end_for_environment( - environment, ensure_finalized_snapshots - ) - - def greatest_common_interval_end( - self, environment: str, models: t.Set[str], ensure_finalized_snapshots: bool = False - ) -> t.Optional[int]: - """Returns the greatest common interval end for given models in the target environment. - - Args: - environment: The environment. - models: The model FQNs to select intervals from. - ensure_finalized_snapshots: Whether to use snapshots from the latest finalized environment state, - or to use whatever snapshots are in the current environment state even if the environment is not finalized. - - Returns: - A timestamp or None if no interval or environment exists. - """ - return self._client.greatest_common_interval_end( - environment, models, ensure_finalized_snapshots - ) - - def get_snapshots( - self, - snapshot_ids: t.Optional[t.Iterable[SnapshotIdLike]], - ) -> t.Dict[SnapshotId, Snapshot]: - """Gets multiple snapshots from the rest api. - - Because of the limitations of the Airflow API, this method is inherently inefficient. - It's impossible to bulkfetch the snapshots and thus every snapshot needs to make an individual - call to the rest api. Multiple threads can be used, but it could possibly have detrimental effects - on the production server. - """ - snapshots = self._client.get_snapshots( - [s.snapshot_id for s in snapshot_ids] if snapshot_ids is not None else None - ) - return {snapshot.snapshot_id: snapshot for snapshot in snapshots} - - def snapshots_exist(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> t.Set[SnapshotId]: - """Checks if multiple snapshots exist in the state sync. - - Args: - snapshot_ids: Iterable of snapshot ids to bulk check. - - Returns: - A set of existing snapshot IDs. - """ - if not snapshot_ids: - return set() - return self._client.snapshots_exist([s.snapshot_id for s in snapshot_ids]) - - def nodes_exist(self, names: t.Iterable[str], exclude_external: bool = False) -> t.Set[str]: - """Returns the node names that exist in the state sync. - - Args: - names: Iterable of node names to check. - exclude_external: Whether to exclude external models from the output. - - Returns: - A set of all the existing node names. - """ - return self._client.nodes_exist(names, exclude_external=exclude_external) - - def _get_versions(self, lock_for_update: bool = False) -> Versions: - """Queries the store to get the migration. - - Args: - lock_for_update: Whether or not the usage of this method plans to update the row. - - Returns: - The versions object. - """ - return self._client.get_versions() - - def push_snapshots(self, snapshots: t.Iterable[Snapshot]) -> None: - """Push snapshots into the state sync. - - This method only allows for pushing new snapshots. If existing snapshots are found, - this method should raise an error. - - Raises: - SQLMeshError when existing snapshots are pushed. - - Args: - snapshots: A list of snapshots to save in the state sync. - """ - raise NotImplementedError("Pushing snapshots is not supported by the Airflow state sync.") - - def delete_snapshots(self, snapshot_ids: t.Iterable[SnapshotIdLike]) -> None: - """Delete snapshots from the state sync. - - Args: - snapshot_ids: A list of snapshot like objects to delete. - """ - raise NotImplementedError("Deleting snapshots is not supported by the Airflow state sync.") - - def delete_expired_snapshots( - self, ignore_ttl: bool = False - ) -> t.List[SnapshotTableCleanupTask]: - """Removes expired snapshots. - - Expired snapshots are snapshots that have exceeded their time-to-live - and are no longer in use within an environment. - - Returns: - The list of table cleanup tasks. - """ - raise NotImplementedError( - "Deleting expired snapshots is not supported by the Airflow state sync." - ) - - def invalidate_environment(self, name: str) -> None: - """Invalidates the target environment by setting its expiration timestamp to now. - - Args: - name: The name of the environment to invalidate. - """ - self._client.invalidate_environment(name) - - def add_interval( - self, - snapshot: Snapshot, - start: TimeLike, - end: TimeLike, - is_dev: bool = False, - ) -> None: - """Add an interval to a snapshot and sync it to the store. - - Snapshots must be pushed before adding intervals to them. - - Args: - snapshot: The snapshot like object to add an interval to. - start: The start of the interval to add. - end: The end of the interval to add. - is_dev: Indicates whether the given interval is being added while in - development mode. - """ - raise NotImplementedError("Adding intervals is not supported by the Airflow state sync.") - - def _add_snapshot_intervals(self, snapshot_intervals: SnapshotIntervals) -> None: - raise NotImplementedError("Adding intervals is not supported by the Airflow state sync.") - - def remove_interval( - self, - snapshot_intervals: t.Sequence[t.Tuple[SnapshotInfoLike, Interval]], - remove_shared_versions: bool = False, - ) -> None: - """Remove an interval from a list of snapshots and sync it to the store. - - Because multiple snapshots can be pointing to the same version or physical table, this method - can also grab all snapshots tied to the passed in version. - - Args: - snapshots: The snapshot info like object to remove intervals from. - start: The start of the interval to add. - end: The end of the interval to add. - all_snapshots: All snapshots can be passed in to skip fetching matching snapshot versions. - """ - raise NotImplementedError("Removing intervals is not supported by the Airflow state sync.") - - def refresh_snapshot_intervals(self, snapshots: t.Collection[Snapshot]) -> t.List[Snapshot]: - """Updates given snapshots with latest intervals from the state. - - Args: - snapshots: The snapshots to refresh. - - Returns: - The updated snapshots. - """ - raise NotImplementedError( - "Refreshing snapshot intervals is not supported by the Airflow state sync." - ) - - def promote( - self, - environment: Environment, - no_gaps_snapshot_names: t.Optional[t.Set[str]] = None, - ) -> PromotionResult: - """Update the environment to reflect the current state. - - This method verifies that snapshots have been pushed. - - Args: - environment: The environment to promote. - no_gaps_snapshot_names: A set of snapshot names to check for data gaps. If None, - all snapshots will be checked. The data gap check ensures that models that are already a - part of the target environment have no data gaps when compared against previous - snapshots for same models. - - Returns: - A promotion result object containing added, removed, and removed environment naming info - """ - raise NotImplementedError( - "Promoting environments is not supported by the Airflow state sync." - ) - - def finalize(self, environment: Environment) -> None: - """Finalize the target environment, indicating that this environment has been - fully promoted and is ready for use. - - Args: - environment: The target environment to finalize. - """ - raise NotImplementedError( - "Finalizing environments is not supported by the Airflow state sync." - ) - - def delete_expired_environments(self) -> t.List[Environment]: - """Removes expired environments. - - Expired environments are environments that have exceeded their time-to-live value. - - Returns: - The list of removed environments. - """ - raise NotImplementedError( - "Deleting expired environments is not supported by the Airflow state sync." - ) - - def unpause_snapshots( - self, snapshots: t.Collection[SnapshotInfoLike], unpaused_dt: TimeLike - ) -> None: - """Unpauses target snapshots. - - Unpaused snapshots are scheduled for evaluation on a recurring basis. - Once unpaused a snapshot can't be paused again. - - Args: - snapshots: Target snapshots. - unpaused_dt: The datetime object which indicates when target snapshots - were unpaused. - """ - raise NotImplementedError("Unpausing snapshots is not supported by the Airflow state sync.") - - def compact_intervals(self) -> None: - """Compacts intervals for all snapshots. - - Compaction process involves merging of existing interval records into new records and - then deleting the old ones. - """ - raise NotImplementedError( - "Compacting intervals is not supported by the Airflow state sync." - ) - - def migrate( - self, - default_catalog: t.Optional[str], - skip_backup: bool = False, - promoted_snapshots_only: bool = True, - ) -> None: - """Migrate the state sync to the latest SQLMesh / SQLGlot version.""" - raise NotImplementedError("Migration is not supported by the Airflow state sync.") - - def rollback(self) -> None: - """Rollback to previous backed up state.""" - raise NotImplementedError("Rollback is not supported by the Airflow state sync.") - - def recycle(self) -> None: - """Closes all open connections and releases all allocated resources associated with any thread - except the calling one.""" - - def close(self) -> None: - """Closes all open connections and releases all allocated resources.""" - - def state_type(self) -> str: - """Returns the type of state sync.""" - return "airflow_http" diff --git a/sqlmesh/schedulers/airflow/util.py b/sqlmesh/schedulers/airflow/util.py deleted file mode 100644 index 1eeb9e620e..0000000000 --- a/sqlmesh/schedulers/airflow/util.py +++ /dev/null @@ -1,184 +0,0 @@ -from __future__ import annotations - -import contextlib -import json -import logging -import typing as t -from datetime import timedelta - -from airflow import settings -from airflow.api.common.experimental.delete_dag import delete_dag -from airflow.exceptions import AirflowException, DagNotFound -from airflow.models import BaseOperator, DagRun, DagTag, XCom -from airflow.models.connection import Connection -from airflow.utils.session import provide_session -from airflow.utils.state import DagRunState -from sqlalchemy.orm import Session - -from sqlmesh.core import constants as c -from sqlmesh.core.config import parse_connection_config -from sqlmesh.core.engine_adapter import create_engine_adapter -from sqlmesh.core.state_sync import CachingStateSync, EngineAdapterStateSync, StateSync -from sqlmesh.schedulers.airflow import common -from sqlmesh.utils.date import now -from sqlmesh.utils.errors import SQLMeshError - -logger = logging.getLogger(__name__) - - -# Used to omit Optional for session instances supplied by -# Airflow at runtime. This makes the type signature cleaner -# and prevents mypy from complaining. -PROVIDED_SESSION: Session = t.cast(Session, None) - - -SQLMESH_STATE_CONN_ID = "sqlmesh_state_db" - - -@contextlib.contextmanager -def scoped_state_sync() -> t.Iterator[StateSync]: - state_schema = c.SQLMESH - try: - connection = Connection.get_connection_from_secrets(SQLMESH_STATE_CONN_ID) - - connection_config_dict = json.loads(connection.extra) - state_schema = connection_config_dict.pop("state_schema", state_schema) - if "type" not in connection_config_dict: - logger.info( - "SQLMesh connection in Airflow did not have type defined. " - "Therefore using Airflow database connection" - ) - raise AirflowException - - logger.info("Using connection '%s' for state sync", connection.conn_id) - - connection_config = parse_connection_config(connection_config_dict) - engine_adapter = connection_config.create_engine_adapter() - except AirflowException: - logger.info("Using the Airflow database connection for state sync") - - dialect = settings.engine.dialect.name - engine_adapter = create_engine_adapter( - settings.engine.raw_connection, dialect, multithreaded=True - ) - - try: - yield CachingStateSync(EngineAdapterStateSync(engine_adapter, state_schema)) # type: ignore - finally: - engine_adapter.close() - - -@provide_session -def get_snapshot_dag_ids(session: Session = PROVIDED_SESSION) -> t.List[str]: - dag_tags = session.query(DagTag).filter(DagTag.name == common.SNAPSHOT_AIRFLOW_TAG).all() - return [tag.dag_id for tag in dag_tags] - - -@provide_session -def get_finished_plan_application_dag_ids( - ttl: t.Optional[timedelta] = None, session: Session = PROVIDED_SESSION -) -> t.Set[str]: - dag_ids = ( - session.query(DagTag.dag_id) - .join(DagRun, DagTag.dag_id == DagRun.dag_id) - .filter( - DagTag.name == common.PLAN_AIRFLOW_TAG, - DagRun.state.in_((DagRunState.SUCCESS, DagRunState.FAILED)), - ) - ) - if ttl is not None: - dag_ids = dag_ids.filter(DagRun.last_scheduling_decision <= now() - ttl) - return {dag_id[0] for dag_id in dag_ids.all()} - - -@provide_session -def delete_dags(dag_ids: t.Set[str], session: Session = PROVIDED_SESSION) -> None: - for dag_id in dag_ids: - try: - delete_dag(dag_id, session=session) - except DagNotFound: - logger.warning("DAG '%s' was not found", dag_id) - except AirflowException: - logger.warning("Failed to delete DAG '%s'", dag_id, exc_info=True) - - -@provide_session -def delete_xcoms( - dag_id: str, - keys: t.Set[str], - task_id: t.Optional[str] = None, - run_id: t.Optional[str] = None, - session: Session = PROVIDED_SESSION, -) -> None: - query = session.query(XCom).filter(XCom.dag_id == dag_id, XCom.key.in_(keys)) - if task_id is not None: - query = query.filter_by(task_id=task_id) - if run_id is not None: - query = query.filter_by(run_id=run_id) - query.delete(synchronize_session=False) - - -def discover_engine_operator(name: str, sql_only: bool = False) -> t.Type[BaseOperator]: - name = name.lower() - - try: - if name == "spark": - from sqlmesh.schedulers.airflow.operators.spark_submit import ( - SQLMeshSparkSubmitOperator, - ) - - return SQLMeshSparkSubmitOperator - if name in ("databricks", "databricks-submit", "databricks-sql"): - if name == "databricks-submit" or (name == "databricks" and not sql_only): - from sqlmesh.schedulers.airflow.operators.databricks import ( - SQLMeshDatabricksSubmitOperator, - ) - - return SQLMeshDatabricksSubmitOperator - if name == "databricks-sql" or (name == "databricks" and sql_only): - from sqlmesh.schedulers.airflow.operators.databricks import ( - SQLMeshDatabricksSQLOperator, - ) - - return SQLMeshDatabricksSQLOperator - if name == "snowflake": - from sqlmesh.schedulers.airflow.operators.snowflake import ( - SQLMeshSnowflakeOperator, - ) - - return SQLMeshSnowflakeOperator - if name == "bigquery": - from sqlmesh.schedulers.airflow.operators.bigquery import ( - SQLMeshBigQueryOperator, - ) - - return SQLMeshBigQueryOperator - if name == "redshift": - from sqlmesh.schedulers.airflow.operators.redshift import ( - SQLMeshRedshiftOperator, - ) - - return SQLMeshRedshiftOperator - if name in ("postgres", "postgresql"): - from sqlmesh.schedulers.airflow.operators.postgres import ( - SQLMeshPostgresOperator, - ) - - return SQLMeshPostgresOperator - if name == "trino": - from sqlmesh.schedulers.airflow.operators.trino import SQLMeshTrinoOperator - - return SQLMeshTrinoOperator - if name == "mssql": - from sqlmesh.schedulers.airflow.operators.mssql import SQLMeshMsSqlOperator - - return SQLMeshMsSqlOperator - - if name == "mysql": - from sqlmesh.schedulers.airflow.operators.mysql import SQLMeshMySqlOperator - - return SQLMeshMySqlOperator - except ImportError: - raise SQLMeshError(f"Failed to automatically discover an operator for '{name}'.'") - - raise ValueError(f"Unsupported engine name '{name}'.") diff --git a/sqlmesh/utils/__init__.py b/sqlmesh/utils/__init__.py index f25dfb5102..5b1b077216 100644 --- a/sqlmesh/utils/__init__.py +++ b/sqlmesh/utils/__init__.py @@ -13,12 +13,15 @@ import types import typing as t import uuid +from dataclasses import dataclass from collections import defaultdict from contextlib import contextmanager from copy import deepcopy +from enum import IntEnum, Enum from functools import lru_cache, reduce, wraps from pathlib import Path +import unicodedata from sqlglot import exp from sqlglot.dialects.dialect import Dialects @@ -65,7 +68,7 @@ def random_id(short: bool = False) -> str: return uuid.uuid4().hex -class UniqueKeyDict(dict, t.Mapping[KEY, VALUE]): +class UniqueKeyDict(t.Dict[KEY, VALUE]): """Dict that raises when a duplicate key is set.""" def __init__(self, name: str, *args: t.Dict[KEY, VALUE], **kwargs: VALUE) -> None: @@ -81,7 +84,10 @@ def __setitem__(self, k: KEY, v: VALUE) -> None: class AttributeDict(dict, t.Mapping[KEY, VALUE]): - __getattr__ = dict.get + def __getattr__(self, key: t.Any) -> t.Optional[VALUE]: + if key.startswith("__") and not hasattr(self, key): + raise AttributeError + return self.get(key) def set(self, field: str, value: t.Any) -> str: self[field] = value @@ -100,6 +106,9 @@ def __call__(self, **kwargs: t.Dict[str, t.Any]) -> str: # Return an empty string, so that this method can be used within Jinja return "" + def __getstate__(self) -> t.Optional[t.Dict[t.Any, t.Any]]: + return None + class registry_decorator: """A decorator that registers itself.""" @@ -129,7 +138,7 @@ def __call__( except ValueError: # No need to raise due to duplicate key if the functions are identical if func.__code__.co_code != registry[func_name].func.__code__.co_code: - raise + raise ValueError(f"Duplicate name: '{func_name}'.") @wraps(func) def wrapper(*args: t.Any, **kwargs: t.Any) -> DECORATOR_RETURN_TYPE: @@ -170,8 +179,7 @@ def sys_path(*paths: Path) -> t.Iterator[None]: def format_exception(exception: BaseException) -> t.List[str]: if sys.version_info < (3, 10): return traceback.format_exception(type(exception), exception, exception.__traceback__) # type: ignore - else: - return traceback.format_exception(exception) # type: ignore + return traceback.format_exception(exception) # type: ignore def word_characters_only(s: str, replacement_char: str = "_") -> str: @@ -284,8 +292,14 @@ def sqlglot_dialects() -> str: NON_ALNUM = re.compile(r"[^a-zA-Z0-9_]") +NON_ALUM_INCLUDE_UNICODE = re.compile(r"\W", flags=re.UNICODE) -def sanitize_name(name: str) -> str: + +def sanitize_name(name: str, *, include_unicode: bool = False) -> str: + if include_unicode: + s = unicodedata.normalize("NFC", name) + s = NON_ALUM_INCLUDE_UNICODE.sub("_", s) + return s return NON_ALNUM.sub("_", name) @@ -332,3 +346,79 @@ def type_is_known(d_type: t.Union[exp.DataType, exp.ColumnDef]) -> bool: def columns_to_types_all_known(columns_to_types: t.Dict[str, exp.DataType]) -> bool: """Checks that all column types are known and not NULL.""" return all(type_is_known(expression) for expression in columns_to_types.values()) + + +class Verbosity(IntEnum): + """Verbosity levels for SQLMesh output.""" + + DEFAULT = 0 + VERBOSE = 1 + VERY_VERBOSE = 2 + + @property + def is_default(self) -> bool: + return self == Verbosity.DEFAULT + + @property + def is_verbose(self) -> bool: + return self == Verbosity.VERBOSE + + @property + def is_very_verbose(self) -> bool: + return self == Verbosity.VERY_VERBOSE + + +class CompletionStatus(Enum): + SUCCESS = "success" + FAILURE = "failure" + NOTHING_TO_DO = "nothing_to_do" + + @property + def is_success(self) -> bool: + return self == CompletionStatus.SUCCESS + + @property + def is_failure(self) -> bool: + return self == CompletionStatus.FAILURE + + @property + def is_nothing_to_do(self) -> bool: + return self == CompletionStatus.NOTHING_TO_DO + + +def to_snake_case(name: str) -> str: + return "".join( + f"_{c.lower()}" if c.isupper() and idx != 0 else c.lower() for idx, c in enumerate(name) + ) + + +class JobType(Enum): + PLAN = "SQLMESH_PLAN" + RUN = "SQLMESH_RUN" + + +@dataclass(frozen=True) +class CorrelationId: + """ID that is added to each query in order to identify the job that created it.""" + + job_type: JobType + job_id: str + + def __str__(self) -> str: + return f"{self.job_type.value}: {self.job_id}" + + @classmethod + def from_plan_id(cls, plan_id: str) -> CorrelationId: + return CorrelationId(JobType.PLAN, plan_id) + + +def get_source_columns_to_types( + columns_to_types: t.Dict[str, exp.DataType], + source_columns: t.Optional[t.List[str]], +) -> t.Dict[str, exp.DataType]: + source_column_lookup = set(source_columns) if source_columns else None + return { + k: v + for k, v in columns_to_types.items() + if not source_column_lookup or k in source_column_lookup + } diff --git a/sqlmesh/utils/aws.py b/sqlmesh/utils/aws.py new file mode 100644 index 0000000000..ed9c4f723c --- /dev/null +++ b/sqlmesh/utils/aws.py @@ -0,0 +1,39 @@ +import typing as t +from urllib.parse import urlparse + +from sqlmesh.utils.errors import SQLMeshError + + +def validate_s3_uri( + value: str, base: bool = False, error_type: t.Type[Exception] = SQLMeshError +) -> str: + if not value.startswith("s3://"): + raise error_type(f"Location '{value}' must be a s3:// URI") + + if base and not value.endswith("/"): + value = value + "/" + + # To avoid HIVE_METASTORE_ERROR: S3 resource path length must be less than or equal to 700. + if len(value) > 700: + raise error_type(f"Location '{value}' cannot be more than 700 characters") + + return value + + +def parse_s3_uri(s3_uri: str) -> t.Tuple[str, str]: + """ + Given a s3:// URI, parse it into a pair of (bucket, key) + + Note that this could be any URI, including a file key, so we dont add a trailing / unlike validate_s3_base_uri + """ + validate_s3_uri(s3_uri) + + parsed_uri = urlparse(s3_uri) + + bucket = parsed_uri.netloc + key = parsed_uri.path + + if key: + key = key[1:] # trim off leading / + + return bucket, key diff --git a/sqlmesh/utils/cache.py b/sqlmesh/utils/cache.py index 79ada3c2a2..e72c34f632 100644 --- a/sqlmesh/utils/cache.py +++ b/sqlmesh/utils/cache.py @@ -3,6 +3,7 @@ import gzip import logging import pickle +import shutil import typing as t from pathlib import Path @@ -11,11 +12,11 @@ from sqlmesh.utils import sanitize_name from sqlmesh.utils.date import to_datetime from sqlmesh.utils.errors import SQLMeshError -from sqlmesh.utils.pydantic import PydanticModel +from sqlmesh.utils.windows import IS_WINDOWS, fix_windows_path logger = logging.getLogger(__name__) -T = t.TypeVar("T", bound=PydanticModel) +T = t.TypeVar("T") SQLGLOT_VERSION_TUPLE = tuple(SQLGLOT_VERSION.split(".")) @@ -33,14 +34,8 @@ class FileCache(t.Generic[T]): stored in the same cache folder. """ - def __init__( - self, - path: Path, - entry_class: t.Type[T], - prefix: t.Optional[str] = None, - ): + def __init__(self, path: Path, prefix: t.Optional[str] = None): self._path = path / prefix if prefix else path - self._entry_class = entry_class from sqlmesh.core.state_sync.base import SCHEMA_VERSION @@ -64,7 +59,11 @@ def __init__( threshold = to_datetime("1 week ago").timestamp() # delete all old cache files for file in self._path.glob("*"): - if not file.stem.startswith(self._cache_version) or file.stat().st_mtime < threshold: + if IS_WINDOWS: + # the file.stat() call below will fail on windows if the :file name is longer than 260 chars + file = fix_windows_path(file) + + if not file.stem.startswith(self._cache_version) or file.stat().st_atime < threshold: file.unlink(missing_ok=True) def get_or_load(self, name: str, entry_id: str = "", *, loader: t.Callable[[], T]) -> T: @@ -79,7 +78,7 @@ def get_or_load(self, name: str, entry_id: str = "", *, loader: t.Callable[[], T The entry. """ cached_entry = self.get(name, entry_id) - if cached_entry: + if cached_entry is not None: return cached_entry loaded_entry = loader() @@ -100,7 +99,7 @@ def get(self, name: str, entry_id: str = "") -> t.Optional[T]: if cache_entry_path.exists(): with gzip.open(cache_entry_path, "rb") as fd: try: - return self._entry_class.parse_obj(pickle.load(fd)) + return pickle.load(fd) except Exception as ex: logger.warning("Failed to load a cache entry '%s': %s", name, ex) @@ -119,8 +118,27 @@ def put(self, name: str, entry_id: str = "", *, value: T) -> None: raise SQLMeshError(f"Cache path '{self._path}' is not a directory.") with gzip.open(self._cache_entry_path(name, entry_id), "wb", compresslevel=1) as fd: - pickle.dump(value.dict(exclude_none=False), fd) + pickle.dump(value, fd) + + def exists(self, name: str, entry_id: str = "") -> bool: + """Returns true if the cache entry with the given name and ID exists, false otherwise. + + Args: + name: The name of the entry. + entry_id: The unique entry identifier. Used for cache invalidation. + """ + return self._cache_entry_path(name, entry_id).exists() + + def clear(self) -> None: + try: + shutil.rmtree(str(self._path.absolute())) + except Exception: + pass def _cache_entry_path(self, name: str, entry_id: str = "") -> Path: entry_file_name = "__".join(p for p in (self._cache_version, name, entry_id) if p) - return self._path / sanitize_name(entry_file_name) + full_path = self._path / sanitize_name(entry_file_name, include_unicode=True) + if IS_WINDOWS: + # handle paths longer than 260 chars + full_path = fix_windows_path(full_path) + return full_path diff --git a/sqlmesh/utils/concurrency.py b/sqlmesh/utils/concurrency.py index 9881bb5b0a..c5f78645f6 100644 --- a/sqlmesh/utils/concurrency.py +++ b/sqlmesh/utils/concurrency.py @@ -105,16 +105,23 @@ def _skip_next_nodes(self, parent: H) -> None: self._finished_future.set_result(None) return - skipped_nodes = [node for node, deps in self._unprocessed_nodes.items() if parent in deps] + skipped_nodes = {node for node, deps in self._unprocessed_nodes.items() if parent in deps} - self._skipped_nodes.extend(skipped_nodes) + while skipped_nodes: + self._skipped_nodes.extend(skipped_nodes) - for skipped_node in skipped_nodes: - self._unprocessed_nodes_num -= 1 - self._unprocessed_nodes.pop(skipped_node) + for skipped_node in skipped_nodes: + self._unprocessed_nodes_num -= 1 + self._unprocessed_nodes.pop(skipped_node) + + skipped_nodes = { + node + for node, deps in self._unprocessed_nodes.items() + if skipped_nodes.intersection(deps) + } - for skipped_node in skipped_nodes: - self._skip_next_nodes(skipped_node) + if not self._unprocessed_nodes_num: + self._finished_future.set_result(None) def _init_state(self) -> None: self._unprocessed_nodes = self.dag.graph @@ -226,12 +233,12 @@ def sequential_apply_to_dag( try: fn(node) except Exception as ex: - if raise_on_error: - raise NodeExecutionFailedError(node) from ex - error = NodeExecutionFailedError(node) error.__cause__ = ex + if raise_on_error: + raise error + node_errors.append(error) failed_or_skipped_nodes.add(node) diff --git a/sqlmesh/utils/config.py b/sqlmesh/utils/config.py new file mode 100644 index 0000000000..248f3adcc7 --- /dev/null +++ b/sqlmesh/utils/config.py @@ -0,0 +1,76 @@ +from typing import Any, Optional, Set + +from sqlmesh.core.config.connection import ConnectionConfig +from sqlmesh.utils import yaml + + +# Fields that should be excluded from the configuration hash +excluded_fields: Set[str] = { + "concurrent_tasks", + "pre_ping", + "register_comments", +} + +# Sensitive fields that should be masked in the configuration print or hash +sensitive_fields: Set[str] = { + "access_token", + "api_key", + "auth_token", + "client_secret", + "certificate", + "credentials", + "user", + "password", + "keytab", + "keyfile", + "keyfile_json", + "principal", + "private_key", + "private_key_passphrase", + "private_key_path", + "refresh_token", + "secret", + "ssh_key", + "token", +} + + +def is_sensitive_field(field_name: str, sensitive_fields: Set[str]) -> bool: + """ + Check if a field name contains any sensitive keywords + """ + field_lower = field_name.lower() + return any(sensitive in field_lower for sensitive in sensitive_fields) + + +def mask_sensitive_value(value: Any) -> str: + """ + Mask sensitive values with a placeholder + Returns '****' for non-empty values and '' for empty ones + """ + if value and str(value).strip(): + return "****" + return "None" + + +def print_config(config: Optional[ConnectionConfig], console: Any, title: str) -> None: + """ + Print configuration while masking sensitive information + + Args: + config: Pydantic model containing configuration + console: Console object with log_status_update method + """ + if not config: + return + + config_dict = config.dict(mode="json") + + for field_name in config_dict: + if is_sensitive_field(field_name, sensitive_fields): + config_dict[field_name] = mask_sensitive_value(config_dict[field_name]) + + configWithTitle = {title: config_dict} + yaml_output = yaml.dump(configWithTitle) + + console.log_status_update(yaml_output) diff --git a/sqlmesh/utils/connection_pool.py b/sqlmesh/utils/connection_pool.py index e3eb806810..9a70db6885 100644 --- a/sqlmesh/utils/connection_pool.py +++ b/sqlmesh/utils/connection_pool.py @@ -48,6 +48,17 @@ def set_attribute(self, key: str, value: t.Any) -> None: value: Attribute value. """ + @abc.abstractmethod + def get_all_attributes(self, key: str) -> t.List[t.Any]: + """Returns all attributes with the given key across all connections/threads. + + Args: + key: Attribute key. + + Returns: + List of attribute values from all connections/threads. + """ + @abc.abstractmethod def begin(self) -> None: """Starts a new transaction.""" @@ -111,40 +122,29 @@ def _do_rollback(self) -> None: self.get().rollback() -class ThreadLocalConnectionPool(_TransactionManagementMixin): +class _ThreadLocalBase(_TransactionManagementMixin): def __init__( self, connection_factory: t.Callable[[], t.Any], - cursor_kwargs: t.Optional[t.Dict[str, t.Any]] = None, cursor_init: t.Optional[t.Callable[[t.Any], None]] = None, ): self._connection_factory = connection_factory - self._thread_connections: t.Dict[t.Hashable, t.Any] = {} self._thread_cursors: t.Dict[t.Hashable, t.Any] = {} self._thread_transactions: t.Set[t.Hashable] = set() self._thread_attributes: t.Dict[t.Hashable, t.Dict[str, t.Any]] = defaultdict(dict) - self._thread_connections_lock = Lock() self._thread_cursors_lock = Lock() self._thread_transactions_lock = Lock() - self._cursor_kwargs = cursor_kwargs or {} self._cursor_init = cursor_init def get_cursor(self) -> t.Any: thread_id = get_ident() with self._thread_cursors_lock: if thread_id not in self._thread_cursors: - self._thread_cursors[thread_id] = self.get().cursor(**self._cursor_kwargs) + self._thread_cursors[thread_id] = self.get().cursor() if self._cursor_init: self._cursor_init(self._thread_cursors[thread_id]) return self._thread_cursors[thread_id] - def get(self) -> t.Any: - thread_id = get_ident() - with self._thread_connections_lock: - if thread_id not in self._thread_connections: - self._thread_connections[thread_id] = self._connection_factory() - return self._thread_connections[thread_id] - def get_attribute(self, key: str) -> t.Optional[t.Any]: thread_id = get_ident() return self._thread_attributes[thread_id].get(key) @@ -153,6 +153,14 @@ def set_attribute(self, key: str, value: t.Any) -> None: thread_id = get_ident() self._thread_attributes[thread_id][key] = value + def get_all_attributes(self, key: str) -> t.List[t.Any]: + """Returns all attributes with the given key across all threads.""" + return [ + thread_attrs[key] + for thread_attrs in self._thread_attributes.values() + if key in thread_attrs + ] + def begin(self) -> None: self._do_begin() with self._thread_transactions_lock: @@ -178,6 +186,28 @@ def close_cursor(self) -> None: _try_close(self._thread_cursors[thread_id], "cursor") self._thread_cursors.pop(thread_id) + def _discard_transaction(self, thread_id: t.Hashable) -> None: + with self._thread_transactions_lock: + self._thread_transactions.discard(thread_id) + + +class ThreadLocalConnectionPool(_ThreadLocalBase): + def __init__( + self, + connection_factory: t.Callable[[], t.Any], + cursor_init: t.Optional[t.Callable[[t.Any], None]] = None, + ): + super().__init__(connection_factory, cursor_init) + self._thread_connections: t.Dict[t.Hashable, t.Any] = {} + self._thread_connections_lock = Lock() + + def get(self) -> t.Any: + thread_id = get_ident() + with self._thread_connections_lock: + if thread_id not in self._thread_connections: + self._thread_connections[thread_id] = self._connection_factory() + return self._thread_connections[thread_id] + def close(self) -> None: thread_id = get_ident() with self._thread_cursors_lock, self._thread_connections_lock: @@ -193,36 +223,70 @@ def close_all(self, exclude_calling_thread: bool = False) -> None: with self._thread_cursors_lock, self._thread_connections_lock: for thread_id, connection in self._thread_connections.copy().items(): if not exclude_calling_thread or thread_id != calling_thread_id: - # NOTE: the access to the connection instance itself is not thread-safe here. _try_close(connection, "connection") self._thread_connections.pop(thread_id) self._thread_cursors.pop(thread_id, None) self._discard_transaction(thread_id) + + self._thread_attributes.clear() + + +class ThreadLocalSharedConnectionPool(_ThreadLocalBase): + def __init__( + self, + connection_factory: t.Callable[[], t.Any], + cursor_init: t.Optional[t.Callable[[t.Any], None]] = None, + ): + super().__init__(connection_factory, cursor_init) + self._connection: t.Optional[t.Any] = None + self._connection_lock = Lock() + + def get(self) -> t.Any: + with self._connection_lock: + if self._connection is None: + self._connection = self._connection_factory() + return self._connection + + def close(self) -> None: + thread_id = get_ident() + with self._thread_cursors_lock, self._connection_lock: + if thread_id in self._thread_cursors: + _try_close(self._thread_cursors[thread_id], "cursor") + self._thread_cursors.pop(thread_id) + self._discard_transaction(thread_id) + self._thread_attributes.pop(thread_id, None) + + def close_all(self, exclude_calling_thread: bool = False) -> None: + calling_thread_id = get_ident() + with self._thread_cursors_lock, self._connection_lock: + for thread_id, cursor in self._thread_cursors.copy().items(): + if not exclude_calling_thread or thread_id != calling_thread_id: + _try_close(cursor, "cursor") + self._thread_cursors.pop(thread_id) + self._discard_transaction(thread_id) self._thread_attributes.pop(thread_id, None) - def _discard_transaction(self, thread_id: t.Hashable) -> None: - with self._thread_transactions_lock: - self._thread_transactions.discard(thread_id) + if not exclude_calling_thread: + _try_close(self._connection, "connection") + self._connection = None class SingletonConnectionPool(_TransactionManagementMixin): def __init__( self, connection_factory: t.Callable[[], t.Any], - cursor_kwargs: t.Optional[t.Dict[str, t.Any]] = None, cursor_init: t.Optional[t.Callable[[t.Any], None]] = None, ): self._connection_factory = connection_factory self._connection: t.Optional[t.Any] = None self._cursor: t.Optional[t.Any] = None - self._cursor_kwargs = cursor_kwargs or {} self._attributes: t.Dict[str, t.Any] = {} self._is_transaction_active: bool = False self._cursor_init = cursor_init def get_cursor(self) -> t.Any: if not self._cursor: - self._cursor = self.get().cursor(**self._cursor_kwargs) + self._cursor = self.get().cursor() if self._cursor_init: self._cursor_init(self._cursor) return self._cursor @@ -238,6 +302,11 @@ def get_attribute(self, key: str) -> t.Optional[t.Any]: def set_attribute(self, key: str, value: t.Any) -> None: self._attributes[key] = value + def get_all_attributes(self, key: str) -> t.List[t.Any]: + """Returns all attributes with the given key (single-threaded pool has at most one).""" + value = self._attributes.get(key) + return [value] if value is not None else [] + def begin(self) -> None: self._do_begin() self._is_transaction_active = True @@ -273,18 +342,17 @@ def close_all(self, exclude_calling_thread: bool = False) -> None: def create_connection_pool( connection_factory: t.Callable[[], t.Any], multithreaded: bool, - cursor_kwargs: t.Optional[t.Dict[str, t.Any]] = None, + shared_connection: bool = False, cursor_init: t.Optional[t.Callable[[t.Any], None]] = None, ) -> ConnectionPool: - return ( - ThreadLocalConnectionPool( - connection_factory, cursor_kwargs=cursor_kwargs, cursor_init=cursor_init - ) + pool_class = ( + ThreadLocalSharedConnectionPool + if multithreaded and shared_connection + else ThreadLocalConnectionPool if multithreaded - else SingletonConnectionPool( - connection_factory, cursor_kwargs=cursor_kwargs, cursor_init=cursor_init - ) + else SingletonConnectionPool ) + return pool_class(connection_factory, cursor_init=cursor_init) def _try_close(closeable: t.Any, kind: str) -> None: diff --git a/sqlmesh/utils/conversions.py b/sqlmesh/utils/conversions.py index 2b92772022..411f3c8ab1 100644 --- a/sqlmesh/utils/conversions.py +++ b/sqlmesh/utils/conversions.py @@ -1,6 +1,7 @@ from __future__ import annotations import typing as t +from datetime import date, datetime def ensure_bool(val: t.Any) -> bool: @@ -19,3 +20,13 @@ def try_str_to_bool(val: str) -> t.Union[str, bool]: return maybe_bool == "true" return val + + +def make_serializable(obj: t.Any) -> t.Any: + if isinstance(obj, (date, datetime)): + return obj.isoformat() + if isinstance(obj, dict): + return {k: make_serializable(v) for k, v in obj.items()} + if isinstance(obj, list): + return [make_serializable(item) for item in obj] + return obj diff --git a/sqlmesh/utils/cron.py b/sqlmesh/utils/cron.py index 7950f87df2..904202db7c 100644 --- a/sqlmesh/utils/cron.py +++ b/sqlmesh/utils/cron.py @@ -1,7 +1,7 @@ from __future__ import annotations import typing as t -from datetime import datetime, timedelta +from datetime import datetime, timedelta, tzinfo from functools import lru_cache from croniter import croniter @@ -10,7 +10,7 @@ from sqlmesh.utils.date import TimeLike, now, to_datetime -@lru_cache(maxsize=None) +@lru_cache(maxsize=16384) def interval_seconds(cron: str) -> int: """Computes the interval seconds of a cron statement if it is deterministic. @@ -34,21 +34,22 @@ def interval_seconds(cron: str) -> int: class CroniterCache: - def __init__(self, cron: str, time: t.Optional[TimeLike] = None): + def __init__(self, cron: str, time: t.Optional[TimeLike] = None, tz: t.Optional[tzinfo] = None): self.cron = cron - self.curr: datetime = to_datetime(now() if time is None else time) + self.tz = tz + self.curr: datetime = to_datetime(now() if time is None else time, tz=self.tz) self.interval_seconds = interval_seconds(self.cron) def get_next(self, estimate: bool = False) -> datetime: if estimate and self.interval_seconds: self.curr = self.curr + timedelta(seconds=self.interval_seconds) else: - self.curr = to_datetime(croniter(self.cron, self.curr).get_next() * 1000) + self.curr = to_datetime(croniter(self.cron, self.curr).get_next() * 1000, tz=self.tz) return self.curr def get_prev(self, estimate: bool = False) -> datetime: if estimate and self.interval_seconds: self.curr = self.curr - timedelta(seconds=self.interval_seconds) else: - self.curr = to_datetime(croniter(self.cron, self.curr).get_prev() * 1000) + self.curr = to_datetime(croniter(self.cron, self.curr).get_prev() * 1000, tz=self.tz) return self.curr diff --git a/sqlmesh/utils/dag.py b/sqlmesh/utils/dag.py index 69c4585597..c39fd2a1d2 100644 --- a/sqlmesh/utils/dag.py +++ b/sqlmesh/utils/dag.py @@ -19,6 +19,7 @@ class DAG(t.Generic[T]): def __init__(self, graph: t.Optional[t.Dict[T, t.Set[T]]] = None): self._dag: t.Dict[T, t.Set[T]] = {} self._sorted: t.Optional[t.List[T]] = None + self._upstream: t.Dict[T, t.Set[T]] = {} for node, dependencies in (graph or {}).items(): self.add(node, dependencies) @@ -31,6 +32,7 @@ def add(self, node: T, dependencies: t.Optional[t.Iterable[T]] = None) -> None: dependencies: Optional dependencies to add to the node. """ self._sorted = None + self._upstream.clear() if node not in self._dag: self._dag[node] = set() if dependencies: @@ -60,15 +62,15 @@ def subdag(self, *nodes: T) -> DAG[T]: A new dag consisting of the specified nodes and upstream. """ queue = set(nodes) - graph = {} + dag: DAG[T] = DAG() while queue: node = queue.pop() deps = self._dag.get(node, set()) - graph[node] = deps + dag.add(node, deps) queue.update(deps) - return DAG(graph) + return dag def prune(self, *nodes: T) -> DAG[T]: """Create a dag keeping only the included nodes. @@ -79,17 +81,70 @@ def prune(self, *nodes: T) -> DAG[T]: Returns: A new dag consisting of the specified nodes. """ - graph = {} + dag: DAG[T] = DAG() for node, deps in self._dag.items(): if node in nodes: - graph[node] = {dep for dep in deps if dep in nodes} + dag.add(node, (dep for dep in deps if dep in nodes)) + + return dag + + def upstream(self, node: T) -> t.Set[T]: + """Returns all upstream dependencies.""" + if node not in self._upstream: + deps = self._dag.get(node, set()) + self._upstream[node] = { + upstream for dep in deps for upstream in self.upstream(dep) + } | deps + + return self._upstream[node] - return DAG(graph) + def _find_cycle_path(self, nodes_in_cycle: t.Dict[T, t.Set[T]]) -> t.Optional[t.List[T]]: + """Find the exact cycle path using DFS when a cycle is detected. + + Args: + nodes_in_cycle: Dictionary of nodes that are part of the cycle and their dependencies + + Returns: + List of nodes forming the cycle path, or None if no cycle found + """ + if not nodes_in_cycle: + return None - def upstream(self, node: T) -> t.List[T]: - """Returns all upstream dependencies in topologically sorted order.""" - return self.subdag(node).sorted[:-1] + # Use DFS to find a cycle path + visited: t.Set[T] = set() + path: t.List[T] = [] + + def dfs(node: T) -> t.Optional[t.List[T]]: + if node in path: + # Found a cycle - extract the cycle path + cycle_start = path.index(node) + return path[cycle_start:] + [node] + + if node in visited: + return None + + visited.add(node) + path.append(node) + + # Only follow edges to nodes that are still in the unprocessed set + for neighbor in nodes_in_cycle.get(node, set()): + if neighbor in nodes_in_cycle: + cycle = dfs(neighbor) + if cycle: + return cycle + + path.pop() + return None + + # Try starting DFS from each unvisited node + for start_node in nodes_in_cycle: + if start_node not in visited: + cycle = dfs(start_node) + if cycle: + return cycle[:-1] # Remove the duplicate node at the end + + return None @property def roots(self) -> t.Set[T]: @@ -108,7 +163,6 @@ def sorted(self) -> t.List[T]: """Returns a list of nodes sorted in topological order.""" if self._sorted is None: self._sorted = [] - unprocessed_nodes = self.graph last_processed_nodes: t.Set[T] = set() @@ -118,23 +172,31 @@ def sorted(self) -> t.List[T]: next_nodes = {node for node, deps in unprocessed_nodes.items() if not deps} if not next_nodes: - # Sort cycle candidates to make the order deterministic - cycle_candidates_msg = ( - "\nPossible candidates to check for circular references: " - + ", ".join(str(node) for node in sorted(cycle_candidates)) - ) + # A cycle was detected - find the exact cycle path + cycle_path = self._find_cycle_path(unprocessed_nodes) - if last_processed_nodes: - last_processed_msg = "\nLast nodes added to the DAG: " + ", ".join( - str(node) for node in last_processed_nodes + last_processed_msg = "" + if cycle_path: + node_output = " ->\n".join( + str(node) for node in (cycle_path + [cycle_path[0]]) ) + cycle_msg = f"\nCycle:\n{node_output}" else: - last_processed_msg = "" + # Fallback message in case a cycle can't be found + cycle_candidates_msg = ( + "\nPossible candidates to check for circular references: " + + ", ".join(str(node) for node in sorted(cycle_candidates)) + ) + cycle_msg = cycle_candidates_msg + if last_processed_nodes: + last_processed_msg = "\nLast nodes added to the DAG: " + ", ".join( + str(node) for node in last_processed_nodes + ) raise SQLMeshError( "Detected a cycle in the DAG. " "Please make sure there are no circular references between nodes." - f"{last_processed_msg}{cycle_candidates_msg}" + f"{last_processed_msg}{cycle_msg}" ) for node in next_nodes: diff --git a/sqlmesh/utils/date.py b/sqlmesh/utils/date.py index 1f8bf421bf..c9bb19c835 100644 --- a/sqlmesh/utils/date.py +++ b/sqlmesh/utils/date.py @@ -5,24 +5,26 @@ import typing as t import warnings -from pandas.api.types import is_datetime64_any_dtype # type: ignore - -from datetime import date, datetime, timedelta, timezone +from datetime import date, datetime, timedelta, timezone, tzinfo import dateparser -import pandas as pd from dateparser import freshness_date_parser as freshness_date_parser_module from dateparser.freshness_date_parser import freshness_date_parser from sqlglot import exp from sqlmesh.utils import ttl_cache +if t.TYPE_CHECKING: + import pandas as pd + + from sqlglot.dialects.dialect import DialectType + UTC = timezone.utc TimeLike = t.Union[date, datetime, str, int, float] +DatetimeRange = t.Tuple[datetime, datetime] +DatetimeRanges = t.List[DatetimeRange] DATE_INT_FMT = "%Y%m%d" -if t.TYPE_CHECKING: - from sqlmesh.core.scheduler import Interval warnings.filterwarnings( "ignore", @@ -31,7 +33,7 @@ # The Freshness Date Data Parser doesn't support plural units so we add the `s?` to the expression -freshness_date_parser_module.PATTERN = re.compile( +freshness_date_parser_module.PATTERN = re.compile( # type: ignore r"(\d+[.,]?\d*)\s*(%s)s?\b" % freshness_date_parser_module._UNITS, # type: ignore re.I | re.S | re.U, # type: ignore ) @@ -85,63 +87,82 @@ def now_ds() -> str: return to_ds(now()) -def yesterday() -> datetime: +def yesterday(relative_base: t.Optional[datetime] = None) -> datetime: """ Yesterday utc datetime. Returns: A datetime object with tz utc representing yesterday's date """ - return to_datetime("yesterday") + return to_datetime("yesterday", relative_base=relative_base) -def yesterday_ds() -> str: +def yesterday_ds(relative_base: t.Optional[datetime] = None) -> str: """ Yesterday utc ds. Returns: Yesterday's ds string. """ - return to_ds("yesterday") + return to_ds("yesterday", relative_base=relative_base) -def yesterday_timestamp() -> int: +def yesterday_timestamp(relative_base: t.Optional[datetime] = None) -> int: """ Yesterday utc timestamp. Returns: UTC epoch millis timestamp of yesterday """ - return to_timestamp(yesterday()) + return to_timestamp(yesterday(relative_base=relative_base)) -def to_timestamp(value: TimeLike, relative_base: t.Optional[datetime] = None) -> int: +def to_timestamp( + value: TimeLike, + relative_base: t.Optional[datetime] = None, + check_categorical_relative_expression: bool = True, +) -> int: """ Converts a value into an epoch millis timestamp. Args: value: A variety of date formats. If value is a string, it must be in iso format. relative_base: The datetime to reference for time expressions that are using relative terms + check_categorical_relative_expression: If True, takes into account the relative expressions that are categorical. Returns: Epoch millis timestamp. """ - return int(to_datetime(value, relative_base=relative_base).timestamp() * 1000) + return int( + to_datetime( + value, + relative_base=relative_base, + check_categorical_relative_expression=check_categorical_relative_expression, + ).timestamp() + * 1000 + ) @ttl_cache() -def to_datetime(value: TimeLike, relative_base: t.Optional[datetime] = None) -> datetime: +def to_datetime( + value: TimeLike, + relative_base: t.Optional[datetime] = None, + check_categorical_relative_expression: bool = True, + tz: t.Optional[tzinfo] = None, +) -> datetime: """Converts a value into a UTC datetime object. Args: value: A variety of date formats. If the value is number-like, it is assumed to be millisecond epochs. - relative_base: The datetime to reference for time expressions that are using relative terms + relative_base: The datetime to reference for time expressions that are using relative terms. + check_categorical_relative_expression: If True, takes into account the relative expressions that are categorical. + tz: Timezone to convert datetime to, defaults to utc Raises: ValueError if value cannot be converted to a datetime. Returns: - A datetime object with tz utc. + A datetime object with tz (default UTC). """ if isinstance(value, datetime): dt: t.Optional[datetime] = value @@ -158,9 +179,17 @@ def to_datetime(value: TimeLike, relative_base: t.Optional[datetime] = None) -> if epoch is None: relative_base = relative_base or now() expression = str(value) - if is_catagorical_relative_expression(expression): + if check_categorical_relative_expression and is_categorical_relative_expression( + expression + ): relative_base = relative_base.replace(hour=0, minute=0, second=0, microsecond=0) - dt = dateparser.parse(expression, settings={"RELATIVE_BASE": relative_base}) + + # note: we hardcode TIMEZONE: UTC to work around this bug: https://github.com/scrapinghub/dateparser/issues/896 + # where dateparser just silently fails if it cant interpret the contents of /etc/localtime + # this works because SQLMesh only deals with UTC, there is no concept of user local time + dt = dateparser.parse( + expression, settings={"RELATIVE_BASE": relative_base, "TIMEZONE": "UTC"} + ) else: try: dt = datetime.strptime(str(value), DATE_INT_FMT) @@ -170,9 +199,11 @@ def to_datetime(value: TimeLike, relative_base: t.Optional[datetime] = None) -> if dt is None: raise ValueError(f"Could not convert `{value}` to datetime.") + tz = tz or UTC + if dt.tzinfo: - return dt if dt.tzinfo == UTC else dt.astimezone(UTC) - return dt.replace(tzinfo=UTC) + return dt if dt.tzinfo == tz else dt.astimezone(tz) + return dt.replace(tzinfo=tz) def to_date(value: TimeLike, relative_base: t.Optional[datetime] = None) -> date: @@ -191,8 +222,10 @@ def to_date(value: TimeLike, relative_base: t.Optional[datetime] = None) -> date def date_dict( - execution_time: TimeLike, start: t.Optional[TimeLike], end: t.Optional[TimeLike] -) -> t.Dict[str, t.Union[str, datetime, date, float, int]]: + execution_time: TimeLike, + start: t.Optional[TimeLike], + end: t.Optional[TimeLike], +) -> t.Dict[str, TimeLike]: """Creates a kwarg dictionary of datetime variables for use in SQL Contexts. Keys are like start_date, start_ds, end_date, end_ds... @@ -220,8 +253,12 @@ def date_dict( for prefix, time_like in prefixes: dt = to_datetime(time_like) + dtntz = dt.replace(tzinfo=None) + millis = to_timestamp(time_like) + kwargs[f"{prefix}_dt"] = dt + kwargs[f"{prefix}_dtntz"] = dtntz kwargs[f"{prefix}_date"] = to_date(dt) kwargs[f"{prefix}_ds"] = to_ds(time_like) kwargs[f"{prefix}_ts"] = to_ts(dt) @@ -229,22 +266,23 @@ def date_dict( kwargs[f"{prefix}_epoch"] = millis / 1000 kwargs[f"{prefix}_millis"] = millis kwargs[f"{prefix}_hour"] = dt.hour + return kwargs -def to_ds(obj: TimeLike) -> str: +def to_ds(obj: TimeLike, relative_base: t.Optional[datetime] = None) -> str: """Converts a TimeLike object into YYYY-MM-DD formatted string.""" - return to_ts(obj)[0:10] + return to_ts(obj, relative_base=relative_base)[0:10] -def to_ts(obj: TimeLike) -> str: +def to_ts(obj: TimeLike, relative_base: t.Optional[datetime] = None) -> str: """Converts a TimeLike object into YYYY-MM-DD HH:MM:SS formatted string.""" - return to_datetime(obj).replace(tzinfo=None).isoformat(sep=" ") + return to_datetime(obj, relative_base=relative_base).replace(tzinfo=None).isoformat(sep=" ") -def to_tstz(obj: TimeLike) -> str: +def to_tstz(obj: TimeLike, relative_base: t.Optional[datetime] = None) -> str: """Converts a TimeLike object into YYYY-MM-DD HH:MM:SS+00:00 formatted string.""" - return to_datetime(obj).isoformat(sep=" ") + return to_datetime(obj, relative_base=relative_base).isoformat(sep=" ") def is_date(obj: TimeLike) -> bool: @@ -259,7 +297,9 @@ def is_date(obj: TimeLike) -> bool: return False -def make_inclusive(start: TimeLike, end: TimeLike) -> Interval: +def make_inclusive( + start: TimeLike, end: TimeLike, dialect: t.Optional[DialectType] = "" +) -> DatetimeRange: """Adjust start and end times to to become inclusive datetimes. SQLMesh treats start and end times as inclusive so that filters can be written as @@ -270,7 +310,8 @@ def make_inclusive(start: TimeLike, end: TimeLike) -> Interval: In the ds ('2020-01-01') case, because start_ds and end_ds are categorical, between works even if start_ds and end_ds are equivalent. However, when we move to ts ('2022-01-01 12:00:00'), because timestamps are numeric, using simple equality doesn't make sense. When the end is not a categorical date, then it is - treated as an exclusive range and converted to inclusive by subtracting 1 millisecond. + treated as an exclusive range and converted to inclusive by subtracting 1 microsecond. If the dialect is + T-SQL then 1 nanoseconds is subtracted to account for the increased precision. Args: start: Start timelike object. @@ -283,14 +324,38 @@ def make_inclusive(start: TimeLike, end: TimeLike) -> Interval: Returns: A tuple of inclusive datetime objects. """ - return (to_datetime(start), make_inclusive_end(end)) + return (to_datetime(start), make_inclusive_end(end, dialect=dialect)) + + +def make_inclusive_end(end: TimeLike, dialect: t.Optional[DialectType] = "") -> datetime: + import pandas as pd + exclusive_end = make_exclusive(end) + if dialect == "tsql": + return to_utc_timestamp(exclusive_end) - pd.Timedelta(1, unit="ns") + return exclusive_end - timedelta(microseconds=1) -def make_inclusive_end(end: TimeLike) -> datetime: - end_dt = to_datetime(end) - if is_date(end): - end_dt = end_dt + timedelta(days=1) - return end_dt - timedelta(microseconds=1) + +def make_exclusive(time: TimeLike) -> datetime: + dt = to_datetime(time) + if is_date(time): + dt = dt + timedelta(days=1) + return dt + + +def make_ts_exclusive(time: TimeLike, dialect: DialectType) -> datetime: + ts = to_datetime(time) + if dialect == "tsql": + return to_utc_timestamp(ts) - pd.Timedelta(1, unit="ns") + return ts + timedelta(microseconds=1) + + +def to_utc_timestamp(time: datetime) -> pd.Timestamp: + import pandas as pd + + if time.tzinfo is not None: + return pd.Timestamp(time).tz_convert("utc") + return pd.Timestamp(time, tz="utc") def validate_date_range( @@ -311,7 +376,7 @@ def time_like_to_str(time_like: TimeLike) -> str: return to_ts(time_like) -def is_catagorical_relative_expression(expression: str) -> bool: +def is_categorical_relative_expression(expression: str) -> bool: if expression.strip().lower() in DAY_SHORTCUT_EXPRESSIONS: return True grain_kwargs = freshness_date_parser.get_kwargs(expression) @@ -320,33 +385,70 @@ def is_catagorical_relative_expression(expression: str) -> bool: return not any(k in TIME_UNITS for k in grain_kwargs) +def is_relative(value: TimeLike) -> bool: + """ + Tests a TimeLike object to see if it is a relative expression, eg '1 week ago' as opposed to an absolute timestamp + """ + if isinstance(value, str): + return is_categorical_relative_expression(value) + + return False + + def to_time_column( time_column: t.Union[TimeLike, exp.Null], time_column_type: exp.DataType, + dialect: str, time_column_format: t.Optional[str] = None, + nullable: bool = False, ) -> exp.Expression: """Convert a TimeLike object to the same time format and type as the model's time column.""" + if dialect == "clickhouse" and time_column_type.is_type( + *(exp.DataType.TEMPORAL_TYPES - {exp.DataType.Type.DATE, exp.DataType.Type.DATE32}) + ): + if time_column_type.is_type(exp.DataType.Type.DATETIME64): + if nullable: + time_column_type.set("nullable", nullable) + else: + # Clickhouse will error if we pass fractional seconds to DateTime, so we always + # use DateTime64 for timestamps. + # + # `datetime` objects have microsecond precision, so we specify the type precision as 6. + # If a timezone is present in the passed type object, it is included in the DateTime64 type + # via the `expressions` arg. + time_column_type = exp.DataType.build( + exp.DataType.Type.DATETIME64, + expressions=[ + exp.DataTypeParam(this=exp.Literal(this=6, is_string=False)), + *time_column_type.expressions, + ], + nullable=nullable or time_column_type.args.get("nullable", False), + ) + if isinstance(time_column, exp.Null): return exp.cast(time_column, to=time_column_type) - if time_column_type.is_type(exp.DataType.Type.DATE): + if time_column_type.is_type(exp.DataType.Type.DATE, exp.DataType.Type.DATE32): return exp.cast(exp.Literal.string(to_ds(time_column)), to="date") - if time_column_type.this in TEMPORAL_TZ_TYPES: - return exp.cast(exp.Literal.string(to_tstz(time_column)), to=time_column_type.this) - if time_column_type.this in exp.DataType.TEMPORAL_TYPES: - return exp.cast(exp.Literal.string(to_ts(time_column)), to=time_column_type.this) + if time_column_type.is_type(*TEMPORAL_TZ_TYPES): + return exp.cast(exp.Literal.string(to_tstz(time_column)), to=time_column_type) + if time_column_type.is_type(*exp.DataType.TEMPORAL_TYPES): + return exp.cast(exp.Literal.string(to_ts(time_column)), to=time_column_type) if time_column_format: time_column = to_datetime(time_column).strftime(time_column_format) - if time_column_type.this in exp.DataType.TEXT_TYPES: + if time_column_type.is_type(*exp.DataType.TEXT_TYPES): return exp.Literal.string(time_column) - if time_column_type.this in exp.DataType.NUMERIC_TYPES: + if time_column_type.is_type(*exp.DataType.NUMERIC_TYPES): return exp.Literal.number(time_column) return exp.convert(time_column) def pandas_timestamp_to_pydatetime( - df: pd.DataFrame, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] + df: pd.DataFrame, columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None ) -> pd.DataFrame: + import pandas as pd + from pandas.api.types import is_datetime64_any_dtype # type: ignore + for column in df.columns: if is_datetime64_any_dtype(df.dtypes[column]): # We must use `pd.Series` and dtype or pandas will convert it back to pd.Timestamp during assignment @@ -363,3 +465,19 @@ def pandas_timestamp_to_pydatetime( ) return df + + +def format_tz_datetime( + time: TimeLike, + format_string: t.Optional[str] = "%Y-%m-%d %I:%M%p %Z", + use_local_timezone: bool = False, +) -> str: + output_datetime = to_datetime(time) + if use_local_timezone: + local_timezone = datetime.now().astimezone().tzinfo + output_datetime = output_datetime.astimezone(local_timezone) + return ( + output_datetime.strftime(format_string) + if format_string + else output_datetime.isoformat(sep=" ") + ) diff --git a/sqlmesh/utils/errors.py b/sqlmesh/utils/errors.py index d2267aa6e6..ca3e1bfb05 100644 --- a/sqlmesh/utils/errors.py +++ b/sqlmesh/utils/errors.py @@ -11,6 +11,7 @@ from requests.models import Response from sqlmesh.core.model import Model + from sqlmesh.core.schema_diff import TableAlterOperation class ErrorLevel(AutoName): @@ -24,7 +25,25 @@ class SQLMeshError(Exception): class ConfigError(SQLMeshError): - pass + location: t.Optional[Path] = None + + def __init__(self, message: str | Exception, location: t.Optional[Path] = None) -> None: + super().__init__(message) + if location: + self.location = Path(location) if isinstance(location, str) else location + + +class BaseMissingReferenceError(ConfigError): + def __init__(self, ref: str) -> None: + self.ref = ref + + +class MissingModelError(BaseMissingReferenceError): + """Raised when a model that is referenced is missing.""" + + +class MissingSourceError(BaseMissingReferenceError): + """Raised when a source that is referenced is missing.""" class MissingDependencyError(SQLMeshError): @@ -47,6 +66,10 @@ class UncategorizedPlanError(PlanError): pass +class ConflictingPlanError(PlanError): + pass + + class MissingContextException(Exception): pass @@ -63,10 +86,15 @@ class AuditConfigError(ConfigError): pass +class StateMigrationError(SQLMeshError): + pass + + class AuditError(SQLMeshError): def __init__( self, audit_name: str, + audit_args: t.Dict[t.Any, t.Any], count: int, query: exp.Query, model: t.Optional[Model] = None, @@ -74,14 +102,15 @@ def __init__( adapter_dialect: t.Optional[str] = None, ) -> None: self.audit_name = audit_name + self.audit_args = audit_args self.model = model self.count = count self.query = query self.adapter_dialect = adapter_dialect - def __str__(self) -> str: - model_str = f" for model '{self.model_name}'" if self.model_name else "" - return f"Audit '{self.audit_name}'{model_str} failed.\nGot {self.count} results, expected 0.\n{self.sql()}" + super().__init__( + f"'{self.audit_name}' audit error: {self.count} {'row' if self.count == 1 else 'rows'} failed" + ) @property def model_name(self) -> t.Optional[str]: @@ -102,11 +131,30 @@ def sql(self, dialect: t.Optional[str] = None, **opts: t.Any) -> str: return self.query.sql(dialect=dialect or self.adapter_dialect, **opts) +class NodeAuditsErrors(SQLMeshError): + def __init__(self, errors: t.List[AuditError]) -> None: + self.errors = errors + + super().__init__(f"Audits failed: {', '.join([e.audit_name for e in errors])}") + + class TestError(SQLMeshError): __test__ = False # prevent pytest trying to collect this as a test class pass +class DestructiveChangeError(SQLMeshError): + pass + + +class AdditiveChangeError(SQLMeshError): + pass + + +class MigrationNotSupportedError(SQLMeshError): + pass + + class NotificationTargetError(SQLMeshError): pass @@ -148,17 +196,35 @@ class UnsupportedCatalogOperationError(EngineAdapterError): class CircuitBreakerError(SQLMeshError): def __init__(self) -> None: - super().__init__("Circuit breaker has been triggered.") + super().__init__("Circuit breaker triggered.") + + +class PythonModelEvalError(SQLMeshError): + pass + + +class MissingDefaultCatalogError(SQLMeshError): + pass + + +class LinterError(SQLMeshError): + pass + + +class SignalEvalError(SQLMeshError): + """Errors when evaluating a signal that is because of a user mistake and not a SQLMesh bug.""" + + pass def raise_config_error( msg: str, - location: t.Optional[str | Path] = None, + location: t.Optional[Path] = None, error_type: t.Type[ConfigError] = ConfigError, ) -> None: if location: - raise error_type(f"{msg} at '{location}'") - raise error_type(msg) + raise error_type(f"{msg} at '{location}'", location) + raise error_type(msg, location=location) def raise_for_status(response: Response) -> None: @@ -168,3 +234,89 @@ def raise_for_status(response: Response) -> None: raise ApiClientError(response.text, response.status_code) if 500 <= response.status_code < 600: raise ApiServerError(response.text, response.status_code) + + +def _format_schema_change_msg( + snapshot_name: str, + is_destructive: bool, + alter_operations: t.List[TableAlterOperation], + dialect: str, + error: bool = True, +) -> str: + """ + Common function to format schema change messages. + + Args: + snapshot_name: Name of the model/snapshot + is_destructive: if change is destructive else it would be additive + alter_operations: List of table alter operations + dialect: SQL dialect for formatting + error: Whether this is an error or warning + """ + from sqlmesh.core.schema_diff import get_dropped_column_names, get_additive_column_names + + change_type = "destructive" if is_destructive else "additive" + setting_name = "on_destructive_change" if is_destructive else "on_additive_change" + action_verb = "drops" if is_destructive else "adds" + cli_flag = "--allow-destructive-model" if is_destructive else "--allow-additive-model" + + column_names = ( + get_dropped_column_names(alter_operations) + if is_destructive + else get_additive_column_names(alter_operations) + ) + column_str = "', '".join(column_names) + column_msg = ( + f" that {action_verb} column{'s' if column_names and len(column_names) > 1 else ''} '{column_str}'" + if column_str + else "" + ) + + # Format ALTER expressions + alter_expr_msg = "\n\nSchema changes:\n " + "\n ".join( + [alter.expression.sql(dialect) for alter in alter_operations] + ) + + # Main warning message + warning_msg = ( + f"Plan requires {change_type} change to forward-only model '{snapshot_name}'s schema" + ) + + if error: + permissive_values = "`warn`, `allow`, or `ignore`" + cli_part = f" or include the model in the plan's `{cli_flag}` option" + err_msg = f"\n\nTo allow the {change_type} change, set the model's `{setting_name}` setting to {permissive_values}{cli_part}.\n" + else: + err_msg = "" + + return f"\n{warning_msg}{column_msg}.{alter_expr_msg}{err_msg}" + + +def format_destructive_change_msg( + snapshot_name: str, + alter_expressions: t.List[TableAlterOperation], + dialect: str, + error: bool = True, +) -> str: + return _format_schema_change_msg( + snapshot_name=snapshot_name, + is_destructive=True, + alter_operations=alter_expressions, + dialect=dialect, + error=error, + ) + + +def format_additive_change_msg( + snapshot_name: str, + alter_operations: t.List[TableAlterOperation], + dialect: str, + error: bool = True, +) -> str: + return _format_schema_change_msg( + snapshot_name=snapshot_name, + is_destructive=False, + alter_operations=alter_operations, + dialect=dialect, + error=error, + ) diff --git a/sqlmesh/utils/git.py b/sqlmesh/utils/git.py index 9a558dec9a..cdb9d4e2d5 100644 --- a/sqlmesh/utils/git.py +++ b/sqlmesh/utils/git.py @@ -16,7 +16,9 @@ def list_untracked_files(self) -> t.List[Path]: ) def list_uncommitted_changed_files(self) -> t.List[Path]: - return self._execute_list_output(["diff", "--name-only", "--diff-filter=d"], self._git_root) + return self._execute_list_output( + ["diff", "--name-only", "--diff-filter=d", "HEAD"], self._git_root + ) def list_committed_changed_files(self, target_branch: str = "main") -> t.List[Path]: return self._execute_list_output( @@ -27,7 +29,23 @@ def _execute_list_output(self, commands: t.List[str], base_path: Path) -> t.List return [(base_path / o).absolute() for o in self._execute(commands).split("\n") if o] def _execute(self, commands: t.List[str]) -> str: - result = subprocess.run(["git"] + commands, cwd=self._work_dir, stdout=subprocess.PIPE) + result = subprocess.run( + ["git"] + commands, + cwd=self._work_dir, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + check=False, + ) + + # If the Git command failed, extract and raise the error message in the console + if result.returncode != 0: + stderr_output = result.stderr.decode("utf-8").strip() + error_message = next( + (line for line in stderr_output.splitlines() if line.lower().startswith("fatal:")), + stderr_output, + ) + raise RuntimeError(f"Git error: {error_message}") + return result.stdout.decode("utf-8").strip() @cached_property diff --git a/sqlmesh/utils/hashing.py b/sqlmesh/utils/hashing.py index 1bccd987bc..a166d36bec 100644 --- a/sqlmesh/utils/hashing.py +++ b/sqlmesh/utils/hashing.py @@ -9,7 +9,9 @@ def crc32(data: t.Iterable[t.Optional[str]]) -> str: return str(zlib.crc32(_safe_concat(data))) -def md5(data: t.Iterable[t.Optional[str]]) -> str: +def md5(data: t.Union[str, t.Iterable[t.Optional[str]]]) -> str: + if isinstance(data, str): + data = [data] return hashlib.md5(_safe_concat(data)).hexdigest() diff --git a/sqlmesh/utils/jinja.py b/sqlmesh/utils/jinja.py index 69aee6b8d4..240b183391 100644 --- a/sqlmesh/utils/jinja.py +++ b/sqlmesh/utils/jinja.py @@ -4,16 +4,25 @@ import json import re import typing as t +import zlib from collections import defaultdict from enum import Enum +from sys import exc_info +from traceback import walk_tb -from jinja2 import Environment, Template, nodes +from jinja2 import Environment, Template, nodes, UndefinedError +from jinja2.runtime import Macro from sqlglot import Dialect, Expression, Parser, TokenType from sqlmesh.core import constants as c from sqlmesh.core import dialect as d from sqlmesh.utils import AttributeDict -from sqlmesh.utils.pydantic import PydanticModel, field_serializer, field_validator +from sqlmesh.utils.pydantic import PRIVATE_FIELDS, PydanticModel, field_serializer, field_validator +from sqlmesh.utils.metaprogramming import SqlValue + + +if t.TYPE_CHECKING: + CallNames = t.Tuple[t.Tuple[str, ...], t.Union[nodes.Call, nodes.Getattr]] SQLMESH_JINJA_PACKAGE = "sqlmesh.utils.jinja" @@ -47,6 +56,7 @@ class MacroInfo(PydanticModel): definition: str depends_on: t.List[MacroReference] + is_top_level: bool = False class MacroReturnVal(Exception): @@ -67,7 +77,7 @@ def extract(self, jinja: str, dialect: str = "") -> t.Dict[str, MacroInfo]: """ self.reset() self.sql = jinja - self._tokens = Dialect.get_or_raise(dialect).tokenizer.tokenize(jinja) + self._tokens = Dialect.get_or_raise(dialect).tokenize(jinja) self._index = -1 self._advance() @@ -119,12 +129,16 @@ def render_jinja(query: str, methods: t.Optional[t.Dict[str, t.Any]] = None) -> return ENVIRONMENT.from_string(query).render(methods or {}) -def find_call_names( - node: nodes.Node, vars_in_scope: t.Set[str] -) -> t.Iterator[t.Tuple[t.Tuple[str, ...], nodes.Call]]: +def find_call_names(node: nodes.Node, vars_in_scope: t.Set[str]) -> t.Iterator[CallNames]: vars_in_scope = vars_in_scope.copy() for child_node in node.iter_child_nodes(): if "target" in child_node.fields: + # For nodes with assignment targets (Assign, AssignBlock, For, Import), + # the target name could shadow a reference in the right hand side. + # So we need to process the RHS before adding the target to scope. + # For example: {% set model = model.path %} should track model.path. + yield from find_call_names(child_node, vars_in_scope) + target = getattr(child_node, "target") if isinstance(target, nodes.Name): vars_in_scope.add(target.name) @@ -135,15 +149,40 @@ def find_call_names( elif isinstance(child_node, nodes.Macro): for arg in child_node.args: vars_in_scope.add(arg.name) - elif isinstance(child_node, nodes.Call): + elif isinstance(child_node, nodes.Call) or ( + isinstance(child_node, nodes.Getattr) and not isinstance(child_node.node, nodes.Getattr) + ): name = call_name(child_node) if name[0][0] != "'" and name[0] not in vars_in_scope: yield (name, child_node) - yield from find_call_names(child_node, vars_in_scope) + + if "target" not in child_node.fields: + yield from find_call_names(child_node, vars_in_scope) + + +def extract_call_names( + jinja_str: str, cache: t.Optional[t.Dict[str, t.Tuple[t.List[CallNames], bool]]] = None +) -> t.List[CallNames]: + def parse() -> t.List[CallNames]: + return list(find_call_names(ENVIRONMENT.parse(jinja_str), set())) + + if cache is not None: + key = str(zlib.crc32(jinja_str.encode("utf-8"))) + if key in cache: + names = cache[key][0] + else: + names = parse() + cache[key] = (names, True) + return names + return parse() -def extract_call_names(jinja_str: str) -> t.List[t.Tuple[t.Tuple[str, ...], nodes.Call]]: - return list(find_call_names(ENVIRONMENT.parse(jinja_str), set())) +def is_variable_node(n: nodes.Node) -> bool: + return ( + isinstance(n, nodes.Call) + and isinstance(n.node, nodes.Name) + and n.node.name in (c.VAR, c.BLUEPRINT_VAR) + ) def extract_macro_references_and_variables( @@ -153,7 +192,16 @@ def extract_macro_references_and_variables( variables = set() for jinja_str in jinja_strs: for call_name, node in extract_call_names(jinja_str): - if call_name[0] == c.VAR: + if call_name[0] in (c.VAR, c.BLUEPRINT_VAR): + if not is_variable_node(node): + # Find the variable node which could be nested + for n in node.find_all(nodes.Call): + if is_variable_node(n): + node = n + break + else: + raise ValueError(f"Could not find variable name in {jinja_str}") + node = t.cast(nodes.Call, node) args = [jinja_call_arg_name(arg) for arg in node.args] if args and args[0]: variables.add(args[0].lower()) @@ -166,6 +214,20 @@ def extract_macro_references_and_variables( return macro_references, variables +def sort_dict_recursive( + item: t.Dict[str, t.Any], +) -> t.Dict[str, t.Any]: + sorted_dict: t.Dict[str, t.Any] = {} + for k, v in sorted(item.items()): + if isinstance(v, list): + sorted_dict[k] = sorted(v) + elif isinstance(v, dict): + sorted_dict[k] = sort_dict_recursive(v) + else: + sorted_dict[k] = v + return sorted_dict + + JinjaGlobalAttribute = t.Union[str, int, float, bool, AttributeDict] @@ -192,8 +254,16 @@ class JinjaMacroRegistry(PydanticModel): top_level_packages: t.List[str] = [] _parser_cache: t.Dict[t.Tuple[t.Optional[str], str], Template] = {} + _trimmed: bool = False __environment: t.Optional[Environment] = None + def __getstate__(self) -> t.Dict[t.Any, t.Any]: + state = super().__getstate__() + private = state[PRIVATE_FIELDS] + private["_parser_cache"] = {} + private["_JinjaMacroRegistry__environment"] = None + return state + @field_validator("global_objs", mode="before") @classmethod def _validate_global_objs(cls, value: t.Any) -> t.Any: @@ -222,6 +292,10 @@ def _convert( return _convert(value) + @property + def trimmed(self) -> bool: + return self._trimmed + def add_macros(self, macros: t.Dict[str, MacroInfo], package: t.Optional[str] = None) -> None: """Adds macros to the target package. @@ -244,6 +318,9 @@ def add_globals(self, globals: t.Dict[str, JinjaGlobalAttribute]) -> None: Args: globals: The global objects that should be added. """ + # Keep the registry lightweight when the graph is not needed + if not "graph" in self.packages: + globals.pop("flat_graph", None) self.global_objs.update(**self._validate_global_objs(globals)) def build_macro(self, reference: MacroReference, **kwargs: t.Any) -> t.Optional[t.Callable]: @@ -272,18 +349,22 @@ def build_environment(self, **kwargs: t.Any) -> Environment: package_macros: t.Dict[str, t.Any] = defaultdict(AttributeDict) for package_name, macros in self.packages.items(): - for macro_name in macros: - package_macros[package_name][macro_name] = self._MacroWrapper( - macro_name, package_name, self, context - ) + for macro_name, macro in macros.items(): + macro_wrapper = self._MacroWrapper(macro_name, package_name, self, context) + package_macros[package_name][macro_name] = macro_wrapper + if macro.is_top_level and macro_name not in root_macros: + root_macros[macro_name] = macro_wrapper + + top_level_packages = self.top_level_packages.copy() if self.root_package_name is not None: package_macros[self.root_package_name].update(root_macros) + top_level_packages.append(self.root_package_name) env = environment() builtin_globals = self._create_builtin_globals(kwargs) - for top_level_package_name in self.top_level_packages: + for top_level_package_name in top_level_packages: # Make sure that the top-level package doesn't fully override the same builtin package. package_macros[top_level_package_name] = AttributeDict( { @@ -296,6 +377,7 @@ def build_environment(self, **kwargs: t.Any) -> Environment: context.update(builtin_globals) context.update(root_macros) context.update(package_macros) + context["render"] = lambda input: env.from_string(input).render() env.globals.update(context) env.filters.update(self._environment.filters) @@ -330,6 +412,8 @@ def trim( for package, names in dependencies_by_package.items(): result = result.merge(self._trim_macros(names, package)) + result._trimmed = True + return result def merge(self, other: JinjaMacroRegistry) -> JinjaMacroRegistry: @@ -379,7 +463,7 @@ def to_expressions(self) -> t.List[Expression]: d.PythonCode( expressions=[ f"{k} = '{v}'" if isinstance(v, str) else f"{k} = {v}" - for k, v in sorted(filtered_objs.items()) + for k, v in sort_dict_recursive(filtered_objs).items() ] ) ) @@ -558,7 +642,10 @@ def jinja_call_arg_name(node: nodes.Node) -> str: def create_var(variables: t.Dict[str, t.Any]) -> t.Callable: def _var(var_name: str, default: t.Optional[t.Any] = None) -> t.Optional[t.Any]: - return variables.get(var_name.lower(), default) + value = variables.get(var_name.lower(), default) + if isinstance(value, SqlValue): + return value.sql + return value return _var @@ -568,8 +655,59 @@ def create_builtin_globals( ) -> t.Dict[str, t.Any]: global_vars.pop(c.GATEWAY, None) variables = global_vars.pop(c.SQLMESH_VARS, None) or {} + blueprint_variables = global_vars.pop(c.SQLMESH_BLUEPRINT_VARS, None) or {} return { + **global_vars, c.VAR: create_var(variables), c.GATEWAY: lambda: variables.get(c.GATEWAY, None), - **global_vars, + c.BLUEPRINT_VAR: create_var(blueprint_variables), } + + +def make_jinja_registry( + jinja_macros: JinjaMacroRegistry, package_name: str, jinja_references: t.Set[MacroReference] +) -> JinjaMacroRegistry: + """ + Creates a Jinja macro registry for a specific package. + + This function takes an existing Jinja macro registry and returns a new + registry that includes only the macros associated with the specified + package and trims the registry to include only the macros referenced + in the provided set of macro references. + + Args: + jinja_macros: The original Jinja macro registry containing all macros. + package_name: The name of the package for which to create the registry. + jinja_references: A set of macro references to retain in the new registry. + + Returns: + A new JinjaMacroRegistry containing only the macros for the specified + package and the referenced macros. + """ + + jinja_registry = jinja_macros.copy() + jinja_registry.root_macros = jinja_registry.packages.get(package_name) or {} + jinja_registry = jinja_registry.trim(jinja_references) + + return jinja_registry + + +def extract_error_details(ex: Exception) -> str: + """Extracts a readable message from a Jinja2 error, to include missing name and macro.""" + + error_details = "" + if isinstance(ex, UndefinedError): + if match := re.search(r"'(\w+)'", str(ex)): + error_details += f"\nUndefined macro/variable: '{match.group(1)}'" + try: + _, _, exc_traceback = exc_info() + for frame, _ in walk_tb(exc_traceback): + if frame.f_code.co_name == "_invoke": + macro = frame.f_locals.get("self") + if isinstance(macro, Macro): + error_details += f" in macro: '{macro.name}'\n" + break + except: + # to fall back to the generic error message if frame analysis fails + pass + return error_details or str(ex) diff --git a/sqlmesh/utils/lineage.py b/sqlmesh/utils/lineage.py new file mode 100644 index 0000000000..f5b4506c68 --- /dev/null +++ b/sqlmesh/utils/lineage.py @@ -0,0 +1,424 @@ +import typing as t +from pathlib import Path + +from pydantic import Field + +from sqlmesh.core.dialect import normalize_model_name +from sqlmesh.core.linter.helpers import ( + TokenPositionDetails, +) +from sqlmesh.core.linter.rule import Range, Position +from sqlmesh.core.model.definition import SqlModel, ExternalModel, PythonModel, SeedModel +from sqlglot import exp +from sqlglot.optimizer.scope import build_scope + +from sqlglot.optimizer.normalize_identifiers import normalize_identifiers +from ruamel.yaml import YAML + +from sqlmesh.utils.pydantic import PydanticModel + +if t.TYPE_CHECKING: + from sqlmesh.core.context import Context + from sqlmesh.core.context import GenericContext + + +class ModelReference(PydanticModel): + """A reference to a model, excluding external models.""" + + type: t.Literal["model"] = "model" + path: Path + range: Range + markdown_description: t.Optional[str] = None + + +class ExternalModelReference(PydanticModel): + """A reference to an external model.""" + + type: t.Literal["external_model"] = "external_model" + range: Range + target_range: t.Optional[Range] = None + path: t.Optional[Path] = None + """The path of the external model, typically a YAML file, it is optional because + external models can be unregistered and so the path is not available.""" + + markdown_description: t.Optional[str] = None + + +class CTEReference(PydanticModel): + """A reference to a CTE.""" + + type: t.Literal["cte"] = "cte" + path: Path + range: Range + target_range: Range + + +class MacroReference(PydanticModel): + """A reference to a macro.""" + + type: t.Literal["macro"] = "macro" + path: Path + range: Range + target_range: Range + markdown_description: t.Optional[str] = None + + +Reference = t.Annotated[ + t.Union[ModelReference, CTEReference, MacroReference, ExternalModelReference], + Field(discriminator="type"), +] + + +def extract_references_from_query( + query: exp.Expression, + context: t.Union["Context", "GenericContext[t.Any]"], + document_path: Path, + read_file: t.List[str], + depends_on: t.Set[str], + dialect: t.Optional[str] = None, +) -> t.List[Reference]: + # Build a scope tree to properly handle nested CTEs + try: + query = normalize_identifiers(query.copy(), dialect=dialect) + root_scope = build_scope(query) + except Exception: + root_scope = None + + references: t.List[Reference] = [] + if not root_scope: + return references + + # Traverse all scopes to find CTE definitions and table references + for scope in root_scope.traverse(): + for table in scope.tables: + table_name = table.name + + # Check if this table reference is a CTE in the current scope + if cte_scope := scope.cte_sources.get(table_name): + cte = cte_scope.expression.parent + alias = cte.args["alias"] + if isinstance(alias, exp.TableAlias): + identifier = alias.this + if isinstance(identifier, exp.Identifier): + target_range_sqlmesh = TokenPositionDetails.from_meta( + identifier.meta + ).to_range(read_file) + table_range_sqlmesh = TokenPositionDetails.from_meta( + table.this.meta + ).to_range(read_file) + + references.append( + CTEReference( + path=document_path, # Same file + range=table_range_sqlmesh, + target_range=target_range_sqlmesh, + ) + ) + + column_references = _process_column_references( + scope=scope, + reference_name=table.name, + read_file=read_file, + referenced_model_path=document_path, + description="", + reference_type="cte", + cte_target_range=target_range_sqlmesh, + ) + references.extend(column_references) + continue + + # For non-CTE tables, process these as before (external model references) + # Normalize the table reference + unaliased = table.copy() + if unaliased.args.get("alias") is not None: + unaliased.set("alias", None) + reference_name = unaliased.sql(dialect=dialect) + try: + normalized_reference_name = normalize_model_name( + reference_name, + default_catalog=context.default_catalog, + dialect=dialect, + ) + if normalized_reference_name not in depends_on: + continue + except Exception: + # Skip references that cannot be normalized + continue + + # Get the referenced model uri + referenced_model = context.get_model( + model_or_snapshot=normalized_reference_name, raise_if_missing=False + ) + if referenced_model is None: + # Extract metadata for positioning + table_meta = TokenPositionDetails.from_meta(table.this.meta) + table_range_sqlmesh = table_meta.to_range(read_file) + start_pos_sqlmesh = table_range_sqlmesh.start + end_pos_sqlmesh = table_range_sqlmesh.end + + # If there's a catalog or database qualifier, adjust the start position + catalog_or_db = table.args.get("catalog") or table.args.get("db") + if catalog_or_db is not None: + catalog_or_db_meta = TokenPositionDetails.from_meta(catalog_or_db.meta) + catalog_or_db_range_sqlmesh = catalog_or_db_meta.to_range(read_file) + start_pos_sqlmesh = catalog_or_db_range_sqlmesh.start + + references.append( + ExternalModelReference( + range=Range( + start=start_pos_sqlmesh, + end=end_pos_sqlmesh, + ), + markdown_description="Unregistered external model", + ) + ) + continue + referenced_model_path = referenced_model._path + if referenced_model_path is None: + continue + # Check whether the path exists + if not referenced_model_path.is_file(): + continue + + # Extract metadata for positioning + table_meta = TokenPositionDetails.from_meta(table.this.meta) + table_range_sqlmesh = table_meta.to_range(read_file) + start_pos_sqlmesh = table_range_sqlmesh.start + end_pos_sqlmesh = table_range_sqlmesh.end + + # If there's a catalog or database qualifier, adjust the start position + catalog_or_db = table.args.get("catalog") or table.args.get("db") + if catalog_or_db is not None: + catalog_or_db_meta = TokenPositionDetails.from_meta(catalog_or_db.meta) + catalog_or_db_range_sqlmesh = catalog_or_db_meta.to_range(read_file) + start_pos_sqlmesh = catalog_or_db_range_sqlmesh.start + + description = generate_markdown_description(referenced_model) + + # For external models in YAML files, find the specific model block + if isinstance(referenced_model, ExternalModel): + yaml_target_range: t.Optional[Range] = None + if ( + referenced_model_path.suffix in (".yaml", ".yml") + and referenced_model_path.is_file() + ): + yaml_target_range = _get_yaml_model_range( + referenced_model_path, referenced_model.name + ) + references.append( + ExternalModelReference( + path=referenced_model_path, + range=Range( + start=start_pos_sqlmesh, + end=end_pos_sqlmesh, + ), + markdown_description=description, + target_range=yaml_target_range, + ) + ) + + column_references = _process_column_references( + scope=scope, + reference_name=normalized_reference_name, + read_file=read_file, + referenced_model_path=referenced_model_path, + description=description, + yaml_target_range=yaml_target_range, + reference_type="external_model", + default_catalog=context.default_catalog, + dialect=dialect, + ) + references.extend(column_references) + else: + references.append( + ModelReference( + path=referenced_model_path, + range=Range( + start=start_pos_sqlmesh, + end=end_pos_sqlmesh, + ), + markdown_description=description, + ) + ) + + column_references = _process_column_references( + scope=scope, + reference_name=normalized_reference_name, + read_file=read_file, + referenced_model_path=referenced_model_path, + description=description, + reference_type="model", + default_catalog=context.default_catalog, + dialect=dialect, + ) + references.extend(column_references) + + return references + + +def generate_markdown_description( + model: t.Union[SqlModel, ExternalModel, PythonModel, SeedModel], +) -> t.Optional[str]: + description = model.description + columns = model.columns_to_types + column_descriptions = model.column_descriptions + + if columns is None: + return description or None + + columns_table = "\n".join( + [ + f"| {column} | {column_type} | {column_descriptions.get(column, '')} |" + for column, column_type in columns.items() + ] + ) + + table_header = "| Column | Type | Description |\n|--------|------|-------------|\n" + columns_text = table_header + columns_table + return f"{description}\n\n{columns_text}" if description else columns_text + + +def _process_column_references( + scope: t.Any, + reference_name: str, + read_file: t.List[str], + referenced_model_path: Path, + description: t.Optional[str] = None, + yaml_target_range: t.Optional[Range] = None, + reference_type: t.Literal["model", "external_model", "cte"] = "model", + default_catalog: t.Optional[str] = None, + dialect: t.Optional[str] = None, + cte_target_range: t.Optional[Range] = None, +) -> t.List[Reference]: + """ + Process column references for a given table and create appropriate reference objects. + + Args: + scope: The SQL scope to search for columns + reference_name: The full reference name (may include database/catalog) + read_file: The file content as list of lines + referenced_model_path: Path of the referenced model + description: Markdown description for the reference + yaml_target_range: Target range for external models (YAML files) + reference_type: Type of reference - "model", "external_model", or "cte" + default_catalog: Default catalog for normalization + dialect: SQL dialect for normalization + cte_target_range: Target range for CTE references + + Returns: + List of table references for column usages + """ + + references: t.List[Reference] = [] + for column in scope.find_all(exp.Column): + if column.table: + if reference_type == "cte": + if column.table == reference_name: + table_range = _get_column_table_range(column, read_file) + references.append( + CTEReference( + path=referenced_model_path, + range=table_range, + target_range=cte_target_range, + ) + ) + else: + table_parts = [part.sql(dialect) for part in column.parts[:-1]] + table_ref = ".".join(table_parts) + normalized_reference_name = normalize_model_name( + table_ref, + default_catalog=default_catalog, + dialect=dialect, + ) + if normalized_reference_name == reference_name: + table_range = _get_column_table_range(column, read_file) + if reference_type == "external_model": + references.append( + ExternalModelReference( + path=referenced_model_path, + range=table_range, + markdown_description=description, + target_range=yaml_target_range, + ) + ) + else: + references.append( + ModelReference( + path=referenced_model_path, + range=table_range, + markdown_description=description, + ) + ) + + return references + + +def _get_column_table_range(column: exp.Column, read_file: t.List[str]) -> Range: + """ + Get the range for a column's table reference, handling both simple and qualified table names. + + Args: + column: The column expression + read_file: The file content as list of lines + + Returns: + The Range covering the table reference in the column + """ + + table_parts = column.parts[:-1] + + start_range = TokenPositionDetails.from_meta(table_parts[0].meta).to_range(read_file) + end_range = TokenPositionDetails.from_meta(table_parts[-1].meta).to_range(read_file) + + return Range( + start=start_range.start, + end=end_range.end, + ) + + +def _get_yaml_model_range(path: Path, model_name: str) -> t.Optional[Range]: + """ + Find the range of a specific model block in a YAML file. + + Args: + yaml_path: Path to the YAML file + model_name: Name of the model to find + + Returns: + The Range of the model block in the YAML file, or None if not found + """ + model_name_ranges = get_yaml_model_name_ranges(path) + if model_name_ranges is None: + return None + return model_name_ranges.get(model_name, None) + + +def get_yaml_model_name_ranges(path: Path) -> t.Optional[t.Dict[str, Range]]: + """ + Get the ranges of all model names in a YAML file. + + Args: + path: Path to the YAML file + + Returns: + A dictionary mapping model names to their ranges in the YAML file. + """ + yaml = YAML() + with path.open("r", encoding="utf-8") as f: + data = yaml.load(f) + + if not isinstance(data, list): + return None + + model_name_ranges = {} + for item in data: + if isinstance(item, dict): + position_data = item.lc.data["name"] # type: ignore + start = Position(line=position_data[2], character=position_data[3]) + end = Position(line=position_data[2], character=position_data[3] + len(item["name"])) + name = item.get("name") + if not name: + continue + model_name_ranges[name] = Range(start=start, end=end) + + return model_name_ranges diff --git a/sqlmesh/utils/metaprogramming.py b/sqlmesh/utils/metaprogramming.py index 786e847da8..753db427f3 100644 --- a/sqlmesh/utils/metaprogramming.py +++ b/sqlmesh/utils/metaprogramming.py @@ -5,13 +5,16 @@ import importlib import inspect import linecache +import logging import os import re import sys import textwrap import types import typing as t +from dataclasses import dataclass from enum import Enum +from numbers import Number from pathlib import Path from astor import to_source @@ -21,11 +24,15 @@ from sqlmesh.utils.errors import SQLMeshError from sqlmesh.utils.pydantic import PydanticModel -IGNORE_DECORATORS = {"macro", "model"} +logger = logging.getLogger(__name__) + + +IGNORE_DECORATORS = {"macro", "model", "signal"} +SERIALIZABLE_CALLABLES = (type, types.FunctionType) +LITERALS = (Number, str, bytes, tuple, list, dict, set, bool) def _is_relative_to(path: t.Optional[Path | str], other: t.Optional[Path | str]) -> bool: - """path.is_relative_to compatibility, was only supported >= 3.9""" if path is None or other is None: return False @@ -34,7 +41,7 @@ def _is_relative_to(path: t.Optional[Path | str], other: t.Optional[Path | str]) if isinstance(other, str): other = Path(other) - if "site-packages" in str(path): + if "site-packages" in str(path) or not path.exists() or not other.exists(): return False try: @@ -58,6 +65,16 @@ def _code_globals(code: types.CodeType) -> t.Dict[str, None]: return variables +def _globals_match(obj1: t.Any, obj2: t.Any) -> bool: + return type(obj1) == type(obj2) and ( + obj1 == obj2 + or ( + getattr(obj1, "__module__", None) == getattr(obj2, "__module__", None) + and getattr(obj1, "__name__", None) == getattr(obj2, "__name__", None) + ) + ) + + def func_globals(func: t.Callable) -> t.Dict[str, t.Any]: """Finds all global references and closures in a function and nested functions. @@ -72,12 +89,23 @@ def func_globals(func: t.Callable) -> t.Dict[str, t.Any]: variables = {} if hasattr(func, "__code__"): - code = func.__code__ + root_node = parse_source(func) + + func_args = next(node for node in ast.walk(root_node) if isinstance(node, ast.arguments)) + arg_defaults = (d for d in func_args.defaults + func_args.kw_defaults if d is not None) - for var in list(_code_globals(code)) + decorators(func): + # ast.Name corresponds to variable references, such as foo or x.foo. The former is + # represented as Name(id=foo), and the latter as Attribute(value=Name(id=x) attr=foo) + arg_globals = [ + n.id for default in arg_defaults for n in ast.walk(default) if isinstance(n, ast.Name) + ] + + code = func.__code__ + for var in ( + arg_globals + list(_code_globals(code)) + decorator_vars(func, root_node=root_node) + ): if var in func.__globals__: - ref = func.__globals__[var] - variables[var] = ref + variables[var] = func.__globals__[var] if func.__closure__: for var, value in zip(code.co_freevars, func.__closure__): @@ -120,6 +148,38 @@ def visit_ClassDef(self, node: ast.ClassDef) -> None: self.stack.pop() +class _DecoratorDependencyFinder(ast.NodeVisitor): + def __init__(self) -> None: + self.dependencies: t.List[str] = [] + + def _extract_dependencies(self, node: ast.ClassDef | ast.FunctionDef) -> None: + for decorator in node.decorator_list: + dependencies: t.List[str] = [] + for n in ast.walk(decorator): + if isinstance(n, ast.Attribute): + dep = n.attr + elif isinstance(n, ast.Name): + dep = n.id + else: + continue + + if dep in IGNORE_DECORATORS: + dependencies = [] + break + + dependencies.append(dep) + + self.dependencies.extend(dependencies) + + def visit_FunctionDef(self, node: ast.FunctionDef) -> None: + self._extract_dependencies(node) + + def visit_ClassDef(self, node: ast.ClassDef) -> None: + self._extract_dependencies(node) + + visit_AsyncFunctionDef = visit_FunctionDef # type: ignore + + def getsource(obj: t.Any) -> str: """Get the source of a function or class. @@ -170,25 +230,22 @@ def parse_source(func: t.Callable) -> ast.Module: def _decorator_name(decorator: ast.expr) -> str: + node = decorator if isinstance(decorator, ast.Call): - return decorator.func.id # type: ignore - if isinstance(decorator, ast.Name): - return decorator.id - return "" - + node = decorator.func + return node.id if isinstance(node, ast.Name) else "" -def decorators(func: t.Callable) -> t.List[str]: - """Finds a list of all the decorators of a callable.""" - root_node = parse_source(func) - decorators = [] - for node in ast.walk(root_node): - if isinstance(node, (ast.FunctionDef, ast.ClassDef)): - for decorator in node.decorator_list: - name = _decorator_name(decorator) - if name not in IGNORE_DECORATORS: - decorators.append(name) - return unique(decorators) +def decorator_vars(func: t.Callable, root_node: t.Optional[ast.Module] = None) -> t.List[str]: + """ + Returns a list of all the decorators of a callable, as well as names of objects that + are referenced in their argument list. These objects may be transitive dependencies + that we need to include in the serialized python environments. + """ + root_node = root_node or parse_source(func) + finder = _DecoratorDependencyFinder() + finder.visit(root_node) + return unique(finder.dependencies) def normalize_source(obj: t.Any) -> str: @@ -223,9 +280,10 @@ def normalize_source(obj: t.Any) -> str: def build_env( obj: t.Any, *, - env: t.Dict[str, t.Any], + env: t.Dict[str, t.Tuple[t.Any, t.Optional[bool]]], name: str, path: Path, + is_metadata_obj: bool = False, ) -> None: """Fills in env dictionary with all globals needed to execute the object. @@ -236,59 +294,105 @@ def build_env( env: Dictionary to store the env. name: Name of the object in the env. path: The module path to serialize. Other modules will not be walked and treated as imports. + is_metadata_obj: An optional flag that determines whether the input object is metadata-only. """ + # We don't rely on `env` to keep track of visited objects, because it's populated in post-order + visited: t.Set[str] = set() - obj_module = inspect.getmodule(obj) + def walk(obj: t.Any, name: str, is_metadata: bool = False) -> None: + obj_module = inspect.getmodule(obj) + if obj_module and obj_module.__name__ == "builtins": + return - if obj_module and obj_module.__name__ == "builtins": - return + if name in visited: + if name not in env or _globals_match(env[name][0], obj): + return - def walk(obj: t.Any) -> None: - if inspect.isclass(obj): - for decorator in decorators(obj): - if obj_module and decorator in obj_module.__dict__: - build_env( - obj_module.__dict__[decorator], - env=env, - name=decorator, - path=path, - ) - - for base in obj.__bases__: - build_env(base, env=env, name=base.__qualname__, path=path) - - for k, v in obj.__dict__.items(): - if k.startswith("__"): - continue - # traverse methods in a class to find global references - if isinstance(v, (classmethod, staticmethod)): - v = v.__func__ - if callable(v): - # if the method is a part of the object, walk it - # else it is a global function and we just store it - if v.__qualname__.startswith(obj.__qualname__): - walk(v) + raise SQLMeshError( + f"Cannot store {obj} in environment, duplicate definitions found for '{name}'" + ) + + visited.add(name) + name_missing_from_env = name not in env + + if name_missing_from_env or (not is_metadata and env[name] == (obj, True)): + if not name_missing_from_env: + # The existing object in the env is "metadata only" but we're walking it again as a + # non-"metadata only" dependency, so we update this flag to ensure all transitive + # dependencies are also not marked as "metadata only" + is_metadata = False + + if hasattr(obj, c.SQLMESH_MACRO): + # We only need to add the undecorated code of @macro() functions in env, which + # is accessible through the `__wrapped__` attribute added by functools.wraps + obj = obj.__wrapped__ + elif callable(obj) and not isinstance(obj, SERIALIZABLE_CALLABLES): + obj = getattr(obj, "__wrapped__", None) + name = getattr(obj, "__name__", "") + + # Callable class instances shouldn't be serialized (e.g. tenacity.Retrying). + # We still want to walk the callables they decorate, though + if not isinstance(obj, SERIALIZABLE_CALLABLES) or name in env: + return + + if ( + not obj_module + or not hasattr(obj_module, "__file__") + or not _is_relative_to(obj_module.__file__, path) + ): + env[name] = (obj, is_metadata) + return + + if inspect.isclass(obj): + for var in decorator_vars(obj): + if obj_module and var in obj_module.__dict__: + walk(obj_module.__dict__[var], var, is_metadata) + + for base in obj.__bases__: + walk(base, base.__qualname__, is_metadata) + + for k, v in obj.__dict__.items(): + # skip dunder methods bar __init__ as it might contain user defined logic with cross class references + if k.startswith("__") and k != "__init__": + continue + + # Traverse methods in a class to find global references + if isinstance(v, (classmethod, staticmethod)): + v = v.__func__ + + if callable(v): + # Walk the method if it's part of the object, else it's a global function and we just store it + if v.__qualname__.startswith(obj.__qualname__): + try: + for k, v in func_globals(v).items(): + walk(v, k, is_metadata) + except (OSError, TypeError): + # __init__ may come from built-ins or wrapped callables + pass else: - build_env(v, env=env, name=v.__name__, path=path) - elif callable(obj): - for k, v in func_globals(obj).items(): - build_env(v, env=env, name=k, path=path) - - if name not in env: - # We only need to add the undecorated code of @macro() functions in env, which - # is accessible through the `__wrapped__` attribute added by functools.wraps - env[name] = obj.__wrapped__ if hasattr(obj, c.SQLMESH_MACRO) else obj - - if ( - obj_module - and hasattr(obj_module, "__file__") - and _is_relative_to(obj_module.__file__, path) - ): - walk(env[name]) - elif env[name] != obj: - raise SQLMeshError( - f"Cannot store {obj} in environment, duplicate definitions found for '{name}'" - ) + walk(v, k, is_metadata) + elif callable(obj): + for k, v in func_globals(obj).items(): + walk(v, k, is_metadata) + + # We store the object in the environment after its dependencies, because otherwise we + # could crash at environment hydration time, since dicts are ordered and the top-level + # objects would be loaded before their dependencies. + env[name] = (obj, is_metadata) + elif not _globals_match(env[name][0], obj): + raise SQLMeshError( + f"Cannot store {obj} in environment, duplicate definitions found for '{name}'" + ) + + # The "metadata only" annotation of the object is transitive + walk(obj, name, is_metadata_obj or getattr(obj, c.SQLMESH_METADATA, False)) + + +@dataclass +class SqlValue: + """A SQL string representing a generated SQLGlot AST.""" + + sql: str class ExecutableKind(str, Enum): @@ -329,8 +433,15 @@ def is_value(self) -> bool: return self.kind == ExecutableKind.VALUE @classmethod - def value(cls, v: t.Any) -> Executable: - return Executable(payload=repr(v), kind=ExecutableKind.VALUE) + def value( + cls, v: t.Any, is_metadata: t.Optional[bool] = None, sort_root_dict: bool = False + ) -> Executable: + payload = _dict_sort(v) if sort_root_dict else repr(v) + return Executable( + payload=payload, + kind=ExecutableKind.VALUE, + is_metadata=is_metadata or None, + ) def serialize_env(env: t.Dict[str, t.Any], path: Path) -> t.Dict[str, Executable]: @@ -344,16 +455,52 @@ def serialize_env(env: t.Dict[str, t.Any], path: Path) -> t.Dict[str, Executable """ serialized = {} - for k, v in env.items(): - if callable(v): + for k, (v, is_metadata) in env.items(): + # We don't store `False` for `is_metadata` to reduce the pydantic model's payload size + is_metadata = is_metadata or None + + if isinstance(v, LITERALS) or v is None: + serialized[k] = Executable.value(v, is_metadata=is_metadata) + elif inspect.ismodule(v): + name = v.__name__ + if hasattr(v, "__file__") and _is_relative_to(v.__file__, path): + raise SQLMeshError( + f"Cannot serialize 'import {name}'. Use 'from {name} import ...' instead." + ) + postfix = "" if name == k else f" as {k}" + serialized[k] = Executable( + payload=f"import {name}{postfix}", + kind=ExecutableKind.IMPORT, + is_metadata=is_metadata, + ) + elif callable(v): name = v.__name__ name = k if name == "" else name - # We can't call getfile on built-in callables + # getfile raises a `TypeError` for built-in modules, classes, or functions # https://docs.python.org/3/library/inspect.html#inspect.getfile - file_path = Path(inspect.getfile(v)) if not inspect.isbuiltin(v) else None - - if _is_relative_to(file_path, path): + try: + file_path = Path(inspect.getfile(v)) + relative_obj_file_path = _is_relative_to(file_path, path) + + # A callable can be a "wrapper" that is defined in a third-party library [1], in which case the file + # containing its definition won't be relative to the project's path. This can lead to serializing + # it as a "relative import", such as `from models.some_python_model import foo`, because the `wraps` + # decorator preserves the wrapped function's module [2]. Payloads like this are invalid, as they + # can result in `ModuleNotFoundError`s when hydrating python environments, e.g. if a project's files + # are not available during a scheduled cadence run. + # + # [1]: https://github.com/jd/tenacity/blob/0d40e76f7d06d631fb127e1ec58c8bd776e70d49/tenacity/__init__.py#L322-L346 + # [2]: https://github.com/python/cpython/blob/f502c8f6a6db4be27c97a0e5466383d117859b7f/Lib/functools.py#L33-L57 + if not relative_obj_file_path and (wrapped := getattr(v, "__wrapped__", None)): + v = wrapped + file_path = Path(inspect.getfile(wrapped)) + relative_obj_file_path = _is_relative_to(file_path, path) + except TypeError: + file_path = None + relative_obj_file_path = False + + if relative_obj_file_path: serialized[k] = Executable( name=name, payload=normalize_source(v), @@ -361,26 +508,20 @@ def serialize_env(env: t.Dict[str, t.Any], path: Path) -> t.Dict[str, Executable # Do `as_posix` to serialize windows path back to POSIX path=t.cast(Path, file_path).relative_to(path.absolute()).as_posix(), alias=k if name != k else None, - is_metadata=getattr(v, c.SQLMESH_METADATA, None), + is_metadata=is_metadata, ) else: serialized[k] = Executable( payload=f"from {v.__module__} import {name}", kind=ExecutableKind.IMPORT, + is_metadata=is_metadata, ) - elif inspect.ismodule(v): - name = v.__name__ - if hasattr(v, "__file__") and _is_relative_to(v.__file__, path): - raise SQLMeshError( - f"Cannot serialize 'import {name}'. Use 'from {name} import ...' instead." - ) - postfix = "" if name == k else f" as {k}" - serialized[k] = Executable( - payload=f"import {name}{postfix}", - kind=ExecutableKind.IMPORT, - ) else: - serialized[k] = Executable.value(v) + raise SQLMeshError( + f"Object '{v}' cannot be serialized. If it's defined in a library, import the corresponding " + "module and reference the object using its fully-qualified name. For example, the datetime " + "module's 'UTC' object should be accessed as 'datetime.UTC'." + ) return serialized @@ -407,19 +548,19 @@ def prepare_env( python_env.items(), key=lambda item: 0 if item[1].is_import else 1 ): if executable.is_value: - env[name] = ast.literal_eval(executable.payload) + env[name] = eval(executable.payload) else: exec(executable.payload, env) if executable.alias and executable.name: env[executable.alias] = env[executable.name] + return env -def print_exception( +def format_evaluated_code_exception( exception: Exception, python_env: t.Dict[str, Executable], - out: t.TextIO = sys.stderr, -) -> None: +) -> str: """Formats exceptions that occur from evaled code. Stack traces generated by evaled code lose code context and are difficult to debug. @@ -428,26 +569,41 @@ def print_exception( Args: exception: The exception to print the stack trace for. python_env: The environment containing stringified python code. - out: The output stream to write to. """ tb: t.List[str] = [] + indent = "" + + skip_patterns = re.compile( + r"Traceback \(most recent call last\):|" + r'File ".*?core/model/definition\.py|' + r'File ".*?core/snapshot/definition\.py|' + r'File ".*?core/macros\.py|' + r'File ".*?inspect\.py' + ) for error_line in format_exception(exception): - match = re.search('File "", line (.*), in (.*)', error_line) + if skip_patterns.search(error_line): + continue - if not match: - tb.append(error_line) + error_match = re.search("^.*?Error: ", error_line) + if error_match: + tb.append(f"{indent * 2} {error_line}") continue - line_num = int(match.group(1)) - func = match.group(2) + eval_code_match = re.search('File "", line (.*), in (.*)', error_line) + if not eval_code_match: + tb.append(f"{indent}{error_line}") + continue + + line_num = int(eval_code_match.group(1)) + func = eval_code_match.group(2) if func not in python_env: tb.append(error_line) continue executable = python_env[func] - indent = error_line[: match.start()] + indent = error_line[: eval_code_match.start()] error_line = ( f"{indent}File '{executable.path}' (or imported file), line {line_num}, in {func}" @@ -471,11 +627,38 @@ def print_exception( os.linesep.join(formatted), indent + " ", ), - os.linesep, ) ) - out.write(os.linesep.join(tb)) + return os.linesep.join(tb) + + +def print_exception( + exception: Exception, + python_env: t.Dict[str, Executable], + out: t.TextIO = sys.stderr, +) -> None: + """Prints exceptions that occur from evaled code. + + Stack traces generated by evaled code lose code context and are difficult to debug. + This intercepts the default stack trace and tries to make it debuggable. + + Args: + exception: The exception to print the stack trace for. + python_env: The environment containing stringified python code. + out: The output stream to write to. + """ + tb = format_evaluated_code_exception(exception, python_env) + out.write(tb) + + +def _dict_sort(obj: t.Any) -> str: + try: + if isinstance(obj, dict): + obj = dict(sorted(obj.items(), key=lambda x: str(x[0]))) + except Exception: + logger.warning("Failed to sort non-recursive dict", exc_info=True) + return repr(obj) def import_python_file(path: Path, relative_base: Path = Path()) -> types.ModuleType: diff --git a/sqlmesh/utils/pandas.py b/sqlmesh/utils/pandas.py index 86662f802f..43851e861a 100644 --- a/sqlmesh/utils/pandas.py +++ b/sqlmesh/utils/pandas.py @@ -1,40 +1,68 @@ from __future__ import annotations import typing as t +from functools import lru_cache -import numpy as np -import pandas as pd from sqlglot import exp +if t.TYPE_CHECKING: + import pandas as pd -PANDAS_TYPE_MAPPINGS = { - np.dtype("int8"): exp.DataType.build("tinyint"), - np.dtype("int16"): exp.DataType.build("smallint"), - np.dtype("int32"): exp.DataType.build("int"), - np.dtype("int64"): exp.DataType.build("bigint"), - np.dtype("float16"): exp.DataType.build("float"), - np.dtype("float32"): exp.DataType.build("float"), - np.dtype("float64"): exp.DataType.build("double"), - np.dtype("O"): exp.DataType.build("text"), - np.dtype("bool"): exp.DataType.build("boolean"), - np.dtype("datetime64"): exp.DataType.build("timestamp"), - np.dtype("datetime64[ns]"): exp.DataType.build("timestamp"), - np.dtype("datetime64[us]"): exp.DataType.build("timestamp"), - pd.Int8Dtype(): exp.DataType.build("tinyint"), - pd.Int16Dtype(): exp.DataType.build("smallint"), - pd.Int32Dtype(): exp.DataType.build("int"), - pd.Int64Dtype(): exp.DataType.build("bigint"), - pd.Float32Dtype(): exp.DataType.build("float"), - pd.Float64Dtype(): exp.DataType.build("double"), - pd.StringDtype(): exp.DataType.build("text"), # type: ignore - pd.BooleanDtype(): exp.DataType.build("boolean"), -} + +@lru_cache() +def get_pandas_type_mappings() -> t.Dict[t.Any, exp.DataType]: + import pandas as pd + import numpy as np + + mappings = { + np.dtype("int8"): exp.DataType.build("tinyint"), + np.dtype("int16"): exp.DataType.build("smallint"), + np.dtype("int32"): exp.DataType.build("int"), + np.dtype("int64"): exp.DataType.build("bigint"), + np.dtype("float16"): exp.DataType.build("float"), + np.dtype("float32"): exp.DataType.build("float"), + np.dtype("float64"): exp.DataType.build("double"), + np.dtype("O"): exp.DataType.build("text"), + np.dtype("bool"): exp.DataType.build("boolean"), + np.dtype("datetime64"): exp.DataType.build("timestamp"), + np.dtype("datetime64[ns]"): exp.DataType.build("timestamp"), + np.dtype("datetime64[us]"): exp.DataType.build("timestamp"), + pd.Int8Dtype(): exp.DataType.build("tinyint"), + pd.Int16Dtype(): exp.DataType.build("smallint"), + pd.Int32Dtype(): exp.DataType.build("int"), + pd.Int64Dtype(): exp.DataType.build("bigint"), + pd.Float32Dtype(): exp.DataType.build("float"), + pd.Float64Dtype(): exp.DataType.build("double"), + pd.StringDtype(): exp.DataType.build("text"), # type: ignore + pd.BooleanDtype(): exp.DataType.build("boolean"), + } + try: + import pyarrow # type: ignore # noqa + + # Only add this if pyarrow is installed + mappings[pd.StringDtype("pyarrow")] = exp.DataType.build("text") + except ImportError: + pass + + return mappings def columns_to_types_from_df(df: pd.DataFrame) -> t.Dict[str, exp.DataType]: + return columns_to_types_from_dtypes(df.dtypes.items()) + + +def columns_to_types_from_dtypes( + dtypes: t.Iterable[t.Tuple[t.Hashable, t.Any]], +) -> t.Dict[str, exp.DataType]: + import pandas as pd + result = {} - for column_name, column_type in df.dtypes.items(): - exp_type = PANDAS_TYPE_MAPPINGS.get(column_type) + for column_name, column_type in dtypes: + exp_type: t.Optional[exp.DataType] = None + if hasattr(pd, "DatetimeTZDtype") and isinstance(column_type, pd.DatetimeTZDtype): + exp_type = exp.DataType.build("timestamptz") + else: + exp_type = get_pandas_type_mappings().get(column_type) if not exp_type: raise ValueError(f"Unsupported pandas type '{column_type}'") result[str(column_name)] = exp_type diff --git a/sqlmesh/utils/process.py b/sqlmesh/utils/process.py new file mode 100644 index 0000000000..453fee78f5 --- /dev/null +++ b/sqlmesh/utils/process.py @@ -0,0 +1,73 @@ +# mypy: disable-error-code=no-untyped-def + +from concurrent.futures import Future, ProcessPoolExecutor +import typing as t +import multiprocessing as mp +from sqlmesh.utils.windows import IS_WINDOWS + + +class SynchronousPoolExecutor: + """A mock implementation of the ProcessPoolExecutor for synchronous use. + + This executor runs functions synchronously in the same process, avoiding the issues + with forking in test environments or when forking isn't possible (non-posix). + """ + + def __init__(self, max_workers=None, mp_context=None, initializer=None, initargs=()): + if initializer is not None: + try: + initializer(*initargs) + except BaseException as ex: + raise RuntimeError(f"Exception in initializer: {ex}") + + def __enter__(self): + return self + + def __exit__(self, *args): + self.shutdown(wait=True) + return False + + def shutdown(self, wait=True, cancel_futures=False): + """No-op method to match ProcessPoolExecutor API. + + Since this executor runs synchronously, there are no background processes + or resources to shut down and all futures will have completed already. + """ + pass + + def submit(self, fn, *args, **kwargs): + """Execute the function synchronously and return a Future with the result.""" + future = Future() + try: + result = fn(*args, **kwargs) + future.set_result(result) + except Exception as e: + future.set_exception(e) + return future + + def map(self, fn, *iterables, timeout=None, chunksize=1): + """Synchronous implementation of ProcessPoolExecutor.map. + + This executes the function for each set of inputs from the iterables in the + current process using Python's built-in map, rather than distributing work. + """ + return map(fn, *iterables) + + +PoolExecutor = t.Union[SynchronousPoolExecutor, ProcessPoolExecutor] + + +def create_process_pool_executor( + initializer: t.Callable, initargs: t.Tuple, max_workers: t.Optional[int] +) -> PoolExecutor: + if max_workers == 1 or IS_WINDOWS: + return SynchronousPoolExecutor( + initializer=initializer, + initargs=initargs, + ) + return ProcessPoolExecutor( + mp_context=mp.get_context("fork"), + initializer=initializer, + initargs=initargs, + max_workers=max_workers, + ) diff --git a/sqlmesh/utils/pydantic.py b/sqlmesh/utils/pydantic.py index 63e53abac4..2c9c570e5b 100644 --- a/sqlmesh/utils/pydantic.py +++ b/sqlmesh/utils/pydantic.py @@ -1,11 +1,11 @@ from __future__ import annotations import json -import sys import typing as t -from functools import cached_property, wraps +from datetime import tzinfo import pydantic +from pydantic import ValidationInfo as ValidationInfo from pydantic.fields import FieldInfo from sqlglot import exp, parse_one from sqlglot.helper import ensure_list @@ -15,64 +15,30 @@ from sqlmesh.core import dialect as d from sqlmesh.utils import str_to_bool -if sys.version_info >= (3, 9): - from typing import Annotated -else: - from typing_extensions import Annotated - if t.TYPE_CHECKING: + from sqlglot._typing import E + Model = t.TypeVar("Model", bound="PydanticModel") T = t.TypeVar("T") DEFAULT_ARGS = {"exclude_none": True, "by_alias": True} +PRIVATE_FIELDS = "__pydantic_private__" PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION = [int(p) for p in pydantic.__version__.split(".")][ :2 ] -if PYDANTIC_MAJOR_VERSION >= 2: - - def field_validator(*args: t.Any, **kwargs: t.Any) -> t.Callable[[t.Any], t.Any]: - # Pydantic v2 doesn't support "always" argument. The validator behaves as if "always" is True. - kwargs.pop("always", None) - return pydantic.field_validator(*args, **kwargs) # type: ignore - - def model_validator(*args: t.Any, **kwargs: t.Any) -> t.Callable[[t.Any], t.Any]: - # Pydantic v2 doesn't support "always" argument. The validator behaves as if "always" is True. - kwargs.pop("always", None) - return pydantic.model_validator(*args, **kwargs) # type: ignore - - def field_serializer(*args: t.Any, **kwargs: t.Any) -> t.Callable[[t.Any], t.Any]: - return pydantic.field_serializer(*args, **kwargs) # type: ignore - -else: - - def field_validator(*args: t.Any, **kwargs: t.Any) -> t.Callable[[t.Any], t.Any]: - mode = kwargs.pop("mode", "after") - return pydantic.validator(*args, **kwargs, pre=mode.lower() == "before", allow_reuse=True) - - def model_validator(*args: t.Any, **kwargs: t.Any) -> t.Callable[[t.Any], t.Any]: - mode = kwargs.pop("mode", "after") - return pydantic.root_validator( - *args, **kwargs, pre=mode.lower() == "before", allow_reuse=True - ) - - def field_serializer(*args: t.Any, **kwargs: t.Any) -> t.Callable[[t.Any], t.Any]: - def _decorator(func: t.Callable[[t.Any], t.Any]) -> t.Callable[[t.Any], t.Any]: - @wraps(func) - def _wrapper(*args: t.Any, **kwargs: t.Any) -> t.Any: - return func(*args, **kwargs) +def field_validator(*args: t.Any, **kwargs: t.Any) -> t.Callable[[t.Any], t.Any]: + return pydantic.field_validator(*args, **kwargs) - return _wrapper - return _decorator +def model_validator(*args: t.Any, **kwargs: t.Any) -> t.Callable[[t.Any], t.Any]: + return pydantic.model_validator(*args, **kwargs) -def parse_obj_as(type_: T, obj: t.Any) -> T: - if PYDANTIC_MAJOR_VERSION >= 2: - return pydantic.TypeAdapter(type_).validate_python(obj) # type: ignore - return pydantic.tools.parse_obj_as(type_, obj) # type: ignore +def field_serializer(*args: t.Any, **kwargs: t.Any) -> t.Callable[[t.Any], t.Any]: + return pydantic.field_serializer(*args, **kwargs) def get_dialect(values: t.Any) -> str: @@ -87,7 +53,7 @@ def get_dialect(values: t.Any) -> str: from sqlmesh.core.model import model dialect = (values if isinstance(values, dict) else values.data).get("dialect") - return model._dialect if dialect is None else dialect + return model._dialect if dialect is None else dialect # type: ignore def _expression_encoder(e: exp.Expression) -> str: @@ -99,93 +65,54 @@ def _expression_encoder(e: exp.Expression) -> str: class PydanticModel(pydantic.BaseModel): - if PYDANTIC_MAJOR_VERSION >= 2: - model_config = pydantic.ConfigDict( # type: ignore - arbitrary_types_allowed=True, - extra="forbid", # type: ignore - # Even though Pydantic v2 kept support for json_encoders, the functionality has been - # crippled badly. Here we need to enumerate all different ways of how sqlglot expressions - # show up in pydantic models. - json_encoders={ - exp.Expression: _expression_encoder, - exp.DataType: _expression_encoder, - exp.Tuple: _expression_encoder, - AuditQueryTypes: _expression_encoder, # type: ignore - ModelQueryTypes: _expression_encoder, # type: ignore - }, - protected_namespaces=(), - ) - else: - - class Config: - arbitrary_types_allowed = True - extra = "forbid" - json_encoders = {exp.Expression: _expression_encoder} - underscore_attrs_are_private = True - smart_union = True - keep_untouched = (cached_property,) + model_config = pydantic.ConfigDict( + # Even though Pydantic v2 kept support for json_encoders, the functionality has been + # crippled badly. Here we need to enumerate all different ways of how sqlglot expressions + # show up in pydantic models. + json_encoders={ + exp.Expression: _expression_encoder, + exp.DataType: _expression_encoder, + exp.Tuple: _expression_encoder, + AuditQueryTypes: _expression_encoder, # type: ignore + ModelQueryTypes: _expression_encoder, # type: ignore + tzinfo: lambda tz: tz.key, + }, + arbitrary_types_allowed=True, + extra="forbid", + protected_namespaces=(), + ) _hash_func_mapping: t.ClassVar[t.Dict[t.Type[t.Any], t.Callable[[t.Any], int]]] = {} - def dict( - self, - **kwargs: t.Any, - ) -> t.Dict[str, t.Any]: + def dict(self, **kwargs: t.Any) -> t.Dict[str, t.Any]: kwargs = {**DEFAULT_ARGS, **kwargs} - if PYDANTIC_MAJOR_VERSION >= 2: - return super().model_dump(**kwargs) # type: ignore - - include = kwargs.pop("include", None) - if include is None and self.__config__.extra != "allow": # type: ignore - # Workaround to support @cached_property in Pydantic v1. - include = {f.name for f in self.all_field_infos().values()} # type: ignore - - mode = kwargs.pop("mode", None) - if mode == "json": - # Pydantic v1 doesn't support the 'json' mode for dict(). - return json.loads(super().json(include=include, **kwargs)) - return super().dict(include=include, **kwargs) # type: ignore + return super().model_dump(**kwargs) # type: ignore def json( self, **kwargs: t.Any, ) -> str: kwargs = {**DEFAULT_ARGS, **kwargs} - if PYDANTIC_MAJOR_VERSION >= 2: - # Pydantic v2 doesn't support arbitrary arguments for json.dump(). - if kwargs.pop("sort_keys", False): - return json.dumps(super().model_dump(mode="json", **kwargs), sort_keys=True) # type: ignore - else: - return super().model_dump_json(**kwargs) # type: ignore + # Pydantic v2 doesn't support arbitrary arguments for json.dump(). + if kwargs.pop("sort_keys", False): + return json.dumps(super().model_dump(mode="json", **kwargs), sort_keys=True) - include = kwargs.pop("include", None) - if include is None and self.__config__.extra != "allow": # type: ignore - # Workaround to support @cached_property in Pydantic v1. - include = {f.name for f in self.all_field_infos().values()} # type: ignore - return super().json(include=include, **kwargs) # type: ignore + return super().model_dump_json(**kwargs) def copy(self: "Model", **kwargs: t.Any) -> "Model": - return ( - super().model_copy(**kwargs) if PYDANTIC_MAJOR_VERSION >= 2 else super().copy(**kwargs) # type: ignore - ) + return super().model_copy(**kwargs) @property def fields_set(self: "Model") -> t.Set[str]: - return self.__pydantic_fields_set__ if PYDANTIC_MAJOR_VERSION >= 2 else self.__fields_set__ # type: ignore + return self.__pydantic_fields_set__ @classmethod def parse_obj(cls: t.Type["Model"], obj: t.Any) -> "Model": - return ( - super().model_validate(obj) if PYDANTIC_MAJOR_VERSION >= 2 else super().parse_obj(obj) # type: ignore - ) + return super().model_validate(obj) @classmethod def parse_raw(cls: t.Type["Model"], b: t.Union[str, bytes], **kwargs: t.Any) -> "Model": - return ( - super().model_validate_json(b, **kwargs) # type: ignore - if PYDANTIC_MAJOR_VERSION >= 2 - else super().parse_raw(b, **kwargs) - ) + return super().model_validate_json(b, **kwargs) @classmethod def missing_required_fields( @@ -203,13 +130,11 @@ def all_fields(cls: t.Type["PydanticModel"]) -> t.Set[str]: @classmethod def all_field_infos(cls: t.Type["PydanticModel"]) -> t.Dict[str, FieldInfo]: - return cls.model_fields if PYDANTIC_MAJOR_VERSION >= 2 else cls.__fields__ # type: ignore + return cls.model_fields @classmethod def required_fields(cls: t.Type["PydanticModel"]) -> t.Set[str]: - return cls._fields( - lambda field: field.is_required() if PYDANTIC_MAJOR_VERSION >= 2 else field.required - ) # type: ignore + return cls._fields(lambda field: field.is_required()) @classmethod def _fields( @@ -218,7 +143,7 @@ def _fields( ) -> t.Set[str]: return { field_info.alias if field_info.alias else field_name - for field_name, field_info in cls.all_field_infos().items() # type: ignore + for field_name, field_info in cls.all_field_infos().items() if predicate(field_info) } @@ -226,8 +151,7 @@ def __eq__(self, other: t.Any) -> bool: if (PYDANTIC_MAJOR_VERSION, PYDANTIC_MINOR_VERSION) < (2, 6): if isinstance(other, pydantic.BaseModel): return self.dict() == other.dict() - else: - return self.dict() == other + return self.dict() == other return super().__eq__(other) def __hash__(self) -> int: @@ -235,12 +159,11 @@ def __hash__(self) -> int: obj = {k: v for k, v in self.__dict__.items() if k in self.all_field_infos()} return hash(self.__class__) + hash(tuple(obj.values())) - from pydantic._internal._model_construction import ( # type: ignore - make_hash_func, - ) + from pydantic._internal._model_construction import make_hash_func # type: ignore if self.__class__ not in PydanticModel._hash_func_mapping: PydanticModel._hash_func_mapping[self.__class__] = make_hash_func(self.__class__) + return PydanticModel._hash_func_mapping[self.__class__](self) def __str__(self) -> str: @@ -258,30 +181,6 @@ def __repr__(self) -> str: return str(self) -def model_validator_v1_args(func: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]: - @wraps(func) - def wrapper(cls: t.Type, values: t.Any, *args: t.Any, **kwargs: t.Any) -> t.Any: - is_values_dict = isinstance(values, dict) - values_dict = values if is_values_dict else values.__dict__ - result = func(cls, values_dict, *args, **kwargs) - if is_values_dict: - return result - else: - values.__dict__.update(result) - return values - - return wrapper - - -def field_validator_v1_args(func: t.Callable[..., t.Any]) -> t.Callable[..., t.Any]: - @wraps(func) - def wrapper(cls: t.Type, v: t.Any, values: t.Any, *args: t.Any, **kwargs: t.Any) -> t.Any: - values_dict = values if isinstance(values, dict) else values.data - return func(cls, v, values_dict, *args, **kwargs) - - return wrapper - - def validate_list_of_strings(v: t.Any) -> t.List[str]: if isinstance(v, exp.Identifier): return [v.name] @@ -296,6 +195,12 @@ def validate_string(v: t.Any) -> str: return str(v) +def validate_expression(expression: E, dialect: str) -> E: + # this normalizes and quotes identifiers in the given expression according the specified dialect + # it also sets expression.meta["dialect"] so that when we serialize for state, the expression is serialized in the correct dialect + return _get_field(expression, {"dialect": dialect}) # type: ignore + + def bool_validator(v: t.Any) -> bool: if isinstance(v, exp.Boolean): return v.this @@ -314,6 +219,21 @@ def positive_int_validator(v: t.Any) -> int: return v +def validation_error_message(error: pydantic.ValidationError, base: str) -> str: + errors = "\n ".join(_formatted_validation_errors(error)) + return f"{base}\n {errors}" + + +def _formatted_validation_errors(error: pydantic.ValidationError) -> t.List[str]: + result = [] + for e in error.errors(): + msg = e["msg"] + loc: t.Optional[t.Tuple] = e.get("loc") + loc_str = ".".join(loc) if loc else None + result.append(f"Invalid field '{loc_str}':\n {msg}" if loc_str else msg) + return result + + def _get_field( v: t.Any, values: t.Any, @@ -369,13 +289,47 @@ def column_validator(v: t.Any, values: t.Any) -> exp.Column: return expression -def list_of_columns_or_star_validator( +def list_of_fields_or_star_validator( v: t.Any, values: t.Any -) -> t.Union[exp.Star, t.List[exp.Column]]: +) -> t.Union[exp.Star, t.List[exp.Expression]]: expressions = _get_fields(v, values) if len(expressions) == 1 and isinstance(expressions[0], exp.Star): return t.cast(exp.Star, expressions[0]) - return t.cast(t.List[exp.Column], expressions) + return t.cast(t.List[exp.Expression], expressions) + + +def cron_validator(v: t.Any) -> str: + if isinstance(v, exp.Expression): + v = v.name + + from croniter import CroniterBadCronError, croniter + + if not isinstance(v, str): + raise ValueError(f"Invalid cron expression '{v}'. Value must be a string.") + + try: + croniter(v) + except CroniterBadCronError: + raise ValueError(f"Invalid cron expression '{v}'") + return v + + +def get_concrete_types_from_typehint(typehint: type[t.Any]) -> set[type[t.Any]]: + concrete_types = set() + unpacked = t.get_origin(typehint) + if unpacked is None: + if type(typehint) == type(type): + return {typehint} + elif unpacked is t.Union: + for item in t.get_args(typehint): + if str(item).startswith("typing."): + concrete_types |= get_concrete_types_from_typehint(item) + else: + concrete_types.add(item) + else: + concrete_types.add(unpacked) + + return concrete_types if t.TYPE_CHECKING: @@ -385,47 +339,20 @@ def list_of_columns_or_star_validator( SQLGlotPositiveInt = int SQLGlotColumn = exp.Column SQLGlotListOfFields = t.List[exp.Expression] - SQLGlotListOfColumnsOrStar = t.Union[t.List[exp.Column], exp.Star] -elif PYDANTIC_MAJOR_VERSION >= 2: - from pydantic.functional_validators import BeforeValidator # type: ignore - - SQLGlotListOfStrings = Annotated[t.List[str], BeforeValidator(validate_list_of_strings)] - SQLGlotString = Annotated[str, BeforeValidator(validate_string)] - SQLGlotBool = Annotated[bool, BeforeValidator(bool_validator)] - SQLGlotPositiveInt = Annotated[int, BeforeValidator(positive_int_validator)] - SQLGlotColumn = Annotated[exp.Expression, BeforeValidator(column_validator)] - SQLGlotListOfFields = Annotated[ + SQLGlotListOfFieldsOrStar = t.Union[SQLGlotListOfFields, exp.Star] + SQLGlotCron = str +else: + from pydantic.functional_validators import BeforeValidator + + SQLGlotListOfStrings = t.Annotated[t.List[str], BeforeValidator(validate_list_of_strings)] + SQLGlotString = t.Annotated[str, BeforeValidator(validate_string)] + SQLGlotBool = t.Annotated[bool, BeforeValidator(bool_validator)] + SQLGlotPositiveInt = t.Annotated[int, BeforeValidator(positive_int_validator)] + SQLGlotColumn = t.Annotated[exp.Expression, BeforeValidator(column_validator)] + SQLGlotListOfFields = t.Annotated[ t.List[exp.Expression], BeforeValidator(list_of_fields_validator) ] - SQLGlotListOfColumnsOrStar = Annotated[ - t.Union[t.List[exp.Column], exp.Star], BeforeValidator(list_of_columns_or_star_validator) + SQLGlotListOfFieldsOrStar = t.Annotated[ + t.Union[SQLGlotListOfFields, exp.Star], BeforeValidator(list_of_fields_or_star_validator) ] -else: - - class PydanticTypeProxy(t.Generic[T]): - validate: t.Callable[[t.Any], T] - - @classmethod - def __get_validators__(cls) -> t.Iterator[t.Callable[[t.Any], T]]: - yield cls.validate - - class SQLGlotListOfStrings(PydanticTypeProxy[t.List[str]]): - validate = validate_list_of_strings - - class SQLGlotString(PydanticTypeProxy[str]): - validate = validate_string - - class SQLGlotBool(PydanticTypeProxy[bool]): - validate = bool_validator - - class SQLGlotPositiveInt(PydanticTypeProxy[int]): - validate = positive_int_validator - - class SQLGlotColumn(PydanticTypeProxy[exp.Column]): - validate = column_validator - - class SQLGlotListOfFields(PydanticTypeProxy[t.List[exp.Expression]]): - validate = list_of_fields_validator - - class SQLGlotListOfColumnsOrStar(PydanticTypeProxy[t.Union[exp.Star, t.List[exp.Column]]]): - validate = list_of_columns_or_star_validator + SQLGlotCron = t.Annotated[str, BeforeValidator(cron_validator)] diff --git a/sqlmesh/utils/rich.py b/sqlmesh/utils/rich.py index 6ebeab3114..0b43e3d87c 100644 --- a/sqlmesh/utils/rich.py +++ b/sqlmesh/utils/rich.py @@ -1,8 +1,17 @@ +from __future__ import annotations + import typing as t +import re + from rich.console import Console from rich.progress import Column, ProgressColumn, Task, Text from rich.theme import Theme +from rich.table import Table +from rich.align import Align + +if t.TYPE_CHECKING: + import pandas as pd theme = Theme( { @@ -46,3 +55,50 @@ def render(self, task: Task) -> Text: f"{completed:{total_width}d}{self.separator}{total}", style="progress.download", ) + + +def strip_ansi_codes(text: str) -> str: + """Strip ANSI color codes and styling from text.""" + ansi_escape = re.compile(r"\x1b\[[0-9;]*[a-zA-Z]") + return ansi_escape.sub("", text).strip() + + +def df_to_table( + header: str, + df: pd.DataFrame, + show_index: bool = True, + index_name: str = "Row", +) -> Table: + """Convert a pandas.DataFrame obj into a rich.Table obj. + Args: + df (DataFrame): A Pandas DataFrame to be converted to a rich Table. + rich_table (Table): A rich Table that should be populated by the DataFrame values. + show_index (bool): Add a column with a row count to the table. Defaults to True. + index_name (str, optional): The column name to give to the index column. Defaults to None, showing no value. + Returns: + Table: The rich Table instance passed, populated with the DataFrame values.""" + + rich_table = Table(title=f"[bold red]{header}[/bold red]", show_lines=True, min_width=60) + if show_index: + index_name = str(index_name) if index_name else "" + rich_table.add_column(Align.center(index_name)) + + for column in df.columns: + column_name = column if isinstance(column, str) else ": ".join(str(col) for col in column) + + # Color coding unit test columns (expected/actual), can be removed or refactored if df_to_table is used elswhere too + lower = column_name.lower() + if "expected" in lower: + column_name = f"[green]{column_name}[/green]" + elif "actual" in lower: + column_name = f"[red]{column_name}[/red]" + + rich_table.add_column(Align.center(column_name)) + + for index, value_list in zip(df.index, df.values.tolist()): + row = [str(index)] if show_index else [] + row += [str(x) for x in value_list] + center = [Align.center(x) for x in row] + rich_table.add_row(*center) + + return rich_table diff --git a/sqlmesh/utils/windows.py b/sqlmesh/utils/windows.py new file mode 100644 index 0000000000..b2de5b8af9 --- /dev/null +++ b/sqlmesh/utils/windows.py @@ -0,0 +1,24 @@ +import platform +from pathlib import Path + +IS_WINDOWS = platform.system() == "Windows" + +WINDOWS_LONGPATH_PREFIX = "\\\\?\\" + + +def fix_windows_path(path: Path) -> Path: + """ + Windows paths are limited to 260 characters: https://learn.microsoft.com/en-us/windows/win32/fileio/maximum-file-path-limitation + Users can change this by updating a registry entry but we cant rely on that. + + SQLMesh quite commonly generates cache file paths that exceed 260 characters and thus cause a FileNotFound error. + If we prefix paths with "\\?\" then we can have paths up to 32,767 characters. + + Note that this prefix also means that relative paths no longer work. From the above docs: + > Because you cannot use the "\\?\" prefix with a relative path, relative paths are always limited to a total of MAX_PATH characters. + + So we also call path.resolve() to resolve the relative sections so that operations like `path.read_text()` continue to work + """ + if path.parts and not path.parts[0].startswith(WINDOWS_LONGPATH_PREFIX): + path = Path(WINDOWS_LONGPATH_PREFIX + str(path.absolute())) + return path.resolve() diff --git a/sqlmesh/utils/yaml.py b/sqlmesh/utils/yaml.py index fcf2c19114..d72e9d49e5 100644 --- a/sqlmesh/utils/yaml.py +++ b/sqlmesh/utils/yaml.py @@ -1,11 +1,14 @@ from __future__ import annotations +import getpass import io import typing as t +from decimal import Decimal from os import getenv from pathlib import Path from ruamel import yaml +from ruamel.yaml.constructor import SafeConstructor from sqlmesh.core.constants import VAR from sqlmesh.utils.errors import SQLMeshError @@ -13,9 +16,38 @@ JINJA_METHODS = { "env_var": lambda key, default=None: getenv(key, default), + "user": lambda: getpass.getuser(), } -YAML = lambda: yaml.YAML(typ="safe") # noqa: E731 + +def YAML(typ: t.Optional[str] = "safe") -> yaml.YAML: + yaml_obj = yaml.YAML(typ=typ) + + # Ruamel doesn't know how to serialize Decimal values. This is problematic when, + # e.g., we're trying to auto-generate a unit test whose body contains Decimal data. + # This is a best-effort approach to solve this by serializing them as strings. + yaml_obj.representer.add_representer( + Decimal, lambda dumper, data: dumper.represent_str(str(data)) + ) + + return yaml_obj + + +class SafeConstructorOverride(SafeConstructor): + def check_mapping_key( + self, + node: t.Any, + key_node: t.Any, + mapping: t.Any, + key: t.Any, + value: t.Any, + ) -> bool: + """This function normally returns True if key is unique. + + It is only used by the construct_mapping function. By always returning True, + keys will always be updated and so the last value will be kept for mappings. + """ + return True def load( @@ -24,6 +56,7 @@ def load( render_jinja: bool = True, allow_duplicate_keys: bool = False, variables: t.Optional[t.Dict[str, t.Any]] = None, + keep_last_duplicate_key: bool = False, ) -> t.Dict: """Loads a YAML object from either a raw string or a file.""" path: t.Optional[Path] = None @@ -42,6 +75,8 @@ def load( ) yaml = YAML() + if allow_duplicate_keys and keep_last_duplicate_key: + yaml.Constructor = SafeConstructorOverride yaml.allow_duplicate_keys = allow_duplicate_keys contents = yaml.load(source) if contents is None: @@ -64,8 +99,5 @@ def dump(value: t.Any) -> str: ... def dump(value: t.Any, stream: t.Optional[io.IOBase] = None) -> t.Optional[str]: """Dumps a ruamel.yaml loaded object and converts it into a string or writes it to a stream.""" result = io.StringIO() - yaml.YAML().dump(value, stream or result) - - if stream: - return None - return result.getvalue() + YAML(typ=None).dump(value, stream or result) + return None if stream else result.getvalue() diff --git a/sqlmesh_dbt/__init__.py b/sqlmesh_dbt/__init__.py new file mode 100644 index 0000000000..984f083f5b --- /dev/null +++ b/sqlmesh_dbt/__init__.py @@ -0,0 +1,5 @@ +# Note: `sqlmesh_dbt` is deliberately in its own package from `sqlmesh` to avoid the upfront time overhead +# that comes from `import sqlmesh` +# +# Obviously we still have to `import sqlmesh` at some point but this allows us to defer it until needed, +# which means we can make the CLI feel more responsive by being able to output something immediately diff --git a/sqlmesh_dbt/cli.py b/sqlmesh_dbt/cli.py new file mode 100644 index 0000000000..278daa5370 --- /dev/null +++ b/sqlmesh_dbt/cli.py @@ -0,0 +1,220 @@ +import typing as t +import sys +import click +from sqlmesh_dbt.operations import DbtOperations, create +from sqlmesh_dbt.error import cli_global_error_handler, ErrorHandlingGroup +from pathlib import Path +from sqlmesh_dbt.options import YamlParamType +import functools + + +def _get_dbt_operations( + ctx: click.Context, vars: t.Optional[t.Dict[str, t.Any]], threads: t.Optional[int] = None +) -> DbtOperations: + if not isinstance(ctx.obj, functools.partial): + raise ValueError(f"Unexpected click context object: {type(ctx.obj)}") + + dbt_operations = ctx.obj(vars=vars, threads=threads) + + if not isinstance(dbt_operations, DbtOperations): + raise ValueError(f"Unexpected dbt operations type: {type(dbt_operations)}") + + @ctx.call_on_close + def _cleanup() -> None: + dbt_operations.close() + + return dbt_operations + + +vars_option = click.option( + "--vars", + type=YamlParamType(), + help="Supply variables to the project. This argument overrides variables defined in your dbt_project.yml file. This argument should be a YAML string, eg. '{my_variable: my_value}'", +) + + +select_option = click.option( + "-s", + "--select", + multiple=True, + help="Specify the nodes to include.", +) +model_option = click.option( + "-m", + "--models", + "--model", + multiple=True, + help="Specify the model nodes to include; other nodes are excluded.", +) +exclude_option = click.option("--exclude", multiple=True, help="Specify the nodes to exclude.") + +# TODO: expand this out into --resource-type/--resource-types and --exclude-resource-type/--exclude-resource-types +resource_types = [ + "metric", + "semantic_model", + "saved_query", + "source", + "analysis", + "model", + "test", + "unit_test", + "exposure", + "snapshot", + "seed", + "default", + "all", +] +resource_type_option = click.option( + "--resource-type", type=click.Choice(resource_types, case_sensitive=False) +) + + +@click.group(cls=ErrorHandlingGroup, invoke_without_command=True) +@click.option("--profile", help="Which existing profile to load. Overrides output.profile") +@click.option("-t", "--target", help="Which target to load for the given profile") +@click.option( + "-d", + "--debug/--no-debug", + default=False, + help="Display debug logging during dbt execution. Useful for debugging and making bug reports events to help when debugging.", +) +@click.option( + "--log-level", + default="info", + type=click.Choice(["debug", "info", "warn", "error", "none"]), + help="Specify the minimum severity of events that are logged to the console and the log file.", +) +@click.option( + "--profiles-dir", + type=click.Path(exists=True, file_okay=False, path_type=Path), + help="Which directory to look in for the profiles.yml file. If not set, dbt will look in the current working directory first, then HOME/.dbt/", +) +@click.option( + "--project-dir", + type=click.Path(exists=True, file_okay=False, path_type=Path), + help="Which directory to look in for the dbt_project.yml file. Default is the current working directory and its parents.", +) +@click.pass_context +@cli_global_error_handler +def dbt( + ctx: click.Context, + profile: t.Optional[str] = None, + target: t.Optional[str] = None, + debug: bool = False, + log_level: t.Optional[str] = None, + profiles_dir: t.Optional[Path] = None, + project_dir: t.Optional[Path] = None, +) -> None: + """ + An ELT tool for managing your SQL transformations and data models, powered by the SQLMesh engine. + """ + + if "--help" in sys.argv: + # we dont need to import sqlmesh/load the project for CLI help + return + + # we have a partially applied function here because subcommands might set extra options like --vars + # that need to be known before we attempt to load the project + ctx.obj = functools.partial( + create, + project_dir=project_dir, + profiles_dir=profiles_dir, + profile=profile, + target=target, + debug=debug, + log_level=log_level, + ) + + if not ctx.invoked_subcommand: + if profile or target: + # trigger a project load to validate the specified profile / target + ctx.obj() + + click.echo( + f"No command specified. Run `{ctx.info_name} --help` to see the available commands." + ) + + +@dbt.command() +@select_option +@model_option +@exclude_option +@resource_type_option +@click.option( + "-f", + "--full-refresh", + is_flag=True, + default=False, + help="If specified, sqlmesh will drop incremental models and fully-recalculate the incremental table from the model definition.", +) +@click.option( + "--env", + "--environment", + help="Run against a specific Virtual Data Environment (VDE) instead of the main environment", +) +@click.option( + "--empty/--no-empty", default=False, help="If specified, limit input refs and sources" +) +@click.option( + "--threads", + type=int, + help="Specify number of threads to use while executing models. Overrides settings in profiles.yml.", +) +@vars_option +@click.pass_context +def run( + ctx: click.Context, + vars: t.Optional[t.Dict[str, t.Any]], + threads: t.Optional[int], + env: t.Optional[str] = None, + **kwargs: t.Any, +) -> None: + """Compile SQL and execute against the current target database.""" + _get_dbt_operations(ctx, vars, threads).run(environment=env, **kwargs) + + +@dbt.command(name="list") +@select_option +@model_option +@exclude_option +@resource_type_option +@vars_option +@click.pass_context +def list_(ctx: click.Context, vars: t.Optional[t.Dict[str, t.Any]], **kwargs: t.Any) -> None: + """List the resources in your project""" + _get_dbt_operations(ctx, vars).list_(**kwargs) + + +@dbt.command(name="ls", hidden=True) # hidden alias for list +@click.pass_context +def ls(ctx: click.Context) -> None: + """List the resources in your project""" + ctx.forward(list_) + + +def _not_implemented(name: str) -> None: + @dbt.command(name=name) + def _not_implemented() -> None: + """Not implemented""" + click.echo(f"dbt {name} not implemented") + + +for subcommand in ( + "build", + "clean", + "clone", + "compile", + "debug", + "deps", + "docs", + "init", + "parse", + "retry", + "run-operation", + "seed", + "show", + "snapshot", + "source", + "test", +): + _not_implemented(subcommand) diff --git a/sqlmesh_dbt/console.py b/sqlmesh_dbt/console.py new file mode 100644 index 0000000000..6bf7a1618f --- /dev/null +++ b/sqlmesh_dbt/console.py @@ -0,0 +1,35 @@ +import typing as t +from sqlmesh.core.console import TerminalConsole +from sqlmesh.core.model import Model +from sqlmesh.core.snapshot.definition import Node +from rich.tree import Tree + + +class DbtCliConsole(TerminalConsole): + def print(self, msg: str) -> None: + return self._print(msg) + + def list_models( + self, + models: t.List[Model], + all_nodes: t.Dict[str, Node], + list_parents: bool = True, + list_audits: bool = True, + ) -> None: + model_list = Tree("[bold]Models in project:[/bold]") + + for model in models: + model_tree = model_list.add(model.dbt_fqn or model.name) + + if list_parents: + for parent_name in model.depends_on: + if parent := all_nodes.get(parent_name): + parent_name = parent.dbt_fqn or parent_name + + model_tree.add(f"depends_on: {parent_name}") + + if list_audits: + for audit_name, audit in model.audit_definitions.items(): + model_tree.add(f"audit: {audit.dbt_fqn or audit_name}") + + self._print(model_list) diff --git a/sqlmesh_dbt/error.py b/sqlmesh_dbt/error.py new file mode 100644 index 0000000000..49a2f8195b --- /dev/null +++ b/sqlmesh_dbt/error.py @@ -0,0 +1,36 @@ +import typing as t +import logging +from functools import wraps +import click +import sys + +logger = logging.getLogger(__name__) + + +def cli_global_error_handler( + func: t.Callable[..., t.Any], +) -> t.Callable[..., t.Any]: + @wraps(func) + def wrapper(*args: t.List[t.Any], **kwargs: t.Any) -> t.Any: + try: + return func(*args, **kwargs) + except Exception as ex: + # these imports are deliberately deferred to avoid the penalty of importing the `sqlmesh` + # package up front for every CLI command + from sqlmesh.utils.errors import SQLMeshError + from sqlglot.errors import SqlglotError + + if isinstance(ex, (SQLMeshError, SqlglotError, ValueError)): + click.echo(click.style("Error: " + str(ex), fg="red")) + sys.exit(1) + else: + raise + + return wrapper + + +class ErrorHandlingGroup(click.Group): + def add_command(self, cmd: click.Command, name: t.Optional[str] = None) -> None: + if cmd.callback: + cmd.callback = cli_global_error_handler(cmd.callback) + super().add_command(cmd, name=name) diff --git a/sqlmesh_dbt/operations.py b/sqlmesh_dbt/operations.py new file mode 100644 index 0000000000..576d8e090b --- /dev/null +++ b/sqlmesh_dbt/operations.py @@ -0,0 +1,311 @@ +from __future__ import annotations +import typing as t +from rich.progress import Progress +from pathlib import Path +import logging +from sqlmesh_dbt import selectors + +if t.TYPE_CHECKING: + # important to gate these to be able to defer importing sqlmesh until we need to + from sqlmesh.core.context import Context + from sqlmesh.dbt.project import Project + from sqlmesh_dbt.console import DbtCliConsole + from sqlmesh.core.model import Model + from sqlmesh.core.plan import Plan, PlanBuilder + +logger = logging.getLogger(__name__) + + +class DbtOperations: + def __init__(self, sqlmesh_context: Context, dbt_project: Project, debug: bool = False): + self.context = sqlmesh_context + self.project = dbt_project + self.debug = debug + + def list_( + self, + select: t.Optional[t.List[str]] = None, + exclude: t.Optional[t.List[str]] = None, + models: t.Optional[t.List[str]] = None, + resource_type: t.Optional[str] = None, + ) -> None: + # dbt list prints: + # - models + # - "data tests" (audits) for those models + # it also applies selectors which is useful for testing selectors + selected_models = list( + self._selected_models(select, exclude, models, resource_type).values() + ) + self.console.list_models( + selected_models, {k: v.node for k, v in self.context.snapshots.items()} + ) + + def run( + self, + environment: t.Optional[str] = None, + select: t.Optional[t.List[str]] = None, + exclude: t.Optional[t.List[str]] = None, + models: t.Optional[t.List[str]] = None, + resource_type: t.Optional[str] = None, + full_refresh: bool = False, + empty: bool = False, + ) -> Plan: + consolidated_select, consolidated_exclude = selectors.consolidate( + select or [], exclude or [], models or [], resource_type + ) + + plan_builder = self._plan_builder( + environment=environment, + select=consolidated_select, + exclude=consolidated_exclude, + full_refresh=full_refresh, + empty=empty, + ) + + plan = plan_builder.build() + + self.console.plan( + plan_builder, + default_catalog=self.context.default_catalog, + # start doing work immediately (since no_diff is set, there isnt really anything for the user to say yes/no to) + auto_apply=True, + # dont output a diff of model changes + no_diff=True, + # don't throw up any prompts like "set the effective date" - use defaults + no_prompts=True, + ) + + return plan + + def _plan_builder( + self, + environment: t.Optional[str] = None, + select: t.Optional[t.List[str]] = None, + exclude: t.Optional[t.List[str]] = None, + full_refresh: bool = False, + empty: bool = False, + ) -> PlanBuilder: + return self.context.plan_builder( + **self._plan_builder_options( + environment=environment, + select=select, + exclude=exclude, + full_refresh=full_refresh, + empty=empty, + ) + ) + + def _selected_models( + self, + select: t.Optional[t.List[str]] = None, + exclude: t.Optional[t.List[str]] = None, + models: t.Optional[t.List[str]] = None, + resource_type: t.Optional[str] = None, + ) -> t.Dict[str, Model]: + if sqlmesh_selector := selectors.to_sqlmesh( + *selectors.consolidate(select or [], exclude or [], models or [], resource_type) + ): + if self.debug: + self.console.print(f"dbt --select: {select}") + self.console.print(f"dbt --exclude: {exclude}") + self.console.print(f"sqlmesh equivalent: '{sqlmesh_selector}'") + model_selector = self.context._new_selector() + selected_models = { + fqn: model + for fqn, model in self.context.models.items() + if fqn in model_selector.expand_model_selections([sqlmesh_selector]) + } + else: + selected_models = dict(self.context.models) + + return selected_models + + def _plan_builder_options( + self, + # upstream dbt options + select: t.Optional[t.List[str]] = None, + exclude: t.Optional[t.List[str]] = None, + empty: bool = False, + full_refresh: bool = False, + # sqlmesh extra options + environment: t.Optional[str] = None, + ) -> t.Dict[str, t.Any]: + import sqlmesh.core.constants as c + + # convert --select and --exclude to a selector expression for the SQLMesh selector engine + select_models = None + if sqlmesh_selector := selectors.to_sqlmesh(select or [], exclude or []): + select_models = [sqlmesh_selector] + + is_dev = environment and environment != c.PROD + is_prod = not is_dev + + options: t.Dict[str, t.Any] = {} + + if is_prod or (is_dev and select_models): + # prod plans should "catch up" before applying the changes so that after the command finishes prod is the latest it can be + # dev plans *with* selectors should do the same as the user is saying "specifically update these models to the latest" + # dev plans *without* selectors should just have the defaults of never exceeding prod as the user is saying "just create this env" without focusing on any specific models + options.update( + dict( + # always catch the data up to latest rather than only operating on what has been loaded before + run=True, + # don't taking cron schedules into account when deciding what models to run, do everything even if it just ran + ignore_cron=True, + ) + ) + + if is_dev: + options.update( + dict( + # don't create views for all of prod in the dev environment + include_unmodified=False, + # always plan from scratch against prod. note that this is coupled with the `always_recreate_environment=True` setting in the default config file. + # the result is that rather than planning against the previous state of an existing dev environment, the full scope of changes vs prod are always shown + create_from=c.PROD, + # Always enable dev previews for incremental / forward-only models. + # Due to how DBT does incrementals (INCREMENTAL_UNMANAGED on the SQLMesh engine), this will result in the full model being refreshed + # with the entire dataset, which can potentially be very large. If this is undesirable, users have two options: + # - work around this using jinja to conditionally add extra filters to the WHERE clause or a LIMIT to the model query + # - upgrade to SQLMesh's incremental models, where we have variables for the start/end date and inject leak guards to + # limit the amount of data backfilled + # + # Note: enable_preview=True is *different* behaviour to the `sqlmesh` CLI, which uses enable_preview=None. + # This means the `sqlmesh` CLI will only enable dev previews for dbt projects if the target adapter supports cloning, + # whereas we enable it unconditionally here + enable_preview=True, + ) + ) + + if empty: + # `dbt --empty` adds LIMIT 0 to the queries, resulting in empty tables. In addition, it happily clobbers existing tables regardless of if they are populated. + # This *partially* lines up with --skip-backfill in SQLMesh, which indicates to not populate tables if they happened to be created/updated as part of this plan. + # However, if a table already exists and has data in it, there is no change so SQLMesh will not recreate the table and thus it will not be cleared. + # Currently, SQLMesh has no way to say "restate with empty data", because --restate-model coupled with --skip-backfill ends up being a no-op + options["skip_backfill"] = True + + self.console.log_warning( + "dbt's `--empty` drops the tables for all selected models and replaces them with empty ones.\n" + "This can easily result in accidental data loss, so SQLMesh limits this to only new or modified models and leaves the tables for existing unmodified models alone.\n\n" + "If you were creating empty tables to preview model changes, please consider using `--environment` to preview these changes in an isolated Virtual Data Environment instead.\n\n" + "Otherwise, if you really do want dbt's `--empty` behaviour of clearing every selected table, please file an issue on GitHub so we can better understand the use-case.\n" + ) + + if full_refresh: + # --full-refresh is implemented in terms of "add every model as a restatement" + # however, `--empty` sets skip_backfill=True, which causes the BackfillStage of the plan to be skipped. + # the re-processing of data intervals happens in the BackfillStage, so if it gets skipped, restatements become a no-op + raise ValueError("`--full-refresh` alongside `--empty` is not currently supported.") + + if full_refresh: + options.update( + dict( + # Add every selected model as a restatement to force them to get repopulated from scratch + restate_models=[m.dbt_fqn for m in self.context.models.values() if m.dbt_fqn] + if not select_models + else select_models, + # by default in SQLMesh, restatements only operate on what has been committed to state. + # in order to emulate dbt, we need to use the local filesystem instead, so we override this default + always_include_local_changes=True, + ) + ) + + return dict( + environment=environment, + select_models=select_models, + **options, + ) + + @property + def console(self) -> DbtCliConsole: + console = self.context.console + from sqlmesh_dbt.console import DbtCliConsole + + if not isinstance(console, DbtCliConsole): + raise ValueError(f"Expecting dbt cli console, got: {console}") + + return console + + def close(self) -> None: + self.context.close() + + +def create( + project_dir: t.Optional[Path] = None, + profiles_dir: t.Optional[Path] = None, + profile: t.Optional[str] = None, + target: t.Optional[str] = None, + vars: t.Optional[t.Dict[str, t.Any]] = None, + threads: t.Optional[int] = None, + debug: bool = False, + log_level: t.Optional[str] = None, +) -> DbtOperations: + with Progress(transient=True) as progress: + # Indeterminate progress bar before SQLMesh import to provide feedback to the user that something is indeed happening + load_task_id = progress.add_task("Loading engine", total=None) + + from sqlmesh import configure_logging + from sqlmesh.core.context import Context + from sqlmesh.dbt.loader import DbtLoader + from sqlmesh.core.console import set_console + from sqlmesh_dbt.console import DbtCliConsole + from sqlmesh.utils.errors import SQLMeshError + from sqlmesh.core.selector import DbtSelector + + # clear any existing handlers set up by click/rich as defaults so that once SQLMesh logging config is applied, + # we dont get duplicate messages logged from things like console.log_warning() + root_logger = logging.getLogger() + while root_logger.hasHandlers(): + root_logger.removeHandler(root_logger.handlers[0]) + + configure_logging(force_debug=debug, log_level=log_level) + set_console(DbtCliConsole()) + + progress.update(load_task_id, description="Loading project", total=None) + + project_dir = project_dir or Path.cwd() + init_project_if_required(project_dir) + + sqlmesh_context = Context( + paths=[project_dir], + config_loader_kwargs=dict( + profile=profile, + target=target, + variables=vars, + threads=threads, + profiles_dir=profiles_dir, + ), + load=True, + # DbtSelector selects based on dbt model fqn's rather than SQLMesh model names + selector=DbtSelector, + ) + + dbt_loader = sqlmesh_context._loaders[0] + if not isinstance(dbt_loader, DbtLoader): + raise SQLMeshError(f"Unexpected loader type: {type(dbt_loader)}") + + # so that DbtOperations can query information from the DBT project files in order to invoke SQLMesh correctly + dbt_project = dbt_loader._projects[0] + + return DbtOperations(sqlmesh_context, dbt_project, debug=debug) + + +def init_project_if_required(project_dir: Path, start: t.Optional[str] = None) -> None: + """ + SQLMesh needs a start date to as the starting point for calculating intervals on incremental models, amongst other things + + Rather than forcing the user to update their config manually or having a default that is not saved between runs, + we can generate a basic SQLMesh config if it doesnt exist. + + This is preferable to trying to inject config into `dbt_project.yml` because it means we have full control over the file + and dont need to worry about accidentally reformatting it or accidentally clobbering other config + """ + from sqlmesh.cli.project_init import init_example_project, ProjectTemplate + from sqlmesh.core.config.common import ALL_CONFIG_FILENAMES + from sqlmesh.core.console import get_console + + if not any(f.exists() for f in [project_dir / file for file in ALL_CONFIG_FILENAMES]): + get_console().log_warning("No existing SQLMesh config detected; creating one") + init_example_project( + path=project_dir, engine_type=None, template=ProjectTemplate.DBT, start=start + ) diff --git a/sqlmesh_dbt/options.py b/sqlmesh_dbt/options.py new file mode 100644 index 0000000000..5a7cabe93b --- /dev/null +++ b/sqlmesh_dbt/options.py @@ -0,0 +1,25 @@ +import typing as t +import click +from click.core import Context, Parameter + + +class YamlParamType(click.ParamType): + name = "yaml" + + def convert( + self, value: t.Any, param: t.Optional[Parameter], ctx: t.Optional[Context] + ) -> t.Any: + if not isinstance(value, str): + self.fail(f"Input value '{value}' should be a string", param, ctx) + + from sqlmesh.utils import yaml + + try: + parsed = yaml.load(source=value, render_jinja=False) + except: + self.fail(f"String '{value}' is not valid YAML", param, ctx) + + if not isinstance(parsed, dict): + self.fail(f"String '{value}' did not evaluate to a dict, got: {parsed}", param, ctx) + + return parsed diff --git a/sqlmesh_dbt/selectors.py b/sqlmesh_dbt/selectors.py new file mode 100644 index 0000000000..5821586ad3 --- /dev/null +++ b/sqlmesh_dbt/selectors.py @@ -0,0 +1,174 @@ +import typing as t +import logging + +logger = logging.getLogger(__name__) + + +def consolidate( + select: t.List[str], + exclude: t.List[str], + models: t.List[str], + resource_type: t.Optional[str], +) -> t.Tuple[t.List[str], t.List[str]]: + """ + Given a bunch of dbt CLI arguments that may or may not be defined: + --select, --exclude, --models, --resource-type + + Combine them into a single set of --select/--exclude node selectors, throwing an error if mutually exclusive combinations are provided + Note that the returned value is still in dbt format, pass it to to_sqlmesh() to create a selector for the sqlmesh selector engine + """ + if models and select: + raise ValueError('"models" and "select" are mutually exclusive arguments') + + if models and resource_type: + raise ValueError('"models" and "resource_type" are mutually exclusive arguments') + + if models: + # --models implies resource_type:model + resource_type = "model" + + if resource_type: + resource_type_selector = f"resource_type:{resource_type}" + all_selectors = [*select, *models] + select = ( + [ + f"resource_type:{resource_type},{original_selector}" + for original_selector in all_selectors + ] + if all_selectors + else [resource_type_selector] + ) + + return select, exclude + + +def to_sqlmesh(dbt_select: t.List[str], dbt_exclude: t.List[str]) -> t.Optional[str]: + """ + Given selectors defined in the format of the dbt cli --select and --exclude arguments, convert them into a selector expression that + the SQLMesh selector engine can understand. + + The main things being mapped are: + - set union (" " between items within the same selector string OR multiple --select arguments) is mapped to " | " + - set intersection ("," between items within the same selector string) is mapped to " & " + - `--exclude`. The SQLMesh selector engine does not treat this as a separate parameter and rather treats exclusion as a normal selector + that just happens to contain negation syntax, so we generate these by negating each expression and then intersecting the result + with any --select expressions + + Things that are *not* currently being mapped include: + - selectors based on file paths + - selectors based on partially qualified names like "model_a". The SQLMesh selector engine requires either: + - wildcards, eg "*model_a*" + - the full model name qualified with the schema, eg "staging.model_a" + + Examples: + --select "model_a" + -> "model_a" + --select "main.model_a" + -> "main.model_a" + --select "main.model_a" --select "main.model_b" + -> "main.model_a | main.model_b" + --select "main.model_a main.model_b" + -> "main.model_a | main.model_b" + --select "(main.model_a+ & ^main.model_b)" + -> "(main.model_a+ & ^main.model_b)" + --select "+main.model_a" --exclude "raw.src_data" + -> "+main.model_a & ^(raw.src_data)" + --select "+main.model_a" --select "main.*b+" --exclude "raw.src_data" + -> "(+main.model_a | main.*b+) & ^(raw.src_data)" + --select "+main.model_a" --select "main.*b+" --exclude "raw.src_data" --exclude "main.model_c" + -> "(+main.model_a | main.*b+) & ^(raw.src_data | main.model_c)" + --select "+main.model_a main.*b+" --exclude "raw.src_data main.model_c" + -> "(+main.model_a | main.*b+) & ^(raw.src_data | main.model_c)" + """ + if not dbt_select and not dbt_exclude: + return None + + select_expr = " | ".join(_to_sqlmesh(expr) for expr in dbt_select) + select_expr = _wrap(select_expr) if dbt_exclude and len(dbt_select) > 1 else select_expr + + exclude_expr = "" + + if dbt_exclude: + exclude_expr = " | ".join(_to_sqlmesh(expr) for expr in dbt_exclude) + exclude_expr = _negate( + _wrap(exclude_expr) if dbt_select and len(dbt_exclude) > 1 else exclude_expr + ) + + main_expr = " & ".join([expr for expr in [select_expr, exclude_expr] if expr]) + + logger.debug( + f"Expanded dbt select: {dbt_select}, exclude: {dbt_exclude} into SQLMesh: {main_expr}" + ) + + return main_expr + + +def _to_sqlmesh(selector_str: str) -> str: + unions, intersections = _split_unions_and_intersections(selector_str) + + union_expr = " | ".join(unions) + intersection_expr = " & ".join(intersections) + + if len(unions) > 1 and intersections: + union_expr = f"({union_expr})" + + if len(intersections) > 1 and unions: + intersection_expr = f"({intersection_expr})" + + return " | ".join([expr for expr in [union_expr, intersection_expr] if expr]) + + +def _split_unions_and_intersections(selector_str: str) -> t.Tuple[t.List[str], t.List[str]]: + # break space-separated items like: "my_first_model my_second_model" into a list of selectors to union + # and comma-separated items like: "my_first_model,my_second_model" into a list of selectors to intersect + # but, take into account brackets, eg "(my_first_model & my_second_model)" should not be split + # also take into account both types in the same string, eg "my_first_model my_second_model model_3,model_4,model_5" + + def _split_by(input: str, delimiter: str) -> t.Iterator[str]: + buf = "" + depth = 0 + + for char in input: + if char == delimiter and depth <= 0: + # only split on a space if we are not within parenthesis + yield buf + buf = "" + continue + elif char == "(": + depth += 1 + elif char == ")": + depth -= 1 + + buf += char + + if buf: + yield buf + + # first, break up based on spaces + segments = list(_split_by(selector_str, " ")) + + # then, within each segment, identify the unions and intersections + unions = [] + intersections = [] + + for segment in segments: + maybe_intersections = list(_split_by(segment, ",")) + if len(maybe_intersections) > 1: + intersections.extend(maybe_intersections) + else: + unions.append(segment) + + return unions, intersections + + +def _negate(expr: str) -> str: + return f"^{_wrap(expr)}" + + +def _wrap(expr: str) -> str: + already_wrapped = expr.strip().startswith("(") and expr.strip().endswith(")") + + if expr and not already_wrapped: + return f"({expr})" + + return expr diff --git a/tests/.gitignore b/tests/.gitignore new file mode 100644 index 0000000000..fff33c1c74 --- /dev/null +++ b/tests/.gitignore @@ -0,0 +1 @@ +sqlmesh_pyproject.toml \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py index d64b8864ba..2b35b8c9c9 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -1,3 +1,5 @@ +from sqlmesh.core import constants as c from sqlmesh.core.analytics import disable_analytics +c.MAX_FORK_WORKERS = 1 disable_analytics() diff --git a/tests/cli/test_cli.py b/tests/cli/test_cli.py index 769518f65f..480d186fa1 100644 --- a/tests/cli/test_cli.py +++ b/tests/cli/test_cli.py @@ -1,41 +1,45 @@ -import logging -from contextlib import contextmanager -from os import path +import json +import os +import pytest +import string +import time_machine +from os import getcwd, path, remove from pathlib import Path +from shutil import rmtree +from unittest.mock import MagicMock -import pytest +from click import ClickException from click.testing import CliRunner -from freezegun import freeze_time - -from sqlmesh.cli.example_project import init_example_project +from sqlmesh import RuntimeEnv +from sqlmesh.cli.project_init import ProjectTemplate, init_example_project from sqlmesh.cli.main import cli +from sqlmesh.core.context import Context +from sqlmesh.integrations.dlt import generate_dlt_models +from sqlmesh.utils.date import now_ds, time_like_to_str, timedelta, to_datetime, yesterday_ds +from sqlmesh.core.config.connection import DIALECT_TO_TYPE -FREEZE_TIME = "2023-01-01 00:00:00" +FREEZE_TIME = "2023-01-01 00:00:00 UTC" pytestmark = pytest.mark.slow -@pytest.fixture(scope="session") -def runner() -> CliRunner: - return CliRunner() +@pytest.fixture(autouse=True) +def mock_runtime_env(monkeypatch): + monkeypatch.setattr("sqlmesh.RuntimeEnv.get", MagicMock(return_value=RuntimeEnv.TERMINAL)) -@contextmanager -def disable_logging(): - logging.disable(logging.CRITICAL) - try: - yield - finally: - logging.disable(logging.NOTSET) +@pytest.fixture(scope="session") +def runner() -> CliRunner: + return CliRunner(env={"COLUMNS": "80"}) -def create_example_project(temp_dir) -> None: +def create_example_project(temp_dir, template=ProjectTemplate.DEFAULT) -> None: """ Sets up CLI tests requiring a real SQLMesh project by: - Creating the SQLMesh example project in the temp_dir directory - Overwriting the config.yaml file so the duckdb database file will be created in the temp_dir directory """ - init_example_project(temp_dir, "duckdb") + init_example_project(temp_dir, engine_type="duckdb", template=template) with open(temp_dir / "config.yaml", "w", encoding="utf-8") as f: f.write( f"""gateways: @@ -48,6 +52,9 @@ def create_example_project(temp_dir) -> None: model_defaults: dialect: duckdb + +plan: + no_prompts: false """ ) @@ -113,26 +120,25 @@ def assert_duckdb_test(result) -> None: assert "Successfully Ran 1 tests against duckdb" in result.output -def assert_new_env(result, new_env="prod", from_env="prod") -> None: - assert f"New environment `{new_env}` will be created from `{from_env}`" in result.output - - -def assert_model_versions_created(result) -> None: - assert "All model versions have been created successfully" in result.output +def assert_new_env(result, new_env="prod", from_env="prod", initialize=True) -> None: + assert ( + f"`{new_env}` environment will be initialized" + if initialize + else f"New environment `{new_env}` will be created from `{from_env}`" + ) in result.output def assert_model_batches_executed(result) -> None: - assert "All model batches have been executed successfully" in result.output + assert "Model batches executed" in result.output -def assert_target_env_updated(result) -> None: - assert "The target environment has been updated successfully" in result.output +def assert_virtual_layer_updated(result) -> None: + assert "Virtual layer updated" in result.output def assert_backfill_success(result) -> None: - assert_model_versions_created(result) assert_model_batches_executed(result) - assert_target_env_updated(result) + assert_virtual_layer_updated(result) def assert_plan_success(result, new_env="prod", from_env="prod") -> None: @@ -142,10 +148,6 @@ def assert_plan_success(result, new_env="prod", from_env="prod") -> None: assert_backfill_success(result) -def assert_virtual_update(result) -> None: - assert "Virtual Update executed successfully" in result.output - - def test_version(runner, tmp_path): from sqlmesh import __version__ as SQLMESH_VERSION @@ -161,6 +163,7 @@ def test_plan_no_config(runner, tmp_path): assert "Error: SQLMesh project config could not be found" in result.output +@time_machine.travel(FREEZE_TIME) def test_plan(runner, tmp_path): create_example_project(tmp_path) @@ -171,6 +174,9 @@ def test_plan(runner, tmp_path): cli, ["--log-file-dir", tmp_path, "--paths", tmp_path, "plan"], input="y\n" ) assert_plan_success(result) + # 'Models needing backfill' section and eval progress bar should display the same inclusive intervals + assert "sqlmesh_example.incremental_model: [2020-01-01 - 2022-12-31]" in result.output + assert "sqlmesh_example.incremental_model [insert 2020-01-01 - 2022-12-31]" in result.output def test_plan_skip_tests(runner, tmp_path): @@ -187,6 +193,28 @@ def test_plan_skip_tests(runner, tmp_path): assert_backfill_success(result) +def test_plan_skip_linter(runner, tmp_path): + create_example_project(tmp_path) + + with open(tmp_path / "config.yaml", "a", encoding="utf-8") as f: + f.write( + """linter: + enabled: True + rules: "ALL" + """ + ) + + # Input: `y` to apply and backfill + result = runner.invoke( + cli, ["--log-file-dir", tmp_path, "--paths", tmp_path, "plan", "--skip-linter"], input="y\n" + ) + + assert result.exit_code == 0 + assert "Linter warnings" not in result.output + assert_new_env(result) + assert_backfill_success(result) + + def test_plan_restate_model(runner, tmp_path): create_example_project(tmp_path) init_prod_and_backfill(runner, tmp_path) @@ -208,34 +236,31 @@ def test_plan_restate_model(runner, tmp_path): ) assert result.exit_code == 0 assert_duckdb_test(result) - assert "No differences when compared to `prod`" in result.output - assert "sqlmesh_example.full_model evaluated in" in result.output - assert_backfill_success(result) + assert "Models selected for restatement" in result.output + assert "sqlmesh_example.full_model [full refresh" in result.output + assert_model_batches_executed(result) + assert "Virtual layer updated" not in result.output -def test_plan_skip_backfill(runner, tmp_path): +@pytest.mark.parametrize("flag", ["--skip-backfill", "--dry-run"]) +def test_plan_skip_backfill(runner, tmp_path, flag): create_example_project(tmp_path) # plan for `prod` errors if `--skip-backfill` is passed without --no-gaps - result = runner.invoke( - cli, ["--log-file-dir", tmp_path, "--paths", tmp_path, "plan", "--skip-backfill"] - ) + result = runner.invoke(cli, ["--log-file-dir", tmp_path, "--paths", tmp_path, "plan", flag]) assert result.exit_code == 1 - assert ( - "Error: When targeting the production environment either the backfill should not be skipped or the lack of data gaps should be enforced (--no-gaps flag)." - in result.output - ) + assert "Skipping the backfill stage for production can lead to unexpected" in result.output # plan executes virtual update without executing model batches # Input: `y` to perform virtual update result = runner.invoke( cli, - ["--log-file-dir", tmp_path, "--paths", tmp_path, "plan", "--skip-backfill", "--no-gaps"], + ["--log-file-dir", tmp_path, "--paths", tmp_path, "plan", flag, "--no-gaps"], input="y\n", ) assert result.exit_code == 0 - assert_virtual_update(result) - assert "All model batches have been executed successfully" not in result.output + assert_virtual_layer_updated(result) + assert "Model batches executed" not in result.output def test_plan_auto_apply(runner, tmp_path): @@ -249,7 +274,7 @@ def test_plan_auto_apply(runner, tmp_path): # confirm verbose output not present assert "sqlmesh_example.seed_model created" not in result.output - assert "sqlmesh_example.seed_model promoted" not in result.output + assert "sqlmesh_example.seed_model updated" not in result.output def test_plan_verbose(runner, tmp_path): @@ -260,8 +285,47 @@ def test_plan_verbose(runner, tmp_path): cli, ["--log-file-dir", tmp_path, "--paths", tmp_path, "plan", "--verbose"], input="y\n" ) assert_plan_success(result) - assert "sqlmesh_example.seed_model created" in result.output - assert "sqlmesh_example.seed_model promoted" in result.output + assert "sqlmesh_example.seed_model created" in result.output + assert "sqlmesh_example.full_model created" in result.output + + # confirm virtual layer action labels correct + update_incremental_model(tmp_path) + import os + + os.remove(tmp_path / "models" / "full_model.sql") + + # Input: `y` to apply and backfill + result = runner.invoke( + cli, ["--log-file-dir", tmp_path, "--paths", tmp_path, "plan", "--verbose"], input="y\n" + ) + assert result.exit_code == 0 + assert_backfill_success(result) + assert "sqlmesh_example.incremental_model updated" in result.output + assert "sqlmesh_example.full_model dropped" in result.output + + +def test_plan_very_verbose(runner, tmp_path, copy_to_temp_path): + temp_path = copy_to_temp_path("examples/sushi") + + # Input: `y` to apply and backfill + result = runner.invoke( + cli, + ["--log-file-dir", temp_path[0], "--paths", temp_path[0], "plan", "-v"], + input="y\n", + ) + assert result.exit_code == 0 + # models needing backfill list is still abbreviated with regular VERBOSE, so this should not be present + assert "sushi.customers: [full refresh]" not in result.output + + # Input: `y` to apply and backfill + result = runner.invoke( + cli, + ["--log-file-dir", temp_path[0], "--paths", temp_path[0], "plan", "-vv"], + input="y\n", + ) + assert result.exit_code == 0 + # models needing backfill list is complete with VERY_VERBOSE, so this should be present + assert "sushi.customers: [full refresh]" in result.output def test_plan_dev(runner, tmp_path): @@ -284,8 +348,8 @@ def test_plan_dev_start_date(runner, tmp_path): input="\ny\n", ) assert_plan_success(result, "dev") - assert "sqlmesh_example__dev.full_model: 2023-01-01" in result.output - assert "sqlmesh_example__dev.incremental_model: 2023-01-01" in result.output + assert "sqlmesh_example__dev.full_model: [full refresh]" in result.output + assert "sqlmesh_example__dev.incremental_model: [2023-01-01" in result.output def test_plan_dev_end_date(runner, tmp_path): @@ -298,11 +362,11 @@ def test_plan_dev_end_date(runner, tmp_path): input="\ny\n", ) assert_plan_success(result, "dev") - assert "sqlmesh_example__dev.full_model: 2020-01-01 - 2023-01-01" in result.output - assert "sqlmesh_example__dev.incremental_model: 2020-01-01 - 2023-01-01" in result.output + assert "sqlmesh_example__dev.full_model: [full refresh]" in result.output + assert "sqlmesh_example__dev.incremental_model: [2020-01-01 - 2023-01-01]" in result.output -def test_plan_dev_create_from(runner, tmp_path): +def test_plan_dev_create_from_virtual(runner, tmp_path): create_example_project(tmp_path) # create dev environment and backfill @@ -337,26 +401,111 @@ def test_plan_dev_create_from(runner, tmp_path): ], input="y\n", ) + assert result.exit_code == 0 + assert_new_env(result, "dev2", "dev", initialize=False) + assert "SKIP: No physical layer updates to perform" in result.output + assert "SKIP: No model batches to execute" in result.output + assert_virtual_layer_updated(result) + + +def test_plan_dev_create_from(runner, tmp_path): + create_example_project(tmp_path) + + # create dev environment and backfill + runner.invoke( + cli, + [ + "--log-file-dir", + tmp_path, + "--paths", + tmp_path, + "plan", + "dev", + "--no-prompts", + "--auto-apply", + ], + ) + # make model change + update_incremental_model(tmp_path) + + # create dev2 environment from dev + result = runner.invoke( + cli, + [ + "--log-file-dir", + tmp_path, + "--paths", + tmp_path, + "plan", + "dev2", + "--create-from", + "dev", + "--no-prompts", + "--auto-apply", + ], + ) + + assert result.exit_code == 0 + assert_new_env(result, "dev2", "dev", initialize=False) + assert "Differences from the `dev` environment:" in result.output + + +def test_plan_dev_bad_create_from(runner, tmp_path): + create_example_project(tmp_path) + + # create dev environment and backfill + runner.invoke( + cli, + [ + "--log-file-dir", + tmp_path, + "--paths", + tmp_path, + "plan", + "dev", + "--no-prompts", + "--auto-apply", + ], + ) + # make model change + update_incremental_model(tmp_path) + + # create dev2 environment from non-existent dev3 + result = runner.invoke( + cli, + [ + "--log-file-dir", + tmp_path, + "--paths", + tmp_path, + "plan", + "dev2", + "--create-from", + "dev3", + "--no-prompts", + "--auto-apply", + ], + ) + assert result.exit_code == 0 assert_new_env(result, "dev2", "dev") - assert_model_versions_created(result) - assert_target_env_updated(result) - assert_virtual_update(result) + assert ( + "[WARNING] The environment name 'dev3' was passed to the `plan` command's `--create-from` argument, but 'dev3' does not exist. Initializing new environment 'dev2' from scratch." + in result.output.replace("\n", "") + ) def test_plan_dev_no_prompts(runner, tmp_path): create_example_project(tmp_path) - # plan for non-prod environment doesn't prompt to apply and doesn't - # backfill if only `--no-prompts` is passed + # plan for non-prod environment doesn't prompt for dates but prompts to apply result = runner.invoke( cli, ["--log-file-dir", tmp_path, "--paths", tmp_path, "plan", "dev", "--no-prompts"] ) - assert result.exit_code == 0 - assert "Apply - Backfill Tables [y/n]: " not in result.output - assert "All model versions have been created successfully" not in result.output - assert "All model batches have been executed successfully" not in result.output - assert "The target environment has been updated successfully" not in result.output + assert "Apply - Backfill Tables [y/n]: " in result.output + assert "Physical layer updated" not in result.output + assert "Model batches executed" not in result.output + assert "The target environment has been updated" not in result.output def test_plan_dev_auto_apply(runner, tmp_path): @@ -379,7 +528,7 @@ def test_plan_dev_no_changes(runner, tmp_path): result = runner.invoke(cli, ["--log-file-dir", tmp_path, "--paths", tmp_path, "plan", "dev"]) assert result.exit_code == 1 assert ( - "Error: No changes were detected. Make a change or run with --include-unmodified" + "Error: Creating a new environment requires a change, but project files match the `prod` environment. Make a change or use the --include-unmodified flag to create a new environment without changes." in result.output ) @@ -391,9 +540,58 @@ def test_plan_dev_no_changes(runner, tmp_path): input="y\n", ) assert result.exit_code == 0 - assert_new_env(result, "dev") - assert_target_env_updated(result) - assert_virtual_update(result) + assert_new_env(result, "dev", initialize=False) + assert_virtual_layer_updated(result) + + +def test_plan_dev_longnames(runner, tmp_path): + create_example_project(tmp_path) + + long_model_names = { + "full": f"full_{'a' * 80}", + "incremental": f"incremental_{'b' * 80}", + "seed": f"seed_{'c' * 80}", + } + for model_name in long_model_names: + with open(tmp_path / "models" / f"{model_name}_model.sql", "r") as f: + model_text = f.read() + for more_model_names in long_model_names: + model_text = model_text.replace( + f"sqlmesh_example.{more_model_names}_model", + f"sqlmesh_example.{long_model_names[more_model_names]}_model", + ) + with open(tmp_path / "models" / f"{model_name}_model.sql", "w") as f: + f.write(model_text) + + # Input: `y` to apply and backfill + result = runner.invoke( + cli, + [ + "--log-file-dir", + tmp_path, + "--paths", + tmp_path, + "plan", + "dev_butamuchlongerenvironmentname", + "--skip-tests", + "--no-prompts", + "--auto-apply", + ], + ) + assert result.exit_code == 0 + assert ( + "sqlmesh_example__dev_butamuchlongerenvironmentname.seed_cccccccccccccccccccccccc\ncccccccccccccccccccccccccccccccccccccccccccccccccccccccc_model [insert \nseed file]" + in result.output + ) + assert ( + "sqlmesh_example__dev_butamuchlongerenvironmentname.incremental_bbbbbbbbbbbbbbbbb\nbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb_model [insert " + in result.output + ) + assert ( + "sqlmesh_example__dev_butamuchlongerenvironmentname.full_aaaaaaaaaaaaaaaaaaaaaaaa\naaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_model [full \nrefresh" + in result.output + ) + assert_backfill_success(result) def test_plan_nonbreaking(runner, tmp_path): @@ -407,12 +605,12 @@ def test_plan_nonbreaking(runner, tmp_path): cli, ["--log-file-dir", tmp_path, "--paths", tmp_path, "plan"], input="y\n" ) assert result.exit_code == 0 - assert "Summary of differences against `prod`" in result.output + assert "Differences from the `prod` environment" in result.output assert "+ 'a' AS new_col" in result.output assert "Directly Modified: sqlmesh_example.incremental_model (Non-breaking)" in result.output assert "sqlmesh_example.full_model (Indirect Non-breaking)" in result.output - assert "sqlmesh_example.incremental_model evaluated in" in result.output - assert "sqlmesh_example.full_model evaluated in" not in result.output + assert "sqlmesh_example.incremental_model [insert" in result.output + assert "sqlmesh_example.full_model [full refresh" not in result.output assert_backfill_success(result) @@ -469,8 +667,8 @@ def test_plan_breaking(runner, tmp_path): assert result.exit_code == 0 assert "+ item_id + 1 AS item_id," in result.output assert "Directly Modified: sqlmesh_example.full_model (Breaking)" in result.output - assert "sqlmesh_example.full_model evaluated in" in result.output - assert "sqlmesh_example.incremental_model evaluated in" not in result.output + assert "sqlmesh_example.full_model [full refresh" in result.output + assert "sqlmesh_example.incremental_model [insert" not in result.output assert_backfill_success(result) @@ -508,8 +706,8 @@ def test_plan_dev_select(runner, tmp_path): assert "+ item_id + 1 AS item_id," not in result.output assert "Directly Modified: sqlmesh_example__dev.full_model (Breaking)" not in result.output # only incremental_model backfilled - assert "sqlmesh_example__dev.incremental_model evaluated in" in result.output - assert "sqlmesh_example__dev.full_model evaluated in" not in result.output + assert "sqlmesh_example__dev.incremental_model [insert" in result.output + assert "sqlmesh_example__dev.full_model [full refresh" not in result.output assert_backfill_success(result) @@ -538,7 +736,7 @@ def test_plan_dev_backfill(runner, tmp_path): input="\n\ny\n", ) assert result.exit_code == 0 - assert_new_env(result, "dev") + assert_new_env(result, "dev", initialize=False) # both model diffs present assert "+ item_id + 1 AS item_id," in result.output assert "Directly Modified: sqlmesh_example__dev.full_model (Breaking)" in result.output @@ -547,8 +745,8 @@ def test_plan_dev_backfill(runner, tmp_path): "Directly Modified: sqlmesh_example__dev.incremental_model (Non-breaking)" in result.output ) # only incremental_model backfilled - assert "sqlmesh_example__dev.incremental_model evaluated in" in result.output - assert "sqlmesh_example__dev.full_model evaluated in" not in result.output + assert "sqlmesh_example__dev.incremental_model [insert" in result.output + assert "sqlmesh_example__dev.full_model [full refresh" not in result.output assert_backfill_success(result) @@ -561,15 +759,16 @@ def test_run_no_prod(runner, tmp_path): assert "Error: Environment 'prod' was not found." in result.output -@freeze_time(FREEZE_TIME) -def test_run_dev(runner, tmp_path): +@pytest.mark.parametrize("flag", ["--skip-backfill", "--dry-run"]) +@time_machine.travel(FREEZE_TIME) +def test_run_dev(runner, tmp_path, flag): create_example_project(tmp_path) # Create dev environment but DO NOT backfill # Input: `y` for virtual update runner.invoke( cli, - ["--log-file-dir", tmp_path, "--paths", tmp_path, "plan", "dev", "--skip-backfill"], + ["--log-file-dir", tmp_path, "--paths", tmp_path, "plan", "dev", flag], input="y\n", ) @@ -579,27 +778,31 @@ def test_run_dev(runner, tmp_path): assert_model_batches_executed(result) -@freeze_time(FREEZE_TIME) +@time_machine.travel(FREEZE_TIME) def test_run_cron_not_elapsed(runner, tmp_path, caplog): create_example_project(tmp_path) init_prod_and_backfill(runner, tmp_path) - # No error and no output if `prod` environment exists and cron has not elapsed - with disable_logging(): - result = runner.invoke(cli, ["--log-file-dir", tmp_path, "--paths", tmp_path, "run"]) + # No error if `prod` environment exists and cron has not elapsed + result = runner.invoke(cli, ["--log-file-dir", tmp_path, "--paths", tmp_path, "run"]) assert result.exit_code == 0 - assert result.output.strip() == "Run finished for environment 'prod'" + + assert ( + "No models are ready to run. Please wait until a model `cron` interval has \nelapsed.\n\nNext run will be ready at " + in result.output.strip() + ) def test_run_cron_elapsed(runner, tmp_path): create_example_project(tmp_path) # Create and backfill `prod` environment - with freeze_time("2023-01-01 23:59:00"): + with time_machine.travel("2023-01-01 23:59:00 UTC", tick=False) as traveler: + runner = CliRunner() init_prod_and_backfill(runner, tmp_path) - # Run `prod` environment with daily cron elapsed - with freeze_time("2023-01-02 00:01:00"): + # Run `prod` environment with daily cron elapsed + traveler.move_to("2023-01-02 00:01:00 UTC") result = runner.invoke(cli, ["--log-file-dir", tmp_path, "--paths", tmp_path, "run"]) assert result.exit_code == 0 @@ -628,17 +831,1411 @@ def test_table_name(runner, tmp_path): # Create and backfill `prod` environment create_example_project(tmp_path) init_prod_and_backfill(runner, tmp_path) - with disable_logging(): - result = runner.invoke( - cli, - [ - "--log-file-dir", - tmp_path, - "--paths", - tmp_path, - "table_name", - "sqlmesh_example.full_model", - ], - ) + result = runner.invoke( + cli, + [ + "--log-file-dir", + tmp_path, + "--paths", + tmp_path, + "table_name", + "sqlmesh_example.full_model", + ], + ) assert result.exit_code == 0 assert result.output.startswith("db.sqlmesh__sqlmesh_example.sqlmesh_example__full_model__") + + +def test_info_on_new_project_does_not_create_state_sync(runner, tmp_path): + create_example_project(tmp_path) + + # Invoke the info command + result = runner.invoke(cli, ["--log-file-dir", tmp_path, "--paths", tmp_path, "info"]) + assert result.exit_code == 0 + + context = Context(paths=tmp_path) + + # Confirm that the state sync tables haven't been created + assert not context.engine_adapter.table_exists("sqlmesh._snapshots") + assert not context.engine_adapter.table_exists("sqlmesh._environments") + assert not context.engine_adapter.table_exists("sqlmesh._intervals") + assert not context.engine_adapter.table_exists("sqlmesh._versions") + + +def test_dlt_pipeline_errors(runner, tmp_path): + # Error if no pipeline is provided + result = runner.invoke(cli, ["--paths", tmp_path, "init", "-t", "dlt", "duckdb"]) + assert ( + "Error: Please provide a DLT pipeline with the `--dlt-pipeline` flag to generate a SQLMesh project from DLT" + in result.output + ) + + # Error if the pipeline provided is not correct + result = runner.invoke( + cli, + ["--paths", tmp_path, "init", "-t", "dlt", "--dlt-pipeline", "missing_pipeline", "duckdb"], + ) + assert "Error: Could not attach to pipeline" in result.output + + +@time_machine.travel(FREEZE_TIME) +def test_dlt_filesystem_pipeline(tmp_path): + import dlt + + root_dir = path.abspath(getcwd()) + storage_path = root_dir + "/temp_storage" + if path.exists(storage_path): + rmtree(storage_path) + + filesystem_pipeline = dlt.pipeline( + pipeline_name="filesystem_pipeline", + destination=dlt.destinations.filesystem("file://" + storage_path), + ) + info = filesystem_pipeline.run([{"item_id": 1}], table_name="equipment") + assert not info.has_failed_jobs + + init_example_project( + tmp_path, "athena", template=ProjectTemplate.DLT, pipeline="filesystem_pipeline" + ) + + # Validate generated sqlmesh config and models + config_path = tmp_path / "config.yaml" + equipment_model_path = tmp_path / "models/incremental_equipment.sql" + dlt_loads_model_path = tmp_path / "models/incremental__dlt_loads.sql" + + assert config_path.exists() + assert equipment_model_path.exists() + assert dlt_loads_model_path.exists() + + expected_incremental_model = """MODEL ( + name filesystem_pipeline_dataset_sqlmesh.incremental_equipment, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column _dlt_load_time, + ), +); + +SELECT + CAST(c.item_id AS BIGINT) AS item_id, + CAST(c._dlt_load_id AS VARCHAR) AS _dlt_load_id, + CAST(c._dlt_id AS VARCHAR) AS _dlt_id, + TO_TIMESTAMP(CAST(c._dlt_load_id AS DOUBLE)) as _dlt_load_time +FROM + filesystem_pipeline_dataset.equipment as c +WHERE + TO_TIMESTAMP(CAST(c._dlt_load_id AS DOUBLE)) BETWEEN @start_ds AND @end_ds +""" + + with open(equipment_model_path) as file: + incremental_model = file.read() + + assert incremental_model == expected_incremental_model + + expected_config = ( + "# --- Gateway Connection ---\n" + "gateways:\n" + " athena:\n" + " connection:\n" + " # For more information on configuring the connection to your execution engine, visit:\n" + " # https://sqlmesh.readthedocs.io/en/stable/reference/configuration/#connection\n" + " # https://sqlmesh.readthedocs.io/en/stable/integrations/engines/athena/#connection-options\n" + " type: athena\n" + " # concurrent_tasks: 4\n" + " # register_comments: False\n" + " # pre_ping: False\n" + " # pretty_sql: False\n" + " # schema_differ_overrides: \n" + " # catalog_type_overrides: \n" + " # aws_access_key_id: \n" + " # aws_secret_access_key: \n" + " # role_arn: \n" + " # role_session_name: \n" + " # region_name: \n" + " # work_group: \n" + " # s3_staging_dir: \n" + " # schema_name: \n" + " # catalog_name: \n" + " # s3_warehouse_location: \n\n" + "default_gateway: athena\n\n" + "# --- Model Defaults ---\n" + "# https://sqlmesh.readthedocs.io/en/stable/reference/model_configuration/#model-defaults\n\n" + "model_defaults:\n" + " dialect: athena\n" + f" start: {yesterday_ds()} # Start date for backfill history\n" + " cron: '@daily' # Run models daily at 12am UTC (can override per model)\n\n" + "# --- Linting Rules ---\n" + "# Enforce standards for your team\n" + "# https://sqlmesh.readthedocs.io/en/stable/guides/linter/\n\n" + "linter:\n" + " enabled: true\n" + " rules:\n" + " - ambiguousorinvalidcolumn\n" + " - invalidselectstarexpansion\n" + " - noambiguousprojections\n" + ) + + with open(config_path) as file: + config = file.read() + + assert config == expected_config + + if path.exists(storage_path): + rmtree(storage_path) + + +@time_machine.travel(FREEZE_TIME) +def test_dlt_pipeline(runner, tmp_path): + from dlt.common.pipeline import get_dlt_pipelines_dir + + root_dir = path.abspath(getcwd()) + pipeline_path = root_dir + "/examples/sushi_dlt/sushi_pipeline.py" + dataset_path = root_dir + "/sushi.duckdb" + + if path.exists(dataset_path): + remove(dataset_path) + + with open(pipeline_path) as file: + exec(file.read()) + + # This should fail since it won't be able to locate the pipeline in this path + with pytest.raises(ClickException, match=r".*Could not attach to pipeline*"): + init_example_project( + tmp_path, + "duckdb", + template=ProjectTemplate.DLT, + pipeline="sushi", + dlt_path="./dlt2/pipelines", + ) + + # By setting the pipelines path where the pipeline directory is located, it should work + dlt_path = get_dlt_pipelines_dir() + init_example_project( + tmp_path, "duckdb", template=ProjectTemplate.DLT, pipeline="sushi", dlt_path=dlt_path + ) + + expected_config = f"""# --- Gateway Connection --- +gateways: + duckdb: + connection: + type: duckdb + database: {dataset_path} +default_gateway: duckdb + +# --- Model Defaults --- +# https://sqlmesh.readthedocs.io/en/stable/reference/model_configuration/#model-defaults + +model_defaults: + dialect: duckdb + start: {yesterday_ds()} # Start date for backfill history + cron: '@daily' # Run models daily at 12am UTC (can override per model) + +# --- Linting Rules --- +# Enforce standards for your team +# https://sqlmesh.readthedocs.io/en/stable/guides/linter/ + +linter: + enabled: true + rules: + - ambiguousorinvalidcolumn + - invalidselectstarexpansion + - noambiguousprojections +""" + + with open(tmp_path / "config.yaml") as file: + config = file.read() + + expected_incremental_model = """MODEL ( + name sushi_dataset_sqlmesh.incremental_sushi_types, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column _dlt_load_time, + ), + grain (id), +); + +SELECT + CAST(c.id AS BIGINT) AS id, + CAST(c.name AS TEXT) AS name, + CAST(c._dlt_load_id AS TEXT) AS _dlt_load_id, + CAST(c._dlt_id AS TEXT) AS _dlt_id, + TO_TIMESTAMP(CAST(c._dlt_load_id AS DOUBLE)) as _dlt_load_time +FROM + sushi_dataset.sushi_types as c +WHERE + TO_TIMESTAMP(CAST(c._dlt_load_id AS DOUBLE)) BETWEEN @start_ds AND @end_ds +""" + + dlt_sushi_types_model_path = tmp_path / "models/incremental_sushi_types.sql" + dlt_loads_model_path = tmp_path / "models/incremental__dlt_loads.sql" + dlt_waiters_model_path = tmp_path / "models/incremental_waiters.sql" + dlt_sushi_fillings_model_path = tmp_path / "models/incremental_sushi_menu__fillings.sql" + dlt_sushi_twice_nested_model_path = ( + tmp_path / "models/incremental_sushi_menu__details__ingredients.sql" + ) + + with open(dlt_sushi_types_model_path) as file: + incremental_model = file.read() + + expected_dlt_loads_model = """MODEL ( + name sushi_dataset_sqlmesh.incremental__dlt_loads, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column _dlt_load_time, + ), +); + +SELECT + CAST(c.load_id AS TEXT) AS load_id, + CAST(c.schema_name AS TEXT) AS schema_name, + CAST(c.status AS BIGINT) AS status, + CAST(c.inserted_at AS TIMESTAMP) AS inserted_at, + CAST(c.schema_version_hash AS TEXT) AS schema_version_hash, + TO_TIMESTAMP(CAST(c.load_id AS DOUBLE)) as _dlt_load_time +FROM + sushi_dataset._dlt_loads as c +WHERE + TO_TIMESTAMP(CAST(c.load_id AS DOUBLE)) BETWEEN @start_ds AND @end_ds +""" + + with open(dlt_loads_model_path) as file: + dlt_loads_model = file.read() + + expected_nested_fillings_model = """MODEL ( + name sushi_dataset_sqlmesh.incremental_sushi_menu__fillings, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column _dlt_load_time, + ), +); + +SELECT + CAST(c.value AS TEXT) AS value, + CAST(c._dlt_root_id AS TEXT) AS _dlt_root_id, + CAST(c._dlt_parent_id AS TEXT) AS _dlt_parent_id, + CAST(c._dlt_list_idx AS BIGINT) AS _dlt_list_idx, + CAST(c._dlt_id AS TEXT) AS _dlt_id, + TO_TIMESTAMP(CAST(p._dlt_load_id AS DOUBLE)) as _dlt_load_time +FROM + sushi_dataset.sushi_menu__fillings as c +JOIN + sushi_dataset.sushi_menu as p +ON + c._dlt_parent_id = p._dlt_id +WHERE + TO_TIMESTAMP(CAST(p._dlt_load_id AS DOUBLE)) BETWEEN @start_ds AND @end_ds +""" + + with open(dlt_sushi_fillings_model_path) as file: + nested_model = file.read() + + # Validate generated config and models + assert config == expected_config + assert dlt_loads_model_path.exists() + assert dlt_sushi_types_model_path.exists() + assert dlt_waiters_model_path.exists() + assert dlt_sushi_fillings_model_path.exists() + assert dlt_sushi_twice_nested_model_path.exists() + assert dlt_loads_model == expected_dlt_loads_model + assert incremental_model == expected_incremental_model + assert nested_model == expected_nested_fillings_model + + try: + # Plan prod and backfill + result = runner.invoke( + cli, ["--log-file-dir", tmp_path, "--paths", tmp_path, "plan", "--auto-apply"] + ) + + assert result.exit_code == 0 + assert_backfill_success(result) + + # Remove and update with missing model + remove(dlt_waiters_model_path) + assert not dlt_waiters_model_path.exists() + + # Update with force = False will generate only the missing model + context = Context(paths=tmp_path) + assert generate_dlt_models(context, "sushi", [], False) == [ + "sushi_dataset_sqlmesh.incremental_waiters" + ] + assert dlt_waiters_model_path.exists() + + # Remove all models + remove(dlt_waiters_model_path) + remove(dlt_loads_model_path) + remove(dlt_sushi_types_model_path) + remove(dlt_sushi_fillings_model_path) + remove(dlt_sushi_twice_nested_model_path) + + # Update to generate a specific model: sushi_types. + # Also validate using the dlt_path that the pipelines are located. + assert generate_dlt_models(context, "sushi", ["sushi_types"], False, dlt_path) == [ + "sushi_dataset_sqlmesh.incremental_sushi_types" + ] + + # Only the sushi_types should be generated now + assert not dlt_waiters_model_path.exists() + assert not dlt_loads_model_path.exists() + assert not dlt_sushi_fillings_model_path.exists() + assert not dlt_sushi_twice_nested_model_path.exists() + assert dlt_sushi_types_model_path.exists() + + # Update with force = True will generate all models and overwrite existing ones + generate_dlt_models(context, "sushi", [], True) + assert dlt_loads_model_path.exists() + assert dlt_sushi_types_model_path.exists() + assert dlt_waiters_model_path.exists() + assert dlt_sushi_fillings_model_path.exists() + assert dlt_sushi_twice_nested_model_path.exists() + finally: + remove(dataset_path) + + +@time_machine.travel(FREEZE_TIME) +def test_environments(runner, tmp_path): + create_example_project(tmp_path) + ttl = time_like_to_str(to_datetime(now_ds()) + timedelta(days=7)) + + # create dev environment and backfill + runner.invoke( + cli, + [ + "--log-file-dir", + tmp_path, + "--paths", + tmp_path, + "plan", + "dev", + "--no-prompts", + "--auto-apply", + ], + ) + + result = runner.invoke( + cli, + [ + "--log-file-dir", + tmp_path, + "--paths", + tmp_path, + "environments", + ], + ) + assert result.exit_code == 0 + assert f"Number of SQLMesh environments are: 1\ndev - {ttl}\n" in result.output + + # # create dev2 environment from dev environment + # # Input: `y` to apply and virtual update + runner.invoke( + cli, + [ + "--log-file-dir", + tmp_path, + "--paths", + tmp_path, + "plan", + "dev2", + "--create-from", + "dev", + "--include-unmodified", + ], + input="y\n", + ) + + result = runner.invoke( + cli, + [ + "--log-file-dir", + tmp_path, + "--paths", + tmp_path, + "environments", + ], + ) + assert result.exit_code == 0 + assert f"Number of SQLMesh environments are: 2\ndev - {ttl}\ndev2 - {ttl}\n" in result.output + + # Example project models have start dates, so there are no date prompts + # for the `prod` environment. + # Input: `y` to apply and backfill + runner.invoke(cli, ["--log-file-dir", tmp_path, "--paths", tmp_path, "plan"], input="y\n") + result = runner.invoke( + cli, + [ + "--log-file-dir", + tmp_path, + "--paths", + tmp_path, + "environments", + ], + ) + assert result.exit_code == 0 + assert ( + f"Number of SQLMesh environments are: 3\ndev - {ttl}\ndev2 - {ttl}\nprod - No Expiry\n" + in result.output + ) + + +def test_lint(runner, tmp_path): + create_example_project(tmp_path) + + with open(tmp_path / "config.yaml", "a", encoding="utf-8") as f: + f.write( + """linter: + enabled: True + rules: "ALL" +""" + ) + + result = runner.invoke(cli, ["--paths", tmp_path, "lint"]) + assert result.output.count("Linter errors for") == 2 + assert result.exit_code == 1 + + # Test with specific model + result = runner.invoke( + cli, ["--paths", tmp_path, "lint", "--model", "sqlmesh_example.seed_model"] + ) + assert result.output.count("Linter errors for") == 1 + assert result.exit_code == 1 + + # Test with multiple models + result = runner.invoke( + cli, + [ + "--paths", + tmp_path, + "lint", + "--model", + "sqlmesh_example.seed_model", + "--model", + "sqlmesh_example.incremental_model", + ], + ) + assert result.output.count("Linter errors for") == 2 + assert result.exit_code == 1 + + +def test_state_export(runner: CliRunner, tmp_path: Path) -> None: + create_example_project(tmp_path) + + state_export_file = tmp_path / "state_export.json" + + # create some state + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "plan", + "--no-prompts", + "--auto-apply", + ], + ) + assert result.exit_code == 0 + + # export it + result = runner.invoke( + cli, + ["--paths", str(tmp_path), "state", "export", "-o", str(state_export_file), "--no-confirm"], + catch_exceptions=False, + ) + assert result.exit_code == 0 + assert "Gateway: local" in result.output + assert "Type: duckdb" in result.output + assert "Exporting versions" in result.output + assert "Exporting snapshots" in result.output + assert "Exporting environments" in result.output + assert "State exported successfully" in result.output + + assert state_export_file.exists() + assert len(state_export_file.read_text()) > 0 + + +def test_state_export_specific_environments(runner: CliRunner, tmp_path: Path) -> None: + create_example_project(tmp_path) + + state_export_file = tmp_path / "state_export.json" + + # create prod + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "plan", + "--no-prompts", + "--auto-apply", + ], + ) + assert result.exit_code == 0 + + (tmp_path / "models" / "new_model.sql").write_text( + """ + MODEL ( + name sqlmesh_example.new_model, + kind FULL + ); + + SELECT 1; + """ + ) + + # create dev env with new model + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "plan", + "dev", + "--no-prompts", + "--auto-apply", + ], + ) + assert result.exit_code == 0 + + # export non existent env - should fail + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "state", + "export", + "--environment", + "nonexist", + "-o", + str(state_export_file), + "--no-confirm", + ], + catch_exceptions=False, + ) + assert result.exit_code == 1 + assert "No such environment: nonexist" in result.output + + # export dev, should contain original snapshots + new one + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "state", + "export", + "--environment", + "dev", + "-o", + str(state_export_file), + "--no-confirm", + ], + catch_exceptions=False, + ) + assert result.exit_code == 0 + assert "Environment: dev" in result.output + assert "State exported successfully" in result.output + + state = json.loads(state_export_file.read_text(encoding="utf8")) + assert len(state["snapshots"]) == 4 + assert any("new_model" in s["name"] for s in state["snapshots"]) + assert len(state["environments"]) == 1 + assert "dev" in state["environments"] + assert "prod" not in state["environments"] + + +def test_state_export_local(runner: CliRunner, tmp_path: Path) -> None: + create_example_project(tmp_path) + + state_export_file = tmp_path / "state_export.json" + + # note: we have not plan+applied at all, we are just exporting local state + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "state", + "export", + "--local", + "-o", + str(state_export_file), + "--no-confirm", + ], + catch_exceptions=False, + ) + assert result.exit_code == 0 + assert "Exporting local state" in result.output + assert "the resulting file cannot be imported" in result.output + assert "State exported successfully" in result.output + + state = json.loads(state_export_file.read_text(encoding="utf8")) + assert len(state["snapshots"]) == 3 + assert not state["metadata"]["importable"] + assert len(state["environments"]) == 0 + + # test mutually exclusive with --environment + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "state", + "export", + "--environment", + "foo", + "--local", + "-o", + str(state_export_file), + "--no-confirm", + ], + catch_exceptions=False, + ) + assert result.exit_code == 1 + + assert "Cannot specify both --environment and --local" in result.output + + +def test_state_import(runner: CliRunner, tmp_path: Path) -> None: + create_example_project(tmp_path) + + state_export_file = tmp_path / "state_export.json" + + # create some state + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "plan", + "--no-prompts", + "--auto-apply", + ], + ) + assert result.exit_code == 0 + + # export it + result = runner.invoke( + cli, + ["--paths", str(tmp_path), "state", "export", "-o", str(state_export_file), "--no-confirm"], + catch_exceptions=False, + ) + assert result.exit_code == 0 + + # import it back + result = runner.invoke( + cli, + ["--paths", str(tmp_path), "state", "import", "-i", str(state_export_file), "--no-confirm"], + catch_exceptions=False, + ) + assert result.exit_code == 0 + + assert "Gateway: local" in result.output + assert "Type: duckdb" in result.output + assert "Importing versions" in result.output + assert "Importing snapshots" in result.output + assert "Importing environments" in result.output + assert "State imported successfully" in result.output + + # plan should have no changes + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "plan", + ], + ) + assert result.exit_code == 0 + assert "No changes to plan" in result.output + + +def test_state_import_replace(runner: CliRunner, tmp_path: Path) -> None: + create_example_project(tmp_path) + + state_export_file = tmp_path / "state_export.json" + + # prod + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "plan", + "--no-prompts", + "--auto-apply", + ], + ) + assert result.exit_code == 0 + + (tmp_path / "models" / "new_model.sql").write_text( + """ + MODEL ( + name sqlmesh_example.new_model, + kind FULL + ); + + SELECT 1; + """ + ) + + # create dev with new model + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "plan", + "dev", + "--no-prompts", + "--auto-apply", + ], + ) + assert result.exit_code == 0 + + # prove both dev and prod exist + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "environments", + ], + ) + assert result.exit_code == 0 + assert "dev -" in result.output + assert "prod -" in result.output + + # export just prod + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "state", + "export", + "--environment", + "prod", + "-o", + str(state_export_file), + "--no-confirm", + ], + catch_exceptions=False, + ) + assert result.exit_code == 0 + + # import it back with --replace + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "state", + "import", + "-i", + str(state_export_file), + "--replace", + "--no-confirm", + ], + catch_exceptions=False, + ) + assert result.exit_code == 0 + assert "State imported successfully" in result.output + + # prove only prod exists now + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "environments", + ], + ) + assert result.exit_code == 0 + assert "dev -" not in result.output + assert "prod -" in result.output + + +def test_state_import_local(runner: CliRunner, tmp_path: Path) -> None: + create_example_project(tmp_path) + + state_export_file = tmp_path / "state_export.json" + + # local state export + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "state", + "export", + "--local", + "-o", + str(state_export_file), + "--no-confirm", + ], + catch_exceptions=False, + ) + assert result.exit_code == 0 + + # import should fail - local state is not importable + result = runner.invoke( + cli, + ["--paths", str(tmp_path), "state", "import", "-i", str(state_export_file), "--no-confirm"], + catch_exceptions=False, + ) + assert result.exit_code == 1 + assert "State file is marked as not importable" in result.output + assert "Aborting" in result.output + + +def test_ignore_warnings(runner: CliRunner, tmp_path: Path) -> None: + create_example_project(tmp_path) + + # Add non-blocking audit to generate WARNING + with open(tmp_path / "models" / "full_model.sql", "w", encoding="utf-8") as f: + f.write(""" +MODEL ( + name sqlmesh_example.full_model, + kind FULL, + cron '@daily', + grain item_id, + audits (full_nonblocking_audit), +); + +SELECT + item_id, + COUNT(DISTINCT id) AS num_orders, +FROM + sqlmesh_example.incremental_model +GROUP BY item_id; + +AUDIT ( + name full_nonblocking_audit, + blocking false, +); +select 1 as a; +""") + + audit_warning = "[WARNING] sqlmesh_example.full_model: 'full_nonblocking_audit' audit error: " + + result = runner.invoke( + cli, + ["--paths", str(tmp_path), "plan", "--no-prompts", "--auto-apply", "--skip-tests"], + ) + assert result.exit_code == 0 + assert audit_warning in result.output + + result = runner.invoke( + cli, + [ + "--ignore-warnings", + "--paths", + str(tmp_path), + "plan", + "--no-prompts", + "--auto-apply", + "--skip-tests", + ], + ) + assert result.exit_code == 0 + assert audit_warning not in result.output + + +def test_table_diff_schema_diff_ignore_case(runner: CliRunner, tmp_path: Path): + from sqlmesh.core.engine_adapter import DuckDBEngineAdapter + + create_example_project(tmp_path) + + ctx = Context(paths=tmp_path) + assert isinstance(ctx.engine_adapter, DuckDBEngineAdapter) + + ctx.engine_adapter.execute('create table t1 (id int, "naME" varchar)') + ctx.engine_adapter.execute('create table t2 (id int, "name" varchar)') + + # default behavior (case sensitive) + result = runner.invoke( + cli, + ["--paths", str(tmp_path), "table_diff", "t1:t2", "-o", "id"], + ) + assert result.exit_code == 0 + stripped_output = "".join((x for x in result.output if x in string.printable)) + assert "Added Columns:\n name (TEXT)" in stripped_output + assert "Removed Columns:\n naME (TEXT)" in stripped_output + + # ignore case + result = runner.invoke( + cli, + ["--paths", str(tmp_path), "table_diff", "t1:t2", "-o", "id", "--schema-diff-ignore-case"], + ) + assert result.exit_code == 0 + stripped_output = "".join((x for x in result.output if x in string.printable)) + assert "Schema Diff Between 'T1' and 'T2':\n Schemas match" in stripped_output + + +# passing an invalid engine_type errors +def test_init_bad_engine_type(runner: CliRunner, tmp_path: Path): + result = runner.invoke( + cli, + ["--paths", str(tmp_path), "init", "invalid"], + ) + assert result.exit_code == 1 + assert "Invalid engine 'invalid'. Please specify one of " in result.output + + +# passing an invalid template errors +def test_init_bad_template(runner: CliRunner, tmp_path: Path): + result = runner.invoke( + cli, + ["--paths", str(tmp_path), "init", "-t", "invalid_template"], + ) + assert result.exit_code == 1 + assert "Invalid project template 'invalid_template'. Please specify one of " in result.output + + +# empty template should not produce example project files +def test_init_empty_template(runner: CliRunner, tmp_path: Path): + result = runner.invoke( + cli, + ["--paths", str(tmp_path), "init", "duckdb", "-t", "empty"], + ) + assert result.exit_code == 0 + + # Directories should exist, but example project files should not. + assert (tmp_path / "models").exists() + assert not (tmp_path / "models" / "full_model.sql").exists() + assert not (tmp_path / "models" / "incremental_model.sql").exists() + assert not (tmp_path / "seeds" / "seed_data.csv").exists() + + +# interactive init begins when no engine_type is provided and template is not dbt +def test_init_interactive_start(runner: CliRunner, tmp_path: Path): + # Input: 1 (DEFAULT template), 1 (duckdb engine), 1 (DEFAULT CLI mode) + result = runner.invoke( + cli, + ["--paths", str(tmp_path), "init"], + input="1\n1\n1\n", + ) + assert result.exit_code == 0 + assert "Choose your SQL engine" in result.output + + # dbt template passed, so no interactive + result = runner.invoke( + cli, + ["--paths", str(tmp_path), "init", "-t", "dbt"], + ) + assert "Choose your SQL engine" not in result.output + + +# passing an invalid integer response displays error +def test_init_interactive_invalid_int(runner: CliRunner, tmp_path: Path): + # First response is invalid (0) followed by valid selections. + # Input: 0 (invalid), 1 (DEFAULT template), 1 (duckdb engine), 1 (DEFAULT CLI mode) + result = runner.invoke( + cli, + ["--paths", str(tmp_path), "init"], + input="0\n1\n1\n1\n", + ) + assert result.exit_code == 0 + assert ( + "'0' is not a valid project type number - please enter a number between 1" in result.output + ) + + +# interactive init template step should not appear if a template is passed +def test_init_interactive_template_passed(runner: CliRunner, tmp_path: Path): + # Input: 1 (duckdb engine), 1 (DEFAULT CLI mode) + result = runner.invoke( + cli, + ["--paths", str(tmp_path), "init", "-t", "empty"], + input="1\n1\n", + ) + assert result.exit_code == 0 + assert "What type of project do you want to set up?" not in result.output + + +def test_init_interactive_cli_mode_default(runner: CliRunner, tmp_path: Path): + # Input: 1 (DEFAULT template), 1 (duckdb engine), 1 (DEFAULT CLI mode) + result = runner.invoke( + cli, + ["--paths", str(tmp_path), "init"], + input="1\n1\n1\n", + ) + assert result.exit_code == 0 + + config_path = tmp_path / "config.yaml" + assert config_path.exists() + assert "no_diff: true" not in config_path.read_text() + + +def test_init_interactive_cli_mode_simple(runner: CliRunner, tmp_path: Path): + # Input: 1 (DEFAULT template), 1 (duckdb engine), 2 (SIMPLE CLI mode) + result = runner.invoke( + cli, + ["--paths", str(tmp_path), "init"], + input="1\n1\n2\n", + ) + assert result.exit_code == 0 + + config_path = tmp_path / "config.yaml" + assert config_path.exists() + assert "no_diff: true" in config_path.read_text() + + +def test_init_interactive_engine_install_msg(runner: CliRunner, tmp_path: Path, monkeypatch): + monkeypatch.setattr("sqlmesh.utils.rich.console.width", 80) + + # Engine install text should not appear for built-in engines like DuckDB + # Input: 1 (DEFAULT template), 1 (duckdb engine), 1 (DEFAULT CLI mode) + result = runner.invoke( + cli, + ["--paths", str(tmp_path), "init"], + input="1\n1\n1\n", + ) + assert result.exit_code == 0 + assert "Run command in CLI to install your SQL engine" not in result.output + + remove(tmp_path / "config.yaml") + + # Input: 1 (DEFAULT template), 13 (gcp postgres engine), 1 (DEFAULT CLI mode) + result = runner.invoke( + cli, + ["--paths", str(tmp_path), "init"], + input="1\n13\n1\n", + ) + assert result.exit_code == 0 + assert ( + 'Run command in CLI to install your SQL engine\'s Python dependencies: pip \ninstall "sqlmesh[gcppostgres]"' + in result.output + ) + + +# dbt template without dbt_project.yml in directory should error +def test_init_dbt_template_no_dbt_project(runner: CliRunner, tmp_path: Path): + # template passed to init + result = runner.invoke( + cli, + ["--paths", str(tmp_path), "init", "-t", "dbt"], + ) + assert result.exit_code == 1 + assert ( + "Required dbt project file 'dbt_project.yml' not found in the current directory." + in result.output + ) + + # interactive init + # Input: 2 (dbt template) + result = runner.invoke( + cli, + ["--paths", str(tmp_path), "init"], + input="2\n", + ) + assert result.exit_code == 1 + assert ( + "Required dbt project file 'dbt_project.yml' not found in the current directory." + in result.output + ) + + +def test_init_dbt_template(runner: CliRunner, tmp_path: Path): + Path(tmp_path / "dbt_project.yml").touch() + result = runner.invoke( + cli, + ["--paths", str(tmp_path), "init"], + input="2\n", + ) + assert result.exit_code == 0 + + config_path = tmp_path / "sqlmesh.yaml" + assert config_path.exists() + + config = config_path.read_text() + + assert "model_defaults" in config + assert "start:" in config + + +@time_machine.travel(FREEZE_TIME) +def test_init_project_engine_configs(tmp_path): + engine_type_to_config = { + "redshift": "# concurrent_tasks: 4\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # catalog_type_overrides: \n # user: \n # password: \n # database: \n # host: \n # port: \n # source_address: \n # unix_sock: \n # ssl: \n # sslmode: \n # timeout: \n # tcp_keepalive: \n # application_name: \n # preferred_role: \n # principal_arn: \n # credentials_provider: \n # region: \n # cluster_identifier: \n # iam: \n # is_serverless: \n # serverless_acct_id: \n # serverless_work_group: \n # enable_merge: ", + "bigquery": "# concurrent_tasks: 1\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # catalog_type_overrides: \n # method: oauth\n # project: \n # execution_project: \n # quota_project: \n # location: \n # keyfile: \n # keyfile_json: \n # token: \n # refresh_token: \n # client_id: \n # client_secret: \n # token_uri: \n # scopes: \n # impersonated_service_account: \n # job_creation_timeout_seconds: \n # job_execution_timeout_seconds: \n # job_retries: 1\n # job_retry_deadline_seconds: \n # priority: \n # maximum_bytes_billed: ", + "snowflake": "account: \n # concurrent_tasks: 4\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # catalog_type_overrides: \n # user: \n # password: \n # warehouse: \n # database: \n # role: \n # authenticator: \n # token: \n # host: \n # port: \n # application: Tobiko_SQLMesh\n # private_key: \n # private_key_path: \n # private_key_passphrase: \n # session_parameters: ", + "databricks": "# concurrent_tasks: 1\n # register_comments: True\n # pre_ping: False\n # pretty_sql: False\n # schema_differ_overrides: \n # catalog_type_overrides: \n # server_hostname: \n # http_path: \n # access_token: \n # auth_type: \n # oauth_client_id: \n # oauth_client_secret: \n # catalog: \n # http_headers: \n # session_configuration: \n # databricks_connect_server_hostname: \n # databricks_connect_access_token: \n # databricks_connect_cluster_id: \n # databricks_connect_use_serverless: False\n # force_databricks_connect: False\n # disable_databricks_connect: False\n # disable_spark_session: False", + "postgres": "host: \n user: \n password: \n port: \n database: \n # concurrent_tasks: 4\n # register_comments: True\n # pre_ping: True\n # pretty_sql: False\n # schema_differ_overrides: \n # catalog_type_overrides: \n # keepalives_idle: \n # connect_timeout: 10\n # role: \n # sslmode: \n # application_name: ", + } + + for engine_type, expected_config in engine_type_to_config.items(): + init_example_project(tmp_path, engine_type=engine_type) + + config_start = f"# --- Gateway Connection ---\ngateways:\n {engine_type}:\n connection:\n # For more information on configuring the connection to your execution engine, visit:\n # https://sqlmesh.readthedocs.io/en/stable/reference/configuration/#connection\n # https://sqlmesh.readthedocs.io/en/stable/integrations/engines/{engine_type}/#connection-options\n type: {engine_type}\n " + config_end = f""" + +default_gateway: {engine_type} + +# --- Model Defaults --- +# https://sqlmesh.readthedocs.io/en/stable/reference/model_configuration/#model-defaults + +model_defaults: + dialect: {DIALECT_TO_TYPE.get(engine_type)} + start: {yesterday_ds()} # Start date for backfill history + cron: '@daily' # Run models daily at 12am UTC (can override per model) + +# --- Linting Rules --- +# Enforce standards for your team +# https://sqlmesh.readthedocs.io/en/stable/guides/linter/ + +linter: + enabled: true + rules: + - ambiguousorinvalidcolumn + - invalidselectstarexpansion + - noambiguousprojections +""" + + with open(tmp_path / "config.yaml") as file: + config = file.read() + + assert config == f"{config_start}{expected_config}{config_end}" + + remove(tmp_path / "config.yaml") + + +def test_render(runner: CliRunner, tmp_path: Path): + create_example_project(tmp_path) + + ctx = Context(paths=tmp_path) + + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "render", + "sqlmesh_example.full_model", + "--max-text-width", + "10", + ], + ) + assert result.exit_code == 0 + + cleaned_output = "\n".join(l.rstrip(" ") for l in result.output.split("\n")) + expected = """SELECT + "incremental_model"."item_id" AS "item_id", + COUNT( + DISTINCT "incremental_model"."id" + ) AS "num_orders" +FROM "db"."sqlmesh_example"."incremental_model" AS "incremental_model" +GROUP BY + "incremental_model"."item_id" +""" + + assert expected in cleaned_output + + +@time_machine.travel(FREEZE_TIME) +def test_signals(runner: CliRunner, tmp_path: Path): + create_example_project(tmp_path, template=ProjectTemplate.EMPTY) + + # Create signals module + signals_dir = tmp_path / "signals" + signals_dir.mkdir(exist_ok=True) + + # Create signal definitions + (signals_dir / "signal.py").write_text( + """from sqlmesh import signal +@signal() +def only_first_two_ready(batch): + if len(batch) > 2: + return batch[:2] + return batch + +@signal() +def none_ready(batch): + return False +""" + ) + + # Create model with signals + (tmp_path / "models" / "model_with_signals.sql").write_text( + """MODEL ( + name sqlmesh_example.model_with_signals, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds + ), + start '2022-12-28', + cron '@daily', + signals [ + only_first_two_ready() + ] +); + +SELECT + ds::DATE as ds, + 'test' as value +FROM VALUES + ('2022-12-28'), + ('2022-12-29'), + ('2022-12-30'), + ('2022-12-31'), + ('2023-01-01') +AS t(ds) +WHERE ds::DATE BETWEEN @start_ds AND @end_ds +""" + ) + + # Create model with no ready intervals + (tmp_path / "models" / "model_with_unready.sql").write_text( + """MODEL ( + name sqlmesh_example.model_with_unready, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds + ), + start '2022-12-28', + cron '@daily', + signals [ + none_ready() + ] +); + +SELECT + ds::DATE as ds, + 'unready' as value +FROM VALUES + ('2022-12-28'), + ('2022-12-29'), + ('2022-12-30'), + ('2022-12-31'), + ('2023-01-01') +AS t(ds) +WHERE ds::DATE BETWEEN @start_ds AND @end_ds +""" + ) + + # Test 1: Normal plan flow with --no-prompts --auto-apply + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "plan", + "--no-prompts", + "--auto-apply", + ], + ) + assert result.exit_code == 0 + + assert "Checking signals for sqlmesh_example.model_with_signals" in result.output + assert "[1/1] only_first_two_ready" in result.output + assert "Check: 2022-12-28 - 2022-12-31" in result.output + assert "Some ready: 2022-12-28 - 2022-12-29" in result.output + + assert "Checking signals for sqlmesh_example.model_with_unready" in result.output + assert "[1/1] none_ready" in result.output + assert "None ready: no intervals" in result.output + + # Test 2: Run command with start and end dates + result = runner.invoke( + cli, + [ + "--paths", + str(tmp_path), + "run", + "--start", + "2022-12-29", + "--end", + "2022-12-31", + ], + ) + assert result.exit_code == 0 + + assert "Checking signals for sqlmesh_example.model_with_signals" in result.output + assert "[1/1] only_first_two_ready" in result.output + assert "Check: 2022-12-30 - 2022-12-31" in result.output + assert "All ready: 2022-12-30 - 2022-12-31" in result.output + + assert "Checking signals for sqlmesh_example.model_with_unready" in result.output + assert "[1/1] none_ready" in result.output + assert "Check: 2022-12-29 - 2022-12-31" in result.output + assert "None ready: no intervals" in result.output + + # Only one model was executed + assert "100.0% • 1/1 • 0:00:00" in result.output + + rmtree(tmp_path) + tmp_path.mkdir(parents=True, exist_ok=True) + + create_example_project(tmp_path) + + # Example project models have start dates, so there are no date prompts + # for the `prod` environment. + # Input: `y` to apply and backfill + result = runner.invoke( + cli, ["--log-file-dir", str(tmp_path), "--paths", str(tmp_path), "plan"], input="y\n" + ) + assert_plan_success(result) + + assert "Checking signals" not in result.output + + +@pytest.mark.isolated +@time_machine.travel(FREEZE_TIME) +def test_format_leading_comma_default(runner: CliRunner, tmp_path: Path): + """Test that format command respects leading_comma environment variable.""" + create_example_project(tmp_path, template=ProjectTemplate.EMPTY) + + # Create a SQL file with trailing comma format + test_sql = tmp_path / "models" / "test_format.sql" + test_sql.write_text("""MODEL ( + name sqlmesh_example.test_format, + kind FULL +); + +SELECT + col1, + col2, + col3 +FROM table1""") + + # Test 1: Default behavior (no env var set) - should not change the file + result = runner.invoke(cli, ["--paths", str(tmp_path), "format", "--check"]) + assert result.exit_code == 0 + + # Test 2: Set env var to true - should require reformatting to leading comma + os.environ["SQLMESH__FORMAT__LEADING_COMMA"] = "true" + try: + result = runner.invoke(cli, ["--paths", str(tmp_path), "format", "--check"]) + # Should exit with 1 because formatting is needed + assert result.exit_code == 1 + + # Actually format the file + result = runner.invoke(cli, ["--paths", str(tmp_path), "format"]) + assert result.exit_code == 0 + + # Check that the file now has leading commas + formatted_content = test_sql.read_text() + assert ", col2" in formatted_content + assert ", col3" in formatted_content + + # Now check should pass + result = runner.invoke(cli, ["--paths", str(tmp_path), "format", "--check"]) + assert result.exit_code == 0 + finally: + # Clean up env var + del os.environ["SQLMESH__FORMAT__LEADING_COMMA"] + + # Test 3: Explicit command line flag overrides env var + os.environ["SQLMESH__FORMAT__LEADING_COMMA"] = "false" + try: + # Write file with leading commas + test_sql.write_text("""MODEL ( + name sqlmesh_example.test_format, + kind FULL +); + +SELECT + col1 + , col2 + , col3 +FROM table1""") + + # Check with --leading-comma flag (should pass) + result = runner.invoke( + cli, + ["--paths", str(tmp_path), "format", "--check", "--leading-comma"], + ) + assert result.exit_code == 0 + finally: + del os.environ["SQLMESH__FORMAT__LEADING_COMMA"] diff --git a/tests/cli/test_integration_cli.py b/tests/cli/test_integration_cli.py new file mode 100644 index 0000000000..5d000b9d8b --- /dev/null +++ b/tests/cli/test_integration_cli.py @@ -0,0 +1,376 @@ +import typing as t +from pathlib import Path +import pytest +import subprocess +from sqlmesh.cli.project_init import init_example_project +from sqlmesh.utils import yaml +import shutil +import site +import uuid + +pytestmark = pytest.mark.slow + + +class InvokeCliType(t.Protocol): + def __call__( + self, sqlmesh_args: t.List[str], **kwargs: t.Any + ) -> subprocess.CompletedProcess: ... + + +class CreateSitePackageType(t.Protocol): + def __call__(self, name: str) -> t.Tuple[str, Path]: ... + + +@pytest.fixture +def invoke_cli(tmp_path: Path) -> InvokeCliType: + # Fetch the full path to the SQLMesh binary so that when we use `cwd` to run in the context of a test dir, the correct SQLMesh binary is executed + # this will be the current project because `make install-dev` installs an editable version of SQLMesh into the current python environment + sqlmesh_bin = subprocess.run( + ["which", "sqlmesh"], capture_output=True, text=True + ).stdout.strip() + + def _invoke(sqlmesh_args: t.List[str], **kwargs: t.Any) -> subprocess.CompletedProcess: + return subprocess.run( + args=[sqlmesh_bin] + sqlmesh_args, + # set the working directory to the isolated temp dir for this test + cwd=tmp_path, + # return text instead of binary from the output streams + text=True, + # combine stdout/stderr into a single stream + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + **kwargs, + ) + + return _invoke + + +@pytest.fixture +def duckdb_example_project(tmp_path: Path) -> Path: + init_example_project(tmp_path, engine_type="duckdb") + config_path = tmp_path / "config.yaml" + + # we need state to persist between invocations + config_dict = yaml.load(config_path) + config_dict["gateways"]["duckdb"]["state_connection"] = { + "type": "duckdb", + "database": str(tmp_path / "state.db"), + } + config_path.write_text(yaml.dump(config_dict)) + + return tmp_path + + +@pytest.fixture +def last_log_file_contents(tmp_path: Path) -> t.Callable[[], str]: + def _fetch() -> str: + log_file = sorted(list((tmp_path / "logs").iterdir()))[-1] + return log_file.read_text() + + return _fetch + + +@pytest.fixture +def create_site_package() -> t.Iterator[CreateSitePackageType]: + created_package_path = None + + def _create(name: str) -> t.Tuple[str, Path]: + nonlocal created_package_path + + unique_id = str(uuid.uuid4())[0:8] + package_name = f"{name}_{unique_id}" # so that multiple tests using the same name dont clobber each other + + site_packages = site.getsitepackages()[0] + package_path = Path(site_packages) / package_name + package_path.mkdir() + + created_package_path = package_path + + return package_name, package_path + + yield _create + + if created_package_path: + # cleanup + shutil.rmtree(created_package_path, ignore_errors=True) + + +def test_load_snapshots_that_reference_nonexistent_python_libraries( + invoke_cli: InvokeCliType, + duckdb_example_project: Path, + last_log_file_contents: t.Callable[[], str], + create_site_package: CreateSitePackageType, +) -> None: + """ + Scenario: + - A model is created using a macro that is imported from an external package + - That model is applied + snapshot committed to state + - The external package is removed locally and the import macro import is changed to an inline definition + + Outcome: + - `sqlmesh plan` should not exit with an ImportError when it tries to render the query of the snapshot stored in state + - Instead, it should log a warning and proceed with applying the new model version + """ + + project_path = duckdb_example_project + + # simulate a 3rd party library that provides a macro + package_name, package_path = create_site_package("sqlmesh_test_macros") + (package_path / "macros.py").write_text(""" +from sqlmesh import macro + +@macro() +def do_something(evaluator): + return "'value from site-packages'" +""") + + # reference the macro from site-packages + (project_path / "macros" / "__init__.py").write_text(f""" +from {package_name}.macros import do_something +""") + + (project_path / "models" / "example.sql").write_text(""" +MODEL ( + name example.test_model, + kind FULL +); + +select @do_something() as a +""") + + result = invoke_cli(["plan", "--no-prompts", "--auto-apply", "--skip-tests"]) + + assert result.returncode == 0 + assert "Virtual layer updated" in result.stdout + + # render the query to ensure our macro is being invoked + result = invoke_cli(["render", "example.test_model"]) + assert result.returncode == 0 + assert """SELECT 'value from site-packages' AS "a\"""" in " ".join(result.stdout.split()) + + # clear cache to ensure we are forced to reload everything + assert invoke_cli(["clean"]).returncode == 0 + + # deleting this removes the 'do_something()' macro used by the version of the snapshot stored in state + # when loading the old snapshot from state in the local python env, this will create an ImportError + shutil.rmtree(package_path) + + # Move the macro inline so its no longer being loaded from a library but still exists with the same signature + (project_path / "macros" / "__init__.py").write_text(""" +from sqlmesh import macro + +@macro() +def do_something(evaluator): + return "'some value not from site-packages'" +""") + + # this should produce an error but not a fatal one. there will be an error rendering the optimized query of the old snapshot, which should be logged + result = invoke_cli( + [ + "plan", + "--no-prompts", + "--auto-apply", + "--skip-tests", + ] + ) + assert result.returncode == 0 + assert "Virtual layer updated" in result.stdout + + log_file_contents = last_log_file_contents() + assert f"ModuleNotFoundError: No module named '{package_name}'" in log_file_contents + assert ( + "ERROR - Failed to cache optimized query for model 'example.test_model'" + in log_file_contents + ) + assert ( + 'ERROR - Failed to cache snapshot SnapshotId<"db"."example"."test_model"' + in log_file_contents + ) + + +def test_model_selector_snapshot_references_nonexistent_python_libraries( + invoke_cli: InvokeCliType, + duckdb_example_project: Path, + last_log_file_contents: t.Callable[[], str], + create_site_package: CreateSitePackageType, +) -> None: + """ + Scenario: + - A model is created using a macro that is imported from an external package + - That model is applied + snapshot committed to state + - The external package is removed locally and the import macro import is changed to an inline definition + - Thus, local version of the model can be rendered but the remote version in state cannot + + Outcome: + - `sqlmesh plan --select-model ` should work as it picks up the local version + - `sqlmesh plan --select-model should exit with an error, because the plan needs a valid DAG and the remote version is no longer valid locally + """ + project_path = duckdb_example_project + + # simulate a 3rd party library that provides a macro + package_name, package_path = create_site_package("sqlmesh_test_macros") + (package_path / "macros.py").write_text(""" +from sqlmesh import macro + +@macro() +def do_something(evaluator): + return "'value from site-packages'" +""") + + # reference the macro from site-packages + (project_path / "macros" / "__init__.py").write_text(f""" +from {package_name}.macros import do_something +""") + + (project_path / "models" / "example.sql").write_text(""" +MODEL ( + name sqlmesh_example.test_model, + kind FULL +); + +select @do_something() as a +""") + + result = invoke_cli(["plan", "--no-prompts", "--auto-apply", "--skip-tests"]) + + assert result.returncode == 0 + assert "Virtual layer updated" in result.stdout + + # clear cache to ensure we are forced to reload everything + assert invoke_cli(["clean"]).returncode == 0 + + # deleting this removes the 'do_something()' macro used by the version of the snapshot stored in state + # when loading the old snapshot from state in the local python env, this will create an ImportError + shutil.rmtree(package_path) + + # Move the macro inline so its no longer being loaded from a library but still exists with the same signature + (project_path / "macros" / "__init__.py").write_text(""" +from sqlmesh import macro + +@macro() +def do_something(evaluator): + return "'some value not from site-packages'" +""") + + # the invalid snapshot is in state but is not preventing a plan + result = invoke_cli( + [ + "plan", + "--no-prompts", + "--skip-tests", + ], + input="n", # for the apply backfill (y/n) prompt + ) + assert result.returncode == 0 + assert "Apply - Backfill Tables [y/n]:" in result.stdout + + # the invalid snapshot in state should not prevent a plan if --select-model is used on it (since the local version can be rendered) + result = invoke_cli( + ["plan", "--select-model", "sqlmesh_example.test_model", "--no-prompts", "--skip-tests"], + input="n", # for the apply backfill (y/n) prompt + ) + assert result.returncode == 0 + assert "ModuleNotFoundError" not in result.stdout + assert "sqlmesh_example.test_model" in result.stdout + assert "Apply - Backfill Tables" in result.stdout + + # the invalid snapshot in state should prevent a plan if --select-model is used on another model + # (since this says to SQLMesh "source everything from state except this selected model" and the plan DAG must be valid to run the plan) + result = invoke_cli( + [ + "plan", + "--select-model", + "sqlmesh_example.full_model", + "--no-prompts", + "--skip-tests", + ], + input="n", # for the apply backfill (y/n) prompt + ) + assert result.returncode == 1 + assert ( + "Model 'sqlmesh_example.test_model' sourced from state cannot be rendered in the local environment" + in result.stdout + ) + assert f"No module named '{package_name}'" in result.stdout + assert ( + "If the model has been fixed locally, please ensure that the --select-model expression includes it" + in result.stdout + ) + + # verify the full stack trace was logged + log_file_contents = last_log_file_contents() + assert f"ModuleNotFoundError: No module named '{package_name}'" in log_file_contents + assert ( + "The above exception was the direct cause of the following exception:" in log_file_contents + ) + + +def test_model_selector_tags_picks_up_both_remote_and_local( + invoke_cli: InvokeCliType, duckdb_example_project: Path +) -> None: + """ + Scenario: + - A model that has already been applied to prod (so exists in state) has a tag added locally + - A new model is created locally that has the same tag + + Outcome: + - `sqlmesh plan --select-model tag:` should include both models + """ + project_path = duckdb_example_project + + # default state of full_model + (project_path / "models" / "full_model.sql").write_text(""" + MODEL ( + name sqlmesh_example.full_model, + kind FULL, + cron '@daily', + grain item_id, + audits (assert_positive_order_ids), + ); + + SELECT + item_id, + COUNT(DISTINCT id) AS num_orders + FROM sqlmesh_example.incremental_model + GROUP BY item_id + """) + + # apply plan - starting point + result = invoke_cli(["plan", "--no-prompts", "--auto-apply", "--skip-tests"]) + + assert result.returncode == 0 + assert "Virtual layer updated" in result.stdout + + # add a new model locally with tag:a + (project_path / "models" / "new_model.sql").write_text(""" + MODEL ( + name sqlmesh_example.new_model, + kind full, + tags (a) + ); + + SELECT 1; + """) + + # update full_model with tag:a + (project_path / "models" / "full_model.sql").write_text(""" + MODEL ( + name sqlmesh_example.full_model, + kind FULL, + tags (a) + ); + + SELECT + item_id, + COUNT(DISTINCT id) AS num_orders + FROM sqlmesh_example.incremental_model + GROUP BY item_id + """) + + result = invoke_cli( + ["plan", "--select-model", "tag:a", "--no-prompts", "--skip-tests"], + input="n", # for the apply backfill (y/n) prompt + ) + assert result.returncode == 0 + assert "sqlmesh_example.full_model" in result.stdout # metadata update: tags + assert "sqlmesh_example.new_model" in result.stdout # added diff --git a/tests/cli/test_project_init.py b/tests/cli/test_project_init.py new file mode 100644 index 0000000000..12b42705e1 --- /dev/null +++ b/tests/cli/test_project_init.py @@ -0,0 +1,33 @@ +import pytest +from pathlib import Path +from sqlmesh.utils.errors import SQLMeshError +from sqlmesh.cli.project_init import init_example_project, ProjectTemplate +from sqlmesh.utils import yaml +from sqlmesh.core.context import Context +from sqlmesh.core.config.common import VirtualEnvironmentMode + + +def test_project_init_dbt(tmp_path: Path): + assert not len(list(tmp_path.glob("**/*"))) + + with pytest.raises(SQLMeshError, match=r"Required dbt project file.*not found"): + init_example_project(path=tmp_path, engine_type=None, template=ProjectTemplate.DBT) + + with (tmp_path / "dbt_project.yml").open("w") as f: + yaml.dump({"name": "jaffle_shop"}, f) + + init_example_project(path=tmp_path, engine_type=None, template=ProjectTemplate.DBT) + files = [f for f in tmp_path.glob("**/*") if f.is_file()] + + assert set([f.name for f in files]) == set(["sqlmesh.yaml", "dbt_project.yml"]) + + sqlmesh_config = next(f for f in files if f.name == "sqlmesh.yaml") + assert "model_defaults" in sqlmesh_config.read_text() + assert "start: " in sqlmesh_config.read_text() + + with (tmp_path / "profiles.yml").open("w") as f: + yaml.dump({"jaffle_shop": {"target": "dev", "outputs": {"dev": {"type": "duckdb"}}}}, f) + + ctx = Context(paths=tmp_path) + assert ctx.config.model_defaults.start + assert ctx.config.virtual_environment_mode == VirtualEnvironmentMode.DEV_ONLY diff --git a/tests/common_fixtures.py b/tests/common_fixtures.py deleted file mode 100644 index 91d09d470f..0000000000 --- a/tests/common_fixtures.py +++ /dev/null @@ -1,9 +0,0 @@ -import pytest -from pytest_mock.plugin import MockerFixture - -from sqlmesh.schedulers.airflow.client import AirflowClient - - -@pytest.fixture(scope="function") -def mock_airflow_client(mocker: MockerFixture) -> AirflowClient: - return AirflowClient(airflow_url="", session=mocker.Mock()) diff --git a/tests/conftest.py b/tests/conftest.py index d543a904e9..b18271465d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,17 +1,21 @@ from __future__ import annotations + import datetime import logging import typing as t import uuid +from contextlib import nullcontext from pathlib import Path from shutil import copytree, rmtree from tempfile import TemporaryDirectory from unittest import mock from unittest.mock import PropertyMock +import os +import shutil -import duckdb -import pandas as pd +import duckdb # noqa: TID253 +import pandas as pd # noqa: TID253 import pytest from pytest_mock.plugin import MockerFixture from sqlglot import exp, maybe_parse, parse_one @@ -19,16 +23,19 @@ from sqlglot.helper import ensure_list from sqlglot.optimizer.normalize_identifiers import normalize_identifiers -from sqlmesh.core.config import DuckDBConnectionConfig +from sqlmesh.core.config import Config, BaseDuckDBConnectionConfig, DuckDBConnectionConfig +from sqlmesh.core.config.connection import ConnectionConfig from sqlmesh.core.context import Context -from sqlmesh.core.engine_adapter import SparkEngineAdapter +from sqlmesh.core.engine_adapter import MSSQLEngineAdapter, SparkEngineAdapter from sqlmesh.core.engine_adapter.base import EngineAdapter from sqlmesh.core.environment import EnvironmentNamingInfo +from sqlmesh.core import lineage from sqlmesh.core.macros import macro from sqlmesh.core.model import IncrementalByTimeRangeKind, SqlModel, model -from sqlmesh.core.model.kind import OnDestructiveChange -from sqlmesh.core.plan import BuiltInPlanEvaluator, Plan +from sqlmesh.core.model.kind import OnDestructiveChange, OnAdditiveChange +from sqlmesh.core.plan import BuiltInPlanEvaluator, Plan, stages as plan_stages from sqlmesh.core.snapshot import ( + DeployabilityIndex, Node, Snapshot, SnapshotChangeCategory, @@ -37,8 +44,8 @@ ) from sqlmesh.utils import random_id from sqlmesh.utils.date import TimeLike, to_date - -pytest_plugins = ["tests.common_fixtures"] +from sqlmesh.utils.windows import IS_WINDOWS, fix_windows_path +from sqlmesh.core.engine_adapter.shared import CatalogSupport T = t.TypeVar("T", bound=EngineAdapter) @@ -98,12 +105,25 @@ def qualified_views(self) -> t.List[exp.Table]: @property def schemas(self) -> t.List[str]: + return self.schemas_in_catalog(self.engine_adapter.get_current_catalog() or "") + + def schemas_in_catalog(self, catalog_name: str) -> t.List[str]: return self._get_single_col( - f"SELECT schema_name FROM information_schema.schemata WHERE catalog_name = '{self.engine_adapter.get_current_catalog()}' and {self._system_schema_filter('schema_name')}", + f"SELECT schema_name FROM information_schema.schemata WHERE catalog_name = '{catalog_name}' and {self._system_schema_filter('schema_name')}", "schema_name", self.engine_adapter, ) + @property + def catalogs(self) -> t.Set[str]: + return set( + self._get_single_col( + f"SELECT database_name FROM duckdb_databases() WHERE internal=false", + "database_name", + self.engine_adapter, + ) + ) + def _system_schema_filter(self, col: str) -> str: return f"{col} not in ('information_schema', 'pg_catalog', 'main')" @@ -113,12 +133,13 @@ def _get_single_col(query: str, col: str, engine_adapter: EngineAdapter) -> t.Li class SushiDataValidator: - def __init__(self, engine_adapter: EngineAdapter): + def __init__(self, engine_adapter: EngineAdapter, sushi_schema_name: str): self.engine_adapter = engine_adapter + self.sushi_schema_name = sushi_schema_name @classmethod - def from_context(cls, context: Context): - return cls(engine_adapter=context.engine_adapter) + def from_context(cls, context: Context, sushi_schema_name: str = "sushi"): + return cls(engine_adapter=context.engine_adapter, sushi_schema_name=sushi_schema_name) def validate( self, @@ -153,9 +174,12 @@ def validate( """ Both start and end are inclusive. """ - if model_name == "sushi.customer_revenue_lifetime": + if model_name in ( + f"{self.sushi_schema_name}.customer_revenue_lifetime", + "sushi.customer_revenue_lifetime", + ): env_name = f"__{env_name}" if env_name else "" - full_table_path = f"sushi{env_name}.customer_revenue_lifetime" + full_table_path = f"{self.sushi_schema_name}{env_name}.customer_revenue_lifetime" query = f"SELECT event_date, count(*) AS the_count FROM {full_table_path} group by event_date order by 2 desc, 1 desc" results = self.engine_adapter.fetchdf( parse_one(query), quote_identifiers=True @@ -168,18 +192,25 @@ def validate( expected_dates = [ pd.to_datetime(end_date - datetime.timedelta(days=x)) for x in range(num_days_diff) ] - # all engines but duckdb fetch dates as datetime.date objects - if dialect and dialect != "duckdb": + # all engines but duckdb and clickhouse fetch dates as datetime.date objects + if dialect and dialect not in ("duckdb", "clickhouse"): expected_dates = [x.date() for x in expected_dates] # type: ignore assert list(results["event_date"].values()) == expected_dates return results - else: - raise NotImplementedError(f"Unknown model_name: {model_name}") + raise NotImplementedError(f"Unknown model_name: {model_name}") def pytest_collection_modifyitems(items, *args, **kwargs): - test_type_markers = {"fast", "slow", "docker", "remote"} + test_type_markers = { + "fast", + "slow", + "docker", + "remote", + "isolated", + "registry_isolation", + "dialect_isolated", + } for item in items: for marker in item.iter_markers(): if marker.name in test_type_markers: @@ -212,16 +243,32 @@ def rescope_global_models(request): @pytest.fixture(scope="function", autouse=True) def rescope_duckdb_classvar(request): - DuckDBConnectionConfig._data_file_to_adapter = {} + BaseDuckDBConnectionConfig._data_file_to_adapter = {} yield -@pytest.fixture(scope="module", autouse=True) +@pytest.fixture(scope="function", autouse=True) def rescope_log_handlers(): logging.getLogger().handlers.clear() yield +@pytest.fixture(scope="function", autouse=True) +def rescope_lineage_cache(request): + lineage.CACHE.clear() + yield + + +@pytest.fixture(autouse=True) +def reset_console(): + from sqlmesh.core.console import set_console, NoopConsole, get_console + + orig_console = get_console() + set_console(NoopConsole()) + yield + set_console(orig_console) + + @pytest.fixture def duck_conn() -> duckdb.DuckDBPyConnection: return duckdb.connect() @@ -229,11 +276,32 @@ def duck_conn() -> duckdb.DuckDBPyConnection: def push_plan(context: Context, plan: Plan) -> None: plan_evaluator = BuiltInPlanEvaluator( - context.state_sync, context.snapshot_evaluator, context.default_catalog + context.state_sync, + context.snapshot_evaluator, + context.create_scheduler, + context.default_catalog, ) - plan_evaluator._push(plan) - promotion_result = plan_evaluator._promote(plan) - plan_evaluator._update_views(plan, promotion_result) + deployability_index = DeployabilityIndex.create(context.snapshots.values()) + evaluatable_plan = plan.to_evaluatable().copy(update={"skip_backfill": True}) + stages = plan_stages.build_plan_stages( + evaluatable_plan, context.state_sync, context.default_catalog + ) + for stage in stages: + if isinstance(stage, plan_stages.CreateSnapshotRecordsStage): + plan_evaluator.visit_create_snapshot_records_stage(stage, evaluatable_plan) + elif isinstance(stage, plan_stages.PhysicalLayerSchemaCreationStage): + stage.deployability_index = deployability_index + plan_evaluator.visit_physical_layer_schema_creation_stage(stage, evaluatable_plan) + elif isinstance(stage, plan_stages.PhysicalLayerUpdateStage): + stage.deployability_index = deployability_index + plan_evaluator.visit_physical_layer_update_stage(stage, evaluatable_plan) + elif isinstance(stage, plan_stages.EnvironmentRecordUpdateStage): + plan_evaluator.visit_environment_record_update_stage(stage, evaluatable_plan) + elif isinstance(stage, plan_stages.VirtualLayerUpdateStage): + stage.deployability_index = deployability_index + plan_evaluator.visit_virtual_layer_update_stage(stage, evaluatable_plan) + elif isinstance(stage, plan_stages.FinalizeEnvironmentStage): + plan_evaluator.visit_finalize_environment_stage(stage, evaluatable_plan) @pytest.fixture() @@ -352,8 +420,8 @@ def _make_function(node: Node, version: t.Optional[str] = None, **kwargs) -> Sna def make_snapshot_on_destructive_change(make_snapshot: t.Callable) -> t.Callable: def _make_function( name: str = "a", - old_query: str = "select '1' as one, '2022-01-01' ds", - new_query: str = "select 1 as one, '2022-01-01' ds", + old_query: str = "select '1' as one, '2' as two, '2022-01-01' ds", + new_query: str = "select 1 as one, 2 as two, '2022-01-01' ds", on_destructive_change: OnDestructiveChange = OnDestructiveChange.ERROR, ) -> t.Tuple[Snapshot, Snapshot]: snapshot_old = make_snapshot( @@ -384,7 +452,54 @@ def _make_function( metadata_hash="test_metadata_hash", ), version="test_version", - change_category=SnapshotChangeCategory.FORWARD_ONLY, + change_category=SnapshotChangeCategory.NON_BREAKING, + dev_table_suffix="dev", + ), + ) + + return snapshot_old, snapshot + + return _make_function + + +@pytest.fixture +def make_snapshot_on_additive_change(make_snapshot: t.Callable) -> t.Callable: + def _make_function( + name: str = "a", + old_query: str = "select '1' as one, '2' as two, '2022-01-01' ds", + new_query: str = "select '1' as one, '2' as two, '3' as three, '2022-01-01' ds", + on_additive_change: OnAdditiveChange = OnAdditiveChange.ERROR, + ) -> t.Tuple[Snapshot, Snapshot]: + snapshot_old = make_snapshot( + SqlModel( + name=name, + dialect="duckdb", + query=parse_one(old_query), + kind=IncrementalByTimeRangeKind( + time_column="ds", forward_only=True, on_additive_change=on_additive_change + ), + ) + ) + + snapshot = make_snapshot( + SqlModel( + name=name, + dialect="duckdb", + query=parse_one(new_query), + kind=IncrementalByTimeRangeKind( + time_column="ds", forward_only=True, on_additive_change=on_additive_change + ), + ) + ) + snapshot.previous_versions = ( + SnapshotDataVersion( + fingerprint=SnapshotFingerprint( + data_hash="test_data_hash", + metadata_hash="test_metadata_hash", + ), + version="test_version", + change_category=SnapshotChangeCategory.NON_BREAKING, + dev_table_suffix="dev", ), ) @@ -411,22 +526,36 @@ def sushi_fixed_date_data_validator(sushi_context_fixed_date: Context) -> SushiD @pytest.fixture def make_mocked_engine_adapter(mocker: MockerFixture) -> t.Callable: def _make_function( - klass: t.Type[T], dialect: t.Optional[str] = None, register_comments: bool = True + klass: t.Type[T], + dialect: t.Optional[str] = None, + register_comments: bool = True, + default_catalog: t.Optional[str] = None, + patch_get_data_objects: bool = True, + **kwargs: t.Any, ) -> T: connection_mock = mocker.NonCallableMock() cursor_mock = mocker.Mock() connection_mock.cursor.return_value = cursor_mock cursor_mock.connection.return_value = connection_mock adapter = klass( - lambda: connection_mock, + lambda *args, **kwargs: connection_mock, dialect=dialect or klass.DIALECT, register_comments=register_comments, + default_catalog=default_catalog, + **kwargs, ) if isinstance(adapter, SparkEngineAdapter): mocker.patch( "sqlmesh.engines.spark.db_api.spark_session.SparkSessionConnection._spark_major_minor", new_callable=PropertyMock(return_value=(3, 5)), ) + if isinstance(adapter, MSSQLEngineAdapter): + mocker.patch( + "sqlmesh.core.engine_adapter.mssql.MSSQLEngineAdapter.catalog_support", + new_callable=PropertyMock(return_value=CatalogSupport.REQUIRES_SET_CATALOG), + ) + if patch_get_data_objects: + mocker.patch.object(adapter, "_get_data_objects", return_value=[]) return adapter return _make_function @@ -435,18 +564,49 @@ def _make_function( @pytest.fixture def copy_to_temp_path(tmp_path: Path) -> t.Callable: def ignore(src, names): + # do not copy any sub-dirs if current dir is named one of these if Path(src).name in {".cache", "__pycache__", "logs", "data", "target"}: return names - return [] + # do not copy sub-dirs named ".cache" for any current dir + return [name for name in names if name == ".cache"] def _make_function( paths: t.Union[t.Union[str, Path], t.Collection[t.Union[str, Path]]], ) -> t.List[Path]: paths = ensure_list(paths) + all_paths = [Path(p) for p in paths] temp_dirs = [] - for path in paths: + for path in all_paths: temp_dir = Path(tmp_path) / uuid.uuid4().hex - copytree(path, temp_dir, symlinks=True, ignore=ignore) + + if IS_WINDOWS: + # shutil.copytree just doesnt work properly with the symlinks on Windows, regardless of the `symlinks` setting + src = str(path.absolute()) + dst = str(temp_dir.absolute()) + + # Robocopy flag reference: https://learn.microsoft.com/en-us/windows-server/administration/windows-commands/robocopy#copy-options + # /E: Copy subdirectories, including empty directories + # /COPY:D Copy "data" only. In particular, this avoids copying auditing information, which can throw + # an error like "ERROR : You do not have the Manage Auditing user right" + robocopy_cmd = f"robocopy {src} {dst} /E /COPY:D" + exit_code = os.system(robocopy_cmd) + + # exit code reference: https://learn.microsoft.com/en-us/windows-server/administration/windows-commands/robocopy#exit-return-codes + if exit_code > 8: + raise Exception( + f"robocopy command: '{robocopy_cmd}' failed with exit code: {exit_code}" + ) + + # after copying, delete the files that would have been ignored + for root, dirs, _ in os.walk(temp_dir): + for dir in dirs: + full_dir = fix_windows_path(Path(root) / dir) + for ignored in ignore(full_dir, [full_dir]): + shutil.rmtree(ignored) + + else: + copytree(path, temp_dir, symlinks=True, ignore=ignore) + temp_dirs.append(temp_dir) return temp_dirs @@ -469,3 +629,26 @@ def _make_function(table_name: str, random_id: str) -> exp.Table: return temp_table return _make_function + + +@pytest.fixture(scope="function", autouse=True) +def set_default_connection(request): + request = request.node.get_closest_marker("set_default_connection") + disable = request and request.kwargs.get("disable") + + if disable: + ctx = nullcontext() + else: + original_get_connection = Config.get_connection + + def _lax_get_connection(self, gateway_name: t.Optional[str] = None) -> ConnectionConfig: + try: + connection = original_get_connection(self, gateway_name) + except: + connection = DuckDBConnectionConfig() + return connection + + ctx = mock.patch("sqlmesh.core.config.Config.get_connection", _lax_get_connection) + + with ctx: + yield diff --git a/tests/core/analytics/test_collector.py b/tests/core/analytics/test_collector.py index 7c39081cb7..1a4c42cbe3 100644 --- a/tests/core/analytics/test_collector.py +++ b/tests/core/analytics/test_collector.py @@ -5,6 +5,7 @@ import pytest from pytest_mock.plugin import MockerFixture +from sqlmesh.core import constants as c from sqlmesh.core.analytics.collector import AnalyticsCollector from sqlmesh.core.snapshot import SnapshotChangeCategory from sqlmesh.integrations.github.cicd.config import GithubCICDBotConfig @@ -17,21 +18,17 @@ def collector(mocker: MockerFixture) -> AnalyticsCollector: return AnalyticsCollector(dispatcher=dispatcher_mock) -def test_on_project_loaded(collector: AnalyticsCollector, mocker: MockerFixture): +@pytest.mark.parametrize( + "project_type", + [ + "native", + "dbt", + "hybrid", + ], +) +def test_on_project_loaded(collector: AnalyticsCollector, mocker: MockerFixture, project_type): collector.on_project_loaded( - project_type="native", - models_count=1, - audits_count=2, - standalone_audits_count=3, - macros_count=4, - jinja_macros_count=5, - load_time_sec=1.123, - state_sync_fingerprint="test_fingerprint", - project_name="test_project", - ) - - collector.on_project_loaded( - project_type="dbt", + project_type=project_type, models_count=1, audits_count=2, standalone_audits_count=3, @@ -46,6 +43,7 @@ def test_on_project_loaded(collector: AnalyticsCollector, mocker: MockerFixture) from dbt.version import __version__ as dbt_version + version = ', "dbt_version": "' + dbt_version + '"' if project_type != c.NATIVE else "" collector._dispatcher.add_event.assert_has_calls( # type: ignore [ call( @@ -55,17 +53,11 @@ def test_on_project_loaded(collector: AnalyticsCollector, mocker: MockerFixture) "seq_num": 0, "event_type": "PROJECT_LOADED", "client_ts": mocker.ANY, - "event": '{"project_type": "native", "models_count": 1, "audits_count": 2, "standalone_audits_count": 3, "macros_count": 4, "jinja_macros_count": 5, "load_time_ms": 1123, "state_sync_fingerprint": "test_fingerprint", "project_name_hash": "6e72a69d5c5cca8f0400338441c022e4"}', - } - ), - call( - { - "user_id": mocker.ANY, - "process_id": collector._process_id, - "seq_num": 1, - "event_type": "PROJECT_LOADED", - "client_ts": mocker.ANY, - "event": f'{{"project_type": "dbt", "models_count": 1, "audits_count": 2, "standalone_audits_count": 3, "macros_count": 4, "jinja_macros_count": 5, "load_time_ms": 1123, "state_sync_fingerprint": "test_fingerprint", "project_name_hash": "6e72a69d5c5cca8f0400338441c022e4", "dbt_version": "{dbt_version}"}}', + "event": '{"project_type": "' + + project_type + + '", "models_count": 1, "audits_count": 2, "standalone_audits_count": 3, "macros_count": 4, "jinja_macros_count": 5, "load_time_ms": 1123, "state_sync_fingerprint": "test_fingerprint", "project_name_hash": "6e72a69d5c5cca8f0400338441c022e4"' + + version + + "}", } ), ] @@ -153,7 +145,7 @@ def test_on_cicd_command(collector: AnalyticsCollector, mocker: MockerFixture): { "seq_num": 1, "event_type": "CICD_COMMAND", - "event": '{"command_name": "test_cicd", "command_args": ["arg_1", "arg_2"], "parent_command_names": ["parent_a", "parent_b"], "cicd_bot_config": {"invalidate_environment_after_deploy": true, "enable_deploy_command": false, "auto_categorize_changes": {"external": "off", "python": "off", "sql": "off", "seed": "off"}, "skip_pr_backfill": true, "run_on_deploy_to_prod": true}}', + "event": '{"command_name": "test_cicd", "command_args": ["arg_1", "arg_2"], "parent_command_names": ["parent_a", "parent_b"], "cicd_bot_config": {"invalidate_environment_after_deploy": true, "enable_deploy_command": false, "run_on_deploy_to_prod": false}}', **common_fields, } ), @@ -169,7 +161,10 @@ def test_on_plan_apply( plan_id = plan.plan_id collector.on_plan_apply_start( - plan=plan, engine_type="bigquery", state_sync_type="mysql", scheduler_type="builtin" + plan=plan.to_evaluatable(), + engine_type="bigquery", + state_sync_type="mysql", + scheduler_type="builtin", ) collector.on_plan_apply_end(plan_id=plan_id) collector.on_plan_apply_end(plan_id=plan_id, error=SQLMeshError("test_error")) @@ -188,7 +183,7 @@ def test_on_plan_apply( { "seq_num": 0, "event_type": "PLAN_APPLY_START", - "event": f'{{"plan_id": "{plan_id}", "engine_type": "bigquery", "state_sync_type": "mysql", "scheduler_type": "builtin", "is_dev": false, "skip_backfill": false, "no_gaps": false, "forward_only": false, "ensure_finalized_snapshots": false, "has_restatements": false, "directly_modified_count": 18, "indirectly_modified_count": 0, "environment_name_hash": "d6e4a9b6646c62fc48baa6dd6150d1f7"}}', + "event": f'{{"plan_id": "{plan_id}", "engine_type": "bigquery", "state_sync_type": "mysql", "scheduler_type": "builtin", "is_dev": false, "skip_backfill": false, "no_gaps": false, "forward_only": false, "ensure_finalized_snapshots": false, "has_restatements": false, "directly_modified_count": 21, "indirectly_modified_count": 0, "environment_name_hash": "d6e4a9b6646c62fc48baa6dd6150d1f7"}}', **common_fields, } ), @@ -223,7 +218,7 @@ def test_on_snapshots_created( context.get_snapshot("sushi.waiter_revenue_by_day"), context.get_snapshot("sushi.top_waiters"), ] - new_snapshots[0].categorize_as(SnapshotChangeCategory.FORWARD_ONLY) + new_snapshots[0].categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) new_snapshots[0].effective_from = "2024-01-01" new_snapshots[0].version = "test_version" @@ -244,7 +239,7 @@ def test_on_snapshots_created( "node_type": "model", "model_kind": "incremental_by_time_range", "is_sql": False, - "change_category": "forward_only", + "change_category": "breaking", "dialect": "duckdb", "audits_count": 0, "effective_from_set": True, @@ -289,8 +284,11 @@ def test_on_snapshots_created( def test_on_run(collector: AnalyticsCollector, mocker: MockerFixture): run_id = collector.on_run_start(engine_type="bigquery", state_sync_type="mysql") - collector.on_run_end(run_id=run_id, succeeded=True) - collector.on_run_end(run_id=run_id, succeeded=False, error=SQLMeshError("test_error")) + collector.on_run_end(run_id=run_id, succeeded=True, interrupted=False) + collector.on_run_end( + run_id=run_id, succeeded=False, interrupted=False, error=SQLMeshError("test_error") + ) + collector.on_run_end(run_id=run_id, succeeded=False, interrupted=True) collector.flush() @@ -314,7 +312,7 @@ def test_on_run(collector: AnalyticsCollector, mocker: MockerFixture): { "seq_num": 1, "event_type": "RUN_END", - "event": f'{{"run_id": "{run_id}", "succeeded": true, "error": null}}', + "event": f'{{"run_id": "{run_id}", "succeeded": true, "interrupted": false, "error": null}}', **common_fields, } ), @@ -322,7 +320,15 @@ def test_on_run(collector: AnalyticsCollector, mocker: MockerFixture): { "seq_num": 2, "event_type": "RUN_END", - "event": f'{{"run_id": "{run_id}", "succeeded": false, "error": "SQLMeshError"}}', + "event": f'{{"run_id": "{run_id}", "succeeded": false, "interrupted": false, "error": "SQLMeshError"}}', + **common_fields, + } + ), + call( + { + "seq_num": 3, + "event_type": "RUN_END", + "event": f'{{"run_id": "{run_id}", "succeeded": false, "interrupted": true, "error": null}}', **common_fields, } ), diff --git a/tests/core/engine_adapter/config.yaml b/tests/core/engine_adapter/config.yaml deleted file mode 100644 index 66be470ebc..0000000000 --- a/tests/core/engine_adapter/config.yaml +++ /dev/null @@ -1,73 +0,0 @@ -gateways: - inttest_duckdb: - connection: - type: duckdb - catalogs: - memory: ':memory:' - testing: 'testing.duckdb' - inttest_trino: - connection: - type: trino - host: localhost - port: 8080 - user: admin - catalog: datalake - http_scheme: http - retries: 20 - state_connection: - type: duckdb - inttest_trino_iceberg: - connection: - type: trino - host: localhost - port: 8080 - user: admin - catalog: datalake_iceberg - http_scheme: http - retries: 20 - state_connection: - type: duckdb - inttest_trino_delta: - connection: - type: trino - host: localhost - port: 8080 - user: admin - catalog: datalake_delta - http_scheme: http - retries: 20 - state_connection: - type: duckdb - inttest_spark: - connection: - type: spark - config: - spark.remote: sc://localhost - state_connection: - type: duckdb - inttest_mssql: - connection: - type: mssql - host: localhost - user: sa - password: 1StrongPwd@@ - inttest_postgres: - connection: - type: postgres - user: postgres - password: postgres - database: postgres - host: localhost - port: 5432 - concurrent_tasks: 1 - inttest_mysql: - connection: - type: mysql - host: localhost - user: root - password: mysql - port: 3306 - charset: utf8 - -model_defaults: - dialect: duckdb diff --git a/tests/core/engine_adapter/docker-compose.yaml b/tests/core/engine_adapter/docker-compose.yaml deleted file mode 100644 index dcba648e6a..0000000000 --- a/tests/core/engine_adapter/docker-compose.yaml +++ /dev/null @@ -1,189 +0,0 @@ -version: '3' - -x-hive-metastore-environments: &hive_metastore_environments - S3_ENDPOINT: http://minio:9000 - S3_ACCESS_KEY: minio - S3_SECRET_KEY: minio123 - S3_PATH_STYLE_ACCESS: "true" - REGION: "" - GOOGLE_CLOUD_KEY_FILE_PATH: "" - AZURE_ADL_CLIENT_ID: "" - AZURE_ADL_CREDENTIAL: "" - AZURE_ADL_REFRESH_URL: "" - AZURE_ABFS_STORAGE_ACCOUNT: "" - AZURE_ABFS_ACCESS_KEY: "" - AZURE_WASB_STORAGE_ACCOUNT: "" - AZURE_ABFS_OAUTH: "" - AZURE_ABFS_OAUTH_TOKEN_PROVIDER: "" - AZURE_ABFS_OAUTH_CLIENT_ID: "" - AZURE_ABFS_OAUTH_SECRET: "" - AZURE_ABFS_OAUTH_ENDPOINT: "" - AZURE_WASB_ACCESS_KEY: "" - -services: - mysql: - image: mysql:8.1 - ports: - - '3306:3306' - environment: - MYSQL_ROOT_PASSWORD: mysql - postgres: - image: postgres - ports: - - '5432:5432' - environment: - POSTGRES_PASSWORD: postgres - mssql: - image: mcr.microsoft.com/mssql/server:2019-latest - ports: - - '1433:1433' - environment: - SA_PASSWORD: 1StrongPwd@@ - ACCEPT_EULA: Y - - # Trino Stack - trino: - hostname: trino - container_name: trino - image: 'trinodb/trino:429' - ports: - - '8080:8080' - volumes: - - ./trino/catalog:/etc/trino/catalog - - trino_metastore_db: - image: postgres - hostname: trino_metastore_db - environment: - POSTGRES_USER: hive - POSTGRES_PASSWORD: hive - volumes: - - ./trino/initdb.sql:/docker-entrypoint-initdb.d/initdb.sql - - # A second metastore DB is needed because testing all of hive/iceberg/delta - # creates too many connections for a single postgres DB. - trino_iceberg_delta_metastore_db: - image: postgres - hostname: trino_iceberg_delta_metastore_db - environment: - POSTGRES_USER: hive - POSTGRES_PASSWORD: hive - volumes: - - ./trino/initdb.sql:/docker-entrypoint-initdb.d/initdb.sql - - trino-datalake-hive-metastore: - hostname: trino-datalake-hive-metastore - image: 'starburstdata/hive:3.1.2-e.15' - environment: - HIVE_METASTORE_DRIVER: org.postgresql.Driver - HIVE_METASTORE_JDBC_URL: jdbc:postgresql://trino_metastore_db:5432/datalake_metastore - HIVE_METASTORE_USER: hive - HIVE_METASTORE_PASSWORD: hive - HIVE_METASTORE_WAREHOUSE_DIR: s3://trino/datalake - <<: *hive_metastore_environments - depends_on: - - trino_metastore_db - - trino-testing-hive-metastore: - hostname: trino-testing-hive-metastore - image: 'starburstdata/hive:3.1.2-e.15' - environment: - HIVE_METASTORE_DRIVER: org.postgresql.Driver - HIVE_METASTORE_JDBC_URL: jdbc:postgresql://trino_metastore_db:5432/testing_metastore - HIVE_METASTORE_USER: hive - HIVE_METASTORE_PASSWORD: hive - HIVE_METASTORE_WAREHOUSE_DIR: s3://trino/testing - <<: *hive_metastore_environments - depends_on: - - trino_metastore_db - - trino-datalake-iceberg-hive-metastore: - hostname: trino-datalake-iceberg-hive-metastore - image: 'starburstdata/hive:3.1.2-e.15' - environment: - HIVE_METASTORE_DRIVER: org.postgresql.Driver - HIVE_METASTORE_JDBC_URL: jdbc:postgresql://trino_iceberg_delta_metastore_db:5432/datalake_iceberg_metastore - HIVE_METASTORE_USER: hive - HIVE_METASTORE_PASSWORD: hive - HIVE_METASTORE_WAREHOUSE_DIR: s3://trino/datalake_iceberg - <<: *hive_metastore_environments - depends_on: - - trino_iceberg_delta_metastore_db - - trino-datalake-delta-hive-metastore: - hostname: trino-datalake-delta-hive-metastore - image: 'starburstdata/hive:3.1.2-e.15' - environment: - HIVE_METASTORE_DRIVER: org.postgresql.Driver - HIVE_METASTORE_JDBC_URL: jdbc:postgresql://trino_iceberg_delta_metastore_db:5432/datalake_delta_metastore - HIVE_METASTORE_USER: hive - HIVE_METASTORE_PASSWORD: hive - HIVE_METASTORE_WAREHOUSE_DIR: s3://trino/datalake_delta - <<: *hive_metastore_environments - depends_on: - - trino_iceberg_delta_metastore_db - - # Spark Stack - spark: - build: - context: ./spark - command: /opt/bitnami/spark/sbin/start-connect-server.sh --packages org.apache.spark:spark-connect_2.12:3.5.0 - ports: - - '15000-15100:15000-15100' - volumes: - - ./spark/conf/spark-defaults.conf:/opt/bitnami/spark/conf/spark-defaults.conf - - ./spark/conf/hive-site.xml:/opt/bitnami/spark/conf/hive-site.xml - depends_on: - - spark-hive-metastore - - spark_metastore_db: - image: postgres:11 - hostname: spark_metastore_db - environment: - POSTGRES_USER: hive - POSTGRES_PASSWORD: hive - POSTGRES_DB: metastore - - spark-hive-metastore: - hostname: spark-hive-metastore - image: 'starburstdata/hive:3.1.2-e.15' - environment: - HIVE_METASTORE_DRIVER: org.postgresql.Driver - HIVE_METASTORE_JDBC_URL: jdbc:postgresql://spark_metastore_db:5432/metastore - HIVE_METASTORE_USER: hive - HIVE_METASTORE_PASSWORD: hive - HIVE_METASTORE_WAREHOUSE_DIR: s3://spark/ - <<: *hive_metastore_environments - depends_on: - - spark_metastore_db - - # Shared Spark/Trino S3 Storage - minio: - hostname: minio - image: 'minio/minio:RELEASE.2022-05-26T05-48-41Z' - ports: - - '9000:9000' - - '9001:9001' - environment: - MINIO_ACCESS_KEY: minio - MINIO_SECRET_KEY: minio123 - command: server /data --console-address ":9001" - - # This job will create the "spark/trino" buckets and sub paths - mc-job: - image: 'minio/mc:RELEASE.2022-05-09T04-08-26Z' - entrypoint: | - /bin/bash -c " - sleep 5; - /usr/bin/mc config --quiet host add myminio http://minio:9000 minio minio123; - /usr/bin/mc mb --quiet myminio/trino/datalake; - /usr/bin/mc mb --quiet myminio/trino/datalake_iceberg; - /usr/bin/mc mb --quiet myminio/trino/datalake_delta; - /usr/bin/mc mb --quiet myminio/trino/testing; - /usr/bin/mc mb --quiet myminio/trino/testing_iceberg; - /usr/bin/mc mb --quiet myminio/trino/testing_delta; - /usr/bin/mc mb --quiet myminio/spark/datalake; - /usr/bin/mc mb --quiet myminio/spark/testing - " - depends_on: - - minio diff --git a/tests/core/engine_adapter/integration/__init__.py b/tests/core/engine_adapter/integration/__init__.py new file mode 100644 index 0000000000..4ad6a17944 --- /dev/null +++ b/tests/core/engine_adapter/integration/__init__.py @@ -0,0 +1,862 @@ +from __future__ import annotations + +import os +import pathlib +import sys +import typing as t +import time +from contextlib import contextmanager + +import pandas as pd # noqa: TID253 +import pytest +from sqlglot import exp, parse_one +from sqlglot.optimizer.normalize_identifiers import normalize_identifiers + +from sqlmesh import Config, Context, EngineAdapter +from sqlmesh.core.config import load_config_from_paths +from sqlmesh.core.config.connection import AthenaConnectionConfig +from sqlmesh.core.dialect import normalize_model_name +import sqlmesh.core.dialect as d +from sqlmesh.core.engine_adapter import SparkEngineAdapter, TrinoEngineAdapter, AthenaEngineAdapter +from sqlmesh.core.engine_adapter.shared import DataObject +from sqlmesh.core.model.definition import SqlModel, load_sql_based_model +from sqlmesh.utils import random_id +from sqlmesh.utils.date import to_ds +from sqlmesh.utils.pydantic import PydanticModel +from tests.utils.pandas import compare_dataframes +from dataclasses import dataclass +from _pytest.mark import MarkDecorator +from _pytest.mark.structures import ParameterSet + +if t.TYPE_CHECKING: + from sqlmesh.core._typing import TableName, SchemaName + from sqlmesh.core.engine_adapter._typing import Query + +TEST_SCHEMA = "test_schema" + + +@dataclass +class IntegrationTestEngine: + engine: str + catalog_types: t.Optional[t.List[str]] = None + native_dataframe_type: t.Optional[str] = None + cloud: bool = False + + @property + def dialect(self) -> str: + return self.engine.split("_", maxsplit=1)[0] + + @property + def pytest_marks(self) -> t.List[MarkDecorator]: + marks = [getattr(pytest.mark, self.engine), pytest.mark.engine] + if self.cloud: + marks.append(pytest.mark.remote) + else: + marks.append(pytest.mark.docker) + if self.engine == "duckdb": + marks.extend( + [ + # run the duckdb tests in `make cicd-test` as well + pytest.mark.slow, + # the duckdb tests cannot run concurrently because many of them point at the same files + # and duckdb does not support multi process read/write on the same files + # ref: https://duckdb.org/docs/connect/concurrency.html#writing-to-duckdb-from-multiple-processes + pytest.mark.xdist_group("engine_integration_duckdb"), + ] + ) + return marks + + +ENGINES = [ + # Docker engines that can be locally tested + IntegrationTestEngine("duckdb"), + IntegrationTestEngine("postgres"), + IntegrationTestEngine("mysql"), + IntegrationTestEngine("mssql"), + IntegrationTestEngine("trino", catalog_types=["hive", "iceberg", "delta", "nessie"]), + IntegrationTestEngine("spark", native_dataframe_type="pyspark"), + IntegrationTestEngine("clickhouse", catalog_types=["standalone", "cluster"]), + IntegrationTestEngine("risingwave"), + # Cloud engines that need paid accounts / special credentials + IntegrationTestEngine("clickhouse_cloud", cloud=True), + IntegrationTestEngine("redshift", cloud=True), + IntegrationTestEngine("athena", catalog_types=["hive", "iceberg"], cloud=True), + IntegrationTestEngine("bigquery", native_dataframe_type="bigframe", cloud=True), + IntegrationTestEngine("databricks", native_dataframe_type="pyspark", cloud=True), + IntegrationTestEngine("snowflake", native_dataframe_type="snowpark", cloud=True), + IntegrationTestEngine("fabric", cloud=True), + IntegrationTestEngine("gcp_postgres", cloud=True), +] + +ENGINES_BY_NAME = {e.engine: e for e in ENGINES} + + +def generate_pytest_params( + engines: t.Union[IntegrationTestEngine, t.List[IntegrationTestEngine]], + query: bool = True, + df: bool = False, + show_variant_in_test_id: bool = True, +) -> t.Iterable[ParameterSet]: + """ + The engine adapter tests have a bunch of variants: + - Per engine for engines that dont have pluggable catalogs + - Per engine per catalog type for engines that have pluggable catalogs + + In addition, many engine adapter functions take either a SQL Query or a DataFrame so we need to test both combinations. + For the methods that take a DataFrame: + - Every engine takes a Pandas DataFrame + - A small subset of engines also take their own engine-specific DataFrame (eg Bigframe, Snowpark, Pyspark) + + This function controls the parameter generation so that: + - Tests that only need to test SQL queries only get called once per engine/catalog + - Tests that only need to test DataFrame's get called once for Pandas Dataframe's and once for each engine-specific DataFrame + - Tests that need to test both SQL Queries and DataFrame's get called once for every combination of (engine, catalog, *(query, pandas df, native df)) + + The goal is to prevent needing to code this kind of logic into tests: + + > if test_type == "df": + > pytest.skip("Test only needs to run for query") + + As well as make it easier to generate the right combinations for new databases / catalogs / DataFrame implementations + """ + if not isinstance(engines, list): + engines = [engines] + + for engine in engines: + catalogs = engine.catalog_types if engine.catalog_types else [""] + for catalog in catalogs: + gateway = ( + f"inttest_{engine.engine}_{catalog}" if catalog else f"inttest_{engine.engine}" + ) + if engine.engine == "athena": + # athena only has a single gateway defined, not a gateway per catalog + gateway = f"inttest_athena" + + variants = [] + if query: + variants.append("query") + if df: + variants.append("df-pandas") + if engine.native_dataframe_type: + variants.append(f"df-{engine.native_dataframe_type}") + + test_id = f"{engine.engine}_{catalog}" if catalog else f"{engine.engine}" + default_table_format = catalog + + for variant in variants: + yield pytest.param( + (engine, gateway, variant, default_table_format), + marks=engine.pytest_marks, + id=f"[{variant}]{test_id}" if show_variant_in_test_id else test_id, + ) + + +class MetadataResults(PydanticModel): + tables: t.List[str] = [] + views: t.List[str] = [] + materialized_views: t.List[str] = [] + managed_tables: t.List[str] = [] + + @classmethod + def from_data_objects(cls, data_objects: t.List[DataObject]) -> MetadataResults: + tables = [] + views = [] + materialized_views = [] + managed_tables = [] + for obj in data_objects: + if obj.type.is_table: + tables.append(obj.name) + elif obj.type.is_view: + views.append(obj.name) + elif obj.type.is_materialized_view: + materialized_views.append(obj.name) + elif obj.type.is_managed_table: + managed_tables.append(obj.name) + else: + raise ValueError(f"Unexpected object type: {obj.type}") + return MetadataResults( + tables=tables, + views=views, + materialized_views=materialized_views, + managed_tables=managed_tables, + ) + + @property + def non_temp_tables(self) -> t.List[str]: + return [x for x in self.tables if not x.startswith("__temp") and not x.startswith("temp")] + + +class TestContext: + __test__ = False # prevent pytest trying to collect this as a test class + + def __init__( + self, + test_type: str, + engine_adapter: EngineAdapter, + mark: str, + gateway: str, + tmp_path: pathlib.Path, + is_remote: bool = False, + columns_to_types: t.Optional[t.Dict[str, t.Union[str, exp.DataType]]] = None, + ): + self._test_type = test_type + self.engine_adapter = engine_adapter + self.mark = mark + self.gateway = gateway + self._columns_to_types = columns_to_types + self.test_id = random_id(short=True) + self._context: t.Optional[Context] = None + self.is_remote = is_remote + self._schemas: t.List[ + str + ] = [] # keep track of any schemas returned from self.schema() / self.table() so we can drop them at the end + self._catalogs: t.List[ + str + ] = [] # keep track of any catalogs created via self.create_catalog() so we can drop them at the end + self.tmp_path = tmp_path + + @property + def test_type(self) -> str: + return "df" if self._test_type.startswith("df") else "query" + + @property + def df_type(self) -> t.Optional[str]: + if self.test_type == "df": + # the 'pandas' part of 'df-pandas' + return self._test_type.split("-", maxsplit=1)[1] + return None + + @property + def engine_type(self) -> str: + if self.mark.startswith("gcp_postgres"): + return "gcp_postgres" + + return self.mark.split("_")[0] + + @property + def columns_to_types(self): + if self._columns_to_types is None: + self._columns_to_types = { + "id": exp.DataType.build("int"), + "ds": exp.DataType.build("string"), + } + return self._columns_to_types + + @columns_to_types.setter + def columns_to_types(self, value: t.Dict[str, t.Union[str, exp.DataType]]): + self._columns_to_types = { + k: exp.DataType.build(v, dialect=self.dialect) for k, v in value.items() + } + + @property + def time_columns(self) -> t.List[str]: + return [ + k + for k, v in self.columns_to_types.items() + if v.sql().lower().startswith("timestamp") + or v.sql().lower().startswith("date") + or k.lower() == "ds" + ] + + @property + def timestamp_columns(self) -> t.List[str]: + return [ + k + for k, v in self.columns_to_types.items() + if v.sql().lower().startswith("timestamp") + or (v.sql().lower() == "datetime" and self.dialect == "bigquery") + ] + + @property + def time_column(self) -> str: + return self.time_columns[0] + + @property + def time_formatter(self) -> t.Callable: + return lambda x, _: exp.Literal.string(to_ds(x)) + + @property + def partitioned_by(self) -> t.List[exp.Expression]: + return [parse_one(self.time_column)] + + @property + def dialect(self) -> str: + return self.engine_adapter.dialect + + @property + def current_catalog_type(self) -> str: + return self.engine_adapter.current_catalog_type + + @property + def supports_merge(self) -> bool: + if self.dialect == "spark": + assert isinstance(self.engine_adapter, SparkEngineAdapter) + # Spark supports MERGE on the Iceberg catalog (which is configured under "testing" in these integration tests) + return self.engine_adapter.default_catalog == "testing" + + if self.dialect == "trino": + assert isinstance(self.engine_adapter, TrinoEngineAdapter) + # Trino supports MERGE on Delta and Iceberg but not Hive + return ( + self.engine_adapter.get_catalog_type(self.engine_adapter.default_catalog) != "hive" + ) + + if self.dialect == "athena": + return "hive" not in self.mark + + if self.dialect == "risingwave": + return False + + return True + + @property + def default_table_format(self) -> t.Optional[str]: + if self.dialect in {"athena", "trino"} and "_" in self.mark: + return self.mark.split("_", 1)[-1] # take eg 'athena_iceberg' and return 'iceberg' + return None + + def add_test_suffix(self, value: str) -> str: + return f"{value}_{self.test_id}" + + def get_metadata_results(self, schema: t.Optional[SchemaName] = None) -> MetadataResults: + schema = schema if schema else self.schema(TEST_SCHEMA) + return MetadataResults.from_data_objects(self.engine_adapter.get_data_objects(schema)) + + def _init_engine_adapter(self) -> None: + schema = self.schema(TEST_SCHEMA) + self.engine_adapter.drop_schema(schema, ignore_if_not_exists=True, cascade=True) + self.engine_adapter.create_schema(schema) + + def _format_df(self, data: pd.DataFrame, to_datetime: bool = True) -> pd.DataFrame: + for timestamp_column in self.timestamp_columns: + if timestamp_column in data.columns: + value = data[timestamp_column] + if to_datetime: + value = pd.to_datetime(value) + data[timestamp_column] = value.astype("datetime64[ns]") + return data + + def init(self): + if self.df_type == "pyspark" and not hasattr(self.engine_adapter, "is_pyspark_df"): + pytest.skip(f"Engine adapter {self.engine_adapter} doesn't support pyspark") + self._init_engine_adapter() + + def input_data( + self, + data: pd.DataFrame, + columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, + ) -> t.Union[Query, pd.DataFrame]: + columns_to_types = columns_to_types or self.columns_to_types + if self.test_type == "query": + return self.engine_adapter._values_to_sql( + list(data.itertuples(index=False, name=None)), + batch_start=0, + batch_end=sys.maxsize, + target_columns_to_types=columns_to_types, + ) + if self.test_type == "df": + formatted_df = self._format_df(data, to_datetime=self.dialect != "trino") + if self.df_type == "pandas": + return formatted_df + if self.df_type == "pyspark": + return self.engine_adapter.spark.createDataFrame(formatted_df) # type: ignore + if self.df_type == "bigframe": + return self.engine_adapter.bigframe.read_pandas(formatted_df) # type: ignore + if self.df_type == "snowpark": + return self.engine_adapter.snowpark.create_dataframe(formatted_df) # type: ignore + + raise ValueError(f"Unknown DF type: {self.df_type}") + + raise ValueError(f"Unknown test type: {self.test_type}") + + def output_data(self, data: pd.DataFrame) -> pd.DataFrame: + return self._format_df(data) + + def table(self, table_name: TableName, schema: str = TEST_SCHEMA) -> exp.Table: + schema = self.add_test_suffix(schema) + self._schemas.append(schema) + + table = exp.to_table(table_name, dialect=self.dialect) + table.set("db", exp.parse_identifier(schema, dialect=self.dialect)) + + return exp.to_table( + normalize_model_name( + table, + default_catalog=self.engine_adapter.default_catalog, + dialect=self.dialect, + ) + ) + + def physical_properties( + self, properties_for_dialect: t.Dict[str, t.Dict[str, str | exp.Expression]] + ) -> t.Dict[str, exp.Expression]: + if props := properties_for_dialect.get(self.dialect): + return {k: exp.Literal.string(v) if isinstance(v, str) else v for k, v in props.items()} + return {} + + def schema(self, schema_name: str = TEST_SCHEMA, catalog_name: t.Optional[str] = None) -> str: + schema_name = exp.table_name( + normalize_model_name( + self.add_test_suffix( + ".".join( + p + for p in (catalog_name or self.engine_adapter.default_catalog, schema_name) + if p + ) + if "." not in schema_name + else schema_name + ), + default_catalog=None, + dialect=self.dialect, + ) + ) + self._schemas.append(schema_name) + return schema_name + + def get_current_data(self, table: exp.Table) -> pd.DataFrame: + df = self.engine_adapter.fetchdf(exp.select("*").from_(table), quote_identifiers=True) + if self.dialect == "snowflake" and "id" in df.columns: + df["id"] = df["id"].apply(lambda x: x if pd.isna(x) else int(x)) + return self._format_df(df) + + def compare_with_current(self, table: exp.Table, expected: pd.DataFrame) -> None: + compare_dataframes( + self.get_current_data(table), + self.output_data(expected), + check_dtype=False, + check_index_type=False, + ) + + def get_table_comment( + self, + schema_name: str, + table_name: str, + table_kind: str = "BASE TABLE", + snowflake_capitalize_ids: bool = True, + ) -> t.Optional[str]: + if self.dialect in ["postgres", "redshift"]: + query = f""" + SELECT + pgc.relname, + pg_catalog.obj_description(pgc.oid, 'pg_class') + FROM pg_catalog.pg_class pgc + INNER JOIN pg_catalog.pg_namespace n + ON pgc.relnamespace = n.oid + WHERE + n.nspname = '{schema_name}' + AND pgc.relname = '{table_name}' + AND pgc.relkind = '{"v" if table_kind == "VIEW" else "r"}' + ; + """ + elif self.dialect in ["mysql", "snowflake"]: + # Snowflake treats all identifiers as uppercase unless they are lowercase and quoted. + # They are lowercase and quoted in sushi but not in the inline tests. + if self.dialect == "snowflake" and snowflake_capitalize_ids: + schema_name = schema_name.upper() + table_name = table_name.upper() + + comment_field_name = { + "mysql": "table_comment", + "snowflake": "comment", + } + + query = f""" + SELECT + table_name, + {comment_field_name[self.dialect]} + FROM INFORMATION_SCHEMA.TABLES + WHERE + table_schema='{schema_name}' + AND table_name='{table_name}' + AND table_type='{table_kind}' + """ + elif self.dialect == "bigquery": + query = f""" + SELECT + table_name, + option_value + FROM `region-us.INFORMATION_SCHEMA.TABLE_OPTIONS` + WHERE + table_schema='{schema_name}' + AND table_name='{table_name}' + AND option_name = 'description' + """ + elif self.dialect in ["spark", "databricks"]: + query = f"DESCRIBE TABLE EXTENDED {schema_name}.{table_name}" + elif self.dialect == "trino": + query = f""" + SELECT + table_name, + comment + FROM system.metadata.table_comments + WHERE + schema_name = '{schema_name}' + AND table_name = '{table_name}' + """ + elif self.dialect == "duckdb": + kind = "table" if table_kind == "BASE TABLE" else "view" + query = f""" + SELECT + {kind}_name, + comment + FROM duckdb_{kind}s() + WHERE + schema_name = '{schema_name}' + AND {kind}_name = '{table_name}' + """ + elif self.dialect == "clickhouse": + query = f"SELECT name, comment FROM system.tables WHERE database = '{schema_name}' AND name = '{table_name}'" + elif self.dialect == "risingwave": + query = f""" + SELECT + c.relname, + d.description + FROM pg_class c + INNER JOIN pg_description d ON c.oid = d.objoid AND d.objsubid = 0 + INNER JOIN pg_namespace n ON c.relnamespace = n.oid + WHERE + c.relname = '{table_name}' + AND n.nspname= '{schema_name}' + AND c.relkind = '{"v" if table_kind == "VIEW" else "r"}' + ; + """ + + result = self.engine_adapter.fetchall(query) + + if result: + if self.dialect == "bigquery": + comment = result[0][1].replace('"', "").replace("\\n", "\n") + elif self.dialect in ["spark", "databricks"]: + comment = [x for x in result if x[0] == "Comment"] + comment = comment[0][1] if comment else None + else: + comment = result[0][1] + + return comment + + return None + + def get_column_comments( + self, + schema_name: str, + table_name: str, + table_kind: str = "BASE TABLE", + snowflake_capitalize_ids: bool = True, + ) -> t.Dict[str, str]: + comment_index = 1 + if self.dialect in ["postgres", "redshift"]: + query = f""" + SELECT + cols.column_name, + pg_catalog.col_description(pgc.oid, cols.ordinal_position::int) AS column_comment + FROM pg_catalog.pg_class pgc + INNER JOIN pg_catalog.pg_namespace n + ON + pgc.relnamespace = n.oid + INNER JOIN information_schema.columns cols + ON + pgc.relname = cols.table_name + AND n.nspname = cols.table_schema + WHERE + n.nspname = '{schema_name}' + AND pgc.relname = '{table_name}' + AND pgc.relkind = '{"v" if table_kind == "VIEW" else "r"}' + ; + """ + elif self.dialect in ["mysql", "snowflake", "trino"]: + # Snowflake treats all identifiers as uppercase unless they are lowercase and quoted. + # They are lowercase and quoted in sushi but not in the inline tests. + if self.dialect == "snowflake" and snowflake_capitalize_ids: + schema_name = schema_name.upper() + table_name = table_name.upper() + + comment_field_name = { + "mysql": "column_comment", + "snowflake": "comment", + "trino": "comment", + } + + query = f""" + SELECT + column_name, + {comment_field_name[self.dialect]} + FROM + information_schema.columns + WHERE + table_schema = '{schema_name}' + AND table_name = '{table_name}' + """ + elif self.dialect == "bigquery": + query = f""" + SELECT + column_name, + description + FROM + `region-us.INFORMATION_SCHEMA.COLUMN_FIELD_PATHS` + WHERE + table_schema = '{schema_name}' + AND table_name = '{table_name}' + ; + """ + elif self.dialect in ["spark", "databricks", "clickhouse"]: + query = f"DESCRIBE TABLE {schema_name}.{table_name}" + comment_index = 2 if self.dialect in ["spark", "databricks"] else 4 + elif self.dialect == "duckdb": + query = f""" + SELECT + column_name, + comment + FROM duckdb_columns() + WHERE + schema_name = '{schema_name}' + AND table_name = '{table_name}' + """ + elif self.dialect == "risingwave": + query = f""" + SELECT + a.attname AS column_name, d.description + FROM + pg_class c + INNER JOIN pg_namespace n ON c.relnamespace = n.oid + INNER JOIN pg_attribute a ON c.oid = a.attrelid + INNER JOIN pg_description d + ON + a.attnum = d.objsubid + AND d.objoid = c.oid + WHERE + n.nspname = '{schema_name}' + AND c.relname = '{table_name}' + AND c.relkind = '{"v" if table_kind == "VIEW" else "r"}' + ; + """ + + result = self.engine_adapter.fetchall(query) + + comments = {} + if result: + if self.dialect in ["spark", "databricks"]: + result = list(set([x for x in result if not x[0].startswith("#")])) + + comments = { + x[0]: x[comment_index] + for x in result + if x[comment_index] is not None and x[comment_index].strip() != "" + } + + return comments + + def create_context( + self, + config_mutator: t.Optional[t.Callable[[str, Config], None]] = None, + path: t.Optional[pathlib.Path] = None, + ephemeral_state_connection: bool = True, + ) -> Context: + private_sqlmesh_dir = pathlib.Path(pathlib.Path().home(), ".sqlmesh") + config = load_config_from_paths( + Config, + project_paths=[ + pathlib.Path(os.path.join(os.path.dirname(__file__), "config.yaml")), + private_sqlmesh_dir / "config.yml", + private_sqlmesh_dir / "config.yaml", + ], + variables={"tmp_path": str(path or self.tmp_path)}, + ) + if config_mutator: + config_mutator(self.gateway, config) + config.gateways = {self.gateway: config.gateways[self.gateway]} + + gateway_config = config.gateways[self.gateway] + if ephemeral_state_connection: + # Override whatever state connection has been configured on the integration test config to use in-memory DuckDB instead + # This is so tests that initialize a SQLMesh context can run concurrently without clobbering each others state + from sqlmesh.core.config.connection import DuckDBConnectionConfig + + gateway_config.state_connection = DuckDBConnectionConfig() + + if "athena" in self.gateway: + conn = gateway_config.connection + assert isinstance(conn, AthenaConnectionConfig) + assert isinstance(self.engine_adapter, AthenaEngineAdapter) + # Ensure that s3_warehouse_location is propagated + conn.s3_warehouse_location = self.engine_adapter.s3_warehouse_location + + self._context = Context(paths=path or ".", config=config, gateway=self.gateway) + return self._context + + def create_catalog(self, catalog_name: str): + if self.dialect == "databricks": + self.engine_adapter.execute(f"CREATE CATALOG IF NOT EXISTS {catalog_name}") + elif self.dialect == "tsql": + self.engine_adapter.cursor.connection.autocommit(True) + try: + self.engine_adapter.cursor.execute(f"CREATE DATABASE {catalog_name}") + except Exception: + pass + self.engine_adapter.cursor.connection.autocommit(False) + elif self.dialect == "fabric": + # Use the engine adapter's built-in catalog creation functionality + self.engine_adapter.create_catalog(catalog_name) + elif self.dialect == "snowflake": + self.engine_adapter.execute(f'CREATE DATABASE IF NOT EXISTS "{catalog_name}"') + elif self.dialect == "duckdb": + try: + # Only applies to MotherDuck + self.engine_adapter.execute(f'CREATE DATABASE IF NOT EXISTS "{catalog_name}"') + except Exception: + pass + + self._catalogs.append(catalog_name) + + def drop_catalog(self, catalog_name: str): + if self.dialect == "bigquery": + return # bigquery cannot create/drop catalogs + if self.dialect == "databricks": + self.engine_adapter.execute(f"DROP CATALOG IF EXISTS {catalog_name} CASCADE") + elif self.dialect == "fabric": + # Use the engine adapter's built-in catalog dropping functionality + self.engine_adapter.drop_catalog(catalog_name) + else: + self.engine_adapter.execute(f'DROP DATABASE IF EXISTS "{catalog_name}"') + + def cleanup(self, ctx: t.Optional[Context] = None): + self._schemas.append(self.schema(TEST_SCHEMA)) + + ctx = ctx or self._context + if ctx and ctx.models: + for _, model in ctx.models.items(): + self._schemas.append(model.schema_name) + self._schemas.append(model.physical_schema) + + for schema_name in set(self._schemas): + self.engine_adapter.drop_schema( + schema_name=schema_name, ignore_if_not_exists=True, cascade=True + ) + + for catalog_name in set(self._catalogs): + self.drop_catalog(catalog_name) + + self.engine_adapter.close() + + def upsert_sql_model(self, model_definition: str) -> t.Tuple[Context, SqlModel]: + if not self._context: + self._context = self.create_context() + + model = load_sql_based_model(expressions=d.parse(model_definition)) + assert isinstance(model, SqlModel) + self._context.upsert_model(model) + return self._context, model + + def _get_create_user_or_role( + self, username: str, password: t.Optional[str] = None + ) -> t.Tuple[str, t.Optional[str]]: + password = password or random_id() + if self.dialect in ["postgres", "redshift"]: + return username, f"CREATE USER \"{username}\" WITH PASSWORD '{password}'" + if self.dialect == "snowflake": + return username, f"CREATE ROLE {username}" + if self.dialect == "databricks": + # Creating an account-level group in Databricks requires making REST API calls so we are going to + # use a pre-created group instead. We assume the suffix on the name is the unique id. + # In the Databricks UI, Workspace Settings -> Identity and Access, create the following groups: + # - test_user, test_analyst, test_etl_user, test_reader, test_writer, test_admin + # (there do not need to be any users assigned to these groups) + return "_".join(username.split("_")[:-1]), None + if self.dialect == "bigquery": + # BigQuery uses IAM service accounts that need to be pre-created + # Pre-created GCP service accounts: + # - sqlmesh-test-admin@{project-id}.iam.gserviceaccount.com + # - sqlmesh-test-analyst@{project-id}.iam.gserviceaccount.com + # - sqlmesh-test-etl-user@{project-id}.iam.gserviceaccount.com + # - sqlmesh-test-reader@{project-id}.iam.gserviceaccount.com + # - sqlmesh-test-user@{project-id}.iam.gserviceaccount.com + # - sqlmesh-test-writer@{project-id}.iam.gserviceaccount.com + role_name = ( + username.replace(f"_{self.test_id}", "").replace("test_", "").replace("_", "-") + ) + project_id = self.engine_adapter.get_current_catalog() + service_account = f"sqlmesh-test-{role_name}@{project_id}.iam.gserviceaccount.com" + return f"serviceAccount:{service_account}", None + raise ValueError(f"User creation not supported for dialect: {self.dialect}") + + def _create_user_or_role(self, username: str, password: t.Optional[str] = None) -> str: + username, create_user_sql = self._get_create_user_or_role(username, password) + if create_user_sql: + self.engine_adapter.execute(create_user_sql) + return username + + @contextmanager + def create_users_or_roles(self, *role_names: str) -> t.Iterator[t.Dict[str, str]]: + created_users = [] + roles = {} + + try: + for role_name in role_names: + user_name = normalize_identifiers( + self.add_test_suffix(f"test_{role_name}"), dialect=self.dialect + ).sql(dialect=self.dialect) + password = random_id() + if self.dialect == "redshift": + password += ( + "A" # redshift requires passwords to have at least one uppercase letter + ) + user_name = self._create_user_or_role(user_name, password) + created_users.append(user_name) + roles[role_name] = user_name + + yield roles + + finally: + for user_name in created_users: + self._cleanup_user_or_role(user_name) + + def get_select_privilege(self) -> str: + if self.dialect == "bigquery": + return "roles/bigquery.dataViewer" + return "SELECT" + + def get_insert_privilege(self) -> str: + if self.dialect == "databricks": + # This would really be "MODIFY" but for the purposes of having this be unique from UPDATE + # we return "MANAGE" instead + return "MANAGE" + if self.dialect == "bigquery": + return "roles/bigquery.dataEditor" + return "INSERT" + + def get_update_privilege(self) -> str: + if self.dialect == "databricks": + return "MODIFY" + if self.dialect == "bigquery": + return "roles/bigquery.dataOwner" + return "UPDATE" + + def _cleanup_user_or_role(self, user_name: str) -> None: + """Helper function to clean up a user/role and all their dependencies.""" + try: + if self.dialect in ["postgres", "redshift"]: + self.engine_adapter.execute(f""" + SELECT pg_terminate_backend(pid) + FROM pg_stat_activity + WHERE usename = '{user_name}' AND pid <> pg_backend_pid() + """) + self.engine_adapter.execute(f'DROP OWNED BY "{user_name}"') + self.engine_adapter.execute(f'DROP USER IF EXISTS "{user_name}"') + elif self.dialect == "snowflake": + self.engine_adapter.execute(f"DROP ROLE IF EXISTS {user_name}") + elif self.dialect in ["databricks", "bigquery"]: + # For Databricks and BigQuery, we use pre-created accounts that should not be deleted + pass + except Exception: + pass + + +def wait_until(fn: t.Callable[..., bool], attempts=3, wait=5) -> None: + current_attempt = 0 + while current_attempt < attempts: + current_attempt += 1 + result = fn() + if result: + return + time.sleep(wait) + + raise Exception(f"Wait function did not return True after {attempts} attempts") diff --git a/tests/core/engine_adapter/integration/config.yaml b/tests/core/engine_adapter/integration/config.yaml new file mode 100644 index 0000000000..0b1ecd8193 --- /dev/null +++ b/tests/core/engine_adapter/integration/config.yaml @@ -0,0 +1,221 @@ +gateways: + + inttest_duckdb: + connection: + type: duckdb + catalogs: + memory: ':memory:' + testing: "{{ var('tmp_path') }}/testing.duckdb" + + # Databases with docker images available + inttest_trino_hive: + connection: + type: trino + host: {{ env_var('DOCKER_HOSTNAME', 'localhost') }} + port: 8080 + user: admin + catalog: datalake + http_scheme: http + retries: 20 + check_import: false + state_connection: + type: duckdb + inttest_trino_iceberg: + connection: + type: trino + host: {{ env_var('DOCKER_HOSTNAME', 'localhost') }} + port: 8080 + user: admin + catalog: datalake_iceberg + http_scheme: http + retries: 20 + check_import: false + state_connection: + type: duckdb + inttest_trino_delta: + connection: + type: trino + host: {{ env_var('DOCKER_HOSTNAME', 'localhost') }} + port: 8080 + user: admin + catalog: datalake_delta + http_scheme: http + retries: 20 + check_import: false + state_connection: + type: duckdb + inttest_trino_nessie: + connection: + type: trino + host: {{ env_var('DOCKER_HOSTNAME', 'localhost') }} + port: 8080 + user: admin + catalog: datalake_nessie + http_scheme: http + retries: 20 + check_import: false + state_connection: + type: duckdb + inttest_spark: + connection: + type: spark + config: + spark.remote: sc://{{ env_var('DOCKER_HOSTNAME', 'localhost') }} + check_import: false + state_connection: + type: duckdb + inttest_mssql: + connection: + type: mssql + host: {{ env_var('DOCKER_HOSTNAME', 'localhost') }} + user: sa + password: 1StrongPwd@@ + check_import: false + inttest_postgres: + connection: + type: postgres + user: postgres + password: postgres + database: postgres + host: {{ env_var('DOCKER_HOSTNAME', 'localhost') }} + port: 5432 + check_import: false + inttest_mysql: + connection: + type: mysql + host: {{ env_var('DOCKER_HOSTNAME', 'localhost') }} + user: root + password: mysql + port: 3306 + charset: utf8 + check_import: false + inttest_clickhouse_standalone: + connection: + type: clickhouse + host: {{ env_var('DOCKER_HOSTNAME', 'localhost') }} + port: 8123 + username: clickhouse + password: clickhouse + check_import: false + state_connection: + type: duckdb + inttest_clickhouse_cluster: + connection: + type: clickhouse + host: {{ env_var('DOCKER_HOSTNAME', 'localhost') }} + port: 8123 + username: clickhouse + password: clickhouse + cluster: cluster1 + check_import: false + state_connection: + type: duckdb + inttest_risingwave: + connection: + type: risingwave + user: root + database: dev + host: {{ env_var('DOCKER_HOSTNAME', 'localhost') }} + port: 4566 + check_import: false + + + # Cloud databases + inttest_snowflake: + connection: + type: snowflake + account: {{ env_var('SNOWFLAKE_ACCOUNT') }} + warehouse: {{ env_var('SNOWFLAKE_WAREHOUSE') }} + database: {{ env_var('SNOWFLAKE_DATABASE') }} + user: {{ env_var('SNOWFLAKE_USER') }} + private_key_path: {{ env_var('SNOWFLAKE_PRIVATE_KEY_FILE', 'tests/fixtures/snowflake/rsa_key_no_pass.p8') }} + check_import: false + state_connection: + type: duckdb + + inttest_databricks: + connection: + type: databricks + catalog: {{ env_var('DATABRICKS_CATALOG') }} + server_hostname: {{ env_var('DATABRICKS_SERVER_HOSTNAME') }} + http_path: {{ env_var('DATABRICKS_HTTP_PATH') }} + auth_type: {{ env_var('DATABRICKS_AUTH_TYPE', 'databricks-oauth') }} + oauth_client_id: {{ env_var('DATABRICKS_CLIENT_ID') }} + oauth_client_secret: {{ env_var('DATABRICKS_CLIENT_SECRET') }} + databricks_connect_use_serverless: true + check_import: false + + inttest_redshift: + connection: + type: redshift + host: {{ env_var('REDSHIFT_HOST') }} + user: {{ env_var('REDSHIFT_USER') }} + password: {{ env_var('REDSHIFT_PASSWORD') }} + database: {{ env_var('REDSHIFT_DATABASE') }} + check_import: false + + inttest_bigquery: + connection: + type: bigquery + method: service-account + keyfile: {{ env_var('BIGQUERY_KEYFILE') }} + check_import: false + state_connection: + type: duckdb + + inttest_clickhouse_cloud: + connection: + type: clickhouse + host: {{ env_var("CLICKHOUSE_CLOUD_HOST") }} + port: 8443 + username: {{ env_var("CLICKHOUSE_CLOUD_USERNAME") }} + password: {{ env_var("CLICKHOUSE_CLOUD_PASSWORD") }} + connect_timeout: 30 + connection_pool_options: + retries: 5 + check_import: false + state_connection: + type: duckdb + + inttest_athena: + connection: + type: athena + aws_access_key_id: {{ env_var("AWS_ACCESS_KEY_ID") }} + aws_secret_access_key: {{ env_var("AWS_SECRET_ACCESS_KEY") }} + region_name: {{ env_var("AWS_REGION") }} + work_group: {{ env_var("ATHENA_WORK_GROUP", "primary") }} + s3_warehouse_location: {{ env_var("ATHENA_S3_WAREHOUSE_LOCATION", "") }} + check_import: false + state_connection: + type: duckdb + + inttest_fabric: + connection: + type: fabric + driver: pyodbc + host: {{ env_var("FABRIC_HOST") }} + user: {{ env_var("FABRIC_CLIENT_ID") }} + password: {{ env_var("FABRIC_CLIENT_SECRET") }} + database: {{ env_var("FABRIC_DATABASE") }} + tenant_id: {{ env_var("FABRIC_TENANT_ID") }} + workspace_id: {{ env_var("FABRIC_WORKSPACE_ID") }} + odbc_properties: + Authentication: ActiveDirectoryServicePrincipal + check_import: false + state_connection: + type: duckdb + + inttest_gcp_postgres: + connection: + type: gcp_postgres + instance_connection_string: {{ env_var("GCP_POSTGRES_INSTANCE_CONNECTION_STRING") }} + user: {{ env_var("GCP_POSTGRES_USER") }} + password: {{ env_var("GCP_POSTGRES_PASSWORD") }} + keyfile_json: {{ env_var("GCP_POSTGRES_KEYFILE_JSON", "") }} + db: {{ env_var("GCP_POSTGRES_DATABASE") }} + enable_iam_auth: true + check_import: false + + +model_defaults: + dialect: duckdb diff --git a/tests/core/engine_adapter/integration/conftest.py b/tests/core/engine_adapter/integration/conftest.py new file mode 100644 index 0000000000..3fb4bc15f1 --- /dev/null +++ b/tests/core/engine_adapter/integration/conftest.py @@ -0,0 +1,156 @@ +from __future__ import annotations + +import typing as t +import pytest +import pathlib +import os +import logging +from pytest import FixtureRequest + +from sqlmesh import Config, EngineAdapter +from sqlmesh.core.constants import SQLMESH_PATH +from sqlmesh.core.config.connection import ( + ConnectionConfig, + AthenaConnectionConfig, + DuckDBConnectionConfig, +) +from sqlmesh.core.engine_adapter import AthenaEngineAdapter +from sqlmesh.core.config import load_config_from_paths + +from tests.core.engine_adapter.integration import ( + TestContext, + generate_pytest_params, + ENGINES, + IntegrationTestEngine, +) + +logger = logging.getLogger(__name__) + + +@pytest.fixture +def config(tmp_path: pathlib.Path) -> Config: + return load_config_from_paths( + Config, + project_paths=[ + pathlib.Path(os.path.join(os.path.dirname(__file__), "config.yaml")), + ], + personal_paths=[(SQLMESH_PATH / "config.yaml").expanduser()], + variables={"tmp_path": str(tmp_path)}, + ) + + +@pytest.fixture +def create_engine_adapter( + request: pytest.FixtureRequest, + testrun_uid: str, + config: Config, +) -> t.Callable[[str, str], EngineAdapter]: + def _create(engine_name: str, gateway: str) -> EngineAdapter: + assert gateway in config.gateways + connection_config = config.gateways[gateway].connection + assert isinstance(connection_config, ConnectionConfig) + + engine_adapter = connection_config.create_engine_adapter() + + if engine_name == "athena": + assert isinstance(connection_config, AthenaConnectionConfig) + assert isinstance(engine_adapter, AthenaEngineAdapter) + + # S3 files need to go into a unique location for each test run + # This is because DROP TABLE on a Hive table just drops the table from the metastore + # The files still exist in S3, so if you CREATE TABLE to the same location, the old data shows back up + # Note that the `testrun_uid` fixture comes from the xdist plugin + if connection_config.s3_warehouse_location: + engine_adapter.s3_warehouse_location = os.path.join( + connection_config.s3_warehouse_location, + f"testrun_{testrun_uid}", + request.node.originalname, + ) + + # Trino: If we batch up the requests then when running locally we get a table not found error after creating the + # table and then immediately after trying to insert rows into it. There seems to be a delay between when the + # metastore is made aware of the table and when it responds that it exists. I'm hoping this is not an issue + # in practice on production machines. + if not engine_name == "trino": + engine_adapter.DEFAULT_BATCH_SIZE = 1 + + # Clear our any local db files that may have been left over from previous runs + if engine_name == "duckdb": + assert isinstance(connection_config, DuckDBConnectionConfig) + for raw_path in [ + v for v in (connection_config.catalogs or {}).values() if isinstance(v, str) + ]: + pathlib.Path(raw_path).unlink(missing_ok=True) + + return engine_adapter + + return _create + + +@pytest.fixture +def create_test_context( + request: FixtureRequest, + create_engine_adapter: t.Callable[[str, str], EngineAdapter], + tmp_path: pathlib.Path, +) -> t.Callable[[IntegrationTestEngine, str, str, str], t.Iterable[TestContext]]: + def _create( + engine: IntegrationTestEngine, gateway: str, test_type: str, table_format: str + ) -> t.Iterable[TestContext]: + is_remote = request.node.get_closest_marker("remote") is not None + + engine_adapter = create_engine_adapter(engine.engine, gateway) + + ctx = TestContext( + test_type, + engine_adapter, + f"{engine.engine}_{table_format}", + gateway, + tmp_path=tmp_path, + is_remote=is_remote, + ) + + try: + ctx.init() + except: + # pytest-retry doesnt work if there are errors in fixture setup (ref: https://github.com/str0zzapreti/pytest-retry/issues/33 ) + # what we can do is log the exception and return a partially-initialized context to the test, which should + # throw an exception when it tries to access something that didnt init properly and thus trigger pytest-retry to retry + logger.exception("Context init failed") + + with ctx.engine_adapter.session({}): + yield ctx + + try: + ctx.cleanup() + except: + # We need to catch this exception because if there is an error during teardown, pytest-retry aborts immediately + # instead of retrying + logger.exception("Context cleanup failed") + + return _create + + +@pytest.fixture( + params=list(generate_pytest_params(ENGINES, query=True, show_variant_in_test_id=False)) +) +def ctx( + request: FixtureRequest, + create_test_context: t.Callable[[IntegrationTestEngine, str], t.Iterable[TestContext]], +) -> t.Iterable[TestContext]: + yield from create_test_context(*request.param) + + +@pytest.fixture(params=list(generate_pytest_params(ENGINES, query=False, df=True))) +def ctx_df( + request: FixtureRequest, + create_test_context: t.Callable[[IntegrationTestEngine, str], t.Iterable[TestContext]], +) -> t.Iterable[TestContext]: + yield from create_test_context(*request.param) + + +@pytest.fixture(params=list(generate_pytest_params(ENGINES, query=True, df=True))) +def ctx_query_and_df( + request: FixtureRequest, + create_test_context: t.Callable[[IntegrationTestEngine, str], t.Iterable[TestContext]], +) -> t.Iterable[TestContext]: + yield from create_test_context(*request.param) diff --git a/tests/core/engine_adapter/integration/docker/_common-hive.yaml b/tests/core/engine_adapter/integration/docker/_common-hive.yaml new file mode 100644 index 0000000000..59d3cd80a0 --- /dev/null +++ b/tests/core/engine_adapter/integration/docker/_common-hive.yaml @@ -0,0 +1,42 @@ +services: + + # Postgres backing storage for Hive metastores + metastore: + image: postgres + environment: + POSTGRES_USER: hive + POSTGRES_PASSWORD: hive + volumes: + - ./init-metastore-db.sql:/docker-entrypoint-initdb.d/initdb.sql + command: -c max_connections=500 + + # S3-style object storage + minio: + image: 'minio/minio:RELEASE.2022-05-26T05-48-41Z' + ports: + - '9000:9000' + - '9001:9001' + environment: + MINIO_ACCESS_KEY: minio + MINIO_SECRET_KEY: minio123 + command: server /data --console-address ":9001" + + # Set up minio with default buckets / paths + mc-job: + image: 'minio/mc:RELEASE.2022-05-09T04-08-26Z' + entrypoint: | + /bin/bash -c " + sleep 5; + /usr/bin/mc config --quiet host add myminio http://minio:9000 minio minio123; + /usr/bin/mc mb --quiet myminio/trino/datalake; + /usr/bin/mc mb --quiet myminio/trino/datalake_iceberg; + /usr/bin/mc mb --quiet myminio/trino/datalake_delta; + /usr/bin/mc mb --quiet myminio/trino/testing; + /usr/bin/mc mb --quiet myminio/trino/testing_iceberg; + /usr/bin/mc mb --quiet myminio/trino/testing_delta; + /usr/bin/mc mb --quiet myminio/spark/datalake; + /usr/bin/mc mb --quiet myminio/spark/testing; + /usr/bin/mc mb --quiet myminio/nessie/warehouse; + " + depends_on: + - minio \ No newline at end of file diff --git a/tests/core/engine_adapter/integration/docker/clickhouse/config.xml b/tests/core/engine_adapter/integration/docker/clickhouse/config.xml new file mode 100644 index 0000000000..0e8915cda4 --- /dev/null +++ b/tests/core/engine_adapter/integration/docker/clickhouse/config.xml @@ -0,0 +1,40 @@ + + + 1 + warning + + 0.0.0.0 + + + + + clickhouse-1 + 9000 + + + clickhouse-2 + 9000 + + + clickhouse-3 + 9000 + + + + + + 01 + + cluster1 + + + + keeper + 2181 + + + + jdbc-bridge + 9019 + + \ No newline at end of file diff --git a/tests/core/engine_adapter/integration/docker/clickhouse/keeper.xml b/tests/core/engine_adapter/integration/docker/clickhouse/keeper.xml new file mode 100644 index 0000000000..61a2ab62b2 --- /dev/null +++ b/tests/core/engine_adapter/integration/docker/clickhouse/keeper.xml @@ -0,0 +1,18 @@ + + + 1 + warning + + 0.0.0.0 + + 2181 + 1 + + + 1 + keeper + 9234 + + + + \ No newline at end of file diff --git a/tests/core/engine_adapter/integration/docker/compose.clickhouse.yaml b/tests/core/engine_adapter/integration/docker/compose.clickhouse.yaml new file mode 100644 index 0000000000..4ee3355afd --- /dev/null +++ b/tests/core/engine_adapter/integration/docker/compose.clickhouse.yaml @@ -0,0 +1,47 @@ +x-clickhouse-server: &clickhouse-server + image: 'clickhouse/clickhouse-server:24.7' + volumes: + # note: this is deliberately published as docker_related_config.xml to replace the file already in the image + # which tries to do things like configure ipv6 and throw thousands of errors + - ./clickhouse/config.xml:/etc/clickhouse-server/config.d/docker_related_config.xml + depends_on: + - keeper + - jdbc-bridge + +x-clickhouse-server-environment: &default-environment + CLICKHOUSE_DEFAULT_ACCESS_MANAGEMENT: '1' + CLICKHOUSE_USER: clickhouse + CLICKHOUSE_PASSWORD: clickhouse + +services: + standalone: + image: 'clickhouse/clickhouse-server:24.7' + environment: + <<: *default-environment + ports: + - 8122:8123 + clickhouse-1: + <<: *clickhouse-server + environment: + <<: *default-environment + MACRO_REPLICA: '01' + ports: + - 8123:8123 + clickhouse-2: + <<: *clickhouse-server + environment: + <<: *default-environment + MACRO_REPLICA: '02' + clickhouse-3: + <<: *clickhouse-server + environment: + <<: *default-environment + MACRO_REPLICA: '03' + keeper: + image: clickhouse/clickhouse-keeper:24.7 + volumes: + - ./clickhouse/keeper.xml:/etc/clickhouse-keeper/keeper_config.d/keeper.xml + # This is just so you can use a client like DBeaver without seeing this error everywhere: + # DB::Exception: clickhouse-jdbc-bridge is not running. Please, start it manually + jdbc-bridge: + image: clickhouse/jdbc-bridge:2.1 \ No newline at end of file diff --git a/tests/core/engine_adapter/integration/docker/compose.mssql.yaml b/tests/core/engine_adapter/integration/docker/compose.mssql.yaml new file mode 100644 index 0000000000..c4ead5dca7 --- /dev/null +++ b/tests/core/engine_adapter/integration/docker/compose.mssql.yaml @@ -0,0 +1,8 @@ +services: + mssql: + image: mcr.microsoft.com/mssql/server:2019-latest + ports: + - '1433:1433' + environment: + SA_PASSWORD: 1StrongPwd@@ + ACCEPT_EULA: Y \ No newline at end of file diff --git a/tests/core/engine_adapter/integration/docker/compose.mysql.yaml b/tests/core/engine_adapter/integration/docker/compose.mysql.yaml new file mode 100644 index 0000000000..1916a3dc44 --- /dev/null +++ b/tests/core/engine_adapter/integration/docker/compose.mysql.yaml @@ -0,0 +1,7 @@ +services: + mysql: + image: mysql:8.1 + ports: + - '3306:3306' + environment: + MYSQL_ROOT_PASSWORD: mysql \ No newline at end of file diff --git a/tests/core/engine_adapter/integration/docker/compose.postgres.yaml b/tests/core/engine_adapter/integration/docker/compose.postgres.yaml new file mode 100644 index 0000000000..941c245cd9 --- /dev/null +++ b/tests/core/engine_adapter/integration/docker/compose.postgres.yaml @@ -0,0 +1,7 @@ +services: + postgres: + image: postgres + ports: + - '5432:5432' + environment: + POSTGRES_PASSWORD: postgres \ No newline at end of file diff --git a/tests/core/engine_adapter/integration/docker/compose.risingwave.yaml b/tests/core/engine_adapter/integration/docker/compose.risingwave.yaml new file mode 100644 index 0000000000..6835b42501 --- /dev/null +++ b/tests/core/engine_adapter/integration/docker/compose.risingwave.yaml @@ -0,0 +1,5 @@ +services: + risingwave: + image: risingwavelabs/risingwave:nightly-20250225 + ports: + - "4566:4566" \ No newline at end of file diff --git a/tests/core/engine_adapter/integration/docker/compose.spark.yaml b/tests/core/engine_adapter/integration/docker/compose.spark.yaml new file mode 100644 index 0000000000..bf132474b0 --- /dev/null +++ b/tests/core/engine_adapter/integration/docker/compose.spark.yaml @@ -0,0 +1,51 @@ +# this needs to be duplicated here and in the Trino compose file because Docker +# refuses to implement support for YAMl anchors in the `include:` mechanism +# ref: https://github.com/docker/compose/issues/5621 +x-hive-metastore-environments: &hive_metastore_environments + S3_ENDPOINT: http://minio:9000 + S3_ACCESS_KEY: minio + S3_SECRET_KEY: minio123 + S3_PATH_STYLE_ACCESS: "true" + REGION: "" + GOOGLE_CLOUD_KEY_FILE_PATH: "" + AZURE_ADL_CLIENT_ID: "" + AZURE_ADL_CREDENTIAL: "" + AZURE_ADL_REFRESH_URL: "" + AZURE_ABFS_STORAGE_ACCOUNT: "" + AZURE_ABFS_ACCESS_KEY: "" + AZURE_WASB_STORAGE_ACCOUNT: "" + AZURE_ABFS_OAUTH: "" + AZURE_ABFS_OAUTH_TOKEN_PROVIDER: "" + AZURE_ABFS_OAUTH_CLIENT_ID: "" + AZURE_ABFS_OAUTH_SECRET: "" + AZURE_ABFS_OAUTH_ENDPOINT: "" + AZURE_WASB_ACCESS_KEY: "" + +include: + - ./_common-hive.yaml + +services: + spark: + build: + context: ./spark + command: /opt/bitnami/spark/sbin/start-connect-server.sh --packages org.apache.spark:spark-connect_2.12:3.5.0 + ports: + - '15000-15100:15000-15100' + volumes: + - ./spark/conf/spark-defaults.conf:/opt/bitnami/spark/conf/spark-defaults.conf + - ./spark/conf/hive-site.xml:/opt/bitnami/spark/conf/hive-site.xml + depends_on: + - spark-hive-metastore + + spark-hive-metastore: + hostname: spark-hive-metastore + image: 'starburstdata/hive:3.1.2-e.15' + environment: + HIVE_METASTORE_DRIVER: org.postgresql.Driver + HIVE_METASTORE_JDBC_URL: jdbc:postgresql://metastore:5432/metastore + HIVE_METASTORE_USER: hive + HIVE_METASTORE_PASSWORD: hive + HIVE_METASTORE_WAREHOUSE_DIR: s3://spark/ + <<: *hive_metastore_environments + depends_on: + - metastore diff --git a/tests/core/engine_adapter/integration/docker/compose.trino.yaml b/tests/core/engine_adapter/integration/docker/compose.trino.yaml new file mode 100644 index 0000000000..f5ae25fa4e --- /dev/null +++ b/tests/core/engine_adapter/integration/docker/compose.trino.yaml @@ -0,0 +1,110 @@ +# this needs to be duplicated here and in the Spark compose file because Docker +# refuses to implement support for YAMl anchors in the `include:` mechanism +# ref: https://github.com/docker/compose/issues/5621 +x-hive-metastore-environments: &hive_metastore_environments + S3_ENDPOINT: http://minio:9000 + S3_ACCESS_KEY: minio + S3_SECRET_KEY: minio123 + S3_PATH_STYLE_ACCESS: "true" + REGION: "" + GOOGLE_CLOUD_KEY_FILE_PATH: "" + AZURE_ADL_CLIENT_ID: "" + AZURE_ADL_CREDENTIAL: "" + AZURE_ADL_REFRESH_URL: "" + AZURE_ABFS_STORAGE_ACCOUNT: "" + AZURE_ABFS_ACCESS_KEY: "" + AZURE_WASB_STORAGE_ACCOUNT: "" + AZURE_ABFS_OAUTH: "" + AZURE_ABFS_OAUTH_TOKEN_PROVIDER: "" + AZURE_ABFS_OAUTH_CLIENT_ID: "" + AZURE_ABFS_OAUTH_SECRET: "" + AZURE_ABFS_OAUTH_ENDPOINT: "" + AZURE_WASB_ACCESS_KEY: "" + +include: + - ./_common-hive.yaml + +services: + + # Trino Stack + trino: + image: 'trinodb/trino:475' + ports: + - '8080:8080' + volumes: + - ./trino/catalog:/etc/trino/catalog + depends_on: + - minio + - metastore + + trino-datalake-hive-metastore: + image: 'starburstdata/hive:3.1.3-e.11' + environment: + HIVE_METASTORE_DRIVER: org.postgresql.Driver + HIVE_METASTORE_JDBC_URL: jdbc:postgresql://metastore:5432/datalake_metastore + HIVE_METASTORE_USER: hive + HIVE_METASTORE_PASSWORD: hive + HIVE_METASTORE_WAREHOUSE_DIR: s3://trino/datalake + <<: *hive_metastore_environments + depends_on: + - metastore + + trino-testing-hive-metastore: + image: 'starburstdata/hive:3.1.3-e.11' + environment: + HIVE_METASTORE_DRIVER: org.postgresql.Driver + HIVE_METASTORE_JDBC_URL: jdbc:postgresql://metastore:5432/testing_metastore + HIVE_METASTORE_USER: hive + HIVE_METASTORE_PASSWORD: hive + HIVE_METASTORE_WAREHOUSE_DIR: s3://trino/testing + <<: *hive_metastore_environments + depends_on: + - metastore + + trino-datalake-iceberg-hive-metastore: + image: 'starburstdata/hive:3.1.3-e.11' + environment: + HIVE_METASTORE_DRIVER: org.postgresql.Driver + HIVE_METASTORE_JDBC_URL: jdbc:postgresql://metastore:5432/datalake_iceberg_metastore + HIVE_METASTORE_USER: hive + HIVE_METASTORE_PASSWORD: hive + HIVE_METASTORE_WAREHOUSE_DIR: s3://trino/datalake_iceberg + <<: *hive_metastore_environments + depends_on: + - metastore + + trino-datalake-delta-hive-metastore: + image: 'starburstdata/hive:3.1.3-e.11' + environment: + HIVE_METASTORE_DRIVER: org.postgresql.Driver + HIVE_METASTORE_JDBC_URL: jdbc:postgresql://metastore:5432/datalake_delta_metastore + HIVE_METASTORE_USER: hive + HIVE_METASTORE_PASSWORD: hive + HIVE_METASTORE_WAREHOUSE_DIR: s3://trino/datalake_delta + <<: *hive_metastore_environments + depends_on: + - metastore + + nessie: + image: ghcr.io/projectnessie/nessie:0.102.2 + restart: on-failure + ports: + - '19120:19120' + environment: + nessie.version.store.type: JDBC2 + nessie.version.store.persist.jdbc.datasource: postgresql + quarkus.datasource.postgresql.jdbc.url: jdbc:postgresql://metastore:5432/nessie + quarkus.datasource.postgresql.username: hive + quarkus.datasource.postgresql.password: hive + nessie.catalog.default-warehouse: warehouse + nessie.catalog.warehouses.warehouse.location: s3://nessie/warehouse + nessie.catalog.service.s3.default-options.region: us-east-1 + nessie.catalog.service.s3.default-options.path-style-access: 'true' + nessie.catalog.service.s3.default-options.access-key: urn:nessie-secret:quarkus:nessie.catalog.secrets.access-key + nessie.catalog.secrets.access-key.name: minio + nessie.catalog.secrets.access-key.secret: minio123 + nessie.catalog.service.s3.default-options.endpoint: http://minio:9000/ + + depends_on: + - metastore + diff --git a/tests/core/engine_adapter/integration/docker/init-metastore-db.sql b/tests/core/engine_adapter/integration/docker/init-metastore-db.sql new file mode 100644 index 0000000000..dd84670830 --- /dev/null +++ b/tests/core/engine_adapter/integration/docker/init-metastore-db.sql @@ -0,0 +1,14 @@ +-- Settings +alter system set max_connections to '500'; + +-- Spark +create database metastore; + +-- Trino +create database datalake_metastore; +create database datalake_iceberg_metastore; +create database datalake_delta_metastore; +create database testing_metastore; +create database testing_iceberg_metastore; +create database testing_delta_metastore; +create database nessie; \ No newline at end of file diff --git a/tests/core/engine_adapter/spark/Dockerfile b/tests/core/engine_adapter/integration/docker/spark/Dockerfile similarity index 94% rename from tests/core/engine_adapter/spark/Dockerfile rename to tests/core/engine_adapter/integration/docker/spark/Dockerfile index 7fb39b840c..cfbe7d1e88 100644 --- a/tests/core/engine_adapter/spark/Dockerfile +++ b/tests/core/engine_adapter/integration/docker/spark/Dockerfile @@ -1,4 +1,4 @@ -FROM docker.io/bitnami/spark:3.5 +FROM bitnamilegacy/spark:3.5.2 USER root RUN install_packages curl USER 1001 diff --git a/tests/core/engine_adapter/spark/conf/hive-site.xml b/tests/core/engine_adapter/integration/docker/spark/conf/hive-site.xml similarity index 100% rename from tests/core/engine_adapter/spark/conf/hive-site.xml rename to tests/core/engine_adapter/integration/docker/spark/conf/hive-site.xml diff --git a/tests/core/engine_adapter/spark/conf/spark-defaults.conf b/tests/core/engine_adapter/integration/docker/spark/conf/spark-defaults.conf similarity index 100% rename from tests/core/engine_adapter/spark/conf/spark-defaults.conf rename to tests/core/engine_adapter/integration/docker/spark/conf/spark-defaults.conf diff --git a/tests/core/engine_adapter/integration/docker/trino/catalog/datalake.properties b/tests/core/engine_adapter/integration/docker/trino/catalog/datalake.properties new file mode 100644 index 0000000000..6ee067fa61 --- /dev/null +++ b/tests/core/engine_adapter/integration/docker/trino/catalog/datalake.properties @@ -0,0 +1,13 @@ +connector.name=hive + +hive.metastore.uri=thrift://trino-datalake-hive-metastore:9083 +hive.metastore.thrift.client.connect-timeout=10s +hive.metastore.thrift.client.read-timeout=10s +hive.storage-format=PARQUET + +fs.native-s3.enabled=true +s3.endpoint=http://minio:9000 +s3.path-style-access=true +s3.aws-access-key=minio +s3.aws-secret-key=minio123 +s3.region=us-east-1 \ No newline at end of file diff --git a/tests/core/engine_adapter/integration/docker/trino/catalog/datalake_delta.properties b/tests/core/engine_adapter/integration/docker/trino/catalog/datalake_delta.properties new file mode 100644 index 0000000000..23f5ec0835 --- /dev/null +++ b/tests/core/engine_adapter/integration/docker/trino/catalog/datalake_delta.properties @@ -0,0 +1,15 @@ +connector.name=delta_lake + +hive.metastore.uri=thrift://trino-datalake-delta-hive-metastore:9083 +hive.metastore.thrift.client.connect-timeout=10s +hive.metastore.thrift.client.read-timeout=10s + +delta.enable-non-concurrent-writes=true +delta.hive-catalog-name=datalake + +fs.native-s3.enabled=true +s3.endpoint=http://minio:9000 +s3.path-style-access=true +s3.aws-access-key=minio +s3.aws-secret-key=minio123 +s3.region=us-east-1 \ No newline at end of file diff --git a/tests/core/engine_adapter/integration/docker/trino/catalog/datalake_iceberg.properties b/tests/core/engine_adapter/integration/docker/trino/catalog/datalake_iceberg.properties new file mode 100644 index 0000000000..b2dc6ecb03 --- /dev/null +++ b/tests/core/engine_adapter/integration/docker/trino/catalog/datalake_iceberg.properties @@ -0,0 +1,19 @@ +connector.name=iceberg + +hive.metastore.uri=thrift://trino-datalake-iceberg-hive-metastore:9083 +hive.metastore.thrift.client.connect-timeout=10s +hive.metastore.thrift.client.read-timeout=10s + +# note: we have to use a Hive metastore instead of the REST catalog because +# as at 2024-02-16 its the only one that supports views +iceberg.catalog.type=hive_metastore +iceberg.file-format=PARQUET +iceberg.metadata-cache.enabled=false +iceberg.hive-catalog-name=datalake + +fs.native-s3.enabled=true +s3.endpoint=http://minio:9000 +s3.path-style-access=true +s3.aws-access-key=minio +s3.aws-secret-key=minio123 +s3.region=us-east-1 \ No newline at end of file diff --git a/tests/core/engine_adapter/integration/docker/trino/catalog/datalake_nessie.properties b/tests/core/engine_adapter/integration/docker/trino/catalog/datalake_nessie.properties new file mode 100644 index 0000000000..6ac1786252 --- /dev/null +++ b/tests/core/engine_adapter/integration/docker/trino/catalog/datalake_nessie.properties @@ -0,0 +1,14 @@ +connector.name=iceberg +iceberg.catalog.type=rest +iceberg.rest-catalog.security=NONE +iceberg.rest-catalog.uri=http://nessie:19120/iceberg/ +iceberg.rest-catalog.vended-credentials-enabled=false +iceberg.metadata-cache.enabled=false +iceberg.hive-catalog-name=datalake + +fs.native-s3.enabled=true +s3.endpoint=http://minio:9000 +s3.path-style-access=true +s3.aws-access-key=minio +s3.aws-secret-key=minio123 +s3.region=us-east-1 \ No newline at end of file diff --git a/tests/core/engine_adapter/integration/docker/trino/catalog/testing.properties b/tests/core/engine_adapter/integration/docker/trino/catalog/testing.properties new file mode 100644 index 0000000000..2fd90f8dcc --- /dev/null +++ b/tests/core/engine_adapter/integration/docker/trino/catalog/testing.properties @@ -0,0 +1,13 @@ +connector.name=hive + +hive.metastore.uri=thrift://trino-testing-hive-metastore:9083 +hive.metastore.thrift.client.connect-timeout=10s +hive.metastore.thrift.client.read-timeout=10s +hive.storage-format=PARQUET + +fs.native-s3.enabled=true +s3.endpoint=http://minio:9000 +s3.path-style-access=true +s3.aws-access-key=minio +s3.aws-secret-key=minio123 +s3.region=us-east-1 \ No newline at end of file diff --git a/tests/core/engine_adapter/integration/test_freshness.py b/tests/core/engine_adapter/integration/test_freshness.py new file mode 100644 index 0000000000..e5ee574e7e --- /dev/null +++ b/tests/core/engine_adapter/integration/test_freshness.py @@ -0,0 +1,488 @@ +# type: ignore +from __future__ import annotations + +import pathlib +import typing as t +from datetime import datetime, timedelta +from IPython.utils.capture import capture_output + +import time_machine +from pytest_mock.plugin import MockerFixture + +import pytest +import time_machine + +import sqlmesh +from sqlmesh import Config, Context +from sqlmesh.utils.date import now, to_datetime +from sqlmesh.utils.errors import SignalEvalError +from tests.core.engine_adapter.integration import ( + TestContext, + TEST_SCHEMA, +) +from tests.utils.test_helpers import use_terminal_console + +EVALUATION_SPY = None + + +@pytest.fixture(autouse=True) +def _skip_snowflake(ctx: TestContext): + if ctx.dialect == "snowflake": + # these tests use callbacks that need to run db queries within a time_travel context that changes the system time to be in the future + # this causes invalid JWT's to be generated when the callbacks try to run a db query + pytest.skip( + "snowflake.connector generates an invalid JWT when time_travel changes the system time" + ) + + +# Mock the snapshot evaluator's evaluate function to count the number of times it is called +@pytest.fixture(autouse=True, scope="function") +def _install_evaluation_spy(mocker: MockerFixture): + global EVALUATION_SPY + EVALUATION_SPY = mocker.spy(sqlmesh.core.snapshot.evaluator.SnapshotEvaluator, "evaluate") + yield + EVALUATION_SPY = None + + +def assert_snapshot_last_altered_ts( + context: Context, + snapshot_id: str, + last_altered_ts: datetime, + dev_last_altered_ts: t.Optional[datetime] = None, +): + """ + Ensure that prod and dev last altered timestamps of a snapshot are as expected. + """ + snapshot = context.state_sync.get_snapshots([snapshot_id])[snapshot_id] + + if snapshot.is_external: + return + + assert to_datetime(snapshot.last_altered_ts).replace(microsecond=0) == last_altered_ts.replace( + microsecond=0 + ) + + if dev_last_altered_ts: + assert to_datetime(snapshot.dev_last_altered_ts).replace( + microsecond=0 + ) == dev_last_altered_ts.replace(microsecond=0) + + +def assert_model_evaluation( + lambda_func, was_evaluated: bool = True, day_delta: int = 0, model_evaluations: int = 1 +): + """ + Ensure that a model was evaluated by checking the freshness signal and that + the evaluation function was called the expected number of times. + """ + EVALUATION_SPY.reset_mock() + timestamp = now(minute_floor=False) + timedelta(days=day_delta) + with time_machine.travel(timestamp, tick=False): + with capture_output() as output: + plan_or_run_result = lambda_func() + + evaluate_function_called = EVALUATION_SPY.call_count == model_evaluations + signal_was_checked = "Checking signals for" in output.stdout + + assert signal_was_checked + if was_evaluated: + assert "All ready" in output.stdout + assert evaluate_function_called + else: + assert "None ready" in output.stdout + assert not evaluate_function_called + + return timestamp, plan_or_run_result + + +def create_model( + name: str, schema: str, query: str, path: pathlib.Path, signals: str = "freshness()" +): + """ + Create a freshness model with the given name, path, and query. + """ + model_name = f"{schema}.{name}" + model_path = path / "models" / f"{name}.sql" + (path / "models").mkdir(parents=True, exist_ok=True) + model_path.write_text( + f""" + MODEL ( + name {model_name}, + start '2024-01-01', + kind FULL, + signals ( + {signals}, + ) + ); + + {query} + """ + ) + + return model_name, model_path + + +def initialize_context( + ctx: TestContext, tmp_path: pathlib.Path, num_external_models: int = 1 +) -> t.Tuple[Context, str, t.List[str]]: + """ + Initialize a context by creating a schema and external models. + """ + adapter = ctx.engine_adapter + if not adapter.SUPPORTS_METADATA_TABLE_LAST_MODIFIED_TS: + pytest.skip("This test only runs for engines that support metadata-based freshness") + + # Create & initialize schema + schema = ctx.add_test_suffix(TEST_SCHEMA) + ctx._schemas.append(schema) + adapter.create_schema(schema) + + # Create & initialize external models + external_tables = [] + + yaml_content = "" + for i in range(1, num_external_models + 1): + external_table = f"{schema}.external_table{i}" + external_tables.append(f"{schema}.external_table{i}") + adapter.execute( + f"CREATE TABLE {external_table} AS (SELECT {i} AS col{i})", + quote_identifiers=False, + ) + + yaml_content = ( + yaml_content + + f""" +- name: {external_table} + columns: + col{i}: int + +""" + ) + + external_models_yaml = tmp_path / "external_models.yaml" + external_models_yaml.write_text(yaml_content) + + # Initialize context + def _set_config(gateway: str, config: Config) -> None: + config.model_defaults.dialect = ctx.dialect + + context = ctx.create_context(path=tmp_path, config_mutator=_set_config) + + return context, schema, external_tables + + +@use_terminal_console +def test_external_model_freshness(ctx: TestContext, tmp_path: pathlib.Path, mocker: MockerFixture): + adapter = ctx.engine_adapter + context, schema, (external_table1, external_table2) = initialize_context( + ctx, tmp_path, num_external_models=2 + ) + + # Create model that depends on external models + model_name, model_path = create_model( + "new_model", + schema, + f"SELECT col1 * col2 AS col FROM {external_table1}, {external_table2}", + tmp_path, + ) + + context.load() + + # Case 1: Model is evaluated for the first plan + prod_plan_ts_1, prod_plan_1 = assert_model_evaluation( + lambda: context.plan(auto_apply=True, no_prompts=True) + ) + + prod_snapshot_id = next(iter(prod_plan_1.context_diff.new_snapshots)) + assert_snapshot_last_altered_ts(context, prod_snapshot_id, last_altered_ts=prod_plan_ts_1) + + # Case 2: Model is NOT evaluated on run if external models are not fresh + assert_model_evaluation(lambda: context.run(), was_evaluated=False, day_delta=1) + + # Case 3: Differentiate last_altered_ts between snapshots with shared version + # For instance, creating a FORWARD_ONLY change in dev (reusing the version but creating a dev preview) should not cause + # any side effects to the prod snapshot's last_altered_ts hydration + model_path.write_text(model_path.read_text().replace("col1 * col2", "col1 + col2")) + context.load() + dev_plan_ts = now(minute_floor=False) + timedelta(days=2) + with time_machine.travel(dev_plan_ts, tick=False): + dev_plan = context.plan( + environment="dev", forward_only=True, auto_apply=True, no_prompts=True + ) + + context.state_sync.clear_cache() + dev_snapshot_id = next(iter(dev_plan.context_diff.new_snapshots)) + assert_snapshot_last_altered_ts( + context, + dev_snapshot_id, + last_altered_ts=prod_plan_ts_1, + dev_last_altered_ts=dev_plan_ts, + ) + assert_snapshot_last_altered_ts(context, prod_snapshot_id, last_altered_ts=prod_plan_ts_1) + + # Case 4: Model is evaluated on run if any external model is fresh + adapter.execute(f"INSERT INTO {external_table2} (col2) VALUES (3)", quote_identifiers=False) + assert_model_evaluation(lambda: context.run(), day_delta=2) + + # Case 5: Model is evaluated if changed (case 3) even if the external model is not fresh + model_path.write_text(model_path.read_text().replace("col1 + col2", "col1 * col2 * 5")) + context.load() + assert_model_evaluation( + lambda: context.plan(auto_apply=True, no_prompts=True), + day_delta=3, + ) + + # Case 6: Model is evaluated on a restatement plan even if the external model is not fresh + assert_model_evaluation( + lambda: context.plan(restate_models=[model_name], auto_apply=True, no_prompts=True), + day_delta=4, + ) + + +@use_terminal_console +def test_mixed_model_freshness(ctx: TestContext, tmp_path: pathlib.Path): + """ + Scenario: Freshness for a model that depends on both external and SQLMesh models + """ + + adapter = ctx.engine_adapter + context, schema, (external_table,) = initialize_context(ctx, tmp_path, num_external_models=1) + + # Create parent model that depends on the external model + parent_model_name, _ = create_model( + "parent_model", + schema, + f"SELECT col1 AS new_col FROM {external_table}", + tmp_path, + ) + + # First child model depends only on the parent model + create_model( + "child_model1", + schema, + f"SELECT new_col FROM {parent_model_name}", + tmp_path, + ) + + # Second child model depends on the parent model and the external table + create_model( + "child_model2", + schema, + f"SELECT col1 + new_col FROM {parent_model_name}, {external_table}", + tmp_path, + ) + + # Third model does not depend on any models, so it should only be evaluated once + create_model( + "child_model3", + schema, + f"SELECT 1 AS col", + tmp_path, + ) + + context.load() + + # Case 1: New models are evaluated when introduced in a plan + prod_plan_ts_1, prod_plan_1 = assert_model_evaluation( + lambda: context.plan(auto_apply=True, no_prompts=True), + model_evaluations=4, + ) + + for new_snapshot in prod_plan_1.context_diff.new_snapshots: + assert_snapshot_last_altered_ts(context, new_snapshot, last_altered_ts=prod_plan_ts_1) + + # Case 2: Mixed models are evaluated if the upstream models (sqlmesh or external) become fresh + adapter.execute(f"INSERT INTO {external_table} (col1) VALUES (2)", quote_identifiers=False) + + assert_model_evaluation( + lambda: context.run(), was_evaluated=True, day_delta=1, model_evaluations=3 + ) + + # Case 3: Mixed models are still evaluated if breaking changes are introduced + create_model( + "child_model2", + schema, + f"SELECT col1 * new_col FROM {parent_model_name}, {external_table}", + tmp_path, + ) + + context.load() + + prod_plan_ts_2, prod_plan_2 = assert_model_evaluation( + lambda: context.plan(auto_apply=True, no_prompts=True), + day_delta=1, + model_evaluations=1, + ) + + assert prod_plan_2.context_diff.modified_snapshots + + assert_snapshot_last_altered_ts( + context, next(iter(prod_plan_2.context_diff.new_snapshots)), last_altered_ts=prod_plan_ts_2 + ) + + +def test_missing_external_model_freshness(ctx: TestContext, tmp_path: pathlib.Path): + """ + Scenario: Freshness for a model that depends on an external model that is missing + """ + adapter = ctx.engine_adapter + context, schema, (external_table,) = initialize_context(ctx, tmp_path) + + # Create model that depends on the external model + create_model( + "new_model", + schema, + f"SELECT * FROM {external_table}", + tmp_path, + ) + + context.load() + context.plan(auto_apply=True, no_prompts=True) + + # Case: By dropping the external table, the freshness signal should raise an error + # instead of silently succeeding/failing + adapter.execute(f"DROP TABLE {external_table}", quote_identifiers=False) + + with time_machine.travel(now() + timedelta(days=1)): + with pytest.raises(SignalEvalError): + context.run() + + +@use_terminal_console +def test_check_ready_intervals(ctx: TestContext, tmp_path: pathlib.Path): + """ + Scenario: Ensure that freshness evaluates the "ready" intervals of the parent snapshots i.e their + missing intervals plus their signals applied. + + """ + + def _write_user_signal(signal: str, tmp_path: pathlib.Path): + signal_code = f""" +import typing as t +from sqlmesh import signal + +@signal() +{signal} + """ + + test_signals = tmp_path / "signals/test_signals.py" + test_signals.parent.mkdir(parents=True, exist_ok=True) + test_signals.write_text(signal_code) + + context, schema, _ = initialize_context(ctx, tmp_path, num_external_models=0) + + _write_user_signal( + """ +def my_signal(batch): + return True + """, + tmp_path, + ) + + # Parent model depends on a custom signal + parent_model, _ = create_model( + "parent_model", + schema, + f"SELECT 1 AS col", + tmp_path, + signals="my_signal()", + ) + + # Create a new model that depends on the parent model + create_model( + "child_model", + schema, + f"SELECT * FROM {parent_model}", + tmp_path, + ) + + # Case 1: Both models are evaluated when introduced in a plan and subsequent runs, + # given that `my_signal()` always returns True. + context.load() + context.plan(auto_apply=True, no_prompts=True) + + assert_model_evaluation( + lambda: context.run(), + day_delta=2, + model_evaluations=2, + ) + + # Case 2: By changing the signal to return False, both models should not be evaluated. + _write_user_signal( + """ +def my_signal(batch): + return False + """, + tmp_path, + ) + + context.load() + context.plan(auto_apply=True, no_prompts=True) + + assert_model_evaluation( + lambda: context.run(), + day_delta=3, + was_evaluated=False, + ) + + +@use_terminal_console +def test_registered_and_unregistered_external_models( + ctx: TestContext, tmp_path: pathlib.Path, mocker: MockerFixture +): + """ + Scenario: Ensure that external models are queried for their last modified timestamp + regardless of whether they are present in the "external_models.yaml" file (registered) or not (unregistered) + """ + + adapter = ctx.engine_adapter + context, schema, (registered_external_table,) = initialize_context( + ctx, tmp_path, num_external_models=1 + ) + + current_catalog = ctx.engine_adapter.get_current_catalog() + + def normalize_external_table_name(external_table_name) -> str: + from sqlglot import exp + + normalized = exp.normalize_table_name( + f"{current_catalog}.{external_table_name}", dialect=ctx.dialect + ) + return exp.table_name(normalized, dialect=ctx.dialect, identify=True) + + unregistered_external_table = f"{schema}.unregistered_external_table" + + adapter.execute( + f"CREATE TABLE {unregistered_external_table} AS (SELECT 1 AS col)", + quote_identifiers=False, + ) + + create_model( + "new_model", + schema, + f"SELECT * FROM {unregistered_external_table}, {registered_external_table}", + tmp_path, + ) + + context.load() + context.plan(auto_apply=True, no_prompts=True) + + spy = mocker.spy( + sqlmesh.core.engine_adapter.SnowflakeEngineAdapter, "get_table_last_modified_ts" + ) + assert_model_evaluation( + lambda: context.run(), + day_delta=1, + was_evaluated=False, + ) + + assert spy.call_args_list + + # The first argument of "get_table_last_modified_ts" is a list of external table names in normalized form + # Ensure that this contains both external tables (registered and unregistered) + assert sorted(spy.call_args[0][1]) == sorted( + [ + normalize_external_table_name(registered_external_table), + normalize_external_table_name(unregistered_external_table), + ] + ) diff --git a/tests/core/engine_adapter/integration/test_integration.py b/tests/core/engine_adapter/integration/test_integration.py new file mode 100644 index 0000000000..1fba346db3 --- /dev/null +++ b/tests/core/engine_adapter/integration/test_integration.py @@ -0,0 +1,4082 @@ +# type: ignore +from __future__ import annotations + +import pathlib +import re +import sys +import typing as t +import shutil +from datetime import datetime, timedelta, date +from unittest import mock +from unittest.mock import patch +import logging + + +import time_machine + +import numpy as np # noqa: TID253 +import pandas as pd # noqa: TID253 +import pytest +import pytz +import time_machine +from sqlglot import exp, parse_one +from sqlglot.optimizer.normalize_identifiers import normalize_identifiers +from sqlglot.optimizer.qualify_columns import quote_identifiers + +from sqlmesh import Config, Context +from sqlmesh.cli.project_init import init_example_project +from sqlmesh.core.config.common import VirtualEnvironmentMode +from sqlmesh.core.config.connection import ConnectionConfig +import sqlmesh.core.dialect as d +from sqlmesh.core.environment import EnvironmentSuffixTarget +from sqlmesh.core.dialect import select_from_values +from sqlmesh.core.model import Model, load_sql_based_model +from sqlmesh.core.engine_adapter.shared import DataObject, DataObjectType +from sqlmesh.core.engine_adapter.mixins import RowDiffMixin, LogicalMergeMixin +from sqlmesh.core.model.definition import create_sql_model +from sqlmesh.core.plan import Plan +from sqlmesh.core.state_sync.db import EngineAdapterStateSync +from sqlmesh.core.snapshot import Snapshot, SnapshotChangeCategory +from sqlmesh.utils.date import now, to_date, to_time_column +from sqlmesh.core.table_diff import TableDiff +from sqlmesh.utils.errors import SQLMeshError +from sqlmesh.utils.pydantic import PydanticModel +from tests.conftest import SushiDataValidator +from tests.core.engine_adapter.integration import ( + TestContext, + MetadataResults, + TEST_SCHEMA, + wait_until, +) + +DATA_TYPE = exp.DataType.Type +VARCHAR_100 = exp.DataType.build("varchar(100)") + + +class PlanResults(PydanticModel): + plan: Plan + ctx: TestContext + schema_metadata: MetadataResults + internal_schema_metadata: MetadataResults + + @classmethod + def create(cls, plan: Plan, ctx: TestContext, schema_name: str): + schema_metadata = ctx.get_metadata_results(schema_name) + internal_schema_metadata = ctx.get_metadata_results(f"sqlmesh__{schema_name}") + return PlanResults( + plan=plan, + ctx=ctx, + schema_metadata=schema_metadata, + internal_schema_metadata=internal_schema_metadata, + ) + + def snapshot_for(self, model: Model) -> Snapshot: + return next((s for s in list(self.plan.snapshots.values()) if s.name == model.fqn)) + + def modified_snapshot_for(self, model: Model) -> Snapshot: + return next((s for s in list(self.plan.modified_snapshots.values()) if s.name == model.fqn)) + + def table_name_for( + self, snapshot_or_model: Snapshot | Model, is_deployable: bool = True + ) -> str: + snapshot = ( + snapshot_or_model + if isinstance(snapshot_or_model, Snapshot) + else self.snapshot_for(snapshot_or_model) + ) + table_name = snapshot.table_name(is_deployable) + return exp.to_table(table_name).this.sql(dialect=self.ctx.dialect) + + def dev_table_name_for(self, snapshot: Snapshot) -> str: + return self.table_name_for(snapshot, is_deployable=False) + + +def test_connection(ctx: TestContext): + cursor_from_connection = ctx.engine_adapter.connection.cursor() + cursor_from_connection.execute("SELECT 1") + assert cursor_from_connection.fetchone()[0] == 1 + + +def test_catalog_operations(ctx: TestContext): + if ( + ctx.engine_adapter.catalog_support.is_unsupported + or ctx.engine_adapter.catalog_support.is_single_catalog_only + ): + pytest.skip( + f"Engine adapter {ctx.engine_adapter.dialect} doesn't support catalog operations" + ) + + # use a unique name so that integration tests on cloud databases can run in parallel + catalog_name = "testing" if not ctx.is_remote else ctx.add_test_suffix("testing") + + ctx.create_catalog(catalog_name) + + current_catalog = ctx.engine_adapter.get_current_catalog().lower() + ctx.engine_adapter.set_current_catalog(catalog_name) + assert ctx.engine_adapter.get_current_catalog().lower() == catalog_name + ctx.engine_adapter.set_current_catalog(current_catalog) + assert ctx.engine_adapter.get_current_catalog().lower() == current_catalog + + # cleanup cloud databases since they persist between runs + if ctx.is_remote: + ctx.drop_catalog(catalog_name) + + +def test_drop_schema_catalog(ctx: TestContext, caplog): + def drop_schema_and_validate(schema_name: str): + ctx.engine_adapter.drop_schema(schema_name, cascade=True) + results = ctx.get_metadata_results(schema_name) + assert ( + len(results.tables) + == len(results.views) + == len(results.materialized_views) + == len(results.non_temp_tables) + == 0 + ) + + def create_objects_and_validate(schema_name: str): + ctx.engine_adapter.create_schema(schema_name) + ctx.engine_adapter.create_view(f"{schema_name}.test_view", parse_one("SELECT 1 as col")) + ctx.engine_adapter.create_table( + f"{schema_name}.test_table", {"col": exp.DataType.build("int")} + ) + ctx.engine_adapter.create_table( + f"{schema_name}.replace_table", {"col": exp.DataType.build("int")} + ) + ctx.engine_adapter.replace_query( + f"{schema_name}.replace_table", + parse_one("SELECT 1 as col"), + {"col": exp.DataType.build("int")}, + ) + results = ctx.get_metadata_results(schema_name) + assert len(results.tables) == 2 + assert len(results.views) == 1 + assert len(results.materialized_views) == 0 + assert len(results.non_temp_tables) == 2 + + if ctx.engine_adapter.catalog_support.is_unsupported: + pytest.skip( + f"Engine adapter {ctx.engine_adapter.dialect} doesn't support catalog operations" + ) + if ctx.dialect == "spark": + pytest.skip( + "Currently local spark is configured to have iceberg be the testing catalog and drop cascade doesn't work on iceberg. Skipping until we have time to fix." + ) + + catalog_name = "testing" if not ctx.is_remote else ctx.add_test_suffix("testing") + if ctx.dialect == "bigquery": + catalog_name = ctx.engine_adapter.get_current_catalog() + + catalog_name = normalize_identifiers(catalog_name, dialect=ctx.dialect).sql(dialect=ctx.dialect) + + ctx.create_catalog(catalog_name) + + schema = ctx.schema("drop_schema_catalog_test", catalog_name) + if ctx.engine_adapter.catalog_support.is_single_catalog_only: + with pytest.raises( + SQLMeshError, match="requires that all catalog operations be against a single catalog" + ): + drop_schema_and_validate(schema) + create_objects_and_validate(schema) + return + drop_schema_and_validate(schema) + create_objects_and_validate(schema) + + if ctx.is_remote: + ctx.drop_catalog(catalog_name) + + +def test_temp_table(ctx_query_and_df: TestContext): + ctx = ctx_query_and_df + input_data = pd.DataFrame( + [ + {"id": 1, "ds": "2022-01-01"}, + {"id": 2, "ds": "2022-01-02"}, + {"id": 3, "ds": "2022-01-03"}, + ] + ) + table = ctx.table("example") + + with ctx.engine_adapter.temp_table( + ctx.input_data(input_data), table.sql(), table_format=ctx.default_table_format + ) as table_name: + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.tables) == 1 + assert len(results.non_temp_tables) == 0 + assert len(results.materialized_views) == 0 + ctx.compare_with_current(table_name, input_data) + + results = ctx.get_metadata_results() + assert len(results.views) == len(results.tables) == len(results.non_temp_tables) == 0 + + +def test_create_table(ctx: TestContext): + table = ctx.table("test_table") + ctx.engine_adapter.create_table( + table, + {"id": exp.DataType.build("int")}, + table_description="test table description", + column_descriptions={"id": "test id column description"}, + table_format=ctx.default_table_format, + ) + results = ctx.get_metadata_results() + assert len(results.tables) == 1 + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert results.tables[0] == table.name + + if ctx.engine_adapter.COMMENT_CREATION_TABLE.is_supported: + table_description = ctx.get_table_comment(table.db, "test_table") + column_comments = ctx.get_column_comments(table.db, "test_table") + assert table_description == "test table description" + assert column_comments == {"id": "test id column description"} + + +def test_ctas(ctx_query_and_df: TestContext): + ctx = ctx_query_and_df + table = ctx.table("test_table") + + input_data = pd.DataFrame( + [ + {"id": 1, "ds": "2022-01-01"}, + {"id": 2, "ds": "2022-01-02"}, + {"id": 3, "ds": "2022-01-03"}, + ] + ) + ctx.engine_adapter.ctas( + table, + ctx.input_data(input_data), + table_description="test table description", + column_descriptions={"id": "test id column description"}, + table_format=ctx.default_table_format, + ) + + results = ctx.get_metadata_results(schema=table.db) + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current(table, input_data) + + if ctx.engine_adapter.COMMENT_CREATION_TABLE.is_supported: + table_description = ctx.get_table_comment(table.db, table.name) + column_comments = ctx.get_column_comments(table.db, table.name) + + assert table_description == "test table description" + assert column_comments == {"id": "test id column description"} + + # ensure we don't hit clickhouse INSERT with LIMIT 0 bug on CTAS + if ctx.dialect == "clickhouse": + ctx.engine_adapter.ctas(table, exp.select("1").limit(0)) + + +def test_ctas_source_columns(ctx_query_and_df: TestContext): + ctx = ctx_query_and_df + table = ctx.table("test_table") + + columns_to_types = ctx.columns_to_types.copy() + columns_to_types["ignored_column"] = exp.DataType.build("int") + + input_data = pd.DataFrame( + [ + {"id": 1, "ds": "2022-01-01", "ignored_source": "ignored_value"}, + {"id": 2, "ds": "2022-01-02", "ignored_source": "ignored_value"}, + {"id": 3, "ds": "2022-01-03", "ignored_source": "ignored_value"}, + ] + ) + ctx.engine_adapter.ctas( + table, + ctx.input_data(input_data), + table_description="test table description", + column_descriptions={"id": "test id column description"}, + table_format=ctx.default_table_format, + target_columns_to_types=columns_to_types, + source_columns=["id", "ds", "ignored_source"], + ) + + expected_data = input_data.copy() + expected_data["ignored_column"] = pd.Series() + expected_data = expected_data.drop(columns=["ignored_source"]) + + results = ctx.get_metadata_results(schema=table.db) + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current(table, expected_data) + + if ctx.engine_adapter.COMMENT_CREATION_TABLE.is_supported: + table_description = ctx.get_table_comment(table.db, table.name) + column_comments = ctx.get_column_comments(table.db, table.name) + + assert table_description == "test table description" + assert column_comments == {"id": "test id column description"} + + # ensure we don't hit clickhouse INSERT with LIMIT 0 bug on CTAS + if ctx.dialect == "clickhouse": + ctx.engine_adapter.ctas(table, exp.select("1").limit(0)) + + +def test_create_view(ctx_query_and_df: TestContext): + ctx = ctx_query_and_df + input_data = pd.DataFrame( + [ + {"id": 1, "ds": "2022-01-01"}, + {"id": 2, "ds": "2022-01-02"}, + {"id": 3, "ds": "2022-01-03"}, + ] + ) + view = ctx.table("test_view") + ctx.engine_adapter.create_view( + view, + ctx.input_data(input_data), + table_description="test view description", + column_descriptions={"id": "test id column description"}, + ) + results = ctx.get_metadata_results() + assert len(results.tables) == 0 + assert len(results.views) == 1 + assert len(results.materialized_views) == 0 + assert results.views[0] == view.name + ctx.compare_with_current(view, input_data) + + if ctx.engine_adapter.COMMENT_CREATION_VIEW.is_supported: + table_description = ctx.get_table_comment(view.db, "test_view", table_kind="VIEW") + column_comments = ctx.get_column_comments(view.db, "test_view", table_kind="VIEW") + + # Query: + # In the query test, columns_to_types are not available when the view is created. Since we + # can only register column comments in the CREATE VIEW schema expression with columns_to_types + # available, the column comments must be registered via post-creation commands. Some engines, + # such as Spark and Snowflake, do not support view column comments via post-creation commands. + assert table_description == "test view description" + assert column_comments == ( + {} + if ( + ctx.test_type == "query" + and not ctx.engine_adapter.COMMENT_CREATION_VIEW.supports_column_comment_commands + ) + else {"id": "test id column description"} + ) + + +def test_create_view_source_columns(ctx_query_and_df: TestContext): + ctx = ctx_query_and_df + + columns_to_types = ctx.columns_to_types.copy() + columns_to_types["ignored_column"] = exp.DataType.build("int") + + input_data = pd.DataFrame( + [ + {"id": 1, "ds": "2022-01-01", "ignored_source": "ignored_value"}, + {"id": 2, "ds": "2022-01-02", "ignored_source": "ignored_value"}, + {"id": 3, "ds": "2022-01-03", "ignored_source": "ignored_value"}, + ] + ) + view = ctx.table("test_view") + ctx.engine_adapter.create_view( + view, + ctx.input_data(input_data), + table_description="test view description", + column_descriptions={"id": "test id column description"}, + source_columns=["id", "ds", "ignored_source"], + target_columns_to_types=columns_to_types, + ) + + expected_data = input_data.copy() + expected_data["ignored_column"] = pd.Series() + expected_data = expected_data.drop(columns=["ignored_source"]) + + results = ctx.get_metadata_results() + assert len(results.tables) == 0 + assert len(results.views) == 1 + assert len(results.materialized_views) == 0 + assert results.views[0] == view.name + ctx.compare_with_current(view, expected_data) + + if ctx.engine_adapter.COMMENT_CREATION_VIEW.is_supported: + table_description = ctx.get_table_comment(view.db, "test_view", table_kind="VIEW") + column_comments = ctx.get_column_comments(view.db, "test_view", table_kind="VIEW") + + assert table_description == "test view description" + assert column_comments == {"id": "test id column description"} + + +def test_materialized_view(ctx_query_and_df: TestContext): + ctx = ctx_query_and_df + if not ctx.engine_adapter.SUPPORTS_MATERIALIZED_VIEWS: + pytest.skip(f"Engine adapter {ctx.engine_adapter} doesn't support materialized views") + if ctx.engine_adapter.dialect == "databricks": + pytest.skip( + "Databricks requires DBSQL Serverless or Pro warehouse to test materialized views which we do not have setup" + ) + if ctx.engine_adapter.dialect == "snowflake": + pytest.skip("Snowflake requires enterprise edition which we do not have setup") + input_data = pd.DataFrame( + [ + {"id": 1, "ds": "2022-01-01"}, + {"id": 2, "ds": "2022-01-02"}, + {"id": 3, "ds": "2022-01-03"}, + ] + ) + source_table = ctx.table("source_table") + ctx.engine_adapter.ctas(source_table, ctx.input_data(input_data), ctx.columns_to_types) + view = ctx.table("test_view") + view_query = exp.select(*ctx.columns_to_types).from_(source_table) + ctx.engine_adapter.create_view(view, view_query, materialized=True) + results = ctx.get_metadata_results() + # Redshift considers the underlying dataset supporting materialized views as a table therefore we get 2 + # tables in the result + if ctx.engine_adapter.dialect == "redshift": + assert len(results.tables) == 2 + else: + assert len(results.tables) == 1 + assert len(results.views) == 0 + assert len(results.materialized_views) == 1 + assert results.materialized_views[0] == view.name + ctx.compare_with_current(view, input_data) + # Make sure that dropping a materialized view also works + ctx.engine_adapter.drop_view(view, materialized=True) + results = ctx.get_metadata_results() + assert len(results.materialized_views) == 0 + + +def test_drop_schema(ctx: TestContext): + ctx.columns_to_types = {"one": "int"} + schema = ctx.schema(TEST_SCHEMA) + ctx.engine_adapter.drop_schema(schema, cascade=True) + results = ctx.get_metadata_results() + assert len(results.tables) == 0 + assert len(results.views) == 0 + + ctx.engine_adapter.create_schema(schema) + view = ctx.table("test_view") + view_query = exp.Select().select(exp.Literal.number(1).as_("one")) + ctx.engine_adapter.create_view(view, view_query, ctx.columns_to_types) + results = ctx.get_metadata_results() + assert len(results.tables) == 0 + assert len(results.views) == 1 + + ctx.engine_adapter.drop_schema(schema, cascade=True) + results = ctx.get_metadata_results() + assert len(results.tables) == 0 + assert len(results.views) == 0 + + +def test_nan_roundtrip(ctx_df: TestContext): + ctx = ctx_df + ctx.engine_adapter.DEFAULT_BATCH_SIZE = sys.maxsize + table = ctx.table("test_table") + # Initial Load + input_data = pd.DataFrame( + [ + {"id": 1, "ds": "2022-01-01"}, + {"id": 2, "ds": "2022-01-02"}, + {"id": np.nan, "ds": np.nan}, + ] + ) + ctx.engine_adapter.create_table(table, ctx.columns_to_types) + ctx.engine_adapter.replace_query( + table, + ctx.input_data(input_data), + target_columns_to_types=ctx.columns_to_types, + ) + results = ctx.get_metadata_results() + assert not results.views + assert not results.materialized_views + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current(table, input_data) + + +def test_replace_query(ctx_query_and_df: TestContext): + ctx = ctx_query_and_df + ctx.engine_adapter.DEFAULT_BATCH_SIZE = sys.maxsize + table = ctx.table("test_table") + # Initial Load + input_data = pd.DataFrame( + [ + {"id": 1, "ds": "2022-01-01"}, + {"id": 2, "ds": "2022-01-02"}, + {"id": 3, "ds": "2022-01-03"}, + ] + ) + ctx.engine_adapter.create_table( + table, ctx.columns_to_types, table_format=ctx.default_table_format + ) + ctx.engine_adapter.replace_query( + table, + ctx.input_data(input_data), + # Spark based engines do a create table -> insert overwrite instead of replace. If columns to types aren't + # provided then it checks the table itself for types. This is fine within SQLMesh since we always know the tables + # exist prior to evaluation but when running these tests that isn't the case. As a result we just pass in + # columns_to_types for these two engines so we can still test inference on the other ones + target_columns_to_types=ctx.columns_to_types + if ctx.dialect in ["spark", "databricks"] + else None, + table_format=ctx.default_table_format, + ) + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current(table, input_data) + + # Replace that we only need to run once + if type == "df": + replace_data = pd.DataFrame( + [ + {"id": 4, "ds": "2022-01-04"}, + {"id": 5, "ds": "2022-01-05"}, + {"id": 6, "ds": "2022-01-06"}, + ] + ) + ctx.engine_adapter.replace_query( + table, + ctx.input_data(replace_data), + target_columns_to_types=( + ctx.columns_to_types if ctx.dialect in ["spark", "databricks"] else None + ), + table_format=ctx.default_table_format, + ) + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current(table, replace_data) + + +def test_replace_query_source_columns(ctx_query_and_df: TestContext): + ctx = ctx_query_and_df + ctx.engine_adapter.DEFAULT_BATCH_SIZE = sys.maxsize + table = ctx.table("test_table") + + columns_to_types = ctx.columns_to_types.copy() + columns_to_types["ignored_column"] = exp.DataType.build("int") + + # Initial Load + input_data = pd.DataFrame( + [ + {"id": 1, "ds": "2022-01-01", "ignored_source": "ignored_value"}, + {"id": 2, "ds": "2022-01-02", "ignored_source": "ignored_value"}, + {"id": 3, "ds": "2022-01-03", "ignored_source": "ignored_value"}, + ] + ) + ctx.engine_adapter.create_table(table, columns_to_types, table_format=ctx.default_table_format) + ctx.engine_adapter.replace_query( + table, + ctx.input_data(input_data), + table_format=ctx.default_table_format, + source_columns=["id", "ds", "ignored_source"], + target_columns_to_types=columns_to_types, + ) + expected_data = input_data.copy() + expected_data["ignored_column"] = pd.Series() + expected_data = expected_data.drop(columns=["ignored_source"]) + + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current(table, expected_data) + + # Replace that we only need to run once + if type == "df": + replace_data = pd.DataFrame( + [ + {"id": 4, "ds": "2022-01-04"}, + {"id": 5, "ds": "2022-01-05"}, + {"id": 6, "ds": "2022-01-06"}, + ] + ) + ctx.engine_adapter.replace_query( + table, + ctx.input_data(replace_data), + table_format=ctx.default_table_format, + source_columns=["id", "ds"], + target_columns_to_types=columns_to_types, + ) + expected_data = replace_data.copy() + expected_data["ignored_column"] = pd.Series() + + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current(table, expected_data) + + +def test_replace_query_batched(ctx_query_and_df: TestContext): + ctx = ctx_query_and_df + ctx.engine_adapter.DEFAULT_BATCH_SIZE = 1 + table = ctx.table("test_table") + # Initial Load + input_data = pd.DataFrame( + [ + {"id": 1, "ds": "2022-01-01"}, + {"id": 2, "ds": "2022-01-02"}, + {"id": 3, "ds": "2022-01-03"}, + ] + ) + ctx.engine_adapter.create_table( + table, ctx.columns_to_types, table_format=ctx.default_table_format + ) + ctx.engine_adapter.replace_query( + table, + ctx.input_data(input_data), + # Spark based engines do a create table -> insert overwrite instead of replace. If columns to types aren't + # provided then it checks the table itself for types. This is fine within SQLMesh since we always know the tables + # exist prior to evaluation but when running these tests that isn't the case. As a result we just pass in + # columns_to_types for these two engines so we can still test inference on the other ones + target_columns_to_types=ctx.columns_to_types + if ctx.dialect in ["spark", "databricks"] + else None, + table_format=ctx.default_table_format, + ) + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current(table, input_data) + + # Replace that we only need to run once + if ctx.test_type == "df": + replace_data = pd.DataFrame( + [ + {"id": 4, "ds": "2022-01-04"}, + {"id": 5, "ds": "2022-01-05"}, + {"id": 6, "ds": "2022-01-06"}, + ] + ) + ctx.engine_adapter.replace_query( + table, + ctx.input_data(replace_data), + target_columns_to_types=( + ctx.columns_to_types if ctx.dialect in ["spark", "databricks"] else None + ), + table_format=ctx.default_table_format, + ) + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current(table, replace_data) + + +def test_insert_append(ctx_query_and_df: TestContext): + ctx = ctx_query_and_df + table = ctx.table("test_table") + ctx.engine_adapter.create_table( + table, ctx.columns_to_types, table_format=ctx.default_table_format + ) + # Initial Load + input_data = pd.DataFrame( + [ + {"id": 1, "ds": "2022-01-01"}, + {"id": 2, "ds": "2022-01-02"}, + {"id": 3, "ds": "2022-01-03"}, + ] + ) + ctx.engine_adapter.insert_append(table, ctx.input_data(input_data)) + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current(table, input_data) + + # Replace that we only need to run once + if ctx.test_type == "df": + append_data = pd.DataFrame( + [ + {"id": 4, "ds": "2022-01-04"}, + {"id": 5, "ds": "2022-01-05"}, + {"id": 6, "ds": "2022-01-06"}, + ] + ) + ctx.engine_adapter.insert_append(table, ctx.input_data(append_data)) + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) in [1, 2, 3] + assert len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current(table, pd.concat([input_data, append_data])) + + +def test_insert_append_source_columns(ctx_query_and_df: TestContext): + ctx = ctx_query_and_df + table = ctx.table("test_table") + columns_to_types = ctx.columns_to_types.copy() + columns_to_types["ignored_column"] = exp.DataType.build("int") + ctx.engine_adapter.create_table(table, columns_to_types, table_format=ctx.default_table_format) + # Initial Load + input_data = pd.DataFrame( + [ + {"id": 1, "ds": "2022-01-01", "ignored_source": "ignored_value"}, + {"id": 2, "ds": "2022-01-02", "ignored_source": "ignored_value"}, + {"id": 3, "ds": "2022-01-03", "ignored_source": "ignored_value"}, + ] + ) + ctx.engine_adapter.insert_append( + table, + ctx.input_data(input_data), + source_columns=["id", "ds", "ignored_source"], + target_columns_to_types=columns_to_types, + ) + expected_data = input_data.copy() + expected_data["ignored_column"] = pd.Series() + expected_data = expected_data.drop(columns=["ignored_source"]) + + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current(table, expected_data) + + # Replace that we only need to run once + if ctx.test_type == "df": + append_data = pd.DataFrame( + [ + {"id": 4, "ds": "2022-01-04", "ignored_source": "ignored_value"}, + {"id": 5, "ds": "2022-01-05", "ignored_source": "ignored_value"}, + {"id": 6, "ds": "2022-01-06", "ignored_source": "ignored_value"}, + ] + ) + ctx.engine_adapter.insert_append( + table, + ctx.input_data(append_data), + source_columns=["id", "ds", "ignored_source"], + target_columns_to_types=columns_to_types, + ) + append_expected_data = append_data.copy() + append_expected_data["ignored_column"] = pd.Series() + append_expected_data = append_expected_data.drop(columns=["ignored_source"]) + + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) in [1, 2, 3] + assert len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current(table, pd.concat([expected_data, append_expected_data])) + + +def test_insert_overwrite_by_time_partition(ctx_query_and_df: TestContext): + ctx = ctx_query_and_df + ds_type = "string" + if ctx.dialect == "bigquery": + ds_type = "datetime" + if ctx.dialect == "tsql": + ds_type = "varchar(max)" + + ctx.columns_to_types = {"id": "int", "ds": ds_type} + table = ctx.table("test_table") + if ctx.dialect == "bigquery": + partitioned_by = ["DATE(ds)"] + else: + partitioned_by = ctx.partitioned_by # type: ignore + ctx.engine_adapter.create_table( + table, + ctx.columns_to_types, + partitioned_by=partitioned_by, + partition_interval_unit="DAY", + table_format=ctx.default_table_format, + ) + input_data = pd.DataFrame( + [ + {"id": 1, ctx.time_column: "2022-01-01"}, + {"id": 2, ctx.time_column: "2022-01-02"}, + {"id": 3, ctx.time_column: "2022-01-03"}, + ] + ) + ctx.engine_adapter.insert_overwrite_by_time_partition( + table, + ctx.input_data(input_data), + start="2022-01-02", + end="2022-01-03", + time_formatter=ctx.time_formatter, + time_column=ctx.time_column, + target_columns_to_types=ctx.columns_to_types, + ) + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + + if ctx.dialect == "trino": + # trino has some lag between partitions being registered and data showing up + wait_until(lambda: len(ctx.get_current_data(table)) > 0) + + ctx.compare_with_current(table, input_data.iloc[1:]) + + if ctx.test_type == "df": + overwrite_data = pd.DataFrame( + [ + {"id": 10, ctx.time_column: "2022-01-03"}, + {"id": 4, ctx.time_column: "2022-01-04"}, + {"id": 5, ctx.time_column: "2022-01-05"}, + ] + ) + ctx.engine_adapter.insert_overwrite_by_time_partition( + table, + ctx.input_data(overwrite_data), + start="2022-01-03", + end="2022-01-05", + time_formatter=ctx.time_formatter, + time_column=ctx.time_column, + target_columns_to_types=ctx.columns_to_types, + ) + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + + if ctx.dialect == "trino": + wait_until(lambda: len(ctx.get_current_data(table)) > 2) + + ctx.compare_with_current( + table, + pd.DataFrame( + [ + {"id": 2, ctx.time_column: "2022-01-02"}, + {"id": 10, ctx.time_column: "2022-01-03"}, + {"id": 4, ctx.time_column: "2022-01-04"}, + {"id": 5, ctx.time_column: "2022-01-05"}, + ] + ), + ) + + +def test_insert_overwrite_by_time_partition_source_columns(ctx_query_and_df: TestContext): + ctx = ctx_query_and_df + ds_type = "string" + if ctx.dialect == "bigquery": + ds_type = "datetime" + if ctx.dialect == "tsql": + ds_type = "varchar(max)" + + ctx.columns_to_types = {"id": "int", "ds": ds_type} + columns_to_types = { + "id": exp.DataType.build("int"), + "ignored_column": exp.DataType.build("int"), + "ds": exp.DataType.build(ds_type), + } + table = ctx.table("test_table") + if ctx.dialect == "bigquery": + partitioned_by = ["DATE(ds)"] + else: + partitioned_by = ctx.partitioned_by # type: ignore + ctx.engine_adapter.create_table( + table, + columns_to_types, + partitioned_by=partitioned_by, + partition_interval_unit="DAY", + table_format=ctx.default_table_format, + ) + input_data = pd.DataFrame( + [ + {"id": 1, ctx.time_column: "2022-01-01", "ignored_source": "ignored_value"}, + {"id": 2, ctx.time_column: "2022-01-02", "ignored_source": "ignored_value"}, + {"id": 3, ctx.time_column: "2022-01-03", "ignored_source": "ignored_value"}, + ] + ) + ctx.engine_adapter.insert_overwrite_by_time_partition( + table, + ctx.input_data(input_data), + start="2022-01-02", + end="2022-01-03", + time_formatter=ctx.time_formatter, + time_column=ctx.time_column, + target_columns_to_types=columns_to_types, + source_columns=["id", "ds", "ignored_source"], + ) + + expected_data = input_data.copy() + expected_data = expected_data.drop(columns=["ignored_source"]) + expected_data.insert(len(expected_data.columns) - 1, "ignored_column", pd.Series()) + + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + + if ctx.dialect == "trino": + # trino has some lag between partitions being registered and data showing up + wait_until(lambda: len(ctx.get_current_data(table)) > 0) + + ctx.compare_with_current(table, expected_data.iloc[1:]) + + if ctx.test_type == "df": + overwrite_data = pd.DataFrame( + [ + {"id": 10, ctx.time_column: "2022-01-03", "ignored_source": "ignored_value"}, + {"id": 4, ctx.time_column: "2022-01-04", "ignored_source": "ignored_value"}, + {"id": 5, ctx.time_column: "2022-01-05", "ignored_source": "ignored_value"}, + ] + ) + ctx.engine_adapter.insert_overwrite_by_time_partition( + table, + ctx.input_data(overwrite_data), + start="2022-01-03", + end="2022-01-05", + time_formatter=ctx.time_formatter, + time_column=ctx.time_column, + target_columns_to_types=columns_to_types, + source_columns=["id", "ds", "ignored_source"], + ) + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + + if ctx.dialect == "trino": + wait_until(lambda: len(ctx.get_current_data(table)) > 2) + + ctx.compare_with_current( + table, + pd.DataFrame( + [ + {"id": 2, "ignored_column": None, ctx.time_column: "2022-01-02"}, + {"id": 10, "ignored_column": None, ctx.time_column: "2022-01-03"}, + {"id": 4, "ignored_column": None, ctx.time_column: "2022-01-04"}, + {"id": 5, "ignored_column": None, ctx.time_column: "2022-01-05"}, + ] + ), + ) + + +def test_merge(ctx_query_and_df: TestContext): + ctx = ctx_query_and_df + if not ctx.supports_merge: + pytest.skip(f"{ctx.dialect} doesn't support merge") + + table = ctx.table("test_table") + + # Athena only supports MERGE on Iceberg tables + # And it cant fall back to a logical merge on Hive tables because it cant delete records + table_format = "iceberg" if ctx.dialect == "athena" else None + + ctx.engine_adapter.create_table(table, ctx.columns_to_types, table_format=table_format) + input_data = pd.DataFrame( + [ + {"id": 1, "ds": "2022-01-01"}, + {"id": 2, "ds": "2022-01-02"}, + {"id": 3, "ds": "2022-01-03"}, + ] + ) + ctx.engine_adapter.merge( + table, + ctx.input_data(input_data), + target_columns_to_types=None, + unique_key=[exp.to_identifier("id")], + ) + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current(table, input_data) + + if ctx.test_type == "df": + merge_data = pd.DataFrame( + [ + {"id": 2, "ds": "2022-01-10"}, + {"id": 4, "ds": "2022-01-04"}, + {"id": 5, "ds": "2022-01-05"}, + ] + ) + ctx.engine_adapter.merge( + table, + ctx.input_data(merge_data), + target_columns_to_types=None, + unique_key=[exp.to_identifier("id")], + ) + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current( + table, + pd.DataFrame( + [ + {"id": 1, "ds": "2022-01-01"}, + {"id": 2, "ds": "2022-01-10"}, + {"id": 3, "ds": "2022-01-03"}, + {"id": 4, "ds": "2022-01-04"}, + {"id": 5, "ds": "2022-01-05"}, + ] + ), + ) + + +def test_merge_source_columns(ctx_query_and_df: TestContext): + ctx = ctx_query_and_df + if not ctx.supports_merge: + pytest.skip(f"{ctx.dialect} doesn't support merge") + + table = ctx.table("test_table") + + # Athena only supports MERGE on Iceberg tables + # And it cant fall back to a logical merge on Hive tables because it cant delete records + table_format = "iceberg" if ctx.dialect == "athena" else None + + columns_to_types = ctx.columns_to_types.copy() + columns_to_types["ignored_column"] = exp.DataType.build("int") + + ctx.engine_adapter.create_table(table, columns_to_types, table_format=table_format) + input_data = pd.DataFrame( + [ + {"id": 1, "ds": "2022-01-01", "ignored_source": "ignored_value"}, + {"id": 2, "ds": "2022-01-02", "ignored_source": "ignored_value"}, + {"id": 3, "ds": "2022-01-03", "ignored_source": "ignored_value"}, + ] + ) + ctx.engine_adapter.merge( + table, + ctx.input_data(input_data), + unique_key=[exp.to_identifier("id")], + target_columns_to_types=columns_to_types, + source_columns=["id", "ds", "ignored_source"], + ) + + expected_data = input_data.copy() + expected_data["ignored_column"] = pd.Series() + expected_data = expected_data.drop(columns=["ignored_source"]) + + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current(table, expected_data) + + if ctx.test_type == "df": + merge_data = pd.DataFrame( + [ + {"id": 2, "ds": "2022-01-10", "ignored_source": "ignored_value"}, + {"id": 4, "ds": "2022-01-04", "ignored_source": "ignored_value"}, + {"id": 5, "ds": "2022-01-05", "ignored_source": "ignored_value"}, + ] + ) + ctx.engine_adapter.merge( + table, + ctx.input_data(merge_data), + unique_key=[exp.to_identifier("id")], + target_columns_to_types=columns_to_types, + source_columns=["id", "ds", "ignored_source"], + ) + + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current( + table, + pd.DataFrame( + [ + {"id": 1, "ds": "2022-01-01", "ignored_column": None}, + {"id": 2, "ds": "2022-01-10", "ignored_column": None}, + {"id": 3, "ds": "2022-01-03", "ignored_column": None}, + {"id": 4, "ds": "2022-01-04", "ignored_column": None}, + {"id": 5, "ds": "2022-01-05", "ignored_column": None}, + ] + ), + ) + + +def test_scd_type_2_by_time(ctx_query_and_df: TestContext): + ctx = ctx_query_and_df + # Athena only supports the operations required for SCD models on Iceberg tables + if ctx.mark == "athena_hive": + pytest.skip("SCD Type 2 is only supported on Athena / Iceberg") + + time_type = exp.DataType.build("timestamp") + + ctx.columns_to_types = { + "id": "int", + "name": "string", + "updated_at": time_type, + "valid_from": time_type, + "valid_to": time_type, + } + table = ctx.table("test_table") + input_schema = { + k: v for k, v in ctx.columns_to_types.items() if k not in ("valid_from", "valid_to") + } + + ctx.engine_adapter.create_table( + table, ctx.columns_to_types, table_format=ctx.default_table_format + ) + input_data = pd.DataFrame( + [ + {"id": 1, "name": "a", "updated_at": "2022-01-01 00:00:00"}, + {"id": 2, "name": "b", "updated_at": "2022-01-02 00:00:00"}, + {"id": 3, "name": "c", "updated_at": "2022-01-03 00:00:00"}, + ] + ) + ctx.engine_adapter.scd_type_2_by_time( + table, + ctx.input_data(input_data, input_schema), + unique_key=[parse_one("COALESCE(id, -1)")], + valid_from_col=exp.column("valid_from", quoted=True), + valid_to_col=exp.column("valid_to", quoted=True), + updated_at_col=exp.column("updated_at", quoted=True), + execution_time="2023-01-01 00:00:00", + updated_at_as_valid_from=False, + target_columns_to_types=input_schema, + table_format=ctx.default_table_format, + truncate=True, + ) + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current( + table, + pd.DataFrame( + [ + { + "id": 1, + "name": "a", + "updated_at": "2022-01-01 00:00:00", + "valid_from": "1970-01-01 00:00:00", + "valid_to": pd.NaT, + }, + { + "id": 2, + "name": "b", + "updated_at": "2022-01-02 00:00:00", + "valid_from": "1970-01-01 00:00:00", + "valid_to": pd.NaT, + }, + { + "id": 3, + "name": "c", + "updated_at": "2022-01-03 00:00:00", + "valid_from": "1970-01-01 00:00:00", + "valid_to": pd.NaT, + }, + ] + ), + ) + + if ctx.test_type == "query": + return + + current_data = pd.DataFrame( + [ + # Change `a` to `x` + {"id": 1, "name": "x", "updated_at": "2022-01-04 00:00:00"}, + # Delete + # {"id": 2, "name": "b", "updated_at": "2022-01-02 00:00:00"}, + # No change + {"id": 3, "name": "c", "updated_at": "2022-01-03 00:00:00"}, + # Add + {"id": 4, "name": "d", "updated_at": "2022-01-04 00:00:00"}, + ] + ) + ctx.engine_adapter.scd_type_2_by_time( + table, + ctx.input_data(current_data, input_schema), + unique_key=[exp.to_column("id")], + valid_from_col=exp.column("valid_from", quoted=True), + valid_to_col=exp.column("valid_to", quoted=True), + updated_at_col=exp.column("updated_at", quoted=True), + execution_time="2023-01-05 00:00:00", + updated_at_as_valid_from=False, + target_columns_to_types=input_schema, + table_format=ctx.default_table_format, + truncate=False, + ) + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current( + table, + pd.DataFrame( + [ + { + "id": 1, + "name": "a", + "updated_at": "2022-01-01 00:00:00", + "valid_from": "1970-01-01 00:00:00", + "valid_to": "2022-01-04 00:00:00", + }, + { + "id": 1, + "name": "x", + "updated_at": "2022-01-04 00:00:00", + "valid_from": "2022-01-04 00:00:00", + "valid_to": pd.NaT, + }, + { + "id": 2, + "name": "b", + "updated_at": "2022-01-02 00:00:00", + "valid_from": "1970-01-01 00:00:00", + "valid_to": "2023-01-05 00:00:00", + }, + { + "id": 3, + "name": "c", + "updated_at": "2022-01-03 00:00:00", + "valid_from": "1970-01-01 00:00:00", + "valid_to": pd.NaT, + }, + { + "id": 4, + "name": "d", + "updated_at": "2022-01-04 00:00:00", + "valid_from": "1970-01-01 00:00:00", + "valid_to": pd.NaT, + }, + ] + ), + ) + + +def test_scd_type_2_by_time_source_columns(ctx_query_and_df: TestContext): + ctx = ctx_query_and_df + # Athena only supports the operations required for SCD models on Iceberg tables + if ctx.mark == "athena_hive": + pytest.skip("SCD Type 2 is only supported on Athena / Iceberg") + + time_type = exp.DataType.build("timestamp") + + ctx.columns_to_types = { + "id": "int", + "name": "string", + "updated_at": time_type, + "valid_from": time_type, + "valid_to": time_type, + } + columns_to_types = ctx.columns_to_types.copy() + columns_to_types["ignored_column"] = exp.DataType.build("int") + + table = ctx.table("test_table") + input_schema = { + k: v for k, v in ctx.columns_to_types.items() if k not in ("valid_from", "valid_to") + } + + ctx.engine_adapter.create_table(table, columns_to_types, table_format=ctx.default_table_format) + input_data = pd.DataFrame( + [ + { + "id": 1, + "name": "a", + "updated_at": "2022-01-01 00:00:00", + "ignored_source": "ignored_value", + }, + { + "id": 2, + "name": "b", + "updated_at": "2022-01-02 00:00:00", + "ignored_source": "ignored_value", + }, + { + "id": 3, + "name": "c", + "updated_at": "2022-01-03 00:00:00", + "ignored_source": "ignored_value", + }, + ] + ) + ctx.engine_adapter.scd_type_2_by_time( + table, + ctx.input_data(input_data, input_schema), + unique_key=[parse_one("COALESCE(id, -1)")], + valid_from_col=exp.column("valid_from", quoted=True), + valid_to_col=exp.column("valid_to", quoted=True), + updated_at_col=exp.column("updated_at", quoted=True), + execution_time="2023-01-01 00:00:00", + updated_at_as_valid_from=False, + table_format=ctx.default_table_format, + truncate=True, + start="2022-01-01 00:00:00", + target_columns_to_types=columns_to_types, + source_columns=["id", "name", "updated_at", "ignored_source"], + ) + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current( + table, + pd.DataFrame( + [ + { + "id": 1, + "name": "a", + "updated_at": "2022-01-01 00:00:00", + "valid_from": "1970-01-01 00:00:00", + "valid_to": pd.NaT, + "ignored_column": None, + }, + { + "id": 2, + "name": "b", + "updated_at": "2022-01-02 00:00:00", + "valid_from": "1970-01-01 00:00:00", + "valid_to": pd.NaT, + "ignored_column": None, + }, + { + "id": 3, + "name": "c", + "updated_at": "2022-01-03 00:00:00", + "valid_from": "1970-01-01 00:00:00", + "valid_to": pd.NaT, + "ignored_column": None, + }, + ] + ), + ) + + if ctx.test_type == "query": + return + + current_data = pd.DataFrame( + [ + # Change `a` to `x` + { + "id": 1, + "name": "x", + "updated_at": "2022-01-04 00:00:00", + "ignored_source": "ignored_value", + }, + # Delete + # {"id": 2, "name": "b", "updated_at": "2022-01-02 00:00:00", "ignored_source": "ignored_value"}, + # No change + { + "id": 3, + "name": "c", + "updated_at": "2022-01-03 00:00:00", + "ignored_source": "ignored_value", + }, + # Add + { + "id": 4, + "name": "d", + "updated_at": "2022-01-04 00:00:00", + "ignored_source": "ignored_value", + }, + ] + ) + ctx.engine_adapter.scd_type_2_by_time( + table, + ctx.input_data(current_data, input_schema), + unique_key=[exp.to_column("id")], + valid_from_col=exp.column("valid_from", quoted=True), + valid_to_col=exp.column("valid_to", quoted=True), + updated_at_col=exp.column("updated_at", quoted=True), + execution_time="2023-01-05 00:00:00", + updated_at_as_valid_from=False, + table_format=ctx.default_table_format, + truncate=False, + start="2022-01-01 00:00:00", + target_columns_to_types=columns_to_types, + source_columns=["id", "name", "updated_at", "ignored_source"], + ) + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current( + table, + pd.DataFrame( + [ + { + "id": 1, + "name": "a", + "updated_at": "2022-01-01 00:00:00", + "valid_from": "1970-01-01 00:00:00", + "valid_to": "2022-01-04 00:00:00", + "ignored_column": None, + }, + { + "id": 1, + "name": "x", + "updated_at": "2022-01-04 00:00:00", + "valid_from": "2022-01-04 00:00:00", + "valid_to": pd.NaT, + "ignored_column": None, + }, + { + "id": 2, + "name": "b", + "updated_at": "2022-01-02 00:00:00", + "valid_from": "1970-01-01 00:00:00", + "valid_to": "2023-01-05 00:00:00", + "ignored_column": None, + }, + { + "id": 3, + "name": "c", + "updated_at": "2022-01-03 00:00:00", + "valid_from": "1970-01-01 00:00:00", + "valid_to": pd.NaT, + "ignored_column": None, + }, + { + "id": 4, + "name": "d", + "updated_at": "2022-01-04 00:00:00", + "valid_from": "1970-01-01 00:00:00", + "valid_to": pd.NaT, + "ignored_column": None, + }, + ] + ), + ) + + +def test_scd_type_2_by_column(ctx_query_and_df: TestContext): + ctx = ctx_query_and_df + # Athena only supports the operations required for SCD models on Iceberg tables + if ctx.mark == "athena_hive": + pytest.skip("SCD Type 2 is only supported on Athena / Iceberg") + + time_type = exp.DataType.build("timestamp") + + ctx.columns_to_types = { + "id": "int", + "name": "string", + "status": "string", + "valid_from": time_type, + "valid_to": time_type, + } + table = ctx.table("test_table") + input_schema = { + k: v for k, v in ctx.columns_to_types.items() if k not in ("valid_from", "valid_to") + } + + ctx.engine_adapter.create_table( + table, ctx.columns_to_types, table_format=ctx.default_table_format + ) + input_data = pd.DataFrame( + [ + {"id": 1, "name": "a", "status": "active"}, + {"id": 2, "name": "b", "status": "inactive"}, + {"id": 3, "name": "c", "status": "active"}, + {"id": 4, "name": "d", "status": "active"}, + ] + ) + ctx.engine_adapter.scd_type_2_by_column( + table, + ctx.input_data(input_data, input_schema), + unique_key=[exp.to_column("id")], + check_columns=[exp.to_column("name"), exp.to_column("status")], + valid_from_col=exp.column("valid_from", quoted=True), + valid_to_col=exp.column("valid_to", quoted=True), + execution_time="2023-01-01", + execution_time_as_valid_from=False, + target_columns_to_types=ctx.columns_to_types, + truncate=True, + ) + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current( + table, + pd.DataFrame( + [ + { + "id": 1, + "name": "a", + "status": "active", + "valid_from": "1970-01-01 00:00:00", + "valid_to": pd.NaT, + }, + { + "id": 2, + "name": "b", + "status": "inactive", + "valid_from": "1970-01-01 00:00:00", + "valid_to": pd.NaT, + }, + { + "id": 3, + "name": "c", + "status": "active", + "valid_from": "1970-01-01 00:00:00", + "valid_to": pd.NaT, + }, + { + "id": 4, + "name": "d", + "status": "active", + "valid_from": "1970-01-01 00:00:00", + "valid_to": pd.NaT, + }, + ] + ), + ) + + if ctx.test_type == "query": + return + + current_data = pd.DataFrame( + [ + # Change `a` to `x` + {"id": 1, "name": "x", "status": "active"}, + # Delete + # {"id": 2, "name": "b", status: "inactive"}, + # No change + {"id": 3, "name": "c", "status": "active"}, + # Change status to inactive + {"id": 4, "name": "d", "status": "inactive"}, + # Add + {"id": 5, "name": "e", "status": "inactive"}, + ] + ) + ctx.engine_adapter.scd_type_2_by_column( + table, + ctx.input_data(current_data, input_schema), + unique_key=[exp.to_column("id")], + check_columns=[exp.to_column("name"), exp.to_column("status")], + valid_from_col=exp.column("valid_from", quoted=True), + valid_to_col=exp.column("valid_to", quoted=True), + execution_time="2023-01-05 00:00:00", + execution_time_as_valid_from=False, + target_columns_to_types=ctx.columns_to_types, + truncate=False, + ) + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current( + table, + pd.DataFrame( + [ + { + "id": 1, + "name": "a", + "status": "active", + "valid_from": "1970-01-01 00:00:00", + "valid_to": "2023-01-05 00:00:00", + }, + { + "id": 1, + "name": "x", + "status": "active", + "valid_from": "2023-01-05 00:00:00", + "valid_to": pd.NaT, + }, + { + "id": 2, + "name": "b", + "status": "inactive", + "valid_from": "1970-01-01 00:00:00", + "valid_to": "2023-01-05 00:00:00", + }, + { + "id": 3, + "name": "c", + "status": "active", + "valid_from": "1970-01-01 00:00:00", + "valid_to": pd.NaT, + }, + { + "id": 4, + "name": "d", + "status": "active", + "valid_from": "1970-01-01 00:00:00", + "valid_to": "2023-01-05 00:00:00", + }, + { + "id": 4, + "name": "d", + "status": "inactive", + "valid_from": "2023-01-05 00:00:00", + "valid_to": pd.NaT, + }, + { + "id": 5, + "name": "e", + "status": "inactive", + "valid_from": "2023-01-05 00:00:00", + "valid_to": pd.NaT, + }, + ] + ), + ) + + +def test_scd_type_2_by_column_source_columns(ctx_query_and_df: TestContext): + ctx = ctx_query_and_df + # Athena only supports the operations required for SCD models on Iceberg tables + if ctx.mark == "athena_hive": + pytest.skip("SCD Type 2 is only supported on Athena / Iceberg") + + time_type = exp.DataType.build("timestamp") + + ctx.columns_to_types = { + "id": "int", + "name": "string", + "status": "string", + "valid_from": time_type, + "valid_to": time_type, + } + columns_to_types = ctx.columns_to_types.copy() + columns_to_types["ignored_column"] = exp.DataType.build("int") + + table = ctx.table("test_table") + input_schema = { + k: v for k, v in ctx.columns_to_types.items() if k not in ("valid_from", "valid_to") + } + + ctx.engine_adapter.create_table(table, columns_to_types, table_format=ctx.default_table_format) + input_data = pd.DataFrame( + [ + {"id": 1, "name": "a", "status": "active", "ignored_source": "ignored_value"}, + {"id": 2, "name": "b", "status": "inactive", "ignored_source": "ignored_value"}, + {"id": 3, "name": "c", "status": "active", "ignored_source": "ignored_value"}, + {"id": 4, "name": "d", "status": "active", "ignored_source": "ignored_value"}, + ] + ) + ctx.engine_adapter.scd_type_2_by_column( + table, + ctx.input_data(input_data, input_schema), + unique_key=[exp.to_column("id")], + check_columns=[exp.to_column("name"), exp.to_column("status")], + valid_from_col=exp.column("valid_from", quoted=True), + valid_to_col=exp.column("valid_to", quoted=True), + execution_time="2023-01-01", + execution_time_as_valid_from=False, + truncate=True, + start="2023-01-01", + target_columns_to_types=columns_to_types, + source_columns=["id", "name", "status", "ignored_source"], + ) + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current( + table, + pd.DataFrame( + [ + { + "id": 1, + "name": "a", + "status": "active", + "valid_from": "1970-01-01 00:00:00", + "valid_to": pd.NaT, + "ignored_column": None, + }, + { + "id": 2, + "name": "b", + "status": "inactive", + "valid_from": "1970-01-01 00:00:00", + "valid_to": pd.NaT, + "ignored_column": None, + }, + { + "id": 3, + "name": "c", + "status": "active", + "valid_from": "1970-01-01 00:00:00", + "valid_to": pd.NaT, + "ignored_column": None, + }, + { + "id": 4, + "name": "d", + "status": "active", + "valid_from": "1970-01-01 00:00:00", + "valid_to": pd.NaT, + "ignored_column": None, + }, + ] + ), + ) + + if ctx.test_type == "query": + return + + current_data = pd.DataFrame( + [ + # Change `a` to `x` + {"id": 1, "name": "x", "status": "active", "ignored_source": "ignored_value"}, + # Delete + # {"id": 2, "name": "b", status: "inactive", "ignored_source": "ignored_value"}, + # No change + {"id": 3, "name": "c", "status": "active", "ignored_source": "ignored_value"}, + # Change status to inactive + {"id": 4, "name": "d", "status": "inactive", "ignored_source": "ignored_value"}, + # Add + {"id": 5, "name": "e", "status": "inactive", "ignored_source": "ignored_value"}, + ] + ) + ctx.engine_adapter.scd_type_2_by_column( + table, + ctx.input_data(current_data, input_schema), + unique_key=[exp.to_column("id")], + check_columns=[exp.to_column("name"), exp.to_column("status")], + valid_from_col=exp.column("valid_from", quoted=True), + valid_to_col=exp.column("valid_to", quoted=True), + execution_time="2023-01-05 00:00:00", + execution_time_as_valid_from=False, + truncate=False, + start="2023-01-01", + target_columns_to_types=columns_to_types, + source_columns=["id", "name", "status", "ignored_source"], + ) + results = ctx.get_metadata_results() + assert len(results.views) == 0 + assert len(results.materialized_views) == 0 + assert len(results.tables) == len(results.non_temp_tables) == 1 + assert results.non_temp_tables[0] == table.name + ctx.compare_with_current( + table, + pd.DataFrame( + [ + { + "id": 1, + "name": "a", + "status": "active", + "valid_from": "1970-01-01 00:00:00", + "valid_to": "2023-01-05 00:00:00", + "ignored_column": None, + }, + { + "id": 1, + "name": "x", + "status": "active", + "valid_from": "2023-01-05 00:00:00", + "valid_to": pd.NaT, + "ignored_column": None, + }, + { + "id": 2, + "name": "b", + "status": "inactive", + "valid_from": "1970-01-01 00:00:00", + "valid_to": "2023-01-05 00:00:00", + "ignored_column": None, + }, + { + "id": 3, + "name": "c", + "status": "active", + "valid_from": "1970-01-01 00:00:00", + "valid_to": pd.NaT, + "ignored_column": None, + }, + { + "id": 4, + "name": "d", + "status": "active", + "valid_from": "1970-01-01 00:00:00", + "valid_to": "2023-01-05 00:00:00", + "ignored_column": None, + }, + { + "id": 4, + "name": "d", + "status": "inactive", + "valid_from": "2023-01-05 00:00:00", + "valid_to": pd.NaT, + "ignored_column": None, + }, + { + "id": 5, + "name": "e", + "status": "inactive", + "valid_from": "2023-01-05 00:00:00", + "valid_to": pd.NaT, + "ignored_column": None, + }, + ] + ), + ) + + +def test_get_data_objects(ctx_query_and_df: TestContext): + ctx = ctx_query_and_df + table = ctx.table("test_table") + view = ctx.table("test_view") + ctx.engine_adapter.create_table( + table, + {"id": exp.DataType.build("int")}, + table_description="test table description", + column_descriptions={"id": "test id column description"}, + table_format=ctx.default_table_format, + ) + ctx.engine_adapter.create_view( + view, + ctx.input_data(pd.DataFrame([{"id": 1, "ds": "2022-01-01"}])), + table_description="test view description", + column_descriptions={"id": "test id column description"}, + ) + + schema = ctx.schema(TEST_SCHEMA) + + assert sorted(ctx.engine_adapter.get_data_objects(schema), key=lambda o: o.name) == [ + DataObject( + name=table.name, + schema=table.db, + catalog=table.catalog or None, + type=DataObjectType.TABLE, + ), + DataObject( + name=view.name, + schema=view.db, + catalog=view.catalog or None, + type=DataObjectType.VIEW, + ), + ] + + assert sorted( + ctx.engine_adapter.get_data_objects(schema, {table.name, view.name}), + key=lambda o: o.name, + ) == [ + DataObject( + name=table.name, + schema=table.db, + catalog=table.catalog or None, + type=DataObjectType.TABLE, + ), + DataObject( + name=view.name, + schema=view.db, + catalog=view.catalog or None, + type=DataObjectType.VIEW, + ), + ] + + assert ctx.engine_adapter.get_data_objects(schema, {table.name}) == [ + DataObject( + name=table.name, + schema=table.db, + catalog=table.catalog or None, + type=DataObjectType.TABLE, + ), + ] + + assert ctx.engine_adapter.get_data_objects(schema, {view.name}) == [ + DataObject( + name=view.name, + schema=view.db, + catalog=view.catalog or None, + type=DataObjectType.VIEW, + ), + ] + + assert ctx.engine_adapter.get_data_objects(schema, {}) == [] + assert ctx.engine_adapter.get_data_objects("missing_schema") == [] + + +def test_truncate_table(ctx: TestContext): + table = ctx.table("test_table") + + ctx.engine_adapter.create_table( + table, ctx.columns_to_types, table_format=ctx.default_table_format + ) + input_data = pd.DataFrame( + [ + {"id": 1, "ds": "2022-01-01"}, + {"id": 2, "ds": "2022-01-02"}, + {"id": 3, "ds": "2022-01-03"}, + ] + ) + ctx.engine_adapter.insert_append(table, ctx.input_data(input_data)) + ctx.compare_with_current(table, input_data) + ctx.engine_adapter._truncate_table(table) + assert ctx.engine_adapter.fetchone(exp.select("count(*)").from_(table))[0] == 0 + + +def test_transaction(ctx: TestContext): + if ctx.engine_adapter.SUPPORTS_TRANSACTIONS is False: + pytest.skip(f"Engine adapter {ctx.engine_adapter.dialect} doesn't support transactions") + + table = ctx.table("test_table") + input_data = pd.DataFrame( + [ + {"id": 1, "ds": "2022-01-01"}, + {"id": 2, "ds": "2022-01-02"}, + {"id": 3, "ds": "2022-01-03"}, + ] + ) + with ctx.engine_adapter.transaction(): + ctx.engine_adapter.create_table(table, ctx.columns_to_types) + ctx.engine_adapter.insert_append( + table, ctx.input_data(input_data, ctx.columns_to_types), ctx.columns_to_types + ) + ctx.compare_with_current(table, input_data) + with ctx.engine_adapter.transaction(): + ctx.engine_adapter._truncate_table(table) + ctx.engine_adapter._connection_pool.rollback() + ctx.compare_with_current(table, input_data) + + +@pytest.mark.parametrize( + "virtual_environment_mode", [VirtualEnvironmentMode.FULL, VirtualEnvironmentMode.DEV_ONLY] +) +def test_sushi( + ctx: TestContext, tmp_path: pathlib.Path, virtual_environment_mode: VirtualEnvironmentMode +): + if ctx.mark == "athena_hive": + pytest.skip( + "Sushi end-to-end tests only need to run once for Athena because sushi needs a hybrid of both Hive and Iceberg" + ) + + sushi_test_schema = ctx.add_test_suffix("sushi") + sushi_state_schema = ctx.add_test_suffix("sushi_state") + raw_test_schema = ctx.add_test_suffix("raw") + + # Copy sushi example to tmpdir + shutil.copytree(pathlib.Path("./examples/sushi"), tmp_path, dirs_exist_ok=True) + + # Rewrite schema references to test schema references + # Note that we deliberately do it at the filesystem level instead of messing with the Context to ensure + # that we are testing an actual Context rather than a doctored one + extensions = ["*.sql", "*.yaml", "*.py"] + replacements = { + "sushi.": f"{sushi_test_schema}.", + 'sushi".': f'{sushi_test_schema}".', + " raw.": f" {raw_test_schema}.", + "NOT EXISTS raw;": f" NOT EXISTS {raw_test_schema};", + } + for ext in extensions: + for f in tmp_path.rglob(ext): + if f.is_file(): + contents = f.read_text() + for search, replace in replacements.items(): + contents = contents.replace(search, replace) + f.write_text(contents) + + before_all = [ + f"CREATE SCHEMA IF NOT EXISTS {raw_test_schema}", + f"DROP VIEW IF EXISTS {raw_test_schema}.demographics", + f"CREATE VIEW {raw_test_schema}.demographics AS (SELECT 1 AS customer_id, '00000' AS zip)", + ] + + def _mutate_config(gateway: str, config: Config) -> None: + config.gateways[gateway].state_schema = sushi_state_schema + config.before_all = [ + quote_identifiers( + parse_one(e, dialect=config.model_defaults.dialect), + dialect=config.model_defaults.dialect, + ).sql(dialect=config.model_defaults.dialect) + for e in before_all + ] + config.virtual_environment_mode = virtual_environment_mode + + context = ctx.create_context(_mutate_config, path=tmp_path, ephemeral_state_connection=False) + + end = now() + start = to_date(end - timedelta(days=7)) + yesterday = to_date(end - timedelta(days=1)) + + # Databricks requires the table property `delta.columnMapping.mode = 'name'` for + # spaces in column names. Other engines error if it is set in the model definition, + # so we set it here. + if ctx.dialect == "databricks": + cust_rev_by_day_key = [key for key in context._models if "customer_revenue_by_day" in key][ + 0 + ] + + cust_rev_by_day_model_tbl_props = context._models[cust_rev_by_day_key].copy( + update={ + "physical_properties": { + "delta.columnMapping.mode": exp.Literal(this="name", is_string=True) + } + } + ) + + context._models.update({cust_rev_by_day_key: cust_rev_by_day_model_tbl_props}) + + # Clickhouse requires columns used as keys to be non-Nullable, but all transpiled columns are nullable by default + if ctx.dialect == "clickhouse": + models_to_modify = { + '"items"', + '"orders"', + '"order_items"', + '"customer_revenue_by_day"', + '"customer_revenue_lifetime"', + '"waiter_revenue_by_day"', + '"waiter_as_customer_by_day"', + } + for model in models_to_modify: + model_key = [key for key in context._models if model in key][0] + model_columns = context._models[model_key].columns_to_types + updated_model_columns = { + k: exp.DataType.build(v.sql(), dialect="clickhouse", nullable=False) + for k, v in model_columns.items() + } + + model_ch_cols_to_types = context._models[model_key].copy( + update={ + "columns_to_types": updated_model_columns, + "columns_to_types_": updated_model_columns, + "columns_to_types_or_raise": updated_model_columns, + } + ) + context._models.update({model_key: model_ch_cols_to_types}) + + # create raw schema and view + if ctx.gateway == "inttest_clickhouse_cluster": + context.engine_adapter.execute( + f"CREATE DATABASE IF NOT EXISTS {raw_test_schema} ON CLUSTER cluster1;" + ) + context.engine_adapter.execute( + f"DROP VIEW IF EXISTS {raw_test_schema}.demographics ON CLUSTER cluster1;" + ) + context.engine_adapter.execute( + f"CREATE VIEW {raw_test_schema}.demographics ON CLUSTER cluster1 AS SELECT 1 AS customer_id, '00000' AS zip;" + ) + + # DuckDB parses TIMESTAMP into Type.TIMESTAMPNTZ which generates into TIMESTAMP_NTZ for + # Spark, but this type is not supported in Spark's DDL statements so we make it a TIMESTAMP + if ctx.dialect == "spark": + for model_key, model in context._models.items(): + model_columns = model.columns_to_types + + updated_model_columns = {} + for k, v in model_columns.items(): + updated_model_columns[k] = v + if v.this == exp.DataType.Type.TIMESTAMPNTZ: + v.set("this", exp.DataType.Type.TIMESTAMP) + + update_fields = { + "columns_to_types": updated_model_columns, + "columns_to_types_": updated_model_columns, + "columns_to_types_or_raise": updated_model_columns, + } + + # We get rid of the sushi.marketing post statement here because it asserts that + # updated_at is a 'timestamp', which is parsed using duckdb in assert_has_columns + # and the assertion fails because we now have TIMESTAMPs and not TIMESTAMPNTZs in + # the columns_to_types mapping + if '"marketing"' in model_key: + update_fields["post_statements_"] = [] + + context._models.update( + {model_key: context._models[model_key].copy(update=update_fields)} + ) + + if ctx.dialect == "athena": + for model_name in {"customer_revenue_lifetime"}: + model_key = next(k for k in context._models if model_name in k) + model = context._models[model_key].copy( + update={"table_format": ctx.default_table_format} + ) + context._models.update({model_key: model}) + + plan: Plan = context.plan( + environment="test_prod", + start=start, + end=end, + skip_tests=True, + no_prompts=True, + auto_apply=True, + ) + + data_validator = SushiDataValidator.from_context(context, sushi_schema_name=sushi_test_schema) + data_validator.validate( + f"{sushi_test_schema}.customer_revenue_lifetime", + start, + yesterday, + env_name="test_prod", + dialect=ctx.dialect, + environment_naming_info=plan.environment_naming_info, + ) + + # Ensure table and column comments were correctly registered with engine + if ctx.engine_adapter.COMMENT_CREATION_TABLE.is_supported: + comments = { + "customer_revenue_by_day": { + "table": "Table of revenue from customers by day.", + "column": { + "customer_id": "Customer id", + "revenue": "Revenue from orders made by this customer", + "event_date": "Date", + }, + }, + "customer_revenue_lifetime": { + "table": """Table of lifetime customer revenue. + Date is available to get lifetime value up to a certain date. + Use latest date to get current lifetime value.""", + "column": { + "customer_id": "Customer id", + "revenue": "Lifetime revenue from this customer", + "event_date": "End date of the lifetime calculation", + }, + }, + "customers": { + "table": "Sushi customer data", + "column": {"customer_id": "customer_id uniquely identifies customers"}, + }, + "orders": { + "table": "Table of sushi orders.", + }, + "raw_marketing": { + "table": "Table of marketing status.", + "column": {"customer_id": "Unique identifier of the customer"}, + }, + "top_waiters": { + "table": "View of top waiters.", + }, + "waiter_names": { + "table": "List of waiter names", + }, + "waiter_revenue_by_day": { + "table": "Table of revenue generated by waiters by day.", + "column": { + "waiter_id": "Waiter id", + "revenue": "Revenue from orders taken by this waiter", + "event_date": "Date", + }, + }, + } + + def validate_comments( + schema_name: str, + expected_comments_dict: t.Dict[str, t.Any] = comments, + is_physical_layer: bool = True, + prod_schema_name: str = "sushi", + ) -> None: + layer_objects = context.engine_adapter.get_data_objects(schema_name) + layer_models = { + x.name.split("__")[1] if is_physical_layer else x.name: { + "table_name": x.name, + "is_view": x.type == DataObjectType.VIEW, + } + for x in layer_objects + if not x.name.endswith("__dev") + } + + for model_name, comment in comments.items(): + if not model_name in layer_models: + continue + layer_table_name = layer_models[model_name]["table_name"] + table_kind = "VIEW" if layer_models[model_name]["is_view"] else "BASE TABLE" + + # is this model in a physical layer or PROD environment? + is_physical_or_prod = is_physical_layer or ( + not is_physical_layer and schema_name == prod_schema_name + ) + # is this model a VIEW and the engine doesn't support VIEW comments? + is_view_and_comments_unsupported = ( + layer_models[model_name]["is_view"] + and ctx.engine_adapter.COMMENT_CREATION_VIEW.is_unsupported + ) + if is_physical_or_prod and not is_view_and_comments_unsupported: + expected_tbl_comment = comments.get(model_name).get("table", None) + if expected_tbl_comment: + actual_tbl_comment = ctx.get_table_comment( + schema_name, + layer_table_name, + table_kind=table_kind, + snowflake_capitalize_ids=False, + ) + assert expected_tbl_comment == actual_tbl_comment + + expected_col_comments = comments.get(model_name).get("column", None) + + # Trino: + # Trino on Hive COMMENT permissions are separate from standard SQL object permissions. + # Trino has a bug where CREATE SQL permissions are not passed to COMMENT permissions, + # which generates permissions errors when COMMENT commands are issued. + # + # The errors are thrown for both table and comments, but apparently the + # table comments are actually registered with the engine. Column comments are not. + # + # Query: + # In the query test, columns_to_types are not available when views are created. Since we + # can only register column comments in the CREATE VIEW schema expression with columns_to_types + # available, the column comments must be registered via post-creation commands. Some engines, + # such as Spark and Snowflake, do not support view column comments via post-creation commands. + if ( + expected_col_comments + and not ctx.dialect == "trino" + and not ( + ctx.test_type == "query" + and layer_models[model_name]["is_view"] + and not ctx.engine_adapter.COMMENT_CREATION_VIEW.supports_column_comment_commands + ) + ): + actual_col_comments = ctx.get_column_comments( + schema_name, + layer_table_name, + table_kind=table_kind, + snowflake_capitalize_ids=False, + ) + for column_name, expected_col_comment in expected_col_comments.items(): + expected_col_comment = expected_col_comments.get(column_name, None) + actual_col_comment = actual_col_comments.get(column_name, None) + assert expected_col_comment == actual_col_comment + + return None + + def validate_no_comments( + schema_name: str, + expected_comments_dict: t.Dict[str, t.Any] = comments, + is_physical_layer: bool = True, + table_name_suffix: str = "", + check_temp_tables: bool = False, + prod_schema_name: str = "sushi", + ) -> None: + layer_objects = context.engine_adapter.get_data_objects(schema_name) + layer_models = { + x.name.split("__")[1] if is_physical_layer else x.name: { + "table_name": x.name, + "is_view": x.type == DataObjectType.VIEW, + } + for x in layer_objects + if x.name.endswith(table_name_suffix) + } + if not check_temp_tables: + layer_models = {k: v for k, v in layer_models.items() if not k.endswith("__dev")} + + for model_name, comment in comments.items(): + layer_table_name = layer_models[model_name]["table_name"] + table_kind = "VIEW" if layer_models[model_name]["is_view"] else "BASE TABLE" + + actual_tbl_comment = ctx.get_table_comment( + schema_name, + layer_table_name, + table_kind=table_kind, + snowflake_capitalize_ids=False, + ) + # MySQL doesn't support view comments and always returns "VIEW" as the table comment + if ctx.dialect == "mysql" and layer_models[model_name]["is_view"]: + assert actual_tbl_comment == "VIEW" + else: + assert actual_tbl_comment is None or actual_tbl_comment == "" + + # MySQL and Spark pass through the column comments from the underlying table to the view + # so always have view comments present + if not ( + ctx.dialect in ("mysql", "spark", "databricks") + and layer_models[model_name]["is_view"] + ): + expected_col_comments = comments.get(model_name).get("column", None) + if expected_col_comments: + actual_col_comments = ctx.get_column_comments( + schema_name, + layer_table_name, + table_kind=table_kind, + snowflake_capitalize_ids=False, + ) + for column_name in expected_col_comments: + actual_col_comment = actual_col_comments.get(column_name, None) + assert actual_col_comment is None or actual_col_comment == "" + + return None + + validate_comments(f"sqlmesh__{sushi_test_schema}", prod_schema_name=sushi_test_schema) + + # confirm view layer comments are not registered in non-PROD environment + env_name = "test_prod" + if plan.environment_naming_info and plan.environment_naming_info.normalize_name: + env_name = normalize_identifiers(env_name, dialect=ctx.dialect).name + validate_no_comments( + f"{sushi_test_schema}__{env_name}", + is_physical_layer=False, + prod_schema_name=sushi_test_schema, + ) + + # Ensure that the plan has been applied successfully. + no_change_plan: Plan = context.plan_builder( + environment="test_dev", + start=start, + end=end, + skip_tests=True, + include_unmodified=True, + ).build() + assert not no_change_plan.requires_backfill + assert no_change_plan.context_diff.is_new_environment + + # make and validate unmodified dev environment + context.apply(no_change_plan) + + data_validator.validate( + f"{sushi_test_schema}.customer_revenue_lifetime", + start, + yesterday, + env_name="test_dev", + dialect=ctx.dialect, + environment_naming_info=no_change_plan.environment_naming_info, + ) + + # confirm view layer comments are registered in PROD + if ctx.engine_adapter.COMMENT_CREATION_VIEW.is_supported: + context.plan(skip_tests=True, no_prompts=True, auto_apply=True) + validate_comments(sushi_test_schema, is_physical_layer=False) + + # Register schemas for cleanup + for schema in [ + f"{sushi_test_schema}__test_prod", + f"{sushi_test_schema}__test_dev", + sushi_test_schema, + f"sqlmesh__{sushi_test_schema}", + sushi_state_schema, + raw_test_schema, + ]: + ctx._schemas.append(schema) + + +def test_init_project(ctx: TestContext, tmp_path: pathlib.Path): + schema_name = ctx.add_test_suffix(TEST_SCHEMA) + state_schema = ctx.add_test_suffix("sqlmesh_state") + + object_names = { + "view_schema": [schema_name], + "physical_schema": [f"sqlmesh__{schema_name}"], + "dev_schema": [f"{schema_name}__test_dev"], + "views": ["full_model", "incremental_model", "seed_model"], + } + + # normalize object names for snowflake + if ctx.dialect == "snowflake": + + def _normalize_snowflake(name: str, prefix_regex: str = "(sqlmesh__)(.*)"): + match = re.search(prefix_regex, name) + if match: + return f"{match.group(1)}{match.group(2).upper()}" + return name.upper() + + object_names = { + k: [_normalize_snowflake(name) for name in v] for k, v in object_names.items() + } + + init_example_project(tmp_path, ctx.engine_type, schema_name=schema_name) + + def _mutate_config(gateway: str, config: Config): + # ensure default dialect comes from init_example_project and not ~/.sqlmesh/config.yaml + if config.model_defaults.dialect != ctx.dialect: + config.model_defaults = config.model_defaults.copy(update={"dialect": ctx.dialect}) + + # Ensure the state schema is unique to this test (since we deliberately use the warehouse as the state connection) + config.gateways[gateway].state_schema = state_schema + + context = ctx.create_context(_mutate_config, path=tmp_path, ephemeral_state_connection=False) + + if ctx.default_table_format: + # if the default table format is explicitly set, ensure its being used + replacement_models = {} + for model_key, model in context._models.items(): + if not model.table_format: + replacement_models[model_key] = model.copy( + update={"table_format": ctx.default_table_format} + ) + context._models.update(replacement_models) + + # capture row counts for each evaluated snapshot + actual_execution_stats = {} + + def capture_execution_stats( + snapshot, + interval, + batch_idx, + duration_ms, + num_audits_passed, + num_audits_failed, + audit_only=False, + execution_stats=None, + auto_restatement_triggers=None, + ): + if execution_stats is not None: + actual_execution_stats[snapshot.model.name.replace(f"{schema_name}.", "")] = ( + execution_stats + ) + + # apply prod plan + with patch.object( + context.console, "update_snapshot_evaluation_progress", capture_execution_stats + ): + context.plan(auto_apply=True, no_prompts=True) + + prod_schema_results = ctx.get_metadata_results(object_names["view_schema"][0]) + assert sorted(prod_schema_results.views) == object_names["views"] + assert len(prod_schema_results.materialized_views) == 0 + assert len(prod_schema_results.tables) == len(prod_schema_results.non_temp_tables) == 0 + + physical_layer_results = ctx.get_metadata_results(object_names["physical_schema"][0]) + assert len(physical_layer_results.views) == 0 + assert len(physical_layer_results.materialized_views) == 0 + assert len(physical_layer_results.tables) == len(physical_layer_results.non_temp_tables) == 3 + + if ctx.engine_adapter.SUPPORTS_QUERY_EXECUTION_TRACKING: + assert actual_execution_stats["incremental_model"].total_rows_processed == 7 + # snowflake and redshift don't track rows for CTAS + assert actual_execution_stats["full_model"].total_rows_processed == ( + None if ctx.mark.startswith("snowflake") or ctx.mark.startswith("redshift") else 3 + ) + assert actual_execution_stats["seed_model"].total_rows_processed == ( + None if ctx.mark.startswith("snowflake") else 7 + ) + + if ctx.mark.startswith("bigquery"): + assert actual_execution_stats["incremental_model"].total_bytes_processed + assert actual_execution_stats["full_model"].total_bytes_processed + + # run that loads 0 rows in incremental model + # - some cloud DBs error because time travel messes up token expiration + if not ctx.is_remote: + actual_execution_stats = {} + with patch.object( + context.console, "update_snapshot_evaluation_progress", capture_execution_stats + ): + with time_machine.travel(date.today() + timedelta(days=1)): + context.run() + + if ctx.engine_adapter.SUPPORTS_QUERY_EXECUTION_TRACKING: + assert actual_execution_stats["incremental_model"].total_rows_processed == 0 + assert actual_execution_stats["full_model"].total_rows_processed == 3 + + # make and validate unmodified dev environment + no_change_plan: Plan = context.plan_builder( + environment="test_dev", + skip_tests=True, + include_unmodified=True, + ).build() + assert not no_change_plan.requires_backfill + assert no_change_plan.context_diff.is_new_environment + + context.apply(no_change_plan) + + environment = no_change_plan.environment + first_snapshot = no_change_plan.environment.snapshots[0] + schema_name = first_snapshot.qualified_view_name.schema_for_environment( + environment, dialect=ctx.dialect + ) + dev_schema_results = ctx.get_metadata_results(schema_name) + assert sorted(dev_schema_results.views) == object_names["views"] + assert len(dev_schema_results.materialized_views) == 0 + assert len(dev_schema_results.tables) == len(dev_schema_results.non_temp_tables) == 0 + + # register the schemas to be cleaned up + for schema in [ + state_schema, + *object_names["view_schema"], + *object_names["physical_schema"], + *object_names["dev_schema"], + ]: + ctx._schemas.append(schema) + + +def test_dialects(ctx: TestContext): + from sqlglot import Dialect, parse_one + + dialect = Dialect[ctx.dialect] + + if dialect.NORMALIZATION_STRATEGY == "CASE_INSENSITIVE": + a = '"a"' + b = '"b"' + c = '"c"' + d = '"d"' + elif dialect.NORMALIZATION_STRATEGY == "LOWERCASE": + a = '"a"' + b = '"B"' + c = '"c"' + d = '"d"' + # https://dev.mysql.com/doc/refman/8.0/en/identifier-case-sensitivity.html + # if these tests fail for mysql it means you're running on os x or windows + elif dialect.NORMALIZATION_STRATEGY == "CASE_SENSITIVE": + a = '"a"' + b = '"B"' + c = '"c"' + d = '"D"' + else: + a = '"a"' + b = '"B"' + c = '"C"' + d = '"D"' + + q = parse_one( + f""" + WITH + "a" AS (SELECT 1 w), + "B" AS (SELECT 1 x), + c AS (SELECT 1 y), + D AS (SELECT 1 z) + + SELECT * + FROM {a} + CROSS JOIN {b} + CROSS JOIN {c} + CROSS JOIN {d} + """ + ) + df = ctx.engine_adapter.fetchdf(q) + expected_columns = ["W", "X", "Y", "Z"] if ctx.dialect == "snowflake" else ["w", "x", "y", "z"] + pd.testing.assert_frame_equal( + df, pd.DataFrame([[1, 1, 1, 1]], columns=expected_columns), check_dtype=False + ) + + +@pytest.mark.parametrize( + "time_column, time_column_type, time_column_format, result", + [ + ( + exp.null(), + exp.DataType.build("TIMESTAMP", nullable=True), + None, + { + "default": None, + "bigquery": pd.NaT, + "clickhouse": pd.NaT, + "databricks": pd.NaT, + "duckdb": pd.NaT, + "motherduck": pd.NaT, + "snowflake": pd.NaT, + "spark": pd.NaT, + }, + ), + ( + "2020-01-01 00:00:00+00:00", + exp.DataType.build("DATE"), + None, + { + "default": datetime(2020, 1, 1).date(), + "clickhouse": pd.Timestamp("2020-01-01"), + "duckdb": pd.Timestamp("2020-01-01"), + }, + ), + ( + "2020-01-01 00:00:00+00:00", + exp.DataType.build("TIMESTAMPTZ"), + None, + { + "default": pd.Timestamp("2020-01-01 00:00:00+00:00"), + "clickhouse": pd.Timestamp("2020-01-01 00:00:00"), + "fabric": pd.Timestamp("2020-01-01 00:00:00"), + "mysql": pd.Timestamp("2020-01-01 00:00:00"), + "spark": pd.Timestamp("2020-01-01 00:00:00"), + "databricks": pd.Timestamp("2020-01-01 00:00:00"), + }, + ), + ( + "2020-01-01 00:00:00+00:00", + exp.DataType.build("TIMESTAMP"), + None, + {"default": pd.Timestamp("2020-01-01 00:00:00")}, + ), + ( + "2020-01-01 00:00:00+00:00", + exp.DataType.build("TEXT"), + "%Y-%m-%dT%H:%M:%S%z", + { + "default": "2020-01-01T00:00:00+0000", + }, + ), + ( + "2020-01-01 00:00:00+00:00", + exp.DataType.build("INT"), + "%Y%m%d", + { + "default": 20200101, + }, + ), + ], +) +def test_to_time_column( + ctx: TestContext, time_column, time_column_type, time_column_format, result +): + # TODO: can this be cleaned up after recent sqlglot updates? + if ctx.dialect == "clickhouse" and time_column_type.is_type(exp.DataType.Type.TIMESTAMPTZ): + # Clickhouse does not have natively timezone-aware types and does not accept timestrings + # with UTC offset "+XX:XX". Therefore, we remove the timezone offset and set a timezone- + # specific data type to validate what is returned. + + time_column = re.match(r"^(.*?)\+", time_column).group(1) + time_column_type = exp.DataType.build("TIMESTAMP('UTC')", dialect="clickhouse") + + time_column = to_time_column(time_column, time_column_type, ctx.dialect, time_column_format) + df = ctx.engine_adapter.fetchdf(exp.select(time_column).as_("the_col")) + expected = result.get(ctx.dialect, result.get("default")) + col_name = "THE_COL" if ctx.dialect == "snowflake" else "the_col" + if expected is pd.NaT or expected is None: + assert df[col_name][0] is expected + else: + assert df[col_name][0] == expected + + +def test_batch_size_on_incremental_by_unique_key_model(ctx: TestContext): + if not ctx.supports_merge: + pytest.skip(f"{ctx.dialect} on {ctx.gateway} doesnt support merge") + + def _mutate_config(current_gateway_name: str, config: Config): + # make stepping through in the debugger easier + connection = config.gateways[current_gateway_name].connection + connection.concurrent_tasks = 1 + + context = ctx.create_context(_mutate_config) + assert context.default_dialect == "duckdb" + + schema = ctx.schema(TEST_SCHEMA) + seed_columns_to_types = { + "item_id": exp.DataType.build("integer"), + "event_date": exp.DataType.build("date"), + } + seed_query = ctx.input_data( + pd.DataFrame( + [ + [2, "2020-01-01"], + [1, "2020-01-01"], + [3, "2020-01-03"], + [1, "2020-01-04"], + [1, "2020-01-05"], + [1, "2020-01-06"], + [1, "2020-01-07"], + ], + columns=["item_id", "event_date"], + ), + columns_to_types=seed_columns_to_types, + ) + context.upsert_model( + create_sql_model( + name=f"{schema}.seed_model", + query=seed_query, + kind="FULL", + columns=seed_columns_to_types, + ) + ) + + table_format = "" + if ctx.dialect == "athena": + # INCREMENTAL_BY_UNIQUE_KEY uses MERGE which is only supported in Athena on Iceberg tables + table_format = "table_format iceberg," + + context.upsert_model( + load_sql_based_model( + d.parse( + f"""MODEL ( + name {schema}.test_model, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key item_id, + batch_size 1 + ), + {table_format} + start '2020-01-01', + end '2020-01-07', + cron '@daily' + ); + + select * from {schema}.seed_model + where event_date between @start_date and @end_date""", + ) + ) + ) + + try: + context.plan(auto_apply=True, no_prompts=True) + + test_model = context.get_model(f"{schema}.test_model") + normalized_schema_name = test_model.fully_qualified_table.db + results = ctx.get_metadata_results(normalized_schema_name) + assert "test_model" in results.views + + actual_df = ( + ctx.get_current_data(test_model.fqn).sort_values(by="event_date").reset_index(drop=True) + ) + actual_df["event_date"] = actual_df["event_date"].astype(str) + assert actual_df.count()[0] == 3 + + expected_df = pd.DataFrame( + [[2, "2020-01-01"], [3, "2020-01-03"], [1, "2020-01-07"]], + columns=actual_df.columns, + ).sort_values(by="event_date") + + pd.testing.assert_frame_equal( + actual_df, + expected_df, + check_dtype=False, + ) + + finally: + ctx.cleanup(context) + + +def test_incremental_by_unique_key_model_when_matched(ctx: TestContext): + if not ctx.supports_merge: + pytest.skip(f"{ctx.dialect} on {ctx.gateway} doesnt support merge") + + # DuckDB and some other engines use logical_merge which doesn't support when_matched + if isinstance(ctx.engine_adapter, LogicalMergeMixin): + pytest.skip( + f"{ctx.dialect} on {ctx.gateway} uses logical merge which doesn't support when_matched" + ) + + def _mutate_config(current_gateway_name: str, config: Config): + connection = config.gateways[current_gateway_name].connection + connection.concurrent_tasks = 1 + if current_gateway_name == "inttest_redshift": + connection.enable_merge = True + + context = ctx.create_context(_mutate_config) + schema = ctx.schema(TEST_SCHEMA) + + # Create seed data with multiple days + seed_query = ctx.input_data( + pd.DataFrame( + [ + [1, "item_a", 100, "2020-01-01"], + [2, "item_b", 200, "2020-01-01"], + [1, "item_a_changed", 150, "2020-01-02"], # Same item_id, different name and value + [2, "item_b_changed", 250, "2020-01-02"], # Same item_id, different name and value + [3, "item_c", 300, "2020-01-02"], # New item on day 2 + ], + columns=["item_id", "name", "value", "event_date"], + ), + columns_to_types={ + "item_id": exp.DataType.build("integer"), + "name": exp.DataType.build("text"), + "value": exp.DataType.build("integer"), + "event_date": exp.DataType.build("date"), + }, + ) + context.upsert_model( + create_sql_model(name=f"{schema}.seed_model", query=seed_query, kind="FULL") + ) + + table_format = "" + if ctx.dialect == "athena": + # INCREMENTAL_BY_UNIQUE_KEY uses MERGE which is only supported in Athena on Iceberg tables + table_format = "table_format iceberg," + + # Create model with when_matched clause that only updates the value column + # BUT keeps the existing name column unchanged + # batch_size=1 is so that we trigger merge on second batch and verify behaviour of when_matched + context.upsert_model( + load_sql_based_model( + d.parse( + f"""MODEL ( + name {schema}.test_model_when_matched, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key item_id, + batch_size 1, + merge_filter source.event_date > target.event_date, + when_matched WHEN MATCHED THEN UPDATE SET target.value = source.value, target.event_date = source.event_date + ), + {table_format} + start '2020-01-01', + end '2020-01-02', + cron '@daily' + ); + + select item_id, name, value, event_date + from {schema}.seed_model + where event_date between @start_date and @end_date""", + ) + ) + ) + + try: + # Initial plan to create the model and run it + context.plan(auto_apply=True, no_prompts=True) + + test_model = context.get_model(f"{schema}.test_model_when_matched") + + # Verify that the model has the when_matched clause and merge_filter + assert test_model.kind.when_matched is not None + assert ( + test_model.kind.when_matched.sql() + == '(WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."value" = "__MERGE_SOURCE__"."value", "__MERGE_TARGET__"."event_date" = "__MERGE_SOURCE__"."event_date")' + ) + assert test_model.merge_filter is not None + assert ( + test_model.merge_filter.sql() + == '"__MERGE_SOURCE__"."event_date" > "__MERGE_TARGET__"."event_date"' + ) + + actual_df = ( + ctx.get_current_data(test_model.fqn).sort_values(by="item_id").reset_index(drop=True) + ) + + # Expected results after batch processing: + # - Day 1: Items 1 and 2 are inserted (first insert) + # - Day 2: Items 1 and 2 are merged (when_matched clause preserves names but updates values/dates) + # Item 3 is inserted as new + expected_df = ( + pd.DataFrame( + [ + [1, "item_a", 150, "2020-01-02"], # name from day 1, value and date from day 2 + [2, "item_b", 250, "2020-01-02"], # name from day 1, value and date from day 2 + [3, "item_c", 300, "2020-01-02"], # new item from day 2 + ], + columns=["item_id", "name", "value", "event_date"], + ) + .sort_values(by="item_id") + .reset_index(drop=True) + ) + + # Convert date columns to string for comparison + actual_df["event_date"] = actual_df["event_date"].astype(str) + expected_df["event_date"] = expected_df["event_date"].astype(str) + + pd.testing.assert_frame_equal( + actual_df, + expected_df, + check_dtype=False, + ) + + finally: + ctx.cleanup(context) + + +def test_managed_model_upstream_forward_only(ctx: TestContext): + """ + This scenario goes as follows: + - A managed model B is a downstream dependency of an incremental model A + (as a sidenote: this is an incorrect use of managed models, they should really only reference external models, but we dont prevent it specifically to be more user friendly) + - User plans a forward-only change against Model A in a virtual environment "dev" + - This causes a new non-deployable snapshot of Model B in "dev". + - In these situations, we create a normal table for Model B, not a managed table + - User modifies model B and applies a plan in "dev" + - This should also result in a normal table + - User decides they want to deploy so they run their plan against prod + - We need to ensure we ignore the normal table for Model B (it was just a dev preview) and create a new managed table for prod + - Upon apply to prod, Model B should be completely recreated as a managed table + """ + + if not ctx.engine_adapter.SUPPORTS_MANAGED_MODELS: + pytest.skip("This test only runs for engines that support managed models") + + def _run_plan(sqlmesh_context: Context, environment: str = None) -> PlanResults: + plan: Plan = sqlmesh_context.plan(auto_apply=True, no_prompts=True, environment=environment) + return PlanResults.create(plan, ctx, schema) + + context = ctx.create_context() + schema = ctx.add_test_suffix(TEST_SCHEMA) + + model_a = load_sql_based_model( + d.parse( # type: ignore + f""" + MODEL ( + name {schema}.upstream_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ts, + forward_only True + ), + ); + + SELECT 1 as id, 'foo' as name, current_timestamp as ts; + """ + ) + ) + + model_b = load_sql_based_model( + d.parse( # type: ignore + f""" + MODEL ( + name {schema}.managed_model, + kind MANAGED, + physical_properties ( + target_lag = '5 minutes' + ) + ); + + SELECT * from {schema}.upstream_model; + """ + ) + ) + + context.upsert_model(model_a) + context.upsert_model(model_b) + + plan_1 = _run_plan(context) + + assert plan_1.snapshot_for(model_a).change_category == SnapshotChangeCategory.BREAKING + assert not plan_1.snapshot_for(model_a).is_forward_only + assert plan_1.snapshot_for(model_b).change_category == SnapshotChangeCategory.BREAKING + assert not plan_1.snapshot_for(model_b).is_forward_only + + # so far so good, model_a should exist as a normal table, model b should be a managed table and the prod views should exist + assert len(plan_1.schema_metadata.views) == 2 + assert plan_1.snapshot_for(model_a).model.view_name in plan_1.schema_metadata.views + assert plan_1.snapshot_for(model_b).model.view_name in plan_1.schema_metadata.views + + assert len(plan_1.internal_schema_metadata.tables) == 1 + + assert plan_1.table_name_for(model_a) in plan_1.internal_schema_metadata.tables + assert ( + plan_1.table_name_for(model_b) not in plan_1.internal_schema_metadata.tables + ) # because its a managed table + + assert len(plan_1.internal_schema_metadata.managed_tables) == 1 + assert plan_1.table_name_for(model_b) in plan_1.internal_schema_metadata.managed_tables + assert ( + plan_1.dev_table_name_for(model_b) not in plan_1.internal_schema_metadata.managed_tables + ) # the dev table should not be created as managed + + # Let's modify model A with a breaking change and plan it against a dev environment. This should trigger a forward-only plan + new_model_a = load_sql_based_model( + d.parse( # type: ignore + f""" + MODEL ( + name {schema}.upstream_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ts, + forward_only True + ), + ); + + SELECT 1 as id, 'foo' as name, 'bar' as extra, current_timestamp as ts; + """ + ) + ) + context.upsert_model(new_model_a) + + # apply plan to dev environment + plan_2 = _run_plan(context, "dev") + + assert plan_2.plan.has_changes + assert len(plan_2.plan.modified_snapshots) == 2 + assert plan_2.snapshot_for(new_model_a).change_category == SnapshotChangeCategory.NON_BREAKING + assert plan_2.snapshot_for(new_model_a).is_forward_only + assert plan_2.snapshot_for(model_b).change_category == SnapshotChangeCategory.NON_BREAKING + assert not plan_2.snapshot_for(model_b).is_forward_only + + # verify that the new snapshots were created correctly + # the forward-only change to model A should be in a new table separate from the one created in the first plan + # since model B depends on an upstream model with a forward-only change, it should also get recreated, but as a normal table, not a managed table + assert plan_2.table_name_for(model_a) == plan_1.table_name_for( + model_a + ) # no change in the main table because the dev preview changes go to the dev table + assert plan_2.dev_table_name_for(model_a) != plan_1.dev_table_name_for( + model_a + ) # it creates a new dev table to hold the dev preview + assert plan_2.dev_table_name_for(model_a) in plan_2.internal_schema_metadata.tables + + assert plan_2.table_name_for(model_b) != plan_1.table_name_for( + model_b + ) # model b gets a new table + assert plan_2.dev_table_name_for(model_b) != plan_1.dev_table_name_for( + model_b + ) # model b gets a new dev table as well + assert ( + plan_2.table_name_for(model_b) not in plan_2.internal_schema_metadata.tables + ) # the new main table is not actually created, because it was triggered by a forward-only change. downstream models use the dev table + assert plan_2.table_name_for(model_b) not in plan_2.internal_schema_metadata.managed_tables + assert ( + plan_2.dev_table_name_for(model_b) in plan_2.internal_schema_metadata.tables + ) # dev tables are always regular tables for managed models + + # modify model B, still in the dev environment + new_model_b = load_sql_based_model( + d.parse( # type: ignore + f""" + MODEL ( + name {schema}.managed_model, + kind MANAGED, + physical_properties ( + target_lag = '5 minutes' + ) + ); + + SELECT *, 'modified' as extra_b from {schema}.upstream_model; + """ + ) + ) + context.upsert_model(new_model_b) + + plan_3 = _run_plan(context, "dev") + + assert plan_3.plan.has_changes + assert len(plan_3.plan.modified_snapshots) == 1 + assert ( + plan_3.modified_snapshot_for(model_b).change_category == SnapshotChangeCategory.NON_BREAKING + ) + + # model A should be unchanged + # the new model B should be a normal table, not a managed table + assert plan_3.table_name_for(model_a) == plan_2.table_name_for(model_a) + assert plan_3.dev_table_name_for(model_a) == plan_2.dev_table_name_for(model_a) + assert plan_3.table_name_for(model_b) != plan_2.table_name_for(model_b) + assert plan_3.dev_table_name_for(model_b) != plan_2.table_name_for(model_b) + + assert ( + plan_3.table_name_for(model_b) not in plan_3.internal_schema_metadata.tables + ) # still using the dev table, no main table created + assert plan_3.dev_table_name_for(model_b) in plan_3.internal_schema_metadata.tables + assert ( + plan_3.table_name_for(model_b) not in plan_3.internal_schema_metadata.managed_tables + ) # still not a managed table + + # apply plan to prod + plan_4 = _run_plan(context) + + assert plan_4.plan.has_changes + assert plan_4.snapshot_for(model_a).change_category == SnapshotChangeCategory.NON_BREAKING + assert plan_4.snapshot_for(model_a).is_forward_only + assert plan_4.snapshot_for(model_b).change_category == SnapshotChangeCategory.NON_BREAKING + assert not plan_4.snapshot_for(model_b).is_forward_only + + # verify the Model B table is created as a managed table in prod + assert plan_4.table_name_for(model_b) == plan_3.table_name_for( + model_b + ) # the model didnt change; the table should still have the same name + assert ( + plan_4.table_name_for(model_b) not in plan_4.internal_schema_metadata.tables + ) # however, it should be a managed table, not a normal table + assert plan_4.table_name_for(model_b) in plan_4.internal_schema_metadata.managed_tables + + +@pytest.mark.parametrize( + "column_type, input_data, expected_results", + [ + (DATA_TYPE.BOOLEAN, (True, False, None), ("1", "0", None)), + ( + DATA_TYPE.DATE, + (datetime(2023, 1, 1).date(), datetime(2024, 12, 15, 5, 30, 0).date(), None), + ("2023-01-01", "2024-12-15", None), + ), + ( + DATA_TYPE.TIMESTAMP, + ( + datetime(2023, 1, 1), + datetime(2023, 1, 1, 13, 14, 15), + datetime(2023, 1, 1, 13, 14, 15, 123456), + None, + ), + ( + "2023-01-01 00:00:00.000000", + "2023-01-01 13:14:15.000000", + "2023-01-01 13:14:15.123456", + None, + ), + ), + ( + DATA_TYPE.DATETIME, + ( + datetime(2023, 1, 1), + datetime(2023, 1, 1, 13, 14, 15), + datetime(2023, 1, 1, 13, 14, 15, 123456), + None, + ), + ( + "2023-01-01 00:00:00.000000", + "2023-01-01 13:14:15.000000", + "2023-01-01 13:14:15.123456", + None, + ), + ), + ( + DATA_TYPE.TIMESTAMPTZ, + ( + pytz.timezone("America/Los_Angeles").localize(datetime(2023, 1, 1)), + pytz.timezone("Europe/Athens").localize(datetime(2023, 1, 1, 13, 14, 15)), + pytz.timezone("Pacific/Auckland").localize( + datetime(2023, 1, 1, 13, 14, 15, 123456) + ), + None, + ), + ( + "2023-01-01 08:00:00.000000", + "2023-01-01 11:14:15.000000", + "2023-01-01 00:14:15.123456", + None, + ), + ), + ], +) +def test_value_normalization( + ctx: TestContext, + column_type: exp.DataType.Type, + input_data: t.Tuple[t.Any, ...], + expected_results: t.Tuple[str, ...], +) -> None: + # Skip TIMESTAMPTZ tests for engines that don't support it + if column_type == exp.DataType.Type.TIMESTAMPTZ: + if ctx.dialect == "trino" and ctx.engine_adapter.current_catalog_type == "hive": + pytest.skip("Trino on Hive doesn't support TIMESTAMP WITH TIME ZONE fields") + if ctx.dialect == "fabric": + pytest.skip("Fabric doesn't support TIMESTAMP WITH TIME ZONE fields") + + if not isinstance(ctx.engine_adapter, RowDiffMixin): + pytest.skip( + "Value normalization tests are only relevant for engines with row diffing implemented" + ) + + full_column_type = exp.DataType.build(column_type) + + # resolve dialect-specific types + if column_type in (DATA_TYPE.DATETIME, DATA_TYPE.TIMESTAMP, DATA_TYPE.TIMESTAMPTZ): + if ctx.dialect in ("mysql", "trino"): + # MySQL needs DATETIME(6) instead of DATETIME or subseconds will be truncated. + # It also needs TIMESTAMP(6) as the column type for CREATE TABLE or the truncation will occur + full_column_type = exp.DataType.build( + column_type, + expressions=[ + exp.DataTypeParam( + this=exp.Literal.number(ctx.engine_adapter.MAX_TIMESTAMP_PRECISION) + ) + ], + ) + if ctx.dialect == "tsql" and column_type == exp.DataType.Type.DATETIME: + full_column_type = exp.DataType.build("DATETIME2", dialect="tsql") + + columns_to_types = { + "_idx": exp.DataType.build(DATA_TYPE.INT), + "value": full_column_type, + } + + input_data_with_idx = [(idx, value) for idx, value in enumerate(input_data)] + + test_table = normalize_identifiers( + exp.to_table(ctx.table("test_value_normalization")), dialect=ctx.dialect + ) + columns_to_types_normalized = { + normalize_identifiers(k, dialect=ctx.dialect).sql(dialect=ctx.dialect): v + for k, v in columns_to_types.items() + } + + ctx.engine_adapter.create_table( + table_name=test_table, target_columns_to_types=columns_to_types_normalized + ) + data_query = next(select_from_values(input_data_with_idx, columns_to_types_normalized)) + ctx.engine_adapter.insert_append( + table_name=test_table, + query_or_df=data_query, + target_columns_to_types=columns_to_types_normalized, + ) + + query = ( + exp.select( + ctx.engine_adapter.normalize_value( + normalize_identifiers(exp.to_column("value"), dialect=ctx.dialect), + columns_to_types["value"], + decimal_precision=3, + timestamp_precision=ctx.engine_adapter.MAX_TIMESTAMP_PRECISION, + ).as_("value") + ) + .from_(test_table) + .order_by(normalize_identifiers("_idx", dialect=ctx.dialect)) + ) + result = ctx.engine_adapter.fetchdf(query, quote_identifiers=True) + assert len(result) == len(expected_results) + + def truncate_timestamp(ts: str, precision: int) -> str: + if not ts: + return ts + + digits_to_truncate = 6 - precision + return ts[:-digits_to_truncate] if digits_to_truncate > 0 else ts + + if full_column_type.is_type(DATA_TYPE.DATETIME, DATA_TYPE.TIMESTAMP, DATA_TYPE.TIMESTAMPTZ): + # truncate our expected results to the engine precision + expected_results = tuple( + truncate_timestamp(e, ctx.engine_adapter.MAX_TIMESTAMP_PRECISION) + for e in expected_results + ) + + for idx, row in enumerate(result.itertuples(index=False)): + assert row.value == expected_results[idx] + + +def test_table_diff_grain_check_single_key(ctx: TestContext): + if not isinstance(ctx.engine_adapter, RowDiffMixin): + pytest.skip("table_diff tests are only relevant for engines with row diffing implemented") + + src_table = ctx.table("source") + target_table = ctx.table("target") + + columns_to_types = { + "key1": exp.DataType.build("int"), + "value": exp.DataType.build("varchar"), + } + + ctx.engine_adapter.create_table(src_table, columns_to_types) + ctx.engine_adapter.create_table(target_table, columns_to_types) + + src_data = [ + (1, "one"), + (2, "two"), + (None, "three"), + (4, "four"), # missing in target + ] + + target_data = [ + (1, "one"), + (2, "two"), + (None, "three"), + (5, "five"), # missing in src + (6, "six"), # missing in src + ] + + ctx.engine_adapter.replace_query( + src_table, pd.DataFrame(src_data, columns=columns_to_types.keys()), columns_to_types + ) + ctx.engine_adapter.replace_query( + target_table, pd.DataFrame(target_data, columns=columns_to_types.keys()), columns_to_types + ) + + table_diff = TableDiff( + adapter=ctx.engine_adapter, + source=exp.table_name(src_table), + target=exp.table_name(target_table), + on=['"key1"'], + ) + + row_diff = table_diff.row_diff() + + assert row_diff.full_match_count == 2 + assert row_diff.full_match_pct == 57.14 + assert row_diff.s_only_count == 1 + assert row_diff.t_only_count == 2 + assert row_diff.stats["key1_matches"] == 4 + assert row_diff.stats["value_matches"] == 2 + assert row_diff.stats["join_count"] == 2 + assert row_diff.stats["null_grain_count"] == 2 + assert row_diff.stats["s_count"] == 3 + assert row_diff.stats["distinct_count_s"] == 3 + assert row_diff.stats["t_count"] == 4 + assert row_diff.stats["distinct_count_t"] == 4 + assert row_diff.stats["s_only_count"] == 1 + assert row_diff.stats["t_only_count"] == 2 + assert row_diff.s_sample.shape == (1, 2) + assert row_diff.t_sample.shape == (2, 2) + + +def test_table_diff_grain_check_multiple_keys(ctx: TestContext): + if not isinstance(ctx.engine_adapter, RowDiffMixin): + pytest.skip("table_diff tests are only relevant for engines with row diffing implemented") + + src_table = ctx.table("source") + target_table = ctx.table("target") + + columns_to_types = { + "key1": exp.DataType.build("int"), + "key2": exp.DataType.build("varchar"), + "value": exp.DataType.build("varchar"), + } + + ctx.engine_adapter.create_table(src_table, columns_to_types) + ctx.engine_adapter.create_table(target_table, columns_to_types) + + src_data = [ + (1, 1, 1), + (7, 4, 2), + (None, 3, 3), + (None, None, 3), + (1, 2, 2), + (4, None, 3), + (2, 3, 2), + ] + + target_data = src_data + [(1, 6, 1), (1, 5, 3), (None, 2, 3)] + + ctx.engine_adapter.insert_append( + src_table, next(select_from_values(src_data, columns_to_types)), columns_to_types + ) + ctx.engine_adapter.insert_append( + target_table, next(select_from_values(target_data, columns_to_types)), columns_to_types + ) + + table_diff = TableDiff( + adapter=ctx.engine_adapter, + source=exp.table_name(src_table), + target=exp.table_name(target_table), + on=['"key1"', '"key2"'], + ) + + row_diff = table_diff.row_diff() + + assert row_diff.full_match_count == 7 + assert row_diff.full_match_pct == 82.35 + assert row_diff.s_only_count == 0 + assert row_diff.t_only_count == 3 + assert row_diff.stats["join_count"] == 7 + assert ( + row_diff.stats["null_grain_count"] == 4 + ) # null grain currently (2025-07-24) means "any key column is null" as opposed to "all key columns are null" + assert row_diff.stats["distinct_count_s"] == 7 + assert row_diff.stats["s_count"] == row_diff.stats["distinct_count_s"] + assert row_diff.stats["distinct_count_t"] == 10 + assert row_diff.stats["t_count"] == row_diff.stats["distinct_count_t"] + assert row_diff.s_sample.shape == (row_diff.s_only_count, 3) + assert row_diff.t_sample.shape == (row_diff.t_only_count, 3) + + +def test_table_diff_arbitrary_condition(ctx: TestContext): + if not isinstance(ctx.engine_adapter, RowDiffMixin): + pytest.skip("table_diff tests are only relevant for engines with row diffing implemented") + + src_table = ctx.table("source") + target_table = ctx.table("target") + + columns_to_types_src = { + "id": exp.DataType.build("int"), + "value": exp.DataType.build("varchar"), + "ts": exp.DataType.build("timestamp"), + } + + columns_to_types_target = { + "item_id": exp.DataType.build("int"), + "value": exp.DataType.build("varchar"), + "ts": exp.DataType.build("timestamp"), + } + + ctx.engine_adapter.create_table(src_table, columns_to_types_src) + ctx.engine_adapter.create_table(target_table, columns_to_types_target) + + src_data = [ + (1, "one", datetime(2023, 1, 1, 12, 13, 14)), + (2, "two", datetime(2023, 10, 1, 8, 13, 14)), + (3, "three", datetime(2024, 1, 1, 8, 13, 14)), + ] + + target_data = src_data + [(4, "four", datetime(2024, 2, 1, 8, 13, 14))] + + ctx.engine_adapter.replace_query( + src_table, pd.DataFrame(src_data, columns=columns_to_types_src.keys()), columns_to_types_src + ) + ctx.engine_adapter.replace_query( + target_table, + pd.DataFrame(target_data, columns=columns_to_types_target.keys()), + columns_to_types_target, + ) + + table_diff = TableDiff( + adapter=ctx.engine_adapter, + source=exp.table_name(src_table), + target=exp.table_name(target_table), + on=parse_one('"s"."id" = "t"."item_id"', into=exp.Condition), + where=parse_one("to_char(\"ts\", 'YYYY') = '2024'", dialect="postgres", into=exp.Condition), + ) + + row_diff = table_diff.row_diff() + + assert row_diff.full_match_count == 1 + assert row_diff.full_match_pct == 66.67 + assert row_diff.s_only_count == 0 + assert row_diff.t_only_count == 1 + assert row_diff.stats["value_matches"] == 1 + assert row_diff.stats["ts_matches"] == 1 + assert row_diff.stats["join_count"] == 1 + assert row_diff.stats["null_grain_count"] == 0 + assert row_diff.stats["s_count"] == 1 + assert row_diff.stats["distinct_count_s"] == 1 + assert row_diff.stats["t_count"] == 2 + assert row_diff.stats["distinct_count_t"] == 2 + assert row_diff.stats["s_only_count"] == 0 + assert row_diff.stats["t_only_count"] == 1 + assert row_diff.s_sample.shape == (0, 3) + assert row_diff.t_sample.shape == (1, 3) + + +def test_table_diff_identical_dataset(ctx: TestContext): + if not isinstance(ctx.engine_adapter, RowDiffMixin): + pytest.skip("table_diff tests are only relevant for engines with row diffing implemented") + + src_table = ctx.table("source") + target_table = ctx.table("target") + + columns_to_types = { + "key1": exp.DataType.build("int"), + "key2": exp.DataType.build("varchar"), + "value": exp.DataType.build("varchar"), + } + + ctx.engine_adapter.create_table(src_table, columns_to_types) + ctx.engine_adapter.create_table(target_table, columns_to_types) + + src_data = [ + (1, 1, 1), + (7, 4, 2), + (1, 2, 2), + (4, 1, 3), + (2, 3, 2), + ] + + target_data = src_data + + ctx.engine_adapter.insert_append( + src_table, next(select_from_values(src_data, columns_to_types)), columns_to_types + ) + ctx.engine_adapter.insert_append( + target_table, next(select_from_values(target_data, columns_to_types)), columns_to_types + ) + + table_diff = TableDiff( + adapter=ctx.engine_adapter, + source=exp.table_name(src_table), + target=exp.table_name(target_table), + on=['"key1"', '"key2"'], + ) + + row_diff = table_diff.row_diff() + + assert row_diff.full_match_count == 5 + assert row_diff.full_match_pct == 100 + assert row_diff.s_only_count == 0 + assert row_diff.t_only_count == 0 + assert row_diff.stats["join_count"] == 5 + assert row_diff.stats["null_grain_count"] == 0 + assert row_diff.stats["s_count"] == 5 + assert row_diff.stats["distinct_count_s"] == 5 + assert row_diff.stats["t_count"] == 5 + assert row_diff.stats["distinct_count_t"] == 5 + assert row_diff.stats["s_only_count"] == 0 + assert row_diff.stats["t_only_count"] == 0 + assert row_diff.s_sample.shape == (0, 3) + assert row_diff.t_sample.shape == (0, 3) + + +def test_state_migrate_from_scratch(ctx: TestContext): + test_schema = ctx.add_test_suffix("state") + ctx._schemas.append(test_schema) # so it gets cleaned up when the test finishes + + def _use_warehouse_as_state_connection(gateway_name: str, config: Config): + warehouse_connection = config.gateways[gateway_name].connection + assert isinstance(warehouse_connection, ConnectionConfig) + if warehouse_connection.is_forbidden_for_state_sync: + pytest.skip( + f"{warehouse_connection.type_} doesnt support being used as a state connection" + ) + + # this triggers the fallback to using the warehouse as a state connection + config.gateways[gateway_name].state_connection = None + assert config.get_state_connection(gateway_name) is None + + config.gateways[gateway_name].state_schema = test_schema + + sqlmesh_context = ctx.create_context( + config_mutator=_use_warehouse_as_state_connection, ephemeral_state_connection=False + ) + assert sqlmesh_context.config.get_state_schema(ctx.gateway) == test_schema + + state_sync = ( + sqlmesh_context._new_state_sync() + ) # this prevents migrate() being called which it does if you access the state_sync property + assert isinstance(state_sync, EngineAdapterStateSync) + assert state_sync.engine_adapter.dialect == ctx.dialect + + # will throw if one of the migrations produces an error, which can happen if we forget to take quoting or normalization into account + sqlmesh_context.migrate() + + +def test_python_model_column_order(ctx_df: TestContext, tmp_path: pathlib.Path): + ctx = ctx_df + + model_name = ctx.table("TEST") + + (tmp_path / "models").mkdir() + + # note: this model deliberately defines the columns in the @model definition to be in a different order than what + # is returned by the DataFrame within the model + model_path = tmp_path / "models" / "python_model.py" + + model_definitions = { + # python model that emits a Pandas dataframe + "pandas": """ +import pandas as pd # noqa: TID253 +import typing as t +from sqlmesh import ExecutionContext, model + +@model( + 'MODEL_NAME', + columns={ + "id": "int", + "name": "text" + }, + dialect='DIALECT', + TABLE_FORMAT +) +def execute( + context: ExecutionContext, + **kwargs: t.Any, +) -> pd.DataFrame: + record = { "name": "foo", "id": 1 } if context.engine_adapter.dialect != 'snowflake' else { "NAME": "foo", "ID": 1 } + return pd.DataFrame([ + record + ]) + """, + # python model that emits a PySpark dataframe + "pyspark": """ +from pyspark.sql import DataFrame, Row +import typing as t +from sqlmesh import ExecutionContext, model + +@model( + 'MODEL_NAME', + columns={ + "id": "int", + "name": "varchar" + }, + dialect='DIALECT' +) +def execute( + context: ExecutionContext, + **kwargs: t.Any, +) -> DataFrame: + return context.spark.createDataFrame([ + Row(name="foo", id=1) + ]) + """, + # python model that emits a BigFrame dataframe + "bigframe": """ +from bigframes.pandas import DataFrame +import typing as t +from sqlmesh import ExecutionContext, model + +@model( + 'MODEL_NAME', + columns={ + "id": "int", + "name": "varchar" + }, + dialect="DIALECT" +) +def execute( + context: ExecutionContext, + **kwargs: t.Any, +) -> DataFrame: + return DataFrame({'name': ['foo'], 'id': [1]}, session=context.bigframe) + """, + # python model that emits a Snowpark dataframe + "snowpark": """ +from snowflake.snowpark.dataframe import DataFrame +import typing as t +from sqlmesh import ExecutionContext, model + +@model( + 'MODEL_NAME', + columns={ + "id": "int", + "name": "varchar" + }, + dialect="DIALECT" +) +def execute( + context: ExecutionContext, + **kwargs: t.Any, +) -> DataFrame: + return context.snowpark.create_dataframe([["foo", 1]], schema=["NAME", "ID"]) + """, + } + + model_path.write_text( + ( + model_definitions[ctx.df_type] + .replace("MODEL_NAME", model_name.sql(dialect=ctx.dialect)) + .replace("DIALECT", ctx.dialect) + .replace( + "TABLE_FORMAT", + f"table_format='{ctx.default_table_format}'" if ctx.default_table_format else "", + ) + ) + ) + + sqlmesh_ctx = ctx.create_context(path=tmp_path) + + assert len(sqlmesh_ctx.models) == 1 + + plan = sqlmesh_ctx.plan(auto_apply=True) + assert len(plan.new_snapshots) == 1 + + engine_adapter = sqlmesh_ctx.engine_adapter + + query = exp.select("*").from_(plan.environment.snapshots[0].fully_qualified_table) + df = engine_adapter.fetchdf(query, quote_identifiers=True) + assert len(df) == 1 + + # This test uses the dialect= on the model. + # For dialect=snowflake, this means that the identifiers are all normalized to uppercase by default + expected_result = ( + {"id": 1, "name": "foo"} if ctx.dialect != "snowflake" else {"ID": 1, "NAME": "foo"} + ) + assert df.iloc[0].to_dict() == expected_result + + +def test_identifier_length_limit(ctx: TestContext): + adapter = ctx.engine_adapter + if adapter.MAX_IDENTIFIER_LENGTH is None: + pytest.skip(f"Engine {adapter.dialect} does not have identifier length limits set.") + + long_table_name = "a" * (adapter.MAX_IDENTIFIER_LENGTH + 1) + + match = f"Identifier name '{long_table_name}' (length {len(long_table_name)}) exceeds {adapter.dialect.capitalize()}'s max identifier limit of {adapter.MAX_IDENTIFIER_LENGTH} characters" + with pytest.raises( + SQLMeshError, + match=re.escape(match), + ): + adapter.create_table(long_table_name, {"col": exp.DataType.build("int")}) + + +@pytest.mark.parametrize( + "environment_suffix_target", + [ + EnvironmentSuffixTarget.TABLE, + EnvironmentSuffixTarget.SCHEMA, + EnvironmentSuffixTarget.CATALOG, + ], +) +@pytest.mark.xdist_group("serial") +def test_janitor( + ctx: TestContext, tmp_path: pathlib.Path, environment_suffix_target: EnvironmentSuffixTarget +): + if ( + environment_suffix_target == EnvironmentSuffixTarget.CATALOG + and not ctx.engine_adapter.SUPPORTS_CREATE_DROP_CATALOG + ): + pytest.skip("Engine does not support catalog-based virtual environments") + + schema = ctx.schema() # catalog.schema + parsed_schema = d.to_schema(schema) + + init_example_project(tmp_path, ctx.engine_type, schema_name=parsed_schema.db) + + def _set_config(gateway: str, config: Config) -> None: + config.environment_suffix_target = environment_suffix_target + config.model_defaults.dialect = ctx.dialect + config.gateways[gateway].connection.concurrent_tasks = 1 + + sqlmesh = ctx.create_context(path=tmp_path, config_mutator=_set_config) + + sqlmesh.plan(auto_apply=True) + + # create a new model in dev + (tmp_path / "models" / "new_model.sql").write_text(f""" + MODEL ( + name {schema}.new_model, + kind FULL + ); + + select * from {schema}.full_model + """) + sqlmesh.load() + + result = sqlmesh.plan(environment="dev", auto_apply=True) + assert result.context_diff.is_new_environment + assert len(result.context_diff.new_snapshots) == 1 + new_model = list(result.context_diff.new_snapshots.values())[0] + assert "new_model" in new_model.name.lower() + + # check physical objects + snapshot_table_name = exp.to_table(new_model.table_name(), dialect=ctx.dialect) + snapshot_schema = parsed_schema.copy() + snapshot_schema.set( + "db", exp.to_identifier(snapshot_table_name.db) + ) # we need this to be catalog.schema and not just schema for environment_suffix_target: catalog + + prod_schema = normalize_identifiers(d.to_schema(schema), dialect=ctx.dialect) + dev_env_schema = prod_schema.copy() + if environment_suffix_target == EnvironmentSuffixTarget.CATALOG: + dev_env_schema.set("catalog", exp.to_identifier(f"{prod_schema.catalog}__dev")) + else: + dev_env_schema.set("db", exp.to_identifier(f"{prod_schema.db}__dev")) + normalize_identifiers(dev_env_schema, dialect=ctx.dialect) + + md = ctx.get_metadata_results(prod_schema) + if environment_suffix_target == EnvironmentSuffixTarget.TABLE: + assert sorted([v.lower() for v in md.views]) == [ + "full_model", + "incremental_model", + "new_model__dev", + "seed_model", + ] + else: + assert sorted([v.lower() for v in md.views]) == [ + "full_model", + "incremental_model", + "seed_model", + ] + assert not md.tables + assert not md.managed_tables + + if environment_suffix_target != EnvironmentSuffixTarget.TABLE: + # note: this is "catalog__dev.schema" for EnvironmentSuffixTarget.CATALOG and "catalog.schema__dev" for EnvironmentSuffixTarget.SCHEMA + md = ctx.get_metadata_results(dev_env_schema) + assert [v.lower() for v in md.views] == ["new_model"] + assert not md.tables + assert not md.managed_tables + + md = ctx.get_metadata_results(snapshot_schema) + assert not md.views + assert not md.managed_tables + assert sorted(t.split("__")[1].lower() for t in md.tables) == [ + "full_model", + "incremental_model", + "new_model", + "seed_model", + ] + + # invalidate dev and run the janitor to clean it up + sqlmesh.invalidate_environment("dev") + assert sqlmesh.run_janitor( + ignore_ttl=True + ) # ignore_ttl to delete the new_model snapshot even though it hasnt expired yet + + # there should be no dev environment or dev tables / schemas + md = ctx.get_metadata_results(prod_schema) + assert sorted([v.lower() for v in md.views]) == [ + "full_model", + "incremental_model", + "seed_model", + ] + assert not md.tables + assert not md.managed_tables + + if environment_suffix_target != EnvironmentSuffixTarget.TABLE: + if environment_suffix_target == EnvironmentSuffixTarget.SCHEMA: + md = ctx.get_metadata_results(dev_env_schema) + else: + try: + md = ctx.get_metadata_results(dev_env_schema) + except Exception as e: + # Most engines will raise an error when @set_catalog tries to set a catalog that doesnt exist + # in this case, we just swallow the error. We know this call already worked before in the earlier checks + md = MetadataResults() + + assert not md.views + assert not md.tables + assert not md.managed_tables + + if ctx.dialect == "fabric": + # TestContext is using a different EngineAdapter instance / connection pool instance to the SQLMesh context + # When the SQLMesh context drops :snapshot_schema using its EngineAdapter, connections in TestContext are unaware + # and still have their threadlocal "target_catalog" attribute pointing to a catalog that no longer exists + # Trying to establish a connection to a nonexistant catalog produces an error, so we close all connections here + # to clear the threadlocal attributes + ctx.engine_adapter.close() + + md = ctx.get_metadata_results(snapshot_schema) + assert not md.views + assert not md.managed_tables + assert sorted(t.split("__")[1].lower() for t in md.tables) == [ + "full_model", + "incremental_model", + "seed_model", + ] + + +def test_materialized_view_evaluation(ctx: TestContext): + adapter = ctx.engine_adapter + dialect = ctx.dialect + + if not adapter.SUPPORTS_MATERIALIZED_VIEWS: + pytest.skip(f"Skipping engine {dialect} as it does not support materialized views") + elif dialect in ("snowflake", "databricks"): + pytest.skip(f"Skipping {dialect} as they're not enabled on standard accounts") + + model_name = ctx.table("test_tbl") + mview_name = ctx.table("test_mview") + + sqlmesh = ctx.create_context() + + sqlmesh.upsert_model( + load_sql_based_model( + d.parse( + f""" + MODEL (name {model_name}, kind FULL); + + SELECT 1 AS col + """ + ) + ) + ) + + sqlmesh.upsert_model( + load_sql_based_model( + d.parse( + f""" + MODEL (name {mview_name}, kind VIEW (materialized true)); + + SELECT * FROM {model_name} + """ + ) + ) + ) + + def _assert_mview_value(value: int): + df = adapter.fetchdf(f"SELECT * FROM {mview_name.sql(dialect=dialect)}") + assert df["col"][0] == value + + # Case 1: Ensure that plan is successful and we can query the materialized view + sqlmesh.plan(auto_apply=True, no_prompts=True) + + _assert_mview_value(value=1) + + # Case 2: Ensure that we can change the underlying table and the materialized view is recreated + sqlmesh.upsert_model( + load_sql_based_model(d.parse(f"""MODEL (name {model_name}, kind FULL); SELECT 2 AS col""")) + ) + + logger = logging.getLogger("sqlmesh.core.snapshot.evaluator") + + with mock.patch.object(logger, "info") as mock_logger: + sqlmesh.plan(auto_apply=True, no_prompts=True) + + assert any("Replacing view" in call[0][0] for call in mock_logger.call_args_list) + + _assert_mview_value(value=2) + + +def test_unicode_characters(ctx: TestContext, tmp_path: Path): + # Engines that don't quote identifiers in views are incompatible with unicode characters in model names + # at the time of writing this is Spark/Trino and they do this for compatibility reasons. + # I also think Spark may not support unicode in general but that would need to be verified. + if not ctx.engine_adapter.QUOTE_IDENTIFIERS_IN_VIEWS: + pytest.skip("Skipping as these engines have issues with unicode characters in model names") + + model_name = "客户数据" + table = ctx.table(model_name).sql(dialect=ctx.dialect) + (tmp_path / "models").mkdir(exist_ok=True) + + model_def = f""" + MODEL ( + name {table}, + kind FULL, + dialect '{ctx.dialect}' + ); + SELECT 1 as id + """ + + (tmp_path / "models" / "客户数据.sql").write_text(model_def) + + context = ctx.create_context(path=tmp_path) + context.plan(auto_apply=True, no_prompts=True) + + results = ctx.get_metadata_results() + assert len(results.views) == 1 + assert results.views[0].lower() == model_name + + schema = d.to_schema(ctx.schema(), dialect=ctx.dialect) + schema_name = schema.args["db"].this + schema.args["db"].set("this", "sqlmesh__" + schema_name) + table_results = ctx.get_metadata_results(schema) + assert len(table_results.tables) == 1 + assert table_results.tables[0].lower().startswith(schema_name.lower() + "________") + + +def test_sync_grants_config(ctx: TestContext) -> None: + if not ctx.engine_adapter.SUPPORTS_GRANTS: + pytest.skip( + f"Skipping Test since engine adapter {ctx.engine_adapter.dialect} doesn't support grants" + ) + + table = ctx.table("sync_grants_integration") + select_privilege = ctx.get_select_privilege() + insert_privilege = ctx.get_insert_privilege() + update_privilege = ctx.get_update_privilege() + with ctx.create_users_or_roles("reader", "writer", "admin") as roles: + ctx.engine_adapter.create_table(table, {"id": exp.DataType.build("INT")}) + + initial_grants = { + select_privilege: [roles["reader"]], + insert_privilege: [roles["writer"]], + } + ctx.engine_adapter.sync_grants_config(table, initial_grants) + + current_grants = ctx.engine_adapter._get_current_grants_config(table) + assert set(current_grants.get(select_privilege, [])) == {roles["reader"]} + assert set(current_grants.get(insert_privilege, [])) == {roles["writer"]} + + target_grants = { + select_privilege: [roles["writer"], roles["admin"]], + update_privilege: [roles["admin"]], + } + ctx.engine_adapter.sync_grants_config(table, target_grants) + + synced_grants = ctx.engine_adapter._get_current_grants_config(table) + assert set(synced_grants.get(select_privilege, [])) == { + roles["writer"], + roles["admin"], + } + assert set(synced_grants.get(update_privilege, [])) == {roles["admin"]} + assert synced_grants.get(insert_privilege, []) == [] + + +def test_grants_sync_empty_config(ctx: TestContext): + if not ctx.engine_adapter.SUPPORTS_GRANTS: + pytest.skip( + f"Skipping Test since engine adapter {ctx.engine_adapter.dialect} doesn't support grants" + ) + + table = ctx.table("grants_empty_test") + select_privilege = ctx.get_select_privilege() + insert_privilege = ctx.get_insert_privilege() + with ctx.create_users_or_roles("user") as roles: + ctx.engine_adapter.create_table(table, {"id": exp.DataType.build("INT")}) + + initial_grants = { + select_privilege: [roles["user"]], + insert_privilege: [roles["user"]], + } + ctx.engine_adapter.sync_grants_config(table, initial_grants) + + initial_current_grants = ctx.engine_adapter._get_current_grants_config(table) + assert roles["user"] in initial_current_grants.get(select_privilege, []) + assert roles["user"] in initial_current_grants.get(insert_privilege, []) + + ctx.engine_adapter.sync_grants_config(table, {}) + + final_grants = ctx.engine_adapter._get_current_grants_config(table) + assert final_grants == {} + + +def test_grants_case_insensitive_grantees(ctx: TestContext): + if not ctx.engine_adapter.SUPPORTS_GRANTS: + pytest.skip( + f"Skipping Test since engine adapter {ctx.engine_adapter.dialect} doesn't support grants" + ) + + with ctx.create_users_or_roles("reader", "writer") as roles: + table = ctx.table("grants_quoted_test") + ctx.engine_adapter.create_table(table, {"id": exp.DataType.build("INT")}) + + reader = roles["reader"] + writer = roles["writer"] + select_privilege = ctx.get_select_privilege() + + if ctx.dialect == "bigquery": + # BigQuery labels are case sensitive, e.g. serviceAccount + lablel, grantee = writer.split(":", 1) + upper_case_writer = f"{lablel}:{grantee.upper()}" + else: + upper_case_writer = writer.upper() + + grants_config = {select_privilege: [reader, upper_case_writer]} + ctx.engine_adapter.sync_grants_config(table, grants_config) + + # Grantees are still in lowercase + current_grants = ctx.engine_adapter._get_current_grants_config(table) + assert reader in current_grants.get(select_privilege, []) + assert writer in current_grants.get(select_privilege, []) + + # Revoke writer + grants_config = {select_privilege: [reader.upper()]} + ctx.engine_adapter.sync_grants_config(table, grants_config) + + current_grants = ctx.engine_adapter._get_current_grants_config(table) + assert reader in current_grants.get(select_privilege, []) + assert writer not in current_grants.get(select_privilege, []) + + +def test_grants_plan(ctx: TestContext, tmp_path: Path): + if not ctx.engine_adapter.SUPPORTS_GRANTS: + pytest.skip( + f"Skipping Test since engine adapter {ctx.engine_adapter.dialect} doesn't support grants" + ) + + table = ctx.table("grant_model").sql(dialect="duckdb") + select_privilege = ctx.get_select_privilege() + insert_privilege = ctx.get_insert_privilege() + with ctx.create_users_or_roles("analyst", "etl_user") as roles: + (tmp_path / "models").mkdir(exist_ok=True) + + model_def = f""" + MODEL ( + name {table}, + kind FULL, + grants ( + '{select_privilege}' = ['{roles["analyst"]}'] + ), + grants_target_layer 'all' + ); + SELECT 1 as id, CURRENT_DATE as created_date + """ + + (tmp_path / "models" / "grant_model.sql").write_text(model_def) + + context = ctx.create_context(path=tmp_path) + plan_result = context.plan(auto_apply=True, no_prompts=True) + + assert len(plan_result.new_snapshots) == 1 + snapshot = plan_result.new_snapshots[0] + + # Physical layer w/ grants + table_name = snapshot.table_name() + view_name = snapshot.qualified_view_name.for_environment( + plan_result.environment_naming_info, dialect=ctx.dialect + ) + current_grants = ctx.engine_adapter._get_current_grants_config( + exp.to_table(table_name, dialect=ctx.dialect) + ) + assert current_grants == {select_privilege: [roles["analyst"]]} + + # Virtual layer (view) w/ grants + virtual_grants = ctx.engine_adapter._get_current_grants_config( + exp.to_table(view_name, dialect=ctx.dialect) + ) + assert virtual_grants == {select_privilege: [roles["analyst"]]} + + # Update model with query change and new grants + updated_model = load_sql_based_model( + d.parse( + f""" + MODEL ( + name {table}, + kind FULL, + grants ( + '{select_privilege}' = ['{roles["analyst"]}', '{roles["etl_user"]}'], + '{insert_privilege}' = ['{roles["etl_user"]}'] + ), + grants_target_layer 'all' + ); + SELECT 1 as id, CURRENT_DATE as created_date, 'v2' as version + """, + default_dialect=context.default_dialect, + ), + dialect=context.default_dialect, + ) + context.upsert_model(updated_model) + + plan = context.plan(auto_apply=True, no_prompts=True) + plan_result = PlanResults.create(plan, ctx, ctx.add_test_suffix(TEST_SCHEMA)) + assert len(plan_result.plan.directly_modified) == 1 + + new_snapshot = plan_result.snapshot_for(updated_model) + assert new_snapshot is not None + + new_table_name = new_snapshot.table_name() + final_grants = ctx.engine_adapter._get_current_grants_config( + exp.to_table(new_table_name, dialect=ctx.dialect) + ) + expected_final_grants = { + select_privilege: [roles["analyst"], roles["etl_user"]], + insert_privilege: [roles["etl_user"]], + } + assert set(final_grants.get(select_privilege, [])) == set( + expected_final_grants[select_privilege] + ) + assert final_grants.get(insert_privilege, []) == expected_final_grants[insert_privilege] + + # Virtual layer should also have the updated grants + updated_virtual_grants = ctx.engine_adapter._get_current_grants_config( + exp.to_table(view_name, dialect=ctx.dialect) + ) + assert set(updated_virtual_grants.get(select_privilege, [])) == set( + expected_final_grants[select_privilege] + ) + assert ( + updated_virtual_grants.get(insert_privilege, []) + == expected_final_grants[insert_privilege] + ) diff --git a/tests/core/engine_adapter/integration/test_integration_athena.py b/tests/core/engine_adapter/integration/test_integration_athena.py new file mode 100644 index 0000000000..1c0ece6d78 --- /dev/null +++ b/tests/core/engine_adapter/integration/test_integration_athena.py @@ -0,0 +1,547 @@ +import typing as t +import pytest +from pytest import FixtureRequest +import pandas as pd # noqa: TID253 +import datetime +from sqlmesh.core.engine_adapter import AthenaEngineAdapter +from sqlmesh.utils.aws import parse_s3_uri +from sqlmesh.utils.pandas import columns_to_types_from_df +from sqlmesh.utils.date import to_ds, to_ts, TimeLike +import dataclasses +from tests.core.engine_adapter.integration import ( + TestContext, + generate_pytest_params, + ENGINES_BY_NAME, + IntegrationTestEngine, +) +from sqlglot import exp + +# The tests in this file dont need to be called twice, so we create a single instance of Athena +ENGINE_ATHENA = dataclasses.replace(ENGINES_BY_NAME["athena"], catalog_types=None) +assert isinstance(ENGINE_ATHENA, IntegrationTestEngine) + + +@pytest.fixture(params=list(generate_pytest_params(ENGINE_ATHENA))) +def ctx( + request: FixtureRequest, + create_test_context: t.Callable[[IntegrationTestEngine, str, str], t.Iterable[TestContext]], +) -> t.Iterable[TestContext]: + yield from create_test_context(*request.param) + + +@pytest.fixture +def engine_adapter(ctx: TestContext) -> AthenaEngineAdapter: + assert isinstance(ctx.engine_adapter, AthenaEngineAdapter) + return ctx.engine_adapter + + +@pytest.fixture +def s3(engine_adapter: AthenaEngineAdapter) -> t.Any: + return engine_adapter._s3_client + + +def s3_list_objects(s3: t.Any, location: str, **list_objects_kwargs: t.Any) -> t.List[str]: + bucket, prefix = parse_s3_uri(location) + lst = [] + for page in s3.get_paginator("list_objects_v2").paginate(Bucket=bucket, Prefix=prefix): + lst.extend([o["Key"] for o in page.get("Contents", [])]) + return lst + + +def test_clear_partition_data(ctx: TestContext, engine_adapter: AthenaEngineAdapter, s3: t.Any): + base_uri = engine_adapter.s3_warehouse_location_or_raise + assert len(s3_list_objects(s3, base_uri)) == 0 + + src_table = ctx.table("src_table") + test_table = ctx.table("test_table") + + base_data = pd.DataFrame( + [ + {"id": 1, "ts": datetime.datetime(2023, 1, 1, 12, 13, 14)}, + {"id": 2, "ts": datetime.datetime(2023, 1, 2, 8, 10, 0)}, + {"id": 3, "ts": datetime.datetime(2023, 1, 3, 16, 5, 14)}, + ] + ) + + engine_adapter.ctas( + table_name=src_table, + query_or_df=base_data, + ) + + sqlmesh_context, model = ctx.upsert_sql_model( + f""" + MODEL ( + name {test_table}, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds + ), + start '2023-01-01' + ); + + SELECT + id, ts, (ts::date)::varchar as ds + FROM {src_table} + WHERE ts BETWEEN @start_dt AND @end_dt + """ + ) + + plan = sqlmesh_context.plan(no_prompts=True, auto_apply=True) + assert len(plan.snapshots) == 1 + test_table_snapshot = list(plan.snapshots.values())[0] + + files_before = s3_list_objects(s3, base_uri) + assert len(files_before) > 0 + + # src_table should have no partitions + with pytest.raises(Exception, match=r".*TABLE_NOT_FOUND.*\$partitions"): + engine_adapter._list_partitions(src_table) + + # test_table physical snapshot table should have 3 partitions + test_table_physical_name = exp.to_table(test_table_snapshot.table_name()) + partitions = engine_adapter._list_partitions(test_table_physical_name, where=None) + assert len(partitions) == 3 + assert [p[0] for p in partitions] == [["2023-01-01"], ["2023-01-02"], ["2023-01-03"]] + + assert engine_adapter.fetchone(f"select count(*) from {test_table}")[0] == 3 # type: ignore + + # clear a partition + assert model.time_column + engine_adapter._clear_partition_data( + table=test_table_physical_name, + where=exp.Between( + this=model.time_column.column, + low=exp.Literal.string("2023-01-01"), + high=exp.Literal.string("2023-01-01"), + ), + ) + partitions = engine_adapter._list_partitions(test_table_physical_name, where=None) + assert len(partitions) == 2 + assert [p[0] for p in partitions] == [["2023-01-02"], ["2023-01-03"]] + + # test that only S3 data for that partition was affected + files_after = s3_list_objects(s3, base_uri) + assert len(files_after) == len(files_before) - 1 + assert len([f for f in files_before if "ds=2023-01-01" in f]) == 1 + assert len([f for f in files_after if "ds=2023-01-01" in f]) == 0 + + assert engine_adapter.fetchone(f"select count(*) from {test_table}")[0] == 2 # type: ignore + + +def test_clear_partition_data_multiple_columns( + ctx: TestContext, engine_adapter: AthenaEngineAdapter, s3: t.Any +): + base_uri = engine_adapter.s3_warehouse_location_or_raise + + src_table = ctx.table("src_table") + test_table = ctx.table("test_table") + + base_data = pd.DataFrame( + [ + {"id": 1, "ts": datetime.datetime(2023, 1, 1, 12, 13, 14), "system": "dev"}, + {"id": 2, "ts": datetime.datetime(2023, 1, 1, 8, 13, 14), "system": "prod"}, + {"id": 3, "ts": datetime.datetime(2023, 1, 2, 11, 10, 0), "system": "dev"}, + {"id": 4, "ts": datetime.datetime(2023, 1, 2, 8, 10, 0), "system": "dev"}, + {"id": 5, "ts": datetime.datetime(2023, 1, 3, 16, 5, 14), "system": "dev"}, + {"id": 6, "ts": datetime.datetime(2023, 1, 3, 16, 5, 14), "system": "prod"}, + ] + ) + + engine_adapter.ctas( + table_name=src_table, + query_or_df=base_data, + ) + + sqlmesh_context, model = ctx.upsert_sql_model( + f""" + MODEL ( + name {test_table}, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds + ), + partitioned_by (ds, system), + start '2023-01-01' + ); + + SELECT + id, ts, (ts::date)::varchar as ds, system + FROM {src_table} + WHERE ts BETWEEN @start_dt AND @end_dt + """ + ) + + plan = sqlmesh_context.plan(no_prompts=True, auto_apply=True) + assert len(plan.snapshots) == 1 + test_table_snapshot = list(plan.snapshots.values())[0] + test_table_physical_name = exp.to_table(test_table_snapshot.table_name()) + + partitions = engine_adapter._list_partitions(test_table_physical_name, where=None) + assert len(partitions) == 5 + assert [p[0] for p in partitions] == [ + ["2023-01-01", "dev"], + ["2023-01-01", "prod"], + ["2023-01-02", "dev"], + ["2023-01-03", "dev"], + ["2023-01-03", "prod"], + ] + + files_before = s3_list_objects(s3, base_uri) + assert len(files_before) > 0 + + assert engine_adapter.fetchone(f"select count(*) from {test_table}")[0] == 6 # type: ignore + + # this should clear 2 partitions, ["2023-01-01", "dev"] and ["2023-01-01", "prod"] + assert model.time_column + engine_adapter._clear_partition_data( + table=test_table_physical_name, + where=exp.Between( + this=model.time_column.column, + low=exp.Literal.string("2023-01-01"), + high=exp.Literal.string("2023-01-01"), + ), + ) + + partitions = engine_adapter._list_partitions(test_table_physical_name, where=None) + assert len(partitions) == 3 + assert [p[0] for p in partitions] == [ + ["2023-01-02", "dev"], + ["2023-01-03", "dev"], + ["2023-01-03", "prod"], + ] + + files_after = s3_list_objects(s3, base_uri) + assert len(files_after) == len(files_before) - 2 + + def _match_partition(location_list: t.List[str], match: str): + return any(match in location for location in location_list) + + assert _match_partition(files_before, "ds=2023-01-01/system=dev") + assert _match_partition(files_before, "ds=2023-01-01/system=prod") + assert not _match_partition(files_after, "ds=2023-01-01/system=dev") + assert not _match_partition(files_after, "ds=2023-01-01/system=prod") + + assert engine_adapter.fetchone(f"select count(*) from {test_table}")[0] == 4 # type: ignore + + +def test_hive_truncate_table(ctx: TestContext, engine_adapter: AthenaEngineAdapter, s3: t.Any): + base_uri = engine_adapter.s3_warehouse_location_or_raise + + table_1 = ctx.table("table_one") + table_2 = ctx.table("table_two") + + base_data = pd.DataFrame( + [ + {"id": 1, "ts": datetime.datetime(2023, 1, 1, 12, 13, 14), "system": "dev"}, + {"id": 2, "ts": datetime.datetime(2023, 1, 1, 8, 13, 14), "system": "prod"}, + {"id": 3, "ts": datetime.datetime(2023, 1, 2, 11, 10, 0), "system": "dev"}, + {"id": 4, "ts": datetime.datetime(2023, 1, 2, 8, 10, 0), "system": "dev"}, + {"id": 5, "ts": datetime.datetime(2023, 1, 3, 16, 5, 14), "system": "dev"}, + {"id": 6, "ts": datetime.datetime(2023, 1, 3, 16, 5, 14), "system": "prod"}, + ] + ) + + assert len(s3_list_objects(s3, base_uri)) == 0 + + engine_adapter.ctas(table_name=table_1, query_or_df=base_data) + + engine_adapter.ctas(table_name=table_2, query_or_df=base_data) + + all_files = s3_list_objects(s3, base_uri) + assert len(all_files) > 0 + + table_1_location = engine_adapter._query_table_s3_location(table_1) + table_2_location = engine_adapter._query_table_s3_location(table_2) + + table_1_files = s3_list_objects(s3, table_1_location) + table_2_files = s3_list_objects(s3, table_2_location) + + assert len(table_1_files) < len(all_files) + assert len(table_2_files) < len(all_files) + assert len(table_1_files) + len(table_2_files) == len(all_files) + + assert engine_adapter.fetchone(f"select count(*) from {table_1}")[0] == 6 # type: ignore + engine_adapter._truncate_table(table_1) + assert len(s3_list_objects(s3, table_1_location)) == 0 + assert len(s3_list_objects(s3, table_2_location)) == len(table_2_files) + + assert engine_adapter.fetchone(f"select count(*) from {table_1}")[0] == 0 # type: ignore + + # check truncating an empty table doesnt throw an error + engine_adapter._truncate_table(table_1) + + +def test_hive_drop_table_removes_data(ctx: TestContext, engine_adapter: AthenaEngineAdapter): + # check no exception with dropping a table that doesnt exist + engine_adapter.drop_table("nonexist") + + seed_table = ctx.table("seed") + + data = pd.DataFrame( + [ + {"id": 1, "name": "one"}, + ] + ) + + columns_to_types = columns_to_types_from_df(data) + + engine_adapter.create_table( + table_name=seed_table, target_columns_to_types=columns_to_types, exists=False + ) + engine_adapter.insert_append( + table_name=seed_table, query_or_df=data, target_columns_to_types=columns_to_types + ) + assert engine_adapter.fetchone(f"select count(*) from {seed_table}")[0] == 1 # type: ignore + + # By default, dropping a Hive table leaves its data in S3 so creating a new table with the same name / location picks up the old data + # This ensures that our drop table logic to delete the data from S3 is working + engine_adapter.drop_table(seed_table, exists=False) + engine_adapter.create_table( + table_name=seed_table, target_columns_to_types=columns_to_types, exists=False + ) + assert engine_adapter.fetchone(f"select count(*) from {seed_table}")[0] == 0 # type: ignore + + +def test_hive_replace_query_same_schema(ctx: TestContext, engine_adapter: AthenaEngineAdapter): + seed_table = ctx.table("seed") + + data = pd.DataFrame( + [ + {"id": 1, "name": "one"}, + {"id": 2, "name": "two"}, + ] + ) + + assert not engine_adapter.table_exists(seed_table) + + engine_adapter.replace_query(table_name=seed_table, query_or_df=data) + + assert engine_adapter.fetchone(f"select count(*) from {seed_table}")[0] == 2 # type: ignore + + data.loc[len(data)] = [3, "three"] # type: ignore + + engine_adapter.replace_query(table_name=seed_table, query_or_df=data) + + assert engine_adapter.fetchone(f"select count(*) from {seed_table}")[0] == 3 # type: ignore + + +def test_hive_replace_query_new_schema(ctx: TestContext, engine_adapter: AthenaEngineAdapter): + seed_table = ctx.table("seed") + + orig_data = pd.DataFrame( + [ + {"id": 1, "name": "one"}, + {"id": 2, "name": "two"}, + ] + ) + + new_data = pd.DataFrame( + [ + {"foo": 1, "bar": "one", "ts": datetime.datetime(2023, 1, 1)}, + ] + ) + + engine_adapter.replace_query(table_name=seed_table, query_or_df=orig_data) + + assert engine_adapter.fetchall(f"select id, name from {seed_table} order by id") == [ + (1, "one"), + (2, "two"), + ] + + engine_adapter.replace_query(table_name=seed_table, query_or_df=new_data) + + with pytest.raises(Exception, match=r".*COLUMN_NOT_FOUND.*"): + assert engine_adapter.fetchall(f"select id, name from {seed_table}") + + assert engine_adapter.fetchone(f"select foo, bar, ts from {seed_table}") == ( + 1, + "one", + datetime.datetime(2023, 1, 1), + ) + + +def test_insert_overwrite_by_time_partition_date_type( + ctx: TestContext, engine_adapter: AthenaEngineAdapter +): + table = ctx.table("test_table") + + data = pd.DataFrame( + [ + {"id": 1, "date": datetime.date(2023, 1, 1)}, + {"id": 2, "date": datetime.date(2023, 1, 2)}, + {"id": 3, "date": datetime.date(2023, 1, 3)}, + ] + ) + + columns_to_types = { + "id": exp.DataType.build("int"), + "date": exp.DataType.build( + "date" + ), # note: columns_to_types_from_df() would infer this as TEXT but we need a DATE type + } + + def time_formatter(time: TimeLike, _: t.Optional[t.Dict[str, exp.DataType]]) -> exp.Expression: + return exp.cast(exp.Literal.string(to_ds(time)), "date") + + engine_adapter.create_table( + table_name=table, + target_columns_to_types=columns_to_types, + partitioned_by=[exp.to_column("date")], + ) + engine_adapter.insert_overwrite_by_time_partition( + table_name=table, + query_or_df=data, + target_columns_to_types=columns_to_types, + time_column=exp.to_identifier("date"), + start="2023-01-01", + end="2023-01-03", + time_formatter=time_formatter, + ) + + assert len(engine_adapter.fetchdf(exp.select("*").from_(table))) == 3 + + new_data = pd.DataFrame( + [ + {"id": 4, "date": datetime.date(2023, 1, 3)}, # replaces the old entry for 2023-01-03 + {"id": 5, "date": datetime.date(2023, 1, 4)}, + ] + ) + + engine_adapter.insert_overwrite_by_time_partition( + table_name=table, + query_or_df=new_data, + target_columns_to_types=columns_to_types, + time_column=exp.to_identifier("date"), + start="2023-01-03", + end="2023-01-04", + time_formatter=time_formatter, + ) + + result = engine_adapter.fetchdf(exp.select("*").from_(table)) + assert len(result) == 4 + assert sorted(result["id"].tolist()) == [1, 2, 4, 5] + + +def test_insert_overwrite_by_time_partition_datetime_type( + ctx: TestContext, engine_adapter: AthenaEngineAdapter +): + table = ctx.table("test_table") + + data = pd.DataFrame( + [ + {"id": 1, "ts": datetime.datetime(2023, 1, 1, 1, 0, 0)}, + {"id": 2, "ts": datetime.datetime(2023, 1, 1, 2, 0, 0)}, + {"id": 3, "ts": datetime.datetime(2023, 1, 1, 3, 0, 0)}, + ] + ) + + columns_to_types = { + "id": exp.DataType.build("int"), + "ts": exp.DataType.build( + "datetime" + ), # note: columns_to_types_from_df() would infer this as TEXT but we need a DATETIME type + } + + def time_formatter(time: TimeLike, _: t.Optional[t.Dict[str, exp.DataType]]) -> exp.Expression: + return exp.cast(exp.Literal.string(to_ts(time)), "datetime") + + engine_adapter.create_table( + table_name=table, + target_columns_to_types=columns_to_types, + partitioned_by=[exp.to_column("ts")], + ) + engine_adapter.insert_overwrite_by_time_partition( + table_name=table, + query_or_df=data, + target_columns_to_types=columns_to_types, + time_column=exp.to_identifier("ts"), + start="2023-01-01 00:00:00", + end="2023-01-01 04:00:00", + time_formatter=time_formatter, + ) + + assert len(engine_adapter.fetchdf(exp.select("*").from_(table))) == 3 + + new_data = pd.DataFrame( + [ + { + "id": 4, + "ts": datetime.datetime(2023, 1, 1, 3, 0, 0), + }, # replaces the old entry for 2023-01-01 03:00:00 + {"id": 5, "ts": datetime.datetime(2023, 1, 1, 4, 0, 0)}, + ] + ) + + engine_adapter.insert_overwrite_by_time_partition( + table_name=table, + query_or_df=new_data, + target_columns_to_types=columns_to_types, + time_column=exp.to_identifier("ts"), + start="2023-01-01 03:00:00", + end="2023-01-01 05:00:00", + time_formatter=time_formatter, + ) + + result = engine_adapter.fetchdf(exp.select("*").from_(table)) + assert len(result) == 4 + assert sorted(result["id"].tolist()) == [1, 2, 4, 5] + + +def test_scd_type_2_iceberg_timestamps( + ctx: TestContext, engine_adapter: AthenaEngineAdapter +) -> None: + src_table = ctx.table("src_table") + scd_model_table = ctx.table("scd_model") + + base_data = pd.DataFrame( + [ + {"id": 1, "ts": datetime.datetime(2023, 1, 1, 12, 13, 14)}, + {"id": 2, "ts": datetime.datetime(2023, 1, 1, 8, 13, 14)}, + {"id": 3, "ts": datetime.datetime(2023, 1, 2, 11, 10, 0)}, + {"id": 4, "ts": datetime.datetime(2023, 1, 2, 8, 10, 0)}, + {"id": 5, "ts": datetime.datetime(2023, 1, 3, 16, 5, 14)}, + {"id": 6, "ts": datetime.datetime(2023, 1, 3, 16, 5, 14)}, + ] + ) + + engine_adapter.ctas( + table_name=src_table, + query_or_df=base_data, + ) + + sqlmesh_context, model = ctx.upsert_sql_model( + f""" + MODEL ( + name {scd_model_table}, + kind SCD_TYPE_2_BY_TIME ( + unique_key id, + updated_at_name ts, + time_data_type timestamp(6) + ), + start '2020-01-01', + cron '@daily', + table_format iceberg + ); + + SELECT + id, ts::timestamp(6) as ts + FROM {src_table}; + """ + ) + + assert model.table_format == "iceberg" + + # throws if the temp tables created by the SCD Type 2 strategy are Hive tables instead of Iceberg + # because the Iceberg timestamp(6) type isnt supported in Hive + plan = sqlmesh_context.plan(auto_apply=True) + + assert len(plan.snapshots) == 1 + test_table_snapshot = list(plan.snapshots.values())[0] + test_table_physical_name = exp.to_table(test_table_snapshot.table_name()) + + assert engine_adapter._query_table_type(test_table_physical_name) == "iceberg" + timestamp_columns = [ + v + for k, v in engine_adapter.columns(test_table_physical_name).items() + if k in {"ts", "valid_from", "valid_to"} + ] + assert len(timestamp_columns) == 3 + assert all([v.sql(dialect="athena").lower() == "timestamp(6)" for v in timestamp_columns]) diff --git a/tests/core/engine_adapter/integration/test_integration_bigquery.py b/tests/core/engine_adapter/integration/test_integration_bigquery.py new file mode 100644 index 0000000000..0a6dd6b2a4 --- /dev/null +++ b/tests/core/engine_adapter/integration/test_integration_bigquery.py @@ -0,0 +1,471 @@ +import typing as t +import pytest +from pathlib import Path +from sqlglot import exp +from sqlglot.optimizer.qualify_columns import quote_identifiers +from sqlglot.helper import seq_get +from sqlmesh.cli.project_init import ProjectTemplate, init_example_project +from sqlmesh.core.config import Config +from sqlmesh.core.engine_adapter import BigQueryEngineAdapter +from sqlmesh.core.engine_adapter.mixins import ( + TableAlterDropClusterKeyOperation, + TableAlterChangeClusterKeyOperation, +) +from sqlmesh.core.engine_adapter.shared import DataObject +import sqlmesh.core.dialect as d +from sqlmesh.core.model import SqlModel, load_sql_based_model +from sqlmesh.core.plan import Plan, BuiltInPlanEvaluator +from sqlmesh.core.table_diff import TableDiff +from sqlmesh.utils import CorrelationId +from tests.core.engine_adapter.integration import TestContext +from pytest import FixtureRequest +from tests.core.engine_adapter.integration import ( + TestContext, + generate_pytest_params, + ENGINES_BY_NAME, + IntegrationTestEngine, +) + + +@pytest.fixture(params=list(generate_pytest_params(ENGINES_BY_NAME["bigquery"]))) +def ctx( + request: FixtureRequest, + create_test_context: t.Callable[[IntegrationTestEngine, str, str], t.Iterable[TestContext]], +) -> t.Iterable[TestContext]: + yield from create_test_context(*request.param) + + +@pytest.fixture +def engine_adapter(ctx: TestContext) -> BigQueryEngineAdapter: + assert isinstance(ctx.engine_adapter, BigQueryEngineAdapter) + return ctx.engine_adapter + + +def test_get_alter_expressions_includes_clustering( + ctx: TestContext, engine_adapter: BigQueryEngineAdapter +): + clustered_table = ctx.table("clustered_table") + clustered_differently_table = ctx.table("clustered_differently_table") + normal_table = ctx.table("normal_table") + + engine_adapter.execute( + f"CREATE TABLE {clustered_table.sql(dialect=ctx.dialect)} (c1 int, c2 timestamp) CLUSTER BY c1" + ) + engine_adapter.execute( + f"CREATE TABLE {clustered_differently_table.sql(dialect=ctx.dialect)} (c1 int, c2 timestamp) CLUSTER BY c1, c2" + ) + engine_adapter.execute( + f"CREATE TABLE {normal_table.sql(dialect=ctx.dialect)} (c1 int, c2 timestamp)" + ) + + metadata = engine_adapter.get_data_objects( + normal_table.db, {clustered_table.name, clustered_differently_table.name, normal_table.name} + ) + clustered_table_metadata = next(md for md in metadata if md.name == clustered_table.name) + clustered_differently_table_metadata = next( + md for md in metadata if md.name == clustered_differently_table.name + ) + normal_table_metadata = next(md for md in metadata if md.name == normal_table.name) + + assert clustered_table_metadata.clustering_key == "(c1)" + assert clustered_differently_table_metadata.clustering_key == "(c1,c2)" + assert normal_table_metadata.clustering_key is None + + assert len(engine_adapter.get_alter_operations(normal_table, normal_table)) == 0 + assert len(engine_adapter.get_alter_operations(clustered_table, clustered_table)) == 0 + + # alter table drop clustered + clustered_to_normal = engine_adapter.get_alter_operations(clustered_table, normal_table) + assert len(clustered_to_normal) == 1 + assert isinstance(clustered_to_normal[0], TableAlterDropClusterKeyOperation) + assert clustered_to_normal[0].target_table == clustered_table + assert not hasattr(clustered_to_normal[0], "clustering_key") + + # alter table add clustered + normal_to_clustered = engine_adapter.get_alter_operations(normal_table, clustered_table) + assert len(normal_to_clustered) == 1 + operation = normal_to_clustered[0] + assert isinstance(operation, TableAlterChangeClusterKeyOperation) + assert operation.target_table == normal_table + assert operation.clustering_key == "(c1)" + + # alter table change clustering (c1 -> (c1, c2)) + clustered_to_clustered_differently = engine_adapter.get_alter_operations( + clustered_table, clustered_differently_table + ) + assert len(clustered_to_clustered_differently) == 1 + operation = clustered_to_clustered_differently[0] + assert isinstance(operation, TableAlterChangeClusterKeyOperation) + assert operation.target_table == clustered_table + assert operation.clustering_key == "(c1,c2)" + + # alter table change clustering ((c1, c2) -> c1) + clustered_differently_to_clustered = engine_adapter.get_alter_operations( + clustered_differently_table, clustered_table + ) + assert len(clustered_differently_to_clustered) == 1 + operation = clustered_differently_to_clustered[0] + assert isinstance(operation, TableAlterChangeClusterKeyOperation) + assert operation.target_table == clustered_differently_table + assert operation.clustering_key == "(c1)" + + +def test_mutating_clustered_by_forward_only( + ctx: TestContext, engine_adapter: BigQueryEngineAdapter +): + model_name = ctx.table("TEST") + + sqlmesh = ctx.create_context() + + def _create_model(**kwargs: t.Any) -> SqlModel: + extra_props = "\n".join([f"{k} {v}," for k, v in kwargs.items()]) + return t.cast( + SqlModel, + load_sql_based_model( + d.parse( + f""" + MODEL ( + name {model_name.sql(dialect="bigquery")}, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column partitiondate + ), + {extra_props} + start '2021-01-01', + cron '@daily', + dialect 'bigquery' + ); + + select 1 as ID, current_date() as partitiondate + """ + ) + ), + ) + + def _get_data_object(table: exp.Table) -> DataObject: + data_object = seq_get(engine_adapter.get_data_objects(table.db, {table.name}), 0) + if not data_object: + raise ValueError(f"Expected metadata for {table}") + return data_object + + m1 = _create_model() + m2 = _create_model(clustered_by="partitiondate") + m3 = _create_model(clustered_by="(id, partitiondate)") + + # Initial plan - non-clustered table + sqlmesh.upsert_model(m1) + plan_1: Plan = sqlmesh.plan(auto_apply=True, no_prompts=True) + assert len(plan_1.snapshots) == 1 + target_table_1 = exp.to_table(list(plan_1.snapshots.values())[0].table_name()) + quote_identifiers(target_table_1) + + assert not _get_data_object(target_table_1).is_clustered + + # Next plan - add clustering key (non-clustered -> clustered) + sqlmesh.upsert_model(m2) + plan_2: Plan = sqlmesh.plan(auto_apply=True, no_prompts=True, forward_only=True) + assert len(plan_2.snapshots) == 1 + target_table_2 = exp.to_table(list(plan_2.snapshots.values())[0].table_name()) + quote_identifiers(target_table_2) + + assert target_table_1 == target_table_2 + + metadata = _get_data_object(target_table_1) + assert metadata.is_clustered + assert metadata.clustering_key == "(partitiondate)" + + # Next plan - change clustering key (clustered -> clustered differently) + sqlmesh.upsert_model(m3) + plan_3: Plan = sqlmesh.plan(auto_apply=True, no_prompts=True, forward_only=True) + assert len(plan_3.snapshots) == 1 + target_table_3 = exp.to_table(list(plan_3.snapshots.values())[0].table_name()) + quote_identifiers(target_table_3) + + assert target_table_1 == target_table_3 + + metadata = _get_data_object(target_table_1) + assert metadata.is_clustered + assert metadata.clustering_key == "(id,partitiondate)" + + # Next plan - drop clustering key + sqlmesh.upsert_model(m1) + plan_4: Plan = sqlmesh.plan(auto_apply=True, no_prompts=True, forward_only=True) + assert len(plan_4.snapshots) == 1 + target_table_4 = exp.to_table(list(plan_4.snapshots.values())[0].table_name()) + quote_identifiers(target_table_4) + + assert target_table_1 == target_table_4 + + metadata = _get_data_object(target_table_1) + assert not metadata.is_clustered + + +def test_information_schema_view_external_model(ctx: TestContext, tmp_path: Path): + # Information schema views are represented as: + # + # Table( + # this=Identifier(INFORMATION_SCHEMA.SOME_VIEW, quoted=True), + # db=Identifier(some_schema), + # catalog=Identifier(some_catalog)) + # + # This representation is produced by BigQuery's parser, so that the mapping schema + # nesting depth is consistent with other table references in a project, which will + # usually look like `project.dataset.table`. + information_schema_tables = ctx.table("INFORMATION_SCHEMA.TABLES") + assert len(information_schema_tables.parts) == 3 + + model_name = ctx.table("test") + dependency = f"`{'.'.join(part.name for part in information_schema_tables.parts)}`" + + init_example_project(tmp_path, engine_type="bigquery", template=ProjectTemplate.EMPTY) + with open(tmp_path / "models" / "test.sql", "w", encoding="utf-8") as f: + f.write( + f""" + MODEL ( + name {model_name.sql("bigquery")}, + kind FULL, + dialect 'bigquery' + ); + + SELECT * FROM {dependency} AS tables + """ + ) + + def _mutate_config(_: str, config: Config) -> None: + config.model_defaults.dialect = "bigquery" + + sqlmesh = ctx.create_context(_mutate_config, path=tmp_path) + sqlmesh.create_external_models() + sqlmesh.load() + + actual_columns_to_types = sqlmesh.get_model(information_schema_tables.sql()).columns_to_types + expected_columns_to_types = { + "table_catalog": exp.DataType.build("TEXT"), + "table_schema": exp.DataType.build("TEXT"), + "table_name": exp.DataType.build("TEXT"), + "table_type": exp.DataType.build("TEXT"), + } + + assert actual_columns_to_types is not None + assert actual_columns_to_types.items() >= expected_columns_to_types.items() + + rendered_query = sqlmesh.get_model(model_name.sql()).render_query() + assert isinstance(rendered_query, exp.Query) + assert not rendered_query.selects[0].is_star + + +def test_compare_nested_values_in_table_diff(ctx: TestContext): + src_table = ctx.table("source") + target_table = ctx.table("target") + + query: exp.Query = exp.maybe_parse( + """ + SELECT + 1 AS id, + STRUCT( + 'Main String' AS top_level_string, + [1, 2, 3, 4] AS top_level_array, + STRUCT( + 'Nested String' AS nested_string, + [STRUCT( + 'Inner Struct String 1' AS inner_string, + [10, 20, 30] AS inner_array + ), STRUCT( + 'Inner Struct String 2' AS inner_string, + [40, 50, 60] AS inner_array + )] AS nested_array_of_structs + ) AS nested_struct, + [STRUCT( + 'Array Struct String 1' AS array_struct_string, + STRUCT( + 'Deeper Nested String' AS deeper_nested_string, + [100, 200] AS deeper_nested_array + ) AS deeper_nested_struct + ), STRUCT( + 'Array Struct String 2' AS array_struct_string, + STRUCT( + 'Another Nested String' AS deeper_nested_string, + [300, 400] AS deeper_nested_array + ) AS deeper_nested_struct + )] AS array_of_structs_with_nested_structs, + ARRAY( + SELECT STRUCT( + CONCAT('Dynamic String ', CAST(num AS STRING)) AS dynamic_string, + ARRAY( + SELECT CAST(num * multiplier AS INT64) + FROM UNNEST([1, 2, 3]) AS multiplier + ) AS dynamic_array + ) + FROM UNNEST([1, 2, 3]) AS num + ) AS dynamically_generated_struct_array + ) AS nested_value + """, + dialect="bigquery", + ) + + ctx.engine_adapter.ctas(src_table, query) + ctx.engine_adapter.ctas(target_table, query) + + table_diff = TableDiff( + adapter=ctx.engine_adapter, + source=exp.table_name(src_table), + target=exp.table_name(target_table), + on=["id"], + ) + row_diff = table_diff.row_diff() + + assert row_diff.stats["join_count"] == 1 + assert row_diff.full_match_count == 1 + + ctx.engine_adapter.drop_table(src_table) + ctx.engine_adapter.drop_table(target_table) + + query1: exp.Query = exp.maybe_parse( + "SELECT 0 as id, [STRUCT(0 as struct_id, 'value1' as struct_value), STRUCT(1 as struct_id, 'value2' as struct_value)] as struct_array", + dialect="bigquery", + ) + query2: exp.Query = exp.maybe_parse( + "SELECT 0 as id, [STRUCT(0 as struct_id, 'value2' as struct_value), STRUCT(1 as struct_id, 'value1' as struct_value)] as struct_array", + dialect="bigquery", + ) + + ctx.engine_adapter.ctas(src_table, query1) + ctx.engine_adapter.ctas(target_table, query2) + + table_diff = TableDiff( + adapter=ctx.engine_adapter, + source=exp.table_name(src_table), + target=exp.table_name(target_table), + on=["id"], + ) + row_diff = table_diff.row_diff() + + assert row_diff.stats["join_count"] == 1 + assert row_diff.full_match_count == 0 + + ctx.engine_adapter.drop_table(src_table) + ctx.engine_adapter.drop_table(target_table) + + +def test_get_bq_schema(ctx: TestContext, engine_adapter: BigQueryEngineAdapter): + from google.cloud.bigquery import SchemaField + + table = ctx.table("test") + + engine_adapter.execute(f""" + CREATE TABLE {table.sql(dialect=ctx.dialect)} ( + id STRING NOT NULL, + user_data STRUCT, + tags ARRAY, + score NUMERIC, + created_at DATETIME + ) + """) + + bg_schema = engine_adapter.get_bq_schema(table) + assert len(bg_schema) == 5 + assert bg_schema[0] == SchemaField(name="id", field_type="STRING", mode="REQUIRED") + assert bg_schema[1] == SchemaField( + name="user_data", + field_type="RECORD", + mode="NULLABLE", + fields=[ + SchemaField(name="id", field_type="STRING", mode="REQUIRED"), + SchemaField(name="name", field_type="STRING", mode="REQUIRED"), + SchemaField(name="address", field_type="STRING", mode="NULLABLE"), + ], + ) + assert bg_schema[2] == SchemaField(name="tags", field_type="STRING", mode="REPEATED") + assert bg_schema[3] == SchemaField(name="score", field_type="NUMERIC", mode="NULLABLE") + assert bg_schema[4] == SchemaField(name="created_at", field_type="DATETIME", mode="NULLABLE") + + +def test_column_types(ctx: TestContext): + model_name = ctx.table("test") + sqlmesh = ctx.create_context() + + sqlmesh.upsert_model( + load_sql_based_model( + d.parse( + f""" + MODEL ( + name {model_name}, + ); + + SELECT + RANGE('01-01-1900'::DATE, '01-01-1902'::DATE) AS col1, + JSON '{{"id": 10}}' AS col2, + STRUCT([PARSE_JSON('{{"id": 10}}')] AS arr) AS col3; + """ + ) + ) + ) + + sqlmesh.plan(auto_apply=True, no_prompts=True) + + columns = sqlmesh.engine_adapter.columns(model_name) + + assert columns["col1"].is_type("RANGE") + assert columns["col2"].is_type("JSON") + + col3 = columns["col3"] + coldef = col3.find(exp.ColumnDef) + assert col3.is_type("STRUCT") + assert coldef and coldef.kind and coldef.kind.is_type("ARRAY") + + +def test_table_diff_table_name_matches_column_name(ctx: TestContext): + src_table = ctx.table("source") + target_table = ctx.table("target") + + # BigQuery has a quirk where if you do SELECT foo FROM project-id.schema.foo, the projection is + # interpreted as a struct column, reflecting the scanned table's schema, even if the table has + # a column with the same name (foo). + # + # This is a problem, because we compare the columns of the source and target tables using the + # equality operator (=), which is not defined for struct values in BigQuery, leading to an error. + query: exp.Query = exp.maybe_parse("SELECT 1 AS s, 2 AS source, 3 AS target") + + ctx.engine_adapter.ctas(src_table, query) + ctx.engine_adapter.ctas(target_table, query) + + table_diff = TableDiff( + adapter=ctx.engine_adapter, + source=exp.table_name(src_table), + target=exp.table_name(target_table), + on=["s"], + ) + + row_diff = table_diff.row_diff() + + assert row_diff.stats["join_count"] == 1 + assert row_diff.full_match_count == 1 + + +def test_correlation_id_in_job_labels(ctx: TestContext): + model_name = ctx.table("test") + + sqlmesh = ctx.create_context() + sqlmesh.upsert_model( + load_sql_based_model(d.parse(f"MODEL (name {model_name}, kind FULL); SELECT 1 AS col")) + ) + + # Create a plan evaluator and a plan to evaluate + plan_evaluator = BuiltInPlanEvaluator( + sqlmesh.state_sync, + sqlmesh.snapshot_evaluator, + sqlmesh.create_scheduler, + sqlmesh.default_catalog, + ) + plan: Plan = sqlmesh.plan_builder("prod", skip_tests=True).build() + + # Evaluate the plan and retrieve the plan evaluator's adapter + plan_evaluator.evaluate(plan.to_evaluatable()) + adapter = t.cast(BigQueryEngineAdapter, plan_evaluator.snapshot_evaluator.adapter) + + # Case 1: Ensure that the correlation id is set in the underlying adapter + assert adapter.correlation_id is not None + + # Case 2: Ensure that the correlation id is set in the job labels + labels = adapter._job_params.get("labels") + correlation_id = CorrelationId.from_plan_id(plan.plan_id) + assert labels == {correlation_id.job_type.value.lower(): correlation_id.job_id} diff --git a/tests/core/engine_adapter/integration/test_integration_clickhouse.py b/tests/core/engine_adapter/integration/test_integration_clickhouse.py new file mode 100644 index 0000000000..f09360c673 --- /dev/null +++ b/tests/core/engine_adapter/integration/test_integration_clickhouse.py @@ -0,0 +1,562 @@ +import typing as t +import pytest +from pytest import FixtureRequest +from tests.core.engine_adapter.integration import TestContext +from sqlmesh.core.engine_adapter.clickhouse import ClickhouseEngineAdapter +import pandas as pd # noqa: TID253 +from sqlglot import exp, parse_one +from sqlmesh.core.snapshot import SnapshotChangeCategory + +from tests.core.engine_adapter.integration import ( + TestContext, + generate_pytest_params, + ENGINES_BY_NAME, + IntegrationTestEngine, +) + + +@pytest.fixture( + params=list( + generate_pytest_params([ENGINES_BY_NAME["clickhouse"], ENGINES_BY_NAME["clickhouse_cloud"]]) + ) +) +def ctx( + request: FixtureRequest, + create_test_context: t.Callable[[IntegrationTestEngine, str, str], t.Iterable[TestContext]], +) -> t.Iterable[TestContext]: + yield from create_test_context(*request.param) + + +@pytest.fixture +def engine_adapter(ctx: TestContext) -> ClickhouseEngineAdapter: + assert isinstance(ctx.engine_adapter, ClickhouseEngineAdapter) + return ctx.engine_adapter + + +def _get_source_queries_and_columns_to_types( + ctx, + insert_table: exp.Table, + target_table: exp.Table, + columns_to_types: t.Dict[str, exp.DataType] = { + "id": exp.DataType.build("Int8", dialect="clickhouse"), + "ds": exp.DataType.build("Date", dialect="clickhouse"), + }, +): + return ctx.engine_adapter._get_source_queries_and_columns_to_types( + parse_one(f"SELECT * FROM {insert_table.sql()}"), # type: ignore + columns_to_types, + target_table=target_table.sql(), + ) + + +def _create_table_and_insert_existing_data( + ctx: TestContext, + existing_data: pd.DataFrame = pd.DataFrame( + [ + {"id": 1, "ds": "2024-01-01"}, + {"id": 2, "ds": "2024-02-01"}, + {"id": 3, "ds": "2024-02-28"}, + {"id": 4, "ds": "2024-03-01"}, + ] + ), + columns_to_types: t.Dict[str, exp.DataType] = { + "id": exp.DataType.build("Int8", "clickhouse"), + "ds": exp.DataType.build("Date", "clickhouse"), + }, + table_name: str = "data_existing", + partitioned_by: t.Optional[t.List[exp.Expression]] = [ + parse_one("toMonth(ds)", dialect="clickhouse") + ], +) -> exp.Table: + existing_data = existing_data + existing_table_name: exp.Table = ctx.table(table_name) + ctx.engine_adapter.ctas( + existing_table_name.sql(), + ctx.input_data(existing_data, columns_to_types), + columns_to_types, + partitioned_by=partitioned_by, + ) + return existing_table_name + + +def test_insert_overwrite_by_condition_replace_partitioned(ctx: TestContext): + existing_table_name = _create_table_and_insert_existing_data(ctx) + + # new data to insert + insert_data = pd.DataFrame( + [ + {"id": 5, "ds": "2024-02-29"}, + {"id": 6, "ds": "2024-04-01"}, + ] + ) + insert_table_name = ctx.table("data_insert") + + with ctx.engine_adapter.temp_table( + ctx.input_data(insert_data), insert_table_name.sql() + ) as insert_table: + source_queries, columns_to_types = _get_source_queries_and_columns_to_types( + ctx, insert_table, existing_table_name + ) + + ctx.engine_adapter._insert_overwrite_by_condition( + existing_table_name.sql(), + source_queries, + columns_to_types, + ) + + ctx.compare_with_current( + existing_table_name, + pd.DataFrame( + [ + {"id": 5, "ds": pd.Timestamp("2024-02-29")}, + {"id": 6, "ds": pd.Timestamp("2024-04-01")}, + ] + ), + ) + + +def test_insert_overwrite_by_condition_replace(ctx: TestContext): + existing_table_name = _create_table_and_insert_existing_data(ctx, partitioned_by=None) + + # new data to insert + insert_data = pd.DataFrame( + [ + {"id": 5, "ds": "2024-02-29"}, + {"id": 6, "ds": "2024-04-01"}, + ] + ) + insert_table_name = ctx.table("data_insert") + + with ctx.engine_adapter.temp_table( + ctx.input_data(insert_data), insert_table_name.sql() + ) as insert_table: + source_queries, columns_to_types = _get_source_queries_and_columns_to_types( + ctx, insert_table, existing_table_name + ) + + ctx.engine_adapter._insert_overwrite_by_condition( + existing_table_name.sql(), + source_queries, + columns_to_types, + ) + + ctx.compare_with_current( + existing_table_name, + pd.DataFrame( + [ + {"id": 5, "ds": pd.Timestamp("2024-02-29")}, + {"id": 6, "ds": pd.Timestamp("2024-04-01")}, + ] + ), + ) + + +def test_insert_overwrite_by_condition_where_partitioned(ctx: TestContext): + # `where` time window + start_date = "2024-02-15" + end_date = "2024-04-30" + + # data currently in target table + existing_table_name = _create_table_and_insert_existing_data(ctx) + + # new data to insert + insert_data = pd.DataFrame( + [ + {"id": 5, "ds": "2024-02-29"}, + {"id": 6, "ds": "2024-04-01"}, + ] + ) + insert_table_name = ctx.table("data_insert") + + with ctx.engine_adapter.temp_table( + ctx.input_data(insert_data), insert_table_name.sql() + ) as insert_table: + source_queries, columns_to_types = _get_source_queries_and_columns_to_types( + ctx, insert_table, existing_table_name + ) + ctx.engine_adapter._insert_overwrite_by_condition( + existing_table_name.sql(), + source_queries, + columns_to_types, + exp.Between( + this=exp.column("ds"), + low=parse_one(f"'{start_date}'"), + high=parse_one(f"'{end_date}'"), + ), + ) + + ctx.compare_with_current( + existing_table_name, + pd.DataFrame( + [ + {"id": 1, "ds": pd.Timestamp("2024-01-01")}, # retained + {"id": 2, "ds": pd.Timestamp("2024-02-01")}, # retained + {"id": 5, "ds": pd.Timestamp("2024-02-29")}, # inserted + {"id": 6, "ds": pd.Timestamp("2024-04-01")}, # inserted + ] + ), + ) + + +def test_insert_overwrite_by_condition_where_compound_partitioned(ctx: TestContext): + # `where` time window + start_date = "2024-02-15" + end_date = "2024-04-30" + + compound_columns_to_types = { + "id": exp.DataType.build("Int8", ctx.dialect), + "ds": exp.DataType.build("Date", ctx.dialect), + "city": exp.DataType.build("String", ctx.dialect), + } + + # data currently in target table + existing_table_name = _create_table_and_insert_existing_data( + ctx, + existing_data=pd.DataFrame( + [ + {"id": 1, "ds": "2024-01-01", "city": "1"}, + {"id": 2, "ds": "2024-01-02", "city": "2"}, + {"id": 3, "ds": "2024-02-01", "city": "1"}, + {"id": 4, "ds": "2024-02-02", "city": "2"}, + {"id": 5, "ds": "2024-02-27", "city": "1"}, + {"id": 6, "ds": "2024-02-28", "city": "2"}, + {"id": 7, "ds": "2024-03-01", "city": "1"}, + {"id": 8, "ds": "2024-03-02", "city": "2"}, + ] + ), + columns_to_types=compound_columns_to_types, + partitioned_by=[parse_one("toMonth(ds)", dialect=ctx.dialect), exp.column("city")], + ) + + # new data to insert + insert_data = pd.DataFrame( + [ + {"id": 9, "ds": "2024-02-26", "city": "1"}, + {"id": 10, "ds": "2024-02-29", "city": "2"}, + {"id": 11, "ds": "2024-04-01", "city": "1"}, + {"id": 12, "ds": "2024-04-02", "city": "2"}, + ] + ) + insert_table_name = ctx.table("data_insert") + + with ctx.engine_adapter.temp_table( + ctx.input_data(insert_data, compound_columns_to_types), insert_table_name.sql() + ) as insert_table: + source_queries, columns_to_types = _get_source_queries_and_columns_to_types( + ctx, + insert_table, + existing_table_name, + compound_columns_to_types, + ) + ctx.engine_adapter._insert_overwrite_by_condition( + existing_table_name.sql(), + source_queries, + columns_to_types, + exp.Between( + this=exp.column("ds"), + low=parse_one(f"'{start_date}'"), + high=parse_one(f"'{end_date}'"), + ), + ) + + ctx.compare_with_current( + existing_table_name, + pd.DataFrame( + [ + {"id": 1, "ds": pd.Timestamp("2024-01-01"), "city": "1"}, + {"id": 2, "ds": pd.Timestamp("2024-01-02"), "city": "2"}, + {"id": 3, "ds": pd.Timestamp("2024-02-01"), "city": "1"}, + {"id": 4, "ds": pd.Timestamp("2024-02-02"), "city": "2"}, + {"id": 9, "ds": pd.Timestamp("2024-02-26"), "city": "1"}, + {"id": 10, "ds": pd.Timestamp("2024-02-29"), "city": "2"}, + {"id": 11, "ds": pd.Timestamp("2024-04-01"), "city": "1"}, + {"id": 12, "ds": pd.Timestamp("2024-04-02"), "city": "2"}, + ] + ), + ) + + +def test_insert_overwrite_by_condition_by_key(ctx: TestContext): + # key parameters + key = [exp.column("id")] + key_exp = key[0] + + # data currently in target table + existing_table_name = _create_table_and_insert_existing_data(ctx, partitioned_by=None) + + # new data to insert + insert_data = pd.DataFrame( + [ + {"id": 4, "ds": "2024-05-01"}, # will overwrite existing record + {"id": 4, "ds": "2024-05-02"}, # only inserted if unique_key = False + {"id": 5, "ds": "2024-02-29"}, + {"id": 6, "ds": "2024-04-01"}, + ] + ) + + # unique_key = True + with ctx.engine_adapter.temp_table( + ctx.input_data(insert_data), ctx.table("data_insert").sql() + ) as insert_table: + source_queries, columns_to_types = _get_source_queries_and_columns_to_types( + ctx, insert_table, existing_table_name + ) + ctx.engine_adapter._insert_overwrite_by_condition( + existing_table_name.sql(), + source_queries, + columns_to_types, + dynamic_key=key, + dynamic_key_exp=key_exp, + dynamic_key_unique=True, + ) + + ctx.compare_with_current( + existing_table_name, + pd.DataFrame( + [ + {"id": 1, "ds": pd.Timestamp("2024-01-01")}, + {"id": 2, "ds": pd.Timestamp("2024-02-01")}, + {"id": 3, "ds": pd.Timestamp("2024-02-28")}, + {"id": 4, "ds": pd.Timestamp("2024-05-01")}, + {"id": 5, "ds": pd.Timestamp("2024-02-29")}, + {"id": 6, "ds": pd.Timestamp("2024-04-01")}, + ] + ), + ) + ctx.engine_adapter.drop_table(existing_table_name.sql()) + + # unique_key = False + existing_table_name = _create_table_and_insert_existing_data( + ctx, table_name="data_existing_no_unique", partitioned_by=None + ) + + with ctx.engine_adapter.temp_table( + ctx.input_data(insert_data), ctx.table("data_insert_no_unique").sql() + ) as insert_table: + source_queries, columns_to_types = _get_source_queries_and_columns_to_types( + ctx, insert_table, existing_table_name + ) + ctx.engine_adapter._insert_overwrite_by_condition( + existing_table_name.sql(), + source_queries, + columns_to_types, + dynamic_key=key, + dynamic_key_exp=key_exp, + dynamic_key_unique=False, + ) + + ctx.compare_with_current( + existing_table_name, + pd.DataFrame( + [ + {"id": 1, "ds": pd.Timestamp("2024-01-01")}, + {"id": 2, "ds": pd.Timestamp("2024-02-01")}, + {"id": 3, "ds": pd.Timestamp("2024-02-28")}, + {"id": 4, "ds": pd.Timestamp("2024-05-01")}, + {"id": 4, "ds": pd.Timestamp("2024-05-02")}, + {"id": 5, "ds": pd.Timestamp("2024-02-29")}, + {"id": 6, "ds": pd.Timestamp("2024-04-01")}, + ] + ), + ) + + +def test_insert_overwrite_by_condition_by_key_partitioned(ctx: TestContext): + # key parameters + key = [exp.column("id")] + key_exp = key[0] + + # data currently in target table + existing_table_name = _create_table_and_insert_existing_data(ctx) + + # new data to insert + insert_data = pd.DataFrame( + [ + {"id": 4, "ds": "2024-05-01"}, # will overwrite existing record + {"id": 4, "ds": "2024-05-02"}, # only inserted if unique_key = False + {"id": 5, "ds": "2024-02-29"}, + {"id": 6, "ds": "2024-04-01"}, + ] + ) + + # unique_key = True + with ctx.engine_adapter.temp_table( + ctx.input_data(insert_data), ctx.table("data_insert").sql() + ) as insert_table: + source_queries, columns_to_types = _get_source_queries_and_columns_to_types( + ctx, insert_table, existing_table_name + ) + ctx.engine_adapter._insert_overwrite_by_condition( + existing_table_name.sql(), + source_queries, + columns_to_types, + dynamic_key=key, + dynamic_key_exp=key_exp, + dynamic_key_unique=True, + ) + + ctx.compare_with_current( + existing_table_name, + pd.DataFrame( + [ + {"id": 1, "ds": pd.Timestamp("2024-01-01")}, + {"id": 2, "ds": pd.Timestamp("2024-02-01")}, + {"id": 3, "ds": pd.Timestamp("2024-02-28")}, + {"id": 4, "ds": pd.Timestamp("2024-05-01")}, + {"id": 5, "ds": pd.Timestamp("2024-02-29")}, + {"id": 6, "ds": pd.Timestamp("2024-04-01")}, + ] + ), + ) + ctx.engine_adapter.drop_table(existing_table_name.sql()) + + # unique_key = False + existing_table_name = _create_table_and_insert_existing_data( + ctx, table_name="data_existing_no_unique" + ) + + with ctx.engine_adapter.temp_table( + ctx.input_data(insert_data), ctx.table("data_insert_no_unique").sql() + ) as insert_table: + source_queries, columns_to_types = _get_source_queries_and_columns_to_types( + ctx, insert_table, existing_table_name + ) + ctx.engine_adapter._insert_overwrite_by_condition( + existing_table_name.sql(), + source_queries, + columns_to_types, + dynamic_key=key, + dynamic_key_exp=key_exp, + dynamic_key_unique=False, + ) + + ctx.compare_with_current( + existing_table_name, + pd.DataFrame( + [ + {"id": 1, "ds": pd.Timestamp("2024-01-01")}, + {"id": 2, "ds": pd.Timestamp("2024-02-01")}, + {"id": 3, "ds": pd.Timestamp("2024-02-28")}, + {"id": 4, "ds": pd.Timestamp("2024-05-01")}, + { + "id": 4, + "ds": pd.Timestamp("2024-05-02"), + }, # second ID=4 row because unique_key=False + {"id": 5, "ds": pd.Timestamp("2024-02-29")}, + {"id": 6, "ds": pd.Timestamp("2024-04-01")}, + ] + ), + ) + + +def test_insert_overwrite_by_condition_inc_by_partition(ctx: TestContext): + existing_table_name = _create_table_and_insert_existing_data(ctx) + + # new data to insert + insert_data = pd.DataFrame( + [ + {"id": 5, "ds": "2024-02-29"}, + {"id": 6, "ds": "2024-04-01"}, + ] + ) + insert_table_name = ctx.table("data_insert") + + with ctx.engine_adapter.temp_table( + ctx.input_data(insert_data), insert_table_name.sql() + ) as insert_table: + source_queries, columns_to_types = _get_source_queries_and_columns_to_types( + ctx, insert_table, existing_table_name + ) + ctx.engine_adapter._insert_overwrite_by_condition( + existing_table_name.sql(), + source_queries, + columns_to_types, + keep_existing_partition_rows=False, + ) + + ctx.compare_with_current( + existing_table_name, + pd.DataFrame( + [ + {"id": 1, "ds": pd.Timestamp("2024-01-01")}, + { + "id": 5, + "ds": pd.Timestamp("2024-02-29"), + }, # all existing Feb records overwritten by this row + {"id": 4, "ds": pd.Timestamp("2024-03-01")}, + {"id": 6, "ds": pd.Timestamp("2024-04-01")}, + ] + ), + ) + + +def test_inc_by_time_auto_partition_string(ctx: TestContext): + # ensure automatic time partitioning works when the time column is not a Date/DateTime type + existing_table_name = _create_table_and_insert_existing_data( + ctx, + columns_to_types={ + "id": exp.DataType.build("Int8", "clickhouse"), + "ds": exp.DataType.build("String", "clickhouse"), # String time column + }, + table_name="data_existing", + partitioned_by=None, + ) + + sqlmesh_context, model = ctx.upsert_sql_model( + f""" + MODEL ( + name test.inc_by_time_no_partition, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds + ), + dialect clickhouse, + start '2023-01-01' + ); + + SELECT + id::Int8, + ds::String + FROM {existing_table_name.sql()} + WHERE ds BETWEEN @start_ds AND @end_ds + """ + ) + + plan = sqlmesh_context.plan(no_prompts=True, auto_apply=True) + + physical_location = ctx.engine_adapter.get_data_objects( + plan.environment.snapshots[0].physical_schema + )[0] + + partitions = ctx.engine_adapter.fetchall( + exp.select("_partition_id") + .distinct() + .from_(f"{physical_location.schema_name}.{physical_location.name}") + ) + + # The automatic time partitioning creates one partition per week. The 4 input data points + # are located in three distinct weeks, which should have one partition each. + assert len(partitions) == 3 + + +def test_diff_requires_dialect(ctx: TestContext): + sql = """ + MODEL ( + name test_schema.some_view, + kind VIEW, + dialect clickhouse + ); + + SELECT + maxIf('2020-01-01'::Date, 1={rhs})::Nullable(Date) as col + """ + + sqlmesh_context, model = ctx.upsert_sql_model(sql.format(rhs="1")) + sqlmesh_context.plan(no_prompts=True, auto_apply=True) + + _, model = ctx.upsert_sql_model(sql.format(rhs="2")) + sqlmesh_context.upsert_model(model) + + plan = sqlmesh_context.plan(no_prompts=True, auto_apply=True, no_diff=True) + + new_snapshot = plan.context_diff.modified_snapshots['"test_schema"."some_view"'][0] + assert new_snapshot.change_category == SnapshotChangeCategory.BREAKING diff --git a/tests/core/engine_adapter/integration/test_integration_duckdb.py b/tests/core/engine_adapter/integration/test_integration_duckdb.py new file mode 100644 index 0000000000..a53c559a55 --- /dev/null +++ b/tests/core/engine_adapter/integration/test_integration_duckdb.py @@ -0,0 +1,141 @@ +import typing as t +import pytest +from threading import current_thread, Thread +import random +from sqlglot import exp +from pathlib import Path +from concurrent.futures import ThreadPoolExecutor, as_completed + +from sqlmesh.core.config.connection import DuckDBConnectionConfig +from sqlmesh.utils.connection_pool import ThreadLocalSharedConnectionPool + +pytestmark = [pytest.mark.duckdb, pytest.mark.engine, pytest.mark.slow] + + +@pytest.mark.parametrize("database", [None, "db.db"]) +def test_multithread_concurrency(tmp_path: Path, database: t.Optional[str]): + num_threads = 100 + + if database: + database = str(tmp_path / database) + + config = DuckDBConnectionConfig(concurrent_tasks=8, database=database) + + adapter = config.create_engine_adapter() + + assert isinstance(adapter._connection_pool, ThreadLocalSharedConnectionPool) + + # this test loosely follows this example: https://duckdb.org/docs/guides/python/multiple_threads.html + adapter.execute( + "create table tbl (thread_name varchar, insert_time timestamp default current_timestamp)" + ) + + # list.append() is threadsafe + write_results = [] + read_results = [] + + def write_from_thread(): + thread_name = str(current_thread().name) + query = exp.insert( + exp.values([(exp.Literal.string(thread_name),)]), "tbl", columns=["thread_name"] + ) + adapter.execute(query) + adapter.execute(f"CREATE TABLE thread_{thread_name} (id int)") + write_results.append(thread_name) + + def read_from_thread(): + thread_name = str(current_thread().name) + query = exp.select( + exp.Literal.string(thread_name).as_("thread_name"), + exp.Count(this="*").as_("row_counter"), + exp.CurrentTimestamp(), + ).from_("tbl") + results = adapter.fetchall(query) + assert len(results) == 1 + read_results.append(results[0]) + + threads = [] + + for i in range(num_threads): + threads.append(Thread(target=write_from_thread, name=f"write_thread_{i}")) + threads.append(Thread(target=read_from_thread, name=f"read_thread_{i}")) + + random.seed(6) + random.shuffle(threads) + + for thread in threads: + thread.start() + + for thread in threads: + thread.join() + + assert len(read_results) == num_threads + assert len(write_results) == num_threads + + tables = adapter.fetchall("show tables") + assert len(tables) == num_threads + 1 + + +def test_secret_registration_from_multiple_connections(tmp_path: Path): + database = str(tmp_path / "db.db") + + config = DuckDBConnectionConfig( + database=database, + concurrent_tasks=2, + secrets={"s3": {"type": "s3", "region": "us-east-1", "key_id": "foo", "secret": "bar"}}, + ) + + adapter = config.create_engine_adapter() + pool = adapter._connection_pool + + assert isinstance(pool, ThreadLocalSharedConnectionPool) + + def _open_connection() -> bool: + # this triggers cursor_init() to be run again for the new connection from the new thread + # if the operations in cursor_init() are not idempotent, DuckDB will throw an error and this test will fail + cur = pool.get_cursor() + cur.execute("SELECT name FROM duckdb_secrets()") + secret_names = [name for name_row in cur.fetchall() for name in name_row] + assert secret_names == ["s3"] + return True + + thread_pool = ThreadPoolExecutor(max_workers=4) + futures = [] + for _ in range(10): + futures.append(thread_pool.submit(_open_connection)) + + for future in as_completed(futures): + assert future.result() + + +def test_connector_config_from_multiple_connections(tmp_path: Path): + config = DuckDBConnectionConfig( + concurrent_tasks=2, + extensions=["tpch"], + connector_config={"temp_directory": str(tmp_path), "memory_limit": "16mb"}, + ) + + adapter = config.create_engine_adapter() + pool = adapter._connection_pool + + assert isinstance(pool, ThreadLocalSharedConnectionPool) + + adapter.execute("CALL dbgen(sf = 0.1)") + + # check that temporary files exist so that calling "SET temp_directory = 'anything'" will throw an error + assert len(adapter.fetchall("select path from duckdb_temporary_files()")) > 0 + + def _open_connection() -> bool: + # This triggers cursor_init() which should only SET values if they have changed + pool.get_cursor() + return True + + thread_pool = ThreadPoolExecutor(max_workers=4) + futures = [] + for _ in range(4): + futures.append(thread_pool.submit(_open_connection)) + + for future in as_completed(futures): + assert future.result() + + pool.close_all() diff --git a/tests/core/engine_adapter/integration/test_integration_fabric.py b/tests/core/engine_adapter/integration/test_integration_fabric.py new file mode 100644 index 0000000000..41f399b3b8 --- /dev/null +++ b/tests/core/engine_adapter/integration/test_integration_fabric.py @@ -0,0 +1,117 @@ +import typing as t +import threading +import queue +import pytest +from pytest import FixtureRequest +from sqlmesh.core.engine_adapter import FabricEngineAdapter +from sqlmesh.utils.connection_pool import ThreadLocalConnectionPool +from tests.core.engine_adapter.integration import TestContext +from concurrent.futures import ThreadPoolExecutor + +from tests.core.engine_adapter.integration import ( + TestContext, + generate_pytest_params, + ENGINES_BY_NAME, + IntegrationTestEngine, +) + + +@pytest.fixture( + params=list(generate_pytest_params(ENGINES_BY_NAME["fabric"], show_variant_in_test_id=False)) +) +def ctx( + request: FixtureRequest, + create_test_context: t.Callable[[IntegrationTestEngine, str, str], t.Iterable[TestContext]], +) -> t.Iterable[TestContext]: + yield from create_test_context(*request.param) + + +@pytest.fixture +def engine_adapter(ctx: TestContext) -> FabricEngineAdapter: + assert isinstance(ctx.engine_adapter, FabricEngineAdapter) + return ctx.engine_adapter + + +def test_create_drop_catalog(ctx: TestContext, engine_adapter: FabricEngineAdapter): + catalog_name = ctx.add_test_suffix("test_catalog") + + try: + ctx.create_catalog(catalog_name) + # if already exists, should be no-op, not error + ctx.create_catalog(catalog_name) + ctx.drop_catalog(catalog_name) + finally: + # if doesnt exist, should be no-op, not error + ctx.drop_catalog(catalog_name) + + +def test_drop_catalog_clears_threadlocals_that_reference_it( + ctx: TestContext, engine_adapter: FabricEngineAdapter +): + catalog_name = ctx.add_test_suffix("test_drop_catalog") + default_catalog = engine_adapter.get_current_catalog() + + assert isinstance(engine_adapter._connection_pool, ThreadLocalConnectionPool) + + # sets the connection attribute for this thread + engine_adapter.create_catalog(catalog_name) + assert engine_adapter._target_catalog is None + engine_adapter.set_current_catalog(catalog_name) + assert engine_adapter.get_current_catalog() == catalog_name + assert engine_adapter._target_catalog == catalog_name + + lock = threading.RLock() + + def _set_and_return_catalog_in_another_thread( + q: queue.Queue, engine_adapter: FabricEngineAdapter + ) -> t.Optional[str]: + q.put("thread_started") + + assert engine_adapter.get_current_catalog() == default_catalog + assert engine_adapter._target_catalog is None + + engine_adapter.set_current_catalog(catalog_name) + assert engine_adapter.get_current_catalog() == catalog_name + assert engine_adapter._target_catalog == catalog_name + + q.put("catalog_set_in_thread") + + # block this thread while we drop the catalog in the main test thread + lock.acquire() + + # the current catalog should have been cleared from the threadlocal connection attributes + # when this catalog was dropped by the outer thread, causing it to fall back to the default catalog + try: + assert engine_adapter._target_catalog is None + return engine_adapter.get_current_catalog() + finally: + lock.release() + + q: queue.Queue = queue.Queue() + + with ThreadPoolExecutor() as executor: + lock.acquire() # we have the lock, thread will be blocked until we release it + + future = executor.submit(_set_and_return_catalog_in_another_thread, q, engine_adapter) + + assert q.get() == "thread_started" + assert not future.done() + + try: + assert q.get(timeout=20) == "catalog_set_in_thread" + except: + if exec := future.exception(): + raise exec + raise + + ctx.drop_catalog(catalog_name) + assert not future.done() + + lock.release() # yield the lock to the thread + + # block until thread complete + result = future.result() + + # both threads should be automatically using the default catalog now + assert result == default_catalog + assert engine_adapter.get_current_catalog() == default_catalog diff --git a/tests/core/engine_adapter/integration/test_integration_postgres.py b/tests/core/engine_adapter/integration/test_integration_postgres.py new file mode 100644 index 0000000000..f236fdebce --- /dev/null +++ b/tests/core/engine_adapter/integration/test_integration_postgres.py @@ -0,0 +1,1226 @@ +import typing as t +from contextlib import contextmanager +import pytest +from pytest import FixtureRequest +from pathlib import Path +from sqlmesh.core.engine_adapter import PostgresEngineAdapter +from sqlmesh.core.config import Config, DuckDBConnectionConfig +from sqlmesh.core.config.common import VirtualEnvironmentMode +from tests.core.engine_adapter.integration import TestContext +import time_machine +from datetime import timedelta +from sqlmesh.utils.date import to_ds +from sqlglot import exp +from sqlmesh.core.context import Context +from sqlmesh.core.state_sync import CachingStateSync, EngineAdapterStateSync +from sqlmesh.core.snapshot.definition import SnapshotId +from sqlmesh.utils import random_id + +from tests.core.engine_adapter.integration import ( + TestContext, + generate_pytest_params, + ENGINES_BY_NAME, + IntegrationTestEngine, + TEST_SCHEMA, +) + + +def _cleanup_user(engine_adapter: PostgresEngineAdapter, user_name: str) -> None: + """Helper function to clean up a PostgreSQL user and all their dependencies.""" + try: + engine_adapter.execute(f""" + SELECT pg_terminate_backend(pid) + FROM pg_stat_activity + WHERE usename = '{user_name}' AND pid <> pg_backend_pid() + """) + engine_adapter.execute(f'DROP OWNED BY "{user_name}"') + engine_adapter.execute(f'DROP USER IF EXISTS "{user_name}"') + except Exception: + pass + + +@contextmanager +def create_users( + engine_adapter: PostgresEngineAdapter, *role_names: str +) -> t.Iterator[t.Dict[str, t.Dict[str, str]]]: + """Create a set of Postgres users and yield their credentials.""" + created_users = [] + roles = {} + + try: + for role_name in role_names: + user_name = f"test_{role_name}" + _cleanup_user(engine_adapter, user_name) + + for role_name in role_names: + user_name = f"test_{role_name}" + password = random_id() + engine_adapter.execute(f"CREATE USER \"{user_name}\" WITH PASSWORD '{password}'") + engine_adapter.execute(f'GRANT USAGE ON SCHEMA public TO "{user_name}"') + created_users.append(user_name) + roles[role_name] = {"username": user_name, "password": password} + + yield roles + + finally: + for user_name in created_users: + _cleanup_user(engine_adapter, user_name) + + +def create_engine_adapter_for_role( + role_credentials: t.Dict[str, str], ctx: TestContext, config: Config +) -> PostgresEngineAdapter: + """Create a PostgreSQL adapter for a specific role to test authentication and permissions.""" + from sqlmesh.core.config import PostgresConnectionConfig + + gateway = ctx.gateway + assert gateway in config.gateways + connection_config = config.gateways[gateway].connection + assert isinstance(connection_config, PostgresConnectionConfig) + + role_connection_config = PostgresConnectionConfig( + host=connection_config.host, + port=connection_config.port, + database=connection_config.database, + user=role_credentials["username"], + password=role_credentials["password"], + keepalives_idle=connection_config.keepalives_idle, + connect_timeout=connection_config.connect_timeout, + role=connection_config.role, + sslmode=connection_config.sslmode, + application_name=connection_config.application_name, + ) + + return t.cast(PostgresEngineAdapter, role_connection_config.create_engine_adapter()) + + +@contextmanager +def engine_adapter_for_role( + role_credentials: t.Dict[str, str], ctx: TestContext, config: Config +) -> t.Iterator[PostgresEngineAdapter]: + """Context manager that yields a PostgresEngineAdapter and ensures it is closed.""" + adapter = create_engine_adapter_for_role(role_credentials, ctx, config) + try: + yield adapter + finally: + adapter.close() + + +@pytest.fixture(params=list(generate_pytest_params(ENGINES_BY_NAME["postgres"]))) +def ctx( + request: FixtureRequest, + create_test_context: t.Callable[[IntegrationTestEngine, str, str], t.Iterable[TestContext]], +) -> t.Iterable[TestContext]: + yield from create_test_context(*request.param) + + +@pytest.fixture +def engine_adapter(ctx: TestContext) -> PostgresEngineAdapter: + assert isinstance(ctx.engine_adapter, PostgresEngineAdapter) + return ctx.engine_adapter + + +def test_engine_adapter(ctx: TestContext): + assert isinstance(ctx.engine_adapter, PostgresEngineAdapter) + assert ctx.engine_adapter.fetchone("select 1") == (1,) + + +def test_server_version_psycopg(ctx: TestContext): + assert isinstance(ctx.engine_adapter, PostgresEngineAdapter) + assert ctx.engine_adapter.server_version != (0, 0) + + +def test_janitor_drop_cascade(ctx: TestContext, tmp_path: Path) -> None: + """ + Scenario: + Ensure that cleaning up expired table snapshots also cleans up any unexpired view snapshots that depend on them + - We create a A (table) <- B (view) + - In dev, we modify A - triggers new version of A and a dev preview of B that both expire in 7 days + - We advance time by 3 days + - In dev, we modify B - triggers a new version of B that depends on A but expires 3 days after A + - In dev, we create B(view) <- C(view) and B(view) <- D(table) + - We advance time by 5 days so that A has reached its expiry but B, C and D have not + - We expire dev so that none of these snapshots are promoted and are thus targets for cleanup + - We run the janitor + + Expected outcome: + - All the dev versions of A and B should be dropped + - C should be dropped as well because it's a view that depends on B which was dropped + - D should not be dropped because while it depends on B which was dropped, it's a table so is still valid after B is dropped + - We should NOT get a 'ERROR: cannot drop table x because other objects depend on it' + + Note that the references in state to the views that were cascade-dropped by postgres will still exist, this is considered ok + as applying a plan will recreate the physical objects + """ + + def _all_snapshot_ids(context: Context) -> t.List[SnapshotId]: + assert isinstance(context.state_sync, CachingStateSync) + assert isinstance(context.state_sync.state_sync, EngineAdapterStateSync) + + return [ + SnapshotId(name=name, identifier=identifier) + for name, identifier in context.state_sync.state_sync.engine_adapter.fetchall( + "select name, identifier from sqlmesh._snapshots" + ) + ] + + models_dir = tmp_path / "models" + models_dir.mkdir() + schema = exp.to_table(ctx.schema(TEST_SCHEMA)).this + + (models_dir / "model_a.sql").write_text(f""" + MODEL ( + name {schema}.model_a, + kind FULL + ); + SELECT 1 as a, 2 as b; + """) + + (models_dir / "model_b.sql").write_text(f""" + MODEL ( + name {schema}.model_b, + kind VIEW + ); + SELECT a from {schema}.model_a; + """) + + def _mutate_config(gateway: str, config: Config): + config.gateways[gateway].state_connection = DuckDBConnectionConfig( + database=str(tmp_path / "state.db") + ) + + with time_machine.travel("2020-01-01 00:00:00"): + sqlmesh = ctx.create_context( + path=tmp_path, config_mutator=_mutate_config, ephemeral_state_connection=False + ) + sqlmesh.plan(auto_apply=True) + + model_a_snapshot = next(s for n, s in sqlmesh.snapshots.items() if "model_a" in n) + # expiry is last updated + ttl + assert timedelta(milliseconds=model_a_snapshot.ttl_ms) == timedelta(weeks=1) + assert to_ds(model_a_snapshot.updated_ts) == "2020-01-01" + assert to_ds(model_a_snapshot.expiration_ts) == "2020-01-08" + + model_b_snapshot = next(s for n, s in sqlmesh.snapshots.items() if "model_b" in n) + assert timedelta(milliseconds=model_b_snapshot.ttl_ms) == timedelta(weeks=1) + assert to_ds(model_b_snapshot.updated_ts) == "2020-01-01" + assert to_ds(model_b_snapshot.expiration_ts) == "2020-01-08" + + model_a_prod_snapshot = model_a_snapshot + model_b_prod_snapshot = model_b_snapshot + + # move forward 1 days + # new dev environment - touch models to create new snapshots + # model a / b expiry in prod should remain unmodified + # model a / b expiry in dev should be as at today + with time_machine.travel("2020-01-02 00:00:00"): + (models_dir / "model_a.sql").write_text(f""" + MODEL ( + name {schema}.model_a, + kind FULL + ); + SELECT 1 as a, 2 as b, 3 as c; + """) + + sqlmesh = ctx.create_context( + path=tmp_path, config_mutator=_mutate_config, ephemeral_state_connection=False + ) + sqlmesh.plan(environment="dev", auto_apply=True) + + # should now have 4 snapshots in state - 2x model a and 2x model b + # the new model b is a dev preview because its upstream model changed + all_snapshot_ids = _all_snapshot_ids(sqlmesh) + assert len(all_snapshot_ids) == 4 + assert len([s for s in all_snapshot_ids if "model_a" in s.name]) == 2 + assert len([s for s in all_snapshot_ids if "model_b" in s.name]) == 2 + + # context just has the two latest + assert len(sqlmesh.snapshots) == 2 + + # these expire 1 day later than what's in prod + model_a_snapshot = next(s for n, s in sqlmesh.snapshots.items() if "model_a" in n) + assert timedelta(milliseconds=model_a_snapshot.ttl_ms) == timedelta(weeks=1) + assert to_ds(model_a_snapshot.updated_ts) == "2020-01-02" + assert to_ds(model_a_snapshot.expiration_ts) == "2020-01-09" + + model_b_snapshot = next(s for n, s in sqlmesh.snapshots.items() if "model_b" in n) + assert timedelta(milliseconds=model_b_snapshot.ttl_ms) == timedelta(weeks=1) + assert to_ds(model_b_snapshot.updated_ts) == "2020-01-02" + assert to_ds(model_b_snapshot.expiration_ts) == "2020-01-09" + + # move forward 3 days + # touch model b in dev but leave model a + # this bumps the model b expiry but model a remains unchanged, so will expire before model b even though model b depends on it + with time_machine.travel("2020-01-05 00:00:00"): + (models_dir / "model_b.sql").write_text(f""" + MODEL ( + name {schema}.model_b, + kind VIEW + ); + SELECT a, 'b' as b from {schema}.model_a; + """) + + (models_dir / "model_c.sql").write_text(f""" + MODEL ( + name {schema}.model_c, + kind VIEW + ); + SELECT a, 'c' as c from {schema}.model_b; + """) + + (models_dir / "model_d.sql").write_text(f""" + MODEL ( + name {schema}.model_d, + kind FULL + ); + SELECT a, 'd' as d from {schema}.model_b; + """) + + sqlmesh = ctx.create_context( + path=tmp_path, config_mutator=_mutate_config, ephemeral_state_connection=False + ) + # need run=True to prevent a "start date is greater than end date" error + # since dev cant exceed what is in prod, and prod has no cadence runs, + # without run=True this plan gets start=2020-01-04 (now) end=2020-01-01 (last prod interval) which fails + sqlmesh.plan(environment="dev", auto_apply=True, run=True) + + # should now have 7 snapshots in state - 2x model a, 3x model b, 1x model c and 1x model d + all_snapshot_ids = _all_snapshot_ids(sqlmesh) + assert len(all_snapshot_ids) == 7 + assert len([s for s in all_snapshot_ids if "model_a" in s.name]) == 2 + assert len([s for s in all_snapshot_ids if "model_b" in s.name]) == 3 + assert len([s for s in all_snapshot_ids if "model_c" in s.name]) == 1 + assert len([s for s in all_snapshot_ids if "model_d" in s.name]) == 1 + + # context just has the 4 latest + assert len(sqlmesh.snapshots) == 4 + + # model a expiry should not have changed + model_a_snapshot = next(s for n, s in sqlmesh.snapshots.items() if "model_a" in n) + assert timedelta(milliseconds=model_a_snapshot.ttl_ms) == timedelta(weeks=1) + assert to_ds(model_a_snapshot.updated_ts) == "2020-01-02" + assert to_ds(model_a_snapshot.expiration_ts) == "2020-01-09" + + # model b should now expire well after model a + model_b_snapshot = next(s for n, s in sqlmesh.snapshots.items() if "model_b" in n) + assert timedelta(milliseconds=model_b_snapshot.ttl_ms) == timedelta(weeks=1) + assert to_ds(model_b_snapshot.updated_ts) == "2020-01-05" + assert to_ds(model_b_snapshot.expiration_ts) == "2020-01-12" + + # model c should expire at the same time as model b + model_c_snapshot = next(s for n, s in sqlmesh.snapshots.items() if "model_c" in n) + assert to_ds(model_c_snapshot.updated_ts) == to_ds(model_b_snapshot.updated_ts) + assert to_ds(model_c_snapshot.expiration_ts) == to_ds(model_b_snapshot.expiration_ts) + + # model d should expire at the same time as model b + model_d_snapshot = next(s for n, s in sqlmesh.snapshots.items() if "model_d" in n) + assert to_ds(model_d_snapshot.updated_ts) == to_ds(model_b_snapshot.updated_ts) + assert to_ds(model_d_snapshot.expiration_ts) == to_ds(model_b_snapshot.expiration_ts) + + # move forward to date where after model a has expired but before model b has expired + # invalidate dev to trigger cleanups + # run janitor + # - table model a is expired so will be cleaned up and this will cascade to view model b + # - view model b is not expired, but because it got cascaded to, this will cascade again to view model c + # - table model d is a not a view, so even though its parent view model b got dropped, it doesnt need to be dropped + with time_machine.travel("2020-01-10 00:00:00"): + sqlmesh = ctx.create_context( + path=tmp_path, config_mutator=_mutate_config, ephemeral_state_connection=False + ) + + before_snapshot_ids = _all_snapshot_ids(sqlmesh) + + before_objects = ctx.get_metadata_results(f"sqlmesh__{schema}") + assert set(before_objects.tables) == set( + [ + exp.to_table(s.table_name()).text("this") + for s in (model_a_prod_snapshot, model_a_snapshot, model_d_snapshot) + ] + ) + assert set(before_objects.views).issuperset( + [ + exp.to_table(s.table_name()).text("this") + for s in (model_b_prod_snapshot, model_b_snapshot, model_c_snapshot) + ] + ) + + sqlmesh.invalidate_environment("dev") + sqlmesh.run_janitor(ignore_ttl=False) + + after_snapshot_ids = _all_snapshot_ids(sqlmesh) + + assert len(before_snapshot_ids) != len(after_snapshot_ids) + + # Everything should be left in state except the model_a snapshot, which expired + assert set(after_snapshot_ids) == set(before_snapshot_ids) - set( + [model_a_snapshot.snapshot_id] + ) + + # In the db, there should be: + # - the two original snapshots that were in prod, table model_a and view model_b + # - model d, even though its not promoted in any environment, because it's a table snapshot that hasnt expired yet + # the view snapshots that depended on model_a should be gone due to the cascading delete + after_objects = ctx.get_metadata_results(f"sqlmesh__{schema}") + assert set(after_objects.tables) == set( + [ + exp.to_table(s.table_name()).text("this") + for s in (model_a_prod_snapshot, model_d_snapshot) + ] + ) + assert after_objects.views == [ + exp.to_table(model_b_prod_snapshot.table_name()).text("this") + ] + + +# Grants Integration Tests + + +def test_grants_plan_target_layer_physical_only( + engine_adapter: PostgresEngineAdapter, ctx: TestContext, tmp_path: Path +): + with create_users(engine_adapter, "reader") as roles: + (tmp_path / "models").mkdir(exist_ok=True) + + model_def = """ + MODEL ( + name test_schema.physical_grants_model, + kind FULL, + grants ( + 'select' = ['test_reader'] + ), + grants_target_layer 'physical' + ); + SELECT 1 as id, 'physical_only' as layer + """ + + (tmp_path / "models" / "physical_grants_model.sql").write_text(model_def) + + context = ctx.create_context(path=tmp_path) + plan_result = context.plan(auto_apply=True, no_prompts=True) + + assert len(plan_result.new_snapshots) == 1 + snapshot = plan_result.new_snapshots[0] + physical_table_name = snapshot.table_name() + + physical_grants = engine_adapter._get_current_grants_config( + exp.to_table(physical_table_name, dialect=engine_adapter.dialect) + ) + assert physical_grants == {"SELECT": [roles["reader"]["username"]]} + + # Virtual layer should have no grants + virtual_view_name = f"test_schema.physical_grants_model" + virtual_grants = engine_adapter._get_current_grants_config( + exp.to_table(virtual_view_name, dialect=engine_adapter.dialect) + ) + assert virtual_grants == {} + + +def test_grants_plan_target_layer_virtual_only( + engine_adapter: PostgresEngineAdapter, ctx: TestContext, tmp_path: Path +): + with create_users(engine_adapter, "viewer") as roles: + (tmp_path / "models").mkdir(exist_ok=True) + + model_def = """ + MODEL ( + name test_schema.virtual_grants_model, + kind FULL, + grants ( + 'select' = ['test_viewer'] + ), + grants_target_layer 'virtual' + ); + SELECT 1 as id, 'virtual_only' as layer + """ + + (tmp_path / "models" / "virtual_grants_model.sql").write_text(model_def) + + context = ctx.create_context(path=tmp_path) + plan_result = context.plan(auto_apply=True, no_prompts=True) + + assert len(plan_result.new_snapshots) == 1 + snapshot = plan_result.new_snapshots[0] + physical_table_name = snapshot.table_name() + + physical_grants = engine_adapter._get_current_grants_config( + exp.to_table(physical_table_name, dialect=engine_adapter.dialect) + ) + # Physical table should have no grants + assert physical_grants == {} + + virtual_view_name = f"test_schema.virtual_grants_model" + virtual_grants = engine_adapter._get_current_grants_config( + exp.to_table(virtual_view_name, dialect=engine_adapter.dialect) + ) + assert virtual_grants == {"SELECT": [roles["viewer"]["username"]]} + + +def test_grants_plan_full_refresh_model_via_replace( + engine_adapter: PostgresEngineAdapter, ctx: TestContext, tmp_path: Path +): + with create_users(engine_adapter, "reader") as roles: + (tmp_path / "models").mkdir(exist_ok=True) + (tmp_path / "models" / "full_refresh_model.sql").write_text( + f""" + MODEL ( + name test_schema.full_refresh_model, + kind FULL, + grants ( + 'SELECT' = ['{roles["reader"]["username"]}'] + ), + grants_target_layer 'all' + ); + SELECT 1 as id, 'test_data' as status + """ + ) + + context = ctx.create_context(path=tmp_path) + + plan_result = context.plan( + "dev", # this triggers _replace_query_for_model for FULL models + auto_apply=True, + no_prompts=True, + ) + + assert len(plan_result.new_snapshots) == 1 + snapshot = plan_result.new_snapshots[0] + table_name = snapshot.table_name() + + # Physical table + grants = engine_adapter._get_current_grants_config( + exp.to_table(table_name, dialect=engine_adapter.dialect) + ) + assert grants == {"SELECT": [roles["reader"]["username"]]} + + # Virtual view + dev_view_name = "test_schema__dev.full_refresh_model" + dev_grants = engine_adapter._get_current_grants_config( + exp.to_table(dev_view_name, dialect=engine_adapter.dialect) + ) + assert dev_grants == {"SELECT": [roles["reader"]["username"]]} + + +def test_grants_plan_incremental_model( + engine_adapter: PostgresEngineAdapter, ctx: TestContext, tmp_path: Path +): + with create_users(engine_adapter, "reader", "writer") as roles: + (tmp_path / "models").mkdir(exist_ok=True) + + model_name = "incr_model" + model_definition = f""" + MODEL ( + name test_schema.{model_name}, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ts + ), + grants ( + 'SELECT' = ['{roles["reader"]["username"]}'], + 'INSERT' = ['{roles["writer"]["username"]}'] + ), + grants_target_layer 'all' + ); + SELECT 1 as id, @start_ds::timestamp as ts, 'data' as value + """ + + (tmp_path / "models" / f"{model_name}.sql").write_text(model_definition) + + context = ctx.create_context(path=tmp_path) + + plan_result = context.plan( + "dev", start="2020-01-01", end="2020-01-01", auto_apply=True, no_prompts=True + ) + assert len(plan_result.new_snapshots) == 1 + + snapshot = plan_result.new_snapshots[0] + table_name = snapshot.table_name() + + physical_grants = engine_adapter._get_current_grants_config( + exp.to_table(table_name, dialect=engine_adapter.dialect) + ) + assert physical_grants.get("SELECT", []) == [roles["reader"]["username"]] + assert physical_grants.get("INSERT", []) == [roles["writer"]["username"]] + + view_name = f"test_schema__dev.{model_name}" + view_grants = engine_adapter._get_current_grants_config( + exp.to_table(view_name, dialect=engine_adapter.dialect) + ) + assert view_grants.get("SELECT", []) == [roles["reader"]["username"]] + assert view_grants.get("INSERT", []) == [roles["writer"]["username"]] + + +def test_grants_plan_clone_environment( + engine_adapter: PostgresEngineAdapter, ctx: TestContext, tmp_path: Path +): + with create_users(engine_adapter, "reader") as roles: + (tmp_path / "models").mkdir(exist_ok=True) + (tmp_path / "models" / "clone_model.sql").write_text( + f""" + MODEL ( + name test_schema.clone_model, + kind FULL, + grants ( + 'SELECT' = ['{roles["reader"]["username"]}'] + ), + grants_target_layer 'all' + ); + + SELECT 1 as id, 'data' as value + """ + ) + + context = ctx.create_context(path=tmp_path) + prod_plan_result = context.plan("prod", auto_apply=True, no_prompts=True) + + assert len(prod_plan_result.new_snapshots) == 1 + prod_snapshot = prod_plan_result.new_snapshots[0] + prod_table_name = prod_snapshot.table_name() + + # Prod physical table grants + prod_grants = engine_adapter._get_current_grants_config( + exp.to_table(prod_table_name, dialect=engine_adapter.dialect) + ) + assert prod_grants == {"SELECT": [roles["reader"]["username"]]} + + # Prod virtual view grants + prod_view_name = f"test_schema.clone_model" + prod_view_grants = engine_adapter._get_current_grants_config( + exp.to_table(prod_view_name, dialect=engine_adapter.dialect) + ) + assert prod_view_grants == {"SELECT": [roles["reader"]["username"]]} + + # Create dev environment (cloned from prod) + context.plan("dev", auto_apply=True, no_prompts=True, include_unmodified=True) + + # Physical table grants should remain unchanged + prod_grants_after_clone = engine_adapter._get_current_grants_config( + exp.to_table(prod_table_name, dialect=engine_adapter.dialect) + ) + assert prod_grants_after_clone == prod_grants + + # Dev virtual view should have the same grants as prod + dev_view_name = f"test_schema__dev.clone_model" + dev_view_grants = engine_adapter._get_current_grants_config( + exp.to_table(dev_view_name, dialect=engine_adapter.dialect) + ) + assert dev_view_grants == prod_grants + + +@pytest.mark.parametrize( + "model_name,kind_config,query,extra_config,needs_seed", + [ + ( + "grants_full", + "FULL", + "SELECT 1 as id, 'unchanged_query' as data", + "", + False, + ), + ( + "grants_view", + "VIEW", + "SELECT 1 as id, 'unchanged_query' as data", + "", + False, + ), + ( + "grants_incr_time", + "INCREMENTAL_BY_TIME_RANGE (time_column event_date)", + "SELECT '2025-09-01'::date as event_date, 1 as id, 'unchanged_query' as data", + "start '2025-09-01',", + False, + ), + ( + "grants_seed", + "SEED (path '../seeds/grants_seed.csv')", + "", + "", + True, + ), + ], +) +def test_grants_metadata_only_changes( + engine_adapter: PostgresEngineAdapter, + ctx: TestContext, + tmp_path: Path, + model_name: str, + kind_config: str, + query: str, + extra_config: str, + needs_seed: bool, +): + with create_users(engine_adapter, "reader", "writer", "admin") as roles: + (tmp_path / "models").mkdir(exist_ok=True) + + if needs_seed: + (tmp_path / "seeds").mkdir(exist_ok=True) + csv_content = "id,data\\n1,unchanged_query" + (tmp_path / "seeds" / f"{model_name}.csv").write_text(csv_content) + + initial_model_def = f""" + MODEL ( + name test_schema.{model_name}, + kind {kind_config}, + {extra_config} + grants ( + 'select' = ['{roles["reader"]["username"]}'] + ), + grants_target_layer 'all' + ); + {query} + """ + (tmp_path / "models" / f"{model_name}.sql").write_text(initial_model_def) + + context = ctx.create_context(path=tmp_path) + initial_plan_result = context.plan(auto_apply=True, no_prompts=True) + + assert len(initial_plan_result.new_snapshots) == 1 + initial_snapshot = initial_plan_result.new_snapshots[0] + + physical_table_name = initial_snapshot.table_name() + virtual_view_name = f"test_schema.{model_name}" + + initial_physical_grants = engine_adapter._get_current_grants_config( + exp.to_table(physical_table_name, dialect=engine_adapter.dialect) + ) + assert initial_physical_grants == {"SELECT": [roles["reader"]["username"]]} + + initial_virtual_grants = engine_adapter._get_current_grants_config( + exp.to_table(virtual_view_name, dialect=engine_adapter.dialect) + ) + assert initial_virtual_grants == {"SELECT": [roles["reader"]["username"]]} + + # Metadata-only change: update grants only using upsert_model + existing_model = context.get_model(f"test_schema.{model_name}") + context.upsert_model( + existing_model, + grants={ + "select": [roles["writer"]["username"], roles["admin"]["username"]], + "insert": [roles["admin"]["username"]], + }, + ) + second_plan_result = context.plan(auto_apply=True, no_prompts=True) + + expected_grants = { + "SELECT": [roles["writer"]["username"], roles["admin"]["username"]], + "INSERT": [roles["admin"]["username"]], + } + + # For seed models, grant changes rebuild the entire table, so it will create a new physical table + if model_name == "grants_seed" and second_plan_result.new_snapshots: + updated_snapshot = second_plan_result.new_snapshots[0] + physical_table_name = updated_snapshot.table_name() + + updated_physical_grants = engine_adapter._get_current_grants_config( + exp.to_table(physical_table_name, dialect=engine_adapter.dialect) + ) + assert set(updated_physical_grants.get("SELECT", [])) == set(expected_grants["SELECT"]) + assert updated_physical_grants.get("INSERT", []) == expected_grants["INSERT"] + + updated_virtual_grants = engine_adapter._get_current_grants_config( + exp.to_table(virtual_view_name, dialect=engine_adapter.dialect) + ) + assert set(updated_virtual_grants.get("SELECT", [])) == set(expected_grants["SELECT"]) + assert updated_virtual_grants.get("INSERT", []) == expected_grants["INSERT"] + + +def _vde_dev_only_config(gateway: str, config: Config) -> None: + config.virtual_environment_mode = VirtualEnvironmentMode.DEV_ONLY + + +@pytest.mark.parametrize( + "grants_target_layer,model_kind", + [ + ("virtual", "FULL"), + ("physical", "FULL"), + ("all", "FULL"), + ("virtual", "VIEW"), + ("physical", "VIEW"), + ], +) +def test_grants_target_layer_with_vde_dev_only( + engine_adapter: PostgresEngineAdapter, + ctx: TestContext, + tmp_path: Path, + grants_target_layer: str, + model_kind: str, +): + with create_users(engine_adapter, "reader", "writer") as roles: + (tmp_path / "models").mkdir(exist_ok=True) + + if model_kind == "VIEW": + grants_config = ( + f"'SELECT' = ['{roles['reader']['username']}', '{roles['writer']['username']}']" + ) + else: + grants_config = f""" + 'SELECT' = ['{roles["reader"]["username"]}', '{roles["writer"]["username"]}'], + 'INSERT' = ['{roles["writer"]["username"]}'] + """.strip() + + model_def = f""" + MODEL ( + name test_schema.vde_model_{grants_target_layer}_{model_kind.lower()}, + kind {model_kind}, + grants ( + {grants_config} + ), + grants_target_layer '{grants_target_layer}' + ); + SELECT 1 as id, '{grants_target_layer}_{model_kind}' as test_type + """ + ( + tmp_path / "models" / f"vde_model_{grants_target_layer}_{model_kind.lower()}.sql" + ).write_text(model_def) + + context = ctx.create_context(path=tmp_path, config_mutator=_vde_dev_only_config) + context.plan("prod", auto_apply=True, no_prompts=True) + + table_name = f"test_schema.vde_model_{grants_target_layer}_{model_kind.lower()}" + + # In VDE dev_only mode, VIEWs are created as actual views + assert context.engine_adapter.table_exists(table_name) + + grants = engine_adapter._get_current_grants_config( + exp.to_table(table_name, dialect=engine_adapter.dialect) + ) + assert roles["reader"]["username"] in grants.get("SELECT", []) + assert roles["writer"]["username"] in grants.get("SELECT", []) + + if model_kind != "VIEW": + assert roles["writer"]["username"] in grants.get("INSERT", []) + + +def test_grants_incremental_model_with_vde_dev_only( + engine_adapter: PostgresEngineAdapter, ctx: TestContext, tmp_path: Path +): + with create_users(engine_adapter, "etl", "analyst") as roles: + (tmp_path / "models").mkdir(exist_ok=True) + + model_def = f""" + MODEL ( + name test_schema.vde_incremental_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column event_date + ), + grants ( + 'SELECT' = ['{roles["analyst"]["username"]}'], + 'INSERT' = ['{roles["etl"]["username"]}'] + ), + grants_target_layer 'virtual' + ); + SELECT + 1 as id, + @start_date::date as event_date, + 'event' as event_type + """ + (tmp_path / "models" / "vde_incremental_model.sql").write_text(model_def) + + context = ctx.create_context(path=tmp_path, config_mutator=_vde_dev_only_config) + context.plan("prod", auto_apply=True, no_prompts=True) + + prod_table = "test_schema.vde_incremental_model" + prod_grants = engine_adapter._get_current_grants_config( + exp.to_table(prod_table, dialect=engine_adapter.dialect) + ) + assert roles["analyst"]["username"] in prod_grants.get("SELECT", []) + assert roles["etl"]["username"] in prod_grants.get("INSERT", []) + + +@pytest.mark.parametrize( + "change_type,initial_query,updated_query,expect_schema_change", + [ + # Metadata-only change (grants only) + ( + "metadata_only", + "SELECT 1 as id, 'same' as status", + "SELECT 1 as id, 'same' as status", + False, + ), + # Breaking change only + ( + "breaking_only", + "SELECT 1 as id, 'initial' as status, 100 as amount", + "SELECT 1 as id, 'updated' as status", # Removed column + True, + ), + # Both metadata and breaking changes + ( + "metadata_and_breaking", + "SELECT 1 as id, 'initial' as status, 100 as amount", + "SELECT 2 as id, 'changed' as new_status", # Different schema + True, + ), + ], +) +def test_grants_changes_with_vde_dev_only( + engine_adapter: PostgresEngineAdapter, + ctx: TestContext, + tmp_path: Path, + change_type: str, + initial_query: str, + updated_query: str, + expect_schema_change: bool, +): + with create_users(engine_adapter, "user1", "user2", "user3") as roles: + (tmp_path / "models").mkdir(exist_ok=True) + model_path = tmp_path / "models" / f"vde_changes_{change_type}.sql" + + initial_model = f""" + MODEL ( + name test_schema.vde_changes_{change_type}, + kind FULL, + grants ( + 'SELECT' = ['{roles["user1"]["username"]}'] + ), + grants_target_layer 'virtual' + ); + {initial_query} + """ + model_path.write_text(initial_model) + + context = ctx.create_context(path=tmp_path, config_mutator=_vde_dev_only_config) + context.plan("prod", auto_apply=True, no_prompts=True) + + table_name = f"test_schema.vde_changes_{change_type}" + initial_grants = engine_adapter._get_current_grants_config( + exp.to_table(table_name, dialect=engine_adapter.dialect) + ) + assert roles["user1"]["username"] in initial_grants.get("SELECT", []) + assert roles["user2"]["username"] not in initial_grants.get("SELECT", []) + + # Update model with new grants and potentially new query + updated_model = f""" + MODEL ( + name test_schema.vde_changes_{change_type}, + kind FULL, + grants ( + 'SELECT' = ['{roles["user1"]["username"]}', '{roles["user2"]["username"]}', '{roles["user3"]["username"]}'], + 'INSERT' = ['{roles["user3"]["username"]}'] + ), + grants_target_layer 'virtual' + ); + {updated_query} + """ + model_path.write_text(updated_model) + + # Get initial table columns + initial_columns = set( + col[0] + for col in engine_adapter.fetchall( + f"SELECT column_name FROM information_schema.columns WHERE table_schema = 'test_schema' AND table_name = 'vde_changes_{change_type}'" + ) + ) + + context.load() + plan = context.plan("prod", auto_apply=True, no_prompts=True) + + assert len(plan.new_snapshots) == 1 + + current_columns = set( + col[0] + for col in engine_adapter.fetchall( + f"SELECT column_name FROM information_schema.columns WHERE table_schema = 'test_schema' AND table_name = 'vde_changes_{change_type}'" + ) + ) + + if expect_schema_change: + assert current_columns != initial_columns + else: + # For metadata-only changes, schema should be the same + assert current_columns == initial_columns + + # Grants should be updated in all cases + updated_grants = engine_adapter._get_current_grants_config( + exp.to_table(table_name, dialect=engine_adapter.dialect) + ) + assert roles["user1"]["username"] in updated_grants.get("SELECT", []) + assert roles["user2"]["username"] in updated_grants.get("SELECT", []) + assert roles["user3"]["username"] in updated_grants.get("SELECT", []) + assert roles["user3"]["username"] in updated_grants.get("INSERT", []) + + +@pytest.mark.parametrize( + "grants_target_layer,environment", + [ + ("virtual", "prod"), + ("virtual", "dev"), + ("physical", "prod"), + ("physical", "staging"), + ("all", "prod"), + ("all", "preview"), + ], +) +def test_grants_target_layer_plan_env_with_vde_dev_only( + engine_adapter: PostgresEngineAdapter, + ctx: TestContext, + tmp_path: Path, + grants_target_layer: str, + environment: str, +): + with create_users(engine_adapter, "grantee") as roles: + (tmp_path / "models").mkdir(exist_ok=True) + + model_def = f""" + MODEL ( + name test_schema.vde_layer_model, + kind FULL, + grants ( + 'SELECT' = ['{roles["grantee"]["username"]}'] + ), + grants_target_layer '{grants_target_layer}' + ); + SELECT 1 as id, '{environment}' as env, '{grants_target_layer}' as layer + """ + (tmp_path / "models" / "vde_layer_model.sql").write_text(model_def) + + context = ctx.create_context(path=tmp_path, config_mutator=_vde_dev_only_config) + + if environment == "prod": + context.plan("prod", auto_apply=True, no_prompts=True) + table_name = "test_schema.vde_layer_model" + grants = engine_adapter._get_current_grants_config( + exp.to_table(table_name, dialect=engine_adapter.dialect) + ) + assert roles["grantee"]["username"] in grants.get("SELECT", []) + else: + context.plan(environment, auto_apply=True, no_prompts=True, include_unmodified=True) + virtual_view = f"test_schema__{environment}.vde_layer_model" + assert context.engine_adapter.table_exists(virtual_view) + virtual_grants = engine_adapter._get_current_grants_config( + exp.to_table(virtual_view, dialect=engine_adapter.dialect) + ) + + data_objects = engine_adapter.get_data_objects("sqlmesh__test_schema") + physical_tables = [ + obj + for obj in data_objects + if "vde_layer_model" in obj.name + and obj.name.endswith("__dev") # Always __dev suffix in VDE dev_only + and "TABLE" in str(obj.type).upper() + ] + + if grants_target_layer == "virtual": + # Virtual layer should have grants, physical should not + assert roles["grantee"]["username"] in virtual_grants.get("SELECT", []) + + assert len(physical_tables) > 0 + for physical_table in physical_tables: + physical_table_name = f"sqlmesh__test_schema.{physical_table.name}" + physical_grants = engine_adapter._get_current_grants_config( + exp.to_table(physical_table_name, dialect=engine_adapter.dialect) + ) + assert roles["grantee"]["username"] not in physical_grants.get("SELECT", []) + + elif grants_target_layer == "physical": + # Virtual layer should not have grants, physical should + assert roles["grantee"]["username"] not in virtual_grants.get("SELECT", []) + + assert len(physical_tables) > 0 + for physical_table in physical_tables: + physical_table_name = f"sqlmesh__test_schema.{physical_table.name}" + physical_grants = engine_adapter._get_current_grants_config( + exp.to_table(physical_table_name, dialect=engine_adapter.dialect) + ) + assert roles["grantee"]["username"] in physical_grants.get("SELECT", []) + + else: # grants_target_layer == "all" + # Both layers should have grants + assert roles["grantee"]["username"] in virtual_grants.get("SELECT", []) + assert len(physical_tables) > 0 + for physical_table in physical_tables: + physical_table_name = f"sqlmesh__test_schema.{physical_table.name}" + physical_grants = engine_adapter._get_current_grants_config( + exp.to_table(physical_table_name, dialect=engine_adapter.dialect) + ) + assert roles["grantee"]["username"] in physical_grants.get("SELECT", []) + + +@pytest.mark.parametrize( + "model_kind", + [ + "SCD_TYPE_2", + "SCD_TYPE_2_BY_TIME", + ], +) +def test_grants_plan_scd_type_2_models( + engine_adapter: PostgresEngineAdapter, + ctx: TestContext, + tmp_path: Path, + model_kind: str, +): + with create_users(engine_adapter, "reader", "writer", "analyst") as roles: + (tmp_path / "models").mkdir(exist_ok=True) + model_name = "scd_model" + + kind_config = f"{model_kind} (unique_key [id])" + model_definition = f""" + MODEL ( + name test_schema.{model_name}, + kind {kind_config}, + grants ( + 'SELECT' = ['{roles["reader"]["username"]}'], + 'INSERT' = ['{roles["writer"]["username"]}'] + ), + grants_target_layer 'all' + ); + SELECT 1 as id, 'initial_data' as name, CURRENT_TIMESTAMP as updated_at + """ + (tmp_path / "models" / f"{model_name}.sql").write_text(model_definition) + + context = ctx.create_context(path=tmp_path) + plan_result = context.plan( + "dev", start="2023-01-01", end="2023-01-01", auto_apply=True, no_prompts=True + ) + assert len(plan_result.new_snapshots) == 1 + + current_snapshot = plan_result.new_snapshots[0] + fingerprint_version = current_snapshot.fingerprint.to_version() + physical_table_name = ( + f"sqlmesh__test_schema.test_schema__{model_name}__{fingerprint_version}__dev" + ) + physical_grants = engine_adapter._get_current_grants_config( + exp.to_table(physical_table_name, dialect=engine_adapter.dialect) + ) + assert physical_grants.get("SELECT", []) == [roles["reader"]["username"]] + assert physical_grants.get("INSERT", []) == [roles["writer"]["username"]] + + view_name = f"test_schema__dev.{model_name}" + view_grants = engine_adapter._get_current_grants_config( + exp.to_table(view_name, dialect=engine_adapter.dialect) + ) + assert view_grants.get("SELECT", []) == [roles["reader"]["username"]] + assert view_grants.get("INSERT", []) == [roles["writer"]["username"]] + + # Data change + updated_model_definition = f""" + MODEL ( + name test_schema.{model_name}, + kind {kind_config}, + grants ( + 'SELECT' = ['{roles["reader"]["username"]}'], + 'INSERT' = ['{roles["writer"]["username"]}'] + ), + grants_target_layer 'all' + ); + SELECT 1 as id, 'updated_data' as name, CURRENT_TIMESTAMP as updated_at + """ + (tmp_path / "models" / f"{model_name}.sql").write_text(updated_model_definition) + + context.load() + context.plan("dev", start="2023-01-02", end="2023-01-02", auto_apply=True, no_prompts=True) + + snapshot = context.get_snapshot(f"test_schema.{model_name}") + assert snapshot + fingerprint = snapshot.fingerprint.to_version() + table_name = f"sqlmesh__test_schema.test_schema__{model_name}__{fingerprint}__dev" + data_change_grants = engine_adapter._get_current_grants_config( + exp.to_table(table_name, dialect=engine_adapter.dialect) + ) + assert data_change_grants.get("SELECT", []) == [roles["reader"]["username"]] + assert data_change_grants.get("INSERT", []) == [roles["writer"]["username"]] + + # Data + grants changes + grant_change_model_definition = f""" + MODEL ( + name test_schema.{model_name}, + kind {kind_config}, + grants ( + 'SELECT' = ['{roles["reader"]["username"]}', '{roles["analyst"]["username"]}'], + 'INSERT' = ['{roles["writer"]["username"]}'], + 'UPDATE' = ['{roles["analyst"]["username"]}'] + ), + grants_target_layer 'all' + ); + SELECT 1 as id, 'grant_changed_data' as name, CURRENT_TIMESTAMP as updated_at + """ + (tmp_path / "models" / f"{model_name}.sql").write_text(grant_change_model_definition) + + context.load() + context.plan("dev", start="2023-01-03", end="2023-01-03", auto_apply=True, no_prompts=True) + + snapshot = context.get_snapshot(f"test_schema.{model_name}") + assert snapshot + fingerprint = snapshot.fingerprint.to_version() + table_name = f"sqlmesh__test_schema.test_schema__{model_name}__{fingerprint}__dev" + final_grants = engine_adapter._get_current_grants_config( + exp.to_table(table_name, dialect=engine_adapter.dialect) + ) + expected_select_users = {roles["reader"]["username"], roles["analyst"]["username"]} + assert set(final_grants.get("SELECT", [])) == expected_select_users + assert final_grants.get("INSERT", []) == [roles["writer"]["username"]] + assert final_grants.get("UPDATE", []) == [roles["analyst"]["username"]] + + final_view_grants = engine_adapter._get_current_grants_config( + exp.to_table(view_name, dialect=engine_adapter.dialect) + ) + assert set(final_view_grants.get("SELECT", [])) == expected_select_users + assert final_view_grants.get("INSERT", []) == [roles["writer"]["username"]] + assert final_view_grants.get("UPDATE", []) == [roles["analyst"]["username"]] + + +@pytest.mark.parametrize( + "model_kind", + [ + "SCD_TYPE_2", + "SCD_TYPE_2_BY_TIME", + ], +) +def test_grants_plan_scd_type_2_with_vde_dev_only( + engine_adapter: PostgresEngineAdapter, + ctx: TestContext, + tmp_path: Path, + model_kind: str, +): + with create_users(engine_adapter, "etl_user", "analyst") as roles: + (tmp_path / "models").mkdir(exist_ok=True) + model_name = "vde_scd_model" + + model_def = f""" + MODEL ( + name test_schema.{model_name}, + kind {model_kind} (unique_key [customer_id]), + grants ( + 'SELECT' = ['{roles["analyst"]["username"]}'], + 'INSERT' = ['{roles["etl_user"]["username"]}'] + ), + grants_target_layer 'all' + ); + SELECT + 1 as customer_id, + 'active' as status, + CURRENT_TIMESTAMP as updated_at + """ + (tmp_path / "models" / f"{model_name}.sql").write_text(model_def) + + context = ctx.create_context(path=tmp_path, config_mutator=_vde_dev_only_config) + + # Prod + context.plan("prod", auto_apply=True, no_prompts=True) + prod_table = f"test_schema.{model_name}" + prod_grants = engine_adapter._get_current_grants_config( + exp.to_table(prod_table, dialect=engine_adapter.dialect) + ) + assert roles["analyst"]["username"] in prod_grants.get("SELECT", []) + assert roles["etl_user"]["username"] in prod_grants.get("INSERT", []) + + # Dev + context.plan("dev", auto_apply=True, no_prompts=True, include_unmodified=True) + dev_view = f"test_schema__dev.{model_name}" + dev_grants = engine_adapter._get_current_grants_config( + exp.to_table(dev_view, dialect=engine_adapter.dialect) + ) + assert roles["analyst"]["username"] in dev_grants.get("SELECT", []) + assert roles["etl_user"]["username"] in dev_grants.get("INSERT", []) + + snapshot = context.get_snapshot(f"test_schema.{model_name}") + assert snapshot + fingerprint_version = snapshot.fingerprint.to_version() + dev_physical_table_name = ( + f"sqlmesh__test_schema.test_schema__{model_name}__{fingerprint_version}__dev" + ) + + dev_physical_grants = engine_adapter._get_current_grants_config( + exp.to_table(dev_physical_table_name, dialect=engine_adapter.dialect) + ) + assert roles["analyst"]["username"] in dev_physical_grants.get("SELECT", []) + assert roles["etl_user"]["username"] in dev_physical_grants.get("INSERT", []) diff --git a/tests/core/engine_adapter/integration/test_integration_redshift.py b/tests/core/engine_adapter/integration/test_integration_redshift.py new file mode 100644 index 0000000000..be5a47e714 --- /dev/null +++ b/tests/core/engine_adapter/integration/test_integration_redshift.py @@ -0,0 +1,107 @@ +import typing as t +import pytest +from pytest import FixtureRequest +from tests.core.engine_adapter.integration import TestContext +from sqlmesh.core.engine_adapter.redshift import RedshiftEngineAdapter +from sqlglot import exp + +from tests.core.engine_adapter.integration import ( + TestContext, + generate_pytest_params, + ENGINES_BY_NAME, + IntegrationTestEngine, +) + + +@pytest.fixture(params=list(generate_pytest_params(ENGINES_BY_NAME["redshift"]))) +def ctx( + request: FixtureRequest, + create_test_context: t.Callable[[IntegrationTestEngine, str, str], t.Iterable[TestContext]], +) -> t.Iterable[TestContext]: + yield from create_test_context(*request.param) + + +@pytest.fixture +def engine_adapter(ctx: TestContext) -> RedshiftEngineAdapter: + assert isinstance(ctx.engine_adapter, RedshiftEngineAdapter) + return ctx.engine_adapter + + +def test_columns(ctx: TestContext): + ctx.init() + + table = ctx.table("column_types") + col_strings = { + "char": ["char", "character", "nchar"], + "varchar": ["varchar", "character varying", "nvarchar"], + "varbinary": ["varbyte", "varbinary", "binary varying"], + "decimal": ["decimal", "numeric"], + } + + # raw ddl + sql = f"CREATE TABLE {table} (" + sql += ( + ", ".join( + f"{col.replace(' ', '_')}10 {col}(10)" + for col in [*col_strings["char"], *col_strings["varchar"], *col_strings["varbinary"]] + ) + + ", " + ) + # bare types that should have their default lengths of 1 added by columns() + sql += ", ".join(f"{col.replace(' ', '_')}1 {col}" for col in col_strings["char"]) + ", " + # bare types that should have their default lengths of 256 added by columns() + sql += ", ".join(f"{col.replace(' ', '_')}256 {col}" for col in col_strings["varchar"]) + ", " + sql += ( + ", ".join(f"{col.replace(' ', '_')}172 {col}(17, 2)" for col in col_strings["decimal"]) + + ")" + ) + + ctx.engine_adapter.cursor.execute(sql) + columns = ctx.engine_adapter.columns(table) + + # columns to types + cols_to_types = { + f"{col.replace(' ', '_')}10": exp.DataType.build(f"{col}(10)", dialect=ctx.dialect) + for col in [*col_strings["char"], *col_strings["varchar"], *col_strings["varbinary"]] + } + cols_to_types.update( + { + f"{col.replace(' ', '_')}1": exp.DataType.build(f"{col}(1)", dialect=ctx.dialect) + for col in col_strings["char"] + } + ) + cols_to_types.update( + { + f"{col.replace(' ', '_')}256": exp.DataType.build(f"{col}(256)", dialect=ctx.dialect) + for col in col_strings["varchar"] + } + ) + cols_to_types.update( + { + f"{col.replace(' ', '_')}172": exp.DataType.build(f"{col}(17, 2)", dialect=ctx.dialect) + for col in col_strings["decimal"] + } + ) + + # did we convert the types from redshift correctly? + assert [col.sql(ctx.dialect) for col in columns.values()] == [ + col.sql(ctx.dialect) for col in cols_to_types.values() + ] + + # did we replace default char/varchar lengths with MAX correctly? + max_cols = [col for col in columns if col.endswith("1") or col.endswith("256")] + assert [ + col.sql(ctx.dialect) + for col in ctx.engine_adapter._default_precision_to_max( # type: ignore + {k: columns[k] for k in max_cols} + ).values() + ] == ["CHAR(max)", "CHAR(max)", "CHAR(max)", "VARCHAR(max)", "VARCHAR(max)", "VARCHAR(max)"] + + +def test_fetch_native_df_respects_case_sensitivity(ctx: TestContext): + adapter = ctx.engine_adapter + adapter.execute("SET enable_case_sensitive_identifier TO true") + assert adapter.fetchdf('WITH t AS (SELECT 1 AS "C", 2 AS "c") SELECT * FROM t').to_dict() == { + "C": {0: 1}, + "c": {0: 2}, + } diff --git a/tests/core/engine_adapter/integration/test_integration_risingwave.py b/tests/core/engine_adapter/integration/test_integration_risingwave.py new file mode 100644 index 0000000000..76b3d20a7c --- /dev/null +++ b/tests/core/engine_adapter/integration/test_integration_risingwave.py @@ -0,0 +1,82 @@ +import typing as t +import pytest +from sqlglot import exp +from pytest import FixtureRequest +from sqlmesh.core.engine_adapter import RisingwaveEngineAdapter +from tests.core.engine_adapter.integration import ( + TestContext, + generate_pytest_params, + ENGINES_BY_NAME, + IntegrationTestEngine, +) + + +@pytest.fixture(params=list(generate_pytest_params(ENGINES_BY_NAME["risingwave"]))) +def ctx( + request: FixtureRequest, + create_test_context: t.Callable[[IntegrationTestEngine, str, str], t.Iterable], +) -> t.Iterable[TestContext]: + yield from create_test_context(*request.param) + + +@pytest.fixture +def engine_adapter(ctx: TestContext) -> RisingwaveEngineAdapter: + assert isinstance(ctx.engine_adapter, RisingwaveEngineAdapter) + return ctx.engine_adapter + + +@pytest.fixture +def risingwave_columns_with_datatypes(ctx: TestContext) -> t.Dict[str, exp.DataType]: + base_types = { + "smallint_col": exp.DataType.build(exp.DataType.Type.SMALLINT, nested=False), + "int_col": exp.DataType.build(exp.DataType.Type.INT, nested=False), + "bigint_col": exp.DataType.build(exp.DataType.Type.BIGINT, nested=False), + "ts_col": exp.DataType.build(exp.DataType.Type.TIMESTAMP, nested=False), + "tstz_col": exp.DataType.build(exp.DataType.Type.TIMESTAMPTZ, nested=False), + "vchar_col": exp.DataType.build(exp.DataType.Type.VARCHAR, nested=False), + } + # generate all arrays of base types + arr_types = { + f"{type_name}_arr_col": exp.DataType.build( + exp.DataType.Type.ARRAY, + expressions=[base_type], + nested=True, + ) + for type_name, base_type in base_types.items() + } + # generate struct with all base types as nested columns + struct_types = { + "struct_col": exp.DataType.build( + exp.DataType.Type.STRUCT, + expressions=[ + exp.ColumnDef( + this=exp.Identifier(this=f"nested_{type_name}_col", quoted=False), + kind=base_type, + ) + for type_name, base_type in base_types.items() + ], + nested=True, + ) + } + return {**base_types, **arr_types, **struct_types} + + +def test_engine_adapter(ctx: TestContext): + assert isinstance(ctx.engine_adapter, RisingwaveEngineAdapter) + assert ctx.engine_adapter.fetchone("select 1") == (1,) + + +def test_engine_adapter_columns( + ctx: TestContext, risingwave_columns_with_datatypes: t.Dict[str, exp.DataType] +): + table = ctx.table("TEST_COLUMNS") + query = exp.select( + *[ + exp.cast(exp.null(), dtype).as_(name) + for name, dtype in risingwave_columns_with_datatypes.items() + ] + ) + ctx.engine_adapter.ctas(table, query) + + column_result = ctx.engine_adapter.columns(table) + assert column_result == risingwave_columns_with_datatypes diff --git a/tests/core/engine_adapter/integration/test_integration_snowflake.py b/tests/core/engine_adapter/integration/test_integration_snowflake.py new file mode 100644 index 0000000000..f9862c51cb --- /dev/null +++ b/tests/core/engine_adapter/integration/test_integration_snowflake.py @@ -0,0 +1,380 @@ +import pytest +import typing as t +from datetime import datetime +from pathlib import Path +from pytest import FixtureRequest +from pytest_mock import MockerFixture + +import sqlmesh.core.dialect as d +from sqlglot import exp +from sqlmesh import Config, ExecutionContext, model +from sqlglot.helper import seq_get +from sqlglot.optimizer.qualify_columns import quote_identifiers +from sqlmesh.core.config import ModelDefaultsConfig +from sqlmesh.core.engine_adapter import SnowflakeEngineAdapter +from sqlmesh.core.engine_adapter.shared import DataObject +from sqlmesh.core.model import ModelKindName, SqlModel, load_sql_based_model +from sqlmesh.core.plan import Plan +from sqlmesh.core.snapshot import SnapshotId, SnapshotIdBatch +from sqlmesh.core.snapshot.execution_tracker import ( + QueryExecutionContext, + QueryExecutionTracker, +) +from tests.core.engine_adapter.integration import ( + TestContext, + generate_pytest_params, + ENGINES_BY_NAME, + IntegrationTestEngine, +) + + +@pytest.fixture( + params=list(generate_pytest_params(ENGINES_BY_NAME["snowflake"], show_variant_in_test_id=False)) +) +def ctx( + request: FixtureRequest, + create_test_context: t.Callable[[IntegrationTestEngine, str, str], t.Iterable[TestContext]], +) -> t.Iterable[TestContext]: + yield from create_test_context(*request.param) + + +@pytest.fixture +def engine_adapter(ctx: TestContext) -> SnowflakeEngineAdapter: + assert isinstance(ctx.engine_adapter, SnowflakeEngineAdapter) + return ctx.engine_adapter + + +def test_get_alter_expressions_includes_clustering( + ctx: TestContext, engine_adapter: SnowflakeEngineAdapter +): + clustered_table = ctx.table("clustered_table") + clustered_differently_table = ctx.table("clustered_differently_table") + normal_table = ctx.table("normal_table") + + engine_adapter.execute(f"CREATE TABLE {clustered_table} (c1 int, c2 timestamp) CLUSTER BY (c1)") + engine_adapter.execute( + f"CREATE TABLE {clustered_differently_table} (c1 int, c2 timestamp) CLUSTER BY (c1, to_date(c2))" + ) + engine_adapter.execute(f"CREATE TABLE {normal_table} (c1 int, c2 timestamp)") + + assert len(engine_adapter.get_alter_operations(normal_table, normal_table)) == 0 + assert len(engine_adapter.get_alter_operations(clustered_table, clustered_table)) == 0 + + # alter table drop clustered + clustered_to_normal = engine_adapter.get_alter_operations(clustered_table, normal_table) + assert len(clustered_to_normal) == 1 + assert ( + clustered_to_normal[0].expression.sql(dialect=ctx.dialect) + == f"ALTER TABLE {clustered_table} DROP CLUSTERING KEY" + ) + + # alter table add clustered + normal_to_clustered = engine_adapter.get_alter_operations(normal_table, clustered_table) + assert len(normal_to_clustered) == 1 + assert ( + normal_to_clustered[0].expression.sql(dialect=ctx.dialect) + == f"ALTER TABLE {normal_table} CLUSTER BY (c1)" + ) + + # alter table change clustering + clustered_to_clustered_differently = engine_adapter.get_alter_operations( + clustered_table, clustered_differently_table + ) + assert len(clustered_to_clustered_differently) == 1 + assert ( + clustered_to_clustered_differently[0].expression.sql(dialect=ctx.dialect) + == f"ALTER TABLE {clustered_table} CLUSTER BY (c1, TO_DATE(c2))" + ) + + # alter table change clustering + clustered_differently_to_clustered = engine_adapter.get_alter_operations( + clustered_differently_table, clustered_table + ) + assert len(clustered_differently_to_clustered) == 1 + assert ( + clustered_differently_to_clustered[0].expression.sql(dialect=ctx.dialect) + == f"ALTER TABLE {clustered_differently_table} CLUSTER BY (c1)" + ) + + +def test_mutating_clustered_by_forward_only( + ctx: TestContext, engine_adapter: SnowflakeEngineAdapter +): + model_name = ctx.table("TEST") + + sqlmesh = ctx.create_context() + + def _create_model(**kwargs: t.Any) -> SqlModel: + extra_props = "\n".join([f"{k} {v}," for k, v in kwargs.items()]) + return t.cast( + SqlModel, + load_sql_based_model( + d.parse( + f""" + MODEL ( + name {model_name}, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column PARTITIONDATE + ), + {extra_props} + start '2021-01-01', + cron '@daily', + dialect 'snowflake' + ); + + select 1 as ID, current_timestamp() as PARTITIONDATE + """ + ) + ), + ) + + def _get_data_object(table: exp.Table) -> DataObject: + data_object = seq_get(engine_adapter.get_data_objects(table.db, {table.name}), 0) + if not data_object: + raise ValueError(f"Expected metadata for {table}") + return data_object + + m1 = _create_model() + m2 = _create_model(clustered_by="PARTITIONDATE") + m3 = _create_model(clustered_by="(ID, PARTITIONDATE)") + + # Initial plan - non-clustered table + sqlmesh.upsert_model(m1) + plan_1: Plan = sqlmesh.plan(auto_apply=True, no_prompts=True) + assert len(plan_1.snapshots) == 1 + target_table_1 = exp.to_table(list(plan_1.snapshots.values())[0].table_name()) + quote_identifiers(target_table_1) + + assert not _get_data_object(target_table_1).is_clustered + + # Next plan - add clustering key (non-clustered -> clustered) + sqlmesh.upsert_model(m2) + plan_2: Plan = sqlmesh.plan(auto_apply=True, no_prompts=True, forward_only=True) + assert len(plan_2.snapshots) == 1 + target_table_2 = exp.to_table(list(plan_2.snapshots.values())[0].table_name()) + quote_identifiers(target_table_2) + + assert target_table_1 == target_table_2 + + metadata = _get_data_object(target_table_1) + assert metadata.is_clustered + assert metadata.clustering_key == 'LINEAR("PARTITIONDATE")' + + # Next plan - change clustering key (clustered -> clustered differently) + sqlmesh.upsert_model(m3) + plan_3: Plan = sqlmesh.plan(auto_apply=True, no_prompts=True, forward_only=True) + assert len(plan_3.snapshots) == 1 + target_table_3 = exp.to_table(list(plan_3.snapshots.values())[0].table_name()) + quote_identifiers(target_table_3) + + assert target_table_1 == target_table_3 + + metadata = _get_data_object(target_table_1) + assert metadata.is_clustered + assert metadata.clustering_key == 'LINEAR("ID", "PARTITIONDATE")' + + # Next plan - drop clustering key + sqlmesh.upsert_model(m1) + plan_4: Plan = sqlmesh.plan(auto_apply=True, no_prompts=True, forward_only=True) + assert len(plan_4.snapshots) == 1 + target_table_4 = exp.to_table(list(plan_4.snapshots.values())[0].table_name()) + quote_identifiers(target_table_4) + + assert target_table_1 == target_table_4 + + metadata = _get_data_object(target_table_1) + assert not metadata.is_clustered + + +def test_create_iceberg_table(ctx: TestContext) -> None: + # Note: this test relies on a default Catalog and External Volume being configured in Snowflake + # ref: https://docs.snowflake.com/en/user-guide/tables-iceberg-configure-catalog-integration#set-a-default-catalog-at-the-account-database-or-schema-level + # ref: https://docs.snowflake.com/en/user-guide/tables-iceberg-configure-external-volume#set-a-default-external-volume-at-the-account-database-or-schema-level + # This has been done on the Snowflake account used by CI + + model_name = ctx.table("TEST") + managed_model_name = ctx.table("TEST_DYNAMIC") + sqlmesh = ctx.create_context() + + model = load_sql_based_model( + d.parse(f""" + MODEL ( + name {model_name}, + kind FULL, + table_format iceberg, + dialect 'snowflake' + ); + + select 1 as "ID", 'foo' as "NAME"; + """) + ) + + managed_model = load_sql_based_model( + d.parse(f""" + MODEL ( + name {managed_model_name}, + kind MANAGED, + physical_properties ( + target_lag = '20 minutes' + ), + table_format iceberg, + dialect 'snowflake' + ); + + select "ID", "NAME" from {model_name}; + """) + ) + + sqlmesh.upsert_model(model) + sqlmesh.upsert_model(managed_model) + + result = sqlmesh.plan(auto_apply=True) + + assert len(result.new_snapshots) == 2 + + +def test_snowpark_concurrency(ctx: TestContext) -> None: + from snowflake.snowpark import DataFrame + + table = ctx.table("my_model") + + # this model will insert 10 records in batches of 1, with 4 batches at a time running concurrently + @model( + name=table.sql(), + kind=dict( + name=ModelKindName.INCREMENTAL_BY_TIME_RANGE, + time_column="ds", + batch_size=1, + batch_concurrency=4, + ), + columns={"id": "int", "ds": "date"}, + start="2020-01-01", + end="2020-01-10", + ) + def execute(context: ExecutionContext, start: datetime, **kwargs) -> DataFrame: + if snowpark := context.snowpark: + return snowpark.create_dataframe([(start.day, start.date())], schema=["id", "ds"]) + + raise ValueError("Snowpark not present!") + + m = model.get_registry()[table.sql().lower()].model( + module_path=Path("."), path=Path("."), dialect="snowflake" + ) + + sqlmesh = ctx.create_context() + + # verify that we are actually running in multithreaded mode + assert sqlmesh.concurrent_tasks > 1 + assert ctx.engine_adapter._multithreaded + + sqlmesh.upsert_model(m) + + plan = sqlmesh.plan(auto_apply=True) + + assert len(plan.new_snapshots) == 1 + + query = exp.select("*").from_(table) + df = ctx.engine_adapter.fetchdf(query, quote_identifiers=True) + assert len(df) == 10 + + +def test_create_drop_catalog(ctx: TestContext, engine_adapter: SnowflakeEngineAdapter): + non_sqlmesh_managed_catalog = ctx.add_test_suffix("external_catalog") + sqlmesh_managed_catalog = ctx.add_test_suffix("env_dev") + + initial_catalog = engine_adapter.get_current_catalog() + assert initial_catalog + + ctx.create_catalog( + non_sqlmesh_managed_catalog + ) # create via TestContext so the sqlmesh_managed comment doesnt get added + ctx._catalogs.append(sqlmesh_managed_catalog) # so it still gets cleaned up if the test fails + + engine_adapter.create_catalog( + sqlmesh_managed_catalog + ) # create via EngineAdapter so the sqlmesh_managed comment is added + + def fetch_database_names() -> t.Set[str]: + engine_adapter.set_current_catalog(initial_catalog) + return { + str(r[0]) + for r in engine_adapter.fetchall( + f"select database_name from information_schema.databases where database_name like '%{ctx.test_id}'" + ) + } + + assert fetch_database_names() == {non_sqlmesh_managed_catalog, sqlmesh_managed_catalog} + + engine_adapter.drop_catalog( + non_sqlmesh_managed_catalog + ) # no-op: catalog is not SQLMesh-managed + assert fetch_database_names() == {non_sqlmesh_managed_catalog, sqlmesh_managed_catalog} + + engine_adapter.drop_catalog(sqlmesh_managed_catalog) # works, catalog is SQLMesh-managed + assert fetch_database_names() == {non_sqlmesh_managed_catalog} + + +def test_rows_tracker( + ctx: TestContext, engine_adapter: SnowflakeEngineAdapter, mocker: MockerFixture +): + sqlmesh = ctx.create_context() + tracker = QueryExecutionTracker() + + add_execution_spy = mocker.spy(QueryExecutionContext, "add_execution") + + with tracker.track_execution( + SnapshotIdBatch(snapshot_id=SnapshotId(name="a", identifier="a"), batch_id=0) + ): + # Snowflake doesn't report row counts for CTAS, so this should not be tracked + engine_adapter._create_table("a", exp.select("1 as id")) + + assert add_execution_spy.call_count == 0 + + stats = tracker.get_execution_stats( + SnapshotIdBatch(snapshot_id=SnapshotId(name="a", identifier="a"), batch_id=0) + ) + assert stats is not None + assert stats.total_rows_processed is None + assert stats.total_bytes_processed is None + + +def test_unit_test(tmp_path: Path, ctx: TestContext): + models_path = tmp_path / "models" + tests_path = tmp_path / "tests" + + models_path.mkdir() + tests_path.mkdir() + + test_payload = """ +test_dummy_model: + model: s.dummy + inputs: + s.src_table: + rows: + - c: 1 + outputs: + query: + - c: 1 + """ + + (models_path / "dummy_model.sql").write_text(f"MODEL (name s.dummy); SELECT c FROM s.src_table") + (tests_path / "test_dummy_model.yaml").write_text(test_payload) + + def _config_mutator(gateway_name: str, config: Config): + config.model_defaults = ModelDefaultsConfig(dialect="snowflake") + test_connection = config.gateways[gateway_name].connection.copy() # type: ignore + + # Force the database to lowercase to test that we normalize (if we didn't, the test would fail) + test_connection.database = test_connection.database.lower() # type: ignore + config.gateways[gateway_name].test_connection = test_connection + + sqlmesh = ctx.create_context(path=tmp_path, config_mutator=_config_mutator) + + test_conn = sqlmesh.config.get_test_connection(ctx.gateway) + assert test_conn.type_ == "snowflake" + + catalog = test_conn.get_catalog() + assert catalog is not None and catalog.islower() + + test_results = sqlmesh.test() + assert not test_results.errors diff --git a/tests/core/engine_adapter/integration/test_integration_trino.py b/tests/core/engine_adapter/integration/test_integration_trino.py new file mode 100644 index 0000000000..81313b2a8d --- /dev/null +++ b/tests/core/engine_adapter/integration/test_integration_trino.py @@ -0,0 +1,84 @@ +import typing as t +import pytest +from pytest import FixtureRequest +from pathlib import Path +from sqlmesh.core.engine_adapter import TrinoEngineAdapter +from tests.core.engine_adapter.integration import TestContext +from sqlglot import parse_one, exp +from tests.core.engine_adapter.integration import ( + TestContext, + generate_pytest_params, + ENGINES_BY_NAME, + IntegrationTestEngine, +) + + +@pytest.fixture(params=list(generate_pytest_params(ENGINES_BY_NAME["trino"]))) +def ctx( + request: FixtureRequest, + create_test_context: t.Callable[[IntegrationTestEngine, str, str], t.Iterable[TestContext]], +) -> t.Iterable[TestContext]: + yield from create_test_context(*request.param) + + +@pytest.fixture +def engine_adapter(ctx: TestContext) -> TrinoEngineAdapter: + assert isinstance(ctx.engine_adapter, TrinoEngineAdapter) + return ctx.engine_adapter + + +def test_macros_in_physical_properties( + tmp_path: Path, ctx: TestContext, engine_adapter: TrinoEngineAdapter +): + if "iceberg" not in ctx.gateway: + pytest.skip("This test only needs to be run once") + + models_dir = tmp_path / "models" + models_dir.mkdir(parents=True) + + schema = ctx.schema() + + with open(models_dir / "test_model.sql", "w") as f: + f.write( + """ + MODEL ( + name SCHEMA.test, + kind FULL, + physical_properties ( + location = @resolve_template('s3://trino/@{catalog_name}/@{schema_name}/@{table_name}'), + sorted_by = @if(@gateway = 'inttest_trino_iceberg', ARRAY['col_a'], ARRAY['col_b']) + ) + ); + + select 1 as col_a, 2 as col_b; + """.replace("SCHEMA", schema) + ) + + context = ctx.create_context(path=tmp_path) + assert len(context.models) == 1 + + plan_result = context.plan(auto_apply=True, no_prompts=True) + + assert len(plan_result.new_snapshots) == 1 + + snapshot = plan_result.new_snapshots[0] + + physical_table_str = snapshot.table_name() + physical_table = exp.to_table(physical_table_str) + create_sql = list(engine_adapter.fetchone(f"show create table {physical_table}") or [])[0] + + parsed_create_sql = parse_one(create_sql, dialect="trino") + + location_property = parsed_create_sql.find(exp.LocationProperty) + assert location_property + + assert "@{table_name}" not in location_property.sql(dialect="trino") + assert ( + location_property.text("this") + == f"s3://trino/{physical_table.catalog}/{physical_table.db}/{physical_table.name}" + ) + + sorted_by_property = next( + p for p in parsed_create_sql.find_all(exp.Property) if "sorted_by" in p.sql(dialect="trino") + ) + assert sorted_by_property.sql(dialect="trino") == "sorted_by=ARRAY['col_a ASC NULLS FIRST']" diff --git a/tests/core/engine_adapter/test_athena.py b/tests/core/engine_adapter/test_athena.py new file mode 100644 index 0000000000..66e84ae025 --- /dev/null +++ b/tests/core/engine_adapter/test_athena.py @@ -0,0 +1,575 @@ +import typing as t +import pytest +from unittest.mock import Mock +from pytest_mock import MockerFixture +import pandas as pd # noqa: TID253 + +from sqlglot import exp, parse_one +import sqlmesh.core.dialect as d +from sqlmesh.core.engine_adapter import AthenaEngineAdapter +from sqlmesh.core.engine_adapter.shared import DataObject +from sqlmesh.core.model import load_sql_based_model +from sqlmesh.core.model.definition import SqlModel +from sqlmesh.utils.errors import SQLMeshError +from sqlmesh.core.table_diff import TableDiff + +from tests.core.engine_adapter import to_sql_calls + +pytestmark = [pytest.mark.athena, pytest.mark.engine] + + +@pytest.fixture +def adapter(make_mocked_engine_adapter: t.Callable) -> AthenaEngineAdapter: + return make_mocked_engine_adapter(AthenaEngineAdapter) + + +@pytest.fixture +def table_diff(adapter: AthenaEngineAdapter) -> TableDiff: + return TableDiff( + adapter=adapter, + source="source_table", + target="target_table", + on=["id"], + ) + + +@pytest.mark.parametrize( + "config_s3_warehouse_location,table_properties,table,expected_location", + [ + # No s3_warehouse_location in config + (None, None, exp.to_table("schema.table"), None), + (None, {}, exp.to_table("schema.table"), None), + ( + None, + {"s3_base_location": exp.Literal.string("s3://some/location/")}, + exp.to_table("schema.table"), + "s3://some/location/table/", + ), + # Location set to bucket + ("s3://bucket", None, exp.to_table("schema.table"), "s3://bucket/schema/table/"), + ("s3://bucket", {}, exp.to_table("schema.table"), "s3://bucket/schema/table/"), + ("s3://bucket", None, exp.to_table("schema.table"), "s3://bucket/schema/table/"), + ( + "s3://bucket", + {"s3_base_location": exp.Literal.string("s3://some/location/")}, + exp.to_table("schema.table"), + "s3://some/location/table/", + ), + ("s3://bucket", {}, exp.Table(db=exp.Identifier(this="test")), "s3://bucket/test/"), + # Location set to bucket with prefix + ( + "s3://bucket/subpath/", + None, + exp.to_table("schema.table"), + "s3://bucket/subpath/schema/table/", + ), + ("s3://bucket/subpath/", None, exp.to_table("table"), "s3://bucket/subpath/table/"), + ( + "s3://bucket/subpath/", + None, + exp.to_table("catalog.schema.table"), + "s3://bucket/subpath/catalog/schema/table/", + ), + ( + "s3://bucket/subpath/", + None, + exp.Table(db=exp.Identifier(this="test")), + "s3://bucket/subpath/test/", + ), + ], +) +def test_table_location( + adapter: AthenaEngineAdapter, + config_s3_warehouse_location: t.Optional[str], + table_properties: t.Optional[t.Dict[str, exp.Expression]], + table: exp.Table, + expected_location: t.Optional[str], +) -> None: + adapter.s3_warehouse_location = config_s3_warehouse_location + if expected_location is None: + with pytest.raises(SQLMeshError, match=r"Cannot figure out location for table.*"): + adapter._table_location_or_raise(table_properties, table) + else: + location = adapter._table_location_or_raise( + table_properties, table + ).this.name # extract the unquoted location value from the LocationProperty + assert location == expected_location + + if table_properties is not None: + # this get consumed by _table_location because we dont want it to end up in a TBLPROPERTIES clause + assert "s3_base_location" not in table_properties + + +def test_create_schema(adapter: AthenaEngineAdapter) -> None: + adapter.create_schema("test") + + adapter.s3_warehouse_location = "s3://base" + adapter.create_schema("test") + + assert to_sql_calls(adapter) == [ + "CREATE SCHEMA IF NOT EXISTS `test`", + "CREATE SCHEMA IF NOT EXISTS `test` LOCATION 's3://base/test/'", + ] + + +def test_create_table_hive(adapter: AthenaEngineAdapter) -> None: + expressions = d.parse( + """ + MODEL ( + name test_table, + kind FULL, + partitioned_by (cola, colb), + storage_format parquet, + physical_properties ( + s3_base_location = 's3://foo', + has_encrypted_data = 'true' + ) + ); + + SELECT 1::timestamp AS cola, 2::varchar as colb, 'foo' as colc; + """ + ) + model: SqlModel = t.cast(SqlModel, load_sql_based_model(expressions)) + + adapter.create_table( + model.name, + target_columns_to_types=model.columns_to_types_or_raise, + table_properties=model.physical_properties, + partitioned_by=model.partitioned_by, + storage_format=model.storage_format, + ) + + assert to_sql_calls(adapter) == [ + "CREATE EXTERNAL TABLE IF NOT EXISTS `test_table` (`colc` STRING) PARTITIONED BY (`cola` TIMESTAMP, `colb` STRING) STORED AS PARQUET LOCATION 's3://foo/test_table/' TBLPROPERTIES ('has_encrypted_data'='true')" + ] + + +def test_create_table_iceberg(adapter: AthenaEngineAdapter) -> None: + expressions = d.parse( + """ + MODEL ( + name test_table, + kind FULL, + partitioned_by (colc, bucket(16, cola)), + table_format iceberg, + storage_format parquet, + physical_properties ( + s3_base_location = 's3://foo' + ) + ); + + SELECT 1::timestamp AS cola, 2::varchar as colb, 'foo' as colc; + """ + ) + model: SqlModel = t.cast(SqlModel, load_sql_based_model(expressions)) + + adapter.create_table( + model.name, + target_columns_to_types=model.columns_to_types_or_raise, + table_properties=model.physical_properties, + partitioned_by=model.partitioned_by, + table_format=model.table_format, + storage_format=model.storage_format, + ) + + assert to_sql_calls(adapter) == [ + "CREATE TABLE IF NOT EXISTS `test_table` (`cola` TIMESTAMP, `colb` STRING, `colc` STRING) PARTITIONED BY (`colc`, BUCKET(16, `cola`)) LOCATION 's3://foo/test_table/' TBLPROPERTIES ('table_type'='iceberg', 'format'='parquet')" + ] + + +def test_create_table_no_location(adapter: AthenaEngineAdapter) -> None: + expressions = d.parse( + """ + MODEL ( + name test_table, + kind FULL + ); + + SELECT a::int FROM foo; + """ + ) + model: SqlModel = t.cast(SqlModel, load_sql_based_model(expressions)) + + with pytest.raises(SQLMeshError, match=r"Cannot figure out location.*"): + adapter.create_table( + model.name, + target_columns_to_types=model.columns_to_types_or_raise, + table_properties=model.physical_properties, + ) + + adapter.s3_warehouse_location = "s3://bucket/prefix" + adapter.create_table( + model.name, + target_columns_to_types=model.columns_to_types_or_raise, + table_properties=model.physical_properties, + ) + + assert to_sql_calls(adapter) == [ + "CREATE EXTERNAL TABLE IF NOT EXISTS `test_table` (`a` INT) LOCATION 's3://bucket/prefix/test_table/'", + ] + + +def test_ctas_hive(adapter: AthenaEngineAdapter): + adapter.s3_warehouse_location = "s3://bucket/prefix/" + + adapter.ctas( + table_name="foo.bar", + target_columns_to_types={"a": exp.DataType.build("int")}, + query_or_df=parse_one("select 1", into=exp.Select), + ) + + assert to_sql_calls(adapter) == [ + 'CREATE TABLE IF NOT EXISTS "foo"."bar" WITH (external_location=\'s3://bucket/prefix/foo/bar/\') AS SELECT CAST("a" AS INTEGER) AS "a" FROM (SELECT 1) AS "_subquery"' + ] + + +def test_ctas_iceberg(adapter: AthenaEngineAdapter): + adapter.s3_warehouse_location = "s3://bucket/prefix/" + + adapter.ctas( + table_name="foo.bar", + target_columns_to_types={"a": exp.DataType.build("int")}, + query_or_df=parse_one("select 1", into=exp.Select), + table_format="iceberg", + ) + + assert to_sql_calls(adapter) == [ + 'CREATE TABLE IF NOT EXISTS "foo"."bar" WITH (table_type=\'iceberg\', location=\'s3://bucket/prefix/foo/bar/\', is_external=false) AS SELECT CAST("a" AS INTEGER) AS "a" FROM (SELECT 1) AS "_subquery"' + ] + + +def test_ctas_iceberg_no_specific_location(adapter: AthenaEngineAdapter): + with pytest.raises(SQLMeshError, match=r"Cannot figure out location.*"): + adapter.ctas( + table_name="foo.bar", + target_columns_to_types={"a": exp.DataType.build("int")}, + query_or_df=parse_one("select 1", into=exp.Select), + table_properties={"table_type": exp.Literal.string("iceberg")}, + ) + + assert to_sql_calls(adapter) == [] + + +def test_ctas_iceberg_partitioned(adapter: AthenaEngineAdapter): + expressions = d.parse( + """ + MODEL ( + name test_table, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column business_date + ), + table_format iceberg, + start '2025-01-15' + ); + + SELECT 1::timestamp AS business_date, 2::varchar as colb, 'foo' as colc; + """ + ) + model: SqlModel = t.cast(SqlModel, load_sql_based_model(expressions)) + + adapter.s3_warehouse_location = "s3://bucket/prefix/" + adapter.ctas( + table_name=model.name, + target_columns_to_types=model.columns_to_types, + partitioned_by=model.partitioned_by, + query_or_df=model.ctas_query(), + table_format=model.table_format, + ) + + assert to_sql_calls(adapter) == [ + """CREATE TABLE IF NOT EXISTS "test_table" WITH (table_type='iceberg', partitioning=ARRAY['business_date'], location='s3://bucket/prefix/test_table/', is_external=false) AS SELECT CAST("business_date" AS TIMESTAMP) AS "business_date", CAST("colb" AS VARCHAR) AS "colb", CAST("colc" AS VARCHAR) AS "colc" FROM (SELECT CAST(1 AS TIMESTAMP) AS "business_date", CAST(2 AS VARCHAR) AS "colb", 'foo' AS "colc" LIMIT 0) AS "_subquery\"""" + ] + + +def test_replace_query(adapter: AthenaEngineAdapter, mocker: MockerFixture): + mocker.patch( + "sqlmesh.core.engine_adapter.athena.AthenaEngineAdapter.table_exists", return_value=True + ) + mocker.patch( + "sqlmesh.core.engine_adapter.athena.AthenaEngineAdapter._query_table_type", + return_value="iceberg", + ) + mocker.patch.object( + adapter, + "_get_data_objects", + return_value=[DataObject(schema="", name="test", type="table")], + ) + + adapter.replace_query( + table_name="test", + query_or_df=parse_one("select 1 as a", into=exp.Select), + target_columns_to_types={"a": exp.DataType.build("int")}, + table_properties={}, + ) + + assert to_sql_calls(adapter) == [ + 'DELETE FROM "test" WHERE TRUE', + 'INSERT INTO "test" ("a") SELECT 1 AS "a"', + ] + + mocker.patch( + "sqlmesh.core.engine_adapter.athena.AthenaEngineAdapter.table_exists", return_value=False + ) + mocker.patch.object(adapter, "_get_data_objects", return_value=[]) + adapter.cursor.execute.reset_mock() + adapter._clear_data_object_cache() + + adapter.s3_warehouse_location = "s3://foo" + adapter.replace_query( + table_name="test", + query_or_df=parse_one("select 1 as a", into=exp.Select), + target_columns_to_types={"a": exp.DataType.build("int")}, + table_properties={}, + ) + + # gets recreated as a Hive table because table_exists=False and nothing in the properties indicates it should be Iceberg + assert to_sql_calls(adapter) == [ + 'CREATE TABLE IF NOT EXISTS "test" WITH (external_location=\'s3://foo/test/\') AS SELECT CAST("a" AS INTEGER) AS "a" FROM (SELECT 1 AS "a") AS "_subquery"' + ] + + +def test_columns(adapter: AthenaEngineAdapter, mocker: MockerFixture): + mock = mocker.patch( + "pandas.io.sql.read_sql_query", + return_value=pd.DataFrame( + data=[["col1", "int"], ["col2", "varchar"]], columns=["column_name", "data_type"] + ), + ) + + assert adapter.columns("foo.bar") == { + "col1": exp.DataType.build("int"), + "col2": exp.DataType.build("varchar"), + } + + assert ( + mock.call_args_list[0][0][0] + == """SELECT "column_name", "data_type" FROM "information_schema"."columns" WHERE "table_schema" = 'foo' AND "table_name" = 'bar' ORDER BY "ordinal_position" NULLS FIRST""" + ) + + +def test_truncate_table_iceberg(adapter: AthenaEngineAdapter, mocker: MockerFixture): + mocker.patch.object( + adapter, + "_query_table_type", + return_value="iceberg", + ) + mocker.patch.multiple( + adapter, _clear_partition_data=mocker.DEFAULT, _clear_s3_location=mocker.DEFAULT + ) + adapter._truncate_table(exp.to_table("foo.bar")) + + assert to_sql_calls(adapter) == ['DELETE FROM "foo"."bar" WHERE TRUE'] + t.cast(Mock, adapter._clear_partition_data).assert_not_called() + t.cast(Mock, adapter._clear_s3_location).assert_not_called() + + +def test_truncate_table_hive(adapter: AthenaEngineAdapter, mocker: MockerFixture): + mocker.patch.object( + adapter, + "_query_table_type", + return_value="hive", + ) + mocker.patch.object( + adapter, + "_is_hive_partitioned_table", + return_value=False, + ) + mocker.patch.object(adapter, "_query_table_s3_location", return_value="s3://foo/bar") + mocker.patch.multiple( + adapter, _clear_partition_data=mocker.DEFAULT, _clear_s3_location=mocker.DEFAULT + ) + + adapter._truncate_table(exp.to_table("foo.bar")) + + assert to_sql_calls(adapter) == [] + t.cast(Mock, adapter._clear_partition_data).assert_not_called() + t.cast(Mock, adapter._clear_s3_location).assert_called_with("s3://foo/bar") + + +def test_truncate_table_hive_partitioned(adapter: AthenaEngineAdapter, mocker: MockerFixture): + mocker.patch.object( + adapter, + "_query_table_type", + return_value="hive", + ) + mocker.patch.object( + adapter, + "_is_hive_partitioned_table", + return_value=True, + ) + mocker.patch.object(adapter, "_clear_partition_data") + mocker.patch.object(adapter, "_clear_s3_location") + adapter._truncate_table(exp.to_table("foo.bar")) + + assert to_sql_calls(adapter) == [] + t.cast(Mock, adapter._clear_partition_data).assert_called_with( + exp.to_table("foo.bar"), exp.true() + ) + t.cast(Mock, adapter._clear_s3_location).assert_not_called() + + +def test_create_state_table(adapter: AthenaEngineAdapter): + adapter.s3_warehouse_location = "s3://base" + adapter.create_state_table("_snapshots", {"name": exp.DataType.build("varchar")}) + + assert to_sql_calls(adapter) == [ + "CREATE TABLE IF NOT EXISTS `_snapshots` (`name` STRING) LOCATION 's3://base/_snapshots/' TBLPROPERTIES ('table_type'='iceberg')" + ] + + +def test_drop_partitions_from_metastore_uses_batches( + adapter: AthenaEngineAdapter, mocker: MockerFixture +): + glue_client_mock = mocker.patch.object(AthenaEngineAdapter, "_glue_client", autospec=True) + + glue_client_mock.batch_delete_partition.assert_not_called() + + partition_values = [] + + for i in range(63): + partition_values.append([str(i)]) + + adapter._drop_partitions_from_metastore( + table=exp.table_("foo"), partition_values=partition_values + ) + + glue_client_mock.batch_delete_partition.assert_called() + + # should have been called in batches of 25 + calls = glue_client_mock.batch_delete_partition.call_args_list + assert len(calls) == 3 + + assert len(calls[0][1]["PartitionsToDelete"]) == 25 + assert len(calls[1][1]["PartitionsToDelete"]) == 25 + assert len(calls[2][1]["PartitionsToDelete"]) == 13 + + # first call 0-24 + assert calls[0][1]["PartitionsToDelete"][0]["Values"][0] == "0" + assert calls[0][1]["PartitionsToDelete"][-1]["Values"][0] == "24" + + # second call 25-49 + assert calls[1][1]["PartitionsToDelete"][0]["Values"][0] == "25" + assert calls[1][1]["PartitionsToDelete"][-1]["Values"][0] == "49" + + # third call 50-62 + assert calls[2][1]["PartitionsToDelete"][0]["Values"][0] == "50" + assert calls[2][1]["PartitionsToDelete"][-1]["Values"][0] == "62" + + +def test_iceberg_partition_transforms(adapter: AthenaEngineAdapter): + expressions = d.parse( + """ + MODEL ( + name test_table, + kind FULL, + table_format iceberg, + partitioned_by (month(business_date), bucket(4, colb), colc) + ); + + SELECT 1::timestamp AS business_date, 2::varchar as colb, 'foo' as colc; + """ + ) + model: SqlModel = t.cast(SqlModel, load_sql_based_model(expressions)) + + assert model.partitioned_by == [ + exp.Month(this=exp.column("business_date", quoted=True)), + exp.PartitionedByBucket( + this=exp.column("colb", quoted=True), expression=exp.Literal.number(4) + ), + exp.column("colc", quoted=True), + ] + + adapter.s3_warehouse_location = "s3://bucket/prefix/" + + adapter.create_table( + table_name=model.name, + target_columns_to_types=model.columns_to_types_or_raise, + partitioned_by=model.partitioned_by, + table_format=model.table_format, + ) + + adapter.ctas( + table_name=model.name, + target_columns_to_types=model.columns_to_types_or_raise, + partitioned_by=model.partitioned_by, + query_or_df=model.ctas_query(), + table_format=model.table_format, + ) + + assert to_sql_calls(adapter) == [ + # Hive syntax - create table + """CREATE TABLE IF NOT EXISTS `test_table` (`business_date` TIMESTAMP, `colb` STRING, `colc` STRING) PARTITIONED BY (MONTH(`business_date`), BUCKET(4, `colb`), `colc`) LOCATION 's3://bucket/prefix/test_table/' TBLPROPERTIES ('table_type'='iceberg')""", + # Trino syntax - CTAS + """CREATE TABLE IF NOT EXISTS "test_table" WITH (table_type='iceberg', partitioning=ARRAY['MONTH(business_date)', 'BUCKET(colb, 4)', 'colc'], location='s3://bucket/prefix/test_table/', is_external=false) AS SELECT CAST("business_date" AS TIMESTAMP) AS "business_date", CAST("colb" AS VARCHAR) AS "colb", CAST("colc" AS VARCHAR) AS "colc" FROM (SELECT CAST(1 AS TIMESTAMP) AS "business_date", CAST(2 AS VARCHAR) AS "colb", 'foo' AS "colc" LIMIT 0) AS "_subquery\"""", + ] + + +@pytest.mark.parametrize( + "source_format, target_format, expected_temp_format, expect_error", + [ + ("hive", "hive", None, False), + ("iceberg", "hive", None, True), # Expect error for mismatched formats + ("hive", "iceberg", None, True), # Expect error for mismatched formats + ("iceberg", "iceberg", "iceberg", False), + (None, "iceberg", None, True), # Source doesn't exist or type unknown, target is iceberg + ( + "iceberg", + None, + "iceberg", + True, + ), # Target doesn't exist or type unknown, source is iceberg + (None, "hive", None, False), # Source doesn't exist or type unknown, target is hive + ("hive", None, None, False), # Target doesn't exist or type unknown, source is hive + (None, None, None, False), # Both don't exist or types unknown + ], +) +def test_table_diff_temp_table_format( + table_diff: TableDiff, + mocker: MockerFixture, + source_format: t.Optional[str], + target_format: t.Optional[str], + expected_temp_format: t.Optional[str], + expect_error: bool, +): + adapter = t.cast(AthenaEngineAdapter, table_diff.adapter) + + # Mock _query_table_type to return specified formats + def mock_query_table_type(table_name: exp.Table) -> t.Optional[str]: + if table_name.name == "source_table": + return source_format + if table_name.name == "target_table": + return target_format + return "hive" # Default for other tables if any + + mocker.patch.object(adapter, "_query_table_type", side_effect=mock_query_table_type) + + # Mock temp_table to capture kwargs + mock_temp_table = mocker.patch.object(adapter, "temp_table", autospec=True) + mock_temp_table.return_value.__enter__.return_value = exp.to_table("diff_table") + + # Mock fetchdf and other calls made within row_diff to avoid actual DB interaction + mocker.patch.object(adapter, "fetchdf", return_value=pd.DataFrame()) + mocker.patch.object(adapter, "get_data_objects", return_value=[]) + mocker.patch.object(adapter, "columns", return_value={"id": exp.DataType.build("int")}) + + if expect_error: + with pytest.raises( + SQLMeshError, + match="do not match for Athena. Diffing between different table formats is not supported.", + ): + table_diff.row_diff() + mock_temp_table.assert_not_called() # temp_table should not be called if formats mismatch + return + + try: + table_diff.row_diff() + except Exception: + pass # We only care about the temp_table call args for non-error cases + + mock_temp_table.assert_called_once() + _, called_kwargs = mock_temp_table.call_args + + if expected_temp_format: + assert called_kwargs.get("table_format") == expected_temp_format + else: + assert "table_format" not in called_kwargs diff --git a/tests/core/engine_adapter/test_base.py b/tests/core/engine_adapter/test_base.py index 9f6b2278be..2b9bcc665f 100644 --- a/tests/core/engine_adapter/test_base.py +++ b/tests/core/engine_adapter/test_base.py @@ -3,7 +3,7 @@ from datetime import datetime from unittest.mock import call -import pandas as pd +import pandas as pd # noqa: TID253 import pytest from pytest_mock.plugin import MockerFixture from sqlglot import expressions as exp @@ -13,13 +13,14 @@ from sqlmesh.core import dialect as d from sqlmesh.core.dialect import normalize_model_name from sqlmesh.core.engine_adapter import EngineAdapter, EngineAdapterWithIndexSupport -from sqlmesh.core.engine_adapter.shared import InsertOverwriteStrategy -from sqlmesh.core.schema_diff import SchemaDiffer, TableAlterOperation +from sqlmesh.core.engine_adapter.shared import InsertOverwriteStrategy, DataObject +from sqlmesh.core.schema_diff import SchemaDiffer, TableAlterOperation, NestedSupport from sqlmesh.utils import columns_to_types_to_struct from sqlmesh.utils.date import to_ds from sqlmesh.utils.errors import SQLMeshError, UnsupportedCatalogOperationError from tests.core.engine_adapter import to_sql_calls + pytestmark = pytest.mark.engine @@ -42,6 +43,23 @@ def test_create_view(make_mocked_engine_adapter: t.Callable): ] +def test_create_view_existing_data_object_type_mismatch( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(EngineAdapter) + mocker.patch.object( + adapter, + "_get_data_objects", + return_value=[DataObject(schema="", name="test_view", type="table")], + ) + adapter.create_view("test_view", parse_one("SELECT a FROM tbl")) + + assert to_sql_calls(adapter) == [ + 'DROP TABLE IF EXISTS "test_view"', + 'CREATE OR REPLACE VIEW "test_view" AS SELECT "a" FROM "tbl"', + ] + + def test_create_view_pandas(make_mocked_engine_adapter: t.Callable): adapter = make_mocked_engine_adapter(EngineAdapter) adapter.create_view("test_view", pd.DataFrame({"a": [1, 2, 3]}), replace=False) @@ -58,6 +76,39 @@ def test_create_view_pandas(make_mocked_engine_adapter: t.Callable): ] +def test_create_view_pandas_source_columns(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(EngineAdapter) + bigint_dtype = exp.DataType.build("BIGINT") + adapter.create_view( + "test_view", + pd.DataFrame({"a": [1, 2, 3], "ignored_source": [4, 5, 6]}), + target_columns_to_types={"a": bigint_dtype, "b": bigint_dtype}, + replace=False, + source_columns=["a", "ignored_source"], + ) + + assert to_sql_calls(adapter) == [ + 'CREATE VIEW "test_view" ("a", "b") AS SELECT "a", CAST(NULL AS BIGINT) AS "b" FROM (SELECT CAST("a" AS BIGINT) AS "a" FROM (VALUES (1), (2), (3)) AS "t"("a")) AS "select_source_columns"', + ] + + +def test_create_view_query_source_columns(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.create_view( + "test_view", + parse_one("SELECT a, ignored_source FROM tbl"), + target_columns_to_types={ + "a": exp.DataType.build("BIGINT"), + "b": exp.DataType.build("BIGINT"), + }, + replace=False, + source_columns=["a", "ignored_source"], + ) + assert to_sql_calls(adapter) == [ + 'CREATE VIEW "test_view" ("a", "b") AS SELECT "a", CAST(NULL AS BIGINT) AS "b" FROM (SELECT "a", "ignored_source" FROM "tbl") AS "select_source_columns"', + ] + + def test_create_materialized_view(make_mocked_engine_adapter: t.Callable): adapter = make_mocked_engine_adapter(EngineAdapter) adapter.SUPPORTS_MATERIALIZED_VIEWS = True @@ -65,14 +116,14 @@ def test_create_materialized_view(make_mocked_engine_adapter: t.Callable): "test_view", parse_one("SELECT a FROM tbl"), materialized=True, - columns_to_types={"a": exp.DataType.build("INT")}, + target_columns_to_types={"a": exp.DataType.build("INT")}, ) adapter.create_view( "test_view", parse_one("SELECT a FROM tbl"), replace=False, materialized=True, - columns_to_types={"a": exp.DataType.build("INT")}, + target_columns_to_types={"a": exp.DataType.build("INT")}, ) adapter.cursor.execute.assert_has_calls( @@ -88,7 +139,7 @@ def test_create_materialized_view(make_mocked_engine_adapter: t.Callable): parse_one("SELECT a, b FROM tbl"), replace=False, materialized=True, - columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("INT")}, + target_columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("INT")}, ) adapter.create_view( "test_view", parse_one("SELECT a, b FROM tbl"), replace=False, materialized=True @@ -172,7 +223,7 @@ def test_insert_overwrite_by_time_partition(make_mocked_engine_adapter: t.Callab end="2022-01-02", time_column="b", time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), - columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("STRING")}, + target_columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("STRING")}, ) adapter.cursor.begin.assert_called_once() @@ -184,6 +235,37 @@ def test_insert_overwrite_by_time_partition(make_mocked_engine_adapter: t.Callab ] +def test_insert_overwrite_by_time_partition_missing_time_column_type( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(EngineAdapter) + + columns_mock = mocker.patch.object(adapter, "columns") + columns_mock.return_value = {"a": exp.DataType.build("INT"), "b": exp.DataType.build("STRING")} + + adapter.insert_overwrite_by_time_partition( + "test_table", + parse_one("SELECT a, b FROM tbl"), + start="2022-01-01", + end="2022-01-02", + time_column="b", + time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), + target_columns_to_types={ + "a": exp.DataType.build("INT"), + "b": exp.DataType.build("UNKNOWN"), + }, + ) + + columns_mock.assert_called_once_with("test_table") + adapter.cursor.begin.assert_called_once() + adapter.cursor.commit.assert_called_once() + + assert to_sql_calls(adapter) == [ + """DELETE FROM "test_table" WHERE "b" BETWEEN '2022-01-01' AND '2022-01-02'""", + """INSERT INTO "test_table" ("a", "b") SELECT "a", "b" FROM (SELECT "a", "b" FROM "tbl") AS "_subquery" WHERE "b" BETWEEN '2022-01-01' AND '2022-01-02'""", + ] + + def test_insert_overwrite_by_time_partition_supports_insert_overwrite( make_mocked_engine_adapter: t.Callable, ): @@ -198,7 +280,7 @@ def test_insert_overwrite_by_time_partition_supports_insert_overwrite( end="2022-01-02", time_column="b", time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), - columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("STRING")}, + target_columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("STRING")}, ) adapter.cursor.execute.assert_called_once_with( @@ -220,7 +302,10 @@ def test_insert_overwrite_by_time_partition_supports_insert_overwrite_pandas( end="2022-01-02", time_column="ds", time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), - columns_to_types={"a": exp.DataType.build("INT"), "ds": exp.DataType.build("STRING")}, + target_columns_to_types={ + "a": exp.DataType.build("INT"), + "ds": exp.DataType.build("STRING"), + }, ) assert to_sql_calls(adapter) == [ @@ -228,6 +313,53 @@ def test_insert_overwrite_by_time_partition_supports_insert_overwrite_pandas( ] +def test_insert_overwrite_by_time_partition_supports_insert_overwrite_pandas_source_columns( + make_mocked_engine_adapter: t.Callable, +): + adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.INSERT_OVERWRITE + df = pd.DataFrame({"a": [1, 2], "ignored_source": [3, 4]}) + adapter.insert_overwrite_by_time_partition( + "test_table", + df, + start="2022-01-01", + end="2022-01-02", + time_column="ds", + time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), + target_columns_to_types={ + "a": exp.DataType.build("INT"), + "ds": exp.DataType.build("STRING"), + }, + source_columns=["a", "ignored_source"], + ) + assert to_sql_calls(adapter) == [ + """INSERT OVERWRITE TABLE "test_table" ("a", "ds") SELECT "a", "ds" FROM (SELECT CAST("a" AS INT) AS "a", CAST(NULL AS TEXT) AS "ds" FROM (VALUES (1), (2)) AS "t"("a")) AS "_subquery" WHERE "ds" BETWEEN '2022-01-01' AND '2022-01-02'""" + ] + + +def test_insert_overwrite_by_time_partition_supports_insert_overwrite_query_source_columns( + make_mocked_engine_adapter: t.Callable, +): + adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.INSERT_OVERWRITE + adapter.insert_overwrite_by_time_partition( + "test_table", + parse_one("SELECT a, ignored_source FROM tbl"), + start="2022-01-01", + end="2022-01-02", + time_column="ds", + time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), + target_columns_to_types={ + "a": exp.DataType.build("INT"), + "ds": exp.DataType.build("STRING"), + }, + source_columns=["a", "ignored_source"], + ) + assert to_sql_calls(adapter) == [ + """INSERT OVERWRITE TABLE "test_table" ("a", "ds") SELECT "a", "ds" FROM (SELECT "a", CAST(NULL AS TEXT) AS "ds" FROM (SELECT "a", "ignored_source" FROM "tbl") AS "select_source_columns") AS "_subquery" WHERE "ds" BETWEEN '2022-01-01' AND '2022-01-02'""" + ] + + def test_insert_overwrite_by_time_partition_replace_where(make_mocked_engine_adapter: t.Callable): adapter = make_mocked_engine_adapter(EngineAdapter) adapter.INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.REPLACE_WHERE @@ -239,7 +371,7 @@ def test_insert_overwrite_by_time_partition_replace_where(make_mocked_engine_ada end="2022-01-02", time_column="b", time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), - columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("STRING")}, + target_columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("STRING")}, ) assert to_sql_calls(adapter) == [ @@ -262,7 +394,10 @@ def test_insert_overwrite_by_time_partition_replace_where_pandas( end="2022-01-02", time_column="ds", time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), - columns_to_types={"a": exp.DataType.build("INT"), "ds": exp.DataType.build("STRING")}, + target_columns_to_types={ + "a": exp.DataType.build("INT"), + "ds": exp.DataType.build("STRING"), + }, ) assert to_sql_calls(adapter) == [ @@ -270,6 +405,53 @@ def test_insert_overwrite_by_time_partition_replace_where_pandas( ] +def test_insert_overwrite_by_time_partition_replace_where_pandas_source_columns( + make_mocked_engine_adapter: t.Callable, +): + adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.REPLACE_WHERE + df = pd.DataFrame({"a": [1, 2], "ignored_source": [3, 4]}) + adapter.insert_overwrite_by_time_partition( + "test_table", + df, + start="2022-01-01", + end="2022-01-02", + time_column="ds", + time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), + target_columns_to_types={ + "a": exp.DataType.build("INT"), + "ds": exp.DataType.build("STRING"), + }, + source_columns=["a", "ignored_source"], + ) + assert to_sql_calls(adapter) == [ + """INSERT INTO "test_table" REPLACE WHERE "ds" BETWEEN '2022-01-01' AND '2022-01-02' SELECT "a", "ds" FROM (SELECT CAST("a" AS INT) AS "a", CAST(NULL AS TEXT) AS "ds" FROM (VALUES (1), (2)) AS "t"("a")) AS "_subquery" WHERE "ds" BETWEEN '2022-01-01' AND '2022-01-02'""" + ] + + +def test_insert_overwrite_by_time_partition_replace_where_query_source_columns( + make_mocked_engine_adapter: t.Callable, +): + adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.REPLACE_WHERE + adapter.insert_overwrite_by_time_partition( + "test_table", + parse_one("SELECT a, ignored_source FROM tbl"), + start="2022-01-01", + end="2022-01-02", + time_column="ds", + time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), + target_columns_to_types={ + "a": exp.DataType.build("INT"), + "ds": exp.DataType.build("STRING"), + }, + source_columns=["a", "ignored_source"], + ) + assert to_sql_calls(adapter) == [ + """INSERT INTO "test_table" REPLACE WHERE "ds" BETWEEN '2022-01-01' AND '2022-01-02' SELECT "a", "ds" FROM (SELECT "a", CAST(NULL AS TEXT) AS "ds" FROM (SELECT "a", "ignored_source" FROM "tbl") AS "select_source_columns") AS "_subquery" WHERE "ds" BETWEEN '2022-01-01' AND '2022-01-02'""" + ] + + def test_insert_overwrite_no_where(make_mocked_engine_adapter: t.Callable): adapter = make_mocked_engine_adapter(EngineAdapter) @@ -282,7 +464,7 @@ def test_insert_overwrite_no_where(make_mocked_engine_adapter: t.Callable): adapter._insert_overwrite_by_condition( "test_table", source_queries, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, ) adapter.cursor.begin.assert_called_once() @@ -294,13 +476,38 @@ def test_insert_overwrite_no_where(make_mocked_engine_adapter: t.Callable): ] +def test_insert_overwrite_by_condition_column_contains_unsafe_characters( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.MERGE + + source_queries, columns_to_types = adapter._get_source_queries_and_columns_to_types( + parse_one("SELECT 1 AS c"), None, target_table="test_table" + ) + + columns_mock = mocker.patch.object(adapter, "columns") + columns_mock.return_value = {"foo.bar.baz": exp.DataType.build("INT")} + + adapter._insert_overwrite_by_condition( + "test_table", + source_queries, + target_columns_to_types=None, + ) + + # The goal here is to assert that we don't parse `foo.bar.baz` into a qualified column + assert to_sql_calls(adapter) == [ + 'MERGE INTO "test_table" AS "__MERGE_TARGET__" USING (SELECT "foo.bar.baz" FROM (SELECT 1 AS "c") AS "_subquery") AS "__MERGE_SOURCE__" ON FALSE WHEN NOT MATCHED BY SOURCE THEN DELETE WHEN NOT MATCHED THEN INSERT ("foo.bar.baz") VALUES ("foo.bar.baz")' + ] + + def test_insert_append_query(make_mocked_engine_adapter: t.Callable): adapter = make_mocked_engine_adapter(EngineAdapter) adapter.insert_append( "test_table", parse_one("SELECT a FROM tbl"), - columns_to_types={"a": exp.DataType.build("INT")}, + target_columns_to_types={"a": exp.DataType.build("INT")}, ) assert to_sql_calls(adapter) == [ @@ -314,7 +521,7 @@ def test_insert_append_query_select_star(make_mocked_engine_adapter: t.Callable) adapter.insert_append( "test_table", parse_one("SELECT 1 AS a, * FROM tbl"), - columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("INT")}, + target_columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("INT")}, ) assert to_sql_calls(adapter) == [ @@ -329,7 +536,7 @@ def test_insert_append_pandas(make_mocked_engine_adapter: t.Callable): adapter.insert_append( "test_table", df, - columns_to_types={ + target_columns_to_types={ "a": exp.DataType.build("INT"), "b": exp.DataType.build("INT"), }, @@ -348,7 +555,7 @@ def test_insert_append_pandas_batches(make_mocked_engine_adapter: t.Callable): adapter.insert_append( "test_table", df, - columns_to_types={ + target_columns_to_types={ "a": exp.DataType.build("INT"), "b": exp.DataType.build("INT"), }, @@ -364,6 +571,39 @@ def test_insert_append_pandas_batches(make_mocked_engine_adapter: t.Callable): ] +def test_insert_append_pandas_source_columns(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(EngineAdapter) + df = pd.DataFrame({"a": [1, 2, 3], "ignored_source": [4, 5, 6]}) + adapter.insert_append( + "test_table", + df, + target_columns_to_types={ + "a": exp.DataType.build("INT"), + "b": exp.DataType.build("INT"), + }, + source_columns=["a", "ignored_source"], + ) + assert to_sql_calls(adapter) == [ + 'INSERT INTO "test_table" ("a", "b") SELECT CAST("a" AS INT) AS "a", CAST(NULL AS INT) AS "b" FROM (VALUES (1), (2), (3)) AS "t"("a")', + ] + + +def test_insert_append_query_source_columns(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.insert_append( + "test_table", + parse_one("SELECT a, ignored_source FROM tbl"), + target_columns_to_types={ + "a": exp.DataType.build("INT"), + "b": exp.DataType.build("INT"), + }, + source_columns=["a", "ignored_source"], + ) + assert to_sql_calls(adapter) == [ + 'INSERT INTO "test_table" ("a", "b") SELECT "a", CAST(NULL AS INT) AS "b" FROM (SELECT "a", "ignored_source" FROM "tbl") AS "select_source_columns"', + ] + + def test_create_table(make_mocked_engine_adapter: t.Callable): adapter = make_mocked_engine_adapter(EngineAdapter) @@ -473,7 +713,7 @@ def test_comments(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture) ( { "support_positional_add": True, - "support_nested_operations": True, + "nested_support": NestedSupport.ALL, "array_element_selector": "element", }, { @@ -511,7 +751,27 @@ def test_comments(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture) ), ( { - "support_nested_operations": True, + "coerceable_types": { + exp.DataType.build("FLOAT"): {exp.DataType.build("INT")}, + }, + }, + { + "a": "FLOAT", + "b": "TEXT", + }, + { + "a": "INT", + "b": "TEXT", + }, + { + "a": "FLOAT", + "b": "TEXT", + }, + [], + ), + ( + { + "nested_support": NestedSupport.ALL_BUT_DROP, "array_element_selector": "element", }, { @@ -629,7 +889,7 @@ def test_comments(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture) ( { "support_positional_add": True, - "support_nested_operations": True, + "nested_support": NestedSupport.ALL, "array_element_selector": "element", }, { @@ -658,7 +918,7 @@ def test_comments(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture) ( { "support_positional_add": True, - "support_nested_operations": True, + "nested_support": NestedSupport.ALL, "array_element_selector": "element", }, { @@ -714,7 +974,7 @@ def test_comments(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture) # Test multiple operations on a column with no positional and nested features enabled ( { - "support_nested_operations": True, + "nested_support": NestedSupport.ALL, "array_element_selector": "element", }, { @@ -771,7 +1031,7 @@ def test_comments(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture) # Test deeply nested structures ( { - "support_nested_operations": True, + "nested_support": NestedSupport.ALL, "array_element_selector": "element", }, { @@ -800,13 +1060,13 @@ def test_alter_table( ): adapter = make_mocked_engine_adapter(EngineAdapter) - adapter.SCHEMA_DIFFER = SchemaDiffer(**schema_differ_config) - original_from_structs = adapter.SCHEMA_DIFFER._from_structs + adapter.SCHEMA_DIFFER_KWARGS = schema_differ_config + original_from_structs = adapter.schema_differ._from_structs - def _from_structs( - current_struct: exp.DataType, new_struct: exp.DataType - ) -> t.List[TableAlterOperation]: - operations = original_from_structs(current_struct, new_struct) + def _from_structs(*args, **kwargs) -> t.List[TableAlterOperation]: + operations = original_from_structs(*args, **kwargs) + if not operations: + return operations assert ( operations[-1].expected_table_struct.sql() == columns_to_types_to_struct(expected_final_structure).sql() @@ -821,12 +1081,11 @@ def _from_structs( def table_columns(table_name: str) -> t.Dict[str, exp.DataType]: if table_name == current_table_name: return {k: exp.DataType.build(v) for k, v in current_table.items()} - else: - return {k: exp.DataType.build(v) for k, v in target_table.items()} + return {k: exp.DataType.build(v) for k, v in target_table.items()} adapter.columns = table_columns - adapter.alter_table(adapter.get_alter_expressions(current_table_name, target_table_name)) + adapter.alter_table(adapter.get_alter_operations(current_table_name, target_table_name)) adapter.cursor.begin.assert_called_once() adapter.cursor.commit.assert_called_once() @@ -839,7 +1098,7 @@ def test_merge_upsert(make_mocked_engine_adapter: t.Callable, assert_exp_eq): adapter.merge( target_table="target", source_table=t.cast(exp.Select, parse_one('SELECT "ID", ts, val FROM source')), - columns_to_types={ + target_columns_to_types={ "ID": exp.DataType.build("int"), "ts": exp.DataType.build("timestamp"), "val": exp.DataType.build("int"), @@ -870,7 +1129,7 @@ def test_merge_upsert(make_mocked_engine_adapter: t.Callable, assert_exp_eq): adapter.merge( target_table="target", source_table=parse_one("SELECT id, ts, val FROM source"), - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("int"), "ts": exp.DataType.build("timestamp"), "val": exp.DataType.build("int"), @@ -887,11 +1146,11 @@ def test_merge_upsert(make_mocked_engine_adapter: t.Callable, assert_exp_eq): def test_merge_upsert_pandas(make_mocked_engine_adapter: t.Callable): adapter = make_mocked_engine_adapter(EngineAdapter) - df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + df = pd.DataFrame({"id": [1, 2, 3], "ts": [4, 5, 6], "val": [1, 2, 3]}) adapter.merge( target_table="target", source_table=df, - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("int"), "ts": exp.DataType.build("timestamp"), "val": exp.DataType.build("int"), @@ -899,7 +1158,7 @@ def test_merge_upsert_pandas(make_mocked_engine_adapter: t.Callable): unique_key=[exp.to_identifier("id")], ) adapter.cursor.execute.assert_called_once_with( - 'MERGE INTO "target" AS "__MERGE_TARGET__" USING (SELECT CAST("id" AS INT) AS "id", CAST("ts" AS TIMESTAMP) AS "ts", CAST("val" AS INT) AS "val" FROM (VALUES (1, 4), (2, 5), (3, 6)) AS "t"("id", "ts", "val")) AS "__MERGE_SOURCE__" ON "__MERGE_TARGET__"."id" = "__MERGE_SOURCE__"."id" ' + 'MERGE INTO "target" AS "__MERGE_TARGET__" USING (SELECT CAST("id" AS INT) AS "id", CAST("ts" AS TIMESTAMP) AS "ts", CAST("val" AS INT) AS "val" FROM (VALUES (1, 4, 1), (2, 5, 2), (3, 6, 3)) AS "t"("id", "ts", "val")) AS "__MERGE_SOURCE__" ON "__MERGE_TARGET__"."id" = "__MERGE_SOURCE__"."id" ' 'WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."id" = "__MERGE_SOURCE__"."id", "__MERGE_TARGET__"."ts" = "__MERGE_SOURCE__"."ts", "__MERGE_TARGET__"."val" = "__MERGE_SOURCE__"."val" ' 'WHEN NOT MATCHED THEN INSERT ("id", "ts", "val") VALUES ("__MERGE_SOURCE__"."id", "__MERGE_SOURCE__"."ts", "__MERGE_SOURCE__"."val")' ) @@ -908,7 +1167,7 @@ def test_merge_upsert_pandas(make_mocked_engine_adapter: t.Callable): adapter.merge( target_table="target", source_table=df, - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("int"), "ts": exp.DataType.build("timestamp"), "val": exp.DataType.build("int"), @@ -916,7 +1175,48 @@ def test_merge_upsert_pandas(make_mocked_engine_adapter: t.Callable): unique_key=[exp.to_identifier("id"), exp.to_identifier("ts")], ) adapter.cursor.execute.assert_called_once_with( - 'MERGE INTO "target" AS "__MERGE_TARGET__" USING (SELECT CAST("id" AS INT) AS "id", CAST("ts" AS TIMESTAMP) AS "ts", CAST("val" AS INT) AS "val" FROM (VALUES (1, 4), (2, 5), (3, 6)) AS "t"("id", "ts", "val")) AS "__MERGE_SOURCE__" ON "__MERGE_TARGET__"."id" = "__MERGE_SOURCE__"."id" AND "__MERGE_TARGET__"."ts" = "__MERGE_SOURCE__"."ts" ' + 'MERGE INTO "target" AS "__MERGE_TARGET__" USING (SELECT CAST("id" AS INT) AS "id", CAST("ts" AS TIMESTAMP) AS "ts", CAST("val" AS INT) AS "val" FROM (VALUES (1, 4, 1), (2, 5, 2), (3, 6, 3)) AS "t"("id", "ts", "val")) AS "__MERGE_SOURCE__" ON "__MERGE_TARGET__"."id" = "__MERGE_SOURCE__"."id" AND "__MERGE_TARGET__"."ts" = "__MERGE_SOURCE__"."ts" ' + 'WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."id" = "__MERGE_SOURCE__"."id", "__MERGE_TARGET__"."ts" = "__MERGE_SOURCE__"."ts", "__MERGE_TARGET__"."val" = "__MERGE_SOURCE__"."val" ' + 'WHEN NOT MATCHED THEN INSERT ("id", "ts", "val") VALUES ("__MERGE_SOURCE__"."id", "__MERGE_SOURCE__"."ts", "__MERGE_SOURCE__"."val")' + ) + + +def test_merge_upsert_pandas_source_columns(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(EngineAdapter) + df = pd.DataFrame({"id": [1, 2, 3], "ts": [4, 5, 6], "ignored_source": [7, 8, 9]}) + adapter.merge( + target_table="target", + source_table=df, + target_columns_to_types={ + "id": exp.DataType.build("int"), + "ts": exp.DataType.build("timestamp"), + "val": exp.DataType.build("int"), + }, + unique_key=[exp.to_identifier("id")], + source_columns=["id", "ignored_source", "ts"], + ) + adapter.cursor.execute.assert_called_once_with( + 'MERGE INTO "target" AS "__MERGE_TARGET__" USING (SELECT CAST("id" AS INT) AS "id", CAST("ts" AS TIMESTAMP) AS "ts", CAST(NULL AS INT) AS "val" FROM (VALUES (1, 4), (2, 5), (3, 6)) AS "t"("id", "ts")) AS "__MERGE_SOURCE__" ON "__MERGE_TARGET__"."id" = "__MERGE_SOURCE__"."id" ' + 'WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."id" = "__MERGE_SOURCE__"."id", "__MERGE_TARGET__"."ts" = "__MERGE_SOURCE__"."ts", "__MERGE_TARGET__"."val" = "__MERGE_SOURCE__"."val" ' + 'WHEN NOT MATCHED THEN INSERT ("id", "ts", "val") VALUES ("__MERGE_SOURCE__"."id", "__MERGE_SOURCE__"."ts", "__MERGE_SOURCE__"."val")' + ) + + +def test_merge_upsert_query_source_columns(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.merge( + target_table="target", + source_table=parse_one("SELECT id, ts, ignored_source FROM source"), + target_columns_to_types={ + "id": exp.DataType.build("int"), + "ts": exp.DataType.build("timestamp"), + "val": exp.DataType.build("int"), + }, + unique_key=[exp.to_identifier("id")], + source_columns=["id", "ts", "ignored_source"], + ) + adapter.cursor.execute.assert_called_once_with( + 'MERGE INTO "target" AS "__MERGE_TARGET__" USING (SELECT "id", "ts", CAST(NULL AS INT) AS "val" FROM (SELECT "id", "ts", "ignored_source" FROM "source") AS "select_source_columns") AS "__MERGE_SOURCE__" ON "__MERGE_TARGET__"."id" = "__MERGE_SOURCE__"."id" ' 'WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."id" = "__MERGE_SOURCE__"."id", "__MERGE_TARGET__"."ts" = "__MERGE_SOURCE__"."ts", "__MERGE_TARGET__"."val" = "__MERGE_SOURCE__"."val" ' 'WHEN NOT MATCHED THEN INSERT ("id", "ts", "val") VALUES ("__MERGE_SOURCE__"."id", "__MERGE_SOURCE__"."ts", "__MERGE_SOURCE__"."val")' ) @@ -928,26 +1228,32 @@ def test_merge_when_matched(make_mocked_engine_adapter: t.Callable, assert_exp_e adapter.merge( target_table="target", source_table=t.cast(exp.Select, parse_one('SELECT "ID", ts, val FROM source')), - columns_to_types={ + target_columns_to_types={ "ID": exp.DataType.build("int"), "ts": exp.DataType.build("timestamp"), "val": exp.DataType.build("int"), }, unique_key=[exp.to_identifier("ID", quoted=True)], - when_matched=exp.When( - matched=True, - source=False, - then=exp.Update( - expressions=[ - exp.column("val", "__MERGE_TARGET__").eq(exp.column("val", "__MERGE_SOURCE__")), - exp.column("ts", "__MERGE_TARGET__").eq( - exp.Coalesce( - this=exp.column("ts", "__MERGE_SOURCE__"), - expressions=[exp.column("ts", "__MERGE_TARGET__")], - ) + when_matched=exp.Whens( + expressions=[ + exp.When( + matched=True, + source=False, + then=exp.Update( + expressions=[ + exp.column("val", "__MERGE_TARGET__").eq( + exp.column("val", "__MERGE_SOURCE__") + ), + exp.column("ts", "__MERGE_TARGET__").eq( + exp.Coalesce( + this=exp.column("ts", "__MERGE_SOURCE__"), + expressions=[exp.column("ts", "__MERGE_TARGET__")], + ) + ), + ], ), - ], - ), + ) + ] ), ) @@ -969,40 +1275,184 @@ def test_merge_when_matched(make_mocked_engine_adapter: t.Callable, assert_exp_e ) -def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable): +def test_merge_when_matched_multiple(make_mocked_engine_adapter: t.Callable, assert_exp_eq): adapter = make_mocked_engine_adapter(EngineAdapter) - adapter.scd_type_2_by_time( + adapter.merge( target_table="target", - source_table=t.cast( - exp.Select, parse_one("SELECT id, name, price, test_UPDATED_at FROM source") - ), - unique_key=[ - parse_one("""COALESCE("id", '') || '|' || COALESCE("name", '')"""), - parse_one("""COALESCE("name", '')"""), - ], - valid_from_col=exp.column("test_valid_from", quoted=True), - valid_to_col=exp.column("test_valid_to", quoted=True), - updated_at_col=exp.column("test_UPDATED_at", quoted=True), - columns_to_types={ - "id": exp.DataType.build("INT"), - "name": exp.DataType.build("VARCHAR"), - "price": exp.DataType.build("DOUBLE"), - "test_UPDATED_at": exp.DataType.build("TIMESTAMP"), - "test_valid_from": exp.DataType.build("TIMESTAMP"), - "test_valid_to": exp.DataType.build("TIMESTAMP"), + source_table=t.cast(exp.Select, parse_one('SELECT "ID", ts, val FROM source')), + target_columns_to_types={ + "ID": exp.DataType.build("int"), + "ts": exp.DataType.build("timestamp"), + "val": exp.DataType.build("int"), }, - execution_time=datetime(2020, 1, 1, 0, 0, 0), + unique_key=[exp.to_identifier("ID", quoted=True)], + when_matched=exp.Whens( + expressions=[ + exp.When( + matched=True, + condition=exp.column("ID", "__MERGE_SOURCE__").eq(exp.Literal.number(1)), + then=exp.Update( + expressions=[ + exp.column("val", "__MERGE_TARGET__").eq( + exp.column("val", "__MERGE_SOURCE__") + ), + exp.column("ts", "__MERGE_TARGET__").eq( + exp.Coalesce( + this=exp.column("ts", "__MERGE_SOURCE__"), + expressions=[exp.column("ts", "__MERGE_TARGET__")], + ) + ), + ], + ), + ), + exp.When( + matched=True, + source=False, + then=exp.Update( + expressions=[ + exp.column("val", "__MERGE_TARGET__").eq( + exp.column("val", "__MERGE_SOURCE__") + ), + exp.column("ts", "__MERGE_TARGET__").eq( + exp.Coalesce( + this=exp.column("ts", "__MERGE_SOURCE__"), + expressions=[exp.column("ts", "__MERGE_TARGET__")], + ) + ), + ], + ), + ), + ] + ), ) - assert ( - adapter.cursor.execute.call_args[0][0] - == parse_one( - """ -CREATE OR REPLACE TABLE "target" AS -WITH "source" AS ( - SELECT DISTINCT ON (COALESCE("id", '') || '|' || COALESCE("name", ''), COALESCE("name", '')) - TRUE AS "_exists", + assert_exp_eq( + adapter.cursor.execute.call_args[0][0], + """ +MERGE INTO "target" AS "__MERGE_TARGET__" USING ( + SELECT + "ID", + "ts", + "val" + FROM "source" +) AS "__MERGE_SOURCE__" + ON "__MERGE_TARGET__"."ID" = "__MERGE_SOURCE__"."ID" + WHEN MATCHED AND "__MERGE_SOURCE__"."ID" = 1 THEN UPDATE SET "__MERGE_TARGET__"."val" = "__MERGE_SOURCE__"."val", "__MERGE_TARGET__"."ts" = COALESCE("__MERGE_SOURCE__"."ts", "__MERGE_TARGET__"."ts"), + WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."val" = "__MERGE_SOURCE__"."val", "__MERGE_TARGET__"."ts" = COALESCE("__MERGE_SOURCE__"."ts", "__MERGE_TARGET__"."ts") + WHEN NOT MATCHED THEN INSERT ("ID", "ts", "val") + VALUES ("__MERGE_SOURCE__"."ID", "__MERGE_SOURCE__"."ts", "__MERGE_SOURCE__"."val") +""", + ) + + +def test_merge_filter(make_mocked_engine_adapter: t.Callable, assert_exp_eq): + adapter = make_mocked_engine_adapter(EngineAdapter) + + adapter.merge( + target_table="target", + source_table=t.cast(exp.Select, parse_one('SELECT "ID", ts, val FROM source')), + target_columns_to_types={ + "ID": exp.DataType.build("int"), + "ts": exp.DataType.build("timestamp"), + "val": exp.DataType.build("int"), + }, + unique_key=[exp.to_identifier("ID", quoted=True)], + when_matched=exp.Whens( + expressions=[ + exp.When( + matched=True, + source=False, + then=exp.Update( + expressions=[ + exp.column("val", "__MERGE_TARGET__").eq( + exp.column("val", "__MERGE_SOURCE__") + ), + exp.column("ts", "__MERGE_TARGET__").eq( + exp.Coalesce( + this=exp.column("ts", "__MERGE_SOURCE__"), + expressions=[exp.column("ts", "__MERGE_TARGET__")], + ) + ), + ], + ), + ) + ] + ), + merge_filter=exp.And( + this=exp.GT( + this=exp.column("ID", "__MERGE_SOURCE__"), + expression=exp.Literal(this="0", is_string=False), + ), + expression=exp.LT( + this=exp.column("ts", "__MERGE_TARGET__"), + expression=exp.Timestamp(this=exp.column("2020-02-05", quoted=True)), + ), + ), + ) + + assert_exp_eq( + adapter.cursor.execute.call_args[0][0], + """ +MERGE INTO "target" AS "__MERGE_TARGET__" +USING ( + SELECT "ID", "ts", "val" + FROM "source" +) AS "__MERGE_SOURCE__" +ON ( + "__MERGE_SOURCE__"."ID" > 0 + AND "__MERGE_TARGET__"."ts" < TIMESTAMP("2020-02-05") +) +AND "__MERGE_TARGET__"."ID" = "__MERGE_SOURCE__"."ID" +WHEN MATCHED THEN + UPDATE SET + "__MERGE_TARGET__"."val" = "__MERGE_SOURCE__"."val", + "__MERGE_TARGET__"."ts" = COALESCE("__MERGE_SOURCE__"."ts", "__MERGE_TARGET__"."ts") +WHEN NOT MATCHED THEN + INSERT ("ID", "ts", "val") + VALUES ( + "__MERGE_SOURCE__"."ID", + "__MERGE_SOURCE__"."ts", + "__MERGE_SOURCE__"."val" + ); +""", + ) + + +def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(EngineAdapter) + + adapter.scd_type_2_by_time( + target_table="target", + source_table=t.cast( + exp.Select, parse_one("SELECT id, name, price, test_UPDATED_at FROM source") + ), + unique_key=[ + parse_one("""COALESCE("id", '') || '|' || COALESCE("name", '')"""), + parse_one("""COALESCE("name", '')"""), + ], + valid_from_col=exp.column("test_valid_from", quoted=True), + valid_to_col=exp.column("test_valid_to", quoted=True), + updated_at_col=exp.column("test_UPDATED_at", quoted=True), + target_columns_to_types={ + "id": exp.DataType.build("INT"), + "name": exp.DataType.build("VARCHAR"), + "price": exp.DataType.build("DOUBLE"), + "test_UPDATED_at": exp.DataType.build("TIMESTAMP"), + "test_valid_from": exp.DataType.build("TIMESTAMP"), + "test_valid_to": exp.DataType.build("TIMESTAMP"), + }, + execution_time=datetime(2020, 1, 1, 0, 0, 0), + ) + + assert ( + adapter.cursor.execute.call_args[0][0] + == parse_one( + """ +CREATE OR REPLACE TABLE "target" AS +WITH "source" AS ( + SELECT DISTINCT ON (COALESCE("id", '') || '|' || COALESCE("name", ''), COALESCE("name", '')) + TRUE AS "_exists", "id", "name", "price", @@ -1069,7 +1519,7 @@ def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable): COALESCE("name", '') ), "joined" AS ( SELECT - "source"."_exists", + "source"."_exists" AS "_exists", "latest"."id" AS "t_id", "latest"."name" AS "t_name", "latest"."price" AS "t_price", @@ -1090,7 +1540,7 @@ def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable): AND COALESCE("latest"."name", '') = COALESCE("source"."name", '') UNION ALL SELECT - "source"."_exists", + "source"."_exists" AS "_exists", "latest"."id" AS "t_id", "latest"."name" AS "t_name", "latest"."price" AS "t_price", @@ -1129,8 +1579,8 @@ def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable): ELSE "t_test_valid_from" END AS "test_valid_from", CASE - WHEN "test_UPDATED_at" > "t_test_UPDATED_at" - THEN "test_UPDATED_at" + WHEN "joined"."test_UPDATED_at" > "joined"."t_test_UPDATED_at" + THEN "joined"."test_UPDATED_at" WHEN "joined"."_exists" IS NULL THEN CAST('2020-01-01 00:00:00' AS TIMESTAMP) ELSE "t_test_valid_to" @@ -1151,7 +1601,7 @@ def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable): CAST(NULL AS TIMESTAMP) AS "test_valid_to" FROM "joined" WHERE - "test_UPDATED_at" > "t_test_UPDATED_at" + "joined"."test_UPDATED_at" > "joined"."t_test_UPDATED_at" ) SELECT CAST("id" AS INT) AS "id", @@ -1170,56 +1620,69 @@ def test_scd_type_2_by_time(make_mocked_engine_adapter: t.Callable): ) -def test_scd_type_2_by_time_no_invalidate_hard_deletes(make_mocked_engine_adapter: t.Callable): +def test_scd_type_2_by_time_source_columns(make_mocked_engine_adapter: t.Callable): adapter = make_mocked_engine_adapter(EngineAdapter) - + df = pd.DataFrame( + { + "id": [1, 2, 3], + "name": ["a", "b", "c"], + "test_UPDATED_at": [ + "2020-01-01 10:00:00", + "2020-01-02 15:00:00", + "2020-01-03 12:00:00", + ], + "ignored_source": [4, 5, 6], + } + ) adapter.scd_type_2_by_time( target_table="target", - source_table=t.cast( - exp.Select, parse_one("SELECT id, name, price, test_updated_at FROM source") - ), - unique_key=[exp.func("COALESCE", "id", "''")], + source_table=df, + unique_key=[exp.column("id")], valid_from_col=exp.column("test_valid_from", quoted=True), valid_to_col=exp.column("test_valid_to", quoted=True), - updated_at_col=exp.column("test_updated_at", quoted=True), - invalidate_hard_deletes=False, - columns_to_types={ + updated_at_col=exp.column("test_UPDATED_at", quoted=True), + target_columns_to_types={ "id": exp.DataType.build("INT"), "name": exp.DataType.build("VARCHAR"), "price": exp.DataType.build("DOUBLE"), - "test_updated_at": exp.DataType.build("TIMESTAMP"), + "test_UPDATED_at": exp.DataType.build("TIMESTAMP"), "test_valid_from": exp.DataType.build("TIMESTAMP"), "test_valid_to": exp.DataType.build("TIMESTAMP"), }, + source_columns=["id", "name", "test_UPDATED_at", "ignored_source"], execution_time=datetime(2020, 1, 1, 0, 0, 0), + start=datetime(2020, 1, 1, 0, 0, 0), + is_restatement=True, ) - + sql_calls = to_sql_calls(adapter) assert ( - adapter.cursor.execute.call_args[0][0] - == parse_one( - """ + parse_one(sql_calls[1]).sql() + == parse_one(""" CREATE OR REPLACE TABLE "target" AS WITH "source" AS ( - SELECT DISTINCT ON (COALESCE("id", '')) + SELECT DISTINCT ON ("id") TRUE AS "_exists", "id", "name", "price", - CAST("test_updated_at" AS TIMESTAMP) AS "test_updated_at" + CAST("test_UPDATED_at" AS TIMESTAMP) AS "test_UPDATED_at" FROM ( SELECT - "id", - "name", - "price", - "test_updated_at" - FROM "source" + CAST("id" AS INT) AS "id", + CAST("name" AS VARCHAR) AS "name", + CAST(NULL AS DOUBLE) AS "price", + CAST("test_UPDATED_at" AS TIMESTAMP) AS "test_UPDATED_at" + FROM (VALUES + (1, 'a', '2020-01-01 10:00:00'), + (2, 'b', '2020-01-02 15:00:00'), + (3, 'c', '2020-01-03 12:00:00')) AS "t"("id", "name", "test_UPDATED_at") ) AS "raw_source" ), "static" AS ( SELECT "id", "name", "price", - "test_updated_at", + "test_UPDATED_at", "test_valid_from", "test_valid_to", TRUE AS "_exists" @@ -1231,7 +1694,7 @@ def test_scd_type_2_by_time_no_invalidate_hard_deletes(make_mocked_engine_adapte "id", "name", "price", - "test_updated_at", + "test_UPDATED_at", "test_valid_from", "test_valid_to", TRUE AS "_exists" @@ -1243,54 +1706,54 @@ def test_scd_type_2_by_time_no_invalidate_hard_deletes(make_mocked_engine_adapte "static"."id", "static"."name", "static"."price", - "static"."test_updated_at", + "static"."test_UPDATED_at", "static"."test_valid_from", "static"."test_valid_to" FROM "static" LEFT JOIN "latest" - ON COALESCE("static"."id", '') = COALESCE("latest"."id", '') + ON "static"."id" = "latest"."id" WHERE "latest"."test_valid_to" IS NULL ), "latest_deleted" AS ( SELECT TRUE AS "_exists", - COALESCE("id", '') AS "_key0", + "id" AS "_key0", MAX("test_valid_to") AS "test_valid_to" FROM "deleted" GROUP BY - COALESCE("id", '') + "id" ), "joined" AS ( SELECT - "source"."_exists", + "source"."_exists" AS "_exists", "latest"."id" AS "t_id", "latest"."name" AS "t_name", "latest"."price" AS "t_price", - "latest"."test_updated_at" AS "t_test_updated_at", + "latest"."test_UPDATED_at" AS "t_test_UPDATED_at", "latest"."test_valid_from" AS "t_test_valid_from", "latest"."test_valid_to" AS "t_test_valid_to", "source"."id" AS "id", "source"."name" AS "name", "source"."price" AS "price", - "source"."test_updated_at" AS "test_updated_at" + "source"."test_UPDATED_at" AS "test_UPDATED_at" FROM "latest" LEFT JOIN "source" - ON COALESCE("latest"."id", '') = COALESCE("source"."id", '') + ON "latest"."id" = "source"."id" UNION ALL SELECT - "source"."_exists", + "source"."_exists" AS "_exists", "latest"."id" AS "t_id", "latest"."name" AS "t_name", "latest"."price" AS "t_price", - "latest"."test_updated_at" AS "t_test_updated_at", + "latest"."test_UPDATED_at" AS "t_test_UPDATED_at", "latest"."test_valid_from" AS "t_test_valid_from", "latest"."test_valid_to" AS "t_test_valid_to", "source"."id" AS "id", "source"."name" AS "name", "source"."price" AS "price", - "source"."test_updated_at" AS "test_updated_at" + "source"."test_UPDATED_at" AS "test_UPDATED_at" FROM "latest" RIGHT JOIN "source" - ON COALESCE("latest"."id", '') = COALESCE("source"."id", '') + ON "latest"."id" = "source"."id" WHERE "latest"."_exists" IS NULL ), "updated_rows" AS ( @@ -1298,82 +1761,99 @@ def test_scd_type_2_by_time_no_invalidate_hard_deletes(make_mocked_engine_adapte COALESCE("joined"."t_id", "joined"."id") AS "id", COALESCE("joined"."t_name", "joined"."name") AS "name", COALESCE("joined"."t_price", "joined"."price") AS "price", - COALESCE("joined"."t_test_updated_at", "joined"."test_updated_at") AS "test_updated_at", + COALESCE("joined"."t_test_UPDATED_at", "joined"."test_UPDATED_at") AS "test_UPDATED_at", CASE WHEN "t_test_valid_from" IS NULL AND NOT "latest_deleted"."_exists" IS NULL THEN CASE - WHEN "latest_deleted"."test_valid_to" > "test_updated_at" + WHEN "latest_deleted"."test_valid_to" > "test_UPDATED_at" THEN "latest_deleted"."test_valid_to" - ELSE "test_updated_at" + ELSE "test_UPDATED_at" END WHEN "t_test_valid_from" IS NULL THEN CAST('1970-01-01 00:00:00' AS TIMESTAMP) ELSE "t_test_valid_from" END AS "test_valid_from", CASE - WHEN "test_updated_at" > "t_test_updated_at" - THEN "test_updated_at" + WHEN "joined"."test_UPDATED_at" > "joined"."t_test_UPDATED_at" + THEN "joined"."test_UPDATED_at" + WHEN "joined"."_exists" IS NULL + THEN CAST('2020-01-01 00:00:00' AS TIMESTAMP) ELSE "t_test_valid_to" END AS "test_valid_to" FROM "joined" LEFT JOIN "latest_deleted" - ON COALESCE("joined"."id", '') = "latest_deleted"."_key0" + ON "joined"."id" = "latest_deleted"."_key0" ), "inserted_rows" AS ( SELECT "id", "name", "price", - "test_updated_at", - "test_updated_at" AS "test_valid_from", + "test_UPDATED_at", + "test_UPDATED_at" AS "test_valid_from", CAST(NULL AS TIMESTAMP) AS "test_valid_to" FROM "joined" WHERE - "test_updated_at" > "t_test_updated_at" + "joined"."test_UPDATED_at" > "joined"."t_test_UPDATED_at" ) SELECT CAST("id" AS INT) AS "id", CAST("name" AS VARCHAR) AS "name", CAST("price" AS DOUBLE) AS "price", - CAST("test_updated_at" AS TIMESTAMP) AS "test_updated_at", + CAST("test_UPDATED_at" AS TIMESTAMP) AS "test_UPDATED_at", CAST("test_valid_from" AS TIMESTAMP) AS "test_valid_from", CAST("test_valid_to" AS TIMESTAMP) AS "test_valid_to" FROM ( - SELECT "id", "name", "price", "test_updated_at", "test_valid_from", "test_valid_to" FROM "static" - UNION ALL SELECT "id", "name", "price", "test_updated_at", "test_valid_from", "test_valid_to" FROM "updated_rows" - UNION ALL SELECT "id", "name", "price", "test_updated_at", "test_valid_from", "test_valid_to" FROM "inserted_rows" + SELECT + "id", + "name", + "price", + "test_UPDATED_at", + "test_valid_from", + "test_valid_to" + FROM "static" + UNION ALL + SELECT + "id", + "name", + "price", + "test_UPDATED_at", + "test_valid_from", + "test_valid_to" + FROM "updated_rows" + UNION ALL + SELECT + "id", + "name", + "price", + "test_UPDATED_at", + "test_valid_from", + "test_valid_to" + FROM "inserted_rows" ) AS "_subquery" - """ - ).sql() + """).sql() ) -def test_merge_scd_type_2_pandas(make_mocked_engine_adapter: t.Callable): +def test_scd_type_2_by_time_no_invalidate_hard_deletes(make_mocked_engine_adapter: t.Callable): adapter = make_mocked_engine_adapter(EngineAdapter) - df = pd.DataFrame( - { - "id1": [1, 2, 3], - "id2": [4, 5, 6], - "name": ["muffins", "chips", "soda"], - "price": [4.0, 5.0, 6.0], - "updated_at": ["2020-01-01 10:00:00", "2020-01-02 15:00:00", "2020-01-03 12:00:00"], - } - ) adapter.scd_type_2_by_time( target_table="target", - source_table=df, - unique_key=[exp.column("id1"), exp.column("id2")], + source_table=t.cast( + exp.Select, parse_one("SELECT id, name, price, test_updated_at FROM source") + ), + unique_key=[exp.func("COALESCE", "id", "''")], valid_from_col=exp.column("test_valid_from", quoted=True), valid_to_col=exp.column("test_valid_to", quoted=True), updated_at_col=exp.column("test_updated_at", quoted=True), - columns_to_types={ - "id1": exp.DataType.build("INT"), - "id2": exp.DataType.build("INT"), + invalidate_hard_deletes=False, + target_columns_to_types={ + "id": exp.DataType.build("INT"), "name": exp.DataType.build("VARCHAR"), "price": exp.DataType.build("DOUBLE"), - "test_updated_at": exp.DataType.build("TIMESTAMPTZ"), - "test_valid_from": exp.DataType.build("TIMESTAMPTZ"), - "test_valid_to": exp.DataType.build("TIMESTAMPTZ"), + "test_updated_at": exp.DataType.build("TIMESTAMP"), + "test_valid_from": exp.DataType.build("TIMESTAMP"), + "test_valid_to": exp.DataType.build("TIMESTAMP"), }, execution_time=datetime(2020, 1, 1, 0, 0, 0), ) @@ -1384,31 +1864,23 @@ def test_merge_scd_type_2_pandas(make_mocked_engine_adapter: t.Callable): """ CREATE OR REPLACE TABLE "target" AS WITH "source" AS ( - SELECT DISTINCT ON ("id1", "id2") + SELECT DISTINCT ON (COALESCE("id", '')) TRUE AS "_exists", - "id1", - "id2", + "id", "name", "price", - CAST("test_updated_at" AS TIMESTAMPTZ) AS "test_updated_at" + CAST("test_updated_at" AS TIMESTAMP) AS "test_updated_at" FROM ( SELECT - CAST("id1" AS INT) AS "id1", - CAST("id2" AS INT) AS "id2", - CAST("name" AS VARCHAR) AS "name", - CAST("price" AS DOUBLE) AS "price", - CAST("test_updated_at" AS TIMESTAMPTZ) AS "test_updated_at", - CAST("test_valid_from" AS TIMESTAMPTZ) AS "test_valid_from", - CAST("test_valid_to" AS TIMESTAMPTZ) AS "test_valid_to" - FROM (VALUES - (1, 4, 'muffins', 4.0, '2020-01-01 10:00:00'), - (2, 5, 'chips', 5.0, '2020-01-02 15:00:00'), - (3, 6, 'soda', 6.0, '2020-01-03 12:00:00')) AS "t"("id1", "id2", "name", "price", "test_updated_at", "test_valid_from", "test_valid_to") + "id", + "name", + "price", + "test_updated_at" + FROM "source" ) AS "raw_source" ), "static" AS ( SELECT - "id1", - "id2", + "id", "name", "price", "test_updated_at", @@ -1420,8 +1892,7 @@ def test_merge_scd_type_2_pandas(make_mocked_engine_adapter: t.Callable): NOT "test_valid_to" IS NULL ), "latest" AS ( SELECT - "id1", - "id2", + "id", "name", "price", "test_updated_at", @@ -1433,8 +1904,7 @@ def test_merge_scd_type_2_pandas(make_mocked_engine_adapter: t.Callable): "test_valid_to" IS NULL ), "deleted" AS ( SELECT - "static"."id1", - "static"."id2", + "static"."id", "static"."name", "static"."price", "static"."test_updated_at", @@ -1442,55 +1912,251 @@ def test_merge_scd_type_2_pandas(make_mocked_engine_adapter: t.Callable): "static"."test_valid_to" FROM "static" LEFT JOIN "latest" - ON "static"."id1" = "latest"."id1" AND "static"."id2" = "latest"."id2" + ON COALESCE("static"."id", '') = COALESCE("latest"."id", '') WHERE "latest"."test_valid_to" IS NULL ), "latest_deleted" AS ( SELECT TRUE AS "_exists", - "id1" AS "_key0", - "id2" AS "_key1", + COALESCE("id", '') AS "_key0", MAX("test_valid_to") AS "test_valid_to" FROM "deleted" GROUP BY - "id1", - "id2" + COALESCE("id", '') ), "joined" AS ( SELECT - "source"."_exists", - "latest"."id1" AS "t_id1", - "latest"."id2" AS "t_id2", + "source"."_exists" AS "_exists", + "latest"."id" AS "t_id", "latest"."name" AS "t_name", "latest"."price" AS "t_price", "latest"."test_updated_at" AS "t_test_updated_at", "latest"."test_valid_from" AS "t_test_valid_from", "latest"."test_valid_to" AS "t_test_valid_to", - "source"."id1" AS "id1", - "source"."id2" AS "id2", + "source"."id" AS "id", "source"."name" AS "name", "source"."price" AS "price", "source"."test_updated_at" AS "test_updated_at" FROM "latest" LEFT JOIN "source" - ON "latest"."id1" = "source"."id1" AND "latest"."id2" = "source"."id2" + ON COALESCE("latest"."id", '') = COALESCE("source"."id", '') UNION ALL SELECT - "source"."_exists", - "latest"."id1" AS "t_id1", - "latest"."id2" AS "t_id2", + "source"."_exists" AS "_exists", + "latest"."id" AS "t_id", "latest"."name" AS "t_name", "latest"."price" AS "t_price", "latest"."test_updated_at" AS "t_test_updated_at", "latest"."test_valid_from" AS "t_test_valid_from", "latest"."test_valid_to" AS "t_test_valid_to", - "source"."id1" AS "id1", - "source"."id2" AS "id2", + "source"."id" AS "id", "source"."name" AS "name", "source"."price" AS "price", "source"."test_updated_at" AS "test_updated_at" FROM "latest" RIGHT JOIN "source" - ON "latest"."id1" = "source"."id1" AND "latest"."id2" = "source"."id2" + ON COALESCE("latest"."id", '') = COALESCE("source"."id", '') + WHERE + "latest"."_exists" IS NULL +), "updated_rows" AS ( + SELECT + COALESCE("joined"."t_id", "joined"."id") AS "id", + COALESCE("joined"."t_name", "joined"."name") AS "name", + COALESCE("joined"."t_price", "joined"."price") AS "price", + COALESCE("joined"."t_test_updated_at", "joined"."test_updated_at") AS "test_updated_at", + CASE + WHEN "t_test_valid_from" IS NULL AND NOT "latest_deleted"."_exists" IS NULL + THEN CASE + WHEN "latest_deleted"."test_valid_to" > "test_updated_at" + THEN "latest_deleted"."test_valid_to" + ELSE "test_updated_at" + END + WHEN "t_test_valid_from" IS NULL + THEN CAST('1970-01-01 00:00:00' AS TIMESTAMP) + ELSE "t_test_valid_from" + END AS "test_valid_from", + CASE + WHEN "joined"."test_updated_at" > "joined"."t_test_updated_at" + THEN "joined"."test_updated_at" + ELSE "t_test_valid_to" + END AS "test_valid_to" + FROM "joined" + LEFT JOIN "latest_deleted" + ON COALESCE("joined"."id", '') = "latest_deleted"."_key0" +), "inserted_rows" AS ( + SELECT + "id", + "name", + "price", + "test_updated_at", + "test_updated_at" AS "test_valid_from", + CAST(NULL AS TIMESTAMP) AS "test_valid_to" + FROM "joined" + WHERE + "joined"."test_updated_at" > "joined"."t_test_updated_at" +) +SELECT + CAST("id" AS INT) AS "id", + CAST("name" AS VARCHAR) AS "name", + CAST("price" AS DOUBLE) AS "price", + CAST("test_updated_at" AS TIMESTAMP) AS "test_updated_at", + CAST("test_valid_from" AS TIMESTAMP) AS "test_valid_from", + CAST("test_valid_to" AS TIMESTAMP) AS "test_valid_to" +FROM ( + SELECT "id", "name", "price", "test_updated_at", "test_valid_from", "test_valid_to" FROM "static" + UNION ALL SELECT "id", "name", "price", "test_updated_at", "test_valid_from", "test_valid_to" FROM "updated_rows" + UNION ALL SELECT "id", "name", "price", "test_updated_at", "test_valid_from", "test_valid_to" FROM "inserted_rows" +) AS "_subquery" + """ + ).sql() + ) + + +def test_merge_scd_type_2_pandas(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(EngineAdapter) + + df = pd.DataFrame( + { + "id1": [1, 2, 3], + "id2": [4, 5, 6], + "name": ["muffins", "chips", "soda"], + "price": [4.0, 5.0, 6.0], + "test_updated_at": [ + "2020-01-01 10:00:00", + "2020-01-02 15:00:00", + "2020-01-03 12:00:00", + ], + } + ) + adapter.scd_type_2_by_time( + target_table="target", + source_table=df, + unique_key=[exp.column("id1"), exp.column("id2")], + valid_from_col=exp.column("test_valid_from", quoted=True), + valid_to_col=exp.column("test_valid_to", quoted=True), + updated_at_col=exp.column("test_updated_at", quoted=True), + target_columns_to_types={ + "id1": exp.DataType.build("INT"), + "id2": exp.DataType.build("INT"), + "name": exp.DataType.build("VARCHAR"), + "price": exp.DataType.build("DOUBLE"), + "test_updated_at": exp.DataType.build("TIMESTAMPTZ"), + "test_valid_from": exp.DataType.build("TIMESTAMPTZ"), + "test_valid_to": exp.DataType.build("TIMESTAMPTZ"), + }, + execution_time=datetime(2020, 1, 1, 0, 0, 0), + ) + + assert ( + adapter.cursor.execute.call_args[0][0] + == parse_one( + """ +CREATE OR REPLACE TABLE "target" AS +WITH "source" AS ( + SELECT DISTINCT ON ("id1", "id2") + TRUE AS "_exists", + "id1", + "id2", + "name", + "price", + CAST("test_updated_at" AS TIMESTAMPTZ) AS "test_updated_at" + FROM ( + SELECT + CAST("id1" AS INT) AS "id1", + CAST("id2" AS INT) AS "id2", + CAST("name" AS VARCHAR) AS "name", + CAST("price" AS DOUBLE) AS "price", + CAST("test_updated_at" AS TIMESTAMPTZ) AS "test_updated_at", + FROM (VALUES + (1, 4, 'muffins', 4.0, '2020-01-01 10:00:00'), + (2, 5, 'chips', 5.0, '2020-01-02 15:00:00'), + (3, 6, 'soda', 6.0, '2020-01-03 12:00:00')) AS "t"("id1", "id2", "name", "price", "test_updated_at") + ) AS "raw_source" +), "static" AS ( + SELECT + "id1", + "id2", + "name", + "price", + "test_updated_at", + "test_valid_from", + "test_valid_to", + TRUE AS "_exists" + FROM "target" + WHERE + NOT "test_valid_to" IS NULL +), "latest" AS ( + SELECT + "id1", + "id2", + "name", + "price", + "test_updated_at", + "test_valid_from", + "test_valid_to", + TRUE AS "_exists" + FROM "target" + WHERE + "test_valid_to" IS NULL +), "deleted" AS ( + SELECT + "static"."id1", + "static"."id2", + "static"."name", + "static"."price", + "static"."test_updated_at", + "static"."test_valid_from", + "static"."test_valid_to" + FROM "static" + LEFT JOIN "latest" + ON "static"."id1" = "latest"."id1" AND "static"."id2" = "latest"."id2" + WHERE + "latest"."test_valid_to" IS NULL +), "latest_deleted" AS ( + SELECT + TRUE AS "_exists", + "id1" AS "_key0", + "id2" AS "_key1", + MAX("test_valid_to") AS "test_valid_to" + FROM "deleted" + GROUP BY + "id1", + "id2" +), "joined" AS ( + SELECT + "source"."_exists" AS "_exists", + "latest"."id1" AS "t_id1", + "latest"."id2" AS "t_id2", + "latest"."name" AS "t_name", + "latest"."price" AS "t_price", + "latest"."test_updated_at" AS "t_test_updated_at", + "latest"."test_valid_from" AS "t_test_valid_from", + "latest"."test_valid_to" AS "t_test_valid_to", + "source"."id1" AS "id1", + "source"."id2" AS "id2", + "source"."name" AS "name", + "source"."price" AS "price", + "source"."test_updated_at" AS "test_updated_at" + FROM "latest" + LEFT JOIN "source" + ON "latest"."id1" = "source"."id1" AND "latest"."id2" = "source"."id2" + UNION ALL + SELECT + "source"."_exists" AS "_exists", + "latest"."id1" AS "t_id1", + "latest"."id2" AS "t_id2", + "latest"."name" AS "t_name", + "latest"."price" AS "t_price", + "latest"."test_updated_at" AS "t_test_updated_at", + "latest"."test_valid_from" AS "t_test_valid_from", + "latest"."test_valid_to" AS "t_test_valid_to", + "source"."id1" AS "id1", + "source"."id2" AS "id2", + "source"."name" AS "name", + "source"."price" AS "price", + "source"."test_updated_at" AS "test_updated_at" + FROM "latest" + RIGHT JOIN "source" + ON "latest"."id1" = "source"."id1" AND "latest"."id2" = "source"."id2" WHERE "latest"."_exists" IS NULL ), "updated_rows" AS ( @@ -1512,8 +2178,8 @@ def test_merge_scd_type_2_pandas(make_mocked_engine_adapter: t.Callable): ELSE "t_test_valid_from" END AS "test_valid_from", CASE - WHEN "test_updated_at" > "t_test_updated_at" - THEN "test_updated_at" + WHEN "joined"."test_updated_at" > "joined"."t_test_updated_at" + THEN "joined"."test_updated_at" WHEN "joined"."_exists" IS NULL THEN CAST('2020-01-01 00:00:00+00:00' AS TIMESTAMPTZ) ELSE "t_test_valid_to" @@ -1533,7 +2199,7 @@ def test_merge_scd_type_2_pandas(make_mocked_engine_adapter: t.Callable): CAST(NULL AS TIMESTAMPTZ) AS "test_valid_to" FROM "joined" WHERE - "test_updated_at" > "t_test_updated_at" + "joined"."test_updated_at" > "joined"."t_test_updated_at" ) SELECT CAST("id1" AS INT) AS "id1", CAST("id2" AS INT) AS "id2", CAST("name" AS VARCHAR) AS "name", CAST("price" AS DOUBLE) AS "price", CAST("test_updated_at" AS TIMESTAMPTZ) AS "test_updated_at", CAST("test_valid_from" AS TIMESTAMPTZ) AS "test_valid_from", CAST("test_valid_to" AS TIMESTAMPTZ) AS "test_valid_to" FROM (SELECT "id1", "id2", "name", "price", "test_updated_at", "test_valid_from", "test_valid_to" FROM "static" UNION ALL SELECT "id1", "id2", "name", "price", "test_updated_at", "test_valid_from", "test_valid_to" FROM "updated_rows" UNION ALL SELECT "id1", "id2", "name", "price", "test_updated_at", "test_valid_from", "test_valid_to" FROM "inserted_rows") AS "_subquery" """ @@ -1551,7 +2217,7 @@ def test_scd_type_2_by_column(make_mocked_engine_adapter: t.Callable): valid_from_col=exp.column("test_VALID_from", quoted=True), valid_to_col=exp.column("test_valid_to", quoted=True), check_columns=[exp.column("name"), exp.column("price")], - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("INT"), "name": exp.DataType.build("VARCHAR"), "price": exp.DataType.build("DOUBLE"), @@ -1559,6 +2225,7 @@ def test_scd_type_2_by_column(make_mocked_engine_adapter: t.Callable): "test_valid_to": exp.DataType.build("TIMESTAMP"), }, execution_time=datetime(2020, 1, 1, 0, 0, 0), + extra_col_ignore="testing", ) assert ( @@ -1623,7 +2290,7 @@ def test_scd_type_2_by_column(make_mocked_engine_adapter: t.Callable): "id" ), "joined" AS ( SELECT - "source"."_exists", + "source"."_exists" AS "_exists", "latest"."id" AS "t_id", "latest"."name" AS "t_name", "latest"."price" AS "t_price", @@ -1637,7 +2304,7 @@ def test_scd_type_2_by_column(make_mocked_engine_adapter: t.Callable): ON "latest"."id" = "source"."id" UNION ALL SELECT - "source"."_exists", + "source"."_exists" AS "_exists", "latest"."id" AS "t_id", "latest"."name" AS "t_name", "latest"."price" AS "t_price", @@ -1656,27 +2323,27 @@ def test_scd_type_2_by_column(make_mocked_engine_adapter: t.Callable): COALESCE("joined"."t_id", "joined"."id") AS "id", COALESCE("joined"."t_name", "joined"."name") AS "name", COALESCE("joined"."t_price", "joined"."price") AS "price", - COALESCE("t_test_VALID_from", CAST('1970-01-01 00:00:00' AS TIMESTAMP)) AS "test_VALID_from", + COALESCE("t_test_VALID_from", CAST('2020-01-01 00:00:00' AS TIMESTAMP)) AS "test_VALID_from", CASE WHEN "joined"."_exists" IS NULL OR ( ( - NOT "t_id" IS NULL AND NOT "id" IS NULL + NOT "joined"."t_id" IS NULL AND NOT "joined"."id" IS NULL ) AND ( - "name" <> "t_name" + "joined"."name" <> "joined"."t_name" OR ( - "t_name" IS NULL AND NOT "name" IS NULL + "joined"."t_name" IS NULL AND NOT "joined"."name" IS NULL ) OR ( - NOT "t_name" IS NULL AND "name" IS NULL + NOT "joined"."t_name" IS NULL AND "joined"."name" IS NULL ) - OR "price" <> "t_price" + OR "joined"."price" <> "joined"."t_price" OR ( - "t_price" IS NULL AND NOT "price" IS NULL + "joined"."t_price" IS NULL AND NOT "joined"."price" IS NULL ) OR ( - NOT "t_price" IS NULL AND "price" IS NULL + NOT "joined"."t_price" IS NULL AND "joined"."price" IS NULL ) ) ) @@ -1696,22 +2363,22 @@ def test_scd_type_2_by_column(make_mocked_engine_adapter: t.Callable): FROM "joined" WHERE ( - NOT "t_id" IS NULL AND NOT "id" IS NULL + NOT "joined"."t_id" IS NULL AND NOT "joined"."id" IS NULL ) AND ( - "name" <> "t_name" + "joined"."name" <> "joined"."t_name" OR ( - "t_name" IS NULL AND NOT "name" IS NULL + "joined"."t_name" IS NULL AND NOT "joined"."name" IS NULL ) OR ( - NOT "t_name" IS NULL AND "name" IS NULL + NOT "joined"."t_name" IS NULL AND "joined"."name" IS NULL ) - OR "price" <> "t_price" + OR "joined"."price" <> "joined"."t_price" OR ( - "t_price" IS NULL AND NOT "price" IS NULL + "joined"."t_price" IS NULL AND NOT "joined"."price" IS NULL ) OR ( - NOT "t_price" IS NULL AND "price" IS NULL + NOT "joined"."t_price" IS NULL AND "joined"."price" IS NULL ) ) ) @@ -1721,25 +2388,25 @@ def test_scd_type_2_by_column(make_mocked_engine_adapter: t.Callable): ) -def test_scd_type_2_truncate(make_mocked_engine_adapter: t.Callable): +def test_scd_type_2_by_column_composite_key(make_mocked_engine_adapter: t.Callable): adapter = make_mocked_engine_adapter(EngineAdapter) adapter.scd_type_2_by_column( target_table="target", - source_table=t.cast(exp.Select, parse_one("SELECT id, name, price FROM source")), - unique_key=[exp.column("id")], - valid_from_col=exp.column("test_valid_from", quoted=True), + source_table=t.cast(exp.Select, parse_one("SELECT id_a, id_b, name, price FROM source")), + unique_key=[exp.func("CONCAT", exp.column("id_a"), exp.column("id_b"))], + valid_from_col=exp.column("test_VALID_from", quoted=True), valid_to_col=exp.column("test_valid_to", quoted=True), check_columns=[exp.column("name"), exp.column("price")], - columns_to_types={ - "id": exp.DataType.build("INT"), + target_columns_to_types={ + "id_a": exp.DataType.build("VARCHAR"), + "id_b": exp.DataType.build("VARCHAR"), "name": exp.DataType.build("VARCHAR"), "price": exp.DataType.build("DOUBLE"), - "test_valid_from": exp.DataType.build("TIMESTAMP"), + "test_VALID_from": exp.DataType.build("TIMESTAMP"), "test_valid_to": exp.DataType.build("TIMESTAMP"), }, execution_time=datetime(2020, 1, 1, 0, 0, 0), - truncate=True, ) assert ( @@ -1748,118 +2415,126 @@ def test_scd_type_2_truncate(make_mocked_engine_adapter: t.Callable): """ CREATE OR REPLACE TABLE "target" AS WITH "source" AS ( - SELECT DISTINCT ON ("id") + SELECT DISTINCT ON (CONCAT("id_a", "id_b")) TRUE AS "_exists", - "id", + "id_a", + "id_b", "name", - "price" + "price", FROM ( SELECT - "id", + "id_a", + "id_b", "name", "price" FROM "source" ) AS "raw_source" ), "static" AS ( SELECT - "id", + "id_a", + "id_b", "name", "price", - "test_valid_from", + "test_VALID_from", "test_valid_to", TRUE AS "_exists" FROM "target" WHERE NOT "test_valid_to" IS NULL - LIMIT 0 ), "latest" AS ( SELECT - "id", + "id_a", + "id_b", "name", "price", - "test_valid_from", + "test_VALID_from", "test_valid_to", TRUE AS "_exists" FROM "target" WHERE "test_valid_to" IS NULL - LIMIT 0 ), "deleted" AS ( SELECT - "static"."id", + "static"."id_a", + "static"."id_b", "static"."name", "static"."price", - "static"."test_valid_from", + "static"."test_VALID_from", "static"."test_valid_to" FROM "static" LEFT JOIN "latest" - ON "static"."id" = "latest"."id" + ON CONCAT("static"."id_a", "static"."id_b") = CONCAT("latest"."id_a", "latest"."id_b") WHERE "latest"."test_valid_to" IS NULL ), "latest_deleted" AS ( SELECT TRUE AS "_exists", - "id" AS "_key0", + CONCAT("id_a", "id_b") AS "_key0", MAX("test_valid_to") AS "test_valid_to" FROM "deleted" GROUP BY - "id" + CONCAT("id_a", "id_b") ), "joined" AS ( SELECT - "source"."_exists", - "latest"."id" AS "t_id", + "source"."_exists" AS "_exists", + "latest"."id_a" AS "t_id_a", + "latest"."id_b" AS "t_id_b", "latest"."name" AS "t_name", "latest"."price" AS "t_price", - "latest"."test_valid_from" AS "t_test_valid_from", + "latest"."test_VALID_from" AS "t_test_VALID_from", "latest"."test_valid_to" AS "t_test_valid_to", - "source"."id" AS "id", + "source"."id_a" AS "id_a", + "source"."id_b" AS "id_b", "source"."name" AS "name", "source"."price" AS "price" FROM "latest" LEFT JOIN "source" - ON "latest"."id" = "source"."id" + ON CONCAT("latest"."id_a", "latest"."id_b") = CONCAT("source"."id_a", "source"."id_b") UNION ALL SELECT - "source"."_exists", - "latest"."id" AS "t_id", + "source"."_exists" AS "_exists", + "latest"."id_a" AS "t_id_a", + "latest"."id_b" AS "t_id_b", "latest"."name" AS "t_name", "latest"."price" AS "t_price", - "latest"."test_valid_from" AS "t_test_valid_from", + "latest"."test_VALID_from" AS "t_test_VALID_from", "latest"."test_valid_to" AS "t_test_valid_to", - "source"."id" AS "id", + "source"."id_a" AS "id_a", + "source"."id_b" AS "id_b", "source"."name" AS "name", "source"."price" AS "price" FROM "latest" RIGHT JOIN "source" - ON "latest"."id" = "source"."id" + ON CONCAT("latest"."id_a", "latest"."id_b") = CONCAT("source"."id_a", "source"."id_b") WHERE "latest"."_exists" IS NULL ), "updated_rows" AS ( SELECT - COALESCE("joined"."t_id", "joined"."id") AS "id", + COALESCE("joined"."t_id_a", "joined"."id_a") AS "id_a", + COALESCE("joined"."t_id_b", "joined"."id_b") AS "id_b", COALESCE("joined"."t_name", "joined"."name") AS "name", COALESCE("joined"."t_price", "joined"."price") AS "price", - COALESCE("t_test_valid_from", CAST('1970-01-01 00:00:00' AS TIMESTAMP)) AS "test_valid_from", + COALESCE("t_test_VALID_from", CAST('2020-01-01 00:00:00' AS TIMESTAMP)) AS "test_VALID_from", CASE WHEN "joined"."_exists" IS NULL OR ( ( - NOT "t_id" IS NULL AND NOT "id" IS NULL + NOT CONCAT("t_id_a", "t_id_b") IS NULL AND NOT CONCAT("id_a", "id_b") IS NULL ) AND ( - "name" <> "t_name" + "joined"."name" <> "joined"."t_name" OR ( - "t_name" IS NULL AND NOT "name" IS NULL + "joined"."t_name" IS NULL AND NOT "joined"."name" IS NULL ) OR ( - NOT "t_name" IS NULL AND "name" IS NULL + NOT "joined"."t_name" IS NULL AND "joined"."name" IS NULL ) - OR "price" <> "t_price" + OR "joined"."price" <> "joined"."t_price" OR ( - "t_price" IS NULL AND NOT "price" IS NULL + "joined"."t_price" IS NULL AND NOT "joined"."price" IS NULL ) OR ( - NOT "t_price" IS NULL AND "price" IS NULL + NOT "joined"."t_price" IS NULL AND "joined"."price" IS NULL ) ) ) @@ -1868,43 +2543,44 @@ def test_scd_type_2_truncate(make_mocked_engine_adapter: t.Callable): END AS "test_valid_to" FROM "joined" LEFT JOIN "latest_deleted" - ON "joined"."id" = "latest_deleted"."_key0" + ON CONCAT("joined"."id_a", "joined"."id_b") = "latest_deleted"."_key0" ), "inserted_rows" AS ( SELECT - "id", + "id_a", + "id_b", "name", "price", - CAST('2020-01-01 00:00:00' AS TIMESTAMP) AS "test_valid_from", + CAST('2020-01-01 00:00:00' AS TIMESTAMP) AS "test_VALID_from", CAST(NULL AS TIMESTAMP) AS "test_valid_to" FROM "joined" WHERE ( - NOT "t_id" IS NULL AND NOT "id" IS NULL + NOT CONCAT("t_id_a", "t_id_b") IS NULL AND NOT CONCAT("id_a", "id_b") IS NULL ) AND ( - "name" <> "t_name" + "joined"."name" <> "joined"."t_name" OR ( - "t_name" IS NULL AND NOT "name" IS NULL + "joined"."t_name" IS NULL AND NOT "joined"."name" IS NULL ) OR ( - NOT "t_name" IS NULL AND "name" IS NULL + NOT "joined"."t_name" IS NULL AND "joined"."name" IS NULL ) - OR "price" <> "t_price" + OR "joined"."price" <> "joined"."t_price" OR ( - "t_price" IS NULL AND NOT "price" IS NULL + "joined"."t_price" IS NULL AND NOT "joined"."price" IS NULL ) OR ( - NOT "t_price" IS NULL AND "price" IS NULL + NOT "joined"."t_price" IS NULL AND "joined"."price" IS NULL ) ) ) -SELECT CAST("id" AS INT) AS "id", CAST("name" AS VARCHAR) AS "name", CAST("price" AS DOUBLE) AS "price", CAST("test_valid_from" AS TIMESTAMP) AS "test_valid_from", CAST("test_valid_to" AS TIMESTAMP) AS "test_valid_to" FROM (SELECT "id", "name", "price", "test_valid_from", "test_valid_to" FROM "static" UNION ALL SELECT "id", "name", "price", "test_valid_from", "test_valid_to" FROM "updated_rows" UNION ALL SELECT "id", "name", "price", "test_valid_from", "test_valid_to" FROM "inserted_rows") AS "_subquery" +SELECT CAST("id_a" AS VARCHAR) AS "id_a", CAST("id_b" AS VARCHAR) AS "id_b", CAST("name" AS VARCHAR) AS "name", CAST("price" AS DOUBLE) AS "price", CAST("test_VALID_from" AS TIMESTAMP) AS "test_VALID_from", CAST("test_valid_to" AS TIMESTAMP) AS "test_valid_to" FROM (SELECT "id_a", "id_b", "name", "price", "test_VALID_from", "test_valid_to" FROM "static" UNION ALL SELECT "id_a", "id_b", "name", "price", "test_VALID_from", "test_valid_to" FROM "updated_rows" UNION ALL SELECT "id_a", "id_b", "name", "price", "test_VALID_from", "test_valid_to" FROM "inserted_rows") AS "_subquery" """ ).sql() ) -def test_scd_type_2_by_column_star_check(make_mocked_engine_adapter: t.Callable): +def test_scd_type_2_truncate(make_mocked_engine_adapter: t.Callable): adapter = make_mocked_engine_adapter(EngineAdapter) adapter.scd_type_2_by_column( @@ -1913,8 +2589,8 @@ def test_scd_type_2_by_column_star_check(make_mocked_engine_adapter: t.Callable) unique_key=[exp.column("id")], valid_from_col=exp.column("test_valid_from", quoted=True), valid_to_col=exp.column("test_valid_to", quoted=True), - check_columns=exp.Star(), - columns_to_types={ + check_columns=[exp.column("name"), exp.column("price")], + target_columns_to_types={ "id": exp.DataType.build("INT"), "name": exp.DataType.build("VARCHAR"), "price": exp.DataType.build("DOUBLE"), @@ -1922,6 +2598,7 @@ def test_scd_type_2_by_column_star_check(make_mocked_engine_adapter: t.Callable) "test_valid_to": exp.DataType.build("TIMESTAMP"), }, execution_time=datetime(2020, 1, 1, 0, 0, 0), + truncate=True, ) assert ( @@ -1953,6 +2630,7 @@ def test_scd_type_2_by_column_star_check(make_mocked_engine_adapter: t.Callable) FROM "target" WHERE NOT "test_valid_to" IS NULL + LIMIT 0 ), "latest" AS ( SELECT "id", @@ -1964,6 +2642,7 @@ def test_scd_type_2_by_column_star_check(make_mocked_engine_adapter: t.Callable) FROM "target" WHERE "test_valid_to" IS NULL + LIMIT 0 ), "deleted" AS ( SELECT "static"."id", @@ -1986,7 +2665,7 @@ def test_scd_type_2_by_column_star_check(make_mocked_engine_adapter: t.Callable) "id" ), "joined" AS ( SELECT - "source"."_exists", + "source"."_exists" AS "_exists", "latest"."id" AS "t_id", "latest"."name" AS "t_name", "latest"."price" AS "t_price", @@ -2000,7 +2679,7 @@ def test_scd_type_2_by_column_star_check(make_mocked_engine_adapter: t.Callable) ON "latest"."id" = "source"."id" UNION ALL SELECT - "source"."_exists", + "source"."_exists" AS "_exists", "latest"."id" AS "t_id", "latest"."name" AS "t_name", "latest"."price" AS "t_price", @@ -2024,29 +2703,22 @@ def test_scd_type_2_by_column_star_check(make_mocked_engine_adapter: t.Callable) WHEN "joined"."_exists" IS NULL OR ( ( - NOT "t_id" IS NULL AND NOT "id" IS NULL + NOT "joined"."t_id" IS NULL AND NOT "joined"."id" IS NULL ) AND ( - "id" <> "t_id" - OR ( - "t_id" IS NULL AND NOT "id" IS NULL - ) - OR ( - NOT "t_id" IS NULL AND "id" IS NULL - ) - OR "name" <> "t_name" + "joined"."name" <> "joined"."t_name" OR ( - "t_name" IS NULL AND NOT "name" IS NULL + "joined"."t_name" IS NULL AND NOT "joined"."name" IS NULL ) OR ( - NOT "t_name" IS NULL AND "name" IS NULL + NOT "joined"."t_name" IS NULL AND "joined"."name" IS NULL ) - OR "price" <> "t_price" + OR "joined"."price" <> "joined"."t_price" OR ( - "t_price" IS NULL AND NOT "price" IS NULL + "joined"."t_price" IS NULL AND NOT "joined"."price" IS NULL ) OR ( - NOT "t_price" IS NULL AND "price" IS NULL + NOT "joined"."t_price" IS NULL AND "joined"."price" IS NULL ) ) ) @@ -2066,29 +2738,22 @@ def test_scd_type_2_by_column_star_check(make_mocked_engine_adapter: t.Callable) FROM "joined" WHERE ( - NOT "t_id" IS NULL AND NOT "id" IS NULL + NOT "joined"."t_id" IS NULL AND NOT "joined"."id" IS NULL ) AND ( - "id" <> "t_id" + "joined"."name" <> "joined"."t_name" OR ( - "t_id" IS NULL AND NOT "id" IS NULL + "joined"."t_name" IS NULL AND NOT "joined"."name" IS NULL ) OR ( - NOT "t_id" IS NULL AND "id" IS NULL + NOT "joined"."t_name" IS NULL AND "joined"."name" IS NULL ) - OR "name" <> "t_name" + OR "joined"."price" <> "joined"."t_price" OR ( - "t_name" IS NULL AND NOT "name" IS NULL + "joined"."t_price" IS NULL AND NOT "joined"."price" IS NULL ) OR ( - NOT "t_name" IS NULL AND "name" IS NULL - ) - OR "price" <> "t_price" - OR ( - "t_price" IS NULL AND NOT "price" IS NULL - ) - OR ( - NOT "t_price" IS NULL AND "price" IS NULL + NOT "joined"."t_price" IS NULL AND "joined"."price" IS NULL ) ) ) @@ -2098,7 +2763,7 @@ def test_scd_type_2_by_column_star_check(make_mocked_engine_adapter: t.Callable) ) -def test_scd_type_2_by_column_no_invalidate_hard_deletes(make_mocked_engine_adapter: t.Callable): +def test_scd_type_2_by_column_star_check(make_mocked_engine_adapter: t.Callable): adapter = make_mocked_engine_adapter(EngineAdapter) adapter.scd_type_2_by_column( @@ -2107,9 +2772,8 @@ def test_scd_type_2_by_column_no_invalidate_hard_deletes(make_mocked_engine_adap unique_key=[exp.column("id")], valid_from_col=exp.column("test_valid_from", quoted=True), valid_to_col=exp.column("test_valid_to", quoted=True), - invalidate_hard_deletes=False, - check_columns=[exp.column("name"), exp.column("price")], - columns_to_types={ + check_columns=exp.Star(), + target_columns_to_types={ "id": exp.DataType.build("INT"), "name": exp.DataType.build("VARCHAR"), "price": exp.DataType.build("DOUBLE"), @@ -2181,7 +2845,7 @@ def test_scd_type_2_by_column_no_invalidate_hard_deletes(make_mocked_engine_adap "id" ), "joined" AS ( SELECT - "source"."_exists", + "source"."_exists" AS "_exists", "latest"."id" AS "t_id", "latest"."name" AS "t_name", "latest"."price" AS "t_price", @@ -2195,7 +2859,7 @@ def test_scd_type_2_by_column_no_invalidate_hard_deletes(make_mocked_engine_adap ON "latest"."id" = "source"."id" UNION ALL SELECT - "source"."_exists", + "source"."_exists" AS "_exists", "latest"."id" AS "t_id", "latest"."name" AS "t_name", "latest"."price" AS "t_price", @@ -2214,25 +2878,220 @@ def test_scd_type_2_by_column_no_invalidate_hard_deletes(make_mocked_engine_adap COALESCE("joined"."t_id", "joined"."id") AS "id", COALESCE("joined"."t_name", "joined"."name") AS "name", COALESCE("joined"."t_price", "joined"."price") AS "price", - COALESCE("t_test_valid_from", CAST('1970-01-01 00:00:00' AS TIMESTAMP)) AS "test_valid_from", + COALESCE("t_test_valid_from", CAST('2020-01-01 00:00:00' AS TIMESTAMP)) AS "test_valid_from", + CASE + WHEN "joined"."_exists" IS NULL + OR ( + ( + NOT "joined"."t_id" IS NULL AND NOT "joined"."id" IS NULL + ) + AND ( + "joined"."id" <> "joined"."t_id" + OR ( + "joined"."t_id" IS NULL AND NOT "joined"."id" IS NULL + ) + OR ( + NOT "joined"."t_id" IS NULL AND "joined"."id" IS NULL + ) + OR "joined"."name" <> "joined"."t_name" + OR ( + "joined"."t_name" IS NULL AND NOT "joined"."name" IS NULL + ) + OR ( + NOT "joined"."t_name" IS NULL AND "joined"."name" IS NULL + ) + OR "joined"."price" <> "joined"."t_price" + OR ( + "joined"."t_price" IS NULL AND NOT "joined"."price" IS NULL + ) + OR ( + NOT "joined"."t_price" IS NULL AND "joined"."price" IS NULL + ) + ) + ) + THEN CAST('2020-01-01 00:00:00' AS TIMESTAMP) + ELSE "t_test_valid_to" + END AS "test_valid_to" + FROM "joined" + LEFT JOIN "latest_deleted" + ON "joined"."id" = "latest_deleted"."_key0" +), "inserted_rows" AS ( + SELECT + "id", + "name", + "price", + CAST('2020-01-01 00:00:00' AS TIMESTAMP) AS "test_valid_from", + CAST(NULL AS TIMESTAMP) AS "test_valid_to" + FROM "joined" + WHERE + ( + NOT "joined"."t_id" IS NULL AND NOT "joined"."id" IS NULL + ) + AND ( + "joined"."id" <> "joined"."t_id" + OR ( + "joined"."t_id" IS NULL AND NOT "joined"."id" IS NULL + ) + OR ( + NOT "joined"."t_id" IS NULL AND "joined"."id" IS NULL + ) + OR "joined"."name" <> "joined"."t_name" + OR ( + "joined"."t_name" IS NULL AND NOT "joined"."name" IS NULL + ) + OR ( + NOT "joined"."t_name" IS NULL AND "joined"."name" IS NULL + ) + OR "joined"."price" <> "joined"."t_price" + OR ( + "joined"."t_price" IS NULL AND NOT "joined"."price" IS NULL + ) + OR ( + NOT "joined"."t_price" IS NULL AND "joined"."price" IS NULL + ) + ) +) +SELECT CAST("id" AS INT) AS "id", CAST("name" AS VARCHAR) AS "name", CAST("price" AS DOUBLE) AS "price", CAST("test_valid_from" AS TIMESTAMP) AS "test_valid_from", CAST("test_valid_to" AS TIMESTAMP) AS "test_valid_to" FROM (SELECT "id", "name", "price", "test_valid_from", "test_valid_to" FROM "static" UNION ALL SELECT "id", "name", "price", "test_valid_from", "test_valid_to" FROM "updated_rows" UNION ALL SELECT "id", "name", "price", "test_valid_from", "test_valid_to" FROM "inserted_rows") AS "_subquery" + """ + ).sql() + ) + + +def test_scd_type_2_by_column_no_invalidate_hard_deletes(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(EngineAdapter) + + adapter.scd_type_2_by_column( + target_table="target", + source_table=t.cast(exp.Select, parse_one("SELECT id, name, price FROM source")), + unique_key=[exp.column("id")], + valid_from_col=exp.column("test_valid_from", quoted=True), + valid_to_col=exp.column("test_valid_to", quoted=True), + invalidate_hard_deletes=False, + check_columns=[exp.column("name"), exp.column("price")], + target_columns_to_types={ + "id": exp.DataType.build("INT"), + "name": exp.DataType.build("VARCHAR"), + "price": exp.DataType.build("DOUBLE"), + "test_valid_from": exp.DataType.build("TIMESTAMP"), + "test_valid_to": exp.DataType.build("TIMESTAMP"), + }, + execution_time=datetime(2020, 1, 1, 0, 0, 0), + ) + + assert ( + adapter.cursor.execute.call_args[0][0] + == parse_one( + """ +CREATE OR REPLACE TABLE "target" AS +WITH "source" AS ( + SELECT DISTINCT ON ("id") + TRUE AS "_exists", + "id", + "name", + "price" + FROM ( + SELECT + "id", + "name", + "price" + FROM "source" + ) AS "raw_source" +), "static" AS ( + SELECT + "id", + "name", + "price", + "test_valid_from", + "test_valid_to", + TRUE AS "_exists" + FROM "target" + WHERE + NOT "test_valid_to" IS NULL +), "latest" AS ( + SELECT + "id", + "name", + "price", + "test_valid_from", + "test_valid_to", + TRUE AS "_exists" + FROM "target" + WHERE + "test_valid_to" IS NULL +), "deleted" AS ( + SELECT + "static"."id", + "static"."name", + "static"."price", + "static"."test_valid_from", + "static"."test_valid_to" + FROM "static" + LEFT JOIN "latest" + ON "static"."id" = "latest"."id" + WHERE + "latest"."test_valid_to" IS NULL +), "latest_deleted" AS ( + SELECT + TRUE AS "_exists", + "id" AS "_key0", + MAX("test_valid_to") AS "test_valid_to" + FROM "deleted" + GROUP BY + "id" +), "joined" AS ( + SELECT + "source"."_exists" AS "_exists", + "latest"."id" AS "t_id", + "latest"."name" AS "t_name", + "latest"."price" AS "t_price", + "latest"."test_valid_from" AS "t_test_valid_from", + "latest"."test_valid_to" AS "t_test_valid_to", + "source"."id" AS "id", + "source"."name" AS "name", + "source"."price" AS "price" + FROM "latest" + LEFT JOIN "source" + ON "latest"."id" = "source"."id" + UNION ALL + SELECT + "source"."_exists" AS "_exists", + "latest"."id" AS "t_id", + "latest"."name" AS "t_name", + "latest"."price" AS "t_price", + "latest"."test_valid_from" AS "t_test_valid_from", + "latest"."test_valid_to" AS "t_test_valid_to", + "source"."id" AS "id", + "source"."name" AS "name", + "source"."price" AS "price" + FROM "latest" + RIGHT JOIN "source" + ON "latest"."id" = "source"."id" + WHERE + "latest"."_exists" IS NULL +), "updated_rows" AS ( + SELECT + COALESCE("joined"."t_id", "joined"."id") AS "id", + COALESCE("joined"."t_name", "joined"."name") AS "name", + COALESCE("joined"."t_price", "joined"."price") AS "price", + COALESCE("t_test_valid_from", CAST('2020-01-01 00:00:00' AS TIMESTAMP)) AS "test_valid_from", CASE WHEN ( - NOT "t_id" IS NULL AND NOT "id" IS NULL + NOT "joined"."t_id" IS NULL AND NOT "joined"."id" IS NULL ) AND ( - "name" <> "t_name" + "joined"."name" <> "joined"."t_name" OR ( - "t_name" IS NULL AND NOT "name" IS NULL + "joined"."t_name" IS NULL AND NOT "joined"."name" IS NULL ) OR ( - NOT "t_name" IS NULL AND "name" IS NULL + NOT "joined"."t_name" IS NULL AND "joined"."name" IS NULL ) - OR "price" <> "t_price" + OR "joined"."price" <> "joined"."t_price" OR ( - "t_price" IS NULL AND NOT "price" IS NULL + "joined"."t_price" IS NULL AND NOT "joined"."price" IS NULL ) OR ( - NOT "t_price" IS NULL AND "price" IS NULL + NOT "joined"."t_price" IS NULL AND "joined"."price" IS NULL ) ) THEN CAST('2020-01-01 00:00:00' AS TIMESTAMP) @@ -2251,22 +3110,22 @@ def test_scd_type_2_by_column_no_invalidate_hard_deletes(make_mocked_engine_adap FROM "joined" WHERE ( - NOT "t_id" IS NULL AND NOT "id" IS NULL + NOT "joined"."t_id" IS NULL AND NOT "joined"."id" IS NULL ) AND ( - "name" <> "t_name" + "joined"."name" <> "joined"."t_name" OR ( - "t_name" IS NULL AND NOT "name" IS NULL + "joined"."t_name" IS NULL AND NOT "joined"."name" IS NULL ) OR ( - NOT "t_name" IS NULL AND "name" IS NULL + NOT "joined"."t_name" IS NULL AND "joined"."name" IS NULL ) - OR "price" <> "t_price" + OR "joined"."price" <> "joined"."t_price" OR ( - "t_price" IS NULL AND NOT "price" IS NULL + "joined"."t_price" IS NULL AND NOT "joined"."price" IS NULL ) OR ( - NOT "t_price" IS NULL AND "price" IS NULL + NOT "joined"."t_price" IS NULL AND "joined"."price" IS NULL ) ) ) @@ -2297,6 +3156,26 @@ def test_replace_query(make_mocked_engine_adapter: t.Callable, mocker: MockerFix ] +def test_replace_query_data_object_type_mismatch( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(EngineAdapter) + mocker.patch.object( + adapter, + "_get_data_objects", + return_value=[DataObject(schema="", name="test_table", type="view")], + ) + + adapter.replace_query( + "test_table", parse_one("SELECT a FROM tbl"), {"a": exp.DataType.build("INT")} + ) + + assert to_sql_calls(adapter) == [ + 'DROP VIEW IF EXISTS "test_table"', + 'CREATE OR REPLACE TABLE "test_table" AS SELECT CAST("a" AS INT) AS "a" FROM (SELECT "a" FROM "tbl") AS "_subquery"', + ] + + def test_replace_query_pandas(make_mocked_engine_adapter: t.Callable): adapter = make_mocked_engine_adapter(EngineAdapter) adapter.DEFAULT_BATCH_SIZE = 1 @@ -2313,6 +3192,39 @@ def test_replace_query_pandas(make_mocked_engine_adapter: t.Callable): ] +def test_replace_query_pandas_source_columns(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(EngineAdapter) + df = pd.DataFrame({"a": [1, 2, 3], "ignored_source": [4, 5, 6]}) + adapter.replace_query( + "test_table", + df, + target_columns_to_types={ + "a": exp.DataType.build("INT"), + "b": exp.DataType.build("INT"), + }, + source_columns=["a", "ignored_source"], + ) + assert to_sql_calls(adapter) == [ + 'CREATE OR REPLACE TABLE "test_table" AS SELECT CAST("a" AS INT) AS "a", CAST("b" AS INT) AS "b" FROM (SELECT CAST("a" AS INT) AS "a", CAST(NULL AS INT) AS "b" FROM (VALUES (1), (2), (3)) AS "t"("a")) AS "_subquery"', + ] + + +def test_replace_query_query_source_columns(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.replace_query( + "test_table", + parse_one("SELECT a, ignored_source FROM tbl"), + target_columns_to_types={ + "a": exp.DataType.build("INT"), + "b": exp.DataType.build("INT"), + }, + source_columns=["a", "ignored_source"], + ) + assert to_sql_calls(adapter) == [ + 'CREATE OR REPLACE TABLE "test_table" AS SELECT CAST("a" AS INT) AS "a", CAST("b" AS INT) AS "b" FROM (SELECT "a", CAST(NULL AS INT) AS "b" FROM (SELECT "a", "ignored_source" FROM "tbl") AS "select_source_columns") AS "_subquery"', + ] + + def test_replace_query_self_referencing_not_exists_unknown( make_mocked_engine_adapter: t.Callable, mocker: MockerFixture ): @@ -2330,7 +3242,7 @@ def test_replace_query_self_referencing_not_exists_unknown( adapter.replace_query( "test", parse_one("SELECT a FROM test"), - columns_to_types={"a": exp.DataType.build("UNKNOWN")}, + target_columns_to_types={"a": exp.DataType.build("UNKNOWN")}, ) @@ -2347,7 +3259,7 @@ def test_replace_query_self_referencing_exists( adapter.replace_query( "test", parse_one("SELECT a FROM test"), - columns_to_types={"a": exp.DataType.build("UNKNOWN")}, + target_columns_to_types={"a": exp.DataType.build("UNKNOWN")}, ) assert to_sql_calls(adapter) == [ @@ -2368,7 +3280,7 @@ def test_replace_query_self_referencing_not_exists_known( adapter.replace_query( "test", parse_one("SELECT a FROM test"), - columns_to_types={"a": exp.DataType.build("INT")}, + target_columns_to_types={"a": exp.DataType.build("INT")}, ) assert to_sql_calls(adapter) == [ @@ -2377,12 +3289,19 @@ def test_replace_query_self_referencing_not_exists_known( ] -def test_create_table_like(make_mocked_engine_adapter: t.Callable): +def test_create_table_like(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): adapter = make_mocked_engine_adapter(EngineAdapter) + columns_to_types = { + "cola": exp.DataType.build("INT"), + "colb": exp.DataType.build("TEXT"), + } + columns_mock = mocker.patch.object(adapter, "columns") + columns_mock.return_value = columns_to_types + adapter.create_table_like("target_table", "source_table") adapter.cursor.execute.assert_called_once_with( - 'CREATE TABLE IF NOT EXISTS "target_table" LIKE "source_table"' + 'CREATE TABLE IF NOT EXISTS "target_table" ("cola" INT, "colb" TEXT)' ) @@ -2428,7 +3347,7 @@ def test_clone_table(make_mocked_engine_adapter: t.Callable): adapter.clone_table("target_table", "source_table") adapter.cursor.execute.assert_called_once_with( - "CREATE TABLE `target_table` CLONE `source_table`" + "CREATE TABLE IF NOT EXISTS `target_table` CLONE `source_table`" ) @@ -2454,6 +3373,39 @@ def test_ctas_pandas(make_mocked_engine_adapter: t.Callable): ] +def test_ctas_pandas_source_columns(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(EngineAdapter) + df = pd.DataFrame({"a": [1, 2, 3], "ignored_source": [4, 5, 6]}) + adapter.ctas( + "test_table", + df, + target_columns_to_types={ + "a": exp.DataType.build("INT"), + "b": exp.DataType.build("INT"), + }, + source_columns=["a", "ignored_source"], + ) + assert to_sql_calls(adapter) == [ + 'CREATE TABLE IF NOT EXISTS "test_table" AS SELECT CAST("a" AS INT) AS "a", CAST("b" AS INT) AS "b" FROM (SELECT CAST("a" AS INT) AS "a", CAST(NULL AS INT) AS "b" FROM (VALUES (1), (2), (3)) AS "t"("a")) AS "_subquery"', + ] + + +def test_ctas_query_source_columns(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.ctas( + "test_table", + parse_one("SELECT a, ignored_source FROM tbl"), + target_columns_to_types={ + "a": exp.DataType.build("INT"), + "b": exp.DataType.build("INT"), + }, + source_columns=["a", "ignored_source"], + ) + assert to_sql_calls(adapter) == [ + 'CREATE TABLE IF NOT EXISTS "test_table" AS SELECT CAST("a" AS INT) AS "a", CAST("b" AS INT) AS "b" FROM (SELECT "a", CAST(NULL AS INT) AS "b" FROM (SELECT "a", "ignored_source" FROM "tbl") AS "select_source_columns") AS "_subquery"', + ] + + def test_drop_view(make_mocked_engine_adapter: t.Callable): adapter = make_mocked_engine_adapter(EngineAdapter) @@ -2504,6 +3456,7 @@ def test_drop_view(make_mocked_engine_adapter: t.Callable): ) def test_drop_schema(kwargs, expected, make_mocked_engine_adapter: t.Callable): adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.SUPPORTED_DROP_CASCADE_OBJECT_KINDS = ["SCHEMA"] adapter.drop_schema(**kwargs) @@ -2593,7 +3546,7 @@ def test_insert_overwrite_by_partition_query( table_name, parse_one("SELECT a, ds, b FROM tbl"), partitioned_by=[d.parse_one(k) for k in partitioned_by], - columns_to_types={ + target_columns_to_types={ "a": exp.DataType.build("int"), "ds": exp.DataType.build("DATETIME"), "b": exp.DataType.build("boolean"), @@ -2633,7 +3586,7 @@ def test_insert_overwrite_by_partition_query_insert_overwrite_strategy( d.parse_one("DATETIME_TRUNC(ds, MONTH)"), d.parse_one("b"), ], - columns_to_types={ + target_columns_to_types={ "a": exp.DataType.build("int"), "ds": exp.DataType.build("DATETIME"), "b": exp.DataType.build("boolean"), @@ -2644,3 +3597,576 @@ def test_insert_overwrite_by_partition_query_insert_overwrite_strategy( assert sql_calls == [ 'INSERT OVERWRITE TABLE "test_schema"."test_table" ("a", "ds", "b") SELECT "a", "ds", "b" FROM "tbl"' ] + + +def test_log_sql(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): + adapter = make_mocked_engine_adapter(EngineAdapter) + + mock_logger = mocker.patch("sqlmesh.core.engine_adapter.base.logger") + + df = pd.DataFrame({"id": [1, 2, 3], "value": ["test1", "test2", "test3"]}) + + adapter.execute(parse_one("SELECT 1")) + adapter.execute(parse_one("INSERT INTO test SELECT * FROM source")) + adapter.execute(parse_one("INSERT INTO test (id, value) VALUES (1, 'test')")) + adapter.insert_append("test", df) + adapter.replace_query("test", df) + + assert mock_logger.log.call_count == 5 + assert mock_logger.log.call_args_list[0][0][2] == "SELECT 1" + assert mock_logger.log.call_args_list[1][0][2] == 'INSERT INTO "test" SELECT * FROM "source"' + assert ( + mock_logger.log.call_args_list[2][0][2] + == 'INSERT INTO "test" ("id", "value") VALUES ""' + ) + assert ( + mock_logger.log.call_args_list[3][0][2] + == 'INSERT INTO "test" ("id", "value") SELECT CAST("id" AS BIGINT) AS "id", CAST("value" AS TEXT) AS "value" FROM (VALUES "") AS "t"("id", "value")' + ) + assert ( + mock_logger.log.call_args_list[4][0][2] + == 'CREATE OR REPLACE TABLE "test" AS SELECT CAST("id" AS BIGINT) AS "id", CAST("value" AS TEXT) AS "value" FROM (SELECT CAST("id" AS BIGINT) AS "id", CAST("value" AS TEXT) AS "value" FROM (VALUES "") AS "t"("id", "value")) AS "_subquery"' + ) + + +@pytest.mark.parametrize( + "columns, source_columns, expected", + [ + (["a", "b"], None, 'SELECT "a", "b"'), + (["a", "b"], ["a"], 'SELECT "a", NULL AS "b"'), + (["a", "b"], ["a", "b"], 'SELECT "a", "b"'), + (["a", "b"], ["c", "d"], 'SELECT NULL AS "a", NULL AS "b"'), + (["a", "b"], [], 'SELECT "a", "b"'), + ], +) +def test_select_columns( + columns: t.List[str], source_columns: t.Optional[t.List[str]], expected: str +) -> None: + assert ( + EngineAdapter._select_columns( + columns, + source_columns, + ).sql() + == expected + ) + + +@pytest.mark.parametrize( + "columns_to_types, source_columns, expected", + [ + ( + { + "a": exp.DataType.build("INT"), + "b": exp.DataType.build("TEXT"), + }, + None, + [ + 'CAST("a" AS INT) AS "a"', + 'CAST("b" AS TEXT) AS "b"', + ], + ), + ( + { + "a": exp.DataType.build("INT"), + "b": exp.DataType.build("TEXT"), + }, + ["a"], + [ + 'CAST("a" AS INT) AS "a"', + 'CAST(NULL AS TEXT) AS "b"', + ], + ), + ( + { + "a": exp.DataType.build("INT"), + "b": exp.DataType.build("TEXT"), + }, + ["b", "c"], + [ + 'CAST(NULL AS INT) AS "a"', + 'CAST("b" AS TEXT) AS "b"', + ], + ), + ], +) +def test_casted_columns( + columns_to_types: t.Dict[str, exp.DataType], source_columns: t.List[str], expected: t.List[str] +) -> None: + assert [ + x.sql() for x in EngineAdapter._casted_columns(columns_to_types, source_columns) + ] == expected + + +def test_data_object_cache_get_data_objects( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(EngineAdapter, patch_get_data_objects=False) + + table1 = DataObject(catalog=None, schema="test_schema", name="table1", type="table") + table2 = DataObject(catalog=None, schema="test_schema", name="table2", type="table") + + mock_get_data_objects = mocker.patch.object( + adapter, "_get_data_objects", return_value=[table1, table2] + ) + + result1 = adapter.get_data_objects("test_schema", {"table1", "table2"}, safe_to_cache=True) + assert len(result1) == 2 + assert mock_get_data_objects.call_count == 1 + + result2 = adapter.get_data_objects("test_schema", {"table1", "table2"}, safe_to_cache=True) + assert len(result2) == 2 + assert mock_get_data_objects.call_count == 1 # Should not increase + + result3 = adapter.get_data_objects("test_schema", {"table1"}) + assert len(result3) == 1 + assert result3[0].name == "table1" + assert mock_get_data_objects.call_count == 1 # Should not increase + + +def test_data_object_cache_get_data_objects_bypasses_cache( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(EngineAdapter, patch_get_data_objects=False) + + table1 = DataObject(catalog=None, schema="test_schema", name="table1", type="table") + table2 = DataObject(catalog=None, schema="test_schema", name="table2", type="table") + + mock_get_data_objects = mocker.patch.object( + adapter, "_get_data_objects", return_value=[table1, table2] + ) + + assert adapter.get_data_objects("test_schema") + assert adapter.get_data_objects("test_schema", {"table1", "table2"}) + assert adapter.get_data_objects("test_schema", {"table1", "table2"}) + assert adapter.get_data_objects("test_schema", {"table1"}) + assert adapter.get_data_object("test_schema.table1") is not None + + mock_get_data_objects.return_value = [] + assert not adapter.get_data_objects("test_schema") + assert not adapter.get_data_objects("test_schema", {"missing"}) + assert not adapter.get_data_objects("test_schema", {"missing"}) + assert adapter.get_data_object("test_schema.missing") is None + + # None of the calls should've been cached + assert mock_get_data_objects.call_count == 9 + assert not adapter._data_object_cache + + +def test_data_object_cache_get_data_objects_no_object_names( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(EngineAdapter, patch_get_data_objects=False) + + table1 = DataObject(catalog=None, schema="test_schema", name="table1", type="table") + table2 = DataObject(catalog=None, schema="test_schema", name="table2", type="table") + + mock_get_data_objects = mocker.patch.object( + adapter, "_get_data_objects", return_value=[table1, table2] + ) + + result1 = adapter.get_data_objects("test_schema", safe_to_cache=True) + assert len(result1) == 2 + assert mock_get_data_objects.call_count == 1 + + result2 = adapter.get_data_objects("test_schema", {"table1", "table2"}, safe_to_cache=True) + assert len(result2) == 2 + assert mock_get_data_objects.call_count == 1 # Should not increase + + +def test_data_object_cache_get_data_object( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(EngineAdapter, patch_get_data_objects=False) + + table = DataObject(catalog=None, schema="test_schema", name="test_table", type="table") + + mock_get_data_objects = mocker.patch.object(adapter, "_get_data_objects", return_value=[table]) + + result1 = adapter.get_data_object("test_schema.test_table", safe_to_cache=True) + assert result1 is not None + assert result1.name == "test_table" + assert mock_get_data_objects.call_count == 1 + + result2 = adapter.get_data_object("test_schema.test_table", safe_to_cache=True) + assert result2 is not None + assert result2.name == "test_table" + assert mock_get_data_objects.call_count == 1 # Should not increase + + +def test_data_object_cache_cleared_on_drop_table( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(EngineAdapter, patch_get_data_objects=False) + + table = DataObject(catalog=None, schema="test_schema", name="test_table", type="table") + + mock_get_data_objects = mocker.patch.object(adapter, "_get_data_objects", return_value=[table]) + + adapter.get_data_object("test_schema.test_table", safe_to_cache=True) + assert mock_get_data_objects.call_count == 1 + + adapter.drop_table("test_schema.test_table") + + mock_get_data_objects.return_value = [] + result = adapter.get_data_object("test_schema.test_table", safe_to_cache=True) + assert result is None + assert mock_get_data_objects.call_count == 2 + + +def test_data_object_cache_cleared_on_drop_view( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(EngineAdapter, patch_get_data_objects=False) + + view = DataObject(catalog=None, schema="test_schema", name="test_view", type="view") + + mock_get_data_objects = mocker.patch.object(adapter, "_get_data_objects", return_value=[view]) + + adapter.get_data_object("test_schema.test_view", safe_to_cache=True) + assert mock_get_data_objects.call_count == 1 + + adapter.drop_view("test_schema.test_view") + + mock_get_data_objects.return_value = [] + result = adapter.get_data_object("test_schema.test_view", safe_to_cache=True) + assert result is None + assert mock_get_data_objects.call_count == 2 + + +def test_data_object_cache_cleared_on_drop_data_object( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(EngineAdapter, patch_get_data_objects=False) + + table = DataObject(catalog=None, schema="test_schema", name="test_table", type="table") + + mock_get_data_objects = mocker.patch.object(adapter, "_get_data_objects", return_value=[table]) + + adapter.get_data_object("test_schema.test_table", safe_to_cache=True) + assert mock_get_data_objects.call_count == 1 + + adapter.drop_data_object(table) + + mock_get_data_objects.return_value = [] + result = adapter.get_data_object("test_schema.test_table", safe_to_cache=True) + assert result is None + assert mock_get_data_objects.call_count == 2 + + +def test_data_object_cache_cleared_on_create_table( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + from sqlglot import exp + + adapter = make_mocked_engine_adapter(EngineAdapter, patch_get_data_objects=False) + + # Initially cache that table doesn't exist + mock_get_data_objects = mocker.patch.object(adapter, "_get_data_objects", return_value=[]) + result = adapter.get_data_object("test_schema.test_table", safe_to_cache=True) + assert result is None + assert mock_get_data_objects.call_count == 1 + + # Create the table + table = DataObject(catalog=None, schema="test_schema", name="test_table", type="table") + mock_get_data_objects.return_value = [table] + adapter.create_table( + "test_schema.test_table", + {"col1": exp.DataType.build("INT")}, + ) + + # Cache should be cleared, so next get_data_object should call _get_data_objects again + result = adapter.get_data_object("test_schema.test_table", safe_to_cache=True) + assert result is not None + assert mock_get_data_objects.call_count == 2 + + +def test_data_object_cache_cleared_on_create_view( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + from sqlglot import parse_one + + adapter = make_mocked_engine_adapter(EngineAdapter, patch_get_data_objects=False) + + # Initially cache that view doesn't exist + mock_get_data_objects = mocker.patch.object(adapter, "_get_data_objects", return_value=[]) + result = adapter.get_data_object("test_schema.test_view", safe_to_cache=True) + assert result is None + assert mock_get_data_objects.call_count == 1 + + # Create the view + view = DataObject(catalog=None, schema="test_schema", name="test_view", type="view") + mock_get_data_objects.return_value = [view] + adapter.create_view( + "test_schema.test_view", + parse_one("SELECT 1 AS col1"), + ) + + # Cache should be cleared, so next get_data_object should call _get_data_objects again + result = adapter.get_data_object("test_schema.test_view", safe_to_cache=True) + assert result is not None + assert mock_get_data_objects.call_count == 2 + + +def test_data_object_cache_cleared_on_clone_table( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + from sqlmesh.core.engine_adapter.snowflake import SnowflakeEngineAdapter + + adapter = make_mocked_engine_adapter( + SnowflakeEngineAdapter, patch_get_data_objects=False, default_catalog="test_catalog" + ) + + # Initially cache that target table doesn't exist + mock_get_data_objects = mocker.patch.object(adapter, "_get_data_objects", return_value=[]) + result = adapter.get_data_object("test_schema.test_target", safe_to_cache=True) + assert result is None + assert mock_get_data_objects.call_count == 1 + + # Clone the table + target_table = DataObject( + catalog="test_catalog", schema="test_schema", name="test_target", type="table" + ) + mock_get_data_objects.return_value = [target_table] + adapter.clone_table("test_schema.test_target", "test_schema.test_source") + + # Cache should be cleared, so next get_data_object should call _get_data_objects again + result = adapter.get_data_object("test_schema.test_target", safe_to_cache=True) + assert result is not None + assert mock_get_data_objects.call_count == 2 + + +def test_data_object_cache_with_catalog( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + from sqlmesh.core.engine_adapter.snowflake import SnowflakeEngineAdapter + + adapter = make_mocked_engine_adapter( + SnowflakeEngineAdapter, patch_get_data_objects=False, default_catalog="test_catalog" + ) + + table = DataObject( + catalog="test_catalog", schema="test_schema", name="test_table", type="table" + ) + + mock_get_data_objects = mocker.patch.object(adapter, "_get_data_objects", return_value=[table]) + + result1 = adapter.get_data_object("test_catalog.test_schema.test_table", safe_to_cache=True) + assert result1 is not None + assert result1.catalog == "test_catalog" + assert mock_get_data_objects.call_count == 1 + + result2 = adapter.get_data_object("test_catalog.test_schema.test_table", safe_to_cache=True) + assert result2 is not None + assert result2.catalog == "test_catalog" + assert mock_get_data_objects.call_count == 1 # Should not increase + + +def test_data_object_cache_partial_cache_hit( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(EngineAdapter, patch_get_data_objects=False) + + table1 = DataObject(catalog=None, schema="test_schema", name="table1", type="table") + table2 = DataObject(catalog=None, schema="test_schema", name="table2", type="table") + table3 = DataObject(catalog=None, schema="test_schema", name="table3", type="table") + + mock_get_data_objects = mocker.patch.object( + adapter, "_get_data_objects", return_value=[table1, table2] + ) + + adapter.get_data_objects("test_schema", {"table1", "table2"}, safe_to_cache=True) + assert mock_get_data_objects.call_count == 1 + + mock_get_data_objects.return_value = [table3] + result = adapter.get_data_objects("test_schema", {"table1", "table3"}, safe_to_cache=True) + + assert len(result) == 2 + assert {obj.name for obj in result} == {"table1", "table3"} + assert mock_get_data_objects.call_count == 2 # Called again for table3 + + +def test_data_object_cache_get_data_objects_missing_objects( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(EngineAdapter, patch_get_data_objects=False) + + table1 = DataObject(catalog=None, schema="test_schema", name="table1", type="table") + table2 = DataObject(catalog=None, schema="test_schema", name="table2", type="table") + + mock_get_data_objects = mocker.patch.object(adapter, "_get_data_objects", return_value=[]) + + result1 = adapter.get_data_objects("test_schema", {"table1", "table2"}, safe_to_cache=True) + assert not result1 + assert mock_get_data_objects.call_count == 1 + + result2 = adapter.get_data_objects("test_schema", {"table1", "table2"}, safe_to_cache=True) + assert not result2 + assert mock_get_data_objects.call_count == 1 # Should not increase + + result3 = adapter.get_data_objects("test_schema", {"table1"}, safe_to_cache=True) + assert not result3 + assert mock_get_data_objects.call_count == 1 # Should not increase + + +def test_data_object_cache_cleared_on_rename_table( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(EngineAdapter, patch_get_data_objects=False) + + old_table = DataObject(catalog=None, schema="test_schema", name="old_table", type="table") + mock_get_data_objects = mocker.patch.object( + adapter, "_get_data_objects", return_value=[old_table] + ) + + result = adapter.get_data_object("test_schema.old_table", safe_to_cache=True) + assert result is not None + assert result.name == "old_table" + assert mock_get_data_objects.call_count == 1 + + new_table = DataObject(catalog=None, schema="test_schema", name="new_table", type="table") + mock_get_data_objects.return_value = [new_table] + adapter.rename_table("test_schema.old_table", "test_schema.new_table") + + mock_get_data_objects.return_value = [] + result = adapter.get_data_object("test_schema.old_table", safe_to_cache=True) + assert result is None + assert mock_get_data_objects.call_count == 2 + + mock_get_data_objects.return_value = [new_table] + result = adapter.get_data_object("test_schema.new_table", safe_to_cache=True) + assert result is not None + assert result.name == "new_table" + assert mock_get_data_objects.call_count == 3 + + +def test_data_object_cache_cleared_on_create_table_like( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + from sqlglot import exp + + adapter = make_mocked_engine_adapter(EngineAdapter, patch_get_data_objects=False) + + columns_to_types = { + "col1": exp.DataType.build("INT"), + "col2": exp.DataType.build("TEXT"), + } + mocker.patch.object(adapter, "columns", return_value=columns_to_types) + + mock_get_data_objects = mocker.patch.object(adapter, "_get_data_objects", return_value=[]) + result = adapter.get_data_object("test_schema.target_table", safe_to_cache=True) + assert result is None + assert mock_get_data_objects.call_count == 1 + + target_table = DataObject(catalog=None, schema="test_schema", name="target_table", type="table") + mock_get_data_objects.return_value = [target_table] + adapter.create_table_like("test_schema.target_table", "test_schema.source_table") + + result = adapter.get_data_object("test_schema.target_table", safe_to_cache=True) + assert result is not None + assert result.name == "target_table" + assert mock_get_data_objects.call_count == 2 + + +def test_diff_grants_configs(): + new = {"SELECT": ["u1", "u2"], "INSERT": ["u1"]} + old = {"SELECT": ["u1", "u3"], "update": ["u1"]} + + additions, removals = EngineAdapter._diff_grants_configs(new, old) + + assert additions.get("SELECT") and set(additions["SELECT"]) == {"u2"} + assert removals.get("SELECT") and set(removals["SELECT"]) == {"u3"} + + assert additions.get("INSERT") and set(additions["INSERT"]) == {"u1"} + assert removals.get("update") and set(removals["update"]) == {"u1"} + + for perm, grantees in additions.items(): + assert set(grantees).isdisjoint(set(old.get(perm, []))) + for perm, grantees in removals.items(): + assert set(grantees).isdisjoint(set(new.get(perm, []))) + + +def test_diff_grants_configs_empty_new(): + new = {} + old = {"SELECT": ["u1", "u2"], "INSERT": ["u3"]} + + additions, removals = EngineAdapter._diff_grants_configs(new, old) + + assert additions == {} + assert removals == old + + +def test_diff_grants_configs_empty_old(): + new = {"SELECT": ["u1", "u2"], "INSERT": ["u3"]} + old = {} + + additions, removals = EngineAdapter._diff_grants_configs(new, old) + + assert additions == new + assert removals == {} + + +def test_diff_grants_configs_identical(): + grants = {"SELECT": ["u1", "u2"], "INSERT": ["u3"]} + + additions, removals = EngineAdapter._diff_grants_configs(grants, grants) + + assert additions == {} + assert removals == {} + + +def test_diff_grants_configs_none_configs(): + grants = {"SELECT": ["u1"]} + + additions, removals = EngineAdapter._diff_grants_configs(grants, {}) + assert additions == grants + assert removals == {} + + additions, removals = EngineAdapter._diff_grants_configs({}, grants) + assert additions == {} + assert removals == grants + + additions, removals = EngineAdapter._diff_grants_configs({}, {}) + assert additions == {} + assert removals == {} + + +def test_diff_grants_configs_duplicate_grantees(): + new = {"SELECT": ["u1", "u2", "u1"]} + old = {"SELECT": ["u2", "u3", "u2"]} + + additions, removals = EngineAdapter._diff_grants_configs(new, old) + + assert additions["SELECT"] == ["u1", "u1"] + assert removals["SELECT"] == ["u3"] + + +def test_diff_grants_configs_case_sensitive(): + new = {"select": ["u1"], "SELECT": ["u2"]} + old = {"Select": ["u3"]} + + additions, removals = EngineAdapter._diff_grants_configs(new, old) + + assert set(additions.keys()) == {"select", "SELECT"} + assert set(removals.keys()) == {"Select"} + assert additions["select"] == ["u1"] + assert additions["SELECT"] == ["u2"] + assert removals["Select"] == ["u3"] + + +def test_sync_grants_config_unsupported_engine(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.SUPPORTS_GRANTS = False + + relation = exp.to_table("test_table") + grants_config = {"SELECT": ["user1"]} + + with pytest.raises(NotImplementedError, match="Engine does not support grants"): + adapter.sync_grants_config(relation, grants_config) + + +def test_get_current_grants_config_not_implemented(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(EngineAdapter) + relation = exp.to_table("test_table") + + with pytest.raises(NotImplementedError): + adapter._get_current_grants_config(relation) diff --git a/tests/core/engine_adapter/test_base_postgres.py b/tests/core/engine_adapter/test_base_postgres.py index e586dde7b3..f286c47c56 100644 --- a/tests/core/engine_adapter/test_base_postgres.py +++ b/tests/core/engine_adapter/test_base_postgres.py @@ -3,6 +3,7 @@ from unittest.mock import call import pytest +from pytest_mock.plugin import MockerFixture from sqlglot import exp, parse_one from sqlmesh.core.engine_adapter.base_postgres import BasePostgresEngineAdapter @@ -17,7 +18,7 @@ def test_columns(make_mocked_engine_adapter: t.Callable): resp = adapter.columns("db.table") adapter.cursor.execute.assert_called_once_with( 'SELECT "attname" AS "column_name", ' - '"pg_catalog".FORMAT_TYPE("atttypid", "atttypmod") AS "data_type" ' + '"pg_catalog".format_type("atttypid", "atttypmod") AS "data_type" ' 'FROM "pg_catalog"."pg_attribute" ' 'JOIN "pg_catalog"."pg_class" ON "pg_class"."oid" = "attrelid" ' 'JOIN "pg_catalog"."pg_namespace" ON "pg_namespace"."oid" = "relnamespace" ' @@ -75,3 +76,26 @@ def test_drop_view(make_mocked_engine_adapter: t.Callable): call('DROP VIEW IF EXISTS "db"."view"'), ] ) + + +def test_get_current_schema(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): + adapter = make_mocked_engine_adapter(BasePostgresEngineAdapter) + + fetchone_mock = mocker.patch.object(adapter, "fetchone", return_value=("test_schema",)) + result = adapter._get_current_schema() + + assert result == "test_schema" + fetchone_mock.assert_called_once() + executed_query = fetchone_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="postgres") + assert executed_sql == "SELECT CURRENT_SCHEMA" + + fetchone_mock.reset_mock() + fetchone_mock.return_value = None + result = adapter._get_current_schema() + assert result == "public" + + fetchone_mock.reset_mock() + fetchone_mock.return_value = (None,) # search_path = '' or 'nonexistent_schema' + result = adapter._get_current_schema() + assert result == "public" diff --git a/tests/core/engine_adapter/test_bigquery.py b/tests/core/engine_adapter/test_bigquery.py index 2906e1fc43..9a6bc7d851 100644 --- a/tests/core/engine_adapter/test_bigquery.py +++ b/tests/core/engine_adapter/test_bigquery.py @@ -1,8 +1,8 @@ # type: ignore -import sys import typing as t +from datetime import datetime -import pandas as pd +import pandas as pd # noqa: TID253 import pytest from google.cloud import bigquery from pytest_mock.plugin import MockerFixture @@ -13,12 +13,21 @@ import sqlmesh.core.dialect as d from sqlmesh.core.engine_adapter import BigQueryEngineAdapter from sqlmesh.core.engine_adapter.bigquery import select_partitions_expr +from sqlmesh.core.engine_adapter.shared import DataObjectType from sqlmesh.core.node import IntervalUnit from sqlmesh.utils import AttributeDict +from sqlmesh.utils.errors import SQLMeshError pytestmark = [pytest.mark.bigquery, pytest.mark.engine] +@pytest.fixture +def adapter(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture) -> BigQueryEngineAdapter: + mocked_adapter = make_mocked_engine_adapter(BigQueryEngineAdapter) + mocker.patch("sqlmesh.core.engine_adapter.bigquery.BigQueryEngineAdapter.execute") + return mocked_adapter + + def test_insert_overwrite_by_time_partition_query( make_mocked_engine_adapter: t.Callable, mocker: MockerFixture ): @@ -33,7 +42,7 @@ def test_insert_overwrite_by_time_partition_query( end="2022-01-05", time_formatter=lambda x, _: exp.Literal.string(x.strftime("%Y-%m-%d")), time_column="ds", - columns_to_types={ + target_columns_to_types={ "a": exp.DataType.build("int"), "ds": exp.DataType.build("string"), }, @@ -51,7 +60,7 @@ def test_insert_overwrite_by_partition_query( execute_mock = mocker.patch( "sqlmesh.core.engine_adapter.bigquery.BigQueryEngineAdapter.execute" ) - + adapter._default_catalog = "test_project" temp_table_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table") table_name = "test_schema.test_table" temp_table_id = "abcdefgh" @@ -63,7 +72,7 @@ def test_insert_overwrite_by_partition_query( partitioned_by=[ d.parse_one("DATETIME_TRUNC(ds, MONTH)"), ], - columns_to_types={ + target_columns_to_types={ "a": exp.DataType.build("int"), "ds": exp.DataType.build("DATETIME"), }, @@ -73,7 +82,7 @@ def test_insert_overwrite_by_partition_query( assert sql_calls == [ "CREATE SCHEMA IF NOT EXISTS `test_schema`", f"CREATE TABLE IF NOT EXISTS `test_schema`.`__temp_test_table_{temp_table_id}` PARTITION BY DATETIME_TRUNC(`ds`, MONTH) AS SELECT `a`, `ds` FROM `tbl`", - f"DECLARE _sqlmesh_target_partitions_ ARRAY DEFAULT (SELECT ARRAY_AGG(PARSE_DATETIME('%Y%m', partition_id)) FROM `test_schema`.INFORMATION_SCHEMA.PARTITIONS WHERE table_name = '__temp_test_table_{temp_table_id}' AND NOT partition_id IS NULL AND partition_id <> '__NULL__');", + f"DECLARE _sqlmesh_target_partitions_ ARRAY DEFAULT (SELECT ARRAY_AGG(PARSE_DATETIME('%Y%m', partition_id)) FROM `test_project`.`test_schema`.`INFORMATION_SCHEMA.PARTITIONS` AS PARTITIONS WHERE table_name = '__temp_test_table_{temp_table_id}' AND NOT partition_id IS NULL AND partition_id <> '__NULL__');", f"MERGE INTO `test_schema`.`test_table` AS `__MERGE_TARGET__` USING (SELECT `a`, `ds` FROM (SELECT * FROM `test_schema`.`__temp_test_table_{temp_table_id}`) AS `_subquery` WHERE DATETIME_TRUNC(`ds`, MONTH) IN UNNEST(`_sqlmesh_target_partitions_`)) AS `__MERGE_SOURCE__` ON FALSE WHEN NOT MATCHED BY SOURCE AND DATETIME_TRUNC(`ds`, MONTH) IN UNNEST(`_sqlmesh_target_partitions_`) THEN DELETE WHEN NOT MATCHED THEN INSERT (`a`, `ds`) VALUES (`a`, `ds`)", f"DROP TABLE IF EXISTS `test_schema`.`__temp_test_table_{temp_table_id}`", ] @@ -94,7 +103,7 @@ def test_insert_overwrite_by_partition_query_unknown_column_types( "a": exp.DataType.build("int"), "ds": exp.DataType.build("DATETIME"), } - + adapter._default_catalog = "test_project" temp_table_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table") table_name = "test_schema.test_table" temp_table_id = "abcdefgh" @@ -106,21 +115,19 @@ def test_insert_overwrite_by_partition_query_unknown_column_types( partitioned_by=[ d.parse_one("DATETIME_TRUNC(ds, MONTH)"), ], - columns_to_types={ + target_columns_to_types={ "a": exp.DataType.build("unknown"), "ds": exp.DataType.build("UNKNOWN"), }, ) - columns_mock.assert_called_once_with( - exp.to_table(f"test_schema.__temp_test_table_{temp_table_id}") - ) + columns_mock.assert_called_once_with(table_name) sql_calls = _to_sql_calls(execute_mock) assert sql_calls == [ "CREATE SCHEMA IF NOT EXISTS `test_schema`", f"CREATE TABLE IF NOT EXISTS `test_schema`.`__temp_test_table_{temp_table_id}` PARTITION BY DATETIME_TRUNC(`ds`, MONTH) AS SELECT `a`, `ds` FROM `tbl`", - f"DECLARE _sqlmesh_target_partitions_ ARRAY DEFAULT (SELECT ARRAY_AGG(PARSE_DATETIME('%Y%m', partition_id)) FROM `test_schema`.INFORMATION_SCHEMA.PARTITIONS WHERE table_name = '__temp_test_table_{temp_table_id}' AND NOT partition_id IS NULL AND partition_id <> '__NULL__');", + f"DECLARE _sqlmesh_target_partitions_ ARRAY DEFAULT (SELECT ARRAY_AGG(PARSE_DATETIME('%Y%m', partition_id)) FROM `test_project`.`test_schema`.`INFORMATION_SCHEMA.PARTITIONS` AS PARTITIONS WHERE table_name = '__temp_test_table_{temp_table_id}' AND NOT partition_id IS NULL AND partition_id <> '__NULL__');", f"MERGE INTO `test_schema`.`test_table` AS `__MERGE_TARGET__` USING (SELECT `a`, `ds` FROM (SELECT * FROM `test_schema`.`__temp_test_table_{temp_table_id}`) AS `_subquery` WHERE DATETIME_TRUNC(`ds`, MONTH) IN UNNEST(`_sqlmesh_target_partitions_`)) AS `__MERGE_SOURCE__` ON FALSE WHEN NOT MATCHED BY SOURCE AND DATETIME_TRUNC(`ds`, MONTH) IN UNNEST(`_sqlmesh_target_partitions_`) THEN DELETE WHEN NOT MATCHED THEN INSERT (`a`, `ds`) VALUES (`a`, `ds`)", f"DROP TABLE IF EXISTS `test_schema`.`__temp_test_table_{temp_table_id}`", ] @@ -173,7 +180,7 @@ def temp_table_exists(table: exp.Table) -> bool: end="2022-01-05", time_formatter=lambda x, _: exp.Literal.string(x.strftime("%Y-%m-%d")), time_column="ds", - columns_to_types={ + target_columns_to_types={ "a": exp.DataType.build("int"), "ds": exp.DataType.build("string"), }, @@ -181,15 +188,13 @@ def temp_table_exists(table: exp.Table) -> bool: assert execute_mock.call_count == 2 assert retry_resp.call_count == 1 assert db_call_mock.call_count == 1 + create_temp_table = db_call_mock.call_args_list[0] load_temp_table = retry_resp.call_args_list[0] merge, drop_temp_table = execute_mock.call_args_list merge_sql = merge[0][0] drop_temp_table_sql = drop_temp_table[0][0] - if sys.version_info < (3, 8): - create_temp_table.kwargs = create_temp_table[1] - load_temp_table.kwargs = load_temp_table[1] - drop_temp_table.kwargs = drop_temp_table[1] + assert create_temp_table.kwargs == { "exists_ok": False, "table": get_temp_bq_table.return_value, @@ -204,7 +209,7 @@ def temp_table_exists(table: exp.Table) -> bool: assert load_temp_table.kwargs["job_config"].write_disposition is None assert ( merge_sql.sql(dialect="bigquery") - == "MERGE INTO test_table AS __MERGE_TARGET__ USING (SELECT `a`, `ds` FROM (SELECT `a`, `ds` FROM project.dataset.temp_table) AS _subquery WHERE ds BETWEEN '2022-01-01' AND '2022-01-05') AS __MERGE_SOURCE__ ON FALSE WHEN NOT MATCHED BY SOURCE AND ds BETWEEN '2022-01-01' AND '2022-01-05' THEN DELETE WHEN NOT MATCHED THEN INSERT (a, ds) VALUES (a, ds)" + == "MERGE INTO test_table AS __MERGE_TARGET__ USING (SELECT `a`, `ds` FROM (SELECT CAST(`a` AS INT64) AS `a`, CAST(`ds` AS STRING) AS `ds` FROM project.dataset.temp_table) AS _subquery WHERE ds BETWEEN '2022-01-01' AND '2022-01-05') AS __MERGE_SOURCE__ ON FALSE WHEN NOT MATCHED BY SOURCE AND ds BETWEEN '2022-01-01' AND '2022-01-05' THEN DELETE WHEN NOT MATCHED THEN INSERT (a, ds) VALUES (a, ds)" ) assert ( drop_temp_table_sql.sql(dialect="bigquery") @@ -294,7 +299,7 @@ def temp_table_exists(table: exp.Table) -> bool: ] sql_calls = _to_sql_calls(execute_mock) assert sql_calls == [ - "CREATE OR REPLACE TABLE `test_table` AS SELECT CAST(`a` AS INT64) AS `a`, CAST(`b` AS INT64) AS `b` FROM (SELECT `a`, `b` FROM `project`.`dataset`.`temp_table`) AS `_subquery`", + "CREATE OR REPLACE TABLE `test_table` AS SELECT CAST(`a` AS INT64) AS `a`, CAST(`b` AS INT64) AS `b` FROM (SELECT CAST(`a` AS INT64) AS `a`, CAST(`b` AS INT64) AS `b` FROM `project`.`dataset`.`temp_table`) AS `_subquery`", "DROP TABLE IF EXISTS `project`.`dataset`.`temp_table`", ] @@ -322,7 +327,7 @@ def test_create_table_date_partition( {"a": exp.DataType.build("int"), "b": exp.DataType.build("int")}, partitioned_by=partition_by_cols, partition_interval_unit=IntervalUnit.DAY, - clustered_by=["b"], + clustered_by=[exp.column("b")], ) sql_calls = _to_sql_calls(execute_mock) @@ -430,7 +435,7 @@ def test_merge(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): adapter.merge( target_table="target", source_table=parse_one("SELECT id, ts, val FROM source"), - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.Type.INT, "ts": exp.DataType.Type.TIMESTAMP, "val": exp.DataType.Type.INT, @@ -483,11 +488,17 @@ def temp_table_exists(table: exp.Table) -> bool: retry_resp_call.errors = None retry_mock.return_value = retry_resp db_call_mock.return_value = AttributeDict({"errors": None}) - df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) + df = pd.DataFrame( + { + "id": [1, 2, 3], + "ts": ["2025-01-01 00:00:00", "2025-01-01 00:00:00", "2025-01-01 00:00:00"], + "val": [7, 8, 9], + } + ) adapter.merge( target_table="target", source_table=df, - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("INT"), "ts": exp.DataType.build("TIMESTAMP"), "val": exp.DataType.build("INT"), @@ -497,7 +508,7 @@ def temp_table_exists(table: exp.Table) -> bool: sql_calls = _to_sql_calls(execute_mock, identify=False) assert sql_calls == [ - "MERGE INTO target AS __MERGE_TARGET__ USING (SELECT `id`, `ts`, `val` FROM project.dataset.temp_table) AS __MERGE_SOURCE__ ON __MERGE_TARGET__.id = __MERGE_SOURCE__.id " + "MERGE INTO target AS __MERGE_TARGET__ USING (SELECT CAST(`id` AS INT64) AS `id`, CAST(`ts` AS DATETIME) AS `ts`, CAST(`val` AS INT64) AS `val` FROM project.dataset.temp_table) AS __MERGE_SOURCE__ ON __MERGE_TARGET__.id = __MERGE_SOURCE__.id " "WHEN MATCHED THEN UPDATE SET __MERGE_TARGET__.id = __MERGE_SOURCE__.id, __MERGE_TARGET__.ts = __MERGE_SOURCE__.ts, __MERGE_TARGET__.val = __MERGE_SOURCE__.val " "WHEN NOT MATCHED THEN INSERT (id, ts, val) VALUES (__MERGE_SOURCE__.id, __MERGE_SOURCE__.ts, __MERGE_SOURCE__.val)", "DROP TABLE IF EXISTS project.dataset.temp_table", @@ -531,6 +542,7 @@ def test_begin_end_session(mocker: MockerFixture): adapter = BigQueryEngineAdapter(lambda: connection_mock, job_retries=0) + # starting a session without session properties with adapter.session({}): assert adapter._connection_pool.get_attribute("session_id") is not None adapter.execute("SELECT 2;") @@ -551,17 +563,40 @@ def test_begin_end_session(mocker: MockerFixture): assert execute_b_call[1]["query"] == "SELECT 3;" assert not execute_b_call[1]["job_config"].connection_properties + # starting a new session with session property query_label and array value + with adapter.session({"query_label": parse_one("[('key1', 'value1'), ('key2', 'value2')]")}): + adapter.execute("SELECT 4;") + begin_new_session_call = connection_mock._client.query.call_args_list[3] + assert begin_new_session_call[0][0] == 'SET @@query_label = "key1:value1,key2:value2";SELECT 1;' + + # starting a new session with session property query_label and Paren value + with adapter.session({"query_label": parse_one("(('key1', 'value1'))")}): + adapter.execute("SELECT 5;") + begin_new_session_call = connection_mock._client.query.call_args_list[5] + assert begin_new_session_call[0][0] == 'SET @@query_label = "key1:value1";SELECT 1;' + + # test invalid query_label value + with pytest.raises( + SQLMeshError, + match="Invalid value for `session_properties.query_label`. Must be an array or tuple.", + ): + with adapter.session({"query_label": parse_one("'key1:value1'")}): + adapter.execute("SELECT 6;") + def _to_sql_calls(execute_mock: t.Any, identify: bool = True) -> t.List[str]: + if isinstance(execute_mock, BigQueryEngineAdapter): + execute_mock = execute_mock.execute output = [] for call in execute_mock.call_args_list: - value = call[0][0] - sql = ( - value.sql(dialect="bigquery", identify=identify) - if isinstance(value, exp.Expression) - else str(value) - ) - output.append(sql) + values = ensure_list(call[0][0]) + for value in values: + sql = ( + value.sql(dialect="bigquery", identify=identify) + if isinstance(value, exp.Expression) + else str(value) + ) + output.append(sql) return output @@ -623,27 +658,167 @@ def test_comments(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture) column_descriptions={"a": long_column_comment}, ) + adapter._create_table_comment( + "test_table", + long_table_comment, + ) + + # Only called if column comments are registered + get_table_mock = mocker.patch( + "sqlmesh.core.engine_adapter.bigquery.BigQueryEngineAdapter._get_table" + ) + + db_call_mock = mocker.patch( + "sqlmesh.core.engine_adapter.bigquery.BigQueryEngineAdapter._db_call" + ) + adapter.create_view( "test_table", parse_one("SELECT a, b FROM source_table"), table_description=long_table_comment, + column_descriptions={"a": long_column_comment}, ) + assert get_table_mock.called - adapter._create_table_comment( + # Bigquery doesn't support column comments for materialized views + db_call_mock.reset_mock() + + adapter.create_view( "test_table", - long_table_comment, + parse_one("SELECT a, b FROM source_table"), + table_description=long_table_comment, + column_descriptions={"a": long_column_comment}, + materialized=True, ) + assert not db_call_mock.called sql_calls = _to_sql_calls(execute_mock) assert sql_calls == [ f"CREATE TABLE IF NOT EXISTS `test_table` (`a` INT64 OPTIONS (description='{truncated_column_comment}'), `b` INT64) OPTIONS (description='{truncated_table_comment}')", "CREATE TABLE IF NOT EXISTS `test_table` (`a` INT64 OPTIONS (description='\\\\'), `b` INT64) OPTIONS (description='\\\\')", f"CREATE TABLE IF NOT EXISTS `test_table` (`a` INT64 OPTIONS (description='{truncated_column_comment}'), `b` INT64) OPTIONS (description='{truncated_table_comment}') AS SELECT CAST(`a` AS INT64) AS `a`, CAST(`b` AS INT64) AS `b` FROM (SELECT `a`, `b` FROM `source_table`) AS `_subquery`", - f"CREATE OR REPLACE VIEW `test_table` OPTIONS (description='{truncated_table_comment}') AS SELECT `a`, `b` FROM `source_table`", f"ALTER TABLE `test_table` SET OPTIONS(description = '{truncated_table_comment}')", + f"CREATE OR REPLACE VIEW `test_table` OPTIONS (description='{truncated_table_comment}') AS SELECT `a`, `b` FROM `source_table`", + f"CREATE OR REPLACE MATERIALIZED VIEW `test_table` OPTIONS (description='{truncated_table_comment}') AS SELECT `a`, `b` FROM `source_table`", ] +def test_nested_comments(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): + adapter = make_mocked_engine_adapter(BigQueryEngineAdapter) + + allowed_column_comment_length = BigQueryEngineAdapter.MAX_COLUMN_COMMENT_LENGTH + + execute_mock = mocker.patch( + "sqlmesh.core.engine_adapter.bigquery.BigQueryEngineAdapter.execute" + ) + + nested_columns_to_types = { + "record with space": exp.DataType.build( + "STRUCT<`int_field` INT, `record_field` STRUCT<`sub_record_field` STRUCT<`nest_array` ARRAY>>>", + dialect="bigquery", + ), + "repeated_record": exp.DataType.build( + "ARRAY, `struct_field with space` STRUCT<`nested_field` INT>>>>", + dialect="bigquery", + ), + "same_name_": exp.DataType.build( + "ARRAY>>>", + dialect="bigquery", + ), + } + + long_column_descriptions = { + "record with space": "Top Record", + "record with space.int_field": "Record Int Field", + "record with space.record_field": "Record Nested Record Field", + "record with space.record_field.sub_record_field": "Record Nested Records Subfield", + "record with space.record_field.sub_record_field.nest_array": "Record Nested Records Nested Array", + "repeated_record": "Top Repeated Record", + "repeated_record.nested_repeated_record": "Nested Repeated Record", + "repeated_record.nested_repeated_record.int_field": "Nested Repeated Record Int Field", + "repeated_record.nested_repeated_record.array_field": "Nested Repeated Array Field", + "repeated_record.nested_repeated_record.struct_field with space": "Nested Repeated Struct Field", + "repeated_record.nested_repeated_record.struct_field with space.nested_field": "Nested Repeated Record Nested Field", + "same_name_": "Level 1", + "same_name_.same_name_.same_name_": "Level 3", + "same_name_.same_name_.same_name_.same_name_": "4" * allowed_column_comment_length + "X", + } + + adapter.create_table( + "test_table", + target_columns_to_types=nested_columns_to_types, + column_descriptions=long_column_descriptions, + ) + + adapter.ctas( + "test_table", + parse_one("SELECT * FROM source_table"), + target_columns_to_types=nested_columns_to_types, + column_descriptions=long_column_descriptions, + ) + + sql_calls = _to_sql_calls(execute_mock) + + # The comments should be added in the correct nested field with appropriate truncation + assert sql_calls[0] == ( + "CREATE TABLE IF NOT EXISTS `test_table` (" + "`record with space` STRUCT<" + "`int_field` INT64 OPTIONS (description='Record Int Field'), " + "`record_field` STRUCT<" + "`sub_record_field` STRUCT<" + "`nest_array` ARRAY OPTIONS (description='Record Nested Records Nested Array')> " + "OPTIONS (description='Record Nested Records Subfield')> " + "OPTIONS (description='Record Nested Record Field')> " + "OPTIONS (description='Top Record'), " + "`repeated_record` ARRAY OPTIONS (description='Nested Repeated Array Field'), " + "`struct_field with space` STRUCT<" + "`nested_field` INT64 OPTIONS (description='Nested Repeated Record Nested Field')> " + "OPTIONS (description='Nested Repeated Struct Field')>> " + "OPTIONS (description='Nested Repeated Record')>> " + "OPTIONS (description='Top Repeated Record'), " + "`same_name_` ARRAY " + "OPTIONS (description='Level 3')>>>> " + "OPTIONS (description='Level 1'))" + ) + + assert sql_calls[1] == ( + "CREATE TABLE IF NOT EXISTS `test_table` (" + "`record with space` STRUCT<" + "`int_field` INT64 OPTIONS (description='Record Int Field'), " + "`record_field` STRUCT<" + "`sub_record_field` STRUCT<" + "`nest_array` ARRAY OPTIONS (description='Record Nested Records Nested Array')> " + "OPTIONS (description='Record Nested Records Subfield')> " + "OPTIONS (description='Record Nested Record Field')> " + "OPTIONS (description='Top Record'), " + "`repeated_record` ARRAY OPTIONS (description='Nested Repeated Array Field'), " + "`struct_field with space` STRUCT<" + "`nested_field` INT64 OPTIONS (description='Nested Repeated Record Nested Field')> " + "OPTIONS (description='Nested Repeated Struct Field')>> " + "OPTIONS (description='Nested Repeated Record')>> " + "OPTIONS (description='Top Repeated Record'), " + "`same_name_` ARRAY " + "OPTIONS (description='Level 3')>>>> " + "OPTIONS (description='Level 1'))" + " AS SELECT CAST(`record with space` AS STRUCT<`int_field` INT64, `record_field` STRUCT<`sub_record_field` STRUCT<`nest_array` ARRAY>>>) AS `record with space`, " + "CAST(`repeated_record` AS ARRAY, `struct_field with space` STRUCT<`nested_field` INT64>>>>>) AS `repeated_record`, " + "CAST(`same_name_` AS ARRAY>>>>) AS `same_name_` " + "FROM (SELECT * FROM `source_table`) AS `_subquery`" + ) + + def test_select_partitions_expr(): assert ( select_partitions_expr( @@ -651,9 +826,9 @@ def test_select_partitions_expr(): "{{ adapter.resolve_identifier(this) }}", "date", granularity="day", - database="{{ target.database }}", + catalog="{{ target.database }}", ) - == "SELECT MAX(PARSE_DATE('%Y%m%d', partition_id)) FROM `{{ target.database }}.{{ adapter.resolve_schema(this) }}.INFORMATION_SCHEMA.PARTITIONS` WHERE table_name = '{{ adapter.resolve_identifier(this) }}' AND NOT partition_id IS NULL AND partition_id <> '__NULL__'" + == "SELECT MAX(PARSE_DATE('%Y%m%d', partition_id)) FROM `{{ target.database }}`.`{{ adapter.resolve_schema(this) }}`.`INFORMATION_SCHEMA.PARTITIONS` AS PARTITIONS WHERE table_name = '{{ adapter.resolve_identifier(this) }}' AND NOT partition_id IS NULL AND partition_id <> '__NULL__'" ) assert ( @@ -662,7 +837,7 @@ def test_select_partitions_expr(): "test_table", "int64", ) - == "SELECT MAX(CAST(partition_id AS INT64)) FROM `test_schema`.INFORMATION_SCHEMA.PARTITIONS WHERE table_name = 'test_table' AND NOT partition_id IS NULL AND partition_id <> '__NULL__'" + == "SELECT MAX(CAST(partition_id AS INT64)) FROM `test_schema`.`INFORMATION_SCHEMA.PARTITIONS` AS PARTITIONS WHERE table_name = 'test_table' AND NOT partition_id IS NULL AND partition_id <> '__NULL__'" ) @@ -755,3 +930,453 @@ def test_view_properties(make_mocked_engine_adapter: t.Callable, mocker: MockerF "CREATE OR REPLACE VIEW `test_table` OPTIONS (description='some description', labels=[('test-view-label', 'label-view-value')]) AS SELECT 1", "CREATE OR REPLACE VIEW `test_table` AS SELECT 1", ] + + +def test_materialized_view_properties( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(BigQueryEngineAdapter) + execute_mock = mocker.patch( + "sqlmesh.core.engine_adapter.bigquery.BigQueryEngineAdapter.execute" + ) + + adapter.create_view( + "test_table", + parse_one("SELECT 1"), + materialized=True, + materialized_properties={ + "partitioned_by": [exp.column("ds")], + "clustered_by": [exp.column("a")], + "partition_interval_unit": IntervalUnit.DAY, + }, + ) + + sql_calls = _to_sql_calls(execute_mock) + # https://cloud.google.com/bigquery/docs/materialized-views-create#example_1 + assert sql_calls == [ + "CREATE OR REPLACE MATERIALIZED VIEW `test_table` PARTITION BY `ds` CLUSTER BY `a` AS SELECT 1", + ] + + +def test_nested_fields_update(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): + adapter = make_mocked_engine_adapter(BigQueryEngineAdapter) + + current_schema = [ + bigquery.SchemaField( + "user", + "RECORD", + "NULLABLE", + fields=( + bigquery.SchemaField("name", "STRING", "NULLABLE"), + bigquery.SchemaField( + "orders", + "RECORD", + "REPEATED", + fields=([bigquery.SchemaField("id", "INT64", "NULLABLE")]), + ), + ), + ) + ] + new_nested_fields = [("year", "INT64", ["user", "orders"]), ("active", "BOOL", ["user"])] + expected = [ + bigquery.SchemaField( + "user", + "RECORD", + "NULLABLE", + fields=( + bigquery.SchemaField("name", "STRING", "NULLABLE"), + bigquery.SchemaField( + "orders", + "RECORD", + "REPEATED", + fields=( + bigquery.SchemaField("id", "INT64", "NULLABLE"), + bigquery.SchemaField("year", "INT64", "NULLABLE"), + ), + ), + bigquery.SchemaField("active", "BOOL", "NULLABLE"), + ), + ) + ] + assert adapter._build_nested_fields(current_schema, new_nested_fields) == expected + + current_schema = [ + bigquery.SchemaField( + "users", + "RECORD", + "REPEATED", + fields=( + [ + bigquery.SchemaField( + "user", + "RECORD", + "NULLABLE", + fields=(bigquery.SchemaField("name", "STRING", "NULLABLE"),), + ) + ] + ), + ) + ] + new_nested_fields = [ + ("orders", "ARRAY", ["users", "user"]), + ("tags", "STRING", ["users"]), + ("details", "ARRAY", []), + ] + expected = [ + bigquery.SchemaField( + "users", + "RECORD", + "REPEATED", + fields=( + bigquery.SchemaField( + "user", + "RECORD", + "NULLABLE", + fields=( + bigquery.SchemaField("name", "STRING", "NULLABLE"), + bigquery.SchemaField("orders", "INT64", "REPEATED"), + ), + ), + bigquery.SchemaField( + "tags", + "STRING", + "NULLABLE", + ), + ), + ), + bigquery.SchemaField("details", "STRING", "REPEATED"), + ] + assert adapter._build_nested_fields(current_schema, new_nested_fields) == expected + + +def test_get_alter_expressions_includes_catalog( + adapter: BigQueryEngineAdapter, mocker: MockerFixture +): + adapter._default_catalog = "test_project" + + columns_mock = mocker.patch( + "sqlmesh.core.engine_adapter.bigquery.BigQueryEngineAdapter.columns" + ) + columns_mock.return_value = { + "a": exp.DataType.build("int"), + } + + get_data_objects_mock = mocker.patch( + "sqlmesh.core.engine_adapter.bigquery.BigQueryEngineAdapter.get_data_objects" + ) + get_data_objects_mock.return_value = [] + + adapter.get_alter_operations("catalog1.foo.bar", "catalog2.bar.bing") + + assert get_data_objects_mock.call_count == 2 + + schema, tables = get_data_objects_mock.call_args_list[0][0] + assert isinstance(schema, exp.Table) + assert isinstance(tables, set) + assert schema.catalog == "catalog1" + assert schema.db == "foo" + assert schema.sql(dialect="bigquery") == "catalog1.foo" + assert tables == {"bar"} + + schema, tables = get_data_objects_mock.call_args_list[1][0] + assert isinstance(schema, exp.Table) + assert isinstance(tables, set) + assert schema.catalog == "catalog2" + assert schema.db == "bar" + assert schema.sql(dialect="bigquery") == "catalog2.bar" + assert tables == {"bing"} + + +def test_job_cancellation_on_keyboard_interrupt_job_still_running(mocker: MockerFixture): + # Create a mock connection + connection_mock = mocker.NonCallableMock() + cursor_mock = mocker.Mock() + cursor_mock.connection = connection_mock + connection_mock.cursor.return_value = cursor_mock + + # Mock the query job + mock_job = mocker.Mock() + mock_job.project = "test-project" + mock_job.location = "us-central1" + mock_job.job_id = "test-job-123" + mock_job.done.return_value = False # Job is still running + mock_job.result.side_effect = KeyboardInterrupt() + mock_job._query_results = mocker.Mock() + mock_job._query_results.total_rows = 0 + mock_job._query_results.schema = [] + + # Set up the client to return our mock job + connection_mock._client.query.return_value = mock_job + + # Create adapter with the mocked connection + adapter = BigQueryEngineAdapter(lambda: connection_mock, job_retries=0) + + # Execute a query and expect KeyboardInterrupt + with pytest.raises(KeyboardInterrupt): + adapter.execute("SELECT 1") + + # Ensure the adapter's closed, so that the job can be aborted + adapter.close() + + # Verify the job was created + connection_mock._client.query.assert_called_once() + + # Verify job status was checked and cancellation was called + mock_job.done.assert_called_once() + mock_job.cancel.assert_called_once() + + +def test_job_cancellation_on_keyboard_interrupt_job_already_done(mocker: MockerFixture): + # Create a mock connection + connection_mock = mocker.NonCallableMock() + cursor_mock = mocker.Mock() + cursor_mock.connection = connection_mock + connection_mock.cursor.return_value = cursor_mock + + # Mock the query job + mock_job = mocker.Mock() + mock_job.project = "test-project" + mock_job.location = "us-central1" + mock_job.job_id = "test-job-456" + mock_job.done.return_value = True # Job is already done + mock_job.result.side_effect = KeyboardInterrupt() + mock_job._query_results = mocker.Mock() + mock_job._query_results.total_rows = 0 + mock_job._query_results.schema = [] + + # Set up the client to return our mock job + connection_mock._client.query.return_value = mock_job + + # Create adapter with the mocked connection + adapter = BigQueryEngineAdapter(lambda: connection_mock, job_retries=0) + + # Execute a query and expect KeyboardInterrupt + with pytest.raises(KeyboardInterrupt): + adapter.execute("SELECT 1") + + # Ensure the adapter's closed, so that the job can be aborted + adapter.close() + + # Verify the job was created + connection_mock._client.query.assert_called_once() + + # Verify job status was checked but cancellation was NOT called + mock_job.done.assert_called_once() + mock_job.cancel.assert_not_called() + + +def test_drop_cascade(adapter: BigQueryEngineAdapter): + adapter.drop_table("foo", cascade=True) + adapter.drop_table("foo", cascade=False) + + # BigQuery doesnt support DROP CASCADE for tables + # ref: https://cloud.google.com/bigquery/docs/reference/standard-sql/data-definition-language#drop_table_statement + assert _to_sql_calls(adapter) == ["DROP TABLE IF EXISTS `foo`", "DROP TABLE IF EXISTS `foo`"] + adapter.execute.reset_mock() # type: ignore + + # But, it does for schemas + adapter.drop_schema("foo", cascade=True) + adapter.drop_schema("foo", cascade=False) + + assert _to_sql_calls(adapter) == [ + "DROP SCHEMA IF EXISTS `foo` CASCADE", + "DROP SCHEMA IF EXISTS `foo`", + ] + + +def test_scd_type_2_by_partitioning(adapter: BigQueryEngineAdapter): + adapter.scd_type_2_by_time( + target_table="target", + source_table=t.cast( + exp.Select, parse_one("SELECT id, name, price, test_UPDATED_at FROM source") + ), + unique_key=[ + exp.to_column("id"), + ], + updated_at_col=exp.column("test_UPDATED_at", quoted=True), + valid_from_col=exp.to_column("valid_from", quoted=True), + valid_to_col=exp.to_column("valid_to", quoted=True), + target_columns_to_types={ + "id": exp.DataType.build("INT"), + "name": exp.DataType.build("VARCHAR"), + "price": exp.DataType.build("DOUBLE"), + "test_UPDATED_at": exp.DataType.build("TIMESTAMP"), + "valid_from": exp.DataType.build("TIMESTAMP"), + "valid_to": exp.DataType.build("TIMESTAMP"), + }, + execution_time=datetime(2020, 1, 1, 0, 0, 0), + partitioned_by=[parse_one("TIMESTAMP_TRUNC(valid_from, DAY)")], + ) + + calls = _to_sql_calls(adapter) + + # Initial call to create the table and then another to replace since it is self-referencing + assert len(calls) == 2 + # Both calls should contain the partition logic (the scd logic is already covered by other tests) + assert "PARTITION BY TIMESTAMP_TRUNC(`valid_from`, DAY)" in calls[0] + assert "PARTITION BY TIMESTAMP_TRUNC(`valid_from`, DAY)" in calls[1] + + +def test_sync_grants_config(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): + adapter = make_mocked_engine_adapter(BigQueryEngineAdapter) + relation = exp.to_table("project.dataset.test_table", dialect="bigquery") + new_grants_config = { + "roles/bigquery.dataViewer": ["user:analyst@example.com", "group:data-team@example.com"], + "roles/bigquery.dataEditor": ["user:admin@example.com"], + } + current_grants = [ + ("roles/bigquery.dataViewer", "user:old_analyst@example.com"), + ("roles/bigquery.admin", "user:old_admin@example.com"), + ] + + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + execute_mock = mocker.patch.object(adapter, "execute") + mocker.patch.object(adapter, "get_current_catalog", return_value="project") + mocker.patch.object(adapter.client, "location", "us-central1") + + mock_dataset = mocker.Mock() + mock_dataset.location = "us-central1" + mocker.patch.object(adapter, "_db_call", return_value=mock_dataset) + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="bigquery") + expected_sql = ( + "SELECT privilege_type, grantee FROM `project`.`region-us-central1`.`INFORMATION_SCHEMA.OBJECT_PRIVILEGES` AS OBJECT_PRIVILEGES " + "WHERE object_schema = 'dataset' AND object_name = 'test_table' AND SPLIT(grantee, ':')[OFFSET(1)] <> SESSION_USER()" + ) + assert executed_sql == expected_sql + + sql_calls = _to_sql_calls(execute_mock) + + assert len(sql_calls) == 4 + assert ( + "REVOKE `roles/bigquery.dataViewer` ON TABLE `project`.`dataset`.`test_table` FROM 'user:old_analyst@example.com'" + in sql_calls + ) + assert ( + "REVOKE `roles/bigquery.admin` ON TABLE `project`.`dataset`.`test_table` FROM 'user:old_admin@example.com'" + in sql_calls + ) + assert ( + "GRANT `roles/bigquery.dataViewer` ON TABLE `project`.`dataset`.`test_table` TO 'user:analyst@example.com', 'group:data-team@example.com'" + in sql_calls + ) + assert ( + "GRANT `roles/bigquery.dataEditor` ON TABLE `project`.`dataset`.`test_table` TO 'user:admin@example.com'" + in sql_calls + ) + + +def test_sync_grants_config_with_overlaps( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(BigQueryEngineAdapter) + relation = exp.to_table("project.dataset.test_table", dialect="bigquery") + new_grants_config = { + "roles/bigquery.dataViewer": [ + "user:analyst1@example.com", + "user:analyst2@example.com", + "user:analyst3@example.com", + ], + "roles/bigquery.dataEditor": ["user:analyst2@example.com", "user:editor@example.com"], + } + current_grants = [ + ("roles/bigquery.dataViewer", "user:analyst1@example.com"), # Keep + ("roles/bigquery.dataViewer", "user:old_analyst@example.com"), # Remove + ("roles/bigquery.dataEditor", "user:analyst2@example.com"), # Keep + ("roles/bigquery.admin", "user:admin@example.com"), # Remove + ] + + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + execute_mock = mocker.patch.object(adapter, "execute") + mocker.patch.object(adapter, "get_current_catalog", return_value="project") + mocker.patch.object(adapter.client, "location", "us-central1") + + mock_dataset = mocker.Mock() + mock_dataset.location = "us-central1" + mocker.patch.object(adapter, "_db_call", return_value=mock_dataset) + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="bigquery") + expected_sql = ( + "SELECT privilege_type, grantee FROM `project`.`region-us-central1`.`INFORMATION_SCHEMA.OBJECT_PRIVILEGES` AS OBJECT_PRIVILEGES " + "WHERE object_schema = 'dataset' AND object_name = 'test_table' AND SPLIT(grantee, ':')[OFFSET(1)] <> SESSION_USER()" + ) + assert executed_sql == expected_sql + + sql_calls = _to_sql_calls(execute_mock) + + assert len(sql_calls) == 4 + assert ( + "REVOKE `roles/bigquery.dataViewer` ON TABLE `project`.`dataset`.`test_table` FROM 'user:old_analyst@example.com'" + in sql_calls + ) + assert ( + "REVOKE `roles/bigquery.admin` ON TABLE `project`.`dataset`.`test_table` FROM 'user:admin@example.com'" + in sql_calls + ) + assert ( + "GRANT `roles/bigquery.dataViewer` ON TABLE `project`.`dataset`.`test_table` TO 'user:analyst2@example.com', 'user:analyst3@example.com'" + in sql_calls + ) + assert ( + "GRANT `roles/bigquery.dataEditor` ON TABLE `project`.`dataset`.`test_table` TO 'user:editor@example.com'" + in sql_calls + ) + + +@pytest.mark.parametrize( + "table_type, expected_keyword", + [ + (DataObjectType.TABLE, "TABLE"), + (DataObjectType.VIEW, "VIEW"), + (DataObjectType.MATERIALIZED_VIEW, "MATERIALIZED VIEW"), + ], +) +def test_sync_grants_config_object_kind( + make_mocked_engine_adapter: t.Callable, + mocker: MockerFixture, + table_type: DataObjectType, + expected_keyword: str, +) -> None: + adapter = make_mocked_engine_adapter(BigQueryEngineAdapter) + relation = exp.to_table("project.dataset.test_object", dialect="bigquery") + + mocker.patch.object(adapter, "fetchall", return_value=[]) + execute_mock = mocker.patch.object(adapter, "execute") + mocker.patch.object(adapter, "get_current_catalog", return_value="project") + mocker.patch.object(adapter.client, "location", "us-central1") + + mock_dataset = mocker.Mock() + mock_dataset.location = "us-central1" + mocker.patch.object(adapter, "_db_call", return_value=mock_dataset) + + adapter.sync_grants_config( + relation, {"roles/bigquery.dataViewer": ["user:test@example.com"]}, table_type + ) + + executed_exprs = execute_mock.call_args[0][0] + sql_calls = [expr.sql(dialect="bigquery") for expr in executed_exprs] + assert sql_calls == [ + f"GRANT `roles/bigquery.dataViewer` ON {expected_keyword} project.dataset.test_object TO 'user:test@example.com'" + ] + + +def test_sync_grants_config_no_schema( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(BigQueryEngineAdapter) + relation = exp.to_table("test_table", dialect="bigquery") + new_grants_config = { + "roles/bigquery.dataViewer": ["user:analyst@example.com"], + "roles/bigquery.dataEditor": ["user:editor@example.com"], + } + + with pytest.raises(ValueError, match="Table test_table does not have a schema \\(dataset\\)"): + adapter.sync_grants_config(relation, new_grants_config) diff --git a/tests/core/engine_adapter/test_clickhouse.py b/tests/core/engine_adapter/test_clickhouse.py new file mode 100644 index 0000000000..54fbe7c323 --- /dev/null +++ b/tests/core/engine_adapter/test_clickhouse.py @@ -0,0 +1,1376 @@ +import pytest +from sqlmesh.core.engine_adapter import ClickhouseEngineAdapter +from sqlmesh.core.model.definition import load_sql_based_model +from sqlmesh.core.model.kind import ModelKindName +from sqlmesh.core.engine_adapter.shared import EngineRunMode, DataObject +from tests.core.engine_adapter import to_sql_calls +from sqlmesh.core.dialect import parse +from sqlglot import exp, parse_one +import typing as t +from datetime import datetime +from pytest_mock.plugin import MockerFixture +from sqlmesh.core import dialect as d +from sqlglot.optimizer.qualify_columns import quote_identifiers + +pytestmark = [pytest.mark.clickhouse, pytest.mark.engine] + + +@pytest.fixture +def adapter(make_mocked_engine_adapter, mocker) -> ClickhouseEngineAdapter: + mocker.patch.object( + ClickhouseEngineAdapter, + "engine_run_mode", + new_callable=mocker.PropertyMock(return_value=EngineRunMode.STANDALONE), + ) + + adapter = make_mocked_engine_adapter(ClickhouseEngineAdapter) + + return adapter + + +def test_create_schema(adapter: ClickhouseEngineAdapter, mocker): + mocker.patch.object( + ClickhouseEngineAdapter, + "cluster", + new_callable=mocker.PropertyMock(return_value="default"), + ) + + # ON CLUSTER not added because engine_run_mode.is_cluster=False + adapter.create_schema("foo") + + mocker.patch.object( + ClickhouseEngineAdapter, + "engine_run_mode", + new_callable=mocker.PropertyMock(return_value=EngineRunMode.CLUSTER), + ) + adapter.create_schema("foo") + + assert to_sql_calls(adapter) == [ + 'CREATE DATABASE IF NOT EXISTS "foo"', + 'CREATE DATABASE IF NOT EXISTS "foo" ON CLUSTER "default"', + ] + + +def test_drop_schema(adapter: ClickhouseEngineAdapter, mocker): + mocker.patch.object( + ClickhouseEngineAdapter, + "cluster", + new_callable=mocker.PropertyMock(return_value="default"), + ) + + # ON CLUSTER not added because engine_run_mode.is_cluster=False + adapter.drop_schema("foo") + + mocker.patch.object( + ClickhouseEngineAdapter, + "engine_run_mode", + new_callable=mocker.PropertyMock(return_value=EngineRunMode.CLUSTER), + ) + adapter.drop_schema("foo") + + assert to_sql_calls(adapter) == [ + 'DROP DATABASE IF EXISTS "foo"', + 'DROP DATABASE IF EXISTS "foo" ON CLUSTER "default"', + ] + + +def test_create_table(adapter: ClickhouseEngineAdapter, mocker): + mocker.patch.object( + ClickhouseEngineAdapter, + "cluster", + new_callable=mocker.PropertyMock(return_value="default"), + ) + + # ON CLUSTER not added because engine_run_mode.is_cluster=False + adapter.create_table("foo", {"a": exp.DataType.build("Int8", dialect=adapter.dialect)}) + # adapter.create_table_like("target", "source") + + mocker.patch.object( + ClickhouseEngineAdapter, + "engine_run_mode", + new_callable=mocker.PropertyMock(return_value=EngineRunMode.CLUSTER), + ) + adapter.create_table("foo", {"a": exp.DataType.build("Int8", dialect=adapter.dialect)}) + # adapter.create_table_like("target", "source") + + assert to_sql_calls(adapter) == [ + 'CREATE TABLE IF NOT EXISTS "foo" ("a" Int8) ENGINE=MergeTree ORDER BY ()', + # "CREATE TABLE IF NOT EXISTS target AS source", + 'CREATE TABLE IF NOT EXISTS "foo" ON CLUSTER "default" ("a" Int8) ENGINE=MergeTree ORDER BY ()', + # "CREATE TABLE IF NOT EXISTS target AS source", + ] + + +def test_rename_table(adapter: ClickhouseEngineAdapter, mocker): + mocker.patch.object( + ClickhouseEngineAdapter, + "cluster", + new_callable=mocker.PropertyMock(return_value="default"), + ) + + # ON CLUSTER not added because engine_run_mode.is_cluster=False + adapter.rename_table(exp.to_table("foo"), exp.to_table("bar")) + + mocker.patch.object( + ClickhouseEngineAdapter, + "engine_run_mode", + new_callable=mocker.PropertyMock(return_value=EngineRunMode.CLUSTER), + ) + adapter.rename_table(exp.to_table("foo"), exp.to_table("bar")) + + assert to_sql_calls(adapter) == [ + 'RENAME TABLE "foo" TO "bar"', + 'RENAME TABLE "foo" TO "bar" ON CLUSTER "default" ', + ] + + +def test_delete_from(adapter: ClickhouseEngineAdapter, mocker): + mocker.patch.object( + ClickhouseEngineAdapter, + "cluster", + new_callable=mocker.PropertyMock(return_value="default"), + ) + + # ON CLUSTER not added because engine_run_mode.is_cluster=False + adapter.delete_from(exp.to_table("foo"), "a = 1") + + mocker.patch.object( + ClickhouseEngineAdapter, + "engine_run_mode", + new_callable=mocker.PropertyMock(return_value=EngineRunMode.CLUSTER), + ) + adapter.delete_from(exp.to_table("foo"), "a = 1") + + assert to_sql_calls(adapter) == [ + 'DELETE FROM "foo" WHERE "a" = 1', + 'DELETE FROM "foo" ON CLUSTER "default" WHERE "a" = 1', + ] + + +def test_alter_table( + adapter: ClickhouseEngineAdapter, + mocker, +): + adapter.SCHEMA_DIFFER_KWARGS = {} + current_table_name = "test_table" + current_table = {"a": "Int8", "b": "String", "c": "Int8"} + target_table_name = "target_table" + target_table = { + "a": "Int8", + "b": "String", + "f": "String", + } + + def table_columns(table_name: str) -> t.Dict[str, exp.DataType]: + if table_name == current_table_name: + return { + k: exp.DataType.build(v, dialect=adapter.dialect) for k, v in current_table.items() + } + return {k: exp.DataType.build(v, dialect=adapter.dialect) for k, v in target_table.items()} + + adapter.columns = table_columns # type: ignore + + # ON CLUSTER not added because engine_run_mode.is_cluster=False + adapter.alter_table(adapter.get_alter_operations(current_table_name, target_table_name)) + + mocker.patch.object( + ClickhouseEngineAdapter, + "cluster", + new_callable=mocker.PropertyMock(return_value="default"), + ) + mocker.patch.object( + ClickhouseEngineAdapter, + "engine_run_mode", + new_callable=mocker.PropertyMock(return_value=EngineRunMode.CLUSTER), + ) + + adapter.alter_table(adapter.get_alter_operations(current_table_name, target_table_name)) + + assert to_sql_calls(adapter) == [ + 'ALTER TABLE "test_table" DROP COLUMN "c"', + 'ALTER TABLE "test_table" ADD COLUMN "f" String', + 'ALTER TABLE "test_table" ON CLUSTER "default" DROP COLUMN "c"', + 'ALTER TABLE "test_table" ON CLUSTER "default" ADD COLUMN "f" String', + ] + + +def test_nullable_datatypes_in_model_kind(adapter: ClickhouseEngineAdapter): + model = load_sql_based_model( + parse( + """ + MODEL ( + name foo, + kind SCD_TYPE_2_BY_TIME(unique_key id, time_data_type Nullable(DateTime64)), + ); + + select 1; + """, + default_dialect="clickhouse", + ) + ) + + assert model.kind.name == ModelKindName.SCD_TYPE_2_BY_TIME + assert model.kind.time_data_type.sql(dialect="clickhouse") == "Nullable(DateTime64)" + + +def test_nullable_datatypes_in_model_columns(adapter: ClickhouseEngineAdapter): + model = load_sql_based_model( + parse( + """ + MODEL ( + name foo, + columns ( + id Int64, + data Nullable(JSON), + ts DateTime64, + other Tuple(UInt16, String) + ) + ); + + select 1, 2, 3, 4; + """, + default_dialect="clickhouse", + ) + ) + + rendered_columns_to_types = { + k: v.sql(dialect="clickhouse") for k, v in model.columns_to_types_or_raise.items() + } + + assert rendered_columns_to_types["id"] == "Int64" + assert rendered_columns_to_types["data"] == "Nullable(JSON)" + assert rendered_columns_to_types["ts"] == "DateTime64" + assert rendered_columns_to_types["other"] == "Tuple(UInt16, String)" + + +def test_model_properties(adapter: ClickhouseEngineAdapter): + def build_properties_sql(storage_format="", order_by="", primary_key="", properties=""): + model = load_sql_based_model( + parse( + f""" + MODEL ( + name foo, + dialect clickhouse, + {storage_format} + physical_properties ( + {order_by} + {primary_key} + {properties} + ), + ); + + select + * + from bar; + """, + default_dialect="clickhouse", + ) + ) + + return adapter._build_table_properties_exp( + storage_format=model.storage_format, table_properties=model.physical_properties + ).sql("clickhouse") + + # no order by or primary key because table engine is not part of "MergeTree" engine family + assert ( + build_properties_sql( + storage_format="storage_format Log,", + order_by="ORDER_BY = a,", + primary_key="PRIMARY_KEY = a,", + ) + == "ENGINE=Log" + ) + + assert ( + build_properties_sql( + storage_format="storage_format ReplicatedMergeTree('/clickhouse/tables/{shard}/table_name', '{replica}', ver),", + order_by="ORDER_BY = a,", + primary_key="PRIMARY_KEY = a,", + ) + == "ENGINE=ReplicatedMergeTree('/clickhouse/tables/{shard}/table_name', '{replica}', ver) ORDER BY (a) PRIMARY KEY (a)" + ) + + assert ( + build_properties_sql(order_by="ORDER_BY = a,", primary_key="PRIMARY_KEY = a,") + == "ENGINE=MergeTree ORDER BY (a) PRIMARY KEY (a)" + ) + + assert ( + build_properties_sql(order_by='ORDER_BY = "a",', primary_key='PRIMARY_KEY = "a",') + == 'ENGINE=MergeTree ORDER BY ("a") PRIMARY KEY ("a")' + ) + + assert ( + build_properties_sql(order_by="ORDER_BY = (a),", primary_key="PRIMARY_KEY = (a)") + == "ENGINE=MergeTree ORDER BY (a) PRIMARY KEY (a)" + ) + + assert build_properties_sql(order_by="ORDER_BY = a + 1,") == "ENGINE=MergeTree ORDER BY (a + 1)" + + assert ( + build_properties_sql(order_by="ORDER_BY = (a + 1),") == "ENGINE=MergeTree ORDER BY (a + 1)" + ) + + assert ( + build_properties_sql(order_by="ORDER_BY = (a, b + 1),", primary_key="PRIMARY_KEY = (a, b)") + == "ENGINE=MergeTree ORDER BY (a, b + 1) PRIMARY KEY (a, b)" + ) + + assert ( + build_properties_sql( + order_by="ORDER_BY = (a, b + 1),", + primary_key="PRIMARY_KEY = (a, b),", + properties="PROP1 = 1, PROP2 = '2'", + ) + == "ENGINE=MergeTree ORDER BY (a, b + 1) PRIMARY KEY (a, b) SETTINGS prop1 = 1 SETTINGS prop2 = '2'" + ) + + assert ( + build_properties_sql( + order_by="ORDER_BY = 'timestamp with fill to dateTrunc(\\'DAY\\', toDateTime64(\\'2024-07-11\\', 3)) step toIntervalDay(1) interpolate(price as price)'," + ) + == "ENGINE=MergeTree ORDER BY (timestamp WITH FILL TO dateTrunc('DAY', toDateTime64('2024-07-11', 3)) STEP toIntervalDay(1) INTERPOLATE (price AS price))" + ) + + assert ( + build_properties_sql( + order_by="ORDER_BY = (\"a\", 'timestamp with fill to dateTrunc(\\'DAY\\', toDateTime64(\\'2024-07-11\\', 3)) step toIntervalDay(1) interpolate(price as price)')," + ) + == "ENGINE=MergeTree ORDER BY (\"a\", timestamp WITH FILL TO dateTrunc('DAY', toDateTime64('2024-07-11', 3)) STEP toIntervalDay(1) INTERPOLATE (price AS price))" + ) + + assert ( + build_properties_sql(properties="TTL = time + INTERVAL 1 WEEK") + == "ENGINE=MergeTree ORDER BY () TTL time + INTERVAL '1' WEEK" + ) + + +def test_partitioned_by_expr(make_mocked_engine_adapter: t.Callable): + # user doesn't specify, unknown time column type + model = load_sql_based_model( + parse( + """ + MODEL ( + name foo, + dialect clickhouse, + kind INCREMENTAL_BY_TIME_RANGE( + time_column ds + ) + ); + + select + * + from bar; + """, + default_dialect="clickhouse", + ) + ) + + assert ( + model.partitioned_by[0].sql("clickhouse") + == """dateTrunc('WEEK', CAST("ds" AS DateTime64(9, 'UTC')))""" + ) + + # user specifies without time column, unknown time column type + model = load_sql_based_model( + parse( + """ + MODEL ( + name foo, + dialect clickhouse, + kind INCREMENTAL_BY_TIME_RANGE( + time_column ds + ), + partitioned_by x + ); + + select + * + from bar; + """, + default_dialect="clickhouse", + ) + ) + + assert [p.sql("clickhouse") for p in model.partitioned_by] == [ + """dateTrunc('WEEK', CAST("ds" AS DateTime64(9, 'UTC')))""", + '"x"', + ] + + # user doesn't specify, conformable date/datetime time column type + model = load_sql_based_model( + parse( + """ + MODEL ( + name foo, + dialect clickhouse, + kind INCREMENTAL_BY_TIME_RANGE( + time_column ds + ) + ); + + select + ds::DATE as ds + from bar; + """, + default_dialect="clickhouse", + ) + ) + + assert model.partitioned_by[0].sql("clickhouse") == """dateTrunc('WEEK', "ds")""" + + # user doesn't specify, non-conformable time column type + model = load_sql_based_model( + parse( + """ + MODEL ( + name foo, + dialect clickhouse, + kind INCREMENTAL_BY_TIME_RANGE( + time_column ds + ) + ); + + select + ds::String as ds + from bar; + """, + default_dialect="clickhouse", + ) + ) + + assert ( + model.partitioned_by[0].sql("clickhouse") + == """CAST(dateTrunc('WEEK', CAST("ds" AS DateTime64(9, 'UTC'))) AS String)""" + ) + + # user specifies partitioned_by with time column + model = load_sql_based_model( + parse( + """ + MODEL ( + name foo, + dialect clickhouse, + kind INCREMENTAL_BY_TIME_RANGE( + time_column ds + ), + partitioned_by toStartOfWeek(ds) + ); + + select + * + from bar; + """, + default_dialect="clickhouse", + ) + ) + + assert model.partitioned_by == [exp.func("toStartOfWeek", '"ds"')] + + +def test_nullable_partition_cols(make_mocked_engine_adapter: t.Callable, mocker): + adapter = make_mocked_engine_adapter(ClickhouseEngineAdapter) + + columns_to_types = { + "cola": exp.DataType.build("INT"), + "colb": exp.DataType.build("TEXT"), + } + + adapter.create_table( + "test_table", + columns_to_types, + ) + + adapter.create_table( + "test_table", + columns_to_types, + partitioned_by=[exp.to_column("colb")], + ) + + assert to_sql_calls(adapter) == [ + 'CREATE TABLE IF NOT EXISTS "test_table" ("cola" Nullable(Int32), "colb" Nullable(String)) ENGINE=MergeTree ORDER BY ()', + 'CREATE TABLE IF NOT EXISTS "test_table" ("cola" Nullable(Int32), "colb" String) ENGINE=MergeTree ORDER BY () PARTITION BY ("colb")', + ] + + +def test_create_table_properties(make_mocked_engine_adapter: t.Callable, mocker): + adapter = make_mocked_engine_adapter(ClickhouseEngineAdapter) + + mocker.patch( + "sqlmesh.core.engine_adapter.ClickhouseEngineAdapter.fetchone", + return_value="1", + ) + + columns_to_types = { + "cola": exp.DataType.build("INT", dialect="clickhouse"), + "colb": exp.DataType.build("TEXT", dialect="clickhouse"), + "colc": exp.DataType.build("TEXT", dialect="clickhouse"), + } + adapter.create_table( + "test_table", + columns_to_types, + partitioned_by=[exp.to_column("colb")], + storage_format="ReplicatedMergeTree", + table_properties={ + "ORDER_BY": [exp.to_column("cola"), exp.to_column("colb")], + "PRIMARY_KEY": [exp.to_column("cola"), exp.to_column("colb")], + }, + ) + + assert to_sql_calls(adapter) == [ + 'CREATE TABLE IF NOT EXISTS "test_table" ("cola" Int32, "colb" String, "colc" String) ENGINE=ReplicatedMergeTree ORDER BY ("cola", "colb") PRIMARY KEY ("cola", "colb") PARTITION BY ("colb")', + ] + + +def test_nulls_after_join(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): + adapter = make_mocked_engine_adapter(ClickhouseEngineAdapter) + + query = exp.select("col1").from_("table") + + assert ( + adapter.ensure_nulls_for_unmatched_after_join(query.copy()).sql(adapter.dialect) + == "SELECT col1 FROM table SETTINGS join_use_nulls = 1" + ) + + # User already set the setting, so we should not override it + query_with_setting = query.copy() + query_with_setting.set( + "settings", + [ + exp.EQ( + this=exp.var("join_use_nulls"), + expression=exp.Literal(this="0", is_string=False), + ) + ], + ) + + assert ( + adapter.use_server_nulls_for_unmatched_after_join(query_with_setting) == query_with_setting + ) + + # Server default of 0 != method default of 1, so we inject 0 + mocker.patch( + "sqlmesh.core.engine_adapter.ClickhouseEngineAdapter.fetchone", + return_value="0", + ) + + assert ( + adapter.use_server_nulls_for_unmatched_after_join(query).sql(adapter.dialect) + == "SELECT col1 FROM table SETTINGS join_use_nulls = 0" + ) + + +def test_scd_type_2_by_time( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture, make_temp_table_name: t.Callable +): + adapter = make_mocked_engine_adapter(ClickhouseEngineAdapter) + + temp_table_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table") + table_name = "target" + temp_table_mock.side_effect = [ + make_temp_table_name(table_name, "efgh"), + make_temp_table_name(table_name, "abcd"), + ] + + mocker.patch.object( + adapter, + "get_data_objects", + return_value=[DataObject(schema="", name=table_name, type="table")], + ) + + fetchone_mock = mocker.patch("sqlmesh.core.engine_adapter.ClickhouseEngineAdapter.fetchone") + fetchone_mock.return_value = None + + # The SCD query we build must specify the setting join_use_nulls = 1. We need to ensure that our + # setting on the outer query doesn't override the value the user expects. + # + # This test's user query does not contain a setting "join_use_nulls", so we determine whether or not + # to inject it based on the current server value. The mocked server value is 1, so we should not + # inject. + fetchone_mock = mocker.patch("sqlmesh.core.engine_adapter.ClickhouseEngineAdapter.fetchone") + fetchone_mock.return_value = "1" + + adapter.scd_type_2_by_time( + target_table="target", + source_table=t.cast( + exp.Select, parse_one("SELECT id, name, price, test_UPDATED_at FROM source") + ), + unique_key=[ + parse_one("""COALESCE("id", '') || '|' || COALESCE("name", '')"""), + parse_one("""COALESCE("name", '')"""), + ], + valid_from_col=exp.column("test_valid_from", quoted=True), + valid_to_col=exp.column("test_valid_to", quoted=True), + updated_at_col=exp.column("test_UPDATED_at", quoted=True), + target_columns_to_types={ + "id": exp.DataType.build("INT"), + "name": exp.DataType.build("VARCHAR"), + "price": exp.DataType.build("DOUBLE"), + "test_UPDATED_at": exp.DataType.build("TIMESTAMP"), + "test_valid_from": exp.DataType.build("TIMESTAMP"), + "test_valid_to": exp.DataType.build("TIMESTAMP"), + }, + execution_time=datetime(2020, 1, 1, 0, 0, 0), + ) + + assert to_sql_calls(adapter)[3] == parse_one( + """ +INSERT INTO "__temp_target_abcd" ("id", "name", "price", "test_UPDATED_at", "test_valid_from", "test_valid_to") +WITH "source" AS ( + SELECT DISTINCT ON (COALESCE("id", '') || '|' || COALESCE("name", ''), COALESCE("name", '')) + TRUE AS "_exists", + "id", + "name", + "price", + CAST("test_UPDATED_at" AS Nullable(DateTime)) AS "test_UPDATED_at" + FROM ( + SELECT + "id", + "name", + "price", + "test_UPDATED_at" + FROM "source" + ) AS "raw_source" +), "static" AS ( + SELECT + "id", + "name", + "price", + "test_UPDATED_at", + "test_valid_from", + "test_valid_to", + TRUE AS "_exists" + FROM "__temp_target_efgh" + WHERE + NOT "test_valid_to" IS NULL +), "latest" AS ( + SELECT + "id", + "name", + "price", + "test_UPDATED_at", + "test_valid_from", + "test_valid_to", + TRUE AS "_exists" + FROM "__temp_target_efgh" + WHERE + "test_valid_to" IS NULL +), "deleted" AS ( + SELECT + "static"."id", + "static"."name", + "static"."price", + "static"."test_UPDATED_at", + "static"."test_valid_from", + "static"."test_valid_to" + FROM "static" + LEFT JOIN "latest" + ON ( + COALESCE("static"."id", '') || '|' || COALESCE("static"."name", '') + ) = ( + COALESCE("latest"."id", '') || '|' || COALESCE("latest"."name", '') + ) + AND COALESCE("static"."name", '') = COALESCE("latest"."name", '') + WHERE + "latest"."test_valid_to" IS NULL +), "latest_deleted" AS ( + SELECT + TRUE AS "_exists", + COALESCE("id", '') || '|' || COALESCE("name", '') AS "_key0", + COALESCE("name", '') AS "_key1", + MAX("test_valid_to") AS "test_valid_to" + FROM "deleted" + GROUP BY + COALESCE("id", '') || '|' || COALESCE("name", ''), + COALESCE("name", '') +), "joined" AS ( + SELECT + "source"."_exists" AS "_exists", + "latest"."id" AS "t_id", + "latest"."name" AS "t_name", + "latest"."price" AS "t_price", + "latest"."test_UPDATED_at" AS "t_test_UPDATED_at", + "latest"."test_valid_from" AS "t_test_valid_from", + "latest"."test_valid_to" AS "t_test_valid_to", + "source"."id" AS "id", + "source"."name" AS "name", + "source"."price" AS "price", + "source"."test_UPDATED_at" AS "test_UPDATED_at" + FROM "latest" + LEFT JOIN "source" + ON ( + COALESCE("latest"."id", '') || '|' || COALESCE("latest"."name", '') + ) = ( + COALESCE("source"."id", '') || '|' || COALESCE("source"."name", '') + ) + AND COALESCE("latest"."name", '') = COALESCE("source"."name", '') + UNION ALL + SELECT + "source"."_exists" AS "_exists", + "latest"."id" AS "t_id", + "latest"."name" AS "t_name", + "latest"."price" AS "t_price", + "latest"."test_UPDATED_at" AS "t_test_UPDATED_at", + "latest"."test_valid_from" AS "t_test_valid_from", + "latest"."test_valid_to" AS "t_test_valid_to", + "source"."id" AS "id", + "source"."name" AS "name", + "source"."price" AS "price", + "source"."test_UPDATED_at" AS "test_UPDATED_at" + FROM "latest" + RIGHT JOIN "source" + ON ( + COALESCE("latest"."id", '') || '|' || COALESCE("latest"."name", '') + ) = ( + COALESCE("source"."id", '') || '|' || COALESCE("source"."name", '') + ) + AND COALESCE("latest"."name", '') = COALESCE("source"."name", '') + WHERE + "latest"."_exists" IS NULL +), "updated_rows" AS ( + SELECT + COALESCE("joined"."t_id", "joined"."id") AS "id", + COALESCE("joined"."t_name", "joined"."name") AS "name", + COALESCE("joined"."t_price", "joined"."price") AS "price", + COALESCE("joined"."t_test_UPDATED_at", "joined"."test_UPDATED_at") AS "test_UPDATED_at", + CASE + WHEN "t_test_valid_from" IS NULL AND NOT "latest_deleted"."_exists" IS NULL + THEN CASE + WHEN "latest_deleted"."test_valid_to" > "test_UPDATED_at" + THEN "latest_deleted"."test_valid_to" + ELSE "test_UPDATED_at" + END + WHEN "t_test_valid_from" IS NULL + THEN CAST('1970-01-01 00:00:00' AS Nullable(DateTime64(6))) + ELSE "t_test_valid_from" + END AS "test_valid_from", + CASE + WHEN "joined"."test_UPDATED_at" > "joined"."t_test_UPDATED_at" + THEN "joined"."test_UPDATED_at" + WHEN "joined"."_exists" IS NULL + THEN CAST('2020-01-01 00:00:00' AS Nullable(DateTime64(6))) + ELSE "t_test_valid_to" + END AS "test_valid_to" + FROM "joined" + LEFT JOIN "latest_deleted" + ON ( + COALESCE("joined"."id", '') || '|' || COALESCE("joined"."name", '') + ) = "latest_deleted"."_key0" + AND COALESCE("joined"."name", '') = "latest_deleted"."_key1" +), "inserted_rows" AS ( + SELECT + "id", + "name", + "price", + "test_UPDATED_at", + "test_UPDATED_at" AS "test_valid_from", + CAST(NULL AS Nullable(DateTime64(6))) AS "test_valid_to" + FROM "joined" + WHERE + "joined"."test_UPDATED_at" > "joined"."t_test_UPDATED_at" +) +SELECT "id", "name", "price", "test_UPDATED_at", "test_valid_from", "test_valid_to" FROM "static" +UNION ALL SELECT "id", "name", "price", "test_UPDATED_at", "test_valid_from", "test_valid_to" FROM "updated_rows" +UNION ALL SELECT "id", "name", "price", "test_UPDATED_at", "test_valid_from", "test_valid_to" FROM "inserted_rows" +SETTINGS join_use_nulls = 1 + """, + dialect=adapter.dialect, + ).sql(adapter.dialect) + + +def test_scd_type_2_by_column( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture, make_temp_table_name: t.Callable +): + adapter = make_mocked_engine_adapter(ClickhouseEngineAdapter) + + temp_table_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table") + table_name = "target" + temp_table_mock.side_effect = [ + make_temp_table_name(table_name, "efgh"), + make_temp_table_name(table_name, "abcd"), + ] + + mocker.patch.object( + adapter, + "get_data_objects", + return_value=[DataObject(schema="", name=table_name, type="table")], + ) + + fetchone_mock = mocker.patch("sqlmesh.core.engine_adapter.ClickhouseEngineAdapter.fetchone") + fetchone_mock.return_value = None + + # The SCD query we build must specify the setting join_use_nulls = 1. We need to ensure that our + # setting on the outer query doesn't override the value the user expects. + # + # This test's user query does not contain a setting "join_use_nulls", so we determine whether or not + # to inject it based on the current server value. The mocked server value is 0, so we should inject. + fetchone_mock = mocker.patch("sqlmesh.core.engine_adapter.ClickhouseEngineAdapter.fetchone") + fetchone_mock.return_value = "0" + + adapter.scd_type_2_by_column( + target_table="target", + source_table=t.cast(exp.Select, parse_one("SELECT id, name, price FROM source")), + unique_key=[exp.column("id")], + valid_from_col=exp.column("test_VALID_from", quoted=True), + valid_to_col=exp.column("test_valid_to", quoted=True), + check_columns=[exp.column("name"), exp.column("price")], + target_columns_to_types={ + "id": exp.DataType.build("INT"), + "name": exp.DataType.build("VARCHAR"), + "price": exp.DataType.build("DOUBLE"), + "test_VALID_from": exp.DataType.build("TIMESTAMP"), + "test_valid_to": exp.DataType.build("TIMESTAMP"), + }, + execution_time=datetime(2020, 1, 1, 0, 0, 0), + ) + + assert to_sql_calls(adapter)[3] == parse_one( + """ +INSERT INTO "__temp_target_abcd" ("id", "name", "price", "test_VALID_from", "test_valid_to") +WITH "source" AS ( + SELECT DISTINCT ON ("id") + TRUE AS "_exists", + "id", + "name", + "price" + FROM ( + SELECT + "id", + "name", + "price" + FROM "source" + SETTINGS join_use_nulls = 0 + ) AS "raw_source" +), "static" AS ( + SELECT + "id", + "name", + "price", + "test_VALID_from", + "test_valid_to", + TRUE AS "_exists" + FROM "__temp_target_efgh" + WHERE + NOT "test_valid_to" IS NULL +), "latest" AS ( + SELECT + "id", + "name", + "price", + "test_VALID_from", + "test_valid_to", + TRUE AS "_exists" + FROM "__temp_target_efgh" + WHERE + "test_valid_to" IS NULL +), "deleted" AS ( + SELECT + "static"."id", + "static"."name", + "static"."price", + "static"."test_VALID_from", + "static"."test_valid_to" + FROM "static" + LEFT JOIN "latest" + ON "static"."id" = "latest"."id" + WHERE + "latest"."test_valid_to" IS NULL +), "latest_deleted" AS ( + SELECT + TRUE AS "_exists", + "id" AS "_key0", + MAX("test_valid_to") AS "test_valid_to" + FROM "deleted" + GROUP BY + "id" +), "joined" AS ( + SELECT + "source"."_exists" AS "_exists", + "latest"."id" AS "t_id", + "latest"."name" AS "t_name", + "latest"."price" AS "t_price", + "latest"."test_VALID_from" AS "t_test_VALID_from", + "latest"."test_valid_to" AS "t_test_valid_to", + "source"."id" AS "id", + "source"."name" AS "name", + "source"."price" AS "price" + FROM "latest" + LEFT JOIN "source" + ON "latest"."id" = "source"."id" + UNION ALL + SELECT + "source"."_exists" AS "_exists", + "latest"."id" AS "t_id", + "latest"."name" AS "t_name", + "latest"."price" AS "t_price", + "latest"."test_VALID_from" AS "t_test_VALID_from", + "latest"."test_valid_to" AS "t_test_valid_to", + "source"."id" AS "id", + "source"."name" AS "name", + "source"."price" AS "price" + FROM "latest" + RIGHT JOIN "source" + ON "latest"."id" = "source"."id" + WHERE + "latest"."_exists" IS NULL +), "updated_rows" AS ( + SELECT + COALESCE("joined"."t_id", "joined"."id") AS "id", + COALESCE("joined"."t_name", "joined"."name") AS "name", + COALESCE("joined"."t_price", "joined"."price") AS "price", + COALESCE("t_test_VALID_from", CAST('2020-01-01 00:00:00' AS Nullable(DateTime64(6)))) AS "test_VALID_from", + CASE + WHEN "joined"."_exists" IS NULL + OR ( + ( + NOT "joined"."t_id" IS NULL AND NOT "joined"."id" IS NULL + ) + AND ( + "joined"."name" <> "joined"."t_name" + OR ( + "joined"."t_name" IS NULL AND NOT "joined"."name" IS NULL + ) + OR ( + NOT "joined"."t_name" IS NULL AND "joined"."name" IS NULL + ) + OR "joined"."price" <> "joined"."t_price" + OR ( + "joined"."t_price" IS NULL AND NOT "joined"."price" IS NULL + ) + OR ( + NOT "joined"."t_price" IS NULL AND "joined"."price" IS NULL + ) + ) + ) + THEN CAST('2020-01-01 00:00:00' AS Nullable(DateTime64(6))) + ELSE "t_test_valid_to" + END AS "test_valid_to" + FROM "joined" + LEFT JOIN "latest_deleted" + ON "joined"."id" = "latest_deleted"."_key0" +), "inserted_rows" AS ( + SELECT + "id", + "name", + "price", + CAST('2020-01-01 00:00:00' AS Nullable(DateTime64(6))) AS "test_VALID_from", + CAST(NULL AS Nullable(DateTime64(6))) AS "test_valid_to" + FROM "joined" + WHERE + ( + NOT "joined"."t_id" IS NULL AND NOT "joined"."id" IS NULL + ) + AND ( + "joined"."name" <> "joined"."t_name" + OR ( + "joined"."t_name" IS NULL AND NOT "joined"."name" IS NULL + ) + OR ( + NOT "joined"."t_name" IS NULL AND "joined"."name" IS NULL + ) + OR "joined"."price" <> "joined"."t_price" + OR ( + "joined"."t_price" IS NULL AND NOT "joined"."price" IS NULL + ) + OR ( + NOT "joined"."t_price" IS NULL AND "joined"."price" IS NULL + ) + ) +) +SELECT "id", "name", "price", "test_VALID_from", "test_valid_to" FROM "static" UNION ALL SELECT "id", "name", "price", "test_VALID_from", "test_valid_to" FROM "updated_rows" UNION ALL SELECT "id", "name", "price", "test_VALID_from", "test_valid_to" FROM "inserted_rows" SETTINGS join_use_nulls = 1 + """, + dialect=adapter.dialect, + ).sql(adapter.dialect) + + +def test_insert_overwrite_by_condition_replace_partitioned( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture, make_temp_table_name: t.Callable +): + adapter = make_mocked_engine_adapter(ClickhouseEngineAdapter) + + temp_table_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table") + table_name = "target" + temp_table_mock.return_value = make_temp_table_name(table_name, "abcd") + + fetchone_mock = mocker.patch("sqlmesh.core.engine_adapter.ClickhouseEngineAdapter.fetchone") + fetchone_mock.return_value = "dateTrunc('WEEK', ds)" + + insert_table_name = make_temp_table_name("new_records", "abcd") + existing_table_name = make_temp_table_name("existing_records", "abcd") + + source_queries, columns_to_types = adapter._get_source_queries_and_columns_to_types( + parse_one(f"SELECT * FROM {insert_table_name}"), + { + "id": exp.DataType.build("Int8", dialect="clickhouse"), + "ds": exp.DataType.build("Date", dialect="clickhouse"), + }, + existing_table_name, + ) + + adapter._insert_overwrite_by_condition( + existing_table_name.sql(), + source_queries, + columns_to_types, + ) + + assert to_sql_calls(adapter) == [ + "CREATE TABLE __temp_target_abcd AS __temp_existing_records_abcd", + 'INSERT INTO "__temp_target_abcd" ("id", "ds") SELECT "id", "ds" FROM (SELECT * FROM "__temp_new_records_abcd") AS "_subquery"', + 'EXCHANGE TABLES "__temp_existing_records_abcd" AND "__temp_target_abcd"', + 'DROP TABLE IF EXISTS "__temp_target_abcd"', + ] + + +def test_insert_overwrite_by_condition_replace( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture, make_temp_table_name: t.Callable +): + adapter = make_mocked_engine_adapter(ClickhouseEngineAdapter) + + temp_table_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table") + table_name = "target" + temp_table_mock.return_value = make_temp_table_name(table_name, "abcd") + + fetchone_mock = mocker.patch("sqlmesh.core.engine_adapter.ClickhouseEngineAdapter.fetchone") + fetchone_mock.return_value = None + + insert_table_name = make_temp_table_name("new_records", "abcd") + existing_table_name = make_temp_table_name("existing_records", "abcd") + + source_queries, columns_to_types = adapter._get_source_queries_and_columns_to_types( + parse_one(f"SELECT * FROM {insert_table_name}"), + { + "id": exp.DataType.build("Int8", dialect="clickhouse"), + "ds": exp.DataType.build("Date", dialect="clickhouse"), + }, + existing_table_name, + ) + + adapter._insert_overwrite_by_condition( + existing_table_name.sql(), + source_queries, + columns_to_types, + ) + + to_sql_calls(adapter) == [ + "CREATE TABLE __temp_target_abcd AS __temp_existing_records_abcd", + 'INSERT INTO "__temp_target_abcd" ("id", "ds") SELECT "id", "ds" FROM (SELECT * FROM "__temp_new_records_abcd") AS "_subquery"', + 'EXCHANGE TABLES "__temp_existing_records_abcd" AND "__temp_target_abcd"', + 'DROP TABLE IF EXISTS "__temp_target_abcd"', + ] + + +def test_insert_overwrite_by_condition_where_partitioned( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture, make_temp_table_name: t.Callable +): + adapter = make_mocked_engine_adapter(ClickhouseEngineAdapter) + + temp_table_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table") + table_name = "target" + temp_table_mock.return_value = make_temp_table_name(table_name, "abcd") + + fetchone_mock = mocker.patch("sqlmesh.core.engine_adapter.ClickhouseEngineAdapter.fetchone") + fetchone_mock.return_value = "dateTrunc('WEEK', ds)" + + fetchall_mock = mocker.patch("sqlmesh.core.engine_adapter.ClickhouseEngineAdapter.fetchall") + fetchall_mock.side_effect = [ + [("1",), ("2",), ("3",), ("4",)], + ["1", "2", "4"], + ] + + insert_table_name = make_temp_table_name("new_records", "abcd") + existing_table_name = make_temp_table_name("existing_records", "abcd") + + source_queries, columns_to_types = adapter._get_source_queries_and_columns_to_types( + parse_one(f"SELECT * FROM {insert_table_name}"), + { + "id": exp.DataType.build("Int8", dialect="clickhouse"), + "ds": exp.DataType.build("Date", dialect="clickhouse"), + }, + existing_table_name, + ) + + adapter._insert_overwrite_by_condition( + existing_table_name.sql(), + source_queries, + columns_to_types, + exp.Between( + this=exp.column("ds"), + low=parse_one("'2024-02-15'"), + high=parse_one("'2024-04-30'"), + ), + ) + + to_sql_calls(adapter) == [ + "CREATE TABLE __temp_target_abcd AS __temp_existing_records_abcd", + """INSERT INTO "__temp_target_abcd" ("id", "ds") SELECT "id", "ds" FROM (SELECT * FROM "__temp_new_records_abcd") AS "_subquery" WHERE "ds" BETWEEN '2024-02-15' AND '2024-04-30'""", + """CREATE TABLE IF NOT EXISTS "__temp_target_abcd" ENGINE=MergeTree ORDER BY () AS SELECT DISTINCT "partition_id" FROM (SELECT "_partition_id" AS "partition_id" FROM "__temp_existing_records_abcd" WHERE "ds" BETWEEN '2024-02-15' AND '2024-04-30' UNION DISTINCT SELECT "_partition_id" AS "partition_id" FROM "__temp_target_abcd") AS "_affected_partitions\"""", + """INSERT INTO "__temp_target_abcd" SELECT "id", "ds" FROM "__temp_existing_records_abcd" WHERE NOT ("ds" BETWEEN '2024-02-15' AND '2024-04-30') AND "_partition_id" IN (SELECT "partition_id" FROM "__temp_target_abcd")""", + """ALTER TABLE "__temp_existing_records_abcd" REPLACE PARTITION ID '1' FROM "__temp_target_abcd", REPLACE PARTITION ID '2' FROM "__temp_target_abcd", REPLACE PARTITION ID '4' FROM "__temp_target_abcd", DROP PARTITION ID '3'""", + 'DROP TABLE IF EXISTS "__temp_target_abcd"', + ] + + +def test_insert_overwrite_by_condition_by_key( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture, make_temp_table_name: t.Callable +): + adapter = make_mocked_engine_adapter(ClickhouseEngineAdapter) + + temp_table_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table") + table_name = "target" + temp_table_mock.return_value = make_temp_table_name(table_name, "abcd") + + fetchone_mock = mocker.patch("sqlmesh.core.engine_adapter.ClickhouseEngineAdapter.fetchone") + fetchone_mock.return_value = None + + insert_table_name = make_temp_table_name("new_records", "abcd") + existing_table_name = make_temp_table_name("existing_records", "abcd") + + source_queries, columns_to_types = adapter._get_source_queries_and_columns_to_types( + parse_one(f"SELECT * FROM {insert_table_name}"), + { + "id": exp.DataType.build("Int8", dialect="clickhouse"), + "ds": exp.DataType.build("Date", dialect="clickhouse"), + }, + existing_table_name, + ) + adapter._insert_overwrite_by_condition( + existing_table_name.sql(), + source_queries, + columns_to_types, + dynamic_key=[exp.column("id")], + dynamic_key_exp=exp.column("id"), + dynamic_key_unique=True, + ) + + adapter._insert_overwrite_by_condition( + existing_table_name.sql(), + source_queries, + columns_to_types, + dynamic_key=[exp.column("id")], + dynamic_key_exp=exp.column("id"), + dynamic_key_unique=False, + ) + + to_sql_calls(adapter) == [ + "CREATE TABLE __temp_target_abcd AS __temp_existing_records_abcd", + 'INSERT INTO "__temp_target_abcd" ("id", "ds") SELECT "id", "ds" FROM (SELECT DISTINCT ON ("id") * FROM "__temp_new_records_abcd") AS "_subquery"', + 'INSERT INTO "__temp_target_abcd" SELECT "id", "ds" FROM "__temp_existing_records_abcd" WHERE NOT ("id" IN (SELECT "id" FROM "__temp_target_abcd"))', + 'EXCHANGE TABLES "__temp_existing_records_abcd" AND "__temp_target_abcd"', + 'DROP TABLE IF EXISTS "__temp_target_abcd"', + "CREATE TABLE __temp_target_abcd AS __temp_existing_records_abcd", + 'INSERT INTO "__temp_target_abcd" ("id", "ds") SELECT "id", "ds" FROM (SELECT * FROM "__temp_new_records_abcd") AS "_subquery"', + 'INSERT INTO "__temp_target_abcd" SELECT "id", "ds" FROM "__temp_existing_records_abcd" WHERE NOT ("id" IN (SELECT "id" FROM "__temp_target_abcd"))', + 'EXCHANGE TABLES "__temp_existing_records_abcd" AND "__temp_target_abcd"', + 'DROP TABLE IF EXISTS "__temp_target_abcd"', + ] + + +def test_insert_overwrite_by_condition_by_key_partitioned( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture, make_temp_table_name: t.Callable +): + adapter = make_mocked_engine_adapter(ClickhouseEngineAdapter) + + temp_table_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table") + table_name = "target" + temp_table_mock.return_value = make_temp_table_name(table_name, "abcd") + + fetchone_mock = mocker.patch("sqlmesh.core.engine_adapter.ClickhouseEngineAdapter.fetchone") + fetchone_mock.side_effect = ["dateTrunc('WEEK', ds)", "dateTrunc('WEEK', ds)"] + + fetchall_mock = mocker.patch("sqlmesh.core.engine_adapter.ClickhouseEngineAdapter.fetchall") + fetchall_mock.side_effect = [ + [("1",), ("2",), ("3",), ("4",)], + ["1", "2", "4"], + [("1",), ("2",), ("3",), ("4",)], + ["1", "2", "4"], + ] + + insert_table_name = make_temp_table_name("new_records", "abcd") + existing_table_name = make_temp_table_name("existing_records", "abcd") + + source_queries, columns_to_types = adapter._get_source_queries_and_columns_to_types( + parse_one(f"SELECT * FROM {insert_table_name}"), + { + "id": exp.DataType.build("Int8", dialect="clickhouse"), + "ds": exp.DataType.build("Date", dialect="clickhouse"), + }, + existing_table_name, + ) + adapter._insert_overwrite_by_condition( + existing_table_name.sql(), + source_queries, + columns_to_types, + dynamic_key=[exp.column("id")], + dynamic_key_exp=exp.column("id"), + dynamic_key_unique=True, + ) + + adapter._insert_overwrite_by_condition( + existing_table_name.sql(), + source_queries, + columns_to_types, + dynamic_key=[exp.column("id")], + dynamic_key_exp=exp.column("id"), + dynamic_key_unique=False, + ) + + to_sql_calls(adapter) == [ + "CREATE TABLE __temp_target_abcd AS __temp_existing_records_abcd", + 'INSERT INTO "__temp_target_abcd" ("id", "ds") SELECT "id", "ds" FROM (SELECT DISTINCT ON ("id") * FROM "__temp_new_records_abcd") AS "_subquery"', + 'CREATE TABLE IF NOT EXISTS "__temp_target_abcd" ENGINE=MergeTree ORDER BY () AS SELECT DISTINCT "partition_id" FROM (SELECT "_partition_id" AS "partition_id" FROM "__temp_existing_records_abcd" WHERE "id" IN (SELECT "id" FROM "__temp_target_abcd") UNION DISTINCT SELECT "_partition_id" AS "partition_id" FROM "__temp_target_abcd") AS "_affected_partitions"', + 'INSERT INTO "__temp_target_abcd" SELECT "id", "ds" FROM "__temp_existing_records_abcd" WHERE NOT ("id" IN (SELECT "id" FROM "__temp_target_abcd")) AND "_partition_id" IN (SELECT "partition_id" FROM "__temp_target_abcd")', + """ALTER TABLE "__temp_existing_records_abcd" REPLACE PARTITION ID '2' FROM "__temp_target_abcd", REPLACE PARTITION ID '1' FROM "__temp_target_abcd", REPLACE PARTITION ID '4' FROM "__temp_target_abcd", DROP PARTITION ID '3'""", + 'DROP TABLE IF EXISTS "__temp_target_abcd"', + "CREATE TABLE __temp_target_abcd AS __temp_existing_records_abcd", + 'INSERT INTO "__temp_target_abcd" ("id", "ds") SELECT "id", "ds" FROM (SELECT * FROM "__temp_new_records_abcd") AS "_subquery"', + 'CREATE TABLE IF NOT EXISTS "__temp_target_abcd" ENGINE=MergeTree ORDER BY () AS SELECT DISTINCT "partition_id" FROM (SELECT "_partition_id" AS "partition_id" FROM "__temp_existing_records_abcd" WHERE "id" IN (SELECT "id" FROM "__temp_target_abcd") UNION DISTINCT SELECT "_partition_id" AS "partition_id" FROM "__temp_target_abcd") AS "_affected_partitions"', + 'INSERT INTO "__temp_target_abcd" SELECT "id", "ds" FROM "__temp_existing_records_abcd" WHERE NOT ("id" IN (SELECT "id" FROM "__temp_target_abcd")) AND "_partition_id" IN (SELECT "partition_id" FROM "__temp_target_abcd")', + """ALTER TABLE "__temp_existing_records_abcd" REPLACE PARTITION ID '2' FROM "__temp_target_abcd", REPLACE PARTITION ID '1' FROM "__temp_target_abcd", REPLACE PARTITION ID '4' FROM "__temp_target_abcd", DROP PARTITION ID '3'""", + 'DROP TABLE IF EXISTS "__temp_target_abcd"', + ] + + +def test_insert_overwrite_by_condition_inc_by_partition( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture, make_temp_table_name: t.Callable +): + adapter = make_mocked_engine_adapter(ClickhouseEngineAdapter) + + temp_table_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table") + table_name = "target" + temp_table_mock.return_value = make_temp_table_name(table_name, "abcd") + + fetchone_mock = mocker.patch("sqlmesh.core.engine_adapter.ClickhouseEngineAdapter.fetchone") + fetchone_mock.return_value = "dateTrunc('WEEK', ds)" + + fetchall_mock = mocker.patch("sqlmesh.core.engine_adapter.ClickhouseEngineAdapter.fetchall") + fetchall_mock.return_value = [("1",), ("2",), ("4",)] + + insert_table_name = make_temp_table_name("new_records", "abcd") + existing_table_name = make_temp_table_name("existing_records", "abcd") + + source_queries, columns_to_types = adapter._get_source_queries_and_columns_to_types( + parse_one(f"SELECT * FROM {insert_table_name}"), + { + "id": exp.DataType.build("Int8", dialect="clickhouse"), + "ds": exp.DataType.build("Date", dialect="clickhouse"), + }, + existing_table_name, + ) + adapter._insert_overwrite_by_condition( + existing_table_name.sql(), + source_queries, + columns_to_types, + keep_existing_partition_rows=False, + ) + + to_sql_calls(adapter) == [ + "CREATE TABLE __temp_target_abcd AS __temp_existing_records_abcd", + 'INSERT INTO "__temp_target_abcd" ("id", "ds") SELECT "id", "ds" FROM (SELECT * FROM "__temp_new_records_abcd") AS "_subquery"', + """ALTER TABLE "__temp_existing_records_abcd" REPLACE PARTITION ID '1' FROM "__temp_target_abcd", REPLACE PARTITION ID '2' FROM "__temp_target_abcd", REPLACE PARTITION ID '4' FROM "__temp_target_abcd\"""", + 'DROP TABLE IF EXISTS "__temp_target_abcd"', + ] + + +def test_to_time_column(): + # we should get DateTime64(6) back for any temporal type other than explicit DateTime64 + expressions = d.parse( + """ + MODEL ( + name db.table, + kind INCREMENTAL_BY_TIME_RANGE( + time_column (ds) + ), + dialect clickhouse + ); + + SELECT ds::datetime + """ + ) + model = load_sql_based_model(expressions) + assert ( + model.convert_to_time_column("2022-01-01 00:00:00.000001").sql("clickhouse") + == "CAST('2022-01-01 00:00:00.000001' AS DateTime64(6))" + ) + + # We should respect the user's DateTime64 precision if specified + expressions = d.parse( + """ + MODEL ( + name db.table, + kind INCREMENTAL_BY_TIME_RANGE( + time_column (ds) + ), + dialect clickhouse + ); + + SELECT ds::DateTime64(4) + """ + ) + model = load_sql_based_model(expressions) + assert ( + model.convert_to_time_column("2022-01-01 00:00:00.000001").sql("clickhouse") + == "CAST('2022-01-01 00:00:00.000001' AS DateTime64(4))" + ) + + # We should respect the user's DateTime64 precision if specified, even if we're making it nullable + from sqlmesh.utils.date import to_time_column + + expressions = d.parse( + """ + MODEL ( + name db.table, + kind INCREMENTAL_BY_TIME_RANGE( + time_column (ds) + ), + dialect clickhouse + ); + + SELECT ds::DateTime64(4) + """ + ) + model = load_sql_based_model(expressions) + assert ( + to_time_column( + "2022-01-01 00:00:00.000001", + exp.DataType.build("DateTime64(4)", dialect="clickhouse"), + dialect="clickhouse", + nullable=True, + ).sql("clickhouse") + == "CAST('2022-01-01 00:00:00.000001' AS Nullable(DateTime64(4)))" + ) + + +def test_exchange_tables( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture, make_temp_table_name: t.Callable +): + from clickhouse_connect.driver.exceptions import DatabaseError # type: ignore + + adapter = make_mocked_engine_adapter(ClickhouseEngineAdapter) + + temp_table_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table") + temp_table_mock.return_value = make_temp_table_name("table1", "abcd") + + execute_mock = mocker.patch("sqlmesh.core.engine_adapter.ClickhouseEngineAdapter.execute") + execute_mock.side_effect = [ + DatabaseError( + "DB::Exception: Moving tables between databases of different engines is not supported. (NOT_IMPLEMENTED)" + ), + None, + None, + None, + ] + + adapter._exchange_tables("table1", "table2") + + # The EXCHANGE TABLES call errored, so we RENAME TABLE instead + assert [ + quote_identifiers(call.args[0]).sql("clickhouse") + if isinstance(call.args[0], exp.Expression) + else call.args[0] + for call in execute_mock.call_args_list + ] == [ + 'EXCHANGE TABLES "table1" AND "table2"', + 'RENAME TABLE "table1" TO "__temp_table1_abcd"', + 'RENAME TABLE "table2" TO "table1"', + 'DROP TABLE IF EXISTS "__temp_table1_abcd"', + ] diff --git a/tests/core/engine_adapter/test_databricks.py b/tests/core/engine_adapter/test_databricks.py index a825af44e1..de91fd3b70 100644 --- a/tests/core/engine_adapter/test_databricks.py +++ b/tests/core/engine_adapter/test_databricks.py @@ -1,13 +1,15 @@ # type: ignore import typing as t -import pandas as pd +import pandas as pd # noqa: TID253 import pytest from pytest_mock import MockFixture from sqlglot import exp, parse_one from sqlmesh.core import dialect as d from sqlmesh.core.engine_adapter import DatabricksEngineAdapter +from sqlmesh.core.engine_adapter.shared import DataObject, DataObjectType +from sqlmesh.core.node import IntervalUnit from tests.core.engine_adapter import to_sql_calls pytestmark = [pytest.mark.databricks, pytest.mark.engine] @@ -18,7 +20,10 @@ def test_replace_query_not_exists(mocker: MockFixture, make_mocked_engine_adapte "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.table_exists", return_value=False, ) - adapter = make_mocked_engine_adapter(DatabricksEngineAdapter) + mocker.patch( + "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog" + ) + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog") adapter.replace_query( "test_table", parse_one("SELECT a FROM tbl"), {"a": exp.DataType.build("INT")} ) @@ -33,7 +38,15 @@ def test_replace_query_exists(mocker: MockFixture, make_mocked_engine_adapter: t "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.table_exists", return_value=True, ) - adapter = make_mocked_engine_adapter(DatabricksEngineAdapter) + mocker.patch( + "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog" + ) + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog") + mocker.patch.object( + adapter, + "_get_data_objects", + return_value=[DataObject(schema="", name="test_table", type="table")], + ) adapter.replace_query("test_table", parse_one("SELECT a FROM tbl"), {"a": "int"}) assert to_sql_calls(adapter) == [ @@ -48,7 +61,10 @@ def test_replace_query_pandas_not_exists( "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.table_exists", return_value=False, ) - adapter = make_mocked_engine_adapter(DatabricksEngineAdapter) + mocker.patch( + "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog" + ) + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog") df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) adapter.replace_query( "test_table", df, {"a": exp.DataType.build("INT"), "b": exp.DataType.build("INT")} @@ -64,7 +80,15 @@ def test_replace_query_pandas_exists(mocker: MockFixture, make_mocked_engine_ada "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.table_exists", return_value=True, ) - adapter = make_mocked_engine_adapter(DatabricksEngineAdapter) + mocker.patch( + "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog" + ) + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog") + mocker.patch.object( + adapter, + "_get_data_objects", + return_value=[DataObject(schema="", name="test_table", type="table")], + ) df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) adapter.replace_query( "test_table", df, {"a": exp.DataType.build("int"), "b": exp.DataType.build("int")} @@ -75,41 +99,230 @@ def test_replace_query_pandas_exists(mocker: MockFixture, make_mocked_engine_ada ] -def test_clone_table(make_mocked_engine_adapter: t.Callable): - adapter = make_mocked_engine_adapter(DatabricksEngineAdapter) +def test_clone_table(mocker: MockFixture, make_mocked_engine_adapter: t.Callable): + mocker.patch( + "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog" + ) + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog") adapter.clone_table("target_table", "source_table") adapter.cursor.execute.assert_called_once_with( - "CREATE TABLE `target_table` SHALLOW CLONE `source_table`" + "CREATE TABLE IF NOT EXISTS `target_table` SHALLOW CLONE `source_table`" ) -def test_set_current_catalog(make_mocked_engine_adapter: t.Callable): - adapter = make_mocked_engine_adapter(DatabricksEngineAdapter) - adapter.set_current_catalog("test_catalog") +def test_set_current_catalog(mocker: MockFixture, make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog") + adapter.set_current_catalog("test_catalog2") - assert to_sql_calls(adapter) == ["USE CATALOG `test_catalog`"] + assert to_sql_calls(adapter) == ["USE CATALOG `test_catalog2`"] -def test_get_current_catalog(make_mocked_engine_adapter: t.Callable): - adapter = make_mocked_engine_adapter(DatabricksEngineAdapter) +def test_get_current_catalog(mocker: MockFixture, make_mocked_engine_adapter: t.Callable): + mocker.patch( + "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog" + ) + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog") adapter.cursor.fetchone.return_value = ("test_catalog",) assert adapter.get_current_catalog() == "test_catalog" assert to_sql_calls(adapter) == ["SELECT CURRENT_CATALOG()"] -def test_get_current_database(make_mocked_engine_adapter: t.Callable): - adapter = make_mocked_engine_adapter(DatabricksEngineAdapter) +def test_get_current_schema(mocker: MockFixture, make_mocked_engine_adapter: t.Callable): + mocker.patch( + "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog" + ) + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog") adapter.cursor.fetchone.return_value = ("test_database",) - assert adapter.get_current_database() == "test_database" + assert adapter._get_current_schema() == "test_database" assert to_sql_calls(adapter) == ["SELECT CURRENT_DATABASE()"] +def test_sync_grants_config(make_mocked_engine_adapter: t.Callable, mocker: MockFixture): + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="main") + relation = exp.to_table("main.test_schema.test_table", dialect="databricks") + new_grants_config = { + "SELECT": ["group1", "group2"], + "MODIFY": ["writers"], + } + + current_grants = [ + ("SELECT", "legacy"), + ("REFRESH", "stale"), + ] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="databricks") + expected_sql = ( + "SELECT privilege_type, grantee FROM main.information_schema.table_privileges " + "WHERE table_catalog = 'main' AND table_schema = 'test_schema' AND table_name = 'test_table' " + "AND grantor = CURRENT_USER() AND grantee <> CURRENT_USER() AND inherited_from = 'NONE'" + ) + assert executed_sql == expected_sql + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 5 + + assert "GRANT SELECT ON TABLE `main`.`test_schema`.`test_table` TO `group1`" in sql_calls + assert "GRANT SELECT ON TABLE `main`.`test_schema`.`test_table` TO `group2`" in sql_calls + assert "GRANT MODIFY ON TABLE `main`.`test_schema`.`test_table` TO `writers`" in sql_calls + assert "REVOKE SELECT ON TABLE `main`.`test_schema`.`test_table` FROM `legacy`" in sql_calls + assert "REVOKE REFRESH ON TABLE `main`.`test_schema`.`test_table` FROM `stale`" in sql_calls + + +def test_sync_grants_config_with_overlaps( + make_mocked_engine_adapter: t.Callable, mocker: MockFixture +): + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="main") + relation = exp.to_table("main.test_schema.test_table", dialect="databricks") + new_grants_config = { + "SELECT": ["shared", "new_role"], + "MODIFY": ["shared", "writer"], + } + + current_grants = [ + ("SELECT", "shared"), + ("SELECT", "legacy"), + ("MODIFY", "shared"), + ] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="databricks") + expected_sql = ( + "SELECT privilege_type, grantee FROM main.information_schema.table_privileges " + "WHERE table_catalog = 'main' AND table_schema = 'test_schema' AND table_name = 'test_table' " + "AND grantor = CURRENT_USER() AND grantee <> CURRENT_USER() AND inherited_from = 'NONE'" + ) + assert executed_sql == expected_sql + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 3 + + assert "GRANT SELECT ON TABLE `main`.`test_schema`.`test_table` TO `new_role`" in sql_calls + assert "GRANT MODIFY ON TABLE `main`.`test_schema`.`test_table` TO `writer`" in sql_calls + assert "REVOKE SELECT ON TABLE `main`.`test_schema`.`test_table` FROM `legacy`" in sql_calls + + +@pytest.mark.parametrize( + "table_type, expected_keyword", + [ + (DataObjectType.TABLE, "TABLE"), + (DataObjectType.VIEW, "VIEW"), + (DataObjectType.MATERIALIZED_VIEW, "MATERIALIZED VIEW"), + (DataObjectType.MANAGED_TABLE, "TABLE"), + ], +) +def test_sync_grants_config_object_kind( + make_mocked_engine_adapter: t.Callable, + mocker: MockFixture, + table_type: DataObjectType, + expected_keyword: str, +) -> None: + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="main") + relation = exp.to_table("main.test_schema.test_object", dialect="databricks") + + mocker.patch.object(adapter, "fetchall", return_value=[]) + + adapter.sync_grants_config(relation, {"SELECT": ["test"]}, table_type) + + sql_calls = to_sql_calls(adapter) + assert sql_calls == [ + f"GRANT SELECT ON {expected_keyword} `main`.`test_schema`.`test_object` TO `test`" + ] + + +def test_sync_grants_config_quotes(make_mocked_engine_adapter: t.Callable, mocker: MockFixture): + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="`test_db`") + relation = exp.to_table("`test_db`.`test_schema`.`test_table`", dialect="databricks") + new_grants_config = { + "SELECT": ["group1", "group2"], + "MODIFY": ["writers"], + } + + current_grants = [ + ("SELECT", "legacy"), + ("REFRESH", "stale"), + ] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="databricks") + expected_sql = ( + "SELECT privilege_type, grantee FROM `test_db`.information_schema.table_privileges " + "WHERE table_catalog = 'test_db' AND table_schema = 'test_schema' AND table_name = 'test_table' " + "AND grantor = CURRENT_USER() AND grantee <> CURRENT_USER() AND inherited_from = 'NONE'" + ) + assert executed_sql == expected_sql + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 5 + + assert "GRANT SELECT ON TABLE `test_db`.`test_schema`.`test_table` TO `group1`" in sql_calls + assert "GRANT SELECT ON TABLE `test_db`.`test_schema`.`test_table` TO `group2`" in sql_calls + assert "GRANT MODIFY ON TABLE `test_db`.`test_schema`.`test_table` TO `writers`" in sql_calls + assert "REVOKE SELECT ON TABLE `test_db`.`test_schema`.`test_table` FROM `legacy`" in sql_calls + assert "REVOKE REFRESH ON TABLE `test_db`.`test_schema`.`test_table` FROM `stale`" in sql_calls + + +def test_sync_grants_config_no_catalog_or_schema( + make_mocked_engine_adapter: t.Callable, mocker: MockFixture +): + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="main_catalog") + relation = exp.to_table("test_table", dialect="databricks") + new_grants_config = { + "SELECT": ["group1", "group2"], + "MODIFY": ["writers"], + } + + current_grants = [ + ("SELECT", "legacy"), + ("REFRESH", "stale"), + ] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + mocker.patch.object(adapter, "_get_current_schema", return_value="schema") + mocker.patch.object(adapter, "get_current_catalog", return_value="main_catalog") + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="databricks") + expected_sql = ( + "SELECT privilege_type, grantee FROM `main_catalog`.information_schema.table_privileges " + "WHERE table_catalog = 'main_catalog' AND table_schema = 'schema' AND table_name = 'test_table' " + "AND grantor = CURRENT_USER() AND grantee <> CURRENT_USER() AND inherited_from = 'NONE'" + ) + assert executed_sql == expected_sql + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 5 + + assert "GRANT SELECT ON TABLE `test_table` TO `group1`" in sql_calls + assert "GRANT SELECT ON TABLE `test_table` TO `group2`" in sql_calls + assert "GRANT MODIFY ON TABLE `test_table` TO `writers`" in sql_calls + assert "REVOKE SELECT ON TABLE `test_table` FROM `legacy`" in sql_calls + assert "REVOKE REFRESH ON TABLE `test_table` FROM `stale`" in sql_calls + + def test_insert_overwrite_by_partition_query( make_mocked_engine_adapter: t.Callable, mocker: MockFixture, make_temp_table_name: t.Callable ): - adapter = make_mocked_engine_adapter(DatabricksEngineAdapter) + mocker.patch( + "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog" + ) + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog") temp_table_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table") table_name = "test_schema.test_table" @@ -123,7 +336,7 @@ def test_insert_overwrite_by_partition_query( d.parse_one("DATETIME_TRUNC(ds, MONTH)"), d.parse_one("b"), ], - columns_to_types={ + target_columns_to_types={ "a": exp.DataType.build("int"), "ds": exp.DataType.build("DATETIME"), "b": exp.DataType.build("boolean"), @@ -136,3 +349,180 @@ def test_insert_overwrite_by_partition_query( "INSERT INTO `test_schema`.`test_table` REPLACE WHERE CONCAT_WS('__SQLMESH_DELIM__', DATE_TRUNC('MONTH', `ds`), `b`) IN (SELECT DISTINCT CONCAT_WS('__SQLMESH_DELIM__', DATE_TRUNC('MONTH', `ds`), `b`) FROM `test_schema`.`temp_test_table_abcdefgh`) SELECT `a`, `ds`, `b` FROM `test_schema`.`temp_test_table_abcdefgh`", "DROP TABLE IF EXISTS `test_schema`.`temp_test_table_abcdefgh`", ] + + +def test_materialized_view_properties(mocker: MockFixture, make_mocked_engine_adapter: t.Callable): + mocker.patch( + "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog" + ) + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog") + + adapter.create_view( + "test_table", + parse_one("SELECT 1"), + materialized=True, + materialized_properties={ + "partitioned_by": [exp.column("ds")], + # Clustered by is not supported so we are confirming it is ignored + "clustered_by": [exp.column("a")], + "partition_interval_unit": IntervalUnit.DAY, + }, + ) + + sql_calls = to_sql_calls(adapter) + # https://docs.databricks.com/en/sql/language-manual/sql-ref-syntax-ddl-create-materialized-view.html#syntax + assert sql_calls == [ + "CREATE OR REPLACE MATERIALIZED VIEW `test_table` PARTITIONED BY (`ds`) AS SELECT 1", + ] + + +def test_materialized_view_with_column_comments( + mocker: MockFixture, make_mocked_engine_adapter: t.Callable +): + mocker.patch( + "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog" + ) + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog") + mocker.patch.object(adapter, "get_current_catalog", return_value="test_catalog") + + adapter.create_view( + "test_view", + parse_one("SELECT a, b FROM source_table"), + target_columns_to_types={ + "a": exp.DataType.build("INT"), + "b": exp.DataType.build("STRING"), + }, + materialized=True, + column_descriptions={ + "a": "column a description", + "b": "column b description", + }, + ) + + sql_calls = to_sql_calls(adapter) + # Databricks requires column types when column comments are present in materialized views + assert sql_calls == [ + "CREATE OR REPLACE MATERIALIZED VIEW `test_view` (`a` INT COMMENT 'column a description', `b` STRING COMMENT 'column b description') AS SELECT `a`, `b` FROM `source_table`", + ] + + +def test_regular_view_with_column_comments( + mocker: MockFixture, make_mocked_engine_adapter: t.Callable +): + mocker.patch( + "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog" + ) + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog") + mocker.patch.object(adapter, "get_current_catalog", return_value="test_catalog") + + adapter.create_view( + "test_view", + parse_one("SELECT a, b FROM source_table"), + target_columns_to_types={ + "a": exp.DataType.build("INT"), + "b": exp.DataType.build("STRING"), + }, + materialized=False, + column_descriptions={ + "a": "column a description", + "b": "column b description", + }, + ) + + sql_calls = to_sql_calls(adapter) + # Regular views should NOT include column types even when column comments are present + assert sql_calls == [ + "CREATE OR REPLACE VIEW `test_view` (`a` COMMENT 'column a description', `b` COMMENT 'column b description') AS SELECT `a`, `b` FROM `source_table`", + ] + + +def test_create_table_clustered_by(mocker: MockFixture, make_mocked_engine_adapter: t.Callable): + mocker.patch( + "sqlmesh.core.engine_adapter.databricks.DatabricksEngineAdapter.set_current_catalog" + ) + adapter = make_mocked_engine_adapter(DatabricksEngineAdapter, default_catalog="test_catalog") + + columns_to_types = { + "cola": exp.DataType.build("INT"), + "colb": exp.DataType.build("TEXT"), + } + adapter.create_table( + "test_table", + columns_to_types, + clustered_by=[exp.column("cola")], + ) + + sql_calls = to_sql_calls(adapter) + assert sql_calls == [ + "CREATE TABLE IF NOT EXISTS `test_table` (`cola` INT, `colb` STRING) CLUSTER BY (`cola`)", + ] + + +def test_get_data_objects_distinguishes_view_types(mocker): + adapter = DatabricksEngineAdapter(lambda: None, default_catalog="test_catalog") + + # (Databricks requires DBSQL Serverless or Pro warehouse to test materialized views which we do not have setup) + # so this mocks the fetchdf call to simulate the response we would expect from the correct SQL query + mock_df = pd.DataFrame( + [ + { + "name": "regular_view", + "schema": "test_schema", + "catalog": "test_catalog", + "type": "view", + }, + { + "name": "mat_view", + "schema": "test_schema", + "catalog": "test_catalog", + "type": "materialized_view", + }, + { + "name": "regular_table", + "schema": "test_schema", + "catalog": "test_catalog", + "type": "table", + }, + ] + ) + + mocker.patch.object(adapter, "fetchdf", return_value=mock_df) + + data_objects = adapter._get_data_objects( + schema_name=exp.Table(db="test_schema", catalog="test_catalog") + ) + + adapter.fetchdf.assert_called_once() + call_args = adapter.fetchdf.call_args + sql_query_exp = call_args[0][0] + + # _get_data_objects query should distinguish between VIEW and MATERIALIZED_VIEW types + sql_query = sql_query_exp.sql(dialect="databricks") + assert ( + "CASE table_type WHEN 'VIEW' THEN 'view' WHEN 'MATERIALIZED_VIEW' THEN 'materialized_view' ELSE 'table' END AS type" + in sql_query + ) + + objects_by_name = {obj.name: obj for obj in data_objects} + assert objects_by_name["regular_view"].type == DataObjectType.VIEW + assert objects_by_name["mat_view"].type == DataObjectType.MATERIALIZED_VIEW + assert objects_by_name["regular_table"].type == DataObjectType.TABLE + + +def test_drop_data_object_materialized_view_calls_correct_drop(mocker: MockFixture): + adapter = DatabricksEngineAdapter(lambda: None, default_catalog="test_catalog") + + mv_data_object = DataObject( + catalog="test_catalog", + schema="test_schema", + name="test_mv", + type=DataObjectType.MATERIALIZED_VIEW, + ) + + drop_view_mock = mocker.patch.object(adapter, "drop_view") + adapter.drop_data_object(mv_data_object) + + # Ensure drop_view is called with materialized=True + drop_view_mock.assert_called_once_with( + mv_data_object.to_table(), ignore_if_not_exists=True, materialized=True + ) diff --git a/tests/core/engine_adapter/test_duckdb.py b/tests/core/engine_adapter/test_duckdb.py index 79cfd8d950..9fd65a6e66 100644 --- a/tests/core/engine_adapter/test_duckdb.py +++ b/tests/core/engine_adapter/test_duckdb.py @@ -1,10 +1,10 @@ import typing as t -import pandas as pd +import pandas as pd # noqa: TID253 import pytest +from pytest_mock.plugin import MockerFixture from sqlglot import expressions as exp from sqlglot import parse_one - from sqlmesh.core.engine_adapter import DuckDBEngineAdapter, EngineAdapter from tests.core.engine_adapter import to_sql_calls @@ -75,3 +75,82 @@ def test_set_current_catalog(make_mocked_engine_adapter: t.Callable, duck_conn): assert to_sql_calls(adapter) == [ 'USE "test_catalog"', ] + + +def test_temporary_table(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): + adapter = make_mocked_engine_adapter(DuckDBEngineAdapter) + + mocker.patch.object(adapter, "get_current_catalog", return_value="test_catalog") + mocker.patch.object(adapter, "fetchone", return_value=("test_catalog",)) + + adapter.create_table( + "test_table", + {"a": exp.DataType.build("INT"), "b": exp.DataType.build("INT")}, + table_properties={"creatable_type": exp.Column(this=exp.Identifier(this="Temporary"))}, + ) + + assert to_sql_calls(adapter) == [ + 'CREATE TEMPORARY TABLE IF NOT EXISTS "test_table" ("a" INT, "b" INT)', + ] + + +def test_create_catalog(make_mocked_engine_adapter: t.Callable) -> None: + adapter: DuckDBEngineAdapter = make_mocked_engine_adapter(DuckDBEngineAdapter) + adapter.create_catalog(exp.to_identifier("foo")) + + assert to_sql_calls(adapter) == ["ATTACH IF NOT EXISTS 'foo.db' AS \"foo\""] + + +def test_create_catalog_motherduck(make_mocked_engine_adapter: t.Callable) -> None: + adapter: DuckDBEngineAdapter = make_mocked_engine_adapter( + DuckDBEngineAdapter, is_motherduck=True + ) + adapter.create_catalog(exp.to_identifier("foo")) + + assert to_sql_calls(adapter) == ['CREATE DATABASE IF NOT EXISTS "foo"'] + + +def test_drop_catalog(make_mocked_engine_adapter: t.Callable) -> None: + adapter: DuckDBEngineAdapter = make_mocked_engine_adapter(DuckDBEngineAdapter) + adapter.drop_catalog(exp.to_identifier("foo")) + + assert to_sql_calls(adapter) == ['DETACH DATABASE IF EXISTS "foo"'] + + +def test_drop_catalog_motherduck(make_mocked_engine_adapter: t.Callable) -> None: + adapter: DuckDBEngineAdapter = make_mocked_engine_adapter( + DuckDBEngineAdapter, is_motherduck=True + ) + adapter.drop_catalog(exp.to_identifier("foo")) + + assert to_sql_calls(adapter) == ['DROP DATABASE IF EXISTS "foo" CASCADE'] + + +def test_ducklake_partitioning(adapter: EngineAdapter, duck_conn, tmp_path): + catalog = "a_ducklake_db" + + duck_conn.install_extension("ducklake") + duck_conn.load_extension("ducklake") + duck_conn.execute( + f"ATTACH 'ducklake:{tmp_path}/{catalog}.ducklake' AS {catalog} (DATA_PATH '{tmp_path}');" + ) + + # no partitions on catalog creation + partition_info = duck_conn.execute( + f"SELECT * FROM __ducklake_metadata_{catalog}.main.ducklake_partition_info" + ).fetchdf() + assert partition_info.empty + + adapter.set_current_catalog(catalog) + adapter.create_schema("test_schema") + adapter.create_table( + "test_schema.test_table", + {"a": exp.DataType.build("INT"), "b": exp.DataType.build("INT")}, + partitioned_by=[exp.to_column("a"), exp.to_column("b")], + ) + + # 1 partition after table creation + partition_info = duck_conn.execute( + f"SELECT * FROM __ducklake_metadata_{catalog}.main.ducklake_partition_info" + ).fetchdf() + assert partition_info.shape[0] == 1 diff --git a/tests/core/engine_adapter/test_fabric.py b/tests/core/engine_adapter/test_fabric.py new file mode 100644 index 0000000000..a52218a097 --- /dev/null +++ b/tests/core/engine_adapter/test_fabric.py @@ -0,0 +1,288 @@ +# type: ignore + +import typing as t + +import pandas as pd # noqa: TID253 +import pytest +from pytest_mock import MockerFixture +from sqlglot import exp, parse_one + +from sqlmesh.core.engine_adapter import FabricEngineAdapter +from tests.core.engine_adapter import to_sql_calls +from sqlmesh.core.engine_adapter.shared import DataObject + +pytestmark = [pytest.mark.engine, pytest.mark.fabric] + + +@pytest.fixture +def adapter(make_mocked_engine_adapter: t.Callable) -> FabricEngineAdapter: + return make_mocked_engine_adapter(FabricEngineAdapter) + + +def test_columns(adapter: FabricEngineAdapter): + adapter.cursor.fetchall.return_value = [ + ("decimal_ps", "decimal", None, 5, 4), + ("decimal", "decimal", None, 18, 0), + ("float", "float", None, 53, None), + ("char_n", "char", 10, None, None), + ("varchar_n", "varchar", 10, None, None), + ("nvarchar_max", "nvarchar", -1, None, None), + ] + + assert adapter.columns("db.table") == { + "decimal_ps": exp.DataType.build("decimal(5, 4)", dialect=adapter.dialect), + "decimal": exp.DataType.build("decimal(18, 0)", dialect=adapter.dialect), + "float": exp.DataType.build("float(53)", dialect=adapter.dialect), + "char_n": exp.DataType.build("char(10)", dialect=adapter.dialect), + "varchar_n": exp.DataType.build("varchar(10)", dialect=adapter.dialect), + "nvarchar_max": exp.DataType.build("nvarchar(max)", dialect=adapter.dialect), + } + + # Verify that the adapter queries the uppercase INFORMATION_SCHEMA + adapter.cursor.execute.assert_called_once_with( + """SELECT [COLUMN_NAME], [DATA_TYPE], [CHARACTER_MAXIMUM_LENGTH], [NUMERIC_PRECISION], [NUMERIC_SCALE] FROM [INFORMATION_SCHEMA].[COLUMNS] WHERE [TABLE_NAME] = 'table' AND [TABLE_SCHEMA] = 'db';""" + ) + + +def test_table_exists(adapter: FabricEngineAdapter): + adapter.cursor.fetchone.return_value = (1,) + assert adapter.table_exists("db.table") + # Verify that the adapter queries the uppercase INFORMATION_SCHEMA + adapter.cursor.execute.assert_called_once_with( + """SELECT 1 FROM [INFORMATION_SCHEMA].[TABLES] WHERE [TABLE_NAME] = 'table' AND [TABLE_SCHEMA] = 'db';""" + ) + + adapter.cursor.fetchone.return_value = None + assert not adapter.table_exists("db.table") + + +def test_insert_overwrite_by_time_partition(adapter: FabricEngineAdapter): + adapter.insert_overwrite_by_time_partition( + "test_table", + parse_one("SELECT a, b FROM tbl"), + start="2022-01-01", + end="2022-01-02", + time_column="b", + time_formatter=lambda x, _: exp.Literal.string(x.strftime("%Y-%m-%d")), + target_columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("STRING")}, + ) + + # Fabric adapter should use DELETE/INSERT strategy, not MERGE. + assert to_sql_calls(adapter) == [ + """DELETE FROM [test_table] WHERE [b] BETWEEN '2022-01-01' AND '2022-01-02';""", + """INSERT INTO [test_table] ([a], [b]) SELECT [a], [b] FROM (SELECT [a] AS [a], [b] AS [b] FROM [tbl]) AS [_subquery] WHERE [b] BETWEEN '2022-01-01' AND '2022-01-02';""", + ] + + +def test_replace_query(adapter: FabricEngineAdapter, mocker: MockerFixture): + mocker.patch.object( + adapter, + "_get_data_objects", + return_value=[DataObject(schema="", name="test_table", type="table")], + ) + adapter.replace_query( + "test_table", parse_one("SELECT a FROM tbl"), {"a": exp.DataType.build("int")} + ) + + # This behavior is inherited from MSSQLEngineAdapter and should be TRUNCATE + INSERT + assert to_sql_calls(adapter) == [ + "TRUNCATE TABLE [test_table];", + "INSERT INTO [test_table] ([a]) SELECT [a] FROM [tbl];", + ] + + +def test_alter_table_column_type_workaround(adapter: FabricEngineAdapter, mocker: MockerFixture): + """ + Tests the alter_table method's workaround for changing a column's data type. + """ + # Mock set_current_catalog to avoid connection pool side effects + set_catalog_mock = mocker.patch.object(adapter, "set_current_catalog") + # Mock random_id to have a predictable temporary column name + mocker.patch("sqlmesh.core.engine_adapter.fabric.random_id", return_value="abcdef") + + alter_expression = exp.Alter( + this=exp.to_table("my_db.my_schema.my_table"), + actions=[ + exp.AlterColumn( + this=exp.to_column("col_a"), + dtype=exp.DataType.build("BIGINT"), + ) + ], + ) + + adapter.alter_table([alter_expression]) + + set_catalog_mock.assert_called_once_with("my_db") + + expected_calls = [ + "ALTER TABLE [my_schema].[my_table] ADD [col_a__abcdef] BIGINT;", + "UPDATE [my_schema].[my_table] SET [col_a__abcdef] = CAST([col_a] AS BIGINT);", + "ALTER TABLE [my_schema].[my_table] DROP COLUMN [col_a];", + "EXEC sp_rename 'my_schema.my_table.col_a__abcdef', 'col_a', 'COLUMN'", + ] + + assert to_sql_calls(adapter) == expected_calls + + +def test_alter_table_direct_alteration(adapter: FabricEngineAdapter, mocker: MockerFixture): + """ + Tests the alter_table method for direct alterations like adding a column. + """ + set_catalog_mock = mocker.patch.object(adapter, "set_current_catalog") + + alter_expression = exp.Alter( + this=exp.to_table("my_db.my_schema.my_table"), + actions=[exp.ColumnDef(this=exp.to_column("new_col"), kind=exp.DataType.build("INT"))], + ) + + adapter.alter_table([alter_expression]) + + set_catalog_mock.assert_called_once_with("my_db") + + expected_calls = [ + "ALTER TABLE [my_schema].[my_table] ADD [new_col] INT;", + ] + + assert to_sql_calls(adapter) == expected_calls + + +def test_merge_pandas( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture, make_temp_table_name: t.Callable +): + mocker.patch( + "sqlmesh.core.engine_adapter.fabric.FabricEngineAdapter.table_exists", + return_value=False, + ) + + adapter = make_mocked_engine_adapter(FabricEngineAdapter) + + temp_table_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table") + table_name = "target" + temp_table_id = "abcdefgh" + temp_table_mock.return_value = make_temp_table_name(table_name, temp_table_id) + + df = pd.DataFrame({"id": [1, 2, 3], "ts": [1, 2, 3], "val": [4, 5, 6]}) + + # 1 key + adapter.merge( + target_table=table_name, + source_table=df, + target_columns_to_types={ + "id": exp.DataType.build("int"), + "ts": exp.DataType.build("TIMESTAMP"), + "val": exp.DataType.build("int"), + }, + unique_key=[exp.to_identifier("id")], + ) + adapter._connection_pool.get().bulk_copy.assert_called_with( + f"__temp_target_{temp_table_id}", [(1, 1, 4), (2, 2, 5), (3, 3, 6)] + ) + + assert to_sql_calls(adapter) == [ + f"""IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '__temp_target_{temp_table_id}') EXEC('CREATE TABLE [__temp_target_{temp_table_id}] ([id] INT, [ts] DATETIME2(6), [val] INT)');""", + f"MERGE INTO [target] AS [__MERGE_TARGET__] USING (SELECT CAST([id] AS INT) AS [id], CAST([ts] AS DATETIME2(6)) AS [ts], CAST([val] AS INT) AS [val] FROM [__temp_target_{temp_table_id}]) AS [__MERGE_SOURCE__] ON [__MERGE_TARGET__].[id] = [__MERGE_SOURCE__].[id] WHEN MATCHED THEN UPDATE SET [__MERGE_TARGET__].[ts] = [__MERGE_SOURCE__].[ts], [__MERGE_TARGET__].[val] = [__MERGE_SOURCE__].[val] WHEN NOT MATCHED THEN INSERT ([id], [ts], [val]) VALUES ([__MERGE_SOURCE__].[id], [__MERGE_SOURCE__].[ts], [__MERGE_SOURCE__].[val]);", + f"DROP TABLE IF EXISTS [__temp_target_{temp_table_id}];", + ] + + # 2 keys + adapter.cursor.reset_mock() + adapter._connection_pool.get().reset_mock() + temp_table_mock.return_value = make_temp_table_name(table_name, temp_table_id) + adapter.merge( + target_table=table_name, + source_table=df, + target_columns_to_types={ + "id": exp.DataType.build("int"), + "ts": exp.DataType.build("TIMESTAMP"), + "val": exp.DataType.build("int"), + }, + unique_key=[exp.to_identifier("id"), exp.to_column("ts")], + ) + adapter._connection_pool.get().bulk_copy.assert_called_with( + f"__temp_target_{temp_table_id}", [(1, 1, 4), (2, 2, 5), (3, 3, 6)] + ) + + assert to_sql_calls(adapter) == [ + f"""IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '__temp_target_{temp_table_id}') EXEC('CREATE TABLE [__temp_target_{temp_table_id}] ([id] INT, [ts] DATETIME2(6), [val] INT)');""", + f"MERGE INTO [target] AS [__MERGE_TARGET__] USING (SELECT CAST([id] AS INT) AS [id], CAST([ts] AS DATETIME2(6)) AS [ts], CAST([val] AS INT) AS [val] FROM [__temp_target_{temp_table_id}]) AS [__MERGE_SOURCE__] ON [__MERGE_TARGET__].[id] = [__MERGE_SOURCE__].[id] AND [__MERGE_TARGET__].[ts] = [__MERGE_SOURCE__].[ts] WHEN MATCHED THEN UPDATE SET [__MERGE_TARGET__].[val] = [__MERGE_SOURCE__].[val] WHEN NOT MATCHED THEN INSERT ([id], [ts], [val]) VALUES ([__MERGE_SOURCE__].[id], [__MERGE_SOURCE__].[ts], [__MERGE_SOURCE__].[val]);", + f"DROP TABLE IF EXISTS [__temp_target_{temp_table_id}];", + ] + + +def test_merge_exists( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture, make_temp_table_name: t.Callable +): + mocker.patch( + "sqlmesh.core.engine_adapter.fabric.FabricEngineAdapter.table_exists", + return_value=False, + ) + + adapter = make_mocked_engine_adapter(FabricEngineAdapter) + + temp_table_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table") + table_name = "target" + temp_table_id = "abcdefgh" + temp_table_mock.return_value = make_temp_table_name(table_name, temp_table_id) + + df = pd.DataFrame({"id": [1, 2, 3], "ts": [1, 2, 3], "val": [4, 5, 6]}) + + # regular implementation + adapter.merge( + target_table=table_name, + source_table=df, + target_columns_to_types={ + "id": exp.DataType.build("int"), + "ts": exp.DataType.build("TIMESTAMP"), + "val": exp.DataType.build("int"), + }, + unique_key=[exp.to_identifier("id")], + ) + + assert to_sql_calls(adapter) == [ + f"""IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '__temp_target_{temp_table_id}') EXEC('CREATE TABLE [__temp_target_{temp_table_id}] ([id] INT, [ts] DATETIME2(6), [val] INT)');""", + f"MERGE INTO [target] AS [__MERGE_TARGET__] USING (SELECT CAST([id] AS INT) AS [id], CAST([ts] AS DATETIME2(6)) AS [ts], CAST([val] AS INT) AS [val] FROM [__temp_target_{temp_table_id}]) AS [__MERGE_SOURCE__] ON [__MERGE_TARGET__].[id] = [__MERGE_SOURCE__].[id] WHEN MATCHED THEN UPDATE SET [__MERGE_TARGET__].[ts] = [__MERGE_SOURCE__].[ts], [__MERGE_TARGET__].[val] = [__MERGE_SOURCE__].[val] WHEN NOT MATCHED THEN INSERT ([id], [ts], [val]) VALUES ([__MERGE_SOURCE__].[id], [__MERGE_SOURCE__].[ts], [__MERGE_SOURCE__].[val]);", + f"DROP TABLE IF EXISTS [__temp_target_{temp_table_id}];", + ] + + # merge exists implementation + adapter.cursor.reset_mock() + adapter._connection_pool.get().reset_mock() + temp_table_mock.return_value = make_temp_table_name(table_name, temp_table_id) + adapter.merge( + target_table=table_name, + source_table=df, + target_columns_to_types={ + "id": exp.DataType.build("int"), + "ts": exp.DataType.build("TIMESTAMP"), + "val": exp.DataType.build("int"), + }, + unique_key=[exp.to_identifier("id")], + physical_properties={"mssql_merge_exists": True}, + ) + + assert to_sql_calls(adapter) == [ + f"""IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '__temp_target_{temp_table_id}') EXEC('CREATE TABLE [__temp_target_{temp_table_id}] ([id] INT, [ts] DATETIME2(6), [val] INT)');""", + f"MERGE INTO [target] AS [__MERGE_TARGET__] USING (SELECT CAST([id] AS INT) AS [id], CAST([ts] AS DATETIME2(6)) AS [ts], CAST([val] AS INT) AS [val] FROM [__temp_target_{temp_table_id}]) AS [__MERGE_SOURCE__] ON [__MERGE_TARGET__].[id] = [__MERGE_SOURCE__].[id] WHEN MATCHED AND EXISTS(SELECT [__MERGE_TARGET__].[ts], [__MERGE_TARGET__].[val] EXCEPT SELECT [__MERGE_SOURCE__].[ts], [__MERGE_SOURCE__].[val]) THEN UPDATE SET [__MERGE_TARGET__].[ts] = [__MERGE_SOURCE__].[ts], [__MERGE_TARGET__].[val] = [__MERGE_SOURCE__].[val] WHEN NOT MATCHED THEN INSERT ([id], [ts], [val]) VALUES ([__MERGE_SOURCE__].[id], [__MERGE_SOURCE__].[ts], [__MERGE_SOURCE__].[val]);", + f"DROP TABLE IF EXISTS [__temp_target_{temp_table_id}];", + ] + + # merge exists and all model columns are keys + adapter.cursor.reset_mock() + adapter._connection_pool.get().reset_mock() + temp_table_mock.return_value = make_temp_table_name(table_name, temp_table_id) + adapter.merge( + target_table=table_name, + source_table=df, + target_columns_to_types={ + "id": exp.DataType.build("int"), + "ts": exp.DataType.build("TIMESTAMP"), + }, + unique_key=[exp.to_identifier("id"), exp.to_column("ts")], + physical_properties={"mssql_merge_exists": True}, + ) + + assert to_sql_calls(adapter) == [ + f"""IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '__temp_target_{temp_table_id}') EXEC('CREATE TABLE [__temp_target_{temp_table_id}] ([id] INT, [ts] DATETIME2(6))');""", + f"MERGE INTO [target] AS [__MERGE_TARGET__] USING (SELECT CAST([id] AS INT) AS [id], CAST([ts] AS DATETIME2(6)) AS [ts] FROM [__temp_target_{temp_table_id}]) AS [__MERGE_SOURCE__] ON [__MERGE_TARGET__].[id] = [__MERGE_SOURCE__].[id] AND [__MERGE_TARGET__].[ts] = [__MERGE_SOURCE__].[ts] WHEN NOT MATCHED THEN INSERT ([id], [ts]) VALUES ([__MERGE_SOURCE__].[id], [__MERGE_SOURCE__].[ts]);", + f"DROP TABLE IF EXISTS [__temp_target_{temp_table_id}];", + ] diff --git a/tests/core/engine_adapter/test_integration.py b/tests/core/engine_adapter/test_integration.py deleted file mode 100644 index 950e24e001..0000000000 --- a/tests/core/engine_adapter/test_integration.py +++ /dev/null @@ -1,2527 +0,0 @@ -# type: ignore -from __future__ import annotations - -import os -import pathlib -import sys -import typing as t -from datetime import datetime, timedelta - -import numpy as np -import pandas as pd -import pytest -from sqlglot import exp, parse_one -from sqlglot.optimizer.normalize_identifiers import normalize_identifiers - -from sqlmesh import Config, Context, EngineAdapter -from sqlmesh.cli.example_project import init_example_project -from sqlmesh.core.config import load_config_from_paths -from sqlmesh.core.dialect import normalize_model_name -import sqlmesh.core.dialect as d -from sqlmesh.core.engine_adapter import SparkEngineAdapter, TrinoEngineAdapter -from sqlmesh.core.model import Model, load_sql_based_model -from sqlmesh.core.engine_adapter.shared import DataObject, DataObjectType -from sqlmesh.core.model.definition import create_sql_model -from sqlmesh.core.plan import Plan -from sqlmesh.core.snapshot import Snapshot, SnapshotChangeCategory -from sqlmesh.utils import random_id -from sqlmesh.utils.date import now, to_date, to_ds, to_time_column, yesterday -from sqlmesh.utils.pydantic import PydanticModel -from tests.conftest import SushiDataValidator -from tests.utils.pandas import compare_dataframes - -if t.TYPE_CHECKING: - from sqlmesh.core.engine_adapter._typing import Query - - -TEST_SCHEMA = "test_schema" - - -class TestContext: - __test__ = False # prevent pytest trying to collect this as a test class - - def __init__( - self, - test_type: str, - engine_adapter: EngineAdapter, - gateway: str, - columns_to_types: t.Optional[t.Dict[str, t.Union[str, exp.DataType]]] = None, - ): - self.test_type = test_type - self.engine_adapter = engine_adapter - self.gateway = gateway - self._columns_to_types = columns_to_types - self.test_id = random_id(short=True) - self._context = None - - @property - def columns_to_types(self): - if self._columns_to_types is None: - self._columns_to_types = { - "id": exp.DataType.build("int"), - "ds": exp.DataType.build("string"), - } - return self._columns_to_types - - @columns_to_types.setter - def columns_to_types(self, value: t.Dict[str, t.Union[str, exp.DataType]]): - self._columns_to_types = { - k: exp.DataType.build(v, dialect=self.dialect) for k, v in value.items() - } - - @property - def time_columns(self) -> t.List[str]: - return [ - k - for k, v in self.columns_to_types.items() - if v.sql().lower().startswith("timestamp") - or v.sql().lower().startswith("date") - or k.lower() == "ds" - ] - - @property - def timestamp_columns(self) -> t.List[str]: - return [ - k - for k, v in self.columns_to_types.items() - if v.sql().lower().startswith("timestamp") - or (v.sql().lower() == "datetime" and self.dialect == "bigquery") - ] - - @property - def time_column(self) -> str: - return self.time_columns[0] - - @property - def time_formatter(self) -> t.Callable: - return lambda x, _: exp.Literal.string(to_ds(x)) - - @property - def partitioned_by(self) -> t.List[exp.Expression]: - return [parse_one(self.time_column)] - - @property - def dialect(self) -> str: - return self.engine_adapter.dialect - - @property - def current_catalog_type(self) -> str: - return self.engine_adapter.current_catalog_type - - @property - def supports_merge(self) -> bool: - if self.dialect == "spark": - engine_adapter: SparkEngineAdapter = self.engine_adapter - # Spark supports MERGE on the Iceberg catalog (which is configured under "testing" in these integration tests) - return engine_adapter.default_catalog == "testing" - - if self.dialect == "trino": - engine_adapter: TrinoEngineAdapter = self.engine_adapter - # Trino supports MERGE on Delta and Iceberg but not Hive - return engine_adapter.get_catalog_type(engine_adapter.default_catalog) != "hive" - - return True - - def add_test_suffix(self, value: str) -> str: - return f"{value}_{self.test_id}" - - def get_metadata_results(self, schema: t.Optional[str] = None) -> MetadataResults: - schema = schema if schema else self.schema(TEST_SCHEMA) - return MetadataResults.from_data_objects(self.engine_adapter.get_data_objects(schema)) - - def _init_engine_adapter(self) -> None: - schema = self.schema(TEST_SCHEMA) - self.engine_adapter.drop_schema(schema, ignore_if_not_exists=True, cascade=True) - self.engine_adapter.create_schema(schema) - - def _format_df(self, data: pd.DataFrame, to_datetime: bool = True) -> pd.DataFrame: - for timestamp_column in self.timestamp_columns: - if timestamp_column in data.columns: - value = data[timestamp_column] - if to_datetime: - value = pd.to_datetime(value) - data[timestamp_column] = value.astype("datetime64[ns]") - return data - - def init(self): - if self.test_type == "pyspark" and not hasattr(self.engine_adapter, "is_pyspark_df"): - pytest.skip(f"Engine adapter {self.engine_adapter} doesn't support pyspark") - self._init_engine_adapter() - - def input_data( - self, - data: pd.DataFrame, - columns_to_types: t.Optional[t.Dict[str, exp.DataType]] = None, - ) -> t.Union[Query, pd.DataFrame]: - columns_to_types = columns_to_types or self.columns_to_types - if self.test_type == "query": - return self.engine_adapter._values_to_sql( - list(data.itertuples(index=False, name=None)), - batch_start=0, - batch_end=sys.maxsize, - columns_to_types=columns_to_types, - ) - elif self.test_type == "pyspark": - return self.engine_adapter.spark.createDataFrame(data) # type: ignore - return self._format_df(data, to_datetime=self.dialect != "trino") - - def output_data(self, data: pd.DataFrame) -> pd.DataFrame: - return self._format_df(data) - - def table(self, table_name: str, schema: str = TEST_SCHEMA) -> exp.Table: - schema = self.add_test_suffix(schema) - return exp.to_table( - normalize_model_name( - ".".join([schema, table_name]), - default_catalog=self.engine_adapter.default_catalog, - dialect=self.dialect, - ) - ) - - def schema(self, schema_name: str, catalog_name: t.Optional[str] = None) -> str: - return exp.table_name( - normalize_model_name( - self.add_test_suffix( - ".".join( - p - for p in (catalog_name or self.engine_adapter.default_catalog, schema_name) - if p - ) - if "." not in schema_name - else schema_name - ), - default_catalog=None, - dialect=self.dialect, - ) - ) - - def get_current_data(self, table: exp.Table) -> pd.DataFrame: - df = self.engine_adapter.fetchdf(exp.select("*").from_(table), quote_identifiers=True) - if self.dialect == "snowflake" and "id" in df.columns: - df["id"] = df["id"].apply(lambda x: x if pd.isna(x) else int(x)) - return self._format_df(df) - - def compare_with_current(self, table: exp.Table, expected: pd.DataFrame) -> None: - compare_dataframes( - self.get_current_data(table), - self.output_data(expected), - check_dtype=False, - check_index_type=False, - ) - - def get_table_comment( - self, - schema_name: str, - table_name: str, - table_kind: str = "BASE TABLE", - snowflake_capitalize_ids: bool = True, - ) -> str: - if self.dialect in ["postgres", "redshift"]: - query = f""" - SELECT - pgc.relname, - pg_catalog.obj_description(pgc.oid, 'pg_class') - FROM pg_catalog.pg_class pgc - INNER JOIN pg_catalog.pg_namespace n - ON pgc.relnamespace = n.oid - WHERE - n.nspname = '{schema_name}' - AND pgc.relname = '{table_name}' - AND pgc.relkind = '{'v' if table_kind == "VIEW" else 'r'}' - ; - """ - elif self.dialect in ["mysql", "snowflake"]: - # Snowflake treats all identifiers as uppercase unless they are lowercase and quoted. - # They are lowercase and quoted in sushi but not in the inline tests. - if self.dialect == "snowflake" and snowflake_capitalize_ids: - schema_name = schema_name.upper() - table_name = table_name.upper() - - comment_field_name = { - "mysql": "table_comment", - "snowflake": "comment", - } - - query = f""" - SELECT - table_name, - {comment_field_name[self.dialect]} - FROM INFORMATION_SCHEMA.TABLES - WHERE - table_schema='{schema_name}' - AND table_name='{table_name}' - AND table_type='{table_kind}' - """ - elif self.dialect == "bigquery": - query = f""" - SELECT - table_name, - option_value - FROM `region-us.INFORMATION_SCHEMA.TABLE_OPTIONS` - WHERE - table_schema='{schema_name}' - AND table_name='{table_name}' - AND option_name = 'description' - """ - elif self.dialect in ["spark", "databricks"]: - query = f"DESCRIBE TABLE EXTENDED {schema_name}.{table_name}" - elif self.dialect == "trino": - query = f""" - SELECT - table_name, - comment - FROM system.metadata.table_comments - WHERE - schema_name = '{schema_name}' - AND table_name = '{table_name}' - """ - elif self.dialect == "duckdb": - kind = "table" if table_kind == "BASE TABLE" else "view" - query = f""" - SELECT - {kind}_name, - comment - FROM duckdb_{kind}s() - WHERE - schema_name = '{schema_name}' - AND {kind}_name = '{table_name}' - """ - - result = self.engine_adapter.fetchall(query) - - if result: - if self.dialect == "bigquery": - comment = result[0][1].replace('"', "").replace("\\n", "\n") - elif self.dialect in ["spark", "databricks"]: - comment = [x for x in result if x[0] == "Comment"] - comment = comment[0][1] if comment else None - else: - comment = result[0][1] - - return comment - - return None - - def get_column_comments( - self, - schema_name: str, - table_name: str, - table_kind: str = "BASE TABLE", - snowflake_capitalize_ids: bool = True, - ) -> t.Dict[str, str]: - comment_index = 1 - if self.dialect in ["postgres", "redshift"]: - query = f""" - SELECT - cols.column_name, - pg_catalog.col_description(pgc.oid, cols.ordinal_position::int) AS column_comment - FROM pg_catalog.pg_class pgc - INNER JOIN pg_catalog.pg_namespace n - ON - pgc.relnamespace = n.oid - INNER JOIN information_schema.columns cols - ON - pgc.relname = cols.table_name - AND n.nspname = cols.table_schema - WHERE - n.nspname = '{schema_name}' - AND pgc.relname = '{table_name}' - AND pgc.relkind = '{'v' if table_kind == "VIEW" else 'r'}' - ; - """ - elif self.dialect in ["mysql", "snowflake"]: - # Snowflake treats all identifiers as uppercase unless they are lowercase and quoted. - # They are lowercase and quoted in sushi but not in the inline tests. - if self.dialect == "snowflake" and snowflake_capitalize_ids: - schema_name = schema_name.upper() - table_name = table_name.upper() - - comment_field_name = { - "mysql": "column_comment", - "snowflake": "comment", - } - - query = f""" - SELECT - column_name, - {comment_field_name[self.dialect]} - FROM - information_schema.columns - WHERE - table_schema = '{schema_name}' - AND table_name = '{table_name}' - ; - """ - elif self.dialect == "bigquery": - query = f""" - SELECT - column_name, - description - FROM - `region-us.INFORMATION_SCHEMA.COLUMN_FIELD_PATHS` - WHERE - table_schema = '{schema_name}' - AND table_name = '{table_name}' - ; - """ - elif self.dialect in ["spark", "databricks"]: - query = f"DESCRIBE TABLE {schema_name}.{table_name}" - comment_index = 2 - elif self.dialect == "trino": - query = f"SHOW COLUMNS FROM {schema_name}.{table_name}" - comment_index = 3 - elif self.dialect == "duckdb": - query = f""" - SELECT - column_name, - comment - FROM duckdb_columns() - WHERE - schema_name = '{schema_name}' - AND table_name = '{table_name}' - """ - - result = self.engine_adapter.fetchall(query) - - comments = {} - if result: - if self.dialect in ["spark", "databricks"]: - result = list(set([x for x in result if not x[0].startswith("#")])) - - comments = { - x[0]: x[comment_index] - for x in result - if x[comment_index] is not None and x[comment_index].strip() != "" - } - - return comments - - def create_context( - self, config_mutator: t.Optional[t.Callable[[str, Config], None]] = None - ) -> Context: - private_sqlmesh_dir = pathlib.Path(pathlib.Path().home(), ".sqlmesh") - config = load_config_from_paths( - Config, - project_paths=[ - pathlib.Path(os.path.join(os.path.dirname(__file__), "config.yaml")), - private_sqlmesh_dir / "config.yml", - private_sqlmesh_dir / "config.yaml", - ], - ) - if config_mutator: - config_mutator(self.gateway, config) - self._context = Context(paths=".", config=config, gateway=self.gateway) - return self._context - - def cleanup(self, ctx: t.Optional[Context] = None): - schemas = [self.schema(TEST_SCHEMA)] - - ctx = ctx or self._context - if ctx and ctx.models: - for _, model in ctx.models.items(): - schemas.append(model.schema_name) - schemas.append(model.physical_schema) - - for schema_name in set(schemas): - self.engine_adapter.drop_schema( - schema_name=schema_name, ignore_if_not_exists=True, cascade=True - ) - - -class MetadataResults(PydanticModel): - tables: t.List[str] = [] - views: t.List[str] = [] - materialized_views: t.List[str] = [] - managed_tables: t.List[str] = [] - - @classmethod - def from_data_objects(cls, data_objects: t.List[DataObject]) -> MetadataResults: - tables = [] - views = [] - materialized_views = [] - managed_tables = [] - for obj in data_objects: - if obj.type.is_table: - tables.append(obj.name) - elif obj.type.is_view: - views.append(obj.name) - elif obj.type.is_materialized_view: - materialized_views.append(obj.name) - elif obj.type.is_managed_table: - managed_tables.append(obj.name) - else: - raise ValueError(f"Unexpected object type: {obj.type}") - return MetadataResults( - tables=tables, - views=views, - materialized_views=materialized_views, - managed_tables=managed_tables, - ) - - @property - def non_temp_tables(self) -> t.List[str]: - return [x for x in self.tables if not x.startswith("__temp") and not x.startswith("temp")] - - -class PlanResults(PydanticModel): - plan: Plan - ctx: TestContext - schema_metadata: MetadataResults - internal_schema_metadata: MetadataResults - - @classmethod - def create(cls, plan: Plan, ctx: TestContext, schema_name: str): - schema_metadata = ctx.get_metadata_results(schema_name) - internal_schema_metadata = ctx.get_metadata_results(f"sqlmesh__{schema_name}") - return PlanResults( - plan=plan, - ctx=ctx, - schema_metadata=schema_metadata, - internal_schema_metadata=internal_schema_metadata, - ) - - def snapshot_for(self, model: Model) -> Snapshot: - return next((s for s in list(self.plan.snapshots.values()) if s.name == model.fqn)) - - def modified_snapshot_for(self, model: Model) -> Snapshot: - return next((s for s in list(self.plan.modified_snapshots.values()) if s.name == model.fqn)) - - def table_name_for( - self, snapshot_or_model: Snapshot | Model, is_deployable: bool = True - ) -> str: - snapshot = ( - snapshot_or_model - if isinstance(snapshot_or_model, Snapshot) - else self.snapshot_for(snapshot_or_model) - ) - table_name = snapshot.table_name(is_deployable) - return exp.to_table(table_name).this.sql(dialect=self.ctx.dialect) - - def dev_table_name_for(self, snapshot: Snapshot) -> str: - return self.table_name_for(snapshot, is_deployable=False) - - -@pytest.fixture(params=["df", "query", "pyspark"]) -def test_type(request): - return request.param - - -@pytest.fixture(scope="session") -def config() -> Config: - return load_config_from_paths( - Config, - project_paths=[ - pathlib.Path("examples/wursthall/config.yaml"), - pathlib.Path(os.path.join(os.path.dirname(__file__), "config.yaml")), - ], - personal_paths=[pathlib.Path("~/.sqlmesh/config.yaml").expanduser()], - ) - - -@pytest.fixture( - params=[ - pytest.param( - "duckdb", - marks=[ - pytest.mark.duckdb, - pytest.mark.engine, - pytest.mark.slow, - pytest.mark.xdist_group("engine_integration_duckdb"), - ], - ), - pytest.param( - "postgres", - marks=[ - pytest.mark.docker, - pytest.mark.engine, - pytest.mark.postgres, - pytest.mark.xdist_group("engine_integration_postgres"), - ], - ), - pytest.param( - "mysql", - marks=[ - pytest.mark.docker, - pytest.mark.engine, - pytest.mark.mysql, - pytest.mark.xdist_group("engine_integration_mysql"), - ], - ), - pytest.param( - "mssql", - marks=[ - pytest.mark.docker, - pytest.mark.engine, - pytest.mark.mssql, - pytest.mark.xdist_group("engine_integration_mssql"), - ], - ), - pytest.param( - "trino", - marks=[ - pytest.mark.docker, - pytest.mark.engine, - pytest.mark.trino, - pytest.mark.xdist_group("engine_integration_trino"), - ], - ), - pytest.param( - "trino_iceberg", - marks=[ - pytest.mark.docker, - pytest.mark.engine, - pytest.mark.trino_iceberg, - pytest.mark.xdist_group("engine_integration_trino_iceberg"), - ], - ), - pytest.param( - "trino_delta", - marks=[ - pytest.mark.docker, - pytest.mark.engine, - pytest.mark.trino_delta, - pytest.mark.xdist_group("engine_integration_trino_delta"), - ], - ), - pytest.param( - "spark", - marks=[ - pytest.mark.docker, - pytest.mark.engine, - pytest.mark.spark, - pytest.mark.xdist_group("engine_integration_spark"), - ], - ), - pytest.param( - "bigquery", - marks=[ - pytest.mark.bigquery, - pytest.mark.engine, - pytest.mark.remote, - pytest.mark.xdist_group("engine_integration_bigquery"), - ], - ), - pytest.param( - "databricks", - marks=[ - pytest.mark.databricks, - pytest.mark.engine, - pytest.mark.remote, - pytest.mark.xdist_group("engine_integration_databricks"), - ], - ), - # TODO: add motherduck tests once they support DuckDB>=0.10.0 - pytest.param( - "redshift", - marks=[ - pytest.mark.engine, - pytest.mark.remote, - pytest.mark.redshift, - pytest.mark.xdist_group("engine_integration_redshift"), - ], - ), - pytest.param( - "snowflake", - marks=[ - pytest.mark.engine, - pytest.mark.remote, - pytest.mark.snowflake, - pytest.mark.xdist_group("engine_integration_snowflake"), - ], - ), - ] -) -def mark_gateway(request) -> t.Tuple[str, str]: - return request.param, f"inttest_{request.param}" - - -@pytest.fixture -def engine_adapter(mark_gateway: t.Tuple[str, str], config) -> EngineAdapter: - mark, gateway = mark_gateway - if gateway not in config.gateways: - # TODO: Once everything is fully setup we want to error if a gateway is not configured that we expect - pytest.skip(f"Gateway {gateway} not configured") - connection_config = config.gateways[gateway].connection - engine_adapter = connection_config.create_engine_adapter() - # Trino: If we batch up the requests then when running locally we get a table not found error after creating the - # table and then immediately after trying to insert rows into it. There seems to be a delay between when the - # metastore is made aware of the table and when it responds that it exists. I'm hoping this is not an issue - # in practice on production machines. - if not mark.startswith("trino"): - engine_adapter.DEFAULT_BATCH_SIZE = 1 - # Clear our any local db files that may have been left over from previous runs - if mark == "duckdb": - for raw_path in (connection_config.catalogs or {}).values(): - pathlib.Path(raw_path).unlink(missing_ok=True) - return engine_adapter - - -@pytest.fixture -def default_columns_to_types(): - return {"id": exp.DataType.build("int"), "ds": exp.DataType.build("string")} - - -@pytest.fixture -def ctx(engine_adapter, test_type, mark_gateway): - _, gateway = mark_gateway - return TestContext(test_type, engine_adapter, gateway) - - -@pytest.fixture(autouse=True) -def cleanup(ctx: TestContext): - yield # run test - - if ctx: - ctx.cleanup() - - -def test_catalog_operations(ctx: TestContext): - if ( - ctx.engine_adapter.CATALOG_SUPPORT.is_unsupported - or ctx.engine_adapter.CATALOG_SUPPORT.is_single_catalog_only - ): - pytest.skip( - f"Engine adapter {ctx.engine_adapter.dialect} doesn't support catalog operations" - ) - if ctx.test_type != "query": - pytest.skip("Catalog operation tests only need to run once so we skip anything not query") - catalog_name = "testing" - if ctx.dialect == "databricks": - catalog_name = "catalogtest" - ctx.engine_adapter.execute(f"CREATE CATALOG IF NOT EXISTS {catalog_name}") - elif ctx.dialect == "tsql": - ctx.engine_adapter.cursor.connection.autocommit(True) - try: - ctx.engine_adapter.cursor.execute(f"CREATE DATABASE {catalog_name}") - except Exception: - pass - ctx.engine_adapter.cursor.connection.autocommit(False) - elif ctx.dialect == "snowflake": - ctx.engine_adapter.execute(f'CREATE DATABASE IF NOT EXISTS "{catalog_name}"') - elif ctx.dialect == "duckdb": - try: - # Only applies to MotherDuck - ctx.engine_adapter.execute(f'CREATE DATABASE IF NOT EXISTS "{catalog_name}"') - except Exception: - pass - current_catalog = ctx.engine_adapter.get_current_catalog().lower() - ctx.engine_adapter.set_current_catalog(catalog_name) - assert ctx.engine_adapter.get_current_catalog().lower() == catalog_name - ctx.engine_adapter.set_current_catalog(current_catalog) - assert ctx.engine_adapter.get_current_catalog().lower() == current_catalog - - -def test_drop_schema_catalog(ctx: TestContext, caplog): - def drop_schema_and_validate(schema_name: str): - ctx.engine_adapter.drop_schema(schema_name, cascade=True) - results = ctx.get_metadata_results(schema_name) - assert ( - len(results.tables) - == len(results.views) - == len(results.materialized_views) - == len(results.non_temp_tables) - == 0 - ) - - def create_objects_and_validate(schema_name: str): - ctx.engine_adapter.create_schema(schema_name) - ctx.engine_adapter.create_view(f"{schema_name}.test_view", parse_one("SELECT 1 as col")) - ctx.engine_adapter.create_table( - f"{schema_name}.test_table", {"col": exp.DataType.build("int")} - ) - ctx.engine_adapter.create_table( - f"{schema_name}.replace_table", {"col": exp.DataType.build("int")} - ) - ctx.engine_adapter.replace_query( - f"{schema_name}.replace_table", - parse_one("SELECT 1 as col"), - {"col": exp.DataType.build("int")}, - ) - results = ctx.get_metadata_results(schema_name) - assert len(results.tables) == 2 - assert len(results.views) == 1 - assert len(results.materialized_views) == 0 - assert len(results.non_temp_tables) == 2 - - if ctx.engine_adapter.CATALOG_SUPPORT.is_unsupported: - pytest.skip( - f"Engine adapter {ctx.engine_adapter.dialect} doesn't support catalog operations" - ) - if ctx.dialect == "spark": - pytest.skip( - "Currently local spark is configured to have iceberg be the testing catalog and drop cascade doesn't work on iceberg. Skipping until we have time to fix." - ) - if ctx.test_type != "query": - pytest.skip("Drop Schema Catalog tests only need to run once so we skip anything not query") - catalog_name = "testing" - if ctx.dialect == "databricks": - catalog_name = "catalogtest" - ctx.engine_adapter.execute(f"CREATE CATALOG IF NOT EXISTS {catalog_name}") - elif ctx.dialect == "tsql": - ctx.engine_adapter.cursor.connection.autocommit(True) - try: - ctx.engine_adapter.cursor.execute(f"CREATE DATABASE {catalog_name}") - except Exception: - pass - ctx.engine_adapter.cursor.connection.autocommit(False) - elif ctx.dialect == "snowflake": - ctx.engine_adapter.execute(f'CREATE DATABASE IF NOT EXISTS "{catalog_name}"') - elif ctx.dialect == "duckdb": - try: - # Only applies to MotherDuck - ctx.engine_adapter.execute(f'CREATE DATABASE IF NOT EXISTS "{catalog_name}"') - except Exception: - pass - elif ctx.dialect == "bigquery": - catalog_name = "tobiko-test" - - schema = ctx.schema("drop_schema_catalog_test", catalog_name) - if ctx.engine_adapter.CATALOG_SUPPORT.is_single_catalog_only: - drop_schema_and_validate(schema) - assert "requires that all catalog operations be against a single catalog" in caplog.text - return - drop_schema_and_validate(schema) - create_objects_and_validate(schema) - - -def test_temp_table(ctx: TestContext): - ctx.init() - input_data = pd.DataFrame( - [ - {"id": 1, "ds": "2022-01-01"}, - {"id": 2, "ds": "2022-01-02"}, - {"id": 3, "ds": "2022-01-03"}, - ] - ) - table = ctx.table("example") - - with ctx.engine_adapter.temp_table(ctx.input_data(input_data), table.sql()) as table_name: - results = ctx.get_metadata_results() - assert len(results.views) == 0 - assert len(results.tables) == 1 - assert len(results.non_temp_tables) == 0 - assert len(results.materialized_views) == 0 - ctx.compare_with_current(table_name, input_data) - - results = ctx.get_metadata_results() - assert len(results.views) == len(results.tables) == len(results.non_temp_tables) == 0 - - -def test_create_table(ctx: TestContext): - table = ctx.table("test_table") - ctx.init() - ctx.engine_adapter.create_table( - table, - {"id": exp.DataType.build("int")}, - table_description="test table description", - column_descriptions={"id": "test id column description"}, - ) - results = ctx.get_metadata_results() - assert len(results.tables) == 1 - assert len(results.views) == 0 - assert len(results.materialized_views) == 0 - assert results.tables[0] == table.name - - if ctx.engine_adapter.COMMENT_CREATION_TABLE.is_supported: - table_description = ctx.get_table_comment(table.db, "test_table") - column_comments = ctx.get_column_comments(table.db, "test_table") - assert table_description == "test table description" - assert column_comments == {"id": "test id column description"} - - -def test_ctas(ctx: TestContext): - ctx.init() - table = ctx.table("test_table") - input_data = pd.DataFrame( - [ - {"id": 1, "ds": "2022-01-01"}, - {"id": 2, "ds": "2022-01-02"}, - {"id": 3, "ds": "2022-01-03"}, - ] - ) - ctx.engine_adapter.ctas( - table, - ctx.input_data(input_data), - table_description="test table description", - column_descriptions={"id": "test id column description"}, - ) - results = ctx.get_metadata_results() - assert len(results.views) == 0 - assert len(results.materialized_views) == 0 - assert len(results.tables) == len(results.non_temp_tables) == 1 - assert results.non_temp_tables[0] == table.name - ctx.compare_with_current(table, input_data) - - if ctx.engine_adapter.COMMENT_CREATION_TABLE.is_supported: - table_description = ctx.get_table_comment(table.db, "test_table") - column_comments = ctx.get_column_comments(table.db, "test_table") - - # Trino on Hive COMMENT permissions are separate from standard SQL object permissions. - # Trino has a bug where CREATE SQL permissions are not passed to COMMENT permissions, - # which generates permissions errors when COMMENT commands are issued. - # - # The errors are thrown for both table and comments, but apparently the - # table comments are actually registered with the engine. Column comments are not. - assert table_description == "test table description" - assert column_comments == ( - {} - if (ctx.dialect == "trino" and ctx.current_catalog_type == "hive") - else {"id": "test id column description"} - ) - - -def test_create_view(ctx: TestContext): - input_data = pd.DataFrame( - [ - {"id": 1, "ds": "2022-01-01"}, - {"id": 2, "ds": "2022-01-02"}, - {"id": 3, "ds": "2022-01-03"}, - ] - ) - view = ctx.table("test_view") - ctx.init() - ctx.engine_adapter.create_view( - view, - ctx.input_data(input_data), - table_description="test view description", - column_descriptions={"id": "test id column description"}, - ) - results = ctx.get_metadata_results() - assert len(results.tables) == 0 - assert len(results.views) == 1 - assert len(results.materialized_views) == 0 - assert results.views[0] == view.name - ctx.compare_with_current(view, input_data) - - if ctx.engine_adapter.COMMENT_CREATION_VIEW.is_supported: - table_description = ctx.get_table_comment(view.db, "test_view", table_kind="VIEW") - column_comments = ctx.get_column_comments(view.db, "test_view", table_kind="VIEW") - - # Trino: - # Trino on Hive COMMENT permissions are separate from standard SQL object permissions. - # Trino has a bug where CREATE SQL permissions are not passed to COMMENT permissions, - # which generates permissions errors when COMMENT commands are issued. - # - # The errors are thrown for both table and comments, but apparently the - # table comments are actually registered with the engine. Column comments are not. - # - # Query: - # In the query test, columns_to_types are not available when the view is created. Since we - # can only register column comments in the CREATE VIEW schema expression with columns_to_types - # available, the column comments must be registered via post-creation commands. Some engines, - # such as Spark and Snowflake, do not support view column comments via post-creation commands. - assert table_description == "test view description" - assert column_comments == ( - {} - if (ctx.dialect == "trino" and ctx.current_catalog_type == "hive") - or ( - ctx.test_type == "query" - and not ctx.engine_adapter.COMMENT_CREATION_VIEW.supports_column_comment_commands - ) - else {"id": "test id column description"} - ) - - -def test_materialized_view(ctx: TestContext): - if not ctx.engine_adapter.SUPPORTS_MATERIALIZED_VIEWS: - pytest.skip(f"Engine adapter {ctx.engine_adapter} doesn't support materialized views") - if ctx.engine_adapter.dialect == "databricks": - pytest.skip( - "Databricks requires DBSQL Serverless or Pro warehouse to test materialized views which we do not have setup" - ) - if ctx.engine_adapter.dialect == "snowflake": - pytest.skip("Snowflake requires enterprise edition which we do not have setup") - input_data = pd.DataFrame( - [ - {"id": 1, "ds": "2022-01-01"}, - {"id": 2, "ds": "2022-01-02"}, - {"id": 3, "ds": "2022-01-03"}, - ] - ) - ctx.init() - source_table = ctx.table("source_table") - ctx.engine_adapter.ctas(source_table, ctx.input_data(input_data), ctx.columns_to_types) - view = ctx.table("test_view") - view_query = exp.select(*ctx.columns_to_types).from_(source_table) - ctx.engine_adapter.create_view(view, view_query, materialized=True) - results = ctx.get_metadata_results() - # Redshift considers the underlying dataset supporting materialized views as a table therefore we get 2 - # tables in the result - if ctx.engine_adapter.dialect == "redshift": - assert len(results.tables) == 2 - else: - assert len(results.tables) == 1 - assert len(results.views) == 0 - assert len(results.materialized_views) == 1 - assert results.materialized_views[0] == view.name - ctx.compare_with_current(view, input_data) - # Make sure that dropping a materialized view also works - ctx.engine_adapter.drop_view(view, materialized=True) - results = ctx.get_metadata_results() - assert len(results.materialized_views) == 0 - - -def test_drop_schema(ctx: TestContext): - if ctx.test_type != "query": - pytest.skip("Drop Schema tests only need to run once so we skip anything not query") - ctx.columns_to_types = {"one": "int"} - schema = ctx.schema(TEST_SCHEMA) - ctx.engine_adapter.drop_schema(schema, cascade=True) - results = ctx.get_metadata_results() - assert len(results.tables) == 0 - assert len(results.views) == 0 - - ctx.engine_adapter.create_schema(schema) - view = ctx.table("test_view") - view_query = exp.Select().select(exp.Literal.number(1).as_("one")) - ctx.engine_adapter.create_view(view, view_query, ctx.columns_to_types) - results = ctx.get_metadata_results() - assert len(results.tables) == 0 - assert len(results.views) == 1 - - ctx.engine_adapter.drop_schema(schema, cascade=True) - results = ctx.get_metadata_results() - assert len(results.tables) == 0 - assert len(results.views) == 0 - - -def test_nan_roundtrip(ctx: TestContext): - if ctx.test_type != "df": - pytest.skip("NaN roundtrip test only relevant for dataframes.") - ctx.engine_adapter.DEFAULT_BATCH_SIZE = sys.maxsize - ctx.init() - table = ctx.table("test_table") - # Initial Load - input_data = pd.DataFrame( - [ - {"id": 1, "ds": "2022-01-01"}, - {"id": 2, "ds": "2022-01-02"}, - {"id": np.nan, "ds": np.nan}, - ] - ) - ctx.engine_adapter.create_table(table, ctx.columns_to_types) - ctx.engine_adapter.replace_query( - table, - ctx.input_data(input_data), - columns_to_types=ctx.columns_to_types, - ) - results = ctx.get_metadata_results() - assert not results.views - assert not results.materialized_views - assert len(results.tables) == len(results.non_temp_tables) == 1 - assert results.non_temp_tables[0] == table.name - ctx.compare_with_current(table, input_data) - - -def test_replace_query(ctx: TestContext): - ctx.engine_adapter.DEFAULT_BATCH_SIZE = sys.maxsize - ctx.init() - table = ctx.table("test_table") - # Initial Load - input_data = pd.DataFrame( - [ - {"id": 1, "ds": "2022-01-01"}, - {"id": 2, "ds": "2022-01-02"}, - {"id": 3, "ds": "2022-01-03"}, - ] - ) - ctx.engine_adapter.create_table(table, ctx.columns_to_types) - ctx.engine_adapter.replace_query( - table, - ctx.input_data(input_data), - # Spark based engines do a create table -> insert overwrite instead of replace. If columns to types aren't - # provided then it checks the table itself for types. This is fine within SQLMesh since we always know the tables - # exist prior to evaluation but when running these tests that isn't the case. As a result we just pass in - # columns_to_types for these two engines so we can still test inference on the other ones - columns_to_types=ctx.columns_to_types if ctx.dialect in ["spark", "databricks"] else None, - ) - results = ctx.get_metadata_results() - assert len(results.views) == 0 - assert len(results.materialized_views) == 0 - assert len(results.tables) == len(results.non_temp_tables) == 1 - assert results.non_temp_tables[0] == table.name - ctx.compare_with_current(table, input_data) - - # Replace that we only need to run once - if type == "df": - replace_data = pd.DataFrame( - [ - {"id": 4, "ds": "2022-01-04"}, - {"id": 5, "ds": "2022-01-05"}, - {"id": 6, "ds": "2022-01-06"}, - ] - ) - ctx.engine_adapter.replace_query( - table, - ctx.input_data(replace_data), - columns_to_types=( - ctx.columns_to_types if ctx.dialect in ["spark", "databricks"] else None - ), - ) - results = ctx.get_metadata_results() - assert len(results.views) == 0 - assert len(results.materialized_views) == 0 - assert len(results.tables) == len(results.non_temp_tables) == 1 - assert results.non_temp_tables[0] == table.name - ctx.compare_with_current(table, replace_data) - - -def test_replace_query_batched(ctx: TestContext): - ctx.engine_adapter.DEFAULT_BATCH_SIZE = 1 - ctx.init() - table = ctx.table("test_table") - # Initial Load - input_data = pd.DataFrame( - [ - {"id": 1, "ds": "2022-01-01"}, - {"id": 2, "ds": "2022-01-02"}, - {"id": 3, "ds": "2022-01-03"}, - ] - ) - ctx.engine_adapter.create_table(table, ctx.columns_to_types) - ctx.engine_adapter.replace_query( - table, - ctx.input_data(input_data), - # Spark based engines do a create table -> insert overwrite instead of replace. If columns to types aren't - # provided then it checks the table itself for types. This is fine within SQLMesh since we always know the tables - # exist prior to evaluation but when running these tests that isn't the case. As a result we just pass in - # columns_to_types for these two engines so we can still test inference on the other ones - columns_to_types=ctx.columns_to_types if ctx.dialect in ["spark", "databricks"] else None, - ) - results = ctx.get_metadata_results() - assert len(results.views) == 0 - assert len(results.materialized_views) == 0 - assert len(results.tables) == len(results.non_temp_tables) == 1 - assert results.non_temp_tables[0] == table.name - ctx.compare_with_current(table, input_data) - - # Replace that we only need to run once - if type == "df": - replace_data = pd.DataFrame( - [ - {"id": 4, "ds": "2022-01-04"}, - {"id": 5, "ds": "2022-01-05"}, - {"id": 6, "ds": "2022-01-06"}, - ] - ) - ctx.engine_adapter.replace_query( - table, - ctx.input_data(replace_data), - columns_to_types=( - ctx.columns_to_types if ctx.dialect in ["spark", "databricks"] else None - ), - ) - results = ctx.get_metadata_results() - assert len(results.views) == 0 - assert len(results.materialized_views) == 0 - assert len(results.tables) == len(results.non_temp_tables) == 1 - assert results.non_temp_tables[0] == table.name - ctx.compare_with_current(table, replace_data) - - -def test_insert_append(ctx: TestContext): - ctx.init() - table = ctx.table("test_table") - ctx.engine_adapter.create_table(table, ctx.columns_to_types) - # Initial Load - input_data = pd.DataFrame( - [ - {"id": 1, "ds": "2022-01-01"}, - {"id": 2, "ds": "2022-01-02"}, - {"id": 3, "ds": "2022-01-03"}, - ] - ) - ctx.engine_adapter.insert_append(table, ctx.input_data(input_data)) - results = ctx.get_metadata_results() - assert len(results.views) == 0 - assert len(results.materialized_views) == 0 - assert len(results.tables) == len(results.non_temp_tables) == 1 - assert results.non_temp_tables[0] == table.name - ctx.compare_with_current(table, input_data) - - # Replace that we only need to run once - if type == "df": - append_data = pd.DataFrame( - [ - {"id": 4, "ds": "2022-01-04"}, - {"id": 5, "ds": "2022-01-05"}, - {"id": 6, "ds": "2022-01-06"}, - ] - ) - ctx.engine_adapter.insert_append(table, ctx.input_data(append_data)) - results = ctx.get_metadata_results() - assert len(results.views) == 0 - assert len(results.materialized_views) == 0 - assert len(results.tables) in [1, 2, 3] - assert len(results.non_temp_tables) == 1 - assert results.non_temp_tables[0] == table.name - ctx.compare_with_current(table, pd.concat([input_data, append_data])) - - -def test_insert_overwrite_by_time_partition(ctx: TestContext): - ds_type = "string" - if ctx.dialect == "bigquery": - ds_type = "datetime" - if ctx.dialect == "tsql": - ds_type = "varchar(max)" - - ctx.columns_to_types = {"id": "int", "ds": ds_type} - ctx.init() - table = ctx.table("test_table") - if ctx.dialect == "bigquery": - partitioned_by = ["DATE(ds)"] - else: - partitioned_by = ctx.partitioned_by # type: ignore - ctx.engine_adapter.create_table( - table, - ctx.columns_to_types, - partitioned_by=partitioned_by, - partition_interval_unit="DAY", - ) - input_data = pd.DataFrame( - [ - {"id": 1, ctx.time_column: "2022-01-01"}, - {"id": 2, ctx.time_column: "2022-01-02"}, - {"id": 3, ctx.time_column: "2022-01-03"}, - ] - ) - ctx.engine_adapter.insert_overwrite_by_time_partition( - table, - ctx.input_data(input_data), - start="2022-01-02", - end="2022-01-03", - time_formatter=ctx.time_formatter, - time_column=ctx.time_column, - columns_to_types=ctx.columns_to_types, - ) - results = ctx.get_metadata_results() - assert len(results.views) == 0 - assert len(results.materialized_views) == 0 - assert len(results.tables) == len(results.non_temp_tables) == 1 - assert len(results.non_temp_tables) == 1 - assert results.non_temp_tables[0] == table.name - ctx.compare_with_current(table, input_data.iloc[1:]) - - if test_type == "df": - overwrite_data = pd.DataFrame( - [ - {"id": 10, ctx.time_column: "2022-01-03"}, - {"id": 4, ctx.time_column: "2022-01-04"}, - {"id": 5, ctx.time_column: "2022-01-05"}, - ] - ) - ctx.engine_adapter.insert_overwrite_by_time_partition( - table, - ctx.input_data(overwrite_data), - start="2022-01-03", - end="2022-01-05", - time_formatter=ctx.time_formatter, - time_column=ctx.time_column, - columns_to_types=ctx.columns_to_types, - ) - results = ctx.get_metadata_results() - assert len(results.views) == 0 - assert len(results.materialized_views) == 0 - assert len(results.tables) == len(results.non_temp_tables) == 1 - assert results.non_temp_tables[0] == table.name - ctx.compare_with_current( - table, - pd.DataFrame( - [ - {"id": 2, ctx.time_column: "2022-01-02"}, - {"id": 10, ctx.time_column: "2022-01-03"}, - {"id": 4, ctx.time_column: "2022-01-04"}, - {"id": 5, ctx.time_column: "2022-01-05"}, - ] - ), - ) - - -def test_merge(ctx: TestContext): - if not ctx.supports_merge: - pytest.skip(f"{ctx.dialect} doesn't support merge") - - ctx.init() - table = ctx.table("test_table") - ctx.engine_adapter.create_table(table, ctx.columns_to_types) - input_data = pd.DataFrame( - [ - {"id": 1, "ds": "2022-01-01"}, - {"id": 2, "ds": "2022-01-02"}, - {"id": 3, "ds": "2022-01-03"}, - ] - ) - ctx.engine_adapter.merge( - table, - ctx.input_data(input_data), - columns_to_types=None, - unique_key=[exp.to_identifier("id")], - ) - results = ctx.get_metadata_results() - assert len(results.views) == 0 - assert len(results.materialized_views) == 0 - assert len(results.tables) == len(results.non_temp_tables) == 1 - assert len(results.non_temp_tables) == 1 - assert results.non_temp_tables[0] == table.name - ctx.compare_with_current(table, input_data) - - if test_type == "df": - merge_data = pd.DataFrame( - [ - {"id": 2, "ds": "2022-01-10"}, - {"id": 4, "ds": "2022-01-04"}, - {"id": 5, "ds": "2022-01-05"}, - ] - ) - ctx.engine_adapter.merge( - table, - ctx.input_data(merge_data), - columns_to_types=None, - unique_key=[exp.to_identifier("id")], - ) - results = ctx.get_metadata_results() - assert len(results.views) == 0 - assert len(results.materialized_views) == 0 - assert len(results.tables) == len(results.non_temp_tables) == 1 - assert results.non_temp_tables[0] == table.name - ctx.compare_with_current( - table, - pd.DataFrame( - [ - {"id": 1, "ds": "2022-01-01"}, - {"id": 2, "ds": "2022-01-10"}, - {"id": 3, "ds": "2022-01-03"}, - {"id": 4, "ds": "2022-01-04"}, - {"id": 5, "ds": "2022-01-05"}, - ] - ), - ) - - -def test_scd_type_2_by_time(ctx: TestContext): - time_type = exp.DataType.build("timestamp") - - ctx.columns_to_types = { - "id": "int", - "name": "string", - "updated_at": time_type, - "valid_from": time_type, - "valid_to": time_type, - } - ctx.init() - table = ctx.table("test_table") - input_schema = { - k: v for k, v in ctx.columns_to_types.items() if k not in ("valid_from", "valid_to") - } - ctx.engine_adapter.create_table(table, ctx.columns_to_types) - input_data = pd.DataFrame( - [ - {"id": 1, "name": "a", "updated_at": "2022-01-01 00:00:00"}, - {"id": 2, "name": "b", "updated_at": "2022-01-02 00:00:00"}, - {"id": 3, "name": "c", "updated_at": "2022-01-03 00:00:00"}, - ] - ) - ctx.engine_adapter.scd_type_2_by_time( - table, - ctx.input_data(input_data, input_schema), - unique_key=[parse_one("COALESCE(id, -1)")], - valid_from_col=exp.column("valid_from", quoted=True), - valid_to_col=exp.column("valid_to", quoted=True), - updated_at_col=exp.column("updated_at", quoted=True), - execution_time="2023-01-01", - updated_at_as_valid_from=False, - columns_to_types=input_schema, - ) - results = ctx.get_metadata_results() - assert len(results.views) == 0 - assert len(results.materialized_views) == 0 - assert len(results.tables) == len(results.non_temp_tables) == 1 - assert len(results.non_temp_tables) == 1 - assert results.non_temp_tables[0] == table.name - ctx.compare_with_current( - table, - pd.DataFrame( - [ - { - "id": 1, - "name": "a", - "updated_at": "2022-01-01 00:00:00", - "valid_from": "1970-01-01 00:00:00", - "valid_to": pd.NaT, - }, - { - "id": 2, - "name": "b", - "updated_at": "2022-01-02 00:00:00", - "valid_from": "1970-01-01 00:00:00", - "valid_to": pd.NaT, - }, - { - "id": 3, - "name": "c", - "updated_at": "2022-01-03 00:00:00", - "valid_from": "1970-01-01 00:00:00", - "valid_to": pd.NaT, - }, - ] - ), - ) - - if ctx.test_type == "query": - return - current_data = pd.DataFrame( - [ - # Change `a` to `x` - {"id": 1, "name": "x", "updated_at": "2022-01-04 00:00:00"}, - # Delete - # {"id": 2, "name": "b", "updated_at": "2022-01-02 00:00:00"}, - # No change - {"id": 3, "name": "c", "updated_at": "2022-01-03 00:00:00"}, - # Add - {"id": 4, "name": "d", "updated_at": "2022-01-04 00:00:00"}, - ] - ) - ctx.engine_adapter.scd_type_2_by_time( - table, - ctx.input_data(current_data, input_schema), - unique_key=[exp.to_column("id")], - valid_from_col=exp.column("valid_from", quoted=True), - valid_to_col=exp.column("valid_to", quoted=True), - updated_at_col=exp.column("updated_at", quoted=True), - execution_time="2023-01-05", - updated_at_as_valid_from=False, - columns_to_types=input_schema, - ) - results = ctx.get_metadata_results() - assert len(results.views) == 0 - assert len(results.materialized_views) == 0 - assert len(results.tables) == len(results.non_temp_tables) == 1 - assert results.non_temp_tables[0] == table.name - ctx.compare_with_current( - table, - pd.DataFrame( - [ - { - "id": 1, - "name": "a", - "updated_at": "2022-01-01 00:00:00", - "valid_from": "1970-01-01 00:00:00", - "valid_to": "2022-01-04 00:00:00", - }, - { - "id": 1, - "name": "x", - "updated_at": "2022-01-04 00:00:00", - "valid_from": "2022-01-04 00:00:00", - "valid_to": pd.NaT, - }, - { - "id": 2, - "name": "b", - "updated_at": "2022-01-02 00:00:00", - "valid_from": "1970-01-01 00:00:00", - "valid_to": "2023-01-05 00:00:00", - }, - { - "id": 3, - "name": "c", - "updated_at": "2022-01-03 00:00:00", - "valid_from": "1970-01-01 00:00:00", - "valid_to": pd.NaT, - }, - { - "id": 4, - "name": "d", - "updated_at": "2022-01-04 00:00:00", - "valid_from": "1970-01-01 00:00:00", - "valid_to": pd.NaT, - }, - ] - ), - ) - - -def test_scd_type_2_by_column(ctx: TestContext): - time_type = exp.DataType.build("timestamp") - - ctx.columns_to_types = { - "id": "int", - "name": "string", - "status": "string", - "valid_from": time_type, - "valid_to": time_type, - } - ctx.init() - table = ctx.table("test_table") - input_schema = { - k: v for k, v in ctx.columns_to_types.items() if k not in ("valid_from", "valid_to") - } - ctx.engine_adapter.create_table(table, ctx.columns_to_types) - input_data = pd.DataFrame( - [ - {"id": 1, "name": "a", "status": "active"}, - {"id": 2, "name": "b", "status": "inactive"}, - {"id": 3, "name": "c", "status": "active"}, - {"id": 4, "name": "d", "status": "active"}, - ] - ) - ctx.engine_adapter.scd_type_2_by_column( - table, - ctx.input_data(input_data, input_schema), - unique_key=[exp.to_column("id")], - check_columns=[exp.to_column("name"), exp.to_column("status")], - valid_from_col=exp.column("valid_from", quoted=True), - valid_to_col=exp.column("valid_to", quoted=True), - execution_time="2023-01-01", - execution_time_as_valid_from=False, - columns_to_types=input_schema, - ) - results = ctx.get_metadata_results() - assert len(results.views) == 0 - assert len(results.materialized_views) == 0 - assert len(results.tables) == len(results.non_temp_tables) == 1 - assert len(results.non_temp_tables) == 1 - assert results.non_temp_tables[0] == table.name - ctx.compare_with_current( - table, - pd.DataFrame( - [ - { - "id": 1, - "name": "a", - "status": "active", - "valid_from": "1970-01-01 00:00:00", - "valid_to": pd.NaT, - }, - { - "id": 2, - "name": "b", - "status": "inactive", - "valid_from": "1970-01-01 00:00:00", - "valid_to": pd.NaT, - }, - { - "id": 3, - "name": "c", - "status": "active", - "valid_from": "1970-01-01 00:00:00", - "valid_to": pd.NaT, - }, - { - "id": 4, - "name": "d", - "status": "active", - "valid_from": "1970-01-01 00:00:00", - "valid_to": pd.NaT, - }, - ] - ), - ) - - if ctx.test_type == "query": - return - current_data = pd.DataFrame( - [ - # Change `a` to `x` - {"id": 1, "name": "x", "status": "active"}, - # Delete - # {"id": 2, "name": "b", status: "inactive"}, - # No change - {"id": 3, "name": "c", "status": "active"}, - # Change status to inactive - {"id": 4, "name": "d", "status": "inactive"}, - # Add - {"id": 5, "name": "e", "status": "inactive"}, - ] - ) - ctx.engine_adapter.scd_type_2_by_column( - table, - ctx.input_data(current_data, input_schema), - unique_key=[exp.to_column("id")], - check_columns=[exp.to_column("name"), exp.to_column("status")], - valid_from_col=exp.column("valid_from", quoted=True), - valid_to_col=exp.column("valid_to", quoted=True), - execution_time="2023-01-05", - execution_time_as_valid_from=False, - columns_to_types=input_schema, - ) - results = ctx.get_metadata_results() - assert len(results.views) == 0 - assert len(results.materialized_views) == 0 - assert len(results.tables) == len(results.non_temp_tables) == 1 - assert results.non_temp_tables[0] == table.name - ctx.compare_with_current( - table, - pd.DataFrame( - [ - { - "id": 1, - "name": "a", - "status": "active", - "valid_from": "1970-01-01 00:00:00", - "valid_to": "2023-01-05 00:00:00", - }, - { - "id": 1, - "name": "x", - "status": "active", - "valid_from": "2023-01-05 00:00:00", - "valid_to": pd.NaT, - }, - { - "id": 2, - "name": "b", - "status": "inactive", - "valid_from": "1970-01-01 00:00:00", - "valid_to": "2023-01-05 00:00:00", - }, - { - "id": 3, - "name": "c", - "status": "active", - "valid_from": "1970-01-01 00:00:00", - "valid_to": pd.NaT, - }, - { - "id": 4, - "name": "d", - "status": "active", - "valid_from": "1970-01-01 00:00:00", - "valid_to": "2023-01-05 00:00:00", - }, - { - "id": 4, - "name": "d", - "status": "inactive", - "valid_from": "2023-01-05 00:00:00", - "valid_to": pd.NaT, - }, - { - "id": 5, - "name": "e", - "status": "inactive", - "valid_from": "1970-01-01 00:00:00", - "valid_to": pd.NaT, - }, - ] - ), - ) - - -def test_get_data_objects(ctx: TestContext): - table = ctx.table("test_table") - view = ctx.table("test_view") - ctx.init() - ctx.engine_adapter.create_table( - table, - {"id": exp.DataType.build("int")}, - table_description="test table description", - column_descriptions={"id": "test id column description"}, - ) - ctx.engine_adapter.create_view( - view, - ctx.input_data(pd.DataFrame([{"id": 1, "ds": "2022-01-01"}])), - table_description="test view description", - column_descriptions={"id": "test id column description"}, - ) - - schema = ctx.schema(TEST_SCHEMA) - - assert sorted(ctx.engine_adapter.get_data_objects(schema), key=lambda o: o.name) == [ - DataObject( - name=table.name, - schema=table.db, - catalog=table.catalog or None, - type=DataObjectType.TABLE, - ), - DataObject( - name=view.name, - schema=view.db, - catalog=view.catalog or None, - type=DataObjectType.VIEW, - ), - ] - - assert sorted( - ctx.engine_adapter.get_data_objects(schema, {table.name, view.name}), - key=lambda o: o.name, - ) == [ - DataObject( - name=table.name, - schema=table.db, - catalog=table.catalog or None, - type=DataObjectType.TABLE, - ), - DataObject( - name=view.name, - schema=view.db, - catalog=view.catalog or None, - type=DataObjectType.VIEW, - ), - ] - - assert ctx.engine_adapter.get_data_objects(schema, {table.name}) == [ - DataObject( - name=table.name, - schema=table.db, - catalog=table.catalog or None, - type=DataObjectType.TABLE, - ), - ] - - assert ctx.engine_adapter.get_data_objects(schema, {view.name}) == [ - DataObject( - name=view.name, - schema=view.db, - catalog=view.catalog or None, - type=DataObjectType.VIEW, - ), - ] - - assert ctx.engine_adapter.get_data_objects(schema, {}) == [] - assert ctx.engine_adapter.get_data_objects("missing_schema") == [] - - -def test_truncate_table(ctx: TestContext): - if ctx.test_type != "query": - pytest.skip("Truncate table test does not change based on input data type") - - ctx.init() - table = ctx.table("test_table") - ctx.engine_adapter.create_table(table, ctx.columns_to_types) - input_data = pd.DataFrame( - [ - {"id": 1, "ds": "2022-01-01"}, - {"id": 2, "ds": "2022-01-02"}, - {"id": 3, "ds": "2022-01-03"}, - ] - ) - ctx.engine_adapter.insert_append(table, ctx.input_data(input_data)) - ctx.compare_with_current(table, input_data) - ctx.engine_adapter._truncate_table(table) - assert ctx.engine_adapter.fetchone(exp.select("count(*)").from_(table))[0] == 0 - - -def test_transaction(ctx: TestContext): - if ctx.engine_adapter.SUPPORTS_TRANSACTIONS is False: - pytest.skip(f"Engine adapter {ctx.engine_adapter.dialect} doesn't support transactions") - if ctx.test_type != "query": - pytest.skip("Transaction test can just run for query") - - ctx.init() - table = ctx.table("test_table") - input_data = pd.DataFrame( - [ - {"id": 1, "ds": "2022-01-01"}, - {"id": 2, "ds": "2022-01-02"}, - {"id": 3, "ds": "2022-01-03"}, - ] - ) - with ctx.engine_adapter.transaction(): - ctx.engine_adapter.create_table(table, ctx.columns_to_types) - ctx.engine_adapter.insert_append( - table, ctx.input_data(input_data, ctx.columns_to_types), ctx.columns_to_types - ) - ctx.compare_with_current(table, input_data) - with ctx.engine_adapter.transaction(): - ctx.engine_adapter._truncate_table(table) - ctx.engine_adapter._connection_pool.rollback() - ctx.compare_with_current(table, input_data) - - -def test_sushi(mark_gateway: t.Tuple[str, str], ctx: TestContext): - if ctx.test_type != "query": - pytest.skip("Sushi end-to-end tests only need to run for query") - - config = load_config_from_paths( - Config, - project_paths=[ - pathlib.Path(os.path.join(os.path.dirname(__file__), "config.yaml")), - ], - personal_paths=[pathlib.Path("~/.sqlmesh/config.yaml").expanduser()], - ) - _, gateway = mark_gateway - - # clear cache from prior runs - cache_dir = pathlib.Path("./examples/sushi/.cache") - if cache_dir.exists(): - import shutil - - shutil.rmtree(cache_dir) - - context = Context(paths="./examples/sushi", config=config, gateway=gateway) - - # clean up any leftover schemas from previous runs (requires context) - for schema in [ - "sushi__test_prod", - "sushi__test_dev", - "sushi", - "sqlmesh__sushi", - "sqlmesh", - "raw", - ]: - context.engine_adapter.drop_schema(schema, ignore_if_not_exists=True, cascade=True) - - start = to_date(now() - timedelta(days=7)) - end = now() - - # Databricks requires the table property `delta.columnMapping.mode = 'name'` for - # spaces in column names. Other engines error if it is set in the model definition, - # so we set it here. - if ctx.dialect == "databricks": - cust_rev_by_day_key = [key for key in context._models if "customer_revenue_by_day" in key][ - 0 - ] - - cust_rev_by_day_model_tbl_props = context._models[cust_rev_by_day_key].copy( - update={ - "physical_properties": { - "delta.columnMapping.mode": exp.Literal(this="name", is_string=True) - } - } - ) - - context._models.update({cust_rev_by_day_key: cust_rev_by_day_model_tbl_props}) - - plan: Plan = context.plan( - environment="test_prod", - start=start, - end=end, - skip_tests=True, - no_prompts=True, - auto_apply=True, - ) - - data_validator = SushiDataValidator.from_context(context) - data_validator.validate( - "sushi.customer_revenue_lifetime", - start, - yesterday(), - env_name="test_prod", - dialect=ctx.dialect, - environment_naming_info=plan.environment_naming_info, - ) - - # Ensure table and column comments were correctly registered with engine - if ctx.engine_adapter.COMMENT_CREATION_TABLE.is_supported: - comments = { - "customer_revenue_by_day": { - "table": "Table of revenue from customers by day.", - "column": { - "customer_id": "Customer id", - "revenue": "Revenue from orders made by this customer", - "event_date": "Date", - }, - }, - "customer_revenue_lifetime": { - "table": """Table of lifetime customer revenue. - Date is available to get lifetime value up to a certain date. - Use latest date to get current lifetime value.""", - "column": { - "customer_id": "Customer id", - "revenue": "Lifetime revenue from this customer", - "event_date": "End date of the lifetime calculation", - }, - }, - "customers": { - "table": "Sushi customer data", - "column": {"customer_id": "customer_id uniquely identifies customers"}, - }, - "marketing": { - "table": "Sushi marketing data", - "column": {"customer_id": "customer_id uniquely identifies customers \\"}, - }, - "orders": { - "table": "Table of sushi orders.", - }, - "raw_marketing": { - "table": "Table of marketing status.", - "column": {"customer_id": "Unique identifier of the customer"}, - }, - "top_waiters": { - "table": "View of top waiters.", - }, - "waiter_names": { - "table": "List of waiter names", - }, - "waiter_revenue_by_day": { - "table": "Table of revenue generated by waiters by day.", - "column": { - "waiter_id": "Waiter id", - "revenue": "Revenue from orders taken by this waiter", - "event_date": "Date", - }, - }, - } - - def validate_comments( - schema_name: str, - expected_comments_dict: t.Dict[str, t.Any] = comments, - is_physical_layer: bool = True, - prod_schema_name: str = "sushi", - ) -> None: - layer_objects = context.engine_adapter.get_data_objects(schema_name) - layer_models = { - x.name.split("__")[1] if is_physical_layer else x.name: { - "table_name": x.name, - "is_view": x.type == DataObjectType.VIEW, - } - for x in layer_objects - if not x.name.endswith("__temp") - } - - for model_name, comment in comments.items(): - layer_table_name = layer_models[model_name]["table_name"] - table_kind = "VIEW" if layer_models[model_name]["is_view"] else "BASE TABLE" - - # is this model in a physical layer or PROD environment? - is_physical_or_prod = is_physical_layer or ( - not is_physical_layer and schema_name == prod_schema_name - ) - # is this model a VIEW and the engine doesn't support VIEW comments? - is_view_and_comments_unsupported = ( - layer_models[model_name]["is_view"] - and ctx.engine_adapter.COMMENT_CREATION_VIEW.is_unsupported - ) - if is_physical_or_prod and not is_view_and_comments_unsupported: - expected_tbl_comment = comments.get(model_name).get("table", None) - if expected_tbl_comment: - actual_tbl_comment = ctx.get_table_comment( - schema_name, - layer_table_name, - table_kind=table_kind, - snowflake_capitalize_ids=False, - ) - assert expected_tbl_comment == actual_tbl_comment - - expected_col_comments = comments.get(model_name).get("column", None) - - # Trino: - # Trino on Hive COMMENT permissions are separate from standard SQL object permissions. - # Trino has a bug where CREATE SQL permissions are not passed to COMMENT permissions, - # which generates permissions errors when COMMENT commands are issued. - # - # The errors are thrown for both table and comments, but apparently the - # table comments are actually registered with the engine. Column comments are not. - # - # Query: - # In the query test, columns_to_types are not available when views are created. Since we - # can only register column comments in the CREATE VIEW schema expression with columns_to_types - # available, the column comments must be registered via post-creation commands. Some engines, - # such as Spark and Snowflake, do not support view column comments via post-creation commands. - if ( - expected_col_comments - and not ctx.dialect == "trino" - and not ( - ctx.test_type == "query" - and layer_models[model_name]["is_view"] - and not ctx.engine_adapter.COMMENT_CREATION_VIEW.supports_column_comment_commands - ) - ): - actual_col_comments = ctx.get_column_comments( - schema_name, - layer_table_name, - table_kind=table_kind, - snowflake_capitalize_ids=False, - ) - for column_name, expected_col_comment in expected_col_comments.items(): - expected_col_comment = expected_col_comments.get(column_name, None) - actual_col_comment = actual_col_comments.get(column_name, None) - assert expected_col_comment == actual_col_comment - - return None - - def validate_no_comments( - schema_name: str, - expected_comments_dict: t.Dict[str, t.Any] = comments, - is_physical_layer: bool = True, - table_name_suffix: str = "", - check_temp_tables: bool = False, - prod_schema_name: str = "sushi", - ) -> None: - layer_objects = context.engine_adapter.get_data_objects(schema_name) - layer_models = { - x.name.split("__")[1] if is_physical_layer else x.name: { - "table_name": x.name, - "is_view": x.type == DataObjectType.VIEW, - } - for x in layer_objects - if x.name.endswith(table_name_suffix) - } - if not check_temp_tables: - layer_models = {k: v for k, v in layer_models.items() if not k.endswith("__temp")} - - for model_name, comment in comments.items(): - layer_table_name = layer_models[model_name]["table_name"] - table_kind = "VIEW" if layer_models[model_name]["is_view"] else "BASE TABLE" - - actual_tbl_comment = ctx.get_table_comment( - schema_name, - layer_table_name, - table_kind=table_kind, - snowflake_capitalize_ids=False, - ) - # MySQL doesn't support view comments and always returns "VIEW" as the table comment - if ctx.dialect == "mysql" and layer_models[model_name]["is_view"]: - assert actual_tbl_comment == "VIEW" - else: - assert actual_tbl_comment is None or actual_tbl_comment == "" - - # MySQL and Spark pass through the column comments from the underlying table to the view - # so always have view comments present - if not ( - ctx.dialect in ("mysql", "spark", "databricks") - and layer_models[model_name]["is_view"] - ): - expected_col_comments = comments.get(model_name).get("column", None) - if expected_col_comments: - actual_col_comments = ctx.get_column_comments( - schema_name, - layer_table_name, - table_kind=table_kind, - snowflake_capitalize_ids=False, - ) - for column_name in expected_col_comments: - actual_col_comment = actual_col_comments.get(column_name, None) - assert actual_col_comment is None or actual_col_comment == "" - - return None - - # confirm physical layer comments are registered - validate_comments("sqlmesh__sushi") - # confirm physical temp table comments are not registered - validate_no_comments("sqlmesh__sushi", table_name_suffix="__temp", check_temp_tables=True) - # confirm view layer comments are not registered in non-PROD environment - env_name = "test_prod" - if plan.environment_naming_info and plan.environment_naming_info.normalize_name: - env_name = normalize_identifiers(env_name, dialect=ctx.dialect).name - validate_no_comments(f"sushi__{env_name}", is_physical_layer=False) - - # Ensure that the plan has been applied successfully. - no_change_plan: Plan = context.plan( - environment="test_dev", - start=start, - end=end, - skip_tests=True, - no_prompts=True, - include_unmodified=True, - ) - assert not no_change_plan.requires_backfill - assert no_change_plan.context_diff.is_new_environment - - # make and validate unmodified dev environment - context.apply(no_change_plan) - - data_validator.validate( - "sushi.customer_revenue_lifetime", - start, - yesterday(), - env_name="test_dev", - dialect=ctx.dialect, - environment_naming_info=no_change_plan.environment_naming_info, - ) - - # confirm view layer comments are registered in PROD - if ctx.engine_adapter.COMMENT_CREATION_VIEW.is_supported: - context.plan(skip_tests=True, no_prompts=True, auto_apply=True) - validate_comments("sushi", is_physical_layer=False) - - -def test_init_project(ctx: TestContext, mark_gateway: t.Tuple[str, str], tmp_path: pathlib.Path): - if ctx.test_type != "query": - pytest.skip("Init example project end-to-end tests only need to run for query") - - init_example_project(tmp_path, ctx.dialect) - config = load_config_from_paths( - Config, - project_paths=[ - pathlib.Path(os.path.join(os.path.dirname(__file__), "config.yaml")), - ], - personal_paths=[pathlib.Path("~/.sqlmesh/config.yaml").expanduser()], - ) - _, gateway = mark_gateway - context = Context(paths=tmp_path, config=config, gateway=gateway) - ctx.engine_adapter = context.engine_adapter - - # clean up any leftover schemas from previous runs (requires context) - for schema in [ - "sqlmesh_example", - "sqlmesh_example__test_dev", - "sqlmesh__sqlmesh_example", - "sqlmesh", - ]: - context.engine_adapter.drop_schema(schema, ignore_if_not_exists=True, cascade=True) - - # apply prod plan - context.plan(auto_apply=True, no_prompts=True) - - prod_schema_results = ctx.get_metadata_results("sqlmesh_example") - assert sorted(prod_schema_results.views) == [ - "full_model", - "incremental_model", - "seed_model", - ] - assert len(prod_schema_results.materialized_views) == 0 - assert len(prod_schema_results.tables) == len(prod_schema_results.non_temp_tables) == 0 - - physical_layer_results = ctx.get_metadata_results("sqlmesh__sqlmesh_example") - assert len(physical_layer_results.views) == 0 - assert len(physical_layer_results.materialized_views) == 0 - assert len(physical_layer_results.tables) == len(physical_layer_results.non_temp_tables) == 6 - - # make and validate unmodified dev environment - no_change_plan: Plan = context.plan( - environment="test_dev", - skip_tests=True, - no_prompts=True, - include_unmodified=True, - ) - assert not no_change_plan.requires_backfill - assert no_change_plan.context_diff.is_new_environment - - context.apply(no_change_plan) - - environment = no_change_plan.environment - first_snapshot = no_change_plan.environment.snapshots[0] - schema_name = first_snapshot.qualified_view_name.schema_for_environment( - environment, dialect=ctx.dialect - ) - dev_schema_results = ctx.get_metadata_results(schema_name) - assert sorted(dev_schema_results.views) == [ - "full_model", - "incremental_model", - "seed_model", - ] - assert len(dev_schema_results.materialized_views) == 0 - assert len(dev_schema_results.tables) == len(dev_schema_results.non_temp_tables) == 0 - - -def test_dialects(ctx: TestContext): - if ctx.test_type != "query": - pytest.skip("Dialect tests only need to run once so we skip anything not query") - - from sqlglot import Dialect, parse_one - - dialect = Dialect[ctx.dialect] - - if dialect.NORMALIZATION_STRATEGY == "CASE_INSENSITIVE": - a = '"a"' - b = '"b"' - c = '"c"' - d = '"d"' - elif dialect.NORMALIZATION_STRATEGY == "LOWERCASE": - a = '"a"' - b = '"B"' - c = '"c"' - d = '"d"' - # https://dev.mysql.com/doc/refman/8.0/en/identifier-case-sensitivity.html - # if these tests fail for mysql it means you're running on os x or windows - elif dialect.NORMALIZATION_STRATEGY == "CASE_SENSITIVE": - a = '"a"' - b = '"B"' - c = '"c"' - d = '"D"' - else: - a = '"a"' - b = '"B"' - c = '"C"' - d = '"D"' - - q = parse_one( - f""" - WITH - "a" AS (SELECT 1 w), - "B" AS (SELECT 1 x), - c AS (SELECT 1 y), - D AS (SELECT 1 z) - - SELECT * - FROM {a} - CROSS JOIN {b} - CROSS JOIN {c} - CROSS JOIN {d} - """ - ) - df = ctx.engine_adapter.fetchdf(q) - expected_columns = ["W", "X", "Y", "Z"] if ctx.dialect == "snowflake" else ["w", "x", "y", "z"] - pd.testing.assert_frame_equal( - df, pd.DataFrame([[1, 1, 1, 1]], columns=expected_columns), check_dtype=False - ) - - -@pytest.mark.parametrize( - "time_column, time_column_type, time_column_format, result", - [ - ( - exp.null(), - exp.DataType.build("TIMESTAMP"), - None, - { - "default": None, - "bigquery": pd.NaT, - "databricks": pd.NaT, - "duckdb": pd.NaT, - "motherduck": pd.NaT, - "snowflake": pd.NaT, - "spark": pd.NaT, - }, - ), - ( - "2020-01-01 00:00:00+00:00", - exp.DataType.build("DATE"), - None, - { - "default": datetime(2020, 1, 1).date(), - "duckdb": pd.Timestamp("2020-01-01"), - }, - ), - ( - "2020-01-01 00:00:00+00:00", - exp.DataType.build("TIMESTAMPTZ"), - None, - { - "default": pd.Timestamp("2020-01-01 00:00:00+00:00"), - "mysql": pd.Timestamp("2020-01-01 00:00:00"), - "spark": pd.Timestamp("2020-01-01 00:00:00"), - }, - ), - ( - "2020-01-01 00:00:00+00:00", - exp.DataType.build("TIMESTAMP"), - None, - { - "default": pd.Timestamp("2020-01-01 00:00:00"), - # Databricks' timestamp type is tz-aware: - # "Represents values comprising values of fields year, month, day, hour, minute, and second, - # with the session local time-zone. - # The timestamp value represents an absolute point in time." - # https://docs.databricks.com/en/sql/language-manual/data-types/timestamp-type.html - # - # They are adding a non-aware version TIMESTAMP_NTZ that's currently in public preview - - # you have to specify a table option to use it: - # "Feature support is enabled automatically when you create a new Delta table with a column of - # TIMESTAMP_NTZ type. It is not enabled automatically when you add a column of - # TIMESTAMP_NTZ type to an existing table. - # To enable support for TIMESTAMP_NTZ columns, support for the feature must be explicitly enabled for - # the existing table." - # https://docs.databricks.com/en/sql/language-manual/data-types/timestamp-ntz-type.html - "databricks": pd.Timestamp("2020-01-01 00:00:00+00:00"), - }, - ), - ( - "2020-01-01 00:00:00+00:00", - exp.DataType.build("TEXT"), - "%Y-%m-%dT%H:%M:%S%z", - { - "default": "2020-01-01T00:00:00+0000", - }, - ), - ( - "2020-01-01 00:00:00+00:00", - exp.DataType.build("INT"), - "%Y%m%d", - { - "default": 20200101, - }, - ), - ], -) -def test_to_time_column( - ctx: TestContext, time_column, time_column_type, time_column_format, result -): - if ctx.test_type != "query": - pytest.skip("Time column tests only need to run for query") - - time_column = to_time_column(time_column, time_column_type, time_column_format) - df = ctx.engine_adapter.fetchdf(exp.select(time_column).as_("the_col")) - expected = result.get(ctx.dialect, result.get("default")) - col_name = "THE_COL" if ctx.dialect == "snowflake" else "the_col" - if expected is pd.NaT or expected is None: - assert df[col_name][0] is expected - else: - assert df[col_name][0] == expected - - -def test_batch_size_on_incremental_by_unique_key_model( - ctx: TestContext, mark_gateway: t.Tuple[str, str] -): - if ctx.test_type != "query": - pytest.skip("This only needs to run once so we skip anything not query") - - if not ctx.supports_merge: - _, gateway = mark_gateway - pytest.skip(f"{ctx.dialect} on {gateway} doesnt support merge") - - def _mutate_config(current_gateway_name: str, config: Config): - # make stepping through in the debugger easier - connection = config.gateways[current_gateway_name].connection - connection.concurrent_tasks = 1 - - context = ctx.create_context(_mutate_config) - assert context.default_dialect == "duckdb" - - schema = ctx.schema(TEST_SCHEMA) - seed_query = ctx.input_data( - pd.DataFrame( - [ - [2, "2020-01-01"], - [1, "2020-01-01"], - [3, "2020-01-03"], - [1, "2020-01-04"], - [1, "2020-01-05"], - [1, "2020-01-06"], - [1, "2020-01-07"], - ], - columns=["item_id", "event_date"], - ), - columns_to_types={ - "item_id": exp.DataType.build("integer"), - "event_date": exp.DataType.build("date"), - }, - ) - context.upsert_model(create_sql_model(name=f"{schema}.seed_model", query=seed_query)) - context.upsert_model( - load_sql_based_model( - d.parse( - f"""MODEL ( - name {schema}.test_model, - kind INCREMENTAL_BY_UNIQUE_KEY ( - unique_key item_id, - batch_size 1 - ), - start '2020-01-01', - end '2020-01-07', - cron '@daily' - ); - - select * from {schema}.seed_model - where event_date between @start_date and @end_date""", - ) - ) - ) - - try: - context.plan(auto_apply=True, no_prompts=True) - - test_model = context.get_model(f"{schema}.test_model") - normalized_schema_name = test_model.fully_qualified_table.db - results = ctx.get_metadata_results(normalized_schema_name) - assert "test_model" in results.views - - actual_df = ( - ctx.get_current_data(test_model.fqn).sort_values(by="event_date").reset_index(drop=True) - ) - actual_df["event_date"] = actual_df["event_date"].astype(str) - assert actual_df.count()[0] == 3 - - expected_df = pd.DataFrame( - [[2, "2020-01-01"], [3, "2020-01-03"], [1, "2020-01-07"]], - columns=actual_df.columns, - ).sort_values(by="event_date") - - pd.testing.assert_frame_equal( - actual_df, - expected_df, - check_dtype=False, - ) - - finally: - ctx.cleanup(context) - - -def test_managed_model_upstream_forward_only(ctx: TestContext): - """ - This scenario goes as follows: - - A managed model B is a downstream dependency of an incremental model A - (as a sidenote: this is an incorrect use of managed models, they should really only reference external models, but we dont prevent it specifically to be more user friendly) - - User plans a forward-only change against Model A in a virtual environment "dev" - - This causes a new non-deployable snapshot of Model B in "dev". - - In these situations, we create a normal table for Model B, not a managed table - - User modifies model B and applies a plan in "dev" - - This should also result in a normal table - - User decides they want to deploy so they run their plan against prod - - We need to ensure we ignore the normal table for Model B (it was just a dev preview) and create a new managed table for prod - - Upon apply to prod, Model B should be completely recreated as a managed table - """ - - if ctx.test_type != "query": - pytest.skip("This only needs to run once so we skip anything not query") - - if not ctx.engine_adapter.SUPPORTS_MANAGED_MODELS: - pytest.skip("This test only runs for engines that support managed models") - - def _run_plan(sqlmesh_context: Context, environment: str = None) -> PlanResults: - plan: Plan = sqlmesh_context.plan(auto_apply=True, no_prompts=True, environment=environment) - return PlanResults.create(plan, ctx, schema) - - context = ctx.create_context() - schema = ctx.add_test_suffix(TEST_SCHEMA) - - model_a = load_sql_based_model( - d.parse( # type: ignore - f""" - MODEL ( - name {schema}.upstream_model, - kind INCREMENTAL_BY_TIME_RANGE ( - time_column ts, - forward_only True - ), - ); - - SELECT 1 as id, 'foo' as name, current_timestamp as ts; - """ - ) - ) - - model_b = load_sql_based_model( - d.parse( # type: ignore - f""" - MODEL ( - name {schema}.managed_model, - kind MANAGED, - physical_properties ( - target_lag = '5 minutes' - ) - ); - - SELECT * from {schema}.upstream_model; - """ - ) - ) - - context.upsert_model(model_a) - context.upsert_model(model_b) - - plan_1 = _run_plan(context) - - assert plan_1.snapshot_for(model_a).change_category == SnapshotChangeCategory.BREAKING - assert plan_1.snapshot_for(model_b).change_category == SnapshotChangeCategory.BREAKING - - # so far so good, model_a should exist as a normal table, model b should be a managed table and the prod views should exist - assert len(plan_1.schema_metadata.views) == 2 - assert plan_1.snapshot_for(model_a).model.view_name in plan_1.schema_metadata.views - assert plan_1.snapshot_for(model_b).model.view_name in plan_1.schema_metadata.views - - assert len(plan_1.internal_schema_metadata.tables) == 3 - assert plan_1.table_name_for(model_a) in plan_1.internal_schema_metadata.tables - assert plan_1.dev_table_name_for(model_a) in plan_1.internal_schema_metadata.tables - assert ( - plan_1.table_name_for(model_b) not in plan_1.internal_schema_metadata.tables - ) # because its a managed table - assert ( - plan_1.dev_table_name_for(model_b) in plan_1.internal_schema_metadata.tables - ) # its dev table is a normal table however - - assert len(plan_1.internal_schema_metadata.managed_tables) == 1 - assert plan_1.table_name_for(model_b) in plan_1.internal_schema_metadata.managed_tables - assert ( - plan_1.dev_table_name_for(model_b) not in plan_1.internal_schema_metadata.managed_tables - ) # the dev table should not be created as managed - - # Let's modify model A with a breaking change and plan it against a dev environment. This should trigger a forward-only plan - new_model_a = load_sql_based_model( - d.parse( # type: ignore - f""" - MODEL ( - name {schema}.upstream_model, - kind INCREMENTAL_BY_TIME_RANGE ( - time_column ts, - forward_only True - ), - ); - - SELECT 1 as id, 'foo' as name, 'bar' as extra, current_timestamp as ts; - """ - ) - ) - context.upsert_model(new_model_a) - - # apply plan to dev environment - plan_2 = _run_plan(context, "dev") - - assert plan_2.plan.has_changes - assert len(plan_2.plan.modified_snapshots) == 2 - assert plan_2.snapshot_for(new_model_a).change_category == SnapshotChangeCategory.FORWARD_ONLY - assert plan_2.snapshot_for(model_b).change_category == SnapshotChangeCategory.NON_BREAKING - - # verify that the new snapshots were created correctly - # the forward-only change to model A should be in a new table separate from the one created in the first plan - # since model B depends on an upstream model with a forward-only change, it should also get recreated, but as a normal table, not a managed table - assert plan_2.table_name_for(model_a) == plan_1.table_name_for( - model_a - ) # no change in the main table because the dev preview changes go to the dev table - assert plan_2.dev_table_name_for(model_a) != plan_1.dev_table_name_for( - model_a - ) # it creates a new dev table to hold the dev preview - assert plan_2.dev_table_name_for(model_a) in plan_2.internal_schema_metadata.tables - - assert plan_2.table_name_for(model_b) != plan_1.table_name_for( - model_b - ) # model b gets a new table - assert plan_2.dev_table_name_for(model_b) != plan_1.dev_table_name_for( - model_b - ) # model b gets a new dev table as well - assert ( - plan_2.table_name_for(model_b) not in plan_2.internal_schema_metadata.tables - ) # the new main table is not actually created, because it was triggered by a forward-only change. downstream models use the dev table - assert plan_2.table_name_for(model_b) not in plan_2.internal_schema_metadata.managed_tables - assert ( - plan_2.dev_table_name_for(model_b) in plan_2.internal_schema_metadata.tables - ) # dev tables are always regular tables for managed models - - # modify model B, still in the dev environment - new_model_b = load_sql_based_model( - d.parse( # type: ignore - f""" - MODEL ( - name {schema}.managed_model, - kind MANAGED, - physical_properties ( - target_lag = '5 minutes' - ) - ); - - SELECT *, 'modified' as extra_b from {schema}.upstream_model; - """ - ) - ) - context.upsert_model(new_model_b) - - plan_3 = _run_plan(context, "dev") - - assert plan_3.plan.has_changes - assert len(plan_3.plan.modified_snapshots) == 1 - assert ( - plan_3.modified_snapshot_for(model_b).change_category == SnapshotChangeCategory.NON_BREAKING - ) - - # model A should be unchanged - # the new model B should be a normal table, not a managed table - assert plan_3.table_name_for(model_a) == plan_2.table_name_for(model_a) - assert plan_3.dev_table_name_for(model_a) == plan_2.dev_table_name_for(model_a) - assert plan_3.table_name_for(model_b) != plan_2.table_name_for(model_b) - assert plan_3.dev_table_name_for(model_b) != plan_2.table_name_for(model_b) - - assert ( - plan_3.table_name_for(model_b) not in plan_3.internal_schema_metadata.tables - ) # still using the dev table, no main table created - assert plan_3.dev_table_name_for(model_b) in plan_3.internal_schema_metadata.tables - assert ( - plan_3.table_name_for(model_b) not in plan_3.internal_schema_metadata.managed_tables - ) # still not a managed table - - # apply plan to prod - plan_4 = _run_plan(context) - - assert plan_4.plan.has_changes - assert plan_4.snapshot_for(model_a).change_category == SnapshotChangeCategory.FORWARD_ONLY - assert plan_4.snapshot_for(model_b).change_category == SnapshotChangeCategory.NON_BREAKING - - # verify the Model B table is created as a managed table in prod - assert plan_4.table_name_for(model_b) == plan_3.table_name_for( - model_b - ) # the model didnt change; the table should still have the same name - assert ( - plan_4.table_name_for(model_b) not in plan_4.internal_schema_metadata.tables - ) # however, it should be a managed table, not a normal table - assert plan_4.table_name_for(model_b) in plan_4.internal_schema_metadata.managed_tables diff --git a/tests/core/engine_adapter/test_mixins.py b/tests/core/engine_adapter/test_mixins.py index 57803427d4..50bef59d6e 100644 --- a/tests/core/engine_adapter/test_mixins.py +++ b/tests/core/engine_adapter/test_mixins.py @@ -23,7 +23,7 @@ def test_logical_merge(make_mocked_engine_adapter: t.Callable, mocker: MockerFix adapter.merge( target_table="target", source_table=t.cast(exp.Select, parse_one("SELECT id, ts, val FROM source")), - columns_to_types={ + target_columns_to_types={ "id": exp.DataType(this=exp.DataType.Type.INT), "ts": exp.DataType(this=exp.DataType.Type.TIMESTAMP), "val": exp.DataType(this=exp.DataType.Type.INT), @@ -48,7 +48,7 @@ def test_logical_merge(make_mocked_engine_adapter: t.Callable, mocker: MockerFix adapter.merge( target_table="target", source_table=t.cast(exp.Select, parse_one("SELECT id, ts, val FROM source")), - columns_to_types={ + target_columns_to_types={ "id": exp.DataType(this=exp.DataType.Type.INT), "ts": exp.DataType(this=exp.DataType.Type.TIMESTAMP), "val": exp.DataType(this=exp.DataType.Type.INT), diff --git a/tests/core/engine_adapter/test_mssql.py b/tests/core/engine_adapter/test_mssql.py index 47d0847524..ec6a4ba3e8 100644 --- a/tests/core/engine_adapter/test_mssql.py +++ b/tests/core/engine_adapter/test_mssql.py @@ -3,27 +3,32 @@ from datetime import date from unittest import mock -import pandas as pd +import pandas as pd # noqa: TID253 import pytest from pytest_mock.plugin import MockerFixture from sqlglot import expressions as exp from sqlglot import parse_one +from pathlib import Path +from sqlmesh import model from sqlmesh.core.engine_adapter.mssql import MSSQLEngineAdapter -from sqlmesh.core.engine_adapter.shared import ( - DataObject, - DataObjectType, - InsertOverwriteStrategy, -) +from sqlmesh.core.snapshot import SnapshotEvaluator, SnapshotChangeCategory, Snapshot +from sqlmesh.core.model import load_sql_based_model +from sqlmesh.core.model.kind import SCDType2ByTimeKind +from sqlmesh.core import dialect as d +from sqlmesh.core.engine_adapter.shared import DataObject, DataObjectType, SourceQuery from sqlmesh.utils.date import to_ds from tests.core.engine_adapter import to_sql_calls pytestmark = [pytest.mark.engine, pytest.mark.mssql] -def test_columns(make_mocked_engine_adapter: t.Callable): - adapter = make_mocked_engine_adapter(MSSQLEngineAdapter) +@pytest.fixture +def adapter(make_mocked_engine_adapter: t.Callable) -> MSSQLEngineAdapter: + return make_mocked_engine_adapter(MSSQLEngineAdapter) + +def test_columns(adapter: MSSQLEngineAdapter): adapter.cursor.fetchall.return_value = [ ("decimal_ps", "decimal", None, 5, 4), ("decimal", "decimal", None, 18, 0), @@ -73,7 +78,7 @@ def test_columns(make_mocked_engine_adapter: t.Callable): } adapter.cursor.execute.assert_called_once_with( - """SELECT [column_name], [data_type], [character_maximum_length], [numeric_precision], [numeric_scale] FROM [information_schema].[columns] WHERE [table_name] = 'table' AND [table_schema] = 'db';""" + """SELECT [COLUMN_NAME], [DATA_TYPE], [CHARACTER_MAXIMUM_LENGTH], [NUMERIC_PRECISION], [NUMERIC_SCALE] FROM [INFORMATION_SCHEMA].[COLUMNS] WHERE [TABLE_NAME] = 'table' AND [TABLE_SCHEMA] = 'db';""" ) @@ -143,8 +148,8 @@ def test_table_exists(make_mocked_engine_adapter: t.Callable): resp = adapter.table_exists("db.table") adapter.cursor.execute.assert_called_once_with( """SELECT 1 """ - """FROM [information_schema].[tables] """ - """WHERE [table_name] = 'table' AND [table_schema] = 'db';""" + """FROM [INFORMATION_SCHEMA].[TABLES] """ + """WHERE [TABLE_NAME] = 'table' AND [TABLE_SCHEMA] = 'db';""" ) assert resp adapter.cursor.fetchone.return_value = None @@ -152,6 +157,114 @@ def test_table_exists(make_mocked_engine_adapter: t.Callable): assert not resp +@pytest.mark.parametrize( + "select_expr, input_time, expected_sql", + [ + # Respect the user's precision for datetimeoffset, time, datetime2 + ( + "SELECT ds::datetime2", + pd.Timestamp("2022-01-01 00:00:00.1234567"), + "CAST('2022-01-01 00:00:00.123456700' AS DATETIME2)", + ), + ( + "SELECT ds::datetimeoffset(4)", + pd.Timestamp("2022-01-01 00:00:00.1234567"), + "CAST('2022-01-01 00:00:00.123456700+00:00' AS DATETIMEOFFSET(4))", + ), + ( + "SELECT ds::time", + pd.Timestamp("2022-01-01 00:00:00.1234567"), + "CAST('2022-01-01 00:00:00.123456700' AS TIME)", + ), + # Respecting precision in datetimeoffset with time zone offsets + ( + "SELECT ds::time(7)", + pd.Timestamp("2022-01-01 00:00:00.1234567+00:00"), + "CAST('2022-01-01 00:00:00.123456700' AS TIME(7))", + ), + ( + "SELECT ds::datetimeoffset(6)", + pd.Timestamp("2022-01-01 00:00:00.1234567+02:00"), + "CAST('2021-12-31 22:00:00.123456700+00:00' AS DATETIMEOFFSET(6))", + ), + # For date types without nano-second precision, truncate as usual + ( + "SELECT ds::datetime", + "2022-01-01 00:00:00.1234567+01:00", + "CAST('2021-12-31 23:00:00.123456' AS DATETIME)", + ), + ( + "SELECT ds::smalldatetime", + "2022-01-01 00:00:00.1234567+00:00", + "CAST('2022-01-01 00:00:00.123456' AS SMALLDATETIME)", + ), + ("SELECT ds::date", "2022-01-01 00:00:00.001", "CAST('2022-01-01' AS DATE)"), + ], +) +def test_to_time_column(select_expr, input_time, expected_sql): + expressions = d.parse( + f""" + MODEL ( + name db.table, + kind INCREMENTAL_BY_TIME_RANGE( + time_column (ds) + ), + dialect tsql + ); + + {select_expr} + """ + ) + model = load_sql_based_model(expressions) + assert model.convert_to_time_column(input_time).sql("tsql") == expected_sql + + +def test_incremental_by_time_datetimeoffset_precision( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture, make_snapshot +): + adapter = make_mocked_engine_adapter(MSSQLEngineAdapter) + adapter.get_current_catalog = mocker.MagicMock(return_value="other_catalog") + + evaluator = SnapshotEvaluator(adapter) + parsed = d.parse( # type: ignore + """ + MODEL ( + name test_schema.test_model, + kind INCREMENTAL_BY_TIME_RANGE (time_column ds), + ); + + SELECT a::int, ds::datetimeoffset FROM tbl as t WHERE t.ds BETWEEN @start_dt and @end_dt; + """, + ) + + model = load_sql_based_model(parsed, dialect="tsql") + + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + evaluator.evaluate( + snapshot, + start="2020-01-01", + end="2020-01-02", + execution_time="2020-01-02", + snapshots={}, + target_table_exists=True, + ) + + assert adapter.cursor.execute.call_args_list[0][0][0] == ( + f"MERGE INTO [sqlmesh__test_schema].[test_schema__test_model__{snapshot.version}] AS [__MERGE_TARGET__] USING " + "(SELECT [a] AS [a], [ds] AS [ds] FROM (SELECT CAST([a] AS INTEGER) AS [a], " + "CAST([ds] AS DATETIMEOFFSET) AS [ds] FROM [tbl] AS [t] WHERE [t].[ds] BETWEEN " + "CAST('2020-01-01 00:00:00+00:00' AS DATETIMEOFFSET) AT TIME ZONE 'UTC' AND " + "CAST('2020-01-02 23:59:59.999999999+00:00' AS DATETIMEOFFSET) AT TIME ZONE 'UTC') AS [_subquery] " + "WHERE [ds] BETWEEN CAST('2020-01-01 00:00:00+00:00' AS DATETIMEOFFSET) AND " + "CAST('2020-01-02 23:59:59.999999999+00:00' AS DATETIMEOFFSET)) AS [__MERGE_SOURCE__] ON (1 = 0) WHEN NOT " + "MATCHED BY SOURCE AND [ds] BETWEEN CAST('2020-01-01 00:00:00+00:00' AS DATETIMEOFFSET) AND " + "CAST('2020-01-02 23:59:59.999999999+00:00' AS DATETIMEOFFSET) THEN DELETE WHEN NOT MATCHED THEN INSERT " + "([a], [ds]) VALUES ([a], [ds]);" + ) + + def test_insert_overwrite_by_time_partition_supports_insert_overwrite_pandas_not_exists( make_mocked_engine_adapter: t.Callable, mocker: MockerFixture, make_temp_table_name: t.Callable ): @@ -177,13 +290,16 @@ def test_insert_overwrite_by_time_partition_supports_insert_overwrite_pandas_not end="2022-01-02", time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), time_column="ds", - columns_to_types={"a": exp.DataType.build("INT"), "ds": exp.DataType.build("STRING")}, + target_columns_to_types={ + "a": exp.DataType.build("INT"), + "ds": exp.DataType.build("STRING"), + }, ) adapter._connection_pool.get().bulk_copy.assert_called_with( f"__temp_test_table_{temp_table_id}", [(1, "2022-01-01"), (2, "2022-01-02")] ) assert to_sql_calls(adapter) == [ - f"""IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = '__temp_test_table_{temp_table_id}') EXEC('CREATE TABLE [__temp_test_table_{temp_table_id}] ([a] INTEGER, [ds] VARCHAR(MAX))');""", + f"""IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '__temp_test_table_{temp_table_id}') EXEC('CREATE TABLE [__temp_test_table_{temp_table_id}] ([a] INTEGER, [ds] VARCHAR(MAX))');""", f"""MERGE INTO [test_table] AS [__MERGE_TARGET__] USING (SELECT [a] AS [a], [ds] AS [ds] FROM (SELECT CAST([a] AS INTEGER) AS [a], CAST([ds] AS VARCHAR(MAX)) AS [ds] FROM [__temp_test_table_{temp_table_id}]) AS [_subquery] WHERE [ds] BETWEEN '2022-01-01' AND '2022-01-02') AS [__MERGE_SOURCE__] ON (1 = 0) WHEN NOT MATCHED BY SOURCE AND [ds] BETWEEN '2022-01-01' AND '2022-01-02' THEN DELETE WHEN NOT MATCHED THEN INSERT ([a], [ds]) VALUES ([a], [ds]);""", f"DROP TABLE IF EXISTS [__temp_test_table_{temp_table_id}];", ] @@ -214,46 +330,12 @@ def test_insert_overwrite_by_time_partition_supports_insert_overwrite_pandas_exi end="2022-01-02", time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), time_column="ds", - columns_to_types={"a": exp.DataType.build("INT"), "ds": exp.DataType.build("STRING")}, - ) - assert to_sql_calls(adapter) == [ - f"""MERGE INTO [test_table] AS [__MERGE_TARGET__] USING (SELECT [a] AS [a], [ds] AS [ds] FROM (SELECT CAST([a] AS INTEGER) AS [a], CAST([ds] AS VARCHAR(MAX)) AS [ds] FROM [__temp_test_table_{temp_table_id}]) AS [_subquery] WHERE [ds] BETWEEN '2022-01-01' AND '2022-01-02') AS [__MERGE_SOURCE__] ON (1 = 0) WHEN NOT MATCHED BY SOURCE AND [ds] BETWEEN '2022-01-01' AND '2022-01-02' THEN DELETE WHEN NOT MATCHED THEN INSERT ([a], [ds]) VALUES ([a], [ds]);""", - f"DROP TABLE IF EXISTS [__temp_test_table_{temp_table_id}];", - ] - - -def test_insert_overwrite_by_time_partition_replace_where_pandas( - make_mocked_engine_adapter: t.Callable, mocker: MockerFixture, make_temp_table_name: t.Callable -): - mocker.patch( - "sqlmesh.core.engine_adapter.mssql.MSSQLEngineAdapter.table_exists", - return_value=False, - ) - - adapter = make_mocked_engine_adapter(MSSQLEngineAdapter) - adapter.INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.REPLACE_WHERE - - temp_table_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table") - table_name = "test_table" - temp_table_id = "abcdefgh" - temp_table_mock.return_value = make_temp_table_name(table_name, temp_table_id) - - df = pd.DataFrame({"a": [1, 2], "ds": ["2022-01-01", "2022-01-02"]}) - adapter.insert_overwrite_by_time_partition( - table_name, - df, - start="2022-01-01", - end="2022-01-02", - time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), - time_column="ds", - columns_to_types={"a": exp.DataType.build("INT"), "ds": exp.DataType.build("STRING")}, - ) - adapter._connection_pool.get().bulk_copy.assert_called_with( - f"__temp_test_table_{temp_table_id}", [(1, "2022-01-01"), (2, "2022-01-02")] + target_columns_to_types={ + "a": exp.DataType.build("INT"), + "ds": exp.DataType.build("STRING"), + }, ) - assert to_sql_calls(adapter) == [ - f"""IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = '__temp_test_table_{temp_table_id}') EXEC('CREATE TABLE [__temp_test_table_{temp_table_id}] ([a] INTEGER, [ds] VARCHAR(MAX))');""", f"""MERGE INTO [test_table] AS [__MERGE_TARGET__] USING (SELECT [a] AS [a], [ds] AS [ds] FROM (SELECT CAST([a] AS INTEGER) AS [a], CAST([ds] AS VARCHAR(MAX)) AS [ds] FROM [__temp_test_table_{temp_table_id}]) AS [_subquery] WHERE [ds] BETWEEN '2022-01-01' AND '2022-01-02') AS [__MERGE_SOURCE__] ON (1 = 0) WHEN NOT MATCHED BY SOURCE AND [ds] BETWEEN '2022-01-01' AND '2022-01-02' THEN DELETE WHEN NOT MATCHED THEN INSERT ([a], [ds]) VALUES ([a], [ds]);""", f"DROP TABLE IF EXISTS [__temp_test_table_{temp_table_id}];", ] @@ -278,7 +360,7 @@ def test_insert_append_pandas( adapter.insert_append( table_name, df, - columns_to_types={ + target_columns_to_types={ "a": exp.DataType.build("INT"), "b": exp.DataType.build("INT"), }, @@ -288,7 +370,7 @@ def test_insert_append_pandas( ) assert to_sql_calls(adapter) == [ - f"""IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = '__temp_test_table_{temp_table_id}') EXEC('CREATE TABLE [__temp_test_table_{temp_table_id}] ([a] INTEGER, [b] INTEGER)');""", + f"""IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '__temp_test_table_{temp_table_id}') EXEC('CREATE TABLE [__temp_test_table_{temp_table_id}] ([a] INTEGER, [b] INTEGER)');""", f"INSERT INTO [test_table] ([a], [b]) SELECT CAST([a] AS INTEGER) AS [a], CAST([b] AS INTEGER) AS [b] FROM [__temp_test_table_{temp_table_id}];", f"DROP TABLE IF EXISTS [__temp_test_table_{temp_table_id}];", ] @@ -304,7 +386,7 @@ def test_create_table(make_mocked_engine_adapter: t.Callable): adapter.create_table("test_table", columns_to_types) adapter.cursor.execute.assert_called_once_with( - """IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = 'test_table') EXEC('CREATE TABLE [test_table] ([cola] INTEGER, [colb] VARCHAR(MAX))');""" + """IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = 'test_table') EXEC('CREATE TABLE [test_table] ([cola] INTEGER, [colb] VARCHAR(MAX))');""" ) @@ -323,7 +405,7 @@ def test_create_physical_properties(make_mocked_engine_adapter: t.Callable): ) adapter.cursor.execute.assert_called_once_with( - """IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = 'test_table') EXEC('CREATE TABLE [test_table] ([cola] INTEGER, [colb] VARCHAR(MAX))');""" + """IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = 'test_table') EXEC('CREATE TABLE [test_table] ([cola] INTEGER, [colb] VARCHAR(MAX))');""" ) @@ -343,10 +425,12 @@ def test_merge_pandas( temp_table_mock.return_value = make_temp_table_name(table_name, temp_table_id) df = pd.DataFrame({"id": [1, 2, 3], "ts": [1, 2, 3], "val": [4, 5, 6]}) + + # 1 key adapter.merge( target_table=table_name, source_table=df, - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("int"), "ts": exp.DataType.build("TIMESTAMP"), "val": exp.DataType.build("int"), @@ -358,18 +442,19 @@ def test_merge_pandas( ) assert to_sql_calls(adapter) == [ - f"""IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = '__temp_target_{temp_table_id}') EXEC('CREATE TABLE [__temp_target_{temp_table_id}] ([id] INTEGER, [ts] DATETIME2, [val] INTEGER)');""", - f"MERGE INTO [target] AS [__MERGE_TARGET__] USING (SELECT CAST([id] AS INTEGER) AS [id], CAST([ts] AS DATETIME2) AS [ts], CAST([val] AS INTEGER) AS [val] FROM [__temp_target_{temp_table_id}]) AS [__MERGE_SOURCE__] ON [__MERGE_TARGET__].[id] = [__MERGE_SOURCE__].[id] WHEN MATCHED THEN UPDATE SET [__MERGE_TARGET__].[id] = [__MERGE_SOURCE__].[id], [__MERGE_TARGET__].[ts] = [__MERGE_SOURCE__].[ts], [__MERGE_TARGET__].[val] = [__MERGE_SOURCE__].[val] WHEN NOT MATCHED THEN INSERT ([id], [ts], [val]) VALUES ([__MERGE_SOURCE__].[id], [__MERGE_SOURCE__].[ts], [__MERGE_SOURCE__].[val]);", + f"""IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '__temp_target_{temp_table_id}') EXEC('CREATE TABLE [__temp_target_{temp_table_id}] ([id] INTEGER, [ts] DATETIME2, [val] INTEGER)');""", + f"MERGE INTO [target] AS [__MERGE_TARGET__] USING (SELECT CAST([id] AS INTEGER) AS [id], CAST([ts] AS DATETIME2) AS [ts], CAST([val] AS INTEGER) AS [val] FROM [__temp_target_{temp_table_id}]) AS [__MERGE_SOURCE__] ON [__MERGE_TARGET__].[id] = [__MERGE_SOURCE__].[id] WHEN MATCHED THEN UPDATE SET [__MERGE_TARGET__].[ts] = [__MERGE_SOURCE__].[ts], [__MERGE_TARGET__].[val] = [__MERGE_SOURCE__].[val] WHEN NOT MATCHED THEN INSERT ([id], [ts], [val]) VALUES ([__MERGE_SOURCE__].[id], [__MERGE_SOURCE__].[ts], [__MERGE_SOURCE__].[val]);", f"DROP TABLE IF EXISTS [__temp_target_{temp_table_id}];", ] + # 2 keys adapter.cursor.reset_mock() adapter._connection_pool.get().reset_mock() temp_table_mock.return_value = make_temp_table_name(table_name, temp_table_id) adapter.merge( target_table=table_name, source_table=df, - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("int"), "ts": exp.DataType.build("TIMESTAMP"), "val": exp.DataType.build("int"), @@ -381,20 +466,103 @@ def test_merge_pandas( ) assert to_sql_calls(adapter) == [ - f"""IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = '__temp_target_{temp_table_id}') EXEC('CREATE TABLE [__temp_target_{temp_table_id}] ([id] INTEGER, [ts] DATETIME2, [val] INTEGER)');""", - f"MERGE INTO [target] AS [__MERGE_TARGET__] USING (SELECT CAST([id] AS INTEGER) AS [id], CAST([ts] AS DATETIME2) AS [ts], CAST([val] AS INTEGER) AS [val] FROM [__temp_target_{temp_table_id}]) AS [__MERGE_SOURCE__] ON [__MERGE_TARGET__].[id] = [__MERGE_SOURCE__].[id] AND [__MERGE_TARGET__].[ts] = [__MERGE_SOURCE__].[ts] WHEN MATCHED THEN UPDATE SET [__MERGE_TARGET__].[id] = [__MERGE_SOURCE__].[id], [__MERGE_TARGET__].[ts] = [__MERGE_SOURCE__].[ts], [__MERGE_TARGET__].[val] = [__MERGE_SOURCE__].[val] WHEN NOT MATCHED THEN INSERT ([id], [ts], [val]) VALUES ([__MERGE_SOURCE__].[id], [__MERGE_SOURCE__].[ts], [__MERGE_SOURCE__].[val]);", + f"""IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '__temp_target_{temp_table_id}') EXEC('CREATE TABLE [__temp_target_{temp_table_id}] ([id] INTEGER, [ts] DATETIME2, [val] INTEGER)');""", + f"MERGE INTO [target] AS [__MERGE_TARGET__] USING (SELECT CAST([id] AS INTEGER) AS [id], CAST([ts] AS DATETIME2) AS [ts], CAST([val] AS INTEGER) AS [val] FROM [__temp_target_{temp_table_id}]) AS [__MERGE_SOURCE__] ON [__MERGE_TARGET__].[id] = [__MERGE_SOURCE__].[id] AND [__MERGE_TARGET__].[ts] = [__MERGE_SOURCE__].[ts] WHEN MATCHED THEN UPDATE SET [__MERGE_TARGET__].[val] = [__MERGE_SOURCE__].[val] WHEN NOT MATCHED THEN INSERT ([id], [ts], [val]) VALUES ([__MERGE_SOURCE__].[id], [__MERGE_SOURCE__].[ts], [__MERGE_SOURCE__].[val]);", f"DROP TABLE IF EXISTS [__temp_target_{temp_table_id}];", ] -def test_replace_query(make_mocked_engine_adapter: t.Callable): +def test_merge_exists( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture, make_temp_table_name: t.Callable +): + mocker.patch( + "sqlmesh.core.engine_adapter.mssql.MSSQLEngineAdapter.table_exists", + return_value=False, + ) + adapter = make_mocked_engine_adapter(MSSQLEngineAdapter) - adapter.cursor.fetchone.return_value = (1,) + + temp_table_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table") + table_name = "target" + temp_table_id = "abcdefgh" + temp_table_mock.return_value = make_temp_table_name(table_name, temp_table_id) + + df = pd.DataFrame({"id": [1, 2, 3], "ts": [1, 2, 3], "val": [4, 5, 6]}) + + # regular implementation + adapter.merge( + target_table=table_name, + source_table=df, + target_columns_to_types={ + "id": exp.DataType.build("int"), + "ts": exp.DataType.build("TIMESTAMP"), + "val": exp.DataType.build("int"), + }, + unique_key=[exp.to_identifier("id")], + ) + + assert to_sql_calls(adapter) == [ + f"""IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '__temp_target_{temp_table_id}') EXEC('CREATE TABLE [__temp_target_{temp_table_id}] ([id] INTEGER, [ts] DATETIME2, [val] INTEGER)');""", + f"MERGE INTO [target] AS [__MERGE_TARGET__] USING (SELECT CAST([id] AS INTEGER) AS [id], CAST([ts] AS DATETIME2) AS [ts], CAST([val] AS INTEGER) AS [val] FROM [__temp_target_{temp_table_id}]) AS [__MERGE_SOURCE__] ON [__MERGE_TARGET__].[id] = [__MERGE_SOURCE__].[id] WHEN MATCHED THEN UPDATE SET [__MERGE_TARGET__].[ts] = [__MERGE_SOURCE__].[ts], [__MERGE_TARGET__].[val] = [__MERGE_SOURCE__].[val] WHEN NOT MATCHED THEN INSERT ([id], [ts], [val]) VALUES ([__MERGE_SOURCE__].[id], [__MERGE_SOURCE__].[ts], [__MERGE_SOURCE__].[val]);", + f"DROP TABLE IF EXISTS [__temp_target_{temp_table_id}];", + ] + + # merge exists implementation + adapter.cursor.reset_mock() + adapter._connection_pool.get().reset_mock() + temp_table_mock.return_value = make_temp_table_name(table_name, temp_table_id) + adapter.merge( + target_table=table_name, + source_table=df, + target_columns_to_types={ + "id": exp.DataType.build("int"), + "ts": exp.DataType.build("TIMESTAMP"), + "val": exp.DataType.build("int"), + }, + unique_key=[exp.to_identifier("id")], + physical_properties={"mssql_merge_exists": True}, + ) + + assert to_sql_calls(adapter) == [ + f"""IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '__temp_target_{temp_table_id}') EXEC('CREATE TABLE [__temp_target_{temp_table_id}] ([id] INTEGER, [ts] DATETIME2, [val] INTEGER)');""", + f"MERGE INTO [target] AS [__MERGE_TARGET__] USING (SELECT CAST([id] AS INTEGER) AS [id], CAST([ts] AS DATETIME2) AS [ts], CAST([val] AS INTEGER) AS [val] FROM [__temp_target_{temp_table_id}]) AS [__MERGE_SOURCE__] ON [__MERGE_TARGET__].[id] = [__MERGE_SOURCE__].[id] WHEN MATCHED AND EXISTS(SELECT [__MERGE_TARGET__].[ts], [__MERGE_TARGET__].[val] EXCEPT SELECT [__MERGE_SOURCE__].[ts], [__MERGE_SOURCE__].[val]) THEN UPDATE SET [__MERGE_TARGET__].[ts] = [__MERGE_SOURCE__].[ts], [__MERGE_TARGET__].[val] = [__MERGE_SOURCE__].[val] WHEN NOT MATCHED THEN INSERT ([id], [ts], [val]) VALUES ([__MERGE_SOURCE__].[id], [__MERGE_SOURCE__].[ts], [__MERGE_SOURCE__].[val]);", + f"DROP TABLE IF EXISTS [__temp_target_{temp_table_id}];", + ] + + # merge exists and all model columns are keys + adapter.cursor.reset_mock() + adapter._connection_pool.get().reset_mock() + temp_table_mock.return_value = make_temp_table_name(table_name, temp_table_id) + adapter.merge( + target_table=table_name, + source_table=df, + target_columns_to_types={ + "id": exp.DataType.build("int"), + "ts": exp.DataType.build("TIMESTAMP"), + }, + unique_key=[exp.to_identifier("id"), exp.to_column("ts")], + physical_properties={"mssql_merge_exists": True}, + ) + + assert to_sql_calls(adapter) == [ + f"""IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '__temp_target_{temp_table_id}') EXEC('CREATE TABLE [__temp_target_{temp_table_id}] ([id] INTEGER, [ts] DATETIME2)');""", + f"MERGE INTO [target] AS [__MERGE_TARGET__] USING (SELECT CAST([id] AS INTEGER) AS [id], CAST([ts] AS DATETIME2) AS [ts] FROM [__temp_target_{temp_table_id}]) AS [__MERGE_SOURCE__] ON [__MERGE_TARGET__].[id] = [__MERGE_SOURCE__].[id] AND [__MERGE_TARGET__].[ts] = [__MERGE_SOURCE__].[ts] WHEN NOT MATCHED THEN INSERT ([id], [ts]) VALUES ([__MERGE_SOURCE__].[id], [__MERGE_SOURCE__].[ts]);", + f"DROP TABLE IF EXISTS [__temp_target_{temp_table_id}];", + ] + + +def test_replace_query(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): + adapter = make_mocked_engine_adapter(MSSQLEngineAdapter) + mocker.patch.object( + adapter, + "_get_data_objects", + return_value=[DataObject(schema="", name="test_table", type="table")], + ) adapter.replace_query("test_table", parse_one("SELECT a FROM tbl"), {"a": "int"}) assert to_sql_calls(adapter) == [ - """SELECT 1 FROM [information_schema].[tables] WHERE [table_name] = 'test_table';""", - "MERGE INTO [test_table] AS [__MERGE_TARGET__] USING (SELECT [a] AS [a] FROM [tbl]) AS [__MERGE_SOURCE__] ON (1 = 0) WHEN NOT MATCHED BY SOURCE THEN DELETE WHEN NOT MATCHED THEN INSERT ([a]) VALUES ([a]);", + "TRUNCATE TABLE [test_table];", + "INSERT INTO [test_table] ([a]) SELECT [a] FROM [tbl];", ] @@ -409,6 +577,11 @@ def test_replace_query_pandas( ) adapter = make_mocked_engine_adapter(MSSQLEngineAdapter) + mocker.patch.object( + adapter, + "_get_data_objects", + return_value=[DataObject(schema="", name="test_table", type="table")], + ) adapter.cursor.fetchone.return_value = (1,) temp_table_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table") @@ -440,8 +613,9 @@ def temp_table_exists(table: exp.Table) -> bool: ) assert to_sql_calls(adapter) == [ - f"""IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = '{temp_table_name}') EXEC('CREATE TABLE [{temp_table_name}] ([a] INTEGER, [b] INTEGER)');""", - "MERGE INTO [test_table] AS [__MERGE_TARGET__] USING (SELECT CAST([a] AS INTEGER) AS [a], CAST([b] AS INTEGER) AS [b] FROM [__temp_test_table_abcdefgh]) AS [__MERGE_SOURCE__] ON (1 = 0) WHEN NOT MATCHED BY SOURCE THEN DELETE WHEN NOT MATCHED THEN INSERT ([a], [b]) VALUES ([a], [b]);", + f"""IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '{temp_table_name}') EXEC('CREATE TABLE [{temp_table_name}] ([a] INTEGER, [b] INTEGER)');""", + "TRUNCATE TABLE [test_table];", + f"INSERT INTO [test_table] ([a], [b]) SELECT CAST([a] AS INTEGER) AS [a], CAST([b] AS INTEGER) AS [b] FROM [{temp_table_name}];", f"DROP TABLE IF EXISTS [{temp_table_name}];", ] @@ -456,7 +630,7 @@ def test_create_table_primary_key(make_mocked_engine_adapter: t.Callable): adapter.create_table("test_table", columns_to_types, primary_key=("cola", "colb")) adapter.cursor.execute.assert_called_once_with( - """IF NOT EXISTS (SELECT * FROM information_schema.tables WHERE table_name = 'test_table') EXEC('CREATE TABLE [test_table] ([cola] INTEGER, [colb] VARCHAR(MAX), PRIMARY KEY ([cola], [colb]))');""" + """IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = 'test_table') EXEC('CREATE TABLE [test_table] ([cola] INTEGER, [colb] VARCHAR(MAX), PRIMARY KEY ([cola], [colb]))');""" ) @@ -485,7 +659,7 @@ def test_drop_schema_with_catalog(make_mocked_engine_adapter: t.Callable, mocker def test_get_data_objects_catalog(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): - adapter = make_mocked_engine_adapter(MSSQLEngineAdapter) + adapter = make_mocked_engine_adapter(MSSQLEngineAdapter, patch_get_data_objects=False) original_set_current_catalog = adapter.set_current_catalog local_state = {} @@ -540,6 +714,36 @@ def test_drop_schema(make_mocked_engine_adapter: t.Callable): ] +def test_drop_schema_with_special_identifiers(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(MSSQLEngineAdapter) + + adapter._get_data_objects = mock.Mock() + adapter._get_data_objects.return_value = [ + DataObject( + catalog="test_catalog", + schema="test schema", # Schema with space + name="test view", # Object with space + type=DataObjectType.from_str("VIEW"), + ), + DataObject( + catalog="test_catalog", + schema="test schema", + name="test table", # Table with space + type=DataObjectType.from_str("TABLE"), + ), + ] + + schema_name = exp.to_table("[test schema]", dialect="tsql") + adapter.drop_schema(schema_name, cascade=True) + + # Validate that names with spaces/special chars are properly quoted with square brackets + assert to_sql_calls(adapter) == [ + """DROP VIEW IF EXISTS [test schema].[test view];""", + """DROP TABLE IF EXISTS [test schema].[test table];""", + """DROP SCHEMA IF EXISTS [test schema];""", + ] + + def test_df_dates(make_mocked_engine_adapter: t.Callable): adapter = make_mocked_engine_adapter(MSSQLEngineAdapter) @@ -620,7 +824,6 @@ def test_create_table_from_query(make_mocked_engine_adapter: t.Callable, mocker: ), exists=False, ) - assert to_sql_calls(adapter) == [ "CREATE VIEW [__temp_ctas_test_random_id] AS SELECT [a], [b], [x] + 1 AS [c], [d] AS [d], [e] FROM (SELECT * FROM [table]);", "DROP VIEW IF EXISTS [__temp_ctas_test_random_id];", @@ -628,3 +831,174 @@ def test_create_table_from_query(make_mocked_engine_adapter: t.Callable, mocker: ] columns_mock.assert_called_once_with(exp.table_("__temp_ctas_test_random_id", quoted=True)) + + # We don't want to drop anything other than LIMIT 0 + # See https://github.com/SQLMesh/sqlmesh/issues/4048 + adapter.ctas( + table_name="test_schema.test_table", + query_or_df=parse_one( + "SELECT * FROM (SELECT * FROM t WHERE FALSE LIMIT 1) WHERE FALSE LIMIT 0" + ), + exists=False, + ) + assert ( + "CREATE VIEW [__temp_ctas_test_random_id] AS SELECT * FROM (SELECT TOP 1 * FROM [t]);" + in to_sql_calls(adapter) + ) + + +def test_replace_query_strategy(adapter: MSSQLEngineAdapter, mocker: MockerFixture): + # ref issue 4472: https://github.com/SQLMesh/sqlmesh/issues/4472 + # The FULL strategy calls EngineAdapter.replace_query() which calls _insert_overwrite_by_condition() should use DELETE+INSERT and not MERGE + expressions = d.parse( + f""" + MODEL ( + name db.table, + kind FULL, + dialect tsql + ); + + select a, b from db.upstream_table; + """ + ) + model = load_sql_based_model(expressions) + + exists_mock = mocker.patch( + "sqlmesh.core.engine_adapter.mssql.MSSQLEngineAdapter.table_exists", + return_value=False, + ) + + assert not adapter.table_exists("test_table") + + # initial - table doesnt exist + adapter.replace_query( + "test_table", + model.render_query_or_raise(), + table_format=model.table_format, + storage_format=model.storage_format, + partitioned_by=model.partitioned_by, + partition_interval_unit=model.partition_interval_unit, + clustered_by=model.clustered_by, + table_properties=model.physical_properties, + table_description=model.description, + column_descriptions=model.column_descriptions, + target_columns_to_types=model.columns_to_types_or_raise, + ) + + # subsequent - table exists + exists_mock.return_value = True + assert adapter.table_exists("test_table") + + mocker.patch.object( + adapter, + "_get_data_objects", + return_value=[DataObject(schema="", name="test_table", type="table")], + ) + + adapter.replace_query( + "test_table", + model.render_query_or_raise(), + table_format=model.table_format, + storage_format=model.storage_format, + partitioned_by=model.partitioned_by, + partition_interval_unit=model.partition_interval_unit, + clustered_by=model.clustered_by, + table_properties=model.physical_properties, + table_description=model.description, + column_descriptions=model.column_descriptions, + target_columns_to_types=model.columns_to_types_or_raise, + ) + + assert to_sql_calls(adapter) == [ + # initial - create table if not exists + "IF NOT EXISTS (SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = 'test_table') EXEC('SELECT * INTO [test_table] FROM (SELECT [a] AS [a], [b] AS [b] FROM [db].[upstream_table] AS [upstream_table]) AS temp');", + # subsequent - truncate + insert + "TRUNCATE TABLE [test_table];", + "INSERT INTO [test_table] ([a], [b]) SELECT [a] AS [a], [b] AS [b] FROM [db].[upstream_table] AS [upstream_table];", + ] + + +def test_mssql_merge_exists_switches_strategy_from_truncate_to_merge( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(MSSQLEngineAdapter) + + query = exp.select("*").from_("source") + source_queries = [SourceQuery(query_factory=lambda: query)] + + # Test WITHOUT mssql_merge_exists, should use DELETE+INSERT strategy + base_insert_overwrite = mocker.patch( + "sqlmesh.core.engine_adapter.base.EngineAdapter._insert_overwrite_by_condition" + ) + + adapter._insert_overwrite_by_condition( + table_name="target", + source_queries=source_queries, + target_columns_to_types={ + "id": exp.DataType.build("INT"), + "value": exp.DataType.build("VARCHAR"), + }, + where=None, + ) + + # Should call base DELETE+INSERT strategy + assert base_insert_overwrite.called + base_insert_overwrite.reset_mock() + + # Test WITH mssql_merge_exists uses MERGE strategy + super_insert_overwrite = mocker.patch( + "sqlmesh.core.engine_adapter.base.EngineAdapterWithIndexSupport._insert_overwrite_by_condition" + ) + + adapter._insert_overwrite_by_condition( + table_name="target", + source_queries=source_queries, + target_columns_to_types={ + "id": exp.DataType.build("INT"), + "value": exp.DataType.build("VARCHAR"), + }, + where=None, + table_properties={"mssql_merge_exists": True}, + ) + + # Should call super's MERGE strategy, not base DELETE+INSERT + assert super_insert_overwrite.called + assert not base_insert_overwrite.called + + +def test_python_scd2_model_preserves_physical_properties(make_snapshot): + @model( + "test_schema.python_scd2_with_mssql_merge", + kind=SCDType2ByTimeKind( + unique_key=["id"], + valid_from_name="valid_from", + valid_to_name="valid_to", + updated_at_name="updated_at", + ), + columns={ + "id": "INT", + "value": "VARCHAR", + "updated_at": "TIMESTAMP", + "valid_from": "TIMESTAMP", + "valid_to": "TIMESTAMP", + }, + physical_properties={"mssql_merge_exists": True}, + ) + def python_scd2_model(context, **kwargs): + import pandas as pd + + return pd.DataFrame( + {"id": [1, 2], "value": ["a", "b"], "updated_at": ["2024-01-01", "2024-01-02"]} + ) + + m = model.get_registry()["test_schema.python_scd2_with_mssql_merge"].model( + module_path=Path("."), + path=Path("."), + dialect="tsql", + ) + + # verify model has physical_properties that trigger merge strategy + assert "mssql_merge_exists" in m.physical_properties + snapshot: Snapshot = make_snapshot(m) + assert snapshot.node.physical_properties == m.physical_properties + assert snapshot.node.physical_properties.get("mssql_merge_exists") diff --git a/tests/core/engine_adapter/test_mysql.py b/tests/core/engine_adapter/test_mysql.py index f00c533262..f9fe140892 100644 --- a/tests/core/engine_adapter/test_mysql.py +++ b/tests/core/engine_adapter/test_mysql.py @@ -75,3 +75,12 @@ def test_pre_ping(mocker: MockerFixture, make_mocked_engine_adapter: t.Callable) ] adapter._connection_pool.get().ping.assert_called_once_with(reconnect=False) + + +def test_create_table_like(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(MySQLEngineAdapter) + + adapter.create_table_like("target_table", "source_table") + adapter.cursor.execute.assert_called_once_with( + "CREATE TABLE IF NOT EXISTS `target_table` LIKE `source_table`" + ) diff --git a/tests/core/engine_adapter/test_postgres.py b/tests/core/engine_adapter/test_postgres.py index b320385a1e..ebcdd03f55 100644 --- a/tests/core/engine_adapter/test_postgres.py +++ b/tests/core/engine_adapter/test_postgres.py @@ -3,10 +3,11 @@ import pytest from pytest_mock import MockFixture from pytest_mock.plugin import MockerFixture -from sqlglot import exp +from sqlglot import exp, parse_one from sqlglot.helper import ensure_list from sqlmesh.core.engine_adapter import PostgresEngineAdapter +from sqlmesh.utils.errors import SQLMeshError from tests.core.engine_adapter import to_sql_calls pytestmark = [pytest.mark.engine, pytest.mark.postgres] @@ -53,15 +54,15 @@ def test_drop_schema(kwargs, expected, make_mocked_engine_adapter: t.Callable): assert to_sql_calls(adapter) == ensure_list(expected) -def test_drop_schema_with_catalog( - make_mocked_engine_adapter: t.Callable, mocker: MockFixture, caplog -): +def test_drop_schema_with_catalog(make_mocked_engine_adapter: t.Callable, mocker: MockFixture): adapter = make_mocked_engine_adapter(PostgresEngineAdapter) adapter.get_current_catalog = mocker.MagicMock(return_value="other_catalog") - adapter.drop_schema("test_catalog.test_schema") - assert "requires that all catalog operations be against a single catalog" in caplog.text + with pytest.raises( + SQLMeshError, match="requires that all catalog operations be against a single catalog" + ): + adapter.drop_schema("test_catalog.test_schema") def test_comments(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): @@ -80,3 +81,204 @@ def test_comments(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture) """COMMENT ON TABLE "test_table" IS '\\'""", """COMMENT ON COLUMN "test_table"."a" IS '\\'""", ] + + +def test_create_table_like(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(PostgresEngineAdapter) + + adapter.create_table_like("target_table", "source_table") + adapter.cursor.execute.assert_called_once_with( + 'CREATE TABLE IF NOT EXISTS "target_table" (LIKE "source_table" INCLUDING ALL)' + ) + + +def test_merge_version_gte_15(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(PostgresEngineAdapter) + adapter.server_version = (15, 0) + + adapter.merge( + target_table="target", + source_table=t.cast(exp.Select, parse_one('SELECT "ID", ts, val FROM source')), + target_columns_to_types={ + "ID": exp.DataType.build("int"), + "ts": exp.DataType.build("timestamp"), + "val": exp.DataType.build("int"), + }, + unique_key=[exp.to_identifier("ID", quoted=True)], + ) + + sql_calls = to_sql_calls(adapter) + assert sql_calls == [ + """MERGE INTO "target" AS "__MERGE_TARGET__" USING (SELECT "ID", "ts", "val" FROM "source") AS "__MERGE_SOURCE__" ON "__MERGE_TARGET__"."ID" = "__MERGE_SOURCE__"."ID" WHEN MATCHED THEN UPDATE SET "ID" = "__MERGE_SOURCE__"."ID", "ts" = "__MERGE_SOURCE__"."ts", "val" = "__MERGE_SOURCE__"."val" WHEN NOT MATCHED THEN INSERT ("ID", "ts", "val") VALUES ("__MERGE_SOURCE__"."ID", "__MERGE_SOURCE__"."ts", "__MERGE_SOURCE__"."val")""" + ] + + +def test_merge_version_lt_15( + make_mocked_engine_adapter: t.Callable, make_temp_table_name: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(PostgresEngineAdapter) + adapter.server_version = (14, 0) + + temp_table_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table") + table_name = "test" + temp_table_id = "abcdefgh" + temp_table_mock.return_value = make_temp_table_name(table_name, temp_table_id) + + adapter.merge( + target_table="target", + source_table=t.cast(exp.Select, parse_one('SELECT "ID", ts, val FROM source')), + target_columns_to_types={ + "ID": exp.DataType.build("int"), + "ts": exp.DataType.build("timestamp"), + "val": exp.DataType.build("int"), + }, + unique_key=[exp.to_identifier("ID", quoted=True)], + ) + + sql_calls = to_sql_calls(adapter) + assert sql_calls == [ + 'CREATE TABLE "__temp_test_abcdefgh" AS SELECT CAST("ID" AS INT) AS "ID", CAST("ts" AS TIMESTAMP) AS "ts", CAST("val" AS INT) AS "val" FROM (SELECT "ID", "ts", "val" FROM "source") AS "_subquery"', + 'DELETE FROM "target" WHERE "ID" IN (SELECT "ID" FROM "__temp_test_abcdefgh")', + 'INSERT INTO "target" ("ID", "ts", "val") SELECT DISTINCT ON ("ID") "ID", "ts", "val" FROM "__temp_test_abcdefgh"', + 'DROP TABLE IF EXISTS "__temp_test_abcdefgh"', + ] + + +def test_alter_table_drop_column_cascade(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(PostgresEngineAdapter) + + current_table_name = "test_table" + target_table_name = "target_table" + + def table_columns(table_name: str) -> t.Dict[str, exp.DataType]: + if table_name == current_table_name: + return {"id": exp.DataType.build("int"), "test_column": exp.DataType.build("int")} + return {"id": exp.DataType.build("int")} + + adapter.columns = table_columns + + adapter.alter_table(adapter.get_alter_operations(current_table_name, target_table_name)) + assert to_sql_calls(adapter) == [ + 'ALTER TABLE "test_table" DROP COLUMN "test_column" CASCADE', + ] + + +def test_server_version(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): + adapter = make_mocked_engine_adapter(PostgresEngineAdapter) + + fetchone_mock = mocker.patch.object(adapter, "fetchone") + fetchone_mock.return_value = ("14.0",) + assert adapter.server_version == (14, 0) + + del adapter.server_version + fetchone_mock.return_value = ("15.8",) + assert adapter.server_version == (15, 8) + + del adapter.server_version + fetchone_mock.return_value = ("15.13 (Debian 15.13-1.pgdg120+1)",) + assert adapter.server_version == (15, 13) + + +def test_sync_grants_config(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): + adapter = make_mocked_engine_adapter(PostgresEngineAdapter) + relation = exp.to_table("test_schema.test_table", dialect="postgres") + new_grants_config = {"SELECT": ["user1", "user2"], "INSERT": ["user3"]} + + current_grants = [("SELECT", "old_user"), ("UPDATE", "admin_user")] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="postgres") + + assert executed_sql == ( + "SELECT privilege_type, grantee FROM information_schema.role_table_grants " + "WHERE table_schema = 'test_schema' AND table_name = 'test_table' " + "AND grantor = current_role AND grantee <> current_role" + ) + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 4 + + assert 'GRANT SELECT ON "test_schema"."test_table" TO "user1", "user2"' in sql_calls + assert 'GRANT INSERT ON "test_schema"."test_table" TO "user3"' in sql_calls + assert 'REVOKE SELECT ON "test_schema"."test_table" FROM "old_user"' in sql_calls + assert 'REVOKE UPDATE ON "test_schema"."test_table" FROM "admin_user"' in sql_calls + + +def test_sync_grants_config_with_overlaps( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(PostgresEngineAdapter) + relation = exp.to_table("test_schema.test_table", dialect="postgres") + new_grants_config = {"SELECT": ["user1", "user2", "user3"], "INSERT": ["user2", "user4"]} + + current_grants = [ + ("SELECT", "user1"), + ("SELECT", "user5"), + ("INSERT", "user2"), + ("UPDATE", "user3"), + ] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="postgres") + + assert executed_sql == ( + "SELECT privilege_type, grantee FROM information_schema.role_table_grants " + "WHERE table_schema = 'test_schema' AND table_name = 'test_table' " + "AND grantor = current_role AND grantee <> current_role" + ) + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 4 + + assert 'GRANT SELECT ON "test_schema"."test_table" TO "user2", "user3"' in sql_calls + assert 'GRANT INSERT ON "test_schema"."test_table" TO "user4"' in sql_calls + assert 'REVOKE SELECT ON "test_schema"."test_table" FROM "user5"' in sql_calls + assert 'REVOKE UPDATE ON "test_schema"."test_table" FROM "user3"' in sql_calls + + +def test_diff_grants_configs(make_mocked_engine_adapter: t.Callable): + new_grants = {"select": ["USER1", "USER2"], "insert": ["user3"]} + old_grants = {"SELECT": ["user1", "user4"], "UPDATE": ["user5"]} + + adapter = make_mocked_engine_adapter(PostgresEngineAdapter) + additions, removals = adapter._diff_grants_configs(new_grants, old_grants) + + assert additions["select"] == ["USER2"] + assert additions["insert"] == ["user3"] + + assert removals["SELECT"] == ["user4"] + assert removals["UPDATE"] == ["user5"] + + +def test_sync_grants_config_with_default_schema( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(PostgresEngineAdapter) + relation = exp.to_table("test_table", dialect="postgres") # No schema + new_grants_config = {"SELECT": ["user1"], "INSERT": ["user2"]} + + currrent_grants = [("UPDATE", "old_user")] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=currrent_grants) + get_schema_mock = mocker.patch.object(adapter, "_get_current_schema", return_value="public") + + adapter.sync_grants_config(relation, new_grants_config) + + get_schema_mock.assert_called_once() + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="postgres") + + assert executed_sql == ( + "SELECT privilege_type, grantee FROM information_schema.role_table_grants " + "WHERE table_schema = 'public' AND table_name = 'test_table' " + "AND grantor = current_role AND grantee <> current_role" + ) diff --git a/tests/core/engine_adapter/test_redshift.py b/tests/core/engine_adapter/test_redshift.py index b7990be043..5438943556 100644 --- a/tests/core/engine_adapter/test_redshift.py +++ b/tests/core/engine_adapter/test_redshift.py @@ -1,13 +1,16 @@ # type: ignore import typing as t -import pandas as pd +import pandas as pd # noqa: TID253 import pytest from pytest_mock.plugin import MockerFixture +from unittest.mock import PropertyMock from sqlglot import expressions as exp from sqlglot import parse_one from sqlmesh.core.engine_adapter import RedshiftEngineAdapter +from sqlmesh.core.engine_adapter.shared import DataObject, DataObjectType +from sqlmesh.utils.errors import SQLMeshError from tests.core.engine_adapter import to_sql_calls pytestmark = [pytest.mark.engine, pytest.mark.redshift] @@ -24,7 +27,7 @@ def test_columns(adapter: t.Callable): adapter.cursor.fetchall.return_value = [("col", "INT")] resp = adapter.columns("db.table") adapter.cursor.execute.assert_called_once_with( - """SELECT "column_name", "data_type" FROM "svv_columns" WHERE "table_name" = 'table' AND "table_schema" = 'db'""" + """SELECT "column_name", "data_type", "character_maximum_length", "numeric_precision", "numeric_scale" FROM "svv_columns" WHERE "table_name" = 'table' AND "table_schema" = 'db'""" ) assert resp == {"col": exp.DataType.build("INT")} @@ -74,12 +77,160 @@ def test_varchar_size_workaround(make_mocked_engine_adapter: t.Callable, mocker: ) assert to_sql_calls(adapter) == [ - 'CREATE VIEW "__temp_ctas_test_random_id" AS SELECT "char", "char1" + 1 AS "char1", "char2" AS "char2", "varchar", "varchar256", "varchar2" FROM (SELECT * FROM "table")', + 'CREATE VIEW "__temp_ctas_test_random_id" AS SELECT "char", "char1" + 1 AS "char1", "char2" AS "char2", "varchar", "varchar256", "varchar2" FROM (SELECT * FROM "table") WITH NO SCHEMA BINDING', 'DROP VIEW IF EXISTS "__temp_ctas_test_random_id" CASCADE', 'CREATE TABLE "test_schema"."test_table" ("char" CHAR, "char1" CHAR(max), "char2" CHAR(2), "varchar" VARCHAR, "varchar256" VARCHAR(max), "varchar2" VARCHAR(2))', ] +def test_sync_grants_config(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): + adapter = make_mocked_engine_adapter(RedshiftEngineAdapter) + relation = exp.to_table("test_schema.test_table", dialect="redshift") + new_grants_config = {"SELECT": ["user1", "user2"], "INSERT": ["user3"]} + + current_grants = [("SELECT", "old_user"), ("UPDATE", "legacy_user")] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="redshift") + expected_sql = ( + "SELECT privilege_type, grantee FROM information_schema.table_privileges " + "WHERE table_schema = 'test_schema' AND table_name = 'test_table' " + "AND grantor = CURRENT_USER AND grantee <> CURRENT_USER" + ) + assert executed_sql == expected_sql + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 4 + assert 'REVOKE SELECT ON "test_schema"."test_table" FROM "old_user"' in sql_calls + assert 'REVOKE UPDATE ON "test_schema"."test_table" FROM "legacy_user"' in sql_calls + assert 'GRANT SELECT ON "test_schema"."test_table" TO "user1", "user2"' in sql_calls + assert 'GRANT INSERT ON "test_schema"."test_table" TO "user3"' in sql_calls + + +def test_sync_grants_config_with_overlaps( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(RedshiftEngineAdapter) + relation = exp.to_table("test_schema.test_table", dialect="redshift") + new_grants_config = { + "SELECT": ["user_shared", "user_new"], + "INSERT": ["user_shared", "user_writer"], + } + + current_grants = [ + ("SELECT", "user_shared"), + ("SELECT", "user_legacy"), + ("INSERT", "user_shared"), + ] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="redshift") + expected_sql = ( + "SELECT privilege_type, grantee FROM information_schema.table_privileges " + "WHERE table_schema = 'test_schema' AND table_name = 'test_table' " + "AND grantor = CURRENT_USER AND grantee <> CURRENT_USER" + ) + assert executed_sql == expected_sql + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 3 + assert 'REVOKE SELECT ON "test_schema"."test_table" FROM "user_legacy"' in sql_calls + assert 'GRANT SELECT ON "test_schema"."test_table" TO "user_new"' in sql_calls + assert 'GRANT INSERT ON "test_schema"."test_table" TO "user_writer"' in sql_calls + + +@pytest.mark.parametrize( + "table_type", + [ + (DataObjectType.TABLE), + (DataObjectType.VIEW), + (DataObjectType.MATERIALIZED_VIEW), + ], +) +def test_sync_grants_config_object_kind( + make_mocked_engine_adapter: t.Callable, + mocker: MockerFixture, + table_type: DataObjectType, +) -> None: + adapter = make_mocked_engine_adapter(RedshiftEngineAdapter) + relation = exp.to_table("test_schema.test_object", dialect="redshift") + + mocker.patch.object(adapter, "fetchall", return_value=[]) + + adapter.sync_grants_config(relation, {"SELECT": ["user_test"]}, table_type) + + sql_calls = to_sql_calls(adapter) + # we don't need to explicitly specify object_type for tables and views + assert sql_calls == [f'GRANT SELECT ON "test_schema"."test_object" TO "user_test"'] + + +def test_sync_grants_config_quotes(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): + adapter = make_mocked_engine_adapter(RedshiftEngineAdapter) + relation = exp.to_table('"TestSchema"."TestTable"', dialect="redshift") + new_grants_config = {"SELECT": ["user1", "user2"], "INSERT": ["user3"]} + + current_grants = [("SELECT", "user_old"), ("UPDATE", "user_legacy")] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="redshift") + expected_sql = ( + "SELECT privilege_type, grantee FROM information_schema.table_privileges " + "WHERE table_schema = 'TestSchema' AND table_name = 'TestTable' " + "AND grantor = CURRENT_USER AND grantee <> CURRENT_USER" + ) + assert executed_sql == expected_sql + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 4 + assert 'REVOKE SELECT ON "TestSchema"."TestTable" FROM "user_old"' in sql_calls + assert 'REVOKE UPDATE ON "TestSchema"."TestTable" FROM "user_legacy"' in sql_calls + assert 'GRANT SELECT ON "TestSchema"."TestTable" TO "user1", "user2"' in sql_calls + assert 'GRANT INSERT ON "TestSchema"."TestTable" TO "user3"' in sql_calls + + +def test_sync_grants_config_no_schema( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(RedshiftEngineAdapter) + relation = exp.to_table("test_table", dialect="redshift") + new_grants_config = {"SELECT": ["user1"], "INSERT": ["user2"]} + + current_grants = [("UPDATE", "user_old")] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + get_schema_mock = mocker.patch.object(adapter, "_get_current_schema", return_value="public") + + adapter.sync_grants_config(relation, new_grants_config) + + get_schema_mock.assert_called_once() + + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="redshift") + expected_sql = ( + "SELECT privilege_type, grantee FROM information_schema.table_privileges " + "WHERE table_schema = 'public' AND table_name = 'test_table' " + "AND grantor = CURRENT_USER AND grantee <> CURRENT_USER" + ) + assert executed_sql == expected_sql + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 3 + assert 'REVOKE UPDATE ON "test_table" FROM "user_old"' in sql_calls + assert 'GRANT SELECT ON "test_table" TO "user1"' in sql_calls + assert 'GRANT INSERT ON "test_table" TO "user2"' in sql_calls + + def test_create_table_from_query_exists_no_if_not_exists( adapter: t.Callable, mocker: MockerFixture ): @@ -113,7 +264,47 @@ def test_create_table_from_query_exists_no_if_not_exists( ) assert to_sql_calls(adapter) == [ - 'CREATE VIEW "__temp_ctas_test_random_id" AS SELECT "a", "b", "x" + 1 AS "c", "d" AS "d", "e" FROM (SELECT * FROM "table")', + 'CREATE VIEW "__temp_ctas_test_random_id" AS SELECT "a", "b", "x" + 1 AS "c", "d" AS "d", "e" FROM (SELECT * FROM "table") WITH NO SCHEMA BINDING', + 'DROP VIEW IF EXISTS "__temp_ctas_test_random_id" CASCADE', + 'CREATE TABLE "test_schema"."test_table" ("a" VARCHAR(MAX), "b" VARCHAR(60), "c" VARCHAR(MAX), "d" VARCHAR(MAX), "e" TIMESTAMP)', + ] + + columns_mock.assert_called_once_with(exp.table_("__temp_ctas_test_random_id", quoted=True)) + + +def test_create_table_recursive_cte(adapter: t.Callable, mocker: MockerFixture): + mocker.patch( + "sqlmesh.core.engine_adapter.base.random_id", + return_value="test_random_id", + ) + + mocker.patch( + "sqlmesh.core.engine_adapter.redshift.RedshiftEngineAdapter.table_exists", + return_value=True, + ) + + columns_mock = mocker.patch( + "sqlmesh.core.engine_adapter.redshift.RedshiftEngineAdapter.columns", + return_value={ + "a": exp.DataType.build("VARCHAR(MAX)", dialect="redshift"), + "b": exp.DataType.build("VARCHAR(60)", dialect="redshift"), + "c": exp.DataType.build("VARCHAR(MAX)", dialect="redshift"), + "d": exp.DataType.build("VARCHAR(MAX)", dialect="redshift"), + "e": exp.DataType.build("TIMESTAMP", dialect="redshift"), + }, + ) + + adapter.ctas( + table_name="test_schema.test_table", + query_or_df=parse_one( + "WITH RECURSIVE cte AS (SELECT * FROM table WHERE FALSE LIMIT 0) SELECT a, b, x + 1 AS c, d AS d, e FROM cte WHERE d > 0 AND FALSE LIMIT 0", + dialect="redshift", + ), + exists=False, + ) + + assert to_sql_calls(adapter) == [ + 'CREATE VIEW "__temp_ctas_test_random_id" AS WITH RECURSIVE "cte" AS (SELECT * FROM "table") SELECT "a", "b", "x" + 1 AS "c", "d" AS "d", "e" FROM "cte"', 'DROP VIEW IF EXISTS "__temp_ctas_test_random_id" CASCADE', 'CREATE TABLE "test_schema"."test_table" ("a" VARCHAR(MAX), "b" VARCHAR(60), "c" VARCHAR(MAX), "d" VARCHAR(MAX), "e" TIMESTAMP)', ] @@ -177,7 +368,7 @@ def test_values_to_sql(adapter: t.Callable, mocker: MockerFixture): df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) result = adapter._values_to_sql( values=list(df.itertuples(index=False, name=None)), - columns_to_types={"a": exp.DataType.build("int"), "b": exp.DataType.build("int")}, + target_columns_to_types={"a": exp.DataType.build("int"), "b": exp.DataType.build("int")}, batch_start=0, batch_end=2, ) @@ -220,11 +411,16 @@ def mock_table(*args, **kwargs): mock_temp_table = mocker.MagicMock(side_effect=mock_table) mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table", mock_temp_table) + mocker.patch.object( + adapter, + "_get_data_objects", + return_value=[DataObject(schema="", name="test_table", type="table")], + ) adapter.replace_query( table_name="test_table", query_or_df=df, - columns_to_types={ + target_columns_to_types={ "a": exp.DataType.build("int"), "b": exp.DataType.build("int"), }, @@ -251,7 +447,7 @@ def test_replace_query_with_df_table_not_exists(adapter: t.Callable, mocker: Moc adapter.replace_query( table_name="test_table", query_or_df=df, - columns_to_types={ + target_columns_to_types={ "a": exp.DataType.build("int"), "b": exp.DataType.build("int"), }, @@ -294,7 +490,7 @@ def test_create_view(adapter: t.Callable): adapter.create_view( view_name="test_view", query_or_df=parse_one("SELECT cola FROM table"), - columns_to_types={ + target_columns_to_types={ "a": exp.DataType.build("int"), "b": exp.DataType.build("int"), }, @@ -304,3 +500,211 @@ def test_create_view(adapter: t.Callable): 'DROP VIEW IF EXISTS "test_view" CASCADE', 'CREATE VIEW "test_view" ("a", "b") AS SELECT "cola" FROM "table" WITH NO SCHEMA BINDING', ] + + +def test_alter_table_drop_column_cascade(adapter: t.Callable): + current_table_name = "test_table" + target_table_name = "target_table" + + def table_columns(table_name: str) -> t.Dict[str, exp.DataType]: + if table_name == current_table_name: + return {"id": exp.DataType.build("int"), "test_column": exp.DataType.build("int")} + return {"id": exp.DataType.build("int")} + + adapter.columns = table_columns + + adapter.alter_table(adapter.get_alter_operations(current_table_name, target_table_name)) + assert to_sql_calls(adapter) == [ + 'ALTER TABLE "test_table" DROP COLUMN "test_column" CASCADE', + ] + + +def test_alter_table_precision_increase_varchar(adapter: t.Callable): + current_table_name = "test_table" + target_table_name = "target_table" + + def table_columns(table_name: str) -> t.Dict[str, exp.DataType]: + if table_name == current_table_name: + return { + "id": exp.DataType.build("int"), + "test_column": exp.DataType.build("VARCHAR(10)"), + } + return { + "id": exp.DataType.build("int"), + "test_column": exp.DataType.build("VARCHAR(20)"), + } + + adapter.columns = table_columns + + adapter.alter_table(adapter.get_alter_operations(current_table_name, target_table_name)) + assert to_sql_calls(adapter) == [ + 'ALTER TABLE "test_table" ALTER COLUMN "test_column" TYPE VARCHAR(20)', + ] + + +def test_alter_table_precision_increase_decimal(adapter: t.Callable): + current_table_name = "test_table" + target_table_name = "target_table" + + def table_columns(table_name: str) -> t.Dict[str, exp.DataType]: + if table_name == current_table_name: + return { + "id": exp.DataType.build("int"), + "test_column": exp.DataType.build("DECIMAL(10, 10)"), + } + return { + "id": exp.DataType.build("int"), + "test_column": exp.DataType.build("DECIMAL(25, 10)"), + } + + adapter.columns = table_columns + + adapter.alter_table(adapter.get_alter_operations(current_table_name, target_table_name)) + assert to_sql_calls(adapter) == [ + 'ALTER TABLE "test_table" DROP COLUMN "test_column" CASCADE', + 'ALTER TABLE "test_table" ADD COLUMN "test_column" DECIMAL(25, 10)', + ] + + +def test_merge(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): + adapter = make_mocked_engine_adapter(RedshiftEngineAdapter) + mocker.patch( + "sqlmesh.core.engine_adapter.redshift.RedshiftEngineAdapter.enable_merge", + new_callable=PropertyMock(return_value=True), + ) + + adapter.merge( + target_table=exp.to_table("target_table_name"), + source_table=t.cast(exp.Select, parse_one('SELECT "ID", ts, val FROM source')), + target_columns_to_types={ + "ID": exp.DataType.build("int"), + "ts": exp.DataType.build("timestamp"), + "val": exp.DataType.build("int"), + }, + unique_key=[exp.to_identifier("ID", quoted=True)], + ) + + # Test additional predicates in the merge_filter + adapter.merge( + target_table=exp.to_table("target_table_name"), + source_table=t.cast(exp.Select, parse_one('SELECT "ID", ts, val FROM source')), + target_columns_to_types={ + "ID": exp.DataType.build("int"), + "ts": exp.DataType.build("timestamp"), + "val": exp.DataType.build("int"), + }, + unique_key=[exp.to_identifier("ID", quoted=True)], + merge_filter=exp.and_( + exp.and_(exp.column("ID", "__MERGE_SOURCE__") > 0), + exp.column("ts", "__MERGE_TARGET__") < exp.column("ts", "__MERGE_SOURCE__"), + ), + ) + + sql_calls = to_sql_calls(adapter) + assert sql_calls == [ + 'MERGE INTO "target_table_name" USING (SELECT "ID", "ts", "val" FROM "source") AS "__MERGE_SOURCE__" ON "target_table_name"."ID" = "__MERGE_SOURCE__"."ID" WHEN MATCHED THEN UPDATE SET "ID" = "__MERGE_SOURCE__"."ID", "ts" = "__MERGE_SOURCE__"."ts", "val" = "__MERGE_SOURCE__"."val" WHEN NOT MATCHED THEN INSERT ("ID", "ts", "val") VALUES ("__MERGE_SOURCE__"."ID", "__MERGE_SOURCE__"."ts", "__MERGE_SOURCE__"."val")', + 'MERGE INTO "target_table_name" USING (SELECT "ID", "ts", "val" FROM "source") AS "__MERGE_SOURCE__" ON ("__MERGE_SOURCE__"."ID" > 0 AND "target_table_name"."ts" < "__MERGE_SOURCE__"."ts") AND "target_table_name"."ID" = "__MERGE_SOURCE__"."ID" WHEN MATCHED THEN UPDATE SET "ID" = "__MERGE_SOURCE__"."ID", "ts" = "__MERGE_SOURCE__"."ts", "val" = "__MERGE_SOURCE__"."val" WHEN NOT MATCHED THEN INSERT ("ID", "ts", "val") VALUES ("__MERGE_SOURCE__"."ID", "__MERGE_SOURCE__"."ts", "__MERGE_SOURCE__"."val")', + ] + + +def test_merge_when_matched_error(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): + adapter = make_mocked_engine_adapter(RedshiftEngineAdapter) + mocker.patch( + "sqlmesh.core.engine_adapter.redshift.RedshiftEngineAdapter.enable_merge", + new_callable=PropertyMock(return_value=True), + ) + + with pytest.raises( + SQLMeshError, + match=r".*Redshift only supports a single WHEN MATCHED and WHEN NOT MATCHED clause*", + ): + adapter.merge( + target_table=exp.to_table("target_table_name"), + source_table=t.cast(exp.Select, parse_one('SELECT "ID", val FROM source')), + target_columns_to_types={ + "ID": exp.DataType.build("int"), + "val": exp.DataType.build("int"), + }, + unique_key=[exp.to_identifier("ID", quoted=True)], + when_matched=exp.Whens( + expressions=[ + exp.When( + matched=True, + condition=exp.column("ID", "__MERGE_SOURCE__").eq(exp.Literal.number(1)), + then=exp.Update( + expressions=[ + exp.column("val", "__MERGE_TARGET__").eq( + exp.column("val", "__MERGE_SOURCE__") + ), + ], + ), + ), + exp.When( + matched=True, + source=False, + then=exp.Update( + expressions=[ + exp.column("val", "__MERGE_TARGET__").eq( + exp.column("val", "__MERGE_SOURCE__") + ), + ], + ), + ), + ] + ), + ) + + +def test_merge_logical_filter_error(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): + adapter = make_mocked_engine_adapter(RedshiftEngineAdapter) + mocker.patch( + "sqlmesh.core.engine_adapter.redshift.RedshiftEngineAdapter.enable_merge", + new_callable=PropertyMock(return_value=False), + ) + + with pytest.raises( + SQLMeshError, + match=r".*This engine does not support MERGE expressions and therefore `merge_filter` is not supported.*", + ): + adapter.merge( + target_table=exp.to_table("target_table_name_2"), + source_table=t.cast(exp.Select, parse_one('SELECT "ID", ts FROM source')), + target_columns_to_types={ + "ID": exp.DataType.build("int"), + "ts": exp.DataType.build("timestamp"), + }, + unique_key=[exp.to_identifier("ID", quoted=True)], + merge_filter=exp.and_( + exp.and_(exp.column("ID", "__MERGE_SOURCE__") > 0), + exp.column("ts", "__MERGE_TARGET__") < exp.column("ts", "__MERGE_SOURCE__"), + ), + ) + + +def test_merge_logical( + make_mocked_engine_adapter: t.Callable, make_temp_table_name: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(RedshiftEngineAdapter) + + temp_table_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter._get_temp_table") + table_name = "test" + temp_table_id = "abcdefgh" + temp_table_mock.return_value = make_temp_table_name(table_name, temp_table_id) + + adapter.merge( + target_table=exp.to_table("target"), + source_table=t.cast(exp.Select, parse_one('SELECT "ID", ts FROM source')), + target_columns_to_types={ + "ID": exp.DataType.build("int"), + "ts": exp.DataType.build("timestamp"), + }, + unique_key=[exp.to_identifier("ID", quoted=True)], + ) + + sql_calls = to_sql_calls(adapter) + assert sql_calls == [ + 'CREATE TABLE "__temp_test_abcdefgh" AS SELECT CAST("ID" AS INTEGER) AS "ID", CAST("ts" AS TIMESTAMP) AS "ts" FROM (SELECT "ID", "ts" FROM "source") AS "_subquery"', + 'DELETE FROM "target" WHERE "ID" IN (SELECT "ID" FROM "__temp_test_abcdefgh")', + 'INSERT INTO "target" ("ID", "ts") SELECT "ID", "ts" FROM (SELECT "ID" AS "ID", "ts" AS "ts", ROW_NUMBER() OVER (PARTITION BY "ID" ORDER BY "ID") AS _row_number FROM "__temp_test_abcdefgh") AS _t WHERE _row_number = 1', + 'DROP TABLE IF EXISTS "__temp_test_abcdefgh"', + ] diff --git a/tests/core/engine_adapter/test_risingwave.py b/tests/core/engine_adapter/test_risingwave.py new file mode 100644 index 0000000000..ed3cd77a3f --- /dev/null +++ b/tests/core/engine_adapter/test_risingwave.py @@ -0,0 +1,86 @@ +# type: ignore +import typing as t +from unittest.mock import call + +import pytest +from sqlglot import parse_one, exp +from sqlmesh.core.engine_adapter.risingwave import RisingwaveEngineAdapter + +pytestmark = [pytest.mark.engine, pytest.mark.risingwave] + + +@pytest.fixture +def adapter(make_mocked_engine_adapter): + adapter = make_mocked_engine_adapter(RisingwaveEngineAdapter) + return adapter + + +def test_columns(adapter: t.Callable): + adapter.cursor.fetchall.return_value = [ + ("smallint_col", "smallint"), + ("int_col", "integer"), + ("bigint_col", "bigint"), + ("ts_col", "timestamp without time zone"), + ("tstz_col", "timestamp with time zone"), + ("int_array_col", "integer[]"), + ("vchar_col", "character varying"), + ("struct_col", "struct"), + ] + resp = adapter.columns("db.table") + assert resp == { + "smallint_col": exp.DataType.build(exp.DataType.Type.SMALLINT, nested=False), + "int_col": exp.DataType.build(exp.DataType.Type.INT, nested=False), + "bigint_col": exp.DataType.build(exp.DataType.Type.BIGINT, nested=False), + "ts_col": exp.DataType.build(exp.DataType.Type.TIMESTAMP, nested=False), + "tstz_col": exp.DataType.build(exp.DataType.Type.TIMESTAMPTZ, nested=False), + "int_array_col": exp.DataType.build( + exp.DataType.Type.ARRAY, + expressions=[exp.DataType.build(exp.DataType.Type.INT, nested=False)], + nested=True, + ), + "vchar_col": exp.DataType.build(exp.DataType.Type.VARCHAR), + "struct_col": exp.DataType.build( + exp.DataType.Type.STRUCT, + expressions=[ + exp.ColumnDef( + this=exp.Identifier(this="nested_col", quoted=False), + kind=exp.DataType.build(exp.DataType.Type.INT, nested=False), + ) + ], + nested=True, + ), + } + + +def test_create_view(adapter: t.Callable): + adapter.create_view("db.view", parse_one("SELECT 1"), replace=True) + adapter.create_view("db.view", parse_one("SELECT 1"), replace=False) + + adapter.cursor.execute.assert_has_calls( + [ + # 1st call + call('DROP VIEW IF EXISTS "db"."view" CASCADE'), + call('CREATE VIEW "db"."view" AS SELECT 1'), + # 2nd call + call('CREATE VIEW "db"."view" AS SELECT 1'), + ] + ) + + +def test_drop_view(adapter: t.Callable): + adapter.drop_view("db.view") + + adapter.drop_view("db.view", materialized=True) + + adapter.drop_view("db.view", cascade=False) + + adapter.cursor.execute.assert_has_calls( + [ + # 1st call + call('DROP VIEW IF EXISTS "db"."view" CASCADE'), + # 2nd call + call('DROP MATERIALIZED VIEW IF EXISTS "db"."view" CASCADE'), + # 3rd call + call('DROP VIEW IF EXISTS "db"."view"'), + ] + ) diff --git a/tests/core/engine_adapter/test_snowflake.py b/tests/core/engine_adapter/test_snowflake.py index 44b8ffe85f..60f6d38e5f 100644 --- a/tests/core/engine_adapter/test_snowflake.py +++ b/tests/core/engine_adapter/test_snowflake.py @@ -1,21 +1,34 @@ import typing as t -import pandas as pd +import pandas as pd # noqa: TID253 import pytest from pytest_mock.plugin import MockerFixture from sqlglot import exp, parse_one +from sqlglot.optimizer.normalize_identifiers import normalize_identifiers import sqlmesh.core.dialect as d from sqlmesh.core.dialect import normalize_model_name -from sqlmesh.core.model import load_sql_based_model from sqlmesh.core.engine_adapter import SnowflakeEngineAdapter +from sqlmesh.core.engine_adapter.base import EngineAdapter +from sqlmesh.core.engine_adapter.shared import DataObjectType +from sqlmesh.core.model import load_sql_based_model from sqlmesh.core.model.definition import SqlModel +from sqlmesh.core.node import IntervalUnit from sqlmesh.utils.errors import SQLMeshError +from sqlmesh.utils import optional_import from tests.core.engine_adapter import to_sql_calls +from sqlmesh.core.model.kind import ViewKind pytestmark = [pytest.mark.engine, pytest.mark.snowflake] +@pytest.fixture +def snowflake_mocked_engine_adapter( + make_mocked_engine_adapter: t.Callable, +) -> SnowflakeEngineAdapter: + return make_mocked_engine_adapter(SnowflakeEngineAdapter) + + def test_get_temp_table(mocker: MockerFixture, make_mocked_engine_adapter: t.Callable): adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter) @@ -28,6 +41,38 @@ def test_get_temp_table(mocker: MockerFixture, make_mocked_engine_adapter: t.Cal assert value.sql(dialect=adapter.dialect) == '"CATALOG"."DB"."__temp_TEST_TABLE_abcdefgh"' +def test_get_data_objects_lowercases_columns( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +) -> None: + adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter, patch_get_data_objects=False) + + adapter.get_current_catalog = mocker.Mock(return_value="TEST_CATALOG") + + adapter.fetchdf = mocker.Mock( + return_value=pd.DataFrame( # type: ignore[assignment] + [ + { + "CATALOG": "TEST_CATALOG", + "NAME": "MY_TABLE", + "SCHEMA_NAME": "PUBLIC", + "TYPE": "TABLE", + "CLUSTERING_KEY": "ID", + } + ] + ) + ) + + data_objects = adapter._get_data_objects("TEST_CATALOG.PUBLIC") + + assert len(data_objects) == 1 + data_object = data_objects[0] + assert data_object.catalog == "TEST_CATALOG" + assert data_object.schema_name == "PUBLIC" + assert data_object.name == "MY_TABLE" + assert data_object.type == DataObjectType.TABLE + assert data_object.clustering_key == "ID" + + @pytest.mark.parametrize( "current_warehouse, current_warehouse_exp, configured_warehouse, configured_warehouse_exp, should_change", [ @@ -87,6 +132,7 @@ def test_session( adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter) adapter.cursor.fetchone.return_value = (current_warehouse,) + # Test normal execution with adapter.session({"warehouse": configured_warehouse}): pass @@ -103,6 +149,27 @@ def test_session( assert to_sql_calls(adapter) == expected_calls + # Test exception handling - warehouse should still be reset + if should_change: + adapter.cursor.execute.reset_mock() + adapter.cursor.fetchone.return_value = (current_warehouse,) + + try: + with adapter.session({"warehouse": configured_warehouse}): + adapter.execute("SELECT 1") + raise RuntimeError("Test exception") + except RuntimeError: + pass + + expected_exception_calls = [ + "SELECT CURRENT_WAREHOUSE()", + f"USE WAREHOUSE {configured_warehouse_exp}", + "SELECT 1", + f"USE WAREHOUSE {current_warehouse_exp}", + ] + + assert to_sql_calls(adapter) == expected_exception_calls + def test_comments(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter) @@ -143,13 +210,240 @@ def test_comments(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture) assert sql_calls == [ """CREATE TABLE IF NOT EXISTS "test_table" ("a" INT COMMENT 'a column description', "b" INT) COMMENT='table description'""", """CREATE TABLE IF NOT EXISTS "test_table" ("a" INT COMMENT 'a column description', "b" INT) COMMENT='table description' AS SELECT CAST("a" AS INT) AS "a", CAST("b" AS INT) AS "b" FROM (SELECT "a", "b" FROM "source_table") AS "_subquery\"""", - """CREATE OR REPLACE VIEW "test_view" COMMENT='table description' AS SELECT "a", "b" FROM "source_table\"""", + """CREATE OR REPLACE VIEW "test_view" COPY GRANTS COMMENT='table description' AS SELECT "a", "b" FROM "source_table\"""", """ALTER VIEW "test_view" ALTER COLUMN "a" COMMENT 'a column description'""", """COMMENT ON TABLE "test_table" IS 'table description'""", """ALTER TABLE "test_table" ALTER COLUMN "a" COMMENT 'a column description'""", ] +def test_multiple_column_comments(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): + adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter) + + adapter.create_table( + "test_table", + {"a": exp.DataType.build("INT"), "b": exp.DataType.build("INT")}, + column_descriptions={"a": "a column description", "b": "b column description"}, + ) + + adapter.create_view( + "test_view", + parse_one("SELECT a, b FROM test_table"), + column_descriptions={"a": "a column description", "b": "b column description"}, + ) + + adapter._create_column_comments( + "test_table", + {"a": "a column description changed", "b": "b column description changed"}, + ) + + sql_calls = to_sql_calls(adapter) + assert sql_calls == [ + """CREATE TABLE IF NOT EXISTS "test_table" ("a" INT COMMENT 'a column description', "b" INT COMMENT 'b column description')""", + """CREATE OR REPLACE VIEW "test_view" COPY GRANTS AS SELECT "a", "b" FROM "test_table\"""", + """ALTER VIEW "test_view" ALTER COLUMN "a" COMMENT 'a column description', COLUMN "b" COMMENT 'b column description'""", + """ALTER TABLE "test_table" ALTER COLUMN "a" COMMENT 'a column description changed', COLUMN "b" COMMENT 'b column description changed'""", + ] + + +def test_sync_grants_config(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): + adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter) + relation = normalize_identifiers( + exp.to_table("test_db.test_schema.test_table", dialect="snowflake"), dialect="snowflake" + ) + new_grants_config = {"SELECT": ["ROLE role1", "ROLE role2"], "INSERT": ["ROLE role3"]} + + current_grants = [ + ("SELECT", "ROLE old_role"), + ("UPDATE", "ROLE legacy_role"), + ] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="snowflake") + expected_sql = ( + "SELECT privilege_type, grantee FROM TEST_DB.INFORMATION_SCHEMA.TABLE_PRIVILEGES " + "WHERE table_catalog = 'TEST_DB' AND table_schema = 'TEST_SCHEMA' AND table_name = 'TEST_TABLE' " + "AND grantor = CURRENT_ROLE() AND grantee <> CURRENT_ROLE()" + ) + assert executed_sql == expected_sql + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 5 + + assert 'GRANT SELECT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" TO ROLE "ROLE1"' in sql_calls + assert 'GRANT SELECT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" TO ROLE "ROLE2"' in sql_calls + assert 'GRANT INSERT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" TO ROLE "ROLE3"' in sql_calls + assert ( + 'REVOKE SELECT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" FROM ROLE "OLD_ROLE"' + in sql_calls + ) + assert ( + 'REVOKE UPDATE ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" FROM ROLE "LEGACY_ROLE"' + in sql_calls + ) + + +def test_sync_grants_config_with_overlaps( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter) + relation = normalize_identifiers( + exp.to_table("test_db.test_schema.test_table", dialect="snowflake"), dialect="snowflake" + ) + new_grants_config = { + "SELECT": ["ROLE shared", "ROLE new_role"], + "INSERT": ["ROLE shared", "ROLE writer"], + } + + current_grants = [ + ("SELECT", "ROLE shared"), + ("SELECT", "ROLE legacy"), + ("INSERT", "ROLE shared"), + ] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="snowflake") + expected_sql = ( + """SELECT privilege_type, grantee FROM TEST_DB.INFORMATION_SCHEMA.TABLE_PRIVILEGES """ + "WHERE table_catalog = 'TEST_DB' AND table_schema = 'TEST_SCHEMA' AND table_name = 'TEST_TABLE' " + "AND grantor = CURRENT_ROLE() AND grantee <> CURRENT_ROLE()" + ) + assert executed_sql == expected_sql + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 3 + + assert ( + 'GRANT SELECT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" TO ROLE "NEW_ROLE"' in sql_calls + ) + assert ( + 'GRANT INSERT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" TO ROLE "WRITER"' in sql_calls + ) + assert ( + 'REVOKE SELECT ON TABLE "TEST_DB"."TEST_SCHEMA"."TEST_TABLE" FROM ROLE "LEGACY"' + in sql_calls + ) + + +@pytest.mark.parametrize( + "table_type, expected_keyword", + [ + (DataObjectType.TABLE, "TABLE"), + (DataObjectType.VIEW, "VIEW"), + (DataObjectType.MATERIALIZED_VIEW, "MATERIALIZED VIEW"), + (DataObjectType.MANAGED_TABLE, "DYNAMIC TABLE"), + ], +) +def test_sync_grants_config_object_kind( + make_mocked_engine_adapter: t.Callable, + mocker: MockerFixture, + table_type: DataObjectType, + expected_keyword: str, +) -> None: + adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter) + relation = normalize_identifiers( + exp.to_table("test_db.test_schema.test_object", dialect="snowflake"), dialect="snowflake" + ) + + mocker.patch.object(adapter, "fetchall", return_value=[]) + + adapter.sync_grants_config(relation, {"SELECT": ["ROLE test"]}, table_type) + + sql_calls = to_sql_calls(adapter) + assert sql_calls == [ + f'GRANT SELECT ON {expected_keyword} "TEST_DB"."TEST_SCHEMA"."TEST_OBJECT" TO ROLE "TEST"' + ] + + +def test_sync_grants_config_quotes(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): + adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter) + relation = normalize_identifiers( + exp.to_table('"test_db"."test_schema"."test_table"', dialect="snowflake"), + dialect="snowflake", + ) + new_grants_config = {"SELECT": ["ROLE role1", "ROLE role2"], "INSERT": ["ROLE role3"]} + + current_grants = [ + ("SELECT", "ROLE old_role"), + ("UPDATE", "ROLE legacy_role"), + ] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="snowflake") + expected_sql = ( + """SELECT privilege_type, grantee FROM "test_db".INFORMATION_SCHEMA.TABLE_PRIVILEGES """ + "WHERE table_catalog = 'test_db' AND table_schema = 'test_schema' AND table_name = 'test_table' " + "AND grantor = CURRENT_ROLE() AND grantee <> CURRENT_ROLE()" + ) + assert executed_sql == expected_sql + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 5 + + assert 'GRANT SELECT ON TABLE "test_db"."test_schema"."test_table" TO ROLE "ROLE1"' in sql_calls + assert 'GRANT SELECT ON TABLE "test_db"."test_schema"."test_table" TO ROLE "ROLE2"' in sql_calls + assert 'GRANT INSERT ON TABLE "test_db"."test_schema"."test_table" TO ROLE "ROLE3"' in sql_calls + assert ( + 'REVOKE SELECT ON TABLE "test_db"."test_schema"."test_table" FROM ROLE "OLD_ROLE"' + in sql_calls + ) + assert ( + 'REVOKE UPDATE ON TABLE "test_db"."test_schema"."test_table" FROM ROLE "LEGACY_ROLE"' + in sql_calls + ) + + +def test_sync_grants_config_no_catalog_or_schema( + make_mocked_engine_adapter: t.Callable, mocker: MockerFixture +): + adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter) + relation = normalize_identifiers( + exp.to_table('"TesT_Table"', dialect="snowflake"), dialect="snowflake" + ) + new_grants_config = {"SELECT": ["ROLE role1", "ROLE role2"], "INSERT": ["ROLE role3"]} + + current_grants = [ + ("SELECT", "ROLE old_role"), + ("UPDATE", "ROLE legacy_role"), + ] + fetchall_mock = mocker.patch.object(adapter, "fetchall", return_value=current_grants) + mocker.patch.object(adapter, "get_current_catalog", return_value="caTalog") + mocker.patch.object(adapter, "_get_current_schema", return_value="sChema") + + adapter.sync_grants_config(relation, new_grants_config) + + fetchall_mock.assert_called_once() + executed_query = fetchall_mock.call_args[0][0] + executed_sql = executed_query.sql(dialect="snowflake") + expected_sql = ( + """SELECT privilege_type, grantee FROM "caTalog".INFORMATION_SCHEMA.TABLE_PRIVILEGES """ + "WHERE table_catalog = 'caTalog' AND table_schema = 'sChema' AND table_name = 'TesT_Table' " + "AND grantor = CURRENT_ROLE() AND grantee <> CURRENT_ROLE()" + ) + assert executed_sql == expected_sql + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 5 + + assert 'GRANT SELECT ON TABLE "TesT_Table" TO ROLE "ROLE1"' in sql_calls + assert 'GRANT SELECT ON TABLE "TesT_Table" TO ROLE "ROLE2"' in sql_calls + assert 'GRANT INSERT ON TABLE "TesT_Table" TO ROLE "ROLE3"' in sql_calls + assert 'REVOKE SELECT ON TABLE "TesT_Table" FROM ROLE "OLD_ROLE"' in sql_calls + assert 'REVOKE UPDATE ON TABLE "TesT_Table" FROM ROLE "LEGACY_ROLE"' in sql_calls + + def test_df_to_source_queries_use_schema( make_mocked_engine_adapter: t.Callable, mocker: MockerFixture ): @@ -192,14 +486,14 @@ def test_create_managed_table(make_mocked_engine_adapter: t.Callable, mocker: Mo adapter.create_managed_table( table_name="test_table", query=query, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, ) # warehouse not specified, should default to current_warehouse() adapter.create_managed_table( table_name="test_table", query=query, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, table_properties={"target_lag": exp.Literal.string("20 minutes")}, ) @@ -207,7 +501,7 @@ def test_create_managed_table(make_mocked_engine_adapter: t.Callable, mocker: Mo adapter.create_managed_table( table_name="test_table", query=query, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, table_properties={ "target_lag": exp.Literal.string("20 minutes"), "warehouse": exp.to_identifier("foo"), @@ -218,11 +512,11 @@ def test_create_managed_table(make_mocked_engine_adapter: t.Callable, mocker: Mo adapter.create_managed_table( table_name="test_table", query=query, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, table_properties={ "target_lag": exp.Literal.string("20 minutes"), }, - clustered_by=["a"], + clustered_by=[exp.column("a")], partitioned_by=["b"], ) @@ -230,7 +524,7 @@ def test_create_managed_table(make_mocked_engine_adapter: t.Callable, mocker: Mo adapter.create_managed_table( table_name="test_table", query=query, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, table_properties={ "target_lag": exp.Literal.string("20 minutes"), "refresh_mode": exp.Literal.string("auto"), @@ -238,23 +532,37 @@ def test_create_managed_table(make_mocked_engine_adapter: t.Callable, mocker: Mo }, ) + # table_format=iceberg + adapter.create_managed_table( + table_name="test_table", + query=query, + target_columns_to_types=columns_to_types, + table_properties={ + "target_lag": exp.Literal.string("20 minutes"), + "catalog": exp.Literal.string("snowflake"), + "external_volume": exp.Literal.string("test"), + }, + table_format="iceberg", + ) + assert to_sql_calls(adapter) == [ """CREATE OR REPLACE DYNAMIC TABLE "test_table" TARGET_LAG='20 minutes' WAREHOUSE="default_warehouse" AS SELECT CAST("a" AS INT) AS "a", CAST("b" AS INT) AS "b" FROM (SELECT "a", "b" FROM "source_table") AS "_subquery\"""", """CREATE OR REPLACE DYNAMIC TABLE "test_table" TARGET_LAG='20 minutes' WAREHOUSE="foo" AS SELECT CAST("a" AS INT) AS "a", CAST("b" AS INT) AS "b" FROM (SELECT "a", "b" FROM "source_table") AS "_subquery\"""", """CREATE OR REPLACE DYNAMIC TABLE "test_table" CLUSTER BY ("a") TARGET_LAG='20 minutes' WAREHOUSE="default_warehouse" AS SELECT CAST("a" AS INT) AS "a", CAST("b" AS INT) AS "b" FROM (SELECT "a", "b" FROM "source_table") AS "_subquery\"""", """CREATE OR REPLACE DYNAMIC TABLE "test_table" TARGET_LAG='20 minutes' REFRESH_MODE='auto' INITIALIZE='on_create' WAREHOUSE="default_warehouse" AS SELECT CAST("a" AS INT) AS "a", CAST("b" AS INT) AS "b" FROM (SELECT "a", "b" FROM "source_table") AS "_subquery\"""", + """CREATE OR REPLACE DYNAMIC ICEBERG TABLE "test_table" TARGET_LAG='20 minutes' CATALOG='snowflake' EXTERNAL_VOLUME='test' WAREHOUSE="default_warehouse" AS SELECT CAST("a" AS INT) AS "a", CAST("b" AS INT) AS "b" FROM (SELECT "a", "b" FROM "source_table") AS "_subquery\"""", ] def test_drop_managed_table(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter) - adapter.drop_managed_table(table_name=exp.parse_identifier("foo"), exists=False) - adapter.drop_managed_table(table_name=exp.parse_identifier("foo"), exists=True) + adapter.drop_managed_table(table_name="foo.bar", exists=False) + adapter.drop_managed_table(table_name="foo.bar", exists=True) assert to_sql_calls(adapter) == [ - 'DROP DYNAMIC TABLE "foo"', - 'DROP DYNAMIC TABLE IF EXISTS "foo"', + 'DROP DYNAMIC TABLE "foo"."bar"', + 'DROP DYNAMIC TABLE IF EXISTS "foo"."bar"', ] @@ -267,7 +575,7 @@ def test_ctas_skips_dynamic_table_properties(make_mocked_engine_adapter: t.Calla adapter.ctas( table_name="test_table", query_or_df=query, - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, table_properties={ "warehouse": exp.to_identifier("foo"), "target_lag": exp.Literal.string("20 minutes"), @@ -359,3 +667,404 @@ def test_set_current_schema(make_mocked_engine_adapter: t.Callable): 'USE SCHEMA "FOO"."foo"', 'USE SCHEMA "FOO"."fOo"', ] + + +def test_replace_query_snowpark_dataframe( + mocker: MockerFixture, make_mocked_engine_adapter: t.Callable +): + if not optional_import("snowflake.snowpark"): + pytest.skip("Snowpark not available in this environment") + + from snowflake.snowpark.session import Session + from snowflake.snowpark.dataframe import DataFrame as SnowparkDataFrame + + session = Session.builder.config("local_testing", True).create() + # df.createOrReplaceTempView() throws "[Local Testing] Mocking SnowflakePlan Rename is not supported" when used against the Snowflake local_testing session + # since we cant trace any queries from the Snowpark library anyway, we just suppress this and verify the cleanup queries issued by our EngineAdapter + session._conn._suppress_not_implemented_error = True + + df: SnowparkDataFrame = session.create_dataframe([(1, "name")], schema=["ID", "NAME"]) + assert isinstance(df, SnowparkDataFrame) + + mocker.patch("sqlmesh.core.engine_adapter.base.random_id", return_value="e6wjkjj6") + spy = mocker.spy(df, "createOrReplaceTempView") + + adapter: SnowflakeEngineAdapter = make_mocked_engine_adapter(SnowflakeEngineAdapter) + adapter._default_catalog = "foo" + + adapter.replace_query( + table_name="foo", + query_or_df=df, + target_columns_to_types={ + "ID": exp.DataType.build("INT"), + "NAME": exp.DataType.build("VARCHAR"), + }, + ) + + # verify that DROP VIEW is called instead of DROP TABLE + assert to_sql_calls(adapter) == [ + 'CREATE OR REPLACE TABLE "foo" AS SELECT CAST("ID" AS INT) AS "ID", CAST("NAME" AS VARCHAR) AS "NAME" FROM (SELECT CAST("ID" AS INT) AS "ID", CAST("NAME" AS VARCHAR) AS "NAME" FROM "__temp_foo_e6wjkjj6") AS "_subquery"', + 'DROP VIEW IF EXISTS "__temp_foo_e6wjkjj6"', + ] + + +def test_creatable_type_materialized_view_properties(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter) + + adapter.create_view( + "test_table", + parse_one("SELECT 1"), + materialized=True, + materialized_properties={ + # Partitioned by is not supported so we are confirming it is ignored + "partitioned_by": [exp.column("ds")], + "clustered_by": [exp.column("a")], + "partition_interval_unit": IntervalUnit.DAY, + }, + ) + + sql_calls = to_sql_calls(adapter) + # https://docs.snowflake.com/en/sql-reference/sql/create-materialized-view#syntax + assert sql_calls == [ + 'CREATE OR REPLACE MATERIALIZED VIEW "test_table" COPY GRANTS CLUSTER BY ("a") AS SELECT 1', + ] + + +def test_creatable_type_secure_view(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter) + + adapter.create_view( + "test_table", + parse_one("SELECT 1"), + view_properties={ + "creatable_type": exp.Column(this=exp.Identifier(this="secure")), + }, + ) + + sql_calls = to_sql_calls(adapter) + # https://docs.snowflake.com/en/sql-reference/sql/create-view.html + assert sql_calls == [ + 'CREATE OR REPLACE SECURE VIEW "test_table" COPY GRANTS AS SELECT 1', + ] + + +def test_creatable_type_secure_materialized_view(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter) + + adapter.create_view( + "test_table", + parse_one("SELECT 1"), + materialized=True, + view_properties={ + "creatable_type": exp.Column(this=exp.Identifier(this="secure")), + }, + ) + + sql_calls = to_sql_calls(adapter) + # https://docs.snowflake.com/en/sql-reference/sql/create-view.html + assert sql_calls == [ + 'CREATE OR REPLACE SECURE MATERIALIZED VIEW "test_table" COPY GRANTS AS SELECT 1', + ] + + +def test_creatable_type_temporary_view(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter) + + adapter.create_view( + "test_table", + parse_one("SELECT 1"), + view_properties={ + "creatable_type": exp.column("temporary"), + }, + ) + + sql_calls = to_sql_calls(adapter) + assert sql_calls == [ + 'CREATE OR REPLACE TEMPORARY VIEW "test_table" COPY GRANTS AS SELECT 1', + ] + + +def test_creatable_type_temporary_table(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter) + + adapter.create_table( + "test_table", + {"a": exp.DataType.build("INT"), "b": exp.DataType.build("INT")}, + table_properties={ + "creatable_type": exp.column("temporary"), + }, + ) + + sql_calls = to_sql_calls(adapter) + assert sql_calls == [ + 'CREATE TEMPORARY TABLE IF NOT EXISTS "test_table" ("a" INT, "b" INT)', + ] + + +def test_creatable_type_transient_table(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter) + + adapter.create_table( + "test_table", + {"a": exp.DataType.build("INT"), "b": exp.DataType.build("INT")}, + table_properties={ + "creatable_type": exp.column("transient"), + }, + ) + + sql_calls = to_sql_calls(adapter) + assert sql_calls == [ + 'CREATE TRANSIENT TABLE IF NOT EXISTS "test_table" ("a" INT, "b" INT)', + ] + + +def test_creatable_type_materialize_creatable_type_raise_error( + make_mocked_engine_adapter: t.Callable, +): + adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter) + + with pytest.raises(SQLMeshError): + adapter.create_view( + "test_view", + parse_one("SELECT 1"), + view_properties={ + "creatable_type": exp.column("materialized"), + }, + ) + + +def test_creatable_type_transient_type_from_model_definition( + make_mocked_engine_adapter: t.Callable, +): + adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter) + + model: SqlModel = t.cast( + SqlModel, + load_sql_based_model( + d.parse( + """ +MODEL ( + name external.test.table, + kind full, + physical_properties ( + creatable_type = transient + ) +); +SELECT a::INT; + """ + ) + ), + ) + adapter.create_table( + model.name, + target_columns_to_types=model.columns_to_types_or_raise, + table_properties=model.physical_properties, + ) + + sql_calls = to_sql_calls(adapter) + assert sql_calls == [ + 'CREATE TRANSIENT TABLE IF NOT EXISTS "external"."test"."table" ("a" INT)', + ] + + +def test_creatable_type_transient_type_from_model_definition_with_other_property( + make_mocked_engine_adapter: t.Callable, +): + adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter) + + model: SqlModel = t.cast( + SqlModel, + load_sql_based_model( + d.parse( + """ +MODEL ( + name external.test.table, + kind full, + physical_properties ( + creatable_type = transient, + require_partition_filter = true + ) +); +SELECT a::INT; + """ + ) + ), + ) + adapter.create_table( + model.name, + target_columns_to_types=model.columns_to_types_or_raise, + table_properties=model.physical_properties, + ) + + sql_calls = to_sql_calls(adapter) + assert sql_calls == [ + 'CREATE TRANSIENT TABLE IF NOT EXISTS "external"."test"."table" ("a" INT) REQUIRE_PARTITION_FILTER=TRUE' + ] + + +def test_create_view(make_mocked_engine_adapter: t.Callable): + adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter) + + adapter.create_view("test_view", parse_one("SELECT 1")) + adapter.create_view("test_view", parse_one("SELECT 1"), replace=False) + + sql_calls = to_sql_calls(adapter) + assert sql_calls == [ + 'CREATE OR REPLACE VIEW "test_view" COPY GRANTS AS SELECT 1', + 'CREATE VIEW "test_view" AS SELECT 1', + ] + + +def test_clone_table(mocker: MockerFixture, make_mocked_engine_adapter: t.Callable): + mocker.patch("sqlmesh.core.engine_adapter.snowflake.SnowflakeEngineAdapter.set_current_catalog") + adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter, default_catalog="test_catalog") + adapter.clone_table("target_table", "source_table") + adapter.cursor.execute.assert_called_once_with( + 'CREATE TABLE IF NOT EXISTS "target_table" CLONE "source_table"' + ) + + # Validate with transient type we create the clone table accordingly + rendered_physical_properties = { + "creatable_type": exp.column("transient"), + } + adapter = make_mocked_engine_adapter(SnowflakeEngineAdapter, default_catalog="test_catalog") + adapter.clone_table( + "target_table", "source_table", rendered_physical_properties=rendered_physical_properties + ) + adapter.cursor.execute.assert_called_once_with( + 'CREATE TRANSIENT TABLE IF NOT EXISTS "target_table" CLONE "source_table"' + ) + + # Validate other engine adapters would work as usual even when we pass the properties + adapter = make_mocked_engine_adapter(EngineAdapter, default_catalog="test_catalog") + adapter.SUPPORTS_CLONING = True + adapter.clone_table( + "target_table", "source_table", rendered_physical_properties=rendered_physical_properties + ) + adapter.cursor.execute.assert_called_once_with( + 'CREATE TABLE IF NOT EXISTS "target_table" CLONE "source_table"' + ) + + +def test_table_format_iceberg(snowflake_mocked_engine_adapter: SnowflakeEngineAdapter) -> None: + adapter = snowflake_mocked_engine_adapter + + model = load_sql_based_model( + expressions=d.parse(""" + MODEL ( + name test.table, + kind full, + table_format iceberg, + physical_properties ( + catalog = 'snowflake', + external_volume = 'test' + ) + ); + SELECT a::INT; + """) + ) + assert isinstance(model, SqlModel) + assert model.table_format == "iceberg" + + adapter.create_table( + table_name=model.name, + target_columns_to_types=model.columns_to_types_or_raise, + table_format=model.table_format, + table_properties=model.physical_properties, + ) + + adapter.ctas( + table_name=model.name, + query_or_df=model.render_query_or_raise(), + target_columns_to_types=model.columns_to_types_or_raise, + table_format=model.table_format, + table_properties=model.physical_properties, + ) + + assert to_sql_calls(adapter) == [ + 'CREATE ICEBERG TABLE IF NOT EXISTS "test"."table" ("a" INT) CATALOG=\'snowflake\' EXTERNAL_VOLUME=\'test\'', + 'CREATE ICEBERG TABLE IF NOT EXISTS "test"."table" CATALOG=\'snowflake\' EXTERNAL_VOLUME=\'test\' AS SELECT CAST("a" AS INT) AS "a" FROM (SELECT CAST("a" AS INT) AS "a") AS "_subquery"', + ] + + +def test_create_view_with_schema_and_grants( + snowflake_mocked_engine_adapter: SnowflakeEngineAdapter, +): + adapter = snowflake_mocked_engine_adapter + + model_v = load_sql_based_model( + d.parse(f""" + MODEL ( + name test.v, + kind VIEW, + description 'normal **view** from integration test', + dialect 'snowflake' + ); + + select 1 as "ID", 'foo' as "NAME"; + """) + ) + + model_mv = load_sql_based_model( + d.parse(f""" + MODEL ( + name test.mv, + kind VIEW ( + materialized true + ), + description 'materialized **view** from integration test', + dialect 'snowflake' + ); + + select 1 as "ID", 'foo' as "NAME"; + """) + ) + + assert isinstance(model_v.kind, ViewKind) + assert isinstance(model_mv.kind, ViewKind) + + adapter.create_view( + "target_view", + model_v.render_query_or_raise(), + model_v.columns_to_types, + materialized=model_v.kind.materialized, + view_properties=model_v.render_physical_properties(), + table_description=model_v.description, + column_descriptions=model_v.column_descriptions, + ) + + adapter.create_view( + "target_materialized_view", + model_mv.render_query_or_raise(), + model_mv.columns_to_types, + materialized=model_mv.kind.materialized, + view_properties=model_mv.render_physical_properties(), + table_description=model_mv.description, + column_descriptions=model_mv.column_descriptions, + ) + + assert to_sql_calls(adapter) == [ + # normal view - COPY GRANTS goes after the column list + """CREATE OR REPLACE VIEW "target_view" ("ID", "NAME") COPY GRANTS COMMENT='normal **view** from integration test' AS SELECT 1 AS "ID", 'foo' AS "NAME\"""", + # materialized view - COPY GRANTS goes before the column list + """CREATE OR REPLACE MATERIALIZED VIEW "target_materialized_view" COPY GRANTS ("ID", "NAME") COMMENT='materialized **view** from integration test' AS SELECT 1 AS "ID", 'foo' AS "NAME\"""", + ] + + +def test_create_catalog(snowflake_mocked_engine_adapter: SnowflakeEngineAdapter) -> None: + adapter = snowflake_mocked_engine_adapter + adapter.create_catalog(exp.to_identifier("foo")) + + assert to_sql_calls(adapter) == [ + "CREATE DATABASE IF NOT EXISTS \"foo\" COMMENT='sqlmesh_managed'" + ] + + +def test_drop_catalog(snowflake_mocked_engine_adapter: SnowflakeEngineAdapter) -> None: + adapter = snowflake_mocked_engine_adapter + adapter.drop_catalog(exp.to_identifier("foo")) + + assert to_sql_calls(adapter) == [ + """SELECT 1 FROM "INFORMATION_SCHEMA"."DATABASES" WHERE "DATABASE_NAME" = 'foo' AND "COMMENT" = 'sqlmesh_managed'""", + 'DROP DATABASE IF EXISTS "foo"', + ] diff --git a/tests/core/engine_adapter/test_spark.py b/tests/core/engine_adapter/test_spark.py index 152c412c55..d7c3127f05 100644 --- a/tests/core/engine_adapter/test_spark.py +++ b/tests/core/engine_adapter/test_spark.py @@ -3,7 +3,6 @@ from datetime import datetime from unittest.mock import call -import pandas as pd import pytest from pyspark.sql import types as spark_types from pytest_mock.plugin import MockerFixture @@ -11,12 +10,21 @@ from sqlglot import parse_one from sqlmesh.core.engine_adapter import SparkEngineAdapter +from sqlmesh.core.engine_adapter.shared import DataObject from sqlmesh.utils.errors import SQLMeshError from tests.core.engine_adapter import to_sql_calls +import sqlmesh.core.dialect as d +from sqlmesh.core.model import load_sql_based_model +from sqlmesh.core.model.definition import SqlModel pytestmark = [pytest.mark.engine, pytest.mark.spark] +@pytest.fixture +def adapter(make_mocked_engine_adapter: t.Callable) -> SparkEngineAdapter: + return make_mocked_engine_adapter(SparkEngineAdapter) + + def test_create_table_properties(make_mocked_engine_adapter: t.Callable): adapter = make_mocked_engine_adapter(SparkEngineAdapter) @@ -58,14 +66,15 @@ def test_create_table_properties(make_mocked_engine_adapter: t.Callable): ) +@pytest.mark.parametrize("wap_enabled", [True, False]) def test_replace_query_table_properties_not_exists( - mocker: MockerFixture, make_mocked_engine_adapter: t.Callable + mocker: MockerFixture, make_mocked_engine_adapter: t.Callable, wap_enabled: bool ): mocker.patch( "sqlmesh.core.engine_adapter.spark.SparkEngineAdapter.table_exists", return_value=False, ) - adapter = make_mocked_engine_adapter(SparkEngineAdapter) + adapter = make_mocked_engine_adapter(SparkEngineAdapter, wap_enabled=wap_enabled) columns_to_types = { "cola": exp.DataType.build("INT"), @@ -75,16 +84,19 @@ def test_replace_query_table_properties_not_exists( adapter.replace_query( "test_table", parse_one("SELECT 1 AS cola, '2' AS colb, '3' AS colc"), - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, partitioned_by=[exp.to_column("colb")], storage_format="ICEBERG", table_properties={"a": exp.convert(1)}, ) - assert to_sql_calls(adapter) == [ + expected_sql_calls = [ "CREATE TABLE IF NOT EXISTS `test_table` USING ICEBERG PARTITIONED BY (`colb`) TBLPROPERTIES ('a'=1) AS SELECT CAST(`cola` AS INT) AS `cola`, CAST(`colb` AS STRING) AS `colb`, CAST(`colc` AS STRING) AS `colc` FROM (SELECT 1 AS `cola`, '2' AS `colb`, '3' AS `colc`) AS `_subquery`", - "INSERT INTO `test_table` SELECT * FROM `test_table`", ] + if wap_enabled: + expected_sql_calls.append("INSERT INTO `test_table` SELECT * FROM `test_table`") + + assert to_sql_calls(adapter) == expected_sql_calls def test_replace_query_table_properties_exists( @@ -95,6 +107,11 @@ def test_replace_query_table_properties_exists( return_value=True, ) adapter = make_mocked_engine_adapter(SparkEngineAdapter) + mocker.patch.object( + adapter, + "_get_data_objects", + return_value=[DataObject(schema="", name="test_table", type="table")], + ) columns_to_types = { "cola": exp.DataType.build("INT"), @@ -104,7 +121,7 @@ def test_replace_query_table_properties_exists( adapter.replace_query( "test_table", parse_one("SELECT 1 AS cola, '2' AS colb, '3' AS colc"), - columns_to_types=columns_to_types, + target_columns_to_types=columns_to_types, partitioned_by=[exp.to_column("colb")], storage_format="ICEBERG", table_properties={"a": exp.convert(1)}, @@ -140,17 +157,16 @@ def table_columns(table_name: str) -> t.Dict[str, exp.DataType]: "complex": exp.DataType.build("STRUCT"), "ds": exp.DataType.build("STRING"), } - else: - return { - "id": exp.DataType.build("BIGINT"), - "a": exp.DataType.build("STRING"), - "complex": exp.DataType.build("STRUCT"), - "ds": exp.DataType.build("INT"), - } + return { + "id": exp.DataType.build("BIGINT"), + "a": exp.DataType.build("STRING"), + "complex": exp.DataType.build("STRUCT"), + "ds": exp.DataType.build("INT"), + } adapter.columns = table_columns - adapter.alter_table(adapter.get_alter_expressions(current_table_name, target_table_name)) + adapter.alter_table(adapter.get_alter_operations(current_table_name, target_table_name)) adapter.cursor.execute.assert_has_calls( [ @@ -188,6 +204,11 @@ def test_replace_query_exists(mocker: MockerFixture, make_mocked_engine_adapter: return_value=True, ) adapter = make_mocked_engine_adapter(SparkEngineAdapter) + mocker.patch.object( + adapter, + "_get_data_objects", + return_value=[DataObject(schema="", name="test_table", type="table")], + ) adapter.replace_query("test_table", parse_one("SELECT a FROM tbl"), {"a": "int"}) assert to_sql_calls(adapter) == [ @@ -195,48 +216,6 @@ def test_replace_query_exists(mocker: MockerFixture, make_mocked_engine_adapter: ] -def test_replace_query_pandas_not_exists( - make_mocked_engine_adapter: t.Callable, mocker: MockerFixture -): - mocker.patch( - "sqlmesh.core.engine_adapter.spark.SparkEngineAdapter.table_exists", - return_value=False, - ) - mocker.patch("sqlmesh.core.engine_adapter.spark.SparkEngineAdapter._use_spark_session", False) - mocker.patch( - "sqlmesh.core.engine_adapter.spark.SparkEngineAdapter._ensure_fqn", side_effect=lambda x: x - ) - adapter = make_mocked_engine_adapter(SparkEngineAdapter) - df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) - adapter.replace_query( - "test_table", df, {"a": exp.DataType.build("int"), "b": exp.DataType.build("int")} - ) - - assert to_sql_calls(adapter) == [ - "CREATE TABLE IF NOT EXISTS `test_table` AS SELECT CAST(`a` AS INT) AS `a`, CAST(`b` AS INT) AS `b` FROM (SELECT CAST(`a` AS INT) AS `a`, CAST(`b` AS INT) AS `b` FROM VALUES (1, 4), (2, 5), (3, 6) AS `t`(`a`, `b`)) AS `_subquery`", - ] - - -def test_replace_query_pandas_exists(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): - mocker.patch( - "sqlmesh.core.engine_adapter.spark.SparkEngineAdapter.table_exists", - return_value=True, - ) - mocker.patch("sqlmesh.core.engine_adapter.spark.SparkEngineAdapter._use_spark_session", False) - mocker.patch( - "sqlmesh.core.engine_adapter.spark.SparkEngineAdapter._ensure_fqn", side_effect=lambda x: x - ) - adapter = make_mocked_engine_adapter(SparkEngineAdapter) - df = pd.DataFrame({"a": [1, 2, 3], "b": [4, 5, 6]}) - adapter.replace_query( - "test_table", df, {"a": exp.DataType.build("int"), "b": exp.DataType.build("int")} - ) - - assert to_sql_calls(adapter) == [ - "INSERT OVERWRITE TABLE `test_table` (`a`, `b`) SELECT CAST(`a` AS INT) AS `a`, CAST(`b` AS INT) AS `b` FROM VALUES (1, 4), (2, 5), (3, 6) AS `t`(`a`, `b`)", - ] - - def test_replace_query_self_ref_not_exists( make_mocked_engine_adapter: t.Callable, mocker: MockerFixture, make_temp_table_name: t.Callable ): @@ -245,7 +224,7 @@ def test_replace_query_self_ref_not_exists( lambda self: "spark_catalog", ) mocker.patch( - "sqlmesh.core.engine_adapter.spark.SparkEngineAdapter.get_current_database", + "sqlmesh.core.engine_adapter.spark.SparkEngineAdapter._get_current_schema", side_effect=lambda: "default", ) @@ -275,6 +254,12 @@ def check_table_exists(table_name: exp.Table) -> bool: side_effect=check_table_exists, ) + mocker.patch.object( + adapter, + "_get_data_objects", + return_value=[DataObject(schema="db", name="table", type="table")], + ) + adapter.replace_query(table_name, parse_one(f"SELECT col + 1 AS col FROM {table_name}")) assert to_sql_calls(adapter) == [ @@ -298,12 +283,17 @@ def test_replace_query_self_ref_exists( return_value="spark_catalog", ) mocker.patch( - "sqlmesh.core.engine_adapter.spark.SparkEngineAdapter.get_current_database", + "sqlmesh.core.engine_adapter.spark.SparkEngineAdapter._get_current_schema", return_value="default", ) adapter = make_mocked_engine_adapter(SparkEngineAdapter) adapter.cursor.fetchone.return_value = (1,) + mocker.patch.object( + adapter, + "_get_data_objects", + return_value=[DataObject(schema="db", name="table", type="table")], + ) table_name = "db.table" temp_table_id = "abcdefgh" @@ -369,7 +359,6 @@ def test_create_state_table(make_mocked_engine_adapter: t.Callable): ("double", spark_types.DoubleType()), ("decimal", spark_types.DecimalType()), ("text", spark_types.StringType()), - # Spark supports VARCHAR and CHAR but SQLGlot currently converts them to strings ("varchar(25)", spark_types.StringType()), ("char(30)", spark_types.StringType()), ("binary", spark_types.BinaryType()), @@ -536,7 +525,13 @@ def test_spark_struct_primitives_to_col_to_types(type_name, spark_type): actual = SparkEngineAdapter.spark_to_sqlglot_types( spark_types.StructType([spark_types.StructField(f"col_{type_name}", spark_type)]) ) - expected = {f"col_{type_name}": exp.DataType.build(type_name, dialect="spark")} + + expected_type = ( + exp.DataType.build("string") + if "char" in type_name + else exp.DataType.build(type_name, dialect="spark") + ) + expected = {f"col_{type_name}": expected_type} assert actual == expected @@ -556,12 +551,8 @@ def test_spark_struct_complex_to_col_to_types(type_name, spark_type): def test_scd_type_2_by_time( make_mocked_engine_adapter: t.Callable, make_temp_table_name: t.Callable, mocker: MockerFixture ): - mocker.patch( - "sqlmesh.core.engine_adapter.spark.SparkEngineAdapter.table_exists", - return_value=False, - ) - adapter = make_mocked_engine_adapter(SparkEngineAdapter) + adapter._default_catalog = "spark_catalog" adapter.spark.catalog.currentCatalog.return_value = "spark_catalog" adapter.spark.catalog.currentDatabase.return_value = "default" @@ -580,6 +571,11 @@ def check_table_exists(table_name: exp.Table) -> bool: "sqlmesh.core.engine_adapter.spark.SparkEngineAdapter.table_exists", side_effect=check_table_exists, ) + mocker.patch.object( + adapter, + "_get_data_objects", + return_value=[DataObject(schema="db", name="target", type="table")], + ) adapter.scd_type_2_by_time( target_table="db.target", @@ -590,7 +586,7 @@ def check_table_exists(table_name: exp.Table) -> bool: valid_from_col=exp.column("test_valid_from", quoted=True), valid_to_col=exp.column("test_valid_to", quoted=True), updated_at_col=exp.column("test_updated_at", quoted=True), - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("INT"), "name": exp.DataType.build("VARCHAR"), "price": exp.DataType.build("DOUBLE"), @@ -608,17 +604,17 @@ def check_table_exists(table_name: exp.Table) -> bool: parse_one( """WITH `source` AS ( SELECT - TRUE AS `_exists`, + `_exists`, `id`, `name`, `price`, - CAST(`test_updated_at` AS TIMESTAMP) AS `test_updated_at` + `test_updated_at` FROM ( SELECT TRUE AS `_exists`, - `id`, - `name`, - `price`, + `id` AS `id`, + `name` AS `name`, + `price` AS `price`, CAST(`test_updated_at` AS TIMESTAMP) AS `test_updated_at`, ROW_NUMBER() OVER (PARTITION BY COALESCE(`id`, '') ORDER BY COALESCE(`id`, '')) AS _row_number FROM ( @@ -679,7 +675,7 @@ def check_table_exists(table_name: exp.Table) -> bool: COALESCE(`id`, '') ), `joined` AS ( SELECT - `source`.`_exists`, + `source`.`_exists` AS `_exists`, `latest`.`id` AS `t_id`, `latest`.`name` AS `t_name`, `latest`.`price` AS `t_price`, @@ -695,7 +691,7 @@ def check_table_exists(table_name: exp.Table) -> bool: ON COALESCE(`latest`.`id`, '') = COALESCE(`source`.`id`, '') UNION ALL SELECT - `source`.`_exists`, + `source`.`_exists` AS `_exists`, `latest`.`id` AS `t_id`, `latest`.`name` AS `t_name`, `latest`.`price` AS `t_price`, @@ -729,8 +725,8 @@ def check_table_exists(table_name: exp.Table) -> bool: ELSE `t_test_valid_from` END AS `test_valid_from`, CASE - WHEN `test_updated_at` > `t_test_updated_at` - THEN `test_updated_at` + WHEN `joined`.`test_updated_at` > `joined`.`t_test_updated_at` + THEN `joined`.`test_updated_at` WHEN `joined`.`_exists` IS NULL THEN CAST('2020-01-01 00:00:00' AS TIMESTAMP) ELSE `t_test_valid_to` @@ -748,7 +744,7 @@ def check_table_exists(table_name: exp.Table) -> bool: CAST(NULL AS TIMESTAMP) AS `test_valid_to` FROM `joined` WHERE - `test_updated_at` > `t_test_updated_at` + `joined`.`test_updated_at` > `joined`.`t_test_updated_at` ) INSERT OVERWRITE TABLE `db`.`target` ( `id`, @@ -833,13 +829,16 @@ def test_wap_publish(make_mocked_engine_adapter: t.Callable, mocker: MockerFixtu ) -def test_create_table_iceberg(mocker: MockerFixture, make_mocked_engine_adapter: t.Callable): +@pytest.mark.parametrize("wap_enabled", [True, False]) +def test_create_table_iceberg( + mocker: MockerFixture, make_mocked_engine_adapter: t.Callable, wap_enabled: bool +): mocker.patch( "sqlmesh.core.engine_adapter.spark.SparkEngineAdapter.table_exists", return_value=False, ) - adapter = make_mocked_engine_adapter(SparkEngineAdapter) + adapter = make_mocked_engine_adapter(SparkEngineAdapter, wap_enabled=wap_enabled) columns_to_types = { "cola": exp.DataType.build("INT"), @@ -854,10 +853,13 @@ def test_create_table_iceberg(mocker: MockerFixture, make_mocked_engine_adapter: storage_format="ICEBERG", ) - assert to_sql_calls(adapter) == [ + expected_sql_calls = [ "CREATE TABLE IF NOT EXISTS `test_table` (`cola` INT, `colb` STRING, `colc` STRING) USING ICEBERG PARTITIONED BY (`colb`)", - "INSERT INTO `test_table` SELECT * FROM `test_table`", ] + if wap_enabled: + expected_sql_calls.append("INSERT INTO `test_table` SELECT * FROM `test_table`") + + assert to_sql_calls(adapter) == expected_sql_calls def test_comments_hive(mocker: MockerFixture, make_mocked_engine_adapter: t.Callable): @@ -981,7 +983,7 @@ def test_create_table_with_wap(make_mocked_engine_adapter: t.Callable, mocker: M "sqlmesh.core.engine_adapter.spark.SparkEngineAdapter.table_exists", return_value=False, ) - adapter = make_mocked_engine_adapter(SparkEngineAdapter) + adapter = make_mocked_engine_adapter(SparkEngineAdapter, wap_enabled=True) adapter.create_table( "catalog.schema.table.branch_wap_12345", @@ -1009,19 +1011,98 @@ def test_replace_query_with_wap_self_reference( ) adapter = make_mocked_engine_adapter(SparkEngineAdapter) + mocker.patch.object( + adapter, + "_get_data_objects", + return_value=[DataObject(schema="schema", name="table", type="table")], + ) adapter.replace_query( "catalog.schema.table.branch_wap_12345", parse_one("SELECT 1 as a FROM catalog.schema.table.branch_wap_12345"), - columns_to_types={"a": exp.DataType.build("INT")}, + target_columns_to_types={"a": exp.DataType.build("INT")}, storage_format="ICEBERG", ) sql_calls = to_sql_calls(adapter) assert sql_calls == [ - "CREATE TABLE IF NOT EXISTS `catalog`.`schema`.`table` (`a` INT)", + "CREATE TABLE IF NOT EXISTS `catalog`.`schema`.`table` (`a` INT) USING ICEBERG", "CREATE SCHEMA IF NOT EXISTS `catalog`.`schema`", "CREATE TABLE IF NOT EXISTS `catalog`.`schema`.`temp_branch_wap_12345_abcdefgh` USING ICEBERG AS SELECT CAST(`a` AS INT) AS `a` FROM (SELECT `a` FROM `catalog`.`schema`.`table`.`branch_wap_12345`) AS `_subquery`", "INSERT OVERWRITE TABLE `catalog`.`schema`.`table`.`branch_wap_12345` (`a`) SELECT 1 AS `a` FROM `catalog`.`schema`.`temp_branch_wap_12345_abcdefgh`", "DROP TABLE IF EXISTS `catalog`.`schema`.`temp_branch_wap_12345_abcdefgh`", ] + + +def test_table_format(adapter: SparkEngineAdapter, mocker: MockerFixture): + mocker.patch( + "sqlmesh.core.engine_adapter.spark.SparkEngineAdapter.table_exists", + return_value=True, + ) + + expressions = d.parse( + """ + MODEL ( + name test_table, + kind FULL, + table_format iceberg, + storage_format orc + ); + + SELECT 1::timestamp AS cola, 2::varchar as colb, 'foo' as colc; + """ + ) + model: SqlModel = t.cast(SqlModel, load_sql_based_model(expressions)) + + # both table_format and storage_format + adapter.create_table( + table_name=model.name, + target_columns_to_types=model.columns_to_types_or_raise, + table_format=model.table_format, + storage_format=model.storage_format, + ) + + # just table_format + adapter.create_table( + table_name=model.name, + target_columns_to_types=model.columns_to_types_or_raise, + table_format=model.table_format, + ) + + # just storage_format set to a table format (test for backwards compatibility) + adapter.create_table( + table_name=model.name, + target_columns_to_types=model.columns_to_types_or_raise, + storage_format=model.table_format, + ) + + adapter.ctas( + table_name=model.name, + query_or_df=model.query, + target_columns_to_types=model.columns_to_types_or_raise, + table_format=model.table_format, + storage_format=model.storage_format, + ) + + assert to_sql_calls(adapter) == [ + "CREATE TABLE IF NOT EXISTS `test_table` (`cola` TIMESTAMP, `colb` STRING, `colc` STRING) USING ICEBERG TBLPROPERTIES ('write.format.default'='orc')", + "CREATE TABLE IF NOT EXISTS `test_table` (`cola` TIMESTAMP, `colb` STRING, `colc` STRING) USING ICEBERG", + "CREATE TABLE IF NOT EXISTS `test_table` (`cola` TIMESTAMP, `colb` STRING, `colc` STRING) USING ICEBERG", + "CREATE TABLE IF NOT EXISTS `test_table` USING ICEBERG TBLPROPERTIES ('write.format.default'='orc') AS SELECT CAST(`cola` AS TIMESTAMP) AS `cola`, CAST(`colb` AS STRING) AS `colb`, CAST(`colc` AS STRING) AS `colc` FROM (SELECT CAST(1 AS TIMESTAMP) AS `cola`, CAST(2 AS STRING) AS `colb`, 'foo' AS `colc`) AS `_subquery`", + ] + + +def test_get_data_object_wap_branch(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): + adapter = make_mocked_engine_adapter(SparkEngineAdapter, patch_get_data_objects=False) + mocker.patch.object(adapter, "_get_data_objects", return_value=[]) + + table = exp.to_table( + "`catalog`.`sqlmesh__test`.`test__test_view__630979748`.`branch_wap_472234d7`", + dialect="spark", + ) + adapter.get_data_object(table) + + adapter._get_data_objects.assert_called_once_with( + d.schema_("sqlmesh__test", "catalog"), + {"test__test_view__630979748"}, + ) diff --git a/tests/core/engine_adapter/test_trino.py b/tests/core/engine_adapter/test_trino.py index 0ad7662e3b..1bfe82b858 100644 --- a/tests/core/engine_adapter/test_trino.py +++ b/tests/core/engine_adapter/test_trino.py @@ -6,9 +6,13 @@ from sqlglot import exp, parse_one import sqlmesh.core.dialect as d +from sqlmesh.core.config.connection import TrinoConnectionConfig from sqlmesh.core.engine_adapter import TrinoEngineAdapter from sqlmesh.core.model import load_sql_based_model from sqlmesh.core.model.definition import SqlModel +from sqlmesh.core.dialect import schema_ +from sqlmesh.utils.date import to_ds +from sqlmesh.utils.errors import SQLMeshError from tests.core.engine_adapter import to_sql_calls pytestmark = [pytest.mark.engine, pytest.mark.trino] @@ -21,8 +25,8 @@ def trino_mocked_engine_adapter( def mock_catalog_type(catalog_name): if "iceberg" in catalog_name: return "iceberg" - if "delta" in catalog_name: - return "delta" + if "delta_lake" in catalog_name: + return "delta_lake" return "hive" mocker.patch( @@ -30,6 +34,11 @@ def mock_catalog_type(catalog_name): side_effect=mock_catalog_type, ) + mocker.patch( + "sqlmesh.core.engine_adapter.trino.TrinoEngineAdapter._block_until_table_exists", + return_value=True, + ) + return make_mocked_engine_adapter(TrinoEngineAdapter) @@ -42,9 +51,7 @@ def test_set_current_catalog(trino_mocked_engine_adapter: TrinoEngineAdapter): ] -@pytest.mark.trino_iceberg -@pytest.mark.trino_delta -@pytest.mark.parametrize("storage_type", ["iceberg", "delta"]) +@pytest.mark.parametrize("storage_type", ["iceberg", "delta_lake"]) def test_get_catalog_type( trino_mocked_engine_adapter: TrinoEngineAdapter, mocker: MockerFixture, storage_type: str ): @@ -58,7 +65,7 @@ def test_get_catalog_type( assert adapter.get_catalog_type("foo") == TrinoEngineAdapter.DEFAULT_CATALOG_TYPE assert adapter.get_catalog_type("datalake_hive") == "hive" assert adapter.get_catalog_type("datalake_iceberg") == "iceberg" - assert adapter.get_catalog_type("datalake_delta") == "delta" + assert adapter.get_catalog_type("datalake_delta_lake") == "delta_lake" mocker.patch( "sqlmesh.core.engine_adapter.trino.TrinoEngineAdapter.get_current_catalog", @@ -97,8 +104,7 @@ def mock_fetchone(sql): assert fetchone_mock.call_count == 2 -@pytest.mark.trino_delta -@pytest.mark.parametrize("storage_type", ["hive", "delta"]) +@pytest.mark.parametrize("storage_type", ["hive", "delta_lake"]) def test_partitioned_by_hive_delta( trino_mocked_engine_adapter: TrinoEngineAdapter, mocker: MockerFixture, storage_type: str ): @@ -125,7 +131,6 @@ def test_partitioned_by_hive_delta( ] -@pytest.mark.trino_iceberg def test_partitioned_by_iceberg( trino_mocked_engine_adapter: TrinoEngineAdapter, mocker: MockerFixture ): @@ -152,7 +157,6 @@ def test_partitioned_by_iceberg( ] -@pytest.mark.trino_iceberg def test_partitioned_by_iceberg_transforms( trino_mocked_engine_adapter: TrinoEngineAdapter, mocker: MockerFixture ): @@ -180,7 +184,7 @@ def test_partitioned_by_iceberg_transforms( adapter.create_table( table_name=model.view_name, - columns_to_types=model.columns_to_types_or_raise, + target_columns_to_types=model.columns_to_types_or_raise, partitioned_by=model.partitioned_by, ) @@ -196,7 +200,6 @@ def test_partitioned_by_iceberg_transforms( ] -@pytest.mark.trino_iceberg def test_partitioned_by_with_multiple_catalogs_same_server( trino_mocked_engine_adapter: TrinoEngineAdapter, mocker: MockerFixture ): @@ -301,7 +304,9 @@ def test_comments_hive(mocker: MockerFixture, make_mocked_engine_adapter: t.Call sql_calls = to_sql_calls(adapter) assert sql_calls == [ f"""CREATE TABLE IF NOT EXISTS "test_table" ("a" INTEGER COMMENT '{truncated_column_comment}', "b" INTEGER) COMMENT '{truncated_table_comment}'""", + 'DESCRIBE "test_table"', f"""CREATE TABLE IF NOT EXISTS "test_table" COMMENT '{truncated_table_comment}' AS SELECT CAST("a" AS INTEGER) AS "a", CAST("b" AS INTEGER) AS "b" FROM (SELECT "a", "b" FROM "source_table") AS "_subquery\"""", + 'DESCRIBE "test_table"', f"""COMMENT ON COLUMN "test_table"."a" IS '{truncated_column_comment}'""", """CREATE OR REPLACE VIEW test_view AS SELECT a, b FROM source_table""", f"""COMMENT ON VIEW "test_view" IS '{truncated_table_comment}'""", @@ -310,9 +315,7 @@ def test_comments_hive(mocker: MockerFixture, make_mocked_engine_adapter: t.Call ] -@pytest.mark.trino_iceberg -@pytest.mark.trino_delta -@pytest.mark.parametrize("storage_type", ["iceberg", "delta"]) +@pytest.mark.parametrize("storage_type", ["iceberg", "delta_lake"]) def test_comments_iceberg_delta( mocker: MockerFixture, make_mocked_engine_adapter: t.Callable, storage_type: str ): @@ -378,7 +381,6 @@ def test_comments_iceberg_delta( ] -@pytest.mark.trino_delta def test_delta_timestamps(make_mocked_engine_adapter: t.Callable): adapter = make_mocked_engine_adapter(TrinoEngineAdapter) @@ -400,3 +402,547 @@ def test_delta_timestamps(make_mocked_engine_adapter: t.Callable): "ts_tz": ts3_tz, "ts_tz_1": ts3_tz, } + + +def test_timestamp_mapping(): + """Test that timestamp_mapping config property is properly defined and accessible.""" + config = TrinoConnectionConfig( + user="user", + host="host", + catalog="catalog", + ) + + assert config._connection_factory_with_kwargs.keywords["source"] == "sqlmesh" + + adapter = config.create_engine_adapter() + assert adapter.timestamp_mapping is None + + config = TrinoConnectionConfig( + user="user", + host="host", + catalog="catalog", + source="my_source", + timestamp_mapping={ + "TIMESTAMP": "TIMESTAMP(6)", + "TIMESTAMP(3)": "TIMESTAMP WITH TIME ZONE", + }, + ) + assert config._connection_factory_with_kwargs.keywords["source"] == "my_source" + adapter = config.create_engine_adapter() + assert adapter.timestamp_mapping is not None + assert adapter.timestamp_mapping[exp.DataType.build("TIMESTAMP")] == exp.DataType.build( + "TIMESTAMP(6)" + ) + + +def test_delta_timestamps_with_custom_mapping(make_mocked_engine_adapter: t.Callable): + """Test that _apply_timestamp_mapping + _to_delta_ts respects custom timestamp_mapping.""" + # Create config with custom timestamp mapping + # Mapped columns are skipped by _to_delta_ts + config = TrinoConnectionConfig( + user="user", + host="host", + catalog="catalog", + timestamp_mapping={ + "TIMESTAMP": "TIMESTAMP(3)", + "TIMESTAMP(1)": "TIMESTAMP(3)", + "TIMESTAMP WITH TIME ZONE": "TIMESTAMP(6) WITH TIME ZONE", + "TIMESTAMP(1) WITH TIME ZONE": "TIMESTAMP(6) WITH TIME ZONE", + }, + ) + + adapter = make_mocked_engine_adapter( + TrinoEngineAdapter, timestamp_mapping=config.timestamp_mapping + ) + + ts3 = exp.DataType.build("timestamp(3)") + ts6_tz = exp.DataType.build("timestamp(6) with time zone") + + columns_to_types = { + "ts": exp.DataType.build("TIMESTAMP"), + "ts_1": exp.DataType.build("TIMESTAMP(1)"), + "ts_tz": exp.DataType.build("TIMESTAMP WITH TIME ZONE"), + "ts_tz_1": exp.DataType.build("TIMESTAMP(1) WITH TIME ZONE"), + } + + # Apply mapping first, then convert to delta types (skipping mapped columns) + mapped_columns_to_types, mapped_column_names = adapter._apply_timestamp_mapping( + columns_to_types + ) + delta_columns_to_types = adapter._to_delta_ts(mapped_columns_to_types, mapped_column_names) + + # All types were mapped, so _to_delta_ts skips them - they keep their mapped types + assert delta_columns_to_types == { + "ts": ts3, + "ts_1": ts3, + "ts_tz": ts6_tz, + "ts_tz_1": ts6_tz, + } + + +def test_delta_timestamps_with_partial_mapping(make_mocked_engine_adapter: t.Callable): + """Test that _apply_timestamp_mapping + _to_delta_ts uses custom mapping for specified types.""" + config = TrinoConnectionConfig( + user="user", + host="host", + catalog="catalog", + timestamp_mapping={ + "TIMESTAMP": "TIMESTAMP(3)", + }, + ) + + adapter = make_mocked_engine_adapter( + TrinoEngineAdapter, timestamp_mapping=config.timestamp_mapping + ) + + ts3 = exp.DataType.build("TIMESTAMP(3)") + ts6 = exp.DataType.build("timestamp(6)") + ts3_tz = exp.DataType.build("timestamp(3) with time zone") + + columns_to_types = { + "ts": exp.DataType.build("TIMESTAMP"), + "ts_1": exp.DataType.build("TIMESTAMP(1)"), + "ts_tz": exp.DataType.build("TIMESTAMP WITH TIME ZONE"), + } + + # Apply mapping first, then convert to delta types (skipping mapped columns) + mapped_columns_to_types, mapped_column_names = adapter._apply_timestamp_mapping( + columns_to_types + ) + delta_columns_to_types = adapter._to_delta_ts(mapped_columns_to_types, mapped_column_names) + + # TIMESTAMP is in mapping → TIMESTAMP(3), skipped by _to_delta_ts + # TIMESTAMP(1) is NOT in mapping, uses default TIMESTAMP → ts6 + # TIMESTAMP WITH TIME ZONE is NOT in mapping, uses default TIMESTAMPTZ → ts3_tz + assert delta_columns_to_types == { + "ts": ts3, # Mapped to TIMESTAMP(3), skipped by _to_delta_ts + "ts_1": ts6, # Not in mapping, uses default + "ts_tz": ts3_tz, # Not in mapping, uses default + } + + +def test_table_format(trino_mocked_engine_adapter: TrinoEngineAdapter, mocker: MockerFixture): + adapter = trino_mocked_engine_adapter + mocker.patch( + "sqlmesh.core.engine_adapter.trino.TrinoEngineAdapter.get_current_catalog", + return_value="iceberg", + ) + + expressions = d.parse( + """ + MODEL ( + name iceberg.test_table, + kind FULL, + table_format iceberg, + storage_format orc + ); + + SELECT 1::timestamp AS cola, 2::varchar as colb, 'foo' as colc; + """ + ) + model: SqlModel = t.cast(SqlModel, load_sql_based_model(expressions)) + + adapter.create_table( + table_name=model.name, + target_columns_to_types=model.columns_to_types_or_raise, + table_format=model.table_format, + storage_format=model.storage_format, + ) + + adapter.ctas( + table_name=model.name, + query_or_df=t.cast(exp.Query, model.query), + target_columns_to_types=model.columns_to_types_or_raise, + table_format=model.table_format, + storage_format=model.storage_format, + ) + + # Trino needs to ignore the `table_format` property because to create Iceberg tables, you target an Iceberg catalog + # rather than explicitly telling it to create an Iceberg table. So this is testing that `FORMAT='ORC'` is output + # instead of `FORMAT='ICEBERG'` which would be invalid + assert to_sql_calls(adapter) == [ + """CREATE TABLE IF NOT EXISTS "iceberg"."test_table" ("cola" TIMESTAMP, "colb" VARCHAR, "colc" VARCHAR) WITH (format='orc')""", + '''CREATE TABLE IF NOT EXISTS "iceberg"."test_table" WITH (format='orc') AS SELECT CAST("cola" AS TIMESTAMP) AS "cola", CAST("colb" AS VARCHAR) AS "colb", CAST("colc" AS VARCHAR) AS "colc" FROM (SELECT CAST(1 AS TIMESTAMP) AS "cola", CAST(2 AS VARCHAR) AS "colb", \'foo\' AS "colc") AS "_subquery"''', + ] + + +def test_table_location(trino_mocked_engine_adapter: TrinoEngineAdapter, mocker: MockerFixture): + adapter = trino_mocked_engine_adapter + mocker.patch( + "sqlmesh.core.engine_adapter.trino.TrinoEngineAdapter.get_current_catalog", + return_value="iceberg", + ) + + expressions = d.parse( + """ + MODEL ( + name iceberg.test_table, + kind FULL, + physical_properties ( + location = 'hdfs://some/table/location' + ) + ); + + SELECT 1::timestamp AS cola, 2::varchar as colb, 'foo' as colc; + """ + ) + model: SqlModel = t.cast(SqlModel, load_sql_based_model(expressions)) + + adapter.create_table( + table_name=model.name, + target_columns_to_types=model.columns_to_types_or_raise, + table_properties=model.physical_properties, + ) + + adapter.ctas( + table_name=model.name, + query_or_df=t.cast(exp.Query, model.query), + target_columns_to_types=model.columns_to_types_or_raise, + table_properties=model.physical_properties, + ) + + assert to_sql_calls(adapter) == [ + 'CREATE TABLE IF NOT EXISTS "iceberg"."test_table" ("cola" TIMESTAMP, "colb" VARCHAR, "colc" VARCHAR) WITH (location=\'hdfs://some/table/location\')', + 'CREATE TABLE IF NOT EXISTS "iceberg"."test_table" WITH (location=\'hdfs://some/table/location\') AS SELECT CAST("cola" AS TIMESTAMP) AS "cola", CAST("colb" AS VARCHAR) AS "colb", CAST("colc" AS VARCHAR) AS "colc" FROM (SELECT CAST(1 AS TIMESTAMP) AS "cola", CAST(2 AS VARCHAR) AS "colb", \'foo\' AS "colc") AS "_subquery"', + ] + + +def test_schema_location_mapping(): + config = TrinoConnectionConfig( + user="user", + host="host", + catalog="catalog", + ) + + adapter = config.create_engine_adapter() + assert adapter.schema_location_mapping is None + assert adapter._schema_location("foo") is None + + config = TrinoConnectionConfig( + user="user", + host="host", + catalog="catalog", + schema_location_mapping={ + "^utils$": "s3://utils-bucket/@{schema_name}", + "^landing\\..*$": "s3://raw-data/@{catalog_name}/@{schema_name}", + "^staging.*$": "s3://bucket/@{schema_name}_dev", + "^sqlmesh.*$": "s3://sqlmesh-internal/dev/@{schema_name}", + }, + ) + adapter = config.create_engine_adapter() + assert adapter.schema_location_mapping is not None + assert adapter._schema_location("foo") is None + assert adapter._schema_location("utils_dev") is None + assert adapter._schema_location("utils") == "s3://utils-bucket/utils" + assert adapter._schema_location("staging_customers") == "s3://bucket/staging_customers_dev" + assert adapter._schema_location("staging_accounts") == "s3://bucket/staging_accounts_dev" + assert ( + adapter._schema_location("sqlmesh__staging_customers") + == "s3://sqlmesh-internal/dev/sqlmesh__staging_customers" + ) + assert ( + adapter._schema_location("sqlmesh__staging_utils") + == "s3://sqlmesh-internal/dev/sqlmesh__staging_utils" + ) + assert adapter._schema_location("landing.transactions") == "s3://raw-data/landing/transactions" + assert ( + adapter._schema_location(schema_("transactions", "landing")) + == "s3://raw-data/landing/transactions" + ) + assert ( + adapter._schema_location('"landing"."transactions"') == "s3://raw-data/landing/transactions" + ) + + +def test_create_schema_sets_location(make_mocked_engine_adapter: t.Callable, mocker: MockerFixture): + mocker.patch( + "sqlmesh.core.engine_adapter.trino.TrinoEngineAdapter.get_catalog_type", + return_value="iceberg", + ) + + mocker.patch( + "sqlmesh.core.engine_adapter.trino.TrinoEngineAdapter._block_until_table_exists", + return_value=True, + ) + + config = TrinoConnectionConfig( + user="user", + host="host", + catalog="catalog", + schema_location_mapping={ + "^utils$": "s3://utils-bucket/@{schema_name}", + "^landing\\..*$": "s3://raw-data/@{catalog_name}/@{schema_name}", + "^staging.*$": "s3://bucket/@{schema_name}_dev", + "^sqlmesh.*$": "s3://sqlmesh-internal/dev/@{schema_name}", + "^iceberg\\.staging.*$": "s3://iceberg-catalog/foo_@{schema_name}", + }, + ) + + adapter: TrinoEngineAdapter = make_mocked_engine_adapter( + TrinoEngineAdapter, schema_location_mapping=config.schema_location_mapping + ) + + adapter.create_schema("foo") + adapter.create_schema(schema_("utils_dev", "db")) + adapter.create_schema(schema_("utils", "db")) + adapter.create_schema(schema_("utils")) + adapter.create_schema(schema_("sqlmesh")) + adapter.create_schema("sqlmesh__staging") + adapter.create_schema(schema_("snapshots", "sqlmesh")) + adapter.create_schema(schema_("staging_foo")) + adapter.create_schema(schema_("staging_bar", "iceberg")) + adapter.create_schema('"catalog"."staging_customers"') + adapter.create_schema(schema_("transactions", "landing")) + + assert ( + to_sql_calls(adapter) + == [ + 'CREATE SCHEMA IF NOT EXISTS "foo"', # no match + 'CREATE SCHEMA IF NOT EXISTS "db"."utils_dev"', # no match + 'CREATE SCHEMA IF NOT EXISTS "db"."utils"', # no match on '^utils$' because of catalog + "CREATE SCHEMA IF NOT EXISTS \"utils\" WITH (LOCATION='s3://utils-bucket/utils')", # match '^utils$' + "CREATE SCHEMA IF NOT EXISTS \"sqlmesh\" WITH (LOCATION='s3://sqlmesh-internal/dev/sqlmesh')", # match '^sqlmesh.*$' + "CREATE SCHEMA IF NOT EXISTS \"sqlmesh__staging\" WITH (LOCATION='s3://sqlmesh-internal/dev/sqlmesh__staging')", # match '^sqlmesh.*$' + 'CREATE SCHEMA IF NOT EXISTS "sqlmesh"."snapshots" WITH (LOCATION=\'s3://sqlmesh-internal/dev/snapshots\')', # match '^sqlmesh.*$' on the catalog + "CREATE SCHEMA IF NOT EXISTS \"staging_foo\" WITH (LOCATION='s3://bucket/staging_foo_dev')", # match '^staging.*$' + 'CREATE SCHEMA IF NOT EXISTS "iceberg"."staging_bar" WITH (LOCATION=\'s3://iceberg-catalog/foo_staging_bar\')', # match '^iceberg\.staging.*$' + 'CREATE SCHEMA IF NOT EXISTS "catalog"."staging_customers"', # no match + 'CREATE SCHEMA IF NOT EXISTS "landing"."transactions" WITH (LOCATION=\'s3://raw-data/landing/transactions\')', # match '^landing\..*$' + ] + ) + + +def test_session_authorization(trino_mocked_engine_adapter: TrinoEngineAdapter): + adapter = trino_mocked_engine_adapter + + # Test 1: No authorization property - should not execute any authorization commands + with adapter.session({}): + pass + + assert to_sql_calls(adapter) == [] + + # Test 2: String authorization + with adapter.session({"authorization": "test_user"}): + adapter.execute("SELECT 1") + + assert to_sql_calls(adapter) == [ + "SET SESSION AUTHORIZATION 'test_user'", + "SELECT 1", + "RESET SESSION AUTHORIZATION", + ] + + # Test 3: Expression authorization + adapter.cursor.execute.reset_mock() + with adapter.session({"authorization": exp.Literal.string("another_user")}): + adapter.execute("SELECT 2") + + assert to_sql_calls(adapter) == [ + "SET SESSION AUTHORIZATION 'another_user'", + "SELECT 2", + "RESET SESSION AUTHORIZATION", + ] + + # Test 4: RESET is called even if exception occurs during session + adapter.cursor.execute.reset_mock() + try: + with adapter.session({"authorization": "test_user"}): + adapter.execute("SELECT 1") + raise RuntimeError("Test exception") + except RuntimeError: + pass + + # Test 5: Invalid authorization value + with pytest.raises( + SQLMeshError, + match="Invalid value for `session_properties.authorization`. Must be a string literal.", + ): + with adapter.session({"authorization": exp.Literal.number(1)}): + adapter.execute("SELECT 1") + + assert to_sql_calls(adapter) == [ + "SET SESSION AUTHORIZATION 'test_user'", + "SELECT 1", + "RESET SESSION AUTHORIZATION", + ] + + +@pytest.mark.parametrize( + "catalog_name,expected_replace", + [ + ("hive_catalog", False), + ("iceberg_catalog", True), + ("delta_catalog", False), + ("acme_delta_lake", True), + ("acme_iceberg", True), + ("custom_delta_lake_something", True), + ("my_iceberg_store", True), + ("plain_catalog", False), + ], +) +def test_replace_table_catalog_support( + trino_mocked_engine_adapter: TrinoEngineAdapter, catalog_name, expected_replace +): + adapter = trino_mocked_engine_adapter + + adapter.replace_query( + table_name=".".join([catalog_name, "schema", "test_table"]), + query_or_df=t.cast(exp.Query, parse_one("SELECT 1 AS col")), + ) + + sql_calls = to_sql_calls(adapter) + assert len(sql_calls) == 1 + if expected_replace: + assert ( + sql_calls[0] + == f'CREATE OR REPLACE TABLE "{catalog_name}"."schema"."test_table" AS SELECT 1 AS "col"' + ) + else: + assert ( + sql_calls[0] + == f'CREATE TABLE IF NOT EXISTS "{catalog_name}"."schema"."test_table" AS SELECT 1 AS "col"' + ) + + +@pytest.mark.parametrize( + "catalog_type_overrides", [{}, {"my_catalog": "hive"}, {"other_catalog": "iceberg"}] +) +def test_insert_overwrite_time_partition_hive( + make_mocked_engine_adapter: t.Callable, catalog_type_overrides: t.Dict[str, str] +): + config = TrinoConnectionConfig( + user="user", + host="host", + catalog="catalog", + catalog_type_overrides=catalog_type_overrides, + ) + adapter: TrinoEngineAdapter = make_mocked_engine_adapter( + TrinoEngineAdapter, catalog_type_overrides=config.catalog_type_overrides + ) + adapter.fetchone = MagicMock(return_value=None) # type: ignore + + adapter.insert_overwrite_by_time_partition( + table_name=".".join(["my_catalog", "schema", "test_table"]), + query_or_df=t.cast(exp.Query, parse_one("SELECT a, b FROM tbl")), + start="2022-01-01", + end="2022-01-02", + time_column="b", + time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), + target_columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("STRING")}, + ) + + assert to_sql_calls(adapter) == [ + "SET SESSION my_catalog.insert_existing_partitions_behavior='OVERWRITE'", + 'INSERT INTO "my_catalog"."schema"."test_table" ("a", "b") SELECT "a", "b" FROM (SELECT "a", "b" FROM "tbl") AS "_subquery" WHERE "b" BETWEEN \'2022-01-01\' AND \'2022-01-02\'', + "SET SESSION my_catalog.insert_existing_partitions_behavior='APPEND'", + ] + + +@pytest.mark.parametrize( + "catalog_type_overrides", + [ + {"my_catalog": "iceberg"}, + {"my_catalog": "unknown"}, + ], +) +def test_insert_overwrite_time_partition_iceberg( + make_mocked_engine_adapter: t.Callable, catalog_type_overrides: t.Dict[str, str] +): + config = TrinoConnectionConfig( + user="user", + host="host", + catalog="catalog", + catalog_type_overrides=catalog_type_overrides, + ) + adapter: TrinoEngineAdapter = make_mocked_engine_adapter( + TrinoEngineAdapter, catalog_type_overrides=config.catalog_type_overrides + ) + adapter.fetchone = MagicMock(return_value=None) # type: ignore + + adapter.insert_overwrite_by_time_partition( + table_name=".".join(["my_catalog", "schema", "test_table"]), + query_or_df=t.cast(exp.Query, parse_one("SELECT a, b FROM tbl")), + start="2022-01-01", + end="2022-01-02", + time_column="b", + time_formatter=lambda x, _: exp.Literal.string(to_ds(x)), + target_columns_to_types={"a": exp.DataType.build("INT"), "b": exp.DataType.build("STRING")}, + ) + + assert to_sql_calls(adapter) == [ + 'DELETE FROM "my_catalog"."schema"."test_table" WHERE "b" BETWEEN \'2022-01-01\' AND \'2022-01-02\'', + 'INSERT INTO "my_catalog"."schema"."test_table" ("a", "b") SELECT "a", "b" FROM (SELECT "a", "b" FROM "tbl") AS "_subquery" WHERE "b" BETWEEN \'2022-01-01\' AND \'2022-01-02\'', + ] + + +def test_delta_timestamps_with_non_timestamp_columns(make_mocked_engine_adapter: t.Callable): + """Test that _apply_timestamp_mapping + _to_delta_ts handles non-timestamp columns.""" + config = TrinoConnectionConfig( + user="user", + host="host", + catalog="catalog", + timestamp_mapping={ + "TIMESTAMP": "TIMESTAMP(3)", + }, + ) + + adapter = make_mocked_engine_adapter( + TrinoEngineAdapter, timestamp_mapping=config.timestamp_mapping + ) + + ts3 = exp.DataType.build("TIMESTAMP(3)") + ts6 = exp.DataType.build("timestamp(6)") + + columns_to_types = { + "ts": exp.DataType.build("TIMESTAMP"), + "ts_1": exp.DataType.build("TIMESTAMP(1)"), + "int_col": exp.DataType.build("INT"), + "varchar_col": exp.DataType.build("VARCHAR(100)"), + "decimal_col": exp.DataType.build("DECIMAL(10,2)"), + } + + # Apply mapping first, then convert to delta types (skipping mapped columns) + mapped_columns_to_types, mapped_column_names = adapter._apply_timestamp_mapping( + columns_to_types + ) + delta_columns_to_types = adapter._to_delta_ts(mapped_columns_to_types, mapped_column_names) + + # TIMESTAMP is in mapping → TIMESTAMP(3), skipped by _to_delta_ts + # TIMESTAMP(1) is NOT in mapping (exact match), uses default TIMESTAMP → ts6 + # Non-timestamp columns should pass through unchanged + assert delta_columns_to_types == { + "ts": ts3, # Mapped to TIMESTAMP(3), skipped by _to_delta_ts + "ts_1": ts6, # Not in mapping, uses default + "int_col": exp.DataType.build("INT"), + "varchar_col": exp.DataType.build("VARCHAR(100)"), + "decimal_col": exp.DataType.build("DECIMAL(10,2)"), + } + + +def test_delta_timestamps_with_empty_mapping(make_mocked_engine_adapter: t.Callable): + """Test that _to_delta_ts handles empty custom mapping dictionary.""" + config = TrinoConnectionConfig( + user="user", + host="host", + catalog="catalog", + timestamp_mapping={}, + ) + + adapter = make_mocked_engine_adapter( + TrinoEngineAdapter, timestamp_mapping=config.timestamp_mapping + ) + + ts6 = exp.DataType.build("timestamp(6)") + ts3_tz = exp.DataType.build("timestamp(3) with time zone") + + columns_to_types = { + "ts": exp.DataType.build("TIMESTAMP"), + "ts_tz": exp.DataType.build("TIMESTAMP WITH TIME ZONE"), + } + + delta_columns_to_types = adapter._to_delta_ts(columns_to_types) + + # With empty custom mapping, should fall back to defaults + assert delta_columns_to_types == { + "ts": ts6, + "ts_tz": ts3_tz, + } diff --git a/tests/core/engine_adapter/trino/catalog/datalake.properties b/tests/core/engine_adapter/trino/catalog/datalake.properties deleted file mode 100644 index d960895a66..0000000000 --- a/tests/core/engine_adapter/trino/catalog/datalake.properties +++ /dev/null @@ -1,15 +0,0 @@ -connector.name=hive -hive.metastore.uri=thrift://trino-datalake-hive-metastore:9083 -hive.s3.endpoint=http://minio:9000 -hive.s3.path-style-access=true -hive.s3.aws-access-key=minio -hive.s3.aws-secret-key=minio123 -hive.metastore-cache-ttl=0s -hive.metastore-refresh-interval=5s -hive.metastore-timeout=10s -hive.allow-drop-table=true -hive.allow-add-column=true -hive.allow-drop-column=true -hive.allow-rename-column=true -hive.allow-rename-table=true -hive.storage-format=PARQUET diff --git a/tests/core/engine_adapter/trino/catalog/datalake_delta.properties b/tests/core/engine_adapter/trino/catalog/datalake_delta.properties deleted file mode 100644 index a2ee5c99fe..0000000000 --- a/tests/core/engine_adapter/trino/catalog/datalake_delta.properties +++ /dev/null @@ -1,11 +0,0 @@ -connector.name=delta_lake -hive.metastore.uri=thrift://trino-datalake-delta-hive-metastore:9083 -hive.metastore-cache-ttl=0s -hive.metastore-refresh-interval=5s -hive.metastore-timeout=10s -hive.s3.endpoint=http://minio:9000 -hive.s3.aws-access-key=minio -hive.s3.aws-secret-key=minio123 -hive.s3.path-style-access=true -delta.enable-non-concurrent-writes=true -delta.hive-catalog-name=datalake_delta \ No newline at end of file diff --git a/tests/core/engine_adapter/trino/catalog/datalake_iceberg.properties b/tests/core/engine_adapter/trino/catalog/datalake_iceberg.properties deleted file mode 100644 index caee1604cc..0000000000 --- a/tests/core/engine_adapter/trino/catalog/datalake_iceberg.properties +++ /dev/null @@ -1,13 +0,0 @@ -connector.name=iceberg -# note: we have to use a Hive metastore instead of the REST catalog because -# as at 2024-02-16 its the only one that supports views -iceberg.catalog.type=hive_metastore -iceberg.file-format=PARQUET -hive.metastore.uri=thrift://trino-datalake-iceberg-hive-metastore:9083 -hive.metastore-cache-ttl=0s -hive.metastore-refresh-interval=5s -hive.metastore-timeout=10s -hive.s3.endpoint=http://minio:9000 -hive.s3.path-style-access=true -hive.s3.aws-access-key=minio -hive.s3.aws-secret-key=minio123 diff --git a/tests/core/engine_adapter/trino/catalog/testing.properties b/tests/core/engine_adapter/trino/catalog/testing.properties deleted file mode 100644 index bdc39d6e4b..0000000000 --- a/tests/core/engine_adapter/trino/catalog/testing.properties +++ /dev/null @@ -1,15 +0,0 @@ -connector.name=hive -hive.metastore.uri=thrift://trino-testing-hive-metastore:9083 -hive.s3.endpoint=http://minio:9000 -hive.s3.path-style-access=true -hive.s3.aws-access-key=minio -hive.s3.aws-secret-key=minio123 -hive.metastore-cache-ttl=0s -hive.metastore-refresh-interval=5s -hive.metastore-timeout=10s -hive.allow-drop-table=true -hive.allow-add-column=true -hive.allow-drop-column=true -hive.allow-rename-column=true -hive.allow-rename-table=true -hive.storage-format=PARQUET diff --git a/tests/core/engine_adapter/trino/initdb.sql b/tests/core/engine_adapter/trino/initdb.sql deleted file mode 100644 index ae71ef668b..0000000000 --- a/tests/core/engine_adapter/trino/initdb.sql +++ /dev/null @@ -1,6 +0,0 @@ -create database datalake_metastore; -create database datalake_iceberg_metastore; -create database datalake_delta_metastore; -create database testing_metastore; -create database testing_iceberg_metastore; -create database testing_delta_metastore; \ No newline at end of file diff --git a/sqlmesh/schedulers/airflow/operators/__init__.py b/tests/core/integration/__init__.py similarity index 100% rename from sqlmesh/schedulers/airflow/operators/__init__.py rename to tests/core/integration/__init__.py diff --git a/tests/core/integration/conftest.py b/tests/core/integration/conftest.py new file mode 100644 index 0000000000..99875e5974 --- /dev/null +++ b/tests/core/integration/conftest.py @@ -0,0 +1,8 @@ +import pytest +from pytest_mock.plugin import MockerFixture + + +@pytest.fixture(autouse=True) +def mock_choices(mocker: MockerFixture): + mocker.patch("sqlmesh.core.console.TerminalConsole._get_snapshot_change_category") + mocker.patch("sqlmesh.core.console.TerminalConsole._prompt_backfill") diff --git a/tests/core/integration/test_audits.py b/tests/core/integration/test_audits.py new file mode 100644 index 0000000000..457974fdac --- /dev/null +++ b/tests/core/integration/test_audits.py @@ -0,0 +1,348 @@ +from __future__ import annotations + +import typing as t +from textwrap import dedent +import pytest +from pathlib import Path +import time_machine +from sqlglot import exp +from IPython.utils.capture import capture_output + +from sqlmesh.core.config import ( + Config, + ModelDefaultsConfig, +) +from sqlmesh.core.context import Context +from sqlmesh.utils.errors import ( + PlanError, +) +from tests.utils.test_helpers import use_terminal_console +from tests.utils.test_filesystem import create_temp_file + +pytestmark = pytest.mark.slow + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +@use_terminal_console +def test_audit_only_metadata_change(init_and_plan_context: t.Callable): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + # Add a new audit + model = context.get_model("sushi.waiter_revenue_by_day") + audits = model.audits.copy() + audits.append(("number_of_rows", {"threshold": exp.Literal.number(1)})) + model = model.copy(update={"audits": audits}) + context.upsert_model(model) + + plan = context.plan_builder("prod", skip_tests=True).build() + assert len(plan.new_snapshots) == 2 + assert all(s.change_category.is_metadata for s in plan.new_snapshots) + assert not plan.missing_intervals + + with capture_output() as output: + context.apply(plan) + + assert "Auditing models" in output.stdout + assert model.name in output.stdout + + +@use_terminal_console +def test_audits_running_on_metadata_changes(tmp_path: Path): + def setup_senario(model_before: str, model_after: str): + models_dir = Path("models") + create_temp_file(tmp_path, models_dir / "test.sql", model_before) + + # Create first snapshot + context = Context(paths=tmp_path, config=Config()) + context.plan("prod", no_prompts=True, auto_apply=True) + + # Create second (metadata) snapshot + create_temp_file(tmp_path, models_dir / "test.sql", model_after) + context.load() + + with capture_output() as output: + with pytest.raises(PlanError): + context.plan("prod", no_prompts=True, auto_apply=True) + + assert 'Failed models\n\n "model"' in output.stdout + + return output + + # Ensure incorrect audits (bad data, incorrect definition etc) are evaluated immediately + output = setup_senario( + "MODEL (name model); SELECT NULL AS col", + "MODEL (name model, audits (not_null(columns=[col]))); SELECT NULL AS col", + ) + assert "'not_null' audit error: 1 row failed" in output.stdout + + output = setup_senario( + "MODEL (name model); SELECT NULL AS col", + "MODEL (name model, audits (not_null(columns=[this_col_does_not_exist]))); SELECT NULL AS col", + ) + assert ( + 'Binder Error: Referenced column "this_col_does_not_exist" not found in \nFROM clause!' + in output.stdout + ) + + +@pytest.mark.slow +def test_default_audits_applied_in_plan(tmp_path: Path): + models_dir = tmp_path / "models" + models_dir.mkdir(exist_ok=True) + + # Create a model with data that will pass the audits + create_temp_file( + tmp_path, + models_dir / "orders.sql", + dedent(""" + MODEL ( + name test.orders, + kind FULL + ); + + SELECT + 1 AS order_id, + 'customer_1' AS customer_id, + 100.50 AS amount, + '2024-01-01'::DATE AS order_date + UNION ALL + SELECT + 2 AS order_id, + 'customer_2' AS customer_id, + 200.75 AS amount, + '2024-01-02'::DATE AS order_date + """), + ) + + config = Config( + model_defaults=ModelDefaultsConfig( + dialect="duckdb", + audits=[ + "not_null(columns := [order_id, customer_id])", + "unique_values(columns := [order_id])", + ], + ) + ) + + context = Context(paths=tmp_path, config=config) + + # Create and apply plan, here audits should pass + plan = context.plan("prod", no_prompts=True) + context.apply(plan) + + # Verify model has the default audits + model = context.get_model("test.orders") + assert len(model.audits) == 2 + + audit_names = [audit[0] for audit in model.audits] + assert "not_null" in audit_names + assert "unique_values" in audit_names + + # Verify audit arguments are preserved + for audit_name, audit_args in model.audits: + if audit_name == "not_null": + assert "columns" in audit_args + columns = [col.name for col in audit_args["columns"].expressions] + assert "order_id" in columns + assert "customer_id" in columns + elif audit_name == "unique_values": + assert "columns" in audit_args + columns = [col.name for col in audit_args["columns"].expressions] + assert "order_id" in columns + + +@pytest.mark.slow +def test_default_audits_fail_on_bad_data(tmp_path: Path): + models_dir = tmp_path / "models" + models_dir.mkdir(exist_ok=True) + + # Create a model with data that violates NOT NULL constraint + create_temp_file( + tmp_path, + models_dir / "bad_orders.sql", + dedent(""" + MODEL ( + name test.bad_orders, + kind FULL + ); + + SELECT + 1 AS order_id, + NULL AS customer_id, -- This violates NOT NULL + 100.50 AS amount, + '2024-01-01'::DATE AS order_date + UNION ALL + SELECT + 2 AS order_id, + 'customer_2' AS customer_id, + 200.75 AS amount, + '2024-01-02'::DATE AS order_date + """), + ) + + config = Config( + model_defaults=ModelDefaultsConfig( + dialect="duckdb", audits=["not_null(columns := [customer_id])"] + ) + ) + + context = Context(paths=tmp_path, config=config) + + # Plan should fail due to audit failure + with pytest.raises(PlanError): + context.plan("prod", no_prompts=True, auto_apply=True) + + +@pytest.mark.slow +def test_default_audits_with_model_specific_audits(tmp_path: Path): + models_dir = tmp_path / "models" + models_dir.mkdir(exist_ok=True) + audits_dir = tmp_path / "audits" + audits_dir.mkdir(exist_ok=True) + + create_temp_file( + tmp_path, + audits_dir / "range_check.sql", + dedent(""" + AUDIT ( + name range_check + ); + + SELECT * FROM @this_model + WHERE @column < @min_value OR @column > @max_value + """), + ) + + # Create a model with its own audits in addition to defaults + create_temp_file( + tmp_path, + models_dir / "products.sql", + dedent(""" + MODEL ( + name test.products, + kind FULL, + audits ( + range_check(column := price, min_value := 0, max_value := 10000) + ) + ); + + SELECT + 1 AS product_id, + 'Widget' AS product_name, + 99.99 AS price + UNION ALL + SELECT + 2 AS product_id, + 'Gadget' AS product_name, + 149.99 AS price + """), + ) + + config = Config( + model_defaults=ModelDefaultsConfig( + dialect="duckdb", + audits=[ + "not_null(columns := [product_id, product_name])", + "unique_values(columns := [product_id])", + ], + ) + ) + + context = Context(paths=tmp_path, config=config) + + # Create and apply plan + plan = context.plan("prod", no_prompts=True) + context.apply(plan) + + # Verify model has both default and model-specific audits + model = context.get_model("test.products") + assert len(model.audits) == 3 + + audit_names = [audit[0] for audit in model.audits] + assert "not_null" in audit_names + assert "unique_values" in audit_names + assert "range_check" in audit_names + + # Verify audit execution order, default audits first then model-specific + assert model.audits[0][0] == "not_null" + assert model.audits[1][0] == "unique_values" + assert model.audits[2][0] == "range_check" + + +@pytest.mark.slow +def test_default_audits_with_custom_audit_definitions(tmp_path: Path): + models_dir = tmp_path / "models" + models_dir.mkdir(exist_ok=True) + audits_dir = tmp_path / "audits" + audits_dir.mkdir(exist_ok=True) + + # Create custom audit definition + create_temp_file( + tmp_path, + audits_dir / "positive_amount.sql", + dedent(""" + AUDIT ( + name positive_amount + ); + + SELECT * FROM @this_model + WHERE @column <= 0 + """), + ) + + # Create a model + create_temp_file( + tmp_path, + models_dir / "transactions.sql", + dedent(""" + MODEL ( + name test.transactions, + kind FULL + ); + + SELECT + 1 AS transaction_id, + 'TXN001' AS transaction_code, + 250.00 AS amount, + '2024-01-01'::DATE AS transaction_date + UNION ALL + SELECT + 2 AS transaction_id, + 'TXN002' AS transaction_code, + 150.00 AS amount, + '2024-01-02'::DATE AS transaction_date + """), + ) + + config = Config( + model_defaults=ModelDefaultsConfig( + dialect="duckdb", + audits=[ + "not_null(columns := [transaction_id, transaction_code])", + "unique_values(columns := [transaction_id])", + "positive_amount(column := amount)", + ], + ) + ) + + context = Context(paths=tmp_path, config=config) + + # Create and apply plan + plan = context.plan("prod", no_prompts=True) + context.apply(plan) + + # Verify model has all default audits including custom + model = context.get_model("test.transactions") + assert len(model.audits) == 3 + + audit_names = [audit[0] for audit in model.audits] + assert "not_null" in audit_names + assert "unique_values" in audit_names + assert "positive_amount" in audit_names + + # Verify custom audit arguments + for audit_name, audit_args in model.audits: + if audit_name == "positive_amount": + assert "column" in audit_args + assert audit_args["column"].name == "amount" diff --git a/tests/core/integration/test_auto_restatement.py b/tests/core/integration/test_auto_restatement.py new file mode 100644 index 0000000000..70ca227fd3 --- /dev/null +++ b/tests/core/integration/test_auto_restatement.py @@ -0,0 +1,219 @@ +from __future__ import annotations + +import typing as t +import pandas as pd # noqa: TID253 +import pytest +import time_machine +from sqlglot import exp + +from sqlmesh.core import dialect as d +from sqlmesh.core.macros import macro +from sqlmesh.core.model import ( + load_sql_based_model, +) +from sqlmesh.core.plan import SnapshotIntervals +from sqlmesh.utils.date import to_timestamp + +pytestmark = pytest.mark.slow + + +@time_machine.travel("2023-01-08 01:00:00 UTC") +def test_run_auto_restatement(init_and_plan_context: t.Callable): + context, _ = init_and_plan_context("examples/sushi") + + context.engine_adapter.execute( + "CREATE TABLE _test_auto_restatement_intervals (name STRING, start_ds STRING, end_ds STRING)" + ) + + @macro() + def record_intervals( + evaluator, name: exp.Expression, start: exp.Expression, end: exp.Expression, **kwargs: t.Any + ) -> None: + if evaluator.runtime_stage == "evaluating": + evaluator.engine_adapter.insert_append( + "_test_auto_restatement_intervals", + pd.DataFrame({"name": [name.name], "start_ds": [start.name], "end_ds": [end.name]}), + ) + + new_model_expr = d.parse( + """ + MODEL ( + name memory.sushi.new_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds, + auto_restatement_cron '0 6 * * 7', -- At 6am every Sunday + auto_restatement_intervals 3, + ), + start '2023-01-01', + ); + + @record_intervals('new_model', @start_ds, @end_ds); + + SELECT '2023-01-07' AS ds, 1 AS a; + """ + ) + new_model = load_sql_based_model(new_model_expr) + context.upsert_model(new_model) + + new_model_downstream_expr = d.parse( + """ + MODEL ( + name memory.sushi.new_model_downstream, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds, + ), + cron '@hourly', + ); + + @record_intervals('new_model_downstream', @start_ts, @end_ts); + + SELECT * FROM memory.sushi.new_model; + """ + ) + new_model_downstream = load_sql_based_model(new_model_downstream_expr) + context.upsert_model(new_model_downstream) + + plan = context.plan_builder("prod").build() + context.apply(plan) + + with time_machine.travel("2023-01-08 06:01:00 UTC"): + assert context.run() + + recorded_intervals_df = context.engine_adapter.fetchdf( + "SELECT start_ds, end_ds FROM _test_auto_restatement_intervals WHERE name = 'new_model'" + ) + # The first interval is the first backfill and the second interval should be the 3 auto restated intervals + assert recorded_intervals_df.to_dict() == { + "start_ds": {0: "2023-01-01", 1: "2023-01-05"}, + "end_ds": {0: "2023-01-07", 1: "2023-01-07"}, + } + recorded_intervals_downstream_df = context.engine_adapter.fetchdf( + "SELECT start_ds, end_ds FROM _test_auto_restatement_intervals WHERE name = 'new_model_downstream'" + ) + # The first interval is the first backfill, the second interval should be the 3 days of restated intervals, and + # the third interval should catch up to the current hour + assert recorded_intervals_downstream_df.to_dict() == { + "start_ds": { + 0: "2023-01-01 00:00:00", + 1: "2023-01-05 00:00:00", + 2: "2023-01-08 01:00:00", + }, + "end_ds": { + 0: "2023-01-08 00:59:59.999999", + 1: "2023-01-07 23:59:59.999999", + 2: "2023-01-08 05:59:59.999999", + }, + } + + snapshot = context.get_snapshot(new_model.name) + snapshot = context.state_sync.state_sync.get_snapshots([snapshot.snapshot_id])[ + snapshot.snapshot_id + ] + assert snapshot.next_auto_restatement_ts == to_timestamp("2023-01-15 06:00:00") + assert not snapshot.pending_restatement_intervals + + snapshot_downstream = context.get_snapshot(new_model_downstream.name) + snapshot_downstream = context.state_sync.state_sync.get_snapshots( + [snapshot_downstream.snapshot_id] + )[snapshot_downstream.snapshot_id] + assert not snapshot_downstream.next_auto_restatement_ts + assert not snapshot_downstream.pending_restatement_intervals + + +@time_machine.travel("2023-01-08 01:00:00 UTC") +def test_run_auto_restatement_plan_preview(init_and_plan_context: t.Callable): + context, init_plan = init_and_plan_context("examples/sushi") + context.apply(init_plan) + + new_model_expr = d.parse( + """ + MODEL ( + name memory.sushi.new_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds, + auto_restatement_cron '0 6 * * 7', + ), + start '2023-01-01', + ); + + SELECT '2023-01-07' AS ds, 1 AS a; + """ + ) + new_model = load_sql_based_model(new_model_expr) + context.upsert_model(new_model) + snapshot = context.get_snapshot(new_model.name) + + plan_dev = context.plan_builder("dev").build() + # Make sure that a limited preview is computed by default + assert to_timestamp(plan_dev.start) == to_timestamp("2023-01-07") + assert plan_dev.missing_intervals == [ + SnapshotIntervals( + snapshot.snapshot_id, + [(to_timestamp("2023-01-07"), to_timestamp("2023-01-08"))], + ) + ] + assert not plan_dev.deployability_index.is_deployable(snapshot.snapshot_id) + context.apply(plan_dev) + + plan_prod = context.plan_builder("prod").build() + assert plan_prod.missing_intervals == [ + SnapshotIntervals( + context.get_snapshot(new_model.name).snapshot_id, + [ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ) + ] + context.apply(plan_prod) + + +@time_machine.travel("2023-01-08 01:00:00 UTC") +def test_run_auto_restatement_failure(init_and_plan_context: t.Callable): + context, _ = init_and_plan_context("examples/sushi") + + @macro() + def fail_auto_restatement(evaluator, start: exp.Expression, **kwargs: t.Any) -> None: + if evaluator.runtime_stage == "evaluating" and start.name != "2023-01-01": + raise Exception("Failed") + + new_model_expr = d.parse( + """ + MODEL ( + name memory.sushi.new_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds, + auto_restatement_cron '0 6 * * 7', -- At 6am every Sunday + auto_restatement_intervals 3, + ), + start '2023-01-01', + ); + + @fail_auto_restatement(@start_ds); + + SELECT '2023-01-07' AS ds, 1 AS a; + """ + ) + new_model = load_sql_based_model(new_model_expr) + context.upsert_model(new_model) + + plan = context.plan_builder("prod").build() + context.apply(plan) + + with time_machine.travel("2023-01-08 06:01:00 UTC"): + run_status = context.run() + assert run_status.is_failure + + snapshot = context.get_snapshot(new_model.name) + snapshot = context.state_sync.state_sync.get_snapshots([snapshot.snapshot_id])[ + snapshot.snapshot_id + ] + assert snapshot.next_auto_restatement_ts == to_timestamp("2023-01-15 06:00:00") + assert snapshot.pending_restatement_intervals == [ + (to_timestamp("2023-01-05"), to_timestamp("2023-01-08")) + ] diff --git a/tests/core/integration/test_aux_commands.py b/tests/core/integration/test_aux_commands.py new file mode 100644 index 0000000000..326e81e0c1 --- /dev/null +++ b/tests/core/integration/test_aux_commands.py @@ -0,0 +1,367 @@ +from __future__ import annotations + +import typing as t +from unittest.mock import patch +import pytest +from pathlib import Path +from sqlmesh.core.config.naming import NameInferenceConfig +from sqlmesh.core.model.common import ParsableSql +import time_machine +from pytest_mock.plugin import MockerFixture + +from sqlmesh.core.config import ( + Config, + GatewayConfig, + ModelDefaultsConfig, + DuckDBConnectionConfig, +) +from sqlmesh.core.context import Context +from sqlmesh.core.model import ( + SqlModel, +) +from sqlmesh.utils.errors import ( + SQLMeshError, +) +from sqlmesh.utils.date import now +from tests.conftest import DuckDBMetadata +from tests.utils.test_helpers import use_terminal_console +from tests.utils.test_filesystem import create_temp_file +from tests.core.integration.utils import add_projection_to_model, apply_to_environment + +pytestmark = pytest.mark.slow + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_table_name(init_and_plan_context: t.Callable): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + snapshot = context.get_snapshot("sushi.waiter_revenue_by_day") + assert snapshot + assert ( + context.table_name("sushi.waiter_revenue_by_day", "prod") + == f"memory.sqlmesh__sushi.sushi__waiter_revenue_by_day__{snapshot.version}" + ) + + with pytest.raises(SQLMeshError, match="Environment 'dev' was not found."): + context.table_name("sushi.waiter_revenue_by_day", "dev") + + with pytest.raises( + SQLMeshError, match="Model 'sushi.missing' was not found in environment 'prod'." + ): + context.table_name("sushi.missing", "prod") + + # Add a new projection + model = context.get_model("sushi.waiter_revenue_by_day") + context.upsert_model(add_projection_to_model(t.cast(SqlModel, model))) + + context.plan("dev_a", auto_apply=True, no_prompts=True, skip_tests=True) + + new_snapshot = context.get_snapshot("sushi.waiter_revenue_by_day") + assert new_snapshot.version != snapshot.version + + assert ( + context.table_name("sushi.waiter_revenue_by_day", "dev_a") + == f"memory.sqlmesh__sushi.sushi__waiter_revenue_by_day__{new_snapshot.version}" + ) + + # Make a forward-only change + context.upsert_model(model, stamp="forward_only") + + context.plan("dev_b", auto_apply=True, no_prompts=True, skip_tests=True, forward_only=True) + + forward_only_snapshot = context.get_snapshot("sushi.waiter_revenue_by_day") + assert forward_only_snapshot.version == snapshot.version + assert forward_only_snapshot.dev_version != snapshot.version + + assert ( + context.table_name("sushi.waiter_revenue_by_day", "dev_b") + == f"memory.sqlmesh__sushi.sushi__waiter_revenue_by_day__{forward_only_snapshot.dev_version}__dev" + ) + + assert ( + context.table_name("sushi.waiter_revenue_by_day", "dev_b", prod=True) + == f"memory.sqlmesh__sushi.sushi__waiter_revenue_by_day__{snapshot.version}" + ) + + +def test_janitor_cleanup_order(mocker: MockerFixture, tmp_path: Path): + def setup_scenario(): + models_dir = tmp_path / "models" + + if not models_dir.exists(): + models_dir.mkdir() + + model1_path = models_dir / "model1.sql" + + with open(model1_path, "w") as f: + f.write("MODEL(name test.model1, kind FULL); SELECT 1 AS col") + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + ) + ctx = Context(paths=[tmp_path], config=config) + + ctx.plan("dev", no_prompts=True, auto_apply=True) + + model1_snapshot = ctx.get_snapshot("test.model1") + + # Delete the model file to cause a snapshot expiration + model1_path.unlink() + + ctx.load() + + ctx.plan("dev", no_prompts=True, auto_apply=True) + + # Invalidate the environment to cause an environment cleanup + ctx.invalidate_environment("dev") + + try: + ctx._run_janitor(ignore_ttl=True) + except: + pass + + return ctx, model1_snapshot + + # Case 1: Assume that the snapshot cleanup yields an error, the snapshot records + # should still exist in the state sync so the next janitor can retry + mocker.patch( + "sqlmesh.core.snapshot.evaluator.SnapshotEvaluator.cleanup", + side_effect=Exception("snapshot cleanup error"), + ) + ctx, model1_snapshot = setup_scenario() + + # - Check that the snapshot record exists in the state sync + state_snapshot = ctx.state_sync.state_sync.get_snapshots([model1_snapshot.snapshot_id]) + assert state_snapshot + + # - Run the janitor again, this time it should succeed + mocker.patch("sqlmesh.core.snapshot.evaluator.SnapshotEvaluator.cleanup") + ctx._run_janitor(ignore_ttl=True) + + # - Check that the snapshot record does not exist in the state sync anymore + state_snapshot = ctx.state_sync.state_sync.get_snapshots([model1_snapshot.snapshot_id]) + assert not state_snapshot + + # Case 2: Assume that the view cleanup yields an error, the enviroment + # record should still exist + mocker.patch( + "sqlmesh.core.context.cleanup_expired_views", side_effect=Exception("view cleanup error") + ) + ctx, model1_snapshot = setup_scenario() + + views = ctx.fetchdf("FROM duckdb_views() SELECT * EXCLUDE(sql) WHERE NOT internal") + assert views.empty + + # - Check that the environment record exists in the state sync + assert ctx.state_sync.get_environment("dev") + + # - Run the janitor again, this time it should succeed + mocker.patch("sqlmesh.core.context.cleanup_expired_views") + ctx._run_janitor(ignore_ttl=True) + + # - Check that the environment record does not exist in the state sync anymore + assert not ctx.state_sync.get_environment("dev") + + +@use_terminal_console +def test_destroy(copy_to_temp_path): + # Testing project with two gateways to verify cleanup is performed across engines + paths = copy_to_temp_path("tests/fixtures/multi_virtual_layer") + path = Path(paths[0]) + first_db_path = str(path / "db_1.db") + second_db_path = str(path / "db_2.db") + + config = Config( + gateways={ + "first": GatewayConfig( + connection=DuckDBConnectionConfig(database=first_db_path), + variables={"overriden_var": "gateway_1"}, + ), + "second": GatewayConfig( + connection=DuckDBConnectionConfig(database=second_db_path), + variables={"overriden_var": "gateway_2"}, + ), + }, + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + model_naming=NameInferenceConfig(infer_names=True), + default_gateway="first", + gateway_managed_virtual_layer=True, + variables={"overriden_var": "global", "global_one": 88}, + ) + + context = Context(paths=paths, config=config) + plan = context.plan_builder().build() + assert len(plan.new_snapshots) == 4 + context.apply(plan) + + # Confirm cache exists + cache_path = Path(path) / ".cache" + assert cache_path.exists() + assert len(list(cache_path.iterdir())) > 0 + + model = context.get_model("db_1.first_schema.model_one") + + context.upsert_model( + model.copy( + update={ + "query_": ParsableSql( + sql=model.query.select("'c' AS extra").sql(dialect=model.dialect) + ) + } + ) + ) + plan = context.plan_builder().build() + context.apply(plan) + + state_environments = context.state_reader.get_environments() + state_snapshots = context.state_reader.get_snapshots(context.snapshots.values()) + + assert len(state_snapshots) == len(state_environments[0].snapshots) + + # Create dev environment with changed models + model = context.get_model("db_2.second_schema.model_one") + context.upsert_model( + model.copy( + update={ + "query_": ParsableSql( + sql=model.query.select("'d' AS extra").sql(dialect=model.dialect) + ) + } + ) + ) + model = context.get_model("first_schema.model_two") + context.upsert_model( + model.copy( + update={ + "query_": ParsableSql( + sql=model.query.select("'d2' AS col").sql(dialect=model.dialect) + ) + } + ) + ) + plan = context.plan_builder("dev").build() + context.apply(plan) + + dev_environment = context.state_sync.get_environment("dev") + assert dev_environment is not None + + state_environments = context.state_reader.get_environments() + state_snapshots = context.state_reader.get_snapshots(context.snapshots.values()) + assert ( + len(state_snapshots) + == len(state_environments[0].snapshots) + == len(state_environments[1].snapshots) + ) + + # The state tables at this point should be able to be retrieved + state_tables = { + "_environments", + "_snapshots", + "_intervals", + "_auto_restatements", + "_environment_statements", + "_intervals", + "_versions", + } + for table_name in state_tables: + context.fetchdf(f"SELECT * FROM db_1.sqlmesh.{table_name}") + + # The actual tables as well + context.engine_adapters["second"].fetchdf(f"SELECT * FROM db_2.second_schema.model_one") + context.engine_adapters["second"].fetchdf(f"SELECT * FROM db_2.second_schema.model_two") + context.fetchdf(f"SELECT * FROM db_1.first_schema.model_one") + context.fetchdf(f"SELECT * FROM db_1.first_schema.model_two") + + # Use the destroy command to remove all data objects and state + # Mock the console confirmation to automatically return True + with patch.object(context.console, "_confirm", return_value=True): + context._destroy() + + # Ensure all tables have been removed + for table_name in state_tables: + with pytest.raises( + Exception, match=f"Catalog Error: Table with name {table_name} does not exist!" + ): + context.fetchdf(f"SELECT * FROM db_1.sqlmesh.{table_name}") + + # Validate tables have been deleted as well + with pytest.raises( + Exception, match=r"Catalog Error: Table with name.*model_two.*does not exist" + ): + context.fetchdf("SELECT * FROM db_1.first_schema.model_two") + with pytest.raises( + Exception, match=r"Catalog Error: Table with name.*model_one.*does not exist" + ): + context.fetchdf("SELECT * FROM db_1.first_schema.model_one") + + with pytest.raises( + Exception, match=r"Catalog Error: Table with name.*model_two.*does not exist" + ): + context.engine_adapters["second"].fetchdf("SELECT * FROM db_2.second_schema.model_two") + with pytest.raises( + Exception, match=r"Catalog Error: Table with name.*model_one.*does not exist" + ): + context.engine_adapters["second"].fetchdf("SELECT * FROM db_2.second_schema.model_one") + + # Ensure the cache has been removed + assert not cache_path.exists() + + +@use_terminal_console +def test_render_path_instead_of_model(tmp_path: Path): + create_temp_file(tmp_path, Path("models/test.sql"), "MODEL (name test_model); SELECT 1 AS col") + ctx = Context(paths=tmp_path, config=Config()) + + # Case 1: Fail gracefully when the user is passing in a path instead of a model name + for test_model in ["models/test.sql", "models/test.py"]: + with pytest.raises( + SQLMeshError, + match="Resolving models by path is not supported, please pass in the model name instead.", + ): + ctx.render(test_model) + + # Case 2: Fail gracefully when the model name is not found + with pytest.raises(SQLMeshError, match="Cannot find model with name 'incorrect_model'"): + ctx.render("incorrect_model") + + # Case 3: Render the model successfully + assert ctx.render("test_model").sql() == 'SELECT 1 AS "col"' + + +def test_invalidating_environment(sushi_context: Context): + apply_to_environment(sushi_context, "dev") + start_environment = sushi_context.state_sync.get_environment("dev") + assert start_environment is not None + metadata = DuckDBMetadata.from_context(sushi_context) + start_schemas = set(metadata.schemas) + assert "sushi__dev" in start_schemas + sushi_context.invalidate_environment("dev") + invalidate_environment = sushi_context.state_sync.get_environment("dev") + assert invalidate_environment is not None + schemas_prior_to_janitor = set(metadata.schemas) + assert invalidate_environment.expiration_ts < start_environment.expiration_ts # type: ignore + assert start_schemas == schemas_prior_to_janitor + sushi_context._run_janitor() + schemas_after_janitor = set(metadata.schemas) + assert sushi_context.state_sync.get_environment("dev") is None + assert start_schemas - schemas_after_janitor == {"sushi__dev"} + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_evaluate_uncategorized_snapshot(init_and_plan_context: t.Callable): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + # Add a new projection + model = context.get_model("sushi.waiter_revenue_by_day") + context.upsert_model(add_projection_to_model(t.cast(SqlModel, model))) + + # Downstream model references the new projection + downstream_model = context.get_model("sushi.top_waiters") + context.upsert_model(add_projection_to_model(t.cast(SqlModel, downstream_model), literal=False)) + + df = context.evaluate( + "sushi.top_waiters", start="2023-01-05", end="2023-01-06", execution_time=now() + ) + assert set(df["one"].tolist()) == {1} diff --git a/tests/core/integration/test_change_scenarios.py b/tests/core/integration/test_change_scenarios.py new file mode 100644 index 0000000000..fb1762220f --- /dev/null +++ b/tests/core/integration/test_change_scenarios.py @@ -0,0 +1,1517 @@ +from __future__ import annotations + +import typing as t +import json +from datetime import timedelta +from unittest import mock +import pandas as pd # noqa: TID253 +import pytest +from pathlib import Path +from sqlmesh.core.model.common import ParsableSql +import time_machine +from sqlglot.expressions import DataType +import re + +from sqlmesh.cli.project_init import init_example_project +from sqlmesh.core import constants as c +from sqlmesh.core import dialect as d +from sqlmesh.core.config import ( + AutoCategorizationMode, + Config, + GatewayConfig, + ModelDefaultsConfig, + DuckDBConnectionConfig, +) +from sqlmesh.core.context import Context +from sqlmesh.core.config.categorizer import CategorizerConfig +from sqlmesh.core.model import ( + FullKind, + ModelKind, + ModelKindName, + SqlModel, + PythonModel, + ViewKind, + load_sql_based_model, +) +from sqlmesh.core.model.kind import model_kind_type_from_name +from sqlmesh.core.plan import Plan, SnapshotIntervals +from sqlmesh.core.snapshot import ( + SnapshotChangeCategory, +) +from sqlmesh.utils.date import now, to_timestamp +from sqlmesh.utils.errors import ( + SQLMeshError, +) +from tests.core.integration.utils import ( + apply_to_environment, + add_projection_to_model, + initial_add, + change_data_type, + validate_apply_basics, + change_model_kind, + validate_model_kind_change, + validate_query_change, + validate_plan_changes, +) + +pytestmark = pytest.mark.slow + + +def test_auto_categorization(sushi_context: Context): + environment = "dev" + for config in sushi_context.configs.values(): + config.plan.auto_categorize_changes.sql = AutoCategorizationMode.FULL + initial_add(sushi_context, environment) + + version = sushi_context.get_snapshot( + "sushi.waiter_as_customer_by_day", raise_if_missing=True + ).version + fingerprint = sushi_context.get_snapshot( + "sushi.waiter_as_customer_by_day", raise_if_missing=True + ).fingerprint + + model = t.cast(SqlModel, sushi_context.get_model("sushi.customers", raise_if_missing=True)) + sushi_context.upsert_model( + "sushi.customers", + query_=ParsableSql(sql=model.query.select("'foo' AS foo").sql(dialect=model.dialect)), # type: ignore + ) + apply_to_environment(sushi_context, environment) + + assert ( + sushi_context.get_snapshot( + "sushi.waiter_as_customer_by_day", raise_if_missing=True + ).change_category + == SnapshotChangeCategory.INDIRECT_NON_BREAKING + ) + assert ( + sushi_context.get_snapshot( + "sushi.waiter_as_customer_by_day", raise_if_missing=True + ).fingerprint + != fingerprint + ) + assert ( + sushi_context.get_snapshot("sushi.waiter_as_customer_by_day", raise_if_missing=True).version + == version + ) + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_breaking_only_impacts_immediate_children(init_and_plan_context: t.Callable): + context, _ = init_and_plan_context("examples/sushi") + context.upsert_model(context.get_model("sushi.top_waiters").copy(update={"kind": FullKind()})) + context.plan("prod", skip_tests=True, auto_apply=True, no_prompts=True) + + breaking_model = context.get_model("sushi.orders") + breaking_model = breaking_model.copy(update={"stamp": "force new version"}) + context.upsert_model(breaking_model) + breaking_snapshot = context.get_snapshot(breaking_model, raise_if_missing=True) + + non_breaking_model = context.get_model("sushi.waiter_revenue_by_day") + context.upsert_model(add_projection_to_model(t.cast(SqlModel, non_breaking_model))) + non_breaking_snapshot = context.get_snapshot(non_breaking_model, raise_if_missing=True) + top_waiter_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) + + plan_builder = context.plan_builder("dev", skip_tests=True, enable_preview=False) + plan_builder.set_choice(breaking_snapshot, SnapshotChangeCategory.BREAKING) + plan = plan_builder.build() + assert ( + plan.context_diff.snapshots[breaking_snapshot.snapshot_id].change_category + == SnapshotChangeCategory.BREAKING + ) + assert ( + plan.context_diff.snapshots[non_breaking_snapshot.snapshot_id].change_category + == SnapshotChangeCategory.NON_BREAKING + ) + assert ( + plan.context_diff.snapshots[top_waiter_snapshot.snapshot_id].change_category + == SnapshotChangeCategory.INDIRECT_NON_BREAKING + ) + assert plan.start == to_timestamp("2023-01-01") + assert not any(i.snapshot_id == top_waiter_snapshot.snapshot_id for i in plan.missing_intervals) + + context.apply(plan) + assert ( + not context.plan_builder("dev", skip_tests=True, enable_preview=False) + .build() + .requires_backfill + ) + + # Deploy everything to prod. + plan = context.plan_builder("prod", skip_tests=True).build() + assert not plan.missing_intervals + + context.apply(plan) + assert ( + not context.plan_builder("prod", skip_tests=True, enable_preview=False) + .build() + .requires_backfill + ) + + +@pytest.mark.parametrize( + "context_fixture", + ["sushi_context", "sushi_dbt_context", "sushi_test_dbt_context", "sushi_no_default_catalog"], +) +def test_model_add(context_fixture: Context, request): + initial_add(request.getfixturevalue(context_fixture), "dev") + + +def test_model_removed(sushi_context: Context): + environment = "dev" + initial_add(sushi_context, environment) + + top_waiters_snapshot_id = sushi_context.get_snapshot( + "sushi.top_waiters", raise_if_missing=True + ).snapshot_id + + sushi_context._models.pop('"memory"."sushi"."top_waiters"') + + def _validate_plan(context, plan): + validate_plan_changes(plan, removed=[top_waiters_snapshot_id]) + assert not plan.missing_intervals + + def _validate_apply(context): + assert not sushi_context.get_snapshot("sushi.top_waiters", raise_if_missing=False) + assert sushi_context.state_reader.get_snapshots([top_waiters_snapshot_id]) + env = sushi_context.state_reader.get_environment(environment) + assert env + assert all(snapshot.name != '"memory"."sushi"."top_waiters"' for snapshot in env.snapshots) + + apply_to_environment( + sushi_context, + environment, + SnapshotChangeCategory.BREAKING, + plan_validators=[_validate_plan], + apply_validators=[_validate_apply], + ) + + +def test_non_breaking_change(sushi_context: Context): + environment = "dev" + initial_add(sushi_context, environment) + validate_query_change(sushi_context, environment, SnapshotChangeCategory.NON_BREAKING, False) + + +def test_breaking_change(sushi_context: Context): + environment = "dev" + initial_add(sushi_context, environment) + validate_query_change(sushi_context, environment, SnapshotChangeCategory.BREAKING, False) + + +def test_logical_change(sushi_context: Context): + environment = "dev" + initial_add(sushi_context, environment) + previous_sushi_items_version = sushi_context.get_snapshot( + "sushi.items", raise_if_missing=True + ).version + + change_data_type( + sushi_context, + "sushi.items", + DataType.Type.DOUBLE, + DataType.Type.FLOAT, + ) + apply_to_environment(sushi_context, environment, SnapshotChangeCategory.NON_BREAKING) + + change_data_type( + sushi_context, + "sushi.items", + DataType.Type.FLOAT, + DataType.Type.DOUBLE, + ) + apply_to_environment(sushi_context, environment, SnapshotChangeCategory.NON_BREAKING) + + assert ( + sushi_context.get_snapshot("sushi.items", raise_if_missing=True).version + == previous_sushi_items_version + ) + + +@pytest.mark.parametrize( + "from_, to", + [ + (ModelKindName.INCREMENTAL_BY_TIME_RANGE, ModelKindName.FULL), + (ModelKindName.FULL, ModelKindName.INCREMENTAL_BY_TIME_RANGE), + ], +) +def test_model_kind_change(from_: ModelKindName, to: ModelKindName, sushi_context: Context): + environment = f"test_model_kind_change__{from_.value.lower()}__{to.value.lower()}" + incremental_snapshot = sushi_context.get_snapshot("sushi.items", raise_if_missing=True).copy() + + if from_ != ModelKindName.INCREMENTAL_BY_TIME_RANGE: + change_model_kind(sushi_context, from_) + apply_to_environment(sushi_context, environment, SnapshotChangeCategory.NON_BREAKING) + + if to == ModelKindName.INCREMENTAL_BY_TIME_RANGE: + sushi_context.upsert_model(incremental_snapshot.model) + else: + change_model_kind(sushi_context, to) + + logical = to in (ModelKindName.INCREMENTAL_BY_TIME_RANGE, ModelKindName.EMBEDDED) + validate_model_kind_change(to, sushi_context, environment, logical=logical) + + +def test_environment_isolation(sushi_context: Context): + prod_snapshots = sushi_context.snapshots.values() + + change_data_type( + sushi_context, + "sushi.items", + DataType.Type.DOUBLE, + DataType.Type.FLOAT, + ) + directly_modified = ['"memory"."sushi"."items"'] + indirectly_modified = [ + '"memory"."sushi"."order_items"', + '"memory"."sushi"."waiter_revenue_by_day"', + '"memory"."sushi"."customer_revenue_by_day"', + '"memory"."sushi"."customer_revenue_lifetime"', + '"memory"."sushi"."top_waiters"', + "assert_item_price_above_zero", + ] + + apply_to_environment(sushi_context, "dev", SnapshotChangeCategory.BREAKING) + + # Verify prod unchanged + validate_apply_basics(sushi_context, "prod", prod_snapshots) + + def _validate_plan(context, plan): + validate_plan_changes(plan, modified=directly_modified + indirectly_modified) + assert not plan.missing_intervals + + apply_to_environment( + sushi_context, + "prod", + SnapshotChangeCategory.BREAKING, + plan_validators=[_validate_plan], + ) + + +def test_environment_promotion(sushi_context: Context): + initial_add(sushi_context, "dev") + + # Simulate prod "ahead" + change_data_type(sushi_context, "sushi.items", DataType.Type.DOUBLE, DataType.Type.FLOAT) + apply_to_environment(sushi_context, "prod", SnapshotChangeCategory.BREAKING) + + # Simulate rebase + apply_to_environment(sushi_context, "dev", SnapshotChangeCategory.BREAKING) + + # Make changes in dev + change_data_type(sushi_context, "sushi.items", DataType.Type.FLOAT, DataType.Type.DECIMAL) + apply_to_environment(sushi_context, "dev", SnapshotChangeCategory.NON_BREAKING) + + change_data_type(sushi_context, "sushi.top_waiters", DataType.Type.DOUBLE, DataType.Type.INT) + apply_to_environment(sushi_context, "dev", SnapshotChangeCategory.BREAKING) + + change_data_type( + sushi_context, + "sushi.customer_revenue_by_day", + DataType.Type.DOUBLE, + DataType.Type.FLOAT, + ) + apply_to_environment( + sushi_context, + "dev", + SnapshotChangeCategory.FORWARD_ONLY, + allow_destructive_models=['"memory"."sushi"."customer_revenue_by_day"'], + ) + + # Promote to prod + def _validate_plan(context, plan): + sushi_items_snapshot = context.get_snapshot("sushi.items", raise_if_missing=True) + sushi_top_waiters_snapshot = context.get_snapshot( + "sushi.top_waiters", raise_if_missing=True + ) + sushi_customer_revenue_by_day_snapshot = context.get_snapshot( + "sushi.customer_revenue_by_day", raise_if_missing=True + ) + + assert ( + plan.context_diff.modified_snapshots[sushi_items_snapshot.name][0].change_category + == SnapshotChangeCategory.NON_BREAKING + ) + assert ( + plan.context_diff.modified_snapshots[sushi_top_waiters_snapshot.name][0].change_category + == SnapshotChangeCategory.BREAKING + ) + assert ( + plan.context_diff.modified_snapshots[sushi_customer_revenue_by_day_snapshot.name][ + 0 + ].change_category + == SnapshotChangeCategory.NON_BREAKING + ) + assert plan.context_diff.snapshots[ + sushi_customer_revenue_by_day_snapshot.snapshot_id + ].is_forward_only + + apply_to_environment( + sushi_context, + "prod", + SnapshotChangeCategory.NON_BREAKING, + plan_validators=[_validate_plan], + allow_destructive_models=['"memory"."sushi"."customer_revenue_by_day"'], + ) + + +def test_no_override(sushi_context: Context) -> None: + change_data_type( + sushi_context, + "sushi.items", + DataType.Type.INT, + DataType.Type.BIGINT, + ) + + change_data_type( + sushi_context, + "sushi.order_items", + DataType.Type.INT, + DataType.Type.BIGINT, + ) + + plan_builder = sushi_context.plan_builder("prod") + plan = plan_builder.build() + + sushi_items_snapshot = sushi_context.get_snapshot("sushi.items", raise_if_missing=True) + sushi_order_items_snapshot = sushi_context.get_snapshot( + "sushi.order_items", raise_if_missing=True + ) + sushi_water_revenue_by_day_snapshot = sushi_context.get_snapshot( + "sushi.waiter_revenue_by_day", raise_if_missing=True + ) + + items = plan.context_diff.snapshots[sushi_items_snapshot.snapshot_id] + order_items = plan.context_diff.snapshots[sushi_order_items_snapshot.snapshot_id] + waiter_revenue = plan.context_diff.snapshots[sushi_water_revenue_by_day_snapshot.snapshot_id] + + plan_builder.set_choice(items, SnapshotChangeCategory.BREAKING).set_choice( + order_items, SnapshotChangeCategory.NON_BREAKING + ) + plan_builder.build() + assert items.is_new_version + assert waiter_revenue.is_new_version + plan_builder.set_choice(items, SnapshotChangeCategory.NON_BREAKING) + plan_builder.build() + assert not waiter_revenue.is_new_version + + +@pytest.mark.parametrize( + "change_categories, expected", + [ + ([SnapshotChangeCategory.NON_BREAKING], SnapshotChangeCategory.BREAKING), + ([SnapshotChangeCategory.BREAKING], SnapshotChangeCategory.BREAKING), + ( + [SnapshotChangeCategory.NON_BREAKING, SnapshotChangeCategory.NON_BREAKING], + SnapshotChangeCategory.BREAKING, + ), + ( + [SnapshotChangeCategory.NON_BREAKING, SnapshotChangeCategory.BREAKING], + SnapshotChangeCategory.BREAKING, + ), + ( + [SnapshotChangeCategory.BREAKING, SnapshotChangeCategory.NON_BREAKING], + SnapshotChangeCategory.BREAKING, + ), + ( + [SnapshotChangeCategory.BREAKING, SnapshotChangeCategory.BREAKING], + SnapshotChangeCategory.BREAKING, + ), + ], +) +def test_revert( + sushi_context: Context, + change_categories: t.List[SnapshotChangeCategory], + expected: SnapshotChangeCategory, +): + environment = "prod" + original_snapshot_id = sushi_context.get_snapshot("sushi.items", raise_if_missing=True) + + types = (DataType.Type.DOUBLE, DataType.Type.FLOAT, DataType.Type.DECIMAL) + assert len(change_categories) < len(types) + + for i, category in enumerate(change_categories): + change_data_type(sushi_context, "sushi.items", *types[i : i + 2]) + apply_to_environment(sushi_context, environment, category) + assert ( + sushi_context.get_snapshot("sushi.items", raise_if_missing=True) != original_snapshot_id + ) + + change_data_type(sushi_context, "sushi.items", types[len(change_categories)], types[0]) + + def _validate_plan(_, plan): + snapshot = next(s for s in plan.snapshots.values() if s.name == '"memory"."sushi"."items"') + assert snapshot.change_category == expected + assert not plan.missing_intervals + + apply_to_environment( + sushi_context, + environment, + change_categories[-1], + plan_validators=[_validate_plan], + ) + assert sushi_context.get_snapshot("sushi.items", raise_if_missing=True) == original_snapshot_id + + +def test_revert_after_downstream_change(sushi_context: Context): + environment = "prod" + change_data_type(sushi_context, "sushi.items", DataType.Type.DOUBLE, DataType.Type.FLOAT) + apply_to_environment(sushi_context, environment, SnapshotChangeCategory.BREAKING) + + change_data_type( + sushi_context, + "sushi.waiter_revenue_by_day", + DataType.Type.DOUBLE, + DataType.Type.FLOAT, + ) + apply_to_environment(sushi_context, environment, SnapshotChangeCategory.NON_BREAKING) + + change_data_type(sushi_context, "sushi.items", DataType.Type.FLOAT, DataType.Type.DOUBLE) + + def _validate_plan(_, plan): + snapshot = next(s for s in plan.snapshots.values() if s.name == '"memory"."sushi"."items"') + assert snapshot.change_category == SnapshotChangeCategory.BREAKING + assert plan.missing_intervals + + apply_to_environment( + sushi_context, + environment, + SnapshotChangeCategory.BREAKING, + plan_validators=[_validate_plan], + ) + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_indirect_non_breaking_change_after_forward_only_in_dev(init_and_plan_context: t.Callable): + context, _ = init_and_plan_context("examples/sushi") + # Make sure that the most downstream model is a materialized model. + model = context.get_model("sushi.top_waiters") + model = model.copy(update={"kind": FullKind()}) + context.upsert_model(model) + context.plan("prod", skip_tests=True, auto_apply=True, no_prompts=True) + + # Make sushi.orders a forward-only model. + model = context.get_model("sushi.orders") + updated_model_kind = model.kind.copy(update={"forward_only": True}) + model = model.copy(update={"stamp": "force new version", "kind": updated_model_kind}) + context.upsert_model(model) + snapshot = context.get_snapshot(model, raise_if_missing=True) + + plan = context.plan_builder( + "dev", + skip_tests=True, + enable_preview=False, + categorizer_config=CategorizerConfig.all_full(), + ).build() + assert ( + plan.context_diff.snapshots[snapshot.snapshot_id].change_category + == SnapshotChangeCategory.BREAKING + ) + assert plan.context_diff.snapshots[snapshot.snapshot_id].is_forward_only + assert not plan.requires_backfill + context.apply(plan) + + # Make a non-breaking change to a model. + model = context.get_model("sushi.top_waiters") + context.upsert_model(add_projection_to_model(t.cast(SqlModel, model))) + top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) + + plan = context.plan_builder("dev", skip_tests=True, enable_preview=False).build() + assert len(plan.new_snapshots) == 1 + assert ( + plan.context_diff.snapshots[top_waiters_snapshot.snapshot_id].change_category + == SnapshotChangeCategory.NON_BREAKING + ) + assert plan.start == to_timestamp("2023-01-01") + assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=top_waiters_snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + ] + + # Apply the non-breaking changes. + context.apply(plan) + + # Make a non-breaking change upstream from the previously modified model. + model = context.get_model("sushi.waiter_revenue_by_day") + context.upsert_model(add_projection_to_model(t.cast(SqlModel, model))) + waiter_revenue_by_day_snapshot = context.get_snapshot( + "sushi.waiter_revenue_by_day", raise_if_missing=True + ) + top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) + + plan = context.plan_builder("dev", skip_tests=True, enable_preview=False).build() + assert len(plan.new_snapshots) == 2 + assert ( + plan.context_diff.snapshots[waiter_revenue_by_day_snapshot.snapshot_id].change_category + == SnapshotChangeCategory.NON_BREAKING + ) + assert ( + plan.context_diff.snapshots[top_waiters_snapshot.snapshot_id].change_category + == SnapshotChangeCategory.INDIRECT_NON_BREAKING + ) + assert plan.start == to_timestamp("2023-01-01") + assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=waiter_revenue_by_day_snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + ] + + # Apply the upstream non-breaking changes. + context.apply(plan) + assert not context.plan_builder("dev", skip_tests=True).build().requires_backfill + + # Deploy everything to prod. + plan = context.plan_builder("prod", skip_tests=True, enable_preview=False).build() + assert plan.start == to_timestamp("2023-01-01") + assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=top_waiters_snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + SnapshotIntervals( + snapshot_id=waiter_revenue_by_day_snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + ] + + context.apply(plan) + assert ( + not context.plan_builder("prod", skip_tests=True, enable_preview=False) + .build() + .requires_backfill + ) + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +@pytest.mark.parametrize("forward_only", [False, True]) +def test_plan_repairs_unrenderable_snapshot_state( + init_and_plan_context: t.Callable, forward_only: bool +): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + target_snapshot = context.get_snapshot("sushi.waiter_revenue_by_day") + assert target_snapshot + + # Manually corrupt the snapshot's query + raw_snapshot = context.state_sync.state_sync.engine_adapter.fetchone( + f"SELECT snapshot FROM sqlmesh._snapshots WHERE name = '{target_snapshot.name}' AND identifier = '{target_snapshot.identifier}'" + )[0] # type: ignore + parsed_snapshot = json.loads(raw_snapshot) + parsed_snapshot["node"]["query"] = "SELECT @missing_macro()" + context.state_sync.state_sync.engine_adapter.update_table( + "sqlmesh._snapshots", + {"snapshot": json.dumps(parsed_snapshot)}, + f"name = '{target_snapshot.name}' AND identifier = '{target_snapshot.identifier}'", + ) + + context.clear_caches() + target_snapshot_in_state = context.state_sync.get_snapshots([target_snapshot.snapshot_id])[ + target_snapshot.snapshot_id + ] + + with pytest.raises(Exception): + target_snapshot_in_state.model.render_query_or_raise() + + # Repair the snapshot by creating a new version of it + context.upsert_model(target_snapshot.model.name, stamp="repair") + target_snapshot = context.get_snapshot(target_snapshot.name) + + plan_builder = context.plan_builder("prod", forward_only=forward_only) + plan = plan_builder.build() + if not forward_only: + assert target_snapshot.snapshot_id in {i.snapshot_id for i in plan.missing_intervals} + assert plan.directly_modified == {target_snapshot.snapshot_id} + plan_builder.set_choice(target_snapshot, SnapshotChangeCategory.NON_BREAKING) + plan = plan_builder.build() + + context.apply(plan) + + context.clear_caches() + assert context.get_snapshot(target_snapshot.name).model.render_query_or_raise() + target_snapshot_in_state = context.state_sync.get_snapshots([target_snapshot.snapshot_id])[ + target_snapshot.snapshot_id + ] + assert target_snapshot_in_state.model.render_query_or_raise() + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_no_backfill_for_model_downstream_of_metadata_change(init_and_plan_context: t.Callable): + context, _ = init_and_plan_context("examples/sushi") + + # Make sushi.waiter_revenue_by_day a forward-only model. + forward_only_model = context.get_model("sushi.waiter_revenue_by_day") + updated_model_kind = forward_only_model.kind.copy(update={"forward_only": True}) + forward_only_model = forward_only_model.copy(update={"kind": updated_model_kind}) + context.upsert_model(forward_only_model) + + context.plan("prod", auto_apply=True, no_prompts=True, skip_tests=True) + + # Make a metadata change upstream of the forward-only model. + context.upsert_model("sushi.orders", owner="new_owner") + + plan = context.plan_builder("test_dev").build() + assert plan.has_changes + assert not plan.directly_modified + assert not plan.indirectly_modified + assert not plan.missing_intervals + assert all( + snapshot.change_category == SnapshotChangeCategory.METADATA + for snapshot in plan.new_snapshots + ) + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_plan_set_choice_is_reflected_in_missing_intervals(init_and_plan_context: t.Callable): + context, _ = init_and_plan_context("examples/sushi") + context.upsert_model(context.get_model("sushi.top_waiters").copy(update={"kind": FullKind()})) + context.plan("prod", skip_tests=True, no_prompts=True, auto_apply=True) + + model_name = "sushi.waiter_revenue_by_day" + + model = context.get_model(model_name) + context.upsert_model(add_projection_to_model(t.cast(SqlModel, model))) + snapshot = context.get_snapshot(model, raise_if_missing=True) + top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) + + plan_builder = context.plan_builder("dev", skip_tests=True) + plan = plan_builder.build() + assert len(plan.new_snapshots) == 2 + assert ( + plan.context_diff.snapshots[snapshot.snapshot_id].change_category + == SnapshotChangeCategory.NON_BREAKING + ) + assert ( + plan.context_diff.snapshots[top_waiters_snapshot.snapshot_id].change_category + == SnapshotChangeCategory.INDIRECT_NON_BREAKING + ) + assert plan.start == to_timestamp("2023-01-01") + assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + ] + + # Change the category to BREAKING + plan = plan_builder.set_choice( + plan.context_diff.snapshots[snapshot.snapshot_id], SnapshotChangeCategory.BREAKING + ).build() + assert ( + plan.context_diff.snapshots[snapshot.snapshot_id].change_category + == SnapshotChangeCategory.BREAKING + ) + assert ( + plan.context_diff.snapshots[top_waiters_snapshot.snapshot_id].change_category + == SnapshotChangeCategory.INDIRECT_BREAKING + ) + assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=top_waiters_snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + SnapshotIntervals( + snapshot_id=snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + ] + + # Change the category back to NON_BREAKING + plan = plan_builder.set_choice( + plan.context_diff.snapshots[snapshot.snapshot_id], SnapshotChangeCategory.NON_BREAKING + ).build() + assert ( + plan.context_diff.snapshots[snapshot.snapshot_id].change_category + == SnapshotChangeCategory.NON_BREAKING + ) + assert ( + plan.context_diff.snapshots[top_waiters_snapshot.snapshot_id].change_category + == SnapshotChangeCategory.INDIRECT_NON_BREAKING + ) + assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + ] + + context.apply(plan) + + dev_df = context.engine_adapter.fetchdf( + "SELECT DISTINCT event_date FROM sushi__dev.waiter_revenue_by_day ORDER BY event_date" + ) + assert dev_df["event_date"].tolist() == [ + pd.to_datetime(x) + for x in [ + "2023-01-01", + "2023-01-02", + "2023-01-03", + "2023-01-04", + "2023-01-05", + "2023-01-06", + "2023-01-07", + ] + ] + + # Promote changes to prod + prod_plan = context.plan_builder(skip_tests=True).build() + assert not prod_plan.missing_intervals + + context.apply(prod_plan) + prod_df = context.engine_adapter.fetchdf( + "SELECT DISTINCT event_date FROM sushi.waiter_revenue_by_day WHERE one IS NOT NULL ORDER BY event_date" + ) + assert prod_df["event_date"].tolist() == [ + pd.to_datetime(x) + for x in [ + "2023-01-01", + "2023-01-02", + "2023-01-03", + "2023-01-04", + "2023-01-05", + "2023-01-06", + "2023-01-07", + ] + ] + + +def test_plan_production_environment_statements(tmp_path: Path): + model_a = """ + MODEL ( + name test_schema.a, + kind FULL, + ); + + @IF( + @runtime_stage IN ('evaluating', 'creating'), + INSERT INTO schema_names_for_prod (physical_schema_name) VALUES (@resolve_template('@{schema_name}')) + ); + + SELECT 1 AS account_id + """ + + models_dir = tmp_path / "models" + models_dir.mkdir() + + for path, defn in {"a.sql": model_a}.items(): + with open(models_dir / path, "w") as f: + f.write(defn) + + before_all = [ + "CREATE TABLE IF NOT EXISTS schema_names_for_@this_env (physical_schema_name VARCHAR)", + "@IF(@runtime_stage = 'before_all', CREATE TABLE IF NOT EXISTS should_create AS SELECT @runtime_stage)", + ] + after_all = [ + "@IF(@this_env = 'prod', CREATE TABLE IF NOT EXISTS after_t AS SELECT @var_5)", + "@IF(@runtime_stage = 'before_all', CREATE TABLE IF NOT EXISTS not_create AS SELECT @runtime_stage)", + ] + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + before_all=before_all, + after_all=after_all, + variables={"var_5": 5}, + ) + ctx = Context(paths=[tmp_path], config=config) + ctx.plan(auto_apply=True, no_prompts=True) + + before_t = ctx.fetchdf("select * from schema_names_for_prod").to_dict() + assert before_t["physical_schema_name"][0] == "sqlmesh__test_schema" + + after_t = ctx.fetchdf("select * from after_t").to_dict() + assert after_t["5"][0] == 5 + + environment_statements = ctx.state_reader.get_environment_statements(c.PROD) + assert environment_statements[0].before_all == before_all + assert environment_statements[0].after_all == after_all + assert environment_statements[0].python_env.keys() == {"__sqlmesh__vars__"} + assert environment_statements[0].python_env["__sqlmesh__vars__"].payload == "{'var_5': 5}" + + should_create = ctx.fetchdf("select * from should_create").to_dict() + assert should_create["before_all"][0] == "before_all" + + with pytest.raises( + Exception, match=r"Catalog Error: Table with name not_create does not exist!" + ): + ctx.fetchdf("select * from not_create") + + +def test_environment_statements_error_handling(tmp_path: Path): + model_a = """ + MODEL ( + name test_schema.a, + kind FULL, + ); + + SELECT 1 AS account_id + """ + + models_dir = tmp_path / "models" + models_dir.mkdir() + + for path, defn in {"a.sql": model_a}.items(): + with open(models_dir / path, "w") as f: + f.write(defn) + + before_all = [ + "CREATE TABLE identical_table (physical_schema_name VARCHAR)", + "CREATE TABLE identical_table (physical_schema_name VARCHAR)", + ] + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + before_all=before_all, + ) + ctx = Context(paths=[tmp_path], config=config) + + expected_error_message = re.escape( + """An error occurred during execution of the following 'before_all' statement: + +CREATE TABLE identical_table (physical_schema_name TEXT) + +Catalog Error: Table with name "identical_table" already exists!""" + ) + + with pytest.raises(SQLMeshError, match=expected_error_message): + ctx.plan(auto_apply=True, no_prompts=True) + + after_all = [ + "@bad_macro()", + ] + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + after_all=after_all, + ) + ctx = Context(paths=[tmp_path], config=config) + + expected_error_message = re.escape( + """An error occurred during rendering of the 'after_all' statements: + +Failed to resolve macros for + +@bad_macro() + +Macro 'bad_macro' does not exist.""" + ) + + with pytest.raises(SQLMeshError, match=expected_error_message): + ctx.plan(auto_apply=True, no_prompts=True) + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_full_model_change_with_plan_start_not_matching_model_start( + init_and_plan_context: t.Callable, +): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + model = context.get_model("sushi.top_waiters") + context.upsert_model(model, kind=model_kind_type_from_name("FULL")()) # type: ignore + + # Apply the change with --skip-backfill first and no plan start + context.plan("dev", skip_tests=True, skip_backfill=True, no_prompts=True, auto_apply=True) + + # Apply the plan again but this time don't skip backfill and set start + # to be later than the model start + context.plan("dev", skip_tests=True, no_prompts=True, auto_apply=True, start="1 day ago") + + # Check that the number of rows is not 0 + row_num = context.engine_adapter.fetchone(f"SELECT COUNT(*) FROM sushi__dev.top_waiters")[0] + assert row_num > 0 + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_hourly_model_with_lookback_no_backfill_in_dev(init_and_plan_context: t.Callable): + context, plan = init_and_plan_context("examples/sushi") + + model_name = "sushi.waiter_revenue_by_day" + + model = context.get_model(model_name) + model = SqlModel.parse_obj( + { + **model.dict(), + "kind": model.kind.copy(update={"lookback": 1}), + "cron": "@hourly", + "audits": [], + } + ) + context.upsert_model(model) + + plan = context.plan_builder("prod", skip_tests=True).build() + context.apply(plan) + + top_waiters_model = context.get_model("sushi.top_waiters") + top_waiters_model = add_projection_to_model(t.cast(SqlModel, top_waiters_model), literal=True) + context.upsert_model(top_waiters_model) + + context.get_snapshot(model, raise_if_missing=True) + top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) + + with time_machine.travel(now() + timedelta(hours=2)): + plan = context.plan_builder("dev", skip_tests=True).build() + # Make sure the waiter_revenue_by_day model is not backfilled. + assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=top_waiters_snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + ] + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_max_interval_end_per_model_not_applied_when_end_is_provided( + init_and_plan_context: t.Callable, +): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + with time_machine.travel("2023-01-09 00:00:00 UTC"): + context.run() + + plan = context.plan_builder( + restate_models=["*"], start="2023-01-09", end="2023-01-09" + ).build() + context.apply(plan) + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_plan_against_expired_environment(init_and_plan_context: t.Callable): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + model = context.get_model("sushi.waiter_revenue_by_day") + context.upsert_model(add_projection_to_model(t.cast(SqlModel, model))) + + modified_models = {model.fqn, context.get_model("sushi.top_waiters").fqn} + + plan = context.plan_builder("dev").build() + assert plan.has_changes + assert set(plan.context_diff.modified_snapshots) == modified_models + assert plan.missing_intervals + context.apply(plan) + + # Make sure there are no changes when comparing against the existing environment. + plan = context.plan_builder("dev").build() + assert not plan.has_changes + assert not plan.context_diff.modified_snapshots + assert not plan.missing_intervals + + # Invalidate the environment and make sure that the plan detects the changes. + context.invalidate_environment("dev") + plan = context.plan_builder("dev").build() + assert plan.has_changes + assert set(plan.context_diff.modified_snapshots) == modified_models + assert not plan.missing_intervals + context.apply(plan) + + +def test_plan_environment_statements_doesnt_cause_extra_diff(tmp_path: Path): + model_a = """ + MODEL ( + name test_schema.a, + kind FULL, + ); + + SELECT 1; + """ + + models_dir = tmp_path / "models" + models_dir.mkdir() + + (models_dir / "a.sql").write_text(model_a) + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + before_all=["select 1 as before_all"], + after_all=["select 2 as after_all"], + ) + ctx = Context(paths=[tmp_path], config=config) + + # first plan - should apply changes + assert ctx.plan(auto_apply=True, no_prompts=True).has_changes + + # second plan - nothing has changed so should report no changes + assert not ctx.plan(auto_apply=True, no_prompts=True).has_changes + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_plan_snapshot_table_exists_for_promoted_snapshot(init_and_plan_context: t.Callable): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + model = context.get_model("sushi.waiter_revenue_by_day") + context.upsert_model(add_projection_to_model(t.cast(SqlModel, model))) + + context.plan("dev", auto_apply=True, no_prompts=True, skip_tests=True) + + # Drop the views and make sure SQLMesh recreates them later + top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) + context.engine_adapter.drop_view(top_waiters_snapshot.table_name()) + context.engine_adapter.drop_view(top_waiters_snapshot.table_name(False)) + + # Make the environment unfinalized to force recreation of all views in the virtual layer + context.state_sync.state_sync.engine_adapter.execute( + "UPDATE sqlmesh._environments SET finalized_ts = NULL WHERE name = 'dev'" + ) + + context.plan( + "prod", + restate_models=["sushi.top_waiters"], + auto_apply=True, + no_prompts=True, + skip_tests=True, + ) + assert context.engine_adapter.table_exists(top_waiters_snapshot.table_name()) + + +def test_plan_twice_with_star_macro_yields_no_diff(tmp_path: Path): + init_example_project(tmp_path, engine_type="duckdb") + + star_model_definition = """ + MODEL ( + name sqlmesh_example.star_model, + kind FULL + ); + + SELECT @STAR(sqlmesh_example.full_model) FROM sqlmesh_example.full_model + """ + + star_model_path = tmp_path / "models" / "star_model.sql" + star_model_path.write_text(star_model_definition) + + db_path = str(tmp_path / "db.db") + config = Config( + gateways={"main": GatewayConfig(connection=DuckDBConnectionConfig(database=db_path))}, + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + ) + context = Context(paths=tmp_path, config=config) + context.plan(auto_apply=True, no_prompts=True) + + # Instantiate new context to remove caches etc + new_context = Context(paths=tmp_path, config=config) + + star_model = new_context.get_model("sqlmesh_example.star_model") + assert ( + star_model.render_query_or_raise().sql() + == 'SELECT CAST("full_model"."item_id" AS INT) AS "item_id", CAST("full_model"."num_orders" AS BIGINT) AS "num_orders" FROM "db"."sqlmesh_example"."full_model" AS "full_model"' + ) + + new_plan = new_context.plan_builder().build() + assert not new_plan.has_changes + assert not new_plan.new_snapshots + + +class OldPythonModel(PythonModel): + kind: ModelKind = ViewKind() + + +def test_python_model_default_kind_change(init_and_plan_context: t.Callable): + """ + Around 2024-07-17 Python models had their default Kind changed from VIEW to FULL in order to + avoid some edge cases where the views might not get updated in certain situations. + + This test ensures that if a user had a Python `kind: VIEW` model stored in state, + it can still be loaded without error and just show as a breaking change from `kind: VIEW` + to `kind: FULL` + """ + + # note: we deliberately dont specify a Kind here to allow the defaults to be picked up + python_model_file = """import typing as t +import pandas as pd # noqa: TID253 +from sqlmesh import ExecutionContext, model + +@model( + "sushi.python_view_model", + columns={ + "id": "int", + } +) +def execute( + context: ExecutionContext, + **kwargs: t.Any, +) -> pd.DataFrame: + return pd.DataFrame([ + {"id": 1} + ]) +""" + + context: Context + context, _ = init_and_plan_context("examples/sushi") + + with open(context.path / "models" / "python_view_model.py", mode="w", encoding="utf8") as f: + f.write(python_model_file) + + # monkey-patch PythonModel to default to kind: View again + # and ViewKind to allow python models again + with ( + mock.patch.object(ViewKind, "supports_python_models", return_value=True), + mock.patch("sqlmesh.core.model.definition.PythonModel", OldPythonModel), + ): + context.load() + + # check the monkey-patching worked + model = context.get_model("sushi.python_view_model") + assert model.kind.name == ModelKindName.VIEW + assert model.source_type == "python" + + # apply plan + plan: Plan = context.plan(auto_apply=True) + + # check that run() still works even though we have a Python model with kind: View in the state + snapshot_ids = [s for s in plan.directly_modified if "python_view_model" in s.name] + snapshot_from_state = list(context.state_sync.get_snapshots(snapshot_ids).values())[0] + assert snapshot_from_state.model.kind.name == ModelKindName.VIEW + assert snapshot_from_state.model.source_type == "python" + context.run() + + # reload context to load model with new defaults + # this also shows the earlier monkey-patching is no longer in effect + context.load() + model = context.get_model("sushi.python_view_model") + assert model.kind.name == ModelKindName.FULL + assert model.source_type == "python" + + plan = context.plan( + categorizer_config=CategorizerConfig.all_full() + ) # the default categorizer_config doesnt auto-categorize python models + + assert plan.has_changes + assert not plan.indirectly_modified + + assert len(plan.directly_modified) == 1 + snapshot_id = list(plan.directly_modified)[0] + assert snapshot_id.name == '"memory"."sushi"."python_view_model"' + assert plan.modified_snapshots[snapshot_id].change_category == SnapshotChangeCategory.BREAKING + + context.apply(plan) + + df = context.engine_adapter.fetchdf("SELECT id FROM sushi.python_view_model") + assert df["id"].to_list() == [1] + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +@pytest.mark.parametrize( + "parent_a_category,parent_b_category,expected_child_category", + [ + ( + SnapshotChangeCategory.BREAKING, + SnapshotChangeCategory.BREAKING, + SnapshotChangeCategory.INDIRECT_BREAKING, + ), + ( + SnapshotChangeCategory.NON_BREAKING, + SnapshotChangeCategory.NON_BREAKING, + SnapshotChangeCategory.INDIRECT_NON_BREAKING, + ), + ( + SnapshotChangeCategory.BREAKING, + SnapshotChangeCategory.NON_BREAKING, + SnapshotChangeCategory.INDIRECT_NON_BREAKING, + ), + ( + SnapshotChangeCategory.NON_BREAKING, + SnapshotChangeCategory.BREAKING, + SnapshotChangeCategory.INDIRECT_BREAKING, + ), + ( + SnapshotChangeCategory.NON_BREAKING, + SnapshotChangeCategory.METADATA, + SnapshotChangeCategory.METADATA, + ), + ( + SnapshotChangeCategory.BREAKING, + SnapshotChangeCategory.METADATA, + SnapshotChangeCategory.METADATA, + ), + ( + SnapshotChangeCategory.METADATA, + SnapshotChangeCategory.BREAKING, + SnapshotChangeCategory.INDIRECT_BREAKING, + ), + ( + SnapshotChangeCategory.METADATA, + SnapshotChangeCategory.NON_BREAKING, + SnapshotChangeCategory.INDIRECT_NON_BREAKING, + ), + ( + SnapshotChangeCategory.METADATA, + SnapshotChangeCategory.METADATA, + SnapshotChangeCategory.METADATA, + ), + ], +) +def test_rebase_two_changed_parents( + init_and_plan_context: t.Callable, + parent_a_category: SnapshotChangeCategory, # This change is deployed to prod first + parent_b_category: SnapshotChangeCategory, # This change is deployed to prod second + expected_child_category: SnapshotChangeCategory, +): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + initial_model_a = context.get_model("sushi.orders") + initial_model_b = context.get_model("sushi.items") + + # Make change A and deploy it to dev_a + context.upsert_model(initial_model_a.name, stamp="1") + plan_builder = context.plan_builder("dev_a", skip_tests=True) + plan_builder.set_choice(context.get_snapshot(initial_model_a.name), parent_a_category) + context.apply(plan_builder.build()) + + # Make change B and deploy it to dev_b + context.upsert_model(initial_model_a) + context.upsert_model(initial_model_b.name, stamp="1") + plan_builder = context.plan_builder("dev_b", skip_tests=True) + plan_builder.set_choice(context.get_snapshot(initial_model_b.name), parent_b_category) + context.apply(plan_builder.build()) + + # Deploy change A to prod + context.upsert_model(initial_model_a.name, stamp="1") + context.upsert_model(initial_model_b) + context.plan("prod", auto_apply=True, no_prompts=True, skip_tests=True) + + # Apply change B in addition to A and plan against prod + context.upsert_model(initial_model_b.name, stamp="1") + plan = context.plan_builder("prod", skip_tests=True).build() + + # Validate the category of child snapshots + direct_child_snapshot = plan.snapshots[context.get_snapshot("sushi.order_items").snapshot_id] + assert direct_child_snapshot.change_category == expected_child_category + + indirect_child_snapshot = plan.snapshots[context.get_snapshot("sushi.top_waiters").snapshot_id] + assert indirect_child_snapshot.change_category == expected_child_category + + +@pytest.mark.parametrize( + "context_fixture", + ["sushi_context", "sushi_no_default_catalog"], +) +def test_unaligned_start_snapshots(context_fixture: Context, request): + context = request.getfixturevalue(context_fixture) + environment = "dev" + apply_to_environment(context, environment) + # Make breaking change to model upstream of a depends_on_self model + context.upsert_model("sushi.order_items", stamp="1") + # Apply the change starting at a date later then the beginning of the downstream depends_on_self model + plan = apply_to_environment( + context, + environment, + choice=SnapshotChangeCategory.BREAKING, + plan_start="2 days ago", + enable_preview=True, + ) + revenue_lifetime_snapshot = context.get_snapshot( + "sushi.customer_revenue_lifetime", raise_if_missing=True + ) + # Validate that the depends_on_self model is non-deployable + assert not plan.deployability_index.is_deployable(revenue_lifetime_snapshot) + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_unaligned_start_snapshot_with_non_deployable_downstream(init_and_plan_context: t.Callable): + context, _ = init_and_plan_context("examples/sushi") + + downstream_model_name = "memory.sushi.customer_max_revenue" + + expressions = d.parse( + f""" + MODEL ( + name {downstream_model_name}, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key customer_id, + forward_only true, + ), + ); + + SELECT + customer_id, MAX(revenue) AS max_revenue + FROM memory.sushi.customer_revenue_lifetime + GROUP BY 1; + """ + ) + + downstream_model = load_sql_based_model(expressions) + assert downstream_model.forward_only + context.upsert_model(downstream_model) + + context.plan(auto_apply=True, no_prompts=True) + + customer_revenue_lifetime_model = context.get_model("sushi.customer_revenue_lifetime") + kwargs = { + **customer_revenue_lifetime_model.dict(), + "name": "memory.sushi.customer_revenue_lifetime_new", + "kind": dict( + name="INCREMENTAL_UNMANAGED" + ), # Make it incremental unmanaged to ensure the depends_on_past behavior. + } + context.upsert_model(SqlModel.parse_obj(kwargs)) + context.upsert_model( + downstream_model_name, + query_=ParsableSql( + sql="SELECT customer_id, MAX(revenue) AS max_revenue FROM memory.sushi.customer_revenue_lifetime_new GROUP BY 1" + ), + ) + + plan = context.plan_builder("dev", enable_preview=True).build() + assert {s.name for s in plan.new_snapshots} == { + '"memory"."sushi"."customer_revenue_lifetime_new"', + '"memory"."sushi"."customer_max_revenue"', + } + for snapshot_interval in plan.missing_intervals: + assert not plan.deployability_index.is_deployable(snapshot_interval.snapshot_id) + assert snapshot_interval.intervals[0][0] == to_timestamp("2023-01-07") + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_indirect_non_breaking_view_is_updated_with_new_table_references( + init_and_plan_context: t.Callable, +): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + # Add a new projection to the base model + model = context.get_model("sushi.waiter_revenue_by_day") + context.upsert_model(add_projection_to_model(t.cast(SqlModel, model))) + + context.plan("prod", auto_apply=True, no_prompts=True, skip_tests=True) + + # Run the janitor to delete the old snapshot record + context.run_janitor(ignore_ttl=True) + + # Check the downstream view and make sure it's still queryable + assert context.get_model("sushi.top_waiters").kind.is_view + row_num = context.engine_adapter.fetchone(f"SELECT COUNT(*) FROM sushi.top_waiters")[0] + assert row_num > 0 + + +@time_machine.travel("2023-01-08 00:00:00 UTC") +def test_annotated_self_referential_model(init_and_plan_context: t.Callable): + context, _ = init_and_plan_context("examples/sushi") + + # Projections are fully annotated in the query but columns were not specified explicitly + expressions = d.parse( + f""" + MODEL ( + name memory.sushi.test_self_ref, + kind FULL, + start '2023-01-01', + ); + + SELECT 1::INT AS one FROM memory.sushi.test_self_ref; + """ + ) + model = load_sql_based_model(expressions) + assert model.depends_on_self + context.upsert_model(model) + + context.plan("prod", skip_tests=True, no_prompts=True, auto_apply=True) + + df = context.fetchdf("SELECT one FROM memory.sushi.test_self_ref") + assert len(df) == 0 + + +@time_machine.travel("2023-01-08 00:00:00 UTC") +def test_creating_stage_for_first_batch_only(init_and_plan_context: t.Callable): + context, _ = init_and_plan_context("examples/sushi") + + expressions = d.parse( + """ + MODEL ( + name memory.sushi.test_batch_size, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key one, + batch_size 1, + ), + + start '2023-01-01', + ); + + CREATE SCHEMA IF NOT EXISTS test_schema; + CREATE TABLE IF NOT EXISTS test_schema.creating_counter (a INT); + + SELECT 1::INT AS one; + + @IF(@runtime_stage = 'creating', INSERT INTO test_schema.creating_counter (a) VALUES (1)); + """ + ) + model = load_sql_based_model(expressions) + context.upsert_model(model) + + context.plan("prod", skip_tests=True, no_prompts=True, auto_apply=True) + assert ( + context.engine_adapter.fetchone("SELECT COUNT(*) FROM test_schema.creating_counter")[0] == 1 + ) diff --git a/tests/core/integration/test_config.py b/tests/core/integration/test_config.py new file mode 100644 index 0000000000..5d571cd7c5 --- /dev/null +++ b/tests/core/integration/test_config.py @@ -0,0 +1,580 @@ +from __future__ import annotations + +import typing as t +from unittest.mock import patch +import logging +import pytest +from pytest import MonkeyPatch +from pathlib import Path +from pytest_mock.plugin import MockerFixture +from sqlglot import exp +from IPython.utils.capture import capture_output + +from sqlmesh.core.config import ( + Config, + GatewayConfig, + ModelDefaultsConfig, + DuckDBConnectionConfig, + TableNamingConvention, + AutoCategorizationMode, +) +from sqlmesh.core.config.common import EnvironmentSuffixTarget +from sqlmesh.core.context import Context +from sqlmesh.core.config.plan import PlanConfig +from sqlmesh.core.engine_adapter import DuckDBEngineAdapter +from sqlmesh.core.model import SqlModel +from sqlmesh.core.model.common import ParsableSql +from sqlmesh.core.snapshot import ( + SnapshotChangeCategory, +) +from sqlmesh.utils.errors import ( + ConfigError, +) +from tests.conftest import DuckDBMetadata +from tests.utils.test_helpers import use_terminal_console +from tests.utils.test_filesystem import create_temp_file +from tests.core.integration.utils import apply_to_environment, initial_add + +pytestmark = pytest.mark.slow + + +@pytest.mark.set_default_connection(disable=True) +def test_missing_connection_config(): + # This is testing the actual implementation of Config.get_connection + # To make writing tests easier, it's patched by the autouse fixture provide_sqlmesh_default_connection + # Case 1: No default_connection or gateways specified should raise a ConfigError + with pytest.raises(ConfigError): + ctx = Context(config=Config()) + + # Case 2: No connection specified in the gateway should raise a ConfigError + with pytest.raises(ConfigError): + ctx = Context(config=Config(gateways={"incorrect": GatewayConfig()})) + + # Case 3: Specifying a default_connection or connection in the gateway should work + ctx = Context(config=Config(default_connection=DuckDBConnectionConfig())) + ctx = Context( + config=Config(gateways={"default": GatewayConfig(connection=DuckDBConnectionConfig())}) + ) + + +def test_physical_table_naming_strategy_table_only(copy_to_temp_path: t.Callable): + sushi_context = Context( + paths=copy_to_temp_path("examples/sushi"), + config="table_only_naming_config", + ) + + assert sushi_context.config.physical_table_naming_convention == TableNamingConvention.TABLE_ONLY + sushi_context.plan(auto_apply=True) + + adapter = sushi_context.engine_adapter + + snapshot_tables = [ + dict(catalog=str(r[0]), schema=str(r[1]), table=str(r[2])) + for r in adapter.fetchall( + "select table_catalog, table_schema, table_name from information_schema.tables where table_type='BASE TABLE'" + ) + ] + + assert all([not t["table"].startswith("sushi") for t in snapshot_tables]) + + prod_env = sushi_context.state_reader.get_environment("prod") + assert prod_env + + prod_env_snapshots = sushi_context.state_reader.get_snapshots(prod_env.snapshots) + + assert all( + s.table_naming_convention == TableNamingConvention.TABLE_ONLY + for s in prod_env_snapshots.values() + ) + + +def test_physical_table_naming_strategy_hash_md5(copy_to_temp_path: t.Callable): + sushi_context = Context( + paths=copy_to_temp_path("examples/sushi"), + config="hash_md5_naming_config", + ) + + assert sushi_context.config.physical_table_naming_convention == TableNamingConvention.HASH_MD5 + sushi_context.plan(auto_apply=True) + + adapter = sushi_context.engine_adapter + + snapshot_tables = [ + dict(catalog=str(r[0]), schema=str(r[1]), table=str(r[2])) + for r in adapter.fetchall( + "select table_catalog, table_schema, table_name from information_schema.tables where table_type='BASE TABLE'" + ) + ] + + assert all([not t["table"].startswith("sushi") for t in snapshot_tables]) + assert all([t["table"].startswith("sqlmesh_md5") for t in snapshot_tables]) + + prod_env = sushi_context.state_reader.get_environment("prod") + assert prod_env + + prod_env_snapshots = sushi_context.state_reader.get_snapshots(prod_env.snapshots) + + assert all( + s.table_naming_convention == TableNamingConvention.HASH_MD5 + for s in prod_env_snapshots.values() + ) + + +def test_environment_suffix_target_table(init_and_plan_context: t.Callable): + context, plan = init_and_plan_context( + "examples/sushi", config="environment_suffix_table_config" + ) + context.apply(plan) + metadata = DuckDBMetadata.from_context(context) + environments_schemas = {"sushi"} + internal_schemas = {"sqlmesh", "sqlmesh__sushi"} + starting_schemas = environments_schemas | internal_schemas + # Make sure no new schemas are created + assert set(metadata.schemas) - starting_schemas == {"raw"} + prod_views = {x for x in metadata.qualified_views if x.db in environments_schemas} + # Make sure that all models are present + assert len(prod_views) == 16 + apply_to_environment(context, "dev") + # Make sure no new schemas are created + assert set(metadata.schemas) - starting_schemas == {"raw"} + dev_views = { + x for x in metadata.qualified_views if x.db in environments_schemas and "__dev" in x.name + } + # Make sure that there is a view with `__dev` for each view that exists in prod + assert len(dev_views) == len(prod_views) + assert {x.name.replace("__dev", "") for x in dev_views} - {x.name for x in prod_views} == set() + context.invalidate_environment("dev") + context._run_janitor() + views_after_janitor = metadata.qualified_views + # Make sure that the number of views after the janitor is the same as when you subtract away dev views + assert len(views_after_janitor) == len( + {x.sql(dialect="duckdb") for x in views_after_janitor} + - {x.sql(dialect="duckdb") for x in dev_views} + ) + # Double check there are no dev views + assert len({x for x in views_after_janitor if "__dev" in x.name}) == 0 + # Make sure prod views were not removed + assert {x.sql(dialect="duckdb") for x in prod_views} - { + x.sql(dialect="duckdb") for x in views_after_janitor + } == set() + + +def test_environment_suffix_target_catalog(tmp_path: Path, monkeypatch: MonkeyPatch) -> None: + monkeypatch.chdir(tmp_path) + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + default_connection=DuckDBConnectionConfig(catalogs={"main_warehouse": ":memory:"}), + environment_suffix_target=EnvironmentSuffixTarget.CATALOG, + ) + + assert config.default_connection + + models_dir = tmp_path / "models" + models_dir.mkdir() + + (models_dir / "model.sql").write_text(""" + MODEL ( + name example_schema.test_model, + kind FULL + ); + + SELECT '1' as a""") + + (models_dir / "fqn_model.sql").write_text(""" + MODEL ( + name memory.example_fqn_schema.test_model_fqn, + kind FULL + ); + + SELECT '1' as a""") + + ctx = Context(config=config, paths=tmp_path) + + metadata = DuckDBMetadata.from_context(ctx) + assert ctx.default_catalog == "main_warehouse" + assert metadata.catalogs == {"main_warehouse", "memory"} + + ctx.plan(auto_apply=True) + + # prod should go to the default catalog and not be overridden to a catalog called 'prod' + assert ( + ctx.engine_adapter.fetchone("select * from main_warehouse.example_schema.test_model")[0] # type: ignore + == "1" + ) + assert ( + ctx.engine_adapter.fetchone("select * from memory.example_fqn_schema.test_model_fqn")[0] # type: ignore + == "1" + ) + assert metadata.catalogs == {"main_warehouse", "memory"} + assert metadata.schemas_in_catalog("main_warehouse") == [ + "example_schema", + "sqlmesh__example_schema", + ] + assert metadata.schemas_in_catalog("memory") == [ + "example_fqn_schema", + "sqlmesh__example_fqn_schema", + ] + + # dev should be overridden to go to a catalogs called 'main_warehouse__dev' and 'memory__dev' + ctx.plan(environment="dev", include_unmodified=True, auto_apply=True) + assert ( + ctx.engine_adapter.fetchone("select * from main_warehouse__dev.example_schema.test_model")[ + 0 + ] # type: ignore + == "1" + ) + assert ( + ctx.engine_adapter.fetchone("select * from memory__dev.example_fqn_schema.test_model_fqn")[ + 0 + ] # type: ignore + == "1" + ) + assert metadata.catalogs == {"main_warehouse", "main_warehouse__dev", "memory", "memory__dev"} + + # schemas in dev envs should match prod and not have a suffix + assert metadata.schemas_in_catalog("main_warehouse") == [ + "example_schema", + "sqlmesh__example_schema", + ] + assert metadata.schemas_in_catalog("main_warehouse__dev") == ["example_schema"] + assert metadata.schemas_in_catalog("memory") == [ + "example_fqn_schema", + "sqlmesh__example_fqn_schema", + ] + assert metadata.schemas_in_catalog("memory__dev") == ["example_fqn_schema"] + + ctx.invalidate_environment("dev", sync=True) + + # dev catalogs cleaned up + assert metadata.catalogs == {"main_warehouse", "memory"} + + # prod catalogs still contain physical layer and views still work + assert metadata.schemas_in_catalog("main_warehouse") == [ + "example_schema", + "sqlmesh__example_schema", + ] + assert metadata.schemas_in_catalog("memory") == [ + "example_fqn_schema", + "sqlmesh__example_fqn_schema", + ] + + assert ( + ctx.engine_adapter.fetchone("select * from main_warehouse.example_schema.test_model")[0] # type: ignore + == "1" + ) + assert ( + ctx.engine_adapter.fetchone("select * from memory.example_fqn_schema.test_model_fqn")[0] # type: ignore + == "1" + ) + + +def test_environment_catalog_mapping(init_and_plan_context: t.Callable): + environments_schemas = {"raw", "sushi"} + + def get_prod_dev_views(metadata: DuckDBMetadata) -> t.Tuple[t.Set[exp.Table], t.Set[exp.Table]]: + views = metadata.qualified_views + prod_views = { + x for x in views if x.catalog == "prod_catalog" if x.db in environments_schemas + } + dev_views = {x for x in views if x.catalog == "dev_catalog" if x.db in environments_schemas} + return prod_views, dev_views + + def get_default_catalog_and_non_tables( + metadata: DuckDBMetadata, default_catalog: t.Optional[str] + ) -> t.Tuple[t.Set[exp.Table], t.Set[exp.Table]]: + tables = metadata.qualified_tables + user_default_tables = { + x for x in tables if x.catalog == default_catalog and x.db != "sqlmesh" + } + non_default_tables = {x for x in tables if x.catalog != default_catalog} + return user_default_tables, non_default_tables + + context, plan = init_and_plan_context( + "examples/sushi", config="environment_catalog_mapping_config" + ) + context.apply(plan) + metadata = DuckDBMetadata(context.engine_adapter) + state_metadata = DuckDBMetadata.from_context(context.state_sync.state_sync) + prod_views, dev_views = get_prod_dev_views(metadata) + ( + user_default_tables, + non_default_tables, + ) = get_default_catalog_and_non_tables(metadata, context.default_catalog) + assert len(prod_views) == 16 + assert len(dev_views) == 0 + assert len(user_default_tables) == 15 + assert state_metadata.schemas == ["sqlmesh"] + assert {x.sql() for x in state_metadata.qualified_tables}.issuperset( + { + "physical.sqlmesh._environments", + "physical.sqlmesh._intervals", + "physical.sqlmesh._snapshots", + "physical.sqlmesh._versions", + } + ) + apply_to_environment(context, "dev") + prod_views, dev_views = get_prod_dev_views(metadata) + ( + user_default_tables, + non_default_tables, + ) = get_default_catalog_and_non_tables(metadata, context.default_catalog) + assert len(prod_views) == 16 + assert len(dev_views) == 16 + assert len(user_default_tables) == 16 + assert len(non_default_tables) == 0 + assert state_metadata.schemas == ["sqlmesh"] + assert {x.sql() for x in state_metadata.qualified_tables}.issuperset( + { + "physical.sqlmesh._environments", + "physical.sqlmesh._intervals", + "physical.sqlmesh._snapshots", + "physical.sqlmesh._versions", + } + ) + apply_to_environment(context, "prodnot") + prod_views, dev_views = get_prod_dev_views(metadata) + ( + user_default_tables, + non_default_tables, + ) = get_default_catalog_and_non_tables(metadata, context.default_catalog) + assert len(prod_views) == 16 + assert len(dev_views) == 32 + assert len(user_default_tables) == 16 + assert len(non_default_tables) == 0 + assert state_metadata.schemas == ["sqlmesh"] + assert {x.sql() for x in state_metadata.qualified_tables}.issuperset( + { + "physical.sqlmesh._environments", + "physical.sqlmesh._intervals", + "physical.sqlmesh._snapshots", + "physical.sqlmesh._versions", + } + ) + context.invalidate_environment("dev") + context._run_janitor() + prod_views, dev_views = get_prod_dev_views(metadata) + ( + user_default_tables, + non_default_tables, + ) = get_default_catalog_and_non_tables(metadata, context.default_catalog) + assert len(prod_views) == 16 + assert len(dev_views) == 16 + assert len(user_default_tables) == 16 + assert len(non_default_tables) == 0 + assert state_metadata.schemas == ["sqlmesh"] + assert {x.sql() for x in state_metadata.qualified_tables}.issuperset( + { + "physical.sqlmesh._environments", + "physical.sqlmesh._intervals", + "physical.sqlmesh._snapshots", + "physical.sqlmesh._versions", + } + ) + + +@use_terminal_console +def test_plan_always_recreate_environment(tmp_path: Path): + def plan_with_output(ctx: Context, environment: str): + with patch.object(logger, "info") as mock_logger: + with capture_output() as output: + ctx.load() + ctx.plan(environment, no_prompts=True, auto_apply=True) + + # Facade logs info "Promoting environment {environment}" + assert mock_logger.call_args[0][1] == environment + + return output + + models_dir = tmp_path / "models" + + logger = logging.getLogger("sqlmesh.core.state_sync.db.facade") + + create_temp_file( + tmp_path, models_dir / "a.sql", "MODEL (name test.a, kind FULL); SELECT 1 AS col" + ) + + config = Config(plan=PlanConfig(always_recreate_environment=True)) + ctx = Context(paths=[tmp_path], config=config) + + # Case 1: Neither prod nor dev exists, so dev is initialized + output = plan_with_output(ctx, "dev") + + assert """`dev` environment will be initialized""" in output.stdout + + # Case 2: Prod does not exist, so dev is updated + create_temp_file( + tmp_path, models_dir / "a.sql", "MODEL (name test.a, kind FULL); SELECT 5 AS col" + ) + + output = plan_with_output(ctx, "dev") + assert "`dev` environment will be initialized" in output.stdout + + # Case 3: Prod is initialized, so plan comparisons moving forward should be against prod + output = plan_with_output(ctx, "prod") + assert "`prod` environment will be initialized" in output.stdout + + # Case 4: Dev is updated with a breaking change. Prod exists now so plan comparisons moving forward should be against prod + create_temp_file( + tmp_path, models_dir / "a.sql", "MODEL (name test.a, kind FULL); SELECT 10 AS col" + ) + ctx.load() + + plan = ctx.plan_builder("dev").build() + + assert ( + next(iter(plan.context_diff.snapshots.values())).change_category + == SnapshotChangeCategory.BREAKING + ) + + output = plan_with_output(ctx, "dev") + assert "New environment `dev` will be created from `prod`" in output.stdout + assert "Differences from the `prod` environment" in output.stdout + + # Case 5: Dev is updated with a metadata change, but comparison against prod shows both the previous and the current changes + # so it's still classified as a breaking change + create_temp_file( + tmp_path, + models_dir / "a.sql", + "MODEL (name test.a, kind FULL, owner 'test'); SELECT 10 AS col", + ) + ctx.load() + + plan = ctx.plan_builder("dev").build() + + assert ( + next(iter(plan.context_diff.snapshots.values())).change_category + == SnapshotChangeCategory.BREAKING + ) + + output = plan_with_output(ctx, "dev") + assert "New environment `dev` will be created from `prod`" in output.stdout + assert "Differences from the `prod` environment" in output.stdout + + stdout_rstrip = "\n".join([line.rstrip() for line in output.stdout.split("\n")]) + assert ( + """MODEL ( + name test.a, ++ owner test, + kind FULL + ) + SELECT +- 5 AS col ++ 10 AS col""" + in stdout_rstrip + ) + + # Case 6: Ensure that target environment and create_from environment are not the same + output = plan_with_output(ctx, "prod") + assert not "New environment `prod` will be created from `prod`" in output.stdout + + # Case 7: Check that we can still run Context::diff() against any environment + for environment in ["dev", "prod"]: + context_diff = ctx._context_diff(environment) + assert context_diff.environment == environment + + +def test_before_all_after_all_execution_order(tmp_path: Path, mocker: MockerFixture): + model = """ + MODEL ( + name test_schema.model_that_depends_on_before_all, + kind FULL, + ); + + SELECT id, value FROM before_all_created_table + """ + + models_dir = tmp_path / "models" + models_dir.mkdir() + + with open(models_dir / "model.sql", "w") as f: + f.write(model) + + # before_all statement that creates a table that the above model depends on + before_all_statement = ( + "CREATE TABLE IF NOT EXISTS before_all_created_table AS SELECT 1 AS id, 'test' AS value" + ) + + # after_all that depends on the model + after_all_statement = "CREATE TABLE IF NOT EXISTS after_all_created_table AS SELECT id, value FROM test_schema.model_that_depends_on_before_all" + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + before_all=[before_all_statement], + after_all=[after_all_statement], + ) + + execute_calls: t.List[str] = [] + + original_duckdb_execute = DuckDBEngineAdapter.execute + + def track_duckdb_execute(self, expression, **kwargs): + sql = expression if isinstance(expression, str) else expression.sql(dialect="duckdb") + state_tables = [ + "_snapshots", + "_environments", + "_versions", + "_intervals", + "_auto_restatements", + "_environment_statements", + ] + + # to ignore the state queries + if not any(table in sql.lower() for table in state_tables): + execute_calls.append(sql) + + return original_duckdb_execute(self, expression, **kwargs) + + ctx = Context(paths=[tmp_path], config=config) + + # the plan would fail if the execution order ever changes and before_all statements dont execute first + ctx.plan(auto_apply=True, no_prompts=True) + + mocker.patch.object(DuckDBEngineAdapter, "execute", track_duckdb_execute) + + # run with the patched execute + ctx.run("prod", start="2023-01-01", end="2023-01-02") + + # validate explicitly that the first execute is for the before_all + assert "before_all_created_table" in execute_calls[0] + + # and that the last is the sole after all that depends on the model + assert "after_all_created_table" in execute_calls[-1] + + +def test_auto_categorization(sushi_context: Context): + environment = "dev" + for config in sushi_context.configs.values(): + config.plan.auto_categorize_changes.sql = AutoCategorizationMode.FULL + initial_add(sushi_context, environment) + + version = sushi_context.get_snapshot( + "sushi.waiter_as_customer_by_day", raise_if_missing=True + ).version + fingerprint = sushi_context.get_snapshot( + "sushi.waiter_as_customer_by_day", raise_if_missing=True + ).fingerprint + + model = t.cast(SqlModel, sushi_context.get_model("sushi.customers", raise_if_missing=True)) + sushi_context.upsert_model( + "sushi.customers", + query_=ParsableSql(sql=model.query.select("'foo' AS foo").sql(dialect=model.dialect)), # type: ignore + ) + apply_to_environment(sushi_context, environment) + + assert ( + sushi_context.get_snapshot( + "sushi.waiter_as_customer_by_day", raise_if_missing=True + ).change_category + == SnapshotChangeCategory.INDIRECT_NON_BREAKING + ) + assert ( + sushi_context.get_snapshot( + "sushi.waiter_as_customer_by_day", raise_if_missing=True + ).fingerprint + != fingerprint + ) + assert ( + sushi_context.get_snapshot("sushi.waiter_as_customer_by_day", raise_if_missing=True).version + == version + ) diff --git a/tests/core/integration/test_cron.py b/tests/core/integration/test_cron.py new file mode 100644 index 0000000000..fa327ac36f --- /dev/null +++ b/tests/core/integration/test_cron.py @@ -0,0 +1,247 @@ +from __future__ import annotations + +import typing as t +import pytest +import time_machine + +from sqlmesh.core import dialect as d +from sqlmesh.core.model import ( + SqlModel, + load_sql_based_model, +) +from sqlmesh.core.plan import SnapshotIntervals +from sqlmesh.utils.date import to_timestamp +from tests.core.integration.utils import add_projection_to_model + +pytestmark = pytest.mark.slow + + +@time_machine.travel("2023-01-08 00:00:00 UTC") +@pytest.mark.parametrize( + "forward_only, expected_intervals", + [ + ( + False, + [ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + ], + ), + ( + True, + [ + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + ], + ), + ], +) +def test_cron_not_aligned_with_day_boundary( + init_and_plan_context: t.Callable, + forward_only: bool, + expected_intervals: t.List[t.Tuple[int, int]], +): + context, plan = init_and_plan_context("examples/sushi") + + model = context.get_model("sushi.waiter_revenue_by_day") + model = SqlModel.parse_obj( + { + **model.dict(), + "kind": model.kind.copy(update={"forward_only": forward_only}), + "cron": "0 12 * * *", + } + ) + context.upsert_model(model) + + plan = context.plan_builder("prod", skip_tests=True).build() + context.apply(plan) + + waiter_revenue_by_day_snapshot = context.get_snapshot(model.name, raise_if_missing=True) + assert waiter_revenue_by_day_snapshot.intervals == [ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-07")) + ] + + model = add_projection_to_model(t.cast(SqlModel, model), literal=True) + context.upsert_model(model) + + waiter_revenue_by_day_snapshot = context.get_snapshot( + "sushi.waiter_revenue_by_day", raise_if_missing=True + ) + + with time_machine.travel("2023-01-08 00:10:00 UTC"): # Past model's cron. + plan = context.plan_builder( + "dev", select_models=[model.name], skip_tests=True, enable_preview=True + ).build() + assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=waiter_revenue_by_day_snapshot.snapshot_id, + intervals=expected_intervals, + ), + ] + + +@time_machine.travel("2023-01-08 00:00:00 UTC") +def test_cron_not_aligned_with_day_boundary_new_model(init_and_plan_context: t.Callable): + context, _ = init_and_plan_context("examples/sushi") + + existing_model = context.get_model("sushi.waiter_revenue_by_day") + existing_model = SqlModel.parse_obj( + { + **existing_model.dict(), + "kind": existing_model.kind.copy(update={"forward_only": True}), + } + ) + context.upsert_model(existing_model) + + plan = context.plan_builder("prod", skip_tests=True).build() + context.apply(plan) + + # Add a new model and make a change to a forward-only model. + # The cron of the new model is not aligned with the day boundary. + new_model = load_sql_based_model( + d.parse( + """ + MODEL ( + name memory.sushi.new_model, + kind FULL, + cron '0 8 * * *', + start '2023-01-01', + ); + + SELECT 1 AS one; + """ + ) + ) + context.upsert_model(new_model) + + existing_model = add_projection_to_model(t.cast(SqlModel, existing_model), literal=True) + context.upsert_model(existing_model) + + plan = context.plan_builder("dev", skip_tests=True, enable_preview=True).build() + assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=context.get_snapshot( + "memory.sushi.new_model", raise_if_missing=True + ).snapshot_id, + intervals=[(to_timestamp("2023-01-06"), to_timestamp("2023-01-07"))], + ), + SnapshotIntervals( + snapshot_id=context.get_snapshot( + "sushi.waiter_revenue_by_day", raise_if_missing=True + ).snapshot_id, + intervals=[ + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + ] + + +@time_machine.travel("2023-01-08 00:00:00 UTC", tick=False) +def test_parent_cron_after_child(init_and_plan_context: t.Callable): + context, plan = init_and_plan_context("examples/sushi") + + model = context.get_model("sushi.waiter_revenue_by_day") + model = SqlModel.parse_obj( + { + **model.dict(), + "cron": "50 23 * * *", + } + ) + context.upsert_model(model) + + plan = context.plan_builder("prod", skip_tests=True).build() + context.apply(plan) + + waiter_revenue_by_day_snapshot = context.get_snapshot(model.name, raise_if_missing=True) + assert waiter_revenue_by_day_snapshot.intervals == [ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-07")) + ] + + top_waiters_model = context.get_model("sushi.top_waiters") + top_waiters_model = add_projection_to_model(t.cast(SqlModel, top_waiters_model), literal=True) + context.upsert_model(top_waiters_model) + + top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) + + with time_machine.travel("2023-01-08 23:55:00 UTC"): # Past parent's cron, but before child's + plan = context.plan_builder("dev", skip_tests=True).build() + # Make sure the waiter_revenue_by_day model is not backfilled. + assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=top_waiters_snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + ] + + +@time_machine.travel("2025-03-08 00:00:00 UTC") +def test_tz(init_and_plan_context): + context, _ = init_and_plan_context("examples/sushi") + + model = context.get_model("sushi.waiter_revenue_by_day") + context.upsert_model( + SqlModel.parse_obj( + {**model.dict(), "cron_tz": "America/Los_Angeles", "start": "2025-03-07"} + ) + ) + + def assert_intervals(plan, intervals): + assert ( + next( + intervals.intervals + for intervals in plan.missing_intervals + if intervals.snapshot_id.name == model.fqn + ) + == intervals + ) + + plan = context.plan_builder("prod", skip_tests=True).build() + + # we have missing intervals but not waiter_revenue_by_day because it's not midnight pacific yet + assert plan.missing_intervals + + with pytest.raises(StopIteration): + assert_intervals(plan, []) + + # now we're ready 8AM UTC == midnight PST + with time_machine.travel("2025-03-08 08:00:00 UTC"): + plan = context.plan_builder("prod", skip_tests=True).build() + assert_intervals(plan, [(to_timestamp("2025-03-07"), to_timestamp("2025-03-08"))]) + + with time_machine.travel("2025-03-09 07:00:00 UTC"): + plan = context.plan_builder("prod", skip_tests=True).build() + + assert_intervals( + plan, + [ + (to_timestamp("2025-03-07"), to_timestamp("2025-03-08")), + ], + ) + + with time_machine.travel("2025-03-09 08:00:00 UTC"): + plan = context.plan_builder("prod", skip_tests=True).build() + + assert_intervals( + plan, + [ + (to_timestamp("2025-03-07"), to_timestamp("2025-03-08")), + (to_timestamp("2025-03-08"), to_timestamp("2025-03-09")), + ], + ) + + context.apply(plan) + + plan = context.plan_builder("prod", skip_tests=True).build() + assert not plan.missing_intervals diff --git a/tests/core/integration/test_dbt.py b/tests/core/integration/test_dbt.py new file mode 100644 index 0000000000..6f23acb97e --- /dev/null +++ b/tests/core/integration/test_dbt.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +import typing as t +import pytest +from sqlmesh.core.model.common import ParsableSql +import time_machine + +from sqlmesh.core.context import Context +from sqlmesh.core.model import ( + IncrementalUnmanagedKind, +) +from sqlmesh.core.snapshot import ( + DeployabilityIndex, + SnapshotChangeCategory, +) + +if t.TYPE_CHECKING: + pass + +pytestmark = pytest.mark.slow + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_dbt_select_star_is_directly_modified(sushi_test_dbt_context: Context): + context = sushi_test_dbt_context + + model = context.get_model("sushi.simple_model_a") + context.upsert_model( + model, + query_=ParsableSql(sql="SELECT 1 AS a, 2 AS b"), + ) + + snapshot_a_id = context.get_snapshot("sushi.simple_model_a").snapshot_id # type: ignore + snapshot_b_id = context.get_snapshot("sushi.simple_model_b").snapshot_id # type: ignore + + plan = context.plan_builder("dev", skip_tests=True).build() + assert plan.directly_modified == {snapshot_a_id, snapshot_b_id} + assert {i.snapshot_id for i in plan.missing_intervals} == {snapshot_a_id, snapshot_b_id} + + assert plan.snapshots[snapshot_a_id].change_category == SnapshotChangeCategory.NON_BREAKING + assert plan.snapshots[snapshot_b_id].change_category == SnapshotChangeCategory.NON_BREAKING + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_dbt_is_incremental_table_is_missing(sushi_test_dbt_context: Context): + context = sushi_test_dbt_context + + model = context.get_model("sushi.waiter_revenue_by_day_v2") + model = model.copy(update={"kind": IncrementalUnmanagedKind(), "start": "2023-01-01"}) + context.upsert_model(model) + context._standalone_audits["sushi.test_top_waiters"].start = "2023-01-01" + + context.plan("prod", auto_apply=True, no_prompts=True, skip_tests=True) + + snapshot = context.get_snapshot("sushi.waiter_revenue_by_day_v2") + assert snapshot + + # Manually drop the table + context.engine_adapter.drop_table(snapshot.table_name()) + + context.snapshot_evaluator.evaluate( + snapshot, + start="2023-01-01", + end="2023-01-08", + execution_time="2023-01-08 15:00:00", + snapshots={s.name: s for s in context.snapshots.values()}, + deployability_index=DeployabilityIndex.all_deployable(), + ) + + # Make sure the table was recreated + assert context.engine_adapter.table_exists(snapshot.table_name()) + + +def test_model_attr(sushi_test_dbt_context: Context, assert_exp_eq): + context = sushi_test_dbt_context + model = context.get_model("sushi.top_waiters") + assert_exp_eq( + model.render_query(), + """ + SELECT + CAST("waiter_id" AS INT) AS "waiter_id", + CAST("revenue" AS DOUBLE) AS "revenue", + 3 AS "model_columns" + FROM "memory"."sushi"."waiter_revenue_by_day_v2" AS "waiter_revenue_by_day_v2" + WHERE + "ds" = ( + SELECT + MAX("ds") + FROM "memory"."sushi"."waiter_revenue_by_day_v2" AS "waiter_revenue_by_day_v2" + ) + ORDER BY + "revenue" DESC NULLS FIRST + LIMIT 10 + """, + ) + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_dbt_requirements(sushi_dbt_context: Context): + assert set(sushi_dbt_context.requirements) == {"dbt-core", "dbt-duckdb"} + assert sushi_dbt_context.requirements["dbt-core"].startswith("1.") + assert sushi_dbt_context.requirements["dbt-duckdb"].startswith("1.") + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_dbt_dialect_with_normalization_strategy(init_and_plan_context: t.Callable): + context, _ = init_and_plan_context( + "tests/fixtures/dbt/sushi_test", config="test_config_with_normalization_strategy" + ) + assert context.default_dialect == "duckdb,normalization_strategy=LOWERCASE" + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_dbt_before_all_with_var_ref_source(init_and_plan_context: t.Callable): + _, plan = init_and_plan_context( + "tests/fixtures/dbt/sushi_test", config="test_config_with_normalization_strategy" + ) + environment_statements = plan.to_evaluatable().environment_statements + assert environment_statements + rendered_statements = [e.render_before_all(dialect="duckdb") for e in environment_statements] + assert rendered_statements[0] == [ + "CREATE TABLE IF NOT EXISTS analytic_stats (physical_table TEXT, evaluation_time TEXT)", + "CREATE TABLE IF NOT EXISTS to_be_executed_last (col TEXT)", + "SELECT 1 AS var, 'items' AS src, 'waiters' AS ref", + ] diff --git a/tests/core/integration/test_dev_only_vde.py b/tests/core/integration/test_dev_only_vde.py new file mode 100644 index 0000000000..611e207771 --- /dev/null +++ b/tests/core/integration/test_dev_only_vde.py @@ -0,0 +1,477 @@ +from __future__ import annotations + +import typing as t +import pytest +from sqlmesh.core.model.common import ParsableSql +import time_machine + +from sqlmesh.core import dialect as d +from sqlmesh.core.config.common import VirtualEnvironmentMode +from sqlmesh.core.model import ( + FullKind, + IncrementalUnmanagedKind, + SqlModel, + ViewKind, + load_sql_based_model, +) +from sqlmesh.core.plan import SnapshotIntervals +from sqlmesh.core.snapshot import ( + SnapshotChangeCategory, +) +from sqlmesh.utils.date import to_date, to_timestamp +from tests.core.integration.utils import add_projection_to_model + +pytestmark = pytest.mark.slow + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_virtual_environment_mode_dev_only(init_and_plan_context: t.Callable): + context, _ = init_and_plan_context( + "examples/sushi", config="test_config_virtual_environment_mode_dev_only" + ) + + assert all( + s.virtual_environment_mode.is_dev_only or not s.is_model or s.is_symbolic + for s in context.snapshots.values() + ) + + # Init prod + context.plan("prod", auto_apply=True, no_prompts=True) + + # Make a change in dev + original_model = context.get_model("sushi.waiter_revenue_by_day") + original_fingerprint = context.get_snapshot(original_model.name).fingerprint + model = original_model.copy( + update={ + "query_": ParsableSql( + sql=original_model.query.order_by("waiter_id").sql(dialect=original_model.dialect) + ) + } + ) + model = add_projection_to_model(t.cast(SqlModel, model)) + context.upsert_model(model) + + plan_dev = context.plan_builder("dev").build() + assert to_timestamp(plan_dev.start) == to_timestamp("2023-01-07") + assert plan_dev.requires_backfill + assert plan_dev.missing_intervals == [ + SnapshotIntervals( + snapshot_id=context.get_snapshot("sushi.top_waiters").snapshot_id, + intervals=[(to_timestamp("2023-01-07"), to_timestamp("2023-01-08"))], + ), + SnapshotIntervals( + snapshot_id=context.get_snapshot("sushi.waiter_revenue_by_day").snapshot_id, + intervals=[(to_timestamp("2023-01-07"), to_timestamp("2023-01-08"))], + ), + ] + assert plan_dev.context_diff.snapshots[context.get_snapshot(model.name).snapshot_id].intervals + assert plan_dev.context_diff.snapshots[ + context.get_snapshot("sushi.top_waiters").snapshot_id + ].intervals + assert plan_dev.context_diff.snapshots[ + context.get_snapshot(model.name).snapshot_id + ].dev_intervals + assert plan_dev.context_diff.snapshots[ + context.get_snapshot("sushi.top_waiters").snapshot_id + ].dev_intervals + context.apply(plan_dev) + + # Make sure the waiter_revenue_by_day model is a table in prod and a view in dev + table_types_df = context.engine_adapter.fetchdf( + "SELECT table_schema, table_type FROM INFORMATION_SCHEMA.TABLES WHERE table_name = 'waiter_revenue_by_day'" + ) + assert table_types_df.to_dict("records") == [ + {"table_schema": "sushi", "table_type": "BASE TABLE"}, + {"table_schema": "sushi__dev", "table_type": "VIEW"}, + ] + + # Check that the specified dates were backfilled + min_event_date = context.engine_adapter.fetchone( + "SELECT MIN(event_date) FROM sushi__dev.waiter_revenue_by_day" + )[0] + assert min_event_date == to_date("2023-01-07") + + # Make sure the changes are applied without backfill in prod + plan_prod = context.plan_builder("prod").build() + assert not plan_prod.requires_backfill + assert not plan_prod.missing_intervals + context.apply(plan_prod) + assert "one" in context.engine_adapter.columns("sushi.waiter_revenue_by_day") + + # Make sure the revert of a breaking changes results in a full rebuild + context.upsert_model(original_model) + assert context.get_snapshot(original_model.name).fingerprint == original_fingerprint + + plan_prod = context.plan_builder( + "prod", allow_destructive_models=["sushi.waiter_revenue_by_day"] + ).build() + assert not plan_prod.requires_backfill + assert not plan_prod.missing_intervals + context.apply(plan_prod) + assert "one" not in context.engine_adapter.columns("sushi.waiter_revenue_by_day") + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_virtual_environment_mode_dev_only_model_kind_change(init_and_plan_context: t.Callable): + context, plan = init_and_plan_context( + "examples/sushi", config="test_config_virtual_environment_mode_dev_only" + ) + context.apply(plan) + + # Change to full kind + model = context.get_model("sushi.top_waiters") + model = model.copy(update={"kind": FullKind()}) + context.upsert_model(model) + prod_plan = context.plan_builder("prod", skip_tests=True).build() + assert prod_plan.missing_intervals + assert prod_plan.requires_backfill + assert not prod_plan.context_diff.snapshots[ + context.get_snapshot(model.name).snapshot_id + ].intervals + context.apply(prod_plan) + data_objects = context.engine_adapter.get_data_objects("sushi", {"top_waiters"}) + assert len(data_objects) == 1 + assert data_objects[0].type == "table" + + # Change back to view + model = context.get_model("sushi.top_waiters") + model = model.copy(update={"kind": ViewKind()}) + context.upsert_model(model) + prod_plan = context.plan_builder("prod", skip_tests=True).build() + assert prod_plan.requires_backfill + assert prod_plan.missing_intervals + assert not prod_plan.context_diff.snapshots[ + context.get_snapshot(model.name).snapshot_id + ].intervals + context.apply(prod_plan) + data_objects = context.engine_adapter.get_data_objects("sushi", {"top_waiters"}) + assert len(data_objects) == 1 + assert data_objects[0].type == "view" + + # Change to incremental + model = context.get_model("sushi.top_waiters") + model = model.copy(update={"kind": IncrementalUnmanagedKind()}) + context.upsert_model(model) + prod_plan = context.plan_builder("prod", skip_tests=True).build() + assert prod_plan.requires_backfill + assert prod_plan.missing_intervals + assert not prod_plan.context_diff.snapshots[ + context.get_snapshot(model.name).snapshot_id + ].intervals + context.apply(prod_plan) + data_objects = context.engine_adapter.get_data_objects("sushi", {"top_waiters"}) + assert len(data_objects) == 1 + assert data_objects[0].type == "table" + + # Change back to full + model = context.get_model("sushi.top_waiters") + model = model.copy(update={"kind": FullKind()}) + context.upsert_model(model) + prod_plan = context.plan_builder("prod", skip_tests=True).build() + assert prod_plan.requires_backfill + assert prod_plan.missing_intervals + assert not prod_plan.context_diff.snapshots[ + context.get_snapshot(model.name).snapshot_id + ].intervals + context.apply(prod_plan) + data_objects = context.engine_adapter.get_data_objects("sushi", {"top_waiters"}) + assert len(data_objects) == 1 + assert data_objects[0].type == "table" + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_virtual_environment_mode_dev_only_model_kind_change_incremental( + init_and_plan_context: t.Callable, +): + context, _ = init_and_plan_context( + "examples/sushi", config="test_config_virtual_environment_mode_dev_only" + ) + + forward_only_model_name = "memory.sushi.test_forward_only_model" + forward_only_model_expressions = d.parse( + f""" + MODEL ( + name {forward_only_model_name}, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds, + forward_only true, + ), + ); + + SELECT '2023-01-01' AS ds, 'value' AS value; + """ + ) + forward_only_model = load_sql_based_model(forward_only_model_expressions) + forward_only_model = forward_only_model.copy( + update={"virtual_environment_mode": VirtualEnvironmentMode.DEV_ONLY} + ) + context.upsert_model(forward_only_model) + + context.plan("prod", auto_apply=True, no_prompts=True) + + # Change to view + model = context.get_model(forward_only_model_name) + original_kind = model.kind + model = model.copy(update={"kind": ViewKind()}) + context.upsert_model(model) + prod_plan = context.plan_builder("prod", skip_tests=True).build() + assert prod_plan.requires_backfill + assert prod_plan.missing_intervals + assert not prod_plan.context_diff.snapshots[ + context.get_snapshot(model.name).snapshot_id + ].intervals + context.apply(prod_plan) + data_objects = context.engine_adapter.get_data_objects("sushi", {"test_forward_only_model"}) + assert len(data_objects) == 1 + assert data_objects[0].type == "view" + + model = model.copy(update={"kind": original_kind}) + context.upsert_model(model) + prod_plan = context.plan_builder("prod", skip_tests=True).build() + assert prod_plan.requires_backfill + assert prod_plan.missing_intervals + assert not prod_plan.context_diff.snapshots[ + context.get_snapshot(model.name).snapshot_id + ].intervals + context.apply(prod_plan) + data_objects = context.engine_adapter.get_data_objects("sushi", {"test_forward_only_model"}) + assert len(data_objects) == 1 + assert data_objects[0].type == "table" + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_virtual_environment_mode_dev_only_model_kind_change_with_follow_up_changes_in_dev( + init_and_plan_context: t.Callable, +): + context, plan = init_and_plan_context( + "examples/sushi", config="test_config_virtual_environment_mode_dev_only" + ) + context.apply(plan) + + # Make sure the initial state is a view + data_objects = context.engine_adapter.get_data_objects("sushi", {"top_waiters"}) + assert len(data_objects) == 1 + assert data_objects[0].type == "view" + + # Change to incremental unmanaged kind + model = context.get_model("sushi.top_waiters") + model = model.copy(update={"kind": IncrementalUnmanagedKind()}) + context.upsert_model(model) + dev_plan = context.plan_builder("dev", skip_tests=True).build() + assert dev_plan.missing_intervals + assert dev_plan.requires_backfill + context.apply(dev_plan) + + # Make a follow-up forward-only change + model = add_projection_to_model(t.cast(SqlModel, model)) + context.upsert_model(model) + dev_plan = context.plan_builder("dev", skip_tests=True, forward_only=True).build() + context.apply(dev_plan) + + # Deploy to prod + prod_plan = context.plan_builder("prod", skip_tests=True).build() + assert prod_plan.requires_backfill + assert prod_plan.missing_intervals + assert not prod_plan.context_diff.snapshots[ + context.get_snapshot(model.name).snapshot_id + ].intervals + context.apply(prod_plan) + data_objects = context.engine_adapter.get_data_objects("sushi", {"top_waiters"}) + assert len(data_objects) == 1 + assert data_objects[0].type == "table" + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_virtual_environment_mode_dev_only_model_kind_change_manual_categorization( + init_and_plan_context: t.Callable, +): + context, plan = init_and_plan_context( + "examples/sushi", config="test_config_virtual_environment_mode_dev_only" + ) + context.apply(plan) + + model = context.get_model("sushi.top_waiters") + model = model.copy(update={"kind": FullKind()}) + context.upsert_model(model) + dev_plan_builder = context.plan_builder("dev", skip_tests=True, no_auto_categorization=True) + dev_plan_builder.set_choice( + dev_plan_builder._context_diff.snapshots[context.get_snapshot(model.name).snapshot_id], + SnapshotChangeCategory.NON_BREAKING, + ) + dev_plan = dev_plan_builder.build() + assert dev_plan.requires_backfill + assert len(dev_plan.missing_intervals) == 1 + context.apply(dev_plan) + + prod_plan = context.plan_builder("prod", skip_tests=True).build() + assert prod_plan.requires_backfill + assert prod_plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=context.get_snapshot("sushi.top_waiters").snapshot_id, + intervals=[ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + ] + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_virtual_environment_mode_dev_only_seed_model_change( + init_and_plan_context: t.Callable, +): + context, _ = init_and_plan_context( + "examples/sushi", config="test_config_virtual_environment_mode_dev_only" + ) + context.load() + context.plan("prod", auto_apply=True, no_prompts=True) + + seed_model = context.get_model("sushi.waiter_names") + with open(seed_model.seed_path, "a") as fd: + fd.write("\n123,New Test Name") + + context.load() + seed_model_snapshot = context.get_snapshot("sushi.waiter_names") + plan = context.plan_builder("dev").build() + assert plan.directly_modified == {seed_model_snapshot.snapshot_id} + assert len(plan.missing_intervals) == 2 + context.apply(plan) + + actual_seed_df_in_dev = context.fetchdf("SELECT * FROM sushi__dev.waiter_names WHERE id = 123") + assert actual_seed_df_in_dev.to_dict("records") == [{"id": 123, "name": "New Test Name"}] + actual_seed_df_in_prod = context.fetchdf("SELECT * FROM sushi.waiter_names WHERE id = 123") + assert actual_seed_df_in_prod.empty + + plan = context.plan_builder("prod").build() + assert plan.directly_modified == {seed_model_snapshot.snapshot_id} + assert len(plan.missing_intervals) == 1 + assert plan.missing_intervals[0].snapshot_id == seed_model_snapshot.snapshot_id + context.apply(plan) + + actual_seed_df_in_prod = context.fetchdf("SELECT * FROM sushi.waiter_names WHERE id = 123") + assert actual_seed_df_in_prod.to_dict("records") == [{"id": 123, "name": "New Test Name"}] + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_virtual_environment_mode_dev_only_model_change_downstream_of_seed( + init_and_plan_context: t.Callable, +): + """This test covers a scenario when a model downstream of a seed model is modified and explicitly selected + causing an (unhydrated) seed model to sourced from the state. If SQLMesh attempts to create + a table for the unchanged seed model, it will fail because the seed model is not hydrated. + """ + context, _ = init_and_plan_context( + "examples/sushi", config="test_config_virtual_environment_mode_dev_only" + ) + context.load() + context.plan("prod", auto_apply=True, no_prompts=True) + + # Make sure that a different version of the seed model is loaded + seed_model = context.get_model("sushi.waiter_names") + seed_model = seed_model.copy(update={"stamp": "force new version"}) + context.upsert_model(seed_model) + + # Make a change to the downstream model + model = context.get_model("sushi.waiter_as_customer_by_day") + model = model.copy(update={"stamp": "force new version"}) + context.upsert_model(model) + + # It is important to clear the cache so that the hydrated seed model is not sourced from the cache + context.clear_caches() + + # Make sure to use the selector so that the seed model is sourced from the state + plan = context.plan_builder("dev", select_models=[model.name]).build() + assert len(plan.directly_modified) == 1 + assert list(plan.directly_modified)[0].name == model.fqn + assert len(plan.missing_intervals) == 1 + assert plan.missing_intervals[0].snapshot_id.name == model.fqn + + # Make sure there's no error when applying the plan + context.apply(plan) + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_virtual_environment_mode_dev_only_model_change_standalone_audit( + init_and_plan_context: t.Callable, +): + context, plan = init_and_plan_context( + "examples/sushi", config="test_config_virtual_environment_mode_dev_only" + ) + context.apply(plan) + + # Change a model upstream from a standalone audit + model = context.get_model("sushi.items") + model = model.copy(update={"stamp": "force new version"}) + context.upsert_model(model) + + plan = context.plan_builder("prod", skip_tests=True).build() + + # Make sure the standalone audit is among modified + assert ( + context.get_snapshot("assert_item_price_above_zero").snapshot_id + in plan.indirectly_modified[context.get_snapshot("sushi.items").snapshot_id] + ) + + # Make sure there's no error when applying the plan + context.apply(plan) + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_virtual_environment_mode_dev_only_seed_model_change_schema( + init_and_plan_context: t.Callable, +): + context, plan = init_and_plan_context( + "examples/sushi", config="test_config_virtual_environment_mode_dev_only" + ) + context.apply(plan) + + new_csv = [] + with open(context.path / "seeds" / "waiter_names.csv", "r") as fd: + is_header = True + for idx, line in enumerate(fd): + line = line.strip() + if not line: + continue + if is_header: + new_csv.append(line + ",new_column") + is_header = False + else: + new_csv.append(line + f",v{idx}") + + with open(context.path / "seeds" / "waiter_names.csv", "w") as fd: + fd.write("\n".join(new_csv)) + + context.load() + + downstream_model = context.get_model("sushi.waiter_as_customer_by_day") + downstream_model_kind = downstream_model.kind.dict() + downstream_model_kwargs = { + **downstream_model.dict(), + "kind": { + **downstream_model_kind, + "on_destructive_change": "allow", + }, + "audits": [], + # Use the new column + "query": "SELECT '2023-01-07' AS event_date, new_column AS new_column FROM sushi.waiter_names", + } + context.upsert_model(SqlModel.parse_obj(downstream_model_kwargs)) + + context.plan("dev", auto_apply=True, no_prompts=True, skip_tests=True, enable_preview=True) + + assert ( + context.engine_adapter.fetchone( + "SELECT COUNT(*) FROM sushi__dev.waiter_as_customer_by_day" + )[0] + == len(new_csv) - 1 + ) + + # Deploy to prod + context.clear_caches() + context.plan("prod", auto_apply=True, no_prompts=True, skip_tests=True) + assert "new_column" in context.engine_adapter.columns("sushi.waiter_as_customer_by_day") diff --git a/tests/core/integration/test_forward_only.py b/tests/core/integration/test_forward_only.py new file mode 100644 index 0000000000..2dddf18efd --- /dev/null +++ b/tests/core/integration/test_forward_only.py @@ -0,0 +1,1497 @@ +from __future__ import annotations + +import typing as t +import numpy as np # noqa: TID253 +import pandas as pd # noqa: TID253 +import pytest +import time_machine + +from sqlmesh.core import dialect as d +from sqlmesh.core.context import Context +from sqlmesh.core.config.categorizer import CategorizerConfig +from sqlmesh.core.model import ( + FullKind, + SqlModel, + load_sql_based_model, +) +from sqlmesh.core.plan import SnapshotIntervals +from sqlmesh.core.snapshot import ( + SnapshotChangeCategory, +) +from sqlmesh.utils.date import to_datetime, to_timestamp +from tests.core.integration.utils import add_projection_to_model + +pytestmark = pytest.mark.slow + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +@pytest.mark.parametrize( + "context_fixture", + ["sushi_context", "sushi_no_default_catalog"], +) +def test_forward_only_plan_with_effective_date(context_fixture: Context, request): + context = request.getfixturevalue(context_fixture) + model_name = "sushi.waiter_revenue_by_day" + model = context.get_model(model_name) + context.upsert_model(add_projection_to_model(t.cast(SqlModel, model)), start="2023-01-01") + snapshot = context.get_snapshot(model, raise_if_missing=True) + top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) + + plan_builder = context.plan_builder("dev", skip_tests=True, forward_only=True) + plan = plan_builder.build() + assert len(plan.new_snapshots) == 2 + assert ( + plan.context_diff.snapshots[snapshot.snapshot_id].change_category + == SnapshotChangeCategory.NON_BREAKING + ) + assert ( + plan.context_diff.snapshots[top_waiters_snapshot.snapshot_id].change_category + == SnapshotChangeCategory.INDIRECT_NON_BREAKING + ) + assert plan.context_diff.snapshots[snapshot.snapshot_id].is_forward_only + assert plan.context_diff.snapshots[top_waiters_snapshot.snapshot_id].is_forward_only + + assert to_timestamp(plan.start) == to_timestamp("2023-01-07") + assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=snapshot.snapshot_id, + intervals=[(to_timestamp("2023-01-07"), to_timestamp("2023-01-08"))], + ), + ] + + plan = plan_builder.set_effective_from("2023-01-05").build() + # Default start should be set to effective_from + assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=top_waiters_snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + SnapshotIntervals( + snapshot_id=snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + ] + + plan = plan_builder.set_start("2023-01-06").build() + # Start override should take precedence + assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=top_waiters_snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + SnapshotIntervals( + snapshot_id=snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + ] + + plan = plan_builder.set_effective_from("2023-01-04").build() + # Start should remain unchanged + assert plan.start == "2023-01-06" + assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=top_waiters_snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + SnapshotIntervals( + snapshot_id=snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + ] + + context.apply(plan) + + dev_df = context.engine_adapter.fetchdf( + "SELECT DISTINCT event_date FROM sushi__dev.waiter_revenue_by_day ORDER BY event_date" + ) + assert dev_df["event_date"].tolist() == [ + pd.to_datetime("2023-01-06"), + pd.to_datetime("2023-01-07"), + ] + + prod_plan = context.plan_builder(skip_tests=True).build() + # Make sure that the previously set effective_from is respected + assert prod_plan.start == to_timestamp("2023-01-04") + assert prod_plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=top_waiters_snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + SnapshotIntervals( + snapshot_id=snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + ] + + context.apply(prod_plan) + + prod_df = context.engine_adapter.fetchdf( + "SELECT DISTINCT event_date FROM sushi.waiter_revenue_by_day WHERE one IS NOT NULL ORDER BY event_date" + ) + assert prod_df["event_date"].tolist() == [ + pd.to_datetime(x) for x in ["2023-01-04", "2023-01-05", "2023-01-06", "2023-01-07"] + ] + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_forward_only_model_regular_plan(init_and_plan_context: t.Callable): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + model_name = "sushi.waiter_revenue_by_day" + + model = context.get_model(model_name) + model = add_projection_to_model(t.cast(SqlModel, model)) + forward_only_kind = model.kind.copy(update={"forward_only": True}) + model = model.copy(update={"kind": forward_only_kind}) + + context.upsert_model(model) + snapshot = context.get_snapshot(model, raise_if_missing=True) + top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) + + plan = context.plan_builder("dev", skip_tests=True, enable_preview=False).build() + assert len(plan.new_snapshots) == 2 + assert ( + plan.context_diff.snapshots[snapshot.snapshot_id].change_category + == SnapshotChangeCategory.NON_BREAKING + ) + assert ( + plan.context_diff.snapshots[top_waiters_snapshot.snapshot_id].change_category + == SnapshotChangeCategory.INDIRECT_NON_BREAKING + ) + assert plan.context_diff.snapshots[snapshot.snapshot_id].is_forward_only + assert plan.context_diff.snapshots[top_waiters_snapshot.snapshot_id].is_forward_only + + assert plan.start == to_datetime("2023-01-01") + assert not plan.missing_intervals + + context.apply(plan) + + dev_df = context.engine_adapter.fetchdf( + "SELECT DISTINCT event_date FROM sushi__dev.waiter_revenue_by_day ORDER BY event_date" + ) + assert not dev_df["event_date"].tolist() + + # Run a restatement plan to preview changes + plan_builder = context.plan_builder( + "dev", skip_tests=True, restate_models=[model_name], enable_preview=False + ) + plan_builder.set_start("2023-01-06") + assert plan_builder.build().missing_intervals == [ + SnapshotIntervals( + snapshot_id=top_waiters_snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + SnapshotIntervals( + snapshot_id=snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + ] + + # Make sure that changed start is reflected in missing intervals + plan_builder.set_start("2023-01-07") + assert plan_builder.build().missing_intervals == [ + SnapshotIntervals( + snapshot_id=top_waiters_snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + SnapshotIntervals( + snapshot_id=snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + ] + + context.apply(plan_builder.build()) + + dev_df = context.engine_adapter.fetchdf( + "SELECT DISTINCT event_date FROM sushi__dev.waiter_revenue_by_day ORDER BY event_date" + ) + assert dev_df["event_date"].tolist() == [pd.to_datetime("2023-01-07")] + + # Promote changes to prod + prod_plan = context.plan_builder(skip_tests=True).build() + assert not prod_plan.missing_intervals + + context.apply(prod_plan) + + # The change was applied in a forward-only manner so no values in the new column should be populated + prod_df = context.engine_adapter.fetchdf( + "SELECT DISTINCT event_date FROM sushi.waiter_revenue_by_day WHERE one IS NOT NULL ORDER BY event_date" + ) + assert not prod_df["event_date"].tolist() + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_forward_only_model_regular_plan_preview_enabled(init_and_plan_context: t.Callable): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + model_name = "sushi.waiter_revenue_by_day" + + model = context.get_model(model_name) + model = add_projection_to_model(t.cast(SqlModel, model)) + forward_only_kind = model.kind.copy(update={"forward_only": True}) + model = model.copy(update={"kind": forward_only_kind}) + + context.upsert_model(model) + snapshot = context.get_snapshot(model, raise_if_missing=True) + top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) + + plan = context.plan_builder("dev", skip_tests=True, enable_preview=True).build() + assert len(plan.new_snapshots) == 2 + assert ( + plan.context_diff.snapshots[snapshot.snapshot_id].change_category + == SnapshotChangeCategory.NON_BREAKING + ) + assert ( + plan.context_diff.snapshots[top_waiters_snapshot.snapshot_id].change_category + == SnapshotChangeCategory.INDIRECT_NON_BREAKING + ) + assert plan.context_diff.snapshots[snapshot.snapshot_id].is_forward_only + assert plan.context_diff.snapshots[top_waiters_snapshot.snapshot_id].is_forward_only + + assert to_timestamp(plan.start) == to_timestamp("2023-01-07") + assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + ] + + context.apply(plan) + + dev_df = context.engine_adapter.fetchdf( + "SELECT DISTINCT event_date FROM sushi__dev.waiter_revenue_by_day ORDER BY event_date" + ) + assert dev_df["event_date"].tolist() == [pd.to_datetime("2023-01-07")] + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_forward_only_model_restate_full_history_in_dev(init_and_plan_context: t.Callable): + context, _ = init_and_plan_context("examples/sushi") + + model_name = "memory.sushi.customer_max_revenue" + expressions = d.parse( + f""" + MODEL ( + name {model_name}, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key customer_id, + forward_only true, + ), + ); + + SELECT + customer_id, MAX(revenue) AS max_revenue + FROM memory.sushi.customer_revenue_lifetime + GROUP BY 1; + """ + ) + + model = load_sql_based_model(expressions) + assert model.forward_only + assert model.kind.full_history_restatement_only + context.upsert_model(model) + + context.plan("prod", skip_tests=True, auto_apply=True, enable_preview=False) + + model_kwargs = { + **model.dict(), + # Make a breaking change. + "query": model.query.order_by("customer_id"), # type: ignore + } + context.upsert_model(SqlModel.parse_obj(model_kwargs)) + + # Apply the model change in dev + plan = context.plan_builder( + "dev", + skip_tests=True, + enable_preview=False, + categorizer_config=CategorizerConfig.all_full(), + ).build() + assert not plan.missing_intervals + context.apply(plan) + + snapshot = context.get_snapshot(model, raise_if_missing=True) + snapshot_table_name = snapshot.table_name(False) + + # Manually insert a dummy value to check that the table is recreated during the restatement + context.engine_adapter.insert_append( + snapshot_table_name, + pd.DataFrame({"customer_id": [-1], "max_revenue": [100]}), + ) + df = context.engine_adapter.fetchdf( + "SELECT COUNT(*) AS cnt FROM sushi__dev.customer_max_revenue WHERE customer_id = -1" + ) + assert df["cnt"][0] == 1 + + # Apply a restatement plan in dev + plan = context.plan("dev", restate_models=[model.name], auto_apply=True, enable_preview=False) + assert len(plan.missing_intervals) == 1 + + # Check that the dummy value is not present + df = context.engine_adapter.fetchdf( + "SELECT COUNT(*) AS cnt FROM sushi__dev.customer_max_revenue WHERE customer_id = -1" + ) + assert df["cnt"][0] == 0 + + # Check that the table is not empty + df = context.engine_adapter.fetchdf( + "SELECT COUNT(*) AS cnt FROM sushi__dev.customer_max_revenue" + ) + assert df["cnt"][0] > 0 + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_full_history_restatement_model_regular_plan_preview_enabled( + init_and_plan_context: t.Callable, +): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + model_name = "sushi.marketing" # SCD2 model + + model = context.get_model(model_name) + model = add_projection_to_model(t.cast(SqlModel, model)) + + context.upsert_model(model) + snapshot = context.get_snapshot(model, raise_if_missing=True) + customers_snapshot = context.get_snapshot("sushi.customers", raise_if_missing=True) + active_customers_snapshot = context.get_snapshot( + "sushi.active_customers", raise_if_missing=True + ) + waiter_as_customer_snapshot = context.get_snapshot( + "sushi.waiter_as_customer_by_day", raise_if_missing=True + ) + + plan = context.plan_builder("dev", skip_tests=True, enable_preview=True).build() + + assert len(plan.new_snapshots) == 6 + assert ( + plan.context_diff.snapshots[snapshot.snapshot_id].change_category + == SnapshotChangeCategory.NON_BREAKING + ) + assert ( + plan.context_diff.snapshots[customers_snapshot.snapshot_id].change_category + == SnapshotChangeCategory.INDIRECT_NON_BREAKING + ) + assert ( + plan.context_diff.snapshots[active_customers_snapshot.snapshot_id].change_category + == SnapshotChangeCategory.INDIRECT_NON_BREAKING + ) + assert ( + plan.context_diff.snapshots[waiter_as_customer_snapshot.snapshot_id].change_category + == SnapshotChangeCategory.INDIRECT_NON_BREAKING + ) + assert all(s.is_forward_only for s in plan.new_snapshots) + + assert to_timestamp(plan.start) == to_timestamp("2023-01-07") + assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + ] + + context.apply(plan) + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_metadata_changed_regular_plan_preview_enabled(init_and_plan_context: t.Callable): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + model_name = "sushi.waiter_revenue_by_day" + + model = context.get_model(model_name) + model = model.copy(update={"owner": "new_owner"}) + + context.upsert_model(model) + snapshot = context.get_snapshot(model, raise_if_missing=True) + top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) + + plan = context.plan_builder("dev", skip_tests=True, enable_preview=True).build() + assert len(plan.new_snapshots) == 2 + assert ( + plan.context_diff.snapshots[snapshot.snapshot_id].change_category + == SnapshotChangeCategory.METADATA + ) + assert ( + plan.context_diff.snapshots[top_waiters_snapshot.snapshot_id].change_category + == SnapshotChangeCategory.METADATA + ) + assert not plan.missing_intervals + assert not plan.restatements + + +@time_machine.travel("2023-01-08 00:00:00 UTC") +def test_forward_only_preview_child_that_runs_before_parent(init_and_plan_context: t.Callable): + context, _ = init_and_plan_context("examples/sushi") + + # This model runs at minute 30 of every hour + upstream_model = load_sql_based_model( + d.parse( + """ + MODEL ( + name memory.sushi.upstream_model, + kind FULL, + cron '30 * * * *', + start '2023-01-01', + ); + + SELECT 1 AS a; + """ + ) + ) + context.upsert_model(upstream_model) + + # This model runs at minute 0 of every hour, so it runs before the upstream model + downstream_model = load_sql_based_model( + d.parse( + """ + MODEL ( + name memory.sushi.downstream_model, + kind INCREMENTAL_BY_TIME_RANGE( + time_column event_date, + forward_only True, + ), + cron '0 * * * *', + start '2023-01-01', + ); + + SELECT a, '2023-01-06' AS event_date FROM memory.sushi.upstream_model; + """ + ) + ) + context.upsert_model(downstream_model) + + context.plan("prod", skip_tests=True, auto_apply=True) + + with time_machine.travel("2023-01-08 00:05:00 UTC"): + # The downstream model runs but not the upstream model + context.run("prod") + + # Now it's time for the upstream model to run but it hasn't run yet + with time_machine.travel("2023-01-08 00:35:00 UTC"): + # Make a change to the downstream model. + downstream_model = add_projection_to_model(t.cast(SqlModel, downstream_model), literal=True) + context.upsert_model(downstream_model) + + # The plan should only backfill the downstream model despite upstream missing intervals + plan = context.plan_builder("dev", skip_tests=True, enable_preview=True).build() + assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=context.get_snapshot( + downstream_model.name, raise_if_missing=True + ).snapshot_id, + intervals=[ + (to_timestamp("2023-01-07 23:00:00"), to_timestamp("2023-01-08 00:00:00")) + ], + ), + ] + + +@time_machine.travel("2023-01-08 00:00:00 UTC") +def test_forward_only_monthly_model(init_and_plan_context: t.Callable): + context, _ = init_and_plan_context("examples/sushi") + + model = context.get_model("sushi.waiter_revenue_by_day") + model = SqlModel.parse_obj( + { + **model.dict(), + "kind": model.kind.copy(update={"forward_only": True}), + "cron": "0 0 1 * *", + "start": "2022-01-01", + "audits": [], + } + ) + context.upsert_model(model) + + plan = context.plan_builder("prod", skip_tests=True).build() + context.apply(plan) + + waiter_revenue_by_day_snapshot = context.get_snapshot(model.name, raise_if_missing=True) + assert waiter_revenue_by_day_snapshot.intervals == [ + (to_timestamp("2022-01-01"), to_timestamp("2023-01-01")) + ] + + model = add_projection_to_model(t.cast(SqlModel, model), literal=True) + context.upsert_model(model) + + waiter_revenue_by_day_snapshot = context.get_snapshot( + "sushi.waiter_revenue_by_day", raise_if_missing=True + ) + + plan = context.plan_builder( + "dev", select_models=[model.name], skip_tests=True, enable_preview=True + ).build() + assert to_timestamp(plan.start) == to_timestamp("2022-12-01") + assert to_timestamp(plan.end) == to_timestamp("2023-01-08") + assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=waiter_revenue_by_day_snapshot.snapshot_id, + intervals=[(to_timestamp("2022-12-01"), to_timestamp("2023-01-01"))], + ), + ] + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_forward_only_parent_created_in_dev_child_created_in_prod( + init_and_plan_context: t.Callable, +): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + waiter_revenue_by_day_model = context.get_model("sushi.waiter_revenue_by_day") + waiter_revenue_by_day_model = add_projection_to_model( + t.cast(SqlModel, waiter_revenue_by_day_model) + ) + forward_only_kind = waiter_revenue_by_day_model.kind.copy(update={"forward_only": True}) + waiter_revenue_by_day_model = waiter_revenue_by_day_model.copy( + update={"kind": forward_only_kind} + ) + context.upsert_model(waiter_revenue_by_day_model) + + waiter_revenue_by_day_snapshot = context.get_snapshot( + waiter_revenue_by_day_model, raise_if_missing=True + ) + top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) + + plan = context.plan_builder("dev", skip_tests=True, enable_preview=False).build() + assert len(plan.new_snapshots) == 2 + assert ( + plan.context_diff.snapshots[waiter_revenue_by_day_snapshot.snapshot_id].change_category + == SnapshotChangeCategory.NON_BREAKING + ) + assert ( + plan.context_diff.snapshots[top_waiters_snapshot.snapshot_id].change_category + == SnapshotChangeCategory.INDIRECT_NON_BREAKING + ) + assert all(s.is_forward_only for s in plan.new_snapshots) + assert plan.start == to_datetime("2023-01-01") + assert not plan.missing_intervals + + context.apply(plan) + + # Update the child to refer to a newly added column. + top_waiters_model = context.get_model("sushi.top_waiters") + top_waiters_model = add_projection_to_model(t.cast(SqlModel, top_waiters_model), literal=False) + context.upsert_model(top_waiters_model) + + top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) + + plan = context.plan_builder("prod", skip_tests=True, enable_preview=False).build() + assert len(plan.new_snapshots) == 1 + assert ( + plan.context_diff.snapshots[top_waiters_snapshot.snapshot_id].change_category + == SnapshotChangeCategory.NON_BREAKING + ) + + context.apply(plan) + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_forward_only_view_migration( + init_and_plan_context: t.Callable, +): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + model = context.get_model("sushi.top_waiters") + assert model.kind.is_view + model = add_projection_to_model(t.cast(SqlModel, model)) + context.upsert_model(model) + + # Apply a forward-only plan + context.plan("prod", skip_tests=True, no_prompts=True, auto_apply=True, forward_only=True) + + # Make sure that the new column got reflected in the view schema + df = context.fetchdf("SELECT one FROM sushi.top_waiters LIMIT 1") + assert len(df) == 1 + + +@time_machine.travel("2023-01-08 00:00:00 UTC") +def test_new_forward_only_model(init_and_plan_context: t.Callable): + context, _ = init_and_plan_context("examples/sushi") + + context.plan("dev", skip_tests=True, no_prompts=True, auto_apply=True, enable_preview=False) + + snapshot = context.get_snapshot("sushi.marketing") + + # The deployable table should not exist yet + assert not context.engine_adapter.table_exists(snapshot.table_name()) + assert context.engine_adapter.table_exists(snapshot.table_name(is_deployable=False)) + + context.plan("prod", skip_tests=True, no_prompts=True, auto_apply=True) + + assert context.engine_adapter.table_exists(snapshot.table_name()) + assert context.engine_adapter.table_exists(snapshot.table_name(is_deployable=False)) + + +@time_machine.travel("2023-01-08 15:00:00 UTC", tick=True) +@pytest.mark.parametrize("has_view_binding", [False, True]) +def test_non_breaking_change_after_forward_only_in_dev( + init_and_plan_context: t.Callable, has_view_binding: bool +): + context, plan = init_and_plan_context("examples/sushi") + context.snapshot_evaluator.adapter.HAS_VIEW_BINDING = has_view_binding + context.apply(plan) + + model = context.get_model("sushi.waiter_revenue_by_day") + context.upsert_model(add_projection_to_model(t.cast(SqlModel, model))) + waiter_revenue_by_day_snapshot = context.get_snapshot( + "sushi.waiter_revenue_by_day", raise_if_missing=True + ) + top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) + + plan = context.plan_builder("dev", skip_tests=True, forward_only=True).build() + assert len(plan.new_snapshots) == 2 + assert ( + plan.context_diff.snapshots[waiter_revenue_by_day_snapshot.snapshot_id].change_category + == SnapshotChangeCategory.NON_BREAKING + ) + assert ( + plan.context_diff.snapshots[top_waiters_snapshot.snapshot_id].change_category + == SnapshotChangeCategory.INDIRECT_NON_BREAKING + ) + assert all(s.is_forward_only for s in plan.new_snapshots) + assert to_timestamp(plan.start) == to_timestamp("2023-01-07") + assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=waiter_revenue_by_day_snapshot.snapshot_id, + intervals=[(to_timestamp("2023-01-07"), to_timestamp("2023-01-08"))], + ), + ] + + # Apply the forward-only changes first. + context.apply(plan) + + dev_df = context.engine_adapter.fetchdf( + "SELECT DISTINCT event_date FROM sushi__dev.waiter_revenue_by_day ORDER BY event_date" + ) + assert dev_df["event_date"].tolist() == [pd.to_datetime("2023-01-07")] + + # Make a non-breaking change to a model downstream. + model = context.get_model("sushi.top_waiters") + # Select 'one' column from the updated upstream model. + context.upsert_model(add_projection_to_model(t.cast(SqlModel, model), literal=False)) + top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) + + plan = context.plan_builder("dev", skip_tests=True).build() + assert len(plan.new_snapshots) == 1 + assert ( + plan.context_diff.snapshots[top_waiters_snapshot.snapshot_id].change_category + == SnapshotChangeCategory.NON_BREAKING + ) + assert to_timestamp(plan.start) == to_timestamp("2023-01-01") + assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=top_waiters_snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + ] + + # Apply the non-breaking changes. + context.apply(plan) + + dev_df = context.engine_adapter.fetchdf( + "SELECT DISTINCT waiter_id FROM sushi__dev.top_waiters WHERE one IS NOT NULL" + ) + assert not dev_df.empty + + prod_df = context.engine_adapter.fetchdf("DESCRIBE sushi.top_waiters") + assert "one" not in prod_df["column_name"].tolist() + + # Deploy both changes to prod. + plan = context.plan_builder("prod", skip_tests=True).build() + assert plan.start == to_timestamp("2023-01-01") + assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=top_waiters_snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + ] + + context.apply(plan) + + prod_df = context.engine_adapter.fetchdf( + "SELECT DISTINCT event_date FROM sushi.waiter_revenue_by_day WHERE one IS NOT NULL ORDER BY event_date" + ) + assert prod_df.empty + + prod_df = context.engine_adapter.fetchdf( + "SELECT DISTINCT waiter_id FROM sushi.top_waiters WHERE one IS NOT NULL" + ) + assert prod_df.empty + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_indirect_non_breaking_change_after_forward_only_in_dev(init_and_plan_context: t.Callable): + context, _ = init_and_plan_context("examples/sushi") + # Make sure that the most downstream model is a materialized model. + model = context.get_model("sushi.top_waiters") + model = model.copy(update={"kind": FullKind()}) + context.upsert_model(model) + context.plan("prod", skip_tests=True, auto_apply=True, no_prompts=True) + + # Make sushi.orders a forward-only model. + model = context.get_model("sushi.orders") + updated_model_kind = model.kind.copy(update={"forward_only": True}) + model = model.copy(update={"stamp": "force new version", "kind": updated_model_kind}) + context.upsert_model(model) + snapshot = context.get_snapshot(model, raise_if_missing=True) + + plan = context.plan_builder( + "dev", + skip_tests=True, + enable_preview=False, + categorizer_config=CategorizerConfig.all_full(), + ).build() + assert ( + plan.context_diff.snapshots[snapshot.snapshot_id].change_category + == SnapshotChangeCategory.BREAKING + ) + assert plan.context_diff.snapshots[snapshot.snapshot_id].is_forward_only + assert not plan.requires_backfill + context.apply(plan) + + # Make a non-breaking change to a model. + model = context.get_model("sushi.top_waiters") + context.upsert_model(add_projection_to_model(t.cast(SqlModel, model))) + top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) + + plan = context.plan_builder("dev", skip_tests=True, enable_preview=False).build() + assert len(plan.new_snapshots) == 1 + assert ( + plan.context_diff.snapshots[top_waiters_snapshot.snapshot_id].change_category + == SnapshotChangeCategory.NON_BREAKING + ) + assert plan.start == to_timestamp("2023-01-01") + assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=top_waiters_snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + ] + + # Apply the non-breaking changes. + context.apply(plan) + + # Make a non-breaking change upstream from the previously modified model. + model = context.get_model("sushi.waiter_revenue_by_day") + context.upsert_model(add_projection_to_model(t.cast(SqlModel, model))) + waiter_revenue_by_day_snapshot = context.get_snapshot( + "sushi.waiter_revenue_by_day", raise_if_missing=True + ) + top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) + + plan = context.plan_builder("dev", skip_tests=True, enable_preview=False).build() + assert len(plan.new_snapshots) == 2 + assert ( + plan.context_diff.snapshots[waiter_revenue_by_day_snapshot.snapshot_id].change_category + == SnapshotChangeCategory.NON_BREAKING + ) + assert ( + plan.context_diff.snapshots[top_waiters_snapshot.snapshot_id].change_category + == SnapshotChangeCategory.INDIRECT_NON_BREAKING + ) + assert plan.start == to_timestamp("2023-01-01") + assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=waiter_revenue_by_day_snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + ] + + # Apply the upstream non-breaking changes. + context.apply(plan) + assert not context.plan_builder("dev", skip_tests=True).build().requires_backfill + + # Deploy everything to prod. + plan = context.plan_builder("prod", skip_tests=True, enable_preview=False).build() + assert plan.start == to_timestamp("2023-01-01") + assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=top_waiters_snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + SnapshotIntervals( + snapshot_id=waiter_revenue_by_day_snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + ] + + context.apply(plan) + assert ( + not context.plan_builder("prod", skip_tests=True, enable_preview=False) + .build() + .requires_backfill + ) + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_changes_downstream_of_indirect_non_breaking_snapshot_without_intervals( + init_and_plan_context: t.Callable, +): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + # Make a breaking change first but don't backfill it + model = context.get_model("sushi.orders") + model = model.copy(update={"stamp": "force new version"}) + context.upsert_model(model) + plan_builder = context.plan_builder( + "dev", skip_backfill=True, skip_tests=True, no_auto_categorization=True + ) + plan_builder.set_choice(context.get_snapshot(model), SnapshotChangeCategory.BREAKING) + context.apply(plan_builder.build()) + + # Now make a non-breaking change to the same snapshot. + model = model.copy(update={"stamp": "force another new version"}) + context.upsert_model(model) + plan_builder = context.plan_builder( + "dev", skip_backfill=True, skip_tests=True, no_auto_categorization=True + ) + plan_builder.set_choice(context.get_snapshot(model), SnapshotChangeCategory.NON_BREAKING) + context.apply(plan_builder.build()) + + # Now make a change to a model downstream of the above model. + downstream_model = context.get_model("sushi.top_waiters") + downstream_model = downstream_model.copy(update={"stamp": "yet another new version"}) + context.upsert_model(downstream_model) + plan = context.plan_builder("dev", skip_tests=True).build() + + # If the parent is not representative then the child cannot be deployable + deployability_index = plan.deployability_index + assert not deployability_index.is_representative( + context.get_snapshot("sushi.waiter_revenue_by_day") + ) + assert not deployability_index.is_deployable(context.get_snapshot("sushi.top_waiters")) + + +@time_machine.travel("2023-01-08 15:00:00 UTC", tick=True) +def test_metadata_change_after_forward_only_results_in_migration(init_and_plan_context: t.Callable): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + # Make a forward-only change + model = context.get_model("sushi.waiter_revenue_by_day") + model = model.copy(update={"kind": model.kind.copy(update={"forward_only": True})}) + model = add_projection_to_model(t.cast(SqlModel, model)) + context.upsert_model(model) + plan = context.plan("dev", skip_tests=True, auto_apply=True, no_prompts=True) + assert len(plan.new_snapshots) == 2 + assert all(s.is_forward_only for s in plan.new_snapshots) + + # Follow-up with a metadata change in the same environment + model = model.copy(update={"owner": "new_owner"}) + context.upsert_model(model) + plan = context.plan("dev", skip_tests=True, auto_apply=True, no_prompts=True) + assert len(plan.new_snapshots) == 2 + assert all(s.change_category == SnapshotChangeCategory.METADATA for s in plan.new_snapshots) + + # Deploy the latest change to prod + context.plan("prod", skip_tests=True, auto_apply=True, no_prompts=True) + + # Check that the new column was added in prod + columns = context.engine_adapter.columns("sushi.waiter_revenue_by_day") + assert "one" in columns + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_indirect_non_breaking_downstream_of_forward_only(init_and_plan_context: t.Callable): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + # Make sushi.orders a forward-only model. + forward_only_model = context.get_model("sushi.orders") + updated_model_kind = forward_only_model.kind.copy(update={"forward_only": True}) + forward_only_model = forward_only_model.copy( + update={"stamp": "force new version", "kind": updated_model_kind} + ) + context.upsert_model(forward_only_model) + forward_only_snapshot = context.get_snapshot(forward_only_model, raise_if_missing=True) + + non_breaking_model = context.get_model("sushi.waiter_revenue_by_day") + non_breaking_model = non_breaking_model.copy(update={"start": "2023-01-01"}) + context.upsert_model(add_projection_to_model(t.cast(SqlModel, non_breaking_model))) + non_breaking_snapshot = context.get_snapshot(non_breaking_model, raise_if_missing=True) + top_waiter_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) + + plan = context.plan_builder( + "dev", + skip_tests=True, + enable_preview=False, + categorizer_config=CategorizerConfig.all_full(), + ).build() + assert ( + plan.context_diff.snapshots[forward_only_snapshot.snapshot_id].change_category + == SnapshotChangeCategory.BREAKING + ) + assert ( + plan.context_diff.snapshots[non_breaking_snapshot.snapshot_id].change_category + == SnapshotChangeCategory.NON_BREAKING + ) + assert ( + plan.context_diff.snapshots[top_waiter_snapshot.snapshot_id].change_category + == SnapshotChangeCategory.INDIRECT_NON_BREAKING + ) + assert plan.context_diff.snapshots[forward_only_snapshot.snapshot_id].is_forward_only + assert not plan.context_diff.snapshots[non_breaking_snapshot.snapshot_id].is_forward_only + assert not plan.context_diff.snapshots[top_waiter_snapshot.snapshot_id].is_forward_only + + assert plan.start == to_timestamp("2023-01-01") + assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=top_waiter_snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + SnapshotIntervals( + snapshot_id=non_breaking_snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + ] + + context.apply(plan) + assert ( + not context.plan_builder("dev", skip_tests=True, enable_preview=False) + .build() + .requires_backfill + ) + + # Deploy everything to prod. + plan = context.plan_builder("prod", skip_tests=True).build() + assert plan.start == to_timestamp("2023-01-01") + assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=top_waiter_snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + SnapshotIntervals( + snapshot_id=non_breaking_snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + ] + + context.apply(plan) + assert ( + not context.plan_builder("prod", skip_tests=True, enable_preview=False) + .build() + .requires_backfill + ) + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_indirect_non_breaking_view_model_non_representative_snapshot( + init_and_plan_context: t.Callable, +): + context, _ = init_and_plan_context("examples/sushi") + + # Forward-only parent + forward_only_model_name = "memory.sushi.test_forward_only_model" + forward_only_model_expressions = d.parse( + f""" + MODEL ( + name {forward_only_model_name}, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds, + forward_only true, + ), + ); + + SELECT '2023-01-01' AS ds, 'value' AS value; + """ + ) + forward_only_model = load_sql_based_model(forward_only_model_expressions) + assert forward_only_model.forward_only + context.upsert_model(forward_only_model) + + # FULL downstream model. + full_downstream_model_name = "memory.sushi.test_full_downstream_model" + full_downstream_model_expressions = d.parse( + f""" + MODEL ( + name {full_downstream_model_name}, + kind FULL, + ); + + SELECT ds, value FROM {forward_only_model_name}; + """ + ) + full_downstream_model = load_sql_based_model(full_downstream_model_expressions) + context.upsert_model(full_downstream_model) + + # VIEW downstream of the previous FULL model. + view_downstream_model_name = "memory.sushi.test_view_downstream_model" + view_downstream_model_expressions = d.parse( + f""" + MODEL ( + name {view_downstream_model_name}, + kind VIEW, + ); + + SELECT ds, value FROM {full_downstream_model_name}; + """ + ) + view_downstream_model = load_sql_based_model(view_downstream_model_expressions) + context.upsert_model(view_downstream_model) + + # Apply the initial plan with all 3 models. + context.plan(auto_apply=True, no_prompts=True) + + # Make a change to the forward-only model and apply it in dev. + context.upsert_model(add_projection_to_model(t.cast(SqlModel, forward_only_model))) + forward_only_model_snapshot_id = context.get_snapshot(forward_only_model_name).snapshot_id + full_downstream_model_snapshot_id = context.get_snapshot(full_downstream_model_name).snapshot_id + view_downstream_model_snapshot_id = context.get_snapshot(view_downstream_model_name).snapshot_id + dev_plan = context.plan("dev", auto_apply=True, no_prompts=True, enable_preview=False) + assert ( + dev_plan.snapshots[forward_only_model_snapshot_id].change_category + == SnapshotChangeCategory.NON_BREAKING + ) + assert ( + dev_plan.snapshots[full_downstream_model_snapshot_id].change_category + == SnapshotChangeCategory.INDIRECT_NON_BREAKING + ) + assert ( + dev_plan.snapshots[view_downstream_model_snapshot_id].change_category + == SnapshotChangeCategory.INDIRECT_NON_BREAKING + ) + assert not dev_plan.missing_intervals + + # Make a follow-up breaking change to the downstream full model. + new_full_downstream_model_expressions = d.parse( + f""" + MODEL ( + name {full_downstream_model_name}, + kind FULL, + ); + + SELECT ds, 'new_value' AS value FROM {forward_only_model_name}; + """ + ) + new_full_downstream_model = load_sql_based_model(new_full_downstream_model_expressions) + context.upsert_model(new_full_downstream_model) + full_downstream_model_snapshot_id = context.get_snapshot(full_downstream_model_name).snapshot_id + view_downstream_model_snapshot_id = context.get_snapshot(view_downstream_model_name).snapshot_id + dev_plan = context.plan( + "dev", + categorizer_config=CategorizerConfig.all_full(), + auto_apply=True, + no_prompts=True, + enable_preview=False, + ) + assert ( + dev_plan.snapshots[full_downstream_model_snapshot_id].change_category + == SnapshotChangeCategory.BREAKING + ) + assert ( + dev_plan.snapshots[view_downstream_model_snapshot_id].change_category + == SnapshotChangeCategory.INDIRECT_BREAKING + ) + assert len(dev_plan.missing_intervals) == 2 + assert dev_plan.missing_intervals[0].snapshot_id == full_downstream_model_snapshot_id + assert dev_plan.missing_intervals[1].snapshot_id == view_downstream_model_snapshot_id + + # Check that the representative view hasn't been created yet. + assert not context.engine_adapter.table_exists( + context.get_snapshot(view_downstream_model_name).table_name() + ) + + # Now promote the very first change to prod without promoting the 2nd breaking change. + context.upsert_model(full_downstream_model) + context.plan(auto_apply=True, no_prompts=True, categorizer_config=CategorizerConfig.all_full()) + + # Finally, make a non-breaking change to the full model in the same dev environment. + context.upsert_model(add_projection_to_model(t.cast(SqlModel, new_full_downstream_model))) + full_downstream_model_snapshot_id = context.get_snapshot(full_downstream_model_name).snapshot_id + view_downstream_model_snapshot_id = context.get_snapshot(view_downstream_model_name).snapshot_id + dev_plan = context.plan( + "dev", + categorizer_config=CategorizerConfig.all_full(), + auto_apply=True, + no_prompts=True, + enable_preview=False, + ) + assert ( + dev_plan.snapshots[full_downstream_model_snapshot_id].change_category + == SnapshotChangeCategory.NON_BREAKING + ) + assert ( + dev_plan.snapshots[view_downstream_model_snapshot_id].change_category + == SnapshotChangeCategory.INDIRECT_NON_BREAKING + ) + + # Deploy changes to prod + context.plan("prod", auto_apply=True, no_prompts=True) + + # Check that the representative view has been created. + assert context.engine_adapter.table_exists( + context.get_snapshot(view_downstream_model_name).table_name() + ) + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_indirect_non_breaking_view_model_non_representative_snapshot_migration( + init_and_plan_context: t.Callable, +): + context, _ = init_and_plan_context("examples/sushi") + + forward_only_model_expr = d.parse( + """ + MODEL ( + name memory.sushi.forward_only_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds, + forward_only TRUE, + on_destructive_change 'allow', + ), + ); + + SELECT '2023-01-07' AS ds, 1 AS a; + """ + ) + forward_only_model = load_sql_based_model(forward_only_model_expr) + context.upsert_model(forward_only_model) + + downstream_view_a_expr = d.parse( + """ + MODEL ( + name memory.sushi.downstream_view_a, + kind VIEW, + ); + + SELECT a from memory.sushi.forward_only_model; + """ + ) + downstream_view_a = load_sql_based_model(downstream_view_a_expr) + context.upsert_model(downstream_view_a) + + downstream_view_b_expr = d.parse( + """ + MODEL ( + name memory.sushi.downstream_view_b, + kind VIEW, + ); + + SELECT a from memory.sushi.downstream_view_a; + """ + ) + downstream_view_b = load_sql_based_model(downstream_view_b_expr) + context.upsert_model(downstream_view_b) + + context.plan(auto_apply=True, no_prompts=True, skip_tests=True) + + # Make a forward-only change + context.upsert_model(add_projection_to_model(t.cast(SqlModel, forward_only_model))) + # Make a non-breaking change downstream + context.upsert_model(add_projection_to_model(t.cast(SqlModel, downstream_view_a))) + + context.plan(auto_apply=True, no_prompts=True, skip_tests=True) + + # Make sure the downstrean indirect non-breaking view is available in prod + count = context.engine_adapter.fetchone("SELECT COUNT(*) FROM memory.sushi.downstream_view_b")[ + 0 + ] + assert count > 0 + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_new_forward_only_model_concurrent_versions(init_and_plan_context: t.Callable): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + new_model_expr = d.parse( + """ + MODEL ( + name memory.sushi.new_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds, + forward_only TRUE, + on_destructive_change 'allow', + ), + ); + + SELECT '2023-01-07' AS ds, 1 AS a; + """ + ) + new_model = load_sql_based_model(new_model_expr) + + # Add the first version of the model and apply it to dev_a. + context.upsert_model(new_model) + snapshot_a = context.get_snapshot(new_model.name) + plan_a = context.plan_builder("dev_a").build() + snapshot_a = plan_a.snapshots[snapshot_a.snapshot_id] + + assert snapshot_a.snapshot_id in plan_a.context_diff.new_snapshots + assert snapshot_a.snapshot_id in plan_a.context_diff.added + assert snapshot_a.change_category == SnapshotChangeCategory.BREAKING + + context.apply(plan_a) + + new_model_alt_expr = d.parse( + """ + MODEL ( + name memory.sushi.new_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds, + forward_only TRUE, + on_destructive_change 'allow', + ), + ); + + SELECT '2023-01-07' AS ds, 1 AS b; + """ + ) + new_model_alt = load_sql_based_model(new_model_alt_expr) + + # Add the second version of the model but don't apply it yet + context.upsert_model(new_model_alt) + snapshot_b = context.get_snapshot(new_model_alt.name) + plan_b = context.plan_builder("dev_b").build() + snapshot_b = plan_b.snapshots[snapshot_b.snapshot_id] + + assert snapshot_b.snapshot_id in plan_b.context_diff.new_snapshots + assert snapshot_b.snapshot_id in plan_b.context_diff.added + assert snapshot_b.change_category == SnapshotChangeCategory.BREAKING + + assert snapshot_b.fingerprint != snapshot_a.fingerprint + assert snapshot_b.version == snapshot_a.version + + # Apply the 1st version to prod + context.upsert_model(new_model) + plan_prod_a = context.plan_builder("prod").build() + assert snapshot_a.snapshot_id in plan_prod_a.snapshots + assert ( + plan_prod_a.snapshots[snapshot_a.snapshot_id].change_category + == SnapshotChangeCategory.BREAKING + ) + context.apply(plan_prod_a) + + df = context.fetchdf("SELECT * FROM memory.sushi.new_model") + assert df.to_dict() == {"ds": {0: "2023-01-07"}, "a": {0: 1}} + + # Modify the 1st version in prod to trigger a forward-only change + new_model = add_projection_to_model(t.cast(SqlModel, new_model)) + context.upsert_model(new_model) + context.plan("prod", auto_apply=True, no_prompts=True, skip_tests=True) + + # Apply the 2nd version to dev_b. + # At this point the snapshot of the 2nd version has already been categorized but not + # persisted in the state. This means that when the snapshot of the 1st version was + # being unpaused during promotion to prod, the state of the 2nd version snapshot was not updated + context.apply(plan_b) + + # Apply the 2nd version to prod + context.upsert_model(new_model_alt) + plan_prod_b = context.plan_builder("prod").build() + assert ( + plan_prod_b.snapshots[snapshot_b.snapshot_id].change_category + == SnapshotChangeCategory.BREAKING + ) + assert not plan_prod_b.requires_backfill + context.apply(plan_prod_b) + + df = context.fetchdf("SELECT * FROM memory.sushi.new_model").replace({np.nan: None}) + assert df.to_dict() == {"ds": {0: "2023-01-07"}, "b": {0: None}} + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_new_forward_only_model_same_dev_environment(init_and_plan_context: t.Callable): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + new_model_expr = d.parse( + """ + MODEL ( + name memory.sushi.new_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds, + forward_only TRUE, + on_destructive_change 'allow', + ), + ); + + SELECT '2023-01-07' AS ds, 1 AS a; + """ + ) + new_model = load_sql_based_model(new_model_expr) + + # Add the first version of the model and apply it to dev. + context.upsert_model(new_model) + snapshot_a = context.get_snapshot(new_model.name) + plan_a = context.plan_builder("dev").build() + snapshot_a = plan_a.snapshots[snapshot_a.snapshot_id] + + assert snapshot_a.snapshot_id in plan_a.context_diff.new_snapshots + assert snapshot_a.snapshot_id in plan_a.context_diff.added + assert snapshot_a.change_category == SnapshotChangeCategory.BREAKING + + context.apply(plan_a) + + df = context.fetchdf("SELECT * FROM memory.sushi__dev.new_model") + assert df.to_dict() == {"ds": {0: "2023-01-07"}, "a": {0: 1}} + + new_model_alt_expr = d.parse( + """ + MODEL ( + name memory.sushi.new_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds, + forward_only TRUE, + on_destructive_change 'allow', + ), + ); + + SELECT '2023-01-07' AS ds, 1 AS b; + """ + ) + new_model_alt = load_sql_based_model(new_model_alt_expr) + + # Add the second version of the model and apply it to the same environment. + context.upsert_model(new_model_alt) + snapshot_b = context.get_snapshot(new_model_alt.name) + + context.invalidate_environment("dev", sync=True) + plan_b = context.plan_builder("dev").build() + snapshot_b = plan_b.snapshots[snapshot_b.snapshot_id] + + context.apply(plan_b) + + df = context.fetchdf("SELECT * FROM memory.sushi__dev.new_model").replace({np.nan: None}) + assert df.to_dict() == {"ds": {0: "2023-01-07"}, "b": {0: 1}} diff --git a/tests/core/integration/test_model_kinds.py b/tests/core/integration/test_model_kinds.py new file mode 100644 index 0000000000..1cc1bf7aeb --- /dev/null +++ b/tests/core/integration/test_model_kinds.py @@ -0,0 +1,2644 @@ +from __future__ import annotations + +import typing as t +from collections import Counter +from datetime import timedelta +from unittest import mock +import pandas as pd # noqa: TID253 +import pytest +from pathlib import Path +import time_machine +from pytest_mock.plugin import MockerFixture +from sqlglot import exp + +from sqlmesh import CustomMaterialization +from sqlmesh.core import dialect as d +from sqlmesh.core.config import ( + Config, + ModelDefaultsConfig, + DuckDBConnectionConfig, + GatewayConfig, +) +from sqlmesh.core.console import Console +from sqlmesh.core.context import Context +from sqlmesh.core.config.categorizer import CategorizerConfig +from sqlmesh.core.model import ( + Model, + SqlModel, + CustomKind, + load_sql_based_model, +) +from sqlmesh.core.plan import SnapshotIntervals +from sqlmesh.utils.date import to_date, to_timestamp +from sqlmesh.utils.pydantic import validate_string +from tests.conftest import SushiDataValidator +from sqlmesh.utils import CorrelationId +from tests.utils.test_filesystem import create_temp_file + +if t.TYPE_CHECKING: + from sqlmesh import QueryOrDF + +pytestmark = pytest.mark.slow + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_incremental_by_partition(init_and_plan_context: t.Callable): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + source_name = "raw.test_incremental_by_partition" + model_name = "memory.sushi.test_incremental_by_partition" + + expressions = d.parse( + f""" + MODEL ( + name {model_name}, + kind INCREMENTAL_BY_PARTITION (disable_restatement false), + partitioned_by [key], + allow_partials true, + start '2023-01-07', + ); + + SELECT key, value FROM {source_name}; + """ + ) + model = load_sql_based_model(expressions) + context.upsert_model(model) + + context.engine_adapter.ctas( + source_name, + d.parse_one("SELECT 'key_a' AS key, 1 AS value"), + ) + + context.plan(auto_apply=True, no_prompts=True) + assert context.engine_adapter.fetchall(f"SELECT * FROM {model_name}") == [ + ("key_a", 1), + ] + + context.engine_adapter.replace_query( + source_name, + d.parse_one("SELECT 'key_b' AS key, 1 AS value"), + ) + context.run(ignore_cron=True) + assert context.engine_adapter.fetchall(f"SELECT * FROM {model_name}") == [ + ("key_a", 1), + ("key_b", 1), + ] + + context.engine_adapter.replace_query( + source_name, + d.parse_one("SELECT 'key_a' AS key, 2 AS value"), + ) + # Run 1 minute later. + with time_machine.travel("2023-01-08 15:01:00 UTC"): + context.run(ignore_cron=True) + assert context.engine_adapter.fetchall(f"SELECT * FROM {model_name}") == [ + ("key_b", 1), + ("key_a", 2), + ] + + # model should fully refresh on restatement + context.engine_adapter.replace_query( + source_name, + d.parse_one("SELECT 'key_c' AS key, 3 AS value"), + ) + context.plan(auto_apply=True, no_prompts=True, restate_models=[model_name]) + assert context.engine_adapter.fetchall(f"SELECT * FROM {model_name}") == [ + ("key_c", 3), + ] + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_custom_materialization(init_and_plan_context: t.Callable): + context, _ = init_and_plan_context("examples/sushi") + + custom_insert_called = False + + class CustomFullMaterialization(CustomMaterialization): + NAME = "test_custom_full" + + def insert( + self, + table_name: str, + query_or_df: QueryOrDF, + model: Model, + is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], + **kwargs: t.Any, + ) -> None: + nonlocal custom_insert_called + custom_insert_called = True + + self._replace_query_for_model(model, table_name, query_or_df, render_kwargs) + + model = context.get_model("sushi.top_waiters") + kwargs = { + **model.dict(), + # Make a breaking change. + "kind": dict(name="CUSTOM", materialization="test_custom_full"), + } + context.upsert_model(SqlModel.parse_obj(kwargs)) + + context.plan(auto_apply=True, no_prompts=True) + + assert custom_insert_called + + +# needs to be defined at the top level. If its defined within the test body, +# adding to the snapshot cache fails with: AttributeError: Can't pickle local object +class TestCustomKind(CustomKind): + __test__ = False # prevent pytest warning since this isnt a class containing tests + + @property + def custom_property(self) -> str: + return validate_string(self.materialization_properties.get("custom_property")) + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_custom_materialization_with_custom_kind(init_and_plan_context: t.Callable): + context, _ = init_and_plan_context("examples/sushi") + + custom_insert_calls = [] + + class CustomFullMaterialization(CustomMaterialization[TestCustomKind]): + NAME = "test_custom_full_with_custom_kind" + + def insert( + self, + table_name: str, + query_or_df: QueryOrDF, + model: Model, + is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], + **kwargs: t.Any, + ) -> None: + assert isinstance(model.kind, TestCustomKind) + + nonlocal custom_insert_calls + custom_insert_calls.append(model.kind.custom_property) + + self._replace_query_for_model(model, table_name, query_or_df, render_kwargs) + + model = context.get_model("sushi.top_waiters") + kwargs = { + **model.dict(), + # Make a breaking change. + "kind": dict( + name="CUSTOM", + materialization="test_custom_full_with_custom_kind", + materialization_properties={"custom_property": "pytest"}, + ), + } + context.upsert_model(SqlModel.parse_obj(kwargs)) + + context.plan(auto_apply=True) + + assert custom_insert_calls == ["pytest"] + + # no changes + context.plan(auto_apply=True) + + assert custom_insert_calls == ["pytest"] + + # change a property on the custom kind, breaking change + kwargs["kind"]["materialization_properties"]["custom_property"] = "some value" + context.upsert_model(SqlModel.parse_obj(kwargs)) + context.plan(auto_apply=True) + + assert custom_insert_calls == ["pytest", "some value"] + + +def test_incremental_time_self_reference( + mocker: MockerFixture, sushi_context: Context, sushi_data_validator: SushiDataValidator +): + start_ts = to_timestamp("1 week ago") + start_date, end_date = to_date("1 week ago"), to_date("yesterday") + if to_timestamp(start_date) < start_ts: + # The start date must be aligned by the interval unit. + start_date += timedelta(days=1) + + df = sushi_context.engine_adapter.fetchdf( + "SELECT MIN(event_date) FROM sushi.customer_revenue_lifetime" + ) + assert df.iloc[0, 0] == pd.to_datetime(start_date) + df = sushi_context.engine_adapter.fetchdf( + "SELECT MAX(event_date) FROM sushi.customer_revenue_lifetime" + ) + assert df.iloc[0, 0] == pd.to_datetime(end_date) + results = sushi_data_validator.validate("sushi.customer_revenue_lifetime", start_date, end_date) + plan = sushi_context.plan_builder( + restate_models=["sushi.customer_revenue_lifetime", "sushi.customer_revenue_by_day"], + start=start_date, + end="5 days ago", + ).build() + revenue_lifeteime_snapshot = sushi_context.get_snapshot( + "sushi.customer_revenue_lifetime", raise_if_missing=True + ) + revenue_by_day_snapshot = sushi_context.get_snapshot( + "sushi.customer_revenue_by_day", raise_if_missing=True + ) + assert sorted(plan.missing_intervals, key=lambda x: x.snapshot_id) == sorted( + [ + SnapshotIntervals( + snapshot_id=revenue_lifeteime_snapshot.snapshot_id, + intervals=[ + (to_timestamp(to_date("7 days ago")), to_timestamp(to_date("6 days ago"))), + (to_timestamp(to_date("6 days ago")), to_timestamp(to_date("5 days ago"))), + (to_timestamp(to_date("5 days ago")), to_timestamp(to_date("4 days ago"))), + (to_timestamp(to_date("4 days ago")), to_timestamp(to_date("3 days ago"))), + (to_timestamp(to_date("3 days ago")), to_timestamp(to_date("2 days ago"))), + (to_timestamp(to_date("2 days ago")), to_timestamp(to_date("1 days ago"))), + (to_timestamp(to_date("1 day ago")), to_timestamp(to_date("today"))), + ], + ), + SnapshotIntervals( + snapshot_id=revenue_by_day_snapshot.snapshot_id, + intervals=[ + (to_timestamp(to_date("7 days ago")), to_timestamp(to_date("6 days ago"))), + (to_timestamp(to_date("6 days ago")), to_timestamp(to_date("5 days ago"))), + ], + ), + ], + key=lambda x: x.snapshot_id, + ) + sushi_context.console = mocker.Mock(spec=Console) + sushi_context.apply(plan) + num_batch_calls = Counter( + [x[0][0] for x in sushi_context.console.update_snapshot_evaluation_progress.call_args_list] # type: ignore + ) + # Validate that we made 7 calls to the customer_revenue_lifetime snapshot and 1 call to the customer_revenue_by_day snapshot + assert num_batch_calls == { + sushi_context.get_snapshot("sushi.customer_revenue_lifetime", raise_if_missing=True): 7, + sushi_context.get_snapshot("sushi.customer_revenue_by_day", raise_if_missing=True): 1, + } + # Validate that the results are the same as before the restate + assert results == sushi_data_validator.validate( + "sushi.customer_revenue_lifetime", start_date, end_date + ) + + +def test_incremental_by_time_model_ignore_destructive_change(tmp_path: Path): + models_dir = tmp_path / "models" + models_dir.mkdir() + data_dir = tmp_path / "data" + data_dir.mkdir() + data_filepath = data_dir / "test.duckdb" + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + default_connection=DuckDBConnectionConfig(database=str(data_filepath)), + ) + + # Initial model with 3 columns + initial_model = f""" + MODEL ( + name test_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds, + forward_only true, + on_destructive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 1 as id, + 'test_name' as name, + @start_ds as ds + FROM + source_table; + """ + + # Write initial model + (models_dir / "test_model.sql").write_text(initial_model) + + with time_machine.travel("2023-01-08 00:00:00 UTC"): + # Create context and apply initial model + context = Context(paths=[tmp_path], config=config) + context.engine_adapter.execute("CREATE TABLE source_table (source_id INT)") + context.engine_adapter.execute("INSERT INTO source_table VALUES (1)") + + # Apply initial plan and load data + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify initial data was loaded + initial_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(initial_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in initial_df.columns + assert "name" in initial_df.columns + assert "ds" in initial_df.columns + + context.close() + + # remove `name` column and add new column + initial_model = """ + MODEL ( + name test_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds, + forward_only true, + on_destructive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 2 as id, + 3 as new_column, + @start_ds as ds + FROM + source_table; + """ + (models_dir / "test_model.sql").write_text(initial_model) + + context = Context(paths=[tmp_path], config=config) + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify data loading continued to work + # The existing data should still be there and new data should be loaded + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + + assert len(updated_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is still in table since destructive was ignored + assert "name" in updated_df.columns + # new_column is added since it is additive and allowed + assert "new_column" in updated_df.columns + + context.close() + + with time_machine.travel("2023-01-10 00:00:00 UTC"): + context = Context(paths=[tmp_path], config=config) + context.run() + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(updated_df) == 2 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is still in table since destructive was ignored + assert "name" in updated_df.columns + # new_column is added since it is additive and allowed + assert "new_column" in updated_df.columns + assert updated_df["new_column"].dropna().tolist() == [3] + + with time_machine.travel("2023-01-11 00:00:00 UTC"): + updated_model = """ + MODEL ( + name test_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds, + forward_only true, + on_destructive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 2 as id, + CAST(4 AS STRING) as new_column, + @start_ds as ds + FROM + source_table; + """ + (models_dir / "test_model.sql").write_text(updated_model) + + context = Context(paths=[tmp_path], config=config) + context.plan("prod", auto_apply=True, no_prompts=True, run=True) + + # Verify data loading continued to work + # The existing data should still be there and new data should be loaded + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + + assert len(updated_df) == 3 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is still in table since destructive was ignored + assert "name" in updated_df.columns + # new_column is added since it is additive and allowed + assert "new_column" in updated_df.columns + # The destructive change was ignored but this change is coercable and therefore we still return ints + assert updated_df["new_column"].dropna().tolist() == [3, 4] + + with time_machine.travel("2023-01-12 00:00:00 UTC"): + updated_model = """ + MODEL ( + name test_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds, + forward_only true, + on_destructive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 2 as id, + CAST(5 AS STRING) as new_column, + @start_ds as ds + FROM + source_table; + """ + (models_dir / "test_model.sql").write_text(updated_model) + + context = Context(paths=[tmp_path], config=config) + # Make the change compatible since that means we will attempt and alter now that is considered additive + context.engine_adapter.SCHEMA_DIFFER_KWARGS["compatible_types"] = { + exp.DataType.build("INT"): {exp.DataType.build("STRING")} + } + context.plan("prod", auto_apply=True, no_prompts=True, run=True) + + # Verify data loading continued to work + # The existing data should still be there and new data should be loaded + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + + assert len(updated_df) == 4 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is still in table since destructive was ignored + assert "name" in updated_df.columns + # new_column is added since it is additive and allowed + assert "new_column" in updated_df.columns + # The change is now reflected since an additive alter could be performed + assert updated_df["new_column"].dropna().tolist() == ["3", "4", "5"] + + context.close() + + +def test_incremental_by_time_model_ignore_additive_change(tmp_path: Path): + models_dir = tmp_path / "models" + models_dir.mkdir() + data_dir = tmp_path / "data" + data_dir.mkdir() + data_filepath = data_dir / "test.duckdb" + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + default_connection=DuckDBConnectionConfig(database=str(data_filepath)), + ) + + # Initial model with 3 columns + initial_model = f""" + MODEL ( + name test_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds, + forward_only true, + on_destructive_change allow, + on_additive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 1 as id, + 'test_name' as name, + 'other' as other_column, + @start_ds as ds + FROM + source_table; + """ + + # Write initial model + (models_dir / "test_model.sql").write_text(initial_model) + + with time_machine.travel("2023-01-08 00:00:00 UTC"): + # Create context and apply initial model + context = Context(paths=[tmp_path], config=config) + context.engine_adapter.execute("CREATE TABLE source_table (source_id INT)") + context.engine_adapter.execute("INSERT INTO source_table VALUES (1)") + + # Apply initial plan and load data + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify initial data was loaded + initial_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(initial_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in initial_df.columns + assert "name" in initial_df.columns + assert "ds" in initial_df.columns + + context.close() + + # remove `name` column and add new column to the source table + initial_model = """ + MODEL ( + name test_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds, + forward_only true, + on_destructive_change allow, + on_additive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 1 as id, + 'other' as other_column, + @start_ds as ds + FROM + source_table; + """ + (models_dir / "test_model.sql").write_text(initial_model) + + context = Context(paths=[tmp_path], config=config) + context.engine_adapter.execute("ALTER TABLE source_table ADD COLUMN new_column INT") + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify data loading continued to work + # The existing data should still be there and new data should be loaded + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + + assert len(updated_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is removed since destructive is allowed + assert "name" not in updated_df.columns + # new_column is not added since additive is ignored + assert "new_column" not in updated_df.columns + + context.close() + + with time_machine.travel("2023-01-10 00:00:00 UTC"): + context = Context(paths=[tmp_path], config=config) + context.run() + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(updated_df) == 2 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is not still in table since destructive was applied + assert "name" not in updated_df.columns + # new_column is still not added since additive is ignored + assert "new_column" not in updated_df.columns + + with time_machine.travel("2023-01-11 00:00:00 UTC"): + updated_model = """ + MODEL ( + name test_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds, + forward_only true, + on_destructive_change allow, + on_additive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + CAST(1 AS STRING) as id, + 'other' as other_column, + @start_ds as ds + FROM + source_table; + """ + (models_dir / "test_model.sql").write_text(updated_model) + + context = Context(paths=[tmp_path], config=config) + context.engine_adapter.SCHEMA_DIFFER_KWARGS["compatible_types"] = { + exp.DataType.build("INT"): {exp.DataType.build("STRING")} + } + context.plan("prod", auto_apply=True, no_prompts=True, run=True) + + # Verify data loading continued to work + # The existing data should still be there and new data should be loaded + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + + assert len(updated_df) == 3 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is not still in table since destructive was allowed + assert "name" not in updated_df.columns + # new_column is still not added since additive is ignored + assert "new_column" not in updated_df.columns + # The additive change was ignored since we set the change as compatible therefore + # instead of getting strings in the result we still return ints + assert updated_df["id"].tolist() == [1, 1, 1] + + with time_machine.travel("2023-01-12 00:00:00 UTC"): + updated_model = """ + MODEL ( + name test_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds, + forward_only true, + on_destructive_change allow, + on_additive_change allow + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + CAST(1 AS STRING) as id, + 'other' as other_column, + @start_ds as ds + FROM + source_table; + """ + (models_dir / "test_model.sql").write_text(updated_model) + + context = Context(paths=[tmp_path], config=config) + # Make the change compatible since that means we will attempt and alter now that is considered additive + context.engine_adapter.SCHEMA_DIFFER_KWARGS["compatible_types"] = { + exp.DataType.build("INT"): {exp.DataType.build("STRING")} + } + context.plan("prod", auto_apply=True, no_prompts=True, run=True) + + # Verify data loading continued to work + # The existing data should still be there and new data should be loaded + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + + assert len(updated_df) == 4 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is not still in table since destructive was allowed + assert "name" not in updated_df.columns + # new_column is now added since it is additive is now allowed + assert "new_column" in updated_df.columns + # The change is now reflected since an additive alter could be performed + assert updated_df["id"].dropna().tolist() == ["1", "1", "1", "1"] + + context.close() + + +def test_incremental_by_unique_key_model_ignore_destructive_change(tmp_path: Path): + models_dir = tmp_path / "models" + models_dir.mkdir() + data_dir = tmp_path / "data" + data_dir.mkdir() + data_filepath = data_dir / "test.duckdb" + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + default_connection=DuckDBConnectionConfig(database=str(data_filepath)), + ) + + # Initial model with 3 columns + initial_model = f""" + MODEL ( + name test_model, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key id, + forward_only true, + on_destructive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 1 as id, + 'test_name' as name, + @start_ds as ds + FROM + source_table; + """ + + # Write initial model + (models_dir / "test_model.sql").write_text(initial_model) + + with time_machine.travel("2023-01-08 00:00:00 UTC"): + # Create context and apply initial model + context = Context(paths=[tmp_path], config=config) + context.engine_adapter.execute("CREATE TABLE source_table (source_id INT)") + context.engine_adapter.execute("INSERT INTO source_table VALUES (1)") + + # Apply initial plan and load data + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify initial data was loaded + initial_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(initial_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in initial_df.columns + assert "name" in initial_df.columns + assert "ds" in initial_df.columns + + context.close() + + # remove `name` column and add new column + initial_model = """ + MODEL ( + name test_model, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key id, + forward_only true, + on_destructive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 2 as id, + 3 as new_column, + @start_ds as ds + FROM + source_table; + """ + (models_dir / "test_model.sql").write_text(initial_model) + + context = Context(paths=[tmp_path], config=config) + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify data loading continued to work + # The existing data should still be there and new data should be loaded + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + + assert len(updated_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is still in table since destructive was ignored + assert "name" in updated_df.columns + # new_column is added since it is additive and allowed + assert "new_column" in updated_df.columns + + context.close() + + with time_machine.travel("2023-01-10 00:00:00 UTC"): + context = Context(paths=[tmp_path], config=config) + context.run() + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(updated_df) == 2 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is still in table since destructive was ignored + assert "name" in updated_df.columns + # new_column is added since it is additive and allowed + assert "new_column" in updated_df.columns + + context.close() + + +def test_incremental_by_unique_key_model_ignore_additive_change(tmp_path: Path): + models_dir = tmp_path / "models" + models_dir.mkdir() + data_dir = tmp_path / "data" + data_dir.mkdir() + data_filepath = data_dir / "test.duckdb" + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + default_connection=DuckDBConnectionConfig(database=str(data_filepath)), + ) + + # Initial model with 3 columns + initial_model = f""" + MODEL ( + name test_model, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key id, + forward_only true, + on_destructive_change allow, + on_additive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 1 as id, + 'test_name' as name, + @start_ds as ds + FROM + source_table; + """ + + # Write initial model + (models_dir / "test_model.sql").write_text(initial_model) + + with time_machine.travel("2023-01-08 00:00:00 UTC"): + # Create context and apply initial model + context = Context(paths=[tmp_path], config=config) + context.engine_adapter.execute("CREATE TABLE source_table (source_id INT)") + context.engine_adapter.execute("INSERT INTO source_table VALUES (1)") + + # Apply initial plan and load data + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify initial data was loaded + initial_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(initial_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in initial_df.columns + assert "name" in initial_df.columns + assert "ds" in initial_df.columns + + context.close() + + # remove `name` column and add new column + initial_model = """ + MODEL ( + name test_model, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key id, + forward_only true, + on_destructive_change allow, + on_additive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 2 as id, + 3 as new_column, + @start_ds as ds + FROM + source_table; + """ + (models_dir / "test_model.sql").write_text(initial_model) + + context = Context(paths=[tmp_path], config=config) + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify data loading continued to work + # The existing data should still be there and new data should be loaded + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + + assert len(updated_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is not in table since destructive was allowed + assert "name" not in updated_df.columns + # new_column is not added since it is additive and ignored + assert "new_column" not in updated_df.columns + + context.close() + + with time_machine.travel("2023-01-10 00:00:00 UTC"): + context = Context(paths=[tmp_path], config=config) + context.run() + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(updated_df) == 2 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is still not in table since destructive was allowed + assert "name" not in updated_df.columns + # new_column is not added since it is additive and ignored + assert "new_column" not in updated_df.columns + + context.close() + + +def test_incremental_unmanaged_model_ignore_destructive_change(tmp_path: Path): + models_dir = tmp_path / "models" + models_dir.mkdir() + data_dir = tmp_path / "data" + data_dir.mkdir() + data_filepath = data_dir / "test.duckdb" + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + default_connection=DuckDBConnectionConfig(database=str(data_filepath)), + ) + + # Initial model with 3 columns + initial_model = f""" + MODEL ( + name test_model, + kind INCREMENTAL_UNMANAGED( + on_destructive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 1 as id, + 'test_name' as name, + @start_ds as ds + FROM + source_table; + """ + + # Write initial model + (models_dir / "test_model.sql").write_text(initial_model) + + with time_machine.travel("2023-01-08 00:00:00 UTC"): + # Create context and apply initial model + context = Context(paths=[tmp_path], config=config) + context.engine_adapter.execute("CREATE TABLE source_table (source_id INT)") + context.engine_adapter.execute("INSERT INTO source_table VALUES (1)") + + # Apply initial plan and load data + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify initial data was loaded + initial_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(initial_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in initial_df.columns + assert "name" in initial_df.columns + assert "ds" in initial_df.columns + + context.close() + + # remove `name` column and add new column + initial_model = """ + MODEL ( + name test_model, + kind INCREMENTAL_UNMANAGED( + on_destructive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 2 as id, + 3 as new_column, + @start_ds as ds + FROM + source_table; + """ + (models_dir / "test_model.sql").write_text(initial_model) + + context = Context(paths=[tmp_path], config=config) + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify data loading continued to work + # The existing data should still be there and new data should be loaded + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + + assert len(updated_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is still in table since destructive was ignored + assert "name" in updated_df.columns + # new_column is added since it is additive and allowed + assert "new_column" in updated_df.columns + + context.close() + + with time_machine.travel("2023-01-10 00:00:00 UTC"): + context = Context(paths=[tmp_path], config=config) + context.run() + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(updated_df) == 2 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is still in table since destructive was ignored + assert "name" in updated_df.columns + # new_column is added since it is additive and allowed + assert "new_column" in updated_df.columns + + context.close() + + +def test_incremental_unmanaged_model_ignore_additive_change(tmp_path: Path): + models_dir = tmp_path / "models" + models_dir.mkdir() + data_dir = tmp_path / "data" + data_dir.mkdir() + data_filepath = data_dir / "test.duckdb" + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + default_connection=DuckDBConnectionConfig(database=str(data_filepath)), + ) + + # Initial model with 3 columns + initial_model = f""" + MODEL ( + name test_model, + kind INCREMENTAL_UNMANAGED( + on_destructive_change allow, + on_additive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 1 as id, + 'test_name' as name, + @start_ds as ds + FROM + source_table; + """ + + # Write initial model + (models_dir / "test_model.sql").write_text(initial_model) + + with time_machine.travel("2023-01-08 00:00:00 UTC"): + # Create context and apply initial model + context = Context(paths=[tmp_path], config=config) + context.engine_adapter.execute("CREATE TABLE source_table (source_id INT)") + context.engine_adapter.execute("INSERT INTO source_table VALUES (1)") + + # Apply initial plan and load data + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify initial data was loaded + initial_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(initial_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in initial_df.columns + assert "name" in initial_df.columns + assert "ds" in initial_df.columns + + context.close() + + # remove `name` column and add new column + initial_model = """ + MODEL ( + name test_model, + kind INCREMENTAL_UNMANAGED( + on_destructive_change allow, + on_additive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 2 as id, + 3 as new_column, + @start_ds as ds + FROM + source_table; + """ + (models_dir / "test_model.sql").write_text(initial_model) + + context = Context(paths=[tmp_path], config=config) + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify data loading continued to work + # The existing data should still be there and new data should be loaded + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + + assert len(updated_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is not in table since destructive was allowed + assert "name" not in updated_df.columns + # new_column is not added since it is additive and ignored + assert "new_column" not in updated_df.columns + + context.close() + + with time_machine.travel("2023-01-10 00:00:00 UTC"): + context = Context(paths=[tmp_path], config=config) + context.run() + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(updated_df) == 2 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is not still in table since destructive was allowed + assert "name" not in updated_df.columns + # new_column is not added since it is additive and ignored + assert "new_column" not in updated_df.columns + + context.close() + + +def test_scd_type_2_by_time_ignore_destructive_change(tmp_path: Path): + models_dir = tmp_path / "models" + models_dir.mkdir() + data_dir = tmp_path / "data" + data_dir.mkdir() + data_filepath = data_dir / "test.duckdb" + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + default_connection=DuckDBConnectionConfig(database=str(data_filepath)), + ) + + # Initial model with 3 columns + initial_model = f""" + MODEL ( + name test_model, + kind SCD_TYPE_2_BY_TIME ( + unique_key id, + updated_at_name ds, + on_destructive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 1 as id, + 'test_name' as name, + @start_dt as ds + FROM + source_table; + """ + + # Write initial model + (models_dir / "test_model.sql").write_text(initial_model) + + with time_machine.travel("2023-01-08 00:00:00 UTC"): + # Create context and apply initial model + context = Context(paths=[tmp_path], config=config) + context.engine_adapter.execute("CREATE TABLE source_table (source_id INT)") + context.engine_adapter.execute("INSERT INTO source_table VALUES (1)") + + # Apply initial plan and load data + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify initial data was loaded + initial_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(initial_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in initial_df.columns + assert "name" in initial_df.columns + assert "ds" in initial_df.columns + + context.close() + + # remove `name` column and add new column + initial_model = """ + MODEL ( + name test_model, + kind SCD_TYPE_2_BY_TIME ( + unique_key id, + updated_at_name ds, + on_destructive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 1 as id, + 3 as new_column, + @start_dt as ds + FROM + source_table; + """ + (models_dir / "test_model.sql").write_text(initial_model) + + context = Context(paths=[tmp_path], config=config) + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify data loading continued to work + # The existing data should still be there and new data should be loaded + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + + assert len(updated_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is still in table since destructive was ignored + assert "name" in updated_df.columns + # new_column is added since it is additive and allowed + assert "new_column" in updated_df.columns + + context.close() + + with time_machine.travel("2023-01-10 00:00:00 UTC"): + context = Context(paths=[tmp_path], config=config) + context.run() + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(updated_df) == 2 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is still in table since destructive was ignored + assert "name" in updated_df.columns + # new_column is added since it is additive and allowed + assert "new_column" in updated_df.columns + + context.close() + + +def test_scd_type_2_by_time_ignore_additive_change(tmp_path: Path): + models_dir = tmp_path / "models" + models_dir.mkdir() + data_dir = tmp_path / "data" + data_dir.mkdir() + data_filepath = data_dir / "test.duckdb" + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + default_connection=DuckDBConnectionConfig(database=str(data_filepath)), + ) + + # Initial model with 3 columns + initial_model = f""" + MODEL ( + name test_model, + kind SCD_TYPE_2_BY_TIME ( + unique_key id, + updated_at_name ds, + on_destructive_change allow, + on_additive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 1 as id, + 'test_name' as name, + @start_dt as ds + FROM + source_table; + """ + + # Write initial model + (models_dir / "test_model.sql").write_text(initial_model) + + with time_machine.travel("2023-01-08 00:00:00 UTC"): + # Create context and apply initial model + context = Context(paths=[tmp_path], config=config) + context.engine_adapter.execute("CREATE TABLE source_table (source_id INT)") + context.engine_adapter.execute("INSERT INTO source_table VALUES (1)") + + # Apply initial plan and load data + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify initial data was loaded + initial_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(initial_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in initial_df.columns + assert "name" in initial_df.columns + assert "ds" in initial_df.columns + + context.close() + + # remove `name` column and add new column + initial_model = """ + MODEL ( + name test_model, + kind SCD_TYPE_2_BY_TIME ( + unique_key id, + updated_at_name ds, + on_destructive_change allow, + on_additive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 1 as id, + 3 as new_column, + @start_dt as ds + FROM + source_table; + """ + (models_dir / "test_model.sql").write_text(initial_model) + + context = Context(paths=[tmp_path], config=config) + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify data loading continued to work + # The existing data should still be there and new data should be loaded + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + + assert len(updated_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is not still in table since destructive was allowed + assert "name" not in updated_df.columns + # new_column is not added since it is additive and ignored + assert "new_column" not in updated_df.columns + + context.close() + + with time_machine.travel("2023-01-10 00:00:00 UTC"): + context = Context(paths=[tmp_path], config=config) + context.run() + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(updated_df) == 2 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is not still in table since destructive was allowed + assert "name" not in updated_df.columns + # new_column is not added since it is additive and ignored + assert "new_column" not in updated_df.columns + + context.close() + + +def test_scd_type_2_by_column_ignore_destructive_change(tmp_path: Path): + models_dir = tmp_path / "models" + models_dir.mkdir() + data_dir = tmp_path / "data" + data_dir.mkdir() + data_filepath = data_dir / "test.duckdb" + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + default_connection=DuckDBConnectionConfig(database=str(data_filepath)), + ) + + # Initial model with 3 columns + initial_model = f""" + MODEL ( + name test_model, + kind SCD_TYPE_2_BY_COLUMN ( + unique_key id, + columns [name], + on_destructive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 1 as id, + 'test_name' as name, + @start_ds as ds + FROM + source_table; + """ + + # Write initial model + (models_dir / "test_model.sql").write_text(initial_model) + + with time_machine.travel("2023-01-08 00:00:00 UTC"): + # Create context and apply initial model + context = Context(paths=[tmp_path], config=config) + context.engine_adapter.execute("CREATE TABLE source_table (source_id INT)") + context.engine_adapter.execute("INSERT INTO source_table VALUES (1)") + + # Apply initial plan and load data + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify initial data was loaded + initial_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(initial_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in initial_df.columns + assert "name" in initial_df.columns + assert "ds" in initial_df.columns + + context.close() + + # remove `name` column and add new column + initial_model = """ + MODEL ( + name test_model, + kind SCD_TYPE_2_BY_COLUMN ( + unique_key id, + columns [new_column], + on_destructive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 1 as id, + 3 as new_column, + @start_ds as ds + FROM + source_table; + """ + (models_dir / "test_model.sql").write_text(initial_model) + + context = Context(paths=[tmp_path], config=config) + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify data loading continued to work + # The existing data should still be there and new data should be loaded + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + + assert len(updated_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is still in table since destructive was ignored + assert "name" in updated_df.columns + # new_column is added since it is additive and allowed + assert "new_column" in updated_df.columns + + context.close() + + with time_machine.travel("2023-01-10 00:00:00 UTC"): + context = Context(paths=[tmp_path], config=config) + context.run() + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(updated_df) == 2 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is still in table since destructive was ignored + assert "name" in updated_df.columns + # new_column is added since it is additive and allowed + assert "new_column" in updated_df.columns + + context.close() + + +def test_scd_type_2_by_column_ignore_additive_change(tmp_path: Path): + models_dir = tmp_path / "models" + models_dir.mkdir() + data_dir = tmp_path / "data" + data_dir.mkdir() + data_filepath = data_dir / "test.duckdb" + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + default_connection=DuckDBConnectionConfig(database=str(data_filepath)), + ) + + # Initial model with 3 columns + initial_model = f""" + MODEL ( + name test_model, + kind SCD_TYPE_2_BY_COLUMN ( + unique_key id, + columns [stable], + on_destructive_change allow, + on_additive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 1 as id, + 'test_name' as name, + 'stable' as stable, + @start_ds as ds + FROM + source_table; + """ + + # Write initial model + (models_dir / "test_model.sql").write_text(initial_model) + + with time_machine.travel("2023-01-08 00:00:00 UTC"): + # Create context and apply initial model + context = Context(paths=[tmp_path], config=config) + context.engine_adapter.execute("CREATE TABLE source_table (source_id INT)") + context.engine_adapter.execute("INSERT INTO source_table VALUES (1)") + + # Apply initial plan and load data + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify initial data was loaded + initial_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(initial_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in initial_df.columns + assert "name" in initial_df.columns + assert "ds" in initial_df.columns + + context.close() + + # remove `name` column and add new column + initial_model = """ + MODEL ( + name test_model, + kind SCD_TYPE_2_BY_COLUMN ( + unique_key id, + columns [stable], + on_destructive_change allow, + on_additive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 1 as id, + 'stable2' as stable, + 3 as new_column, + @start_ds as ds + FROM + source_table; + """ + (models_dir / "test_model.sql").write_text(initial_model) + + context = Context(paths=[tmp_path], config=config) + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify data loading continued to work + # The existing data should still be there and new data should be loaded + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + + assert len(updated_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is not still in table since destructive was ignored + assert "name" not in updated_df.columns + # new_column is not added since it is additive and ignored + assert "new_column" not in updated_df.columns + + context.close() + + with time_machine.travel("2023-01-10 00:00:00 UTC"): + context = Context(paths=[tmp_path], config=config) + context.run() + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(updated_df) == 2 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is not still in table since destructive was allowed + assert "name" not in updated_df.columns + # new_column is not added since it is additive and ignored + assert "new_column" not in updated_df.columns + + context.close() + + +def test_incremental_partition_ignore_destructive_change(tmp_path: Path): + models_dir = tmp_path / "models" + models_dir.mkdir() + data_dir = tmp_path / "data" + data_dir.mkdir() + data_filepath = data_dir / "test.duckdb" + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + default_connection=DuckDBConnectionConfig(database=str(data_filepath)), + ) + + # Initial model with 3 columns + initial_model = f""" + MODEL ( + name test_model, + kind INCREMENTAL_BY_PARTITION ( + on_destructive_change ignore + ), + partitioned_by [ds], + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 1 as id, + 'test_name' as name, + @start_ds as ds + FROM + source_table; + """ + + # Write initial model + (models_dir / "test_model.sql").write_text(initial_model) + + with time_machine.travel("2023-01-08 00:00:00 UTC"): + # Create context and apply initial model + context = Context(paths=[tmp_path], config=config) + context.engine_adapter.execute("CREATE TABLE source_table (source_id INT)") + context.engine_adapter.execute("INSERT INTO source_table VALUES (1)") + + # Apply initial plan and load data + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify initial data was loaded + initial_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(initial_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in initial_df.columns + assert "name" in initial_df.columns + assert "ds" in initial_df.columns + + context.close() + + # remove `name` column and add new column + initial_model = """ + MODEL ( + name test_model, + kind INCREMENTAL_BY_PARTITION ( + on_destructive_change ignore + ), + partitioned_by [ds], + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 1 as id, + 3 as new_column, + @start_ds as ds + FROM + source_table; + """ + (models_dir / "test_model.sql").write_text(initial_model) + + context = Context(paths=[tmp_path], config=config) + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify data loading continued to work + # The existing data should still be there and new data should be loaded + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + + assert len(updated_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is still in table since destructive was ignored + assert "name" in updated_df.columns + # new_column is added since it is additive and allowed + assert "new_column" in updated_df.columns + + context.close() + + with time_machine.travel("2023-01-10 00:00:00 UTC"): + context = Context(paths=[tmp_path], config=config) + context.run() + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(updated_df) == 2 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is still in table since destructive was ignored + assert "name" in updated_df.columns + # new_column is added since it is additive and allowed + assert "new_column" in updated_df.columns + + context.close() + + +def test_incremental_partition_ignore_additive_change(tmp_path: Path): + models_dir = tmp_path / "models" + models_dir.mkdir() + data_dir = tmp_path / "data" + data_dir.mkdir() + data_filepath = data_dir / "test.duckdb" + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + default_connection=DuckDBConnectionConfig(database=str(data_filepath)), + ) + + # Initial model with 3 columns + initial_model = f""" + MODEL ( + name test_model, + kind INCREMENTAL_BY_PARTITION ( + on_destructive_change allow, + on_additive_change ignore + ), + partitioned_by [ds], + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 1 as id, + 'test_name' as name, + @start_ds as ds + FROM + source_table; + """ + + # Write initial model + (models_dir / "test_model.sql").write_text(initial_model) + + with time_machine.travel("2023-01-08 00:00:00 UTC"): + # Create context and apply initial model + context = Context(paths=[tmp_path], config=config) + context.engine_adapter.execute("CREATE TABLE source_table (source_id INT)") + context.engine_adapter.execute("INSERT INTO source_table VALUES (1)") + + # Apply initial plan and load data + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify initial data was loaded + initial_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(initial_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in initial_df.columns + assert "name" in initial_df.columns + assert "ds" in initial_df.columns + + context.close() + + # remove `name` column and add new column + initial_model = """ + MODEL ( + name test_model, + kind INCREMENTAL_BY_PARTITION ( + on_destructive_change allow, + on_additive_change ignore + ), + partitioned_by [ds], + start '2023-01-01', + cron '@daily' + ); + + SELECT + *, + 1 as id, + 3 as new_column, + @start_ds as ds + FROM + source_table; + """ + (models_dir / "test_model.sql").write_text(initial_model) + + context = Context(paths=[tmp_path], config=config) + context.plan("prod", auto_apply=True, no_prompts=True) + + # Verify data loading continued to work + # The existing data should still be there and new data should be loaded + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + + assert len(updated_df) == 1 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is not still in table since destructive was allowed + assert "name" not in updated_df.columns + # new_column is not added since it is additive and ignored + assert "new_column" not in updated_df.columns + + context.close() + + with time_machine.travel("2023-01-10 00:00:00 UTC"): + context = Context(paths=[tmp_path], config=config) + context.run() + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(updated_df) == 2 + assert "source_id" in initial_df.columns + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is not still in table since destructive was allowed + assert "name" not in updated_df.columns + # new_column is not added since it is additive and ignored + assert "new_column" not in updated_df.columns + + context.close() + + +def test_incremental_by_time_model_ignore_destructive_change_unit_test(tmp_path: Path): + models_dir = tmp_path / "models" + models_dir.mkdir() + data_dir = tmp_path / "data" + data_dir.mkdir() + data_filepath = data_dir / "test.duckdb" + test_dir = tmp_path / "tests" + test_dir.mkdir() + test_filepath = test_dir / "test_test_model.yaml" + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + default_connection=DuckDBConnectionConfig(database=str(data_filepath)), + ) + + # Initial model with 3 columns + initial_model = f""" + MODEL ( + name test_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds, + forward_only true, + on_destructive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + id, + name, + ds + FROM + source_table; + """ + + # Write initial model + (models_dir / "test_model.sql").write_text(initial_model) + + initial_test = f""" + +test_test_model: + model: test_model + inputs: + source_table: + - id: 1 + name: 'test_name' + ds: '2025-01-01' + outputs: + query: + - id: 1 + name: 'test_name' + ds: '2025-01-01' +""" + + # Write initial test + test_filepath.write_text(initial_test) + + with time_machine.travel("2023-01-08 00:00:00 UTC"): + # Create context and apply initial model + context = Context(paths=[tmp_path], config=config) + context.engine_adapter.execute( + "CREATE TABLE source_table (id INT, name STRING, new_column INT, ds STRING)" + ) + context.engine_adapter.execute( + "INSERT INTO source_table VALUES (1, 'test_name', NULL, '2023-01-01')" + ) + + # Apply initial plan and load data + context.plan("prod", auto_apply=True, no_prompts=True, skip_tests=True) + test_result = context.test() + + # Verify initial data was loaded + initial_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(initial_df) == 1 + assert "id" in initial_df.columns + assert "name" in initial_df.columns + assert "ds" in initial_df.columns + assert len(test_result.successes) == 1 + assert test_result.testsRun == len(test_result.successes) + + context.close() + + # remove `name` column and add new column + initial_model = """ + MODEL ( + name test_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds, + forward_only true, + on_destructive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + id, + new_column, + ds + FROM + source_table; + """ + (models_dir / "test_model.sql").write_text(initial_model) + + updated_test = f""" + + test_test_model: + model: test_model + inputs: + source_table: + - id: 1 + new_column: 3 + ds: '2025-01-01' + outputs: + query: + - id: 1 + new_column: 3 + ds: '2025-01-01' + """ + + # Write initial test + test_filepath.write_text(updated_test) + + context = Context(paths=[tmp_path], config=config) + context.plan("prod", auto_apply=True, no_prompts=True, skip_tests=True) + test_result = context.test() + + # Verify data loading continued to work + # The existing data should still be there and new data should be loaded + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(updated_df) == 1 + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is still in table since destructive was ignored + assert "name" in updated_df.columns + # new_column is added since it is additive and allowed + assert "new_column" in updated_df.columns + assert len(test_result.successes) == 1 + assert test_result.testsRun == len(test_result.successes) + + context.close() + + with time_machine.travel("2023-01-10 00:00:00 UTC"): + context = Context(paths=[tmp_path], config=config) + context.engine_adapter.execute("INSERT INTO source_table VALUES (2, NULL, 3, '2023-01-09')") + context.run() + test_result = context.test() + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(updated_df) == 2 + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is still in table since destructive was ignored + assert "name" in updated_df.columns + # new_column is added since it is additive and allowed + assert "new_column" in updated_df.columns + assert len(test_result.successes) == 1 + assert test_result.testsRun == len(test_result.successes) + + context.close() + + +def test_incremental_by_time_model_ignore_additive_change_unit_test(tmp_path: Path): + models_dir = tmp_path / "models" + models_dir.mkdir() + data_dir = tmp_path / "data" + data_dir.mkdir() + data_filepath = data_dir / "test.duckdb" + test_dir = tmp_path / "tests" + test_dir.mkdir() + test_filepath = test_dir / "test_test_model.yaml" + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + default_connection=DuckDBConnectionConfig(database=str(data_filepath)), + ) + + # Initial model with 3 columns + initial_model = f""" + MODEL ( + name test_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds, + forward_only true, + on_destructive_change allow, + on_additive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + id, + name, + ds + FROM + source_table; + """ + + # Write initial model + (models_dir / "test_model.sql").write_text(initial_model) + + initial_test = f""" + +test_test_model: + model: test_model + inputs: + source_table: + - id: 1 + name: 'test_name' + ds: '2025-01-01' + outputs: + query: + - id: 1 + name: 'test_name' + ds: '2025-01-01' +""" + + # Write initial test + test_filepath.write_text(initial_test) + + with time_machine.travel("2023-01-08 00:00:00 UTC"): + # Create context and apply initial model + context = Context(paths=[tmp_path], config=config) + context.engine_adapter.execute( + "CREATE TABLE source_table (id INT, name STRING, new_column INT, ds STRING)" + ) + context.engine_adapter.execute( + "INSERT INTO source_table VALUES (1, 'test_name', NULL, '2023-01-01')" + ) + + # Apply initial plan and load data + context.plan("prod", auto_apply=True, no_prompts=True, skip_tests=True) + test_result = context.test() + + # Verify initial data was loaded + initial_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(initial_df) == 1 + assert "id" in initial_df.columns + assert "name" in initial_df.columns + assert "ds" in initial_df.columns + assert len(test_result.successes) == 1 + assert test_result.testsRun == len(test_result.successes) + + context.close() + + # remove `name` column and add new column + initial_model = """ + MODEL ( + name test_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds, + forward_only true, + on_destructive_change allow, + on_additive_change ignore + ), + start '2023-01-01', + cron '@daily' + ); + + SELECT + id, + new_column, + ds + FROM + source_table; + """ + (models_dir / "test_model.sql").write_text(initial_model) + + # `new_column` is in the output since unit tests are based on the model definition that currently + # exists and doesn't take into account the historical changes to the table. Therefore `new_column` is + # not actually in the table but it is represented in the test + updated_test = f""" + test_test_model: + model: test_model + inputs: + source_table: + - id: 1 + new_column: 3 + ds: '2025-01-01' + outputs: + query: + - id: 1 + new_column: 3 + ds: '2025-01-01' + """ + + # Write initial test + test_filepath.write_text(updated_test) + + context = Context(paths=[tmp_path], config=config) + context.plan("prod", auto_apply=True, no_prompts=True, skip_tests=True) + test_result = context.test() + + # Verify data loading continued to work + # The existing data should still be there and new data should be loaded + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(updated_df) == 1 + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is not in table since destructive was ignored + assert "name" not in updated_df.columns + # new_column is not added since it is additive and ignored + assert "new_column" not in updated_df.columns + assert len(test_result.successes) == 1 + assert test_result.testsRun == len(test_result.successes) + + context.close() + + with time_machine.travel("2023-01-10 00:00:00 UTC"): + context = Context(paths=[tmp_path], config=config) + context.engine_adapter.execute("INSERT INTO source_table VALUES (2, NULL, 3, '2023-01-09')") + context.run() + test_result = context.test() + updated_df = context.fetchdf('SELECT * FROM "default"."test_model"') + assert len(updated_df) == 2 + assert "id" in updated_df.columns + assert "ds" in updated_df.columns + # name is still not in table since destructive was allowed + assert "name" not in updated_df.columns + # new_column is not added since it is additive and ignored + assert "new_column" not in updated_df.columns + assert len(test_result.successes) == 1 + assert test_result.testsRun == len(test_result.successes) + + context.close() + + +@time_machine.travel("2020-01-01 00:00:00 UTC") +def test_scd_type_2_full_restatement_no_start_date(init_and_plan_context: t.Callable): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + # Initial product catalog of 3 products + raw_products = d.parse(""" + MODEL ( + name memory.store.raw_products, + kind FULL + ); + + SELECT * FROM VALUES + (101, 'Laptop Pro', 1299.99, 'Electronics', '2020-01-01 00:00:00'::TIMESTAMP), + (102, 'Wireless Mouse', 49.99, 'Electronics', '2020-01-01 00:00:00'::TIMESTAMP), + (103, 'Office Chair', 199.99, 'Furniture', '2020-01-01 00:00:00'::TIMESTAMP) + AS t(product_id, product_name, price, category, last_updated); + """) + + # SCD Type 2 model for product history tracking + product_history = d.parse(""" + MODEL ( + name memory.store.product_history, + kind SCD_TYPE_2_BY_TIME ( + unique_key product_id, + updated_at_name last_updated, + disable_restatement false + ), + owner catalog_team, + cron '0 */6 * * *', + grain product_id, + description 'Product catalog change history' + ); + + SELECT + product_id::INT AS product_id, + product_name::TEXT AS product_name, + price::DECIMAL(10,2) AS price, + category::TEXT AS category, + last_updated AS last_updated + FROM + memory.store.raw_products; + """) + + raw_products_model = load_sql_based_model(raw_products) + product_history_model = load_sql_based_model(product_history) + context.upsert_model(raw_products_model) + context.upsert_model(product_history_model) + + # Initial plan and apply + plan = context.plan_builder("prod", skip_tests=True).build() + context.apply(plan) + + query = "SELECT product_id, product_name, price, category, last_updated, valid_from, valid_to FROM memory.store.product_history ORDER BY product_id, valid_from" + initial_data = context.engine_adapter.fetchdf(query) + + # Validate initial state of 3 products all active + assert len(initial_data) == 3 + assert initial_data["valid_to"].isna().all() + initial_product_names = set(initial_data["product_name"].tolist()) + assert initial_product_names == {"Laptop Pro", "Wireless Mouse", "Office Chair"} + + # Price update and category change + with time_machine.travel("2020-01-15 12:00:00 UTC"): + raw_products_v2 = d.parse(""" + MODEL ( + name memory.store.raw_products, + kind FULL + ); + + SELECT * FROM VALUES + (101, 'Laptop Pro', 1199.99, 'Electronics', '2020-01-15 00:00:00'::TIMESTAMP), + (102, 'Wireless Mouse', 49.99, 'Electronics', '2020-01-01 00:00:00'::TIMESTAMP), + (103, 'Ergonomic Office Chair', 229.99, 'Office Furniture', '2020-01-15 00:00:00'::TIMESTAMP) + AS t(product_id, product_name, price, category, last_updated); + """) + raw_products_v2_model = load_sql_based_model(raw_products_v2) + context.upsert_model(raw_products_v2_model) + context.plan( + auto_apply=True, no_prompts=True, categorizer_config=CategorizerConfig.all_full() + ) + context.run() + + data_after_first_change = context.engine_adapter.fetchdf(query) + + # Should have 5 records (3 original closed, 2 new activε, 1 unchanged) + assert len(data_after_first_change) == 5 + + # Second change + with time_machine.travel("2020-02-01 10:00:00 UTC"): + raw_products_v3 = d.parse(""" + MODEL ( + name memory.store.raw_products, + kind FULL + ); + + SELECT * FROM VALUES + (101, 'Laptop Pro Max', 1399.99, 'Electronics', '2020-02-01 00:00:00'::TIMESTAMP), + (103, 'Ergonomic Office Chair', 229.99, 'Office Furniture', '2020-01-15 00:00:00'::TIMESTAMP), + (102, 'Wireless Mouse', 49.99, 'Electronics', '2020-01-01 00:00:00'::TIMESTAMP) + AS t(product_id, product_name, price, category, last_updated); + """) + raw_products_v3_model = load_sql_based_model(raw_products_v3) + context.upsert_model(raw_products_v3_model) + context.plan( + auto_apply=True, no_prompts=True, categorizer_config=CategorizerConfig.all_full() + ) + context.run() + data_after_second_change = context.engine_adapter.fetchdf(query) + assert len(data_after_second_change) == 6 + + # Store the current state before full restatement + data_before_full_restatement = data_after_second_change.copy() + + # Perform full restatement (no start date provided) + with time_machine.travel("2020-02-01 15:00:00 UTC"): + plan = context.plan_builder( + "prod", skip_tests=True, restate_models=["memory.store.product_history"] + ).build() + context.apply(plan) + data_after_full_restatement = context.engine_adapter.fetchdf(query) + assert len(data_after_full_restatement) == 3 + + # Check that all currently active products before restatement are still active after restatement + active_before = data_before_full_restatement[ + data_before_full_restatement["valid_to"].isna() + ] + active_after = data_after_full_restatement + assert set(active_before["product_id"]) == set(active_after["product_id"]) + + expected_products = { + 101: { + "product_name": "Laptop Pro Max", + "price": 1399.99, + "category": "Electronics", + "last_updated": "2020-02-01", + }, + 102: { + "product_name": "Wireless Mouse", + "price": 49.99, + "category": "Electronics", + "last_updated": "2020-01-01", + }, + 103: { + "product_name": "Ergonomic Office Chair", + "price": 229.99, + "category": "Office Furniture", + "last_updated": "2020-01-15", + }, + } + for _, row in data_after_full_restatement.iterrows(): + pid = row["product_id"] + assert pid in expected_products + expected = expected_products[pid] + assert row["product_name"] == expected["product_name"] + assert float(row["price"]) == expected["price"] + assert row["category"] == expected["category"] + + # valid_from should be the epoch, valid_to should be NaT + assert str(row["valid_from"]) == "1970-01-01 00:00:00" + assert pd.isna(row["valid_to"]) + + +def test_plan_evaluator_correlation_id(tmp_path: Path): + def _correlation_id_in_sqls(correlation_id: CorrelationId, mock_logger): + sqls = [call[0][0] for call in mock_logger.call_args_list] + return any(f"/* {correlation_id} */" in sql for sql in sqls) + + ctx = Context(paths=[tmp_path], config=Config()) + + # Case: Ensure that the correlation id (plan_id) is included in the SQL for each plan + for i in range(2): + create_temp_file( + tmp_path, + Path("models", "test.sql"), + f"MODEL (name test.a, kind FULL); SELECT {i} AS col", + ) + + with mock.patch("sqlmesh.core.engine_adapter.base.EngineAdapter._log_sql") as mock_logger: + ctx.load() + plan = ctx.plan(auto_apply=True, no_prompts=True) + + correlation_id = CorrelationId.from_plan_id(plan.plan_id) + assert str(correlation_id) == f"SQLMESH_PLAN: {plan.plan_id}" + + assert _correlation_id_in_sqls(correlation_id, mock_logger) + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_scd_type_2_regular_run_with_offset(init_and_plan_context: t.Callable): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + raw_employee_status = d.parse(""" + MODEL ( + name memory.hr_system.raw_employee_status, + kind FULL + ); + + SELECT + 1001 AS employee_id, + 'engineering' AS department, + 'EMEA' AS region, + '2023-01-08 15:00:00 UTC' AS last_modified; + """) + + employee_history = d.parse(""" + MODEL ( + name memory.hr_system.employee_history, + kind SCD_TYPE_2_BY_TIME ( + unique_key employee_id, + updated_at_name last_modified, + disable_restatement false + ), + owner hr_analytics, + cron '0 7 * * *', + grain employee_id, + description 'Historical tracking of employee status changes' + ); + + SELECT + employee_id::INT AS employee_id, + department::TEXT AS department, + region::TEXT AS region, + last_modified AS last_modified + FROM + memory.hr_system.raw_employee_status; + """) + + raw_employee_status_model = load_sql_based_model(raw_employee_status) + employee_history_model = load_sql_based_model(employee_history) + context.upsert_model(raw_employee_status_model) + context.upsert_model(employee_history_model) + + # Initial plan and apply + plan = context.plan_builder("prod", skip_tests=True).build() + context.apply(plan) + + query = "SELECT employee_id, department, region, valid_from, valid_to FROM memory.hr_system.employee_history ORDER BY employee_id, valid_from" + initial_data = context.engine_adapter.fetchdf(query) + + assert len(initial_data) == 1 + assert initial_data["valid_to"].isna().all() + assert initial_data["department"].tolist() == ["engineering"] + assert initial_data["region"].tolist() == ["EMEA"] + + # Apply a future plan with source changes a few hours before the cron time of the SCD Type 2 model BUT on the same day + with time_machine.travel("2023-01-09 00:10:00 UTC"): + raw_employee_status_v2 = d.parse(""" + MODEL ( + name memory.hr_system.raw_employee_status, + kind FULL + ); + + SELECT + 1001 AS employee_id, + 'engineering' AS department, + 'AMER' AS region, + '2023-01-09 00:10:00 UTC' AS last_modified; + """) + raw_employee_status_v2_model = load_sql_based_model(raw_employee_status_v2) + context.upsert_model(raw_employee_status_v2_model) + context.plan( + auto_apply=True, no_prompts=True, categorizer_config=CategorizerConfig.all_full() + ) + + # The 7th hour of the day the run is kicked off for the SCD Type 2 model + with time_machine.travel("2023-01-09 07:00:01 UTC"): + context.run() + data_after_change = context.engine_adapter.fetchdf(query) + + # Validate the SCD2 records for employee 1001 + assert len(data_after_change) == 2 + assert data_after_change.iloc[0]["employee_id"] == 1001 + assert data_after_change.iloc[0]["department"] == "engineering" + assert data_after_change.iloc[0]["region"] == "EMEA" + assert str(data_after_change.iloc[0]["valid_from"]) == "1970-01-01 00:00:00" + assert str(data_after_change.iloc[0]["valid_to"]) == "2023-01-09 00:10:00" + assert data_after_change.iloc[1]["employee_id"] == 1001 + assert data_after_change.iloc[1]["department"] == "engineering" + assert data_after_change.iloc[1]["region"] == "AMER" + assert str(data_after_change.iloc[1]["valid_from"]) == "2023-01-09 00:10:00" + assert pd.isna(data_after_change.iloc[1]["valid_to"]) + + # Update source model again a bit later on the same day + raw_employee_status_v2 = d.parse(""" + MODEL ( + name memory.hr_system.raw_employee_status, + kind FULL + ); + + SELECT + 1001 AS employee_id, + 'sales' AS department, + 'ANZ' AS region, + '2023-01-09 07:26:00 UTC' AS last_modified; + """) + raw_employee_status_v2_model = load_sql_based_model(raw_employee_status_v2) + context.upsert_model(raw_employee_status_v2_model) + context.plan( + auto_apply=True, no_prompts=True, categorizer_config=CategorizerConfig.all_full() + ) + + # A day later the run is kicked off for the SCD Type 2 model again + with time_machine.travel("2023-01-10 07:00:00 UTC"): + context.run() + data_after_change = context.engine_adapter.fetchdf(query) + + # Validate the SCD2 history for employee 1001 after second change with the historical records intact + assert len(data_after_change) == 3 + assert data_after_change.iloc[0]["employee_id"] == 1001 + assert data_after_change.iloc[0]["department"] == "engineering" + assert data_after_change.iloc[0]["region"] == "EMEA" + assert str(data_after_change.iloc[0]["valid_from"]) == "1970-01-01 00:00:00" + assert str(data_after_change.iloc[0]["valid_to"]) == "2023-01-09 00:10:00" + assert data_after_change.iloc[1]["employee_id"] == 1001 + assert data_after_change.iloc[1]["department"] == "engineering" + assert data_after_change.iloc[1]["region"] == "AMER" + assert str(data_after_change.iloc[1]["valid_from"]) == "2023-01-09 00:10:00" + assert str(data_after_change.iloc[1]["valid_to"]) == "2023-01-09 07:26:00" + assert data_after_change.iloc[2]["employee_id"] == 1001 + assert data_after_change.iloc[2]["department"] == "sales" + assert data_after_change.iloc[2]["region"] == "ANZ" + assert str(data_after_change.iloc[2]["valid_from"]) == "2023-01-09 07:26:00" + assert pd.isna(data_after_change.iloc[2]["valid_to"]) + + # Now test restatement works (full restatement support currently) + with time_machine.travel("2023-01-10 07:38:00 UTC"): + plan = context.plan_builder( + "prod", + skip_tests=True, + restate_models=["memory.hr_system.employee_history"], + start="2023-01-09 00:10:00", + ).build() + context.apply(plan) + restated_data = context.engine_adapter.fetchdf(query) + + # Validate the SCD2 history after restatement has been wiped bar one + assert len(restated_data) == 1 + assert restated_data.iloc[0]["employee_id"] == 1001 + assert restated_data.iloc[0]["department"] == "sales" + assert restated_data.iloc[0]["region"] == "ANZ" + assert str(restated_data.iloc[0]["valid_from"]) == "1970-01-01 00:00:00" + assert pd.isna(restated_data.iloc[0]["valid_to"]) + + +def test_seed_model_metadata_update_does_not_trigger_backfill(tmp_path: Path): + """ + Scenario: + - Create a seed model; perform initial population + - Modify the model with a metadata-only change and trigger a plan + + Outcome: + - The seed model is modified (metadata-only) but this should NOT trigger backfill + - There should be no missing_intervals on the plan to backfill + """ + + models_path = tmp_path / "models" + seeds_path = tmp_path / "seeds" + models_path.mkdir() + seeds_path.mkdir() + + seed_model_path = models_path / "seed.sql" + seed_path = seeds_path / "seed_data.csv" + + seed_path.write_text("\n".join(["id,name", "1,test"])) + + seed_model_path.write_text(""" + MODEL ( + name test.source_data, + kind SEED ( + path '../seeds/seed_data.csv' + ) + ); + """) + + config = Config( + gateways={"": GatewayConfig(connection=DuckDBConnectionConfig())}, + model_defaults=ModelDefaultsConfig(dialect="duckdb", start="2024-01-01"), + ) + ctx = Context(paths=tmp_path, config=config) + + plan = ctx.plan(auto_apply=True) + + original_seed_snapshot = ctx.snapshots['"memory"."test"."source_data"'] + assert plan.directly_modified == {original_seed_snapshot.snapshot_id} + assert plan.metadata_updated == set() + assert plan.missing_intervals + + # prove data loaded + assert ctx.engine_adapter.fetchall("select id, name from memory.test.source_data") == [ + (1, "test") + ] + + # prove no diff + ctx.load() + plan = ctx.plan(auto_apply=True) + assert not plan.has_changes + assert not plan.missing_intervals + + # make metadata-only change + seed_model_path.write_text(""" + MODEL ( + name test.source_data, + kind SEED ( + path '../seeds/seed_data.csv' + ), + description 'updated by test' + ); + """) + + ctx.load() + plan = ctx.plan(auto_apply=True) + assert plan.has_changes + + new_seed_snapshot = ctx.snapshots['"memory"."test"."source_data"'] + assert ( + new_seed_snapshot.version == original_seed_snapshot.version + ) # should be using the same physical table + assert ( + new_seed_snapshot.snapshot_id != original_seed_snapshot.snapshot_id + ) # but still be different due to the metadata change + assert plan.directly_modified == set() + assert plan.metadata_updated == {new_seed_snapshot.snapshot_id} + + # there should be no missing intervals to backfill since all we did is update a description + assert not plan.missing_intervals + + # there should still be no diff or missing intervals in 3 days time + assert new_seed_snapshot.model.interval_unit.is_day + with time_machine.travel(timedelta(days=3)): + ctx.clear_caches() + ctx.load() + plan = ctx.plan(auto_apply=True) + assert not plan.has_changes + assert not plan.missing_intervals + + # change seed data + seed_path.write_text("\n".join(["id,name", "1,test", "2,updated"])) + + # new plan - NOW we should backfill because data changed + ctx.load() + plan = ctx.plan(auto_apply=True) + assert plan.has_changes + + updated_seed_snapshot = ctx.snapshots['"memory"."test"."source_data"'] + + assert ( + updated_seed_snapshot.snapshot_id + != new_seed_snapshot.snapshot_id + != original_seed_snapshot.snapshot_id + ) + assert not updated_seed_snapshot.forward_only + assert plan.directly_modified == {updated_seed_snapshot.snapshot_id} + assert plan.metadata_updated == set() + assert plan.missing_intervals + + # prove backfilled data loaded + assert ctx.engine_adapter.fetchall("select id, name from memory.test.source_data") == [ + (1, "test"), + (2, "updated"), + ] + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_seed_model_promote_to_prod_after_dev( + init_and_plan_context: t.Callable, +): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + with open(context.path / "seeds" / "waiter_names.csv", "a") as f: + f.write("\n10,New Waiter") + + context.load() + + waiter_names_snapshot = context.get_snapshot("sushi.waiter_names") + plan = context.plan("dev", skip_tests=True, auto_apply=True, no_prompts=True) + assert waiter_names_snapshot.snapshot_id in plan.directly_modified + + # Trigger a metadata change to reuse the previous version + waiter_names_model = waiter_names_snapshot.model.copy( + update={"description": "Updated description"} + ) + context.upsert_model(waiter_names_model) + context.plan("dev", skip_tests=True, auto_apply=True, no_prompts=True) + + # Promote all changes to prod + waiter_names_snapshot = context.get_snapshot("sushi.waiter_names") + plan = context.plan_builder("prod", skip_tests=True).build() + # Clear the cache to source the dehydrated model instance from the state + context.clear_caches() + context.apply(plan) + + assert ( + context.engine_adapter.fetchone("SELECT COUNT(*) FROM sushi.waiter_names WHERE id = 10")[0] + == 1 + ) diff --git a/tests/core/integration/test_multi_repo.py b/tests/core/integration/test_multi_repo.py new file mode 100644 index 0000000000..4d72d137b3 --- /dev/null +++ b/tests/core/integration/test_multi_repo.py @@ -0,0 +1,561 @@ +from __future__ import annotations + +from unittest.mock import patch +from textwrap import dedent +import os +import pytest +from pathlib import Path +from sqlmesh.core.console import ( + get_console, +) +from sqlmesh.core.config.naming import NameInferenceConfig +from sqlmesh.core.model.common import ParsableSql +from sqlmesh.utils.concurrency import NodeExecutionFailedError + +from sqlmesh.core import constants as c +from sqlmesh.core.config import ( + Config, + GatewayConfig, + ModelDefaultsConfig, + DuckDBConnectionConfig, +) +from sqlmesh.core.console import get_console +from sqlmesh.core.context import Context +from sqlmesh.utils.date import now +from tests.conftest import DuckDBMetadata +from tests.utils.test_helpers import use_terminal_console +from tests.core.integration.utils import validate_apply_basics + + +pytestmark = pytest.mark.slow + + +@use_terminal_console +def test_multi(mocker): + context = Context(paths=["examples/multi/repo_1", "examples/multi/repo_2"], gateway="memory") + + with patch.object(get_console(), "log_warning") as mock_logger: + context.plan_builder(environment="dev") + warnings = mock_logger.call_args[0][0] + repo1_path, repo2_path = context.configs.keys() + assert f"Linter warnings for {repo1_path}" in warnings + assert f"Linter warnings for {repo2_path}" not in warnings + + assert ( + context.render("bronze.a").sql() + == '''SELECT 1 AS "col_a", 'b' AS "col_b", 1 AS "one", 'repo_1' AS "dup"''' + ) + assert ( + context.render("silver.d").sql() + == '''SELECT "c"."col_a" AS "col_a", 2 AS "two", 'repo_2' AS "dup" FROM "memory"."silver"."c" AS "c"''' + ) + context._new_state_sync().reset(default_catalog=context.default_catalog) + plan = context.plan_builder().build() + assert len(plan.new_snapshots) == 5 + context.apply(plan) + + # Ensure before_all, after_all statements for multiple repos have executed + environment_statements = context.state_reader.get_environment_statements(c.PROD) + assert len(environment_statements) == 2 + assert context.fetchdf("select * from before_1").to_dict()["1"][0] == 1 + assert context.fetchdf("select * from before_2").to_dict()["2"][0] == 2 + assert context.fetchdf("select * from after_1").to_dict()["repo_1"][0] == "repo_1" + assert context.fetchdf("select * from after_2").to_dict()["repo_2"][0] == "repo_2" + + old_context = context + context = Context( + paths=["examples/multi/repo_1"], + state_sync=old_context.state_sync, + gateway="memory", + ) + context._engine_adapter = old_context.engine_adapter + del context.engine_adapters + + model = context.get_model("bronze.a") + assert model.project == "repo_1" + context.upsert_model( + model.copy( + update={ + "query_": ParsableSql(sql=model.query.select("'c' AS c").sql(dialect=model.dialect)) + } + ) + ) + plan = context.plan_builder().build() + + assert set(snapshot.name for snapshot in plan.directly_modified) == { + '"memory"."bronze"."a"', + '"memory"."bronze"."b"', + '"memory"."silver"."e"', + } + assert sorted([x.name for x in list(plan.indirectly_modified.values())[0]]) == [ + '"memory"."silver"."c"', + '"memory"."silver"."d"', + ] + assert len(plan.missing_intervals) == 3 + context.apply(plan) + validate_apply_basics(context, c.PROD, plan.snapshots.values()) + + # Ensure that before_all and after_all statements of both repos are there despite planning with repo_1 + environment_statements = context.state_reader.get_environment_statements(c.PROD) + assert len(environment_statements) == 2 + + # Ensure that environment statements have the project field set correctly + sorted_env_statements = sorted(environment_statements, key=lambda es: es.project) + assert sorted_env_statements[0].project == "repo_1" + assert sorted_env_statements[1].project == "repo_2" + + # Assert before_all and after_all for each project + assert sorted_env_statements[0].before_all == [ + "CREATE TABLE IF NOT EXISTS before_1 AS select @one()" + ] + assert sorted_env_statements[0].after_all == [ + "CREATE TABLE IF NOT EXISTS after_1 AS select @dup()" + ] + assert sorted_env_statements[1].before_all == [ + "CREATE TABLE IF NOT EXISTS before_2 AS select @two()" + ] + assert sorted_env_statements[1].after_all == [ + "CREATE TABLE IF NOT EXISTS after_2 AS select @dup()" + ] + + +@use_terminal_console +def test_multi_repo_single_project_environment_statements_update(copy_to_temp_path): + paths = copy_to_temp_path("examples/multi") + repo_1_path = f"{paths[0]}/repo_1" + repo_2_path = f"{paths[0]}/repo_2" + + context = Context(paths=[repo_1_path, repo_2_path], gateway="memory") + context._new_state_sync().reset(default_catalog=context.default_catalog) + + initial_plan = context.plan_builder().build() + context.apply(initial_plan) + + # Get initial statements + initial_statements = context.state_reader.get_environment_statements(c.PROD) + assert len(initial_statements) == 2 + + # Modify repo_1's config to add a new before_all statement + repo_1_config_path = f"{repo_1_path}/config.yaml" + with open(repo_1_config_path, "r") as f: + config_content = f.read() + + # Add a new before_all statement to repo_1 only + modified_config = config_content.replace( + "CREATE TABLE IF NOT EXISTS before_1 AS select @one()", + "CREATE TABLE IF NOT EXISTS before_1 AS select @one()\n - CREATE TABLE IF NOT EXISTS before_1_modified AS select 999", + ) + + with open(repo_1_config_path, "w") as f: + f.write(modified_config) + + # Create new context with modified config but only for repo_1 + context_repo_1_only = Context( + paths=[repo_1_path], state_sync=context.state_sync, gateway="memory" + ) + + # Plan with only repo_1, this should preserve repo_2's statements from state + repo_1_plan = context_repo_1_only.plan_builder(environment="dev").build() + context_repo_1_only.apply(repo_1_plan) + updated_statements = context_repo_1_only.state_reader.get_environment_statements("dev") + + # Should still have statements from both projects + assert len(updated_statements) == 2 + + # Sort by project + sorted_updated = sorted(updated_statements, key=lambda es: es.project or "") + + # Verify repo_1 has the new statement + repo_1_updated = sorted_updated[0] + assert repo_1_updated.project == "repo_1" + assert len(repo_1_updated.before_all) == 2 + assert "CREATE TABLE IF NOT EXISTS before_1_modified" in repo_1_updated.before_all[1] + + # Verify repo_2 statements are preserved from state + repo_2_preserved = sorted_updated[1] + assert repo_2_preserved.project == "repo_2" + assert len(repo_2_preserved.before_all) == 1 + assert "CREATE TABLE IF NOT EXISTS before_2" in repo_2_preserved.before_all[0] + assert "CREATE TABLE IF NOT EXISTS after_2 AS select @dup()" in repo_2_preserved.after_all[0] + + +@use_terminal_console +def test_multi_virtual_layer(copy_to_temp_path): + paths = copy_to_temp_path("tests/fixtures/multi_virtual_layer") + path = Path(paths[0]) + first_db_path = str(path / "db_1.db") + second_db_path = str(path / "db_2.db") + + config = Config( + gateways={ + "first": GatewayConfig( + connection=DuckDBConnectionConfig(database=first_db_path), + variables={"overriden_var": "gateway_1"}, + ), + "second": GatewayConfig( + connection=DuckDBConnectionConfig(database=second_db_path), + variables={"overriden_var": "gateway_2"}, + ), + }, + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + model_naming=NameInferenceConfig(infer_names=True), + default_gateway="first", + gateway_managed_virtual_layer=True, + variables={"overriden_var": "global", "global_one": 88}, + ) + + context = Context(paths=paths, config=config) + assert context.default_catalog_per_gateway == {"first": "db_1", "second": "db_2"} + assert len(context.engine_adapters) == 2 + + # For the model without gateway the default should be used and the gateway variable should overide the global + assert ( + context.render("first_schema.model_one").sql() + == 'SELECT \'gateway_1\' AS "item_id", 88 AS "global_one", 1 AS "macro_one"' + ) + + # For model with gateway specified the appropriate variable should be used to overide + assert ( + context.render("db_2.second_schema.model_one").sql() + == 'SELECT \'gateway_2\' AS "item_id", 88 AS "global_one", 1 AS "macro_one"' + ) + + plan = context.plan_builder().build() + assert len(plan.new_snapshots) == 4 + context.apply(plan) + + # Validate the tables that source from the first tables are correct as well with evaluate + assert ( + context.evaluate( + "first_schema.model_two", start=now(), end=now(), execution_time=now() + ).to_string() + == " item_id global_one\n0 gateway_1 88" + ) + assert ( + context.evaluate( + "db_2.second_schema.model_two", start=now(), end=now(), execution_time=now() + ).to_string() + == " item_id global_one\n0 gateway_2 88" + ) + + assert sorted(set(snapshot.name for snapshot in plan.directly_modified)) == [ + '"db_1"."first_schema"."model_one"', + '"db_1"."first_schema"."model_two"', + '"db_2"."second_schema"."model_one"', + '"db_2"."second_schema"."model_two"', + ] + + model = context.get_model("db_1.first_schema.model_one") + + context.upsert_model( + model.copy( + update={ + "query_": ParsableSql( + sql=model.query.select("'c' AS extra").sql(dialect=model.dialect) + ) + } + ) + ) + plan = context.plan_builder().build() + context.apply(plan) + + state_environments = context.state_reader.get_environments() + state_snapshots = context.state_reader.get_snapshots(context.snapshots.values()) + + assert state_environments[0].gateway_managed + assert len(state_snapshots) == len(state_environments[0].snapshots) + assert [snapshot.name for snapshot in plan.directly_modified] == [ + '"db_1"."first_schema"."model_one"' + ] + assert [x.name for x in list(plan.indirectly_modified.values())[0]] == [ + '"db_1"."first_schema"."model_two"' + ] + + assert len(plan.missing_intervals) == 1 + assert ( + context.evaluate( + "db_1.first_schema.model_one", start=now(), end=now(), execution_time=now() + ).to_string() + == " item_id global_one macro_one extra\n0 gateway_1 88 1 c" + ) + + # Create dev environment with changed models + model = context.get_model("db_2.second_schema.model_one") + context.upsert_model( + model.copy( + update={ + "query_": ParsableSql( + sql=model.query.select("'d' AS extra").sql(dialect=model.dialect) + ) + } + ) + ) + model = context.get_model("first_schema.model_two") + context.upsert_model( + model.copy( + update={ + "query_": ParsableSql( + sql=model.query.select("'d2' AS col").sql(dialect=model.dialect) + ) + } + ) + ) + plan = context.plan_builder("dev").build() + context.apply(plan) + + dev_environment = context.state_sync.get_environment("dev") + assert dev_environment is not None + + metadata_engine_1 = DuckDBMetadata.from_context(context) + start_schemas_1 = set(metadata_engine_1.schemas) + assert sorted(start_schemas_1) == sorted( + {"first_schema__dev", "sqlmesh", "first_schema", "sqlmesh__first_schema"} + ) + + metadata_engine_2 = DuckDBMetadata(context._get_engine_adapter("second")) + start_schemas_2 = set(metadata_engine_2.schemas) + assert sorted(start_schemas_2) == sorted( + {"sqlmesh__second_schema", "second_schema", "second_schema__dev"} + ) + + # Invalidate dev environment + context.invalidate_environment("dev") + invalidate_environment = context.state_sync.get_environment("dev") + assert invalidate_environment is not None + assert invalidate_environment.expiration_ts < dev_environment.expiration_ts # type: ignore + assert sorted(start_schemas_1) == sorted(set(metadata_engine_1.schemas)) + assert sorted(start_schemas_2) == sorted(set(metadata_engine_2.schemas)) + + # Run janitor + context._run_janitor() + assert context.state_sync.get_environment("dev") is None + removed_schemas = start_schemas_1 - set(metadata_engine_1.schemas) + assert removed_schemas == {"first_schema__dev"} + removed_schemas = start_schemas_2 - set(metadata_engine_2.schemas) + assert removed_schemas == {"second_schema__dev"} + prod_environment = context.state_sync.get_environment("prod") + + # Remove the second gateway's second model and apply plan + second_model = path / "models/second_schema/model_two.sql" + os.remove(second_model) + assert not second_model.exists() + context = Context(paths=paths, config=config) + plan = context.plan_builder().build() + context.apply(plan) + prod_environment = context.state_sync.get_environment("prod") + assert len(prod_environment.snapshots_) == 3 + + # Changing the flag should show a diff + context.config.gateway_managed_virtual_layer = False + plan = context.plan_builder().build() + assert not plan.requires_backfill + assert ( + plan.context_diff.previous_gateway_managed_virtual_layer + != plan.context_diff.gateway_managed_virtual_layer + ) + assert plan.context_diff.has_changes + + # This should error since the default_gateway won't have access to create the view on a non-shared catalog + with pytest.raises(NodeExecutionFailedError, match=r"Execution failed for node SnapshotId*"): + context.apply(plan) + + +def test_multi_dbt(mocker): + context = Context(paths=["examples/multi_dbt/bronze", "examples/multi_dbt/silver"]) + context._new_state_sync().reset(default_catalog=context.default_catalog) + plan = context.plan_builder().build() + assert len(plan.new_snapshots) == 4 + context.apply(plan) + validate_apply_basics(context, c.PROD, plan.snapshots.values()) + + environment_statements = context.state_sync.get_environment_statements(c.PROD) + assert len(environment_statements) == 2 + bronze_statements = environment_statements[0] + assert bronze_statements.before_all == [ + "JINJA_STATEMENT_BEGIN;\nCREATE TABLE IF NOT EXISTS analytic_stats (physical_table VARCHAR, evaluation_time VARCHAR);\nJINJA_END;" + ] + assert not bronze_statements.after_all + silver_statements = environment_statements[1] + assert not silver_statements.before_all + assert silver_statements.after_all == [ + "JINJA_STATEMENT_BEGIN;\n{{ store_schemas(schemas) }}\nJINJA_END;" + ] + assert "store_schemas" in silver_statements.jinja_macros.root_macros + analytics_table = context.fetchdf("select * from analytic_stats;") + assert sorted(analytics_table.columns) == sorted(["physical_table", "evaluation_time"]) + schema_table = context.fetchdf("select * from schema_table;") + assert sorted(schema_table.all_schemas[0]) == sorted(["bronze", "silver"]) + + +def test_multi_hybrid(mocker): + context = Context( + paths=["examples/multi_hybrid/dbt_repo", "examples/multi_hybrid/sqlmesh_repo"] + ) + context._new_state_sync().reset(default_catalog=context.default_catalog) + plan = context.plan_builder().build() + + assert len(plan.new_snapshots) == 5 + assert context.dag.roots == {'"memory"."dbt_repo"."e"'} + assert context.dag.graph['"memory"."dbt_repo"."c"'] == {'"memory"."sqlmesh_repo"."b"'} + assert context.dag.graph['"memory"."sqlmesh_repo"."b"'] == {'"memory"."sqlmesh_repo"."a"'} + assert context.dag.graph['"memory"."sqlmesh_repo"."a"'] == {'"memory"."dbt_repo"."e"'} + assert context.dag.downstream('"memory"."dbt_repo"."e"') == [ + '"memory"."sqlmesh_repo"."a"', + '"memory"."sqlmesh_repo"."b"', + '"memory"."dbt_repo"."c"', + '"memory"."dbt_repo"."d"', + ] + + sqlmesh_model_a = context.get_model("sqlmesh_repo.a") + dbt_model_c = context.get_model("dbt_repo.c") + assert sqlmesh_model_a.project == "sqlmesh_repo" + + sqlmesh_rendered = ( + 'SELECT "e"."col_a" AS "col_a", "e"."col_b" AS "col_b" FROM "memory"."dbt_repo"."e" AS "e"' + ) + dbt_rendered = 'SELECT DISTINCT ROUND(CAST(("b"."col_a" / NULLIF(100, 0)) AS DECIMAL(16, 2)), 2) AS "rounded_col_a" FROM "memory"."sqlmesh_repo"."b" AS "b"' + assert sqlmesh_model_a.render_query().sql() == sqlmesh_rendered + assert dbt_model_c.render_query().sql() == dbt_rendered + + context.apply(plan) + validate_apply_basics(context, c.PROD, plan.snapshots.values()) + + +def test_multi_repo_no_project_to_project(copy_to_temp_path): + paths = copy_to_temp_path("examples/multi") + repo_1_path = f"{paths[0]}/repo_1" + repo_1_config_path = f"{repo_1_path}/config.yaml" + with open(repo_1_config_path, "r") as f: + config_content = f.read() + with open(repo_1_config_path, "w") as f: + f.write(config_content.replace("project: repo_1\n", "")) + + context = Context(paths=[repo_1_path], gateway="memory") + context._new_state_sync().reset(default_catalog=context.default_catalog) + plan = context.plan_builder().build() + context.apply(plan) + + # initially models in prod have no project + prod_snapshots = context.state_reader.get_snapshots( + context.state_reader.get_environment(c.PROD).snapshots + ) + for snapshot in prod_snapshots.values(): + assert snapshot.node.project == "" + + # we now adopt multi project by adding a project name + with open(repo_1_config_path, "r") as f: + config_content = f.read() + with open(repo_1_config_path, "w") as f: + f.write("project: repo_1\n" + config_content) + + context_with_project = Context( + paths=[repo_1_path], + state_sync=context.state_sync, + gateway="memory", + ) + context_with_project._engine_adapter = context.engine_adapter + del context_with_project.engine_adapters + + # local models should take precedence to pick up the new project name + local_model_a = context_with_project.get_model("bronze.a") + assert local_model_a.project == "repo_1" + local_model_b = context_with_project.get_model("bronze.b") + assert local_model_b.project == "repo_1" + + # also verify the plan works + plan = context_with_project.plan_builder().build() + context_with_project.apply(plan) + validate_apply_basics(context_with_project, c.PROD, plan.snapshots.values()) + + +def test_multi_repo_local_model_overrides_prod_from_other_project(copy_to_temp_path): + paths = copy_to_temp_path("examples/multi") + repo_1_path = f"{paths[0]}/repo_1" + repo_2_path = f"{paths[0]}/repo_2" + + context = Context(paths=[repo_1_path, repo_2_path], gateway="memory") + context._new_state_sync().reset(default_catalog=context.default_catalog) + plan = context.plan_builder().build() + assert len(plan.new_snapshots) == 5 + context.apply(plan) + + prod_model_c = context.get_model("silver.c") + assert prod_model_c.project == "repo_2" + + with open(f"{repo_1_path}/models/c.sql", "w") as f: + f.write( + dedent("""\ + MODEL ( + name silver.c, + kind FULL + ); + + SELECT DISTINCT col_a, col_b + FROM bronze.a + """) + ) + + # silver.c exists locally in repo 1 now AND in prod under repo_2 + context_repo1 = Context( + paths=[repo_1_path], + state_sync=context.state_sync, + gateway="memory", + ) + context_repo1._engine_adapter = context.engine_adapter + del context_repo1.engine_adapters + + # local model should take precedence and its project should reflect the new project name + local_model_c = context_repo1.get_model("silver.c") + assert local_model_c.project == "repo_1" + + rendered = context_repo1.render("silver.c").sql() + assert "col_b" in rendered + + # its downstream dependencies though should still be picked up + plan = context_repo1.plan_builder().build() + directly_modified_names = {snapshot.name for snapshot in plan.directly_modified} + assert '"memory"."silver"."c"' in directly_modified_names + assert '"memory"."silver"."d"' in directly_modified_names + missing_interval_names = {s.snapshot_id.name for s in plan.missing_intervals} + assert '"memory"."silver"."c"' in missing_interval_names + assert '"memory"."silver"."d"' in missing_interval_names + + context_repo1.apply(plan) + validate_apply_basics(context_repo1, c.PROD, plan.snapshots.values()) + result = context_repo1.fetchdf("SELECT * FROM memory.silver.c") + assert "col_b" in result.columns + + +def test_engine_adapters_multi_repo_all_gateways_gathered(copy_to_temp_path): + paths = copy_to_temp_path("examples/multi") + repo_1_path = paths[0] / "repo_1" + repo_2_path = paths[0] / "repo_2" + + # Add an extra gateway to repo_2's config + repo_2_config_path = repo_2_path / "config.yaml" + config_content = repo_2_config_path.read_text() + + modified_config = config_content.replace( + "default_gateway: local", + dedent(""" + extra: + connection: + type: duckdb + database: extra.duckdb + + default_gateway: local + """), + ) + + repo_2_config_path.write_text(modified_config) + + # Create context with both repos but using the repo_1 path first + context = Context( + paths=(repo_1_path, repo_2_path), + gateway="memory", + ) + + # Verify all gateways from both repos are present + gathered_gateways = context.engine_adapters.keys() + expected_gateways = {"local", "memory", "extra"} + assert gathered_gateways == expected_gateways diff --git a/tests/core/integration/test_plan_options.py b/tests/core/integration/test_plan_options.py new file mode 100644 index 0000000000..a50dc145cd --- /dev/null +++ b/tests/core/integration/test_plan_options.py @@ -0,0 +1,529 @@ +from __future__ import annotations + +import typing as t +import pytest +from sqlmesh.core.console import ( + set_console, + get_console, + TerminalConsole, +) +import time_machine + +from sqlmesh.core import dialect as d +from sqlmesh.core.console import get_console +from sqlmesh.core.model import ( + SqlModel, + load_sql_based_model, +) +from sqlmesh.core.plan import SnapshotIntervals +from sqlmesh.core.snapshot import ( + SnapshotChangeCategory, +) +from sqlmesh.utils.date import to_datetime, to_timestamp +from sqlmesh.utils.errors import ( + NoChangesPlanError, +) +from tests.core.integration.utils import ( + add_projection_to_model, +) + +pytestmark = pytest.mark.slow + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_empty_backfill(init_and_plan_context: t.Callable): + context, _ = init_and_plan_context("examples/sushi") + + plan = context.plan_builder("prod", skip_tests=True, empty_backfill=True).build() + assert plan.missing_intervals + assert plan.empty_backfill + assert not plan.requires_backfill + + context.apply(plan) + + for model in context.models.values(): + if model.is_seed or model.kind.is_symbolic: + continue + row_num = context.engine_adapter.fetchone(f"SELECT COUNT(*) FROM {model.name}")[0] + assert row_num == 0 + + plan = context.plan_builder("prod", skip_tests=True).build() + assert not plan.requires_backfill + assert not plan.has_changes + assert not plan.missing_intervals + + snapshots = plan.snapshots + for snapshot in snapshots.values(): + if not snapshot.intervals: + continue + assert snapshot.intervals[-1][1] <= to_timestamp("2023-01-08") + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_empty_backfill_new_model(init_and_plan_context: t.Callable): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + new_model = load_sql_based_model( + d.parse( + """ + MODEL ( + name memory.sushi.new_model, + kind FULL, + cron '0 8 * * *', + start '2023-01-01', + ); + + SELECT 1 AS one; + """ + ) + ) + new_model_name = context.upsert_model(new_model).fqn + + with time_machine.travel("2023-01-09 00:00:00 UTC"): + plan = context.plan_builder("dev", skip_tests=True, empty_backfill=True).build() + assert plan.end == to_datetime("2023-01-09") + assert plan.missing_intervals + assert plan.empty_backfill + assert not plan.requires_backfill + + context.apply(plan) + + for model in context.models.values(): + if model.is_seed or model.kind.is_symbolic: + continue + row_num = context.engine_adapter.fetchone(f"SELECT COUNT(*) FROM sushi__dev.new_model")[ + 0 + ] + assert row_num == 0 + + plan = context.plan_builder("prod", skip_tests=True).build() + assert not plan.requires_backfill + assert not plan.missing_intervals + + snapshots = plan.snapshots + for snapshot in snapshots.values(): + if not snapshot.intervals: + continue + elif snapshot.name == new_model_name: + assert snapshot.intervals[-1][1] == to_timestamp("2023-01-09") + else: + assert snapshot.intervals[-1][1] <= to_timestamp("2023-01-08") + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_plan_explain(init_and_plan_context: t.Callable): + old_console = get_console() + set_console(TerminalConsole()) + + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + waiter_revenue_by_day_model = context.get_model("sushi.waiter_revenue_by_day") + waiter_revenue_by_day_model = add_projection_to_model( + t.cast(SqlModel, waiter_revenue_by_day_model) + ) + context.upsert_model(waiter_revenue_by_day_model) + + waiter_revenue_by_day_snapshot = context.get_snapshot(waiter_revenue_by_day_model.name) + top_waiters_snapshot = context.get_snapshot("sushi.top_waiters") + + common_kwargs = dict(skip_tests=True, no_prompts=True, explain=True) + + # For now just making sure the plan doesn't error + context.plan("dev", **common_kwargs) + context.plan("dev", **common_kwargs, skip_backfill=True) + context.plan("dev", **common_kwargs, empty_backfill=True) + context.plan("dev", **common_kwargs, forward_only=True, enable_preview=True) + context.plan("prod", **common_kwargs) + context.plan("prod", **common_kwargs, forward_only=True) + context.plan("prod", **common_kwargs, restate_models=[waiter_revenue_by_day_model.name]) + + set_console(old_console) + + # Make sure that the now changes were actually applied + for target_env in ("dev", "prod"): + plan = context.plan_builder(target_env, skip_tests=True).build() + assert plan.has_changes + assert plan.missing_intervals + assert plan.directly_modified == {waiter_revenue_by_day_snapshot.snapshot_id} + assert len(plan.new_snapshots) == 2 + assert {s.snapshot_id for s in plan.new_snapshots} == { + waiter_revenue_by_day_snapshot.snapshot_id, + top_waiters_snapshot.snapshot_id, + } + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_plan_ignore_cron( + init_and_plan_context: t.Callable, +): + context, _ = init_and_plan_context("examples/sushi") + + expressions = d.parse( + f""" + MODEL ( + name memory.sushi.test_allow_partials, + kind INCREMENTAL_UNMANAGED, + allow_partials true, + start '2023-01-01', + ); + + SELECT @end_ts AS end_ts + """ + ) + model = load_sql_based_model(expressions) + + context.upsert_model(model) + context.plan("prod", skip_tests=True, auto_apply=True, no_prompts=True) + + assert ( + context.engine_adapter.fetchone("SELECT MAX(end_ts) FROM memory.sushi.test_allow_partials")[ + 0 + ] + == "2023-01-07 23:59:59.999999" + ) + + plan_no_ignore_cron = context.plan_builder( + "prod", run=True, ignore_cron=False, skip_tests=True + ).build() + assert not plan_no_ignore_cron.missing_intervals + + plan = context.plan_builder("prod", run=True, ignore_cron=True, skip_tests=True).build() + assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=context.get_snapshot(model, raise_if_missing=True).snapshot_id, + intervals=[ + (to_timestamp("2023-01-08"), to_timestamp("2023-01-08 15:00:00")), + ], + ) + ] + context.apply(plan) + + assert ( + context.engine_adapter.fetchone("SELECT MAX(end_ts) FROM memory.sushi.test_allow_partials")[ + 0 + ] + == "2023-01-08 14:59:59.999999" + ) + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_plan_with_run( + init_and_plan_context: t.Callable, +): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + model = context.get_model("sushi.waiter_revenue_by_day") + context.upsert_model(add_projection_to_model(t.cast(SqlModel, model))) + + with time_machine.travel("2023-01-09 00:00:00 UTC"): + plan = context.plan(run=True) + assert plan.has_changes + assert plan.missing_intervals + + context.apply(plan) + + snapshots = context.state_sync.state_sync.get_snapshots(context.snapshots.values()) + assert {s.name: s.intervals[0][1] for s in snapshots.values() if s.intervals} == { + '"memory"."sushi"."waiter_revenue_by_day"': to_timestamp("2023-01-09"), + '"memory"."sushi"."order_items"': to_timestamp("2023-01-09"), + '"memory"."sushi"."orders"': to_timestamp("2023-01-09"), + '"memory"."sushi"."items"': to_timestamp("2023-01-09"), + '"memory"."sushi"."customer_revenue_lifetime"': to_timestamp("2023-01-09"), + '"memory"."sushi"."customer_revenue_by_day"': to_timestamp("2023-01-09"), + '"memory"."sushi"."latest_order"': to_timestamp("2023-01-09"), + '"memory"."sushi"."waiter_names"': to_timestamp("2023-01-08"), + '"memory"."sushi"."raw_marketing"': to_timestamp("2023-01-09"), + '"memory"."sushi"."marketing"': to_timestamp("2023-01-09"), + '"memory"."sushi"."waiter_as_customer_by_day"': to_timestamp("2023-01-09"), + '"memory"."sushi"."top_waiters"': to_timestamp("2023-01-09"), + '"memory"."raw"."demographics"': to_timestamp("2023-01-09"), + "assert_item_price_above_zero": to_timestamp("2023-01-09"), + '"memory"."sushi"."active_customers"': to_timestamp("2023-01-09"), + '"memory"."sushi"."customers"': to_timestamp("2023-01-09"), + '"memory"."sushi"."count_customers_active"': to_timestamp("2023-01-09"), + '"memory"."sushi"."count_customers_inactive"': to_timestamp("2023-01-09"), + } + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_select_models(init_and_plan_context: t.Callable): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + # Modify 2 models. + model = context.get_model("sushi.waiter_revenue_by_day") + kwargs = { + **model.dict(), + # Make a breaking change. + "query": model.query.order_by("waiter_id"), # type: ignore + } + context.upsert_model(SqlModel.parse_obj(kwargs)) + + model = context.get_model("sushi.customer_revenue_by_day") + context.upsert_model(add_projection_to_model(t.cast(SqlModel, model))) + + expected_intervals = [ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ] + + waiter_revenue_by_day_snapshot_id = context.get_snapshot( + "sushi.waiter_revenue_by_day", raise_if_missing=True + ).snapshot_id + + # Select one of the modified models. + plan_builder = context.plan_builder( + "dev", select_models=["*waiter_revenue_by_day"], skip_tests=True + ) + snapshot = plan_builder._context_diff.snapshots[waiter_revenue_by_day_snapshot_id] + plan_builder.set_choice(snapshot, SnapshotChangeCategory.BREAKING) + plan = plan_builder.build() + + assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=waiter_revenue_by_day_snapshot_id, + intervals=expected_intervals, + ), + ] + + context.apply(plan) + + dev_df = context.engine_adapter.fetchdf( + "SELECT DISTINCT event_date FROM sushi__dev.waiter_revenue_by_day ORDER BY event_date" + ) + assert len(dev_df) == 7 + + # Make sure that we only create a view for the selected model. + schema_objects = context.engine_adapter.get_data_objects("sushi__dev") + assert len(schema_objects) == 1 + assert schema_objects[0].name == "waiter_revenue_by_day" + + # Validate the other modified model. + assert not context.get_snapshot("sushi.customer_revenue_by_day").change_category + assert not context.get_snapshot("sushi.customer_revenue_by_day").version + + # Validate the downstream model. + assert not context.engine_adapter.table_exists( + context.get_snapshot("sushi.top_waiters").table_name() + ) + assert not context.engine_adapter.table_exists( + context.get_snapshot("sushi.top_waiters").table_name(False) + ) + + # Make sure that tables are created when deploying to prod. + plan = context.plan("prod", skip_tests=True) + context.apply(plan) + assert context.engine_adapter.table_exists( + context.get_snapshot("sushi.top_waiters").table_name() + ) + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_select_models_for_backfill(init_and_plan_context: t.Callable): + context, _ = init_and_plan_context("examples/sushi") + + expected_intervals = [ + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ] + + plan = context.plan_builder( + "dev", backfill_models=["+*waiter_revenue_by_day"], skip_tests=True + ).build() + + assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=context.get_snapshot("sushi.items", raise_if_missing=True).snapshot_id, + intervals=expected_intervals, + ), + SnapshotIntervals( + snapshot_id=context.get_snapshot( + "sushi.order_items", raise_if_missing=True + ).snapshot_id, + intervals=expected_intervals, + ), + SnapshotIntervals( + snapshot_id=context.get_snapshot("sushi.orders", raise_if_missing=True).snapshot_id, + intervals=expected_intervals, + ), + SnapshotIntervals( + snapshot_id=context.get_snapshot( + "sushi.waiter_revenue_by_day", raise_if_missing=True + ).snapshot_id, + intervals=expected_intervals, + ), + ] + + context.apply(plan) + + dev_df = context.engine_adapter.fetchdf( + "SELECT DISTINCT event_date FROM sushi__dev.waiter_revenue_by_day ORDER BY event_date" + ) + assert len(dev_df) == 1 + + schema_objects = context.engine_adapter.get_data_objects("sushi__dev") + assert {o.name for o in schema_objects} == { + "items", + "order_items", + "orders", + "waiter_revenue_by_day", + } + + assert not context.engine_adapter.table_exists( + context.get_snapshot("sushi.customer_revenue_by_day").table_name() + ) + + # Make sure that tables are created when deploying to prod. + plan = context.plan("prod") + context.apply(plan) + assert context.engine_adapter.table_exists( + context.get_snapshot("sushi.customer_revenue_by_day").table_name() + ) + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_select_unchanged_model_for_backfill(init_and_plan_context: t.Callable): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + # Modify 2 models. + model = context.get_model("sushi.waiter_revenue_by_day") + kwargs = { + **model.dict(), + # Make a breaking change. + "query": d.parse_one( + f"{model.query.sql(dialect='duckdb')} ORDER BY waiter_id", dialect="duckdb" + ), + } + context.upsert_model(SqlModel.parse_obj(kwargs)) + + model = context.get_model("sushi.customer_revenue_by_day") + context.upsert_model(add_projection_to_model(t.cast(SqlModel, model))) + + expected_intervals = [ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ] + + waiter_revenue_by_day_snapshot_id = context.get_snapshot( + "sushi.waiter_revenue_by_day", raise_if_missing=True + ).snapshot_id + + # Select one of the modified models. + plan_builder = context.plan_builder( + "dev", select_models=["*waiter_revenue_by_day"], skip_tests=True + ) + snapshot = plan_builder._context_diff.snapshots[waiter_revenue_by_day_snapshot_id] + plan_builder.set_choice(snapshot, SnapshotChangeCategory.BREAKING) + plan = plan_builder.build() + + assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=waiter_revenue_by_day_snapshot_id, + intervals=expected_intervals, + ), + ] + + context.apply(plan) + + # Make sure that we only create a view for the selected model. + schema_objects = context.engine_adapter.get_data_objects("sushi__dev") + assert {o.name for o in schema_objects} == {"waiter_revenue_by_day"} + + # Now select a model downstream from the previously modified one in order to backfill it. + plan = context.plan_builder("dev", select_models=["*top_waiters"], skip_tests=True).build() + + assert not plan.has_changes + assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=context.get_snapshot( + "sushi.top_waiters", raise_if_missing=True + ).snapshot_id, + intervals=expected_intervals, + ), + ] + + context.apply(plan) + + # Make sure that a view has been created for the downstream selected model. + schema_objects = context.engine_adapter.get_data_objects("sushi__dev") + assert {o.name for o in schema_objects} == {"waiter_revenue_by_day", "top_waiters"} + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_create_environment_no_changes_with_selector(init_and_plan_context: t.Callable): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + with pytest.raises(NoChangesPlanError): + context.plan_builder("dev").build() + + plan = context.plan_builder("dev", select_models=["*top_waiters"]).build() + assert not plan.missing_intervals + context.apply(plan) + + schema_objects = context.engine_adapter.get_data_objects("sushi__dev") + assert {o.name for o in schema_objects} == {"top_waiters"} + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_include_unmodified(init_and_plan_context: t.Callable): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + plan = context.plan_builder( + "dev", + include_unmodified=True, + skip_tests=True, + ).build() + + all_snapshots = context.snapshots + + assert len(plan.environment.snapshots) == len(all_snapshots) + assert plan.environment.promoted_snapshot_ids is None + + context.apply(plan) + + data_objs = context.engine_adapter.get_data_objects("sushi__dev") + assert len(data_objs) == len( + [s for s in all_snapshots.values() if s.is_model and not s.is_symbolic] + ) + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_select_models_with_include_unmodified(init_and_plan_context: t.Callable): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + plan = context.plan_builder( + "dev", + select_models=["*top_waiters", "*customer_revenue_by_day"], + include_unmodified=True, + skip_tests=True, + ).build() + + assert len(plan.environment.snapshots) == len(context.snapshots) + + promoted_set = {s_id.name for s_id in plan.environment.promoted_snapshot_ids} + assert promoted_set == { + '"memory"."sushi"."customer_revenue_by_day"', + '"memory"."sushi"."top_waiters"', + } + + context.apply(plan) + + data_objs = context.engine_adapter.get_data_objects("sushi__dev") + assert len(data_objs) == 2 + assert {o.name for o in data_objs} == {"customer_revenue_by_day", "top_waiters"} diff --git a/tests/core/integration/test_restatement.py b/tests/core/integration/test_restatement.py new file mode 100644 index 0000000000..3694efce31 --- /dev/null +++ b/tests/core/integration/test_restatement.py @@ -0,0 +1,1935 @@ +from __future__ import annotations + +import typing as t +import pandas as pd # noqa: TID253 +import pytest +from pathlib import Path +from sqlmesh.core.console import ( + MarkdownConsole, + set_console, + get_console, + CaptureTerminalConsole, +) +import time_machine +from sqlglot import exp +import re +from concurrent.futures import ThreadPoolExecutor, TimeoutError +import time +import queue + +from sqlmesh.core import constants as c +from sqlmesh.core.config import ( + Config, + GatewayConfig, + ModelDefaultsConfig, + DuckDBConnectionConfig, +) +from sqlmesh.core.context import Context +from sqlmesh.core.model import ( + IncrementalByTimeRangeKind, + IncrementalUnmanagedKind, + SqlModel, +) +from sqlmesh.core.plan import SnapshotIntervals +from sqlmesh.core.snapshot import ( + Snapshot, + SnapshotId, +) +from sqlmesh.utils.date import to_timestamp +from sqlmesh.utils.errors import ( + ConflictingPlanError, +) +from tests.core.integration.utils import add_projection_to_model + +pytestmark = pytest.mark.slow + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_restatement_plan_ignores_changes(init_and_plan_context: t.Callable): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + restated_snapshot = context.get_snapshot("sushi.top_waiters") + + # Simulate a change. + model = context.get_model("sushi.waiter_revenue_by_day") + context.upsert_model(add_projection_to_model(t.cast(SqlModel, model))) + + plan = context.plan_builder(restate_models=["sushi.top_waiters"]).build() + assert plan.snapshots != context.snapshots + + assert not plan.directly_modified + assert not plan.has_changes + assert not plan.new_snapshots + assert plan.requires_backfill + assert plan.restatements == { + restated_snapshot.snapshot_id: (to_timestamp("2023-01-01"), to_timestamp("2023-01-09")) + } + assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=restated_snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ) + ] + + context.apply(plan) + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_restatement_plan_across_environments_snapshot_with_shared_version( + init_and_plan_context: t.Callable, +): + context, _ = init_and_plan_context("examples/sushi") + + # Change kind to incremental unmanaged + model = context.get_model("sushi.waiter_revenue_by_day") + previous_kind = model.kind.copy(update={"forward_only": True}) + assert isinstance(previous_kind, IncrementalByTimeRangeKind) + + model = model.copy( + update={ + "kind": IncrementalUnmanagedKind(), + "physical_version": "pinned_version_12345", + "partitioned_by_": [exp.column("event_date")], + } + ) + context.upsert_model(model) + context.plan("prod", auto_apply=True, no_prompts=True) + + # Make some change and deploy it to both dev and prod environments + model = add_projection_to_model(t.cast(SqlModel, model)) + context.upsert_model(model) + context.plan("dev_a", auto_apply=True, no_prompts=True) + context.plan("prod", auto_apply=True, no_prompts=True) + + # Change the kind back to incremental by time range and deploy to prod + model = model.copy(update={"kind": previous_kind}) + context.upsert_model(model) + context.plan("prod", auto_apply=True, no_prompts=True) + + # Restate the model and verify that the interval hasn't been expanded because of the old snapshot + # with the same version + context.plan( + restate_models=["sushi.waiter_revenue_by_day"], + start="2023-01-06", + end="2023-01-08", + auto_apply=True, + no_prompts=True, + ) + + assert ( + context.fetchdf( + "SELECT COUNT(*) AS cnt FROM sushi.waiter_revenue_by_day WHERE one IS NOT NULL AND event_date < '2023-01-06'" + )["cnt"][0] + == 0 + ) + plan = context.plan_builder("prod").build() + assert not plan.missing_intervals + + +def test_restatement_plan_hourly_with_downstream_daily_restates_correct_intervals(tmp_path: Path): + model_a = """ + MODEL ( + name test.a, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column "ts" + ), + start '2024-01-01 00:00:00', + cron '@hourly' + ); + + select account_id, ts from test.external_table; + """ + + model_b = """ + MODEL ( + name test.b, + kind FULL, + cron '@daily' + ); + + select account_id, ts from test.a; + """ + + models_dir = tmp_path / "models" + models_dir.mkdir() + + for path, defn in {"a.sql": model_a, "b.sql": model_b}.items(): + with open(models_dir / path, "w") as f: + f.write(defn) + + config = Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")) + ctx = Context(paths=[tmp_path], config=config) + + engine_adapter = ctx.engine_adapter + engine_adapter.create_schema("test") + + # source data + df = pd.DataFrame( + { + "account_id": [1001, 1002, 1003, 1004], + "ts": [ + "2024-01-01 00:30:00", + "2024-01-01 01:30:00", + "2024-01-01 02:30:00", + "2024-01-02 00:30:00", + ], + } + ) + columns_to_types = { + "account_id": exp.DataType.build("int"), + "ts": exp.DataType.build("timestamp"), + } + external_table = exp.table_(table="external_table", db="test", quoted=True) + engine_adapter.create_table(table_name=external_table, target_columns_to_types=columns_to_types) + engine_adapter.insert_append( + table_name=external_table, query_or_df=df, target_columns_to_types=columns_to_types + ) + + # plan + apply + ctx.plan(auto_apply=True, no_prompts=True) + + def _dates_in_table(table_name: str) -> t.List[str]: + return [ + str(r[0]) for r in engine_adapter.fetchall(f"select ts from {table_name} order by ts") + ] + + # verify initial state + for tbl in ["test.a", "test.b"]: + assert _dates_in_table(tbl) == [ + "2024-01-01 00:30:00", + "2024-01-01 01:30:00", + "2024-01-01 02:30:00", + "2024-01-02 00:30:00", + ] + + # restate A + engine_adapter.execute("delete from test.external_table where ts = '2024-01-01 01:30:00'") + ctx.plan( + restate_models=["test.a"], + start="2024-01-01 01:00:00", + end="2024-01-01 02:00:00", + auto_apply=True, + no_prompts=True, + ) + + # verify result + for tbl in ["test.a", "test.b"]: + assert _dates_in_table(tbl) == [ + "2024-01-01 00:30:00", + "2024-01-01 02:30:00", + "2024-01-02 00:30:00", + ], f"Table {tbl} wasnt cleared" + + # Put some data + df = pd.DataFrame( + { + "account_id": [1001, 1002, 1003, 1004], + "ts": [ + "2024-01-01 01:30:00", + "2024-01-01 23:30:00", + "2024-01-02 03:30:00", + "2024-01-03 12:30:00", + ], + } + ) + engine_adapter.replace_query( + table_name=external_table, query_or_df=df, target_columns_to_types=columns_to_types + ) + + # Restate A across a day boundary with the expectation that two day intervals in B are affected + ctx.plan( + restate_models=["test.a"], + start="2024-01-01 02:00:00", + end="2024-01-02 04:00:00", + auto_apply=True, + no_prompts=True, + ) + + for tbl in ["test.a", "test.b"]: + assert _dates_in_table(tbl) == [ + "2024-01-01 00:30:00", # present already + # "2024-01-01 02:30:00", #removed in last restatement + "2024-01-01 23:30:00", # added in last restatement + "2024-01-02 03:30:00", # added in last restatement + ], f"Table {tbl} wasnt cleared" + + +def test_restatement_plan_respects_disable_restatements(tmp_path: Path): + model_a = """ + MODEL ( + name test.a, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column "ts" + ), + start '2024-01-01', + cron '@daily' + ); + + select account_id, ts from test.external_table; + """ + + model_b = """ + MODEL ( + name test.b, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column "ts", + disable_restatement true, + ), + start '2024-01-01', + cron '@daily' + ); + + select account_id, ts from test.a; + """ + + models_dir = tmp_path / "models" + models_dir.mkdir() + + for path, defn in {"a.sql": model_a, "b.sql": model_b}.items(): + with open(models_dir / path, "w") as f: + f.write(defn) + + config = Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")) + ctx = Context(paths=[tmp_path], config=config) + + engine_adapter = ctx.engine_adapter + engine_adapter.create_schema("test") + + # source data + df = pd.DataFrame( + { + "account_id": [1001, 1002, 1003, 1004], + "ts": [ + "2024-01-01 00:30:00", + "2024-01-01 01:30:00", + "2024-01-01 02:30:00", + "2024-01-02 00:30:00", + ], + } + ) + columns_to_types = { + "account_id": exp.DataType.build("int"), + "ts": exp.DataType.build("timestamp"), + } + external_table = exp.table_(table="external_table", db="test", quoted=True) + engine_adapter.create_table(table_name=external_table, target_columns_to_types=columns_to_types) + engine_adapter.insert_append( + table_name=external_table, query_or_df=df, target_columns_to_types=columns_to_types + ) + + # plan + apply + ctx.plan(auto_apply=True, no_prompts=True) + + def _dates_in_table(table_name: str) -> t.List[str]: + return [ + str(r[0]) for r in engine_adapter.fetchall(f"select ts from {table_name} order by ts") + ] + + def get_snapshot_intervals(snapshot_id): + return list(ctx.state_sync.get_snapshots([snapshot_id]).values())[0].intervals + + # verify initial state + for tbl in ["test.a", "test.b"]: + assert _dates_in_table(tbl) == [ + "2024-01-01 00:30:00", + "2024-01-01 01:30:00", + "2024-01-01 02:30:00", + "2024-01-02 00:30:00", + ] + + # restate A and expect b to be ignored + starting_b_intervals = get_snapshot_intervals(ctx.snapshots['"memory"."test"."b"'].snapshot_id) + engine_adapter.execute("delete from test.external_table where ts = '2024-01-01 01:30:00'") + ctx.plan( + restate_models=["test.a"], + start="2024-01-01", + end="2024-01-02", + auto_apply=True, + no_prompts=True, + ) + + # verify A was changed and not b + assert _dates_in_table("test.a") == [ + "2024-01-01 00:30:00", + "2024-01-01 02:30:00", + "2024-01-02 00:30:00", + ] + assert _dates_in_table("test.b") == [ + "2024-01-01 00:30:00", + "2024-01-01 01:30:00", + "2024-01-01 02:30:00", + "2024-01-02 00:30:00", + ] + + # Verify B intervals were not touched + b_intervals = get_snapshot_intervals(ctx.snapshots['"memory"."test"."b"'].snapshot_id) + assert starting_b_intervals == b_intervals + + +def test_restatement_plan_clears_correct_intervals_across_environments(tmp_path: Path): + model1 = """ + MODEL ( + name test.incremental_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column "date" + ), + start '2024-01-01', + cron '@daily' + ); + + select account_id, date from test.external_table; + """ + + model2 = """ + MODEL ( + name test.downstream_of_incremental, + kind FULL + ); + + select account_id, date from test.incremental_model; + """ + + models_dir = tmp_path / "models" + models_dir.mkdir() + + with open(models_dir / "model1.sql", "w") as f: + f.write(model1) + + with open(models_dir / "model2.sql", "w") as f: + f.write(model2) + + config = Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")) + ctx = Context(paths=[tmp_path], config=config) + + engine_adapter = ctx.engine_adapter + engine_adapter.create_schema("test") + + # source data + df = pd.DataFrame( + { + "account_id": [1001, 1002, 1003, 1004, 1005], + "name": ["foo", "bar", "baz", "bing", "bong"], + "date": ["2024-01-01", "2024-01-02", "2024-01-03", "2024-01-04", "2024-01-05"], + } + ) + columns_to_types = { + "account_id": exp.DataType.build("int"), + "name": exp.DataType.build("varchar"), + "date": exp.DataType.build("date"), + } + external_table = exp.table_(table="external_table", db="test", quoted=True) + engine_adapter.create_table(table_name=external_table, target_columns_to_types=columns_to_types) + engine_adapter.insert_append( + table_name=external_table, query_or_df=df, target_columns_to_types=columns_to_types + ) + + # first, create the prod models + ctx.plan(auto_apply=True, no_prompts=True) + assert engine_adapter.fetchone("select count(*) from test.incremental_model") == (5,) + assert engine_adapter.fetchone("select count(*) from test.downstream_of_incremental") == (5,) + assert not engine_adapter.table_exists("test__dev.incremental_model") + + # then, make a dev version + model1 = """ + MODEL ( + name test.incremental_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column "date" + ), + start '2024-01-01', + cron '@daily' + ); + + select 1 as account_id, date from test.external_table; + """ + with open(models_dir / "model1.sql", "w") as f: + f.write(model1) + ctx.load() + + ctx.plan(environment="dev", auto_apply=True, no_prompts=True) + assert engine_adapter.table_exists("test__dev.incremental_model") + assert engine_adapter.fetchone("select count(*) from test__dev.incremental_model") == (5,) + + # drop some source data so when we restate the interval it essentially clears it which is easy to verify + engine_adapter.execute("delete from test.external_table where date = '2024-01-01'") + assert engine_adapter.fetchone("select count(*) from test.external_table") == (4,) + + # now, restate intervals in dev and verify prod is NOT affected + ctx.plan( + environment="dev", + start="2024-01-01", + end="2024-01-02", + restate_models=["test.incremental_model"], + auto_apply=True, + no_prompts=True, + ) + assert engine_adapter.fetchone("select count(*) from test.incremental_model") == (5,) + assert engine_adapter.fetchone( + "select count(*) from test.incremental_model where date = '2024-01-01'" + ) == (1,) + assert engine_adapter.fetchone("select count(*) from test__dev.incremental_model") == (4,) + assert engine_adapter.fetchone( + "select count(*) from test__dev.incremental_model where date = '2024-01-01'" + ) == (0,) + + # prod still should not be affected by a run because the restatement only happened in dev + ctx.run() + assert engine_adapter.fetchone("select count(*) from test.incremental_model") == (5,) + assert engine_adapter.fetchone( + "select count(*) from test.incremental_model where date = '2024-01-01'" + ) == (1,) + + # drop another interval from the source data + engine_adapter.execute("delete from test.external_table where date = '2024-01-02'") + + # now, restate intervals in prod and verify that dev IS affected + ctx.plan( + start="2024-01-01", + end="2024-01-03", + restate_models=["test.incremental_model"], + auto_apply=True, + no_prompts=True, + ) + assert engine_adapter.fetchone("select count(*) from test.incremental_model") == (3,) + assert engine_adapter.fetchone( + "select count(*) from test.incremental_model where date = '2024-01-01'" + ) == (0,) + assert engine_adapter.fetchone( + "select count(*) from test.incremental_model where date = '2024-01-02'" + ) == (0,) + assert engine_adapter.fetchone( + "select count(*) from test.incremental_model where date = '2024-01-03'" + ) == (1,) + + # dev not affected yet until `sqlmesh run` is run + assert engine_adapter.fetchone("select count(*) from test__dev.incremental_model") == (4,) + assert engine_adapter.fetchone( + "select count(*) from test__dev.incremental_model where date = '2024-01-01'" + ) == (0,) + assert engine_adapter.fetchone( + "select count(*) from test__dev.incremental_model where date = '2024-01-02'" + ) == (1,) + assert engine_adapter.fetchone( + "select count(*) from test__dev.incremental_model where date = '2024-01-03'" + ) == (1,) + + # the restatement plan for prod should have cleared dev intervals too, which means this `sqlmesh run` re-runs 2024-01-01 and 2024-01-02 + ctx.run(environment="dev") + assert engine_adapter.fetchone("select count(*) from test__dev.incremental_model") == (3,) + assert engine_adapter.fetchone( + "select count(*) from test__dev.incremental_model where date = '2024-01-01'" + ) == (0,) + assert engine_adapter.fetchone( + "select count(*) from test__dev.incremental_model where date = '2024-01-02'" + ) == (0,) + assert engine_adapter.fetchone( + "select count(*) from test__dev.incremental_model where date = '2024-01-03'" + ) == (1,) + + # the downstream full model should always reflect whatever the incremental model is showing + assert engine_adapter.fetchone("select count(*) from test.downstream_of_incremental") == (3,) + assert engine_adapter.fetchone("select count(*) from test__dev.downstream_of_incremental") == ( + 3, + ) + + +def test_prod_restatement_plan_clears_correct_intervals_in_derived_dev_tables(tmp_path: Path): + """ + Scenario: + I have models A[hourly] <- B[daily] <- C in prod + I create dev and add 2 new models D and E so that my dev DAG looks like A <- B <- C <- D[daily] <- E + I prod, I restate *one hour* of A + Outcome: + D and E should be restated in dev despite not being a part of prod + since B and D are daily, the whole day should be restated even though only 1hr of the upstream model was restated + """ + + model_a = """ + MODEL ( + name test.a, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column "ts" + ), + start '2024-01-01 00:00:00', + cron '@hourly' + ); + + select account_id, ts from test.external_table; + """ + + def _derived_full_model_def(name: str, upstream: str) -> str: + return f""" + MODEL ( + name test.{name}, + kind FULL + ); + + select account_id, ts from test.{upstream}; + """ + + def _derived_incremental_model_def(name: str, upstream: str) -> str: + return f""" + MODEL ( + name test.{name}, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ts + ), + cron '@daily' + ); + + select account_id, ts from test.{upstream} where ts between @start_ts and @end_ts; + """ + + model_b = _derived_incremental_model_def("b", upstream="a") + model_c = _derived_full_model_def("c", upstream="b") + + models_dir = tmp_path / "models" + models_dir.mkdir() + + for path, defn in {"a.sql": model_a, "b.sql": model_b, "c.sql": model_c}.items(): + with open(models_dir / path, "w") as f: + f.write(defn) + + config = Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")) + ctx = Context(paths=[tmp_path], config=config) + + engine_adapter = ctx.engine_adapter + engine_adapter.create_schema("test") + + # source data + df = pd.DataFrame( + { + "account_id": [1001, 1002, 1003, 1004], + "ts": [ + "2024-01-01 00:30:00", + "2024-01-01 01:30:00", + "2024-01-01 02:30:00", + "2024-01-02 00:30:00", + ], + } + ) + columns_to_types = { + "account_id": exp.DataType.build("int"), + "ts": exp.DataType.build("timestamp"), + } + external_table = exp.table_(table="external_table", db="test", quoted=True) + engine_adapter.create_table(table_name=external_table, target_columns_to_types=columns_to_types) + engine_adapter.insert_append( + table_name=external_table, query_or_df=df, target_columns_to_types=columns_to_types + ) + + # plan + apply A, B, C in prod + ctx.plan(auto_apply=True, no_prompts=True) + + # add D[daily], E in dev + model_d = _derived_incremental_model_def("d", upstream="c") + model_e = _derived_full_model_def("e", upstream="d") + + for path, defn in { + "d.sql": model_d, + "e.sql": model_e, + }.items(): + with open(models_dir / path, "w") as f: + f.write(defn) + + # plan + apply dev + ctx.load() + ctx.plan(environment="dev", auto_apply=True, no_prompts=True) + + def _dates_in_table(table_name: str) -> t.List[str]: + return [ + str(r[0]) for r in engine_adapter.fetchall(f"select ts from {table_name} order by ts") + ] + + # verify initial state + for tbl in ["test.a", "test.b", "test.c", "test__dev.d", "test__dev.e"]: + assert engine_adapter.table_exists(tbl) + assert _dates_in_table(tbl) == [ + "2024-01-01 00:30:00", + "2024-01-01 01:30:00", + "2024-01-01 02:30:00", + "2024-01-02 00:30:00", + ] + + for tbl in ["test.d", "test.e"]: + assert not engine_adapter.table_exists(tbl) + + # restate A in prod + engine_adapter.execute("delete from test.external_table where ts = '2024-01-01 01:30:00'") + ctx.plan( + restate_models=["test.a"], + start="2024-01-01 01:00:00", + end="2024-01-01 02:00:00", + auto_apply=True, + no_prompts=True, + ) + + # verify result + for tbl in ["test.a", "test.b", "test.c"]: + assert _dates_in_table(tbl) == [ + "2024-01-01 00:30:00", + "2024-01-01 02:30:00", + "2024-01-02 00:30:00", + ], f"Table {tbl} wasnt cleared" + + # dev shouldnt have been affected yet + for tbl in ["test__dev.d", "test__dev.e"]: + assert _dates_in_table(tbl) == [ + "2024-01-01 00:30:00", + "2024-01-01 01:30:00", + "2024-01-01 02:30:00", + "2024-01-02 00:30:00", + ], f"Table {tbl} was prematurely cleared" + + # run dev to trigger the processing of the prod restatement + ctx.run(environment="dev") + + # data should now be cleared from dev + # note that D is a daily model, so clearing an hour interval from A should have triggered the full day in D + for tbl in ["test__dev.d", "test__dev.e"]: + assert _dates_in_table(tbl) == [ + "2024-01-01 00:30:00", + "2024-01-01 02:30:00", + "2024-01-02 00:30:00", + ], f"Table {tbl} wasnt cleared" + + +def test_prod_restatement_plan_clears_unaligned_intervals_in_derived_dev_tables(tmp_path: Path): + """ + Scenario: + I have a model A[hourly] in prod + I create dev and add a model B[daily] + I prod, I restate *one hour* of A + + Outcome: + The whole day for B should be restated. The restatement plan for prod has no hints about B's cadence because + B only exists in dev and there are no other downstream models in prod that would cause the restatement intervals + to be widened. + + Therefore, this test checks that SQLMesh does the right thing when an interval is partially cleared + """ + + model_a = """ + MODEL ( + name test.a, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column "ts" + ), + start '2024-01-01 00:00:00', + cron '@hourly' + ); + + select account_id, ts from test.external_table; + """ + + model_b = """ + MODEL ( + name test.b, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ts + ), + cron '@daily' + ); + + select account_id, ts from test.a where ts between @start_ts and @end_ts; + """ + + models_dir = tmp_path / "models" + models_dir.mkdir() + + with open(models_dir / "a.sql", "w") as f: + f.write(model_a) + + config = Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")) + ctx = Context(paths=[tmp_path], config=config) + + engine_adapter = ctx.engine_adapter + engine_adapter.create_schema("test") + + # source data + df = pd.DataFrame( + { + "account_id": [1001, 1002, 1003, 1004], + "ts": [ + "2024-01-01 00:30:00", + "2024-01-01 01:30:00", + "2024-01-01 02:30:00", + "2024-01-02 00:30:00", + ], + } + ) + columns_to_types = { + "account_id": exp.DataType.build("int"), + "ts": exp.DataType.build("timestamp"), + } + external_table = exp.table_(table="external_table", db="test", quoted=True) + engine_adapter.create_table(table_name=external_table, target_columns_to_types=columns_to_types) + engine_adapter.insert_append( + table_name=external_table, query_or_df=df, target_columns_to_types=columns_to_types + ) + + # plan + apply A[hourly] in prod + ctx.plan(auto_apply=True, no_prompts=True) + + # add B[daily] in dev + with open(models_dir / "b.sql", "w") as f: + f.write(model_b) + + # plan + apply dev + ctx.load() + ctx.plan(environment="dev", auto_apply=True, no_prompts=True) + + def _dates_in_table(table_name: str) -> t.List[str]: + return [ + str(r[0]) for r in engine_adapter.fetchall(f"select ts from {table_name} order by ts") + ] + + # verify initial state + for tbl in ["test.a", "test__dev.b"]: + assert _dates_in_table(tbl) == [ + "2024-01-01 00:30:00", + "2024-01-01 01:30:00", + "2024-01-01 02:30:00", + "2024-01-02 00:30:00", + ] + + # restate A in prod + engine_adapter.execute("delete from test.external_table where ts = '2024-01-01 01:30:00'") + ctx.plan( + restate_models=["test.a"], + start="2024-01-01 01:00:00", + end="2024-01-01 02:00:00", + auto_apply=True, + no_prompts=True, + ) + + # verify result + assert _dates_in_table("test.a") == [ + "2024-01-01 00:30:00", + "2024-01-01 02:30:00", + "2024-01-02 00:30:00", + ] + + # dev shouldnt have been affected yet + assert _dates_in_table("test__dev.b") == [ + "2024-01-01 00:30:00", + "2024-01-01 01:30:00", + "2024-01-01 02:30:00", + "2024-01-02 00:30:00", + ] + + # mess with A independently of SQLMesh to prove a whole day gets restated for B instead of just 1hr + snapshot_table_name = ctx.table_name("test.a", "dev") + engine_adapter.execute( + f"delete from {snapshot_table_name} where cast(ts as date) == '2024-01-01'" + ) + engine_adapter.execute( + f"insert into {snapshot_table_name} (account_id, ts) values (1007, '2024-01-02 01:30:00')" + ) + + assert _dates_in_table("test.a") == ["2024-01-02 00:30:00", "2024-01-02 01:30:00"] + + # run dev to trigger the processing of the prod restatement + ctx.run(environment="dev") + + # B should now have no data for 2024-01-01 + # To prove a single day was restated vs the whole model, it also shouldnt have the '2024-01-02 01:30:00' record + assert _dates_in_table("test__dev.b") == ["2024-01-02 00:30:00"] + + +def test_prod_restatement_plan_causes_dev_intervals_to_be_processed_in_next_dev_plan( + tmp_path: Path, +): + """ + Scenario: + I have a model A[hourly] in prod + I create dev and add a model B[daily] + I prod, I restate *one hour* of A + In dev, I run a normal plan instead of a cadence run + + Outcome: + The whole day for B should be restated as part of a normal plan + """ + + model_a = """ + MODEL ( + name test.a, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column "ts" + ), + start '2024-01-01 00:00:00', + cron '@hourly' + ); + + select account_id, ts from test.external_table; + """ + + model_b = """ + MODEL ( + name test.b, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ts + ), + cron '@daily' + ); + + select account_id, ts from test.a where ts between @start_ts and @end_ts; + """ + + models_dir = tmp_path / "models" + models_dir.mkdir() + + with open(models_dir / "a.sql", "w") as f: + f.write(model_a) + + config = Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")) + ctx = Context(paths=[tmp_path], config=config) + + engine_adapter = ctx.engine_adapter + engine_adapter.create_schema("test") + + # source data + df = pd.DataFrame( + { + "account_id": [1001, 1002, 1003, 1004], + "ts": [ + "2024-01-01 00:30:00", + "2024-01-01 01:30:00", + "2024-01-01 02:30:00", + "2024-01-02 00:30:00", + ], + } + ) + columns_to_types = { + "account_id": exp.DataType.build("int"), + "ts": exp.DataType.build("timestamp"), + } + external_table = exp.table_(table="external_table", db="test", quoted=True) + engine_adapter.create_table(table_name=external_table, target_columns_to_types=columns_to_types) + engine_adapter.insert_append( + table_name=external_table, query_or_df=df, target_columns_to_types=columns_to_types + ) + + # plan + apply A[hourly] in prod + ctx.plan(auto_apply=True, no_prompts=True) + + # add B[daily] in dev + with open(models_dir / "b.sql", "w") as f: + f.write(model_b) + + # plan + apply dev + ctx.load() + ctx.plan(environment="dev", auto_apply=True, no_prompts=True) + + def _dates_in_table(table_name: str) -> t.List[str]: + return [ + str(r[0]) for r in engine_adapter.fetchall(f"select ts from {table_name} order by ts") + ] + + # verify initial state + for tbl in ["test.a", "test__dev.b"]: + assert _dates_in_table(tbl) == [ + "2024-01-01 00:30:00", + "2024-01-01 01:30:00", + "2024-01-01 02:30:00", + "2024-01-02 00:30:00", + ] + + # restate A in prod + engine_adapter.execute("delete from test.external_table where ts = '2024-01-01 01:30:00'") + ctx.plan( + restate_models=["test.a"], + start="2024-01-01 01:00:00", + end="2024-01-01 02:00:00", + auto_apply=True, + no_prompts=True, + ) + + # verify result + assert _dates_in_table("test.a") == [ + "2024-01-01 00:30:00", + "2024-01-01 02:30:00", + "2024-01-02 00:30:00", + ] + + # dev shouldnt have been affected yet + assert _dates_in_table("test__dev.b") == [ + "2024-01-01 00:30:00", + "2024-01-01 01:30:00", + "2024-01-01 02:30:00", + "2024-01-02 00:30:00", + ] + + # plan dev which should trigger the missing intervals to get repopulated + ctx.plan(environment="dev", auto_apply=True, no_prompts=True) + + # dev should have the restated data + for tbl in ["test.a", "test__dev.b"]: + assert _dates_in_table(tbl) == [ + "2024-01-01 00:30:00", + "2024-01-01 02:30:00", + "2024-01-02 00:30:00", + ] + + +def test_prod_restatement_plan_causes_dev_intervals_to_be_widened_on_full_restatement_only_model( + tmp_path, +): + """ + Scenario: + I have am INCREMENTAL_BY_TIME_RANGE model A[daily] in prod + I create dev and add a INCREMENTAL_BY_UNIQUE_KEY model B (which supports full restatement only) + I prod, I restate one day of A which should cause intervals in dev to be cleared (but not processed) + In dev, I run a plan + + Outcome: + In the dev plan, the entire model for B should be rebuilt because it does not support partial restatement + """ + + model_a = """ + MODEL ( + name test.a, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column "ts" + ), + start '2024-01-01 00:00:00', + cron '@daily' + ); + + select account_id, ts from test.external_table where ts between @start_ts and @end_ts; + """ + + model_b = """ + MODEL ( + name test.b, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key (account_id, ts) + ), + cron '@daily' + ); + + select account_id, ts from test.a where ts between @start_ts and @end_ts; + """ + + models_dir = tmp_path / "models" + models_dir.mkdir() + + with open(models_dir / "a.sql", "w") as f: + f.write(model_a) + + config = Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")) + ctx = Context(paths=[tmp_path], config=config) + + engine_adapter = ctx.engine_adapter + engine_adapter.create_schema("test") + + # source data + df = pd.DataFrame( + { + "account_id": [1001, 1002, 1003, 1004], + "ts": [ + "2024-01-01 00:30:00", + "2024-01-02 01:30:00", + "2024-01-03 02:30:00", + "2024-01-04 00:30:00", + ], + } + ) + columns_to_types = { + "account_id": exp.DataType.build("int"), + "ts": exp.DataType.build("timestamp"), + } + external_table = exp.table_(table="external_table", db="test", quoted=True) + engine_adapter.create_table(table_name=external_table, target_columns_to_types=columns_to_types) + engine_adapter.insert_append( + table_name=external_table, query_or_df=df, target_columns_to_types=columns_to_types + ) + + # plan + apply A[daily] in prod + ctx.plan(auto_apply=True) + + # add B[daily] in dev + with open(models_dir / "b.sql", "w") as f: + f.write(model_b) + + # plan + apply dev + ctx.load() + ctx.plan(environment="dev", auto_apply=True) + + def _dates_in_table(table_name: str) -> t.List[str]: + return [ + str(r[0]) for r in engine_adapter.fetchall(f"select ts from {table_name} order by ts") + ] + + # verify initial state + for tbl in ["test.a", "test__dev.b"]: + assert _dates_in_table(tbl) == [ + "2024-01-01 00:30:00", + "2024-01-02 01:30:00", + "2024-01-03 02:30:00", + "2024-01-04 00:30:00", + ] + + # restate A in prod + engine_adapter.execute("delete from test.external_table where ts = '2024-01-02 01:30:00'") + ctx.plan( + restate_models=["test.a"], + start="2024-01-02 00:00:00", + end="2024-01-03 00:00:00", + auto_apply=True, + no_prompts=True, + ) + + # verify result + assert _dates_in_table("test.a") == [ + "2024-01-01 00:30:00", + "2024-01-03 02:30:00", + "2024-01-04 00:30:00", + ] + + # dev shouldnt have been affected yet + assert _dates_in_table("test__dev.b") == [ + "2024-01-01 00:30:00", + "2024-01-02 01:30:00", + "2024-01-03 02:30:00", + "2024-01-04 00:30:00", + ] + + # plan dev which should trigger the missing intervals to get repopulated + ctx.plan(environment="dev", auto_apply=True) + + # dev should have fully refreshed + # this is proven by the fact that INCREMENTAL_BY_UNIQUE_KEY cant propagate deletes, so if the + # model was not fully rebuilt, the deleted record would still be present + for tbl in ["test.a", "test__dev.b"]: + assert _dates_in_table(tbl) == [ + "2024-01-01 00:30:00", + "2024-01-03 02:30:00", + "2024-01-04 00:30:00", + ] + + +def test_prod_restatement_plan_missing_model_in_dev( + tmp_path: Path, +): + """ + Scenario: + I have a model B in prod but only model A in dev + I restate B in prod + + Outcome: + The A model should be ignore and the plan shouldn't fail + """ + + model_a = """ + MODEL ( + name test.a, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column "ts" + ), + start '2024-01-01 00:00:00', + cron '@hourly' + ); + + select account_id, ts from test.external_table; + """ + + model_b = """ + MODEL ( + name test.b, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ts + ), + cron '@daily' + ); + + select account_id, ts from test.external_table where ts between @start_ts and @end_ts; + """ + + models_dir = tmp_path / "models" + models_dir.mkdir() + + with open(models_dir / "a.sql", "w") as f: + f.write(model_a) + + config = Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")) + ctx = Context(paths=[tmp_path], config=config) + + engine_adapter = ctx.engine_adapter + engine_adapter.create_schema("test") + + # source data + df = pd.DataFrame( + { + "account_id": [1001, 1002, 1003, 1004], + "ts": [ + "2024-01-01 00:30:00", + "2024-01-01 01:30:00", + "2024-01-01 02:30:00", + "2024-01-02 00:30:00", + ], + } + ) + columns_to_types = { + "account_id": exp.DataType.build("int"), + "ts": exp.DataType.build("timestamp"), + } + external_table = exp.table_(table="external_table", db="test", quoted=True) + engine_adapter.create_table(table_name=external_table, target_columns_to_types=columns_to_types) + engine_adapter.insert_append( + table_name=external_table, query_or_df=df, target_columns_to_types=columns_to_types + ) + + # plan + apply A[hourly] in dev + ctx.plan("dev", auto_apply=True, no_prompts=True) + + # add B[daily] in prod and remove A + with open(models_dir / "b.sql", "w") as f: + f.write(model_b) + Path(models_dir / "a.sql").unlink() + + # plan + apply dev + ctx.load() + ctx.plan(auto_apply=True, no_prompts=True) + + # restate B in prod + ctx.plan( + restate_models=["test.b"], + start="2024-01-01", + end="2024-01-02", + auto_apply=True, + no_prompts=True, + ) + + +def test_prod_restatement_plan_includes_related_unpromoted_snapshots(tmp_path: Path): + """ + Scenario: + - I have models A <- B in prod + - I have models A <- B <- C in dev + - Both B and C have gone through a few iterations in dev so multiple snapshot versions exist + for them but not all of them are promoted / active + - I restate A in prod + + Outcome: + - Intervals should be cleared for all of the versions of B and C, regardless + of if they are active in any particular environment, in case they ever get made + active + """ + + models_dir = tmp_path / "models" + models_dir.mkdir() + + (models_dir / "a.sql").write_text(""" + MODEL ( + name test.a, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column "ts" + ), + start '2024-01-01 00:00:00', + cron '@daily' + ); + + select 1 as a, now() as ts; + """) + + (models_dir / "b.sql").write_text(""" + MODEL ( + name test.b, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column "ts" + ), + start '2024-01-01 00:00:00', + cron '@daily' + ); + + select a, ts from test.a + """) + + config = Config(model_defaults=ModelDefaultsConfig(dialect="duckdb", start="2024-01-01")) + ctx = Context(paths=[tmp_path], config=config) + + def _all_snapshots() -> t.Dict[SnapshotId, Snapshot]: + all_snapshot_ids = [ + SnapshotId(name=name, identifier=identifier) + for (name, identifier) in ctx.state_sync.state_sync.engine_adapter.fetchall( # type: ignore + "select name, identifier from sqlmesh._snapshots" + ) + ] + return ctx.state_sync.get_snapshots(all_snapshot_ids) + + # plan + apply prod + ctx.plan(environment="prod", auto_apply=True) + assert len(_all_snapshots()) == 2 + + # create dev with new version of B + (models_dir / "b.sql").write_text(""" + MODEL ( + name test.b, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column "ts" + ), + start '2024-01-01 00:00:00', + cron '@daily' + ); + + select a, ts, 'b dev 1' as change from test.a + """) + + ctx.load() + ctx.plan(environment="dev", auto_apply=True) + assert len(_all_snapshots()) == 3 + + # update B (new version) and create C + (models_dir / "b.sql").write_text(""" + MODEL ( + name test.b, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column "ts" + ), + start '2024-01-01 00:00:00', + cron '@daily' + ); + + select a, ts, 'b dev 2' as change from test.a + """) + + (models_dir / "c.sql").write_text(""" + MODEL ( + name test.c, + kind FULL, + cron '@daily' + ); + + select *, 'c initial' as val from test.b + """) + + ctx.load() + ctx.plan(environment="dev", auto_apply=True) + assert len(_all_snapshots()) == 5 + + # update C (new version), create D (unrelated) + (models_dir / "c.sql").write_text(""" + MODEL ( + name test.c, + kind FULL, + cron '@daily' + ); + + select *, 'c updated' as val from test.b + """) + + (models_dir / "d.sql").write_text(""" + MODEL ( + name test.d, + cron '@daily' + ); + + select 1 as unrelated + """) + + ctx.load() + ctx.plan(environment="dev", auto_apply=True) + all_snapshots_prior_to_restatement = _all_snapshots() + assert len(all_snapshots_prior_to_restatement) == 7 + + def _snapshot_instances(lst: t.Dict[SnapshotId, Snapshot], name_match: str) -> t.List[Snapshot]: + return [s for s_id, s in lst.items() if name_match in s_id.name] + + # verify initial state + + # 1 instance of A (prod) + assert len(_snapshot_instances(all_snapshots_prior_to_restatement, '"a"')) == 1 + + # 3 instances of B (original in prod + 2 updates in dev) + assert len(_snapshot_instances(all_snapshots_prior_to_restatement, '"b"')) == 3 + + # 2 instances of C (initial + update in dev) + assert len(_snapshot_instances(all_snapshots_prior_to_restatement, '"c"')) == 2 + + # 1 instance of D (initial - dev) + assert len(_snapshot_instances(all_snapshots_prior_to_restatement, '"d"')) == 1 + + # restate A in prod + ctx.plan(environment="prod", restate_models=['"memory"."test"."a"'], auto_apply=True) + + all_snapshots_after_restatement = _all_snapshots() + + # All versions of B and C in dev should have had intervals cleared + # D in dev should not be touched and A + B in prod shoud also not be touched + a = _snapshot_instances(all_snapshots_after_restatement, '"a"') + assert len(a) == 1 + + b = _snapshot_instances(all_snapshots_after_restatement, '"b"') + # the 1 B instance in prod should be populated and 2 in dev (1 active) should be cleared + assert len(b) == 3 + assert len([s for s in b if not s.intervals]) == 2 + + c = _snapshot_instances(all_snapshots_after_restatement, '"c"') + # the 2 instances of C in dev (1 active) should be cleared + assert len(c) == 2 + assert len([s for s in c if not s.intervals]) == 2 + + d = _snapshot_instances(all_snapshots_after_restatement, '"d"') + # D should not be touched since it's in no way downstream of A in prod + assert len(d) == 1 + assert d[0].intervals + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_dev_restatement_of_prod_model(init_and_plan_context: t.Callable): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + model = context.get_model("sushi.waiter_revenue_by_day") + context.upsert_model(add_projection_to_model(t.cast(SqlModel, model))) + + context.plan("dev", auto_apply=True, no_prompts=True, skip_tests=True) + + restatement_plan = context.plan_builder("dev", restate_models=["*"]).build() + assert set(restatement_plan.restatements) == { + context.get_snapshot("sushi.waiter_revenue_by_day").snapshot_id, + context.get_snapshot("sushi.top_waiters").snapshot_id, + } + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_restatement_of_full_model_with_start(init_and_plan_context: t.Callable): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + restatement_plan = context.plan( + restate_models=["sushi.customers"], + start="2023-01-07", + auto_apply=True, + no_prompts=True, + ) + + sushi_customer_interval = restatement_plan.restatements[ + context.get_snapshot("sushi.customers").snapshot_id + ] + assert sushi_customer_interval == (to_timestamp("2023-01-01"), to_timestamp("2023-01-09")) + waiter_by_day_interval = restatement_plan.restatements[ + context.get_snapshot("sushi.waiter_as_customer_by_day").snapshot_id + ] + assert waiter_by_day_interval == (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")) + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_restatement_should_not_override_environment_statements(init_and_plan_context: t.Callable): + context, _ = init_and_plan_context("examples/sushi") + context.config.before_all = ["SELECT 'test_before_all';", *context.config.before_all] + context.load() + + context.plan("prod", auto_apply=True, no_prompts=True, skip_tests=True) + + prod_env_statements = context.state_reader.get_environment_statements(c.PROD) + assert prod_env_statements[0].before_all[0] == "SELECT 'test_before_all';" + + context.plan( + restate_models=["sushi.waiter_revenue_by_day"], + start="2023-01-07", + auto_apply=True, + no_prompts=True, + ) + + prod_env_statements = context.state_reader.get_environment_statements(c.PROD) + assert prod_env_statements[0].before_all[0] == "SELECT 'test_before_all';" + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_restatement_shouldnt_backfill_beyond_prod_intervals(init_and_plan_context: t.Callable): + context, _ = init_and_plan_context("examples/sushi") + + model = context.get_model("sushi.top_waiters") + context.upsert_model(SqlModel.parse_obj({**model.dict(), "cron": "@hourly"})) + + context.plan("prod", auto_apply=True, no_prompts=True, skip_tests=True) + context.run() + + with time_machine.travel("2023-01-09 02:00:00 UTC"): + # It's time to backfill the waiter_revenue_by_day model but it hasn't run yet + restatement_plan = context.plan( + restate_models=["sushi.waiter_revenue_by_day"], + no_prompts=True, + skip_tests=True, + ) + intervals_by_id = {i.snapshot_id: i for i in restatement_plan.missing_intervals} + # Make sure the intervals don't go beyond the prod intervals + assert intervals_by_id[context.get_snapshot("sushi.top_waiters").snapshot_id].intervals[-1][ + 1 + ] == to_timestamp("2023-01-08 15:00:00 UTC") + assert intervals_by_id[ + context.get_snapshot("sushi.waiter_revenue_by_day").snapshot_id + ].intervals[-1][1] == to_timestamp("2023-01-08 00:00:00 UTC") + + +def test_restatement_plan_interval_external_visibility(tmp_path: Path): + """ + Scenario: + - `prod` environment exists, models A <- B + - `dev` environment created, models A <- B(dev) <- C (dev) + - Restatement plan is triggered against `prod` for model A + - During restatement, a new dev environment `dev_2` is created with a new version of B(dev_2) + + Outcome: + - At no point are the prod_intervals considered "missing" from state for A + - The intervals for B(dev) and C(dev) are cleared + - The intervals for B(dev_2) are also cleared even though the environment didnt exist at the time the plan was started, + because they are based on the data from a partially restated version of A + """ + + models_dir = tmp_path / "models" + models_dir.mkdir() + + lock_file_path = tmp_path / "test.lock" # python model blocks while this file is present + + evaluation_lock_file_path = ( + tmp_path / "evaluation.lock" + ) # python model creates this file if it's in the wait loop and deletes it once done + + # Note: to make execution block so we can test stuff, we use a Python model that blocks until it no longer detects the presence of a file + (models_dir / "model_a.py").write_text(f""" +from sqlmesh.core.model import model +from sqlmesh.core.macros import MacroEvaluator + +@model( + "test.model_a", + is_sql=True, + kind="FULL" +) +def entrypoint(evaluator: MacroEvaluator) -> str: + from pathlib import Path + import time + + if evaluator.runtime_stage == 'evaluating': + while True: + if Path("{str(lock_file_path)}").exists(): + Path("{str(evaluation_lock_file_path)}").touch() + print("lock exists; sleeping") + time.sleep(2) + else: + Path("{str(evaluation_lock_file_path)}").unlink(missing_ok=True) + break + + return "select 'model_a' as m" +""") + + (models_dir / "model_b.sql").write_text(""" + MODEL ( + name test.model_b, + kind FULL + ); + + select a.m as m, 'model_b' as mb from test.model_a as a + """) + + config = Config( + gateways={ + "": GatewayConfig( + connection=DuckDBConnectionConfig(database=str(tmp_path / "db.db")), + state_connection=DuckDBConnectionConfig(database=str(tmp_path / "state.db")), + ) + }, + model_defaults=ModelDefaultsConfig(dialect="duckdb", start="2024-01-01"), + ) + ctx = Context(paths=[tmp_path], config=config) + + ctx.plan(environment="prod", auto_apply=True) + + assert len(ctx.snapshots) == 2 + assert all(s.intervals for s in ctx.snapshots.values()) + + prod_model_a_snapshot_id = ctx.snapshots['"db"."test"."model_a"'].snapshot_id + prod_model_b_snapshot_id = ctx.snapshots['"db"."test"."model_b"'].snapshot_id + + # dev models + # new version of B + (models_dir / "model_b.sql").write_text(""" + MODEL ( + name test.model_b, + kind FULL + ); + + select a.m as m, 'model_b' as mb, 'dev' as dev_version from test.model_a as a + """) + + # add C + (models_dir / "model_c.sql").write_text(""" + MODEL ( + name test.model_c, + kind FULL + ); + + select b.*, 'model_c' as mc from test.model_b as b + """) + + ctx.load() + ctx.plan(environment="dev", auto_apply=True) + + dev_model_b_snapshot_id = ctx.snapshots['"db"."test"."model_b"'].snapshot_id + dev_model_c_snapshot_id = ctx.snapshots['"db"."test"."model_c"'].snapshot_id + + assert dev_model_b_snapshot_id != prod_model_b_snapshot_id + + # now, we restate A in prod but touch the lockfile so it hangs during evaluation + # we also have to do it in its own thread due to the hang + lock_file_path.touch() + + def _run_restatement_plan(tmp_path: Path, config: Config, q: queue.Queue): + q.put("thread_started") + + # give this thread its own Context object to prevent segfaulting the Python interpreter + restatement_ctx = Context(paths=[tmp_path], config=config) + + # dev2 not present before the restatement plan starts + assert restatement_ctx.state_sync.get_environment("dev2") is None + + q.put("plan_started") + plan = restatement_ctx.plan( + environment="prod", restate_models=['"db"."test"."model_a"'], auto_apply=True + ) + q.put("plan_completed") + + # dev2 was created during the restatement plan + assert restatement_ctx.state_sync.get_environment("dev2") is not None + + return plan + + executor = ThreadPoolExecutor() + q: queue.Queue = queue.Queue() + restatement_plan_future = executor.submit(_run_restatement_plan, tmp_path, config, q) + assert q.get() == "thread_started" + + try: + if e := restatement_plan_future.exception(timeout=1): + # abort early if the plan thread threw an exception + raise e + except TimeoutError: + # that's ok, we dont actually expect the plan to have finished in 1 second + pass + + # while that restatement is running, we can simulate another process and check that it sees no empty intervals + assert q.get() == "plan_started" + + # dont check for potentially missing intervals until the plan is in the evaluation loop + attempts = 0 + while not evaluation_lock_file_path.exists(): + time.sleep(2) + attempts += 1 + if attempts > 10: + raise ValueError("Gave up waiting for evaluation loop") + + ctx.clear_caches() # get rid of the file cache so that data is re-fetched from state + prod_models_from_state = ctx.state_sync.get_snapshots( + snapshot_ids=[prod_model_a_snapshot_id, prod_model_b_snapshot_id] + ) + + # prod intervals should be present still + assert all(m.intervals for m in prod_models_from_state.values()) + + # so should dev intervals since prod restatement is still running + assert all(m.intervals for m in ctx.snapshots.values()) + + # now, lets create a new dev environment "dev2", while the prod restatement plan is still running, + # that changes model_b while still being based on the original version of model_a + (models_dir / "model_b.sql").write_text(""" + MODEL ( + name test.model_b, + kind FULL + ); + + select a.m as m, 'model_b' as mb, 'dev2' as dev_version from test.model_a as a + """) + ctx.load() + ctx.plan(environment="dev2", auto_apply=True) + + dev2_model_b_snapshot_id = ctx.snapshots['"db"."test"."model_b"'].snapshot_id + assert dev2_model_b_snapshot_id != dev_model_b_snapshot_id + assert dev2_model_b_snapshot_id != prod_model_b_snapshot_id + + # as at this point, everything still has intervals + ctx.clear_caches() + assert all( + s.intervals + for s in ctx.state_sync.get_snapshots( + snapshot_ids=[ + prod_model_a_snapshot_id, + prod_model_b_snapshot_id, + dev_model_b_snapshot_id, + dev_model_c_snapshot_id, + dev2_model_b_snapshot_id, + ] + ).values() + ) + + # now, we finally let that restatement plan complete + # first, verify it's still blocked where it should be + assert not restatement_plan_future.done() + + lock_file_path.unlink() # remove lock file, plan should be able to proceed now + + if e := restatement_plan_future.exception(): # blocks until future complete + raise e + + assert restatement_plan_future.result() + assert q.get() == "plan_completed" + + ctx.clear_caches() + + # check that intervals in prod are present + assert all( + s.intervals + for s in ctx.state_sync.get_snapshots( + snapshot_ids=[ + prod_model_a_snapshot_id, + prod_model_b_snapshot_id, + ] + ).values() + ) + + # check that intervals in dev have been cleared, including the dev2 env that + # was created after the restatement plan started + assert all( + not s.intervals + for s in ctx.state_sync.get_snapshots( + snapshot_ids=[ + dev_model_b_snapshot_id, + dev_model_c_snapshot_id, + dev2_model_b_snapshot_id, + ] + ).values() + ) + + executor.shutdown() + + +def test_restatement_plan_detects_prod_deployment_during_restatement(tmp_path: Path): + """ + Scenario: + - `prod` environment exists, model A + - `dev` environment created, model A(dev) + - Restatement plan is triggered against `prod` for model A + - During restatement, someone else deploys A(dev) to prod, replacing the model that is currently being restated. + + Outcome: + - The deployment plan for dev -> prod should succeed in deploying the new version of A + - The prod restatement plan should fail with a ConflictingPlanError and warn about the model that got updated while undergoing restatement + - The new version of A should have no intervals cleared. The user needs to rerun the restatement if the intervals should still be cleared + """ + orig_console = get_console() + console = CaptureTerminalConsole() + set_console(console) + + models_dir = tmp_path / "models" + models_dir.mkdir() + + lock_file_path = tmp_path / "test.lock" # python model blocks while this file is present + + evaluation_lock_file_path = ( + tmp_path / "evaluation.lock" + ) # python model creates this file if it's in the wait loop and deletes it once done + + # Note: to make execution block so we can test stuff, we use a Python model that blocks until it no longer detects the presence of a file + (models_dir / "model_a.py").write_text(f""" +from sqlmesh.core.model import model +from sqlmesh.core.macros import MacroEvaluator + +@model( + "test.model_a", + is_sql=True, + kind="FULL" +) +def entrypoint(evaluator: MacroEvaluator) -> str: + from pathlib import Path + import time + + if evaluator.runtime_stage == 'evaluating': + while True: + if Path("{str(lock_file_path)}").exists(): + Path("{str(evaluation_lock_file_path)}").touch() + print("lock exists; sleeping") + time.sleep(2) + else: + Path("{str(evaluation_lock_file_path)}").unlink(missing_ok=True) + break + + return "select 'model_a' as m" +""") + + config = Config( + gateways={ + "": GatewayConfig( + connection=DuckDBConnectionConfig(database=str(tmp_path / "db.db")), + state_connection=DuckDBConnectionConfig(database=str(tmp_path / "state.db")), + ) + }, + model_defaults=ModelDefaultsConfig(dialect="duckdb", start="2024-01-01"), + ) + ctx = Context(paths=[tmp_path], config=config) + + # create prod + ctx.plan(environment="prod", auto_apply=True) + original_prod = ctx.state_sync.get_environment("prod") + assert original_prod + + # update model_a for dev + (models_dir / "model_a.py").unlink() + (models_dir / "model_a.sql").write_text(""" + MODEL ( + name test.model_a, + kind FULL + ); + + select 1 as changed + """) + + # create dev + ctx.load() + plan = ctx.plan(environment="dev", auto_apply=True) + assert len(plan.modified_snapshots) == 1 + new_model_a_snapshot_id = list(plan.modified_snapshots)[0] + + # now, trigger a prod restatement plan in a different thread and block it to simulate a long restatement + thread_console = None + + def _run_restatement_plan(tmp_path: Path, config: Config, q: queue.Queue): + nonlocal thread_console + q.put("thread_started") + + # Give this thread its own markdown console to avoid Rich LiveError + thread_console = MarkdownConsole() + set_console(thread_console) + + # give this thread its own Context object to prevent segfaulting the Python interpreter + restatement_ctx = Context(paths=[tmp_path], config=config) + + # ensure dev is present before the restatement plan starts + assert restatement_ctx.state_sync.get_environment("dev") is not None + + q.put("plan_started") + expected_error = None + try: + restatement_ctx.plan( + environment="prod", restate_models=['"db"."test"."model_a"'], auto_apply=True + ) + except ConflictingPlanError as e: + expected_error = e + + q.put("plan_completed") + return expected_error + + executor = ThreadPoolExecutor() + q: queue.Queue = queue.Queue() + lock_file_path.touch() + + restatement_plan_future = executor.submit(_run_restatement_plan, tmp_path, config, q) + restatement_plan_future.add_done_callback(lambda _: executor.shutdown()) + + assert q.get() == "thread_started" + + try: + if e := restatement_plan_future.exception(timeout=1): + # abort early if the plan thread threw an exception + raise e + except TimeoutError: + # that's ok, we dont actually expect the plan to have finished in 1 second + pass + + assert q.get() == "plan_started" + + # ok, now the prod restatement plan is running, let's deploy dev to prod + ctx.plan(environment="prod", auto_apply=True) + + new_prod = ctx.state_sync.get_environment("prod") + assert new_prod + assert new_prod.plan_id != original_prod.plan_id + assert new_prod.previous_plan_id == original_prod.plan_id + + # new prod is deployed but restatement plan is still running + assert not restatement_plan_future.done() + + # allow restatement plan to complete + lock_file_path.unlink() + + plan_error = restatement_plan_future.result() + assert isinstance(plan_error, ConflictingPlanError) + assert "please re-apply your plan" in repr(plan_error).lower() + + output = " ".join(re.split("\\s+", thread_console.captured_output, flags=re.UNICODE)) # type: ignore + assert ( + f"The following models had new versions deployed while data was being restated: └── test.model_a" + in output + ) + + # check that no intervals have been cleared from the model_a currently in prod + model_a = ctx.state_sync.get_snapshots(snapshot_ids=[new_model_a_snapshot_id])[ + new_model_a_snapshot_id + ] + assert isinstance(model_a.node, SqlModel) + assert model_a.node.render_query_or_raise().sql() == 'SELECT 1 AS "changed"' + assert len(model_a.intervals) + + set_console(orig_console) + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_restatement_plan_outside_parent_date_range(init_and_plan_context: t.Callable): + context, _ = init_and_plan_context("examples/sushi") + + context.upsert_model("sushi.items", start="2023-01-06") + context.upsert_model("sushi.orders", start="2023-01-06") + # One of the parents should derive the start from its own parents for the issue + # to reproduce + context.upsert_model("sushi.order_items", start=None) + context.upsert_model("sushi.waiter_revenue_by_day", start="2023-01-01", audits=[]) + + context.plan("prod", auto_apply=True, no_prompts=True, skip_tests=True) + + restated_snapshot = context.get_snapshot("sushi.waiter_revenue_by_day") + downstream_snapshot = context.get_snapshot("sushi.top_waiters") + + plan = context.plan_builder( + restate_models=["sushi.waiter_revenue_by_day"], + start="2023-01-01", + end="2023-01-01", + min_intervals=0, + ).build() + assert plan.snapshots != context.snapshots + + assert plan.requires_backfill + assert plan.restatements == { + restated_snapshot.snapshot_id: (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + downstream_snapshot.snapshot_id: (to_timestamp("2023-01-01"), to_timestamp("2023-01-09")), + } + assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=downstream_snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + ], + ), + SnapshotIntervals( + snapshot_id=restated_snapshot.snapshot_id, + intervals=[ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + ], + ), + ] + + context.apply(plan) diff --git a/tests/core/integration/test_run.py b/tests/core/integration/test_run.py new file mode 100644 index 0000000000..c3e6626ad0 --- /dev/null +++ b/tests/core/integration/test_run.py @@ -0,0 +1,248 @@ +from __future__ import annotations + +import typing as t +import pytest +import time_machine +from pytest_mock.plugin import MockerFixture + +from sqlmesh.core import constants as c +from sqlmesh.core import dialect as d +from sqlmesh.core.config.categorizer import CategorizerConfig +from sqlmesh.core.model import ( + SqlModel, + PythonModel, + load_sql_based_model, +) +from sqlmesh.utils.date import to_timestamp + +if t.TYPE_CHECKING: + pass + +pytestmark = pytest.mark.slow + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_run_with_select_models( + init_and_plan_context: t.Callable, +): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + with time_machine.travel("2023-01-09 00:00:00 UTC"): + assert context.run(select_models=["*waiter_revenue_by_day"]) + + snapshots = context.state_sync.state_sync.get_snapshots(context.snapshots.values()) + # Only waiter_revenue_by_day and its parents should be backfilled up to 2023-01-09. + assert {s.name: s.intervals[0][1] for s in snapshots.values() if s.intervals} == { + '"memory"."sushi"."waiter_revenue_by_day"': to_timestamp("2023-01-09"), + '"memory"."sushi"."order_items"': to_timestamp("2023-01-09"), + '"memory"."sushi"."orders"': to_timestamp("2023-01-09"), + '"memory"."sushi"."items"': to_timestamp("2023-01-09"), + '"memory"."sushi"."customer_revenue_lifetime"': to_timestamp("2023-01-08"), + '"memory"."sushi"."customer_revenue_by_day"': to_timestamp("2023-01-08"), + '"memory"."sushi"."latest_order"': to_timestamp("2023-01-08"), + '"memory"."sushi"."waiter_names"': to_timestamp("2023-01-08"), + '"memory"."sushi"."raw_marketing"': to_timestamp("2023-01-08"), + '"memory"."sushi"."marketing"': to_timestamp("2023-01-08"), + '"memory"."sushi"."waiter_as_customer_by_day"': to_timestamp("2023-01-08"), + '"memory"."sushi"."top_waiters"': to_timestamp("2023-01-08"), + '"memory"."raw"."demographics"': to_timestamp("2023-01-08"), + "assert_item_price_above_zero": to_timestamp("2023-01-08"), + '"memory"."sushi"."active_customers"': to_timestamp("2023-01-08"), + '"memory"."sushi"."customers"': to_timestamp("2023-01-08"), + '"memory"."sushi"."count_customers_active"': to_timestamp("2023-01-08"), + '"memory"."sushi"."count_customers_inactive"': to_timestamp("2023-01-08"), + } + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_run_with_select_models_no_auto_upstream( + init_and_plan_context: t.Callable, +): + context, _ = init_and_plan_context("examples/sushi") + + model = context.get_model("sushi.waiter_revenue_by_day") + model = SqlModel.parse_obj({**model.dict(), "audits": []}) + context.upsert_model(model) + + context.plan("prod", no_prompts=True, skip_tests=True, auto_apply=True) + + with time_machine.travel("2023-01-09 00:00:00 UTC"): + assert context.run(select_models=["*waiter_revenue_by_day"], no_auto_upstream=True) + + snapshots = context.state_sync.state_sync.get_snapshots(context.snapshots.values()) + # Only waiter_revenue_by_day should be backfilled up to 2023-01-09. + assert {s.name: s.intervals[0][1] for s in snapshots.values() if s.intervals} == { + '"memory"."sushi"."waiter_revenue_by_day"': to_timestamp("2023-01-09"), + '"memory"."sushi"."order_items"': to_timestamp("2023-01-08"), + '"memory"."sushi"."orders"': to_timestamp("2023-01-08"), + '"memory"."sushi"."items"': to_timestamp("2023-01-08"), + '"memory"."sushi"."customer_revenue_lifetime"': to_timestamp("2023-01-08"), + '"memory"."sushi"."customer_revenue_by_day"': to_timestamp("2023-01-08"), + '"memory"."sushi"."latest_order"': to_timestamp("2023-01-08"), + '"memory"."sushi"."waiter_names"': to_timestamp("2023-01-08"), + '"memory"."sushi"."raw_marketing"': to_timestamp("2023-01-08"), + '"memory"."sushi"."marketing"': to_timestamp("2023-01-08"), + '"memory"."sushi"."waiter_as_customer_by_day"': to_timestamp("2023-01-08"), + '"memory"."sushi"."top_waiters"': to_timestamp("2023-01-08"), + '"memory"."raw"."demographics"': to_timestamp("2023-01-08"), + "assert_item_price_above_zero": to_timestamp("2023-01-08"), + '"memory"."sushi"."active_customers"': to_timestamp("2023-01-08"), + '"memory"."sushi"."customers"': to_timestamp("2023-01-08"), + '"memory"."sushi"."count_customers_active"': to_timestamp("2023-01-08"), + '"memory"."sushi"."count_customers_inactive"': to_timestamp("2023-01-08"), + } + + +@time_machine.travel("2023-01-08 15:00:00 UTC") +def test_run_respects_excluded_transitive_dependencies(init_and_plan_context: t.Callable): + context, _ = init_and_plan_context("examples/sushi") + + # Graph: C <- B <- A + # B is a transitive dependency linking A and C + # Note that the alphabetical ordering of the model names is intentional and helps + # surface the problem + expressions_a = d.parse( + f""" + MODEL ( + name memory.sushi.test_model_c, + kind FULL, + allow_partials true, + cron '@hourly', + start '2023-01-01', + ); + + SELECT @execution_ts AS execution_ts + """ + ) + model_c = load_sql_based_model(expressions_a) + context.upsert_model(model_c) + + # A VIEW model with no partials allowed and a daily cron instead of hourly. + expressions_b = d.parse( + f""" + MODEL ( + name memory.sushi.test_model_b, + kind VIEW, + allow_partials false, + cron '@daily', + ); + + SELECT * FROM memory.sushi.test_model_c + """ + ) + model_b = load_sql_based_model(expressions_b) + context.upsert_model(model_b) + + expressions_a = d.parse( + f""" + MODEL ( + name memory.sushi.test_model_a, + kind FULL, + allow_partials true, + cron '@hourly', + ); + + SELECT * FROM memory.sushi.test_model_b + """ + ) + model_a = load_sql_based_model(expressions_a) + context.upsert_model(model_a) + + context.plan("prod", skip_tests=True, auto_apply=True, no_prompts=True) + assert ( + context.fetchdf("SELECT execution_ts FROM memory.sushi.test_model_c")["execution_ts"].iloc[ + 0 + ] + == "2023-01-08 15:00:00" + ) + + with time_machine.travel("2023-01-08 17:00:00 UTC", tick=False): + context.run( + "prod", + select_models=["*test_model_c", "*test_model_a"], + no_auto_upstream=True, + ignore_cron=True, + ) + assert ( + context.fetchdf("SELECT execution_ts FROM memory.sushi.test_model_a")[ + "execution_ts" + ].iloc[0] + == "2023-01-08 17:00:00" + ) + + +@time_machine.travel("2023-01-08 00:00:00 UTC") +def test_snapshot_triggers(init_and_plan_context: t.Callable, mocker: MockerFixture): + context, plan = init_and_plan_context("examples/sushi") + context.apply(plan) + + # auto-restatement triggers + orders = context.get_model("sushi.orders") + orders_kind = { + **orders.kind.dict(), + "auto_restatement_cron": "@hourly", + } + orders_kwargs = { + **orders.dict(), + "kind": orders_kind, + } + context.upsert_model(PythonModel.parse_obj(orders_kwargs)) + + order_items = context.get_model("sushi.order_items") + order_items_kind = { + **order_items.kind.dict(), + "auto_restatement_cron": "@hourly", + } + order_items_kwargs = { + **order_items.dict(), + "kind": order_items_kind, + } + context.upsert_model(PythonModel.parse_obj(order_items_kwargs)) + + waiter_revenue_by_day = context.get_model("sushi.waiter_revenue_by_day") + waiter_revenue_by_day_kind = { + **waiter_revenue_by_day.kind.dict(), + "auto_restatement_cron": "@hourly", + } + waiter_revenue_by_day_kwargs = { + **waiter_revenue_by_day.dict(), + "kind": waiter_revenue_by_day_kind, + } + context.upsert_model(SqlModel.parse_obj(waiter_revenue_by_day_kwargs)) + + context.plan(auto_apply=True, no_prompts=True, categorizer_config=CategorizerConfig.all_full()) + + scheduler = context.scheduler() + + import sqlmesh + + spy = mocker.spy(sqlmesh.core.scheduler.Scheduler, "run_merged_intervals") + + with time_machine.travel("2023-01-09 00:00:01 UTC"): + scheduler.run( + environment=c.PROD, + start="2023-01-01", + auto_restatement_enabled=True, + ) + + assert spy.called + + actual_triggers = spy.call_args.kwargs["auto_restatement_triggers"] + actual_triggers = {k: v for k, v in actual_triggers.items() if v} + assert len(actual_triggers) == 12 + + for id, trigger in actual_triggers.items(): + model_name = id.name.replace('"memory"."sushi".', "").replace('"', "") + auto_restatement_triggers = [ + t.name.replace('"memory"."sushi".', "").replace('"', "") for t in trigger + ] + + if model_name in ("orders", "order_items", "waiter_revenue_by_day"): + assert auto_restatement_triggers == [model_name] + elif model_name in ("customer_revenue_lifetime", "customer_revenue_by_day"): + assert sorted(auto_restatement_triggers) == sorted(["orders", "order_items"]) + elif model_name == "top_waiters": + assert auto_restatement_triggers == ["waiter_revenue_by_day"] + else: + assert auto_restatement_triggers == ["orders"] diff --git a/tests/core/integration/utils.py b/tests/core/integration/utils.py new file mode 100644 index 0000000000..bc731e6cc8 --- /dev/null +++ b/tests/core/integration/utils.py @@ -0,0 +1,350 @@ +from __future__ import annotations + +import typing as t +from sqlmesh.core.model.common import ParsableSql +from sqlglot import exp +from sqlglot.expressions import DataType + +from sqlmesh.core import constants as c +from sqlmesh.core.context import Context +from sqlmesh.core.engine_adapter import EngineAdapter +from sqlmesh.core.environment import EnvironmentNamingInfo +from sqlmesh.core.model import ( + IncrementalByTimeRangeKind, + IncrementalByUniqueKeyKind, + ModelKind, + ModelKindName, + SqlModel, + TimeColumn, +) +from sqlmesh.core.model.kind import model_kind_type_from_name +from sqlmesh.core.plan import Plan, PlanBuilder +from sqlmesh.core.snapshot import ( + DeployabilityIndex, + Snapshot, + SnapshotChangeCategory, + SnapshotId, + SnapshotInfoLike, + SnapshotTableInfo, +) +from sqlmesh.utils.date import TimeLike + + +def select_all(table: str, adapter: EngineAdapter) -> t.Iterable: + return adapter.fetchall(f"select * from {table} order by 1") + + +def snapshots_to_versions(snapshots: t.Iterable[Snapshot]) -> t.Dict[str, str]: + return {snapshot.name: snapshot.version or "" for snapshot in snapshots} + + +def to_snapshot_info(snapshot: SnapshotInfoLike) -> SnapshotTableInfo: + return snapshot.table_info + + +def start(context: Context) -> TimeLike: + env = context.state_sync.get_environment("prod") + assert env + return env.start_at + + +def add_projection_to_model(model: SqlModel, literal: bool = True) -> SqlModel: + one_expr = exp.Literal.number(1).as_("one") if literal else exp.column("one") + kwargs = { + **model.dict(), + "query": model.query.select(one_expr), # type: ignore + } + return SqlModel.parse_obj(kwargs) + + +def plan_choice(plan_builder: PlanBuilder, choice: SnapshotChangeCategory) -> None: + for snapshot in plan_builder.build().snapshots.values(): + if not snapshot.version: + plan_builder.set_choice(snapshot, choice) + + +def apply_to_environment( + context: Context, + environment: str, + choice: t.Optional[SnapshotChangeCategory] = None, + plan_validators: t.Optional[t.Iterable[t.Callable]] = None, + apply_validators: t.Optional[t.Iterable[t.Callable]] = None, + plan_start: t.Optional[TimeLike] = None, + allow_destructive_models: t.Optional[t.List[str]] = None, + enable_preview: bool = False, +): + plan_validators = plan_validators or [] + apply_validators = apply_validators or [] + + plan_builder = context.plan_builder( + environment, + start=plan_start or start(context) if environment != c.PROD else None, + forward_only=choice == SnapshotChangeCategory.FORWARD_ONLY, + include_unmodified=True, + allow_destructive_models=allow_destructive_models if allow_destructive_models else [], + enable_preview=enable_preview, + ) + if environment != c.PROD: + plan_builder.set_start(plan_start or start(context)) + + if choice: + if choice == SnapshotChangeCategory.FORWARD_ONLY: + # FORWARD_ONLY is deprecated, fallback to NON_BREAKING to keep the existing tests + choice = SnapshotChangeCategory.NON_BREAKING + plan_choice(plan_builder, choice) + for validator in plan_validators: + validator(context, plan_builder.build()) + + plan = plan_builder.build() + context.apply(plan) + + validate_apply_basics(context, environment, plan.snapshots.values(), plan.deployability_index) + for validator in apply_validators: + validator(context) + return plan + + +def change_data_type( + context: Context, model_name: str, old_type: DataType.Type, new_type: DataType.Type +) -> None: + model = context.get_model(model_name) + assert model is not None + + if isinstance(model, SqlModel): + query = model.query.copy() + data_types = query.find_all(DataType) + for data_type in data_types: + if data_type.this == old_type: + data_type.set("this", new_type) + context.upsert_model(model_name, query_=ParsableSql(sql=query.sql(dialect=model.dialect))) + elif model.columns_to_types_ is not None: + for k, v in model.columns_to_types_.items(): + if v.this == old_type: + model.columns_to_types_[k] = DataType.build(new_type) + context.upsert_model(model_name, columns=model.columns_to_types_) + + +def validate_snapshots_in_state_sync(snapshots: t.Iterable[Snapshot], context: Context) -> None: + snapshot_infos = map(to_snapshot_info, snapshots) + state_sync_table_infos = map( + to_snapshot_info, context.state_reader.get_snapshots(snapshots).values() + ) + assert set(snapshot_infos) == set(state_sync_table_infos) + + +def validate_state_sync_environment( + snapshots: t.Iterable[Snapshot], env: str, context: Context +) -> None: + environment = context.state_reader.get_environment(env) + assert environment + snapshot_infos = map(to_snapshot_info, snapshots) + environment_table_infos = map(to_snapshot_info, environment.snapshots) + assert set(snapshot_infos) == set(environment_table_infos) + + +def validate_tables( + snapshots: t.Iterable[Snapshot], + context: Context, + deployability_index: t.Optional[DeployabilityIndex] = None, +) -> None: + adapter = context.engine_adapter + deployability_index = deployability_index or DeployabilityIndex.all_deployable() + for snapshot in snapshots: + is_deployable = deployability_index.is_representative(snapshot) + if not snapshot.is_model or snapshot.is_external: + continue + table_should_exist = not snapshot.is_embedded + assert adapter.table_exists(snapshot.table_name(is_deployable)) == table_should_exist + if table_should_exist: + assert select_all(snapshot.table_name(is_deployable), adapter) + + +def validate_environment_views( + snapshots: t.Iterable[Snapshot], + environment: str, + context: Context, + deployability_index: t.Optional[DeployabilityIndex] = None, +) -> None: + adapter = context.engine_adapter + deployability_index = deployability_index or DeployabilityIndex.all_deployable() + for snapshot in snapshots: + is_deployable = deployability_index.is_representative(snapshot) + if not snapshot.is_model or snapshot.is_symbolic: + continue + view_name = snapshot.qualified_view_name.for_environment( + EnvironmentNamingInfo.from_environment_catalog_mapping( + context.config.environment_catalog_mapping, + name=environment, + suffix_target=context.config.environment_suffix_target, + ) + ) + + assert adapter.table_exists(view_name) + assert select_all(snapshot.table_name(is_deployable), adapter) == select_all( + view_name, adapter + ) + + +def validate_apply_basics( + context: Context, + environment: str, + snapshots: t.Iterable[Snapshot], + deployability_index: t.Optional[DeployabilityIndex] = None, +) -> None: + validate_snapshots_in_state_sync(snapshots, context) + validate_state_sync_environment(snapshots, environment, context) + validate_tables(snapshots, context, deployability_index) + validate_environment_views(snapshots, environment, context, deployability_index) + + +def validate_plan_changes( + plan: Plan, + *, + added: t.Optional[t.Iterable[SnapshotId]] = None, + modified: t.Optional[t.Iterable[str]] = None, + removed: t.Optional[t.Iterable[SnapshotId]] = None, +) -> None: + added = added or [] + modified = modified or [] + removed = removed or [] + assert set(added) == plan.context_diff.added + assert set(modified) == set(plan.context_diff.modified_snapshots) + assert set(removed) == set(plan.context_diff.removed_snapshots) + + +def validate_versions_same( + model_names: t.List[str], + versions: t.Dict[str, str], + other_versions: t.Dict[str, str], +) -> None: + for name in model_names: + assert versions[name] == other_versions[name] + + +def validate_versions_different( + model_names: t.List[str], + versions: t.Dict[str, str], + other_versions: t.Dict[str, str], +) -> None: + for name in model_names: + assert versions[name] != other_versions[name] + + +def validate_query_change( + context: Context, + environment: str, + change_category: SnapshotChangeCategory, + logical: bool, +): + versions = snapshots_to_versions(context.snapshots.values()) + + change_data_type( + context, + "sushi.items", + DataType.Type.DOUBLE, + DataType.Type.FLOAT, + ) + + directly_modified = ['"memory"."sushi"."items"'] + indirectly_modified = [ + '"memory"."sushi"."order_items"', + '"memory"."sushi"."waiter_revenue_by_day"', + '"memory"."sushi"."customer_revenue_by_day"', + '"memory"."sushi"."customer_revenue_lifetime"', + '"memory"."sushi"."top_waiters"', + "assert_item_price_above_zero", + ] + not_modified = [ + snapshot.name + for snapshot in context.snapshots.values() + if snapshot.name not in directly_modified and snapshot.name not in indirectly_modified + ] + + if change_category == SnapshotChangeCategory.BREAKING and not logical: + models_same = not_modified + models_different = directly_modified + indirectly_modified + elif change_category == SnapshotChangeCategory.FORWARD_ONLY: + models_same = not_modified + directly_modified + indirectly_modified + models_different = [] + else: + models_same = not_modified + indirectly_modified + models_different = directly_modified + + def _validate_plan(context, plan): + validate_plan_changes(plan, modified=directly_modified + indirectly_modified) + assert bool(plan.missing_intervals) != logical + + def _validate_apply(context): + current_versions = snapshots_to_versions(context.snapshots.values()) + validate_versions_same(models_same, versions, current_versions) + validate_versions_different(models_different, versions, current_versions) + + apply_to_environment( + context, + environment, + change_category, + plan_validators=[_validate_plan], + apply_validators=[_validate_apply], + ) + + +def initial_add(context: Context, environment: str): + assert not context.state_reader.get_environment(environment) + + plan = context.plan(environment, start=start(context), create_from="nonexistent_env") + validate_plan_changes(plan, added={x.snapshot_id for x in context.snapshots.values()}) + + context.apply(plan) + validate_apply_basics(context, environment, plan.snapshots.values()) + + +def change_model_kind(context: Context, kind: ModelKindName): + if kind in (ModelKindName.VIEW, ModelKindName.EMBEDDED, ModelKindName.FULL): + context.upsert_model( + "sushi.items", + partitioned_by=[], + ) + context.upsert_model("sushi.items", kind=model_kind_type_from_name(kind)()) # type: ignore + + +def validate_model_kind_change( + kind_name: ModelKindName, + context: Context, + environment: str, + *, + logical: bool, +): + directly_modified = ['"memory"."sushi"."items"'] + indirectly_modified = [ + '"memory"."sushi"."order_items"', + '"memory"."sushi"."waiter_revenue_by_day"', + '"memory"."sushi"."customer_revenue_by_day"', + '"memory"."sushi"."customer_revenue_lifetime"', + '"memory"."sushi"."top_waiters"', + "assert_item_price_above_zero", + ] + if kind_name == ModelKindName.INCREMENTAL_BY_TIME_RANGE: + kind: ModelKind = IncrementalByTimeRangeKind(time_column=TimeColumn(column="event_date")) + elif kind_name == ModelKindName.INCREMENTAL_BY_UNIQUE_KEY: + kind = IncrementalByUniqueKeyKind(unique_key="id") + else: + kind = model_kind_type_from_name(kind_name)() # type: ignore + + def _validate_plan(context, plan): + validate_plan_changes(plan, modified=directly_modified + indirectly_modified) + assert ( + next( + snapshot + for snapshot in plan.snapshots.values() + if snapshot.name == '"memory"."sushi"."items"' + ).model.kind.name + == kind.name + ) + assert bool(plan.missing_intervals) != logical + + apply_to_environment( + context, + environment, + SnapshotChangeCategory.NON_BREAKING, + plan_validators=[_validate_plan], + ) diff --git a/tests/core/linter/test_builtin.py b/tests/core/linter/test_builtin.py new file mode 100644 index 0000000000..0ff91470ff --- /dev/null +++ b/tests/core/linter/test_builtin.py @@ -0,0 +1,234 @@ +import os + +from sqlmesh import Context +from sqlmesh.core.linter.rule import Position, Range + + +def test_no_missing_external_models(tmp_path, copy_to_temp_path) -> None: + """ + Tests that the linter correctly identifies unregistered external model dependencies. + + This test removes the `external_models.yaml` file from the sushi example project, + enables the linter, and verifies that the linter raises a violation for a model + that depends on unregistered external models. + """ + sushi_paths = copy_to_temp_path("examples/sushi") + sushi_path = sushi_paths[0] + + # Remove the external_models.yaml file + os.remove(sushi_path / "external_models.yaml") + + # Override the config.py to turn on lint + with open(sushi_path / "config.py", "r") as f: + read_file = f.read() + + before = """ linter=LinterConfig( + enabled=False, + rules=[ + "ambiguousorinvalidcolumn", + "invalidselectstarexpansion", + "noselectstar", + "nomissingaudits", + "nomissingowner", + "nomissingexternalmodels", + ], + ),""" + after = """linter=LinterConfig(enabled=True, rules=["nomissingexternalmodels"]),""" + read_file = read_file.replace(before, after) + assert after in read_file + with open(sushi_path / "config.py", "w") as f: + f.writelines(read_file) + + # Load the context with the temporary sushi path + context = Context(paths=[sushi_path]) + + # Lint the models + lints = context.lint_models(raise_on_error=False) + assert len(lints) == 1 + lint = lints[0] + assert lint.violation_range is not None + assert ( + lint.violation_msg + == """Model '"memory"."sushi"."customers"' depends on unregistered external model '"memory"."raw"."demographics"'. Please register it in the external models file. This can be done by running 'sqlmesh create_external_models'.""" + ) + assert len(lint.fixes) == 1 + fix = lint.fixes[0] + assert len(fix.edits) == 0 + assert len(fix.create_files) == 1 + create = fix.create_files[0] + assert create.path == sushi_path / "external_models.yaml" + assert create.text == '- name: \'"memory"."raw"."demographics"\'\n' + + +def test_no_missing_external_models_with_existing_file_ending_in_newline( + tmp_path, copy_to_temp_path +) -> None: + sushi_paths = copy_to_temp_path("examples/sushi") + sushi_path = sushi_paths[0] + + # Overwrite the external_models.yaml file to end with a random file and a newline + os.remove(sushi_path / "external_models.yaml") + with open(sushi_path / "external_models.yaml", "w") as f: + f.write("- name: memory.raw.test\n") + + # Override the config.py to turn on lint + with open(sushi_path / "config.py", "r") as f: + read_file = f.read() + + before = """ linter=LinterConfig( + enabled=False, + rules=[ + "ambiguousorinvalidcolumn", + "invalidselectstarexpansion", + "noselectstar", + "nomissingaudits", + "nomissingowner", + "nomissingexternalmodels", + ], + ),""" + after = """linter=LinterConfig(enabled=True, rules=["nomissingexternalmodels"]),""" + read_file = read_file.replace(before, after) + assert after in read_file + with open(sushi_path / "config.py", "w") as f: + f.writelines(read_file) + + # Load the context with the temporary sushi path + context = Context(paths=[sushi_path]) + + # Lint the models + lints = context.lint_models(raise_on_error=False) + assert len(lints) == 1 + lint = lints[0] + assert lint.violation_range is not None + assert ( + lint.violation_msg + == """Model '"memory"."sushi"."customers"' depends on unregistered external model '"memory"."raw"."demographics"'. Please register it in the external models file. This can be done by running 'sqlmesh create_external_models'.""" + ) + assert len(lint.fixes) == 1 + fix = lint.fixes[0] + assert len(fix.edits) == 1 + edit = fix.edits[0] + assert edit.new_text == """- name: '"memory"."raw"."demographics"'\n""" + assert edit.range == Range( + start=Position(line=1, character=0), + end=Position(line=1, character=0), + ) + fix_path = sushi_path / "external_models.yaml" + assert edit.path == fix_path + + +def test_no_missing_external_models_with_existing_file_not_ending_in_newline( + tmp_path, copy_to_temp_path +) -> None: + sushi_paths = copy_to_temp_path("examples/sushi") + sushi_path = sushi_paths[0] + + # Overwrite the external_models.yaml file to end with a random file and a newline + os.remove(sushi_path / "external_models.yaml") + with open(sushi_path / "external_models.yaml", "w") as f: + f.write("- name: memory.raw.test") + + # Override the config.py to turn on lint + with open(sushi_path / "config.py", "r") as f: + read_file = f.read() + + before = """ linter=LinterConfig( + enabled=False, + rules=[ + "ambiguousorinvalidcolumn", + "invalidselectstarexpansion", + "noselectstar", + "nomissingaudits", + "nomissingowner", + "nomissingexternalmodels", + ], + ),""" + after = """linter=LinterConfig(enabled=True, rules=["nomissingexternalmodels"]),""" + read_file = read_file.replace(before, after) + assert after in read_file + with open(sushi_path / "config.py", "w") as f: + f.writelines(read_file) + + # Load the context with the temporary sushi path + context = Context(paths=[sushi_path]) + + # Lint the models + lints = context.lint_models(raise_on_error=False) + assert len(lints) == 1 + lint = lints[0] + assert lint.violation_range is not None + assert ( + lint.violation_msg + == """Model '"memory"."sushi"."customers"' depends on unregistered external model '"memory"."raw"."demographics"'. Please register it in the external models file. This can be done by running 'sqlmesh create_external_models'.""" + ) + assert len(lint.fixes) == 1 + fix = lint.fixes[0] + assert len(fix.edits) == 1 + edit = fix.edits[0] + assert edit.new_text == """\n- name: '"memory"."raw"."demographics"'\n""" + assert edit.range == Range( + start=Position(line=0, character=23), + end=Position(line=0, character=23), + ) + fix_path = sushi_path / "external_models.yaml" + assert edit.path == fix_path + + +def test_no_missing_unit_tests(tmp_path, copy_to_temp_path): + """ + Tests that the NoMissingUnitTest linter rule correctly identifies models + without corresponding unit tests in the tests/ directory + + This test checks the sushi example project, enables the linter, + and verifies that the linter raises a rule violation for the models + that do not have a unit test + """ + sushi_paths = copy_to_temp_path("examples/sushi") + sushi_path = sushi_paths[0] + + # Override the config.py to turn on lint + with open(sushi_path / "config.py", "r") as f: + read_file = f.read() + + before = """ linter=LinterConfig( + enabled=False, + rules=[ + "ambiguousorinvalidcolumn", + "invalidselectstarexpansion", + "noselectstar", + "nomissingaudits", + "nomissingowner", + "nomissingexternalmodels", + ], + ),""" + after = """linter=LinterConfig(enabled=True, rules=["nomissingunittest"]),""" + read_file = read_file.replace(before, after) + assert after in read_file + with open(sushi_path / "config.py", "w") as f: + f.writelines(read_file) + + # Load the context with the temporary sushi path + context = Context(paths=[sushi_path]) + + # Lint the models + lints = context.lint_models(raise_on_error=False) + + # Should have violations for models without tests (most models except customers) + assert len(lints) >= 1 + + # Check that we get violations for models without tests + violation_messages = [lint.violation_msg for lint in lints] + assert any("is missing unit test(s)" in msg for msg in violation_messages) + + # Check that models with existing tests don't have violations + models_with_tests = ["customer_revenue_by_day", "customer_revenue_lifetime", "order_items"] + + for model_name in models_with_tests: + model_violations = [ + lint + for lint in lints + if model_name in lint.violation_msg and "is missing unit test(s)" in lint.violation_msg + ] + assert len(model_violations) == 0, ( + f"Model {model_name} should not have a violation since it has a test" + ) diff --git a/tests/core/linter/test_helpers.py b/tests/core/linter/test_helpers.py new file mode 100644 index 0000000000..c3ba46f304 --- /dev/null +++ b/tests/core/linter/test_helpers.py @@ -0,0 +1,130 @@ +from sqlmesh import Context +from sqlmesh.core.linter.helpers import ( + read_range_from_file, + get_range_of_model_block, + get_range_of_a_key_in_model_block, +) +from sqlmesh.core.model import SqlModel + + +def test_get_position_of_model_block(): + context = Context(paths=["examples/sushi"]) + + sql_models = [ + model + for model in context.models.values() + if isinstance(model, SqlModel) + and model._path is not None + and str(model._path).endswith(".sql") + ] + assert len(sql_models) > 0 + + for model in sql_models: + dialect = model.dialect + assert dialect is not None + + path = model._path + assert path is not None + + with open(path, "r", encoding="utf-8") as file: + content = file.read() + + as_lines = content.splitlines() + + range = get_range_of_model_block(content, dialect) + assert range is not None + + # Check that the range starts with MODEL and ends with ; + read_range = read_range_from_file(path, range) + assert read_range.startswith("MODEL") + assert read_range.endswith(";") + + +def test_get_range_of_a_key_in_model_block_testing_on_sushi(): + context = Context(paths=["examples/sushi"]) + + sql_models = [ + model + for model in context.models.values() + if isinstance(model, SqlModel) + and model._path is not None + and str(model._path).endswith(".sql") + ] + assert len(sql_models) > 0 + + # Test that the function works for all keys in the model block + for model in sql_models: + possible_keys = [ + "name", + "tags", + "description", + "column_descriptions", + "owner", + "cron", + "dialect", + ] + + dialect = model.dialect + assert dialect is not None + + path = model._path + assert path is not None + + with open(path, "r", encoding="utf-8") as file: + content = file.read() + + count_properties_checked = 0 + + for key in possible_keys: + ranges = get_range_of_a_key_in_model_block(content, dialect, key) + + if ranges: + key_range, value_range = ranges + read_key = read_range_from_file(path, key_range) + assert read_key.lower() == key.lower() + # Value range should be non-empty + read_value = read_range_from_file(path, value_range) + assert len(read_value) > 0 + count_properties_checked += 1 + + assert count_properties_checked > 0 + + # Test that the function works for different kind of value blocks + tests = [ + ("sushi.customers", "name", "sushi.customers"), + ( + "sushi.customers", + "tags", + "(pii, fact)", + ), + ("sushi.customers", "description", "'Sushi customer data'"), + ( + "sushi.customers", + "column_descriptions", + "( customer_id = 'customer_id uniquely identifies customers' )", + ), + ("sushi.customers", "owner", "jen"), + ("sushi.customers", "cron", "'@daily'"), + ] + for model_name, key, value in tests: + model = context.get_model(model_name) + assert model is not None + + dialect = model.dialect + assert dialect is not None + + path = model._path + assert path is not None + + with open(path, "r", encoding="utf-8") as file: + content = file.read() + + ranges = get_range_of_a_key_in_model_block(content, dialect, key) + assert ranges is not None, f"Could not find key '{key}' in model '{model_name}'" + + key_range, value_range = ranges + read_key = read_range_from_file(path, key_range) + assert read_key.lower() == key.lower() + + read_value = read_range_from_file(path, value_range) + assert read_value == value diff --git a/tests/core/state_sync/test_export_import.py b/tests/core/state_sync/test_export_import.py new file mode 100644 index 0000000000..769fa2c2fa --- /dev/null +++ b/tests/core/state_sync/test_export_import.py @@ -0,0 +1,621 @@ +import pytest +from pathlib import Path +from sqlmesh.core.state_sync import StateSync, EngineAdapterStateSync, CachingStateSync +from sqlmesh.core.state_sync.export_import import export_state, import_state +from sqlmesh.utils.errors import SQLMeshError +from sqlmesh.core import constants as c +from sqlmesh.cli.project_init import init_example_project +from sqlmesh.core.context import Context +from sqlmesh.core.environment import Environment +from sqlmesh.core.config import Config, GatewayConfig, DuckDBConnectionConfig, ModelDefaultsConfig + +import json + + +@pytest.fixture +def example_project_config(tmp_path: Path) -> Config: + return Config( + gateways={ + "main": GatewayConfig( + connection=DuckDBConnectionConfig(database=str(tmp_path / "warehouse.db")), + state_connection=DuckDBConnectionConfig(database=str(tmp_path / "state.db")), + ) + }, + default_gateway="main", + model_defaults=ModelDefaultsConfig( + dialect="duckdb", + ), + ) + + +@pytest.fixture +def state_sync(tmp_path: Path, example_project_config: Config) -> StateSync: + return EngineAdapterStateSync( + engine_adapter=example_project_config.get_state_connection("main").create_engine_adapter(), # type: ignore + schema=c.SQLMESH, + cache_dir=tmp_path / c.CACHE, + ) + + +def test_export_empty_state(tmp_path: Path, state_sync: StateSync) -> None: + output_file = tmp_path / "state_dump.json" + + # Cannot dump an un-migrated state database + with pytest.raises(SQLMeshError, match=r"Please run a migration"): + export_state(state_sync, output_file) + + state_sync.migrate() + + export_state(state_sync, output_file) + + state = json.loads(output_file.read_text(encoding="utf8")) + + assert "metadata" in state + metadata = state["metadata"] + assert "timestamp" in metadata + assert "file_version" in metadata + assert "importable" in metadata + + assert "versions" in state + versions = state["versions"] + assert "schema_version" in versions + assert "sqlglot_version" in versions + assert "sqlmesh_version" in versions + + assert "snapshots" in state + assert isinstance(state["snapshots"], list) + assert len(state["snapshots"]) == 0 + + assert "environments" in state + assert isinstance(state["environments"], dict) + assert len(state["environments"]) == 0 + + +def test_export_entire_project( + tmp_path: Path, example_project_config: Config, state_sync: StateSync +) -> None: + init_example_project(path=tmp_path, engine_type="duckdb") + context = Context(paths=tmp_path, config=example_project_config, state_sync=state_sync) + + # prod + plan = context.plan(auto_apply=True) + assert len(plan.modified_snapshots) > 0 + + # modify full_model + (tmp_path / c.MODELS / "full_model.sql").write_text(""" + MODEL ( + name sqlmesh_example.full_model, + kind FULL, + cron '@daily' + ); + + SELECT + item_id, + COUNT(DISTINCT id) AS num_orders, + '1' as modified + FROM sqlmesh_example.incremental_model + GROUP BY item_id; + """) + + # add a new model + (tmp_path / c.MODELS / "new_model.sql").write_text(""" + MODEL ( + name sqlmesh_example.new_model, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key id, + auto_restatement_cron '@daily' + ), + cron '@daily' + ); + + SELECT 1 as id; + """) + + # dev + context.load() + context.plan(environment="dev", auto_apply=True, skip_tests=True) + + output_file = tmp_path / "state_dump.json" + export_state(state_sync, output_file) + + state = json.loads(output_file.read_text(encoding="utf8")) + assert "metadata" in state + # full project dumps can always be imported back + assert state["metadata"]["importable"] + + assert "versions" in state + + assert len(state["snapshots"]) > 0 + snapshot_names = [s["name"] for s in state["snapshots"]] + assert len(snapshot_names) == 5 + assert '"warehouse"."sqlmesh_example"."full_model"' in snapshot_names # will be in here twice + assert '"warehouse"."sqlmesh_example"."incremental_model"' in snapshot_names + assert '"warehouse"."sqlmesh_example"."seed_model"' in snapshot_names + assert '"warehouse"."sqlmesh_example"."new_model"' in snapshot_names + + assert "prod" in state["environments"] + assert "dev" in state["environments"] + + prod = state["environments"]["prod"]["environment"] + assert len(prod["snapshots"]) == 3 + prod_snapshot_ids = [s.snapshot_id for s in Environment.model_validate(prod).snapshots] + + dev = state["environments"]["dev"]["environment"] + assert len(dev["snapshots"]) == 4 + dev_snapshot_ids = [s.snapshot_id for s in Environment.model_validate(dev).snapshots] + + full_model_id = next(s for s in dev_snapshot_ids if "full_model" in s.name) + incremental_model_id = next(s for s in dev_snapshot_ids if "incremental_model" in s.name) + seed_model_id = next(s for s in dev_snapshot_ids if "seed_model" in s.name) + new_model_id = next(s for s in dev_snapshot_ids if "new_model" in s.name) + + assert incremental_model_id in prod_snapshot_ids + assert seed_model_id in prod_snapshot_ids + assert full_model_id not in prod_snapshot_ids + assert new_model_id not in prod_snapshot_ids + + +def test_export_specific_environment( + tmp_path: Path, example_project_config: Config, state_sync: StateSync +) -> None: + output_file = tmp_path / "state_dump.json" + init_example_project(path=tmp_path, engine_type="duckdb") + context = Context(paths=tmp_path, config=example_project_config, state_sync=state_sync) + + # create prod + context.plan(auto_apply=True) + + with pytest.raises(SQLMeshError, match=r"No such environment"): + export_state(state_sync, output_file, environment_names=["FOO"]) + + # modify full_model + (tmp_path / c.MODELS / "full_model.sql").write_text(""" + MODEL ( + name sqlmesh_example.full_model, + kind FULL, + cron '@daily' + ); + + SELECT + item_id, + COUNT(DISTINCT id) AS num_orders, + '1' as modified + FROM sqlmesh_example.incremental_model + GROUP BY item_id; + """) + + # plan dev + context.load() + plan = context.plan(environment="dev", auto_apply=True, skip_tests=True) + assert len(plan.modified_snapshots) == 1 + + # export dev - all models should be included + export_state(state_sync, output_file, environment_names=["dev"]) + + dev_state = json.loads(output_file.read_text(encoding="utf8")) + + assert "metadata" in dev_state + assert "versions" in dev_state + + assert len(dev_state["snapshots"]) == 3 + snapshot_names = [s["name"] for s in dev_state["snapshots"]] + assert any("full_model" in name for name in snapshot_names) + assert any("incremental_model" in name for name in snapshot_names) + assert any("seed_model" in name for name in snapshot_names) + dev_full_model = next(s for s in dev_state["snapshots"] if "full_model" in s["name"]) + + assert len(dev_state["environments"]) == 1 + assert "dev" in dev_state["environments"] + + # this state dump is still importable even though its just a subset + assert dev_state["metadata"]["importable"] + + # export prod - prod full_model should be a different version to dev + export_state(state_sync, output_file, environment_names=["prod"]) + + prod_state = json.loads(output_file.read_text(encoding="utf8")) + snapshot_names = [s["name"] for s in prod_state["snapshots"]] + assert any("full_model" in name for name in snapshot_names) + assert any("incremental_model" in name for name in snapshot_names) + assert any("seed_model" in name for name in snapshot_names) + prod_full_model = next(s for s in prod_state["snapshots"] if "full_model" in s["name"]) + + assert len(prod_state["environments"]) == 1 + assert "prod" in prod_state["environments"] + assert prod_state["metadata"]["importable"] + + assert dev_full_model["fingerprint"]["data_hash"] != prod_full_model["fingerprint"]["data_hash"] + + +def test_export_local_state( + tmp_path: Path, example_project_config: Config, state_sync: StateSync +) -> None: + output_file = tmp_path / "state_dump.json" + init_example_project(path=tmp_path, engine_type="duckdb") + context = Context(paths=tmp_path, config=example_project_config, state_sync=state_sync) + + # create prod + context.plan(auto_apply=True) + + # modify full_model - create local change + (tmp_path / c.MODELS / "full_model.sql").write_text(""" + MODEL ( + name sqlmesh_example.full_model, + kind FULL, + cron '@daily' + ); + + SELECT + item_id, + COUNT(DISTINCT id) AS num_orders, + '1' as modified + FROM sqlmesh_example.incremental_model + GROUP BY item_id; + """) + + # add a new model + (tmp_path / c.MODELS / "new_model.sql").write_text(""" + MODEL ( + name sqlmesh_example.new_model, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key id, + auto_restatement_cron '@daily' + ), + cron '@daily' + ); + + SELECT 1 as id; + """) + + assert len(context.snapshots) == 3 + + context.load() + + assert len(context.snapshots) == 4 + + export_state(state_sync, output_file, context.snapshots) + state = json.loads(output_file.read_text(encoding="utf8")) + assert "metadata" in state + assert "versions" in state + + # this state dump cannot be imported because its just local state + assert not state["metadata"]["importable"] + + # no environments because local state is just snapshots + assert len(state["environments"]) == 0 + + snapshots = state["snapshots"] + assert len(snapshots) == 4 + full_model = next(s for s in snapshots if "full_model" in s["name"]) + new_model = next(s for s in snapshots if "new_model" in s["name"]) + + assert "'1' as modified" in full_model["node"]["query"]["sql"] + assert "SELECT 1 as id" in new_model["node"]["query"]["sql"] + + +def test_import_invalid_file(tmp_path: Path, state_sync: StateSync) -> None: + state_file = tmp_path / "state.json" + state_file.write_text("invalid json file") + + with pytest.raises(Exception, match=r"Invalid JSON character"): + import_state(state_sync, state_file) + + state_file.write_text("[]") + with pytest.raises(SQLMeshError, match=r"Expected JSON object"): + import_state(state_sync, state_file) + + state_file.write_text("{}") + with pytest.raises(SQLMeshError, match=r"Expecting a 'metadata' key"): + import_state(state_sync, state_file) + + state_file.write_text('{ "metadata": [] }') + with pytest.raises(SQLMeshError, match=r"Expecting the 'metadata' key to contain an object"): + import_state(state_sync, state_file) + + state_file.write_text('{ "metadata": {} }') + with pytest.raises(SQLMeshError, match=r"Unable to determine state file format version"): + import_state(state_sync, state_file) + + state_file.write_text('{ "metadata": { "file_version": "blah" } }') + with pytest.raises(SQLMeshError, match=r"Unable to parse state file format version"): + import_state(state_sync, state_file) + + state_file.write_text('{ "metadata": { "file_version": 1, "importable": false } }') + with pytest.raises(SQLMeshError, match=r"not importable"): + import_state(state_sync, state_file) + + +def test_import_from_older_version_export_fails(tmp_path: Path, state_sync: StateSync) -> None: + state_sync.migrate() + current_version = state_sync.get_versions() + + major, minor = current_version.minor_sqlmesh_version + older_version = current_version.copy(update=dict(sqlmesh_version=f"{major}.{minor - 1}.0")) + + assert older_version.minor_sqlmesh_version < current_version.minor_sqlmesh_version + + state_file = tmp_path / "state.json" + state_versions = older_version.model_dump(mode="json") + state_file.write_text( + json.dumps( + { + "metadata": { + "timestamp": "2024-01-01 00:00:00", + "file_version": 1, + "importable": True, + }, + "versions": state_versions, + } + ) + ) + + with pytest.raises(SQLMeshError, match=r"SQLMesh version mismatch"): + import_state(state_sync, state_file) + + +def test_import_from_newer_version_export_fails(tmp_path: Path, state_sync: StateSync) -> None: + state_sync.migrate() + current_version = state_sync.get_versions() + + major, minor = current_version.minor_sqlmesh_version + newer_version = current_version.copy(update=dict(sqlmesh_version=f"{major}.{minor + 1}.0")) + + assert newer_version.minor_sqlmesh_version > current_version.minor_sqlmesh_version + + state_file = tmp_path / "state.json" + state_versions = newer_version.model_dump(mode="json") + state_file.write_text( + json.dumps( + { + "versions": state_versions, + "metadata": { + "timestamp": "2024-01-01 00:00:00", + "file_version": 1, + "importable": True, + }, + } + ) + ) + + with pytest.raises(SQLMeshError, match=r"SQLMesh version mismatch"): + import_state(state_sync, state_file) + + +def test_import_local_state_fails( + tmp_path: Path, example_project_config: Config, state_sync: StateSync +) -> None: + output_file = tmp_path / "state_dump.json" + init_example_project(path=tmp_path, engine_type="duckdb") + context = Context(paths=tmp_path, config=example_project_config, state_sync=state_sync) + + export_state(state_sync, output_file, context.snapshots) + state = json.loads(output_file.read_text(encoding="utf8")) + assert len(state["snapshots"]) == 3 + + with pytest.raises(SQLMeshError, match=r"not importable"): + import_state(state_sync, output_file) + + +def test_import_partial( + tmp_path: Path, example_project_config: Config, state_sync: StateSync +) -> None: + output_file = tmp_path / "state_dump.json" + init_example_project(path=tmp_path, engine_type="duckdb") + context = Context(paths=tmp_path, config=example_project_config, state_sync=state_sync) + + # create prod + context.plan(auto_apply=True) + + # add a new model + (tmp_path / c.MODELS / "new_model.sql").write_text(""" + MODEL ( + name sqlmesh_example.new_model, + kind FULL, + cron '@daily' + ); + + SELECT 1 as id; + """) + + # create dev + context.load() + context.plan(environment="dev", auto_apply=True, skip_tests=True) + + # export just dev + export_state(state_sync, output_file, environment_names=["dev"]) + + state = json.loads(output_file.read_text(encoding="utf8")) + # mess with the file to rename "dev" to "dev2" + dev = state["environments"].pop("dev") + dev["environment"]["name"] = "dev2" + state["environments"]["dev2"] = dev + + assert list(state["environments"].keys()) == ["dev2"] + output_file.write_text(json.dumps(state), encoding="utf8") + + # import "dev2" + import_state(state_sync, output_file, clear=False) + + # StateSync should have "prod", "dev" and "dev2". + assert sorted(list(env.name for env in state_sync.get_environments_summary())) == [ + "dev", + "dev2", + "prod", + ] + + assert not context.plan(environment="dev", skip_tests=True).has_changes + assert not context.plan(environment="dev2", skip_tests=True).has_changes + assert context.plan( + environment="prod", skip_tests=True + ).has_changes # prod has changes the 'new_model' model hasnt been applied + + +def test_roundtrip(tmp_path: Path, example_project_config: Config, state_sync: StateSync) -> None: + state_file = tmp_path / "state_dump.json" + + init_example_project(path=tmp_path, engine_type="duckdb") + context = Context(paths=tmp_path, config=example_project_config, state_sync=state_sync) + + # populate initial state + plan = context.plan(auto_apply=True) + assert plan.has_changes + + # plan again to prove no changes + plan = context.plan(auto_apply=True) + assert not plan.has_changes + + export_state(state_sync, state_file) + assert len(state_file.read_text()) > 0 + + # destroy state + assert isinstance(state_sync, EngineAdapterStateSync) + state_sync.engine_adapter.drop_schema("sqlmesh", cascade=True) + + # state was destroyed, plan should have changes + state_sync.migrate() + plan = context.plan() + assert plan.has_changes + + # load in state dump + import_state(state_sync, state_file) + + # plan should have no changes now our state is back + plan = context.plan() + assert not plan.has_changes + + # add a new model + (tmp_path / c.MODELS / "new_model.sql").write_text(""" + MODEL ( + name sqlmesh_example.new_model, + kind FULL, + cron '@daily' + ); + + SELECT 1 as id; + """) + + context.load() + plan = context.plan(environment="dev", auto_apply=True) + assert plan.has_changes + + plan = context.plan(environment="dev") + assert not plan.has_changes + + # dump new state that contains the 'dev' environment + export_state(state_sync, state_file) + + # show state destroyed + state_sync.engine_adapter.drop_schema("sqlmesh", cascade=True) + with pytest.raises(SQLMeshError, match=r"Please run a migration"): + state_sync.get_versions(validate=True) + + state_sync.migrate() + import_state(state_sync, state_file) + + # should be no changes in dev + assert not context.plan(environment="dev").has_changes + + # prod should show a change for adding 'new_model' + prod_plan = context.plan(environment="prod") + assert prod_plan.new_snapshots == [] + assert len(prod_plan.modified_snapshots) == 1 + assert "new_model" in list(prod_plan.modified_snapshots.values())[0].name + + +def test_roundtrip_includes_auto_restatements( + tmp_path: Path, example_project_config: Config, state_sync: StateSync +) -> None: + init_example_project(path=tmp_path, engine_type="duckdb") + + # add a model with auto restatements + (tmp_path / c.MODELS / "new_model.sql").write_text(""" + MODEL ( + name sqlmesh_example.new_model, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key id, + auto_restatement_cron '@daily' + ), + cron '@daily' + ); + + SELECT 1 as id; + """) + + context = Context(paths=tmp_path, config=example_project_config, state_sync=state_sync) + context.plan(auto_apply=True) + + # dump state + output_file = tmp_path / "state_dump.json" + export_state(state_sync, output_file) + state = json.loads(output_file.read_text(encoding="utf8")) + + snapshots = state["snapshots"] + assert len(snapshots) == 4 + + # auto restatements only work after a cadence run + new_model_snapshot = next(s for s in snapshots if "new_model" in s["name"]) + assert "next_auto_restatement_ts" not in new_model_snapshot + + # trigger cadence run and re-dump show auto restatement dumped + context.run() + + export_state(state_sync, output_file) + state = json.loads(output_file.read_text()) + + new_model_snapshot = next(s for s in state["snapshots"] if "new_model" in s["name"]) + assert new_model_snapshot["next_auto_restatement_ts"] > 0 + + # import the state again and run a plan to show there is no changes / the auto restatement was imported + import_state(state_sync, output_file) + + plan = context.plan(skip_tests=True) + assert not plan.has_changes + + +def test_roundtrip_includes_environment_statements(tmp_path: Path) -> None: + config = Config( + gateways={ + "main": GatewayConfig( + connection=DuckDBConnectionConfig(database=str(tmp_path / "warehouse.db")), + state_connection=DuckDBConnectionConfig(database=str(tmp_path / "state.db")), + ) + }, + default_gateway="main", + model_defaults=ModelDefaultsConfig( + dialect="duckdb", + ), + before_all=["select 1 as before_all"], + after_all=["select 2 as after_all"], + ) + + context = Context(paths=tmp_path, config=config) + context.plan(auto_apply=True) + + state_file = tmp_path / "state_dump.json" + context.export_state(state_file) + + environments = json.loads(state_file.read_text(encoding="utf8"))["environments"] + + assert environments["prod"]["statements"][0]["before_all"][0] == "select 1 as before_all" + assert environments["prod"]["statements"][0]["after_all"][0] == "select 2 as after_all" + + assert not context.plan().has_changes + + state_sync = context.state_sync + assert isinstance(state_sync, CachingStateSync) + assert isinstance(state_sync.state_sync, EngineAdapterStateSync) + + # show state destroyed + state_sync.state_sync.engine_adapter.drop_schema("sqlmesh", cascade=True) # type: ignore + with pytest.raises(SQLMeshError, match=r"Please run a migration"): + state_sync.get_versions(validate=True) + + state_sync.migrate() + import_state(state_sync, state_file) + + assert not context.plan().has_changes + + environment_statements = state_sync.get_environment_statements("prod") + assert len(environment_statements) == 1 + assert environment_statements[0].before_all[0] == "select 1 as before_all" + assert environment_statements[0].after_all[0] == "select 2 as after_all" diff --git a/tests/core/state_sync/test_state_sync.py b/tests/core/state_sync/test_state_sync.py new file mode 100644 index 0000000000..88e168c216 --- /dev/null +++ b/tests/core/state_sync/test_state_sync.py @@ -0,0 +1,4051 @@ +import json +import logging +import re +import typing as t +from unittest.mock import call, patch + +import duckdb # noqa: TID253 +import pandas as pd # noqa: TID253 +import pytest +import time_machine +from pytest_mock.plugin import MockerFixture +from sqlglot import exp + +from sqlmesh.core import constants as c +from sqlmesh.core.config import EnvironmentSuffixTarget +from sqlmesh.core.dialect import parse_one +from sqlmesh.core.engine_adapter import create_engine_adapter +from sqlmesh.core.environment import Environment, EnvironmentStatements +from sqlmesh.core.model import ( + FullKind, + IncrementalByTimeRangeKind, + Seed, + SeedKind, + SeedModel, + SqlModel, +) +from sqlmesh.core.snapshot import ( + Snapshot, + SnapshotChangeCategory, + SnapshotId, + SnapshotIntervals, + SnapshotNameVersion, + SnapshotTableCleanupTask, + missing_intervals, +) +from sqlmesh.core.state_sync import ( + CachingStateSync, + EngineAdapterStateSync, +) +from sqlmesh.core.state_sync.base import ( + SCHEMA_VERSION, + SQLGLOT_VERSION, + Versions, +) +from sqlmesh.core.state_sync.common import ( + ExpiredBatchRange, + LimitBoundary, + PromotionResult, + RowBoundary, +) +from sqlmesh.utils.date import now_timestamp, to_datetime, to_timestamp +from sqlmesh.utils.errors import SQLMeshError, StateMigrationError + +pytestmark = pytest.mark.slow + + +def _get_cleanup_tasks( + state_sync: EngineAdapterStateSync, + *, + limit: int = 1000, + ignore_ttl: bool = False, +) -> t.List[SnapshotTableCleanupTask]: + batch = state_sync.get_expired_snapshots( + ignore_ttl=ignore_ttl, + batch_range=ExpiredBatchRange.init_batch_range(batch_size=limit), + ) + return [] if batch is None else batch.cleanup_tasks + + +@pytest.fixture +def state_sync(duck_conn, tmp_path): + state_sync = EngineAdapterStateSync( + create_engine_adapter(lambda: duck_conn, "duckdb"), + schema=c.SQLMESH, + cache_dir=tmp_path / c.CACHE, + ) + state_sync.migrate() + return state_sync + + +@pytest.fixture +def snapshots(make_snapshot: t.Callable) -> t.List[Snapshot]: + return [ + make_snapshot( + SqlModel( + name="a", + query=parse_one("select 1, ds"), + ), + version="a", + ), + make_snapshot( + SqlModel( + name="b", + query=parse_one("select 2, ds"), + ), + version="b", + ), + ] + + +def compare_snapshot_intervals(x: SnapshotIntervals) -> str: + return x.identifier or "" + + +def promote_snapshots( + state_sync: EngineAdapterStateSync, + snapshots: t.List[Snapshot], + environment: str, + no_gaps: bool = False, + no_gaps_snapshot_names: t.Optional[t.Set[str]] = None, + environment_suffix_target: EnvironmentSuffixTarget = EnvironmentSuffixTarget.SCHEMA, + environment_catalog_mapping: t.Optional[t.Dict[re.Pattern, str]] = None, +) -> PromotionResult: + env = Environment.from_environment_catalog_mapping( + environment_catalog_mapping or {}, + name=environment, + suffix_target=environment_suffix_target, + snapshots=[snapshot.table_info for snapshot in snapshots], + start_at="2022-01-01", + end_at="2022-01-01", + plan_id="test_plan_id", + previous_plan_id="test_plan_id", + ) + return state_sync.promote( + env, no_gaps_snapshot_names=no_gaps_snapshot_names if no_gaps else set() + ) + + +def delete_versions(state_sync: EngineAdapterStateSync) -> None: + state_sync.engine_adapter.drop_table(state_sync.version_state.versions_table) + + +def test_push_snapshots( + state_sync: EngineAdapterStateSync, + make_snapshot: t.Callable, +) -> None: + snapshot_a = make_snapshot( + SqlModel( + name="a", + query=parse_one("select 1, ds"), + ) + ) + snapshot_b = make_snapshot( + SqlModel( + name="b", + query=parse_one("select 2, ds"), + ) + ) + + with pytest.raises( + SQLMeshError, + match=r".*has not been versioned.*", + ): + state_sync.push_snapshots([snapshot_a, snapshot_b]) + + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) + snapshot_b.version = "2" + + state_sync.push_snapshots([snapshot_a, snapshot_b]) + + assert state_sync.get_snapshots([snapshot_a.snapshot_id, snapshot_b.snapshot_id]) == { + snapshot_a.snapshot_id: snapshot_a, + snapshot_b.snapshot_id: snapshot_b, + } + + logger = logging.getLogger("sqlmesh.core.state_sync.db.facade") + with patch.object(logger, "error") as mock_logger: + state_sync.push_snapshots([snapshot_a]) + assert str({snapshot_a.snapshot_id}) == mock_logger.call_args[0][1] + state_sync.push_snapshots([snapshot_a, snapshot_b]) + assert str({snapshot_a.snapshot_id, snapshot_b.snapshot_id}) == mock_logger.call_args[0][1] + + # test serialization + state_sync.push_snapshots( + [ + make_snapshot( + SqlModel( + name="a", + kind=FullKind(), + query=parse_one( + """ + select 'x' + ' ' as y, + "z" + '\' as z, + """ + ), + ), + version="1", + ) + ] + ) + + +def test_duplicates(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable) -> None: + snapshot_a = make_snapshot( + SqlModel( + name="a", + query=parse_one("select 1, ds"), + ), + version="1", + ) + snapshot_b = make_snapshot( + SqlModel( + name="a", + query=parse_one("select 1, ds"), + ), + version="1", + ) + snapshot_c = make_snapshot( + SqlModel( + name="a", + query=parse_one("select 1, ds"), + ), + version="1", + ) + snapshot_b.updated_ts = snapshot_a.updated_ts + 1 + snapshot_c.updated_ts = 0 + state_sync.push_snapshots([snapshot_a]) + state_sync.snapshot_state.push_snapshots([snapshot_a]) + state_sync.snapshot_state.push_snapshots([snapshot_b]) + state_sync.snapshot_state.push_snapshots([snapshot_c]) + state_sync.snapshot_state.clear_cache() + assert ( + state_sync.get_snapshots([snapshot_a])[snapshot_a.snapshot_id].updated_ts + == snapshot_b.updated_ts + ) + + +def test_snapshots_exists(state_sync: EngineAdapterStateSync, snapshots: t.List[Snapshot]) -> None: + state_sync.push_snapshots(snapshots) + snapshot_ids = {snapshot.snapshot_id for snapshot in snapshots} + assert state_sync.snapshots_exist(snapshot_ids) == snapshot_ids + + +@pytest.fixture +def get_snapshot_intervals(state_sync) -> t.Callable[[Snapshot], t.Optional[SnapshotIntervals]]: + def _get_snapshot_intervals(snapshot: Snapshot) -> t.Optional[SnapshotIntervals]: + intervals = state_sync.interval_state.get_snapshot_intervals([snapshot]) + return intervals[0] if intervals else None + + return _get_snapshot_intervals + + +def test_add_interval( + state_sync: EngineAdapterStateSync, + make_snapshot: t.Callable, + get_snapshot_intervals: t.Callable, +) -> None: + snapshot = make_snapshot( + SqlModel( + name="a", + cron="@daily", + query=parse_one("select 1, ds"), + ), + version="a", + ) + + state_sync.push_snapshots([snapshot]) + + state_sync.add_interval(snapshot, "2020-01-01", "20200101") + assert get_snapshot_intervals(snapshot).intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-02")), + ] + + state_sync.add_interval(snapshot, "20200101", to_datetime("2020-01-04")) + assert get_snapshot_intervals(snapshot).intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-04")), + ] + + state_sync.add_interval(snapshot, to_datetime("2020-01-05"), "2020-01-10") + assert get_snapshot_intervals(snapshot).intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-04")), + (to_timestamp("2020-01-05"), to_timestamp("2020-01-11")), + ] + + snapshot.change_category = SnapshotChangeCategory.BREAKING + snapshot.forward_only = True + state_sync.add_interval(snapshot, to_datetime("2020-01-16"), "2020-01-20", is_dev=True) + intervals = get_snapshot_intervals(snapshot) + assert intervals.intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-04")), + (to_timestamp("2020-01-05"), to_timestamp("2020-01-11")), + ] + assert intervals.dev_intervals == [ + (to_timestamp("2020-01-16"), to_timestamp("2020-01-21")), + ] + + +def test_add_interval_partial( + state_sync: EngineAdapterStateSync, + make_snapshot: t.Callable, + get_snapshot_intervals: t.Callable, +) -> None: + snapshot = make_snapshot( + SqlModel( + name="a", + cron="@daily", + query=parse_one("select 1, ds"), + ), + version="a", + ) + + state_sync.push_snapshots([snapshot]) + + state_sync.add_interval(snapshot, "2023-01-01", to_timestamp("2023-01-01") + 1000) + assert get_snapshot_intervals(snapshot) is None + + state_sync.add_interval(snapshot, "2023-01-01", to_timestamp("2023-01-02") + 1000) + assert get_snapshot_intervals(snapshot).intervals == [ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + ] + + +def test_remove_interval(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable) -> None: + snapshot_a = make_snapshot( + SqlModel( + name="a", + cron="@daily", + query=parse_one("select 1, ds"), + ), + version="a", + ) + snapshot_b = make_snapshot( + SqlModel( + name="a", + cron="@daily", + query=parse_one("select 2::INT, '2022-01-01'::TEXT AS ds"), + ), + version="a", + ) + state_sync.push_snapshots([snapshot_a, snapshot_b]) + state_sync.add_interval(snapshot_a, "2020-01-01", "2020-01-10") + state_sync.add_interval(snapshot_b, "2020-01-11", "2020-01-30") + + num_of_removals = 4 + for _ in range(num_of_removals): + state_sync.remove_intervals( + [(snapshot_a, snapshot_a.inclusive_exclusive("2020-01-15", "2020-01-17"))], + remove_shared_versions=True, + ) + + remove_records_count = state_sync.engine_adapter.fetchone( + "SELECT COUNT(*) FROM sqlmesh._intervals WHERE name = '\"a\"' AND version = 'a' AND is_removed" + )[0] # type: ignore + assert remove_records_count == num_of_removals * 2 # 2 * snapshots + + snapshots = state_sync.get_snapshots([snapshot_a, snapshot_b]) + + assert snapshots[snapshot_a.snapshot_id].intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-15")), + (to_timestamp("2020-01-18"), to_timestamp("2020-01-31")), + ] + assert snapshots[snapshot_b.snapshot_id].intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-15")), + (to_timestamp("2020-01-18"), to_timestamp("2020-01-31")), + ] + + +def test_remove_interval_missing_snapshot( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable +) -> None: + snapshot_a = make_snapshot( + SqlModel( + name="a", + cron="@daily", + query=parse_one("select 1, ds"), + ), + version="a", + ) + snapshot_b = make_snapshot( + SqlModel( + name="a", + cron="@daily", + query=parse_one("select 2::INT, '2022-01-01'::TEXT AS ds"), + ), + version="a", + ) + state_sync.push_snapshots([snapshot_a, snapshot_b]) + state_sync.add_interval(snapshot_a, "2020-01-01", "2020-01-10") + state_sync.add_interval(snapshot_b, "2020-01-11", "2020-01-30") + # Remove snapshot b in order to test the scenario where it is missing + state_sync.delete_snapshots([snapshot_b.snapshot_id]) + + snapshots = state_sync.get_snapshots([snapshot_a, snapshot_b]) + assert len(snapshots) == 1 + assert snapshots[snapshot_a.snapshot_id].intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-31")), + ] + + state_sync.remove_intervals( + [(snapshot_a, snapshot_a.inclusive_exclusive("2020-01-15", "2020-01-17"))], + remove_shared_versions=True, + ) + + snapshots = state_sync.get_snapshots([snapshot_a, snapshot_b]) + assert len(snapshots) == 1 + assert snapshots[snapshot_a.snapshot_id].intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-15")), + (to_timestamp("2020-01-18"), to_timestamp("2020-01-31")), + ] + + +def test_refresh_snapshot_intervals( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable +) -> None: + snapshot = make_snapshot( + SqlModel( + name="a", + cron="@daily", + query=parse_one("select 1, ds"), + ), + version="a", + ) + + state_sync.push_snapshots([snapshot]) + state_sync.add_interval(snapshot, "2023-01-01", "2023-01-01") + assert not snapshot.intervals + + state_sync.refresh_snapshot_intervals([snapshot]) + assert snapshot.intervals == [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))] + + +def test_get_snapshot_intervals( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable, get_snapshot_intervals +) -> None: + state_sync.interval_state.SNAPSHOT_BATCH_SIZE = 1 + + snapshot_a = make_snapshot( + SqlModel( + name="a", + cron="@daily", + query=parse_one("select 1, ds"), + ), + version="a", + ) + + state_sync.push_snapshots([snapshot_a]) + state_sync.add_interval(snapshot_a, "2020-01-01", "2020-01-01") + + snapshot_b = make_snapshot( + SqlModel( + name="a", + cron="@daily", + query=parse_one("select 2, ds"), + ), + version="a", + ) + state_sync.push_snapshots([snapshot_b]) + + snapshot_c = make_snapshot( + SqlModel( + name="c", + cron="@daily", + query=parse_one("select 3, ds"), + ), + version="c", + ) + state_sync.push_snapshots([snapshot_c]) + state_sync.add_interval(snapshot_c, "2020-01-03", "2020-01-03") + + a_intervals = get_snapshot_intervals(snapshot_a) + c_intervals = get_snapshot_intervals(snapshot_c) + assert a_intervals.intervals == [(to_timestamp("2020-01-01"), to_timestamp("2020-01-02"))] + assert c_intervals.intervals == [(to_timestamp("2020-01-03"), to_timestamp("2020-01-04"))] + + +def test_compact_intervals( + state_sync: EngineAdapterStateSync, + make_snapshot: t.Callable, + get_snapshot_intervals: t.Callable, +) -> None: + snapshot = make_snapshot( + SqlModel( + name="a", + cron="@daily", + query=parse_one("select 1, ds"), + ), + version="a", + ) + + state_sync.push_snapshots([snapshot]) + + state_sync.add_interval(snapshot, "2020-01-01", "2020-01-10") + state_sync.add_interval(snapshot, "2020-01-11", "2020-01-15") + state_sync.remove_intervals( + [(snapshot, snapshot.inclusive_exclusive("2020-01-05", "2020-01-12"))] + ) + state_sync.add_interval(snapshot, "2020-01-12", "2020-01-16") + state_sync.remove_intervals( + [(snapshot, snapshot.inclusive_exclusive("2020-01-14", "2020-01-16"))] + ) + + expected_intervals = [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-05")), + (to_timestamp("2020-01-12"), to_timestamp("2020-01-14")), + ] + + assert get_snapshot_intervals(snapshot).intervals == expected_intervals + + state_sync.compact_intervals() + assert get_snapshot_intervals(snapshot).intervals == expected_intervals + + # Make sure compaction is idempotent. + state_sync.compact_intervals() + assert get_snapshot_intervals(snapshot).intervals == expected_intervals + + +def test_compact_intervals_delete_batches( + state_sync: EngineAdapterStateSync, + make_snapshot: t.Callable, + mocker: MockerFixture, +) -> None: + snapshot = make_snapshot( + SqlModel( + name="a", + cron="@daily", + query=parse_one("select 1, ds"), + ), + version="a", + ) + + delete_from_mock = mocker.patch.object(state_sync.engine_adapter, "delete_from") + state_sync.interval_state.INTERVAL_BATCH_SIZE = 2 + + state_sync.push_snapshots([snapshot]) + + state_sync.add_interval(snapshot, "2020-01-01", "2020-01-11") + state_sync.add_interval(snapshot, "2020-01-01", "2020-01-12") + state_sync.add_interval(snapshot, "2020-01-01", "2020-01-13") + state_sync.add_interval(snapshot, "2020-01-01", "2020-01-14") + state_sync.add_interval(snapshot, "2020-01-01", "2020-01-15") + + state_sync.compact_intervals() + + delete_from_mock.assert_has_calls( + [call(state_sync.interval_state.intervals_table, mocker.ANY)] * 3 + ) + + +def test_promote_snapshots(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable): + snapshot_a = make_snapshot( + SqlModel( + name="a", + query=parse_one("select 1, ds"), + ), + ) + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) + + snapshot_b_old = make_snapshot( + SqlModel( + name="b", + kind=FullKind(), + query=parse_one("select 2 from a"), + ), + nodes={"a": snapshot_a.model}, + ) + snapshot_b_old.categorize_as(SnapshotChangeCategory.BREAKING) + + snapshot_b = make_snapshot( + SqlModel( + name="b", + kind=FullKind(), + query=parse_one("select * from a"), + ), + nodes={"a": snapshot_a.model}, + ) + snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) + + snapshot_c = make_snapshot( + SqlModel( + name="c", + query=parse_one("select 3, ds"), + ), + ) + snapshot_c.categorize_as(SnapshotChangeCategory.BREAKING) + + with pytest.raises( + SQLMeshError, + match=r"Missing snapshots.*", + ): + promote_snapshots(state_sync, [snapshot_a], "prod") + + state_sync.push_snapshots([snapshot_a, snapshot_b_old, snapshot_b, snapshot_c]) + + promotion_result = promote_snapshots(state_sync, [snapshot_a, snapshot_b_old], "prod") + + assert set(promotion_result.added) == set([snapshot_a.table_info, snapshot_b_old.table_info]) + assert not promotion_result.removed + assert not promotion_result.removed_environment_naming_info + promotion_result = promote_snapshots( + state_sync, + [snapshot_a, snapshot_b_old, snapshot_c], + "prod", + ) + assert set(promotion_result.added) == { + snapshot_a.table_info, + snapshot_b_old.table_info, + snapshot_c.table_info, + } + assert not promotion_result.removed + assert not promotion_result.removed_environment_naming_info + + prev_snapshot_b_old_updated_ts = snapshot_b_old.updated_ts + prev_snapshot_c_updated_ts = snapshot_c.updated_ts + + promotion_result = promote_snapshots( + state_sync, + [snapshot_a, snapshot_b], + "prod", + ) + assert set(promotion_result.added) == {snapshot_a.table_info, snapshot_b.table_info} + assert set(promotion_result.removed) == {snapshot_c.table_info} + assert promotion_result.removed_environment_naming_info + assert promotion_result.removed_environment_naming_info.suffix_target.is_schema + assert ( + state_sync.get_snapshots([snapshot_c])[snapshot_c.snapshot_id].updated_ts + > prev_snapshot_c_updated_ts + ) + assert ( + state_sync.get_snapshots([snapshot_b_old])[snapshot_b_old.snapshot_id].updated_ts + > prev_snapshot_b_old_updated_ts + ) + + snapshot_d = make_snapshot( + SqlModel( + name="a", + query=parse_one("select 2, ds"), + ), + ) + snapshot_d.categorize_as(SnapshotChangeCategory.BREAKING) + + state_sync.push_snapshots([snapshot_d]) + promotion_result = promote_snapshots(state_sync, [snapshot_d], "prod") + assert set(promotion_result.added) == {snapshot_d.table_info} + assert set(promotion_result.removed) == {snapshot_b.table_info} + assert promotion_result.removed_environment_naming_info + assert promotion_result.removed_environment_naming_info.suffix_target.is_schema + + +def test_promote_snapshots_suffix_change( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable +): + snapshot_a = make_snapshot( + SqlModel( + name="a", + query=parse_one("select 1, ds"), + ), + ) + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) + + snapshot_b = make_snapshot( + SqlModel( + name="b", + kind=FullKind(), + query=parse_one("select * from a"), + ), + nodes={"a": snapshot_a.model}, + ) + snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) + + state_sync.push_snapshots([snapshot_a, snapshot_b]) + + promotion_result = promote_snapshots( + state_sync, + [snapshot_a, snapshot_b], + "prod", + environment_suffix_target=EnvironmentSuffixTarget.TABLE, + ) + + assert set(promotion_result.added) == {snapshot_a.table_info, snapshot_b.table_info} + assert not promotion_result.removed + assert not promotion_result.removed_environment_naming_info + + snapshot_c = make_snapshot( + SqlModel( + name="c", + query=parse_one("select 3, ds"), + ), + ) + snapshot_c.categorize_as(SnapshotChangeCategory.BREAKING) + + state_sync.push_snapshots([snapshot_c]) + + promotion_result = promote_snapshots( + state_sync, + [snapshot_b, snapshot_c], + "prod", + environment_suffix_target=EnvironmentSuffixTarget.SCHEMA, + ) + + # We still only add the snapshots that are included in the promotion + assert set(promotion_result.added) == {snapshot_b.table_info, snapshot_c.table_info} + # B does not get removed because the suffix target change doesn't affect it due to running in prod. + assert set(promotion_result.removed) == {snapshot_a.table_info} + # Make sure the removed suffix target is correctly seen as table + assert promotion_result.removed_environment_naming_info is not None + assert promotion_result.removed_environment_naming_info.suffix_target.is_table + + promotion_result = promote_snapshots( + state_sync, + [snapshot_b, snapshot_c], + "dev", + environment_suffix_target=EnvironmentSuffixTarget.SCHEMA, + ) + + # We still only add the snapshots that are included in the promotion + assert set(promotion_result.added) == {snapshot_b.table_info, snapshot_c.table_info} + assert len(promotion_result.removed) == 0 + assert promotion_result.removed_environment_naming_info is None + + promotion_result = promote_snapshots( + state_sync, + [snapshot_b, snapshot_c], + "dev", + environment_suffix_target=EnvironmentSuffixTarget.TABLE, + ) + + # All snapshots are promoted due to suffix target change + assert set(promotion_result.added) == { + snapshot_b.table_info, + snapshot_c.table_info, + } + # All snapshots are removed due to suffix target change + assert set(promotion_result.removed) == { + snapshot_b.table_info, + snapshot_c.table_info, + } + assert promotion_result.removed_environment_naming_info is not None + assert promotion_result.removed_environment_naming_info.suffix_target.is_schema + + +def test_promote_snapshots_catalog_name_override_change( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable +): + snapshot_a = make_snapshot( + SqlModel( + name="catalog1.schema.a", + query=parse_one("select 1, ds"), + ), + ) + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) + + snapshot_b = make_snapshot( + SqlModel( + name="catalog1.schema.b", + kind=FullKind(), + query=parse_one("select * from a"), + ), + nodes={"a": snapshot_a.model}, + ) + snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) + + snapshot_c = make_snapshot( + SqlModel( + name="catalog2.schema.c", + kind=FullKind(), + query=parse_one("select * from a"), + ), + nodes={"a": snapshot_a.model}, + ) + snapshot_c.categorize_as(SnapshotChangeCategory.BREAKING) + + state_sync.push_snapshots([snapshot_a, snapshot_b, snapshot_c]) + + promotion_result = promote_snapshots( + state_sync, + [snapshot_a, snapshot_b, snapshot_c], + "prod", + environment_catalog_mapping={}, + ) + + assert set(promotion_result.added) == { + snapshot_a.table_info, + snapshot_b.table_info, + snapshot_c.table_info, + } + assert not promotion_result.removed + assert not promotion_result.removed_environment_naming_info + + snapshot_d = make_snapshot( + SqlModel( + name="catalog1.schema.d", + query=parse_one("select 3, ds"), + ), + ) + snapshot_d.categorize_as(SnapshotChangeCategory.BREAKING) + + state_sync.push_snapshots([snapshot_d]) + + promotion_result = promote_snapshots( + state_sync, + [snapshot_b, snapshot_c, snapshot_d], + "prod", + environment_catalog_mapping={ + re.compile("^prod$"): "catalog1", + }, + ) + + # We still only add the snapshots that are included in the promotion which means removing A + assert set(promotion_result.added) == { + snapshot_b.table_info, + snapshot_c.table_info, + snapshot_d.table_info, + } + # C is removed because of the catalog change. The new one will be created in the new catalog. + # B is not removed because it's catalog did not change and therefore removing would actually result + # in dropping what we just added. + # A is removed because it was explicitly removed from the promotion. + assert set(promotion_result.removed) == {snapshot_a.table_info, snapshot_c.table_info} + # Make sure the removed suffix target correctly has the old catalog name set + assert promotion_result.removed_environment_naming_info + assert promotion_result.removed_environment_naming_info.catalog_name_override is None + + promotion_result = promote_snapshots( + state_sync, + [snapshot_b, snapshot_c, snapshot_d], + "prod", + environment_catalog_mapping={ + re.compile("^prod$"): "catalog2", + }, + ) + + # All are added since their catalog was changed + assert set(promotion_result.added) == { + snapshot_b.table_info, + snapshot_c.table_info, + snapshot_d.table_info, + } + # All are removed since there were moved from their old catalog location + # Note that C has a catalog set in the model definition of `catalog2` which is what we moved to so you might think + # it shouldn't be removed, but its actual catalog was `catalog1` because of the previous override so therefore + # it should be removed from `catalog1`. + assert set(promotion_result.removed) == { + snapshot_b.table_info, + snapshot_c.table_info, + snapshot_d.table_info, + } + # Make sure the removed suffix target correctly has the old catalog name set + assert promotion_result.removed_environment_naming_info + assert promotion_result.removed_environment_naming_info.catalog_name_override == "catalog1" + + +def test_promote_snapshots_parent_plan_id_mismatch( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable +): + snapshot = make_snapshot( + SqlModel( + name="a", + query=parse_one("select 1, ds"), + ), + ) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + state_sync.push_snapshots([snapshot]) + promote_snapshots(state_sync, [snapshot], "prod") + + new_environment = Environment( + name="prod", + snapshots=[snapshot.table_info], + start_at="2022-01-01", + end_at="2022-01-01", + plan_id="new_plan_id", + previous_plan_id="test_plan_id", + ) + + stale_new_environment = Environment( + name="prod", + snapshots=[snapshot.table_info], + start_at="2022-01-01", + end_at="2022-01-01", + plan_id="stale_new_plan_id", + previous_plan_id="test_plan_id", + ) + + state_sync.promote(new_environment) + + with pytest.raises( + SQLMeshError, + match=re.escape( + "Another plan (new_plan_id) was applied to the target environment 'prod' while your current plan (stale_new_plan_id) was still in progress, interrupting it. Please re-apply your plan to resolve this error." + ), + ): + state_sync.promote(stale_new_environment) + + +@pytest.mark.parametrize("environment_name", ["dev", "prod"]) +def test_promote_environment_expired( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable, environment_name: str +): + snapshot = make_snapshot( + SqlModel( + name="a", + query=parse_one("select 1, ds"), + ), + ) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + state_sync.push_snapshots([snapshot]) + promote_snapshots(state_sync, [snapshot], "dev") + state_sync.finalize(state_sync.get_environment("dev")) + state_sync.invalidate_environment("dev") + + new_environment = Environment( + name=environment_name, + snapshots=[snapshot.table_info], + start_at="2022-01-01", + end_at="2022-01-01", + plan_id="new_plan_id", + previous_plan_id=None, # No previous plan ID since it's technically a new environment + expiration_ts=now_timestamp() + 3600, + ) + assert new_environment.expiration_ts + + # This call shouldn't fail. + promotion_result = state_sync.promote(new_environment) + assert promotion_result.added == [snapshot.table_info] + assert promotion_result.removed == [] + assert promotion_result.removed_environment_naming_info is None + + state_sync.finalize(new_environment) + + new_environment.previous_plan_id = new_environment.plan_id + new_environment.plan_id = "another_plan_id" + promotion_result = state_sync.promote(new_environment) + + # Should be empty since the environment is no longer expired and nothing has changed + assert promotion_result.removed == [] + assert promotion_result.removed_environment_naming_info is None + if environment_name == "prod": + assert promotion_result.added == [] + else: + # We should always recreate views in dev environments + assert promotion_result.added == [snapshot.table_info] + + +def test_promote_snapshots_no_gaps(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable): + model = SqlModel( + name="a", + query=parse_one("select 1, ds"), + kind=IncrementalByTimeRangeKind(time_column="ds"), + start="2022-01-01", + ) + + snapshot = make_snapshot(model, version="a") + snapshot.change_category = SnapshotChangeCategory.BREAKING + state_sync.push_snapshots([snapshot]) + state_sync.add_interval(snapshot, "2022-01-01", "2022-01-02") + promote_snapshots(state_sync, [snapshot], "prod", no_gaps=True) + + new_snapshot_same_version = make_snapshot(model, version="a") + new_snapshot_same_version.change_category = SnapshotChangeCategory.INDIRECT_NON_BREAKING + new_snapshot_same_version.fingerprint = snapshot.fingerprint.copy( + update={"data_hash": "new_snapshot_same_version"} + ) + state_sync.push_snapshots([new_snapshot_same_version]) + state_sync.add_interval(new_snapshot_same_version, "2022-01-03", "2022-01-03") + promote_snapshots(state_sync, [new_snapshot_same_version], "prod", no_gaps=True) + + new_snapshot_missing_interval = make_snapshot(model, version="b") + new_snapshot_missing_interval.change_category = SnapshotChangeCategory.BREAKING + new_snapshot_missing_interval.fingerprint = snapshot.fingerprint.copy( + update={"data_hash": "new_snapshot_missing_interval"} + ) + state_sync.push_snapshots([new_snapshot_missing_interval]) + state_sync.add_interval(new_snapshot_missing_interval, "2022-01-01", "2022-01-02") + with pytest.raises( + SQLMeshError, + match=r".*Detected missing intervals for model .*, interrupting your current plan. Please re-apply your plan to resolve this error.*", + ): + promote_snapshots(state_sync, [new_snapshot_missing_interval], "prod", no_gaps=True) + + new_snapshot_same_interval = make_snapshot(model, version="c") + new_snapshot_same_interval.change_category = SnapshotChangeCategory.BREAKING + new_snapshot_same_interval.fingerprint = snapshot.fingerprint.copy( + update={"data_hash": "new_snapshot_same_interval"} + ) + state_sync.push_snapshots([new_snapshot_same_interval]) + state_sync.add_interval(new_snapshot_same_interval, "2022-01-01", "2022-01-03") + promote_snapshots(state_sync, [new_snapshot_same_interval], "prod", no_gaps=True) + + # We should skip the gaps check if the snapshot is not representative. + promote_snapshots( + state_sync, + [new_snapshot_missing_interval], + "prod", + no_gaps=True, + no_gaps_snapshot_names=set(), + ) + + +@time_machine.travel("2023-01-08 16:00:00 UTC", tick=False) +def test_promote_snapshots_no_gaps_lookback( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable +): + model = SqlModel( + name="a", + cron="@hourly", + query=parse_one("select 1, ds"), + kind=IncrementalByTimeRangeKind(time_column="ds", lookback=1), + start="2023-01-01", + ) + + snapshot = make_snapshot(model, version="a") + snapshot.change_category = SnapshotChangeCategory.BREAKING + state_sync.push_snapshots([snapshot]) + state_sync.add_interval(snapshot, "2023-01-01", "2023-01-08 15:00:00") + promote_snapshots(state_sync, [snapshot], "prod", no_gaps=True) + + assert now_timestamp() == to_timestamp("2023-01-08 16:00:00") + + new_snapshot_same_version = make_snapshot(model, version="b") + new_snapshot_same_version.change_category = SnapshotChangeCategory.BREAKING + new_snapshot_same_version.fingerprint = snapshot.fingerprint.copy( + update={"data_hash": "new_snapshot_same_version"} + ) + state_sync.push_snapshots([new_snapshot_same_version]) + state_sync.add_interval(new_snapshot_same_version, "2023-01-01", "2023-01-08 15:00:00") + promote_snapshots(state_sync, [new_snapshot_same_version], "prod", no_gaps=True) + + +def test_finalize(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable): + snapshot_a = make_snapshot( + SqlModel( + name="a", + query=parse_one("select 1, ds"), + ), + ) + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) + + state_sync.push_snapshots([snapshot_a]) + promote_snapshots(state_sync, [snapshot_a], "prod") + + env = state_sync.get_environment("prod") + assert env + state_sync.finalize(env) + + env = state_sync.get_environment("prod") + assert env + assert env.finalized_ts is not None + + env.plan_id = "different_plan_id" + with pytest.raises( + SQLMeshError, + match=re.escape( + "Another plan (test_plan_id) was applied to the target environment 'prod' while your current plan (different_plan_id) was still in progress, interrupting it. Please re-apply your plan to resolve this error." + ), + ): + state_sync.finalize(env) + + +def test_start_date_gap(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable): + model = SqlModel( + name="a", + query=parse_one("select 1, ds"), + start="2022-01-01", + kind=IncrementalByTimeRangeKind(time_column="ds"), + cron="@daily", + ) + + snapshot = make_snapshot(model, version="a") + snapshot.change_category = SnapshotChangeCategory.BREAKING + state_sync.push_snapshots([snapshot]) + state_sync.add_interval(snapshot, "2022-01-01", "2022-01-03") + promote_snapshots(state_sync, [snapshot], "prod") + + model = SqlModel( + name="a", + query=parse_one("select 1, ds"), + start="2022-01-02", + kind=IncrementalByTimeRangeKind(time_column="ds"), + cron="@daily", + ) + + snapshot = make_snapshot(model, version="b") + snapshot.change_category = SnapshotChangeCategory.BREAKING + state_sync.push_snapshots([snapshot]) + state_sync.add_interval(snapshot, "2022-01-03", "2022-01-04") + with pytest.raises( + SQLMeshError, + match=r".*Detected missing intervals for model .*, interrupting your current plan. Please re-apply your plan to resolve this error.*", + ): + promote_snapshots(state_sync, [snapshot], "prod", no_gaps=True) + + state_sync.add_interval(snapshot, "2022-01-02", "2022-01-03") + promote_snapshots(state_sync, [snapshot], "prod", no_gaps=True) + + +def test_delete_expired_environments(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable): + snapshot = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, ds"), + ), + ) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + state_sync.push_snapshots([snapshot]) + + now_ts = now_timestamp() + + env_a = Environment( + name="test_environment_a", + snapshots=[snapshot.table_info], + start_at="2022-01-01", + end_at="2022-01-01", + plan_id="test_plan_id", + previous_plan_id="test_plan_id", + expiration_ts=now_ts - 1000, + ) + + environment_statements = [ + EnvironmentStatements( + before_all=["CREATE OR REPLACE TABLE table_1 AS SELECT 'a'"], + after_all=["CREATE OR REPLACE TABLE table_2 AS SELECT 'b'"], + python_env={}, + ) + ] + + state_sync.promote(env_a, environment_statements=environment_statements) + + env_b = env_a.copy(update={"name": "test_environment_b", "expiration_ts": now_ts + 1000}) + state_sync.promote(env_b) + + env_a = Environment(**json.loads(env_a.json())) + env_b = Environment(**json.loads(env_b.json())) + + assert state_sync.get_environment(env_a.name) == env_a + assert state_sync.get_environment(env_b.name) == env_b + + assert not state_sync.get_environment_statements(env_b.name) + assert state_sync.get_environment_statements(env_a.name) == environment_statements + + deleted_environments = state_sync.delete_expired_environments() + assert deleted_environments == [env_a.summary] + + assert state_sync.get_environment(env_a.name) is None + assert state_sync.get_environment(env_b.name) == env_b + + # Deleting the environments should remove the corresponding environment's statements + assert state_sync.get_environment_statements(env_a.name) == [] + + +def test_delete_expired_snapshots(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable): + now_ts = now_timestamp() + + snapshot = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, ds"), + ), + ) + snapshot.ttl = "in 10 seconds" + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot.updated_ts = now_ts - 15000 + + new_snapshot = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, b, ds"), + ), + ) + new_snapshot.ttl = "in 10 seconds" + new_snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) + new_snapshot.version = snapshot.version + new_snapshot.updated_ts = now_ts - 11000 + + all_snapshots = [snapshot, new_snapshot] + state_sync.push_snapshots(all_snapshots) + assert set(state_sync.get_snapshots(all_snapshots)) == { + snapshot.snapshot_id, + new_snapshot.snapshot_id, + } + + assert _get_cleanup_tasks(state_sync) == [ + SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=True), + SnapshotTableCleanupTask(snapshot=new_snapshot.table_info, dev_table_only=False), + ] + state_sync.delete_expired_snapshots(batch_range=ExpiredBatchRange.all_batch_range()) + + assert not state_sync.get_snapshots(all_snapshots) + + +def test_get_expired_snapshot_batch(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable): + now_ts = now_timestamp() + + snapshots = [] + for idx in range(3): + snapshot = make_snapshot( + SqlModel( + name=f"model_{idx}", + query=parse_one("select 1 as a, ds"), + ), + ) + snapshot.ttl = "in 10 seconds" + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot.updated_ts = now_ts - (20000 + idx * 1000) + snapshots.append(snapshot) + + state_sync.push_snapshots(snapshots) + + batch = state_sync.get_expired_snapshots( + batch_range=ExpiredBatchRange.init_batch_range(batch_size=2), + ) + assert batch is not None + assert len(batch.expired_snapshot_ids) == 2 + assert len(batch.cleanup_tasks) == 2 + + state_sync.delete_expired_snapshots( + batch_range=ExpiredBatchRange( + start=RowBoundary.lowest_boundary(), + end=batch.batch_range.end, + ), + ) + + next_batch = state_sync.get_expired_snapshots( + batch_range=ExpiredBatchRange( + start=batch.batch_range.end, + end=LimitBoundary(batch_size=2), + ), + ) + assert next_batch is not None + assert len(next_batch.expired_snapshot_ids) == 1 + + state_sync.delete_expired_snapshots( + batch_range=ExpiredBatchRange( + start=next_batch.batch_range.start, + end=next_batch.batch_range.end, + ), + ) + + assert ( + state_sync.get_expired_snapshots( + batch_range=ExpiredBatchRange( + start=next_batch.batch_range.end, + end=LimitBoundary(batch_size=2), + ), + ) + is None + ) + + +def test_get_expired_snapshot_batch_same_timestamp( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable +): + """Test that pagination works correctly when multiple snapshots have the same updated_ts.""" + now_ts = now_timestamp() + same_timestamp = now_ts - 20000 + + snapshots = [] + for idx in range(5): + snapshot = make_snapshot( + SqlModel( + name=f"model_{idx:02d}", # Zero-padded to ensure deterministic name ordering + query=parse_one("select 1 as a, ds"), + ), + ) + snapshot.ttl = "in 10 seconds" + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + # All snapshots have the same updated_ts + snapshot.updated_ts = same_timestamp + snapshots.append(snapshot) + + state_sync.push_snapshots(snapshots) + + # Fetch first batch of 2 + batch1 = state_sync.get_expired_snapshots( + batch_range=ExpiredBatchRange.init_batch_range(batch_size=2), + ) + assert batch1 is not None + assert len(batch1.expired_snapshot_ids) == 2 + assert sorted([x.name for x in batch1.expired_snapshot_ids]) == [ + '"model_00"', + '"model_01"', + ] + + # Fetch second batch of 2 using cursor from batch1 + batch2 = state_sync.get_expired_snapshots( + batch_range=ExpiredBatchRange( + start=batch1.batch_range.end, + end=LimitBoundary(batch_size=2), + ), + ) + assert batch2 is not None + assert len(batch2.expired_snapshot_ids) == 2 + assert sorted([x.name for x in batch2.expired_snapshot_ids]) == [ + '"model_02"', + '"model_03"', + ] + + # Fetch third batch of 2 using cursor from batch2 + batch3 = state_sync.get_expired_snapshots( + batch_range=ExpiredBatchRange( + start=batch2.batch_range.end, + end=LimitBoundary(batch_size=2), + ), + ) + assert batch3 is not None + assert sorted([x.name for x in batch3.expired_snapshot_ids]) == [ + '"model_04"', + ] + + +def test_delete_expired_snapshots_batching_with_deletion( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable +): + """Test that delete_expired_snapshots properly deletes batches as it pages through them.""" + now_ts = now_timestamp() + + # Create 5 expired snapshots with different timestamps + snapshots = [] + for idx in range(5): + snapshot = make_snapshot( + SqlModel( + name=f"model_{idx}", + query=parse_one("select 1 as a, ds"), + ), + ) + snapshot.ttl = "in 10 seconds" + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot.updated_ts = now_ts - (20000 + idx * 1000) + snapshots.append(snapshot) + + state_sync.push_snapshots(snapshots) + + # Verify all 5 snapshots exist + assert len(state_sync.get_snapshots(snapshots)) == 5 + + # Get first batch of 2 + batch1 = state_sync.get_expired_snapshots( + batch_range=ExpiredBatchRange.init_batch_range(batch_size=2), + ) + assert batch1 is not None + assert len(batch1.expired_snapshot_ids) == 2 + + # Delete the first batch using batch_range + state_sync.delete_expired_snapshots( + batch_range=ExpiredBatchRange( + start=batch1.batch_range.start, + end=batch1.batch_range.end, + ), + ) + + # Verify first 2 snapshots (model_0 and model_1, the oldest) are deleted and last 3 remain + remaining = state_sync.get_snapshots(snapshots) + assert len(remaining) == 3 + assert snapshots[0].snapshot_id in remaining # model_0 (newest) + assert snapshots[1].snapshot_id in remaining # model_1 + assert snapshots[2].snapshot_id in remaining # model_2 + assert snapshots[3].snapshot_id not in remaining # model_3 + assert snapshots[4].snapshot_id not in remaining # model_4 (oldest) + + # Get next batch of 2 (should start after batch1's boundary) + batch2 = state_sync.get_expired_snapshots( + batch_range=ExpiredBatchRange( + start=batch1.batch_range.end, + end=LimitBoundary(batch_size=2), + ), + ) + assert batch2 is not None + assert len(batch2.expired_snapshot_ids) == 2 + + # Delete the second batch + state_sync.delete_expired_snapshots( + batch_range=ExpiredBatchRange( + start=batch2.batch_range.start, + end=batch2.batch_range.end, + ), + ) + + # Verify only the last snapshot remains + remaining = state_sync.get_snapshots(snapshots) + assert len(remaining) == 1 + assert snapshots[0].snapshot_id in remaining # model_0 (newest) + assert snapshots[1].snapshot_id not in remaining # model_1 + assert snapshots[2].snapshot_id not in remaining # model_2 + assert snapshots[3].snapshot_id not in remaining # model_3 + assert snapshots[4].snapshot_id not in remaining # model_4 (oldest) + + # Get final batch + batch3 = state_sync.get_expired_snapshots( + batch_range=ExpiredBatchRange( + start=batch2.batch_range.end, + end=LimitBoundary(batch_size=2), + ), + ) + assert batch3 is not None + assert len(batch3.expired_snapshot_ids) == 1 + + # Delete the final batch + state_sync.delete_expired_snapshots( + batch_range=ExpiredBatchRange( + start=batch3.batch_range.start, + end=batch3.batch_range.end, + ), + ) + + # Verify all snapshots are deleted + assert len(state_sync.get_snapshots(snapshots)) == 0 + + # Verify no more expired snapshots exist + assert ( + state_sync.get_expired_snapshots( + batch_range=ExpiredBatchRange( + start=batch3.batch_range.end, + end=LimitBoundary(batch_size=2), + ), + ) + is None + ) + + +def test_iterator_expired_snapshot_batch( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable +): + """Test the for_each_expired_snapshot_batch helper function.""" + from sqlmesh.core.state_sync.common import iter_expired_snapshot_batches + + now_ts = now_timestamp() + + snapshots = [] + for idx in range(5): + snapshot = make_snapshot( + SqlModel( + name=f"model_{idx}", + query=parse_one("select 1 as a, ds"), + ), + ) + snapshot.ttl = "in 10 seconds" + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot.updated_ts = now_ts - (20000 + idx * 1000) + snapshots.append(snapshot) + + state_sync.push_snapshots(snapshots) + + # Track all batches processed + batches_processed = [] + + # Process with batch size of 2 + for batch in iter_expired_snapshot_batches( + state_sync, + current_ts=now_ts, + ignore_ttl=False, + batch_size=2, + ): + batches_processed.append(batch) + + # Should have processed 3 batches (2 + 2 + 1) + assert len(batches_processed) == 3 + assert len(batches_processed[0].expired_snapshot_ids) == 2 + assert len(batches_processed[1].expired_snapshot_ids) == 2 + assert len(batches_processed[2].expired_snapshot_ids) == 1 + + # Verify all snapshots were processed + all_processed_ids = set() + for batch in batches_processed: + all_processed_ids.update(batch.expired_snapshot_ids) + + expected_ids = {s.snapshot_id for s in snapshots} + assert all_processed_ids == expected_ids + + +@pytest.mark.parametrize( + "start_boundary,end_boundary,expected_sql", + [ + # Test with GT only (when end is LimitBoundary) + ( + RowBoundary(updated_ts=0, name="", identifier=""), + LimitBoundary(batch_size=100), + "updated_ts > 0 OR (updated_ts = 0 AND name > '') OR (updated_ts = 0 AND name = '' AND identifier > '')", + ), + # Test with GT and LTE (when both are RowBoundary) + ( + RowBoundary(updated_ts=1000, name="model_a", identifier="abc"), + RowBoundary(updated_ts=2000, name="model_z", identifier="xyz"), + "(updated_ts > 1000 OR (updated_ts = 1000 AND name > 'model_a') OR (updated_ts = 1000 AND name = 'model_a' AND identifier > 'abc')) AND (updated_ts < 2000 OR (updated_ts = 2000 AND name < 'model_z') OR (updated_ts = 2000 AND name = 'model_z' AND identifier <= 'xyz'))", + ), + # Test with zero timestamp + ( + RowBoundary(updated_ts=0, name="", identifier=""), + RowBoundary(updated_ts=1234567890, name="model_x", identifier="id_123"), + "(updated_ts > 0 OR (updated_ts = 0 AND name > '') OR (updated_ts = 0 AND name = '' AND identifier > '')) AND (updated_ts < 1234567890 OR (updated_ts = 1234567890 AND name < 'model_x') OR (updated_ts = 1234567890 AND name = 'model_x' AND identifier <= 'id_123'))", + ), + # Test with same timestamp, different names + ( + RowBoundary(updated_ts=5000, name="model_a", identifier="id_1"), + RowBoundary(updated_ts=5000, name="model_b", identifier="id_2"), + "(updated_ts > 5000 OR (updated_ts = 5000 AND name > 'model_a') OR (updated_ts = 5000 AND name = 'model_a' AND identifier > 'id_1')) AND (updated_ts < 5000 OR (updated_ts = 5000 AND name < 'model_b') OR (updated_ts = 5000 AND name = 'model_b' AND identifier <= 'id_2'))", + ), + # Test with same timestamp and name, different identifiers + ( + RowBoundary(updated_ts=7000, name="model_x", identifier="id_a"), + RowBoundary(updated_ts=7000, name="model_x", identifier="id_b"), + "(updated_ts > 7000 OR (updated_ts = 7000 AND name > 'model_x') OR (updated_ts = 7000 AND name = 'model_x' AND identifier > 'id_a')) AND (updated_ts < 7000 OR (updated_ts = 7000 AND name < 'model_x') OR (updated_ts = 7000 AND name = 'model_x' AND identifier <= 'id_b'))", + ), + # Test all_batch_range use case + ( + RowBoundary(updated_ts=0, name="", identifier=""), + RowBoundary(updated_ts=253_402_300_799_999, name="", identifier=""), + "(updated_ts > 0 OR (updated_ts = 0 AND name > '') OR (updated_ts = 0 AND name = '' AND identifier > '')) AND (updated_ts < 253402300799999 OR (updated_ts = 253402300799999 AND name < '') OR (updated_ts = 253402300799999 AND name = '' AND identifier <= ''))", + ), + ], +) +def test_expired_batch_range_where_filter(start_boundary, end_boundary, expected_sql): + """Test ExpiredBatchRange.where_filter generates correct SQL for various boundary combinations.""" + batch_range = ExpiredBatchRange(start=start_boundary, end=end_boundary) + result = batch_range.where_filter + assert result.sql() == expected_sql + + +def test_expired_batch_range_where_filter_with_limit(): + """Test that where_filter correctly handles LimitBoundary (only start condition, no end condition).""" + batch_range = ExpiredBatchRange( + start=RowBoundary(updated_ts=1000, name="model_a", identifier="abc"), + end=LimitBoundary(batch_size=50), + ) + result = batch_range.where_filter + # When end is LimitBoundary, should only have the start (GT) condition + assert ( + result.sql() + == "updated_ts > 1000 OR (updated_ts = 1000 AND name > 'model_a') OR (updated_ts = 1000 AND name = 'model_a' AND identifier > 'abc')" + ) + + +def test_delete_expired_snapshots_seed( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable +): + now_ts = now_timestamp() + + snapshot = make_snapshot( + SeedModel( + name="a", + kind=SeedKind(path="./path/to/seed"), + seed=Seed(content="header\n1\n2"), + column_hashes={"header": "hash"}, + depends_on=set(), + ), + ) + snapshot.ttl = "in 10 seconds" + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot.updated_ts = now_ts - 15000 + + all_snapshots = [snapshot] + state_sync.push_snapshots(all_snapshots) + assert set(state_sync.get_snapshots(all_snapshots)) == {snapshot.snapshot_id} + + assert _get_cleanup_tasks(state_sync) == [ + SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=False), + ] + state_sync.delete_expired_snapshots(batch_range=ExpiredBatchRange.all_batch_range()) + + assert not state_sync.get_snapshots(all_snapshots) + + +def test_delete_expired_snapshots_batching( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable +): + state_sync.snapshot_state.SNAPSHOT_BATCH_SIZE = 1 + now_ts = now_timestamp() + + snapshot_a = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, ds"), + ), + ) + snapshot_a.ttl = "in 10 seconds" + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_a.updated_ts = now_ts - 15000 + + snapshot_b = make_snapshot( + SqlModel( + name="b", + query=parse_one("select a, b, ds"), + ), + ) + snapshot_b.ttl = "in 10 seconds" + snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_b.updated_ts = now_ts - 11000 + + all_snapshots = [snapshot_a, snapshot_b] + state_sync.push_snapshots(all_snapshots) + assert set(state_sync.get_snapshots(all_snapshots)) == { + snapshot_a.snapshot_id, + snapshot_b.snapshot_id, + } + + assert _get_cleanup_tasks(state_sync) == [ + SnapshotTableCleanupTask(snapshot=snapshot_a.table_info, dev_table_only=False), + SnapshotTableCleanupTask(snapshot=snapshot_b.table_info, dev_table_only=False), + ] + state_sync.delete_expired_snapshots(batch_range=ExpiredBatchRange.all_batch_range()) + + assert not state_sync.get_snapshots(all_snapshots) + + +def test_delete_expired_snapshots_promoted( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable, mocker: MockerFixture +): + now_ts = now_timestamp() + + snapshot = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, ds"), + ), + ) + snapshot.ttl = "in 10 seconds" + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot.updated_ts = now_ts - 15000 + + state_sync.push_snapshots([snapshot]) + + env = Environment( + name="test_environment", + snapshots=[snapshot.table_info], + start_at="2022-01-01", + end_at="2022-01-01", + plan_id="test_plan_id", + previous_plan_id="test_plan_id", + ) + state_sync.promote(env) + + all_snapshots = [snapshot] + assert not _get_cleanup_tasks(state_sync) + state_sync.delete_expired_snapshots(batch_range=ExpiredBatchRange.all_batch_range()) + assert set(state_sync.get_snapshots(all_snapshots)) == {snapshot.snapshot_id} + + env.snapshots_ = [] + state_sync.promote(env) + + now_timestamp_mock = mocker.patch("sqlmesh.core.state_sync.db.facade.now_timestamp") + now_timestamp_mock.return_value = now_timestamp() + 11000 + + assert _get_cleanup_tasks(state_sync) == [ + SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=False) + ] + state_sync.delete_expired_snapshots(batch_range=ExpiredBatchRange.all_batch_range()) + assert not state_sync.get_snapshots(all_snapshots) + + +def test_delete_expired_snapshots_previous_finalized_snapshots( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable +): + """Test that expired snapshots are protected if they are part of previous finalized snapshots + in a non-finalized environment.""" + now_ts = now_timestamp() + + # Create an old snapshot that will be expired + old_snapshot = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, ds"), + ), + ) + old_snapshot.ttl = "in 10 seconds" + old_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + # Create a new snapshot + new_snapshot = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, b, ds"), + ), + ) + new_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + state_sync.push_snapshots([old_snapshot, new_snapshot]) + + # Promote the old snapshot to an environment and finalize it + env = Environment( + name="test_environment", + snapshots=[old_snapshot.table_info], + start_at="2022-01-01", + end_at="2022-01-01", + plan_id="test_plan_id", + previous_plan_id="test_plan_id", + ) + state_sync.promote(env) + state_sync.finalize(env) + + # Verify old snapshot is not cleaned up because it's in a finalized environment + assert not _get_cleanup_tasks(state_sync) + + # Now promote the new snapshot to the same environment (this simulates a new plan) + # The environment will have previous_finalized_snapshots set to the old snapshot + # and will not be finalized yet + env = Environment( + name="test_environment", + snapshots=[new_snapshot.table_info], + previous_finalized_snapshots=[old_snapshot.table_info], + start_at="2022-01-01", + end_at="2022-01-01", + plan_id="new_plan_id", + previous_plan_id="test_plan_id", + ) + state_sync.promote(env) + + # Manually update the snapshots updated_ts to simulate expiration + state_sync.engine_adapter.execute( + f"UPDATE sqlmesh._snapshots SET updated_ts = {now_ts - 15000} WHERE name = '{old_snapshot.name}' AND identifier = '{old_snapshot.identifier}'" + ) + + # The old snapshot should still not be cleaned up because it's part of + # previous_finalized_snapshots in a non-finalized environment + assert not _get_cleanup_tasks(state_sync) + state_sync.delete_expired_snapshots(batch_range=ExpiredBatchRange.all_batch_range()) + assert state_sync.snapshots_exist([old_snapshot.snapshot_id]) == {old_snapshot.snapshot_id} + + # Once the environment is finalized, the expired snapshot should be removed successfully + state_sync.finalize(env) + assert _get_cleanup_tasks(state_sync) == [ + SnapshotTableCleanupTask(snapshot=old_snapshot.table_info, dev_table_only=False), + ] + state_sync.delete_expired_snapshots(batch_range=ExpiredBatchRange.all_batch_range()) + assert not state_sync.snapshots_exist([old_snapshot.snapshot_id]) + + +def test_delete_expired_snapshots_dev_table_cleanup_only( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable +): + now_ts = now_timestamp() + + snapshot = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, ds"), + ), + ) + snapshot.ttl = "in 10 seconds" + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot.updated_ts = now_ts - 15000 + + new_snapshot = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, b, ds"), + ), + ) + new_snapshot.ttl = "in 10 seconds" + new_snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) + new_snapshot.version = snapshot.version + new_snapshot.updated_ts = now_ts - 5000 + + all_snapshots = [snapshot, new_snapshot] + state_sync.push_snapshots(all_snapshots) + assert set(state_sync.get_snapshots(all_snapshots)) == { + snapshot.snapshot_id, + new_snapshot.snapshot_id, + } + + assert _get_cleanup_tasks(state_sync) == [ + SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=True) + ] + state_sync.delete_expired_snapshots(batch_range=ExpiredBatchRange.all_batch_range()) + + assert set(state_sync.get_snapshots(all_snapshots)) == {new_snapshot.snapshot_id} + + +def test_delete_expired_snapshots_shared_dev_table( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable +): + now_ts = now_timestamp() + + snapshot = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, ds"), + ), + ) + snapshot.ttl = "in 10 seconds" + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot.updated_ts = now_ts - 15000 + + new_snapshot = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, b, ds"), + ), + ) + new_snapshot.ttl = "in 10 seconds" + new_snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) + new_snapshot.version = snapshot.version + new_snapshot.dev_version_ = snapshot.dev_version + new_snapshot.updated_ts = now_ts - 5000 + + all_snapshots = [snapshot, new_snapshot] + state_sync.push_snapshots(all_snapshots) + assert set(state_sync.get_snapshots(all_snapshots)) == { + snapshot.snapshot_id, + new_snapshot.snapshot_id, + } + + assert not _get_cleanup_tasks(state_sync) # No dev table cleanup + state_sync.delete_expired_snapshots(batch_range=ExpiredBatchRange.all_batch_range()) + assert set(state_sync.get_snapshots(all_snapshots)) == {new_snapshot.snapshot_id} + + +def test_delete_expired_snapshots_ignore_ttl( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable +): + snapshot_a = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, ds"), + ) + ) + snapshot_a.categorize_as(SnapshotChangeCategory.NON_BREAKING) + + snapshot_b = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, b, ds"), + ), + version="2", + ) + snapshot_b.categorize_as(SnapshotChangeCategory.NON_BREAKING) + + snapshot_c = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, b, c, ds"), + ), + ) + snapshot_c.categorize_as(SnapshotChangeCategory.NON_BREAKING) + + state_sync.push_snapshots([snapshot_a, snapshot_b, snapshot_c]) + + env = Environment( + name="test_environment", + snapshots=[snapshot_a.table_info, snapshot_b.table_info], + start_at="2022-01-01", + end_at="2022-01-01", + plan_id="test_plan_id", + previous_plan_id="test_plan_id", + ) + state_sync.promote(env) + + # default TTL = 1 week, nothing to clean up yet if we take TTL into account + assert not _get_cleanup_tasks(state_sync) + state_sync.delete_expired_snapshots(batch_range=ExpiredBatchRange.all_batch_range()) + assert state_sync.snapshots_exist([snapshot_c.snapshot_id]) == {snapshot_c.snapshot_id} + + # If we ignore TTL, only snapshot_c should get cleaned up because snapshot_a and snapshot_b are part of an environment + assert snapshot_a.table_info != snapshot_b.table_info != snapshot_c.table_info + assert _get_cleanup_tasks(state_sync, ignore_ttl=True) == [ + SnapshotTableCleanupTask(snapshot=snapshot_c.table_info, dev_table_only=False) + ] + state_sync.delete_expired_snapshots( + batch_range=ExpiredBatchRange.all_batch_range(), ignore_ttl=True + ) + assert not state_sync.snapshots_exist([snapshot_c.snapshot_id]) + + +def test_delete_expired_snapshots_cleanup_intervals( + state_sync: EngineAdapterStateSync, + make_snapshot: t.Callable, + get_snapshot_intervals: t.Callable, +): + now_ts = now_timestamp() + + # Expired snapshot + snapshot = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, ds"), + ), + ) + snapshot.ttl = "in 10 seconds" + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot.updated_ts = now_ts - 15000 + + # Another expired snapshot with the same version + new_snapshot = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, b, ds"), + ), + ) + new_snapshot.ttl = "in 10 seconds" + new_snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) + new_snapshot.version = snapshot.version + new_snapshot.updated_ts = now_ts - 12000 + + state_sync.push_snapshots([snapshot, new_snapshot]) + state_sync.add_interval(snapshot, "2023-01-01", "2023-01-03") + state_sync.add_interval(snapshot, "2023-01-04", "2023-01-05", is_dev=True) + state_sync.add_interval(snapshot, "2023-01-06", "2023-01-07") + state_sync.remove_intervals( + [(snapshot, (to_timestamp("2023-01-06"), to_timestamp("2023-01-08")))] + ) + + state_sync.add_interval(new_snapshot, "2023-01-04", "2023-01-05") + state_sync.add_interval(new_snapshot, "2023-01-06", "2023-01-07") + state_sync.remove_intervals( + [(new_snapshot, (to_timestamp("2023-01-06"), to_timestamp("2023-01-08")))] + ) + + # Check old snapshot's intervals + stored_snapshot = state_sync.get_snapshots([snapshot])[snapshot.snapshot_id] + assert stored_snapshot.intervals == [ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-06")), + ] + assert stored_snapshot.dev_intervals == [ + (to_timestamp("2023-01-04"), to_timestamp("2023-01-06")), + ] + + # Check new snapshot's intervals + stored_new_snapshot = state_sync.get_snapshots([new_snapshot])[new_snapshot.snapshot_id] + assert stored_new_snapshot.intervals == [ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-06")), + ] + assert not stored_new_snapshot.dev_intervals + + assert _get_cleanup_tasks(state_sync) == [ + SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=True), + SnapshotTableCleanupTask(snapshot=new_snapshot.table_info, dev_table_only=False), + ] + state_sync.delete_expired_snapshots(batch_range=ExpiredBatchRange.all_batch_range()) + + assert not get_snapshot_intervals(snapshot) + + +def test_delete_expired_snapshots_cleanup_intervals_shared_version( + state_sync: EngineAdapterStateSync, + make_snapshot: t.Callable, +): + now_ts = now_timestamp() + + # Expired snapshot + snapshot = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, ds"), + ), + ) + snapshot.ttl = "in 10 seconds" + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot.updated_ts = now_ts - 15000 + + # New non-expired snapshot with the same version + new_snapshot = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, b, ds"), + ), + ) + new_snapshot.ttl = "in 10 seconds" + new_snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) + new_snapshot.version = snapshot.version + new_snapshot.updated_ts = now_ts - 5000 + + state_sync.push_snapshots([snapshot, new_snapshot]) + state_sync.add_interval(snapshot, "2023-01-01", "2023-01-03") + state_sync.add_interval(snapshot, "2023-01-01", "2023-01-03", is_dev=True) + state_sync.add_interval(new_snapshot, "2023-01-04", "2023-01-07") + state_sync.remove_intervals( + [(new_snapshot, (to_timestamp("2023-01-06"), to_timestamp("2023-01-08")))] + ) + + # Check new snapshot's intervals + stored_new_snapshot = state_sync.get_snapshots([new_snapshot])[new_snapshot.snapshot_id] + assert stored_new_snapshot.intervals == [ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-06")), + ] + assert not stored_new_snapshot.dev_intervals + + # Check old snapshot's intervals + stored_snapshot = state_sync.get_snapshots([snapshot])[snapshot.snapshot_id] + assert stored_snapshot.intervals == [ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-06")), + ] + assert stored_snapshot.dev_intervals == [ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-04")), + ] + + # Check all intervals + assert sorted( + state_sync.interval_state.get_snapshot_intervals([snapshot, new_snapshot]), + key=compare_snapshot_intervals, + ) == sorted( + [ + SnapshotIntervals( + name='"a"', + identifier=snapshot.identifier, + version=snapshot.version, + dev_version=snapshot.dev_version, + intervals=[(to_timestamp("2023-01-01"), to_timestamp("2023-01-04"))], + dev_intervals=[(to_timestamp("2023-01-01"), to_timestamp("2023-01-04"))], + ), + SnapshotIntervals( + name='"a"', + identifier=new_snapshot.identifier, + version=snapshot.version, + dev_version=new_snapshot.dev_version, + intervals=[(to_timestamp("2023-01-04"), to_timestamp("2023-01-06"))], + ), + ], + key=compare_snapshot_intervals, + ) + + # Delete the expired snapshot + assert _get_cleanup_tasks(state_sync) == [ + SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=True), + ] + state_sync.delete_expired_snapshots(batch_range=ExpiredBatchRange.all_batch_range()) + assert not state_sync.get_snapshots([snapshot]) + + # Check new snapshot's intervals + stored_new_snapshot = state_sync.get_snapshots([new_snapshot])[new_snapshot.snapshot_id] + assert stored_new_snapshot.intervals == [ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-06")), + ] + assert not stored_new_snapshot.dev_intervals + + # Check all intervals + assert sorted( + state_sync.interval_state.get_snapshot_intervals([snapshot, new_snapshot]), + key=compare_snapshot_intervals, + ) == sorted( + [ + # The intervals of the old snapshot is preserved with the null identifier + SnapshotIntervals( + name='"a"', + identifier=None, + version=snapshot.version, + dev_version=None, + intervals=[(to_timestamp("2023-01-01"), to_timestamp("2023-01-04"))], + ), + # The intervals of the new snapshot has identifier + SnapshotIntervals( + name='"a"', + identifier=new_snapshot.identifier, + version=snapshot.version, + dev_version=new_snapshot.dev_version, + intervals=[(to_timestamp("2023-01-04"), to_timestamp("2023-01-06"))], + ), + ], + key=compare_snapshot_intervals, + ) + + +def test_delete_expired_snapshots_cleanup_intervals_shared_dev_version( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable +): + now_ts = now_timestamp() + + snapshot = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, ds"), + ), + ) + snapshot.ttl = "in 10 seconds" + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot.updated_ts = now_ts - 15000 + + new_snapshot = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, b, ds"), + ), + ) + new_snapshot.ttl = "in 10 seconds" + new_snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) + new_snapshot.version = snapshot.version + new_snapshot.dev_version_ = snapshot.dev_version + new_snapshot.updated_ts = now_ts - 5000 + + state_sync.push_snapshots([snapshot, new_snapshot]) + + state_sync.add_interval(snapshot, "2023-01-01", "2023-01-03") + state_sync.add_interval(snapshot, "2023-01-04", "2023-01-07", is_dev=True) + state_sync.add_interval(new_snapshot, "2023-01-08", "2023-01-10", is_dev=True) + state_sync.remove_intervals( + [(new_snapshot, (to_timestamp("2023-01-10"), to_timestamp("2023-01-11")))] + ) + + # Check new snapshot's intervals + stored_new_snapshot = state_sync.get_snapshots([new_snapshot])[new_snapshot.snapshot_id] + assert stored_new_snapshot.intervals == [ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-04")), + ] + assert stored_new_snapshot.dev_intervals == [ + (to_timestamp("2023-01-04"), to_timestamp("2023-01-11")), + ] + + # Check old snapshot's intervals + stored_snapshot = state_sync.get_snapshots([snapshot])[snapshot.snapshot_id] + assert stored_snapshot.intervals == [ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-04")), + ] + assert stored_snapshot.dev_intervals == [ + (to_timestamp("2023-01-04"), to_timestamp("2023-01-11")), + ] + + # Check all intervals + assert sorted( + state_sync.interval_state.get_snapshot_intervals([snapshot, new_snapshot]), + key=compare_snapshot_intervals, + ) == sorted( + [ + SnapshotIntervals( + name='"a"', + identifier=snapshot.identifier, + version=snapshot.version, + dev_version=snapshot.dev_version, + intervals=[(to_timestamp("2023-01-01"), to_timestamp("2023-01-04"))], + dev_intervals=[(to_timestamp("2023-01-04"), to_timestamp("2023-01-08"))], + ), + SnapshotIntervals( + name='"a"', + identifier=new_snapshot.identifier, + version=snapshot.version, + dev_version=new_snapshot.dev_version, + dev_intervals=[(to_timestamp("2023-01-08"), to_timestamp("2023-01-11"))], + ), + ], + key=compare_snapshot_intervals, + ) + + # Delete the expired snapshot + assert not _get_cleanup_tasks(state_sync) + state_sync.delete_expired_snapshots(batch_range=ExpiredBatchRange.all_batch_range()) + assert not state_sync.get_snapshots([snapshot]) + + # Check new snapshot's intervals + stored_new_snapshot = state_sync.get_snapshots([new_snapshot])[new_snapshot.snapshot_id] + assert stored_new_snapshot.intervals == [ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-04")), + ] + assert stored_new_snapshot.dev_intervals == [ + (to_timestamp("2023-01-04"), to_timestamp("2023-01-11")), + ] + + # Check all intervals + assert sorted( + state_sync.interval_state.get_snapshot_intervals([snapshot, new_snapshot]), + key=compare_snapshot_intervals, + ) == sorted( + [ + SnapshotIntervals( + name='"a"', + identifier=None, + version=snapshot.version, + dev_version=None, + intervals=[(to_timestamp("2023-01-01"), to_timestamp("2023-01-04"))], + ), + SnapshotIntervals( + name='"a"', + identifier=None, + version=snapshot.version, + dev_version=snapshot.dev_version, + dev_intervals=[(to_timestamp("2023-01-04"), to_timestamp("2023-01-08"))], + ), + SnapshotIntervals( + name='"a"', + identifier=new_snapshot.identifier, + version=snapshot.version, + dev_version=new_snapshot.dev_version, + dev_intervals=[(to_timestamp("2023-01-08"), to_timestamp("2023-01-11"))], + ), + ], + key=compare_snapshot_intervals, + ) + + +def test_compact_intervals_after_cleanup( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable +): + now_ts = now_timestamp() + + # Original expired snapshot + snapshot_a = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, ds"), + ), + ) + snapshot_a.ttl = "in 10 seconds" + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_a.updated_ts = now_ts - 15000 + + # A forward-only change on top of the original snapshot. Also expired. + # This snapshot reuses only prod table + snapshot_b = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, b, ds"), + ), + ) + snapshot_b.previous_versions = snapshot_a.all_versions + snapshot_b.ttl = "in 10 seconds" + snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) + snapshot_b.updated_ts = now_ts - 12000 + + # An indirect non-breaking change on top of the forward-only change. Not expired. + # This snapshot reuses both prod and dev tables + snapshot_c = make_snapshot(snapshot_b.model.copy(update={"stamp": "1"})) + snapshot_c.previous_versions = snapshot_b.all_versions + snapshot_c.ttl = "in 10 seconds" + snapshot_c.change_category = SnapshotChangeCategory.INDIRECT_NON_BREAKING + snapshot_c.version = snapshot_b.version + snapshot_c.dev_version_ = snapshot_b.dev_version + snapshot_c.updated_ts = now_ts - 5000 + + state_sync.push_snapshots([snapshot_a, snapshot_b, snapshot_c]) + + state_sync.add_interval(snapshot_a, "2023-01-01", "2023-01-03") + state_sync.add_interval(snapshot_a, "2023-01-01", "2023-01-03", is_dev=True) + state_sync.add_interval(snapshot_b, "2023-01-04", "2023-01-06") + state_sync.add_interval(snapshot_b, "2023-01-04", "2023-01-06", is_dev=True) + state_sync.add_interval(snapshot_c, "2023-01-07", "2023-01-09") + state_sync.add_interval(snapshot_c, "2023-01-07", "2023-01-09", is_dev=True) + + # Only the dev table of the original snapshot should be deleted + assert _get_cleanup_tasks(state_sync) == [ + SnapshotTableCleanupTask(snapshot=snapshot_a.table_info, dev_table_only=True), + ] + state_sync.delete_expired_snapshots(batch_range=ExpiredBatchRange.all_batch_range()) + + assert state_sync.engine_adapter.fetchone("SELECT COUNT(*) FROM sqlmesh._intervals")[0] == 5 # type: ignore + + expected_intervals = [ + # Combined intervals from the original and the forward-only expired snapshots + SnapshotIntervals( + name='"a"', + identifier=None, + version=snapshot_a.version, + dev_version=None, + intervals=[(to_timestamp("2023-01-01"), to_timestamp("2023-01-07"))], + ), + # Dev intervals from the forward-only expired snapshot + SnapshotIntervals( + name='"a"', + identifier=None, + version=snapshot_b.version, + dev_version=snapshot_b.dev_version, + dev_intervals=[(to_timestamp("2023-01-04"), to_timestamp("2023-01-07"))], + ), + # Intervals from the indirect non-breaking snapshot + SnapshotIntervals( + name='"a"', + identifier=snapshot_c.identifier, + version=snapshot_c.version, + dev_version=snapshot_c.dev_version, + intervals=[(to_timestamp("2023-01-07"), to_timestamp("2023-01-10"))], + dev_intervals=[(to_timestamp("2023-01-07"), to_timestamp("2023-01-10"))], + ), + ] + + assert ( + sorted( + state_sync.interval_state.get_snapshot_intervals([snapshot_a, snapshot_b, snapshot_c]), + key=lambda x: (x.identifier or "", x.dev_version or ""), + ) + == expected_intervals + ) + + state_sync.compact_intervals() + + assert state_sync.engine_adapter.fetchone("SELECT COUNT(*) FROM sqlmesh._intervals")[0] == 4 # type: ignore + assert ( + sorted( + state_sync.interval_state.get_snapshot_intervals([snapshot_a, snapshot_b, snapshot_c]), + key=lambda x: (x.identifier or "", x.dev_version or ""), + ) + == expected_intervals + ) + + +def test_environment_start_as_timestamp( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable +): + snapshot = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, ds"), + ), + ) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + state_sync.push_snapshots([snapshot]) + + now_ts = now_timestamp() + + env = Environment( + name="test_environment_a", + snapshots=[snapshot.table_info], + start_at=now_ts, + end_at=None, + plan_id="test_plan_id", + previous_plan_id="test_plan_id", + expiration_ts=now_ts - 1000, + ) + state_sync.promote(env) + + stored_env = state_sync.get_environment(env.name) + assert stored_env + assert stored_env.start_at == to_datetime(now_ts).replace(tzinfo=None).isoformat(sep=" ") + + +def test_unpause_snapshots(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable): + snapshot = make_snapshot( + SqlModel( + name="test_snapshot", + query=parse_one("select 1, ds"), + cron="@daily", + ), + ) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot.version = "a" + + assert not snapshot.unpaused_ts + state_sync.push_snapshots([snapshot]) + + unpaused_dt = "2022-01-01" + state_sync.unpause_snapshots([snapshot], unpaused_dt) + + actual_snapshot = state_sync.get_snapshots([snapshot])[snapshot.snapshot_id] + assert actual_snapshot.unpaused_ts + assert actual_snapshot.unpaused_ts == to_timestamp(unpaused_dt) + + new_snapshot = make_snapshot( + SqlModel(name="test_snapshot", query=parse_one("select 2, ds"), cron="@daily") + ) + new_snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) + new_snapshot.version = "a" + + assert not new_snapshot.unpaused_ts + state_sync.push_snapshots([new_snapshot]) + state_sync.unpause_snapshots([new_snapshot], unpaused_dt) + + actual_snapshots = state_sync.get_snapshots([snapshot, new_snapshot]) + assert not actual_snapshots[snapshot.snapshot_id].unpaused_ts + assert actual_snapshots[new_snapshot.snapshot_id].unpaused_ts == to_timestamp(unpaused_dt) + + assert actual_snapshots[snapshot.snapshot_id].unrestorable + assert not actual_snapshots[new_snapshot.snapshot_id].unrestorable + + +def test_unrestorable_snapshot(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable): + snapshot = make_snapshot( + SqlModel( + name="test_snapshot", + query=parse_one("select 1, ds"), + cron="@daily", + ), + ) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot.version = "a" + + assert not snapshot.unpaused_ts + state_sync.push_snapshots([snapshot]) + + unpaused_dt = "2022-01-01" + state_sync.unpause_snapshots([snapshot], unpaused_dt) + + actual_snapshot = state_sync.get_snapshots([snapshot])[snapshot.snapshot_id] + assert actual_snapshot.unpaused_ts + assert actual_snapshot.unpaused_ts == to_timestamp(unpaused_dt) + + new_indirect_non_breaking_snapshot = make_snapshot( + SqlModel(name="test_snapshot", query=parse_one("select 2, ds"), cron="@daily") + ) + new_indirect_non_breaking_snapshot.categorize_as(SnapshotChangeCategory.INDIRECT_NON_BREAKING) + new_indirect_non_breaking_snapshot.version = "a" + + assert not new_indirect_non_breaking_snapshot.unpaused_ts + state_sync.push_snapshots([new_indirect_non_breaking_snapshot]) + state_sync.unpause_snapshots([new_indirect_non_breaking_snapshot], unpaused_dt) + + actual_snapshots = state_sync.get_snapshots([snapshot, new_indirect_non_breaking_snapshot]) + assert not actual_snapshots[snapshot.snapshot_id].unpaused_ts + assert actual_snapshots[ + new_indirect_non_breaking_snapshot.snapshot_id + ].unpaused_ts == to_timestamp(unpaused_dt) + + assert not actual_snapshots[snapshot.snapshot_id].unrestorable + assert not actual_snapshots[new_indirect_non_breaking_snapshot.snapshot_id].unrestorable + + new_forward_only_snapshot = make_snapshot( + SqlModel(name="test_snapshot", query=parse_one("select 3, ds"), cron="@daily") + ) + new_forward_only_snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) + new_forward_only_snapshot.version = "a" + + assert not new_forward_only_snapshot.unpaused_ts + state_sync.push_snapshots([new_forward_only_snapshot]) + state_sync.unpause_snapshots([new_forward_only_snapshot], unpaused_dt) + + actual_snapshots = state_sync.get_snapshots( + [snapshot, new_indirect_non_breaking_snapshot, new_forward_only_snapshot] + ) + assert not actual_snapshots[snapshot.snapshot_id].unpaused_ts + assert not actual_snapshots[new_indirect_non_breaking_snapshot.snapshot_id].unpaused_ts + assert actual_snapshots[new_forward_only_snapshot.snapshot_id].unpaused_ts == to_timestamp( + unpaused_dt + ) + + assert actual_snapshots[snapshot.snapshot_id].unrestorable + assert actual_snapshots[new_indirect_non_breaking_snapshot.snapshot_id].unrestorable + assert not actual_snapshots[new_forward_only_snapshot.snapshot_id].unrestorable + + +def test_unrestorable_snapshot_target_not_forward_only( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable +): + snapshot = make_snapshot( + SqlModel( + name="test_snapshot", + query=parse_one("select 1, ds"), + cron="@daily", + ), + ) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) + snapshot.version = "a" + + assert not snapshot.unpaused_ts + state_sync.push_snapshots([snapshot]) + + unpaused_dt = "2022-01-01" + state_sync.unpause_snapshots([snapshot], unpaused_dt) + + actual_snapshot = state_sync.get_snapshots([snapshot])[snapshot.snapshot_id] + assert actual_snapshot.unpaused_ts + assert actual_snapshot.unpaused_ts == to_timestamp(unpaused_dt) + + updated_snapshot = make_snapshot( + SqlModel(name="test_snapshot", query=parse_one("select 2, ds"), cron="@daily") + ) + updated_snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=False) + updated_snapshot.version = "a" + + assert not updated_snapshot.unpaused_ts + state_sync.push_snapshots([updated_snapshot]) + state_sync.unpause_snapshots([updated_snapshot], unpaused_dt) + + actual_snapshots = state_sync.get_snapshots([snapshot, updated_snapshot]) + assert not actual_snapshots[snapshot.snapshot_id].unpaused_ts + assert actual_snapshots[updated_snapshot.snapshot_id].unpaused_ts == to_timestamp(unpaused_dt) + + assert actual_snapshots[snapshot.snapshot_id].unrestorable + assert not actual_snapshots[updated_snapshot.snapshot_id].unrestorable + + +def test_version_schema(state_sync: EngineAdapterStateSync, tmp_path) -> None: + from sqlmesh import __version__ as SQLMESH_VERSION + + # fresh install should not raise + assert state_sync.get_versions() == Versions( + schema_version=SCHEMA_VERSION, + sqlglot_version=SQLGLOT_VERSION, + sqlmesh_version=SQLMESH_VERSION, + ) + + # Start with a clean slate. + state_sync = EngineAdapterStateSync( + create_engine_adapter(duckdb.connect, "duckdb"), + schema=c.SQLMESH, + cache_dir=tmp_path / c.CACHE, + ) + + with pytest.raises( + SQLMeshError, + match=rf"SQLMesh \(local\) is using version '{re.escape(SQLMESH_VERSION)}' which is ahead of '0.0.0' \(remote\). Please run a migration \('sqlmesh migrate' command\).", + ): + state_sync.get_versions() + + state_sync.migrate() + + # migration version is behind, always raise + state_sync.version_state.update_versions(schema_version=SCHEMA_VERSION + 1) + error = ( + rf"SQLMesh \(local\) is using version '{SCHEMA_VERSION}' which is behind '{SCHEMA_VERSION + 1}' \(remote\). " + rf"""Please upgrade SQLMesh \('pip install --upgrade "sqlmesh=={re.escape(SQLMESH_VERSION)}"' command\).""" + ) + + with pytest.raises(SQLMeshError, match=error): + state_sync.get_versions() + + # should no longer raise + state_sync.get_versions(validate=False) + + # migration version is ahead, only raise when validate is true + state_sync.version_state.update_versions(schema_version=SCHEMA_VERSION - 1) + with pytest.raises( + SQLMeshError, + match=rf"SQLMesh \(local\) is using version '{SCHEMA_VERSION}' which is ahead of '{SCHEMA_VERSION - 1}'", + ): + state_sync.get_versions() + state_sync.get_versions(validate=False) + + +def test_version_sqlmesh(state_sync: EngineAdapterStateSync) -> None: + from sqlmesh import __version__ as SQLMESH_VERSION + from sqlmesh import __version_tuple__ as SQLMESH_VERSION_TUPLE + + # patch version sqlmesh doesn't matter + major, minor, patch, *_ = SQLMESH_VERSION_TUPLE + new_patch = ( + f"dev{int(patch[3:]) + 1}" # type: ignore + if isinstance(patch, str) and patch.startswith("dev") + else f"{int(patch) + 1}" + ) + sqlmesh_version_patch_bump = f"{major}.{minor}.{new_patch}" + state_sync.version_state.update_versions(sqlmesh_version=sqlmesh_version_patch_bump) + state_sync.get_versions(validate=False) + + # sqlmesh version is behind + sqlmesh_version_minor_bump = f"{major}.{int(minor) + 1}.{patch}" + error = ( + rf"SQLMesh \(local\) is using version '{re.escape(SQLMESH_VERSION)}' which is behind '{sqlmesh_version_minor_bump}' \(remote\). " + rf"""Please upgrade SQLMesh \('pip install --upgrade "sqlmesh=={sqlmesh_version_minor_bump}"' command\).""" + ) + state_sync.version_state.update_versions(sqlmesh_version=sqlmesh_version_minor_bump) + with pytest.raises(SQLMeshError, match=error): + state_sync.get_versions() + state_sync.get_versions(validate=False) + + # sqlmesh version is ahead + sqlmesh_version_minor_decrease = f"{major}.{int(minor) - 1}.{patch}" + error = rf"SQLMesh \(local\) is using version '{re.escape(SQLMESH_VERSION)}' which is ahead of '{sqlmesh_version_minor_decrease}'" + state_sync.version_state.update_versions(sqlmesh_version=sqlmesh_version_minor_decrease) + with pytest.raises(SQLMeshError, match=error): + state_sync.get_versions() + state_sync.get_versions(validate=False) + + +def test_version_sqlglot(state_sync: EngineAdapterStateSync) -> None: + # patch version sqlglot doesn't matter + major, minor, patch, *_ = SQLGLOT_VERSION.split(".") + sqlglot_version = f"{major}.{minor}.{int(patch) + 1}" + state_sync.version_state.update_versions(sqlglot_version=sqlglot_version) + state_sync.get_versions(validate=False) + + # sqlglot version is behind + sqlglot_version = f"{major}.{int(minor) + 1}.{patch}" + error = ( + rf"SQLGlot \(local\) is using version '{SQLGLOT_VERSION}' which is behind '{sqlglot_version}' \(remote\). " + rf"""Please upgrade SQLGlot \('pip install --upgrade "sqlglot=={sqlglot_version}"' command\).""" + ) + state_sync.version_state.update_versions(sqlglot_version=sqlglot_version) + with pytest.raises(SQLMeshError, match=error): + state_sync.get_versions() + state_sync.get_versions(validate=False) + + # sqlglot version is ahead + sqlglot_version = f"{major}.{int(minor) - 1}.{patch}" + error = rf"SQLGlot \(local\) is using version '{SQLGLOT_VERSION}' which is ahead of '{sqlglot_version}'" + state_sync.version_state.update_versions(sqlglot_version=sqlglot_version) + with pytest.raises(SQLMeshError, match=error): + state_sync.get_versions() + state_sync.get_versions(validate=False) + + +def test_empty_versions() -> None: + for empty_versions in ( + Versions(), + Versions(schema_version=None, sqlglot_version=None, sqlmesh_version=None), + ): + assert empty_versions.schema_version == 0 + assert empty_versions.sqlglot_version == "0.0.0" + assert empty_versions.sqlmesh_version == "0.0.0" + + +def test_migrate(state_sync: EngineAdapterStateSync, mocker: MockerFixture, tmp_path) -> None: + from sqlmesh import __version__ as SQLMESH_VERSION + + migrate_rows_mock = mocker.patch( + "sqlmesh.core.state_sync.db.migrator.StateMigrator._migrate_rows" + ) + backup_state_mock = mocker.patch( + "sqlmesh.core.state_sync.db.migrator.StateMigrator._backup_state" + ) + state_sync.migrate() + migrate_rows_mock.assert_not_called() + backup_state_mock.assert_not_called() + + # Start with a clean slate. + state_sync = EngineAdapterStateSync( + create_engine_adapter(duckdb.connect, "duckdb"), + schema=c.SQLMESH, + cache_dir=tmp_path / c.CACHE, + ) + + state_sync.migrate() + migrate_rows_mock.assert_called_once() + backup_state_mock.assert_called_once() + assert state_sync.get_versions() == Versions( + schema_version=SCHEMA_VERSION, + sqlglot_version=SQLGLOT_VERSION, + sqlmesh_version=SQLMESH_VERSION, + ) + + assert ( + state_sync.engine_adapter.fetchone( + "SELECT COUNT(*) FROM sqlmesh._snapshots WHERE ttl_ms IS NULL" + )[0] # type: ignore + == 0 + ) + + +def test_rollback(state_sync: EngineAdapterStateSync, mocker: MockerFixture) -> None: + with pytest.raises( + SQLMeshError, + match="There are no prior migrations to roll back to.", + ): + state_sync.rollback() + + restore_table_spy = mocker.spy(state_sync.migrator, "_restore_table") + state_sync.migrator._backup_state() + + state_sync.rollback() + calls = {(a.sql(), b.sql()) for (a, b), _ in restore_table_spy.call_args_list} + assert ( + f"{state_sync.schema}._snapshots", + f"{state_sync.schema}._snapshots_backup", + ) in calls + assert ( + f"{state_sync.schema}._environments", + f"{state_sync.schema}._environments_backup", + ) in calls + assert ( + f"{state_sync.schema}._versions", + f"{state_sync.schema}._versions_backup", + ) in calls + assert not state_sync.engine_adapter.table_exists(f"{state_sync.schema}._snapshots_backup") + assert not state_sync.engine_adapter.table_exists(f"{state_sync.schema}._environments_backup") + assert not state_sync.engine_adapter.table_exists(f"{state_sync.schema}._versions_backup") + + +def test_first_migration_failure(duck_conn, mocker: MockerFixture, tmp_path) -> None: + state_sync = EngineAdapterStateSync( + create_engine_adapter(lambda: duck_conn, "duckdb"), + schema=c.SQLMESH, + cache_dir=tmp_path / c.CACHE, + ) + mocker.patch.object(state_sync.migrator, "_migrate_rows", side_effect=Exception("mocked error")) + with pytest.raises( + SQLMeshError, + match="SQLMesh migration failed.", + ): + state_sync.migrate() + assert not state_sync.engine_adapter.table_exists(state_sync.snapshot_state.snapshots_table) + assert not state_sync.engine_adapter.table_exists( + state_sync.environment_state.environments_table + ) + assert not state_sync.engine_adapter.table_exists(state_sync.version_state.versions_table) + assert not state_sync.engine_adapter.table_exists(state_sync.interval_state.intervals_table) + + +def test_migrate_rows(state_sync: EngineAdapterStateSync, mocker: MockerFixture) -> None: + state_sync.engine_adapter.replace_query( + "sqlmesh._versions", + pd.read_json("tests/fixtures/migrations/versions.json"), + target_columns_to_types={ + "schema_version": exp.DataType.build("int"), + "sqlglot_version": exp.DataType.build("text"), + "sqlmesh_version": exp.DataType.build("text"), + }, + ) + + state_sync.engine_adapter.replace_query( + "sqlmesh._snapshots", + pd.read_json("tests/fixtures/migrations/snapshots.json"), + target_columns_to_types={ + "name": exp.DataType.build("text"), + "identifier": exp.DataType.build("text"), + "version": exp.DataType.build("text"), + "snapshot": exp.DataType.build("text"), + "kind_name": exp.DataType.build("text"), + "updated_ts": exp.DataType.build("bigint"), + "unpaused_ts": exp.DataType.build("bigint"), + "ttl_ms": exp.DataType.build("bigint"), + "unrestorable": exp.DataType.build("boolean"), + }, + ) + + state_sync.engine_adapter.replace_query( + "sqlmesh._environments", + pd.read_json("tests/fixtures/migrations/environments.json"), + target_columns_to_types={ + "name": exp.DataType.build("text"), + "snapshots": exp.DataType.build("text"), + "start_at": exp.DataType.build("text"), + "end_at": exp.DataType.build("text"), + "plan_id": exp.DataType.build("text"), + "previous_plan_id": exp.DataType.build("text"), + "expiration_ts": exp.DataType.build("bigint"), + "finalized_ts": exp.DataType.build("bigint"), + "promoted_snapshot_ids": exp.DataType.build("text"), + "suffix_target": exp.DataType.build("text"), + "catalog_name_override": exp.DataType.build("text"), + "previous_finalized_snapshots": exp.DataType.build("text"), + "normalize_name": exp.DataType.build("boolean"), + "requirements": exp.DataType.build("text"), + }, + ) + + state_sync.engine_adapter.replace_query( + "sqlmesh._intervals", + pd.read_json("tests/fixtures/migrations/intervals.json"), + target_columns_to_types={ + "id": exp.DataType.build("text"), + "created_ts": exp.DataType.build("bigint"), + "name": exp.DataType.build("text"), + "identifier": exp.DataType.build("text"), + "version": exp.DataType.build("text"), + "start_ts": exp.DataType.build("bigint"), + "end_ts": exp.DataType.build("bigint"), + "is_dev": exp.DataType.build("boolean"), + "is_removed": exp.DataType.build("boolean"), + "is_compacted": exp.DataType.build("boolean"), + }, + ) + + old_snapshots = state_sync.engine_adapter.fetchdf("select * from sqlmesh._snapshots") + old_environments = state_sync.engine_adapter.fetchdf("select * from sqlmesh._environments") + + state_sync.migrate(skip_backup=True) + + new_snapshots = state_sync.engine_adapter.fetchdf("select * from sqlmesh._snapshots") + new_environments = state_sync.engine_adapter.fetchdf("select * from sqlmesh._environments") + + assert len(old_snapshots) == 24 + assert len(new_snapshots) == 36 + assert len(old_environments) == len(new_environments) + + start = "2023-01-01" + end = "2023-01-07" + + assert not missing_intervals( + state_sync.get_snapshots( + t.cast(Environment, state_sync.get_environment("staging")).snapshots + ).values(), + start=start, + end=end, + ) + + dev_snapshots = state_sync.get_snapshots( + t.cast(Environment, state_sync.get_environment("dev")).snapshots + ).values() + + assert all(s.migrated for s in dev_snapshots) + assert all(s.change_category is not None for s in dev_snapshots) + + assert not missing_intervals(dev_snapshots, start=start, end=end) + + assert not missing_intervals(dev_snapshots, start="2023-01-08", end="2023-01-10") == 8 + + all_snapshot_ids = [ + SnapshotId(name=name, identifier=identifier) + for name, identifier in state_sync.engine_adapter.fetchall( + "SELECT name, identifier FROM sqlmesh._snapshots" + ) + ] + for s in state_sync.get_snapshots(all_snapshot_ids).values(): + if not s.is_symbolic: + assert s.intervals + + customer_revenue_by_day = new_snapshots.loc[ + new_snapshots["name"] == '"sushi"."customer_revenue_by_day"' + ].iloc[0] + assert json.loads(customer_revenue_by_day["snapshot"])["node"]["query"].startswith( + "JINJA_QUERY_BEGIN" + ) + + +def test_backup_state(state_sync: EngineAdapterStateSync, mocker: MockerFixture) -> None: + state_sync.engine_adapter.replace_query( + "sqlmesh._snapshots", + pd.read_json("tests/fixtures/migrations/snapshots.json"), + target_columns_to_types={ + "name": exp.DataType.build("text"), + "identifier": exp.DataType.build("text"), + "version": exp.DataType.build("text"), + "snapshot": exp.DataType.build("text"), + }, + ) + + state_sync.migrator._backup_state() + pd.testing.assert_frame_equal( + state_sync.engine_adapter.fetchdf("select * from sqlmesh._snapshots"), + state_sync.engine_adapter.fetchdf("select * from sqlmesh._snapshots_backup"), + ) + + +def test_restore_snapshots_table(state_sync: EngineAdapterStateSync) -> None: + snapshot_columns_to_types = { + "name": exp.DataType.build("text"), + "identifier": exp.DataType.build("text"), + "version": exp.DataType.build("text"), + "snapshot": exp.DataType.build("text"), + } + state_sync.engine_adapter.replace_query( + "sqlmesh._snapshots", + pd.read_json("tests/fixtures/migrations/snapshots.json"), + target_columns_to_types=snapshot_columns_to_types, + ) + + old_snapshots = state_sync.engine_adapter.fetchdf("select * from sqlmesh._snapshots") + old_snapshots_count = state_sync.engine_adapter.fetchone( + "select count(*) from sqlmesh._snapshots" + ) + assert old_snapshots_count == (24,) + state_sync.migrator._backup_state() + + state_sync.engine_adapter.delete_from("sqlmesh._snapshots", "TRUE") + snapshots_count = state_sync.engine_adapter.fetchone("select count(*) from sqlmesh._snapshots") + assert snapshots_count == (0,) + state_sync.migrator._restore_table( + table_name="sqlmesh._snapshots", + backup_table_name="sqlmesh._snapshots_backup", + ) + + new_snapshots = state_sync.engine_adapter.fetchdf("select * from sqlmesh._snapshots") + pd.testing.assert_frame_equal( + old_snapshots, + new_snapshots, + ) + + +def test_seed_hydration( + state_sync: EngineAdapterStateSync, + make_snapshot: t.Callable, +): + snapshot = make_snapshot( + SeedModel( + name="a", + kind=SeedKind(path="./path/to/seed"), + seed=Seed(content="header\n1\n2"), + column_hashes={"header": "hash"}, + depends_on=set(), + ) + ) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + state_sync.push_snapshots([snapshot]) + + assert snapshot.model.is_hydrated + assert snapshot.model.seed.content == "header\n1\n2" + + state_sync.snapshot_state.clear_cache() + stored_snapshot = state_sync.get_snapshots([snapshot.snapshot_id])[snapshot.snapshot_id] + assert isinstance(stored_snapshot.model, SeedModel) + assert not stored_snapshot.model.is_hydrated + assert stored_snapshot.model.seed.content == "" + + +def test_nodes_exist(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable): + snapshot = make_snapshot( + SqlModel( + name="a", + query=parse_one("select 1, ds"), + ) + ) + + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + assert not state_sync.nodes_exist([snapshot.name]) + + state_sync.push_snapshots([snapshot]) + + assert state_sync.nodes_exist([snapshot.name]) == {snapshot.name} + + +def test_invalidate_environment(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable): + snapshot = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, ds"), + ), + ) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + state_sync.push_snapshots([snapshot]) + + original_expiration_ts = now_timestamp() + 100000 + + env = Environment( + name="test_environment", + snapshots=[snapshot.table_info], + start_at="2022-01-01", + end_at="2022-01-01", + plan_id="test_plan_id", + previous_plan_id="test_plan_id", + expiration_ts=original_expiration_ts, + ) + environment_statements = [ + EnvironmentStatements( + before_all=["CREATE OR REPLACE TABLE table_1 AS SELECT 'a'"], + after_all=["CREATE OR REPLACE TABLE table_2 AS SELECT 'b'"], + python_env={}, + ) + ] + + state_sync.promote(env, environment_statements=environment_statements) + + assert state_sync.get_environment_statements(env.name) == environment_statements + + assert not state_sync.delete_expired_environments() + state_sync.invalidate_environment("test_environment") + + stored_env = state_sync.get_environment("test_environment") + assert stored_env + assert stored_env.expiration_ts and stored_env.expiration_ts < original_expiration_ts + + deleted_environments = state_sync.delete_expired_environments() + assert len(deleted_environments) == 1 + assert deleted_environments[0].name == "test_environment" + assert state_sync.get_environment_statements(env.name) == [] + + with pytest.raises(SQLMeshError, match="Cannot invalidate the production environment."): + state_sync.invalidate_environment("prod") + + +def test_promote_environment_without_statements( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable +): + snapshot = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, ds"), + ), + ) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + state_sync.push_snapshots([snapshot]) + + original_expiration_ts = now_timestamp() + 100000 + + env = Environment( + name="test_environment", + snapshots=[snapshot.table_info], + start_at="2022-01-01", + end_at="2022-01-01", + plan_id="test_plan_id", + previous_plan_id="test_plan_id", + expiration_ts=original_expiration_ts, + ) + environment_statements = [ + EnvironmentStatements( + before_all=["CREATE OR REPLACE TABLE table_1 AS SELECT 'a'"], + after_all=["CREATE OR REPLACE TABLE table_2 AS SELECT 'b'"], + python_env={}, + ) + ] + + state_sync.promote(env, environment_statements=environment_statements) + + # Verify the environment statements table is populated with the statements + assert state_sync.get_environment_statements(env.name) == environment_statements + + # Scenario where the statements have been removed from the project and then + # If we promote the environment it doesn't contain before_all, after_all statements + state_sync.promote(env, environment_statements=[]) + + # This should trigger an internal update to the environment statements' table to be removed + assert state_sync.get_environment_statements(env.name) == [] + + +def test_cache(state_sync, make_snapshot, mocker): + cache = CachingStateSync(state_sync, ttl=10) + + snapshot = make_snapshot( + SqlModel( + name="a", + query=parse_one("select 'a', 'ds'"), + ), + ) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + now_timestamp = mocker.patch("sqlmesh.core.state_sync.cache.now_timestamp") + now_timestamp.return_value = to_timestamp("2023-01-01 00:00:00") + + # prime the cache with a cached missing snapshot + assert not cache.get_snapshots([snapshot.snapshot_id]) + + # item is cached and shouldn't hit state sync + with patch.object(state_sync, "get_snapshots") as mock: + assert not cache.get_snapshots([snapshot.snapshot_id]) + mock.assert_not_called() + + # prime the cache with a real snapshot + cache.push_snapshots([snapshot]) + assert cache.get_snapshots([snapshot.snapshot_id]) == {snapshot.snapshot_id: snapshot} + + # cache hit + with patch.object(state_sync, "get_snapshots") as mock: + assert cache.get_snapshots([snapshot.snapshot_id]) == {snapshot.snapshot_id: snapshot} + mock.assert_not_called() + + # clear the cache by adding intervals + cache.add_interval(snapshot, "2020-01-01", "2020-01-01") + with patch.object(state_sync, "get_snapshots") as mock: + assert not cache.get_snapshots([snapshot.snapshot_id]) + mock.assert_called() + + # clear the cache by removing intervals + cache.remove_intervals([(snapshot, snapshot.inclusive_exclusive("2020-01-01", "2020-01-01"))]) + + # prime the cache + assert cache.get_snapshots([snapshot.snapshot_id]) == {snapshot.snapshot_id: snapshot} + + # cache hit half way + now_timestamp.return_value = to_timestamp("2023-01-01 00:00:05") + with patch.object(state_sync, "get_snapshots") as mock: + assert cache.get_snapshots([snapshot.snapshot_id]) + mock.assert_not_called() + + # no cache hit + now_timestamp.return_value = to_timestamp("2023-01-01 00:00:11") + with patch.object(state_sync, "get_snapshots") as mock: + assert not cache.get_snapshots([snapshot.snapshot_id]) + mock.assert_called() + + +def test_max_interval_end_per_model( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable +) -> None: + snapshot_a = make_snapshot( + SqlModel( + name="a", + cron="@daily", + query=parse_one("select 1, ds"), + ), + ) + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) + + snapshot_b = make_snapshot( + SqlModel( + name="b", + cron="@daily", + query=parse_one("select 2, ds"), + ), + ) + snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) + + state_sync.push_snapshots([snapshot_a, snapshot_b]) + + state_sync.add_interval(snapshot_a, "2023-01-01", "2023-01-01") + state_sync.add_interval(snapshot_a, "2023-01-02", "2023-01-02") + state_sync.add_interval(snapshot_a, "2023-01-03", "2023-01-03") + state_sync.add_interval(snapshot_b, "2023-01-01", "2023-01-01") + state_sync.add_interval(snapshot_b, "2023-01-02", "2023-01-02") + + environment_name = "test_max_interval_end_for_environment" + + assert state_sync.max_interval_end_per_model(environment_name) == {} + assert state_sync.max_interval_end_per_model(environment_name, {snapshot_a.name}) == {} + + state_sync.promote( + Environment( + name=environment_name, + snapshots=[snapshot_a.table_info, snapshot_b.table_info], + start_at="2023-01-01", + end_at="2023-01-03", + plan_id="test_plan_id", + previous_finalized_snapshots=[snapshot_b.table_info], + ) + ) + + assert state_sync.max_interval_end_per_model(environment_name, {snapshot_a.name}) == { + snapshot_a.name: to_timestamp("2023-01-04") + } + + assert state_sync.max_interval_end_per_model(environment_name, {snapshot_b.name}) == { + snapshot_b.name: to_timestamp("2023-01-03") + } + + assert state_sync.max_interval_end_per_model( + environment_name, {snapshot_a.name, snapshot_b.name} + ) == { + snapshot_a.name: to_timestamp("2023-01-04"), + snapshot_b.name: to_timestamp("2023-01-03"), + } + + assert state_sync.max_interval_end_per_model(environment_name) == { + snapshot_a.name: to_timestamp("2023-01-04"), + snapshot_b.name: to_timestamp("2023-01-03"), + } + + assert state_sync.max_interval_end_per_model(environment_name, {"missing"}) == {} + assert state_sync.max_interval_end_per_model(environment_name, set()) == {} + + +def test_max_interval_end_per_model_with_pending_restatements( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable +) -> None: + snapshot = make_snapshot( + SqlModel( + name="a", + cron="@daily", + query=parse_one("select 1, ds"), + ), + ) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + state_sync.push_snapshots([snapshot]) + + state_sync.add_interval(snapshot, "2023-01-01", "2023-01-01") + state_sync.add_interval(snapshot, "2023-01-02", "2023-01-02") + state_sync.add_interval(snapshot, "2023-01-03", "2023-01-03") + # Add a pending restatement interval + state_sync.add_snapshots_intervals( + [ + SnapshotIntervals( + name=snapshot.name, + identifier=snapshot.identifier, + version=snapshot.version, + dev_version=snapshot.dev_version, + intervals=[], + dev_intervals=[], + pending_restatement_intervals=[ + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")) + ], + ) + ] + ) + + snapshot = state_sync.get_snapshots([snapshot.snapshot_id])[snapshot.snapshot_id] + assert snapshot.intervals == [(to_timestamp("2023-01-01"), to_timestamp("2023-01-04"))] + assert snapshot.pending_restatement_intervals == [ + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")) + ] + + environment_name = "test_max_interval_end_for_environment" + + state_sync.promote( + Environment( + name=environment_name, + snapshots=[snapshot.table_info], + start_at="2023-01-01", + end_at="2023-01-03", + plan_id="test_plan_id", + previous_finalized_snapshots=[], + ) + ) + + assert state_sync.max_interval_end_per_model(environment_name) == { + snapshot.name: to_timestamp("2023-01-04") + } + + +def test_max_interval_end_per_model_ensure_finalized_snapshots( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable +) -> None: + snapshot_a = make_snapshot( + SqlModel( + name="a", + cron="@daily", + query=parse_one("select 1, ds"), + ), + ) + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) + + snapshot_b = make_snapshot( + SqlModel( + name="b", + cron="@daily", + query=parse_one("select 2, ds"), + ), + ) + snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) + + state_sync.push_snapshots([snapshot_a, snapshot_b]) + + state_sync.add_interval(snapshot_a, "2023-01-01", "2023-01-01") + state_sync.add_interval(snapshot_a, "2023-01-02", "2023-01-02") + state_sync.add_interval(snapshot_a, "2023-01-03", "2023-01-03") + state_sync.add_interval(snapshot_b, "2023-01-01", "2023-01-01") + state_sync.add_interval(snapshot_b, "2023-01-02", "2023-01-02") + + environment_name = "test_max_interval_end_for_environment" + + assert state_sync.max_interval_end_per_model(environment_name) == {} + assert state_sync.max_interval_end_per_model(environment_name, {snapshot_a.name}) == {} + + state_sync.promote( + Environment( + name=environment_name, + snapshots=[snapshot_a.table_info, snapshot_b.table_info], + start_at="2023-01-01", + end_at="2023-01-03", + plan_id="test_plan_id", + previous_finalized_snapshots=[snapshot_b.table_info], + ) + ) + + assert ( + state_sync.max_interval_end_per_model( + environment_name, {snapshot_a.name}, ensure_finalized_snapshots=True + ) + == {} + ) + + assert state_sync.max_interval_end_per_model( + environment_name, {snapshot_b.name}, ensure_finalized_snapshots=True + ) == {snapshot_b.name: to_timestamp("2023-01-03")} + + assert state_sync.max_interval_end_per_model( + environment_name, {snapshot_a.name, snapshot_b.name}, ensure_finalized_snapshots=True + ) == {snapshot_b.name: to_timestamp("2023-01-03")} + + assert state_sync.max_interval_end_per_model( + environment_name, ensure_finalized_snapshots=True + ) == {snapshot_b.name: to_timestamp("2023-01-03")} + + assert ( + state_sync.max_interval_end_per_model( + environment_name, {"missing"}, ensure_finalized_snapshots=True + ) + == {} + ) + assert state_sync.max_interval_end_per_model(environment_name, set()) == {} + + +def test_get_snapshots(mocker): + mock = mocker.MagicMock() + cache = CachingStateSync(mock) + cache.get_snapshots([]) + mock.get_snapshots.assert_not_called() + + +def test_snapshot_batching(state_sync, mocker, make_snapshot): + mock = mocker.Mock() + + state_sync.snapshot_state.SNAPSHOT_BATCH_SIZE = 2 + state_sync.snapshot_state.engine_adapter = mock + + snapshot_a = make_snapshot(SqlModel(name="a", query=parse_one("select 1")), "1") + snapshot_b = make_snapshot(SqlModel(name="a", query=parse_one("select 2")), "2") + snapshot_c = make_snapshot(SqlModel(name="a", query=parse_one("select 3")), "3") + + state_sync.delete_snapshots( + ( + snapshot_a, + snapshot_b, + snapshot_c, + ) + ) + calls = mock.delete_from.call_args_list + identifiers = sorted([snapshot_a.identifier, snapshot_b.identifier, snapshot_c.identifier]) + assert mock.delete_from.call_args_list == [ + call( + exp.to_table("sqlmesh._snapshots"), + where=parse_one( + f"(name, identifier) in (('\"a\"', '{identifiers[0]}'), ('\"a\"', '{identifiers[1]}'))" + ), + ), + call( + exp.to_table("sqlmesh._snapshots"), + where=parse_one(f"(name, identifier) in (('\"a\"', '{identifiers[2]}'))"), + ), + ] + + mock.fetchall.side_effect = [ + [ + [ + make_snapshot( + SqlModel(name="a", query=parse_one("select 1")), + ).json(), + "a", + "1", + "1", + 1, + 1, + False, + False, + None, + ], + [ + make_snapshot( + SqlModel(name="a", query=parse_one("select 2")), + ).json(), + "a", + "2", + "2", + 1, + 1, + False, + False, + None, + ], + ], + [ + [ + make_snapshot( + SqlModel(name="a", query=parse_one("select 3")), + ).json(), + "a", + "3", + "3", + 1, + 1, + False, + False, + None, + ], + ], + ] + + snapshots = state_sync.snapshot_state.get_snapshots( + ( + SnapshotId(name="a", identifier="1"), + SnapshotId(name="a", identifier="2"), + SnapshotId(name="a", identifier="3"), + ), + ) + assert len(snapshots) == 3 + calls = mock.fetchall.call_args_list + assert len(calls) == 2 + + +def test_seed_model_metadata_update( + state_sync: EngineAdapterStateSync, + make_snapshot: t.Callable, +): + model = SeedModel( + name="a", + kind=SeedKind(path="./path/to/seed"), + seed=Seed(content="header\n1\n2"), + column_hashes={"header": "hash"}, + depends_on=set(), + ) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + state_sync.push_snapshots([snapshot]) + + model = model.copy(update={"owner": "jen"}) + new_snapshot = make_snapshot(model) + new_snapshot.previous_versions = snapshot.all_versions + new_snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) + + assert snapshot.fingerprint != new_snapshot.fingerprint + assert snapshot.version == new_snapshot.version + + state_sync.push_snapshots([new_snapshot]) + assert len(state_sync.get_snapshots([new_snapshot, snapshot])) == 2 + + +def test_snapshot_cache( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable, mocker: MockerFixture +): + cache_mock = mocker.Mock() + state_sync.snapshot_state._snapshot_cache = cache_mock + + snapshot = make_snapshot(SqlModel(name="a", query=parse_one("select 1"))) + cache_mock.get_or_load.return_value = ({snapshot.snapshot_id: snapshot}, {snapshot.snapshot_id}) + + state_sync.snapshot_state.push_snapshots([snapshot]) + + assert state_sync.get_snapshots([snapshot.snapshot_id]) == {snapshot.snapshot_id: snapshot} + cache_mock.get_or_load.assert_called_once_with({snapshot.snapshot_id}, mocker.ANY) + + # Update the snapshot in the state and make sure this update is reflected on the cached instance. + assert snapshot.unpaused_ts is None + assert not snapshot.unrestorable + state_sync.snapshot_state._update_snapshots( + [snapshot.snapshot_id], unpaused_ts=1, unrestorable=True + ) + new_snapshot = state_sync.get_snapshots([snapshot.snapshot_id])[snapshot.snapshot_id] + assert new_snapshot.unpaused_ts == 1 + assert new_snapshot.unrestorable + + # If the record was deleted from the state, the cached version should not be returned. + state_sync.delete_snapshots([snapshot.snapshot_id]) + assert state_sync.get_snapshots([snapshot.snapshot_id]) == {} + + +def test_update_auto_restatements(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable): + snapshot_a = make_snapshot(SqlModel(name="a", query=parse_one("select 1")), version="1") + snapshot_b = make_snapshot(SqlModel(name="b", query=parse_one("select 2")), version="2") + snapshot_c = make_snapshot(SqlModel(name="c", query=parse_one("select 3")), version="3") + + state_sync.snapshot_state.push_snapshots([snapshot_a, snapshot_b, snapshot_c]) + + next_auto_restatement_ts: t.Dict[SnapshotNameVersion, t.Optional[int]] = { + snapshot_a.name_version: 1, + snapshot_b.name_version: 2, + snapshot_c.name_version: 3, + } + state_sync.update_auto_restatements(next_auto_restatement_ts) + + snapshots = state_sync.get_snapshots( + [snapshot_a.snapshot_id, snapshot_b.snapshot_id, snapshot_c.snapshot_id] + ) + assert snapshots[snapshot_a.snapshot_id].next_auto_restatement_ts == 1 + assert snapshots[snapshot_b.snapshot_id].next_auto_restatement_ts == 2 + assert snapshots[snapshot_c.snapshot_id].next_auto_restatement_ts == 3 + + next_auto_restatement_ts = { + snapshot_a.name_version: 4, + snapshot_b.name_version: 5, + snapshot_c.name_version: None, + } + state_sync.update_auto_restatements(next_auto_restatement_ts) + + snapshots = state_sync.get_snapshots( + [snapshot_a.snapshot_id, snapshot_b.snapshot_id, snapshot_c.snapshot_id] + ) + assert snapshots[snapshot_a.snapshot_id].next_auto_restatement_ts == 4 + assert snapshots[snapshot_b.snapshot_id].next_auto_restatement_ts == 5 + assert snapshots[snapshot_c.snapshot_id].next_auto_restatement_ts is None + + +@time_machine.travel("2020-01-05 00:00:00 UTC") +def test_compact_intervals_pending_restatement( + state_sync: EngineAdapterStateSync, + make_snapshot: t.Callable, + get_snapshot_intervals: t.Callable, +) -> None: + snapshot = make_snapshot( + SqlModel( + name="a", + cron="@daily", + query=parse_one("select 1, ds"), + ), + version="a", + ) + + state_sync.push_snapshots([snapshot]) + + state_sync.add_interval(snapshot, "2020-01-01", "2020-01-01") + state_sync.add_interval(snapshot, "2020-01-02", "2020-01-02") + state_sync.add_interval(snapshot, "2020-01-03", "2020-01-03") + state_sync.add_interval(snapshot, "2020-01-04", "2020-01-04") + + pending_restatement_intervals = [ + (to_timestamp("2020-01-03"), to_timestamp("2020-01-05")), + ] + state_sync.add_snapshots_intervals( + [ + SnapshotIntervals( + name=snapshot.name, + identifier=snapshot.identifier, + version=snapshot.version, + dev_version=snapshot.dev_version, + intervals=[], + dev_intervals=[], + pending_restatement_intervals=pending_restatement_intervals, + ) + ] + ) + + with time_machine.travel("2020-01-05 01:00:00 UTC"): + # Backfill one of the pending restatement intervals. + state_sync.add_interval(snapshot, "2020-01-03", "2020-01-03") + snapshot = state_sync.get_snapshots([snapshot])[snapshot.snapshot_id] + assert snapshot.intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-05")), + ] + assert snapshot.pending_restatement_intervals == [ + (to_timestamp("2020-01-04"), to_timestamp("2020-01-05")), + ] + + state_sync.compact_intervals() + snapshot = state_sync.get_snapshots([snapshot])[snapshot.snapshot_id] + assert snapshot.intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-05")), + ] + assert snapshot.pending_restatement_intervals == [ + (to_timestamp("2020-01-04"), to_timestamp("2020-01-05")), + ] + + # Make sure compaction is idempotent. + state_sync.compact_intervals() + snapshot = state_sync.get_snapshots([snapshot])[snapshot.snapshot_id] + assert snapshot.intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-05")), + ] + assert snapshot.pending_restatement_intervals == [ + (to_timestamp("2020-01-04"), to_timestamp("2020-01-05")), + ] + + with time_machine.travel("2020-01-05 02:00:00 UTC"): + # Backfill the remaining pending restatement interval. + state_sync.add_interval(snapshot, "2020-01-04", "2020-01-04") + snapshot = state_sync.get_snapshots([snapshot])[snapshot.snapshot_id] + assert snapshot.intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-05")), + ] + assert snapshot.pending_restatement_intervals == [] + + state_sync.compact_intervals() + snapshot = state_sync.get_snapshots([snapshot])[snapshot.snapshot_id] + assert snapshot.intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-05")), + ] + assert snapshot.pending_restatement_intervals == [] + + # Make sure compaction is idempotent. + state_sync.compact_intervals() + snapshot = state_sync.get_snapshots([snapshot])[snapshot.snapshot_id] + assert snapshot.intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-05")), + ] + assert snapshot.pending_restatement_intervals == [] + + +@time_machine.travel("2020-01-05 00:00:00 UTC") +def test_compact_intervals_pending_restatement_shared_version( + state_sync: EngineAdapterStateSync, + make_snapshot: t.Callable, + get_snapshot_intervals: t.Callable, +) -> None: + snapshot_a = make_snapshot( + SqlModel( + name="a", + cron="@daily", + query=parse_one("select 1, ds"), + ), + version="a", + ) + snapshot_b = make_snapshot( + SqlModel( + name="a", + cron="@daily", + query=parse_one("select 2, ds"), + ), + version="a", + ) + + state_sync.push_snapshots([snapshot_a, snapshot_b]) + + state_sync.add_interval(snapshot_a, "2020-01-01", "2020-01-01") + state_sync.add_interval(snapshot_a, "2020-01-02", "2020-01-02") + state_sync.add_interval(snapshot_a, "2020-01-03", "2020-01-03") + state_sync.add_interval(snapshot_a, "2020-01-04", "2020-01-04") + state_sync.add_interval(snapshot_a, "2020-01-05", "2020-01-05") + state_sync.add_snapshots_intervals( + [ + SnapshotIntervals( + name=snapshot_a.name, + identifier=snapshot_a.identifier, + version=snapshot_a.version, + dev_version=snapshot_a.dev_version, + intervals=[], + dev_intervals=[], + pending_restatement_intervals=[ + (to_timestamp("2020-01-03"), to_timestamp("2020-01-06")), + ], + ) + ] + ) + + expected_intervals = [ + SnapshotIntervals( + name=snapshot_b.name, + identifier=None, + version=snapshot_b.version, + dev_version=None, + intervals=[], + dev_intervals=[], + pending_restatement_intervals=[ + (to_timestamp("2020-01-04"), to_timestamp("2020-01-06")), + ], + ), + SnapshotIntervals( + name=snapshot_a.name, + identifier=snapshot_a.identifier, + version=snapshot_a.version, + dev_version=snapshot_a.dev_version, + intervals=[ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-06")), + ], + dev_intervals=[], + pending_restatement_intervals=[], + ), + SnapshotIntervals( + name=snapshot_b.name, + identifier=snapshot_b.identifier, + version=snapshot_b.version, + dev_version=snapshot_b.dev_version, + intervals=[ + (to_timestamp("2020-01-03"), to_timestamp("2020-01-04")), + ], + dev_intervals=[], + pending_restatement_intervals=[], + ), + ] + expected_intervals = sorted(expected_intervals, key=lambda x: (x.name, x.identifier or "")) + + with time_machine.travel("2020-01-05 01:00:00 UTC"): + # Add a new interval for the new snapshot + state_sync.add_interval(snapshot_b, "2020-01-03", "2020-01-03") + assert ( + sorted( + state_sync.interval_state.get_snapshot_intervals([snapshot_a, snapshot_b]), + key=lambda x: (x.name, x.identifier or ""), + ) + == expected_intervals + ) + snapshots = state_sync.get_snapshots([snapshot_a, snapshot_b]) + assert snapshots[snapshot_a.snapshot_id].pending_restatement_intervals == [ + (to_timestamp("2020-01-04"), to_timestamp("2020-01-06")), + ] + assert snapshots[snapshot_b.snapshot_id].pending_restatement_intervals == [ + (to_timestamp("2020-01-04"), to_timestamp("2020-01-06")), + ] + + state_sync.compact_intervals() + snapshots = state_sync.get_snapshots([snapshot_a, snapshot_b]) + assert snapshots[snapshot_a.snapshot_id].intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-06")), + ] + assert snapshots[snapshot_a.snapshot_id].pending_restatement_intervals == [ + (to_timestamp("2020-01-04"), to_timestamp("2020-01-06")), + ] + assert snapshots[snapshot_b.snapshot_id].pending_restatement_intervals == [ + (to_timestamp("2020-01-04"), to_timestamp("2020-01-06")), + ] + assert snapshots[snapshot_b.snapshot_id].intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-06")), + ] + + state_sync.compact_intervals() + snapshots = state_sync.get_snapshots([snapshot_a, snapshot_b]) + assert snapshots[snapshot_a.snapshot_id].intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-06")), + ] + assert snapshots[snapshot_a.snapshot_id].pending_restatement_intervals == [ + (to_timestamp("2020-01-04"), to_timestamp("2020-01-06")), + ] + assert snapshots[snapshot_b.snapshot_id].pending_restatement_intervals == [ + (to_timestamp("2020-01-04"), to_timestamp("2020-01-06")), + ] + assert snapshots[snapshot_b.snapshot_id].intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-06")), + ] + + expected_intervals = [ + SnapshotIntervals( + name=snapshot_a.name, + identifier=None, + version=snapshot_a.version, + dev_version=None, + intervals=[], + dev_intervals=[], + pending_restatement_intervals=[ + (to_timestamp("2020-01-05"), to_timestamp("2020-01-06")), + ], + ), + SnapshotIntervals( + name=snapshot_a.name, + identifier=snapshot_a.identifier, + version=snapshot_a.version, + dev_version=snapshot_a.dev_version, + intervals=[ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-06")), + ], + dev_intervals=[], + pending_restatement_intervals=[], + ), + SnapshotIntervals( + name=snapshot_b.name, + identifier=snapshot_b.identifier, + version=snapshot_b.version, + dev_version=snapshot_b.dev_version, + intervals=[ + (to_timestamp("2020-01-03"), to_timestamp("2020-01-04")), + ], + dev_intervals=[], + pending_restatement_intervals=[], + ), + ] + expected_intervals = sorted(expected_intervals, key=lambda x: (x.name, x.identifier or "")) + + with time_machine.travel("2020-01-05 02:00:00 UTC"): + # Add a new interval for the previous snapshot + state_sync.add_interval(snapshot_a, "2020-01-04", "2020-01-04") + assert ( + sorted( + state_sync.interval_state.get_snapshot_intervals([snapshot_a, snapshot_b]), + key=lambda x: (x.name, x.identifier or ""), + ) + == expected_intervals + ) + snapshots = state_sync.get_snapshots([snapshot_a, snapshot_b]) + assert snapshots[snapshot_a.snapshot_id].pending_restatement_intervals == [ + (to_timestamp("2020-01-05"), to_timestamp("2020-01-06")), + ] + assert snapshots[snapshot_b.snapshot_id].pending_restatement_intervals == [ + (to_timestamp("2020-01-05"), to_timestamp("2020-01-06")), + ] + + state_sync.compact_intervals() + snapshots = state_sync.get_snapshots([snapshot_a, snapshot_b]) + assert snapshots[snapshot_a.snapshot_id].intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-06")), + ] + assert snapshots[snapshot_a.snapshot_id].pending_restatement_intervals == [ + (to_timestamp("2020-01-05"), to_timestamp("2020-01-06")), + ] + assert snapshots[snapshot_b.snapshot_id].pending_restatement_intervals == [ + (to_timestamp("2020-01-05"), to_timestamp("2020-01-06")), + ] + assert snapshots[snapshot_b.snapshot_id].intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-06")), + ] + + expected_intervals = [ + SnapshotIntervals( + name=snapshot_a.name, + identifier=snapshot_a.identifier, + version=snapshot_a.version, + dev_version=snapshot_a.dev_version, + intervals=[ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-06")), + ], + dev_intervals=[], + pending_restatement_intervals=[], + ), + SnapshotIntervals( + name=snapshot_b.name, + identifier=snapshot_b.identifier, + version=snapshot_b.version, + dev_version=snapshot_b.dev_version, + intervals=[ + (to_timestamp("2020-01-03"), to_timestamp("2020-01-04")), + (to_timestamp("2020-01-05"), to_timestamp("2020-01-06")), + ], + dev_intervals=[], + pending_restatement_intervals=[], + ), + ] + expected_intervals = sorted(expected_intervals, key=lambda x: (x.name, x.identifier or "")) + + with time_machine.travel("2020-01-05 03:00:00 UTC"): + state_sync.add_interval(snapshot_b, "2020-01-05", "2020-01-05") + assert ( + sorted( + state_sync.interval_state.get_snapshot_intervals([snapshot_a, snapshot_b]), + key=lambda x: (x.name, x.identifier or ""), + ) + == expected_intervals + ) + snapshots = state_sync.get_snapshots([snapshot_a, snapshot_b]) + assert snapshots[snapshot_a.snapshot_id].pending_restatement_intervals == [] + assert snapshots[snapshot_b.snapshot_id].pending_restatement_intervals == [] + + state_sync.compact_intervals() + snapshots = state_sync.get_snapshots([snapshot_a, snapshot_b]) + assert snapshots[snapshot_a.snapshot_id].pending_restatement_intervals == [] + assert snapshots[snapshot_a.snapshot_id].intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-06")), + ] + assert snapshots[snapshot_b.snapshot_id].pending_restatement_intervals == [] + assert snapshots[snapshot_b.snapshot_id].intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-06")), + ] + + +def test_get_environments_summary( + state_sync: EngineAdapterStateSync, + make_snapshot: t.Callable, +) -> None: + snapshot = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, ds"), + ), + ) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + state_sync.push_snapshots([snapshot]) + + now_ts = now_timestamp() + env_a_ttl = now_ts - 1000 + + env_a = Environment( + name="test_environment_a", + snapshots=[snapshot.table_info], + start_at="2022-01-01", + end_at="2022-01-01", + plan_id="test_plan_id", + previous_plan_id="test_plan_id", + expiration_ts=env_a_ttl, + ) + state_sync.promote(env_a) + + env_b_ttl = now_ts + 1000 + env_b = env_a.copy(update={"name": "test_environment_b", "expiration_ts": env_b_ttl}) + state_sync.promote(env_b) + + prod = Environment( + name="prod", + snapshots=[snapshot.table_info], + start_at="2022-01-01", + end_at="2022-01-01", + plan_id="test_plan_id", + previous_plan_id="test_plan_id", + ) + state_sync.promote(prod) + + actual = set(state_sync.get_environments_summary()) + expected = {prod.summary, env_a.summary, env_b.summary} + assert actual == expected + + +def test_get_environments_summary_only_prod( + state_sync: EngineAdapterStateSync, + make_snapshot: t.Callable, +) -> None: + snapshot = make_snapshot( + SqlModel( + name="a", + query=parse_one("select a, ds"), + ), + ) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + state_sync.push_snapshots([snapshot]) + + prod = Environment( + name="prod", + snapshots=[snapshot.table_info], + start_at="2022-01-01", + end_at="2022-01-01", + plan_id="test_plan_id", + previous_plan_id="test_plan_id", + ) + state_sync.promote(prod) + actual = state_sync.get_environments_summary() + expected = [prod.summary] + assert actual == expected + + +def test_get_environments_summary_no_env(state_sync: EngineAdapterStateSync) -> None: + assert state_sync.get_environments_summary() == [] + + +@time_machine.travel("2020-01-05 00:00:00 UTC") +def test_compact_intervals_pending_restatement_many_snapshots_same_version( + state_sync: EngineAdapterStateSync, + make_snapshot: t.Callable, + get_snapshot_intervals: t.Callable, +) -> None: + snapshots = [ + make_snapshot( + SqlModel( + name="a", + cron="@daily", + query=parse_one(f"select {i}, ds"), + ), + version="a", + ) + for i in range(100) + ] + + state_sync.push_snapshots(snapshots) + + for snapshot in snapshots: + state_sync.add_interval(snapshot, "2020-01-01", "2020-01-01") + state_sync.add_interval(snapshot, "2020-01-02", "2020-01-02") + state_sync.add_interval(snapshot, "2020-01-03", "2020-01-03") + state_sync.add_interval(snapshot, "2020-01-04", "2020-01-04") + + pending_restatement_intervals = [ + (to_timestamp("2020-01-03"), to_timestamp("2020-01-05")), + ] + state_sync.add_snapshots_intervals( + [ + SnapshotIntervals( + name=snapshots[0].name, + identifier=snapshots[0].identifier, + version=snapshots[0].version, + dev_version=snapshots[0].dev_version, + intervals=[], + dev_intervals=[], + pending_restatement_intervals=pending_restatement_intervals, + ) + ] + ) + + # Because of the number of snapshots requiring compaction, some compacted records will have different creation + # timestamps. + state_sync.compact_intervals() + + assert state_sync.get_snapshots([snapshots[0].snapshot_id])[ + snapshots[0].snapshot_id + ].pending_restatement_intervals == [ + (to_timestamp("2020-01-03"), to_timestamp("2020-01-05")), + ] + + +def test_update_environment_statements(state_sync: EngineAdapterStateSync): + assert state_sync.get_environment_statements(environment="dev") == [] + + environment = Environment( + name="dev", + snapshots=[], + start_at="2022-01-01", + end_at="2022-01-01", + plan_id="test_plan_id", + ) + environment_statements = [ + EnvironmentStatements( + before_all=["CREATE OR REPLACE TABLE table_1 AS SELECT 'a'"], + after_all=["CREATE OR REPLACE TABLE table_2 AS SELECT 'b'"], + python_env={}, + ) + ] + + state_sync.environment_state.update_environment(environment=environment) + state_sync.environment_state.update_environment_statements( + environment.name, environment.plan_id, environment_statements + ) + + environment_statements_dev = state_sync.get_environment_statements(environment="dev") + assert environment_statements_dev[0].before_all == [ + "CREATE OR REPLACE TABLE table_1 AS SELECT 'a'" + ] + assert environment_statements_dev[0].after_all == [ + "CREATE OR REPLACE TABLE table_2 AS SELECT 'b'" + ] + + environment_statements = [ + EnvironmentStatements( + before_all=["CREATE OR REPLACE TABLE table_1 AS SELECT 'a'"], + after_all=[ + "@grant_schema_usage()", + "@grant_select_privileges()", + ], + python_env={}, + ) + ] + + state_sync.environment_state.update_environment(environment=environment) + state_sync.environment_state.update_environment_statements( + environment.name, environment.plan_id, environment_statements + ) + + environment_statements_dev = state_sync.get_environment_statements(environment="dev") + assert environment_statements_dev[0].before_all == [ + "CREATE OR REPLACE TABLE table_1 AS SELECT 'a'" + ] + assert environment_statements_dev[0].after_all == [ + "@grant_schema_usage()", + "@grant_select_privileges()", + ] + + +def test_get_snapshots_by_names( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable[..., Snapshot] +): + assert state_sync.get_snapshots_by_names(snapshot_names=[]) == set() + + snap_a_v1, snap_a_v2 = ( + make_snapshot( + SqlModel( + name="a", + query=parse_one(f"select {i}, ds"), + ), + version="a", + ) + for i in range(2) + ) + + snap_b = make_snapshot( + SqlModel( + name="b", + query=parse_one(f"select 'b' as b, ds"), + ), + version="b", + ) + + state_sync.push_snapshots([snap_a_v1, snap_a_v2, snap_b]) + + assert {s.snapshot_id for s in state_sync.get_snapshots_by_names(snapshot_names=['"a"'])} == { + snap_a_v1.snapshot_id, + snap_a_v2.snapshot_id, + } + assert { + s.snapshot_id for s in state_sync.get_snapshots_by_names(snapshot_names=['"a"', '"b"']) + } == { + snap_a_v1.snapshot_id, + snap_a_v2.snapshot_id, + snap_b.snapshot_id, + } + + +def test_get_snapshots_by_names_include_expired( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable[..., Snapshot] +): + now_ts = now_timestamp() + + normal_a = make_snapshot( + SqlModel( + name="a", + query=parse_one(f"select 1, ds"), + ), + version="a", + ) + + expired_a = make_snapshot( + SqlModel( + name="a", + query=parse_one(f"select 2, ds"), + ), + version="a", + ttl="in 10 seconds", + ) + expired_a.updated_ts = now_ts - ( + 1000 * 15 + ) # last updated 15 seconds ago, expired 10 seconds from last updated = expired 5 seconds ago + + state_sync.push_snapshots([normal_a, expired_a]) + + assert { + s.snapshot_id + for s in state_sync.get_snapshots_by_names(snapshot_names=['"a"'], current_ts=now_ts) + } == {normal_a.snapshot_id} + assert { + s.snapshot_id + for s in state_sync.get_snapshots_by_names(snapshot_names=['"a"'], exclude_expired=False) + } == { + normal_a.snapshot_id, + expired_a.snapshot_id, + } + + # wind back time to 10 seconds ago (before the expired snapshot is expired - it expired 5 seconds ago) to test it stil shows in a normal query + assert { + s.snapshot_id + for s in state_sync.get_snapshots_by_names( + snapshot_names=['"a"'], current_ts=(now_ts - (10 * 1000)) + ) + } == {normal_a.snapshot_id, expired_a.snapshot_id} + + +def test_state_version_is_too_old( + state_sync: EngineAdapterStateSync, mocker: MockerFixture +) -> None: + state_sync.engine_adapter.replace_query( + "sqlmesh._versions", + pd.DataFrame( + [{"schema_version": 59, "sqlmesh_version": "0.133.0", "sqlglot_version": "25.31.4"}] + ), + target_columns_to_types={ + "schema_version": exp.DataType.build("int"), + "sqlglot_version": exp.DataType.build("text"), + "sqlmesh_version": exp.DataType.build("text"), + }, + ) + + with pytest.raises( + StateMigrationError, + match="The current state belongs to an old version of SQLMesh that is no longer supported. Please upgrade to 0.134.0 first before upgrading to.*", + ): + state_sync.migrate(skip_backup=True) diff --git a/tests/core/test_audit.py b/tests/core/test_audit.py index 9ab97a43b4..66897ed088 100644 --- a/tests/core/test_audit.py +++ b/tests/core/test_audit.py @@ -1,18 +1,21 @@ +import json import pytest from sqlglot import exp, parse_one from sqlmesh.core import constants as c +from sqlmesh.core.config.model import ModelDefaultsConfig from sqlmesh.core.context import Context +from sqlmesh.core.node import DbtNodeInfo from sqlmesh.core.audit import ( - BUILT_IN_AUDITS, ModelAudit, StandaloneAudit, builtin, load_audit, load_multiple_audits, ) -from sqlmesh.core.dialect import parse +from sqlmesh.core.dialect import parse, jinja_query from sqlmesh.core.model import ( + FullKind, IncrementalByTimeRangeKind, Model, SeedModel, @@ -78,6 +81,8 @@ def test_load(assert_exp_eq): col IS NULL """, ) + assert audit.query_._parsed is not None + assert audit.query_._parsed_dialect == "spark" def test_load_standalone(assert_exp_eq): @@ -119,6 +124,8 @@ def test_load_standalone(assert_exp_eq): col IS NULL """, ) + assert audit.query_._parsed is not None + assert audit.query_._parsed_dialect == "spark" def test_load_standalone_default_catalog(assert_exp_eq): @@ -166,7 +173,7 @@ def test_load_standalone_default_catalog(assert_exp_eq): """, ) assert_exp_eq( - audit.render_query(audit), + audit.render_audit_query(), """ SELECT * @@ -303,6 +310,49 @@ def test_load_multiple(assert_exp_eq): ) +def test_load_with_dictionary_defaults(): + expressions = parse( + """ + AUDIT ( + name my_audit, + dialect spark, + defaults ( + field1 = some_column, + field2 = 3 + ), + ); + + SELECT 1 + """ + ) + + audit = load_audit(expressions, dialect="spark") + assert audit.defaults.keys() == {"field1", "field2"} + for value in audit.defaults.values(): + assert isinstance(value, exp.Expression) + + +def test_load_with_single_defaults(): + # testing it also works with a single default with no trailing comma + expressions = parse( + """ + AUDIT ( + name my_audit, + defaults ( + field1 = some_column + ), + ); + + SELECT 1 + """ + ) + + audit = load_audit(expressions, dialect="duckdb") + assert audit.defaults.keys() == {"field1"} + for value in audit.defaults.values(): + assert isinstance(value, exp.Expression) + + def test_no_audit_statement(): expressions = parse( """ @@ -347,7 +397,7 @@ def test_no_query(): def test_macro(model: Model): - expected_query = """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_q_0" WHERE "a" IS NULL""" + expected_query = """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_0" WHERE "a" IS NULL""" audit = ModelAudit( name="test_audit", @@ -359,8 +409,8 @@ def test_macro(model: Model): query="JINJA_QUERY_BEGIN; SELECT * FROM {{ this_model }} WHERE a IS NULL; JINJA_END;", ) - assert audit.render_query(model).sql() == expected_query - assert audit_jinja.render_query(model).sql() == expected_query + assert model.render_audit_query(audit).sql() == expected_query + assert model.render_audit_query(audit_jinja).sql() == expected_query def test_load_with_defaults(model, assert_exp_eq): @@ -394,190 +444,232 @@ def test_load_with_defaults(model, assert_exp_eq): "field4": exp.Literal.string("some string"), } assert_exp_eq( - audit.render_query(model, field4=exp.Literal.string("overridden")), + model.render_audit_query(audit, field4=exp.Literal.string("overridden")), 'SELECT * FROM "db"."table" AS "table" WHERE TRUE AND \'overridden\' IN (\'some string\', \'other string\') AND "some_column" = 3 AND "other_column" <> \'overridden\'', ) def test_not_null_audit(model: Model): - rendered_query_a = builtin.not_null_audit.render_query( - model, + rendered_query_a = model.render_audit_query( + builtin.not_null_audit, columns=[exp.to_column("a")], ) assert ( rendered_query_a.sql() - == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_q_0" WHERE "a" IS NULL AND TRUE""" + == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_0" WHERE "a" IS NULL AND TRUE""" ) - rendered_query_a_and_b = builtin.not_null_audit.render_query( - model, + rendered_query_a_and_b = model.render_audit_query( + builtin.not_null_audit, columns=[exp.to_column("a"), exp.to_column("b")], ) assert ( rendered_query_a_and_b.sql() - == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_q_0" WHERE ("a" IS NULL OR "b" IS NULL) AND TRUE""" + == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_0" WHERE ("a" IS NULL OR "b" IS NULL) AND TRUE""" ) def test_not_null_audit_default_catalog(model_default_catalog: Model): - rendered_query_a = builtin.not_null_audit.render_query( - model_default_catalog, + rendered_query_a = model_default_catalog.render_audit_query( + builtin.not_null_audit, columns=[exp.to_column("a")], ) assert ( rendered_query_a.sql() - == """SELECT * FROM (SELECT * FROM "test_catalog"."db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_q_0" WHERE "a" IS NULL AND TRUE""" + == """SELECT * FROM (SELECT * FROM "test_catalog"."db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_0" WHERE "a" IS NULL AND TRUE""" ) - rendered_query_a_and_b = builtin.not_null_audit.render_query( - model_default_catalog, + rendered_query_a_and_b = model_default_catalog.render_audit_query( + builtin.not_null_audit, columns=[exp.to_column("a"), exp.to_column("b")], ) assert ( rendered_query_a_and_b.sql() - == """SELECT * FROM (SELECT * FROM "test_catalog"."db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_q_0" WHERE ("a" IS NULL OR "b" IS NULL) AND TRUE""" + == """SELECT * FROM (SELECT * FROM "test_catalog"."db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_0" WHERE ("a" IS NULL OR "b" IS NULL) AND TRUE""" ) def test_unique_values_audit(model: Model): - rendered_query_a = builtin.unique_values_audit.render_query( - model, columns=[exp.to_column("a")], condition=parse_one("b IS NULL") + rendered_query_a = model.render_audit_query( + builtin.unique_values_audit, columns=[exp.to_column("a")], condition=parse_one("b IS NULL") ) assert ( rendered_query_a.sql() - == 'SELECT * FROM (SELECT ROW_NUMBER() OVER (PARTITION BY "a" ORDER BY "a") AS "rank_a" FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_q_0" WHERE "b" IS NULL) AS "_q_1" WHERE "rank_a" > 1' + == 'SELECT * FROM (SELECT ROW_NUMBER() OVER (PARTITION BY "a" ORDER BY "a") AS "rank_a" FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_0" WHERE "b" IS NULL) AS "_1" WHERE "rank_a" > 1' ) - rendered_query_a_and_b = builtin.unique_values_audit.render_query( - model, columns=[exp.to_column("a"), exp.to_column("b")] + rendered_query_a_and_b = model.render_audit_query( + builtin.unique_values_audit, columns=[exp.to_column("a"), exp.to_column("b")] ) assert ( rendered_query_a_and_b.sql() - == 'SELECT * FROM (SELECT ROW_NUMBER() OVER (PARTITION BY "a" ORDER BY "a") AS "rank_a", ROW_NUMBER() OVER (PARTITION BY "b" ORDER BY "b") AS "rank_b" FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_q_0" WHERE TRUE) AS "_q_1" WHERE "rank_a" > 1 OR "rank_b" > 1' + == 'SELECT * FROM (SELECT ROW_NUMBER() OVER (PARTITION BY "a" ORDER BY "a") AS "rank_a", ROW_NUMBER() OVER (PARTITION BY "b" ORDER BY "b") AS "rank_b" FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_0" WHERE TRUE) AS "_1" WHERE "rank_a" > 1 OR "rank_b" > 1' ) def test_accepted_values_audit(model: Model): - rendered_query = builtin.accepted_values_audit.render_query( - model, + rendered_query = model.render_audit_query( + builtin.accepted_values_audit, column=exp.to_column("a"), is_in=["value_a", "value_b"], ) assert ( rendered_query.sql() - == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_q_0" WHERE NOT "a" IN ('value_a', 'value_b') AND TRUE""" + == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_0" WHERE NOT "a" IN ('value_a', 'value_b') AND TRUE""" ) def test_number_of_rows_audit(model: Model): - rendered_query = builtin.number_of_rows_audit.render_query( - model, + rendered_query = model.render_audit_query( + builtin.number_of_rows_audit, threshold=0, ) assert ( rendered_query.sql() - == """SELECT COUNT(*) FROM (SELECT 1 FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_q_0" WHERE TRUE LIMIT 0 + 1) AS "_q_1" HAVING COUNT(*) <= 0""" + == """SELECT COUNT(*) FROM (SELECT 1 FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_0" WHERE TRUE LIMIT 0 + 1) AS "_1" HAVING COUNT(*) <= 0""" ) def test_forall_audit(model: Model): - rendered_query_a = builtin.forall_audit.render_query( - model, + rendered_query_a = model.render_audit_query( + builtin.forall_audit, criteria=[parse_one("a >= b")], ) assert ( rendered_query_a.sql() - == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_q_0" WHERE NOT ("a" >= "b") AND TRUE""" + == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_0" WHERE NOT ("a" >= "b") AND TRUE""" ) - rendered_query_a = builtin.forall_audit.render_query( - model, + rendered_query_a = model.render_audit_query( + builtin.forall_audit, criteria=[parse_one("a >= b"), parse_one("c + d - e < 1.0")], ) assert ( rendered_query_a.sql() - == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_q_0" WHERE (NOT ("a" >= "b") OR NOT ("c" + "d" - "e" < 1.0)) AND TRUE""" + == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_0" WHERE (NOT ("a" >= "b") OR NOT ("c" + "d" - "e" < 1.0)) AND TRUE""" ) - rendered_query_a = builtin.forall_audit.render_query( - model, + rendered_query_a = model.render_audit_query( + builtin.forall_audit, criteria=[parse_one("a >= b"), parse_one("c + d - e < 1.0")], condition=parse_one("f = 42"), ) assert ( rendered_query_a.sql() - == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_q_0" WHERE (NOT ("a" >= "b") OR NOT ("c" + "d" - "e" < 1.0)) AND "f" = 42""" + == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_0" WHERE (NOT ("a" >= "b") OR NOT ("c" + "d" - "e" < 1.0)) AND "f" = 42""" ) def test_accepted_range_audit(model: Model): - rendered_query = builtin.accepted_range_audit.render_query( - model, column=exp.to_column("a"), min_v=0 + rendered_query = model.render_audit_query( + builtin.accepted_range_audit, column=exp.to_column("a"), min_v=0 ) assert ( rendered_query.sql() - == 'SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_q_0" WHERE "a" < 0 AND TRUE' + == 'SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_0" WHERE "a" < 0 AND TRUE' ) - rendered_query = builtin.accepted_range_audit.render_query( - model, column=exp.to_column("a"), max_v=100, inclusive=exp.false() + rendered_query = model.render_audit_query( + builtin.accepted_range_audit, column=exp.to_column("a"), max_v=100, inclusive=exp.false() ) assert ( rendered_query.sql() - == 'SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_q_0" WHERE "a" >= 100 AND TRUE' + == 'SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_0" WHERE "a" >= 100 AND TRUE' ) - rendered_query = builtin.accepted_range_audit.render_query( - model, column=exp.to_column("a"), min_v=100, max_v=100 + rendered_query = model.render_audit_query( + builtin.accepted_range_audit, column=exp.to_column("a"), min_v=100, max_v=100 ) assert ( rendered_query.sql() - == 'SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_q_0" WHERE ("a" < 100 OR "a" > 100) AND TRUE' + == 'SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_0" WHERE ("a" < 100 OR "a" > 100) AND TRUE' ) def test_at_least_one_audit(model: Model): - rendered_query = builtin.at_least_one_audit.render_query( - model, + rendered_query = model.render_audit_query( + builtin.at_least_one_audit, column=exp.to_column("a"), ) assert ( rendered_query.sql() - == 'SELECT 1 AS "1" FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_q_0" WHERE TRUE GROUP BY 1 HAVING COUNT("a") = 0' + == 'SELECT 1 AS "1" FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_0" WHERE TRUE GROUP BY 1 HAVING COUNT("a") = 0' ) def test_mutually_exclusive_ranges_audit(model: Model): - rendered_query = builtin.mutually_exclusive_ranges_audit.render_query( - model, + rendered_query = model.render_audit_query( + builtin.mutually_exclusive_ranges_audit, lower_bound_column=exp.to_column("a"), upper_bound_column=exp.to_column("a"), ) assert ( rendered_query.sql() - == '''WITH "window_functions" AS (SELECT "a" AS "lower_bound", "a" AS "upper_bound", LEAD("a") OVER (ORDER BY "a", "a") AS "next_lower_bound", ROW_NUMBER() OVER (ORDER BY "a" DESC, "a" DESC) = 1 AS "is_last_record" FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_q_0" WHERE TRUE), "calc" AS (SELECT *, COALESCE("lower_bound" <= "upper_bound", FALSE) AS "lower_bound_lte_upper_bound", COALESCE("upper_bound" <= "next_lower_bound", "is_last_record", FALSE) AS "upper_bound_lte_next_lower_bound" FROM "window_functions" AS "window_functions"), "validation_errors" AS (SELECT * FROM "calc" AS "calc" WHERE NOT ("lower_bound_lte_upper_bound" AND "upper_bound_lte_next_lower_bound")) SELECT * FROM "validation_errors" AS "validation_errors"''' + == '''WITH "window_functions" AS (SELECT "a" AS "lower_bound", "a" AS "upper_bound", LEAD("a") OVER (ORDER BY "a", "a") AS "next_lower_bound", ROW_NUMBER() OVER (ORDER BY "a" DESC, "a" DESC) = 1 AS "is_last_record" FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_0" WHERE TRUE), "calc" AS (SELECT *, COALESCE("lower_bound" <= "upper_bound", FALSE) AS "lower_bound_lte_upper_bound", COALESCE("upper_bound" <= "next_lower_bound", "is_last_record", FALSE) AS "upper_bound_lte_next_lower_bound" FROM "window_functions" AS "window_functions"), "validation_errors" AS (SELECT * FROM "calc" AS "calc" WHERE NOT ("lower_bound_lte_upper_bound" AND "upper_bound_lte_next_lower_bound")) SELECT * FROM "validation_errors" AS "validation_errors"''' ) def test_sequential_values_audit(model: Model): - rendered_query = builtin.sequential_values_audit.render_query( - model, + rendered_query = model.render_audit_query( + builtin.sequential_values_audit, column=exp.to_column("a"), ) assert ( rendered_query.sql() - == '''WITH "windowed" AS (SELECT "a", LAG("a") OVER (ORDER BY "a") AS "prv" FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_q_0" WHERE TRUE), "validation_errors" AS (SELECT * FROM "windowed" AS "windowed" WHERE NOT ("a" = "prv" + 1)) SELECT * FROM "validation_errors" AS "validation_errors"''' + == '''WITH "windowed" AS (SELECT "a", LAG("a") OVER (ORDER BY "a") AS "prv" FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_0" WHERE TRUE), "validation_errors" AS (SELECT * FROM "windowed" AS "windowed" WHERE NOT ("a" = "prv" + 1)) SELECT * FROM "validation_errors" AS "validation_errors"''' ) def test_chi_square_audit(model: Model): - rendered_query = builtin.chi_square_audit.render_query( - model, + rendered_query = model.render_audit_query( + builtin.chi_square_audit, column_a=exp.to_column("a"), column_b=exp.to_column("b"), critical_value=exp.convert(9.48773), ) assert ( rendered_query.sql() - == """WITH "samples" AS (SELECT "a" AS "x_a", "b" AS "x_b" FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_q_0" WHERE (NOT "a" IS NULL AND NOT "b" IS NULL) AND TRUE), "contingency_table" AS (SELECT "x_a", "x_b", COUNT(*) AS "observed", (SELECT COUNT(*) FROM "samples" AS "t" WHERE "r"."x_a" = "t"."x_a") AS "tot_a", (SELECT COUNT(*) FROM "samples" AS "t" WHERE "r"."x_b" = "t"."x_b") AS "tot_b", (SELECT COUNT(*) FROM "samples" AS "samples") AS "g_t" /* g_t is the grand total */ FROM "samples" AS "r" GROUP BY "x_a", "x_b") SELECT ((SELECT COUNT(DISTINCT "x_a") FROM "contingency_table" AS "contingency_table") - 1) * ((SELECT COUNT(DISTINCT "x_b") FROM "contingency_table" AS "contingency_table") - 1) AS "degrees_of_freedom", SUM(("observed" - ("tot_a" * "tot_b" / "g_t")) * ("observed" - ("tot_a" * "tot_b" / "g_t")) / ("tot_a" * "tot_b" / "g_t")) AS "chi_square" FROM "contingency_table" AS "contingency_table" HAVING NOT "chi_square" > 9.48773""" + == """WITH "samples" AS (SELECT "a" AS "x_a", "b" AS "x_b" FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_0" WHERE (NOT "a" IS NULL AND NOT "b" IS NULL) AND TRUE), "contingency_table" AS (SELECT "x_a", "x_b", COUNT(*) AS "observed", (SELECT COUNT(*) FROM "samples" AS "t" WHERE "r"."x_a" = "t"."x_a") AS "tot_a", (SELECT COUNT(*) FROM "samples" AS "t" WHERE "r"."x_b" = "t"."x_b") AS "tot_b", (SELECT COUNT(*) FROM "samples" AS "samples") AS "g_t" /* g_t is the grand total */ FROM "samples" AS "r" GROUP BY "x_a", "x_b") SELECT ((SELECT COUNT(DISTINCT "x_a") FROM "contingency_table" AS "contingency_table") - 1) * ((SELECT COUNT(DISTINCT "x_b") FROM "contingency_table" AS "contingency_table") - 1) AS "degrees_of_freedom", SUM(("observed" - ("tot_a" * "tot_b" / "g_t")) * ("observed" - ("tot_a" * "tot_b" / "g_t")) / ("tot_a" * "tot_b" / "g_t")) AS "chi_square" FROM "contingency_table" AS "contingency_table" /* H0: the two variables are independent */ /* H1: the two variables are dependent */ /* if chi_square > critical_value, reject H0 */ /* if chi_square <= critical_value, fail to reject H0 */ HAVING NOT "chi_square" > 9.48773""" + ) + + +def test_pattern_audits(model: Model): + rendered_query = model.render_audit_query( + builtin.match_regex_pattern_list_audit, + column=exp.to_column("a"), + patterns=[r"^\d.*", ".*!$"], + ) + assert ( + rendered_query.sql() + == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_0" WHERE (NOT REGEXP_LIKE("a", \'^\\d.*\') AND NOT REGEXP_LIKE("a", \'.*!$\')) AND TRUE""" + ) + + rendered_query = model.render_audit_query( + builtin.not_match_regex_pattern_list_audit, + column=exp.to_column("a"), + patterns=[r"^\d.*", ".*!$"], + ) + assert ( + rendered_query.sql() + == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_0" WHERE (REGEXP_LIKE("a", \'^\\d.*\') OR REGEXP_LIKE("a", \'.*!$\')) AND TRUE""" + ) + + rendered_query = model.render_audit_query( + builtin.match_like_pattern_list, + column=exp.to_column("a"), + patterns=["jim%", "pam%"], + ) + assert ( + rendered_query.sql() + == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_0" WHERE (NOT "a" LIKE \'jim%\' AND NOT "a" LIKE \'pam%\') AND TRUE""" + ) + + rendered_query = model.render_audit_query( + builtin.not_match_like_pattern_list_audit, + column=exp.to_column("a"), + patterns=["jim%", "pam%"], + ) + assert ( + rendered_query.sql() + == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN \'1970-01-01\' AND \'1970-01-01\') AS "_0" WHERE ("a" LIKE \'jim%\' OR "a" LIKE \'pam%\') AND TRUE""" ) @@ -588,11 +680,16 @@ def test_standalone_audit(model: Model, assert_exp_eq): assert audit.depends_on == {model.fqn} - rendered_query = audit.render_query(audit) + rendered_query = audit.render_audit_query() assert_exp_eq( rendered_query, """SELECT * FROM "db"."test_model" AS "test_model" WHERE "col" IS NULL""" ) + with pytest.raises(AuditConfigError) as ex: + StandaloneAudit(name="test_audit", query=parse_one("SELECT 1"), blocking=True) + + assert "Standalone audits cannot be blocking: 'test_audit'." in str(ex.value) + def test_render_definition(): expressions = parse( @@ -634,6 +731,27 @@ def test_render_definition(): assert "def test_macro(evaluator, v):" in format_model_expressions(audit.render_definition()) +def test_render_definition_dbt_node_info(): + node_info = DbtNodeInfo( + unique_id="test.project.my_audit", name="my_audit", fqn="project.my_audit" + ) + + audit = StandaloneAudit(name="my_audit", dbt_node_info=node_info, query=jinja_query("select 1")) + + assert ( + audit.render_definition()[0].sql(pretty=True) + == """AUDIT ( + name my_audit, + dbt_node_info ( + fqn := 'project.my_audit', + name := 'my_audit', + unique_id := 'test.project.my_audit' + ), + standalone TRUE +)""" + ) + + def test_text_diff(): expressions = parse( """ @@ -680,43 +798,45 @@ def test_text_diff(): def test_non_blocking_builtin(): + from sqlmesh.core.audit.builtin import BUILT_IN_AUDITS + assert BUILT_IN_AUDITS["not_null_non_blocking"].blocking is False assert BUILT_IN_AUDITS["not_null_non_blocking"].name == "not_null_non_blocking" assert BUILT_IN_AUDITS["not_null"].query == BUILT_IN_AUDITS["not_null_non_blocking"].query def test_string_length_between_audit(model: Model): - rendered_query = builtin.string_length_between_audit.render_query( - model, + rendered_query = model.render_audit_query( + builtin.string_length_between_audit, column=exp.column("x"), min_v=1, max_v=5, ) assert ( rendered_query.sql() - == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_q_0" WHERE (LENGTH("x") < 1 OR LENGTH("x") > 5) AND TRUE""" + == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_0" WHERE (LENGTH("x") < 1 OR LENGTH("x") > 5) AND TRUE""" ) def test_not_constant_audit(model: Model): - rendered_query = builtin.not_constant_audit.render_query( - model, column=exp.column("x"), condition=exp.condition("x > 1") + rendered_query = model.render_audit_query( + builtin.not_constant_audit, column=exp.column("x"), condition=exp.condition("x > 1") ) assert ( rendered_query.sql() - == """SELECT 1 AS "1" FROM (SELECT COUNT(DISTINCT "x") AS "t_cardinality" FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_q_0" WHERE "x" > 1) AS "r" WHERE "r"."t_cardinality" <= 1""" + == """SELECT 1 AS "1" FROM (SELECT COUNT(DISTINCT "x") AS "t_cardinality" FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_0" WHERE "x" > 1) AS "r" WHERE "r"."t_cardinality" <= 1""" ) def test_condition_with_macro_var(model: Model): - rendered_query = builtin.not_null_audit.render_query( - model, + rendered_query = model.render_audit_query( + builtin.not_null_audit, columns=[exp.column("x")], condition=exp.condition("dt BETWEEN @start_dt AND @end_dt"), ) assert ( rendered_query.sql(dialect="duckdb") - == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_q_0" WHERE "x" IS NULL AND "dt" BETWEEN CAST('1970-01-01 00:00:00+00:00' AS TIMESTAMP) AND CAST('1970-01-01 23:59:59.999999+00:00' AS TIMESTAMP)""" + == """SELECT * FROM (SELECT * FROM "db"."test_model" AS "test_model" WHERE "ds" BETWEEN '1970-01-01' AND '1970-01-01') AS "_0" WHERE "x" IS NULL AND "dt" BETWEEN CAST('1970-01-01 00:00:00+00:00' AS TIMESTAMPTZ) AND CAST('1970-01-01 23:59:59.999999+00:00' AS TIMESTAMPTZ)""" ) @@ -746,7 +866,7 @@ def test_variables(assert_exp_eq): ) assert audit.python_env[c.SQLMESH_VARS] == Executable.value({"test_var": "test_val"}) assert ( - audit.render_query(audit).sql(dialect="bigquery") + audit.render_audit_query().sql(dialect="bigquery") == "SELECT * FROM `db`.`table` AS `table` WHERE `col` = 'test_val'" ) @@ -757,7 +877,7 @@ def test_load_inline_audits(assert_exp_eq): MODEL ( name db.table, dialect spark, - audits(does_not_exceed_threshold) + audits(does_not_exceed_threshold, assert_positive_id) ); SELECT id FROM tbl; @@ -779,18 +899,230 @@ def test_load_inline_audits(assert_exp_eq): ) model = load_sql_based_model(expressions) - assert len(model.audits) == 1 - assert len(model.inline_audits) == 2 - assert isinstance(model.inline_audits["assert_positive_id"], ModelAudit) - assert isinstance(model.inline_audits["does_not_exceed_threshold"], ModelAudit) + assert len(model.audits) == 2 + assert len(model.audits_with_args) == 2 + assert isinstance(model.audit_definitions["assert_positive_id"], ModelAudit) + assert isinstance(model.audit_definitions["does_not_exceed_threshold"], ModelAudit) def test_model_inline_audits(sushi_context: Context): model_name = "sushi.waiter_names" - expected_query = 'SELECT * FROM (SELECT * FROM "memory"."sushi"."waiter_names" AS "waiter_names") AS "_q_0" WHERE "id" < 0' + expected_query = 'SELECT * FROM (SELECT * FROM "memory"."sushi"."waiter_names" AS "waiter_names") AS "_0" WHERE "id" < 0' model = sushi_context.get_snapshot(model_name, raise_if_missing=True).node assert isinstance(model, SeedModel) - assert len(model.inline_audits) == 3 - assert isinstance(model.inline_audits["assert_valid_name"], ModelAudit) - assert model.inline_audits["assert_positive_id"].render_query(model).sql() == expected_query + assert len(model.audit_definitions) == 3 + assert isinstance(model.audit_definitions["assert_valid_name"], ModelAudit) + model.render_audit_query(model.audit_definitions["assert_positive_id"]).sql() == expected_query + + +def test_audit_query_normalization(): + model = create_sql_model( + "db.test_model", + parse_one("SELECT a, b, ds"), + kind=FullKind(), + dialect="snowflake", + ) + rendered_audit_query = model.render_audit_query( + builtin.not_null_audit, + columns=[exp.to_column("a")], + ) + assert ( + rendered_audit_query.sql("snowflake") + == """SELECT * FROM "DB"."TEST_MODEL" AS "TEST_MODEL" WHERE "A" IS NULL AND TRUE""" + ) + + +def test_rendered_diff(): + audit1 = StandaloneAudit( + name="test_audit", query=parse_one("SELECT * FROM 'test' WHERE @AND(TRUE, NULL) > 2") + ) + + audit2 = StandaloneAudit( + name="test_audit", query=parse_one("SELECT * FROM 'test' WHERE @OR(FALSE, NULL) > 2") + ) + + assert """@@ -6,4 +6,4 @@ + + * + FROM "test" AS "test" + WHERE +- TRUE > 2 ++ FALSE > 2""" in audit1.text_diff(audit2, rendered=True) + + +def test_multiple_audits_with_same_name(): + expressions = parse( + """ + MODEL ( + name db.table, + dialect spark, + audits( + does_not_exceed_threshold(column := id, threshold := 1000), + does_not_exceed_threshold(column := price, threshold := 100), + does_not_exceed_threshold(column := price, threshold := 100) + ) + ); + + SELECT id, price FROM tbl; + + AUDIT ( + name does_not_exceed_threshold, + ); + SELECT * FROM @this_model + WHERE @column >= @threshold; + """ + ) + model = load_sql_based_model(expressions) + assert len(model.audits) == 3 + assert len(model.audits_with_args) == 3 + assert len(model.audit_definitions) == 1 + + # Testing that audit names are identical + assert model.audits[0][0] == model.audits[1][0] == model.audits[2][0] + + # Testing that audit arguments are different for first and second audit + assert model.audits[0][1] != model.audits[1][1] + + # Testing that audit arguments are identical for second and third audit + # This establishes that identical audits are preserved + assert model.audits[1][1] == model.audits[2][1] + + +def test_default_audits_included_when_no_model_audits(): + expressions = parse(""" + MODEL ( + name test.basic_model + ); + SELECT 1 as id, 'test' as name; + """) + + model_defaults = ModelDefaultsConfig( + dialect="duckdb", audits=["not_null(columns := ['id'])", "unique_values(columns := ['id'])"] + ) + model = load_sql_based_model(expressions, defaults=model_defaults.dict()) + + assert len(model.audits) == 2 + audit_names = [audit[0] for audit in model.audits] + assert "not_null" in audit_names + assert "unique_values" in audit_names + + # Verify arguments are preserved + for audit_name, audit_args in model.audits: + if audit_name == "not_null": + assert "columns" in audit_args + assert audit_args["columns"].expressions[0].this == "id" + elif audit_name == "unique_values": + assert "columns" in audit_args + assert audit_args["columns"].expressions[0].this == "id" + + for audit_name, audit_args in model.audits_with_args: + if audit_name == "not_null": + assert "columns" in audit_args + assert audit_args["columns"].expressions[0].this == "id" + elif audit_name == "unique_values": + assert "columns" in audit_args + assert audit_args["columns"].expressions[0].this == "id" + + +def test_model_defaults_audits_with_same_name(): + expressions = parse( + """ + MODEL ( + name db.table, + dialect spark, + audits( + does_not_exceed_threshold(column := id, threshold := 1000), + does_not_exceed_threshold(column := price, threshold := 100), + unique_values(columns := ['id']) + ) + ); + + SELECT id, price FROM tbl; + + AUDIT ( + name does_not_exceed_threshold, + ); + SELECT * FROM @this_model + WHERE @column >= @threshold; + """ + ) + + model_defaults = ModelDefaultsConfig( + dialect="duckdb", + audits=[ + "does_not_exceed_threshold(column := price, threshold := 33)", + "does_not_exceed_threshold(column := id, threshold := 65)", + "not_null(columns := ['id'])", + ], + ) + model = load_sql_based_model(expressions, defaults=model_defaults.dict()) + assert len(model.audits) == 6 + assert len(model.audits_with_args) == 6 + assert len(model.audit_definitions) == 1 + + expected_audits = [ + ( + "does_not_exceed_threshold", + {"column": exp.column("price"), "threshold": exp.Literal.number(33)}, + ), + ( + "does_not_exceed_threshold", + {"column": exp.column("id"), "threshold": exp.Literal.number(65)}, + ), + ("not_null", {"columns": exp.convert(["id"])}), + ( + "does_not_exceed_threshold", + {"column": exp.column("id"), "threshold": exp.Literal.number(1000)}, + ), + ( + "does_not_exceed_threshold", + {"column": exp.column("price"), "threshold": exp.Literal.number(100)}, + ), + ("unique_values", {"columns": exp.convert(["id"])}), + ] + + for (actual_name, actual_args), (expected_name, expected_args) in zip( + model.audits, expected_audits + ): + # Validate the audit names are preserved + assert actual_name == expected_name + for key in expected_args: + # comparing sql representaion is easier + assert actual_args[key].sql() == expected_args[key].sql() + + # Validate audits with args as well along with their arguments + for (actual_audit, actual_args), (expected_name, expected_args) in zip( + model.audits_with_args, expected_audits + ): + assert actual_audit.name == expected_name + for key in expected_args: + assert actual_args[key].sql() == expected_args[key].sql() + + +def test_audit_formatting_flag_serde(): + expressions = parse( + """ + AUDIT ( + name my_audit, + dialect bigquery, + formatting false, + ); + + SELECT * FROM db.table WHERE col = @VAR('test_var') + """ + ) + + audit = load_audit( + expressions, + path="/path/to/audit", + dialect="bigquery", + variables={"test_var": "test_val", "test_var_unused": "unused_val"}, + ) + + audit_json = audit.json() + + assert "formatting" not in json.loads(audit_json) + + deserialized_audit = ModelAudit.parse_raw(audit_json) + assert deserialized_audit.dict() == audit.dict() diff --git a/tests/core/test_config.py b/tests/core/test_config.py index d74efbd41b..9ae239f298 100644 --- a/tests/core/test_config.py +++ b/tests/core/test_config.py @@ -3,8 +3,10 @@ import re from pathlib import Path from unittest import mock +import typing as t import pytest +from pytest_mock import MockerFixture from sqlglot import exp from sqlmesh.core.config import ( @@ -12,17 +14,28 @@ DuckDBConnectionConfig, GatewayConfig, ModelDefaultsConfig, + BigQueryConnectionConfig, + MotherDuckConnectionConfig, + BuiltInSchedulerConfig, + EnvironmentSuffixTarget, + TableNamingConvention, ) -from sqlmesh.core.config.connection import DuckDBAttachOptions -from sqlmesh.core.config.feature_flag import DbtFeatureFlag, FeatureFlag +from sqlmesh.core.config.connection import DuckDBAttachOptions, RedshiftConnectionConfig from sqlmesh.core.config.loader import ( load_config_from_env, load_config_from_paths, load_config_from_python_module, + load_configs, ) +from sqlmesh.core.context import Context +from sqlmesh.core.engine_adapter.athena import AthenaEngineAdapter +from sqlmesh.core.engine_adapter.duckdb import DuckDBEngineAdapter +from sqlmesh.core.engine_adapter.redshift import RedshiftEngineAdapter from sqlmesh.core.notification_target import ConsoleNotificationTarget from sqlmesh.core.user import User from sqlmesh.utils.errors import ConfigError +from sqlmesh.utils import yaml +from sqlmesh.dbt.loader import DbtLoader from tests.utils.test_filesystem import create_temp_file @@ -284,15 +297,32 @@ def test_load_config_from_env(): { "SQLMESH__GATEWAY__CONNECTION__TYPE": "duckdb", "SQLMESH__GATEWAY__CONNECTION__DATABASE": "test_db", - "SQLMESH__FEATURE_FLAGS__DBT__SCD_TYPE_2_SUPPORT": "false", }, ): assert Config.parse_obj(load_config_from_env()) == Config( gateways=GatewayConfig(connection=DuckDBConnectionConfig(database="test_db")), - feature_flags=FeatureFlag(dbt=DbtFeatureFlag(scd_type_2_support=False)), ) +def test_load_config_from_env_fails(): + with mock.patch.dict(os.environ, {"SQLMESH__GATEWAYS__ABCDEF__CONNECTION__PASSWORD": "..."}): + with pytest.raises( + ConfigError, + match="Missing connection type.\n\nVerify your config.yaml and environment variables.", + ): + Config.parse_obj(load_config_from_env()) + + +def test_load_config_from_env_no_config_vars(): + with mock.patch.dict( + os.environ, + { + "DUMMY_ENV_VAR": "dummy", + }, + ): + assert load_config_from_env() == {} + + def test_load_config_from_env_invalid_variable_name(): with mock.patch.dict( os.environ, @@ -307,6 +337,75 @@ def test_load_config_from_env_invalid_variable_name(): load_config_from_env() +def test_load_yaml_config_env_var_gateway_override(tmp_path_factory): + config_path = tmp_path_factory.mktemp("yaml_config") / "config.yaml" + with open(config_path, "w", encoding="utf-8") as fd: + fd.write( + """ +gateways: + testing: + connection: + type: motherduck + database: blah +model_defaults: + dialect: bigquery + """ + ) + with mock.patch.dict( + os.environ, + { + "SQLMESH__GATEWAYS__TESTING__STATE_CONNECTION__TYPE": "bigquery", + "SQLMESH__GATEWAYS__TESTING__STATE_CONNECTION__CHECK_IMPORT": "false", + "SQLMESH__DEFAULT_GATEWAY": "testing", + }, + ): + assert load_config_from_paths( + Config, + project_paths=[config_path], + ) == Config( + gateways={ + "testing": GatewayConfig( + connection=MotherDuckConnectionConfig(database="blah"), + state_connection=BigQueryConnectionConfig(check_import=False), + ), + }, + model_defaults=ModelDefaultsConfig(dialect="bigquery"), + default_gateway="testing", + ) + + +def test_load_py_config_env_var_gateway_override(tmp_path_factory): + config_path = tmp_path_factory.mktemp("python_config") / "config.py" + with open(config_path, "w", encoding="utf-8") as fd: + fd.write( + """from sqlmesh.core.config import Config, DuckDBConnectionConfig, GatewayConfig, ModelDefaultsConfig +config = Config(gateways={"duckdb_gateway": GatewayConfig(connection=DuckDBConnectionConfig())}, model_defaults=ModelDefaultsConfig(dialect='')) + """ + ) + with mock.patch.dict( + os.environ, + { + "SQLMESH__GATEWAYS__DUCKDB_GATEWAY__STATE_CONNECTION__TYPE": "bigquery", + "SQLMESH__GATEWAYS__DUCKDB_GATEWAY__STATE_CONNECTION__CHECK_IMPORT": "false", + "SQLMESH__DEFAULT_GATEWAY": "duckdb_gateway", + }, + ): + config = load_config_from_paths( + Config, + project_paths=[config_path], + ) + assert config == Config( + gateways={ # type: ignore + "duckdb_gateway": GatewayConfig( + connection=DuckDBConnectionConfig(), + state_connection=BigQueryConnectionConfig(check_import=False), + ), + }, + model_defaults=ModelDefaultsConfig(dialect=""), + default_gateway="duckdb_gateway", + ) + + def test_load_config_from_python_module_missing_config(tmp_path): config_path = tmp_path / "missing_config.py" with open(config_path, "w", encoding="utf-8") as fd: @@ -328,31 +427,6 @@ def test_load_config_from_python_module_invalid_config_object(tmp_path): load_config_from_python_module(Config, config_path) -def test_cloud_composer_scheduler_config(tmp_path_factory): - config_path = tmp_path_factory.mktemp("yaml_config") / "config.yaml" - with open(config_path, "w", encoding="utf-8") as fd: - fd.write( - """ -gateways: - another_gateway: - connection: - type: duckdb - database: test_db - scheduler: - type: cloud_composer - airflow_url: https://airflow.url - -model_defaults: - dialect: bigquery - """ - ) - - assert load_config_from_paths( - Config, - project_paths=[config_path], - ) - - @pytest.mark.parametrize( [ "mapping", @@ -417,33 +491,14 @@ def test_environment_catalog_mapping(tmp_path_factory, mapping, expected, dialec ) -def test_load_feature_flag(tmp_path_factory): - config_path = tmp_path_factory.mktemp("yaml_config") / "config.yaml" - with open(config_path, "w", encoding="utf-8") as fd: - fd.write( - """ -gateways: - duckdb_gateway: - connection: - type: duckdb -model_defaults: - dialect: bigquery -feature_flags: - dbt: - scd_type_2_support: false - """ - ) +def test_physical_schema_mapping_mutually_exclusive_with_physical_schema_override() -> None: + Config(physical_schema_override={"foo": "bar"}) # type: ignore + Config(physical_schema_mapping={"^foo$": "bar"}) - assert load_config_from_paths( - Config, - project_paths=[config_path], - ) == Config( - gateways={ - "duckdb_gateway": GatewayConfig(connection=DuckDBConnectionConfig()), - }, - model_defaults=ModelDefaultsConfig(dialect="bigquery"), - feature_flags=FeatureFlag(dbt=DbtFeatureFlag(scd_type_2_support=False)), - ) + with pytest.raises( + ConfigError, match=r"Only one.*physical_schema_override.*physical_schema_mapping" + ): + Config(physical_schema_override={"foo": "bar"}, physical_schema_mapping={"^foo$": "bar"}) # type: ignore def test_load_alternative_config_type(yaml_config_path: Path, python_config_path: Path): @@ -476,7 +531,10 @@ def test_connection_config_serialization(): "type": "duckdb", "extensions": [], "pre_ping": False, + "pretty_sql": False, "connector_config": {}, + "secrets": [], + "filesystems": [], "database": "my_db", } assert serialized["default_test_connection"] == { @@ -485,7 +543,10 @@ def test_connection_config_serialization(): "type": "duckdb", "extensions": [], "pre_ping": False, + "pretty_sql": False, "connector_config": {}, + "secrets": [], + "filesystems": [], "database": "my_test_db", } @@ -557,3 +618,899 @@ def test_load_duckdb_attach_config(tmp_path): assert attach_config_2.type == "postgres" assert attach_config_2.path == "dbname=postgres user=postgres host=127.0.0.1" assert attach_config_2.read_only is True + + +def test_load_model_defaults_audits(tmp_path): + config_path = tmp_path / "config_model_defaults_audits.yaml" + with open(config_path, "w", encoding="utf-8") as fd: + fd.write( + """ +model_defaults: + dialect: '' + audits: + - assert_positive_order_ids + - does_not_exceed_threshold(column := id, threshold := 1000) + """ + ) + + config = load_config_from_paths( + Config, + project_paths=[config_path], + ) + + assert len(config.model_defaults.audits) == 2 + assert config.model_defaults.audits[0] == ("assert_positive_order_ids", {}) + assert config.model_defaults.audits[1][0] == "does_not_exceed_threshold" + assert type(config.model_defaults.audits[1][1]["column"]) == exp.Column + assert config.model_defaults.audits[1][1]["column"].this.this == "id" + assert type(config.model_defaults.audits[1][1]["threshold"]) == exp.Literal + assert config.model_defaults.audits[1][1]["threshold"].this == "1000" + + +def test_load_model_defaults_statements(tmp_path): + config_path = tmp_path / "config_model_defaults_statements.yaml" + with open(config_path, "w", encoding="utf-8") as fd: + fd.write( + """ +model_defaults: + dialect: duckdb + pre_statements: + - SET memory_limit = '10GB' + - CREATE TEMP TABLE temp_data AS SELECT 1 as id + post_statements: + - DROP TABLE IF EXISTS temp_data + - ANALYZE @this_model + - SET memory_limit = '5GB' + on_virtual_update: + - UPDATE stats_table SET last_update = CURRENT_TIMESTAMP + """ + ) + + config = load_config_from_paths( + Config, + project_paths=[config_path], + ) + + assert config.model_defaults.pre_statements is not None + assert len(config.model_defaults.pre_statements) == 2 + assert isinstance(exp.maybe_parse(config.model_defaults.pre_statements[0]), exp.Set) + assert isinstance(exp.maybe_parse(config.model_defaults.pre_statements[1]), exp.Create) + + assert config.model_defaults.post_statements is not None + assert len(config.model_defaults.post_statements) == 3 + assert isinstance(exp.maybe_parse(config.model_defaults.post_statements[0]), exp.Drop) + assert isinstance(exp.maybe_parse(config.model_defaults.post_statements[1]), exp.Analyze) + assert isinstance(exp.maybe_parse(config.model_defaults.post_statements[2]), exp.Set) + + assert config.model_defaults.on_virtual_update is not None + assert len(config.model_defaults.on_virtual_update) == 1 + assert isinstance(exp.maybe_parse(config.model_defaults.on_virtual_update[0]), exp.Update) + + +def test_load_model_defaults_validation_statements(tmp_path): + config_path = tmp_path / "config_model_defaults_statements_wrong.yaml" + with open(config_path, "w", encoding="utf-8") as fd: + fd.write( + """ +model_defaults: + dialect: duckdb + pre_statements: + - 313 + """ + ) + + with pytest.raises(TypeError, match=r"expected str instance, int found"): + config = load_config_from_paths( + Config, + project_paths=[config_path], + ) + + +def test_scheduler_config(tmp_path_factory): + config_path = tmp_path_factory.mktemp("yaml_config") / "config.yaml" + with open(config_path, "w", encoding="utf-8") as fd: + fd.write( + """ +gateways: + builtin_gateway: + scheduler: + type: builtin + +default_scheduler: + type: builtin + +model_defaults: + dialect: bigquery + """ + ) + + config = load_config_from_paths( + Config, + project_paths=[config_path], + ) + assert isinstance(config.default_scheduler, BuiltInSchedulerConfig) + assert isinstance(config.get_gateway("builtin_gateway").scheduler, BuiltInSchedulerConfig) + + +def test_multi_gateway_config(tmp_path, mocker: MockerFixture): + config_path = tmp_path / "config_athena_redshift.yaml" + with open(config_path, "w", encoding="utf-8") as fd: + fd.write( + """ +gateways: + redshift: + connection: + type: redshift + user: user + password: '1234' + host: host + database: db + test_connection: + type: redshift + database: test_db + state_connection: + type: duckdb + database: state.db + athena: + connection: + type: athena + aws_access_key_id: '1234' + aws_secret_access_key: accesskey + work_group: group + s3_warehouse_location: s3://location + duckdb: + connection: + type: duckdb + database: db.db + +default_gateway: redshift + +model_defaults: + dialect: redshift + """ + ) + + config = load_config_from_paths( + Config, + project_paths=[config_path], + ) + + ctx = Context(paths=tmp_path, config=config) + + assert isinstance(ctx.connection_config, RedshiftConnectionConfig) + assert len(ctx.engine_adapters) == 3 + assert isinstance(ctx.engine_adapters["athena"], AthenaEngineAdapter) + assert isinstance(ctx.engine_adapters["redshift"], RedshiftEngineAdapter) + assert isinstance(ctx.engine_adapters["duckdb"], DuckDBEngineAdapter) + assert ctx.engine_adapter == ctx._get_engine_adapter("redshift") + + # The duckdb engine adapter should be have been set as multithreaded as well + assert ctx.engine_adapters["duckdb"]._multithreaded + + +def test_multi_gateway_single_threaded_config(tmp_path): + config_path = tmp_path / "config_duck_athena.yaml" + with open(config_path, "w", encoding="utf-8") as fd: + fd.write( + """ +gateways: + duckdb: + connection: + type: duckdb + database: db.db + athena: + connection: + type: athena + aws_access_key_id: '1234' + aws_secret_access_key: accesskey + work_group: group + s3_warehouse_location: s3://location +default_gateway: duckdb +model_defaults: + dialect: duckdb + """ + ) + + config = load_config_from_paths( + Config, + project_paths=[config_path], + ) + + ctx = Context(paths=tmp_path, config=config) + assert isinstance(ctx.connection_config, DuckDBConnectionConfig) + assert len(ctx.engine_adapters) == 2 + assert ctx.engine_adapter == ctx._get_engine_adapter("duckdb") + assert isinstance(ctx.engine_adapters["athena"], AthenaEngineAdapter) + + # In this case athena should use 1 concurrent task as the default gateway is duckdb + assert not ctx.engine_adapters["athena"]._multithreaded + + +def test_trino_schema_location_mapping_syntax(tmp_path): + config_path = tmp_path / "config_trino.yaml" + with open(config_path, "w", encoding="utf-8") as fd: + fd.write( + """ + gateways: + trino: + connection: + type: trino + user: trino + host: trino + catalog: trino + schema_location_mapping: + '^utils$': 's3://utils-bucket/@{schema_name}' + '^landing\\..*$': 's3://raw-data/@{catalog_name}/@{schema_name}' + + default_gateway: trino + + model_defaults: + dialect: trino + """ + ) + + config = load_config_from_paths( + Config, + project_paths=[config_path], + ) + + from sqlmesh.core.config.connection import TrinoConnectionConfig + + conn = config.gateways["trino"].connection + assert isinstance(conn, TrinoConnectionConfig) + + assert len(conn.schema_location_mapping) == 2 + + +def test_trino_source_option(tmp_path): + config_path = tmp_path / "config_trino_source.yaml" + with open(config_path, "w", encoding="utf-8") as fd: + fd.write( + """ + gateways: + trino: + connection: + type: trino + user: trino + host: trino + catalog: trino + source: my_sqlmesh_source + + default_gateway: trino + + model_defaults: + dialect: trino + """ + ) + + config = load_config_from_paths( + Config, + project_paths=[config_path], + ) + + from sqlmesh.core.config.connection import TrinoConnectionConfig + + conn = config.gateways["trino"].connection + assert isinstance(conn, TrinoConnectionConfig) + assert conn.source == "my_sqlmesh_source" + + +def test_gcp_postgres_ip_and_scopes(tmp_path): + config_path = tmp_path / "config_gcp_postgres.yaml" + with open(config_path, "w", encoding="utf-8") as fd: + fd.write( + """ + gateways: + gcp_postgres: + connection: + type: gcp_postgres + check_import: false + instance_connection_string: something + user: user + password: password + db: db + ip_type: private + scopes: + - https://www.googleapis.com/auth/cloud-platform + - https://www.googleapis.com/auth/sqlservice.admin + + default_gateway: gcp_postgres + + model_defaults: + dialect: postgres + """ + ) + + config = load_config_from_paths( + Config, + project_paths=[config_path], + ) + + from sqlmesh.core.config.connection import GCPPostgresConnectionConfig + + conn = config.gateways["gcp_postgres"].connection + assert isinstance(conn, GCPPostgresConnectionConfig) + + assert len(conn.scopes) == 2 + assert conn.scopes[0] == "https://www.googleapis.com/auth/cloud-platform" + assert conn.scopes[1] == "https://www.googleapis.com/auth/sqlservice.admin" + assert conn.ip_type == "private" + + +def test_gateway_model_defaults(tmp_path): + global_defaults = ModelDefaultsConfig( + dialect="snowflake", owner="foo", optimize_query=True, enabled=True, cron="@daily" + ) + gateway_defaults = ModelDefaultsConfig(dialect="duckdb", owner="baz", optimize_query=False) + + config = Config( + gateways={ + "duckdb": GatewayConfig( + connection=DuckDBConnectionConfig(database="db.db"), + model_defaults=gateway_defaults, + ) + }, + model_defaults=global_defaults, + default_gateway="duckdb", + ) + + ctx = Context(paths=tmp_path, config=config, gateway="duckdb") + + expected = ModelDefaultsConfig( + dialect="duckdb", owner="baz", optimize_query=False, enabled=True, cron="@daily" + ) + + assert ctx.config.model_defaults == expected + + +def test_redshift_merge_flag(tmp_path, mocker: MockerFixture): + config_path = tmp_path / "config_redshift_merge.yaml" + with open(config_path, "w", encoding="utf-8") as fd: + fd.write( + """ +gateways: + redshift: + connection: + type: redshift + user: user + password: '1234' + host: host + database: db + enable_merge: true + default: + connection: + type: redshift + user: user + password: '1234' + host: host + database: db + +default_gateway: redshift + +model_defaults: + dialect: redshift + """ + ) + + config = load_config_from_paths( + Config, + project_paths=[config_path], + ) + redshift_connection = config.get_connection("redshift") + assert isinstance(redshift_connection, RedshiftConnectionConfig) + assert redshift_connection.enable_merge + adapter = redshift_connection.create_engine_adapter() + assert isinstance(adapter, RedshiftEngineAdapter) + assert adapter.enable_merge + + adapter_2 = config.get_connection("default").create_engine_adapter() + assert isinstance(adapter_2, RedshiftEngineAdapter) + assert not adapter_2.enable_merge + + +def test_environment_statements_config(tmp_path): + config_path = tmp_path / "config_before_after_all.yaml" + with open(config_path, "w", encoding="utf-8") as fd: + fd.write( + """ + gateways: + postgres: + connection: + type: postgres + database: db + user: postgres + password: postgres + host: localhost + port: 5432 + + default_gateway: postgres + + before_all: + - CREATE TABLE IF NOT EXISTS custom_analytics (physical_table VARCHAR, evaluation_time VARCHAR); + after_all: + - "@grant_schema_privileges()" + - "GRANT REFERENCES ON FUTURE VIEWS IN DATABASE db TO ROLE admin_role;" + + model_defaults: + dialect: postgres + """ + ) + + config = load_config_from_paths( + Config, + project_paths=[config_path], + ) + + assert config.before_all == [ + "CREATE TABLE IF NOT EXISTS custom_analytics (physical_table VARCHAR, evaluation_time VARCHAR);" + ] + assert config.after_all == [ + "@grant_schema_privileges()", + "GRANT REFERENCES ON FUTURE VIEWS IN DATABASE db TO ROLE admin_role;", + ] + + +# https://github.com/SQLMesh/sqlmesh/pull/4049 +def test_pydantic_import_error() -> None: + class TestConfig(DuckDBConnectionConfig): + pass + + TestConfig() + + +def test_config_subclassing() -> None: + class ConfigSubclass(Config): ... + + ConfigSubclass() + + +def test_config_complex_types_supplied_as_json_strings_from_env(tmp_path: Path) -> None: + config_path = tmp_path / "config.yaml" + config_path.write_text(""" + gateways: + bigquery: + connection: + type: bigquery + project: unit-test + + default_gateway: bigquery + + model_defaults: + dialect: bigquery +""") + with mock.patch.dict( + os.environ, + { + "SQLMESH__GATEWAYS__BIGQUERY__CONNECTION__SCOPES": ' ["a","b","c"]', # note: leading whitespace is deliberate + "SQLMESH__GATEWAYS__BIGQUERY__CONNECTION__KEYFILE_JSON": '{ "foo": "bar" }', + }, + ): + config = load_config_from_paths( + Config, + project_paths=[config_path], + ) + + conn = config.gateways["bigquery"].connection + assert isinstance(conn, BigQueryConnectionConfig) + + assert conn.project == "unit-test" + assert conn.scopes == ("a", "b", "c") + assert conn.keyfile_json == {"foo": "bar"} + + +def test_config_user_macro_function(tmp_path: Path) -> None: + config_path = tmp_path / "config.yaml" + config_path.write_text(""" + gateways: + bigquery: + connection: + type: bigquery + project: unit-test + + default_gateway: bigquery + + model_defaults: + dialect: bigquery + + default_target_environment: dev_{{ user() }} +""") + + with mock.patch("getpass.getuser", return_value="test_user"): + config = load_config_from_paths( + Config, + project_paths=[config_path], + ) + + assert config.default_target_environment == "dev_test_user" + + +def test_environment_suffix_target_catalog(tmp_path: Path) -> None: + config_path = tmp_path / "config.yaml" + config_path.write_text(""" + gateways: + warehouse: + connection: + type: duckdb + + default_gateway: warehouse + + model_defaults: + dialect: duckdb + + environment_suffix_target: catalog +""") + + config = load_config_from_paths( + Config, + project_paths=[config_path], + ) + + assert config.environment_suffix_target == EnvironmentSuffixTarget.CATALOG + assert not config.environment_catalog_mapping + + config_path.write_text(""" + gateways: + warehouse: + connection: + type: duckdb + + default_gateway: warehouse + + model_defaults: + dialect: duckdb + + environment_suffix_target: catalog + + environment_catalog_mapping: + '.*': "foo" +""") + + with pytest.raises(ConfigError, match=r"mutually exclusive"): + config = load_config_from_paths( + Config, + project_paths=[config_path], + ) + + +def test_load_python_config_dot_env_vars(tmp_path_factory): + main_dir = tmp_path_factory.mktemp("python_config") + config_path = main_dir / "config.py" + with open(config_path, "w", encoding="utf-8") as fd: + fd.write( + """from sqlmesh.core.config import Config, DuckDBConnectionConfig, GatewayConfig, ModelDefaultsConfig +config = Config(gateways={"duckdb_gateway": GatewayConfig(connection=DuckDBConnectionConfig())}, model_defaults=ModelDefaultsConfig(dialect='')) + """ + ) + + # The environment variable value from the dot env file should be set + # SQLMESH__ variables override config fields directly if they follow the naming structure + dot_path = main_dir / ".env" + with open(dot_path, "w", encoding="utf-8") as fd: + fd.write( + """SQLMESH__GATEWAYS__DUCKDB_GATEWAY__STATE_CONNECTION__TYPE="bigquery" +SQLMESH__GATEWAYS__DUCKDB_GATEWAY__STATE_CONNECTION__CHECK_IMPORT="false" +SQLMESH__DEFAULT_GATEWAY="duckdb_gateway" + """ + ) + + # Use mock.patch.dict to isolate environment variables between the tests + with mock.patch.dict(os.environ, {}, clear=True): + configs = load_configs( + "config", + Config, + paths=[main_dir], + ) + + assert next(iter(configs.values())) == Config( + gateways={ + "duckdb_gateway": GatewayConfig( + connection=DuckDBConnectionConfig(), + state_connection=BigQueryConnectionConfig(check_import=False), + ), + }, + model_defaults=ModelDefaultsConfig(dialect=""), + default_gateway="duckdb_gateway", + ) + + +def test_load_yaml_config_dot_env_vars(tmp_path_factory): + main_dir = tmp_path_factory.mktemp("yaml_config") + config_path = main_dir / "config.yaml" + with open(config_path, "w", encoding="utf-8") as fd: + fd.write( + """gateways: + duckdb_gateway: + connection: + type: duckdb + catalogs: + local: local.db + cloud_sales: {{ env_var('S3_BUCKET') }} + extensions: + - name: httpfs + secrets: + - type: "s3" + key_id: {{ env_var('S3_KEY') }} + secret: {{ env_var('S3_SECRET') }} +model_defaults: + dialect: "" +""" + ) + + # This test checks both using SQLMESH__ prefixed environment variables with underscores + # and setting a regular environment variable for use with env_var(). + dot_path = main_dir / ".env" + with open(dot_path, "w", encoding="utf-8") as fd: + fd.write( + """S3_BUCKET="s3://metrics_bucket/sales.db" +S3_KEY="S3_KEY_ID" +S3_SECRET="XXX_S3_SECRET_XXX" +SQLMESH__DEFAULT_GATEWAY="duckdb_gateway" +SQLMESH__MODEL_DEFAULTS__DIALECT="athena" +""" + ) + + # Use mock.patch.dict to isolate environment variables between the tests + with mock.patch.dict(os.environ, {}, clear=True): + configs = load_configs( + "config", + Config, + paths=[main_dir], + ) + + assert next(iter(configs.values())) == Config( + gateways={ + "duckdb_gateway": GatewayConfig( + connection=DuckDBConnectionConfig( + catalogs={ + "local": "local.db", + "cloud_sales": "s3://metrics_bucket/sales.db", + }, + extensions=[{"name": "httpfs"}], + secrets=[{"type": "s3", "key_id": "S3_KEY_ID", "secret": "XXX_S3_SECRET_XXX"}], + ), + ), + }, + default_gateway="duckdb_gateway", + model_defaults=ModelDefaultsConfig(dialect="athena"), + ) + + +def test_load_config_dotenv_directory_not_loaded(tmp_path_factory): + main_dir = tmp_path_factory.mktemp("config_with_env_dir") + config_path = main_dir / "config.yaml" + with open(config_path, "w", encoding="utf-8") as fd: + fd.write( + """gateways: + test_gateway: + connection: + type: duckdb + database: test.db +model_defaults: + dialect: duckdb +""" + ) + + # Create a .env directory instead of a file to simulate a Python virtual environment + env_dir = main_dir / ".env" + env_dir.mkdir() + (env_dir / "pyvenv.cfg").touch() + + # Also create a regular .env file in another project directory + other_dir = tmp_path_factory.mktemp("config_with_env_file") + other_config_path = other_dir / "config.yaml" + with open(other_config_path, "w", encoding="utf-8") as fd: + fd.write( + """gateways: + test_gateway: + connection: + type: duckdb + database: test.db +model_defaults: + dialect: duckdb +""" + ) + + env_file = other_dir / ".env" + with open(env_file, "w", encoding="utf-8") as fd: + fd.write('TEST_ENV_VAR="from_dotenv_file"') + + # Test that the .env directory doesn't cause an error and is skipped + with mock.patch.dict(os.environ, {}, clear=True): + load_configs( + "config", + Config, + paths=[main_dir], + ) + # Should succeed without loading any env vars from the directory + assert "TEST_ENV_VAR" not in os.environ + + # Test that a real .env file is still loaded properly + with mock.patch.dict(os.environ, {}, clear=True): + load_configs( + "config", + Config, + paths=[other_dir], + ) + # The env var should be loaded from the file + assert os.environ.get("TEST_ENV_VAR") == "from_dotenv_file" + + +def test_load_yaml_config_custom_dotenv_path(tmp_path_factory): + main_dir = tmp_path_factory.mktemp("yaml_config_2") + config_path = main_dir / "config.yaml" + with open(config_path, "w", encoding="utf-8") as fd: + fd.write( + """gateways: + test_gateway: + connection: + type: duckdb + database: {{ env_var('DB_NAME') }} +""" + ) + + # Create a custom dot env file in a different location + custom_env_dir = tmp_path_factory.mktemp("custom_env") + custom_env_path = custom_env_dir / ".my_env" + with open(custom_env_path, "w", encoding="utf-8") as fd: + fd.write( + """DB_NAME="custom_database.db" +SQLMESH__DEFAULT_GATEWAY="test_gateway" +SQLMESH__MODEL_DEFAULTS__DIALECT="postgres" +""" + ) + + # Test that without custom dotenv path, env vars are not loaded + with mock.patch.dict(os.environ, {}, clear=True): + with pytest.raises( + ConfigError, match=r"Default model SQL dialect is a required configuratio*" + ): + load_configs( + "config", + Config, + paths=[main_dir], + ) + + # Test that with custom dotenv path, env vars are loaded correctly + with mock.patch.dict(os.environ, {}, clear=True): + configs = load_configs( + "config", + Config, + paths=[main_dir], + dotenv_path=custom_env_path, + ) + + assert next(iter(configs.values())) == Config( + gateways={ + "test_gateway": GatewayConfig( + connection=DuckDBConnectionConfig( + database="custom_database.db", + ), + ), + }, + default_gateway="test_gateway", + model_defaults=ModelDefaultsConfig(dialect="postgres"), + ) + + +@pytest.mark.parametrize( + "convention_str, expected", + [ + (None, TableNamingConvention.SCHEMA_AND_TABLE), + ("schema_and_table", TableNamingConvention.SCHEMA_AND_TABLE), + ("table_only", TableNamingConvention.TABLE_ONLY), + ("hash_md5", TableNamingConvention.HASH_MD5), + ], +) +def test_physical_table_naming_convention( + convention_str: t.Optional[str], expected: t.Optional[TableNamingConvention], tmp_path: Path +): + config_part = f"physical_table_naming_convention: {convention_str}" if convention_str else "" + (tmp_path / "config.yaml").write_text(f""" +gateways: + test_gateway: + connection: + type: duckdb +model_defaults: + dialect: duckdb +{config_part} + """) + + config = load_config_from_paths(Config, project_paths=[tmp_path / "config.yaml"]) + assert config.physical_table_naming_convention == expected + + +def test_load_configs_includes_sqlmesh_yaml(tmp_path: Path): + for extension in ("yaml", "yml"): + config_file = tmp_path / f"sqlmesh.{extension}" + config_file.write_text(""" +model_defaults: + start: '2023-04-05' + dialect: bigquery""") + + configs = load_configs(config=None, config_type=Config, paths=[tmp_path]) + assert len(configs) == 1 + + config: Config = list(configs.values())[0] + + assert config.model_defaults.start == "2023-04-05" + assert config.model_defaults.dialect == "bigquery" + + config_file.unlink() + + +def test_load_configs_without_main_connection(tmp_path: Path): + # this is for DBT projects where the main connection is defined in profiles.yml + # but we also need to be able to specify the sqlmesh state connection without editing any DBT files + # and without also duplicating the main connection + config_file = tmp_path / "sqlmesh.yaml" + with config_file.open("w") as f: + yaml.dump( + { + "gateways": {"": {"state_connection": {"type": "duckdb", "database": "state.db"}}}, + "model_defaults": {"dialect": "duckdb", "start": "2020-01-01"}, + }, + f, + ) + + configs = list(load_configs(config=None, config_type=Config, paths=[tmp_path]).values()) + assert len(configs) == 1 + + config = configs[0] + state_connection_config = config.get_state_connection() + assert isinstance(state_connection_config, DuckDBConnectionConfig) + assert state_connection_config.database == "state.db" + + +def test_load_configs_in_dbt_project_without_config_py(tmp_path: Path): + # this is when someone either: + # - inits a dbt project for sqlmesh, which creates a sqlmesh.yaml file + # - uses the sqlmesh_dbt cli for the first time, which runs init if the config doesnt exist, which creates a config + # when in pure yaml mode, sqlmesh should be able to auto-detect the presence of DBT and select the DbtLoader instead + # of the main loader + (tmp_path / "dbt_project.yml").write_text(""" +name: jaffle_shop + """) + + (tmp_path / "profiles.yml").write_text(""" +jaffle_shop: + + target: dev + outputs: + dev: + type: duckdb + path: 'jaffle_shop.duckdb' + """) + + (tmp_path / "sqlmesh.yaml").write_text(""" +gateways: + dev: + state_connection: + type: duckdb + database: state.db +model_defaults: + start: '2020-01-01' +""") + + configs = list(load_configs(config=None, config_type=Config, paths=[tmp_path]).values()) + assert len(configs) == 1 + + config = configs[0] + assert config.loader == DbtLoader + + assert list(config.gateways) == ["dev"] + + # main connection + connection_config = config.get_connection() + assert connection_config + assert isinstance(connection_config, DuckDBConnectionConfig) + assert connection_config.database == "jaffle_shop.duckdb" # from dbt profiles.yml + + # state connection + state_connection_config = config.get_state_connection() + assert state_connection_config + assert isinstance(state_connection_config, DuckDBConnectionConfig) + assert state_connection_config.database == "state.db" # from sqlmesh.yaml + + # model_defaults + assert config.model_defaults.dialect == "duckdb" # from dbt profiles.yml + assert config.model_defaults.start == "2020-01-01" # from sqlmesh.yaml diff --git a/tests/core/test_connection_config.py b/tests/core/test_connection_config.py index 6fb158ceb6..dd979a2551 100644 --- a/tests/core/test_connection_config.py +++ b/tests/core/test_connection_config.py @@ -1,22 +1,34 @@ import base64 +import re import typing as t import pytest from _pytest.fixtures import FixtureRequest +from sqlglot import exp +from unittest.mock import patch, MagicMock from sqlmesh.core.config.connection import ( BigQueryConnectionConfig, + ClickhouseConnectionConfig, ConnectionConfig, + DatabricksConnectionConfig, DuckDBAttachOptions, + FabricConnectionConfig, DuckDBConnectionConfig, GCPPostgresConnectionConfig, + MotherDuckConnectionConfig, MySQLConnectionConfig, PostgresConnectionConfig, SnowflakeConnectionConfig, TrinoAuthenticationMethod, + AthenaConnectionConfig, + MSSQLConnectionConfig, _connection_config_validator, + _get_engine_import_validator, + INIT_DISPLAY_INFO_TO_TYPE, ) from sqlmesh.utils.errors import ConfigError +from sqlmesh.utils.pydantic import PydanticModel @pytest.fixture @@ -385,14 +397,248 @@ def test_trino(make_config): make_config(method="ldap", http_scheme="http", **required_kwargs) +def test_trino_schema_location_mapping(make_config): + required_kwargs = dict( + type="trino", + user="user", + host="host", + catalog="catalog", + ) + + with pytest.raises( + ConfigError, match=r".*needs to include the '@\{schema_name\}' placeholder.*" + ): + make_config(**required_kwargs, schema_location_mapping={".*": "s3://foo"}) + + config = make_config( + **required_kwargs, + schema_location_mapping={ + "^utils$": "s3://utils-bucket/@{schema_name}", + "^staging.*$": "s3://bucket/@{schema_name}_dev", + "^sqlmesh.*$": "s3://sqlmesh-internal/dev/@{schema_name}", + }, + ) + + assert config.schema_location_mapping is not None + assert len(config.schema_location_mapping) == 3 + + assert all((isinstance(k, re.Pattern) for k in config.schema_location_mapping)) + assert all((isinstance(v, str) for v in config.schema_location_mapping.values())) + + +def test_trino_catalog_type_override(make_config): + required_kwargs = dict( + type="trino", + user="user", + host="host", + catalog="catalog", + ) + + config = make_config( + **required_kwargs, + catalog_type_overrides={"my_catalog": "iceberg"}, + ) + + assert config.catalog_type_overrides is not None + assert len(config.catalog_type_overrides) == 1 + + assert config.catalog_type_overrides == {"my_catalog": "iceberg"} + + +def test_trino_timestamp_mapping(make_config): + required_kwargs = dict( + type="trino", + user="user", + host="host", + catalog="catalog", + ) + + # Test config without timestamp_mapping + config = make_config(**required_kwargs) + assert config.timestamp_mapping is None + + # Test config with timestamp_mapping + config = make_config( + **required_kwargs, + timestamp_mapping={ + "TIMESTAMP": "TIMESTAMP(6)", + "TIMESTAMP(3)": "TIMESTAMP WITH TIME ZONE", + }, + ) + + assert config.timestamp_mapping is not None + assert config.timestamp_mapping[exp.DataType.build("TIMESTAMP")] == exp.DataType.build( + "TIMESTAMP(6)" + ) + + # Test with invalid source type + with pytest.raises(ConfigError) as exc_info: + make_config( + **required_kwargs, + timestamp_mapping={ + "INVALID_TYPE": "TIMESTAMP", + }, + ) + assert "Invalid SQL type string" in str(exc_info.value) + assert "INVALID_TYPE" in str(exc_info.value) + + # Test with invalid target type (not a valid SQL type) + with pytest.raises(ConfigError) as exc_info: + make_config( + **required_kwargs, + timestamp_mapping={ + "TIMESTAMP": "INVALID_TARGET_TYPE", + }, + ) + assert "Invalid SQL type string" in str(exc_info.value) + assert "INVALID_TARGET_TYPE" in str(exc_info.value) + + # Test with empty mapping + config = make_config( + **required_kwargs, + timestamp_mapping={}, + ) + assert config.timestamp_mapping is not None + assert config.timestamp_mapping == {} + + def test_duckdb(make_config): config = make_config( type="duckdb", database="test", connector_config={"foo": "bar"}, + secrets=[ + { + "type": "s3", + "region": "aws_region", + "key_id": "aws_access_key", + "secret": "aws_secret", + } + ], + filesystems=[ + { + "protocol": "abfs", + "storage_options": { + "account_name": "onelake", + "account_host": "onelake.blob.fabric.microsoft.com", + "anon": False, + }, + } + ], ) + assert config.connector_config + assert config.secrets + assert config.filesystems assert isinstance(config, DuckDBConnectionConfig) - assert config.is_recommended_for_state_sync is True + assert not config.is_recommended_for_state_sync + + +@patch("duckdb.connect") +def test_duckdb_multiple_secrets(mock_connect, make_config): + """Test that multiple secrets are correctly converted to CREATE SECRET SQL statements.""" + mock_cursor = MagicMock() + mock_connection = MagicMock() + mock_connection.cursor.return_value = mock_cursor + mock_connection.execute = mock_cursor.execute + mock_connect.return_value = mock_connection + + # Create config with 2 secrets + config = make_config( + type="duckdb", + secrets=[ + { + "type": "s3", + "region": "us-east-1", + "key_id": "my_aws_key", + "secret": "my_aws_secret", + }, + { + "type": "azure", + "account_name": "myaccount", + "account_key": "myaccountkey", + }, + ], + ) + + assert isinstance(config, DuckDBConnectionConfig) + assert len(config.secrets) == 2 + + # Create cursor which triggers _cursor_init + cursor = config.create_engine_adapter().cursor + + execute_calls = [call[0][0] for call in mock_cursor.execute.call_args_list] + create_secret_calls = [ + call for call in execute_calls if call.startswith("CREATE OR REPLACE SECRET") + ] + + # Should have exactly 2 CREATE SECRET calls + assert len(create_secret_calls) == 2 + + # Verify the SQL for the first secret (S3) + assert ( + create_secret_calls[0] + == "CREATE OR REPLACE SECRET (type 's3', region 'us-east-1', key_id 'my_aws_key', secret 'my_aws_secret');" + ) + + # Verify the SQL for the second secret (Azure) + assert ( + create_secret_calls[1] + == "CREATE OR REPLACE SECRET (type 'azure', account_name 'myaccount', account_key 'myaccountkey');" + ) + + +@patch("duckdb.connect") +def test_duckdb_named_secrets(mock_connect, make_config): + """Test that named secrets are correctly converted to CREATE SECRET SQL statements.""" + mock_cursor = MagicMock() + mock_connection = MagicMock() + mock_connection.cursor.return_value = mock_cursor + mock_connection.execute = mock_cursor.execute + mock_connect.return_value = mock_connection + + # Create config with named secrets using dictionary format + config = make_config( + type="duckdb", + secrets={ + "my_s3_secret": { + "type": "s3", + "region": "us-east-1", + "key_id": "my_aws_key", + "secret": "my_aws_secret", + }, + "my_azure_secret": { + "type": "azure", + "account_name": "myaccount", + "account_key": "myaccountkey", + }, + }, + ) + + assert isinstance(config, DuckDBConnectionConfig) + assert len(config.secrets) == 2 + + # Create cursor which triggers _cursor_init + cursor = config.create_engine_adapter().cursor + + execute_calls = [call[0][0] for call in mock_cursor.execute.call_args_list] + create_secret_calls = [ + call for call in execute_calls if call.startswith("CREATE OR REPLACE SECRET") + ] + + # Should have exactly 2 CREATE SECRET calls + assert len(create_secret_calls) == 2 + + # Verify the SQL for the first secret (S3) includes the secret name + assert ( + create_secret_calls[0] + == "CREATE OR REPLACE SECRET my_s3_secret (type 's3', region 'us-east-1', key_id 'my_aws_key', secret 'my_aws_secret');" + ) + + # Verify the SQL for the second secret (Azure) includes the secret name + assert ( + create_secret_calls[1] + == "CREATE OR REPLACE SECRET my_azure_secret (type 'azure', account_name 'myaccount', account_key 'myaccountkey');" + ) @pytest.mark.parametrize( @@ -553,7 +799,54 @@ def test_duckdb_attach_catalog(make_config): assert config.catalogs.get("test2").read_only is False assert config.catalogs.get("test2").path == "test2.duckdb" - assert config.is_recommended_for_state_sync is True + assert not config.is_recommended_for_state_sync + + +def test_duckdb_attach_ducklake_catalog(make_config): + config = make_config( + type="duckdb", + catalogs={ + "ducklake": DuckDBAttachOptions( + type="ducklake", + path="catalog.ducklake", + data_path="/tmp/ducklake_data", + encrypted=True, + data_inlining_row_limit=10, + ), + }, + ) + assert isinstance(config, DuckDBConnectionConfig) + ducklake_catalog = config.catalogs.get("ducklake") + assert ducklake_catalog is not None + assert ducklake_catalog.type == "ducklake" + assert ducklake_catalog.path == "catalog.ducklake" + assert ducklake_catalog.data_path == "/tmp/ducklake_data" + assert ducklake_catalog.encrypted is True + assert ducklake_catalog.data_inlining_row_limit == 10 + # Check that the generated SQL includes DATA_PATH + generated_sql = ducklake_catalog.to_sql("ducklake") + assert "DATA_PATH '/tmp/ducklake_data'" in generated_sql + assert "ENCRYPTED" in generated_sql + assert "DATA_INLINING_ROW_LIMIT 10" in generated_sql + # Check that the ducklake: prefix is automatically added + assert "ATTACH IF NOT EXISTS 'ducklake:catalog.ducklake'" in generated_sql + + # Test that a path with existing ducklake: prefix is preserved + config_with_prefix = make_config( + type="duckdb", + catalogs={ + "ducklake": DuckDBAttachOptions( + type="ducklake", + path="ducklake:catalog.ducklake", + data_path="/tmp/ducklake_data", + ), + }, + ) + ducklake_catalog_with_prefix = config_with_prefix.catalogs.get("ducklake") + generated_sql_with_prefix = ducklake_catalog_with_prefix.to_sql("ducklake") + assert "ATTACH IF NOT EXISTS 'ducklake:catalog.ducklake'" in generated_sql_with_prefix + # Ensure we don't have double prefixes + assert "'ducklake:catalog.ducklake" in generated_sql_with_prefix def test_duckdb_attach_options(): @@ -563,12 +856,263 @@ def test_duckdb_attach_options(): assert ( options.to_sql(alias="db") - == "ATTACH 'dbname=postgres user=postgres host=127.0.0.1' AS db (TYPE POSTGRES, READ_ONLY)" + == "ATTACH IF NOT EXISTS 'dbname=postgres user=postgres host=127.0.0.1' AS db (TYPE POSTGRES, READ_ONLY)" ) options = DuckDBAttachOptions(type="duckdb", path="test.db", read_only=False) - assert options.to_sql(alias="db") == "ATTACH 'test.db' AS db" + assert options.to_sql(alias="db") == "ATTACH IF NOT EXISTS 'test.db' AS db" + + +def test_ducklake_attach_add_ducklake_prefix(): + # Test that ducklake: prefix is automatically added when missing + options = DuckDBAttachOptions(type="ducklake", path="catalog.ducklake") + assert ( + options.to_sql(alias="my_ducklake") + == "ATTACH IF NOT EXISTS 'ducklake:catalog.ducklake' AS my_ducklake" + ) + + # Test that ducklake: prefix is preserved when already present + options = DuckDBAttachOptions(type="ducklake", path="ducklake:catalog.ducklake") + assert ( + options.to_sql(alias="my_ducklake") + == "ATTACH IF NOT EXISTS 'ducklake:catalog.ducklake' AS my_ducklake" + ) + + +def test_ducklake_metadata_schema(): + # Test that metadata_schema parameter is included when specified + options = DuckDBAttachOptions( + type="ducklake", path="catalog.ducklake", metadata_schema="custom_schema" + ) + assert ( + options.to_sql(alias="my_ducklake") + == "ATTACH IF NOT EXISTS 'ducklake:catalog.ducklake' AS my_ducklake (METADATA_SCHEMA 'custom_schema')" + ) + + # Test that metadata_schema is not included when not specified (default behavior) + options = DuckDBAttachOptions(type="ducklake", path="catalog.ducklake") + assert ( + options.to_sql(alias="my_ducklake") + == "ATTACH IF NOT EXISTS 'ducklake:catalog.ducklake' AS my_ducklake" + ) + + # Test metadata_schema with other ducklake options + options = DuckDBAttachOptions( + type="ducklake", + path="catalog.ducklake", + data_path="/path/to/data", + encrypted=True, + metadata_schema="workspace_schema", + ) + assert ( + options.to_sql(alias="my_ducklake") + == "ATTACH IF NOT EXISTS 'ducklake:catalog.ducklake' AS my_ducklake (DATA_PATH '/path/to/data', ENCRYPTED, METADATA_SCHEMA 'workspace_schema')" + ) + + +def test_duckdb_config_json_strings(make_config): + config = make_config( + type="duckdb", + extensions='["foo","bar"]', + catalogs="""{ + "test1": "test1.duckdb", + "test2": { + "type": "duckdb", + "path": "test2.duckdb" + } + }""", + ) + assert isinstance(config, DuckDBConnectionConfig) + + assert config.extensions == ["foo", "bar"] + + assert config.get_catalog() == "test1" + assert config.catalogs.get("test1") == "test1.duckdb" + assert config.catalogs.get("test2").path == "test2.duckdb" + + +def test_motherduck_attach_catalog(make_config): + config = make_config( + type="motherduck", + catalogs={ + "test1": "md:test1", + "test2": DuckDBAttachOptions( + type="motherduck", + path="md:test2", + ), + }, + ) + assert isinstance(config, MotherDuckConnectionConfig) + assert config.get_catalog() == "test1" + + assert config.catalogs.get("test2").read_only is False + assert config.catalogs.get("test2").path == "md:test2" + assert not config.is_recommended_for_state_sync + + +def test_motherduck_attach_options(): + options = DuckDBAttachOptions( + type="postgres", path="dbname=postgres user=postgres host=127.0.0.1", read_only=True + ) + + assert ( + options.to_sql(alias="db") + == "ATTACH IF NOT EXISTS 'dbname=postgres user=postgres host=127.0.0.1' AS db (TYPE POSTGRES, READ_ONLY)" + ) + + options = DuckDBAttachOptions(type="motherduck", path="md:test.db", read_only=False) + + # Here the alias should be ignored compared to duckdb + assert options.to_sql(alias="db") == "ATTACH IF NOT EXISTS 'md:test.db'" + + +def test_duckdb_multithreaded_connection_factory(make_config): + from sqlmesh.core.engine_adapter import DuckDBEngineAdapter + from sqlmesh.utils.connection_pool import ThreadLocalSharedConnectionPool + from threading import Thread + + config = make_config(type="duckdb") + + # defaults to 1, no issue + assert config.concurrent_tasks == 1 + + # check that the connection factory always returns the same connection in multithreaded mode + # this sounds counter-intuitive but that's what DuckDB recommends here: https://duckdb.org/docs/guides/python/multiple_threads.html + config = make_config(type="duckdb", concurrent_tasks=8) + adapter = config.create_engine_adapter() + assert isinstance(adapter, DuckDBEngineAdapter) + assert isinstance(adapter._connection_pool, ThreadLocalSharedConnectionPool) + + threads = [] + connection_objects = [] + + def _thread_connection(): + connection_objects.append(adapter.connection) + + for _ in range(8): + threads.append(Thread(target=_thread_connection)) + + for thread in threads: + thread.start() + + for thread in threads: + thread.join() + + assert len(connection_objects) == 8 + assert len(set(connection_objects)) == 1 # they should all be the same object + + # test that recycling the pool means we dont end up with unusable connections (eg check we havent cached a closed connection) + assert adapter.fetchone("select 1") == (1,) + adapter.recycle() + assert adapter.fetchone("select 1") == (1,) + + +def test_motherduck_token_mask(make_config): + config_1 = make_config( + type="motherduck", + token="short", + database="whodunnit", + ) + config_2 = make_config( + type="motherduck", + token="longtoken123456789", + database="whodunnit", + ) + config_3 = make_config( + type="motherduck", + token="secret1235", + catalogs={ + "test1": DuckDBAttachOptions( + type="motherduck", + path="md:whodunnit", + ), + }, + ) + + assert isinstance(config_1, MotherDuckConnectionConfig) + assert isinstance(config_2, MotherDuckConnectionConfig) + assert isinstance(config_3, MotherDuckConnectionConfig) + + # motherduck format + assert config_1._mask_sensitive_data(config_1.database) == "whodunnit" + assert ( + config_1._mask_sensitive_data(f"md:{config_1.database}?motherduck_token={config_1.token}") + == "md:whodunnit?motherduck_token=********" + ) + assert ( + config_1._mask_sensitive_data( + f"md:{config_1.database}?attach_mode=single&motherduck_token={config_1.token}" + ) + == "md:whodunnit?attach_mode=single&motherduck_token=********" + ) + assert ( + config_2._mask_sensitive_data(f"md:{config_2.database}?motherduck_token={config_2.token}") + == "md:whodunnit?motherduck_token=********" + ) + assert ( + config_3._mask_sensitive_data(f"md:?motherduck_token={config_3.token}") + == "md:?motherduck_token=********" + ) + assert ( + config_1._mask_sensitive_data("?motherduck_token=secret1235") + == "?motherduck_token=********" + ) + assert ( + config_1._mask_sensitive_data("md:whodunnit?motherduck_token=short") + == "md:whodunnit?motherduck_token=********" + ) + assert ( + config_1._mask_sensitive_data("md:whodunnit?motherduck_token=longtoken123456789") + == "md:whodunnit?motherduck_token=********" + ) + assert ( + config_1._mask_sensitive_data("md:whodunnit?motherduck_token=") + == "md:whodunnit?motherduck_token=" + ) + assert config_1._mask_sensitive_data(":memory:") == ":memory:" + + # postgres format + assert ( + config_1._mask_sensitive_data( + "postgres:dbname=mydb user=myuser password=secret123 host=localhost" + ) + == "postgres:dbname=mydb user=myuser password=******** host=localhost" + ) + + assert ( + config_1._mask_sensitive_data( + "dbname=postgres user=postgres password=pg_secret host=127.0.0.1" + ) + == "dbname=postgres user=postgres password=******** host=127.0.0.1" + ) + assert ( + config_1._mask_sensitive_data( + "postgres:dbname=testdb password=verylongpassword123 user=admin" + ) + == "postgres:dbname=testdb password=******** user=admin" + ) + assert config_1._mask_sensitive_data("postgres:password=short") == "postgres:password=********" + assert ( + config_1._mask_sensitive_data("postgres:host=localhost password=p@ssw0rd! dbname=db") + == "postgres:host=localhost password=******** dbname=db" + ) + + assert ( + config_1._mask_sensitive_data("postgres:dbname=mydb user=myuser host=localhost") + == "postgres:dbname=mydb user=myuser host=localhost" + ) + + assert ( + config_1._mask_sensitive_data("md:db?motherduck_token=token123 postgres:password=secret") + == "md:db?motherduck_token=******** postgres:password=********" + ) + + # MySQL format + assert ( + config_1._mask_sensitive_data("host=localhost user=root password=mysql123 database=mydb") + == "host=localhost user=root password=******** database=mydb" + ) def test_bigquery(make_config): @@ -576,14 +1120,38 @@ def test_bigquery(make_config): type="bigquery", project="project", execution_project="execution_project", + quota_project="quota_project", + check_import=False, ) assert isinstance(config, BigQueryConnectionConfig) assert config.project == "project" assert config.execution_project == "execution_project" + assert config.quota_project == "quota_project" assert config.get_catalog() == "project" assert config.is_recommended_for_state_sync is False + with pytest.raises(ConfigError, match="you must also specify the `project` field"): + make_config(type="bigquery", execution_project="execution_project", check_import=False) + + with pytest.raises(ConfigError, match="you must also specify the `project` field"): + make_config(type="bigquery", quota_project="quota_project", check_import=False) + + +def test_bigquery_config_json_string(make_config): + config = make_config( + type="bigquery", + project="project", + # these can be present as strings if they came from env vars + scopes='["a","b","c"]', + keyfile_json='{"foo":"bar"}', + ) + + assert isinstance(config, BigQueryConnectionConfig) + + assert config.scopes == ("a", "b", "c") + assert config.keyfile_json == {"foo": "bar"} + def test_postgres(make_config): config = make_config( @@ -605,9 +1173,21 @@ def test_gcp_postgres(make_config): user="user", password="password", db="database", + check_import=False, ) assert isinstance(config, GCPPostgresConnectionConfig) assert config.is_recommended_for_state_sync is True + assert config.ip_type == "public" + config = make_config( + type="gcp_postgres", + instance_connection_string="something", + user="user", + password="password", + db="database", + ip_type="private", + check_import=False, + ) + assert config.ip_type == "private" def test_mysql(make_config): @@ -616,6 +1196,754 @@ def test_mysql(make_config): host="host", user="user", password="password", + check_import=False, ) assert isinstance(config, MySQLConnectionConfig) assert config.is_recommended_for_state_sync is True + + +def test_clickhouse(make_config): + from sqlmesh import __version__ + + config = make_config( + type="clickhouse", + host="localhost", + username="default", + password="default", + cluster="default", + use_compression=True, + connection_settings={"this_setting": "1"}, + server_host_name="server_host_name", + verify=True, + ca_cert="ca_cert", + client_cert="client_cert", + client_cert_key="client_cert_key", + https_proxy="https://proxy", + connection_pool_options={"pool_option": "value"}, + ) + assert isinstance(config, ClickhouseConnectionConfig) + assert config.cluster == "default" + assert config.use_compression + assert config._static_connection_kwargs["compress"] + assert config._static_connection_kwargs["client_name"] == f"SQLMesh/{__version__}" + assert config._static_connection_kwargs["this_setting"] == "1" + assert config.is_recommended_for_state_sync is False + assert config.is_forbidden_for_state_sync + + pool = config._connection_factory.keywords["pool_mgr"] + assert pool.connection_pool_kw["server_hostname"] == "server_host_name" + assert pool.connection_pool_kw["assert_hostname"] == "server_host_name" # because verify=True + assert pool.connection_pool_kw["ca_certs"] == "ca_cert" + assert pool.connection_pool_kw["cert_file"] == "client_cert" + assert pool.connection_pool_kw["key_file"] == "client_cert_key" + assert pool.connection_pool_kw["pool_option"] == "value" + + config2 = make_config( + type="clickhouse", + host="localhost", + username="default", + password="default", + compression_method="lz4", + ) + + assert config2.use_compression + assert config2._static_connection_kwargs["compress"] == "lz4" + + config3 = make_config( + type="clickhouse", + host="localhost", + username="default", + password="default", + use_compression=False, + compression_method="lz4", + ) + + assert not config3.use_compression + assert not config3._static_connection_kwargs["compress"] + + +def test_athena(make_config): + config = make_config(type="athena", work_group="primary") + assert isinstance(config, AthenaConnectionConfig) + + +def test_athena_catalog(make_config): + config = make_config(type="athena", work_group="primary", catalog_name="foo") + assert isinstance(config, AthenaConnectionConfig) + + assert config.catalog_name == "foo" + adapter = config.create_engine_adapter() + assert adapter.default_catalog == "foo" + + config = make_config(type="athena", work_group="primary") + assert isinstance(config, AthenaConnectionConfig) + assert config.catalog_name is None + adapter = config.create_engine_adapter() + assert adapter.default_catalog == "awsdatacatalog" + + +def test_athena_s3_staging_dir_or_workgroup(make_config): + with pytest.raises( + ConfigError, match=r"At least one of work_group or s3_staging_dir must be set" + ): + config = make_config(type="athena") + + config = make_config(type="athena", s3_staging_dir="s3://foo") + + assert isinstance(config, AthenaConnectionConfig) + assert config.work_group is None + assert config.s3_staging_dir == "s3://foo/" # validator appends trailing / + + config = make_config(type="athena", work_group="test") + + assert isinstance(config, AthenaConnectionConfig) + assert config.work_group == "test" + assert config.s3_staging_dir is None + + +def test_athena_s3_locations_valid(make_config): + with pytest.raises(ConfigError, match=r".*must be a s3:// URI.*"): + make_config( + type="athena", work_group="primary", s3_warehouse_location="hdfs://legacy/location" + ) + + with pytest.raises(ConfigError, match=r".*must be a s3:// URI.*"): + make_config(type="athena", s3_staging_dir="alskdjlskadgj") + + config = make_config( + type="athena", + s3_staging_dir="s3://bucket/query-results", + s3_warehouse_location="s3://bucket/prod/warehouse/", + ) + + assert isinstance(config, AthenaConnectionConfig) + assert config.s3_staging_dir == "s3://bucket/query-results/" + assert config.s3_warehouse_location == "s3://bucket/prod/warehouse/" + + config = make_config( + type="athena", work_group="primary", s3_staging_dir=None, s3_warehouse_location=None + ) + + assert isinstance(config, AthenaConnectionConfig) + assert config.s3_staging_dir is None + assert config.s3_warehouse_location is None + + +def test_databricks(make_config): + # Personal Access Token + oauth_pat_config = make_config( + type="databricks", + server_hostname="dbc-test.cloud.databricks.com", + http_path="sql/test/foo", + access_token="foo", + ) + assert isinstance(oauth_pat_config, DatabricksConnectionConfig) + assert oauth_pat_config.server_hostname == "dbc-test.cloud.databricks.com" + assert oauth_pat_config.http_path == "sql/test/foo" + assert oauth_pat_config.access_token == "foo" + assert oauth_pat_config.auth_type is None + assert oauth_pat_config.oauth_client_id is None + assert oauth_pat_config.oauth_client_secret is None + + # OAuth (M2M) + oauth_m2m_config = make_config( + type="databricks", + server_hostname="dbc-test.cloud.databricks.com", + http_path="sql/test/foo", + auth_type="databricks-oauth", + oauth_client_id="client-id", + oauth_client_secret="client-secret", + ) + assert isinstance(oauth_m2m_config, DatabricksConnectionConfig) + assert oauth_m2m_config.server_hostname == "dbc-test.cloud.databricks.com" + assert oauth_pat_config.http_path == "sql/test/foo" + assert oauth_m2m_config.access_token is None + assert oauth_m2m_config.auth_type == "databricks-oauth" + assert oauth_m2m_config.oauth_client_id == "client-id" + assert oauth_m2m_config.oauth_client_secret == "client-secret" + + # OAuth (U2M) + oauth_u2m_config = make_config( + type="databricks", + server_hostname="dbc-test.cloud.databricks.com", + http_path="sql/test/foo", + auth_type="databricks-oauth", + ) + assert isinstance(oauth_u2m_config, DatabricksConnectionConfig) + assert oauth_u2m_config.server_hostname == "dbc-test.cloud.databricks.com" + assert oauth_pat_config.http_path == "sql/test/foo" + assert oauth_u2m_config.access_token is None + assert oauth_u2m_config.auth_type == "databricks-oauth" + assert oauth_u2m_config.oauth_client_id is None + assert oauth_u2m_config.oauth_client_secret is None + + # auth_type must match the AuthType enum if specified + with pytest.raises(ConfigError, match=r".*nonexist does not match a valid option.*"): + make_config( + type="databricks", server_hostname="dbc-test.cloud.databricks.com", auth_type="nonexist" + ) + + # if client_secret is specified, client_id must also be specified + with pytest.raises(ConfigError, match=r"`oauth_client_id` is required.*"): + make_config( + type="databricks", + server_hostname="dbc-test.cloud.databricks.com", + auth_type="databricks-oauth", + oauth_client_secret="client-secret", + ) + + # http_path is still required when auth_type is specified + with pytest.raises(ConfigError, match=r"`http_path` is still required.*"): + make_config( + type="databricks", + server_hostname="dbc-test.cloud.databricks.com", + auth_type="databricks-oauth", + ) + + +def test_engine_import_validator(): + with pytest.raises( + ConfigError, + match=re.escape( + "Failed to import the 'bigquery' engine library. This may be due to a missing " + "or incompatible installation. Please ensure the required dependency is installed by " + 'running: `pip install "sqlmesh[bigquery]"`. For more details, check the logs ' + "in the 'logs/' folder, or rerun the command with the '--debug' flag." + ), + ): + + class TestConfigA(PydanticModel): + _engine_import_validator = _get_engine_import_validator("missing", "bigquery") + + TestConfigA() + + with pytest.raises( + ConfigError, + match=re.escape( + "Failed to import the 'bigquery' engine library. This may be due to a missing " + "or incompatible installation. Please ensure the required dependency is installed by " + 'running: `pip install "sqlmesh[bigquery_extra]"`. For more details, check the logs ' + "in the 'logs/' folder, or rerun the command with the '--debug' flag." + ), + ): + + class TestConfigB(PydanticModel): + _engine_import_validator = _get_engine_import_validator( + "missing", "bigquery", "bigquery_extra" + ) + + TestConfigB() + + class TestConfigC(PydanticModel): + _engine_import_validator = _get_engine_import_validator("sqlmesh", "bigquery") + + TestConfigC() + + +def test_engine_display_order(): + """ + Each engine's ConnectionConfig contains a display_order integer class var that is used to order the + interactive `sqlmesh init` engine choices. + + This test ensures that those integers begin with 1, are unique, and are sequential. + """ + display_numbers = [ + info[0] for info in sorted(INIT_DISPLAY_INFO_TO_TYPE.values(), key=lambda x: x[0]) + ] + assert display_numbers == list(range(1, len(display_numbers) + 1)) + + +def test_mssql_engine_import_validator(): + """Test that MSSQL import validator respects driver configuration.""" + + # Test PyODBC driver suggests mssql-odbc extra when import fails + with pytest.raises(ConfigError, match=r"pip install \"sqlmesh\[mssql-odbc\]\""): + with patch("importlib.import_module") as mock_import: + mock_import.side_effect = ImportError("No module named 'pyodbc'") + MSSQLConnectionConfig(host="localhost", driver="pyodbc") + + # Test PyMSSQL driver suggests mssql extra when import fails + with pytest.raises(ConfigError, match=r"pip install \"sqlmesh\[mssql\]\""): + with patch("importlib.import_module") as mock_import: + mock_import.side_effect = ImportError("No module named 'pymssql'") + MSSQLConnectionConfig(host="localhost", driver="pymssql") + + # Test default driver (pymssql) suggests mssql extra when import fails + with pytest.raises(ConfigError, match=r"pip install \"sqlmesh\[mssql\]\""): + with patch("importlib.import_module") as mock_import: + mock_import.side_effect = ImportError("No module named 'pymssql'") + MSSQLConnectionConfig(host="localhost") # No driver specified + + # Test successful import works without error + with patch("importlib.import_module") as mock_import: + mock_import.return_value = None + config = MSSQLConnectionConfig(host="localhost", driver="pyodbc") + assert config.driver == "pyodbc" + + +def test_mssql_connection_config_parameter_validation(make_config): + """Test MSSQL connection config parameter validation.""" + # Test default driver is pymssql + config = make_config(type="mssql", host="localhost", check_import=False) + assert isinstance(config, MSSQLConnectionConfig) + assert config.driver == "pymssql" + + # Test explicit pyodbc driver + config = make_config(type="mssql", host="localhost", driver="pyodbc", check_import=False) + assert isinstance(config, MSSQLConnectionConfig) + assert config.driver == "pyodbc" + + # Test explicit pymssql driver + config = make_config(type="mssql", host="localhost", driver="pymssql", check_import=False) + assert isinstance(config, MSSQLConnectionConfig) + assert config.driver == "pymssql" + + # Test pyodbc specific parameters + config = make_config( + type="mssql", + host="localhost", + driver="pyodbc", + driver_name="ODBC Driver 18 for SQL Server", + trust_server_certificate=True, + encrypt=False, + odbc_properties={"Authentication": "ActiveDirectoryServicePrincipal"}, + check_import=False, + ) + assert isinstance(config, MSSQLConnectionConfig) + assert config.driver_name == "ODBC Driver 18 for SQL Server" + assert config.trust_server_certificate is True + assert config.encrypt is False + assert config.odbc_properties == {"Authentication": "ActiveDirectoryServicePrincipal"} + + # Test pymssql specific parameters + config = make_config( + type="mssql", + host="localhost", + driver="pymssql", + tds_version="7.4", + conn_properties=["SET ANSI_NULLS ON"], + check_import=False, + ) + assert isinstance(config, MSSQLConnectionConfig) + assert config.tds_version == "7.4" + assert config.conn_properties == ["SET ANSI_NULLS ON"] + + +def test_mssql_connection_kwargs_keys(): + """Test _connection_kwargs_keys returns correct keys for each driver variant.""" + # Test pymssql driver keys + config = MSSQLConnectionConfig(host="localhost", driver="pymssql", check_import=False) + pymssql_keys = config._connection_kwargs_keys + expected_pymssql_keys = { + "password", + "user", + "database", + "host", + "timeout", + "login_timeout", + "charset", + "appname", + "port", + "tds_version", + "conn_properties", + "autocommit", + } + assert pymssql_keys == expected_pymssql_keys + + # Test pyodbc driver keys + config = MSSQLConnectionConfig(host="localhost", driver="pyodbc", check_import=False) + pyodbc_keys = config._connection_kwargs_keys + expected_pyodbc_keys = { + "password", + "user", + "database", + "host", + "timeout", + "login_timeout", + "charset", + "appname", + "port", + "autocommit", + "driver_name", + "trust_server_certificate", + "encrypt", + "odbc_properties", + } + assert pyodbc_keys == expected_pyodbc_keys + + # Verify pyodbc keys don't include pymssql-specific parameters + assert "tds_version" not in pyodbc_keys + assert "conn_properties" not in pyodbc_keys + + +def test_mssql_pyodbc_connection_string_generation(): + """Test pyodbc.connect gets invoked with the correct ODBC connection string.""" + with patch("pyodbc.connect") as mock_pyodbc_connect: + # Mock the return value to have the methods we need + mock_connection = mock_pyodbc_connect.return_value + + # Create a pyodbc config + config = MSSQLConnectionConfig( + host="testserver.database.windows.net", + port=1433, + database="testdb", + user="testuser", + password="testpass", + driver="pyodbc", + driver_name="ODBC Driver 18 for SQL Server", + trust_server_certificate=True, + encrypt=True, + login_timeout=30, + check_import=False, + ) + + # Get the connection factory with kwargs and call it + factory_with_kwargs = config._connection_factory_with_kwargs + connection = factory_with_kwargs() + + # Verify pyodbc.connect was called with the correct connection string + mock_pyodbc_connect.assert_called_once() + call_args = mock_pyodbc_connect.call_args + + # Check the connection string (first argument) + conn_str = call_args[0][0] + expected_parts = [ + "DRIVER={ODBC Driver 18 for SQL Server}", + "SERVER=testserver.database.windows.net,1433", + "DATABASE=testdb", + "Encrypt=YES", + "TrustServerCertificate=YES", + "Connection Timeout=30", + "UID=testuser", + "PWD=testpass", + ] + + for part in expected_parts: + assert part in conn_str + + # Check autocommit parameter + assert call_args[1]["autocommit"] is False + + +def test_mssql_pyodbc_connection_string_with_odbc_properties(): + """Test pyodbc connection string includes custom ODBC properties.""" + with patch("pyodbc.connect") as mock_pyodbc_connect: + # Create a pyodbc config with custom ODBC properties + config = MSSQLConnectionConfig( + host="testserver.database.windows.net", + database="testdb", + user="client-id", + password="client-secret", + driver="pyodbc", + odbc_properties={ + "Authentication": "ActiveDirectoryServicePrincipal", + "ClientCertificate": "/path/to/cert.pem", + "TrustServerCertificate": "NO", # This should be ignored since we set it explicitly + }, + trust_server_certificate=True, # This should take precedence + check_import=False, + ) + + # Get the connection factory with kwargs and call it + factory_with_kwargs = config._connection_factory_with_kwargs + connection = factory_with_kwargs() + + # Verify pyodbc.connect was called + mock_pyodbc_connect.assert_called_once() + conn_str = mock_pyodbc_connect.call_args[0][0] + + # Check that custom ODBC properties are included + assert "Authentication=ActiveDirectoryServicePrincipal" in conn_str + assert "ClientCertificate=/path/to/cert.pem" in conn_str + + # Verify that explicit trust_server_certificate takes precedence + assert "TrustServerCertificate=YES" in conn_str + + # Should not have the conflicting property from odbc_properties + assert conn_str.count("TrustServerCertificate") == 1 + + +def test_mssql_pyodbc_connection_string_minimal(): + """Test pyodbc connection string with minimal configuration.""" + with patch("pyodbc.connect") as mock_pyodbc_connect: + config = MSSQLConnectionConfig( + host="localhost", + driver="pyodbc", + autocommit=True, + check_import=False, + ) + + factory_with_kwargs = config._connection_factory_with_kwargs + connection = factory_with_kwargs() + + mock_pyodbc_connect.assert_called_once() + conn_str = mock_pyodbc_connect.call_args[0][0] + + # Check basic required parts + assert "DRIVER={ODBC Driver 18 for SQL Server}" in conn_str + assert "SERVER=localhost,1433" in conn_str + assert "Encrypt=YES" in conn_str # Default encrypt=True + assert "Connection Timeout=60" in conn_str # Default timeout + + # Check autocommit parameter + assert mock_pyodbc_connect.call_args[1]["autocommit"] is True + + +def test_mssql_pymssql_connection_factory(): + """Test pymssql connection factory returns correct function.""" + # Mock the import of pymssql at the module level + import sys + from unittest.mock import MagicMock + + # Create a mock pymssql module + mock_pymssql = MagicMock() + sys.modules["pymssql"] = mock_pymssql + + try: + config = MSSQLConnectionConfig( + host="localhost", + driver="pymssql", + check_import=False, + ) + + factory = config._connection_factory + + # Verify the factory returns pymssql.connect + assert factory is mock_pymssql.connect + finally: + # Clean up the mock module + if "pymssql" in sys.modules: + del sys.modules["pymssql"] + + +def test_mssql_pyodbc_connection_datetimeoffset_handling(): + """Test that the MSSQL pyodbc connection properly handles DATETIMEOFFSET conversion.""" + from datetime import datetime, timezone, timedelta + import struct + from unittest.mock import Mock, patch + + with patch("pyodbc.connect") as mock_pyodbc_connect: + # Track calls to add_output_converter + converter_calls = [] + + def mock_add_output_converter(sql_type, converter_func): + converter_calls.append((sql_type, converter_func)) + + # Create a mock connection that will be returned by pyodbc.connect + mock_connection = Mock() + mock_connection.add_output_converter = mock_add_output_converter + mock_pyodbc_connect.return_value = mock_connection + + config = MSSQLConnectionConfig( + host="localhost", + driver="pyodbc", # DATETIMEOFFSET handling is pyodbc-specific + check_import=False, + ) + + # Get the connection factory and call it + factory_with_kwargs = config._connection_factory_with_kwargs + connection = factory_with_kwargs() + + # Verify that add_output_converter was called for SQL type -155 (DATETIMEOFFSET) + assert len(converter_calls) == 1 + sql_type, converter_func = converter_calls[0] + assert sql_type == -155 + + # Test the converter function with actual DATETIMEOFFSET binary data + # Create a test DATETIMEOFFSET value: 2023-12-25 15:30:45.123456789 +05:30 + year, month, day = 2023, 12, 25 + hour, minute, second = 15, 30, 45 + nanoseconds = 123456789 + tz_hour_offset, tz_minute_offset = 5, 30 + + # Pack the binary data according to the DATETIMEOFFSET format + binary_data = struct.pack( + "<6hI2h", + year, + month, + day, + hour, + minute, + second, + nanoseconds, + tz_hour_offset, + tz_minute_offset, + ) + + # Convert using the registered converter + result = converter_func(binary_data) + + # Verify the result + expected_dt = datetime( + 2023, + 12, + 25, + 15, + 30, + 45, + 123456, # microseconds = nanoseconds // 1000 + timezone(timedelta(hours=5, minutes=30)), + ) + assert result == expected_dt + assert result.tzinfo == timezone(timedelta(hours=5, minutes=30)) + + +def test_mssql_pyodbc_connection_negative_timezone_offset(): + """Test DATETIMEOFFSET handling with negative timezone offset at connection level.""" + from datetime import datetime, timezone, timedelta + import struct + from unittest.mock import Mock, patch + + with patch("pyodbc.connect") as mock_pyodbc_connect: + converter_calls = [] + + def mock_add_output_converter(sql_type, converter_func): + converter_calls.append((sql_type, converter_func)) + + mock_connection = Mock() + mock_connection.add_output_converter = mock_add_output_converter + mock_pyodbc_connect.return_value = mock_connection + + config = MSSQLConnectionConfig( + host="localhost", + driver="pyodbc", # DATETIMEOFFSET handling is pyodbc-specific + check_import=False, + ) + + factory_with_kwargs = config._connection_factory_with_kwargs + connection = factory_with_kwargs() + + # Get the converter function + _, converter_func = converter_calls[0] + + # Test with negative timezone offset: 2023-01-01 12:00:00.0 -08:00 + year, month, day = 2023, 1, 1 + hour, minute, second = 12, 0, 0 + nanoseconds = 0 + tz_hour_offset, tz_minute_offset = -8, 0 + + binary_data = struct.pack( + "<6hI2h", + year, + month, + day, + hour, + minute, + second, + nanoseconds, + tz_hour_offset, + tz_minute_offset, + ) + + result = converter_func(binary_data) + + expected_dt = datetime(2023, 1, 1, 12, 0, 0, 0, timezone(timedelta(hours=-8, minutes=0))) + assert result == expected_dt + assert result.tzinfo == timezone(timedelta(hours=-8)) + + +def test_fabric_connection_config_defaults(make_config): + """Test Fabric connection config defaults to pyodbc and autocommit=True.""" + config = make_config( + type="fabric", + host="localhost", + workspace_id="test-workspace-id", + tenant_id="test-tenant-id", + check_import=False, + ) + assert isinstance(config, FabricConnectionConfig) + assert config.driver == "pyodbc" + assert config.autocommit is True + + # Ensure it creates the FabricEngineAdapter + from sqlmesh.core.engine_adapter.fabric import FabricEngineAdapter + + assert isinstance(config.create_engine_adapter(), FabricEngineAdapter) + + +def test_fabric_connection_config_parameter_validation(make_config): + """Test Fabric connection config parameter validation.""" + # Test that FabricConnectionConfig correctly handles pyodbc-specific parameters. + config = make_config( + type="fabric", + host="localhost", + driver_name="ODBC Driver 18 for SQL Server", + trust_server_certificate=True, + encrypt=False, + odbc_properties={"Authentication": "ActiveDirectoryServicePrincipal"}, + workspace_id="test-workspace-id", + tenant_id="test-tenant-id", + check_import=False, + ) + assert isinstance(config, FabricConnectionConfig) + assert config.driver == "pyodbc" # Driver is fixed to pyodbc + assert config.driver_name == "ODBC Driver 18 for SQL Server" + assert config.trust_server_certificate is True + assert config.encrypt is False + assert config.odbc_properties == {"Authentication": "ActiveDirectoryServicePrincipal"} + + # Test that specifying a different driver for Fabric raises an error + with pytest.raises(ConfigError, match=r"Input should be 'pyodbc'"): + make_config(type="fabric", host="localhost", driver="pymssql", check_import=False) + + +def test_fabric_pyodbc_connection_string_generation(): + """Test that the Fabric pyodbc connection gets invoked with the correct ODBC connection string.""" + with patch("pyodbc.connect") as mock_pyodbc_connect: + # Create a Fabric config + config = FabricConnectionConfig( + host="testserver.datawarehouse.fabric.microsoft.com", + port=1433, + database="testdb", + user="testuser", + password="testpass", + driver_name="ODBC Driver 18 for SQL Server", + trust_server_certificate=True, + encrypt=True, + login_timeout=30, + workspace_id="test-workspace-id", + tenant_id="test-tenant-id", + check_import=False, + ) + + # Get the connection factory with kwargs and call it + factory_with_kwargs = config._connection_factory_with_kwargs + connection = factory_with_kwargs() + + # Verify pyodbc.connect was called with the correct connection string + mock_pyodbc_connect.assert_called_once() + call_args = mock_pyodbc_connect.call_args + + # Check the connection string (first argument) + conn_str = call_args[0][0] + expected_parts = [ + "DRIVER={ODBC Driver 18 for SQL Server}", + "SERVER=testserver.datawarehouse.fabric.microsoft.com,1433", + "DATABASE=testdb", + "Encrypt=YES", + "TrustServerCertificate=YES", + "Connection Timeout=30", + "UID=testuser", + "PWD=testpass", + ] + + for part in expected_parts: + assert part in conn_str + + # Check autocommit parameter, should default to True for Fabric + assert call_args[1]["autocommit"] is True + + +def test_schema_differ_overrides(make_config) -> None: + default_config = make_config(type="duckdb") + assert default_config.schema_differ_overrides is None + default_adapter = default_config.create_engine_adapter() + assert default_adapter._schema_differ_overrides is None + assert default_adapter.schema_differ.parameterized_type_defaults != {} + + override: t.Dict[str, t.Any] = {"parameterized_type_defaults": {}} + config = make_config(type="duckdb", schema_differ_overrides=override) + assert config.schema_differ_overrides == override + adapter = config.create_engine_adapter() + assert adapter._schema_differ_overrides == override + assert adapter.schema_differ.parameterized_type_defaults == {} diff --git a/tests/core/test_console.py b/tests/core/test_console.py new file mode 100644 index 0000000000..f899713235 --- /dev/null +++ b/tests/core/test_console.py @@ -0,0 +1,131 @@ +from sqlmesh.core.console import MarkdownConsole + + +def test_markdown_console_warning_block(): + console = MarkdownConsole( + alert_block_max_content_length=100, alert_block_collapsible_threshold=45 + ) + assert console.consume_captured_warnings() == "" + + # single warning, within threshold + console.log_warning("First warning") + assert console.consume_captured_warnings() == "> [!WARNING]\n>\n> First warning\n\n" + + # multiple warnings, within threshold (list syntax) + console.log_warning("First warning") + console.log_warning("Second warning") + assert ( + console.consume_captured_warnings() + == "> [!WARNING]\n>\n> - First warning\n>\n> - Second warning\n\n" + ) + + # single warning, within max threshold but over collapsible section threshold + warning = "The snowflake engine is not recommended for storing SQLMesh state in production deployments" + assert len(warning) > console.alert_block_collapsible_threshold + assert len(warning) < console.alert_block_max_content_length + console.log_warning(warning) + assert ( + console.consume_captured_warnings() + == "> [!WARNING]\n>
\n>\n> The snowflake engine is not recommended for storing SQLMesh state in production deployments\n>
\n" + ) + + # single warning, over max threshold + warning = "The snowflake engine is not recommended for storing SQLMesh state in production deployments. Please see for a list of recommended engines and more information." + assert len(warning) > console.alert_block_collapsible_threshold + assert len(warning) > console.alert_block_max_content_length + console.log_warning(warning) + assert ( + console.consume_captured_warnings() + == "> [!WARNING]\n>
\n>\n> The snowflake engine is not re...\n>\n> Truncated. Please check the console for full information.\n>
\n" + ) + + # multiple warnings, within max threshold but over collapsible section threshold + warning_1 = "This is the first warning" + warning_2 = "This is the second warning" + assert (len(warning_1) + len(warning_2)) > console.alert_block_collapsible_threshold + assert (len(warning_1) + len(warning_2)) < console.alert_block_max_content_length + console.log_warning(warning_1) + console.log_warning(warning_2) + assert ( + console.consume_captured_warnings() + == "> [!WARNING]\n>
\n>\n> - This is the first warning\n>\n> - This is the second warning\n>
\n" + ) + + # multiple warnings, over max threshold + warning_1 = "This is the first warning and its really really long" + warning_2 = "This is the second warning and its also really really long" + assert (len(warning_1) + len(warning_2)) > console.alert_block_collapsible_threshold + assert (len(warning_1) + len(warning_2)) > console.alert_block_max_content_length + console.log_warning(warning_1) + console.log_warning(warning_2) + assert ( + console.consume_captured_warnings() + == "> [!WARNING]\n>
\n>\n> - This is the first warning an...\n>\n> Truncated. Please check the console for full information.\n>
\n" + ) + + assert console.consume_captured_warnings() == "" + + +def test_markdown_console_error_block(): + console = MarkdownConsole( + alert_block_max_content_length=100, alert_block_collapsible_threshold=40 + ) + assert console.consume_captured_errors() == "" + + # single error, within threshold + console.log_error("First error") + assert console.consume_captured_errors() == "> [!CAUTION]\n>\n> First error\n\n" + + # multiple errors, within threshold (list syntax) + console.log_error("First error") + console.log_error("Second error") + assert ( + console.consume_captured_errors() + == "> [!CAUTION]\n>\n> - First error\n>\n> - Second error\n\n" + ) + + # single error, within max threshold but over collapsible section threshold + error = "The snowflake engine is not recommended for storing SQLMesh state in production deployments" + assert len(error) > console.alert_block_collapsible_threshold + assert len(error) < console.alert_block_max_content_length + console.log_error(error) + assert ( + console.consume_captured_errors() + == "> [!CAUTION]\n>
\n>\n> The snowflake engine is not recommended for storing SQLMesh state in production deployments\n>
\n" + ) + + # single error, over max threshold + error = "The snowflake engine is not recommended for storing SQLMesh state in production deployments. Please see for a list of recommended engines and more information." + assert len(error) > console.alert_block_collapsible_threshold + assert len(error) > console.alert_block_max_content_length + console.log_error(error) + assert ( + console.consume_captured_errors() + == "> [!CAUTION]\n>
\n>\n> The snowflake engine is not re...\n>\n> Truncated. Please check the console for full information.\n>
\n" + ) + + # multiple errors, within max threshold but over collapsible section threshold + error_1 = "This is the first error" + error_2 = "This is the second error" + assert (len(error_1) + len(error_2)) > console.alert_block_collapsible_threshold + assert (len(error_1) + len(error_2)) < console.alert_block_max_content_length + console.log_error(error_1) + console.log_error(error_2) + assert ( + console.consume_captured_errors() + == "> [!CAUTION]\n>
\n>\n> - This is the first error\n>\n> - This is the second error\n>
\n" + ) + + # multiple errors, over max threshold + error_1 = "This is the first error and its really really long" + error_2 = "This is the second error and its also really really long" + assert (len(error_1) + len(error_2)) > console.alert_block_collapsible_threshold + assert (len(error_1) + len(error_2)) > console.alert_block_max_content_length + console.log_error(error_1) + console.log_error(error_2) + assert ( + console.consume_captured_errors() + == "> [!CAUTION]\n>
\n>\n> - This is the first error and ...\n>\n> Truncated. Please check the console for full information.\n>
\n" + ) + + assert console.consume_captured_errors() == "" diff --git a/tests/core/test_context.py b/tests/core/test_context.py index 6444dc1218..c3d88e205e 100644 --- a/tests/core/test_context.py +++ b/tests/core/test_context.py @@ -1,42 +1,69 @@ import logging import pathlib import typing as t -from datetime import date, timedelta +import re +from datetime import date, timedelta, datetime from tempfile import TemporaryDirectory from unittest.mock import PropertyMock, call, patch -import freezegun +import time_machine import pytest -import pandas as pd +import pandas as pd # noqa: TID253 from pathlib import Path from pytest_mock.plugin import MockerFixture -from sqlglot import parse_one +from sqlglot import ParseError, exp, parse_one, Dialect from sqlglot.errors import SchemaError import sqlmesh.core.constants -import sqlmesh.core.dialect as d +from sqlmesh.cli.project_init import init_example_project +from sqlmesh.core.console import TerminalConsole +from sqlmesh.core import dialect as d, constants as c from sqlmesh.core.config import ( + load_configs, + AutoCategorizationMode, + CategorizerConfig, Config, DuckDBConnectionConfig, EnvironmentSuffixTarget, + GatewayConfig, + LinterConfig, ModelDefaultsConfig, + PlanConfig, SnowflakeConnectionConfig, - load_configs, ) from sqlmesh.core.context import Context +from sqlmesh.core.console import create_console, get_console from sqlmesh.core.dialect import parse, schema_ -from sqlmesh.core.environment import Environment -from sqlmesh.core.model import load_sql_based_model, model +from sqlmesh.core.engine_adapter.duckdb import DuckDBEngineAdapter +from sqlmesh.core.environment import Environment, EnvironmentNamingInfo, EnvironmentStatements +from sqlmesh.core.plan.definition import Plan +from sqlmesh.core.macros import MacroEvaluator, RuntimeStage +from sqlmesh.core.model import load_sql_based_model, model, SqlModel, Model +from sqlmesh.core.model.common import ParsableSql +from sqlmesh.core.model.cache import OptimizedQueryCache +from sqlmesh.core.renderer import render_statements from sqlmesh.core.model.kind import ModelKindName -from sqlmesh.core.plan import BuiltInPlanEvaluator, PlanBuilder +from sqlmesh.core.state_sync.cache import CachingStateSync +from sqlmesh.core.state_sync.db import EngineAdapterStateSync +from sqlmesh.utils.connection_pool import SingletonConnectionPool, ThreadLocalSharedConnectionPool from sqlmesh.utils.date import ( make_inclusive_end, now, to_date, + to_datetime, to_timestamp, yesterday_ds, ) -from sqlmesh.utils.errors import ConfigError, SQLMeshError +from sqlmesh.utils.errors import ( + ConfigError, + SQLMeshError, + LinterError, + PlanError, + NoChangesPlanError, +) +from sqlmesh.utils.metaprogramming import Executable +from sqlmesh.utils.windows import IS_WINDOWS, fix_windows_path +from tests.utils.test_helpers import use_terminal_console from tests.utils.test_filesystem import create_temp_file @@ -47,7 +74,7 @@ def test_global_config(copy_to_temp_path: t.Callable): def test_named_config(copy_to_temp_path: t.Callable): - context = Context(paths=copy_to_temp_path("examples/sushi"), config="local_config") + context = Context(paths=copy_to_temp_path("examples/sushi"), config="test_config") assert len(context.config.gateways) == 1 @@ -77,7 +104,9 @@ def test_generate_table_name_in_dialect(mocker: MockerFixture): "sqlmesh.core.context.GenericContext._model_tables", PropertyMock(return_value={'"project-id"."dataset"."table"': '"project-id".dataset.table'}), ) - assert context.table('"project-id"."dataset"."table"') == "`project-id`.dataset.table" + assert ( + context.resolve_table('"project-id"."dataset"."table"') == "`project-id`.`dataset`.`table`" + ) def test_config_not_found(copy_to_temp_path: t.Callable): @@ -85,7 +114,7 @@ def test_config_not_found(copy_to_temp_path: t.Callable): ConfigError, match=r".*config could not be found.*", ): - Context(paths="nonexistent/directory", config="local_config") + Context(paths="nonexistent/directory", config="config") def test_custom_macros(sushi_context): @@ -178,6 +207,50 @@ def test_render_sql_model(sushi_context, assert_exp_eq, copy_to_temp_path: t.Cal ) +@pytest.mark.slow +def test_render_non_deployable_parent(sushi_context, assert_exp_eq, copy_to_temp_path: t.Callable): + model = sushi_context.get_model("sushi.waiter_revenue_by_day") + forward_only_kind = model.kind.copy(update={"forward_only": True}) + model = model.copy(update={"kind": forward_only_kind, "stamp": "trigger forward-only change"}) + sushi_context.upsert_model(model) + sushi_context.plan("dev", no_prompts=True, auto_apply=True) + + expected_table_name = parse_one( + sushi_context.get_snapshot("sushi.waiter_revenue_by_day").table_name(is_deployable=False), + into=exp.Table, + ).this.this + + assert_exp_eq( + sushi_context.render( + "sushi.top_waiters", + start=date(2021, 1, 1), + end=date(2021, 1, 1), + ), + f""" + WITH "test_macros" AS ( + SELECT + 2 AS "lit_two", + "waiter_revenue_by_day"."revenue" * 2.0 AS "sql_exp", + CAST("waiter_revenue_by_day"."revenue" AS TEXT) AS "sql_lit" + FROM "memory"."sqlmesh__sushi"."{expected_table_name}" AS "waiter_revenue_by_day" /* memory.sushi.waiter_revenue_by_day */ + ) + SELECT + CAST("waiter_revenue_by_day"."waiter_id" AS INT) AS "waiter_id", + CAST("waiter_revenue_by_day"."revenue" AS DOUBLE) AS "revenue" + FROM "memory"."sqlmesh__sushi"."{expected_table_name}" AS "waiter_revenue_by_day" /* memory.sushi.waiter_revenue_by_day */ + WHERE + "waiter_revenue_by_day"."event_date" = ( + SELECT + MAX("waiter_revenue_by_day"."event_date") AS "_col_0" + FROM "memory"."sqlmesh__sushi"."{expected_table_name}" AS "waiter_revenue_by_day" /* memory.sushi.waiter_revenue_by_day */ + ) + ORDER BY + "revenue" DESC + LIMIT 10 + """, + ) + + @pytest.mark.slow def test_render_seed_model(sushi_context, assert_exp_eq): assert_exp_eq( @@ -207,24 +280,9 @@ def test_diff(sushi_context: Context, mocker: MockerFixture): yesterday = yesterday_ds() success = sushi_context.run(start=yesterday, end=yesterday) - plan_evaluator = BuiltInPlanEvaluator( - sushi_context.state_sync, sushi_context.snapshot_evaluator, sushi_context.default_catalog - ) - - plan = PlanBuilder( - context_diff=sushi_context._context_diff("prod"), - engine_schema_differ=sushi_context.engine_adapter.SCHEMA_DIFFER, - ).build() - - # stringify used to trigger an unhashable exception due to - # https://github.com/pydantic/pydantic/issues/8016 - assert str(plan) != "" - - promotion_result = plan_evaluator._promote(plan) - plan_evaluator._update_views(plan, promotion_result) - sushi_context.upsert_model("sushi.customers", query=parse_one("select 1 as customer_id")) sushi_context.diff("test") + assert mock_console.show_environment_difference_summary.called assert mock_console.show_model_difference_summary.called assert success @@ -262,6 +320,85 @@ def test_evaluate_limit(): assert context.evaluate("without_limit", "2020-01-01", "2020-01-02", "2020-01-02", 2).size == 2 +def test_gateway_specific_adapters(copy_to_temp_path, mocker): + path = copy_to_temp_path("examples/sushi") + ctx = Context(paths=path, config="isolated_systems_config", gateway="prod") + assert len(ctx.engine_adapters) == 3 + assert ctx.engine_adapter == ctx.engine_adapters["prod"] + assert ctx._get_engine_adapter("dev") == ctx.engine_adapters["dev"] + + ctx = Context(paths=path, config="isolated_systems_config") + assert len(ctx.engine_adapters) == 3 + assert ctx.engine_adapter == ctx.engine_adapters["dev"] + + ctx = Context(paths=path, config="isolated_systems_config") + assert len(ctx.engine_adapters) == 3 + assert ctx.engine_adapter == ctx._get_engine_adapter() + assert ctx._get_engine_adapter("test") == ctx.engine_adapters["test"] + + +def test_multiple_gateways(tmp_path: Path): + db_path = str(tmp_path / "db.db") + gateways = { + "staging": GatewayConfig(connection=DuckDBConnectionConfig(database=db_path)), + "final": GatewayConfig(connection=DuckDBConnectionConfig(database=db_path)), + } + + config = Config(gateways=gateways, default_gateway="final") + context = Context(config=config) + + gateway_model = load_sql_based_model( + parse( + """ + MODEL(name staging.stg_model, start '2024-01-01',kind FULL, gateway staging); + SELECT t.v as v FROM (VALUES (1), (2), (3), (4), (5)) AS t(v)""" + ), + default_catalog="db", + ) + + assert gateway_model.gateway == "staging" + context.upsert_model(gateway_model) + assert context.evaluate("staging.stg_model", "2020-01-01", "2020-01-02", "2020-01-02").size == 5 + + default_model = load_sql_based_model( + parse( + """ + MODEL(name main.final_model, start '2024-01-01',kind FULL); + SELECT v FROM staging.stg_model""" + ), + default_catalog="db", + ) + + assert not default_model.gateway + context.upsert_model(default_model) + + context.plan( + execution_time="2024-01-02", + auto_apply=True, + no_prompts=True, + ) + + sorted_snapshots = sorted(context.snapshots.values()) + + physical_schemas = [snapshot.physical_schema for snapshot in sorted_snapshots] + assert physical_schemas == ["sqlmesh__main", "sqlmesh__staging"] + + view_schemas = [snapshot.qualified_view_name.schema_name for snapshot in sorted_snapshots] + assert view_schemas == ["main", "staging"] + + assert ( + str(context.fetchdf("select * from staging.stg_model")) + == " v\n0 1\n1 2\n2 3\n3 4\n4 5" + ) + assert str(context.fetchdf("select * from final_model")) == " v\n0 1\n1 2\n2 3\n3 4\n4 5" + + assert ( + context.snapshots['"db"."main"."final_model"'].parents[0].name + == '"db"."staging"."stg_model"' + ) + assert context.dag._sorted == ['"db"."staging"."stg_model"', '"db"."main"."final_model"'] + + def test_plan_execution_time(): context = Context(config=Config()) context.upsert_model( @@ -292,7 +429,162 @@ def test_plan_execution_time(): ) +def test_plan_execution_time_start_end(): + context = Context(config=Config()) + context.upsert_model( + load_sql_based_model( + parse( + """ + MODEL( + name db.x, + start '2020-01-01', + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds + ), + cron '@daily' + ); + + SELECT id, ds FROM (VALUES + ('1', '2020-01-01'), + ('2', '2021-01-01'), + ('3', '2022-01-01'), + ('4', '2023-01-01'), + ('5', '2024-01-01') + ) data(id, ds) + WHERE ds BETWEEN @start_ds AND @end_ds + """ + ) + ) + ) + + # prod plan - no fixed execution time so it defaults to now() and reads all the data + prod_plan = context.plan(auto_apply=True) + + assert len(prod_plan.new_snapshots) == 1 + + context.upsert_model( + load_sql_based_model( + parse( + """ + MODEL( + name db.x, + start '2020-01-01', + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds + ), + cron '@daily' + ); + + SELECT id, ds, 'changed' as a FROM (VALUES + ('1', '2020-01-01'), + ('2', '2021-01-01'), + ('3', '2022-01-01'), + ('4', '2023-01-01'), + ('5', '2024-01-01') + ) data(id, ds) + WHERE ds BETWEEN @start_ds AND @end_ds + """ + ) + ) + ) + + # dev plan with an execution time in the past and no explicit start/end specified + # the plan end should be bounded to it and not exceed it even though in prod the last interval (used as a default end) + # is newer than the execution time + dev_plan = context.plan("dev", execution_time="2020-01-05") + + assert to_datetime(dev_plan.start) == to_datetime( + "2020-01-01" + ) # default start is the earliest prod interval + assert to_datetime(dev_plan.execution_time) == to_datetime("2020-01-05") + assert to_datetime(dev_plan.end) == to_datetime( + "2020-01-05" + ) # end should not be greater than execution_time + + # same as above but with a relative start + dev_plan = context.plan("dev", start="1 day ago", execution_time="2020-01-05") + + assert to_datetime(dev_plan.start) == to_datetime( + "2020-01-04" + ) # start relative to execution_time + assert to_datetime(dev_plan.execution_time) == to_datetime("2020-01-05") + assert to_datetime(dev_plan.end) == to_datetime( + "2020-01-05" + ) # end should not be greater than execution_time + + # same as above but with a relative start and a relative end + dev_plan = context.plan("dev", start="2 days ago", execution_time="2020-01-05", end="1 day ago") + + assert to_datetime(dev_plan.start) == to_datetime( + "2020-01-03" + ) # start relative to execution_time + assert to_datetime(dev_plan.execution_time) == to_datetime("2020-01-05") + assert to_datetime(dev_plan.end) == to_datetime("2020-01-04") # end relative to execution_time + + +def test_override_builtin_audit_blocking_mode(): + context = Context(config=Config()) + context.upsert_model( + load_sql_based_model( + parse( + """ + MODEL( + name db.x, + kind FULL, + audits ( + not_null(columns := [c], blocking := false), + unique_values(columns := [c]), + ) + ); + + SELECT NULL AS c + """ + ) + ) + ) + + with patch.object(context.console, "log_warning") as mock_logger: + plan = context.plan(auto_apply=True, no_prompts=True) + new_snapshot = next(iter(plan.context_diff.new_snapshots.values())) + + assert ( + mock_logger.call_args_list[0][0][0] == "\ndb.x: 'not_null' audit error: 1 row failed." + ) + + # Even though there are two builtin audits referenced in the above definition, we only + # store the one that overrides `blocking` in the snapshot; the other one isn't needed + audits_with_args = new_snapshot.model.audits_with_args + assert len(audits_with_args) == 2 + audit, args = audits_with_args[0] + assert audit.name == "not_null" + assert list(args) == ["columns", "blocking"] + + context = Context(config=Config()) + context.upsert_model( + load_sql_based_model( + parse( + """ + MODEL( + name db.x, + kind FULL, + audits ( + not_null_non_blocking(columns := [c], blocking := true) + ) + ); + + SELECT NULL AS c + """ + ) + ) + ) + + with pytest.raises(SQLMeshError): + context.plan(auto_apply=True, no_prompts=True) + + def test_python_model_empty_df_raises(sushi_context, capsys): + sushi_context.console = create_console() + @model( "memory.sushi.test_model", columns={"col": "int"}, @@ -311,11 +603,8 @@ def entrypoint(context, **kwargs): sushi_context.plan(no_prompts=True, auto_apply=True) assert ( - "Cannot construct source query from an empty \nDataFrame. This error " - "is commonly related to Python models that produce no data.\nFor such " - "models, consider yielding from an empty generator if the resulting set " - "\nis empty, i.e. use `yield from ()`" - ) in capsys.readouterr().out + "Cannot construct source query from an empty DataFrame. This error is commonly related to Python models that produce no data. For such models, consider yielding from an empty generator if the resulting set is empty, i.e. use" + ) in capsys.readouterr().out.replace("\n", "") def test_env_and_default_schema_normalization(mocker: MockerFixture): @@ -343,6 +632,49 @@ def test_env_and_default_schema_normalization(mocker: MockerFixture): assert list(context.fetchdf('select c from "DEFAULT__DEV"."X"')["c"])[0] == 1 +def test_jinja_macro_undefined_variable_error(tmp_path: pathlib.Path): + models_dir = tmp_path / "models" + models_dir.mkdir(parents=True) + macros_dir = tmp_path / "macros" + macros_dir.mkdir(parents=True) + + macro_file = macros_dir / "my_macros.sql" + macro_file.write_text(""" +{%- macro generate_select(table_name) -%} + {%- if target.name == 'production' -%} + {%- set results = run_query('SELECT 1') -%} + {%- endif -%} + SELECT {{ results.columns[0].values()[0] }} FROM {{ table_name }} +{%- endmacro -%} +""") + + model_file = models_dir / "my_model.sql" + model_file.write_text(""" +MODEL ( + name my_schema.my_model, + kind FULL +); + +JINJA_QUERY_BEGIN; +{{ generate_select('users') }} +JINJA_END; +""") + + config_file = tmp_path / "config.yaml" + config_file.write_text(""" +model_defaults: + dialect: duckdb +""") + + with pytest.raises(ConfigError) as exc_info: + Context(paths=str(tmp_path)) + + error_message = str(exc_info.value) + assert "Failed to load model" in error_message + assert "Could not render jinja for" in error_message + assert "Undefined macro/variable: 'target' in macro: 'generate_select'" in error_message + + def test_clear_caches(tmp_path: pathlib.Path): models_dir = tmp_path / "models" @@ -360,6 +692,141 @@ def test_clear_caches(tmp_path: pathlib.Path): assert not cache_dir.exists() assert models_dir.exists() + # Ensure that we don't initialize a CachingStateSync only to clear its (empty) caches + assert context._state_sync is None + + # Test clearing caches when cache directory doesn't exist + # This should not raise an exception + context.clear_caches() + assert not cache_dir.exists() + + +def test_clear_caches_with_long_base_path(tmp_path: pathlib.Path): + base_path = tmp_path / ("abcde" * 50) + assert ( + len(str(base_path.absolute())) > 260 + ) # Paths longer than 260 chars trigger problems on Windows + + default_cache_dir = base_path / c.CACHE + custom_cache_dir = base_path / ".test_cache" + + # note: we create the Context here so it doesnt get passed any "fixed" paths + ctx = Context(config=Config(cache_dir=str(custom_cache_dir)), paths=base_path) + + if IS_WINDOWS: + # fix these so we can use them in this test + default_cache_dir = fix_windows_path(default_cache_dir) + custom_cache_dir = fix_windows_path(custom_cache_dir) + + default_cache_dir.mkdir(parents=True) + custom_cache_dir.mkdir(parents=True) + + default_cache_file = default_cache_dir / "cache.txt" + custom_cache_file = custom_cache_dir / "cache.txt" + + default_cache_file.write_text("test") + custom_cache_file.write_text("test") + + assert default_cache_file.exists() + assert custom_cache_file.exists() + assert default_cache_dir.exists() + assert custom_cache_dir.exists() + + ctx.clear_caches() + + assert not default_cache_file.exists() + assert not custom_cache_file.exists() + assert not default_cache_dir.exists() + assert not custom_cache_dir.exists() + + +def test_cache_path_configurations(tmp_path: pathlib.Path): + project_dir = tmp_path / "project" + project_dir.mkdir(parents=True) + config_file = project_dir / "config.yaml" + + # Test relative path + config_file.write_text("model_defaults:\n dialect: duckdb\ncache_dir: .my_cache") + context = Context(paths=str(project_dir)) + assert context.cache_dir == project_dir / ".my_cache" + + # Test absolute path + abs_cache = tmp_path / "abs_cache" + config_file.write_text(f"model_defaults:\n dialect: duckdb\ncache_dir: {abs_cache}") + context = Context(paths=str(project_dir)) + assert context.cache_dir == abs_cache + + # Test default + config_file.write_text("model_defaults:\n dialect: duckdb") + context = Context(paths=str(project_dir)) + assert context.cache_dir == project_dir / ".cache" + + +def test_plan_apply_populates_cache(copy_to_temp_path, mocker): + sushi_paths = copy_to_temp_path("examples/sushi") + sushi_path = sushi_paths[0] + custom_cache_dir = sushi_path.parent / "custom_cache" + + # Modify the existing config.py to add cache_dir to test_config + config_py_path = sushi_path / "config.py" + with open(config_py_path, "r") as f: + config_content = f.read() + + # Add cache_dir to the test_config definition + config_content += f"""test_config_cache_dir = Config( + gateways={{"in_memory": GatewayConfig(connection=DuckDBConnectionConfig())}}, + default_gateway="in_memory", + plan=PlanConfig( + auto_categorize_changes=CategorizerConfig( + sql=AutoCategorizationMode.SEMI, python=AutoCategorizationMode.OFF + ) + ), + model_defaults=model_defaults, + cache_dir="{custom_cache_dir.as_posix()}", + before_all=before_all, +)""" + + with open(config_py_path, "w") as f: + f.write(config_content) + + # Create context with the test config + context = Context(paths=sushi_path, config="test_config_cache_dir") + custom_cache_dir = context.cache_dir + assert "custom_cache" in str(custom_cache_dir) + assert (custom_cache_dir / "optimized_query").exists() + assert (custom_cache_dir / "model_definition").exists() + assert not (custom_cache_dir / "snapshot").exists() + + # Clear the cache + context.clear_caches() + assert not custom_cache_dir.exists() + + plan = context.plan("dev", create_from="prod", skip_tests=True) + context.apply(plan) + + # Cache directory should now exist again + assert custom_cache_dir.exists() + assert any(custom_cache_dir.iterdir()) + + # Since the cache has been deleted post loading here only snapshot should exist + assert (custom_cache_dir / "snapshot").exists() + assert not (custom_cache_dir / "optimized_query").exists() + assert not (custom_cache_dir / "model_definition").exists() + + # New context should load same models and create the cache for optimized_query and model_definition + initial_model_count = len(context.models) + context2 = Context(paths=context.path, config="test_config_cache_dir") + cached_model_count = len(context2.models) + + assert initial_model_count == cached_model_count > 0 + assert (custom_cache_dir / "optimized_query").exists() + assert (custom_cache_dir / "model_definition").exists() + assert (custom_cache_dir / "snapshot").exists() + + # Clear caches should remove the custom cache directory + context.clear_caches() + assert not custom_cache_dir.exists() + def test_ignore_files(mocker: MockerFixture, tmp_path: pathlib.Path): mocker.patch.object( @@ -376,6 +843,11 @@ def test_ignore_files(mocker: MockerFixture, tmp_path: pathlib.Path): pathlib.Path(models_dir, "ignore", "ignore_model.sql"), "MODEL(name ignore.ignore_model); SELECT 1 AS cola", ) + create_temp_file( + tmp_path, + pathlib.Path(models_dir, "ignore", "inner_ignore", "inner_ignore_model.sql"), + "MODEL(name ignore.inner_ignore_model); SELECT 1 AS cola", + ) create_temp_file( tmp_path, pathlib.Path(macros_dir, "macro_ignore.py"), @@ -409,7 +881,7 @@ def test(): """, ) config = Config( - ignore_patterns=["models/ignore/*.sql", "macro_ignore.py", ".ipynb_checkpoints/*"] + ignore_patterns=["models/ignore/**/*.sql", "macro_ignore.py", ".ipynb_checkpoints/*"] ) context = Context(paths=tmp_path, config=config) @@ -475,6 +947,7 @@ def test_project_config_person_config_overrides(tmp_path: pathlib.Path): assert snowflake_connection.account == "abc123" assert snowflake_connection.user == "ABC" assert snowflake_connection.password == "XYZ" + assert snowflake_connection.application == "Tobiko_SQLMesh" @pytest.mark.slow @@ -500,12 +973,12 @@ def get_sushi_fingerprints(context: Context): project_path = copy_to_temp_path("examples/sushi") no_mapping_context = Context(paths=project_path) - assert no_mapping_context.config.physical_schema_override == {} + assert no_mapping_context.config.physical_schema_mapping == {} assert get_schemas(no_mapping_context) == {"sqlmesh__sushi", "sqlmesh__raw"} assert get_view_schemas(no_mapping_context) == {"sushi", "raw"} no_mapping_fingerprints = get_sushi_fingerprints(no_mapping_context) context = Context(paths=project_path, config="map_config") - assert context.config.physical_schema_override == {"sushi": "company_internal"} + assert context.config.physical_schema_mapping == {re.compile("^sushi$"): "company_internal"} assert get_schemas(context) == {"company_internal", "sqlmesh__raw"} assert get_view_schemas(context) == {"sushi", "raw"} sushi_fingerprints = get_sushi_fingerprints(context) @@ -516,14 +989,61 @@ def get_sushi_fingerprints(context: Context): ) +@pytest.mark.slow +def test_physical_schema_mapping(tmp_path: pathlib.Path) -> None: + create_temp_file( + tmp_path, + pathlib.Path(pathlib.Path("models"), "a.sql"), + "MODEL(name foo_staging.model_a); SELECT 1;", + ) + + create_temp_file( + tmp_path, + pathlib.Path(pathlib.Path("models"), "b.sql"), + "MODEL(name testone.model_b); SELECT 1;", + ) + + create_temp_file( + tmp_path, + pathlib.Path(pathlib.Path("models"), "c.sql"), + "MODEL(name untouched.model_c); SELECT 1;", + ) + + ctx = Context( + config=Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + physical_schema_mapping={ + # anything ending with 'staging' becomes 'overridden_staging' + "^.*staging$": "overridden_staging", + # anything starting with 'test' becomes 'testing' + "^test.*": "testing", + }, + ), + paths=tmp_path, + ) + + ctx.load() + + physical_schemas = [snapshot.physical_schema for snapshot in sorted(ctx.snapshots.values())] + + view_schemas = [ + snapshot.qualified_view_name.schema_name for snapshot in sorted(ctx.snapshots.values()) + ] + + assert len(physical_schemas) == len(view_schemas) == 3 + assert physical_schemas == ["overridden_staging", "testing", "sqlmesh__untouched"] + assert view_schemas == ["foo_staging", "testone", "untouched"] + + @pytest.mark.slow def test_janitor(sushi_context, mocker: MockerFixture) -> None: adapter_mock = mocker.MagicMock() adapter_mock.dialect = "duckdb" state_sync_mock = mocker.MagicMock() - state_sync_mock.delete_expired_environments.return_value = [ + + environments = [ Environment( - name="test_environment", + name="test_environment1", suffix_target=EnvironmentSuffixTarget.TABLE, snapshots=[x.table_info for x in sushi_context.snapshots.values()], start_at="2022-01-01", @@ -532,7 +1052,7 @@ def test_janitor(sushi_context, mocker: MockerFixture) -> None: previous_plan_id="test_plan_id", ), Environment( - name="test_environment", + name="test_environment2", suffix_target=EnvironmentSuffixTarget.SCHEMA, snapshots=[x.table_info for x in sushi_context.snapshots.values()], start_at="2022-01-01", @@ -541,15 +1061,24 @@ def test_janitor(sushi_context, mocker: MockerFixture) -> None: previous_plan_id="test_plan_id", ), ] + + state_sync_mock.get_expired_environments.return_value = [env.summary for env in environments] + state_sync_mock.get_environment = lambda name: next( + env for env in environments if env.name == name + ) + sushi_context._engine_adapter = adapter_mock + sushi_context.engine_adapters = {sushi_context.config.default_gateway: adapter_mock} sushi_context._state_sync = state_sync_mock + state_sync_mock.get_expired_snapshots.return_value = None + sushi_context._run_janitor() # Assert that the schemas are dropped just twice for the schema based environment # Make sure that external model schemas/tables are not dropped adapter_mock.drop_schema.assert_has_calls( [ call( - schema_("sushi__test_environment", "memory"), + schema_("sushi__test_environment2", "memory"), cascade=True, ignore_if_not_exists=True, ), @@ -557,11 +1086,11 @@ def test_janitor(sushi_context, mocker: MockerFixture) -> None: ) # Assert that the views are dropped for each snapshot just once and make sure that the name used is the # view name with the environment as a suffix - assert adapter_mock.drop_view.call_count == 13 + assert adapter_mock.drop_view.call_count == 16 adapter_mock.drop_view.assert_has_calls( [ call( - "memory.sushi.waiter_as_customer_by_day__test_environment", + "memory.sushi.waiter_as_customer_by_day__test_environment1", ignore_if_not_exists=True, ), ] @@ -582,26 +1111,27 @@ def test_plan_default_end(sushi_context_pre_scheduling: Context): assert dev_plan.end is not None assert to_date(make_inclusive_end(dev_plan.end)) == plan_end - forward_only_dev_plan = sushi_context_pre_scheduling.plan( - "test_env_forward_only", no_prompts=True, include_unmodified=True, forward_only=True - ) + forward_only_dev_plan = sushi_context_pre_scheduling.plan_builder( + "test_env_forward_only", include_unmodified=True, forward_only=True + ).build() assert forward_only_dev_plan.end is not None assert to_date(make_inclusive_end(forward_only_dev_plan.end)) == plan_end - assert forward_only_dev_plan.start == plan_end + assert to_timestamp(forward_only_dev_plan.start) == to_timestamp(plan_end) @pytest.mark.slow def test_plan_start_ahead_of_end(copy_to_temp_path): path = copy_to_temp_path("examples/sushi") - with freezegun.freeze_time("2024-01-02 00:00:00"): - context = Context(paths=path, config="local_config") + with time_machine.travel("2024-01-02 00:00:00 UTC"): + context = Context(paths=path, gateway="duckdb_persistent") context.plan("prod", no_prompts=True, auto_apply=True) - assert context.state_sync.max_interval_end_for_environment("prod") == to_timestamp( - "2024-01-02" + assert all( + i == to_timestamp("2024-01-02") + for i in context.state_sync.max_interval_end_per_model("prod").values() ) context.close() - with freezegun.freeze_time("2024-01-03 00:00:00"): - context = Context(paths=path, config="local_config") + with time_machine.travel("2024-01-03 00:00:00 UTC"): + context = Context(paths=path, gateway="duckdb_persistent") expression = d.parse( """ MODEL( @@ -619,13 +1149,80 @@ def test_plan_start_ahead_of_end(copy_to_temp_path): # Since the new start is ahead of the latest end loaded for prod, the table is deployed as empty # This isn't considered a gap since prod has not loaded these intervals yet # As a results the max interval end is unchanged and the table is empty - assert context.state_sync.max_interval_end_for_environment("prod") == to_timestamp( - "2024-01-02" + assert all( + i == to_timestamp("2024-01-02") + for i in context.state_sync.max_interval_end_per_model("prod").values() ) assert context.engine_adapter.fetchone("SELECT COUNT(*) FROM sushi.hourly")[0] == 0 context.close() +@pytest.mark.slow +def test_plan_seed_model_excluded_from_default_end(copy_to_temp_path: t.Callable): + path = copy_to_temp_path("examples/sushi") + with time_machine.travel("2024-06-01 00:00:00 UTC"): + context = Context(paths=path, gateway="duckdb_persistent") + context.plan("prod", no_prompts=True, auto_apply=True) + max_ends = context.state_sync.max_interval_end_per_model("prod") + seed_fqns = [k for k in max_ends if "waiter_names" in k] + assert len(seed_fqns) == 1 + assert max_ends[seed_fqns[0]] == to_timestamp("2024-06-01") + context.close() + + with time_machine.travel("2026-03-01 00:00:00 UTC"): + context = Context(paths=path, gateway="duckdb_persistent") + + # a model that depends on this seed but has no interval in prod yet so only the seed would contribute to max_interval_end_per_model + context.upsert_model( + load_sql_based_model( + parse( + """ + MODEL( + name sushi.waiter_summary, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds + ), + start '2025-01-01', + cron '@daily' + ); + + SELECT + id, + name, + @start_ds AS ds + FROM + sushi.waiter_names + WHERE + @start_ds BETWEEN @start_ds AND @end_ds + """ + ), + default_catalog=context.default_catalog, + ) + ) + + # the seed's interval end would still be 2024-06-01 + max_ends = context.state_sync.max_interval_end_per_model("prod") + seed_fqns = [k for k in max_ends if "waiter_names" in k] + assert len(seed_fqns) == 1 + assert max_ends[seed_fqns[0]] == to_timestamp("2024-06-01") + + # the plan start date 2025-01-01 is after the seeds end date but shouldnt cause the plan to fail + plan = context.plan( + "dev", start="2025-01-01", no_prompts=True, select_models=["*waiter_summary"] + ) + + # the end should fall back to execution_time rather than seeds end + assert plan.models_to_backfill == { + '"duckdb"."sushi"."waiter_names"', + '"duckdb"."sushi"."waiter_summary"', + } + assert plan.provided_end is None + assert plan.provided_start == "2025-01-01" + assert to_timestamp(plan.end) == to_timestamp("2026-03-01") + assert to_timestamp(plan.start) == to_timestamp("2025-01-01") + context.close() + + @pytest.mark.slow def test_schema_error_no_default(sushi_context_pre_scheduling) -> None: context = sushi_context_pre_scheduling @@ -678,6 +1275,7 @@ def test_unrestorable_snapshot(sushi_context: Context) -> None: no_prompts=True, forward_only=True, allow_destructive_models=["memory.sushi.test_unrestorable"], + categorizer_config=CategorizerConfig.all_full(), ) sushi_context.upsert_model(model_v1) @@ -686,6 +1284,7 @@ def test_unrestorable_snapshot(sushi_context: Context) -> None: no_prompts=True, forward_only=True, allow_destructive_models=["memory.sushi.test_unrestorable"], + categorizer_config=CategorizerConfig.all_full(), ) model_v1_new_snapshot = sushi_context.get_snapshot( "memory.sushi.test_unrestorable", raise_if_missing=True @@ -704,7 +1303,7 @@ def test_default_catalog_connections(copy_to_temp_path: t.Callable): "sqlmesh.core.engine_adapter.base.EngineAdapter.default_catalog", PropertyMock(return_value=None), ): - context = Context(paths="examples/sushi") + context = Context(paths=copy_to_temp_path("examples/sushi")) assert context.default_catalog is None # Verify that providing a catalog gets set as default catalog @@ -740,9 +1339,19 @@ def test_load_external_models(copy_to_temp_path): assert "prod_raw.model1" not in external_model_names # get physical table names of external models using table - assert context.table("raw.model1") == "memory.raw.model1" - assert context.table("raw.demographics") == "memory.raw.demographics" - assert context.table("raw.model2") == "memory.raw.model2" + assert context.resolve_table("raw.model1") == '"memory"."raw"."model1"' + assert context.resolve_table("raw.demographics") == '"memory"."raw"."demographics"' + assert context.resolve_table("raw.model2") == '"memory"."raw"."model2"' + + with patch.object(context.console, "log_warning") as mock_logger: + context.table("raw.model1") == '"memory"."raw"."model1"' + + assert mock_logger.mock_calls == [ + call( + "The SQLMesh context's `table` method is deprecated and will be removed " + "in a future release. Please use the `resolve_table` method instead." + ) + ] def test_load_gateway_specific_external_models(copy_to_temp_path): @@ -765,6 +1374,12 @@ def _get_external_model_names(gateway=None): # gateway explicitly set to prod; prod model should now show assert "prod_raw.model1" in _get_external_model_names(gateway="prod") + # test uppercase gateway name should match lowercase external model definition + assert "prod_raw.model1" in _get_external_model_names(gateway="PROD") + + # test mixed case gateway name should also work + assert "prod_raw.model1" in _get_external_model_names(gateway="Prod") + def test_disabled_model(copy_to_temp_path): path = copy_to_temp_path("examples/sushi") @@ -778,6 +1393,27 @@ def test_disabled_model(copy_to_temp_path): assert not context.get_model("sushi.disabled_py") +def test_disabled_model_python_macro(sushi_context): + @model( + "memory.sushi.disabled_model_2", + columns={"col": "int"}, + enabled="@IF(@gateway = 'dev', True, False)", + ) + def entrypoint(context, **kwargs): + yield pd.DataFrame({"col": []}) + + test_model = model.get_registry()["memory.sushi.disabled_model_2"].model( + module_path=Path("."), path=Path("."), variables={"gateway": "prod"} + ) + assert not test_model.enabled + + with pytest.raises( + SQLMeshError, + match="The disabled model 'memory.sushi.disabled_model_2' cannot be upserted", + ): + sushi_context.upsert_model(test_model) + + def test_get_model_mixed_dialects(copy_to_temp_path): path = copy_to_temp_path("examples/sushi") @@ -794,7 +1430,7 @@ def test_get_model_mixed_dialects(copy_to_temp_path): model = load_sql_based_model(expression, default_catalog=context.default_catalog) context.upsert_model(model) - assert context.get_model("sushi.snowflake_dialect") == model + assert context.get_model("sushi.snowflake_dialect").dict() == model.dict() def test_override_dialect_normalization_strategy(): @@ -812,3 +1448,1896 @@ def test_override_dialect_normalization_strategy(): # The above change is applied globally so we revert it to avoid breaking other tests DuckDB.NORMALIZATION_STRATEGY = NormalizationStrategy.CASE_INSENSITIVE + + +def test_different_gateway_normalization_strategy(tmp_path: pathlib.Path): + config = Config( + gateways={ + "duckdb": GatewayConfig( + connection=DuckDBConnectionConfig(database="db.db"), + model_defaults=ModelDefaultsConfig( + dialect="snowflake, normalization_strategy=case_insensitive" + ), + ) + }, + model_defaults=ModelDefaultsConfig(dialect="snowflake"), + default_gateway="duckdb", + ) + + from sqlglot.dialects import Snowflake + from sqlglot.dialects.dialect import NormalizationStrategy + + assert Snowflake.NORMALIZATION_STRATEGY == NormalizationStrategy.UPPERCASE + + ctx = Context(paths=tmp_path, config=config, gateway="duckdb") + + dialect = Dialect.get_or_raise(ctx.config.dialect) + + assert dialect == "snowflake" + assert Snowflake.NORMALIZATION_STRATEGY == NormalizationStrategy.CASE_INSENSITIVE + + Snowflake.NORMALIZATION_STRATEGY = NormalizationStrategy.UPPERCASE + + +def test_access_self_columns_to_types_in_macro(tmp_path: pathlib.Path): + create_temp_file( + tmp_path, + pathlib.Path(pathlib.Path("models"), "test.sql"), + "MODEL(name test); SELECT 1 AS c; @post_statement()", + ) + create_temp_file( + tmp_path, + pathlib.Path(pathlib.Path("macros"), "post_statement.py"), + """ +from sqlglot import exp +from sqlmesh.core.macros import macro + +@macro() +def post_statement(evaluator): + if evaluator.runtime_stage != 'loading': + assert evaluator.columns_to_types("test") == {"c": exp.DataType.build("int")} + return None +""", + ) + + context = Context(paths=tmp_path, config=Config()) + context.plan(auto_apply=True, no_prompts=True) + + +def test_wildcard(copy_to_temp_path: t.Callable): + parent_path = copy_to_temp_path("examples/multi")[0] + + context = Context(paths=f"{parent_path}/*", gateway="memory") + assert len(context.models) == 5 + + +def test_duckdb_state_connection_automatic_multithreaded_mode(tmp_path): + single_threaded_config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + default_gateway="duckdb", + gateways={ + "duckdb": GatewayConfig( + connection=DuckDBConnectionConfig(concurrent_tasks=1), + state_connection=DuckDBConnectionConfig(concurrent_tasks=1), + ) + }, + ) + + # main connection 4 concurrent tasks, state connection 1 concurrent task, + # context should adjust concurrent tasks on state connection to match main connection + multi_threaded_config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + default_gateway="duckdb", + gateways={ + "duckdb": GatewayConfig( + connection=DuckDBConnectionConfig(concurrent_tasks=4), + state_connection=DuckDBConnectionConfig(concurrent_tasks=1), + ) + }, + ) + + context = Context(paths=[tmp_path], config=single_threaded_config) + assert isinstance(context.state_sync, CachingStateSync) + state_sync = context.state_sync.state_sync + assert isinstance(state_sync, EngineAdapterStateSync) + assert isinstance(state_sync.engine_adapter, DuckDBEngineAdapter) + assert isinstance(state_sync.engine_adapter._connection_pool, SingletonConnectionPool) + + context = Context(paths=[tmp_path], config=multi_threaded_config) + assert isinstance(context.state_sync, CachingStateSync) + state_sync = context.state_sync.state_sync + assert isinstance(state_sync, EngineAdapterStateSync) + assert isinstance(state_sync.engine_adapter, DuckDBEngineAdapter) + assert isinstance(state_sync.engine_adapter._connection_pool, ThreadLocalSharedConnectionPool) + + +def test_requirements(copy_to_temp_path: t.Callable): + from sqlmesh.utils.metaprogramming import Executable + + context_path = copy_to_temp_path("examples/sushi")[0] + + with open(context_path / c.REQUIREMENTS, "w") as f: + # Add pandas and test_package and exclude ruamel.yaml + f.write("pandas==2.2.2\ntest_package==1.0.0\n^ruamel.yaml\n^ruamel.yaml.clib") + + context = Context(paths=context_path) + + model = context.get_model("sushi.items") + model.python_env["ruamel"] = Executable(payload="import ruamel", kind="import") + model.python_env["Image"] = Executable( + payload="from ipywidgets.widgets.widget_media import Image", kind="import" + ) + + environment = context.plan( + "dev", no_prompts=True, skip_tests=True, skip_backfill=True, auto_apply=True + ).environment + requirements = {"ipywidgets", "numpy", "pandas", "test_package"} + if IS_WINDOWS: + requirements.add("pendulum") + assert environment.requirements["pandas"] == "2.2.2" + assert set(environment.requirements) == requirements + + context._requirements = {"numpy": "2.1.2", "pandas": "2.2.1"} + context._excluded_requirements = {"ipywidgets", "ruamel.yaml", "ruamel.yaml.clib"} + diff = context.plan_builder("dev", skip_tests=True, skip_backfill=True).build().context_diff + assert set(diff.previous_requirements) == requirements + reqs = {"numpy", "pandas"} + if IS_WINDOWS: + reqs.add("pendulum") + assert set(diff.requirements) == reqs + + +def test_deactivate_automatic_requirement_inference(copy_to_temp_path: t.Callable): + context_path = copy_to_temp_path("examples/sushi")[0] + config = next(iter(load_configs("config", Config, paths=context_path).values())) + + config.infer_python_dependencies = False + context = Context(paths=context_path, config=config) + environment = context.plan( + "dev", no_prompts=True, skip_tests=True, skip_backfill=True, auto_apply=True + ).environment + + assert environment.requirements == {"pandas": "2.2.2"} + + +@pytest.mark.slow +def test_rendered_diff(): + ctx = Context(config=Config()) + + ctx.upsert_model( + load_sql_based_model( + parse( + """ + MODEL ( + name test, + ); + + CREATE TABLE IF NOT EXISTS foo AS (SELECT @OR(FALSE, TRUE)); + + SELECT 4 + 2; + + CREATE TABLE IF NOT EXISTS foo2 AS (SELECT @AND(TRUE, FALSE)); + + ON_VIRTUAL_UPDATE_BEGIN; + DROP VIEW @this_model + ON_VIRTUAL_UPDATE_END; + + """ + ) + ) + ) + + ctx.plan("dev", auto_apply=True, no_prompts=True) + + # Alter the model's query and pre/post/virtual statements to cause the diff + ctx.upsert_model( + load_sql_based_model( + parse( + """ + MODEL ( + name test, + ); + + CREATE TABLE IF NOT EXISTS foo AS (SELECT @AND(TRUE, NULL)); + + SELECT 5 + 2; + + CREATE TABLE IF NOT EXISTS foo2 AS (SELECT @OR(TRUE, NULL)); + + ON_VIRTUAL_UPDATE_BEGIN; + DROP VIEW IF EXISTS @this_model + ON_VIRTUAL_UPDATE_END; + """ + ) + ) + ) + + plan = ctx.plan("dev", auto_apply=True, no_prompts=True, diff_rendered=True) + + assert plan.context_diff.text_diff('"test"') == ( + "--- \n\n" + "+++ \n\n" + "@@ -4,15 +4,15 @@\n\n" + ' CREATE TABLE IF NOT EXISTS "foo" AS\n' + " (\n" + " SELECT\n" + "- FALSE OR TRUE\n" + "+ TRUE\n" + " )\n" + " SELECT\n" + '- 6 AS "_col_0"\n' + '+ 7 AS "_col_0"\n' + ' CREATE TABLE IF NOT EXISTS "foo2" AS\n' + " (\n" + " SELECT\n" + "- TRUE AND FALSE\n" + "+ TRUE\n" + " )\n" + " ON_VIRTUAL_UPDATE_BEGIN;\n" + '-DROP VIEW "test";\n' + '+DROP VIEW IF EXISTS "test";\n' + " ON_VIRTUAL_UPDATE_END;" + ) + + +def test_plan_enable_preview_default(sushi_context: Context, sushi_dbt_context: Context): + assert sushi_context._plan_preview_enabled + assert not sushi_dbt_context._plan_preview_enabled + + sushi_dbt_context.engine_adapter.SUPPORTS_CLONING = True + assert sushi_dbt_context._plan_preview_enabled + + +@pytest.mark.slow +def test_raw_code_handling(sushi_test_dbt_context: Context): + model = sushi_test_dbt_context.models['"memory"."sushi"."model_with_raw_code"'] + assert "raw_code" not in model.jinja_macros.global_objs["model"] # type: ignore + + # logging "pre-hook" (in dbt_projects.yml) + the actual pre-hook in the model file + assert len(model.pre_statements) == 2 + + original_file_path = model.jinja_macros.global_objs["model"]["original_file_path"] # type: ignore + model_file_path = sushi_test_dbt_context.path / original_file_path + + raw_code_length = len(model_file_path.read_text()) - 1 + + hook = model.render_pre_statements()[0] + assert ( + hook.sql() + == f'''CREATE TABLE IF NOT EXISTS "t" AS SELECT 'Length is {raw_code_length}' AS "length_col"''' + ) + + +@pytest.mark.slow +def test_dbt_models_are_not_validated(sushi_test_dbt_context: Context): + model = sushi_test_dbt_context.models['"memory"."sushi"."non_validated_model"'] + + assert model.render_query_or_raise().sql(comments=False) == 'SELECT 1 AS "c", 2 AS "c"' + assert sushi_test_dbt_context.fetchdf( + 'SELECT * FROM "memory"."sushi"."non_validated_model"' + ).to_dict() == {"c": {0: 1}, "c_1": {0: 2}} + + # Write a new incremental model file that should fail validation + models_dir = sushi_test_dbt_context.path / "models" + incremental_model_path = models_dir / "invalid_incremental.sql" + incremental_model_content = """{{ + config( + materialized='incremental', + incremental_strategy='delete+insert', + ) +}} + +SELECT + 1 AS c""" + + incremental_model_path.write_text(incremental_model_content) + + # Reload the context - this should raise a validation error for the incremental model + with pytest.raises( + ConfigError, + match="Unmanaged incremental models with insert / overwrite enabled must specify the partitioned_by field", + ): + Context(paths=sushi_test_dbt_context.path, config="test_config") + + +def test_catalog_name_needs_to_be_quoted(): + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + default_connection=DuckDBConnectionConfig(catalogs={'"foo--bar"': ":memory:"}), + ) + context = Context(config=config) + parsed_model = parse("MODEL(name db.x, kind FULL); SELECT 1 AS c") + context.upsert_model(load_sql_based_model(parsed_model, default_catalog='"foo--bar"')) + context.plan(auto_apply=True, no_prompts=True) + assert context.fetchdf('select * from "foo--bar".db.x').to_dict() == {"c": {0: 1}} + + +def test_plan_runs_audits_on_dev_previews(sushi_context: Context, capsys, caplog): + sushi_context.console = create_console() + + test_model = """ + MODEL ( + name sushi.test_audit_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column event_date, + forward_only true + ), + audits ( + number_of_rows(threshold := 10), + not_null(columns := id), + at_least_one_non_blocking(column := waiter_id) + ) + ); + + SELECT * FROM sushi.orders WHERE event_date BETWEEN @start_ts AND @end_ts + """ + + sushi_context.upsert_model( + load_sql_based_model(parse(test_model), default_catalog=sushi_context.default_catalog) + ) + plan = sushi_context.plan(auto_apply=True) + + assert plan.new_snapshots[0].name == '"memory"."sushi"."test_audit_model"' + assert plan.deployability_index.is_deployable(plan.new_snapshots[0]) + + # now, we mutate the model and run a plan in dev to create a dev preview + test_model = """ + MODEL ( + name sushi.test_audit_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column event_date, + forward_only true + ), + audits ( + not_null(columns := new_col), + at_least_one_non_blocking(column := new_col) + ) + ); + + SELECT *, null as new_col FROM sushi.orders WHERE event_date BETWEEN @start_ts AND @end_ts + """ + + sushi_context.upsert_model( + load_sql_based_model(parse(test_model), default_catalog=sushi_context.default_catalog) + ) + + capsys.readouterr() # clear output buffer + plan = sushi_context.plan(environment="dev", auto_apply=True) + + assert len(plan.new_snapshots) == 1 + dev_preview = plan.new_snapshots[0] + assert dev_preview.name == '"memory"."sushi"."test_audit_model"' + assert dev_preview.is_forward_only + assert not plan.deployability_index.is_deployable( + dev_preview + ) # if something is not deployable to prod, then its by definiton a dev preview + + # we only see audit results if they fail + stdout = capsys.readouterr().out + log = caplog.text + assert "'not_null' audit error:" in log + assert "'at_least_one_non_blocking' audit error:" in log + assert "Virtual layer updated" in stdout + + +def test_environment_statements(tmp_path: pathlib.Path): + models_dir = pathlib.Path("models") + macros_dir = pathlib.Path("macros") + dialect = "postgres" + + config = Config( + model_defaults=ModelDefaultsConfig(dialect=dialect), + before_all=[ + "CREATE TABLE IF NOT EXISTS analytic_stats (physical_table VARCHAR, evaluation_time VARCHAR)" + ], + after_all=[ + "@grant_schema_usage()", + "@grant_usage_role(@schemas, 'admin')", + "@grant_select_privileges()", + ], + ) + + expression = """ +MODEL( + name db.test_after_model, + kind full +); + +SELECT 1 AS col_a; + """ + + create_temp_file( + tmp_path, + pathlib.Path(models_dir, "db", "test_after_model.sql"), + expression, + ) + + create_temp_file( + tmp_path, + pathlib.Path(macros_dir, "grant_select_privileges.py"), + """ +from sqlmesh.core.macros import macro +@macro() +def grant_select_privileges(evaluator): + if evaluator.this_env and evaluator.views: + return [ + f"GRANT SELECT ON VIEW {view_name} /* sqlglot.meta replace=false */ TO ROLE admin_role;" + for view_name in evaluator.views + ] +""", + ) + + create_temp_file( + tmp_path, + pathlib.Path(macros_dir, "grant_schema_file.py"), + """ +from sqlmesh import macro + +@macro() +def grant_schema_usage(evaluator): + if evaluator._environment_naming_info: + schemas = { + snapshot.qualified_view_name.schema_for_environment( + evaluator._environment_naming_info + ) + for snapshot in evaluator._snapshots.values() + if snapshot.is_model + } + return [ + f"GRANT USAGE ON SCHEMA {schema} TO user_role;" + for schema in schemas + ] +""", + ) + + create_temp_file( + tmp_path, + pathlib.Path(macros_dir, "grant_usage_file.py"), + """ +from sqlmesh import macro + +@macro() +def grant_usage_role(evaluator, schemas, role): + if evaluator._environment_naming_info: + return [ + f"GRANT USAGE ON SCHEMA {schema} TO {role};" + for schema in schemas + ] +""", + ) + + context = Context(paths=tmp_path, config=config) + snapshots = {s.name: s for s in context.snapshots.values()} + + environment_statements = context._environment_statements[0] + before_all = environment_statements.before_all + after_all = environment_statements.after_all + python_env = environment_statements.python_env + + assert isinstance(python_env["grant_schema_usage"], Executable) + assert isinstance(python_env["grant_usage_role"], Executable) + assert isinstance(python_env["grant_select_privileges"], Executable) + + before_all_rendered = render_statements( + statements=before_all, + dialect=dialect, + python_env=python_env, + runtime_stage=RuntimeStage.BEFORE_ALL, + ) + + assert before_all_rendered == [ + "CREATE TABLE IF NOT EXISTS analytic_stats (physical_table VARCHAR, evaluation_time VARCHAR)" + ] + + after_all_rendered = render_statements( + statements=after_all, + dialect=dialect, + python_env=python_env, + snapshots=snapshots, + environment_naming_info=EnvironmentNamingInfo(name="prod"), + runtime_stage=RuntimeStage.AFTER_ALL, + ) + + assert sorted(after_all_rendered) == sorted( + [ + "GRANT USAGE ON SCHEMA db TO user_role", + 'GRANT USAGE ON SCHEMA "db" TO "admin"', + "GRANT SELECT ON VIEW memory.db.test_after_model /* sqlglot.meta replace=false */ TO ROLE admin_role", + ] + ) + + after_all_rendered_dev = render_statements( + statements=after_all, + dialect=dialect, + python_env=python_env, + snapshots=snapshots, + environment_naming_info=EnvironmentNamingInfo(name="dev"), + runtime_stage=RuntimeStage.AFTER_ALL, + ) + + assert sorted(after_all_rendered_dev) == sorted( + [ + "GRANT USAGE ON SCHEMA db__dev TO user_role", + 'GRANT USAGE ON SCHEMA "db__dev" TO "admin"', + "GRANT SELECT ON VIEW memory.db__dev.test_after_model /* sqlglot.meta replace=false */ TO ROLE admin_role", + ] + ) + + +def test_plan_environment_statements(tmp_path: pathlib.Path): + models_dir = pathlib.Path("models") + macros_dir = pathlib.Path("macros") + dialect = "duckdb" + + config = Config( + model_defaults=ModelDefaultsConfig(dialect=dialect), + before_all=["@create_stats_table()", "@access_adapter()"], + after_all=["CREATE TABLE IF NOT EXISTS after_table AS SELECT @some_var"], + variables={"some_var": 5}, + ) + + model_file = """ +MODEL( + name db.test_stats_model, + kind full, +); + +@IF( + @runtime_stage IN ('evaluating', 'creating'), + SET VARIABLE stats_model_start = now() +); + +SELECT 1 AS cola; + +@IF( + @runtime_stage IN ('evaluating', 'creating'), + INSERT INTO analytic_stats (physical_table, evaluation_start, evaluation_end, evaluation_time) + VALUES (@resolve_template('@{schema_name}.@{table_name}'), getvariable('stats_model_start'), now(), now() - getvariable('stats_model_start')) +); + + """ + + create_temp_file( + tmp_path, + pathlib.Path(models_dir, "db", "test_stats_model.sql"), + model_file, + ) + + create_temp_file( + tmp_path, + pathlib.Path(macros_dir, "create_stats_table.py"), + """ +from sqlmesh.core.macros import macro + +@macro() +def create_stats_table(evaluator): + return "CREATE TABLE IF NOT EXISTS analytic_stats (physical_table VARCHAR, evaluation_start VARCHAR, evaluation_end VARCHAR, evaluation_time VARCHAR)" +""", + ) + + create_temp_file( + tmp_path, + pathlib.Path(macros_dir, "access_adapter.py"), + """ +from sqlmesh.core.macros import macro + +@macro() +def access_adapter(evaluator): + if evaluator.runtime_stage == 'before_all': + engine_adapter = evaluator.engine_adapter + for i in range(10): + try: + sql_inside_macro = f"CREATE TABLE IF NOT EXISTS db_connect AS SELECT {i} as 'access_attempt'" + engine_adapter.execute(sql_inside_macro) + return None + except Exception as e: + sleep(10) + raise Exception(f"Failed to connect to the database") + """, + ) + + context = Context(paths=tmp_path, config=config) + + assert context._environment_statements[0].before_all == [ + "@create_stats_table()", + "@access_adapter()", + ] + + assert context._environment_statements[0].after_all == [ + "CREATE TABLE IF NOT EXISTS after_table AS SELECT @some_var" + ] + assert context._environment_statements[0].python_env["create_stats_table"] + + context.plan(auto_apply=True, no_prompts=True) + + model = context.get_model("db.test_stats_model") + snapshot = context.get_snapshot("db.test_stats_model") + assert snapshot and snapshot.version + + assert ( + model.pre_statements[0].sql() + == "@IF(@runtime_stage IN ('evaluating', 'creating'), SET stats_model_start = NOW())" + ) + assert ( + model.post_statements[0].sql() + == "@IF(@runtime_stage IN ('evaluating', 'creating'), INSERT INTO analytic_stats (physical_table, evaluation_start, evaluation_end, evaluation_time) VALUES (@resolve_template('@{schema_name}.@{table_name}'), GETVARIABLE('stats_model_start'), NOW(), NOW() - GETVARIABLE('stats_model_start')))" + ) + + stats_table = context.fetchdf("select * from memory.analytic_stats").to_dict() + assert stats_table.keys() == { + "physical_table", + "evaluation_start", + "evaluation_end", + "evaluation_time", + } + assert ( + stats_table["physical_table"][0] == f"sqlmesh__db.db__test_stats_model__{snapshot.version}" + ) + + assert context.fetchdf("select * from memory.after_table").to_dict()["5"][0] == 5 + + state_table = context.state_reader.get_environment_statements(c.PROD) + assert state_table[0].before_all == context._environment_statements[0].before_all + assert state_table[0].after_all == context._environment_statements[0].after_all + assert state_table[0].python_env == context._environment_statements[0].python_env + + # This table will be created inside the macro by accessing the engine_adapter directly + inside_macro_execute = context.fetchdf("select * from memory.db_connect").to_dict() + assert (attempt_column := inside_macro_execute.get("access_attempt")) + assert isinstance(attempt_column, dict) and attempt_column[0] < 10 + + +def test_environment_statements_dialect(tmp_path: Path): + before_all = [ + "EXPORT DATA OPTIONS (URI='gs://path*.csv.gz', FORMAT='CSV') AS SELECT * FROM all_rows" + ] + after_all = ["@IF(@this_env = 'prod', CREATE TABLE IF NOT EXISTS after_t AS SELECT 1)"] + config = Config( + model_defaults=ModelDefaultsConfig(dialect="bigquery"), + before_all=before_all, + after_all=after_all, + ) + ctx = Context(paths=[tmp_path], config=config) + assert ctx._environment_statements == [ + EnvironmentStatements(before_all=before_all, after_all=after_all, python_env={}) + ] + + # Without the correct dialect this statement should error out instead + with pytest.raises(ParseError, match=r"Invalid expression / Unexpected token*"): + config.model_defaults.dialect = "duckdb" + ctx = Context(paths=[tmp_path], config=config) + + +@pytest.mark.slow +@use_terminal_console +def test_model_linting(tmp_path: pathlib.Path, sushi_context) -> None: + def assert_cached_violations_exist(cache: OptimizedQueryCache, model: Model): + model = t.cast(SqlModel, model) + cache_entry = cache._file_cache.get(cache._entry_name(model)) + assert cache_entry is not None + assert cache_entry.optimized_rendered_query is not None + assert cache_entry.renderer_violations is not None + + cfg = LinterConfig(enabled=True, rules="ALL") + ctx = Context( + config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"), linter=cfg), + paths=tmp_path, + ) + + config_err = "Linter detected errors in the code. Please fix them before proceeding." + + # Case: Ensure load DOES NOT work if linter is enabled + for query in ("SELECT * FROM tbl", "SELECT t.* FROM tbl"): + with pytest.raises(LinterError, match=config_err): + ctx.upsert_model(load_sql_based_model(d.parse(f"MODEL (name test); {query}"))) + ctx.plan(environment="dev", auto_apply=True, no_prompts=True) + + error_model = load_sql_based_model(d.parse("MODEL (name test); SELECT * FROM (SELECT 1)")) + with pytest.raises(LinterError, match=config_err): + ctx.upsert_model(error_model) + ctx.plan_builder("dev") + + # Case: Ensure error violations are cached if the model did not pass linting + cache = OptimizedQueryCache(ctx.cache_dir) + + assert_cached_violations_exist(cache, error_model) + + # Case: Ensure NoSelectStar only raises for top-level SELECTs, new model shouldn't raise + # and thus should also be cached + model2 = load_sql_based_model( + d.parse( + "MODEL (name test2, audits (at_least_one(column := col))); SELECT col FROM (SELECT * FROM tbl)" + ) + ) + ctx.upsert_model(model2) + + model2 = t.cast(SqlModel, model2) + assert cache._file_cache.exists(cache._entry_name(model2)) + + # Case: Ensure warning violations are found again even if the optimized query is cached + ctx.config.linter = LinterConfig(enabled=True, warn_rules="ALL") + ctx.load() + + for i in range(3): + with patch.object(get_console(), "log_warning") as mock_logger: + if i > 1: + # Model's violations have been cached from the previous upserts + assert_cached_violations_exist(cache, model2) + + ctx.upsert_model(error_model) + ctx.plan(environment="dev", auto_apply=True, no_prompts=True) + + assert ( + """noselectstar - Query should not contain SELECT * on its outer most projections""" + in mock_logger.call_args[0][0] + ) + + # Model's violations have been cached after the former upsert + assert_cached_violations_exist(cache, model2) + + # Case: Ensure load WORKS if linter is enabled but the rules are not + create_temp_file( + tmp_path, + pathlib.Path(pathlib.Path("models"), "test.sql"), + "MODEL(name test); SELECT * FROM (SELECT 1 AS col);", + ) + + ignore_or_warn_cfgs = [ + LinterConfig(enabled=True, warn_rules=["noselectstar"]), + LinterConfig(enabled=True, ignored_rules=["noselectstar"]), + ] + for cfg in ignore_or_warn_cfgs: + ctx.config.linter = cfg + ctx.load() + + # Case: Ensure load DOES NOT work if LinterConfig has overlapping rules + with pytest.raises( + ConfigError, + match=r"Rules cannot simultaneously warn and raise an error: \[noselectstar\]", + ): + ctx.config.linter = LinterConfig( + enabled=True, rules=["noselectstar"], warn_rules=["noselectstar"] + ) + ctx.load() + + # Case: Ensure model attribute overrides global config + ctx.config.linter = LinterConfig(enabled=True, rules=["noselectstar"]) + + create_temp_file( + tmp_path, + pathlib.Path(pathlib.Path("models"), "test.sql"), + "MODEL(name test, ignored_rules ['ALL']); SELECT * FROM (SELECT 1 AS col);", + ) + + create_temp_file( + tmp_path, + pathlib.Path(pathlib.Path("models"), "test2.sql"), + "MODEL(name test2, audits (at_least_one(column := col)), ignored_rules ['noselectstar']); SELECT * FROM (SELECT 1 AS col);", + ) + + ctx.plan(environment="dev", auto_apply=True, no_prompts=True) + + # Case: Ensure we can load & use the user-defined rules + sushi_context.config.linter = LinterConfig(enabled=True, rules=["aLl"]) + sushi_context.load() + sushi_context.upsert_model( + load_sql_based_model( + d.parse("MODEL (name sushi.test); SELECT col FROM (SELECT * FROM tbl)"), + default_catalog="memory", + ) + ) + + with pytest.raises(LinterError, match=config_err): + sushi_context.plan_builder(environment="dev") + + # Case: Ensure the Linter also picks up Python model violations + @model(name="memory.sushi.model3", is_sql=True, kind="full", dialect="snowflake") + def model3_entrypoint(evaluator: MacroEvaluator) -> str: + return "select * from model1" + + model3 = model.get_registry()["memory.sushi.model3"].model( + module_path=Path("."), path=Path(".") + ) + + @model(name="memory.sushi.model4", columns={"col": "int"}) + def model4_entrypoint(context, **kwargs): + yield pd.DataFrame({"col": []}) + + model4 = model.get_registry()["memory.sushi.model4"].model( + module_path=Path("."), path=Path(".") + ) + + for python_model in (model3, model4): + with pytest.raises(LinterError, match=config_err): + sushi_context.upsert_model(python_model) + sushi_context.plan(environment="dev", auto_apply=True, no_prompts=True) + + +def test_plan_selector_expression_no_match(sushi_context: Context) -> None: + with pytest.raises( + PlanError, + match="Selector did not return any models. Please check your model selection and try again.", + ): + sushi_context.plan("dev", select_models=["*missing*"]) + + with pytest.raises( + PlanError, + match="Selector did not return any models. Please check your model selection and try again.", + ): + sushi_context.plan("dev", backfill_models=["*missing*"]) + + with pytest.raises( + PlanError, + match="Selector did not return any models. Please check your model selection and try again.", + ): + sushi_context.plan("prod", restate_models=["*missing*"]) + + +def test_plan_on_virtual_update_this_model_in_macro(tmp_path: pathlib.Path): + models_dir = pathlib.Path("models") + macros_dir = pathlib.Path("macros") + dialect = "duckdb" + + config = Config( + model_defaults=ModelDefaultsConfig(dialect=dialect), + ) + + model_file = """ +MODEL( + name db.test_view_macro_this_model, + kind full, +); + + +SELECT 1 AS cola; + +ON_VIRTUAL_UPDATE_BEGIN; +CREATE OR REPLACE TABLE log_schema AS SELECT @resolve_template('@{schema_name}') as my_schema; +@create_log_view(@this_model); +ON_VIRTUAL_UPDATE_END; + + """ + + create_temp_file( + tmp_path, + pathlib.Path(models_dir, "db", "test_view_macro_this_model.sql"), + model_file, + ) + + create_temp_file( + tmp_path, + pathlib.Path(macros_dir, "create_log_view.py"), + """ +from sqlmesh.core.macros import macro + +@macro() +def create_log_view(evaluator, view_name): + return f"CREATE OR REPLACE TABLE log_view AS SELECT '{view_name}' as fqn_this_model, '{evaluator.this_model}' as evaluator_this_model;" +""", + ) + + context = Context(paths=tmp_path, config=config) + context.plan(environment="dev", auto_apply=True, no_prompts=True) + + model = context.get_model("db.test_view_macro_this_model") + assert ( + model.on_virtual_update[0].sql(dialect=dialect) + == "CREATE OR REPLACE TABLE log_schema AS SELECT @resolve_template('@{schema_name}') AS my_schema" + ) + assert model.on_virtual_update[1].sql(dialect=dialect) == "@create_log_view(@this_model)" + + snapshot = context.get_snapshot("db.test_view_macro_this_model") + assert snapshot and snapshot.version + + log_view = context.fetchdf("select * from log_view").to_dict() + log_schema = context.fetchdf("select * from log_schema").to_dict() + + # Validate that within macro for this_model we resolve to the environment-specific view + assert ( + log_view["fqn_this_model"][0] + == '"db__dev"."test_view_macro_this_model" /* memory.db.test_view_macro_this_model */' + ) + + # Validate that from the macro evaluator this_model we get the environment-specific fqn + assert log_view["evaluator_this_model"][0] == '"db__dev"."test_view_macro_this_model"' + + # Validate the schema is retrieved using resolve_template for the environment-specific schema + assert log_schema["my_schema"][0] == "db__dev" + + +def test_plan_audit_intervals(tmp_path: pathlib.Path, caplog): + ctx = Context( + paths=tmp_path, config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")) + ) + + ctx.upsert_model( + load_sql_based_model( + parse( + """ + MODEL ( + name sqlmesh_audit.date_example, + kind INCREMENTAL_BY_TIME_RANGE( + time_column(date_id, '%Y-%m-%d') + ), + cron '@daily', + partitioned_by (date_id), + audits [unique_combination_of_columns(columns=(date_id))] + ); + + WITH sample_table AS ( + SELECT + DATE('2025-02-01') as date_id, + ) + SELECT date_id FROM sample_table WHERE date_id BETWEEN @start_ds AND @end_ds + """ + ) + ) + ) + + ctx.upsert_model( + load_sql_based_model( + parse( + """ + MODEL ( + name sqlmesh_audit.timestamp_example, + kind INCREMENTAL_BY_TIME_RANGE( + time_column(timestamp_id, '%Y-%m-%d %H:%M:%S') + ), + cron '@daily', + partitioned_by (timestamp_id), + audits [unique_combination_of_columns(columns=(timestamp_id))] + ); + + WITH sample_table AS ( + SELECT + TIMESTAMP('2025-02-01') as timestamp_id, + ) + SELECT timestamp_id FROM sample_table WHERE timestamp_id BETWEEN @start_ts AND @end_ts + """ + ) + ) + ) + + plan = ctx.plan( + environment="dev", auto_apply=True, no_prompts=True, start="2025-02-01", end="2025-02-01" + ) + assert plan.missing_intervals + + date_snapshot = next(s for s in plan.new_snapshots if "date_example" in s.name) + timestamp_snapshot = next(s for s in plan.new_snapshots if "timestamp_example" in s.name) + + # Case 1: The timestamp audit should be in the inclusive range ['2025-02-01 00:00:00', '2025-02-01 23:59:59.999999'] + assert ( + f"""SELECT COUNT(*) FROM (SELECT "timestamp_id" AS "timestamp_id" FROM (SELECT * FROM "sqlmesh__sqlmesh_audit"."sqlmesh_audit__timestamp_example__{timestamp_snapshot.version}" AS "sqlmesh_audit__timestamp_example__{timestamp_snapshot.version}" WHERE "timestamp_id" BETWEEN CAST('2025-02-01 00:00:00' AS TIMESTAMP) AND CAST('2025-02-01 23:59:59.999999' AS TIMESTAMP)) AS "_0" WHERE TRUE GROUP BY "timestamp_id" HAVING COUNT(*) > 1) AS "audit\"""" + in caplog.text + ) + + # Case 2: The date audit should be in the inclusive range ['2025-02-01', '2025-02-01'] + assert ( + f"""SELECT COUNT(*) FROM (SELECT "date_id" AS "date_id" FROM (SELECT * FROM "sqlmesh__sqlmesh_audit"."sqlmesh_audit__date_example__{date_snapshot.version}" AS "sqlmesh_audit__date_example__{date_snapshot.version}" WHERE "date_id" BETWEEN CAST('2025-02-01' AS DATE) AND CAST('2025-02-01' AS DATE)) AS "_0" WHERE TRUE GROUP BY "date_id" HAVING COUNT(*) > 1) AS "audit\"""" + in caplog.text + ) + + +def test_check_intervals(sushi_context, mocker): + with pytest.raises( + SQLMeshError, + match="Environment 'dev' was not found", + ): + sushi_context.check_intervals(environment="dev", no_signals=False, select_models=[]) + + spy = mocker.spy(sqlmesh.core.snapshot.definition, "check_ready_intervals") + intervals = sushi_context.check_intervals(environment=None, no_signals=False, select_models=[]) + + min_intervals = 19 + assert spy.call_count == 2 + assert len(intervals) >= min_intervals + + for i in intervals.values(): + assert not i.intervals + + spy.reset_mock() + intervals = sushi_context.check_intervals(environment=None, no_signals=True, select_models=[]) + assert spy.call_count == 0 + assert len(intervals) >= min_intervals + + intervals = sushi_context.check_intervals( + environment=None, no_signals=False, select_models=["*waiter_as_customer*"] + ) + assert len(intervals) == 1 + + intervals = sushi_context.check_intervals( + environment=None, no_signals=False, select_models=["*waiter_as_customer*"], end="next week" + ) + assert tuple(intervals.values())[0].intervals + + +def test_audit(): + context = Context(config=Config()) + + parsed_model = parse( + """ + MODEL ( + name dummy, + audits ( + not_null_non_blocking(columns=[c]) + ) + ); + + SELECT NULL AS c + """ + ) + context.upsert_model(load_sql_based_model(parsed_model)) + context.plan(no_prompts=True, auto_apply=True) + + assert context.audit(models=["dummy"], start="2020-01-01", end="2020-01-01") is False + + parsed_model = parse( + """ + MODEL ( + name dummy, + audits ( + not_null_non_blocking(columns=[c]) + ) + ); + + SELECT 1 AS c + """ + ) + context.upsert_model(load_sql_based_model(parsed_model)) + context.plan(no_prompts=True, auto_apply=True) + + assert context.audit(models=["dummy"], start="2020-01-01", end="2020-01-01") is True + + +def test_prompt_if_uncategorized_snapshot(mocker: MockerFixture, tmp_path: Path) -> None: + init_example_project(tmp_path, engine_type="duckdb") + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + plan=PlanConfig( + auto_categorize_changes=CategorizerConfig( + external=AutoCategorizationMode.OFF, + python=AutoCategorizationMode.OFF, + sql=AutoCategorizationMode.OFF, + seed=AutoCategorizationMode.OFF, + ), + ), + ) + context = Context(paths=tmp_path, config=config) + context.plan(no_prompts=True, auto_apply=True) + + incremental_model = context.get_model("sqlmesh_example.incremental_model") + incremental_model_query = incremental_model.render_query() + new_incremental_model_query = t.cast(exp.Select, incremental_model_query).select("1 AS z") + context.upsert_model( + "sqlmesh_example.incremental_model", + query_=ParsableSql(sql=new_incremental_model_query.sql(dialect=incremental_model.dialect)), + ) + + mock_console = mocker.Mock() + spy_plan = mocker.spy(mock_console, "plan") + context.console = mock_console + + context.plan() + + calls = spy_plan.mock_calls + assert len(calls) == 1 + + # Show that the presence of uncategorized snapshots forces no_prompts to + # False instead of respecting the default plan config value, which is True + assert calls[0].kwargs["no_prompts"] == False + assert context.config.plan.no_prompts == True + + +def test_plan_explain_skips_tests(sushi_context: Context, mocker: MockerFixture) -> None: + sushi_context.console = TerminalConsole() + spy = mocker.spy(sushi_context, "_run_plan_tests") + sushi_context.plan(environment="dev", explain=True, no_prompts=True, include_unmodified=True) + spy.assert_called_once_with(skip_tests=True) + + +def test_dev_environment_virtual_update_with_environment_statements(tmp_path: Path) -> None: + models_dir = tmp_path / "models" + models_dir.mkdir() + model_sql = """ + MODEL ( + name db.test_model, + kind FULL + ); + + SELECT 1 as id, 'test' as name + """ + + with open(models_dir / "test_model.sql", "w") as f: + f.write(model_sql) + + # Create initial context without environment statements + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + gateways={"duckdb": GatewayConfig(connection=DuckDBConnectionConfig())}, + ) + + context = Context(paths=tmp_path, config=config) + + # First, apply to production + context.plan("prod", auto_apply=True, no_prompts=True) + + # Try to create dev environment without changes (should fail) + with pytest.raises(NoChangesPlanError, match="Creating a new environment requires a change"): + context.plan("dev", auto_apply=True, no_prompts=True) + + # Now create a new context with only new environment statements + config_with_statements = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + gateways={"duckdb": GatewayConfig(connection=DuckDBConnectionConfig())}, + before_all=["CREATE TABLE IF NOT EXISTS audit_log (id INT, action VARCHAR(100))"], + after_all=["INSERT INTO audit_log VALUES (1, 'environment_created')"], + ) + + context_with_statements = Context(paths=tmp_path, config=config_with_statements) + + # This should succeed because environment statements are different + context_with_statements.plan("dev", auto_apply=True, no_prompts=True) + env = context_with_statements.state_reader.get_environment("dev") + assert env is not None + assert env.name == "dev" + + # Verify the environment statements were stored + stored_statements = context_with_statements.state_reader.get_environment_statements("dev") + assert len(stored_statements) == 1 + assert stored_statements[0].before_all == [ + "CREATE TABLE IF NOT EXISTS audit_log (id INT, action VARCHAR(100))" + ] + assert stored_statements[0].after_all == [ + "INSERT INTO audit_log VALUES (1, 'environment_created')" + ] + + # Update environment statements and plan again (should trigger another virtual update) + config_updated_statements = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + gateways={"duckdb": GatewayConfig(connection=DuckDBConnectionConfig())}, + before_all=[ + "CREATE TABLE IF NOT EXISTS audit_log (id INT, action VARCHAR(100))", + "CREATE TABLE IF NOT EXISTS metrics (metric_name VARCHAR(50), value INT)", + ], + after_all=["INSERT INTO audit_log VALUES (1, 'environment_created')"], + ) + + context_updated = Context(paths=tmp_path, config=config_updated_statements) + context_updated.plan("dev", auto_apply=True, no_prompts=True) + + # Verify the updated statements were stored + updated_statements = context_updated.state_reader.get_environment_statements("dev") + assert len(updated_statements) == 1 + assert len(updated_statements[0].before_all) == 2 + assert ( + updated_statements[0].before_all[1] + == "CREATE TABLE IF NOT EXISTS metrics (metric_name VARCHAR(50), value INT)" + ) + + +def test_table_diff_ignores_extra_args(sushi_context: Context): + sushi_context.plan(environment="dev", auto_apply=True, include_unmodified=True) + + # the test fails if this call throws an exception + sushi_context.table_diff( + source="prod", + target="dev", + select_models=["sushi.customers"], + on=["customer_id"], + show_sample=True, + some_tcloud_option=1_000, + ) + + +def test_plan_min_intervals(tmp_path: Path): + init_example_project(tmp_path, engine_type="duckdb", dialect="duckdb") + + context = Context( + paths=tmp_path, config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")) + ) + + current_time = to_datetime("2020-02-01 00:00:01") + + # initial state of example project + context.plan(auto_apply=True, execution_time=current_time) + + (tmp_path / "models" / "daily_model.sql").write_text(""" + MODEL ( + name sqlmesh_example.daily_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column start_dt + ), + start '2020-01-01', + cron '@daily' + ); + + select @start_ds as start_ds, @end_ds as end_ds, @start_dt as start_dt, @end_dt as end_dt; + """) + + (tmp_path / "models" / "weekly_model.sql").write_text(""" + MODEL ( + name sqlmesh_example.weekly_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column start_dt + ), + start '2020-01-01', + cron '@weekly' + ); + + select @start_ds as start_ds, @end_ds as end_ds, @start_dt as start_dt, @end_dt as end_dt; + """) + + (tmp_path / "models" / "monthly_model.sql").write_text(""" + MODEL ( + name sqlmesh_example.monthly_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column start_dt + ), + start '2020-01-01', + cron '@monthly' + ); + + select @start_ds as start_ds, @end_ds as end_ds, @start_dt as start_dt, @end_dt as end_dt; + """) + + (tmp_path / "models" / "ended_daily_model.sql").write_text(""" + MODEL ( + name sqlmesh_example.ended_daily_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column start_dt + ), + start '2020-01-01', + end '2020-01-18', + cron '@daily' + ); + + select @start_ds as start_ds, @end_ds as end_ds, @start_dt as start_dt, @end_dt as end_dt; + """) + + context.load() + + # initial state - backfill from 2020-01-01 -> now() (2020-01-02 00:00:01) on new models + plan = context.plan(execution_time=current_time) + + assert to_datetime(plan.start) == to_datetime("2020-01-01 00:00:00") + assert to_datetime(plan.end) == to_datetime("2020-02-01 00:00:00") + assert to_datetime(plan.execution_time) == to_datetime("2020-02-01 00:00:01") + + def _get_missing_intervals(plan: Plan, name: str) -> t.List[t.Tuple[datetime, datetime]]: + snapshot_id = context.get_snapshot(name, raise_if_missing=True).snapshot_id + snapshot_intervals = next( + si for si in plan.missing_intervals if si.snapshot_id == snapshot_id + ) + return [(to_datetime(s), to_datetime(e)) for s, e in snapshot_intervals.merged_intervals] + + # check initial intervals - should be full time range between start and execution time + assert len(plan.missing_intervals) == 4 + + assert _get_missing_intervals(plan, "sqlmesh_example.daily_model") == [ + (to_datetime("2020-01-01 00:00:00"), to_datetime("2020-02-01 00:00:00")) + ] + assert _get_missing_intervals(plan, "sqlmesh_example.weekly_model") == [ + ( + to_datetime("2020-01-01 00:00:00"), + to_datetime("2020-01-26 00:00:00"), + ) # last week in 2020-01 hasnt fully elapsed yet + ] + assert _get_missing_intervals(plan, "sqlmesh_example.monthly_model") == [ + (to_datetime("2020-01-01 00:00:00"), to_datetime("2020-02-01 00:00:00")) + ] + assert _get_missing_intervals(plan, "sqlmesh_example.ended_daily_model") == [ + (to_datetime("2020-01-01 00:00:00"), to_datetime("2020-01-19 00:00:00")) + ] + + # now, create a dev env for "1 day ago" with min_intervals=1 + plan = context.plan( + environment="pr_env", + start="1 day ago", + execution_time=current_time, + min_intervals=1, + ) + + # this should pick up last day for daily model, last week for weekly model, last month for the monthly model and the last day of "ended_daily_model" before it ended + assert len(plan.missing_intervals) == 4 + + assert _get_missing_intervals(plan, "sqlmesh_example.daily_model") == [ + (to_datetime("2020-01-31 00:00:00"), to_datetime("2020-02-01 00:00:00")) + ] + assert _get_missing_intervals(plan, "sqlmesh_example.weekly_model") == [ + ( + to_datetime("2020-01-19 00:00:00"), # last completed week + to_datetime("2020-01-26 00:00:00"), + ) + ] + assert _get_missing_intervals(plan, "sqlmesh_example.monthly_model") == [ + ( + to_datetime("2020-01-01 00:00:00"), # last completed month + to_datetime("2020-02-01 00:00:00"), + ) + ] + assert _get_missing_intervals(plan, "sqlmesh_example.ended_daily_model") == [ + ( + to_datetime("2020-01-18 00:00:00"), # last day before the model end date + to_datetime("2020-01-19 00:00:00"), + ) + ] + + # run the plan for '1 day ago' but min_intervals=1 + context.apply(plan) + + # show that the data was created (which shows that when the Plan became an EvaluatablePlan and eventually evaluated, the start date overrides didnt get dropped) + assert context.engine_adapter.fetchall( + "select start_dt, end_dt from sqlmesh_example__pr_env.daily_model" + ) == [(to_datetime("2020-01-31 00:00:00"), to_datetime("2020-01-31 23:59:59.999999"))] + assert context.engine_adapter.fetchall( + "select start_dt, end_dt from sqlmesh_example__pr_env.weekly_model" + ) == [ + (to_datetime("2020-01-19 00:00:00"), to_datetime("2020-01-25 23:59:59.999999")), + ] + assert context.engine_adapter.fetchall( + "select start_dt, end_dt from sqlmesh_example__pr_env.monthly_model" + ) == [ + (to_datetime("2020-01-01 00:00:00"), to_datetime("2020-01-31 23:59:59.999999")), + ] + assert context.engine_adapter.fetchall( + "select start_dt, end_dt from sqlmesh_example__pr_env.ended_daily_model" + ) == [ + (to_datetime("2020-01-18 00:00:00"), to_datetime("2020-01-18 23:59:59.999999")), + ] + + +def test_plan_min_intervals_adjusted_for_downstream(tmp_path: Path): + """ + Scenario: + A(hourly) <- B(daily) <- C(weekly) + <- D(two-hourly) + E(monthly) + + We need to ensure that :min_intervals covers at least :min_intervals of all downstream models for the dag to be valid + In this scenario, if min_intervals=1: + - A would need to cover at least (7 days * 24 hours) because its downstream model C is weekly. It should also be unaffected by its sibling, E + - B would need to cover at least 7 days because its downstream model C is weekly + - C would need to cover at least 1 week because min_intervals: 1 + - D would need to cover at least 2 hours because min_intervals: 1 and should be unaffected by C + - E is unrelated to A, B, C and D so would need to cover 1 month satisfy min_intervals: 1. + - It also ensures that each tree branch has a unique cumulative date, because + if the dag is iterated purely in topological order with a global min date it would set A to to 1 month instead if 1 week + """ + + init_example_project(tmp_path, engine_type="duckdb", dialect="duckdb") + + context = Context( + paths=tmp_path, config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")) + ) + + current_time = to_datetime("2020-02-01 00:00:01") + + # initial state of example project + context.plan(auto_apply=True, execution_time=current_time) + + (tmp_path / "models" / "hourly_model.sql").write_text(""" + MODEL ( + name sqlmesh_example.hourly_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column start_dt, + batch_size 1 + ), + start '2020-01-01', + cron '@hourly' + ); + + select @start_dt as start_dt, @end_dt as end_dt; + """) + + (tmp_path / "models" / "two_hourly_model.sql").write_text(""" + MODEL ( + name sqlmesh_example.two_hourly_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column start_dt + ), + start '2020-01-01', + cron '0 */2 * * *' + ); + + select start_dt, end_dt from sqlmesh_example.hourly_model where start_dt between @start_dt and @end_dt; + """) + + (tmp_path / "models" / "unrelated_monthly_model.sql").write_text(""" + MODEL ( + name sqlmesh_example.unrelated_monthly_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column start_dt + ), + start '2020-01-01', + cron '@monthly' + ); + + select @start_dt as start_dt, @end_dt as end_dt; + """) + + (tmp_path / "models" / "daily_model.sql").write_text(""" + MODEL ( + name sqlmesh_example.daily_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column start_dt + ), + start '2020-01-01', + cron '@daily' + ); + + select start_dt, end_dt from sqlmesh_example.hourly_model where start_dt between @start_dt and @end_dt; + """) + + (tmp_path / "models" / "weekly_model.sql").write_text(""" + MODEL ( + name sqlmesh_example.weekly_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column start_dt + ), + start '2020-01-01', + cron '@weekly' + ); + + select start_dt, end_dt from sqlmesh_example.daily_model where start_dt between @start_dt and @end_dt; + """) + + context.load() + + # create a dev env for "1 day ago" with min_intervals=1 + # this should force a weeks worth of intervals for every model + plan = context.plan( + environment="pr_env", + start="1 day ago", + execution_time=current_time, + min_intervals=1, + ) + + def _get_missing_intervals(name: str) -> t.List[t.Tuple[datetime, datetime]]: + snapshot_id = context.get_snapshot(name, raise_if_missing=True).snapshot_id + snapshot_intervals = next( + si for si in plan.missing_intervals if si.snapshot_id == snapshot_id + ) + return [(to_datetime(s), to_datetime(e)) for s, e in snapshot_intervals.merged_intervals] + + # We only operate on completed intervals, so given the current_time this is the range of the last completed week + _get_missing_intervals("sqlmesh_example.weekly_model") == [ + (to_datetime("2020-01-19 00:00:00"), to_datetime("2020-01-26 00:00:00")) + ] + + # The daily model needs to cover the week, so it gets its start date moved back to line up + _get_missing_intervals("sqlmesh_example.daily_model") == [ + (to_datetime("2020-01-19 00:00:00"), to_datetime("2020-02-01 00:00:00")) + ] + + # The hourly model needs to cover both the daily model and the weekly model, so it also gets its start date moved back to line up with the weekly model + assert _get_missing_intervals("sqlmesh_example.hourly_model") == [ + (to_datetime("2020-01-19 00:00:00"), to_datetime("2020-02-01 00:00:00")) + ] + + # The two-hourly model only needs to cover 2 hours and should be unaffected by the fact its sibling node has a weekly child node + # However it still gets backfilled for 24 hours because the plan start is 1 day and this satisfies min_intervals: 1 + assert _get_missing_intervals("sqlmesh_example.two_hourly_model") == [ + (to_datetime("2020-01-31 00:00:00"), to_datetime("2020-02-01 00:00:00")) + ] + + # The unrelated model has no upstream constraints, so its start date doesnt get moved to line up with the weekly model + # However it still gets backfilled for 24 hours because the plan start is 1 day and this satisfies min_intervals: 1 + _get_missing_intervals("sqlmesh_example.unrelated_monthly_model") == [ + (to_datetime("2020-01-01 00:00:00"), to_datetime("2020-02-01 00:00:00")) + ] + + # Check that actually running the plan produces the correct result, since missing intervals are re-calculated in the evaluator + context.apply(plan) + + assert context.engine_adapter.fetchall( + "select min(start_dt), max(end_dt) from sqlmesh_example__pr_env.weekly_model" + ) == [(to_datetime("2020-01-19 00:00:00"), to_datetime("2020-01-25 23:59:59.999999"))] + + assert context.engine_adapter.fetchall( + "select min(start_dt), max(end_dt) from sqlmesh_example__pr_env.daily_model" + ) == [(to_datetime("2020-01-19 00:00:00"), to_datetime("2020-01-31 23:59:59.999999"))] + + assert context.engine_adapter.fetchall( + "select min(start_dt), max(end_dt) from sqlmesh_example__pr_env.hourly_model" + ) == [(to_datetime("2020-01-19 00:00:00"), to_datetime("2020-01-31 23:59:59.999999"))] + + assert context.engine_adapter.fetchall( + "select min(start_dt), max(end_dt) from sqlmesh_example__pr_env.two_hourly_model" + ) == [(to_datetime("2020-01-31 00:00:00"), to_datetime("2020-01-31 23:59:59.999999"))] + + assert context.engine_adapter.fetchall( + "select min(start_dt), max(end_dt) from sqlmesh_example__pr_env.unrelated_monthly_model" + ) == [(to_datetime("2020-01-01 00:00:00"), to_datetime("2020-01-31 23:59:59.999999"))] + + +def test_defaults_pre_post_statements(tmp_path: Path): + config_path = tmp_path / "config.yaml" + models_path = tmp_path / "models" + models_path.mkdir() + + # Create config with default statements + config_path.write_text( + """ +model_defaults: + dialect: duckdb + pre_statements: + - SET memory_limit = '10GB' + - SET threads = @var1 + post_statements: + - ANALYZE @this_model +variables: + var1: 4 +""" + ) + + # Create a model + model_path = models_path / "test_model.sql" + model_path.write_text( + """ +MODEL ( + name test_model, + kind FULL +); + +SELECT 1 as id, 'test' as status; +""" + ) + + ctx = Context(paths=[tmp_path]) + + # Initial plan and apply + initial_plan = ctx.plan(auto_apply=True, no_prompts=True) + assert len(initial_plan.new_snapshots) == 1 + + snapshot = list(initial_plan.new_snapshots)[0] + model = snapshot.model + + # Verify statements are in the model and python environment has been popuplated + assert len(model.pre_statements) == 2 + assert len(model.post_statements) == 1 + assert model.python_env[c.SQLMESH_VARS].payload == "{'var1': 4}" + + # Verify the statements contain the expected SQL + assert model.pre_statements[0].sql() == "SET memory_limit = '10GB'" + assert model.render_pre_statements()[0].sql() == "SET \"memory_limit\" = '10GB'" + assert model.pre_statements[1].sql() == "SET threads = @var1" + assert model.render_pre_statements()[1].sql() == 'SET "threads" = 4' + + # Update config to change pre_statement + config_path.write_text( + """ +model_defaults: + dialect: duckdb + pre_statements: + - SET memory_limit = '5GB' # Changed value + post_statements: + - ANALYZE @this_model +""" + ) + + # Reload context and create new plan + ctx = Context(paths=[tmp_path]) + updated_plan = ctx.plan(no_prompts=True) + + # Should detect a change due to different pre_statements + assert len(updated_plan.directly_modified) == 1 + + # Apply the plan + ctx.apply(updated_plan) + + # Reload the models to get the updated version + ctx.load() + new_model = ctx.models['"test_model"'] + + # Verify updated statements + assert len(new_model.pre_statements) == 1 + assert new_model.pre_statements[0].sql() == "SET memory_limit = '5GB'" + assert new_model.render_pre_statements()[0].sql() == "SET \"memory_limit\" = '5GB'" + + # Verify the change was detected by the plan + assert len(updated_plan.directly_modified) == 1 + + +def test_model_defaults_statements_with_on_virtual_update(tmp_path: Path): + config_path = tmp_path / "config.yaml" + models_path = tmp_path / "models" + models_path.mkdir() + + # Create config with on_virtual_update + config_path.write_text( + """ +model_defaults: + dialect: duckdb + on_virtual_update: + - SELECT 'Model-defailt virtual update' AS message +""" + ) + + # Create a model with its own on_virtual_update as wel + model_path = models_path / "test_model.sql" + model_path.write_text( + """ +MODEL ( + name test_model, + kind FULL +); + +SELECT 1 as id, 'test' as name; + +ON_VIRTUAL_UPDATE_BEGIN; +SELECT 'Model-specific update' AS message; +ON_VIRTUAL_UPDATE_END; +""" + ) + + ctx = Context(paths=[tmp_path]) + + # Plan and apply + plan = ctx.plan(auto_apply=True, no_prompts=True) + + snapshot = list(plan.new_snapshots)[0] + model = snapshot.model + + # Verify both default and model-specific on_virtual_update statements + assert len(model.on_virtual_update) == 2 + + # Default statements should come first + assert model.on_virtual_update[0].sql() == "SELECT 'Model-defailt virtual update' AS message" + assert model.on_virtual_update[1].sql() == "SELECT 'Model-specific update' AS message" + + +def test_uppercase_gateway_external_models(tmp_path): + # Create a temporary SQLMesh project with uppercase gateway name + config_py = tmp_path / "config.py" + config_py.write_text(""" +from sqlmesh.core.config import Config, DuckDBConnectionConfig, GatewayConfig, ModelDefaultsConfig + +config = Config( + gateways={ + "UPPERCASE_GATEWAY": GatewayConfig( + connection=DuckDBConnectionConfig(), + ), + }, + default_gateway="UPPERCASE_GATEWAY", + model_defaults=ModelDefaultsConfig(dialect="duckdb"), +) +""") + + # Create external models file with lowercase gateway name (this should still match uppercase) + external_models_yaml = tmp_path / "external_models.yaml" + external_models_yaml.write_text(""" +- name: test_db.uppercase_gateway_table + description: Test external model with lowercase gateway name that should match uppercase gateway + gateway: uppercase_gateway # lowercase in external model, but config has UPPERCASE_GATEWAY + columns: + id: int + name: text + +- name: test_db.no_gateway_table + description: Test external model without gateway (should be available for all gateways) + columns: + id: int + name: text +""") + + # Create a model that references the external model + models_dir = tmp_path / "models" + models_dir.mkdir() + model_sql = models_dir / "test_model.sql" + model_sql.write_text(""" +MODEL ( + name test.my_model, + kind FULL, +); + +SELECT * FROM test_db.uppercase_gateway_table; +""") + + # Test with uppercase gateway name - this should find both models + context_uppercase = Context(paths=[tmp_path], gateway="UPPERCASE_GATEWAY") + + # Verify external model with lowercase gateway name in YAML is found when using uppercase gateway + gateway_specific_models = [ + model + for model in context_uppercase.models.values() + if model.name == "test_db.uppercase_gateway_table" + ] + assert len(gateway_specific_models) == 1, ( + f"External model with lowercase gateway name should be found with uppercase gateway. Found {len(gateway_specific_models)} models" + ) + + # Verify external model without gateway is also found + no_gateway_models = [ + model + for model in context_uppercase.models.values() + if model.name == "test_db.no_gateway_table" + ] + assert len(no_gateway_models) == 1, ( + f"External model without gateway should be found. Found {len(no_gateway_models)} models" + ) + + # Check that the column types are properly loaded (not UNKNOWN) + external_model = gateway_specific_models[0] + column_types = {name: str(dtype) for name, dtype in external_model.columns_to_types.items()} + assert column_types == { + "id": "INT", + "name": "TEXT", + }, f"External model column types should not be UNKNOWN, got: {column_types}" + + # Test that when using a different case for the gateway parameter, we get the same results + context_mixed_case = Context( + paths=[tmp_path], gateway="uppercase_gateway" + ) # lowercase parameter + + gateway_specific_models_mixed = [ + model + for model in context_mixed_case.models.values() + if model.name == "test_db.uppercase_gateway_table" + ] + # This should work but might fail if case sensitivity is not handled correctly + assert len(gateway_specific_models_mixed) == 1, ( + f"External model should be found regardless of gateway parameter case. Found {len(gateway_specific_models_mixed)} models" + ) + + # Test a case that should demonstrate the potential issue: + # Create another external model file with uppercase gateway name in the YAML + external_models_yaml_uppercase = tmp_path / "external_models_uppercase.yaml" + external_models_yaml_uppercase.write_text(""" +- name: test_db.uppercase_in_yaml + description: Test external model with uppercase gateway name in YAML + gateway: UPPERCASE_GATEWAY # uppercase in external model yaml + columns: + id: int + status: text +""") + + # Add the new external models file to the project + models_dir = tmp_path / "external_models" + models_dir.mkdir(exist_ok=True) + (models_dir / "uppercase_gateway_models.yaml").write_text(""" +- name: test_db.uppercase_in_yaml + description: Test external model with uppercase gateway name in YAML + gateway: UPPERCASE_GATEWAY # uppercase in external model yaml + columns: + id: int + status: text +""") + + # Reload context to pick up the new external models + context_reloaded = Context(paths=[tmp_path], gateway="UPPERCASE_GATEWAY") + + uppercase_in_yaml_models = [ + model + for model in context_reloaded.models.values() + if model.name == "test_db.uppercase_in_yaml" + ] + assert len(uppercase_in_yaml_models) == 1, ( + f"External model with uppercase gateway in YAML should be found. Found {len(uppercase_in_yaml_models)} models" + ) + + +def test_plan_no_start_configured(): + context = Context(config=Config()) + context.upsert_model( + load_sql_based_model( + parse( + """ + MODEL( + name db.xvg, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds + ), + cron '@daily' + ); + + SELECT id, ds FROM (VALUES + ('1', '2020-01-01'), + ) data(id, ds) + WHERE ds BETWEEN @start_ds AND @end_ds + """ + ) + ) + ) + + prod_plan = context.plan(auto_apply=True) + assert len(prod_plan.new_snapshots) == 1 + + context.upsert_model( + load_sql_based_model( + parse( + """ + MODEL( + name db.xvg, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds + ), + cron '@daily', + physical_properties ('some_prop' = 1), + ); + + SELECT id, ds FROM (VALUES + ('1', '2020-01-01'), + ) data(id, ds) + WHERE ds BETWEEN @start_ds AND @end_ds + """ + ) + ) + ) + + # This should raise an error because the model has no start configured and the end time is less than the start time which will be calculated from the intervals + with pytest.raises( + PlanError, + match=r"Model '.*xvg.*': Start date / time .* can't be greater than end date / time .*\.\nSet the `start` attribute in your project config model defaults to avoid this issue", + ): + context.plan("dev", execution_time="1999-01-05") + + +def test_lint_model_projections(tmp_path: Path): + init_example_project(tmp_path, engine_type="duckdb", dialect="duckdb") + + context = Context(paths=tmp_path) + context.upsert_model( + load_sql_based_model( + parse("""MODEL(name sqlmesh_example.m); SELECT 1 AS x, 2 AS x"""), + default_catalog="db", + ) + ) + + config_err = "Linter detected errors in the code. Please fix them before proceeding." + + with pytest.raises(LinterError, match=config_err): + prod_plan = context.plan(no_prompts=True, auto_apply=True) + + +def test_grants_through_plan_apply(sushi_context, mocker): + from sqlmesh.core.engine_adapter.duckdb import DuckDBEngineAdapter + from sqlmesh.core.model.meta import GrantsTargetLayer + + model = sushi_context.get_model("sushi.waiter_revenue_by_day") + + mocker.patch.object(DuckDBEngineAdapter, "SUPPORTS_GRANTS", True) + sync_grants_mock = mocker.patch.object(DuckDBEngineAdapter, "sync_grants_config") + + model_with_grants = model.copy( + update={ + "grants": {"select": ["analyst", "reporter"]}, + "grants_target_layer": GrantsTargetLayer.ALL, + } + ) + sushi_context.upsert_model(model_with_grants) + + sushi_context.plan("dev", no_prompts=True, auto_apply=True) + + # When planning for dev env w/ metadata only changes, + # only virtual layer is updated, so no physical grants are applied + assert sync_grants_mock.call_count == 1 + assert all( + call[0][1] == {"select": ["analyst", "reporter"]} + for call in sync_grants_mock.call_args_list + ) + + sync_grants_mock.reset_mock() + + new_grants = ({"select": ["analyst", "reporter", "manager"], "insert": ["etl_user"]},) + model_updated = model_with_grants.copy( + update={ + "query": parse_one(model.query.sql() + " LIMIT 1000"), + "grants": new_grants, + # force model update, hence new physical table creation + "stamp": "update model and grants", + } + ) + sushi_context.upsert_model(model_updated) + sushi_context.plan("dev", no_prompts=True, auto_apply=True) + + # Applies grants 2 times: 1 x physical, 1 x virtual + assert sync_grants_mock.call_count == 2 + assert all(call[0][1] == new_grants for call in sync_grants_mock.call_args_list) + + sync_grants_mock.reset_mock() + + # plan for prod + sushi_context.plan(no_prompts=True, auto_apply=True) + assert sync_grants_mock.call_count == 2 diff --git a/tests/core/test_dialect.py b/tests/core/test_dialect.py index 988cd34a26..02068b1c59 100644 --- a/tests/core/test_dialect.py +++ b/tests/core/test_dialect.py @@ -12,7 +12,11 @@ select_from_values_for_batch_range, text_diff, ) +import sqlmesh.core.dialect as d from sqlmesh.core.model import SqlModel, load_sql_based_model +from sqlmesh.core.config.connection import DIALECT_TO_TYPE + +pytestmark = pytest.mark.dialect_isolated def test_format_model_expressions(): @@ -26,6 +30,7 @@ def test_format_model_expressions(): a, (b, c) as d, ), -- c + @macro_prop_with_comment(proper := 'foo'), -- k audits [ not_null(columns=[ foo_id, @@ -76,7 +81,12 @@ def test_format_model_expressions(): sum(n + 1)::int as n, -- n o, p + 1, - CAST(x as int)::int, + CAST(x as int)::int,; + +@IF( + @runtime_stage = 'creating', + GRANT SELECT ON foo.bar TO "bla" +) """ ) ) @@ -86,6 +96,7 @@ def test_format_model_expressions(): name a.b, /* a */ kind FULL, /* b */ references (a, (b, c) AS d), /* c */ + @macro_prop_with_comment(proper := 'foo'), /* k */ audits ARRAY( NOT_NULL( columns = ARRAY( @@ -151,7 +162,9 @@ def test_format_model_expressions(): SUM(n + 1)::INT AS n, /* n */ o, p + 1, - x::INT::INT""" + x::INT::INT; + +@IF(@runtime_stage = 'creating', GRANT SELECT ON foo.bar TO "bla")""" ) x = format_model_expressions( @@ -194,9 +207,67 @@ def test_format_model_expressions(): SAFE_CAST('bla' AS INT64) AS FOO""" ) + x = format_model_expressions( + parse( + """ + MODEL(name foo); + SELECT CAST(1 AS INT) AS bla + """ + ), + rewrite_casts=False, + ) + assert ( + x + == """MODEL ( + name foo +); + +SELECT + CAST(1 AS INT) AS bla""" + ) + + x = format_model_expressions( + parse( + """MODEL(name foo); +SELECT CAST(1 AS INT) AS bla; + on_virtual_update_begin; +CREATE OR REPLACE VIEW test_view FROM demo_db.table;GRANT SELECT ON VIEW @this_model TO ROLE owner_name; +JINJA_STATEMENT_BEGIN; GRANT SELECT ON VIEW {{this_model}} TO ROLE admin; JINJA_END; + GRANT REFERENCES, SELECT ON FUTURE VIEWS IN DATABASE demo_db TO ROLE owner_name; +@resolve_parent_name('parent');GRANT SELECT ON VIEW demo_db.table /* sqlglot.meta replace=false */ TO ROLE admin; +ON_VIRTUAL_update_end;""" + ) + ) + + assert ( + x + == """MODEL ( + name foo +); + +SELECT + 1::INT AS bla; + +ON_VIRTUAL_UPDATE_BEGIN; +CREATE OR REPLACE VIEW test_view AS +SELECT + * +FROM demo_db.table; +GRANT SELECT ON VIEW @this_model TO ROLE owner_name; +JINJA_STATEMENT_BEGIN; +GRANT SELECT ON VIEW {{this_model}} TO ROLE admin; +JINJA_END; +GRANT REFERENCES, SELECT ON FUTURE VIEWS IN DATABASE demo_db TO ROLE owner_name; +@resolve_parent_name('parent'); +GRANT SELECT ON VIEW demo_db.table /* sqlglot.meta replace=false */ TO ROLE admin; +ON_VIRTUAL_UPDATE_END;""" + ) + def test_macro_format(): assert parse_one("@EACH(ARRAY(1,2), x -> x)").sql() == "@EACH(ARRAY(1, 2), x -> x)" + assert parse_one("INTERVAL @x DAY").sql() == "INTERVAL @x DAY" + assert parse_one("INTERVAL @'@{bar}' DAY").sql() == "INTERVAL @'@{bar}' DAY" def test_format_body_macros(): @@ -204,17 +275,21 @@ def test_format_body_macros(): format_model_expressions( parse( """ - Model ( name foo ); + Model ( name foo , @macro_dialect(), @properties_macro(prop_1 := 'max', prop_2 := 33)); @WITH(TRUE) x AS (SELECT 1) SELECT col::int - FROM foo @ORDER_BY(@include_order_by) + FROM foo + WHERE @MY_MACRO() /* my macro comment */ +@ORDER_BY(@include_order_by) @EACH( @columns, item -> @'@iteaoeuatnoehutoenahuoanteuhonateuhaoenthuaoentuhaeotnhaoem'), @'@foo' """ ) ) == """MODEL ( - name foo + name foo, + @macro_dialect(), + @properties_macro(prop_1 := 'max', prop_2 := 33) ); @WITH(TRUE) x AS ( @@ -224,6 +299,8 @@ def test_format_body_macros(): SELECT col::INT FROM foo +WHERE + @MY_MACRO() /* my macro comment */ @ORDER_BY(@include_order_by) @EACH(@columns, item -> @'@iteaoeuatnoehutoenahuoanteuhonateuhaoenthuaoentuhaeotnhaoem'), @'@foo'""" @@ -298,7 +375,7 @@ def test_parse(): dialect snowflake ); - SELECT a FROM @if(true, m2, m3) + SELECT a FROM @If(true, m2, m3) """, ) assert len(expressions) == 2 @@ -541,7 +618,7 @@ def test_model_normalization_multiple_serde( expressions = parse( f""" MODEL ( - name {table}, + name {exp.maybe_parse(table, into=exp.Table).sql(dialect=normalization_dialect)}, kind INCREMENTAL_BY_TIME_RANGE( time_column ds ), @@ -570,9 +647,6 @@ def test_model_normalization_quote_flexibility(): normalize_model_name('"catalog"."db"."table"', default_catalog=None, dialect="spark") == '"catalog"."db"."table"' ) - # It doesn't work the other way which is what we currently expect - with pytest.raises(ParseError): - normalize_model_name("`catalog`.`db`.`table`", default_catalog=None, dialect=None) def test_macro_parse(): @@ -606,3 +680,55 @@ def test_conditional_statement(): q.sql("snowflake") == "@IF(TRUE, COPY INTO 's3://example/data.csv' FROM EXTRA.EXAMPLE.TABLE STORAGE_INTEGRATION = S3_INTEGRATION FILE_FORMAT = (TYPE=CSV COMPRESSION=NONE NULL_IF=('') FIELD_OPTIONALLY_ENCLOSED_BY='\"') HEADER = TRUE OVERWRITE = TRUE SINGLE = TRUE /* this is a comment */)" ) + + q = parse_one("@IF(cond, VACUUM ANALYZE);", read="postgres") + assert q.sql(dialect="postgres") == "@IF(cond, VACUUM ANALYZE)" + + +def test_model_name_cannot_be_string(): + with pytest.raises(ParseError) as parse_error: + parse( + """ + MODEL( + name 'schema.table', + kind FULL + ); + + SELECT + 1 AS c + """ + ) + + assert "\\'name\\' property cannot be a string value" in str(parse_error) + + +def test_parse_snowflake_create_schema_ddl(): + assert parse_one("CREATE SCHEMA d.s", dialect="snowflake").sql() == "CREATE SCHEMA d.s" + + +@pytest.mark.parametrize("dialect", sorted(set(DIALECT_TO_TYPE.values()))) +def test_sqlglot_extended_correctly(dialect: str) -> None: + # MODEL is a SQLMesh extension and not part of SQLGlot + # If we can roundtrip an expression containing MODEL across every dialect, then the SQLMesh extensions have been registered correctly + ast = d.parse_one("MODEL (name foo)", dialect=dialect) + assert isinstance(ast, d.Model) + name_prop = ast.find(exp.Property) + assert isinstance(name_prop, exp.Property) + assert name_prop.this == "name" + value = name_prop.args["value"] + assert isinstance(value, exp.Table) + assert value.sql() == "foo" + assert ast.sql(dialect=dialect) == "MODEL (\nname foo\n)" + + +def test_connected_identifier(): + ast = d.parse_one("""SELECT ("x"at time zone 'utc')::timestamp as x""", "redshift") + assert ast.sql("redshift") == """SELECT CAST(("x" AT TIME ZONE 'utc') AS TIMESTAMP) AS x""" + + +def test_pipe_syntax(): + ast = d.parse_one("SELECT * FROM (FROM t2 |> SELECT id)", "bigquery") + assert ( + ast.sql("bigquery") + == "SELECT * FROM (WITH __tmp1 AS (SELECT id FROM t2) SELECT * FROM __tmp1)" + ) diff --git a/tests/core/test_environment.py b/tests/core/test_environment.py index 42b75b0c30..307f220c25 100644 --- a/tests/core/test_environment.py +++ b/tests/core/test_environment.py @@ -1,6 +1,8 @@ import pytest from sqlmesh.core.environment import Environment, EnvironmentNamingInfo +from sqlmesh.core.snapshot import SnapshotId, SnapshotTableInfo +from sqlmesh.core.state_sync.db.environment import _environment_to_df def test_sanitize_name(): @@ -42,3 +44,38 @@ def test_from_environment_catalog_mapping(mapping, name, expected): ).catalog_name_override == expected ) + + +def test_lazy_loading(sushi_context): + snapshot_ids = [s.snapshot_id for s in sushi_context.snapshots.values()] + snapshot_table_infos = [s.table_info for s in sushi_context.snapshots.values()] + env = Environment( + name="test", + start_at="now", + plan_id="plan_1", + snapshots=snapshot_table_infos, + promoted_snapshot_ids=snapshot_ids, + previous_finalized_snapshots=snapshot_table_infos, + ) + + df = _environment_to_df(env) + row = df.to_dict(orient="records")[0] + env = Environment(**{field: row[field] for field in Environment.all_fields()}) + + assert all(isinstance(snapshot, dict) for snapshot in env.snapshots_) + assert all(isinstance(snapshot, SnapshotTableInfo) for snapshot in env.snapshots) + assert all(isinstance(s_id, dict) for s_id in env.promoted_snapshot_ids_) + assert all(isinstance(s_id, SnapshotId) for s_id in env.promoted_snapshot_ids) + assert all(isinstance(snapshot, dict) for snapshot in env.previous_finalized_snapshots_) + assert all( + isinstance(snapshot, SnapshotTableInfo) for snapshot in env.previous_finalized_snapshots + ) + + with pytest.raises(ValueError, match="Must be a list of SnapshotTableInfo dicts or objects"): + Environment(**{**env.dict(), **{"snapshots": [1, 2, 3]}}) + + with pytest.raises(ValueError, match="Must be a list of SnapshotId dicts or objects"): + Environment(**{**env.dict(), **{"promoted_snapshot_ids": [1, 2, 3]}}) + + with pytest.raises(ValueError, match="Must be a list of SnapshotTableInfo dicts or objects"): + Environment(**{**env.dict(), **{"previous_finalized_snapshots": [1, 2, 3]}}) diff --git a/tests/core/test_execution_tracker.py b/tests/core/test_execution_tracker.py new file mode 100644 index 0000000000..0e58395bee --- /dev/null +++ b/tests/core/test_execution_tracker.py @@ -0,0 +1,50 @@ +from __future__ import annotations + +from concurrent.futures import ThreadPoolExecutor + +from sqlmesh.core.snapshot.execution_tracker import QueryExecutionStats, QueryExecutionTracker +from sqlmesh.core.snapshot import SnapshotIdBatch, SnapshotId + + +def test_execution_tracker_thread_isolation() -> None: + def worker(id: SnapshotId, row_counts: list[int]) -> QueryExecutionStats: + with execution_tracker.track_execution(SnapshotIdBatch(snapshot_id=id, batch_id=0)) as ctx: + assert execution_tracker.is_tracking() + + for count in row_counts: + execution_tracker.record_execution("SELECT 1", count, None) + + assert ctx is not None + return ctx.get_execution_stats() + + execution_tracker = QueryExecutionTracker() + + with ThreadPoolExecutor() as executor: + futures = [ + executor.submit(worker, SnapshotId(name="batch_A", identifier="batch_A"), [10, 5]), + executor.submit(worker, SnapshotId(name="batch_B", identifier="batch_B"), [3, 7]), + ] + results = [f.result() for f in futures] + + # Main thread has no active tracking context + assert not execution_tracker.is_tracking() + + # Order of results is not deterministic, so look up by id + by_batch = {s.snapshot_id_batch: s for s in results} + + assert ( + by_batch[ + SnapshotIdBatch( + snapshot_id=SnapshotId(name="batch_A", identifier="batch_A"), batch_id=0 + ) + ].total_rows_processed + == 15 + ) + assert ( + by_batch[ + SnapshotIdBatch( + snapshot_id=SnapshotId(name="batch_B", identifier="batch_B"), batch_id=0 + ) + ].total_rows_processed + == 10 + ) diff --git a/tests/core/test_format.py b/tests/core/test_format.py index 907e0b4ff9..7d544eadf0 100644 --- a/tests/core/test_format.py +++ b/tests/core/test_format.py @@ -1,14 +1,17 @@ import pathlib +from pytest_mock.plugin import MockerFixture from sqlmesh.core.config import Config from sqlmesh.core.context import Context from sqlmesh.core.dialect import parse from sqlmesh.core.audit import ModelAudit from sqlmesh.core.model import SqlModel, load_sql_based_model from tests.utils.test_filesystem import create_temp_file +from unittest.mock import call +from sqlmesh.core.config import ModelDefaultsConfig -def test_format_files(tmp_path: pathlib.Path): +def test_format_files(tmp_path: pathlib.Path, mocker: MockerFixture): models_dir = pathlib.Path("models") audits_dir = pathlib.Path("audits") @@ -25,7 +28,7 @@ def test_format_files(tmp_path: pathlib.Path): f3 = create_temp_file( tmp_path, pathlib.Path(audits_dir, "audit_1.sql"), - "AUDIT(name assert_positive_id, dialect 'duckdb'); SELECT * FROM @this_model WHERE \"CaseSensitive\"_item_id < 0;", + "AUDIT(name assert_positive_id, dialect 'duckdb'); SELECT * FROM @this_model WHERE \"CaseSensitive_item_id\" < 0;", ) f4 = create_temp_file( tmp_path, @@ -35,6 +38,7 @@ def test_format_files(tmp_path: pathlib.Path): config = Config() context = Context(paths=tmp_path, config=config) + context.console = mocker.Mock() context.load() assert isinstance(context.get_model("this.model"), SqlModel) @@ -45,7 +49,7 @@ def test_format_files(tmp_path: pathlib.Path): assert context.get_model("other.model").query.sql() == "SELECT 2 AS another_column" # type: ignore assert context.get_model("audit.model").query.sql() == "SELECT 3 AS item_id" # type: ignore assert ( - context.get_model("audit.model").inline_audits["inline_audit"].query.sql() + context.get_model("audit.model").audit_definitions["inline_audit"].query.sql() == "SELECT * FROM @this_model WHERE item_id < 0" ) assert ( @@ -53,9 +57,24 @@ def test_format_files(tmp_path: pathlib.Path): == 'SELECT * FROM @this_model WHERE "CaseSensitive_item_id" < 0' ) + assert not context.format(check=True) + assert all( + c in context.console.log_status_update.mock_calls # type: ignore + for c in [ + call(f"{tmp_path / 'models/model_3.sql'} needs reformatting."), + call(f"{tmp_path / 'models/model_2.sql'} needs reformatting."), + call(f"{tmp_path / 'models/model_1.sql'} needs reformatting."), + call(f"{tmp_path / 'audits/audit_1.sql'} needs reformatting."), + call("\n4 file(s) need reformatting."), + ] + ) + # Transpile project to BigQuery context.format(transpile="bigquery") + # Ensure format check is successful + assert context.format(transpile="bigquery", check=True) + # Ensure transpilation success AND model specific dialect is mutated upd1 = f1.read_text(encoding="utf-8") assert ( @@ -82,3 +101,46 @@ def test_format_files(tmp_path: pathlib.Path): upd4 == "MODEL (\n name audit.model,\n audits (\n inline_audit\n )\n);\n\nSELECT\n 3 AS item_id;\n\nAUDIT (\n name inline_audit\n);\n\nSELECT\n *\nFROM @this_model\nWHERE\n item_id < 0" ) + + +def test_ignore_formating_files(tmp_path: pathlib.Path): + models_dir = pathlib.Path("models") + audits_dir = pathlib.Path("audits") + + # Case 1: Model and Audit are not formatted if the flag is set to false (overriding defaults) + model1_text = "MODEL(name this.model1, dialect 'duckdb', formatting false); SELECT 1 col" + model1 = create_temp_file(tmp_path, pathlib.Path(models_dir, "model_1.sql"), model1_text) + + audit1_text = "AUDIT(name audit1, dialect 'duckdb', formatting false); SELECT col1 col2 FROM @this_model WHERE foo < 0;" + audit1 = create_temp_file(tmp_path, pathlib.Path(audits_dir, "audit_1.sql"), audit1_text) + + audit2_text = "AUDIT(name audit2, dialect 'duckdb', standalone true, formatting false); SELECT col1 col2 FROM @this_model WHERE foo < 0;" + audit2 = create_temp_file(tmp_path, pathlib.Path(audits_dir, "audit_2.sql"), audit2_text) + + Context( + paths=tmp_path, config=Config(model_defaults=ModelDefaultsConfig(formatting=True)) + ).format() + + assert model1.read_text(encoding="utf-8") == model1_text + assert audit1.read_text(encoding="utf-8") == audit1_text + assert audit2.read_text(encoding="utf-8") == audit2_text + + # Case 2: Model is formatted (or not) based on it's flag and the defaults flag + model2_text = "MODEL(name this.model2, dialect 'duckdb'); SELECT 1 col" + model2 = create_temp_file(tmp_path, pathlib.Path(models_dir, "model_2.sql"), model2_text) + + model3_text = "MODEL(name this.model3, dialect 'duckdb', formatting true); SELECT 1 col" + model3 = create_temp_file(tmp_path, pathlib.Path(models_dir, "model_3.sql"), model3_text) + + Context( + paths=tmp_path, config=Config(model_defaults=ModelDefaultsConfig(formatting=False)) + ).format() + + # Case 2.1: Model is not formatted if the defaults flag is set to false + assert model2.read_text(encoding="utf-8") == model2_text + + # Case 2.2: Model is formatted if it's flag is set to true, overriding defaults + assert ( + model3.read_text(encoding="utf-8") + == "MODEL (\n name this.model3,\n dialect 'duckdb',\n formatting TRUE\n);\n\nSELECT\n 1 AS col" + ) diff --git a/tests/core/test_integration.py b/tests/core/test_integration.py deleted file mode 100644 index fc7c2417f8..0000000000 --- a/tests/core/test_integration.py +++ /dev/null @@ -1,2420 +0,0 @@ -from __future__ import annotations - -import typing as t -from collections import Counter -from datetime import timedelta - -import pandas as pd -import pytest -from pathlib import Path -from freezegun import freeze_time -from pytest_mock.plugin import MockerFixture -from sqlglot import exp -from sqlglot.expressions import DataType - -from sqlmesh import CustomMaterialization -from sqlmesh.cli.example_project import init_example_project -from sqlmesh.core import constants as c -from sqlmesh.core import dialect as d -from sqlmesh.core.config import ( - AutoCategorizationMode, - Config, - GatewayConfig, - ModelDefaultsConfig, - DuckDBConnectionConfig, -) -from sqlmesh.core.console import Console -from sqlmesh.core.context import Context -from sqlmesh.core.engine_adapter import EngineAdapter -from sqlmesh.core.environment import EnvironmentNamingInfo -from sqlmesh.core.model import ( - IncrementalByTimeRangeKind, - IncrementalByUniqueKeyKind, - Model, - ModelKind, - ModelKindName, - SqlModel, - TimeColumn, - load_sql_based_model, -) -from sqlmesh.core.model.kind import model_kind_type_from_name -from sqlmesh.core.plan import Plan, PlanBuilder, SnapshotIntervals -from sqlmesh.core.snapshot import ( - Snapshot, - SnapshotChangeCategory, - SnapshotId, - SnapshotInfoLike, - SnapshotTableInfo, -) -from sqlmesh.utils.date import TimeLike, now, to_date, to_datetime, to_timestamp -from tests.conftest import DuckDBMetadata, SushiDataValidator - - -if t.TYPE_CHECKING: - from sqlmesh import QueryOrDF - -pytestmark = pytest.mark.slow - - -@pytest.fixture(autouse=True) -def mock_choices(mocker: MockerFixture): - mocker.patch("sqlmesh.core.console.TerminalConsole._get_snapshot_change_category") - mocker.patch("sqlmesh.core.console.TerminalConsole._prompt_backfill") - - -def plan_choice(plan_builder: PlanBuilder, choice: SnapshotChangeCategory) -> None: - for snapshot in plan_builder.build().snapshots.values(): - if not snapshot.version: - plan_builder.set_choice(snapshot, choice) - - -@freeze_time("2023-01-08 15:00:00") -@pytest.mark.parametrize( - "context_fixture", - ["sushi_context", "sushi_no_default_catalog"], -) -def test_forward_only_plan_with_effective_date(context_fixture: Context, request): - context = request.getfixturevalue(context_fixture) - model_name = "sushi.waiter_revenue_by_day" - model = context.get_model(model_name) - context.upsert_model(add_projection_to_model(t.cast(SqlModel, model)), start="2023-01-01") - snapshot = context.get_snapshot(model, raise_if_missing=True) - top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) - - plan_builder = context.plan_builder("dev", skip_tests=True, forward_only=True) - plan = plan_builder.build() - assert len(plan.new_snapshots) == 2 - assert ( - plan.context_diff.snapshots[snapshot.snapshot_id].change_category - == SnapshotChangeCategory.FORWARD_ONLY - ) - assert ( - plan.context_diff.snapshots[top_waiters_snapshot.snapshot_id].change_category - == SnapshotChangeCategory.FORWARD_ONLY - ) - assert plan.start == to_date("2023-01-07") - assert plan.missing_intervals == [ - SnapshotIntervals( - snapshot_id=top_waiters_snapshot.snapshot_id, - intervals=[(to_timestamp("2023-01-07"), to_timestamp("2023-01-08"))], - ), - SnapshotIntervals( - snapshot_id=snapshot.snapshot_id, - intervals=[(to_timestamp("2023-01-07"), to_timestamp("2023-01-08"))], - ), - ] - - plan = plan_builder.set_effective_from("2023-01-05").build() - # Default start should be set to effective_from - assert plan.missing_intervals == [ - SnapshotIntervals( - snapshot_id=top_waiters_snapshot.snapshot_id, - intervals=[ - (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), - (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), - (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), - ], - ), - SnapshotIntervals( - snapshot_id=snapshot.snapshot_id, - intervals=[ - (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), - (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), - (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), - ], - ), - ] - - plan = plan_builder.set_start("2023-01-06").build() - # Start override should take precedence - assert plan.missing_intervals == [ - SnapshotIntervals( - snapshot_id=top_waiters_snapshot.snapshot_id, - intervals=[ - (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), - (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), - ], - ), - SnapshotIntervals( - snapshot_id=snapshot.snapshot_id, - intervals=[ - (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), - (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), - ], - ), - ] - - plan = plan_builder.set_effective_from("2023-01-04").build() - # Start should remain unchanged - assert plan.start == "2023-01-06" - assert plan.missing_intervals == [ - SnapshotIntervals( - snapshot_id=top_waiters_snapshot.snapshot_id, - intervals=[ - (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), - (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), - ], - ), - SnapshotIntervals( - snapshot_id=snapshot.snapshot_id, - intervals=[ - (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), - (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), - ], - ), - ] - - context.apply(plan) - - dev_df = context.engine_adapter.fetchdf( - "SELECT DISTINCT event_date FROM sushi__dev.waiter_revenue_by_day ORDER BY event_date" - ) - assert dev_df["event_date"].tolist() == [ - pd.to_datetime("2023-01-06"), - pd.to_datetime("2023-01-07"), - ] - - prod_plan = context.plan(no_prompts=True, skip_tests=True) - # Make sure that the previously set effective_from is respected - assert prod_plan.start == to_timestamp("2023-01-04") - assert prod_plan.missing_intervals == [ - SnapshotIntervals( - snapshot_id=top_waiters_snapshot.snapshot_id, - intervals=[ - (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), - (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), - (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), - (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), - ], - ), - SnapshotIntervals( - snapshot_id=snapshot.snapshot_id, - intervals=[ - (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), - (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), - (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), - (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), - ], - ), - ] - - context.apply(prod_plan) - - prod_df = context.engine_adapter.fetchdf( - "SELECT DISTINCT event_date FROM sushi.waiter_revenue_by_day WHERE one IS NOT NULL ORDER BY event_date" - ) - assert prod_df["event_date"].tolist() == [ - pd.to_datetime(x) for x in ["2023-01-04", "2023-01-05", "2023-01-06", "2023-01-07"] - ] - - -@freeze_time("2023-01-08 15:00:00") -def test_forward_only_model_regular_plan(init_and_plan_context: t.Callable): - context, plan = init_and_plan_context("examples/sushi") - context.apply(plan) - - model_name = "sushi.waiter_revenue_by_day" - - model = context.get_model(model_name) - model = add_projection_to_model(t.cast(SqlModel, model)) - forward_only_kind = model.kind.copy(update={"forward_only": True}) - model = model.copy(update={"kind": forward_only_kind}) - - context.upsert_model(model) - snapshot = context.get_snapshot(model, raise_if_missing=True) - top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) - - plan = context.plan("dev", no_prompts=True, skip_tests=True) - assert len(plan.new_snapshots) == 2 - assert ( - plan.context_diff.snapshots[snapshot.snapshot_id].change_category - == SnapshotChangeCategory.FORWARD_ONLY - ) - assert ( - plan.context_diff.snapshots[top_waiters_snapshot.snapshot_id].change_category - == SnapshotChangeCategory.FORWARD_ONLY - ) - assert plan.start == to_datetime("2023-01-01") - assert not plan.missing_intervals - - context.apply(plan) - - dev_df = context.engine_adapter.fetchdf( - "SELECT DISTINCT event_date FROM sushi__dev.waiter_revenue_by_day ORDER BY event_date" - ) - assert not dev_df["event_date"].tolist() - - # Run a restatement plan to preview changes - plan_builder = context.plan_builder("dev", skip_tests=True, restate_models=[model_name]) - plan_builder.set_start("2023-01-06") - assert plan_builder.build().missing_intervals == [ - SnapshotIntervals( - snapshot_id=top_waiters_snapshot.snapshot_id, - intervals=[ - (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), - (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), - ], - ), - SnapshotIntervals( - snapshot_id=snapshot.snapshot_id, - intervals=[ - (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), - (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), - ], - ), - ] - - # Make sure that changed start is reflected in missing intervals - plan_builder.set_start("2023-01-07") - assert plan_builder.build().missing_intervals == [ - SnapshotIntervals( - snapshot_id=top_waiters_snapshot.snapshot_id, - intervals=[ - (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), - ], - ), - SnapshotIntervals( - snapshot_id=snapshot.snapshot_id, - intervals=[ - (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), - ], - ), - ] - - context.apply(plan_builder.build()) - - dev_df = context.engine_adapter.fetchdf( - "SELECT DISTINCT event_date FROM sushi__dev.waiter_revenue_by_day ORDER BY event_date" - ) - assert dev_df["event_date"].tolist() == [pd.to_datetime("2023-01-07")] - - # Promote changes to prod - prod_plan = context.plan(no_prompts=True, skip_tests=True) - assert not prod_plan.missing_intervals - - context.apply(prod_plan) - - # The change was applied in a forward-only manner so no values in the new column should be populated - prod_df = context.engine_adapter.fetchdf( - "SELECT DISTINCT event_date FROM sushi.waiter_revenue_by_day WHERE one IS NOT NULL ORDER BY event_date" - ) - assert not prod_df["event_date"].tolist() - - -@freeze_time("2023-01-08 15:00:00") -def test_forward_only_model_regular_plan_preview_enabled(init_and_plan_context: t.Callable): - context, plan = init_and_plan_context("examples/sushi") - context.apply(plan) - - model_name = "sushi.waiter_revenue_by_day" - - model = context.get_model(model_name) - model = add_projection_to_model(t.cast(SqlModel, model)) - forward_only_kind = model.kind.copy(update={"forward_only": True}) - model = model.copy(update={"kind": forward_only_kind}) - - context.upsert_model(model) - snapshot = context.get_snapshot(model, raise_if_missing=True) - top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) - - plan = context.plan("dev", no_prompts=True, skip_tests=True, enable_preview=True) - assert len(plan.new_snapshots) == 2 - assert ( - plan.context_diff.snapshots[snapshot.snapshot_id].change_category - == SnapshotChangeCategory.FORWARD_ONLY - ) - assert ( - plan.context_diff.snapshots[top_waiters_snapshot.snapshot_id].change_category - == SnapshotChangeCategory.FORWARD_ONLY - ) - assert plan.start == to_date("2023-01-07") - assert plan.missing_intervals == [ - SnapshotIntervals( - snapshot_id=top_waiters_snapshot.snapshot_id, - intervals=[ - (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), - ], - ), - SnapshotIntervals( - snapshot_id=snapshot.snapshot_id, - intervals=[ - (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), - ], - ), - ] - - context.apply(plan) - - dev_df = context.engine_adapter.fetchdf( - "SELECT DISTINCT event_date FROM sushi__dev.waiter_revenue_by_day ORDER BY event_date" - ) - assert dev_df["event_date"].tolist() == [pd.to_datetime("2023-01-07")] - - -@freeze_time("2023-01-08 15:00:00") -def test_full_history_restatement_model_regular_plan_preview_enabled( - init_and_plan_context: t.Callable, -): - context, plan = init_and_plan_context("examples/sushi") - context.apply(plan) - - model_name = "sushi.marketing" # SCD2 model - - model = context.get_model(model_name) - model = add_projection_to_model(t.cast(SqlModel, model)) - - context.upsert_model(model) - snapshot = context.get_snapshot(model, raise_if_missing=True) - customers_snapshot = context.get_snapshot("sushi.customers", raise_if_missing=True) - active_customers_snapshot = context.get_snapshot( - "sushi.active_customers", raise_if_missing=True - ) - waiter_as_customer_snapshot = context.get_snapshot( - "sushi.waiter_as_customer_by_day", raise_if_missing=True - ) - - plan = context.plan("dev", no_prompts=True, skip_tests=True, enable_preview=True) - - assert len(plan.new_snapshots) == 4 - assert ( - plan.context_diff.snapshots[snapshot.snapshot_id].change_category - == SnapshotChangeCategory.FORWARD_ONLY - ) - assert ( - plan.context_diff.snapshots[customers_snapshot.snapshot_id].change_category - == SnapshotChangeCategory.FORWARD_ONLY - ) - assert ( - plan.context_diff.snapshots[active_customers_snapshot.snapshot_id].change_category - == SnapshotChangeCategory.FORWARD_ONLY - ) - assert ( - plan.context_diff.snapshots[waiter_as_customer_snapshot.snapshot_id].change_category - == SnapshotChangeCategory.FORWARD_ONLY - ) - - assert plan.start == to_date("2023-01-07") - assert plan.missing_intervals == [ - SnapshotIntervals( - snapshot_id=active_customers_snapshot.snapshot_id, - intervals=[ - (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), - ], - ), - SnapshotIntervals( - snapshot_id=customers_snapshot.snapshot_id, - intervals=[ - (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), - ], - ), - SnapshotIntervals( - snapshot_id=snapshot.snapshot_id, - intervals=[ - (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), - ], - ), - SnapshotIntervals( - snapshot_id=waiter_as_customer_snapshot.snapshot_id, - intervals=[ - (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), - ], - ), - ] - - context.apply(plan) - - -@freeze_time("2023-01-08 15:00:00") -def test_metadata_changed_regular_plan_preview_enabled(init_and_plan_context: t.Callable): - context, plan = init_and_plan_context("examples/sushi") - context.apply(plan) - - model_name = "sushi.waiter_revenue_by_day" - - model = context.get_model(model_name) - model = model.copy(update={"owner": "new_owner"}) - - context.upsert_model(model) - snapshot = context.get_snapshot(model, raise_if_missing=True) - top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) - - plan = context.plan("dev", no_prompts=True, skip_tests=True, enable_preview=True) - assert len(plan.new_snapshots) == 2 - assert ( - plan.context_diff.snapshots[snapshot.snapshot_id].change_category - == SnapshotChangeCategory.METADATA - ) - assert ( - plan.context_diff.snapshots[top_waiters_snapshot.snapshot_id].change_category - == SnapshotChangeCategory.METADATA - ) - assert not plan.missing_intervals - assert not plan.restatements - - -@freeze_time("2023-01-08 15:00:00") -def test_hourly_model_with_lookback_no_backfill_in_dev(init_and_plan_context: t.Callable): - context, plan = init_and_plan_context("examples/sushi") - - model_name = "sushi.waiter_revenue_by_day" - - model = context.get_model(model_name) - model = SqlModel.parse_obj( - { - **model.dict(), - "kind": model.kind.copy(update={"lookback": 1}), - "cron": "@hourly", - "audits": [], - } - ) - context.upsert_model(model) - - plan = context.plan("prod", no_prompts=True, skip_tests=True) - context.apply(plan) - - top_waiters_model = context.get_model("sushi.top_waiters") - top_waiters_model = add_projection_to_model(t.cast(SqlModel, top_waiters_model), literal=True) - context.upsert_model(top_waiters_model) - - context.get_snapshot(model, raise_if_missing=True) - top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) - - with freeze_time(now() + timedelta(hours=2)): - plan = context.plan("dev", no_prompts=True, skip_tests=True) - # Make sure the waiter_revenue_by_day model is not backfilled. - assert plan.missing_intervals == [ - SnapshotIntervals( - snapshot_id=top_waiters_snapshot.snapshot_id, - intervals=[ - (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), - (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), - (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), - (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), - (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), - (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), - (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), - ], - ), - ] - - -@freeze_time("2023-01-08 00:00:00") -def test_parent_cron_before_child(init_and_plan_context: t.Callable): - context, plan = init_and_plan_context("examples/sushi") - - model = context.get_model("sushi.waiter_revenue_by_day") - model = SqlModel.parse_obj( - { - **model.dict(), - "cron": "50 23 * * *", - } - ) - context.upsert_model(model) - - plan = context.plan("prod", no_prompts=True, skip_tests=True) - context.apply(plan) - - top_waiters_model = context.get_model("sushi.top_waiters") - top_waiters_model = add_projection_to_model(t.cast(SqlModel, top_waiters_model), literal=True) - context.upsert_model(top_waiters_model) - - top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) - - with freeze_time("2023-01-08 23:55:00"): # Past parent's cron, but before child's - plan = context.plan("dev", no_prompts=True, skip_tests=True) - # Make sure the waiter_revenue_by_day model is not backfilled. - assert plan.missing_intervals == [ - SnapshotIntervals( - snapshot_id=top_waiters_snapshot.snapshot_id, - intervals=[ - (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), - (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), - (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), - (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), - (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), - (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), - (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), - ], - ), - ] - - -@freeze_time("2023-01-08 15:00:00") -def test_forward_only_parent_created_in_dev_child_created_in_prod( - init_and_plan_context: t.Callable, -): - context, plan = init_and_plan_context("examples/sushi") - context.apply(plan) - - waiter_revenue_by_day_model = context.get_model("sushi.waiter_revenue_by_day") - waiter_revenue_by_day_model = add_projection_to_model( - t.cast(SqlModel, waiter_revenue_by_day_model) - ) - forward_only_kind = waiter_revenue_by_day_model.kind.copy(update={"forward_only": True}) - waiter_revenue_by_day_model = waiter_revenue_by_day_model.copy( - update={"kind": forward_only_kind} - ) - context.upsert_model(waiter_revenue_by_day_model) - - waiter_revenue_by_day_snapshot = context.get_snapshot( - waiter_revenue_by_day_model, raise_if_missing=True - ) - top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) - - plan = context.plan("dev", no_prompts=True, skip_tests=True) - assert len(plan.new_snapshots) == 2 - assert ( - plan.context_diff.snapshots[waiter_revenue_by_day_snapshot.snapshot_id].change_category - == SnapshotChangeCategory.FORWARD_ONLY - ) - assert ( - plan.context_diff.snapshots[top_waiters_snapshot.snapshot_id].change_category - == SnapshotChangeCategory.FORWARD_ONLY - ) - assert plan.start == to_datetime("2023-01-01") - assert not plan.missing_intervals - - context.apply(plan) - - # Update the child to refer to a newly added column. - top_waiters_model = context.get_model("sushi.top_waiters") - top_waiters_model = add_projection_to_model(t.cast(SqlModel, top_waiters_model), literal=False) - context.upsert_model(top_waiters_model) - - top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) - - plan = context.plan("prod", no_prompts=True, skip_tests=True) - assert len(plan.new_snapshots) == 1 - assert ( - plan.context_diff.snapshots[top_waiters_snapshot.snapshot_id].change_category - == SnapshotChangeCategory.NON_BREAKING - ) - - context.apply(plan) - - -@freeze_time("2023-01-08 15:00:00") -def test_plan_set_choice_is_reflected_in_missing_intervals(init_and_plan_context: t.Callable): - context, plan = init_and_plan_context("examples/sushi") - context.apply(plan) - - model_name = "sushi.waiter_revenue_by_day" - - model = context.get_model(model_name) - context.upsert_model(add_projection_to_model(t.cast(SqlModel, model))) - snapshot = context.get_snapshot(model, raise_if_missing=True) - top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) - - plan_builder = context.plan_builder("dev", skip_tests=True) - plan = plan_builder.build() - assert len(plan.new_snapshots) == 2 - assert ( - plan.context_diff.snapshots[snapshot.snapshot_id].change_category - == SnapshotChangeCategory.NON_BREAKING - ) - assert ( - plan.context_diff.snapshots[top_waiters_snapshot.snapshot_id].change_category - == SnapshotChangeCategory.INDIRECT_NON_BREAKING - ) - assert plan.start == to_timestamp("2023-01-01") - assert plan.missing_intervals == [ - SnapshotIntervals( - snapshot_id=snapshot.snapshot_id, - intervals=[ - (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), - (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), - (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), - (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), - (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), - (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), - (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), - ], - ), - ] - - # Change the category to BREAKING - plan = plan_builder.set_choice( - plan.context_diff.snapshots[snapshot.snapshot_id], SnapshotChangeCategory.BREAKING - ).build() - assert ( - plan.context_diff.snapshots[snapshot.snapshot_id].change_category - == SnapshotChangeCategory.BREAKING - ) - assert ( - plan.context_diff.snapshots[top_waiters_snapshot.snapshot_id].change_category - == SnapshotChangeCategory.INDIRECT_BREAKING - ) - assert plan.missing_intervals == [ - SnapshotIntervals( - snapshot_id=top_waiters_snapshot.snapshot_id, - intervals=[ - (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), - (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), - (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), - (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), - (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), - (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), - (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), - ], - ), - SnapshotIntervals( - snapshot_id=snapshot.snapshot_id, - intervals=[ - (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), - (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), - (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), - (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), - (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), - (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), - (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), - ], - ), - ] - - # Change the category back to NON_BREAKING - plan = plan_builder.set_choice( - plan.context_diff.snapshots[snapshot.snapshot_id], SnapshotChangeCategory.NON_BREAKING - ).build() - assert ( - plan.context_diff.snapshots[snapshot.snapshot_id].change_category - == SnapshotChangeCategory.NON_BREAKING - ) - assert ( - plan.context_diff.snapshots[top_waiters_snapshot.snapshot_id].change_category - == SnapshotChangeCategory.INDIRECT_NON_BREAKING - ) - assert plan.missing_intervals == [ - SnapshotIntervals( - snapshot_id=snapshot.snapshot_id, - intervals=[ - (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), - (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), - (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), - (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), - (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), - (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), - (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), - ], - ), - ] - - context.apply(plan) - - dev_df = context.engine_adapter.fetchdf( - "SELECT DISTINCT event_date FROM sushi__dev.waiter_revenue_by_day ORDER BY event_date" - ) - assert dev_df["event_date"].tolist() == [ - pd.to_datetime(x) - for x in [ - "2023-01-01", - "2023-01-02", - "2023-01-03", - "2023-01-04", - "2023-01-05", - "2023-01-06", - "2023-01-07", - ] - ] - - # Promote changes to prod - prod_plan = context.plan(no_prompts=True, skip_tests=True) - assert not prod_plan.missing_intervals - - context.apply(prod_plan) - prod_df = context.engine_adapter.fetchdf( - "SELECT DISTINCT event_date FROM sushi.waiter_revenue_by_day WHERE one IS NOT NULL ORDER BY event_date" - ) - assert prod_df["event_date"].tolist() == [ - pd.to_datetime(x) - for x in [ - "2023-01-01", - "2023-01-02", - "2023-01-03", - "2023-01-04", - "2023-01-05", - "2023-01-06", - "2023-01-07", - ] - ] - - -@freeze_time("2023-01-08 15:00:00") -@pytest.mark.parametrize("has_view_binding", [False, True]) -def test_non_breaking_change_after_forward_only_in_dev( - init_and_plan_context: t.Callable, has_view_binding: bool -): - context, plan = init_and_plan_context("examples/sushi") - context.snapshot_evaluator.adapter.HAS_VIEW_BINDING = has_view_binding - context.apply(plan) - - model = context.get_model("sushi.waiter_revenue_by_day") - context.upsert_model(add_projection_to_model(t.cast(SqlModel, model))) - waiter_revenue_by_day_snapshot = context.get_snapshot( - "sushi.waiter_revenue_by_day", raise_if_missing=True - ) - top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) - - plan = context.plan("dev", no_prompts=True, skip_tests=True, forward_only=True) - assert len(plan.new_snapshots) == 2 - assert ( - plan.context_diff.snapshots[waiter_revenue_by_day_snapshot.snapshot_id].change_category - == SnapshotChangeCategory.FORWARD_ONLY - ) - assert ( - plan.context_diff.snapshots[top_waiters_snapshot.snapshot_id].change_category - == SnapshotChangeCategory.FORWARD_ONLY - ) - assert plan.start == pd.to_datetime("2023-01-07") - assert plan.missing_intervals == [ - SnapshotIntervals( - snapshot_id=top_waiters_snapshot.snapshot_id, - intervals=[(to_timestamp("2023-01-07"), to_timestamp("2023-01-08"))], - ), - SnapshotIntervals( - snapshot_id=waiter_revenue_by_day_snapshot.snapshot_id, - intervals=[(to_timestamp("2023-01-07"), to_timestamp("2023-01-08"))], - ), - ] - - # Apply the forward-only changes first. - context.apply(plan) - - dev_df = context.engine_adapter.fetchdf( - "SELECT DISTINCT event_date FROM sushi__dev.waiter_revenue_by_day ORDER BY event_date" - ) - assert dev_df["event_date"].tolist() == [pd.to_datetime("2023-01-07")] - - # FIXME: Due to freezgun freezing the time, all interval records have the same creation timestamp. - # As a result removal records are always being applied after any addition records. Running the plan repeatedly - # to make sure there are no missing intervals. - context._run_janitor() - context.plan("dev", no_prompts=True, skip_tests=True, auto_apply=True) - - # Make a non-breaking change to a model downstream. - model = context.get_model("sushi.top_waiters") - # Select 'one' column from the updated upstream model. - context.upsert_model(add_projection_to_model(t.cast(SqlModel, model), literal=False)) - top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) - - plan = context.plan("dev", no_prompts=True, skip_tests=True) - assert len(plan.new_snapshots) == 1 - assert ( - plan.context_diff.snapshots[top_waiters_snapshot.snapshot_id].change_category - == SnapshotChangeCategory.NON_BREAKING - ) - assert plan.start == to_timestamp("2023-01-01") - assert plan.missing_intervals == [ - SnapshotIntervals( - snapshot_id=top_waiters_snapshot.snapshot_id, - intervals=[ - (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), - (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), - (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), - (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), - (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), - (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), - (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), - ], - ), - ] - - # Apply the non-breaking changes. - context.apply(plan) - - dev_df = context.engine_adapter.fetchdf( - "SELECT DISTINCT waiter_id FROM sushi__dev.top_waiters WHERE one IS NOT NULL" - ) - assert not dev_df.empty - - prod_df = context.engine_adapter.fetchdf("DESCRIBE sushi.top_waiters") - assert "one" not in prod_df["column_name"].tolist() - - # Deploy both changes to prod. - plan = context.plan("prod", no_prompts=True, skip_tests=True) - assert plan.start == to_timestamp("2023-01-01") - assert plan.missing_intervals == [ - SnapshotIntervals( - snapshot_id=top_waiters_snapshot.snapshot_id, - intervals=[ - (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), - (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), - (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), - (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), - (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), - (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), - (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), - ], - ), - ] - - context.apply(plan) - - prod_df = context.engine_adapter.fetchdf( - "SELECT DISTINCT event_date FROM sushi.waiter_revenue_by_day WHERE one IS NOT NULL ORDER BY event_date" - ) - assert prod_df.empty - - prod_df = context.engine_adapter.fetchdf( - "SELECT DISTINCT waiter_id FROM sushi.top_waiters WHERE one IS NOT NULL" - ) - assert prod_df.empty - - -@freeze_time("2023-01-08 15:00:00") -def test_indirect_non_breaking_change_after_forward_only_in_dev(init_and_plan_context: t.Callable): - context, plan = init_and_plan_context("examples/sushi") - context.apply(plan) - - # Make sushi.orders a forward-only model. - model = context.get_model("sushi.orders") - updated_model_kind = model.kind.copy(update={"forward_only": True}) - model = model.copy(update={"stamp": "force new version", "kind": updated_model_kind}) - context.upsert_model(model) - snapshot = context.get_snapshot(model, raise_if_missing=True) - - plan = context.plan("dev", no_prompts=True, skip_tests=True) - assert ( - plan.context_diff.snapshots[snapshot.snapshot_id].change_category - == SnapshotChangeCategory.FORWARD_ONLY - ) - assert not plan.requires_backfill - context.apply(plan) - - # Make a non-breaking change to a model. - model = context.get_model("sushi.top_waiters") - context.upsert_model(add_projection_to_model(t.cast(SqlModel, model))) - top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) - - plan = context.plan("dev", no_prompts=True, skip_tests=True) - assert len(plan.new_snapshots) == 1 - assert ( - plan.context_diff.snapshots[top_waiters_snapshot.snapshot_id].change_category - == SnapshotChangeCategory.NON_BREAKING - ) - assert plan.start == to_timestamp("2023-01-01") - assert plan.missing_intervals == [ - SnapshotIntervals( - snapshot_id=top_waiters_snapshot.snapshot_id, - intervals=[ - (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), - (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), - (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), - (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), - (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), - (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), - (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), - ], - ), - ] - - # Apply the non-breaking changes. - context.apply(plan) - - # Make a non-breaking change upstream from the previously modified model. - model = context.get_model("sushi.waiter_revenue_by_day") - context.upsert_model(add_projection_to_model(t.cast(SqlModel, model))) - waiter_revenue_by_day_snapshot = context.get_snapshot( - "sushi.waiter_revenue_by_day", raise_if_missing=True - ) - top_waiters_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) - - plan = context.plan("dev", no_prompts=True, skip_tests=True) - assert len(plan.new_snapshots) == 2 - assert ( - plan.context_diff.snapshots[waiter_revenue_by_day_snapshot.snapshot_id].change_category - == SnapshotChangeCategory.NON_BREAKING - ) - assert ( - plan.context_diff.snapshots[top_waiters_snapshot.snapshot_id].change_category - == SnapshotChangeCategory.INDIRECT_NON_BREAKING - ) - assert plan.start == to_timestamp("2023-01-01") - assert plan.missing_intervals == [ - SnapshotIntervals( - snapshot_id=waiter_revenue_by_day_snapshot.snapshot_id, - intervals=[ - (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), - (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), - (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), - (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), - (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), - (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), - (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), - ], - ), - ] - - # Apply the upstream non-breaking changes. - context.apply(plan) - assert not context.plan("dev", no_prompts=True, skip_tests=True).requires_backfill - - # Deploy everything to prod. - plan = context.plan("prod", no_prompts=True, skip_tests=True) - assert plan.start == to_timestamp("2023-01-01") - assert plan.missing_intervals == [ - SnapshotIntervals( - snapshot_id=top_waiters_snapshot.snapshot_id, - intervals=[ - (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), - (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), - (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), - (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), - (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), - (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), - (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), - ], - ), - SnapshotIntervals( - snapshot_id=waiter_revenue_by_day_snapshot.snapshot_id, - intervals=[ - (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), - (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), - (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), - (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), - (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), - (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), - (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), - ], - ), - ] - - context.apply(plan) - assert not context.plan("prod", no_prompts=True, skip_tests=True).requires_backfill - - -@freeze_time("2023-01-08 15:00:00") -def test_forward_only_precedence_over_indirect_non_breaking(init_and_plan_context: t.Callable): - context, plan = init_and_plan_context("examples/sushi") - context.apply(plan) - - # Make sushi.orders a forward-only model. - forward_only_model = context.get_model("sushi.orders") - updated_model_kind = forward_only_model.kind.copy(update={"forward_only": True}) - forward_only_model = forward_only_model.copy( - update={"stamp": "force new version", "kind": updated_model_kind} - ) - context.upsert_model(forward_only_model) - forward_only_snapshot = context.get_snapshot(forward_only_model, raise_if_missing=True) - - non_breaking_model = context.get_model("sushi.waiter_revenue_by_day") - context.upsert_model(add_projection_to_model(t.cast(SqlModel, non_breaking_model))) - non_breaking_snapshot = context.get_snapshot(non_breaking_model, raise_if_missing=True) - top_waiter_snapshot = context.get_snapshot("sushi.top_waiters", raise_if_missing=True) - - plan = context.plan("dev", no_prompts=True, skip_tests=True) - assert ( - plan.context_diff.snapshots[forward_only_snapshot.snapshot_id].change_category - == SnapshotChangeCategory.FORWARD_ONLY - ) - assert ( - plan.context_diff.snapshots[non_breaking_snapshot.snapshot_id].change_category - == SnapshotChangeCategory.NON_BREAKING - ) - assert ( - plan.context_diff.snapshots[top_waiter_snapshot.snapshot_id].change_category - == SnapshotChangeCategory.FORWARD_ONLY - ) - assert plan.start == to_timestamp("2023-01-01") - assert plan.missing_intervals == [ - SnapshotIntervals( - snapshot_id=non_breaking_snapshot.snapshot_id, - intervals=[ - (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), - (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), - (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), - (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), - (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), - (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), - (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), - ], - ), - ] - - context.apply(plan) - assert not context.plan("dev", no_prompts=True, skip_tests=True).requires_backfill - - # Deploy everything to prod. - plan = context.plan("prod", no_prompts=True, skip_tests=True) - assert plan.start == to_timestamp("2023-01-01") - assert plan.missing_intervals == [ - SnapshotIntervals( - snapshot_id=non_breaking_snapshot.snapshot_id, - intervals=[ - (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), - (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), - (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), - (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), - (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), - (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), - (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), - ], - ), - ] - - context.apply(plan) - assert not context.plan("prod", no_prompts=True, skip_tests=True).requires_backfill - - -@freeze_time("2023-01-08 15:00:00") -def test_select_models(init_and_plan_context: t.Callable): - context, plan = init_and_plan_context("examples/sushi") - context.apply(plan) - - # Modify 2 models. - model = context.get_model("sushi.waiter_revenue_by_day") - kwargs = { - **model.dict(), - # Make a breaking change. - "query": model.query.order_by("waiter_id"), # type: ignore - } - context.upsert_model(SqlModel.parse_obj(kwargs)) - - model = context.get_model("sushi.customer_revenue_by_day") - context.upsert_model(add_projection_to_model(t.cast(SqlModel, model))) - - expected_intervals = [ - (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), - (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), - (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), - (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), - (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), - (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), - (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), - ] - - waiter_revenue_by_day_snapshot_id = context.get_snapshot( - "sushi.waiter_revenue_by_day", raise_if_missing=True - ).snapshot_id - - # Select one of the modified models. - plan_builder = context.plan_builder( - "dev", select_models=["+*waiter_revenue_by_day"], skip_tests=True - ) - snapshot = plan_builder._context_diff.snapshots[waiter_revenue_by_day_snapshot_id] - plan_builder.set_choice(snapshot, SnapshotChangeCategory.BREAKING) - plan = plan_builder.build() - - assert plan.missing_intervals == [ - SnapshotIntervals( - snapshot_id=waiter_revenue_by_day_snapshot_id, - intervals=expected_intervals, - ), - ] - - context.apply(plan) - - dev_df = context.engine_adapter.fetchdf( - "SELECT DISTINCT event_date FROM sushi__dev.waiter_revenue_by_day ORDER BY event_date" - ) - assert len(dev_df) == 7 - - # Make sure that we only create a view for the selected model. - schema_objects = context.engine_adapter.get_data_objects("sushi__dev") - assert len(schema_objects) == 1 - assert schema_objects[0].name == "waiter_revenue_by_day" - - # Validate the other modified model. - assert not context.get_snapshot("sushi.customer_revenue_by_day").change_category - assert not context.get_snapshot("sushi.customer_revenue_by_day").version - - # Validate the downstream model. - assert not context.engine_adapter.table_exists( - context.get_snapshot("sushi.top_waiters").table_name() - ) - assert not context.engine_adapter.table_exists( - context.get_snapshot("sushi.top_waiters").table_name(False) - ) - - # Make sure that tables are created when deploying to prod. - plan = context.plan("prod", skip_tests=True) - context.apply(plan) - assert context.engine_adapter.table_exists( - context.get_snapshot("sushi.top_waiters").table_name() - ) - assert context.engine_adapter.table_exists( - context.get_snapshot("sushi.top_waiters").table_name(False) - ) - - -@freeze_time("2023-01-08 15:00:00") -def test_select_models_for_backfill(init_and_plan_context: t.Callable): - context, _ = init_and_plan_context("examples/sushi") - - expected_intervals = [ - (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), - (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), - (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), - (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), - (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), - (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), - (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), - ] - - plan = context.plan( - "dev", backfill_models=["*waiter_revenue_by_day"], no_prompts=True, skip_tests=True - ) - - assert plan.missing_intervals == [ - SnapshotIntervals( - snapshot_id=context.get_snapshot("sushi.items", raise_if_missing=True).snapshot_id, - intervals=expected_intervals, - ), - SnapshotIntervals( - snapshot_id=context.get_snapshot( - "sushi.order_items", raise_if_missing=True - ).snapshot_id, - intervals=expected_intervals, - ), - SnapshotIntervals( - snapshot_id=context.get_snapshot("sushi.orders", raise_if_missing=True).snapshot_id, - intervals=expected_intervals, - ), - SnapshotIntervals( - snapshot_id=context.get_snapshot( - "sushi.waiter_revenue_by_day", raise_if_missing=True - ).snapshot_id, - intervals=expected_intervals, - ), - ] - - context.apply(plan) - - dev_df = context.engine_adapter.fetchdf( - "SELECT DISTINCT event_date FROM sushi__dev.waiter_revenue_by_day ORDER BY event_date" - ) - assert len(dev_df) == 7 - - schema_objects = context.engine_adapter.get_data_objects("sushi__dev") - assert {o.name for o in schema_objects} == { - "items", - "order_items", - "orders", - "waiter_revenue_by_day", - } - - assert not context.engine_adapter.table_exists( - context.get_snapshot("sushi.customer_revenue_by_day").table_name() - ) - - # Make sure that tables are created when deploying to prod. - plan = context.plan("prod") - context.apply(plan) - assert context.engine_adapter.table_exists( - context.get_snapshot("sushi.customer_revenue_by_day").table_name() - ) - - -@freeze_time("2023-01-08 15:00:00") -def test_dbt_select_star_is_directly_modified(sushi_test_dbt_context: Context): - context = sushi_test_dbt_context - - model = context.get_model("sushi.simple_model_a") - context.upsert_model( - SqlModel.parse_obj( - { - **model.dict(), - "query": d.parse_one("SELECT 1 AS a, 2 AS b"), - } - ) - ) - - snapshot_a_id = context.get_snapshot("sushi.simple_model_a").snapshot_id # type: ignore - snapshot_b_id = context.get_snapshot("sushi.simple_model_b").snapshot_id # type: ignore - - plan = context.plan_builder("dev", skip_tests=True).build() - assert plan.directly_modified == {snapshot_a_id, snapshot_b_id} - assert {i.snapshot_id for i in plan.missing_intervals} == {snapshot_a_id, snapshot_b_id} - - assert plan.snapshots[snapshot_a_id].change_category == SnapshotChangeCategory.NON_BREAKING - assert plan.snapshots[snapshot_b_id].change_category == SnapshotChangeCategory.NON_BREAKING - - -@freeze_time("2023-01-08 15:00:00") -def test_incremental_by_partition(init_and_plan_context: t.Callable): - context, plan = init_and_plan_context("examples/sushi") - context.apply(plan) - - source_name = "raw.test_incremental_by_partition" - model_name = "memory.sushi.test_incremental_by_partition" - - expressions = d.parse( - f""" - MODEL ( - name {model_name}, - kind INCREMENTAL_BY_PARTITION, - partitioned_by [key], - allow_partials true, - ); - - SELECT key, value FROM {source_name}; - """ - ) - model = load_sql_based_model(expressions) - context.upsert_model(model) - - context.engine_adapter.ctas( - source_name, - d.parse_one("SELECT 'key_a' AS key, 1 AS value"), - ) - - context.plan(auto_apply=True, no_prompts=True) - assert context.engine_adapter.fetchall(f"SELECT * FROM {model_name}") == [ - ("key_a", 1), - ] - - context.engine_adapter.replace_query( - source_name, - d.parse_one("SELECT 'key_b' AS key, 1 AS value"), - ) - context.run(ignore_cron=True) - assert context.engine_adapter.fetchall(f"SELECT * FROM {model_name}") == [ - ("key_a", 1), - ("key_b", 1), - ] - - context.engine_adapter.replace_query( - source_name, - d.parse_one("SELECT 'key_a' AS key, 2 AS value"), - ) - context.run(ignore_cron=True) - assert context.engine_adapter.fetchall(f"SELECT * FROM {model_name}") == [ - ("key_b", 1), - ("key_a", 2), - ] - - -@freeze_time("2023-01-08 15:00:00") -def test_custom_materialization(init_and_plan_context: t.Callable): - context, _ = init_and_plan_context("examples/sushi") - - custom_insert_called = False - - class CustomFullMaterialization(CustomMaterialization): - NAME = "test_custom_full" - - def insert( - self, - table_name: str, - query_or_df: QueryOrDF, - model: Model, - is_first_insert: bool, - **kwargs: t.Any, - ) -> None: - nonlocal custom_insert_called - custom_insert_called = True - - self._replace_query_for_model(model, table_name, query_or_df) - - model = context.get_model("sushi.top_waiters") - kwargs = { - **model.dict(), - # Make a breaking change. - "kind": dict(name="CUSTOM", materialization="test_custom_full"), - } - context.upsert_model(SqlModel.parse_obj(kwargs)) - - context.plan(auto_apply=True, no_prompts=True) - - assert custom_insert_called - - -@freeze_time("2023-01-08 15:00:00") -def test_ignored_snapshot_with_non_deployable_downstream(init_and_plan_context: t.Callable): - context, _ = init_and_plan_context("examples/sushi") - - downstream_model_name = "memory.sushi.customer_max_revenue" - - expressions = d.parse( - f""" - MODEL ( - name {downstream_model_name}, - kind INCREMENTAL_BY_UNIQUE_KEY ( - unique_key customer_id, - forward_only true, - ), - ); - - SELECT - customer_id, MAX(revenue) AS max_revenue - FROM memory.sushi.customer_revenue_lifetime - GROUP BY 1; - """ - ) - - downstream_model = load_sql_based_model(expressions) - assert downstream_model.forward_only - context.upsert_model(downstream_model) - - context.plan(auto_apply=True, no_prompts=True) - - customer_revenue_lifetime_model = context.get_model("sushi.customer_revenue_lifetime") - kwargs = { - **customer_revenue_lifetime_model.dict(), - "name": "memory.sushi.customer_revenue_lifetime_new", - "kind": dict( - name="INCREMENTAL_UNMANAGED" - ), # Make it incremental unmanaged to ensure the depends_on_past behavior. - } - context.upsert_model(SqlModel.parse_obj(kwargs)) - context.upsert_model( - downstream_model_name, - query=d.parse_one( - "SELECT customer_id, MAX(revenue) AS max_revenue FROM memory.sushi.customer_revenue_lifetime_new GROUP BY 1" - ), - ) - - plan = context.plan("dev", no_prompts=True, enable_preview=True) - assert {s.name for s in plan.ignored} == { - '"memory"."sushi"."customer_revenue_lifetime_new"', - '"memory"."sushi"."customer_max_revenue"', - } - assert not plan.new_snapshots - assert not plan.missing_intervals - - -@freeze_time("2023-01-08 15:00:00") -def test_restatement_plan_ignores_changes(init_and_plan_context: t.Callable): - context, plan = init_and_plan_context("examples/sushi") - context.apply(plan) - - restated_snapshot = context.get_snapshot("sushi.top_waiters") - - # Simulate a change. - model = context.get_model("sushi.waiter_revenue_by_day") - context.upsert_model(add_projection_to_model(t.cast(SqlModel, model))) - - plan = context.plan(no_prompts=True, restate_models=["sushi.top_waiters"], start="2023-01-07") - assert plan.snapshots != context.snapshots - - assert not plan.directly_modified - assert not plan.has_changes - assert not plan.new_snapshots - assert plan.requires_backfill - assert plan.restatements == { - restated_snapshot.snapshot_id: (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")) - } - assert plan.missing_intervals == [ - SnapshotIntervals( - snapshot_id=restated_snapshot.snapshot_id, - intervals=[(to_timestamp("2023-01-07"), to_timestamp("2023-01-08"))], - ) - ] - - context.apply(plan) - - -def test_plan_twice_with_star_macro_yields_no_diff(tmp_path: Path): - init_example_project(tmp_path, dialect="duckdb") - - star_model_definition = """ - MODEL ( - name sqlmesh_example.star_model, - kind FULL - ); - - SELECT @STAR(sqlmesh_example.full_model) FROM sqlmesh_example.full_model - """ - - star_model_path = tmp_path / "models" / "star_model.sql" - star_model_path.write_text(star_model_definition) - - db_path = str(tmp_path / "db.db") - config = Config( - gateways={"main": GatewayConfig(connection=DuckDBConnectionConfig(database=db_path))}, - model_defaults=ModelDefaultsConfig(dialect="duckdb"), - ) - context = Context(paths=tmp_path, config=config) - context.plan(auto_apply=True, no_prompts=True) - - # Instantiate new context to remove caches etc - new_context = Context(paths=tmp_path, config=config) - - star_model = new_context.get_model("sqlmesh_example.star_model") - assert ( - star_model.render_query_or_raise().sql() - == 'SELECT CAST("full_model"."item_id" AS INT) AS "item_id", CAST("full_model"."num_orders" AS BIGINT) AS "num_orders" FROM "db"."sqlmesh_example"."full_model" AS "full_model"' - ) - - new_plan = new_context.plan(no_prompts=True) - assert not new_plan.has_changes - assert not new_plan.new_snapshots - - -@pytest.mark.parametrize( - "context_fixture", - ["sushi_context", "sushi_dbt_context", "sushi_test_dbt_context", "sushi_no_default_catalog"], -) -def test_model_add(context_fixture: Context, request): - initial_add(request.getfixturevalue(context_fixture), "dev") - - -def test_model_removed(sushi_context: Context): - environment = "dev" - initial_add(sushi_context, environment) - - top_waiters_snapshot_id = sushi_context.get_snapshot( - "sushi.top_waiters", raise_if_missing=True - ).snapshot_id - - sushi_context._models.pop('"memory"."sushi"."top_waiters"') - - def _validate_plan(context, plan): - validate_plan_changes(plan, removed=[top_waiters_snapshot_id]) - assert not plan.missing_intervals - - def _validate_apply(context): - assert not sushi_context.get_snapshot("sushi.top_waiters", raise_if_missing=False) - assert sushi_context.state_reader.get_snapshots([top_waiters_snapshot_id]) - env = sushi_context.state_reader.get_environment(environment) - assert env - assert all(snapshot.name != '"memory"."sushi"."top_waiters"' for snapshot in env.snapshots) - - apply_to_environment( - sushi_context, - environment, - SnapshotChangeCategory.BREAKING, - plan_validators=[_validate_plan], - apply_validators=[_validate_apply], - ) - - -def test_non_breaking_change(sushi_context: Context): - environment = "dev" - initial_add(sushi_context, environment) - validate_query_change(sushi_context, environment, SnapshotChangeCategory.NON_BREAKING, False) - - -def test_breaking_change(sushi_context: Context): - environment = "dev" - initial_add(sushi_context, environment) - validate_query_change(sushi_context, environment, SnapshotChangeCategory.BREAKING, False) - - -def test_forward_only(sushi_context: Context): - environment = "dev" - initial_add(sushi_context, environment) - validate_query_change(sushi_context, environment, SnapshotChangeCategory.FORWARD_ONLY, False) - - -def test_logical_change(sushi_context: Context): - environment = "dev" - initial_add(sushi_context, environment) - previous_sushi_items_version = sushi_context.get_snapshot( - "sushi.items", raise_if_missing=True - ).version - - change_data_type( - sushi_context, - "sushi.items", - DataType.Type.DOUBLE, - DataType.Type.FLOAT, - ) - apply_to_environment(sushi_context, environment, SnapshotChangeCategory.NON_BREAKING) - - change_data_type( - sushi_context, - "sushi.items", - DataType.Type.FLOAT, - DataType.Type.DOUBLE, - ) - apply_to_environment(sushi_context, environment, SnapshotChangeCategory.NON_BREAKING) - - assert ( - sushi_context.get_snapshot("sushi.items", raise_if_missing=True).version - == previous_sushi_items_version - ) - - -def validate_query_change( - context: Context, - environment: str, - change_category: SnapshotChangeCategory, - logical: bool, -): - versions = snapshots_to_versions(context.snapshots.values()) - - change_data_type( - context, - "sushi.items", - DataType.Type.DOUBLE, - DataType.Type.FLOAT, - ) - - directly_modified = ['"memory"."sushi"."items"'] - indirectly_modified = [ - '"memory"."sushi"."order_items"', - '"memory"."sushi"."waiter_revenue_by_day"', - '"memory"."sushi"."customer_revenue_by_day"', - '"memory"."sushi"."customer_revenue_lifetime"', - '"memory"."sushi"."top_waiters"', - "assert_item_price_above_zero", - ] - not_modified = [ - snapshot.name - for snapshot in context.snapshots.values() - if snapshot.name not in directly_modified and snapshot.name not in indirectly_modified - ] - - if change_category == SnapshotChangeCategory.BREAKING and not logical: - models_same = not_modified - models_different = directly_modified + indirectly_modified - elif change_category == SnapshotChangeCategory.FORWARD_ONLY: - models_same = not_modified + directly_modified + indirectly_modified - models_different = [] - else: - models_same = not_modified + indirectly_modified - models_different = directly_modified - - def _validate_plan(context, plan): - validate_plan_changes(plan, modified=directly_modified + indirectly_modified) - assert bool(plan.missing_intervals) != logical - - def _validate_apply(context): - current_versions = snapshots_to_versions(context.snapshots.values()) - validate_versions_same(models_same, versions, current_versions) - validate_versions_different(models_different, versions, current_versions) - - apply_to_environment( - context, - environment, - change_category, - plan_validators=[_validate_plan], - apply_validators=[_validate_apply], - ) - - -@pytest.mark.parametrize( - "from_, to", - [ - (ModelKindName.INCREMENTAL_BY_TIME_RANGE, ModelKindName.VIEW), - (ModelKindName.INCREMENTAL_BY_TIME_RANGE, ModelKindName.EMBEDDED), - (ModelKindName.INCREMENTAL_BY_TIME_RANGE, ModelKindName.FULL), - (ModelKindName.VIEW, ModelKindName.EMBEDDED), - (ModelKindName.VIEW, ModelKindName.FULL), - (ModelKindName.VIEW, ModelKindName.INCREMENTAL_BY_TIME_RANGE), - (ModelKindName.EMBEDDED, ModelKindName.VIEW), - (ModelKindName.EMBEDDED, ModelKindName.FULL), - (ModelKindName.EMBEDDED, ModelKindName.INCREMENTAL_BY_TIME_RANGE), - (ModelKindName.FULL, ModelKindName.VIEW), - (ModelKindName.FULL, ModelKindName.EMBEDDED), - (ModelKindName.FULL, ModelKindName.INCREMENTAL_BY_TIME_RANGE), - ], -) -def test_model_kind_change(from_: ModelKindName, to: ModelKindName, sushi_context: Context): - environment = f"test_model_kind_change__{from_.value.lower()}__{to.value.lower()}" - incremental_snapshot = sushi_context.get_snapshot("sushi.items", raise_if_missing=True).copy() - - if from_ != ModelKindName.INCREMENTAL_BY_TIME_RANGE: - change_model_kind(sushi_context, from_) - apply_to_environment(sushi_context, environment, SnapshotChangeCategory.NON_BREAKING) - - if to == ModelKindName.INCREMENTAL_BY_TIME_RANGE: - sushi_context.upsert_model(incremental_snapshot.model) - else: - change_model_kind(sushi_context, to) - - logical = to in (ModelKindName.INCREMENTAL_BY_TIME_RANGE, ModelKindName.EMBEDDED) - validate_model_kind_change(to, sushi_context, environment, logical=logical) - - -def change_model_kind(context: Context, kind: ModelKindName): - if kind in (ModelKindName.VIEW, ModelKindName.EMBEDDED, ModelKindName.FULL): - context.upsert_model( - "sushi.items", - partitioned_by=[], - audits=[], - ) - context.upsert_model("sushi.items", kind=model_kind_type_from_name(kind)()) # type: ignore - - -def validate_model_kind_change( - kind_name: ModelKindName, - context: Context, - environment: str, - *, - logical: bool, -): - directly_modified = ['"memory"."sushi"."items"'] - indirectly_modified = [ - '"memory"."sushi"."order_items"', - '"memory"."sushi"."waiter_revenue_by_day"', - '"memory"."sushi"."customer_revenue_by_day"', - '"memory"."sushi"."customer_revenue_lifetime"', - '"memory"."sushi"."top_waiters"', - "assert_item_price_above_zero", - ] - if kind_name == ModelKindName.INCREMENTAL_BY_TIME_RANGE: - kind: ModelKind = IncrementalByTimeRangeKind(time_column=TimeColumn(column="event_date")) - elif kind_name == ModelKindName.INCREMENTAL_BY_UNIQUE_KEY: - kind = IncrementalByUniqueKeyKind(unique_key="id") - else: - kind = model_kind_type_from_name(kind_name)() # type: ignore - - def _validate_plan(context, plan): - validate_plan_changes(plan, modified=directly_modified + indirectly_modified) - assert ( - next( - snapshot - for snapshot in plan.snapshots.values() - if snapshot.name == '"memory"."sushi"."items"' - ).model.kind.name - == kind.name - ) - assert bool(plan.missing_intervals) != logical - - apply_to_environment( - context, - environment, - SnapshotChangeCategory.NON_BREAKING, - plan_validators=[_validate_plan], - ) - - -def test_environment_isolation(sushi_context: Context): - prod_snapshots = sushi_context.snapshots.values() - - change_data_type( - sushi_context, - "sushi.items", - DataType.Type.DOUBLE, - DataType.Type.FLOAT, - ) - directly_modified = ['"memory"."sushi"."items"'] - indirectly_modified = [ - '"memory"."sushi"."order_items"', - '"memory"."sushi"."waiter_revenue_by_day"', - '"memory"."sushi"."customer_revenue_by_day"', - '"memory"."sushi"."customer_revenue_lifetime"', - '"memory"."sushi"."top_waiters"', - "assert_item_price_above_zero", - ] - - apply_to_environment(sushi_context, "dev", SnapshotChangeCategory.BREAKING) - - # Verify prod unchanged - validate_apply_basics(sushi_context, "prod", prod_snapshots) - - def _validate_plan(context, plan): - validate_plan_changes(plan, modified=directly_modified + indirectly_modified) - assert not plan.missing_intervals - - apply_to_environment( - sushi_context, - "prod", - SnapshotChangeCategory.BREAKING, - plan_validators=[_validate_plan], - ) - - -def test_environment_promotion(sushi_context: Context): - initial_add(sushi_context, "dev") - - # Simulate prod "ahead" - change_data_type(sushi_context, "sushi.items", DataType.Type.DOUBLE, DataType.Type.FLOAT) - apply_to_environment(sushi_context, "prod", SnapshotChangeCategory.BREAKING) - - # Simulate rebase - apply_to_environment(sushi_context, "dev", SnapshotChangeCategory.BREAKING) - - # Make changes in dev - change_data_type(sushi_context, "sushi.items", DataType.Type.FLOAT, DataType.Type.DECIMAL) - apply_to_environment(sushi_context, "dev", SnapshotChangeCategory.NON_BREAKING) - - change_data_type(sushi_context, "sushi.top_waiters", DataType.Type.DOUBLE, DataType.Type.INT) - apply_to_environment(sushi_context, "dev", SnapshotChangeCategory.BREAKING) - - change_data_type( - sushi_context, - "sushi.customer_revenue_by_day", - DataType.Type.DOUBLE, - DataType.Type.FLOAT, - ) - apply_to_environment( - sushi_context, - "dev", - SnapshotChangeCategory.FORWARD_ONLY, - allow_destructive_models=['"memory"."sushi"."customer_revenue_by_day"'], - ) - - # Promote to prod - def _validate_plan(context, plan): - sushi_items_snapshot = context.get_snapshot("sushi.items", raise_if_missing=True) - sushi_top_waiters_snapshot = context.get_snapshot( - "sushi.top_waiters", raise_if_missing=True - ) - sushi_customer_revenue_by_day_snapshot = context.get_snapshot( - "sushi.customer_revenue_by_day", raise_if_missing=True - ) - - assert ( - plan.context_diff.modified_snapshots[sushi_items_snapshot.name][0].change_category - == SnapshotChangeCategory.NON_BREAKING - ) - assert ( - plan.context_diff.modified_snapshots[sushi_top_waiters_snapshot.name][0].change_category - == SnapshotChangeCategory.BREAKING - ) - assert ( - plan.context_diff.modified_snapshots[sushi_customer_revenue_by_day_snapshot.name][ - 0 - ].change_category - == SnapshotChangeCategory.FORWARD_ONLY - ) - - apply_to_environment( - sushi_context, - "prod", - SnapshotChangeCategory.NON_BREAKING, - plan_validators=[_validate_plan], - allow_destructive_models=['"memory"."sushi"."customer_revenue_by_day"'], - ) - - -def test_no_override(sushi_context: Context) -> None: - change_data_type( - sushi_context, - "sushi.items", - DataType.Type.INT, - DataType.Type.BIGINT, - ) - - change_data_type( - sushi_context, - "sushi.order_items", - DataType.Type.INT, - DataType.Type.BIGINT, - ) - - plan_builder = sushi_context.plan_builder("prod") - plan = plan_builder.build() - - sushi_items_snapshot = sushi_context.get_snapshot("sushi.items", raise_if_missing=True) - sushi_order_items_snapshot = sushi_context.get_snapshot( - "sushi.order_items", raise_if_missing=True - ) - sushi_water_revenue_by_day_snapshot = sushi_context.get_snapshot( - "sushi.waiter_revenue_by_day", raise_if_missing=True - ) - - items = plan.context_diff.snapshots[sushi_items_snapshot.snapshot_id] - order_items = plan.context_diff.snapshots[sushi_order_items_snapshot.snapshot_id] - waiter_revenue = plan.context_diff.snapshots[sushi_water_revenue_by_day_snapshot.snapshot_id] - - plan_builder.set_choice(items, SnapshotChangeCategory.BREAKING).set_choice( - order_items, SnapshotChangeCategory.NON_BREAKING - ) - assert items.is_new_version - assert waiter_revenue.is_new_version - plan_builder.set_choice(items, SnapshotChangeCategory.NON_BREAKING) - assert not waiter_revenue.is_new_version - - -@pytest.mark.parametrize( - "change_categories, expected", - [ - ([SnapshotChangeCategory.NON_BREAKING], SnapshotChangeCategory.BREAKING), - ([SnapshotChangeCategory.BREAKING], SnapshotChangeCategory.BREAKING), - ( - [SnapshotChangeCategory.NON_BREAKING, SnapshotChangeCategory.NON_BREAKING], - SnapshotChangeCategory.BREAKING, - ), - ( - [SnapshotChangeCategory.NON_BREAKING, SnapshotChangeCategory.BREAKING], - SnapshotChangeCategory.BREAKING, - ), - ( - [SnapshotChangeCategory.BREAKING, SnapshotChangeCategory.NON_BREAKING], - SnapshotChangeCategory.BREAKING, - ), - ( - [SnapshotChangeCategory.BREAKING, SnapshotChangeCategory.BREAKING], - SnapshotChangeCategory.BREAKING, - ), - ], -) -def test_revert( - sushi_context: Context, - change_categories: t.List[SnapshotChangeCategory], - expected: SnapshotChangeCategory, -): - environment = "prod" - original_snapshot_id = sushi_context.get_snapshot("sushi.items", raise_if_missing=True) - - types = (DataType.Type.DOUBLE, DataType.Type.FLOAT, DataType.Type.DECIMAL) - assert len(change_categories) < len(types) - - for i, category in enumerate(change_categories): - change_data_type(sushi_context, "sushi.items", *types[i : i + 2]) - apply_to_environment(sushi_context, environment, category) - assert ( - sushi_context.get_snapshot("sushi.items", raise_if_missing=True) != original_snapshot_id - ) - - change_data_type(sushi_context, "sushi.items", types[len(change_categories)], types[0]) - - def _validate_plan(_, plan): - snapshot = next(s for s in plan.snapshots.values() if s.name == '"memory"."sushi"."items"') - assert snapshot.change_category == expected - assert not plan.missing_intervals - - apply_to_environment( - sushi_context, - environment, - change_categories[-1], - plan_validators=[_validate_plan], - ) - assert sushi_context.get_snapshot("sushi.items", raise_if_missing=True) == original_snapshot_id - - -def test_revert_after_downstream_change(sushi_context: Context): - environment = "prod" - change_data_type(sushi_context, "sushi.items", DataType.Type.DOUBLE, DataType.Type.FLOAT) - apply_to_environment(sushi_context, environment, SnapshotChangeCategory.BREAKING) - - change_data_type( - sushi_context, - "sushi.waiter_revenue_by_day", - DataType.Type.DOUBLE, - DataType.Type.FLOAT, - ) - apply_to_environment(sushi_context, environment, SnapshotChangeCategory.NON_BREAKING) - - change_data_type(sushi_context, "sushi.items", DataType.Type.FLOAT, DataType.Type.DOUBLE) - - def _validate_plan(_, plan): - snapshot = next(s for s in plan.snapshots.values() if s.name == '"memory"."sushi"."items"') - assert snapshot.change_category == SnapshotChangeCategory.BREAKING - assert plan.missing_intervals - - apply_to_environment( - sushi_context, - environment, - SnapshotChangeCategory.BREAKING, - plan_validators=[_validate_plan], - ) - - -def test_auto_categorization(sushi_context: Context): - environment = "dev" - for config in sushi_context.configs.values(): - config.plan.auto_categorize_changes.sql = AutoCategorizationMode.FULL - initial_add(sushi_context, environment) - - version = sushi_context.get_snapshot( - "sushi.waiter_as_customer_by_day", raise_if_missing=True - ).version - fingerprint = sushi_context.get_snapshot( - "sushi.waiter_as_customer_by_day", raise_if_missing=True - ).fingerprint - - model = t.cast(SqlModel, sushi_context.get_model("sushi.customers", raise_if_missing=True)) - sushi_context.upsert_model("sushi.customers", query=model.query.select("'foo' AS foo")) # type: ignore - apply_to_environment(sushi_context, environment) - - assert ( - sushi_context.get_snapshot( - "sushi.waiter_as_customer_by_day", raise_if_missing=True - ).change_category - == SnapshotChangeCategory.INDIRECT_NON_BREAKING - ) - assert ( - sushi_context.get_snapshot( - "sushi.waiter_as_customer_by_day", raise_if_missing=True - ).fingerprint - != fingerprint - ) - assert ( - sushi_context.get_snapshot("sushi.waiter_as_customer_by_day", raise_if_missing=True).version - == version - ) - - -def test_multi(mocker): - context = Context(paths=["examples/multi/repo_1", "examples/multi/repo_2"], gateway="memory") - context._new_state_sync().reset(default_catalog=context.default_catalog) - plan = context.plan() - assert len(plan.new_snapshots) == 4 - context.apply(plan) - - context = Context( - paths=["examples/multi/repo_1"], - engine_adapter=context.engine_adapter, - state_sync=context.state_sync, - gateway="memory", - ) - model = context.get_model("bronze.a") - assert model.project == "repo_1" - context.upsert_model(model.copy(update={"query": model.query.select("'c' AS c")})) - plan = context.plan() - assert set(snapshot.name for snapshot in plan.directly_modified) == { - '"memory"."bronze"."a"', - '"memory"."bronze"."b"', - } - assert sorted([x.name for x in list(plan.indirectly_modified.values())[0]]) == [ - '"memory"."silver"."c"', - '"memory"."silver"."d"', - ] - assert len(plan.missing_intervals) == 2 - context.apply(plan) - validate_apply_basics(context, c.PROD, plan.snapshots.values()) - - -def test_multi_dbt(mocker): - context = Context(paths=["examples/multi_dbt/bronze", "examples/multi_dbt/silver"]) - plan = context.plan() - assert len(plan.new_snapshots) == 4 - context.apply(plan) - validate_apply_basics(context, c.PROD, plan.snapshots.values()) - - -def test_incremental_time_self_reference( - mocker: MockerFixture, sushi_context: Context, sushi_data_validator: SushiDataValidator -): - start_ts = to_timestamp("1 week ago") - start_date, end_date = to_date("1 week ago"), to_date("yesterday") - if to_timestamp(start_date) < start_ts: - # The start date must be aligned by the interval unit. - start_date += timedelta(days=1) - - df = sushi_context.engine_adapter.fetchdf( - "SELECT MIN(event_date) FROM sushi.customer_revenue_lifetime" - ) - assert df.iloc[0, 0] == pd.to_datetime(start_date) - df = sushi_context.engine_adapter.fetchdf( - "SELECT MAX(event_date) FROM sushi.customer_revenue_lifetime" - ) - assert df.iloc[0, 0] == pd.to_datetime(end_date) - results = sushi_data_validator.validate("sushi.customer_revenue_lifetime", start_date, end_date) - plan = sushi_context.plan( - restate_models=["sushi.customer_revenue_lifetime", "sushi.customer_revenue_by_day"], - no_prompts=True, - start=start_date, - end="5 days ago", - ) - revenue_lifeteime_snapshot = sushi_context.get_snapshot( - "sushi.customer_revenue_lifetime", raise_if_missing=True - ) - revenue_by_day_snapshot = sushi_context.get_snapshot( - "sushi.customer_revenue_by_day", raise_if_missing=True - ) - assert sorted(plan.missing_intervals, key=lambda x: x.snapshot_id) == sorted( - [ - SnapshotIntervals( - snapshot_id=revenue_lifeteime_snapshot.snapshot_id, - intervals=[ - (to_timestamp(to_date("7 days ago")), to_timestamp(to_date("6 days ago"))), - (to_timestamp(to_date("6 days ago")), to_timestamp(to_date("5 days ago"))), - (to_timestamp(to_date("5 days ago")), to_timestamp(to_date("4 days ago"))), - (to_timestamp(to_date("4 days ago")), to_timestamp(to_date("3 days ago"))), - (to_timestamp(to_date("3 days ago")), to_timestamp(to_date("2 days ago"))), - (to_timestamp(to_date("2 days ago")), to_timestamp(to_date("1 days ago"))), - (to_timestamp(to_date("1 day ago")), to_timestamp(to_date("today"))), - ], - ), - SnapshotIntervals( - snapshot_id=revenue_by_day_snapshot.snapshot_id, - intervals=[ - (to_timestamp(to_date("7 days ago")), to_timestamp(to_date("6 days ago"))), - (to_timestamp(to_date("6 days ago")), to_timestamp(to_date("5 days ago"))), - ], - ), - ], - key=lambda x: x.snapshot_id, - ) - sushi_context.console = mocker.Mock(spec=Console) - sushi_context.apply(plan) - num_batch_calls = Counter( - [x[0][0] for x in sushi_context.console.update_snapshot_evaluation_progress.call_args_list] # type: ignore - ) - # Validate that we made 7 calls to the customer_revenue_lifetime snapshot and 1 call to the customer_revenue_by_day snapshot - assert num_batch_calls == { - sushi_context.get_snapshot("sushi.customer_revenue_lifetime", raise_if_missing=True): 7, - sushi_context.get_snapshot("sushi.customer_revenue_by_day", raise_if_missing=True): 1, - } - # Validate that the results are the same as before the restate - assert results == sushi_data_validator.validate( - "sushi.customer_revenue_lifetime", start_date, end_date - ) - - -def test_invalidating_environment(sushi_context: Context): - apply_to_environment(sushi_context, "dev") - start_environment = sushi_context.state_sync.get_environment("dev") - assert start_environment is not None - metadata = DuckDBMetadata.from_context(sushi_context) - start_schemas = set(metadata.schemas) - assert "sushi__dev" in start_schemas - sushi_context.invalidate_environment("dev") - invalidate_environment = sushi_context.state_sync.get_environment("dev") - assert invalidate_environment is not None - schemas_prior_to_janitor = set(metadata.schemas) - assert invalidate_environment.expiration_ts < start_environment.expiration_ts # type: ignore - assert start_schemas == schemas_prior_to_janitor - sushi_context._run_janitor() - schemas_after_janitor = set(metadata.schemas) - assert sushi_context.state_sync.get_environment("dev") is None - assert start_schemas - schemas_after_janitor == {"sushi__dev"} - - -def test_environment_suffix_target_table(init_and_plan_context: t.Callable): - context, plan = init_and_plan_context("examples/sushi", config="environment_suffix_config") - context.apply(plan) - metadata = DuckDBMetadata.from_context(context) - environments_schemas = {"sushi"} - internal_schemas = {"sqlmesh", "sqlmesh__sushi"} - starting_schemas = environments_schemas | internal_schemas - # Make sure no new schemas are created - assert set(metadata.schemas) - starting_schemas == {"raw"} - prod_views = {x for x in metadata.qualified_views if x.db in environments_schemas} - # Make sure that all models are present - assert len(prod_views) == 13 - apply_to_environment(context, "dev") - # Make sure no new schemas are created - assert set(metadata.schemas) - starting_schemas == {"raw"} - dev_views = { - x for x in metadata.qualified_views if x.db in environments_schemas and "__dev" in x.name - } - # Make sure that there is a view with `__dev` for each view that exists in prod - assert len(dev_views) == len(prod_views) - assert {x.name.replace("__dev", "") for x in dev_views} - {x.name for x in prod_views} == set() - context.invalidate_environment("dev") - context._run_janitor() - views_after_janitor = metadata.qualified_views - # Make sure that the number of views after the janitor is the same as when you subtract away dev views - assert len(views_after_janitor) == len( - {x.sql(dialect="duckdb") for x in views_after_janitor} - - {x.sql(dialect="duckdb") for x in dev_views} - ) - # Double check there are no dev views - assert len({x for x in views_after_janitor if "__dev" in x.name}) == 0 - # Make sure prod views were not removed - assert {x.sql(dialect="duckdb") for x in prod_views} - { - x.sql(dialect="duckdb") for x in views_after_janitor - } == set() - - -def test_environment_catalog_mapping(init_and_plan_context: t.Callable): - environments_schemas = {"raw", "sushi"} - - def get_prod_dev_views(metadata: DuckDBMetadata) -> t.Tuple[t.Set[exp.Table], t.Set[exp.Table]]: - views = metadata.qualified_views - prod_views = { - x for x in views if x.catalog == "prod_catalog" if x.db in environments_schemas - } - dev_views = {x for x in views if x.catalog == "dev_catalog" if x.db in environments_schemas} - return prod_views, dev_views - - def get_default_catalog_and_non_tables( - metadata: DuckDBMetadata, default_catalog: t.Optional[str] - ) -> t.Tuple[t.Set[exp.Table], t.Set[exp.Table]]: - tables = metadata.qualified_tables - user_default_tables = { - x for x in tables if x.catalog == default_catalog and x.db != "sqlmesh" - } - non_default_tables = {x for x in tables if x.catalog != default_catalog} - return user_default_tables, non_default_tables - - context, plan = init_and_plan_context( - "examples/sushi", config="environment_catalog_mapping_config" - ) - context.apply(plan) - metadata = DuckDBMetadata(context.engine_adapter) - state_metadata = DuckDBMetadata.from_context(context.state_sync.state_sync) - prod_views, dev_views = get_prod_dev_views(metadata) - ( - user_default_tables, - non_default_tables, - ) = get_default_catalog_and_non_tables(metadata, context.default_catalog) - assert len(prod_views) == 13 - assert len(dev_views) == 0 - assert len(user_default_tables) == 24 - assert state_metadata.schemas == ["sqlmesh"] - assert {x.sql() for x in state_metadata.qualified_tables}.issuperset( - { - "physical.sqlmesh._environments", - "physical.sqlmesh._intervals", - "physical.sqlmesh._plan_dags", - "physical.sqlmesh._snapshots", - "physical.sqlmesh._versions", - } - ) - apply_to_environment(context, "dev") - prod_views, dev_views = get_prod_dev_views(metadata) - ( - user_default_tables, - non_default_tables, - ) = get_default_catalog_and_non_tables(metadata, context.default_catalog) - assert len(prod_views) == 13 - assert len(dev_views) == 13 - assert len(user_default_tables) == 24 - assert len(non_default_tables) == 0 - assert state_metadata.schemas == ["sqlmesh"] - assert {x.sql() for x in state_metadata.qualified_tables}.issuperset( - { - "physical.sqlmesh._environments", - "physical.sqlmesh._intervals", - "physical.sqlmesh._plan_dags", - "physical.sqlmesh._snapshots", - "physical.sqlmesh._versions", - } - ) - apply_to_environment(context, "prodnot") - prod_views, dev_views = get_prod_dev_views(metadata) - ( - user_default_tables, - non_default_tables, - ) = get_default_catalog_and_non_tables(metadata, context.default_catalog) - assert len(prod_views) == 13 - assert len(dev_views) == 26 - assert len(user_default_tables) == 24 - assert len(non_default_tables) == 0 - assert state_metadata.schemas == ["sqlmesh"] - assert {x.sql() for x in state_metadata.qualified_tables}.issuperset( - { - "physical.sqlmesh._environments", - "physical.sqlmesh._intervals", - "physical.sqlmesh._plan_dags", - "physical.sqlmesh._snapshots", - "physical.sqlmesh._versions", - } - ) - context.invalidate_environment("dev") - context._run_janitor() - prod_views, dev_views = get_prod_dev_views(metadata) - ( - user_default_tables, - non_default_tables, - ) = get_default_catalog_and_non_tables(metadata, context.default_catalog) - assert len(prod_views) == 13 - assert len(dev_views) == 13 - assert len(user_default_tables) == 24 - assert len(non_default_tables) == 0 - assert state_metadata.schemas == ["sqlmesh"] - assert {x.sql() for x in state_metadata.qualified_tables}.issuperset( - { - "physical.sqlmesh._environments", - "physical.sqlmesh._intervals", - "physical.sqlmesh._plan_dags", - "physical.sqlmesh._snapshots", - "physical.sqlmesh._versions", - } - ) - - -@pytest.mark.parametrize( - "context_fixture", - ["sushi_context", "sushi_no_default_catalog"], -) -def test_ignored_snapshots(context_fixture: Context, request): - context = request.getfixturevalue(context_fixture) - environment = "dev" - apply_to_environment(context, environment) - # Make breaking change to model upstream of a depends_on_self model - context.upsert_model("sushi.order_items", stamp="1") - # Apply the change starting at a date later then the beginning of the downstream depends_on_self model - plan = apply_to_environment( - context, environment, choice=SnapshotChangeCategory.BREAKING, plan_start="2 days ago" - ) - revenue_lifetime_snapshot = context.get_snapshot( - "sushi.customer_revenue_lifetime", raise_if_missing=True - ) - # Validate that the depends_on_self model is ignored - assert plan.ignored == {revenue_lifetime_snapshot.snapshot_id} - # Validate that the table was really ignored - metadata = DuckDBMetadata.from_context(context) - # Make sure prod view exists - catalog = context.default_catalog or "memory" - assert exp.table_("customer_revenue_lifetime", "sushi", catalog) in metadata.qualified_views - # Make sure dev view doesn't exist since it was ignored - assert ( - exp.table_("customer_revenue_lifetime", "sushi__dev", catalog) - not in metadata.qualified_views - ) - # Make sure that dev view for order items was created - assert exp.table_("order_items", "sushi__dev", catalog) in metadata.qualified_views - - -def initial_add(context: Context, environment: str): - assert not context.state_reader.get_environment(environment) - - plan = context.plan(environment, start=start(context), create_from="nonexistent_env") - validate_plan_changes(plan, added={x.snapshot_id for x in context.snapshots.values()}) - - context.apply(plan) - validate_apply_basics(context, environment, plan.snapshots.values()) - - -def apply_to_environment( - context: Context, - environment: str, - choice: t.Optional[SnapshotChangeCategory] = None, - plan_validators: t.Optional[t.Iterable[t.Callable]] = None, - apply_validators: t.Optional[t.Iterable[t.Callable]] = None, - plan_start: t.Optional[TimeLike] = None, - allow_destructive_models: t.Optional[t.List[str]] = None, -): - plan_validators = plan_validators or [] - apply_validators = apply_validators or [] - - plan_builder = context.plan_builder( - environment, - start=plan_start or start(context) if environment != c.PROD else None, - forward_only=choice == SnapshotChangeCategory.FORWARD_ONLY, - include_unmodified=True, - allow_destructive_models=allow_destructive_models if allow_destructive_models else [], - ) - if environment != c.PROD: - plan_builder.set_start(plan_start or start(context)) - - if choice: - plan_choice(plan_builder, choice) - for validator in plan_validators: - validator(context, plan_builder.build()) - - plan = plan_builder.build() - context.apply(plan) - - validate_apply_basics(context, environment, plan.snapshots.values()) - for validator in apply_validators: - validator(context) - return plan - - -def change_data_type( - context: Context, model_name: str, old_type: DataType.Type, new_type: DataType.Type -) -> None: - model = context.get_model(model_name) - assert model is not None - - if isinstance(model, SqlModel): - data_types = model.query.find_all(DataType) - for data_type in data_types: - if data_type.this == old_type: - data_type.set("this", new_type) - context.upsert_model(model_name, query=model.query) - elif model.columns_to_types_ is not None: - for k, v in model.columns_to_types_.items(): - if v.this == old_type: - model.columns_to_types_[k] = DataType.build(new_type) - context.upsert_model(model_name, columns=model.columns_to_types_) - - -def validate_plan_changes( - plan: Plan, - *, - added: t.Optional[t.Iterable[SnapshotId]] = None, - modified: t.Optional[t.Iterable[str]] = None, - removed: t.Optional[t.Iterable[SnapshotId]] = None, -) -> None: - added = added or [] - modified = modified or [] - removed = removed or [] - assert set(added) == plan.context_diff.added - assert set(modified) == set(plan.context_diff.modified_snapshots) - assert set(removed) == set(plan.context_diff.removed_snapshots) - - -def validate_versions_same( - model_names: t.List[str], - versions: t.Dict[str, str], - other_versions: t.Dict[str, str], -) -> None: - for name in model_names: - assert versions[name] == other_versions[name] - - -def validate_versions_different( - model_names: t.List[str], - versions: t.Dict[str, str], - other_versions: t.Dict[str, str], -) -> None: - for name in model_names: - assert versions[name] != other_versions[name] - - -def validate_apply_basics( - context: Context, environment: str, snapshots: t.Iterable[Snapshot] -) -> None: - validate_snapshots_in_state_sync(snapshots, context) - validate_state_sync_environment(snapshots, environment, context) - validate_tables(snapshots, context) - validate_environment_views(snapshots, environment, context) - - -def validate_snapshots_in_state_sync(snapshots: t.Iterable[Snapshot], context: Context) -> None: - snapshot_infos = map(to_snapshot_info, snapshots) - state_sync_table_infos = map( - to_snapshot_info, context.state_reader.get_snapshots(snapshots).values() - ) - assert set(snapshot_infos) == set(state_sync_table_infos) - - -def validate_state_sync_environment( - snapshots: t.Iterable[Snapshot], env: str, context: Context -) -> None: - environment = context.state_reader.get_environment(env) - assert environment - snapshot_infos = map(to_snapshot_info, snapshots) - environment_table_infos = map(to_snapshot_info, environment.snapshots) - assert set(snapshot_infos) == set(environment_table_infos) - - -def validate_tables(snapshots: t.Iterable[Snapshot], context: Context) -> None: - adapter = context.engine_adapter - for snapshot in snapshots: - if not snapshot.is_model or snapshot.is_external: - continue - table_should_exist = not snapshot.is_embedded - assert adapter.table_exists(snapshot.table_name()) == table_should_exist - if table_should_exist: - assert select_all(snapshot.table_name(), adapter) - - -def validate_environment_views( - snapshots: t.Iterable[Snapshot], environment: str, context: Context -) -> None: - adapter = context.engine_adapter - for snapshot in snapshots: - if not snapshot.is_model or snapshot.is_symbolic: - continue - view_name = snapshot.qualified_view_name.for_environment( - EnvironmentNamingInfo.from_environment_catalog_mapping( - context.config.environment_catalog_mapping, - name=environment, - suffix_target=context.config.environment_suffix_target, - ) - ) - - is_deployable = environment == c.PROD or not snapshot.is_paused_forward_only - - assert adapter.table_exists(view_name) - assert select_all(snapshot.table_name(is_deployable), adapter) == select_all( - view_name, adapter - ) - - -def select_all(table: str, adapter: EngineAdapter) -> t.Iterable: - return adapter.fetchall(f"select * from {table} order by 1") - - -def snapshots_to_versions(snapshots: t.Iterable[Snapshot]) -> t.Dict[str, str]: - return {snapshot.name: snapshot.version or "" for snapshot in snapshots} - - -def to_snapshot_info(snapshot: SnapshotInfoLike) -> SnapshotTableInfo: - return snapshot.table_info - - -def start(context: Context) -> TimeLike: - env = context.state_sync.get_environment("prod") - assert env - return env.start_at - - -def add_projection_to_model(model: SqlModel, literal: bool = True) -> SqlModel: - one_expr = exp.Literal.number(1).as_("one") if literal else exp.column("one") - kwargs = { - **model.dict(), - "query": model.query.select(one_expr), # type: ignore - } - return SqlModel.parse_obj(kwargs) diff --git a/tests/core/test_janitor.py b/tests/core/test_janitor.py new file mode 100644 index 0000000000..e5e209f2cc --- /dev/null +++ b/tests/core/test_janitor.py @@ -0,0 +1,282 @@ +import typing as t +from unittest.mock import call + +import pytest +from pytest_mock.plugin import MockerFixture + +from sqlmesh.core.config import EnvironmentSuffixTarget +from sqlmesh.core import constants as c +from sqlmesh.core.dialect import parse_one, schema_ +from sqlmesh.core.engine_adapter import create_engine_adapter +from sqlmesh.core.environment import Environment +from sqlmesh.core.model import ( + ModelKindName, + SqlModel, +) +from sqlmesh.core.model.definition import ExternalModel +from sqlmesh.core.snapshot import ( + SnapshotChangeCategory, +) +from sqlmesh.core.state_sync import ( + EngineAdapterStateSync, +) +from sqlmesh.core.janitor import cleanup_expired_views, delete_expired_snapshots +from sqlmesh.utils.date import now_timestamp +from sqlmesh.utils.errors import SQLMeshError + +pytestmark = pytest.mark.slow + + +@pytest.fixture +def state_sync(duck_conn, tmp_path): + state_sync = EngineAdapterStateSync( + create_engine_adapter(lambda: duck_conn, "duckdb"), + schema=c.SQLMESH, + cache_dir=tmp_path / c.CACHE, + ) + state_sync.migrate() + return state_sync + + +def test_cleanup_expired_views(mocker: MockerFixture, make_snapshot: t.Callable): + adapter = mocker.MagicMock() + adapter.dialect = None + snapshot_a = make_snapshot(SqlModel(name="catalog.schema.a", query=parse_one("select 1, ds"))) + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_b = make_snapshot(SqlModel(name="catalog.schema.b", query=parse_one("select 1, ds"))) + snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) + # Make sure that we don't drop schemas from external models + snapshot_external_model = make_snapshot( + ExternalModel(name="catalog.external_schema.external_table", kind=ModelKindName.EXTERNAL) + ) + snapshot_external_model.categorize_as(SnapshotChangeCategory.BREAKING) + schema_environment = Environment( + name="test_environment", + suffix_target=EnvironmentSuffixTarget.SCHEMA, + snapshots=[ + snapshot_a.table_info, + snapshot_b.table_info, + snapshot_external_model.table_info, + ], + start_at="2022-01-01", + end_at="2022-01-01", + plan_id="test_plan_id", + previous_plan_id="test_plan_id", + catalog_name_override="catalog_override", + ) + snapshot_c = make_snapshot(SqlModel(name="catalog.schema.c", query=parse_one("select 1, ds"))) + snapshot_c.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_d = make_snapshot(SqlModel(name="catalog.schema.d", query=parse_one("select 1, ds"))) + snapshot_d.categorize_as(SnapshotChangeCategory.BREAKING) + table_environment = Environment( + name="test_environment", + suffix_target=EnvironmentSuffixTarget.TABLE, + snapshots=[ + snapshot_c.table_info, + snapshot_d.table_info, + snapshot_external_model.table_info, + ], + start_at="2022-01-01", + end_at="2022-01-01", + plan_id="test_plan_id", + previous_plan_id="test_plan_id", + catalog_name_override="catalog_override", + ) + cleanup_expired_views(adapter, {}, [schema_environment, table_environment]) + assert adapter.drop_schema.called + assert adapter.drop_view.called + assert adapter.drop_schema.call_args_list == [ + call( + schema_("schema__test_environment", "catalog_override"), + ignore_if_not_exists=True, + cascade=True, + ) + ] + assert sorted(adapter.drop_view.call_args_list) == [ + call("catalog_override.schema.c__test_environment", ignore_if_not_exists=True), + call("catalog_override.schema.d__test_environment", ignore_if_not_exists=True), + ] + + +@pytest.mark.parametrize( + "suffix_target", [EnvironmentSuffixTarget.SCHEMA, EnvironmentSuffixTarget.TABLE] +) +def test_cleanup_expired_environment_schema_warn_on_delete_failure( + mocker: MockerFixture, make_snapshot: t.Callable, suffix_target: EnvironmentSuffixTarget +): + adapter = mocker.MagicMock() + adapter.dialect = None + adapter.drop_schema.side_effect = Exception("Failed to drop the schema") + adapter.drop_view.side_effect = Exception("Failed to drop the view") + + snapshot = make_snapshot( + SqlModel(name="test_catalog.test_schema.test_model", query=parse_one("select 1, ds")) + ) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + schema_environment = Environment( + name="test_environment", + suffix_target=suffix_target, + snapshots=[snapshot.table_info], + start_at="2022-01-01", + end_at="2022-01-01", + plan_id="test_plan_id", + previous_plan_id="test_plan_id", + catalog_name_override="catalog_override", + ) + + with pytest.raises(SQLMeshError, match="Failed to drop the expired environment .*"): + cleanup_expired_views(adapter, {}, [schema_environment], warn_on_delete_failure=False) + + cleanup_expired_views(adapter, {}, [schema_environment], warn_on_delete_failure=True) + + if suffix_target == EnvironmentSuffixTarget.SCHEMA: + assert adapter.drop_schema.called + else: + assert adapter.drop_view.called + + +def test_delete_expired_snapshots_common_function_batching( + state_sync: EngineAdapterStateSync, make_snapshot: t.Callable, mocker: MockerFixture +): + """Test that the common delete_expired_snapshots function properly pages through batches and deletes them.""" + from sqlmesh.core.state_sync.common import ExpiredBatchRange, RowBoundary, LimitBoundary + from unittest.mock import MagicMock + + now_ts = now_timestamp() + + # Create 5 expired snapshots with different timestamps + snapshots = [] + for idx in range(5): + snapshot = make_snapshot( + SqlModel( + name=f"model_{idx}", + query=parse_one("select 1 as a, ds"), + ), + ) + snapshot.ttl = "in 10 seconds" + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot.updated_ts = now_ts - (20000 + idx * 1000) + snapshots.append(snapshot) + + state_sync.push_snapshots(snapshots) + + # Spy on get_expired_snapshots and delete_expired_snapshots methods + get_expired_spy = mocker.spy(state_sync, "get_expired_snapshots") + delete_expired_spy = mocker.spy(state_sync, "delete_expired_snapshots") + + # Mock snapshot evaluator + mock_evaluator = MagicMock() + mock_evaluator.cleanup = MagicMock() + + # Run delete_expired_snapshots with batch_size=2 + delete_expired_snapshots( + state_sync, + mock_evaluator, + current_ts=now_ts, + batch_size=2, + ) + + # Verify get_expired_snapshots was called the correct number of times: + # - 3 batches (2+2+1): each batch triggers 2 calls (one from iter_expired_snapshot_batches, one from delete_expired_snapshots) + # - Plus 1 final call that returns empty to exit the loop + # Total: 3 * 2 + 1 = 7 calls + assert get_expired_spy.call_count == 7 + + # Verify the progression of batch_range calls from the iter_expired_snapshot_batches loop + # (calls at indices 0, 2, 4, 6 are from iter_expired_snapshot_batches) + # (calls at indices 1, 3, 5 are from delete_expired_snapshots in facade.py) + calls = get_expired_spy.call_args_list + + # First call from iterator should have a batch_range starting from the beginning + first_call_kwargs = calls[0][1] + assert "batch_range" in first_call_kwargs + first_range = first_call_kwargs["batch_range"] + assert isinstance(first_range, ExpiredBatchRange) + assert isinstance(first_range.start, RowBoundary) + assert isinstance(first_range.end, LimitBoundary) + assert first_range.end.batch_size == 2 + assert first_range.start.updated_ts == 0 + assert first_range.start.name == "" + assert first_range.start.identifier == "" + + # Third call (second batch from iterator) should have a batch_range from the first batch's range + third_call_kwargs = calls[2][1] + assert "batch_range" in third_call_kwargs + second_range = third_call_kwargs["batch_range"] + assert isinstance(second_range, ExpiredBatchRange) + assert isinstance(second_range.start, RowBoundary) + assert isinstance(second_range.end, LimitBoundary) + assert second_range.end.batch_size == 2 + # Should have progressed from the first batch + assert second_range.start.updated_ts > 0 + assert second_range.start.name == '"model_3"' + + # Fifth call (third batch from iterator) should have a batch_range from the second batch's range + fifth_call_kwargs = calls[4][1] + assert "batch_range" in fifth_call_kwargs + third_range = fifth_call_kwargs["batch_range"] + assert isinstance(third_range, ExpiredBatchRange) + assert isinstance(third_range.start, RowBoundary) + assert isinstance(third_range.end, LimitBoundary) + assert third_range.end.batch_size == 2 + # Should have progressed from the second batch + assert third_range.start.updated_ts >= second_range.start.updated_ts + assert third_range.start.name == '"model_1"' + + # Seventh call (final call from iterator) should have a batch_range from the third batch's range + seventh_call_kwargs = calls[6][1] + assert "batch_range" in seventh_call_kwargs + fourth_range = seventh_call_kwargs["batch_range"] + assert isinstance(fourth_range, ExpiredBatchRange) + assert isinstance(fourth_range.start, RowBoundary) + assert isinstance(fourth_range.end, LimitBoundary) + assert fourth_range.end.batch_size == 2 + # Should have progressed from the third batch + assert fourth_range.start.updated_ts >= third_range.start.updated_ts + assert fourth_range.start.name == '"model_0"' + + # Verify delete_expired_snapshots was called 3 times (once per batch) + assert delete_expired_spy.call_count == 3 + + # Verify each delete call used a batch_range + delete_calls = delete_expired_spy.call_args_list + + # First call should have a batch_range matching the first batch + first_delete_kwargs = delete_calls[0][1] + assert "batch_range" in first_delete_kwargs + first_delete_range = first_delete_kwargs["batch_range"] + assert isinstance(first_delete_range, ExpiredBatchRange) + assert isinstance(first_delete_range.start, RowBoundary) + assert first_delete_range.start.updated_ts == 0 + assert isinstance(first_delete_range.end, RowBoundary) + assert first_delete_range.end.updated_ts == second_range.start.updated_ts + assert first_delete_range.end.name == second_range.start.name + assert first_delete_range.end.identifier == second_range.start.identifier + + second_delete_kwargs = delete_calls[1][1] + assert "batch_range" in second_delete_kwargs + second_delete_range = second_delete_kwargs["batch_range"] + assert isinstance(second_delete_range, ExpiredBatchRange) + assert isinstance(second_delete_range.start, RowBoundary) + assert second_delete_range.start.updated_ts == 0 + assert isinstance(second_delete_range.end, RowBoundary) + assert second_delete_range.end.updated_ts == third_range.start.updated_ts + assert second_delete_range.end.name == third_range.start.name + assert second_delete_range.end.identifier == third_range.start.identifier + + third_delete_kwargs = delete_calls[2][1] + assert "batch_range" in third_delete_kwargs + third_delete_range = third_delete_kwargs["batch_range"] + assert isinstance(third_delete_range, ExpiredBatchRange) + assert isinstance(third_delete_range.start, RowBoundary) + assert third_delete_range.start.updated_ts == 0 + assert isinstance(third_delete_range.end, RowBoundary) + assert third_delete_range.end.updated_ts == fourth_range.start.updated_ts + assert third_delete_range.end.name == fourth_range.start.name + assert third_delete_range.end.identifier == fourth_range.start.identifier + # Verify the cleanup method was called for each batch that had cleanup tasks + assert mock_evaluator.cleanup.call_count >= 1 + + # Verify all snapshots were deleted in the end + remaining = state_sync.get_snapshots(snapshots) + assert len(remaining) == 0 diff --git a/tests/core/test_loader.py b/tests/core/test_loader.py new file mode 100644 index 0000000000..14a20ec09a --- /dev/null +++ b/tests/core/test_loader.py @@ -0,0 +1,203 @@ +import pytest +from pathlib import Path +from sqlmesh.cli.project_init import init_example_project +from sqlmesh.core.config import Config, ModelDefaultsConfig +from sqlmesh.core.context import Context +from sqlmesh.utils.errors import ConfigError + + +@pytest.fixture +def sample_models(request): + models = { + "sql": { + "contents": """ +MODEL ( + name test_schema.test_model, + kind FULL, +); + +SELECT 1; +""", + "path": "models/sql_model.sql", + }, + "python": { + "contents": """import typing as t +import pandas as pd # noqa: TID253 +from sqlmesh import ExecutionContext, model + +@model( + "test_schema.test_model", + kind="FULL", + columns={ + "id": "int", + } +) +def execute( + context: ExecutionContext, + **kwargs: t.Any, +) -> pd.DataFrame: + return pd.DataFrame([ + {"id": 1} + ]) +""", + "path": "models/python_model.py", + }, + "external": { + "contents": """ +- name: test_schema.test_model + columns: + id: INT +""", + "path": "external_models/external_model.yaml", + }, + } + requested_models = request.param.split("_") + return [v for k, v in models.items() if k in requested_models] + + +@pytest.mark.parametrize( + "sample_models", + ["sql_python", "python_external", "sql_external", "sql_python_external"], + indirect=True, +) +def test_duplicate_model_names_different_kind(tmp_path: Path, sample_models): + """Test different (SQL, Python and external) models with duplicate model names raises ValueError.""" + model_1, *models = sample_models + if len(models) == 2: + model_2, model_3 = models + else: + model_2, model_3 = models[0], None + + init_example_project(tmp_path, engine_type="duckdb") + config = Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")) + + path_1: Path = tmp_path / model_1["path"] + path_2: Path = tmp_path / model_2["path"] + + path_1.parent.mkdir(parents=True, exist_ok=True) + path_1.write_text(model_1["contents"]) + path_2.parent.mkdir(parents=True, exist_ok=True) + path_2.write_text(model_2["contents"]) + + if model_3: + path_3: Path = tmp_path / model_3["path"] + path_3.parent.mkdir(parents=True, exist_ok=True) + path_3.write_text(model_3["contents"]) + + with pytest.raises( + ConfigError, match=r'Duplicate model name\(s\) found: "memory"."test_schema"."test_model".' + ): + Context(paths=tmp_path, config=config) + + +@pytest.mark.parametrize("sample_models", ["sql", "external"], indirect=True) +def test_duplicate_model_names_same_kind(tmp_path: Path, sample_models): + """Test same (SQL and external) models with duplicate model names raises ConfigError.""" + + def duplicate_model_path(fpath): + return Path(fpath).parent / ("duplicate" + Path(fpath).suffix) + + model = sample_models[0] + init_example_project(tmp_path, engine_type="duckdb") + config = Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")) + + path_1: Path = tmp_path / model["path"] + path_1.parent.mkdir(parents=True, exist_ok=True) + path_1.write_text(model["contents"]) + + duplicate_fpath = tmp_path / duplicate_model_path(model["path"]) + duplicate_fpath.write_text(model["contents"]) + + with pytest.raises( + ConfigError, + match=r".*Duplicate .* model name: 'test_schema.test_model'", + ): + Context(paths=tmp_path, config=config) + + +@pytest.mark.registry_isolation +def test_duplicate_python_model_names_raise_error(tmp_path: Path) -> None: + """Test python models with duplicate model names raises ConfigError if the functions are not identical.""" + init_example_project(tmp_path, engine_type="duckdb") + config = Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")) + model_name = "test_schema.test_model" + + path_a = tmp_path / "models/test_schema/test_model_a.py" + path_b = tmp_path / "models/test_schema/test_model_b.py" + + model_payload_a = f"""from sqlmesh import model +@model( + name="{model_name}", + columns={{'"COL"': "int"}}, +) +def my_model(context, **kwargs): + pass""" + + model_payload_b = f"""import typing as t +import pandas as pd # noqa: TID253 +from sqlmesh import ExecutionContext, model + +@model( + name="{model_name}", + kind="FULL", + columns={{ + "id": "int", + }} +) +def execute( + context: ExecutionContext, + **kwargs: t.Any, +) -> pd.DataFrame: + return pd.DataFrame([ + {{"id": 1}} + ]) +""" + + path_a.parent.mkdir(parents=True, exist_ok=True) + path_a.write_text(model_payload_a) + path_b.write_text(model_payload_b) + + with pytest.raises( + ConfigError, + match=r"Failed to load model from file '.*'.\n\n Duplicate name: 'test_schema.test_model'.", + ): + Context(paths=tmp_path, config=config) + + +@pytest.mark.slow +def test_duplicate_python_model_names_no_error(tmp_path: Path) -> None: + """Test python models with duplicate model names raises no error if the functions are identical.""" + init_example_project(tmp_path, engine_type="duckdb") + config = Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")) + model_name = "test_schema.test_model" + + path_a = tmp_path / "models/test_schema1/test_model_a.py" + path_b = tmp_path / "models/test_schema2/test_model_b.py" + + model_payload_a = f"""from sqlmesh import model +@model( + name="{model_name}", + columns={{'"COL"': "int"}}, + description="model_payload_a", +) +def my_model(context, **kwargs): + pass""" + + model_payload_b = f"""from sqlmesh import model +@model( + name="{model_name}", + columns={{'"COL"': "int"}}, + description="model_payload_b", +) +def my_model(context, **kwargs): + pass""" + + path_a.parent.mkdir(parents=True, exist_ok=True) + path_b.parent.mkdir(parents=True, exist_ok=True) + path_a.write_text(model_payload_a) + context = Context(paths=tmp_path, config=config) + context.load() + model = context.get_model(f"{model_name}") + assert model.description == "model_payload_a" + path_b.write_text(model_payload_b) + context.load() # raise no error to duplicate key if the functions are identical (by registry class_method) diff --git a/tests/core/test_macros.py b/tests/core/test_macros.py index 6b1fba2060..fb10f64b27 100644 --- a/tests/core/test_macros.py +++ b/tests/core/test_macros.py @@ -1,4 +1,5 @@ import typing as t +from datetime import datetime, date import pytest from sqlglot import MappingSchema, ParseError, exp, parse_one @@ -6,8 +7,10 @@ from sqlmesh.core import constants as c, dialect as d from sqlmesh.core.dialect import StagedFilePath from sqlmesh.core.macros import SQL, MacroEvalError, MacroEvaluator, macro +from sqlmesh.utils.date import to_datetime, to_date from sqlmesh.utils.errors import SQLMeshError from sqlmesh.utils.metaprogramming import Executable +from sqlmesh.core.macros import RuntimeStage @pytest.fixture @@ -89,6 +92,16 @@ def test_default_arg_coercion( ): return sum([a1, a2]) + @macro() + def test_select_macro(evaluator): + return "SELECT 1 AS col" + + @macro() + def test_literal_type(evaluator, a: t.Literal["test_literal_a", "test_literal_b", 1, True]): + if isinstance(a, exp.Expression): + raise SQLMeshError("Coercion failed") + return f"'{a}'" + return MacroEvaluator( "hive", {"test": Executable(name="test", payload="def test(_):\n return 'test'")}, @@ -97,7 +110,7 @@ def test_default_arg_coercion( def test_star(assert_exp_eq) -> None: sql = """SELECT @STAR(foo) FROM foo""" - expected_sql = """SELECT CAST([foo].[a] AS DATETIMEOFFSET) AS [a], CAST([foo].[b] AS INTEGER) AS [b] FROM foo""" + expected_sql = "SELECT CAST([foo].[a] AS DATETIMEOFFSET) AS [a], CAST([foo].[b] AS INTEGER) AS [b] FROM foo" schema = MappingSchema( { "foo": { @@ -110,7 +123,7 @@ def test_star(assert_exp_eq) -> None: evaluator = MacroEvaluator(schema=schema, dialect="tsql") assert_exp_eq(evaluator.transform(parse_one(sql, read="tsql")), expected_sql, dialect="tsql") - sql = """SELECT @STAR(foo, exclude := [SomeColumn]) FROM foo""" + sql = "SELECT @STAR(foo, exclude := [SomeColumn]) FROM foo" expected_sql = "SELECT CAST(`foo`.`a` AS STRING) AS `a` FROM foo" schema = MappingSchema( { @@ -128,6 +141,72 @@ def test_star(assert_exp_eq) -> None: dialect="databricks", ) + sql = "SELECT @STAR(foo, exclude := ARRAY(b)) FROM foo" + expected_sql = "SELECT [foo].[a] AS [a] FROM foo" + schema = MappingSchema( + { + "foo": { + "a": exp.DataType.build("unknown"), + "b": "int", + }, + }, + dialect="tsql", + ) + evaluator = MacroEvaluator(schema=schema, dialect="tsql") + assert_exp_eq(evaluator.transform(parse_one(sql, read="tsql")), expected_sql, dialect="tsql") + + sql = """SELECT @STAR(foo) FROM foo""" + expected_sql = ( + """SELECT CAST("FOO"."A" AS DATE) AS "A", CAST("FOO"."B" AS INTEGER) AS "B" FROM foo""" + ) + schema = MappingSchema( + { + "foo": { + "a": exp.DataType.build("date", dialect="snowflake"), + "b": "int", + }, + }, + dialect="snowflake", + ) + evaluator = MacroEvaluator(schema=schema, dialect="snowflake") + assert_exp_eq( + evaluator.transform(parse_one(sql, read="snowflake")), expected_sql, dialect="snowflake" + ) + + sql = """SELECT @STAR("foo") FROM "foo" """ + expected_sql = ( + """SELECT CAST("foo"."A" AS DATE) AS "A", CAST("foo"."B" AS INTEGER) AS "B" FROM "foo" """ + ) + schema = MappingSchema( + { + '"foo"': { + "a": exp.DataType.build("date", dialect="snowflake"), + "b": "int", + }, + }, + dialect="snowflake", + ) + evaluator = MacroEvaluator(schema=schema, dialect="snowflake") + assert_exp_eq( + evaluator.transform(parse_one(sql, read="snowflake")), expected_sql, dialect="snowflake" + ) + + sql = """SELECT @STAR(foo, alias := "bar") FROM foo "bar" """ + expected_sql = """SELECT CAST("bar"."A" AS DATE) AS "A", CAST("bar"."B" AS INTEGER) AS "B" FROM foo "bar" """ + schema = MappingSchema( + { + "foo": { + "a": exp.DataType.build("date", dialect="snowflake"), + "b": "int", + }, + }, + dialect="snowflake", + ) + evaluator = MacroEvaluator(schema=schema, dialect="snowflake") + assert_exp_eq( + evaluator.transform(parse_one(sql, read="snowflake")), expected_sql, dialect="snowflake" + ) + def test_start_no_column_types(assert_exp_eq) -> None: sql = """SELECT @STAR(foo) FROM foo""" @@ -155,7 +234,9 @@ def test_macro_var(macro_evaluator): # Check Snowflake-specific StagedFilePath / MacroVar behavior e = parse_one("select @x from @path, @y", dialect="snowflake") + macro_evaluator.locals = {"x": parse_one("a"), "y": parse_one("t2")} + macro_evaluator.dialect = "snowflake" assert e.find(StagedFilePath) is not None assert macro_evaluator.transform(e).sql(dialect="snowflake") == "SELECT a FROM @path, t2" @@ -171,6 +252,10 @@ def test_macro_var(macro_evaluator): # Parsing a "parameter" like Snowflake's $1 should not produce a MacroVar expression e = parse_one("select $1 from @path (file_format => bla.foo)", read="snowflake") assert e.find(exp.Parameter) is e.selects[0] + assert e.find(StagedFilePath) + # test no space + e = parse_one("select $1 from @path(file_format => bla.foo)", read="snowflake") + assert e.find(StagedFilePath) assert e.sql(dialect="snowflake") == "SELECT $1 FROM @path (FILE_FORMAT => bla.foo)" macro_evaluator.locals = {"x": 1} @@ -207,6 +292,16 @@ def test_ast_correctness(macro_evaluator): "SELECT 'a' + a_z + 'c' + c_a, 'b' + b_z + 'c' + c_b", {"y": "c"}, ), + ( + """select @each(['a'], x -> @X)""", + "SELECT 'a'", + {}, + ), + ( + """select @each(['a'], X -> @x)""", + "SELECT 'a'", + {}, + ), ( '"is_@{x}"', '"is_b"', @@ -278,11 +373,24 @@ def test_ast_correctness(macro_evaluator): "SELECT column LIKE a OR column LIKE b OR column LIKE c", {}, ), + ("SELECT @REDUCE([1], (x, y) -> x + y)", "SELECT 1", {}), + ("SELECT @REDUCE([1, 2], (x, y) -> x + y)", "SELECT 1 + 2", {}), + ("SELECT @REDUCE([[1]], (x, y) -> x + y)", "SELECT ARRAY(1)", {}), + ("SELECT @REDUCE([[1, 2]], (x, y) -> x + y)", "SELECT ARRAY(1, 2)", {}), ( """select @EACH([a, b, c], x -> column like x AS @SQL('@{x}_y', 'Identifier')), @x""", "SELECT column LIKE a AS a_y, column LIKE b AS b_y, column LIKE c AS c_y, '3'", {"x": "3"}, ), + ("SELECT @EACH([1], a -> [@a])", "SELECT ARRAY(1)", {}), + ("SELECT @EACH([1, 2], a -> [@a])", "SELECT ARRAY(1), ARRAY(2)", {}), + ("SELECT @REDUCE(@EACH([1], a -> [@a]), (x, y) -> x + y)", "SELECT ARRAY(1)", {}), + ( + "SELECT @REDUCE(@EACH([1, 2], a -> [@a]), (x, y) -> x + y)", + "SELECT ARRAY(1) + ARRAY(2)", + {}, + ), + ("SELECT @REDUCE([[1],[2]], (x, y) -> x + y)", "SELECT ARRAY(1) + ARRAY(2)", {}), ( """@WITH(@do_with) all_cities as (select * from city) select all_cities""", "WITH all_cities AS (SELECT * FROM city) SELECT all_cities", @@ -490,6 +598,26 @@ def test_ast_correctness(macro_evaluator): "SELECT 3", {}, ), + ( + "SELECT * FROM (VALUES @EACH([1, 2, 3], v -> (v)) ) AS v", + "SELECT * FROM (VALUES (1), (2), (3)) AS v", + {}, + ), + ( + "SELECT * FROM (VALUES (@EACH([1, 2, 3], v -> (v))) ) AS v", + "SELECT * FROM (VALUES ((1), (2), (3))) AS v", + {}, + ), + ( + "SELECT * FROM (VALUES @EACH([1, 2, 3], v -> (v, @EVAL(@v + 1))) ) AS v", + "SELECT * FROM (VALUES (1, 2), (2, 3), (3, 4)) AS v", + {}, + ), + ( + "SELECT * FROM (VALUES (@EACH([1, 2, 3], v -> (v, @EVAL(@v + 1)))) ) AS v", + "SELECT * FROM (VALUES ((1, 2), (2, 3), (3, 4))) AS v", + {}, + ), ], ) def test_macro_functions(macro_evaluator: MacroEvaluator, assert_exp_eq, sql, expected, args): @@ -513,6 +641,8 @@ def test_macro_coercion(macro_evaluator: MacroEvaluator, assert_exp_eq): assert coerce(exp.Literal.number(1.1), float) == 1.1 assert coerce(exp.Literal.string("Hi mom"), str) == "Hi mom" assert coerce(exp.true(), bool) is True + assert coerce(exp.Literal.string("2020-01-01"), datetime) == to_datetime("2020-01-01") + assert coerce(exp.Literal.string("2020-01-01"), date) == to_date("2020-01-01") # Coercing a string literal to a column should return a column with the same name assert_exp_eq(coerce(exp.Literal.string("order"), exp.Column), exp.column("order")) @@ -593,35 +723,26 @@ def test_positional_follows_kwargs(macro_evaluator): def test_macro_parameter_resolution(macro_evaluator): - with pytest.raises(MacroEvalError) as e: + with pytest.raises(MacroEvalError, match=".*missing a required argument: 'pos_only'"): macro_evaluator.evaluate(parse_one("@test_arg_resolution()")) - assert str(e.value.__cause__) == "missing a required argument: 'pos_only'" - with pytest.raises(MacroEvalError) as e: + with pytest.raises(MacroEvalError, match=".*missing a required argument: 'pos_only'"): macro_evaluator.evaluate(parse_one("@test_arg_resolution(a1 := 1)")) - assert str(e.value.__cause__) == "missing a required argument: 'pos_only'" - with pytest.raises(MacroEvalError) as e: + with pytest.raises(MacroEvalError, match=".*missing a required argument: 'a1'"): macro_evaluator.evaluate(parse_one("@test_arg_resolution(1)")) - assert str(e.value.__cause__) == "missing a required argument: 'a1'" - with pytest.raises(MacroEvalError) as e: + with pytest.raises(MacroEvalError, match=".*missing a required argument: 'a1'"): macro_evaluator.evaluate(parse_one("@test_arg_resolution(1, a2 := 2)")) - assert str(e.value.__cause__) == "missing a required argument: 'a1'" - with pytest.raises(MacroEvalError) as e: + with pytest.raises( + MacroEvalError, + match=".*'pos_only' parameter is positional only, but was passed as a keyword|.*missing a required positional-only argument: 'pos_only'|.*missing a required argument: 'a1'", + ): macro_evaluator.evaluate(parse_one("@test_arg_resolution(pos_only := 1)")) - # The CI was failing for Python 3.12 with the latter message, but other versions fail - # with the former one. This ensures we capture both. - assert str(e.value.__cause__) in ( - "'pos_only' parameter is positional only, but was passed as a keyword", - "missing a required argument: 'a1'", - ) - - with pytest.raises(MacroEvalError) as e: + with pytest.raises(MacroEvalError, match=".*too many positional arguments"): macro_evaluator.evaluate(parse_one("@test_arg_resolution(1, 2, 3)")) - assert str(e.value.__cause__) == "too many positional arguments" def test_macro_metadata_flag(): @@ -659,3 +780,389 @@ def test_macro_first_value_ignore_respect_nulls(assert_exp_eq) -> None: "SELECT FIRST_VALUE(@test(x) RESPECT NULLS) OVER (ORDER BY y) AS column_test" ) assert_exp_eq(evaluator.transform(actual_expr), expected_sql, dialect="duckdb") + + +DEDUPLICATE_SQL = """ +@deduplicate( + my_table, + [user_id, CAST(timestamp AS DATE)], + ['timestamp DESC', 'status ASC nulls last'] +) +""" + + +@pytest.mark.parametrize( + "dialect, sql, expected_sql", + [ + *[ + ( + dialect, + DEDUPLICATE_SQL, + """ + SELECT * + FROM my_table + QUALIFY ROW_NUMBER() OVER ( + PARTITION BY user_id, CAST(timestamp AS DATE) + ORDER BY timestamp DESC, status ASC NULLS LAST + ) = 1 + """, + ) + for dialect in ["bigquery", "databricks", "snowflake", "duckdb"] + ], + ( + "redshift", + DEDUPLICATE_SQL, + """ + SELECT * + FROM my_table + QUALIFY ROW_NUMBER() OVER ( + PARTITION BY user_id, CAST("timestamp" AS DATE) + ORDER BY "timestamp" DESC, status ASC NULLS LAST + ) = 1 + """, + ), + *[ + ( + dialect, + DEDUPLICATE_SQL, + """ + SELECT * + FROM ( + SELECT *, ROW_NUMBER() OVER ( + PARTITION BY user_id, CAST(timestamp AS DATE) + ORDER BY timestamp DESC, status ASC NULLS LAST + ) AS _w + FROM my_table + ) as _t + WHERE _w = 1 + """, + ) + for dialect in ["trino", "postgres"] + ], + ], +) +def test_deduplicate(assert_exp_eq, dialect, sql, expected_sql): + schema = MappingSchema({}, dialect=dialect) + evaluator = MacroEvaluator(schema=schema, dialect=dialect) + assert_exp_eq(evaluator.transform(parse_one(sql)), expected_sql, dialect=dialect) + + +def test_deduplicate_error_handling(macro_evaluator): + # Test error handling: non-list partition_by + with pytest.raises( + SQLMeshError, + match="partition_by must be a list of columns: \\[, cast\\( as \\)\\]", + ): + macro_evaluator.evaluate(parse_one("@deduplicate(my_table, user_id, ['timestamp DESC'])")) + + # Test error handling: non-list order_by + with pytest.raises( + SQLMeshError, + match="order_by must be a list of strings, optional - nulls ordering: \\[' nulls '\\]", + ): + macro_evaluator.evaluate(parse_one("@deduplicate(my_table, [user_id], 'timestamp DESC')")) + + # Test error handling: empty order_by + with pytest.raises( + SQLMeshError, + match="order_by must be a list of strings, optional - nulls ordering: \\[' nulls '\\]", + ): + macro_evaluator.evaluate(parse_one("@deduplicate(my_table, [user_id], [])")) + + +@pytest.mark.parametrize( + "dialect, date_part", + [ + (dialect, date_part) + for dialect in [ + "duckdb", + "snowflake", + "postgres", + "spark", + "bigquery", + "databricks", + "redshift", + ] + for date_part in ["day", "week", "month", "quarter", "year"] + ], +) +def test_date_spine(assert_exp_eq, dialect, date_part): + date_spine_macro = f""" + @date_spine( + '{date_part}', + '2022-01-01', + '2024-12-31' + ) + """ + schema = MappingSchema({}, dialect=dialect) + evaluator = MacroEvaluator(schema=schema, dialect=dialect) + + # Generate the expected SQL based on the dialect and date_part + if dialect == "duckdb": + interval = f"INTERVAL '1' {date_part.upper()}" + expected_sql = f""" + SELECT + date_{date_part} + FROM + UNNEST( + CAST( + GENERATE_SERIES( + CAST('2022-01-01' AS DATE), + CAST('2024-12-31' AS DATE), + {interval} + ) AS DATE[] + ) + ) AS _exploded(date_{date_part}) + """ + elif dialect == "snowflake": + expected_sql = f""" + SELECT + date_{date_part} + FROM ( + SELECT + DATEADD( + {date_part.upper()}, + CAST(date_{date_part} AS INT), + CAST('2022-01-01' AS DATE) + ) AS date_{date_part} + FROM + TABLE( + FLATTEN( + INPUT => ARRAY_GENERATE_RANGE( + 0, + ( + DATEDIFF( + {date_part.upper()}, + CAST('2022-01-01' AS DATE), + CAST('2024-12-31' AS DATE) + ) + 1 - 1 + ) + 1 + ) + ) + ) AS _exploded(seq, key, path, index, date_{date_part}, this) + ) AS _exploded(date_{date_part}) + """ + elif dialect == "postgres": + interval = "3 MONTH" if date_part == "quarter" else f"1 {date_part.upper()}" + expected_sql = f""" + SELECT + date_{date_part} + FROM ( + SELECT + CAST(value AS DATE) + FROM + GENERATE_SERIES( + CAST('2022-01-01' AS DATE), + CAST('2024-12-31' AS DATE), + INTERVAL '{interval}' + ) AS _t(value) + ) AS _exploded(date_{date_part}) + """ + elif dialect == "spark": + interval = "3 MONTH" if date_part == "quarter" else f"1 {date_part.upper()}" + expected_sql = f""" + SELECT + date_{date_part} + FROM + EXPLODE( + SEQUENCE( + CAST('2022-01-01' AS DATE), + CAST('2024-12-31' AS DATE), + INTERVAL {interval} + ) + ) AS _exploded(date_{date_part}) + """ + elif dialect == "databricks": + interval = "3 MONTH" if date_part == "quarter" else f"1 {date_part.upper()}" + expected_sql = f""" + SELECT + date_{date_part} + FROM + EXPLODE( + SEQUENCE( + CAST('2022-01-01' AS DATE), + CAST('2024-12-31' AS DATE), + INTERVAL {interval} + ) + ) AS _exploded(date_{date_part}) + """ + elif dialect == "bigquery": + expected_sql = f""" + SELECT + date_{date_part} + FROM + UNNEST( + GENERATE_DATE_ARRAY( + CAST('2022-01-01' AS DATE), + CAST('2024-12-31' AS DATE), + INTERVAL '1' {date_part} + ) + ) AS date_{date_part} + """ + elif dialect == "redshift": + expected_sql = f""" + WITH RECURSIVE _generated_dates(date_{date_part}) AS ( + SELECT + CAST('2022-01-01' AS DATE) AS date_{date_part} + UNION ALL + SELECT + CAST(DATEADD({date_part}, 1, date_{date_part}) AS DATE) + FROM _generated_dates + WHERE + CAST(DATEADD({date_part}, 1, date_{date_part}) AS DATE) <= CAST('2024-12-31' AS DATE) + ) + SELECT + date_{date_part} + FROM ( + SELECT + date_{date_part} + FROM _generated_dates + ) AS _generated_dates + """ + assert_exp_eq(evaluator.transform(parse_one(date_spine_macro)), expected_sql, dialect=dialect) + + +def test_date_spine_error_handling(macro_evaluator): + # Test error handling: invalid datepart + with pytest.raises( + MacroEvalError, + match=".*Invalid datepart 'invalid'. Expected: 'day', 'week', 'month', 'quarter', or 'year'", + ): + macro_evaluator.evaluate(parse_one("@date_spine('invalid', '2022-01-01', '2024-12-31')")) + + # Test error handling: invalid start_date format + with pytest.raises( + MacroEvalError, + match=".*Invalid date format - start_date and end_date must be in format: YYYY-MM-DD", + ): + macro_evaluator.evaluate(parse_one("@date_spine('day', '2022/01/01', '2024-12-31')")) + + # Test error handling: invalid end_date format + with pytest.raises( + MacroEvalError, + match=".*Invalid date format - start_date and end_date must be in format: YYYY-MM-DD", + ): + macro_evaluator.evaluate(parse_one("@date_spine('day', '2022-01-01', '2024/12/31')")) + + # Test error handling: start_date after end_date + with pytest.raises( + MacroEvalError, + match=".*Invalid date range - start_date '2024-12-31' is after end_date '2022-01-01'.", + ): + macro_evaluator.evaluate(parse_one("@date_spine('day', '2024-12-31', '2022-01-01')")) + + +def test_macro_union(assert_exp_eq, macro_evaluator: MacroEvaluator): + sql = "SELECT 1 AS col UNION ALL @TEST_SELECT_MACRO()" + expected_sql = "SELECT 1 AS col UNION ALL SELECT 1 AS col" + + assert_exp_eq(macro_evaluator.transform(parse_one(sql)), expected_sql) + + +def test_resolve_template_literal(): + parsed_sql = parse_one( + "@resolve_template('s3://data-bucket/prod/@{catalog_name}/@{schema_name}/@{table_name}')" + ) + + # Loading + # During loading, this should passthrough / no-op + # This is because SQLMesh renders everything on load to figure out model dependencies and we dont want to throw an error + evaluator = MacroEvaluator(runtime_stage=RuntimeStage.LOADING) + assert evaluator.transform(parsed_sql) == exp.Literal.string( + "s3://data-bucket/prod/@{catalog_name}/@{schema_name}/@{table_name}" + ) + + # Creating + # This macro can work during creating / evaluating but only if @this_model is present in the context + evaluator = MacroEvaluator(runtime_stage=RuntimeStage.CREATING) + with pytest.raises(MacroEvalError, match=".*this_model must be present"): + evaluator.transform(parsed_sql) + + evaluator.locals.update( + {"this_model": exp.to_table("test_catalog.sqlmesh__test.test__test_model__2517971505")} + ) + + assert ( + evaluator.transform(parsed_sql).sql() + == "'s3://data-bucket/prod/test_catalog/sqlmesh__test/test__test_model__2517971505'" + ) + + # Evaluating + evaluator = MacroEvaluator(runtime_stage=RuntimeStage.EVALUATING) + evaluator.locals.update( + {"this_model": exp.to_table("test_catalog.sqlmesh__test.test__test_model__2517971505")} + ) + assert ( + evaluator.transform(parsed_sql).sql() + == "'s3://data-bucket/prod/test_catalog/sqlmesh__test/test__test_model__2517971505'" + ) + + +def test_resolve_template_table(): + parsed_sql = parse_one( + "SELECT * FROM @resolve_template('@{catalog_name}.@{schema_name}.@{table_name}$partitions', mode := 'table')" + ) + + evaluator = MacroEvaluator(runtime_stage=RuntimeStage.CREATING) + evaluator.locals.update( + {"this_model": exp.to_table("test_catalog.sqlmesh__test.test__test_model__2517971505")} + ) + + assert ( + evaluator.transform(parsed_sql).sql(identify=True) + == 'SELECT * FROM "test_catalog"."sqlmesh__test"."test__test_model__2517971505$partitions"' + ) + + +def test_macro_with_spaces(): + evaluator = MacroEvaluator() + evaluator.evaluate(d.parse_one(""" @DEF(x, "a b") """)) + evaluator.evaluate(d.parse_one(""" @DEF(y, 'a b') """)) + evaluator.evaluate(d.parse_one(""" @DEF(z, a."b c") """)) + + for sql, expected in ( + ("@x", '"a b"'), + ("@X", '"a b"'), + ("@{x}", '"a b"'), + ("@{X}", '"a b"'), + ("a_@x", '"a_a b"'), + ("a.@x", 'a."a b"'), + ("@y", "'a b'"), + ("@{y}", '"a b"'), # a little tricky here as it's not a string + ("a_@y", '"a_a b"'), + ("a.@{y}", 'a."a b"'), + ("@z", 'a."b c"'), + ("d.@z", 'd.a."b c"'), + ("@'test_@{X}_suffix'", "'test_a b_suffix'"), + ): + assert evaluator.transform(parse_one(sql)).sql() == expected + + +def test_macro_coerce_literal_type(macro_evaluator): + expression = d.parse_one("@TEST_LITERAL_TYPE('test_literal_a')") + assert macro_evaluator.transform(expression).sql() == "'test_literal_a'" + + expression = d.parse_one("@TEST_LITERAL_TYPE('test_literal_b')") + assert macro_evaluator.transform(expression).sql() == "'test_literal_b'" + + expression = d.parse_one("@TEST_LITERAL_TYPE(1)") + assert macro_evaluator.transform(expression).sql() == "'1'" + + expression = d.parse_one("@TEST_LITERAL_TYPE(True)") + assert macro_evaluator.transform(expression).sql() == "'True'" + + expression = d.parse_one("@TEST_LITERAL_TYPE('test_literal_c')") + with pytest.raises(MacroEvalError, match=".*Coercion failed"): + macro_evaluator.transform(expression) + + expression = d.parse_one("@TEST_LITERAL_TYPE(2)") + with pytest.raises(MacroEvalError, match=".*Coercion failed"): + macro_evaluator.transform(expression) + + expression = d.parse_one("@TEST_LITERAL_TYPE(False)") + with pytest.raises(MacroEvalError, match=".*Coercion failed"): + macro_evaluator.transform(expression) + + expression = d.parse_one("@TEST_LITERAL_TYPE(1.0)") + with pytest.raises(MacroEvalError, match=".*Coercion failed"): + macro_evaluator.transform(expression) diff --git a/tests/core/test_model.py b/tests/core/test_model.py index bbe73352ac..cfcb843739 100644 --- a/tests/core/test_model.py +++ b/tests/core/test_model.py @@ -1,27 +1,43 @@ # ruff: noqa: F811 import json -import logging import typing as t +import re from datetime import date, datetime from pathlib import Path -from unittest.mock import patch +from unittest.mock import patch, PropertyMock -import pandas as pd +import time_machine +import pandas as pd # noqa: TID253 import pytest from pytest_mock.plugin import MockerFixture from sqlglot import exp, parse_one +from sqlglot.errors import ParseError from sqlglot.schema import MappingSchema -from sqlmesh.cli.example_project import init_example_project +from sqlmesh.cli.project_init import init_example_project, ProjectTemplate +from sqlmesh.core.environment import EnvironmentNamingInfo +from sqlmesh.core.model.kind import TimeColumn, ModelKindName, SeedKind +from sqlmesh import CustomMaterialization, CustomKind +from pydantic import model_validator, ValidationError from sqlmesh.core import constants as c from sqlmesh.core import dialect as d +from sqlmesh.core.console import get_console +from sqlmesh.core.audit import ModelAudit, load_audit +from sqlmesh.core.model.common import ParsableSql from sqlmesh.core.config import ( Config, + DuckDBConnectionConfig, + GatewayConfig, NameInferenceConfig, ModelDefaultsConfig, + LinterConfig, ) +from sqlmesh.core import constants as c from sqlmesh.core.context import Context, ExecutionContext from sqlmesh.core.dialect import parse +from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS +from sqlmesh.core.engine_adapter.duckdb import DuckDBEngineAdapter +from sqlmesh.core.engine_adapter.shared import DataObjectType from sqlmesh.core.macros import MacroEvaluator, macro from sqlmesh.core.model import ( CustomKind, @@ -29,27 +45,35 @@ FullKind, IncrementalByTimeRangeKind, IncrementalUnmanagedKind, + IncrementalByUniqueKeyKind, ModelCache, ModelMeta, SeedKind, SqlModel, TimeColumn, + ExternalKind, ViewKind, + EmbeddedKind, + SCDType2ByTimeKind, create_external_model, create_seed_model, create_sql_model, load_sql_based_model, + load_sql_based_models, model, ) from sqlmesh.core.model.common import parse_expression -from sqlmesh.core.model.kind import ModelKindName, _model_kind_validator +from sqlmesh.core.model.kind import _ModelKind, ModelKindName, _model_kind_validator from sqlmesh.core.model.seed import CsvSettings -from sqlmesh.core.node import IntervalUnit, _Node +from sqlmesh.core.node import IntervalUnit, _Node, DbtNodeInfo +from sqlmesh.core.signal import signal from sqlmesh.core.snapshot import Snapshot, SnapshotChangeCategory from sqlmesh.utils.date import TimeLike, to_datetime, to_ds, to_timestamp -from sqlmesh.utils.errors import ConfigError, SQLMeshError -from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroInfo -from sqlmesh.utils.metaprogramming import Executable +from sqlmesh.utils.errors import ConfigError, SQLMeshError, LinterError +from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroInfo, MacroExtractor +from sqlmesh.utils.metaprogramming import Executable, SqlValue +from sqlmesh.core.macros import RuntimeStage +from tests.utils.test_helpers import use_terminal_console def missing_schema_warning_msg(model, deps): @@ -110,9 +134,10 @@ def test_load(assert_exp_eq): assert model.name == "db.table" assert model.owner == "owner_name" assert model.dialect == "spark" + assert model.table_format is None assert model.storage_format == "iceberg" assert [col.sql() for col in model.partitioned_by] == ['"a"', '"d"'] - assert model.clustered_by == ["e"] + assert [col.sql() for col in model.clustered_by] == ['"e"'] assert model.columns_to_types == { "a": exp.DataType.build("int"), "b": exp.DataType.build("double"), @@ -177,14 +202,48 @@ def test_model_multiple_select_statements(): load_sql_based_model(expressions) -@pytest.mark.parametrize( - "query, error", - [ - ("y::int, x::int AS y", "duplicate"), - ("* FROM db.table", "require inferrable column types"), - ], -) -def test_model_validation(query, error): +def test_model_validation(tmp_path): + expressions = d.parse( + f""" + MODEL ( + name db.table, + kind FULL, + ); + + SELECT + y::int, + x::int AS y + FROM db.ext + """ + ) + + ctx = Context( + config=Config(linter=LinterConfig(enabled=True, rules=["noambiguousprojections"])), + paths=tmp_path, + ) + ctx.upsert_model(load_sql_based_model(expressions, default_catalog="memory")) + + errors = ctx.lint_models(["db.table"], raise_on_error=False) + assert errors, "Expected NoAmbiguousProjections violation" + assert errors[0].violation_msg == "Found duplicate outer select name 'y'" + + expressions = d.parse( + """ + MODEL ( + name db.table, + kind FULL, + ); + + SELECT a, a UNION SELECT c, c + """ + ) + + ctx.upsert_model(load_sql_based_model(expressions, default_catalog="memory")) + + errors = ctx.lint_models(["db.table"], raise_on_error=False) + assert errors, "Expected NoAmbiguousProjections violation" + assert errors[0].violation_msg == "Found duplicate outer select name 'a'" + expressions = d.parse( f""" MODEL ( @@ -192,14 +251,15 @@ def test_model_validation(query, error): kind FULL, ); - SELECT {query} + SELECT * FROM db.table """ ) model = load_sql_based_model(expressions) with pytest.raises(ConfigError) as ex: model.validate_definition() - assert error in str(ex.value) + + assert "require inferrable column types" in str(ex.value) def test_model_union_query(sushi_context, assert_exp_eq): @@ -232,7 +292,7 @@ def test_model_union_query(sushi_context, assert_exp_eq): """SELECT CAST("marketing"."customer_id" AS INT) AS "customer_id", CAST("marketing"."status" AS TEXT) AS "status", - CAST("marketing"."updated_at" AS TIMESTAMP) AS "updated_at", + CAST("marketing"."updated_at" AS TIMESTAMPNTZ) AS "updated_at", CAST("marketing"."valid_from" AS TIMESTAMP) AS "valid_from", CAST("marketing"."valid_to" AS TIMESTAMP) AS "valid_to" FROM "memory"."sushi"."marketing" AS "marketing" @@ -240,7 +300,7 @@ def test_model_union_query(sushi_context, assert_exp_eq): SELECT CAST("marketing"."customer_id" AS INT) AS "customer_id", CAST("marketing"."status" AS TEXT) AS "status", - CAST("marketing"."updated_at" AS TIMESTAMP) AS "updated_at", + CAST("marketing"."updated_at" AS TIMESTAMPNTZ) AS "updated_at", CAST("marketing"."valid_from" AS TIMESTAMP) AS "valid_from", CAST("marketing"."valid_to" AS TIMESTAMP) AS "valid_to" FROM "memory"."sushi"."marketing" AS "marketing" @@ -248,26 +308,145 @@ def test_model_union_query(sushi_context, assert_exp_eq): ) -def test_model_validation_union_query(): +@time_machine.travel("1996-02-10 00:00:00 UTC") +@pytest.mark.parametrize( + "test_id, condition, union_type, table_count, expected_result", + [ + # Test case 1: Basic conditional union - True condition + ( + "test_1", + "@get_date() == '1996-02-10'", + "'all'", + 2, + lambda expected_select: f"{expected_select}\nUNION ALL\n{expected_select}\n", + ), + # Test case 2: False condition - should return just first table + ( + "test_2", + "@get_date() > '1996-02-10'", + "'all'", + 2, + lambda expected_select: f"{expected_select}\n", + ), + # Test case 3: Multiple tables in union + ( + "test_3", + "@get_date() == '1996-02-10'", + "'all'", + 3, + lambda expected_select: f"{expected_select}\nUNION ALL\n{expected_select}\nUNION ALL\n{expected_select}\n", + ), + # Test case 4: DISTINCT type + ( + "test_4", + "@get_date() == '1996-02-10'", + "'distinct'", + 2, + lambda expected_select: f"{expected_select}\nUNION\n{expected_select}\n", + ), + # Test case 5: Complex condition + ( + "test_5", + "@get_date() = '1996-02-10' and 1=1 or @get_date() > '1996-02-10'", + "'distinct'", + 2, + lambda expected_select: f"{expected_select}\nUNION\n{expected_select}\n", + ), + # Test case 6: Missing union type (defaults to ALL) + ( + "test_6", + "@get_date() == '1996-02-10'", + "", + 2, + lambda expected_select: f"{expected_select}\nUNION ALL\n{expected_select}\n", + ), + # Test case 7: Missing union type AND condition + ( + "test_7", + "", + "", + 2, + lambda expected_select: f"{expected_select}\nUNION ALL\n{expected_select}\n", + ), + # Test case 8: Missing union type AND condition multiple tables + ( + "test_8", + "", + "", + 3, + lambda expected_select: f"{expected_select}\nUNION ALL\n{expected_select}\n\nUNION ALL\n{expected_select}\n", + ), + # Test case 9: Missing union type AND condition one table + ( + "test_9", + "", + "", + 1, + lambda expected_select: f"{expected_select}", + ), + # Test case 10: Union type with one table + ( + "test_10", + "", + "'distinct'", + 1, + lambda expected_select: f"{expected_select}", + ), + # Test case 11: Condition with one table + ( + "test_9", + "True", + "", + 1, + lambda expected_select: f"{expected_select}", + ), + ], +) +def test_model_union_conditional( + sushi_context, assert_exp_eq, test_id, condition, union_type, table_count, expected_result +): + @macro() + def get_date(evaluator): + from sqlmesh.utils.date import now + + return f"'{now().date()}'" + + expected_select = """SELECT + CAST("marketing"."customer_id" AS INT) AS "customer_id", + CAST("marketing"."status" AS TEXT) AS "status", + CAST("marketing"."updated_at" AS TIMESTAMPNTZ) AS "updated_at", + CAST("marketing"."valid_from" AS TIMESTAMP) AS "valid_from", + CAST("marketing"."valid_to" AS TIMESTAMP) AS "valid_to" +FROM "memory"."sushi"."marketing" AS "marketing" +""" + + # Create tables argument list based on table_count + tables = ", ".join(["sushi.marketing"] * table_count) + + # Handle the missing union_type case + union_type_arg = f", {union_type}" if union_type else "" + expressions = d.parse( - """ + f""" MODEL ( - name db.table, + name sushi.{test_id}, kind FULL, ); - SELECT a, a UNION SELECT c, c + @union({condition}{union_type_arg}, {tables}) """ ) + sushi_context.upsert_model(load_sql_based_model(expressions, default_catalog="memory")) - model = load_sql_based_model(expressions) - with pytest.raises(ConfigError, match=r"Found duplicate outer select name 'a'"): - model.validate_definition() + assert_exp_eq( + sushi_context.get_model(f"sushi.{test_id}").render_query(), + expected_result(expected_select), + ) -def test_model_qualification(): - logger = logging.getLogger("sqlmesh.core.renderer") - with patch.object(logger, "warning") as mock_logger: +@use_terminal_console +def test_model_qualification(tmp_path: Path): + with patch.object(get_console(), "log_warning") as mock_logger: expressions = d.parse( """ MODEL ( @@ -279,14 +458,79 @@ def test_model_qualification(): """ ) - model = load_sql_based_model(expressions) - model.render_query(optimize=True) + ctx = Context( + config=Config(linter=LinterConfig(enabled=True, warn_rules=["ALL"])), paths=tmp_path + ) + ctx.upsert_model(load_sql_based_model(expressions)) + ctx.plan_builder("dev") + + warning_msg = mock_logger.call_args[0][0] + assert "ambiguousorinvalidcolumn:" in warning_msg + assert "could not be resolved" in warning_msg + assert "db.table" in warning_msg + + +@use_terminal_console +def test_model_missing_audits(tmp_path: Path): + with patch.object(get_console(), "log_warning") as mock_logger: + expressions = d.parse( + """ + MODEL ( + name db.table, + kind FULL, + ); + + SELECT a + """ + ) + + ctx = Context( + config=Config(linter=LinterConfig(enabled=True, warn_rules=["nomissingaudits"])), + paths=tmp_path, + ) + ctx.upsert_model(load_sql_based_model(expressions)) + ctx.plan_builder("plan") + assert ( - mock_logger.call_args[0][0] - == "%s for model '%s', the column may not exist or is ambiguous" + """Model `audits` must be configured to test data quality.""" + in mock_logger.call_args[0][0] ) +def test_project_is_set_in_standalone_audit(tmp_path: Path) -> None: + init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY) + + db_path = str(tmp_path / "db.db") + db_connection = DuckDBConnectionConfig(database=db_path) + + config = Config( + project="test", + gateways={"gw": GatewayConfig(connection=db_connection)}, + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + ) + + model = tmp_path / "models" / "some_model.sql" + model.parent.mkdir(parents=True, exist_ok=True) + model.write_text("MODEL (name m); SELECT 1 AS c") + + audit = tmp_path / "audits" / "a_standalone_audit.sql" + audit.parent.mkdir(parents=True, exist_ok=True) + audit.write_text("AUDIT (name a, standalone true); SELECT * FROM m WHERE c <= 0") + + context = Context(paths=tmp_path, config=config) + context.plan(no_prompts=True, auto_apply=True) + + model = tmp_path / "models" / "some_model.sql" + model.parent.mkdir(parents=True, exist_ok=True) + model.write_text("MODEL (name m); SELECT 2 AS c") + + assert context.fetchdf( + "select snapshot -> 'node' -> 'project' AS standalone_audit_project " + """from sqlmesh._snapshots where (snapshot -> 'node' -> 'source_type')::text = '"audit"'""" + ).to_dict()["standalone_audit_project"] == {0: '"test"'} + assert context.load().standalone_audits["a"].project == "test" + + @pytest.mark.parametrize( "partition_by_input, partition_by_output, output_dialect, expected_exception", [ @@ -320,7 +564,7 @@ def test_partitioned_by( ) model = load_sql_based_model(expressions) - assert model.clustered_by == ["c", "d"] + assert model.clustered_by == [exp.to_column('"c"'), exp.to_column('"d"')] if expected_exception: with pytest.raises(expected_exception): model.validate_definition() @@ -331,19 +575,161 @@ def test_partitioned_by( ] == partition_by_output -def test_no_model_statement(): +def test_opt_out_of_time_column_in_partitioned_by(): + expressions = d.parse( + """ + MODEL ( + name db.table, + dialect bigquery, + partitioned_by b, + kind INCREMENTAL_BY_TIME_RANGE( + time_column a, + partition_by_time_column false + ), + ); + + SELECT 1::int AS a, 2::int AS b; + """ + ) + + model = load_sql_based_model(expressions) + assert model.partitioned_by == [exp.to_column('"b"')] + + +def test_model_no_name(): + expressions = d.parse( + """ + MODEL ( + dialect bigquery, + ); + + SELECT 1::int AS a, 2::int AS b; + """ + ) + + with pytest.raises(ConfigError) as ex: + load_sql_based_model(expressions) + assert ( + str(ex.value) + == "Please add the required 'name' field to the MODEL block at the top of the file.\n\nLearn more at https://sqlmesh.readthedocs.io/en/stable/concepts/models/overview" + ) + + +def test_model_field_name_suggestions(): + # top-level field + expressions = d.parse( + """ + MODEL ( + name db.table, + dialects bigquery, + ); + + SELECT 1::int AS a, 2::int AS b; + """ + ) + + with pytest.raises(ConfigError) as ex: + load_sql_based_model(expressions) + assert ( + str(ex.value) + == "Invalid field name present in the MODEL block: 'dialects'. Did you mean 'dialect'?" + ) + + # kind field + expressions = d.parse( + """ + MODEL ( + name db.table, + kind INCREMENTAL_BY_TIME_RANGE( + time_column a, + batch_sizes 1 + ), + ); + + SELECT 1::int AS a, 2::int AS b; + """ + ) + + with pytest.raises(ConfigError) as ex: + load_sql_based_model(expressions) + assert ( + str(ex.value) + == "Invalid field name present in the MODEL block 'kind INCREMENTAL_BY_TIME_RANGE' field: 'batch_sizes'. Did you mean 'batch_size'?" + ) + + # multiple fields + expressions = d.parse( + """ + MODEL ( + name db.table, + dialects bigquery, + descriptions 'a', + asdfasdf true + ); + + SELECT 1::int AS a, 2::int AS b; + """ + ) + + with pytest.raises(ConfigError) as ex: + load_sql_based_model(expressions) + ex_str = str(ex.value) + # field order is non-deterministic, so we can't test the output string directly + assert "Invalid field names present in the MODEL block: " in ex_str + assert "'descriptions'" in ex_str + assert "'dialects'" in ex_str + assert "'asdfasdf'" in ex_str + assert "- descriptions: Did you mean 'description'?" in ex_str + assert "- dialects: Did you mean 'dialect'?" in ex_str + assert "- asdfasdf: Did you mean " not in ex_str + + +def test_model_required_field_missing(): expressions = d.parse( """ - SELECT 1 AS x + MODEL ( + name db.table, + kind INCREMENTAL_BY_TIME_RANGE (), + ); + + SELECT 1::int AS a, 2::int AS b; """ ) + with pytest.raises(ConfigError) as ex: + load_sql_based_model(expressions) + assert ( + str(ex.value) + == "Please add required field 'time_column' to the MODEL block 'kind INCREMENTAL_BY_TIME_RANGE' field." + ) + + +def test_no_model_statement(tmp_path: Path): + # No name inference => MODEL (...) is required + expressions = d.parse("SELECT 1 AS x") with pytest.raises( ConfigError, - match="MODEL statement is required as the first statement in the definition at '.", + match="Please add a MODEL block at the top of the file. Example:", ): load_sql_based_model(expressions) + # Name inference is enabled => MODEL (...) not required + init_example_project(tmp_path, engine_type="duckdb") + + test_sql_file = tmp_path / "models/test_schema/test_model.sql" + test_sql_file.parent.mkdir(parents=True, exist_ok=True) + test_sql_file.write_text("SELECT 1 AS c") + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + model_naming=NameInferenceConfig(infer_names=True), + ) + context = Context(paths=tmp_path, config=config) + + model = context.get_model("test_schema.test_model") + assert isinstance(model, SqlModel) + assert model.name == "test_schema.test_model" + def test_unordered_model_statements(): expressions = d.parse( @@ -360,7 +746,7 @@ def test_unordered_model_statements(): with pytest.raises(ConfigError) as ex: load_sql_based_model(expressions) - assert "MODEL statement is required" in str(ex.value) + assert "Please add a MODEL block at the top of the file. Example:" in str(ex.value) def test_no_query(): @@ -373,12 +759,40 @@ def test_no_query(): ); @DEF(x, 1) - """ + """ ) with pytest.raises(ConfigError) as ex: - load_sql_based_model(expressions, path=Path("test_location")) - assert "have a SELECT" in str(ex.value) + model = load_sql_based_model(expressions, path=Path("test_location")) + model.validate_definition() + + assert "Model query needs to be a SELECT or a UNION, got @DEF(x, 1)." in str(ex.value) + + +def test_single_macro_as_query(assert_exp_eq): + @macro() + def select_query(evaluator, *projections): + return exp.select(*[f'{p} AS "{p}"' for p in projections]) + + expressions = d.parse( + """ + MODEL ( + name test + ); + + @SELECT_QUERY(1, 2, 3) + """ + ) + model = load_sql_based_model(expressions) + assert_exp_eq( + model.render_query(), + """ + SELECT + 1 AS "1", + 2 AS "2", + 3 AS "3" + """, + ) def test_partition_key_is_missing_in_query(): @@ -484,7 +898,29 @@ def test_json_serde(): deserialized_model = SqlModel.parse_raw(model_json) - assert deserialized_model == model + assert deserialized_model.dict() == model.dict() + + expressions = parse( + """ + MODEL ( + name test_model, + kind FULL, + dialect duckdb, + ); + + SELECT + x ~ y AS c + """ + ) + + model = load_sql_based_model(expressions) + model_json = model.json() + model_json_parsed = json.loads(model.json()) + + assert ( + SqlModel.parse_obj(model_json_parsed).render_query().sql("duckdb") + == 'SELECT REGEXP_FULL_MATCH("x", "y") AS "c"' + ) def test_scd_type_2_by_col_serde(): @@ -513,7 +949,7 @@ def test_scd_type_2_by_col_serde(): model_json_parsed = json.loads(model.json()) assert model_json_parsed["kind"]["dialect"] == "bigquery" assert model_json_parsed["kind"]["unique_key"] == ["`a`"] - assert model_json_parsed["kind"]["columns"] == "*" + assert model_json_parsed["kind"]["columns"] == ["*"] # Bigquery converts TIMESTAMP -> DATETIME assert model_json_parsed["kind"]["time_data_type"] == "DATETIME" @@ -585,6 +1021,8 @@ def test_macro(**kwargs) -> None: def test_model_pre_post_statements(): + macro.registry().pop("foo", None) + @macro() def foo(**kwargs) -> None: pass @@ -618,14 +1056,72 @@ def foo(**kwargs) -> None: ] assert model.pre_statements == expected_pre - expected_post = [ - *d.parse("@foo(bar='x', val=@this)"), - *d.parse("DROP TABLE x2;"), - ] + expected_post = d.parse("@foo(bar='x', val=@this); DROP TABLE x2;") assert model.post_statements == expected_post assert model.query == d.parse("SELECT 1 AS x")[0] + @macro() + def multiple_statements(evaluator, t1_value=exp.Literal.number(1)): + return [f"CREATE TABLE t1 AS SELECT {t1_value} AS c", "CREATE TABLE t2 AS SELECT 2 AS c"] + + expressions = d.parse( + """ + MODEL (name db.table); + + SELECT 1 AS col; + + @multiple_statements() + """ + ) + model = load_sql_based_model(expressions) + + expected_post = d.parse( + 'CREATE TABLE "t1" AS SELECT 1 AS "c"; CREATE TABLE "t2" AS SELECT 2 AS "c"' + ) + assert model.render_post_statements() == expected_post + assert "exp" in model.python_env + + +@pytest.mark.parametrize("model_kind", ["FULL", "VIEW"]) +def test_model_pre_post_statements_start_end_are_always_available(model_kind: str): + macro.registry().pop("foo", None) + macro.registry().pop("bar", None) + + @macro() + def foo(evaluator: MacroEvaluator, start: str, end: str) -> str: + return f"'{start}, {end}'" + + @macro() + def bar(evaluator: MacroEvaluator, start: int, end: int) -> str: + return f"'{start}, {end}'" + + expressions = d.parse( + f""" + MODEL ( + name db.table, + kind {model_kind}, + ); + + @foo(@start_ds, @end_ds); + + SELECT 1 AS x; + + @bar(@start_millis, @end_millis); + """ + ) + model = load_sql_based_model(expressions) + + start = "2025-01-01" + end = "2025-01-02" + + assert model.render_pre_statements(start=start, end=end) == [ + exp.Literal.string(f"{start}, {end}") + ] + assert model.render_post_statements(start=start, end=end) == [ + exp.Literal.string(f"{to_timestamp(start)}, {to_timestamp('2025-01-03') - 1}") + ] + def test_seed_hydration(): expressions = d.parse( @@ -702,7 +1198,7 @@ def test_seed_model_creation_error(): ); """ ) - with pytest.raises(ConfigError, match="No such file or directory"): + with pytest.raises(FileNotFoundError, match="No such file or directory"): load_sql_based_model(expressions) @@ -737,6 +1233,63 @@ def test_seed_provided_columns(): } +def test_seed_case_sensitive_columns(tmp_path): + model_csv_path = (tmp_path / "model.csv").absolute() + + with open(model_csv_path, "w", encoding="utf-8") as fd: + fd.write( + """camelCaseId,camelCaseBool,camelCaseString,normalisedCaseDate,camelCaseTimestamp +1,false,Alice,2022-01-01,2022-01-01 +""" + ) + + expressions = d.parse( + f""" + MODEL ( + name db.seed, + dialect postgres, + kind SEED ( + path '{str(model_csv_path)}', + ), + columns ( + "camelCaseId" int, + "camelCaseBool" boolean, + "camelCaseString" text, + "camelCaseTimestamp" timestamp + ) + ); + """ + ) + + model = load_sql_based_model(expressions, path=Path("./examples/sushi/models/test_model.sql")) + + assert isinstance(model.kind, SeedKind) + assert model.seed is not None + assert len(model.seed.content) > 0 + assert model.columns_to_types == { + "camelCaseId": exp.DataType.build("int"), + "camelCaseBool": exp.DataType.build("boolean"), + "camelCaseString": exp.DataType.build("text"), + "camelCaseTimestamp": exp.DataType.build("TIMESTAMP"), + } + df = next(model.render(context=None)) + + assert df["camelCaseId"].dtype == "int64" + assert df["camelCaseId"].iloc[0] == 1 + + assert df["camelCaseBool"].dtype == "bool" + assert not df["camelCaseBool"].iloc[0] + + assert df["camelCaseString"].dtype == "object" + assert df["camelCaseString"].iloc[0] == "Alice" + + assert df["normalisedcasedate"].dtype == "object" + assert df["normalisedcasedate"].iloc[0] == "2022-01-01" + + assert df["camelCaseTimestamp"].dtype == "datetime64[ns]" + assert df["camelCaseTimestamp"].iloc[0] == pd.Timestamp("2022-01-01 00:00:00") + + def test_seed_csv_settings(): expressions = d.parse( """ @@ -748,6 +1301,8 @@ def test_seed_csv_settings(): csv_settings ( quotechar = '''', escapechar = '\\', + keep_default_na = false, + na_values = (id = [1, '2', false, null], alias = ('foo')) ), ), columns ( @@ -761,7 +1316,39 @@ def test_seed_csv_settings(): model = load_sql_based_model(expressions, path=Path("./examples/sushi/models/test_model.sql")) assert isinstance(model.kind, SeedKind) - assert model.kind.csv_settings == CsvSettings(quotechar="'", escapechar="\\") + assert model.kind.csv_settings == CsvSettings( + quotechar="'", + escapechar="\\", + na_values={"id": [1, "2", False, None], "alias": ["foo"]}, + keep_default_na=False, + ) + assert model.kind.data_hash_values == [ + "SEED", + "'", + "\\", + "{'id': [1, '2', False, None], 'alias': ['foo']}", + "False", + ] + + expressions = d.parse( + """ + MODEL ( + name db.seed, + kind SEED ( + path '../seeds/waiter_names.csv', + csv_settings ( + na_values = ('#N/A', 'other') + ), + ), + ); + """ + ) + + model = load_sql_based_model(expressions, path=Path("./examples/sushi/models/test_model.sql")) + + assert isinstance(model.kind, SeedKind) + assert model.kind.csv_settings == CsvSettings(na_values=["#N/A", "other"]) + assert model.kind.data_hash_values == ["SEED", "['#N/A', 'other']"] def test_seed_marker_substitution(): @@ -784,12 +1371,14 @@ def test_seed_marker_substitution(): ) assert isinstance(model.kind, SeedKind) - assert model.kind.path == "examples/sushi/seeds/waiter_names.csv" + assert model.kind.path == str(Path("examples/sushi/seeds/waiter_names.csv")) assert model.seed is not None assert len(model.seed.content) > 0 def test_seed_pre_post_statements(): + macro.registry().pop("bar", None) + @macro() def bar(**kwargs) -> None: pass @@ -862,6 +1451,40 @@ def test_seed_pre_statements_only(): assert not model.post_statements +def test_seed_on_virtual_update_statements(): + expressions = d.parse( + """ + MODEL ( + name db.seed, + kind SEED ( + path '../seeds/waiter_names.csv', + batch_size 100, + ) + ); + + JINJA_STATEMENT_BEGIN; + CREATE TABLE x{{ 1 + 1 }}; + JINJA_END; + + ON_VIRTUAL_UPDATE_BEGIN; + JINJA_STATEMENT_BEGIN; + GRANT SELECT ON VIEW {{ this_model }} TO ROLE dev_role; + JINJA_END; + DROP TABLE x2; + ON_VIRTUAL_UPDATE_END; + + """ + ) + + model = load_sql_based_model(expressions, path=Path("./examples/sushi/models/test_model.sql")) + + assert model.pre_statements == [d.jinja_statement("CREATE TABLE x{{ 1 + 1 }};")] + assert model.on_virtual_update == [ + d.jinja_statement("GRANT SELECT ON VIEW {{ this_model }} TO ROLE dev_role;"), + *d.parse("DROP TABLE x2;"), + ] + + def test_seed_model_custom_types(tmp_path): model_csv_path = (tmp_path / "model.csv").absolute() @@ -946,6 +1569,66 @@ def test_seed_with_special_characters_in_column(tmp_path, assert_exp_eq): ) +def test_python_model_jinja_pre_post_statements(): + macros = """ + {% macro test_macro(v) %}{{ v }}{% endmacro %} + {% macro extra_macro(v) %}{{ v + 1 }}{% endmacro %} + """ + + jinja_macros = JinjaMacroRegistry() + jinja_macros.add_macros(MacroExtractor().extract(macros)) + + @model( + "db.test_model", + kind="full", + columns={"id": "string", "name": "string"}, + pre_statements=[ + "JINJA_STATEMENT_BEGIN;\n{% set table_name = 'x' %}\nCREATE OR REPLACE TABLE {{table_name}}{{ 1 + 1 }};\nJINJA_END;" + ], + post_statements=[ + "JINJA_STATEMENT_BEGIN;\nCREATE INDEX {{test_macro('idx')}} ON db.test_model(id);\nJINJA_END;", + parse_one("DROP TABLE x2;"), + ], + ) + def model_with_statements(context, **kwargs): + return pd.DataFrame( + [ + { + "id": context.var("1"), + "name": context.var("var"), + } + ] + ) + + python_model = model.get_registry()["db.test_model"].model( + module_path=Path("."), path=Path("."), dialect="duckdb", jinja_macros=jinja_macros + ) + + assert len(jinja_macros.root_macros) == 2 + assert len(python_model.jinja_macros.root_macros) == 1 + assert "test_macro" in python_model.jinja_macros.root_macros + assert "extra_macro" not in python_model.jinja_macros.root_macros + + expected_pre = [ + d.jinja_statement( + "{% set table_name = 'x' %}\nCREATE OR REPLACE TABLE {{table_name}}{{ 1 + 1 }};" + ), + ] + assert python_model.pre_statements == expected_pre + assert python_model.render_pre_statements()[0].sql() == 'CREATE OR REPLACE TABLE "x2"' + + expected_post = [ + d.jinja_statement("CREATE INDEX {{test_macro('idx')}} ON db.test_model(id);"), + *d.parse("DROP TABLE x2;"), + ] + assert python_model.post_statements == expected_post + assert ( + python_model.render_post_statements()[0].sql() + == 'CREATE INDEX "idx" ON "db"."test_model"("id" NULLS LAST)' + ) + assert python_model.render_post_statements()[1].sql() == 'DROP TABLE "x2"' + + def test_audits(): expressions = d.parse( """ @@ -953,7 +1636,8 @@ def test_audits(): name db.seed, audits ( audit_a, - audit_b(key='value') + audit_b(key='value'), + audit_c(key=@start_ds) ), tags (foo) ); @@ -961,30 +1645,193 @@ def test_audits(): """ ) - model = load_sql_based_model(expressions, path=Path("./examples/sushi/models/test_model.sql")) + audit_definitions = { + audit_name: load_audit( + d.parse(f"AUDIT (name {audit_name}); SELECT 1 WHERE FALSE"), dialect="duckdb" + ) + for audit_name in ("audit_a", "audit_b", "audit_c") + } + + model = load_sql_based_model( + expressions, + path=Path("./examples/sushi/models/test_model.sql"), + audit_definitions=audit_definitions, + ) assert model.audits == [ ("audit_a", {}), ("audit_b", {"key": exp.Literal.string("value")}), + ("audit_c", {"key": d.MacroVar(this="start_ds")}), ] assert model.tags == ["foo"] +def test_enable_audits_from_model_defaults(): + expressions = d.parse( + """ + MODEL ( + name db.audit_model, + ); + SELECT 1 as id; + + AUDIT ( + name assert_positive_order_ids, + ); + SELECT * + FROM @this_model + WHERE + id < 0; + """ + ) + + model_defaults = ModelDefaultsConfig(dialect="duckdb", audits=["assert_positive_order_ids"]) + + model = load_sql_based_model( + expressions, + path=Path("./examples/sushi/models/test_model.sql"), + defaults=model_defaults.dict(), + ) + + assert len(model.audits) == 1 + + config = Config(model_defaults=model_defaults) + assert config.model_defaults.audits[0] == ("assert_positive_order_ids", {}) == model.audits[0] + + audits_with_args = model.audits_with_args + assert len(audits_with_args) == 1 + audit, args = audits_with_args[0] + assert type(audit) == ModelAudit + assert args == {} + assert audit.query.sql() == "SELECT * FROM @this_model WHERE id < 0" + + def test_description(sushi_context): assert sushi_context.models['"memory"."sushi"."orders"'].description == "Table of sushi orders." +def test_model_defaults_statements_merge(): + model_defaults = ModelDefaultsConfig( + dialect="duckdb", + pre_statements=[ + "SET enable_progress_bar = true", + "CREATE TEMP TABLE default_temp AS SELECT 1", + ], + post_statements=[ + "DROP TABLE IF EXISTS default_temp", + "grant select on @this_model to group reporter", + ], + on_virtual_update=["ANALYZE"], + ) + + # Create a model with its own statements as well + expressions = parse( + """ + MODEL ( + name test_model, + kind FULL + ); + + CREATE TEMP TABLE model_temp AS SELECT 2; + + SELECT * FROM test_table; + + DROP TABLE IF EXISTS model_temp; + + ON_VIRTUAL_UPDATE_BEGIN; + UPDATE stats_table SET last_update = CURRENT_TIMESTAMP; + ON_VIRTUAL_UPDATE_END; + """ + ) + + model = load_sql_based_model( + expressions, + path=Path("./test_model.sql"), + defaults=model_defaults.dict(), + ) + + # Check that pre_statements contains both default and model-specific statements + assert len(model.pre_statements) == 3 + assert model.pre_statements[0].sql() == "SET enable_progress_bar = TRUE" + assert model.pre_statements[1].sql() == "CREATE TEMPORARY TABLE default_temp AS SELECT 1" + assert model.pre_statements[2].sql() == "CREATE TEMPORARY TABLE model_temp AS SELECT 2" + + # Check that post_statements contains both default and model-specific statements + assert len(model.post_statements) == 3 + assert model.post_statements[0].sql() == "DROP TABLE IF EXISTS default_temp" + assert model.post_statements[1].sql() == "GRANT SELECT ON @this_model TO GROUP reporter" + assert model.post_statements[2].sql() == "DROP TABLE IF EXISTS model_temp" + + # Check that the query is rendered correctly with @this_model resolved to table name + assert ( + model.render_post_statements()[1].sql() + == 'GRANT SELECT ON "test_model" TO GROUP "reporter"' + ) + + # Check that on_virtual_update contains both default and model-specific statements + assert len(model.on_virtual_update) == 2 + assert model.on_virtual_update[0].sql() == "ANALYZE" + assert ( + model.on_virtual_update[1].sql() + == "UPDATE stats_table SET last_update = CURRENT_TIMESTAMP()" + ) + + +def test_model_defaults_statements_integration(): + config = Config( + model_defaults=ModelDefaultsConfig( + dialect="postgres", + pre_statements=["SET memory_limit = '10GB'"], + post_statements=["VACUUM ANALYZE"], + on_virtual_update=["GRANT SELECT ON @this_model TO GROUP public"], + ) + ) + + expressions = parse( + """ + MODEL ( + name test_model, + kind FULL + ); + + SELECT * FROM source_table; + """ + ) + + model = load_sql_based_model( + expressions, + path=Path("./test_model.sql"), + defaults=config.model_defaults.dict(), + ) + + # Verify defaults were applied + assert len(model.pre_statements) == 1 + assert model.pre_statements[0].sql() == "SET memory_limit = '10GB'" + + assert len(model.post_statements) == 1 + assert isinstance(model.post_statements[0], exp.Command) + + assert len(model.on_virtual_update) == 1 + assert model.on_virtual_update[0].sql() == "GRANT SELECT ON @this_model TO GROUP public" + assert ( + model.render_on_virtual_update()[0].sql() + == 'GRANT SELECT ON "test_model" TO GROUP "public"' + ) + + def test_render_definition(): expressions = d.parse( """ MODEL ( name db.table, owner owner_name, + cron_tz 'America/Los_Angeles', dialect spark, kind INCREMENTAL_BY_TIME_RANGE ( time_column (`a`, 'yyyymmdd'), + partition_by_time_column TRUE, forward_only FALSE, disable_restatement FALSE, - on_destructive_change 'ERROR' + on_destructive_change 'ERROR', + on_additive_change 'ALLOW' ), storage_format iceberg, partitioned_by `a`, @@ -1078,7 +1925,9 @@ def test_render_definition_with_defaults(): dialect spark, kind VIEW ( materialized FALSE - ) + ), + virtual_environment_mode 'full', + grants_target_layer 'virtual' ); {query} @@ -1091,33 +1940,309 @@ def test_render_definition_with_defaults(): ) == d.format_model_expressions(expected_expressions) -def test_cron(): - daily = _Node(name="x", cron="@daily") - assert to_datetime(daily.cron_prev("2020-01-01")) == to_datetime("2019-12-31") - assert to_datetime(daily.cron_floor("2020-01-01")) == to_datetime("2020-01-01") - assert to_timestamp(daily.cron_floor("2020-01-01 10:00:00")) == to_timestamp("2020-01-01") - assert to_timestamp(daily.cron_next("2020-01-01 10:00:00")) == to_timestamp("2020-01-02") - interval = daily.interval_unit - assert to_datetime(interval.cron_prev("2020-01-01")) == to_datetime("2019-12-31") - assert to_datetime(interval.cron_floor("2020-01-01")) == to_datetime("2020-01-01") - assert to_timestamp(interval.cron_floor("2020-01-01 10:00:00")) == to_timestamp("2020-01-01") - assert to_timestamp(interval.cron_next("2020-01-01 10:00:00")) == to_timestamp("2020-01-02") +def test_render_definition_with_grants(): + from sqlmesh.core.model.meta import GrantsTargetLayer - offset = _Node(name="x", cron="1 0 * * *") - assert to_datetime(offset.cron_prev("2020-01-01")) == to_datetime("2019-12-31 00:01") - assert to_datetime(offset.cron_floor("2020-01-01")) == to_datetime("2019-12-31 00:01") - assert to_timestamp(offset.cron_floor("2020-01-01 10:00:00")) == to_timestamp( - "2020-01-01 00:01" + expressions = d.parse( + """ + MODEL ( + name test.grants_model, + kind FULL, + grants ( + 'select' = ['user1', 'user2'], + 'insert' = ['admin'], + 'roles/bigquery.dataViewer' = ['user:data_eng@mycompany.com'] + ), + grants_target_layer all, + ); + SELECT 1 as id + """ ) - assert to_timestamp(offset.cron_next("2020-01-01 10:00:00")) == to_timestamp("2020-01-02 00:01") - interval = offset.interval_unit - assert to_datetime(interval.cron_prev("2020-01-01")) == to_datetime("2019-12-31") - assert to_datetime(interval.cron_floor("2020-01-01")) == to_datetime("2020-01-01") - assert to_timestamp(interval.cron_floor("2020-01-01 10:00:00")) == to_timestamp("2020-01-01") - assert to_timestamp(interval.cron_next("2020-01-01 10:00:00")) == to_timestamp("2020-01-02") + model = load_sql_based_model(expressions) + assert model.grants_target_layer == GrantsTargetLayer.ALL + assert model.grants == { + "select": ["user1", "user2"], + "insert": ["admin"], + "roles/bigquery.dataViewer": ["user:data_eng@mycompany.com"], + } - hourly = _Node(name="x", cron="1 * * * *") - assert to_timestamp(hourly.cron_prev("2020-01-01 10:00:00")) == to_timestamp( + rendered = model.render_definition(include_defaults=True) + rendered_text = d.format_model_expressions(rendered) + assert "grants_target_layer 'all'" in rendered_text + assert re.search( + r"grants\s*\(" + r"\s*'select'\s*=\s*ARRAY\('user1',\s*'user2'\)," + r"\s*'insert'\s*=\s*ARRAY\('admin'\)," + r"\s*'roles/bigquery.dataViewer'\s*=\s*ARRAY\('user:data_eng@mycompany.com'\)" + r"\s*\)", + rendered_text, + ) + + model_with_grants = create_sql_model( + name="test_grants_programmatic", + query=d.parse_one("SELECT 1 as id"), + grants={"select": ["user1", "user2"], "insert": ["admin"]}, + grants_target_layer=GrantsTargetLayer.ALL, + ) + assert model_with_grants.grants == {"select": ["user1", "user2"], "insert": ["admin"]} + assert model_with_grants.grants_target_layer == GrantsTargetLayer.ALL + rendered_text = d.format_model_expressions( + model_with_grants.render_definition(include_defaults=True) + ) + assert "grants_target_layer 'all'" in rendered_text + assert re.search( + r"grants\s*\(" + r"\s*'select'\s*=\s*ARRAY\('user1',\s*'user2'\)," + r"\s*'insert'\s*=\s*ARRAY\('admin'\)" + r"\s*\)", + rendered_text, + ) + + virtual_expressions = d.parse( + """ + MODEL ( + name test.virtual_grants_model, + kind FULL, + grants_target_layer virtual + ); + SELECT 1 as id + """ + ) + virtual_model = load_sql_based_model(virtual_expressions) + assert virtual_model.grants_target_layer == GrantsTargetLayer.VIRTUAL + + default_expressions = d.parse( + """ + MODEL ( + name test.default_grants_model, + kind FULL + ); + SELECT 1 as id + """ + ) + default_model = load_sql_based_model(default_expressions) + assert default_model.grants_target_layer == GrantsTargetLayer.VIRTUAL # default value + + +def test_render_definition_partitioned_by(): + # no parenthesis in definition, no parenthesis when rendered + model = load_sql_based_model( + d.parse( + f""" + MODEL ( + name db.table, + kind FULL, + partitioned_by a + ); + + select 1 as a; + """ + ) + ) + + assert model.partitioned_by == [exp.column("a", quoted=True)] + assert ( + model.render_definition()[0].sql(pretty=True) + == """MODEL ( + name db.table, + kind FULL, + partitioned_by "a" +)""" + ) + + # single column wrapped in parenthesis in defintion, no parenthesis in rendered + model = load_sql_based_model( + d.parse( + f""" + MODEL ( + name db.table, + kind FULL, + partitioned_by (a) + ); + + select 1 as a; + """ + ) + ) + + assert model.partitioned_by == [exp.column("a", quoted=True)] + assert ( + model.render_definition()[0].sql(pretty=True) + == """MODEL ( + name db.table, + kind FULL, + partitioned_by "a" +)""" + ) + + # multiple columns wrapped in parenthesis in definition, parenthesis in rendered + model = load_sql_based_model( + d.parse( + f""" + MODEL ( + name db.table, + kind FULL, + partitioned_by (a, b) + ); + + select 1 as a, 2 as b; + """ + ) + ) + + assert model.partitioned_by == [exp.column("a", quoted=True), exp.column("b", quoted=True)] + assert ( + model.render_definition()[0].sql(pretty=True) + == """MODEL ( + name db.table, + kind FULL, + partitioned_by ("a", "b") +)""" + ) + + # multiple columns not wrapped in parenthesis in the definition is an error + with pytest.raises(ParseError, match=r"keyword: 'value' missing"): + load_sql_based_model( + d.parse( + f""" + MODEL ( + name db.table, + kind FULL, + partitioned_by a, b + ); + + select 1 as a, 2 as b; + """ + ) + ) + + # Iceberg transforms / functions + model = load_sql_based_model( + d.parse( + f""" + MODEL ( + name db.table, + kind FULL, + partitioned_by (day(a), truncate(b, 4), bucket(c, 3)) + ); + + select 1 as a, 2 as b, 3 as c; + """ + ), + dialect="trino", + ) + + assert model.partitioned_by == [ + exp.Day(this=exp.column("a", quoted=True)), + exp.PartitionByTruncate( + this=exp.column("b", quoted=True), expression=exp.Literal.number(4) + ), + exp.PartitionedByBucket( + this=exp.column("c", quoted=True), expression=exp.Literal.number(3) + ), + ] + assert ( + model.render_definition()[0].sql(pretty=True) + == """MODEL ( + name db.table, + dialect trino, + kind FULL, + partitioned_by (DAY("a"), TRUNCATE("b", 4), BUCKET("c", 3)) +)""" + ) + + +def test_render_definition_with_virtual_update_statements(): + # model has virtual update statements + model = load_sql_based_model( + d.parse( + f""" + MODEL ( + name db.table, + kind FULL + ); + + select 1 as a; + + ON_VIRTUAL_UPDATE_BEGIN; + GRANT SELECT ON VIEW @this_model TO ROLE role_name + ON_VIRTUAL_UPDATE_END; + """ + ) + ) + + assert model.on_virtual_update == [ + exp.Grant( + privileges=[exp.GrantPrivilege(this=exp.Var(this="SELECT"))], + kind="VIEW", + securable=exp.Table(this=d.MacroVar(this="this_model")), + principals=[ + exp.GrantPrincipal(this=exp.Identifier(this="role_name", quoted=False), kind="ROLE") + ], + ) + ] + assert ( + model.render_definition()[-1].sql(pretty=True) + == """ON_VIRTUAL_UPDATE_BEGIN; +GRANT SELECT ON VIEW @this_model TO ROLE role_name; +ON_VIRTUAL_UPDATE_END;""" + ) + + +def test_render_definition_dbt_node_info(): + node_info = DbtNodeInfo(unique_id="model.db.table", name="table", fqn="db.table") + model = load_sql_based_model( + d.parse( + f""" + MODEL ( + name db.table, + kind FULL + ); + + select 1 as a; + """ + ), + dbt_node_info=node_info, + ) + + assert model.dbt_node_info + assert ( + model.render_definition()[0].sql(pretty=True) + == """MODEL ( + name db.table, + dbt_node_info (fqn := 'db.table', name := 'table', unique_id := 'model.db.table'), + kind FULL +)""" + ) + + +def test_cron(): + daily = _Node(name="x", cron="@daily") + assert to_datetime(daily.cron_prev("2020-01-01")) == to_datetime("2019-12-31") + assert to_datetime(daily.cron_floor("2020-01-01")) == to_datetime("2020-01-01") + assert to_timestamp(daily.cron_floor("2020-01-01 10:00:00")) == to_timestamp("2020-01-01") + assert to_timestamp(daily.cron_next("2020-01-01 10:00:00")) == to_timestamp("2020-01-02") + interval = daily.interval_unit + assert to_datetime(interval.cron_prev("2020-01-01")) == to_datetime("2019-12-31") + assert to_datetime(interval.cron_floor("2020-01-01")) == to_datetime("2020-01-01") + assert to_timestamp(interval.cron_floor("2020-01-01 10:00:00")) == to_timestamp("2020-01-01") + assert to_timestamp(interval.cron_next("2020-01-01 10:00:00")) == to_timestamp("2020-01-02") + + offset = _Node(name="x", cron="1 0 * * *") + assert to_datetime(offset.cron_prev("2020-01-01")) == to_datetime("2019-12-31 00:01") + assert to_datetime(offset.cron_floor("2020-01-01")) == to_datetime("2019-12-31 00:01") + assert to_timestamp(offset.cron_floor("2020-01-01 10:00:00")) == to_timestamp( + "2020-01-01 00:01" + ) + assert to_timestamp(offset.cron_next("2020-01-01 10:00:00")) == to_timestamp("2020-01-02 00:01") + interval = offset.interval_unit + assert to_datetime(interval.cron_prev("2020-01-01")) == to_datetime("2019-12-31") + assert to_datetime(interval.cron_floor("2020-01-01")) == to_datetime("2020-01-01") + assert to_timestamp(interval.cron_floor("2020-01-01 10:00:00")) == to_timestamp("2020-01-01") + assert to_timestamp(interval.cron_next("2020-01-01 10:00:00")) == to_timestamp("2020-01-02") + + hourly = _Node(name="x", cron="1 * * * *") + assert to_timestamp(hourly.cron_prev("2020-01-01 10:00:00")) == to_timestamp( "2020-01-01 09:01:00" ) assert to_timestamp(hourly.cron_prev("2020-01-01 10:02:00")) == to_timestamp( @@ -1569,11 +2694,15 @@ def test_parse(assert_exp_eq): dialect '', ); + JINJA_QUERY_BEGIN; + SELECT id::INT AS id, ds FROM x - WHERE ds BETWEEN '{{ start_ds }}' AND @end_ds + WHERE ds BETWEEN '{{ start_ds }}' AND @end_ds; + + JINJA_END; """ ) model = load_sql_based_model(expressions, dialect="hive") @@ -1583,8 +2712,8 @@ def test_parse(assert_exp_eq): } assert not model.annotated assert model.dialect == "" - assert isinstance(model.query, exp.Select) - assert isinstance(SqlModel.parse_raw(model.json()).query, exp.Select) + assert isinstance(model.query, d.JinjaQuery) + assert isinstance(SqlModel.parse_raw(model.json()).query, d.JinjaQuery) assert_exp_eq( model.render_query(), """ @@ -1598,16 +2727,173 @@ def test_parse(assert_exp_eq): ) +def test_dialect_pattern(): + def make_test_sql(text: str) -> str: + return f""" + MODEL ( + name test_model, + kind INCREMENTAL_BY_TIME_RANGE( + time_column ds + ), + {text} + ); + + SELECT 1; + """ + + def assert_match(test_sql: str, expected_value: t.Optional[str] = "duckdb"): + match = d.DIALECT_PATTERN.search(test_sql) + + dialect_str: t.Optional[str] = None + if expected_value is not None: + assert match + dialect_str = match.group("dialect") + + assert dialect_str == expected_value + + # single-quoted dialect + assert_match( + make_test_sql( + """ + dialect 'duckdb', + description 'there's a dialect foo in here too!' + """ + ) + ) + + # bare dialect + assert_match( + make_test_sql( + """ + dialect duckdb, + description 'there's a dialect foo in here too!' + """ + ) + ) + + # double-quoted dialect (allowed in BQ) + assert_match( + make_test_sql( + """ + dialect "duckdb", + description 'there's a dialect foo in here too!' + """ + ) + ) + + # no dialect specified, "dialect" in description + test_sql = make_test_sql( + """ + description 'there's a dialect foo in here too!' + """ + ) + + matches = list(d.DIALECT_PATTERN.finditer(test_sql)) + assert not matches + + # line comment between properties + assert_match( + make_test_sql( + """ + tag my_tag, -- comment + dialect duckdb + """ + ) + ) + + # block comment between properties + assert_match( + make_test_sql( + """ + tag my_tag, /* comment */ + dialect duckdb + """ + ) + ) + + # quoted empty dialect + assert_match( + make_test_sql( + """ + dialect '', + tag my_tag + """ + ), + None, + ) + + # double-quoted empty dialect + assert_match( + make_test_sql( + """ + dialect "", + tag my_tag + """ + ), + None, + ) + + # trailing comment after dialect value + assert_match( + make_test_sql( + """ + dialect duckdb -- trailing comment + """ + ) + ) + + # dialect value isn't terminated by ',' or ')' + test_sql = make_test_sql( + """ + dialect duckdb -- trailing comment + tag my_tag + """ + ) + + matches = list(d.DIALECT_PATTERN.finditer(test_sql)) + assert not matches + + # dialect first + assert_match( + """ + MODEL( + dialect duckdb, + name my_name + ); + """ + ) + + # full parse + sql = """ + MODEL ( + name test_model, + description 'this text mentions dialect foo but is not a property' + ); + + SELECT 1; + """ + expressions = d.parse(sql, default_dialect="duckdb") + model = load_sql_based_model(expressions) + assert model.dialect == "" + + CONST = "bar" def test_python_model(assert_exp_eq) -> None: from functools import reduce - @model(name="my_model", kind="full", columns={'"COL"': "int"}, enabled=True) + @model( + name="my_model", + kind="full", + columns={'"COL"': "int"}, + pre_statements=["CACHE TABLE x AS SELECT 1;"], + post_statements=["DROP TABLE x;"], + enabled=True, + ) def my_model(context, **kwargs): - context.table("foo") - context.table(model_name=CONST + ".baz") + context.resolve_table("foo") + context.resolve_table(model_name=CONST + ".baz") # This checks that built-in functions are serialized properly a = reduce(lambda x, y: x + y, [1, 2, 3, 4]) # noqa: F841 @@ -1615,18 +2901,24 @@ def my_model(context, **kwargs): m = model.get_registry()["my_model"].model( module_path=Path("."), path=Path("."), - dialect="duckdb", + dialect="duckdb,normalization_strategy=LOWERCASE", ) + assert list(m.pre_statements) == [ + d.parse_one("CACHE TABLE x AS SELECT 1"), + ] + assert list(m.post_statements) == [ + d.parse_one("DROP TABLE x"), + ] assert m.enabled - assert m.dialect == "duckdb" + assert m.dialect == "duckdb,normalization_strategy=lowercase" assert m.depends_on == {'"foo"', '"bar"."baz"'} - assert m.columns_to_types == {"col": exp.DataType.build("int")} + assert m.columns_to_types == {"COL": exp.DataType.build("int")} assert_exp_eq( m.ctas_query(), """ SELECT - CAST(NULL AS INT) AS "col" + CAST(NULL AS INT) AS "COL" FROM (VALUES (1)) AS t(dummy) WHERE @@ -1644,28 +2936,54 @@ def test_python_model_depends_on() -> None: depends_on={"foo.bar"}, ) def my_model(context, **kwargs): - context.table("foo") - context.table(model_name=CONST + ".baz") + context.resolve_table("foo") + context.resolve_table(model_name=CONST + ".baz") m = model.get_registry()["model_with_depends_on"].model( module_path=Path("."), path=Path("."), ) - # We are not expecting the context.table() calls to be reflected in the model's depends_on since we - # explicitly specified the depends_on argument. + # We are not expecting the context.resolve_table() calls to be reflected in the + # model's depends_on since we explicitly specified the depends_on argument. assert m.depends_on == {'"foo"."bar"'} -def test_python_model_with_session_properties(): +def test_python_model_variable_dependencies() -> None: @model( - name="python_model_prop", + name="bla.test_model_var_dep", kind="full", - columns={"some_col": "int"}, - session_properties={"some_string": "string_prop", "some_bool": True, "some_float": 1.0}, + columns={'"col"': "int"}, + depends_on={"@schema_name.table_name"}, ) - def python_model_prop(context, **kwargs): - context.table("foo") + def my_model(context, **kwargs): + # Even though the argument is not statically resolvable, no error + # is raised, because the `depends_on` property is present + schema_name = context.var("schema_name") + table = context.resolve_table(f"{schema_name}.table_name") + + return context.fetchdf(exp.select("*").from_(table)) + + m = model.get_registry()["bla.test_model_var_dep"].model( + module_path=Path("."), + path=Path("."), + variables={"schema_name": "foo"}, + ) + + assert m.depends_on == {'"foo"."table_name"'} + + +def test_python_model_with_properties(make_snapshot): + @model( + name="python_model_prop", + kind="full", + columns={"some_col": "int"}, + session_properties={"some_string": "string_prop", "some_bool": True, "some_float": 1.0}, + physical_properties={"partition_expiration_days": 7}, + virtual_properties={"creatable_type": None}, + ) + def python_model_prop(context, **kwargs): + context.resolve_table("foo") m = model.get_registry()["python_model_prop"].model( module_path=Path("."), @@ -1675,7 +2993,15 @@ def python_model_prop(context, **kwargs): "session_properties": { "some_string": "default_string", "default_value": "default_value", - } + }, + "physical_properties": { + "partition_expiration_days": 13, + "creatable_type": "@IF(@model_kind_name != 'view', 'TRANSIENT', NULL)", + "conditional_prop": "@IF(@model_kind_name == 'view', 'view_prop', NULL)", + }, + "virtual_properties": { + "creatable_type": "SECURE", + }, }, ) assert m.session_properties == { @@ -1685,6 +3011,27 @@ def python_model_prop(context, **kwargs): "default_value": "default_value", } + assert m.physical_properties == { + "partition_expiration_days": exp.convert(7), + "creatable_type": exp.maybe_parse( + "@IF(@model_kind_name != 'view', 'TRANSIENT', NULL)", dialect="duckdb" + ), + "conditional_prop": exp.maybe_parse( + "@IF(@model_kind_name == 'view', 'view_prop', NULL)", dialect="duckdb" + ), + } + + snapshot: Snapshot = make_snapshot(m) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + # Rendering the properties will result to a TRANSIENT creatable_type and the removal of the conditional prop + assert m.render_physical_properties(snapshots={m.fqn: snapshot}, python_env=m.python_env) == { + "partition_expiration_days": exp.convert(7), + "creatable_type": exp.convert("TRANSIENT"), + } + + assert not m.virtual_properties + def test_python_models_returning_sql(assert_exp_eq) -> None: config = Config(model_defaults=ModelDefaultsConfig(dialect="snowflake")) @@ -1752,9 +3099,7 @@ def model2_entrypoint(evaluator: MacroEvaluator) -> str: def test_python_model_decorator_kind() -> None: - logger = logging.getLogger("sqlmesh.core.model.decorator") - - # no kind specified -> default View kind + # no kind specified -> default Full kind @model("default_kind", columns={'"COL"': "int"}) def a_model(context): pass @@ -1764,10 +3109,10 @@ def a_model(context): path=Path("."), ) - assert isinstance(python_model.kind, ViewKind) + assert isinstance(python_model.kind, FullKind) # string kind name specified - @model("kind_string", kind="full", columns={'"COL"': "int"}) + @model("kind_string", kind="external", columns={'"COL"': "int"}) def b_model(context): pass @@ -1776,7 +3121,7 @@ def b_model(context): path=Path("."), ) - assert isinstance(python_model.kind, FullKind) + assert isinstance(python_model.kind, ExternalKind) @model("kind_empty_dict", kind=dict(), columns={'"COL"': "int"}) def my_model(context): @@ -1805,7 +3150,7 @@ def my_model_2(context): pass # warning if kind is ModelKind instance - with patch.object(logger, "warning") as mock_logger: + with patch.object(get_console(), "log_warning") as mock_logger: python_model = model.get_registry()["kind_instance"].model( module_path=Path("."), path=Path("."), @@ -1817,9 +3162,17 @@ def my_model_2(context): ) # no warning with valid kind dict - with patch.object(logger, "warning") as mock_logger: - - @model("kind_valid_dict", kind=dict(name=ModelKindName.FULL), columns={'"COL"': "int"}) + with patch.object(get_console(), "log_warning") as mock_logger: + + @model( + "kind_valid_dict", + kind=dict( + name=ModelKindName.INCREMENTAL_BY_TIME_RANGE, + time_column="ds", + auto_restatement_cron="@hourly", + ), + columns={'"ds"': "date", '"COL"': "int"}, + ) def my_model(context): pass @@ -1828,11 +3181,33 @@ def my_model(context): path=Path("."), ) - assert isinstance(python_model.kind, FullKind) + assert isinstance(python_model.kind, IncrementalByTimeRangeKind) assert not mock_logger.call_args +def test_python_model_decorator_auto_restatement_cron() -> None: + @model( + "auto_restatement_model", + cron="@daily", + kind=dict( + name=ModelKindName.INCREMENTAL_BY_TIME_RANGE, + time_column="ds", + auto_restatement_cron="@hourly", + ), + columns={'"ds"': "date", '"COL"': "int"}, + ) + def my_model(context): + pass + + python_model = model.get_registry()["auto_restatement_model"].model( + module_path=Path("."), + path=Path("."), + ) + + assert python_model.auto_restatement_cron == "@hourly" + + def test_python_model_decorator_col_descriptions() -> None: # `columns` and `column_descriptions` column names are different cases, but name normalization makes both lower @model("col_descriptions", columns={"col": "int"}, column_descriptions={"COL": "a column"}) @@ -1855,11 +3230,39 @@ def a_model(context): def b_model(context): pass - with pytest.raises(ConfigError, match="a description is provided for column 'COL'"): + with patch.object(get_console(), "log_warning") as mock_logger: py_model = model.get_registry()["col_descriptions_quoted"].model( module_path=Path("."), path=Path("."), ) + assert '"COL"' not in py_model.column_descriptions + assert ( + mock_logger.mock_calls[0].args[0] + == "In model 'col_descriptions_quoted', a description is provided for column 'COL' but it is not a column in the model." + ) + + +def test_python_model_unsupported_kind() -> None: + kinds = { + "seed": {"name": ModelKindName.SEED, "path": "."}, + "view": {"name": ModelKindName.VIEW}, + "managed": {"name": ModelKindName.MANAGED}, + "embedded": {"name": ModelKindName.EMBEDDED}, + } + + for kindname in kinds: + + @model(f"kind_{kindname}", kind=kinds[kindname], columns={'"COL"': "int"}) + def the_kind(context): + pass + + with pytest.raises( + SQLMeshError, match=r".*Cannot create Python model.*doesn't support Python models" + ): + model.get_registry()[f"kind_{kindname}"].model( + module_path=Path("."), + path=Path("."), + ).validate_definition() def test_star_expansion(assert_exp_eq) -> None: @@ -2083,7 +3486,7 @@ def test_model_cache(tmp_path: Path, mocker: MockerFixture): expressions = d.parse( """ MODEL ( - name db.seed, + name db.model_sql, ); SELECT 1, ds; """ @@ -2091,18 +3494,66 @@ def test_model_cache(tmp_path: Path, mocker: MockerFixture): model = load_sql_based_model([e for e in expressions if e]) - loader = mocker.Mock(return_value=model) + assert cache.put([model], "test_model", "test_entry_a") + assert cache.get("test_model", "test_entry_a")[0].dict() == model.dict() + + expressions = d.parse( + """ + MODEL ( + name db.model_seed, + kind SEED ( + path '../seeds/waiter_names.csv', + ), + ); + """ + ) + + seed_model = load_sql_based_model( + expressions, path=Path("./examples/sushi/models/test_model.sql") + ) + + assert not cache.put([seed_model], "test_model", "test_entry_b") + + +@pytest.mark.slow +def test_model_cache_gateway(tmp_path: Path, mocker: MockerFixture): + init_example_project(tmp_path, engine_type="duckdb") + + db_path = str(tmp_path / "db.db") + config = Config( + gateways={ + "main": GatewayConfig(connection=DuckDBConnectionConfig(database=db_path)), + "secondary": GatewayConfig(connection=DuckDBConnectionConfig()), + }, + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + ) + Context(paths=tmp_path, config=config) + + patched_cache_put = mocker.patch("sqlmesh.utils.cache.FileCache.put") + + Context(paths=tmp_path, config=config) + assert patched_cache_put.call_count == 0 + + Context(paths=tmp_path, config=config, gateway="secondary") + assert patched_cache_put.call_count == 2 + - assert cache.get_or_load("test_model", "test_entry_a", loader=loader).dict() == model.dict() - assert cache.get_or_load("test_model", "test_entry_a", loader=loader).dict() == model.dict() +@pytest.mark.slow +def test_model_cache_default_catalog(tmp_path: Path, mocker: MockerFixture): + init_example_project(tmp_path, engine_type="duckdb") + Context(paths=tmp_path) - assert cache.get_or_load("test_model", "test_entry_b", loader=loader).dict() == model.dict() - assert cache.get_or_load("test_model", "test_entry_b", loader=loader).dict() == model.dict() + patched_cache_put = mocker.patch("sqlmesh.utils.cache.FileCache.put") - assert cache.get_or_load("test_model", "test_entry_a", loader=loader).dict() == model.dict() - assert cache.get_or_load("test_model", "test_entry_a", loader=loader).dict() == model.dict() + Context(paths=tmp_path) + assert patched_cache_put.call_count == 0 - assert loader.call_count == 2 + with patch( + "sqlmesh.core.engine_adapter.base.EngineAdapter.default_catalog", + PropertyMock(return_value=None), + ): + Context(paths=tmp_path) + assert patched_cache_put.call_count == 2 def test_model_ctas_query(): @@ -2199,7 +3650,7 @@ def test_model_ctas_query(): assert ( load_sql_based_model(expressions, dialect="bigquery").ctas_query().sql() - == 'WITH RECURSIVE "a" AS (SELECT * FROM (SELECT * FROM (SELECT * FROM "x" AS "x" WHERE FALSE) AS "_q_0" WHERE FALSE) AS "_q_1" WHERE FALSE), "b" AS (SELECT * FROM "a" AS "a" WHERE FALSE UNION ALL SELECT * FROM "a" AS "a" WHERE FALSE) SELECT * FROM "b" AS "b" WHERE FALSE LIMIT 0' + == 'WITH RECURSIVE "a" AS (SELECT * FROM (SELECT * FROM (SELECT * FROM "x" AS "x" WHERE FALSE) AS "_0" WHERE FALSE) AS "_1" WHERE FALSE), "b" AS (SELECT * FROM "a" AS "a" WHERE FALSE UNION ALL SELECT * FROM "a" AS "a" WHERE FALSE) SELECT * FROM "b" AS "b" WHERE FALSE LIMIT 0' ) expressions = d.parse( @@ -2220,7 +3671,7 @@ def test_model_ctas_query(): assert ( load_sql_based_model(expressions, dialect="bigquery").ctas_query().sql() - == 'WITH RECURSIVE "a" AS (WITH "nested_a" AS (SELECT * FROM (SELECT * FROM (SELECT * FROM "x" AS "x" WHERE FALSE) AS "_q_0" WHERE FALSE) AS "_q_1" WHERE FALSE) SELECT * FROM "nested_a" AS "nested_a" WHERE FALSE), "b" AS (SELECT * FROM "a" AS "a" WHERE FALSE UNION ALL SELECT * FROM "a" AS "a" WHERE FALSE) SELECT * FROM "b" AS "b" WHERE FALSE LIMIT 0' + == 'WITH RECURSIVE "a" AS (WITH "nested_a" AS (SELECT * FROM (SELECT * FROM (SELECT * FROM "x" AS "x" WHERE FALSE) AS "_0" WHERE FALSE) AS "_1" WHERE FALSE) SELECT * FROM "nested_a" AS "nested_a" WHERE FALSE), "b" AS (SELECT * FROM "a" AS "a" WHERE FALSE UNION ALL SELECT * FROM "a" AS "a" WHERE FALSE) SELECT * FROM "b" AS "b" WHERE FALSE LIMIT 0' ) @@ -2245,12 +3696,12 @@ def test_parse_expression_list_with_jinja(): "JINJA_STATEMENT_BEGIN;\n{{ log('log message') }}\nJINJA_END;", "GRANT SELECT ON TABLE foo TO DEV", ] - assert input == [val.sql() for val in parse_expression(SqlModel, input, {})] + assert input == [val.sql() for val in parse_expression(SqlModel, input, None)] def test_no_depends_on_runtime_jinja_query(): @macro() - def runtime_macro(**kwargs) -> None: + def runtime_macro(evaluator, **kwargs) -> None: from sqlmesh.utils.errors import ParsetimeAdapterCallError raise ParsetimeAdapterCallError("") @@ -2268,12 +3719,13 @@ def runtime_macro(**kwargs) -> None: model = load_sql_based_model(expressions) with pytest.raises( ConfigError, - match=r"Dependencies must be provided explicitly for models that can be rendered only at runtime at.*", + match=r"Dependencies must be provided explicitly for models that can be rendered only at runtime", ): model.validate_definition() -def test_update_schema(): +@use_terminal_console +def test_update_schema(tmp_path: Path): expressions = d.parse( """ MODEL (name db.table); @@ -2290,11 +3742,15 @@ def test_update_schema(): model.update_schema(schema) assert model.mapping_schema == {'"table_a"': {"a": "INT"}} - logger = logging.getLogger("sqlmesh.core.renderer") - with patch.object(logger, "warning") as mock_logger: - model.render_query(optimize=True) - assert mock_logger.call_args[0][0] == missing_schema_warning_msg( - '"db"."table"', ('"table_b"',) + ctx = Context( + config=Config(linter=LinterConfig(enabled=True, warn_rules=["ALL"])), paths=tmp_path + ) + with patch.object(get_console(), "log_warning") as mock_logger: + ctx.upsert_model(model) + ctx.plan_builder("dev") + assert ( + missing_schema_warning_msg('"db"."table"', ('"table_b"',)) + in mock_logger.call_args[0][0] ) schema.add_table('"table_b"', {"b": exp.DataType.build("int")}) @@ -2303,12 +3759,11 @@ def test_update_schema(): '"table_a"': {"a": "INT"}, '"table_b"': {"b": "INT"}, } - model.render_query(optimize=True) - + model.render_query(needs_optimization=True) -def test_missing_schema_warnings(): - logger = logging.getLogger("sqlmesh.core.renderer") +@use_terminal_console +def test_missing_schema_warnings(tmp_path: Path): full_schema = MappingSchema( { "a": {"x": exp.DataType.build("int")}, @@ -2323,71 +3778,79 @@ def test_missing_schema_warnings(): }, ) + console = get_console() + + ctx = Context( + config=Config(linter=LinterConfig(enabled=True, warn_rules=["ALL"])), paths=tmp_path + ) + # star, no schema, no deps - with patch.object(logger, "warning") as mock_logger: + with patch.object(console, "log_warning") as mock_logger: model = load_sql_based_model(d.parse("MODEL (name test); SELECT * FROM (SELECT 1 a) x")) - model.render_query(optimize=True) + model.render_query(needs_optimization=True) mock_logger.assert_not_called() # star, full schema - with patch.object(logger, "warning") as mock_logger: + with patch.object(console, "log_warning") as mock_logger: model = load_sql_based_model(d.parse("MODEL (name test); SELECT * FROM a CROSS JOIN b")) model.update_schema(full_schema) - model.render_query(optimize=True) + model.render_query(needs_optimization=True) mock_logger.assert_not_called() # star, partial schema - with patch.object(logger, "warning") as mock_logger: + with patch.object(console, "log_warning") as mock_logger: model = load_sql_based_model(d.parse("MODEL (name test); SELECT * FROM a CROSS JOIN b")) model.update_schema(partial_schema) - model.render_query(optimize=True) - assert mock_logger.call_args[0][0] == missing_schema_warning_msg('"test"', ('"b"',)) + ctx.upsert_model(model) + ctx.plan_builder("dev") + assert missing_schema_warning_msg('"test"', ('"b"',)) in mock_logger.call_args[0][0] # star, no schema - with patch.object(logger, "warning") as mock_logger: + with patch.object(console, "log_warning") as mock_logger: model = load_sql_based_model(d.parse("MODEL (name test); SELECT * FROM b JOIN a")) - model.render_query(optimize=True) - assert mock_logger.call_args[0][0] == missing_schema_warning_msg('"test"', ('"a"', '"b"')) + ctx.upsert_model(model) + ctx.plan_builder("dev") + assert missing_schema_warning_msg('"test"', ('"a"', '"b"')) in mock_logger.call_args[0][0] # no star, full schema - with patch.object(logger, "warning") as mock_logger: + with patch.object(console, "log_warning") as mock_logger: model = load_sql_based_model( d.parse("MODEL (name test); SELECT x::INT FROM a CROSS JOIN b") ) model.update_schema(full_schema) - model.render_query(optimize=True) + model.render_query(needs_optimization=True) mock_logger.assert_not_called() # no star, partial schema - with patch.object(logger, "warning") as mock_logger: + with patch.object(console, "log_warning") as mock_logger: model = load_sql_based_model( d.parse("MODEL (name test); SELECT x::INT FROM a CROSS JOIN b") ) model.update_schema(partial_schema) - model.render_query(optimize=True) + model.render_query(needs_optimization=True) mock_logger.assert_not_called() # no star, empty schema - with patch.object(logger, "warning") as mock_logger: + with patch.object(console, "log_warning") as mock_logger: model = load_sql_based_model( d.parse("MODEL (name test); SELECT x::INT FROM a CROSS JOIN b") ) - model.render_query(optimize=True) + model.render_query(needs_optimization=True) mock_logger.assert_not_called() def test_user_provided_depends_on(): - expressions = d.parse( - """ - MODEL (name db.table, depends_on [table_b]); - - SELECT a FROM table_a - """ - ) + for l_delim, r_delim in (("(", ")"), ("[", "]")): + expressions = d.parse( + f""" + MODEL (name db.table, depends_on {l_delim}table_b{r_delim}); - model = load_sql_based_model(expressions) + SELECT a FROM table_a + """ + ) - assert model.depends_on == {'"table_a"', '"table_b"'} + model = load_sql_based_model(expressions) + assert model.depends_on == {'"table_a"', '"table_b"'}, f"Delimiters {l_delim}, {r_delim}" def test_check_schema_mapping_when_rendering_at_runtime(assert_exp_eq): @@ -2468,7 +3931,7 @@ def test_model_normalization(): assert model.partitioned_by[0].sql(dialect="snowflake") == '"A"' assert model.partitioned_by[1].sql(dialect="snowflake") == 'FOO("ds")' assert model.tags == ["pii", "fact"] - assert model.clustered_by == ["A"] + assert model.clustered_by == [exp.to_column('"A"')] assert model.depends_on == {'"BLA"'} # Check possible variations of unique_key definitions @@ -2530,7 +3993,7 @@ def test_model_normalization(): assert model.time_column.column == exp.column("A", quoted=True) assert model.columns_to_types["A"].sql(dialect="snowflake") == "INT" assert model.tags == ["pii", "fact"] - assert model.clustered_by == ["A"] + assert model.clustered_by == [exp.to_column('"A"')] assert model.depends_on == {'"BLA"'} model = create_sql_model( @@ -2542,7 +4005,7 @@ def test_model_normalization(): tags=["pii", "fact"], clustered_by=[exp.column("a"), exp.column("b")], ) - assert model.clustered_by == ["A", "B"] + assert model.clustered_by == [exp.to_column('"A"'), exp.to_column('"B"')] model = create_sql_model( "foo", @@ -2553,7 +4016,7 @@ def test_model_normalization(): tags=["pii", "fact"], clustered_by=["a", "b"], ) - assert model.clustered_by == ["A", "B"] + assert model.clustered_by == [exp.to_column('"A"'), exp.to_column('"B"')] def test_incremental_unmanaged_validation(): @@ -2573,6 +4036,42 @@ def test_incremental_unmanaged_validation(): model.validate_definition() +def test_incremental_unmanaged(): + expr = d.parse( + """ + MODEL ( + name foo, + kind INCREMENTAL_UNMANAGED + ); + + SELECT x.a AS a FROM test.x AS x + """ + ) + + model = load_sql_based_model(expressions=expr) + + assert isinstance(model.kind, IncrementalUnmanagedKind) + assert not model.kind.insert_overwrite + + expr = d.parse( + """ + MODEL ( + name foo, + kind INCREMENTAL_UNMANAGED ( + insert_overwrite true + ), + partitioned_by a + ); + + SELECT x.a AS a FROM test.x AS x + """ + ) + + model = load_sql_based_model(expressions=expr) + assert isinstance(model.kind, IncrementalUnmanagedKind) + assert model.kind.insert_overwrite + + def test_custom_interval_unit(): assert ( load_sql_based_model( @@ -2619,20 +4118,51 @@ def test_custom_interval_unit(): ) with pytest.raises( - ConfigError, match=r"Interval unit of '.*' is larger than cron period of '@daily'" + ConfigError, match=r"Cron '@daily' cannot be more frequent than interval unit 'month'." ): load_sql_based_model( d.parse("MODEL (name db.table, interval_unit month); SELECT a FROM tbl;") ) with pytest.raises( - ConfigError, match=r"Interval unit of '.*' is larger than cron period of '@hourly'" + ConfigError, + match=r"Cron '@hourly' cannot be more frequent than interval unit 'day'. If this is intentional, set allow_partials to True.", ): load_sql_based_model( d.parse("MODEL (name db.table, interval_unit Day, cron '@hourly'); SELECT a FROM tbl;") ) +def test_interval_unit_larger_than_cron_period(): + # The interval unit can be larger than the cron period if allow_partials is True + model = load_sql_based_model( + d.parse( + "MODEL (name db.table, interval_unit day, cron '@hourly', allow_partials TRUE); SELECT a FROM tbl;" + ) + ) + assert model.interval_unit == IntervalUnit.DAY + assert model.cron == "@hourly" + assert model.allow_partials + + with pytest.raises( + ConfigError, + match=r"Cron '@hourly' cannot be more frequent than interval unit 'day'. If this is intentional, set allow_partials to True.", + ): + load_sql_based_model( + d.parse("MODEL (name db.table, interval_unit day, cron '@hourly'); SELECT a FROM tbl;") + ) + + with pytest.raises( + ConfigError, + match=r"Cron '@hourly' cannot be more frequent than interval unit 'day'. If this is intentional, set allow_partials to True.", + ): + load_sql_based_model( + d.parse( + "MODEL (name db.table, interval_unit day, cron '@hourly', allow_partials FALSE); SELECT a FROM tbl;" + ) + ) + + def test_model_physical_properties() -> None: # Validate python model table properties @model( @@ -2779,7 +4309,10 @@ def my_model(context, **kwargs): """('key_a' = 'value_a', 'key_b' = 1, 'key_c' = TRUE, 'key_d' = 2.0)""" ) - with pytest.raises(ConfigError, match=r"Invalid property 'invalid'.*"): + with pytest.raises( + ConfigError, + match=r"Invalid property 'invalid'. Properties must be specified as key-value pairs = . ", + ): load_sql_based_model( d.parse( """ @@ -2916,7 +4449,12 @@ def test_session_properties_on_model_and_project(sushi_context): "some_bool": False, "quoted_identifier": "value_you_wont_see", "project_level_property": "project_property", - } + }, + physical_properties={ + "warehouse": "small", + "target_lag": "10 minutes", + }, + virtual_properties={"creatable_type": "SECURE"}, ) model = load_sql_based_model( @@ -2931,7 +4469,10 @@ def test_session_properties_on_model_and_project(sushi_context): some_float = 0.1, quoted_identifier = "quoted identifier", unquoted_identifier = unquoted_identifier, - ) + ), + physical_properties ( + target_lag = '1 hour' + ), ); SELECT a FROM tbl; """, @@ -2950,14 +4491,33 @@ def test_session_properties_on_model_and_project(sushi_context): "project_level_property": "project_property", } + assert model.physical_properties == { + "warehouse": exp.convert("small"), + "target_lag": exp.convert("1 hour"), + } + + assert model.virtual_properties == { + "creatable_type": exp.convert("SECURE"), + } + -def test_project_level_session_properties(sushi_context): +def test_project_level_properties(sushi_context): model_defaults = ModelDefaultsConfig( session_properties={ "some_bool": False, "some_float": 0.1, "project_level_property": "project_property", - } + }, + physical_properties={ + "warehouse": "small", + "target_lag": "1 hour", + }, + virtual_properties={"creatable_type": "SECURE"}, + enabled=False, + allow_partials=True, + interval_unit="quarter_hour", + optimize_query=True, + cron="@hourly", ) model = load_sql_based_model( @@ -2965,6 +4525,10 @@ def test_project_level_session_properties(sushi_context): """ MODEL ( name test_schema.test_model, + kind FULL, + virtual_properties ( + creatable_type = None + ), ); SELECT a FROM tbl; """, @@ -2973,2332 +4537,7808 @@ def test_project_level_session_properties(sushi_context): defaults=model_defaults.dict(), ) + # Validate use of project wide defaults + assert not model.enabled + assert model.allow_partials + assert model.interval_unit == IntervalUnit.QUARTER_HOUR + assert model.optimize_query + assert model.cron == "@hourly" + assert model.session_properties == { "some_bool": False, "some_float": 0.1, "project_level_property": "project_property", } - -def test_model_session_properties(sushi_context): - assert sushi_context.models['"memory"."sushi"."items"'].session_properties == { - "string_prop": "some_value", - "int_prop": 1, - "float_prop": 1.0, - "bool_prop": True, + assert model.physical_properties == { + "warehouse": exp.convert("small"), + "target_lag": exp.convert("1 hour"), } - model = load_sql_based_model( + + # Validate disabling global property + assert not model.virtual_properties + + model_2 = load_sql_based_model( d.parse( """ MODEL ( - name test_schema.test_model, - session_properties ( - 'spark.executor.cores' = 2, - 'spark.executor.memory' = '1G', - some_bool = True, - some_float = 0.1, - quoted_identifier = "quoted identifier", - unquoted_identifier = unquoted_identifier, - ) + name test_schema.test_model_2, + kind FULL, + allow_partials False, + interval_unit hour, + cron '@daily' + ); - SELECT a FROM tbl; + SELECT a, b FROM tbl; """, default_dialect="snowflake", - ) + ), + defaults=model_defaults.dict(), ) - assert model.session_properties == { - "spark.executor.cores": 2, - "spark.executor.memory": "1G", - "some_bool": True, - "some_float": 0.1, - "quoted_identifier": exp.column("quoted identifier", quoted=True), - "unquoted_identifier": exp.column("unquoted_identifier", quoted=False), - } + # Validate overriding of project wide defaults + assert not model_2.allow_partials + assert model_2.interval_unit == IntervalUnit.HOUR + assert model_2.cron == "@daily" - model = load_sql_based_model( + +def test_conditional_physical_properties(make_snapshot): + model_defaults = ModelDefaultsConfig( + physical_properties={ + "creatable_type": "@IF(@model_kind_name != 'VIEW', 'TRANSIENT', NULL)" + }, + ) + + full_model = load_sql_based_model( d.parse( """ MODEL ( - name test_schema.test_model, - session_properties ( - 'warehouse' = 'test_warehouse' - ) + name test_schema.test_full_model, + kind FULL, ); SELECT a FROM tbl; """, default_dialect="snowflake", - ) + ), + defaults=model_defaults.dict(), ) - assert model.session_properties == { - "warehouse": "test_warehouse", - } - -def test_model_jinja_macro_rendering(): - expressions = d.parse( - """ + view_model = load_sql_based_model( + d.parse( + """ MODEL ( - name db.table, - dialect spark, - owner owner_name, + name test_schema.test_view_model_kind, + kind VIEW, ); - - JINJA_STATEMENT_BEGIN; - {{ test_package.macro_a() }} - {{ macro_b() }} - JINJA_END; - - SELECT 1 AS x; - """ + SELECT a FROM tbl; + """, + default_dialect="snowflake", + ), + defaults=model_defaults.dict(), ) - jinja_macros = JinjaMacroRegistry( - packages={ - "test_package": {"macro_a": MacroInfo(definition="macro_a_body", depends_on=[])}, - }, - root_macros={"macro_b": MacroInfo(definition="macro_b_body", depends_on=[])}, - global_objs={"test_int": 1, "test_str": "value"}, + # load time is a no-op + assert ( + view_model.physical_properties + == full_model.physical_properties + == { + "creatable_type": exp.maybe_parse( + "@IF(@model_kind_name != 'VIEW', 'TRANSIENT', NULL)", dialect="snowflake" + ) + } ) - model = load_sql_based_model(expressions, jinja_macros=jinja_macros) - definition = model.render_definition() + # substitution occurs at runtime + snapshot: Snapshot = make_snapshot(full_model) + snapshot_view: Snapshot = make_snapshot(view_model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - assert definition[1].sql() == "JINJA_STATEMENT_BEGIN;\nmacro_b_body\nJINJA_END;" - assert definition[2].sql() == "JINJA_STATEMENT_BEGIN;\nmacro_a_body\nJINJA_END;" + # Validate use of TRANSIENT type for FULL model + assert full_model.render_physical_properties( + snapshots={full_model.fqn: snapshot, view_model.fqn: snapshot_view} + ) == {"creatable_type": exp.Literal(this="TRANSIENT", is_string=True)} + # Validate disabling the creatable_type property for VIEW model + assert ( + view_model.render_physical_properties( + snapshots={full_model.fqn: snapshot, view_model.fqn: snapshot_view} + ) + == {} + ) -def test_view_model_data_hash(): - view_model_expressions = d.parse( - """ - MODEL ( - name db.table, - kind VIEW, - ); - SELECT 1; - """ + +def test_project_level_properties_python_model(): + model_defaults = { + "physical_properties": { + "partition_expiration_days": 13, + "creatable_type": "TRANSIENT", + }, + "description": "general model description", + "enabled": False, + "allow_partials": True, + "interval_unit": "quarter_hour", + "optimize_query": True, + } + + @model( + name="python_model_t_defaults", + kind="full", + columns={"some_col": "int"}, + physical_properties={"partition_expiration_days": 7}, ) - view_model_hash = load_sql_based_model(view_model_expressions).data_hash + def python_model_prop(context, **kwargs): + context.resolve_table("foo") - materialized_view_model_expressions = d.parse( - """ - MODEL ( - name db.table, - kind VIEW ( - materialized true - ), - ); - SELECT 1; - """ + m = model.get_registry()["python_model_t_defaults"].model( + module_path=Path("."), + path=Path("."), + dialect="duckdb", + defaults=model_defaults, ) - materialized_view_model_hash = load_sql_based_model( - materialized_view_model_expressions - ).data_hash - assert view_model_hash != materialized_view_model_hash + assert m.physical_properties == { + "partition_expiration_days": exp.convert(7), + "creatable_type": exp.convert("TRANSIENT"), + } + # Even if in the project wide defaults these are ignored for python models + assert not m.optimize_query -def test_seed_model_data_hash(): - expressions = d.parse( - """ - MODEL ( - name db.seed, - kind SEED ( - path '../seeds/waiter_names.csv', - ) - ); - """ - ) - seed_model = load_sql_based_model( - expressions, path=Path("./examples/sushi/models/test_model.sql") + assert not m.enabled + assert m.allow_partials + assert m.interval_unit == IntervalUnit.QUARTER_HOUR + + +def test_model_defaults_macros(make_snapshot): + model_defaults = ModelDefaultsConfig( + table_format="@IF(@gateway = 'dev', 'iceberg', NULL)", + cron="@cron_macro", + storage_format="@IF(@gateway = 'local', 'parquet', NULL)", + optimize_query="@IF(@gateway = 'dev', True, False)", + enabled="@IF(@gateway = 'dev', True, False)", + allow_partials="@IF(@gateway = 'local', True, False)", + interval_unit="@IF(@gateway = 'local', 'quarter_hour', 'day')", + start="@IF(@gateway = 'dev', '1 month ago', '2024-01-01')", + session_properties={ + "spark.executor.cores": "@IF(@gateway = 'dpev', 1, 2)", + "spark.executor.memory": "1G", + "unset_property": "@IF(@model_kind_name = 'FULL', 'NOTSET', NULL)", + }, + physical_properties={ + "partition_expiration_days": 13, + "creatable_type": "@IF(@model_kind_name = 'FULL', 'TRANSIENT', NULL)", + }, + virtual_properties={ + "creatable_type": "@create_type", + "unset_property": "@IF(@model_kind_name = 'FULL', 'NOTSET', NULL)", + }, ) - expressions = d.parse( - """ + model = load_sql_based_model( + d.parse( + """ MODEL ( - name db.seed, - kind SEED ( - path '../seeds/waiter_names.csv', - csv_settings ( - quotechar = '''', - ) - ) + name test_schema.test_model, + physical_properties ( + target_lag = '1 hour' + ), ); - """ - ) - new_seed_model = load_sql_based_model( - expressions, path=Path("./examples/sushi/models/test_model.sql") + SELECT a FROM tbl; + """, + default_dialect="snowflake", + ), + defaults=model_defaults.dict(), + variables={"gateway": "dev", "create_type": "SECURE", "cron_macro": "@daily"}, ) - assert seed_model.data_hash != new_seed_model.data_hash + snapshot: Snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + # Validate rendering of model defaults + assert model.optimize_query + assert model.enabled + assert model.start == "1 month ago" + assert not model.allow_partials + assert model.interval_unit == IntervalUnit.DAY + assert model.table_format == "iceberg" + assert model.cron == "@daily" -def test_interval_unit_validation(): - assert ( - create_sql_model( - "a", - d.parse_one("SELECT a, ds FROM table_a"), - interval_unit=IntervalUnit.HOUR, - ).interval_unit - == IntervalUnit.HOUR - ) + # Validate disabling of conditional model default + assert not model.storage_format - assert ( - create_sql_model( - "a", - d.parse_one("SELECT a, ds FROM table_a"), - interval_unit="HOUR", - ).interval_unit - == IntervalUnit.HOUR + # The model defaults properties won't be rendered at load time + assert model.session_properties == { + "spark.executor.cores": exp.maybe_parse( + "@IF(@gateway = 'dpev', 1, 2)", dialect="snowflake" + ), + "spark.executor.memory": "1G", + "unset_property": exp.maybe_parse( + "@IF(@model_kind_name = 'FULL', 'NOTSET', NULL)", dialect="snowflake" + ), + } + + assert model.physical_properties == { + "partition_expiration_days": exp.convert(13), + "creatable_type": exp.maybe_parse( + "@IF(@model_kind_name = 'FULL', 'TRANSIENT', NULL)", dialect="snowflake" + ), + "target_lag": exp.convert("1 hour"), + } + + assert model.virtual_properties == { + "creatable_type": d.MacroVar(this="create_type"), + "unset_property": exp.maybe_parse( + "@IF(@model_kind_name = 'FULL', 'NOTSET', NULL)", dialect="snowflake" + ), + } + + # Validate the correct rendering and removal of conditional properties + assert model.render_session_properties(snapshots={model.fqn: snapshot}) == { + "spark.executor.cores": exp.convert(2), + "spark.executor.memory": "1G", + } + + assert model.render_physical_properties(snapshots={model.fqn: snapshot}) == { + "partition_expiration_days": exp.convert(13), + "target_lag": exp.convert("1 hour"), + } + + assert model.render_virtual_properties(snapshots={model.fqn: snapshot}) == { + "creatable_type": exp.convert("SECURE"), + } + + +def test_model_defaults_macros_python_model(make_snapshot): + model_defaults = { + "physical_properties": { + "partition_expiration_days": 13, + "creatable_type": "@IF(@model_kind_name = 'FULL', 'TRANSIENT', NULL)", + }, + "cron": "@cron_macro_expr", + "table_format": "@IF(@gateway = 'local', 'iceberg', NULL)", + "storage_format": "@IF(@gateway = 'dev', 'parquet', NULL)", + "optimize_query": "@IF(@gateway = 'local', True, False)", + "enabled": "@IF(@gateway = 'local', True, False)", + "allow_partials": "@IF(@gateway = 'local', True, False)", + "interval_unit": "@IF(@gateway = 'local', 'quarter_hour', 'day')", + "start": "@IF(@gateway = 'dev', '1 month ago', '2024-01-01')", + "virtual_properties": {"creatable_type": "@create_type"}, + "session_properties": { + "spark.executor.cores": "@IF(@gateway = 'dev', 1, 2)", + "spark.executor.memory": "1G", + }, + } + + @model( + name="python_model_defaults_macro", + kind="full", + columns={"some_col": "int"}, + physical_properties={"partition_expiration_days": 7}, ) + def python_model_prop_macro(context, **kwargs): + context.resolve_table("foo") - assert ( - create_sql_model( - "a", - d.parse_one("SELECT a, ds FROM table_a"), - interval_unit=None, - ).interval_unit_ - is None + m = model.get_registry()["python_model_defaults_macro"].model( + module_path=Path("."), + path=Path("."), + dialect="duckdb", + defaults=model_defaults, + variables={"gateway": "local", "create_type": "SECURE", "cron_macro_expr": "0 */2 * * *"}, ) + # Even if in the project wide defaults this is ignored for python models + assert not m.optimize_query -def test_scd_type_2_by_time_defaults(): - model_def = d.parse( - """ - MODEL ( - name db.table, - kind SCD_TYPE_2 ( - unique_key (COALESCE("ID", '') || '|' || COALESCE("ds", ''), COALESCE("ds", '')), - ), - ); - SELECT - 1 as "ID", - '2020-01-01' as ds, - '2020-01-01' as test_updated_at, - '2020-01-01' as test_valid_from, - '2020-01-01' as test_valid_to - ; - """ - ) - scd_type_2_model = load_sql_based_model(model_def) - assert scd_type_2_model.unique_key == [ - parse_one("""COALESCE("ID", '') || '|' || COALESCE("ds", '')"""), - parse_one("""COALESCE("ds", '')"""), - ] - assert scd_type_2_model.columns_to_types == { - "ID": exp.DataType.build("int"), - "ds": exp.DataType.build("varchar"), - "test_updated_at": exp.DataType.build("varchar"), - "test_valid_from": exp.DataType.build("varchar"), - "test_valid_to": exp.DataType.build("varchar"), - "valid_from": exp.DataType.build("TIMESTAMP"), - "valid_to": exp.DataType.build("TIMESTAMP"), + # Validate rendering of model defaults + assert m.cron == "0 */2 * * *" + assert m.enabled + assert m.start == "2024-01-01" + assert m.allow_partials + assert m.interval_unit == IntervalUnit.QUARTER_HOUR + assert m.table_format == "iceberg" + + # Validate disabling attribute dynamically + assert not m.storage_format + + snapshot: Snapshot = make_snapshot(m) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + # Ensure properties are not rendered at load time + assert m.physical_properties == { + "partition_expiration_days": exp.convert(7), + "creatable_type": exp.maybe_parse( + "@IF(@model_kind_name = 'FULL', 'TRANSIENT', NULL)", dialect="duckdb" + ), } - assert scd_type_2_model.managed_columns == { - "valid_from": exp.DataType.build("TIMESTAMP"), - "valid_to": exp.DataType.build("TIMESTAMP"), + + # Substitution occurs at runtime for properties so here these will be unrendered + assert m.render_physical_properties( + snapshots={ + m.fqn: snapshot, + } + ) == { + "partition_expiration_days": exp.convert(7), + "creatable_type": exp.convert("TRANSIENT"), } - assert scd_type_2_model.kind.updated_at_name == exp.column("updated_at", quoted=True) - assert scd_type_2_model.kind.valid_from_name == exp.column("valid_from", quoted=True) - assert scd_type_2_model.kind.valid_to_name == exp.column("valid_to", quoted=True) - assert not scd_type_2_model.kind.updated_at_as_valid_from - assert scd_type_2_model.kind.is_scd_type_2_by_time - assert scd_type_2_model.kind.is_scd_type_2 - assert scd_type_2_model.kind.is_materialized - assert scd_type_2_model.kind.forward_only - assert scd_type_2_model.kind.disable_restatement + assert m.session_properties == { + "spark.executor.cores": exp.maybe_parse("@IF(@gateway = 'dev', 1, 2)", dialect="duckdb"), + "spark.executor.memory": "1G", + } -def test_scd_type_2_by_time_overrides(): - model_def = d.parse( - """ - MODEL ( - name db.table, - kind SCD_TYPE_2_BY_TIME ( - unique_key ["iD", ds], - updated_at_name test_updated_at, - valid_from_name test_valid_from, - valid_to_name test_valid_to, - time_data_type TIMESTAMPTZ, - updated_at_as_valid_from True, - forward_only False, - disable_restatement False, - invalidate_hard_deletes False, - ), - dialect snowflake - ); - SELECT - 1 as "iD", - '2020-01-01' as ds, - '2020-01-01' as test_updated_at, - '2020-01-01' as test_valid_from, - '2020-01-01' as test_valid_to - ; - """ - ) - scd_type_2_model = load_sql_based_model(model_def) - assert scd_type_2_model.unique_key == [ - exp.column("iD", quoted=True), - exp.column("DS", quoted=True), - ] - assert scd_type_2_model.managed_columns == { - "TEST_VALID_FROM": exp.DataType.build("TIMESTAMPTZ"), - "TEST_VALID_TO": exp.DataType.build("TIMESTAMPTZ"), + assert m.virtual_properties["creatable_type"] == d.MacroVar(this="create_type") + + # Validate rendering of properties + assert m.render_session_properties( + snapshots={ + m.fqn: snapshot, + }, + ) == { + "spark.executor.cores": exp.convert(2), + "spark.executor.memory": "1G", } - assert scd_type_2_model.kind.updated_at_name == exp.column("TEST_UPDATED_AT", quoted=True) - assert scd_type_2_model.kind.valid_from_name == exp.column("TEST_VALID_FROM", quoted=True) - assert scd_type_2_model.kind.valid_to_name == exp.column("TEST_VALID_TO", quoted=True) - assert scd_type_2_model.kind.updated_at_as_valid_from - assert scd_type_2_model.kind.is_scd_type_2_by_time - assert scd_type_2_model.kind.is_scd_type_2 - assert scd_type_2_model.kind.is_materialized - assert not scd_type_2_model.kind.invalidate_hard_deletes - assert not scd_type_2_model.kind.forward_only - assert not scd_type_2_model.kind.disable_restatement - model_kind_dict = scd_type_2_model.kind.dict() - assert scd_type_2_model.kind == _model_kind_validator(None, model_kind_dict, {}) + assert m.render_virtual_properties( + snapshots={ + m.fqn: snapshot, + } + ) == {"creatable_type": exp.convert("SECURE")} + assert m.render_physical_properties(snapshots={m.fqn: snapshot}) == { + "partition_expiration_days": exp.convert(7), + "creatable_type": exp.convert("TRANSIENT"), + } -def test_scd_type_2_by_column_defaults(): - model_def = d.parse( - """ - MODEL ( + +@pytest.mark.parametrize( + "optimize_query, enabled, allow_partials, interval_unit, expected_error", + [ + ("string", "string", "string", "string", r"^Expected boolean for*"), + (True, "string", "string", "string", r"^Expected boolean for*"), + (True, True, "string", "string", r"^Expected boolean for*"), + (True, True, True, "string", r"^Invalid interval unitr*"), + ], +) +def test_model_defaults_validations( + optimize_query, enabled, allow_partials, interval_unit, expected_error +): + model_defaults = ModelDefaultsConfig( + optimize_query=optimize_query, + enabled=enabled, + allow_partials=allow_partials, + interval_unit=interval_unit, + ) + + with pytest.raises(ConfigError, match=expected_error): + load_sql_based_model( + d.parse( + """ + MODEL ( + name test_schema.test_model, + ); + SELECT a FROM tbl; + """, + ), + defaults=model_defaults.dict(), + ) + + +def test_model_session_properties(sushi_context): + assert sushi_context.models['"memory"."sushi"."items"'].session_properties == { + "string_prop": "some_value", + "int_prop": 1, + "float_prop": 1.0, + "bool_prop": True, + } + model = load_sql_based_model( + d.parse( + """ + MODEL ( + name test_schema.test_model, + session_properties ( + 'spark.executor.cores' = 2, + 'spark.executor.memory' = '1G', + some_bool = True, + some_float = 0.1, + quoted_identifier = "quoted identifier", + unquoted_identifier = unquoted_identifier, + ) + ); + SELECT a FROM tbl; + """, + default_dialect="snowflake", + ) + ) + + assert model.session_properties == { + "spark.executor.cores": 2, + "spark.executor.memory": "1G", + "some_bool": True, + "some_float": 0.1, + "quoted_identifier": exp.column("quoted identifier", quoted=True), + "unquoted_identifier": exp.column("unquoted_identifier", quoted=False), + } + + model = load_sql_based_model( + d.parse( + """ + MODEL ( + name test_schema.test_model, + session_properties ( + 'warehouse' = 'test_warehouse' + ) + ); + SELECT a FROM tbl; + """, + default_dialect="snowflake", + ) + ) + assert model.session_properties == { + "warehouse": "test_warehouse", + } + + model = load_sql_based_model( + d.parse( + """ + MODEL ( + name test_schema.test_model, + session_properties ( + 'query_label' = [ + ('key1', 'value1'), + ('key2', 'value2') + ] + ) + ); + SELECT a FROM tbl; + """, + default_dialect="bigquery", + ) + ) + assert model.session_properties == { + "query_label": parse_one("[('key1', 'value1'), ('key2', 'value2')]", dialect="bigquery") + } + + model = load_sql_based_model( + d.parse( + """ + MODEL ( + name test_schema.test_model, + session_properties ( + 'query_label' = ( + ('key1', 'value1') + ) + ) + ); + SELECT a FROM tbl; + """, + default_dialect="bigquery", + ) + ) + assert model.session_properties == {"query_label": parse_one("(('key1', 'value1'))")} + + with pytest.raises( + ConfigError, + match=r"Invalid value for `session_properties.query_label`. Must be an array or tuple.", + ): + load_sql_based_model( + d.parse( + """ + MODEL ( + name test_schema.test_model, + session_properties ( + 'query_label' = 'invalid value' + ) + ); + SELECT a FROM tbl; + """, + default_dialect="bigquery", + ) + ) + + with pytest.raises( + ConfigError, + match=r"Invalid entry in `session_properties.query_label`. Must be tuples of string literals with length 2.", + ): + load_sql_based_model( + d.parse( + """ + MODEL ( + name test_schema.test_model, + session_properties ( + 'query_label' = ( + ('key1', 'value1', 'another_value') + ) + ) + ); + SELECT a FROM tbl; + """, + default_dialect="bigquery", + ) + ) + + with pytest.raises( + ConfigError, + match=r"Invalid entry in `session_properties.query_label`. Must be tuples of string literals with length 2.", + ): + load_sql_based_model( + d.parse( + """ + MODEL ( + name test_schema.test_model, + session_properties ( + 'query_label' = ( + 'some value', + 'another value', + 'yet another value', + ) + ) + ); + SELECT a FROM tbl; + """, + default_dialect="bigquery", + ) + ) + + +def test_session_properties_authorization_validation(): + model = load_sql_based_model( + d.parse( + """ + MODEL ( + name test_schema.test_model, + session_properties ( + authorization = 'test_user' + ) + ); + SELECT a FROM tbl; + """, + default_dialect="trino", + ) + ) + assert model.session_properties == {"authorization": "test_user"} + + with pytest.raises( + ConfigError, + match=r"Invalid value for `session_properties.authorization`. Must be a string literal.", + ): + load_sql_based_model( + d.parse( + """ + MODEL ( + name test_schema.test_model, + session_properties ( + authorization = 123 + ) + ); + SELECT a FROM tbl; + """, + default_dialect="trino", + ) + ) + + with pytest.raises( + ConfigError, + match=r"Invalid value for `session_properties.authorization`. Must be a string literal.", + ): + load_sql_based_model( + d.parse( + """ + MODEL ( + name test_schema.test_model, + session_properties ( + authorization = some_column + ) + ); + SELECT a FROM tbl; + """, + default_dialect="trino", + ) + ) + + with pytest.raises( + ConfigError, + match=r"Invalid value for `session_properties.authorization`. Must be a string literal.", + ): + load_sql_based_model( + d.parse( + """ + MODEL ( + name test_schema.test_model, + session_properties ( + authorization = CONCAT('user', '_suffix') + ) + ); + SELECT a FROM tbl; + """, + default_dialect="trino", + ) + ) + + +def test_model_jinja_macro_rendering(): + expressions = d.parse( + """ + MODEL ( + name db.table, + dialect spark, + owner owner_name, + ); + + JINJA_STATEMENT_BEGIN; + {{ test_package.macro_a() }} + {{ macro_b() }} + JINJA_END; + + SELECT 1 AS x; + """ + ) + + jinja_macros = JinjaMacroRegistry( + packages={ + "test_package": {"macro_a": MacroInfo(definition="macro_a_body", depends_on=[])}, + }, + root_macros={"macro_b": MacroInfo(definition="macro_b_body", depends_on=[])}, + global_objs={"test_int": 1, "test_str": "value"}, + ) + + model = load_sql_based_model(expressions, jinja_macros=jinja_macros) + definition = model.render_definition() + + assert definition[1].sql() == "JINJA_STATEMENT_BEGIN;\nmacro_b_body\nJINJA_END;" + assert definition[2].sql() == "JINJA_STATEMENT_BEGIN;\nmacro_a_body\nJINJA_END;" + + +def test_view_model_data_hash(): + view_model_expressions = d.parse( + """ + MODEL ( + name db.table, + kind VIEW, + ); + SELECT 1; + """ + ) + view_model_hash = load_sql_based_model(view_model_expressions).data_hash + + materialized_view_model_expressions = d.parse( + """ + MODEL ( + name db.table, + kind VIEW ( + materialized true + ), + ); + SELECT 1; + """ + ) + materialized_view_model_hash = load_sql_based_model( + materialized_view_model_expressions + ).data_hash + + assert view_model_hash != materialized_view_model_hash + + +def test_view_materialized_partition_by_clustered_by(): + materialized_view_model_expressions = d.parse( + """ + MODEL ( + name db.table, + kind VIEW ( + materialized true + ), + partitioned_by ds, + clustered_by a + ); + SELECT 1; + """ + ) + materialized_view_model = load_sql_based_model(materialized_view_model_expressions) + assert materialized_view_model.partitioned_by == [exp.column("ds", quoted=True)] + assert materialized_view_model.clustered_by == [exp.to_column('"a"')] + + +def test_view_non_materialized_partition_by(): + view_model_expressions = d.parse( + """ + MODEL ( + name db.table, + kind VIEW, + partitioned_by ds, + ); + SELECT 1; + """ + ) + with pytest.raises(ValidationError, match=r".*partitioned_by field cannot be set for VIEW.*"): + load_sql_based_model(view_model_expressions) + + +def test_view_non_materialized_clustered_by(): + view_model_expressions = d.parse( + """ + MODEL ( + name db.table, + kind VIEW, + clustered_by ds, + ); + SELECT 1; + """ + ) + with pytest.raises(ValidationError, match=r".*clustered_by field cannot be set for VIEW.*"): + load_sql_based_model(view_model_expressions) + + +def test_seed_model_data_hash(): + expressions = d.parse( + """ + MODEL ( + name db.seed, + kind SEED ( + path '../seeds/waiter_names.csv', + ) + ); + """ + ) + seed_model = load_sql_based_model( + expressions, path=Path("./examples/sushi/models/test_model.sql") + ) + + expressions = d.parse( + """ + MODEL ( + name db.seed, + kind SEED ( + path '../seeds/waiter_names.csv', + csv_settings ( + quotechar = '''', + ) + ) + ); + """ + ) + new_seed_model = load_sql_based_model( + expressions, path=Path("./examples/sushi/models/test_model.sql") + ) + + assert seed_model.data_hash != new_seed_model.data_hash + + +def test_interval_unit_validation(): + assert ( + create_sql_model( + "a", + d.parse_one("SELECT a, ds FROM table_a"), + interval_unit=IntervalUnit.HOUR, + ).interval_unit + == IntervalUnit.HOUR + ) + + assert ( + create_sql_model( + "a", + d.parse_one("SELECT a, ds FROM table_a"), + interval_unit="HOUR", + ).interval_unit + == IntervalUnit.HOUR + ) + + assert ( + create_sql_model( + "a", + d.parse_one("SELECT a, ds FROM table_a"), + interval_unit=None, + ).interval_unit_ + is None + ) + + +def test_scd_type_2_by_time_defaults(): + model_def = d.parse( + """ + MODEL ( + name db.table, + kind SCD_TYPE_2 ( + unique_key (COALESCE("ID", '') || '|' || COALESCE("ds", ''), COALESCE("ds", '')), + ), + ); + SELECT + 1 as "ID", + '2020-01-01' as ds, + '2020-01-01' as test_updated_at, + '2020-01-01' as test_valid_from, + '2020-01-01' as test_valid_to + ; + """ + ) + scd_type_2_model = load_sql_based_model(model_def) + assert scd_type_2_model.unique_key == [ + parse_one("""COALESCE("ID", '') || '|' || COALESCE("ds", '')"""), + parse_one("""COALESCE("ds", '')"""), + ] + assert scd_type_2_model.columns_to_types == { + "ID": exp.DataType.build("int"), + "ds": exp.DataType.build("varchar"), + "test_updated_at": exp.DataType.build("varchar"), + "test_valid_from": exp.DataType.build("varchar"), + "test_valid_to": exp.DataType.build("varchar"), + "valid_from": exp.DataType.build("TIMESTAMP"), + "valid_to": exp.DataType.build("TIMESTAMP"), + } + assert scd_type_2_model.managed_columns == { + "valid_from": exp.DataType.build("TIMESTAMP"), + "valid_to": exp.DataType.build("TIMESTAMP"), + } + assert scd_type_2_model.kind.updated_at_name == exp.column("updated_at", quoted=True) + assert scd_type_2_model.kind.valid_from_name == exp.column("valid_from", quoted=True) + assert scd_type_2_model.kind.valid_to_name == exp.column("valid_to", quoted=True) + assert not scd_type_2_model.kind.updated_at_as_valid_from + assert scd_type_2_model.kind.is_scd_type_2_by_time + assert scd_type_2_model.kind.is_scd_type_2 + assert scd_type_2_model.kind.is_materialized + assert scd_type_2_model.kind.forward_only + assert scd_type_2_model.kind.disable_restatement + + +def test_scd_type_2_by_time_overrides(): + model_def = d.parse( + """ + MODEL ( + name db.table, + kind SCD_TYPE_2_BY_TIME ( + unique_key ["iD", ds], + updated_at_name test_updated_at, + valid_from_name test_valid_from, + valid_to_name test_valid_to, + time_data_type TIMESTAMPTZ, + updated_at_as_valid_from True, + forward_only False, + disable_restatement False, + invalidate_hard_deletes False, + ), + dialect snowflake + ); + SELECT + 1 as "iD", + '2020-01-01' as ds, + '2020-01-01' as test_updated_at, + '2020-01-01' as test_valid_from, + '2020-01-01' as test_valid_to + ; + """ + ) + scd_type_2_model = load_sql_based_model(model_def) + assert scd_type_2_model.unique_key == [ + exp.column("iD", quoted=True), + exp.column("DS", quoted=True), + ] + assert scd_type_2_model.managed_columns == { + "TEST_VALID_FROM": exp.DataType.build("TIMESTAMPTZ"), + "TEST_VALID_TO": exp.DataType.build("TIMESTAMPTZ"), + } + assert scd_type_2_model.kind.updated_at_name == exp.column("TEST_UPDATED_AT", quoted=True) + assert scd_type_2_model.kind.valid_from_name == exp.column("TEST_VALID_FROM", quoted=True) + assert scd_type_2_model.kind.valid_to_name == exp.column("TEST_VALID_TO", quoted=True) + assert scd_type_2_model.kind.updated_at_as_valid_from + assert scd_type_2_model.kind.is_scd_type_2_by_time + assert scd_type_2_model.kind.is_scd_type_2 + assert scd_type_2_model.kind.is_materialized + assert not scd_type_2_model.kind.invalidate_hard_deletes + assert not scd_type_2_model.kind.forward_only + assert not scd_type_2_model.kind.disable_restatement + + model_kind_dict = scd_type_2_model.kind.dict() + assert scd_type_2_model.kind == _model_kind_validator(None, model_kind_dict, None) + + +def test_scd_type_2_by_column_defaults(): + model_def = d.parse( + """ + MODEL ( + name db.table, + kind SCD_TYPE_2_BY_COLUMN ( + unique_key "ID", + columns ["value_to_track"] + ), + ); + SELECT + 1 as "ID", + 2 as "value_to_track", + '2020-01-01' as ds, + ; + """ + ) + scd_type_2_model = load_sql_based_model(model_def) + assert scd_type_2_model.unique_key == [exp.to_column("ID", quoted=True)] + assert scd_type_2_model.kind.columns == [exp.to_column("value_to_track", quoted=True)] + assert scd_type_2_model.columns_to_types == { + "ID": exp.DataType.build("int"), + "value_to_track": exp.DataType.build("int"), + "ds": exp.DataType.build("varchar"), + "valid_from": exp.DataType.build("TIMESTAMP"), + "valid_to": exp.DataType.build("TIMESTAMP"), + } + assert scd_type_2_model.managed_columns == { + "valid_from": exp.DataType.build("TIMESTAMP"), + "valid_to": exp.DataType.build("TIMESTAMP"), + } + assert scd_type_2_model.kind.valid_from_name == exp.column("valid_from", quoted=True) + assert scd_type_2_model.kind.valid_to_name == exp.column("valid_to", quoted=True) + assert not scd_type_2_model.kind.execution_time_as_valid_from + assert scd_type_2_model.kind.is_scd_type_2_by_column + assert scd_type_2_model.kind.is_scd_type_2 + assert scd_type_2_model.kind.is_materialized + assert scd_type_2_model.kind.forward_only + assert scd_type_2_model.kind.disable_restatement + + +def test_scd_type_2_by_column_overrides(): + model_def = d.parse( + """ + MODEL ( + name db.table, + kind SCD_TYPE_2_BY_COLUMN ( + unique_key ["iD", ds], + columns "value_to_track", + valid_from_name test_valid_from, + valid_to_name test_valid_to, + execution_time_as_valid_from True, + time_data_type TIMESTAMPTZ, + forward_only False, + disable_restatement False, + invalidate_hard_deletes False, + batch_size 1 + ), + ); + SELECT + 1 as "ID", + 2 as "value_to_track", + '2020-01-01' as ds, + ; + """ + ) + scd_type_2_model = load_sql_based_model(model_def) + assert scd_type_2_model.unique_key == [ + exp.column("iD", quoted=True), + exp.column("ds", quoted=True), + ] + assert scd_type_2_model.managed_columns == { + "test_valid_from": exp.DataType.build("TIMESTAMPTZ"), + "test_valid_to": exp.DataType.build("TIMESTAMPTZ"), + } + assert scd_type_2_model.kind.valid_from_name == exp.column("test_valid_from", quoted=True) + assert scd_type_2_model.kind.valid_to_name == exp.column("test_valid_to", quoted=True) + assert scd_type_2_model.kind.execution_time_as_valid_from + assert scd_type_2_model.kind.is_scd_type_2_by_column + assert scd_type_2_model.kind.is_scd_type_2 + assert scd_type_2_model.kind.is_materialized + assert scd_type_2_model.kind.time_data_type == exp.DataType.build("TIMESTAMPTZ") + assert scd_type_2_model.kind.batch_size == 1 + assert not scd_type_2_model.kind.invalidate_hard_deletes + assert not scd_type_2_model.kind.forward_only + assert not scd_type_2_model.kind.disable_restatement + + model_kind_dict = scd_type_2_model.kind.dict() + assert scd_type_2_model.kind == _model_kind_validator(None, model_kind_dict, None) + + +def test_scd_type_2_python_model() -> None: + @model( + "test_scd_type_2_python_model", + kind=dict( + name=ModelKindName.SCD_TYPE_2_BY_TIME, + unique_key="a", + updated_at_name="b", + updated_at_as_valid_from=True, + ), + columns={"a": "string", "b": "string"}, + ) + def scd_type_2_model(context, **kwargs): + return pd.DataFrame( + [ + { + "a": "val1", + "b": "2024-01-01", + } + ] + ) + + python_model = model.get_registry()["test_scd_type_2_python_model"].model( + module_path=Path("."), + path=Path("."), + ) + + assert python_model.columns_to_types == { + "a": exp.DataType.build("string"), + "b": exp.DataType.build("string"), + "valid_from": exp.DataType.build("TIMESTAMP"), + "valid_to": exp.DataType.build("TIMESTAMP"), + } + + +@pytest.mark.parametrize( + "input_columns,expected_columns", + [ + ( + "col1", + [exp.to_column("col1", quoted=True)], + ), + ( + "[col1]", + [exp.to_column("col1", quoted=True)], + ), + ( + "[col1, col2]", + [exp.to_column("col1", quoted=True), exp.to_column("col2", quoted=True)], + ), + ( + '"col1"', + [exp.to_column("col1", quoted=True)], + ), + ( + '["col1"]', + [exp.to_column("col1", quoted=True)], + ), + ("*", [exp.Star()]), + ], +) +def test_check_column_variants(input_columns, expected_columns): + model_def = d.parse( + f""" + MODEL ( + name db.table, + kind SCD_TYPE_2_BY_COLUMN ( + unique_key "ID", + columns {input_columns} + ), + ); + SELECT 1 + ; + """ + ) + scd_type_2_model = load_sql_based_model(model_def) + assert scd_type_2_model.kind.columns == expected_columns + + +def test_model_dialect_name(): + expressions = d.parse( + """ + MODEL ( + name `project-1`.`db`.`tbl1`, + dialect bigquery + ); + SELECT 1; + """ + ) + + model = load_sql_based_model(expressions) + assert model.fqn == '"project-1"."db"."tbl1"' + + model = create_external_model( + "`project-1`.`db`.`tbl1`", columns={"x": "STRING"}, dialect="bigquery" + ) + assert "name `project-1`.`db`.`tbl1`" in model.render_definition()[0].sql(dialect="bigquery") + + # This used to fail due to the dialect regex picking up `DIALECT_TEST` as the model's dialect + expressions = d.parse("MODEL(name DIALECT_TEST.foo); SELECT 1") + + +def test_model_allow_partials(): + expressions = d.parse( + """ + MODEL ( + name db.table, + allow_partials true, + ); + SELECT 1; + """ + ) + + model = load_sql_based_model(expressions) + + assert model.allow_partials + + assert "allow_partials TRUE" in model.render_definition()[0].sql() + + +def test_signals(): + expressions = d.parse( + """ + MODEL ( + name db.table, + signals [ + (arg = 1), + ], + ); + SELECT 1; + """ + ) + + model = load_sql_based_model(expressions) + assert model.signals[0][1] == {"arg": exp.Literal.number(1)} + + @signal() + def my_signal(batch): + return True + + expressions = d.parse( + """ + MODEL ( + name db.table, + signals [ + my_signal(arg = 1), + ( + table_name := 'table_a', + ds := @end_ds, + ), + ( + table_name = 'table_b', + ds = @end_ds, + hour = @end_hour, + ), + ( + bool_key = True, + int_key = 1, + float_key = 1.0, + string_key = 'string', + ) + ], + ); + SELECT 1; + """ + ) + + model = load_sql_based_model( + expressions, + signal_definitions={"my_signal": signal.get_registry()["my_signal"]}, + ) + assert model.signals == [ + ( + "my_signal", + { + "arg": exp.Literal.number(1), + }, + ), + ( + "", + { + "table_name": exp.Literal.string("table_a"), + "ds": d.MacroVar(this="end_ds"), + }, + ), + ( + "", + { + "table_name": exp.Literal.string("table_b"), + "ds": d.MacroVar(this="end_ds"), + "hour": d.MacroVar(this="end_hour"), + }, + ), + ( + "", + { + "bool_key": exp.true(), + "int_key": exp.Literal.number(1), + "float_key": exp.Literal.number(1.0), + "string_key": exp.Literal.string("string"), + }, + ), + ] + + rendered_signals = model.render_signals(start="2023-01-01", end="2023-01-02 15:00:00") + assert rendered_signals == [ + {"table_name": "table_a", "ds": "2023-01-02"}, + {"table_name": "table_b", "ds": "2023-01-02", "hour": 14}, + {"bool_key": True, "int_key": 1, "float_key": 1.0, "string_key": "string"}, + ] + + assert ( + "signals (MY_SIGNAL(arg := 1), (table_name = 'table_a', ds = @end_ds), (table_name = 'table_b', ds = @end_ds, hour = @end_hour), (bool_key = TRUE, int_key = 1, float_key = 1.0, string_key = 'string'))" + in model.render_definition()[0].sql() + ) + + +def test_load_python_model_with_signals(): + @signal() + def always_true(batch): + return True + + @model( + name="model_with_signal", + kind="full", + columns={'"COL"': "int"}, + signals=[("always_true", {})], + ) + def model_with_signal(context, **kwargs): + return pd.DataFrame([{"COL": 1}]) + + models = model.get_registry()["model_with_signal"].models( + get_variables=lambda _: {}, + path=Path("."), + module_path=Path("."), + signal_definitions=signal.get_registry(), + ) + assert len(models) == 1 + assert models[0].signals == [("always_true", {})] + + +def test_null_column_type(): + expressions = d.parse( + """ + MODEL ( + name test_db.test_model, + columns ( + id INT, + ds NULL, + ) + ); + + SELECT + id::INT AS id, + ds + FROM x + """ + ) + model = load_sql_based_model(expressions, dialect="hive") + assert model.columns_to_types == { + "ds": exp.DataType.build("null"), + "id": exp.DataType.build("int"), + } + assert not model.annotated + + +def test_when_matched(): + expressions = d.parse( + """ + MODEL ( + name db.employees, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key name, + when_matched WHEN MATCHED THEN UPDATE SET target.salary = COALESCE(source.salary, target.salary) + ) + ); + SELECT 'name' AS name, 1 AS salary; + """ + ) + + expected_when_matched = "(WHEN MATCHED THEN UPDATE SET `__MERGE_TARGET__`.`salary` = COALESCE(`__MERGE_SOURCE__`.`salary`, `__MERGE_TARGET__`.`salary`))" + + model = load_sql_based_model(expressions, dialect="hive") + assert model.kind.when_matched.sql(dialect="hive") == expected_when_matched + + model = SqlModel.parse_raw(model.json()) + assert model.kind.when_matched.sql(dialect="hive") == expected_when_matched + + expressions = d.parse( + """ + MODEL ( + name @{macro_val}.test, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key purchase_order_id, + when_matched ( + WHEN MATCHED AND source._operation = 1 THEN DELETE + WHEN MATCHED AND source._operation <> 1 THEN UPDATE SET target.purchase_order_id = 1 + ) + ) + ); + + SELECT + purchase_order_id + FROM @{macro_val}.upstream + """ + ) + + model = SqlModel.parse_raw(load_sql_based_model(expressions).json()) + assert d.format_model_expressions(model.render_definition()) == ( + """MODEL ( + name @{macro_val}.test, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key ("purchase_order_id"), + when_matched ( + WHEN MATCHED AND "__MERGE_SOURCE__"."_operation" = 1 THEN DELETE + WHEN MATCHED AND "__MERGE_SOURCE__"."_operation" <> 1 THEN UPDATE SET + "__MERGE_TARGET__"."purchase_order_id" = 1 + ), + batch_concurrency 1, + forward_only FALSE, + disable_restatement FALSE, + on_destructive_change 'ERROR', + on_additive_change 'ALLOW' + ) +); + +SELECT + purchase_order_id +FROM @{macro_val}.upstream""" + ) + + @macro() + def fingerprint_merge( + evaluator: MacroEvaluator, + fingerprint_column: exp.Column, + update_columns: list[exp.Column], + ) -> exp.Whens: + fingerprint_evaluation = f"source.{fingerprint_column} <> target.{fingerprint_column}" + column_update = [f"target.{column} = source.{column}" for column in update_columns] + return exp.maybe_parse( + f"WHEN MATCHED AND {fingerprint_evaluation} THEN UPDATE SET {column_update}", + into=exp.Whens, + ) + + expressions = d.parse( + """ + MODEL ( + name test, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key purchase_order_id, + when_matched (@fingerprint_merge(salary, [update_datetime, salary])) + ) + ); + + SELECT + 1 AS purchase_order_id, + 1 AS salary, + CAST('2020-01-01 12:05:01' AS DATETIME) AS update_datetime + """ + ) + + model = SqlModel.parse_raw(load_sql_based_model(expressions).json()) + assert d.format_model_expressions(model.render_definition()) == ( + """MODEL ( + name test, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key ("purchase_order_id"), + when_matched ( + WHEN MATCHED AND "__MERGE_SOURCE__"."salary" <> "__MERGE_TARGET__"."salary" THEN UPDATE SET + ARRAY('target.update_datetime = source.update_datetime', 'target.salary = source.salary') + ), + batch_concurrency 1, + forward_only FALSE, + disable_restatement FALSE, + on_destructive_change 'ERROR', + on_additive_change 'ALLOW' + ) +); + +SELECT + 1 AS purchase_order_id, + 1 AS salary, + '2020-01-01 12:05:01'::DATETIME AS update_datetime""" + ) + + +def test_when_matched_multiple(): + expressions = d.parse( + """ + MODEL ( + name @{schema}.employees, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key name, + when_matched WHEN MATCHED AND source.x = 1 THEN UPDATE SET target.salary = COALESCE(source.salary, target.salary), + WHEN MATCHED THEN UPDATE SET target.salary = COALESCE(source.salary, target.salary) + + ) + ); + SELECT 'name' AS name, 1 AS salary; + """ + ) + + expected_when_matched = [ + "WHEN MATCHED AND `__MERGE_SOURCE__`.`x` = 1 THEN UPDATE SET `__MERGE_TARGET__`.`salary` = COALESCE(`__MERGE_SOURCE__`.`salary`, `__MERGE_TARGET__`.`salary`)", + "WHEN MATCHED THEN UPDATE SET `__MERGE_TARGET__`.`salary` = COALESCE(`__MERGE_SOURCE__`.`salary`, `__MERGE_TARGET__`.`salary`)", + ] + + model = load_sql_based_model(expressions, dialect="hive", variables={"schema": "db"}) + whens = model.kind.when_matched + assert len(whens.expressions) == 2 + assert whens.expressions[0].sql(dialect="hive") == expected_when_matched[0] + assert whens.expressions[1].sql(dialect="hive") == expected_when_matched[1] + + model = SqlModel.parse_raw(model.json()) + whens = model.kind.when_matched + assert len(whens.expressions) == 2 + assert whens.expressions[0].sql(dialect="hive") == expected_when_matched[0] + assert whens.expressions[1].sql(dialect="hive") == expected_when_matched[1] + + +def test_when_matched_merge_filter_multi_part_columns(): + expressions = d.parse( + """ + MODEL ( + name @{schema}.records_model, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key name, + when_matched (WHEN MATCHED AND source.record.nested_record.field = 1 THEN UPDATE SET target.repeated_record.sub_repeated_record.sub_field = COALESCE(source.repeated_record.sub_repeated_record.sub_field, target.repeated_record.sub_repeated_record.sub_field), + WHEN MATCHED THEN UPDATE SET target.repeated_record.sub_repeated_record.sub_field = COALESCE(source.repeated_record.sub_repeated_record.sub_field, target.repeated_record.sub_repeated_record.sub_field)), + merge_filter source.record.nested_record.field < target.record.nested_record.field AND + target.repeated_record.sub_repeated_record.sub_field > source.repeated_record.sub_repeated_record.sub_field + ) + ); + SELECT + id, + [STRUCT([STRUCT(sub_field AS sub_field)] AS sub_repeated_record)] AS repeated_record, + STRUCT( + STRUCT([2, 3] AS array, field AS field) AS nested_record + ) AS record + FROM + @{schema}.seed_model; + """ + ) + + expected_when_matched = [ + "WHEN MATCHED AND `__MERGE_SOURCE__`.`record`.`nested_record`.`field` = 1 THEN UPDATE SET `__MERGE_TARGET__`.`repeated_record`.`sub_repeated_record`.`sub_field` = COALESCE(`__MERGE_SOURCE__`.`repeated_record`.`sub_repeated_record`.`sub_field`, `__MERGE_TARGET__`.`repeated_record`.`sub_repeated_record`.`sub_field`)", + "WHEN MATCHED THEN UPDATE SET `__MERGE_TARGET__`.`repeated_record`.`sub_repeated_record`.`sub_field` = COALESCE(`__MERGE_SOURCE__`.`repeated_record`.`sub_repeated_record`.`sub_field`, `__MERGE_TARGET__`.`repeated_record`.`sub_repeated_record`.`sub_field`)", + ] + + expected_merge_filter = ( + "`__MERGE_SOURCE__`.`record`.`nested_record`.`field` < `__MERGE_TARGET__`.`record`.`nested_record`.`field` AND " + "`__MERGE_TARGET__`.`repeated_record`.`sub_repeated_record`.`sub_field` > `__MERGE_SOURCE__`.`repeated_record`.`sub_repeated_record`.`sub_field`" + ) + + model = load_sql_based_model(expressions, dialect="bigquery", variables={"schema": "db"}) + whens = model.kind.when_matched + assert len(whens.expressions) == 2 + assert whens.expressions[0].sql(dialect="bigquery") == expected_when_matched[0] + assert whens.expressions[1].sql(dialect="bigquery") == expected_when_matched[1] + assert model.merge_filter.sql(dialect="bigquery") == expected_merge_filter + + model = SqlModel.parse_raw(model.json()) + whens = model.kind.when_matched + assert len(whens.expressions) == 2 + assert whens.expressions[0].sql(dialect="bigquery") == expected_when_matched[0] + assert whens.expressions[1].sql(dialect="bigquery") == expected_when_matched[1] + assert model.merge_filter.sql(dialect="bigquery") == expected_merge_filter + + +def test_when_matched_normalization() -> None: + # unquoted should be normalized and quoted + expressions = d.parse( + """ + MODEL ( + name test.employees, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key name, + when_matched ( + WHEN MATCHED THEN UPDATE SET + target.key_a = source.key_a, + target.key_b = source.key_b, + ) + ) + ); + SELECT 'name' AS name, 1 AS key_a, 2 AS key_b; + """ + ) + model = load_sql_based_model(expressions, dialect="snowflake") + + assert isinstance(model.kind, IncrementalByUniqueKeyKind) + assert isinstance(model.kind.when_matched, exp.Whens) + first_expression = model.kind.when_matched.expressions[0] + assert isinstance(first_expression, exp.Expression) + assert ( + first_expression.sql(dialect="snowflake") + == 'WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."KEY_A" = "__MERGE_SOURCE__"."KEY_A", "__MERGE_TARGET__"."KEY_B" = "__MERGE_SOURCE__"."KEY_B"' + ) + + # quoted should be preserved + expressions = d.parse( + """ + MODEL ( + name test.employees, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key name, + when_matched ( + WHEN MATCHED THEN UPDATE SET + target."kEy_A" = source."kEy_A", + target."kEY_b" = source.key_b, + ) + ) + ); + SELECT 'name' AS name, 1 AS "kEy_A", 2 AS "kEY_b"; + """ + ) + model = load_sql_based_model(expressions, dialect="snowflake") + + assert isinstance(model.kind, IncrementalByUniqueKeyKind) + assert isinstance(model.kind.when_matched, exp.Whens) + first_expression = model.kind.when_matched.expressions[0] + assert isinstance(first_expression, exp.Expression) + assert ( + first_expression.sql(dialect="snowflake") + == 'WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."kEy_A" = "__MERGE_SOURCE__"."kEy_A", "__MERGE_TARGET__"."kEY_b" = "__MERGE_SOURCE__"."KEY_B"' + ) + + +def test_default_catalog_sql(assert_exp_eq): + """ + This test validates the hashing behavior of the system as it relates to the default catalog. + The system is not designed to actually support having an engine that doesn't support default catalog + to start supporting it or the reverse of that. If that did happen then bugs would occur. + """ + HASH_WITH_CATALOG = "2768215345" + + # Test setting default catalog doesn't change hash if it matches existing logic + expressions = d.parse( + """ + MODEL ( + name catalog.db.table + ); + SELECT x + FROM catalog.db.source + """ + ) + + model = load_sql_based_model(expressions, default_catalog="catalog") + assert model.default_catalog == "catalog" + assert model.name == "catalog.db.table" + assert model.fqn == '"catalog"."db"."table"' + assert model.depends_on == {'"catalog"."db"."source"'} + + assert_exp_eq( + model.render_query(), + """ + SELECT + "x" AS "x" + FROM "catalog"."db"."source" AS "source" + """, + ) + + assert model.data_hash == HASH_WITH_CATALOG + + expressions = d.parse( + """ + MODEL ( + name catalog.db.table, + ); + SELECT x + FROM catalog.db.source + """ + ) + + model = load_sql_based_model(expressions) + assert model.default_catalog is None + assert model.name == "catalog.db.table" + assert model.fqn == '"catalog"."db"."table"' + assert model.depends_on == {'"catalog"."db"."source"'} + + assert_exp_eq( + model.render_query(), + """ + SELECT + "x" AS "x" + FROM "catalog"."db"."source" AS "source" + """, + ) + + assert model.data_hash == HASH_WITH_CATALOG + + # Test setting default catalog to a different catalog but everything if fully qualified then no hash change + expressions = d.parse( + """ + MODEL ( + name catalog.db.table + ); + SELECT x + FROM catalog.db.source + """ + ) + + model = load_sql_based_model(expressions, default_catalog="other_catalog") + assert model.default_catalog == "other_catalog" + assert model.name == "catalog.db.table" + assert model.fqn == '"catalog"."db"."table"' + assert model.depends_on == {'"catalog"."db"."source"'} + + assert_exp_eq( + model.render_query(), + """ + SELECT + "x" AS "x" + FROM "catalog"."db"."source" AS "source" + """, + ) + + assert model.data_hash == HASH_WITH_CATALOG + + # test that hash changes if model contains a non-fully-qualified reference + expressions = d.parse( + """ + MODEL ( + name catalog.db.table + ); + SELECT x + FROM db.source + """ + ) + + model = load_sql_based_model(expressions, default_catalog="other_catalog") + assert model.default_catalog == "other_catalog" + assert model.name == "catalog.db.table" + assert model.fqn == '"catalog"."db"."table"' + assert model.depends_on == {'"other_catalog"."db"."source"'} + + # The query changed so the hash should change + assert model.data_hash != HASH_WITH_CATALOG + + # test that hash is the same but the fqn is different so the snapshot is different so this is + # a new snapshot but with the same hash as before + expressions = d.parse( + """ + MODEL ( + name db.table, + ); + SELECT x + FROM catalog.db.source + """ + ) + + model = load_sql_based_model(expressions) + assert model.default_catalog is None + assert model.name == "db.table" + assert model.fqn == '"db"."table"' + assert model.depends_on == {'"catalog"."db"."source"'} + + assert model.data_hash == HASH_WITH_CATALOG + + # This will also have the same hash but the fqn is different so the snapshot is different so this is + # a new snapshot but with the same hash as before + expressions = d.parse( + """ + MODEL ( + name db.table + ); + SELECT x + FROM catalog.db.source + """ + ) + + model = load_sql_based_model(expressions, default_catalog="catalog") + assert model.default_catalog == "catalog" + assert model.name == "db.table" + assert model.fqn == '"catalog"."db"."table"' + assert model.depends_on == {'"catalog"."db"."source"'} + + assert model.data_hash == HASH_WITH_CATALOG + + # Query is different since default catalog does not apply and therefore the hash is different + expressions = d.parse( + """ + MODEL ( + name table + ); + SELECT x + FROM source + """ + ) + + model = load_sql_based_model(expressions, default_catalog="catalog") + assert model.default_catalog == "catalog" + assert model.name == "table" + assert model.fqn == '"table"' + assert model.depends_on == {'"source"'} + + assert model.data_hash != HASH_WITH_CATALOG + + +def test_default_catalog_python(): + HASH_WITH_CATALOG = "2728996410" + + @model(name="db.table", kind="full", columns={'"COL"': "int"}) + def my_model(context, **kwargs): + context.resolve_table("dependency.table") + + m = model.get_registry()["db.table"].model( + module_path=Path("."), + path=Path("."), + ) + + assert m.default_catalog is None + assert m.name == "db.table" + assert m.fqn == '"db"."table"' + assert m.depends_on == {'"dependency"."table"'} + + assert m.data_hash != HASH_WITH_CATALOG + + m = model.get_registry()["db.table"].model( + module_path=Path("."), + path=Path("."), + default_catalog="catalog", + ) + + assert m.default_catalog == "catalog" + assert m.name == "db.table" + assert m.fqn == '"catalog"."db"."table"' + assert m.depends_on == {'"catalog"."dependency"."table"'} + + # This ideally would be `m.data_hash == HASH_WITH_CATALOG`. The reason it is not is + # because when we hash the python function we make the hash out of the actual logic + # of the function which means `context.resolve_table("dependency.table")` is used + # when really is should be `context.resolve_table("catalog.dependency.table")`. + assert m.data_hash != HASH_WITH_CATALOG + + @model(name="catalog.db.table", kind="full", columns={'"COL"': "int"}) + def my_model(context, **kwargs): + context.resolve_table("catalog.dependency.table") + + m = model.get_registry()["catalog.db.table"].model( + module_path=Path("."), + path=Path("."), + default_catalog="other_catalog", + ) + + assert m.default_catalog == "other_catalog" + assert m.name == "catalog.db.table" + assert m.fqn == '"catalog"."db"."table"' + assert m.depends_on == {'"catalog"."dependency"."table"'} + + assert m.data_hash == HASH_WITH_CATALOG + + @model(name="catalog.db.table2", kind="full", columns={'"COL"': "int"}) + def my_model(context, **kwargs): + context.resolve_table("dependency.table") + + m = model.get_registry()["catalog.db.table2"].model( + module_path=Path("."), + path=Path("."), + default_catalog="other_catalog", + ) + + assert m.default_catalog == "other_catalog" + assert m.name == "catalog.db.table2" + assert m.fqn == '"catalog"."db"."table2"' + assert m.depends_on == {'"other_catalog"."dependency"."table"'} + + assert m.data_hash != HASH_WITH_CATALOG + + @model(name="table", kind="full", columns={'"COL"': "int"}) + def my_model(context, **kwargs): + context.resolve_table("table2") + + m = model.get_registry()["table"].model( + module_path=Path("."), + path=Path("."), + default_catalog="catalog", + ) + + assert m.default_catalog == "catalog" + assert m.name == "table" + assert m.fqn == '"table"' + assert m.depends_on == {'"table2"'} + + assert m.data_hash != HASH_WITH_CATALOG + + +def test_default_catalog_external_model(): + """ + Since external models fqns are the only thing affected by default catalog, and when they change new snapshots + are made, the hash will be the same across different names. + """ + EXPECTED_HASH = "763256265" + + model = create_external_model("db.table", columns={"a": "int", "limit": "int"}) + assert model.default_catalog is None + assert model.name == "db.table" + assert model.fqn == '"db"."table"' + + assert model.data_hash == EXPECTED_HASH + + model = create_external_model( + "db.table", columns={"a": "int", "limit": "int"}, default_catalog="catalog" + ) + assert model.default_catalog == "catalog" + assert model.name == "db.table" + assert model.fqn == '"catalog"."db"."table"' + + assert model.data_hash == EXPECTED_HASH + + model = create_external_model( + "catalog.db.table", columns={"a": "int", "limit": "int"}, default_catalog="other_catalog" + ) + assert model.default_catalog == "other_catalog" + assert model.name == "catalog.db.table" + assert model.fqn == '"catalog"."db"."table"' + + assert model.data_hash == EXPECTED_HASH + + # Since there is no schema defined, the default physical schema is used which changes the hash + model = create_external_model( + "table", columns={"a": "int", "limit": "int"}, default_catalog="catalog" + ) + + assert model.default_catalog == "catalog" + assert model.name == "table" + assert model.fqn == '"table"' + + assert model.data_hash != EXPECTED_HASH + + +def test_user_cannot_set_default_catalog(): + expressions = d.parse( + """ + MODEL ( + name db.table, + default_catalog some_catalog + ); + + SELECT 1::int AS a, 2::int AS b, 3 AS c, 4 as d; + """ + ) + + with pytest.raises(ConfigError, match="`default_catalog` cannot be set on a per-model basis"): + load_sql_based_model(expressions) + + with pytest.raises(ConfigError, match="`default_catalog` cannot be set on a per-model basis"): + + @model(name="db.table", kind="full", columns={'"COL"': "int"}, default_catalog="catalog") + def my_model(context, **kwargs): + context.resolve_table("dependency.table") + + +def test_depends_on_default_catalog_python(): + @model(name="some.table", kind="full", columns={'"COL"': "int"}, depends_on={"other.table"}) + def my_model(context, **kwargs): + context.resolve_table("dependency.table") + + m = model.get_registry()["some.table"].model( + module_path=Path("."), + path=Path("."), + default_catalog="catalog", + ) + + assert m.default_catalog == "catalog" + assert m.depends_on == {'"catalog"."other"."table"'} + + +def test_end_date(): + expressions = d.parse( + """ + MODEL ( + name db.table, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ts, + ), + start '2023-01-01', + end '2023-06-01' + ); + + SELECT 1::int AS a, 2::int AS b, now::timestamp as ts + """ + ) + model = load_sql_based_model(expressions) + + assert model.start == "2023-01-01" + assert model.end == "2023-06-01" + assert model.interval_unit == IntervalUnit.DAY + + with pytest.raises(ValidationError, match=".*Start date.+can't be greater than end date.*"): + load_sql_based_model( + d.parse( + """ + MODEL ( + name db.table, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ts, + ), + start '2024-01-01', + end '2023-06-01' + ); + + SELECT 1::int AS a, 2::int AS b, now::timestamp as ts + """ + ) + ) + + +def test_end_no_start(): + expressions = d.parse( + """ + MODEL ( + name db.table, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ts, + ), + end '2023-06-01' + ); + + SELECT 1::int AS a, 2::int AS b, now::timestamp as ts + """ + ) + with pytest.raises(ConfigError, match="Must define a start date if an end date is defined"): + load_sql_based_model(expressions) + load_sql_based_model(expressions, defaults={"start": "2023-01-01"}) + + +def test_variables(): + @macro() + def test_macro_var(evaluator) -> exp.Expression: + return exp.convert(evaluator.var("TEST_VAR_D") + 10) + + expressions = parse( + """ + MODEL( + name test_model, + kind FULL, + ); + + SELECT + @VAR('TEST_VAR_A') AS a, + @VAR('test_var_b', 'default_value') AS b, + @VAR('test_var_c') AS c, + @TEST_MACRO_VAR() AS d, + @'foo_@{test_var_e}' AS e, + @SQL(foo_@{test_var_f}) AS f, + 'foo_@{test_var_unused}' AS g + """, + default_dialect="bigquery", + ) + + model = load_sql_based_model( + expressions, + variables={ + "test_var_a": "test_value", + "test_var_d": 1, + "test_var_e": 4, + "test_var_f": 5, + "test_var_unused": 2, + }, + ) + assert model.python_env[c.SQLMESH_VARS] == Executable.value( + {"test_var_a": "test_value", "test_var_d": 1, "test_var_e": 4, "test_var_f": 5} + ) + assert ( + model.render_query().sql(dialect="bigquery") + == "SELECT 'test_value' AS `a`, 'default_value' AS `b`, NULL AS `c`, 11 AS `d`, 'foo_4' AS `e`, `foo_5` AS `f`, 'foo_@{test_var_unused}' AS `g`" + ) + + with pytest.raises(ConfigError, match=r"Macro VAR requires at least one argument.*"): + expressions = parse( + """ + MODEL( + name test_model, + ); + + SELECT @VAR() AS a; + """, + default_dialect="bigquery", + ) + load_sql_based_model(expressions) + + with pytest.raises( + ConfigError, match=r"The variable name must be a string literal, '123' was given instead.*" + ): + expressions = parse( + """ + MODEL( + name test_model, + ); + + SELECT @VAR(123) AS a; + """, + default_dialect="bigquery", + ) + load_sql_based_model(expressions) + + with pytest.raises( + ConfigError, + match=r"The variable name must be a string literal, '@VAR_NAME' was given instead.*", + ): + expressions = parse( + """ + MODEL( + name test_model, + ); + + @DEF(VAR_NAME, 'var_name'); + SELECT @VAR(@VAR_NAME) AS a; + """, + default_dialect="bigquery", + ) + load_sql_based_model(expressions) + + +def test_named_variable_macros() -> None: + model = load_sql_based_model( + parse( + """ + MODEL(name sushi.test_gateway_macro); + @DEF(overridden_var, 'overridden_value'); + SELECT @gateway AS gateway, @TEST_VAR_A AS test_var_a, @overridden_var AS overridden_var + """ + ), + variables={ + c.GATEWAY: "in_memory", + "test_var_a": "test_value", + "test_var_unused": "unused", + "overridden_var": "initial_value", + }, + ) + + assert model.python_env[c.SQLMESH_VARS] == Executable.value( + {c.GATEWAY: "in_memory", "test_var_a": "test_value", "overridden_var": "initial_value"}, + sort_root_dict=True, + ) + assert ( + model.render_query_or_raise().sql() + == "SELECT 'in_memory' AS \"gateway\", 'test_value' AS \"test_var_a\", 'overridden_value' AS \"overridden_var\"" + ) + + +def test_variables_in_templates() -> None: + model = load_sql_based_model( + parse( + """ + MODEL(name sushi.test_gateway_macro); + @DEF(overridden_var, overridden_value); + SELECT 'gateway' AS col_@gateway, 'test_var_a' AS @{test_var_a}_col, 'overridden_var' AS col_@{overridden_var}_col + """ + ), + variables={ + c.GATEWAY: "in_memory", + "test_var_a": "test_value", + "test_var_unused": "unused", + "overridden_var": "initial_value", + }, + ) + + assert model.python_env[c.SQLMESH_VARS] == Executable.value( + {c.GATEWAY: "in_memory", "test_var_a": "test_value", "overridden_var": "initial_value"}, + sort_root_dict=True, + ) + assert ( + model.render_query_or_raise().sql() + == "SELECT 'gateway' AS \"col_in_memory\", 'test_var_a' AS \"test_value_col\", 'overridden_var' AS \"col_overridden_value_col\"" + ) + + model = load_sql_based_model( + parse( + """ + MODEL(name sushi.test_gateway_macro); + @DEF(overridden_var, overridden_value); + SELECT 'combo' AS col_@{test_var_a}_@{overridden_var}_col_@gateway + """ + ), + variables={ + c.GATEWAY: "in_memory", + "test_var_a": "test_value", + "test_var_unused": "unused", + "overridden_var": "initial_value", + }, + ) + + assert model.python_env[c.SQLMESH_VARS] == Executable.value( + {c.GATEWAY: "in_memory", "test_var_a": "test_value", "overridden_var": "initial_value"}, + sort_root_dict=True, + ) + assert ( + model.render_query_or_raise().sql() + == "SELECT 'combo' AS \"col_test_value_overridden_value_col_in_memory\"" + ) + + model = load_sql_based_model( + parse( + """ + MODEL( + name @{some_var}.bar, + dialect snowflake + ); + + SELECT 1 AS c + """ + ), + variables={ + "some_var": "foo", + }, + ) + + assert model.name == "foo.bar" + + +def test_variables_jinja(): + expressions = parse( + """ + MODEL( + name test_model, + kind FULL, + ); + + JINJA_QUERY_BEGIN; + SELECT '{{ var('TEST_VAR_A') }}' AS a, '{{ var('test_var_b', 'default_value') }}' AS b, '{{ var('test_var_c') }}' AS c, {{ test_macro_var() }} AS d; + JINJA_END; + """, + default_dialect="bigquery", + ) + + jinja_macros = JinjaMacroRegistry( + root_macros={ + "test_macro_var": MacroInfo( + definition="{% macro test_macro_var() %}{{ var('test_var_d') + 10 }}{% endmacro %}", + depends_on=[], + ) + }, + ) + + model = load_sql_based_model( + expressions, + variables={"test_var_a": "test_value", "test_var_d": 1, "test_var_unused": 2}, + jinja_macros=jinja_macros, + ) + assert model.python_env[c.SQLMESH_VARS] == Executable.value( + {"test_var_a": "test_value", "test_var_d": 1} + ) + assert ( + model.render_query().sql(dialect="bigquery") + == "SELECT 'test_value' AS `a`, 'default_value' AS `b`, 'None' AS `c`, 11 AS `d`" + ) + + +def test_variables_python_model(mocker: MockerFixture) -> None: + @model( + "foo_@{bar}", + kind="full", + columns={"a": "string", "b": "string", "c": "string"}, + ) + def model_with_variables(context, **kwargs): + return pd.DataFrame( + [ + { + "a": context.var("TEST_VAR_A"), + "b": context.var("test_var_b", "default_value"), + "c": context.var("test_var_c"), + } + ] + ) + + python_model = model.get_registry()["foo_@{bar}"].model( + module_path=Path("."), + path=Path("."), + variables={"test_var_a": "test_value", "test_var_unused": 2, "bar": "suffix"}, + ) + + assert python_model.name == "foo_suffix" + assert python_model.python_env[c.SQLMESH_VARS] == Executable.value({"test_var_a": "test_value"}) + + context = ExecutionContext(mocker.Mock(), {}, None, None) + df = list(python_model.render(context=context))[0] + assert df.to_dict(orient="records") == [{"a": "test_value", "b": "default_value", "c": None}] + + +def test_load_external_model_python(sushi_context) -> None: + @model( + "test_load_external_model_python", + columns={"customer_id": "int", "zip": "str"}, + kind={"name": ModelKindName.FULL}, + ) + def external_model_python(context, **kwargs): + demographics_table = context.resolve_table("memory.raw.demographics") + return context.fetchdf( + exp.select("customer_id", "zip").from_(demographics_table), + ) + + python_model = model.get_registry()["test_load_external_model_python"].model( + module_path=Path("."), + path=Path("."), + ) + + context = ExecutionContext(sushi_context.engine_adapter, sushi_context.snapshots, None, None) + df = list(python_model.render(context=context))[0] + + assert df.to_dict(orient="records") == [{"customer_id": 1, "zip": "00000"}] + + +def test_variables_python_sql_model(mocker: MockerFixture) -> None: + @model( + "test_variables_python_model_@{bar}", + is_sql=True, + kind="full", + columns={"a": "string", "b": "string", "c": "string"}, + ) + def model_with_variables(evaluator, **kwargs): + return exp.select( + exp.convert(evaluator.var("TEST_VAR_A")).as_("a"), + exp.convert(evaluator.var("test_var_b", "default_value")).as_("b"), + exp.convert(evaluator.var("test_var_c")).as_("c"), + ) + + python_sql_model = model.get_registry()["test_variables_python_model_@{bar}"].model( + module_path=Path("."), + path=Path("."), + variables={"test_var_a": "test_value", "test_var_unused": 2, "bar": "suffix"}, + ) + + assert python_sql_model.name == "test_variables_python_model_suffix" + assert python_sql_model.python_env[c.SQLMESH_VARS] == Executable.value( + {"test_var_a": "test_value"} + ) + + context = ExecutionContext(mocker.Mock(), {}, None, None) + query = list(python_sql_model.render(context=context))[0] + assert ( + query.sql() + == """SELECT 'test_value' AS "a", 'default_value' AS "b", NULL AS "c" """.strip() + ) + + +def test_macros_python_model(mocker: MockerFixture) -> None: + @model( + "foo_macro_model_@{bar}", + columns={"a": "string"}, + kind=dict(name=ModelKindName.INCREMENTAL_BY_TIME_RANGE, time_column="@{time_col}"), + stamp="@{stamp}", + cron="@some_cron_var", + owner="@IF(@gateway = 'dev', @{dev_owner}, @{prod_owner})", + enabled="@IF(@gateway = 'dev', True, False)", + start="@IF(@gateway = 'dev', '1 month ago', '2024-01-01')", + partitioned_by=[ + d.parse_one("DATETIME_TRUNC(@{time_col}, MONTH)"), + ], + ) + def model_with_macros(context, **kwargs): + return pd.DataFrame( + [ + { + "a": context.var("TEST_VAR_A"), + } + ] + ) + + python_model = model.get_registry()["foo_macro_model_@{bar}"].model( + module_path=Path("."), + path=Path("."), + variables={ + "test_var_a": "test_value", + "gateway": "prod", + "bar": "suffix", + "dev_owner": "dv_1", + "prod_owner": "pr_1", + "stamp": "bump", + "time_col": "a", + "some_cron_var": "@daily", + }, + ) + + assert python_model.name == "foo_macro_model_suffix" + assert python_model.python_env[c.SQLMESH_VARS] == Executable.value({"test_var_a": "test_value"}) + assert not python_model.enabled + assert python_model.start == "2024-01-01" + assert python_model.owner == "pr_1" + assert python_model.stamp == "bump" + assert python_model.time_column.column == exp.column("a", quoted=True) + assert python_model.partitioned_by[0].sql() == 'DATETIME_TRUNC("a", MONTH)' + assert python_model.cron == "@daily" + + context = ExecutionContext(mocker.Mock(), {}, None, None) + df = list(python_model.render(context=context))[0] + assert df.to_dict(orient="records") == [{"a": "test_value"}] + + +def test_macros_python_sql_model(mocker: MockerFixture) -> None: + @macro() + def end_date_macro(evaluator: MacroEvaluator, var: bool): + return f"@IF({var} = False, '1 day ago', '2025-01-01 12:00:00')" + + @model( + "test_macros_python_model_@{bar}", + is_sql=True, + kind="full", + cron="@daily", + columns={"a": "string"}, + enabled="@IF(@gateway = 'dev', True, False)", + start="@IF(@gateway = 'dev', '1 month ago', '2024-01-01')", + end="@end_date_macro(@{global_var})", + owner="@IF(@gateway = 'dev', @{dev_owner}, @{prod_owner})", + stamp="@{stamp}", + tags=["@{tag1}", "@{tag2}"], + description="'Model desc @{test_}'", + ) + def model_with_macros(evaluator, **kwargs): + return exp.select( + exp.convert(evaluator.var("TEST_VAR_A")).as_("a"), + ) + + python_sql_model = model.get_registry()["test_macros_python_model_@{bar}"].model( + module_path=Path("."), + path=Path("."), + macros=macro.get_registry(), + variables={ + "test_var_a": "test_value", + "test_var_unused": 2, + "bar": "suffix", + "gateway": "dev", + "global_var": False, + "dev_owner": "dv_1", + "prod_owner": "pr_1", + "stamp": "bump", + "time_col": "a", + "tag1": "tag__1", + "tag2": "tag__2", + }, + ) + + assert python_sql_model.name == "test_macros_python_model_suffix" + assert python_sql_model.python_env[c.SQLMESH_VARS] == Executable.value( + {"test_var_a": "test_value"} + ) + + assert python_sql_model.enabled + assert python_sql_model.start == "1 month ago" + assert python_sql_model.end == "1 day ago" + assert python_sql_model.owner == "dv_1" + assert python_sql_model.stamp == "bump" + assert python_sql_model.description == "Model desc @{test_}" + assert python_sql_model.tags == ["tag__1", "tag__2"] + + context = ExecutionContext(mocker.Mock(), {}, None, None) + query = list(python_sql_model.render(context=context))[0] + assert query.sql() == """SELECT 'test_value' AS "a" """.strip() + + +def test_unrendered_macros_sql_model(mocker: MockerFixture) -> None: + model = load_sql_based_model( + parse( + """ + MODEL ( + name db.employees, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key @{key}, + merge_filter source.id > 0 and target.updated_at < @end_ds and source.updated_at > @start_ds and @merge_filter_var + ), + cron '@daily', + allow_partials @IF(@gateway = 'dev', True, False), + physical_properties ( + location1 = @'s3://bucket/prefix/@{schema_name}/@{table_name}', + location2 = @IF(@gateway = 'dev', @'hdfs://@{catalog_name}/@{schema_name}/dev/@{table_name}', @'s3://prod/@{table_name}'), + foo = @physical_var + ), + virtual_properties ( + creatable_type = @{create_type}, + bar = @virtual_var, + ), + session_properties ( + 'spark.executor.cores' = @IF(@gateway = 'dev', 1, 2), + 'spark.executor.memory' = '1G', + baz = @session_var + ), + ); + + SELECT * FROM src; + """ + ), + variables={ + "gateway": "dev", + "key": "a", # Not included in python_env because kind is rendered at load time + "create_type": "'SECURE'", + "merge_filter_var": True, + "physical_var": "bla", + "virtual_var": "blb", + "session_var": "blc", + }, + ) + + assert model.python_env[c.SQLMESH_VARS] == Executable.value( + { + "gateway": "dev", + "create_type": "'SECURE'", + "merge_filter_var": True, + "physical_var": "bla", + "virtual_var": "blb", + "session_var": "blc", + }, + sort_root_dict=True, + ) + + assert "location1" in model.physical_properties + assert "location2" in model.physical_properties + + # The properties will stay unrendered at load time + assert model.session_properties == { + "spark.executor.cores": exp.maybe_parse("@IF(@gateway = 'dev', 1, 2)"), + "spark.executor.memory": "1G", + "baz": exp.maybe_parse("@session_var"), + } + assert model.virtual_properties["creatable_type"] == exp.maybe_parse("@{create_type}") + + assert ( + model.physical_properties["location1"].sql() + == "@'s3://bucket/prefix/@{schema_name}/@{table_name}'" + ) + assert ( + model.physical_properties["location2"].sql() + == "@IF(@gateway = 'dev', @'hdfs://@{catalog_name}/@{schema_name}/dev/@{table_name}', @'s3://prod/@{table_name}')" + ) + + # merge_filter will stay unrendered as well + assert model.unique_key[0] == exp.column("a", quoted=True) + assert ( + t.cast(exp.Expression, model.merge_filter).sql() + == '"__MERGE_SOURCE__"."id" > 0 AND "__MERGE_TARGET__"."updated_at" < @end_ds AND "__MERGE_SOURCE__"."updated_at" > @start_ds AND @merge_filter_var' + ) + + +def test_unrendered_macros_python_model(mocker: MockerFixture) -> None: + @model( + "test_unrendered_macros_python_model_@{bar}", + is_sql=True, + kind=dict( + name=ModelKindName.INCREMENTAL_BY_UNIQUE_KEY, + unique_key="@{key}", + merge_filter="source.id > 0 and target.updated_at < @end_ds and source.updated_at > @start_ds and @merge_filter_var", + ), + cron="@daily", + columns={"a": "string"}, + allow_partials="@IF(@gateway = 'dev', True, False)", + physical_properties=dict( + location1="@'s3://bucket/prefix/@{schema_name}/@{table_name}'", + location2="@IF(@gateway = 'dev', @'hdfs://@{catalog_name}/@{schema_name}/dev/@{table_name}', @'s3://prod/@{table_name}')", + foo="@physical_var", + ), + virtual_properties={"creatable_type": "@{create_type}", "bar": "@virtual_var"}, + session_properties={ + "spark.executor.cores": "@IF(@gateway = 'dev', 1, 2)", + "spark.executor.memory": "1G", + "baz": "@session_var", + }, + ) + def model_with_macros(evaluator, **kwargs): + return exp.select( + exp.convert(evaluator.var("TEST_VAR_A")).as_("a"), + ) + + python_sql_model = model.get_registry()["test_unrendered_macros_python_model_@{bar}"].model( + module_path=Path("."), + path=Path("."), + macros=macro.get_registry(), + variables={ + "test_var_a": "test_value", + "bar": "suffix", + "gateway": "dev", + "key": "a", + "create_type": "'SECURE'", + "merge_filter_var": True, + "physical_var": "bla", + "virtual_var": "blb", + "session_var": "blc", + }, + ) + + assert python_sql_model.name == "test_unrendered_macros_python_model_suffix" + assert python_sql_model.python_env[c.SQLMESH_VARS] == Executable.value( + { + "test_var_a": "test_value", + "gateway": "dev", + "create_type": "'SECURE'", + "merge_filter_var": True, + "physical_var": "bla", + "virtual_var": "blb", + "session_var": "blc", + }, + sort_root_dict=True, + ) + assert python_sql_model.enabled + + context = ExecutionContext(mocker.Mock(), {}, None, None) + query = list(python_sql_model.render(context=context))[0] + assert query.sql() == """SELECT 'test_value' AS "a" """.strip() + assert python_sql_model.allow_partials + + assert "location1" in python_sql_model.physical_properties + assert "location2" in python_sql_model.physical_properties + + # The properties will stay unrendered at load time + assert python_sql_model.session_properties == { + "spark.executor.cores": exp.maybe_parse("@IF(@gateway = 'dev', 1, 2)"), + "spark.executor.memory": "1G", + "baz": exp.maybe_parse("@session_var"), + } + assert python_sql_model.virtual_properties["creatable_type"] == exp.maybe_parse( + "@{create_type}" + ) + + assert ( + python_sql_model.physical_properties["location1"].sql() + == "@'s3://bucket/prefix/@{schema_name}/@{table_name}'" + ) + assert ( + python_sql_model.physical_properties["location2"].sql() + == "@IF(@gateway = 'dev', @'hdfs://@{catalog_name}/@{schema_name}/dev/@{table_name}', @'s3://prod/@{table_name}')" + ) + + # merge_filter will stay unrendered as well + assert python_sql_model.unique_key[0] == exp.column("a", quoted=True) + assert ( + python_sql_model.merge_filter.sql() + == '"__MERGE_SOURCE__"."id" > 0 AND "__MERGE_TARGET__"."updated_at" < @end_ds AND "__MERGE_SOURCE__"."updated_at" > @start_ds AND @merge_filter_var' + ) + + +def test_columns_python_sql_model() -> None: + @model( + "test_columns_python_model", + is_sql=True, + kind="full", + columns={"d": "Date", "s": "String", "dt": "DateTime"}, + ) + def model_with_columns(evaluator, **kwargs): + return exp.select("*").from_("fake") + + python_sql_model = model.get_registry()["test_columns_python_model"].model( + module_path=Path("."), + path=Path("."), + ) + + columns_to_types = python_sql_model.columns_to_types + + assert columns_to_types is not None + assert isinstance(columns_to_types["d"], exp.DataType) + assert columns_to_types["d"].this == exp.DataType.Type.DATE + assert isinstance(columns_to_types["s"], exp.DataType) + assert columns_to_types["s"].this == exp.DataType.Type.TEXT + assert isinstance(columns_to_types["dt"], exp.DataType) + assert columns_to_types["dt"].this == exp.DataType.Type.DATETIME + + +def test_named_variables_python_model(mocker: MockerFixture) -> None: + mocker.patch("sqlmesh.core.model.decorator.model._registry", {}) + + @model( + "test_named_variables_python_model", + kind="full", + columns={"a": "string", "b": "string", "c": "string"}, + ) + def model_with_named_variables( + context, start: TimeLike, test_var_a: str, test_var_b: t.Optional[str] = None, **kwargs + ): + return pd.DataFrame( + [{"a": test_var_a, "b": test_var_b, "start": start.strftime("%Y-%m-%d")}] # type: ignore + ) + + python_model = model.get_registry()["test_named_variables_python_model"].model( + module_path=Path("."), + path=Path("."), + # Passing `start` in variables to make sure that built-in arguments can't be overridden. + variables={ + "test_var_a": "test_value", + "test_var_unused": 2, + "start": "2024-01-01", + }, + ) + + assert python_model.python_env[c.SQLMESH_VARS] == Executable.value( + {"test_var_a": "test_value", "start": "2024-01-01"}, sort_root_dict=True + ) + + context = ExecutionContext(mocker.Mock(), {}, None, None) + df = list(python_model.render(context=context))[0] + assert df.to_dict(orient="records") == [{"a": "test_value", "b": None, "start": to_ds(c.EPOCH)}] + + +def test_named_variables_kw_only_python_model(mocker: MockerFixture) -> None: + mocker.patch("sqlmesh.core.model.decorator.model._registry", {}) + + @model( + "test_named_variables_python_model", + kind="full", + columns={"a": "string"}, + ) + def model_with_named_kw_only_variables( + context, start: TimeLike, *, test_var_a: str = "", **kwargs: t.Any + ): + return pd.DataFrame([{"a": test_var_a}]) + + python_model = model.get_registry()["test_named_variables_python_model"].model( + module_path=Path("."), + path=Path("."), + variables={"test_var_a": "test_value"}, + ) + + assert python_model.python_env[c.SQLMESH_VARS] == Executable.value({"test_var_a": "test_value"}) + + context = ExecutionContext(mocker.Mock(), {}, None, None) + df = list(python_model.render(context=context))[0] + assert df.to_dict(orient="records") == [{"a": "test_value"}] + + +def test_gateway_macro() -> None: + model = load_sql_based_model( + parse( + """ + MODEL(name sushi.test_gateway_macro); + SELECT @gateway AS gateway + """ + ), + variables={c.GATEWAY: "in_memory"}, + ) + + assert model.python_env[c.SQLMESH_VARS] == Executable.value({c.GATEWAY: "in_memory"}) + assert model.render_query_or_raise().sql() == "SELECT 'in_memory' AS \"gateway\"" + + @macro() + def macro_uses_gateway(evaluator) -> exp.Expression: + return exp.convert(evaluator.gateway + "_from_macro") + + model = load_sql_based_model( + parse( + """ + MODEL(name sushi.test_gateway_macro); + SELECT @macro_uses_gateway() AS gateway_from_macro + """ + ), + variables={c.GATEWAY: "in_memory"}, + ) + + assert model.python_env[c.SQLMESH_VARS] == Executable.value({c.GATEWAY: "in_memory"}) + assert ( + model.render_query_or_raise().sql() + == "SELECT 'in_memory_from_macro' AS \"gateway_from_macro\"" + ) + + +def test_gateway_macro_jinja() -> None: + model = load_sql_based_model( + parse( + """ + MODEL(name sushi.test_gateway_macro_jinja); + JINJA_QUERY_BEGIN; + SELECT '{{ gateway() }}' AS gateway_jinja; + JINJA_END; + """ + ), + variables={c.GATEWAY: "in_memory"}, + ) + + assert model.python_env[c.SQLMESH_VARS] == Executable.value({c.GATEWAY: "in_memory"}) + assert model.render_query_or_raise().sql() == "SELECT 'in_memory' AS \"gateway_jinja\"" + + +def test_gateway_python_model(mocker: MockerFixture) -> None: + @model( + "test_gateway_python_model", + kind="full", + columns={"gateway_python": "string"}, + ) + def model_with_variables(context, **kwargs): + return pd.DataFrame([{"gateway_python": context.gateway + "_from_python"}]) + + python_model = model.get_registry()["test_gateway_python_model"].model( + module_path=Path("."), + path=Path("."), + variables={c.GATEWAY: "in_memory"}, + ) + + assert python_model.python_env[c.SQLMESH_VARS] == Executable.value({c.GATEWAY: "in_memory"}) + + context = ExecutionContext(mocker.Mock(), {}, None, None) + df = list(python_model.render(context=context))[0] + assert df.to_dict(orient="records") == [{"gateway_python": "in_memory_from_python"}] + + +@pytest.mark.parametrize("dialect", ["spark", "trino"]) +def test_view_render_no_quote_identifiers(dialect: str) -> None: + expressions = d.parse( + """ + MODEL ( + name db.table, + kind VIEW, + ); + SELECT a, b, c FROM source_table; + """ + ) + model = load_sql_based_model(expressions, dialect=dialect) + assert ( + model.render_query_or_raise().sql(dialect=dialect) + == "SELECT a AS a, b AS b, c AS c FROM source_table AS source_table" + ) + + +@pytest.mark.parametrize( + "dialect,kind", + [ + ("spark", "FULL"), + ("trino", "FULL"), + ("duckdb", "VIEW"), + ("duckdb", "FULL"), + ], +) +def test_render_quote_identifiers(dialect: str, kind: str) -> None: + expressions = d.parse( + f""" + MODEL ( + name db.table, + kind {kind}, + ); + SELECT a, b, c FROM source_table; + """ + ) + model = load_sql_based_model(expressions, dialect=dialect) + assert ( + model.render_query_or_raise().sql(dialect="duckdb") + == 'SELECT "a" AS "a", "b" AS "b", "c" AS "c" FROM "source_table" AS "source_table"' + ) + + +def test_this_model() -> None: + expressions = d.parse( + """ + MODEL ( + name `project-1.table`, + dialect bigquery, + ); + + JINJA_STATEMENT_BEGIN; + VACUUM {{ this_model }} TO 'a'; + JINJA_END; + + JINJA_QUERY_BEGIN; + SELECT '{{ this_model }}' as x + JINJA_END; + + JINJA_STATEMENT_BEGIN; + VACUUM {{ this_model }} TO 'b'; + JINJA_END; + """ + ) + model = load_sql_based_model(expressions) + + assert ( + model.render_query_or_raise().sql(dialect="bigquery") + == """SELECT '`project-1`.`table`' AS `x`""" + ) + + assert ( + model.render_pre_statements()[0].sql(dialect="bigquery") + == """VACUUM `project-1`.`table` TO 'a'""" + ) + assert ( + model.render_post_statements()[0].sql(dialect="bigquery") + == """VACUUM `project-1`.`table` TO 'b'""" + ) + + snapshot = Snapshot.from_node(model, nodes={}) + + assert ( + model.render_query_or_raise( + start="2020-01-01", + snapshots={snapshot.name: snapshot}, + ).sql(dialect="bigquery") + == """SELECT '`project-1`.`table`' AS `x`""" + ) + + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + assert ( + model.render_query_or_raise( + start="2020-01-01", + snapshots={snapshot.name: snapshot}, + ).sql(dialect="bigquery") + == f"SELECT '`sqlmesh__project-1`.`project_1__table__{snapshot.version}`' AS `x`" + ) + + @macro() + def this_model_resolves_to_quoted_table(evaluator): + this_model = evaluator.locals.get("this_model") + if not this_model: + return True + + this_snapshot = evaluator.get_snapshot("db.table") + if this_snapshot and this_snapshot.version: + # If the table wasn't quoted, we'd break the `sqlmesh_DB` reference by + # normalizing it twice + expected_name = f'"sqlmesh__DB"."DB__TABLE__{this_snapshot.version}"' + else: + expected_name = '"DB"."TABLE"' + + return not this_model or ( + isinstance(this_model, exp.Table) + and this_model.sql(dialect=evaluator.dialect, comments=False) == expected_name + and evaluator.this_model == expected_name + ) + + expressions = d.parse( + """ + MODEL (name db.table, dialect snowflake); + + SELECT + 1 AS col, + @this_model_resolves_to_quoted_table() AS this_model_resolves_to_quoted_table; + + CREATE TABLE db.other AS SELECT * FROM @this_model AS x; + """ + ) + model = load_sql_based_model(expressions) + + expected_post = d.parse('CREATE TABLE "DB"."OTHER" AS SELECT * FROM "DB"."TABLE" AS "X";') + assert model.render_post_statements() == expected_post + + snapshot = Snapshot.from_node(model, nodes={}) + assert ( + model.render_query_or_raise(snapshots={snapshot.name: snapshot}, start="2020-01-01").sql( + dialect="snowflake" + ) + == 'SELECT 1 AS "COL", TRUE AS "THIS_MODEL_RESOLVES_TO_QUOTED_TABLE"' + ) + + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + assert ( + model.render_query_or_raise(snapshots={snapshot.name: snapshot}, start="2021-01-01").sql( + dialect="snowflake" + ) + == 'SELECT 1 AS "COL", TRUE AS "THIS_MODEL_RESOLVES_TO_QUOTED_TABLE"' + ) + + +def test_macros_in_physical_properties(make_snapshot): + expressions = d.parse( + """ + MODEL ( + name test.test_model, + kind FULL, + physical_properties ( + location1 = @resolve_template('s3://bucket/prefix/@{schema_name}/@{table_name}'), + location2 = @IF( + @gateway = 'dev', + @resolve_template('hdfs://@{catalog_name}/@{schema_name}/dev/@{table_name}'), + @resolve_template('s3://prod/@{table_name}') + ), + sort_order = @IF(@gateway = 'prod', 'desc', 'asc'), + conditional_prop = @IF(@gateway == 'prod', 'PROD_PROP', NULL) + ) + ); + + SELECT 1; + """ + ) + + model = load_sql_based_model( + expressions, variables={"gateway": "dev"}, default_catalog="unit_test" + ) + + assert model.name == "test.test_model" + assert "location1" in model.physical_properties + assert "location2" in model.physical_properties + assert "sort_order" in model.physical_properties + assert "conditional_prop" in model.physical_properties + + # load time is a no-op + assert isinstance(model.physical_properties["location1"], d.MacroFunc) + assert isinstance(model.physical_properties["location2"], d.MacroFunc) + assert isinstance(model.physical_properties["sort_order"], d.MacroFunc) + assert isinstance(model.physical_properties["conditional_prop"], d.MacroFunc) + + # substitution occurs at runtime + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + rendered_physical_properties = model.render_physical_properties( + snapshots={model.fqn: snapshot}, # to trigger @this_model generation + runtime_stage=RuntimeStage.CREATING, + python_env=model.python_env, + ) + + assert ( + rendered_physical_properties["location1"].text("this") + == f"s3://bucket/prefix/sqlmesh__test/test__test_model__{snapshot.version}" + ) + assert ( + rendered_physical_properties["location2"].text("this") + == f"hdfs://unit_test/sqlmesh__test/dev/test__test_model__{snapshot.version}" + ) + assert rendered_physical_properties["sort_order"].text("this") == "asc" + + # the conditional_prop will be disabled for "dev" gateway + assert "conditional_prop" not in rendered_physical_properties + + +def test_macros_in_model_statement(sushi_context, assert_exp_eq): + @macro() + def session_properties(evaluator, value): + return exp.Property( + this=exp.var("session_properties"), + value=exp.convert([exp.convert("foo").eq(exp.var(f"bar_{value}"))]), + ) + + expressions = d.parse( + """ + MODEL ( + name @{gateway}__@{gateway}.test_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column @{time_column} + + ), + start @IF(@gateway = 'test_gateway', '2023-01-01', '2024-01-02'), + @session_properties(baz) + ); + + SELECT a, b UNION SELECT c, c + """ + ) + + model = load_sql_based_model( + expressions, variables={"gateway": "test_gateway", "time_column": "a"} + ) + assert model.name == "test_gateway__test_gateway.test_model" + assert model.time_column + assert model.time_column.column == exp.column("a", quoted=True) + assert model.start == "2023-01-01" + assert model.session_properties == {"foo": exp.column("bar_baz", quoted=False)} + + +def test_macro_references_in_audits(): + @macro() + def zero_value(evaluator: MacroEvaluator) -> int: + return 0 + + @macro() + def min_value(evaluator: MacroEvaluator) -> int: + return 1 + + @macro() + def not_loaded_macro(evaluator: MacroEvaluator) -> int: + return 10 + + @macro() + def max_value(evaluator: MacroEvaluator) -> int: + return 1000 + + audit_expression = parse( + """ + AUDIT ( + name assert_max_value, + ); + SELECT * + FROM @this_model + WHERE + id > @max_value; + """ + ) + + not_zero_audit = parse( + """ + AUDIT ( + name assert_not_zero, + ); + SELECT * + FROM @this_model + WHERE + id = @zero_value; + """ + ) + + model_expression = d.parse( + """ + MODEL ( + name db.audit_model, + audits (assert_max_value, assert_positive_ids), + ); + SELECT 1 as id; + + AUDIT ( + name assert_positive_ids, + ); + SELECT * + FROM @this_model + WHERE + id < @min_value; + """ + ) + + audits = { + "assert_max_value": load_audit(audit_expression, dialect="duckdb"), + "assert_not_zero": load_audit(not_zero_audit, dialect="duckdb"), + } + model_defaults = ModelDefaultsConfig(dialect="duckdb", audits=["assert_not_zero"]) + + model = load_sql_based_model( + model_expression, + defaults=model_defaults.dict(), + audit_definitions=audits, + ) + + assert len(model.audits) == 3 + audits_with_args = model.audits_with_args + assert len(audits_with_args) == 3 + assert len(model.python_env) == 3 + assert model.audits == [ + ("assert_not_zero", {}), + ("assert_max_value", {}), + ("assert_positive_ids", {}), + ] + assert isinstance(audits_with_args[0][0], ModelAudit) + assert isinstance(audits_with_args[1][0], ModelAudit) + assert isinstance(audits_with_args[2][0], ModelAudit) + assert isinstance(model.python_env["min_value"], Executable) + assert isinstance(model.python_env["max_value"], Executable) + assert isinstance(model.python_env["zero_value"], Executable) + assert "not_loaded_macro" not in model.python_env + + +def test_python_model_dialect(): + model._dialect = "snowflake" + + @model( + name="a", + kind=IncrementalByTimeRangeKind(time_column=TimeColumn(column="x", format="YYMMDD")), + columns={}, + ) + def test(context, **kwargs): + return None + + m = model.get_registry()["a"].model( + module_path=Path("."), + path=Path("."), + dialect="snowflake", + ) + + assert m.time_column.column.sql() == '"X"' + assert m.time_column.format == "%y%m%d" + + @model( + name="b", + kind=IncrementalByTimeRangeKind(time_column="y"), + columns={}, + ) + def test(context, **kwargs): + return None + + m = model.get_registry()["b"].model( + module_path=Path("."), + path=Path("."), + dialect="snowflake", + ) + + assert m.time_column.column.sql() == '"Y"' + assert m.time_column.format == "%Y-%m-%d" + + # column type parseable by default dialect: no error + model._dialect = "clickhouse" + + @model("good", columns={'"COL"': "DateTime64(9)"}) + def a_model(context): + pass + + # column type not parseable by default dialect and no explicit dialect: error + model._dialect = "snowflake" + + with pytest.raises(ParseError, match="No expression was parsed from 'DateTime64\\(9\\)'"): + + @model("bad", columns={'"COL"': "DateTime64(9)"}) + def a_model(context): + pass + + # column type not parseable by default dialect and explicit dialect specified: no error + @model("good", columns={'"COL"': "DateTime64(9)"}, dialect="clickhouse") + def a_model(context): + pass + + model._dialect = None + + +def test_jinja_runtime_stage(assert_exp_eq): + expressions = d.parse( + """ + MODEL ( + name test.jinja + ); + + JINJA_QUERY_BEGIN; + + SELECT '{{ runtime_stage }}' as a, {{ runtime_stage == 'loading' }} as b + + JINJA_END; + """ + ) + + model = load_sql_based_model(expressions) + assert_exp_eq(model.render_query(), '''SELECT 'loading' as "a", TRUE as "b"''') + + +def test_forward_only_on_destructive_change_config() -> None: + # global default to ALLOW for non-incremental models + config = Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")) + context = Context(config=config) + + expressions = d.parse( + """ + MODEL ( + name memory.db.table, + kind FULL, + ); + SELECT a, b, c FROM source_table; + """ + ) + model = load_sql_based_model(expressions, defaults=config.model_defaults.dict()) + context.upsert_model(model) + context_model = context.get_model("memory.db.table") + assert context_model.on_destructive_change.is_allow + + # global default to ERROR for incremental models + config = Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")) + context = Context(config=config) + + expressions = d.parse( + """ + MODEL ( + name memory.db.table, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column c, + forward_only True + ), + ); + SELECT a, b, c FROM source_table; + """ + ) + model = load_sql_based_model(expressions, defaults=config.model_defaults.dict()) + context.upsert_model(model) + context_model = context.get_model("memory.db.table") + assert context_model.on_destructive_change.is_error + + # WARN specified in model definition, overrides incremental model sqlmesh default ERROR + config = Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")) + context = Context(config=config) + + expressions = d.parse( + """ + MODEL ( + name memory.db.table, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column c, + forward_only True, + on_destructive_change warn + ), + ); + SELECT a, b, c FROM source_table; + """ + ) + model = load_sql_based_model(expressions, defaults=config.model_defaults.dict()) + context.upsert_model(model) + context_model = context.get_model("memory.db.table") + assert context_model.on_destructive_change.is_warn + + # WARN specified as model default, overrides incremental model sqlmesh default ERROR + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb", on_destructive_change="warn") + ) + context = Context(config=config) + + expressions = d.parse( + """ + MODEL ( + name memory.db.table, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column c, + forward_only True + ), + ); + SELECT a, b, c FROM source_table; + """ + ) + model = load_sql_based_model(expressions, defaults=config.model_defaults.dict()) + context.upsert_model(model) + context_model = context.get_model("memory.db.table") + assert context_model.on_destructive_change.is_warn + + # WARN specified as model default, does not override non-incremental sqlmesh default ALLOW + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb", on_destructive_change="warn") + ) + context = Context(config=config) + + expressions = d.parse( + """ + MODEL ( + name memory.db.table, + kind FULL, + ); + SELECT a, b, c FROM source_table; + """ + ) + model = load_sql_based_model(expressions, defaults=config.model_defaults.dict()) + context.upsert_model(model) + context_model = context.get_model("memory.db.table") + assert context_model.on_destructive_change.is_allow + + +def test_batch_concurrency_config() -> None: + # No batch_concurrency default for incremental models + config = Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")) + context = Context(config=config) + + expressions = d.parse( + """ + MODEL ( + name memory.db.table, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column c + ), + ); + SELECT a, b, c FROM source_table; + """ + ) + model = load_sql_based_model(expressions, defaults=config.model_defaults.dict()) + context.upsert_model(model) + context_model = context.get_model("memory.db.table") + assert context_model.batch_concurrency is None + + # batch_concurrency specified in model defaults applies to incremental models + config = Config(model_defaults=ModelDefaultsConfig(dialect="duckdb", batch_concurrency=5)) + context = Context(config=config) + + expressions = d.parse( + """ + MODEL ( + name memory.db.table, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column c + ), + ); + SELECT a, b, c FROM source_table; + """ + ) + model = load_sql_based_model(expressions, defaults=config.model_defaults.dict()) + context.upsert_model(model) + context_model = context.get_model("memory.db.table") + assert context_model.batch_concurrency == 5 + + # batch_concurrency specified in model definition overrides default + config = Config(model_defaults=ModelDefaultsConfig(dialect="duckdb", batch_concurrency=5)) + context = Context(config=config) + + expressions = d.parse( + """ + MODEL ( + name memory.db.table, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column c, + batch_concurrency 10 + ), + ); + SELECT a, b, c FROM source_table; + """ + ) + model = load_sql_based_model(expressions, defaults=config.model_defaults.dict()) + context.upsert_model(model) + context_model = context.get_model("memory.db.table") + assert context_model.batch_concurrency == 10 + + # batch_concurrency default does not apply to non-incremental models + config = Config(model_defaults=ModelDefaultsConfig(dialect="duckdb", batch_concurrency=5)) + context = Context(config=config) + + expressions = d.parse( + """ + MODEL ( + name memory.db.table, + kind FULL, + ); + SELECT a, b, c FROM source_table; + """ + ) + model = load_sql_based_model(expressions, defaults=config.model_defaults.dict()) + context.upsert_model(model) + context_model = context.get_model("memory.db.table") + assert context_model.batch_concurrency is None + + # batch_concurrency default does not apply to INCREMENTAL_BY_UNIQUE_KEY models + config = Config(model_defaults=ModelDefaultsConfig(dialect="duckdb", batch_concurrency=5)) + context = Context(config=config) + + expressions = d.parse( + """ + MODEL ( + name memory.db.table, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key a + ), + ); + SELECT a, b, c FROM source_table; + """ + ) + model = load_sql_based_model(expressions, defaults=config.model_defaults.dict()) + context.upsert_model(model) + context_model = context.get_model("memory.db.table") + assert context_model.batch_concurrency == 1 + + +def test_model_meta_on_additive_change_property() -> None: + """Test that ModelMeta has on_additive_change property that works like on_destructive_change.""" + from sqlmesh.core.model.kind import IncrementalByTimeRangeKind, OnAdditiveChange + from sqlmesh.core.model.meta import ModelMeta + + # Test incremental model with on_additive_change=ERROR + incremental_kind = IncrementalByTimeRangeKind( + time_column="c", + forward_only=True, + on_additive_change=OnAdditiveChange.ERROR, + ) + model_meta = ModelMeta(name="test_model", kind=incremental_kind) + assert model_meta.on_additive_change == OnAdditiveChange.ERROR + + # Test incremental model with on_additive_change=WARN + incremental_kind = IncrementalByTimeRangeKind( + time_column="c", + forward_only=True, + on_additive_change=OnAdditiveChange.WARN, + ) + model_meta = ModelMeta(name="test_model", kind=incremental_kind) + assert model_meta.on_additive_change == OnAdditiveChange.WARN + + # Test incremental model with default on_additive_change (should be ALLOW) + incremental_kind = IncrementalByTimeRangeKind( + time_column="c", + forward_only=True, + ) + model_meta = ModelMeta(name="test_model", kind=incremental_kind) + assert model_meta.on_additive_change == OnAdditiveChange.ALLOW + + incremental_kind = IncrementalByTimeRangeKind( + time_column="c", + forward_only=True, + on_additive_change=OnAdditiveChange.IGNORE, + ) + model_meta = ModelMeta(name="test_model", kind=incremental_kind) + assert model_meta.on_additive_change == OnAdditiveChange.IGNORE + + +def test_incremental_by_partition(sushi_context, assert_exp_eq): + expressions = d.parse( + """ + MODEL ( + name db.table, + kind INCREMENTAL_BY_PARTITION, + partitioned_by [a], + ); + + SELECT a, b + """ + ) + model = load_sql_based_model(expressions) + assert model.kind.is_incremental_by_partition + assert not model.kind.disable_restatement + + expressions = d.parse( + """ + MODEL ( + name db.table, + kind INCREMENTAL_BY_PARTITION ( + disable_restatement false + ), + partitioned_by [a], + ); + + SELECT a, b + """ + ) + model = load_sql_based_model(expressions) + assert model.kind.is_incremental_by_partition + assert not model.kind.disable_restatement + + with pytest.raises( + ValidationError, + match=r".*partitioned_by field is required for INCREMENTAL_BY_PARTITION models.*", + ): + expressions = d.parse( + """ + MODEL ( + name db.table, + kind INCREMENTAL_BY_PARTITION, + ); + + SELECT a, b + """ + ) + load_sql_based_model(expressions) + + with pytest.raises( + ConfigError, + match=r".*Do not specify the `forward_only` configuration key.*", + ): + expressions = d.parse( + """ + MODEL ( + name db.table, + kind INCREMENTAL_BY_PARTITION ( + forward_only true + ), + ); + + SELECT a, b + """ + ) + load_sql_based_model(expressions) + + +@pytest.mark.parametrize( + ["model_def", "path", "expected_name"], + [ + [ + """dialect duckdb,""", + """models/test_schema/test_model.sql,""", + "test_schema.test_model", + ], + [ + """dialect duckdb,""", + """models/test_model.sql,""", + "test_model", + ], + [ + """dialect duckdb,""", + """models/inventory/db/test_schema/test_model.sql,""", + "db.test_schema.test_model", + ], + ["""name test_model,""", """models/schema/test_model.sql,""", "test_model"], + ], +) +def test_model_table_name_inference( + sushi_context: Context, model_def: str, path: str, expected_name: str +): + model = load_sql_based_model( + d.parse( + f""" + MODEL ( + {model_def} + ); + SELECT a FROM tbl; + """, + default_dialect="duckdb", + ), + path=Path(f"$root/{path}"), + infer_names=True, + ) + assert model.name == expected_name + + +@pytest.mark.parametrize( + ["path", "expected_name"], + [ + [ + """models/test_schema/test_model.py""", + "test_schema.test_model", + ], + [ + """models/inventory/db/test_schema/test_model.py""", + "db.test_schema.test_model", + ], + ], +) +def test_python_model_name_inference(tmp_path: Path, path: str, expected_name: str) -> None: + init_example_project(tmp_path, engine_type="duckdb") + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + model_naming=NameInferenceConfig(infer_names=True), + ) + + foo_py_file = tmp_path / path + foo_py_file.parent.mkdir(parents=True, exist_ok=True) + foo_py_file.write_text("""from sqlmesh import model +@model( + columns={'"COL"': "int"}, +) +def my_model(context, **kwargs): + pass""") + context = Context(paths=tmp_path, config=config) + assert context.get_model(expected_name).name == expected_name + assert isinstance(context.get_model(expected_name), PythonModel) + + +def test_python_model_name_inference_multiple_models(tmp_path: Path) -> None: + init_example_project(tmp_path, engine_type="duckdb") + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + model_naming=NameInferenceConfig(infer_names=True), + ) + + path_a = tmp_path / "models/test_schema/test_model_a.py" + path_b = tmp_path / "models/test_schema/test_model_b.py" + + model_payload = """from sqlmesh import model +@model( + columns={'"COL"': "int"}, +) +def my_model(context, **kwargs): + pass""" + + path_a.parent.mkdir(parents=True, exist_ok=True) + path_a.write_text(model_payload) + path_b.write_text(model_payload) + + context = Context(paths=tmp_path, config=config) + assert context.get_model("test_schema.test_model_a").name == "test_schema.test_model_a" + assert context.get_model("test_schema.test_model_b").name == "test_schema.test_model_b" + + +def test_custom_kind(): + expressions = d.parse( + """ + MODEL ( + name db.table, + kind CUSTOM ( + materialization 'MyTestStrategy', + forward_only true, + disable_restatement true, + materialization_properties ( + 'key_a' = 'value_a', + key_b = 2, + 'key_c' = true, + 'key_d' = 1.23, + ), + batch_size 1, + batch_concurrency 2, + lookback 3, + ) + ); + + SELECT a, b + """ + ) + + with pytest.raises( + ConfigError, match=r"Materialization strategy with name 'MyTestStrategy' was not found.*" + ): + model = load_sql_based_model(expressions) + model.validate_definition() + + class MyTestStrategy(CustomMaterialization): + pass + + model = load_sql_based_model(expressions) + assert model.kind.is_custom + + kind = t.cast(CustomKind, model.kind) + assert kind.disable_restatement + assert kind.forward_only + assert kind.materialization == "MyTestStrategy" + assert kind.materialization_properties == { + "key_a": "value_a", + "key_b": 2, + "key_c": True, + "key_d": 1.23, + } + assert kind.batch_size == 1 + assert kind.batch_concurrency == 2 + assert kind.lookback == 3 + + assert ( + kind.to_expression().sql() + == """CUSTOM ( +materialization 'MyTestStrategy', +materialization_properties ('key_a' = 'value_a', key_b = 2, 'key_c' = TRUE, 'key_d' = 1.23), +forward_only TRUE, +disable_restatement TRUE, +batch_size 1, +batch_concurrency 2, +lookback 3 +)""" + ) + + +def test_custom_kind_lookback_property(): + """Test that CustomKind's lookback property is correctly accessed via ModelMeta.lookback. + + This test verifies the fix for issue #5268 where CustomKind models were not respecting + the lookback parameter because the isinstance check for _IncrementalBy failed. + """ + + # Test 1: CustomKind with lookback = 3 + class MyTestStrategy(CustomMaterialization): + pass + + expressions = d.parse( + """ + MODEL ( + name db.custom_table, + kind CUSTOM ( + materialization 'MyTestStrategy', + lookback 3 + ) + ); + SELECT a, b FROM upstream + """ + ) + + model = load_sql_based_model(expressions) + assert model.kind.is_custom + + # Verify that the kind itself has lookback = 3 + kind = t.cast(CustomKind, model.kind) + assert kind.lookback == 3 + + # The bug: model.lookback should return 3, but with the old implementation + # using isinstance(self.kind, _IncrementalBy), it would return 0 + assert model.lookback == 3, "CustomKind lookback not accessible via model.lookback property" + + # Test 2: CustomKind without lookback (should default to 0) + expressions_no_lookback = d.parse( + """ + MODEL ( + name db.custom_table_no_lookback, + kind CUSTOM ( + materialization 'MyTestStrategy' + ) + ); + SELECT a, b FROM upstream + """ + ) + + model_no_lookback = load_sql_based_model(expressions_no_lookback) + assert model_no_lookback.lookback == 0 + + # Test 3: Ensure IncrementalByTimeRangeKind still works correctly + incremental_expressions = d.parse( + """ + MODEL ( + name db.incremental_table, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds, + lookback 5 + ) + ); + SELECT ds, a, b FROM upstream + """ + ) + + incremental_model = load_sql_based_model(incremental_expressions) + assert incremental_model.lookback == 5 + + +def test_time_column_format_in_custom_kind(): + class TimeColumnCustomKind(CustomKind): # type: ignore[no-untyped-def] + _time_column: TimeColumn + + @model_validator(mode="after") + def _validate(self): + self._time_column = TimeColumn.create( + self.materialization_properties.get("time_column"), self.dialect + ) + + @property + def time_column(self): + return self._time_column + + class TimeColumnMaterialization(CustomMaterialization[TimeColumnCustomKind]): + NAME = "time_column_custom_strategy" + + expressions = d.parse( + """ + MODEL ( + name db.table, + kind CUSTOM ( + materialization 'time_column_custom_strategy', + materialization_properties ( + time_column = ts + ), + ), + dialect duckdb + ); + + SELECT a, b, '2020-01-01' as ts + """ + ) + + model = load_sql_based_model(expressions, time_column_format="%d-%m-%Y") + assert isinstance(model.kind, TimeColumnCustomKind) + assert model.kind.time_column.column == exp.to_column("ts", quoted=True) + assert model.kind.time_column.format == "%d-%m-%Y" + assert model.kind.dialect == "duckdb" + assert "dialect" not in json.loads( + model.kind.json() + ) # dialect should not be serialized against the kind + + # explicit time_column format within the model + expressions = d.parse( + """ + MODEL ( + name db.table, + kind CUSTOM ( + materialization 'time_column_custom_strategy', + materialization_properties ( + time_column = (ts, '%Y-%m-%d') + ), + ) + ); + + SELECT a, b, '2020-01-01' as ts + """ + ) + + model = load_sql_based_model(expressions, time_column_format="%d-%m-%Y") + assert model.time_column.format == "%Y-%m-%d" + + +def test_model_kind_to_expression(): + assert ( + load_sql_based_model( + d.parse( + """ + MODEL ( + name db.table, + kind INCREMENTAL_BY_TIME_RANGE( + time_column a, + ), + ); + SELECT a, b + """ + ) + ) + .kind.to_expression() + .sql() + == """INCREMENTAL_BY_TIME_RANGE ( +time_column ("a", '%Y-%m-%d'), +partition_by_time_column TRUE, +forward_only FALSE, +disable_restatement FALSE, +on_destructive_change 'ERROR', +on_additive_change 'ALLOW' +)""" + ) + + assert ( + load_sql_based_model( + d.parse( + """ + MODEL ( + name db.table, + kind INCREMENTAL_BY_TIME_RANGE( + time_column a, + batch_size 1, + batch_concurrency 2, + lookback 3, + forward_only TRUE, + disable_restatement TRUE, + on_destructive_change WARN, + ), + ); + SELECT a, b + """ + ) + ) + .kind.to_expression() + .sql() + == """INCREMENTAL_BY_TIME_RANGE ( +time_column ("a", '%Y-%m-%d'), +partition_by_time_column TRUE, +batch_size 1, +batch_concurrency 2, +lookback 3, +forward_only TRUE, +disable_restatement TRUE, +on_destructive_change 'WARN', +on_additive_change 'ALLOW' +)""" + ) + + assert ( + load_sql_based_model( + d.parse( + """ + MODEL ( + name db.table, + kind INCREMENTAL_BY_UNIQUE_KEY( + unique_key a, + ), + ); + SELECT a, b + """ + ) + ) + .kind.to_expression() + .sql() + == """INCREMENTAL_BY_UNIQUE_KEY ( +unique_key ("a"), +batch_concurrency 1, +forward_only FALSE, +disable_restatement FALSE, +on_destructive_change 'ERROR', +on_additive_change 'ALLOW' +)""" + ) + + assert ( + load_sql_based_model( + d.parse( + """ + MODEL ( + name db.table, + kind INCREMENTAL_BY_UNIQUE_KEY( + unique_key a, + when_matched WHEN MATCHED THEN UPDATE SET target.b = COALESCE(source.b, target.b) + ), + ); + SELECT a, b + """ + ) + ) + .kind.to_expression() + .sql() + == """INCREMENTAL_BY_UNIQUE_KEY ( +unique_key ("a"), +when_matched (WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."b" = COALESCE("__MERGE_SOURCE__"."b", "__MERGE_TARGET__"."b")), +batch_concurrency 1, +forward_only FALSE, +disable_restatement FALSE, +on_destructive_change 'ERROR', +on_additive_change 'ALLOW' +)""" + ) + + assert ( + load_sql_based_model( + d.parse( + """ + MODEL ( + name db.table, + kind INCREMENTAL_BY_UNIQUE_KEY( + unique_key a, + when_matched WHEN MATCHED AND source.x = 1 THEN UPDATE SET target.b = COALESCE(source.b, target.b), + WHEN MATCHED THEN UPDATE SET target.b = COALESCE(source.b, target.b) + ), + ); + SELECT a, b + """ + ) + ) + .kind.to_expression() + .sql() + == """INCREMENTAL_BY_UNIQUE_KEY ( +unique_key ("a"), +when_matched (WHEN MATCHED AND "__MERGE_SOURCE__"."x" = 1 THEN UPDATE SET "__MERGE_TARGET__"."b" = COALESCE("__MERGE_SOURCE__"."b", "__MERGE_TARGET__"."b") WHEN MATCHED THEN UPDATE SET "__MERGE_TARGET__"."b" = COALESCE("__MERGE_SOURCE__"."b", "__MERGE_TARGET__"."b")), +batch_concurrency 1, +forward_only FALSE, +disable_restatement FALSE, +on_destructive_change 'ERROR', +on_additive_change 'ALLOW' +)""" + ) + + assert ( + load_sql_based_model( + d.parse( + """ + MODEL ( + name db.table, + kind INCREMENTAL_BY_PARTITION, + partitioned_by ["a"], + ); + SELECT a, b + """ + ) + ) + .kind.to_expression() + .sql() + == """INCREMENTAL_BY_PARTITION ( +forward_only TRUE, +disable_restatement FALSE, +on_destructive_change 'ERROR', +on_additive_change 'ALLOW' +)""" + ) + + assert ( + load_sql_based_model( + d.parse( + """ + MODEL ( + name db.seed, + kind SEED ( + path '../seeds/waiter_names.csv', + ) + ); + """ + ), + path=Path("./examples/sushi/models/test_model.sql"), + ) + .kind.to_expression() + .sql() + == """SEED ( +path '../seeds/waiter_names.csv', +batch_size 1000 +)""" + ) + + assert ( + load_sql_based_model( + d.parse( + """ + MODEL ( + name db.table, + kind SCD_TYPE_2_BY_TIME ( + unique_key [a, b] + ) + ); + SELECT a, b + """ + ) + ) + .kind.to_expression() + .sql() + == """SCD_TYPE_2_BY_TIME ( +updated_at_name "updated_at", +updated_at_as_valid_from FALSE, +unique_key ("a", "b"), +valid_from_name "valid_from", +valid_to_name "valid_to", +invalidate_hard_deletes FALSE, +time_data_type TIMESTAMP, +forward_only TRUE, +disable_restatement TRUE, +on_destructive_change 'ERROR', +on_additive_change 'ALLOW' +)""" + ) + + assert ( + load_sql_based_model( + d.parse( + """ + MODEL ( + name db.table, + kind SCD_TYPE_2_BY_COLUMN ( + unique_key [a, b], + columns [b] + ) + ); + SELECT a, b, c + """ + ) + ) + .kind.to_expression() + .sql() + == """SCD_TYPE_2_BY_COLUMN ( +columns ("b"), +execution_time_as_valid_from FALSE, +unique_key ("a", "b"), +valid_from_name "valid_from", +valid_to_name "valid_to", +invalidate_hard_deletes FALSE, +time_data_type TIMESTAMP, +forward_only TRUE, +disable_restatement TRUE, +on_destructive_change 'ERROR', +on_additive_change 'ALLOW' +)""" + ) + + assert ( + load_sql_based_model( + d.parse( + """ + MODEL ( + name db.table, + kind SCD_TYPE_2_BY_COLUMN ( + unique_key [a, b], + columns * + ) + ); + SELECT a, b, c + """ + ) + ) + .kind.to_expression() + .sql() + == """SCD_TYPE_2_BY_COLUMN ( +columns (*), +execution_time_as_valid_from FALSE, +unique_key ("a", "b"), +valid_from_name "valid_from", +valid_to_name "valid_to", +invalidate_hard_deletes FALSE, +time_data_type TIMESTAMP, +forward_only TRUE, +disable_restatement TRUE, +on_destructive_change 'ERROR', +on_additive_change 'ALLOW' +)""" + ) + + assert ( + load_sql_based_model( + d.parse( + """ + MODEL ( + name db.table, + kind FULL + ); + SELECT a, b, c + """ + ) + ) + .kind.to_expression() + .sql() + == "FULL" + ) + + assert ( + load_sql_based_model( + d.parse( + """ + MODEL ( + name db.table, + kind VIEW + ); + SELECT a, b, c + """ + ) + ) + .kind.to_expression() + .sql() + == """VIEW ( +materialized FALSE +)""" + ) + + assert ( + load_sql_based_model( + d.parse( + """ + MODEL ( + name db.table, + kind VIEW (materialized true) + ); + SELECT a, b, c + """ + ) + ) + .kind.to_expression() + .sql() + == """VIEW ( +materialized TRUE +)""" + ) + + +def test_bad_model_kind(): + with pytest.raises( + SQLMeshError, + match=f"Model kind specified as 'BAD_KIND', but that is not a valid model kind.\n\nPlease specify one of {', '.join(ModelKindName)}.", + ): + d.parse( + """ + MODEL ( + name db.table, + kind BAD_KIND + ); + SELECT a, b + """ + ) + + +def test_merge_filter(): + expressions = d.parse( + """ + MODEL ( + name db.employees, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key name, + merge_filter source.salary > 0 + ) + ); + SELECT 'name' AS name, 1 AS salary; + """ + ) + + expected_incremental_predicate = f"`{MERGE_SOURCE_ALIAS}`.`salary` > 0" + + model = load_sql_based_model(expressions, dialect="hive") + assert model.kind.merge_filter.sql(dialect="hive") == expected_incremental_predicate + + model = SqlModel.parse_raw(model.json()) + assert model.kind.merge_filter.sql(dialect="hive") == expected_incremental_predicate + assert model.dialect == "hive" + + expressions = d.parse( + """ + MODEL ( + name db.test, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key purchase_order_id, + when_matched ( + WHEN MATCHED AND source._operation = 1 THEN DELETE + WHEN MATCHED AND source._operation <> 1 THEN UPDATE SET target.purchase_order_id = 1 + ), + merge_filter ( + source.ds > (SELECT MAX(ds) FROM db.test) AND + source.ds > @start_ds AND + source._operation <> 1 AND + target.start_date > date_add(current_date, interval 7 day) + ) + ) + ); + + SELECT + purchase_order_id, + start_date + FROM db.upstream + """ + ) + + model = SqlModel.parse_raw(load_sql_based_model(expressions, dialect="duckdb").json()) + assert d.format_model_expressions(model.render_definition(), dialect=model.dialect) == ( + f"""MODEL ( + name db.test, + dialect duckdb, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key ("purchase_order_id"), + when_matched ( + WHEN MATCHED AND "{MERGE_SOURCE_ALIAS}"."_operation" = 1 THEN DELETE + WHEN MATCHED AND "{MERGE_SOURCE_ALIAS}"."_operation" <> 1 THEN UPDATE SET + "{MERGE_TARGET_ALIAS}"."purchase_order_id" = 1 + ), + merge_filter ( + "{MERGE_SOURCE_ALIAS}"."ds" > ( + SELECT + MAX("ds") + FROM "db"."test" + ) + AND "{MERGE_SOURCE_ALIAS}"."ds" > @start_ds + AND "{MERGE_SOURCE_ALIAS}"."_operation" <> 1 + AND "{MERGE_TARGET_ALIAS}"."start_date" > CURRENT_DATE + INTERVAL '7' DAY + ), + batch_concurrency 1, + forward_only FALSE, + disable_restatement FALSE, + on_destructive_change 'ERROR', + on_additive_change 'ALLOW' + ) +); + +SELECT + purchase_order_id, + start_date +FROM db.upstream""" + ) + + rendered_merge_filters = model.render_merge_filter(start="2023-01-01", end="2023-01-02") + assert ( + rendered_merge_filters.sql(dialect="hive") + == "(`__MERGE_SOURCE__`.`ds` > (SELECT MAX(`ds`) FROM `db`.`test`) AND `__MERGE_SOURCE__`.`ds` > '2023-01-01' AND `__MERGE_SOURCE__`.`_operation` <> 1 AND `__MERGE_TARGET__`.`start_date` > CURRENT_DATE + INTERVAL '7' DAY)" + ) + + +def test_merge_filter_normalization(): + # unquoted gets normalized and quoted + expressions = d.parse( + """ + MODEL ( + name db.employees, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key name, + merge_filter source.salary > 0 + ) + ); + SELECT 'name' AS name, 1 AS salary; + """ + ) + + model = load_sql_based_model(expressions, dialect="snowflake") + assert model.merge_filter.sql(dialect="snowflake") == '"__MERGE_SOURCE__"."SALARY" > 0' + + # quoted gets preserved + expressions = d.parse( + """ + MODEL ( + name db.employees, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key name, + merge_filter source."SaLArY" > 0 + ) + ); + SELECT 'name' AS name, 1 AS "SaLArY"; + """ + ) + + model = load_sql_based_model(expressions, dialect="snowflake") + assert model.merge_filter.sql(dialect="snowflake") == '"__MERGE_SOURCE__"."SaLArY" > 0' + + +def test_merge_filter_macro(): + @macro() + def predicate( + evaluator: MacroEvaluator, + cluster_column: exp.Column, + ) -> exp.Expression: + return parse_one(f"source.{cluster_column} > dateadd(day, -7, target.{cluster_column})") + + expressions = d.parse( + """ + MODEL ( + name db.incremental_model, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key id, + merge_filter @predicate(update_datetime) and target.update_datetime > @start_dt + ), + clustered_by update_datetime + ); + SELECT id, update_datetime FROM db.test_model; + """ + ) + + unrendered_merge_filter = f"""@predicate("UPDATE_DATETIME") AND "{MERGE_TARGET_ALIAS}"."UPDATE_DATETIME" > @start_dt""" + expected_merge_filter = ( + f"""\"{MERGE_SOURCE_ALIAS}"."UPDATE_DATETIME" > DATEADD(DAY, -7, "{MERGE_TARGET_ALIAS}"."UPDATE_DATETIME") """ + f"""AND "{MERGE_TARGET_ALIAS}"."UPDATE_DATETIME" > CAST('2023-01-01 15:00:00+00:00' AS TIMESTAMPTZ)""" + ) + + model = load_sql_based_model(expressions, dialect="snowflake") + assert model.kind.merge_filter.sql(dialect=model.dialect) == unrendered_merge_filter + + model = SqlModel.parse_raw(model.json()) + assert model.kind.merge_filter.sql(dialect=model.dialect) == unrendered_merge_filter + + rendered_merge_filters = model.render_merge_filter(start="2023-01-01 15:00:00") + assert rendered_merge_filters.sql(dialect=model.dialect) == expected_merge_filter + + +@pytest.mark.parametrize( + "metadata_only", + [True, False], +) +def test_macro_func_hash(mocker: MockerFixture, metadata_only: bool): + mocker.patch("sqlmesh.core.macros.macro._registry", {}) + + @macro(metadata_only=metadata_only) + def noop(evaluator) -> None: + return None + + expressions = d.parse( + """ + MODEL ( + name db.model, + ); + + SELECT 1; + """ + ) + model = load_sql_based_model(expressions, path=Path("./examples/sushi/models/test_model.sql")) + + expressions = d.parse( + """ + MODEL ( + name db.model, + ); + + SELECT 1; + + @noop(); + """ + ) + new_model = load_sql_based_model( + expressions, path=Path("./examples/sushi/models/test_model.sql") + ) + assert model.metadata_hash != new_model.metadata_hash + assert model.data_hash != new_model.data_hash + assert new_model.is_metadata_only_change(model) == metadata_only + + @macro(metadata_only=metadata_only) # type: ignore + def noop(evaluator) -> None: + print("noop") + return None + + updated_model = load_sql_based_model( + expressions, path=Path("./examples/sushi/models/test_model.sql") + ) + if metadata_only: + assert "print" not in new_model._additional_metadata[0] + assert "print" in updated_model._additional_metadata[0] + assert new_model.data_hash == updated_model.data_hash + assert new_model.metadata_hash != updated_model.metadata_hash + else: + assert "print" not in new_model._data_hash_values[0] + assert "print" in updated_model._data_hash_values[0] + assert new_model.data_hash != updated_model.data_hash + assert new_model.metadata_hash == updated_model.metadata_hash + assert updated_model.is_metadata_only_change(new_model) == metadata_only + + +def test_managed_kind_sql(): + expressions = d.parse( + """ + MODEL ( + name db.table, + kind MANAGED, + physical_properties ( + warehouse = small, + target_lag = '20 minutes', + refresh_mode = auto + ) + ); + + SELECT a, b + """ + ) + + model = load_sql_based_model(expressions) + + assert model.kind.is_managed + + with pytest.raises(ConfigError, match=r".*must specify the 'target_lag' physical property.*"): + load_sql_based_model( + d.parse( + """ + MODEL ( + name db.table, + kind MANAGED, + dialect snowflake + ); + + SELECT a, b + """ + ) + ).validate_definition() + + +def test_managed_kind_python(): + @model("test_managed_python_model", kind="managed", columns={"a": "int"}) + def execute( + context: ExecutionContext, + start: datetime, + end: datetime, + execution_time: datetime, + **kwargs: t.Any, + ) -> pd.DataFrame: + return pd.DataFrame.from_dict(data={"a": 1}, orient="index") + + with pytest.raises( + SQLMeshError, + match=r".*Cannot create Python model.*the 'MANAGED' kind doesn't support Python models", + ): + model.get_registry()["test_managed_python_model"].model( + module_path=Path("."), + path=Path("."), + ).validate_definition() + + +def test_physical_version(): + expressions = d.parse( + """ + MODEL ( + name db.table, + kind INCREMENTAL_BY_TIME_RANGE( + time_column a, + forward_only TRUE, + ), + physical_version '1234', + ); + + SELECT a, b + """ + ) + + model = load_sql_based_model(expressions) + assert model.physical_version == "1234" + + with pytest.raises( + ConfigError, + match=r"Pinning a physical version is only supported for forward only models( at.*)?", + ): + load_sql_based_model( + d.parse( + """ + MODEL ( + name db.table, + kind INCREMENTAL_BY_TIME_RANGE( + time_column a, + ), + physical_version '1234', + ); + + SELECT a, b + """ + ) + ).validate_definition() + + +def test_trailing_comments(): + expressions = d.parse( + """ + MODEL (name db.table); + + /* some comment A */ + + SELECT 1; + /* some comment B */ + """ + ) + model = load_sql_based_model(expressions) + assert not model.render_pre_statements() + assert not model.render_post_statements() + + expressions = d.parse( + """ + MODEL ( + name db.seed, + kind SEED ( + path '../seeds/waiter_names.csv', + batch_size 100, + ) + ); + + /* some comment A */ + """ + ) + model = load_sql_based_model(expressions, path=Path("./examples/sushi/models/test_model.sql")) + assert not model.render_pre_statements() + assert not model.render_post_statements() + + +def test_comments_in_jinja_query(): + expressions = d.parse( + """ + MODEL (name db.table); + + JINJA_QUERY_BEGIN; + /* some comment A */ + + SELECT 1; + /* some comment B */ + + JINJA_END; + """ + ) + model = load_sql_based_model(expressions) + assert model.render_query().sql() == '/* some comment A */ SELECT 1 AS "1"' + + expressions = d.parse( + """ + MODEL (name db.table); + + JINJA_QUERY_BEGIN; + /* some comment A */ + + SELECT 1; + SELECT 2; + /* some comment B */ + + JINJA_END; + """ + ) + model = load_sql_based_model(expressions) + with pytest.raises(ConfigError, match=r"Too many statements in query.*"): + model.render_query() + + +def test_jinja_render_parse_error(): + expressions = d.parse( + """ + MODEL (name db.test_model); + + JINJA_QUERY_BEGIN; + {{ unknown_macro() }} + JINJA_END; + """ + ) + + model = load_sql_based_model(expressions) + + with pytest.raises(ConfigError, match=r"Could not render jinja"): + model.render_query() + + +def test_jinja_render_debug_logging(caplog): + """Test that rendered Jinja expressions are logged for debugging.""" + import logging + + # Set log level to DEBUG to capture debug logs + caplog.set_level(logging.DEBUG, logger="sqlmesh.core.renderer") + + # Create a model with unparseable Jinja that will be rendered + expressions = d.parse( + """ + MODEL (name db.test_model); + + JINJA_QUERY_BEGIN; + {{ 'SELECT invalid syntax here!' }} + JINJA_END; + """ + ) + + model = load_sql_based_model(expressions) + + # Attempt to render - this should fail due to invalid SQL syntax + with pytest.raises(ConfigError, match=r"Could not parse the rendered jinja"): + model.render_query() + + # Check that the rendered Jinja was logged + assert any( + 'Rendered Jinja expression for model \'"db"."test_model"\'' in record.message + and "SELECT invalid syntax here!" in record.message + for record in caplog.records + ) + + +def test_staged_file_path(): + expressions = d.parse( + """ + MODEL (name test, dialect snowflake); + + SELECT * FROM @a.b/c/d.csv(FILE_FORMAT => 'b.ff') + """ + ) + model = load_sql_based_model(expressions) + query = model.render_query() + assert query.sql(dialect="snowflake") == "SELECT * FROM @a.b/c/d.csv (FILE_FORMAT => 'b.ff')" + + expressions = d.parse( + """ + MODEL (name test, dialect snowflake); + + SELECT + * + FROM @variable (FILE_FORMAT => 'foo'), @non_variable (FILE_FORMAT => 'bar') + LIMIT 100 + """ + ) + model = load_sql_based_model(expressions, variables={"variable": "some_path"}) + query = model.render_query() + assert ( + query.sql(dialect="snowflake") + == """SELECT * FROM 'some_path' (FILE_FORMAT => 'foo') AS "SOME_PATH", @non_variable (FILE_FORMAT => 'bar') LIMIT 100""" + ) + + +def test_cache(): + expressions = d.parse( + """ + MODEL (name test); + + SELECT 1 x + FROM y + + """ + ) + model = load_sql_based_model(expressions) + assert model.depends_on == {'"y"'} + assert model.copy(update={"depends_on_": {'"z"'}}).depends_on == {'"z"', '"y"'} + + +def test_snowflake_macro_func_as_table(tmp_path: Path): + init_example_project(tmp_path, engine_type="duckdb") + + custom_macro_file = tmp_path / "macros/custom_macros.py" + custom_macro_file.parent.mkdir(parents=True, exist_ok=True) + custom_macro_file.write_text(""" +from sqlmesh import macro + +@macro() +def custom_macro(evaluator, arg1, arg2): + return "SELECT 1 AS c" + +@macro() +def get_model_name(evaluator): + return f"sqlmesh_example.{evaluator._path.stem}" + """) + + new_snowflake_model_file = tmp_path / "models/new_model.sql" + new_snowflake_model_file.parent.mkdir(parents=True, exist_ok=True) + new_snowflake_model_file.write_text(""" +MODEL ( + name @get_model_name(), + dialect snowflake, +); + +SELECT * FROM (@custom_macro(@a, @b)) AS q + """) + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + variables={"a": "a", "b": "b"}, + ) + context = Context(paths=tmp_path, config=config) + + query = context.get_model("sqlmesh_example.new_model").render_query() + + assert ( + t.cast(exp.Query, query).sql("snowflake") + == 'SELECT "Q"."C" AS "C" FROM (SELECT 1 AS "C") AS "Q"' + ) + + context.plan(no_prompts=True, auto_apply=True) + + +def test_resolve_table(make_snapshot: t.Callable): + @macro() + def resolve_parent(evaluator, name): + return evaluator.resolve_table(name.name) + + for post_statement in ( + "JINJA_STATEMENT_BEGIN; {{ resolve_table('parent') }}; JINJA_END;", + "@resolve_parent('parent')", + ): + expressions = d.parse( + f""" + MODEL (name child); + + SELECT c FROM parent; + + {post_statement} + """ + ) + child = load_sql_based_model(expressions) + parent = load_sql_based_model(d.parse("MODEL (name parent); SELECT 1 AS c")) + + parent_snapshot = make_snapshot(parent) + parent_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + version = parent_snapshot.version + + post_statements = child.render_post_statements(snapshots={'"parent"': parent_snapshot}) + + assert len(post_statements) == 1 + assert post_statements[0].sql() == f'"sqlmesh__default"."parent__{version}"' + + # test with additional nesting level and default catalog + for post_statement in ( + "JINJA_STATEMENT_BEGIN; {{ resolve_table('schema.parent') }}; JINJA_END;", + "@resolve_parent('schema.parent')", + ): + expressions = d.parse( + f""" + MODEL (name schema.child); + + SELECT c FROM schema.parent; + + {post_statement} + """ + ) + child = load_sql_based_model(expressions, default_catalog="main") + parent = load_sql_based_model( + d.parse("MODEL (name schema.parent); SELECT 1 AS c"), default_catalog="main" + ) + + parent_snapshot = make_snapshot(parent) + parent_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + version = parent_snapshot.version + + post_statements = child.render_post_statements( + snapshots={'"main"."schema"."parent"': parent_snapshot} + ) + + assert len(post_statements) == 1 + assert post_statements[0].sql() == f'"main"."sqlmesh__schema"."schema__parent__{version}"' + + +def test_cluster_with_complex_expression(): + expressions = d.parse( + """ + MODEL ( + name test, + dialect snowflake, + kind full, + clustered_by (to_date(cluster_col)) + ); + + SELECT + 1 AS c, + CAST('2020-01-01 12:05:03' AS TIMESTAMPTZ) AS cluster_col + """ + ) + + model = load_sql_based_model(expressions) + assert [expr.sql("snowflake") for expr in model.clustered_by] == ['(TO_DATE("CLUSTER_COL"))'] + + +def test_parametric_model_kind(): + parsed_definition = d.parse( + """ + MODEL ( + name db.test_schema.test_model, + kind @IF(@gateway = 'main', VIEW, FULL) + ); + + SELECT + 1 AS c + """ + ) + + model = load_sql_based_model(parsed_definition, variables={c.GATEWAY: "main"}) + assert isinstance(model.kind, ViewKind) + + model = load_sql_based_model(parsed_definition, variables={c.GATEWAY: "other"}) + assert isinstance(model.kind, FullKind) + + +def test_fingerprint_signals(): + @signal() + def test_signal_hash(batch): + return True + + expressions = d.parse( + """ + MODEL ( name db.table, - kind SCD_TYPE_2_BY_COLUMN ( - unique_key "ID", - columns ["value_to_track"] - ), + signals [ + test_signal_hash(arg = 1), + ], + ); + SELECT 1; + """ + ) + + model = load_sql_based_model(expressions, signal_definitions=signal.get_registry()) + metadata_hash = model.metadata_hash + data_hash = model.data_hash + + def assert_metadata_only(): + model._metadata_hash = None + model._data_hash = None + assert model.metadata_hash != metadata_hash + assert model.data_hash == data_hash + + executable = model.python_env["test_signal_hash"] + model.python_env["test_signal_hash"].payload = executable.payload.replace("True", "False") + assert_metadata_only() + + model = load_sql_based_model(expressions, signal_definitions=signal.get_registry()) + model.signals.clear() + assert_metadata_only() + + +def test_model_optimize(tmp_path: Path, assert_exp_eq): + defaults = [ + ModelDefaultsConfig(optimize_query=True).dict(), + ModelDefaultsConfig(optimize_query=False).dict(), + ] + non_optimized_sql = 'SELECT 1 + 2 AS "new_col"' + optimized_sql = 'SELECT 3 AS "new_col"' + + # Model flag is False, overriding defaults + disabled_opt = d.parse( + """ + MODEL ( + name test, + optimize_query False, + ); + + SELECT 1 + 2 AS new_col + """ + ) + + for default in defaults: + model = load_sql_based_model(disabled_opt, defaults=default) + assert_exp_eq(model.render_query(), non_optimized_sql) + + # Model flag is True, overriding defaults + enabled_opt = d.parse( + """ + MODEL ( + name test, + optimize_query True, + ); + + SELECT 1 + 2 AS new_col + """ + ) + + for default in defaults: + model = load_sql_based_model(enabled_opt, defaults=default) + assert_exp_eq(model.render_query(), optimized_sql) + + # Model flag is not defined, behavior is set according to the defaults + none_opt = d.parse( + """ + MODEL ( + name test, + ); + + SELECT 1 + 2 AS new_col + """ + ) + + assert_exp_eq(load_sql_based_model(none_opt).render_query(), optimized_sql) + assert_exp_eq( + load_sql_based_model(none_opt, defaults=defaults[0]).render_query(), optimized_sql + ) + assert_exp_eq( + load_sql_based_model(none_opt, defaults=defaults[1]).render_query(), non_optimized_sql + ) + + # Ensure that plan works as expected (optimize_query flag affects the model's data hash) + for parsed_model in [enabled_opt, disabled_opt, none_opt]: + context = Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))) + context.upsert_model(load_sql_based_model(parsed_model)) + context.plan(auto_apply=True, no_prompts=True) + + # Ensure non-SQLModels raise if optimize_query is not None + seed_path = tmp_path / "seed.csv" + model_kind = SeedKind(path=str(seed_path.absolute())) + with open(seed_path, "w", encoding="utf-8") as fd: + fd.write( + """ +col_a,col_b,col_c +1,text_a,1.0 +2,text_b,2.0""" + ) + model = create_seed_model("test_db.test_seed_model", model_kind, optimize_query=True) + context = Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))) + + with pytest.raises( + ConfigError, + match=r"SQLMesh query optimizer can only be enabled for SQL models", + ): + context.upsert_model(model) + context.plan(auto_apply=True, no_prompts=True) + + model = create_seed_model("test_db.test_seed_model", model_kind, optimize_query=False) + context.upsert_model(model) + context.plan(auto_apply=True, no_prompts=True) + + +def test_column_description_metadata_change(): + context = Context(config=Config()) + + model = load_sql_based_model( + d.parse( + """ + MODEL ( + name db.test_model, + kind full ); + SELECT - 1 as "ID", - 2 as "value_to_track", - '2020-01-01' as ds, - ; + 1 AS id /* description */ """ + ), + default_catalog=context.default_catalog, ) - scd_type_2_model = load_sql_based_model(model_def) - assert scd_type_2_model.unique_key == [exp.to_column("ID", quoted=True)] - assert scd_type_2_model.kind.columns == [exp.to_column("value_to_track", quoted=True)] - assert scd_type_2_model.columns_to_types == { - "ID": exp.DataType.build("int"), - "value_to_track": exp.DataType.build("int"), - "ds": exp.DataType.build("varchar"), - "valid_from": exp.DataType.build("TIMESTAMP"), - "valid_to": exp.DataType.build("TIMESTAMP"), - } - assert scd_type_2_model.managed_columns == { - "valid_from": exp.DataType.build("TIMESTAMP"), - "valid_to": exp.DataType.build("TIMESTAMP"), - } - assert scd_type_2_model.kind.valid_from_name == exp.column("valid_from", quoted=True) - assert scd_type_2_model.kind.valid_to_name == exp.column("valid_to", quoted=True) - assert not scd_type_2_model.kind.execution_time_as_valid_from - assert scd_type_2_model.kind.is_scd_type_2_by_column - assert scd_type_2_model.kind.is_scd_type_2 - assert scd_type_2_model.kind.is_materialized - assert scd_type_2_model.kind.forward_only - assert scd_type_2_model.kind.disable_restatement + context.upsert_model(model) + context.plan(no_prompts=True, auto_apply=True) -def test_scd_type_2_by_column_overrides(): - model_def = d.parse( + context.upsert_model( + "db.test_model", query_=ParsableSql(sql="SELECT 1 AS id /* description 2 */") + ) + plan = context.plan(no_prompts=True, auto_apply=True) + + snapshots = list(plan.snapshots.values()) + assert len(snapshots) == 1 + + snapshot = snapshots[0] + assert len(snapshot.previous_versions) == 1 + assert snapshot.change_category == SnapshotChangeCategory.METADATA + + +def test_auto_restatement(): + parsed_definition = d.parse( """ MODEL ( - name db.table, - kind SCD_TYPE_2_BY_COLUMN ( - unique_key ["iD", ds], - columns "value_to_track", - valid_from_name test_valid_from, - valid_to_name test_valid_to, - execution_time_as_valid_from True, - time_data_type TIMESTAMPTZ, - forward_only False, - disable_restatement False, - invalidate_hard_deletes False, - ), + name test_schema.test_model, + kind INCREMENTAL_BY_TIME_RANGE( + time_column a, + auto_restatement_cron '@daily', + ) ); - SELECT - 1 as "ID", - 2 as "value_to_track", - '2020-01-01' as ds, - ; + SELECT 1 AS c """ ) - scd_type_2_model = load_sql_based_model(model_def) - assert scd_type_2_model.unique_key == [ - exp.column("iD", quoted=True), - exp.column("ds", quoted=True), - ] - assert scd_type_2_model.managed_columns == { - "test_valid_from": exp.DataType.build("TIMESTAMPTZ"), - "test_valid_to": exp.DataType.build("TIMESTAMPTZ"), - } - assert scd_type_2_model.kind.valid_from_name == exp.column("test_valid_from", quoted=True) - assert scd_type_2_model.kind.valid_to_name == exp.column("test_valid_to", quoted=True) - assert scd_type_2_model.kind.execution_time_as_valid_from - assert scd_type_2_model.kind.is_scd_type_2_by_column - assert scd_type_2_model.kind.is_scd_type_2 - assert scd_type_2_model.kind.is_materialized - assert scd_type_2_model.kind.time_data_type == exp.DataType.build("TIMESTAMPTZ") - assert not scd_type_2_model.kind.invalidate_hard_deletes - assert not scd_type_2_model.kind.forward_only - assert not scd_type_2_model.kind.disable_restatement - - model_kind_dict = scd_type_2_model.kind.dict() - assert scd_type_2_model.kind == _model_kind_validator(None, model_kind_dict, {}) + model = load_sql_based_model(parsed_definition) + assert model.auto_restatement_cron == "@daily" + assert ( + model.kind.to_expression().sql(pretty=True) + == """INCREMENTAL_BY_TIME_RANGE ( + time_column ("a", '%Y-%m-%d'), + partition_by_time_column TRUE, + forward_only FALSE, + disable_restatement FALSE, + on_destructive_change 'ERROR', + on_additive_change 'ALLOW', + auto_restatement_cron '@daily' +)""" + ) + parsed_definition = d.parse( + """ + MODEL ( + name test_schema.test_model, + kind INCREMENTAL_BY_TIME_RANGE( + time_column a, + auto_restatement_cron '@daily', + auto_restatement_intervals 1, + ) + ); + SELECT 1 AS c + """ + ) + model = load_sql_based_model(parsed_definition) + assert model.auto_restatement_cron == "@daily" + assert model.auto_restatement_intervals == 1 + assert ( + model.kind.to_expression().sql(pretty=True) + == """INCREMENTAL_BY_TIME_RANGE ( + time_column ("a", '%Y-%m-%d'), + partition_by_time_column TRUE, + auto_restatement_intervals 1, + forward_only FALSE, + disable_restatement FALSE, + on_destructive_change 'ERROR', + on_additive_change 'ALLOW', + auto_restatement_cron '@daily' +)""" + ) -@pytest.mark.parametrize( - "input_columns,expected_columns", - [ - ( - "col1", - [exp.to_column("col1", quoted=True)], - ), - ( - "[col1]", - [exp.to_column("col1", quoted=True)], - ), - ( - "[col1, col2]", - [exp.to_column("col1", quoted=True), exp.to_column("col2", quoted=True)], - ), - ( - '"col1"', - [exp.to_column("col1", quoted=True)], - ), - ( - '["col1"]', - [exp.to_column("col1", quoted=True)], - ), - ("*", exp.Star()), - ], -) -def test_check_column_variants(input_columns, expected_columns): - model_def = d.parse( - f""" + parsed_definition = d.parse( + """ MODEL ( - name db.table, - kind SCD_TYPE_2_BY_COLUMN ( - unique_key "ID", - columns {input_columns} - ), + name test_schema.test_model, + kind INCREMENTAL_BY_TIME_RANGE( + time_column a, + auto_restatement_cron '@invalid' + ) ); - SELECT 1 - ; + SELECT 1 AS c """ ) - scd_type_2_model = load_sql_based_model(model_def) - assert scd_type_2_model.kind.columns == expected_columns + with pytest.raises(ValueError, match="Invalid cron expression '@invalid'.*"): + load_sql_based_model(parsed_definition) -def test_model_dialect_name(): +def test_gateway_specific_render(assert_exp_eq) -> None: + gateways = { + "main": GatewayConfig(connection=DuckDBConnectionConfig()), + "duckdb": GatewayConfig(connection=DuckDBConnectionConfig()), + } + config = Config( + gateways=gateways, + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + default_gateway="main", + ) + context = Context(config=config) + assert context.engine_adapter == context.engine_adapters["main"] + + @model( + name="dummy_model", + is_sql=True, + kind="full", + gateway="duckdb", + grain='"x"', + ) + def dummy_model_entry(evaluator: MacroEvaluator) -> exp.Select: + return exp.select("x").from_(exp.values([("1", 2)], "_v", ["x"])) + + dummy_model = model.get_registry()["dummy_model"].model(module_path=Path("."), path=Path(".")) + context.upsert_model(dummy_model) + assert isinstance(dummy_model, SqlModel) + assert dummy_model.gateway == "duckdb" + + assert_exp_eq( + context.render("dummy_model"), + """ + SELECT + "_v"."x" AS "x", + FROM (VALUES ('1', 2)) AS "_v"("x") + """, + ) + assert isinstance(context._get_engine_adapter("duckdb"), DuckDBEngineAdapter) + assert len(context.engine_adapters) == 2 + + +def test_model_on_virtual_update(make_snapshot: t.Callable): + # Macro to test resolution within virtual statement + @macro() + def resolve_parent_name(evaluator, name): + return evaluator.resolve_table(name.name) + + dialect = "postgres" + virtual_update_statements = """ + CREATE OR REPLACE VIEW test_view FROM demo_db.table; + GRANT SELECT ON VIEW @this_model TO ROLE owner_name; + JINJA_STATEMENT_BEGIN; + GRANT SELECT ON VIEW {{this_model}} TO ROLE admin; + JINJA_END; + GRANT REFERENCES, SELECT ON FUTURE VIEWS IN DATABASE demo_db TO ROLE owner_name; + @resolve_parent_name('parent'); + GRANT SELECT ON VIEW demo_db.table /* sqlglot.meta replace=false */ TO ROLE admin; + """ + expressions = d.parse( + f""" + MODEL ( + name demo_db.table, + owner owner_name, + ); + + SELECT id from parent; + + on_virtual_update_begin; + + {virtual_update_statements} + + on_virtual_update_end; + + """, + default_dialect=dialect, + ) + + parent_expressions = d.parse( """ MODEL ( - name `project-1`.`db`.`tbl1`, - dialect bigquery + name parent, ); - SELECT 1; + + SELECT 1 from id; + + ON_VIRTUAL_UPDATE_BEGIN; + JINJA_STATEMENT_BEGIN; + GRANT SELECT ON VIEW {{this_model}} TO ROLE admin; + JINJA_END; + ON_VIRTUAL_UPDATE_END; + + """, + default_dialect=dialect, + ) + + model = load_sql_based_model(expressions, dialect=dialect) + parent = load_sql_based_model(parent_expressions, dialect=dialect) + + parent_snapshot = make_snapshot(parent) + parent_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + model_snapshot = make_snapshot(model) + model_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + assert model.on_virtual_update == d.parse(virtual_update_statements, default_dialect=dialect) + + assert parent.on_virtual_update == d.parse( + "JINJA_STATEMENT_BEGIN; GRANT SELECT ON VIEW {{this_model}} TO ROLE admin; JINJA_END;", + default_dialect=dialect, + ) + + environment_naming_info = EnvironmentNamingInfo(name="dev") + table_mapping = {model.fqn: "demo_db__dev.table", parent.fqn: "default__dev.parent"} + snapshots = { + parent_snapshot.name: parent_snapshot, + model_snapshot.name: model_snapshot, + } + + rendered_on_virtual_update = model.render_on_virtual_update( + snapshots=snapshots, + table_mapping=table_mapping, + this_model=model_snapshot.qualified_view_name.table_for_environment( + environment_naming_info, dialect=dialect + ), + ) + + assert len(rendered_on_virtual_update) == 6 + assert ( + rendered_on_virtual_update[0].sql() + == 'CREATE OR REPLACE VIEW "test_view" AS SELECT * FROM "demo_db__dev"."table" AS "table" /* demo_db.table */' + ) + + assert ( + rendered_on_virtual_update[1].sql() + == 'GRANT SELECT ON VIEW "demo_db__dev"."table" TO ROLE "owner_name"' + ) + assert ( + rendered_on_virtual_update[3].sql() + == "GRANT REFERENCES, SELECT ON FUTURE VIEWS IN DATABASE demo_db TO ROLE owner_name" + ) + + assert rendered_on_virtual_update[4].sql() == '"default__dev"."parent"' + + # When replace=false the table should remain as is + assert ( + rendered_on_virtual_update[5].sql() + == 'GRANT SELECT ON VIEW "demo_db"."table" /* sqlglot.meta replace=false */ TO ROLE "admin"' + ) + + rendered_parent_on_virtual_update = parent.render_on_virtual_update( + snapshots=snapshots, + table_mapping=table_mapping, + this_model=parent_snapshot.qualified_view_name.table_for_environment( + environment_naming_info, dialect=dialect + ), + ) + assert len(rendered_parent_on_virtual_update) == 1 + assert ( + rendered_parent_on_virtual_update[0].sql() + == 'GRANT SELECT ON VIEW "default__dev"."parent" TO ROLE "admin"' + ) + + +def test_python_model_on_virtual_update(): + macros = """ + {% macro index_name(v) %}{{ v }}{% endmacro %} + """ + + jinja_macros = JinjaMacroRegistry() + jinja_macros.add_macros(MacroExtractor().extract(macros)) + + @model( + "db.test_model", + kind="full", + columns={"id": "string", "name": "string"}, + on_virtual_update=[ + "JINJA_STATEMENT_BEGIN;\nCREATE INDEX {{index_name('id_index')}} ON db.test_model(id);\nJINJA_END;", + parse_one("GRANT SELECT ON VIEW @this_model TO ROLE dev_role;"), + "GRANT REFERENCES, SELECT ON FUTURE VIEWS IN DATABASE db TO ROLE dev_role;", + ], + ) + def model_with_virtual_statements(context, **kwargs): + return pd.DataFrame( + [ + { + "id": context.var("1"), + "name": context.var("var"), + } + ] + ) + + python_model = model.get_registry()["db.test_model"].model( + module_path=Path("."), path=Path("."), dialect="duckdb", jinja_macros=jinja_macros + ) + + assert len(jinja_macros.root_macros) == 1 + assert len(python_model.jinja_macros.root_macros) == 1 + assert "index_name" in python_model.jinja_macros.root_macros + assert len(python_model.on_virtual_update) == 3 + + rendered_statements = python_model._render_statements( + python_model.on_virtual_update, table_mapping={'"db"."test_model"': "db.test_model"} + ) + + assert ( + rendered_statements[0].sql() + == 'CREATE INDEX "id_index" ON "db"."test_model" /* db.test_model */("id" NULLS LAST)' + ) + assert ( + rendered_statements[1].sql() + == 'GRANT SELECT ON VIEW "db"."test_model" /* db.test_model */ TO ROLE "dev_role"' + ) + assert ( + rendered_statements[2].sql() + == "GRANT REFERENCES, SELECT ON FUTURE VIEWS IN DATABASE db TO ROLE dev_role" + ) + + +def test_compile_time_checks(tmp_path: Path): + ctx = Context( + config=Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + linter=LinterConfig( + enabled=True, rules=["ambiguousorinvalidcolumn", "invalidselectstarexpansion"] + ), + ), + paths=tmp_path, + ) + + cfg_err = "Linter detected errors in the code. Please fix them before proceeding." + + # Strict SELECT * expansion + strict_query = d.parse( """ + MODEL ( + name test, + ); + + SELECT * FROM tbl + """ ) - model = load_sql_based_model(expressions) - assert model.fqn == '"project-1"."db"."tbl1"' + with pytest.raises(LinterError, match=cfg_err): + ctx.upsert_model(load_sql_based_model(strict_query)) + ctx.plan_builder("dev") - model = create_external_model( - "`project-1`.`db`.`tbl1`", columns={"x": "STRING"}, dialect="bigquery" + # Strict column resolution + strict_query = d.parse( + """ + MODEL ( + name test, + ); + + SELECT foo + """ ) - assert "name `project-1`.`db`.`tbl1`" in model.render_definition()[0].sql(dialect="bigquery") + with pytest.raises(LinterError, match=cfg_err): + ctx.upsert_model(load_sql_based_model(strict_query)) + ctx.plan_builder("dev") -def test_model_allow_partials(): + +def test_partition_interval_unit(): expressions = d.parse( """ MODEL ( - name db.table, - allow_partials true, + name test, + kind INCREMENTAL_BY_TIME_RANGE( + time_column ds, + ), + cron '0 0 1 * *' ); - SELECT 1; + SELECT '2024-01-01' AS ds; """ ) - model = load_sql_based_model(expressions) + assert model.partition_interval_unit == IntervalUnit.MONTH - assert model.allow_partials - - assert "allow_partials TRUE" in model.render_definition()[0].sql() - - -def test_signals(): + # Partitioning was explicitly set by the user expressions = d.parse( """ MODEL ( - name db.table, - signals [ - ( - table_name = 'table_a', - ds = @end_ds, - ), - ( - table_name = 'table_b', - ds = @end_ds, - hour = @end_hour, - ), - ( - bool_key = True, - int_key = 1, - float_key = 1.0, - string_key = 'string', - ) - ], + name test, + kind INCREMENTAL_BY_TIME_RANGE( + time_column ds, + ), + cron '0 0 1 * *', + partitioned_by (ds) ); - SELECT 1; + SELECT '2024-01-01' AS ds; """ ) - model = load_sql_based_model(expressions) - assert model.signals == [ - exp.Tuple( - expressions=[ - exp.to_column("table_name").eq("table_a"), - exp.to_column("ds").eq(d.MacroVar(this="end_ds")), - ] - ), - exp.Tuple( - expressions=[ - exp.to_column("table_name").eq("table_b"), - exp.to_column("ds").eq(d.MacroVar(this="end_ds")), - exp.to_column("hour").eq(d.MacroVar(this="end_hour")), - ] - ), - exp.Tuple( - expressions=[ - exp.to_column("bool_key").eq(True), - exp.to_column("int_key").eq(1), - exp.to_column("float_key").eq(1.0), - exp.to_column("string_key").eq("string"), - ] - ), - ] + assert model.partition_interval_unit is None - rendered_signals = model.render_signals(start="2023-01-01", end="2023-01-02 15:00:00") - assert rendered_signals == [ - {"table_name": "table_a", "ds": "2023-01-02"}, - {"table_name": "table_b", "ds": "2023-01-02", "hour": 14}, - {"bool_key": True, "int_key": 1, "float_key": 1.0, "string_key": "string"}, - ] - assert ( - "signals ((table_name = 'table_a', ds = @end_ds), (table_name = 'table_b', ds = @end_ds, hour = @end_hour), (bool_key = TRUE, int_key = 1, float_key = 1.0, string_key = 'string')" - in model.render_definition()[0].sql() +def test_model_blueprinting(tmp_path: Path) -> None: + init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY) + + db_path = str(tmp_path / "db.db") + db_connection = DuckDBConnectionConfig(database=db_path) + + gateways = { + "gw1": GatewayConfig(connection=db_connection, variables={"x": 1}), + "gw2": GatewayConfig(connection=db_connection, variables={"x": 2}), + } + config = Config( + gateways=gateways, + model_defaults=ModelDefaultsConfig(dialect="duckdb"), ) + identity_macro = tmp_path / "macros" / "identity_macro.py" + identity_macro.parent.mkdir(parents=True, exist_ok=True) + identity_macro.write_text( + """from sqlmesh import macro -def test_null_column_type(): - expressions = d.parse( +@macro() +def identity(evaluator, value): + return value +""" + ) + blueprint_sql = tmp_path / "models" / "blueprint.sql" + blueprint_sql.parent.mkdir(parents=True, exist_ok=True) + blueprint_sql.write_text( """ MODEL ( - name test_db.test_model, - columns ( - id INT, - ds NULL, - ) + name @{blueprint}.test_model_sql, + gateway @identity(@blueprint), + blueprints ((blueprint := gw1), (blueprint := gw2)), + kind FULL ); SELECT - id::INT AS id, - ds - FROM x - """ + @x AS x + """ ) - model = load_sql_based_model(expressions, dialect="hive") - assert model.columns_to_types == { - "ds": exp.DataType.build("null"), - "id": exp.DataType.build("int"), - } - assert not model.annotated + blueprint_pydf = tmp_path / "models" / "blueprint_df.py" + blueprint_pydf.parent.mkdir(parents=True, exist_ok=True) + blueprint_pydf.write_text( + """ +import pandas as pd # noqa: TID253 +from sqlmesh import model -def test_when_matched(): - expressions = d.parse( +@model( + "@{blueprint}.test_model_pydf", + gateway="@blueprint", + blueprints=[{"blueprint": "gw1"}, {"blueprint": "gw2"}], + kind="FULL", + columns={"x": "INT"}, +) +def entrypoint(context, *args, **kwargs): + x_var = context.var("x") + assert context.blueprint_var("blueprint").startswith("gw") + return pd.DataFrame({"x": [x_var]})""" + ) + blueprint_pysql = tmp_path / "models" / "blueprint_sql.py" + blueprint_pysql.parent.mkdir(parents=True, exist_ok=True) + blueprint_pysql.write_text( """ - MODEL ( - name db.employees, - kind INCREMENTAL_BY_UNIQUE_KEY ( - unique_key name, - when_matched WHEN MATCHED THEN UPDATE SET target.salary = COALESCE(source.salary, target.salary) - ) - ); - SELECT 'name' AS name, 1 AS salary; - """ +from sqlmesh import model + + +@model( + "@{blueprint}.test_model_pysql", + gateway="@blueprint", + blueprints=[{"blueprint": "gw1"}, {"blueprint": "gw2"}], + kind="FULL", + is_sql=True, +) +def entrypoint(evaluator): + x_var = evaluator.var("x") + assert evaluator.blueprint_var("blueprint", default="").startswith("gw") + return f'SELECT {x_var} AS x'""" ) - expected_when_matched = "WHEN MATCHED THEN UPDATE SET __MERGE_TARGET__.salary = COALESCE(__MERGE_SOURCE__.salary, __MERGE_TARGET__.salary)" + context = Context(paths=tmp_path, config=config) + models = context.models + + # Each of the three model files "expands" into two models + assert len(models) == 6 - model = load_sql_based_model(expressions, dialect="hive") - assert model.kind.when_matched.sql() == expected_when_matched + context.plan(no_prompts=True, auto_apply=True, no_diff=True) - model = SqlModel.parse_raw(model.json()) - assert model.kind.when_matched.sql() == expected_when_matched + for model_name in ("test_model_sql", "test_model_pydf", "test_model_pysql"): + for gateway_no in range(1, 3): + blueprint_value = f"gw{gateway_no}" + model = models.get(f'"db"."{blueprint_value}"."{model_name}"') + assert model is not None + assert "blueprints" not in model.all_fields() -def test_default_catalog_sql(assert_exp_eq): - """ - This test validates the hashing behavior of the system as it relates to the default catalog. - The system is not designed to actually support having an engine that doesn't support default catalog - to start supporting it or the reverse of that. If that did happen then bugs would occur. - """ - HASH_WITH_CATALOG = "3198762995" + python_env = model.python_env - # Test setting default catalog doesn't change hash if it matches existing logic - expressions = d.parse( + assert python_env.get(c.SQLMESH_VARS) == Executable.value({"x": gateway_no}) + + if model_name == "test_model_sql": + assert c.SQLMESH_BLUEPRINT_VARS not in python_env + else: + assert python_env.get(c.SQLMESH_BLUEPRINT_VARS) == Executable.value( + {"blueprint": blueprint_value} + ) + + assert context.fetchdf(f"from {model.fqn}").to_dict() == {"x": {0: gateway_no}} + + multi_variable_blueprint_example = tmp_path / "models" / "multi_variable_blueprint_example.sql" + multi_variable_blueprint_example.parent.mkdir(parents=True, exist_ok=True) + multi_variable_blueprint_example.write_text( """ MODEL ( - name catalog.db.table + name @{customer}.my_table, + blueprints ( + (customer := customer1, Customer_Field := 'bar'), + (customer := customer2, Customer_Field := qux), + ), + kind FULL ); - SELECT x - FROM catalog.db.source + + SELECT + @customer_FIELD AS foo, + @{customer_field} AS foo2, + @BLUEPRINT_VAR('customer_field') AS foo3, + FROM @{customer}.my_source """ ) - model = load_sql_based_model(expressions, default_catalog="catalog") - assert model.default_catalog == "catalog" - assert model.name == "catalog.db.table" - assert model.fqn == '"catalog"."db"."table"' - assert model.depends_on == {'"catalog"."db"."source"'} + context = Context(paths=tmp_path, config=config) + models = context.models - assert_exp_eq( - model.render_query(), - """ - SELECT - "x" AS "x" - FROM "catalog"."db"."source" AS "source" - """, + # The new model expands into another 2 new models + assert len(models) == 8 + + customer1_model = models.get('"db"."customer1"."my_table"') + assert customer1_model is not None + + customer1_python_env = customer1_model.python_env + assert customer1_python_env.get(c.SQLMESH_VARS) is None + assert customer1_python_env.get(c.SQLMESH_BLUEPRINT_VARS) == Executable.value( + {"customer": SqlValue(sql="customer1"), "customer_field": SqlValue(sql="'bar'")} ) - assert model.data_hash == HASH_WITH_CATALOG + assert t.cast(exp.Expression, customer1_model.render_query()).sql() == ( + """SELECT 'bar' AS "foo", "bar" AS "foo2", 'bar' AS "foo3" FROM "db"."customer1"."my_source" AS "my_source\"""" + ) - expressions = d.parse( + customer2_model = models.get('"db"."customer2"."my_table"') + assert customer2_model is not None + + customer2_python_env = customer2_model.python_env + assert customer2_python_env.get(c.SQLMESH_VARS) is None + assert customer2_python_env.get(c.SQLMESH_BLUEPRINT_VARS) == Executable.value( + {"customer": SqlValue(sql="customer2"), "customer_field": SqlValue(sql="qux")} + ) + + assert t.cast(exp.Expression, customer2_model.render_query()).sql() == ( + '''SELECT "qux" AS "foo", "qux" AS "foo2", "qux" AS "foo3" FROM "db"."customer2"."my_source" AS "my_source"''' + ) + + +def test_dynamic_blueprinting_using_custom_macro(tmp_path: Path) -> None: + init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY) + + dynamic_template_sql = tmp_path / "models/dynamic_template_custom_macro.sql" + dynamic_template_sql.parent.mkdir(parents=True, exist_ok=True) + dynamic_template_sql.write_text( """ MODEL ( - name catalog.db.table, + name @customer.some_table, + kind FULL, + blueprints @gen_blueprints(), ); - SELECT x - FROM catalog.db.source + + SELECT + @field_a, + @{field_b} AS field_b + FROM @customer.some_source + """ ) - model = load_sql_based_model(expressions) - assert model.default_catalog is None - assert model.name == "catalog.db.table" - assert model.fqn == '"catalog"."db"."table"' - assert model.depends_on == {'"catalog"."db"."source"'} - - assert_exp_eq( - model.render_query(), + dynamic_template_py = tmp_path / "models/dynamic_template_custom_macro.py" + dynamic_template_py.parent.mkdir(parents=True, exist_ok=True) + dynamic_template_py.write_text( """ - SELECT - "x" AS "x" - FROM "catalog"."db"."source" AS "source" - """, +from sqlmesh import model + +@model( + "@{customer}.some_other_table", + kind="FULL", + blueprints="@gen_blueprints()", + is_sql=True, +) +def entrypoint(evaluator): + field_a = evaluator.blueprint_var("field_a") + return f"SELECT {field_a}, @BLUEPRINT_VAR('field_b') AS field_b FROM @customer.some_source" +""" ) - assert model.data_hash == HASH_WITH_CATALOG + gen_blueprints = tmp_path / "macros/gen_blueprints.py" + gen_blueprints.parent.mkdir(parents=True, exist_ok=True) + gen_blueprints.write_text( + """from sqlmesh import macro - # Test setting default catalog to a different catalog but everything if fully qualified then no hash change - expressions = d.parse( +@macro() +def gen_blueprints(evaluator): + return ( + "((customer := customer1, field_a := x, field_b := y)," + " (customer := customer2, field_a := z, field_b := w))" + )""" + ) + + ctx = Context( + config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")), paths=tmp_path + ) + + assert len(ctx.models) == 4 + assert '"memory"."customer1"."some_table"' in ctx.models + assert '"memory"."customer2"."some_table"' in ctx.models + assert '"memory"."customer1"."some_other_table"' in ctx.models + assert '"memory"."customer2"."some_other_table"' in ctx.models + + +def test_dynamic_blueprinting_using_each(tmp_path: Path) -> None: + init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY) + + dynamic_template_sql = tmp_path / "models/dynamic_template_each.sql" + dynamic_template_sql.parent.mkdir(parents=True, exist_ok=True) + dynamic_template_sql.write_text( """ MODEL ( - name catalog.db.table + name @customer.some_table, + kind FULL, + blueprints @EACH(@values, x -> (customer := schema_@x)), ); - SELECT x - FROM catalog.db.source + + SELECT + 1 AS c """ ) - model = load_sql_based_model(expressions, default_catalog="other_catalog") - assert model.default_catalog == "other_catalog" - assert model.name == "catalog.db.table" - assert model.fqn == '"catalog"."db"."table"' - assert model.depends_on == {'"catalog"."db"."source"'} - - assert_exp_eq( - model.render_query(), + dynamic_template_py = tmp_path / "models/dynamic_template_each.py" + dynamic_template_py.parent.mkdir(parents=True, exist_ok=True) + dynamic_template_py.write_text( """ - SELECT - "x" AS "x" - FROM "catalog"."db"."source" AS "source" - """, +from sqlmesh import model + +@model( + "@{customer}.some_other_table", + kind="FULL", + blueprints="@EACH(@values, x -> (customer := schema_@x))", + is_sql=True, +) +def entrypoint(evaluator): + return "SELECT 1 AS c" +""" ) - assert model.data_hash == HASH_WITH_CATALOG + model_defaults = ModelDefaultsConfig(dialect="duckdb") + variables = {"values": ["customer1", "customer2"]} + config = Config(model_defaults=model_defaults, variables=variables) + ctx = Context(config=config, paths=tmp_path) - # test that hash changes if model contains a non-fully-qualified reference - expressions = d.parse( + assert len(ctx.models) == 4 + assert '"memory"."schema_customer1"."some_table"' in ctx.models + assert '"memory"."schema_customer2"."some_table"' in ctx.models + assert '"memory"."schema_customer1"."some_other_table"' in ctx.models + assert '"memory"."schema_customer2"."some_other_table"' in ctx.models + + +def test_single_blueprint(tmp_path: Path) -> None: + init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY) + + single_blueprint = tmp_path / "models/single_blueprint.sql" + single_blueprint.parent.mkdir(parents=True, exist_ok=True) + single_blueprint.write_text( """ MODEL ( - name catalog.db.table + name @single_blueprint.some_table, + kind FULL, + blueprints ((single_blueprint := bar)) ); - SELECT x - FROM db.source + + SELECT 1 AS c """ ) - model = load_sql_based_model(expressions, default_catalog="other_catalog") - assert model.default_catalog == "other_catalog" - assert model.name == "catalog.db.table" - assert model.fqn == '"catalog"."db"."table"' - assert model.depends_on == {'"other_catalog"."db"."source"'} + ctx = Context( + config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")), paths=tmp_path + ) - # The query changed so the hash should change - assert model.data_hash != HASH_WITH_CATALOG + assert len(ctx.models) == 1 + assert '"memory"."bar"."some_table"' in ctx.models - # test that hash is the same but the fqn is different so the snapshot is different so this is - # a new snapshot but with the same hash as before - expressions = d.parse( + +def test_blueprinting_with_quotes(tmp_path: Path) -> None: + init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY) + + template_with_quoted_vars = tmp_path / "models/template_with_quoted_vars.sql" + template_with_quoted_vars.parent.mkdir(parents=True, exist_ok=True) + template_with_quoted_vars.write_text( """ MODEL ( - name db.table, + name m.@{bp_var}, + blueprints ( + (bp_var := "a b"), + (bp_var := 'c d'), + ), ); - SELECT x - FROM catalog.db.source + + SELECT @bp_var AS c1, @{bp_var} AS c2 """ ) - model = load_sql_based_model(expressions) - assert model.default_catalog is None - assert model.name == "db.table" - assert model.fqn == '"db"."table"' - assert model.depends_on == {'"catalog"."db"."source"'} + ctx = Context( + config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")), paths=tmp_path + ) + assert len(ctx.models) == 2 - assert model.data_hash == HASH_WITH_CATALOG + m1 = ctx.get_model('m."a b"', raise_if_missing=True) + m2 = ctx.get_model('m."c d"', raise_if_missing=True) - # This will also have the same hash but the fqn is different so the snapshot is different so this is - # a new snapshot but with the same hash as before - expressions = d.parse( + assert m1.name == 'm."a b"' + assert m2.name == 'm."c d"' + assert t.cast(exp.Query, m1.render_query()).sql() == '''SELECT "a b" AS "c1", "a b" AS "c2"''' + assert t.cast(exp.Query, m2.render_query()).sql() == '''SELECT 'c d' AS "c1", "c d" AS "c2"''' + + +def test_blueprint_variable_precedence_sql(tmp_path: Path, assert_exp_eq: t.Callable) -> None: + init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY) + + blueprint_variables = tmp_path / "models/blueprint_variables.sql" + blueprint_variables.parent.mkdir(parents=True, exist_ok=True) + blueprint_variables.write_text( """ MODEL ( - name db.table + name s.@{bp_name}, + blueprints ( + (bp_name := m1, var1 := 'v1', var2 := 'v2'), + (bp_name := m2, var1 := 'v3'), + ), ); - SELECT x - FROM catalog.db.source + + @DEF(bp_name, override); + + SELECT + @var1 AS var1_macro_var, + @{var1} AS var1_identifier, + @VAR('var1') AS var1_var_macro_func, + @BLUEPRINT_VAR('var1') AS var1_blueprint_var_macro_func, + + @var2 AS var2_macro_var, + @{var2} AS var2_identifier, + @VAR('var2') AS var2_var_macro_func, + @BLUEPRINT_VAR('var2') AS var2_blueprint_var_macro_func, + + @bp_name AS bp_name_macro_var, + @{bp_name} AS bp_name_identifier, + @VAR('bp_name') AS bp_name_var_macro_func, + @BLUEPRINT_VAR('bp_name') AS bp_name_blueprint_var_macro_func, """ ) - model = load_sql_based_model(expressions, default_catalog="catalog") - assert model.default_catalog == "catalog" - assert model.name == "db.table" - assert model.fqn == '"catalog"."db"."table"' - assert model.depends_on == {'"catalog"."db"."source"'} + ctx = Context( + config=Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + variables={"var2": "1"}, + ), + paths=tmp_path, + ) + assert len(ctx.models) == 2 - assert model.data_hash == HASH_WITH_CATALOG + m1 = ctx.get_model("s.m1", raise_if_missing=True) + m2 = ctx.get_model("s.m2", raise_if_missing=True) - # Query is different since default catalog does not apply and therefore the hash is different - expressions = d.parse( + assert_exp_eq( + m1.render_query(), """ - MODEL ( - name table - ); - SELECT x - FROM source + SELECT + 'v1' AS "var1_macro_var", + "v1" AS "var1_identifier", + NULL AS "var1_var_macro_func", + 'v1' AS "var1_blueprint_var_macro_func", + 'v2' AS "var2_macro_var", + "v2" AS "var2_identifier", + '1' AS "var2_var_macro_func", + 'v2' AS "var2_blueprint_var_macro_func", + "override" AS "bp_name_macro_var", + "override" AS "bp_name_identifier", + NULL AS "bp_name_var_macro_func", + "m1" AS "bp_name_blueprint_var_macro_func" + """, + ) + assert_exp_eq( + m2.render_query(), """ + SELECT + 'v3' AS "var1_macro_var", + "v3" AS "var1_identifier", + NULL AS "var1_var_macro_func", + 'v3' AS "var1_blueprint_var_macro_func", + '1' AS "var2_macro_var", + "1" AS "var2_identifier", + '1' AS "var2_var_macro_func", + NULL AS "var2_blueprint_var_macro_func", + "override" AS "bp_name_macro_var", + "override" AS "bp_name_identifier", + NULL AS "bp_name_var_macro_func", + "m2" AS "bp_name_blueprint_var_macro_func" + """, ) - model = load_sql_based_model(expressions, default_catalog="catalog") - assert model.default_catalog == "catalog" - assert model.name == "table" - assert model.fqn == '"table"' - assert model.depends_on == {'"source"'} - - assert model.data_hash != HASH_WITH_CATALOG +def test_blueprint_variable_jinja(tmp_path: Path, assert_exp_eq: t.Callable) -> None: + init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY) -def test_default_catalog_python(): - HASH_WITH_CATALOG = "2928466080" + blueprint_variables = tmp_path / "models/blueprint_variables.sql" + blueprint_variables.parent.mkdir(parents=True, exist_ok=True) + blueprint_variables.write_text( + """ + MODEL ( + name s.@{bp_name}, + blueprints ( + (bp_name := m1, var1 := 'v1', var2 := v2), + (bp_name := m2, var1 := 'v3'), + ), + ); - @model(name="db.table", kind="full", columns={'"COL"': "int"}) - def my_model(context, **kwargs): - context.table("dependency.table") + @DEF(bp_name, override); - m = model.get_registry()["db.table"].model( - module_path=Path("."), - path=Path("."), + JINJA_QUERY_BEGIN; + SELECT + {{ blueprint_var('var1') }} AS var1, + '{{ blueprint_var('var2') }}' AS var2, + '{{ blueprint_var('var2', 'var2_default') }}' AS var2_default, + '{{ blueprint_var('bp_name') }}' AS bp_name + FROM s.{{ blueprint_var('bp_name') }}_source; + JINJA_END; + """ ) + ctx = Context( + config=Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + variables={"var2": "1"}, + ), + paths=tmp_path, + ) + assert len(ctx.models) == 2 - assert m.default_catalog is None - assert m.name == "db.table" - assert m.fqn == '"db"."table"' - assert m.depends_on == {'"dependency"."table"'} - - assert m.data_hash != HASH_WITH_CATALOG + m1 = ctx.get_model("s.m1", raise_if_missing=True) + m2 = ctx.get_model("s.m2", raise_if_missing=True) - m = model.get_registry()["db.table"].model( - module_path=Path("."), - path=Path("."), - default_catalog="catalog", + assert_exp_eq( + m1.render_query(), + """SELECT 'v1' AS "var1", 'v2' AS "var2", 'v2' AS "var2_default", 'm1' AS "bp_name" FROM "memory"."s"."m1_source" AS "m1_source" """, + ) + assert_exp_eq( + m2.render_query(), + """SELECT 'v3' AS "var1", 'None' AS "var2", 'var2_default' AS "var2_default", 'm2' AS "bp_name" FROM "memory"."s"."m2_source" AS "m2_source" """, ) - assert m.default_catalog == "catalog" - assert m.name == "db.table" - assert m.fqn == '"catalog"."db"."table"' - assert m.depends_on == {'"catalog"."dependency"."table"'} - # This ideally would be `m.data_hash == HASH_WITH_CATALOG`. The reason it is not is because when we hash - # the python function we make the hash out of the actual logic of the function which means `context.table("dependency.table")` - # is used when really is should be `context.table("catalog.dependency.table")`. - assert m.data_hash != HASH_WITH_CATALOG +def test_blueprint_variable_precedence_python(tmp_path: Path, mocker: MockerFixture) -> None: + init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY) - @model(name="catalog.db.table", kind="full", columns={'"COL"': "int"}) - def my_model(context, **kwargs): - context.table("catalog.dependency.table") + blueprint_variables = tmp_path / "models/blueprint_variables.py" + blueprint_variables.parent.mkdir(parents=True, exist_ok=True) + blueprint_variables.write_text( + """ +import pandas as pd # noqa: TID253 +from sqlglot import exp +from sqlmesh import model - m = model.get_registry()["catalog.db.table"].model( - module_path=Path("."), - path=Path("."), - default_catalog="other_catalog", - ) - assert m.default_catalog == "other_catalog" - assert m.name == "catalog.db.table" - assert m.fqn == '"catalog"."db"."table"' - assert m.depends_on == {'"catalog"."dependency"."table"'} +@model( + "s.@{bp_name}", + blueprints=[{"bp_name": "m", "var1": exp.to_column("v1"), "var2": 1}], + kind="FULL", + columns={"x": "INT"}, +) +def entrypoint(context, *args, **kwargs): + assert "bp_name" not in kwargs + assert "var1" not in kwargs + assert kwargs.get("var2") == "1" - assert m.data_hash == HASH_WITH_CATALOG + assert context.var("bp_name") is None + assert context.var("var1") is None + assert context.var("var2") == "1" - @model(name="catalog.db.table2", kind="full", columns={'"COL"': "int"}) - def my_model(context, **kwargs): - context.table("dependency.table") + assert context.blueprint_var("bp_name") == "m" + assert context.blueprint_var("var1") == exp.to_column("v1") + assert context.blueprint_var("var2") == 1 - m = model.get_registry()["catalog.db.table2"].model( - module_path=Path("."), - path=Path("."), - default_catalog="other_catalog", + return pd.DataFrame({"x": [1]}) + """ ) - assert m.default_catalog == "other_catalog" - assert m.name == "catalog.db.table2" - assert m.fqn == '"catalog"."db"."table2"' - assert m.depends_on == {'"other_catalog"."dependency"."table"'} + ctx = Context( + config=Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + variables={"var2": "1"}, + ), + paths=tmp_path, + ) + assert len(ctx.models) == 1 - assert m.data_hash != HASH_WITH_CATALOG + m = ctx.get_model("s.m", raise_if_missing=True) + context = ExecutionContext(mocker.Mock(), {}, None, None) - @model(name="table", kind="full", columns={'"COL"': "int"}) - def my_model(context, **kwargs): - context.table("table2") + assert t.cast(pd.DataFrame, list(m.render(context=context))[0]).to_dict() == {"x": {0: 1}} - m = model.get_registry()["table"].model( - module_path=Path("."), - path=Path("."), - default_catalog="catalog", + +def test_python_model_depends_on_blueprints(tmp_path: Path) -> None: + sql_model = tmp_path / "models" / "base_blueprints.sql" + sql_model.parent.mkdir(parents=True, exist_ok=True) + sql_model.write_text( + """ + MODEL ( + name test_schema1.@{model_name}, + blueprints ((model_name := foo), (model_name := bar)), + kind FULL + ); + + SELECT 1 AS id + """ ) - assert m.default_catalog == "catalog" - assert m.name == "table" - assert m.fqn == '"table"' - assert m.depends_on == {'"table2"'} + py_model = tmp_path / "models" / "depends_on_with_blueprint_vars.py" + py_model.parent.mkdir(parents=True, exist_ok=True) + py_model.write_text( + """ +import pandas as pd # noqa: TID253 +from sqlmesh import model - assert m.data_hash != HASH_WITH_CATALOG +@model( + "test_schema2.@model_name", + columns={ + "id": "int", + }, + blueprints=[ + {"model_name": "foo"}, + {"model_name": "bar"}, + ], + depends_on=["test_schema1.@{model_name}"], +) +def entrypoint(context, *args, **kwargs): + table = context.resolve_table(f"test_schema1.{context.blueprint_var('model_name')}") + return context.fetchdf(f"SELECT * FROM {table}")""" + ) + ctx = Context( + config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")), + paths=tmp_path, + ) + assert len(ctx.models) == 4 -def test_default_catalog_external_model(): - """ - Since external models fqns are the only thing affected by default catalog, and when they change new snapshots - are made, the hash will be the same across different names. - """ - EXPECTED_HASH = "1837375494" + ctx.plan(no_prompts=True, auto_apply=True) + assert ctx.fetchdf("SELECT * FROM test_schema2.foo").to_dict() == {"id": {0: 1}} - model = create_external_model("db.table", columns={"a": "int", "limit": "int"}) - assert model.default_catalog is None - assert model.name == "db.table" - assert model.fqn == '"db"."table"' - assert model.data_hash == EXPECTED_HASH +@time_machine.travel("2020-01-01 00:00:00 UTC") +def test_dynamic_date_spine_model(assert_exp_eq): + @macro() + def get_current_date(evaluator): + from sqlmesh.utils.date import now - model = create_external_model( - "db.table", columns={"a": "int", "limit": "int"}, default_catalog="catalog" - ) - assert model.default_catalog == "catalog" - assert model.name == "db.table" - assert model.fqn == '"catalog"."db"."table"' + return f"'{now().date()}'" - assert model.data_hash == EXPECTED_HASH + expressions = d.parse( + """ + MODEL (name test_model, dialect duckdb); - model = create_external_model( - "catalog.db.table", columns={"a": "int", "limit": "int"}, default_catalog="other_catalog" - ) - assert model.default_catalog == "other_catalog" - assert model.name == "catalog.db.table" - assert model.fqn == '"catalog"."db"."table"' + @DEF(curr_date, @get_current_date()); - assert model.data_hash == EXPECTED_HASH + WITH discount_promotion_dates AS ( + @date_spine('day', @curr_date::date - 90, @curr_date::date) + ) - # Since there is no schema defined, the default physical schema is used which changes the hash - model = create_external_model( - "table", columns={"a": "int", "limit": "int"}, default_catalog="catalog" + SELECT * FROM discount_promotion_dates + """ + ) + model = load_sql_based_model(expressions) + assert_exp_eq( + model.render_query(), + """ + WITH "discount_promotion_dates" AS ( + SELECT + "_exploded"."date_day" AS "date_day" + FROM UNNEST(CAST(GENERATE_SERIES(CAST('2020-01-01' AS DATE) - 90, CAST('2020-01-01' AS DATE), INTERVAL '1' DAY) AS DATE[])) AS "_exploded"("date_day") + ) + SELECT + "discount_promotion_dates"."date_day" AS "date_day" + FROM "discount_promotion_dates" AS "discount_promotion_dates" + """, ) - assert model.default_catalog == "catalog" - assert model.name == "table" - assert model.fqn == '"table"' - assert model.data_hash != EXPECTED_HASH +def test_seed_dont_coerce_na_into_null(tmp_path): + model_csv_path = (tmp_path / "model.csv").absolute() + with open(model_csv_path, "w", encoding="utf-8") as fd: + fd.write("code\nNA") -def test_user_cannot_set_default_catalog(): expressions = d.parse( - """ + f""" MODEL ( - name db.table, - default_catalog some_catalog + name db.seed, + kind SEED ( + path '{str(model_csv_path)}', + csv_settings ( + -- override NaN handling, such that no value can be coerced into NaN + keep_default_na = false, + na_values = (), + ), + ), ); - - SELECT 1::int AS a, 2::int AS b, 3 AS c, 4 as d; """ ) - with pytest.raises(ConfigError, match="`default_catalog` cannot be set on a per-model basis"): - load_sql_based_model(expressions) + model = load_sql_based_model(expressions, path=Path("./examples/sushi/models/test_model.sql")) - with pytest.raises(ConfigError, match="`default_catalog` cannot be set on a per-model basis"): + assert isinstance(model.kind, SeedKind) + assert model.seed is not None + assert len(model.seed.content) > 0 + assert next(model.render(context=None)).to_dict() == {"code": {0: "NA"}} - @model(name="db.table", kind="full", columns={'"COL"': "int"}, default_catalog="catalog") - def my_model(context, **kwargs): - context.table("dependency.table") +def test_seed_coerce_datetime(tmp_path): + model_csv_path = (tmp_path / "model.csv").absolute() -def test_depends_on_default_catalog_python(): - @model(name="some.table", kind="full", columns={'"COL"': "int"}, depends_on={"other.table"}) - def my_model(context, **kwargs): - context.table("dependency.table") + with open(model_csv_path, "w", encoding="utf-8") as fd: + fd.write("bad_datetime\n9999-12-31 23:59:59") - m = model.get_registry()["some.table"].model( - module_path=Path("."), - path=Path("."), - default_catalog="catalog", + expressions = d.parse( + f""" + MODEL ( + name db.seed, + kind SEED ( + path '{str(model_csv_path)}', + ), + columns ( + bad_datetime datetime, + ), + ); + """ ) - assert m.default_catalog == "catalog" - assert m.depends_on == {'"catalog"."other"."table"'} + model = load_sql_based_model(expressions, path=Path("./examples/sushi/models/test_model.sql")) + df = next(model.render(context=None)) + assert df["bad_datetime"].iloc[0] == "9999-12-31 23:59:59" -def test_end_date(): +def test_seed_invalid_date_column(tmp_path): + model_csv_path = (tmp_path / "model.csv").absolute() + + with open(model_csv_path, "w", encoding="utf-8") as fd: + fd.write("bad_date\n9999-12-31\n2025-01-01\n1000-01-01") + expressions = d.parse( - """ + f""" MODEL ( - name db.table, - kind INCREMENTAL_BY_TIME_RANGE ( - time_column ts, + name db.seed, + kind SEED ( + path '{str(model_csv_path)}', + ), + columns ( + bad_date date, ), - start '2023-01-01', - end '2023-06-01' ); - - SELECT 1::int AS a, 2::int AS b, now::timestamp as ts - """ + """ ) - model = load_sql_based_model(expressions) - assert model.start == "2023-01-01" - assert model.end == "2023-06-01" - assert model.interval_unit == IntervalUnit.DAY - - with pytest.raises(ConfigError, match=".*Start date.+can't be greater than end date.*"): - load_sql_based_model( - d.parse( - """ - MODEL ( - name db.table, - kind INCREMENTAL_BY_TIME_RANGE ( - time_column ts, - ), - start '2024-01-01', - end '2023-06-01' - ); + model = load_sql_based_model(expressions, path=Path("./examples/sushi/models/test_model.sql")) + df = next(model.render(context=None)) + # The conversion to date should not raise an error + assert df["bad_date"].to_list() == ["9999-12-31", "2025-01-01", "1000-01-01"] - SELECT 1::int AS a, 2::int AS b, now::timestamp as ts - """ - ) - ) +def test_seed_missing_columns(tmp_path): + model_csv_path = (tmp_path / "model.csv").absolute() + + with open(model_csv_path, "w", encoding="utf-8") as fd: + fd.write("key,value\n1,2\n3,4") -def test_end_no_start(): expressions = d.parse( - """ + f""" MODEL ( - name db.table, - kind INCREMENTAL_BY_TIME_RANGE ( - time_column ts, + name db.seed, + kind SEED ( + path '{str(model_csv_path)}', + ), + columns ( + key int, + value int, + missing_column int, ), - end '2023-06-01' ); + """ + ) - SELECT 1::int AS a, 2::int AS b, now::timestamp as ts + model = load_sql_based_model(expressions, path=Path("./examples/sushi/models/test_model.sql")) + with pytest.raises( + ConfigError, match="Seed model 'db.seed' has missing columns: {'missing_column'}.*" + ): + next(model.render(context=None)) + + +def test_missing_column_data_in_columns_key(): + expressions = d.parse( """ + MODEL ( + name db.seed, + kind SEED ( + path '../seeds/waiter_names.csv', + ), + columns ( + culprit, other_column double, + ) + ); + """ ) - with pytest.raises(ConfigError, match="Must define a start date if an end date is defined"): - load_sql_based_model(expressions) - load_sql_based_model(expressions, defaults={"start": "2023-01-01"}) - + with pytest.raises(ConfigError, match="Missing data type for column 'culprit'."): + load_sql_based_model(expressions, path=Path("./examples/sushi/models/test_model.sql")) -def test_variables(): - @macro() - def test_macro_var(evaluator) -> exp.Expression: - return exp.convert(evaluator.var("TEST_VAR_D") + 10) - expressions = parse( +def test_ignored_rules_serialization(): + expressions = d.parse( """ MODEL( name test_model, - kind FULL, + ignored_rules ['foo', 'bar'] ); - SELECT - @VAR('TEST_VAR_A') AS a, - @VAR('test_var_b', 'default_value') AS b, - @VAR('test_var_c') AS c, - @TEST_MACRO_VAR() AS d, - @'foo_@{test_var_e}' AS e, - @SQL(foo_@{test_var_f}) AS f, - 'foo_@{test_var_unused}' AS g + SELECT * FROM tbl; """, default_dialect="bigquery", ) - model = load_sql_based_model( - expressions, - variables={ - "test_var_a": "test_value", - "test_var_d": 1, - "test_var_e": 4, - "test_var_f": 5, - "test_var_unused": 2, - }, - ) - assert model.python_env[c.SQLMESH_VARS] == Executable.value( - {"test_var_a": "test_value", "test_var_d": 1, "test_var_e": 4, "test_var_f": 5} - ) - assert ( - model.render_query().sql(dialect="bigquery") - == "SELECT 'test_value' AS `a`, 'default_value' AS `b`, NULL AS `c`, 11 AS `d`, 'foo_4' AS `e`, `foo_5` AS `f`, 'foo_@{test_var_unused}' AS `g`" + model = load_sql_based_model(expressions) + + model_json = model.json() + model_json_parsed = json.loads(model_json) + + assert "ignored_rules" not in model_json_parsed + assert "ignored_rules_" not in model_json_parsed + + deserialized_model = SqlModel.parse_raw(model_json) + assert deserialized_model.dict() == model.dict() + + +def test_data_hash_unchanged_when_column_type_uses_default_dialect(): + model = create_sql_model( + "foo", + parse_one("SELECT * FROM bla"), + columns={"a": exp.DataType.build("int")}, + dialect="bigquery", ) - with pytest.raises(ConfigError, match=r"Macro VAR requires at least one argument.*"): - expressions = parse( - """ - MODEL( - name test_model, - ); + deserialized_model = SqlModel.parse_raw(model.json()) - SELECT @VAR() AS a; - """, - default_dialect="bigquery", - ) - load_sql_based_model(expressions) + assert model.columns_to_types_ == {"a": exp.DataType.build("int")} + assert deserialized_model.columns_to_types_ == {"a": exp.DataType.build("bigint")} - with pytest.raises( - ConfigError, match=r"The variable name must be a string literal, '123' was given instead.*" - ): - expressions = parse( - """ - MODEL( - name test_model, - ); + # int == int64 in bigquery + assert model.data_hash == deserialized_model.data_hash - SELECT @VAR(123) AS a; - """, - default_dialect="bigquery", - ) - load_sql_based_model(expressions) - with pytest.raises( - ConfigError, - match=r"The variable name must be a string literal, '@VAR_NAME' was given instead.*", - ): - expressions = parse( - """ - MODEL( - name test_model, - ); +def test_transitive_dependency_of_metadata_only_object_is_metadata_only(tmp_path: Path) -> None: + init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY) - @DEF(VAR_NAME, 'var_name'); - SELECT @VAR(@VAR_NAME) AS a; - """, - default_dialect="bigquery", - ) - load_sql_based_model(expressions) + test_model = tmp_path / "models/test_model.sql" + test_model.parent.mkdir(parents=True, exist_ok=True) + test_model.write_text("MODEL (name test_model, kind FULL); @metadata_macro(); SELECT 1 AS c") + metadata_macro_code = """ +from sqlglot import parse_one +from sqlmesh import macro -def test_named_variable_macros() -> None: - model = load_sql_based_model( - parse( - """ - MODEL(name sushi.test_gateway_macro); - @DEF(overridden_var, 'overridden_value'); - SELECT @gateway AS gateway, @TEST_VAR_A AS test_var_a, @overridden_var AS overridden_var - """ - ), - variables={ - c.GATEWAY: "in_memory", - "test_var_a": "test_value", - "test_var_unused": "unused", - "overridden_var": "initial_value", - }, - ) +def get_parsed_query(): + return parse_one("SELECT 1") - assert model.python_env[c.SQLMESH_VARS] == Executable.value( - {c.GATEWAY: "in_memory", "test_var_a": "test_value", "overridden_var": "initial_value"} - ) - assert ( - model.render_query_or_raise().sql() - == "SELECT 'in_memory' AS \"gateway\", 'test_value' AS \"test_var_a\", 'overridden_value' AS \"overridden_var\"" - ) +@macro(metadata_only=True) +def metadata_macro(evaluator): + if evaluator.runtime_stage == "evaluating": + evaluator.engine_adapter.execute({query})""" + metadata_macro = tmp_path / "macros/metadata_macro.py" + metadata_macro.parent.mkdir(parents=True, exist_ok=True) + metadata_macro.write_text(metadata_macro_code.format(query='"SELECT 1"')) -def test_variables_in_templates() -> None: - model = load_sql_based_model( - parse( - """ - MODEL(name sushi.test_gateway_macro); - @DEF(overridden_var, overridden_value); - SELECT 'gateway' AS col_@gateway, 'test_var_a' AS @{test_var_a}_col, 'overridden_var' AS col_@{overridden_var}_col - """ - ), - variables={ - c.GATEWAY: "in_memory", - "test_var_a": "test_value", - "test_var_unused": "unused", - "overridden_var": "initial_value", - }, + ctx = Context( + config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")), + paths=tmp_path, ) - assert model.python_env[c.SQLMESH_VARS] == Executable.value( - {c.GATEWAY: "in_memory", "test_var_a": "test_value", "overridden_var": "initial_value"} - ) - assert ( - model.render_query_or_raise().sql() - == "SELECT 'gateway' AS \"col_in_memory\", 'test_var_a' AS \"test_value_col\", 'overridden_var' AS \"col_overridden_value_col\"" - ) + model = ctx.get_model("test_model") + python_env = model.python_env - model = load_sql_based_model( - parse( - """ - MODEL(name sushi.test_gateway_macro); - @DEF(overridden_var, overridden_value); - SELECT 'combo' AS col_@{test_var_a}_@{overridden_var}_col_@gateway - """ - ), - variables={ - c.GATEWAY: "in_memory", - "test_var_a": "test_value", - "test_var_unused": "unused", - "overridden_var": "initial_value", - }, - ) + empty_executable = Executable(payload="") - assert model.python_env[c.SQLMESH_VARS] == Executable.value( - {c.GATEWAY: "in_memory", "test_var_a": "test_value", "overridden_var": "initial_value"} - ) - assert ( - model.render_query_or_raise().sql() - == "SELECT 'combo' AS \"col_test_value_overridden_value_col_in_memory\"" - ) + assert len(python_env) == 1 + assert (python_env.get("metadata_macro") or empty_executable).is_metadata - model = load_sql_based_model( - parse( - """ - MODEL( - name @{some_var}.bar, - dialect snowflake - ); + ctx.plan(no_prompts=True, auto_apply=True) - SELECT 1 AS c - """ - ), - variables={ - "some_var": "foo", - }, - ) + # This should make `parse_one` a dependency and so it'll be added in the python env + metadata_macro.write_text(metadata_macro_code.format(query='parse_one("SELECT 1")')) - assert model.name == "foo.bar" + ctx.load() + model = ctx.get_model("test_model") + python_env = model.python_env + assert len(python_env) == 2 + assert (python_env.get("metadata_macro") or empty_executable).is_metadata + assert (python_env.get("parse_one") or empty_executable).is_metadata -def test_variables_jinja(): - expressions = parse( - """ - MODEL( - name test_model, - kind FULL, - ); + plan = ctx.plan(no_prompts=True, auto_apply=True) + ctx_diff = plan.context_diff - JINJA_QUERY_BEGIN; - SELECT '{{ var('TEST_VAR_A') }}' AS a, '{{ var('test_var_b', 'default_value') }}' AS b, '{{ var('test_var_c') }}' AS c, {{ test_macro_var() }} AS d; - JINJA_END; - """, - default_dialect="bigquery", - ) + assert len(ctx_diff.modified_snapshots) == 1 - jinja_macros = JinjaMacroRegistry( - root_macros={ - "test_macro_var": MacroInfo( - definition="{% macro test_macro_var() %}{{ var('test_var_d') + 10 }}{% endmacro %}", - depends_on=[], - ) - }, - ) + new_snapshot, _ = ctx_diff.modified_snapshots['"test_model"'] + assert new_snapshot.change_category == SnapshotChangeCategory.METADATA - model = load_sql_based_model( - expressions, - variables={"test_var_a": "test_value", "test_var_d": 1, "test_var_unused": 2}, - jinja_macros=jinja_macros, - ) - assert model.python_env[c.SQLMESH_VARS] == Executable.value( - {"test_var_a": "test_value", "test_var_d": 1} - ) - assert ( - model.render_query().sql(dialect="bigquery") - == "SELECT 'test_value' AS `a`, 'default_value' AS `b`, 'None' AS `c`, 11 AS `d`" - ) + # This should make `get_parsed_query` a dependency and so it'll be added in + # the python env, carrying `parse_one` with it as another transitive dependency + metadata_macro.write_text(metadata_macro_code.format(query="get_parsed_query()")) + ctx.load() + model = ctx.get_model("test_model") + python_env = model.python_env + + assert len(python_env) == 3 + assert (python_env.get("metadata_macro") or empty_executable).is_metadata + assert (python_env.get("get_parsed_query") or empty_executable).is_metadata + assert (python_env.get("parse_one") or empty_executable).is_metadata + + plan = ctx.plan(no_prompts=True, auto_apply=True) + ctx_diff = plan.context_diff + + assert len(ctx_diff.modified_snapshots) == 1 + + new_snapshot, _ = ctx_diff.modified_snapshots['"test_model"'] + assert new_snapshot.change_category == SnapshotChangeCategory.METADATA -def test_variables_python_model(mocker: MockerFixture) -> None: - @model( - "test_variables_python_model", - kind="full", - columns={"a": "string", "b": "string", "c": "string"}, - ) - def model_with_variables(context, **kwargs): - return pd.DataFrame( - [ - { - "a": context.var("TEST_VAR_A"), - "b": context.var("test_var_b", "default_value"), - "c": context.var("test_var_c"), - } - ] - ) - python_model = model.get_registry()["test_variables_python_model"].model( - module_path=Path("."), - path=Path("."), - variables={"test_var_a": "test_value", "test_var_unused": 2}, +def test_vars_are_taken_into_account_when_propagating_metadata_status(tmp_path: Path) -> None: + init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY) + + test_model = tmp_path / "models/test_model.sql" + test_model.parent.mkdir(parents=True, exist_ok=True) + test_model.write_text( + "MODEL (name test_model, kind FULL, blueprints ((v4 := 4, v5 := 5)));" + "@m1_metadata_references_v1();" # metadata macro, references v1 internally => v1 metadata + "@m2_metadata_does_not_reference_var(@v2, @v3);" # metadata macro => v2 metadata, v3 metadata + "@m3_non_metadata_references_v4(@v3);" # non-metadata macro, references v4 => v3, v4 are not metadata + "SELECT 1 AS c;" + "@m2_metadata_does_not_reference_var(@v6);" # metadata macro => v6 is metadata + "@m4_non_metadata_references_v6();" # non-metadata macro, references v6 => v6 is not metadata + "ON_VIRTUAL_UPDATE_BEGIN;" + "@m3_non_metadata_references_v4(@v5);" # non-metadata macro, metadata expression => v5 metadata + "ON_VIRTUAL_UPDATE_END;" ) - assert python_model.python_env[c.SQLMESH_VARS] == Executable.value({"test_var_a": "test_value"}) + macro_code = """ +from sqlmesh import macro - context = ExecutionContext(mocker.Mock(), {}, None, None) - df = list(python_model.render(context=context))[0] - assert df.to_dict(orient="records") == [{"a": "test_value", "b": "default_value", "c": None}] +@macro(metadata_only=True) +def m1_metadata_references_v1(evaluator): + evaluator.var("v1") + return None +@macro(metadata_only=True) +def m2_metadata_does_not_reference_var(evaluator, *args): + return None -def test_load_external_model_python(sushi_context) -> None: - @model( - "test_load_external_model_python", - columns={"customer_id": "int", "zip": "str"}, - kind={"name": ModelKindName.FULL}, - ) - def external_model_python(context, **kwargs): - demographics_table = context.table("memory.raw.demographics") - return context.fetchdf( - exp.select("customer_id", "zip").from_(demographics_table), - ) +@macro() +def m3_non_metadata_references_v4(evaluator, *args): + evaluator.var("v4") + return None - python_model = model.get_registry()["test_load_external_model_python"].model( - module_path=Path("."), - path=Path("."), +@macro() +def m4_non_metadata_references_v6(evaluator): + evaluator.var("v6") + return None""" + + test_macros = tmp_path / "macros/test_macros.py" + test_macros.parent.mkdir(parents=True, exist_ok=True) + test_macros.write_text(macro_code) + + ctx = Context( + config=Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + variables={"v1": 1, "v2": 2, "v3": 3, "v6": 6}, + ), + paths=tmp_path, ) + model = ctx.get_model("test_model") - context = ExecutionContext(sushi_context.engine_adapter, sushi_context.snapshots, None, None) - df = list(python_model.render(context=context))[0] + python_env = model.python_env - assert df.to_dict(orient="records") == [{"customer_id": 1, "zip": "00000"}] + assert len(python_env) == 8 + assert "m1_metadata_references_v1" in python_env + assert "m2_metadata_does_not_reference_var" in python_env + assert "m3_non_metadata_references_v4" in python_env + assert "m4_non_metadata_references_v6" in python_env + variables = python_env.get(c.SQLMESH_VARS) + metadata_variables = python_env.get(c.SQLMESH_VARS_METADATA) -def test_variables_python_sql_model(mocker: MockerFixture) -> None: - @model( - "test_variables_python_model", - is_sql=True, - kind="full", - columns={"a": "string", "b": "string", "c": "string"}, - ) - def model_with_variables(evaluator, **kwargs): - return exp.select( - exp.convert(evaluator.var("TEST_VAR_A")).as_("a"), - exp.convert(evaluator.var("test_var_b", "default_value")).as_("b"), - exp.convert(evaluator.var("test_var_c")).as_("c"), - ) + assert variables == Executable.value({"v3": 3, "v6": 6}) + assert metadata_variables == Executable.value({"v1": 1, "v2": 2}, is_metadata=True) - python_sql_model = model.get_registry()["test_variables_python_model"].model( - module_path=Path("."), - path=Path("."), - variables={"test_var_a": "test_value", "test_var_unused": 2}, - ) + blueprint_variables = python_env.get(c.SQLMESH_BLUEPRINT_VARS) + blueprint_metadata_variables = python_env.get(c.SQLMESH_BLUEPRINT_VARS_METADATA) - assert python_sql_model.python_env[c.SQLMESH_VARS] == Executable.value( - {"test_var_a": "test_value"} + assert blueprint_variables == Executable.value({"v4": SqlValue(sql="4")}) + assert blueprint_metadata_variables == Executable.value( + {"v5": SqlValue(sql="5")}, is_metadata=True ) - context = ExecutionContext(mocker.Mock(), {}, None, None) - query = list(python_sql_model.render(context=context))[0] - assert ( - query.sql() - == """SELECT 'test_value' AS "a", 'default_value' AS "b", NULL AS "c" """.strip() - ) + macro_evaluator = MacroEvaluator(python_env=python_env) + assert macro_evaluator.locals == { + "runtime_stage": "loading", + "default_catalog": None, + c.SQLMESH_VARS: {"v3": 3, "v6": 6}, + c.SQLMESH_VARS_METADATA: {"v1": 1, "v2": 2}, + c.SQLMESH_BLUEPRINT_VARS: {"v4": exp.Literal.number("4")}, + c.SQLMESH_BLUEPRINT_VARS_METADATA: {"v5": exp.Literal.number("5")}, + } + assert macro_evaluator.var("v1") == 1 + assert macro_evaluator.var("v2") == 2 + assert macro_evaluator.var("v3") == 3 + assert macro_evaluator.var("v6") == 6 + assert macro_evaluator.blueprint_var("v4") == exp.Literal.number("4") + assert macro_evaluator.blueprint_var("v5") == exp.Literal.number("5") -def test_named_variables_python_model(mocker: MockerFixture) -> None: - @model( - "test_named_variables_python_model", - kind="full", - columns={"a": "string", "b": "string", "c": "string"}, + query_with_vars = macro_evaluator.transform( + parse_one("SELECT " + ", ".join(f"@v{var}, @VAR('v{var}')" for var in [1, 2, 3, 6])) ) - def model_with_named_variables( - context, start: TimeLike, test_var_a: str, test_var_b: t.Optional[str] = None, **kwargs - ): - return pd.DataFrame( - [{"a": test_var_a, "b": test_var_b, "start": start.strftime("%Y-%m-%d")}] # type: ignore - ) + assert t.cast(exp.Expression, query_with_vars).sql() == "SELECT 1, 1, 2, 2, 3, 3, 6, 6" - python_model = model.get_registry()["test_named_variables_python_model"].model( - module_path=Path("."), - path=Path("."), - # Passing `start` in variables to make sure that built-in arguments can't be overridden. - variables={ - "test_var_a": "test_value", - "test_var_unused": 2, - "start": "2024-01-01", - }, + query_with_blueprint_vars = macro_evaluator.transform( + parse_one("SELECT " + ", ".join(f"@v{var}, @BLUEPRINT_VAR('v{var}')" for var in [4, 5])) ) + assert t.cast(exp.Expression, query_with_blueprint_vars).sql() == "SELECT 4, 4, 5, 5" - assert python_model.python_env[c.SQLMESH_VARS] == Executable.value( - {"test_var_a": "test_value", "start": "2024-01-01"} + +def test_variable_mentioned_in_both_metadata_and_non_metadata_macro(tmp_path: Path) -> None: + init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY) + + test_model = tmp_path / "models/test_model.sql" + test_model.parent.mkdir(parents=True, exist_ok=True) + test_model.write_text( + "MODEL (name test_model, kind FULL); @m1_references_v_metadata(); SELECT @m2_references_v_non_metadata() AS c;" ) - context = ExecutionContext(mocker.Mock(), {}, None, None) - df = list(python_model.render(context=context))[0] - assert df.to_dict(orient="records") == [{"a": "test_value", "b": None, "start": to_ds(c.EPOCH)}] + macro_code = """ +from sqlmesh import macro +@macro(metadata_only=True) +def m1_references_v_metadata(evaluator): + evaluator.var("v") + return None -def test_gateway_macro() -> None: - model = load_sql_based_model( - parse( - """ - MODEL(name sushi.test_gateway_macro); - SELECT @gateway AS gateway - """ - ), - variables={c.GATEWAY: "in_memory"}, +@macro() +def m2_references_v_non_metadata(evaluator): + evaluator.var("v") + return None""" + + test_macros = tmp_path / "macros/test_macros.py" + test_macros.parent.mkdir(parents=True, exist_ok=True) + test_macros.write_text(macro_code) + + ctx = Context( + config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"), variables={"v": 1}), + paths=tmp_path, ) + model = ctx.get_model("test_model") - assert model.python_env[c.SQLMESH_VARS] == Executable.value({c.GATEWAY: "in_memory"}) - assert model.render_query_or_raise().sql() == "SELECT 'in_memory' AS \"gateway\"" + python_env = model.python_env - @macro() - def macro_uses_gateway(evaluator) -> exp.Expression: - return exp.convert(evaluator.gateway + "_from_macro") + assert len(python_env) == 3 + assert set(python_env) > {"m1_references_v_metadata", "m2_references_v_non_metadata"} + assert python_env.get(c.SQLMESH_VARS) == Executable.value({"v": 1}) - model = load_sql_based_model( - parse( - """ - MODEL(name sushi.test_gateway_macro); - SELECT @macro_uses_gateway() AS gateway_from_macro - """ - ), - variables={c.GATEWAY: "in_memory"}, - ) - assert model.python_env[c.SQLMESH_VARS] == Executable.value({c.GATEWAY: "in_memory"}) - assert ( - model.render_query_or_raise().sql() - == "SELECT 'in_memory_from_macro' AS \"gateway_from_macro\"" +def test_only_top_level_macro_func_impacts_var_descendant_metadata_status(tmp_path: Path) -> None: + init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY) + + test_model = tmp_path / "models/test_model.sql" + test_model.parent.mkdir(parents=True, exist_ok=True) + test_model.write_text( + "MODEL (name test_model, kind FULL); @m1_metadata(@m2_non_metadata(@v)); SELECT 1 AS c;" ) + macro_code = """ +from sqlmesh import macro -def test_gateway_macro_jinja() -> None: - model = load_sql_based_model( - parse( - """ - MODEL(name sushi.test_gateway_macro_jinja); - JINJA_QUERY_BEGIN; - SELECT '{{ gateway() }}' AS gateway_jinja; - JINJA_END; - """ - ), - variables={c.GATEWAY: "in_memory"}, - ) +@macro(metadata_only=True) +def m1_metadata(evaluator, *args): + return None - assert model.python_env[c.SQLMESH_VARS] == Executable.value({c.GATEWAY: "in_memory"}) - assert model.render_query_or_raise().sql() == "SELECT 'in_memory' AS \"gateway_jinja\"" +@macro() +def m2_non_metadata(evaluator, *args): + return None""" + test_macros = tmp_path / "macros/test_macros.py" + test_macros.parent.mkdir(parents=True, exist_ok=True) + test_macros.write_text(macro_code) -def test_gateway_python_model(mocker: MockerFixture) -> None: - @model( - "test_gateway_python_model", - kind="full", - columns={"gateway_python": "string"}, + ctx = Context( + config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"), variables={"v": 1}), + paths=tmp_path, ) - def model_with_variables(context, **kwargs): - return pd.DataFrame([{"gateway_python": context.gateway + "_from_python"}]) + model = ctx.get_model("test_model") - python_model = model.get_registry()["test_gateway_python_model"].model( - module_path=Path("."), - path=Path("."), - variables={c.GATEWAY: "in_memory"}, + python_env = model.python_env + + assert len(python_env) == 3 + assert set(python_env) > {"m1_metadata", "m2_non_metadata"} + assert python_env.get(c.SQLMESH_VARS_METADATA) == Executable.value({"v": 1}, is_metadata=True) + + +def test_non_metadata_object_takes_precedence_over_metadata_only_object(tmp_path: Path) -> None: + init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY) + + test_model = tmp_path / "models/test_model.sql" + test_model.parent.mkdir(parents=True, exist_ok=True) + test_model.write_text("MODEL (name test_model, kind FULL); @m1(); @m2(); SELECT 1 AS c") + + macro_code = """ +from sqlglot import parse_one +from sqlmesh import macro + +def common_dep(): + pass + +def m1_dep(): + pass + +@macro(metadata_only=True) +def m1(evaluator): + m1_dep() + common_dep() + if evaluator.runtime_stage == "evaluating": + evaluator.engine_adapter.execute(parse_one("SELECT 1")) + +@macro() +def m2(evaluator): + common_dep() + if evaluator.runtime_stage == "evaluating": + evaluator.engine_adapter.execute(parse_one("SELECT 1"))""" + + test_macros = tmp_path / "macros/test_macros.py" + test_macros.parent.mkdir(parents=True, exist_ok=True) + test_macros.write_text(macro_code) + + ctx = Context( + config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")), + paths=tmp_path, ) + model = ctx.get_model("test_model") + empty_executable = Executable(payload="") - assert python_model.python_env[c.SQLMESH_VARS] == Executable.value({c.GATEWAY: "in_memory"}) + python_env = model.python_env - context = ExecutionContext(mocker.Mock(), {}, None, None) - df = list(python_model.render(context=context))[0] - assert df.to_dict(orient="records") == [{"gateway_python": "in_memory_from_python"}] + # Both `m1` and `m2` refer to `parse_one`, so for `m1` it would be a transitive metadata-only + # object, but since the python env is a flat namespace and `parse_one` is also a dependency + # of `m2`, it needs to be treated as non-metadata. + assert len(python_env) == 5 + assert (python_env.get("m1") or empty_executable).is_metadata + assert (python_env.get("m1_dep") or empty_executable).is_metadata + assert not (python_env.get("m2") or empty_executable).is_metadata + assert not (python_env.get("parse_one") or empty_executable).is_metadata + assert not (python_env.get("common_dep") or empty_executable).is_metadata -@pytest.mark.parametrize("dialect", ["spark", "trino"]) -def test_view_render_no_quote_identifiers(dialect: str) -> None: - expressions = d.parse( +def test_macros_referenced_in_metadata_statements_and_properties_are_metadata_only( + tmp_path: Path, +) -> None: + init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY) + + test_model = tmp_path / "models/test_model.sql" + test_model.parent.mkdir(parents=True, exist_ok=True) + test_model.write_text( """ MODEL ( - name db.table, - kind VIEW, + name test_model, + kind FULL, + signals ( + test_signal_always_true(arg1 := @m1(), arg2 := @non_metadata_macro()) + ), + audits ( + non_zero_c1, + unique_values(columns := @m2()), + ), + physical_properties ( + random_prop = @random_prop_macro() + ), ); - SELECT a, b, c FROM source_table; + + SELECT + 1 AS c1, + @zero_alt() AS c2, + @zero_metadata() AS c3, + @non_metadata_macro() AS c4; + + + ON_VIRTUAL_UPDATE_BEGIN; + + @bla(); + @random_prop_macro(); + + ON_VIRTUAL_UPDATE_END; + """ ) - model = load_sql_based_model(expressions, dialect=dialect) - assert ( - model.render_query_or_raise().sql(dialect=dialect) - == "SELECT a AS a, b AS b, c AS c FROM source_table AS source_table" + + macro_code = """ +from sqlglot import exp +from sqlmesh import macro + +def baz(): + pass + +def bob(): + pass + +@macro() +def bla(): + bob() + +@macro() +def m1(evaluator): + baz() + return 1 + +@macro() +def m2(evaluator): + return exp.column("c") + +@macro() +def zero(evaluator): + return 0 + +@macro() +def zero_alt(evaluator): + return 0 + +@macro(metadata_only=True) +def zero_metadata(evaluator): + return 0 + +@macro() +def non_metadata_macro(evaluator): + return 1 + +@macro() +def random_prop_macro(evaluator): + return 1""" + + test_macros = tmp_path / "macros/test_macros.py" + test_macros.parent.mkdir(parents=True, exist_ok=True) + test_macros.write_text(macro_code) + + signal_code = """ +import typing as t + +from sqlmesh import signal + +def bar(): + pass + +@signal() +def test_signal_always_true(batch, arg1, arg2): + bar() + return True""" + + test_signals = tmp_path / "signals/test_signals.py" + test_signals.parent.mkdir(parents=True, exist_ok=True) + test_signals.write_text(signal_code) + + audit_code = """ + AUDIT ( + name non_zero_c1, + ); + + SELECT + * + FROM @this_model + WHERE + c = @zero() OR c = @zero_alt() OR c = @zero_metadata(); + """ + + test_audits = tmp_path / "audits/non_zero_c1.sql" + test_audits.parent.mkdir(parents=True, exist_ok=True) + test_audits.write_text(audit_code) + + ctx = Context( + config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")), + paths=tmp_path, ) + model = ctx.get_model("test_model") + empty_executable = Executable(payload="") + python_env = model.python_env -@pytest.mark.parametrize( - "dialect,kind", - [ - ("spark", "FULL"), - ("trino", "FULL"), - ("duckdb", "VIEW"), - ("duckdb", "FULL"), - ], -) -def test_render_quote_identifiers(dialect: str, kind: str) -> None: - expressions = d.parse( - f""" - MODEL ( - name db.table, - kind {kind}, - ); - SELECT a, b, c FROM source_table; - """ + assert len(python_env) == 13 + assert (python_env.get("test_signal_always_true") or empty_executable).is_metadata + assert (python_env.get("bar") or empty_executable).is_metadata + assert (python_env.get("m1") or empty_executable).is_metadata + assert (python_env.get("baz") or empty_executable).is_metadata + assert (python_env.get("m2") or empty_executable).is_metadata + assert (python_env.get("exp") or empty_executable).is_metadata + assert (python_env.get("zero") or empty_executable).is_metadata + assert (python_env.get("zero_metadata") or empty_executable).is_metadata + assert (python_env.get("bla") or empty_executable).is_metadata + assert (python_env.get("bob") or empty_executable).is_metadata + + # non_metadata_macro is referenced in the signal, which makes that reference "metadata only", + # but it's also referenced in the model's query and the macro itself is not "metadata only", + # so the corresponding executable needs to be included in the data hash calculation. The same + # is true for zero_alt, which is referenced in the non_zero_c1 audit. + assert not (python_env.get("zero_alt") or empty_executable).is_metadata + assert not (python_env.get("non_metadata_macro") or empty_executable).is_metadata + assert not (python_env.get("random_prop_macro") or empty_executable).is_metadata + + +def test_scd_type_2_full_history_restatement(): + assert ModelKindName.SCD_TYPE_2.full_history_restatement_only is True + assert ModelKindName.SCD_TYPE_2_BY_TIME.full_history_restatement_only is True + assert ModelKindName.SCD_TYPE_2_BY_COLUMN.full_history_restatement_only is True + assert ModelKindName.INCREMENTAL_BY_TIME_RANGE.full_history_restatement_only is False + + +def test_python_model_boolean_values(): + @model( + "db.test_model", + kind=dict( + name=ModelKindName.SCD_TYPE_2_BY_TIME, + unique_key=["id"], + disable_restatement=False, + ), + columns={"id": "string", "name": "string"}, + optimize_query=False, ) - model = load_sql_based_model(expressions, dialect=dialect) - assert ( - model.render_query_or_raise().sql(dialect="duckdb") - == 'SELECT "a" AS "a", "b" AS "b", "c" AS "c" FROM "source_table" AS "source_table"' + def test_model(context, **kwargs): + return pd.DataFrame([{"id": context.var("1")}]) + + python_model = model.get_registry()["db.test_model"].model( + module_path=Path("."), path=Path("."), dialect="duckdb" ) + assert python_model.kind.disable_restatement is False + assert python_model.optimize_query is False -def test_this_model() -> None: + +def test_var_in_def(assert_exp_eq): expressions = d.parse( """ MODEL ( - name `project-1.table`, - dialect bigquery, + name db.table, + kind INCREMENTAL_BY_TIME_RANGE( + time_column ds + ), ); - JINJA_STATEMENT_BEGIN; - VACUUM {{ this_model }} TO 'a'; - JINJA_END; - - JINJA_QUERY_BEGIN; - SELECT '{{ this_model }}' as x - JINJA_END; + @DEF(var, @start_ds); - JINJA_STATEMENT_BEGIN; - VACUUM {{ this_model }} TO 'b'; - JINJA_END; - """ + SELECT @var AS ds + """ ) - model = load_sql_based_model(expressions) - assert ( - model.render_query_or_raise().sql(dialect="bigquery") - == """SELECT '`project-1`.`table`' AS `x`""" - ) + model = load_sql_based_model(expressions) - assert ( - model.render_pre_statements()[0].sql(dialect="bigquery") - == """VACUUM `project-1`.`table` TO 'a'""" - ) - assert ( - model.render_post_statements()[0].sql(dialect="bigquery") - == """VACUUM `project-1`.`table` TO 'b'""" + assert_exp_eq( + model.render_query(), + """ + SELECT '1970-01-01' AS "ds" + """, ) - snapshot = Snapshot.from_node(model, nodes={}) - assert ( - model.render_query_or_raise( - start="2020-01-01", - snapshots={snapshot.name: snapshot}, - ).sql(dialect="bigquery") - == """SELECT '`project-1`.`table`' AS `x`""" +def test_formatting_flag_serde(): + expressions = d.parse( + """ + MODEL( + name test_model, + formatting False, + ); + SELECT * FROM tbl; + """, + default_dialect="duckdb", ) - snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + model = load_sql_based_model(expressions) + assert model.render_definition()[0].sql() == "MODEL (\nname test_model,\nformatting False\n)" - assert ( - model.render_query_or_raise( - start="2020-01-01", - snapshots={snapshot.name: snapshot}, - ).sql(dialect="bigquery") - == f"SELECT '`sqlmesh__project-1`.`project_1__table__{snapshot.version}`' AS `x`" - ) + model_json = model.json() + assert "formatting" not in json.loads(model_json) + + deserialized_model = SqlModel.parse_raw(model_json) + assert deserialized_model.dict() == model.dict() -def test_macros_in_model_statement(sushi_context, assert_exp_eq): +def test_call_python_macro_from_jinja(): + def noop() -> None: + print("noop") + @macro() - def session_properties(evaluator, value): - return exp.Property( - this=exp.var("session_properties"), - value=exp.convert([exp.convert("foo").eq(exp.var(f"bar_{value}"))]), - ) + def test_runtime_stage(evaluator): + noop() + return evaluator.runtime_stage expressions = d.parse( """ MODEL ( - name @{gateway}__@{gateway}.test_model, - kind INCREMENTAL_BY_TIME_RANGE ( - time_column @{time_column} - - ), - start @IF(@gateway = 'test_gateway', '2023-01-01', '2024-01-02'), - @session_properties(baz) + name db.table, + dialect spark, + owner owner_name, ); - SELECT a, b UNION SELECT c, c - """ + JINJA_QUERY_BEGIN; + SELECT '{{ test_runtime_stage() }}' AS a, '{{ test_runtime_stage_jinja('bla') }}' AS b; + JINJA_END; + """ ) - model = load_sql_based_model( - expressions, variables={"gateway": "test_gateway", "time_column": "a"} + jinja_macros = JinjaMacroRegistry( + root_macros={ + "test_runtime_stage_jinja": MacroInfo( + definition="{% macro test_runtime_stage_jinja(value) %}{{ test_runtime_stage() }}_{{ value }}{% endmacro %}", + depends_on=[], + ) + } ) - assert model.name == "test_gateway__test_gateway.test_model" - assert model.time_column - assert model.time_column.column == exp.column("a", quoted=True) - assert model.start == "2023-01-01" - assert model.session_properties == {"foo": exp.column("bar_baz", quoted=False)} + model = load_sql_based_model(expressions, jinja_macros=jinja_macros) + assert model.render_query().sql() == "SELECT 'loading' AS a, 'loading_bla' AS b" + assert set(model.python_env) == {"noop", "test_runtime_stage"} -def test_python_model_dialect(): - model._dialect = "snowflake" - @model( - name="a", - kind=IncrementalByTimeRangeKind(time_column=TimeColumn(column="x", format="YYMMDD")), - columns={}, - ) - def test(context, **kwargs): - return None +def test_python_env_references_are_unequal_but_point_to_same_definition(tmp_path: Path) -> None: + # This tests for regressions against an edge case bug which was due to reloading modules + # in sqlmesh.utils.metaprogramming.import_python_file. Depending on the module loading + # order, we could get a "duplicate symbol in python env" error, even though the references + # essentially pointed to the same definition (e.g. function or class). + init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY) - m = model.get_registry()["a"].model( - module_path=Path("."), - path=Path("."), - dialect="snowflake", + db_path = str(tmp_path / "db.db") + db_connection = DuckDBConnectionConfig(database=db_path) + + config = Config( + gateways={"duckdb": GatewayConfig(connection=db_connection)}, + model_defaults=ModelDefaultsConfig(dialect="duckdb"), ) - assert m.time_column.column.sql() == '"X"' - assert m.time_column.format == "%y%m%d" + file_a = tmp_path / "macros" / "a.py" + file_b = tmp_path / "macros" / "b.py" + file_c = tmp_path / "macros" / "c.py" - @model( - name="b", - kind=IncrementalByTimeRangeKind(time_column="y"), - columns={}, + file_a.write_text( + """from macros.c import target + +def f1(): + target() +""" ) - def test(context, **kwargs): - return None + file_b.write_text( + """from sqlmesh import macro - m = model.get_registry()["b"].model( - module_path=Path("."), - path=Path("."), - dialect="snowflake", +from macros.a import f1 +from macros.c import target + +@macro() +def first_macro(evaluator): + f1() + +@macro() +def second_macro(evaluator): + target() +""" + ) + file_c.write_text( + """def target(): + pass +""" ) - assert m.time_column.column.sql() == '"Y"' - assert m.time_column.format == "%Y-%m-%d" + model_file = tmp_path / "models" / "model.sql" + model_file.write_text("MODEL (name a); @first_macro(); @second_macro(); SELECT 1 AS c") - model._dialect = None + ctx = Context(paths=tmp_path, config=config, load=False) + loader = ctx._loaders[0] + original_glob_paths = loader._glob_paths -def test_jinja_runtime_stage(assert_exp_eq): - expressions = d.parse( - """ - MODEL ( - name test.jinja - ); + def _patched_glob_paths(path, *args, **kwargs): + if path == tmp_path / "macros": + yield from [file_a, file_c, file_b] + else: + yield from original_glob_paths(path, *args, **kwargs) - JINJA_QUERY_BEGIN; + # We force the import order to be a.py -> c.py -> b.py: + # + # 1. a.py is loaded, so "macros", "macros.a" and "macros.c" are loaded in sys.modules + # 2. c.py is loaded, so "macros" and "macros.c" are reloaded in sys.modules + # 3. b.py is loaded, so "macros" is reloaded and "macros.b" is loaded in sys.modules + # + # (1) => id(sys.modules["macros.a"].target) == id(sys.modules["macros.c"].target) == X + # (2) => id(sys.modules["macros.c"].target) == Y != X == id(sys.modules["macros.a"].target) + # (3) => affects neither sys.modules["macros.a"] nor sys.modules["macros.c"], just loads the macros + # + # At this point we have two different function instances, one in sys.modules["macros.a"] and one + # in sys.modules["macros.c"], which encapsulate the same definition (source code). This used to + # lead to a crash, because we prohibit unequal objects with the same name in the python env. + with patch.object(loader, "_glob_paths", side_effect=_patched_glob_paths): + ctx.load() - SELECT '{{ runtime_stage }}' as a, {{ runtime_stage == 'loading' }} as b + model = ctx.models['"a"'] + python_env = model.python_env - JINJA_END; - """ + assert len(python_env) == 4 + assert python_env.get("target") == Executable( + payload="def target():\n pass", name="target", path="macros/c.py" ) - model = load_sql_based_model(expressions) - assert_exp_eq(model.render_query(), '''SELECT 'loading' as "a", TRUE as "b"''') - -def test_forward_only_on_destructive_change_config() -> None: - # global default to ALLOW for non-incremental models - config = Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")) - context = Context(config=config) +def test_unequal_duplicate_python_env_references_are_prohibited(tmp_path: Path) -> None: + init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY) - expressions = d.parse( - """ - MODEL ( - name memory.db.table, - kind FULL, - ); - SELECT a, b, c FROM source_table; - """ + db_path = str(tmp_path / "db.db") + db_connection = DuckDBConnectionConfig(database=db_path) + config = Config( + gateways={"duckdb": GatewayConfig(connection=db_connection)}, + model_defaults=ModelDefaultsConfig(dialect="duckdb"), ) - model = load_sql_based_model(expressions, defaults=config.model_defaults.dict()) - context.upsert_model(model) - context_model = context.get_model("memory.db.table") - assert context_model.on_destructive_change.is_allow - # global default to ERROR for incremental models - config = Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")) - context = Context(config=config) + file_a = tmp_path / "macros" / "unimportant_macro.py" + file_b = tmp_path / "macros" / "just_f.py" - expressions = d.parse( - """ - MODEL ( - name memory.db.table, - kind INCREMENTAL_BY_TIME_RANGE ( - time_column c, - forward_only True - ), - ); - SELECT a, b, c FROM source_table; - """ - ) - model = load_sql_based_model(expressions, defaults=config.model_defaults.dict()) - context.upsert_model(model) - context_model = context.get_model("memory.db.table") - assert context_model.on_destructive_change.is_error + file_a.write_text( + """from sqlmesh import macro +from macros.just_f import f - # WARN specified in model definition, overrides incremental model sqlmesh default ERROR - config = Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")) - context = Context(config=config) +a = False - expressions = d.parse( - """ - MODEL ( - name memory.db.table, - kind INCREMENTAL_BY_TIME_RANGE ( - time_column c, - forward_only True, - on_destructive_change warn - ), - ); - SELECT a, b, c FROM source_table; - """ +@macro() +def unimportant_macro(evaluator): + print(a) + f() + return 1 +""" ) - model = load_sql_based_model(expressions, defaults=config.model_defaults.dict()) - context.upsert_model(model) - context_model = context.get_model("memory.db.table") - assert context_model.on_destructive_change.is_warn + file_b.write_text( + """a = 0 - # WARN specified as model default, overrides incremental model sqlmesh default ERROR +def f(): + print(a) +""" + ) + + model_file = tmp_path / "models" / "model.sql" + model_file.write_text("MODEL (name m); SELECT @unimportant_macro() AS unimportant_macro") + + with pytest.raises(SQLMeshError, match=r"duplicate definitions found"): + Context(paths=tmp_path, config=config) + + +def test_semicolon_is_metadata_only_change(tmp_path, assert_exp_eq): + init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY) + + db_connection = DuckDBConnectionConfig(database=str(tmp_path / "db.db")) config = Config( - model_defaults=ModelDefaultsConfig(dialect="duckdb", on_destructive_change="warn") + gateways={"duckdb": GatewayConfig(connection=db_connection)}, + model_defaults=ModelDefaultsConfig(dialect="duckdb"), ) - context = Context(config=config) - expressions = d.parse( + model_file = tmp_path / "models" / "model_with_semicolon.sql" + model_file.write_text( """ MODEL ( - name memory.db.table, - kind INCREMENTAL_BY_TIME_RANGE ( - time_column c, - forward_only True - ), + name sqlmesh_example.incremental_model_with_semicolon, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column event_date + ), + start '2020-01-01', + cron '@daily', + grain (id, event_date) ); - SELECT a, b, c FROM source_table; + + SELECT + 1 AS id, + 1 AS item_id, + CAST('2020-01-01' AS DATE) AS event_date + ; + + --Just a comment """ ) - model = load_sql_based_model(expressions, defaults=config.model_defaults.dict()) - context.upsert_model(model) - context_model = context.get_model("memory.db.table") - assert context_model.on_destructive_change.is_warn - # WARN specified as model default, does not override non-incremental sqlmesh default ALLOW - config = Config( - model_defaults=ModelDefaultsConfig(dialect="duckdb", on_destructive_change="warn") + ctx = Context(paths=tmp_path, config=config) + model = ctx.get_model("sqlmesh_example.incremental_model_with_semicolon") + + assert not model.pre_statements + assert not model.post_statements + + assert_exp_eq( + model.render_query(), + 'SELECT 1 AS "id", 1 AS "item_id", CAST(\'2020-01-01\' AS DATE) AS "event_date"', ) - context = Context(config=config) + ctx.format() - expressions = d.parse( + assert ( + model_file.read_text() + == """MODEL ( + name sqlmesh_example.incremental_model_with_semicolon, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column event_date + ), + start '2020-01-01', + cron '@daily', + grain (id, event_date) +); + +SELECT + 1 AS id, + 1 AS item_id, + '2020-01-01'::DATE AS event_date; + +/* Just a comment */""" + ) + + ctx.plan(no_prompts=True, auto_apply=True) + + model_file = tmp_path / "models" / "model_with_semicolon.sql" + model_file.write_text( """ MODEL ( - name memory.db.table, - kind FULL, + name sqlmesh_example.incremental_model_with_semicolon, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column event_date + ), + start '2020-01-01', + cron '@daily', + grain (id, event_date) ); - SELECT a, b, c FROM source_table; + + SELECT + 1 AS id, + 1 AS item_id, + CAST('2020-01-01' AS DATE) AS event_date """ ) - model = load_sql_based_model(expressions, defaults=config.model_defaults.dict()) - context.upsert_model(model) - context_model = context.get_model("memory.db.table") - assert context_model.on_destructive_change.is_allow + ctx.load() + plan = ctx.plan(no_prompts=True, auto_apply=True) -def test_incremental_by_partition(sushi_context, assert_exp_eq): - expressions = d.parse( - """ + assert len(plan.context_diff.modified_snapshots) == 1 + assert len(plan.new_snapshots) == 1 + assert plan.new_snapshots[0].is_metadata + + +def test_invalid_audit_reference(): + sql = """ + MODEL ( + name test, + audits (not_nulll (columns := (id))) + ); + + SELECT + 1 AS id + """ + expressions = d.parse(sql) + + with pytest.raises(ConfigError, match="Audit 'not_nulll' is undefined"): + load_sql_based_model(expressions) + + +def test_invalid_signal_reference(): + sql = """ + MODEL ( + name test, + signals (s()) + ); + + SELECT + 1 AS id + """ + expressions = d.parse(sql) + + with pytest.raises(ConfigError, match="Signal 's' is undefined"): + load_sql_based_model(expressions) + + +def test_scd_time_data_type_does_not_cause_diff_after_deserialization() -> None: + for dialect in ( + "athena", + "bigquery", + "clickhouse", + "databricks", + "duckdb", + "dune", + "hive", + "mysql", + "postgres", + "presto", + "redshift", + "snowflake", + "spark", + "trino", + "tsql", + ): + sql = f""" MODEL ( - name db.table, - kind INCREMENTAL_BY_PARTITION, - partitioned_by [a], + name test_schema.test_model, + kind SCD_TYPE_2_BY_COLUMN ( + unique_key ARRAY(col), + columns ARRAY(col), + invalidate_hard_deletes TRUE, + on_destructive_change error + ), + dialect {dialect} ); - SELECT a, b - """ + SELECT + 1 AS col + """ + + model = load_sql_based_model(d.parse(sql)) + deserialized_model = SqlModel.parse_raw(model.json()) + + assert model.data_hash == deserialized_model.data_hash + + +def test_python_env_includes_file_path_in_render_definition(): + @model( + "db.test_model_path", + kind=dict( + name=ModelKindName.SCD_TYPE_2_BY_TIME, + unique_key=["id"], + disable_restatement=False, + ), + columns={"id": "string", "name": "string"}, + optimize_query=False, + ) + def test_model(context, **kwargs): + return pd.DataFrame([{"id": context.var("1")}]) + + python_model = model.get_registry()["db.test_model_path"].model( + module_path=Path("."), path=Path("."), dialect="duckdb" ) - model = load_sql_based_model(expressions) - assert model.kind.is_incremental_by_partition - assert model.kind.disable_restatement - expressions = d.parse( - """ - MODEL ( - name db.table, - kind INCREMENTAL_BY_PARTITION ( - disable_restatement false - ), - partitioned_by [a], - ); + model_executable_str = python_model.render_definition()[1].sql() + # Make sure the file path is included in the render definition + assert "# tests/core/test_model.py" in model_executable_str - SELECT a, b - """ + +def test_resolve_interpolated_variables_when_parsing_python_deps(): + @model( + name="bla.test_interpolate_var_in_dep_py", + kind="full", + columns={'"col"': "int"}, ) - model = load_sql_based_model(expressions) - assert model.kind.is_incremental_by_partition - assert not model.kind.disable_restatement + def unimportant_testing_model(context, **kwargs): + table1 = context.resolve_table(f"{context.var('schema_name')}.table_name") + table2 = context.resolve_table(f"{context.blueprint_var('schema_name')}.table_name") - with pytest.raises( - ConfigError, - match=r".*partitioned_by field is required for INCREMENTAL_BY_PARTITION models.*", - ): - expressions = d.parse( + return context.fetchdf(exp.select("*").from_(table)) + + m = model.get_registry()["bla.test_interpolate_var_in_dep_py"].model( + module_path=Path("."), + path=Path("."), + variables={"schema_name": "foo"}, + blueprint_variables={"schema_name": "baz"}, + ) + + assert m.depends_on == {'"foo"."table_name"', '"baz"."table_name"'} + assert m.python_env.get(c.SQLMESH_VARS) == Executable.value({"schema_name": "foo"}) + assert m.python_env.get(c.SQLMESH_BLUEPRINT_VARS) == Executable.value({"schema_name": "baz"}) + + @macro() + def unimportant_testing_macro(evaluator, *projections): + evaluator.var(f"{evaluator.var('selector')}_variable") + evaluator.var(f"{evaluator.blueprint_var('selector')}_variable") + + return exp.select(*[f'{p} AS "{p}"' for p in projections]) + + m = load_sql_based_model( + d.parse( """ MODEL ( - name db.table, - kind INCREMENTAL_BY_PARTITION, + name bla.test_interpolate_var_in_dep_sql ); - SELECT a, b - """ - ) - load_sql_based_model(expressions) - + @unimportant_testing_macro(); -@pytest.mark.parametrize( - ["model_def", "path", "expected_name"], - [ - [ - """dialect duckdb,""", - """models/test_schema/test_model.sql,""", - "test_schema.test_model", - ], - [ - """dialect duckdb,""", - """models/test_model.sql,""", - "test_model", - ], - [ - """dialect duckdb,""", - """models/inventory/db/test_schema/test_model.sql,""", - "db.test_schema.test_model", - ], - ["""name test_model,""", """models/schema/test_model.sql,""", "test_model"], - ], -) -def test_model_table_name_inference( - sushi_context: Context, model_def: str, path: str, expected_name: str -): - model = load_sql_based_model( - d.parse( - f""" - MODEL ( - {model_def} - ); - SELECT a FROM tbl; - """, - default_dialect="duckdb", + SELECT + 1 AS c + """, ), - path=Path(f"$root/{path}"), - infer_names=True, + variables={"selector": "bla", "bla_variable": 1, "baz_variable": 2}, + blueprint_variables={"selector": "baz"}, ) - assert model.name == expected_name - -@pytest.mark.parametrize( - ["path", "expected_name"], - [ - [ - """models/test_schema/test_model.py""", - "test_schema.test_model", - ], - [ - """models/inventory/db/test_schema/test_model.py""", - "db.test_schema.test_model", - ], - ], -) -def test_python_model_name_inference(tmp_path: Path, path: str, expected_name: str) -> None: - init_example_project(tmp_path, dialect="duckdb") - config = Config( - model_defaults=ModelDefaultsConfig(dialect="duckdb"), - model_naming=NameInferenceConfig(infer_names=True), + assert m.python_env.get(c.SQLMESH_VARS) == Executable.value( + {"selector": "bla", "bla_variable": 1, "baz_variable": 2}, + sort_root_dict=True, + ) + assert m.python_env.get(c.SQLMESH_BLUEPRINT_VARS) == Executable.value( + {"selector": "baz"}, sort_root_dict=True ) - foo_py_file = tmp_path / path - foo_py_file.parent.mkdir(parents=True, exist_ok=True) - foo_py_file.write_text("""from sqlmesh import model -@model( - columns={'"COL"': "int"}, -) -def my_model(context, **kwargs): - pass""") - context = Context(paths=tmp_path, config=config) - assert context.get_model(expected_name).name == expected_name - assert isinstance(context.get_model(expected_name), PythonModel) +def test_extract_schema_in_post_statement(tmp_path: Path) -> None: + init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY) -def test_custom_kind(): - from sqlmesh import CustomMaterialization + config = Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")) - expressions = d.parse( - """ - MODEL ( - name db.table, - kind CUSTOM ( - materialization 'MyTestStrategy', - forward_only true, - disable_restatement true, - materialization_properties ( - 'key_a' = 'value_a', - key_b = 2, - 'key_c' = true, - 'key_d' = 1.23, - ), - batch_size 1, - batch_concurrency 2, - lookback 3, - ) - ); + model1 = tmp_path / "models" / "parent_model.sql" + model1.parent.mkdir(parents=True, exist_ok=True) + model1.write_text("MODEL (name x); SELECT 1 AS c") - SELECT a, b + model2 = tmp_path / "models" / "child_model.sql" + model2.parent.mkdir(parents=True, exist_ok=True) + model2.write_text( + """ + MODEL (name y); + SELECT c FROM x; + ON_VIRTUAL_UPDATE_BEGIN; + @check_schema('y'); + @check_self_schema(); + ON_VIRTUAL_UPDATE_END; """ ) - with pytest.raises( - ConfigError, match=r"Materialization strategy with name 'MyTestStrategy' was not found.*" - ): - load_sql_based_model(expressions) + check_schema = tmp_path / "macros/check_schema.py" + check_schema.parent.mkdir(parents=True, exist_ok=True) + check_schema.write_text(""" +from sqlglot import exp +from sqlmesh import macro - class MyTestStrategy(CustomMaterialization): - pass +@macro() +def check_schema(evaluator, model_name: str): + if evaluator.runtime_stage != 'loading': + assert evaluator.columns_to_types(model_name) == {"c": exp.DataType.build("INT")} - model = load_sql_based_model(expressions) - assert model.kind.is_custom +@macro() +def check_self_schema(evaluator): + if evaluator.runtime_stage != 'loading': + assert evaluator.columns_to_types(evaluator.this_model_fqn) == {"c": exp.DataType.build("INT")} +""") - kind = t.cast(CustomKind, model.kind) - assert kind.disable_restatement - assert kind.forward_only - assert kind.materialization == "MyTestStrategy" - assert kind.materialization_properties == { - "key_a": "value_a", - "key_b": 2, - "key_c": True, - "key_d": 1.23, - } - assert kind.batch_size == 1 - assert kind.batch_concurrency == 2 - assert kind.lookback == 3 + context = Context(paths=tmp_path, config=config) + context.plan(no_prompts=True, auto_apply=True) - assert ( - kind.to_expression().sql() - == """CUSTOM ( -materialization 'MyTestStrategy', -materialization_properties ('key_a' = 'value_a', key_b = 2, 'key_c' = TRUE, 'key_d' = 1.23), -forward_only TRUE, -disable_restatement TRUE, -batch_size 1, -batch_concurrency 2, -lookback 3 -)""" - ) +def test_model_relies_on_os_getenv(tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY) -def test_model_kind_to_expression(): - assert ( - load_sql_based_model( - d.parse( - """ - MODEL ( - name db.table, - kind INCREMENTAL_BY_TIME_RANGE( - time_column a, - ), - ); - SELECT a, b + (tmp_path / "macros" / "getenv_macro.py").write_text( """ - ) - ) - .kind.to_expression() - .sql() - == """INCREMENTAL_BY_TIME_RANGE ( -time_column ("a", '%Y-%m-%d'), -forward_only FALSE, -disable_restatement FALSE, -on_destructive_change 'ERROR' -)""" +from os import getenv +from sqlmesh import macro + +@macro() +def getenv_macro(evaluator): + getenv("foo", None) + return 1""" + ) + (tmp_path / "models" / "model.sql").write_text( + "MODEL (name test); SELECT @getenv_macro() AS foo" ) - assert ( - load_sql_based_model( - d.parse( - """ + monkeypatch.chdir(tmp_path) + ctx = Context(paths=tmp_path) + + +def test_invalid_sql_model_query() -> None: + for kind in ("", ", KIND FULL"): + expressions = d.parse( + f""" + MODEL (name db.table{kind}); + + JINJA_STATEMENT_BEGIN; + SELECT 1 AS c; + JINJA_END; + """ + ) + + with pytest.raises( + ConfigError, + match=r"^A query is required and must be a SELECT statement, a UNION statement, or a JINJA_QUERY block.*", + ): + load_sql_based_model(expressions) + + +def test_query_label_and_authorization_macro() -> None: + @macro() + def test_query_label_macro(evaluator): + return "[('key', 'value')]" + + @macro() + def test_authorization_macro(evaluator): + return exp.Literal.string("test_authorization") + + expressions = d.parse( + """ MODEL ( - name db.table, - kind INCREMENTAL_BY_TIME_RANGE( - time_column a, - batch_size 1, - batch_concurrency 2, - lookback 3, - forward_only TRUE, - disable_restatement TRUE, - on_destructive_change WARN, - ), + name db.table, + session_properties ( + query_label = @test_query_label_macro(), + authorization = @test_authorization_macro() + ) ); - SELECT a, b + + SELECT 1 AS c; """ - ) - ) - .kind.to_expression() - .sql() - == """INCREMENTAL_BY_TIME_RANGE ( -time_column ("a", '%Y-%m-%d'), -batch_size 1, -batch_concurrency 2, -lookback 3, -forward_only TRUE, -disable_restatement TRUE, -on_destructive_change 'WARN' -)""" ) - assert ( - load_sql_based_model( - d.parse( - """ - MODEL ( - name db.table, - kind INCREMENTAL_BY_UNIQUE_KEY( - unique_key a, - ), + model = load_sql_based_model(expressions) + assert model.session_properties == { + "query_label": d.parse_one("@test_query_label_macro()"), + "authorization": d.parse_one("@test_authorization_macro()"), + } + + assert model.render_session_properties() == { + "query_label": d.parse_one("[('key', 'value')]"), + "authorization": d.parse_one("'test_authorization'"), + } + + +def test_boolean_property_validation() -> None: + expressions = d.parse( + """ + MODEL ( + name db.table, + enabled @IF(TRUE, TRUE, FALSE), + dialect tsql ); - SELECT a, b + + SELECT 1 AS c; """ - ) - ) - .kind.to_expression() - .sql() - == """INCREMENTAL_BY_UNIQUE_KEY ( -unique_key ("a"), -batch_concurrency 1, -forward_only FALSE, -disable_restatement FALSE, -on_destructive_change 'ERROR' -)""" ) + model = load_sql_based_model(expressions, dialect="tsql") + assert model.enabled - assert ( - load_sql_based_model( - d.parse( - """ + +def test_datetime_without_timezone_variable_redshift() -> None: + expressions = d.parse( + """ MODEL ( - name db.table, - kind INCREMENTAL_BY_UNIQUE_KEY( - unique_key a, - when_matched WHEN MATCHED THEN UPDATE SET target.b = COALESCE(source.b, target.b) + name test, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column test_time_col, + batch_size 1, + batch_concurrency 1 ), + start '2025-06-01', + dialect redshift ); - SELECT a, b + + SELECT @start_dtntz AS test_time_col """ - ) - ) - .kind.to_expression() - .sql() - == """INCREMENTAL_BY_UNIQUE_KEY ( -unique_key ("a"), -when_matched WHEN MATCHED THEN UPDATE SET __MERGE_TARGET__.b = COALESCE(__MERGE_SOURCE__.b, __MERGE_TARGET__.b), -batch_concurrency 1, -forward_only FALSE, -disable_restatement FALSE, -on_destructive_change 'ERROR' -)""" ) + model = load_sql_based_model(expressions, dialect="redshift") assert ( - load_sql_based_model( - d.parse( - """ - MODEL ( - name db.table, - kind INCREMENTAL_BY_PARTITION, - partitioned_by ["a"], - ); - SELECT a, b + model.render_query_or_raise().sql("redshift") + == '''SELECT CAST('1970-01-01 00:00:00' AS TIMESTAMP) AS "test_time_col"''' + ) + + +def test_python_model_cron_with_blueprints(tmp_path: Path) -> None: + init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY) + + cron_blueprint_model = tmp_path / "models" / "cron_blueprint.py" + cron_blueprint_model.parent.mkdir(parents=True, exist_ok=True) + cron_blueprint_model.write_text( """ - ) - ) - .kind.to_expression() - .sql() - == """INCREMENTAL_BY_PARTITION ( -forward_only TRUE, -disable_restatement TRUE, -on_destructive_change 'ERROR' -)""" +import typing as t +from datetime import datetime + +import pandas as pd +from sqlmesh import ExecutionContext, model + +@model( + "@{customer}.some_table", + kind="FULL", + cron="@'*/@{min} * * * *'", + blueprints=[ + {"customer": "customer1", "field_a": "x", "field_b": "y", "min": 5}, + {"customer": "customer2", "field_a": "z", "field_b": "w", "min": 10}, + ], + columns={ + "field_a": "text", + "field_b": "text", + "customer": "text", + }, + enabled=True +) +def entrypoint( + context: ExecutionContext, + start: datetime, + end: datetime, + execution_time: datetime, + **kwargs: t.Any, +) -> pd.DataFrame: + return pd.DataFrame( + { + "field_a": [context.blueprint_var("field_a")], + "field_b": [context.blueprint_var("field_b")], + "customer": [context.blueprint_var("customer")], + } + ) +""" ) - assert ( - load_sql_based_model( - d.parse( - """ - MODEL ( - name db.seed, - kind SEED ( - path '../seeds/waiter_names.csv', - ) - ); + context = Context( + paths=tmp_path, config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")) + ) + models = context.models + + # Test first blueprint + customer1_model = models.get('"memory"."customer1"."some_table"') + assert customer1_model is not None + assert customer1_model.cron == "*/5 * * * *" + assert customer1_model.enabled + assert "blueprints" not in customer1_model.all_fields() + assert customer1_model.python_env.get(c.SQLMESH_BLUEPRINT_VARS) == Executable.value( + {"customer": "customer1", "field_a": "x", "field_b": "y"} + ) + + # Test second blueprint + customer2_model = models.get('"memory"."customer2"."some_table"') + assert customer2_model is not None + assert customer2_model.cron == "*/10 * * * *" + assert customer2_model.python_env.get(c.SQLMESH_BLUEPRINT_VARS) == Executable.value( + {"customer": "customer2", "field_a": "z", "field_b": "w"} + ) + + # Test that the models can be planned and applied + context.plan(no_prompts=True, auto_apply=True, no_diff=True) + + # Verify the data + assert context.fetchdf('from "memory"."customer1"."some_table"').to_dict() == { + "field_a": {0: "x"}, + "field_b": {0: "y"}, + "customer": {0: "customer1"}, + } + assert context.fetchdf('from "memory"."customer2"."some_table"').to_dict() == { + "field_a": {0: "z"}, + "field_b": {0: "w"}, + "customer": {0: "customer2"}, + } + + +def test_python_model_cron_macro_rendering(tmp_path: Path) -> None: + init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY) + + cron_macro_model = tmp_path / "models" / "cron_macro.py" + cron_macro_model.parent.mkdir(parents=True, exist_ok=True) + cron_macro_model.write_text( """ - ), - path=Path("./examples/sushi/models/test_model.sql"), - ) - .kind.to_expression() - .sql() - == """SEED ( -path '../seeds/waiter_names.csv', -batch_size 1000 -)""" +import pandas as pd +from sqlmesh import model + +@model( + "msc.test_cron_model", + kind="FULL", + cron="@{cron_schedule}", + columns={"a": "int"}, +) +def entrypoint(context, **kwargs): + return pd.DataFrame([{"a": 1}]) +""" ) - assert ( - load_sql_based_model( - d.parse( - """ - MODEL ( - name db.table, - kind SCD_TYPE_2_BY_TIME ( - unique_key [a, b] - ) - ); - SELECT a, b + # Test with cron alias + context_daily = Context( + paths=tmp_path, + config=Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + variables={"cron_schedule": "@daily"}, + ), + ) + model_daily = context_daily.models.get('"memory"."msc"."test_cron_model"') + + assert model_daily is not None + assert model_daily.cron == "@daily" + + # Test with cron expression + context_expr = Context( + paths=tmp_path, + config=Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + variables={"cron_schedule": "0 */2 * * *"}, + ), + ) + model_expr = context_expr.models.get('"memory"."msc"."test_cron_model"') + assert model_expr is not None + assert model_expr.cron == "0 */2 * * *" + + +def test_python_model_normal_cron(tmp_path: Path) -> None: + init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY) + + cron_macro_model = tmp_path / "models" / "cron_macro.py" + cron_macro_model.parent.mkdir(parents=True, exist_ok=True) + cron_macro_model.write_text( """ - ) +import pandas as pd +from sqlmesh import model + +@model( + "msc.normal_test_cron_model", + kind="FULL", + cron="@daily", + columns={"a": "int"}, +) +def entrypoint(context, **kwargs): + return pd.DataFrame([{"a": 1}]) +""" + ) + + # Test with cron alias + context_daily = Context( + paths=tmp_path, + config=Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + variables={"cron_schedule": "@daily"}, + ), + ) + model_daily = context_daily.models.get('"memory"."msc"."normal_test_cron_model"') + + assert model_daily is not None + assert model_daily.cron == "@daily" + + +def test_render_query_optimize_query_false(assert_exp_eq, sushi_context): + snapshots = sushi_context.snapshots + + model = sushi_context.get_model("sushi.top_waiters") + model = model.copy(update={"optimize_query": False}) + + upstream_model_version = sushi_context.get_snapshot("sushi.waiter_revenue_by_day").version + + assert_exp_eq( + model.render_query(snapshots=snapshots).sql(), + f""" + WITH "test_macros" AS ( + SELECT + 2 AS "lit_two", + "revenue" * 2.0 AS "sql_exp", + CAST("revenue" AS TEXT) AS "sql_lit" + FROM "memory"."sqlmesh__sushi"."sushi__waiter_revenue_by_day__{upstream_model_version}" AS "waiter_revenue_by_day" /* memory.sushi.waiter_revenue_by_day */ ) - .kind.to_expression() - .sql() - == """SCD_TYPE_2_BY_TIME ( -updated_at_name "updated_at", -updated_at_as_valid_from FALSE, -unique_key ("a", "b"), -valid_from_name "valid_from", -valid_to_name "valid_to", -invalidate_hard_deletes FALSE, -time_data_type TIMESTAMP, -forward_only TRUE, -disable_restatement TRUE, -on_destructive_change 'ERROR' -)""" + SELECT + CAST("waiter_id" AS INT) AS "waiter_id", + CAST("revenue" AS DOUBLE) AS "revenue" + FROM "memory"."sqlmesh__sushi"."sushi__waiter_revenue_by_day__{upstream_model_version}" AS "waiter_revenue_by_day" /* memory.sushi.waiter_revenue_by_day */ + WHERE + "event_date" = ( + SELECT + MAX("event_date") + FROM "memory"."sqlmesh__sushi"."sushi__waiter_revenue_by_day__{upstream_model_version}" AS "waiter_revenue_by_day" /* memory.sushi.waiter_revenue_by_day */ + ) + ORDER BY + "revenue" DESC + LIMIT 10 + """, ) - assert ( - load_sql_based_model( - d.parse( - """ + +def test_each_macro_with_paren_expression_arg(assert_exp_eq): + expressions = d.parse( + """ MODEL ( - name db.table, - kind SCD_TYPE_2_BY_COLUMN ( - unique_key [a, b], - columns [b] - ) + name dataset.@table_name, + kind VIEW, + blueprints ( + ( + table_name := model1, + event_columns := ( + 'value' AS property1, + 'value' AS property2 + ) + ), + ( + table_name := model2, + event_columns := ( + 'value' AS property1 + ) + ) + ), ); - SELECT a, b, c + + SELECT @EACH(@event_columns, x -> x) """ - ) - ) - .kind.to_expression() - .sql() - == """SCD_TYPE_2_BY_COLUMN ( -columns ("b"), -execution_time_as_valid_from FALSE, -unique_key ("a", "b"), -valid_from_name "valid_from", -valid_to_name "valid_to", -invalidate_hard_deletes FALSE, -time_data_type TIMESTAMP, -forward_only TRUE, -disable_restatement TRUE, -on_destructive_change 'ERROR' -)""" ) - assert ( - load_sql_based_model( - d.parse( - """ - MODEL ( - name db.table, - kind SCD_TYPE_2_BY_COLUMN ( - unique_key [a, b], - columns * - ) - ); - SELECT a, b, c + models = load_sql_based_models(expressions, lambda _: {}) + + # Should generate 2 models from the blueprints + assert len(models) == 2 + + # Get the models sorted by name for consistent testing + model1 = next(m for m in models if "model1" in m.name) + model2 = next(m for m in models if "model2" in m.name) + + # Verify model names + assert model1.name == "dataset.model1" + assert model2.name == "dataset.model2" + + assert_exp_eq( + model1.render_query(), """ - ) - ) - .kind.to_expression() - .sql() - == """SCD_TYPE_2_BY_COLUMN ( -columns *, -execution_time_as_valid_from FALSE, -unique_key ("a", "b"), -valid_from_name "valid_from", -valid_to_name "valid_to", -invalidate_hard_deletes FALSE, -time_data_type TIMESTAMP, -forward_only TRUE, -disable_restatement TRUE, -on_destructive_change 'ERROR' -)""" + SELECT + 'value' AS "property1", + 'value' AS "property2" + """, + ) + + assert_exp_eq( + model2.render_query(), + """ + SELECT + 'value' AS "property1" + """, + ) + + +@pytest.mark.parametrize( + "macro_func, variables", + [ + ("@M(@v1)", {"v1"}), + ("@M(@{v1})", {"v1"}), + ("@M(@SQL('@v1'))", {"v1"}), + ("@M(@'@{v1}_foo')", {"v1"}), + ("@M1(@VAR('v1'))", {"v1"}), + ("@M1(@v1, @M2(@v2), @BLUEPRINT_VAR('v3'))", {"v1", "v2", "v3"}), + ("@M1(@BLUEPRINT_VAR(@VAR('v1')))", {"v1"}), + ], +) +def test_extract_macro_func_variable_references(macro_func: str, variables: t.Set[str]) -> None: + from sqlmesh.core.model.common import _extract_macro_func_variable_references + + macro_func_ast = parse_one(macro_func) + assert _extract_macro_func_variable_references(macro_func_ast, True)[0] == variables + + +def test_text_diff_column_descriptions(): + """Test that column_descriptions changes are visible in text_diff.""" + # Create model without column descriptions + model1 = create_sql_model( + name="test.model", + query=parse("SELECT id, name FROM upstream")[0], + ) + + # Create model with column descriptions + model2 = create_sql_model( + name="test.model", + query=parse("SELECT id, name FROM upstream")[0], + column_descriptions={"id": "User identifier", "name": "User name"}, + ) + + # Verify the diff shows the column_descriptions + diff = model1.text_diff(model2) + assert diff, "Expected diff to show column_descriptions change" + assert "+ id = 'User identifier'," in diff + assert "+ name = 'User name'" in diff + + # Verify reverse diff also works + diff = model2.text_diff(model1) + assert diff, "Expected reverse diff to show column_descriptions removal" + assert "- id = 'User identifier'," in diff + assert "- name = 'User name'" in diff + + +def test_text_diff_optimize_query(): + """Test that optimize_query changes are visible in text_diff.""" + # Create model without optimize_query + model1 = create_sql_model( + name="test.model", + query=parse("SELECT id, name FROM upstream")[0], ) - assert ( - load_sql_based_model( - d.parse( - """ - MODEL ( - name db.table, - kind FULL - ); - SELECT a, b, c + # Create model with optimize_query enabled + model2 = create_sql_model( + name="test.model", + query=parse("SELECT id, name FROM upstream")[0], + optimize_query=True, + ) + + # Verify the diff shows the optimize_query change + diff = model1.text_diff(model2) + assert diff, "Expected diff to show optimize_query change" + assert "+ optimize_query" in diff.lower() + + +def test_raw_jinja_raw_tag(): + expressions = d.parse( + """ + MODEL (name test); + + JINJA_QUERY_BEGIN; + SELECT {% raw %} '{{ foo }}' {% endraw %} AS col; + JINJA_END; """ - ) - ) - .kind.to_expression() - .sql() - == "FULL" ) - assert ( - load_sql_based_model( - d.parse( - """ - MODEL ( - name db.table, - kind VIEW + model = load_sql_based_model(expressions) + assert model.render_query().sql() == "SELECT '{{ foo }}' AS \"col\"" + + +def test_use_original_sql(): + expressions = d.parse( + """ + MODEL (name test); + + CREATE TABLE pre ( + a INT + ); + + SELECT + 1, + 2; + + CREATE TABLE post ( + b INT ); - SELECT a, b, c """ - ) - ) - .kind.to_expression() - .sql() - == """VIEW ( -materialized FALSE -)""" ) - assert ( - load_sql_based_model( - d.parse( - """ + model = load_sql_based_model(expressions) + assert model.query_.sql == "SELECT\n 1,\n 2" + assert model.pre_statements_[0].sql == "CREATE TABLE pre (\n a INT\n )" + assert model.post_statements_[0].sql == "CREATE TABLE post (\n b INT\n );" + + # Now manually create the model and make sure that the original SQL is not used + model_query = d.parse_one("SELECT 1 AS one") + assert model_query.meta["sql"] == "SELECT 1 AS one" + model_query = model_query.select("2 AS two") + + pre_statements = [d.parse_one("CREATE TABLE pre (\n a INT\n )")] + post_statements = [d.parse_one("CREATE TABLE post (\n b INT\n );")] + + model = create_sql_model( + "test", + model_query, + pre_statements=pre_statements, + post_statements=post_statements, + ) + assert model.query_.sql == "SELECT 1 AS one, 2 AS two" + assert model.pre_statements_[0].sql == "CREATE TABLE pre (a INT)" + assert model.post_statements_[0].sql == "CREATE TABLE post (b INT)" + + +def test_case_sensitive_macro_locals(tmp_path: Path) -> None: + init_example_project(tmp_path, engine_type="duckdb", template=ProjectTemplate.EMPTY) + + db_path = str(tmp_path / "db.db") + db_connection = DuckDBConnectionConfig(database=db_path) + + config = Config( + gateways={"gw": GatewayConfig(connection=db_connection)}, + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + ) + + macro_file = tmp_path / "macros" / "some_macro_with_globals.py" + macro_file.parent.mkdir(parents=True, exist_ok=True) + macro_file.write_text( + """from sqlmesh import macro + +x = 1 +X = 2 + +@macro() +def my_macro(evaluator): + assert evaluator.locals.get("x") == 1 + assert evaluator.locals.get("X") == 2 + + return x + X +""" + ) + test_model = tmp_path / "models" / "test_model.sql" + test_model.parent.mkdir(parents=True, exist_ok=True) + test_model.write_text("MODEL (name test_model, kind FULL); SELECT @my_macro() AS c") + + context = Context(paths=tmp_path, config=config) + model = context.get_model("test_model", raise_if_missing=True) + + assert model.render_query_or_raise().sql() == 'SELECT 3 AS "c"' + + +def test_grants(): + expressions = d.parse(""" MODEL ( - name db.table, - kind VIEW (materialized true) - ); - SELECT a, b, c - """ + name test.table, + kind FULL, + grants ( + 'select' = ['user1', 123, admin_role, 'user2'], + 'insert' = 'admin', + 'roles/bigquery.dataViewer' = ["group:data_eng@company.com", 'user:someone@company.com'], + 'update' = 'admin' ) - ) - .kind.to_expression() - .sql() - == """VIEW ( -materialized TRUE -)""" + ); + SELECT 1 as id + """) + model = load_sql_based_model(expressions) + assert model.grants == { + "select": ["user1", "123", "admin_role", "user2"], + "insert": ["admin"], + "roles/bigquery.dataViewer": ["group:data_eng@company.com", "user:someone@company.com"], + "update": ["admin"], + } + + model = create_sql_model( + "db.table", + parse_one("SELECT 1 AS id"), + kind="FULL", + grants={ + "select": ["user1", "user2"], + "insert": ["admin"], + "roles/bigquery.dataViewer": "user:data_eng@company.com", + }, ) + assert model.grants == { + "select": ["user1", "user2"], + "insert": ["admin"], + "roles/bigquery.dataViewer": ["user:data_eng@company.com"], + } @pytest.mark.parametrize( - "metadata_only", - [True, False], + "kind", + [ + "FULL", + "VIEW", + SeedKind(path="test.csv"), + IncrementalByTimeRangeKind(time_column="ds"), + IncrementalByUniqueKeyKind(unique_key="id"), + ], ) -def test_macro_func_hash(metadata_only): - macro.set_registry({}) +def test_grants_valid_model_kinds(kind: t.Union[str, _ModelKind]): + model = create_sql_model( + "db.table", + parse_one("SELECT 1 AS id"), + kind=kind, + grants={"select": ["user1", "user2"], "insert": ["admin_user"]}, + ) + assert model.grants == {"select": ["user1", "user2"], "insert": ["admin_user"]} - @macro(metadata_only=metadata_only) - def noop(evaluator) -> None: - return None - expressions = d.parse( - """ +@pytest.mark.parametrize( + "kind", + [ + "EXTERNAL", + "EMBEDDED", + ], +) +def test_grants_invalid_model_kind_errors(kind: str): + with pytest.raises(ValidationError, match=rf".*grants cannot be set for {kind}.*"): + create_sql_model( + "db.table", + parse_one("SELECT 1 AS id"), + kind=kind, + grants={"select": ["user1"], "insert": ["admin_user"]}, + ) + + +def test_model_kind_supports_grants(): + assert FullKind().supports_grants is True + assert ViewKind().supports_grants is True + assert IncrementalByTimeRangeKind(time_column="ds").supports_grants is True + assert IncrementalByUniqueKeyKind(unique_key=["id"]).supports_grants is True + assert SCDType2ByTimeKind(unique_key=["id"]).supports_grants is True + + assert EmbeddedKind().supports_grants is False + assert ExternalKind().supports_grants is False + + +def test_grants_validation_no_grants(): + model = create_sql_model("db.table", parse_one("SELECT 1 AS id"), kind="FULL") + assert model.grants is None + + +def test_grants_validation_empty_grantees(): + model = create_sql_model( + "db.table", parse_one("SELECT 1 AS id"), kind="FULL", grants={"select": []} + ) + assert model.grants == {"select": []} + + +def test_grants_single_value_conversions(): + expressions = d.parse(f""" MODEL ( - name db.model, + name test.nested_arrays, + kind FULL, + grants ( + 'select' = "user1", update = user2 + ) ); + SELECT 1 as id + """) + model = load_sql_based_model(expressions) + assert model.grants == {"select": ["user1"], "update": ["user2"]} - SELECT 1; - """ + model = create_sql_model( + "db.table", + parse_one("SELECT 1 AS id"), + kind="FULL", + grants={"select": "user1", "insert": 123}, ) - model = load_sql_based_model(expressions, path=Path("./examples/sushi/models/test_model.sql")) + assert model.grants == {"select": ["user1"], "insert": ["123"]} - expressions = d.parse( - """ + +@pytest.mark.parametrize( + "grantees", + [ + "('user1', ('user2', 'user3'), 'user4')", + "('user1', ['user2', 'user3'], user4)", + "['user1', ['user2', user3], 'user4']", + "[user1, ('user2', \"user3\"), 'user4']", + ], +) +def test_grants_array_flattening(grantees: str): + expressions = d.parse(f""" MODEL ( - name db.model, + name test.nested_arrays, + kind FULL, + grants ( + 'select' = {grantees} + ) ); + SELECT 1 as id + """) + model = load_sql_based_model(expressions) + assert model.grants == {"select": ["user1", "user2", "user3", "user4"]} - SELECT 1; - @noop(); - """ - ) - new_model = load_sql_based_model( - expressions, path=Path("./examples/sushi/models/test_model.sql") +def test_grants_macro_var_resolved(): + expressions = d.parse(""" + MODEL ( + name test.macro_grants, + kind FULL, + grants ( + 'select' = @VAR('readers'), + 'insert' = @VAR('writers') + ) + ); + SELECT 1 as id + """) + model = load_sql_based_model( + expressions, variables={"readers": ["user1", "user2"], "writers": "admin"} ) - if metadata_only: - assert "noop" not in new_model._data_hash_values[0] - assert "noop" in new_model._additional_metadata[0] - assert model.data_hash == new_model.data_hash - assert model.metadata_hash(audits={}) != new_model.metadata_hash(audits={}) - else: - assert "noop" in new_model._data_hash_values[0] - assert not new_model._additional_metadata - assert model.data_hash != new_model.data_hash - assert model.metadata_hash(audits={}) == new_model.metadata_hash(audits={}) + assert model.grants == { + "select": ["user1", "user2"], + "insert": ["admin"], + } - @macro(metadata_only=metadata_only) - def noop(evaluator) -> None: - print("noop") - return None - updated_model = load_sql_based_model( - expressions, path=Path("./examples/sushi/models/test_model.sql") - ) - if metadata_only: - assert "print" not in new_model._additional_metadata[0] - assert "print" in updated_model._additional_metadata[0] - assert new_model.data_hash == updated_model.data_hash - assert new_model.metadata_hash(audits={}) != updated_model.metadata_hash(audits={}) - else: - assert "print" not in new_model._data_hash_values[0] - assert "print" in updated_model._data_hash_values[0] - assert new_model.data_hash != updated_model.data_hash - assert new_model.metadata_hash(audits={}) == updated_model.metadata_hash(audits={}) +def test_grants_macro_var_in_array_flattening(): + expressions = d.parse(""" + MODEL ( + name test.macro_in_array, + kind FULL, + grants ( + 'select' = ['user1', @VAR('admins'), 'user3'] + ) + ); + SELECT 1 as id + """) + model = load_sql_based_model(expressions, variables={"admins": ["admin1", "admin2"]}) + assert model.grants == {"select": ["user1", "admin1", "admin2", "user3"]} -def test_managed_kind_sql(): - expressions = d.parse( - """ + model2 = load_sql_based_model(expressions, variables={"admins": "super_admin"}) + assert model2.grants == {"select": ["user1", "super_admin", "user3"]} + + +def test_grants_dynamic_permission_names(): + expressions = d.parse(""" MODEL ( - name db.table, - kind MANAGED, - physical_properties ( - warehouse = small, - target_lag = '20 minutes', - refresh_mode = auto + name test.dynamic_keys, + kind FULL, + grants ( + @VAR('read_perm') = ['user1', 'user2'], + @VAR('write_perm') = ['admin'] ) ); - - SELECT a, b - """ + SELECT 1 as id + """) + model = load_sql_based_model( + expressions, variables={"read_perm": "select", "write_perm": "insert"} ) + assert model.grants == {"select": ["user1", "user2"], "insert": ["admin"]} - model = load_sql_based_model(expressions) - assert model.kind.is_managed +def test_grants_unresolved_macro_errors(): + expressions1 = d.parse(""" + MODEL (name test.bad1, kind FULL, grants ('select' = @VAR('undefined'))); + SELECT 1 as id + """) + with pytest.raises(ConfigError, match=r"Invalid grants configuration for 'select': NULL value"): + load_sql_based_model(expressions1) - with pytest.raises(ConfigError, match=r".*must specify the 'target_lag' physical property.*"): - load_sql_based_model( - d.parse( - """ - MODEL ( - name db.table, - kind MANAGED, - dialect snowflake - ); + expressions2 = d.parse(""" + MODEL (name test.bad2, kind FULL, grants (@VAR('undefined') = ['user'])); + SELECT 1 as id + """) + with pytest.raises(ConfigError, match=r"Invalid grants configuration.*NULL value"): + load_sql_based_model(expressions2) - SELECT a, b - """ - ) - ).validate_definition() + expressions3 = d.parse(""" + MODEL (name test.bad3, kind FULL, grants ('select' = ['user', @VAR('undefined')])); + SELECT 1 as id + """) + with pytest.raises(ConfigError, match=r"Invalid grants configuration for 'select': NULL value"): + load_sql_based_model(expressions3) -def test_managed_kind_python(): - @model("test_managed_python_model", kind="managed", columns={"a": "int"}) - def execute( - context: ExecutionContext, - start: datetime, - end: datetime, - execution_time: datetime, - **kwargs: t.Any, - ) -> pd.DataFrame: - return pd.DataFrame.from_dict(data={"a": 1}, orient="index") +def test_grants_empty_values(): + model1 = create_sql_model( + "db.table", parse_one("SELECT 1 AS id"), kind="FULL", grants={"select": []} + ) + assert model1.grants == {"select": []} - with pytest.raises( - SQLMeshError, - match=r".*Cannot create Python model.*the 'MANAGED' kind doesnt support Python models", - ): - model.get_registry()["test_managed_python_model"].model( - module_path=Path("."), - path=Path("."), - ).validate_definition() + model2 = create_sql_model("db.table", parse_one("SELECT 1 AS id"), kind="FULL") + assert model2.grants is None -def test_trailing_comments(): +@pytest.mark.parametrize( + "kind, expected", + [ + ("VIEW", DataObjectType.VIEW), + ("FULL", DataObjectType.TABLE), + ("MANAGED", DataObjectType.MANAGED_TABLE), + (ViewKind(materialized=True), DataObjectType.MATERIALIZED_VIEW), + ], +) +def test_grants_table_type(kind: t.Union[str, _ModelKind], expected: DataObjectType): + model = create_sql_model("test_table", parse_one("SELECT 1 as id"), kind=kind) + assert model.grants_table_type == expected + + +def test_model_macro_using_locals_called_from_jinja(assert_exp_eq) -> None: + @macro() + def execution_date(evaluator): + return f"""'{evaluator.locals.get("execution_date")}'""" + expressions = d.parse( """ MODEL (name db.table); - /* some comment A */ - - SELECT 1; - /* some comment B */ + JINJA_QUERY_BEGIN; + SELECT {{ execution_date() }} AS col; + JINJA_END; """ ) model = load_sql_based_model(expressions) - assert not model.render_pre_statements() - assert not model.render_post_statements() + assert_exp_eq(model.render_query(), '''SELECT '1970-01-01' AS "col"''') - expressions = d.parse( + +def test_audits_in_embedded_model(): + expression = d.parse( """ MODEL ( - name db.seed, - kind SEED ( - path '../seeds/waiter_names.csv', - batch_size 100, - ) + name test.embedded_with_audits, + kind EMBEDDED, + audits (not_null (columns := (id))) ); - /* some comment A */ + SELECT 1 AS id, 'A' as value """ ) - model = load_sql_based_model(expressions, path=Path("./examples/sushi/models/test_model.sql")) - assert not model.render_pre_statements() - assert not model.render_post_statements() + with pytest.raises(ConfigError, match="Audits are not supported for embedded models"): + load_sql_based_model(expression).validate_definition() diff --git a/tests/core/test_plan.py b/tests/core/test_plan.py index 554f5501b9..4b330c376f 100644 --- a/tests/core/test_plan.py +++ b/tests/core/test_plan.py @@ -4,24 +4,29 @@ from unittest.mock import patch import pytest -from freezegun import freeze_time + +from sqlmesh.core.console import TerminalConsole +from sqlmesh.utils.metaprogramming import Executable +from tests.core.test_table_diff import create_test_console +import time_machine from pytest_mock.plugin import MockerFixture -from sqlglot import parse_one +from sqlglot import parse_one, exp +from sqlmesh.core import dialect as d from sqlmesh.core.context import Context from sqlmesh.core.context_diff import ContextDiff -from sqlmesh.core.engine_adapter import DuckDBEngineAdapter -from sqlmesh.core.environment import EnvironmentNamingInfo +from sqlmesh.core.environment import EnvironmentNamingInfo, EnvironmentStatements from sqlmesh.core.model import ( ExternalModel, FullKind, IncrementalByTimeRangeKind, + IncrementalUnmanagedKind, SeedKind, SeedModel, SqlModel, ModelKindName, ) -from sqlmesh.core.model.kind import OnDestructiveChange +from sqlmesh.core.model.kind import OnDestructiveChange, OnAdditiveChange, ViewKind from sqlmesh.core.model.seed import Seed from sqlmesh.core.plan import Plan, PlanBuilder, SnapshotIntervals from sqlmesh.core.snapshot import ( @@ -34,13 +39,13 @@ from sqlmesh.utils.dag import DAG from sqlmesh.utils.date import ( now, - now_timestamp, to_date, to_datetime, to_timestamp, yesterday_ds, ) -from sqlmesh.utils.errors import PlanError +from sqlmesh.utils.errors import PlanError, NoChangesPlanError +from sqlmesh.utils.rich import strip_ansi_codes def test_forward_only_plan_sets_version(make_snapshot, mocker: MockerFixture): @@ -59,7 +64,8 @@ def test_forward_only_plan_sets_version(make_snapshot, mocker: MockerFixture): metadata_hash="test_metadata_hash", ), version="test_version", - change_category=SnapshotChangeCategory.FORWARD_ONLY, + change_category=SnapshotChangeCategory.NON_BREAKING, + dev_table_suffix="dev", ), ) assert not snapshot_b.version @@ -70,6 +76,7 @@ def test_forward_only_plan_sets_version(make_snapshot, mocker: MockerFixture): is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added=set(), removed_snapshots={}, modified_snapshots={snapshot_b.name: (snapshot_b, snapshot_b)}, @@ -81,17 +88,16 @@ def test_forward_only_plan_sets_version(make_snapshot, mocker: MockerFixture): previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) - plan_builder = PlanBuilder(context_diff, DuckDBEngineAdapter.SCHEMA_DIFFER, forward_only=True) + plan_builder = PlanBuilder(context_diff, forward_only=True) plan_builder.build() assert snapshot_b.version == "test_version" - # Make sure that the choice can't be set manually. - with pytest.raises(PlanError, match="Choice setting is not supported by a forward-only plan."): - plan_builder.set_choice(snapshot_b, SnapshotChangeCategory.BREAKING).build() - def test_forward_only_dev(make_snapshot, mocker: MockerFixture): snapshot = make_snapshot( @@ -123,6 +129,7 @@ def test_forward_only_dev(make_snapshot, mocker: MockerFixture): is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added=set(), removed_snapshots={}, modified_snapshots={snapshot.name: (updated_snapshot, snapshot)}, @@ -131,6 +138,9 @@ def test_forward_only_dev(make_snapshot, mocker: MockerFixture): previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) yesterday_ds_mock = mocker.patch("sqlmesh.core.plan.builder.yesterday_ds") @@ -142,9 +152,7 @@ def test_forward_only_dev(make_snapshot, mocker: MockerFixture): mocker.patch("sqlmesh.core.plan.builder.now").return_value = expected_end mocker.patch("sqlmesh.core.plan.definition.now").return_value = expected_end - plan = PlanBuilder( - context_diff, DuckDBEngineAdapter.SCHEMA_DIFFER, forward_only=True, is_dev=True - ).build() + plan = PlanBuilder(context_diff, forward_only=True, is_dev=True).build() assert plan.restatements == { updated_snapshot.snapshot_id: (to_timestamp(expected_start), expected_interval_end) @@ -182,6 +190,7 @@ def test_forward_only_metadata_change_dev(make_snapshot, mocker: MockerFixture): is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added=set(), removed_snapshots={}, modified_snapshots={updated_snapshot.name: (updated_snapshot, snapshot)}, @@ -190,6 +199,9 @@ def test_forward_only_metadata_change_dev(make_snapshot, mocker: MockerFixture): previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) yesterday_ds_mock = mocker.patch("sqlmesh.core.plan.builder.yesterday_ds") @@ -201,9 +213,7 @@ def test_forward_only_metadata_change_dev(make_snapshot, mocker: MockerFixture): mocker.patch("sqlmesh.core.plan.builder.now").return_value = expected_end mocker.patch("sqlmesh.core.plan.definition.now").return_value = expected_end - plan = PlanBuilder( - context_diff, DuckDBEngineAdapter.SCHEMA_DIFFER, forward_only=True, is_dev=True - ).build() + plan = PlanBuilder(context_diff, forward_only=True, is_dev=True).build() assert not plan.restatements @@ -224,6 +234,7 @@ def test_forward_only_plan_added_models(make_snapshot, mocker: MockerFixture): is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added={snapshot_b.snapshot_id}, removed_snapshots={}, modified_snapshots={snapshot_a.name: (snapshot_a, snapshot_a)}, @@ -238,11 +249,64 @@ def test_forward_only_plan_added_models(make_snapshot, mocker: MockerFixture): previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) - PlanBuilder(context_diff, DuckDBEngineAdapter.SCHEMA_DIFFER, forward_only=True).build() - assert snapshot_a.change_category == SnapshotChangeCategory.FORWARD_ONLY + PlanBuilder(context_diff, forward_only=True).build() + assert snapshot_a.change_category == SnapshotChangeCategory.METADATA assert snapshot_b.change_category == SnapshotChangeCategory.BREAKING + assert snapshot_a.is_forward_only + assert snapshot_b.is_forward_only + + +def test_forward_only_plan_categorizes_change_model_kind_as_breaking( + make_snapshot, mocker: MockerFixture +): + snapshot_old = make_snapshot( + SqlModel( + name="a", + dialect="duckdb", + query=parse_one("select 1, ds"), + kind=dict(name=ModelKindName.INCREMENTAL_BY_TIME_RANGE, time_column="ds"), + ) + ) + + # Simulate a change in the model kind. + updated_snapshot = make_snapshot( + SqlModel( + **{ + **snapshot_old.model.dict(), + "kind": dict(name=ModelKindName.VIEW), + } + ) + ) + + context_diff = ContextDiff( + environment="prod", + is_new_environment=False, + is_unfinalized_environment=False, + normalize_environment_name=True, + create_from="prod", + create_from_env_exists=True, + added=set(), + removed_snapshots={}, + modified_snapshots={updated_snapshot.name: (updated_snapshot, snapshot_old)}, + snapshots={updated_snapshot.snapshot_id: updated_snapshot}, + new_snapshots={updated_snapshot.snapshot_id: updated_snapshot}, + previous_plan_id=None, + previously_promoted_snapshot_ids=set(), + previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], + ) + + PlanBuilder(context_diff, forward_only=True).build() + + assert updated_snapshot.change_category == SnapshotChangeCategory.BREAKING + assert not updated_snapshot.is_forward_only def test_paused_forward_only_parent(make_snapshot, mocker: MockerFixture): @@ -255,12 +319,13 @@ def test_paused_forward_only_parent(make_snapshot, mocker: MockerFixture): ), version="test_version", change_category=SnapshotChangeCategory.BREAKING, + dev_table_suffix="dev", ), ) - snapshot_a.categorize_as(SnapshotChangeCategory.FORWARD_ONLY) + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) snapshot_b_old = make_snapshot(SqlModel(name="b", query=parse_one("select 2, ds from a"))) - snapshot_b_old.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_b_old.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=False) snapshot_b = make_snapshot(SqlModel(name="b", query=parse_one("select 3, ds from a"))) assert not snapshot_b.version @@ -271,6 +336,7 @@ def test_paused_forward_only_parent(make_snapshot, mocker: MockerFixture): is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added=set(), removed_snapshots={}, modified_snapshots={snapshot_b.name: (snapshot_b, snapshot_b_old)}, @@ -282,17 +348,18 @@ def test_paused_forward_only_parent(make_snapshot, mocker: MockerFixture): previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) - PlanBuilder(context_diff, DuckDBEngineAdapter.SCHEMA_DIFFER, forward_only=False).build() + PlanBuilder(context_diff, forward_only=False).build() assert snapshot_b.change_category == SnapshotChangeCategory.BREAKING def test_forward_only_plan_allow_destructive_models( make_snapshot, make_snapshot_on_destructive_change ): - schema_differ = DuckDBEngineAdapter.SCHEMA_DIFFER - # forward-only model, not forward-only plan snapshot_a_old, snapshot_a = make_snapshot_on_destructive_change() @@ -302,6 +369,7 @@ def test_forward_only_plan_allow_destructive_models( is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added=set(), removed_snapshots={}, modified_snapshots={snapshot_a.name: (snapshot_a, snapshot_a_old)}, @@ -310,15 +378,20 @@ def test_forward_only_plan_allow_destructive_models( previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) - with pytest.raises(PlanError, match="Plan results in a destructive change to forward-only"): - PlanBuilder(context_diff_a, schema_differ, forward_only=False).build() + with pytest.raises( + PlanError, match="Plan requires a destructive change to a forward-only model" + ): + PlanBuilder(context_diff_a, forward_only=False).build() logger = logging.getLogger("sqlmesh.core.plan.builder") with patch.object(logger, "warning") as mock_logger: assert PlanBuilder( - context_diff_a, schema_differ, forward_only=False, allow_destructive_models=['"a"'] + context_diff_a, forward_only=False, allow_destructive_models=['"a"'] ).build() assert mock_logger.call_count == 0 @@ -365,6 +438,7 @@ def test_forward_only_plan_allow_destructive_models( is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added=set(), removed_snapshots={}, modified_snapshots={ @@ -381,38 +455,105 @@ def test_forward_only_plan_allow_destructive_models( previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) with pytest.raises( PlanError, - match="""Plan results in a destructive change to forward-only model '"b"'s schema.""", + match="""Plan requires a destructive change to a forward-only model.""", ): - PlanBuilder(context_diff_b, schema_differ, forward_only=True).build() + PlanBuilder(context_diff_b, forward_only=True).build() with pytest.raises( PlanError, - match="""Plan results in a destructive change to forward-only model '"c"'s schema.""", + match="""Plan requires a destructive change to a forward-only model.""", ): - PlanBuilder( - context_diff_b, schema_differ, forward_only=True, allow_destructive_models=['"b"'] - ).build() + PlanBuilder(context_diff_b, forward_only=True, allow_destructive_models=['"b"']).build() logger = logging.getLogger("sqlmesh.core.plan.builder") with patch.object(logger, "warning") as mock_logger: PlanBuilder( context_diff_b, - schema_differ, forward_only=True, allow_destructive_models=['"b"', '"c"'], ).build() assert mock_logger.call_count == 0 +def test_forward_only_plan_allow_additive_models( + mocker, make_snapshot, make_snapshot_on_additive_change +): + # forward-only model, not forward-only plan + snapshot_a_old, snapshot_a = make_snapshot_on_additive_change() + + context_diff_a = ContextDiff( + environment="prod", + is_new_environment=False, + is_unfinalized_environment=False, + normalize_environment_name=True, + create_from="prod", + create_from_env_exists=True, + added=set(), + removed_snapshots={}, + modified_snapshots={snapshot_a.name: (snapshot_a, snapshot_a_old)}, + snapshots={snapshot_a.snapshot_id: snapshot_a, snapshot_a_old.snapshot_id: snapshot_a_old}, + new_snapshots={snapshot_a.snapshot_id: snapshot_a}, + previous_plan_id=None, + previously_promoted_snapshot_ids=set(), + previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], + ) + + with pytest.raises(PlanError, match="Plan requires an additive change to a forward-only model"): + PlanBuilder(context_diff_a, forward_only=False).build() + + console = TerminalConsole() + log_warning_spy = mocker.spy(console, "log_warning") + assert PlanBuilder( + context_diff_a, forward_only=False, allow_additive_models=['"a"'], console=console + ).build() + assert log_warning_spy.call_count == 0 + + snapshot_a_old, snapshot_a = make_snapshot_on_additive_change( + on_additive_change=OnAdditiveChange.WARN + ) + + context_diff_a = ContextDiff( + environment="prod", + is_new_environment=False, + is_unfinalized_environment=False, + normalize_environment_name=True, + create_from="prod", + create_from_env_exists=True, + added=set(), + removed_snapshots={}, + modified_snapshots={snapshot_a.name: (snapshot_a, snapshot_a_old)}, + snapshots={snapshot_a.snapshot_id: snapshot_a, snapshot_a_old.snapshot_id: snapshot_a_old}, + new_snapshots={snapshot_a.snapshot_id: snapshot_a}, + previous_plan_id=None, + previously_promoted_snapshot_ids=set(), + previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], + ) + + log_warning_spy.reset_mock() + assert PlanBuilder(context_diff_a, forward_only=False, console=console).build() + log_warning_spy.assert_called_once_with(""" +Plan requires additive change to forward-only model '"a"'s schema that adds column 'three'. + +Schema changes: + ALTER TABLE "a" ADD COLUMN three TEXT""") + + def test_forward_only_model_on_destructive_change( make_snapshot, make_snapshot_on_destructive_change ): - schema_differ = DuckDBEngineAdapter.SCHEMA_DIFFER - # direct change to A snapshot_a_old, snapshot_a = make_snapshot_on_destructive_change() @@ -422,6 +563,7 @@ def test_forward_only_model_on_destructive_change( is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added=set(), removed_snapshots={}, modified_snapshots={ @@ -436,13 +578,16 @@ def test_forward_only_model_on_destructive_change( previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) with pytest.raises( PlanError, - match="""Plan results in a destructive change to forward-only model '"a"'s schema.""", + match="""Plan requires a destructive change to a forward-only model.""", ): - PlanBuilder(context_diff_1, schema_differ).build() + PlanBuilder(context_diff_1).build() # allow A, indirect change to B snapshot_a_old2, snapshot_a2 = make_snapshot_on_destructive_change( @@ -466,6 +611,7 @@ def test_forward_only_model_on_destructive_change( metadata_hash="test_metadata_hash", ), version="test_version", + dev_table_suffix="dev", ), ) @@ -475,6 +621,7 @@ def test_forward_only_model_on_destructive_change( is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added=set(), removed_snapshots={}, modified_snapshots={ @@ -492,9 +639,12 @@ def test_forward_only_model_on_destructive_change( previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) - PlanBuilder(context_diff_2, schema_differ).build() + PlanBuilder(context_diff_2).build() # allow A and B, indirect change to C snapshot_a_old3, snapshot_a3 = make_snapshot_on_destructive_change( @@ -522,6 +672,7 @@ def test_forward_only_model_on_destructive_change( metadata_hash="test_metadata_hash", ), version="test_version", + dev_table_suffix="dev", ), ) @@ -544,6 +695,7 @@ def test_forward_only_model_on_destructive_change( metadata_hash="test_metadata_hash", ), version="test_version", + dev_table_suffix="dev", ), ) @@ -553,6 +705,7 @@ def test_forward_only_model_on_destructive_change( is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added=set(), removed_snapshots={}, modified_snapshots={ @@ -573,9 +726,12 @@ def test_forward_only_model_on_destructive_change( previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) - PlanBuilder(context_diff_3, schema_differ).build() + PlanBuilder(context_diff_3).build() def test_forward_only_model_on_destructive_change_no_column_types( @@ -591,6 +747,7 @@ def test_forward_only_model_on_destructive_change_no_column_types( is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added=set(), removed_snapshots={}, modified_snapshots={ @@ -606,11 +763,14 @@ def test_forward_only_model_on_destructive_change_no_column_types( previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) logger = logging.getLogger("sqlmesh.core.plan.builder") with patch.object(logger, "warning") as mock_logger: - PlanBuilder(context_diff_1, DuckDBEngineAdapter.SCHEMA_DIFFER).build() + PlanBuilder(context_diff_1).build() assert mock_logger.call_count == 0 @@ -631,6 +791,7 @@ def test_missing_intervals_lookback(make_snapshot, mocker: MockerFixture): is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added=set(), modified_snapshots={}, removed_snapshots={}, @@ -641,6 +802,9 @@ def test_missing_intervals_lookback(make_snapshot, mocker: MockerFixture): previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) plan = Plan( @@ -651,76 +815,118 @@ def test_missing_intervals_lookback(make_snapshot, mocker: MockerFixture): execution_time="2022-01-05 12:00", is_dev=True, skip_backfill=False, + empty_backfill=False, no_gaps=False, forward_only=False, allow_destructive_models=set(), + allow_additive_models=set(), include_unmodified=True, environment_naming_info=EnvironmentNamingInfo(), directly_modified={snapshot_a.snapshot_id}, indirectly_modified={}, - ignored=set(), deployability_index=DeployabilityIndex.all_deployable(), restatements={}, + restate_all_snapshots=False, end_bounded=False, ensure_finalized_snapshots=False, + start_override_per_model=None, + end_override_per_model=None, + explain=False, ) assert not plan.missing_intervals @pytest.mark.slow -@freeze_time() +@time_machine.travel(now(), tick=False) def test_restate_models(sushi_context_pre_scheduling: Context): plan = sushi_context_pre_scheduling.plan( restate_models=["sushi.waiter_revenue_by_day", "tag:expensive"], no_prompts=True ) + + start = to_timestamp(plan.start) + tomorrow = to_timestamp(to_date("tomorrow")) + assert plan.restatements == { sushi_context_pre_scheduling.get_snapshot( "sushi.waiter_revenue_by_day", raise_if_missing=True ).snapshot_id: ( - to_timestamp(plan.start), - to_timestamp(to_date("today")), + start, + tomorrow, ), sushi_context_pre_scheduling.get_snapshot( "sushi.top_waiters", raise_if_missing=True ).snapshot_id: ( - to_timestamp(plan.start), - to_timestamp(to_date("today")), + start, + tomorrow, ), sushi_context_pre_scheduling.get_snapshot( "sushi.customer_revenue_by_day", raise_if_missing=True ).snapshot_id: ( - to_timestamp(plan.start), - to_timestamp(to_date("today")), + start, + tomorrow, ), sushi_context_pre_scheduling.get_snapshot( "sushi.customer_revenue_lifetime", raise_if_missing=True ).snapshot_id: ( - to_timestamp(plan.start), - to_timestamp(to_date("today")), + start, + tomorrow, ), } assert plan.requires_backfill + assert plan.models_to_backfill == { + '"memory"."sushi"."customer_revenue_by_day"', + '"memory"."sushi"."customer_revenue_lifetime"', + '"memory"."sushi"."items"', + '"memory"."sushi"."order_items"', + '"memory"."sushi"."orders"', + '"memory"."sushi"."top_waiters"', + '"memory"."sushi"."waiter_revenue_by_day"', + } - plan = sushi_context_pre_scheduling.plan(restate_models=["unknown_model"], no_prompts=True) - assert not plan.has_changes - assert not plan.restatements + with pytest.raises( + PlanError, + match="Selector did not return any models. Please check your model selection and try again.", + ): + sushi_context_pre_scheduling.plan(restate_models=["unknown_model"], no_prompts=True) + + with pytest.raises( + PlanError, + match="Selector did not return any models. Please check your model selection and try again.", + ): + sushi_context_pre_scheduling.plan(restate_models=["tag:unknown_tag"], no_prompts=True) - plan = sushi_context_pre_scheduling.plan(restate_models=["tag:unknown_tag"], no_prompts=True) + plan = sushi_context_pre_scheduling.plan(restate_models=["raw.demographics"], no_prompts=True) assert not plan.has_changes - assert not plan.restatements + assert plan.restatements + assert plan.models_to_backfill == { + '"memory"."raw"."demographics"', + '"memory"."sushi"."active_customers"', + '"memory"."sushi"."customers"', + '"memory"."sushi"."marketing"', + '"memory"."sushi"."orders"', + '"memory"."sushi"."raw_marketing"', + '"memory"."sushi"."waiter_as_customer_by_day"', + '"memory"."sushi"."waiter_names"', + '"memory"."sushi"."waiters"', + '"memory"."sushi"."count_customers_active"', + '"memory"."sushi"."count_customers_inactive"', + } @pytest.mark.slow -@freeze_time() -def test_restate_models_with_existing_missing_intervals(sushi_context: Context): +@time_machine.travel(now(minute_floor=False), tick=False) +def test_restate_models_with_existing_missing_intervals(init_and_plan_context: t.Callable): + sushi_context, plan = init_and_plan_context("examples/sushi") + sushi_context.apply(plan) + yesterday_ts = to_timestamp(yesterday_ds()) assert not sushi_context.plan(no_prompts=True).requires_backfill waiter_revenue_by_day = sushi_context.snapshots['"memory"."sushi"."waiter_revenue_by_day"'] - waiter_revenue_by_day.intervals = [ - (waiter_revenue_by_day.intervals[0][0], yesterday_ts), - ] + sushi_context.state_sync.remove_intervals( + [(waiter_revenue_by_day, (yesterday_ts, waiter_revenue_by_day.intervals[0][1]))] + ) assert sushi_context.plan(no_prompts=True).requires_backfill plan = sushi_context.plan(restate_models=["sushi.waiter_revenue_by_day"], no_prompts=True) @@ -750,7 +956,7 @@ def test_restate_models_with_existing_missing_intervals(sushi_context: Context): ), top_waiters_snapshot_id: ( plan_start_ts, - today_ts, + to_timestamp(to_date("tomorrow")), ), } assert plan.missing_intervals == [ @@ -764,6 +970,13 @@ def test_restate_models_with_existing_missing_intervals(sushi_context: Context): ), ] assert plan.requires_backfill + assert plan.models_to_backfill == { + top_waiters_snapshot_id.name, + waiter_revenue_by_day_snapshot_id.name, + '"memory"."sushi"."items"', + '"memory"."sushi"."order_items"', + '"memory"."sushi"."orders"', + } def test_restate_symbolic_model(make_snapshot, mocker: MockerFixture): @@ -781,6 +994,7 @@ def test_restate_symbolic_model(make_snapshot, mocker: MockerFixture): is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added=set(), removed_snapshots={}, modified_snapshots={}, @@ -789,12 +1003,13 @@ def test_restate_symbolic_model(make_snapshot, mocker: MockerFixture): previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) - plan = PlanBuilder( - context_diff, DuckDBEngineAdapter.SCHEMA_DIFFER, restate_models=[snapshot_a.name] - ).build() - assert not plan.restatements + plan = PlanBuilder(context_diff, restate_models=[snapshot_a.name]).build() + assert plan.restatements def test_restate_seed_model(make_snapshot, mocker: MockerFixture): @@ -814,6 +1029,7 @@ def test_restate_seed_model(make_snapshot, mocker: MockerFixture): is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added=set(), removed_snapshots={}, modified_snapshots={}, @@ -822,11 +1038,12 @@ def test_restate_seed_model(make_snapshot, mocker: MockerFixture): previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) - plan = PlanBuilder( - context_diff, DuckDBEngineAdapter.SCHEMA_DIFFER, restate_models=[snapshot_a.name] - ).build() + plan = PlanBuilder(context_diff, restate_models=[snapshot_a.name]).build() assert not plan.restatements @@ -837,6 +1054,7 @@ def test_restate_missing_model(make_snapshot, mocker: MockerFixture): is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added=set(), removed_snapshots={}, modified_snapshots={}, @@ -845,41 +1063,16 @@ def test_restate_missing_model(make_snapshot, mocker: MockerFixture): previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) with pytest.raises( PlanError, match=r"Cannot restate model 'missing'. Model does not exist.", ): - PlanBuilder( - context_diff, DuckDBEngineAdapter.SCHEMA_DIFFER, restate_models=["missing"] - ).build() - - -def test_new_snapshots_with_restatements(make_snapshot, mocker: MockerFixture): - snapshot_a = make_snapshot(SqlModel(name="a", query=parse_one("select 1, ds"))) - - context_diff = ContextDiff( - environment="test_environment", - is_new_environment=True, - is_unfinalized_environment=False, - normalize_environment_name=True, - create_from="prod", - added=set(), - removed_snapshots={}, - modified_snapshots={}, - snapshots={snapshot_a.snapshot_id: snapshot_a}, - new_snapshots={snapshot_a.snapshot_id: snapshot_a}, - previous_plan_id=None, - previously_promoted_snapshot_ids=set(), - previous_finalized_snapshots=None, - ) - - with pytest.raises( - PlanError, - match=r"Model changes and restatements can't be a part of the same plan.*", - ): - PlanBuilder(context_diff, DuckDBEngineAdapter.SCHEMA_DIFFER, restate_models=["a"]).build() + PlanBuilder(context_diff, restate_models=["missing"]).build() def test_end_validation(make_snapshot, mocker: MockerFixture): @@ -897,6 +1090,7 @@ def test_end_validation(make_snapshot, mocker: MockerFixture): is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added={snapshot_a.snapshot_id}, removed_snapshots={}, modified_snapshots={}, @@ -905,10 +1099,12 @@ def test_end_validation(make_snapshot, mocker: MockerFixture): previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) - schema_differ = DuckDBEngineAdapter.SCHEMA_DIFFER - dev_plan_builder = PlanBuilder(context_diff, schema_differ, end="2022-01-03", is_dev=True) + dev_plan_builder = PlanBuilder(context_diff, end="2022-01-03", is_dev=True) assert dev_plan_builder.build().end == "2022-01-03" dev_plan_builder.set_end("2022-01-04") assert dev_plan_builder.build().end == "2022-01-04" @@ -918,12 +1114,12 @@ def test_end_validation(make_snapshot, mocker: MockerFixture): ) with pytest.raises(PlanError, match=start_end_not_allowed_message): - PlanBuilder(context_diff, schema_differ, end="2022-01-03").build() + PlanBuilder(context_diff, end="2022-01-03").build() with pytest.raises(PlanError, match=start_end_not_allowed_message): - PlanBuilder(context_diff, schema_differ, start="2022-01-03").build() + PlanBuilder(context_diff, start="2022-01-03").build() - prod_plan_builder = PlanBuilder(context_diff, schema_differ) + prod_plan_builder = PlanBuilder(context_diff) with pytest.raises(PlanError, match=start_end_not_allowed_message): prod_plan_builder.set_end("2022-01-03").build() @@ -934,7 +1130,6 @@ def test_end_validation(make_snapshot, mocker: MockerFixture): context_diff.new_snapshots = {} restatement_prod_plan_builder = PlanBuilder( context_diff, - schema_differ, start="2022-01-01", end="2022-01-03", restate_models=['"a"'], @@ -944,16 +1139,29 @@ def test_end_validation(make_snapshot, mocker: MockerFixture): assert restatement_prod_plan_builder.build().end == "2022-01-04" -def test_forward_only_revert_not_allowed(make_snapshot, mocker: MockerFixture): - snapshot = make_snapshot(SqlModel(name="a", query=parse_one("select 1, ds"))) - snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - assert not snapshot.is_forward_only +def test_forward_only_plan_seed_models(make_snapshot, mocker: MockerFixture): + snapshot_a = make_snapshot( + SeedModel( + name="a", + kind=SeedKind(path="./path/to/seed"), + seed=Seed(content="content"), + column_hashes={"col": "hash1"}, + depends_on=set(), + ) + ) + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) - forward_only_snapshot = make_snapshot(SqlModel(name="a", query=parse_one("select 2, ds"))) - forward_only_snapshot.categorize_as(SnapshotChangeCategory.FORWARD_ONLY) - forward_only_snapshot.version = snapshot.version - forward_only_snapshot.unpaused_ts = now_timestamp() - assert forward_only_snapshot.is_forward_only + snapshot_a_updated = make_snapshot( + SeedModel( + name="a", + kind=SeedKind(path="./path/to/seed"), + seed=Seed(content="new_content"), + column_hashes={"col": "hash2"}, + depends_on=set(), + ) + ) + assert snapshot_a_updated.version is None + assert snapshot_a_updated.change_category is None context_diff = ContextDiff( environment="test_environment", @@ -961,34 +1169,29 @@ def test_forward_only_revert_not_allowed(make_snapshot, mocker: MockerFixture): is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added=set(), removed_snapshots={}, - modified_snapshots={snapshot.name: (snapshot, forward_only_snapshot)}, - snapshots={snapshot.snapshot_id: snapshot}, - new_snapshots={}, + modified_snapshots={snapshot_a_updated.name: (snapshot_a_updated, snapshot_a)}, + snapshots={snapshot_a_updated.snapshot_id: snapshot_a_updated}, + new_snapshots={snapshot_a_updated.snapshot_id: snapshot_a}, previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) - schema_differ = DuckDBEngineAdapter.SCHEMA_DIFFER - with pytest.raises( - PlanError, - match=r"Attempted to revert to an unrevertable version of model.*", - ): - PlanBuilder(context_diff, schema_differ, forward_only=True).build() - - # Make sure the plan can be created if a new snapshot version was enforced. - new_version_snapshot = make_snapshot( - SqlModel(name="a", query=parse_one("select 1, ds"), stamp="test_stamp") - ) - snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - context_diff.modified_snapshots = {snapshot.name: (new_version_snapshot, forward_only_snapshot)} - context_diff.new_snapshots = {new_version_snapshot.snapshot_id: new_version_snapshot} - PlanBuilder(context_diff, schema_differ, forward_only=True).build() + PlanBuilder(context_diff, forward_only=True).build() + assert snapshot_a_updated.version == snapshot_a_updated.fingerprint.to_version() + assert snapshot_a_updated.change_category == SnapshotChangeCategory.BREAKING + assert not snapshot_a_updated.is_forward_only -def test_forward_only_plan_seed_models(make_snapshot, mocker: MockerFixture): +def test_seed_model_metadata_change_no_missing_intervals( + make_snapshot: t.Callable[..., Snapshot], +): snapshot_a = make_snapshot( SeedModel( name="a", @@ -999,38 +1202,51 @@ def test_forward_only_plan_seed_models(make_snapshot, mocker: MockerFixture): ) ) snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_a.add_interval("2022-01-01", now()) - snapshot_a_updated = make_snapshot( + snapshot_a_metadata_updated = make_snapshot( SeedModel( name="a", kind=SeedKind(path="./path/to/seed"), - seed=Seed(content="new_content"), - column_hashes={"col": "hash2"}, + seed=Seed(content="content"), + column_hashes={"col": "hash1"}, depends_on=set(), + description="foo", ) ) - assert snapshot_a_updated.version is None - assert snapshot_a_updated.change_category is None + snapshot_a_metadata_updated.previous_versions = snapshot_a.all_versions + assert snapshot_a_metadata_updated.version is None + assert snapshot_a_metadata_updated.change_category is None context_diff = ContextDiff( - environment="test_environment", + environment="prod", is_new_environment=True, is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added=set(), removed_snapshots={}, - modified_snapshots={snapshot_a_updated.name: (snapshot_a_updated, snapshot_a)}, - snapshots={snapshot_a_updated.snapshot_id: snapshot_a_updated}, - new_snapshots={snapshot_a_updated.snapshot_id: snapshot_a}, + modified_snapshots={ + snapshot_a_metadata_updated.name: (snapshot_a_metadata_updated, snapshot_a) + }, + snapshots={snapshot_a_metadata_updated.snapshot_id: snapshot_a_metadata_updated}, + new_snapshots={snapshot_a_metadata_updated.snapshot_id: snapshot_a}, previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) - PlanBuilder(context_diff, DuckDBEngineAdapter.SCHEMA_DIFFER, forward_only=True).build() - assert snapshot_a_updated.version == snapshot_a_updated.fingerprint.to_version() - assert snapshot_a_updated.change_category == SnapshotChangeCategory.NON_BREAKING + plan = PlanBuilder(context_diff).build() + assert snapshot_a_metadata_updated.change_category == SnapshotChangeCategory.METADATA + assert not snapshot_a_metadata_updated.is_forward_only + assert not plan.missing_intervals # plan should have no missing intervals + assert ( + snapshot_a_metadata_updated.intervals == snapshot_a.intervals + ) # intervals should have been copied def test_start_inference(make_snapshot, mocker: MockerFixture): @@ -1048,6 +1264,7 @@ def test_start_inference(make_snapshot, mocker: MockerFixture): is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added={snapshot_b.snapshot_id}, removed_snapshots={}, modified_snapshots={}, @@ -1059,19 +1276,21 @@ def test_start_inference(make_snapshot, mocker: MockerFixture): previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) snapshot_b.add_interval("2022-01-01", now()) - schema_differ = DuckDBEngineAdapter.SCHEMA_DIFFER - plan = PlanBuilder(context_diff, schema_differ).build() + plan = PlanBuilder(context_diff).build() assert len(plan.missing_intervals) == 1 assert plan.missing_intervals[0].snapshot_id == snapshot_a.snapshot_id assert plan.start == to_timestamp("2022-01-01") # Test inference from existing intervals context_diff.snapshots = {snapshot_b.snapshot_id: snapshot_b} - plan = PlanBuilder(context_diff, schema_differ).build() + plan = PlanBuilder(context_diff).build() assert not plan.missing_intervals assert plan.start == to_datetime("2022-01-01") @@ -1088,6 +1307,7 @@ def test_auto_categorization(make_snapshot, mocker: MockerFixture): is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added=set(), removed_snapshots={}, modified_snapshots={updated_snapshot.name: (updated_snapshot, snapshot)}, @@ -1096,9 +1316,12 @@ def test_auto_categorization(make_snapshot, mocker: MockerFixture): previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) - PlanBuilder(context_diff, DuckDBEngineAdapter.SCHEMA_DIFFER).build() + PlanBuilder(context_diff).build() assert updated_snapshot.version == updated_snapshot.fingerprint.to_version() assert updated_snapshot.change_category == SnapshotChangeCategory.BREAKING @@ -1127,6 +1350,7 @@ def test_auto_categorization_missing_schema_downstream(make_snapshot, mocker: Mo is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added=set(), removed_snapshots={}, modified_snapshots={ @@ -1141,9 +1365,12 @@ def test_auto_categorization_missing_schema_downstream(make_snapshot, mocker: Mo previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) - PlanBuilder(context_diff, DuckDBEngineAdapter.SCHEMA_DIFFER).build() + PlanBuilder(context_diff).build() assert updated_snapshot.version assert updated_snapshot.change_category == SnapshotChangeCategory.BREAKING @@ -1162,6 +1389,7 @@ def test_broken_references(make_snapshot, mocker: MockerFixture): is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added=set(), removed_snapshots={snapshot_a.snapshot_id: snapshot_a.table_info}, modified_snapshots={snapshot_b.name: (snapshot_b, snapshot_b)}, @@ -1170,6 +1398,9 @@ def test_broken_references(make_snapshot, mocker: MockerFixture): previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) # Make sure the downstream snapshot doesn't have any parents, @@ -1180,7 +1411,7 @@ def test_broken_references(make_snapshot, mocker: MockerFixture): PlanError, match=r"""Removed '"a"' are referenced in '"b"'.*""", ): - PlanBuilder(context_diff, DuckDBEngineAdapter.SCHEMA_DIFFER).build() + PlanBuilder(context_diff).build() def test_broken_references_external_model(make_snapshot, mocker: MockerFixture): @@ -1196,6 +1427,7 @@ def test_broken_references_external_model(make_snapshot, mocker: MockerFixture): is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added=set(), removed_snapshots={snapshot_a.snapshot_id: snapshot_a.table_info}, modified_snapshots={snapshot_b.name: (snapshot_b, snapshot_b)}, @@ -1204,6 +1436,9 @@ def test_broken_references_external_model(make_snapshot, mocker: MockerFixture): previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) # Make sure the downstream snapshot doesn't have any parents, @@ -1211,7 +1446,7 @@ def test_broken_references_external_model(make_snapshot, mocker: MockerFixture): assert not snapshot_b.parents # Shouldn't raise - PlanBuilder(context_diff, DuckDBEngineAdapter.SCHEMA_DIFFER).build() + PlanBuilder(context_diff).build() def test_effective_from(make_snapshot, mocker: MockerFixture): @@ -1236,6 +1471,7 @@ def test_effective_from(make_snapshot, mocker: MockerFixture): is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added=set(), removed_snapshots={}, modified_snapshots={updated_snapshot.name: (updated_snapshot, snapshot)}, @@ -1244,24 +1480,27 @@ def test_effective_from(make_snapshot, mocker: MockerFixture): previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) - schema_differ = DuckDBEngineAdapter.SCHEMA_DIFFER with pytest.raises( PlanError, match="Effective date can only be set for a forward-only plan.", ): - PlanBuilder(context_diff, schema_differ).set_effective_from("2023-02-01").build() + PlanBuilder(context_diff).set_effective_from("2023-02-01").build() # The snapshot gets categorized as breaking in previous step so we want to reset that back to None updated_snapshot.change_category = None plan_builder = PlanBuilder( context_diff, - schema_differ, forward_only=True, start="2023-01-01", end="2023-03-01", + execution_time="2023-03-02 00:01:00", is_dev=True, + end_bounded=True, ) updated_snapshot.add_interval("2023-01-01", "2023-03-01") @@ -1273,9 +1512,9 @@ def test_effective_from(make_snapshot, mocker: MockerFixture): assert plan_builder.set_effective_from(None).build().effective_from is None assert updated_snapshot.effective_from is None - assert not plan_builder.build().missing_intervals plan_builder.set_effective_from("2023-02-01") + plan_builder.set_start("2023-02-01") assert plan_builder.build().effective_from == "2023-02-01" assert updated_snapshot.effective_from == "2023-02-01" @@ -1289,6 +1528,62 @@ def test_effective_from(make_snapshot, mocker: MockerFixture): assert updated_snapshot.effective_from is None +def test_effective_from_non_evaluatble_model(make_snapshot, mocker: MockerFixture): + snapshot = make_snapshot( + SqlModel( + name="a", + kind="EMBEDDED", + query=parse_one("select 1, ds FROM b"), + start="2023-01-01", + dialect="duckdb", + ) + ) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + updated_snapshot = make_snapshot( + SqlModel( + name="a", + kind="EMBEDDED", + query=parse_one("select 2, ds FROM b"), + start="2023-01-01", + dialect="duckdb", + ) + ) + updated_snapshot.previous_versions = snapshot.all_versions + + context_diff = ContextDiff( + environment="test_environment", + is_new_environment=True, + is_unfinalized_environment=False, + normalize_environment_name=True, + create_from="prod", + create_from_env_exists=True, + added=set(), + removed_snapshots={}, + modified_snapshots={updated_snapshot.name: (updated_snapshot, snapshot)}, + snapshots={updated_snapshot.snapshot_id: updated_snapshot}, + new_snapshots={updated_snapshot.snapshot_id: updated_snapshot}, + previous_plan_id=None, + previously_promoted_snapshot_ids=set(), + previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], + ) + + plan_builder = PlanBuilder( + context_diff, + forward_only=True, + start="2023-01-01", + end="2023-03-01", + is_dev=True, + ) + + plan_builder.set_effective_from("2023-02-01") + assert plan_builder.build().effective_from == "2023-02-01" + assert not updated_snapshot.effective_from + + def test_new_environment_no_changes(make_snapshot, mocker: MockerFixture): snapshot = make_snapshot(SqlModel(name="a", query=parse_one("select 1, ds"))) snapshot.categorize_as(SnapshotChangeCategory.BREAKING) @@ -1299,6 +1594,7 @@ def test_new_environment_no_changes(make_snapshot, mocker: MockerFixture): is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added=set(), removed_snapshots={}, modified_snapshots={}, @@ -1307,17 +1603,19 @@ def test_new_environment_no_changes(make_snapshot, mocker: MockerFixture): previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) - schema_differ = DuckDBEngineAdapter.SCHEMA_DIFFER - with pytest.raises(PlanError, match="No changes were detected.*"): - PlanBuilder(context_diff, schema_differ, is_dev=True).build() + with pytest.raises( + PlanError, match="Creating a new environment requires a change, but project files match.*" + ): + PlanBuilder(context_diff, is_dev=True).build() + assert PlanBuilder(context_diff).build().environment.promoted_snapshot_ids is None assert ( - PlanBuilder(context_diff, schema_differ).build().environment.promoted_snapshot_ids is None - ) - assert ( - PlanBuilder(context_diff, schema_differ, is_dev=True, include_unmodified=True) + PlanBuilder(context_diff, is_dev=True, include_unmodified=True) .build() .environment.promoted_snapshot_ids is None @@ -1338,6 +1636,7 @@ def test_new_environment_with_changes(make_snapshot, mocker: MockerFixture): is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added=set(), removed_snapshots={}, modified_snapshots={updated_snapshot_a.name: (updated_snapshot_a, snapshot_a)}, @@ -1349,13 +1648,16 @@ def test_new_environment_with_changes(make_snapshot, mocker: MockerFixture): previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) # Modified the existing model. - schema_differ = DuckDBEngineAdapter.SCHEMA_DIFFER - assert PlanBuilder( - context_diff, schema_differ, is_dev=True - ).build().environment.promoted_snapshot_ids == [updated_snapshot_a.snapshot_id] + + assert PlanBuilder(context_diff, is_dev=True).build().environment.promoted_snapshot_ids == [ + updated_snapshot_a.snapshot_id + ] # Updating the existing environment with a previously promoted snapshot. context_diff.previously_promoted_snapshot_ids = { @@ -1364,10 +1666,7 @@ def test_new_environment_with_changes(make_snapshot, mocker: MockerFixture): } context_diff.is_new_environment = False assert set( - PlanBuilder(context_diff, schema_differ, is_dev=True) - .build() - .environment.promoted_snapshot_ids - or [] + PlanBuilder(context_diff, is_dev=True).build().environment.promoted_snapshot_ids or [] ) == { updated_snapshot_a.snapshot_id, snapshot_b.snapshot_id, @@ -1386,10 +1685,7 @@ def test_new_environment_with_changes(make_snapshot, mocker: MockerFixture): context_diff.new_snapshots = {snapshot_c.snapshot_id: snapshot_c} assert set( - PlanBuilder(context_diff, schema_differ, is_dev=True) - .build() - .environment.promoted_snapshot_ids - or [] + PlanBuilder(context_diff, is_dev=True).build().environment.promoted_snapshot_ids or [] ) == { updated_snapshot_a.snapshot_id, snapshot_b.snapshot_id, @@ -1423,6 +1719,7 @@ def test_forward_only_models(make_snapshot, mocker: MockerFixture): is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added=set(), removed_snapshots={}, modified_snapshots={updated_snapshot.name: (updated_snapshot, snapshot)}, @@ -1431,21 +1728,26 @@ def test_forward_only_models(make_snapshot, mocker: MockerFixture): previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) - schema_differ = DuckDBEngineAdapter.SCHEMA_DIFFER - PlanBuilder(context_diff, schema_differ, is_dev=True).build() - assert updated_snapshot.change_category == SnapshotChangeCategory.FORWARD_ONLY + PlanBuilder(context_diff, is_dev=True).build() + assert updated_snapshot.change_category == SnapshotChangeCategory.BREAKING + assert updated_snapshot.is_forward_only updated_snapshot.change_category = None updated_snapshot.version = None - PlanBuilder(context_diff, schema_differ, is_dev=True, forward_only=True).build() - assert updated_snapshot.change_category == SnapshotChangeCategory.FORWARD_ONLY + PlanBuilder(context_diff, is_dev=True, forward_only=True).build() + assert updated_snapshot.change_category == SnapshotChangeCategory.BREAKING + assert updated_snapshot.is_forward_only updated_snapshot.change_category = None updated_snapshot.version = None - PlanBuilder(context_diff, schema_differ, forward_only=True).build() - assert updated_snapshot.change_category == SnapshotChangeCategory.FORWARD_ONLY + PlanBuilder(context_diff, forward_only=True).build() + assert updated_snapshot.change_category == SnapshotChangeCategory.BREAKING + assert updated_snapshot.is_forward_only def test_forward_only_models_model_kind_changed(make_snapshot, mocker: MockerFixture): @@ -1466,6 +1768,61 @@ def test_forward_only_models_model_kind_changed(make_snapshot, mocker: MockerFix is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, + added=set(), + removed_snapshots={}, + modified_snapshots={updated_snapshot.name: (updated_snapshot, snapshot)}, + snapshots={updated_snapshot.snapshot_id: updated_snapshot}, + new_snapshots={updated_snapshot.snapshot_id: updated_snapshot}, + previous_plan_id=None, + previously_promoted_snapshot_ids=set(), + previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], + ) + + PlanBuilder(context_diff, is_dev=True).build() + assert updated_snapshot.change_category == SnapshotChangeCategory.BREAKING + + +@pytest.mark.parametrize( + "partitioned_by, expected_forward_only", + [ + ([], False), + ([d.parse_one("ds")], True), + ], +) +def test_forward_only_models_model_kind_changed_to_incremental_by_time_range( + make_snapshot, + partitioned_by: t.List[exp.Expression], + expected_forward_only: bool, +): + snapshot = make_snapshot( + SqlModel( + name="a", + query=parse_one("select 1, ds"), + kind=IncrementalUnmanagedKind(), + partitioned_by=partitioned_by, + ) + ) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + updated_snapshot = make_snapshot( + SqlModel( + name="a", + query=parse_one("select 3, ds"), + kind=IncrementalByTimeRangeKind(time_column="ds", forward_only=True), + ) + ) + updated_snapshot.previous_versions = snapshot.all_versions + + context_diff = ContextDiff( + environment="test_environment", + is_new_environment=True, + is_unfinalized_environment=False, + normalize_environment_name=True, + create_from="prod", + create_from_env_exists=True, added=set(), removed_snapshots={}, modified_snapshots={updated_snapshot.name: (updated_snapshot, snapshot)}, @@ -1474,10 +1831,14 @@ def test_forward_only_models_model_kind_changed(make_snapshot, mocker: MockerFix previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) - PlanBuilder(context_diff, DuckDBEngineAdapter.SCHEMA_DIFFER, is_dev=True).build() + PlanBuilder(context_diff, is_dev=True).build() assert updated_snapshot.change_category == SnapshotChangeCategory.BREAKING + assert updated_snapshot.is_forward_only == expected_forward_only def test_indirectly_modified_forward_only_model(make_snapshot, mocker: MockerFixture): @@ -1495,7 +1856,7 @@ def test_indirectly_modified_forward_only_model(make_snapshot, mocker: MockerFix ), nodes={'"a"': snapshot_a.model}, ) - snapshot_b.categorize_as(SnapshotChangeCategory.FORWARD_ONLY) + snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) updated_snapshot_b = make_snapshot(snapshot_b.model, nodes={'"a"': updated_snapshot_a.model}) updated_snapshot_b.previous_versions = snapshot_b.all_versions @@ -1508,47 +1869,74 @@ def test_indirectly_modified_forward_only_model(make_snapshot, mocker: MockerFix ) updated_snapshot_c.previous_versions = snapshot_c.all_versions + snapshot_d = make_snapshot( + SqlModel(name="d", query=parse_one("select a.a from a, b")), + nodes={ + '"a"': snapshot_a.model, + '"b"': snapshot_b.model, + }, + ) + snapshot_d.categorize_as(SnapshotChangeCategory.BREAKING) + updated_snapshot_d = make_snapshot( + snapshot_d.model, nodes={'"b"': updated_snapshot_b.model, '"a"': updated_snapshot_a.model} + ) + updated_snapshot_d.previous_versions = snapshot_d.all_versions + context_diff = ContextDiff( environment="test_environment", is_new_environment=True, is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added=set(), removed_snapshots={}, modified_snapshots={ updated_snapshot_a.name: (updated_snapshot_a, snapshot_a), updated_snapshot_b.name: (updated_snapshot_b, snapshot_b), updated_snapshot_c.name: (updated_snapshot_c, snapshot_c), + updated_snapshot_d.name: (updated_snapshot_d, snapshot_d), }, snapshots={ updated_snapshot_a.snapshot_id: updated_snapshot_a, updated_snapshot_b.snapshot_id: updated_snapshot_b, updated_snapshot_c.snapshot_id: updated_snapshot_c, + updated_snapshot_d.snapshot_id: updated_snapshot_d, }, new_snapshots={ updated_snapshot_a.snapshot_id: updated_snapshot_a, updated_snapshot_b.snapshot_id: updated_snapshot_b, updated_snapshot_c.snapshot_id: updated_snapshot_c, + updated_snapshot_d.snapshot_id: updated_snapshot_d, }, previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) - plan = PlanBuilder(context_diff, DuckDBEngineAdapter.SCHEMA_DIFFER, is_dev=True).build() + plan = PlanBuilder(context_diff, is_dev=True).build() assert plan.indirectly_modified == { updated_snapshot_a.snapshot_id: { updated_snapshot_b.snapshot_id, updated_snapshot_c.snapshot_id, + updated_snapshot_d.snapshot_id, } } assert plan.directly_modified == {updated_snapshot_a.snapshot_id} assert updated_snapshot_a.change_category == SnapshotChangeCategory.BREAKING - assert updated_snapshot_b.change_category == SnapshotChangeCategory.FORWARD_ONLY + assert updated_snapshot_b.change_category == SnapshotChangeCategory.INDIRECT_BREAKING assert updated_snapshot_c.change_category == SnapshotChangeCategory.INDIRECT_BREAKING + assert updated_snapshot_d.change_category == SnapshotChangeCategory.INDIRECT_BREAKING + + assert not updated_snapshot_a.is_forward_only + assert updated_snapshot_b.is_forward_only + assert not updated_snapshot_c.is_forward_only + assert not updated_snapshot_d.is_forward_only deployability_index = DeployabilityIndex.create( { @@ -1564,7 +1952,7 @@ def test_indirectly_modified_forward_only_model(make_snapshot, mocker: MockerFix def test_added_model_with_forward_only_parent(make_snapshot, mocker: MockerFixture): snapshot_a = make_snapshot(SqlModel(name="a", query=parse_one("select 1 as a, ds"))) - snapshot_a.categorize_as(SnapshotChangeCategory.FORWARD_ONLY) + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) snapshot_b = make_snapshot(SqlModel(name="b", query=parse_one("select a, ds from a"))) @@ -1574,6 +1962,7 @@ def test_added_model_with_forward_only_parent(make_snapshot, mocker: MockerFixtu is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added={snapshot_b.snapshot_id}, removed_snapshots={}, modified_snapshots={}, @@ -1585,10 +1974,14 @@ def test_added_model_with_forward_only_parent(make_snapshot, mocker: MockerFixtu previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) - PlanBuilder(context_diff, DuckDBEngineAdapter.SCHEMA_DIFFER, is_dev=True).build() + PlanBuilder(context_diff, is_dev=True).build() assert snapshot_b.change_category == SnapshotChangeCategory.BREAKING + assert not snapshot_b.is_forward_only def test_added_forward_only_model(make_snapshot, mocker: MockerFixture): @@ -1608,6 +2001,7 @@ def test_added_forward_only_model(make_snapshot, mocker: MockerFixture): is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added={snapshot_a.snapshot_id, snapshot_b.snapshot_id}, removed_snapshots={}, modified_snapshots={}, @@ -1622,9 +2016,12 @@ def test_added_forward_only_model(make_snapshot, mocker: MockerFixture): previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) - PlanBuilder(context_diff, DuckDBEngineAdapter.SCHEMA_DIFFER).build() + PlanBuilder(context_diff).build() assert snapshot_a.change_category == SnapshotChangeCategory.BREAKING assert snapshot_b.change_category == SnapshotChangeCategory.BREAKING @@ -1645,6 +2042,7 @@ def test_disable_restatement(make_snapshot, mocker: MockerFixture): is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added=set(), removed_snapshots={}, modified_snapshots={}, @@ -1653,25 +2051,31 @@ def test_disable_restatement(make_snapshot, mocker: MockerFixture): previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) - schema_differ = DuckDBEngineAdapter.SCHEMA_DIFFER - plan = PlanBuilder(context_diff, schema_differ, restate_models=['"a"']).build() + plan = PlanBuilder(context_diff, restate_models=['"a"']).build() assert not plan.restatements # Effective from doesn't apply to snapshots for which restatements are disabled. - plan = PlanBuilder( - context_diff, schema_differ, forward_only=True, effective_from="2023-01-01" - ).build() + plan = PlanBuilder(context_diff, forward_only=True, effective_from="2023-01-01").build() assert plan.effective_from == "2023-01-01" assert snapshot.effective_from is None # Restatements should still be supported when in dev. - plan = PlanBuilder(context_diff, schema_differ, is_dev=True, restate_models=['"a"']).build() + plan = PlanBuilder(context_diff, is_dev=True, restate_models=['"a"']).build() assert plan.restatements == { - snapshot.snapshot_id: (to_timestamp(plan.start), to_timestamp(to_date("today"))) + snapshot.snapshot_id: (to_timestamp(plan.start), to_timestamp(to_date("tomorrow"))) } + # We don't want to restate a disable_restatement model if it is unpaused since that would be mean we are violating + # the model kind property + snapshot.unpaused_ts = 9999999999 + plan = PlanBuilder(context_diff, is_dev=True, restate_models=['"a"']).build() + assert plan.restatements == {} + def test_revert_to_previous_value(make_snapshot, mocker: MockerFixture): """ @@ -1690,7 +2094,7 @@ def test_revert_to_previous_value(make_snapshot, mocker: MockerFixture): snapshot_b = make_snapshot( SqlModel(name="b", query=parse_one("select 1, ds FROM a"), depends_on={"a"}) ) - snapshot_b.categorize_as(SnapshotChangeCategory.FORWARD_ONLY) + snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) snapshot_b.add_interval("2022-01-01", now()) context_diff = ContextDiff( @@ -1699,6 +2103,7 @@ def test_revert_to_previous_value(make_snapshot, mocker: MockerFixture): is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added=set(), removed_snapshots={}, modified_snapshots={ @@ -1713,13 +2118,17 @@ def test_revert_to_previous_value(make_snapshot, mocker: MockerFixture): previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) - plan_builder = PlanBuilder(context_diff, DuckDBEngineAdapter.SCHEMA_DIFFER) + plan_builder = PlanBuilder(context_diff) plan_builder.set_choice(snapshot_a, SnapshotChangeCategory.BREAKING) plan_builder.build() # Make sure it does not get assigned INDIRECT_BREAKING - assert snapshot_b.change_category == SnapshotChangeCategory.FORWARD_ONLY + assert snapshot_b.change_category == SnapshotChangeCategory.BREAKING + assert snapshot_b.is_forward_only test_add_restatement_fixtures = [ @@ -1916,6 +2325,7 @@ def test_add_restatements( is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added=set(), removed_snapshots={}, modified_snapshots={}, @@ -1924,11 +2334,13 @@ def test_add_restatements( previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) plan = PlanBuilder( context_diff, - DuckDBEngineAdapter.SCHEMA_DIFFER, start=to_date(start), end=to_date(end), execution_time=to_date(execution_time), @@ -1984,6 +2396,7 @@ def test_dev_plan_depends_past(make_snapshot, mocker: MockerFixture): is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added={snapshot.snapshot_id, snapshot_child.snapshot_id, unrelated_snapshot.snapshot_id}, removed_snapshots={}, modified_snapshots={}, @@ -2000,11 +2413,13 @@ def test_dev_plan_depends_past(make_snapshot, mocker: MockerFixture): previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) - schema_differ = DuckDBEngineAdapter.SCHEMA_DIFFER dev_plan_start_aligned = PlanBuilder( - context_diff, schema_differ, start="2023-01-01", end="2023-01-10", is_dev=True + context_diff, start="2023-01-01", end="2023-01-10", is_dev=True ).build() assert len(dev_plan_start_aligned.new_snapshots) == 3 assert sorted([x.name for x in dev_plan_start_aligned.new_snapshots]) == [ @@ -2020,16 +2435,17 @@ def test_dev_plan_depends_past(make_snapshot, mocker: MockerFixture): assert dev_plan_start_aligned.indirectly_modified == {} dev_plan_start_ahead_of_model = PlanBuilder( - context_diff, schema_differ, start="2023-01-02", end="2023-01-10", is_dev=True + context_diff, start="2023-01-02", end="2023-01-10", is_dev=True ).build() - assert len(dev_plan_start_ahead_of_model.new_snapshots) == 1 - assert [x.name for x in dev_plan_start_ahead_of_model.new_snapshots] == ['"b"'] - assert len(dev_plan_start_ahead_of_model.ignored) == 2 - assert sorted(list(dev_plan_start_ahead_of_model.ignored)) == [ + assert len(dev_plan_start_ahead_of_model.new_snapshots) == 3 + assert not dev_plan_start_ahead_of_model.deployability_index.is_deployable(snapshot) + assert not dev_plan_start_ahead_of_model.deployability_index.is_deployable(snapshot_child) + assert dev_plan_start_ahead_of_model.deployability_index.is_deployable(unrelated_snapshot) + assert dev_plan_start_ahead_of_model.directly_modified == { snapshot.snapshot_id, snapshot_child.snapshot_id, - ] - assert dev_plan_start_ahead_of_model.directly_modified == {unrelated_snapshot.snapshot_id} + unrelated_snapshot.snapshot_id, + } assert dev_plan_start_ahead_of_model.indirectly_modified == {} @@ -2053,18 +2469,16 @@ def test_dev_plan_depends_past_non_deployable(make_snapshot, mocker: MockerFixtu } ), ) - updated_snapshot.categorize_as(SnapshotChangeCategory.FORWARD_ONLY) snapshot_child = make_snapshot( SqlModel( name="a_child", query=parse_one("select 1, ds FROM a"), start="2023-01-01", - kind=IncrementalByTimeRangeKind(time_column="ds"), + kind=IncrementalByTimeRangeKind(time_column="ds", forward_only=True), ), nodes={'"a"': updated_snapshot.model}, ) - snapshot_child.categorize_as(SnapshotChangeCategory.BREAKING) unrelated_snapshot = make_snapshot( SqlModel( name="b", @@ -2073,7 +2487,6 @@ def test_dev_plan_depends_past_non_deployable(make_snapshot, mocker: MockerFixtu kind=IncrementalByTimeRangeKind(time_column="ds"), ), ) - unrelated_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) assert updated_snapshot.depends_on_self assert not snapshot_child.depends_on_self @@ -2088,6 +2501,7 @@ def test_dev_plan_depends_past_non_deployable(make_snapshot, mocker: MockerFixtu is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added={snapshot_child.snapshot_id, unrelated_snapshot.snapshot_id}, removed_snapshots={}, modified_snapshots={snapshot.name: (updated_snapshot, snapshot)}, @@ -2104,30 +2518,24 @@ def test_dev_plan_depends_past_non_deployable(make_snapshot, mocker: MockerFixtu previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) - schema_differ = DuckDBEngineAdapter.SCHEMA_DIFFER - dev_plan_start_aligned = PlanBuilder( - context_diff, schema_differ, start="2023-01-01", end="2023-01-10", is_dev=True - ).build() - assert len(dev_plan_start_aligned.new_snapshots) == 3 - assert sorted([x.name for x in dev_plan_start_aligned.new_snapshots]) == [ - '"a"', - '"a_child"', - '"b"', - ] + def new_builder(start, end): + builder = PlanBuilder(context_diff, start=start, end=end, is_dev=True) + builder.set_choice(updated_snapshot, SnapshotChangeCategory.BREAKING) + builder.set_choice(snapshot_child, SnapshotChangeCategory.BREAKING) + builder.set_choice(unrelated_snapshot, SnapshotChangeCategory.BREAKING) + return builder - # There should be no ignored snapshots because all changes are non-deployable. - dev_plan_start_ahead_of_model = PlanBuilder( - context_diff, schema_differ, start="2023-01-02", end="2023-01-10", is_dev=True - ).build() - assert len(dev_plan_start_ahead_of_model.new_snapshots) == 3 + dev_plan_start_aligned = new_builder("2023-01-01", "2023-01-10").build() assert sorted([x.name for x in dev_plan_start_aligned.new_snapshots]) == [ '"a"', '"a_child"', '"b"', ] - assert not dev_plan_start_ahead_of_model.ignored def test_restatement_intervals_after_updating_start(sushi_context: Context): @@ -2165,6 +2573,7 @@ def test_models_selected_for_backfill(make_snapshot, mocker: MockerFixture): is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added={snapshot_b.snapshot_id}, removed_snapshots={}, modified_snapshots={}, @@ -2176,16 +2585,12 @@ def test_models_selected_for_backfill(make_snapshot, mocker: MockerFixture): previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) - schema_differ = DuckDBEngineAdapter.SCHEMA_DIFFER - with pytest.raises( - PlanError, - match="Selecting models to backfill is only supported for development environments", - ): - PlanBuilder(context_diff, schema_differ, backfill_models={'"a"'}).build() - - plan = PlanBuilder(context_diff, schema_differ).build() + plan = PlanBuilder(context_diff).build() assert plan.is_selected_for_backfill('"a"') assert plan.is_selected_for_backfill('"b"') assert plan.models_to_backfill is None @@ -2194,14 +2599,14 @@ def test_models_selected_for_backfill(make_snapshot, mocker: MockerFixture): snapshot_b.snapshot_id, } - plan = PlanBuilder(context_diff, schema_differ, is_dev=True, backfill_models={'"a"'}).build() + plan = PlanBuilder(context_diff, is_dev=True, backfill_models={'"a"'}).build() assert plan.is_selected_for_backfill('"a"') assert not plan.is_selected_for_backfill('"b"') assert plan.models_to_backfill == {'"a"'} assert {i.snapshot_id for i in plan.missing_intervals} == {snapshot_a.snapshot_id} - assert not plan.environment.promoted_snapshot_ids + assert plan.environment.promoted_snapshot_ids == [snapshot_a.snapshot_id] - plan = PlanBuilder(context_diff, schema_differ, is_dev=True, backfill_models={'"b"'}).build() + plan = PlanBuilder(context_diff, is_dev=True, backfill_models={'"b"'}).build() assert plan.is_selected_for_backfill('"a"') assert plan.is_selected_for_backfill('"b"') assert plan.models_to_backfill == {'"a"', '"b"'} @@ -2224,6 +2629,7 @@ def test_categorized_uncategorized(make_snapshot, mocker: MockerFixture): is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added=set(), removed_snapshots={}, modified_snapshots={new_snapshot.name: (new_snapshot, snapshot)}, @@ -2232,11 +2638,12 @@ def test_categorized_uncategorized(make_snapshot, mocker: MockerFixture): previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) - plan_builder = PlanBuilder( - context_diff, DuckDBEngineAdapter.SCHEMA_DIFFER, auto_categorization_enabled=False - ) + plan_builder = PlanBuilder(context_diff, auto_categorization_enabled=False) plan = plan_builder.build() assert plan.uncategorized == [new_snapshot] @@ -2271,6 +2678,7 @@ def test_environment_previous_finalized_snapshots(make_snapshot, mocker: MockerF is_unfinalized_environment=True, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added={snapshot_b.snapshot_id}, removed_snapshots={snapshot_c.snapshot_id: snapshot_c.table_info}, modified_snapshots={snapshot_a.name: (updated_snapshot_a, snapshot_a)}, @@ -2286,10 +2694,12 @@ def test_environment_previous_finalized_snapshots(make_snapshot, mocker: MockerF previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=[snapshot_c.table_info, snapshot_d.table_info], + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) - schema_differ = DuckDBEngineAdapter.SCHEMA_DIFFER - plan = PlanBuilder(context_diff, schema_differ).build() + plan = PlanBuilder(context_diff).build() assert set(plan.environment.previous_finalized_snapshots or []) == { snapshot_c.table_info, snapshot_d.table_info, @@ -2297,7 +2707,7 @@ def test_environment_previous_finalized_snapshots(make_snapshot, mocker: MockerF context_diff.is_unfinalized_environment = False - plan = PlanBuilder(context_diff, schema_differ).build() + plan = PlanBuilder(context_diff).build() assert set(plan.environment.previous_finalized_snapshots or []) == { snapshot_a.table_info, snapshot_c.table_info, @@ -2331,6 +2741,7 @@ def test_metadata_change(make_snapshot, mocker: MockerFixture): is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added=set(), removed_snapshots={}, modified_snapshots={snapshot.name: (updated_snapshot, snapshot)}, @@ -2339,9 +2750,12 @@ def test_metadata_change(make_snapshot, mocker: MockerFixture): previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) - plan = PlanBuilder(context_diff, DuckDBEngineAdapter.SCHEMA_DIFFER, is_dev=True).build() + plan = PlanBuilder(context_diff, is_dev=True).build() assert ( plan.snapshots[updated_snapshot.snapshot_id].change_category @@ -2368,6 +2782,7 @@ def test_plan_start_when_preview_enabled(make_snapshot, mocker: MockerFixture): is_unfinalized_environment=False, normalize_environment_name=True, create_from="prod", + create_from_env_exists=True, added={snapshot.snapshot_id}, removed_snapshots={}, modified_snapshots={}, @@ -2378,19 +2793,20 @@ def test_plan_start_when_preview_enabled(make_snapshot, mocker: MockerFixture): previous_plan_id=None, previously_promoted_snapshot_ids=set(), previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], ) default_start_for_preview = "2024-06-09" - # When a model is added SQLMesh should not consider the backfill to be a preview. plan_builder = PlanBuilder( context_diff, - DuckDBEngineAdapter.SCHEMA_DIFFER, default_start=default_start_for_preview, is_dev=True, enable_preview=True, ) - assert plan_builder.build().start == to_timestamp(model_start) + assert plan_builder.build().start == default_start_for_preview # When a model is modified then the backfill should be a preview. snapshot = make_snapshot(model) @@ -2399,9 +2815,1490 @@ def test_plan_start_when_preview_enabled(make_snapshot, mocker: MockerFixture): plan_builder = PlanBuilder( context_diff, - DuckDBEngineAdapter.SCHEMA_DIFFER, default_start=default_start_for_preview, is_dev=True, enable_preview=True, ) assert plan_builder.build().start == default_start_for_preview + + +def test_end_override_per_model(make_snapshot): + snapshot = make_snapshot(SqlModel(name="a", query=parse_one("select 1, ds"))) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + new_snapshot = make_snapshot(SqlModel(name="a", query=parse_one("select 2, ds"))) + + context_diff = ContextDiff( + environment="test_environment", + is_new_environment=True, + is_unfinalized_environment=False, + normalize_environment_name=True, + create_from="prod", + create_from_env_exists=True, + added=set(), + removed_snapshots={}, + modified_snapshots={new_snapshot.name: (new_snapshot, snapshot)}, + snapshots={new_snapshot.snapshot_id: new_snapshot}, + new_snapshots={new_snapshot.snapshot_id: new_snapshot}, + previous_plan_id=None, + previously_promoted_snapshot_ids=set(), + previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], + ) + + plan_builder = PlanBuilder( + context_diff, + end_override_per_model={snapshot.name: to_datetime("2023-01-09")}, + ) + assert plan_builder.build().end_override_per_model == {snapshot.name: to_datetime("2023-01-09")} + + # User-provided end should take precedence. + plan_builder = PlanBuilder( + context_diff, + end_override_per_model={snapshot.name: to_datetime("2023-01-09")}, + end="2023-01-10", + is_dev=True, + ) + assert plan_builder.build().end_override_per_model is None + + +def test_unaligned_start_model_with_forward_only_preview(make_snapshot): + snapshot_a = make_snapshot( + SqlModel( + name="a", + query=parse_one("select 1, ds"), + kind=dict( + name=ModelKindName.INCREMENTAL_BY_TIME_RANGE, forward_only=True, time_column="ds" + ), + ) + ) + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) + + new_snapshot_a = make_snapshot( + SqlModel( + name="a", + query=parse_one("select 2, ds"), + kind=dict( + name=ModelKindName.INCREMENTAL_BY_TIME_RANGE, forward_only=True, time_column="ds" + ), + ) + ) + new_snapshot_a.previous_versions = snapshot_a.all_versions + new_snapshot_a.unpaused_ts = 1 + + snapshot_b = make_snapshot( + SqlModel( + name="b", + query=parse_one("select 1 AS key"), + kind=dict(name=ModelKindName.INCREMENTAL_BY_UNIQUE_KEY, unique_key="key"), + start="2024-01-01", + depends_on={"a"}, + ), + nodes={new_snapshot_a.name: new_snapshot_a.model}, + ) + + context_diff = ContextDiff( + environment="test_environment", + is_new_environment=True, + is_unfinalized_environment=False, + normalize_environment_name=True, + create_from="prod", + create_from_env_exists=True, + added={snapshot_b.snapshot_id}, + removed_snapshots={}, + snapshots={new_snapshot_a.snapshot_id: new_snapshot_a, snapshot_b.snapshot_id: snapshot_b}, + new_snapshots={ + new_snapshot_a.snapshot_id: new_snapshot_a, + snapshot_b.snapshot_id: snapshot_b, + }, + modified_snapshots={snapshot_a.name: (new_snapshot_a, snapshot_a)}, + previous_plan_id=None, + previously_promoted_snapshot_ids=set(), + previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], + ) + + plan_builder = PlanBuilder( + context_diff, + enable_preview=True, + is_dev=True, + ) + plan = plan_builder.build() + + assert set(plan.restatements) == {new_snapshot_a.snapshot_id, snapshot_b.snapshot_id} + assert not plan.deployability_index.is_deployable(new_snapshot_a) + assert not plan.deployability_index.is_deployable(snapshot_b) + + +def test_restate_production_model_in_dev(make_snapshot, mocker: MockerFixture): + snapshot = make_snapshot( + SqlModel( + name="test_model_a", + dialect="duckdb", + query=parse_one("select 1, ds"), + kind=dict(name=ModelKindName.INCREMENTAL_BY_TIME_RANGE, time_column="ds"), + ) + ) + + prod_snapshot = make_snapshot( + SqlModel( + name="test_model_b", + dialect="duckdb", + query=parse_one("select 2, ds"), + kind=dict(name=ModelKindName.INCREMENTAL_BY_TIME_RANGE, time_column="ds"), + ) + ) + prod_snapshot.unpaused_ts = 1 + + context_diff = ContextDiff( + environment="test_environment", + is_new_environment=False, + is_unfinalized_environment=True, + normalize_environment_name=True, + create_from="prod", + create_from_env_exists=True, + added=set(), + removed_snapshots={}, + modified_snapshots={}, + snapshots={snapshot.snapshot_id: snapshot, prod_snapshot.snapshot_id: prod_snapshot}, + new_snapshots={}, + previous_plan_id=None, + previously_promoted_snapshot_ids=set(), + previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], + ) + + mock_console = mocker.Mock() + + plan = PlanBuilder( + context_diff, + is_dev=True, + restate_models={snapshot.name, prod_snapshot.name}, + console=mock_console, + ).build() + + assert len(plan.restatements) == 1 + assert prod_snapshot.snapshot_id not in plan.restatements + + mock_console.log_warning.assert_called_once_with( + "Cannot restate model '\"test_model_b\"' because the current version is used in production. " + "Run the restatement against the production environment instead to restate this model." + ) + + +@time_machine.travel("2025-02-23 15:00:00 UTC") +def test_restate_daily_to_monthly(make_snapshot, mocker: MockerFixture): + snapshot_a = make_snapshot( + SqlModel( + name="a", + query=parse_one("select 1 as one"), + cron="@daily", + start="2025-01-01", + ), + ) + + snapshot_b = make_snapshot( + SqlModel( + name="b", + query=parse_one("select one from a"), + cron="@monthly", + start="2025-01-01", + ), + nodes={'"a"': snapshot_a.model}, + ) + + snapshot_c = make_snapshot( + SqlModel( + name="c", + query=parse_one("select one from b"), + cron="@daily", + start="2025-01-01", + ), + nodes={ + '"a"': snapshot_a.model, + '"b"': snapshot_b.model, + }, + ) + + snapshot_d = make_snapshot( + SqlModel( + name="d", + query=parse_one("select one from b union all select one from a"), + cron="@daily", + start="2025-01-01", + ), + nodes={ + '"a"': snapshot_a.model, + '"b"': snapshot_b.model, + }, + ) + snapshot_e = make_snapshot( + SqlModel( + name="e", + query=parse_one("select one from b"), + cron="@daily", + start="2025-01-01", + ), + nodes={ + '"a"': snapshot_a.model, + '"b"': snapshot_b.model, + }, + ) + + context_diff = ContextDiff( + environment="prod", + is_new_environment=False, + is_unfinalized_environment=True, + normalize_environment_name=True, + create_from="prod", + create_from_env_exists=True, + added=set(), + removed_snapshots={}, + modified_snapshots={}, + snapshots={ + snapshot_a.snapshot_id: snapshot_a, + snapshot_b.snapshot_id: snapshot_b, + snapshot_c.snapshot_id: snapshot_c, + snapshot_d.snapshot_id: snapshot_d, + snapshot_e.snapshot_id: snapshot_e, + }, + new_snapshots={}, + previous_plan_id=None, + previously_promoted_snapshot_ids=set(), + previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], + ) + + plan = PlanBuilder( + context_diff, + restate_models=[snapshot_a.name, snapshot_e.name], + start="2025-02-15", + end="2025-02-20", + ).build() + + assert plan.restatements == { + snapshot_a.snapshot_id: (1739577600000, 1740355200000), + snapshot_b.snapshot_id: (1738368000000, 1740787200000), + snapshot_c.snapshot_id: (1739577600000, 1740355200000), + snapshot_d.snapshot_id: (1739577600000, 1740355200000), + snapshot_e.snapshot_id: (1739577600000, 1740355200000), + } + + +def test_plan_environment_statements_diff(make_snapshot): + snapshot = make_snapshot( + SqlModel( + name="test_model_a", + dialect="duckdb", + query=parse_one("select 1, ds"), + kind=dict(name=ModelKindName.INCREMENTAL_BY_TIME_RANGE, time_column="ds"), + ) + ) + + context_diff = ContextDiff( + environment="test_environment", + is_new_environment=False, + is_unfinalized_environment=True, + normalize_environment_name=True, + create_from="prod", + create_from_env_exists=True, + added=set(), + removed_snapshots={}, + modified_snapshots={}, + snapshots={snapshot.snapshot_id: snapshot}, + new_snapshots={}, + previous_plan_id=None, + previously_promoted_snapshot_ids=set(), + previous_finalized_snapshots=None, + environment_statements=[ + EnvironmentStatements( + before_all=["CREATE OR REPLACE TABLE table_1 AS SELECT 1", "@test_macro()"], + after_all=["CREATE OR REPLACE TABLE table_2 AS SELECT 2"], + python_env={ + "test_macro": Executable( + payload="def test_macro(evaluator):\n return 'one'" + ), + }, + ) + ], + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + ) + + assert context_diff.has_changes + assert context_diff.has_environment_statements_changes + + console_output, terminal_console = create_test_console() + for _, diff in context_diff.environment_statements_diff(): + terminal_console._print(diff) + output = console_output.getvalue() + stripped = strip_ansi_codes(output) + + expected_output = ( + "before_all:\n" + " + CREATE OR REPLACE TABLE table_1 AS SELECT 1\n" + " + @test_macro()\n\n" + "after_all:\n" + " + CREATE OR REPLACE TABLE table_2 AS SELECT 2" + ) + assert stripped == expected_output + console_output.close() + + # Validate with python env included + console_output, terminal_console = create_test_console() + for _, diff in context_diff.environment_statements_diff(include_python_env=True): + terminal_console._print(diff) + output = console_output.getvalue() + stripped = strip_ansi_codes(output) + expected_output = ( + "before_all:\n" + " + CREATE OR REPLACE TABLE table_1 AS SELECT 1\n" + " + @test_macro()\n\n" + "after_all:\n" + " + CREATE OR REPLACE TABLE table_2 AS SELECT 2\n\n" + "dependencies:\n" + "@@ -0,0 +1,2 @@\n\n" + "+def test_macro(evaluator):\n" + "+ return 'one'" + ) + assert stripped == expected_output + console_output.close() + + +def test_set_choice_for_forward_only_model(make_snapshot): + snapshot = make_snapshot( + SqlModel( + name="a", + query=parse_one("select 1, ds"), + dialect="duckdb", + kind=IncrementalByTimeRangeKind(time_column="ds", forward_only=True), + ) + ) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + updated_snapshot = make_snapshot( + SqlModel( + name="a", + query=parse_one("select 3, ds"), + kind=IncrementalByTimeRangeKind(time_column="ds", forward_only=True), + dialect="duckdb", + ) + ) + updated_snapshot.previous_versions = snapshot.all_versions + + context_diff = ContextDiff( + environment="test_environment", + is_new_environment=True, + is_unfinalized_environment=False, + normalize_environment_name=True, + create_from="prod", + create_from_env_exists=True, + added=set(), + removed_snapshots={}, + modified_snapshots={updated_snapshot.name: (updated_snapshot, snapshot)}, + snapshots={updated_snapshot.snapshot_id: updated_snapshot}, + new_snapshots={updated_snapshot.snapshot_id: updated_snapshot}, + previous_plan_id=None, + previously_promoted_snapshot_ids=set(), + previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], + ) + + plan_builder = PlanBuilder(context_diff, is_dev=True) + plan_builder.set_choice(updated_snapshot, SnapshotChangeCategory.BREAKING) + + plan = plan_builder.build() + assert ( + plan.snapshots[updated_snapshot.snapshot_id].change_category + == SnapshotChangeCategory.BREAKING + ) + assert plan.snapshots[updated_snapshot.snapshot_id].is_forward_only + + +def test_user_provided_flags(sushi_context: Context): + expected_flags = { + "run": True, + "execution_time": "2025-01-01", + } + plan_a = sushi_context.plan(no_prompts=True, run=True, execution_time="2025-01-01") + assert plan_a.user_provided_flags == expected_flags + evaluatable_plan = plan_a.to_evaluatable() + assert evaluatable_plan.user_provided_flags == expected_flags + + plan_b = sushi_context.plan() + assert plan_b.user_provided_flags == {} + evaluatable_plan_b = plan_b.to_evaluatable() + assert evaluatable_plan_b.user_provided_flags == {} + + context_diff = ContextDiff( + environment="test_environment", + is_new_environment=True, + is_unfinalized_environment=False, + normalize_environment_name=True, + create_from="prod", + create_from_env_exists=True, + added=set(), + removed_snapshots={}, + modified_snapshots={}, + snapshots={}, + new_snapshots={}, + previous_plan_id=None, + previously_promoted_snapshot_ids=set(), + previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], + ) + plan_builder = PlanBuilder( + context_diff, + forward_only=True, + user_provided_flags={"forward_only": True}, + ).build() + assert plan_builder.user_provided_flags == {"forward_only": True} + plan_builder = PlanBuilder( + context_diff, + ).build() + assert plan_builder.user_provided_flags == None + + +@time_machine.travel(now()) +@pytest.mark.parametrize( + "input,output", + [ + # execution_time, start, end + ( + # no execution time, start or end + (None, None, None), + # execution time defaults to now() + # start defaults to 1 day before execution time + # end defaults to execution_time + (now(), yesterday_ds(), now()), + ), + ( + # fixed execution time, no start, no end + ("2020-01-05", None, None), + # execution time set to 2020-01-05 + # start defaults to 1 day before execution time + # end defaults to execution time + ("2020-01-05", "2020-01-04", "2020-01-05"), + ), + ( + # fixed execution time, relative start, no end + ("2020-01-05", "2 days ago", None), + # execution time set to 2020-01-05 + # start relative to execution time + # end defaults to execution time + ("2020-01-05", "2020-01-03", "2020-01-05"), + ), + ( + # fixed execution time, relative start, relative end + ("2020-01-05", "2 days ago", "1 day ago"), + # execution time set to 2020-01-05 + # start relative to execution time + # end relative to execution time + ("2020-01-05", "2020-01-03", "2020-01-04"), + ), + ( + # fixed execution time, fixed start, fixed end + ("2020-01-05", "2020-01-01", "2020-01-05"), + # fixed dates are all in the valid range + ("2020-01-05", "2020-01-01", "2020-01-05"), + ), + ( + # fixed execution time, fixed start, fixed end + ("2020-01-05", "2020-01-05", "2020-01-01"), + # Error because start is after end + r"Plan end date.*must be after the plan start date", + ), + ( + # fixed execution time, relative start, fixed end beyond fixed execution time + ("2020-01-05", "2 days ago", "2021-01-01"), + # Error because end is set to 2021-01-01 which is after the execution time + r"Plan end date.*cannot be in the future", + ), + ], +) +def test_plan_dates_relative_to_execution_time( + input: t.Tuple[t.Optional[str], ...], + output: t.Union[str, t.Tuple[t.Optional[str], ...]], + make_snapshot: t.Callable, +): + snapshot_a = make_snapshot( + SqlModel(name="a", query=parse_one("select 1, ds"), dialect="duckdb") + ) + + context_diff = ContextDiff( + environment="test_environment", + is_new_environment=True, + is_unfinalized_environment=False, + normalize_environment_name=True, + create_from="prod", + create_from_env_exists=True, + added={snapshot_a.snapshot_id}, + removed_snapshots={}, + modified_snapshots={}, + snapshots={}, + new_snapshots={snapshot_a.snapshot_id: snapshot_a}, + previous_plan_id=None, + previously_promoted_snapshot_ids=set(), + previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], + ) + + input_execution_time, input_start, input_end = input + + def _build_plan() -> Plan: + return PlanBuilder( + context_diff, + start=input_start, + end=input_end, + execution_time=input_execution_time, + is_dev=True, + ).build() + + if isinstance(output, str): + with pytest.raises(PlanError, match=output): + _build_plan() + else: + output_execution_time, output_start, output_end = output + + plan = _build_plan() + assert to_datetime(plan.start) == to_datetime(output_start) + assert to_datetime(plan.end) == to_datetime(output_end) + assert to_datetime(plan.execution_time) == to_datetime(output_execution_time) + + +def test_plan_builder_additive_change_error_blocks_plan(make_snapshot): + """Test that additive changes block plan when on_additive_change=ERROR.""" + # Create models with actual schema differences + # Use explicit column schemas in CTE so columns_to_types can be determined + old_model = SqlModel( + name="test_model", + dialect="duckdb", + query=parse_one(""" + with source as ( + select 1::INT as id, 'test'::VARCHAR as name, '2022-01-01'::DATE as ds + ) + select id, name, ds from source + """), + kind=IncrementalByTimeRangeKind( + time_column="ds", + forward_only=True, + on_additive_change=OnAdditiveChange.ERROR, + ), + ) + + # New model with additional column (additive change) + new_model = SqlModel( + name="test_model", + dialect="duckdb", + query=parse_one(""" + with source as ( + select 1::INT as id, 'test'::VARCHAR as name, 'email@test.com'::VARCHAR as email, '2022-01-01'::DATE as ds + ) + select id, name, email, ds from source + """), + kind=IncrementalByTimeRangeKind( + time_column="ds", + forward_only=True, + on_additive_change=OnAdditiveChange.ERROR, + ), + ) + + old_snapshot = make_snapshot(old_model) + new_snapshot = make_snapshot(new_model) + + # Set previous versions to simulate a modification + new_snapshot.previous_versions = ( + SnapshotDataVersion( + fingerprint=SnapshotFingerprint( + data_hash="old_data_hash", + metadata_hash="old_metadata_hash", + ), + version="old_version", + change_category=SnapshotChangeCategory.FORWARD_ONLY, + dev_table_suffix="dev", + ), + ) + + context_diff = ContextDiff( + environment="prod", + is_new_environment=False, + is_unfinalized_environment=False, + normalize_environment_name=True, + create_from="prod", + create_from_env_exists=True, + added=set(), + removed_snapshots={}, + modified_snapshots={old_snapshot.name: (new_snapshot, old_snapshot)}, + snapshots={ + old_snapshot.snapshot_id: old_snapshot, + new_snapshot.snapshot_id: new_snapshot, + }, + new_snapshots={new_snapshot.snapshot_id: new_snapshot}, + previous_plan_id=None, + previously_promoted_snapshot_ids=set(), + previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], + ) + + builder = PlanBuilder(context_diff, forward_only=True) + + # Should raise PlanError for additive changes when on_additive_change=ERROR + with pytest.raises(PlanError, match="additive change"): + builder.build() + + +def test_plan_builder_additive_change_warn_allows_plan(make_snapshot): + """Test that additive changes allow plan with warning when on_additive_change=WARN.""" + old_model = SqlModel( + name="test_model", + dialect="duckdb", + query=parse_one(""" + with source as ( + select 1::INT as id, 'test'::VARCHAR as name, '2022-01-01'::DATE as ds + ) + select id, name, ds from source + """), + kind=IncrementalByTimeRangeKind( + time_column="ds", + forward_only=True, + on_additive_change=OnAdditiveChange.WARN, + ), + ) + + new_model = SqlModel( + name="test_model", + dialect="duckdb", + query=parse_one(""" + with source as ( + select 1::INT as id, 'test'::VARCHAR as name, 'email@test.com'::VARCHAR as email, '2022-01-01'::DATE as ds + ) + select id, name, email, ds from source + """), + kind=IncrementalByTimeRangeKind( + time_column="ds", + forward_only=True, + on_additive_change=OnAdditiveChange.WARN, + ), + ) + + old_snapshot = make_snapshot(old_model) + new_snapshot = make_snapshot(new_model) + + # Set previous versions to simulate a modification + new_snapshot.previous_versions = ( + SnapshotDataVersion( + fingerprint=SnapshotFingerprint( + data_hash="old_data_hash", + metadata_hash="old_metadata_hash", + ), + version="old_version", + change_category=SnapshotChangeCategory.FORWARD_ONLY, + dev_table_suffix="dev", + ), + ) + + context_diff = ContextDiff( + environment="prod", + is_new_environment=False, + is_unfinalized_environment=False, + normalize_environment_name=True, + create_from="prod", + create_from_env_exists=True, + added=set(), + removed_snapshots={}, + modified_snapshots={old_snapshot.name: (new_snapshot, old_snapshot)}, + snapshots={ + old_snapshot.snapshot_id: old_snapshot, + new_snapshot.snapshot_id: new_snapshot, + }, + new_snapshots={new_snapshot.snapshot_id: new_snapshot}, + previous_plan_id=None, + previously_promoted_snapshot_ids=set(), + previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], + ) + + builder = PlanBuilder(context_diff, forward_only=True) + + # Should log warning but not fail + with patch.object(builder._console, "log_additive_change") as mock_log_additive: + plan = builder.build() + assert plan is not None + mock_log_additive.assert_called() # Should have logged an additive change + + +def test_plan_builder_additive_change_allow_permits_plan(make_snapshot): + """Test that additive changes are permitted when on_additive_change=ALLOW.""" + old_model = SqlModel( + name="test_model", + dialect="duckdb", + query=parse_one(""" + with source as ( + select 1::INT as id, 'test'::VARCHAR as name, '2022-01-01'::DATE as ds + ) + select id, name, ds from source + """), + kind=IncrementalByTimeRangeKind( + time_column="ds", + forward_only=True, + on_additive_change=OnAdditiveChange.ALLOW, + ), + ) + + new_model = SqlModel( + name="test_model", + dialect="duckdb", + query=parse_one(""" + with source as ( + select 1::INT as id, 'test'::VARCHAR as name, 'email@test.com'::VARCHAR as email, '2022-01-01'::DATE as ds + ) + select id, name, email, ds from source + """), + kind=IncrementalByTimeRangeKind( + time_column="ds", + forward_only=True, + on_additive_change=OnAdditiveChange.ALLOW, + ), + ) + + old_snapshot = make_snapshot(old_model) + new_snapshot = make_snapshot(new_model) + + # Set previous versions to simulate a modification + new_snapshot.previous_versions = ( + SnapshotDataVersion( + fingerprint=SnapshotFingerprint( + data_hash="old_data_hash", + metadata_hash="old_metadata_hash", + ), + version="old_version", + change_category=SnapshotChangeCategory.FORWARD_ONLY, + dev_table_suffix="dev", + ), + ) + + context_diff = ContextDiff( + environment="prod", + is_new_environment=False, + is_unfinalized_environment=False, + normalize_environment_name=True, + create_from="prod", + create_from_env_exists=True, + added=set(), + removed_snapshots={}, + modified_snapshots={old_snapshot.name: (new_snapshot, old_snapshot)}, + snapshots={ + old_snapshot.snapshot_id: old_snapshot, + new_snapshot.snapshot_id: new_snapshot, + }, + new_snapshots={new_snapshot.snapshot_id: new_snapshot}, + previous_plan_id=None, + previously_promoted_snapshot_ids=set(), + previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], + ) + + builder = PlanBuilder(context_diff, forward_only=True) + + # Should build plan without issues + plan = builder.build() + assert plan is not None + + +def test_plan_builder_additive_change_ignore_skips_validation(make_snapshot): + """Test that additive changes are ignored when on_additive_change=IGNORE.""" + old_model = SqlModel( + name="test_model", + dialect="duckdb", + query=parse_one(""" + with source as ( + select 1::INT as id, 'test'::VARCHAR as name, '2022-01-01'::DATE as ds + ) + select id, name, ds from source + """), + kind=IncrementalByTimeRangeKind( + time_column="ds", + forward_only=True, + on_additive_change=OnAdditiveChange.IGNORE, + ), + ) + + new_model = SqlModel( + name="test_model", + dialect="duckdb", + query=parse_one(""" + with source as ( + select 1::INT as id, 'test'::VARCHAR as name, 'email@test.com'::VARCHAR as email, '2022-01-01'::DATE as ds + ) + select id, name, email, ds from source + """), + kind=IncrementalByTimeRangeKind( + time_column="ds", + forward_only=True, + on_additive_change=OnAdditiveChange.IGNORE, + ), + ) + + old_snapshot = make_snapshot(old_model) + new_snapshot = make_snapshot(new_model) + + # Set previous versions to simulate a modification + new_snapshot.previous_versions = ( + SnapshotDataVersion( + fingerprint=SnapshotFingerprint( + data_hash="old_data_hash", + metadata_hash="old_metadata_hash", + ), + version="old_version", + change_category=SnapshotChangeCategory.FORWARD_ONLY, + dev_table_suffix="dev", + ), + ) + + context_diff = ContextDiff( + environment="prod", + is_new_environment=False, + is_unfinalized_environment=False, + normalize_environment_name=True, + create_from="prod", + create_from_env_exists=True, + added=set(), + removed_snapshots={}, + modified_snapshots={old_snapshot.name: (new_snapshot, old_snapshot)}, + snapshots={ + old_snapshot.snapshot_id: old_snapshot, + new_snapshot.snapshot_id: new_snapshot, + }, + new_snapshots={new_snapshot.snapshot_id: new_snapshot}, + previous_plan_id=None, + previously_promoted_snapshot_ids=set(), + previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], + ) + + builder = PlanBuilder(context_diff, forward_only=True) + + # Should build plan without any validation + with patch("sqlmesh.core.plan.builder.logger.warning") as mock_warning: + plan = builder.build() + assert plan is not None + mock_warning.assert_not_called() # Should not log any warnings + + +def test_plan_builder_mixed_destructive_and_additive_changes(make_snapshot): + """Test scenarios with both destructive and additive changes.""" + # Test case: on_destructive_change=IGNORE, on_additive_change=ERROR + # Should ignore destructive changes but error on additive changes + old_model = SqlModel( + name="test_model", + dialect="duckdb", + query=parse_one(""" + with source as ( + select 1::INT as id, 'test'::VARCHAR as name, 'old_value'::VARCHAR as old_col, '2022-01-01'::DATE as ds + ) + select id, name, old_col, ds from source + """), + kind=IncrementalByTimeRangeKind( + time_column="ds", + forward_only=True, + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.ERROR, + ), + ) + + new_model = SqlModel( + name="test_model", + dialect="duckdb", + query=parse_one(""" + with source as ( + select 1::INT as id, 'test'::VARCHAR as name, 'new_value'::VARCHAR as new_col, '2022-01-01'::DATE as ds + ) + select id, name, new_col, ds from source + """), + kind=IncrementalByTimeRangeKind( + time_column="ds", + forward_only=True, + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.ERROR, + ), + ) + + old_snapshot = make_snapshot(old_model) + new_snapshot = make_snapshot(new_model) + + # Set previous versions to simulate a modification + new_snapshot.previous_versions = ( + SnapshotDataVersion( + fingerprint=SnapshotFingerprint( + data_hash="old_data_hash", + metadata_hash="old_metadata_hash", + ), + version="old_version", + change_category=SnapshotChangeCategory.FORWARD_ONLY, + dev_table_suffix="dev", + ), + ) + + context_diff = ContextDiff( + environment="prod", + is_new_environment=False, + is_unfinalized_environment=False, + normalize_environment_name=True, + create_from="prod", + create_from_env_exists=True, + added=set(), + removed_snapshots={}, + modified_snapshots={old_snapshot.name: (new_snapshot, old_snapshot)}, + snapshots={ + old_snapshot.snapshot_id: old_snapshot, + new_snapshot.snapshot_id: new_snapshot, + }, + new_snapshots={new_snapshot.snapshot_id: new_snapshot}, + previous_plan_id=None, + previously_promoted_snapshot_ids=set(), + previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], + ) + + builder = PlanBuilder(context_diff, forward_only=True) + + # Should error on additive change (new_col), but ignore destructive change (old_col removal) + with pytest.raises(PlanError, match="additive change"): + builder.build() + + +def test_plan_builder_allow_additive_models_flag(make_snapshot): + """Test that --allow-additive-model flag overrides on_additive_change=ERROR.""" + old_model = SqlModel( + name="test_model", + dialect="duckdb", + query=parse_one(""" + with source as ( + select 1::INT as id, 'test'::VARCHAR as name, '2022-01-01'::DATE as ds + ) + select id, name, ds from source + """), + kind=IncrementalByTimeRangeKind( + time_column="ds", + forward_only=True, + on_additive_change=OnAdditiveChange.ERROR, + ), + ) + + # New model with additional column (additive change) + new_model = SqlModel( + name="test_model", + dialect="duckdb", + query=parse_one(""" + with source as ( + select 1::INT as id, 'test'::VARCHAR as name, 'email@test.com'::VARCHAR as email, '2022-01-01'::DATE as ds + ) + select id, name, email, ds from source + """), + kind=IncrementalByTimeRangeKind( + time_column="ds", + forward_only=True, + on_additive_change=OnAdditiveChange.ERROR, + ), + ) + + old_snapshot = make_snapshot(old_model) + new_snapshot = make_snapshot(new_model) + + # Set previous versions to simulate a modification + new_snapshot.previous_versions = ( + SnapshotDataVersion( + fingerprint=SnapshotFingerprint( + data_hash="old_data_hash", + metadata_hash="old_metadata_hash", + ), + version="old_version", + change_category=SnapshotChangeCategory.FORWARD_ONLY, + dev_table_suffix="dev", + ), + ) + + context_diff = ContextDiff( + environment="prod", + is_new_environment=False, + is_unfinalized_environment=False, + normalize_environment_name=True, + create_from="prod", + create_from_env_exists=True, + added=set(), + removed_snapshots={}, + modified_snapshots={new_snapshot.name: (new_snapshot, old_snapshot)}, + snapshots={new_snapshot.snapshot_id: new_snapshot}, + new_snapshots={new_snapshot.snapshot_id: new_snapshot}, + previous_plan_id=None, + previously_promoted_snapshot_ids=set(), + previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], + ) + + # First, verify that without the flag, the plan fails with additive change error + builder = PlanBuilder(context_diff, forward_only=True) + with pytest.raises(PlanError, match="additive change"): + builder.build() + + # Now test that the --allow-additive-model flag allows the plan to succeed + builder_with_flag = PlanBuilder( + context_diff, + forward_only=True, + allow_additive_models={'"test_model"'}, + ) + + # Should succeed without raising an exception + plan = builder_with_flag.build() + assert plan is not None + + +def test_plan_builder_allow_additive_models_pattern_matching(make_snapshot): + """Test that --allow-additive-model flag supports pattern matching like destructive models.""" + # Create two models with additive changes + old_model_1 = SqlModel( + name="test.model_1", + dialect="duckdb", + query=parse_one(""" + with source as ( + select 1::INT as id, 'test'::VARCHAR as name, '2022-01-01'::DATE as ds + ) + select id, name, ds from source + """), + kind=IncrementalByTimeRangeKind( + time_column="ds", + forward_only=True, + on_additive_change=OnAdditiveChange.ERROR, + ), + ) + + new_model_1 = SqlModel( + name="test.model_1", + dialect="duckdb", + query=parse_one(""" + with source as ( + select 1::INT as id, 'test'::VARCHAR as name, 'email@test.com'::VARCHAR as email, '2022-01-01'::DATE as ds + ) + select id, name, email, ds from source + """), + kind=IncrementalByTimeRangeKind( + time_column="ds", + forward_only=True, + on_additive_change=OnAdditiveChange.ERROR, + ), + ) + + old_model_2 = SqlModel( + name="other.model_2", + dialect="duckdb", + query=parse_one(""" + with source as ( + select 1::INT as id, 'test'::VARCHAR as name, '2022-01-01'::DATE as ds + ) + select id, name, ds from source + """), + kind=IncrementalByTimeRangeKind( + time_column="ds", + forward_only=True, + on_additive_change=OnAdditiveChange.ERROR, + ), + ) + + new_model_2 = SqlModel( + name="other.model_2", + dialect="duckdb", + query=parse_one(""" + with source as ( + select 1::INT as id, 'test'::VARCHAR as name, 'phone'::VARCHAR as phone, '2022-01-01'::DATE as ds + ) + select id, name, phone, ds from source + """), + kind=IncrementalByTimeRangeKind( + time_column="ds", + forward_only=True, + on_additive_change=OnAdditiveChange.ERROR, + ), + ) + + old_snapshot_1 = make_snapshot(old_model_1) + new_snapshot_1 = make_snapshot(new_model_1) + old_snapshot_2 = make_snapshot(old_model_2) + new_snapshot_2 = make_snapshot(new_model_2) + + # Set previous versions to simulate modifications + for new_snapshot in [new_snapshot_1, new_snapshot_2]: + new_snapshot.previous_versions = ( + SnapshotDataVersion( + fingerprint=SnapshotFingerprint( + data_hash="old_data_hash", + metadata_hash="old_metadata_hash", + ), + version="old_version", + change_category=SnapshotChangeCategory.FORWARD_ONLY, + dev_table_suffix="dev", + ), + ) + + context_diff = ContextDiff( + environment="prod", + is_new_environment=False, + is_unfinalized_environment=False, + normalize_environment_name=True, + create_from="prod", + create_from_env_exists=True, + added=set(), + removed_snapshots={}, + modified_snapshots={ + new_snapshot_1.name: (new_snapshot_1, old_snapshot_1), + new_snapshot_2.name: (new_snapshot_2, old_snapshot_2), + }, + snapshots={ + new_snapshot_1.snapshot_id: new_snapshot_1, + new_snapshot_2.snapshot_id: new_snapshot_2, + }, + new_snapshots={ + new_snapshot_1.snapshot_id: new_snapshot_1, + new_snapshot_2.snapshot_id: new_snapshot_2, + }, + previous_plan_id=None, + previously_promoted_snapshot_ids=set(), + previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], + ) + + # Test pattern matching: allow only models in "test" schema + # In real usage, patterns would be expanded by Context.expand_model_selections + # Here we simulate what the expansion would produce + builder_with_pattern = PlanBuilder( + context_diff, + forward_only=True, + allow_additive_models={'"test"."model_1"'}, # Only allow test.model_1, not other.model_2 + ) + + # Should still fail because other.model_2 is not allowed + with pytest.raises(PlanError, match="additive change"): + builder_with_pattern.build() + + # Test allowing both patterns + builder_with_both = PlanBuilder( + context_diff, + forward_only=True, + allow_additive_models={'"test"."model_1"', '"other"."model_2"'}, # Allow both models + ) + + # Should succeed + plan = builder_with_both.build() + assert plan is not None + + +def test_environment_statements_change_allows_dev_environment_creation(make_snapshot): + snapshot = make_snapshot( + SqlModel( + name="test_model", + dialect="duckdb", + query=parse_one("select 1, ds"), + kind=dict(name=ModelKindName.INCREMENTAL_BY_TIME_RANGE, time_column="ds"), + ) + ) + + # First context diff of a new 'dev' environment without environment statements + context_diff_no_statements = ContextDiff( + environment="dev", + is_new_environment=True, + is_unfinalized_environment=False, + normalize_environment_name=True, + create_from="prod", + create_from_env_exists=True, + added=set(), + removed_snapshots={}, + modified_snapshots={}, + snapshots={snapshot.snapshot_id: snapshot}, + new_snapshots={}, + previous_plan_id=None, + previously_promoted_snapshot_ids={snapshot.snapshot_id}, + previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], + previous_environment_statements=[], + ) + + # Should fail because no changes + plan_builder = PlanBuilder( + context_diff_no_statements, + is_dev=True, + ) + + with pytest.raises(NoChangesPlanError, match="Creating a new environment requires a change"): + plan_builder.build() + + # Now create context diff with environment statements + environment_statements = [ + EnvironmentStatements( + before_all=["CREATE TABLE IF NOT EXISTS test_table (id INT)"], + after_all=[], + python_env={}, + jinja_macros=None, + ) + ] + + context_diff_with_statements = ContextDiff( + environment="dev", + is_new_environment=True, + is_unfinalized_environment=False, + normalize_environment_name=True, + create_from="prod", + create_from_env_exists=True, + added=set(), + removed_snapshots={}, + modified_snapshots={}, + snapshots={snapshot.snapshot_id: snapshot}, + new_snapshots={}, + previous_plan_id=None, + previously_promoted_snapshot_ids={snapshot.snapshot_id}, + previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=environment_statements, + previous_environment_statements=[], + ) + + # Should succeed because there are environment statements changes + plan_builder_with_statements = PlanBuilder( + context_diff_with_statements, + is_dev=True, + ) + + # Test that allows creating a dev environment without other changes + plan = plan_builder_with_statements.build() + assert plan is not None + assert plan.context_diff.has_environment_statements_changes + assert plan.context_diff.environment_statements == environment_statements + + +def test_plan_ignore_cron_flag(make_snapshot): + snapshot_a = make_snapshot( + SqlModel( + name="test_model", + kind=IncrementalByTimeRangeKind(time_column="ds"), + cron="@daily", # Daily cron schedule + start="2023-01-01", + query=parse_one("SELECT 1 as id, ds FROM VALUES ('2023-01-01') t(ds)"), + allow_partials=True, + ) + ) + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=False) + + context_diff = ContextDiff( + environment="dev", + is_new_environment=True, + is_unfinalized_environment=False, + normalize_environment_name=True, + create_from="prod", + create_from_env_exists=True, + added=set(), + removed_snapshots={}, + modified_snapshots={}, + snapshots={snapshot_a.snapshot_id: snapshot_a}, + new_snapshots={snapshot_a.snapshot_id: snapshot_a}, + previous_plan_id=None, + previously_promoted_snapshot_ids=set(), + previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], + ) + + plan_builder_ignore_cron = PlanBuilder( + context_diff, + start="2023-01-01", + execution_time="2023-01-05 12:00:00", + is_dev=True, + include_unmodified=True, + ignore_cron=True, + end_bounded=False, + ) + + plan = plan_builder_ignore_cron.build() + assert plan.ignore_cron is True + assert plan.to_evaluatable().ignore_cron is True + + assert plan.missing_intervals == [ + SnapshotIntervals( + snapshot_id=snapshot_a.snapshot_id, + intervals=[ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-05 12:00:00")), + ], + ) + ] + + +def test_indirect_change_to_materialized_view_is_breaking(make_snapshot): + snapshot_a_old = make_snapshot( + SqlModel( + name="a", + query=parse_one("select 1 as col_a"), + kind=ViewKind(materialized=True), + ) + ) + snapshot_a_old.categorize_as(SnapshotChangeCategory.BREAKING) + + snapshot_b_old = make_snapshot( + SqlModel( + name="b", + query=parse_one("select col_a from a"), + kind=ViewKind(materialized=True), + ), + nodes={'"a"': snapshot_a_old.model}, + ) + snapshot_b_old.categorize_as(SnapshotChangeCategory.BREAKING) + + snapshot_a_new = make_snapshot( + SqlModel( + name="a", + query=parse_one("select 1 as col_a, 2 as col_b"), + kind=ViewKind(materialized=True), + ) + ) + + snapshot_a_new.previous_versions = snapshot_a_old.all_versions + + snapshot_b_new = make_snapshot( + snapshot_b_old.model, + nodes={'"a"': snapshot_a_new.model}, + ) + snapshot_b_new.previous_versions = snapshot_b_old.all_versions + + context_diff = ContextDiff( + environment="test_environment", + is_new_environment=True, + is_unfinalized_environment=False, + normalize_environment_name=True, + create_from="prod", + create_from_env_exists=True, + added=set(), + removed_snapshots={}, + modified_snapshots={ + snapshot_a_new.name: (snapshot_a_new, snapshot_a_old), + snapshot_b_new.name: (snapshot_b_new, snapshot_b_old), + }, + snapshots={ + snapshot_a_new.snapshot_id: snapshot_a_new, + snapshot_b_new.snapshot_id: snapshot_b_new, + }, + new_snapshots={ + snapshot_a_new.snapshot_id: snapshot_a_new, + snapshot_b_new.snapshot_id: snapshot_b_new, + }, + previous_plan_id=None, + previously_promoted_snapshot_ids=set(), + previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], + ) + + PlanBuilder(context_diff, forward_only=False).build() + + assert snapshot_b_new.change_category == SnapshotChangeCategory.INDIRECT_BREAKING + + +def test_forward_only_indirect_change_to_materialized_view(make_snapshot): + snapshot_a_old = make_snapshot( + SqlModel( + name="a", + query=parse_one("select 1 as col_a"), + ) + ) + snapshot_a_old.categorize_as(SnapshotChangeCategory.BREAKING) + + snapshot_b_old = make_snapshot( + SqlModel( + name="b", + query=parse_one("select col_a from a"), + kind=ViewKind(materialized=True), + ), + nodes={'"a"': snapshot_a_old.model}, + ) + snapshot_b_old.categorize_as(SnapshotChangeCategory.BREAKING) + + snapshot_a_new = make_snapshot( + SqlModel( + name="a", + query=parse_one("select 1 as col_a, 2 as col_b"), + ) + ) + + snapshot_a_new.previous_versions = snapshot_a_old.all_versions + + snapshot_b_new = make_snapshot( + snapshot_b_old.model, + nodes={'"a"': snapshot_a_new.model}, + ) + snapshot_b_new.previous_versions = snapshot_b_old.all_versions + + context_diff = ContextDiff( + environment="test_environment", + is_new_environment=True, + is_unfinalized_environment=False, + normalize_environment_name=True, + create_from="prod", + create_from_env_exists=True, + added=set(), + removed_snapshots={}, + modified_snapshots={ + snapshot_a_new.name: (snapshot_a_new, snapshot_a_old), + snapshot_b_new.name: (snapshot_b_new, snapshot_b_old), + }, + snapshots={ + snapshot_a_new.snapshot_id: snapshot_a_new, + snapshot_b_new.snapshot_id: snapshot_b_new, + }, + new_snapshots={ + snapshot_a_new.snapshot_id: snapshot_a_new, + snapshot_b_new.snapshot_id: snapshot_b_new, + }, + previous_plan_id=None, + previously_promoted_snapshot_ids=set(), + previous_finalized_snapshots=None, + previous_gateway_managed_virtual_layer=False, + gateway_managed_virtual_layer=False, + environment_statements=[], + ) + + PlanBuilder(context_diff, forward_only=True).build() + + # Forward-only indirect changes to MVs should not always be classified as indirect breaking. + # Instead, we want to preserve the standard categorization. + assert snapshot_b_new.change_category == SnapshotChangeCategory.INDIRECT_NON_BREAKING diff --git a/tests/core/test_plan_evaluator.py b/tests/core/test_plan_evaluator.py index 6726139b8f..575f5ae742 100644 --- a/tests/core/test_plan_evaluator.py +++ b/tests/core/test_plan_evaluator.py @@ -5,16 +5,12 @@ from sqlmesh.core.context import Context from sqlmesh.core.model import FullKind, SqlModel, ViewKind from sqlmesh.core.plan import ( - AirflowPlanEvaluator, BuiltInPlanEvaluator, - MWAAPlanEvaluator, Plan, PlanBuilder, - update_intervals_for_new_snapshots, + stages as plan_stages, ) from sqlmesh.core.snapshot import SnapshotChangeCategory -from sqlmesh.utils.date import to_timestamp -from sqlmesh.utils.errors import SQLMeshError @pytest.fixture @@ -25,7 +21,6 @@ def sushi_plan(sushi_context: Context, mocker: MockerFixture) -> Plan: return PlanBuilder( sushi_context._context_diff("dev"), - sushi_context.engine_adapter.SCHEMA_DIFFER, is_dev=True, include_unmodified=True, ).build() @@ -60,17 +55,26 @@ def test_builtin_evaluator_push(sushi_context: Context, make_snapshot): new_model_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) new_view_model_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - plan = PlanBuilder( - sushi_context._context_diff("prod"), sushi_context.engine_adapter.SCHEMA_DIFFER - ).build() + plan = PlanBuilder(sushi_context._context_diff("prod")).build() evaluator = BuiltInPlanEvaluator( sushi_context.state_sync, sushi_context.snapshot_evaluator, + sushi_context.create_scheduler, sushi_context.default_catalog, console=sushi_context.console, ) - evaluator._push(plan) + + evaluatable_plan = plan.to_evaluatable() + stages = plan_stages.build_plan_stages( + evaluatable_plan, sushi_context.state_sync, sushi_context.default_catalog + ) + assert isinstance(stages[1], plan_stages.CreateSnapshotRecordsStage) + evaluator.visit_create_snapshot_records_stage(stages[1], evaluatable_plan) + assert isinstance(stages[2], plan_stages.PhysicalLayerSchemaCreationStage) + evaluator.visit_physical_layer_schema_creation_stage(stages[2], evaluatable_plan) + assert isinstance(stages[3], plan_stages.BackfillStage) + evaluator.visit_backfill_stage(stages[3], evaluatable_plan) assert ( len(sushi_context.state_sync.get_snapshots([new_model_snapshot, new_view_model_snapshot])) @@ -78,153 +82,3 @@ def test_builtin_evaluator_push(sushi_context: Context, make_snapshot): ) assert sushi_context.engine_adapter.table_exists(new_model_snapshot.table_name()) assert sushi_context.engine_adapter.table_exists(new_view_model_snapshot.table_name()) - - -def test_airflow_evaluator(sushi_plan: Plan, mocker: MockerFixture): - airflow_client_mock = mocker.Mock() - airflow_client_mock.wait_for_dag_run_completion.return_value = True - airflow_client_mock.wait_for_first_dag_run.return_value = "test_plan_application_dag_run_id" - - evaluator = AirflowPlanEvaluator(airflow_client_mock) - evaluator.evaluate(sushi_plan) - - airflow_client_mock.apply_plan.assert_called_once_with( - sushi_plan.new_snapshots, - sushi_plan.environment, - mocker.ANY, - no_gaps=False, - notification_targets=[], - restatements={}, - backfill_concurrent_tasks=1, - ddl_concurrent_tasks=1, - skip_backfill=False, - users=[], - is_dev=True, - allow_destructive_snapshots=set(), - forward_only=False, - models_to_backfill=None, - end_bounded=False, - ensure_finalized_snapshots=False, - directly_modified_snapshots=[], - indirectly_modified_snapshots={}, - removed_snapshots=[], - execution_time=None, - ) - - airflow_client_mock.wait_for_dag_run_completion.assert_called_once() - airflow_client_mock.wait_for_first_dag_run.assert_called_once() - - -def test_airflow_evaluator_plan_application_dag_fails(sushi_plan: Plan, mocker: MockerFixture): - airflow_client_mock = mocker.Mock() - airflow_client_mock.wait_for_dag_run_completion.return_value = False - airflow_client_mock.wait_for_first_dag_run.return_value = "test_plan_application_dag_run_id" - - evaluator = AirflowPlanEvaluator(airflow_client_mock) - - with pytest.raises(SQLMeshError): - evaluator.evaluate(sushi_plan) - - airflow_client_mock.apply_plan.assert_called_once() - airflow_client_mock.wait_for_dag_run_completion.assert_called_once() - airflow_client_mock.wait_for_first_dag_run.assert_called_once() - - -def test_mwaa_evaluator(sushi_plan: Plan, mocker: MockerFixture): - mwaa_client_mock = mocker.Mock() - mwaa_client_mock.wait_for_dag_run_completion.return_value = True - mwaa_client_mock.wait_for_first_dag_run.return_value = "test_plan_application_dag_run_id" - mwaa_client_mock.set_variable.return_value = "", "" - - state_sync_mock = mocker.Mock() - - plan_dag_spec_mock = mocker.Mock() - - create_plan_dag_spec_mock = mocker.patch("sqlmesh.schedulers.airflow.plan.create_plan_dag_spec") - create_plan_dag_spec_mock.return_value = plan_dag_spec_mock - - plan_dag_state_mock = mocker.Mock() - mocker.patch( - "sqlmesh.schedulers.airflow.plan.PlanDagState.from_state_sync", - return_value=plan_dag_state_mock, - ) - - evaluator = MWAAPlanEvaluator(mwaa_client_mock, state_sync_mock) - evaluator.evaluate(sushi_plan) - - plan_dag_state_mock.add_dag_spec.assert_called_once_with(plan_dag_spec_mock) - - mwaa_client_mock.wait_for_dag_run_completion.assert_called_once() - mwaa_client_mock.wait_for_first_dag_run.assert_called_once() - - -@pytest.mark.parametrize( - "change_category", [SnapshotChangeCategory.BREAKING, SnapshotChangeCategory.FORWARD_ONLY] -) -def test_update_intervals_for_new_snapshots( - sushi_context: Context, - mocker: MockerFixture, - change_category: SnapshotChangeCategory, - make_snapshot, -): - model = SqlModel( - name="sushi.new_test_model", - query=parse_one("SELECT 1::INT AS one"), - ) - snapshot = make_snapshot(model) - snapshot.change_category = change_category - - snapshot.add_interval("2023-01-01", "2023-01-01") - - state_sync_mock = mocker.Mock() - state_sync_mock.refresh_snapshot_intervals.return_value = [snapshot] - - update_intervals_for_new_snapshots([snapshot], state_sync_mock) - - state_sync_mock.refresh_snapshot_intervals.assert_called_once_with([snapshot]) - - if change_category == SnapshotChangeCategory.FORWARD_ONLY: - assert snapshot.dev_intervals == [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))] - state_sync_mock.add_interval.assert_called_once_with( - snapshot, to_timestamp("2023-01-01"), to_timestamp("2023-01-02"), is_dev=True - ) - else: - assert not snapshot.dev_intervals - state_sync_mock.add_interval.assert_not_called() - - -def test_state_based_airflow_evaluator_with_restatements( - sushi_context: Context, mocker: MockerFixture -): - model_fqn = sushi_context.get_model("sushi.waiter_revenue_by_day").fqn - downstream_model_fqn = sushi_context.get_model("sushi.top_waiters").fqn - - plan = PlanBuilder( - sushi_context._context_diff("prod"), - sushi_context.engine_adapter.SCHEMA_DIFFER, - restate_models=[sushi_context.get_model("sushi.waiter_revenue_by_day").fqn], - ).build() - - mwaa_client_mock = mocker.Mock() - mwaa_client_mock.wait_for_dag_run_completion.return_value = True - mwaa_client_mock.wait_for_first_dag_run.return_value = "test_plan_application_dag_run_id" - mwaa_client_mock.set_variable.return_value = "", "" - - state_sync_mock = mocker.Mock() - - plan_dag_spec_mock = mocker.Mock() - - create_plan_dag_spec_mock = mocker.patch("sqlmesh.schedulers.airflow.plan.create_plan_dag_spec") - create_plan_dag_spec_mock.return_value = plan_dag_spec_mock - - plan_dag_state_mock = mocker.Mock() - mocker.patch( - "sqlmesh.schedulers.airflow.plan.PlanDagState.from_state_sync", - return_value=plan_dag_state_mock, - ) - - evaluator = MWAAPlanEvaluator(mwaa_client_mock, state_sync_mock) - evaluator.evaluate(plan) - - plan_application_request = create_plan_dag_spec_mock.call_args[0][0] - assert plan_application_request.restatements.keys() == {model_fqn, downstream_model_fqn} diff --git a/tests/core/test_plan_stages.py b/tests/core/test_plan_stages.py new file mode 100644 index 0000000000..f93a8a4780 --- /dev/null +++ b/tests/core/test_plan_stages.py @@ -0,0 +1,2190 @@ +import pytest +import typing as t +from sqlglot import parse_one +from pytest_mock.plugin import MockerFixture + +from sqlmesh.core.config import EnvironmentSuffixTarget +from sqlmesh.core.config.common import VirtualEnvironmentMode +from sqlmesh.core.model import SqlModel, ModelKindName +from sqlmesh.core.plan.common import SnapshotIntervalClearRequest +from sqlmesh.core.plan.definition import EvaluatablePlan +from sqlmesh.core.plan.stages import ( + build_plan_stages, + AfterAllStage, + AuditOnlyRunStage, + PhysicalLayerUpdateStage, + PhysicalLayerSchemaCreationStage, + CreateSnapshotRecordsStage, + BeforeAllStage, + BackfillStage, + EnvironmentRecordUpdateStage, + VirtualLayerUpdateStage, + RestatementStage, + MigrateSchemasStage, + FinalizeEnvironmentStage, + UnpauseStage, +) +from sqlmesh.core.plan.explainer import ExplainableRestatementStage +from sqlmesh.core.snapshot.definition import ( + SnapshotChangeCategory, + DeployabilityIndex, + Snapshot, + SnapshotId, + SnapshotIdLike, +) +from sqlmesh.core.state_sync import StateReader +from sqlmesh.core.environment import Environment, EnvironmentStatements +from sqlmesh.utils.date import to_timestamp + + +@pytest.fixture +def snapshot_a(make_snapshot) -> Snapshot: + snapshot = make_snapshot( + SqlModel( + name="a", + query=parse_one("select 1, ds"), + kind=dict(name=ModelKindName.INCREMENTAL_BY_TIME_RANGE, time_column="ds"), + ) + ) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + return snapshot + + +@pytest.fixture +def snapshot_b(make_snapshot, snapshot_a: Snapshot) -> Snapshot: + snapshot = make_snapshot( + SqlModel( + name="b", + query=parse_one("select 2, ds from a"), + kind=dict(name=ModelKindName.INCREMENTAL_BY_TIME_RANGE, time_column="ds"), + ), + nodes={'"a"': snapshot_a.model}, + ) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + return snapshot + + +@pytest.fixture +def snapshot_c(make_snapshot, snapshot_a: Snapshot) -> Snapshot: + snapshot = make_snapshot( + SqlModel( + name="c", + query=parse_one("select * from a"), + kind=dict(name=ModelKindName.VIEW), + ), + nodes={'"a"': snapshot_a.model}, + ) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + return snapshot + + +def test_build_plan_stages_basic( + snapshot_a: Snapshot, snapshot_b: Snapshot, mocker: MockerFixture +) -> None: + # Mock state reader + state_reader = mocker.Mock(spec=StateReader) + state_reader.get_snapshots.return_value = {} + state_reader.get_environment.return_value = None + + # Create environment + environment = Environment( + snapshots=[snapshot_a.table_info, snapshot_b.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="test_plan", + previous_plan_id=None, + promoted_snapshot_ids=[snapshot_a.snapshot_id, snapshot_b.snapshot_id], + ) + + # Create evaluatable plan + plan = EvaluatablePlan( + start="2023-01-01", + end="2023-01-02", + new_snapshots=[snapshot_a, snapshot_b], + environment=environment, + no_gaps=False, + skip_backfill=False, + empty_backfill=False, + restatements={}, + restate_all_snapshots=False, + is_dev=False, + allow_destructive_models=set(), + allow_additive_models=set(), + forward_only=False, + end_bounded=False, + ensure_finalized_snapshots=False, + ignore_cron=False, + directly_modified_snapshots=[snapshot_a.snapshot_id, snapshot_b.snapshot_id], + indirectly_modified_snapshots={}, + metadata_updated_snapshots=[], + removed_snapshots=[], + requires_backfill=True, + models_to_backfill=None, + execution_time="2023-01-02", + disabled_restatement_models=set(), + environment_statements=None, + user_provided_flags=None, + ) + + # Build plan stages + stages = build_plan_stages(plan, state_reader, None) + + # Verify stages + assert len(stages) == 7 + + # Verify CreateSnapshotRecordsStage + create_snapshot_records_stage = stages[0] + assert isinstance(create_snapshot_records_stage, CreateSnapshotRecordsStage) + assert len(create_snapshot_records_stage.snapshots) == 2 + assert {s.snapshot_id for s in create_snapshot_records_stage.snapshots} == { + snapshot_a.snapshot_id, + snapshot_b.snapshot_id, + } + # Verify PhysicalLayerSchemaCreationStage + physical_stage = stages[1] + assert isinstance(physical_stage, PhysicalLayerSchemaCreationStage) + assert len(physical_stage.snapshots) == 2 + assert {s.snapshot_id for s in physical_stage.snapshots} == { + snapshot_a.snapshot_id, + snapshot_b.snapshot_id, + } + assert physical_stage.deployability_index == DeployabilityIndex.all_deployable() + + # Verify BackfillStage + backfill_stage = stages[2] + assert isinstance(backfill_stage, BackfillStage) + assert backfill_stage.deployability_index == DeployabilityIndex.all_deployable() + expected_interval = (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")) + assert len(backfill_stage.snapshot_to_intervals) == 2 + assert backfill_stage.snapshot_to_intervals[snapshot_a] == [expected_interval] + assert backfill_stage.snapshot_to_intervals[snapshot_b] == [expected_interval] + + # Verify EnvironmentRecordUpdateStage + assert isinstance(stages[3], EnvironmentRecordUpdateStage) + assert stages[3].no_gaps_snapshot_names == {snapshot_a.name, snapshot_b.name} + + # Verify UnpauseStage + assert isinstance(stages[4], UnpauseStage) + assert {s.name for s in stages[4].promoted_snapshots} == {snapshot_a.name, snapshot_b.name} + + # Verify VirtualLayerUpdateStage + virtual_stage = stages[5] + assert isinstance(virtual_stage, VirtualLayerUpdateStage) + assert len(virtual_stage.promoted_snapshots) == 2 + assert len(virtual_stage.demoted_snapshots) == 0 + assert {s.name for s in virtual_stage.promoted_snapshots} == {snapshot_a.name, snapshot_b.name} + + state_reader.refresh_snapshot_intervals.assert_called_once() + + assert isinstance(stages[6], FinalizeEnvironmentStage) + + +def test_build_plan_stages_with_before_all_and_after_all( + snapshot_a: Snapshot, snapshot_b: Snapshot, mocker: MockerFixture +) -> None: + # Mock state reader + state_reader = mocker.Mock(spec=StateReader) + state_reader.get_snapshots.return_value = {} + state_reader.get_environment.return_value = None + + # Create environment + environment = Environment( + snapshots=[snapshot_a.table_info, snapshot_b.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="test_plan", + previous_plan_id=None, + promoted_snapshot_ids=[snapshot_a.snapshot_id, snapshot_b.snapshot_id], + ) + + environment_statements = [ + EnvironmentStatements( + before_all=["BEFORE ALL A", "BEFORE ALL B"], + after_all=["AFTER ALL A", "AFTER ALL B"], + python_env={}, + jinja_macros=None, + ) + ] + + # Create evaluatable plan + plan = EvaluatablePlan( + start="2023-01-01", + end="2023-01-02", + new_snapshots=[snapshot_a, snapshot_b], + environment=environment, + no_gaps=False, + skip_backfill=False, + empty_backfill=False, + restatements={}, + restate_all_snapshots=False, + is_dev=False, + allow_destructive_models=set(), + allow_additive_models=set(), + forward_only=False, + end_bounded=False, + ensure_finalized_snapshots=False, + ignore_cron=False, + directly_modified_snapshots=[snapshot_a.snapshot_id, snapshot_b.snapshot_id], + indirectly_modified_snapshots={}, + metadata_updated_snapshots=[], + removed_snapshots=[], + requires_backfill=True, + models_to_backfill=None, + execution_time="2023-01-02", + disabled_restatement_models=set(), + environment_statements=environment_statements, + user_provided_flags=None, + ) + + # Build plan stages + stages = build_plan_stages(plan, state_reader, None) + + # Verify stages + assert len(stages) == 9 + + # Verify BeforeAllStage + before_all_stage = stages[0] + assert isinstance(before_all_stage, BeforeAllStage) + assert before_all_stage.statements == environment_statements + + # Verify CreateSnapshotRecordsStage + create_snapshot_records_stage = stages[1] + assert isinstance(create_snapshot_records_stage, CreateSnapshotRecordsStage) + assert len(create_snapshot_records_stage.snapshots) == 2 + assert {s.snapshot_id for s in create_snapshot_records_stage.snapshots} == { + snapshot_a.snapshot_id, + snapshot_b.snapshot_id, + } + + # Verify PhysicalLayerSchemaCreationStage + physical_stage = stages[2] + assert isinstance(physical_stage, PhysicalLayerSchemaCreationStage) + assert len(physical_stage.snapshots) == 2 + assert {s.snapshot_id for s in physical_stage.snapshots} == { + snapshot_a.snapshot_id, + snapshot_b.snapshot_id, + } + assert physical_stage.deployability_index == DeployabilityIndex.all_deployable() + + # Verify BackfillStage + backfill_stage = stages[3] + assert isinstance(backfill_stage, BackfillStage) + assert backfill_stage.deployability_index == DeployabilityIndex.all_deployable() + expected_interval = (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")) + assert len(backfill_stage.snapshot_to_intervals) == 2 + assert backfill_stage.snapshot_to_intervals[snapshot_a] == [expected_interval] + assert backfill_stage.snapshot_to_intervals[snapshot_b] == [expected_interval] + + # Verify EnvironmentRecordUpdateStage + assert isinstance(stages[4], EnvironmentRecordUpdateStage) + assert stages[4].no_gaps_snapshot_names == {snapshot_a.name, snapshot_b.name} + + # Verify UnpauseStage + assert isinstance(stages[5], UnpauseStage) + assert {s.name for s in stages[5].promoted_snapshots} == {snapshot_a.name, snapshot_b.name} + + # Verify VirtualLayerUpdateStage + virtual_stage = stages[6] + assert isinstance(virtual_stage, VirtualLayerUpdateStage) + assert len(virtual_stage.promoted_snapshots) == 2 + assert len(virtual_stage.demoted_snapshots) == 0 + assert {s.name for s in virtual_stage.promoted_snapshots} == {'"a"', '"b"'} + + # Verify FinalizeEnvironmentStage + assert isinstance(stages[7], FinalizeEnvironmentStage) + + # Verify AfterAllStage + after_all_stage = stages[8] + assert isinstance(after_all_stage, AfterAllStage) + assert after_all_stage.statements == environment_statements + + +def test_build_plan_stages_select_models( + snapshot_a: Snapshot, snapshot_b: Snapshot, mocker: MockerFixture +) -> None: + # Mock state reader + state_reader = mocker.Mock(spec=StateReader) + state_reader.get_snapshots.return_value = {} + state_reader.get_environment.return_value = None + + # Create environment + environment = Environment( + snapshots=[snapshot_a.table_info, snapshot_b.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="test_plan", + previous_plan_id=None, + promoted_snapshot_ids=[snapshot_a.snapshot_id], + ) + + # Create evaluatable plan + plan = EvaluatablePlan( + start="2023-01-01", + end="2023-01-02", + new_snapshots=[snapshot_a, snapshot_b], + environment=environment, + no_gaps=False, + skip_backfill=False, + empty_backfill=False, + restatements={}, + restate_all_snapshots=False, + is_dev=False, + allow_destructive_models=set(), + allow_additive_models=set(), + forward_only=False, + end_bounded=False, + ensure_finalized_snapshots=False, + ignore_cron=False, + directly_modified_snapshots=[snapshot_a.snapshot_id, snapshot_b.snapshot_id], + indirectly_modified_snapshots={}, + metadata_updated_snapshots=[], + removed_snapshots=[], + requires_backfill=True, + models_to_backfill={snapshot_a.name}, + execution_time="2023-01-02", + disabled_restatement_models=set(), + environment_statements=None, + user_provided_flags=None, + ) + + # Build plan stages + stages = build_plan_stages(plan, state_reader, None) + + # Verify stages + assert len(stages) == 7 + + # Verify CreateSnapshotRecordsStage + create_snapshot_records_stage = stages[0] + assert isinstance(create_snapshot_records_stage, CreateSnapshotRecordsStage) + assert len(create_snapshot_records_stage.snapshots) == 2 + assert {s.snapshot_id for s in create_snapshot_records_stage.snapshots} == { + snapshot_a.snapshot_id, + snapshot_b.snapshot_id, + } + + # Verify PhysicalLayerSchemaCreationStage + physical_stage = stages[1] + assert isinstance(physical_stage, PhysicalLayerSchemaCreationStage) + assert len(physical_stage.snapshots) == 1 + assert {s.snapshot_id for s in physical_stage.snapshots} == {snapshot_a.snapshot_id} + assert physical_stage.deployability_index == DeployabilityIndex.all_deployable() + + # Verify BackfillStage + backfill_stage = stages[2] + assert isinstance(backfill_stage, BackfillStage) + assert backfill_stage.deployability_index == DeployabilityIndex.all_deployable() + expected_interval = (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")) + assert len(backfill_stage.snapshot_to_intervals) == 1 + assert backfill_stage.snapshot_to_intervals[snapshot_a] == [expected_interval] + + # Verify EnvironmentRecordUpdateStage + assert isinstance(stages[3], EnvironmentRecordUpdateStage) + assert stages[3].no_gaps_snapshot_names == {snapshot_a.name} + + # Verify UnpauseStage + assert isinstance(stages[4], UnpauseStage) + assert {s.name for s in stages[4].promoted_snapshots} == {snapshot_a.name} + + # Verify VirtualLayerUpdateStage + virtual_stage = stages[5] + assert isinstance(virtual_stage, VirtualLayerUpdateStage) + assert len(virtual_stage.promoted_snapshots) == 1 + assert len(virtual_stage.demoted_snapshots) == 0 + assert {s.name for s in virtual_stage.promoted_snapshots} == {'"a"'} + + # Verify FinalizeEnvironmentStage + assert isinstance(stages[6], FinalizeEnvironmentStage) + + +@pytest.mark.parametrize("skip_backfill,empty_backfill", [(True, False), (False, True)]) +def test_build_plan_stages_basic_no_backfill( + snapshot_a: Snapshot, + snapshot_b: Snapshot, + mocker: MockerFixture, + skip_backfill: bool, + empty_backfill: bool, +) -> None: + # Mock state reader + state_reader = mocker.Mock(spec=StateReader) + state_reader.get_snapshots.return_value = {} + state_reader.get_environment.return_value = None + + # Create environment + environment = Environment( + snapshots=[snapshot_a.table_info, snapshot_b.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="test_plan", + previous_plan_id=None, + promoted_snapshot_ids=[snapshot_a.snapshot_id, snapshot_b.snapshot_id], + ) + + # Create evaluatable plan + plan = EvaluatablePlan( + start="2023-01-01", + end="2023-01-02", + new_snapshots=[snapshot_a, snapshot_b], + environment=environment, + no_gaps=False, + skip_backfill=skip_backfill, + empty_backfill=empty_backfill, + restatements={}, + restate_all_snapshots=False, + is_dev=False, + allow_destructive_models=set(), + allow_additive_models=set(), + forward_only=False, + end_bounded=False, + ensure_finalized_snapshots=False, + ignore_cron=False, + directly_modified_snapshots=[snapshot_a.snapshot_id, snapshot_b.snapshot_id], + indirectly_modified_snapshots={}, + metadata_updated_snapshots=[], + removed_snapshots=[], + requires_backfill=True, + models_to_backfill=None, + execution_time="2023-01-02", + disabled_restatement_models=set(), + environment_statements=None, + user_provided_flags=None, + ) + + # Build plan stages + stages = build_plan_stages(plan, state_reader, None) + + # Verify stages + assert len(stages) == 8 + + # Verify CreateSnapshotRecordsStage + create_snapshot_records_stage = stages[0] + assert isinstance(create_snapshot_records_stage, CreateSnapshotRecordsStage) + assert len(create_snapshot_records_stage.snapshots) == 2 + assert {s.snapshot_id for s in create_snapshot_records_stage.snapshots} == { + snapshot_a.snapshot_id, + snapshot_b.snapshot_id, + } + # Verify PhysicalLayerSchemaCreationStage + physical_stage = stages[1] + assert isinstance(physical_stage, PhysicalLayerSchemaCreationStage) + assert len(physical_stage.snapshots) == 2 + assert {s.snapshot_id for s in physical_stage.snapshots} == { + snapshot_a.snapshot_id, + snapshot_b.snapshot_id, + } + + # Verify PhysicalLayerUpdateStage + physical_stage = stages[2] + assert isinstance(physical_stage, PhysicalLayerUpdateStage) + assert len(physical_stage.snapshots) == 2 + assert {s.snapshot_id for s in physical_stage.snapshots} == { + snapshot_a.snapshot_id, + snapshot_b.snapshot_id, + } + + # Verify BackfillStage + backfill_stage = stages[3] + assert isinstance(backfill_stage, BackfillStage) + assert backfill_stage.deployability_index == DeployabilityIndex.all_deployable() + assert backfill_stage.snapshot_to_intervals == {} + + # Verify EnvironmentRecordUpdateStage + assert isinstance(stages[4], EnvironmentRecordUpdateStage) + assert stages[4].no_gaps_snapshot_names == {snapshot_a.name, snapshot_b.name} + + # Verify UnpauseStage + assert isinstance(stages[5], UnpauseStage) + assert {s.name for s in stages[5].promoted_snapshots} == {snapshot_a.name, snapshot_b.name} + + # Verify VirtualLayerUpdateStage + virtual_stage = stages[6] + assert isinstance(virtual_stage, VirtualLayerUpdateStage) + assert len(virtual_stage.promoted_snapshots) == 2 + assert len(virtual_stage.demoted_snapshots) == 0 + assert {s.name for s in virtual_stage.promoted_snapshots} == {'"a"', '"b"'} + + # Verify FinalizeEnvironmentStage + assert isinstance(stages[7], FinalizeEnvironmentStage) + + +def test_build_plan_stages_restatement_prod_only( + snapshot_a: Snapshot, snapshot_b: Snapshot, mocker: MockerFixture +) -> None: + """ + Scenario: + - Prod restatement triggered in a project with no dev environments + + Expected Outcome: + - Plan still contains a RestatementStage in case a dev environment was + created during restatement + """ + + # Mock state reader to return existing snapshots and environment + state_reader = mocker.Mock(spec=StateReader) + state_reader.get_snapshots.return_value = { + snapshot_a.snapshot_id: snapshot_a, + snapshot_b.snapshot_id: snapshot_b, + } + state_reader.get_snapshots_by_names.return_value = { + snapshot_a.id_and_version, + snapshot_b.id_and_version, + } + + existing_environment = Environment( + name="prod", + snapshots=[snapshot_a.table_info, snapshot_b.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="previous_plan", + previous_plan_id=None, + promoted_snapshot_ids=[snapshot_a.snapshot_id, snapshot_b.snapshot_id], + finalized_ts=to_timestamp("2023-01-02"), + ) + + state_reader.get_environment.return_value = existing_environment + state_reader.get_environments_summary.return_value = [existing_environment.summary] + + environment = Environment( + name="prod", + snapshots=[snapshot_a.table_info, snapshot_b.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="test_plan", + previous_plan_id="previous_plan", + promoted_snapshot_ids=[snapshot_a.snapshot_id, snapshot_b.snapshot_id], + ) + + # Create evaluatable plan with restatements + plan = EvaluatablePlan( + start="2023-01-01", + end="2023-01-02", + new_snapshots=[], # No new snapshots + environment=environment, + no_gaps=False, + skip_backfill=False, + empty_backfill=False, + restatements={ + '"a"': (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + '"b"': (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + }, + restate_all_snapshots=True, + is_dev=False, + allow_destructive_models=set(), + allow_additive_models=set(), + forward_only=False, + end_bounded=False, + ensure_finalized_snapshots=False, + ignore_cron=False, + directly_modified_snapshots=[], # No changes + indirectly_modified_snapshots={}, # No changes + metadata_updated_snapshots=[], + removed_snapshots=[], + requires_backfill=True, + models_to_backfill=None, + execution_time="2023-01-02", + disabled_restatement_models=set(), + environment_statements=None, + user_provided_flags=None, + ) + + # Build plan stages + stages = build_plan_stages(plan, state_reader, None) + + # Verify stages + assert len(stages) == 5 + + # Verify PhysicalLayerSchemaCreationStage + physical_stage = stages[0] + assert isinstance(physical_stage, PhysicalLayerSchemaCreationStage) + assert len(physical_stage.snapshots) == 2 + assert {s.snapshot_id for s in physical_stage.snapshots} == { + snapshot_a.snapshot_id, + snapshot_b.snapshot_id, + } + + # Verify BackfillStage + backfill_stage = stages[1] + assert isinstance(backfill_stage, BackfillStage) + assert len(backfill_stage.snapshot_to_intervals) == 2 + assert backfill_stage.deployability_index == DeployabilityIndex.all_deployable() + expected_backfill_interval = [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))] + for intervals in backfill_stage.snapshot_to_intervals.values(): + assert intervals == expected_backfill_interval + + # Verify RestatementStage exists but is empty + restatement_stage = stages[2] + assert isinstance(restatement_stage, RestatementStage) + restatement_stage = ExplainableRestatementStage.from_restatement_stage( + restatement_stage, state_reader, plan + ) + assert not restatement_stage.snapshot_intervals_to_clear + + # Verify EnvironmentRecordUpdateStage + assert isinstance(stages[3], EnvironmentRecordUpdateStage) + + # Verify FinalizeEnvironmentStage + assert isinstance(stages[4], FinalizeEnvironmentStage) + + +def test_build_plan_stages_restatement_prod_identifies_dev_intervals( + snapshot_a: Snapshot, + snapshot_b: Snapshot, + make_snapshot: t.Callable[..., Snapshot], + mocker: MockerFixture, +) -> None: + """ + Scenario: + - Prod restatement triggered in a project with a dev environment + - The dev environment contains a different physical version of the affected model + + Expected Outcome: + - Plan contains a RestatementStage that highlights the affected dev version + """ + # Dev version of snapshot_a, same name but different version + snapshot_a_dev = make_snapshot( + SqlModel( + name="a", + query=parse_one("select 1, changed, ds"), + kind=dict(name=ModelKindName.INCREMENTAL_BY_TIME_RANGE, time_column="ds"), + ) + ) + snapshot_a_dev.categorize_as(SnapshotChangeCategory.BREAKING) + assert snapshot_a_dev.snapshot_id != snapshot_a.snapshot_id + assert snapshot_a_dev.table_info != snapshot_a.table_info + + # Mock state reader to return existing snapshots and environment + state_reader = mocker.Mock(spec=StateReader) + snapshots_in_state = { + snapshot_a.snapshot_id: snapshot_a, + snapshot_b.snapshot_id: snapshot_b, + snapshot_a_dev.snapshot_id: snapshot_a_dev, + } + + def _get_snapshots(snapshot_ids: t.Iterable[SnapshotIdLike]): + return { + k: v + for k, v in snapshots_in_state.items() + if k in {s.snapshot_id for s in snapshot_ids} + } + + state_reader.get_snapshots.side_effect = _get_snapshots + state_reader.get_snapshots_by_names.return_value = set() + + existing_prod_environment = Environment( + name="prod", + snapshots=[snapshot_a.table_info, snapshot_b.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="previous_plan", + previous_plan_id=None, + promoted_snapshot_ids=[snapshot_a.snapshot_id, snapshot_b.snapshot_id], + finalized_ts=to_timestamp("2023-01-02"), + ) + + # dev has new version of snapshot_a but same version of snapshot_b + existing_dev_environment = Environment( + name="dev", + snapshots=[snapshot_a_dev.table_info, snapshot_b.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="previous_plan", + previous_plan_id=None, + promoted_snapshot_ids=[snapshot_a_dev.snapshot_id, snapshot_b.snapshot_id], + finalized_ts=to_timestamp("2023-01-02"), + ) + + state_reader.get_environment.side_effect = ( + lambda name: existing_dev_environment if name == "dev" else existing_prod_environment + ) + state_reader.get_environments_summary.return_value = [ + existing_prod_environment.summary, + existing_dev_environment.summary, + ] + + environment = Environment( + name="prod", + snapshots=[snapshot_a.table_info, snapshot_b.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="test_plan", + previous_plan_id="previous_plan", + promoted_snapshot_ids=[snapshot_a.snapshot_id, snapshot_b.snapshot_id], + ) + + # Create evaluatable plan with restatements + plan = EvaluatablePlan( + start="2023-01-01", + end="2023-01-02", + new_snapshots=[], # No new snapshots + environment=environment, + no_gaps=False, + skip_backfill=False, + empty_backfill=False, + restatements={ + '"a"': (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + '"b"': (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + }, + restate_all_snapshots=True, + is_dev=False, + allow_destructive_models=set(), + allow_additive_models=set(), + forward_only=False, + end_bounded=False, + ensure_finalized_snapshots=False, + ignore_cron=False, + directly_modified_snapshots=[], # No changes + indirectly_modified_snapshots={}, # No changes + metadata_updated_snapshots=[], + removed_snapshots=[], + requires_backfill=True, + models_to_backfill=None, + execution_time="2023-01-02", + disabled_restatement_models=set(), + environment_statements=None, + user_provided_flags=None, + ) + + # Build plan stages + stages = build_plan_stages(plan, state_reader, None) + + # Verify stages + assert len(stages) == 5 + + # Verify PhysicalLayerSchemaCreationStage + physical_stage = stages[0] + assert isinstance(physical_stage, PhysicalLayerSchemaCreationStage) + assert len(physical_stage.snapshots) == 2 + assert {s.snapshot_id for s in physical_stage.snapshots} == { + snapshot_a.snapshot_id, + snapshot_b.snapshot_id, + } + + # Verify BackfillStage + backfill_stage = stages[1] + assert isinstance(backfill_stage, BackfillStage) + assert len(backfill_stage.snapshot_to_intervals) == 2 + assert backfill_stage.deployability_index == DeployabilityIndex.all_deployable() + expected_backfill_interval = [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))] + for intervals in backfill_stage.snapshot_to_intervals.values(): + assert intervals == expected_backfill_interval + + # Verify RestatementStage + restatement_stage = stages[2] + assert isinstance(restatement_stage, RestatementStage) + restatement_stage = ExplainableRestatementStage.from_restatement_stage( + restatement_stage, state_reader, plan + ) + + # note: we only clear the intervals from state for "a" in dev, we leave prod alone + assert restatement_stage.snapshot_intervals_to_clear + assert len(restatement_stage.snapshot_intervals_to_clear) == 1 + snapshot_name, clear_requests = list(restatement_stage.snapshot_intervals_to_clear.items())[0] + assert snapshot_name == '"a"' + assert len(clear_requests) == 1 + clear_request = clear_requests[0] + assert isinstance(clear_request, SnapshotIntervalClearRequest) + assert clear_request.snapshot_id == snapshot_a_dev.snapshot_id + assert clear_request.snapshot == snapshot_a_dev.id_and_version + assert clear_request.interval == (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")) + + # Verify EnvironmentRecordUpdateStage + assert isinstance(stages[3], EnvironmentRecordUpdateStage) + + # Verify FinalizeEnvironmentStage + assert isinstance(stages[4], FinalizeEnvironmentStage) + + +def test_build_plan_stages_restatement_dev_does_not_clear_intervals( + snapshot_a: Snapshot, + snapshot_b: Snapshot, + make_snapshot: t.Callable[..., Snapshot], + mocker: MockerFixture, +) -> None: + """ + Scenario: + - Restatement triggered against the dev environment + + Expected Outcome: + - BackfillStage only touches models in that dev environment + - Plan does not contain a RestatementStage because making changes in dev doesnt mean we need + to clear intervals from other environments + """ + # Dev version of snapshot_a, same name but different version + snapshot_a_dev = make_snapshot( + SqlModel( + name="a", + query=parse_one("select 1, changed, ds"), + kind=dict(name=ModelKindName.INCREMENTAL_BY_TIME_RANGE, time_column="ds"), + ) + ) + snapshot_a_dev.categorize_as(SnapshotChangeCategory.BREAKING) + assert snapshot_a_dev.snapshot_id != snapshot_a.snapshot_id + assert snapshot_a_dev.table_info != snapshot_a.table_info + + # Mock state reader to return existing snapshots and environment + state_reader = mocker.Mock(spec=StateReader) + snapshots_in_state = { + snapshot_a.snapshot_id: snapshot_a, + snapshot_b.snapshot_id: snapshot_b, + snapshot_a_dev.snapshot_id: snapshot_a_dev, + } + state_reader.get_snapshots.side_effect = lambda snapshot_info_like: { + k: v + for k, v in snapshots_in_state.items() + if k in [sil.snapshot_id for sil in snapshot_info_like] + } + + # prod has snapshot_a, snapshot_b + existing_prod_environment = Environment( + name="prod", + snapshots=[snapshot_a.table_info, snapshot_b.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="previous_prod_plan", + previous_plan_id=None, + promoted_snapshot_ids=[snapshot_a.snapshot_id, snapshot_b.snapshot_id], + finalized_ts=to_timestamp("2023-01-02"), + ) + + # dev has new version of snapshot_a + existing_dev_environment = Environment( + name="dev", + snapshots=[snapshot_a_dev.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="previous_dev_plan", + previous_plan_id=None, + promoted_snapshot_ids=[snapshot_a_dev.snapshot_id], + finalized_ts=to_timestamp("2023-01-02"), + ) + + state_reader.get_environment.side_effect = ( + lambda name: existing_dev_environment if name == "dev" else existing_prod_environment + ) + state_reader.get_environments_summary.return_value = [ + existing_prod_environment.summary, + existing_dev_environment.summary, + ] + + environment = Environment( + name="dev", + snapshots=[snapshot_a_dev.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="test_plan", + previous_plan_id="previous_dev_plan", + promoted_snapshot_ids=[snapshot_a_dev.snapshot_id], + ) + + # Create evaluatable plan with restatements + plan = EvaluatablePlan( + start="2023-01-01", + end="2023-01-02", + new_snapshots=[], # No new snapshots + environment=environment, + no_gaps=False, + skip_backfill=False, + empty_backfill=False, + restatements={ + '"a"': (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + }, + restate_all_snapshots=False, + is_dev=True, + allow_destructive_models=set(), + allow_additive_models=set(), + forward_only=False, + end_bounded=False, + ensure_finalized_snapshots=False, + ignore_cron=False, + directly_modified_snapshots=[], # No changes + indirectly_modified_snapshots={}, # No changes + metadata_updated_snapshots=[], + removed_snapshots=[], + requires_backfill=True, + models_to_backfill=None, + execution_time="2023-01-02", + disabled_restatement_models=set(), + environment_statements=None, + user_provided_flags=None, + ) + + # Build plan stages + stages = build_plan_stages(plan, state_reader, None) + + # Verify stages + assert len(stages) == 5 + + # Verify no RestatementStage + assert not any(s for s in stages if isinstance(s, RestatementStage)) + + # Verify PhysicalLayerSchemaCreationStage + physical_stage = stages[0] + assert isinstance(physical_stage, PhysicalLayerSchemaCreationStage) + assert len(physical_stage.snapshots) == 1 + assert {s.snapshot_id for s in physical_stage.snapshots} == { + snapshot_a_dev.snapshot_id, + } + + # Verify BackfillStage + backfill_stage = stages[1] + assert isinstance(backfill_stage, BackfillStage) + assert len(backfill_stage.snapshot_to_intervals) == 1 + assert backfill_stage.deployability_index == DeployabilityIndex.all_deployable() + backfill_snapshot, backfill_intervals = list(backfill_stage.snapshot_to_intervals.items())[0] + assert backfill_snapshot.snapshot_id == snapshot_a_dev.snapshot_id + assert backfill_intervals == [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))] + + # Verify EnvironmentRecordUpdateStage + assert isinstance(stages[2], EnvironmentRecordUpdateStage) + + # Verify VirtualLayerUpdateStage (all non-prod plans get this regardless) + assert isinstance(stages[3], VirtualLayerUpdateStage) + + # Verify FinalizeEnvironmentStage + assert isinstance(stages[4], FinalizeEnvironmentStage) + + +def test_build_plan_stages_forward_only( + snapshot_a: Snapshot, snapshot_b: Snapshot, make_snapshot, mocker: MockerFixture +) -> None: + # Categorize snapshot_a as forward-only + new_snapshot_a = make_snapshot( + snapshot_a.model.copy(update={"stamp": "new_version"}), + ) + new_snapshot_a.previous_versions = snapshot_a.all_versions + new_snapshot_a.categorize_as(SnapshotChangeCategory.NON_BREAKING, forward_only=True) + + new_snapshot_b = make_snapshot( + snapshot_b.model.copy(), + nodes={'"a"': new_snapshot_a.model}, + ) + new_snapshot_b.previous_versions = snapshot_b.all_versions + new_snapshot_b.categorize_as(SnapshotChangeCategory.INDIRECT_NON_BREAKING, forward_only=True) + + state_reader = mocker.Mock(spec=StateReader) + state_reader.get_snapshots.return_value = {} + existing_environment = Environment( + name="prod", + snapshots=[snapshot_a.table_info, snapshot_b.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="previous_plan", + previous_plan_id=None, + promoted_snapshot_ids=[snapshot_a.snapshot_id, snapshot_b.snapshot_id], + finalized_ts=to_timestamp("2023-01-02"), + ) + state_reader.get_environment.return_value = existing_environment + + # Create environment + environment = Environment( + name="prod", + snapshots=[new_snapshot_a.table_info, new_snapshot_b.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="test_plan", + previous_plan_id="previous_plan", + promoted_snapshot_ids=[new_snapshot_a.snapshot_id, new_snapshot_b.snapshot_id], + ) + + # Create evaluatable plan + plan = EvaluatablePlan( + start="2023-01-01", + end="2023-01-02", + new_snapshots=[new_snapshot_a, new_snapshot_b], + environment=environment, + no_gaps=False, + skip_backfill=False, + empty_backfill=False, + restatements={}, + restate_all_snapshots=False, + is_dev=False, + allow_destructive_models=set(), + allow_additive_models=set(), + forward_only=False, + end_bounded=False, + ensure_finalized_snapshots=False, + ignore_cron=False, + directly_modified_snapshots=[new_snapshot_a.snapshot_id], + indirectly_modified_snapshots={ + new_snapshot_a.name: [new_snapshot_b.snapshot_id], + }, + metadata_updated_snapshots=[], + removed_snapshots=[], + requires_backfill=True, + models_to_backfill=None, + execution_time="2023-01-02", + disabled_restatement_models=set(), + environment_statements=None, + user_provided_flags=None, + ) + + # Build plan stages + stages = build_plan_stages(plan, state_reader, None) + + # Verify stages + assert len(stages) == 8 + + # Verify CreateSnapshotRecordsStage + create_snapshot_records_stage = stages[0] + assert isinstance(create_snapshot_records_stage, CreateSnapshotRecordsStage) + assert len(create_snapshot_records_stage.snapshots) == 2 + assert {s.snapshot_id for s in create_snapshot_records_stage.snapshots} == { + new_snapshot_a.snapshot_id, + new_snapshot_b.snapshot_id, + } + + # Verify PhysicalLayerSchemaCreationStage + physical_stage = stages[1] + assert isinstance(physical_stage, PhysicalLayerSchemaCreationStage) + assert len(physical_stage.snapshots) == 2 + assert {s.snapshot_id for s in physical_stage.snapshots} == { + new_snapshot_a.snapshot_id, + new_snapshot_b.snapshot_id, + } + assert physical_stage.deployability_index == DeployabilityIndex.all_deployable() + + # Verify EnvironmentRecordUpdateStage + assert isinstance(stages[2], EnvironmentRecordUpdateStage) + assert stages[2].no_gaps_snapshot_names == set() + + # Verify MigrateSchemasStage + migrate_stage = stages[3] + assert isinstance(migrate_stage, MigrateSchemasStage) + assert len(migrate_stage.snapshots) == 2 + assert {s.snapshot_id for s in migrate_stage.snapshots} == { + new_snapshot_a.snapshot_id, + new_snapshot_b.snapshot_id, + } + + # Verify UnpauseStage + assert isinstance(stages[4], UnpauseStage) + assert {s.name for s in stages[4].promoted_snapshots} == { + new_snapshot_a.name, + new_snapshot_b.name, + } + + # Verify BackfillStage + backfill_stage = stages[5] + assert isinstance(backfill_stage, BackfillStage) + assert len(backfill_stage.snapshot_to_intervals) == 2 + expected_interval = [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))] + for intervals in backfill_stage.snapshot_to_intervals.values(): + assert intervals == expected_interval + assert backfill_stage.deployability_index == DeployabilityIndex.all_deployable() + + # Verify VirtualLayerUpdateStage + virtual_stage = stages[6] + assert isinstance(virtual_stage, VirtualLayerUpdateStage) + assert len(virtual_stage.promoted_snapshots) == 2 + assert len(virtual_stage.demoted_snapshots) == 0 + assert {s.name for s in virtual_stage.promoted_snapshots} == {'"a"', '"b"'} + + # Verify FinalizeEnvironmentStage + assert isinstance(stages[7], FinalizeEnvironmentStage) + + +def test_build_plan_stages_forward_only_dev( + snapshot_a: Snapshot, snapshot_b: Snapshot, make_snapshot, mocker: MockerFixture +) -> None: + # Categorize snapshot_a as forward-only + new_snapshot_a = make_snapshot( + snapshot_a.model.copy(update={"stamp": "new_version"}), + ) + new_snapshot_a.previous_versions = snapshot_a.all_versions + new_snapshot_a.categorize_as(SnapshotChangeCategory.NON_BREAKING, forward_only=True) + + new_snapshot_b = make_snapshot( + snapshot_b.model.copy(), + nodes={'"a"': new_snapshot_a.model}, + ) + new_snapshot_b.previous_versions = snapshot_b.all_versions + new_snapshot_b.categorize_as(SnapshotChangeCategory.INDIRECT_NON_BREAKING, forward_only=True) + + state_reader = mocker.Mock(spec=StateReader) + state_reader.get_snapshots.return_value = {} + state_reader.get_environment.return_value = None + + # Create environment + environment = Environment( + name="dev", + snapshots=[new_snapshot_a.table_info, new_snapshot_b.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="test_plan", + previous_plan_id=None, + promoted_snapshot_ids=[new_snapshot_a.snapshot_id, new_snapshot_b.snapshot_id], + ) + + # Create evaluatable plan + plan = EvaluatablePlan( + start="2023-01-01", + end="2023-01-02", + new_snapshots=[new_snapshot_a, new_snapshot_b], + environment=environment, + no_gaps=False, + skip_backfill=False, + empty_backfill=False, + restatements={}, + restate_all_snapshots=False, + is_dev=True, + allow_destructive_models=set(), + allow_additive_models=set(), + forward_only=False, + end_bounded=False, + ensure_finalized_snapshots=False, + ignore_cron=False, + directly_modified_snapshots=[new_snapshot_a.snapshot_id], + indirectly_modified_snapshots={ + new_snapshot_a.name: [new_snapshot_b.snapshot_id], + }, + metadata_updated_snapshots=[], + removed_snapshots=[], + requires_backfill=True, + models_to_backfill=None, + execution_time="2023-01-02", + disabled_restatement_models=set(), + environment_statements=None, + user_provided_flags=None, + ) + + # Build plan stages + stages = build_plan_stages(plan, state_reader, None) + + # Verify stages + assert len(stages) == 6 + + # Verify CreateSnapshotRecordsStage + create_snapshot_records_stage = stages[0] + assert isinstance(create_snapshot_records_stage, CreateSnapshotRecordsStage) + assert len(create_snapshot_records_stage.snapshots) == 2 + assert {s.snapshot_id for s in create_snapshot_records_stage.snapshots} == { + new_snapshot_a.snapshot_id, + new_snapshot_b.snapshot_id, + } + + # Verify PhysicalLayerSchemaCreationStage + physical_stage = stages[1] + assert isinstance(physical_stage, PhysicalLayerSchemaCreationStage) + assert len(physical_stage.snapshots) == 2 + assert {s.snapshot_id for s in physical_stage.snapshots} == { + new_snapshot_a.snapshot_id, + new_snapshot_b.snapshot_id, + } + assert physical_stage.deployability_index == DeployabilityIndex.create( + [new_snapshot_a, new_snapshot_b] + ) + + # Verify BackfillStage + backfill_stage = stages[2] + assert isinstance(backfill_stage, BackfillStage) + assert len(backfill_stage.snapshot_to_intervals) == 2 + expected_interval = [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))] + for intervals in backfill_stage.snapshot_to_intervals.values(): + assert intervals == expected_interval + assert backfill_stage.deployability_index == DeployabilityIndex.create( + [new_snapshot_a, new_snapshot_b] + ) + + # Verify EnvironmentRecordUpdateStage + assert isinstance(stages[3], EnvironmentRecordUpdateStage) + + # Verify VirtualLayerUpdateStage + virtual_stage = stages[4] + assert isinstance(virtual_stage, VirtualLayerUpdateStage) + assert len(virtual_stage.promoted_snapshots) == 2 + assert len(virtual_stage.demoted_snapshots) == 0 + assert {s.name for s in virtual_stage.promoted_snapshots} == {'"a"', '"b"'} + + # Verify FinalizeEnvironmentStage + assert isinstance(stages[5], FinalizeEnvironmentStage) + + +def test_build_plan_stages_audit_only( + snapshot_a: Snapshot, snapshot_b: Snapshot, make_snapshot, mocker: MockerFixture +) -> None: + # Categorize snapshot_a as forward-only + new_snapshot_a = make_snapshot( + snapshot_a.model.copy(update={"audits": [("not_null", {})]}), + ) + new_snapshot_a.previous_versions = snapshot_a.all_versions + new_snapshot_a.categorize_as(SnapshotChangeCategory.METADATA) + new_snapshot_a.add_interval("2023-01-01", "2023-01-02") + + new_snapshot_b = make_snapshot( + snapshot_b.model.copy(), + nodes={'"a"': new_snapshot_a.model}, + ) + new_snapshot_b.previous_versions = snapshot_b.all_versions + new_snapshot_b.categorize_as(SnapshotChangeCategory.METADATA) + new_snapshot_b.add_interval("2023-01-01", "2023-01-02") + + def _get_snapshots(snapshot_ids: t.List[SnapshotId]) -> t.Dict[SnapshotId, Snapshot]: + if snapshot_a.snapshot_id in snapshot_ids and snapshot_b.snapshot_id in snapshot_ids: + return { + snapshot_a.snapshot_id: snapshot_a, + snapshot_b.snapshot_id: snapshot_b, + } + return {} + + state_reader = mocker.Mock(spec=StateReader) + state_reader.get_snapshots.side_effect = _get_snapshots + state_reader.get_environment.return_value = None + + # Create environment + environment = Environment( + name="dev", + snapshots=[new_snapshot_a.table_info, new_snapshot_b.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="test_plan", + previous_plan_id=None, + promoted_snapshot_ids=[new_snapshot_a.snapshot_id, new_snapshot_b.snapshot_id], + ) + + # Create evaluatable plan + plan = EvaluatablePlan( + start="2023-01-01", + end="2023-01-02", + new_snapshots=[new_snapshot_a, new_snapshot_b], + environment=environment, + no_gaps=False, + skip_backfill=False, + empty_backfill=False, + restatements={}, + restate_all_snapshots=False, + is_dev=True, + allow_destructive_models=set(), + allow_additive_models=set(), + forward_only=False, + end_bounded=False, + ensure_finalized_snapshots=False, + ignore_cron=False, + directly_modified_snapshots=[new_snapshot_a.snapshot_id], + indirectly_modified_snapshots={ + new_snapshot_a.name: [new_snapshot_b.snapshot_id], + }, + metadata_updated_snapshots=[], + removed_snapshots=[], + requires_backfill=True, + models_to_backfill=None, + execution_time="2023-01-02", + disabled_restatement_models=set(), + environment_statements=None, + user_provided_flags=None, + ) + + # Build plan stages + stages = build_plan_stages(plan, state_reader, None) + + # Verify stages + assert len(stages) == 8 + + # Verify CreateSnapshotRecordsStage + create_snapshot_records_stage = stages[0] + assert isinstance(create_snapshot_records_stage, CreateSnapshotRecordsStage) + assert len(create_snapshot_records_stage.snapshots) == 2 + assert {s.snapshot_id for s in create_snapshot_records_stage.snapshots} == { + new_snapshot_a.snapshot_id, + new_snapshot_b.snapshot_id, + } + + # Verify PhysicalLayerSchemaCreationStage + physical_stage = stages[1] + assert isinstance(physical_stage, PhysicalLayerSchemaCreationStage) + assert len(physical_stage.snapshots) == 2 + assert {s.snapshot_id for s in physical_stage.snapshots} == { + new_snapshot_a.snapshot_id, + new_snapshot_b.snapshot_id, + } + assert physical_stage.deployability_index == DeployabilityIndex.create( + [new_snapshot_a, new_snapshot_b] + ) + + # Verify PhysicalLayerUpdateStage + physical_stage = stages[2] + assert isinstance(physical_stage, PhysicalLayerUpdateStage) + assert len(physical_stage.snapshots) == 2 + assert {s.snapshot_id for s in physical_stage.snapshots} == { + new_snapshot_a.snapshot_id, + new_snapshot_b.snapshot_id, + } + assert physical_stage.deployability_index == DeployabilityIndex.create( + [new_snapshot_a, new_snapshot_b] + ) + + # Verify AuditOnlyRunStage + audit_only_stage = stages[3] + assert isinstance(audit_only_stage, AuditOnlyRunStage) + assert len(audit_only_stage.snapshots) == 1 + assert audit_only_stage.snapshots[0].snapshot_id == new_snapshot_a.snapshot_id + + # Verify BackfillStage + backfill_stage = stages[4] + assert isinstance(backfill_stage, BackfillStage) + assert len(backfill_stage.snapshot_to_intervals) == 0 + + # Verify EnvironmentRecordUpdateStage + assert isinstance(stages[5], EnvironmentRecordUpdateStage) + + # Verify VirtualLayerUpdateStage + virtual_stage = stages[6] + assert isinstance(virtual_stage, VirtualLayerUpdateStage) + assert len(virtual_stage.promoted_snapshots) == 2 + assert len(virtual_stage.demoted_snapshots) == 0 + assert {s.name for s in virtual_stage.promoted_snapshots} == {'"a"', '"b"'} + + # Verify FinalizeEnvironmentStage + assert isinstance(stages[7], FinalizeEnvironmentStage) + + +def test_build_plan_stages_forward_only_ensure_finalized_snapshots( + snapshot_a: Snapshot, snapshot_b: Snapshot, make_snapshot, mocker: MockerFixture +) -> None: + # Categorize snapshot_a as forward-only + new_snapshot_a = make_snapshot( + snapshot_a.model.copy(update={"stamp": "new_version"}), + ) + new_snapshot_a.previous_versions = snapshot_a.all_versions + new_snapshot_a.categorize_as(SnapshotChangeCategory.NON_BREAKING, forward_only=True) + + new_snapshot_b = make_snapshot( + snapshot_b.model.copy(), + nodes={'"a"': new_snapshot_a.model}, + ) + new_snapshot_b.previous_versions = snapshot_b.all_versions + new_snapshot_b.categorize_as(SnapshotChangeCategory.INDIRECT_NON_BREAKING, forward_only=True) + + state_reader = mocker.Mock(spec=StateReader) + state_reader.get_snapshots.return_value = {} + existing_environment = Environment( + name="prod", + snapshots=[snapshot_a.table_info, snapshot_b.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="previous_plan", + previous_plan_id=None, + promoted_snapshot_ids=[snapshot_a.snapshot_id, snapshot_b.snapshot_id], + finalized_ts=to_timestamp("2023-01-02"), + ) + state_reader.get_environment.return_value = existing_environment + + # Create environment + environment = Environment( + name="prod", + snapshots=[new_snapshot_a.table_info, new_snapshot_b.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="test_plan", + previous_plan_id="previous_plan", + promoted_snapshot_ids=[new_snapshot_a.snapshot_id, new_snapshot_b.snapshot_id], + ) + + # Create evaluatable plan + plan = EvaluatablePlan( + start="2023-01-01", + end="2023-01-02", + new_snapshots=[new_snapshot_a, new_snapshot_b], + environment=environment, + no_gaps=False, + skip_backfill=False, + empty_backfill=False, + restatements={}, + restate_all_snapshots=False, + is_dev=False, + allow_destructive_models=set(), + allow_additive_models=set(), + forward_only=False, + end_bounded=False, + ensure_finalized_snapshots=True, + ignore_cron=False, + directly_modified_snapshots=[new_snapshot_a.snapshot_id], + indirectly_modified_snapshots={ + new_snapshot_a.name: [new_snapshot_b.snapshot_id], + }, + metadata_updated_snapshots=[], + removed_snapshots=[], + requires_backfill=True, + models_to_backfill=None, + execution_time="2023-01-02", + disabled_restatement_models=set(), + environment_statements=None, + user_provided_flags=None, + ) + + # Build plan stages + stages = build_plan_stages(plan, state_reader, None) + + assert len(stages) == 8 + assert isinstance(stages[0], CreateSnapshotRecordsStage) + assert isinstance(stages[1], PhysicalLayerSchemaCreationStage) + assert isinstance(stages[2], EnvironmentRecordUpdateStage) + assert isinstance(stages[3], MigrateSchemasStage) + assert isinstance(stages[4], BackfillStage) + assert isinstance(stages[5], UnpauseStage) + assert isinstance(stages[6], VirtualLayerUpdateStage) + assert isinstance(stages[7], FinalizeEnvironmentStage) + + +def test_build_plan_stages_removed_model( + snapshot_a: Snapshot, snapshot_b: Snapshot, mocker: MockerFixture +) -> None: + # Mock state reader + state_reader = mocker.Mock(spec=StateReader) + state_reader.get_snapshots.return_value = { + snapshot_a.snapshot_id: snapshot_a, + snapshot_b.snapshot_id: snapshot_b, + } + existing_environment = Environment( + name="prod", + snapshots=[snapshot_a.table_info, snapshot_b.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="previous_plan", + previous_plan_id=None, + promoted_snapshot_ids=[snapshot_a.snapshot_id, snapshot_b.snapshot_id], + finalized_ts=to_timestamp("2023-01-02"), + ) + state_reader.get_environment.return_value = existing_environment + + # Create environment + environment = Environment( + snapshots=[snapshot_a.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="test_plan", + previous_plan_id="previous_plan", + promoted_snapshot_ids=[snapshot_a.snapshot_id], + ) + + # Create evaluatable plan + plan = EvaluatablePlan( + start="2023-01-01", + end="2023-01-02", + new_snapshots=[], + environment=environment, + no_gaps=False, + skip_backfill=False, + empty_backfill=False, + restatements={}, + restate_all_snapshots=False, + is_dev=False, + allow_destructive_models=set(), + allow_additive_models=set(), + forward_only=False, + end_bounded=False, + ensure_finalized_snapshots=False, + ignore_cron=False, + directly_modified_snapshots=[], + indirectly_modified_snapshots={}, + metadata_updated_snapshots=[], + removed_snapshots=[snapshot_b.snapshot_id], + requires_backfill=False, + models_to_backfill=None, + execution_time="2023-01-02", + disabled_restatement_models=set(), + environment_statements=None, + user_provided_flags=None, + ) + + # Build plan stages + stages = build_plan_stages(plan, state_reader, None) + + # Verify stages + assert len(stages) == 5 + + assert isinstance(stages[0], PhysicalLayerSchemaCreationStage) + assert isinstance(stages[1], BackfillStage) + assert isinstance(stages[2], EnvironmentRecordUpdateStage) + assert isinstance(stages[3], VirtualLayerUpdateStage) + assert isinstance(stages[4], FinalizeEnvironmentStage) + + virtual_layer_update_stage = stages[3] + assert virtual_layer_update_stage.promoted_snapshots == set() + assert virtual_layer_update_stage.demoted_snapshots == {snapshot_b.table_info} + assert ( + virtual_layer_update_stage.demoted_environment_naming_info + == existing_environment.naming_info + ) + + +def test_build_plan_stages_environment_suffix_target_changed( + snapshot_a: Snapshot, snapshot_b: Snapshot, mocker: MockerFixture +) -> None: + # Mock state reader + state_reader = mocker.Mock(spec=StateReader) + state_reader.get_snapshots.return_value = { + snapshot_a.snapshot_id: snapshot_a, + snapshot_b.snapshot_id: snapshot_b, + } + existing_environment = Environment( + name="dev", + snapshots=[snapshot_a.table_info, snapshot_b.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="previous_plan", + previous_plan_id=None, + promoted_snapshot_ids=[snapshot_a.snapshot_id, snapshot_b.snapshot_id], + finalized_ts=to_timestamp("2023-01-02"), + ) + state_reader.get_environment.return_value = existing_environment + + # Create environment + environment = Environment( + name="dev", + snapshots=[snapshot_a.table_info, snapshot_b.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="test_plan", + previous_plan_id="previous_plan", + promoted_snapshot_ids=[snapshot_a.snapshot_id, snapshot_b.snapshot_id], + suffix_target=EnvironmentSuffixTarget.TABLE, + ) + + # Create evaluatable plan + plan = EvaluatablePlan( + start="2023-01-01", + end="2023-01-02", + new_snapshots=[], + environment=environment, + no_gaps=False, + skip_backfill=False, + empty_backfill=False, + restatements={}, + restate_all_snapshots=False, + is_dev=True, + allow_destructive_models=set(), + allow_additive_models=set(), + forward_only=False, + end_bounded=False, + ensure_finalized_snapshots=False, + ignore_cron=False, + directly_modified_snapshots=[], + indirectly_modified_snapshots={}, + metadata_updated_snapshots=[], + removed_snapshots=[], + requires_backfill=False, + models_to_backfill=None, + execution_time="2023-01-02", + disabled_restatement_models=set(), + environment_statements=None, + user_provided_flags=None, + ) + + # Build plan stages + stages = build_plan_stages(plan, state_reader, None) + + # Verify stages + assert len(stages) == 5 + + assert isinstance(stages[0], PhysicalLayerSchemaCreationStage) + assert isinstance(stages[1], BackfillStage) + assert isinstance(stages[2], EnvironmentRecordUpdateStage) + assert isinstance(stages[3], VirtualLayerUpdateStage) + assert isinstance(stages[4], FinalizeEnvironmentStage) + + virtual_layer_update_stage = stages[3] + assert virtual_layer_update_stage.promoted_snapshots == { + snapshot_a.table_info, + snapshot_b.table_info, + } + assert virtual_layer_update_stage.demoted_snapshots == { + snapshot_a.table_info, + snapshot_b.table_info, + } + assert ( + virtual_layer_update_stage.demoted_environment_naming_info + == existing_environment.naming_info + ) + + +def test_build_plan_stages_indirect_non_breaking_view_migration( + snapshot_a: Snapshot, snapshot_c: Snapshot, make_snapshot, mocker: MockerFixture +) -> None: + # Categorize snapshot_a as forward-only + new_snapshot_a = make_snapshot( + snapshot_a.model.copy(update={"stamp": "new_version"}), + ) + new_snapshot_a.previous_versions = snapshot_a.all_versions + new_snapshot_a.categorize_as(SnapshotChangeCategory.NON_BREAKING) + + new_snapshot_c = make_snapshot( + snapshot_c.model.copy(), + nodes={'"a"': new_snapshot_a.model}, + ) + new_snapshot_c.previous_versions = snapshot_c.all_versions + new_snapshot_c.change_category = SnapshotChangeCategory.INDIRECT_NON_BREAKING + new_snapshot_c.version = new_snapshot_c.previous_version.data_version.version + + state_reader = mocker.Mock(spec=StateReader) + state_reader.get_snapshots.return_value = {} + existing_environment = Environment( + name="prod", + snapshots=[snapshot_a.table_info, snapshot_c.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="previous_plan", + previous_plan_id=None, + promoted_snapshot_ids=[snapshot_a.snapshot_id, snapshot_c.snapshot_id], + finalized_ts=to_timestamp("2023-01-02"), + ) + state_reader.get_environment.return_value = existing_environment + + # Create environment + environment = Environment( + name="prod", + snapshots=[new_snapshot_a.table_info, new_snapshot_c.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="test_plan", + previous_plan_id="previous_plan", + promoted_snapshot_ids=[new_snapshot_a.snapshot_id, new_snapshot_c.snapshot_id], + ) + + # Create evaluatable plan + plan = EvaluatablePlan( + start="2023-01-01", + end="2023-01-02", + new_snapshots=[new_snapshot_a, new_snapshot_c], + environment=environment, + no_gaps=False, + skip_backfill=False, + empty_backfill=False, + restatements={}, + restate_all_snapshots=False, + is_dev=False, + allow_destructive_models=set(), + allow_additive_models=set(), + forward_only=False, + end_bounded=False, + ensure_finalized_snapshots=False, + ignore_cron=False, + directly_modified_snapshots=[new_snapshot_a.snapshot_id], + indirectly_modified_snapshots={ + new_snapshot_a.name: [new_snapshot_c.snapshot_id], + }, + metadata_updated_snapshots=[], + removed_snapshots=[], + requires_backfill=True, + models_to_backfill=None, + execution_time="2023-01-02", + disabled_restatement_models=set(), + environment_statements=None, + user_provided_flags=None, + ) + + # Build plan stages + stages = build_plan_stages(plan, state_reader, None) + + # Verify stages + assert len(stages) == 9 + + assert isinstance(stages[0], CreateSnapshotRecordsStage) + assert isinstance(stages[1], PhysicalLayerSchemaCreationStage) + assert isinstance(stages[2], BackfillStage) + assert isinstance(stages[3], EnvironmentRecordUpdateStage) + assert isinstance(stages[4], MigrateSchemasStage) + assert isinstance(stages[5], UnpauseStage) + assert isinstance(stages[6], BackfillStage) + assert isinstance(stages[7], VirtualLayerUpdateStage) + assert isinstance(stages[8], FinalizeEnvironmentStage) + + +def test_build_plan_stages_virtual_environment_mode_filtering( + make_snapshot, mocker: MockerFixture +) -> None: + # Create snapshots with different virtual environment modes + snapshot_full = make_snapshot( + SqlModel( + name="full_model", + query=parse_one("select 1, ds"), + kind=dict(name=ModelKindName.INCREMENTAL_BY_TIME_RANGE, time_column="ds"), + virtual_environment_mode=VirtualEnvironmentMode.FULL, + ) + ) + snapshot_full.categorize_as(SnapshotChangeCategory.BREAKING) + + snapshot_dev_only = make_snapshot( + SqlModel( + name="dev_only_model", + query=parse_one("select 2, ds"), + kind=dict(name=ModelKindName.INCREMENTAL_BY_TIME_RANGE, time_column="ds"), + virtual_environment_mode=VirtualEnvironmentMode.DEV_ONLY, + ) + ) + snapshot_dev_only.categorize_as(SnapshotChangeCategory.BREAKING) + + # Mock state reader + state_reader = mocker.Mock(spec=StateReader) + state_reader.get_snapshots.return_value = {} + state_reader.get_environment.return_value = None + + # Test 1: Dev environment - both snapshots should be included + environment_dev = Environment( + name="dev", + snapshots=[snapshot_full.table_info, snapshot_dev_only.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="test_plan", + previous_plan_id=None, + promoted_snapshot_ids=[snapshot_full.snapshot_id, snapshot_dev_only.snapshot_id], + ) + + plan_dev = EvaluatablePlan( + start="2023-01-01", + end="2023-01-02", + new_snapshots=[snapshot_full, snapshot_dev_only], + environment=environment_dev, + no_gaps=False, + skip_backfill=False, + empty_backfill=False, + restatements={}, + restate_all_snapshots=False, + is_dev=True, + allow_destructive_models=set(), + allow_additive_models=set(), + forward_only=False, + end_bounded=False, + ensure_finalized_snapshots=False, + ignore_cron=False, + directly_modified_snapshots=[snapshot_full.snapshot_id, snapshot_dev_only.snapshot_id], + indirectly_modified_snapshots={}, + metadata_updated_snapshots=[], + removed_snapshots=[], + requires_backfill=True, + models_to_backfill=None, + execution_time="2023-01-02", + disabled_restatement_models=set(), + environment_statements=None, + user_provided_flags=None, + ) + + stages_dev = build_plan_stages(plan_dev, state_reader, None) + + # Find VirtualLayerUpdateStage + virtual_stage_dev = next( + stage for stage in stages_dev if isinstance(stage, VirtualLayerUpdateStage) + ) + + # In dev environment, both snapshots should be promoted regardless of virtual_environment_mode + assert {s.name for s in virtual_stage_dev.promoted_snapshots} == { + '"full_model"', + '"dev_only_model"', + } + assert len(virtual_stage_dev.demoted_snapshots) == 0 + + # Test 2: Production environment - only FULL mode snapshots should be included + environment_prod = Environment( + name="prod", + snapshots=[snapshot_full.table_info, snapshot_dev_only.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="test_plan", + previous_plan_id=None, + promoted_snapshot_ids=[snapshot_full.snapshot_id, snapshot_dev_only.snapshot_id], + ) + + plan_prod = EvaluatablePlan( + start="2023-01-01", + end="2023-01-02", + new_snapshots=[snapshot_full, snapshot_dev_only], + environment=environment_prod, + no_gaps=False, + skip_backfill=False, + empty_backfill=False, + restatements={}, + restate_all_snapshots=False, + is_dev=False, + allow_destructive_models=set(), + allow_additive_models=set(), + forward_only=False, + end_bounded=False, + ensure_finalized_snapshots=False, + ignore_cron=False, + directly_modified_snapshots=[snapshot_full.snapshot_id, snapshot_dev_only.snapshot_id], + indirectly_modified_snapshots={}, + metadata_updated_snapshots=[], + removed_snapshots=[], + requires_backfill=True, + models_to_backfill=None, + execution_time="2023-01-02", + disabled_restatement_models=set(), + environment_statements=None, + user_provided_flags=None, + ) + + stages_prod = build_plan_stages(plan_prod, state_reader, None) + + # Find VirtualLayerUpdateStage + virtual_stage_prod = next( + stage for stage in stages_prod if isinstance(stage, VirtualLayerUpdateStage) + ) + + # In production environment, only FULL mode snapshots should be promoted + assert {s.name for s in virtual_stage_prod.promoted_snapshots} == {'"full_model"'} + assert len(virtual_stage_prod.demoted_snapshots) == 0 + + # Test 3: Production environment with demoted snapshots + existing_environment = Environment( + name="prod", + snapshots=[snapshot_full.table_info, snapshot_dev_only.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="previous_plan", + previous_plan_id=None, + promoted_snapshot_ids=[snapshot_full.snapshot_id, snapshot_dev_only.snapshot_id], + finalized_ts=to_timestamp("2023-01-02"), + ) + state_reader.get_environment.return_value = existing_environment + + # Remove both snapshots from the new environment + environment_prod_demote = Environment( + name="prod", + snapshots=[], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="test_plan", + previous_plan_id="previous_plan", + promoted_snapshot_ids=[], + ) + + plan_prod_demote = EvaluatablePlan( + start="2023-01-01", + end="2023-01-02", + new_snapshots=[], + environment=environment_prod_demote, + no_gaps=False, + skip_backfill=False, + empty_backfill=False, + restatements={}, + restate_all_snapshots=False, + is_dev=False, + allow_destructive_models=set(), + allow_additive_models=set(), + forward_only=False, + end_bounded=False, + ensure_finalized_snapshots=False, + ignore_cron=False, + directly_modified_snapshots=[], + indirectly_modified_snapshots={}, + metadata_updated_snapshots=[], + removed_snapshots=[snapshot_full.snapshot_id, snapshot_dev_only.snapshot_id], + requires_backfill=False, + models_to_backfill=None, + execution_time="2023-01-02", + disabled_restatement_models=set(), + environment_statements=None, + user_provided_flags=None, + ) + + stages_prod_demote = build_plan_stages(plan_prod_demote, state_reader, None) + + # Find VirtualLayerUpdateStage + virtual_stage_prod_demote = next( + stage for stage in stages_prod_demote if isinstance(stage, VirtualLayerUpdateStage) + ) + + # In production environment, only FULL mode snapshots should be demoted + assert len(virtual_stage_prod_demote.promoted_snapshots) == 0 + assert {s.name for s in virtual_stage_prod_demote.demoted_snapshots} == {'"full_model"'} + assert ( + virtual_stage_prod_demote.demoted_environment_naming_info + == existing_environment.naming_info + ) + + +def test_build_plan_stages_virtual_environment_mode_no_updates( + snapshot_a: Snapshot, make_snapshot, mocker: MockerFixture +) -> None: + # Create snapshot with DEV_ONLY mode + snapshot_dev_only = make_snapshot( + SqlModel( + name="dev_only_model", + query=parse_one("select 1, ds"), + kind=dict(name=ModelKindName.INCREMENTAL_BY_TIME_RANGE, time_column="ds"), + virtual_environment_mode=VirtualEnvironmentMode.DEV_ONLY, + ) + ) + snapshot_dev_only.categorize_as(SnapshotChangeCategory.BREAKING) + + # Mock state reader + state_reader = mocker.Mock(spec=StateReader) + state_reader.get_snapshots.return_value = {} + state_reader.get_environment.return_value = None + + # Production environment with only DEV_ONLY snapshots + environment = Environment( + name="prod", + snapshots=[snapshot_dev_only.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="test_plan", + previous_plan_id=None, + promoted_snapshot_ids=[snapshot_dev_only.snapshot_id], + ) + + plan = EvaluatablePlan( + start="2023-01-01", + end="2023-01-02", + new_snapshots=[snapshot_dev_only], + environment=environment, + no_gaps=False, + skip_backfill=False, + empty_backfill=False, + restatements={}, + restate_all_snapshots=False, + is_dev=False, + allow_destructive_models=set(), + allow_additive_models=set(), + forward_only=False, + end_bounded=False, + ensure_finalized_snapshots=False, + ignore_cron=False, + directly_modified_snapshots=[snapshot_dev_only.snapshot_id], + indirectly_modified_snapshots={}, + metadata_updated_snapshots=[], + removed_snapshots=[], + requires_backfill=True, + models_to_backfill=None, + execution_time="2023-01-02", + disabled_restatement_models=set(), + environment_statements=None, + user_provided_flags=None, + ) + + stages = build_plan_stages(plan, state_reader, None) + + # No VirtualLayerUpdateStage should be created since all snapshots are filtered out + virtual_stages = [stage for stage in stages if isinstance(stage, VirtualLayerUpdateStage)] + assert len(virtual_stages) == 0 + + +def test_adjust_intervals_new_forward_only_dev_intervals( + make_snapshot, mocker: MockerFixture +) -> None: + forward_only_snapshot = make_snapshot( + SqlModel( + name="forward_only_model", + query=parse_one("select 1, ds"), + kind=dict(name=ModelKindName.INCREMENTAL_BY_TIME_RANGE, time_column="ds"), + ) + ) + forward_only_snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) + forward_only_snapshot.intervals = [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))] + + forward_only_snapshot.dev_intervals = [] + + state_reader = mocker.Mock(spec=StateReader) + state_reader.refresh_snapshot_intervals = mocker.Mock() + state_reader.get_snapshots.return_value = {} + state_reader.get_environment.return_value = None + + environment = Environment( + snapshots=[forward_only_snapshot.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="test_plan", + previous_plan_id=None, + promoted_snapshot_ids=[forward_only_snapshot.snapshot_id], + ) + + plan = EvaluatablePlan( + start="2023-01-01", + end="2023-01-02", + new_snapshots=[forward_only_snapshot], # This snapshot should have dev_intervals set + environment=environment, + no_gaps=False, + skip_backfill=False, + empty_backfill=False, + restatements={}, + restate_all_snapshots=False, + is_dev=True, # Dev environment + allow_destructive_models=set(), + allow_additive_models=set(), + forward_only=False, + end_bounded=False, + ensure_finalized_snapshots=False, + ignore_cron=False, + directly_modified_snapshots=[], + indirectly_modified_snapshots={}, + metadata_updated_snapshots=[], + removed_snapshots=[], + requires_backfill=True, + models_to_backfill=None, + execution_time="2023-01-02", + disabled_restatement_models=set(), + environment_statements=None, + user_provided_flags=None, + ) + + assert forward_only_snapshot.dev_intervals == [] + + build_plan_stages(plan, state_reader, None) + + assert forward_only_snapshot.dev_intervals == [ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")) + ] + assert forward_only_snapshot.dev_intervals is not forward_only_snapshot.intervals + + state_reader.refresh_snapshot_intervals.assert_called_once() + + +def test_adjust_intervals_restatement_removal( + snapshot_a: Snapshot, snapshot_b: Snapshot, mocker: MockerFixture +) -> None: + snapshot_a.intervals = [(to_timestamp("2023-01-01"), to_timestamp("2023-01-04"))] + snapshot_b.intervals = [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))] + + original_a_intervals = snapshot_a.intervals.copy() + original_b_intervals = snapshot_b.intervals.copy() + + state_reader = mocker.Mock(spec=StateReader) + state_reader.refresh_snapshot_intervals = mocker.Mock() + state_reader.get_snapshots.return_value = {} + state_reader.get_environment.return_value = None + state_reader.get_environments_summary.return_value = [] + + environment = Environment( + snapshots=[snapshot_a.table_info, snapshot_b.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="test_plan", + previous_plan_id=None, + promoted_snapshot_ids=[snapshot_a.snapshot_id, snapshot_b.snapshot_id], + ) + + restatements = { + snapshot_a.name: (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + snapshot_b.name: (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + } + + plan = EvaluatablePlan( + start="2023-01-01", + end="2023-01-02", + new_snapshots=[snapshot_a, snapshot_b], + environment=environment, + no_gaps=False, + skip_backfill=False, + empty_backfill=False, + restatements=restatements, + restate_all_snapshots=True, + is_dev=False, + allow_destructive_models=set(), + allow_additive_models=set(), + forward_only=False, + end_bounded=False, + ensure_finalized_snapshots=False, + ignore_cron=False, + directly_modified_snapshots=[], + indirectly_modified_snapshots={}, + metadata_updated_snapshots=[], + removed_snapshots=[], + requires_backfill=True, + models_to_backfill=None, + execution_time="2023-01-02", + disabled_restatement_models=set(), + environment_statements=None, + user_provided_flags=None, + ) + + stages = build_plan_stages(plan, state_reader, None) + + assert snapshot_a.intervals != original_a_intervals + assert snapshot_b.intervals != original_b_intervals + + state_reader.refresh_snapshot_intervals.assert_called_once() + + restatement_stages = [stage for stage in stages if isinstance(stage, RestatementStage)] + assert len(restatement_stages) == 1 + + backfill_stages = [stage for stage in stages if isinstance(stage, BackfillStage)] + assert len(backfill_stages) == 1 + (snapshot, intervals) = next(iter(backfill_stages[0].snapshot_to_intervals.items())) + assert snapshot.intervals == [(to_timestamp("2023-01-02"), to_timestamp("2023-01-04"))] + assert intervals == [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))] + + +def test_adjust_intervals_should_force_rebuild(make_snapshot, mocker: MockerFixture) -> None: + old_snapshot = make_snapshot( + SqlModel( + name="test_model", + query=parse_one("select 1, ds"), + kind=dict(name=ModelKindName.INCREMENTAL_BY_TIME_RANGE, time_column="ds"), + ) + ) + old_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + old_snapshot.intervals = [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))] + + new_snapshot = make_snapshot( + SqlModel( + name="test_model", + query=parse_one("select 1, ds"), + kind=dict(name=ModelKindName.FULL), + ) + ) + new_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + new_snapshot.version = old_snapshot.version + new_snapshot.intervals = [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))] + + state_reader = mocker.Mock(spec=StateReader) + state_reader.refresh_snapshot_intervals = mocker.Mock() + state_reader.get_snapshots.side_effect = [{}, {old_snapshot.snapshot_id: old_snapshot}, {}, {}] + + existing_environment = Environment( + name="prod", + snapshots=[old_snapshot.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="previous_plan", + promoted_snapshot_ids=[old_snapshot.snapshot_id], + finalized_ts=to_timestamp("2023-01-02"), + ) + state_reader.get_environment.return_value = existing_environment + + environment = Environment( + snapshots=[new_snapshot.table_info], + start_at="2023-01-01", + end_at="2023-01-02", + plan_id="test_plan", + previous_plan_id="previous_plan", + promoted_snapshot_ids=[new_snapshot.snapshot_id], + ) + + plan = EvaluatablePlan( + start="2023-01-01", + end="2023-01-02", + new_snapshots=[new_snapshot], + environment=environment, + no_gaps=False, + skip_backfill=False, + empty_backfill=False, + restatements={}, + restate_all_snapshots=False, + is_dev=False, + allow_destructive_models=set(), + allow_additive_models=set(), + forward_only=False, + end_bounded=False, + ensure_finalized_snapshots=False, + ignore_cron=False, + directly_modified_snapshots=[new_snapshot.snapshot_id], + indirectly_modified_snapshots={}, + metadata_updated_snapshots=[], + removed_snapshots=[], + requires_backfill=True, + models_to_backfill=None, + execution_time="2023-01-02", + disabled_restatement_models=set(), + environment_statements=None, + user_provided_flags=None, + ) + + stages = build_plan_stages(plan, state_reader, None) + + state_reader.refresh_snapshot_intervals.assert_called_once() + state_reader.get_environment.assert_called() + + assert not new_snapshot.intervals + backfill_stages = [stage for stage in stages if isinstance(stage, BackfillStage)] + assert len(backfill_stages) == 1 + (snapshot, intervals) = next(iter(backfill_stages[0].snapshot_to_intervals.items())) + assert not snapshot.intervals + assert intervals == [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))] diff --git a/tests/core/test_rule.py b/tests/core/test_rule.py new file mode 100644 index 0000000000..785988932d --- /dev/null +++ b/tests/core/test_rule.py @@ -0,0 +1,57 @@ +from __future__ import annotations + +import inspect +import typing as t +from unittest.mock import MagicMock + +import pytest +from sqlmesh.core.model import Model +from sqlmesh.core.linter.rule import Rule, RuleViolation + + +class TestRule(Rule): + """A test rule for testing the get_definition_location method.""" + + __test__ = False # prevent pytest warning since this isnt a class containing tests + + def check_model(self, model: Model) -> t.Optional[RuleViolation]: + """The evaluation function that'll check for a violation of this rule.""" + return None + + +def test_get_definition_location(): + """Test the get_definition_location method returns correct file and line information.""" + # Create a mock context + mock_context = MagicMock() + rule = TestRule(mock_context) + + # Get the expected location using the inspect module + expected_file = inspect.getfile(TestRule) + expected_source, expected_start_line = inspect.getsourcelines(TestRule) + expected_end_line = expected_start_line + len(expected_source) - 1 + + # Get the location using the Rule method + location = rule.get_definition_location() + + # Assert the file path matches + assert location.file_path == expected_file + + # Assert the line numbers match + assert location.start_line == expected_start_line + + # Test the fallback case for a class without source + with pytest.MonkeyPatch.context() as mp: + # Mock inspect.getsourcelines to raise an exception + def mock_getsourcelines(*args, **kwargs): + raise IOError("Mock error") + + mp.setattr(inspect, "getsourcelines", mock_getsourcelines) + + # Get the location with the mocked function + fallback_location = rule.get_definition_location() + + # It should still have the file path + assert fallback_location.file_path == expected_file + + # But not the line numbers + assert fallback_location.start_line is None diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index a0ef5a89ee..cd32d2451d 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -3,21 +3,39 @@ import pytest from pytest_mock.plugin import MockerFixture from sqlglot import parse_one, parse +from sqlglot.helper import first -from sqlmesh.core.context import Context +from sqlmesh.core.context import Context, ExecutionContext from sqlmesh.core.environment import EnvironmentNamingInfo +from sqlmesh.core.macros import RuntimeStage from sqlmesh.core.model import load_sql_based_model -from sqlmesh.core.model.definition import SqlModel +from sqlmesh.core.model.definition import AuditResult, SqlModel from sqlmesh.core.model.kind import ( IncrementalByTimeRangeKind, IncrementalByUniqueKeyKind, TimeColumn, + SCDType2ByColumnKind, ) from sqlmesh.core.node import IntervalUnit -from sqlmesh.core.scheduler import Scheduler, compute_interval_params -from sqlmesh.core.snapshot import Snapshot, SnapshotEvaluator, SnapshotChangeCategory -from sqlmesh.utils.date import to_datetime -from sqlmesh.utils.errors import CircuitBreakerError +from sqlmesh.core.scheduler import ( + Scheduler, + interval_diff, + compute_interval_params, + SnapshotToIntervals, + EvaluateNode, + SchedulingUnit, + DummyNode, +) +from sqlmesh.core.signal import signal +from sqlmesh.core.snapshot import ( + Snapshot, + SnapshotEvaluator, + SnapshotChangeCategory, + DeployabilityIndex, + snapshots_to_dag, +) +from sqlmesh.utils.date import to_datetime, to_timestamp, DatetimeRanges, TimeLike +from sqlmesh.utils.errors import CircuitBreakerError, NodeAuditsErrors @pytest.fixture @@ -30,6 +48,11 @@ def orders(sushi_context_fixed_date: Context) -> Snapshot: return sushi_context_fixed_date.get_snapshot("sushi.orders", raise_if_missing=True) +@pytest.fixture +def waiter_names(sushi_context_fixed_date: Context) -> Snapshot: + return sushi_context_fixed_date.get_snapshot("sushi.waiter_names", raise_if_missing=True) + + @pytest.mark.slow def test_interval_params(scheduler: Scheduler, sushi_context_fixed_date: Context, orders: Snapshot): waiter_revenue = sushi_context_fixed_date.get_snapshot( @@ -40,18 +63,30 @@ def test_interval_params(scheduler: Scheduler, sushi_context_fixed_date: Context assert compute_interval_params([orders, waiter_revenue], start=start_ds, end=end_ds) == { orders: [ - (to_datetime(start_ds), to_datetime("2022-01-31")), - (to_datetime("2022-01-31"), to_datetime("2022-02-06")), + (to_timestamp(start_ds), to_timestamp("2022-02-06")), ], waiter_revenue: [ - (to_datetime(start_ds), to_datetime("2022-01-11")), - (to_datetime("2022-01-11"), to_datetime("2022-01-21")), - (to_datetime("2022-01-21"), to_datetime("2022-01-31")), - (to_datetime("2022-01-31"), to_datetime("2022-02-06")), + (to_timestamp(start_ds), to_timestamp("2022-02-06")), ], } +@pytest.fixture +def get_batched_missing_intervals( + mocker: MockerFixture, +) -> t.Callable[[Scheduler, TimeLike, TimeLike, t.Optional[TimeLike]], SnapshotToIntervals]: + def _get_batched_missing_intervals( + scheduler: Scheduler, + start: TimeLike, + end: TimeLike, + execution_time: t.Optional[TimeLike] = None, + ) -> SnapshotToIntervals: + merged_intervals = scheduler.merged_missing_intervals(start, end, execution_time) + return scheduler.batch_intervals(merged_intervals, mocker.Mock(), mocker.Mock()) + + return _get_batched_missing_intervals + + def test_interval_params_nonconsecutive(scheduler: Scheduler, orders: Snapshot): start_ds = "2022-01-01" end_ds = "2022-02-05" @@ -60,8 +95,8 @@ def test_interval_params_nonconsecutive(scheduler: Scheduler, orders: Snapshot): assert compute_interval_params([orders], start=start_ds, end=end_ds) == { orders: [ - (to_datetime(start_ds), to_datetime("2022-01-10")), - (to_datetime("2022-01-16"), to_datetime("2022-02-06")), + (to_timestamp(start_ds), to_timestamp("2022-01-10")), + (to_timestamp("2022-01-16"), to_timestamp("2022-02-06")), ] } @@ -77,7 +112,7 @@ def test_interval_params_missing(scheduler: Scheduler, sushi_context_fixed_date: assert compute_interval_params( sushi_context_fixed_date.snapshots.values(), start=start_ds, end=end_ds )[waiters] == [ - (to_datetime(start_ds), to_datetime("2022-03-02")), + (to_timestamp(start_ds), to_timestamp("2022-03-02")), ] @@ -99,7 +134,9 @@ def test_run(sushi_context_fixed_date: Context, scheduler: Scheduler): ) == (0, "Hotate", 5.99) -def test_incremental_by_unique_key_kind_dag(mocker: MockerFixture, make_snapshot): +def test_incremental_by_unique_key_kind_dag( + mocker: MockerFixture, make_snapshot, get_batched_missing_intervals +): """ Test that when given a week of data that it batches dates together. """ @@ -116,7 +153,7 @@ def test_incremental_by_unique_key_kind_dag(mocker: MockerFixture, make_snapshot query=parse_one("SELECT id FROM VALUES (1), (2) AS t(id)"), ), ) - snapshot_evaluator = SnapshotEvaluator(adapter=mocker.MagicMock(), ddl_concurrent_tasks=1) + snapshot_evaluator = SnapshotEvaluator(adapters=mocker.MagicMock(), ddl_concurrent_tasks=1) mock_state_sync = mocker.MagicMock() scheduler = Scheduler( snapshots=[unique_by_key_snapshot], @@ -125,18 +162,20 @@ def test_incremental_by_unique_key_kind_dag(mocker: MockerFixture, make_snapshot max_workers=2, default_catalog=None, ) - batches = scheduler.batches(start, end, end) + batches = get_batched_missing_intervals(scheduler, start, end, end) dag = scheduler._dag(batches) assert dag.graph == { - ( + EvaluateNode( unique_by_key_snapshot.name, - ((to_datetime("2023-01-01"), to_datetime("2023-01-07")), 0), + interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-07")), + batch_index=0, ): set(), } - mock_state_sync.refresh_snapshot_intervals.assert_called_once() -def test_incremental_time_self_reference_dag(mocker: MockerFixture, make_snapshot): +def test_incremental_time_self_reference_dag( + mocker: MockerFixture, make_snapshot, get_batched_missing_intervals +): """ Test that we always process a day at a time and all future days rely on the previous day """ @@ -156,7 +195,7 @@ def test_incremental_time_self_reference_dag(mocker: MockerFixture, make_snapsho incremental_self_snapshot.add_interval("2023-01-02", "2023-01-02") incremental_self_snapshot.add_interval("2023-01-05", "2023-01-05") - snapshot_evaluator = SnapshotEvaluator(adapter=mocker.MagicMock(), ddl_concurrent_tasks=1) + snapshot_evaluator = SnapshotEvaluator(adapters=mocker.MagicMock(), ddl_concurrent_tasks=1) scheduler = Scheduler( snapshots=[incremental_self_snapshot], snapshot_evaluator=snapshot_evaluator, @@ -164,65 +203,71 @@ def test_incremental_time_self_reference_dag(mocker: MockerFixture, make_snapsho max_workers=2, default_catalog=None, ) - batches = scheduler.batches(start, end, end) + batches = get_batched_missing_intervals(scheduler, start, end, end) dag = scheduler._dag(batches) assert dag.graph == { # Only run one day at a time and each day relies on the previous days - ( + EvaluateNode( incremental_self_snapshot.name, - ((to_datetime("2023-01-01"), to_datetime("2023-01-02")), 0), + interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + batch_index=0, ): set(), - ( + EvaluateNode( incremental_self_snapshot.name, - ((to_datetime("2023-01-03"), to_datetime("2023-01-04")), 1), + interval=(to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + batch_index=1, ): { - ( - incremental_self_snapshot.name, - ((to_datetime("2023-01-01"), to_datetime("2023-01-02")), 0), - ) + EvaluateNode( + snapshot_name=incremental_self_snapshot.name, + interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + batch_index=0, + ), }, - ( + EvaluateNode( incremental_self_snapshot.name, - ((to_datetime("2023-01-04"), to_datetime("2023-01-05")), 2), + interval=(to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + batch_index=2, ): { - ( - incremental_self_snapshot.name, - ((to_datetime("2023-01-03"), to_datetime("2023-01-04")), 1), + EvaluateNode( + snapshot_name=incremental_self_snapshot.name, + interval=(to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + batch_index=1, ), }, - ( - incremental_self_snapshot.name, - ((to_datetime("2023-01-06"), to_datetime("2023-01-07")), 3), + EvaluateNode( + snapshot_name=incremental_self_snapshot.name, + interval=(to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + batch_index=3, ): { - ( - incremental_self_snapshot.name, - ((to_datetime("2023-01-04"), to_datetime("2023-01-05")), 2), + EvaluateNode( + snapshot_name=incremental_self_snapshot.name, + interval=(to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + batch_index=2, + ), + }, + DummyNode(snapshot_name=incremental_self_snapshot.name): { + EvaluateNode( + snapshot_name=incremental_self_snapshot.name, + interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + batch_index=0, + ), + EvaluateNode( + snapshot_name=incremental_self_snapshot.name, + interval=(to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + batch_index=1, + ), + EvaluateNode( + snapshot_name=incremental_self_snapshot.name, + interval=(to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + batch_index=2, + ), + EvaluateNode( + snapshot_name=incremental_self_snapshot.name, + interval=(to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + batch_index=3, ), }, - ( - incremental_self_snapshot.name, - ((to_datetime(0), to_datetime(0)), -1), - ): set( - [ - ( - incremental_self_snapshot.name, - ((to_datetime("2023-01-01"), to_datetime("2023-01-02")), 0), - ), - ( - incremental_self_snapshot.name, - ((to_datetime("2023-01-03"), to_datetime("2023-01-04")), 1), - ), - ( - incremental_self_snapshot.name, - ((to_datetime("2023-01-04"), to_datetime("2023-01-05")), 2), - ), - ( - incremental_self_snapshot.name, - ((to_datetime("2023-01-06"), to_datetime("2023-01-07")), 3), - ), - ] - ), } @@ -233,16 +278,26 @@ def test_incremental_time_self_reference_dag(mocker: MockerFixture, make_snapsho 2, 2, { - ( - '"test_model"', - ((to_datetime("2023-01-01"), to_datetime("2023-01-03")), 0), + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-03")), + batch_index=0, ): set(), - ( - '"test_model"', - ((to_datetime("2023-01-03"), to_datetime("2023-01-05")), 1), + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-03"), to_timestamp("2023-01-05")), + batch_index=1, ): set(), - ('"test_model"', ((to_datetime("2023-01-05"), to_datetime("2023-01-07")), 2)): { - ('"test_model"', ((to_datetime("2023-01-01"), to_datetime("2023-01-03")), 0)), + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-05"), to_timestamp("2023-01-07")), + batch_index=2, + ): { + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-03")), + batch_index=0, + ), }, }, ), @@ -250,26 +305,53 @@ def test_incremental_time_self_reference_dag(mocker: MockerFixture, make_snapsho 1, 3, { - ( - '"test_model"', - ((to_datetime("2023-01-01"), to_datetime("2023-01-02")), 0), + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + batch_index=0, ): set(), - ( - '"test_model"', - ((to_datetime("2023-01-02"), to_datetime("2023-01-03")), 1), + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + batch_index=1, ): set(), - ( - '"test_model"', - ((to_datetime("2023-01-03"), to_datetime("2023-01-04")), 2), + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + batch_index=2, ): set(), - ('"test_model"', ((to_datetime("2023-01-04"), to_datetime("2023-01-05")), 3)): { - ('"test_model"', ((to_datetime("2023-01-01"), to_datetime("2023-01-02")), 0)), + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + batch_index=3, + ): { + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + batch_index=0, + ), }, - ('"test_model"', ((to_datetime("2023-01-05"), to_datetime("2023-01-06")), 4)): { - ('"test_model"', ((to_datetime("2023-01-02"), to_datetime("2023-01-03")), 1)), + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + batch_index=4, + ): { + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + batch_index=1, + ), }, - ('"test_model"', ((to_datetime("2023-01-06"), to_datetime("2023-01-07")), 5)): { - ('"test_model"', ((to_datetime("2023-01-03"), to_datetime("2023-01-04")), 2)), + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + batch_index=5, + ): { + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + batch_index=2, + ), }, }, ), @@ -277,29 +359,35 @@ def test_incremental_time_self_reference_dag(mocker: MockerFixture, make_snapsho 1, 10, { - ( - '"test_model"', - ((to_datetime("2023-01-01"), to_datetime("2023-01-02")), 0), + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + batch_index=0, ): set(), - ( - '"test_model"', - ((to_datetime("2023-01-02"), to_datetime("2023-01-03")), 1), + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + batch_index=1, ): set(), - ( - '"test_model"', - ((to_datetime("2023-01-03"), to_datetime("2023-01-04")), 2), + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + batch_index=2, ): set(), - ( - '"test_model"', - ((to_datetime("2023-01-04"), to_datetime("2023-01-05")), 3), + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + batch_index=3, ): set(), - ( - '"test_model"', - ((to_datetime("2023-01-05"), to_datetime("2023-01-06")), 4), + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + batch_index=4, ): set(), - ( - '"test_model"', - ((to_datetime("2023-01-06"), to_datetime("2023-01-07")), 5), + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + batch_index=5, ): set(), }, ), @@ -307,9 +395,10 @@ def test_incremental_time_self_reference_dag(mocker: MockerFixture, make_snapsho 10, 10, { - ( - '"test_model"', - ((to_datetime("2023-01-01"), to_datetime("2023-01-07")), 0), + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-07")), + batch_index=0, ): set(), }, ), @@ -317,9 +406,10 @@ def test_incremental_time_self_reference_dag(mocker: MockerFixture, make_snapsho 10, 1, { - ( - '"test_model"', - ((to_datetime("2023-01-01"), to_datetime("2023-01-07")), 0), + EvaluateNode( + snapshot_name='"test_model"', + interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-07")), + batch_index=0, ): set(), }, ), @@ -328,9 +418,10 @@ def test_incremental_time_self_reference_dag(mocker: MockerFixture, make_snapsho def test_incremental_batch_concurrency( mocker: MockerFixture, make_snapshot, + get_batched_missing_intervals, batch_size: int, batch_concurrency: int, - expected_graph: t.Dict[str, t.Any], + expected_graph: t.Dict[SchedulingUnit, t.Set[SchedulingUnit]], ): start = to_datetime("2023-01-01") end = to_datetime("2023-01-07") @@ -346,7 +437,7 @@ def test_incremental_batch_concurrency( ), ) - snapshot_evaluator = SnapshotEvaluator(adapter=mocker.MagicMock(), ddl_concurrent_tasks=1) + snapshot_evaluator = SnapshotEvaluator(adapters=mocker.MagicMock(), ddl_concurrent_tasks=1) mock_state_sync = mocker.MagicMock() scheduler = Scheduler( snapshots=[snapshot], @@ -356,9 +447,9 @@ def test_incremental_batch_concurrency( default_catalog=None, ) - batches = scheduler.batches(start, end, end) + batches = get_batched_missing_intervals(scheduler, start, end, end) dag = scheduler._dag(batches) - graph = {k: v for k, v in dag.graph.items() if k[1][1] != -1} # exclude the terminal node.} + graph = {k: v for k, v in dag.graph.items() if isinstance(k, EvaluateNode)} assert graph == expected_graph @@ -373,7 +464,9 @@ def test_circuit_breaker(scheduler: Scheduler): ) -def test_intervals_with_end_date_on_model(mocker: MockerFixture, make_snapshot): +def test_intervals_with_end_date_on_model( + mocker: MockerFixture, make_snapshot, get_batched_missing_intervals +): snapshot: Snapshot = make_snapshot( SqlModel( name="name", @@ -385,7 +478,7 @@ def test_intervals_with_end_date_on_model(mocker: MockerFixture, make_snapshot): ) ) - snapshot_evaluator = SnapshotEvaluator(adapter=mocker.MagicMock(), ddl_concurrent_tasks=1) + snapshot_evaluator = SnapshotEvaluator(adapters=mocker.MagicMock(), ddl_concurrent_tasks=1) scheduler = Scheduler( snapshots=[snapshot], snapshot_evaluator=snapshot_evaluator, @@ -396,27 +489,35 @@ def test_intervals_with_end_date_on_model(mocker: MockerFixture, make_snapshot): # generate for 1 year to show that the returned batches should only cover # the range defined on the model itself - batches = scheduler.batches(start="2023-01-01", end="2024-01-01")[snapshot] + batches = get_batched_missing_intervals(scheduler, start="2023-01-01", end="2024-01-01")[ + snapshot + ] assert len(batches) == 31 # days in Jan 2023 - assert batches[0] == (to_datetime("2023-01-01"), to_datetime("2023-01-02")) - assert batches[-1] == (to_datetime("2023-01-31"), to_datetime("2023-02-01")) + assert batches[0] == (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")) + assert batches[-1] == (to_timestamp("2023-01-31"), to_timestamp("2023-02-01")) # generate for less than 1 month to ensure that the scheduler end date # takes precedence over the model end date - batches = scheduler.batches(start="2023-01-01", end="2023-01-10")[snapshot] + batches = get_batched_missing_intervals(scheduler, start="2023-01-01", end="2023-01-10")[ + snapshot + ] assert len(batches) == 10 - assert batches[0] == (to_datetime("2023-01-01"), to_datetime("2023-01-02")) - assert batches[-1] == (to_datetime("2023-01-10"), to_datetime("2023-01-11")) + assert batches[0] == (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")) + assert batches[-1] == (to_timestamp("2023-01-10"), to_timestamp("2023-01-11")) # generate for the last day of range - batches = scheduler.batches(start="2023-01-31", end="2023-01-31")[snapshot] + batches = get_batched_missing_intervals(scheduler, start="2023-01-31", end="2023-01-31")[ + snapshot + ] assert len(batches) == 1 - assert batches[0] == (to_datetime("2023-01-31"), to_datetime("2023-02-01")) + assert batches[0] == (to_timestamp("2023-01-31"), to_timestamp("2023-02-01")) # generate for future days to ensure no future batches are loaded - snapshot_to_batches = scheduler.batches(start="2023-02-01", end="2023-02-28") + snapshot_to_batches = get_batched_missing_intervals( + scheduler, start="2023-02-01", end="2023-02-28" + ) assert len(snapshot_to_batches) == 0 @@ -439,7 +540,7 @@ def test_external_model_audit(mocker, make_snapshot): snapshot = make_snapshot(model) snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - evaluator = SnapshotEvaluator(adapter=mocker.MagicMock()) + evaluator = SnapshotEvaluator(adapters=mocker.MagicMock()) spy = mocker.spy(evaluator, "_audit") scheduler = Scheduler( @@ -460,89 +561,655 @@ def test_external_model_audit(mocker, make_snapshot): spy.assert_called_once() -def test_contiguous_intervals(): - from sqlmesh.core.scheduler import _contiguous_intervals as ci +def test_audit_failure_notifications( + scheduler: Scheduler, waiter_names: Snapshot, mocker: MockerFixture +): + evaluator_evaluate_mock = mocker.Mock() + mocker.patch("sqlmesh.core.scheduler.SnapshotEvaluator.evaluate", evaluator_evaluate_mock) + evaluator_audit_mock = mocker.Mock() + mocker.patch("sqlmesh.core.scheduler.SnapshotEvaluator.audit", evaluator_audit_mock) + notify_user_mock = mocker.Mock() + mocker.patch( + "sqlmesh.core.notification_target.NotificationTargetManager.notify_user", notify_user_mock + ) + notify_mock = mocker.Mock() + mocker.patch("sqlmesh.core.notification_target.NotificationTargetManager.notify", notify_mock) + + audit = first(waiter_names.model.audit_definitions.values()) + query = waiter_names.model.render_query() + + def _evaluate(): + scheduler.evaluate( + waiter_names, + to_datetime("2022-01-01"), + to_datetime("2022-01-02"), + to_datetime("2022-01-03"), + DeployabilityIndex.all_deployable(), + 0, + ) + + evaluator_audit_mock.return_value = [ + AuditResult( + audit=audit, + audit_args={}, + model=waiter_names.model, + query=query, + count=0, + skipped=False, + ) + ] + _evaluate() + assert notify_user_mock.call_count == 0 + assert notify_mock.call_count == 0 + + evaluator_audit_mock.return_value = [ + AuditResult( + audit=audit, + audit_args={}, + model=waiter_names.model, + query=query, + count=None, + skipped=True, + ) + ] + _evaluate() + assert notify_user_mock.call_count == 0 + assert notify_mock.call_count == 0 + + evaluator_audit_mock.return_value = [ + AuditResult( + audit=audit, + audit_args={}, + model=waiter_names.model, + query=query, + count=1, + skipped=False, + blocking=False, + ) + ] + _evaluate() + assert notify_user_mock.call_count == 1 + assert notify_mock.call_count == 1 + notify_user_mock.reset_mock() + notify_mock.reset_mock() - assert ci([]) == [] - assert ci([(0, 1)]) == [[(0, 1)]] - assert ci([(0, 1), (1, 2), (2, 3)]) == [[(0, 1), (1, 2), (2, 3)]] - assert ci([(0, 1), (3, 4), (4, 5), (6, 7)]) == [ - [(0, 1)], - [(3, 4), (4, 5)], - [(6, 7)], + evaluator_audit_mock.return_value = [ + AuditResult( + audit=audit, + audit_args={}, + model=waiter_names.model, + query=query, + count=1, + skipped=False, + ) ] + with pytest.raises(NodeAuditsErrors): + _evaluate() + assert notify_user_mock.call_count == 1 + assert notify_mock.call_count == 1 + + +def test_interval_diff(): + assert interval_diff([(1, 2)], []) == [(1, 2)] + assert interval_diff([(1, 2)], [(1, 2)]) == [] + assert interval_diff([(1, 2)], [(0, 2)]) == [] + assert interval_diff([(1, 2)], [(2, 3)]) == [(1, 2)] + assert interval_diff([(1, 2)], [(0, 1)]) == [(1, 2)] + assert interval_diff([(1, 2), (2, 3), (3, 4)], [(1, 4)]) == [] + assert interval_diff([(1, 2), (2, 3), (3, 4)], [(1, 2)]) == [(2, 3), (3, 4)] + assert interval_diff([(4, 5)], [(1, 2), (2, 3)]) == [(4, 5)] + assert interval_diff( + [(1, 2), (2, 3), (3, 4), (4, 5), (5, 6)], + [(2, 3), (4, 6)], + ) == [(1, 2), (3, 4)] + + assert interval_diff( + [(1, 2), (2, 3), (3, 4)], + [(1, 3)], + ) == [(3, 4)] + + assert interval_diff( + [(1, 3), (3, 4)], + [(1, 2), (2, 3)], + ) == [(3, 4)] + + assert interval_diff([(1, 2), (2, 3)], [(1, 2)], uninterrupted=True) == [] + assert interval_diff([(1, 2), (2, 3)], [(3, 4)], uninterrupted=True) == [(1, 2), (2, 3)] + assert interval_diff([(1, 2), (2, 3)], [(2, 3)], uninterrupted=True) == [(1, 2)] -def test_check_ready_intervals(mocker: MockerFixture): - from sqlmesh.core.scheduler import _check_ready_intervals, Interval - - def const_signal(const): - signal_mock = mocker.Mock() - signal_mock.check_intervals = mocker.MagicMock(return_value=const) - return signal_mock - - def assert_always_signal(intervals): - _check_ready_intervals(const_signal(True), intervals) == intervals - - assert_always_signal([]) - assert_always_signal([(0, 1)]) - assert_always_signal([(0, 1), (1, 2)]) - assert_always_signal([(0, 1), (2, 3)]) - - def assert_never_signal(intervals): - _check_ready_intervals(const_signal(False), intervals) == [] - - assert_never_signal([]) - assert_never_signal([(0, 1)]) - assert_never_signal([(0, 1), (1, 2)]) - assert_never_signal([(0, 1), (2, 3)]) - - def to_intervals(values: t.List[t.Tuple[int, int]]) -> t.List[Interval]: - return [(to_datetime(s), to_datetime(e)) for s, e in values] - - def assert_check_intervals( - intervals: t.List[t.Tuple[int, int]], - ready: t.List[t.List[t.Tuple[int, int]]], - expected: t.List[t.Tuple[int, int]], - ): - signal_mock = mocker.Mock() - signal_mock.check_intervals = mocker.MagicMock(side_effect=[to_intervals(r) for r in ready]) - _check_ready_intervals(signal_mock, intervals) == expected - - assert_check_intervals([], [], []) - assert_check_intervals([(0, 1)], [[]], []) - assert_check_intervals( - [(0, 1)], - [[(0, 1)]], - [(0, 1)], - ) - assert_check_intervals( - [(0, 1), (1, 2)], - [[(0, 1)]], - [(0, 1)], - ) - assert_check_intervals( - [(0, 1), (1, 2)], - [[(1, 2)]], - [(1, 2)], - ) - assert_check_intervals( - [(0, 1), (1, 2)], - [[(0, 1), (1, 2)]], - [(0, 1), (1, 2)], - ) - assert_check_intervals( - [(0, 1), (1, 2), (3, 4)], - [[], []], - [], - ) - assert_check_intervals( - [(0, 1), (1, 2), (3, 4)], - [[(0, 1)], []], - [(0, 1)], - ) - assert_check_intervals( - [(0, 1), (1, 2), (3, 4)], - [[(0, 1)], [(3, 4)]], - [(0, 1), (3, 4)], +def test_signal_intervals(mocker: MockerFixture, make_snapshot, get_batched_missing_intervals): + @signal() + def signal_a(batch: DatetimeRanges, context: ExecutionContext): + if not hasattr(context, "engine_adapter"): + raise + return [batch[0], batch[1]] + + @signal() + def signal_b(batch: DatetimeRanges): + return batch[-49:] + + signals = signal.get_registry() + + a = make_snapshot( + load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name a, + kind FULL, + start '2023-01-01', + signals SIGNAL_A(), + ); + + SELECT 1 x; + """ + ), + signal_definitions=signals, + ), ) + + b = make_snapshot( + load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name b, + kind FULL, + cron '@hourly', + start '2023-01-01', + signals SIGNAL_B(), + ); + + SELECT 2 x; + """ + ), + signal_definitions=signals, + ), + nodes={a.name: a.model}, + ) + + c = make_snapshot( + load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name c, + kind FULL, + start '2023-01-01', + ); + + SELECT * FROM a UNION SELECT * FROM b + """ + ), + signal_definitions=signals, + ), + nodes={a.name: a.model, b.name: b.model}, + ) + d = make_snapshot( + load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name d, + kind FULL, + start '2023-01-01', + ); + + SELECT * FROM c UNION SELECT * FROM d + """ + ), + signal_definitions=signals, + ), + nodes={a.name: a.model, b.name: b.model, c.name: c.model}, + ) + + snapshot_evaluator = SnapshotEvaluator(adapters=mocker.MagicMock(), ddl_concurrent_tasks=1) + scheduler = Scheduler( + snapshots=[a, b, c, d], + snapshot_evaluator=snapshot_evaluator, + state_sync=mocker.MagicMock(), + max_workers=2, + default_catalog=None, + ) + + batches = get_batched_missing_intervals(scheduler, "2023-01-01", "2023-01-03", None) + + assert batches == { + a: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-03"))], + b: [(to_timestamp("2023-01-01 23:00:00"), to_timestamp("2023-01-04"))], + # Full models and models that depend on past can't run for a discontinuous range + c: [], + d: [], + } + + +def test_signals_snapshots_out_of_order( + mocker: MockerFixture, make_snapshot, get_batched_missing_intervals +): + @signal() + def signal_base(batch: DatetimeRanges): + return [batch[0]] + + signals = signal.get_registry() + + snapshot_a = make_snapshot( + load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name a, + kind INCREMENTAL_BY_TIME_RANGE( + lookback 1, + time_column dt, + ), + start '2023-01-01', + signals SIGNAL_BASE(), + ); + SELECT @start_date AS dt; + """ + ), + signal_definitions=signals, + ), + ) + + snapshot_b = make_snapshot( + load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name b, + kind INCREMENTAL_BY_TIME_RANGE( + lookback 1, + time_column dt, + ), + start '2023-01-01' + ); + SELECT @start_date AS dt; + """ + ), + signal_definitions=signals, + ) + ) + + snapshot_c = make_snapshot( + load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name c, + kind INCREMENTAL_BY_TIME_RANGE( + lookback 1, + time_column dt, + ), + start '2023-01-01', + ); + SELECT * FROM a UNION SELECT * FROM b + """ + ), + signal_definitions=signals, + ), + nodes={snapshot_a.name: snapshot_a.model, snapshot_b.name: snapshot_b.model}, + ) + + snapshot_evaluator = SnapshotEvaluator(adapters=mocker.MagicMock(), ddl_concurrent_tasks=1) + scheduler = Scheduler( + snapshots=[snapshot_c, snapshot_b, snapshot_a], # reverse order + snapshot_evaluator=snapshot_evaluator, + state_sync=mocker.MagicMock(), + max_workers=2, + default_catalog=None, + ) + + batches = get_batched_missing_intervals(scheduler, "2023-01-01", "2023-01-03", None) + + assert batches == { + snapshot_a: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))], + snapshot_b: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-04"))], + snapshot_c: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))], + } + + +@pytest.mark.parametrize( + "batch_size, expected_batches", + [ + ( + 1, + [ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + ], + ), + ( + None, + [ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-04")), + ], + ), + ], +) +def test_scd_type_2_batch_size( + mocker: MockerFixture, + make_snapshot, + get_batched_missing_intervals, + batch_size: t.Optional[int], + expected_batches: t.List[t.Tuple[int, int]], +): + """ + Test that SCD_TYPE_2_BY_COLUMN models are batched correctly based on batch_size. + With batch_size=1, we expect 3 separate batches for 3 days. + Without a specified batch_size, we expect a single batch for the entire period. + """ + start = to_datetime("2023-01-01") + end = to_datetime("2023-01-04") + + # Configure kind params + kind_params = {} + if batch_size is not None: + kind_params["batch_size"] = batch_size + + # Create the model and snapshot + model = SqlModel( + name="test_scd_model", + kind=SCDType2ByColumnKind(columns="valid_to", unique_key=["id"], **kind_params), + cron="@daily", + start=start, + query=parse_one("SELECT id, valid_from, valid_to FROM source"), + ) + snapshot = make_snapshot(model) + + # Setup scheduler + snapshot_evaluator = SnapshotEvaluator(adapters=mocker.MagicMock(), ddl_concurrent_tasks=1) + scheduler = Scheduler( + snapshots=[snapshot], + snapshot_evaluator=snapshot_evaluator, + state_sync=mocker.MagicMock(), + max_workers=2, + default_catalog=None, + ) + + # Get batches for the time period + batches = get_batched_missing_intervals(scheduler, start, end, end)[snapshot] + + # Verify batches match expectations + assert batches == expected_batches + + +def test_before_all_environment_statements_called_first(mocker: MockerFixture, make_snapshot): + model = SqlModel( + name="test.model_items", + query=parse_one("SELECT id, ds FROM raw.items"), + kind=IncrementalByTimeRangeKind(time_column=TimeColumn(column="ds")), + ) + snapshot = make_snapshot(model) + + # to track the order of calls + call_order = [] + + mock_state_sync = mocker.MagicMock() + mock_state_sync.get_environment_statements.return_value = [ + ("CREATE TABLE IF NOT EXISTS test_table (id INT)", RuntimeStage.BEFORE_ALL) + ] + + def record_get_environment_statements(*args, **kwargs): + call_order.append("get_environment_statements") + return mock_state_sync.get_environment_statements.return_value + + mock_state_sync.get_environment_statements.side_effect = record_get_environment_statements + + mock_snapshot_evaluator = mocker.MagicMock() + mock_adapter = mocker.MagicMock() + mock_snapshot_evaluator.adapter = mock_adapter + + def record_get_snapshots_to_create(*args, **kwargs): + call_order.append("get_snapshots_to_create") + return [] + + mock_snapshot_evaluator.get_snapshots_to_create.side_effect = record_get_snapshots_to_create + + mock_execute_env_statements = mocker.patch( + "sqlmesh.core.scheduler.execute_environment_statements" + ) + + def record_execute_environment_statements(*args, **kwargs): + call_order.append("execute_environment_statements") + + mock_execute_env_statements.side_effect = record_execute_environment_statements + + scheduler = Scheduler( + snapshots=[snapshot], + snapshot_evaluator=mock_snapshot_evaluator, + state_sync=mock_state_sync, + default_catalog=None, + ) + merged_intervals = { + snapshot: [ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + ], + } + + deployability_index = DeployabilityIndex.create([snapshot]) + environment_naming_info = EnvironmentNamingInfo(name="test_env") + + scheduler.run_merged_intervals( + merged_intervals=merged_intervals, + deployability_index=deployability_index, + environment_naming_info=environment_naming_info, + run_environment_statements=True, + ) + + mock_state_sync.get_environment_statements.assert_called_once_with("test_env") + mock_snapshot_evaluator.get_snapshots_to_create.assert_called_once() + + # execute_environment_statements is called twice + assert mock_execute_env_statements.call_count == 2 + + # first for before all and second for after all + first_call = mock_execute_env_statements.call_args_list[0] + assert first_call.kwargs["runtime_stage"] == RuntimeStage.BEFORE_ALL + second_call = mock_execute_env_statements.call_args_list[1] + assert second_call.kwargs["runtime_stage"] == RuntimeStage.AFTER_ALL + + assert "get_environment_statements" in call_order + assert "execute_environment_statements" in call_order + assert "get_snapshots_to_create" in call_order + + # Verify the before all environment statements are called first before get_snapshots_to_create + env_statements_idx = call_order.index("get_environment_statements") + execute_env_idx = call_order.index("execute_environment_statements") + snapshots_to_create_idx = call_order.index("get_snapshots_to_create") + assert env_statements_idx < execute_env_idx < snapshots_to_create_idx + + +def test_dag_transitive_deps(mocker: MockerFixture, make_snapshot): + # Create a simple dependency chain: A <- B <- C + snapshot_a = make_snapshot(SqlModel(name="a", query=parse_one("SELECT 1 as id"))) + snapshot_b = make_snapshot(SqlModel(name="b", query=parse_one("SELECT * FROM a"))) + snapshot_c = make_snapshot(SqlModel(name="c", query=parse_one("SELECT * FROM b"))) + + snapshot_b = snapshot_b.model_copy(update={"parents": (snapshot_a.snapshot_id,)}) + snapshot_c = snapshot_c.model_copy(update={"parents": (snapshot_b.snapshot_id,)}) + + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_c.categorize_as(SnapshotChangeCategory.BREAKING) + + scheduler = Scheduler( + snapshots=[snapshot_a, snapshot_b, snapshot_c], + snapshot_evaluator=mocker.Mock(), + state_sync=mocker.Mock(), + default_catalog=None, + ) + + # Test scenario: select only A and C (skip B) + merged_intervals = { + snapshot_a: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))], + snapshot_c: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))], + } + + deployability_index = DeployabilityIndex.create([snapshot_a, snapshot_b, snapshot_c]) + + full_dag = snapshots_to_dag([snapshot_a, snapshot_b, snapshot_c]) + + dag = scheduler._dag(merged_intervals, snapshot_dag=full_dag) + assert dag.graph == { + EvaluateNode( + snapshot_name='"a"', + interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + batch_index=0, + ): set(), + EvaluateNode( + snapshot_name='"c"', + interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + batch_index=0, + ): { + EvaluateNode( + snapshot_name='"a"', + interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + batch_index=0, + ) + }, + } + + +def test_dag_multiple_chain_transitive_deps(mocker: MockerFixture, make_snapshot): + # Create a more complex dependency graph: + # A <- B <- D <- E + # A <- C <- D <- E + # Select A and E only + snapshots = {} + for name in ["a", "b", "c", "d", "e"]: + snapshots[name] = make_snapshot(SqlModel(name=name, query=parse_one("SELECT 1 as id"))) + snapshots[name].categorize_as(SnapshotChangeCategory.BREAKING) + + # Set up dependencies + snapshots["b"] = snapshots["b"].model_copy(update={"parents": (snapshots["a"].snapshot_id,)}) + snapshots["c"] = snapshots["c"].model_copy(update={"parents": (snapshots["a"].snapshot_id,)}) + snapshots["d"] = snapshots["d"].model_copy( + update={"parents": (snapshots["b"].snapshot_id, snapshots["c"].snapshot_id)} + ) + snapshots["e"] = snapshots["e"].model_copy(update={"parents": (snapshots["d"].snapshot_id,)}) + + scheduler = Scheduler( + snapshots=list(snapshots.values()), + snapshot_evaluator=mocker.Mock(), + state_sync=mocker.Mock(), + default_catalog=None, + ) + + # Only provide intervals for A and E + batched_intervals = { + snapshots["a"]: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))], + snapshots["e"]: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))], + } + + # Create subdag including transitive dependencies + full_dag = snapshots_to_dag(snapshots.values()) + + dag = scheduler._dag(batched_intervals, snapshot_dag=full_dag) + assert dag.graph == { + EvaluateNode( + snapshot_name='"a"', + interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + batch_index=0, + ): set(), + EvaluateNode( + snapshot_name='"e"', + interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + batch_index=0, + ): { + EvaluateNode( + snapshot_name='"a"', + interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + batch_index=0, + ) + }, + } + + +def test_dag_upstream_dependency_caching_with_complex_diamond(mocker: MockerFixture, make_snapshot): + r""" + Test that the upstream dependency caching correctly handles a complex diamond dependency graph. + + Dependency graph: + A (has intervals) + / \ + B C (no intervals - transitive) + / \ / \ + D E F (no intervals - transitive) + \ / \ / + G H (has intervals - selected) + + This creates multiple paths from G and H to A. Without caching, A's dependencies would be + computed multiple times (once for each path). With caching, they should be computed once + and reused. + """ + snapshots = {} + + for name in ["a", "b", "c", "d", "e", "f", "g", "h"]: + snapshots[name] = make_snapshot(SqlModel(name=name, query=parse_one("SELECT 1 as id"))) + snapshots[name].categorize_as(SnapshotChangeCategory.BREAKING) + + # A is the root + snapshots["b"] = snapshots["b"].model_copy(update={"parents": (snapshots["a"].snapshot_id,)}) + snapshots["c"] = snapshots["c"].model_copy(update={"parents": (snapshots["a"].snapshot_id,)}) + + # Middle layer: D, E, F depend on B and/or C + snapshots["d"] = snapshots["d"].model_copy(update={"parents": (snapshots["b"].snapshot_id,)}) + snapshots["e"] = snapshots["e"].model_copy( + update={"parents": (snapshots["b"].snapshot_id, snapshots["c"].snapshot_id)} + ) + snapshots["f"] = snapshots["f"].model_copy(update={"parents": (snapshots["c"].snapshot_id,)}) + + # Bottom layer: G and H depend on D/E and E/F respectively + snapshots["g"] = snapshots["g"].model_copy( + update={"parents": (snapshots["d"].snapshot_id, snapshots["e"].snapshot_id)} + ) + snapshots["h"] = snapshots["h"].model_copy( + update={"parents": (snapshots["e"].snapshot_id, snapshots["f"].snapshot_id)} + ) + + scheduler = Scheduler( + snapshots=list(snapshots.values()), + snapshot_evaluator=mocker.Mock(), + state_sync=mocker.Mock(), + default_catalog=None, + ) + + batched_intervals = { + snapshots["a"]: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))], + snapshots["g"]: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))], + snapshots["h"]: [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))], + } + + full_dag = snapshots_to_dag(snapshots.values()) + dag = scheduler._dag(batched_intervals, snapshot_dag=full_dag) + + # Verify the DAG structure: + # 1. A should be evaluated first (no dependencies) + # 2. Both G and H should depend on A (through transitive dependencies) + # 3. Transitive nodes (B, C, D, E, F) should not appear as separate evaluation nodes + expected_a_node = EvaluateNode( + snapshot_name='"a"', + interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + batch_index=0, + ) + + expected_g_node = EvaluateNode( + snapshot_name='"g"', + interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + batch_index=0, + ) + + expected_h_node = EvaluateNode( + snapshot_name='"h"', + interval=(to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + batch_index=0, + ) + + assert dag.graph == { + expected_a_node: set(), + expected_g_node: {expected_a_node}, + expected_h_node: {expected_a_node}, + } diff --git a/tests/core/test_schema_diff.py b/tests/core/test_schema_diff.py index f5581c588e..52bd6bb606 100644 --- a/tests/core/test_schema_diff.py +++ b/tests/core/test_schema_diff.py @@ -9,21 +9,27 @@ TableAlterColumn, TableAlterColumnPosition, TableAlterOperation, + get_schema_differ, + TableAlterAddColumnOperation, + TableAlterDropColumnOperation, + TableAlterChangeColumnTypeOperation, + NestedSupport, ) +from sqlmesh.utils.errors import SQLMeshError def test_schema_diff_calculate(): - alter_expressions = SchemaDiffer( + alter_operations = SchemaDiffer( **{ "support_positional_add": False, - "support_nested_operations": False, + "nested_support": NestedSupport.NONE, "array_element_selector": "", "compatible_types": { exp.DataType.build("STRING"): {exp.DataType.build("INT")}, }, } ).compare_columns( - "apply_to_table", + exp.to_table("apply_to_table"), { "id": exp.DataType.build("INT"), "name": exp.DataType.build("STRING"), @@ -38,25 +44,51 @@ def test_schema_diff_calculate(): }, ) - assert [x.sql() for x in alter_expressions] == [ + assert [x.expression.sql() for x in alter_operations] == [ """ALTER TABLE apply_to_table DROP COLUMN price""", """ALTER TABLE apply_to_table ADD COLUMN new_column DOUBLE""", """ALTER TABLE apply_to_table ALTER COLUMN name SET DATA TYPE INT""", ] +def test_schema_diff_drop_cascade(): + alter_expressions = SchemaDiffer( + **{ + "support_positional_add": False, + "nested_support": NestedSupport.NONE, + "array_element_selector": "", + "drop_cascade": True, + } + ).compare_columns( + exp.to_table("apply_to_table"), + { + "id": exp.DataType.build("INT"), + "name": exp.DataType.build("STRING"), + "price": exp.DataType.build("DOUBLE"), + }, + { + "id": exp.DataType.build("INT"), + "name": exp.DataType.build("STRING"), + }, + ) + + assert [x.expression.sql() for x in alter_expressions] == [ + """ALTER TABLE apply_to_table DROP COLUMN price CASCADE""" + ] + + def test_schema_diff_calculate_type_transitions(): alter_expressions = SchemaDiffer( **{ "support_positional_add": False, - "support_nested_operations": False, + "nested_support": NestedSupport.NONE, "array_element_selector": "", "compatible_types": { exp.DataType.build("STRING"): {exp.DataType.build("INT")}, }, } ).compare_columns( - "apply_to_table", + exp.to_table("apply_to_table"), { "id": exp.DataType.build("INT"), "ds": exp.DataType.build("STRING"), @@ -67,7 +99,7 @@ def test_schema_diff_calculate_type_transitions(): }, ) - assert [x.sql() for x in alter_expressions] == [ + assert [x.expression.sql() for x in alter_expressions] == [ """ALTER TABLE apply_to_table DROP COLUMN id""", """ALTER TABLE apply_to_table ADD COLUMN id BIGINT""", """ALTER TABLE apply_to_table ALTER COLUMN ds SET DATA TYPE INT""", @@ -87,10 +119,14 @@ def test_schema_diff_calculate_type_transitions(): "STRUCT", "STRUCT", [ - TableAlterOperation.add( - TableAlterColumn.primitive("address"), - "STRING", - "STRUCT", + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address")], + column_type=exp.DataType.build("STRING"), + expected_table_struct=exp.DataType.build( + "STRUCT" + ), + array_element_selector="", ) ], {}, @@ -100,10 +136,14 @@ def test_schema_diff_calculate_type_transitions(): "STRUCT", "STRUCT
", [ - TableAlterOperation.add( - TableAlterColumn.primitive("address"), - "STRING", - expected_table_struct="STRUCT
", + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address")], + column_type=exp.DataType.build("STRING"), + expected_table_struct=exp.DataType.build( + "STRUCT
" + ), + array_element_selector="", position=TableAlterColumnPosition.first(), ) ], @@ -114,10 +154,14 @@ def test_schema_diff_calculate_type_transitions(): "STRUCT", "STRUCT", [ - TableAlterOperation.add( - TableAlterColumn.primitive("address"), - "STRING", - expected_table_struct="STRUCT", + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address")], + column_type=exp.DataType.build("STRING"), + expected_table_struct=exp.DataType.build( + "STRUCT" + ), + array_element_selector="", position=TableAlterColumnPosition.middle(after="id"), ) ], @@ -128,22 +172,34 @@ def test_schema_diff_calculate_type_transitions(): "STRUCT", "STRUCT
", [ - TableAlterOperation.add( - TableAlterColumn.primitive("address"), - "STRING", - expected_table_struct="STRUCT
", + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address")], + column_type=exp.DataType.build("STRING"), + expected_table_struct=exp.DataType.build( + "STRUCT
" + ), + array_element_selector="", position=TableAlterColumnPosition.first(), ), - TableAlterOperation.add( - TableAlterColumn.primitive("address2"), - "STRING", - expected_table_struct="STRUCT
", + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address2")], + column_type=exp.DataType.build("STRING"), + expected_table_struct=exp.DataType.build( + "STRUCT
" + ), + array_element_selector="", position=TableAlterColumnPosition.middle(after="id"), ), - TableAlterOperation.add( - TableAlterColumn.primitive("address3"), - "STRING", - expected_table_struct="STRUCT
", + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address3")], + column_type=exp.DataType.build("STRING"), + expected_table_struct=exp.DataType.build( + "STRUCT
" + ), + array_element_selector="", position=TableAlterColumnPosition.last(after="age"), ), ], @@ -154,16 +210,24 @@ def test_schema_diff_calculate_type_transitions(): "STRUCT", "STRUCT
", [ - TableAlterOperation.add( - TableAlterColumn.primitive("address"), - "STRING", - expected_table_struct="STRUCT
", + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address")], + column_type=exp.DataType.build("STRING"), + expected_table_struct=exp.DataType.build( + "STRUCT
" + ), + array_element_selector="", position=TableAlterColumnPosition.first(), ), - TableAlterOperation.add( - TableAlterColumn.primitive("address2"), - "STRING", - expected_table_struct="STRUCT
", + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address2")], + column_type=exp.DataType.build("STRING"), + expected_table_struct=exp.DataType.build( + "STRUCT
" + ), + array_element_selector="", position=TableAlterColumnPosition.middle(after="address"), ), ], @@ -177,10 +241,11 @@ def test_schema_diff_calculate_type_transitions(): "STRUCT", "STRUCT", [ - TableAlterOperation.drop( - TableAlterColumn.primitive("id"), - "STRUCT", - "INT", + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("id")], + expected_table_struct=exp.DataType.build("STRUCT"), + array_element_selector="", ) ], {}, @@ -190,10 +255,11 @@ def test_schema_diff_calculate_type_transitions(): "STRUCT", "STRUCT", [ - TableAlterOperation.drop( - TableAlterColumn.primitive("name"), - "STRUCT", - "STRING", + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("name")], + expected_table_struct=exp.DataType.build("STRUCT"), + array_element_selector="", ) ], {}, @@ -203,10 +269,11 @@ def test_schema_diff_calculate_type_transitions(): "STRUCT", "STRUCT", [ - TableAlterOperation.drop( - TableAlterColumn.primitive("age"), - "STRUCT", - "INT", + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("age")], + expected_table_struct=exp.DataType.build("STRUCT"), + array_element_selector="", ) ], {}, @@ -216,20 +283,27 @@ def test_schema_diff_calculate_type_transitions(): "STRUCT", "STRUCT", [ - TableAlterOperation.drop( - TableAlterColumn.primitive("id"), - "STRUCT", - "INT", + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("id")], + expected_table_struct=exp.DataType.build( + "STRUCT" + ), + array_element_selector="", ), - TableAlterOperation.drop( - TableAlterColumn.primitive("middle"), - "STRUCT", - "STRING", + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("middle")], + expected_table_struct=exp.DataType.build( + "STRUCT" + ), + array_element_selector="", ), - TableAlterOperation.drop( - TableAlterColumn.primitive("age"), - "STRUCT", - "INT", + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("age")], + expected_table_struct=exp.DataType.build("STRUCT"), + array_element_selector="", ), ], {}, @@ -239,15 +313,21 @@ def test_schema_diff_calculate_type_transitions(): "STRUCT
", "STRUCT", [ - TableAlterOperation.drop( - TableAlterColumn.primitive("address"), - "STRUCT", - "STRING", + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address")], + expected_table_struct=exp.DataType.build( + "STRUCT" + ), + array_element_selector="", ), - TableAlterOperation.drop( - TableAlterColumn.primitive("address2"), - "STRUCT", - "STRING", + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address2")], + expected_table_struct=exp.DataType.build( + "STRUCT" + ), + array_element_selector="", ), ], {}, @@ -269,11 +349,15 @@ def test_schema_diff_calculate_type_transitions(): "STRUCT", "STRUCT", [ - TableAlterOperation.alter_type( - TableAlterColumn.primitive("id"), - "STRING", - current_type="INT", - expected_table_struct="STRUCT", + TableAlterChangeColumnTypeOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("id")], + column_type=exp.DataType.build("STRING"), + current_type=exp.DataType.build("INT"), + expected_table_struct=exp.DataType.build( + "STRUCT" + ), + array_element_selector="", ) ], dict( @@ -290,21 +374,30 @@ def test_schema_diff_calculate_type_transitions(): "STRUCT", "STRUCT", [ - TableAlterOperation.drop( - TableAlterColumn.primitive("name"), - "STRUCT", - "STRING", - ), - TableAlterOperation.add( - TableAlterColumn.primitive("address"), - "STRING", - expected_table_struct="STRUCT", - ), - TableAlterOperation.alter_type( - TableAlterColumn.primitive("id"), - "STRING", - current_type="INT", - expected_table_struct="STRUCT", + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("name")], + expected_table_struct=exp.DataType.build("STRUCT"), + array_element_selector="", + ), + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address")], + column_type=exp.DataType.build("STRING"), + expected_table_struct=exp.DataType.build( + "STRUCT" + ), + array_element_selector="", + ), + TableAlterChangeColumnTypeOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("id")], + column_type=exp.DataType.build("STRING"), + current_type=exp.DataType.build("INT"), + expected_table_struct=exp.DataType.build( + "STRUCT" + ), + array_element_selector="", ), ], dict( @@ -321,169 +414,280 @@ def test_schema_diff_calculate_type_transitions(): "STRUCT>", "STRUCT>", [ - TableAlterOperation.add( - [ + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ TableAlterColumn.struct("info"), TableAlterColumn.primitive("col_d"), ], - "INT", - expected_table_struct="STRUCT>", + column_type=exp.DataType.build("INT"), + expected_table_struct=exp.DataType.build( + "STRUCT>" + ), + array_element_selector="", position=TableAlterColumnPosition.first(), ), ], - dict(support_positional_add=True, support_nested_operations=True), + dict(support_positional_add=True, nested_support=NestedSupport.ALL_BUT_DROP), ), # Add a column to the end of a struct ( "STRUCT>", "STRUCT>", [ - TableAlterOperation.add( - [ + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ TableAlterColumn.struct("info"), TableAlterColumn.primitive("col_d"), ], - "INT", - expected_table_struct="STRUCT>", + column_type=exp.DataType.build("INT"), + expected_table_struct=exp.DataType.build( + "STRUCT>" + ), + array_element_selector="", position=TableAlterColumnPosition.last(after="col_c"), ), ], - dict(support_positional_add=True, support_nested_operations=True), + dict(support_positional_add=True, nested_support=NestedSupport.ALL_BUT_DROP), ), # Add a column to the middle of a struct ( "STRUCT>", "STRUCT>", [ - TableAlterOperation.add( - [ + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ TableAlterColumn.struct("info"), TableAlterColumn.primitive("col_d"), ], - "INT", - expected_table_struct="STRUCT>", + column_type=exp.DataType.build("INT"), + expected_table_struct=exp.DataType.build( + "STRUCT>" + ), position=TableAlterColumnPosition.middle(after="col_a"), + array_element_selector="", ), ], - dict(support_positional_add=True, support_nested_operations=True), + dict(support_positional_add=True, nested_support=NestedSupport.ALL_BUT_DROP), ), # Add two columns at the start of a struct ( "STRUCT>", "STRUCT>", [ - TableAlterOperation.add( - [ + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ TableAlterColumn.struct("info"), TableAlterColumn.primitive("col_d"), ], - "INT", - expected_table_struct="STRUCT>", + column_type=exp.DataType.build("INT"), + expected_table_struct=exp.DataType.build( + "STRUCT>" + ), position=TableAlterColumnPosition.first(), + array_element_selector="", ), - TableAlterOperation.add( - [ + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ TableAlterColumn.struct("info"), TableAlterColumn.primitive("col_e"), ], - "INT", - expected_table_struct="STRUCT>", + column_type=exp.DataType.build("INT"), + expected_table_struct=exp.DataType.build( + "STRUCT>" + ), position=TableAlterColumnPosition.middle(after="col_d"), + array_element_selector="", ), ], - dict(support_positional_add=True, support_nested_operations=True), + dict(support_positional_add=True, nested_support=NestedSupport.ALL_BUT_DROP), + ), + # Add columns in different levels of nesting of structs + ( + "STRUCT>", + "STRUCT, txt TEXT>", + [ + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ + TableAlterColumn.primitive("txt"), + ], + column_type=exp.DataType.build("TEXT"), + expected_table_struct=exp.DataType.build( + "STRUCT, txt TEXT>" + ), + array_element_selector="", + ), + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ + TableAlterColumn.struct("info"), + TableAlterColumn.primitive("col_d"), + ], + column_type=exp.DataType.build("INT"), + expected_table_struct=exp.DataType.build( + "STRUCT, txt TEXT>" + ), + array_element_selector="", + ), + ], + dict(support_positional_add=False, nested_support=NestedSupport.ALL_BUT_DROP), ), # Remove a column from the start of a struct ( "STRUCT>", "STRUCT>", [ - TableAlterOperation.drop( - [ + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ TableAlterColumn.struct("info"), TableAlterColumn.primitive("col_a"), ], - "STRUCT>", - "INT", + expected_table_struct=exp.DataType.build( + "STRUCT>" + ), + array_element_selector="", ), ], - dict(support_positional_add=True, support_nested_operations=True), + dict( + support_positional_add=True, + nested_support=NestedSupport.ALL, + ), ), # Remove a column from the end of a struct ( "STRUCT>", "STRUCT>", [ - TableAlterOperation.drop( - [ + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ TableAlterColumn.struct("info"), TableAlterColumn.primitive("col_c"), ], - "STRUCT>", - "INT", + expected_table_struct=exp.DataType.build( + "STRUCT>" + ), + array_element_selector="", ), ], - dict(support_positional_add=True, support_nested_operations=True), + dict( + support_positional_add=True, + nested_support=NestedSupport.ALL, + ), ), # Remove a column from the middle of a struct ( "STRUCT>", "STRUCT>", [ - TableAlterOperation.drop( - [ + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ TableAlterColumn.struct("info"), TableAlterColumn.primitive("col_b"), ], - "STRUCT>", - "INT", + expected_table_struct=exp.DataType.build( + "STRUCT>" + ), + array_element_selector="", ), ], - dict(support_positional_add=True, support_nested_operations=True), + dict( + support_positional_add=True, + nested_support=NestedSupport.ALL, + ), + ), + # Remove a column from a struct where nested drop is not supported + ( + "STRUCT>", + "STRUCT>", + [ + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ + TableAlterColumn.struct("info"), + ], + expected_table_struct=exp.DataType.build("STRUCT"), + array_element_selector="", + ), + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ + TableAlterColumn.struct("info"), + ], + expected_table_struct=exp.DataType.build( + "STRUCT>" + ), + column_type=exp.DataType.build("STRUCT"), + array_element_selector="", + is_part_of_destructive_change=True, + ), + ], + dict( + nested_support=NestedSupport.ALL_BUT_DROP, + ), ), # Remove two columns from the start of a struct ( "STRUCT>", "STRUCT>", [ - TableAlterOperation.drop( - [ + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ TableAlterColumn.struct("info"), TableAlterColumn.primitive("col_a"), ], - "STRUCT>", - "INT", + expected_table_struct=exp.DataType.build( + "STRUCT>" + ), + array_element_selector="", ), - TableAlterOperation.drop( - [ + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ TableAlterColumn.struct("info"), TableAlterColumn.primitive("col_b"), ], - "STRUCT>", - "INT", + expected_table_struct=exp.DataType.build( + "STRUCT>" + ), + array_element_selector="", ), ], - dict(support_positional_add=True, support_nested_operations=True), + dict( + support_positional_add=True, + nested_support=NestedSupport.ALL, + ), ), # Change a column type in a struct ( "STRUCT>", "STRUCT>", [ - TableAlterOperation.alter_type( - [ + TableAlterChangeColumnTypeOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ TableAlterColumn.struct("info"), TableAlterColumn.primitive("col_c"), ], - "TEXT", - expected_table_struct="STRUCT>", - position=TableAlterColumnPosition.last(after="col_b"), + column_type=exp.DataType.build("TEXT"), + expected_table_struct=exp.DataType.build( + "STRUCT>" + ), current_type=exp.DataType.build("INT"), + array_element_selector="", ), ], dict( support_positional_add=True, - support_nested_operations=True, + nested_support=NestedSupport.ALL_BUT_DROP, compatible_types={ exp.DataType.build("INT"): {exp.DataType.build("TEXT")}, }, @@ -494,46 +698,93 @@ def test_schema_diff_calculate_type_transitions(): "STRUCT>", "STRUCT>", [ - TableAlterOperation.drop( - [ + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ TableAlterColumn.struct("info"), TableAlterColumn.primitive("col_a"), ], - "STRUCT>", - "INT", + expected_table_struct=exp.DataType.build( + "STRUCT>" + ), + array_element_selector="", ), - TableAlterOperation.add( - [ + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ TableAlterColumn.struct("info"), TableAlterColumn.primitive("col_d"), ], - "INT", - expected_table_struct="STRUCT>", + column_type=exp.DataType.build("INT"), + expected_table_struct=exp.DataType.build( + "STRUCT>" + ), position=TableAlterColumnPosition.first(), + array_element_selector="", ), - TableAlterOperation.add( - [ + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ TableAlterColumn.struct("info"), TableAlterColumn.primitive("col_e"), ], - "INT", - expected_table_struct="STRUCT>", + column_type=exp.DataType.build("INT"), + expected_table_struct=exp.DataType.build( + "STRUCT>" + ), position=TableAlterColumnPosition.middle(after="col_b"), + array_element_selector="", ), - TableAlterOperation.alter_type( - [ + TableAlterChangeColumnTypeOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ TableAlterColumn.struct("info"), TableAlterColumn.primitive("col_c"), ], - "TEXT", - expected_table_struct="STRUCT>", - position=TableAlterColumnPosition.last(after="col_e"), + column_type=exp.DataType.build("TEXT"), + expected_table_struct=exp.DataType.build( + "STRUCT>" + ), current_type=exp.DataType.build("INT"), + array_element_selector="", ), ], dict( support_positional_add=True, - support_nested_operations=True, + nested_support=NestedSupport.ALL, + compatible_types={ + exp.DataType.build("INT"): {exp.DataType.build("TEXT")}, + }, + ), + ), + # Add, remove and change a column from a struct where nested drop is not supported + ( + "STRUCT>", + "STRUCT>", + [ + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ + TableAlterColumn.struct("info"), + ], + expected_table_struct=exp.DataType.build("STRUCT"), + array_element_selector="", + ), + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ + TableAlterColumn.struct("info"), + ], + expected_table_struct=exp.DataType.build( + "STRUCT>" + ), + column_type=exp.DataType.build("STRUCT"), + array_element_selector="", + is_part_of_destructive_change=True, + ), + ], + dict( + nested_support=NestedSupport.ALL_BUT_DROP, compatible_types={ exp.DataType.build("INT"): {exp.DataType.build("TEXT")}, }, @@ -544,44 +795,61 @@ def test_schema_diff_calculate_type_transitions(): "STRUCT>>", "STRUCT, col_c INT>>", [ - TableAlterOperation.drop( - [ + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ TableAlterColumn.struct("info"), TableAlterColumn.primitive("col_b"), ], - "STRUCT>>", - "INT", + expected_table_struct=exp.DataType.build( + "STRUCT>>" + ), + array_element_selector="", ), - TableAlterOperation.add( - [ + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ TableAlterColumn.struct("info"), TableAlterColumn.primitive("col_c"), ], - "INT", - expected_table_struct="STRUCT, col_c INT>>", + column_type=exp.DataType.build("INT"), + expected_table_struct=exp.DataType.build( + "STRUCT, col_c INT>>" + ), position=TableAlterColumnPosition.last("nested_info"), + array_element_selector="", ), - TableAlterOperation.drop( - [ + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ TableAlterColumn.struct("info"), TableAlterColumn.struct("nested_info"), TableAlterColumn.primitive("nest_col_b"), ], - "STRUCT, col_c INT>>", - "INT", + expected_table_struct=exp.DataType.build( + "STRUCT, col_c INT>>" + ), + array_element_selector="", ), - TableAlterOperation.add( - [ + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ TableAlterColumn.struct("info"), TableAlterColumn.struct("nested_info"), TableAlterColumn.primitive("nest_col_c"), ], - "INT", - expected_table_struct="STRUCT, col_c INT>>", + column_type=exp.DataType.build("INT"), + expected_table_struct=exp.DataType.build( + "STRUCT, col_c INT>>" + ), position=TableAlterColumnPosition.last("nest_col_a"), + array_element_selector="", ), ], - dict(support_positional_add=True, support_nested_operations=True), + dict( + support_positional_add=True, + nested_support=NestedSupport.ALL, + ), ), # ##################### # # Array Struct Tests @@ -591,73 +859,119 @@ def test_schema_diff_calculate_type_transitions(): "STRUCT>>", "STRUCT>>", [ - TableAlterOperation.add( - [ + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ TableAlterColumn.array_of_struct("infos"), TableAlterColumn.primitive("col_d"), ], - "INT", - expected_table_struct="STRUCT>>", + column_type=exp.DataType.build("INT"), + expected_table_struct=exp.DataType.build( + "STRUCT>>" + ), position=TableAlterColumnPosition.middle("col_b"), + array_element_selector="", ), ], - dict(support_positional_add=True, support_nested_operations=True), + dict(support_positional_add=True, nested_support=NestedSupport.ALL_BUT_DROP), ), # Remove column from array of structs ( "STRUCT>>", "STRUCT>>", [ - TableAlterOperation.drop( - [ + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ TableAlterColumn.array_of_struct("infos"), TableAlterColumn.primitive("col_b"), ], - "STRUCT>>", - "INT", + expected_table_struct=exp.DataType.build( + "STRUCT>>" + ), + array_element_selector="", ), ], - dict(support_positional_add=True, support_nested_operations=True), + dict( + support_positional_add=True, + nested_support=NestedSupport.ALL, + ), ), # Alter column type in array of structs ( "STRUCT>>", "STRUCT>>", [ - TableAlterOperation.alter_type( - [ + TableAlterChangeColumnTypeOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ TableAlterColumn.array_of_struct("infos"), TableAlterColumn.primitive("col_c"), ], - "TEXT", - expected_table_struct="STRUCT>>", - position=TableAlterColumnPosition.last("col_b"), - current_type="INT", + column_type=exp.DataType.build("TEXT"), + expected_table_struct=exp.DataType.build( + "STRUCT>>" + ), + current_type=exp.DataType.build("INT"), + array_element_selector="", ), ], dict( support_positional_add=True, - support_nested_operations=True, + nested_support=NestedSupport.ALL_BUT_DROP, compatible_types={ exp.DataType.build("INT"): {exp.DataType.build("TEXT")}, }, ), ), + # Add columns to struct of array within different nesting levels + ( + "STRUCT>>", + "STRUCT>, col_e INT>", + [ + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("col_e")], + column_type=exp.DataType.build("INT"), + expected_table_struct=exp.DataType.build( + "STRUCT>, col_e INT>" + ), + array_element_selector="", + ), + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ + TableAlterColumn.array_of_struct("infos"), + TableAlterColumn.primitive("col_d"), + ], + column_type=exp.DataType.build("INT"), + expected_table_struct=exp.DataType.build( + "STRUCT>, col_e INT>" + ), + array_element_selector="", + ), + ], + dict(support_positional_add=False, nested_support=NestedSupport.ALL_BUT_DROP), + ), # Add an array of primitives ( "STRUCT>>", "STRUCT>, values ARRAY>", [ - TableAlterOperation.add( - [ + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ TableAlterColumn.array_of_primitive("values"), ], - "ARRAY", - expected_table_struct="STRUCT>, values ARRAY>", + column_type=exp.DataType.build("ARRAY"), + expected_table_struct=exp.DataType.build( + "STRUCT>, values ARRAY>" + ), position=TableAlterColumnPosition.last("infos"), + array_element_selector="", ), ], - dict(support_positional_add=True, support_nested_operations=True), + dict(support_positional_add=True, nested_support=NestedSupport.ALL_BUT_DROP), ), # untyped array to support Snowflake ( @@ -671,19 +985,19 @@ def test_schema_diff_calculate_type_transitions(): "STRUCT", "STRUCT", [ - TableAlterOperation.drop( - [ - TableAlterColumn.primitive("ids"), - ], - "STRUCT", - "INT", + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("ids")], + expected_table_struct=exp.DataType.build("STRUCT"), + array_element_selector="", ), - TableAlterOperation.add( - [ - TableAlterColumn.primitive("ids"), - ], - "ARRAY", - expected_table_struct="STRUCT", + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("ids")], + column_type=exp.DataType.build("ARRAY"), + expected_table_struct=exp.DataType.build("STRUCT"), + array_element_selector="", + is_part_of_destructive_change=True, ), ], {}, @@ -693,19 +1007,23 @@ def test_schema_diff_calculate_type_transitions(): "STRUCT", "STRUCT", [ - TableAlterOperation.drop( - [ + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ TableAlterColumn.array_of_primitive("ids"), ], - "STRUCT", - "ARRAY", + expected_table_struct=exp.DataType.build("STRUCT"), + array_element_selector="", ), - TableAlterOperation.add( - [ + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ TableAlterColumn.array_of_primitive("ids"), ], - "INT", - expected_table_struct="STRUCT", + column_type=exp.DataType.build("INT"), + expected_table_struct=exp.DataType.build("STRUCT"), + array_element_selector="", + is_part_of_destructive_change=True, ), ], {}, @@ -725,35 +1043,91 @@ def test_schema_diff_calculate_type_transitions(): "STRUCT", "STRUCT", [ - TableAlterOperation.alter_type( - TableAlterColumn.primitive("address"), - "VARCHAR(121)", - current_type="VARCHAR(120)", - expected_table_struct="STRUCT", + TableAlterChangeColumnTypeOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address")], + column_type=exp.DataType.build("VARCHAR(121)"), + current_type=exp.DataType.build("VARCHAR(120)"), + expected_table_struct=exp.DataType.build( + "STRUCT" + ), + array_element_selector="", ) ], {}, ), + # Increase the precision of a type is ALTER when the type is supported + ( + "STRUCT", + "STRUCT", + [ + TableAlterChangeColumnTypeOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address")], + column_type=exp.DataType.build("VARCHAR(121)"), + current_type=exp.DataType.build("VARCHAR(120)"), + expected_table_struct=exp.DataType.build( + "STRUCT" + ), + array_element_selector="", + ) + ], + dict( + precision_increase_allowed_types={exp.DataType.build("VARCHAR").this}, + ), + ), + # Increase the precision of a type is DROP/ADD when the type is not supported + ( + "STRUCT", + "STRUCT", + [ + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address")], + expected_table_struct=exp.DataType.build("STRUCT"), + array_element_selector="", + ), + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address")], + column_type=exp.DataType.build("VARCHAR(121)"), + expected_table_struct=exp.DataType.build( + "STRUCT" + ), + array_element_selector="", + is_part_of_destructive_change=True, + ), + ], + dict( + precision_increase_allowed_types={exp.DataType.build("DECIMAL").this}, + ), + ), # Decrease the precision of a type is DROP/ADD ( "STRUCT", "STRUCT", [ - TableAlterOperation.drop( - TableAlterColumn.primitive("address"), - "STRUCT", - "VARCHAR(120)", - ), - TableAlterOperation.add( - TableAlterColumn.primitive("address"), - "VARCHAR(100)", - expected_table_struct="STRUCT", + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address")], + expected_table_struct=exp.DataType.build("STRUCT"), + array_element_selector="", + ), + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address")], + column_type=exp.DataType.build("VARCHAR(100)"), + expected_table_struct=exp.DataType.build( + "STRUCT" + ), position=TableAlterColumnPosition.last("id"), + array_element_selector="", + is_part_of_destructive_change=True, ), ], dict( support_positional_add=True, - support_nested_operations=True, + nested_support=NestedSupport.ALL, ), ), # Type with precision to same type with no precision and no default is DROP/ADD @@ -761,16 +1135,20 @@ def test_schema_diff_calculate_type_transitions(): "STRUCT", "STRUCT", [ - TableAlterOperation.drop( - TableAlterColumn.primitive("address"), - "STRUCT", - "VARCHAR(120)", - ), - TableAlterOperation.add( - TableAlterColumn.primitive("address"), - "VARCHAR", - expected_table_struct="STRUCT", + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address")], + expected_table_struct=exp.DataType.build("STRUCT"), + array_element_selector="", + ), + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address")], + column_type=exp.DataType.build("VARCHAR"), + expected_table_struct=exp.DataType.build("STRUCT"), position=TableAlterColumnPosition.last("id"), + array_element_selector="", + is_part_of_destructive_change=True, ), ], dict( @@ -782,16 +1160,22 @@ def test_schema_diff_calculate_type_transitions(): "STRUCT", "STRUCT", [ - TableAlterOperation.drop( - TableAlterColumn.primitive("address"), - "STRUCT", - "VARCHAR", - ), - TableAlterOperation.add( - TableAlterColumn.primitive("address"), - "VARCHAR(120)", - expected_table_struct="STRUCT", + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address")], + expected_table_struct=exp.DataType.build("STRUCT"), + array_element_selector="", + ), + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address")], + column_type=exp.DataType.build("VARCHAR(120)"), + expected_table_struct=exp.DataType.build( + "STRUCT" + ), position=TableAlterColumnPosition.last("id"), + array_element_selector="", + is_part_of_destructive_change=True, ), ], dict( @@ -803,11 +1187,13 @@ def test_schema_diff_calculate_type_transitions(): "STRUCT", # default of 1 --> VARCHAR(1) "STRUCT", [ - TableAlterOperation.alter_type( - TableAlterColumn.primitive("address"), - "VARCHAR(2)", - current_type="VARCHAR", - expected_table_struct="STRUCT", + TableAlterChangeColumnTypeOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address")], + column_type=exp.DataType.build("VARCHAR(2)"), + current_type=exp.DataType.build("VARCHAR"), + expected_table_struct=exp.DataType.build("STRUCT"), + array_element_selector="", ) ], dict( @@ -821,16 +1207,20 @@ def test_schema_diff_calculate_type_transitions(): "STRUCT", "STRUCT", # default of 1 --> VARCHAR(1) [ - TableAlterOperation.drop( - TableAlterColumn.primitive("address"), - "STRUCT", - "VARCHAR(120)", - ), - TableAlterOperation.add( - TableAlterColumn.primitive("address"), - "VARCHAR", - expected_table_struct="STRUCT", + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address")], + expected_table_struct=exp.DataType.build("STRUCT"), + array_element_selector="", + ), + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address")], + column_type=exp.DataType.build("VARCHAR"), + expected_table_struct=exp.DataType.build("STRUCT"), position=TableAlterColumnPosition.last("id"), + array_element_selector="", + is_part_of_destructive_change=True, ), ], dict( @@ -845,11 +1235,15 @@ def test_schema_diff_calculate_type_transitions(): "STRUCT", "STRUCT", [ - TableAlterOperation.alter_type( - TableAlterColumn.primitive("address"), - "VARCHAR(max)", - current_type="VARCHAR(120)", - expected_table_struct="STRUCT", + TableAlterChangeColumnTypeOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address")], + column_type=exp.DataType.build("VARCHAR(max)"), + current_type=exp.DataType.build("VARCHAR(120)"), + expected_table_struct=exp.DataType.build( + "STRUCT" + ), + array_element_selector="", ) ], dict( @@ -863,11 +1257,15 @@ def test_schema_diff_calculate_type_transitions(): "STRUCT", "STRUCT", [ - TableAlterOperation.alter_type( - TableAlterColumn.primitive("address"), - "VARCHAR(max)", - current_type="VARCHAR(120)", - expected_table_struct="STRUCT", + TableAlterChangeColumnTypeOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address")], + column_type=exp.DataType.build("VARCHAR(max)"), + current_type=exp.DataType.build("VARCHAR(120)"), + expected_table_struct=exp.DataType.build( + "STRUCT" + ), + array_element_selector="", ) ], dict( @@ -881,16 +1279,22 @@ def test_schema_diff_calculate_type_transitions(): "STRUCT", "STRUCT", [ - TableAlterOperation.drop( - TableAlterColumn.primitive("address"), - "STRUCT", - "VARCHAR(max)", - ), - TableAlterOperation.add( - TableAlterColumn.primitive("address"), - "VARCHAR(120)", - expected_table_struct="STRUCT", + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address")], + expected_table_struct=exp.DataType.build("STRUCT"), + array_element_selector="", + ), + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address")], + column_type=exp.DataType.build("VARCHAR(120)"), + expected_table_struct=exp.DataType.build( + "STRUCT" + ), position=TableAlterColumnPosition.last("id"), + array_element_selector="", + is_part_of_destructive_change=True, ), ], dict( @@ -905,11 +1309,13 @@ def test_schema_diff_calculate_type_transitions(): "STRUCT", "STRUCT", [ - TableAlterOperation.alter_type( - TableAlterColumn.primitive("address"), - "VARCHAR", - current_type="VARCHAR(120)", - expected_table_struct="STRUCT", + TableAlterChangeColumnTypeOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address")], + column_type=exp.DataType.build("VARCHAR"), + current_type=exp.DataType.build("VARCHAR(120)"), + expected_table_struct=exp.DataType.build("STRUCT"), + array_element_selector="", ) ], dict( @@ -925,16 +1331,22 @@ def test_schema_diff_calculate_type_transitions(): "STRUCT", "STRUCT", [ - TableAlterOperation.drop( - TableAlterColumn.primitive("address"), - "STRUCT", - "VARCHAR", - ), - TableAlterOperation.add( - TableAlterColumn.primitive("address"), - "VARCHAR(120)", - expected_table_struct="STRUCT", + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address")], + expected_table_struct=exp.DataType.build("STRUCT"), + array_element_selector="", + ), + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address")], + column_type=exp.DataType.build("VARCHAR(120)"), + expected_table_struct=exp.DataType.build( + "STRUCT" + ), position=TableAlterColumnPosition.last("id"), + array_element_selector="", + is_part_of_destructive_change=True, ), ], dict( @@ -951,11 +1363,13 @@ def test_schema_diff_calculate_type_transitions(): "STRUCT", "STRUCT", [ - TableAlterOperation.alter_type( - TableAlterColumn.primitive("address"), - "TEXT", - current_type="VARCHAR(120)", - expected_table_struct="STRUCT", + TableAlterChangeColumnTypeOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address")], + column_type=exp.DataType.build("TEXT"), + current_type=exp.DataType.build("VARCHAR(120)"), + expected_table_struct=exp.DataType.build("STRUCT"), + array_element_selector="", ) ], dict( @@ -976,36 +1390,163 @@ def test_schema_diff_calculate_type_transitions(): [], dict( support_positional_add=True, - support_nested_operations=True, + nested_support=NestedSupport.ALL_BUT_DROP, support_coercing_compatible_types=True, compatible_types={ exp.DataType.build("INT"): {exp.DataType.build("FLOAT")}, }, ), ), + ( + "STRUCT", + "STRUCT", + [], + dict( + support_positional_add=True, + nested_support=NestedSupport.ALL_BUT_DROP, + coerceable_types={ + exp.DataType.build("FLOAT"): {exp.DataType.build("INT")}, + }, + ), + ), + ( + "STRUCT", + "STRUCT", + [], + dict( + support_positional_add=True, + nested_support=NestedSupport.ALL_BUT_DROP, + support_coercing_compatible_types=True, + compatible_types={ + exp.DataType.build("INT"): {exp.DataType.build("FLOAT")}, + }, + coerceable_types={ + exp.DataType.build("STRING"): {exp.DataType.build("INT")}, + }, + ), + ), # Coercion with an alter results in a single alter ( "STRUCT", "STRUCT", [ - TableAlterOperation.alter_type( - TableAlterColumn.primitive("total"), - "FLOAT", - current_type="INT", + TableAlterChangeColumnTypeOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("total")], + column_type=exp.DataType.build("FLOAT"), + current_type=exp.DataType.build("INT"), # Note that the resulting table struct will not match what we defined as the desired # result since it could be coerced - expected_table_struct="STRUCT", + expected_table_struct=exp.DataType.build( + "STRUCT" + ), + array_element_selector="", ) ], dict( support_positional_add=False, - support_nested_operations=True, + nested_support=NestedSupport.ALL_BUT_DROP, support_coercing_compatible_types=True, compatible_types={ exp.DataType.build("INT"): {exp.DataType.build("FLOAT")}, }, ), ), + # ################### + # Ignore Nested Tests + # ################### + # Remove nested col_c + ( + "STRUCT>", + "STRUCT>", + [], + dict(nested_support=NestedSupport.IGNORE), + ), + # Add nested col_d + ( + "STRUCT>", + "STRUCT>", + [], + dict(nested_support=NestedSupport.IGNORE), + ), + # Change nested col_c to incompatible type + ( + "STRUCT>", + "STRUCT>", + [], + dict(nested_support=NestedSupport.IGNORE), + ), + # Change nested col_c to compatible type + ( + "STRUCT>", + "STRUCT>", + [], + dict( + nested_support=NestedSupport.IGNORE, + compatible_types={ + exp.DataType.build("INT"): {exp.DataType.build("STRING")}, + }, + ), + ), + # Mix of ignored nested and non-nested changes + ( + "STRUCT, age INT>", + "STRUCT, age STRING, new_col INT>", + [ + # `col_c` change is ignored + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("new_col")], + column_type=exp.DataType.build("INT"), + expected_table_struct=exp.DataType.build( + "STRUCT, age INT, new_col INT>" + ), + position=TableAlterColumnPosition.last("age"), + array_element_selector="", + ), + TableAlterChangeColumnTypeOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("age")], + column_type=exp.DataType.build("STRING"), + current_type=exp.DataType.build("INT"), + expected_table_struct=exp.DataType.build( + "STRUCT, age STRING, new_col INT>" + ), + array_element_selector="", + ), + ], + dict( + nested_support=NestedSupport.IGNORE, + compatible_types={ + exp.DataType.build("INT"): {exp.DataType.build("STRING")}, + }, + support_positional_add=True, + ), + ), + # ############################ + # Change Data Type Destructive + # ############################ + ( + "STRUCT", + "STRUCT", + [ + TableAlterChangeColumnTypeOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("age")], + column_type=exp.DataType.build("STRING"), + current_type=exp.DataType.build("INT"), + expected_table_struct=exp.DataType.build("STRUCT"), + array_element_selector="", + is_part_of_destructive_change=True, + ), + ], + dict( + treat_alter_data_type_as_destructive=True, + compatible_types={ + exp.DataType.build("INT"): {exp.DataType.build("STRING")}, + }, + ), + ), ], ) def test_struct_diff( @@ -1016,7 +1557,9 @@ def test_struct_diff( ): resolver = SchemaDiffer(**config) operations = resolver._from_structs( - exp.DataType.build(current_struct), exp.DataType.build(new_struct) + exp.DataType.build(current_struct), + exp.DataType.build(new_struct), + "apply_to_table", ) assert operations == expected_diff @@ -1044,7 +1587,7 @@ def test_schema_diff_calculate_duckdb(duck_conn): }, ) - alter_expressions = engine_adapter.get_alter_expressions("apply_to_table", "schema_from_table") + alter_expressions = engine_adapter.get_alter_operations("apply_to_table", "schema_from_table") engine_adapter.alter_table(alter_expressions) assert engine_adapter.columns("apply_to_table") == { "id": exp.DataType.build("int"), @@ -1055,46 +1598,57 @@ def test_schema_diff_calculate_duckdb(duck_conn): def test_schema_diff_alter_op_column(): - nested = TableAlterOperation.add( - [ + nested = TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ TableAlterColumn.array_of_struct("nested"), TableAlterColumn.primitive("col_a"), ], - "INT", - expected_table_struct="STRUCT>>", + column_type=exp.DataType.build("INT"), + expected_table_struct=exp.DataType.build("STRUCT>>"), position=TableAlterColumnPosition.last("id"), + array_element_selector="", ) - assert nested.column("").sql() == "nested.col_a" - nested_complete_column = TableAlterOperation.add( - [ + assert nested.column.sql() == "nested.col_a" + nested_complete_column = TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ TableAlterColumn.array_of_struct("nested_1", quoted=True), TableAlterColumn.struct("nested_2"), TableAlterColumn.array_of_struct("nested_3"), TableAlterColumn.primitive("col_a", quoted=True), ], - "INT", - expected_table_struct="""STRUCT>>>>>""", + column_type=exp.DataType.build("INT"), + expected_table_struct=exp.DataType.build( + """STRUCT>>>>>""" + ), position=TableAlterColumnPosition.last("id"), + array_element_selector="", ) - assert nested_complete_column.column("").sql() == '"nested_1".nested_2.nested_3."col_a"' - nested_one_more_complete_column = TableAlterOperation.add( - [ + assert nested_complete_column.column.sql() == '"nested_1".nested_2.nested_3."col_a"' + nested_one_more_complete_column = TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ TableAlterColumn.array_of_struct("nested_1", quoted=True), TableAlterColumn.struct("nested_2"), TableAlterColumn.array_of_struct("nested_3"), TableAlterColumn.struct("nested_4"), TableAlterColumn.primitive("col_a", quoted=True), ], - "INT", - expected_table_struct="""STRUCT>>>>>>""", + column_type=exp.DataType.build("INT"), + expected_table_struct=exp.DataType.build( + """STRUCT>>>>>>""" + ), position=TableAlterColumnPosition.last("id"), + array_element_selector="", ) assert ( - nested_one_more_complete_column.column("").sql() + nested_one_more_complete_column.column.sql() == '"nested_1".nested_2.nested_3.nested_4."col_a"' ) - super_nested = TableAlterOperation.add( - [ + super_nested = TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ TableAlterColumn.array_of_struct("nested_1", quoted=True), TableAlterColumn.struct("nested_2"), TableAlterColumn.array_of_struct("nested_3"), @@ -1105,11 +1659,710 @@ def test_schema_diff_alter_op_column(): TableAlterColumn.array_of_struct("nested_8"), TableAlterColumn.primitive("col_a", quoted=True), ], - "INT", - expected_table_struct="""STRUCT>>>>>>>>>>>""", + column_type=exp.DataType.build("INT"), + expected_table_struct=exp.DataType.build( + """STRUCT>>>>>>>>>>>""" + ), position=TableAlterColumnPosition.last("id"), + array_element_selector="element", ) assert ( - super_nested.column("element").sql() + super_nested.column.sql() == '"nested_1".element.nested_2.nested_3.element.nested_4.nested_5."nested_6".nested_7.nested_8.element."col_a"' ) + + +@pytest.mark.parametrize( + "current_struct, new_struct, expected_diff_with_destructive, expected_diff_ignore_destructive, config", + [ + # Simple DROP operation - should be ignored when ignore_destructive=True + ( + "STRUCT", + "STRUCT", + [ + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("name")], + expected_table_struct=exp.DataType.build("STRUCT"), + array_element_selector="", + ) + ], + [], # No operations when ignoring destructive + {}, + ), + # DROP + ADD operation (incompatible type change) - should be ignored when ignore_destructive=True + ( + "STRUCT", + "STRUCT", + [ + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("name")], + expected_table_struct=exp.DataType.build("STRUCT"), + array_element_selector="", + ), + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("name")], + column_type=exp.DataType.build("BIGINT"), + expected_table_struct=exp.DataType.build( + "STRUCT" + ), + array_element_selector="", + is_part_of_destructive_change=True, + ), + ], + [], # No operations when ignoring destructive + {}, + ), + # Pure ADD operation - should work same way regardless of ignore_destructive + ( + "STRUCT", + "STRUCT", + [ + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("new_col")], + column_type=exp.DataType.build("STRING"), + expected_table_struct=exp.DataType.build( + "STRUCT" + ), + array_element_selector="", + ), + ], + [ + # Same operation when ignoring destructive + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("new_col")], + column_type=exp.DataType.build("STRING"), + expected_table_struct=exp.DataType.build( + "STRUCT" + ), + array_element_selector="", + ), + ], + {}, + ), + # Mix of destructive and non-destructive operations + ( + "STRUCT", + "STRUCT", + [ + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("name")], + expected_table_struct=exp.DataType.build("STRUCT"), + array_element_selector="", + ), + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address")], + column_type=exp.DataType.build("STRING"), + expected_table_struct=exp.DataType.build( + "STRUCT" + ), + array_element_selector="", + ), + TableAlterChangeColumnTypeOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("id")], + column_type=exp.DataType.build("STRING"), + current_type=exp.DataType.build("INT"), + expected_table_struct=exp.DataType.build( + "STRUCT" + ), + array_element_selector="", + ), + ], + [ + # Only non-destructive operations remain + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address")], + column_type=exp.DataType.build("STRING"), + expected_table_struct=exp.DataType.build( + "STRUCT" + ), + array_element_selector="", + ), + TableAlterChangeColumnTypeOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("id")], + column_type=exp.DataType.build("STRING"), + current_type=exp.DataType.build("INT"), + expected_table_struct=exp.DataType.build( + "STRUCT" + ), + array_element_selector="", + ), + ], + dict( + compatible_types={ + exp.DataType.build("INT"): {exp.DataType.build("STRING")}, + } + ), + ), + ], +) +def test_ignore_destructive_operations( + current_struct, + new_struct, + expected_diff_with_destructive: t.List[TableAlterOperation], + expected_diff_ignore_destructive: t.List[TableAlterOperation], + config: t.Dict[str, t.Any], +): + resolver = SchemaDiffer(**config) + + # Test with destructive operations allowed (default behavior) + operations_with_destructive = resolver._from_structs( + exp.DataType.build(current_struct), + exp.DataType.build(new_struct), + "apply_to_table", + ignore_destructive=False, + ) + assert operations_with_destructive == expected_diff_with_destructive + + # Test with destructive operations ignored + operations_ignore_destructive = resolver._from_structs( + exp.DataType.build(current_struct), + exp.DataType.build(new_struct), + "apply_to_table", + ignore_destructive=True, + ) + assert operations_ignore_destructive == expected_diff_ignore_destructive + + +def test_ignore_destructive_compare_columns(): + """Test ignore_destructive behavior in compare_columns method.""" + schema_differ = SchemaDiffer( + support_positional_add=True, + nested_support=NestedSupport.NONE, + compatible_types={ + exp.DataType.build("INT"): {exp.DataType.build("STRING")}, + }, + ) + + current = { + "id": exp.DataType.build("INT"), + "name": exp.DataType.build("STRING"), + "to_drop": exp.DataType.build("DOUBLE"), + "age": exp.DataType.build("INT"), + } + + new = { + "id": exp.DataType.build("STRING"), # Compatible type change + "name": exp.DataType.build("STRING"), + "age": exp.DataType.build("INT"), + "new_col": exp.DataType.build("DOUBLE"), # New column + } + + # With destructive operations allowed + alter_expressions_with_destructive = schema_differ.compare_columns( + "test_table", current, new, ignore_destructive=False + ) + assert len(alter_expressions_with_destructive) == 3 # DROP + ADD + ALTER + + # With destructive operations ignored + alter_expressions_ignore_destructive = schema_differ.compare_columns( + "test_table", current, new, ignore_destructive=True + ) + assert len(alter_expressions_ignore_destructive) == 2 # Only ADD + ALTER + + # Verify the operations are correct + operations_sql = [expr.expression.sql() for expr in alter_expressions_ignore_destructive] + add_column_found = any("ADD COLUMN new_col DOUBLE" in op for op in operations_sql) + alter_column_found = any("ALTER COLUMN id SET DATA TYPE" in op for op in operations_sql) + drop_column_found = any("DROP COLUMN to_drop" in op for op in operations_sql) + + assert add_column_found, f"ADD COLUMN not found in: {operations_sql}" + assert alter_column_found, f"ALTER COLUMN not found in: {operations_sql}" + assert not drop_column_found, f"DROP COLUMN should not be present in: {operations_sql}" + + +def test_ignore_destructive_nested_struct_without_support(): + """Test ignore_destructive with nested structs when nested_drop is not supported.""" + schema_differ = SchemaDiffer( + nested_support=NestedSupport.ALL_BUT_DROP, # This forces DROP+ADD for nested changes + ) + + current_struct = "STRUCT>" + new_struct = "STRUCT>" # Removes col_b + + # With destructive operations allowed - should do DROP+ADD of entire struct + operations_with_destructive = schema_differ._from_structs( + exp.DataType.build(current_struct), + exp.DataType.build(new_struct), + "apply_to_table", + ignore_destructive=False, + ) + assert len(operations_with_destructive) == 2 # DROP struct + ADD struct + assert isinstance(operations_with_destructive[0], TableAlterDropColumnOperation) + assert isinstance(operations_with_destructive[1], TableAlterAddColumnOperation) + + # With destructive operations ignored - should do nothing + operations_ignore_destructive = schema_differ._from_structs( + exp.DataType.build(current_struct), + exp.DataType.build(new_struct), + "apply_to_table", + ignore_destructive=True, + ) + assert len(operations_ignore_destructive) == 0 + + +def test_get_schema_differ(): + # Test that known dialects return SchemaDiffer instances + for dialect in ["bigquery", "snowflake", "postgres", "databricks", "spark", "duckdb"]: + schema_differ = get_schema_differ(dialect) + assert isinstance(schema_differ, SchemaDiffer) + + # Test specific configurations + # Databricks should support positional add and nested operations + databricks_differ = get_schema_differ("databricks") + assert databricks_differ.support_positional_add is True + assert databricks_differ.nested_support == NestedSupport.ALL + + # BigQuery should have specific compatible types configured + bigquery_differ = get_schema_differ("bigquery") + assert len(bigquery_differ.compatible_types) > 0 + assert bigquery_differ.support_coercing_compatible_types is True + + # Snowflake should have parameterized type defaults + snowflake_differ = get_schema_differ("snowflake") + assert len(snowflake_differ.parameterized_type_defaults) > 0 + + # Postgres should support drop cascade + postgres_differ = get_schema_differ("postgres") + assert postgres_differ.drop_cascade is True + assert len(postgres_differ.types_with_unlimited_length) > 0 + + # Test dialect aliases work correctly + schema_differ_pg = get_schema_differ("postgresql") + schema_differ_postgres = get_schema_differ("postgres") + assert schema_differ_pg.drop_cascade == schema_differ_postgres.drop_cascade + + # Test unknown dialect returns default SchemaDiffer + schema_differ_unknown = get_schema_differ("unknown_dialect") + assert isinstance(schema_differ_unknown, SchemaDiffer) + assert schema_differ_unknown.support_positional_add is False + assert schema_differ_unknown.nested_support == NestedSupport.NONE + + # Test case insensitivity + schema_differ_upper = get_schema_differ("BIGQUERY") + schema_differ_lower = get_schema_differ("bigquery") + assert ( + schema_differ_upper.support_coercing_compatible_types + == schema_differ_lower.support_coercing_compatible_types + ) + + # Test override + schema_differ_with_override = get_schema_differ("postgres", {"drop_cascade": False}) + assert schema_differ_with_override.drop_cascade is False + + +def test_ignore_destructive_edge_cases(): + """Test edge cases for ignore_destructive behavior.""" + schema_differ = SchemaDiffer(support_positional_add=True) + + # Test when all operations are destructive - should result in empty list + current_struct = "STRUCT" + new_struct = "STRUCT<>" # Remove all columns + + operations_ignore_destructive = schema_differ._from_structs( + exp.DataType.build(current_struct), + exp.DataType.build(new_struct), + "apply_to_table", + ignore_destructive=True, + ) + assert len(operations_ignore_destructive) == 0 + + # Test when no operations are needed - should result in empty list regardless of ignore_destructive + same_struct = "STRUCT" + + operations_same_with_destructive = schema_differ._from_structs( + exp.DataType.build(same_struct), + exp.DataType.build(same_struct), + "apply_to_table", + ignore_destructive=False, + ) + operations_same_ignore_destructive = schema_differ._from_structs( + exp.DataType.build(same_struct), + exp.DataType.build(same_struct), + "apply_to_table", + ignore_destructive=True, + ) + assert len(operations_same_with_destructive) == 0 + assert len(operations_same_ignore_destructive) == 0 + + # Test when only ADD operations are needed - should be same regardless of ignore_destructive + current_struct = "STRUCT" + new_struct = "STRUCT" + + operations_add_with_destructive = schema_differ._from_structs( + exp.DataType.build(current_struct), + exp.DataType.build(new_struct), + "apply_to_table", + ignore_destructive=False, + ) + operations_add_ignore_destructive = schema_differ._from_structs( + exp.DataType.build(current_struct), + exp.DataType.build(new_struct), + "apply_to_table", + ignore_destructive=True, + ) + assert len(operations_add_with_destructive) == 2 # ADD name, ADD age + assert len(operations_add_ignore_destructive) == 2 # Same operations + assert operations_add_with_destructive == operations_add_ignore_destructive + + +@pytest.mark.parametrize( + "current_struct, new_struct, expected_diff_with_additive, expected_diff_ignore_additive, config", + [ + # Simple ADD operation - should be ignored when ignore_additive=True + ( + "STRUCT", + "STRUCT", + [ + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("age")], + column_type=exp.DataType.build("INT"), + expected_table_struct=exp.DataType.build( + "STRUCT" + ), + array_element_selector="", + ) + ], + [], # No operations when ignoring additive + {}, + ), + # Multiple ADD operations - should all be ignored when ignore_additive=True + ( + "STRUCT", + "STRUCT", + [ + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("name")], + column_type=exp.DataType.build("STRING"), + expected_table_struct=exp.DataType.build("STRUCT"), + array_element_selector="", + ), + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("age")], + column_type=exp.DataType.build("INT"), + expected_table_struct=exp.DataType.build( + "STRUCT" + ), + array_element_selector="", + ), + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address")], + column_type=exp.DataType.build("STRING"), + expected_table_struct=exp.DataType.build( + "STRUCT" + ), + array_element_selector="", + ), + ], + [], # No operations when ignoring additive + {}, + ), + # Pure DROP operation - should work same way regardless of ignore_additive + ( + "STRUCT", + "STRUCT", + [ + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("age")], + expected_table_struct=exp.DataType.build("STRUCT"), + array_element_selector="", + ), + ], + [ + # Same operation when ignoring additive + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("age")], + expected_table_struct=exp.DataType.build("STRUCT"), + array_element_selector="", + ), + ], + {}, + ), + # Mix of additive and non-additive operations + ( + "STRUCT", + "STRUCT", + [ + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("name")], + expected_table_struct=exp.DataType.build( + "STRUCT" + ), + array_element_selector="", + ), + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("address")], + column_type=exp.DataType.build("STRING"), + expected_table_struct=exp.DataType.build( + "STRUCT" + ), + array_element_selector="", + ), + TableAlterChangeColumnTypeOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("id")], + column_type=exp.DataType.build("STRING"), + current_type=exp.DataType.build("INT"), + expected_table_struct=exp.DataType.build( + "STRUCT" + ), + array_element_selector="", + ), + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("something")], + expected_table_struct=exp.DataType.build( + "STRUCT" + ), + array_element_selector="", + ), + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("something")], + column_type=exp.DataType.build("INT"), + expected_table_struct=exp.DataType.build( + "STRUCT" + ), + array_element_selector="", + is_part_of_destructive_change=True, + ), + ], + [ + # Only non-additive operations remain (alter is considered additive since it was a compatible change) + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("name")], + expected_table_struct=exp.DataType.build( + "STRUCT" + ), + array_element_selector="", + ), + TableAlterDropColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("something")], + expected_table_struct=exp.DataType.build("STRUCT"), + array_element_selector="", + ), + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("something")], + column_type=exp.DataType.build("INT"), + expected_table_struct=exp.DataType.build( + "STRUCT" + ), + array_element_selector="", + is_part_of_destructive_change=True, + ), + ], + dict( + compatible_types={ + exp.DataType.build("INT"): {exp.DataType.build("STRING")}, + } + ), + ), + # ADD operations with nested structs - should be ignored when ignore_additive=True + ( + "STRUCT>", + "STRUCT, new_field STRING>", + [ + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[TableAlterColumn.primitive("new_field")], + column_type=exp.DataType.build("STRING"), + expected_table_struct=exp.DataType.build( + "STRUCT, new_field STRING>" + ), + array_element_selector="", + ), + TableAlterAddColumnOperation( + target_table=exp.to_table("apply_to_table"), + column_parts=[ + TableAlterColumn.struct("info"), + TableAlterColumn.primitive("col_c"), + ], + column_type=exp.DataType.build("INT"), + expected_table_struct=exp.DataType.build( + "STRUCT, new_field STRING>" + ), + array_element_selector="", + ), + ], + [], # No operations when ignoring additive + dict(nested_support=NestedSupport.ALL_BUT_DROP), + ), + ], +) +def test_ignore_additive_operations( + current_struct, + new_struct, + expected_diff_with_additive: t.List[TableAlterOperation], + expected_diff_ignore_additive: t.List[TableAlterOperation], + config: t.Dict[str, t.Any], +): + resolver = SchemaDiffer(**config) + + # Test with additive operations allowed (default behavior) + operations_with_additive = resolver._from_structs( + exp.DataType.build(current_struct), + exp.DataType.build(new_struct), + "apply_to_table", + ignore_additive=False, + ) + assert operations_with_additive == expected_diff_with_additive + + # Test with additive operations ignored + operations_ignore_additive = resolver._from_structs( + exp.DataType.build(current_struct), + exp.DataType.build(new_struct), + "apply_to_table", + ignore_additive=True, + ) + assert operations_ignore_additive == expected_diff_ignore_additive + + +def test_ignore_additive_edge_cases(): + """Test edge cases for ignore_additive behavior.""" + schema_differ = SchemaDiffer(support_positional_add=True) + + # Test when all operations are additive - should result in empty list + current_struct = "STRUCT" + new_struct = "STRUCT" # Add all columns + + operations_ignore_additive = schema_differ._from_structs( + exp.DataType.build(current_struct), + exp.DataType.build(new_struct), + "apply_to_table", + ignore_additive=True, + ) + assert len(operations_ignore_additive) == 0 + + # Test when no operations are needed - should result in empty list regardless of ignore_additive + same_struct = "STRUCT" + + operations_same_with_additive = schema_differ._from_structs( + exp.DataType.build(same_struct), + exp.DataType.build(same_struct), + "apply_to_table", + ignore_additive=False, + ) + operations_same_ignore_additive = schema_differ._from_structs( + exp.DataType.build(same_struct), + exp.DataType.build(same_struct), + "apply_to_table", + ignore_additive=True, + ) + assert len(operations_same_with_additive) == 0 + assert len(operations_same_ignore_additive) == 0 + + # Test when only DROP operations are needed - should be same regardless of ignore_additive + current_struct = "STRUCT" + new_struct = "STRUCT" + + operations_drop_with_additive = schema_differ._from_structs( + exp.DataType.build(current_struct), + exp.DataType.build(new_struct), + "apply_to_table", + ignore_additive=False, + ) + operations_drop_ignore_additive = schema_differ._from_structs( + exp.DataType.build(current_struct), + exp.DataType.build(new_struct), + "apply_to_table", + ignore_additive=True, + ) + assert len(operations_drop_with_additive) == 2 # DROP name, DROP age + assert len(operations_drop_ignore_additive) == 2 # Same operations + assert operations_drop_with_additive == operations_drop_ignore_additive + + +def test_ignore_both_destructive_and_additive(): + """Test behavior when both ignore_destructive and ignore_additive are True.""" + schema_differ = SchemaDiffer( + support_positional_add=True, + compatible_types={ + exp.DataType.build("INT"): {exp.DataType.build("STRING")}, + }, + ) + + current_struct = "STRUCT" + new_struct = "STRUCT" # DROP name, ADD address, ALTER id + + operations_ignore_both = schema_differ._from_structs( + exp.DataType.build(current_struct), + exp.DataType.build(new_struct), + "apply_to_table", + ignore_destructive=True, + ignore_additive=True, + ) + assert len(operations_ignore_both) == 0 + + +def test_ignore_additive_array_operations(): + """Test ignore_additive with array of struct operations.""" + schema_differ = SchemaDiffer( + nested_support=NestedSupport.ALL, + support_positional_add=True, + ) + + current_struct = "STRUCT>>" + new_struct = "STRUCT>>" + + # With additive operations allowed - should add to array struct + operations_with_additive = schema_differ._from_structs( + exp.DataType.build(current_struct), + exp.DataType.build(new_struct), + "apply_to_table", + ignore_additive=False, + ) + assert len(operations_with_additive) == 1 # ADD to array struct + assert isinstance(operations_with_additive[0], TableAlterAddColumnOperation) + + # With additive operations ignored - should do nothing + operations_ignore_additive = schema_differ._from_structs( + exp.DataType.build(current_struct), + exp.DataType.build(new_struct), + "apply_to_table", + ignore_additive=True, + ) + assert len(operations_ignore_additive) == 0 + + +def test_drop_operation_missing_column_error(): + schema_differ = SchemaDiffer( + nested_support=NestedSupport.NONE, + support_positional_add=False, + ) + + # a struct that doesn't contain the column we're going to drop + current_struct = exp.DataType.build("STRUCT") + + with pytest.raises(SQLMeshError) as error_message: + schema_differ._drop_operation( + columns=[TableAlterColumn.primitive("missing_column")], + struct=current_struct, + pos=0, + root_struct=current_struct, + table_name="test_table", + ) + + assert ( + str(error_message.value) + == "Cannot drop column 'missing_column' from table 'test_table' - column not found. This may indicate a mismatch between the expected and actual table schemas." + ) diff --git a/tests/core/test_schema_loader.py b/tests/core/test_schema_loader.py index ada27dfdc3..8b944793be 100644 --- a/tests/core/test_schema_loader.py +++ b/tests/core/test_schema_loader.py @@ -1,10 +1,9 @@ import pytest -import logging import typing as t from pathlib import Path from unittest.mock import patch -import pandas as pd +import pandas as pd # noqa: TID253 from pytest_mock.plugin import MockerFixture from sqlglot import exp, parse_one @@ -16,13 +15,13 @@ from sqlmesh.core.model.definition import ExternalModel from sqlmesh.core.schema_loader import create_external_models_file from sqlmesh.core.snapshot import SnapshotChangeCategory -from sqlmesh.utils.yaml import YAML +from sqlmesh.utils import yaml from sqlmesh.utils.errors import SQLMeshError -def test_create_external_models(tmpdir, assert_exp_eq): +def test_create_external_models(tmp_path, assert_exp_eq): config = Config(gateways=GatewayConfig(connection=DuckDBConnectionConfig())) - context = Context(paths=[tmpdir], config=config) + context = Context(paths=[tmp_path], config=config) # `fruits` is used by DuckDB in the upcoming select query fruits = pd.DataFrame( @@ -107,7 +106,7 @@ def test_create_external_models(tmpdir, assert_exp_eq): assert context.models['"memory"."sushi"."raw_fruits"'].gateway is None -def test_gateway_specific_external_models(tmpdir: Path): +def test_gateway_specific_external_models(tmp_path: Path): gateways = { "dev": GatewayConfig(connection=DuckDBConnectionConfig()), "prod": GatewayConfig(connection=DuckDBConnectionConfig()), @@ -115,12 +114,12 @@ def test_gateway_specific_external_models(tmpdir: Path): config = Config(gateways=gateways, default_gateway="dev") - dev_context = Context(paths=[tmpdir], config=config, gateway="dev") + dev_context = Context(paths=[tmp_path], config=config, gateway="dev") dev_context.engine_adapter.execute("create schema landing") dev_context.engine_adapter.execute("create table landing.dev_source as select 1") dev_context.engine_adapter.execute("create schema lake") - prod_context = Context(paths=[tmpdir], config=config, gateway="prod") + prod_context = Context(paths=[tmp_path], config=config, gateway="prod") prod_context.engine_adapter.execute("create schema landing") prod_context.engine_adapter.execute("create table landing.prod_source as select 1") prod_context.engine_adapter.execute("create schema lake") @@ -160,23 +159,28 @@ def _create_model(gateway: str): # each context can only see models for its own gateway # check that models from both gateways present in the file, to show that prod_context.create_external_models() didnt clobber the dev ones - external_models_filename = tmpdir / c.EXTERNAL_MODELS_YAML - with open(external_models_filename, "r", encoding="utf8") as fd: - contents = YAML().load(fd) + contents = yaml.load(tmp_path / c.EXTERNAL_MODELS_YAML) - assert len(contents) == 2 - assert len([c for c in contents if c["name"] == '"memory"."landing"."dev_source"']) == 1 - assert len([c for c in contents if c["name"] == '"memory"."landing"."prod_source"']) == 1 + assert len(contents) == 2 + assert len([c for c in contents if c["name"] == '"memory"."landing"."dev_source"']) == 1 + assert len([c for c in contents if c["name"] == '"memory"."landing"."prod_source"']) == 1 -def test_gateway_specific_external_models_mixed_with_others(tmpdir): +def test_gateway_specific_external_models_mixed_with_others(tmp_path: Path): + def _init_db(ctx: Context): + ctx.engine_adapter.execute("create schema landing") + ctx.engine_adapter.execute("create table landing.source_table as select 1") + ctx.engine_adapter.execute("create schema lake") + gateways = { "dev": GatewayConfig(connection=DuckDBConnectionConfig()), + "prod": GatewayConfig(connection=DuckDBConnectionConfig()), } config = Config(gateways=gateways, default_gateway="dev") - model_dir = (tmpdir / c.MODELS).mkdir() + model_dir = tmp_path / c.MODELS + model_dir.mkdir() with open(model_dir / "table.sql", "w", encoding="utf8") as fd: fd.write( @@ -190,12 +194,11 @@ def test_gateway_specific_external_models_mixed_with_others(tmpdir): """, ) - ctx = Context(paths=[tmpdir], config=config) # note: No explicitly defined gateway + ctx = Context(paths=[tmp_path], config=config) # note: No explicitly defined gateway assert ctx.gateway is None + assert ctx.selected_gateway == "dev" - ctx.engine_adapter.execute("create schema landing") - ctx.engine_adapter.execute("create table landing.source_table as select 1") - ctx.engine_adapter.execute("create schema lake") + _init_db(ctx) ctx.load() assert len(ctx.models) == 1 @@ -203,46 +206,50 @@ def test_gateway_specific_external_models_mixed_with_others(tmpdir): ctx.create_external_models() - # no gateway was specifically chosen; external models should be created without a gateway - external_models_filename = tmpdir / c.EXTERNAL_MODELS_YAML - with open(external_models_filename, "r", encoding="utf8") as fd: - contents = YAML().load(fd) - assert len(contents) == 1 - assert "gateway" not in contents[0] + # no gateway was specifically chosen; external models should be created against the default gateway + external_models_filename = tmp_path / c.EXTERNAL_MODELS_YAML + contents = yaml.load(external_models_filename) + assert len(contents) == 1 + assert contents[0]["gateway"] == "dev" ctx.load() assert len(ctx.models) == 2 assert '"memory"."landing"."source_table"' in ctx.models assert '"memory"."lake"."table"' in ctx.models - ctx.gateway = "dev" # explicitly set gateway=dev - ctx.create_external_models() + # explicitly set --gateway prod + prod_ctx = Context(paths=[tmp_path], config=config, gateway="prod") + assert prod_ctx.gateway == "prod" + assert prod_ctx.selected_gateway == "prod" + + _init_db(prod_ctx) - # there should now be 2 external models with the same name - one with a gateway and one without - with open(external_models_filename, "r", encoding="utf8") as fd: - contents = YAML().load(fd) - assert len(contents) == 2 - assert "gateway" not in contents[0] - assert "gateway" in contents[1] - assert contents[0]["name"] == contents[1]["name"] + prod_ctx.create_external_models() + + # there should now be 2 external models with the same name - one with a gateway=dev and one with gateway=prod + contents = yaml.load(external_models_filename) + assert len(contents) == 2 + assert sorted([contents[0]["gateway"], contents[1]["gateway"]]) == ["dev", "prod"] + assert contents[0]["name"] == contents[1]["name"] # check that this doesnt present a problem on load - ctx.load() + prod_ctx.load() - external_models = [m for _, m in ctx.models.items() if type(m) == ExternalModel] + external_models = [m for _, m in prod_ctx.models.items() if type(m) == ExternalModel] assert len(external_models) == 1 assert external_models[0].name == '"memory"."landing"."source_table"' + assert external_models[0].gateway == "prod" -def test_gateway_specific_external_models_default_gateway(tmpdir: Path): +def test_gateway_specific_external_models_default_gateway(tmp_path: Path): model_0 = {"name": "db.model0", "columns": {"a": "int"}} model_1 = {"name": "db.model1", "gateway": "dev", "columns": {"a": "int"}} model_2 = {"name": "db.model2", "gateway": "prod", "columns": {"a": "int"}} - with open(tmpdir / c.EXTERNAL_MODELS_YAML, "w", encoding="utf8") as fd: - YAML().dump([model_0, model_1, model_2], fd) + with open(tmp_path / c.EXTERNAL_MODELS_YAML, "w", encoding="utf8") as fd: + yaml.dump([model_0, model_1, model_2], fd) gateways = { "dev": GatewayConfig(connection=DuckDBConnectionConfig()), @@ -250,7 +257,7 @@ def test_gateway_specific_external_models_default_gateway(tmpdir: Path): } config = Config(gateways=gateways, default_gateway="prod") - ctx = Context(paths=[tmpdir], config=config) + ctx = Context(paths=[tmp_path], config=config) ctx.load() @@ -260,10 +267,11 @@ def test_gateway_specific_external_models_default_gateway(tmpdir: Path): assert '"memory"."db"."model2"' in model_names -def test_create_external_models_no_duplicates(tmpdir): +def test_create_external_models_no_duplicates(tmp_path: Path): config = Config(gateways={"": GatewayConfig(connection=DuckDBConnectionConfig())}) - model_dir = (tmpdir / c.MODELS).mkdir() + model_dir = tmp_path / c.MODELS + model_dir.mkdir() with open(model_dir / "table.sql", "w", encoding="utf8") as fd: fd.write( @@ -277,15 +285,14 @@ def test_create_external_models_no_duplicates(tmpdir): """, ) - ctx = Context(paths=[tmpdir], config=config) + ctx = Context(paths=[tmp_path], config=config) assert ctx.gateway is None ctx.engine_adapter.execute("create schema landing") ctx.engine_adapter.execute("create table landing.source_table as select 1") ctx.engine_adapter.execute("create schema lake") def _load_external_models(): - with open(tmpdir / c.EXTERNAL_MODELS_YAML, "r", encoding="utf8") as fd: - return YAML().load(fd) + return yaml.load(tmp_path / c.EXTERNAL_MODELS_YAML) ctx.create_external_models() @@ -298,7 +305,7 @@ def _load_external_models(): assert len(_load_external_models()) == 1 -def test_no_internal_model_conversion(tmp_path: Path, make_snapshot, mocker: MockerFixture): +def test_no_internal_model_conversion(tmp_path: Path, mocker: MockerFixture): engine_adapter_mock = mocker.Mock() engine_adapter_mock.columns.return_value = { "b": exp.DataType.build("text"), @@ -323,8 +330,7 @@ def test_no_internal_model_conversion(tmp_path: Path, make_snapshot, mocker: Moc "bigquery", ) - with open(filename, "r", encoding="utf8") as fd: - schema = YAML().load(fd) + schema = yaml.load(filename) assert len(schema) == 2 assert schema[0]["name"] == "`tbl-d`" @@ -342,8 +348,7 @@ def test_missing_table(tmp_path: Path): model = SqlModel(name="a", query=parse_one("select * FROM tbl_source")) filename = tmp_path / c.EXTERNAL_MODELS_YAML - logger = logging.getLogger("sqlmesh.core.schema_loader") - with patch.object(logger, "warning") as mock_logger: + with patch.object(context.console, "log_warning") as mock_logger: create_external_models_file( filename, {"a": model}, # type: ignore @@ -353,8 +358,7 @@ def test_missing_table(tmp_path: Path): ) assert """Unable to get schema for '"tbl_source"'""" in mock_logger.call_args[0][0] - with open(filename, "r", encoding="utf8") as fd: - schema = YAML().load(fd) + schema = yaml.load(filename) assert len(schema) == 0 with pytest.raises(SQLMeshError, match=r"""Unable to get schema for '"tbl_source"'.*"""): diff --git a/tests/core/test_seed.py b/tests/core/test_seed.py index cbe263c1bc..a22805cbd2 100644 --- a/tests/core/test_seed.py +++ b/tests/core/test_seed.py @@ -1,4 +1,4 @@ -import pandas as pd +import pandas as pd # noqa: TID253 import pytest from sqlglot import exp @@ -6,10 +6,10 @@ def test_read(): - content = """key,value,ds -1,one,2022-01-01 -2,two,2022-01-02 -3,three,2022-01-03 + content = """key,value,ds,bool +1,one,2022-01-01,true +2,two,2022-01-02,false +3,three,2022-01-03,true """ seed = Seed(content=content) # Since we provide "snowflake" as the dialect, all identifiers are expected to @@ -20,13 +20,14 @@ def test_read(): "KEY": exp.DataType.build("bigint"), "VALUE": exp.DataType.build("text"), "DS": exp.DataType.build("text"), + "BOOL": exp.DataType.build("boolean"), } - expected_df = pd.DataFrame( data={ "KEY": [1, 2, 3], "VALUE": ["one", "two", "three"], "DS": ["2022-01-01", "2022-01-02", "2022-01-03"], + "BOOL": [True, False, True], } ) dfs = seed_reader.read(batch_size=2) diff --git a/tests/core/test_selector.py b/tests/core/test_selector.py deleted file mode 100644 index f43e918b86..0000000000 --- a/tests/core/test_selector.py +++ /dev/null @@ -1,478 +0,0 @@ -from __future__ import annotations - -import typing as t -from pathlib import Path -from unittest.mock import call - -import pytest -from pytest_mock.plugin import MockerFixture - -from sqlmesh.core import dialect as d -from sqlmesh.core.audit import StandaloneAudit -from sqlmesh.core.environment import Environment -from sqlmesh.core.model import Model, SqlModel -from sqlmesh.core.selector import Selector -from sqlmesh.core.snapshot import SnapshotChangeCategory -from sqlmesh.utils import UniqueKeyDict - - -@pytest.mark.parametrize( - "default_catalog", - [ - None, - "test_catalog", - ], -) -def test_select_models(mocker: MockerFixture, make_snapshot, default_catalog: t.Optional[str]): - added_model = SqlModel( - name="db.added_model", - query=d.parse_one("SELECT 1 AS a"), - default_catalog=default_catalog, - tags=["tag1"], - ) - modified_model_v1 = SqlModel( - name="db.modified_model", - query=d.parse_one("SELECT a + 1 FROM db.added_model"), - default_catalog=default_catalog, - tags=["tag2"], - ) - modified_model_v2 = SqlModel( - name="db.modified_model", - query=d.parse_one("SELECT a + 2 FROM db.added_model"), - default_catalog=default_catalog, - tags=["tag2"], - ) - removed_model = SqlModel( - name="db.removed_model", - query=d.parse_one("SELECT a FROM db.added_model"), - default_catalog=default_catalog, - ) - standalone_audit = StandaloneAudit( - name="test_audit", query=d.parse_one("SELECT * FROM added_model WHERE a IS NULL") - ) - - modified_model_v1_snapshot = make_snapshot(modified_model_v1) - modified_model_v1_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - removed_model_snapshot = make_snapshot(removed_model) - removed_model_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - standalone_audit_snapshot = make_snapshot(standalone_audit) - standalone_audit_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - - env_name = "test_env" - - state_reader_mock = mocker.Mock() - state_reader_mock.get_environment.return_value = Environment( - name=env_name, - snapshots=[ - s.table_info - for s in (modified_model_v1_snapshot, removed_model_snapshot, standalone_audit_snapshot) - ], - start_at="2023-01-01", - end_at="2023-02-01", - plan_id="test_plan_id", - ) - state_reader_mock.get_snapshots.return_value = { - modified_model_v1_snapshot.snapshot_id: modified_model_v1_snapshot, - removed_model_snapshot.snapshot_id: removed_model_snapshot, - standalone_audit_snapshot.snapshot_id: standalone_audit_snapshot, - } - - added_model_schema = {'"db"': {'"added_model"': {"a": "INT"}}} - if default_catalog: - added_model_schema = {f'"{default_catalog}"': added_model_schema} # type: ignore - - local_models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") - local_models[added_model.fqn] = added_model - local_models[modified_model_v2.fqn] = modified_model_v2.copy( - update={"mapping_schema": added_model_schema} - ) - selector = Selector(state_reader_mock, local_models, default_catalog=default_catalog) - - _assert_models_equal( - selector.select_models(["db.added_model"], env_name), - { - added_model.fqn: added_model, - modified_model_v1.fqn: modified_model_v1.copy( - update={"mapping_schema": added_model_schema} - ), - removed_model.fqn: removed_model, - }, - ) - _assert_models_equal( - selector.select_models(["db.modified_model"], "missing_env", fallback_env_name=env_name), - { - modified_model_v2.fqn: modified_model_v2, - removed_model.fqn: removed_model, - }, - ) - _assert_models_equal( - selector.select_models(["db.removed_model"], env_name), - { - modified_model_v1.fqn: modified_model_v1, - }, - ) - _assert_models_equal( - selector.select_models( - ["db.added_model", "db.modified_model"], "missing_env", fallback_env_name=env_name - ), - { - added_model.fqn: added_model, - modified_model_v2.fqn: modified_model_v2.copy( - update={"mapping_schema": added_model_schema} - ), - removed_model.fqn: removed_model, - }, - ) - _assert_models_equal( - selector.select_models(["+db.modified_model"], env_name), - { - added_model.fqn: added_model, - modified_model_v2.fqn: modified_model_v2.copy( - update={"mapping_schema": added_model_schema} - ), - removed_model.fqn: removed_model, - }, - ) - _assert_models_equal( - selector.select_models(["db.added_model+"], env_name), - { - added_model.fqn: added_model, - modified_model_v2.fqn: modified_model_v2.copy( - update={"mapping_schema": added_model_schema} - ), - removed_model.fqn: removed_model, - }, - ) - _assert_models_equal( - selector.select_models( - ["db.added_model", "db.modified_model", "db.removed_model"], env_name - ), - local_models, - ) - _assert_models_equal( - selector.select_models(["*_model", "db.removed_model"], env_name), - local_models, - ) - _assert_models_equal( - selector.select_models(["tag:tag1", "tag:tag2"], "missing_env", fallback_env_name=env_name), - { - added_model.fqn: added_model, - modified_model_v2.fqn: modified_model_v2.copy( - update={"mapping_schema": added_model_schema} - ), - removed_model.fqn: removed_model, - }, - ) - _assert_models_equal( - selector.select_models(["tag:tag*"], "missing_env", fallback_env_name=env_name), - { - added_model.fqn: added_model, - modified_model_v2.fqn: modified_model_v2.copy( - update={"mapping_schema": added_model_schema} - ), - removed_model.fqn: removed_model, - }, - ) - _assert_models_equal( - selector.select_models(["tag:+tag2"], env_name), - { - added_model.fqn: added_model, - modified_model_v2.fqn: modified_model_v2.copy( - update={"mapping_schema": added_model_schema} - ), - removed_model.fqn: removed_model, - }, - ) - - -def test_select_models_missing_env(mocker: MockerFixture, make_snapshot): - model = SqlModel(name="test_model", query=d.parse_one("SELECT 1 AS a")) - - state_reader_mock = mocker.Mock() - state_reader_mock.get_environment.return_value = None - - local_models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") - local_models[model.fqn] = model - - selector = Selector(state_reader_mock, local_models) - - assert selector.select_models([model.name], "missing_env").keys() == {model.fqn} - assert not selector.select_models(["missing"], "missing_env") - - assert selector.select_models( - [model.name], "missing_env", fallback_env_name="another_missing_env" - ).keys() == {model.fqn} - - state_reader_mock.get_environment.assert_has_calls( - [ - call("missing_env"), - call("missing_env"), - call("missing_env"), - call("another_missing_env"), - ] - ) - - -@pytest.mark.parametrize( - "model_defs, selections, output", - [ - # Direct matching only - ( - [("model1", "tag1", None), ("model2", "tag2", None), ("model3", "tag3", None)], - ["tag:tag1", "tag:tag3"], - {'"model1"', '"model3"'}, - ), - # Wildcard works - ( - [("model1", "tag1", None), ("model2", "tag2", None), ("model3", "tag3", None)], - ["tag:tag*"], - {'"model1"', '"model2"', '"model3"'}, - ), - # Downstream models are included - ( - [("model1", "tag1", None), ("model2", "tag2", {"model1"}), ("model3", "tag3", None)], - ["tag:tag1+"], - {'"model1"', '"model2"'}, - ), - # Upstream models are included - ( - [("model1", "tag1", None), ("model2", "tag2", None), ("model3", "tag3", {"model2"})], - ["tag:+tag3"], - {'"model2"', '"model3"'}, - ), - # Upstream and downstream models are included - ( - [ - ("model1", "tag1", None), - ("model2", "tag2", {"model1"}), - ("model3", "tag3", {"model2"}), - ], - ["tag:+tag2+"], - {'"model1"', '"model2"', '"model3"'}, - ), - # Wildcard works with upstream and downstream models - ( - [ - ("model1", "tag1", None), - ("model2", "tag2", {"model1"}), - ("model3", "tag3", {"model2"}), - ("model4", "blah", {"model3"}), - ("model5", "tag4", None), - # Only excluded model since it doesn't match wildcard nor upstream/downstream - ("model6", "blah", None), - ], - ["tag:+tag*+"], - {'"model1"', '"model2"', '"model3"', '"model4"', '"model5"'}, - ), - # Multiple tags work - ( - [ - ("model1", "tag1", None), - ("model2", "tag2", None), - ("model3", "tag3", None), - ("model4", "tag4", None), - ], - ["tag:tag1", "tag:tag3"], - {'"model1"', '"model3"'}, - ), - # Multiple tags work with upstream and downstream models - ( - [ - ("model1", "tag1", None), - ("model2", "tag2", {"model1"}), - ("model3", "tag3", {"model2"}), - ("model4", "tag4", None), - ("model5", "tag5", {"model4"}), - ("model6", "tag6", {"model5"}), - ], - ["tag:+tag3", "tag:tag5"], - {'"model1"', '"model2"', '"model3"', '"model5"'}, - ), - # Case-insensitive matching - ( - [ - ("model1", "tag1", None), - ("model2", "tag2", None), - ("model3", "tag3", None), - ], - ["tag:TAG*"], - {'"model1"', '"model2"', '"model3"'}, - ), - # Wildcard returns everything - ( - [ - ("model1", "tag1", None), - ("model2", "tag2", None), - ("model3", "tag3", None), - ], - ["tag:*"], - {'"model1"', '"model2"', '"model3"'}, - ), - # Upstream that don't exist is fine - ( - [ - ("model1", "tag1", None), - ("model2", "tag2", None), - ], - ["tag:+tag2"], - {'"model2"'}, - ), - # No matches returns empty set - ( - [ - ("model1", "tag1", None), - ("model2", "tag2", None), - ], - ["tag:+tag3*+", "tag:+tag3+"], - set(), - ), - # Mix of models and tags - ( - [ - ("model1", "tag1", None), - ("model2", "tag2", None), - ("model3", "tag3", None), - ], - ["tag:tag1", "model2"], - {'"model1"', '"model2"'}, - ), - # Intersection of tags and model names - ( - [ - ("model1", "tag1", None), - ("model2", "tag1", {"model1"}), - ("model3", "tag2", {"model1"}), - ("model4", "tag1", None), - ], - ["tag:tag1 & model1+"], - {'"model1"', '"model2"'}, - ), - # Intersection of tags and model names (order doesn't matter) - ( - [ - ("model1", "tag1", None), - ("model2", "tag1", {"model1"}), - ("model3", "tag2", {"model1"}), - ("model4", "tag1", None), - ], - ["model1+ & tag:tag1"], - {'"model1"', '"model2"'}, - ), - ], -) -def test_expand_model_selections( - mocker: MockerFixture, make_snapshot, model_defs, selections, output -): - models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") - for model_name, tag, depends_on in model_defs: - model = SqlModel( - name=model_name, query=d.parse_one("SELECT 1 AS a"), depends_on=depends_on, tags=[tag] - ) - models[model.fqn] = model - - selector = Selector(mocker.Mock(), models) - assert selector.expand_model_selections(selections) == output - - -def test_model_selection_normalized(mocker: MockerFixture, make_snapshot): - models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") - model = SqlModel( - name="`db.test_Model`", - query=d.parse_one("SELECT 1 AS a"), - tags=["tag1"], - dialect="bigquery", - ) - models[model.fqn] = model - selector = Selector(mocker.Mock(), models, dialect="bigquery") - assert selector.expand_model_selections(["db.test_Model"]) == {'"db"."test_Model"'} - - -@pytest.mark.parametrize( - "expressions, expected_fqns", - [ - (["git:main"], {'"test_model_a"', '"test_model_c"'}), - (["git:main & +*model_c"], {'"test_model_c"'}), - (["git:main+"], {'"test_model_a"', '"test_model_c"', '"test_model_d"'}), - (["+git:main"], {'"test_model_a"', '"test_model_c"', '"test_model_b"'}), - (["+git:main+"], {'"test_model_a"', '"test_model_c"', '"test_model_b"', '"test_model_d"'}), - ], -) -def test_expand_git_selection( - mocker: MockerFixture, expressions: t.List[str], expected_fqns: t.Set[str] -): - models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") - - model_a = SqlModel(name="test_model_a", query=d.parse_one("SELECT 1 AS a")) - model_a._path = Path("/path/to/test_model_a.sql") - models[model_a.fqn] = model_a - - model_b = SqlModel(name="test_model_b", query=d.parse_one("SELECT 2 AS b")) - model_b._path = Path("/path/to/test_model_b.sql") - models[model_b.fqn] = model_b - - model_c = SqlModel( - name="test_model_c", - query=d.parse_one("SELECT b AS c FROM test_model_b"), - depends_on={"test_model_b"}, - ) - model_c._path = Path("/path/to/test_model_c.sql") - models[model_c.fqn] = model_c - - model_d = SqlModel( - name="test_model_d", - query=d.parse_one("SELECT c FROM test_model_c"), - depends_on={"test_model_c"}, - ) - model_d._path = Path("/path/to/test_model_d.sql") - models[model_d.fqn] = model_d - - git_client_mock = mocker.Mock() - git_client_mock.list_untracked_files.return_value = [] - git_client_mock.list_uncommitted_changed_files.return_value = [] - git_client_mock.list_committed_changed_files.return_value = [model_a._path, model_c._path] - - selector = Selector(mocker.Mock(), models) - selector._git_client = git_client_mock - - assert selector.expand_model_selections(expressions) == expected_fqns - - git_client_mock.list_committed_changed_files.assert_called_once_with(target_branch="main") - git_client_mock.list_uncommitted_changed_files.assert_called_once() - git_client_mock.list_untracked_files.assert_called_once() - - -def test_select_models_with_external_parent(mocker: MockerFixture): - default_catalog = "test_catalog" - added_model = SqlModel( - name="db.added_model", - query=d.parse_one("SELECT 1 AS a FROM external"), - default_catalog=default_catalog, - tags=["tag1"], - ) - - env_name = "test_env" - - state_reader_mock = mocker.Mock() - state_reader_mock.get_environment.return_value = Environment( - name=env_name, - snapshots=[], - start_at="2023-01-01", - end_at="2023-02-01", - plan_id="test_plan_id", - ) - state_reader_mock.get_snapshots.return_value = {} - - local_models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") - local_models[added_model.fqn] = added_model - - selector = Selector(state_reader_mock, local_models, default_catalog=default_catalog) - - expanded_selections = selector.expand_model_selections(["+*added_model*"]) - assert expanded_selections == {added_model.fqn} - - -def _assert_models_equal(actual: t.Dict[str, Model], expected: t.Dict[str, Model]) -> None: - assert set(actual) == set(expected) - for name, model in actual.items(): - # Use dict() to make Pydantic V2 happy. - assert model.dict() == expected[name].dict() diff --git a/tests/core/test_selector_dbt.py b/tests/core/test_selector_dbt.py new file mode 100644 index 0000000000..112c5740ac --- /dev/null +++ b/tests/core/test_selector_dbt.py @@ -0,0 +1,63 @@ +import typing as t +import pytest +from pytest_mock import MockerFixture +from sqlglot import exp +from sqlmesh.core.model.kind import SeedKind, ExternalKind, FullKind +from sqlmesh.core.model.seed import Seed +from sqlmesh.core.model.definition import SqlModel, SeedModel, ExternalModel +from sqlmesh.core.audit.definition import StandaloneAudit +from sqlmesh.core.snapshot.definition import Node +from sqlmesh.core.selector import DbtSelector +from sqlmesh.core.selector import parse, ResourceType +from sqlmesh.utils.errors import SQLMeshError +import sqlmesh.core.dialect as d +from sqlmesh.utils import UniqueKeyDict + + +def test_parse_resource_type(): + assert parse("resource_type:foo") == ResourceType(this=exp.Var(this="foo")) + + +@pytest.mark.parametrize( + "resource_type,expected", + [ + ("model", {'"test"."normal_model"'}), + ("seed", {'"test"."seed_model"'}), + ("test", {'"test"."standalone_audit"'}), + ("source", {'"external"."model"'}), + ], +) +def test_expand_model_selections_resource_type( + mocker: MockerFixture, resource_type: str, expected: t.Set[str] +): + models: t.Dict[str, Node] = { + '"test"."normal_model"': SqlModel( + name="test.normal_model", + kind=FullKind(), + query=d.parse_one("SELECT 'normal_model' AS what"), + ), + '"test"."seed_model"': SeedModel( + name="test.seed_model", kind=SeedKind(path="/tmp/foo"), seed=Seed(content="id,name") + ), + '"test"."standalone_audit"': StandaloneAudit( + name="test.standalone_audit", query=d.parse_one("SELECT 'standalone_audit' AS what") + ), + '"external"."model"': ExternalModel(name="external.model", kind=ExternalKind()), + } + + selector = DbtSelector(state_reader=mocker.Mock(), models=UniqueKeyDict("models")) + + assert selector.expand_model_selections([f"resource_type:{resource_type}"], models) == expected + + +def test_unsupported_resource_type(mocker: MockerFixture): + selector = DbtSelector(state_reader=mocker.Mock(), models=UniqueKeyDict("models")) + + models: t.Dict[str, Node] = { + '"test"."normal_model"': SqlModel( + name="test.normal_model", query=d.parse_one("SELECT 'normal_model' AS what") + ), + } + + with pytest.raises(SQLMeshError, match="Unsupported"): + selector.expand_model_selections(["resource_type:analysis"], models) diff --git a/tests/core/test_selector_native.py b/tests/core/test_selector_native.py new file mode 100644 index 0000000000..5889efadda --- /dev/null +++ b/tests/core/test_selector_native.py @@ -0,0 +1,808 @@ +from __future__ import annotations + +import typing as t +from pathlib import Path +from unittest.mock import call + +import pytest +from pytest_mock.plugin import MockerFixture +import subprocess + +from sqlmesh.core import dialect as d +from sqlmesh.core.audit import StandaloneAudit +from sqlmesh.core.environment import Environment +from sqlmesh.core.model import Model, SqlModel +from sqlmesh.core.model.common import ParsableSql +from sqlmesh.core.selector import NativeSelector +from sqlmesh.core.snapshot import SnapshotChangeCategory +from sqlmesh.utils import UniqueKeyDict +from sqlmesh.utils.date import now_timestamp +from sqlmesh.utils.git import GitClient + + +@pytest.mark.parametrize( + "default_catalog", + [ + None, + "test_catalog", + ], +) +def test_select_models(mocker: MockerFixture, make_snapshot, default_catalog: t.Optional[str]): + added_model = SqlModel( + name="db.added_model", + query=d.parse_one("SELECT 1 AS a"), + default_catalog=default_catalog, + tags=["tag1"], + ) + modified_model_v1 = SqlModel( + name="db.modified_model", + query=d.parse_one("SELECT a + 1 FROM db.added_model"), + default_catalog=default_catalog, + tags=["tag2"], + ) + modified_model_v2 = SqlModel( + name="db.modified_model", + query=d.parse_one("SELECT a + 2 FROM db.added_model"), + default_catalog=default_catalog, + tags=["tag2"], + ) + removed_model = SqlModel( + name="db.removed_model", + query=d.parse_one("SELECT a FROM db.added_model"), + default_catalog=default_catalog, + ) + standalone_audit = StandaloneAudit( + name="test_audit", query=d.parse_one("SELECT * FROM added_model WHERE a IS NULL") + ) + + modified_model_v1_snapshot = make_snapshot(modified_model_v1) + modified_model_v1_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + removed_model_snapshot = make_snapshot(removed_model) + removed_model_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + standalone_audit_snapshot = make_snapshot(standalone_audit) + standalone_audit_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + env_name = "test_env" + + state_reader_mock = mocker.Mock() + state_reader_mock.get_environment.return_value = Environment( + name=env_name, + snapshots=[ + s.table_info + for s in (modified_model_v1_snapshot, removed_model_snapshot, standalone_audit_snapshot) + ], + start_at="2023-01-01", + end_at="2023-02-01", + plan_id="test_plan_id", + ) + state_reader_mock.get_snapshots.return_value = { + modified_model_v1_snapshot.snapshot_id: modified_model_v1_snapshot, + removed_model_snapshot.snapshot_id: removed_model_snapshot, + standalone_audit_snapshot.snapshot_id: standalone_audit_snapshot, + } + + added_model_schema = {'"db"': {'"added_model"': {"a": "INT"}}} + if default_catalog: + added_model_schema = {f'"{default_catalog}"': added_model_schema} # type: ignore + + local_models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") + local_models[added_model.fqn] = added_model + local_models[modified_model_v2.fqn] = modified_model_v2.copy( + update={"mapping_schema": added_model_schema} + ) + selector = NativeSelector(state_reader_mock, local_models, default_catalog=default_catalog) + + _assert_models_equal( + selector.select_models(["db.added_model"], env_name), + { + added_model.fqn: added_model, + modified_model_v1.fqn: modified_model_v1.copy( + update={"mapping_schema": added_model_schema} + ), + removed_model.fqn: removed_model, + }, + ) + _assert_models_equal( + selector.select_models(["db.modified_model"], "missing_env", fallback_env_name=env_name), + { + modified_model_v2.fqn: modified_model_v2, + removed_model.fqn: removed_model, + }, + ) + _assert_models_equal( + selector.select_models(["db.removed_model"], env_name), + { + modified_model_v1.fqn: modified_model_v1, + }, + ) + _assert_models_equal( + selector.select_models( + ["db.added_model", "db.modified_model"], "missing_env", fallback_env_name=env_name + ), + { + added_model.fqn: added_model, + modified_model_v2.fqn: modified_model_v2.copy( + update={"mapping_schema": added_model_schema} + ), + removed_model.fqn: removed_model, + }, + ) + _assert_models_equal( + selector.select_models(["+db.modified_model"], env_name), + { + added_model.fqn: added_model, + modified_model_v2.fqn: modified_model_v2.copy( + update={"mapping_schema": added_model_schema} + ), + removed_model.fqn: removed_model, + }, + ) + _assert_models_equal( + selector.select_models(["db.added_model+"], env_name), + { + added_model.fqn: added_model, + modified_model_v2.fqn: modified_model_v2.copy( + update={"mapping_schema": added_model_schema} + ), + removed_model.fqn: removed_model, + }, + ) + _assert_models_equal( + selector.select_models( + ["db.added_model", "db.modified_model", "db.removed_model"], env_name + ), + local_models, + ) + _assert_models_equal( + selector.select_models(["*_model", "db.removed_model"], env_name), + local_models, + ) + _assert_models_equal( + selector.select_models(["tag:tag1", "tag:tag2"], "missing_env", fallback_env_name=env_name), + { + added_model.fqn: added_model, + modified_model_v2.fqn: modified_model_v2.copy( + update={"mapping_schema": added_model_schema} + ), + removed_model.fqn: removed_model, + }, + ) + _assert_models_equal( + selector.select_models(["tag:tag*"], "missing_env", fallback_env_name=env_name), + { + added_model.fqn: added_model, + modified_model_v2.fqn: modified_model_v2.copy( + update={"mapping_schema": added_model_schema} + ), + removed_model.fqn: removed_model, + }, + ) + _assert_models_equal( + selector.select_models(["+tag:tag2"], env_name), + { + added_model.fqn: added_model, + modified_model_v2.fqn: modified_model_v2.copy( + update={"mapping_schema": added_model_schema} + ), + removed_model.fqn: removed_model, + }, + ) + + +def test_select_models_expired_environment(mocker: MockerFixture, make_snapshot): + modified_model_v1 = SqlModel( + name="db.modified_model", + query=d.parse_one("SELECT a + 1 FROM db.added_model"), + ) + modified_model_v2 = SqlModel( + name="db.modified_model", + query=d.parse_one("SELECT a + 2 FROM db.added_model"), + ) + removed_model = SqlModel( + name="db.removed_model", + query=d.parse_one("SELECT a FROM db.added_model"), + ) + standalone_audit = StandaloneAudit( + name="test_audit", query=d.parse_one("SELECT * FROM added_model WHERE a IS NULL") + ) + + modified_model_v1_snapshot = make_snapshot(modified_model_v1) + modified_model_v1_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + removed_model_snapshot = make_snapshot(removed_model) + removed_model_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + standalone_audit_snapshot = make_snapshot(standalone_audit) + standalone_audit_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + prod_env = Environment( + name="prod", + snapshots=[modified_model_v1_snapshot.table_info], + start_at="2023-01-01", + end_at="2023-02-01", + plan_id="test_plan_id", + ) + + env_name = "test_env" + dev_env = Environment( + name=env_name, + snapshots=[modified_model_v1_snapshot.table_info, removed_model_snapshot.table_info], + start_at="2023-01-01", + end_at="2023-02-01", + plan_id="test_plan_id", + ) + + state_reader_mock = mocker.Mock() + state_reader_mock.get_environment.side_effect = ( + lambda name: prod_env if name == "prod" else dev_env + ) + + all_snapshots = { + modified_model_v1_snapshot.snapshot_id: modified_model_v1_snapshot, + removed_model_snapshot.snapshot_id: removed_model_snapshot, + } + state_reader_mock.get_snapshots.side_effect = lambda snapshots: { + s.snapshot_id: all_snapshots[s.snapshot_id] for s in snapshots + } + + local_models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") + local_models[modified_model_v2.fqn] = modified_model_v2 + selector = NativeSelector(state_reader_mock, local_models) + + _assert_models_equal( + selector.select_models(["*.modified_model"], env_name, fallback_env_name="prod"), + { + removed_model.fqn: removed_model, + modified_model_v2.fqn: modified_model_v2, + }, + ) + + dev_env.expiration_ts = now_timestamp() - 1 + _assert_models_equal( + selector.select_models(["*.modified_model"], env_name, fallback_env_name="prod"), + { + modified_model_v2.fqn: modified_model_v2, + }, + ) + + +def test_select_change_schema(mocker: MockerFixture, make_snapshot): + parent = SqlModel( + name="db.parent", + query=d.parse_one("SELECT 1 AS a"), + ) + child = SqlModel( + name="db.child", + query=d.parse_one("SELECT * FROM db.parent"), + mapping_schema={'"db"': {'"parent"': {"a": "INT"}}}, + ) + + parent_snapshot = make_snapshot(parent) + parent_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + child_snapshot = make_snapshot(child) + child_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + env_name = "test_env" + + state_reader_mock = mocker.Mock() + state_reader_mock.get_environment.return_value = Environment( + name=env_name, + snapshots=[s.table_info for s in (parent_snapshot, child_snapshot)], + start_at="2023-01-01", + end_at="2023-02-01", + plan_id="test_plan_id", + ) + state_reader_mock.get_snapshots.return_value = { + parent_snapshot.snapshot_id: parent_snapshot, + child_snapshot.snapshot_id: child_snapshot, + } + + local_models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") + local_parent = parent.copy( + update={ + "query_": ParsableSql( + sql=parent.query.select("2 as b", append=False).sql(dialect=parent.dialect) # type: ignore + ) + } + ) + local_models[local_parent.fqn] = local_parent + local_child = child.copy(update={"mapping_schema": {'"db"': {'"parent"': {"b": "INT"}}}}) + local_models[local_child.fqn] = local_child + + selector = NativeSelector(state_reader_mock, local_models) + + selected = selector.select_models(["db.parent"], env_name) + assert selected[local_child.fqn].render_query() != child.render_query() + + _assert_models_equal( + selected, + { + local_parent.fqn: local_parent, + local_child.fqn: local_child, + }, + ) + + selected = selector.select_models(["db.child"], env_name) + assert selected[local_child.fqn].data_hash == child.data_hash + + _assert_models_equal( + selected, + { + parent.fqn: parent, + child.fqn: child, + }, + ) + + +def test_select_models_missing_env(mocker: MockerFixture, make_snapshot): + model = SqlModel(name="test_model", query=d.parse_one("SELECT 1 AS a")) + + state_reader_mock = mocker.Mock() + state_reader_mock.get_environment.return_value = None + + local_models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") + local_models[model.fqn] = model + + selector = NativeSelector(state_reader_mock, local_models) + + assert selector.select_models([model.name], "missing_env").keys() == {model.fqn} + assert not selector.select_models(["missing"], "missing_env") + + assert selector.select_models( + [model.name], "missing_env", fallback_env_name="another_missing_env" + ).keys() == {model.fqn} + + state_reader_mock.get_environment.assert_has_calls( + [ + call("missing_env"), + call("missing_env"), + call("missing_env"), + call("another_missing_env"), + ] + ) + + +@pytest.mark.parametrize( + "model_defs, selections, output", + [ + # Direct matching only + ( + [("model1", "tag1", None), ("model2", "tag2", None), ("model3", "tag3", None)], + ["tag:tag1", "tag:tag3"], + {'"model1"', '"model3"'}, + ), + # Wildcard works + ( + [("model1", "tag1", None), ("model2", "tag2", None), ("model3", "tag3", None)], + ["tag:tag*"], + {'"model1"', '"model2"', '"model3"'}, + ), + # Downstream models are included + ( + [("model1", "tag1", None), ("model2", "tag2", {"model1"}), ("model3", "tag3", None)], + ["tag:tag1+"], + {'"model1"', '"model2"'}, + ), + # Upstream models are included + ( + [("model1", "tag1", None), ("model2", "tag2", None), ("model3", "tag3", {"model2"})], + ["+tag:tag3"], + {'"model2"', '"model3"'}, + ), + # Upstream and downstream models are included + ( + [ + ("model1", "tag1", None), + ("model2", "tag2", {"model1"}), + ("model3", "tag3", {"model2"}), + ], + ["+tag:tag2+"], + {'"model1"', '"model2"', '"model3"'}, + ), + # Wildcard works with upstream and downstream models + ( + [ + ("model1", "tag1", None), + ("model2", "tag2", {"model1"}), + ("model3", "tag3", {"model2"}), + ("model4", "blah", {"model3"}), + ("model5", "tag4", None), + # Only excluded model since it doesn't match wildcard nor upstream/downstream + ("model6", "blah", None), + ], + ["+tag:tag*+"], + {'"model1"', '"model2"', '"model3"', '"model4"', '"model5"'}, + ), + # Multiple tags work + ( + [ + ("model1", "tag1", None), + ("model2", "tag2", None), + ("model3", "tag3", None), + ("model4", "tag4", None), + ], + ["tag:tag1", "tag:tag3"], + {'"model1"', '"model3"'}, + ), + # Multiple tags work with upstream and downstream models + ( + [ + ("model1", "tag1", None), + ("model2", "tag2", {"model1"}), + ("model3", "tag3", {"model2"}), + ("model4", "tag4", None), + ("model5", "tag5", {"model4"}), + ("model6", "tag6", {"model5"}), + ], + ["+tag:tag3", "tag:tag5"], + {'"model1"', '"model2"', '"model3"', '"model5"'}, + ), + # Case-insensitive matching + ( + [ + ("model1", "tag1", None), + ("model2", "tag2", None), + ("model3", "tag3", None), + ], + ["tag:TAG*"], + {'"model1"', '"model2"', '"model3"'}, + ), + # Wildcard returns everything + ( + [ + ("model1", "tag1", None), + ("model2", "tag2", None), + ("model3", "tag3", None), + ], + ["tag:*"], + {'"model1"', '"model2"', '"model3"'}, + ), + # Upstream that don't exist is fine + ( + [ + ("model1", "tag1", None), + ("model2", "tag2", None), + ], + ["+tag:tag2"], + {'"model2"'}, + ), + # No matches returns empty set + ( + [ + ("model1", "tag1", None), + ("model2", "tag2", None), + ], + ["+tag:tag3*+", "+tag:tag3+"], + set(), + ), + # Mix of models and tags + ( + [ + ("model1", "tag1", None), + ("model2", "tag2", None), + ("model3", "tag3", None), + ], + ["tag:tag1", "model2"], + {'"model1"', '"model2"'}, + ), + # Intersection of tags and model names + ( + [ + ("model1", "tag1", None), + ("model2", "tag1", {"model1"}), + ("model3", "tag2", {"model1"}), + ("model4", "tag1", None), + ], + ["tag:tag1 & model1+"], + {'"model1"', '"model2"'}, + ), + # Intersection of tags and model names (order doesn't matter) + ( + [ + ("model1", "tag1", None), + ("model2", "tag1", {"model1"}), + ("model3", "tag2", {"model1"}), + ("model4", "tag1", None), + ], + ["model1+ & tag:tag1"], + {'"model1"', '"model2"'}, + ), + # negation + ( + [("model1", "tag1", None), ("model2", "tag2", None), ("model3", "tag3", None)], + ["^tag:tag1"], + {'"model2"', '"model3"'}, + ), + ( + [("model1", "tag1", None), ("model2", "tag2", None), ("model3", "tag3", None)], + ["^model1"], + {'"model2"', '"model3"'}, + ), + ( + [("model1", "tag1", None), ("model2", "tag2", None), ("model3", "tag3", None)], + ["model* & ^(tag:tag1 | tag:tag2)"], + {'"model3"'}, + ), + ( + [ + ("model1", "tag1", None), + ("model2", "tag2", {"model1"}), + ("model3", "tag3", {"model1"}), + ], + ["(model1*)+"], + {'"model1"', '"model2"', '"model3"'}, + ), + ( + [ + ("model1", "tag1", None), + ("model2", "tag2", {"model1"}), + ("model3", "tag3", {"model2"}), + ], + ["+(+model2*+)+"], + {'"model1"', '"model2"', '"model3"'}, + ), + ( + [ + ("model1", "tag1", None), + ("model2", "tag2", {"model1"}), + ("model3", "tag3", {"model1"}), + ], + ["(model* & ^*1)+"], + {'"model2"', '"model3"'}, + ), + ( + [("model2", "tag1", None), ("model2_1", "tag2", None), ("model2_2", "tag3", None)], + ["*2_*"], + {'"model2_1"', '"model2_2"'}, + ), + ], +) +def test_expand_model_selections( + mocker: MockerFixture, make_snapshot, model_defs, selections, output +): + models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") + for model_name, tag, depends_on in model_defs: + model = SqlModel( + name=model_name, query=d.parse_one("SELECT 1 AS a"), depends_on=depends_on, tags=[tag] + ) + models[model.fqn] = model + + selector = NativeSelector(mocker.Mock(), models) + assert selector.expand_model_selections(selections) == output + + +def test_model_selection_normalized(mocker: MockerFixture, make_snapshot): + models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") + model = SqlModel( + name="`db.test_Model`", + query=d.parse_one("SELECT 1 AS a"), + tags=["tag1"], + dialect="bigquery", + ) + models[model.fqn] = model + selector = NativeSelector(mocker.Mock(), models, dialect="bigquery") + assert selector.expand_model_selections(["db.test_Model"]) == {'"db"."test_Model"'} + + +@pytest.mark.parametrize( + "expressions, expected_fqns", + [ + (["git:main"], {'"test_model_a"', '"test_model_c"'}), + (["git:main & +*model_c"], {'"test_model_c"'}), + (["git:main+"], {'"test_model_a"', '"test_model_c"', '"test_model_d"'}), + (["+git:main"], {'"test_model_a"', '"test_model_c"', '"test_model_b"'}), + (["+git:main+"], {'"test_model_a"', '"test_model_c"', '"test_model_b"', '"test_model_d"'}), + ], +) +def test_expand_git_selection( + mocker: MockerFixture, expressions: t.List[str], expected_fqns: t.Set[str] +): + models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") + + model_a = SqlModel(name="test_model_a", query=d.parse_one("SELECT 1 AS a")) + model_a._path = Path("/path/to/test_model_a.sql") + models[model_a.fqn] = model_a + + model_b = SqlModel(name="test_model_b", query=d.parse_one("SELECT 2 AS b")) + model_b._path = Path("/path/to/test_model_b.sql") + models[model_b.fqn] = model_b + + model_c = SqlModel( + name="test_model_c", + query=d.parse_one("SELECT b AS c FROM test_model_b"), + depends_on={"test_model_b"}, + ) + model_c._path = Path("/path/to/test_model_c.sql") + models[model_c.fqn] = model_c + + model_d = SqlModel( + name="test_model_d", + query=d.parse_one("SELECT c FROM test_model_c"), + depends_on={"test_model_c"}, + ) + model_d._path = Path("/path/to/test_model_d.sql") + models[model_d.fqn] = model_d + + git_client_mock = mocker.Mock() + git_client_mock.list_untracked_files.return_value = [] + git_client_mock.list_uncommitted_changed_files.return_value = [] + git_client_mock.list_committed_changed_files.return_value = [model_a._path, model_c._path] + + selector = NativeSelector(mocker.Mock(), models) + selector._git_client = git_client_mock + + assert selector.expand_model_selections(expressions) == expected_fqns + + git_client_mock.list_committed_changed_files.assert_called_once_with(target_branch="main") + git_client_mock.list_uncommitted_changed_files.assert_called_once() + git_client_mock.list_untracked_files.assert_called_once() + + +def test_expand_git_selection_integration(tmp_path: Path, mocker: MockerFixture): + repo_path = tmp_path / "test_repo" + repo_path.mkdir() + subprocess.run(["git", "init", "-b", "main"], cwd=repo_path, check=True, capture_output=True) + + models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") + model_a_path = repo_path / "model_a.sql" + model_a_path.write_text("SELECT 1 AS a") + model_a = SqlModel(name="test_model_a", query=d.parse_one("SELECT 1 AS a")) + model_a._path = model_a_path + models[model_a.fqn] = model_a + + model_b_path = repo_path / "model_b.sql" + model_b_path.write_text("SELECT 2 AS b") + model_b = SqlModel(name="test_model_b", query=d.parse_one("SELECT 2 AS b")) + model_b._path = model_b_path + models[model_b.fqn] = model_b + + subprocess.run(["git", "add", "."], cwd=repo_path, check=True, capture_output=True) + subprocess.run( + [ + "git", + "-c", + "user.name=Max", + "-c", + "user.email=max@rb.com", + "commit", + "-m", + "Initial commit", + ], + cwd=repo_path, + check=True, + capture_output=True, + ) + + # no changes should select nothing + git_client = GitClient(repo_path) + selector = NativeSelector(mocker.Mock(), models) + selector._git_client = git_client + assert selector.expand_model_selections([f"git:main"]) == set() + + # modify A but dont stage it, should be only selected + model_a_path.write_text("SELECT 10 AS a") + assert selector.expand_model_selections([f"git:main"]) == {'"test_model_a"'} + + # stage model A, should still select it + subprocess.run(["git", "add", "model_a.sql"], cwd=repo_path, check=True, capture_output=True) + assert selector.expand_model_selections([f"git:main"]) == {'"test_model_a"'} + + # now add unstaged change to B and both should be selected + model_b_path.write_text("SELECT 20 AS b") + assert selector.expand_model_selections([f"git:main"]) == { + '"test_model_a"', + '"test_model_b"', + } + + subprocess.run( + ["git", "checkout", "-b", "dev"], + cwd=repo_path, + check=True, + capture_output=True, + ) + + subprocess.run( + [ + "git", + "-c", + "user.name=Max", + "-c", + "user.email=max@rb.com", + "commit", + "-m", + "Update model_a", + ], + cwd=repo_path, + check=True, + capture_output=True, + ) + + # now A is committed in the dev branch and B unstaged but should both be selected + assert selector.expand_model_selections([f"git:main"]) == { + '"test_model_a"', + '"test_model_b"', + } + + +def test_select_models_with_external_parent(mocker: MockerFixture): + default_catalog = "test_catalog" + added_model = SqlModel( + name="db.added_model", + query=d.parse_one("SELECT 1 AS a FROM external"), + default_catalog=default_catalog, + tags=["tag1"], + ) + + env_name = "test_env" + + state_reader_mock = mocker.Mock() + state_reader_mock.get_environment.return_value = Environment( + name=env_name, + snapshots=[], + start_at="2023-01-01", + end_at="2023-02-01", + plan_id="test_plan_id", + ) + state_reader_mock.get_snapshots.return_value = {} + + local_models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") + local_models[added_model.fqn] = added_model + + selector = NativeSelector(state_reader_mock, local_models, default_catalog=default_catalog) + + expanded_selections = selector.expand_model_selections(["+*added_model*"]) + assert expanded_selections == {added_model.fqn} + + +def test_select_models_local_tags_take_precedence_over_remote( + mocker: MockerFixture, make_snapshot: t.Callable +) -> None: + existing_model = SqlModel( + name="db.existing", + query=d.parse_one("SELECT 1 AS a"), + ) + + existing_snapshot = make_snapshot(existing_model) + existing_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + env_name = "test_env" + + state_reader_mock = mocker.Mock() + state_reader_mock.get_environment.return_value = Environment( + name=env_name, + snapshots=[existing_snapshot.table_info], + start_at="2023-01-01", + end_at="2023-02-01", + plan_id="test_plan_id", + ) + state_reader_mock.get_snapshots.return_value = { + existing_snapshot.snapshot_id: existing_snapshot + } + + local_models: UniqueKeyDict[str, Model] = UniqueKeyDict("models") + local_new = SqlModel( + name="db.new", + tags=["a"], + query=d.parse_one("SELECT 1 as a"), + ) + local_existing = existing_model.copy(update={"tags": ["a"]}) # type: ignore + local_models[local_existing.fqn] = local_existing + local_models[local_new.fqn] = local_new + + selector = NativeSelector(state_reader_mock, local_models) + + selected = selector.select_models(["tag:a"], env_name) + + # both should get selected because they both now have the 'a' tag locally, even though one exists in remote state without the 'a' tag + _assert_models_equal( + selected, + { + local_existing.fqn: local_existing, + local_new.fqn: local_new, + }, + ) + + +def _assert_models_equal(actual: t.Dict[str, Model], expected: t.Dict[str, Model]) -> None: + assert set(actual) == set(expected) + for name, model in actual.items(): + # Use dict() to make Pydantic V2 happy. + assert model.dict() == expected[name].dict() diff --git a/tests/core/test_snapshot.py b/tests/core/test_snapshot.py index 050f36620b..1acc6cc265 100644 --- a/tests/core/test_snapshot.py +++ b/tests/core/test_snapshot.py @@ -1,3 +1,4 @@ +import pickle import json import typing as t from copy import deepcopy @@ -5,10 +6,12 @@ from pathlib import Path import pytest +import time_machine from _pytest.monkeypatch import MonkeyPatch from pytest_mock.plugin import MockerFixture from sqlglot import exp, to_column +from sqlmesh.core import constants as c from sqlmesh.core.audit import StandaloneAudit from sqlmesh.core.config import ( AutoCategorizationMode, @@ -18,9 +21,11 @@ from sqlmesh.core.context import Context from sqlmesh.core.dialect import parse, parse_one from sqlmesh.core.environment import EnvironmentNamingInfo +from sqlmesh.core.macros import SQL from sqlmesh.core.model import ( FullKind, IncrementalByTimeRangeKind, + IncrementalByUniqueKeyKind, IncrementalUnmanagedKind, Model, Seed, @@ -29,26 +34,44 @@ SqlModel, create_seed_model, load_sql_based_model, + CustomKind, ) from sqlmesh.core.model.kind import TimeColumn, ModelKindName +from sqlmesh.core.node import IntervalUnit +from sqlmesh.core.signal import signal from sqlmesh.core.snapshot import ( DeployabilityIndex, QualifiedViewName, Snapshot, + SnapshotId, + SnapshotIdAndVersion, SnapshotChangeCategory, SnapshotFingerprint, + SnapshotIntervals, SnapshotTableInfo, earliest_start_date, fingerprint_from_node, has_paused_forward_only, missing_intervals, ) +from sqlmesh.core.snapshot.cache import SnapshotCache from sqlmesh.core.snapshot.categorizer import categorize_change -from sqlmesh.core.snapshot.definition import display_name +from sqlmesh.core.snapshot.definition import ( + apply_auto_restatements, + display_name, + get_next_model_interval_start, + check_ready_intervals, + _contiguous_intervals, + table_name, + TableNamingConvention, +) +from sqlmesh.core.config.common import VirtualEnvironmentMode from sqlmesh.utils import AttributeDict -from sqlmesh.utils.date import to_date, to_datetime, to_timestamp -from sqlmesh.utils.errors import SQLMeshError +from sqlmesh.utils.date import DatetimeRanges, to_date, to_datetime, to_timestamp +from sqlmesh.utils.errors import SQLMeshError, SignalEvalError from sqlmesh.utils.jinja import JinjaMacroRegistry, MacroInfo +from sqlmesh.utils.hashing import md5 +from sqlmesh.core.console import get_console @pytest.fixture @@ -57,7 +80,7 @@ def parent_model(): name="parent.tbl", kind=dict(time_column="ds", name=ModelKindName.INCREMENTAL_BY_TIME_RANGE), dialect="spark", - query=parse_one("SELECT 1, ds"), + query="SELECT 1, ds", ) @@ -70,7 +93,7 @@ def model(): dialect="spark", cron="1 0 * * *", start="2020-01-01", - query=parse_one("SELECT @EACH([1, 2], x -> x), ds FROM parent.tbl"), + query="SELECT @EACH([1, 2], x -> x), ds FROM parent.tbl", ) @@ -100,8 +123,11 @@ def test_json(snapshot: Snapshot): "fingerprint": snapshot.fingerprint.dict(), "intervals": [], "dev_intervals": [], + "dev_table_suffix": "dev", + "pending_restatement_intervals": [], "node": { "audits": [], + "audit_definitions": {}, "clustered_by": [], "cron": "1 0 * * *", "kind": { @@ -110,18 +136,22 @@ def test_json(snapshot: Snapshot): "batch_size": 30, "forward_only": False, "on_destructive_change": "ERROR", + "on_additive_change": "ALLOW", + "partition_by_time_column": True, "disable_restatement": False, "dialect": "spark", }, "mapping_schema": {}, - "inline_audits": {}, "start": "2020-01-01", "dialect": "spark", "name": "name", "partitioned_by": [], "project": "", + "python_env": {}, "owner": "owner", - "query": "SELECT @EACH([1, 2], x -> x), ds FROM parent.tbl", + "query": { + "sql": "SELECT @EACH([1, 2], x -> x), ds FROM parent.tbl", + }, "jinja_macros": { "create_builtins_module": "sqlmesh.utils.jinja", "global_objs": {}, @@ -136,18 +166,82 @@ def test_json(snapshot: Snapshot): "allow_partials": False, "signals": [], "enabled": True, + "extract_dependencies_from_query": True, + "virtual_environment_mode": "full", + "grants_target_layer": "virtual", }, - "audits": [], "name": '"name"', "parents": [{"name": '"parent"."tbl"', "identifier": snapshot.parents[0].identifier}], "previous_versions": [], + "table_naming_convention": "schema_and_table", "updated_ts": 1663891973000, "version": snapshot.fingerprint.to_version(), "migrated": False, "unrestorable": False, + "forward_only": False, } +def test_json_with_grants(make_snapshot: t.Callable): + from sqlmesh.core.model.meta import GrantsTargetLayer + + model = SqlModel( + name="name", + kind=dict(time_column="ds", batch_size=30, name=ModelKindName.INCREMENTAL_BY_TIME_RANGE), + owner="owner", + dialect="spark", + cron="1 0 * * *", + start="2020-01-01", + query=parse_one("SELECT @EACH([1, 2], x -> x), ds FROM parent.tbl"), + grants={"SELECT": ["role1", "role2"], "INSERT": ["role3"]}, + grants_target_layer=GrantsTargetLayer.VIRTUAL, + ) + snapshot = make_snapshot(model) + + json_str = snapshot.json() + json_data = json.loads(json_str) + assert ( + json_data["node"]["grants"] + == "('SELECT' = ARRAY('role1', 'role2'), 'INSERT' = ARRAY('role3'))" + ) + assert json_data["node"]["grants_target_layer"] == "virtual" + + reparsed_snapshot = Snapshot.model_validate_json(json_str) + assert isinstance(reparsed_snapshot.node, SqlModel) + assert reparsed_snapshot.node.grants == {"SELECT": ["role1", "role2"], "INSERT": ["role3"]} + assert reparsed_snapshot.node.grants_target_layer == GrantsTargetLayer.VIRTUAL + + +def test_json_custom_materialization(make_snapshot: t.Callable): + model = SqlModel( + name="name", + kind=dict(name=ModelKindName.CUSTOM, materialization="non_existent_should_still_work"), + owner="owner", + dialect="spark", + cron="1 0 * * *", + start="2020-01-01", + query="SELECT @EACH([1, 2], x -> x), ds FROM parent.tbl", + ) + + snapshot = make_snapshot( + model, + nodes={model.fqn: model}, + ) + snapshot.version = snapshot.fingerprint.to_version() + + # this should not throw an error even though the 'non_existent_should_still_work' custom materialization doesnt exist + # this is so we can always deserialize a snapshot based on a custom materialization without the custom materialization class being loaded + new_snapshot = Snapshot.model_validate_json(snapshot.json()) + assert new_snapshot == snapshot + assert isinstance(new_snapshot.model.kind, CustomKind) + assert new_snapshot.model.kind.materialization == "non_existent_should_still_work" + assert new_snapshot.model.kind.materialization_properties == {} + + # this, however, should throw an error + with pytest.raises(SQLMeshError, match=r"Materialization strategy.*was not found"): + new_snapshot.model.validate_definition() + + def test_add_interval(snapshot: Snapshot, make_snapshot): with pytest.raises(ValueError): snapshot.add_interval("2020-01-02", "2020-01-01") @@ -207,7 +301,9 @@ def test_add_interval(snapshot: Snapshot, make_snapshot): def test_add_interval_dev(snapshot: Snapshot, make_snapshot): snapshot.version = "existing_version" - snapshot.change_category = SnapshotChangeCategory.FORWARD_ONLY + snapshot.dev_version_ = "existing_dev_version" + snapshot.change_category = SnapshotChangeCategory.BREAKING + snapshot.forward_only = True snapshot.add_interval("2020-01-01", "2020-01-01") assert snapshot.intervals == [(to_timestamp("2020-01-01"), to_timestamp("2020-01-02"))] @@ -222,8 +318,7 @@ def test_add_interval_dev(snapshot: Snapshot, make_snapshot): assert new_snapshot.dev_intervals == [] new_snapshot = make_snapshot(snapshot.model) - new_snapshot.previous_versions = snapshot.all_versions - new_snapshot.migrated = True + new_snapshot.dev_version_ = snapshot.dev_version new_snapshot.merge_intervals(snapshot) assert new_snapshot.intervals == [(to_timestamp("2020-01-01"), to_timestamp("2020-01-02"))] assert new_snapshot.dev_intervals == [(to_timestamp("2020-01-02"), to_timestamp("2020-01-03"))] @@ -253,6 +348,48 @@ def test_add_interval_partial(snapshot: Snapshot, make_snapshot): ] +@time_machine.travel("2023-01-01 01:00:00 UTC") +def test_get_next_model_interval_start(make_snapshot): + hourly_snapshot = make_snapshot( + SqlModel( + name="late", + kind=FullKind(), + query=parse_one("SELECT 1, ds FROM name"), + cron="@hourly", + interval_unit=IntervalUnit.HALF_HOUR, + ) + ) + + daily_snapshot = make_snapshot( + SqlModel( + name="early", kind=FullKind(), query=parse_one("SELECT 1, ds FROM name"), cron="@daily" + ) + ) + + seed_snapshot = make_snapshot( + SeedModel( + name="seed", + kind=SeedKind(path="./path/to/seed"), + seed=Seed(content="content"), + column_hashes={"col": "hash"}, + depends_on=set(), + interval_unit=IntervalUnit.FIVE_MINUTE, + ) + ) + + audit_snapshot = make_snapshot( + StandaloneAudit( + name="test_standalone_audit", + query=parse_one("SELECT 1"), + interval_unit=IntervalUnit.FIVE_MINUTE, + ) + ) + + assert get_next_model_interval_start( + [daily_snapshot, hourly_snapshot, seed_snapshot, audit_snapshot] + ) == to_datetime("2023-01-01 02:00:00 UTC") + + def test_missing_intervals(snapshot: Snapshot): snapshot.add_interval("2020-01-01", "2020-01-01") snapshot.add_interval("2020-01-03", "2020-01-05") @@ -278,6 +415,7 @@ def test_missing_intervals(snapshot: Snapshot): assert snapshot.missing_intervals("2020-01-03 00:00:01", "2020-01-05 00:00:02") == [] assert snapshot.missing_intervals("2020-01-03 00:00:01", "2020-01-07 00:00:02") == [ (to_timestamp("2020-01-06"), to_timestamp("2020-01-07")), + (to_timestamp("2020-01-07"), to_timestamp("2020-01-08")), ] @@ -296,17 +434,26 @@ def test_missing_intervals_partial(make_snapshot): start = "2023-01-01" end_ts = to_timestamp(start) + 1000 assert snapshot.missing_intervals(start, end_ts) == [ - (to_timestamp(start), to_timestamp("2023-01-02")), - ] - assert snapshot.missing_intervals(start, end_ts, execution_time=end_ts) == [ (to_timestamp(start), end_ts), ] + assert snapshot.missing_intervals(start, end_ts, execution_time=end_ts) == [] + assert snapshot.missing_intervals(start, end_ts, execution_time=end_ts, ignore_cron=True) == [ + (to_timestamp(start), end_ts) + ] + assert snapshot.missing_intervals(start, end_ts, execution_time="2023-01-02") == [ + (to_timestamp(start), end_ts) + ] assert snapshot.missing_intervals(start, start) == [ (to_timestamp(start), to_timestamp("2023-01-02")), ] assert snapshot.missing_intervals(start, start, execution_time=start, ignore_cron=True) == [] assert snapshot.missing_intervals(start, start, execution_time=end_ts, end_bounded=True) == [] + assert snapshot.missing_intervals(start, to_timestamp("2023-01-02 12:00:00")) == [ + (to_timestamp(start), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-02 12:00:00")), + ] + def test_missing_intervals_end_bounded_with_lookback(make_snapshot): snapshot = make_snapshot( @@ -325,13 +472,11 @@ def test_missing_intervals_end_bounded_with_lookback(make_snapshot): snapshot.intervals = [(to_timestamp(start), to_timestamp(end))] execution_ts = to_timestamp("2023-01-03") - assert snapshot.missing_intervals(start, start, execution_time=execution_ts) == [ - (to_timestamp(start), to_timestamp(end)), + assert snapshot.missing_intervals(start, start, execution_time=execution_ts) == [] + assert snapshot.missing_intervals(start, end, execution_time=execution_ts) == [ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), ] - assert ( - snapshot.missing_intervals(start, start, execution_time=execution_ts, end_bounded=True) - == [] - ) def test_missing_intervals_end_bounded_with_ignore_cron(make_snapshot): @@ -370,6 +515,129 @@ def test_missing_intervals_end_bounded_with_ignore_cron(make_snapshot): ] +def test_missing_intervals_past_end_date_with_lookback(make_snapshot): + snapshot: Snapshot = make_snapshot( # type: ignore + SqlModel( + name="test_model", + kind=IncrementalByTimeRangeKind(time_column=TimeColumn(column="ds"), lookback=2), + owner="owner", + cron="@daily", + query=parse_one("SELECT 1, ds FROM name"), + start="2023-01-01", + end="2023-01-05", # inclusive, equivalent to to_timestamp('2023-01-05 23:59:59.999999') + ) + ) + + start_time = to_timestamp("2023-01-01") + end_time = to_timestamp( + "2023-01-06" + ) # exclusive because to_timestamp() returns a timestamp and not a date + assert snapshot.inclusive_exclusive(snapshot.node.start, snapshot.node.end) == ( + start_time, + end_time, + ) + + # baseline - all intervals missing + assert snapshot.missing_intervals(start_time, end_time, execution_time=end_time) == [ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), + (to_timestamp("2023-01-02"), to_timestamp("2023-01-03")), + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + ] + + # fully backfill model - no intervals missing + snapshot.add_interval(start_time, end_time) + + # even though lookback=2, because every interval has been filled, + # there should be no missing intervals + assert snapshot.missing_intervals(start_time, end_time, execution_time=end_time) == [] + + # however, when running for a new interval, this triggers lookback + # in this case, we remove the most recent interval (the one for 2023-01-05) to simulate it being new + # since lookback=2 days, this triggers missing intervals for 2023-01-03, 2023-01-04, 2023-01-05 + snapshot.remove_interval(interval=(to_timestamp("2023-01-05"), to_timestamp("2023-01-06"))) + assert snapshot.missing_intervals(start_time, end_time, execution_time=end_time) == [ + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + ] + + # put the interval we just removed back to make the model fully backfilled again + snapshot.add_interval(to_timestamp("2023-01-05"), to_timestamp("2023-01-06")) + assert snapshot.missing_intervals(start_time, end_time, execution_time=end_time) == [] + + # running on the end date + 1 day (2023-01-07) + # 2023-01-06 "would" run and since lookback=2 this pulls in 2023-01-04 and 2023-01-05 as well + # however, only 2023-01-04 and 2023-01-05 are within the model end date + end_time = to_timestamp("2023-01-07") + assert snapshot.missing_intervals(start_time, end_time, execution_time=end_time) == [ + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + ] + + # running on the end date + 2 days (2023-01-08) + # 2023-01-07 "would" run and since lookback=2 this pulls in 2023-01-06 and 2023-01-05 as well + # however, only 2023-01-05 is within the model end date + end_time = to_timestamp("2023-01-08") + assert snapshot.missing_intervals(start_time, end_time, execution_time=end_time) == [ + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")) + ] + + # running on the end date + 3 days (2023-01-09) + # no missing intervals because subtracting 2 days for lookback exceeds the models end date + end_time = to_timestamp("2023-01-09") + assert snapshot.missing_intervals(start_time, end_time, execution_time=end_time) == [] + + # running way in the future, no missing intervals because subtracting 2 days for lookback still exceeds the models end date + end_time = to_timestamp("2024-01-01") + assert snapshot.missing_intervals(start_time, end_time, execution_time=end_time) == [] + + +def test_missing_intervals_start_override_per_model(make_snapshot: t.Callable[..., Snapshot]): + snapshot = make_snapshot( + load_sql_based_model( + parse(""" + MODEL ( + name a, + kind FULL, + start '2023-02-01', + cron '@daily' + ); + SELECT 1; + """) + ), + version="a", + ) + + # base case - no override + assert list( + missing_intervals(execution_time="2023-02-08 00:05:07", snapshots=[snapshot]).values() + )[0] == [ + (to_timestamp("2023-02-01"), to_timestamp("2023-02-02")), + (to_timestamp("2023-02-02"), to_timestamp("2023-02-03")), + (to_timestamp("2023-02-03"), to_timestamp("2023-02-04")), + (to_timestamp("2023-02-04"), to_timestamp("2023-02-05")), + (to_timestamp("2023-02-05"), to_timestamp("2023-02-06")), + (to_timestamp("2023-02-06"), to_timestamp("2023-02-07")), + (to_timestamp("2023-02-07"), to_timestamp("2023-02-08")), + ] + + # with override, should use the overridden start date when calculating missing intervals + assert list( + missing_intervals( + start="1 day ago", + execution_time="2023-02-08 00:05:07", + snapshots=[snapshot], + start_override_per_model={snapshot.name: to_datetime("2023-02-05 00:00:00")}, + ).values() + )[0] == [ + (to_timestamp("2023-02-05"), to_timestamp("2023-02-06")), + (to_timestamp("2023-02-06"), to_timestamp("2023-02-07")), + (to_timestamp("2023-02-07"), to_timestamp("2023-02-08")), + ] + + def test_incremental_time_self_reference(make_snapshot): snapshot = make_snapshot( SqlModel( @@ -395,7 +663,7 @@ def test_incremental_time_self_reference(make_snapshot): ] -def test_lookback(snapshot: Snapshot, make_snapshot): +def test_lookback(make_snapshot): snapshot = make_snapshot( SqlModel( name="name", @@ -411,59 +679,75 @@ def test_lookback(snapshot: Snapshot, make_snapshot): ] snapshot.add_interval("2023-01-01", "2023-01-04") - assert snapshot.missing_intervals("2023-01-01", "2023-01-02") == [] - - snapshot.add_interval("2023-01-06", "2023-01-07") - assert snapshot.missing_intervals("2023-01-03", "2023-01-03") == [ + assert snapshot.missing_intervals("2023-01-01", "2023-01-04") == [] + assert snapshot.missing_intervals("2023-01-01", "2023-01-05") == [ (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), - ] - assert snapshot.missing_intervals("2023-01-04", "2023-01-04") == [ (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), ] + + snapshot.add_interval("2023-01-06", "2023-01-07") + assert snapshot.missing_intervals("2023-01-03", "2023-01-03") == [] assert snapshot.missing_intervals("2023-01-05", "2023-01-05") == [ (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), ] - assert snapshot.missing_intervals("2023-01-03", "2023-01-05") == [ - (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), - (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + assert snapshot.missing_intervals("2023-01-06", "2023-01-07") == [] + assert snapshot.missing_intervals("2023-01-05", "2023-01-08") == [ (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), - ] - snapshot.add_interval("2023-01-05", "2023-01-05") - assert snapshot.missing_intervals("2023-01-03", "2023-01-06") == [ (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), - ] - - assert snapshot.missing_intervals("2023-01-29", "2023-01-29") == [ - (to_timestamp("2023-01-29"), to_timestamp("2023-01-30")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + (to_timestamp("2023-01-08"), to_timestamp("2023-01-09")), ] snapshot.add_interval("2023-01-28", "2023-01-29") assert snapshot.missing_intervals("2023-01-27", "2023-01-27", "2023-01-30 05:00:00") == [ (to_timestamp("2023-01-27"), to_timestamp("2023-01-28")), ] - assert snapshot.missing_intervals("2023-01-28", "2023-01-28", "2023-01-30 05:00:00") == [ - (to_timestamp("2023-01-28"), to_timestamp("2023-01-29")), - ] - assert snapshot.missing_intervals("2023-01-29", "2023-01-29", "2023-01-30 05:00:00") == [ - (to_timestamp("2023-01-29"), to_timestamp("2023-01-30")), - ] - assert snapshot.missing_intervals("2023-01-27", "2023-01-29", "2023-01-30 05:00:00") == [ - (to_timestamp("2023-01-27"), to_timestamp("2023-01-28")), + assert snapshot.missing_intervals("2023-01-28", "2023-01-29", "2023-01-31 05:00:00") == [] + assert snapshot.missing_intervals("2023-01-28", "2023-01-30", "2023-01-31 05:00:00") == [ (to_timestamp("2023-01-28"), to_timestamp("2023-01-29")), (to_timestamp("2023-01-29"), to_timestamp("2023-01-30")), + (to_timestamp("2023-01-30"), to_timestamp("2023-01-31")), ] - snapshot.add_interval("2023-01-28", "2023-01-30") - assert snapshot.missing_intervals("2023-01-27", "2023-01-27", "2023-01-30 05:00:00") == [ - (to_timestamp("2023-01-27"), to_timestamp("2023-01-28")), - ] - assert snapshot.missing_intervals("2023-01-28", "2023-01-28", "2023-01-30 05:00:00") == [] - assert snapshot.missing_intervals("2023-01-29", "2023-01-29", "2023-01-30 05:00:00") == [] - assert snapshot.missing_intervals("2023-01-27", "2023-01-29", "2023-01-30 05:00:00") == [ - (to_timestamp("2023-01-27"), to_timestamp("2023-01-28")), + assert snapshot.missing_intervals("2023-01-28", "2023-01-30", "2023-01-31 04:00:00") == [] + + +def test_lookback_custom_materialization(make_snapshot): + from sqlmesh import CustomMaterialization + + class MyTestStrategy(CustomMaterialization): + pass + + expressions = parse( + """ + MODEL ( + name name, + kind CUSTOM ( + materialization 'MyTestStrategy', + lookback 2 + ), + start '2023-01-01', + cron '0 5 * * *', + ); + + SELECT ds FROM parent.tbl + """ + ) + + snapshot = make_snapshot(load_sql_based_model(expressions)) + + assert snapshot.missing_intervals("2023-01-01", "2023-01-01") == [ + (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), ] - assert snapshot.missing_intervals("2023-01-30", "2023-01-30", "2023-01-30") == [] + snapshot.add_interval("2023-01-01", "2023-01-04") + assert snapshot.missing_intervals("2023-01-01", "2023-01-04") == [] + assert snapshot.missing_intervals("2023-01-01", "2023-01-05") == [ + (to_timestamp("2023-01-03"), to_timestamp("2023-01-04")), + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + ] def test_seed_intervals(make_snapshot): @@ -549,14 +833,14 @@ def test_missing_interval_smaller_than_interval_unit(make_snapshot): ) assert snapshot_partial.missing_intervals("2020-01-01 00:00:05", "2020-01-01 23:59:59") == [ - (to_timestamp("2020-01-01"), to_timestamp("2020-01-02")) + (to_timestamp("2020-01-01"), to_timestamp("2020-01-01 23:59:59")) ] assert snapshot_partial.missing_intervals("2020-01-01 00:00:00", "2020-01-02 00:00:00") == [ (to_timestamp("2020-01-01"), to_timestamp("2020-01-02")) ] -def test_remove_intervals(snapshot: Snapshot): +def test_remove_intervals(snapshot): snapshot.add_interval("2020-01-01", "2020-01-01") snapshot.remove_interval(snapshot.get_removal_interval("2020-01-01", "2020-01-01")) assert snapshot.intervals == [] @@ -635,6 +919,63 @@ def test_get_removal_intervals_full_history_restatement_model(make_snapshot): assert interval == (to_timestamp("2023-01-01"), execution_time) +def test_get_removal_intervals_warns_when_requested_range_automatically_widened( + make_snapshot: t.Callable[..., Snapshot], mocker: MockerFixture +): + mock_logger = mocker.patch.object(get_console(), "log_warning") + + # INCREMENTAL_BY_UNIQUE_KEY should warn + snapshot = make_snapshot( + SqlModel( + name="name", + kind=IncrementalByUniqueKeyKind(unique_key=[exp.to_column("id")]), + query=parse_one("select id from src"), + ) + ) + + assert not snapshot.intervals + assert snapshot.full_history_restatement_only + + snapshot.add_interval("2020-01-01", "2020-01-10") + + # should warn if requested intervals are a subset of actual intervals and thus are automatically expanded + snapshot.get_removal_interval("2020-01-05", "2020-01-06") + + msg = mock_logger.call_args[0][0] + assert "does not support partial restatement" in msg + assert "Expanding the requested restatement intervals" in msg + + # should not warn if requested intervals are equal to actual intervals + mock_logger.reset_mock() + + snapshot.get_removal_interval("2020-01-01", "2020-01-10") + mock_logger.assert_not_called() + + # should not warn if requested intervals are a superset of actual intervals + mock_logger.reset_mock() + + snapshot.get_removal_interval("2019-12-30", "2020-01-15") + mock_logger.assert_not_called() + + # should not warn on models that support partial restatement, such as INCREMENTAL_BY_TIME_RANGE + mock_logger.reset_mock() + snapshot = make_snapshot( + SqlModel( + name="name", + kind=IncrementalByTimeRangeKind(time_column=TimeColumn(column="ds")), + query=parse_one("select ds from src"), + ) + ) + + assert not snapshot.intervals + assert not snapshot.full_history_restatement_only + + snapshot.add_interval("2020-01-01", "2020-01-10") + + snapshot.get_removal_interval("2020-01-05", "2020-01-06") + mock_logger.assert_not_called() + + each_macro = lambda: "test" # noqa: E731 @@ -643,12 +984,11 @@ def test_fingerprint(model: Model, parent_model: Model): fingerprint = fingerprint_from_node(model, nodes={}) original_fingerprint = SnapshotFingerprint( - data_hash="3582214120", - metadata_hash="2793463216", + data_hash="2406542604", + metadata_hash="1056339358", ) assert fingerprint == original_fingerprint - with_parent_fingerprint = fingerprint_from_node(model, nodes={'"parent"."tbl"': parent_model}) assert with_parent_fingerprint != fingerprint assert int(with_parent_fingerprint.parent_data_hash) > 0 @@ -667,12 +1007,12 @@ def test_fingerprint(model: Model, parent_model: Model): new_fingerprint = fingerprint_from_node(model, nodes={}) assert new_fingerprint != fingerprint assert new_fingerprint.data_hash != fingerprint.data_hash - assert new_fingerprint.metadata_hash == fingerprint.metadata_hash + assert new_fingerprint.metadata_hash != fingerprint.metadata_hash model = SqlModel(**{**model.dict(), "query": parse_one("select 1, ds -- annotation")}) fingerprint = fingerprint_from_node(model, nodes={}) assert new_fingerprint != fingerprint - assert new_fingerprint.data_hash == fingerprint.data_hash + assert new_fingerprint.data_hash != fingerprint.data_hash assert new_fingerprint.metadata_hash != fingerprint.metadata_hash model = SqlModel( @@ -681,13 +1021,15 @@ def test_fingerprint(model: Model, parent_model: Model): fingerprint = fingerprint_from_node(model, nodes={}) assert new_fingerprint != fingerprint assert new_fingerprint.data_hash != fingerprint.data_hash - assert new_fingerprint.metadata_hash == fingerprint.metadata_hash + assert new_fingerprint.metadata_hash != fingerprint.metadata_hash + assert fingerprint.metadata_hash != original_fingerprint.metadata_hash model = SqlModel(**{**original_model.dict(), "post_statements": [parse_one("DROP TABLE test")]}) fingerprint = fingerprint_from_node(model, nodes={}) assert new_fingerprint != fingerprint assert new_fingerprint.data_hash != fingerprint.data_hash - assert new_fingerprint.metadata_hash == fingerprint.metadata_hash + assert new_fingerprint.metadata_hash != fingerprint.metadata_hash + assert fingerprint.metadata_hash != original_fingerprint.metadata_hash def test_fingerprint_seed_model(): @@ -703,8 +1045,8 @@ def test_fingerprint_seed_model(): ) expected_fingerprint = SnapshotFingerprint( - data_hash="2156038176", - metadata_hash="3403817841", + data_hash="2112858704", + metadata_hash="2674364560", ) model = load_sql_based_model(expressions, path=Path("./examples/sushi/models/test_model.sql")) @@ -742,12 +1084,13 @@ def test_fingerprint_jinja_macros(model: Model): } ) original_fingerprint = SnapshotFingerprint( - data_hash="2973224250", - metadata_hash="2793463216", + data_hash="93332825", + metadata_hash="1056339358", ) fingerprint = fingerprint_from_node(model, nodes={}) assert fingerprint == original_fingerprint + model = model.copy() model.jinja_macros.root_macros["test_macro"] = MacroInfo( definition="{% macro test_macro() %}b{% endmacro %}", depends_on=[] @@ -766,8 +1109,10 @@ def test_fingerprint_jinja_macros_global_objs(model: Model, global_obj_key: str) } ) fingerprint = fingerprint_from_node(model, nodes={}) - - model.jinja_macros.global_objs[global_obj_key] = AttributeDict({"test": "test"}) + model = model.copy() + model.jinja_macros.global_objs[global_obj_key] = AttributeDict( + {"test": AttributeDict({"test": "test"})} + ) updated_fingerprint = fingerprint_from_node(model, nodes={}) assert updated_fingerprint.data_hash != fingerprint.data_hash assert updated_fingerprint.metadata_hash == fingerprint.metadata_hash @@ -817,6 +1162,82 @@ def test_fingerprint_virtual_properties(model: Model, parent_model: Model): assert updated_fingerprint.data_hash == fingerprint.data_hash +def test_fingerprint_grants(model: Model, parent_model: Model): + from sqlmesh.core.model.meta import GrantsTargetLayer + + original_model = deepcopy(model) + fingerprint = fingerprint_from_node(model, nodes={}) + + updated_model = SqlModel( + **original_model.dict(), + grants={"SELECT": ["role1", "role2"]}, + ) + updated_fingerprint = fingerprint_from_node(updated_model, nodes={}) + + assert updated_fingerprint != fingerprint + assert updated_fingerprint.metadata_hash != fingerprint.metadata_hash + assert updated_fingerprint.data_hash == fingerprint.data_hash + + different_grants_model = SqlModel( + **original_model.dict(), + grants={"SELECT": ["role3"], "INSERT": ["role4"]}, + ) + different_grants_fingerprint = fingerprint_from_node(different_grants_model, nodes={}) + + assert different_grants_fingerprint.metadata_hash != updated_fingerprint.metadata_hash + assert different_grants_fingerprint.metadata_hash != fingerprint.metadata_hash + + target_layer_model = SqlModel( + **{**original_model.dict(), "grants_target_layer": GrantsTargetLayer.PHYSICAL}, + grants={"SELECT": ["role1", "role2"]}, + ) + target_layer_fingerprint = fingerprint_from_node(target_layer_model, nodes={}) + + assert target_layer_fingerprint.metadata_hash != updated_fingerprint.metadata_hash + + +def test_tableinfo_equality(): + snapshot_a = SnapshotTableInfo( + name="test_schema.a", + fingerprint=SnapshotFingerprint(data_hash="1", metadata_hash="1"), + version="test_version", + physical_schema="test_physical_schema", + parents=[], + dev_table_suffix="dev", + ) + + snapshot_b = SnapshotTableInfo( + name="test_schema.b", + fingerprint=SnapshotFingerprint(data_hash="1", metadata_hash="1"), + version="test_version", + physical_schema="test_physical_schema", + parents=[], + dev_table_suffix="dev", + ) + + snapshot_c = SnapshotTableInfo( + name="test_schema.c", + fingerprint=SnapshotFingerprint(data_hash="1", metadata_hash="1"), + version="test_version", + physical_schema="test_physical_schema", + parents=[snapshot_a.snapshot_id, snapshot_b.snapshot_id], + dev_table_suffix="dev", + ) + + # parents in different order than snapshot_c + snapshot_c2 = SnapshotTableInfo( + name="test_schema.c", + fingerprint=SnapshotFingerprint(data_hash="1", metadata_hash="1"), + version="test_version", + physical_schema="test_physical_schema", + parents=[snapshot_b.snapshot_id, snapshot_a.snapshot_id], + dev_table_suffix="dev", + ) + + assert snapshot_c is not snapshot_c2 + assert snapshot_c == snapshot_c2 + + def test_stamp(model: Model): original_fingerprint = fingerprint_from_node(model, nodes={}) @@ -826,15 +1247,20 @@ def test_stamp(model: Model): assert original_fingerprint != stamped_fingerprint -def test_table_name(snapshot: Snapshot, make_snapshot: t.Callable): +def test_snapshot_table_name(snapshot: Snapshot, make_snapshot: t.Callable): # Mimic a direct breaking change. snapshot.fingerprint = SnapshotFingerprint( data_hash="1", metadata_hash="1", parent_data_hash="1" ) snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + assert snapshot.table_naming_convention == TableNamingConvention.SCHEMA_AND_TABLE + assert snapshot.data_version.table_naming_convention == TableNamingConvention.SCHEMA_AND_TABLE + snapshot.previous_versions = () assert snapshot.table_name(is_deployable=True) == "sqlmesh__default.name__3078928823" - assert snapshot.table_name(is_deployable=False) == "sqlmesh__default.name__3078928823__temp" + assert snapshot.table_name(is_deployable=False) == "sqlmesh__default.name__3078928823__dev" + + assert snapshot.dev_version == snapshot.fingerprint.to_version() # Mimic an indirect non-breaking change. previous_data_version = snapshot.data_version @@ -846,16 +1272,18 @@ def test_table_name(snapshot: Snapshot, make_snapshot: t.Callable): snapshot.categorize_as(SnapshotChangeCategory.INDIRECT_NON_BREAKING) assert snapshot.table_name(is_deployable=True) == "sqlmesh__default.name__3078928823" # Indirect non-breaking snapshots reuse the dev table as well. - assert snapshot.table_name(is_deployable=False) == "sqlmesh__default.name__3078928823__temp" + assert snapshot.table_name(is_deployable=False) == "sqlmesh__default.name__3078928823__dev" + assert snapshot.dev_version != snapshot.fingerprint.to_version() + assert snapshot.dev_version == previous_data_version.dev_version # Mimic a direct forward-only change. snapshot.fingerprint = SnapshotFingerprint( data_hash="2", metadata_hash="1", parent_data_hash="1" ) snapshot.previous_versions = (previous_data_version,) - snapshot.categorize_as(SnapshotChangeCategory.FORWARD_ONLY) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) assert snapshot.table_name(is_deployable=True) == "sqlmesh__default.name__3078928823" - assert snapshot.table_name(is_deployable=False) == "sqlmesh__default.name__3049392110__temp" + assert snapshot.table_name(is_deployable=False) == "sqlmesh__default.name__3049392110__dev" fully_qualified_snapshot = make_snapshot( SqlModel(name='"my-catalog".db.table', query=parse_one("select 1, ds")) @@ -877,33 +1305,171 @@ def test_table_name(snapshot: Snapshot, make_snapshot: t.Callable): ) -def test_categorize_change_sql(make_snapshot): - old_snapshot = make_snapshot(SqlModel(name="a", query=parse_one("select 1, ds"))) - - config = CategorizerConfig(sql=AutoCategorizationMode.SEMI) - - # A projection has been added. - assert ( - categorize_change( - new=make_snapshot(SqlModel(name="a", query=parse_one("select 1, 2, ds"))), - old=old_snapshot, - config=config, - ) - == SnapshotChangeCategory.NON_BREAKING +def test_table_name_naming_convention_table_only(make_snapshot: t.Callable[..., Snapshot]): + # 3-part naming + snapshot = make_snapshot( + SqlModel(name='"foo"."bar"."baz"', query=parse_one("select 1")), + table_naming_convention=TableNamingConvention.TABLE_ONLY, ) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + assert snapshot.table_naming_convention == TableNamingConvention.TABLE_ONLY + assert snapshot.data_version.table_naming_convention == TableNamingConvention.TABLE_ONLY - # A complex projection has been added. + assert snapshot.table_name(is_deployable=True) == f"foo.sqlmesh__bar.baz__{snapshot.version}" assert ( - categorize_change( - new=make_snapshot( - SqlModel( - name="a", query=parse_one("select 1, fun(another_fun(a + 1) * 2)::INT, ds") - ) - ), - old=old_snapshot, - config=config, - ) - == SnapshotChangeCategory.NON_BREAKING + snapshot.table_name(is_deployable=False) == f"foo.sqlmesh__bar.baz__{snapshot.version}__dev" + ) + + # 2-part naming + snapshot = make_snapshot( + SqlModel(name='"foo"."bar"', query=parse_one("select 1")), + table_naming_convention=TableNamingConvention.TABLE_ONLY, + ) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + assert snapshot.table_name(is_deployable=True) == f"sqlmesh__foo.bar__{snapshot.version}" + assert snapshot.table_name(is_deployable=False) == f"sqlmesh__foo.bar__{snapshot.version}__dev" + + +def test_table_name_naming_convention_hash_md5(make_snapshot: t.Callable[..., Snapshot]): + # 3-part naming + snapshot = make_snapshot( + SqlModel(name='"foo"."bar"."baz"', query=parse_one("select 1")), + table_naming_convention=TableNamingConvention.HASH_MD5, + ) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + assert snapshot.table_naming_convention == TableNamingConvention.HASH_MD5 + assert snapshot.data_version.table_naming_convention == TableNamingConvention.HASH_MD5 + + hash = md5(f"foo.sqlmesh__bar.bar__baz__{snapshot.version}") + assert snapshot.table_name(is_deployable=True) == f"foo.sqlmesh__bar.sqlmesh_md5__{hash}" + hash_dev = md5(f"foo.sqlmesh__bar.bar__baz__{snapshot.version}__dev") + assert ( + snapshot.table_name(is_deployable=False) == f"foo.sqlmesh__bar.sqlmesh_md5__{hash_dev}__dev" + ) + + # 2-part naming + snapshot = make_snapshot( + SqlModel(name='"foo"."bar"', query=parse_one("select 1")), + table_naming_convention=TableNamingConvention.HASH_MD5, + ) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + hash = md5(f"sqlmesh__foo.foo__bar__{snapshot.version}") + assert snapshot.table_name(is_deployable=True) == f"sqlmesh__foo.sqlmesh_md5__{hash}" + + hash_dev = md5(f"sqlmesh__foo.foo__bar__{snapshot.version}__dev") + assert snapshot.table_name(is_deployable=False) == f"sqlmesh__foo.sqlmesh_md5__{hash_dev}__dev" + + +def test_table_naming_convention_passed_around_correctly(make_snapshot: t.Callable[..., Snapshot]): + snapshot = make_snapshot( + SqlModel(name='"foo"."bar"."baz"', query=parse_one("select 1")), + table_naming_convention=TableNamingConvention.HASH_MD5, + ) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + assert snapshot.table_naming_convention == TableNamingConvention.HASH_MD5 + assert snapshot.data_version.table_naming_convention == TableNamingConvention.HASH_MD5 + assert snapshot.table_info.table_naming_convention == TableNamingConvention.HASH_MD5 + assert ( + snapshot.table_info.data_version.table_naming_convention == TableNamingConvention.HASH_MD5 + ) + assert snapshot.table_info.table_info.table_naming_convention == TableNamingConvention.HASH_MD5 + assert ( + snapshot.table_info.table_info.data_version.table_naming_convention + == TableNamingConvention.HASH_MD5 + ) + + +def test_table_name_view(make_snapshot: t.Callable): + # Mimic a direct breaking change. + snapshot = make_snapshot(SqlModel(name="name", query=parse_one("select 1"), kind="VIEW")) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot.previous_versions = () + assert snapshot.table_name(is_deployable=True) == f"sqlmesh__default.name__{snapshot.version}" + assert ( + snapshot.table_name(is_deployable=False) + == f"sqlmesh__default.name__{snapshot.dev_version}__dev" + ) + + assert snapshot.dev_version == snapshot.fingerprint.to_version() + + # Mimic an indirect non-breaking change. + new_snapshot = make_snapshot(SqlModel(name="name", query=parse_one("select 2"), kind="VIEW")) + previous_data_version = snapshot.data_version + new_snapshot.previous_versions = (previous_data_version,) + new_snapshot.categorize_as(SnapshotChangeCategory.INDIRECT_NON_BREAKING) + assert ( + new_snapshot.table_name(is_deployable=True) == f"sqlmesh__default.name__{snapshot.version}" + ) + # Indirect non-breaking view snapshots should not reuse the dev table. + assert ( + new_snapshot.table_name(is_deployable=False) + == f"sqlmesh__default.name__{new_snapshot.dev_version}__dev" + ) + assert new_snapshot.dev_version == new_snapshot.fingerprint.to_version() + assert new_snapshot.version == snapshot.version + assert new_snapshot.dev_version != snapshot.dev_version + + +def test_table_naming_convention_change_reuse_previous_version(make_snapshot): + # Ensure that snapshots that trigger "reuse previous version" inherit the naming convention of the previous snapshot + original_snapshot: Snapshot = make_snapshot( + SqlModel(name="a", query=parse_one("select 1, ds")), + table_naming_convention=TableNamingConvention.SCHEMA_AND_TABLE, + ) + original_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + assert original_snapshot.table_naming_convention == TableNamingConvention.SCHEMA_AND_TABLE + assert original_snapshot.table_name() == f"sqlmesh__default.a__{original_snapshot.version}" + + changed_snapshot: Snapshot = make_snapshot( + SqlModel(name="a", query=parse_one("select 1, 'forward_only' as a, ds")), + table_naming_convention=TableNamingConvention.HASH_MD5, + ) + changed_snapshot.previous_versions = original_snapshot.all_versions + + assert changed_snapshot.previous_version == original_snapshot.data_version + + changed_snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) + + # inherited from previous version even though changed_snapshot was created with TableNamingConvention.HASH_MD5 + assert changed_snapshot.table_naming_convention == TableNamingConvention.SCHEMA_AND_TABLE + assert ( + changed_snapshot.previous_version.table_naming_convention + == TableNamingConvention.SCHEMA_AND_TABLE + ) + assert changed_snapshot.table_name() == f"sqlmesh__default.a__{changed_snapshot.version}" + + +def test_categorize_change_sql(make_snapshot): + old_snapshot = make_snapshot(SqlModel(name="a", query=parse_one("select 1, ds"))) + + config = CategorizerConfig(sql=AutoCategorizationMode.SEMI) + + # A projection has been added. + assert ( + categorize_change( + new=make_snapshot(SqlModel(name="a", query=parse_one("select 1, 2, ds"))), + old=old_snapshot, + config=config, + ) + == SnapshotChangeCategory.NON_BREAKING + ) + + # A complex projection has been added. + assert ( + categorize_change( + new=make_snapshot( + SqlModel( + name="a", query=parse_one("select 1, fun(another_fun(a + 1) * 2)::INT, ds") + ) + ), + old=old_snapshot, + config=config, + ) + == SnapshotChangeCategory.NON_BREAKING ) # Multiple projections have been added. @@ -1271,7 +1837,7 @@ def test_physical_schema(snapshot: Snapshot): new_snapshot.previous_versions = (snapshot.data_version,) new_snapshot.physical_schema_ = None new_snapshot.version = None - new_snapshot.categorize_as(SnapshotChangeCategory.FORWARD_ONLY) + new_snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) assert new_snapshot.physical_schema == "custom_schema" assert new_snapshot.data_version.physical_schema == "custom_schema" @@ -1281,10 +1847,10 @@ def test_physical_schema(snapshot: Snapshot): def test_has_paused_forward_only(snapshot: Snapshot): assert not has_paused_forward_only([snapshot], [snapshot]) - snapshot.categorize_as(SnapshotChangeCategory.FORWARD_ONLY) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) assert has_paused_forward_only([snapshot], [snapshot]) - snapshot.set_unpaused_ts("2023-01-01") + snapshot.unpaused_ts = to_timestamp("2023-01-01") assert not has_paused_forward_only([snapshot], [snapshot]) @@ -1303,12 +1869,12 @@ def test_inclusive_exclusive_monthly(make_snapshot): assert snapshot.inclusive_exclusive("2023-01-01", "2023-07-01") == ( to_timestamp("2023-01-01"), - to_timestamp("2023-07-01"), + to_timestamp("2023-08-01"), ) assert snapshot.inclusive_exclusive("2023-01-01", "2023-07-06") == ( to_timestamp("2023-01-01"), - to_timestamp("2023-07-01"), + to_timestamp("2023-08-01"), ) assert snapshot.inclusive_exclusive("2023-01-01", "2023-07-31") == ( @@ -1484,13 +2050,46 @@ def test_is_valid_start(make_snapshot): EnvironmentNamingInfo(name="dev", catalog_name_override="g-h"), '"g-h".default__dev."e-f"', ), - (QualifiedViewName(table="e-f"), EnvironmentNamingInfo(name="dev"), 'default__dev."e-f"'), + ( + QualifiedViewName(table="e-f"), + EnvironmentNamingInfo(name="dev"), + 'default__dev."e-f"', + ), + # EnvironmentSuffixTarget.CATALOG + ( + QualifiedViewName( + catalog="default-foo", schema_name="sqlmesh_example", table="full_model" + ), + EnvironmentNamingInfo( + name="dev", + suffix_target=EnvironmentSuffixTarget.CATALOG, + ), + '"default-foo__dev".sqlmesh_example.full_model', + ), + ( + QualifiedViewName(catalog="default", schema_name="sqlmesh_example", table="full_model"), + EnvironmentNamingInfo( + name=c.PROD, + catalog_name_override=None, + suffix_target=EnvironmentSuffixTarget.CATALOG, + ), + "default.sqlmesh_example.full_model", + ), ), ) def test_qualified_view_name(qualified_view_name, environment_naming_info, expected): assert qualified_view_name.for_environment(environment_naming_info) == expected +def test_qualified_view_name_with_dialect(): + qualified_view_name = QualifiedViewName(catalog="catalog", schema_name="db", table="table") + environment_naming_info = EnvironmentNamingInfo(name="dev", catalog_name_override="override") + assert ( + qualified_view_name.for_environment(environment_naming_info, dialect="snowflake") + == "OVERRIDE.db__DEV.table" + ) + + def test_multi_interval_merge(make_snapshot): a = make_snapshot( SqlModel( @@ -1542,7 +2141,7 @@ def test_deployability_index(make_snapshot): snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) snapshot_b = make_snapshot(SqlModel(name="b", query=parse_one("SELECT 1"))) - snapshot_b.categorize_as(SnapshotChangeCategory.FORWARD_ONLY) + snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) snapshot_b.parents = (snapshot_a.snapshot_id,) snapshot_c = make_snapshot(SqlModel(name="c", query=parse_one("SELECT 1"))) @@ -1561,6 +2160,7 @@ def test_deployability_index(make_snapshot): snapshot_f.parents = (snapshot_e.snapshot_id, snapshot_a.snapshot_id) snapshot_g = make_snapshot(SqlModel(name="g", query=parse_one("SELECT 1"))) + snapshot_g.intervals = [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))] snapshot_g.categorize_as(SnapshotChangeCategory.INDIRECT_NON_BREAKING) snapshot_g.parents = (snapshot_e.snapshot_id,) @@ -1606,7 +2206,7 @@ def test_deployability_index(make_snapshot): def test_deployability_index_unpaused_forward_only(make_snapshot): snapshot_a = make_snapshot(SqlModel(name="a", query=parse_one("SELECT 1"))) - snapshot_a.categorize_as(SnapshotChangeCategory.FORWARD_ONLY) + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) snapshot_a.unpaused_ts = 1 snapshot_b = make_snapshot(SqlModel(name="b", query=parse_one("SELECT 1"))) @@ -1624,6 +2224,43 @@ def test_deployability_index_unpaused_forward_only(make_snapshot): assert deplyability_index.is_representative(snapshot_b) +def test_deployability_index_unpaused_auto_restatement(make_snapshot): + model_a = SqlModel( + name="a", + query=parse_one("SELECT 1, ds"), + kind=IncrementalByTimeRangeKind( + time_column="ds", forward_only=True, auto_restatement_cron="@weekly" + ), + ) + snapshot_a = make_snapshot(model_a) + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) + snapshot_a.unpaused_ts = 1 + + # Snapshot B is a child of a model with auto restatement and is not paused, + # so it is not deployable but is representative + snapshot_b = make_snapshot(SqlModel(name="b", query=parse_one("SELECT 1"))) + snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_b.parents = (snapshot_a.snapshot_id,) + snapshot_b.unpaused_ts = 1 + + # Snapshot C is paused and hence is neither deployable nor representative + snapshot_c = make_snapshot(SqlModel(name="c", query=parse_one("SELECT 1"))) + snapshot_c.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_c.parents = (snapshot_b.snapshot_id,) + + deplyability_index = DeployabilityIndex.create( + {s.snapshot_id: s for s in [snapshot_a, snapshot_b, snapshot_c]} + ) + + assert not deplyability_index.is_deployable(snapshot_a) + assert not deplyability_index.is_deployable(snapshot_b) + assert not deplyability_index.is_deployable(snapshot_c) + + assert deplyability_index.is_representative(snapshot_a) + assert deplyability_index.is_representative(snapshot_b) + assert not deplyability_index.is_representative(snapshot_c) + + def test_deployability_index_uncategorized_forward_only_model(make_snapshot): model_a = SqlModel( name="a", @@ -1651,6 +2288,35 @@ def test_deployability_index_uncategorized_forward_only_model(make_snapshot): assert not deployability_index.is_representative(snapshot_b) +def test_deployability_index_auto_restatement_model(make_snapshot): + model_a = SqlModel( + name="a", + query=parse_one("SELECT 1, ds"), + kind=IncrementalByTimeRangeKind( + time_column="ds", forward_only=False, auto_restatement_cron="@weekly" + ), + ) + + snapshot_a_old = make_snapshot(model_a) + snapshot_a_old.categorize_as(SnapshotChangeCategory.BREAKING) + + snapshot_a = make_snapshot(model_a) + snapshot_a.previous_versions = snapshot_a_old.all_versions + + snapshot_b = make_snapshot(SqlModel(name="b", query=parse_one("SELECT 1"))) + snapshot_b.parents = (snapshot_a.snapshot_id,) + + deployability_index = DeployabilityIndex.create( + {s.snapshot_id: s for s in [snapshot_a, snapshot_b]} + ) + + assert not deployability_index.is_deployable(snapshot_a) + assert not deployability_index.is_deployable(snapshot_b) + + assert not deployability_index.is_representative(snapshot_a) + assert not deployability_index.is_representative(snapshot_b) + + def test_deployability_index_categorized_forward_only_model(make_snapshot): model_a = SqlModel( name="a", @@ -1663,23 +2329,21 @@ def test_deployability_index_categorized_forward_only_model(make_snapshot): snapshot_a = make_snapshot(model_a) snapshot_a.previous_versions = snapshot_a_old.all_versions - snapshot_a.categorize_as(SnapshotChangeCategory.METADATA) + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) snapshot_b = make_snapshot(SqlModel(name="b", query=parse_one("SELECT 1"))) snapshot_b.parents = (snapshot_a.snapshot_id,) - snapshot_b.categorize_as(SnapshotChangeCategory.METADATA) + snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) - # The fact that the model is forward only should be ignored if an actual category - # has been assigned. deployability_index = DeployabilityIndex.create( {s.snapshot_id: s for s in [snapshot_a, snapshot_b]} ) - assert deployability_index.is_deployable(snapshot_a) - assert deployability_index.is_deployable(snapshot_b) + assert not deployability_index.is_deployable(snapshot_a) + assert not deployability_index.is_deployable(snapshot_b) - assert deployability_index.is_representative(snapshot_a) - assert deployability_index.is_representative(snapshot_b) + assert not deployability_index.is_representative(snapshot_a) + assert not deployability_index.is_representative(snapshot_b) def test_deployability_index_missing_parent(make_snapshot): @@ -1687,7 +2351,7 @@ def test_deployability_index_missing_parent(make_snapshot): snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) snapshot_b = make_snapshot(SqlModel(name="b", query=parse_one("SELECT 1"))) - snapshot_b.categorize_as(SnapshotChangeCategory.FORWARD_ONLY) + snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) snapshot_b.parents = (snapshot_a.snapshot_id,) deplyability_index = DeployabilityIndex.create({snapshot_b.snapshot_id: snapshot_b}) @@ -1697,30 +2361,205 @@ def test_deployability_index_missing_parent(make_snapshot): @pytest.mark.parametrize( - "model_name, environment_naming_info, default_catalog, expected", + "call_kwargs, expected", + [ + ######################################## + # TableNamingConvention.SCHEMA_AND_TABLE + ( + dict(physical_schema="sqlmesh__foo", name="bar", version="1234"), + "sqlmesh__foo.bar__1234", + ), + ( + dict(physical_schema="sqlmesh__foo", name="foo.bar", version="1234"), + "sqlmesh__foo.foo__bar__1234", + ), + ( + dict(physical_schema="sqlmesh__foo", name="bar", version="1234", catalog="foo"), + "foo.sqlmesh__foo.bar__1234", + ), + ( + dict(physical_schema="sqlmesh__foo", name="bar.baz", version="1234", catalog="foo"), + "foo.sqlmesh__foo.bar__baz__1234", + ), + ( + dict(physical_schema="sqlmesh__foo", name="bar.baz", version="1234", suffix="dev"), + "sqlmesh__foo.bar__baz__1234__dev", + ), + ( + dict( + physical_schema="sqlmesh__foo", + name="bar.baz", + version="1234", + catalog="foo", + suffix="dev", + ), + "foo.sqlmesh__foo.bar__baz__1234__dev", + ), + ################################## + # TableNamingConvention.TABLE_ONLY + ( + dict( + physical_schema="sqlmesh__foo", + name="bar", + version="1234", + naming_convention=TableNamingConvention.TABLE_ONLY, + ), + "sqlmesh__foo.bar__1234", + ), + ( + dict( + physical_schema="sqlmesh__foo", + name="foo.bar", + version="1234", + naming_convention=TableNamingConvention.TABLE_ONLY, + ), + "sqlmesh__foo.bar__1234", + ), + ( + dict( + physical_schema="sqlmesh__foo", + name="bar", + version="1234", + catalog="foo", + naming_convention=TableNamingConvention.TABLE_ONLY, + ), + "foo.sqlmesh__foo.bar__1234", + ), + ( + dict( + physical_schema="sqlmesh__bar", + name="bar.baz", + version="1234", + catalog="foo", + naming_convention=TableNamingConvention.TABLE_ONLY, + ), + "foo.sqlmesh__bar.baz__1234", + ), + ( + dict( + physical_schema="sqlmesh__bar", + name="bar.baz", + version="1234", + suffix="dev", + naming_convention=TableNamingConvention.TABLE_ONLY, + ), + "sqlmesh__bar.baz__1234__dev", + ), + ( + dict( + physical_schema="sqlmesh__bar", + name="bar.baz", + version="1234", + catalog="foo", + suffix="dev", + naming_convention=TableNamingConvention.TABLE_ONLY, + ), + "foo.sqlmesh__bar.baz__1234__dev", + ), + ################################# + # TableNamingConvention.HASH_MD5 + ( + dict( + physical_schema="sqlmesh__foo", + name="bar", + version="1234", + naming_convention=TableNamingConvention.HASH_MD5, + ), + f"sqlmesh__foo.sqlmesh_md5__{md5('sqlmesh__foo.bar__1234')}", + ), + ( + dict( + physical_schema="sqlmesh__foo", + name="foo.bar", + version="1234", + naming_convention=TableNamingConvention.HASH_MD5, + ), + f"sqlmesh__foo.sqlmesh_md5__{md5('sqlmesh__foo.foo__bar__1234')}", + ), + ( + dict( + physical_schema="sqlmesh__foo", + name="bar", + version="1234", + catalog="foo", + naming_convention=TableNamingConvention.HASH_MD5, + ), + f"foo.sqlmesh__foo.sqlmesh_md5__{md5('foo.sqlmesh__foo.bar__1234')}", + ), + ( + dict( + physical_schema="sqlmesh__foo", + name="bar.baz", + version="1234", + catalog="foo", + naming_convention=TableNamingConvention.HASH_MD5, + ), + f"foo.sqlmesh__foo.sqlmesh_md5__{md5('foo.sqlmesh__foo.bar__baz__1234')}", + ), + ( + dict( + physical_schema="sqlmesh__foo", + name="bar.baz", + version="1234", + suffix="dev", + naming_convention=TableNamingConvention.HASH_MD5, + ), + f"sqlmesh__foo.sqlmesh_md5__{md5('sqlmesh__foo.bar__baz__1234__dev')}__dev", + ), + ( + dict( + physical_schema="sqlmesh__foo", + name="bar.baz", + version="1234", + catalog="foo", + suffix="dev", + naming_convention=TableNamingConvention.HASH_MD5, + ), + f"foo.sqlmesh__foo.sqlmesh_md5__{md5('foo.sqlmesh__foo.bar__baz__1234__dev')}__dev", + ), + ], +) +def test_table_name(call_kwargs: t.Dict[str, t.Any], expected: str): + """ + physical_schema: str + name: str + version: str + catalog: t.Optional[str] + suffix: t.Optional[str] + naming_convention: t.Optional[TableNamingConvention] + """ + assert table_name(**call_kwargs) == expected + + +@pytest.mark.parametrize( + "model_name, environment_naming_info, default_catalog, dialect, expected", ( ( "test_db.test_model", EnvironmentNamingInfo(), None, + "duckdb", "test_db.test_model", ), ( "test_db.test_model", EnvironmentNamingInfo(name="dev"), None, + "duckdb", "test_db__dev.test_model", ), ( "test_db.test_model", EnvironmentNamingInfo(name="dev", suffix_target=EnvironmentSuffixTarget.SCHEMA), None, + "duckdb", "test_db__dev.test_model", ), ( "test_db.test_model", EnvironmentNamingInfo(name="dev", suffix_target=EnvironmentSuffixTarget.TABLE), None, + "duckdb", "test_db.test_model__dev", ), ( @@ -1731,12 +2570,14 @@ def test_deployability_index_missing_parent(make_snapshot): catalog_name_override="catalog_override", ), None, + "duckdb", "catalog_override.test_db.test_model__dev", ), ( "original_catalog.test_db.test_model", EnvironmentNamingInfo(name="dev", suffix_target=EnvironmentSuffixTarget.TABLE), "default_catalog", + "duckdb", "original_catalog.test_db.test_model__dev", ), ( @@ -1747,6 +2588,7 @@ def test_deployability_index_missing_parent(make_snapshot): catalog_name_override="catalog_override", ), "default_catalog", + "duckdb", "catalog_override.test_db.test_model__dev", ), ( @@ -1757,18 +2599,46 @@ def test_deployability_index_missing_parent(make_snapshot): catalog_name_override="catalog_override", ), "default_catalog", + "duckdb", "catalog_override.test_db.test_model__dev", ), ( "test_db.test_model", EnvironmentNamingInfo(name="dev", suffix_target=EnvironmentSuffixTarget.TABLE), "default_catalog", + "duckdb", "test_db.test_model__dev", ), + ( + "test_db.test_model", + EnvironmentNamingInfo( + name="dev", + suffix_target=EnvironmentSuffixTarget.TABLE, + catalog_name_override="catalog_override", + ), + "default_catalog", + "snowflake", + "CATALOG_OVERRIDE.test_db.test_model__DEV", + ), + # EnvironmentSuffixTarget.CATALOG + ( + "test_db.test_model", + EnvironmentNamingInfo(name="dev", suffix_target=EnvironmentSuffixTarget.CATALOG), + "default_catalog", + "duckdb", + "default_catalog__dev.test_db.test_model", + ), + ( + "test_db.test_model", + EnvironmentNamingInfo(name="dev", suffix_target=EnvironmentSuffixTarget.CATALOG), + "default_catalog", + "snowflake", + "DEFAULT_CATALOG__DEV.test_db.test_model", + ), ), ) def test_display_name( - make_snapshot, model_name, environment_naming_info, default_catalog, expected + make_snapshot, model_name, environment_naming_info, default_catalog, dialect, expected ): input_model = SqlModel( name=model_name, @@ -1777,7 +2647,10 @@ def test_display_name( default_catalog=default_catalog, ) input_snapshot = make_snapshot(input_model) - assert display_name(input_snapshot, environment_naming_info, default_catalog) == expected + assert ( + display_name(input_snapshot, environment_naming_info, default_catalog, dialect=dialect) + == expected + ) def test_missing_intervals_node_start_end(make_snapshot): @@ -1815,7 +2688,8 @@ def test_missing_intervals_node_start_end(make_snapshot): def test_external_model_audits(sushi_context): snapshot = sushi_context.get_snapshot("raw.demographics") assert snapshot.evaluatable - assert len(snapshot.model.audits) == 2 + assert len(snapshot.model.audits) == 3 + assert len(snapshot.model.audit_definitions) == 1 assert snapshot.intervals @@ -1849,3 +2723,937 @@ class MyCustomStrategy(CustomMaterialization): parsed_table_info = SnapshotTableInfo.parse_raw(table_info.json()) assert parsed_table_info.custom_materialization == "MyCustomStrategy" + + +def test_ttl_ms(make_snapshot): + snapshot = make_snapshot( + SqlModel( + name="test_model_name", + query=parse_one("SELECT 1"), + ), + ttl="in 1 week", + ) + assert snapshot.ttl_ms == 604800000 + + +def test_snapshot_cache(make_snapshot, tmp_path): + cache_path = tmp_path / "snapshot_cache" + cache = SnapshotCache(cache_path) + + snapshot = make_snapshot( + SqlModel( + name="test_model_name", + query=parse_one("SELECT 1"), + ) + ) + snapshot.add_interval("2024-01-01", "2024-01-02") + snapshot.add_interval("2024-01-01", "2024-01-02", is_dev=True) + + loader_called_times = 0 + + def _loader(snapshot_ids: t.Set[SnapshotId]) -> t.Collection[Snapshot]: + nonlocal loader_called_times + loader_called_times += 1 + return [snapshot] + + assert cache.get_or_load([snapshot.snapshot_id], _loader) == ( + {snapshot.snapshot_id: snapshot}, + set(), + ) + assert cache.get_or_load([snapshot.snapshot_id], _loader) == ( + {snapshot.snapshot_id: snapshot}, + {snapshot.snapshot_id}, + ) + assert cache.get_or_load([snapshot.snapshot_id], _loader) == ( + {snapshot.snapshot_id: snapshot}, + {snapshot.snapshot_id}, + ) + assert loader_called_times == 1 + + cached_snapshot = cache.get_or_load([snapshot.snapshot_id], _loader)[0][snapshot.snapshot_id] + assert cached_snapshot.model._query_renderer._optimized_cache is not None + assert cached_snapshot.model._data_hash is not None + assert cached_snapshot.model._metadata_hash is not None + assert not cached_snapshot.intervals + assert not cached_snapshot.dev_intervals + + cache.clear() + assert cache.get_or_load([snapshot.snapshot_id], _loader) == ( + {snapshot.snapshot_id: snapshot}, + set(), + ) + assert loader_called_times == 2 + + +def test_snapshot_pickle_intervals(make_snapshot): + snapshot = make_snapshot( + SqlModel( + name="test_model_name", + query=parse_one("SELECT 1"), + ) + ) + snapshot.add_interval("2023-01-01", "2023-01-02") + snapshot.add_interval("2023-01-01", "2023-01-02", is_dev=True) + + assert len(snapshot.intervals) > 0 + assert len(snapshot.dev_intervals) > 0 + + loaded_snapshot = pickle.loads(pickle.dumps(snapshot)) + assert not loaded_snapshot.intervals + assert not loaded_snapshot.dev_intervals + assert len(snapshot.intervals) > 0 + assert len(snapshot.dev_intervals) > 0 + + +def test_missing_intervals_end_override_per_model(make_snapshot): + snapshot_a = make_snapshot( + SqlModel( + name="a", + start="2023-01-04", + query=parse_one("SELECT 1"), + ), + version="a", + ) + + snapshot_b = make_snapshot( + SqlModel( + name="b", + start="2023-01-04", + query=parse_one("SELECT 2"), + ), + version="b", + ) + + assert missing_intervals( + [snapshot_a, snapshot_b], + start="2023-01-04", + end="2023-01-10", + end_override_per_model={ + snapshot_a.name: to_datetime("2023-01-09"), + snapshot_b.name: to_datetime("2023-01-06"), + }, + ) == { + snapshot_a: [ + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + (to_timestamp("2023-01-06"), to_timestamp("2023-01-07")), + (to_timestamp("2023-01-07"), to_timestamp("2023-01-08")), + (to_timestamp("2023-01-08"), to_timestamp("2023-01-09")), + ], + snapshot_b: [ + (to_timestamp("2023-01-04"), to_timestamp("2023-01-05")), + (to_timestamp("2023-01-05"), to_timestamp("2023-01-06")), + ], + } + + assert missing_intervals( + [snapshot_a, snapshot_b], + start="2023-01-08", + end="2023-01-08", + end_override_per_model={ + snapshot_a.name: to_datetime("2023-01-09"), + snapshot_b.name: to_datetime( + "2023-01-06" + ), # The interval end is before the start. The snapshot will be skipped + }, + ) == { + snapshot_a: [(to_timestamp("2023-01-08"), to_timestamp("2023-01-09"))], + } + + +def test_physical_version_pin(make_snapshot): + snapshot = make_snapshot( + SqlModel( + name="a", + kind=dict( + time_column="ds", name=ModelKindName.INCREMENTAL_BY_TIME_RANGE, forward_only=True + ), + query=parse_one("SELECT 1, ds"), + physical_version="1234", + ), + ) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + assert snapshot.version == "1234" + + +def test_physical_version_pin_for_new_forward_only_models(make_snapshot): + # A new forward-only model. + snapshot_a = make_snapshot( + SqlModel( + name="a", + kind=dict( + time_column="ds", name=ModelKindName.INCREMENTAL_BY_TIME_RANGE, forward_only=True + ), + query=parse_one("SELECT 1, ds"), + ), + ) + snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) + + # Another version of the new forward-only model created independently. + snapshot_b = make_snapshot( + SqlModel( + name="a", + kind=dict( + time_column="ds", name=ModelKindName.INCREMENTAL_BY_TIME_RANGE, forward_only=True + ), + query=parse_one("SELECT 2, ds"), + ), + ) + snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) + + assert snapshot_a.fingerprint != snapshot_b.fingerprint + assert snapshot_a.version == snapshot_b.version + + # A change to the forward-only model. + snapshot_c = make_snapshot( + SqlModel( + name="a", + kind=dict( + time_column="ds", name=ModelKindName.INCREMENTAL_BY_TIME_RANGE, forward_only=True + ), + query=parse_one("SELECT 3, ds"), + ), + ) + snapshot_c.previous_versions = snapshot_b.all_versions + snapshot_c.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) + + assert snapshot_b.fingerprint != snapshot_c.fingerprint + assert snapshot_b.version == snapshot_c.version + + # Make model non-forward-only. + snapshot_d = make_snapshot( + SqlModel( + name="a", + kind=dict(time_column="ds", name=ModelKindName.INCREMENTAL_BY_TIME_RANGE), + query=parse_one("SELECT 4, ds"), + ), + ) + snapshot_d.previous_versions = snapshot_c.all_versions + snapshot_d.categorize_as(SnapshotChangeCategory.BREAKING) + + assert snapshot_c.fingerprint != snapshot_d.fingerprint + assert snapshot_c.version != snapshot_d.version + + # Make it forward-only again. + snapshot_e = make_snapshot( + SqlModel( + name="a", + kind=dict( + time_column="ds", name=ModelKindName.INCREMENTAL_BY_TIME_RANGE, forward_only=True + ), + query=parse_one("SELECT 5, ds"), + ), + ) + snapshot_e.previous_versions = snapshot_d.all_versions + snapshot_e.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) + + assert snapshot_d.fingerprint != snapshot_e.fingerprint + assert snapshot_d.version == snapshot_e.version + + # Pin the version explicitly. + snapshot_f = make_snapshot( + SqlModel( + name="a", + kind=dict( + time_column="ds", name=ModelKindName.INCREMENTAL_BY_TIME_RANGE, forward_only=True + ), + query=parse_one("SELECT 5, ds"), + physical_version="1234", + ), + ) + snapshot_f.previous_versions = snapshot_e.all_versions + snapshot_f.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) + + assert snapshot_f.version == "1234" + assert snapshot_f.fingerprint != snapshot_e.fingerprint + + +def test_contiguous_intervals(): + assert _contiguous_intervals([]) == [] + assert _contiguous_intervals([(0, 1)]) == [[(0, 1)]] + assert _contiguous_intervals([(0, 1), (1, 2), (2, 3)]) == [[(0, 1), (1, 2), (2, 3)]] + assert _contiguous_intervals([(0, 1), (3, 4), (4, 5), (6, 7)]) == [ + [(0, 1)], + [(3, 4), (4, 5)], + [(6, 7)], + ] + + +def test_check_ready_intervals(mocker: MockerFixture): + def assert_always_signal(intervals): + assert ( + check_ready_intervals(lambda _: True, intervals, mocker.Mock(), mocker.Mock()) + == intervals + ) + + assert_always_signal([]) + assert_always_signal([(0, 1)]) + assert_always_signal([(0, 1), (1, 2)]) + assert_always_signal([(0, 1), (2, 3)]) + + def assert_never_signal(intervals): + assert check_ready_intervals(lambda _: False, intervals, mocker.Mock(), mocker.Mock()) == [] + + assert_never_signal([]) + assert_never_signal([(0, 1)]) + assert_never_signal([(0, 1), (1, 2)]) + assert_never_signal([(0, 1), (2, 3)]) + + def assert_empty_signal(intervals): + assert check_ready_intervals(lambda _: [], intervals, mocker.Mock(), mocker.Mock()) == [] + + assert_empty_signal([]) + assert_empty_signal([(0, 1)]) + assert_empty_signal([(0, 1), (1, 2)]) + assert_empty_signal([(0, 1), (2, 3)]) + + def to_intervals(values: t.List[t.Tuple[int, int]]) -> DatetimeRanges: + return [(to_datetime(s), to_datetime(e)) for s, e in values] + + def assert_check_intervals( + intervals: t.List[t.Tuple[int, int]], + ready: t.List[t.List[t.Tuple[int, int]]], + expected: t.List[t.Tuple[int, int]], + ): + mock = mocker.Mock() + mock.side_effect = [to_intervals(r) for r in ready] + check_ready_intervals(mock, intervals, mocker.Mock(), mocker.Mock()) == expected + + assert_check_intervals([], [], []) + assert_check_intervals([(0, 1)], [[]], []) + assert_check_intervals( + [(0, 1)], + [[(0, 1)]], + [(0, 1)], + ) + assert_check_intervals( + [(0, 1), (1, 2)], + [[(0, 1)]], + [(0, 1)], + ) + assert_check_intervals( + [(0, 1), (1, 2)], + [[(1, 2)]], + [(1, 2)], + ) + assert_check_intervals( + [(0, 1), (1, 2)], + [[(0, 1), (1, 2)]], + [(0, 1), (1, 2)], + ) + assert_check_intervals( + [(0, 1), (1, 2), (3, 4)], + [[], []], + [], + ) + assert_check_intervals( + [(0, 1), (1, 2), (3, 4)], + [[(0, 1)], []], + [(0, 1)], + ) + assert_check_intervals( + [(0, 1), (1, 2), (3, 4)], + [[(0, 1)], [(3, 4)]], + [(0, 1), (3, 4)], + ) + + with pytest.raises(SignalEvalError): + check_ready_intervals( + lambda _: (_ for _ in ()).throw(MemoryError("Some exception")), + [(0, 1), (1, 2)], + mocker.Mock(), + mocker.Mock(), + ) + + +@pytest.mark.parametrize( + "auto_restatement_intervals,expected_auto_restatement_start", + [ + (None, "2020-01-01"), + (2, "2020-01-04"), + ], +) +def test_get_next_auto_restatement_interval( + make_snapshot, + auto_restatement_intervals: t.Optional[int], + expected_auto_restatement_start: str, +): + snapshot = make_snapshot( + SqlModel( + name="test_model", + kind=IncrementalByTimeRangeKind( + time_column=TimeColumn(column="ds"), + auto_restatement_cron="0 10 * * *", + auto_restatement_intervals=auto_restatement_intervals, + ), + cron="@daily", + query=parse_one("SELECT 1, ds FROM name"), + ) + ) + snapshot.add_interval("2020-01-01", "2020-01-05") + snapshot.next_auto_restatement_ts = to_timestamp("2020-01-06 10:00:00") + + assert snapshot.get_next_auto_restatement_interval(to_timestamp("2020-01-06 09:59:00")) is None + + assert snapshot.get_next_auto_restatement_interval(to_timestamp("2020-01-06 10:01:00")) == ( + to_timestamp(expected_auto_restatement_start), + to_timestamp("2020-01-06"), + ) + + +def test_apply_auto_restatements(make_snapshot): + # Hourly upstream model with auto restatement intervals set to 24 + model_a = SqlModel( + name="test_model_a", + kind=IncrementalByTimeRangeKind( + time_column=TimeColumn(column="ds"), + auto_restatement_cron="0 10 * * *", + auto_restatement_intervals=24, + ), + cron="@hourly", + query=parse_one("SELECT 1, ds FROM name"), + ) + snapshot_a = make_snapshot(model_a, version="1") + snapshot_a.add_interval("2020-01-01", "2020-01-06 09:00:00") + snapshot_a.next_auto_restatement_ts = to_timestamp("2020-01-06 10:00:00") + + # Daily downstream model with no auto restatement + model_b = SqlModel( + name="test_model_b", + kind=IncrementalByTimeRangeKind( + time_column=TimeColumn(column="ds"), + ), + cron="@daily", + query=parse_one("SELECT ds FROM test_model_a"), + ) + snapshot_b = make_snapshot(model_b, nodes={model_a.fqn: model_a}, version="2") + snapshot_b.add_interval("2020-01-01", "2020-01-05") + assert snapshot_a.snapshot_id in snapshot_b.parents + + # Daily downstream model with auto restatement intervals of 2 + model_c = SqlModel( + name="test_model_c", + kind=IncrementalByTimeRangeKind( + time_column=TimeColumn(column="ds"), + auto_restatement_cron="0 10 * * *", + auto_restatement_intervals=2, + ), + cron="@daily", + query=parse_one("SELECT ds FROM test_model_a"), + ) + snapshot_c = make_snapshot(model_c, nodes={model_a.fqn: model_a}, version="3") + snapshot_c.add_interval("2020-01-01", "2020-01-05") + snapshot_c.next_auto_restatement_ts = to_timestamp("2020-01-06 10:00:00") + assert snapshot_a.snapshot_id in snapshot_c.parents + + # Hourly downstream model with auto restatement intervals of 5 + model_d = SqlModel( + name="test_model_d", + kind=IncrementalByTimeRangeKind( + time_column=TimeColumn(column="ds"), + auto_restatement_cron="0 10 * * *", + auto_restatement_intervals=5, + ), + cron="@hourly", + query=parse_one("SELECT ds FROM test_model_a"), + ) + snapshot_d = make_snapshot(model_d, nodes={model_a.fqn: model_a}, version="4") + snapshot_d.add_interval("2020-01-01", "2020-01-06 09:00:00") + snapshot_d.next_auto_restatement_ts = to_timestamp("2020-01-06 10:00:00") + assert snapshot_a.snapshot_id in snapshot_d.parents + + # Hourly upstream model with auto restatement intervals set to 5 + model_e = SqlModel( + name="test_model_e", + kind=IncrementalByTimeRangeKind( + time_column=TimeColumn(column="ds"), + auto_restatement_cron="0 10 * * *", + auto_restatement_intervals=5, + ), + cron="@hourly", + query=parse_one("SELECT 1, ds FROM name"), + ) + snapshot_e = make_snapshot(model_e, version="5") + snapshot_e.add_interval("2020-01-01", "2020-01-06 09:00:00") + snapshot_e.next_auto_restatement_ts = to_timestamp("2020-01-06 10:00:00") + + # Daily model downstream from model_e without auto restatement that should not be impacted by auto restatement + # upstream. + model_f = SqlModel( + name="test_model_f", + kind=IncrementalByTimeRangeKind( + time_column=TimeColumn(column="ds"), + ), + cron="@daily", + query=parse_one("SELECT ds FROM test_model_e"), + ) + snapshot_f = make_snapshot(model_f, nodes={model_e.fqn: model_e}, version="6") + snapshot_f.add_interval("2020-01-01", "2020-01-05") + assert snapshot_e.snapshot_id in snapshot_f.parents + + assert snapshot_a.intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-06 09:00:00")), + ] + assert snapshot_b.intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-06")), + ] + assert snapshot_c.intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-06")), + ] + assert snapshot_d.intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-06 09:00:00")), + ] + assert snapshot_e.intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-06 09:00:00")), + ] + assert snapshot_f.intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-06")), + ] + + restated_intervals, _ = apply_auto_restatements( + { + snapshot_a.snapshot_id: snapshot_a, + snapshot_b.snapshot_id: snapshot_b, + snapshot_c.snapshot_id: snapshot_c, + snapshot_d.snapshot_id: snapshot_d, + snapshot_e.snapshot_id: snapshot_e, + snapshot_f.snapshot_id: snapshot_f, + }, + "2020-01-06 10:01:00", + ) + assert sorted(restated_intervals, key=lambda x: x.name) == [ + SnapshotIntervals( + name=snapshot_a.name, + identifier=None, + version=snapshot_a.version, + dev_version=None, + intervals=[], + dev_intervals=[], + pending_restatement_intervals=[ + (to_timestamp("2020-01-05 10:00:00"), to_timestamp("2020-01-06 10:00:00")) + ], + ), + SnapshotIntervals( + name=snapshot_b.name, + identifier=None, + version=snapshot_b.version, + dev_version=None, + intervals=[], + dev_intervals=[], + pending_restatement_intervals=[ + (to_timestamp("2020-01-05"), to_timestamp("2020-01-07")) + ], + ), + SnapshotIntervals( + name=snapshot_c.name, + identifier=None, + version=snapshot_c.version, + dev_version=None, + intervals=[], + dev_intervals=[], + pending_restatement_intervals=[ + (to_timestamp("2020-01-04"), to_timestamp("2020-01-07")) + ], + ), + SnapshotIntervals( + name=snapshot_d.name, + identifier=None, + version=snapshot_d.version, + dev_version=None, + intervals=[], + dev_intervals=[], + pending_restatement_intervals=[ + (to_timestamp("2020-01-05 10:00:00"), to_timestamp("2020-01-06 10:00:00")) + ], + ), + SnapshotIntervals( + name=snapshot_e.name, + identifier=None, + version=snapshot_e.version, + dev_version=None, + intervals=[], + dev_intervals=[], + pending_restatement_intervals=[ + (to_timestamp("2020-01-06 05:00:00"), to_timestamp("2020-01-06 10:00:00")) + ], + ), + SnapshotIntervals( + name=snapshot_f.name, + identifier=None, + version=snapshot_f.version, + dev_version=None, + intervals=[], + dev_intervals=[], + pending_restatement_intervals=[ + (to_timestamp("2020-01-06"), to_timestamp("2020-01-07")) + ], + ), + ] + + assert snapshot_a.next_auto_restatement_ts == to_timestamp("2020-01-07 10:00:00") + assert snapshot_b.next_auto_restatement_ts is None + assert snapshot_c.next_auto_restatement_ts == to_timestamp("2020-01-07 10:00:00") + assert snapshot_d.next_auto_restatement_ts == to_timestamp("2020-01-07 10:00:00") + assert snapshot_e.next_auto_restatement_ts == to_timestamp("2020-01-07 10:00:00") + assert snapshot_f.next_auto_restatement_ts is None + + assert snapshot_a.intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-05 10:00:00")), + ] + assert snapshot_b.intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-05")), + ] + assert snapshot_c.intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-04")), + ] + assert snapshot_d.intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-05 10:00:00")), + ] + assert snapshot_e.intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-06 05:00:00")), + ] + assert snapshot_f.intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-06")), + ] + + +def test_apply_auto_restatements_disable_restatement_downstream(make_snapshot): + # Hourly upstream model with auto restatement intervals set to 24 + model_a = SqlModel( + name="test_model_a", + kind=IncrementalByTimeRangeKind( + time_column=TimeColumn(column="ds"), + auto_restatement_cron="0 10 * * *", + auto_restatement_intervals=24, + ), + cron="@hourly", + query=parse_one("SELECT 1, ds FROM name"), + ) + snapshot_a = make_snapshot(model_a, version="1") + snapshot_a.add_interval("2020-01-01", "2020-01-06 09:00:00") + snapshot_a.next_auto_restatement_ts = to_timestamp("2020-01-06 10:00:00") + + # Daily downstream model with disable restatement + model_b = SqlModel( + name="test_model_b", + kind=IncrementalByTimeRangeKind( + time_column=TimeColumn(column="ds"), + disable_restatement=True, + ), + cron="@daily", + query=parse_one("SELECT ds FROM test_model_a"), + ) + snapshot_b = make_snapshot(model_b, nodes={model_a.fqn: model_a}, version="2") + snapshot_b.add_interval("2020-01-01", "2020-01-05") + assert snapshot_a.snapshot_id in snapshot_b.parents + + restated_intervals, _ = apply_auto_restatements( + { + snapshot_a.snapshot_id: snapshot_a, + snapshot_b.snapshot_id: snapshot_b, + }, + "2020-01-06 10:01:00", + ) + assert sorted(restated_intervals, key=lambda x: x.name) == [ + SnapshotIntervals( + name=snapshot_a.name, + identifier=None, + version=snapshot_a.version, + dev_version=None, + intervals=[], + dev_intervals=[], + pending_restatement_intervals=[ + (to_timestamp("2020-01-05 10:00:00"), to_timestamp("2020-01-06 10:00:00")) + ], + ), + ] + + assert snapshot_a.next_auto_restatement_ts == to_timestamp("2020-01-07 10:00:00") + assert snapshot_b.next_auto_restatement_ts is None + + assert snapshot_a.intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-05 10:00:00")), + ] + assert snapshot_b.intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-06")), + ] + + snapshot_b.pending_restatement_intervals = [ + (to_timestamp("2020-01-03"), to_timestamp("2020-01-06")) + ] + snapshot_b.apply_pending_restatement_intervals() + assert snapshot_b.intervals == [ + (to_timestamp("2020-01-01"), to_timestamp("2020-01-06")), + ] + + +def test_auto_restatement_triggers(make_snapshot): + # Auto restatements: + # a, c, d + # dag: + # a -> b + # a -> c + # [b, c, d] -> e + model_a = SqlModel( + name="test_model_a", + kind=IncrementalByTimeRangeKind( + time_column=TimeColumn(column="ds"), + auto_restatement_cron="0 10 * * *", + auto_restatement_intervals=24, + ), + start="2020-01-01", + cron="@daily", + query=parse_one("SELECT 1 as ds"), + ) + snapshot_a = make_snapshot(model_a, version="1") + snapshot_a.add_interval("2020-01-01", "2020-01-05") + snapshot_a.next_auto_restatement_ts = to_timestamp("2020-01-06 10:00:00") + + model_b = SqlModel( + name="test_model_b", + kind=IncrementalByTimeRangeKind( + time_column=TimeColumn(column="ds"), + ), + start="2020-01-01", + cron="@daily", + query=parse_one("SELECT ds FROM test_model_a"), + ) + snapshot_b = make_snapshot(model_b, nodes={model_a.fqn: model_a}, version="1") + snapshot_b.add_interval("2020-01-01", "2020-01-05") + + model_c = SqlModel( + name="test_model_c", + kind=IncrementalByTimeRangeKind( + time_column=TimeColumn(column="ds"), + auto_restatement_cron="0 10 * * *", + auto_restatement_intervals=24, + ), + start="2020-01-01", + cron="@daily", + query=parse_one("SELECT ds FROM test_model_a"), + ) + snapshot_c = make_snapshot(model_c, nodes={model_a.fqn: model_a}, version="1") + snapshot_c.add_interval("2020-01-01", "2020-01-05") + snapshot_c.next_auto_restatement_ts = to_timestamp("2020-01-06 10:00:00") + + model_d = SqlModel( + name="test_model_d", + kind=IncrementalByTimeRangeKind( + time_column=TimeColumn(column="ds"), + auto_restatement_cron="0 10 * * *", + auto_restatement_intervals=24, + ), + start="2020-01-01", + cron="@daily", + query=parse_one("SELECT 1 as ds"), + ) + snapshot_d = make_snapshot(model_d, version="1") + snapshot_d.add_interval("2020-01-01", "2020-01-05") + snapshot_d.next_auto_restatement_ts = to_timestamp("2020-01-06 10:00:00") + + model_e = SqlModel( + name="test_model_e", + kind=IncrementalByTimeRangeKind( + time_column=TimeColumn(column="ds"), + ), + start="2020-01-01", + cron="@daily", + query=parse_one( + "SELECT ds from test_model_b UNION ALL SELECT ds from test_model_c UNION ALL SELECT ds from test_model_d" + ), + ) + snapshot_e = make_snapshot( + model_e, + nodes={ + model_a.fqn: model_a, + model_b.fqn: model_b, + model_c.fqn: model_c, + model_d.fqn: model_d, + }, + version="1", + ) + snapshot_e.add_interval("2020-01-01", "2020-01-05") + + _, auto_restatement_triggers = apply_auto_restatements( + { + snapshot_a.snapshot_id: snapshot_a, + snapshot_b.snapshot_id: snapshot_b, + snapshot_c.snapshot_id: snapshot_c, + snapshot_d.snapshot_id: snapshot_d, + snapshot_e.snapshot_id: snapshot_e, + }, + "2020-01-06 10:01:00", + ) + + assert auto_restatement_triggers[snapshot_a.snapshot_id] == [snapshot_a.snapshot_id] + assert auto_restatement_triggers[snapshot_c.snapshot_id] == [snapshot_c.snapshot_id] + assert auto_restatement_triggers[snapshot_d.snapshot_id] == [snapshot_d.snapshot_id] + assert auto_restatement_triggers[snapshot_b.snapshot_id] == [snapshot_a.snapshot_id] + # a via b, c and d directly + assert sorted(auto_restatement_triggers[snapshot_e.snapshot_id]) == [ + snapshot_a.snapshot_id, + snapshot_c.snapshot_id, + snapshot_d.snapshot_id, + ] + + +def test_render_signal(make_snapshot, mocker): + @signal() + def check_types(batch, env: str, sql: list[SQL], table: exp.Table, default: int = 0): + if not ( + env == "in_memory" + and default == 0 + and isinstance(sql, list) + and isinstance(sql[0], str) + and isinstance(table, exp.Table) + ): + raise + return True + + sql_model = load_sql_based_model( + parse( + """ + MODEL ( + name test_schema.test_model, + signals check_types(env := @gateway, sql := [a.b], table := b.c) + ); + SELECT a FROM tbl; + """ + ), + variables={ + c.GATEWAY: "in_memory", + }, + signal_definitions=signal.get_registry(), + ) + snapshot_a = make_snapshot(sql_model) + assert snapshot_a.check_ready_intervals([(0, 1)], mocker.Mock()) == [(0, 1)] + + +def test_partitioned_by_roundtrip(make_snapshot: t.Callable): + sql_model = load_sql_based_model( + parse(""" + MODEL ( + name test_schema.test_model, + kind full, + partitioned_by (a, bucket(4, b), truncate(3, c), month(d)) + ); + SELECT a, b, c, d FROM tbl; + """) + ) + snapshot = make_snapshot(sql_model) + assert isinstance(snapshot, Snapshot) + assert isinstance(snapshot.node, SqlModel) + + assert snapshot.node.partitioned_by == [ + exp.column("a", quoted=True), + exp.PartitionedByBucket( + this=exp.column("b", quoted=True), expression=exp.Literal.number(4) + ), + exp.PartitionByTruncate( + this=exp.column("c", quoted=True), expression=exp.Literal.number(3) + ), + exp.Month(this=exp.column("d", quoted=True)), + ] + + # roundtrip through json and ensure we get correct AST nodes on the other end + serialized = snapshot.json() + deserialized = snapshot.parse_raw(serialized) + + assert isinstance(deserialized.node, SqlModel) + assert deserialized.node.partitioned_by == snapshot.node.partitioned_by + + +@pytest.mark.parametrize( + "virtual_env_mode,is_deployable,expected_uses_name_as_is", + [ + (VirtualEnvironmentMode.DEV_ONLY, True, True), + (VirtualEnvironmentMode.DEV_ONLY, False, False), + (VirtualEnvironmentMode.FULL, True, False), + (VirtualEnvironmentMode.FULL, False, False), + ], +) +def test_table_name_virtual_environment_mode( + make_snapshot, + virtual_env_mode: VirtualEnvironmentMode, + is_deployable: bool, + expected_uses_name_as_is: bool, +): + model = SqlModel( + name="my_schema.my_model", + kind=IncrementalByTimeRangeKind(time_column="ds"), + query=parse_one("SELECT 1, ds"), + virtual_environment_mode=virtual_env_mode, + ) + + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + table_name_result = snapshot.table_name(is_deployable=is_deployable) + + if expected_uses_name_as_is: + assert table_name_result == '"my_schema"."my_model"' + else: + # Should contain the versioned table name with schema prefix + assert "sqlmesh__my_schema" in table_name_result + assert "my_schema__my_model" in table_name_result + if is_deployable: + assert table_name_result.endswith(snapshot.version) + else: + assert table_name_result.endswith(f"{snapshot.dev_version}__dev") + + +def test_snapshot_id_and_version_fingerprint_lazy_init(): + snapshot = SnapshotIdAndVersion( + name="a", + identifier="1234", + version="2345", + dev_version=None, + fingerprint='{"data_hash":"1","metadata_hash":"2","parent_data_hash":"3","parent_metadata_hash":"4"}', + ) + + # starts off as a string in the private property + assert isinstance(snapshot.fingerprint_, str) + + # gets parsed into SnapshotFingerprint on first access of public property + fingerprint = snapshot.fingerprint + assert isinstance(fingerprint, SnapshotFingerprint) + assert isinstance(snapshot.fingerprint_, SnapshotFingerprint) + + assert fingerprint.data_hash == "1" + assert fingerprint.metadata_hash == "2" + assert fingerprint.parent_data_hash == "3" + assert fingerprint.parent_metadata_hash == "4" + assert snapshot.dev_version is not None # dev version uses fingerprint + + # can also be supplied as a SnapshotFingerprint to begin with instead of a str + snapshot = SnapshotIdAndVersion( + name="a", identifier="1234", version="2345", dev_version=None, fingerprint=fingerprint + ) + + assert isinstance(snapshot.fingerprint_, SnapshotFingerprint) + assert snapshot.fingerprint == fingerprint + + +def test_snapshot_id_and_version_optional_kind_name(): + snapshot = SnapshotIdAndVersion( + name="a", + identifier="1234", + version="2345", + dev_version=None, + fingerprint="", + ) + + assert snapshot.model_kind_name is None + + snapshot = SnapshotIdAndVersion( + name="a", + identifier="1234", + version="2345", + kind_name="INCREMENTAL_UNMANAGED", + dev_version=None, + fingerprint="", + ) + + assert snapshot.model_kind_name + assert snapshot.is_incremental_unmanaged + assert snapshot.full_history_restatement_only diff --git a/tests/core/test_snapshot_evaluator.py b/tests/core/test_snapshot_evaluator.py index 014effa093..1413ac81f1 100644 --- a/tests/core/test_snapshot_evaluator.py +++ b/tests/core/test_snapshot_evaluator.py @@ -1,9 +1,15 @@ from __future__ import annotations import typing as t -from unittest.mock import call, patch +from typing_extensions import Self +from unittest.mock import call, patch, Mock +import contextlib +import re import logging import pytest +import pandas as pd # noqa: TID253 +import json +from pydantic import model_validator from pathlib import Path from pytest_mock.plugin import MockerFixture from sqlglot import expressions as exp @@ -12,7 +18,7 @@ from sqlmesh.core.audit import ModelAudit, StandaloneAudit from sqlmesh.core import dialect as d from sqlmesh.core.dialect import schema_, to_schema -from sqlmesh.core.engine_adapter import EngineAdapter, create_engine_adapter +from sqlmesh.core.engine_adapter import EngineAdapter, create_engine_adapter, BigQueryEngineAdapter from sqlmesh.core.engine_adapter.base import MERGE_SOURCE_ALIAS, MERGE_TARGET_ALIAS from sqlmesh.core.engine_adapter.shared import ( DataObject, @@ -20,34 +26,61 @@ InsertOverwriteStrategy, ) from sqlmesh.core.environment import EnvironmentNamingInfo -from sqlmesh.core.macros import RuntimeStage, macro +from sqlmesh.core.macros import RuntimeStage, macro, MacroEvaluator, MacroFunc from sqlmesh.core.model import ( Model, FullKind, IncrementalByTimeRangeKind, IncrementalUnmanagedKind, IncrementalByPartitionKind, + IncrementalByUniqueKeyKind, PythonModel, SqlModel, TimeColumn, ViewKind, + CustomKind, load_sql_based_model, + ExternalModel, + model, + create_sql_model, ) -from sqlmesh.core.model.kind import OnDestructiveChange +from sqlmesh.core.model.kind import OnDestructiveChange, ExternalKind, OnAdditiveChange +from sqlmesh.core.model.meta import GrantsTargetLayer from sqlmesh.core.node import IntervalUnit from sqlmesh.core.snapshot import ( DeployabilityIndex, Intervals, Snapshot, + SnapshotDataVersion, + SnapshotFingerprint, SnapshotChangeCategory, SnapshotEvaluator, SnapshotTableCleanupTask, ) -from sqlmesh.core.snapshot.evaluator import CustomMaterialization +from sqlmesh.core.snapshot.definition import to_view_mapping +from sqlmesh.core.snapshot.evaluator import ( + CustomMaterialization, + EngineManagedStrategy, + FullRefreshStrategy, + IncrementalByPartitionStrategy, + IncrementalByTimeRangeStrategy, + IncrementalByUniqueKeyStrategy, + IncrementalUnmanagedStrategy, + MaterializableStrategy, + SCDType2Strategy, + SnapshotCreationFailedError, + ViewStrategy, +) from sqlmesh.utils.concurrency import NodeExecutionFailedError from sqlmesh.utils.date import to_timestamp -from sqlmesh.utils.errors import AuditError, ConfigError +from sqlmesh.utils.errors import ( + ConfigError, + SQLMeshError, + DestructiveChangeError, + AdditiveChangeError, +) from sqlmesh.utils.metaprogramming import Executable +from sqlmesh.utils.pydantic import list_of_fields_validator if t.TYPE_CHECKING: @@ -80,13 +113,16 @@ def date_kwargs() -> t.Dict[str, str]: @pytest.fixture def adapter_mock(mocker: MockerFixture): + def mock_exit(self, exc_type, exc_value, traceback): + pass + transaction_mock = mocker.Mock() transaction_mock.__enter__ = mocker.Mock() - transaction_mock.__exit__ = mocker.Mock() + transaction_mock.__exit__ = mock_exit session_mock = mocker.Mock() session_mock.__enter__ = mocker.Mock() - session_mock.__exit__ = mocker.Mock() + session_mock.__exit__ = mock_exit adapter_mock = mocker.Mock() adapter_mock.transaction.return_value = transaction_mock @@ -95,9 +131,34 @@ def adapter_mock(mocker: MockerFixture): adapter_mock.HAS_VIEW_BINDING = False adapter_mock.wap_supported.return_value = False adapter_mock.get_data_objects.return_value = [] + adapter_mock.with_settings.return_value = adapter_mock return adapter_mock +@pytest.fixture +def adapters(mocker: MockerFixture): + adapters = [] + for i in range(3): + transaction_mock = mocker.Mock() + transaction_mock.__enter__ = mocker.Mock() + transaction_mock.__exit__ = mocker.Mock() + + session_mock = mocker.Mock() + session_mock.__enter__ = mocker.Mock() + session_mock.__exit__ = mocker.Mock() + + adapter_mock = mocker.Mock() + adapter_mock.transaction.return_value = transaction_mock + adapter_mock.session.return_value = session_mock + adapter_mock.dialect = "duckdb" + adapter_mock.HAS_VIEW_BINDING = False + adapter_mock.wap_supported.return_value = False + adapter_mock.get_data_objects.return_value = [] + adapter_mock.with_settings.return_value = adapter_mock + adapters.append(adapter_mock) + return adapters + + def test_evaluate(mocker: MockerFixture, adapter_mock, make_snapshot): evaluator = SnapshotEvaluator(adapter_mock) @@ -159,14 +220,9 @@ def x(evaluator, y=None) -> None: execute_calls = [call([parse_one('CREATE TABLE "hook_called"')])] adapter_mock.execute.assert_has_calls(execute_calls) - adapter_mock.create_schema.assert_has_calls( - [ - call(to_schema("sqlmesh__test_schema")), - ] - ) - common_kwargs = dict( - columns_to_types={"a": exp.DataType.build("int")}, + target_columns_to_types={"a": exp.DataType.build("int")}, + table_format=None, storage_format="parquet", partitioned_by=[exp.to_column("a", quoted=True)], partition_interval_unit=IntervalUnit.DAY, @@ -175,19 +231,11 @@ def x(evaluator, y=None) -> None: table_description=None, ) - adapter_mock.create_table.assert_has_calls( - [ - call( - snapshot.table_name(is_deployable=False), - column_descriptions=None, - **common_kwargs, - ), - call( - snapshot.table_name(), - column_descriptions={}, - **common_kwargs, - ), - ] + # Create will be called once and only prod table will be created + adapter_mock.create_table.assert_called_once_with( + snapshot.table_name(), + column_descriptions={}, + **common_kwargs, ) @@ -220,6 +268,9 @@ def increment_stage_counter(evaluator) -> None: snapshot = make_snapshot(model) snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + snapshot.model.render_pre_statements() + assert f"RuntimeStage value: {RuntimeStage.LOADING.value}" in capsys.readouterr().out evaluator.create([snapshot], {}) @@ -263,6 +314,8 @@ def test_promote(mocker: MockerFixture, adapter_mock, make_snapshot): evaluator.promote([snapshot], EnvironmentNamingInfo(name="test_env")) + adapter_mock.transaction.assert_called() + adapter_mock.session.assert_called() adapter_mock.create_schema.assert_called_once_with(to_schema("test_schema__test_env")) adapter_mock.create_view.assert_called_once_with( "test_schema__test_env.test_model", @@ -289,6 +342,8 @@ def test_demote(mocker: MockerFixture, adapter_mock, make_snapshot): evaluator.demote([snapshot], EnvironmentNamingInfo(name="test_env")) + adapter_mock.transaction.assert_called() + adapter_mock.session.assert_called() adapter_mock.drop_view.assert_called_once_with( "test_schema__test_env.test_model", cascade=False, @@ -335,7 +390,7 @@ def test_promote_forward_only(mocker: MockerFixture, adapter_mock, make_snapshot ) snapshot = make_snapshot(model) - snapshot.categorize_as(SnapshotChangeCategory.FORWARD_ONLY) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) snapshot.version = "test_version" evaluator.promote( @@ -362,7 +417,7 @@ def test_promote_forward_only(mocker: MockerFixture, adapter_mock, make_snapshot call( "test_schema__test_env.test_model", parse_one( - f"SELECT * FROM sqlmesh__test_schema.test_schema__test_model__{snapshot.fingerprint.to_version()}__temp" + f"SELECT * FROM sqlmesh__test_schema.test_schema__test_model__{snapshot.fingerprint.to_version()}__dev" ), table_description=None, column_descriptions=None, @@ -393,41 +448,221 @@ def create_and_cleanup(name: str, dev_table_only: bool): ) snapshot = make_snapshot(model) - snapshot.categorize_as(SnapshotChangeCategory.FORWARD_ONLY) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) snapshot.version = "test_version" + on_cleanup_mock = mocker.Mock() + evaluator.promote([snapshot], EnvironmentNamingInfo(name="test_env")) evaluator.cleanup( - [SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=dev_table_only)] + [SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=dev_table_only)], + on_complete=on_cleanup_mock, ) + assert on_cleanup_mock.call_count == 1 if dev_table_only else 2 return snapshot snapshot = create_and_cleanup("catalog.test_schema.test_model", True) + adapter_mock.get_data_object.assert_not_called() adapter_mock.drop_table.assert_called_once_with( - f"catalog.sqlmesh__test_schema.test_schema__test_model__{snapshot.fingerprint.to_version()}__temp" + f"catalog.sqlmesh__test_schema.test_schema__test_model__{snapshot.fingerprint.to_version()}__dev", + cascade=True, ) adapter_mock.reset_mock() snapshot = create_and_cleanup("test_schema.test_model", False) + adapter_mock.get_data_object.assert_not_called() adapter_mock.drop_table.assert_has_calls( [ call( - f"sqlmesh__test_schema.test_schema__test_model__{snapshot.fingerprint.to_version()}__temp" + f"sqlmesh__test_schema.test_schema__test_model__{snapshot.fingerprint.to_version()}__dev", + cascade=True, ), - call(f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}"), + call(f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}", cascade=True), ] ) adapter_mock.reset_mock() snapshot = create_and_cleanup("test_model", False) + adapter_mock.get_data_object.assert_not_called() adapter_mock.drop_table.assert_has_calls( [ - call(f"sqlmesh__default.test_model__{snapshot.fingerprint.to_version()}__temp"), - call(f"sqlmesh__default.test_model__{snapshot.version}"), + call( + f"sqlmesh__default.test_model__{snapshot.fingerprint.to_version()}__dev", + cascade=True, + ), + call(f"sqlmesh__default.test_model__{snapshot.version}", cascade=True), + ] + ) + + +def test_cleanup_view(adapter_mock, make_snapshot): + evaluator = SnapshotEvaluator(adapter_mock) + + model = SqlModel( + name="catalog.test_schema.test_model", + kind=ViewKind(materialized=False), + query=parse_one("SELECT a FROM tbl"), + ) + + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) + + evaluator.promote([snapshot], EnvironmentNamingInfo(name="test_env")) + evaluator.cleanup([SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=True)]) + + adapter_mock.get_data_object.assert_not_called() + adapter_mock.drop_view.assert_called_once_with( + f"catalog.sqlmesh__test_schema.test_schema__test_model__{snapshot.fingerprint.to_version()}__dev", + cascade=True, + ignore_if_not_exists=False, + ) + + +def test_cleanup_materialized_view(adapter_mock, make_snapshot): + evaluator = SnapshotEvaluator(adapter_mock) + + model = SqlModel( + name="catalog.test_schema.test_model", + kind=ViewKind(materialized=True), + query=parse_one("SELECT a FROM tbl"), + ) + + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) + + adapter_mock.drop_view.side_effect = [RuntimeError("failed to drop view"), None] + + evaluator.promote([snapshot], EnvironmentNamingInfo(name="test_env")) + evaluator.cleanup([SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=True)]) + + adapter_mock.get_data_object.assert_not_called() + adapter_mock.drop_view.assert_has_calls( + [ + call( + f"catalog.sqlmesh__test_schema.test_schema__test_model__{snapshot.fingerprint.to_version()}__dev", + cascade=True, + ignore_if_not_exists=False, + ), + call( + f"catalog.sqlmesh__test_schema.test_schema__test_model__{snapshot.fingerprint.to_version()}__dev", + materialized=True, + cascade=True, + ignore_if_not_exists=True, + ), ] ) +def test_cleanup_fails(adapter_mock, make_snapshot): + adapter_mock.drop_table.side_effect = RuntimeError("test_error") + + evaluator = SnapshotEvaluator(adapter_mock) + + model = SqlModel( + name="catalog.test_schema.test_model", + kind=IncrementalByTimeRangeKind(time_column="a"), + storage_format="parquet", + query=parse_one("SELECT a FROM tbl WHERE ds BETWEEN @start_ds and @end_ds"), + ) + + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) + snapshot.version = "test_version" + + evaluator.promote([snapshot], EnvironmentNamingInfo(name="test_env")) + with pytest.raises(NodeExecutionFailedError) as exc_info: + evaluator.cleanup( + [SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=True)] + ) + + assert str(exc_info.value.__cause__) == "test_error" + + +def test_cleanup_skip_missing_table(adapter_mock, make_snapshot): + adapter_mock.get_data_object.return_value = None + adapter_mock.drop_table.side_effect = RuntimeError("fail") + + evaluator = SnapshotEvaluator(adapter_mock) + + model = SqlModel( + name="catalog.test_schema.test_model", + kind=IncrementalByTimeRangeKind(time_column="a"), + storage_format="parquet", + query=parse_one("SELECT a FROM tbl WHERE ds BETWEEN @start_ds and @end_ds"), + ) + + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) + snapshot.version = "test_version" + + evaluator.promote([snapshot], EnvironmentNamingInfo(name="test_env")) + evaluator.cleanup([SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=True)]) + + adapter_mock.get_data_object.assert_called_once_with( + f"catalog.sqlmesh__test_schema.test_schema__test_model__{snapshot.fingerprint.to_version()}__dev" + ) + adapter_mock.drop_table.assert_called_once_with( + f"catalog.sqlmesh__test_schema.test_schema__test_model__{snapshot.fingerprint.to_version()}__dev", + cascade=True, + ) + + +def test_cleanup_external_model(mocker: MockerFixture, adapter_mock, make_snapshot): + evaluator = SnapshotEvaluator(adapter_mock) + + def create_and_cleanup_external_model(name: str, dev_table_only: bool): + model = ExternalModel( + name=name, + kind=ExternalKind(), + ) + + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot.version = "test_version" + + evaluator.promote([snapshot], EnvironmentNamingInfo(name="test_env")) + evaluator.cleanup( + [SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=dev_table_only)] + ) + return snapshot + + create_and_cleanup_external_model("catalog.test_schema.test_model", True) + adapter_mock.drop_table.assert_not_called() + + +def test_cleanup_symbolic_and_audit_snapshots_no_callback( + mocker: MockerFixture, adapter_mock, make_snapshot +): + evaluator = SnapshotEvaluator(adapter_mock) + on_complete_mock = mocker.Mock() + + # Test external model + external_model = ExternalModel( + name="test_schema.external_model", + kind=ExternalKind(), + ) + external_snapshot = make_snapshot(external_model) + external_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + # Test standalone audit + audit = StandaloneAudit(name="test_audit", query=parse_one("SELECT NULL LIMIT 0")) + audit_snapshot = make_snapshot(audit) + audit_snapshot.categorize_as(SnapshotChangeCategory.NON_BREAKING) + + evaluator.cleanup( + [ + SnapshotTableCleanupTask(snapshot=external_snapshot.table_info, dev_table_only=False), + SnapshotTableCleanupTask(snapshot=audit_snapshot.table_info, dev_table_only=False), + ], + on_complete=on_complete_mock, + ) + + # Verify that no physical tables were attempted to be dropped + adapter_mock.drop_table.assert_not_called() + adapter_mock.get_data_object.assert_not_called() + on_complete_mock.assert_not_called() + + @pytest.mark.parametrize("view_exists", [True, False]) def test_evaluate_materialized_view( mocker: MockerFixture, adapter_mock, make_snapshot, view_exists: bool @@ -462,25 +697,49 @@ def test_evaluate_materialized_view( snapshots={}, ) - adapter_mock.table_exists.assert_called_once_with(snapshot.table_name()) + # Ensure that the materialized view is recreated even if it exists + assert adapter_mock.create_view.call_count == 1 - if view_exists: - # Evaluation shouldn't take place because the rendered query hasn't changed - # since the last view creation. - assert not adapter_mock.create_view.called - else: - # If the view doesn't exist, it should be created even if the rendered query - # hasn't changed since the last view creation. - adapter_mock.create_view.assert_called_once_with( - snapshot.table_name(), - model.render_query(), - model.columns_to_types, - replace=True, + +def test_evaluate_materialized_view_with_partitioned_by_cluster_by( + mocker: MockerFixture, adapter_mock, make_snapshot +): + execute_mock = mocker.Mock() + # Use an engine adapter that supports cluster by/partitioned by + adapter = BigQueryEngineAdapter(lambda: mocker.Mock()) + adapter.table_exists = lambda *args, **kwargs: False # type: ignore + adapter.get_data_objects = lambda *args, **kwargs: [] # type: ignore + adapter._execute = execute_mock # type: ignore + adapter.with_settings = lambda **kwargs: adapter # type: ignore + evaluator = SnapshotEvaluator(adapter) + + model = SqlModel( + name="test_schema.test_model", + kind=ViewKind( materialized=True, - view_properties={}, - table_description=None, - column_descriptions={}, - ) + ), + partitioned_by=[exp.to_column("a")], + clustered_by=[exp.to_column("b")], + query=parse_one("SELECT a, b FROM tbl"), + ) + + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot.add_interval("2023-01-01", "2023-01-01") + + evaluator.create( + [snapshot], + snapshots={}, + ) + + execute_mock.assert_has_calls( + [ + call( + f"CREATE MATERIALIZED VIEW `sqlmesh__test_schema`.`test_schema__test_model__{snapshot.version}` PARTITION BY `a` CLUSTER BY `b` AS SELECT `a` AS `a`, `b` AS `b` FROM `tbl` AS `tbl`", + False, + ), + ] + ) def test_evaluate_materialized_view_with_execution_time_macro( @@ -540,6 +799,8 @@ def test_evaluate_incremental_unmanaged_with_intervals( snapshot.categorize_as(SnapshotChangeCategory.BREAKING) snapshot.intervals = [(to_timestamp("2020-01-01"), to_timestamp("2020-01-02"))] + adapter_mock.columns.return_value = model.columns_to_types + evaluator = SnapshotEvaluator(adapter_mock) evaluator.evaluate( snapshot, @@ -554,13 +815,15 @@ def test_evaluate_incremental_unmanaged_with_intervals( snapshot.table_name(), model.render_query(), [exp.to_column("ds", quoted=True)], - columns_to_types=model.columns_to_types, + target_columns_to_types=model.columns_to_types, + source_columns=None, ) else: adapter_mock.insert_append.assert_called_once_with( snapshot.table_name(), model.render_query(), - columns_to_types=model.columns_to_types, + target_columns_to_types=model.columns_to_types, + source_columns=None, ) @@ -570,13 +833,16 @@ def test_evaluate_incremental_unmanaged_no_intervals( ): model = SqlModel( name="test_schema.test_model", - query=parse_one("SELECT 1, ds FROM tbl_a"), + query=parse_one("SELECT 1 as one, ds FROM tbl_a"), kind=IncrementalUnmanagedKind(insert_overwrite=insert_overwrite), partitioned_by=["ds"], ) snapshot = make_snapshot(model) snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + table_columns = {"one": exp.DataType.build("int"), "ds": exp.DataType.build("timestamp")} + adapter_mock.columns.return_value = table_columns + evaluator = SnapshotEvaluator(adapter_mock) evaluator.evaluate( snapshot, @@ -591,16 +857,19 @@ def test_evaluate_incremental_unmanaged_no_intervals( model.render_query(), clustered_by=[], column_descriptions={}, - columns_to_types=None, - partition_interval_unit=model.interval_unit, + target_columns_to_types=table_columns, + partition_interval_unit=model.partition_interval_unit, partitioned_by=model.partitioned_by, + table_format=None, storage_format=None, table_description=None, table_properties={}, + source_columns=None, ) + adapter_mock.columns.assert_called_once_with(snapshot.table_name()) -def test_create_tables_exists(mocker: MockerFixture, adapter_mock, make_snapshot): +def test_create_prod_table_exists(mocker: MockerFixture, adapter_mock, make_snapshot): model = load_sql_based_model( parse( # type: ignore """ @@ -618,11 +887,6 @@ def test_create_tables_exists(mocker: MockerFixture, adapter_mock, make_snapshot snapshot.categorize_as(SnapshotChangeCategory.BREAKING) adapter_mock.get_data_objects.return_value = [ - DataObject( - name=f"test_schema__test_model__{snapshot.version}__temp", - schema="sqlmesh__test_schema", - type=DataObjectType.VIEW, - ), DataObject( name=f"test_schema__test_model__{snapshot.version}", schema="sqlmesh__test_schema", @@ -633,13 +897,63 @@ def test_create_tables_exists(mocker: MockerFixture, adapter_mock, make_snapshot evaluator.create([snapshot], {}) adapter_mock.create_view.assert_not_called() + adapter_mock.create_schema.assert_not_called() adapter_mock.get_data_objects.assert_called_once_with( schema_("sqlmesh__test_schema"), { - f"test_schema__test_model__{snapshot.version}__temp", f"test_schema__test_model__{snapshot.version}", }, + safe_to_cache=True, + ) + + +def test_pre_hook_forward_only_clone( + mocker: MockerFixture, make_mocked_engine_adapter, make_snapshot +): + """ + Verifies that pre-statements are executed when creating a clone of a forward-only model. + """ + pre_statement = """CREATE TEMPORARY FUNCTION "example_udf"("x" BIGINT) AS ("x" + 1)""" + model = load_sql_based_model( + parse( # type: ignore + f""" + MODEL ( + name test_schema.test_model, + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds + ) + ); + + {pre_statement}; + + SELECT a::int, ds::string FROM tbl; + """ + ), + ) + + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) + snapshot.previous_versions = snapshot.all_versions + + adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.with_settings = lambda **kwargs: adapter + adapter.table_exists = lambda _: True # type: ignore + adapter.SUPPORTS_CLONING = True + mocker.patch.object( + adapter, + "get_data_objects", + return_value=[], ) + mocker.patch.object( + adapter, + "get_alter_operations", + return_value=[], + ) + + evaluator = SnapshotEvaluator(adapter) + + evaluator.create([snapshot], {}, deployability_index=DeployabilityIndex.none_deployable()) + adapter.cursor.execute.assert_any_call(pre_statement) def test_create_only_dev_table_exists(mocker: MockerFixture, adapter_mock, make_snapshot): @@ -661,7 +975,7 @@ def test_create_only_dev_table_exists(mocker: MockerFixture, adapter_mock, make_ adapter_mock.get_data_objects.return_value = [ DataObject( - name=f"test_schema__test_model__{snapshot.version}__temp", + name=f"test_schema__test_model__{snapshot.version}__dev", schema="sqlmesh__test_schema", type=DataObjectType.VIEW, ), @@ -669,28 +983,30 @@ def test_create_only_dev_table_exists(mocker: MockerFixture, adapter_mock, make_ adapter_mock.table_exists.return_value = True evaluator = SnapshotEvaluator(adapter_mock) - evaluator.create([snapshot], {}) - - adapter_mock.create_view.assert_not_called + evaluator.create([snapshot], {}, deployability_index=DeployabilityIndex.none_deployable()) + adapter_mock.create_view.assert_not_called() adapter_mock.get_data_objects.assert_called_once_with( schema_("sqlmesh__test_schema"), { - f"test_schema__test_model__{snapshot.version}__temp", - f"test_schema__test_model__{snapshot.version}", + f"test_schema__test_model__{snapshot.version}__dev", }, + safe_to_cache=True, ) -def test_create_view_non_deployable_snapshot(mocker: MockerFixture, adapter_mock, make_snapshot): +def test_create_new_forward_only_model(mocker: MockerFixture, adapter_mock, make_snapshot): model = load_sql_based_model( parse( # type: ignore """ MODEL ( name test_schema.test_model, - kind VIEW + kind INCREMENTAL_BY_TIME_RANGE ( + time_column ds, + forward_only true, + ) ); - SELECT a::int FROM tbl; + SELECT a::int, '2024-01-01' as ds FROM tbl; """ ), ) @@ -702,29 +1018,208 @@ def test_create_view_non_deployable_snapshot(mocker: MockerFixture, adapter_mock adapter_mock.table_exists.return_value = False evaluator = SnapshotEvaluator(adapter_mock) - deployability_index = DeployabilityIndex.none_deployable() - evaluator.create([snapshot], {}, deployability_index=deployability_index) - - adapter_mock.create_view.assert_called_once_with( - snapshot.table_name(is_deployable=False), - model.render_query(), - column_descriptions=None, - view_properties={}, + evaluator.create([snapshot], {}, deployability_index=DeployabilityIndex.none_deployable()) + # Only non-deployable table should be created + adapter_mock.create_table.assert_called_once_with( + f"sqlmesh__test_schema.test_schema__test_model__{snapshot.dev_version}__dev", + target_columns_to_types={ + "a": exp.DataType.build("int"), + "ds": exp.DataType.build("varchar"), + }, + table_format=None, + storage_format=None, + partitioned_by=model.partitioned_by, + partition_interval_unit=model.partition_interval_unit, + clustered_by=[], + table_properties={}, table_description=None, - materialized=False, - replace=False, + column_descriptions=None, + ) + adapter_mock.get_data_objects.assert_called_once_with( + schema_("sqlmesh__test_schema"), + { + f"test_schema__test_model__{snapshot.dev_version}__dev", + }, + safe_to_cache=True, ) -def test_create_materialized_view(mocker: MockerFixture, adapter_mock, make_snapshot): - adapter_mock.get_data_objects.return_value = [] - adapter_mock.table_exists.return_value = False +@pytest.mark.parametrize( + "deployability_index, snapshot_category, forward_only", + [ + (DeployabilityIndex.all_deployable(), SnapshotChangeCategory.BREAKING, False), + (DeployabilityIndex.all_deployable(), SnapshotChangeCategory.NON_BREAKING, False), + (DeployabilityIndex.all_deployable(), SnapshotChangeCategory.BREAKING, True), + ( + DeployabilityIndex.all_deployable(), + SnapshotChangeCategory.INDIRECT_BREAKING, + False, + ), + ( + DeployabilityIndex.all_deployable(), + SnapshotChangeCategory.INDIRECT_NON_BREAKING, + False, + ), + (DeployabilityIndex.all_deployable(), SnapshotChangeCategory.METADATA, False), + ( + DeployabilityIndex.none_deployable(), + SnapshotChangeCategory.BREAKING, + False, + ), + ( + DeployabilityIndex.none_deployable(), + SnapshotChangeCategory.NON_BREAKING, + False, + ), + ( + DeployabilityIndex.none_deployable(), + SnapshotChangeCategory.BREAKING, + True, + ), + ( + DeployabilityIndex.none_deployable(), + SnapshotChangeCategory.INDIRECT_BREAKING, + False, + ), + ( + DeployabilityIndex.none_deployable(), + SnapshotChangeCategory.INDIRECT_NON_BREAKING, + False, + ), + ( + DeployabilityIndex.none_deployable(), + SnapshotChangeCategory.METADATA, + False, + ), + ], +) +def test_create_tables_exist( + snapshot: Snapshot, + mocker: MockerFixture, + adapter_mock, + deployability_index: DeployabilityIndex, + snapshot_category: SnapshotChangeCategory, + forward_only: bool, +): + adapter_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter") + adapter_mock.dialect = "duckdb" + adapter_mock.with_settings.return_value = adapter_mock evaluator = SnapshotEvaluator(adapter_mock) + snapshot.categorize_as(category=snapshot_category, forward_only=forward_only) - model = load_sql_based_model( - parse( # type: ignore - """ + table_name = ( + f"db__model__{snapshot.version}" + if deployability_index.is_deployable(snapshot) + else f"db__model__{snapshot.version}__dev" + ) + + adapter_mock.get_data_objects.return_value = [ + DataObject( + name=table_name, + schema="sqlmesh__db", + type=DataObjectType.TABLE, + ), + ] + + evaluator.create( + target_snapshots=[snapshot], + snapshots={}, + deployability_index=deployability_index, + ) + + adapter_mock.get_data_objects.assert_called_once_with( + schema_("sqlmesh__db"), + {table_name}, + safe_to_cache=True, + ) + adapter_mock.create_schema.assert_not_called() + adapter_mock.create_table.assert_not_called() + + +def test_create_prod_table_exists_forward_only(mocker: MockerFixture, adapter_mock, make_snapshot): + model = load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name test_schema.test_model, + kind FULL + ); + + SELECT a::int FROM tbl; + """ + ), + ) + + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) + + adapter_mock.get_data_objects.return_value = [ + DataObject( + name=f"test_schema__test_model__{snapshot.version}", + schema="sqlmesh__test_schema", + type=DataObjectType.TABLE, + ), + ] + evaluator = SnapshotEvaluator(adapter_mock) + evaluator.create([snapshot], {}) + + adapter_mock.get_data_objects.assert_called_once_with( + schema_("sqlmesh__test_schema"), + { + f"test_schema__test_model__{snapshot.version}", + }, + safe_to_cache=True, + ) + + adapter_mock.create_table.assert_not_called() + + +def test_create_view_non_deployable_snapshot(mocker: MockerFixture, adapter_mock, make_snapshot): + model = load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name test_schema.test_model, + kind VIEW + ); + + SELECT a::int FROM tbl; + """ + ), + ) + + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + adapter_mock.get_data_objects.return_value = [] + adapter_mock.table_exists.return_value = False + evaluator = SnapshotEvaluator(adapter_mock) + + deployability_index = DeployabilityIndex.none_deployable() + evaluator.create([snapshot], {}, deployability_index=deployability_index) + + adapter_mock.create_view.assert_called_once_with( + snapshot.table_name(is_deployable=False), + model.render_query(), + column_descriptions=None, + view_properties={}, + table_description=None, + materialized=False, + replace=False, + materialized_properties=None, + ) + + +def test_create_materialized_view(mocker: MockerFixture, adapter_mock, make_snapshot): + adapter_mock.get_data_objects.return_value = [] + adapter_mock.table_exists.return_value = False + + evaluator = SnapshotEvaluator(adapter_mock) + + model = load_sql_based_model( + parse( # type: ignore + """ MODEL ( name test_schema.test_model, kind VIEW ( @@ -744,23 +1239,18 @@ def test_create_materialized_view(mocker: MockerFixture, adapter_mock, make_snap common_kwargs = dict( materialized=True, + materialized_properties={ + "clustered_by": [], + "partition_interval_unit": None, + "partitioned_by": [], + }, view_properties={}, table_description=None, replace=False, ) - adapter_mock.create_view.assert_has_calls( - [ - call( - snapshot.table_name(is_deployable=False), - model.render_query(), - column_descriptions=None, - **common_kwargs, - ), - call( - snapshot.table_name(), model.render_query(), column_descriptions={}, **common_kwargs - ), - ] + adapter_mock.create_view.assert_called_once_with( + snapshot.table_name(), model.render_query(), column_descriptions={}, **common_kwargs ) @@ -798,28 +1288,24 @@ def test_create_view_with_properties(mocker: MockerFixture, adapter_mock, make_s view_properties={ "key": exp.convert("value"), }, + materialized_properties={ + "clustered_by": [], + "partition_interval_unit": None, + "partitioned_by": [], + }, table_description=None, replace=False, ) - adapter_mock.create_view.assert_has_calls( - [ - call( - snapshot.table_name(is_deployable=False), - model.render_query(), - column_descriptions=None, - **common_kwargs, - ), - call( - snapshot.table_name(), model.render_query(), column_descriptions={}, **common_kwargs - ), - ] + adapter_mock.create_view.assert_called_once_with( + snapshot.table_name(), model.render_query(), column_descriptions={}, **common_kwargs ) def test_promote_model_info(mocker: MockerFixture, make_snapshot): adapter_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter") adapter_mock.dialect = "duckdb" + adapter_mock.with_settings.return_value = adapter_mock evaluator = SnapshotEvaluator(adapter_mock) @@ -845,13 +1331,74 @@ def test_promote_model_info(mocker: MockerFixture, make_snapshot): ) -def test_migrate(mocker: MockerFixture, make_snapshot): - connection_mock = mocker.NonCallableMock() - cursor_mock = mocker.Mock() - connection_mock.cursor.return_value = cursor_mock - adapter = EngineAdapter(lambda: connection_mock, "") +def test_promote_deployable(mocker: MockerFixture, make_snapshot): + adapter_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter") + adapter_mock.dialect = "duckdb" + adapter_mock.with_settings.return_value = adapter_mock + + evaluator = SnapshotEvaluator(adapter_mock) + + model = SqlModel( + name="test_schema.test_model", + kind=FullKind(), + query=parse_one("SELECT a FROM tbl"), + ) + + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + adapter_mock.get_data_objects.return_value = [ + DataObject( + name=f"test_schema__test_model__{snapshot.version}", + schema="sqlmesh__test_schema", + type=DataObjectType.TABLE, + ), + ] + + evaluator.create([snapshot], {}) + adapter_mock.get_data_objects.assert_called_once_with( + schema_("sqlmesh__test_schema"), + { + f"test_schema__test_model__{snapshot.version}", + }, + safe_to_cache=True, + ) + adapter_mock.create_table.assert_not_called() + + adapter_mock.get_data_objects.return_value = [] + evaluator.promote([snapshot], EnvironmentNamingInfo(name="test_env")) + + adapter_mock.create_schema.assert_called_once_with(to_schema("test_schema__test_env")) + adapter_mock.create_view.assert_called_once_with( + "test_schema__test_env.test_model", + parse_one( + f"SELECT * FROM sqlmesh__test_schema.test_schema__test_model__{snapshot.version}" + ), + table_description=None, + column_descriptions=None, + view_properties={}, + ) + + +def test_migrate(mocker: MockerFixture, make_snapshot, make_mocked_engine_adapter): + adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.with_settings = lambda **kwargs: adapter # type: ignore + session_spy = mocker.spy(adapter, "session") + + model = SqlModel( + name="test_schema.test_model", + kind=IncrementalByTimeRangeKind( + time_column="a", on_destructive_change=OnDestructiveChange.ALLOW + ), + storage_format="parquet", + query=parse_one("SELECT c, a FROM tbl WHERE ds BETWEEN @start_ds and @end_ds"), + ) + snapshot = make_snapshot(model, version="1") + snapshot.change_category = SnapshotChangeCategory.BREAKING + snapshot.forward_only = True + snapshot.previous_versions = snapshot.all_versions - current_table = "sqlmesh__test_schema.test_schema__test_model__1" + current_table = snapshot.table_name() def columns(table_name): if table_name == current_table: @@ -859,13 +1406,48 @@ def columns(table_name): "c": exp.DataType.build("int"), "b": exp.DataType.build("int"), } - else: - return { - "c": exp.DataType.build("int"), - "a": exp.DataType.build("int"), - } + return { + "c": exp.DataType.build("int"), + "a": exp.DataType.build("int"), + } adapter.columns = columns # type: ignore + adapter.table_exists = lambda _: True # type: ignore + mocker.patch.object( + adapter, + "get_data_objects", + return_value=[ + DataObject( + schema="sqlmesh__test_schema", + name=f"test_schema__test_model__{snapshot.version}", + type="table", + ) + ], + ) + + evaluator = SnapshotEvaluator(adapter) + + evaluator.migrate([snapshot], {}) + + adapter.cursor.execute.assert_has_calls( + [ + call( + f'ALTER TABLE "sqlmesh__test_schema"."test_schema__test_model__{snapshot.version}" DROP COLUMN "b"' + ), + call( + f'ALTER TABLE "sqlmesh__test_schema"."test_schema__test_model__{snapshot.version}" ADD COLUMN "a" INT' + ), + ] + ) + + session_spy.assert_called_once() + + +def test_migrate_missing_table(mocker: MockerFixture, make_snapshot, make_mocked_engine_adapter): + adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.table_exists = lambda _: False # type: ignore + adapter.with_settings = lambda **kwargs: adapter # type: ignore + mocker.patch.object(adapter, "get_data_object", return_value=None) evaluator = SnapshotEvaluator(adapter) @@ -876,36 +1458,33 @@ def columns(table_name): ), storage_format="parquet", query=parse_one("SELECT c, a FROM tbl WHERE ds BETWEEN @start_ds and @end_ds"), + pre_statements=[parse_one("CREATE TABLE pre (a INT)")], + post_statements=[parse_one("DROP TABLE pre")], ) snapshot = make_snapshot(model, version="1") - snapshot.change_category = SnapshotChangeCategory.FORWARD_ONLY + snapshot.change_category = SnapshotChangeCategory.BREAKING + snapshot.forward_only = True + snapshot.previous_versions = snapshot.all_versions - evaluator.migrate([snapshot], {}) + evaluator.migrate([snapshot], {}, deployability_index=DeployabilityIndex.none_deployable()) - cursor_mock.execute.assert_has_calls( - [ - call('ALTER TABLE "sqlmesh__test_schema"."test_schema__test_model__1" DROP COLUMN "b"'), - call( - 'ALTER TABLE "sqlmesh__test_schema"."test_schema__test_model__1" ADD COLUMN "a" INT' - ), - ] - ) + adapter.cursor.execute.assert_not_called() @pytest.mark.parametrize( - "change_category", - [SnapshotChangeCategory.FORWARD_ONLY, SnapshotChangeCategory.INDIRECT_NON_BREAKING], + "change_category, forward_only", + [ + (SnapshotChangeCategory.BREAKING, True), + (SnapshotChangeCategory.INDIRECT_NON_BREAKING, False), + ], ) def test_migrate_view( - mocker: MockerFixture, make_snapshot, change_category: SnapshotChangeCategory + mocker: MockerFixture, + make_snapshot, + make_mocked_engine_adapter, + change_category: SnapshotChangeCategory, + forward_only: bool, ): - connection_mock = mocker.NonCallableMock() - cursor_mock = mocker.Mock() - connection_mock.cursor.return_value = cursor_mock - adapter = EngineAdapter(lambda: connection_mock, "") - - evaluator = SnapshotEvaluator(adapter) - model = SqlModel( name="test_schema.test_model", kind=ViewKind(), @@ -914,14 +1493,71 @@ def test_migrate_view( ) snapshot = make_snapshot(model, version="1") snapshot.change_category = change_category + snapshot.forward_only = forward_only + + adapter = make_mocked_engine_adapter(EngineAdapter) + mocker.patch( + "sqlmesh.core.engine_adapter.base.EngineAdapter.get_data_objects", + return_value=[ + DataObject( + schema="sqlmesh__test_schema", + name=f"test_schema__test_model__{snapshot.version}", + type="view", + ) + ], + ) + evaluator = SnapshotEvaluator(adapter) evaluator.migrate([snapshot], {}) - cursor_mock.execute.assert_has_calls( + adapter.cursor.execute.assert_has_calls( [ call( - 'CREATE OR REPLACE VIEW "sqlmesh__test_schema"."test_schema__test_model__1" ("c", "a") AS SELECT "c" AS "c", "a" AS "a" FROM "tbl" AS "tbl"' + f'CREATE OR REPLACE VIEW "sqlmesh__test_schema"."test_schema__test_model__{snapshot.version}" ("c", "a") AS SELECT "c" AS "c", "a" AS "a" FROM "tbl" AS "tbl"' + ), + ] + ) + + +def test_migrate_snapshot_data_object_type_mismatch( + mocker: MockerFixture, + make_snapshot, + make_mocked_engine_adapter, +): + model = SqlModel( + name="test_schema.test_model", + kind=FullKind(), + storage_format="parquet", + query=parse_one("SELECT c, a FROM tbl"), + ) + snapshot = make_snapshot(model, version="1") + snapshot.change_category = SnapshotChangeCategory.BREAKING + snapshot.forward_only = True + snapshot.previous_versions = snapshot.all_versions + + adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.with_settings = lambda **kwargs: adapter # type: ignore + mocker.patch( + "sqlmesh.core.engine_adapter.base.EngineAdapter.get_data_objects", + return_value=[ + DataObject( + schema="sqlmesh__test_schema", + name=f"test_schema__test_model__{snapshot.version}", + type="view", ) + ], + ) + mocker.patch.object(adapter, "table_exists", return_value=False) + + evaluator = SnapshotEvaluator(adapter) + + evaluator.migrate([snapshot], {}) + + adapter.cursor.execute.assert_has_calls( + [ + call( + f'DROP VIEW IF EXISTS "sqlmesh__test_schema"."test_schema__test_model__{snapshot.version}"' + ), ] ) @@ -932,6 +1568,7 @@ def test_evaluate_creation_duckdb( date_kwargs: t.Dict[str, str], ): evaluator = SnapshotEvaluator(create_engine_adapter(lambda: duck_conn, "duckdb")) + evaluator.create_physical_schemas([snapshot], DeployabilityIndex.all_deployable()) evaluator.create([snapshot], {}) version = snapshot.version @@ -940,7 +1577,6 @@ def assert_tables_exist() -> None: "SELECT table_schema, table_name, table_type FROM information_schema.tables" ).fetchall() == [ ("sqlmesh__db", f"db__model__{version}", "BASE TABLE"), - ("sqlmesh__db", f"db__model__{version}__temp", "BASE TABLE"), ("main", "tbl", "VIEW"), ] @@ -969,6 +1605,7 @@ def assert_tables_exist() -> None: def test_migrate_duckdb(snapshot: Snapshot, duck_conn, make_snapshot): evaluator = SnapshotEvaluator(create_engine_adapter(lambda: duck_conn, "duckdb")) + evaluator.create_physical_schemas([snapshot], DeployabilityIndex.all_deployable()) evaluator.create([snapshot], {}) updated_model_dict = snapshot.model.dict() @@ -976,7 +1613,7 @@ def test_migrate_duckdb(snapshot: Snapshot, duck_conn, make_snapshot): updated_model = SqlModel.parse_obj(updated_model_dict) new_snapshot = make_snapshot(updated_model) - new_snapshot.categorize_as(SnapshotChangeCategory.FORWARD_ONLY) + new_snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) new_snapshot.version = snapshot.version evaluator.create([new_snapshot], {}) @@ -1082,7 +1719,7 @@ def test_snapshot_evaluator_yield_pd(adapter_mock, make_snapshot, input_dfs, out name="python_func", alias="python_func", path="test_snapshot_evaluator.py", - payload=f"""import pandas as pd + payload=f"""import pandas as pd # noqa: TID253 def python_func(**kwargs): for df in [ {input_dfs} @@ -1107,11 +1744,54 @@ def python_func(**kwargs): assert adapter_mock.insert_overwrite_by_time_partition.call_args[0][1].to_dict() == output_dict -def test_create_clone_in_dev(mocker: MockerFixture, adapter_mock, make_snapshot): - adapter_mock.SUPPORTS_CLONING = True - adapter_mock.get_alter_expressions.return_value = [] - evaluator = SnapshotEvaluator(adapter_mock) - +def test_snapshot_evaluator_yield_empty_pd(adapter_mock, make_snapshot): + adapter_mock.is_pyspark_df.return_value = False + adapter_mock.INSERT_OVERWRITE_STRATEGY = InsertOverwriteStrategy.INSERT_OVERWRITE + adapter_mock.try_get_df = lambda x: x + evaluator = SnapshotEvaluator(adapter_mock) + + snapshot = make_snapshot( + PythonModel( + name="db.model", + entrypoint="python_func", + kind=IncrementalByTimeRangeKind(time_column=TimeColumn(column="ds", format="%Y-%m-%d")), + columns={ + "a": "INT", + "ds": "STRING", + }, + python_env={ + "python_func": Executable( + name="python_func", + alias="python_func", + path="test_snapshot_evaluator.py", + payload="""def python_func(**kwargs): + yield from ()""", + ) + }, + ) + ) + + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + evaluator.create([snapshot], {}) + + # This should not raise a TypeError from reduce() with empty sequence + evaluator.evaluate( + snapshot, + start="2023-01-01", + end="2023-01-09", + execution_time="2023-01-09", + snapshots={}, + ) + + # When there are no dataframes to process, insert_overwrite_by_time_partition should not be called + adapter_mock.insert_overwrite_by_time_partition.assert_not_called() + + +def test_create_clone_in_dev(mocker: MockerFixture, adapter_mock, make_snapshot): + adapter_mock.SUPPORTS_CLONING = True + adapter_mock.get_alter_operations.return_value = [] + evaluator = SnapshotEvaluator(adapter_mock) + model = load_sql_based_model( parse( # type: ignore """ @@ -1128,14 +1808,15 @@ def test_create_clone_in_dev(mocker: MockerFixture, adapter_mock, make_snapshot) ) snapshot = make_snapshot(model) - snapshot.categorize_as(SnapshotChangeCategory.FORWARD_ONLY) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) snapshot.previous_versions = snapshot.all_versions - evaluator.create([snapshot], {}) + evaluator.create([snapshot], {}, deployability_index=DeployabilityIndex.none_deployable()) adapter_mock.create_table.assert_called_once_with( - f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__temp__schema_migration_source", - columns_to_types={"a": exp.DataType.build("int"), "ds": exp.DataType.build("date")}, + f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev_schema_tmp", + target_columns_to_types={"a": exp.DataType.build("int"), "ds": exp.DataType.build("date")}, + table_format=None, storage_format=None, partitioned_by=[exp.to_column("ds", quoted=True)], partition_interval_unit=IntervalUnit.DAY, @@ -1146,29 +1827,31 @@ def test_create_clone_in_dev(mocker: MockerFixture, adapter_mock, make_snapshot) ) adapter_mock.clone_table.assert_called_once_with( - f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__temp", + f"sqlmesh__test_schema.test_schema__test_model__{snapshot.dev_version}__dev", f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}", - replace=True, + rendered_physical_properties={}, ) - adapter_mock.get_alter_expressions.assert_called_once_with( - f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__temp", - f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__temp__schema_migration_source", + adapter_mock.get_alter_operations.assert_called_once_with( + f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev", + f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev_schema_tmp", + ignore_destructive=False, + ignore_additive=False, ) adapter_mock.alter_table.assert_called_once_with([]) adapter_mock.drop_table.assert_called_once_with( - f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__temp__schema_migration_source" + f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev_schema_tmp" ) def test_drop_clone_in_dev_when_migration_fails(mocker: MockerFixture, adapter_mock, make_snapshot): adapter_mock.SUPPORTS_CLONING = True - adapter_mock.get_alter_expressions.return_value = [] + adapter_mock.get_alter_operations.return_value = [] evaluator = SnapshotEvaluator(adapter_mock) - adapter_mock.alter_table.side_effect = Exception("Migration failed") + adapter_mock.alter_table.side_effect = DestructiveChangeError("Migration failed") model = load_sql_based_model( parse( # type: ignore @@ -1186,42 +1869,49 @@ def test_drop_clone_in_dev_when_migration_fails(mocker: MockerFixture, adapter_m ) snapshot = make_snapshot(model) - snapshot.categorize_as(SnapshotChangeCategory.FORWARD_ONLY) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) snapshot.previous_versions = snapshot.all_versions - evaluator.create([snapshot], {}) + with pytest.raises(SnapshotCreationFailedError): + evaluator.create([snapshot], {}, deployability_index=DeployabilityIndex.none_deployable()) adapter_mock.clone_table.assert_called_once_with( - f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__temp", + f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev", f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}", - replace=True, + rendered_physical_properties={}, ) - adapter_mock.get_alter_expressions.assert_called_once_with( - f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__temp", - f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__temp__schema_migration_source", + adapter_mock.get_alter_operations.assert_called_once_with( + f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev", + f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev_schema_tmp", + ignore_destructive=False, + ignore_additive=False, ) adapter_mock.alter_table.assert_called_once_with([]) adapter_mock.drop_table.assert_has_calls( [ - call(f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__temp"), call( - f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__temp__schema_migration_source" + f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev_schema_tmp" ), + call(f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev"), ] ) -def test_create_clone_in_dev_self_referencing(mocker: MockerFixture, adapter_mock, make_snapshot): +@pytest.mark.parametrize("use_this_model", [True, False]) +def test_create_clone_in_dev_self_referencing( + mocker: MockerFixture, adapter_mock, make_snapshot, use_this_model: bool +): adapter_mock.SUPPORTS_CLONING = True - adapter_mock.get_alter_expressions.return_value = [] + adapter_mock.get_alter_operations.return_value = [] evaluator = SnapshotEvaluator(adapter_mock) + from_table = "test_schema.test_model" if not use_this_model else "@this_model" model = load_sql_based_model( parse( # type: ignore - """ + f""" MODEL ( name test_schema.test_model, kind INCREMENTAL_BY_TIME_RANGE ( @@ -1229,20 +1919,21 @@ def test_create_clone_in_dev_self_referencing(mocker: MockerFixture, adapter_moc ) ); - SELECT 1::INT as a, ds::DATE FROM test_schema.test_model; + SELECT 1::INT as a, ds::DATE FROM {from_table}; """ ), ) snapshot = make_snapshot(model) - snapshot.categorize_as(SnapshotChangeCategory.FORWARD_ONLY) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) snapshot.previous_versions = snapshot.all_versions - evaluator.create([snapshot], {}) + evaluator.create([snapshot], {}, deployability_index=DeployabilityIndex.none_deployable()) adapter_mock.create_table.assert_called_once_with( - f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__temp__schema_migration_source", - columns_to_types={"a": exp.DataType.build("int"), "ds": exp.DataType.build("date")}, + f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev_schema_tmp", + target_columns_to_types={"a": exp.DataType.build("int"), "ds": exp.DataType.build("date")}, + table_format=None, storage_format=None, partitioned_by=[exp.to_column("ds", quoted=True)], partition_interval_unit=IntervalUnit.DAY, @@ -1252,24 +1943,39 @@ def test_create_clone_in_dev_self_referencing(mocker: MockerFixture, adapter_moc column_descriptions=None, ) - # Make sure the dry run references the correct ("...__schema_migration_source") table. + # Make sure the dry run references the correct ("..._schema_tmp") table. + table_alias = ( + "test_model" + if not use_this_model + else f"test_schema__test_model__{snapshot.version}__dev_schema_tmp" + ) dry_run_query = adapter_mock.fetchall.call_args[0][0].sql() assert ( dry_run_query - == f'SELECT CAST(1 AS INT) AS "a", CAST("ds" AS DATE) AS "ds" FROM "sqlmesh__test_schema"."test_schema__test_model__{snapshot.version}__temp__schema_migration_source" AS "test_model" /* test_schema.test_model */ WHERE FALSE LIMIT 0' + == f'SELECT CAST(1 AS INT) AS "a", CAST("ds" AS DATE) AS "ds" FROM "sqlmesh__test_schema"."test_schema__test_model__{snapshot.version}__dev_schema_tmp" AS "{table_alias}" /* test_schema.test_model */ WHERE FALSE LIMIT 0' ) def test_on_destructive_change_runtime_check( mocker: MockerFixture, make_snapshot, + make_mocked_engine_adapter, ): - connection_mock = mocker.NonCallableMock() - cursor_mock = mocker.Mock() - connection_mock.cursor.return_value = cursor_mock - adapter = EngineAdapter(lambda: connection_mock, "") + # SQLMesh default: ERROR + model = SqlModel( + name="test_schema.test_model", + kind=IncrementalByTimeRangeKind(time_column="a"), + query=parse_one("SELECT c, a FROM tbl WHERE ds BETWEEN @start_ds and @end_ds"), + ) + snapshot = make_snapshot(model, version="1") + snapshot.change_category = SnapshotChangeCategory.BREAKING + snapshot.forward_only = True + snapshot.previous_versions = snapshot.all_versions + + adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.with_settings = lambda **kwargs: adapter # type: ignore - current_table = "sqlmesh__test_schema.test_schema__test_model__1" + current_table = snapshot.table_name() def columns(table_name): if table_name == current_table: @@ -1277,58 +1983,205 @@ def columns(table_name): "c": exp.DataType.build("int"), "b": exp.DataType.build("int"), } - else: - return { - "c": exp.DataType.build("int"), - "a": exp.DataType.build("int"), - } + return { + "c": exp.DataType.build("int"), + "a": exp.DataType.build("int"), + } adapter.columns = columns # type: ignore + mocker.patch( + "sqlmesh.core.engine_adapter.base.EngineAdapter.get_data_objects", + return_value=[ + DataObject( + schema="sqlmesh__test_schema", + name=f"test_schema__test_model__{snapshot.version}", + type=DataObjectType.TABLE, + ) + ], + ) evaluator = SnapshotEvaluator(adapter) - # SQLMesh default: ERROR + with pytest.raises(NodeExecutionFailedError) as ex: + evaluator.migrate([snapshot], {}) + + destructive_change_err = ex.value.__cause__ + assert isinstance(destructive_change_err, DestructiveChangeError) + assert ( + str(destructive_change_err) + == "\nPlan requires destructive change to forward-only model '\"test_schema\".\"test_model\"'s schema that drops column 'b'.\n\nSchema changes:\n ALTER TABLE sqlmesh__test_schema.test_schema__test_model__1 DROP COLUMN b\n ALTER TABLE sqlmesh__test_schema.test_schema__test_model__1 ADD COLUMN a INT\n\nTo allow the destructive change, set the model's `on_destructive_change` setting to `warn`, `allow`, or `ignore` or include the model in the plan's `--allow-destructive-model` option.\n" + ) + + # WARN model = SqlModel( name="test_schema.test_model", - kind=IncrementalByTimeRangeKind(time_column="a"), + kind=IncrementalByTimeRangeKind( + time_column="a", on_destructive_change=OnDestructiveChange.WARN + ), query=parse_one("SELECT c, a FROM tbl WHERE ds BETWEEN @start_ds and @end_ds"), ) snapshot = make_snapshot(model, version="1") - snapshot.change_category = SnapshotChangeCategory.FORWARD_ONLY + snapshot.change_category = SnapshotChangeCategory.BREAKING + snapshot.forward_only = True + snapshot.previous_versions = snapshot.all_versions - with pytest.raises( - NodeExecutionFailedError, - match="""Execution failed for node SnapshotId<"test_schema"."test_model""", - ): - with pytest.raises( - RuntimeError, - match="""Plan results in a destructive change to forward-only table '"test_schema"."test_model"'s schema.""", - ): - evaluator.migrate([snapshot], {}) + logger = logging.getLogger("sqlmesh.core.snapshot.evaluator") + with patch.object(logger, "warning") as mock_logger: + evaluator.migrate([snapshot], {}) + assert ( + mock_logger.call_args[0][0] + == "\nPlan requires destructive change to forward-only model '\"test_schema\".\"test_model\"'s schema that drops column 'b'.\n\nSchema changes:\n ALTER TABLE sqlmesh__test_schema.test_schema__test_model__1 DROP COLUMN b\n ALTER TABLE sqlmesh__test_schema.test_schema__test_model__1 ADD COLUMN a INT" + ) + + # allow destructive + with patch.object(logger, "warning") as mock_logger: + evaluator.migrate( + [snapshot], + {}, + {'"test_schema"."test_model"'}, + ) + assert mock_logger.call_count == 0 + + +def test_on_additive_change_runtime_check( + mocker: MockerFixture, + make_snapshot, + make_mocked_engine_adapter, +): + # SQLMesh default: ERROR + model = SqlModel( + name="test_schema.test_model", + kind=IncrementalByTimeRangeKind(time_column="a", on_additive_change=OnAdditiveChange.ERROR), + query=parse_one("SELECT c, a, b FROM tbl WHERE ds BETWEEN @start_ds and @end_ds"), + ) + snapshot = make_snapshot(model, version="1") + snapshot.change_category = SnapshotChangeCategory.BREAKING + snapshot.forward_only = True + snapshot.previous_versions = snapshot.all_versions + + adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.with_settings = lambda **kwargs: adapter # type: ignore + current_table = snapshot.table_name() + + def columns(table_name): + if table_name == current_table: + return { + "c": exp.DataType.build("int"), + "a": exp.DataType.build("int"), + } + return { + "c": exp.DataType.build("int"), + "a": exp.DataType.build("int"), + "b": exp.DataType.build("int"), + } + + adapter.columns = columns # type: ignore + mocker.patch( + "sqlmesh.core.engine_adapter.base.EngineAdapter.get_data_objects", + return_value=[ + DataObject( + schema="sqlmesh__test_schema", + name=f"test_schema__test_model__{snapshot.version}", + type=DataObjectType.TABLE, + ) + ], + ) + + evaluator = SnapshotEvaluator(adapter) + + with pytest.raises(NodeExecutionFailedError) as ex: + evaluator.migrate([snapshot], {}) + + additive_change_error = ex.value.__cause__ + assert isinstance(additive_change_error, AdditiveChangeError) + assert ( + str(additive_change_error) + == "\nPlan requires additive change to forward-only model '\"test_schema\".\"test_model\"'s schema that adds column 'b'.\n\nSchema changes:\n ALTER TABLE sqlmesh__test_schema.test_schema__test_model__1 ADD COLUMN b INT\n\nTo allow the additive change, set the model's `on_additive_change` setting to `warn`, `allow`, or `ignore` or include the model in the plan's `--allow-additive-model` option.\n" + ) # WARN model = SqlModel( name="test_schema.test_model", kind=IncrementalByTimeRangeKind( - time_column="a", on_destructive_change=OnDestructiveChange.WARN + time_column="a", on_additive_change=OnDestructiveChange.WARN ), query=parse_one("SELECT c, a FROM tbl WHERE ds BETWEEN @start_ds and @end_ds"), ) snapshot = make_snapshot(model, version="1") - snapshot.change_category = SnapshotChangeCategory.FORWARD_ONLY + snapshot.change_category = SnapshotChangeCategory.BREAKING + snapshot.forward_only = True + snapshot.previous_versions = snapshot.all_versions logger = logging.getLogger("sqlmesh.core.snapshot.evaluator") with patch.object(logger, "warning") as mock_logger: evaluator.migrate([snapshot], {}) assert ( mock_logger.call_args[0][0] - == """Plan results in a destructive change to forward-only table '"test_schema"."test_model"'s schema.""" + == "\nPlan requires additive change to forward-only model '\"test_schema\".\"test_model\"'s schema that adds column 'b'.\n\nSchema changes:\n ALTER TABLE sqlmesh__test_schema.test_schema__test_model__1 ADD COLUMN b INT" ) - # allow destructive - with patch.object(logger, "warning") as mock_logger: - evaluator.migrate([snapshot], {}, {'"test_schema"."test_model"'}) - assert mock_logger.call_count == 0 + +def test_temp_table_includes_schema_for_ignore_changes( + mocker: MockerFixture, + make_snapshot, + make_mocked_engine_adapter, +): + """Test that temp table creation includes the physical schema when on_destructive_change or on_additive_change is IGNORE.""" + # Create a model with on_destructive_change=IGNORE to trigger temp table creation + model = SqlModel( + name="test_schema.test_model", + kind=IncrementalByTimeRangeKind( + time_column="ds", on_destructive_change=OnDestructiveChange.IGNORE + ), + query=parse_one("SELECT c, a FROM tbl WHERE ds BETWEEN @start_ds and @end_ds"), + ) + snapshot = make_snapshot(model, version="1") + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + # Set up the mocked adapter + adapter = make_mocked_engine_adapter(EngineAdapter) + adapter.with_settings = lambda **kwargs: adapter # type: ignore + adapter.table_exists = lambda _: True # type: ignore + + # Mock columns method to return existing columns + def columns(table_name): + return { + "c": exp.DataType.build("int"), + "a": exp.DataType.build("int"), + "ds": exp.DataType.build("timestamp"), + } + + adapter.columns = columns # type: ignore + + # Create a mock for the temp_table context manager + temp_table_name_captured = None + + @contextlib.contextmanager + def mock_temp_table(query_or_df, name="diff", **kwargs): + nonlocal temp_table_name_captured + temp_table_name_captured = exp.to_table(name) + # Return a table that temp_table would normally return + yield exp.table_("__temp_diff_12345", db=temp_table_name_captured.db) + + adapter.temp_table = mock_temp_table # type: ignore + adapter.insert_append = lambda *args, **kwargs: None # type: ignore + + evaluator = SnapshotEvaluator(adapter) + + # Call the append method which will trigger _get_target_and_source_columns + evaluator.evaluate( + snapshot, + start="2020-01-01", + end="2020-01-02", + execution_time="2020-01-02", + snapshots={}, + ) + + # Verify that temp_table was called with a name that includes the schema + assert temp_table_name_captured is not None + assert temp_table_name_captured.name == "diff" + assert temp_table_name_captured.db == model.physical_schema + assert str(temp_table_name_captured.db) == "sqlmesh__test_schema" def test_forward_only_snapshot_for_added_model(mocker: MockerFixture, adapter_mock, make_snapshot): @@ -1351,12 +2204,13 @@ def test_forward_only_snapshot_for_added_model(mocker: MockerFixture, adapter_mo ) snapshot = make_snapshot(model) - snapshot.categorize_as(SnapshotChangeCategory.FORWARD_ONLY) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) - evaluator.create([snapshot], {}) + evaluator.create([snapshot], {}, deployability_index=DeployabilityIndex.none_deployable()) common_create_args = dict( - columns_to_types={"a": exp.DataType.build("int"), "ds": exp.DataType.build("date")}, + target_columns_to_types={"a": exp.DataType.build("int"), "ds": exp.DataType.build("date")}, + table_format=None, storage_format=None, partitioned_by=[exp.to_column("ds", quoted=True)], partition_interval_unit=IntervalUnit.DAY, @@ -1367,7 +2221,7 @@ def test_forward_only_snapshot_for_added_model(mocker: MockerFixture, adapter_mo adapter_mock.create_table.assert_has_calls( [ call( - f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__temp", + f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}__dev", column_descriptions=None, **common_create_args, ), @@ -1396,10 +2250,10 @@ def test_create_scd_type_2_by_time(adapter_mock, make_snapshot): snapshot = make_snapshot(model) snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - evaluator.create([snapshot], {}) + evaluator.create([snapshot], {}, deployability_index=DeployabilityIndex.none_deployable()) common_kwargs = dict( - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("INT"), "name": exp.DataType.build("STRING"), "updated_at": exp.DataType.build("TIMESTAMPTZ"), @@ -1407,9 +2261,10 @@ def test_create_scd_type_2_by_time(adapter_mock, make_snapshot): "valid_from": exp.DataType.build("TIMESTAMPTZ"), "valid_to": exp.DataType.build("TIMESTAMPTZ"), }, + table_format=None, storage_format=None, partitioned_by=[], - partition_interval_unit=IntervalUnit.DAY, + partition_interval_unit=None, clustered_by=[], table_properties={}, table_description=None, @@ -1418,14 +2273,15 @@ def test_create_scd_type_2_by_time(adapter_mock, make_snapshot): adapter_mock.create_table.assert_has_calls( [ call( - snapshot.table_name(is_deployable=False), column_descriptions=None, **common_kwargs + snapshot.table_name(is_deployable=False), + column_descriptions=None, + **common_kwargs, ), - call(snapshot.table_name(), column_descriptions={}, **common_kwargs), ] ) -def test_create_ctas_scd_type_2_by_time(adapter_mock, make_snapshot): +def test_create_ctas_scd_type_2_by_time(adapter_mock, make_snapshot, mocker): evaluator = SnapshotEvaluator(adapter_mock) model = load_sql_based_model( parse( # type: ignore @@ -1436,7 +2292,8 @@ def test_create_ctas_scd_type_2_by_time(adapter_mock, make_snapshot): unique_key id, time_data_type TIMESTAMPTZ, invalidate_hard_deletes false - ) + ), + partitioned_by cola ); SELECT * FROM tbl; @@ -1447,17 +2304,21 @@ def test_create_ctas_scd_type_2_by_time(adapter_mock, make_snapshot): snapshot = make_snapshot(model) snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - evaluator.create([snapshot], {}) + evaluator.create([snapshot], {}, deployability_index=DeployabilityIndex.none_deployable()) + source_query = parse_one('SELECT * FROM "tbl" AS "tbl"') query = parse_one( """SELECT *, CAST(NULL AS TIMESTAMPTZ) AS valid_from, CAST(NULL AS TIMESTAMPTZ) AS valid_to FROM "tbl" AS "tbl" WHERE FALSE LIMIT 0""" ) # Verify that managed columns are included in CTAS with types common_kwargs = dict( + table_format=None, storage_format=None, - partitioned_by=[], - partition_interval_unit=IntervalUnit.DAY, + partitioned_by=[ + exp.to_column("cola", quoted=True), + ], + partition_interval_unit=None, clustered_by=[], table_properties={}, table_description=None, @@ -1472,7 +2333,38 @@ def test_create_ctas_scd_type_2_by_time(adapter_mock, make_snapshot): column_descriptions=None, **common_kwargs, ), - call(snapshot.table_name(), query, None, column_descriptions={}, **common_kwargs), + ] + ) + + adapter_mock.reset_mock() + + evaluator.evaluate( + snapshot, + start="2020-01-01", + end="2020-01-02", + execution_time="2020-01-02", + snapshots={}, + deployability_index=DeployabilityIndex.none_deployable(), + ) + + adapter_mock.scd_type_2_by_time.assert_has_calls( + [ + call( + column_descriptions={}, + execution_time="2020-01-02", + invalidate_hard_deletes=False, + source_columns=None, + source_table=source_query, + target_columns_to_types=mocker.ANY, + target_table=snapshot.table_name(is_deployable=False), + truncate=True, + unique_key=[exp.to_column("id", quoted=True)], + updated_at_as_valid_from=False, + updated_at_col=exp.column("updated_at", quoted=True), + valid_from_col=exp.column("valid_from", quoted=True), + valid_to_col=exp.column("valid_to", quoted=True), + **common_kwargs, + ), ] ) @@ -1507,6 +2399,16 @@ def test_insert_into_scd_type_2_by_time( snapshot.categorize_as(SnapshotChangeCategory.BREAKING) snapshot.intervals = intervals + table_columns = { + "id": exp.DataType.build("INT"), + "name": exp.DataType.build("STRING"), + "updated_at": exp.DataType.build("TIMESTAMP"), + # Make sure that the call includes these extra columns + "valid_from": exp.DataType.build("TIMESTAMP"), + "valid_to": exp.DataType.build("TIMESTAMP"), + } + adapter_mock.columns.return_value = table_columns + evaluator.evaluate( snapshot, start="2020-01-01", @@ -1518,14 +2420,8 @@ def test_insert_into_scd_type_2_by_time( adapter_mock.scd_type_2_by_time.assert_called_once_with( target_table=snapshot.table_name(), source_table=model.render_query(), - columns_to_types={ - "id": exp.DataType.build("INT"), - "name": exp.DataType.build("STRING"), - "updated_at": exp.DataType.build("TIMESTAMP"), - # Make sure that the call includes these extra columns - "valid_from": exp.DataType.build("TIMESTAMP"), - "valid_to": exp.DataType.build("TIMESTAMP"), - }, + target_columns_to_types=table_columns, + table_format=None, unique_key=[exp.to_column("id", quoted=True)], valid_from_col=exp.column("valid_from", quoted=True), valid_to_col=exp.column("valid_to", quoted=True), @@ -1536,7 +2432,14 @@ def test_insert_into_scd_type_2_by_time( column_descriptions={}, updated_at_as_valid_from=False, truncate=truncate, + source_columns=None, + clustered_by=[], + partition_interval_unit=None, + partitioned_by=[], + storage_format=None, + table_properties={}, ) + adapter_mock.columns.assert_called_once_with(snapshot.table_name()) def test_create_scd_type_2_by_column(adapter_mock, make_snapshot): @@ -1561,19 +2464,20 @@ def test_create_scd_type_2_by_column(adapter_mock, make_snapshot): snapshot = make_snapshot(model) snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - evaluator.create([snapshot], {}) + evaluator.create([snapshot], {}, deployability_index=DeployabilityIndex.none_deployable()) common_kwargs = dict( - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("INT"), "name": exp.DataType.build("STRING"), # Make sure that the call includes these extra columns "valid_from": exp.DataType.build("TIMESTAMP"), "valid_to": exp.DataType.build("TIMESTAMP"), }, + table_format=None, storage_format=None, partitioned_by=[], - partition_interval_unit=IntervalUnit.DAY, + partition_interval_unit=None, clustered_by=[], table_properties={}, table_description=None, @@ -1585,7 +2489,6 @@ def test_create_scd_type_2_by_column(adapter_mock, make_snapshot): snapshot.table_name(is_deployable=False), **{**common_kwargs, "column_descriptions": None}, ), - call(snapshot.table_name(), **{**common_kwargs, "column_descriptions": {}}), ] ) @@ -1611,7 +2514,7 @@ def test_create_ctas_scd_type_2_by_column(adapter_mock, make_snapshot): snapshot = make_snapshot(model) snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - evaluator.create([snapshot], {}) + evaluator.create([snapshot], {}, deployability_index=DeployabilityIndex.none_deployable()) query = parse_one( """SELECT *, CAST(NULL AS TIMESTAMP) AS valid_from, CAST(NULL AS TIMESTAMP) AS valid_to FROM "tbl" AS "tbl" WHERE FALSE LIMIT 0""" @@ -1619,9 +2522,10 @@ def test_create_ctas_scd_type_2_by_column(adapter_mock, make_snapshot): # Verify that managed columns are included in CTAS with types common_kwargs = dict( + table_format=None, storage_format=None, partitioned_by=[], - partition_interval_unit=IntervalUnit.DAY, + partition_interval_unit=None, clustered_by=[], table_properties={}, table_description=None, @@ -1635,9 +2539,6 @@ def test_create_ctas_scd_type_2_by_column(adapter_mock, make_snapshot): None, **{**common_kwargs, "column_descriptions": None}, ), - call( - snapshot.table_name(), query, None, **{**common_kwargs, "column_descriptions": {}} - ), ] ) @@ -1673,6 +2574,15 @@ def test_insert_into_scd_type_2_by_column( snapshot.categorize_as(SnapshotChangeCategory.BREAKING) snapshot.intervals = intervals + table_columns = { + "id": exp.DataType.build("INT"), + "name": exp.DataType.build("STRING"), + # Make sure that the call includes these extra columns + "valid_from": exp.DataType.build("TIMESTAMP"), + "valid_to": exp.DataType.build("TIMESTAMP"), + } + adapter_mock.columns.return_value = table_columns + evaluator.evaluate( snapshot, start="2020-01-01", @@ -1684,15 +2594,10 @@ def test_insert_into_scd_type_2_by_column( adapter_mock.scd_type_2_by_column.assert_called_once_with( target_table=snapshot.table_name(), source_table=model.render_query(), - columns_to_types={ - "id": exp.DataType.build("INT"), - "name": exp.DataType.build("STRING"), - # Make sure that the call includes these extra columns - "valid_from": exp.DataType.build("TIMESTAMP"), - "valid_to": exp.DataType.build("TIMESTAMP"), - }, + target_columns_to_types=table_columns, + table_format=None, unique_key=[exp.to_column("id", quoted=True)], - check_columns=exp.Star(), + check_columns=[exp.Star()], valid_from_col=exp.column("valid_from", quoted=True), valid_to_col=exp.column("valid_to", quoted=True), execution_time="2020-01-02", @@ -1701,7 +2606,14 @@ def test_insert_into_scd_type_2_by_column( table_description=None, column_descriptions={}, truncate=truncate, + source_columns=None, + clustered_by=[], + partition_interval_unit=None, + partitioned_by=[], + storage_format=None, + table_properties={}, ) + adapter_mock.columns.assert_called_once_with(snapshot.table_name()) def test_create_incremental_by_unique_key_updated_at_exp(adapter_mock, make_snapshot): @@ -1737,33 +2649,42 @@ def test_create_incremental_by_unique_key_updated_at_exp(adapter_mock, make_snap adapter_mock.merge.assert_called_once_with( snapshot.table_name(), model.render_query(), - columns_to_types={ + target_columns_to_types={ "id": exp.DataType.build("INT"), "name": exp.DataType.build("STRING"), "updated_at": exp.DataType.build("TIMESTAMP"), }, unique_key=[exp.to_column("id", quoted=True)], - when_matched=exp.When( - matched=True, - source=False, - then=exp.Update( - expressions=[ - exp.column("name", MERGE_TARGET_ALIAS).eq( - exp.column("name", MERGE_SOURCE_ALIAS) + merge_filter=None, + when_matched=exp.Whens( + expressions=[ + exp.When( + matched=True, + source=False, + then=exp.Update( + expressions=[ + exp.column("name", MERGE_TARGET_ALIAS, quoted=True).eq( + exp.column("name", MERGE_SOURCE_ALIAS, quoted=True) + ), + exp.column("updated_at", MERGE_TARGET_ALIAS, quoted=True).eq( + exp.Coalesce( + this=exp.column("updated_at", MERGE_SOURCE_ALIAS, quoted=True), + expressions=[ + exp.column("updated_at", MERGE_TARGET_ALIAS, quoted=True) + ], + ) + ), + ], ), - exp.column("updated_at", MERGE_TARGET_ALIAS).eq( - exp.Coalesce( - this=exp.column("updated_at", MERGE_SOURCE_ALIAS), - expressions=[exp.column("updated_at", MERGE_TARGET_ALIAS)], - ) - ), - ], - ), + ) + ] ), + physical_properties={}, + source_columns=None, ) -def test_create_incremental_by_unique_no_intervals(adapter_mock, make_snapshot): +def test_create_incremental_by_unique_key_multiple_updated_at_exp(adapter_mock, make_snapshot): evaluator = SnapshotEvaluator(adapter_mock) model = load_sql_based_model( parse( # type: ignore @@ -1772,6 +2693,8 @@ def test_create_incremental_by_unique_no_intervals(adapter_mock, make_snapshot): name test_schema.test_model, kind INCREMENTAL_BY_UNIQUE_KEY ( unique_key [id], + when_matched WHEN MATCHED AND source.id = 1 THEN UPDATE SET target.name = source.name, target.updated_at = COALESCE(source.updated_at, target.updated_at), + WHEN MATCHED THEN UPDATE SET target.name = source.name, target.updated_at = COALESCE(source.updated_at, target.updated_at) ) ); @@ -1782,6 +2705,7 @@ def test_create_incremental_by_unique_no_intervals(adapter_mock, make_snapshot): snapshot = make_snapshot(model) snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot.intervals = [(to_timestamp("2020-01-01"), to_timestamp("2020-01-02"))] evaluator.evaluate( snapshot, @@ -1791,28 +2715,223 @@ def test_create_incremental_by_unique_no_intervals(adapter_mock, make_snapshot): snapshots={}, ) - adapter_mock.replace_query.assert_called_once_with( + adapter_mock.merge.assert_called_once_with( snapshot.table_name(), model.render_query(), - clustered_by=[], - column_descriptions={}, - columns_to_types=model.columns_to_types, - partition_interval_unit=model.interval_unit, - partitioned_by=model.partitioned_by, - storage_format=None, - table_description=None, - table_properties={}, + target_columns_to_types={ + "id": exp.DataType.build("INT"), + "name": exp.DataType.build("STRING"), + "updated_at": exp.DataType.build("TIMESTAMP"), + }, + unique_key=[exp.to_column("id", quoted=True)], + merge_filter=None, + when_matched=exp.Whens( + expressions=[ + exp.When( + matched=True, + condition=exp.column("id", MERGE_SOURCE_ALIAS, quoted=True).eq( + exp.Literal.number(1) + ), + then=exp.Update( + expressions=[ + exp.column("name", MERGE_TARGET_ALIAS, quoted=True).eq( + exp.column("name", MERGE_SOURCE_ALIAS, quoted=True) + ), + exp.column("updated_at", MERGE_TARGET_ALIAS, quoted=True).eq( + exp.Coalesce( + this=exp.column("updated_at", MERGE_SOURCE_ALIAS, quoted=True), + expressions=[ + exp.column("updated_at", MERGE_TARGET_ALIAS, quoted=True) + ], + ) + ), + ], + ), + ), + exp.When( + matched=True, + source=False, + then=exp.Update( + expressions=[ + exp.column("name", MERGE_TARGET_ALIAS, quoted=True).eq( + exp.column("name", MERGE_SOURCE_ALIAS, quoted=True) + ), + exp.column("updated_at", MERGE_TARGET_ALIAS, quoted=True).eq( + exp.Coalesce( + this=exp.column("updated_at", MERGE_SOURCE_ALIAS, quoted=True), + expressions=[ + exp.column("updated_at", MERGE_TARGET_ALIAS, quoted=True) + ], + ) + ), + ], + ), + ), + ], + ), + physical_properties={}, + source_columns=None, ) -def test_create_seed(mocker: MockerFixture, adapter_mock, make_snapshot): - expressions = d.parse( - """ - MODEL ( - name db.seed, - kind SEED ( - path '../seeds/waiter_names.csv', - batch_size 5, +def test_create_incremental_by_unique_no_intervals(adapter_mock, make_snapshot): + evaluator = SnapshotEvaluator(adapter_mock) + model = load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name test_schema.test_model, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key [id], + ) + ); + + SELECT id::int, name::string, updated_at::timestamp FROM tbl; + """ + ) + ) + + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + table_columns = { + "id": exp.DataType.build("int"), + "name": exp.DataType.build("string"), + "updated_at": exp.DataType.build("timestamp"), + } + adapter_mock.columns.return_value = table_columns + + evaluator.evaluate( + snapshot, + start="2020-01-01", + end="2020-01-02", + execution_time="2020-01-02", + snapshots={}, + ) + + adapter_mock.replace_query.assert_called_once_with( + snapshot.table_name(), + model.render_query(), + clustered_by=[], + column_descriptions={}, + target_columns_to_types=table_columns, + partition_interval_unit=model.partition_interval_unit, + partitioned_by=model.partitioned_by, + table_format=None, + storage_format=None, + table_description=None, + table_properties={}, + source_columns=None, + ) + adapter_mock.columns.assert_called_once_with(snapshot.table_name()) + + +def test_create_incremental_by_unique_key_merge_filter(adapter_mock, make_snapshot): + evaluator = SnapshotEvaluator(adapter_mock) + model = load_sql_based_model( + d.parse( + """ + MODEL ( + name test_schema.test_model, + kind INCREMENTAL_BY_UNIQUE_KEY ( + unique_key [id], + merge_filter source.id > 0 and target.updated_at < @end_ds and source.updated_at > @start_ds, + when_matched WHEN MATCHED THEN UPDATE SET target.updated_at = COALESCE(source.updated_at, target.updated_at), + ) + ); + + SELECT id::int, updated_at::timestamp FROM tbl; + """ + ) + ) + + # At load time macros should remain unresolved + assert model.merge_filter == exp.And( + this=exp.And( + this=exp.GT( + this=exp.column("id", MERGE_SOURCE_ALIAS, quoted=True), + expression=exp.Literal(this="0", is_string=False), + ), + expression=exp.LT( + this=exp.column("updated_at", MERGE_TARGET_ALIAS, quoted=True), + expression=d.MacroVar(this="end_ds"), + ), + ), + expression=exp.GT( + this=exp.column("updated_at", MERGE_SOURCE_ALIAS, quoted=True), + expression=d.MacroVar(this="start_ds"), + ), + ) + + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot.intervals = [(to_timestamp("2020-01-01"), to_timestamp("2020-01-02"))] + + evaluator.evaluate( + snapshot, + start="2020-01-01", + end="2020-01-02", + execution_time="2020-01-02", + snapshots={}, + ) + + # The macros for merge_filter must now be rendered at evaluation time. + adapter_mock.merge.assert_called_once_with( + snapshot.table_name(), + model.render_query(), + target_columns_to_types={ + "id": exp.DataType.build("INT"), + "updated_at": exp.DataType.build("TIMESTAMP"), + }, + unique_key=[exp.to_column("id", quoted=True)], + when_matched=exp.Whens( + expressions=[ + exp.When( + matched=True, + then=exp.Update( + expressions=[ + exp.column("updated_at", MERGE_TARGET_ALIAS, quoted=True).eq( + exp.Coalesce( + this=exp.column("updated_at", MERGE_SOURCE_ALIAS, quoted=True), + expressions=[ + exp.column("updated_at", MERGE_TARGET_ALIAS, quoted=True) + ], + ) + ), + ], + ), + ) + ] + ), + merge_filter=exp.And( + this=exp.And( + this=exp.GT( + this=exp.column("id", MERGE_SOURCE_ALIAS, quoted=True), + expression=exp.Literal(this="0", is_string=False), + ), + expression=exp.LT( + this=exp.column("updated_at", MERGE_TARGET_ALIAS, quoted=True), + expression=exp.Literal(this="2020-01-02", is_string=True), + ), + ), + expression=exp.GT( + this=exp.column("updated_at", MERGE_SOURCE_ALIAS, quoted=True), + expression=exp.Literal(this="2020-01-01", is_string=True), + ), + ), + physical_properties={}, + source_columns=None, + ) + + +def test_create_seed(mocker: MockerFixture, adapter_mock, make_snapshot): + expressions = d.parse( + """ + MODEL ( + name db.seed, + kind SEED ( + path '../seeds/waiter_names.csv', + batch_size 5, ) ); """ @@ -1827,10 +2946,14 @@ def test_create_seed(mocker: MockerFixture, adapter_mock, make_snapshot): evaluator.create([snapshot], {}) common_create_kwargs: t.Dict[str, t.Any] = dict( - columns_to_types={"id": exp.DataType.build("bigint"), "name": exp.DataType.build("text")}, + target_columns_to_types={ + "id": exp.DataType.build("bigint"), + "name": exp.DataType.build("text"), + }, + table_format=None, storage_format=None, partitioned_by=[], - partition_interval_unit=IntervalUnit.DAY, + partition_interval_unit=None, clustered_by=[], table_properties={}, table_description=None, @@ -1840,22 +2963,14 @@ def test_create_seed(mocker: MockerFixture, adapter_mock, make_snapshot): f"sqlmesh__db.db__seed__{snapshot.version}", mocker.ANY, column_descriptions={}, + source_columns=["id", "name"], **common_create_kwargs, ) - adapter_mock.create_table.assert_has_calls( - [ - call( - f"sqlmesh__db.db__seed__{snapshot.version}__temp", - column_descriptions=None, - **common_create_kwargs, - ), - call( - f"sqlmesh__db.db__seed__{snapshot.version}", - column_descriptions={}, - **common_create_kwargs, - ), - ] + adapter_mock.create_table.assert_called_once_with( + f"sqlmesh__db.db__seed__{snapshot.version}", + column_descriptions={}, + **common_create_kwargs, ) replace_query_calls = adapter_mock.replace_query.call_args_list @@ -1904,19 +3019,26 @@ def test_create_seed_on_error(mocker: MockerFixture, adapter_mock, make_snapshot snapshot.categorize_as(SnapshotChangeCategory.BREAKING) evaluator = SnapshotEvaluator(adapter_mock) - evaluator.create([snapshot], {}) + + with pytest.raises(SnapshotCreationFailedError): + evaluator.create([snapshot], {}) adapter_mock.replace_query.assert_called_once_with( f"sqlmesh__db.db__seed__{snapshot.version}", mocker.ANY, column_descriptions={}, - columns_to_types={"id": exp.DataType.build("bigint"), "name": exp.DataType.build("text")}, + target_columns_to_types={ + "id": exp.DataType.build("bigint"), + "name": exp.DataType.build("text"), + }, + table_format=None, storage_format=None, partitioned_by=[], - partition_interval_unit=IntervalUnit.DAY, + partition_interval_unit=None, clustered_by=[], table_properties={}, table_description=None, + source_columns=["id", "name"], ) adapter_mock.drop_table.assert_called_once_with(f"sqlmesh__db.db__seed__{snapshot.version}") @@ -1947,11 +3069,6 @@ def test_create_seed_no_intervals(mocker: MockerFixture, adapter_mock, make_snap schema="sqlmesh__db", type=DataObjectType.TABLE, ), - DataObject( - name=f"db__seed__{snapshot.version}__temp", - schema="sqlmesh__db", - type=DataObjectType.TABLE, - ), ] evaluator = SnapshotEvaluator(adapter_mock) @@ -1965,13 +3082,18 @@ def test_create_seed_no_intervals(mocker: MockerFixture, adapter_mock, make_snap f"sqlmesh__db.db__seed__{snapshot.version}", mocker.ANY, column_descriptions={}, - columns_to_types={"id": exp.DataType.build("bigint"), "name": exp.DataType.build("text")}, + target_columns_to_types={ + "id": exp.DataType.build("bigint"), + "name": exp.DataType.build("text"), + }, + table_format=None, storage_format=None, partitioned_by=[], - partition_interval_unit=IntervalUnit.DAY, + partition_interval_unit=None, clustered_by=[], table_properties={}, table_description=None, + source_columns=["id", "name"], ) @@ -2010,7 +3132,7 @@ def test_standalone_audit(mocker: MockerFixture, adapter_mock, make_snapshot): adapter_mock.fetchone.return_value = (0,) evaluator.audit(snapshot=snapshot, snapshots={}) - query = audit.render_query(snapshot) + query = audit.render_audit_query() adapter_mock.fetchone.assert_called_once_with( select("COUNT(*)").from_(query.subquery("audit")), quote_identifiers=True ) @@ -2048,7 +3170,7 @@ def test_standalone_audit(mocker: MockerFixture, adapter_mock, make_snapshot): adapter_mock.session.assert_not_called() -def test_audit_wap(adapter_mock, make_snapshot): +def test_audit_wap(adapter_mock: Mock, make_snapshot: t.Callable[..., Snapshot]) -> None: evaluator = SnapshotEvaluator(adapter_mock) custom_audit = ModelAudit( @@ -2064,8 +3186,9 @@ def test_audit_wap(adapter_mock, make_snapshot): ("not_null", {"columns": exp.to_column("a")}), ("test_audit", {}), ], + audit_definitions={custom_audit.name: custom_audit}, ) - snapshot = make_snapshot(model, audits={custom_audit.name: custom_audit}) + snapshot = make_snapshot(model) snapshot.categorize_as(SnapshotChangeCategory.BREAKING) wap_id = "test_wap_id" @@ -2081,7 +3204,7 @@ def test_audit_wap(adapter_mock, make_snapshot): not_null_query = call_args[0][0][0] assert ( not_null_query.sql(dialect="spark") - == "SELECT COUNT(*) FROM (SELECT * FROM (SELECT * FROM `spark_catalog`.`test_schema`.`test_table`.`branch_wap_test_wap_id` AS `branch_wap_test_wap_id`) AS `_q_0` WHERE `a` IS NULL AND TRUE) AS audit" + == "SELECT COUNT(*) FROM (SELECT * FROM `spark_catalog`.`test_schema`.`test_table`.`branch_wap_test_wap_id` AS `branch_wap_test_wap_id` WHERE `a` IS NULL AND TRUE) AS audit" ) custom_audit_query = call_args[1][0][0] @@ -2094,6 +3217,39 @@ def test_audit_wap(adapter_mock, make_snapshot): adapter_mock.wap_publish.assert_called_once_with(snapshot.table_name(), wap_id) +def test_audit_with_datetime_macros(adapter_mock, make_snapshot): + evaluator = SnapshotEvaluator(adapter_mock) + + model = SqlModel( + name="test_schema.test_table", + kind=IncrementalByUniqueKeyKind(unique_key="a"), + query=parse_one("SELECT a, start_ds FROM tbl"), + audits=[ + ( + "unique_combination_of_columns", + { + "columns": exp.Array(expressions=[exp.to_column("a")]), + "condition": d.MacroVar(this="start_ds").neq("2020-01-01"), + }, + ), + ], + ) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + adapter_mock.fetchone.return_value = (0,) + evaluator.audit(snapshot, snapshots={}, start="2020-01-01") + + call_args = adapter_mock.fetchone.call_args_list + assert len(call_args) == 1 + + unique_combination_of_columns_query = call_args[0][0][0] + assert ( + unique_combination_of_columns_query.sql(dialect="duckdb") + == """SELECT COUNT(*) FROM (SELECT "a" AS "a" FROM "test_schema"."test_table" AS "test_table" WHERE '2020-01-01' <> '2020-01-01' GROUP BY "a" HAVING COUNT(*) > 1) AS audit""" + ) + + def test_audit_set_blocking_at_use_site(adapter_mock, make_snapshot): evaluator = SnapshotEvaluator(adapter_mock) @@ -2120,17 +3276,18 @@ def blocking_value(evaluator): SELECT a::int FROM tbl """ ), + audit_definitions={always_failing_audit.name: always_failing_audit}, ) - snapshot = make_snapshot(model, audits={always_failing_audit.name: always_failing_audit}) + snapshot = make_snapshot(model) snapshot.categorize_as(SnapshotChangeCategory.BREAKING) # Return a non-zero count to indicate audit failure adapter_mock.fetchone.return_value = (1,) - logger = logging.getLogger("sqlmesh.core.snapshot.evaluator") - with patch.object(logger, "warning") as mock_logger: - evaluator.audit(snapshot, snapshots={}) - assert "Audit is warn only so proceeding with execution." in mock_logger.call_args[0][0] + results = evaluator.audit(snapshot, snapshots={}) + assert len(results) == 1 + assert results[0].count == 1 + assert not results[0].blocking model = SqlModel( name="test_schema.test_table", @@ -2139,19 +3296,19 @@ def blocking_value(evaluator): audits=[ ("always_fail", {"blocking": exp.true()}), ], + audit_definitions={always_failing_audit.name: always_failing_audit}, ) - snapshot = make_snapshot(model, audits={always_failing_audit.name: always_failing_audit}) + snapshot = make_snapshot(model) snapshot.categorize_as(SnapshotChangeCategory.BREAKING) adapter_mock.fetchone.return_value = (1,) - with pytest.raises( - AuditError, - match="Audit 'always_fail' for model 'test_schema.test_table' failed.", - ): - evaluator.audit(snapshot, snapshots={}) + results = evaluator.audit(snapshot, snapshots={}) + assert len(results) == 1 + assert results[0].count == 1 + assert results[0].blocking -def test_create_post_statements_use_deployable_table( +def test_create_post_statements_use_non_deployable_table( mocker: MockerFixture, adapter_mock, make_snapshot ): evaluator = SnapshotEvaluator(adapter_mock) @@ -2177,118 +3334,485 @@ def test_create_post_statements_use_deployable_table( snapshot = make_snapshot(model) snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - expected_call = f'CREATE INDEX IF NOT EXISTS "test_idx" ON "sqlmesh__test_schema"."test_schema__test_model__{snapshot.version}" /* test_schema.test_model */("a" NULLS FIRST)' + expected_call = f'CREATE INDEX IF NOT EXISTS "test_idx" ON "sqlmesh__test_schema"."test_schema__test_model__{snapshot.version}__dev" /* test_schema.test_model */("a" NULLS FIRST)' evaluator.create([snapshot], {}, DeployabilityIndex.none_deployable()) call_args = adapter_mock.execute.call_args_list - pre_calls = call_args[0][0][0] + pre_calls = call_args[1][0][0] assert len(pre_calls) == 1 assert pre_calls[0].sql(dialect="postgres") == expected_call - post_calls = call_args[1][0][0] + post_calls = call_args[2][0][0] assert len(post_calls) == 1 assert post_calls[0].sql(dialect="postgres") == expected_call -def test_evaluate_incremental_by_partition(mocker: MockerFixture, make_snapshot, adapter_mock): - model = SqlModel( - name="test_schema.test_model", - query=parse_one("SELECT 1, ds, b FROM tbl_a"), - kind=IncrementalByPartitionKind(), - partitioned_by=["ds", "b"], - ) - snapshot = make_snapshot(model) - snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - +def test_create_pre_post_statements_python_model( + mocker: MockerFixture, adapter_mock, make_snapshot +): evaluator = SnapshotEvaluator(adapter_mock) - evaluator.evaluate( - snapshot, - start="2020-01-01", - end="2020-01-02", - execution_time="2020-01-02", - snapshots={}, - ) - adapter_mock.insert_overwrite_by_partition.assert_called_once_with( - snapshot.table_name(), - model.render_query(), - partitioned_by=[ - exp.to_column("ds", quoted=True), - exp.to_column("b", quoted=True), - ], - columns_to_types=model.columns_to_types, + @macro() + def create_index( + evaluator: MacroEvaluator, + index_name: str, + model_name: str, + column: str, + ): + if evaluator.runtime_stage == "creating": + return f"CREATE INDEX IF NOT EXISTS {index_name} ON {model_name}({column});" + + @model( + "db.test_model", + kind="full", + columns={"id": "string", "name": "string"}, + pre_statements=["CREATE INDEX IF NOT EXISTS idx ON db.test_model(id);"], + post_statements=["@CREATE_INDEX('idx', 'db.test_model', id)"], + ) + def model_with_statements(context, **kwargs): + return pd.DataFrame( + [ + { + "id": context.var("1"), + "name": context.var("var"), + } + ] + ) + + python_model = model.get_registry()["db.test_model"].model( + module_path=Path("."), + path=Path("."), + macros=macro.get_registry(), + dialect="postgres", ) + assert len(python_model.python_env) == 3 + assert len(python_model.pre_statements) == 1 + assert len(python_model.post_statements) == 1 + assert isinstance(python_model.python_env["create_index"], Executable) + assert isinstance(python_model.pre_statements[0], exp.Create) + assert isinstance(python_model.post_statements[0], MacroFunc) -def test_custom_materialization_strategy(adapter_mock, make_snapshot): - custom_insert_called = False + snapshot = make_snapshot(python_model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - class TestCustomMaterializationStrategy(CustomMaterialization): - NAME = "custom_materialization_test" + evaluator.create([snapshot], {}, DeployabilityIndex.none_deployable()) + expected_call = f'CREATE INDEX IF NOT EXISTS "idx" ON "sqlmesh__db"."db__test_model__{snapshot.version}__dev" /* db.test_model */("id")' - def insert( - self, - table_name: str, - query_or_df: QueryOrDF, - model: Model, - is_first_insert: bool, - **kwargs: t.Any, - ) -> None: - nonlocal custom_insert_called - custom_insert_called = True + call_args = adapter_mock.execute.call_args_list + pre_calls = call_args[1][0][0] + assert len(pre_calls) == 1 + assert pre_calls[0].sql(dialect="postgres") == expected_call - assert model.custom_materialization_properties == {"test_property": "test_value"} + post_calls = call_args[2][0][0] + assert len(post_calls) == 1 + assert post_calls[0].sql(dialect="postgres") == expected_call - assert isinstance(query_or_df, exp.Query) - assert query_or_df.sql() == 'SELECT * FROM "tbl" AS "tbl"' +def test_on_virtual_update_statements(mocker: MockerFixture, adapter_mock, make_snapshot): evaluator = SnapshotEvaluator(adapter_mock) + + @macro() + def create_log_table(evaluator, view_name): + return f"CREATE OR REPLACE TABLE log_table AS SELECT '{view_name}' as fqn_this_model, '{evaluator.this_model}' as eval_this_model" + model = load_sql_based_model( - parse( # type: ignore + d.parse( """ MODEL ( name test_schema.test_model, - kind CUSTOM ( - materialization 'custom_materialization_test', - materialization_properties ( - 'test_property' = 'test_value' - ) - ) + kind FULL, + dialect postgres, ); - SELECT * FROM tbl; + SELECT a::int FROM tbl; + + CREATE INDEX IF NOT EXISTS test_idx ON test_schema.test_model(a); + + ON_VIRTUAL_UPDATE_BEGIN; + JINJA_STATEMENT_BEGIN; + GRANT SELECT ON VIEW test_schema.test_model TO ROLE admin; + JINJA_END; + GRANT REFERENCES, SELECT ON FUTURE VIEWS IN DATABASE demo_db TO ROLE owner_name; + @create_log_table(@this_model); + ON_VIRTUAL_UPDATE_END; + """ - ) + ), ) snapshot = make_snapshot(model) snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - evaluator.evaluate( - snapshot, + evaluator.create([snapshot], {}, DeployabilityIndex.none_deployable()) + + snapshots = {snapshot.name: snapshot} + environment_naming_info = EnvironmentNamingInfo(name="test_env") + evaluator.promote( + [snapshot], start="2020-01-01", - end="2020-01-02", - execution_time="2020-01-02", - snapshots={}, + end="2020-01-01", + execution_time="2020-01-01", + snapshots=snapshots, + environment_naming_info=environment_naming_info, + table_mapping=to_view_mapping( + snapshots.values(), + environment_naming_info, + ), + ) + + call_args = adapter_mock.execute.call_args_list + post_calls = call_args[2][0][0] + assert len(post_calls) == 1 + assert ( + post_calls[0].sql(dialect="postgres") + == f'CREATE INDEX IF NOT EXISTS "test_idx" ON "sqlmesh__test_schema"."test_schema__test_model__{snapshot.version}__dev" /* test_schema.test_model */("a")' + ) + + on_virtual_update_calls = call_args[4][0][0] + assert ( + on_virtual_update_calls[0].sql(dialect="postgres") + == 'GRANT SELECT ON VIEW "test_schema__test_env"."test_model" /* test_schema.test_model */ TO ROLE "admin"' + ) + assert ( + on_virtual_update_calls[1].sql(dialect="postgres") + == "GRANT REFERENCES, SELECT ON FUTURE VIEWS IN DATABASE demo_db TO ROLE owner_name" ) - assert custom_insert_called + # Validation that within the macro the environment specific view is used + assert ( + on_virtual_update_calls[2].sql(dialect="postgres") + == 'CREATE OR REPLACE TABLE "log_table" AS SELECT \'"test_schema__test_env"."test_model" /* test_schema.test_model */\' AS "fqn_this_model", \'"test_schema__test_env"."test_model"\' AS "eval_this_model"' + ) -def test_create_managed(adapter_mock, make_snapshot, mocker: MockerFixture): +def test_on_virtual_update_python_model_macro(mocker: MockerFixture, adapter_mock, make_snapshot): evaluator = SnapshotEvaluator(adapter_mock) - model = load_sql_based_model( - parse( # type: ignore - """ - MODEL ( - name test_schema.test_model, - kind MANAGED, - physical_properties ( - warehouse = 'small', - target_lag = '10 minutes' - ), + @macro() + def create_index_2( + evaluator: MacroEvaluator, + index_name: str, + model_name: str, + column: str, + ): + return f"CREATE INDEX IF NOT EXISTS {index_name} ON {model_name}({column});" + + @model( + "db.test_model_3", + kind="full", + columns={"id": "string", "name": "string"}, + on_virtual_update=["@CREATE_INDEX_2('idx', 'db.test_model_3', id)"], + ) + def model_with_statements(context, **kwargs): + return pd.DataFrame( + [ + { + "id": context.var("1"), + "name": context.var("var"), + } + ] + ) + + python_model = model.get_registry()["db.test_model_3"].model( + module_path=Path("."), + path=Path("."), + macros=macro.get_registry(), + dialect="postgres", + ) + + assert len(python_model.python_env) == 3 + assert len(python_model.on_virtual_update) == 1 + assert isinstance(python_model.python_env["create_index_2"], Executable) + assert isinstance(python_model.on_virtual_update[0], MacroFunc) + + snapshot = make_snapshot(python_model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + evaluator.create([snapshot], {}, DeployabilityIndex.none_deployable()) + + snapshots = {snapshot.name: snapshot} + environment_naming_info = EnvironmentNamingInfo(name="prod") + evaluator.promote( + [snapshot], + start="2020-01-01", + end="2020-01-01", + execution_time="2020-01-01", + snapshots=snapshots, + environment_naming_info=environment_naming_info, + table_mapping=to_view_mapping( + snapshots.values(), + environment_naming_info, + ), + ) + + call_args = adapter_mock.execute.call_args_list + on_virtual_update_call = call_args[4][0][0][0] + assert ( + on_virtual_update_call.sql(dialect="postgres") + == 'CREATE INDEX IF NOT EXISTS "idx" ON "db"."test_model_3" /* db.test_model_3 */("id")' + ) + + +def test_evaluate_incremental_by_partition(mocker: MockerFixture, make_snapshot, adapter_mock): + model = SqlModel( + name="test_schema.test_model", + query=parse_one("SELECT 1, ds, b FROM tbl_a"), + kind=IncrementalByPartitionKind(), + partitioned_by=["ds", "b"], + ) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + adapter_mock.columns.return_value = model.columns_to_types + + evaluator = SnapshotEvaluator(adapter_mock) + evaluator.evaluate( + snapshot, + start="2020-01-01", + end="2020-01-02", + execution_time="2020-01-02", + snapshots={}, + ) + + # uses `replace_query` on first model execution + adapter_mock.replace_query.assert_called_once_with( + snapshot.table_name(), + model.render_query(), + partitioned_by=[ + exp.to_column("ds", quoted=True), + exp.to_column("b", quoted=True), + ], + target_columns_to_types=model.columns_to_types, + clustered_by=[], + table_properties={}, + column_descriptions={}, + partition_interval_unit=None, + storage_format=None, + table_description=None, + table_format=None, + source_columns=None, + ) + + adapter_mock.reset_mock() + snapshot.intervals = [(to_timestamp("2020-01-01"), to_timestamp("2020-01-02"))] + + evaluator.evaluate( + snapshot, + start="2020-01-02", + end="2020-01-03", + execution_time="2020-01-03", + snapshots={}, + ) + + # uses `insert_overwrite_by_partition` on all subsequent model executions + adapter_mock.insert_overwrite_by_partition.assert_called_once_with( + snapshot.table_name(), + model.render_query(), + partitioned_by=[ + exp.to_column("ds", quoted=True), + exp.to_column("b", quoted=True), + ], + target_columns_to_types=model.columns_to_types, + source_columns=None, + ) + + +def test_custom_materialization_strategy(adapter_mock, make_snapshot): + custom_insert_kind = None + custom_insert_query_or_df = None + + class TestCustomMaterializationStrategy(CustomMaterialization): + NAME = "custom_materialization_test" + + def insert( + self, + table_name: str, + query_or_df: QueryOrDF, + model: Model, + is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], + **kwargs: t.Any, + ) -> None: + nonlocal custom_insert_kind + nonlocal custom_insert_query_or_df + + custom_insert_kind = model.kind + custom_insert_query_or_df = query_or_df + + evaluator = SnapshotEvaluator(adapter_mock) + model = load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name test_schema.test_model, + kind CUSTOM ( + materialization 'custom_materialization_test', + materialization_properties ( + 'test_property' = 'test_value' + ) + ) + ); + + SELECT * FROM tbl; + """ + ) + ) + + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + evaluator.evaluate( + snapshot, + start="2020-01-01", + end="2020-01-02", + execution_time="2020-01-02", + snapshots={}, + ) + + assert custom_insert_kind + assert isinstance(custom_insert_kind, CustomKind) + assert model.custom_materialization_properties == {"test_property": "test_value"} + + assert isinstance(custom_insert_query_or_df, exp.Query) + assert custom_insert_query_or_df.sql() == 'SELECT * FROM "tbl" AS "tbl"' + + +def test_custom_materialization_strategy_with_custom_properties(adapter_mock, make_snapshot): + custom_insert_kind = None + + class TestCustomKind(CustomKind): + _primary_key: t.List[exp.Expression] # type: ignore[no-untyped-def] + + @model_validator(mode="after") + def _validate_model(self) -> Self: + self._primary_key = list_of_fields_validator( + self.materialization_properties.get("primary_key"), {} + ) + if not self.primary_key: + raise ConfigError("primary_key must be specified") + return self + + @property + def primary_key(self) -> t.List[exp.Expression]: + return self._primary_key + + class TestCustomMaterializationStrategy(CustomMaterialization[TestCustomKind]): + NAME = "custom_materialization_test_1" + + def insert( + self, + table_name: str, + query_or_df: QueryOrDF, + model: Model, + is_first_insert: bool, + render_kwargs: t.Dict[str, t.Any], + **kwargs: t.Any, + ) -> None: + nonlocal custom_insert_kind + custom_insert_kind = model.kind + + evaluator = SnapshotEvaluator(adapter_mock) + + with pytest.raises(ConfigError, match=r"primary_key must be specified"): + model = load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name test_schema.test_model, + kind CUSTOM ( + materialization 'custom_materialization_test_1', + ) + ); + + SELECT * FROM tbl; + """ + ) + ) + + model = load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name test_schema.test_model, + kind CUSTOM ( + materialization 'custom_materialization_test_1', + materialization_properties ( + primary_key = id + ) + ) + ); + + SELECT * FROM tbl; + """ + ) + ) + + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + evaluator.evaluate( + snapshot, + start="2020-01-01", + end="2020-01-02", + execution_time="2020-01-02", + snapshots={}, + ) + + assert custom_insert_kind + assert isinstance(custom_insert_kind, TestCustomKind) + assert custom_insert_kind.primary_key == [exp.column("id", quoted=True)] + assert model.custom_materialization_properties["primary_key"] + + # show that the _primary_key property is transient + as_json = json.loads(model.json()) + assert "primary_key" not in as_json["kind"] + assert "_primary_key" not in as_json["kind"] + + +def test_custom_materialization_strategy_with_custom_kind_must_be_correct_type(): + # note: deliberately doesnt extend CustomKind + class TestCustomKind: + pass + + class TestCustomMaterializationStrategy(CustomMaterialization[TestCustomKind]): # type: ignore + NAME = "custom_materialization_test_2" + + model = load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name test_schema.test_model, + kind CUSTOM ( + materialization 'custom_materialization_test_2', + ) + ); + + SELECT * FROM tbl; + """ + ) + ) + + with pytest.raises( + SQLMeshError, match=r"kind 'TestCustomKind' must be a subclass of CustomKind" + ): + model.validate_definition() + + +def test_create_managed(adapter_mock, make_snapshot, mocker: MockerFixture): + evaluator = SnapshotEvaluator(adapter_mock) + + model = load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name test_schema.test_model, + kind MANAGED, + physical_properties ( + warehouse = 'small', + target_lag = '10 minutes' + ), clustered_by a ); @@ -2300,34 +3824,61 @@ def test_create_managed(adapter_mock, make_snapshot, mocker: MockerFixture): snapshot = make_snapshot(model) snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - evaluator.create([snapshot], {}) + evaluator.create([snapshot], {}, deployability_index=DeployabilityIndex.all_deployable()) + + adapter_mock.create_managed_table.assert_called_with( + table_name=snapshot.table_name(), + query=mocker.ANY, + target_columns_to_types=model.columns_to_types, + partitioned_by=model.partitioned_by, + clustered_by=model.clustered_by, + table_properties=model.physical_properties, + table_description=model.description, + column_descriptions=model.column_descriptions, + table_format=None, + ) + + +def test_create_managed_dev(adapter_mock, make_snapshot, mocker: MockerFixture): + evaluator = SnapshotEvaluator(adapter_mock) + + model = load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name test_schema.test_model, + kind MANAGED, + physical_properties ( + warehouse = 'small', + target_lag = '10 minutes' + ), + clustered_by a + ); + + select a, b from foo; + """ + ) + ) + + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + evaluator.create([snapshot], {}, deployability_index=DeployabilityIndex.none_deployable()) - # first call to evaluation_strategy.create(), is_table_deployable=False triggers a normal table adapter_mock.ctas.assert_called_once_with( - f"{snapshot.table_name()}__temp", + f"{snapshot.table_name()}__dev", mocker.ANY, model.columns_to_types, + table_format=model.table_format, storage_format=model.storage_format, partitioned_by=model.partitioned_by, - partition_interval_unit=model.interval_unit, + partition_interval_unit=model.partition_interval_unit, clustered_by=model.clustered_by, table_properties=model.physical_properties, table_description=None, column_descriptions=None, ) - # second call to evaluation_strategy.create(), is_table_deployable=True and is_snapshot_deployable=True triggers a managed table - adapter_mock.create_managed_table.assert_called_with( - table_name=snapshot.table_name(), - query=mocker.ANY, - columns_to_types=model.columns_to_types, - partitioned_by=model.partitioned_by, - clustered_by=model.clustered_by, - table_properties=model.physical_properties, - table_description=model.description, - column_descriptions=model.column_descriptions, - ) - def test_evaluate_managed(adapter_mock, make_snapshot, mocker: MockerFixture): evaluator = SnapshotEvaluator(adapter_mock) @@ -2372,6 +3923,12 @@ def test_evaluate_managed(adapter_mock, make_snapshot, mocker: MockerFixture): adapter_mock.reset_mock() adapter_mock.assert_not_called() + table_colmns = { + "a": exp.DataType.build("int"), + "b": exp.DataType.build("string"), + } + adapter_mock.columns.return_value = table_colmns + evaluator.evaluate( snapshot, start="2020-01-01", @@ -2383,17 +3940,20 @@ def test_evaluate_managed(adapter_mock, make_snapshot, mocker: MockerFixture): adapter_mock.create_managed_table.assert_not_called() adapter_mock.replace_query.assert_called_with( - f"{snapshot.table_name()}__temp", + snapshot.table_name(is_deployable=False), mocker.ANY, - columns_to_types=None, + target_columns_to_types=table_colmns, + table_format=model.table_format, storage_format=model.storage_format, partitioned_by=model.partitioned_by, - partition_interval_unit=model.interval_unit, + partition_interval_unit=model.partition_interval_unit, clustered_by=model.clustered_by, table_properties=model.physical_properties, table_description=model.description, column_descriptions=model.column_descriptions, + source_columns=None, ) + adapter_mock.columns.assert_called_once_with(snapshot.table_name(is_deployable=False)) def test_cleanup_managed(adapter_mock, make_snapshot, mocker: MockerFixture): @@ -2426,9 +3986,1518 @@ def test_cleanup_managed(adapter_mock, make_snapshot, mocker: MockerFixture): evaluator.cleanup(target_snapshots=[cleanup_task]) - adapter_mock.drop_table.assert_called_once_with( - "sqlmesh__test_schema.test_schema__test_model__1556851963__temp" - ) - adapter_mock.drop_managed_table.assert_called_once_with( - "sqlmesh__test_schema.test_schema__test_model__1556851963" - ) + physical_name = f"sqlmesh__test_schema.test_schema__test_model__{snapshot.version}" + adapter_mock.drop_table.assert_called_once_with(f"{physical_name}__dev") + adapter_mock.drop_managed_table.assert_called_once_with(f"{physical_name}") + + +def test_create_managed_forward_only_with_previous_version_doesnt_clone_for_dev_preview( + adapter_mock, make_snapshot +): + evaluator = SnapshotEvaluator(adapter_mock) + + model = load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name test_schema.test_model, + kind MANAGED + ); + + select a, b from foo; + """ + ) + ) + + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) + snapshot.previous_versions = ( + SnapshotDataVersion( + fingerprint=SnapshotFingerprint( + data_hash="test_data_hash", + metadata_hash="test_metadata_hash", + ), + version="test_version", + change_category=SnapshotChangeCategory.BREAKING, + dev_table_suffix="dev", + ), + ) + + evaluator.create( + target_snapshots=[snapshot], + snapshots={}, + deployability_index=DeployabilityIndex.none_deployable(), + ) + + # We dont clone managed tables to create dev previews, we use normal tables + adapter_mock.clone_table.assert_not_called() + adapter_mock.create_managed_table.assert_not_called() + adapter_mock.create_table.assert_not_called() + + # The table gets created using ctas() because the model column types arent known + adapter_mock.ctas.assert_called_once() + assert adapter_mock.ctas.call_args_list[0].args[0] == snapshot.table_name(is_deployable=False) + + +def test_migrate_snapshot(snapshot: Snapshot, mocker: MockerFixture, adapter_mock, make_snapshot): + adapter_mock = mocker.patch("sqlmesh.core.engine_adapter.EngineAdapter") + adapter_mock.dialect = "duckdb" + adapter_mock.with_settings.return_value = adapter_mock + + evaluator = SnapshotEvaluator(adapter_mock) + evaluator.create([snapshot], {}) + + updated_model_dict = snapshot.model.dict() + updated_model_dict["query"] = "SELECT a::int, b::int FROM tbl" + updated_model = SqlModel.parse_obj(updated_model_dict) + + new_snapshot = make_snapshot(updated_model) + new_snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) + new_snapshot.previous_versions = snapshot.all_versions + new_snapshot.version = snapshot.version + + assert new_snapshot.table_name() == snapshot.table_name() + + adapter_mock.get_data_objects.return_value = [ + DataObject( + schema="sqlmesh__db", + name=f"db__model__{new_snapshot.version}", + type=DataObjectType.TABLE, + ) + ] + adapter_mock.drop_data_object_on_type_mismatch.return_value = False + + evaluator.migrate([new_snapshot], {}) + + common_kwargs: t.Dict[str, t.Any] = dict( + table_format=None, + storage_format=None, + partitioned_by=[], + partition_interval_unit=None, + clustered_by=[], + table_properties={}, + table_description=None, + ) + + adapter_mock.create_table.assert_has_calls( + [ + call( + new_snapshot.table_name(), + target_columns_to_types={"a": exp.DataType.build("int")}, + column_descriptions={}, + **common_kwargs, + ), + call( + f"{new_snapshot.table_name()}_schema_tmp", + target_columns_to_types={ + "a": exp.DataType.build("int"), + "b": exp.DataType.build("int"), + }, + column_descriptions=None, + **common_kwargs, + ), + ] + ) + + adapter_mock.fetchall.assert_has_calls( + [ + call( + parse_one('SELECT CAST("a" AS INT) AS "a" FROM "tbl" AS "tbl" WHERE FALSE LIMIT 0') + ), + call( + parse_one( + 'SELECT CAST("a" AS INT) AS "a", CAST("b" AS INT) AS "b" FROM "tbl" AS "tbl" WHERE FALSE LIMIT 0' + ) + ), + ] + ) + + adapter_mock.get_alter_operations.assert_called_once_with( + snapshot.table_name(), + f"{new_snapshot.table_name()}_schema_tmp", + ignore_destructive=False, + ignore_additive=False, + ) + + +def test_migrate_only_processes_target_snapshots( + mocker: MockerFixture, adapter_mock, make_snapshot +): + evaluator = SnapshotEvaluator(adapter_mock) + + target_model = SqlModel( + name="test_schema.target_model", + kind=FullKind(), + query=parse_one("SELECT 1 AS a"), + ) + extra_model = SqlModel( + name="test_schema.extra_model", + kind=FullKind(), + query=parse_one("SELECT 1 AS a"), + ) + + target_snapshot = make_snapshot(target_model) + extra_snapshot = make_snapshot(extra_model) + target_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + extra_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + target_snapshots = [target_snapshot] + snapshots = { + target_snapshot.snapshot_id: target_snapshot, + extra_snapshot.snapshot_id: extra_snapshot, + } + + mocker.patch.object( + evaluator, + "_get_data_objects", + return_value={target_snapshot.snapshot_id: mocker.Mock()}, + ) + migrate_mock = mocker.patch.object(evaluator, "_migrate_snapshot") + + def apply_side_effect(snapshot_iterable, fn, *_args, **_kwargs): + for snapshot in snapshot_iterable: + fn(snapshot) + return ([], []) + + apply_mock = mocker.patch( + "sqlmesh.core.snapshot.evaluator.concurrent_apply_to_snapshots", + side_effect=apply_side_effect, + ) + + evaluator.migrate(target_snapshots=target_snapshots, snapshots=snapshots) + + assert apply_mock.call_count == 1 + called_snapshots = list(apply_mock.call_args.args[0]) + assert called_snapshots == target_snapshots + + migrate_mock.assert_called_once() + called_snapshot, snapshots_by_name, *_ = migrate_mock.call_args.args + assert called_snapshot is target_snapshot + assert target_snapshot.name in snapshots_by_name + assert extra_snapshot.name in snapshots_by_name + + +def test_migrate_managed(adapter_mock, make_snapshot, mocker: MockerFixture): + evaluator = SnapshotEvaluator(adapter_mock) + + model = load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name test_schema.test_model, + kind MANAGED + ); + + select a, b from foo; + """ + ) + ) + snapshot: Snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING, forward_only=True) + snapshot.previous_versions = snapshot.all_versions + + adapter_mock.get_data_objects.return_value = [ + DataObject( + schema="sqlmesh__test_schema", + name=f"test_schema__test_model__{snapshot.version}", + type=DataObjectType.MANAGED_TABLE, + ) + ] + adapter_mock.drop_data_object_on_type_mismatch.return_value = False + + # no schema changes - no-op + adapter_mock.get_alter_operations.return_value = [] + evaluator.migrate( + target_snapshots=[snapshot], + snapshots={}, + ) + + adapter_mock.create_table.assert_not_called() + adapter_mock.create_managed_table.assert_not_called() + adapter_mock.ctas.assert_called_once() + adapter_mock.reset_mock() + + # schema changes - exception thrown + adapter_mock.get_alter_operations.return_value = [exp.Alter()] + + with pytest.raises(NodeExecutionFailedError) as ex: + evaluator.migrate( + target_snapshots=[snapshot], + snapshots={}, + ) + + sqlmesh_err = ex.value.__cause__ + assert isinstance(sqlmesh_err, SQLMeshError) + assert re.match( + "The schema of the managed model '.*?' cannot be updated in a forward-only fashion", + str(sqlmesh_err), + ) + + adapter_mock.create_table.assert_not_called() + adapter_mock.ctas.assert_called_once() + adapter_mock.create_managed_table.assert_not_called() + + +def test_multiple_engine_creation(snapshot: Snapshot, adapters, make_snapshot): + engine_adapters = {"default": adapters[0], "secondary": adapters[1], "third": adapters[2]} + evaluator = SnapshotEvaluator(engine_adapters) + + assert len(evaluator.adapters) == 3 + assert evaluator.adapter == engine_adapters["default"] + assert evaluator.get_adapter() == engine_adapters["default"] + assert evaluator.get_adapter("third") == engine_adapters["third"] + assert evaluator.get_adapter("secondary") == engine_adapters["secondary"] + + model = load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name test_schema.test_model, + kind FULL, + gateway secondary, + dialect postgres, + ); + CREATE INDEX IF NOT EXISTS test_idx ON test_schema.test_model(a); + SELECT a::int FROM tbl; + CREATE INDEX IF NOT EXISTS test_idx ON test_schema.test_model(a); + """ + ), + ) + + snapshot_2 = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_2.categorize_as(SnapshotChangeCategory.BREAKING) + expected_call = f'CREATE INDEX IF NOT EXISTS "test_idx" ON "sqlmesh__test_schema"."test_schema__test_model__{snapshot_2.version}" /* test_schema.test_model */("a" NULLS FIRST)' + evaluator.create([snapshot_2, snapshot], {}, DeployabilityIndex.all_deployable()) + + # Default gateway adapter + create_args = engine_adapters["default"].create_table.call_args_list + assert len(create_args) == 1 + assert create_args[0][0] == (f"sqlmesh__db.db__model__{snapshot.version}",) + + # Secondary gateway for gateway-specicied model + create_args_2 = engine_adapters["secondary"].create_table.call_args_list + assert len(create_args_2) == 1 + assert create_args_2[0][0] == ( + f"sqlmesh__test_schema.test_schema__test_model__{snapshot_2.version}", + ) + + engine_adapters["third"].create_table.assert_not_called() + evaluator.promote([snapshot, snapshot_2], EnvironmentNamingInfo(name="test_env")) + + # Virtual layer will use the default adapter + engine_adapters["secondary"].create_view.assert_not_called() + engine_adapters["third"].create_view.assert_not_called() + view_args = engine_adapters["default"].create_view.call_args_list + assert len(view_args) == 2 + assert view_args[0][0][0] == "db__test_env.model" + assert view_args[1][0][0] == "test_schema__test_env.test_model" + + call_args = engine_adapters["secondary"].execute.call_args_list + pre_calls = call_args[1][0][0] + assert len(pre_calls) == 1 + assert pre_calls[0].sql(dialect="postgres") == expected_call + + post_calls = call_args[2][0][0] + assert len(post_calls) == 1 + assert post_calls[0].sql(dialect="postgres") == expected_call + + evaluator.demote([snapshot_2], EnvironmentNamingInfo(name="test_env")) + engine_adapters["secondary"].drop_view.assert_not_called() + engine_adapters["default"].drop_view.assert_called_once_with( + "test_schema__test_env.test_model", + cascade=False, + ) + + +def test_multiple_engine_promotion(mocker: MockerFixture, adapter_mock, make_snapshot): + connection_mock = mocker.NonCallableMock() + cursor_mock = mocker.Mock() + connection_mock.cursor.return_value = cursor_mock + adapter = EngineAdapter(lambda: connection_mock, "") + adapter.with_settings = lambda **kwargs: adapter # type: ignore + adapter._get_data_objects = lambda *args, **kwargs: [] # type: ignore + engine_adapters = {"default": adapter_mock, "secondary": adapter} + + def columns(table_name): + return { + "a": exp.DataType.build("int"), + "ds": exp.DataType.build("timestamp"), + } + + adapter.columns = columns # type: ignore + + model = SqlModel( + name="test_schema.test_model", + kind=IncrementalByTimeRangeKind(time_column="ds"), + gateway="secondary", + query=parse_one("SELECT a FROM tbl WHERE ds BETWEEN @start_ds and @end_ds"), + ) + + evaluator = SnapshotEvaluator(engine_adapters) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + evaluator.evaluate( + snapshot, + start="2020-01-01", + end="2020-01-02", + execution_time="2020-01-02", + snapshots={}, + ) + + evaluator.promote([snapshot], EnvironmentNamingInfo(name="test_env")) + + # Verify that the model was evaluated using the gateway specific adapter "secondary" + cursor_mock.execute.assert_has_calls( + [ + call( + f'DELETE FROM "sqlmesh__test_schema"."test_schema__test_model__{snapshot.version}" WHERE "ds" BETWEEN CAST(\'2020-01-01 00:00:00\' AS TIMESTAMP) AND CAST(\'2020-01-02 23:59:59.999999\' AS TIMESTAMP)' + ), + call( + f'INSERT INTO "sqlmesh__test_schema"."test_schema__test_model__{snapshot.version}" ("a", "ds") SELECT "a", "ds" FROM (SELECT "a" AS "a" FROM "tbl" AS "tbl" WHERE "ds" BETWEEN \'2020-01-01\' AND \'2020-01-02\') AS "_subquery" WHERE "ds" BETWEEN CAST(\'2020-01-01 00:00:00\' AS TIMESTAMP) AND CAST(\'2020-01-02 23:59:59.999999\' AS TIMESTAMP)' + ), + ] + ) + + # Verify that the snapshot was promoted using the default adapter "default" (adapter_mock in this case) + adapter_mock.create_schema.assert_called_once_with(schema_("test_schema__test_env")) + adapter_mock.create_view.assert_called_once_with( + "test_schema__test_env.test_model", + parse_one( + f"SELECT * FROM sqlmesh__test_schema.test_schema__test_model__{snapshot.version}" + ), + table_description=None, + column_descriptions=None, + view_properties={}, + ) + + +def test_multiple_engine_migration( + mocker: MockerFixture, adapter_mock, make_snapshot, make_mocked_engine_adapter +): + adapter_one = make_mocked_engine_adapter(EngineAdapter) + adapter_one.with_settings = lambda **kwargs: adapter_one # type: ignore + adapter_two = adapter_mock + adapter_two.with_settings.return_value = adapter_two + engine_adapters = {"one": adapter_one, "two": adapter_two} + + model = SqlModel( + name="test_schema.test_model", + kind=IncrementalByTimeRangeKind( + time_column="a", on_destructive_change=OnDestructiveChange.ALLOW + ), + query=parse_one("SELECT c FROM tbl WHERE ds BETWEEN @start_ds and @end_ds"), + ) + snapshot_1 = make_snapshot(model, version="1") + snapshot_1.change_category = SnapshotChangeCategory.BREAKING + snapshot_1.forward_only = True + snapshot_1.previous_versions = snapshot_1.all_versions + model_2 = SqlModel( + name="test_schema.test_model_2", + kind=IncrementalByTimeRangeKind( + time_column="a", on_destructive_change=OnDestructiveChange.ALLOW + ), + gateway="two", + query=parse_one("SELECT c FROM tbl WHERE ds BETWEEN @start_ds and @end_ds"), + ) + snapshot_2 = make_snapshot(model_2, version="1") + snapshot_2.change_category = SnapshotChangeCategory.BREAKING + snapshot_2.forward_only = True + snapshot_2.previous_versions = snapshot_2.all_versions + + def columns(table_name): + if table_name == snapshot_1.table_name(): + return { + "c": exp.DataType.build("int"), + "b": exp.DataType.build("int"), + } + return { + "c": exp.DataType.build("int"), + "a": exp.DataType.build("int"), + } + + adapter_two.columns.side_effect = columns + adapter_two.drop_data_object_on_type_mismatch.return_value = False + + mocker.patch.object(adapter_one, "columns", side_effect=columns) + mocker.patch( + "sqlmesh.core.engine_adapter.base.EngineAdapter.get_data_objects", + return_value=[ + DataObject( + schema="sqlmesh__test_schema", + name=f"test_schema__test_model__{snapshot_1.version}", + type=DataObjectType.TABLE, + ), + DataObject( + schema="sqlmesh__test_schema", + name=f"test_schema__test_model_2__{snapshot_2.version}", + type=DataObjectType.TABLE, + ), + ], + ) + + evaluator = SnapshotEvaluator(engine_adapters) + evaluator.migrate([snapshot_1, snapshot_2], {}) + + adapter_one.cursor.execute.assert_has_calls( + [ + call('ALTER TABLE "sqlmesh__test_schema"."test_schema__test_model__1" DROP COLUMN "b"'), + call( + 'ALTER TABLE "sqlmesh__test_schema"."test_schema__test_model__1" ADD COLUMN "a" INT' + ), + ] + ) + + # The second mock adapter has to be called only for the gateway-specific model + adapter_mock.get_alter_operations.assert_called_once_with( + snapshot_2.table_name(True), + f"{snapshot_2.table_name(True)}_schema_tmp", + ignore_destructive=False, + ignore_additive=False, + ) + + +def test_multiple_engine_cleanup(snapshot: Snapshot, adapters, make_snapshot): + engine_adapters = {"default": adapters[0], "secondary": adapters[1]} + evaluator = SnapshotEvaluator(engine_adapters) + + model = load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name test_schema.test_model, + kind FULL, + gateway secondary, + ); + SELECT a::int FROM tbl; + """ + ), + ) + + snapshot_2 = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_2.categorize_as(SnapshotChangeCategory.BREAKING) + evaluator.create([snapshot_2, snapshot], {}, DeployabilityIndex.all_deployable()) + + assert engine_adapters["default"].create_table.call_args_list[0][0] == ( + f"sqlmesh__db.db__model__{snapshot.version}", + ) + assert engine_adapters["secondary"].create_table.call_args_list[0][0] == ( + f"sqlmesh__test_schema.test_schema__test_model__{snapshot_2.version}", + ) + + evaluator.cleanup( + [ + SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=True), + SnapshotTableCleanupTask(snapshot=snapshot_2.table_info, dev_table_only=True), + ], + ) + + # The clean up will happen using the specific gateway the model was created with + engine_adapters["default"].drop_table.assert_called_once_with( + f"sqlmesh__db.db__model__{snapshot.version}__dev", cascade=True + ) + engine_adapters["secondary"].drop_table.assert_called_once_with( + f"sqlmesh__test_schema.test_schema__test_model__{snapshot_2.version}__dev", cascade=True + ) + + +def test_multi_engine_python_model_with_macros(adapters, make_snapshot): + engine_adapters = {"default": adapters[0], "secondary": adapters[1]} + evaluator = SnapshotEvaluator(engine_adapters) + + @macro() + def validate_engine_call( + evaluator: MacroEvaluator, + ): + if evaluator.runtime_stage == "creating": + # To validate the model-specified gateway is used for the macro evaluator + evaluator.engine_adapter.get_catalog_type() + return None + + @model( + "db.multi_engine_test_model", + kind="full", + gateway="secondary", + columns={"id": "string", "name": "string"}, + pre_statements=["@VALIDATE_ENGINE_CALL()"], + post_statements=["@VALIDATE_ENGINE_CALL()"], + ) + def model_with_statements(context, **kwargs): + return pd.DataFrame( + [ + { + "id": context.var("1"), + "name": context.var("var"), + } + ] + ) + + python_model = model.get_registry()["db.multi_engine_test_model"].model( + module_path=Path("."), + path=Path("."), + macros=macro.get_registry(), + dialect="postgres", + ) + + assert len(python_model.python_env) == 3 + assert isinstance(python_model.python_env["validate_engine_call"], Executable) + + snapshot = make_snapshot(python_model) + assert snapshot.model_gateway == "secondary" + assert evaluator.adapter == engine_adapters["default"] + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + evaluator.create([snapshot], {}, DeployabilityIndex.all_deployable()) + + # Validate model-specific gateway usage during table creation + create_args = engine_adapters["secondary"].create_table.call_args_list + assert len(create_args) == 1 + assert create_args[0][0] == (f"sqlmesh__db.db__multi_engine_test_model__{snapshot.version}",) + + environment_naming_info = EnvironmentNamingInfo(name="test_env") + evaluator.promote([snapshot], environment_naming_info) + + # Verify that the default gateway creates the view for the virtual layer + engine_adapters["secondary"].create_view.assert_not_called() + view_args = engine_adapters["default"].create_view.call_args_list + assert len(view_args) == 1 + assert view_args[0][0][0] == "db__test_env.multi_engine_test_model" + + # For the pre/post statements verify the model-specific gateway was used + engine_adapters["default"].execute.assert_called_once() + assert len(engine_adapters["secondary"].execute.call_args_list) == 4 + + # Validate that the get_catalog_type method was called only on the secondary engine from the macro evaluator + engine_adapters["default"].get_catalog_type.assert_not_called() + assert len(engine_adapters["secondary"].get_catalog_type.call_args_list) == 2 + + evaluator.demote([snapshot], environment_naming_info) + engine_adapters["default"].drop_view.assert_called_once_with( + "db__test_env.multi_engine_test_model", + cascade=False, + ) + + environment_naming_info_gw = EnvironmentNamingInfo(name="test_env", gateway_managed=True) + # Validate that promoting with gateway_managed leads to this gateway being used for virtual layer + evaluator.promote([snapshot], environment_naming_info_gw) + view_args = engine_adapters["secondary"].create_view.call_args_list + assert len(view_args) == 1 + assert view_args[0][0][0] == "db__test_env.multi_engine_test_model" + + # Similarly for demotion + evaluator.demote([snapshot], environment_naming_info_gw) + engine_adapters["secondary"].drop_view.assert_called_once_with( + "db__test_env.multi_engine_test_model", + cascade=False, + ) + + +def test_multiple_engine_virtual_layer(snapshot: Snapshot, adapters, make_snapshot): + engine_adapters = {"default": adapters[0], "secondary": adapters[1], "third": adapters[2]} + evaluator = SnapshotEvaluator(engine_adapters) + + model = load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name test_schema.test_model, + kind FULL, + gateway secondary, + dialect postgres, + ); + SELECT a::int FROM tbl; + CREATE INDEX IF NOT EXISTS test_idx ON test_schema.test_model(a); + """ + ), + ) + + snapshot_2 = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + snapshot_2.categorize_as(SnapshotChangeCategory.BREAKING) + evaluator.create([snapshot_2, snapshot], {}, DeployabilityIndex.all_deployable()) + + # Default gateway adapter to create table without gateway + create_args = engine_adapters["default"].create_table.call_args_list + assert len(create_args) == 1 + assert create_args[0][0] == (f"sqlmesh__db.db__model__{snapshot.version}",) + + # Secondary gateway for gateway-specicied model + create_args_2 = engine_adapters["secondary"].create_table.call_args_list + assert len(create_args_2) == 1 + assert create_args_2[0][0] == ( + f"sqlmesh__test_schema.test_schema__test_model__{snapshot_2.version}", + ) + + environment_naming_info = EnvironmentNamingInfo(name="test_env", gateway_managed=True) + engine_adapters["third"].create_table.assert_not_called() + evaluator.promote([snapshot, snapshot_2], environment_naming_info) + + # Virtual layer will use the model-specified gateway adapter for the second model and default otherwise + view_args_default = engine_adapters["default"].create_view.call_args_list + engine_adapters["third"].create_view.assert_not_called() + view_args_secondary = engine_adapters["secondary"].create_view.call_args_list + + assert len(view_args_default) == 1 + assert view_args_default[0][0][0] == "db__test_env.model" + assert len(view_args_secondary) == 1 + assert view_args_secondary[0][0][0] == "test_schema__test_env.test_model" + + # Demotion will follow with the same pattern + evaluator.demote([snapshot_2, snapshot], environment_naming_info) + engine_adapters["default"].drop_view.assert_called_once_with( + "db__test_env.model", + cascade=False, + ) + engine_adapters["secondary"].drop_view.assert_called_once_with( + "test_schema__test_env.test_model", + cascade=False, + ) + + +def test_wap_basic( + adapter_mock: Mock, make_snapshot: t.Callable[..., Snapshot], mocker: MockerFixture +) -> None: + evaluator = SnapshotEvaluator(adapter_mock) + + model = SqlModel( + name="test_schema.test_table", + kind=FullKind(), + query=parse_one("SELECT a::int FROM tbl"), + ) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + adapter_mock.wap_supported.return_value = True + + expected_wap_table = "test_schema.test_table.branch_wap_12345678" + adapter_mock.wap_prepare.return_value = expected_wap_table + + wap_id = evaluator.evaluate( + snapshot, + start="2020-01-01", + end="2020-01-01", + execution_time="2020-01-01", + snapshots={}, + target_table_exists=True, # Use parameter to control table existence + ) + + assert wap_id is not None + assert len(wap_id) == 8 + + expected_table_name = snapshot.table_name() + adapter_mock.wap_prepare.assert_called_once_with(expected_table_name, wap_id) + adapter_mock.replace_query.assert_called_once_with( + expected_wap_table, + mocker.ANY, + table_format=mocker.ANY, + storage_format=mocker.ANY, + partitioned_by=mocker.ANY, + partition_interval_unit=mocker.ANY, + clustered_by=mocker.ANY, + table_properties=mocker.ANY, + table_description=mocker.ANY, + column_descriptions=mocker.ANY, + target_columns_to_types=mocker.ANY, + source_columns=mocker.ANY, + ) + + +def test_wap_model_wap_supported( + adapter_mock: Mock, make_snapshot: t.Callable[..., Snapshot], mocker: MockerFixture +) -> None: + evaluator = SnapshotEvaluator(adapter_mock) + + model = SqlModel( + name="test_schema.test_table", + kind=FullKind(), + query=parse_one("SELECT a::int FROM tbl"), + storage_format="iceberg", # Model supports WAP via iceberg format + ) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + adapter_mock.wap_supported.return_value = False + + expected_wap_table = "test_schema.test_table.branch_wap_12345678" + adapter_mock.wap_prepare.return_value = expected_wap_table + + wap_id = evaluator.evaluate( + snapshot, + start="2020-01-01", + end="2020-01-01", + execution_time="2020-01-01", + snapshots={}, + target_table_exists=True, # Use parameter to control table existence + ) + assert wap_id is not None + + expected_table_name = snapshot.table_name() + adapter_mock.wap_prepare.assert_called_once_with(expected_table_name, wap_id) + adapter_mock.replace_query.assert_called_once_with( + expected_wap_table, + mocker.ANY, + table_format=mocker.ANY, + storage_format=mocker.ANY, + partitioned_by=mocker.ANY, + partition_interval_unit=mocker.ANY, + clustered_by=mocker.ANY, + table_properties=mocker.ANY, + table_description=mocker.ANY, + column_descriptions=mocker.ANY, + target_columns_to_types=mocker.ANY, + source_columns=mocker.ANY, + ) + + +def test_wap_no_wap_support(adapter_mock: Mock, make_snapshot: t.Callable[..., Snapshot]) -> None: + evaluator = SnapshotEvaluator(adapter_mock) + + model = SqlModel( + name="test_schema.test_table", + kind=FullKind(), + query=parse_one("SELECT a::int FROM tbl"), + ) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + adapter_mock.wap_supported.return_value = False + + wap_id = evaluator.evaluate( + snapshot, + start="2020-01-01", + end="2020-01-01", + execution_time="2020-01-01", + snapshots={}, + target_table_exists=True, + ) + + assert wap_id is None + adapter_mock.wap_prepare.assert_not_called() + + +def test_wap_non_materialized_snapshot( + adapter_mock: Mock, make_snapshot: t.Callable[..., Snapshot] +) -> None: + evaluator = SnapshotEvaluator(adapter_mock) + + model = SqlModel( + name="test_schema.test_table", + kind=ViewKind(), # View kind is not materialized + query=parse_one("SELECT a::int FROM tbl"), + ) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + adapter_mock.wap_supported.return_value = True + + wap_id = evaluator.evaluate( + snapshot, start="2020-01-01", end="2020-01-01", execution_time="2020-01-01", snapshots={} + ) + + assert wap_id is None + adapter_mock.wap_prepare.assert_not_called() + + +def test_wap_publish_snapshot(adapter_mock: Mock, make_snapshot: t.Callable[..., Snapshot]) -> None: + evaluator = SnapshotEvaluator(adapter_mock) + + model = SqlModel( + name="test_schema.test_table", + kind=FullKind(), + query=parse_one("SELECT a::int FROM tbl"), + ) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + wap_id = "test_wap_id" + deployability_index = DeployabilityIndex.all_deployable() + + evaluator.wap_publish_snapshot(snapshot, wap_id, deployability_index) + + expected_table_name = snapshot.table_name(is_deployable=True) + adapter_mock.wap_publish.assert_called_once_with(expected_table_name, wap_id) + + +def test_wap_during_audit(adapter_mock: Mock, make_snapshot: t.Callable[..., Snapshot]) -> None: + evaluator = SnapshotEvaluator(adapter_mock) + + custom_audit = ModelAudit( + name="custom_audit", + query="SELECT * FROM test_schema.test_table WHERE invalid_condition", + ) + + model = SqlModel( + name="test_schema.test_table", + kind=FullKind(), + query=parse_one("SELECT a::int FROM tbl"), + audits=[ + ("not_null", {"columns": exp.to_column("a")}), + ("custom_audit", {}), + ], + audit_definitions={custom_audit.name: custom_audit}, + ) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + wap_id = "test_wap_id" + expected_wap_table_name = f"test_schema.test_table.branch_wap_{wap_id}" + adapter_mock.wap_table_name.return_value = expected_wap_table_name + adapter_mock.fetchone.return_value = (0,) + + results = evaluator.audit(snapshot, snapshots={}, wap_id=wap_id) + + assert len(results) == 2 + + adapter_mock.wap_table_name.assert_called_once_with(snapshot.table_name(), wap_id) + adapter_mock.wap_publish.assert_called_once_with(snapshot.table_name(), wap_id) + + +def test_wap_prepare_failure(adapter_mock: Mock, make_snapshot: t.Callable[..., Snapshot]) -> None: + evaluator = SnapshotEvaluator(adapter_mock) + + model = SqlModel( + name="test_schema.test_table", + kind=FullKind(), + query=parse_one("SELECT a::int FROM tbl"), + ) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + adapter_mock.wap_supported.return_value = True + + adapter_mock.wap_prepare.side_effect = Exception("WAP prepare failed") + + with pytest.raises(Exception, match="WAP prepare failed"): + evaluator.evaluate( + snapshot, + start="2020-01-01", + end="2020-01-01", + execution_time="2020-01-01", + snapshots={}, + target_table_exists=True, + ) + + +def test_wap_publish_failure(adapter_mock: Mock, make_snapshot: t.Callable[..., Snapshot]) -> None: + """Test error handling when WAP publish fails.""" + evaluator = SnapshotEvaluator(adapter_mock) + + model = SqlModel( + name="test_schema.test_table", + kind=FullKind(), + query=parse_one("SELECT a::int FROM tbl"), + audits=[("not_null", {"columns": exp.to_column("a")})], + ) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + wap_id = "test_wap_id" + expected_wap_table_name = f"test_schema.test_table.branch_wap_{wap_id}" + adapter_mock.wap_table_name.return_value = expected_wap_table_name + adapter_mock.fetchone.return_value = (0,) + + # Mock WAP publish to raise an exception + adapter_mock.wap_publish.side_effect = Exception("WAP publish failed") + + # Execute audit with WAP ID and expect it to raise the exception + with pytest.raises(Exception, match="WAP publish failed"): + evaluator.audit(snapshot, snapshots={}, wap_id=wap_id) + + +def test_properties_are_preserved_in_both_create_statements( + adapter_mock: Mock, make_snapshot: t.Callable[..., Snapshot] +) -> None: + # the below mocks are needed to create a situation + # where we trigger two create statements during evaluation + transaction_mock = Mock() + transaction_mock.__enter__ = Mock() + transaction_mock.__exit__ = Mock() + session_mock = Mock() + session_mock.__enter__ = Mock() + session_mock.__exit__ = Mock() + adapter_mock = Mock() + adapter_mock.transaction.return_value = transaction_mock + adapter_mock.session.return_value = session_mock + adapter_mock.dialect = "trino" + adapter_mock.HAS_VIEW_BINDING = False + adapter_mock.wap_supported.return_value = False + adapter_mock.get_data_objects.return_value = [] + adapter_mock.with_settings.return_value = adapter_mock + adapter_mock.table_exists.return_value = False + + props = [] + + def mutate_view_properties(*args, **kwargs): + view_props = kwargs.get("view_properties") + if isinstance(view_props, dict): + props.append(view_props["creatable_type"].sql()) + # simulate view pop + view_props.pop("creatable_type") + return None + + adapter_mock.create_view.side_effect = mutate_view_properties + + evaluator = SnapshotEvaluator(adapter_mock) + + # create a view model with SECURITY INVOKER physical property + # AND self referenctial to trigger two create statements + model = load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name test_schema.security_view, + kind VIEW, + physical_properties ( + 'creatable_type' = 'SECURITY INVOKER' + ) + ); + + SELECT 1 as col from test_schema.security_view; + """ + ), + ) + + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + evaluator.evaluate( + snapshot, + start="2024-01-01", + end="2024-01-02", + execution_time="2024-01-02", + snapshots={}, + ) + + # Verify create_view was called twice + assert adapter_mock.create_view.call_count == 2 + first_call = adapter_mock.create_view.call_args_list[0] + second_call = adapter_mock.create_view.call_args_list[1] + + # First call should be CREATE VIEW (replace=False) second CREATE OR REPLACE VIEW (replace=True) + assert first_call.kwargs.get("replace") == False + assert second_call.kwargs.get("replace") == True + + # Both calls should have view_properties with security invoker + assert props == ["'SECURITY INVOKER'", "'SECURITY INVOKER'"] + + +def _create_grants_test_model( + grants=None, kind="FULL", grants_target_layer=None, virtual_environment_mode=None +): + if kind == "SEED": + from sqlmesh.core.model.definition import create_seed_model + from sqlmesh.core.model.kind import SeedKind + import tempfile + import os + + # Create a temporary CSV file for the test + temp_csv = tempfile.NamedTemporaryFile(mode="w", suffix=".csv", delete=False) + temp_csv.write("id,name\n1,test\n2,test2\n") + temp_csv.flush() + temp_csv.close() + + seed_kind_config = {"name": "SEED", "path": temp_csv.name} + seed_kind = SeedKind(**seed_kind_config) + + kwargs = {} + if grants is not None: + kwargs["grants"] = grants + if grants_target_layer is not None: + kwargs["grants_target_layer"] = grants_target_layer + + model = create_seed_model("test_model", seed_kind, **kwargs) + + # Clean up the temporary file + os.unlink(temp_csv.name) + + return model + + # Handle regular SQL models + kwargs = { + "kind": kind, + } + if grants is not None: + kwargs["grants"] = grants + if grants_target_layer is not None: + kwargs["grants_target_layer"] = grants_target_layer + if virtual_environment_mode is not None: + kwargs["virtual_environment_mode"] = virtual_environment_mode + + # Add column annotations for non-SEED models to ensure table creation + if kind != "SEED": + kwargs["columns"] = { + "id": "INT", + "ds": "DATE", + "updated_at": "TIMESTAMP", + } + + # Add required fields for specific model kinds + if kind == "INCREMENTAL_BY_TIME_RANGE": + kwargs["kind"] = {"name": "INCREMENTAL_BY_TIME_RANGE", "time_column": "ds"} + elif kind == "INCREMENTAL_BY_PARTITION": + kwargs["kind"] = {"name": "INCREMENTAL_BY_PARTITION"} + kwargs["partitioned_by"] = ["ds"] # This goes on the model, not the kind + elif kind == "INCREMENTAL_BY_UNIQUE_KEY": + kwargs["kind"] = {"name": "INCREMENTAL_BY_UNIQUE_KEY", "unique_key": ["id"]} + elif kind == "INCREMENTAL_UNMANAGED": + kwargs["kind"] = {"name": "INCREMENTAL_UNMANAGED"} + elif kind == "SCD_TYPE_2": + kwargs["kind"] = { + "name": "SCD_TYPE_2", + "unique_key": ["id"], + "updated_at_name": "updated_at", + } + + return create_sql_model( + "test_model", + parse_one("SELECT 1 as id, CURRENT_DATE as ds, CURRENT_TIMESTAMP as updated_at"), + **kwargs, + ) + + +@pytest.mark.parametrize( + "target_layer,apply_layer,expected_call_count", + [ + (GrantsTargetLayer.ALL, GrantsTargetLayer.PHYSICAL, 1), + (GrantsTargetLayer.ALL, GrantsTargetLayer.VIRTUAL, 1), + (GrantsTargetLayer.PHYSICAL, GrantsTargetLayer.PHYSICAL, 1), + (GrantsTargetLayer.PHYSICAL, GrantsTargetLayer.VIRTUAL, 0), + (GrantsTargetLayer.VIRTUAL, GrantsTargetLayer.PHYSICAL, 0), + (GrantsTargetLayer.VIRTUAL, GrantsTargetLayer.VIRTUAL, 1), + ], +) +def test_apply_grants_target_layer( + target_layer: GrantsTargetLayer, + apply_layer: GrantsTargetLayer, + expected_call_count: int, + adapter_mock: Mock, + mocker: MockerFixture, +): + adapter_mock.SUPPORTS_GRANTS = True + sync_grants_mock = mocker.patch.object(adapter_mock, "sync_grants_config") + strategy = ViewStrategy(adapter_mock) + + model = _create_grants_test_model( + grants={"select": ["user1"]}, grants_target_layer=target_layer + ) + + strategy._apply_grants(model, "test_table", apply_layer) + + if expected_call_count > 0: + assert sync_grants_mock.call_count == expected_call_count + else: + sync_grants_mock.assert_not_called() + + +@pytest.mark.parametrize( + "model_kind_name", + [ + "FULL", + "INCREMENTAL_BY_TIME_RANGE", + "SEED", + "MANAGED", + "SCD_TYPE_2", + "VIEW", + ], +) +def test_grants_create_model_kind( + model_kind_name: str, + adapter_mock: Mock, + mocker: MockerFixture, + make_snapshot: t.Callable[..., Snapshot], +): + adapter_mock.SUPPORTS_GRANTS = True + sync_grants_mock = mocker.patch.object(adapter_mock, "sync_grants_config") + + grants = {"select": ["user1"]} + model = _create_grants_test_model( + grants=grants, kind=model_kind_name, grants_target_layer=GrantsTargetLayer.ALL + ) + snapshot = make_snapshot(model) + + evaluator = SnapshotEvaluator(adapter_mock) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + evaluator.create([snapshot], {}) + + sync_grants_mock.assert_called_once() + assert sync_grants_mock.call_args[0][1] == grants + + +@pytest.mark.parametrize( + "target_layer", + [ + GrantsTargetLayer.PHYSICAL, + GrantsTargetLayer.VIRTUAL, + GrantsTargetLayer.ALL, + ], +) +def test_grants_target_layer( + target_layer: GrantsTargetLayer, + adapter_mock: Mock, + mocker: MockerFixture, + make_snapshot: t.Callable[..., Snapshot], +): + adapter_mock.SUPPORTS_GRANTS = True + sync_grants_mock = mocker.patch.object(adapter_mock, "sync_grants_config") + evaluator = SnapshotEvaluator(adapter_mock) + + grants = {"select": ["user1"]} + model = create_sql_model( + "test_schema.test_model", + parse_one("SELECT 1 as id"), + kind="FULL", + grants=grants, + grants_target_layer=target_layer, + ) + + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + evaluator.create([snapshot], {}) + if target_layer == GrantsTargetLayer.VIRTUAL: + assert sync_grants_mock.call_count == 0 + else: + assert sync_grants_mock.call_count == 1 + assert sync_grants_mock.call_args[0][1] == grants + sync_grants_mock.reset_mock() + evaluator.promote([snapshot], EnvironmentNamingInfo(name="prod")) + if target_layer == GrantsTargetLayer.VIRTUAL: + assert sync_grants_mock.call_count == 1 + elif target_layer == GrantsTargetLayer.PHYSICAL: + # Physical layer: no grants applied during promotion (already applied during create) + assert sync_grants_mock.call_count == 0 + else: # target_layer == GrantsTargetLayer.ALL + # All layers: only virtual grants applied during promotion (physical already done in create) + assert sync_grants_mock.call_count == 1 + + +def test_grants_update( + adapter_mock: Mock, mocker: MockerFixture, make_snapshot: t.Callable[..., Snapshot] +): + adapter_mock.SUPPORTS_GRANTS = True + sync_grants_mock = mocker.patch.object(adapter_mock, "sync_grants_config") + + evaluator = SnapshotEvaluator(adapter_mock) + + model = create_sql_model( + "test_schema.test_model", + parse_one("SELECT 1 as id"), + kind="FULL", + grants={"select": ["user1"]}, + grants_target_layer=GrantsTargetLayer.ALL, + ) + + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + evaluator.create([snapshot], {}) + + sync_grants_mock.assert_called_once() + assert sync_grants_mock.call_args[0][1] == {"select": ["user1"]} + + # Update model query AND change grants + updated_model_dict = model.dict() + updated_model_dict["query"] = parse_one("SELECT 1 as id, 2 as value") + updated_model_dict["grants"] = {"select": ["user2", "user3"], "insert": ["admin"]} + updated_model = SqlModel.parse_obj(updated_model_dict) + + new_snapshot = make_snapshot(updated_model) + new_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + sync_grants_mock.reset_mock() + evaluator.create([new_snapshot], {}) + + sync_grants_mock.assert_called_once() + assert sync_grants_mock.call_args[0][1] == {"select": ["user2", "user3"], "insert": ["admin"]} + + # Update model query AND remove grants + updated_model_dict = model.dict() + updated_model_dict["query"] = parse_one("SELECT 1 as id, 'updated' as status") + updated_model_dict["grants"] = {} + updated_model = SqlModel.parse_obj(updated_model_dict) + + new_snapshot = make_snapshot(updated_model) + new_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + sync_grants_mock.reset_mock() + evaluator.create([new_snapshot], {}) + + sync_grants_mock.assert_called_once() + assert sync_grants_mock.call_args[0][1] == {} + + +def test_grants_create_and_evaluate( + adapter_mock: Mock, mocker: MockerFixture, make_snapshot: t.Callable[..., Snapshot] +): + adapter_mock.SUPPORTS_GRANTS = True + sync_grants_mock = mocker.patch.object(adapter_mock, "sync_grants_config") + + evaluator = SnapshotEvaluator(adapter_mock) + + model = load_sql_based_model( + parse( # type: ignore + """ + MODEL ( + name test_schema.test_model, + kind INCREMENTAL_BY_TIME_RANGE (time_column ds), + grants ( + 'select' = ['reader1', 'reader2'], + 'insert' = ['writer'] + ), + grants_target_layer 'all' + ); + SELECT ds::DATE, value::INT FROM source WHERE ds BETWEEN @start_ds AND @end_ds; + """ + ) + ) + + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + evaluator.create([snapshot], {}) + sync_grants_mock.assert_called_once() + assert sync_grants_mock.call_args[0][1] == { + "select": ["reader1", "reader2"], + "insert": ["writer"], + } + + sync_grants_mock.reset_mock() + evaluator.evaluate( + snapshot, start="2020-01-01", end="2020-01-02", execution_time="2020-01-02", snapshots={} + ) + # Evaluate should not reapply grants + sync_grants_mock.assert_not_called() + + +@pytest.mark.parametrize( + "strategy_class", + [ + EngineManagedStrategy, + FullRefreshStrategy, + IncrementalByTimeRangeStrategy, + IncrementalByPartitionStrategy, + IncrementalUnmanagedStrategy, + IncrementalByUniqueKeyStrategy, + SCDType2Strategy, + # SeedStrategy excluded because seeds do not support migrations + ], +) +def test_grants_materializable_strategy_migrate( + strategy_class: t.Type[MaterializableStrategy], + adapter_mock: Mock, + mocker: MockerFixture, + make_snapshot: t.Callable[..., Snapshot], +): + adapter_mock.SUPPORTS_GRANTS = True + adapter_mock.get_alter_operations.return_value = [] + sync_grants_mock = mocker.patch.object(adapter_mock, "sync_grants_config") + strategy = strategy_class(adapter_mock) + grants = {"select": ["user1"]} + model = _create_grants_test_model(grants=grants, grants_target_layer=GrantsTargetLayer.ALL) + snapshot = make_snapshot(model) + + strategy.migrate( + "target_table", + "source_table", + snapshot, + ignore_destructive=False, + ignore_additive=False, + allow_destructive_snapshots=set(), + allow_additive_snapshots=set(), + ) + + sync_grants_mock.assert_called_once() + assert sync_grants_mock.call_args[0][1] == grants + + +def test_grants_clone_snapshot_in_dev( + adapter_mock: Mock, mocker: MockerFixture, make_snapshot: t.Callable[..., Snapshot] +): + adapter_mock.SUPPORTS_CLONING = True + sync_grants_mock = mocker.patch.object(adapter_mock, "sync_grants_config") + + evaluator = SnapshotEvaluator(adapter_mock) + grants = {"select": ["user1", "user2"]} + model = _create_grants_test_model(grants=grants, grants_target_layer=GrantsTargetLayer.ALL) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + evaluator._clone_snapshot_in_dev( + snapshot, {}, DeployabilityIndex.all_deployable(), {}, {}, set(), set() + ) + + sync_grants_mock.assert_called_once() + assert ( + sync_grants_mock.call_args[0][0].sql() + == f"sqlmesh__default.test_model__{snapshot.version}__dev" + ) + assert sync_grants_mock.call_args[0][1] == grants + + +@pytest.mark.parametrize( + "model_kind_name", + [ + "INCREMENTAL_BY_TIME_RANGE", + "SEED", + ], +) +def test_grants_evaluator_insert_without_replace_query_for_model( + model_kind_name: str, + adapter_mock: Mock, + mocker: MockerFixture, + make_snapshot: t.Callable[..., Snapshot], +): + adapter_mock.SUPPORTS_GRANTS = True + adapter_mock.table_exists.return_value = False # Table doesn't exist + sync_grants_mock = mocker.patch.object(adapter_mock, "sync_grants_config") + + evaluator = SnapshotEvaluator(adapter_mock) + + grants = {"select": ["reader1", "reader2"]} + model = _create_grants_test_model( + grants=grants, kind=model_kind_name, grants_target_layer=GrantsTargetLayer.ALL + ) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + evaluator.evaluate( + snapshot, + start="2023-01-01", + end="2023-01-01", + execution_time="2023-01-01", + snapshots={}, + ) + + # Grants are applied during the table creation phase, not during insert + sync_grants_mock.assert_called_once() + assert sync_grants_mock.call_args[0][1] == grants + + sync_grants_mock.reset_mock() + adapter_mock.table_exists.return_value = True + snapshot.add_interval("2023-01-01", "2023-01-01") + evaluator.evaluate( + snapshot, + start="2023-01-02", # Different date from existing interval + end="2023-01-02", + execution_time="2023-01-02", + snapshots={}, + ) + + # Should not apply grants since it's not the first insert + sync_grants_mock.assert_not_called() + + +@pytest.mark.parametrize( + "model_kind_name", + [ + "INCREMENTAL_BY_PARTITION", + "INCREMENTAL_BY_UNIQUE_KEY", + "INCREMENTAL_UNMANAGED", + "FULL", + "SCD_TYPE_2", + ], +) +def test_grants_evaluator_insert_with_replace_query_for_model( + model_kind_name: str, + adapter_mock: Mock, + mocker: MockerFixture, + make_snapshot: t.Callable[..., Snapshot], +): + adapter_mock.SUPPORTS_GRANTS = True + sync_grants_mock = mocker.patch.object(adapter_mock, "sync_grants_config") + adapter_mock.table_exists.return_value = False # Table doesn't exist + adapter_mock.columns.return_value = { + "id": exp.DataType.build("int"), + "ds": exp.DataType.build("date"), + } + + evaluator = SnapshotEvaluator(adapter_mock) + + grants = {"select": ["user1"]} + model = _create_grants_test_model( + grants=grants, kind=model_kind_name, grants_target_layer=GrantsTargetLayer.ALL + ) + snapshot = make_snapshot(model) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + + # Now evaluate the snapshot (this should apply grants during first insert) + evaluator.evaluate( + snapshot, + start="2023-01-01", + end="2023-01-01", + execution_time="2023-01-01", + snapshots={}, + ) + + # Should be called twice more during evaluate: once creating table, + # once during first insert with _replace_query_for_model() + assert sync_grants_mock.call_count == 2 + assert sync_grants_mock.call_args[0][1] == grants + + sync_grants_mock.reset_mock() + adapter_mock.table_exists.return_value = True + snapshot.add_interval("2023-01-01", "2023-01-01") + evaluator.evaluate( + snapshot, + start="2023-01-02", # Different date from existing interval + end="2023-01-02", + execution_time="2023-01-02", + snapshots={}, + ) + + if model_kind_name in ("FULL", "SCD_TYPE_2"): + # Full refresh and SCD_TYPE_2 always recreate the table, so grants are always applied + sync_grants_mock.assert_called_once() + assert sync_grants_mock.call_args[0][1] == grants + else: + # Should not apply grants since it's not the first insert + sync_grants_mock.assert_not_called() + + +@pytest.mark.parametrize( + "model_grants_target_layer", + [ + GrantsTargetLayer.ALL, + GrantsTargetLayer.VIRTUAL, + GrantsTargetLayer.PHYSICAL, + ], +) +def test_grants_in_production_with_dev_only_vde( + adapter_mock: Mock, + mocker: MockerFixture, + make_snapshot: t.Callable[..., Snapshot], + model_grants_target_layer: GrantsTargetLayer, +): + adapter_mock.SUPPORTS_GRANTS = True + sync_grants_mock = mocker.patch.object(adapter_mock, "sync_grants_config") + + from sqlmesh.core.model.meta import VirtualEnvironmentMode, GrantsTargetLayer + from sqlmesh.core.snapshot.definition import DeployabilityIndex + + model_virtual_grants = _create_grants_test_model( + grants={"select": ["user1"], "insert": ["role1"]}, + grants_target_layer=model_grants_target_layer, + virtual_environment_mode=VirtualEnvironmentMode.DEV_ONLY, + ) + + snapshot = make_snapshot(model_virtual_grants) + snapshot.categorize_as(SnapshotChangeCategory.BREAKING) + evaluator = SnapshotEvaluator(adapter_mock) + # create will apply grants to physical layer tables + deployability_index = DeployabilityIndex.all_deployable() + evaluator.create([snapshot], {}, deployability_index=deployability_index) + + sync_grants_mock.assert_called_once() + assert sync_grants_mock.call_args[0][1] == {"select": ["user1"], "insert": ["role1"]} + + # Non-deployable (dev) env + sync_grants_mock.reset_mock() + deployability_index = DeployabilityIndex.none_deployable() + evaluator.create([snapshot], {}, deployability_index=deployability_index) + if model_grants_target_layer == GrantsTargetLayer.VIRTUAL: + sync_grants_mock.assert_not_called() + else: + # Should still apply grants to physical table when target layer is ALL or PHYSICAL + sync_grants_mock.assert_called_once() + assert sync_grants_mock.call_args[0][1] == {"select": ["user1"], "insert": ["role1"]} diff --git a/tests/core/test_state_sync.py b/tests/core/test_state_sync.py deleted file mode 100644 index 7d0f18efae..0000000000 --- a/tests/core/test_state_sync.py +++ /dev/null @@ -1,2327 +0,0 @@ -import json -import re -import typing as t -from unittest.mock import call, patch - -import duckdb -import pandas as pd -import pytest -from freezegun import freeze_time -from pytest_mock.plugin import MockerFixture -from sqlglot import exp - -from sqlmesh.core import constants as c -from sqlmesh.core.config import EnvironmentSuffixTarget -from sqlmesh.core.dialect import parse_one, schema_ -from sqlmesh.core.engine_adapter import create_engine_adapter -from sqlmesh.core.environment import Environment -from sqlmesh.core.model import ( - FullKind, - IncrementalByTimeRangeKind, - ModelKindName, - Seed, - SeedKind, - SeedModel, - SqlModel, -) -from sqlmesh.core.model.definition import ExternalModel -from sqlmesh.core.snapshot import ( - Snapshot, - SnapshotChangeCategory, - SnapshotId, - SnapshotIntervals, - SnapshotTableCleanupTask, - missing_intervals, -) -from sqlmesh.core.state_sync import ( - CachingStateSync, - EngineAdapterStateSync, - cleanup_expired_views, -) -from sqlmesh.core.state_sync.base import ( - SCHEMA_VERSION, - SQLGLOT_VERSION, - PromotionResult, - Versions, -) -from sqlmesh.utils.date import now_timestamp, to_datetime, to_timestamp -from sqlmesh.utils.errors import SQLMeshError - -pytestmark = pytest.mark.slow - - -@pytest.fixture -def state_sync(duck_conn, tmp_path): - state_sync = EngineAdapterStateSync( - create_engine_adapter(lambda: duck_conn, "duckdb"), schema=c.SQLMESH, context_path=tmp_path - ) - state_sync.migrate(default_catalog=None) - return state_sync - - -@pytest.fixture -def snapshots(make_snapshot: t.Callable) -> t.List[Snapshot]: - return [ - make_snapshot( - SqlModel( - name="a", - query=parse_one("select 1, ds"), - ), - version="a", - ), - make_snapshot( - SqlModel( - name="b", - query=parse_one("select 2, ds"), - ), - version="b", - ), - ] - - -def promote_snapshots( - state_sync: EngineAdapterStateSync, - snapshots: t.List[Snapshot], - environment: str, - no_gaps: bool = False, - no_gaps_snapshot_names: t.Optional[t.Set[str]] = None, - environment_suffix_target: EnvironmentSuffixTarget = EnvironmentSuffixTarget.SCHEMA, - environment_catalog_mapping: t.Optional[t.Dict[re.Pattern, str]] = None, -) -> PromotionResult: - env = Environment.from_environment_catalog_mapping( - environment_catalog_mapping or {}, - name=environment, - suffix_target=environment_suffix_target, - snapshots=[snapshot.table_info for snapshot in snapshots], - start_at="2022-01-01", - end_at="2022-01-01", - plan_id="test_plan_id", - previous_plan_id="test_plan_id", - ) - return state_sync.promote( - env, no_gaps_snapshot_names=no_gaps_snapshot_names if no_gaps else set() - ) - - -def delete_versions(state_sync: EngineAdapterStateSync) -> None: - state_sync.engine_adapter.drop_table(state_sync.versions_table) - - -def test_push_snapshots( - state_sync: EngineAdapterStateSync, - make_snapshot: t.Callable, -) -> None: - snapshot_a = make_snapshot( - SqlModel( - name="a", - query=parse_one("select 1, ds"), - ) - ) - snapshot_b = make_snapshot( - SqlModel( - name="b", - query=parse_one("select 2, ds"), - ) - ) - - with pytest.raises( - SQLMeshError, - match=r".*has not been versioned.*", - ): - state_sync.push_snapshots([snapshot_a, snapshot_b]) - - snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) - snapshot_b.categorize_as(SnapshotChangeCategory.FORWARD_ONLY) - snapshot_b.version = "2" - - state_sync.push_snapshots([snapshot_a, snapshot_b]) - - assert state_sync.get_snapshots([snapshot_a.snapshot_id, snapshot_b.snapshot_id]) == { - snapshot_a.snapshot_id: snapshot_a, - snapshot_b.snapshot_id: snapshot_b, - } - - with pytest.raises( - SQLMeshError, - match=r".*already exists.*", - ): - state_sync.push_snapshots([snapshot_a]) - - with pytest.raises( - SQLMeshError, - match=r".*already exists.*", - ): - state_sync.push_snapshots([snapshot_a, snapshot_b]) - - # test serialization - state_sync.push_snapshots( - [ - make_snapshot( - SqlModel( - name="a", - kind=FullKind(), - query=parse_one( - """ - select 'x' + ' ' as y, - "z" + '\' as z, - """ - ), - ), - version="1", - ) - ] - ) - - -def test_duplicates(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable) -> None: - snapshot_a = make_snapshot( - SqlModel( - name="a", - query=parse_one("select 1, ds"), - ), - version="1", - ) - snapshot_b = make_snapshot( - SqlModel( - name="a", - query=parse_one("select 1, ds"), - ), - version="1", - ) - snapshot_c = make_snapshot( - SqlModel( - name="a", - query=parse_one("select 1, ds"), - ), - version="1", - ) - snapshot_b.updated_ts = snapshot_a.updated_ts + 1 - snapshot_c.updated_ts = 0 - state_sync.push_snapshots([snapshot_a]) - state_sync._push_snapshots([snapshot_a]) - state_sync._push_snapshots([snapshot_b]) - state_sync._push_snapshots([snapshot_c]) - assert ( - state_sync.get_snapshots([snapshot_a])[snapshot_a.snapshot_id].updated_ts - == snapshot_b.updated_ts - ) - - -def test_snapshots_exists(state_sync: EngineAdapterStateSync, snapshots: t.List[Snapshot]) -> None: - state_sync.push_snapshots(snapshots) - snapshot_ids = {snapshot.snapshot_id for snapshot in snapshots} - assert state_sync.snapshots_exist(snapshot_ids) == snapshot_ids - - -@pytest.fixture -def get_snapshot_intervals(state_sync) -> t.Callable[[Snapshot], t.Optional[SnapshotIntervals]]: - def _get_snapshot_intervals(snapshot: Snapshot) -> t.Optional[SnapshotIntervals]: - intervals = state_sync._get_snapshot_intervals([snapshot])[-1] - return intervals[0] if intervals else None - - return _get_snapshot_intervals - - -def test_add_interval( - state_sync: EngineAdapterStateSync, - make_snapshot: t.Callable, - get_snapshot_intervals: t.Callable, -) -> None: - snapshot = make_snapshot( - SqlModel( - name="a", - cron="@daily", - query=parse_one("select 1, ds"), - ), - version="a", - ) - - state_sync.push_snapshots([snapshot]) - - state_sync.add_interval(snapshot, "2020-01-01", "20200101") - assert get_snapshot_intervals(snapshot).intervals == [ - (to_timestamp("2020-01-01"), to_timestamp("2020-01-02")), - ] - - state_sync.add_interval(snapshot, "20200101", to_datetime("2020-01-04")) - assert get_snapshot_intervals(snapshot).intervals == [ - (to_timestamp("2020-01-01"), to_timestamp("2020-01-04")), - ] - - state_sync.add_interval(snapshot, to_datetime("2020-01-05"), "2020-01-10") - assert get_snapshot_intervals(snapshot).intervals == [ - (to_timestamp("2020-01-01"), to_timestamp("2020-01-04")), - (to_timestamp("2020-01-05"), to_timestamp("2020-01-11")), - ] - - snapshot.change_category = SnapshotChangeCategory.FORWARD_ONLY - state_sync.add_interval(snapshot, to_datetime("2020-01-16"), "2020-01-20", is_dev=True) - intervals = get_snapshot_intervals(snapshot) - assert intervals.intervals == [ - (to_timestamp("2020-01-01"), to_timestamp("2020-01-04")), - (to_timestamp("2020-01-05"), to_timestamp("2020-01-11")), - ] - assert intervals.dev_intervals == [ - (to_timestamp("2020-01-16"), to_timestamp("2020-01-21")), - ] - - -def test_add_interval_partial( - state_sync: EngineAdapterStateSync, - make_snapshot: t.Callable, - get_snapshot_intervals: t.Callable, -) -> None: - snapshot = make_snapshot( - SqlModel( - name="a", - cron="@daily", - query=parse_one("select 1, ds"), - ), - version="a", - ) - - state_sync.push_snapshots([snapshot]) - - state_sync.add_interval(snapshot, "2023-01-01", to_timestamp("2023-01-01") + 1000) - assert get_snapshot_intervals(snapshot) is None - - state_sync.add_interval(snapshot, "2023-01-01", to_timestamp("2023-01-02") + 1000) - assert get_snapshot_intervals(snapshot).intervals == [ - (to_timestamp("2023-01-01"), to_timestamp("2023-01-02")), - ] - - -def test_remove_interval(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable) -> None: - snapshot_a = make_snapshot( - SqlModel( - name="a", - cron="@daily", - query=parse_one("select 1, ds"), - ), - version="a", - ) - snapshot_b = make_snapshot( - SqlModel( - name="a", - cron="@daily", - query=parse_one("select 2::INT, '2022-01-01'::TEXT AS ds"), - ), - version="a", - ) - state_sync.push_snapshots([snapshot_a, snapshot_b]) - state_sync.add_interval(snapshot_a, "2020-01-01", "2020-01-10") - state_sync.add_interval(snapshot_b, "2020-01-11", "2020-01-30") - - state_sync.remove_interval( - [(snapshot_a, snapshot_a.inclusive_exclusive("2020-01-15", "2020-01-17"))], - remove_shared_versions=True, - ) - - snapshots = state_sync.get_snapshots([snapshot_a, snapshot_b]) - - assert snapshots[snapshot_a.snapshot_id].intervals == [ - (to_timestamp("2020-01-01"), to_timestamp("2020-01-15")), - (to_timestamp("2020-01-18"), to_timestamp("2020-01-31")), - ] - assert snapshots[snapshot_b.snapshot_id].intervals == [ - (to_timestamp("2020-01-01"), to_timestamp("2020-01-15")), - (to_timestamp("2020-01-18"), to_timestamp("2020-01-31")), - ] - - -def test_remove_interval_missing_snapshot( - state_sync: EngineAdapterStateSync, make_snapshot: t.Callable -) -> None: - snapshot_a = make_snapshot( - SqlModel( - name="a", - cron="@daily", - query=parse_one("select 1, ds"), - ), - version="a", - ) - snapshot_b = make_snapshot( - SqlModel( - name="a", - cron="@daily", - query=parse_one("select 2::INT, '2022-01-01'::TEXT AS ds"), - ), - version="a", - ) - state_sync.push_snapshots([snapshot_a, snapshot_b]) - state_sync.add_interval(snapshot_a, "2020-01-01", "2020-01-10") - state_sync.add_interval(snapshot_b, "2020-01-11", "2020-01-30") - # Remove snapshot b in order to test the scenario where it is missing - state_sync.delete_snapshots([snapshot_b.snapshot_id]) - - snapshots = state_sync.get_snapshots([snapshot_a, snapshot_b]) - assert len(snapshots) == 1 - assert snapshots[snapshot_a.snapshot_id].intervals == [ - (to_timestamp("2020-01-01"), to_timestamp("2020-01-31")), - ] - - state_sync.remove_interval( - [(snapshot_a, snapshot_a.inclusive_exclusive("2020-01-15", "2020-01-17"))], - remove_shared_versions=True, - ) - - snapshots = state_sync.get_snapshots([snapshot_a, snapshot_b]) - assert len(snapshots) == 1 - assert snapshots[snapshot_a.snapshot_id].intervals == [ - (to_timestamp("2020-01-01"), to_timestamp("2020-01-15")), - (to_timestamp("2020-01-18"), to_timestamp("2020-01-31")), - ] - - -def test_refresh_snapshot_intervals( - state_sync: EngineAdapterStateSync, make_snapshot: t.Callable -) -> None: - snapshot = make_snapshot( - SqlModel( - name="a", - cron="@daily", - query=parse_one("select 1, ds"), - ), - version="a", - ) - - state_sync.push_snapshots([snapshot]) - state_sync.add_interval(snapshot, "2023-01-01", "2023-01-01") - assert not snapshot.intervals - - state_sync.refresh_snapshot_intervals([snapshot]) - assert snapshot.intervals == [(to_timestamp("2023-01-01"), to_timestamp("2023-01-02"))] - - -def test_get_snapshot_intervals( - state_sync: EngineAdapterStateSync, make_snapshot: t.Callable, get_snapshot_intervals -) -> None: - state_sync.SNAPSHOT_BATCH_SIZE = 1 - - snapshot_a = make_snapshot( - SqlModel( - name="a", - cron="@daily", - query=parse_one("select 1, ds"), - ), - version="a", - ) - - state_sync.push_snapshots([snapshot_a]) - state_sync.add_interval(snapshot_a, "2020-01-01", "2020-01-01") - - snapshot_b = make_snapshot( - SqlModel( - name="a", - cron="@daily", - query=parse_one("select 2, ds"), - ), - version="a", - ) - state_sync.push_snapshots([snapshot_b]) - - snapshot_c = make_snapshot( - SqlModel( - name="c", - cron="@daily", - query=parse_one("select 3, ds"), - ), - version="c", - ) - state_sync.push_snapshots([snapshot_c]) - state_sync.add_interval(snapshot_c, "2020-01-03", "2020-01-03") - - a_intervals = get_snapshot_intervals(snapshot_a) - c_intervals = get_snapshot_intervals(snapshot_c) - assert a_intervals.intervals == [(to_timestamp("2020-01-01"), to_timestamp("2020-01-02"))] - assert c_intervals.intervals == [(to_timestamp("2020-01-03"), to_timestamp("2020-01-04"))] - - -def test_compact_intervals( - state_sync: EngineAdapterStateSync, - make_snapshot: t.Callable, - get_snapshot_intervals: t.Callable, -) -> None: - snapshot = make_snapshot( - SqlModel( - name="a", - cron="@daily", - query=parse_one("select 1, ds"), - ), - version="a", - ) - - state_sync.push_snapshots([snapshot]) - - state_sync.add_interval(snapshot, "2020-01-01", "2020-01-10") - state_sync.add_interval(snapshot, "2020-01-11", "2020-01-15") - state_sync.remove_interval( - [(snapshot, snapshot.inclusive_exclusive("2020-01-05", "2020-01-12"))] - ) - state_sync.add_interval(snapshot, "2020-01-12", "2020-01-16") - state_sync.remove_interval( - [(snapshot, snapshot.inclusive_exclusive("2020-01-14", "2020-01-16"))] - ) - - expected_intervals = [ - (to_timestamp("2020-01-01"), to_timestamp("2020-01-05")), - (to_timestamp("2020-01-12"), to_timestamp("2020-01-14")), - ] - - assert get_snapshot_intervals(snapshot).intervals == expected_intervals - - state_sync.compact_intervals() - assert get_snapshot_intervals(snapshot).intervals == expected_intervals - - # Make sure compaction is idempotent. - state_sync.compact_intervals() - assert get_snapshot_intervals(snapshot).intervals == expected_intervals - - -def test_compact_intervals_delete_batches( - state_sync: EngineAdapterStateSync, - make_snapshot: t.Callable, - mocker: MockerFixture, -) -> None: - snapshot = make_snapshot( - SqlModel( - name="a", - cron="@daily", - query=parse_one("select 1, ds"), - ), - version="a", - ) - - delete_from_mock = mocker.patch.object(state_sync.engine_adapter, "delete_from") - state_sync.INTERVAL_BATCH_SIZE = 2 - - state_sync.push_snapshots([snapshot]) - - state_sync.add_interval(snapshot, "2020-01-01", "2020-01-11") - state_sync.add_interval(snapshot, "2020-01-01", "2020-01-12") - state_sync.add_interval(snapshot, "2020-01-01", "2020-01-13") - state_sync.add_interval(snapshot, "2020-01-01", "2020-01-14") - state_sync.add_interval(snapshot, "2020-01-01", "2020-01-15") - - state_sync.compact_intervals() - - delete_from_mock.assert_has_calls([call(state_sync.intervals_table, mocker.ANY)] * 3) - - -def test_promote_snapshots(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable): - snapshot_a = make_snapshot( - SqlModel( - name="a", - query=parse_one("select 1, ds"), - ), - ) - snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) - - snapshot_b_old = make_snapshot( - SqlModel( - name="b", - kind=FullKind(), - query=parse_one("select 2 from a"), - ), - nodes={"a": snapshot_a.model}, - ) - snapshot_b_old.categorize_as(SnapshotChangeCategory.BREAKING) - - snapshot_b = make_snapshot( - SqlModel( - name="b", - kind=FullKind(), - query=parse_one("select * from a"), - ), - nodes={"a": snapshot_a.model}, - ) - snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) - - snapshot_c = make_snapshot( - SqlModel( - name="c", - query=parse_one("select 3, ds"), - ), - ) - snapshot_c.categorize_as(SnapshotChangeCategory.BREAKING) - - with pytest.raises( - SQLMeshError, - match=r"Missing snapshots.*", - ): - promote_snapshots(state_sync, [snapshot_a], "prod") - - state_sync.push_snapshots([snapshot_a, snapshot_b_old, snapshot_b, snapshot_c]) - - promotion_result = promote_snapshots(state_sync, [snapshot_a, snapshot_b_old], "prod") - - assert set(promotion_result.added) == set([snapshot_a.table_info, snapshot_b_old.table_info]) - assert not promotion_result.removed - assert not promotion_result.removed_environment_naming_info - promotion_result = promote_snapshots( - state_sync, - [snapshot_a, snapshot_b_old, snapshot_c], - "prod", - ) - assert set(promotion_result.added) == { - snapshot_a.table_info, - snapshot_b_old.table_info, - snapshot_c.table_info, - } - assert not promotion_result.removed - assert not promotion_result.removed_environment_naming_info - - prev_snapshot_b_old_updated_ts = snapshot_b_old.updated_ts - prev_snapshot_c_updated_ts = snapshot_c.updated_ts - - promotion_result = promote_snapshots( - state_sync, - [snapshot_a, snapshot_b], - "prod", - ) - assert set(promotion_result.added) == {snapshot_a.table_info, snapshot_b.table_info} - assert set(promotion_result.removed) == {snapshot_c.table_info} - assert promotion_result.removed_environment_naming_info - assert promotion_result.removed_environment_naming_info.suffix_target.is_schema - assert ( - state_sync.get_snapshots([snapshot_c])[snapshot_c.snapshot_id].updated_ts - > prev_snapshot_c_updated_ts - ) - assert ( - state_sync.get_snapshots([snapshot_b_old])[snapshot_b_old.snapshot_id].updated_ts - > prev_snapshot_b_old_updated_ts - ) - - snapshot_d = make_snapshot( - SqlModel( - name="a", - query=parse_one("select 2, ds"), - ), - ) - snapshot_d.categorize_as(SnapshotChangeCategory.BREAKING) - - state_sync.push_snapshots([snapshot_d]) - promotion_result = promote_snapshots(state_sync, [snapshot_d], "prod") - assert set(promotion_result.added) == {snapshot_d.table_info} - assert set(promotion_result.removed) == {snapshot_b.table_info} - assert promotion_result.removed_environment_naming_info - assert promotion_result.removed_environment_naming_info.suffix_target.is_schema - - -def test_promote_snapshots_suffix_change( - state_sync: EngineAdapterStateSync, make_snapshot: t.Callable -): - snapshot_a = make_snapshot( - SqlModel( - name="a", - query=parse_one("select 1, ds"), - ), - ) - snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) - - snapshot_b = make_snapshot( - SqlModel( - name="b", - kind=FullKind(), - query=parse_one("select * from a"), - ), - nodes={"a": snapshot_a.model}, - ) - snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) - - state_sync.push_snapshots([snapshot_a, snapshot_b]) - - promotion_result = promote_snapshots( - state_sync, - [snapshot_a, snapshot_b], - "prod", - environment_suffix_target=EnvironmentSuffixTarget.TABLE, - ) - - assert set(promotion_result.added) == {snapshot_a.table_info, snapshot_b.table_info} - assert not promotion_result.removed - assert not promotion_result.removed_environment_naming_info - - snapshot_c = make_snapshot( - SqlModel( - name="c", - query=parse_one("select 3, ds"), - ), - ) - snapshot_c.categorize_as(SnapshotChangeCategory.BREAKING) - - state_sync.push_snapshots([snapshot_c]) - - promotion_result = promote_snapshots( - state_sync, - [snapshot_b, snapshot_c], - "prod", - environment_suffix_target=EnvironmentSuffixTarget.SCHEMA, - ) - - # We still only add the snapshots that are included in the promotion - assert set(promotion_result.added) == {snapshot_b.table_info, snapshot_c.table_info} - # B does not get removed because the suffix target change doesn't affect it due to running in prod. - assert set(promotion_result.removed) == {snapshot_a.table_info} - # Make sure the removed suffix target is correctly seen as table - assert promotion_result.removed_environment_naming_info is not None - assert promotion_result.removed_environment_naming_info.suffix_target.is_table - - promotion_result = promote_snapshots( - state_sync, - [snapshot_b, snapshot_c], - "dev", - environment_suffix_target=EnvironmentSuffixTarget.SCHEMA, - ) - - # We still only add the snapshots that are included in the promotion - assert set(promotion_result.added) == {snapshot_b.table_info, snapshot_c.table_info} - assert len(promotion_result.removed) == 0 - assert promotion_result.removed_environment_naming_info is None - - promotion_result = promote_snapshots( - state_sync, - [snapshot_b, snapshot_c], - "dev", - environment_suffix_target=EnvironmentSuffixTarget.TABLE, - ) - - # All snapshots are promoted due to suffix target change - assert set(promotion_result.added) == { - snapshot_b.table_info, - snapshot_c.table_info, - } - # All snapshots are removed due to suffix target change - assert set(promotion_result.removed) == { - snapshot_b.table_info, - snapshot_c.table_info, - } - assert promotion_result.removed_environment_naming_info is not None - assert promotion_result.removed_environment_naming_info.suffix_target.is_schema - - -def test_promote_snapshots_catalog_name_override_change( - state_sync: EngineAdapterStateSync, make_snapshot: t.Callable -): - snapshot_a = make_snapshot( - SqlModel( - name="catalog1.schema.a", - query=parse_one("select 1, ds"), - ), - ) - snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) - - snapshot_b = make_snapshot( - SqlModel( - name="catalog1.schema.b", - kind=FullKind(), - query=parse_one("select * from a"), - ), - nodes={"a": snapshot_a.model}, - ) - snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) - - snapshot_c = make_snapshot( - SqlModel( - name="catalog2.schema.c", - kind=FullKind(), - query=parse_one("select * from a"), - ), - nodes={"a": snapshot_a.model}, - ) - snapshot_c.categorize_as(SnapshotChangeCategory.BREAKING) - - state_sync.push_snapshots([snapshot_a, snapshot_b, snapshot_c]) - - promotion_result = promote_snapshots( - state_sync, - [snapshot_a, snapshot_b, snapshot_c], - "prod", - environment_catalog_mapping={}, - ) - - assert set(promotion_result.added) == { - snapshot_a.table_info, - snapshot_b.table_info, - snapshot_c.table_info, - } - assert not promotion_result.removed - assert not promotion_result.removed_environment_naming_info - - snapshot_d = make_snapshot( - SqlModel( - name="catalog1.schema.d", - query=parse_one("select 3, ds"), - ), - ) - snapshot_d.categorize_as(SnapshotChangeCategory.BREAKING) - - state_sync.push_snapshots([snapshot_d]) - - promotion_result = promote_snapshots( - state_sync, - [snapshot_b, snapshot_c, snapshot_d], - "prod", - environment_catalog_mapping={ - re.compile("^prod$"): "catalog1", - }, - ) - - # We still only add the snapshots that are included in the promotion which means removing A - assert set(promotion_result.added) == { - snapshot_b.table_info, - snapshot_c.table_info, - snapshot_d.table_info, - } - # C is removed because of the catalog change. The new one will be created in the new catalog. - # B is not removed because it's catalog did not change and therefore removing would actually result - # in dropping what we just added. - # A is removed because it was explicitly removed from the promotion. - assert set(promotion_result.removed) == {snapshot_a.table_info, snapshot_c.table_info} - # Make sure the removed suffix target correctly has the old catalog name set - assert promotion_result.removed_environment_naming_info - assert promotion_result.removed_environment_naming_info.catalog_name_override is None - - promotion_result = promote_snapshots( - state_sync, - [snapshot_b, snapshot_c, snapshot_d], - "prod", - environment_catalog_mapping={ - re.compile("^prod$"): "catalog2", - }, - ) - - # All are added since their catalog was changed - assert set(promotion_result.added) == { - snapshot_b.table_info, - snapshot_c.table_info, - snapshot_d.table_info, - } - # All are removed since there were moved from their old catalog location - # Note that C has a catalog set in the model definition of `catalog2` which is what we moved to so you might think - # it shouldn't be removed, but its actual catalog was `catalog1` because of the previous override so therefore - # it should be removed from `catalog1`. - assert set(promotion_result.removed) == { - snapshot_b.table_info, - snapshot_c.table_info, - snapshot_d.table_info, - } - # Make sure the removed suffix target correctly has the old catalog name set - assert promotion_result.removed_environment_naming_info - assert promotion_result.removed_environment_naming_info.catalog_name_override == "catalog1" - - -def test_promote_snapshots_parent_plan_id_mismatch( - state_sync: EngineAdapterStateSync, make_snapshot: t.Callable -): - snapshot = make_snapshot( - SqlModel( - name="a", - query=parse_one("select 1, ds"), - ), - ) - snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - - state_sync.push_snapshots([snapshot]) - promote_snapshots(state_sync, [snapshot], "prod") - - new_environment = Environment( - name="prod", - snapshots=[snapshot.table_info], - start_at="2022-01-01", - end_at="2022-01-01", - plan_id="new_plan_id", - previous_plan_id="test_plan_id", - ) - - stale_new_environment = Environment( - name="prod", - snapshots=[snapshot.table_info], - start_at="2022-01-01", - end_at="2022-01-01", - plan_id="stale_new_plan_id", - previous_plan_id="test_plan_id", - ) - - state_sync.promote(new_environment) - - with pytest.raises( - SQLMeshError, - match=r".*is no longer valid.*", - ): - state_sync.promote(stale_new_environment) - - -def test_promote_snapshots_no_gaps(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable): - model = SqlModel( - name="a", - query=parse_one("select 1, ds"), - kind=IncrementalByTimeRangeKind(time_column="ds"), - start="2022-01-01", - ) - - snapshot = make_snapshot(model, version="a") - snapshot.change_category = SnapshotChangeCategory.BREAKING - state_sync.push_snapshots([snapshot]) - state_sync.add_interval(snapshot, "2022-01-01", "2022-01-02") - promote_snapshots(state_sync, [snapshot], "prod", no_gaps=True) - - new_snapshot_same_version = make_snapshot(model, version="a") - new_snapshot_same_version.change_category = SnapshotChangeCategory.INDIRECT_NON_BREAKING - new_snapshot_same_version.fingerprint = snapshot.fingerprint.copy( - update={"data_hash": "new_snapshot_same_version"} - ) - state_sync.push_snapshots([new_snapshot_same_version]) - state_sync.add_interval(new_snapshot_same_version, "2022-01-03", "2022-01-03") - promote_snapshots(state_sync, [new_snapshot_same_version], "prod", no_gaps=True) - - new_snapshot_missing_interval = make_snapshot(model, version="b") - new_snapshot_missing_interval.change_category = SnapshotChangeCategory.BREAKING - new_snapshot_missing_interval.fingerprint = snapshot.fingerprint.copy( - update={"data_hash": "new_snapshot_missing_interval"} - ) - state_sync.push_snapshots([new_snapshot_missing_interval]) - state_sync.add_interval(new_snapshot_missing_interval, "2022-01-01", "2022-01-02") - with pytest.raises( - SQLMeshError, - match=r"Detected gaps in snapshot.*", - ): - promote_snapshots(state_sync, [new_snapshot_missing_interval], "prod", no_gaps=True) - - new_snapshot_same_interval = make_snapshot(model, version="c") - new_snapshot_same_interval.change_category = SnapshotChangeCategory.BREAKING - new_snapshot_same_interval.fingerprint = snapshot.fingerprint.copy( - update={"data_hash": "new_snapshot_same_interval"} - ) - state_sync.push_snapshots([new_snapshot_same_interval]) - state_sync.add_interval(new_snapshot_same_interval, "2022-01-01", "2022-01-03") - promote_snapshots(state_sync, [new_snapshot_same_interval], "prod", no_gaps=True) - - # We should skip the gaps check if the snapshot is not representative. - promote_snapshots( - state_sync, - [new_snapshot_missing_interval], - "prod", - no_gaps=True, - no_gaps_snapshot_names=set(), - ) - - -@freeze_time("2023-01-08 16:00:00") -def test_promote_snapshots_no_gaps_lookback( - state_sync: EngineAdapterStateSync, make_snapshot: t.Callable -): - model = SqlModel( - name="a", - cron="@hourly", - query=parse_one("select 1, ds"), - kind=IncrementalByTimeRangeKind(time_column="ds", lookback=1), - start="2023-01-01", - ) - - snapshot = make_snapshot(model, version="a") - snapshot.change_category = SnapshotChangeCategory.BREAKING - state_sync.push_snapshots([snapshot]) - state_sync.add_interval(snapshot, "2023-01-01", "2023-01-08 15:00:00") - promote_snapshots(state_sync, [snapshot], "prod", no_gaps=True) - - assert now_timestamp() == to_timestamp("2023-01-08 16:00:00") - - new_snapshot_same_version = make_snapshot(model, version="b") - new_snapshot_same_version.change_category = SnapshotChangeCategory.BREAKING - new_snapshot_same_version.fingerprint = snapshot.fingerprint.copy( - update={"data_hash": "new_snapshot_same_version"} - ) - state_sync.push_snapshots([new_snapshot_same_version]) - state_sync.add_interval(new_snapshot_same_version, "2023-01-01", "2023-01-08 15:00:00") - promote_snapshots(state_sync, [new_snapshot_same_version], "prod", no_gaps=True) - - -def test_finalize(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable): - snapshot_a = make_snapshot( - SqlModel( - name="a", - query=parse_one("select 1, ds"), - ), - ) - snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) - - state_sync.push_snapshots([snapshot_a]) - promote_snapshots(state_sync, [snapshot_a], "prod") - - env = state_sync.get_environment("prod") - assert env - state_sync.finalize(env) - - env = state_sync.get_environment("prod") - assert env - assert env.finalized_ts is not None - - env.plan_id = "different_plan_id" - with pytest.raises( - SQLMeshError, - match=r"Plan 'different_plan_id' is no longer valid for the target environment 'prod'.*", - ): - state_sync.finalize(env) - - -def test_start_date_gap(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable): - model = SqlModel( - name="a", - query=parse_one("select 1, ds"), - start="2022-01-01", - kind=IncrementalByTimeRangeKind(time_column="ds"), - cron="@daily", - ) - - snapshot = make_snapshot(model, version="a") - snapshot.change_category = SnapshotChangeCategory.BREAKING - state_sync.push_snapshots([snapshot]) - state_sync.add_interval(snapshot, "2022-01-01", "2022-01-03") - promote_snapshots(state_sync, [snapshot], "prod") - - model = SqlModel( - name="a", - query=parse_one("select 1, ds"), - start="2022-01-02", - kind=IncrementalByTimeRangeKind(time_column="ds"), - cron="@daily", - ) - - snapshot = make_snapshot(model, version="b") - snapshot.change_category = SnapshotChangeCategory.BREAKING - state_sync.push_snapshots([snapshot]) - state_sync.add_interval(snapshot, "2022-01-03", "2022-01-04") - with pytest.raises( - SQLMeshError, - match=r"Detected gaps in snapshot.*", - ): - promote_snapshots(state_sync, [snapshot], "prod", no_gaps=True) - - state_sync.add_interval(snapshot, "2022-01-02", "2022-01-03") - promote_snapshots(state_sync, [snapshot], "prod", no_gaps=True) - - -def test_delete_expired_environments(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable): - snapshot = make_snapshot( - SqlModel( - name="a", - query=parse_one("select a, ds"), - ), - ) - snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - - state_sync.push_snapshots([snapshot]) - - now_ts = now_timestamp() - - env_a = Environment( - name="test_environment_a", - snapshots=[snapshot.table_info], - start_at="2022-01-01", - end_at="2022-01-01", - plan_id="test_plan_id", - previous_plan_id="test_plan_id", - expiration_ts=now_ts - 1000, - ) - state_sync.promote(env_a) - - env_b = env_a.copy(update={"name": "test_environment_b", "expiration_ts": now_ts + 1000}) - state_sync.promote(env_b) - - assert state_sync.get_environment(env_a.name) == env_a - assert state_sync.get_environment(env_b.name) == env_b - - deleted_environments = state_sync.delete_expired_environments() - assert deleted_environments == [env_a] - - assert state_sync.get_environment(env_a.name) is None - assert state_sync.get_environment(env_b.name) == env_b - - -def test_delete_expired_snapshots(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable): - now_ts = now_timestamp() - - snapshot = make_snapshot( - SqlModel( - name="a", - query=parse_one("select a, ds"), - ), - ) - snapshot.ttl = "in 10 seconds" - snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - snapshot.updated_ts = now_ts - 15000 - - new_snapshot = make_snapshot( - SqlModel( - name="a", - query=parse_one("select a, b, ds"), - ), - ) - new_snapshot.ttl = "in 10 seconds" - new_snapshot.categorize_as(SnapshotChangeCategory.FORWARD_ONLY) - new_snapshot.version = snapshot.version - new_snapshot.updated_ts = now_ts - 11000 - - state_sync.push_snapshots([snapshot, new_snapshot]) - assert set(state_sync.get_snapshots(None)) == {snapshot.snapshot_id, new_snapshot.snapshot_id} - - assert state_sync.delete_expired_snapshots() == [ - SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=True), - SnapshotTableCleanupTask(snapshot=new_snapshot.table_info, dev_table_only=False), - ] - - assert not state_sync.get_snapshots(None) - - -def test_delete_expired_snapshots_seed( - state_sync: EngineAdapterStateSync, make_snapshot: t.Callable -): - now_ts = now_timestamp() - - snapshot = make_snapshot( - SeedModel( - name="a", - kind=SeedKind(path="./path/to/seed"), - seed=Seed(content="header\n1\n2"), - column_hashes={"header": "hash"}, - depends_on=set(), - ), - ) - snapshot.ttl = "in 10 seconds" - snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - snapshot.updated_ts = now_ts - 15000 - - state_sync.push_snapshots([snapshot]) - assert set(state_sync.get_snapshots(None)) == {snapshot.snapshot_id} - - assert state_sync.delete_expired_snapshots() == [ - SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=False), - ] - - assert not state_sync.get_snapshots(None) - - -def test_delete_expired_snapshots_batching( - state_sync: EngineAdapterStateSync, make_snapshot: t.Callable -): - state_sync.SNAPSHOT_BATCH_SIZE = 1 - now_ts = now_timestamp() - - snapshot_a = make_snapshot( - SqlModel( - name="a", - query=parse_one("select a, ds"), - ), - ) - snapshot_a.ttl = "in 10 seconds" - snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) - snapshot_a.updated_ts = now_ts - 15000 - - snapshot_b = make_snapshot( - SqlModel( - name="b", - query=parse_one("select a, b, ds"), - ), - ) - snapshot_b.ttl = "in 10 seconds" - snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) - snapshot_b.updated_ts = now_ts - 11000 - - state_sync.push_snapshots([snapshot_a, snapshot_b]) - assert set(state_sync.get_snapshots(None)) == {snapshot_a.snapshot_id, snapshot_b.snapshot_id} - - assert state_sync.delete_expired_snapshots() == [ - SnapshotTableCleanupTask(snapshot=snapshot_a.table_info, dev_table_only=False), - SnapshotTableCleanupTask(snapshot=snapshot_b.table_info, dev_table_only=False), - ] - - assert not state_sync.get_snapshots(None) - - -def test_delete_expired_snapshots_promoted( - state_sync: EngineAdapterStateSync, make_snapshot: t.Callable, mocker: MockerFixture -): - now_ts = now_timestamp() - - snapshot = make_snapshot( - SqlModel( - name="a", - query=parse_one("select a, ds"), - ), - ) - snapshot.ttl = "in 10 seconds" - snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - snapshot.updated_ts = now_ts - 15000 - - state_sync.push_snapshots([snapshot]) - - env = Environment( - name="test_environment", - snapshots=[snapshot.table_info], - start_at="2022-01-01", - end_at="2022-01-01", - plan_id="test_plan_id", - previous_plan_id="test_plan_id", - ) - state_sync.promote(env) - - assert not state_sync.delete_expired_snapshots() - assert set(state_sync.get_snapshots(None)) == {snapshot.snapshot_id} - - env.snapshots = [] - state_sync.promote(env) - - now_timestamp_mock = mocker.patch("sqlmesh.core.state_sync.engine_adapter.now_timestamp") - now_timestamp_mock.return_value = now_timestamp() + 11000 - - assert state_sync.delete_expired_snapshots() == [ - SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=False) - ] - assert not state_sync.get_snapshots(None) - - -def test_delete_expired_snapshots_dev_table_cleanup_only( - state_sync: EngineAdapterStateSync, make_snapshot: t.Callable -): - now_ts = now_timestamp() - - snapshot = make_snapshot( - SqlModel( - name="a", - query=parse_one("select a, ds"), - ), - ) - snapshot.ttl = "in 10 seconds" - snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - snapshot.updated_ts = now_ts - 15000 - - new_snapshot = make_snapshot( - SqlModel( - name="a", - query=parse_one("select a, b, ds"), - ), - ) - new_snapshot.ttl = "in 10 seconds" - new_snapshot.categorize_as(SnapshotChangeCategory.FORWARD_ONLY) - new_snapshot.version = snapshot.version - new_snapshot.updated_ts = now_ts - 5000 - - state_sync.push_snapshots([snapshot, new_snapshot]) - assert set(state_sync.get_snapshots(None)) == {snapshot.snapshot_id, new_snapshot.snapshot_id} - - assert state_sync.delete_expired_snapshots() == [ - SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=True) - ] - - assert set(state_sync.get_snapshots(None)) == {new_snapshot.snapshot_id} - - -def test_delete_expired_snapshots_shared_dev_table( - state_sync: EngineAdapterStateSync, make_snapshot: t.Callable -): - now_ts = now_timestamp() - - snapshot = make_snapshot( - SqlModel( - name="a", - query=parse_one("select a, ds"), - ), - ) - snapshot.ttl = "in 10 seconds" - snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - snapshot.updated_ts = now_ts - 15000 - - new_snapshot = make_snapshot( - SqlModel( - name="a", - query=parse_one("select a, b, ds"), - ), - ) - new_snapshot.ttl = "in 10 seconds" - new_snapshot.categorize_as(SnapshotChangeCategory.FORWARD_ONLY) - new_snapshot.version = snapshot.version - new_snapshot.temp_version = snapshot.temp_version_get_or_generate() - new_snapshot.updated_ts = now_ts - 5000 - - state_sync.push_snapshots([snapshot, new_snapshot]) - assert set(state_sync.get_snapshots(None)) == {snapshot.snapshot_id, new_snapshot.snapshot_id} - - assert not state_sync.delete_expired_snapshots() # No dev table cleanup - assert set(state_sync.get_snapshots(None)) == {new_snapshot.snapshot_id} - - -def test_delete_expired_snapshots_ignore_ttl( - state_sync: EngineAdapterStateSync, make_snapshot: t.Callable -): - snapshot_a = make_snapshot( - SqlModel( - name="a", - query=parse_one("select a, ds"), - ) - ) - snapshot_a.categorize_as(SnapshotChangeCategory.NON_BREAKING) - - snapshot_b = make_snapshot( - SqlModel( - name="a", - query=parse_one("select a, b, ds"), - ), - version="2", - ) - snapshot_b.categorize_as(SnapshotChangeCategory.NON_BREAKING) - - snapshot_c = make_snapshot( - SqlModel( - name="a", - query=parse_one("select a, b, c, ds"), - ), - ) - snapshot_c.categorize_as(SnapshotChangeCategory.NON_BREAKING) - - state_sync.push_snapshots([snapshot_a, snapshot_b, snapshot_c]) - - env = Environment( - name="test_environment", - snapshots=[snapshot_a.table_info, snapshot_b.table_info], - start_at="2022-01-01", - end_at="2022-01-01", - plan_id="test_plan_id", - previous_plan_id="test_plan_id", - ) - state_sync.promote(env) - - # default TTL = 1 week, nothing to clean up yet if we take TTL into account - assert not state_sync.delete_expired_snapshots() - - # If we ignore TTL, only snapshot_c should get cleaned up because snapshot_a and snapshot_b are part of an environment - assert snapshot_a.table_info != snapshot_b.table_info != snapshot_c.table_info - assert state_sync.delete_expired_snapshots(ignore_ttl=True) == [ - SnapshotTableCleanupTask(snapshot=snapshot_c.table_info, dev_table_only=False) - ] - - -def test_environment_start_as_timestamp( - state_sync: EngineAdapterStateSync, make_snapshot: t.Callable -): - snapshot = make_snapshot( - SqlModel( - name="a", - query=parse_one("select a, ds"), - ), - ) - snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - - state_sync.push_snapshots([snapshot]) - - now_ts = now_timestamp() - - env = Environment( - name="test_environment_a", - snapshots=[snapshot.table_info], - start_at=now_ts, - end_at=None, - plan_id="test_plan_id", - previous_plan_id="test_plan_id", - expiration_ts=now_ts - 1000, - ) - state_sync.promote(env) - - stored_env = state_sync.get_environment(env.name) - assert stored_env - assert stored_env.start_at == to_datetime(now_ts).replace(tzinfo=None).isoformat(sep=" ") - - -def test_unpause_snapshots(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable): - snapshot = make_snapshot( - SqlModel( - name="test_snapshot", - query=parse_one("select 1, ds"), - cron="@daily", - ), - ) - snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - snapshot.version = "a" - - assert not snapshot.unpaused_ts - state_sync.push_snapshots([snapshot]) - - unpaused_dt = "2022-01-01" - state_sync.unpause_snapshots([snapshot], unpaused_dt) - - actual_snapshot = state_sync.get_snapshots([snapshot])[snapshot.snapshot_id] - assert actual_snapshot.unpaused_ts - assert actual_snapshot.unpaused_ts == to_timestamp(unpaused_dt) - - new_snapshot = make_snapshot( - SqlModel(name="test_snapshot", query=parse_one("select 2, ds"), cron="@daily") - ) - new_snapshot.categorize_as(SnapshotChangeCategory.FORWARD_ONLY) - new_snapshot.version = "a" - - assert not new_snapshot.unpaused_ts - state_sync.push_snapshots([new_snapshot]) - state_sync.unpause_snapshots([new_snapshot], unpaused_dt) - - actual_snapshots = state_sync.get_snapshots([snapshot, new_snapshot]) - assert not actual_snapshots[snapshot.snapshot_id].unpaused_ts - assert actual_snapshots[new_snapshot.snapshot_id].unpaused_ts == to_timestamp(unpaused_dt) - - assert actual_snapshots[snapshot.snapshot_id].unrestorable - assert not actual_snapshots[new_snapshot.snapshot_id].unrestorable - - -def test_unrestorable_snapshot(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable): - snapshot = make_snapshot( - SqlModel( - name="test_snapshot", - query=parse_one("select 1, ds"), - cron="@daily", - ), - ) - snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - snapshot.version = "a" - - assert not snapshot.unpaused_ts - state_sync.push_snapshots([snapshot]) - - unpaused_dt = "2022-01-01" - state_sync.unpause_snapshots([snapshot], unpaused_dt) - - actual_snapshot = state_sync.get_snapshots([snapshot])[snapshot.snapshot_id] - assert actual_snapshot.unpaused_ts - assert actual_snapshot.unpaused_ts == to_timestamp(unpaused_dt) - - new_indirect_non_breaking_snapshot = make_snapshot( - SqlModel(name="test_snapshot", query=parse_one("select 2, ds"), cron="@daily") - ) - new_indirect_non_breaking_snapshot.categorize_as(SnapshotChangeCategory.INDIRECT_NON_BREAKING) - new_indirect_non_breaking_snapshot.version = "a" - - assert not new_indirect_non_breaking_snapshot.unpaused_ts - state_sync.push_snapshots([new_indirect_non_breaking_snapshot]) - state_sync.unpause_snapshots([new_indirect_non_breaking_snapshot], unpaused_dt) - - actual_snapshots = state_sync.get_snapshots([snapshot, new_indirect_non_breaking_snapshot]) - assert not actual_snapshots[snapshot.snapshot_id].unpaused_ts - assert actual_snapshots[ - new_indirect_non_breaking_snapshot.snapshot_id - ].unpaused_ts == to_timestamp(unpaused_dt) - - assert not actual_snapshots[snapshot.snapshot_id].unrestorable - assert not actual_snapshots[new_indirect_non_breaking_snapshot.snapshot_id].unrestorable - - new_forward_only_snapshot = make_snapshot( - SqlModel(name="test_snapshot", query=parse_one("select 3, ds"), cron="@daily") - ) - new_forward_only_snapshot.categorize_as(SnapshotChangeCategory.FORWARD_ONLY) - new_forward_only_snapshot.version = "a" - - assert not new_forward_only_snapshot.unpaused_ts - state_sync.push_snapshots([new_forward_only_snapshot]) - state_sync.unpause_snapshots([new_forward_only_snapshot], unpaused_dt) - - actual_snapshots = state_sync.get_snapshots( - [snapshot, new_indirect_non_breaking_snapshot, new_forward_only_snapshot] - ) - assert not actual_snapshots[snapshot.snapshot_id].unpaused_ts - assert not actual_snapshots[new_indirect_non_breaking_snapshot.snapshot_id].unpaused_ts - assert actual_snapshots[new_forward_only_snapshot.snapshot_id].unpaused_ts == to_timestamp( - unpaused_dt - ) - - assert actual_snapshots[snapshot.snapshot_id].unrestorable - assert actual_snapshots[new_indirect_non_breaking_snapshot.snapshot_id].unrestorable - assert not actual_snapshots[new_forward_only_snapshot.snapshot_id].unrestorable - - -def test_unpause_snapshots_remove_intervals( - state_sync: EngineAdapterStateSync, make_snapshot: t.Callable -): - snapshot = make_snapshot( - SqlModel( - name="test_snapshot", - query=parse_one("select 1, ds"), - cron="@daily", - ), - version="a", - ) - snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - snapshot.version = "a" - state_sync.push_snapshots([snapshot]) - state_sync.add_interval(snapshot, "2023-01-01", "2023-01-05") - - new_snapshot = make_snapshot( - SqlModel(name="test_snapshot", query=parse_one("select 2, ds"), cron="@daily"), - version="a", - ) - new_snapshot.categorize_as(SnapshotChangeCategory.FORWARD_ONLY) - new_snapshot.version = "a" - new_snapshot.effective_from = "2023-01-03" - state_sync.push_snapshots([new_snapshot]) - state_sync.add_interval(snapshot, "2023-01-06", "2023-01-06") - state_sync.unpause_snapshots([new_snapshot], "2023-01-06") - - actual_snapshots = state_sync.get_snapshots([snapshot, new_snapshot]) - assert actual_snapshots[new_snapshot.snapshot_id].intervals == [ - (to_timestamp("2023-01-01"), to_timestamp("2023-01-03")), - ] - assert actual_snapshots[snapshot.snapshot_id].intervals == [ - (to_timestamp("2023-01-01"), to_timestamp("2023-01-03")), - ] - - -def test_version_schema(state_sync: EngineAdapterStateSync, tmp_path) -> None: - from sqlmesh import __version__ as SQLMESH_VERSION - - # fresh install should not raise - assert state_sync.get_versions() == Versions( - schema_version=SCHEMA_VERSION, - sqlglot_version=SQLGLOT_VERSION, - sqlmesh_version=SQLMESH_VERSION, - ) - - # Start with a clean slate. - state_sync = EngineAdapterStateSync( - create_engine_adapter(duckdb.connect, "duckdb"), schema=c.SQLMESH, context_path=tmp_path - ) - - with pytest.raises( - SQLMeshError, - match=rf"SQLMesh \(local\) is using version '{SCHEMA_VERSION}' which is ahead of '0'", - ): - state_sync.get_versions() - - state_sync.migrate(default_catalog=None) - - # migration version is behind, always raise - state_sync._update_versions(schema_version=SCHEMA_VERSION + 1) - error = ( - rf"SQLMesh \(local\) is using version '{SCHEMA_VERSION}' which is behind '{SCHEMA_VERSION + 1}' \(remote\). " - rf"""Please upgrade SQLMesh \('pip install --upgrade "sqlmesh=={SQLMESH_VERSION}"' command\).""" - ) - - with pytest.raises(SQLMeshError, match=error): - state_sync.get_versions() - - # should no longer raise - state_sync.get_versions(validate=False) - - # migration version is ahead, only raise when validate is true - state_sync._update_versions(schema_version=SCHEMA_VERSION - 1) - with pytest.raises( - SQLMeshError, - match=rf"SQLMesh \(local\) is using version '{SCHEMA_VERSION}' which is ahead of '{SCHEMA_VERSION - 1}'", - ): - state_sync.get_versions() - state_sync.get_versions(validate=False) - - -def test_version_sqlmesh(state_sync: EngineAdapterStateSync) -> None: - from sqlmesh import __version__ as SQLMESH_VERSION - from sqlmesh import __version_tuple__ as SQLMESH_VERSION_TUPLE - - # patch version sqlmesh doesn't matter - major, minor, patch, *_ = SQLMESH_VERSION_TUPLE - new_patch = ( - f"dev{int(patch[3:]) + 1}" # type: ignore - if isinstance(patch, str) and patch.startswith("dev") - else f"{int(patch) + 1}" - ) - sqlmesh_version_patch_bump = f"{major}.{minor}.{new_patch}" - state_sync._update_versions(sqlmesh_version=sqlmesh_version_patch_bump) - state_sync.get_versions(validate=False) - - # sqlmesh version is behind - sqlmesh_version_minor_bump = f"{major}.{int(minor) + 1}.{patch}" - error = ( - rf"SQLMesh \(local\) is using version '{SQLMESH_VERSION}' which is behind '{sqlmesh_version_minor_bump}' \(remote\). " - rf"""Please upgrade SQLMesh \('pip install --upgrade "sqlmesh=={sqlmesh_version_minor_bump}"' command\).""" - ) - state_sync._update_versions(sqlmesh_version=sqlmesh_version_minor_bump) - with pytest.raises(SQLMeshError, match=error): - state_sync.get_versions() - state_sync.get_versions(validate=False) - - # sqlmesh version is ahead - sqlmesh_version_minor_decrease = f"{major}.{int(minor) - 1}.{patch}" - error = rf"SQLMesh \(local\) is using version '{SQLMESH_VERSION}' which is ahead of '{sqlmesh_version_minor_decrease}'" - state_sync._update_versions(sqlmesh_version=sqlmesh_version_minor_decrease) - with pytest.raises(SQLMeshError, match=error): - state_sync.get_versions() - state_sync.get_versions(validate=False) - - -def test_version_sqlglot(state_sync: EngineAdapterStateSync) -> None: - # patch version sqlglot doesn't matter - major, minor, patch, *_ = SQLGLOT_VERSION.split(".") - sqlglot_version = f"{major}.{minor}.{int(patch) + 1}" - state_sync._update_versions(sqlglot_version=sqlglot_version) - state_sync.get_versions(validate=False) - - # sqlglot version is behind - sqlglot_version = f"{major}.{int(minor) + 1}.{patch}" - error = ( - rf"SQLGlot \(local\) is using version '{SQLGLOT_VERSION}' which is behind '{sqlglot_version}' \(remote\). " - rf"""Please upgrade SQLGlot \('pip install --upgrade "sqlglot=={sqlglot_version}"' command\).""" - ) - state_sync._update_versions(sqlglot_version=sqlglot_version) - with pytest.raises(SQLMeshError, match=error): - state_sync.get_versions() - state_sync.get_versions(validate=False) - - # sqlglot version is ahead - sqlglot_version = f"{major}.{int(minor) - 1}.{patch}" - error = rf"SQLGlot \(local\) is using version '{SQLGLOT_VERSION}' which is ahead of '{sqlglot_version}'" - state_sync._update_versions(sqlglot_version=sqlglot_version) - with pytest.raises(SQLMeshError, match=error): - state_sync.get_versions() - state_sync.get_versions(validate=False) - - -def test_empty_versions() -> None: - for empty_versions in ( - Versions(), - Versions(schema_version=None, sqlglot_version=None, sqlmesh_version=None), - ): - assert empty_versions.schema_version == 0 - assert empty_versions.sqlglot_version == "0.0.0" - assert empty_versions.sqlmesh_version == "0.0.0" - - -def test_migrate(state_sync: EngineAdapterStateSync, mocker: MockerFixture, tmp_path) -> None: - from sqlmesh import __version__ as SQLMESH_VERSION - - migrate_rows_mock = mocker.patch("sqlmesh.core.state_sync.EngineAdapterStateSync._migrate_rows") - backup_state_mock = mocker.patch("sqlmesh.core.state_sync.EngineAdapterStateSync._backup_state") - state_sync.migrate(default_catalog=None) - migrate_rows_mock.assert_not_called() - backup_state_mock.assert_not_called() - - # Start with a clean slate. - state_sync = EngineAdapterStateSync( - create_engine_adapter(duckdb.connect, "duckdb"), schema=c.SQLMESH, context_path=tmp_path - ) - - state_sync.migrate(default_catalog=None) - migrate_rows_mock.assert_called_once() - backup_state_mock.assert_called_once() - assert state_sync.get_versions() == Versions( - schema_version=SCHEMA_VERSION, - sqlglot_version=SQLGLOT_VERSION, - sqlmesh_version=SQLMESH_VERSION, - ) - - assert ( - state_sync.engine_adapter.fetchone( - "SELECT COUNT(*) FROM sqlmesh._snapshots WHERE expiration_ts IS NULL" - )[0] - == 0 - ) - - -def test_rollback(state_sync: EngineAdapterStateSync, mocker: MockerFixture) -> None: - with pytest.raises( - SQLMeshError, - match="There are no prior migrations to roll back to.", - ): - state_sync.rollback() - - restore_table_spy = mocker.spy(state_sync, "_restore_table") - state_sync._backup_state() - - state_sync.rollback() - calls = {(a.sql(), b.sql()) for (a, b), _ in restore_table_spy.call_args_list} - assert ( - f"{state_sync.schema}._snapshots", - f"{state_sync.schema}._snapshots_backup", - ) in calls - assert ( - f"{state_sync.schema}._environments", - f"{state_sync.schema}._environments_backup", - ) in calls - assert ( - f"{state_sync.schema}._versions", - f"{state_sync.schema}._versions_backup", - ) in calls - assert not state_sync.engine_adapter.table_exists(f"{state_sync.schema}._snapshots_backup") - assert not state_sync.engine_adapter.table_exists(f"{state_sync.schema}._environments_backup") - assert not state_sync.engine_adapter.table_exists(f"{state_sync.schema}._versions_backup") - - -def test_first_migration_failure(duck_conn, mocker: MockerFixture, tmp_path) -> None: - state_sync = EngineAdapterStateSync( - create_engine_adapter(lambda: duck_conn, "duckdb"), schema=c.SQLMESH, context_path=tmp_path - ) - mocker.patch.object(state_sync, "_migrate_rows", side_effect=Exception("mocked error")) - with pytest.raises( - SQLMeshError, - match="SQLMesh migration failed.", - ): - state_sync.migrate(default_catalog=None) - assert not state_sync.engine_adapter.table_exists(state_sync.snapshots_table) - assert not state_sync.engine_adapter.table_exists(state_sync.environments_table) - assert not state_sync.engine_adapter.table_exists(state_sync.versions_table) - assert not state_sync.engine_adapter.table_exists(state_sync.intervals_table) - - -def test_migrate_rows(state_sync: EngineAdapterStateSync, mocker: MockerFixture) -> None: - delete_versions(state_sync) - - state_sync.engine_adapter.replace_query( - "sqlmesh._snapshots", - pd.read_json("tests/fixtures/migrations/snapshots.json"), - columns_to_types={ - "name": exp.DataType.build("text"), - "identifier": exp.DataType.build("text"), - "version": exp.DataType.build("text"), - "snapshot": exp.DataType.build("text"), - }, - ) - - state_sync.engine_adapter.replace_query( - "sqlmesh._environments", - pd.read_json("tests/fixtures/migrations/environments.json"), - columns_to_types={ - "name": exp.DataType.build("text"), - "snapshots": exp.DataType.build("text"), - "start_at": exp.DataType.build("text"), - "end_at": exp.DataType.build("text"), - "plan_id": exp.DataType.build("text"), - "previous_plan_id": exp.DataType.build("text"), - "expiration_ts": exp.DataType.build("bigint"), - }, - ) - - state_sync.engine_adapter.drop_table("sqlmesh._seeds") - - old_snapshots = state_sync.engine_adapter.fetchdf("select * from sqlmesh._snapshots") - old_environments = state_sync.engine_adapter.fetchdf("select * from sqlmesh._environments") - - state_sync.migrate(default_catalog=None, skip_backup=True) - - new_snapshots = state_sync.engine_adapter.fetchdf("select * from sqlmesh._snapshots") - new_environments = state_sync.engine_adapter.fetchdf("select * from sqlmesh._environments") - - assert len(old_snapshots) * 2 == len(new_snapshots) - assert len(old_environments) == len(new_environments) - - start = "2023-01-01" - end = "2023-01-07" - - assert not missing_intervals( - state_sync.get_snapshots( - t.cast(Environment, state_sync.get_environment("staging")).snapshots - ).values(), - start=start, - end=end, - ) - - dev_snapshots = state_sync.get_snapshots( - t.cast(Environment, state_sync.get_environment("dev")).snapshots - ).values() - - assert all(s.migrated for s in dev_snapshots) - assert all(s.change_category is not None for s in dev_snapshots) - - assert not missing_intervals(dev_snapshots, start=start, end=end) - - assert not missing_intervals(dev_snapshots, start="2023-01-08", end="2023-01-10") == 8 - - for s in state_sync.get_snapshots(None).values(): - if not s.is_symbolic: - assert s.intervals - - customer_revenue_by_day = new_snapshots.loc[ - new_snapshots["name"] == '"sushi"."customer_revenue_by_day"' - ].iloc[0] - assert json.loads(customer_revenue_by_day["snapshot"])["node"]["query"].startswith( - "JINJA_QUERY_BEGIN" - ) - - -def test_backup_state(state_sync: EngineAdapterStateSync, mocker: MockerFixture) -> None: - state_sync.engine_adapter.replace_query( - "sqlmesh._snapshots", - pd.read_json("tests/fixtures/migrations/snapshots.json"), - columns_to_types={ - "name": exp.DataType.build("text"), - "identifier": exp.DataType.build("text"), - "version": exp.DataType.build("text"), - "snapshot": exp.DataType.build("text"), - }, - ) - - state_sync._backup_state() - pd.testing.assert_frame_equal( - state_sync.engine_adapter.fetchdf("select * from sqlmesh._snapshots"), - state_sync.engine_adapter.fetchdf("select * from sqlmesh._snapshots_backup"), - ) - - -def test_restore_snapshots_table(state_sync: EngineAdapterStateSync) -> None: - snapshot_columns_to_types = { - "name": exp.DataType.build("text"), - "identifier": exp.DataType.build("text"), - "version": exp.DataType.build("text"), - "snapshot": exp.DataType.build("text"), - } - state_sync.engine_adapter.replace_query( - "sqlmesh._snapshots", - pd.read_json("tests/fixtures/migrations/snapshots.json"), - columns_to_types=snapshot_columns_to_types, - ) - - old_snapshots = state_sync.engine_adapter.fetchdf("select * from sqlmesh._snapshots") - old_snapshots_count = state_sync.engine_adapter.fetchone( - "select count(*) from sqlmesh._snapshots" - ) - assert old_snapshots_count == (12,) - state_sync._backup_state() - - state_sync.engine_adapter.delete_from("sqlmesh._snapshots", "TRUE") - snapshots_count = state_sync.engine_adapter.fetchone("select count(*) from sqlmesh._snapshots") - assert snapshots_count == (0,) - state_sync._restore_table( - table_name="sqlmesh._snapshots", - backup_table_name="sqlmesh._snapshots_backup", - ) - - new_snapshots = state_sync.engine_adapter.fetchdf("select * from sqlmesh._snapshots") - pd.testing.assert_frame_equal( - old_snapshots, - new_snapshots, - ) - - -def test_seed_hydration( - state_sync: EngineAdapterStateSync, - make_snapshot: t.Callable, -): - snapshot = make_snapshot( - SeedModel( - name="a", - kind=SeedKind(path="./path/to/seed"), - seed=Seed(content="header\n1\n2"), - column_hashes={"header": "hash"}, - depends_on=set(), - ) - ) - snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - - state_sync.push_snapshots([snapshot]) - - assert snapshot.model.is_hydrated - assert snapshot.model.seed.content == "header\n1\n2" - - stored_snapshot = state_sync.get_snapshots([snapshot.snapshot_id])[snapshot.snapshot_id] - assert isinstance(stored_snapshot.model, SeedModel) - assert not stored_snapshot.model.is_hydrated - assert stored_snapshot.model.seed.content == "" - - -def test_nodes_exist(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable): - snapshot = make_snapshot( - SqlModel( - name="a", - query=parse_one("select 1, ds"), - ) - ) - - snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - - assert not state_sync.nodes_exist([snapshot.name]) - - state_sync.push_snapshots([snapshot]) - - assert state_sync.nodes_exist([snapshot.name]) == {snapshot.name} - - -def test_invalidate_environment(state_sync: EngineAdapterStateSync, make_snapshot: t.Callable): - snapshot = make_snapshot( - SqlModel( - name="a", - query=parse_one("select a, ds"), - ), - ) - snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - - state_sync.push_snapshots([snapshot]) - - original_expiration_ts = now_timestamp() + 100000 - - env = Environment( - name="test_environment", - snapshots=[snapshot.table_info], - start_at="2022-01-01", - end_at="2022-01-01", - plan_id="test_plan_id", - previous_plan_id="test_plan_id", - expiration_ts=original_expiration_ts, - ) - state_sync.promote(env) - - assert not state_sync.delete_expired_environments() - state_sync.invalidate_environment("test_environment") - - stored_env = state_sync.get_environment("test_environment") - assert stored_env - assert stored_env.expiration_ts and stored_env.expiration_ts < original_expiration_ts - - deleted_environments = state_sync.delete_expired_environments() - assert len(deleted_environments) == 1 - assert deleted_environments[0].name == "test_environment" - - with pytest.raises(SQLMeshError, match="Cannot invalidate the production environment."): - state_sync.invalidate_environment("prod") - - -def test_cache(state_sync, make_snapshot, mocker): - cache = CachingStateSync(state_sync, ttl=10) - - snapshot = make_snapshot( - SqlModel( - name="a", - query=parse_one("select 'a', 'ds'"), - ), - ) - snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - - now_timestamp = mocker.patch("sqlmesh.core.state_sync.cache.now_timestamp") - now_timestamp.return_value = to_timestamp("2023-01-01 00:00:00") - - # prime the cache with a cached missing snapshot - assert not cache.get_snapshots([snapshot.snapshot_id]) - - # item is cached and shouldn't hit state sync - with patch.object(state_sync, "get_snapshots") as mock: - assert not cache.get_snapshots([snapshot.snapshot_id]) - mock.assert_not_called() - - # prime the cache with a real snapshot - cache.push_snapshots([snapshot]) - assert cache.get_snapshots([snapshot.snapshot_id]) == {snapshot.snapshot_id: snapshot} - - # cache hit - with patch.object(state_sync, "get_snapshots") as mock: - assert cache.get_snapshots([snapshot.snapshot_id]) == {snapshot.snapshot_id: snapshot} - mock.assert_not_called() - - # clear the cache by adding intervals - cache.add_interval(snapshot, "2020-01-01", "2020-01-01") - with patch.object(state_sync, "get_snapshots") as mock: - assert not cache.get_snapshots([snapshot.snapshot_id]) - mock.assert_called() - - # clear the cache by removing intervals - cache.remove_interval([(snapshot, snapshot.inclusive_exclusive("2020-01-01", "2020-01-01"))]) - - # prime the cache - assert cache.get_snapshots([snapshot.snapshot_id]) == {snapshot.snapshot_id: snapshot} - - # cache hit half way - now_timestamp.return_value = to_timestamp("2023-01-01 00:00:05") - with patch.object(state_sync, "get_snapshots") as mock: - assert cache.get_snapshots([snapshot.snapshot_id]) - mock.assert_not_called() - - # no cache hit - now_timestamp.return_value = to_timestamp("2023-01-01 00:00:11") - with patch.object(state_sync, "get_snapshots") as mock: - assert not cache.get_snapshots([snapshot.snapshot_id]) - mock.assert_called() - - -def test_cleanup_expired_views( - mocker: MockerFixture, state_sync: EngineAdapterStateSync, make_snapshot: t.Callable -): - adapter = mocker.MagicMock() - adapter.dialect = None - snapshot_a = make_snapshot(SqlModel(name="catalog.schema.a", query=parse_one("select 1, ds"))) - snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) - snapshot_b = make_snapshot(SqlModel(name="catalog.schema.b", query=parse_one("select 1, ds"))) - snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) - # Make sure that we don't drop schemas from external models - snapshot_external_model = make_snapshot( - ExternalModel(name="catalog.external_schema.external_table", kind=ModelKindName.EXTERNAL) - ) - snapshot_external_model.categorize_as(SnapshotChangeCategory.BREAKING) - schema_environment = Environment( - name="test_environment", - suffix_target=EnvironmentSuffixTarget.SCHEMA, - snapshots=[ - snapshot_a.table_info, - snapshot_b.table_info, - snapshot_external_model.table_info, - ], - start_at="2022-01-01", - end_at="2022-01-01", - plan_id="test_plan_id", - previous_plan_id="test_plan_id", - catalog_name_override="catalog_override", - ) - snapshot_c = make_snapshot(SqlModel(name="catalog.schema.c", query=parse_one("select 1, ds"))) - snapshot_c.categorize_as(SnapshotChangeCategory.BREAKING) - snapshot_d = make_snapshot(SqlModel(name="catalog.schema.d", query=parse_one("select 1, ds"))) - snapshot_d.categorize_as(SnapshotChangeCategory.BREAKING) - table_environment = Environment( - name="test_environment", - suffix_target=EnvironmentSuffixTarget.TABLE, - snapshots=[ - snapshot_c.table_info, - snapshot_d.table_info, - snapshot_external_model.table_info, - ], - start_at="2022-01-01", - end_at="2022-01-01", - plan_id="test_plan_id", - previous_plan_id="test_plan_id", - catalog_name_override="catalog_override", - ) - cleanup_expired_views(adapter, [schema_environment, table_environment]) - assert adapter.drop_schema.called - assert adapter.drop_view.called - assert adapter.drop_schema.call_args_list == [ - call( - schema_("schema__test_environment", "catalog_override"), - ignore_if_not_exists=True, - cascade=True, - ) - ] - assert sorted(adapter.drop_view.call_args_list) == [ - call("catalog_override.schema.c__test_environment", ignore_if_not_exists=True), - call("catalog_override.schema.d__test_environment", ignore_if_not_exists=True), - ] - - -def test_max_interval_end_for_environment( - state_sync: EngineAdapterStateSync, make_snapshot: t.Callable -) -> None: - snapshot_a = make_snapshot( - SqlModel( - name="a", - cron="@daily", - query=parse_one("select 1, ds"), - ), - ) - snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) - - snapshot_b = make_snapshot( - SqlModel( - name="b", - cron="@daily", - query=parse_one("select 2, ds"), - ), - ) - snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) - - state_sync.push_snapshots([snapshot_a, snapshot_b]) - - state_sync.add_interval(snapshot_a, "2023-01-01", "2023-01-01") - state_sync.add_interval(snapshot_a, "2023-01-02", "2023-01-02") - state_sync.add_interval(snapshot_b, "2023-01-01", "2023-01-01") - - environment_name = "test_max_interval_end_for_environment" - - assert state_sync.max_interval_end_for_environment(environment_name) is None - - state_sync.promote( - Environment( - name=environment_name, - snapshots=[snapshot_a.table_info, snapshot_b.table_info], - start_at="2023-01-01", - end_at="2023-01-03", - plan_id="test_plan_id", - previous_finalized_snapshots=[snapshot_b.table_info], - ) - ) - - assert state_sync.max_interval_end_for_environment(environment_name) == to_timestamp( - "2023-01-03" - ) - - -def test_max_interval_end_for_environment_ensure_finalized_snapshots( - state_sync: EngineAdapterStateSync, make_snapshot: t.Callable -) -> None: - snapshot_a = make_snapshot( - SqlModel( - name="a", - cron="@daily", - query=parse_one("select 1, ds"), - ), - ) - snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) - - snapshot_b = make_snapshot( - SqlModel( - name="b", - cron="@daily", - query=parse_one("select 2, ds"), - ), - ) - snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) - - state_sync.push_snapshots([snapshot_a, snapshot_b]) - - state_sync.add_interval(snapshot_a, "2023-01-01", "2023-01-01") - state_sync.add_interval(snapshot_a, "2023-01-02", "2023-01-02") - state_sync.add_interval(snapshot_b, "2023-01-01", "2023-01-01") - - environment_name = "test_max_interval_end_for_environment" - - assert ( - state_sync.max_interval_end_for_environment( - environment_name, ensure_finalized_snapshots=True - ) - is None - ) - - state_sync.promote( - Environment( - name=environment_name, - snapshots=[snapshot_a.table_info, snapshot_b.table_info], - start_at="2023-01-01", - end_at="2023-01-03", - plan_id="test_plan_id", - previous_finalized_snapshots=[snapshot_b.table_info], - ) - ) - - assert state_sync.max_interval_end_for_environment( - environment_name, ensure_finalized_snapshots=True - ) == to_timestamp("2023-01-02") - - -def test_greatest_common_interval_end( - state_sync: EngineAdapterStateSync, make_snapshot: t.Callable -) -> None: - snapshot_a = make_snapshot( - SqlModel( - name="a", - cron="@daily", - query=parse_one("select 1, ds"), - ), - ) - snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) - - snapshot_b = make_snapshot( - SqlModel( - name="b", - cron="@daily", - query=parse_one("select 2, ds"), - ), - ) - snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) - - state_sync.push_snapshots([snapshot_a, snapshot_b]) - - state_sync.add_interval(snapshot_a, "2023-01-01", "2023-01-01") - state_sync.add_interval(snapshot_a, "2023-01-02", "2023-01-02") - state_sync.add_interval(snapshot_a, "2023-01-03", "2023-01-03") - state_sync.add_interval(snapshot_b, "2023-01-01", "2023-01-01") - state_sync.add_interval(snapshot_b, "2023-01-02", "2023-01-02") - - environment_name = "test_max_interval_end_for_environment" - - assert state_sync.greatest_common_interval_end(environment_name, {snapshot_a.name}) is None - - state_sync.promote( - Environment( - name=environment_name, - snapshots=[snapshot_a.table_info, snapshot_b.table_info], - start_at="2023-01-01", - end_at="2023-01-03", - plan_id="test_plan_id", - previous_finalized_snapshots=[snapshot_b.table_info], - ) - ) - - assert state_sync.greatest_common_interval_end( - environment_name, {snapshot_a.name} - ) == to_timestamp("2023-01-04") - - assert state_sync.greatest_common_interval_end( - environment_name, {snapshot_b.name} - ) == to_timestamp("2023-01-03") - - assert state_sync.greatest_common_interval_end( - environment_name, {snapshot_a.name, snapshot_b.name} - ) == to_timestamp("2023-01-03") - - assert state_sync.greatest_common_interval_end(environment_name, {"missing"}) == to_timestamp( - "2023-01-03" - ) - assert state_sync.greatest_common_interval_end(environment_name, set()) is None - - -def test_greatest_common_interval_end_ensure_finalized_snapshots( - state_sync: EngineAdapterStateSync, make_snapshot: t.Callable -) -> None: - snapshot_a = make_snapshot( - SqlModel( - name="a", - cron="@daily", - query=parse_one("select 1, ds"), - ), - ) - snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) - - snapshot_b = make_snapshot( - SqlModel( - name="b", - cron="@daily", - query=parse_one("select 2, ds"), - ), - ) - snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) - - state_sync.push_snapshots([snapshot_a, snapshot_b]) - - state_sync.add_interval(snapshot_a, "2023-01-01", "2023-01-01") - state_sync.add_interval(snapshot_a, "2023-01-02", "2023-01-02") - state_sync.add_interval(snapshot_a, "2023-01-03", "2023-01-03") - state_sync.add_interval(snapshot_b, "2023-01-01", "2023-01-01") - state_sync.add_interval(snapshot_b, "2023-01-02", "2023-01-02") - - environment_name = "test_max_interval_end_for_environment" - - assert state_sync.greatest_common_interval_end(environment_name, {snapshot_a.name}) is None - - state_sync.promote( - Environment( - name=environment_name, - snapshots=[snapshot_a.table_info, snapshot_b.table_info], - start_at="2023-01-01", - end_at="2023-01-03", - plan_id="test_plan_id", - previous_finalized_snapshots=[snapshot_b.table_info], - ) - ) - - assert state_sync.greatest_common_interval_end( - environment_name, {snapshot_a.name}, ensure_finalized_snapshots=True - ) == to_timestamp("2023-01-03") - - assert state_sync.greatest_common_interval_end( - environment_name, {snapshot_b.name}, ensure_finalized_snapshots=True - ) == to_timestamp("2023-01-03") - - assert state_sync.greatest_common_interval_end( - environment_name, {snapshot_a.name, snapshot_b.name}, ensure_finalized_snapshots=True - ) == to_timestamp("2023-01-03") - - assert state_sync.greatest_common_interval_end( - environment_name, {"missing"}, ensure_finalized_snapshots=True - ) == to_timestamp("2023-01-03") - assert state_sync.greatest_common_interval_end(environment_name, set()) is None - - -def test_get_snapshots(mocker): - mock = mocker.MagicMock() - cache = CachingStateSync(mock) - cache.get_snapshots([]) - mock.get_snapshots.assert_not_called() - - -def test_snapshot_batching(state_sync, mocker, make_snapshot): - mock = mocker.Mock() - - state_sync.SNAPSHOT_BATCH_SIZE = 2 - state_sync.engine_adapter = mock - - snapshot_a = make_snapshot(SqlModel(name="a", query=parse_one("select 1")), "1") - snapshot_b = make_snapshot(SqlModel(name="a", query=parse_one("select 2")), "2") - snapshot_c = make_snapshot(SqlModel(name="a", query=parse_one("select 3")), "3") - - state_sync.delete_snapshots( - ( - snapshot_a, - snapshot_b, - snapshot_c, - ) - ) - calls = mock.delete_from.call_args_list - assert mock.delete_from.call_args_list == [ - call( - exp.to_table("sqlmesh._snapshots"), - where=parse_one( - f"(name, identifier) in (('\"a\"', '{snapshot_b.identifier}'), ('\"a\"', '{snapshot_a.identifier}'))" - ), - ), - call( - exp.to_table("sqlmesh._snapshots"), - where=parse_one(f"(name, identifier) in (('\"a\"', '{snapshot_c.identifier}'))"), - ), - ] - - mock.fetchall.side_effect = [ - [ - [ - make_snapshot( - SqlModel(name="a", query=parse_one("select 1")), - ).json(), - "a", - "1", - "1", - ], - [ - make_snapshot( - SqlModel(name="a", query=parse_one("select 2")), - ).json(), - "a", - "2", - "2", - ], - ], - [ - [ - make_snapshot( - SqlModel(name="a", query=parse_one("select 3")), - ).json(), - "a", - "3", - "3", - ], - ], - ] - - snapshots = state_sync._get_snapshots( - ( - SnapshotId(name="a", identifier="1"), - SnapshotId(name="a", identifier="2"), - SnapshotId(name="a", identifier="3"), - ), - hydrate_intervals=False, - ) - assert len(snapshots) == 3 - calls = mock.fetchall.call_args_list - assert len(calls) == 2 - - -def test_seed_model_metadata_update( - state_sync: EngineAdapterStateSync, - make_snapshot: t.Callable, -): - model = SeedModel( - name="a", - kind=SeedKind(path="./path/to/seed"), - seed=Seed(content="header\n1\n2"), - column_hashes={"header": "hash"}, - depends_on=set(), - ) - snapshot = make_snapshot(model) - snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - - state_sync.push_snapshots([snapshot]) - - model = model.copy(update={"owner": "jen"}) - new_snapshot = make_snapshot(model) - new_snapshot.previous_versions = snapshot.all_versions - new_snapshot.categorize_as(SnapshotChangeCategory.FORWARD_ONLY) - - assert snapshot.fingerprint != new_snapshot.fingerprint - assert snapshot.version == new_snapshot.version - - state_sync.push_snapshots([new_snapshot]) - assert len(state_sync.get_snapshots([new_snapshot, snapshot])) == 2 diff --git a/tests/core/test_table_diff.py b/tests/core/test_table_diff.py index fa4175be86..c2e293e4c2 100644 --- a/tests/core/test_table_diff.py +++ b/tests/core/test_table_diff.py @@ -1,19 +1,62 @@ import pytest from pytest_mock.plugin import MockerFixture -import pandas as pd +import pandas as pd # noqa: TID253 from sqlglot import exp from sqlmesh.core import dialect as d +import typing as t +from io import StringIO +from rich.console import Console +from sqlmesh.core.console import TerminalConsole from sqlmesh.core.context import Context -from sqlmesh.core.config import AutoCategorizationMode, CategorizerConfig +from sqlmesh.core.config import AutoCategorizationMode, CategorizerConfig, DuckDBConnectionConfig from sqlmesh.core.model import SqlModel, load_sql_based_model +from sqlmesh.core.model.common import ParsableSql +from sqlmesh.core.table_diff import TableDiff, SchemaDiff +import numpy as np # noqa: TID253 +from sqlmesh.utils.errors import SQLMeshError +from sqlmesh.utils.rich import strip_ansi_codes +pytestmark = pytest.mark.slow -@pytest.mark.slow -def test_data_diff(sushi_context_fixed_date): + +def create_test_console() -> t.Tuple[StringIO, TerminalConsole]: + """Creates a console and buffer for validating console output.""" + console_output = StringIO() + console = Console(file=console_output, force_terminal=True) + terminal_console = TerminalConsole(console=console) + return console_output, terminal_console + + +def capture_console_output(method_name: str, **kwargs) -> str: + """Factory function to invoke and capture output a TerminalConsole method. + + Args: + method_name: Name of the TerminalConsole method to call + **kwargs: Arguments to pass to the method + + Returns: + The captured output as a string + """ + console_output, terminal_console = create_test_console() + try: + method = getattr(terminal_console, method_name) + method(**kwargs) + return console_output.getvalue() + finally: + console_output.close() + + +def test_data_diff(sushi_context_fixed_date, capsys, caplog): model = sushi_context_fixed_date.models['"memory"."sushi"."customer_revenue_by_day"'] - model.query.select(exp.cast("'1'", "VARCHAR").as_("modified_col"), "1 AS y", copy=False) - sushi_context_fixed_date.upsert_model(model) + sushi_context_fixed_date.upsert_model( + model, + query_=ParsableSql( + sql=model.query.select(exp.cast("'1'", "VARCHAR").as_("modified_col"), "1 AS y").sql( + model.dialect + ) + ), + ) sushi_context_fixed_date.plan( "source_dev", @@ -70,8 +113,13 @@ def test_data_diff(sushi_context_fixed_date): source="source_dev", target="target_dev", on=exp.condition("s.customer_id = t.customer_id AND s.event_date = t.event_date"), - model_or_snapshot="sushi.customer_revenue_by_day", - ) + select_models={"sushi.customer_revenue_by_day"}, + )[0] + + # verify queries were actually logged to the log file, this helps immensely with debugging + console_output = capsys.readouterr() + assert "__sqlmesh_join_key" not in console_output # they should not go to the console + assert "__sqlmesh_join_key" in caplog.text schema_diff = diff.schema_diff() assert schema_diff.added == [("z", exp.DataType.build("int"))] @@ -96,8 +144,7 @@ def test_data_diff(sushi_context_fixed_date): assert row_diff.t_sample.shape == (1, 6) -@pytest.mark.slow -def test_data_diff_decimals(sushi_context_fixed_date): +def test_data_diff_decimals_on_float(sushi_context_fixed_date): engine_adapter = sushi_context_fixed_date.engine_adapter engine_adapter.ctas( @@ -115,7 +162,7 @@ def test_data_diff_decimals(sushi_context_fixed_date): pd.DataFrame( { "key": [1, 2, 3], - "value": [1.0, 2.0, 3.1234], + "value": [1.0, 2.0, 3.1234321], } ), ) @@ -124,7 +171,7 @@ def test_data_diff_decimals(sushi_context_fixed_date): source="table_diff_source", target="table_diff_target", on=["key"], - ) + )[0] assert diff.row_diff().full_match_count == 3 assert diff.row_diff().partial_match_count == 0 @@ -133,12 +180,99 @@ def test_data_diff_decimals(sushi_context_fixed_date): target="table_diff_target", on=["key"], decimals=4, + )[0] + + row_diff = diff.row_diff() + joined_sample_columns = row_diff.joined_sample.columns + assert row_diff.full_match_count == 2 + assert row_diff.partial_match_count == 1 + assert "s__value" in joined_sample_columns + assert "t__value" in joined_sample_columns + + table_diff = TableDiff( + adapter=engine_adapter, + source="table_diff_source", + target="table_diff_target", + source_alias="dev", + target_alias="prod", + on=["key"], + decimals=4, ) - assert diff.row_diff().full_match_count == 2 - assert diff.row_diff().partial_match_count == 1 + + aliased_joined_sample = table_diff.row_diff().joined_sample.columns + assert "DEV__value" in aliased_joined_sample + assert "PROD__value" in aliased_joined_sample + + output = capture_console_output("show_row_diff", row_diff=table_diff.row_diff()) + + # Expected output with box-drawings + expected_output = r""" +Row Counts: +├── FULL MATCH: 2 rows (66.67%) +└── PARTIAL MATCH: 1 rows (33.33%) + +COMMON ROWS column comparison stats: + pct_match +value 66.666667 + + +COMMON ROWS sample data differences: +Column: value +┏━━━━━┳━━━━━━━━┳━━━━━━━━┓ +┃ key ┃ DEV ┃ PROD ┃ +┡━━━━━╇━━━━━━━━╇━━━━━━━━┩ +│ 3.0 │ 3.1233 │ 3.1234 │ +└─────┴────────┴────────┘ +""" + + stripped_output = strip_ansi_codes(output) + stripped_expected = expected_output.strip() + assert stripped_output == stripped_expected + + +def test_data_diff_decimals_on_numeric(): + engine_adapter = DuckDBConnectionConfig().create_engine_adapter() + + columns_to_types = { + "key": exp.DataType.build("int"), + "value": exp.DataType.build("decimal(10,5)"), + } + + engine_adapter.create_table("src", columns_to_types) + engine_adapter.create_table("target", columns_to_types) + + src_records = [ + (1, "25.12344"), + (2, "25.1234"), + (3, "25.124"), + (4, "25.14"), + (5, "25.4"), + ] + + target_records = [ + (1, "25.12343"), + (2, "25.1233"), + (3, "25.123"), + (4, "25.13"), + (5, "25.3"), + ] + + src_df = pd.DataFrame(data=src_records, columns=columns_to_types.keys()) + target_df = pd.DataFrame(data=target_records, columns=columns_to_types.keys()) + + engine_adapter.insert_append("src", src_df) + engine_adapter.insert_append("target", target_df) + + for decimals in range(5, 0, -1): + table_diff = TableDiff( + adapter=engine_adapter, source="src", target="target", on=["key"], decimals=decimals + ) + diff = table_diff.row_diff() + + assert diff.full_match_count == 5 - decimals + assert diff.partial_match_count + diff.full_match_count == 5 -@pytest.mark.slow def test_grain_check(sushi_context_fixed_date): expressions = d.parse( """ @@ -201,26 +335,33 @@ def test_grain_check(sushi_context_fixed_date): source="source_dev", target="target_dev", on=["'key_1'", "key_2"], - model_or_snapshot="SUSHI.GRAIN_ITEMS", + select_models={"memory.sushi*"}, skip_grain_check=False, - ) + )[0] row_diff = diff.row_diff() + assert row_diff.source_count == 7 + assert row_diff.target_count == 10 assert row_diff.full_match_count == 7 - assert row_diff.full_match_pct == 93.33 - assert row_diff.s_only_count == 2 - assert row_diff.t_only_count == 5 - assert row_diff.stats["join_count"] == 4 - assert row_diff.stats["null_grain_count"] == 4 - assert row_diff.stats["s_count"] != row_diff.stats["distinct_count_s"] + assert row_diff.partial_match_count == 0 + assert row_diff.s_only_count == 0 + assert row_diff.t_only_count == 3 + assert row_diff.full_match_pct == 82.35 + assert row_diff.partial_match_pct == 0 + assert row_diff.s_only_pct == 0 + assert row_diff.t_only_pct == 17.65 + assert row_diff.stats["join_count"] == 7 + assert ( + row_diff.stats["null_grain_count"] == 4 + ) # null grain currently (2025-07-24) means "any key column is null" as opposed to "all key columns are null" assert row_diff.stats["distinct_count_s"] == 7 - assert row_diff.stats["t_count"] != row_diff.stats["distinct_count_t"] assert row_diff.stats["distinct_count_t"] == 10 - assert row_diff.s_sample.shape == (0, 3) - assert row_diff.t_sample.shape == (3, 3) + assert row_diff.stats["s_count"] == row_diff.stats["distinct_count_s"] + assert row_diff.stats["t_count"] == row_diff.stats["distinct_count_t"] + assert row_diff.s_sample.shape == (row_diff.s_only_count, 3) + assert row_diff.t_sample.shape == (row_diff.t_only_count, 3) -@pytest.mark.slow def test_generated_sql(sushi_context_fixed_date: Context, mocker: MockerFixture): engine_adapter = sushi_context_fixed_date.engine_adapter @@ -246,9 +387,17 @@ def test_generated_sql(sushi_context_fixed_date: Context, mocker: MockerFixture) ), ) - query_sql = 'CREATE TABLE IF NOT EXISTS "sqlmesh_temp"."__temp_diff_abcdefgh" AS SELECT *, CASE WHEN "key_matches" = 1 AND "value_matches" = 1 THEN 1 ELSE 0 END AS "row_full_match" FROM (SELECT "s"."key" AS "s__key", "s"."value" AS "s__value", "t"."key" AS "t__key", "t"."value" AS "t__value", CASE WHEN NOT "s"."key" IS NULL THEN 1 ELSE 0 END AS "s_exists", CASE WHEN NOT "t"."key" IS NULL THEN 1 ELSE 0 END AS "t_exists", CASE WHEN "s"."key" = "t"."key" AND NOT "s"."key" IS NULL AND NOT "t"."key" IS NULL THEN 1 ELSE 0 END AS "row_joined", CASE WHEN "s"."key" IS NULL AND "t"."key" IS NULL THEN 1 ELSE 0 END AS "null_grain", CASE WHEN "s"."key" = "t"."key" THEN 1 WHEN ("s"."key" IS NULL) AND ("t"."key" IS NULL) THEN 1 WHEN ("s"."key" IS NULL) OR ("t"."key" IS NULL) THEN 0 ELSE 0 END AS "key_matches", CASE WHEN ROUND("s"."value", 3) = ROUND("t"."value", 3) THEN 1 WHEN ("s"."value" IS NULL) AND ("t"."value" IS NULL) THEN 1 WHEN ("s"."value" IS NULL) OR ("t"."value" IS NULL) THEN 0 ELSE 0 END AS "value_matches" FROM "table_diff_source" AS "s" FULL JOIN "table_diff_target" AS "t" ON ("s"."key" = "t"."key") OR (("s"."key" IS NULL) AND ("t"."key" IS NULL))) AS "stats"' - summary_query_sql = 'SELECT SUM("s_exists") AS "s_count", SUM("t_exists") AS "t_count", SUM("row_joined") AS "join_count", SUM("null_grain") AS "null_grain_count", SUM("row_full_match") AS "full_match_count", SUM("key_matches") AS "key_matches", SUM("value_matches") AS "value_matches", COUNT(DISTINCT ("s__key")) AS "distinct_count_s", COUNT(DISTINCT ("t__key")) AS "distinct_count_t" FROM "sqlmesh_temp"."__temp_diff_abcdefgh"' - sample_query_sql = 'SELECT "s_exists", "t_exists", "row_joined", "row_full_match", "s__key", "s__value", "t__key", "t__value" FROM "sqlmesh_temp"."__temp_diff_abcdefgh" WHERE "key_matches" = 0 OR "value_matches" = 0 ORDER BY "s__key" NULLS FIRST, "t__key" NULLS FIRST LIMIT 20' + query_sql = 'CREATE TABLE IF NOT EXISTS "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" AS WITH "__source" AS (SELECT "s"."key", "s"."value", "s"."key" AS "__sqlmesh_join_key" FROM "table_diff_source" AS "s"), "__target" AS (SELECT "t"."key", "t"."value", "t"."key" AS "__sqlmesh_join_key" FROM "table_diff_target" AS "t"), "__stats" AS (SELECT "s"."key" AS "s__key", "s"."value" AS "s__value", "s"."__sqlmesh_join_key" AS "s____sqlmesh_join_key", "t"."key" AS "t__key", "t"."value" AS "t__value", "t"."__sqlmesh_join_key" AS "t____sqlmesh_join_key", CASE WHEN NOT "s"."__sqlmesh_join_key" IS NULL THEN 1 ELSE 0 END AS "s_exists", CASE WHEN NOT "t"."__sqlmesh_join_key" IS NULL THEN 1 ELSE 0 END AS "t_exists", CASE WHEN "s"."__sqlmesh_join_key" = "t"."__sqlmesh_join_key" THEN 1 ELSE 0 END AS "row_joined", CASE WHEN "s"."key" IS NULL AND "t"."key" IS NULL THEN 1 ELSE 0 END AS "null_grain", CASE WHEN "s"."key" = "t"."key" THEN 1 WHEN ("s"."key" IS NULL) AND ("t"."key" IS NULL) THEN 1 WHEN ("s"."key" IS NULL) OR ("t"."key" IS NULL) THEN 0 ELSE 0 END AS "key_matches", CASE WHEN CAST(CAST("s"."value" AS DOUBLE) AS DECIMAL(38, 3)) = CAST(CAST("t"."value" AS DOUBLE) AS DECIMAL(38, 3)) THEN 1 WHEN ("s"."value" IS NULL) AND ("t"."value" IS NULL) THEN 1 WHEN ("s"."value" IS NULL) OR ("t"."value" IS NULL) THEN 0 ELSE 0 END AS "value_matches" FROM "__source" AS "s" FULL JOIN "__target" AS "t" ON "s"."__sqlmesh_join_key" = "t"."__sqlmesh_join_key") SELECT *, CASE WHEN "key_matches" = 1 AND "value_matches" = 1 THEN 1 ELSE 0 END AS "row_full_match" FROM "__stats"' + summary_query_sql = 'SELECT SUM("s_exists") AS "s_count", SUM("t_exists") AS "t_count", SUM("row_joined") AS "join_count", SUM("null_grain") AS "null_grain_count", SUM("row_full_match") AS "full_match_count", SUM("key_matches") AS "key_matches", SUM("value_matches") AS "value_matches", COUNT(DISTINCT ("s____sqlmesh_join_key")) AS "distinct_count_s", COUNT(DISTINCT ("t____sqlmesh_join_key")) AS "distinct_count_t" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh"' + compare_sql = 'SELECT ROUND(100 * (CAST(SUM("key_matches") AS DECIMAL) / COUNT("key_matches")), 9) AS "key_matches", ROUND(100 * (CAST(SUM("value_matches") AS DECIMAL) / COUNT("value_matches")), 9) AS "value_matches" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" WHERE "row_joined" = 1' + sample_query_sql = 'WITH "source_only" AS (SELECT \'source_only\' AS "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" WHERE "s_exists" = 1 AND "row_joined" = 0 ORDER BY "s__key" NULLS FIRST LIMIT 20), "target_only" AS (SELECT \'target_only\' AS "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" WHERE "t_exists" = 1 AND "row_joined" = 0 ORDER BY "t__key" NULLS FIRST LIMIT 20), "common_rows" AS (SELECT \'common_rows\' AS "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh" WHERE "row_joined" = 1 AND "row_full_match" = 0 ORDER BY "s__key" NULLS FIRST, "t__key" NULLS FIRST LIMIT 20) SELECT "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "source_only" UNION ALL SELECT "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "target_only" UNION ALL SELECT "__sqlmesh_sample_type", "s__key", "s__value", "s____sqlmesh_join_key", "t__key", "t__value", "t____sqlmesh_join_key" FROM "common_rows"' + drop_sql = 'DROP TABLE IF EXISTS "memory"."sqlmesh_temp_test"."__temp_diff_abcdefgh"' + + # make with_settings() return the current instance of engine_adapter so we can still spy on _execute + mocker.patch.object( + engine_adapter, "with_settings", new_callable=lambda: lambda **kwargs: engine_adapter + ) + assert engine_adapter.with_settings() == engine_adapter spy_execute = mocker.spy(engine_adapter, "_execute") mocker.patch("sqlmesh.core.engine_adapter.base.random_id", return_value="abcdefgh") @@ -258,8 +407,842 @@ def test_generated_sql(sushi_context_fixed_date: Context, mocker: MockerFixture) target="table_diff_target", on=["key"], skip_columns=["ignored"], + temp_schema="sqlmesh_temp_test", + ) + + spy_execute.assert_any_call(query_sql, False) + spy_execute.assert_any_call(summary_query_sql, False) + spy_execute.assert_any_call(compare_sql, False) + spy_execute.assert_any_call(sample_query_sql, False) + spy_execute.assert_any_call(drop_sql, False) + + spy_execute.reset_mock() + + # Also check WHERE clause is propagated correctly + sushi_context_fixed_date.table_diff( + source="table_diff_source", + target="table_diff_target", + on=["key"], + skip_columns=["ignored"], + where="key = 2", + ) + + query_sql_where = 'CREATE TABLE IF NOT EXISTS "memory"."sqlmesh_temp"."__temp_diff_abcdefgh" AS WITH "__source" AS (SELECT "s"."key", "s"."value", "s"."key" AS "__sqlmesh_join_key" FROM "table_diff_source" AS "s" WHERE "s"."key" = 2), "__target" AS (SELECT "t"."key", "t"."value", "t"."key" AS "__sqlmesh_join_key" FROM "table_diff_target" AS "t" WHERE "t"."key" = 2), "__stats" AS (SELECT "s"."key" AS "s__key", "s"."value" AS "s__value", "s"."__sqlmesh_join_key" AS "s____sqlmesh_join_key", "t"."key" AS "t__key", "t"."value" AS "t__value", "t"."__sqlmesh_join_key" AS "t____sqlmesh_join_key", CASE WHEN NOT "s"."__sqlmesh_join_key" IS NULL THEN 1 ELSE 0 END AS "s_exists", CASE WHEN NOT "t"."__sqlmesh_join_key" IS NULL THEN 1 ELSE 0 END AS "t_exists", CASE WHEN "s"."__sqlmesh_join_key" = "t"."__sqlmesh_join_key" THEN 1 ELSE 0 END AS "row_joined", CASE WHEN "s"."key" IS NULL AND "t"."key" IS NULL THEN 1 ELSE 0 END AS "null_grain", CASE WHEN "s"."key" = "t"."key" THEN 1 WHEN ("s"."key" IS NULL) AND ("t"."key" IS NULL) THEN 1 WHEN ("s"."key" IS NULL) OR ("t"."key" IS NULL) THEN 0 ELSE 0 END AS "key_matches", CASE WHEN CAST(CAST("s"."value" AS DOUBLE) AS DECIMAL(38, 3)) = CAST(CAST("t"."value" AS DOUBLE) AS DECIMAL(38, 3)) THEN 1 WHEN ("s"."value" IS NULL) AND ("t"."value" IS NULL) THEN 1 WHEN ("s"."value" IS NULL) OR ("t"."value" IS NULL) THEN 0 ELSE 0 END AS "value_matches" FROM "__source" AS "s" FULL JOIN "__target" AS "t" ON "s"."__sqlmesh_join_key" = "t"."__sqlmesh_join_key") SELECT *, CASE WHEN "key_matches" = 1 AND "value_matches" = 1 THEN 1 ELSE 0 END AS "row_full_match" FROM "__stats"' + spy_execute.assert_any_call(query_sql_where, False) + + +def test_tables_and_grain_inferred_from_model(sushi_context_fixed_date: Context): + (sushi_context_fixed_date.path / "models" / "waiter_revenue_by_day.sql").write_text(""" + MODEL ( + name sushi.waiter_revenue_by_day, + kind incremental_by_time_range ( + time_column event_date, + batch_size 10, + ), + owner jen, + cron '@daily', + audits ( + NUMBER_OF_ROWS(threshold := 0) + ), + grain (waiter_id, event_date) + ); + + SELECT + o.waiter_id::INT + 1 AS waiter_id, /* Waiter id */ + SUM(oi.quantity * i.price)::DOUBLE AS revenue, /* Revenue from orders taken by this waiter */ + o.event_date::DATE AS event_date /* Date */ + FROM sushi.orders AS o + LEFT JOIN sushi.order_items AS oi + ON o.id = oi.order_id AND o.event_date = oi.event_date + LEFT JOIN sushi.items AS i + ON oi.item_id = i.id AND oi.event_date = i.event_date + WHERE + o.event_date BETWEEN @start_date AND @end_date + GROUP BY + o.waiter_id, + o.event_date +""") + # this creates a dev preview of "sushi.waiter_revenue_by_day" + sushi_context_fixed_date.refresh() + sushi_context_fixed_date.auto_categorize_changes = CategorizerConfig( + sql=AutoCategorizationMode.FULL + ) + sushi_context_fixed_date.plan(environment="unit_test", auto_apply=True, include_unmodified=True) + + table_diff = sushi_context_fixed_date.table_diff( + source="unit_test", target="prod", select_models={"sushi.waiter_revenue_by_day"} + )[0] + assert isinstance(table_diff, TableDiff) + assert table_diff.source == "memory.sushi__unit_test.waiter_revenue_by_day" + assert table_diff.target == "memory.sushi.waiter_revenue_by_day" + + _, _, col_names = table_diff.key_columns + assert col_names == ["waiter_id", "event_date"] + + +def test_data_diff_array_dict(sushi_context_fixed_date): + engine_adapter = sushi_context_fixed_date.engine_adapter + + engine_adapter.ctas( + "table_diff_source", + pd.DataFrame( + { + "key": [1, 2, 3], + "value": [np.array([51.2, 4.5678]), np.array([2.31, 12.2]), np.array([5.0])], + "dict": [{"key1": 10, "key2": 20, "key3": 30}, {"key1": 10}, {}], + } + ), + ) + + engine_adapter.ctas( + "table_diff_target", + pd.DataFrame( + { + "key": [1, 2, 3], + "value": [ + np.array([51.2, 4.5679]), + np.array([2.31, 12.2, 3.6, 1.9]), + np.array([5.0]), + ], + "dict": [{"key1": 10, "key2": 13}, {"key1": 10}, {}], + } + ), + ) + + table_diff = TableDiff( + adapter=engine_adapter, + source="table_diff_source", + target="table_diff_target", + source_alias="dev", + target_alias="prod", + on=["key"], + decimals=4, ) - spy_execute.assert_any_call(query_sql) - spy_execute.assert_any_call(summary_query_sql) - spy_execute.assert_any_call(sample_query_sql) + diff = table_diff.row_diff() + aliased_joined_sample = diff.joined_sample.columns + + assert "DEV__value" in aliased_joined_sample + assert "PROD__value" in aliased_joined_sample + assert diff.full_match_count == 1 + assert diff.partial_match_count == 2 + + output = capture_console_output("show_row_diff", row_diff=diff) + + # Expected output with boxes + expected_output = r""" +Row Counts: +├── FULL MATCH: 1 rows (33.33%) +└── PARTIAL MATCH: 2 rows (66.67%) + +COMMON ROWS column comparison stats: + pct_match +value 33.333333 +dict 66.666667 + + +COMMON ROWS sample data differences: +Column: value +┏━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ key ┃ DEV ┃ PROD ┃ +┡━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━┩ +│ 1 │ [51.2, 4.5678] │ [51.2, 4.5679] │ +│ 2 │ [2.31, 12.2] │ [2.31, 12.2, 3.6, 1.9] │ +└─────┴────────────────┴────────────────────────┘ +Column: dict +┏━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┓ +┃ key ┃ DEV ┃ PROD ┃ +┡━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━┩ +│ 1 │ {key1=10, key2=20, key3=30} │ {key1=10, key2=13} │ +└─────┴─────────────────────────────┴────────────────────┘ +""" + + stripped_output = strip_ansi_codes(output) + stripped_expected = expected_output.strip() + assert stripped_output == stripped_expected + + +def test_data_diff_array_struct_query(): + engine_adapter = DuckDBConnectionConfig().create_engine_adapter() + + columns_to_types = {"key": exp.DataType.build("int"), "value": exp.DataType.build("int")} + + engine_adapter.create_table("table_diff_source", columns_to_types) + engine_adapter.create_table("table_diff_target", columns_to_types) + + engine_adapter.execute( + "insert into table_diff_source (key, value) values (1, 1), (1, 2), (1, 3)" + ) + engine_adapter.execute( + "insert into table_diff_target (key, value) values (1, 1), (1, 3), (1, 2)" + ) + + engine_adapter.execute( + "create view src_view as select key, array_agg(value) as val_arr, map(['k','v'], [10,11]) as val_map from table_diff_source group by 1" + ) + engine_adapter.execute( + "create view target_view as select key, array_agg(value) as val_arr, map(['k','v'],[11,10]) as val_map from table_diff_target group by 1" + ) + + table_diff = TableDiff( + adapter=engine_adapter, + source="src_view", + target="target_view", + source_alias="dev", + target_alias="prod", + on=["key"], + ) + + diff = table_diff.row_diff() + + output = capture_console_output("show_row_diff", row_diff=diff) + + assert ( + strip_ansi_codes(output) + == """Row Counts: +└── PARTIAL MATCH: 1 rows (100.0%) + +COMMON ROWS column comparison stats: + pct_match +val_arr 0.0 +val_map 0.0 + + +COMMON ROWS sample data differences: +Column: val_arr +┏━━━━━┳━━━━━━━━━┳━━━━━━━━━┓ +┃ key ┃ DEV ┃ PROD ┃ +┡━━━━━╇━━━━━━━━━╇━━━━━━━━━┩ +│ 1 │ [1 2 3] │ [1 3 2] │ +└─────┴─────────┴─────────┘ +Column: val_map +┏━━━━━┳━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━┓ +┃ key ┃ DEV ┃ PROD ┃ +┡━━━━━╇━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━┩ +│ 1 │ {'k': 10, 'v': 11} │ {'k': 11, 'v': 10} │ +└─────┴────────────────────┴────────────────────┘ +""".strip() + ) + + +def test_data_diff_nullable_booleans(): + engine_adapter = DuckDBConnectionConfig().create_engine_adapter() + + columns_to_types = {"key": exp.DataType.build("int"), "value": exp.DataType.build("boolean")} + + engine_adapter.create_table("table_diff_source", columns_to_types) + engine_adapter.create_table("table_diff_target", columns_to_types) + + engine_adapter.execute( + "insert into table_diff_source (key, value) values (1, true), (2, false), (3, null)" + ) + engine_adapter.execute( + "insert into table_diff_target (key, value) values (1, false), (2, null), (3, true)" + ) + + table_diff = TableDiff( + adapter=engine_adapter, + source="table_diff_source", + target="table_diff_target", + source_alias="dev", + target_alias="prod", + on=["key"], + ) + + diff = table_diff.row_diff() + + output = capture_console_output("show_row_diff", row_diff=diff) + + expected_output = """ +Row Counts: +└── PARTIAL MATCH: 3 rows (100.0%) + +COMMON ROWS column comparison stats: + pct_match +value 0.0 + + +COMMON ROWS sample data differences: +Column: value +┏━━━━━┳━━━━━━━┳━━━━━━━┓ +┃ key ┃ DEV ┃ PROD ┃ +┡━━━━━╇━━━━━━━╇━━━━━━━┩ +│ 1 │ True │ False │ +│ 2 │ False │ │ +│ 3 │ │ True │ +└─────┴───────┴───────┘ +""" + + assert strip_ansi_codes(output) == expected_output.strip() + + +def test_data_diff_multiple_models(sushi_context_fixed_date, capsys, caplog): + # Create first analytics model + expressions = d.parse( + """ + MODEL (name memory.sushi.analytics_1, kind full, grain(key), tags (finance),); + SELECT + key, + value, + FROM + (VALUES + (1, 3), + (2, 4), + ) AS t (key, value) + """ + ) + model_s = load_sql_based_model(expressions, dialect="snowflake") + sushi_context_fixed_date.upsert_model(model_s) + + # Create second analytics model from analytics_1 + expressions_2 = d.parse( + """ + MODEL (name memory.sushi.analytics_2, kind full, grain(key), tags (finance),); + SELECT + key, + value as amount, + FROM + memory.sushi.analytics_1 + """ + ) + model_s2 = load_sql_based_model(expressions_2, dialect="snowflake") + sushi_context_fixed_date.upsert_model(model_s2) + + sushi_context_fixed_date.plan( + "source_dev", + no_prompts=True, + auto_apply=True, + skip_tests=True, + start="2023-01-31", + end="2023-01-31", + ) + + # Modify first model + model = sushi_context_fixed_date.models['"MEMORY"."SUSHI"."ANALYTICS_1"'] + modified_model = model.dict() + modified_model["query"] = ( + exp.select("*") + .from_(model.query.subquery()) + .union("SELECT key, value FROM (VALUES (1, 6),(2,3),) AS t (key, value)") + ) + modified_sqlmodel = SqlModel(**modified_model) + sushi_context_fixed_date.upsert_model(modified_sqlmodel) + + # Modify second model + model2 = sushi_context_fixed_date.models['"MEMORY"."SUSHI"."ANALYTICS_2"'] + modified_model2 = model2.dict() + modified_model2["query"] = ( + exp.select("*") + .from_(model2.query.subquery()) + .union("SELECT key, amount FROM (VALUES (5, 150.2),(6,250.2),) AS t (key, amount)") + ) + modified_sqlmodel2 = SqlModel(**modified_model2) + sushi_context_fixed_date.upsert_model(modified_sqlmodel2) + + sushi_context_fixed_date.auto_categorize_changes = CategorizerConfig( + sql=AutoCategorizationMode.FULL + ) + sushi_context_fixed_date.plan( + "target_dev", + create_from="source_dev", + no_prompts=True, + auto_apply=True, + skip_tests=True, + start="2023-01-31", + end="2023-01-31", + ) + + # Get diffs for both models + selector = {"tag:finance & memory.sushi.analytics*"} + diffs = sushi_context_fixed_date.table_diff( + source="source_dev", + target="target_dev", + on=["key"], + select_models=selector, + skip_grain_check=False, + ) + + assert len(diffs) == 2 + + # Check analytics_1 diff + diff1 = next(d for d in diffs if "ANALYTICS_1" in d.source) + row_diff1 = diff1.row_diff() + assert row_diff1.full_match_count == 2 + assert row_diff1.full_match_pct == 50.0 + assert row_diff1.s_only_count == 0 + assert row_diff1.t_only_count == 0 + assert row_diff1.stats["join_count"] == 4 + assert row_diff1.stats["null_grain_count"] == 0 + assert row_diff1.stats["s_count"] == 4 + assert row_diff1.stats["distinct_count_s"] == 2 + assert row_diff1.stats["t_count"] == 4 + assert row_diff1.stats["distinct_count_t"] == 2 + assert row_diff1.s_sample.shape == (0, 2) + assert row_diff1.t_sample.shape == (0, 2) + + # Check analytics_2 diff + diff2 = next(d for d in diffs if "ANALYTICS_2" in d.source) + row_diff2 = diff2.row_diff() + assert row_diff2.full_match_count == 2 + assert row_diff2.full_match_pct == 40.0 + assert row_diff2.s_only_count == 0 + assert row_diff2.t_only_count == 2 + assert row_diff2.stats["join_count"] == 4 + assert row_diff2.stats["null_grain_count"] == 0 + assert row_diff2.stats["s_count"] == 4 + assert row_diff2.stats["distinct_count_s"] == 2 + assert row_diff2.stats["t_count"] == 6 + assert row_diff2.stats["distinct_count_t"] == 4 + assert row_diff2.s_sample.shape == (0, 2) + assert row_diff2.t_sample.shape == (2, 2) + + # This selector shouldn't return any diffs since both models have this tag + selector = {"^tag:finance"} + diffs = sushi_context_fixed_date.table_diff( + source="source_dev", + target="target_dev", + on=["key"], + select_models=selector, + skip_grain_check=False, + ) + assert len(diffs) == 0 + + +def test_data_diff_forward_only(sushi_context_fixed_date, capsys, caplog): + expressions = d.parse( + """ + MODEL (name memory.sushi.full_1, kind full, grain(key),); + SELECT + key, + value, + FROM + (VALUES + (1, 3), + (2, 4), + ) AS t (key, value) + """ + ) + model_s = load_sql_based_model(expressions, dialect="snowflake") + sushi_context_fixed_date.upsert_model(model_s) + + # Create second analytics model sourcing from first + expressions_2 = d.parse( + """ + MODEL (name memory.sushi.full_2, kind full, grain(key),); + SELECT + key, + value as amount, + FROM + memory.sushi.full_1 + """ + ) + model_s2 = load_sql_based_model(expressions_2, dialect="snowflake") + sushi_context_fixed_date.upsert_model(model_s2) + + sushi_context_fixed_date.plan( + "target_dev", + no_prompts=True, + auto_apply=True, + skip_tests=True, + start="2023-01-31", + end="2023-01-31", + ) + + model = sushi_context_fixed_date.models['"MEMORY"."SUSHI"."FULL_1"'] + modified_model = model.dict() + modified_model["query"] = exp.select("*").from_("(VALUES (12, 6),(5,3),) AS t (key, value)") + modified_sqlmodel = SqlModel(**modified_model) + sushi_context_fixed_date.upsert_model(modified_sqlmodel) + + sushi_context_fixed_date.auto_categorize_changes = CategorizerConfig( + sql=AutoCategorizationMode.FULL + ) + + plan_builder = sushi_context_fixed_date.plan_builder( + "source_dev", skip_tests=True, forward_only=True + ) + plan = plan_builder.build() + + sushi_context_fixed_date.apply(plan) + + # Get diffs for both models + selector = {"*full*"} + diffs = sushi_context_fixed_date.table_diff( + source="source_dev", + target="target_dev", + on=["key"], + select_models=selector, + skip_grain_check=False, + ) + + # Both models should be diffed + assert len(diffs) == 2 + + # Check full_1 diff + diff1 = next(d for d in diffs if "FULL_1" in d.source) + row_diff1 = diff1.row_diff() + diff2 = next(d for d in diffs if "FULL_2" in d.source) + row_diff2 = diff2.row_diff() + + # Both diffs should show the same matches + for row_diff in [row_diff1, row_diff2]: + assert row_diff.full_match_count == 0 + assert row_diff.full_match_pct == 0.0 + assert row_diff.s_only_count == 2 + assert row_diff.t_only_count == 2 + assert row_diff.stats["join_count"] == 0 + assert row_diff.stats["null_grain_count"] == 0 + assert row_diff.stats["s_count"] == 2 + assert row_diff.stats["distinct_count_s"] == 2 + assert row_diff.stats["t_count"] == 2 + assert row_diff.stats["distinct_count_t"] == 2 + assert row_diff.s_sample.shape == (2, 2) + assert row_diff.t_sample.shape == (2, 2) + + +def test_data_diff_empty_tables(): + engine_adapter = DuckDBConnectionConfig().create_engine_adapter() + + columns_to_types_src = { + "key": exp.DataType.build("int"), + "value": exp.DataType.build("varchar"), + } + columns_to_types_target = { + "key": exp.DataType.build("int"), + "value2": exp.DataType.build("varchar"), + } + + engine_adapter.create_table("table_diff_source", columns_to_types_src) + engine_adapter.create_table("table_diff_target", columns_to_types_target) + + table_diff = TableDiff( + adapter=engine_adapter, + source="table_diff_source", + target="table_diff_target", + source_alias="dev", + target_alias="prod", + on=["key"], + ) + + # should show the schema diff + schema_diff = table_diff.schema_diff() + assert len(schema_diff.added) == 1 + assert schema_diff.added[0][0] == "value2" + assert len(schema_diff.removed) == 1 + assert schema_diff.removed[0][0] == "value" + + # should not error on the row diff + row_diff = table_diff.row_diff() + assert row_diff.empty + + output = capture_console_output("show_row_diff", row_diff=row_diff) + assert ( + strip_ansi_codes(output) == "Neither the source nor the target table contained any records" + ) + + +@pytest.mark.slow +def test_data_diff_multiple_models_lacking_grain(sushi_context_fixed_date, capsys, caplog): + # Create first model with grain + expressions = d.parse( + """ + MODEL (name memory.sushi.grain_model, kind full, grain(key),); + SELECT + key, + value, + FROM + (VALUES + (1, 3), + (2, 4), + ) AS t (key, value) + """ + ) + model_s = load_sql_based_model(expressions, dialect="snowflake") + sushi_context_fixed_date.upsert_model(model_s) + + # Create second model without grain + expressions_2 = d.parse( + """ + MODEL (name memory.sushi.no_grain_model, kind full,); + SELECT + key, + value as amount, + FROM + memory.sushi.grain_model + """ + ) + model_s2 = load_sql_based_model(expressions_2, dialect="snowflake") + sushi_context_fixed_date.upsert_model(model_s2) + + sushi_context_fixed_date.plan( + "source_dev", + no_prompts=True, + auto_apply=True, + skip_tests=True, + start="2023-01-31", + end="2023-01-31", + ) + + # Modify first model + model = sushi_context_fixed_date.models['"MEMORY"."SUSHI"."GRAIN_MODEL"'] + modified_model = model.dict() + modified_model["query"] = ( + exp.select("*") + .from_(model.query.subquery()) + .union("SELECT key, value FROM (VALUES (1, 6),(2,3),) AS t (key, value)") + ) + modified_sqlmodel = SqlModel(**modified_model) + sushi_context_fixed_date.upsert_model(modified_sqlmodel) + + # Modify second model + model2 = sushi_context_fixed_date.models['"MEMORY"."SUSHI"."NO_GRAIN_MODEL"'] + modified_model2 = model2.dict() + modified_model2["query"] = ( + exp.select("*") + .from_(model2.query.subquery()) + .union("SELECT key, amount FROM (VALUES (5, 150.2),(6,250.2),) AS t (key, amount)") + ) + modified_sqlmodel2 = SqlModel(**modified_model2) + sushi_context_fixed_date.upsert_model(modified_sqlmodel2) + + sushi_context_fixed_date.auto_categorize_changes = CategorizerConfig( + sql=AutoCategorizationMode.FULL + ) + sushi_context_fixed_date.plan( + "target_dev", + create_from="source_dev", + no_prompts=True, + auto_apply=True, + skip_tests=True, + start="2023-01-31", + end="2023-01-31", + ) + + # By default erroring out when even one model lacks a grain + with pytest.raises( + SQLMeshError, + match=r"SQLMesh doesn't know how to join the tables for the following models:*", + ): + sushi_context_fixed_date.table_diff( + source="source_dev", + target="target_dev", + select_models={"*"}, + skip_grain_check=False, + ) + + # With warn_grain_check flag the diff will go ahead by warning + diffs = sushi_context_fixed_date.table_diff( + source="source_dev", + target="target_dev", + select_models={"*"}, + skip_grain_check=False, + warn_grain_check=True, + ) + + # Check that the diff was performed only for the model with a grain + assert len(diffs) == 1 + diff1 = diffs[0] + + # Check the table diff corresponds to the grain model + row_diff1 = diff1.row_diff() + assert row_diff1.full_match_count == 2.0 + assert row_diff1.full_match_pct == 50.0 + assert row_diff1.s_only_count == 0.0 + assert row_diff1.t_only_count == 0.0 + assert row_diff1.stats["join_count"] == 4.0 + assert row_diff1.stats["null_grain_count"] == 0.0 + assert row_diff1.stats["s_count"] == 4.0 + assert row_diff1.stats["distinct_count_s"] == 2.0 + assert row_diff1.stats["t_count"] == 4.0 + assert row_diff1.stats["distinct_count_t"] == 2.0 + assert row_diff1.s_sample.shape == (0, 2) + assert row_diff1.t_sample.shape == (0, 2) + assert row_diff1.joined_sample.shape == (2, 3) + assert row_diff1.sample.shape == (2, 4) + + +def test_schema_diff_ignore_case(): + # no changes + table_a = {"COL_A": exp.DataType.build("varchar"), "cOl_b": exp.DataType.build("int")} + table_b = {"col_a": exp.DataType.build("varchar"), "COL_b": exp.DataType.build("int")} + + diff = SchemaDiff( + source="table_a", + source_schema=table_a, + target="table_b", + target_schema=table_b, + ignore_case=True, + ) + + assert not diff.has_changes + + # added in target + table_a = {"COL_A": exp.DataType.build("varchar"), "cOl_b": exp.DataType.build("int")} + table_b = { + "col_a": exp.DataType.build("varchar"), + "COL_b": exp.DataType.build("int"), + "cOL__C": exp.DataType.build("date"), + } + + diff = SchemaDiff( + source="table_a", + source_schema=table_a, + target="table_b", + target_schema=table_b, + ignore_case=True, + ) + + assert diff.has_changes + assert len(diff.added) == 1 + assert diff.added[0] == ( + "cOL__C", + exp.DataType.build("date"), + ) # notice: case preserved on output + assert not diff.removed + assert not diff.modified + + # removed from source + table_a = { + "cOL_fo0": exp.DataType.build("float"), + "COL_A": exp.DataType.build("varchar"), + "cOl_b": exp.DataType.build("int"), + } + table_b = {"col_a": exp.DataType.build("varchar"), "COL_b": exp.DataType.build("int")} + + diff = SchemaDiff( + source="table_a", + source_schema=table_a, + target="table_b", + target_schema=table_b, + ignore_case=True, + ) + + assert diff.has_changes + assert not diff.added + assert len(diff.removed) == 1 + assert diff.removed[0] == ( + "cOL_fo0", + exp.DataType.build("float"), + ) # notice: case preserved on output + assert not diff.modified + + # column type change + table_a = {"CoL_A": exp.DataType.build("varchar"), "cOl_b": exp.DataType.build("int")} + table_b = {"col_a": exp.DataType.build("date"), "COL_b": exp.DataType.build("int")} + + diff = SchemaDiff( + source="table_a", + source_schema=table_a, + target="table_b", + target_schema=table_b, + ignore_case=True, + ) + + assert diff.has_changes + assert not diff.added + assert not diff.removed + assert diff.modified == { + "CoL_A": ( + exp.DataType.build("varchar"), + exp.DataType.build("date"), + ) # notice: source casing used on output + } + + +def test_data_diff_sample_limit(): + engine_adapter = DuckDBConnectionConfig().create_engine_adapter() + + columns_to_types = {"id": exp.DataType.build("int"), "name": exp.DataType.build("varchar")} + + engine_adapter.create_table("src", columns_to_types) + engine_adapter.create_table("target", columns_to_types) + + common_records = {} + src_only_records = {} + target_only_records = {} + + for i in range(0, 10): + common_records[i] = f"common_{i}" + src_only_records[i + 20] = f"src_{i}" + target_only_records[i + 40] = f"target_{i}" + + src_records = {**common_records, **src_only_records} + target_records = {**common_records, **target_only_records} + + # changes + src_records[1] = "modified_source_1" + src_records[3] = "modified_source_3" + target_records[2] = "modified_target_2" + target_records[7] = "modified_target_7" + + src_df = pd.DataFrame.from_records([{"id": k, "name": v} for k, v in src_records.items()]) + target_df = pd.DataFrame.from_records([{"id": k, "name": v} for k, v in target_records.items()]) + + engine_adapter.insert_append("src", src_df) + engine_adapter.insert_append("target", target_df) + + table_diff = TableDiff( + adapter=engine_adapter, source="src", target="target", on=["id"], limit=3 + ) + + diff = table_diff.row_diff() + + assert diff.s_only_count == 10 + assert diff.t_only_count == 10 + assert diff.join_count == 10 + + # each sample should contain :limit records + assert len(diff.s_sample) == 3 + assert len(diff.t_sample) == 3 + assert len(diff.joined_sample) == 3 + + +def test_data_diff_nulls_in_some_grain_columns(): + engine_adapter = DuckDBConnectionConfig().create_engine_adapter() + + columns_to_types = { + "key1": exp.DataType.build("int"), + "key2": exp.DataType.build("varchar"), + "key3": exp.DataType.build("int"), + "value": exp.DataType.build("varchar"), + } + + engine_adapter.create_table("src", columns_to_types) + engine_adapter.create_table("target", columns_to_types) + + src_records = [ + (1, None, 1, "value"), # full match + (None, None, None, "null value"), # join, partial match + (2, None, None, "source only"), # source only + ] + + target_records = [ + (1, None, 1, "value"), # full match + (None, None, None, "null value modified"), # join, partial match + (None, "three", 2, "target only"), # target only + ] + + src_df = pd.DataFrame(data=src_records, columns=columns_to_types.keys()) + target_df = pd.DataFrame(data=target_records, columns=columns_to_types.keys()) + + engine_adapter.insert_append("src", src_df) + engine_adapter.insert_append("target", target_df) + + table_diff = TableDiff( + adapter=engine_adapter, source="src", target="target", on=["key1", "key2", "key3"] + ) + + diff = table_diff.row_diff() + + assert diff.join_count == 2 + assert diff.s_only_count == 1 + assert diff.t_only_count == 1 + assert diff.full_match_count == 1 + assert diff.partial_match_count == 1 + + assert diff.s_sample["value"].tolist() == ["source only"] + assert diff.t_sample["value"].tolist() == ["target only"] + assert diff.joined_sample[["s__value", "t__value"]].values.flatten().tolist() == [ + "null value", + "null value modified", + ] diff --git a/tests/core/test_test.py b/tests/core/test_test.py index 97163f4cf6..43d0f333c3 100644 --- a/tests/core/test_test.py +++ b/tests/core/test_test.py @@ -2,15 +2,19 @@ import datetime import typing as t +import io from pathlib import Path -from unittest.mock import call +import unittest +from unittest.mock import call, patch +from shutil import rmtree -import pandas as pd +import pandas as pd # noqa: TID253 import pytest from pytest_mock.plugin import MockerFixture from sqlglot import exp +from IPython.utils.capture import capture_output -from sqlmesh.cli.example_project import init_example_project +from sqlmesh.cli.project_init import init_example_project from sqlmesh.core import constants as c from sqlmesh.core.config import ( Config, @@ -19,16 +23,23 @@ GatewayConfig, ModelDefaultsConfig, ) -from sqlmesh.core.context import Context +from sqlmesh.core.context import Context, ExecutionContext +from sqlmesh.core.console import get_console from sqlmesh.core.dialect import parse from sqlmesh.core.engine_adapter import EngineAdapter from sqlmesh.core.macros import MacroEvaluator, macro from sqlmesh.core.model import Model, SqlModel, load_sql_based_model, model +from sqlmesh.core.model.common import ParsableSql from sqlmesh.core.test.definition import ModelTest, PythonModelTest, SqlModelTest -from sqlmesh.utils.errors import ConfigError, TestError +from sqlmesh.core.test.result import ModelTextTestResult +from sqlmesh.core.test.context import TestExecutionContext +from sqlmesh.utils import Verbosity +from sqlmesh.utils.errors import ConfigError, SQLMeshError, TestError from sqlmesh.utils.yaml import dump as dump_yaml from sqlmesh.utils.yaml import load as load_yaml +from tests.utils.test_helpers import use_terminal_console + if t.TYPE_CHECKING: from unittest import TestResult @@ -46,7 +57,7 @@ def _create_test( test_name=test_name, model=model, models=context._models, - engine_adapter=context._test_connection_config.create_engine_adapter( + engine_adapter=context.test_connection_config.create_engine_adapter( register_comments_override=False ), dialect=context.config.dialect, @@ -60,11 +71,14 @@ def _create_model( meta: str = SUSHI_FOO_META, dialect: t.Optional[str] = None, default_catalog: t.Optional[str] = None, + **kwargs: t.Any, ) -> SqlModel: parsed_definition = parse(f"{meta};{query}", default_dialect=dialect) return t.cast( SqlModel, - load_sql_based_model(parsed_definition, dialect=dialect, default_catalog=default_catalog), + load_sql_based_model( + parsed_definition, dialect=dialect, default_catalog=default_catalog, **kwargs + ), ) @@ -74,6 +88,7 @@ def _check_successful_or_raise( assert result is not None if not result.wasSuccessful(): error_or_failure_traceback = (result.errors or result.failures)[0][1] + print(error_or_failure_traceback) if expected_msg: assert expected_msg in error_or_failure_traceback else: @@ -108,7 +123,7 @@ def full_model_with_two_ctes(request) -> SqlModel: renamed AS ( SELECT id AS fid FROM source ) - SELECT fid FROM renamed; + SELECT fid FROM RENAMED; """, dialect=getattr(request, "param", None), default_catalog="memory", @@ -345,6 +360,48 @@ def test_row_order(sushi_context: Context, full_model_without_ctes: SqlModel) -> ), ) + model_sql = """ +SELECT + ARRAY_AGG(DISTINCT id_contact_b ORDER BY id_contact_b) AS aggregated_duplicates +FROM + source +GROUP BY + id_contact_a +ORDER BY + id_contact_a + """ + + _check_successful_or_raise( + _create_test( + body=load_yaml( + """ +test_array_order: + model: test + inputs: + source: + - id_contact_a: a + id_contact_b: b + - id_contact_a: a + id_contact_b: c + outputs: + query: + - aggregated_duplicates: + - c + - b + """ + ), + test_name="test_array_order", + model=_create_model(model_sql), + context=Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))), + ).run(), + expected_msg=( + """AssertionError: Data mismatch (exp: expected, act: actual)\n\n""" + " aggregated_duplicates \n" + " exp act\n" + "0 (c, b) (b, c)\n" + ), + ) + @pytest.mark.parametrize( "waiter_names_input", @@ -356,7 +413,7 @@ def test_row_order(sushi_context: Context, full_model_without_ctes: SqlModel) -> - id: 3 name: 'bob' """, - """sushi.waiter_names: + """sushi.waiter_names: format: csv rows: | id,name @@ -410,21 +467,21 @@ def test_partial_data(sushi_context: Context, waiter_names_input: str) -> None: [ """sushi.waiter_names: format: yaml - rows: + rows: - id: 1 name: alice - id: 2 name: 'bob' """, - """sushi.waiter_names: + """sushi.waiter_names: format: csv rows: | id,name 1,alice 2,bob""", - """sushi.waiter_names: + """sushi.waiter_names: format: csv - csv_settings: + csv_settings: sep: "#" rows: | id#name @@ -713,6 +770,42 @@ def test_partial_data_column_order(sushi_context: Context) -> None: ).run() ) + # - output df must differ if sorted by (id, event_date) vs. (event_date, id) + # - output partial must be true + _check_successful_or_raise( + _create_test( + body=load_yaml( + """ +test_foo: + model: sushi.foo + inputs: + sushi.items: + - id: 9876 + event_date: 2020-01-01 + - id: 1234 + name: hello + event_date: 2020-01-02 + outputs: + partial: true + query: + - event_date: 2020-01-01 + id: 9876 + - event_date: 2020-01-02 + id: 1234 + name: hello + """ + ), + test_name="test_foo", + model=sushi_context.upsert_model( + _create_model( + "SELECT id, name, price, event_date FROM sushi.items", + default_catalog=sushi_context.default_catalog, + ) + ), + context=sushi_context, + ).run() + ) + def test_partial_data_missing_schemas(sushi_context: Context) -> None: _check_successful_or_raise( @@ -818,7 +911,8 @@ def test_partially_inferred_schemas(sushi_context: Context, mocker: MockerFixtur 'CAST("s" AS STRUCT("d" DATE)) AS "s", ' 'CAST("a" AS INT) AS "a", ' 'CAST("b" AS TEXT) AS "b" ' - """FROM (VALUES ({'d': CAST('2020-01-01' AS DATE)}, 1, 'bla')) AS "t"("s", "a", "b")""" + """FROM (VALUES ({'d': CAST('2020-01-01' AS DATE)}, 1, 'bla')) AS "t"("s", "a", "b")""", + False, ) @@ -971,6 +1065,97 @@ def test_row_difference_failure() -> None: ) +def test_index_preservation_with_later_rows() -> None: + # Test comparison with differences in later rows + _check_successful_or_raise( + _create_test( + body=load_yaml( + """ +test_foo: + model: sushi.foo + inputs: + raw: + - id: 1 + value: 100 + - id: 2 + value: 200 + - id: 3 + value: 300 + - id: 4 + value: 400 + outputs: + query: + - id: 1 + value: 100 + - id: 2 + value: 200 + - id: 3 + value: 999 + - id: 4 + value: 888 + """ + ), + test_name="test_foo", + model=_create_model("SELECT id, value FROM raw"), + context=Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))), + ).run(), + expected_msg=( + "AssertionError: Data mismatch (exp: expected, act: actual)\n\n" + " value \n" + " exp act\n" + "2 999.0 300.0\n" + "3 888.0 400.0\n" + ), + ) + + # Test with null values in later rows + _check_successful_or_raise( + _create_test( + body=load_yaml( + """ +test_foo: + model: sushi.foo + inputs: + raw: + - id: 1 + value: 100 + - id: 2 + value: 200 + - id: 3 + value: null + - id: 4 + value: 400 + - id: 5 + value: null + outputs: + query: + - id: 1 + value: 100 + - id: 2 + value: 200 + - id: 3 + value: 300 + - id: 4 + value: null + - id: 5 + value: 500 + """ + ), + test_name="test_foo", + model=_create_model("SELECT id, value FROM raw"), + context=Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))), + ).run(), + expected_msg=( + "AssertionError: Data mismatch (exp: expected, act: actual)\n\n" + " value \n" + " exp act\n" + "2 300.0 NaN\n" + "3 NaN 400.0\n" + "4 500.0 NaN\n" + ), + ) + + def test_unknown_column_error() -> None: _check_successful_or_raise( _create_test( @@ -992,13 +1177,35 @@ def test_unknown_column_error() -> None: context=Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))), ).run(), expected_msg=( - "sqlmesh.utils.errors.TestError: Detected unknown column(s)\n\n" + "sqlmesh.utils.errors.TestError: Failed to run test:\n" + "Detected unknown column(s)\n\n" "Expected column(s): id, value\n" "Unknown column(s): foo\n" ), ) +def test_invalid_outputs_error() -> None: + with pytest.raises(TestError, match="Incomplete test, outputs must contain 'query' or 'ctes'"): + _create_test( + body=load_yaml( + """ +test_foo: + model: sushi.foo + inputs: + raw: + - id: 1 + outputs: + rows: + - id: 1 + """ + ), + test_name="test_foo", + model=_create_model("SELECT id FROM raw"), + context=Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))), + ) + + def test_empty_rows(sushi_context: Context) -> None: _check_successful_or_raise( _create_test( @@ -1124,25 +1331,30 @@ def test_nested_data_types(sushi_context: Context) -> None: inputs: sushi.raw: columns: - array: "INT[]" + array1: "INT[]" + array2: "STRUCT(k VARCHAR, v STRUCT(v_str VARCHAR, v_int INT, v_int_arr INT[]))[]" struct: "STRUCT(x INT[], y VARCHAR, z INT, w STRUCT(a INT))" rows: - - array: [1, 2, 3] + - array1: [1, 2, 3] + array2: [{'k': 'hello', 'v': {'v_str': 'there', 'v_int': 10, 'v_int_arr': [1, 2]}}] struct: {'x': [1, 2, 3], 'y': 'foo', 'z': 1, 'w': {'a': 5}} - - array: + - array1: - 2 - 3 - - array: [0, 4, 1] + - array1: [0, 4, 1] outputs: query: - - array: [0, 4, 1] - - array: [1, 2, 3] + - array1: [0, 4, 1] + - array1: [1, 2, 3] + array2: [{'k': 'hello', 'v': {'v_str': 'there', 'v_int': 10, 'v_int_arr': [1, 2]}}] struct: {'x': [1, 2, 3], 'y': 'foo', 'z': 1, 'w': {'a': 5}} - - array: [2, 3] + - array1: [2, 3] """ ), test_name="test_foo", - model=_create_model("SELECT array, struct FROM sushi.raw", default_catalog="memory"), + model=_create_model( + "SELECT array1, array2, struct FROM sushi.raw", default_catalog="memory" + ), context=Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))), ).run() ) @@ -1176,25 +1388,55 @@ def test_freeze_time(mocker: MockerFixture) -> None: spy_execute.assert_has_calls( [ - call('CREATE SCHEMA IF NOT EXISTS "memory"."sqlmesh_test_jzngz56a"'), + call('CREATE SCHEMA IF NOT EXISTS "memory"."sqlmesh_test_jzngz56a"', False), call( "SELECT " """CAST('2023-01-01 12:05:03+00:00' AS DATE) AS "cur_date", """ """CAST('2023-01-01 12:05:03+00:00' AS TIME) AS "cur_time", """ - '''CAST('2023-01-01 12:05:03+00:00' AS TIMESTAMP) AS "cur_timestamp"''' + '''CAST('2023-01-01 12:05:03+00:00' AS TIMESTAMP) AS "cur_timestamp"''', + False, ), - call('DROP SCHEMA IF EXISTS "memory"."sqlmesh_test_jzngz56a" CASCADE'), + call('DROP SCHEMA IF EXISTS "memory"."sqlmesh_test_jzngz56a" CASCADE', False), + ] + ) + + test = _create_test( + body=load_yaml( + """ +test_foo: + model: xyz + outputs: + query: + - cur_timestamp: "2023-01-01 12:05:03+00:00" + vars: + execution_time: "2023-01-01 12:05:03+00:00" + """ + ), + test_name="test_foo", + model=_create_model("SELECT CURRENT_TIMESTAMP AS cur_timestamp"), + context=Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="bigquery"))), + ) + + spy_execute = mocker.spy(test.engine_adapter, "_execute") + _check_successful_or_raise(test.run()) + + spy_execute.assert_has_calls( + [ + call( + '''SELECT CAST('2023-01-01 12:05:03+00:00' AS TIMESTAMPTZ) AS "cur_timestamp"''', + False, + ) ] ) @model("py_model", columns={"ts1": "timestamptz", "ts2": "timestamptz"}) def execute(context, start, end, execution_time, **kwargs): - datetime_now = datetime.datetime.now() + datetime_now_utc = datetime.datetime.now(tz=datetime.timezone.utc) context.engine_adapter.execute(exp.select("CURRENT_TIMESTAMP")) current_timestamp = context.engine_adapter.cursor.fetchone()[0] - return pd.DataFrame([{"ts1": datetime_now, "ts2": current_timestamp}]) + return pd.DataFrame([{"ts1": datetime_now_utc, "ts2": current_timestamp}]) _check_successful_or_raise( _create_test( @@ -1297,6 +1539,9 @@ def test_gateway(copy_to_temp_path: t.Callable, mocker: MockerFixture) -> None: with open(test_path, "w", encoding="utf-8") as file: dump_yaml(test_dict, file) + # Re-initialize context to pick up the modified test file + context = Context(paths=path, config=config) + spy_execute = mocker.spy(EngineAdapter, "_execute") mocker.patch("sqlmesh.core.test.definition.random_id", return_value="jzngz56a") @@ -1319,13 +1564,13 @@ def test_gateway(copy_to_temp_path: t.Callable, mocker: MockerFixture) -> None: 'AS "t"("id", "customer_id", "waiter_id", "start_ts", "end_ts", "event_date")' ) test_adapter = t.cast(ModelTest, result.successes[0]).engine_adapter - assert call(test_adapter, expected_view_sql) in spy_execute.mock_calls + assert call(test_adapter, expected_view_sql, False) in spy_execute.mock_calls _check_successful_or_raise(context.test()) def test_generate_input_data_using_sql(mocker: MockerFixture, tmp_path: Path) -> None: - init_example_project(tmp_path, dialect="duckdb") + init_example_project(tmp_path, engine_type="duckdb") config = Config( default_connection=DuckDBConnectionConfig(), model_defaults=ModelDefaultsConfig(dialect="duckdb"), @@ -1340,11 +1585,15 @@ def test_generate_input_data_using_sql(mocker: MockerFixture, tmp_path: Path) -> inputs: sqlmesh_example.incremental_model: query: | - SELECT 1 AS id, 1 AS item_id + SELECT 1 AS item_id, 1 AS id + UNION ALL + SELECT 1 AS item_id, 2 AS id UNION ALL - SELECT 2 AS id, 1 AS item_id + SELECT 2 AS item_id, 3 AS id UNION ALL - SELECT 3 AS id, 2 AS item_id + SELECT 3 AS item_id, 4 AS id + UNION ALL + SELECT 4 AS item_id, null AS id outputs: query: rows: @@ -1352,6 +1601,10 @@ def test_generate_input_data_using_sql(mocker: MockerFixture, tmp_path: Path) -> num_orders: 2 - item_id: 2 num_orders: 1 + - item_id: 3 + num_orders: 1 + - item_id: 4 + num_orders: 0 """ ), test_name="test_example_full_model_alt", @@ -1360,6 +1613,59 @@ def test_generate_input_data_using_sql(mocker: MockerFixture, tmp_path: Path) -> ).run() ) + _check_successful_or_raise( + _create_test( + body=load_yaml( + """ +test_example_full_model_partial: + model: sqlmesh_example.full_model + inputs: + sqlmesh_example.incremental_model: + query: | + SELECT 1 as id, + UNION ALL + SELECT 2 as id, + outputs: + query: + partial: true + rows: + - item_id: null + num_orders: 2 + """ + ), + test_name="test_example_full_model_partial", + model=context.get_model("sqlmesh_example.full_model"), + context=context, + ).run() + ) + + _check_successful_or_raise( + _create_test( + body=load_yaml( + """ +test_example_full_model_partial: + model: sqlmesh_example.full_model + inputs: + sqlmesh_example.incremental_model: + rows: + - id: 1 + item_id: 1 + - id: 2 + item_id: 1 + - id: 3 + item_id: 2 + outputs: + query: + partial: true + query: "SELECT 2 AS num_orders UNION ALL SELECT 1 AS num_orders" + """ + ), + test_name="test_example_full_model_partial", + model=context.get_model("sqlmesh_example.full_model"), + context=context, + ).run() + ) + mocker.patch("sqlmesh.core.test.definition.random_id", return_value="jzngz56a") test = _create_test( body=load_yaml( @@ -1383,7 +1689,8 @@ def test_generate_input_data_using_sql(mocker: MockerFixture, tmp_path: Path) -> spy_execute.assert_any_call( 'CREATE OR REPLACE VIEW "memory"."sqlmesh_test_jzngz56a"."foo" AS ' - '''SELECT {'x': 1, 'n': {'y': 2}} AS "struct_value"''' + '''SELECT {'x': 1, 'n': {'y': 2}} AS "struct_value"''', + False, ) with pytest.raises( @@ -1411,12 +1718,12 @@ def test_generate_input_data_using_sql(mocker: MockerFixture, tmp_path: Path) -> ) -def test_pyspark_python_model() -> None: +def test_pyspark_python_model(tmp_path: Path) -> None: spark_connection_config = SparkConnectionConfig( config={ "spark.master": "local", - "spark.sql.warehouse.dir": "/tmp/data_dir", - "spark.driver.extraJavaOptions": "-Dderby.system.home=/tmp/derby_dir", + "spark.sql.warehouse.dir": f"{tmp_path}/data_dir", + "spark.driver.extraJavaOptions": f"-Dderby.system.home={tmp_path}/derby_dir", }, ) config = Config( @@ -1450,91 +1757,329 @@ def execute(context, start, end, execution_time, **kwargs): def test_variable_usage(tmp_path: Path) -> None: - init_example_project(tmp_path, dialect="duckdb") + init_example_project(tmp_path, engine_type="duckdb") - config = Config( - default_connection=DuckDBConnectionConfig(), - model_defaults=ModelDefaultsConfig(dialect="duckdb"), - variables={"gold": "gold_db", "silver": "silver_db"}, - ) - context = Context(paths=tmp_path, config=config) + variables = {"gold": "gold_db", "silver": "silver_db"} + incorrect_variables = {"gold": "foo", "silver": "bar"} parent = _create_model( - "SELECT 1 AS id, '2022-01-01 01:00:00'::TIMESTAMP AS ts", - meta="MODEL (name silver_db.sch.b)", + "SELECT 1 AS id, '2022-01-02'::DATE AS ds, @start_ts AS start_ts", + meta="MODEL (name silver_db.sch.b, kind INCREMENTAL_BY_TIME_RANGE(time_column ds))", ) - parent = t.cast(SqlModel, context.upsert_model(parent)) child = _create_model( - "SELECT @IF(@VAR('myvar'), id, id + 1) AS id FROM silver_db.sch.b", - meta="MODEL (name gold_db.sch.a)", + "SELECT ds, @IF(@VAR('myvar'), id, id + 1) AS id FROM silver_db.sch.b WHERE ds BETWEEN @start_ds and @end_ds", + meta="MODEL (name gold_db.sch.a, kind INCREMENTAL_BY_TIME_RANGE(time_column ds))", ) - child = t.cast(SqlModel, context.upsert_model(child)) - test_file = tmp_path / "tests" / "test_parameterized_model_names.yaml" - test_file.write_text( - """ + test_text = """ test_parameterized_model_names: - model: {{ var('gold') }}.sch.a + model: {{{{ var('gold') }}}}.sch.a {gateway} vars: myvar: True + start_ds: 2022-01-01 + end_ds: 2022-01-03 inputs: - {{ var('silver') }}.sch.b: - - id: 1 - - id: 2 + {{{{ var('silver') }}}}.sch.b: + - ds: 2022-01-01 + id: 1 + - ds: 2022-01-01 + id: 2 outputs: query: - - id: 1 - - id: 2 - """ - ) - - results = context.test() + - ds: 2022-01-01 + id: 1 + - ds: 2022-01-01 + id: 2""" - assert not results.failures - assert not results.errors + test_file = tmp_path / "tests" / "test_parameterized_model_names.yaml" - # The example project has one test and we added another one above - assert len(results.successes) == 2 + def init_context_and_validate_results(config: Config, **kwargs): + context = Context(paths=tmp_path, config=config, **kwargs) + context.upsert_model(parent) + context.upsert_model(child) + results = context.test() -def test_test_generation(tmp_path: Path) -> None: - init_example_project(tmp_path, dialect="duckdb") + assert not results.failures + assert not results.errors + assert len(results.successes) == 2 + # Case 1: Test root variables config = Config( default_connection=DuckDBConnectionConfig(), model_defaults=ModelDefaultsConfig(dialect="duckdb"), + variables=variables, ) - context = Context(paths=tmp_path, config=config) - query = context.get_model("sqlmesh_example.full_model").render_query() - assert isinstance(query, exp.Query) + test_file.write_text(test_text.format(gateway="")) - context.upsert_model( - "sqlmesh_example.full_model", - query=exp.select(*query.named_selects).from_("cte").with_("cte", as_=query), + init_context_and_validate_results(config) + + # Case 2: Test gateway variables + config = Config( + gateways={ + "main": GatewayConfig(connection=DuckDBConnectionConfig(), variables=variables), + }, + model_defaults=ModelDefaultsConfig(dialect="duckdb"), ) + init_context_and_validate_results(config) - context.plan(auto_apply=True) + # Case 3: Test gateway variables overriding root variables + config = Config( + gateways={ + "main": GatewayConfig(connection=DuckDBConnectionConfig(), variables=variables), + }, + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + variables=incorrect_variables, + ) + init_context_and_validate_results(config, gateway="main") - input_queries = { - "sqlmesh_example.incremental_model": "SELECT * FROM sqlmesh_example.incremental_model LIMIT 3" - } + # Case 4: Use variable from the defined gateway + config = Config( + gateways={ + "main": GatewayConfig( + connection=DuckDBConnectionConfig(), variables=incorrect_variables + ), + "secondary": GatewayConfig(connection=DuckDBConnectionConfig(), variables=variables), + }, + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + ) - with pytest.raises(ConfigError) as ex: - context.create_test("sqlmesh_example.full_model", input_queries=input_queries) + test_file.write_text(test_text.format(gateway="\n gateway: secondary")) + init_context_and_validate_results(config, gateway="main") - assert ( - "tests/test_full_model.yaml' already exists, " - "make sure to set --overwrite if it can be safely overwritten." - ) in str(ex.value) + # Case 5: Use gateways with escaped characters + config = Config( + gateways={ + "main": GatewayConfig( + connection=DuckDBConnectionConfig(), variables=incorrect_variables + ), + "secon\tdary": GatewayConfig(connection=DuckDBConnectionConfig(), variables=variables), + }, + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + ) - test = load_yaml(context.path / c.TESTS / "test_full_model.yaml") + test_file.write_text(test_text.format(gateway='\n gateway: "secon\\tdary"')) + init_context_and_validate_results(config, gateway="main") - assert len(test) == 1 - assert "test_example_full_model" in test - assert "vars" not in test["test_example_full_model"] - assert "ctes" not in test["test_example_full_model"]["outputs"] + +def test_custom_testing_schema(mocker: MockerFixture) -> None: + test = _create_test( + body=load_yaml( + """ +test_foo: + model: xyz + schema: my_schema + outputs: + query: + - a: 1 + """ + ), + test_name="test_foo", + model=_create_model("SELECT 1 AS a"), + context=Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))), + ) + + spy_execute = mocker.spy(test.engine_adapter, "_execute") + _check_successful_or_raise(test.run()) + + spy_execute.assert_has_calls( + [ + call('CREATE SCHEMA IF NOT EXISTS "memory"."my_schema"', False), + call('SELECT 1 AS "a"', False), + call('DROP SCHEMA IF EXISTS "memory"."my_schema" CASCADE', False), + ] + ) + + +def test_pretty_query(mocker: MockerFixture) -> None: + test = _create_test( + body=load_yaml( + """ +test_foo: + model: xyz + schema: my_schema + outputs: + query: + - a: 1 + """ + ), + test_name="test_foo", + model=_create_model("SELECT 1 AS a"), + context=Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))), + ) + test.engine_adapter._pretty_sql = True + spy_execute = mocker.spy(test.engine_adapter, "_execute") + _check_successful_or_raise(test.run()) + spy_execute.assert_has_calls( + [ + call('CREATE SCHEMA IF NOT EXISTS "memory"."my_schema"', False), + call('SELECT\n 1 AS "a"', False), + call('DROP SCHEMA IF EXISTS "memory"."my_schema" CASCADE', False), + ] + ) + + +def test_complicated_recursive_cte() -> None: + model_sql = """ +WITH + RECURSIVE + chained_contacts AS ( + -- Start with the initial set of contacts and their immediate nodes + SELECT + id_contact_a, + id_contact_b + FROM + source + + UNION ALL + + -- Recursive step to find further connected nodes + SELECT + chained_contacts.id_contact_a, + unfactorized_duplicates.id_contact_b + FROM + chained_contacts + JOIN source AS unfactorized_duplicates + ON chained_contacts.id_contact_b = unfactorized_duplicates.id_contact_a + ), + id_contact_a_with_aggregated_id_contact_bs AS ( + SELECT + id_contact_a, + ARRAY_AGG(DISTINCT id_contact_b ORDER BY id_contact_b) AS aggregated_id_contact_bs + FROM + chained_contacts + GROUP BY + id_contact_a + ) +SELECT + ARRAY_CONCAT([id_contact_a], aggregated_id_contact_bs) AS aggregated_duplicates +FROM + id_contact_a_with_aggregated_id_contact_bs +WHERE + id_contact_a NOT IN ( + SELECT DISTINCT + id_contact_b + FROM + source + ) +ORDER BY + id_contact_a + """ + + _check_successful_or_raise( + _create_test( + body=load_yaml( + """ +test_recursive_ctes: + model: test + inputs: + source: + rows: + - id_contact_a: "a" + id_contact_b: "b" + - id_contact_a: "b" + id_contact_b: "c" + - id_contact_a: "c" + id_contact_b: "d" + - id_contact_a: "a" + id_contact_b: "g" + - id_contact_a: "b" + id_contact_b: "e" + - id_contact_a: "c" + id_contact_b: "f" + - id_contact_a: "x" + id_contact_b: "y" + outputs: + ctes: + id_contact_a_with_aggregated_id_contact_bs: + - id_contact_a: a + aggregated_id_contact_bs: [b, c, d, e, f, g] + - id_contact_a: x + aggregated_id_contact_bs: [y] + - id_contact_a: b + aggregated_id_contact_bs: [c, d, e, f] + - id_contact_a: c + aggregated_id_contact_bs: [d, f] + query: + rows: + - aggregated_duplicates: [a, b, c, d, e, f, g] + - aggregated_duplicates: [x, y] + """ + ), + test_name="test_recursive_ctes", + model=_create_model(model_sql), + context=Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))), + ).run() + ) + + +def test_unknown_model_warns(mocker: MockerFixture) -> None: + body = load_yaml( + """ +model: unknown +outputs: + query: + - c: 1 + """ + ) + + with patch.object(get_console(), "log_warning") as mock_logger: + ModelTest.create_test( + body=body, + test_name="test_unknown_model", + models={}, # type: ignore + engine_adapter=mocker.Mock(), + dialect=None, + path=None, + ) + assert mock_logger.mock_calls == [call("Model '\"unknown\"' was not found")] + + +def test_test_generation(tmp_path: Path) -> None: + init_example_project(tmp_path, engine_type="duckdb") + + config = Config( + default_connection=DuckDBConnectionConfig(), + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + ) + context = Context(paths=tmp_path, config=config) + + model = context.get_model("sqlmesh_example.full_model") + query = model.render_query() + assert isinstance(query, exp.Query) + + context.upsert_model( + "sqlmesh_example.full_model", + query_=ParsableSql( + sql=exp.select(*query.named_selects) + .from_("cte") + .with_("cte", as_=query) + .sql(dialect=model.dialect) + ), + ) + + context.plan(auto_apply=True) + + input_queries = { + "sqlmesh_example.incremental_model": "SELECT * FROM sqlmesh_example.incremental_model LIMIT 3" + } + + with pytest.raises(ConfigError) as ex: + context.create_test("sqlmesh_example.full_model", input_queries=input_queries) + + assert ( + "tests/test_full_model.yaml' already exists, " + "make sure to set --overwrite if it can be safely overwritten." + ) in str(ex.value) + + test = load_yaml(context.path / c.TESTS / "test_full_model.yaml") + + assert len(test) == 1 + assert "test_example_full_model" in test + assert "vars" not in test["test_example_full_model"] + assert "ctes" not in test["test_example_full_model"]["outputs"] context.create_test( "sqlmesh_example.full_model", @@ -1630,7 +2175,7 @@ def create_test(context: Context, query: str): ) return load_yaml(context.path / c.TESTS / "test_foo.yaml") - init_example_project(tmp_path, dialect="duckdb") + init_example_project(tmp_path, engine_type="duckdb") config = Config( default_connection=DuckDBConnectionConfig(), @@ -1644,12 +2189,12 @@ def create_test(context: Context, query: str): bar_sql_file.write_text("MODEL (name sqlmesh_example.bar); SELECT col FROM external_table;") test = create_test(Context(paths=tmp_path, config=config), f"SELECT {column} AS col") - assert test["test_foo"]["inputs"] == {"sqlmesh_example.bar": expected} + assert test["test_foo"]["inputs"] == {'"memory"."sqlmesh_example"."bar"': expected} assert test["test_foo"]["outputs"] == {"query": expected} def test_test_generation_with_timestamp(tmp_path: Path) -> None: - init_example_project(tmp_path, dialect="duckdb") + init_example_project(tmp_path, engine_type="duckdb") config = Config( default_connection=DuckDBConnectionConfig(), @@ -1667,20 +2212,1316 @@ def test_test_generation_with_timestamp(tmp_path: Path) -> None: input_queries = { "sqlmesh_example.bar": "SELECT TIMESTAMP '2024-09-20 11:30:00.123456789' AS ts_col" } - - context.create_test( - "sqlmesh_example.foo", - input_queries=input_queries, - overwrite=True, - ) + context.create_test("sqlmesh_example.foo", input_queries=input_queries, overwrite=True) test = load_yaml(context.path / c.TESTS / "test_foo.yaml") assert len(test) == 1 assert "test_foo" in test assert test["test_foo"]["inputs"] == { - "sqlmesh_example.bar": [{"ts_col": datetime.datetime(2024, 9, 20, 11, 30, 0, 123456)}] + '"memory"."sqlmesh_example"."bar"': [ + {"ts_col": datetime.datetime(2024, 9, 20, 11, 30, 0, 123456)} + ] } assert test["test_foo"]["outputs"] == { "query": [{"ts_col": datetime.datetime(2024, 9, 20, 11, 30, 0, 123456)}] } + + +def test_test_generation_with_recursive_ctes(tmp_path: Path) -> None: + init_example_project(tmp_path, engine_type="duckdb") + + config = Config( + default_connection=DuckDBConnectionConfig(), + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + ) + foo_sql_file = tmp_path / "models" / "foo.sql" + foo_sql_file.write_text( + "MODEL (name sqlmesh_example.foo);" + "WITH RECURSIVE t AS (SELECT 1 AS c UNION ALL SELECT c + 1 FROM t WHERE c < 3) SELECT c FROM t" + ) + + context = Context(paths=tmp_path, config=config) + context.plan(auto_apply=True) + + context.create_test("sqlmesh_example.foo", input_queries={}, overwrite=True, include_ctes=True) + + test = load_yaml(context.path / c.TESTS / "test_foo.yaml") + assert len(test) == 1 + assert "test_foo" in test + assert test["test_foo"]["inputs"] == {} + assert test["test_foo"]["outputs"] == { + "query": [{"c": 1}, {"c": 2}, {"c": 3}], + "ctes": { + "t": [{"c": 1}, {"c": 2}, {"c": 3}], + }, + } + + _check_successful_or_raise(context.test()) + + +def test_test_with_gateway_specific_model(tmp_path: Path, mocker: MockerFixture) -> None: + init_example_project(tmp_path, engine_type="duckdb") + + config = Config( + gateways={ + "main": GatewayConfig(connection=DuckDBConnectionConfig()), + "second": GatewayConfig(connection=DuckDBConnectionConfig()), + }, + default_gateway="main", + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + ) + gw_model_sql_file = tmp_path / "models" / "gw_model.sql" + + # The model has a gateway specified which isn't the default + gw_model_sql_file.write_text( + "MODEL (name sqlmesh_example.gw_model, gateway second); SELECT c FROM sqlmesh_example.input_model;" + ) + input_model_sql_file = tmp_path / "models" / "input_model.sql" + input_model_sql_file.write_text( + "MODEL (name sqlmesh_example.input_model); SELECT c FROM external_table;" + ) + + context = Context(paths=tmp_path, config=config) + input_queries = {'"memory"."sqlmesh_example"."input_model"': "SELECT 5 AS c"} + + assert context.engine_adapter == context.engine_adapters["main"] + with pytest.raises( + SQLMeshError, match=r"Gateway 'wrong' not found in the available engine adapters." + ): + context._get_engine_adapter("wrong") + + # Create test should use the gateway specific engine adapter + context.create_test("sqlmesh_example.gw_model", input_queries=input_queries, overwrite=True) + assert context._get_engine_adapter("second") == context.engine_adapters["second"] + assert len(context.engine_adapters) == 2 + + test = load_yaml(context.path / c.TESTS / "test_gw_model.yaml") + + assert len(test) == 1 + assert "test_gw_model" in test + assert test["test_gw_model"]["inputs"] == { + '"memory"."sqlmesh_example"."input_model"': [{"c": 5}] + } + assert test["test_gw_model"]["outputs"] == {"query": [{"c": 5}]} + + +def test_test_with_resolve_template_macro(tmp_path: Path): + config = Config( + default_connection=DuckDBConnectionConfig(), + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + ) + + models_dir = tmp_path / "models" + models_dir.mkdir() + (models_dir / "foo.sql").write_text( + """ + MODEL ( + name test.foo, + kind full, + physical_properties ( + location = @resolve_template('file:///tmp/@{table_name}') + ) + ); + + SELECT t.a + 1 as a + FROM @resolve_template('@{schema_name}.dev_@{table_name}', mode := 'table') as t + """ + ) + + tests_dir = tmp_path / "tests" + tests_dir.mkdir() + (tests_dir / "test_foo.yaml").write_text( + """ +test_resolve_template_macro: + model: test.foo + inputs: + test.dev_foo: + - a: 1 + outputs: + query: + - a: 2 + """ + ) + + context = Context(paths=tmp_path, config=config) + _check_successful_or_raise(context.test()) + + +@use_terminal_console +def test_test_output(tmp_path: Path) -> None: + def copy_test_file(test_file: Path, new_test_file: Path, index: int) -> None: + with open(test_file, "r") as file: + filedata = file.read() + + with open(new_test_file, "w") as file: + file.write(filedata.replace("test_example_full_model", f"test_{index}")) + + init_example_project(tmp_path, engine_type="duckdb") + + original_test_file = tmp_path / "tests" / "test_full_model.yaml" + + new_test_file = tmp_path / "tests" / "test_full_model_error.yaml" + new_test_file.write_text( + """ +test_example_full_model: + model: sqlmesh_example.full_model + description: This is a test + inputs: + sqlmesh_example.incremental_model: + rows: + - id: 1 + item_id: 1 + - id: 2 + item_id: 1 + - id: 3 + item_id: 2 + outputs: + query: + rows: + - item_id: 1 + num_orders: 2 + - item_id: 4 + num_orders: 3 + """ + ) + + config = Config( + default_connection=DuckDBConnectionConfig(), + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + default_test_connection=DuckDBConnectionConfig(concurrent_tasks=8), + ) + context = Context(paths=tmp_path, config=config) + + # Case 1: Ensure the log report is structured correctly + with capture_output() as captured_output: + context.test() + + output = captured_output.stdout + + # Order may change due to concurrent execution + assert "F." in output or ".F" in output + assert ( + f"""This is a test +---------------------------------------------------------------------- + Data mismatch +┏━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┓ +┃ ┃ item_id: ┃ ┃ num_orders: ┃ num_orders: ┃ +┃ Row ┃ Expected ┃ item_id: Actual ┃ Expected ┃ Actual ┃ +┡━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━┩ +│ 1 │ 4.0 │ 2.0 │ 3.0 │ 1.0 │ +└─────┴─────────────────┴─────────────────┴─────────────────┴──────────────────┘ + +----------------------------------------------------------------------""" + in output + ) + + assert "Ran 2 tests" in output + assert "Failed tests (1):" in output + + # Case 2: Ensure that the verbose log report is structured correctly + with capture_output() as captured_output: + context.test(verbosity=Verbosity.VERBOSE) + + output = captured_output.stdout + + assert ( + f"""This is a test +---------------------------------------------------------------------- + Column 'item_id' mismatch +┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ +┃ Row ┃ Expected ┃ Actual ┃ +┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ +│ 1 │ 4.0 │ 2.0 │ +└─────────────┴────────────────────────┴───────────────────┘ + + Column 'num_orders' mismatch +┏━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓ +┃ Row ┃ Expected ┃ Actual ┃ +┡━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩ +│ 1 │ 3.0 │ 1.0 │ +└─────────────┴────────────────────────┴───────────────────┘ + +----------------------------------------------------------------------""" + in output + ) + + # Case 3: Assert that concurrent execution is working properly + for i in range(50): + copy_test_file(original_test_file, tmp_path / "tests" / f"test_success_{i}.yaml", i) + copy_test_file(new_test_file, tmp_path / "tests" / f"test_failure_{i}.yaml", i) + + # Re-initialize context to pick up the new test files + context = Context(paths=tmp_path, config=config) + + with capture_output() as captured_output: + context.test() + + output = captured_output.stdout + + assert "Ran 102 tests" in output + assert "Failed tests (51):" in output + + # Case 4: Test that wide tables are split into even chunks for default verbosity + rmtree(tmp_path / "tests") + + wide_model_query = ( + "SELECT 1 AS col_1, 2 AS col_2, 3 AS col_3, 4 AS col_4, 5 AS col_5, 6 AS col_6, 7 AS col_7" + ) + + wide_model = _create_model( + meta="MODEL(name test.test_wide_model)", + query=wide_model_query, + default_catalog=context.default_catalog, + ) + context.upsert_model(wide_model) + + tests_dir = tmp_path / "tests" + tests_dir.mkdir() + + wide_test_file = tmp_path / "tests" / "test_wide_model.yaml" + wide_test_file_content = """ + test_wide_model: + model: test.test_wide_model + outputs: + query: + rows: + - col_1: 6 + col_2: 5 + col_3: 4 + col_4: 3 + col_5: 2 + col_6: 1 + col_7: 0 + + """ + + wide_test_file.write_text(wide_test_file_content) + + context.load() + context.upsert_model(wide_model) + + with capture_output() as captured_output: + context.test() + + assert ( + """Data mismatch +┏━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━━━━━━┳━━━━━━━━━┳━━━━━━━━┓ +┃ ┃ col_1: ┃ col_1: ┃ col_2: ┃ col_2: ┃ col_3: ┃ col_3: ┃ col_4: ┃ col_4: ┃ +┃ Row ┃ Expec… ┃ Actual ┃ Expec… ┃ Actual ┃ Expec… ┃ Actual ┃ Expect… ┃ Actual ┃ +┡━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━━╇━━━━━━━━━╇━━━━━━━━┩ +│ 0 │ 6 │ 1 │ 5 │ 2 │ 4 │ 3 │ 3 │ 4 │ +└─────┴────────┴────────┴────────┴────────┴────────┴────────┴─────────┴────────┘ + + Data mismatch +┏━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━┓ +┃ ┃ col_5: ┃ col_5: ┃ col_6: ┃ col_6: ┃ col_7: ┃ col_7: ┃ +┃ Row ┃ Expected ┃ Actual ┃ Expected ┃ Actual ┃ Expected ┃ Actual ┃ +┡━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━┩ +│ 0 │ 2 │ 5 │ 1 │ 6 │ 0 │ 7 │ +└─────┴───────────┴───────────┴───────────┴───────────┴───────────┴────────────┘""" + in captured_output.stdout + ) + + # Case 5: Test null value difference in the 3rd row (index 2) + rmtree(tmp_path / "tests") + tests_dir = tmp_path / "tests" + tests_dir.mkdir() + + null_test_file = tmp_path / "tests" / "test_null_in_third_row.yaml" + null_test_file.write_text( + """ +test_null_third_row: + model: sqlmesh_example.full_model + description: Test null value in third row + inputs: + sqlmesh_example.incremental_model: + rows: + - id: 1 + item_id: 1 + - id: 2 + item_id: 1 + - id: 3 + item_id: 2 + - id: 4 + item_id: 3 + outputs: + query: + rows: + - item_id: 1 + num_orders: 2 + - item_id: 2 + num_orders: 1 + - item_id: 3 + num_orders: null + """ + ) + + # Re-initialize context to pick up the modified test file + context = Context(paths=tmp_path, config=config) + + with capture_output() as captured_output: + context.test() + + output = captured_output.stdout + + # Check for null value difference in the 3rd row (index 2) + assert ( + """ +┏━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━┓ +┃ Row ┃ num_orders: Expected ┃ num_orders: Actual ┃ +┡━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━┩ +│ 2 │ nan │ 1.0 │ +└──────┴───────────────────────────┴───────────────────────┘""" + in output + ) + + +@use_terminal_console +def test_test_output_with_invalid_model_name(tmp_path: Path) -> None: + init_example_project(tmp_path, engine_type="duckdb") + + wrong_test_file = tmp_path / "tests" / "test_incorrect_model_name.yaml" + wrong_test_file.write_text( + """ +test_example_full_model: + model: invalid_model + description: This is an invalid test + inputs: + sqlmesh_example.incremental_model: + rows: + - id: 1 + item_id: 1 + - id: 2 + item_id: 1 + - id: 3 + item_id: 2 + outputs: + query: + rows: + - item_id: 1 + num_orders: 2 + - item_id: 2 + num_orders: 2 + """ + ) + + config = Config( + default_connection=DuckDBConnectionConfig(), + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + ) + context = Context(paths=tmp_path, config=config) + + with patch.object(get_console(), "log_warning") as mock_logger: + with capture_output() as output: + context.test() + + assert ( + f"""Model '"invalid_model"' was not found at {wrong_test_file}""" + in mock_logger.call_args[0][0] + ) + assert "Successfully Ran 1 test" in output.stdout + + +def test_number_of_tests_found(tmp_path: Path) -> None: + init_example_project(tmp_path, engine_type="duckdb") + + # Example project contains 1 test and we add a new file with 2 tests + test_file = tmp_path / "tests" / "test_new.yaml" + test_file.write_text( + """ +test_example_full_model1: + model: sqlmesh_example.full_model + inputs: + sqlmesh_example.incremental_model: + rows: + - id: 1 + item_id: 1 + - id: 2 + item_id: 1 + - id: 3 + item_id: 2 + outputs: + query: + rows: + - item_id: 1 + num_orders: 2 + - item_id: 2 + num_orders: 1 + +test_example_full_model2: + model: sqlmesh_example.full_model + inputs: + sqlmesh_example.incremental_model: + rows: + - id: 1 + item_id: 1 + - id: 2 + item_id: 1 + - id: 3 + item_id: 2 + outputs: + query: + rows: + - item_id: 1 + num_orders: 2 + - item_id: 2 + num_orders: 1 + """ + ) + + context = Context(paths=tmp_path) + + # Case 1: All 3 tests should run without any tests specified + results = context.test() + assert len(results.successes) == 3 + + # Case 2: The "new_test.yaml" should amount to 2 subtests + results = context.test(tests=[f"{test_file}"]) + assert len(results.successes) == 2 + + # Case 3: The "new_test.yaml::test_example_full_model2" should amount to a single subtest + results = context.test(tests=[f"{test_file}::test_example_full_model2"]) + assert len(results.successes) == 1 + + +def test_freeze_time_concurrent(tmp_path: Path) -> None: + tests_dir = tmp_path / "tests" + tests_dir.mkdir() + + macros_dir = tmp_path / "macros" + macros_dir.mkdir() + + macro_file = macros_dir / "test_datetime_now.py" + macro_file.write_text( + """ +from sqlglot import exp +import datetime +from sqlmesh.core.macros import macro + +@macro() +def test_datetime_now(evaluator): + return exp.cast(exp.Literal.string(datetime.datetime.now(tz=datetime.timezone.utc)), exp.DataType.Type.DATE) + +@macro() +def test_sqlglot_expr(evaluator): + return exp.CurrentDate().sql(evaluator.dialect) + """ + ) + + models_dir = tmp_path / "models" + models_dir.mkdir() + sql_model1 = models_dir / "sql_model1.sql" + sql_model1.write_text( + """ + MODEL(NAME sql_model1); + SELECT @test_datetime_now() AS col_exec_ds_time, @test_sqlglot_expr() AS col_current_date; + """ + ) + + for model_name in ["sql_model1", "sql_model2", "py_model"]: + for i in range(5): + test_2019 = tmp_path / "tests" / f"test_2019_{model_name}_{i}.yaml" + test_2019.write_text( + f""" + test_2019_{model_name}_{i}: + model: {model_name} + vars: + execution_time: '2019-12-01' + outputs: + query: + rows: + - col_exec_ds_time: '2019-12-01' + col_current_date: '2019-12-01' + """ + ) + + test_2025 = tmp_path / "tests" / f"test_2025_{model_name}_{i}.yaml" + test_2025.write_text( + f""" + test_2025_{model_name}_{i}: + model: {model_name} + vars: + execution_time: '2025-12-01' + outputs: + query: + rows: + - col_exec_ds_time: '2025-12-01' + col_current_date: '2025-12-01' + """ + ) + + ctx = Context( + paths=tmp_path, + config=Config(default_test_connection=DuckDBConnectionConfig(concurrent_tasks=8)), + ) + + @model( + "py_model", + columns={"col_exec_ds_time": "timestamp_ntz", "col_current_date": "timestamp_ntz"}, + ) + def execute(context, start, end, execution_time, **kwargs): + datetime_now_utc = datetime.datetime.now(tz=datetime.timezone.utc) + + context.engine_adapter.execute(exp.select("CURRENT_DATE()")) + current_date = context.engine_adapter.cursor.fetchone()[0] + + return pd.DataFrame( + [{"col_exec_ds_time": datetime_now_utc, "col_current_date": current_date}] + ) + + python_model = model.get_registry()["py_model"].model(module_path=Path("."), path=Path(".")) + ctx.upsert_model(python_model) + + ctx.upsert_model( + _create_model( + meta="MODEL(NAME sql_model2)", + query="SELECT @execution_ds::timestamp_ntz AS col_exec_ds_time, current_date()::date AS col_current_date", + default_catalog=ctx.default_catalog, + ) + ) + + results = ctx.test() + assert len(results.successes) == 30 + + +def test_python_model_upstream_table(sushi_context) -> None: + @model( + "test_upstream_table_python", + columns={"customer_id": "int", "zip": "str"}, + ) + def upstream_table_python(context, **kwargs): + demographics_external_table = context.resolve_table("memory.raw.demographics") + return context.fetchdf( + exp.select("customer_id", "zip").from_(demographics_external_table), + ) + + python_model = model.get_registry()["test_upstream_table_python"].model( + module_path=Path("."), + path=Path("."), + ) + + context = ExecutionContext(sushi_context.engine_adapter, sushi_context.snapshots, None, None) + df = list(python_model.render(context=context))[0] + + # Verify the actual model output matches the expected actual external table's values + assert df.to_dict(orient="records") == [{"customer_id": 1, "zip": "00000"}] + + # Use different input values for the test and verify the outputs + _check_successful_or_raise( + _create_test( + body=load_yaml(""" +test_test_upstream_table_python: + model: test_upstream_table_python + inputs: + memory.raw.demographics: + - customer_id: 12 + zip: "S11HA" + - customer_id: 555 + zip: "94401" + outputs: + query: + - customer_id: 12 + zip: "S11HA" + - customer_id: 555 + zip: "94401" +"""), + test_name="test_test_upstream_table_python", + model=model.get_registry()["test_upstream_table_python"].model( + module_path=Path("."), path=Path(".") + ), + context=sushi_context, + ).run() + ) + + +@use_terminal_console +@pytest.mark.parametrize("is_error", [True, False]) +def test_model_test_text_result_reporting_no_traceback( + sushi_context: Context, full_model_with_two_ctes: SqlModel, is_error: bool +) -> None: + test = _create_test( + body=load_yaml( + """ +test_foo: + model: sushi.foo + inputs: + raw: + - id: 1 + outputs: + ctes: + source: + - id: 1 + renamed: + - fid: 1 + vars: + start: 2022-01-01 + end: 2022-01-01 + """ + ), + test_name="test_foo", + model=sushi_context.upsert_model(full_model_with_two_ctes), + context=sushi_context, + ) + stream = io.StringIO() + result = ModelTextTestResult( + stream=unittest.runner._WritelnDecorator(stream), # type: ignore + verbosity=1, + descriptions=True, + ) + + try: + raise Exception("failure") + except Exception as e: + assert e.__traceback__ is not None + if is_error: + result.addError(test, (e.__class__, e, e.__traceback__)) + else: + result.addFailure(test, (e.__class__, e, e.__traceback__)) + + # Since we're simulating an error/failure, this doesn't go through the + # test runner logic, so we need to manually set how many tests were ran + result.testsRun = 1 + + with capture_output() as captured_output: + get_console().log_test_results(result, "duckdb") + + output = captured_output.stdout + + # Make sure that the traceback is not printed + assert "Traceback" not in output + assert "File" not in output + assert "line" not in output + + prefix = "ERROR" if is_error else "FAIL" + assert f"{prefix}: test_foo (None)" in output + assert "Exception: failure" in output + + +def test_timestamp_normalization() -> None: + model = _create_model( + "SELECT id, array_agg(timestamp_col::timestamp) as agg_timestamp_col FROM temp_model_with_timestamp GROUP BY id", + meta="MODEL (name foo, kind FULL)", + ) + + _check_successful_or_raise( + _create_test( + body=load_yaml( + """ + test_foo: + model: temp_agg_model_with_timestamp + inputs: + temp_model_with_timestamp: + rows: + - id: "id1" + timestamp_col: "2024-01-02T15:00:00" + outputs: + query: + rows: + - id: id1 + agg_timestamp_col: ["2024-01-02T15:00:00.000000"] + """ + ), + test_name="test_foo", + model=model, + context=Context(config=Config(model_defaults=ModelDefaultsConfig(dialect="duckdb"))), + ).run() + ) + + +@use_terminal_console +def test_disable_test_logging_if_no_tests_found(mocker: MockerFixture, tmp_path: Path) -> None: + init_example_project(tmp_path, engine_type="duckdb") + + config = Config( + default_connection=DuckDBConnectionConfig(), + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + default_test_connection=DuckDBConnectionConfig(concurrent_tasks=8), + ) + + rmtree(tmp_path / "tests") + + with capture_output() as captured_output: + context = Context(paths=tmp_path, config=config) + context.plan(no_prompts=True, auto_apply=True) + + output = captured_output.stdout + assert "test" not in output.lower() + + +def test_test_generation_with_timestamp_nat(tmp_path: Path) -> None: + init_example_project(tmp_path, engine_type="duckdb") + + config = Config( + default_connection=DuckDBConnectionConfig(), + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + ) + foo_sql_file = tmp_path / "models" / "foo.sql" + foo_sql_file.write_text( + "MODEL (name sqlmesh_example.foo); SELECT ts_col FROM sqlmesh_example.bar;" + ) + bar_sql_file = tmp_path / "models" / "bar.sql" + bar_sql_file.write_text("MODEL (name sqlmesh_example.bar); SELECT ts_col FROM external_table;") + + context = Context(paths=tmp_path, config=config) + + # This simulates the scenario where upstream models have NULL timestamp values + input_queries = { + "sqlmesh_example.bar": """ + SELECT ts_col FROM ( + VALUES + (TIMESTAMP '2024-09-20 11:30:00.123456789'), + (CAST(NULL AS TIMESTAMP)), + (TIMESTAMP '2024-09-21 15:45:00.987654321') + ) AS t(ts_col) + """ + } + + # This should not raise an exception even with NULL timestamp values + context.create_test("sqlmesh_example.foo", input_queries=input_queries, overwrite=True) + + test = load_yaml(context.path / c.TESTS / "test_foo.yaml") + assert len(test) == 1 + assert "test_foo" in test + + # Verify that the test was created with correct input and output data + inputs = test["test_foo"]["inputs"] + outputs = test["test_foo"]["outputs"] + + # Check that we have the expected input table + assert '"memory"."sqlmesh_example"."bar"' in inputs + bar_data = inputs['"memory"."sqlmesh_example"."bar"'] + + # Verify we have 3 rows (2 with timestamps, 1 with NULL) + assert len(bar_data) == 3 + + # Verify that non-NULL timestamps are preserved + assert bar_data[0]["ts_col"] == datetime.datetime(2024, 9, 20, 11, 30, 0, 123456) + assert bar_data[2]["ts_col"] == datetime.datetime(2024, 9, 21, 15, 45, 0, 987654) + + # Verify that NULL timestamp is represented as None (not NaT) + assert bar_data[1]["ts_col"] is None + + # Verify that the output matches the input (since the model just selects from bar) + query_output = outputs["query"] + assert len(query_output) == 3 + assert query_output[0]["ts_col"] == datetime.datetime(2024, 9, 20, 11, 30, 0, 123456) + assert query_output[1]["ts_col"] is None + assert query_output[2]["ts_col"] == datetime.datetime(2024, 9, 21, 15, 45, 0, 987654) + + +def test_parameterized_name_sql_model() -> None: + variables = {"table_catalog": "gold"} + model = _create_model( + "select 1 as id, 'foo' as name", + meta=""" + MODEL ( + name @{table_catalog}.sushi.foo, + kind FULL + ) + """, + dialect="snowflake", + variables=variables, + ) + assert model.fqn == '"GOLD"."SUSHI"."FOO"' + + test = _create_test( + body=load_yaml( + """ +test_foo: + model: {{ var('table_catalog' ) }}.sushi.foo + outputs: + query: + - id: 1 + name: foo + """, + variables=variables, + ), + test_name="test_foo", + model=model, + context=Context( + config=Config( + model_defaults=ModelDefaultsConfig(dialect="snowflake"), variables=variables + ) + ), + ) + + assert test.body["model"] == '"GOLD"."SUSHI"."FOO"' + + _check_successful_or_raise(test.run()) + + +def test_parameterized_name_python_model() -> None: + variables = {"table_catalog": "gold"} + + @model( + name="@{table_catalog}.sushi.foo", + columns={ + "id": "int", + "name": "varchar", + }, + dialect="snowflake", + ) + def execute( + context: ExecutionContext, + **kwargs: t.Any, + ) -> pd.DataFrame: + return pd.DataFrame([{"ID": 1, "NAME": "foo"}]) + + python_model = model.get_registry()["@{table_catalog}.sushi.foo"].model( + module_path=Path("."), path=Path("."), variables=variables + ) + + assert python_model.fqn == '"GOLD"."SUSHI"."FOO"' + + test = _create_test( + body=load_yaml( + """ +test_foo: + model: {{ var('table_catalog' ) }}.sushi.foo + outputs: + query: + - id: 1 + name: foo + """, + variables=variables, + ), + test_name="test_foo", + model=python_model, + context=Context( + config=Config( + model_defaults=ModelDefaultsConfig(dialect="snowflake"), variables=variables + ) + ), + ) + + assert test.body["model"] == '"GOLD"."SUSHI"."FOO"' + + _check_successful_or_raise(test.run()) + + +def test_parameterized_name_self_referential_model(): + variables = {"table_catalog": "gold"} + model = _create_model( + """ + with last_value as ( + select coalesce(max(v), 0) as v from @{table_catalog}.sushi.foo + ) + select v + 1 as v from last_value + """, + meta=""" + MODEL ( + name @{table_catalog}.sushi.foo, + kind FULL + ) + """, + dialect="snowflake", + variables=variables, + ) + assert model.fqn == '"GOLD"."SUSHI"."FOO"' + + test1 = _create_test( + body=load_yaml( + """ +test_foo_intial_state: + model: {{ var('table_catalog' ) }}.sushi.foo + inputs: + {{ var('table_catalog' ) }}.sushi.foo: + rows: [] + columns: + v: int + outputs: + query: + - v: 1 + """, + variables=variables, + ), + test_name="test_foo_intial_state", + model=model, + context=Context( + config=Config( + model_defaults=ModelDefaultsConfig(dialect="snowflake"), variables=variables + ) + ), + ) + assert isinstance(test1, SqlModelTest) + assert test1.body["model"] == '"GOLD"."SUSHI"."FOO"' + test1_model_query = test1._render_model_query().sql(dialect="snowflake") + assert '"GOLD"."SUSHI"."FOO"' not in test1_model_query + assert ( + test1._test_fixture_table('"GOLD"."SUSHI"."FOO"').sql(dialect="snowflake", identify=True) + in test1_model_query + ) + + test2 = _create_test( + body=load_yaml( + """ +test_foo_cumulative: + model: {{ var('table_catalog' ) }}.sushi.foo + inputs: + {{ var('table_catalog' ) }}.sushi.foo: + rows: + - v: 5 + outputs: + query: + - v: 6 + """, + variables=variables, + ), + test_name="test_foo_cumulative", + model=model, + context=Context( + config=Config( + model_defaults=ModelDefaultsConfig(dialect="snowflake"), variables=variables + ) + ), + ) + assert isinstance(test2, SqlModelTest) + assert test2.body["model"] == '"GOLD"."SUSHI"."FOO"' + test2_model_query = test2._render_model_query().sql(dialect="snowflake") + assert '"GOLD"."SUSHI"."FOO"' not in test2_model_query + assert ( + test2._test_fixture_table('"GOLD"."SUSHI"."FOO"').sql(dialect="snowflake", identify=True) + in test2_model_query + ) + + _check_successful_or_raise(test1.run()) + _check_successful_or_raise(test2.run()) + + +def test_parameterized_name_self_referential_python_model(): + variables = {"table_catalog": "gold"} + + @model( + name="@{table_catalog}.sushi.foo", + columns={ + "id": "int", + }, + depends_on=["@{table_catalog}.sushi.bar"], + dialect="snowflake", + ) + def execute( + context: ExecutionContext, + **kwargs: t.Any, + ) -> pd.DataFrame: + current_table = context.resolve_table(f"{context.var('table_catalog')}.sushi.foo") + current_df = context.fetchdf(f"select id from {current_table}") + upstream_table = context.resolve_table(f"{context.var('table_catalog')}.sushi.bar") + upstream_df = context.fetchdf(f"select id from {upstream_table}") + + return pd.DataFrame([{"ID": upstream_df["ID"].sum() + current_df["ID"].sum()}]) + + @model( + name="@{table_catalog}.sushi.bar", + columns={ + "id": "int", + }, + dialect="snowflake", + ) + def execute( + context: ExecutionContext, + **kwargs: t.Any, + ) -> pd.DataFrame: + return pd.DataFrame([{"ID": 1}]) + + model_foo = model.get_registry()["@{table_catalog}.sushi.foo"].model( + module_path=Path("."), path=Path("."), variables=variables + ) + model_bar = model.get_registry()["@{table_catalog}.sushi.bar"].model( + module_path=Path("."), path=Path("."), variables=variables + ) + + assert model_foo.fqn == '"GOLD"."SUSHI"."FOO"' + assert model_bar.fqn == '"GOLD"."SUSHI"."BAR"' + + ctx = Context( + config=Config(model_defaults=ModelDefaultsConfig(dialect="snowflake"), variables=variables) + ) + ctx.upsert_model(model_foo) + ctx.upsert_model(model_bar) + + test = _create_test( + body=load_yaml( + """ +test_foo: + model: {{ var('table_catalog') }}.sushi.foo + inputs: + {{ var('table_catalog') }}.sushi.foo: + rows: + - id: 3 + {{ var('table_catalog') }}.sushi.bar: + rows: + - id: 5 + outputs: + query: + - id: 8 + """, + variables=variables, + ), + test_name="test_foo", + model=model_foo, + context=ctx, + ) + + assert isinstance(test, PythonModelTest) + + assert test.body["model"] == '"GOLD"."SUSHI"."FOO"' + assert '"GOLD"."SUSHI"."BAR"' in test.body["inputs"] + + assert isinstance(test.context, TestExecutionContext) + assert '"GOLD"."SUSHI"."FOO"' in test.context._model_tables + assert '"GOLD"."SUSHI"."BAR"' in test.context._model_tables + + with pytest.raises(SQLMeshError, match=r"Unable to find a table mapping"): + test.context.resolve_table("silver.sushi.bar") + + _check_successful_or_raise(test.run()) + + +def test_python_model_test_variables_override(tmp_path: Path) -> None: + py_model = tmp_path / "models" / "test_var_model.py" + py_model.parent.mkdir(parents=True, exist_ok=True) + py_model.write_text( + """ +import pandas as pd # noqa: TID253 +from sqlmesh import model, ExecutionContext +import typing as t + +@model( + name="test_var_model", + columns={"id": "int", "flag_value": "boolean", "var_value": "varchar"}, +) +def execute(context: ExecutionContext, **kwargs: t.Any) -> pd.DataFrame: + my_flag = context.var("my_flag") + other_var = context.var("other_var") + + return pd.DataFrame([{ + "id": 1 if my_flag else 2, + "flag_value": my_flag, + "var_value": other_var, + }])""" + ) + + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + variables={"my_flag": False, "other_var": "default_value"}, + ) + context = Context(config=config, paths=tmp_path) + + python_model = context.models['"test_var_model"'] + + # Test when Flag is True + # Overriding the config default flag_value to True + # AND the var_value to use test one + test_flag_true = _create_test( + body=load_yaml(""" +test_flag_true: + model: test_var_model + vars: + my_flag: true + other_var: "test_value" + outputs: + query: + rows: + - id: 1 + flag_value: true + var_value: "test_value" + """), + test_name="test_flag_true", + model=python_model, + context=context, + ) + + _check_successful_or_raise(test_flag_true.run()) + + # Test when Flag is False + # Overriding the config default flag_value to False + # AND the var_value to use test one (since the above would be false for both) + test_flag_false = _create_test( + body=load_yaml(""" +test_flag_false: + model: test_var_model + vars: + my_flag: false + other_var: "another_test_value" + outputs: + query: + rows: + - id: 2 + flag_value: false + var_value: "another_test_value" + """), + test_name="test_flag_false", + model=python_model, + context=context, + ) + + _check_successful_or_raise(test_flag_false.run()) + + # Test with no vars specified + # (should use config defaults for both flag and var_value) + test_default_vars = _create_test( + body=load_yaml(""" +test_default_vars: + model: test_var_model + outputs: + query: + rows: + - id: 2 + flag_value: false + var_value: "default_value" + """), + test_name="test_default_vars", + model=python_model, + context=context, + ) + _check_successful_or_raise(test_default_vars.run()) + + +def test_python_model_sorting(tmp_path: Path) -> None: + py_model = tmp_path / "models" / "test_sort_model.py" + py_model.parent.mkdir(parents=True, exist_ok=True) + py_model.write_text( + """ +import pandas as pd # noqa: TID253 +from sqlmesh import model, ExecutionContext +import typing as t + +@model( + name="test_sort_model", + columns={"id": "int", "value": "varchar"}, +) +def execute(context: ExecutionContext, **kwargs: t.Any) -> pd.DataFrame: + # Return rows in a potentially non-deterministic order + # (simulating a model that doesn't guarantee order) + return pd.DataFrame([ + {"id": 3, "value": "c"}, + {"id": 1, "value": "a"}, + {"id": 2, "value": "b"}, + ])""" + ) + + config = Config(model_defaults=ModelDefaultsConfig(dialect="duckdb")) + context = Context(config=config, paths=tmp_path) + + python_model = context.models['"test_sort_model"'] + + _check_successful_or_raise( + _create_test( + body=load_yaml(""" + test_without_sort: + model: test_sort_model + outputs: + query: + rows: + - id: 1 + value: "a" + - id: 2 + value: "b" + - id: 3 + value: "c" + """), + test_name="test_without_sort", + model=python_model, + context=context, + ).run() + ) + + +@use_terminal_console +def test_cte_failure(tmp_path: Path) -> None: + models_dir = tmp_path / "models" + models_dir.mkdir() + (models_dir / "foo.sql").write_text( + """ + MODEL ( + name test.foo, + kind full + ); + + with model_cte as ( + SELECT 1 AS id + ) + SELECT id FROM model_cte + """ + ) + + config = Config( + default_connection=DuckDBConnectionConfig(), + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + ) + context = Context(paths=tmp_path, config=config) + + expected_cte_failure_output = """Data mismatch (CTE "model_cte") +┏━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓ +┃ Row ┃ id: Expected ┃ id: Actual ┃ +┡━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩ +│ 0 │ 2 │ 1 │ +└──────────┴─────────────────────────┴─────────────────────┘""" + + expected_query_failure_output = """Data mismatch +┏━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓ +┃ Row ┃ id: Expected ┃ id: Actual ┃ +┡━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩ +│ 0 │ 2 │ 1 │ +└──────────┴─────────────────────────┴─────────────────────┘""" + + # Case 1: Ensure that a single CTE failure is reported correctly + tests_dir = tmp_path / "tests" + tests_dir.mkdir() + (tests_dir / "test_foo.yaml").write_text( + """ +test_foo: + model: test.foo + outputs: + ctes: + model_cte: + rows: + - id: 2 + query: + - id: 1 + """ + ) + + # Re-initialize context to pick up the new test file + context = Context(paths=tmp_path, config=config) + + with capture_output() as captured_output: + context.test() + + output = captured_output.stdout + + assert expected_cte_failure_output in output + assert expected_query_failure_output not in output + + assert "Ran 1 tests" in output + assert "Failed tests (1)" in output + + # Case 2: Ensure that both CTE and query failures are reported correctly + (tests_dir / "test_foo.yaml").write_text( + """ +test_foo: + model: test.foo + outputs: + ctes: + model_cte: + rows: + - id: 2 + query: + - id: 2 + """ + ) + + # Re-initialize context to pick up the modified test file + context = Context(paths=tmp_path, config=config) + + with capture_output() as captured_output: + context.test() + + output = captured_output.stdout + + assert expected_cte_failure_output in output + assert expected_query_failure_output in output + + assert "Ran 1 tests" in output + assert "Failed tests (1)" in output diff --git a/tests/dbt/cli/conftest.py b/tests/dbt/cli/conftest.py new file mode 100644 index 0000000000..26757bf3ab --- /dev/null +++ b/tests/dbt/cli/conftest.py @@ -0,0 +1,11 @@ +import typing as t +import functools +from click.testing import CliRunner, Result +import pytest + + +@pytest.fixture +def invoke_cli() -> t.Callable[..., Result]: + from sqlmesh_dbt.cli import dbt + + return functools.partial(CliRunner().invoke, dbt) diff --git a/tests/dbt/cli/test_global_flags.py b/tests/dbt/cli/test_global_flags.py new file mode 100644 index 0000000000..7e2262bd80 --- /dev/null +++ b/tests/dbt/cli/test_global_flags.py @@ -0,0 +1,187 @@ +import typing as t +from pathlib import Path +import pytest +import logging +from pytest_mock import MockerFixture +from click.testing import Result +from sqlmesh.utils.errors import SQLMeshError +from sqlglot.errors import SqlglotError +from tests.dbt.conftest import EmptyProjectCreator + +pytestmark = pytest.mark.slow + + +def test_profile_and_target(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]): + # profile doesnt exist - error + result = invoke_cli(["--profile", "nonexist"]) + assert result.exit_code == 1 + assert "Profile 'nonexist' not found in profiles" in result.output + + # profile exists - successful load with default target + result = invoke_cli(["--profile", "jaffle_shop"]) + assert result.exit_code == 0 + assert "No command specified" in result.output + + # profile exists but target doesnt - error + result = invoke_cli(["--profile", "jaffle_shop", "--target", "nonexist"]) + assert result.exit_code == 1 + assert "Target 'nonexist' not specified in profiles" in result.output + assert "valid target names for this profile are" in result.output + assert "- dev" in result.output + + # profile exists and so does target - successful load with specified target + result = invoke_cli(["--profile", "jaffle_shop", "--target", "dev"]) + assert result.exit_code == 0 + assert "No command specified" in result.output + + +def test_run_error_handler( + jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result], mocker: MockerFixture +) -> None: + mock_run = mocker.patch("sqlmesh_dbt.operations.DbtOperations.run") + mock_run.side_effect = SQLMeshError("Test error message") + + result = invoke_cli(["run"]) + assert result.exit_code == 1 + assert "Error: Test error message" in result.output + assert "Traceback" not in result.output + + # test SqlglotError in run command + mock_run = mocker.patch("sqlmesh_dbt.operations.DbtOperations.run") + mock_run.side_effect = SqlglotError("Invalid SQL syntax") + + result = invoke_cli(["run"]) + + assert result.exit_code == 1 + assert "Error: Invalid SQL syntax" in result.output + assert "Traceback" not in result.output + + # test ValueError in run command + mock_run = mocker.patch("sqlmesh_dbt.operations.DbtOperations.run") + mock_run.side_effect = ValueError("Invalid configuration value") + + result = invoke_cli(["run"]) + + assert result.exit_code == 1 + assert "Error: Invalid configuration value" in result.output + assert "Traceback" not in result.output + + # test SQLMeshError in list command + mock_list = mocker.patch("sqlmesh_dbt.operations.DbtOperations.list_") + mock_list.side_effect = SQLMeshError("List command error") + + result = invoke_cli(["list"]) + + assert result.exit_code == 1 + assert "Error: List command error" in result.output + assert "Traceback" not in result.output + + # test SQLMeshError in main command without subcommand + mock_create = mocker.patch("sqlmesh_dbt.cli.create") + mock_create.side_effect = SQLMeshError("Failed to load project") + result = invoke_cli(["--profile", "jaffle_shop"]) + + assert result.exit_code == 1 + assert "Error: Failed to load project" in result.output + assert "Traceback" not in result.output + mocker.stopall() + + # test error with select option + mock_run_select = mocker.patch("sqlmesh_dbt.operations.DbtOperations.run") + mock_run_select.side_effect = SQLMeshError("Error with selector") + + result = invoke_cli(["run", "--select", "model1"]) + + assert result.exit_code == 1 + assert "Error: Error with selector" in result.output + assert "Traceback" not in result.output + + +def test_log_level(invoke_cli: t.Callable[..., Result], create_empty_project: EmptyProjectCreator): + create_empty_project() + + result = invoke_cli(["--log-level", "info", "list"]) + assert result.exit_code == 0 + assert logging.getLogger("sqlmesh").getEffectiveLevel() == logging.INFO + + result = invoke_cli(["--log-level", "debug", "list"]) + assert result.exit_code == 0 + assert logging.getLogger("sqlmesh").getEffectiveLevel() == logging.DEBUG + + +def test_profiles_dir( + invoke_cli: t.Callable[..., Result], create_empty_project: EmptyProjectCreator, tmp_path: Path +): + project_dir, _ = create_empty_project(project_name="test_profiles_dir") + + orig_profiles_yml = project_dir / "profiles.yml" + assert orig_profiles_yml.exists() + + new_profiles_yml = tmp_path / "some_other_place" / "profiles.yml" + new_profiles_yml.parent.mkdir(parents=True) + + orig_profiles_yml.rename(new_profiles_yml) + assert not orig_profiles_yml.exists() + assert new_profiles_yml.exists() + + # should fail if we don't specify --profiles-dir + result = invoke_cli(["list"]) + assert result.exit_code > 0, result.output + + # alternative ~/.dbt/profiles.yml might exist but doesn't contain the profile + assert "profiles.yml not found" in result.output or "not found in profiles" in result.output + + # should pass if we specify --profiles-dir + result = invoke_cli(["--profiles-dir", str(new_profiles_yml.parent), "list"]) + assert result.exit_code == 0, result.output + assert "Models in project" in result.output + + +def test_project_dir( + invoke_cli: t.Callable[..., Result], create_empty_project: EmptyProjectCreator +): + orig_project_dir, _ = create_empty_project(project_name="test_project_dir") + + orig_project_yml = orig_project_dir / "dbt_project.yml" + assert orig_project_yml.exists() + + new_project_yml = orig_project_dir / "nested" / "dbt_project.yml" + new_project_yml.parent.mkdir(parents=True) + + orig_project_yml.rename(new_project_yml) + assert not orig_project_yml.exists() + assert new_project_yml.exists() + + # should fail if we don't specify --project-dir + result = invoke_cli(["list"]) + assert result.exit_code != 0, result.output + assert "Error:" in result.output + + # should fail if the profiles.yml also doesnt exist at that --project-dir + result = invoke_cli(["--project-dir", str(new_project_yml.parent), "list"]) + assert result.exit_code != 0, result.output + + # profiles.yml might exist but doesn't contain the profile + assert "profiles.yml not found" in result.output or "not found in profiles" in result.output + + # should pass if it can find both files, either because we specified --profiles-dir explicitly or the profiles.yml was found in --project-dir + result = invoke_cli( + [ + "--project-dir", + str(new_project_yml.parent), + "--profiles-dir", + str(orig_project_dir), + "list", + ] + ) + assert result.exit_code == 0, result.output + assert "Models in project" in result.output + + orig_profiles_yml = orig_project_dir / "profiles.yml" + new_profiles_yml = new_project_yml.parent / "profiles.yml" + assert orig_profiles_yml.exists() + orig_profiles_yml.rename(new_profiles_yml) + + result = invoke_cli(["--project-dir", str(new_project_yml.parent), "list"]) + assert result.exit_code == 0, result.output + assert "Models in project" in result.output diff --git a/tests/dbt/cli/test_list.py b/tests/dbt/cli/test_list.py new file mode 100644 index 0000000000..3e6a55125c --- /dev/null +++ b/tests/dbt/cli/test_list.py @@ -0,0 +1,104 @@ +import typing as t +import pytest +from pathlib import Path +from click.testing import Result + +pytestmark = pytest.mark.slow + + +def test_list(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]): + result = invoke_cli(["list"]) + + assert result.exit_code == 0 + assert not result.exception + + assert "─ jaffle_shop.orders" in result.output + assert "─ jaffle_shop.customers" in result.output + assert "─ jaffle_shop.staging.stg_payments" in result.output + assert "─ jaffle_shop.raw_orders" in result.output + + +def test_list_select(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]): + result = invoke_cli(["list", "--select", "raw_customers+"]) + + assert result.exit_code == 0 + assert not result.exception + + assert "─ jaffle_shop.customers" in result.output + assert "─ jaffle_shop.staging.stg_customers" in result.output + assert "─ jaffle_shop.raw_customers" in result.output + + assert "─ jaffle_shop.staging.stg_payments" not in result.output + assert "─ jaffle_shop.raw_orders" not in result.output + + +def test_list_select_exclude(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]): + # single exclude + result = invoke_cli(["list", "--select", "raw_customers+", "--exclude", "orders"]) + + assert result.exit_code == 0 + assert not result.exception + + assert "─ jaffle_shop.customers" in result.output + assert "─ jaffle_shop.staging.stg_customers" in result.output + assert "─ jaffle_shop.raw_customers" in result.output + + assert "─ jaffle_shop.orders" not in result.output + assert "─ jaffle_shop.staging.stg_payments" not in result.output + assert "─ jaffle_shop.raw_orders" not in result.output + + # multiple exclude + for args in ( + ["--select", "stg_orders+", "--exclude", "customers", "--exclude", "orders"], + ["--select", "stg_orders+", "--exclude", "customers orders"], + ): + result = invoke_cli(["list", *args]) + assert result.exit_code == 0 + assert not result.exception + + assert "─ jaffle_shop.staging.stg_orders" in result.output + + assert "─ jaffle_shop.customers" not in result.output + assert "─ jaffle_shop.orders" not in result.output + + +def test_list_with_vars(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]): + ( + jaffle_shop_duckdb / "models" / "vars_model.sql" + ).write_text(""" + select * from {{ ref('custom' + var('foo')) }} + """) + + result = invoke_cli(["list", "--vars", "foo: ers"]) + + assert result.exit_code == 0 + assert not result.exception + + assert ( + """├── jaffle_shop.vars_model +│ └── depends_on: jaffle_shop.customers""" + in result.output + ) + + +def test_list_models_mutually_exclusive( + jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result] +): + result = invoke_cli(["list", "--select", "foo", "--models", "bar"]) + assert result.exit_code != 0 + assert '"models" and "select" are mutually exclusive arguments' in result.output + + result = invoke_cli(["list", "--resource-type", "test", "--models", "bar"]) + assert result.exit_code != 0 + assert '"models" and "resource_type" are mutually exclusive arguments' in result.output + + +def test_list_models(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]): + result = invoke_cli(["list", "--models", "jaffle_shop"]) + assert result.exit_code == 0 + assert not result.exception + + assert "─ jaffle_shop.customers" in result.output + assert ( + "─ jaffle_shop.raw_customers" not in result.output + ) # should be excluded because dbt --models excludes seeds diff --git a/tests/dbt/cli/test_operations.py b/tests/dbt/cli/test_operations.py new file mode 100644 index 0000000000..4aa508e21f --- /dev/null +++ b/tests/dbt/cli/test_operations.py @@ -0,0 +1,376 @@ +import typing as t +from pathlib import Path +import pytest +from sqlmesh_dbt.operations import create +from sqlmesh_dbt.console import DbtCliConsole +from sqlmesh.utils import yaml +from sqlmesh.utils.errors import SQLMeshError +import time_machine +from sqlmesh.core.plan import PlanBuilder +from sqlmesh.core.config.common import VirtualEnvironmentMode +from tests.dbt.conftest import EmptyProjectCreator +import logging + +pytestmark = pytest.mark.slow + + +class PlanCapturingConsole(DbtCliConsole): + def plan( + self, + plan_builder: PlanBuilder, + auto_apply: bool, + default_catalog: t.Optional[str], + no_diff: bool = False, + no_prompts: bool = False, + ) -> None: + self.plan_builder = plan_builder + self.auto_apply = auto_apply + self.default_catalog = default_catalog + self.no_diff = no_diff + self.no_prompts = no_prompts + + # normal console starts applying the plan here; we dont because we just want to capture the parameters + # and check they were set correctly + + +def test_create_sets_and_persists_default_start_date(jaffle_shop_duckdb: Path): + with time_machine.travel("2020-01-02 00:00:00 UTC"): + from sqlmesh.utils.date import yesterday_ds, to_ds + + assert yesterday_ds() == "2020-01-01" + + operations = create() + + assert operations.context.config.model_defaults.start + assert to_ds(operations.context.config.model_defaults.start) == "2020-01-01" + assert all( + to_ds(model.start) if model.start else None == "2020-01-01" + for model in operations.context.models.values() + if not model.kind.is_seed + ) + + # check that the date set on the first invocation persists to future invocations + from sqlmesh.utils.date import yesterday_ds, to_ds + + assert yesterday_ds() != "2020-01-01" + + operations = create() + + assert operations.context.config.model_defaults.start + assert to_ds(operations.context.config.model_defaults.start) == "2020-01-01" + assert all( + to_ds(model.start) if model.start else None == "2020-01-01" + for model in operations.context.models.values() + if not model.kind.is_seed + ) + + +def test_create_uses_configured_start_date_if_supplied(jaffle_shop_duckdb: Path): + sqlmesh_yaml = jaffle_shop_duckdb / "sqlmesh.yml" + + with sqlmesh_yaml.open("w") as f: + yaml.dump({"model_defaults": {"start": "2023-12-12"}}, f) + + operations = create() + + assert operations.context.config.model_defaults.start == "2023-12-12" + assert all( + model.start == "2023-12-12" + for model in operations.context.models.values() + if not model.kind.is_seed + ) + + +def test_create_can_specify_profile_and_target(jaffle_shop_duckdb: Path): + with pytest.raises(SQLMeshError, match=r"Profile 'foo' not found"): + create(profile="foo") + + with pytest.raises( + SQLMeshError, match=r"Target 'prod' not specified in profiles for 'jaffle_shop'" + ): + create(profile="jaffle_shop", target="prod") + + dbt_project = create(profile="jaffle_shop", target="dev").project + + assert dbt_project.context.profile_name == "jaffle_shop" + assert dbt_project.context.target_name == "dev" + + +def test_default_options(jaffle_shop_duckdb: Path): + operations = create() + + config = operations.context.config + dbt_project = operations.project + + assert config.plan.always_recreate_environment is True + assert config.virtual_environment_mode == VirtualEnvironmentMode.DEV_ONLY + assert config.model_defaults.start is not None + assert config.model_defaults.dialect == dbt_project.context.target.dialect + + +def test_create_can_set_project_variables(jaffle_shop_duckdb: Path): + (jaffle_shop_duckdb / "models" / "test_model.sql").write_text(""" + select '{{ var('foo') }}' as a + """) + + dbt_project = create(vars={"foo": "bar"}) + assert dbt_project.context.config.variables["foo"] == "bar" + + test_model = dbt_project.context.models['"jaffle_shop"."main"."test_model"'] + query = test_model.render_query() + assert query is not None + assert query.sql() == "SELECT 'bar' AS \"a\"" + + +def test_run_option_mapping(jaffle_shop_duckdb: Path): + operations = create(project_dir=jaffle_shop_duckdb) + console = PlanCapturingConsole() + operations.context.console = console + + plan = operations.run() + standalone_audit_name = "relationships_orders_customer_id__customer_id__ref_customers_" + assert plan.environment.name == "prod" + assert console.no_prompts is True + assert console.no_diff is True + assert console.auto_apply is True + assert plan.end_bounded is False + assert plan.ignore_cron is True + assert plan.skip_backfill is False + assert plan.selected_models_to_backfill is None + assert {s.name for s in plan.snapshots} == {k for k in operations.context.snapshots} + + plan = operations.run(select=["stg_orders+"]) + assert plan.environment.name == "prod" + assert console.no_prompts is True + assert console.no_diff is True + assert console.auto_apply is True + assert plan.end_bounded is False + assert plan.ignore_cron is True + assert plan.skip_backfill is False + assert plan.selected_models_to_backfill == { + '"jaffle_shop"."main"."customers"', + '"jaffle_shop"."main"."orders"', + '"jaffle_shop"."main"."stg_orders"', + } + assert {s.name for s in plan.snapshots} == ( + plan.selected_models_to_backfill | {standalone_audit_name} + ) + + plan = operations.run(select=["stg_orders+"], exclude=["customers"]) + assert plan.environment.name == "prod" + assert console.no_prompts is True + assert console.no_diff is True + assert console.auto_apply is True + assert plan.end_bounded is False + assert plan.ignore_cron is True + assert plan.skip_backfill is False + assert plan.selected_models_to_backfill == { + '"jaffle_shop"."main"."orders"', + '"jaffle_shop"."main"."stg_orders"', + } + assert {s.name for s in plan.snapshots} == ( + plan.selected_models_to_backfill | {standalone_audit_name} + ) + + plan = operations.run(exclude=["customers"]) + assert plan.environment.name == "prod" + assert console.no_prompts is True + assert console.no_diff is True + assert console.auto_apply is True + assert plan.end_bounded is False + assert plan.ignore_cron is True + assert plan.skip_backfill is False + assert plan.selected_models_to_backfill == {k for k in operations.context.snapshots} - { + '"jaffle_shop"."main"."customers"' + } - {standalone_audit_name} + assert {s.name for s in plan.snapshots} == ( + plan.selected_models_to_backfill | {standalone_audit_name} + ) + + plan = operations.run(empty=True) + assert plan.environment.name == "prod" + assert console.no_prompts is True + assert console.no_diff is True + assert console.auto_apply is True + assert plan.end_bounded is False + assert plan.ignore_cron is True + assert plan.skip_backfill is True + assert plan.selected_models_to_backfill is None + assert {s.name for s in plan.snapshots} == {k for k in operations.context.snapshots} + + +def test_run_option_mapping_dev(jaffle_shop_duckdb: Path): + # create prod so that dev has something to compare against + operations = create(project_dir=jaffle_shop_duckdb) + operations.run() + + (jaffle_shop_duckdb / "models" / "new_model.sql").write_text("select 1") + + operations = create(project_dir=jaffle_shop_duckdb) + + console = PlanCapturingConsole() + operations.context.console = console + + plan = operations.run(environment="dev") + assert plan.environment.name == "dev" + assert console.no_prompts is True + assert console.no_diff is True + assert console.auto_apply is True + assert plan.include_unmodified is False + assert plan.context_diff.create_from == "prod" + assert plan.context_diff.is_new_environment is True + assert console.plan_builder._enable_preview is True + assert plan.end_bounded is True + assert plan.ignore_cron is False + assert plan.skip_backfill is False + assert plan.selected_models_to_backfill == {'"jaffle_shop"."main"."new_model"'} + + plan = operations.run(environment="dev", empty=True) + assert plan.environment.name == "dev" + assert console.no_prompts is True + assert console.no_diff is True + assert console.auto_apply is True + assert plan.include_unmodified is False + assert plan.context_diff.create_from == "prod" + assert plan.context_diff.is_new_environment is True + assert console.plan_builder._enable_preview is True + assert plan.end_bounded is True + assert plan.ignore_cron is False + assert plan.skip_backfill is True + assert plan.selected_models_to_backfill == {'"jaffle_shop"."main"."new_model"'} + + plan = operations.run(environment="dev", select=["stg_orders+"]) + assert plan.environment.name == "dev" + assert console.no_prompts is True + assert console.no_diff is True + assert console.auto_apply is True + assert plan.include_unmodified is False + assert plan.context_diff.create_from == "prod" + assert plan.context_diff.is_new_environment is True + assert console.plan_builder._enable_preview is True + # dev plans with --select have run=True, ignore_cron=True set + # as opposed to dev plans that dont have a specific selector + assert plan.end_bounded is False + assert plan.ignore_cron is True + assert plan.skip_backfill is False + # note: the new model in the dev environment is ignored in favour of the explicitly selected ones + assert plan.selected_models_to_backfill == { + '"jaffle_shop"."main"."customers"', + '"jaffle_shop"."main"."orders"', + '"jaffle_shop"."main"."stg_orders"', + } + + +@pytest.mark.parametrize( + "env_name,vde_mode", + [ + ("prod", VirtualEnvironmentMode.DEV_ONLY), + ("prod", VirtualEnvironmentMode.FULL), + ("dev", VirtualEnvironmentMode.DEV_ONLY), + ("dev", VirtualEnvironmentMode.FULL), + ], +) +def test_run_option_full_refresh( + create_empty_project: EmptyProjectCreator, env_name: str, vde_mode: VirtualEnvironmentMode +): + # create config file prior to load + project_path, models_path = create_empty_project(project_name="test") + + config_path = project_path / "sqlmesh.yaml" + config = yaml.load(config_path) + config["virtual_environment_mode"] = vde_mode.value + + with config_path.open("w") as f: + yaml.dump(config, f) + + (models_path / "model_a.sql").write_text("select 1") + (models_path / "model_b.sql").write_text("select 2") + + operations = create(project_dir=project_path) + + assert operations.context.config.virtual_environment_mode == vde_mode + + console = PlanCapturingConsole() + operations.context.console = console + + plan = operations.run(environment=env_name, full_refresh=True) + + # both models added as backfills + restatements regardless of env / vde mode setting + assert plan.environment.name == env_name + assert len(plan.restatements) == 2 + assert list(plan.restatements)[0].name == '"test"."main"."model_a"' + assert list(plan.restatements)[1].name == '"test"."main"."model_b"' + + assert plan.requires_backfill + assert not plan.empty_backfill + assert not plan.skip_backfill + assert plan.models_to_backfill == set(['"test"."main"."model_a"', '"test"."main"."model_b"']) + + if vde_mode == VirtualEnvironmentMode.DEV_ONLY: + # We do not clear intervals across all model versions in the default DEV_ONLY mode, even when targeting prod, + # because dev data is hardcoded to preview only so by definition and can never be deployed + assert not plan.restate_all_snapshots + else: + if env_name == "prod": + # in FULL mode, we do it for prod + assert plan.restate_all_snapshots + else: + # but not dev + assert not plan.restate_all_snapshots + + +def test_run_option_full_refresh_with_selector(jaffle_shop_duckdb: Path): + operations = create(project_dir=jaffle_shop_duckdb) + assert len(operations.context.models) > 5 + + console = PlanCapturingConsole() + operations.context.console = console + + plan = operations.run(select=["stg_customers"], full_refresh=True) + assert len(plan.restatements) == 1 + assert list(plan.restatements)[0].name == '"jaffle_shop"."main"."stg_customers"' + + assert plan.requires_backfill + assert not plan.empty_backfill + assert not plan.skip_backfill + assert plan.models_to_backfill == set(['"jaffle_shop"."main"."stg_customers"']) + + +def test_create_sets_concurrent_tasks_based_on_threads(create_empty_project: EmptyProjectCreator): + project_dir, _ = create_empty_project(project_name="test") + + # add a postgres target because duckdb overrides to concurrent_tasks=1 regardless of what gets specified + profiles_yml_file = project_dir / "profiles.yml" + profiles_yml = yaml.load(profiles_yml_file) + profiles_yml["test"]["outputs"]["postgres"] = { + "type": "postgres", + "host": "localhost", + "port": 5432, + "user": "postgres", + "password": "postgres", + "dbname": "test", + "schema": "test", + } + profiles_yml_file.write_text(yaml.dump(profiles_yml)) + + operations = create(project_dir=project_dir, target="postgres") + + assert operations.context.concurrent_tasks == 1 # 1 is the default + + operations = create(project_dir=project_dir, threads=16, target="postgres") + + assert operations.context.concurrent_tasks == 16 + assert all( + g.connection and g.connection.concurrent_tasks == 16 + for g in operations.context.config.gateways.values() + ) + + +def test_create_configures_log_level(create_empty_project: EmptyProjectCreator): + project_dir, _ = create_empty_project() + + create(project_dir=project_dir, log_level="info") + assert logging.getLogger("sqlmesh").getEffectiveLevel() == logging.INFO + + create(project_dir=project_dir, log_level="error") + assert logging.getLogger("sqlmesh").getEffectiveLevel() == logging.ERROR diff --git a/tests/dbt/cli/test_options.py b/tests/dbt/cli/test_options.py new file mode 100644 index 0000000000..962ff0beb3 --- /dev/null +++ b/tests/dbt/cli/test_options.py @@ -0,0 +1,23 @@ +import typing as t +import pytest +from sqlmesh_dbt.options import YamlParamType +from click.exceptions import BadParameter + + +@pytest.mark.parametrize( + "input,expected", + [ + (1, BadParameter("Input value '1' should be a string")), + ("", BadParameter("String '' is not valid YAML")), + ("['a', 'b']", BadParameter("String.*did not evaluate to a dict, got.*")), + ("foo: bar", {"foo": "bar"}), + ('{"key": "value", "date": 20180101}', {"key": "value", "date": 20180101}), + ("{key: value, date: 20180101}", {"key": "value", "date": 20180101}), + ], +) +def test_yaml_param_type(input: str, expected: t.Union[BadParameter, t.Dict[str, t.Any]]): + if isinstance(expected, BadParameter): + with pytest.raises(BadParameter, match=expected.message): + YamlParamType().convert(input, None, None) + else: + assert YamlParamType().convert(input, None, None) == expected diff --git a/tests/dbt/cli/test_run.py b/tests/dbt/cli/test_run.py new file mode 100644 index 0000000000..4fdb7a0cdb --- /dev/null +++ b/tests/dbt/cli/test_run.py @@ -0,0 +1,93 @@ +import typing as t +import pytest +from pathlib import Path +from click.testing import Result +import time_machine +from sqlmesh_dbt.operations import create +from tests.cli.test_cli import FREEZE_TIME +from tests.dbt.conftest import EmptyProjectCreator + +pytestmark = pytest.mark.slow + + +def test_run(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]): + result = invoke_cli(["run"]) + + assert result.exit_code == 0 + assert not result.exception + + assert "Model batches executed" in result.output + + +def test_run_with_selectors(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]): + with time_machine.travel(FREEZE_TIME): + # do an initial run to create the objects + # otherwise the selected subset may depend on something that hasnt been created + result = invoke_cli(["run"]) + assert result.exit_code == 0 + assert "main.orders" in result.output + + result = invoke_cli(["run", "--select", "raw_customers+", "--exclude", "orders"]) + + assert result.exit_code == 0 + assert not result.exception + + assert "main.stg_customers" in result.output + assert "main.stg_orders" in result.output + assert "main.stg_payments" in result.output + assert "main.customers" in result.output + + assert "main.orders" not in result.output + + assert "Model batches executed" in result.output + + +def test_run_with_changes_and_full_refresh( + create_empty_project: EmptyProjectCreator, invoke_cli: t.Callable[..., Result] +): + project_path, models_path = create_empty_project(project_name="test") + + engine_adapter = create(project_path).context.engine_adapter + engine_adapter.execute("create table external_table as select 'foo' as a, 'bar' as b") + + (models_path / "model_a.sql").write_text("select a, b from external_table") + (models_path / "model_b.sql").write_text("select a, b from {{ ref('model_a') }}") + + # populate initial env + result = invoke_cli(["run"]) + assert result.exit_code == 0 + assert not result.exception + + assert engine_adapter.fetchall("select a, b from model_b") == [("foo", "bar")] + + engine_adapter.execute("insert into external_table (a, b) values ('baz', 'bing')") + (project_path / "models" / "model_b.sql").write_text( + "select a, b, 'changed' as c from {{ ref('model_a') }}" + ) + + # Clear dbt's partial parse cache to ensure file changes are detected + # Without it dbt may use stale cached model definitions, causing flakiness + partial_parse_file = project_path / "target" / "sqlmesh_partial_parse.msgpack" + if partial_parse_file.exists(): + partial_parse_file.unlink() + + # run with --full-refresh. this should: + # - fully refresh model_a (pick up the new records from external_table) + # - deploy the local change to model_b (introducing the 'changed' column) + result = invoke_cli(["run", "--full-refresh"]) + assert result.exit_code == 0 + assert not result.exception + + assert engine_adapter.fetchall("select a, b from model_a") == [("foo", "bar"), ("baz", "bing")] + assert engine_adapter.fetchall("select a, b, c from model_b") == [ + ("foo", "bar", "changed"), + ("baz", "bing", "changed"), + ] + + +def test_run_with_threads(jaffle_shop_duckdb: Path, invoke_cli: t.Callable[..., Result]): + result = invoke_cli(["run", "--threads", "4"]) + assert result.exit_code == 0 + assert not result.exception + + assert "Model batches executed" in result.output diff --git a/tests/dbt/cli/test_selectors.py b/tests/dbt/cli/test_selectors.py new file mode 100644 index 0000000000..17f0195f58 --- /dev/null +++ b/tests/dbt/cli/test_selectors.py @@ -0,0 +1,331 @@ +import typing as t +import pytest +from sqlmesh_dbt import selectors +from sqlmesh.core.selector import DbtSelector +from sqlmesh.core.context import Context +from pathlib import Path + + +@pytest.mark.parametrize( + "dbt_select,expected", + [ + ([], None), + (["main.model_a"], "main.model_a"), + (["main.model_a main.model_b"], "main.model_a | main.model_b"), + (["main.model_a", "main.model_b"], "main.model_a | main.model_b"), + (["(main.model_a & ^main.model_b)"], "(main.model_a & ^main.model_b)"), + ( + ["(+main.model_a & ^main.model_b)", "main.model_c"], + "(+main.model_a & ^main.model_b) | main.model_c", + ), + ], +) +def test_selection(dbt_select: t.List[str], expected: t.Optional[str]): + assert selectors.to_sqlmesh(dbt_select=dbt_select, dbt_exclude=[]) == expected + + +@pytest.mark.parametrize( + "dbt_exclude,expected", + [ + ([], None), + (["main.model_a"], "^(main.model_a)"), + (["(main.model_a & main.model_b)"], "^(main.model_a & main.model_b)"), + (["main.model_a,main.model_b"], "^(main.model_a & main.model_b)"), + (["main.model_a +main.model_b"], "^(main.model_a | +main.model_b)"), + ( + ["(+main.model_a & ^main.model_b)", "main.model_c"], + "^((+main.model_a & ^main.model_b) | main.model_c)", + ), + ], +) +def test_exclusion(dbt_exclude: t.List[str], expected: t.Optional[str]): + assert selectors.to_sqlmesh(dbt_select=[], dbt_exclude=dbt_exclude) == expected + + +@pytest.mark.parametrize( + "dbt_select,dbt_exclude,expected", + [ + ([], [], None), + (["+main.model_a"], ["raw.src_data"], "+main.model_a & ^(raw.src_data)"), + ( + ["+main.model_a", "main.*b+"], + ["raw.src_data"], + "(+main.model_a | main.*b+) & ^(raw.src_data)", + ), + ( + ["+main.model_a", "main.*b+"], + ["raw.src_data", "tag:disabled"], + "(+main.model_a | main.*b+) & ^(raw.src_data | tag:disabled)", + ), + ], +) +def test_selection_and_exclusion( + dbt_select: t.List[str], dbt_exclude: t.List[str], expected: t.Optional[str] +): + assert selectors.to_sqlmesh(dbt_select=dbt_select, dbt_exclude=dbt_exclude) == expected + + +@pytest.mark.parametrize( + "expression,expected", + [ + ("", ([], [])), + ("model_a", (["model_a"], [])), + ("model_a model_b", (["model_a", "model_b"], [])), + ("model_a,model_b", ([], ["model_a", "model_b"])), + ("model_a model_b,model_c", (["model_a"], ["model_b", "model_c"])), + ("model_a,model_b model_c", (["model_c"], ["model_a", "model_b"])), + ], +) +def test_split_unions_and_intersections( + expression: str, expected: t.Tuple[t.List[str], t.List[str]] +): + assert selectors._split_unions_and_intersections(expression) == expected + + +@pytest.mark.parametrize( + "dbt_select,expected", + [ + (["aging"], set()), + ( + ["staging"], + { + '"jaffle_shop"."main"."stg_customers"', + '"jaffle_shop"."main"."stg_orders"', + '"jaffle_shop"."main"."stg_payments"', + }, + ), + (["staging.stg_customers"], {'"jaffle_shop"."main"."stg_customers"'}), + (["stg_customers.staging"], set()), + ( + ["+customers"], + { + '"jaffle_shop"."main"."customers"', + '"jaffle_shop"."main"."stg_customers"', + '"jaffle_shop"."main"."stg_orders"', + '"jaffle_shop"."main"."stg_payments"', + '"jaffle_shop"."main"."raw_customers"', + '"jaffle_shop"."main"."raw_orders"', + '"jaffle_shop"."main"."raw_payments"', + }, + ), + (["customers+"], {'"jaffle_shop"."main"."customers"'}), + ( + ["customers+", "stg_orders"], + {'"jaffle_shop"."main"."customers"', '"jaffle_shop"."main"."stg_orders"'}, + ), + (["*.staging.stg_c*"], {'"jaffle_shop"."main"."stg_customers"'}), + (["tag:agg"], {'"jaffle_shop"."main"."agg_orders"'}), + ( + ["staging.stg_customers", "tag:agg"], + { + '"jaffle_shop"."main"."stg_customers"', + '"jaffle_shop"."main"."agg_orders"', + }, + ), + ( + ["+tag:agg"], + { + '"jaffle_shop"."main"."agg_orders"', + '"jaffle_shop"."main"."orders"', + '"jaffle_shop"."main"."stg_orders"', + '"jaffle_shop"."main"."stg_payments"', + '"jaffle_shop"."main"."raw_orders"', + '"jaffle_shop"."main"."raw_payments"', + }, + ), + ( + ["tag:agg+"], + { + '"jaffle_shop"."main"."agg_orders"', + }, + ), + ( + ["tag:b*"], + set(), + ), + ( + ["tag:a*"], + { + '"jaffle_shop"."main"."agg_orders"', + }, + ), + ], +) +def test_select_by_dbt_names( + jaffle_shop_duckdb: Path, + jaffle_shop_duckdb_context: Context, + dbt_select: t.List[str], + expected: t.Set[str], +): + (jaffle_shop_duckdb / "models" / "agg_orders.sql").write_text(""" + {{ config(tags=["agg"]) }} + select order_date, count(*) as num_orders from {{ ref('orders') }} + """) + + ctx = jaffle_shop_duckdb_context + ctx.load() + assert '"jaffle_shop"."main"."agg_orders"' in ctx.models + assert ctx.get_model('"jaffle_shop"."main"."agg_orders"').tags == ["agg"] + + selector = ctx._new_selector() + assert isinstance(selector, DbtSelector) + + sqlmesh_selector = selectors.to_sqlmesh(dbt_select=dbt_select, dbt_exclude=[]) + assert sqlmesh_selector + + assert selector.expand_model_selections([sqlmesh_selector]) == expected + + +@pytest.mark.parametrize( + "dbt_exclude,expected", + [ + (["jaffle_shop"], set()), + ( + ["staging"], + { + '"jaffle_shop"."main"."agg_orders"', + '"jaffle_shop"."main"."customers"', + '"jaffle_shop"."main"."orders"', + '"jaffle_shop"."main"."raw_customers"', + '"jaffle_shop"."main"."raw_orders"', + '"jaffle_shop"."main"."raw_payments"', + }, + ), + (["+customers"], {'"jaffle_shop"."main"."orders"', '"jaffle_shop"."main"."agg_orders"'}), + ( + ["+tag:agg"], + { + '"jaffle_shop"."main"."customers"', + '"jaffle_shop"."main"."stg_customers"', + '"jaffle_shop"."main"."raw_customers"', + }, + ), + ], +) +def test_exclude_by_dbt_names( + jaffle_shop_duckdb: Path, + jaffle_shop_duckdb_context: Context, + dbt_exclude: t.List[str], + expected: t.Set[str], +): + (jaffle_shop_duckdb / "models" / "agg_orders.sql").write_text(""" + {{ config(tags=["agg"]) }} + select order_date, count(*) as num_orders from {{ ref('orders') }} + """) + + ctx = jaffle_shop_duckdb_context + ctx.load() + assert '"jaffle_shop"."main"."agg_orders"' in ctx.models + assert ctx.get_model('"jaffle_shop"."main"."agg_orders"').tags == ["agg"] + + selector = ctx._new_selector() + assert isinstance(selector, DbtSelector) + + sqlmesh_selector = selectors.to_sqlmesh(dbt_select=[], dbt_exclude=dbt_exclude) + assert sqlmesh_selector + + assert selector.expand_model_selections([sqlmesh_selector]) == expected + + +@pytest.mark.parametrize( + "dbt_select,dbt_exclude,expected", + [ + (["jaffle_shop"], ["jaffle_shop"], set()), + ( + ["staging"], + ["stg_customers"], + { + '"jaffle_shop"."main"."stg_orders"', + '"jaffle_shop"."main"."stg_payments"', + }, + ), + ( + ["staging.stg_customers", "tag:agg"], + ["tag:agg"], + { + '"jaffle_shop"."main"."stg_customers"', + }, + ), + ], +) +def test_selection_and_exclusion_by_dbt_names( + jaffle_shop_duckdb: Path, + jaffle_shop_duckdb_context: Context, + dbt_select: t.List[str], + dbt_exclude: t.List[str], + expected: t.Set[str], +): + (jaffle_shop_duckdb / "models" / "agg_orders.sql").write_text(""" + {{ config(tags=["agg"]) }} + select order_date, count(*) as num_orders from {{ ref('orders') }} + """) + + ctx = jaffle_shop_duckdb_context + ctx.load() + assert '"jaffle_shop"."main"."agg_orders"' in ctx.models + + selector = ctx._new_selector() + assert isinstance(selector, DbtSelector) + + sqlmesh_selector = selectors.to_sqlmesh(dbt_select=dbt_select, dbt_exclude=dbt_exclude) + assert sqlmesh_selector + + assert selector.expand_model_selections([sqlmesh_selector]) == expected + + +@pytest.mark.parametrize( + "input_args,expected", + [ + ( + dict(select=["jaffle_shop"], models=["jaffle_shop"]), + '"models" and "select" are mutually exclusive', + ), + ( + dict(models=["jaffle_shop"], resource_type="test"), + '"models" and "resource_type" are mutually exclusive', + ), + ( + dict(select=["jaffle_shop"], resource_type="test"), + (["resource_type:test,jaffle_shop"], []), + ), + (dict(resource_type="model"), (["resource_type:model"], [])), + (dict(models=["stg_customers"]), (["resource_type:model,stg_customers"], [])), + ( + dict(models=["stg_customers"], exclude=["orders"]), + (["resource_type:model,stg_customers"], ["orders"]), + ), + ], +) +def test_consolidate(input_args: t.Dict[str, t.Any], expected: t.Union[t.Tuple[str, str], str]): + all_input_args: t.Dict[str, t.Any] = dict(select=[], exclude=[], models=[], resource_type=None) + + all_input_args.update(input_args) + + def _do_assert(): + assert selectors.consolidate(**all_input_args) == expected + + if isinstance(expected, str): + with pytest.raises(ValueError, match=expected): + _do_assert() + else: + _do_assert() + + +def test_models_by_dbt_names(jaffle_shop_duckdb_context: Context): + ctx = jaffle_shop_duckdb_context + + selector = ctx._new_selector() + assert isinstance(selector, DbtSelector) + + selector_expr = selectors.to_sqlmesh( + *selectors.consolidate(select=[], exclude=[], models=["jaffle_shop"], resource_type=None) + ) + assert selector_expr + + assert selector.expand_model_selections([selector_expr]) == { + '"jaffle_shop"."main"."customers"', + '"jaffle_shop"."main"."orders"', + '"jaffle_shop"."main"."stg_customers"', + '"jaffle_shop"."main"."stg_orders"', + '"jaffle_shop"."main"."stg_payments"', + } diff --git a/tests/dbt/conftest.py b/tests/dbt/conftest.py index b71b3db8cb..5e6444c8e6 100644 --- a/tests/dbt/conftest.py +++ b/tests/dbt/conftest.py @@ -1,17 +1,106 @@ from __future__ import annotations import typing as t +import os +from pathlib import Path import pytest from sqlmesh.core.context import Context +from sqlmesh.core.selector import DbtSelector from sqlmesh.dbt.context import DbtContext from sqlmesh.dbt.project import Project +from sqlmesh.dbt.target import PostgresConfig +from sqlmesh_dbt.operations import init_project_if_required +import uuid + + +class EmptyProjectCreator(t.Protocol): + def __call__( + self, + project_name: t.Optional[str] = None, + target_name: t.Optional[str] = None, + start: t.Optional[str] = None, + ) -> t.Tuple[Path, Path]: ... @pytest.fixture() def sushi_test_project(sushi_test_dbt_context: Context) -> Project: - return sushi_test_dbt_context._loader._load_projects()[0] # type: ignore + return sushi_test_dbt_context._loaders[0]._load_projects()[0] # type: ignore + + +@pytest.fixture +def create_empty_project( + copy_to_temp_path: t.Callable[..., t.List[Path]], +) -> t.Iterable[EmptyProjectCreator]: + default_project_name = f"test_{str(uuid.uuid4())[:8]}" + default_target_name = "duckdb" + fixture_path = Path(__file__).parent.parent / "fixtures" / "dbt" / "empty_project" + assert fixture_path.exists() + + current_path = os.getcwd() + + def _create_empty_project( + project_name: t.Optional[str] = None, + target_name: t.Optional[str] = None, + start: t.Optional[str] = None, + ) -> t.Tuple[Path, Path]: + project_name = project_name or default_project_name + target_name = target_name or default_target_name + output_path = copy_to_temp_path(paths=fixture_path)[0] + + dbt_project_yml = output_path / "dbt_project.yml" + profiles_yml = output_path / "profiles.yml" + + assert dbt_project_yml.exists() + assert profiles_yml.exists() + + models_path = output_path / "models" + (models_path).mkdir() + (output_path / "seeds").mkdir() + + dbt_project_yml.write_text( + dbt_project_yml.read_text().replace("empty_project", project_name) + ) + profiles_yml.write_text( + profiles_yml.read_text() + .replace("empty_project", project_name) + .replace("__DEFAULT_TARGET__", target_name) + ) + + init_project_if_required(output_path, start) + + # so that we can invoke commands from the perspective of a user that is already in the correct directory + os.chdir(output_path) + + return output_path, models_path + + yield _create_empty_project + + # cleanup - switch cwd back to original + os.chdir(current_path) + + +@pytest.fixture +def jaffle_shop_duckdb(copy_to_temp_path: t.Callable[..., t.List[Path]]) -> t.Iterable[Path]: + fixture_path = Path(__file__).parent.parent / "fixtures" / "dbt" / "jaffle_shop_duckdb" + assert fixture_path.exists() + + current_path = os.getcwd() + output_path = copy_to_temp_path(paths=fixture_path)[0] + + # so that we can invoke commands from the perspective of a user that is alrady in the correct directory + os.chdir(output_path) + + yield output_path + + os.chdir(current_path) + + +@pytest.fixture +def jaffle_shop_duckdb_context(jaffle_shop_duckdb: Path) -> Context: + init_project_if_required(jaffle_shop_duckdb) + return Context(paths=[jaffle_shop_duckdb], selector=DbtSelector) @pytest.fixture() @@ -25,3 +114,39 @@ def render(value: str) -> str: return render return create_renderer + + +@pytest.fixture() +def dbt_dummy_postgres_config() -> PostgresConfig: + return PostgresConfig( # type: ignore + name="postgres", + host="host", + user="user", + password="password", + dbname="dbname", + port=5432, + schema="schema", + ) + + +@pytest.fixture(scope="function", autouse=True) +def reset_dbt_globals(): + # This fixture is used to clear the memoized cache for _get_package_with_retries + # in dbt.clients.registry. This is necessary because the cache is shared across + # tests and can cause unexpected behavior if not cleared as some tests depend on + # the deprecation warning that _get_package_with_retries fires + yield + # https://github.com/dbt-labs/dbt-core/blob/main/tests/functional/conftest.py#L9 + try: + from dbt.clients.registry import _get_cached + + _get_cached.cache = {} + except Exception: + pass + # https://github.com/dbt-labs/dbt-core/blob/main/core/dbt/tests/util.py#L82 + try: + from dbt_common.events.functions import reset_metadata_vars + + reset_metadata_vars() + except Exception: + pass diff --git a/tests/dbt/test_adapter.py b/tests/dbt/test_adapter.py index 320a64be05..5570212668 100644 --- a/tests/dbt/test_adapter.py +++ b/tests/dbt/test_adapter.py @@ -1,6 +1,5 @@ from __future__ import annotations -import os import typing as t from unittest import mock from unittest.mock import call @@ -11,17 +10,20 @@ from pytest_mock.plugin import MockerFixture from sqlglot import exp, parse_one -from sqlmesh import Context from sqlmesh.core.dialect import schema_ from sqlmesh.core.snapshot import SnapshotId +from sqlmesh.dbt.adapter import ParsetimeAdapter from sqlmesh.dbt.project import Project from sqlmesh.dbt.relation import Policy -from sqlmesh.dbt.target import SnowflakeConfig +from sqlmesh.dbt.target import BigQueryConfig, SnowflakeConfig from sqlmesh.utils.errors import ConfigError +from sqlmesh.utils.jinja import JinjaMacroRegistry +from sqlmesh.core.schema_diff import SchemaDiffer, TableAlterChangeColumnTypeOperation pytestmark = pytest.mark.dbt +@pytest.mark.slow def test_adapter_relation(sushi_test_project: Project, runtime_renderer: t.Callable): context = sushi_test_project.context assert context.target @@ -31,24 +33,40 @@ def test_adapter_relation(sushi_test_project: Project, runtime_renderer: t.Calla engine_adapter.create_schema("foo") engine_adapter.create_schema("ignored") engine_adapter.create_table( - table_name="foo.bar", columns_to_types={"baz": exp.DataType.build("int")} + table_name="foo.bar", target_columns_to_types={"baz": exp.DataType.build("int")} ) engine_adapter.create_table( - table_name="foo.another", columns_to_types={"col": exp.DataType.build("int")} + table_name="foo.another", target_columns_to_types={"col": exp.DataType.build("int")} + ) + engine_adapter.create_view( + view_name="foo.bar_view", query_or_df=t.cast(exp.Query, parse_one("select * from foo.bar")) ) engine_adapter.create_table( - table_name="ignored.ignore", columns_to_types={"col": exp.DataType.build("int")} + table_name="ignored.ignore", target_columns_to_types={"col": exp.DataType.build("int")} ) assert ( renderer("{{ adapter.get_relation(database=None, schema='foo', identifier='bar') }}") == '"memory"."foo"."bar"' ) + + assert ( + renderer("{{ adapter.get_relation(database=None, schema='foo', identifier='bar').type }}") + == "table" + ) + + assert ( + renderer( + "{{ adapter.get_relation(database=None, schema='foo', identifier='bar_view').type }}" + ) + == "view" + ) + assert renderer( "{%- set relation = adapter.get_relation(database=None, schema='foo', identifier='bar') -%} {{ adapter.get_columns_in_relation(relation) }}" ) == str([Column.from_description(name="baz", raw_data_type="INT")]) - assert renderer("{{ adapter.list_relations(database=None, schema='foo')|length }}") == "2" + assert renderer("{{ adapter.list_relations(database=None, schema='foo')|length }}") == "3" assert renderer( """ @@ -65,31 +83,88 @@ def test_adapter_relation(sushi_test_project: Project, runtime_renderer: t.Calla == "[]" ) + renderer(""" + {%- set old_relation = adapter.get_relation( + database=None, + schema='foo', + identifier='bar') -%} + + {%- set backup_relation = api.Relation.create(schema='foo', identifier='bar__backup') -%} + + {% do adapter.rename_relation(old_relation, backup_relation) %} + """) + assert not engine_adapter.table_exists("foo.bar") + assert engine_adapter.table_exists("foo.bar__backup") + + +@pytest.mark.slow +def test_bigquery_get_columns_in_relation( + sushi_test_project: Project, + runtime_renderer: t.Callable, + mocker: MockerFixture, +): + from dbt.adapters.bigquery import BigQueryColumn + from google.cloud.bigquery import SchemaField + + context = sushi_test_project.context + context.target = BigQueryConfig(name="test", schema="test", database="test") + + adapter_mock = mocker.MagicMock() + adapter_mock.default_catalog = "test" + adapter_mock.dialect = "bigquery" + table_schema = [ + SchemaField(name="id", field_type="STRING", mode="REQUIRED"), + SchemaField( + name="user_data", + field_type="RECORD", + mode="NULLABLE", + fields=[ + SchemaField(name="id", field_type="STRING", mode="REQUIRED"), + SchemaField(name="name", field_type="STRING", mode="REQUIRED"), + SchemaField(name="address", field_type="STRING", mode="NULLABLE"), + ], + ), + SchemaField(name="tags", field_type="STRING", mode="REPEATED"), + SchemaField(name="score", field_type="NUMERIC", mode="NULLABLE"), + SchemaField(name="created_at", field_type="TIMESTAMP", mode="NULLABLE"), + ] + adapter_mock.get_bq_schema.return_value = table_schema + renderer = runtime_renderer(context, engine_adapter=adapter_mock, dialect="bigquery") + assert renderer( + "{%- set relation = api.Relation.create(database='test', schema='test', identifier='test_table') -%}" + "{{ adapter.get_columns_in_relation(relation) }}" + ) == str([BigQueryColumn.create_from_field(field) for field in table_schema]) + @pytest.mark.cicdonly +@pytest.mark.slow def test_normalization( sushi_test_project: Project, runtime_renderer: t.Callable, mocker: MockerFixture ): + from sqlmesh.core.engine_adapter.base import DataObject, DataObjectType + context = sushi_test_project.context assert context.target + data_object = DataObject(catalog="test", schema="bla", name="bob", type=DataObjectType.TABLE) # bla and bob will be normalized to lowercase since the target is duckdb adapter_mock = mocker.MagicMock() adapter_mock.default_catalog = "test" adapter_mock.dialect = "duckdb" - + adapter_mock.get_data_object.return_value = data_object duckdb_renderer = runtime_renderer(context, engine_adapter=adapter_mock) schema_bla = schema_("bla", "test", quoted=True) relation_bla_bob = exp.table_("bob", db="bla", catalog="test", quoted=True) duckdb_renderer("{{ adapter.get_relation(database=None, schema='bla', identifier='bob') }}") - adapter_mock.table_exists.assert_has_calls([call(relation_bla_bob)]) + adapter_mock.get_data_object.assert_has_calls([call(relation_bla_bob)]) # bla and bob will be normalized to uppercase since the target is Snowflake, even though the default dialect is duckdb adapter_mock = mocker.MagicMock() adapter_mock.default_catalog = "test" adapter_mock.dialect = "snowflake" + adapter_mock.get_data_object.return_value = data_object context.target = SnowflakeConfig( account="test", user="test", @@ -104,10 +179,10 @@ def test_normalization( relation_bla_bob = exp.table_("bob", db="bla", catalog="test", quoted=True) renderer("{{ adapter.get_relation(database=None, schema='bla', identifier='bob') }}") - adapter_mock.table_exists.assert_has_calls([call(relation_bla_bob)]) + adapter_mock.get_data_object.assert_has_calls([call(relation_bla_bob)]) renderer("{{ adapter.get_relation(database='custom_db', schema='bla', identifier='bob') }}") - adapter_mock.table_exists.assert_has_calls( + adapter_mock.get_data_object.assert_has_calls( [call(exp.table_("bob", db="bla", catalog="custom_db", quoted=True))] ) @@ -160,18 +235,49 @@ def test_normalization( adapter_mock.drop_table.assert_has_calls([call(relation_bla_bob)]) +@pytest.mark.slow def test_adapter_dispatch(sushi_test_project: Project, runtime_renderer: t.Callable): context = sushi_test_project.context renderer = runtime_renderer(context) assert renderer("{{ adapter.dispatch('current_engine', 'customers')() }}") == "duckdb" assert renderer("{{ adapter.dispatch('current_timestamp')() }}") == "now()" assert renderer("{{ adapter.dispatch('current_timestamp', 'dbt')() }}") == "now()" + assert renderer("{{ adapter.dispatch('select_distinct', 'customers')() }}") == "distinct" + + # test with keyword arguments + assert ( + renderer( + "{{ adapter.dispatch(macro_name='current_engine', macro_namespace='customers')() }}" + ) + == "duckdb" + ) + assert renderer("{{ adapter.dispatch(macro_name='current_timestamp')() }}") == "now()" + assert ( + renderer("{{ adapter.dispatch(macro_name='current_timestamp', macro_namespace='dbt')() }}") + == "now()" + ) + + # mixing positional and keyword arguments + assert ( + renderer("{{ adapter.dispatch('current_engine', macro_namespace='customers')() }}") + == "duckdb" + ) + assert ( + renderer("{{ adapter.dispatch('current_timestamp', macro_namespace=None)() }}") == "now()" + ) + assert ( + renderer("{{ adapter.dispatch('current_timestamp', macro_namespace='dbt')() }}") == "now()" + ) + + with pytest.raises(ConfigError, match=r"Macro 'current_engine'.*was not found."): + renderer("{{ adapter.dispatch(macro_name='current_engine')() }}") with pytest.raises(ConfigError, match=r"Macro 'current_engine'.*was not found."): renderer("{{ adapter.dispatch('current_engine')() }}") @pytest.mark.parametrize("project_dialect", ["duckdb", "bigquery"]) +@pytest.mark.slow def test_adapter_map_snapshot_tables( sushi_test_project: Project, runtime_renderer: t.Callable, @@ -204,10 +310,10 @@ def test_adapter_map_snapshot_tables( engine_adapter.create_schema("sqlmesh") engine_adapter.create_table( table_name='"memory"."sqlmesh"."test_db__test_model"', - columns_to_types={"baz": exp.DataType.build("int")}, + target_columns_to_types={"baz": exp.DataType.build("int")}, ) engine_adapter.create_table( - table_name="foo.bar", columns_to_types={"col": exp.DataType.build("int")} + table_name="foo.bar", target_columns_to_types={"col": exp.DataType.build("int")} ) expected_test_model_table_name = parse_one('"memory"."sqlmesh"."test_db__test_model"').sql( @@ -237,23 +343,124 @@ def test_adapter_map_snapshot_tables( assert renderer("{{ adapter.resolve_identifier(foo_bar) }}") == "bar" -def test_feature_flag_scd_type_2(copy_to_temp_path, caplog): - project_root = "tests/fixtures/dbt/sushi_test" - sushi_context = Context(paths=copy_to_temp_path(project_root)) - assert '"memory"."snapshots"."items_snapshot"' in sushi_context.models - assert ( - "Skipping loading Snapshot (SCD Type 2) models due to the feature flag disabling this feature" - not in caplog.text - ) - with mock.patch.dict( - os.environ, - { - "SQLMESH__FEATURE_FLAGS__DBT__SCD_TYPE_2_SUPPORT": "false", - }, +def test_quote_as_configured(): + adapter = ParsetimeAdapter( + JinjaMacroRegistry(), + project_dialect="duckdb", + quote_policy=Policy(schema=False, identifier=True), + ) + adapter.quote_as_configured("foo", "identifier") == '"foo"' + adapter.quote_as_configured("foo", "schema") == "foo" + adapter.quote_as_configured("foo", "database") == "foo" + + +@pytest.mark.slow +def test_adapter_get_relation_normalization( + sushi_test_project: Project, runtime_renderer: t.Callable +): + # Simulate that the quote policy is set to quote everything to make + # sure that we normalize correctly even if quotes are applied + with mock.patch.object( + SnowflakeConfig, + "quote_policy", + Policy(identifier=True, schema=True, database=True), ): - sushi_context = Context(paths=copy_to_temp_path(project_root)) - assert '"memory"."snapshots"."items_snapshot"' not in sushi_context.models + context = sushi_test_project.context + assert context.target + engine_adapter = context.target.to_sqlmesh().create_engine_adapter() + engine_adapter._default_catalog = '"memory"' + renderer = runtime_renderer(context, engine_adapter=engine_adapter, dialect="snowflake") + + engine_adapter.create_schema('"FOO"') + engine_adapter.create_table( + table_name='"FOO"."BAR"', target_columns_to_types={"baz": exp.DataType.build("int")} + ) + assert ( - "Skipping loading Snapshot (SCD Type 2) models due to the feature flag disabling this feature" - in caplog.text + renderer("{{ adapter.get_relation(database=None, schema='foo', identifier='bar') }}") + == '"memory"."FOO"."BAR"' ) + + assert ( + renderer("{{ adapter.list_relations(database=None, schema='foo') }}") + == '[]' + ) + + +@pytest.mark.slow +def test_adapter_expand_target_column_types( + sushi_test_project: Project, runtime_renderer: t.Callable, mocker: MockerFixture +): + from sqlmesh.core.engine_adapter.base import DataObject, DataObjectType + + data_object_from = DataObject( + catalog="test", schema="foo", name="from_table", type=DataObjectType.TABLE + ) + data_object_to = DataObject( + catalog="test", schema="foo", name="to_table", type=DataObjectType.TABLE + ) + from_columns = { + "int_col": exp.DataType.build("int"), + "same_text_col": exp.DataType.build("varchar(1)"), # varchar(1) -> varchar(1) + "unexpandable_text_col": exp.DataType.build("varchar(2)"), # varchar(4) -> varchar(2) + "expandable_text_col1": exp.DataType.build("varchar(16)"), # varchar(8) -> varchar(16) + "expandable_text_col2": exp.DataType.build("varchar(64)"), # varchar(32) -> varchar(64) + } + to_columns = { + "int_col": exp.DataType.build("int"), + "same_text_col": exp.DataType.build("varchar(1)"), + "unexpandable_text_col": exp.DataType.build("varchar(4)"), + "expandable_text_col1": exp.DataType.build("varchar(8)"), + "expandable_text_col2": exp.DataType.build("varchar(32)"), + } + adapter_mock = mocker.MagicMock() + adapter_mock.default_catalog = "test" + adapter_mock.get_data_object.side_effect = [data_object_from, data_object_to] + # columns() is called 4 times, twice by adapter.get_columns_in_relation() and twice by the engine_adapter + adapter_mock.columns.side_effect = [ + from_columns, + to_columns, + from_columns, + to_columns, + ] + adapter_mock.schema_differ = SchemaDiffer() + + context = sushi_test_project.context + renderer = runtime_renderer(context, engine_adapter=adapter_mock) + + renderer(""" + {%- set from_relation = adapter.get_relation( + database=None, + schema='foo', + identifier='from_table') -%} + + {% set to_relation = adapter.get_relation( + database=None, + schema='foo', + identifier='to_table') -%} + + {% do adapter.expand_target_column_types(from_relation, to_relation) %} + """) + adapter_mock.get_data_object.assert_has_calls( + [ + call(exp.to_table('"test"."foo"."from_table"')), + call(exp.to_table('"test"."foo"."to_table"')), + ] + ) + assert len(adapter_mock.alter_table.call_args.args) == 1 + alter_expressions = adapter_mock.alter_table.call_args.args[0] + assert len(alter_expressions) == 2 + alter_operation1 = alter_expressions[0] + assert isinstance(alter_operation1, TableAlterChangeColumnTypeOperation) + assert alter_operation1.expression == parse_one( + """ALTER TABLE "test"."foo"."to_table" + ALTER COLUMN expandable_text_col1 + SET DATA TYPE VARCHAR(16)""" + ) + alter_operation2 = alter_expressions[1] + assert isinstance(alter_operation2, TableAlterChangeColumnTypeOperation) + assert alter_operation2.expression == parse_one( + """ALTER TABLE "test"."foo"."to_table" + ALTER COLUMN expandable_text_col2 + SET DATA TYPE VARCHAR(64)""" + ) diff --git a/tests/dbt/test_config.py b/tests/dbt/test_config.py index bbd4712b5b..5dccd90ed2 100644 --- a/tests/dbt/test_config.py +++ b/tests/dbt/test_config.py @@ -6,14 +6,27 @@ import pytest from dbt.adapters.base import BaseRelation, Column from pytest_mock import MockerFixture + +from sqlglot import exp + +from sqlmesh import Context +from sqlmesh.core.audit import StandaloneAudit from sqlmesh.core.config import Config, ModelDefaultsConfig from sqlmesh.core.dialect import jinja_query from sqlmesh.core.model import SqlModel -from sqlmesh.core.model.kind import OnDestructiveChange +from sqlmesh.core.model.kind import OnDestructiveChange, OnAdditiveChange +from sqlmesh.core.state_sync import CachingStateSync, EngineAdapterStateSync +from sqlmesh.dbt.builtin import Api +from sqlmesh.dbt.column import ColumnConfig from sqlmesh.dbt.common import Dependencies from sqlmesh.dbt.context import DbtContext from sqlmesh.dbt.loader import sqlmesh_config -from sqlmesh.dbt.model import IncrementalByUniqueKeyKind, Materialization, ModelConfig +from sqlmesh.dbt.model import ( + IncrementalByTimeRangeKind, + IncrementalByUniqueKeyKind, + Materialization, + ModelConfig, +) from sqlmesh.dbt.project import Project from sqlmesh.dbt.relation import Policy from sqlmesh.dbt.source import SourceConfig @@ -28,10 +41,14 @@ SnowflakeConfig, TargetConfig, TrinoConfig, + AthenaConfig, + ClickhouseConfig, + SCHEMA_DIFFER_OVERRIDES, ) from sqlmesh.dbt.test import TestConfig from sqlmesh.utils.errors import ConfigError -from sqlmesh.utils.yaml import load as yaml_load +from sqlmesh.utils.yaml import load as yaml_load, dump as yaml_dump +from tests.dbt.conftest import EmptyProjectCreator pytestmark = pytest.mark.dbt @@ -74,10 +91,12 @@ def test_update(current: t.Dict[str, t.Any], new: t.Dict[str, t.Any], expected: assert {k: v for k, v in config.dict().items() if k in expected} == expected -def test_model_to_sqlmesh_fields(): +def test_model_to_sqlmesh_fields(dbt_dummy_postgres_config: PostgresConfig): model_config = ModelConfig( + unique_id="model.package.name", name="name", package_name="package", + fqn=["package", "name"], alias="model", schema="custom", database="database", @@ -87,6 +106,10 @@ def test_model_to_sqlmesh_fields(): start="Jan 1 2023", partition_by=["a"], cluster_by=["a", '"b"'], + incremental_predicates=[ + "55 > DBT_INTERNAL_SOURCE.b", + "DBT_INTERNAL_DEST.session_start > date_add(current_date, interval 7 day)", + ], cron="@hourly", interval_unit="FIVE_MINUTE", batch_size=5, @@ -99,29 +122,37 @@ def test_model_to_sqlmesh_fields(): ) context = DbtContext() context.project_name = "Foo" - context.target = DuckDbConfig(name="target", schema="foo") + context.target = dbt_dummy_postgres_config model = model_config.to_sqlmesh(context) assert isinstance(model, SqlModel) assert model.name == "database.custom.model" + assert model.dbt_unique_id == "model.package.name" + assert model.dbt_fqn == "package.name" assert model.description == "test model" assert ( model.render_query_or_raise().sql() - == 'SELECT 1 AS "a" FROM "memory"."foo"."table" AS "table"' + == 'SELECT 1 AS "a" FROM "dbname"."foo"."table" AS "table"' ) assert model.start == "Jan 1 2023" assert [col.sql() for col in model.partitioned_by] == ['"a"'] - assert model.clustered_by == ["a", "b"] + assert [col.sql() for col in model.clustered_by] == ['"a"', '"b"'] assert model.cron == "@hourly" assert model.interval_unit.value == "five_minute" assert model.stamp == "bar" - assert model.dialect == "duckdb" + assert model.dialect == "postgres" assert model.owner == "Sally" assert model.tags == ["test", "incremental"] + assert model.allow_partials kind = t.cast(IncrementalByUniqueKeyKind, model.kind) assert kind.batch_size == 5 assert kind.lookback == 3 assert kind.on_destructive_change == OnDestructiveChange.ALLOW + assert kind.on_additive_change == OnAdditiveChange.ALLOW + assert ( + kind.merge_filter.sql(dialect=model.dialect) # type: ignore + == """55 > "__MERGE_SOURCE__"."b" AND "__MERGE_TARGET__"."session_start" > CURRENT_DATE + INTERVAL '7'""" + ) model = model_config.update_with({"dialect": "snowflake"}).to_sqlmesh(context) assert model.dialect == "snowflake" @@ -130,15 +161,39 @@ def test_model_to_sqlmesh_fields(): sqlmesh_config=Config(model_defaults=ModelDefaultsConfig(dialect="bigquery")) ) bq_default_context.project_name = "Foo" - bq_default_context.target = DuckDbConfig(name="target", schema="foo") + bq_default_context.target = dbt_dummy_postgres_config + model_config.cluster_by = ["a", "`b`"] model = model_config.to_sqlmesh(bq_default_context) assert model.dialect == "bigquery" + model_config = ModelConfig( + name="name", + package_name="package", + alias="model", + schema="custom", + database="database", + materialized=Materialization.INCREMENTAL, + sql="SELECT * FROM foo.table", + time_column="ds", + start="Jan 1 2023", + batch_size=5, + batch_concurrency=2, + on_schema_change="ignore", + ) + model = model_config.to_sqlmesh(context) + assert isinstance(model.kind, IncrementalByTimeRangeKind) + assert model.kind.batch_concurrency == 2 + assert model.kind.time_column.column.name == "ds" + assert model.kind.on_destructive_change == OnDestructiveChange.IGNORE + assert model.kind.on_additive_change == OnAdditiveChange.IGNORE + def test_test_to_sqlmesh_fields(): sql = "SELECT * FROM FOO WHERE cost > 100" test_config = TestConfig( + unique_id="test.test_package.foo_test", name="foo_test", + fqn=["test_package", "foo_test"], sql=sql, model_name="Foo", column_name="cost", @@ -152,6 +207,8 @@ def test_test_to_sqlmesh_fields(): audit = test_config.to_sqlmesh(context) assert audit.name == "foo_test" + assert audit.dbt_unique_id == "test.test_package.foo_test" + assert audit.dbt_fqn == "test_package.foo_test" assert audit.dialect == "duckdb" assert not audit.skip assert audit.blocking @@ -190,7 +247,32 @@ def test_test_to_sqlmesh_fields(): assert audit.dialect == "bigquery" -def test_singular_test_to_standalone_audit(): +def test_test_config_canonical_name(): + test_config_upper_case_package = TestConfig( + name="foo_test", + package_name="TEST_PACKAGE", + sql="SELECT 1", + ) + + assert test_config_upper_case_package.canonical_name == "test_package.foo_test" + + test_config_mixed_case_package = TestConfig( + name="Bar_Test", + package_name="MixedCase_Package", + sql="SELECT 1", + ) + + assert test_config_mixed_case_package.canonical_name == "mixedcase_package.bar_test" + + test_config_no_package = TestConfig( + name="foo_bar_test", + sql="SELECT 1", + ) + + assert test_config_no_package.canonical_name == "foo_bar_test" + + +def test_singular_test_to_standalone_audit(dbt_dummy_postgres_config: PostgresConfig): sql = "SELECT * FROM FOO.BAR WHERE cost > 100" test_config = TestConfig( name="bar_test", @@ -212,8 +294,8 @@ def test_singular_test_to_standalone_audit(): context = DbtContext() context.add_models({model.name: model}) context._project_name = "Foo" - context.target = DuckDbConfig(name="target", schema="foo") - standalone_audit = test_config.to_sqlmesh(context) + context.target = dbt_dummy_postgres_config + standalone_audit = t.cast(StandaloneAudit, test_config.to_sqlmesh(context)) assert standalone_audit.name == "bar_test" assert standalone_audit.description == "test description" @@ -221,51 +303,16 @@ def test_singular_test_to_standalone_audit(): assert standalone_audit.stamp == "bump" assert standalone_audit.cron == "@monthly" assert standalone_audit.interval_unit.value == "day" - assert standalone_audit.dialect == "duckdb" + assert standalone_audit.dialect == "postgres" assert standalone_audit.query == jinja_query(sql) - assert standalone_audit.depends_on == {'"memory"."foo"."bar"'} + assert standalone_audit.depends_on == {'"dbname"."foo"."bar"'} test_config.dialect_ = "bigquery" - standalone_audit = test_config.to_sqlmesh(context) + standalone_audit = t.cast(StandaloneAudit, test_config.to_sqlmesh(context)) assert standalone_audit.dialect == "bigquery" -def test_model_config_sql_no_config(): - assert ( - ModelConfig( - sql="""{{ - config( - materialized='table', - incremental_strategy='delete+"insert' - ) -}} -query""" - ).sql_no_config.strip() - == "query" - ) - - assert ( - ModelConfig( - sql="""{{ - config( - materialized='table', - incremental_strategy='delete+insert', - post_hook=" '{{ var('new') }}' " - ) -}} -query""" - ).sql_no_config.strip() - == "query" - ) - - assert ( - ModelConfig( - sql="""before {{config(materialized='table', post_hook=" {{ var('new') }} ")}} after""" - ).sql_no_config.strip() - == "before after" - ) - - +@pytest.mark.slow def test_variables(assert_exp_eq, sushi_test_project): # Case 1: using an undefined variable without a default value defined_variables = {} @@ -303,7 +350,6 @@ def test_variables(assert_exp_eq, sushi_test_project): # Case 3: using a defined variable with a default value model_config.sql = "SELECT {{ var('foo', 5) }}" - model_config._sql_no_config = None assert_exp_eq(model_config.to_sqlmesh(**kwargs).render_query(), 'SELECT 6 AS "6"') @@ -315,34 +361,86 @@ def test_variables(assert_exp_eq, sushi_test_project): # Finally, check that variable scoping & overwriting (some_var) works as expected expected_sushi_variables = { - "start": "Jan 1 2022", "yet_another_var": 1, - "top_waiters:limit": 10, + "top_waiters:limit": "{{ get_top_waiters_limit() }}", "top_waiters:revenue": "revenue", "customers:boo": ["a", "b"], + "nested_vars": { + "some_nested_var": 2, + }, + "dynamic_test_var": 3, + "list_var": [ + {"name": "item1", "value": 1}, + {"name": "item2", "value": 2}, + ], + "customers": { + "customers:bla": False, + "customers:customer_id": "customer_id", + "some_var": ["foo", "bar"], + }, + "some_var": "should be overridden in customers package", + "invalid_var": "{{ ref('ref_without_closing_paren' }}", } expected_customer_variables = { - "some_var": ["foo", "bar"], + "some_var": ["foo", "bar"], # Takes precedence over the root project variable "some_other_var": 5, - "yet_another_var": 1, "customers:bla": False, "customers:customer_id": "customer_id", - "start": "Jan 1 2022", - "top_waiters:limit": 10, + "yet_another_var": 1, # Make sure that the project variable takes precedence + "top_waiters:limit": "{{ get_top_waiters_limit() }}", "top_waiters:revenue": "revenue", "customers:boo": ["a", "b"], + "nested_vars": { + "some_nested_var": 2, + }, + "dynamic_test_var": 3, + "list_var": [ + {"name": "item1", "value": 1}, + {"name": "item2", "value": 2}, + ], + "invalid_var": "{{ ref('ref_without_closing_paren' }}", } - assert sushi_test_project.packages["sushi"].variables == expected_sushi_variables assert sushi_test_project.packages["customers"].variables == expected_customer_variables +@pytest.mark.slow +def test_variables_override(init_and_plan_context: t.Callable): + context, _ = init_and_plan_context( + "tests/fixtures/dbt/sushi_test", config="test_config_with_var_override" + ) + dbt_project = context._loaders[0]._load_projects()[0] # type: ignore + assert dbt_project.packages["sushi"].variables["some_var"] == "overridden_from_config_py" + assert dbt_project.packages["customers"].variables["some_var"] == "overridden_from_config_py" + + +@pytest.mark.slow +def test_jinja_in_dbt_variables(sushi_test_dbt_context: Context): + assert sushi_test_dbt_context.render("sushi.top_waiters").sql().endswith("LIMIT 10") + + +@pytest.mark.slow +def test_nested_variables(sushi_test_project): + model_config = ModelConfig( + alias="sushi.test_nested", + sql="SELECT {{ var('nested_vars')['some_nested_var'] }}", + dependencies=Dependencies(variables=["nested_vars"]), + ) + context = sushi_test_project.context.copy() + context.set_and_render_variables(sushi_test_project.packages["sushi"].variables, "sushi") + sqlmesh_model = model_config.to_sqlmesh(context) + assert sqlmesh_model.jinja_macros.global_objs["vars"]["nested_vars"] == {"some_nested_var": 2} + + +@pytest.mark.slow def test_source_config(sushi_test_project: Project): source_configs = sushi_test_project.packages["sushi"].sources assert set(source_configs) == { - "streaming.items", "streaming.orders", + "parquet_file.items", "streaming.order_items", + "streaming.items", + "parquet_file.orders", } expected_config = { @@ -359,10 +457,16 @@ def test_source_config(sushi_test_project: Project): == "raw.order_items" ) + assert ( + source_configs["parquet_file.orders"].canonical_name(sushi_test_project.context) + == "read_parquet('path/to/external/orders.parquet')" + ) + +@pytest.mark.slow def test_seed_config(sushi_test_project: Project, mocker: MockerFixture): seed_configs = sushi_test_project.packages["sushi"].seeds - assert set(seed_configs) == {"waiter_names"} + assert set(seed_configs) == {"waiter_names", "waiter_revenue_semicolon"} raw_items_seed = seed_configs["waiter_names"] expected_config = { @@ -383,6 +487,25 @@ def test_seed_config(sushi_test_project: Project, mocker: MockerFixture): == '"MEMORY"."SUSHI"."WAITER_NAMES"' ) + waiter_revenue_semicolon_seed = seed_configs["waiter_revenue_semicolon"] + + expected_config_semicolon = { + "path": Path(sushi_test_project.context.project_root, "seeds/waiter_revenue_semicolon.csv"), + "schema_": "sushi", + "delimiter": ";", + } + actual_config_semicolon = { + k: getattr(waiter_revenue_semicolon_seed, k) for k, v in expected_config_semicolon.items() + } + assert actual_config_semicolon == expected_config_semicolon + + assert waiter_revenue_semicolon_seed.canonical_name(context) == "sushi.waiter_revenue_semicolon" + assert ( + waiter_revenue_semicolon_seed.to_sqlmesh(context).name == "sushi.waiter_revenue_semicolon" + ) + assert waiter_revenue_semicolon_seed.delimiter == ";" + assert set(waiter_revenue_semicolon_seed.columns.keys()) == {"waiter_id", "revenue", "quarter"} + def test_quoting(): model = ModelConfig(alias="bar", schema="foo") @@ -439,7 +562,7 @@ def test_duckdb_threads(tmp_path): def test_snowflake_config(): - _test_warehouse_config( + config = _test_warehouse_config( """ sushi: target: dev @@ -461,6 +584,11 @@ def test_snowflake_config(): "outputs", "dev", ) + sqlmesh_config = config.to_sqlmesh() + assert sqlmesh_config.application == "Tobiko_SQLMesh" + assert ( + sqlmesh_config.schema_differ_overrides == SCHEMA_DIFFER_OVERRIDES["schema_differ_overrides"] + ) def test_snowflake_config_private_key_path(): @@ -573,6 +701,28 @@ def test_postgres_config(): "outputs", "dev", ) + # 'pass' field instead of 'password' + _test_warehouse_config( + """ + dbt-postgres: + target: dev + outputs: + dev: + type: postgres + host: postgres + user: postgres + pass: postgres + port: 5432 + dbname: postgres + schema: demo + threads: 3 + keepalives_idle: 0 + """, + PostgresConfig, + "dbt-postgres", + "outputs", + "dev", + ) def test_redshift_config(): @@ -597,6 +747,28 @@ def test_redshift_config(): "outputs", "dev", ) + # 'pass' field instead of 'password' + _test_warehouse_config( + """ + dbt-redshift: + target: dev + outputs: + dev: + type: redshift + host: hostname.region.redshift.amazonaws.com + user: username + pass: password1 + port: 5439 + dbname: analytics + schema: analytics + threads: 4 + ra3_node: false + """, + RedshiftConfig, + "dbt-redshift", + "outputs", + "dev", + ) def test_databricks_config(): @@ -620,6 +792,35 @@ def test_databricks_config(): ) +def test_databricks_config_oauth(): + config = _test_warehouse_config( + """ + dbt-databricks: + target: dev + outputs: + dev: + type: databricks + catalog: test_catalog + schema: analytics + host: yourorg.databrickshost.com + http_path: /sql/your/http/path + auth_type: oauth + client_id: client-id + client_secret: client-secret + """, + DatabricksConfig, + "dbt-databricks", + "outputs", + "dev", + ) + + as_sqlmesh = config.to_sqlmesh() + assert as_sqlmesh.auth_type == "databricks-oauth" + assert as_sqlmesh.oauth_client_id == "client-id" + assert as_sqlmesh.oauth_client_secret == "client-secret" + assert as_sqlmesh.schema_differ_overrides == SCHEMA_DIFFER_OVERRIDES["schema_differ_overrides"] + + def test_bigquery_config(): _test_warehouse_config( """ @@ -728,47 +929,131 @@ def test_trino_config(): ) +def test_athena_config(): + _test_warehouse_config( + """ + dbt-athena: + target: dev + outputs: + dev: + type: athena + s3_staging_dir: s3://athena-query-results/dbt/ + s3_data_dir: s3://your_s3_bucket/dbt/ + s3_data_naming: schema_table + s3_tmp_table_dir: s3://your_s3_bucket/temp/ + region_name: eu-west-1 + schema: dbt + database: awsdatacatalog + threads: 4 + aws_profile_name: my-profile + work_group: my-workgroup + spark_work_group: my-spark-workgroup + seed_s3_upload_args: + ACL: bucket-owner-full-control + """, + AthenaConfig, + "dbt-athena", + "outputs", + "dev", + ) + + +def test_clickhouse_config(): + _test_warehouse_config( + """ + dbt-clickhouse: + target: dev + outputs: + dev: + type: clickhouse + host: thehost + user: theuser + password: thepassword + port: 1234 + secure: true + cluster: thecluster + connect_timeout: 1 + send_receive_timeout: 2 + verify: false + compression: lz4 + custom_settings: + setting: value + + """, + ClickhouseConfig, + "dbt-clickhouse", + "outputs", + "dev", + ) + + def test_connection_args(tmp_path): dbt_project_dir = "tests/fixtures/dbt/sushi_test" config = sqlmesh_config(dbt_project_dir) + assert not config.gateways["in_memory"].connection.register_comments + + config = sqlmesh_config(dbt_project_dir, register_comments=True) assert config.gateways["in_memory"].connection.register_comments - config = sqlmesh_config(dbt_project_dir, register_comments=False) - assert not config.gateways["in_memory"].connection.register_comments + +def test_custom_dbt_loader(): + from sqlmesh.core.loader import SqlMeshLoader + from sqlmesh.dbt.loader import DbtLoader + + dbt_project_dir = "tests/fixtures/dbt/sushi_test" + with pytest.raises(ConfigError, match="The loader must be a DbtLoader."): + sqlmesh_config(dbt_project_dir, loader=SqlMeshLoader) + + class CustomDbtLoader(DbtLoader): + pass + + sqlmesh_config(dbt_project_dir, loader=CustomDbtLoader) @pytest.mark.cicdonly +@pytest.mark.slow def test_db_type_to_relation_class(): from dbt.adapters.bigquery.relation import BigQueryRelation from dbt.adapters.databricks.relation import DatabricksRelation from dbt.adapters.duckdb.relation import DuckDBRelation from dbt.adapters.redshift import RedshiftRelation from dbt.adapters.snowflake import SnowflakeRelation - from dbt.adapters.trino.relation import TrinoRelation assert (TARGET_TYPE_TO_CONFIG_CLASS["bigquery"].relation_class) == BigQueryRelation assert (TARGET_TYPE_TO_CONFIG_CLASS["databricks"].relation_class) == DatabricksRelation assert (TARGET_TYPE_TO_CONFIG_CLASS["duckdb"].relation_class) == DuckDBRelation assert (TARGET_TYPE_TO_CONFIG_CLASS["redshift"].relation_class) == RedshiftRelation assert (TARGET_TYPE_TO_CONFIG_CLASS["snowflake"].relation_class) == SnowflakeRelation + + from dbt.adapters.clickhouse.relation import ClickHouseRelation + from dbt.adapters.trino.relation import TrinoRelation + from dbt.adapters.athena.relation import AthenaRelation + + assert (TARGET_TYPE_TO_CONFIG_CLASS["clickhouse"].relation_class) == ClickHouseRelation assert (TARGET_TYPE_TO_CONFIG_CLASS["trino"].relation_class) == TrinoRelation + assert (TARGET_TYPE_TO_CONFIG_CLASS["athena"].relation_class) == AthenaRelation @pytest.mark.cicdonly +@pytest.mark.slow def test_db_type_to_column_class(): from dbt.adapters.bigquery import BigQueryColumn from dbt.adapters.databricks.column import DatabricksColumn from dbt.adapters.snowflake import SnowflakeColumn - from dbt.adapters.sqlserver.sql_server_column import SQLServerColumn - from dbt.adapters.trino.column import TrinoColumn assert (TARGET_TYPE_TO_CONFIG_CLASS["bigquery"].column_class) == BigQueryColumn assert (TARGET_TYPE_TO_CONFIG_CLASS["databricks"].column_class) == DatabricksColumn assert (TARGET_TYPE_TO_CONFIG_CLASS["duckdb"].column_class) == Column assert (TARGET_TYPE_TO_CONFIG_CLASS["snowflake"].column_class) == SnowflakeColumn - assert (TARGET_TYPE_TO_CONFIG_CLASS["sqlserver"].column_class) == SQLServerColumn + + from dbt.adapters.clickhouse.column import ClickHouseColumn + from dbt.adapters.trino.column import TrinoColumn + from dbt.adapters.athena.column import AthenaColumn + + assert (TARGET_TYPE_TO_CONFIG_CLASS["clickhouse"].column_class) == ClickHouseColumn assert (TARGET_TYPE_TO_CONFIG_CLASS["trino"].column_class) == TrinoColumn + assert (TARGET_TYPE_TO_CONFIG_CLASS["athena"].column_class) == AthenaColumn def test_db_type_to_quote_policy(): @@ -778,7 +1063,187 @@ def test_db_type_to_quote_policy(): def test_variable_override(): project_root = "tests/fixtures/dbt/sushi_test" project = Project.load( - DbtContext(project_root=Path(project_root)), - variables={"yet_another_var": 2, "start": "2021-01-01"}, + DbtContext( + project_root=Path(project_root), + sqlmesh_config=Config(model_defaults=ModelDefaultsConfig(start="2021-01-01")), + ), + variables={"yet_another_var": 2}, ) assert project.packages["sushi"].variables["yet_another_var"] == 2 + + +@pytest.mark.slow +def test_depends_on(assert_exp_eq, sushi_test_project): + # Case 1: using an undefined variable without a default value + context = sushi_test_project.context + + model_config = ModelConfig( + alias="sushi.test", + sql="SELECT * FROM {{ ref('waiter_revenue_by_day') }} JOIN other_table", + dependencies=Dependencies(refs=["waiter_revenue_by_day"]), + ) + + sqlmesh_model = model_config.to_sqlmesh(context) + assert sqlmesh_model.depends_on_ == {'"memory"."sushi"."waiter_revenue_by_day_v2"'} + assert sqlmesh_model.depends_on == {'"memory"."sushi"."waiter_revenue_by_day_v2"'} + assert sqlmesh_model.full_depends_on == {'"memory"."sushi"."waiter_revenue_by_day_v2"'} + + # Make sure the query wasn't rendered + assert not sqlmesh_model._query_renderer._cache + + +@pytest.mark.parametrize( + "on_schema_change, expected_additive, expected_destructive", + [ + ("ignore", OnAdditiveChange.IGNORE, OnDestructiveChange.IGNORE), + ("fail", OnAdditiveChange.ERROR, OnDestructiveChange.ERROR), + ("append_new_columns", OnAdditiveChange.ALLOW, OnDestructiveChange.IGNORE), + ("sync_all_columns", OnAdditiveChange.ALLOW, OnDestructiveChange.ALLOW), + ], +) +def test_on_schema_change_properties( + on_schema_change: str, + expected_additive: OnAdditiveChange, + expected_destructive: OnDestructiveChange, +): + model_config = ModelConfig( + name="name", + package_name="package", + alias="model", + schema="custom", + database="database", + materialized=Materialization.INCREMENTAL, + sql="SELECT * FROM foo.table", + time_column="ds", + start="Jan 1 2023", + batch_size=5, + batch_concurrency=2, + on_schema_change=on_schema_change, + ) + context = DbtContext() + context.project_name = "Foo" + context.target = DuckDbConfig(name="target", schema="foo") + model = model_config.to_sqlmesh(context) + + assert model.on_additive_change == expected_additive + assert model.on_destructive_change == expected_destructive + + +def test_sqlmesh_model_kwargs_columns_override(): + context = DbtContext() + context.project_name = "Foo" + context.target = DuckDbConfig(name="target", schema="foo") + + kwargs = ModelConfig(dialect="duckdb").sqlmesh_model_kwargs( + context, + {"c": ColumnConfig(name="c", data_type="uinteger")}, + ) + assert kwargs.get("columns") == {"c": exp.DataType.build(exp.DataType.Type.UINT)} + + +@pytest.mark.parametrize( + "dialect", + [ + "databricks", + "duckdb", + "postgres", + "redshift", + "snowflake", + "bigquery", + "trino", + "clickhouse", + ], +) +def test_api_class_loading(dialect: str): + Api(dialect) + + +def test_empty_vars_config(tmp_path): + """Test that a dbt project can be loaded with an empty vars config.""" + dbt_project_dir = tmp_path / "test_project" + dbt_project_dir.mkdir() + + # Create a minimal dbt_project.yml with empty vars + dbt_project_yml = dbt_project_dir / "dbt_project.yml" + dbt_project_yml.write_text(""" +name: test_empty_vars + +version: "1.0.0" +config-version: 2 + +profile: test_empty_vars + +models: + +start: Jan 1 2022 + +# Empty vars section - various ways to specify empty +vars: + """) + + # Create a minimal profiles.yml + profiles_yml = dbt_project_dir / "profiles.yml" + profiles_yml.write_text(""" +test_empty_vars: + outputs: + dev: + type: duckdb + schema: test + target: dev + """) + + # Create a simple model + model = dbt_project_dir / "models" / "some_model.sql" + model.parent.mkdir(parents=True, exist_ok=True) + model.write_text("SELECT 1 as id") + + # Load the project + from sqlmesh.dbt.context import DbtContext + from sqlmesh.dbt.project import Project + from sqlmesh.core.config import Config + + context = DbtContext(project_root=dbt_project_dir, sqlmesh_config=Config()) + + # This should not raise an error even with empty vars + project = Project.load(context) + + # Verify the project loaded successfully + assert project.packages["test_empty_vars"] is not None + assert project.packages["test_empty_vars"].name == "test_empty_vars" + + # Verify the variables are empty (not causing any issues) + assert project.packages["test_empty_vars"].variables == {} + assert project.context.variables == {} + + +def test_infer_state_schema_name(create_empty_project: EmptyProjectCreator): + project_dir, _ = create_empty_project("test_foo", "dev") + + # infer_state_schema_name defaults to False if omitted + config = sqlmesh_config(project_root=project_dir) + assert config.dbt + assert not config.dbt.infer_state_schema_name + assert config.get_state_schema() == "sqlmesh" + + # create_empty_project() uses the default dbt template for sqlmesh yaml config which + # sets infer_state_schema_name=True + ctx = Context(paths=[project_dir]) + assert ctx.config.dbt + assert ctx.config.dbt.infer_state_schema_name + assert ctx.config.get_state_schema() == "sqlmesh_state_test_foo_main" + assert isinstance(ctx.state_sync, CachingStateSync) + assert isinstance(ctx.state_sync.state_sync, EngineAdapterStateSync) + assert ctx.state_sync.state_sync.schema == "sqlmesh_state_test_foo_main" + + # If the user delberately overrides state_schema then we should respect this choice + config_file = project_dir / "sqlmesh.yaml" + config_yaml = yaml_load(config_file) + config_yaml["gateways"] = {"dev": {"state_schema": "state_override"}} + config_file.write_text(yaml_dump(config_yaml)) + + ctx = Context(paths=[project_dir]) + assert ctx.config.dbt + assert ctx.config.dbt.infer_state_schema_name + assert ctx.config.get_state_schema() == "state_override" + assert isinstance(ctx.state_sync, CachingStateSync) + assert isinstance(ctx.state_sync.state_sync, EngineAdapterStateSync) + assert ctx.state_sync.state_sync.schema == "state_override" diff --git a/tests/dbt/test_custom_materializations.py b/tests/dbt/test_custom_materializations.py new file mode 100644 index 0000000000..c1625d0251 --- /dev/null +++ b/tests/dbt/test_custom_materializations.py @@ -0,0 +1,777 @@ +from __future__ import annotations + +import typing as t +from pathlib import Path + +import pytest + +from sqlmesh import Context +from sqlmesh.core.config import ModelDefaultsConfig +from sqlmesh.core.engine_adapter import DuckDBEngineAdapter +from sqlmesh.core.model.kind import DbtCustomKind +from sqlmesh.dbt.context import DbtContext +from sqlmesh.dbt.manifest import ManifestHelper +from sqlmesh.dbt.model import ModelConfig +from sqlmesh.dbt.profile import Profile +from sqlmesh.dbt.basemodel import Materialization + +pytestmark = pytest.mark.dbt + + +@pytest.mark.xdist_group("dbt_manifest") +def test_custom_materialization_manifest_loading(): + project_path = Path("tests/fixtures/dbt/sushi_test") + profile = Profile.load(DbtContext(project_path)) + + helper = ManifestHelper( + project_path, + project_path, + "sushi", + profile.target, + model_defaults=ModelDefaultsConfig(start="2020-01-01"), + ) + materializations = helper.materializations() + + # custom materialization should have loaded from the manifest + assert "custom_incremental_default" in materializations + custom_incremental = materializations["custom_incremental_default"] + assert custom_incremental.name == "custom_incremental" + assert custom_incremental.adapter == "default" + assert "make_temp_relation(new_relation)" in custom_incremental.definition + assert "run_hooks(pre_hooks)" in custom_incremental.definition + assert " {{ return({'relations': [new_relation]}) }}" in custom_incremental.definition + + +@pytest.mark.xdist_group("dbt_manifest") +def test_custom_materialization_model_config(): + project_path = Path("tests/fixtures/dbt/sushi_test") + profile = Profile.load(DbtContext(project_path)) + + helper = ManifestHelper( + project_path, + project_path, + "sushi", + profile.target, + model_defaults=ModelDefaultsConfig(start="2020-01-01"), + ) + + models = helper.models() + + custom_model = models["custom_incremental_model"] + assert isinstance(custom_model, ModelConfig) + assert custom_model.materialized == "custom_incremental" + assert custom_model.model_materialization == Materialization.CUSTOM + + # pre and post hooks should also be handled in custom materializations + assert len(custom_model.pre_hook) == 2 + assert ( + custom_model.pre_hook[1].sql + == "CREATE TABLE IF NOT EXISTS hook_table (id INTEGER, length_col TEXT, updated_at TIMESTAMP)" + ) + assert len(custom_model.post_hook) == 2 + assert "COALESCE(MAX(id), 0)" in custom_model.post_hook[1].sql + + custom_filter_model = models["custom_incremental_with_filter"] + assert isinstance(custom_filter_model, ModelConfig) + assert custom_filter_model.materialized == "custom_incremental" + assert custom_filter_model.model_materialization == Materialization.CUSTOM + assert custom_filter_model.interval == "2 day" + assert custom_filter_model.time_column == "created_at" + + # verify also that the global hooks are inherited in the model without + assert len(custom_filter_model.pre_hook) == 1 + assert len(custom_filter_model.post_hook) == 1 + + +@pytest.mark.xdist_group("dbt_manifest") +def test_custom_materialization_model_kind(): + project_path = Path("tests/fixtures/dbt/sushi_test") + context = DbtContext(project_path) + profile = Profile.load(DbtContext(project_path)) + + helper = ManifestHelper( + project_path, + project_path, + "sushi", + profile.target, + model_defaults=ModelDefaultsConfig(start="2020-01-01"), + ) + + context._target = profile.target + context._manifest = helper + models = helper.models() + + # custom materialization models get DbtCustomKind populated + custom_model = models["custom_incremental_model"] + kind = custom_model.model_kind(context) + assert isinstance(kind, DbtCustomKind) + assert kind.materialization == "custom_incremental" + assert kind.adapter == "default" + assert "create_table_as" in kind.definition + + custom_filter_model = models["custom_incremental_with_filter"] + kind = custom_filter_model.model_kind(context) + assert isinstance(kind, DbtCustomKind) + assert kind.materialization == "custom_incremental" + assert kind.adapter == "default" + assert "run_hooks" in kind.definition + + # the DbtCustomKind shouldnt be set for normal strategies + regular_model = models["simple_model_a"] + regular_kind = regular_model.model_kind(context) + assert not isinstance(regular_kind, DbtCustomKind) + + # verify in sqlmesh as well + sqlmesh_context = Context( + paths=["tests/fixtures/dbt/sushi_test"], + config=None, + ) + + custom_incremental = sqlmesh_context.get_model("sushi.custom_incremental_model") + assert isinstance(custom_incremental.kind, DbtCustomKind) + assert custom_incremental.kind.materialization == "custom_incremental" + + custom_with_filter = sqlmesh_context.get_model("sushi.custom_incremental_with_filter") + assert isinstance(custom_with_filter.kind, DbtCustomKind) + assert custom_with_filter.kind.materialization == "custom_incremental" + + +@pytest.mark.xdist_group("dbt_manifest") +def test_custom_materialization_dependencies(): + project_path = Path("tests/fixtures/dbt/sushi_test") + context = DbtContext(project_path) + profile = Profile.load(DbtContext(project_path)) + + helper = ManifestHelper( + project_path, + project_path, + "sushi", + profile.target, + model_defaults=ModelDefaultsConfig(start="2020-01-01"), + ) + + context._target = profile.target + context._manifest = helper + models = helper.models() + + # custom materialization uses macros that should appear in dependencies + for model_name in ["custom_incremental_model", "custom_incremental_with_filter"]: + materialization_deps = models[model_name]._get_custom_materialization(context) + assert materialization_deps is not None + assert len(materialization_deps.dependencies.macros) > 0 + macro_names = [macro.name for macro in materialization_deps.dependencies.macros] + expected_macros = [ + "build_incremental_filter_sql", + "Relation", + "create_table_as", + "make_temp_relation", + "run_hooks", + "statement", + ] + assert any(macro in macro_names for macro in expected_macros) + + +@pytest.mark.xdist_group("dbt_manifest") +def test_adapter_specific_materialization_override(copy_to_temp_path: t.Callable): + path = copy_to_temp_path("tests/fixtures/dbt/sushi_test") + temp_project = path[0] + + macros_dir = temp_project / "macros" / "materializations" + macros_dir.mkdir(parents=True, exist_ok=True) + + adapter_mat_content = """ +{%- materialization custom_adapter_test, default -%} + {%- set new_relation = api.Relation.create(database=database, schema=schema, identifier=identifier) -%} + + {{ run_hooks(pre_hooks, inside_transaction=False) }} + + {%- call statement('main') -%} + CREATE TABLE {{ new_relation }} AS ( + SELECT 'default_adapter' as adapter_type, * FROM ({{ sql }}) AS subquery + ) + {%- endcall -%} + + {{ run_hooks(post_hooks, inside_transaction=False) }} + + {{ return({'relations': [new_relation]}) }} +{%- endmaterialization -%} + +{%- materialization custom_adapter_test, adapter='postgres' -%} + {%- set new_relation = api.Relation.create(database=database, schema=schema, identifier=identifier) -%} + + {{ run_hooks(pre_hooks, inside_transaction=False) }} + + {%- call statement('main') -%} + CREATE TABLE {{ new_relation }} AS ( + SELECT 'postgres_adapter'::text as adapter_type, * FROM ({{ sql }}) AS subquery + ) + {%- endcall -%} + + {{ run_hooks(post_hooks, inside_transaction=False) }} + + {{ return({'relations': [new_relation]}) }} +{%- endmaterialization -%} + +{%- materialization custom_adapter_test, adapter='duckdb' -%} + {%- set new_relation = api.Relation.create(database=database, schema=schema, identifier=identifier) -%} + + {{ run_hooks(pre_hooks, inside_transaction=False) }} + + {%- call statement('main') -%} + CREATE TABLE {{ new_relation }} AS ( + SELECT 'duckdb_adapter' as adapter_type, * FROM ({{ sql }}) AS subquery + ) + {%- endcall -%} + + {{ run_hooks(post_hooks, inside_transaction=False) }} + + {{ return({'relations': [new_relation]}) }} +{%- endmaterialization -%} +""".strip() + + (macros_dir / "custom_adapter_test.sql").write_text(adapter_mat_content) + + models_dir = temp_project / "models" + models_dir.mkdir(parents=True, exist_ok=True) + + test_model_content = """ +{{ config( + materialized='custom_adapter_test', +) }} + +SELECT + 1 as id, + 'test' as name +""".strip() + + (models_dir / "test_adapter_specific.sql").write_text(test_model_content) + + context = DbtContext(temp_project) + profile = Profile.load(context) + + helper = ManifestHelper( + temp_project, + temp_project, + "sushi", + profile.target, + model_defaults=ModelDefaultsConfig(start="2020-01-01"), + ) + + materializations = helper.materializations() + assert "custom_adapter_test_default" in materializations + assert "custom_adapter_test_duckdb" in materializations + assert "custom_adapter_test_postgres" in materializations + + default_mat = materializations["custom_adapter_test_default"] + assert "default_adapter" in default_mat.definition + assert default_mat.adapter == "default" + + duckdb_mat = materializations["custom_adapter_test_duckdb"] + assert "duckdb_adapter" in duckdb_mat.definition + assert duckdb_mat.adapter == "duckdb" + + postgres_mat = materializations["custom_adapter_test_postgres"] + assert "postgres_adapter" in postgres_mat.definition + assert postgres_mat.adapter == "postgres" + + # verify that the correct adapter is selected based on target + context._target = profile.target + context._manifest = helper + models = helper.models() + + test_model = models["test_adapter_specific"] + + kind = test_model.model_kind(context) + assert isinstance(kind, DbtCustomKind) + assert kind.materialization == "custom_adapter_test" + # Should use duckdb adapter since that's the default target + assert "duckdb_adapter" in kind.definition or "default_adapter" in kind.definition + + # test also that adapter-specific materializations execute with correct adapter + sushi_context = Context(paths=path) + + plan = sushi_context.plan(select_models=["sushi.test_adapter_specific"]) + sushi_context.apply(plan) + + # check that the table was created with the correct adapter type + result = sushi_context.engine_adapter.fetchdf("SELECT * FROM sushi.test_adapter_specific") + assert len(result) == 1 + assert "adapter_type" in result.columns + assert result["adapter_type"][0] == "duckdb_adapter" + assert result["id"][0] == 1 + assert result["name"][0] == "test" + + +@pytest.mark.xdist_group("dbt_manifest") +def test_missing_custom_materialization_error(): + from sqlmesh.utils.errors import ConfigError + + project_path = Path("tests/fixtures/dbt/sushi_test") + context = DbtContext(project_path) + profile = Profile.load(context) + + # the materialization is non-existent + fake_model_config = ModelConfig( + name="test_model", + path=project_path / "models" / "fake_model.sql", + raw_code="SELECT 1 as id", + materialized="non_existent_custom", + schema="test_schema", + ) + + context._target = profile.target + helper = ManifestHelper( + project_path, + project_path, + "sushi", + profile.target, + model_defaults=ModelDefaultsConfig(start="2020-01-01"), + ) + context._manifest = helper + + # Should raise ConfigError when trying to get the model kind + with pytest.raises(ConfigError) as e: + fake_model_config.model_kind(context) + + assert "Unknown materialization 'non_existent_custom'" in str(e.value) + assert "Custom materializations must be defined" in str(e.value) + + +@pytest.mark.xdist_group("dbt_manifest") +def test_broken_jinja_materialization_error(copy_to_temp_path: t.Callable): + path = copy_to_temp_path("tests/fixtures/dbt/sushi_test") + temp_project = path[0] + + macros_dir = temp_project / "macros" / "materializations" + macros_dir.mkdir(parents=True, exist_ok=True) + + # Create broken Jinja materialization + broken_mat_content = """ +{%- materialization broken_jinja, default -%} + {%- set new_relation = api.Relation.create(database=database, schema=schema, identifier=identifier) -%} + + {{ run_hooks(pre_hooks, inside_transaction=False) }} + + {# An intentional undefined variable that will cause runtime error #} + {%- set broken_var = undefined_variable_that_does_not_exist + 10 -%} + + {%- call statement('main') -%} + CREATE TABLE {{ new_relation }} AS ( + SELECT * FROM ({{ sql }}) AS subquery + WHERE 1 = {{ broken_var }} + ) + {%- endcall -%} + + {{ run_hooks(post_hooks, inside_transaction=False) }} + + {{ return({'relations': [new_relation]}) }} +{%- endmaterialization -%} +""".strip() + + (macros_dir / "broken_jinja.sql").write_text(broken_mat_content) + + models_dir = temp_project / "models" + models_dir.mkdir(parents=True, exist_ok=True) + + test_model_content = """ +{{ config( + materialized='broken_jinja', +) }} + +SELECT + 1 as id, + 'This should fail with Jinja error' as error_msg +""".strip() + + (models_dir / "test_broken_jinja.sql").write_text(test_model_content) + + sushi_context = Context(paths=path) + + # The model will load fine jinja won't fail at parse time + model = sushi_context.get_model("sushi.test_broken_jinja") + assert isinstance(model.kind, DbtCustomKind) + assert model.kind.materialization == "broken_jinja" + + # but execution should fail + with pytest.raises(Exception) as e: + plan = sushi_context.plan(select_models=["sushi.test_broken_jinja"]) + sushi_context.apply(plan) + + assert "plan application failed" in str(e.value).lower() + + +@pytest.mark.xdist_group("dbt_manifest") +def test_failing_hooks_in_materialization(copy_to_temp_path: t.Callable): + path = copy_to_temp_path("tests/fixtures/dbt/sushi_test") + temp_project = path[0] + + models_dir = temp_project / "models" + models_dir.mkdir(parents=True, exist_ok=True) + + test_model_content = """ +{{ config( + materialized='custom_incremental', + pre_hook="CREATE TABLE will_fail_due_to_intentional_syntax_error (", + post_hook="DROP TABLE non_existent_table_that_will_fail", +) }} + +SELECT + 1 as id, + 'Testing hook failures' as test_msg +""".strip() + + (models_dir / "test_failing_hooks.sql").write_text(test_model_content) + + sushi_context = Context(paths=[str(temp_project)]) + + # in this case the pre_hook has invalid syntax + with pytest.raises(Exception) as e: + plan = sushi_context.plan(select_models=["sushi.test_failing_hooks"]) + sushi_context.apply(plan) + + assert "plan application failed" in str(e.value).lower() + + +@pytest.mark.xdist_group("dbt_manifest") +def test_custom_materialization_virtual_environments(copy_to_temp_path: t.Callable): + path = copy_to_temp_path("tests/fixtures/dbt/sushi_test") + temp_project = path[0] + + models_dir = temp_project / "models" + models_dir.mkdir(parents=True, exist_ok=True) + + test_model_content = """ +{{ config( + materialized='custom_incremental', + time_column='created_at', +) }} + +SELECT + CURRENT_TIMESTAMP as created_at, + 1 as id, + 'venv_test' as test_type +""".strip() + + (models_dir / "test_venv_model.sql").write_text(test_model_content) + + sushi_context = Context(paths=path) + prod_plan = sushi_context.plan(select_models=["sushi.test_venv_model"]) + sushi_context.apply(prod_plan) + prod_result = sushi_context.engine_adapter.fetchdf( + "SELECT * FROM sushi.test_venv_model ORDER BY id" + ) + assert len(prod_result) == 1 + assert prod_result["id"][0] == 1 + assert prod_result["test_type"][0] == "venv_test" + + # Create dev environment and check the dev table was created with proper naming + dev_plan = sushi_context.plan("dev", select_models=["sushi.test_venv_model"]) + sushi_context.apply(dev_plan) + dev_result = sushi_context.engine_adapter.fetchdf( + "SELECT * FROM sushi__dev.test_venv_model ORDER BY id" + ) + assert len(dev_result) == 1 + assert dev_result["id"][0] == 1 + assert dev_result["test_type"][0] == "venv_test" + + dev_tables = sushi_context.engine_adapter.fetchdf(""" + SELECT table_name, table_schema + FROM system.information_schema.tables + WHERE table_schema LIKE 'sushi%dev%' + AND table_name LIKE '%test_venv_model%' + """) + + prod_tables = sushi_context.engine_adapter.fetchdf(""" + SELECT table_name, table_schema + FROM system.information_schema.tables + WHERE table_schema = 'sushi' + AND table_name LIKE '%test_venv_model%' + """) + + # Verify both environments have their own tables + assert len(dev_tables) >= 1 + assert len(prod_tables) >= 1 + + +@pytest.mark.xdist_group("dbt_manifest") +def test_virtual_environment_schema_names(copy_to_temp_path: t.Callable): + path = copy_to_temp_path("tests/fixtures/dbt/sushi_test") + temp_project = path[0] + + models_dir = temp_project / "models" + models_dir.mkdir(parents=True, exist_ok=True) + + test_model_content = """ +{{ config( + materialized='custom_incremental', + time_column='created_at', +) }} + +SELECT + CURRENT_TIMESTAMP as created_at, + 1 as id, + 'schema_naming_test' as test_type +""".strip() + + (models_dir / "test_schema_naming.sql").write_text(test_model_content) + + context = Context(paths=path) + prod_plan = context.plan(select_models=["sushi.test_schema_naming"]) + context.apply(prod_plan) + + dev_plan = context.plan("dev", select_models=["sushi.test_schema_naming"]) + context.apply(dev_plan) + + prod_result = context.engine_adapter.fetchdf( + "SELECT * FROM sushi.test_schema_naming ORDER BY id" + ) + assert len(prod_result) == 1 + assert prod_result["test_type"][0] == "schema_naming_test" + + dev_result = context.engine_adapter.fetchdf( + "SELECT * FROM sushi__dev.test_schema_naming ORDER BY id" + ) + assert len(dev_result) == 1 + assert dev_result["test_type"][0] == "schema_naming_test" + + # to examine the schema structure + all_schemas_query = """ + SELECT DISTINCT table_schema, COUNT(*) as table_count + FROM system.information_schema.tables + WHERE table_schema LIKE '%sushi%' + AND table_name LIKE '%test_schema_naming%' + GROUP BY table_schema + ORDER BY table_schema + """ + + schema_info = context.engine_adapter.fetchdf(all_schemas_query) + + schema_names = schema_info["table_schema"].tolist() + + # - virtual schemas: sushi, sushi__dev (for views) + view_schemas = [s for s in schema_names if not s.startswith("sqlmesh__")] + + # - physical schema: sqlmesh__sushi (for actual data tables) + physical_schemas = [s for s in schema_names if s.startswith("sqlmesh__")] + + # verify we got both of them + assert len(view_schemas) >= 2 + assert len(physical_schemas) >= 1 + assert "sushi" in view_schemas + assert "sushi__dev" in view_schemas + assert any("sqlmesh__sushi" in s for s in physical_schemas) + + +@pytest.mark.xdist_group("dbt_manifest") +def test_custom_materialization_lineage_tracking(copy_to_temp_path: t.Callable): + path = copy_to_temp_path("tests/fixtures/dbt/sushi_test") + temp_project = path[0] + + models_dir = temp_project / "models" + models_dir.mkdir(parents=True, exist_ok=True) + + # create a custom materialization model that depends on simple_model_a and waiter_names seed + lineage_model_content = """ +{{ config( + materialized='custom_incremental', + time_column='created_at', +) }} + +SELECT + CURRENT_TIMESTAMP as created_at, + w.id as waiter_id, + w.name as waiter_name, + s.a as simple_value, + w.id * s.a as computed_value, + 'lineage_test' as model_type +FROM {{ ref('waiter_names') }} w +CROSS JOIN {{ ref('simple_model_a') }} s +""".strip() + + (models_dir / "enhanced_waiter_data.sql").write_text(lineage_model_content) + + # Create another custom materialization model that depends on the first one and simple_model_b + downstream_model_content = """ +{{ config( + materialized='custom_incremental', + time_column='analysis_date', +) }} + +SELECT + CURRENT_TIMESTAMP as analysis_date, + e.waiter_name, + e.simple_value, + e.computed_value, + b.a as model_b_value, + e.computed_value + b.a as final_computation, + CASE + WHEN e.computed_value >= 5 THEN 'High' + WHEN e.computed_value >= 2 THEN 'Medium' + ELSE 'Low' + END as category, + 'downstream_lineage_test' as model_type +FROM {{ ref('enhanced_waiter_data') }} e +CROSS JOIN {{ ref('simple_model_b') }} b +WHERE e.computed_value >= 0 +""".strip() + + (models_dir / "waiter_analytics_summary.sql").write_text(downstream_model_content) + + context = Context(paths=path) + enhanced_data_model = context.get_model("sushi.enhanced_waiter_data") + analytics_summary_model = context.get_model("sushi.waiter_analytics_summary") + + # Verify that custom materialization models have proper model kinds + assert isinstance(enhanced_data_model.kind, DbtCustomKind) + assert enhanced_data_model.kind.materialization == "custom_incremental" + + assert isinstance(analytics_summary_model.kind, DbtCustomKind) + assert analytics_summary_model.kind.materialization == "custom_incremental" + + # - enhanced_waiter_data should depend on waiter_names and simple_model_a + enhanced_data_deps = enhanced_data_model.depends_on + assert '"memory"."sushi"."simple_model_a"' in enhanced_data_deps + assert '"memory"."sushi"."waiter_names"' in enhanced_data_deps + + # - waiter_analytics_summary should depend on enhanced_waiter_data and simple_model_b + analytics_deps = analytics_summary_model.depends_on + assert '"memory"."sushi"."enhanced_waiter_data"' in analytics_deps + assert '"memory"."sushi"."simple_model_b"' in analytics_deps + + # build only the models that have dependences + plan = context.plan( + select_models=[ + "sushi.waiter_names", + "sushi.simple_model_a", + "sushi.simple_model_b", + "sushi.enhanced_waiter_data", + "sushi.waiter_analytics_summary", + ] + ) + context.apply(plan) + + # Verify that all δοwnstream models were built and contain expected data + waiter_names_result = context.engine_adapter.fetchdf( + "SELECT COUNT(*) as count FROM sushi.waiter_names" + ) + assert waiter_names_result["count"][0] > 0 + + simple_a_result = context.engine_adapter.fetchdf("SELECT a FROM sushi.simple_model_a") + assert len(simple_a_result) > 0 + assert simple_a_result["a"][0] == 1 + + simple_b_result = context.engine_adapter.fetchdf("SELECT a FROM sushi.simple_model_b") + assert len(simple_b_result) > 0 + assert simple_b_result["a"][0] == 1 + + # Check intermediate custom materialization model + enhanced_data_result = context.engine_adapter.fetchdf(""" + SELECT + waiter_name, + simple_value, + computed_value, + model_type + FROM sushi.enhanced_waiter_data + ORDER BY waiter_id + LIMIT 5 + """) + + assert len(enhanced_data_result) > 0 + assert enhanced_data_result["model_type"][0] == "lineage_test" + assert all(val == 1 for val in enhanced_data_result["simple_value"]) + assert all(val >= 0 for val in enhanced_data_result["computed_value"]) + assert any(val == "Ryan" for val in enhanced_data_result["waiter_name"]) + + # Check final downstream custom materialization model + analytics_summary_result = context.engine_adapter.fetchdf(""" + SELECT + waiter_name, + category, + model_type, + final_computation + FROM sushi.waiter_analytics_summary + ORDER BY waiter_name + LIMIT 5 + """) + + assert len(analytics_summary_result) > 0 + assert analytics_summary_result["model_type"][0] == "downstream_lineage_test" + assert all(cat in ["High", "Medium", "Low"] for cat in analytics_summary_result["category"]) + assert all(val >= 0 for val in analytics_summary_result["final_computation"]) + + # Test that lineage information is preserved in dev environments + dev_plan = context.plan("dev", select_models=["sushi.waiter_analytics_summary"]) + context.apply(dev_plan) + + dev_analytics_result = context.engine_adapter.fetchdf(""" + SELECT + COUNT(*) as count, + COUNT(DISTINCT waiter_name) as unique_waiters + FROM sushi__dev.waiter_analytics_summary + """) + + prod_analytics_result = context.engine_adapter.fetchdf(""" + SELECT + COUNT(*) as count, + COUNT(DISTINCT waiter_name) as unique_waiters + FROM sushi.waiter_analytics_summary + """) + + # Dev and prod should have the same data as they share physical data + assert dev_analytics_result["count"][0] == prod_analytics_result["count"][0] + assert dev_analytics_result["unique_waiters"][0] == prod_analytics_result["unique_waiters"][0] + + +@pytest.mark.xdist_group("dbt_manifest") +def test_custom_materialization_grants(copy_to_temp_path: t.Callable, mocker): + path = copy_to_temp_path("tests/fixtures/dbt/sushi_test") + temp_project = path[0] + + models_dir = temp_project / "models" + models_dir.mkdir(parents=True, exist_ok=True) + + grants_model_content = """ +{{ config( + materialized='custom_incremental', + grants={ + 'select': ['user1', 'user2'], + 'insert': ['writer'] + } +) }} + +SELECT + CURRENT_TIMESTAMP as created_at, + 1 as id, + 'grants_test' as test_type +""".strip() + + (models_dir / "test_grants_model.sql").write_text(grants_model_content) + + mocker.patch.object(DuckDBEngineAdapter, "SUPPORTS_GRANTS", True) + mocker.patch.object(DuckDBEngineAdapter, "_get_current_grants_config", return_value={}) + + sync_grants_calls = [] + + def mock_sync_grants(*args, **kwargs): + sync_grants_calls.append((args, kwargs)) + + mocker.patch.object(DuckDBEngineAdapter, "sync_grants_config", side_effect=mock_sync_grants) + + context = Context(paths=path) + + model = context.get_model("sushi.test_grants_model") + assert isinstance(model.kind, DbtCustomKind) + plan = context.plan(select_models=["sushi.test_grants_model"]) + context.apply(plan) + + assert len(sync_grants_calls) == 1 + args = sync_grants_calls[0][0] + assert args + + table = args[0] + grants_config = args[1] + assert table.sql(dialect="duckdb") == "memory.sushi.test_grants_model" + assert grants_config == { + "select": ["user1", "user2"], + "insert": ["writer"], + } diff --git a/tests/dbt/test_docs.py b/tests/dbt/test_docs.py new file mode 100644 index 0000000000..7c21edb970 --- /dev/null +++ b/tests/dbt/test_docs.py @@ -0,0 +1,28 @@ +from pathlib import Path +import pytest + +from sqlmesh.core.config.model import ModelDefaultsConfig +from sqlmesh.dbt.context import DbtContext +from sqlmesh.dbt.manifest import ManifestHelper +from sqlmesh.dbt.profile import Profile + + +pytestmark = pytest.mark.dbt + + +@pytest.mark.xdist_group("dbt_manifest") +def test_docs_inline(): + project_path = Path("tests/fixtures/dbt/sushi_test") + profile = Profile.load(DbtContext(project_path)) + + helper = ManifestHelper( + project_path, + project_path, + "sushi", + profile.target, + model_defaults=ModelDefaultsConfig(start="2020-01-01"), + ) + # Inline description in yaml + assert helper.models()["waiters"].description == "waiters docs block" + # Docs block from .md file + assert helper.models()["top_waiters"].description == "description of top waiters" diff --git a/tests/dbt/test_integration.py b/tests/dbt/test_integration.py index a7838897f5..ab22bf7826 100644 --- a/tests/dbt/test_integration.py +++ b/tests/dbt/test_integration.py @@ -5,16 +5,22 @@ from functools import partial from pathlib import Path -import pandas as pd +import pandas as pd # noqa: TID253 import pytest -from dbt.cli.main import dbtRunner -from freezegun import freeze_time + +from sqlmesh.dbt.util import DBT_VERSION + +if DBT_VERSION >= (1, 5, 0): + from dbt.cli.main import dbtRunner # type: ignore + +import time_machine from sqlmesh import Context from sqlmesh.core.config.connection import DuckDBConnectionConfig from sqlmesh.core.engine_adapter import DuckDBEngineAdapter from sqlmesh.utils.pandas import columns_to_types_from_df -from sqlmesh.utils.yaml import YAML +from sqlmesh.utils.yaml import YAML, load as yaml_load, dump as yaml_dump +from sqlmesh_dbt.operations import init_project_if_required from tests.utils.pandas import compare_dataframes, create_df # Some developers had issues with this test freezing locally so we mark it as cicdonly @@ -22,6 +28,8 @@ class TestType(str, Enum): + __test__ = False # prevent pytest trying to collect this as a test class + DBT_RUNTIME = "dbt_runtime" DBT_ADAPTER = "dbt_adapter" SQLMESH = "sqlmesh" @@ -48,6 +56,8 @@ def is_sqlmesh_runtime(self) -> bool: class TestStrategy(str, Enum): + __test__ = False # prevent pytest trying to collect this as a test class + CHECK = "check" TIMESTAMP = "timestamp" @@ -137,7 +147,6 @@ def _make_function( f.write( """from pathlib import Path -from sqlmesh.core.config import AirflowSchedulerConfig from sqlmesh.dbt.loader import sqlmesh_config config = sqlmesh_config(Path(__file__).parent) @@ -195,9 +204,11 @@ def _replace_source_table( columns_to_types = columns_to_types_from_df(df) if values: - adapter.replace_query("sushi.raw_marketing", df, columns_to_types=columns_to_types) + adapter.replace_query( + "sushi.raw_marketing", df, target_columns_to_types=columns_to_types + ) else: - adapter.create_table("sushi.raw_marketing", columns_to_types=columns_to_types) + adapter.create_table("sushi.raw_marketing", target_columns_to_types=columns_to_types) def _normalize_dbt_dataframe( self, @@ -242,7 +253,7 @@ def _get_current_df( return df def _get_duckdb_now(self, adapter: DuckDBEngineAdapter) -> datetime.datetime: - return adapter.fetchone("SELECT now()")[0] + return adapter.fetchone("SELECT now()")[0] # type: ignore def _init_test( self, @@ -289,7 +300,7 @@ def _init_test( adapter.create_schema("sushi") if test_type.is_sqlmesh_runtime: self._replace_source_table(adapter, []) - with freeze_time("2019-12-31 00:00:00"): + with time_machine.travel("2019-12-31 00:00:00 UTC"): context.plan("prod", auto_apply=True, no_prompts=True) # type: ignore return run, adapter, context @@ -302,6 +313,9 @@ def test_scd_type_2_by_time( test_type: TestType, invalidate_hard_deletes: bool, ): + if test_type.is_dbt_runtime and DBT_VERSION < (1, 5, 0): + pytest.skip("The dbt version being tested doesn't support the dbtRunner so skipping.") + run, adapter, context = self._init_test( create_scd_type_2_dbt_project, create_scd_type_2_sqlmesh_project, @@ -316,7 +330,7 @@ def test_scd_type_2_by_time( t.List[t.Tuple[int, str, str]], t.List[t.Tuple[int, str, str, str, t.Optional[str]]] ], ] = { - "2020-01-01 00:00:00": ( + "2020-01-01 00:00:00 UTC": ( [ (1, "a", "2020-01-01 00:00:00"), (2, "b", "2020-01-01 00:00:00"), @@ -328,7 +342,7 @@ def test_scd_type_2_by_time( (3, "c", "2020-01-01 00:00:00", "2020-01-01 00:00:00", None), ], ), - "2020-01-02 00:00:00": ( + "2020-01-02 00:00:00 UTC": ( [ # Update to "x" (1, "x", "2020-01-02 00:00:00"), @@ -353,7 +367,7 @@ def test_scd_type_2_by_time( (4, "d", "2020-01-02 00:00:00", "2020-01-02 00:00:00", None), ], ), - "2020-01-04 00:00:00": ( + "2020-01-04 00:00:00 UTC": ( [ # Update to "y" (1, "y", "2020-01-03 00:00:00"), @@ -399,7 +413,8 @@ def test_scd_type_2_by_time( time_start_end_mapping = {} for time, (starting_source_data, expected_table_data) in time_expected_mapping.items(): self._replace_source_table(adapter, starting_source_data) - with freeze_time(time): + # Tick when running dbt runtime because it hangs during execution for unknown reasons. + with time_machine.travel(time, tick=test_type.is_dbt_runtime): start_time = self._get_duckdb_now(adapter) run() end_time = self._get_duckdb_now(adapter) @@ -437,7 +452,7 @@ def test_scd_type_2_by_column( t.List[t.Tuple[int, str, str]], t.List[t.Tuple[int, str, str, str, t.Optional[str]]] ], ] = { - "2020-01-01 00:00:00": ( + "2020-01-01 00:00:00 UTC": ( [ (1, "a", "2020-01-01 00:00:00"), (2, "b", "2020-01-01 00:00:00"), @@ -449,7 +464,7 @@ def test_scd_type_2_by_column( (3, "c", "2020-01-01 00:00:00", "2020-01-01 00:00:00", None), ], ), - "2020-01-02 00:00:00": ( + "2020-01-02 00:00:00 UTC": ( [ # Update to "x" (1, "x", "2020-01-02 00:00:00"), @@ -474,7 +489,7 @@ def test_scd_type_2_by_column( (4, "d", "2020-01-02 00:00:00", "2020-01-02 00:00:00", None), ], ), - "2020-01-04 00:00:00": ( + "2020-01-04 00:00:00 UTC": ( [ # Update to "y" (1, "y", "2020-01-03 00:00:00"), @@ -516,7 +531,7 @@ def test_scd_type_2_by_column( time_start_end_mapping = {} for time, (starting_source_data, expected_table_data) in time_expected_mapping.items(): self._replace_source_table(adapter, starting_source_data) - with freeze_time(time): + with time_machine.travel(time, tick=False): start_time = self._get_duckdb_now(adapter) run() end_time = self._get_duckdb_now(adapter) @@ -526,3 +541,114 @@ def test_scd_type_2_by_column( ) df_expected = create_df(expected_table_data, self.target_schema) compare_dataframes(df_actual, df_expected, msg=f"Failed on time {time}") + + +def test_dbt_node_info(jaffle_shop_duckdb_context: Context): + ctx = jaffle_shop_duckdb_context + + customers = ctx.models['"jaffle_shop"."main"."customers"'] + assert customers.dbt_unique_id == "model.jaffle_shop.customers" + assert customers.dbt_fqn == "jaffle_shop.customers" + assert customers.dbt_node_info + assert customers.dbt_node_info.name == "customers" + + orders = ctx.models['"jaffle_shop"."main"."orders"'] + assert orders.dbt_unique_id == "model.jaffle_shop.orders" + assert orders.dbt_fqn == "jaffle_shop.orders" + assert orders.dbt_node_info + assert orders.dbt_node_info.name == "orders" + + stg_customers = ctx.models['"jaffle_shop"."main"."stg_customers"'] + assert stg_customers.dbt_unique_id == "model.jaffle_shop.stg_customers" + assert stg_customers.dbt_fqn == "jaffle_shop.staging.stg_customers" + assert stg_customers.dbt_node_info + assert stg_customers.dbt_node_info.name == "stg_customers" + + stg_orders = ctx.models['"jaffle_shop"."main"."stg_orders"'] + assert stg_orders.dbt_unique_id == "model.jaffle_shop.stg_orders" + assert stg_orders.dbt_fqn == "jaffle_shop.staging.stg_orders" + assert stg_orders.dbt_node_info + assert stg_orders.dbt_node_info.name == "stg_orders" + + raw_customers = ctx.models['"jaffle_shop"."main"."raw_customers"'] + assert raw_customers.dbt_unique_id == "seed.jaffle_shop.raw_customers" + assert raw_customers.dbt_fqn == "jaffle_shop.raw_customers" + assert raw_customers.dbt_node_info + assert raw_customers.dbt_node_info.name == "raw_customers" + + raw_orders = ctx.models['"jaffle_shop"."main"."raw_orders"'] + assert raw_orders.dbt_unique_id == "seed.jaffle_shop.raw_orders" + assert raw_orders.dbt_fqn == "jaffle_shop.raw_orders" + assert raw_orders.dbt_node_info + assert raw_orders.dbt_node_info.name == "raw_orders" + + raw_payments = ctx.models['"jaffle_shop"."main"."raw_payments"'] + assert raw_payments.dbt_unique_id == "seed.jaffle_shop.raw_payments" + assert raw_payments.dbt_fqn == "jaffle_shop.raw_payments" + assert raw_payments.dbt_node_info + assert raw_payments.dbt_node_info.name == "raw_payments" + + relationship_audit = ctx.snapshots[ + "relationships_orders_customer_id__customer_id__ref_customers_" + ] + assert relationship_audit.node.is_audit + assert ( + relationship_audit.node.dbt_unique_id + == "test.jaffle_shop.relationships_orders_customer_id__customer_id__ref_customers_.c6ec7f58f2" + ) + assert ( + relationship_audit.node.dbt_fqn + == "jaffle_shop.relationships_orders_customer_id__customer_id__ref_customers_" + ) + assert relationship_audit.node.dbt_node_info + assert ( + relationship_audit.node.dbt_node_info.name + == "relationships_orders_customer_id__customer_id__ref_customers_" + ) + + +def test_state_schema_isolation_per_target(jaffle_shop_duckdb: Path): + profiles_file = jaffle_shop_duckdb / "profiles.yml" + + profiles_yml = yaml_load(profiles_file) + + # make prod / dev config identical with the exception of a different default schema to simulate using the same warehouse + profiles_yml["jaffle_shop"]["outputs"]["prod"] = { + **profiles_yml["jaffle_shop"]["outputs"]["dev"] + } + profiles_yml["jaffle_shop"]["outputs"]["prod"]["schema"] = "prod_schema" + profiles_yml["jaffle_shop"]["outputs"]["dev"]["schema"] = "dev_schema" + + profiles_file.write_text(yaml_dump(profiles_yml)) + + init_project_if_required(jaffle_shop_duckdb) + + # start off with the prod target + prod_ctx = Context(paths=[jaffle_shop_duckdb], config_loader_kwargs={"target": "prod"}) + assert prod_ctx.config.get_state_schema() == "sqlmesh_state_jaffle_shop_prod_schema" + assert all("prod_schema" in fqn for fqn in prod_ctx.models) + assert prod_ctx.plan(auto_apply=True).has_changes + assert not prod_ctx.plan(auto_apply=True).has_changes + + # dev target should have changes - new state separate from prod + dev_ctx = Context(paths=[jaffle_shop_duckdb], config_loader_kwargs={"target": "dev"}) + assert dev_ctx.config.get_state_schema() == "sqlmesh_state_jaffle_shop_dev_schema" + assert all("dev_schema" in fqn for fqn in dev_ctx.models) + assert dev_ctx.plan(auto_apply=True).has_changes + assert not dev_ctx.plan(auto_apply=True).has_changes + + # no explicitly specified target should use dev because that's what's set for the default in the profiles.yml + assert profiles_yml["jaffle_shop"]["target"] == "dev" + default_ctx = Context(paths=[jaffle_shop_duckdb]) + assert default_ctx.config.get_state_schema() == "sqlmesh_state_jaffle_shop_dev_schema" + assert all("dev_schema" in fqn for fqn in default_ctx.models) + assert not default_ctx.plan(auto_apply=True).has_changes + + # an explicit state schema override set in `sqlmesh.yaml` should use that + sqlmesh_yaml_file = jaffle_shop_duckdb / "sqlmesh.yaml" + sqlmesh_yaml = yaml_load(sqlmesh_yaml_file) + sqlmesh_yaml["gateways"] = {"dev": {"state_schema": "sqlmesh_dev_state_override"}} + sqlmesh_yaml_file.write_text(yaml_dump(sqlmesh_yaml)) + default_ctx = Context(paths=[jaffle_shop_duckdb]) + assert default_ctx.config.get_state_schema() == "sqlmesh_dev_state_override" + assert all("dev_schema" in fqn for fqn in default_ctx.models) diff --git a/tests/dbt/test_manifest.py b/tests/dbt/test_manifest.py index 59b4d0ffc8..2ecf8b8980 100644 --- a/tests/dbt/test_manifest.py +++ b/tests/dbt/test_manifest.py @@ -4,10 +4,14 @@ import pytest +from sqlmesh.core.config import ModelDefaultsConfig from sqlmesh.dbt.basemodel import Dependencies +from sqlmesh.dbt.common import ModelAttrs from sqlmesh.dbt.context import DbtContext -from sqlmesh.dbt.manifest import ManifestHelper +from sqlmesh.dbt.manifest import ManifestHelper, _convert_jinja_test_to_macro from sqlmesh.dbt.profile import Profile +from sqlmesh.dbt.builtin import Api, _relation_info_to_relation +from sqlmesh.dbt.util import DBT_VERSION from sqlmesh.utils.jinja import MacroReference pytestmark = pytest.mark.dbt @@ -22,7 +26,7 @@ def test_manifest_helper(caplog): project_path, "sushi", profile.target, - variable_overrides={"start": "2020-01-01"}, + model_defaults=ModelDefaultsConfig(start="2020-01-01"), ) models = helper.models() @@ -30,7 +34,12 @@ def test_manifest_helper(caplog): assert models["top_waiters"].dependencies == Dependencies( refs={"sushi.waiter_revenue_by_day", "waiter_revenue_by_day"}, variables={"top_waiters:revenue", "top_waiters:limit"}, - macros=[MacroReference(name="ref"), MacroReference(name="var")], + model_attrs=ModelAttrs(attrs={"columns", "config"}), + macros=[ + MacroReference(name="get_top_waiters_limit"), + MacroReference(name="ref"), + MacroReference(name="var"), + ], ) assert models["top_waiters"].materialized == "view" assert models["top_waiters"].dialect_ == "postgres" @@ -59,9 +68,12 @@ def test_manifest_helper(caplog): assert models["items_no_hard_delete_snapshot"].invalidate_hard_deletes is False # Test versioned models - assert models["waiter_revenue_by_day_v1"].version == 1 - assert models["waiter_revenue_by_day_v2"].version == 2 - assert "waiter_revenue_by_day" not in models + if DBT_VERSION >= (1, 5, 0): + assert models["waiter_revenue_by_day_v1"].version == 1 + assert models["waiter_revenue_by_day_v2"].version == 2 + assert "waiter_revenue_by_day" not in models + else: + assert "waiter_revenue_by_day" in models waiter_as_customer_by_day_config = models["waiter_as_customer_by_day"] assert waiter_as_customer_by_day_config.dependencies == Dependencies( @@ -69,13 +81,17 @@ def test_manifest_helper(caplog): macros=[MacroReference(name="ref")], ) assert waiter_as_customer_by_day_config.materialized == "incremental" - assert waiter_as_customer_by_day_config.incremental_strategy == "delete+insert" + assert waiter_as_customer_by_day_config.incremental_strategy == "incremental_by_time_range" assert waiter_as_customer_by_day_config.cluster_by == ["ds"] assert waiter_as_customer_by_day_config.time_column == "ds" - waiter_revenue_by_day_config = models["waiter_revenue_by_day_v2"] + if DBT_VERSION >= (1, 5, 0): + waiter_revenue_by_day_config = models["waiter_revenue_by_day_v2"] + else: + waiter_revenue_by_day_config = models["waiter_revenue_by_day"] assert waiter_revenue_by_day_config.dependencies == Dependencies( macros={ + MacroReference(name="dynamic_var_name_dependency"), MacroReference(name="log_value"), MacroReference(name="test_dependencies"), MacroReference(package="customers", name="duckdb__current_engine"), @@ -83,10 +99,11 @@ def test_manifest_helper(caplog): MacroReference(name="source"), }, sources={"streaming.items", "streaming.orders", "streaming.order_items"}, - variables={"yet_another_var"}, + variables={"yet_another_var", "nested_vars"}, + has_dynamic_var_names=True, ) assert waiter_revenue_by_day_config.materialized == "incremental" - assert waiter_revenue_by_day_config.incremental_strategy == "delete+insert" + assert waiter_revenue_by_day_config.incremental_strategy == "incremental_by_time_range" assert waiter_revenue_by_day_config.cluster_by == ["ds"] assert waiter_revenue_by_day_config.time_column == "ds" assert waiter_revenue_by_day_config.dialect_ == "bigquery" @@ -112,6 +129,14 @@ def test_manifest_helper(caplog): assert sources["streaming.order_items"].table_name == "order_items" assert sources["streaming.order_items"].schema_ == "raw" + assert all(s.quoting["identifier"] is False for s in sources.values()) + + assert sources["streaming.order_items"].freshness == { + "warn_after": {"count": 10 if DBT_VERSION < (1, 9, 5) else 12, "period": "hour"}, + "error_after": {"count": 11 if DBT_VERSION < (1, 9, 5) else 13, "period": "hour"}, + "filter": None, + } + @pytest.mark.xdist_group("dbt_manifest") def test_tests_referencing_disabled_models(): @@ -122,13 +147,33 @@ def test_tests_referencing_disabled_models(): project_path, "sushi", profile.target, - variable_overrides={"start": "2020-01-01"}, + model_defaults=ModelDefaultsConfig(start="2020-01-01"), ) assert "disabled_model" not in helper.models() assert "not_null_disabled_model_one" not in helper.tests() +@pytest.mark.xdist_group("dbt_manifest") +def test_call_cache(): + project_path = Path("tests/fixtures/dbt/sushi_test") + profile = Profile.load(DbtContext(project_path)) + helper = ManifestHelper( + project_path, + project_path, + "sushi", + profile.target, + model_defaults=ModelDefaultsConfig(start="2020-01-01"), + ) + + unused = "0000" + helper._call_cache.put("", value={unused: "unused"}) + helper._load_all() + calls = set(helper._call_cache.get("").keys()) + assert len(calls) >= 300 + assert unused not in calls + + @pytest.mark.xdist_group("dbt_manifest") def test_variable_override(): project_path = Path("tests/fixtures/dbt/sushi_test") @@ -139,15 +184,186 @@ def test_variable_override(): project_path, "sushi", profile.target, - variable_overrides={"start": "2020-01-01"}, + model_defaults=ModelDefaultsConfig(start="2020-01-01"), ) - assert helper.models()["top_waiters"].limit_value == 10 + assert helper.models()["top_waiters"].limit_value.strip() == "10" helper = ManifestHelper( project_path, project_path, "sushi", profile.target, - variable_overrides={"top_waiters:limit": 1, "start": "2020-01-01"}, + variable_overrides={"top_waiters:limit": 1}, + model_defaults=ModelDefaultsConfig(start="2020-01-01"), ) assert helper.models()["top_waiters"].limit_value == 1 + + +@pytest.mark.xdist_group("dbt_manifest") +def test_source_meta_external_location(): + project_path = Path("tests/fixtures/dbt/sushi_test") + profile = Profile.load(DbtContext(project_path)) + + helper = ManifestHelper( + project_path, + project_path, + "sushi", + profile.target, + model_defaults=ModelDefaultsConfig(start="2020-01-01"), + ) + + sources = helper.sources() + parquet_orders = sources["parquet_file.orders"] + assert parquet_orders.source_meta == { + "external_location": "read_parquet('path/to/external/{name}.parquet')" + } + assert ( + parquet_orders.relation_info.external == "read_parquet('path/to/external/orders.parquet')" + ) + + api = Api("duckdb") + relation_info = sources["parquet_file.items"].relation_info + assert relation_info.external == "read_parquet('path/to/external/items.parquet')" + + relation = _relation_info_to_relation( + sources["parquet_file.items"].relation_info, api.Relation, api.quote_policy + ) + assert relation.identifier == "items" + expected = ( + "read_parquet('path/to/external/items.parquet')" + if DBT_VERSION >= (1, 4, 0) + else '"memory"."parquet_file".items' + ) + assert relation.render() == expected + + +@pytest.mark.xdist_group("dbt_manifest") +def test_top_level_dbt_adapter_macros(): + project_path = Path("tests/fixtures/dbt/sushi_test") + profile = Profile.load(DbtContext(project_path)) + + helper = ManifestHelper( + project_path, + project_path, + "sushi", + profile.target, + model_defaults=ModelDefaultsConfig(start="2020-01-01"), + ) + + # Adapter macros must be marked as top-level + dbt_macros = helper.macros("dbt") + dbt_duckdb_macros = helper.macros("dbt_duckdb") + assert dbt_macros["default__dateadd"].info.is_top_level + assert dbt_macros["default__datediff"].info.is_top_level + assert dbt_duckdb_macros["duckdb__datediff"].info.is_top_level + assert dbt_duckdb_macros["duckdb__dateadd"].info.is_top_level + + # Project dispatch macros should not be marked as top-level + customers_macros = helper.macros("customers") + assert not customers_macros["default__current_engine"].info.is_top_level + assert not customers_macros["duckdb__current_engine"].info.is_top_level + + +def test_convert_jinja_test_to_macro(): + # Test block with whitespace trimming + test_input = """{%- test assert_positive(model, column_name) -%} + select * from {{ model }} where {{ column_name }} <= 0 +{%- endtest -%}""" + + expected_output = """{%- macro test_assert_positive(model, column_name) -%} + select * from {{ model }} where {{ column_name }} <= 0 +{%- endmacro -%}""" + + assert _convert_jinja_test_to_macro(test_input) == expected_output + + # Test block without whitespace trimming + test_input_no_ws = """{% test assert_positive(model, column_name) %} + select * from {{ model }} where {{ column_name }} <= 0 +{% endtest %}""" + + expected_output_no_ws = """{% macro test_assert_positive(model, column_name) %} + select * from {{ model }} where {{ column_name }} <= 0 +{% endmacro %}""" + + assert _convert_jinja_test_to_macro(test_input_no_ws) == expected_output_no_ws + + # Test block with mixed whitespace trimming + test_input_mixed = """{%- test complex_test(model, column_name='id') %} + select count(*) from {{ model }} where {{ column_name }} is null +{% endtest -%}""" + + expected_output_mixed = """{%- macro test_complex_test(model, column_name='id') %} + select count(*) from {{ model }} where {{ column_name }} is null +{% endmacro -%}""" + + assert _convert_jinja_test_to_macro(test_input_mixed) == expected_output_mixed + + # Test already converted macro (should return unchanged) + macro_input = """{%- macro test_already_converted(model) -%} + select * from {{ model }} +{%- endmacro -%}""" + + assert _convert_jinja_test_to_macro(macro_input) == macro_input + + +@pytest.mark.xdist_group("dbt_manifest") +def test_macro_depenency_none_str(): + project_path = Path("tests/fixtures/dbt/sushi_test") + profile = Profile.load(DbtContext(project_path)) + helper = ManifestHelper( + project_path, + project_path, + "sushi", + profile.target, + model_defaults=ModelDefaultsConfig(start="2020-01-01"), + ) + node = helper._manifest.nodes["model.customers.customer_revenue_by_day"] + node.depends_on.macros.append("None") + + from sqlmesh.dbt.manifest import _macro_references + + # "None" macro shouldn't raise a KeyError + _macro_references(helper._manifest, node) + + +@pytest.mark.xdist_group("dbt_manifest") +def test_macro_assignment_shadowing(create_empty_project): + project_name = "local" + project_path, models_path = create_empty_project(project_name=project_name) + + macros_path = project_path / "macros" + macros_path.mkdir() + + (macros_path / "model_path_macro.sql").write_text(""" +{% macro model_path_macro() %} + {% if execute %} + {% set model = model.path.split('/')[-1].replace('.sql', '') %} + SELECT '{{ model }}' as model_name + {% else %} + SELECT 'placeholder' as placeholder + {% endif %} +{% endmacro %} +""") + + (models_path / "model_using_path_macro.sql").write_text(""" +{{ model_path_macro() }} +""") + + context = DbtContext(project_path) + profile = Profile.load(context) + + helper = ManifestHelper( + project_path, + project_path, + project_name, + profile.target, + model_defaults=ModelDefaultsConfig(start="2020-01-01"), + ) + + macros = helper.macros(project_name) + assert "model_path_macro" in macros + assert "path" in macros["model_path_macro"].dependencies.model_attrs.attrs + + models = helper.models() + assert "model_using_path_macro" in models + assert "path" in models["model_using_path_macro"].dependencies.model_attrs.attrs diff --git a/tests/dbt/test_model.py b/tests/dbt/test_model.py index cf88872fc7..a954f98f41 100644 --- a/tests/dbt/test_model.py +++ b/tests/dbt/test_model.py @@ -1,46 +1,1116 @@ +import datetime +import logging + import pytest +from pathlib import Path + +from sqlglot import exp +from sqlglot.errors import SchemaError +from sqlmesh import Context +from sqlmesh.core.console import NoopConsole, get_console +from sqlmesh.core.model import TimeColumn, IncrementalByTimeRangeKind +from sqlmesh.core.model.kind import OnDestructiveChange, OnAdditiveChange, SCDType2ByColumnKind +from sqlmesh.core.state_sync.db.snapshot import _snapshot_to_json +from sqlmesh.core.config.common import VirtualEnvironmentMode +from sqlmesh.core.model.meta import GrantsTargetLayer from sqlmesh.dbt.common import Dependencies from sqlmesh.dbt.context import DbtContext from sqlmesh.dbt.model import ModelConfig +from sqlmesh.dbt.target import BigQueryConfig, DuckDbConfig, PostgresConfig from sqlmesh.dbt.test import TestConfig -from sqlmesh.utils.errors import ConfigError +from sqlmesh.utils.yaml import YAML +from sqlmesh.utils.date import to_ds +import typing as t pytestmark = pytest.mark.dbt -def test_model_test_circular_references() -> None: - upstream_model = ModelConfig(name="upstream") - downstream_model = ModelConfig(name="downstream", dependencies=Dependencies(refs={"upstream"})) - context = DbtContext(_refs={"upstream": upstream_model, "downstream": downstream_model}) - - # Test and downstream model references - downstream_test = TestConfig( - name="downstream_with_upstream", - sql="", - dependencies=Dependencies(refs={"upstream", "downstream"}), - ) - upstream_test = TestConfig( - name="upstream_with_downstream", - sql="", - dependencies=Dependencies(refs={"upstream", "downstream"}), - ) - downstream_model.tests = [downstream_test] - downstream_model.check_for_circular_test_refs(context) - - downstream_model.tests = [] - upstream_model.tests = [upstream_test] - with pytest.raises(ConfigError, match="downstream model"): - upstream_model.check_for_circular_test_refs(context) - - downstream_model.tests = [downstream_test] - with pytest.raises(ConfigError, match="downstream model"): - upstream_model.check_for_circular_test_refs(context) - downstream_model.check_for_circular_test_refs(context) - - # Test only references - downstream_model.dependencies = Dependencies() - with pytest.raises(ConfigError, match="between tests"): - upstream_model.check_for_circular_test_refs(context) - with pytest.raises(ConfigError, match="between tests"): - downstream_model.check_for_circular_test_refs(context) +def test_test_config_is_standalone_behavior() -> None: + """Test that TestConfig.is_standalone correctly identifies tests with cross-model references""" + + # Test with no model_name (should be standalone) + standalone_test = TestConfig( + name="standalone_test", + sql="SELECT 1", + model_name=None, + dependencies=Dependencies(refs={"some_model"}), + ) + assert standalone_test.is_standalone is True + + # Test with only self-reference (should not be standalone) + self_ref_test = TestConfig( + name="self_ref_test", + sql="SELECT * FROM {{ this }}", + model_name="my_model", + dependencies=Dependencies(refs={"my_model"}), + ) + assert self_ref_test.is_standalone is False + + # Test with no references (should not be standalone) + no_ref_test = TestConfig( + name="no_ref_test", + sql="SELECT 1", + model_name="my_model", + dependencies=Dependencies(), + ) + assert no_ref_test.is_standalone is False + + # Test with references to other models (should be standalone) + cross_ref_test = TestConfig( + name="cross_ref_test", + sql="SELECT * FROM {{ ref('other_model') }}", + model_name="my_model", + dependencies=Dependencies(refs={"my_model", "other_model"}), + ) + assert cross_ref_test.is_standalone is True + + # Test with only references to other models, no self-reference (should be standalone) + other_only_test = TestConfig( + name="other_only_test", + sql="SELECT * FROM {{ ref('other_model') }}", + model_name="my_model", + dependencies=Dependencies(refs={"other_model"}), + ) + assert other_only_test.is_standalone is True + + +def test_test_to_sqlmesh_creates_correct_audit_type( + dbt_dummy_postgres_config: PostgresConfig, +) -> None: + """Test that TestConfig.to_sqlmesh creates the correct audit type based on is_standalone""" + from sqlmesh.core.audit.definition import StandaloneAudit, ModelAudit + + # Set up models in context + my_model = ModelConfig( + name="my_model", sql="SELECT 1", schema="test_schema", database="test_db", alias="my_model" + ) + other_model = ModelConfig( + name="other_model", + sql="SELECT 2", + schema="test_schema", + database="test_db", + alias="other_model", + ) + context = DbtContext( + _refs={"my_model": my_model, "other_model": other_model}, + _target=dbt_dummy_postgres_config, + ) + + # Test with only self-reference (should create ModelAudit) + self_ref_test = TestConfig( + name="self_ref_test", + sql="SELECT * FROM {{ this }}", + model_name="my_model", + dependencies=Dependencies(refs={"my_model"}), + ) + audit = self_ref_test.to_sqlmesh(context) + assert isinstance(audit, ModelAudit) + assert audit.name == "self_ref_test" + + # Test with references to other models (should create StandaloneAudit) + cross_ref_test = TestConfig( + name="cross_ref_test", + sql="SELECT * FROM {{ ref('other_model') }}", + model_name="my_model", + dependencies=Dependencies(refs={"my_model", "other_model"}), + ) + audit = cross_ref_test.to_sqlmesh(context) + assert isinstance(audit, StandaloneAudit) + assert audit.name == "cross_ref_test" + + # Test with no model_name (should create StandaloneAudit) + standalone_test = TestConfig( + name="standalone_test", + sql="SELECT 1", + model_name=None, + dependencies=Dependencies(), + ) + audit = standalone_test.to_sqlmesh(context) + assert isinstance(audit, StandaloneAudit) + assert audit.name == "standalone_test" + + +@pytest.mark.slow +def test_manifest_filters_standalone_tests_from_models( + tmp_path: Path, create_empty_project +) -> None: + """Integration test that verifies models only contain non-standalone tests after manifest loading.""" + yaml = YAML() + project_dir, model_dir = create_empty_project(project_name="local") + + # Create two models + model1_contents = "SELECT 1 as id" + model1_file = model_dir / "model1.sql" + with open(model1_file, "w", encoding="utf-8") as f: + f.write(model1_contents) + + model2_contents = "SELECT 2 as id" + model2_file = model_dir / "model2.sql" + with open(model2_file, "w", encoding="utf-8") as f: + f.write(model2_contents) + + # Create schema with both standalone and non-standalone tests + schema_yaml = { + "version": 2, + "models": [ + { + "name": "model1", + "columns": [ + { + "name": "id", + "tests": [ + "not_null", # Non-standalone test - only references model1 + { + "relationships": { # Standalone test - references model2 + "to": "ref('model2')", + "field": "id", + } + }, + ], + } + ], + }, + { + "name": "model2", + "columns": [ + {"name": "id", "tests": ["not_null"]} # Non-standalone test + ], + }, + ], + } + + schema_file = model_dir / "schema.yml" + with open(schema_file, "w", encoding="utf-8") as f: + yaml.dump(schema_yaml, f) + + # Load the project through SQLMesh Context + from sqlmesh.core.context import Context + + context = Context(paths=project_dir) + + model1_snapshot = context.snapshots['"local"."main"."model1"'] + model2_snapshot = context.snapshots['"local"."main"."model2"'] + + # Verify model1 only has non-standalone test in its audits + # Should only have "not_null" test, not the "relationships" test + model1_audit_names = [audit[0] for audit in model1_snapshot.model.audits] + assert len(model1_audit_names) == 1 + assert model1_audit_names[0] == "local.not_null_model1_id" + + # Verify model2 has its non-standalone test + model2_audit_names = [audit[0] for audit in model2_snapshot.model.audits] + assert len(model2_audit_names) == 1 + assert model2_audit_names[0] == "local.not_null_model2_id" + + # Verify the standalone test (relationships) exists as a StandaloneAudit + all_non_standalone_audits = [name for name in context._audits] + assert sorted(all_non_standalone_audits) == [ + "local.not_null_model1_id", + "local.not_null_model2_id", + ] + + standalone_audits = [name for name in context._standalone_audits] + assert len(standalone_audits) == 1 + assert standalone_audits[0] == "local.relationships_model1_id__id__ref_model2_" + + plan_builder = context.plan_builder() + dag = plan_builder._build_dag() + assert [x.name for x in dag.sorted] == [ + '"local"."main"."model1"', + '"local"."main"."model2"', + "relationships_model1_id__id__ref_model2_", + ] + + +@pytest.mark.slow +def test_load_invalid_ref_audit_constraints( + tmp_path: Path, caplog, dbt_dummy_postgres_config: PostgresConfig, create_empty_project +) -> None: + yaml = YAML() + project_dir, model_dir = create_empty_project(project_name="local") + # add `tests` to model config since this is loaded by dbt and ignored and we shouldn't error when loading it + full_model_contents = """{{ config(tags=["blah"], tests=[{"blah": {"to": "ref('completely_ignored')", "field": "blah2"} }]) }} SELECT 1 as cola""" + full_model_file = model_dir / "full_model.sql" + with open(full_model_file, "w", encoding="utf-8") as f: + f.write(full_model_contents) + model_schema = { + "version": 2, + "models": [ + { + "name": "full_model", + "description": "A full model bad ref for audit and constraints", + "columns": [ + { + "name": "cola", + "description": "A column that is used in a ref audit and constraints", + "constraints": [ + { + "type": "primary_key", + "columns": ["cola"], + "expression": "ref('not_real_model') (cola)", + } + ], + "tests": [ + { + # References a model that doesn't exist + "relationships": { + "to": "ref('not_real_model')", + "field": "cola", + }, + }, + { + # Reference a source that doesn't exist + "relationships": { + "to": "source('not_real_source', 'not_real_table')", + "field": "cola", + }, + }, + ], + } + ], + } + ], + } + model_schema_file = model_dir / "schema.yml" + with open(model_schema_file, "w", encoding="utf-8") as f: + yaml.dump(model_schema, f) + + assert isinstance(get_console(), NoopConsole) + with caplog.at_level(logging.DEBUG): + context = Context(paths=project_dir) + assert ( + "Skipping audit 'relationships_full_model_cola__cola__ref_not_real_model_' because model 'not_real_model' is not a valid ref" + in caplog.text + ) + assert ( + "Skipping audit 'relationships_full_model_cola__cola__source_not_real_source_not_real_table_' because source 'not_real_source.not_real_table' is not a valid ref" + in caplog.text + ) + fqn = '"local"."main"."full_model"' + assert fqn in context.snapshots + # The audit isn't loaded due to the invalid ref + assert context.snapshots[fqn].model.audits == [] + + +@pytest.mark.slow +def test_load_microbatch_all_defined( + tmp_path: Path, caplog, dbt_dummy_postgres_config: PostgresConfig, create_empty_project +) -> None: + project_dir, model_dir = create_empty_project(project_name="local") + # add `tests` to model config since this is loaded by dbt and ignored and we shouldn't error when loading it + microbatch_contents = """ + {{ + config( + materialized='incremental', + incremental_strategy='microbatch', + event_time='ds', + begin='2020-01-01', + batch_size='day', + lookback=2, + concurrent_batches=true + ) + }} + + SELECT 1 as cola, '2025-01-01' as ds + """ + microbatch_model_file = model_dir / "microbatch.sql" + with open(microbatch_model_file, "w", encoding="utf-8") as f: + f.write(microbatch_contents) + + snapshot_fqn = '"local"."main"."microbatch"' + context = Context(paths=project_dir) + model = context.snapshots[snapshot_fqn].model + # Validate model-level attributes + assert model.start == datetime.datetime(2020, 1, 1, 0, 0) + assert model.interval_unit.is_day + # Validate model kind attributes + assert isinstance(model.kind, IncrementalByTimeRangeKind) + assert model.kind.lookback == 2 + assert model.kind.time_column == TimeColumn( + column=exp.to_column("ds", quoted=True), format="%Y-%m-%d" + ) + assert model.kind.batch_size == 1 + assert model.depends_on_self is False + + +@pytest.mark.slow +def test_load_microbatch_all_defined_diff_values( + tmp_path: Path, caplog, dbt_dummy_postgres_config: PostgresConfig, create_empty_project +) -> None: + project_dir, model_dir = create_empty_project(project_name="local") + # add `tests` to model config since this is loaded by dbt and ignored and we shouldn't error when loading it + microbatch_contents = """ + {{ + config( + materialized='incremental', + incremental_strategy='microbatch', + cron='@yearly', + event_time='blah', + begin='2022-01-01', + batch_size='year', + lookback=20, + concurrent_batches=false + ) + }} + + SELECT 1 as cola, '2022-01-01' as blah + """ + microbatch_model_file = model_dir / "microbatch.sql" + with open(microbatch_model_file, "w", encoding="utf-8") as f: + f.write(microbatch_contents) + + snapshot_fqn = '"local"."main"."microbatch"' + context = Context(paths=project_dir) + model = context.snapshots[snapshot_fqn].model + # Validate model-level attributes + assert model.start == datetime.datetime(2022, 1, 1, 0, 0) + assert model.interval_unit.is_year + # Validate model kind attributes + assert isinstance(model.kind, IncrementalByTimeRangeKind) + assert model.kind.lookback == 20 + assert model.kind.time_column == TimeColumn( + column=exp.to_column("blah", quoted=True), format="%Y-%m-%d" + ) + assert model.kind.batch_size == 1 + assert model.depends_on_self is True + + +@pytest.mark.slow +def test_load_microbatch_required_only( + tmp_path: Path, caplog, dbt_dummy_postgres_config: PostgresConfig, create_empty_project +) -> None: + project_dir, model_dir = create_empty_project(project_name="local") + # add `tests` to model config since this is loaded by dbt and ignored and we shouldn't error when loading it + microbatch_contents = """ + {{ + config( + materialized='incremental', + incremental_strategy='microbatch', + begin='2021-01-01', + event_time='ds', + batch_size='hour', + ) + }} + + SELECT 1 as cola, '2021-01-01' as ds + """ + microbatch_model_file = model_dir / "microbatch.sql" + with open(microbatch_model_file, "w", encoding="utf-8") as f: + f.write(microbatch_contents) + + snapshot_fqn = '"local"."main"."microbatch"' + context = Context(paths=project_dir) + model = context.snapshots[snapshot_fqn].model + # Validate model-level attributes + assert model.start == datetime.datetime(2021, 1, 1, 0, 0) + assert model.interval_unit.is_hour + # Validate model kind attributes + assert isinstance(model.kind, IncrementalByTimeRangeKind) + assert model.kind.lookback == 1 + assert model.kind.time_column == TimeColumn( + column=exp.to_column("ds", quoted=True), format="%Y-%m-%d" + ) + assert model.kind.batch_size == 1 + assert model.depends_on_self is False + + +@pytest.mark.slow +def test_load_incremental_time_range_strategy_required_only( + tmp_path: Path, caplog, dbt_dummy_postgres_config: PostgresConfig, create_empty_project +) -> None: + project_dir, model_dir = create_empty_project(project_name="local", start="2025-01-01") + # add `tests` to model config since this is loaded by dbt and ignored and we shouldn't error when loading it + incremental_time_range_contents = """ + {{ + config( + materialized='incremental', + incremental_strategy='incremental_by_time_range', + time_column='ds', + ) + }} + + SELECT 1 as cola, '2021-01-01' as ds + """ + incremental_time_range_model_file = model_dir / "incremental_time_range.sql" + with open(incremental_time_range_model_file, "w", encoding="utf-8") as f: + f.write(incremental_time_range_contents) + + snapshot_fqn = '"local"."main"."incremental_time_range"' + context = Context(paths=project_dir) + snapshot = context.snapshots[snapshot_fqn] + model = snapshot.model + # Validate model-level attributes + assert to_ds(model.start or "") == "2025-01-01" + assert model.interval_unit.is_day + # Validate model kind attributes + assert isinstance(model.kind, IncrementalByTimeRangeKind) + assert model.kind.lookback == 1 + assert model.kind.time_column == TimeColumn( + column=exp.to_column("ds", quoted=True), format="%Y-%m-%d" + ) + assert model.kind.batch_size is None + assert model.depends_on_self is False + assert model.kind.auto_restatement_intervals is None + assert model.kind.partition_by_time_column is True + # make sure the snapshot can be serialized to json + assert isinstance(_snapshot_to_json(snapshot), str) + + +@pytest.mark.slow +def test_load_incremental_time_range_strategy_all_defined( + tmp_path: Path, caplog, dbt_dummy_postgres_config: PostgresConfig, create_empty_project +) -> None: + project_dir, model_dir = create_empty_project(project_name="local", start="2025-01-01") + # add `tests` to model config since this is loaded by dbt and ignored and we shouldn't error when loading it + incremental_time_range_contents = """ + {{ + config( + materialized='incremental', + incremental_strategy='incremental_by_time_range', + time_column={ + 'column': 'ds', + 'format': '%Y%m%d' + }, + auto_restatement_intervals=3, + partition_by_time_column=false, + lookback=5, + batch_size=3, + batch_concurrency=2, + forward_only=true, + disable_restatement=true, + on_destructive_change='allow', + on_additive_change='error', + auto_restatement_cron='@hourly', + on_schema_change='ignore' + ) + }} + + SELECT 1 as cola, '2021-01-01' as ds + """ + incremental_time_range_model_file = model_dir / "incremental_time_range.sql" + with open(incremental_time_range_model_file, "w", encoding="utf-8") as f: + f.write(incremental_time_range_contents) + + snapshot_fqn = '"local"."main"."incremental_time_range"' + context = Context(paths=project_dir) + snapshot = context.snapshots[snapshot_fqn] + model = snapshot.model + # Validate model-level attributes + assert to_ds(model.start or "") == "2025-01-01" + assert model.interval_unit.is_day + # Validate model kind attributes + assert isinstance(model.kind, IncrementalByTimeRangeKind) + # `on_schema_change` is ignored since the user explicitly overrode the values + assert model.kind.on_destructive_change == OnDestructiveChange.ALLOW + assert model.kind.on_additive_change == OnAdditiveChange.ERROR + assert model.kind.forward_only is True + assert model.kind.disable_restatement is True + assert model.kind.auto_restatement_cron == "@hourly" + assert model.kind.auto_restatement_intervals == 3 + assert model.kind.partition_by_time_column is False + assert model.kind.lookback == 5 + assert model.kind.time_column == TimeColumn( + column=exp.to_column("ds", quoted=True), format="%Y%m%d" + ) + assert model.kind.batch_size == 3 + assert model.kind.batch_concurrency == 2 + assert model.depends_on_self is False + # make sure the snapshot can be serialized to json + assert isinstance(_snapshot_to_json(snapshot), str) + + +@pytest.mark.slow +def test_load_deprecated_incremental_time_column( + tmp_path: Path, caplog, dbt_dummy_postgres_config: PostgresConfig, create_empty_project +) -> None: + project_dir, model_dir = create_empty_project(project_name="local", start="2025-01-01") + # add `tests` to model config since this is loaded by dbt and ignored and we shouldn't error when loading it + incremental_time_range_contents = """ + {{ + config( + materialized='incremental', + incremental_strategy='delete+insert', + time_column='ds' + ) + }} + + SELECT 1 as cola, '2021-01-01' as ds + """ + incremental_time_range_model_file = model_dir / "incremental_time_range.sql" + with open(incremental_time_range_model_file, "w", encoding="utf-8") as f: + f.write(incremental_time_range_contents) + + snapshot_fqn = '"local"."main"."incremental_time_range"' + assert isinstance(get_console(), NoopConsole) + with caplog.at_level(logging.DEBUG): + context = Context(paths=project_dir) + model = context.snapshots[snapshot_fqn].model + # Validate model-level attributes + assert to_ds(model.start or "") == "2025-01-01" + assert model.interval_unit.is_day + # Validate model-level attributes + assert to_ds(model.start or "") == "2025-01-01" + assert model.interval_unit.is_day + # Validate model kind attributes + assert isinstance(model.kind, IncrementalByTimeRangeKind) + assert model.kind.lookback == 1 + assert model.kind.time_column == TimeColumn( + column=exp.to_column("ds", quoted=True), format="%Y-%m-%d" + ) + assert model.kind.batch_size is None + assert model.depends_on_self is False + assert model.kind.auto_restatement_intervals is None + assert model.kind.partition_by_time_column is True + assert ( + "Using `time_column` on a model with incremental_strategy 'delete+insert' has been deprecated. Please use `incremental_by_time_range` instead in model 'main.incremental_time_range'." + in caplog.text + ) + + +@pytest.mark.slow +def test_load_microbatch_with_ref( + tmp_path: Path, caplog, dbt_dummy_postgres_config: PostgresConfig, create_empty_project +) -> None: + yaml = YAML() + project_dir, model_dir = create_empty_project(project_name="local") + source_schema = { + "version": 2, + "sources": [ + { + "name": "my_source", + "tables": [{"name": "my_table", "config": {"event_time": "ds_source"}}], + } + ], + } + source_schema_file = model_dir / "source_schema.yml" + with open(source_schema_file, "w", encoding="utf-8") as f: + yaml.dump(source_schema, f) + # add `tests` to model config since this is loaded by dbt and ignored and we shouldn't error when loading it + microbatch_contents = """ + {{ + config( + materialized='incremental', + incremental_strategy='microbatch', + event_time='ds', + begin='2020-01-01', + batch_size='day' + ) + }} + + SELECT cola, ds_source as ds FROM {{ source('my_source', 'my_table') }} + """ + microbatch_model_file = model_dir / "microbatch.sql" + with open(microbatch_model_file, "w", encoding="utf-8") as f: + f.write(microbatch_contents) + + microbatch_two_contents = """ + {{ + config( + materialized='incremental', + incremental_strategy='microbatch', + event_time='ds', + begin='2020-01-05', + batch_size='day' + ) + }} + + SELECT cola, ds FROM {{ ref('microbatch') }} + """ + microbatch_two_model_file = model_dir / "microbatch_two.sql" + with open(microbatch_two_model_file, "w", encoding="utf-8") as f: + f.write(microbatch_two_contents) + + microbatch_snapshot_fqn = '"local"."main"."microbatch"' + microbatch_two_snapshot_fqn = '"local"."main"."microbatch_two"' + context = Context(paths=project_dir) + assert ( + context.render(microbatch_snapshot_fqn, start="2025-01-01", end="2025-01-10").sql() + == 'SELECT "cola" AS "cola", "ds_source" AS "ds" FROM (SELECT * FROM "local"."my_source"."my_table" AS "my_table" WHERE "ds_source" >= \'2025-01-01 00:00:00+00:00\' AND "ds_source" < \'2025-01-11 00:00:00+00:00\') AS "_0"' + ) + assert ( + context.render(microbatch_two_snapshot_fqn, start="2025-01-01", end="2025-01-10").sql() + == 'SELECT "_0"."cola" AS "cola", "_0"."ds" AS "ds" FROM (SELECT "microbatch"."cola" AS "cola", "microbatch"."ds" AS "ds" FROM "local"."main"."microbatch" AS "microbatch" WHERE "microbatch"."ds" < \'2025-01-11 00:00:00+00:00\' AND "microbatch"."ds" >= \'2025-01-01 00:00:00+00:00\') AS "_0"' + ) + + +@pytest.mark.slow +def test_load_microbatch_with_ref_no_filter( + tmp_path: Path, caplog, dbt_dummy_postgres_config: PostgresConfig, create_empty_project +) -> None: + yaml = YAML() + project_dir, model_dir = create_empty_project(project_name="local") + source_schema = { + "version": 2, + "sources": [ + { + "name": "my_source", + "tables": [{"name": "my_table", "config": {"event_time": "ds"}}], + } + ], + } + source_schema_file = model_dir / "source_schema.yml" + with open(source_schema_file, "w", encoding="utf-8") as f: + yaml.dump(source_schema, f) + # add `tests` to model config since this is loaded by dbt and ignored and we shouldn't error when loading it + microbatch_contents = """ + {{ + config( + materialized='incremental', + incremental_strategy='microbatch', + event_time='ds', + begin='2020-01-01', + batch_size='day' + ) + }} + + SELECT cola, ds FROM {{ source('my_source', 'my_table').render() }} + """ + microbatch_model_file = model_dir / "microbatch.sql" + with open(microbatch_model_file, "w", encoding="utf-8") as f: + f.write(microbatch_contents) + + microbatch_two_contents = """ + {{ + config( + materialized='incremental', + incremental_strategy='microbatch', + event_time='ds', + begin='2020-01-01', + batch_size='day' + ) + }} + + SELECT cola, ds FROM {{ ref('microbatch').render() }} + """ + microbatch_two_model_file = model_dir / "microbatch_two.sql" + with open(microbatch_two_model_file, "w", encoding="utf-8") as f: + f.write(microbatch_two_contents) + + microbatch_snapshot_fqn = '"local"."main"."microbatch"' + microbatch_two_snapshot_fqn = '"local"."main"."microbatch_two"' + context = Context(paths=project_dir) + assert ( + context.render(microbatch_snapshot_fqn, start="2025-01-01", end="2025-01-10").sql() + == 'SELECT "cola" AS "cola", "ds" AS "ds" FROM "local"."my_source"."my_table" AS "my_table"' + ) + assert ( + context.render(microbatch_two_snapshot_fqn, start="2025-01-01", end="2025-01-10").sql() + == 'SELECT "microbatch"."cola" AS "cola", "microbatch"."ds" AS "ds" FROM "local"."main"."microbatch" AS "microbatch"' + ) + + +@pytest.mark.slow +def test_load_multiple_snapshots_defined_in_same_file(sushi_test_dbt_context: Context) -> None: + context = sushi_test_dbt_context + assert context.get_model("snapshots.items_snapshot") + assert context.get_model("snapshots.items_check_snapshot") + + # Make sure cache works too + context.load() + assert context.get_model("snapshots.items_snapshot") + assert context.get_model("snapshots.items_check_snapshot") + + +@pytest.mark.slow +def test_dbt_snapshot_with_check_cols_expressions(sushi_test_dbt_context: Context) -> None: + context = sushi_test_dbt_context + model = context.get_model("snapshots.items_check_with_cast_snapshot") + assert model is not None + assert isinstance(model.kind, SCDType2ByColumnKind) + + columns = model.kind.columns + assert isinstance(columns, list) + assert len(columns) == 1 + + # expression in check_cols is: ds::DATE + assert isinstance(columns[0], exp.Cast) + assert columns[0].sql() == 'CAST("ds" AS DATE)' + + context.load() + cached_model = context.get_model("snapshots.items_check_with_cast_snapshot") + assert cached_model is not None + assert isinstance(cached_model.kind, SCDType2ByColumnKind) + assert isinstance(cached_model.kind.columns, list) + assert len(cached_model.kind.columns) == 1 + + +@pytest.mark.slow +def test_dbt_jinja_macro_undefined_variable_error(create_empty_project): + project_dir, model_dir = create_empty_project() + + macros_dir = project_dir / "macros" + macros_dir.mkdir() + + # the execute guard in the macro is so that dbt won't fail on the manifest loading earlier + macro_file = macros_dir / "my_macro.sql" + macro_file.write_text(""" +{%- macro select_columns(table_name) -%} + {% if execute %} + {%- if target.name == 'production' -%} + {%- set columns = run_query('SELECT column_name FROM information_schema.columns WHERE table_name = \'' ~ table_name ~ '\'') -%} + {%- endif -%} + SELECT {{ columns.rows[0][0] }} FROM {{ table_name }} + {%- endif -%} +{%- endmacro -%} +""") + + model_file = model_dir / "my_model.sql" + model_file.write_text(""" +{{ config( + materialized='table' +) }} + +{{ select_columns('users') }} +""") + + with pytest.raises(SchemaError) as exc_info: + Context(paths=project_dir) + + error_message = str(exc_info.value) + assert "Failed to update model schemas" in error_message + assert "Could not render jinja for" in error_message + assert "Undefined macro/variable: 'columns' in macro: 'select_columns'" in error_message + + +@pytest.mark.slow +def test_node_name_populated_for_dbt_models(dbt_dummy_postgres_config: PostgresConfig) -> None: + model_config = ModelConfig( + unique_id="model.test_package.test_model", + fqn=["test_package", "test_model"], + name="test_model", + package_name="test_package", + sql="SELECT 1 as id", + database="test_db", + schema_="test_schema", + alias="test_model", + ) + + context = DbtContext() + context.project_name = "test_project" + context.target = dbt_dummy_postgres_config + + # check after convert to SQLMesh model that node_name is populated correctly + sqlmesh_model = model_config.to_sqlmesh(context) + assert sqlmesh_model.dbt_unique_id == "model.test_package.test_model" + assert sqlmesh_model.dbt_fqn == "test_package.test_model" + + +@pytest.mark.slow +def test_load_model_dbt_node_name(tmp_path: Path) -> None: + yaml = YAML() + dbt_project_dir = tmp_path / "dbt" + dbt_project_dir.mkdir() + dbt_model_dir = dbt_project_dir / "models" + dbt_model_dir.mkdir() + + model_contents = "SELECT 1 as id, 'test' as name" + model_file = dbt_model_dir / "simple_model.sql" + with open(model_file, "w", encoding="utf-8") as f: + f.write(model_contents) + + dbt_project_config = { + "name": "test_project", + "version": "1.0.0", + "config-version": 2, + "profile": "test", + "model-paths": ["models"], + } + dbt_project_file = dbt_project_dir / "dbt_project.yml" + with open(dbt_project_file, "w", encoding="utf-8") as f: + yaml.dump(dbt_project_config, f) + + sqlmesh_config = { + "model_defaults": { + "start": "2025-01-01", + } + } + sqlmesh_config_file = dbt_project_dir / "sqlmesh.yaml" + with open(sqlmesh_config_file, "w", encoding="utf-8") as f: + yaml.dump(sqlmesh_config, f) + + dbt_data_dir = tmp_path / "dbt_data" + dbt_data_dir.mkdir() + dbt_data_file = dbt_data_dir / "local.db" + dbt_profile_config = { + "test": { + "outputs": {"duckdb": {"type": "duckdb", "path": str(dbt_data_file)}}, + "target": "duckdb", + } + } + db_profile_file = dbt_project_dir / "profiles.yml" + with open(db_profile_file, "w", encoding="utf-8") as f: + yaml.dump(dbt_profile_config, f) + + context = Context(paths=dbt_project_dir) + + # find the model by its sqlmesh fully qualified name + model_fqn = '"local"."main"."simple_model"' + assert model_fqn in context.snapshots + + # Verify that node_name is the equivalent dbt one + model = context.snapshots[model_fqn].model + assert model.dbt_unique_id == "model.test_project.simple_model" + assert model.dbt_fqn == "test_project.simple_model" + assert model.dbt_node_info + assert model.dbt_node_info.name == "simple_model" + + +@pytest.mark.slow +def test_jinja_config_no_query(create_empty_project): + project_dir, model_dir = create_empty_project(project_name="local") + + # model definition contains only a comment and non-SQL jinja + model_contents = "/* comment */ {{ config(materialized='table') }}" + model_file = model_dir / "comment_config_model.sql" + with open(model_file, "w", encoding="utf-8") as f: + f.write(model_contents) + + schema_yaml = {"version": 2, "models": [{"name": "comment_config_model"}]} + schema_file = model_dir / "schema.yml" + with open(schema_file, "w", encoding="utf-8") as f: + YAML().dump(schema_yaml, f) + + context = Context(paths=project_dir) + + # loads without error and contains empty query (which will error at runtime) + assert not context.snapshots['"local"."main"."comment_config_model"'].model.render_query() + + +@pytest.mark.slow +def test_load_custom_materialisations(sushi_test_dbt_context: Context) -> None: + context = sushi_test_dbt_context + assert context.get_model("sushi.custom_incremental_model") + assert context.get_model("sushi.custom_incremental_with_filter") + + context.load() + assert context.get_model("sushi.custom_incremental_model") + assert context.get_model("sushi.custom_incremental_with_filter") + + +def test_model_grants_to_sqlmesh_grants_config() -> None: + grants_config = { + "select": ["user1", "user2"], + "insert": ["admin_user"], + "update": ["power_user"], + } + model_config = ModelConfig( + name="test_model", + sql="SELECT 1 as id", + grants=grants_config, + path=Path("test_model.sql"), + ) + + context = DbtContext() + context.project_name = "test_project" + context.target = DuckDbConfig(name="target", schema="test_schema") + + sqlmesh_model = model_config.to_sqlmesh( + context, virtual_environment_mode=VirtualEnvironmentMode.FULL + ) + + model_grants = sqlmesh_model.grants + assert model_grants == grants_config + + assert sqlmesh_model.grants_target_layer == GrantsTargetLayer.default + + +def test_model_grants_empty_permissions() -> None: + model_config = ModelConfig( + name="test_model_empty", + sql="SELECT 1 as id", + grants={"select": [], "insert": ["admin_user"]}, + path=Path("test_model_empty.sql"), + ) + + context = DbtContext() + context.project_name = "test_project" + context.target = DuckDbConfig(name="target", schema="test_schema") + + sqlmesh_model = model_config.to_sqlmesh( + context, virtual_environment_mode=VirtualEnvironmentMode.FULL + ) + + model_grants = sqlmesh_model.grants + expected_grants = {"select": [], "insert": ["admin_user"]} + assert model_grants == expected_grants + + +def test_model_no_grants() -> None: + model_config = ModelConfig( + name="test_model_no_grants", + sql="SELECT 1 as id", + path=Path("test_model_no_grants.sql"), + ) + + context = DbtContext() + context.project_name = "test_project" + context.target = DuckDbConfig(name="target", schema="test_schema") + + sqlmesh_model = model_config.to_sqlmesh( + context, virtual_environment_mode=VirtualEnvironmentMode.FULL + ) + + grants_config = sqlmesh_model.grants + assert grants_config is None + + +def test_model_empty_grants() -> None: + model_config = ModelConfig( + name="test_model_empty_grants", + sql="SELECT 1 as id", + grants={}, + path=Path("test_model_empty_grants.sql"), + ) + + context = DbtContext() + context.project_name = "test_project" + context.target = DuckDbConfig(name="target", schema="test_schema") + + sqlmesh_model = model_config.to_sqlmesh( + context, virtual_environment_mode=VirtualEnvironmentMode.FULL + ) + + grants_config = sqlmesh_model.grants + assert grants_config is None + + +def test_model_grants_valid_special_characters() -> None: + valid_grantees = [ + "user@domain.com", + "service-account@project.iam.gserviceaccount.com", + "group:analysts", + '"quoted user"', + "`backtick user`", + "user_with_underscores", + "user.with.dots", + ] + + model_config = ModelConfig( + name="test_model_special_chars", + sql="SELECT 1 as id", + grants={"select": valid_grantees}, + path=Path("test_model.sql"), + ) + + context = DbtContext() + context.project_name = "test_project" + context.target = DuckDbConfig(name="target", schema="test_schema") + + sqlmesh_model = model_config.to_sqlmesh( + context, virtual_environment_mode=VirtualEnvironmentMode.FULL + ) + + grants_config = sqlmesh_model.grants + assert grants_config is not None + assert "select" in grants_config + assert grants_config["select"] == valid_grantees + + +def test_model_grants_engine_specific_bigquery() -> None: + model_config = ModelConfig( + name="test_model_bigquery", + sql="SELECT 1 as id", + grants={ + "bigquery.dataviewer": ["user@domain.com"], + "select": ["analyst@company.com"], + }, + path=Path("test_model.sql"), + ) + + context = DbtContext() + context.project_name = "test_project" + context.target = BigQueryConfig( + name="bigquery_target", + project="test-project", + dataset="test_dataset", + location="US", + database="test-project", + schema="test_dataset", + ) + + sqlmesh_model = model_config.to_sqlmesh( + context, virtual_environment_mode=VirtualEnvironmentMode.FULL + ) + + grants_config = sqlmesh_model.grants + assert grants_config is not None + assert grants_config["bigquery.dataviewer"] == ["user@domain.com"] + assert grants_config["select"] == ["analyst@company.com"] + + +def test_ephemeral_model_ignores_grants() -> None: + """Test that ephemeral models ignore grants configuration.""" + model_config = ModelConfig( + name="ephemeral_model", + sql="SELECT 1 as id", + materialized="ephemeral", + grants={"select": ["reporter", "analyst"]}, + path=Path("ephemeral_model.sql"), + ) + + context = DbtContext() + context.project_name = "test_project" + context.target = DuckDbConfig(name="target", schema="test_schema") + + sqlmesh_model = model_config.to_sqlmesh( + context, virtual_environment_mode=VirtualEnvironmentMode.FULL + ) + + assert sqlmesh_model.kind.is_embedded + assert sqlmesh_model.grants is None # grants config is skipped for ephemeral / embedded models + + +def test_conditional_ref_in_unexecuted_branch(copy_to_temp_path: t.Callable): + path = copy_to_temp_path("tests/fixtures/dbt/sushi_test") + temp_project = path[0] + + models_dir = temp_project / "models" + models_dir.mkdir(parents=True, exist_ok=True) + + test_model_content = """ +{{ config( + materialized='table', +) }} + +{% if true %} + WITH source AS ( + SELECT * + FROM {{ ref('simple_model_a') }} + ) +{% else %} + WITH source AS ( + SELECT * + FROM {{ ref('nonexistent_model') }} -- this doesn't exist but is in unexecuted branch + ) +{% endif %} + +SELECT * FROM source +""".strip() + + (models_dir / "conditional_ref_model.sql").write_text(test_model_content) + sushi_context = Context(paths=[str(temp_project)]) + + # the model should load successfully without raising MissingModelError + model = sushi_context.get_model("sushi.conditional_ref_model") + assert model is not None + + # Verify only the executed ref is in the dependencies + assert len(model.depends_on) == 1 + assert '"memory"."sushi"."simple_model_a"' in model.depends_on + + # Also the model can be rendered successfully with the executed ref + rendered = model.render_query() + assert rendered is not None + assert ( + rendered.sql() + == 'WITH "source" AS (SELECT "simple_model_a"."a" AS "a" FROM "memory"."sushi"."simple_model_a" AS "simple_model_a") SELECT "source"."a" AS "a" FROM "source" AS "source"' + ) + + # And run plan with this conditional model for good measure + plan = sushi_context.plan(select_models=["sushi.conditional_ref_model", "sushi.simple_model_a"]) + sushi_context.apply(plan) + upstream_ref = sushi_context.engine_adapter.fetchone("SELECT * FROM sushi.simple_model_a") + assert upstream_ref == (1,) + result = sushi_context.engine_adapter.fetchone("SELECT * FROM sushi.conditional_ref_model") + assert result == (1,) diff --git a/tests/dbt/test_test.py b/tests/dbt/test_test.py index 845c1d2fc0..fb33220c0c 100644 --- a/tests/dbt/test_test.py +++ b/tests/dbt/test_test.py @@ -1,3 +1,7 @@ +from pathlib import Path + +import pytest + from sqlmesh.dbt.test import TestConfig @@ -8,3 +12,131 @@ def test_multiline_test_kwarg() -> None: test_kwargs={"test_field": "foo\nbar\n"}, ) assert test._kwargs() == 'test_field="foo\nbar"' + + +@pytest.mark.xdist_group("dbt_manifest") +def test_tests_get_unique_names(tmp_path: Path, create_empty_project) -> None: + from sqlmesh.utils.yaml import YAML + from sqlmesh.core.context import Context + + yaml = YAML() + project_dir, model_dir = create_empty_project(project_name="local") + + model_file = model_dir / "my_model.sql" + with open(model_file, "w", encoding="utf-8") as f: + f.write("SELECT 1 as id, 'value1' as status") + + # Create schema.yml with: + # 1. Same test on model and source, both with/without custom test name + # 2. Same test on same model with different args, both with/without custom test name + # 3. Versioned model with tests (both built-in and custom named) + schema_yaml = { + "version": 2, + "sources": [ + { + "name": "raw", + "tables": [ + { + "name": "my_source", + "columns": [ + { + "name": "id", + "data_tests": [ + {"not_null": {"name": "custom_notnull_name"}}, + {"not_null": {}}, + ], + } + ], + } + ], + } + ], + "models": [ + { + "name": "my_model", + "columns": [ + { + "name": "id", + "data_tests": [ + {"not_null": {"name": "custom_notnull_name"}}, + {"not_null": {}}, + ], + }, + { + "name": "status", + "data_tests": [ + {"accepted_values": {"values": ["value1", "value2"]}}, + {"accepted_values": {"values": ["value1", "value2", "value3"]}}, + { + "accepted_values": { + "name": "custom_accepted_values_name", + "values": ["value1", "value2"], + } + }, + { + "accepted_values": { + "name": "custom_accepted_values_name", + "values": ["value1", "value2", "value3"], + } + }, + ], + }, + ], + }, + { + "name": "versioned_model", + "columns": [ + { + "name": "id", + "data_tests": [ + {"not_null": {}}, + {"not_null": {"name": "custom_versioned_notnull"}}, + ], + }, + { + "name": "amount", + "data_tests": [ + {"accepted_values": {"values": ["low", "high"]}}, + ], + }, + ], + "versions": [ + {"v": 1}, + {"v": 2}, + ], + }, + ], + } + + schema_file = model_dir / "schema.yml" + with open(schema_file, "w", encoding="utf-8") as f: + yaml.dump(schema_yaml, f) + + # Create versioned model files + versioned_model_v1_file = model_dir / "versioned_model_v1.sql" + with open(versioned_model_v1_file, "w", encoding="utf-8") as f: + f.write("SELECT 1 as id, 'low' as amount") + + versioned_model_v2_file = model_dir / "versioned_model_v2.sql" + with open(versioned_model_v2_file, "w", encoding="utf-8") as f: + f.write("SELECT 1 as id, 'low' as amount") + + context = Context(paths=project_dir) + + all_audit_names = list(context._audits.keys()) + list(context._standalone_audits.keys()) + assert sorted(all_audit_names) == [ + "local.accepted_values_my_model_status__value1__value2", + "local.accepted_values_my_model_status__value1__value2__value3", + "local.accepted_values_versioned_model_v1_amount__low__high", + "local.accepted_values_versioned_model_v2_amount__low__high", + "local.custom_accepted_values_name_my_model_status__value1__value2", + "local.custom_accepted_values_name_my_model_status__value1__value2__value3", + "local.custom_notnull_name_my_model_id", + "local.custom_versioned_notnull_versioned_model_v1_id", + "local.custom_versioned_notnull_versioned_model_v2_id", + "local.not_null_my_model_id", + "local.not_null_versioned_model_v1_id", + "local.not_null_versioned_model_v2_id", + "local.source_custom_notnull_name_raw_my_source_id", + "local.source_not_null_raw_my_source_id", + ] diff --git a/tests/dbt/test_transformation.py b/tests/dbt/test_transformation.py index 1c9e947c24..fe6073dfad 100644 --- a/tests/dbt/test_transformation.py +++ b/tests/dbt/test_transformation.py @@ -1,29 +1,50 @@ +import agate +from datetime import datetime, timedelta import json import logging import typing as t from pathlib import Path from unittest.mock import patch +from sqlmesh.dbt.util import DBT_VERSION + import pytest from dbt.adapters.base import BaseRelation -from dbt.exceptions import CompilationError -from freezegun import freeze_time +from jinja2 import Template + +if DBT_VERSION >= (1, 4, 0): + from dbt.exceptions import CompilationError +else: + from dbt.exceptions import CompilationException as CompilationError # type: ignore +import time_machine from pytest_mock.plugin import MockerFixture from sqlglot import exp, parse_one from sqlmesh.core import dialect as d +from sqlmesh.core.environment import EnvironmentNamingInfo +from sqlmesh.core.macros import RuntimeStage +from sqlmesh.core.renderer import render_statements from sqlmesh.core.audit import StandaloneAudit from sqlmesh.core.context import Context +from sqlmesh.core.console import get_console from sqlmesh.core.model import ( EmbeddedKind, FullKind, IncrementalByTimeRangeKind, IncrementalByUniqueKeyKind, IncrementalUnmanagedKind, + ManagedKind, SqlModel, ViewKind, ) -from sqlmesh.core.model.kind import SCDType2ByColumnKind, SCDType2ByTimeKind -from sqlmesh.core.state_sync.engine_adapter import _snapshot_to_json +from sqlmesh.core.model.kind import ( + SCDType2ByColumnKind, + SCDType2ByTimeKind, + OnDestructiveChange, + OnAdditiveChange, +) +from sqlmesh.core.state_sync.db.snapshot import _snapshot_to_json +from sqlmesh.dbt.builtin import _relation_info_to_relation, Config +from sqlmesh.dbt.common import Dependencies from sqlmesh.dbt.builtin import _relation_info_to_relation from sqlmesh.dbt.column import ( ColumnConfig, @@ -32,19 +53,27 @@ ) from sqlmesh.dbt.context import DbtContext from sqlmesh.dbt.model import Materialization, ModelConfig +from sqlmesh.dbt.source import SourceConfig from sqlmesh.dbt.project import Project from sqlmesh.dbt.relation import Policy from sqlmesh.dbt.seed import SeedConfig -from sqlmesh.dbt.target import BigQueryConfig, DuckDbConfig, SnowflakeConfig +from sqlmesh.dbt.target import ( + BigQueryConfig, + DuckDbConfig, + SnowflakeConfig, + ClickhouseConfig, + PostgresConfig, +) from sqlmesh.dbt.test import TestConfig -from sqlmesh.utils.errors import ConfigError, MacroEvalError, SQLMeshError +from sqlmesh.utils.errors import ConfigError, SQLMeshError +from sqlmesh.utils.jinja import MacroReference pytestmark = [pytest.mark.dbt, pytest.mark.slow] -def test_model_name(): +def test_model_name(dbt_dummy_postgres_config: PostgresConfig): context = DbtContext() - context._target = DuckDbConfig(name="duckdb", schema="foo") + context._target = dbt_dummy_postgres_config assert ModelConfig(schema="foo", path="models/bar.sql").canonical_name(context) == "foo.bar" assert ( ModelConfig(schema="foo", path="models/bar.sql", alias="baz").canonical_name(context) @@ -52,10 +81,9 @@ def test_model_name(): ) assert ( ModelConfig( - database="memory", schema="foo", path="models/bar.sql", alias="baz" + database="dbname", schema="foo", path="models/bar.sql", alias="baz" ).canonical_name(context) == "foo.baz" - == "foo.baz" ) assert ( ModelConfig( @@ -65,6 +93,150 @@ def test_model_name(): ) +def test_materialization(): + context = DbtContext() + context.project_name = "Test" + context.target = DuckDbConfig(name="target", schema="foo") + + with patch.object(get_console(), "log_warning") as mock_logger: + model_config = ModelConfig( + name="model", alias="model", schema="schema", materialized="materialized_view" + ) + + assert ( + "SQLMesh does not support the 'materialized_view' model materialization. Falling back to the 'view' materialization." + in mock_logger.call_args[0][0] + ) + assert model_config.materialized == "view" + + # clickhouse "dictionary" materialization + with pytest.raises(ConfigError): + ModelConfig(name="model", alias="model", schema="schema", materialized="dictionary") + + +def test_dbt_custom_materialization(): + sushi_context = Context(paths=["tests/fixtures/dbt/sushi_test"]) + + plan_builder = sushi_context.plan_builder(select_models=["sushi.custom_incremental_model"]) + plan = plan_builder.build() + assert len(plan.selected_models) == 1 + selected_model = list(plan.selected_models)[0] + assert selected_model == "model.sushi.custom_incremental_model" + + query = "SELECT * FROM sushi.custom_incremental_model ORDER BY created_at" + hook_table = "SELECT * FROM hook_table ORDER BY id" + sushi_context.apply(plan) + result = sushi_context.engine_adapter.fetchdf(query) + assert len(result) == 1 + assert {"created_at", "id"}.issubset(result.columns) + + # assert the pre/post hooks executed as well as part of the custom materialization + hook_result = sushi_context.engine_adapter.fetchdf(hook_table) + assert len(hook_result) == 1 + assert {"length_col", "id", "updated_at"}.issubset(hook_result.columns) + assert int(hook_result["length_col"][0]) >= 519 + assert hook_result["id"][0] == 1 + + # running with execution time one day in the future to simulate an incremental insert + tomorrow = datetime.now() + timedelta(days=1) + sushi_context.run(select_models=["sushi.custom_incremental_model"], execution_time=tomorrow) + + result_after_run = sushi_context.engine_adapter.fetchdf(query) + assert {"created_at", "id"}.issubset(result_after_run.columns) + + # this should have added new unique values for the new row + assert len(result_after_run) == 2 + assert result_after_run["id"].is_unique + assert result_after_run["created_at"].is_unique + + # validate the hooks executed as part of the run as well + hook_result = sushi_context.engine_adapter.fetchdf(hook_table) + assert len(hook_result) == 2 + assert hook_result["id"][1] == 2 + assert int(hook_result["length_col"][1]) >= 519 + assert hook_result["id"].is_monotonic_increasing + assert hook_result["updated_at"].is_unique + assert not hook_result["length_col"].is_unique + + +def test_dbt_custom_materialization_with_time_filter_and_macro(): + sushi_context = Context(paths=["tests/fixtures/dbt/sushi_test"]) + today = datetime.now() + + # select both custom materialiasation models with the wildcard + selector = ["sushi.custom_incremental*"] + plan_builder = sushi_context.plan_builder(select_models=selector, execution_time=today) + plan = plan_builder.build() + + assert len(plan.selected_models) == 2 + assert { + "model.sushi.custom_incremental_model", + "model.sushi.custom_incremental_with_filter", + }.issubset(plan.selected_models) + + # the model that daily (default cron) populates with data + select_daily = "SELECT * FROM sushi.custom_incremental_model ORDER BY created_at" + + # this model uses `run_started_at` as a filter (which we populate with execution time) with 2 day interval + select_filter = "SELECT * FROM sushi.custom_incremental_with_filter ORDER BY created_at" + + sushi_context.apply(plan) + result = sushi_context.engine_adapter.fetchdf(select_daily) + assert len(result) == 1 + assert {"created_at", "id"}.issubset(result.columns) + + result = sushi_context.engine_adapter.fetchdf(select_filter) + assert len(result) == 1 + assert {"created_at", "id"}.issubset(result.columns) + + # - run ONE DAY LATER + a_day_later = today + timedelta(days=1) + sushi_context.run(select_models=selector, execution_time=a_day_later) + result_after_run = sushi_context.engine_adapter.fetchdf(select_daily) + + # the new row is inserted in the normal incremental model + assert len(result_after_run) == 2 + assert {"created_at", "id"}.issubset(result_after_run.columns) + assert result_after_run["id"].is_unique + assert result_after_run["created_at"].is_unique + + # this model due to the filter shouldn't populate with any new data + result_after_run_filter = sushi_context.engine_adapter.fetchdf(select_filter) + assert len(result_after_run_filter) == 1 + assert {"created_at", "id"}.issubset(result_after_run_filter.columns) + assert result.equals(result_after_run_filter) + assert result_after_run_filter["id"].is_unique + assert result_after_run_filter["created_at"][0].date() == today.date() + + # - run TWO DAYS LATER + two_days_later = a_day_later + timedelta(days=1) + sushi_context.run(select_models=selector, execution_time=two_days_later) + result_after_run = sushi_context.engine_adapter.fetchdf(select_daily) + + # again a new row is inserted in the normal model + assert len(result_after_run) == 3 + assert {"created_at", "id"}.issubset(result_after_run.columns) + assert result_after_run["id"].is_unique + assert result_after_run["created_at"].is_unique + + # the model with the filter now should populate as well + result_after_run_filter = sushi_context.engine_adapter.fetchdf(select_filter) + assert len(result_after_run_filter) == 2 + assert {"created_at", "id"}.issubset(result_after_run_filter.columns) + assert result_after_run_filter["id"].is_unique + assert result_after_run_filter["created_at"][0].date() == today.date() + assert result_after_run_filter["created_at"][1].date() == two_days_later.date() + + # assert hooks have executed for both plan and incremental runs + hook_result = sushi_context.engine_adapter.fetchdf("SELECT * FROM hook_table ORDER BY id") + assert len(hook_result) == 3 + hook_result["id"][0] == 1 + assert hook_result["id"].is_monotonic_increasing + assert hook_result["updated_at"].is_unique + assert int(hook_result["length_col"][1]) >= 519 + assert not hook_result["length_col"].is_unique + + def test_model_kind(): context = DbtContext() context.project_name = "Test" @@ -85,6 +257,8 @@ def test_model_kind(): updated_at_as_valid_from=True, updated_at_name="updated_at", dialect="duckdb", + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.ALLOW, ) assert ModelConfig( materialized=Materialization.SNAPSHOT, @@ -98,6 +272,8 @@ def test_model_kind(): columns=["foo"], execution_time_as_valid_from=True, dialect="duckdb", + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.ALLOW, ) assert ModelConfig( materialized=Materialization.SNAPSHOT, @@ -112,23 +288,66 @@ def test_model_kind(): columns=["foo"], execution_time_as_valid_from=True, dialect="bigquery", + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.ALLOW, ) + check_cols_with_cast = ModelConfig( + materialized=Materialization.SNAPSHOT, + unique_key=["id"], + strategy="check", + check_cols=["created_at::TIMESTAMPTZ"], + ).model_kind(context) + assert isinstance(check_cols_with_cast, SCDType2ByColumnKind) + assert check_cols_with_cast.execution_time_as_valid_from is True + assert len(check_cols_with_cast.columns) == 1 + assert isinstance(check_cols_with_cast.columns[0], exp.Cast) + assert check_cols_with_cast.columns[0].sql() == 'CAST("created_at" AS TIMESTAMPTZ)' + + check_cols_multiple_expr = ModelConfig( + materialized=Materialization.SNAPSHOT, + unique_key=["id"], + strategy="check", + check_cols=["created_at::TIMESTAMPTZ", "COALESCE(status, 'active')"], + ).model_kind(context) + assert isinstance(check_cols_multiple_expr, SCDType2ByColumnKind) + assert len(check_cols_multiple_expr.columns) == 2 + assert isinstance(check_cols_multiple_expr.columns[0], exp.Cast) + assert isinstance(check_cols_multiple_expr.columns[1], exp.Coalesce) + + assert check_cols_multiple_expr.columns[0].sql() == 'CAST("created_at" AS TIMESTAMPTZ)' + assert check_cols_multiple_expr.columns[1].sql() == "COALESCE(\"status\", 'active')" + assert ModelConfig(materialized=Materialization.INCREMENTAL, time_column="foo").model_kind( context - ) == IncrementalByTimeRangeKind(time_column="foo", dialect="duckdb", forward_only=True) + ) == IncrementalByTimeRangeKind( + time_column="foo", + dialect="duckdb", + forward_only=True, + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.IGNORE, + ) assert ModelConfig( materialized=Materialization.INCREMENTAL, time_column="foo", incremental_strategy="delete+insert", forward_only=False, - ).model_kind(context) == IncrementalByTimeRangeKind(time_column="foo", dialect="duckdb") + ).model_kind(context) == IncrementalByTimeRangeKind( + time_column="foo", + dialect="duckdb", + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.IGNORE, + ) assert ModelConfig( materialized=Materialization.INCREMENTAL, time_column="foo", incremental_strategy="insert_overwrite", ).model_kind(context) == IncrementalByTimeRangeKind( - time_column="foo", dialect="duckdb", forward_only=True + time_column="foo", + dialect="duckdb", + forward_only=True, + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.IGNORE, ) assert ModelConfig( materialized=Materialization.INCREMENTAL, @@ -136,37 +355,86 @@ def test_model_kind(): unique_key=["bar"], dialect="bigquery", ).model_kind(context) == IncrementalByTimeRangeKind( - time_column="foo", dialect="bigquery", forward_only=True + time_column="foo", + dialect="bigquery", + forward_only=True, + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.IGNORE, ) assert ModelConfig( materialized=Materialization.INCREMENTAL, unique_key=["bar"], incremental_strategy="merge" ).model_kind(context) == IncrementalByUniqueKeyKind( - unique_key=["bar"], dialect="duckdb", forward_only=True, disable_restatement=False + unique_key=["bar"], + dialect="duckdb", + forward_only=True, + disable_restatement=False, + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.IGNORE, + ) + + dbt_incremental_predicate = "DBT_INTERNAL_DEST.session_start > dateadd(day, -7, current_date)" + expected_sqlmesh_predicate = parse_one( + "__MERGE_TARGET__.session_start > DATEADD(day, -7, CURRENT_DATE)" + ) + assert ModelConfig( + materialized=Materialization.INCREMENTAL, + unique_key=["bar"], + incremental_strategy="merge", + dialect="postgres", + incremental_predicates=[dbt_incremental_predicate], + ).model_kind(context) == IncrementalByUniqueKeyKind( + unique_key=["bar"], + dialect="postgres", + forward_only=True, + disable_restatement=False, + merge_filter=expected_sqlmesh_predicate, + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.IGNORE, ) assert ModelConfig(materialized=Materialization.INCREMENTAL, unique_key=["bar"]).model_kind( context ) == IncrementalByUniqueKeyKind( - unique_key=["bar"], dialect="duckdb", forward_only=True, disable_restatement=False + unique_key=["bar"], + dialect="duckdb", + forward_only=True, + disable_restatement=False, + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.IGNORE, ) assert ModelConfig( materialized=Materialization.INCREMENTAL, unique_key=["bar"], full_refresh=False ).model_kind(context) == IncrementalByUniqueKeyKind( - unique_key=["bar"], dialect="duckdb", forward_only=True, disable_restatement=True + unique_key=["bar"], + dialect="duckdb", + forward_only=True, + disable_restatement=True, + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.IGNORE, ) assert ModelConfig( materialized=Materialization.INCREMENTAL, unique_key=["bar"], full_refresh=True ).model_kind(context) == IncrementalByUniqueKeyKind( - unique_key=["bar"], dialect="duckdb", forward_only=True, disable_restatement=False + unique_key=["bar"], + dialect="duckdb", + forward_only=True, + disable_restatement=False, + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.IGNORE, ) assert ModelConfig( materialized=Materialization.INCREMENTAL, unique_key=["bar"], disable_restatement=True ).model_kind(context) == IncrementalByUniqueKeyKind( - unique_key=["bar"], dialect="duckdb", forward_only=True, disable_restatement=True + unique_key=["bar"], + dialect="duckdb", + forward_only=True, + disable_restatement=True, + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.IGNORE, ) assert ModelConfig( @@ -175,7 +443,12 @@ def test_model_kind(): disable_restatement=True, full_refresh=True, ).model_kind(context) == IncrementalByUniqueKeyKind( - unique_key=["bar"], dialect="duckdb", forward_only=True, disable_restatement=True + unique_key=["bar"], + dialect="duckdb", + forward_only=True, + disable_restatement=True, + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.IGNORE, ) assert ModelConfig( @@ -183,14 +456,69 @@ def test_model_kind(): unique_key=["bar"], disable_restatement=True, full_refresh=False, + auto_restatement_cron="0 0 * * *", ).model_kind(context) == IncrementalByUniqueKeyKind( - unique_key=["bar"], dialect="duckdb", forward_only=True, disable_restatement=True + unique_key=["bar"], + dialect="duckdb", + forward_only=True, + disable_restatement=True, + auto_restatement_cron="0 0 * * *", + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.IGNORE, ) + # Test incompatibile incremental strategies + for incremental_strategy in ("delete+insert", "insert_overwrite", "append"): + assert ModelConfig( + materialized=Materialization.INCREMENTAL, + unique_key=["bar"], + incremental_strategy=incremental_strategy, + ).model_kind(context) == IncrementalByUniqueKeyKind( + unique_key=["bar"], + dialect="duckdb", + forward_only=True, + disable_restatement=False, + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.IGNORE, + ) + assert ModelConfig( materialized=Materialization.INCREMENTAL, time_column="foo", incremental_strategy="merge" ).model_kind(context) == IncrementalByTimeRangeKind( - time_column="foo", dialect="duckdb", forward_only=True, disable_restatement=False + time_column="foo", + dialect="duckdb", + forward_only=True, + disable_restatement=False, + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.IGNORE, + ) + + assert ModelConfig( + materialized=Materialization.INCREMENTAL, + time_column="foo", + incremental_strategy="merge", + full_refresh=True, + ).model_kind(context) == IncrementalByTimeRangeKind( + time_column="foo", + dialect="duckdb", + forward_only=True, + disable_restatement=False, + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.IGNORE, + ) + + assert ModelConfig( + materialized=Materialization.INCREMENTAL, + time_column="foo", + incremental_strategy="merge", + full_refresh=False, + ).model_kind(context) == IncrementalByTimeRangeKind( + time_column="foo", + dialect="duckdb", + forward_only=True, + disable_restatement=False, + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.IGNORE, ) assert ModelConfig( @@ -199,7 +527,25 @@ def test_model_kind(): incremental_strategy="append", disable_restatement=True, ).model_kind(context) == IncrementalByTimeRangeKind( - time_column="foo", dialect="duckdb", forward_only=True, disable_restatement=True + time_column="foo", + dialect="duckdb", + forward_only=True, + disable_restatement=True, + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.IGNORE, + ) + + assert ModelConfig( + materialized=Materialization.INCREMENTAL, + time_column="foo", + incremental_strategy="insert_overwrite", + partition_by={"field": "bar"}, + forward_only=False, + ).model_kind(context) == IncrementalByTimeRangeKind( + time_column="foo", + dialect="duckdb", + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.IGNORE, ) assert ModelConfig( @@ -208,30 +554,73 @@ def test_model_kind(): incremental_strategy="insert_overwrite", partition_by={"field": "bar"}, forward_only=False, - ).model_kind(context) == IncrementalByTimeRangeKind(time_column="foo", dialect="duckdb") + auto_restatement_cron="0 0 * * *", + auto_restatement_intervals=3, + ).model_kind(context) == IncrementalByTimeRangeKind( + time_column="foo", + dialect="duckdb", + forward_only=False, + auto_restatement_cron="0 0 * * *", + auto_restatement_intervals=3, + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.IGNORE, + ) assert ModelConfig( materialized=Materialization.INCREMENTAL, incremental_strategy="insert_overwrite", partition_by={"field": "bar"}, ).model_kind(context) == IncrementalUnmanagedKind( - insert_overwrite=True, disable_restatement=False + insert_overwrite=True, + disable_restatement=False, + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.IGNORE, ) assert ModelConfig(materialized=Materialization.INCREMENTAL).model_kind( context - ) == IncrementalUnmanagedKind(insert_overwrite=True, disable_restatement=False) + ) == IncrementalUnmanagedKind( + insert_overwrite=True, + disable_restatement=False, + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.IGNORE, + ) + + assert ModelConfig(materialized=Materialization.INCREMENTAL, forward_only=False).model_kind( + context + ) == IncrementalUnmanagedKind( + insert_overwrite=True, + disable_restatement=False, + forward_only=False, + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.IGNORE, + ) assert ModelConfig( materialized=Materialization.INCREMENTAL, incremental_strategy="append" - ).model_kind(context) == IncrementalUnmanagedKind(disable_restatement=False) + ).model_kind(context) == IncrementalUnmanagedKind( + disable_restatement=False, + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.IGNORE, + ) + + assert ModelConfig( + materialized=Materialization.INCREMENTAL, incremental_strategy="append", full_refresh=None + ).model_kind(context) == IncrementalUnmanagedKind( + disable_restatement=False, + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.IGNORE, + ) assert ModelConfig( materialized=Materialization.INCREMENTAL, incremental_strategy="insert_overwrite", partition_by={"field": "bar", "data_type": "int64"}, ).model_kind(context) == IncrementalUnmanagedKind( - insert_overwrite=True, disable_restatement=False + insert_overwrite=True, + disable_restatement=False, + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.IGNORE, ) assert ModelConfig( @@ -240,7 +629,10 @@ def test_model_kind(): partition_by={"field": "bar", "data_type": "int64"}, full_refresh=False, ).model_kind(context) == IncrementalUnmanagedKind( - insert_overwrite=True, disable_restatement=True + insert_overwrite=True, + disable_restatement=True, + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.IGNORE, ) assert ModelConfig( @@ -250,7 +642,10 @@ def test_model_kind(): disable_restatement=True, full_refresh=True, ).model_kind(context) == IncrementalUnmanagedKind( - insert_overwrite=True, disable_restatement=True + insert_overwrite=True, + disable_restatement=True, + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.IGNORE, ) assert ModelConfig( @@ -259,27 +654,47 @@ def test_model_kind(): partition_by={"field": "bar", "data_type": "int64"}, disable_restatement=True, ).model_kind(context) == IncrementalUnmanagedKind( - insert_overwrite=True, disable_restatement=True + insert_overwrite=True, + disable_restatement=True, + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.IGNORE, ) - with pytest.raises(ConfigError): - ModelConfig( - materialized=Materialization.INCREMENTAL, - unique_key=["bar"], - incremental_strategy="delete+insert", - ).model_kind(context) - with pytest.raises(ConfigError): - ModelConfig( - materialized=Materialization.INCREMENTAL, - unique_key=["bar"], - incremental_strategy="insert_overwrite", - ).model_kind(context) - with pytest.raises(ConfigError): - ModelConfig( - materialized=Materialization.INCREMENTAL, - unique_key=["bar"], - incremental_strategy="append", - ).model_kind(context) + assert ModelConfig( + materialized=Materialization.INCREMENTAL, + incremental_strategy="insert_overwrite", + auto_restatement_cron="0 0 * * *", + ).model_kind(context) == IncrementalUnmanagedKind( + insert_overwrite=True, + auto_restatement_cron="0 0 * * *", + disable_restatement=False, + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.IGNORE, + ) + + assert ( + ModelConfig(materialized=Materialization.DYNAMIC_TABLE, target_lag="1 hour").model_kind( + context + ) + == ManagedKind() + ) + + assert ModelConfig( + materialized=Materialization.SNAPSHOT, + unique_key=["id"], + updated_at="updated_at::timestamp", + strategy="timestamp", + dialect="redshift", + ).model_kind(context) == SCDType2ByTimeKind( + unique_key=["id"], + valid_from_name="dbt_valid_from", + valid_to_name="dbt_valid_to", + updated_at_as_valid_from=True, + updated_at_name="updated_at", + dialect="redshift", + on_destructive_change=OnDestructiveChange.IGNORE, + on_additive_change=OnAdditiveChange.ALLOW, + ) def test_model_kind_snapshot_bigquery(): @@ -300,6 +715,7 @@ def test_model_kind_snapshot_bigquery(): updated_at_name="updated_at", time_data_type=exp.DataType.build("TIMESTAMPTZ"), dialect="bigquery", + on_destructive_change=OnDestructiveChange.IGNORE, ) # time_data_type is bigquery version even though model dialect is DuckDB @@ -318,6 +734,7 @@ def test_model_kind_snapshot_bigquery(): updated_at_name="updated_at", time_data_type=exp.DataType.build("TIMESTAMPTZ"), # bigquery version dialect="duckdb", + on_destructive_change=OnDestructiveChange.IGNORE, ) @@ -341,9 +758,9 @@ def test_model_columns(): ) expected_column_types = { - "ADDRESS": parse_one("text", into=exp.DataType), - "ZIPCODE": parse_one("varchar(5)", into=exp.DataType), - "DATE": parse_one("timestamp_ntz", into=exp.DataType, dialect="snowflake"), + "ADDRESS": exp.DataType.build("text"), + "ZIPCODE": exp.DataType.build("varchar(5)"), + "DATE": exp.DataType.build("timestamp_ntz", dialect="snowflake"), } expected_column_descriptions = { "ADDRESS": "Business address", @@ -360,7 +777,10 @@ def test_model_columns(): name="target", schema="test", database="test", account="foo", user="bar", password="baz" ) sqlmesh_model = model.to_sqlmesh(context) - assert sqlmesh_model.columns_to_types == expected_column_types + + # Columns being present in a schema.yaml are not respected in DDLs, so SQLMesh doesn't + # set the corresponding columns_to_types_ attribute either to match dbt's behavior + assert sqlmesh_model.columns_to_types == None assert sqlmesh_model.column_descriptions == expected_column_descriptions @@ -370,31 +790,105 @@ def test_seed_columns(): package="package", path=Path("examples/sushi_dbt/seeds/waiter_names.csv"), columns={ - "address": ColumnConfig( - name="address", data_type="text", description="Business address" - ), - "zipcode": ColumnConfig( - name="zipcode", data_type="text", description="Business zipcode" - ), + "id": ColumnConfig(name="id", data_type="text", description="The ID"), + "name": ColumnConfig(name="name", data_type="text", description="The name"), + }, + ) + + # dbt doesn't respect the data_type field in the DDLs– instead, it optionally uses it to + # validate the actual data types at runtime through contracts or external plugins. Thus, + # the actual data type is int, because that is what is inferred from the seed file. + expected_column_types = { + "id": exp.DataType.build("int"), + "name": exp.DataType.build("text"), + } + expected_column_descriptions = { + "id": "The ID", + "name": "The name", + } + + context = DbtContext() + context.project_name = "Foo" + context.target = DuckDbConfig(name="target", schema="test") + sqlmesh_seed = seed.to_sqlmesh(context) + assert sqlmesh_seed.columns_to_types == expected_column_types + assert sqlmesh_seed.column_descriptions == expected_column_descriptions + + +def test_seed_column_types(): + seed = SeedConfig( + name="foo", + package="package", + path=Path("examples/sushi_dbt/seeds/waiter_names.csv"), + column_types={ + "id": "text", + "name": "text", + }, + columns={ + "name": ColumnConfig(name="name", description="The name"), }, + quote_columns=True, ) expected_column_types = { - "address": parse_one("text", into=exp.DataType), - "zipcode": parse_one("text", into=exp.DataType), + "id": exp.DataType.build("text"), + "name": exp.DataType.build("text"), } expected_column_descriptions = { - "address": "Business address", - "zipcode": "Business zipcode", + "name": "The name", } context = DbtContext() context.project_name = "Foo" context.target = DuckDbConfig(name="target", schema="test") sqlmesh_seed = seed.to_sqlmesh(context) + assert sqlmesh_seed.columns_to_types == expected_column_types assert sqlmesh_seed.column_descriptions == expected_column_descriptions + seed = SeedConfig( + name="foo", + package="package", + path=Path("examples/sushi_dbt/seeds/waiter_names.csv"), + column_types={ + "name": "text", + }, + columns={ + # The `data_type` field does not affect the materialized seed's column type + "id": ColumnConfig(name="name", data_type="text"), + }, + quote_columns=True, + ) + + expected_column_types = { + "id": exp.DataType.build("int"), + "name": exp.DataType.build("text"), + } + sqlmesh_seed = seed.to_sqlmesh(context) + assert sqlmesh_seed.columns_to_types == expected_column_types + + seed = SeedConfig( + name="foo", + package="package", + path=Path("examples/sushi_dbt/seeds/waiter_names.csv"), + column_types={ + "id": "TEXT", + "name": "TEXT NOT NULL", + }, + quote_columns=True, + ) + + expected_column_types = { + "id": exp.DataType.build("text"), + "name": exp.DataType.build("text"), + } + + logger = logging.getLogger("sqlmesh.dbt.column") + with patch.object(logger, "warning") as mock_logger: + sqlmesh_seed = seed.to_sqlmesh(context) + assert "Ignoring unsupported constraints" in mock_logger.call_args[0][0] + assert sqlmesh_seed.columns_to_types == expected_column_types + def test_seed_column_inference(tmp_path): seed_csv = tmp_path / "seed.csv" @@ -402,6 +896,7 @@ def test_seed_column_inference(tmp_path): fd.write("int_col,double_col,datetime_col,date_col,boolean_col,text_col\n") fd.write("1,1.2,2021-01-01 00:00:00,2021-01-01,true,foo\n") fd.write("2,2.3,2021-01-02 00:00:00,2021-01-02,false,bar\n") + fd.write("null,,null,,,null\n") seed = SeedConfig( name="test_model", @@ -414,7 +909,9 @@ def test_seed_column_inference(tmp_path): context.target = DuckDbConfig(name="target", schema="test") sqlmesh_seed = seed.to_sqlmesh(context) assert sqlmesh_seed.columns_to_types == { - "int_col": exp.DataType.build("int"), + "int_col": exp.DataType.build("int") + if DBT_VERSION >= (1, 8, 0) + else exp.DataType.build("double"), "double_col": exp.DataType.build("double"), "datetime_col": exp.DataType.build("datetime"), "date_col": exp.DataType.build("date"), @@ -423,44 +920,222 @@ def test_seed_column_inference(tmp_path): } -@pytest.mark.xdist_group("dbt_manifest") -def test_model_dialect(sushi_test_project: Project, assert_exp_eq): - model_config = ModelConfig( - name="model", - package_name="package", - schema="sushi", - alias="table", - sql="SELECT 1 AS `one` FROM {{ schema }}", - ) - context = sushi_test_project.context - - # cannot parse model sql without specifying bigquery dialect - with pytest.raises(ConfigError): - model_config.to_sqlmesh(context).render_query_or_raise().sql() +def test_seed_single_whitespace_is_na(tmp_path): + seed_csv = tmp_path / "seed.csv" + with open(seed_csv, "w", encoding="utf-8") as fd: + fd.write("col_a, col_b\n") + fd.write(" ,1\n") + fd.write("2, \n") - model_config = ModelConfig( - name="model", - package_name="package", - schema="sushi", - alias="table", - sql="SELECT 1 AS `one` FROM {{ schema }}", - dialect="bigquery", - ) - assert_exp_eq( - model_config.to_sqlmesh(context).render_query_or_raise().sql(), - 'SELECT 1 AS "one" FROM "sushi" AS "sushi"', + seed = SeedConfig( + name="test_model", + package="foo", + path=Path(seed_csv), ) + context = DbtContext() + context.project_name = "foo" + context.target = DuckDbConfig(name="target", schema="test") + sqlmesh_seed = seed.to_sqlmesh(context) + assert sqlmesh_seed.columns_to_types == { + "col_a": exp.DataType.build("int"), + "col_b": exp.DataType.build("int"), + } -@pytest.mark.xdist_group("dbt_manifest") -@pytest.mark.parametrize( - "model_fqn", ['"memory"."sushi"."waiters"', '"memory"."sushi"."waiter_names"'] -) -def test_hooks(sushi_test_dbt_context: Context, model_fqn: str): - engine_adapter = sushi_test_dbt_context.engine_adapter - waiters = sushi_test_dbt_context.models[model_fqn] + df = next(sqlmesh_seed.render_seed()) + assert df["col_a"].to_list() == [None, 2] + assert df["col_b"].to_list() == [1, None] - logger = logging.getLogger("sqlmesh.dbt.builtin") + +def test_seed_partial_column_inference(tmp_path): + seed_csv = tmp_path / "seed.csv" + with open(seed_csv, "w", encoding="utf-8") as fd: + fd.write("int_col,double_col,datetime_col,boolean_col\n") + fd.write("1,1.2,2021-01-01 00:00:00,true\n") + fd.write("2,2.3,2021-01-02 00:00:00,false\n") + fd.write("null,,null,\n") + + seed = SeedConfig( + name="test_model", + package="package", + path=Path(seed_csv), + column_types={ + "double_col": "double", + }, + columns={ + "int_col": ColumnConfig( + name="int_col", data_type="int", description="Description with type." + ), + "datetime_col": ColumnConfig( + name="datetime_col", description="Description without type." + ), + "boolean_col": ColumnConfig(name="boolean_col"), + }, + ) + + expected_column_types = { + "int_col": exp.DataType.build("int"), + "double_col": exp.DataType.build("double"), + "datetime_col": exp.DataType.build("datetime"), + "boolean_col": exp.DataType.build("boolean"), + } + + expected_column_descriptions = { + "int_col": "Description with type.", + "datetime_col": "Description without type.", + } + + context = DbtContext() + context.project_name = "Foo" + context.target = DuckDbConfig(name="target", schema="test") + sqlmesh_seed = seed.to_sqlmesh(context) + assert sqlmesh_seed.columns_to_types == expected_column_types + assert sqlmesh_seed.column_descriptions == expected_column_descriptions + + # Check that everything still lines up + seed_df = next(sqlmesh_seed.render_seed()) + assert list(seed_df.columns) == list(sqlmesh_seed.columns_to_types.keys()) + + +def test_seed_delimiter(tmp_path): + seed_csv = tmp_path / "seed_with_delimiter.csv" + + with open(seed_csv, "w", encoding="utf-8") as fd: + fd.writelines("\n".join(["id|name|city", "0|Ayrton|SP", "1|Max|MC", "2|Niki|VIE"])) + + seed = SeedConfig( + name="test_model_pipe", + package="package", + path=Path(seed_csv), + delimiter="|", + ) + + context = DbtContext() + context.project_name = "TestProject" + context.target = DuckDbConfig(name="target", schema="test") + sqlmesh_seed = seed.to_sqlmesh(context) + + # Verify columns are correct with the custom pipe (|) delimiter + expected_columns = {"id", "name", "city"} + assert set(sqlmesh_seed.columns_to_types.keys()) == expected_columns + + seed_df = next(sqlmesh_seed.render_seed()) + assert list(seed_df.columns) == list(sqlmesh_seed.columns_to_types.keys()) + assert len(seed_df) == 3 + + assert seed_df.iloc[0]["name"] == "Ayrton" + assert seed_df.iloc[0]["city"] == "SP" + assert seed_df.iloc[1]["name"] == "Max" + assert seed_df.iloc[1]["city"] == "MC" + + # test with semicolon delimiter + seed_csv_semicolon = tmp_path / "seed_with_semicolon.csv" + with open(seed_csv_semicolon, "w", encoding="utf-8") as fd: + fd.writelines("\n".join(["id;value;status", "1;100;active", "2;200;inactive"])) + + seed_semicolon = SeedConfig( + name="test_model_semicolon", + package="package", + path=Path(seed_csv_semicolon), + delimiter=";", + ) + + sqlmesh_seed_semicolon = seed_semicolon.to_sqlmesh(context) + expected_columns_semicolon = {"id", "value", "status"} + assert set(sqlmesh_seed_semicolon.columns_to_types.keys()) == expected_columns_semicolon + + seed_df_semicolon = next(sqlmesh_seed_semicolon.render_seed()) + assert seed_df_semicolon.iloc[0]["value"] == 100 + assert seed_df_semicolon.iloc[0]["status"] == "active" + + +def test_seed_column_order(tmp_path): + seed_csv = tmp_path / "seed.csv" + + with open(seed_csv, "w", encoding="utf-8") as fd: + fd.writelines("\n".join(["id,name", "0,Toby", "1,Tyson", "2,Ryan"])) + + seed = SeedConfig( + name="test_model", + package="package", + path=Path(seed_csv), + columns={ + "id": ColumnConfig(name="id"), + "name": ColumnConfig(name="name", data_type="varchar"), + }, + ) + + context = DbtContext() + context.project_name = "Foo" + context.target = DuckDbConfig(name="target", schema="test") + sqlmesh_seed = seed.to_sqlmesh(context) + + # Check that everything still lines up + seed_df = next(sqlmesh_seed.render_seed()) + assert list(seed_df.columns) == list(sqlmesh_seed.columns_to_types.keys()) + + +def test_agate_integer_cast(): + # Not all dbt versions have agate.Integer + if DBT_VERSION < (1, 7, 0): + pytest.skip("agate.Integer not available") + + from sqlmesh.dbt.seed import Integer + + agate_integer = Integer(null_values=("null", "")) + assert agate_integer.cast("1") == 1 + assert agate_integer.cast(1) == 1 + assert agate_integer.cast("null") is None + assert agate_integer.cast("") is None + + with pytest.raises(agate.exceptions.CastError): + agate_integer.cast("1.2") + + with pytest.raises(agate.exceptions.CastError): + agate_integer.cast(1.2) + + with pytest.raises(agate.exceptions.CastError): + agate_integer.cast(datetime.now()) + + +@pytest.mark.xdist_group("dbt_manifest") +def test_model_dialect(sushi_test_project: Project, assert_exp_eq): + model_config = ModelConfig( + name="model", + package_name="package", + schema="sushi", + alias="table", + sql="SELECT 1 AS `one` FROM {{ schema }}", + ) + context = sushi_test_project.context + + # cannot parse model sql without specifying bigquery dialect + with pytest.raises(ConfigError): + model_config.to_sqlmesh(context).render_query_or_raise().sql() + + model_config = ModelConfig( + name="model", + package_name="package", + schema="sushi", + alias="table", + sql="SELECT 1 AS `one` FROM {{ schema }}", + dialect="bigquery", + ) + assert_exp_eq( + model_config.to_sqlmesh(context).render_query_or_raise().sql(), + 'SELECT 1 AS "one" FROM "sushi" AS "sushi"', + ) + + +@pytest.mark.xdist_group("dbt_manifest") +@pytest.mark.parametrize( + "model_fqn", ['"memory"."sushi"."waiters"', '"memory"."sushi"."waiter_names"'] +) +def test_hooks(sushi_test_dbt_context: Context, model_fqn: str): + engine_adapter = sushi_test_dbt_context.engine_adapter + waiters = sushi_test_dbt_context.models[model_fqn] + + logger = logging.getLogger("sqlmesh.dbt.builtin") with patch.object(logger, "debug") as mock_logger: engine_adapter.execute( waiters.render_pre_statements( @@ -478,6 +1153,45 @@ def test_hooks(sushi_test_dbt_context: Context, model_fqn: str): assert "post-hook" in mock_logger.call_args[0][0] +@pytest.mark.xdist_group("dbt_manifest") +def test_seed_delimiter_integration(sushi_test_dbt_context: Context): + seed_fqn = '"memory"."sushi"."waiter_revenue_semicolon"' + assert seed_fqn in sushi_test_dbt_context.models + + seed_model = sushi_test_dbt_context.models[seed_fqn] + assert seed_model.columns_to_types is not None + + # this should be loaded with semicolon delimiter otherwise it'd resylt in an one column table + assert set(seed_model.columns_to_types.keys()) == {"waiter_id", "revenue", "quarter"} + + # columns_to_types values are correct types as well + assert seed_model.columns_to_types == { + "waiter_id": exp.DataType.build("int"), + "revenue": exp.DataType.build("double"), + "quarter": exp.DataType.build("text"), + } + + df = sushi_test_dbt_context.fetchdf(f"SELECT * FROM {seed_fqn}") + + assert len(df) == 6 + waiter_ids = set(df["waiter_id"].tolist()) + quarters = set(df["quarter"].tolist()) + assert waiter_ids == {1, 2, 3} + assert quarters == {"Q1", "Q2"} + + q1_w1_rows = df[(df["waiter_id"] == 1) & (df["quarter"] == "Q1")] + assert len(q1_w1_rows) == 1 + assert float(q1_w1_rows.iloc[0]["revenue"]) == 100.50 + + q2_w2_rows = df[(df["waiter_id"] == 2) & (df["quarter"] == "Q2")] + assert len(q2_w2_rows) == 1 + assert float(q2_w2_rows.iloc[0]["revenue"]) == 225.50 + + q2_w3_rows = df[(df["waiter_id"] == 3) & (df["quarter"] == "Q2")] + assert len(q2_w3_rows) == 1 + assert float(q2_w3_rows.iloc[0]["revenue"]) == 175.75 + + @pytest.mark.xdist_group("dbt_manifest") def test_target_jinja(sushi_test_project: Project): context = sushi_test_project.context @@ -489,6 +1203,51 @@ def test_target_jinja(sushi_test_project: Project): assert context.render("{{ target.path }}") == "None" assert context.render("{{ target.profile_name }}") == "None" + context = DbtContext() + context._target = SnowflakeConfig( + name="target", + schema="test", + database="test", + account="account", + user="user", + password="password", + warehouse="warehouse", + role="role", + threads=1, + ) + assert context.render("{{ target.threads }}") == "1" + assert context.render("{{ target.database }}") == "test" + assert context.render("{{ target.warehouse }}") == "warehouse" + assert context.render("{{ target.user }}") == "user" + assert context.render("{{ target.role }}") == "role" + assert context.render("{{ target.account }}") == "account" + + context = DbtContext() + context._target = PostgresConfig( + name="target", + schema="test", + database="test", + dbname="test", + host="host", + port=5432, + user="user", + password="password", + ) + assert context.render("{{ target.dbname }}") == "test" + assert context.render("{{ target.host }}") == "host" + assert context.render("{{ target.port }}") == "5432" + + context = DbtContext() + context._target = BigQueryConfig( + name="target", + schema="test_value", + database="test_project", + ) + assert context.render("{{ target.project }}") == "test_project" + assert context.render("{{ target.database }}") == "test_project" + assert context.render("{{ target.schema }}") == "test_value" + assert context.render("{{ target.dataset }}") == "test_value" + @pytest.mark.xdist_group("dbt_manifest") def test_project_name_jinja(sushi_test_project: Project): @@ -514,7 +1273,7 @@ def test_schema_jinja(sushi_test_project: Project, assert_exp_eq): @pytest.mark.xdist_group("dbt_manifest") def test_config_jinja(sushi_test_project: Project): - hook = "{{ config(alias='bar') }} {{ config.alias }}" + hook = "{{ config(alias='bar') }} {{ config.get('alias') }}" model_config = ModelConfig( name="model", package_name="package", @@ -529,6 +1288,211 @@ def test_config_jinja(sushi_test_project: Project): assert model.render_pre_statements()[0].sql() == '"bar"' +@pytest.mark.xdist_group("dbt_manifest") +def test_config_dict_syntax(): + # Test dictionary syntax + config = Config({}) + result = config({"materialized": "table", "alias": "dict_table"}) + assert result == "" + assert config._config["materialized"] == "table" + assert config._config["alias"] == "dict_table" + + # Test kwargs syntax still works + config2 = Config({}) + result = config2(materialized="view", alias="kwargs_table") + assert result == "" + assert config2._config["materialized"] == "view" + assert config2._config["alias"] == "kwargs_table" + + # Test that mixing args and kwargs is rejected + config3 = Config({}) + try: + config3({"materialized": "table"}, alias="mixed") + assert False, "Should have raised ConfigError" + except Exception as e: + assert "cannot mix positional and keyword arguments" in str(e) + + # Test nested dicts + config4 = Config({}) + config4({"meta": {"owner": "data_team", "priority": 1}, "tags": ["daily", "critical"]}) + assert config4._config["meta"]["owner"] == "data_team" + assert config4._config["tags"] == ["daily", "critical"] + + # Test multiple positional arguments are rejected + config4 = Config({}) + try: + config4({"materialized": "table"}, {"alias": "test"}) + assert False + except Exception as e: + assert "expected a single dictionary, got 2 arguments" in str(e) + + +def test_config_dict_in_jinja(): + # Test dict syntax directly with Config class + config = Config({}) + template = Template("{{ config({'materialized': 'table', 'unique_key': 'id'}) }}done") + result = template.render(config=config) + assert result == "done" + assert config._config["materialized"] == "table" + assert config._config["unique_key"] == "id" + + # Test with nested dict and list values + config2 = Config({}) + complex_template = Template("""{{ config({ + 'tags': ['test', 'dict'], + 'meta': {'owner': 'data_team'} + }) }}result""") + result = complex_template.render(config=config2) + assert result == "result" + assert config2._config["tags"] == ["test", "dict"] + assert config2._config["meta"]["owner"] == "data_team" + + # Test that kwargs still work + config3 = Config({}) + kwargs_template = Template("{{ config(materialized='view', alias='my_view') }}done") + result = kwargs_template.render(config=config3) + assert result == "done" + assert config3._config["materialized"] == "view" + assert config3._config["alias"] == "my_view" + + +@pytest.mark.xdist_group("dbt_manifest") +def test_config_dict_syntax_in_sushi_project(sushi_test_project: Project): + assert sushi_test_project is not None + assert sushi_test_project.context is not None + + sushi_package = sushi_test_project.packages.get("sushi") + assert sushi_package is not None + + top_waiters_found = False + for model_config in sushi_package.models.values(): + if model_config.name == "top_waiters": + # top_waiters model now uses dict config syntax with: + # config({'materialized': 'view', 'limit_value': var('top_waiters:limit'), 'meta': {...}}) + top_waiters_found = True + assert model_config.materialized == "view" + assert model_config.meta is not None + assert model_config.meta.get("owner") == "analytics_team" + assert model_config.meta.get("priority") == "high" + break + + assert top_waiters_found + + +@pytest.mark.xdist_group("dbt_manifest") +def test_config_jinja_get_methods(sushi_test_project: Project): + model_config = ModelConfig( + name="model_conf", + package_name="package", + schema="sushi", + sql="""SELECT 1 AS one FROM foo""", + alias="model_alias", + **{ + "pre-hook": [ + "{{ config(materialized='incremental', unique_key='id') }}" + "{{ config.get('missed', 'a') + config.get('missed', default='b')}}", + "{{ config.set('alias', 'new_alias')}}", + "{{ config.get('package_name') + '_' + config.require('unique_key')}}", + "{{ config.get('alias') or 'default'}}", + ] + }, + **{"post-hook": "{{config.require('missing_key')}}"}, + ) + context = sushi_test_project.context + model = t.cast(SqlModel, model_config.to_sqlmesh(context)) + + assert model.render_pre_statements()[0].sql() == '"ab"' + assert model.render_pre_statements()[1].sql() == '"package_id"' + assert model.render_pre_statements()[2].sql() == '"new_alias"' + + with pytest.raises(ConfigError, match="Missing required config: missing_key"): + model.render_post_statements() + + # test get methods with operations + model_2_config = ModelConfig( + name="model_2", + package_name="package", + schema="sushi", + sql="""SELECT 1 AS one FROM foo""", + alias="mod", + materialized="table", + threads=8, + partition_by="date", + cluster_by=["user_id", "product_id"], + **{ + "pre-hook": [ + "{{ config.get('partition_by', default='none') }}", + "{{ config.get('cluster_by', default=[]) | length }}", + "{% if config.get('threads') > 4 %}high_threads{% else %}low_threads{% endif %}", + ] + }, + ) + model2 = t.cast(SqlModel, model_2_config.to_sqlmesh(context)) + + pre_statements2 = model2.render_pre_statements() + assert pre_statements2[0].sql() == "ARRAY('date')" + assert pre_statements2[1].sql() == "2" + assert pre_statements2[2].sql() == '"high_threads"' + + # test seting variable and conditional + model_invalid_timeout = ModelConfig( + name="invalid_timeout_test", + package_name="package", + schema="sushi", + sql="""SELECT 1 AS one FROM foo""", + alias="invalid_timeout_alias", + connection_timeout=44, + **{ + "pre-hook": [ + """ + {%- set value = config.require('connection_timeout') -%} + {%- set is_valid = value >= 10 and value <= 30 -%} + {%- if not is_valid -%} + {{ exceptions.raise_compiler_error("Validation failed for 'connection_timeout': Value must be between 10 and 30, got: " ~ value) }} + {%- endif -%} + {{ value }} + """, + ] + }, + ) + + model_invalid = t.cast(SqlModel, model_invalid_timeout.to_sqlmesh(context)) + with pytest.raises( + ConfigError, + match="Validation failed for 'connection_timeout': Value must be between 10 and 30, got: 44", + ): + model_invalid.render_pre_statements() + + # test persist_docs methods + model_config_persist = ModelConfig( + name="persist_docs_model", + package_name="package", + schema="sushi", + sql="""SELECT 1 AS one FROM foo""", + alias="persist_alias", + **{ + "pre-hook": [ + "{{ config(persist_docs={'relation': true, 'columns': true}) }}", + "{{ config.persist_relation_docs() }}", + "{{ config.persist_column_docs() }}", + "{{ config(persist_docs={'relation': false, 'columns': true}) }}", + "{{ config.persist_relation_docs() }}", + "{{ config.persist_column_docs() }}", + ] + }, + ) + model3 = t.cast(SqlModel, model_config_persist.to_sqlmesh(context)) + + pre_statements3 = model3.render_pre_statements() + + # it should filter out empty returns, so we get 4 statements + assert len(pre_statements3) == 4 + assert pre_statements3[0].sql() == "TRUE" + assert pre_statements3[1].sql() == "TRUE" + assert pre_statements3[2].sql() == "FALSE" + assert pre_statements3[3].sql() == "TRUE" + + @pytest.mark.xdist_group("dbt_manifest") def test_model_this(assert_exp_eq, sushi_test_project: Project): model_config = ModelConfig( @@ -558,7 +1522,7 @@ def test_test_this(assert_exp_eq, sushi_test_project: Project): context = sushi_test_project.context audit = t.cast(StandaloneAudit, test_config.to_sqlmesh(context)) assert_exp_eq( - audit.render_query(audit).sql(), + audit.render_audit_query().sql(), 'SELECT 1 AS "one" FROM "test" AS "test"', ) @@ -578,12 +1542,12 @@ def test_test_dialect(assert_exp_eq, sushi_test_project: Project): # can't parse test sql without specifying bigquery as default dialect with pytest.raises(ConfigError): audit = t.cast(StandaloneAudit, test_config.to_sqlmesh(context)) - audit.render_query(audit).sql() + audit.render_audit_query().sql() test_config.dialect_ = "bigquery" audit = t.cast(StandaloneAudit, test_config.to_sqlmesh(context)) assert_exp_eq( - audit.render_query(audit).sql(), + audit.render_audit_query().sql(), 'SELECT 1 AS "one" FROM "test" AS "test"', ) @@ -627,13 +1591,40 @@ def test_logging(sushi_test_project: Project, runtime_renderer: t.Callable): renderer = runtime_renderer(context, engine_adapter=engine_adapter) logger = logging.getLogger("sqlmesh.dbt.builtin") - with patch.object(logger, "debug") as mock_logger: - assert renderer('{{ log("foo") }}') == "" - assert "foo" in mock_logger.call_args[0][0] - with patch.object(logger, "debug") as mock_logger: + # Test log with info=False (default), should only log to file with debug and not to console + with ( + patch.object(logger, "debug") as mock_debug, + patch.object(logger, "info") as mock_info, + patch.object(get_console(), "log_status_update") as mock_console, + ): + assert renderer('{{ log("foo") }}') == "" + mock_debug.assert_called_once() + assert "foo" in mock_debug.call_args[0][0] + mock_info.assert_not_called() + mock_console.assert_not_called() + + # Test log with info=True, should log to info and also call log_status_update + with ( + patch.object(logger, "debug") as mock_debug, + patch.object(logger, "info") as mock_info, + patch.object(get_console(), "log_status_update") as mock_console, + ): + assert renderer('{{ log("output to be logged with info", info=true) }}') == "" + mock_info.assert_called_once() + assert "output to be logged with info" in mock_info.call_args[0][0] + mock_debug.assert_not_called() + mock_console.assert_called_once() + assert "output to be logged with info" in mock_console.call_args[0][0] + + # Test print function as well, should use debug + with ( + patch.object(logger, "debug") as mock_logger, + patch.object(get_console(), "log_status_update") as mock_console, + ): assert renderer('{{ print("bar") }}') == "" - assert "bar" in mock_logger.call_args[0][0] + assert "bar" in mock_logger.call_args[0][0] + mock_console.assert_not_called() @pytest.mark.xdist_group("dbt_manifest") @@ -649,6 +1640,29 @@ def test_exceptions(sushi_test_project: Project): context.render('{{ exceptions.raise_compiler_error("Error") }}') +@pytest.mark.xdist_group("dbt_manifest") +def test_try_or_compiler_error(sushi_test_project: Project): + context = sushi_test_project.context + + result = context.render( + '{{ try_or_compiler_error("Error message", modules.datetime.datetime.strptime, "2023-01-15", "%Y-%m-%d") }}' + ) + assert "2023-01-15" in result + + with pytest.raises(CompilationError, match="Invalid date format"): + context.render( + '{{ try_or_compiler_error("Invalid date format", modules.datetime.datetime.strptime, "invalid", "%Y-%m-%d") }}' + ) + + # built-in macro calling try_or_compiler_error works + result = context.render( + '{{ dbt.dates_in_range("2023-01-01", "2023-01-03", "%Y-%m-%d", "%Y-%m-%d") }}' + ) + assert "2023-01-01" in result + assert "2023-01-02" in result + assert "2023-01-03" in result + + @pytest.mark.xdist_group("dbt_manifest") def test_modules(sushi_test_project: Project): context = sushi_test_project.context @@ -666,9 +1680,7 @@ def test_modules(sushi_test_project: Project): assert context.render("{{ modules.re.search('(?<=abc)def', 'abcdef').group(0) }}") == "def" # itertools - itertools_jinja = ( - "{% for num in modules.itertools.accumulate([5]) %}" "{{ num }}" "{% endfor %}" - ) + itertools_jinja = "{% for num in modules.itertools.accumulate([5]) %}{{ num }}{% endfor %}" assert context.render(itertools_jinja) == "5" @@ -681,6 +1693,21 @@ def test_flags(sushi_test_project: Project): assert context.render("{{ flags.WHICH }}") == "parse" +def test_invocation_args_dict(sushi_test_project: Project): + context = sushi_test_project.context + + assert context.render("{{ invocation_args_dict['full_refresh'] }}") == "None" + assert context.render("{{ invocation_args_dict['store_failures'] }}") == "None" + assert context.render("{{ invocation_args_dict['which'] }}") == "parse" + + +@pytest.mark.xdist_group("dbt_manifest") +def test_context_namespace(sushi_test_project: Project): + context = sushi_test_project.context + + assert context.render("{{ context.flags.FULL_REFRESH }}") == "None" + + @pytest.mark.xdist_group("dbt_manifest") def test_relation(sushi_test_project: Project): context = sushi_test_project.context @@ -705,7 +1732,7 @@ def test_column(sushi_test_project: Project): assert context.render("{{ api.Column }}") == "" jinja = ( - "{% set col = api.Column('foo', 'integer') %}" "{{ col.is_integer() }} {{ col.is_string()}}" + "{% set col = api.Column('foo', 'integer') %}{{ col.is_integer() }} {{ col.is_string()}}" ) assert context.render(jinja) == "True False" @@ -724,12 +1751,10 @@ def test_as_filters(sushi_test_project: Project): context = sushi_test_project.context assert context.render("{{ True | as_bool }}") == "True" - with pytest.raises(MacroEvalError, match="Failed to convert 'invalid' into boolean."): - context.render("{{ 'invalid' | as_bool }}") + assert context.render("{{ 'valid' | as_bool }}") == "valid" assert context.render("{{ 123 | as_number }}") == "123" - with pytest.raises(MacroEvalError, match="Failed to convert 'invalid' into number."): - context.render("{{ 'invalid' | as_number }}") + assert context.render("{{ 'valid' | as_number }}") == "valid" assert context.render("{{ None | as_text }}") == "" @@ -788,6 +1813,45 @@ def test_dbt_version(sushi_test_project: Project): assert context.render("{{ dbt_version }}").startswith("1.") +@pytest.mark.xdist_group("dbt_manifest") +def test_dbt_on_run_start_end(sushi_test_project: Project): + # Validate perservation of dbt's order of execution + assert sushi_test_project.packages["sushi"].on_run_start["sushi-on-run-start-0"].index == 0 + assert sushi_test_project.packages["sushi"].on_run_start["sushi-on-run-start-1"].index == 1 + assert sushi_test_project.packages["sushi"].on_run_end["sushi-on-run-end-0"].index == 0 + assert sushi_test_project.packages["sushi"].on_run_end["sushi-on-run-end-1"].index == 1 + assert ( + sushi_test_project.packages["customers"].on_run_start["customers-on-run-start-0"].index == 0 + ) + assert ( + sushi_test_project.packages["customers"].on_run_start["customers-on-run-start-1"].index == 1 + ) + assert sushi_test_project.packages["customers"].on_run_end["customers-on-run-end-0"].index == 0 + assert sushi_test_project.packages["customers"].on_run_end["customers-on-run-end-1"].index == 1 + + assert ( + sushi_test_project.packages["customers"].on_run_start["customers-on-run-start-0"].sql + == "CREATE TABLE IF NOT EXISTS to_be_executed_first (col VARCHAR);" + ) + assert ( + sushi_test_project.packages["customers"].on_run_start["customers-on-run-start-1"].sql + == "CREATE TABLE IF NOT EXISTS analytic_stats_packaged_project (physical_table VARCHAR, evaluation_time VARCHAR);" + ) + assert ( + sushi_test_project.packages["customers"].on_run_end["customers-on-run-end-1"].sql + == "{{ packaged_tables(schemas) }}" + ) + + assert ( + sushi_test_project.packages["sushi"].on_run_start["sushi-on-run-start-0"].sql + == "CREATE TABLE IF NOT EXISTS analytic_stats (physical_table VARCHAR, evaluation_time VARCHAR);" + ) + assert ( + sushi_test_project.packages["sushi"].on_run_end["sushi-on-run-end-0"].sql + == "{{ create_tables(schemas) }}" + ) + + @pytest.mark.xdist_group("dbt_manifest") def test_parsetime_adapter_call( assert_exp_eq, sushi_test_project: Project, sushi_test_dbt_context: Context @@ -819,7 +1883,7 @@ def test_parsetime_adapter_call( @pytest.mark.xdist_group("dbt_manifest") -def test_partition_by(sushi_test_project: Project): +def test_partition_by(sushi_test_project: Project, caplog): context = sushi_test_project.context context.target = BigQueryConfig(name="production", database="main", schema="sushi") model_config = ModelConfig( @@ -859,18 +1923,82 @@ def test_partition_by(sushi_test_project: Project): model_config.partition_by = {"field": "ds", "data_type": "date", "granularity": "day"} assert model_config.to_sqlmesh(context).partitioned_by == [exp.to_column("ds", quoted=True)] + context.target = DuckDbConfig(name="target", schema="foo") + assert model_config.to_sqlmesh(context).partitioned_by == [] -@pytest.mark.xdist_group("dbt_manifest") -def test_relation_info_to_relation(): - assert _relation_info_to_relation( - {"quote_policy": {}}, - BaseRelation, - Policy(database=True, schema=True, identifier=True), - ).quote_policy == Policy(database=True, schema=True, identifier=True) - - assert _relation_info_to_relation( - {"quote_policy": {"database": None, "schema": None, "identifier": None}}, - BaseRelation, + context.target = SnowflakeConfig( + name="target", schema="test", database="test", account="foo", user="bar", password="baz" + ) + assert model_config.to_sqlmesh(context).partitioned_by == [] + assert ( + "Ignoring partition_by config for model 'model' targeting snowflake. The partition_by config is not supported for Snowflake." + in caplog.text + ) + + model_config = ModelConfig( + name="model", + alias="model", + schema="test", + package_name="package", + materialized=Materialization.VIEW.value, + unique_key="ds", + partition_by={"field": "ds", "granularity": "month"}, + sql="""SELECT 1 AS one, ds FROM foo""", + ) + assert model_config.to_sqlmesh(context).partitioned_by == [] + + model_config = ModelConfig( + name="model", + alias="model", + schema="test", + package_name="package", + materialized=Materialization.EPHEMERAL.value, + unique_key="ds", + partition_by={"field": "ds", "granularity": "month"}, + sql="""SELECT 1 AS one, ds FROM foo""", + ) + assert model_config.to_sqlmesh(context).partitioned_by == [] + + with pytest.raises(ConfigError, match="Unexpected data_type 'string' in partition_by"): + ModelConfig( + name="model", + alias="model", + schema="test", + package_name="package", + materialized="table", + partition_by={"field": "ds", "data_type": "string"}, + sql="""SELECT 1 AS one, ds FROM foo""", + ) + + +@pytest.mark.xdist_group("dbt_manifest") +def test_partition_by_none(sushi_test_project: Project): + context = sushi_test_project.context + context.target = BigQueryConfig(name="production", database="main", schema="sushi") + model_config = ModelConfig( + name="model", + alias="model", + schema="test", + package_name="package", + materialized="table", + unique_key="ds", + partition_by=None, + sql="""SELECT 1 AS one, ds FROM foo""", + ) + assert model_config.partition_by is None + + +@pytest.mark.xdist_group("dbt_manifest") +def test_relation_info_to_relation(): + assert _relation_info_to_relation( + {"quote_policy": {}}, + BaseRelation, + Policy(database=True, schema=True, identifier=True), + ).quote_policy == Policy(database=True, schema=True, identifier=True) + + assert _relation_info_to_relation( + {"quote_policy": {"database": None, "schema": None, "identifier": None}}, + BaseRelation, Policy(database=True, schema=True, identifier=True), ).quote_policy == Policy(database=True, schema=True, identifier=True) @@ -916,12 +2044,62 @@ def test_is_incremental(sushi_test_project: Project, assert_exp_eq, mocker): snapshot = mocker.Mock() snapshot.intervals = [1] + snapshot.is_incremental = True assert_exp_eq( model_config.to_sqlmesh(context).render_query_or_raise(snapshot=snapshot).sql(), 'SELECT 1 AS "one" FROM "tbl_a" AS "tbl_a" WHERE "ds" > (SELECT MAX("ds") FROM "model" AS "model")', ) + # If the snapshot_table_exists flag was set to False, intervals should be ignored + assert_exp_eq( + model_config.to_sqlmesh(context) + .render_query_or_raise(snapshot=snapshot, snapshot_table_exists=False) + .sql(), + 'SELECT 1 AS "one" FROM "tbl_a" AS "tbl_a"', + ) + + # If the snapshot_table_exists flag was set to True, intervals should be taken into account + assert_exp_eq( + model_config.to_sqlmesh(context) + .render_query_or_raise(snapshot=snapshot, snapshot_table_exists=True) + .sql(), + 'SELECT 1 AS "one" FROM "tbl_a" AS "tbl_a" WHERE "ds" > (SELECT MAX("ds") FROM "model" AS "model")', + ) + snapshot.intervals = [] + assert_exp_eq( + model_config.to_sqlmesh(context) + .render_query_or_raise(snaspshot=snapshot, snapshot_table_exists=True) + .sql(), + 'SELECT 1 AS "one" FROM "tbl_a" AS "tbl_a"', + ) + + +@pytest.mark.xdist_group("dbt_manifest") +def test_is_incremental_non_incremental_model(sushi_test_project: Project, assert_exp_eq, mocker): + model_config = ModelConfig( + name="model", + package_name="package", + schema="sushi", + alias="some_table", + sql=""" + SELECT 1 AS one FROM tbl_a + {% if is_incremental() %} + WHERE ds > (SELECT MAX(ds) FROM model) + {% endif %} + """, + ) + context = sushi_test_project.context + + snapshot = mocker.Mock() + snapshot.intervals = [1] + snapshot.is_incremental = False + + assert_exp_eq( + model_config.to_sqlmesh(context).render_query_or_raise(snapshot=snapshot).sql(), + 'SELECT 1 AS "one" FROM "tbl_a" AS "tbl_a"', + ) + @pytest.mark.xdist_group("dbt_manifest") def test_dbt_max_partition(sushi_test_project: Project, assert_exp_eq, mocker: MockerFixture): @@ -952,7 +2130,7 @@ def test_dbt_max_partition(sushi_test_project: Project, assert_exp_eq, mocker: M JINJA_STATEMENT_BEGIN; {% if is_incremental() %} DECLARE _dbt_max_partition DATETIME DEFAULT ( - COALESCE((SELECT MAX(PARSE_DATETIME('%Y%m', partition_id)) FROM `{{ target.database }}.{{ adapter.resolve_schema(this) }}.INFORMATION_SCHEMA.PARTITIONS` WHERE table_name = '{{ adapter.resolve_identifier(this) }}' AND NOT partition_id IS NULL AND partition_id <> '__NULL__'), CAST('1970-01-01' AS DATETIME)) + COALESCE((SELECT MAX(PARSE_DATETIME('%Y%m', partition_id)) FROM `{{ target.database }}`.`{{ adapter.resolve_schema(this) }}`.`INFORMATION_SCHEMA.PARTITIONS` AS PARTITIONS WHERE table_name = '{{ adapter.resolve_identifier(this) }}' AND NOT partition_id IS NULL AND partition_id <> '__NULL__'), CAST('1970-01-01' AS DATETIME)) ); {% endif %} JINJA_END;""".strip() @@ -1000,6 +2178,59 @@ def test_bigquery_physical_properties(sushi_test_project: Project, mocker: Mocke } +@pytest.mark.xdist_group("dbt_manifest") +def test_clickhouse_properties(mocker: MockerFixture): + context = DbtContext(target_name="production") + context._project_name = "Foo" + context._target = ClickhouseConfig(name="production") + model_config = ModelConfig( + name="model", + alias="model", + schema="test", + package_name="package", + materialized="incremental", + incremental_strategy="delete+insert", + incremental_predicates=["ds > (SELECT MAX(ds) FROM model)"], + query_settings={"QUERY_SETTING": "value"}, + sharding_key="rand()", + engine="MergeTree()", + partition_by=["toMonday(ds)", "partition_col"], + order_by=["toStartOfWeek(ds)", "order_col"], + primary_key=["ds", "primary_key_col"], + ttl="time + INTERVAL 1 WEEK", + settings={"SETTING": "value"}, + sql="""SELECT 1 AS one, ds FROM foo""", + ) + + with patch.object(get_console(), "log_warning") as mock_logger: + model_to_sqlmesh = model_config.to_sqlmesh(context) + + assert [call[0][0] for call in mock_logger.call_args_list] == [ + "The 'delete+insert' incremental strategy is not supported - SQLMesh will use the temp table/partition swap strategy.", + "SQLMesh does not support 'incremental_predicates' - they will not be applied.", + "SQLMesh does not support the 'query_settings' model configuration parameter. Specify the query settings directly in the model query.", + "SQLMesh does not support the 'sharding_key' model configuration parameter or distributed materializations.", + ] + + assert [e.sql("clickhouse") for e in model_to_sqlmesh.partitioned_by] == [ + "dateTrunc('WEEK', \"ds\")", + '"partition_col"', + ] + assert model_to_sqlmesh.storage_format == "MergeTree()" + + physical_properties = model_to_sqlmesh.physical_properties + assert [e.sql("clickhouse", identify=True) for e in physical_properties["order_by"]] == [ + 'toStartOfWeek("ds")', + '"order_col"', + ] + assert [e.sql("clickhouse", identify=True) for e in physical_properties["primary_key"]] == [ + '"ds"', + '"primary_key_col"', + ] + assert physical_properties["ttl"].sql("clickhouse") == "time + INTERVAL 1 WEEK" + assert physical_properties["SETTING"].sql("clickhouse") == "value" + + @pytest.mark.xdist_group("dbt_manifest") def test_snapshot_json_payload(): sushi_context = Context(paths=["tests/fixtures/dbt/sushi_test"]) @@ -1009,14 +2240,15 @@ def test_snapshot_json_payload(): assert snapshot_json["node"]["jinja_macros"]["global_objs"]["target"] == { "type": "duckdb", "name": "in_memory", - "schema": "sushi", "database": "memory", + "schema": "sushi", + "threads": 1, "target_name": "in_memory", } @pytest.mark.xdist_group("dbt_manifest") -@freeze_time("2023-01-08 00:00:00") +@time_machine.travel("2023-01-08 00:00:00 UTC") def test_dbt_package_macros(sushi_test_project: Project): context = sushi_test_project.context @@ -1029,6 +2261,9 @@ def test_dbt_package_macros(sushi_test_project: Project): @pytest.mark.xdist_group("dbt_manifest") def test_dbt_vars(sushi_test_project: Project): context = sushi_test_project.context + context.set_and_render_variables( + sushi_test_project.packages["customers"].variables, "customers" + ) assert context.render("{{ var('some_other_var') }}") == "5" assert context.render("{{ var('some_other_var', 0) }}") == "5" @@ -1088,7 +2323,21 @@ def test_model_cluster_by(): sql="SELECT * FROM baz", materialized=Materialization.TABLE.value, ) - assert model.to_sqlmesh(context).clustered_by == ["BAR"] + assert model.to_sqlmesh(context).clustered_by == [exp.to_column('"BAR"')] + + model = ModelConfig( + name="model", + alias="model", + package_name="package", + target_schema="test", + cluster_by=["Bar", "qux"], + sql="SELECT * FROM baz", + materialized=Materialization.TABLE.value, + ) + assert model.to_sqlmesh(context).clustered_by == [ + exp.to_column('"BAR"'), + exp.to_column('"QUX"'), + ] model = ModelConfig( name="model", @@ -1097,6 +2346,684 @@ def test_model_cluster_by(): target_schema="test", cluster_by=["Bar", "qux"], sql="SELECT * FROM baz", + materialized=Materialization.VIEW.value, + ) + assert model.to_sqlmesh(context).clustered_by == [] + + model = ModelConfig( + name="model", + alias="model", + package_name="package", + target_schema="test", + cluster_by=["Bar", "qux"], + sql="SELECT * FROM baz", + materialized=Materialization.EPHEMERAL.value, + ) + assert model.to_sqlmesh(context).clustered_by == [] + + model = ModelConfig( + name="model", + alias="model", + package_name="package", + target_schema="test", + cluster_by="Bar, qux", + sql="SELECT * FROM baz", + materialized=Materialization.TABLE.value, + ) + assert model.to_sqlmesh(context).clustered_by == [ + exp.to_column('"BAR"'), + exp.to_column('"QUX"'), + ] + + model = ModelConfig( + name="model", + alias="model", + package_name="package", + target_schema="test", + cluster_by=['"Bar,qux"'], + sql="SELECT * FROM baz", + materialized=Materialization.TABLE.value, + ) + assert model.to_sqlmesh(context).clustered_by == [ + exp.to_column('"Bar,qux"'), + ] + + model = ModelConfig( + name="model", + alias="model", + package_name="package", + target_schema="test", + cluster_by='"Bar,qux"', + sql="SELECT * FROM baz", + materialized=Materialization.TABLE.value, + ) + assert model.to_sqlmesh(context).clustered_by == [ + exp.to_column('"Bar,qux"'), + ] + + model = ModelConfig( + name="model", + alias="model", + package_name="package", + target_schema="test", + cluster_by=["to_date(Bar),qux"], + sql="SELECT * FROM baz", + materialized=Materialization.TABLE.value, + ) + assert model.to_sqlmesh(context).clustered_by == [ + exp.TsOrDsToDate(this=exp.to_column('"BAR"')), + exp.to_column('"QUX"'), + ] + + +def test_snowflake_dynamic_table(): + context = DbtContext() + context._target = SnowflakeConfig( + name="target", + schema="test", + database="test", + account="account", + user="user", + password="password", + ) + + model = ModelConfig( + name="model", + alias="model", + package_name="package", + target_schema="test", + sql="SELECT * FROM baz", + materialized=Materialization.DYNAMIC_TABLE.value, + target_lag="1 hour", + snowflake_warehouse="SMALL", + ) + + as_sqlmesh = model.to_sqlmesh(context) + assert as_sqlmesh.kind == ManagedKind() + assert as_sqlmesh.physical_properties == { + "target_lag": exp.Literal.string("1 hour"), + "warehouse": exp.Literal.string("SMALL"), + } + + # both target_lag and snowflake_warehouse are required properties + # https://docs.getdbt.com/reference/resource-configs/snowflake-configs#dynamic-tables + for required_property in ["target_lag", "snowflake_warehouse"]: + with pytest.raises(ConfigError, match=r".*must be set for dynamic tables"): + model.copy(update={required_property: None}).to_sqlmesh(context) + + +@pytest.mark.xdist_group("dbt_manifest") +def test_refs_in_jinja_globals(sushi_test_project: Project, mocker: MockerFixture): + context = sushi_test_project.context + + sqlmesh_model = t.cast( + SqlModel, + sushi_test_project.packages["sushi"].models["simple_model_b"].to_sqlmesh(context), + ) + assert set(sqlmesh_model.jinja_macros.global_objs["refs"].keys()) == {"simple_model_a"} # type: ignore + + sqlmesh_model = t.cast( + SqlModel, + sushi_test_project.packages["sushi"].models["top_waiters"].to_sqlmesh(context), + ) + assert set(sqlmesh_model.jinja_macros.global_objs["refs"].keys()) == { # type: ignore + "waiter_revenue_by_day", + "sushi.waiter_revenue_by_day", + } + + +def test_allow_partials_by_default(): + context = DbtContext() + context._target = SnowflakeConfig( + name="target", + schema="test", + database="test", + account="account", + user="user", + password="password", + ) + + model = ModelConfig( + name="model", + alias="model", + package_name="package", + target_schema="test", + sql="SELECT * FROM baz", materialized=Materialization.TABLE.value, ) - assert model.to_sqlmesh(context).clustered_by == ["BAR", "QUX"] + assert model.allow_partials + assert model.to_sqlmesh(context).allow_partials + + model.materialized = Materialization.INCREMENTAL.value + assert model.allow_partials + assert model.to_sqlmesh(context).allow_partials + + model.allow_partials = True + assert model.to_sqlmesh(context).allow_partials + + model.allow_partials = False + assert not model.to_sqlmesh(context).allow_partials + + +def test_grain(): + context = DbtContext() + context._target = SnowflakeConfig( + name="target", + schema="test", + database="test", + account="account", + user="user", + password="password", + ) + + model = ModelConfig( + name="model", + alias="model", + package_name="package", + target_schema="test", + sql="SELECT * FROM baz", + materialized=Materialization.TABLE.value, + grain=["id_a", "id_b"], + ) + assert model.to_sqlmesh(context).grains == [exp.to_column("id_a"), exp.to_column("id_b")] + + model.grain = "id_a" + assert model.to_sqlmesh(context).grains == [exp.to_column("id_a")] + + +@pytest.mark.xdist_group("dbt_manifest") +def test_on_run_start_end(): + sushi_context = Context(paths=["tests/fixtures/dbt/sushi_test"]) + assert len(sushi_context._environment_statements) == 2 + + # Root project's on run start / on run end should be first by checking the macros + root_environment_statements = sushi_context._environment_statements[0] + assert "create_tables" in root_environment_statements.jinja_macros.root_macros + + # Validate order of execution to be correct + assert root_environment_statements.before_all == [ + "JINJA_STATEMENT_BEGIN;\nCREATE TABLE IF NOT EXISTS analytic_stats (physical_table VARCHAR, evaluation_time VARCHAR);\nJINJA_END;", + "JINJA_STATEMENT_BEGIN;\nCREATE TABLE IF NOT EXISTS to_be_executed_last (col VARCHAR);\nJINJA_END;", + """JINJA_STATEMENT_BEGIN;\nSELECT {{ var("yet_another_var") }} AS var, '{{ source("raw", "items").identifier }}' AS src, '{{ ref("waiters").identifier }}' AS ref;\nJINJA_END;""", + "JINJA_STATEMENT_BEGIN;\n{{ log_value('on-run-start') }}\nJINJA_END;", + ] + assert root_environment_statements.after_all == [ + "JINJA_STATEMENT_BEGIN;\n{{ create_tables(schemas) }}\nJINJA_END;", + "JINJA_STATEMENT_BEGIN;\nDROP TABLE to_be_executed_last;\nJINJA_END;", + "JINJA_STATEMENT_BEGIN;\n{{ graph_usage() }}\nJINJA_END;", + ] + + assert root_environment_statements.jinja_macros.root_package_name == "sushi" + + rendered_before_all = render_statements( + root_environment_statements.before_all, + dialect=sushi_context.default_dialect, + python_env=root_environment_statements.python_env, + jinja_macros=root_environment_statements.jinja_macros, + runtime_stage=RuntimeStage.BEFORE_ALL, + ) + + runtime_rendered_after_all = render_statements( + root_environment_statements.after_all, + dialect=sushi_context.default_dialect, + python_env=root_environment_statements.python_env, + jinja_macros=root_environment_statements.jinja_macros, + snapshots=sushi_context.snapshots, + runtime_stage=RuntimeStage.AFTER_ALL, + environment_naming_info=EnvironmentNamingInfo(name="dev"), + engine_adapter=sushi_context.engine_adapter, + ) + + # not passing engine adapter simulates "parse-time" rendering + parse_time_rendered_after_all = render_statements( + root_environment_statements.after_all, + dialect=sushi_context.default_dialect, + python_env=root_environment_statements.python_env, + jinja_macros=root_environment_statements.jinja_macros, + snapshots=sushi_context.snapshots, + runtime_stage=RuntimeStage.AFTER_ALL, + environment_naming_info=EnvironmentNamingInfo(name="dev"), + ) + + # validate that the graph_table statement is the same between parse-time and runtime rendering + assert sorted(parse_time_rendered_after_all) == sorted(runtime_rendered_after_all) + graph_table_stmt = runtime_rendered_after_all[-1] + assert graph_table_stmt == parse_time_rendered_after_all[-1] + + assert rendered_before_all == [ + "CREATE TABLE IF NOT EXISTS analytic_stats (physical_table TEXT, evaluation_time TEXT)", + "CREATE TABLE IF NOT EXISTS to_be_executed_last (col TEXT)", + "SELECT 1 AS var, 'items' AS src, 'waiters' AS ref", + ] + + # The jinja macro should have resolved the schemas for this environment and generated corresponding statements + expected_statements = [ + "CREATE OR REPLACE TABLE schema_table_snapshots__dev AS SELECT 'snapshots__dev' AS schema", + "CREATE OR REPLACE TABLE schema_table_sushi__dev AS SELECT 'sushi__dev' AS schema", + "DROP TABLE to_be_executed_last", + ] + assert sorted(runtime_rendered_after_all[:-1]) == sorted(expected_statements) + + # Assert the models with their materialisations are present in the rendered graph_table statement + assert "'model.sushi.simple_model_a' AS unique_id, 'table' AS materialized" in graph_table_stmt + assert "'model.sushi.waiters' AS unique_id, 'ephemeral' AS materialized" in graph_table_stmt + assert "'model.sushi.simple_model_b' AS unique_id, 'table' AS materialized" in graph_table_stmt + assert ( + "'model.sushi.waiter_as_customer_by_day' AS unique_id, 'incremental' AS materialized" + in graph_table_stmt + ) + assert "'model.sushi.top_waiters' AS unique_id, 'view' AS materialized" in graph_table_stmt + assert "'model.customers.customers' AS unique_id, 'view' AS materialized" in graph_table_stmt + assert ( + "'model.customers.customer_revenue_by_day' AS unique_id, 'incremental' AS materialized" + in graph_table_stmt + ) + assert ( + "'model.sushi.waiter_revenue_by_day.v1' AS unique_id, 'incremental' AS materialized" + in graph_table_stmt + ) + assert ( + "'model.sushi.waiter_revenue_by_day.v2' AS unique_id, 'incremental' AS materialized" + in graph_table_stmt + ) + + # Nested dbt_packages on run start / on run end + packaged_environment_statements = sushi_context._environment_statements[1] + + # Validate order of execution to be correct + assert packaged_environment_statements.before_all == [ + "JINJA_STATEMENT_BEGIN;\nCREATE TABLE IF NOT EXISTS to_be_executed_first (col VARCHAR);\nJINJA_END;", + "JINJA_STATEMENT_BEGIN;\nCREATE TABLE IF NOT EXISTS analytic_stats_packaged_project (physical_table VARCHAR, evaluation_time VARCHAR);\nJINJA_END;", + ] + assert packaged_environment_statements.after_all == [ + "JINJA_STATEMENT_BEGIN;\nDROP TABLE to_be_executed_first\nJINJA_END;", + "JINJA_STATEMENT_BEGIN;\n{{ packaged_tables(schemas) }}\nJINJA_END;", + ] + + assert "packaged_tables" in packaged_environment_statements.jinja_macros.root_macros + assert packaged_environment_statements.jinja_macros.root_package_name == "sushi" + + rendered_before_all = render_statements( + packaged_environment_statements.before_all, + dialect=sushi_context.default_dialect, + python_env=packaged_environment_statements.python_env, + jinja_macros=packaged_environment_statements.jinja_macros, + runtime_stage=RuntimeStage.BEFORE_ALL, + ) + + rendered_after_all = render_statements( + packaged_environment_statements.after_all, + dialect=sushi_context.default_dialect, + python_env=packaged_environment_statements.python_env, + jinja_macros=packaged_environment_statements.jinja_macros, + snapshots=sushi_context.snapshots, + runtime_stage=RuntimeStage.AFTER_ALL, + environment_naming_info=EnvironmentNamingInfo(name="dev"), + engine_adapter=sushi_context.engine_adapter, + ) + + # Validate order of execution to match dbt's + assert rendered_before_all == [ + "CREATE TABLE IF NOT EXISTS to_be_executed_first (col TEXT)", + "CREATE TABLE IF NOT EXISTS analytic_stats_packaged_project (physical_table TEXT, evaluation_time TEXT)", + ] + + # This on run end statement should be executed first + assert rendered_after_all[0] == "DROP TABLE to_be_executed_first" + + # The table names is an indication of the rendering of the dbt_packages statements + assert sorted(rendered_after_all) == sorted( + [ + "DROP TABLE to_be_executed_first", + "CREATE OR REPLACE TABLE schema_table_snapshots__dev_nested_package AS SELECT 'snapshots__dev' AS schema", + "CREATE OR REPLACE TABLE schema_table_sushi__dev_nested_package AS SELECT 'sushi__dev' AS schema", + ] + ) + + +@pytest.mark.xdist_group("dbt_manifest") +def test_dynamic_var_names(sushi_test_project: Project, sushi_test_dbt_context: Context): + context = sushi_test_project.context + context.set_and_render_variables(sushi_test_project.packages["sushi"].variables, "sushi") + context.target = BigQueryConfig(name="production", database="main", schema="sushi") + model_config = ModelConfig( + name="model", + alias="model", + schema="test", + package_name="package", + materialized="table", + unique_key="ds", + partition_by={"field": "ds", "granularity": "month"}, + sql=""" + {% set var_name = "yet_" + "another_" + "var" %} + {% set results = run_query('select 1 as one') %} + {% if results %} + SELECT {{ results.columns[0].values()[0] }} AS one {{ var(var_name) }} AS var FROM {{ this.identifier }} + {% else %} + SELECT NULL AS one {{ var(var_name) }} AS var FROM {{ this.identifier }} + {% endif %} + """, + dependencies=Dependencies(has_dynamic_var_names=True), + ) + converted_model = model_config.to_sqlmesh(context) + assert "yet_another_var" in converted_model.jinja_macros.global_objs["vars"] # type: ignore + + # Test the existing model in the sushi project + assert ( + "dynamic_test_var" # type: ignore + in sushi_test_dbt_context.get_model( + "sushi.waiter_revenue_by_day_v2" + ).jinja_macros.global_objs["vars"] + ) + + +@pytest.mark.xdist_group("dbt_manifest") +def test_dynamic_var_names_in_macro(sushi_test_project: Project): + context = sushi_test_project.context + context.set_and_render_variables(sushi_test_project.packages["sushi"].variables, "sushi") + context.target = BigQueryConfig(name="production", database="main", schema="sushi") + model_config = ModelConfig( + name="model", + alias="model", + schema="test", + package_name="package", + materialized="table", + unique_key="ds", + partition_by={"field": "ds", "granularity": "month"}, + sql=""" + {% set var_name = "dynamic_" + "test_" + "var" %} + SELECT {{ sushi.dynamic_var_name_dependency(var_name) }} AS var + """, + dependencies=Dependencies( + macros=[MacroReference(package="sushi", name="dynamic_var_name_dependency")], + has_dynamic_var_names=True, + ), + ) + converted_model = model_config.to_sqlmesh(context) + assert "dynamic_test_var" in converted_model.jinja_macros.global_objs["vars"] # type: ignore + + +def test_selected_resources_with_selectors(): + sushi_context = Context(paths=["tests/fixtures/dbt/sushi_test"]) + + # A plan with a specific model selection + plan_builder = sushi_context.plan_builder(select_models=["sushi.customers"]) + plan = plan_builder.build() + assert len(plan.selected_models) == 1 + selected_model = list(plan.selected_models)[0] + assert "customers" in selected_model + + # Plan without model selections should include all models + plan_builder = sushi_context.plan_builder() + plan = plan_builder.build() + assert plan.selected_models is not None + assert len(plan.selected_models) > 10 + + # with downstream models should select customers and at least one downstream model + plan_builder = sushi_context.plan_builder(select_models=["sushi.customers+"]) + plan = plan_builder.build() + assert plan.selected_models is not None + assert len(plan.selected_models) >= 2 + assert any("customers" in model for model in plan.selected_models) + + # Test wildcard selection + plan_builder = sushi_context.plan_builder(select_models=["sushi.waiter_*"]) + plan = plan_builder.build() + assert plan.selected_models is not None + assert len(plan.selected_models) >= 4 + assert all("waiter" in model for model in plan.selected_models) + + +@pytest.mark.xdist_group("dbt_manifest") +def test_selected_resources_context_variable( + sushi_test_project: Project, sushi_test_dbt_context: Context +): + context = sushi_test_project.context + + # empty selected resources + direct_access = context.render("{{ selected_resources }}") + assert direct_access == "[]" + + # selected_resources is iterable and count items + test_jinja = """ + {%- set resources = [] -%} + {%- for resource in selected_resources -%} + {%- do resources.append(resource) -%} + {%- endfor -%} + {{ resources | length }} + """ + result = context.render(test_jinja) + assert result.strip() == "0" + + # selected_resources in conditions + test_condition = """ + {%- if selected_resources -%} + has_resources + {%- else -%} + no_resources + {%- endif -%} + """ + result = context.render(test_condition) + assert result.strip() == "no_resources" + + # selected resources in dbt format + selected_resources = [ + "model.jaffle_shop.customers", + "model.jaffle_shop.items", + "model.jaffle_shop.orders", + ] + + # check the jinja macros rendering + result = context.render("{{ selected_resources }}", selected_resources=selected_resources) + assert result == selected_resources.__repr__() + + result = context.render(test_jinja, selected_resources=selected_resources) + assert result.strip() == "3" + + result = context.render(test_condition, selected_resources=selected_resources) + assert result.strip() == "has_resources" + + +def test_ignore_source_depends_on_when_also_model(dbt_dummy_postgres_config: PostgresConfig): + context = DbtContext() + context._target = dbt_dummy_postgres_config + + source_a = SourceConfig( + name="source_a", + fqn=["package", "schema", "model_a"], + ) + source_a._canonical_name = "schema.source_a" + source_b = SourceConfig( + name="source_b", + fqn=["package", "schema", "source_b"], + ) + source_b._canonical_name = "schema.source_b" + context.sources = {"source_a": source_a, "source_b": source_b} + + model = ModelConfig( + dependencies=Dependencies(sources={"source_a", "source_b"}), + fqn=["package", "schema", "test_model"], + ) + context.models = { + "test_model": model, + "model_a": ModelConfig(name="model_a", fqn=["package", "schema", "model_a"]), + } + + assert model.sqlmesh_model_kwargs(context)["depends_on"] == {"schema.source_b"} + + +@pytest.mark.xdist_group("dbt_manifest") +def test_dbt_hooks_with_transaction_flag(sushi_test_dbt_context: Context): + model_fqn = '"memory"."sushi"."model_with_transaction_hooks"' + assert model_fqn in sushi_test_dbt_context.models + + model = sushi_test_dbt_context.models[model_fqn] + + pre_statements = model.pre_statements_ + assert pre_statements is not None + assert len(pre_statements) >= 3 + + # we need to check the expected SQL but more importantly that the transaction flags are there + assert any( + s.sql == 'JINJA_STATEMENT_BEGIN;\n{{ log("pre-hook") }}\nJINJA_END;' + and s.transaction is True + for s in pre_statements + ) + assert any( + "CREATE TABLE IF NOT EXISTS hook_outside_pre_table" in s.sql and s.transaction is False + for s in pre_statements + ) + assert any( + "CREATE TABLE IF NOT EXISTS shared_hook_table" in s.sql and s.transaction is False + for s in pre_statements + ) + assert any( + "{{ insert_into_shared_hook_table('inside_pre') }}" in s.sql and s.transaction is True + for s in pre_statements + ) + + post_statements = model.post_statements_ + assert post_statements is not None + assert len(post_statements) >= 4 + assert any( + s.sql == 'JINJA_STATEMENT_BEGIN;\n{{ log("post-hook") }}\nJINJA_END;' + and s.transaction is True + for s in post_statements + ) + assert any( + "{{ insert_into_shared_hook_table('inside_post') }}" in s.sql and s.transaction is True + for s in post_statements + ) + assert any( + "CREATE TABLE IF NOT EXISTS hook_outside_post_table" in s.sql and s.transaction is False + for s in post_statements + ) + assert any( + "{{ insert_into_shared_hook_table('after_commit') }}" in s.sql and s.transaction is False + for s in post_statements + ) + + # render_pre_statements with inside_transaction=True should only return inserrt + inside_pre_statements = model.render_pre_statements(inside_transaction=True) + assert len(inside_pre_statements) == 1 + assert ( + inside_pre_statements[0].sql() + == """INSERT INTO "shared_hook_table" ("id", "hook_name", "execution_order", "created_at") VALUES ((SELECT COALESCE(MAX("id"), 0) + 1 FROM "shared_hook_table"), 'inside_pre', (SELECT COALESCE(MAX("id"), 0) + 1 FROM "shared_hook_table"), NOW())""" + ) + + # while for render_pre_statements with inside_transaction=False the create statements + outside_pre_statements = model.render_pre_statements(inside_transaction=False) + assert len(outside_pre_statements) == 2 + assert "CREATE" in outside_pre_statements[0].sql() + assert "hook_outside_pre_table" in outside_pre_statements[0].sql() + assert "CREATE" in outside_pre_statements[1].sql() + assert "shared_hook_table" in outside_pre_statements[1].sql() + + # similarly for post statements + inside_post_statements = model.render_post_statements(inside_transaction=True) + assert len(inside_post_statements) == 1 + assert ( + inside_post_statements[0].sql() + == """INSERT INTO "shared_hook_table" ("id", "hook_name", "execution_order", "created_at") VALUES ((SELECT COALESCE(MAX("id"), 0) + 1 FROM "shared_hook_table"), 'inside_post', (SELECT COALESCE(MAX("id"), 0) + 1 FROM "shared_hook_table"), NOW())""" + ) + + outside_post_statements = model.render_post_statements(inside_transaction=False) + assert len(outside_post_statements) == 2 + assert "CREATE" in outside_post_statements[0].sql() + assert "hook_outside_post_table" in outside_post_statements[0].sql() + assert "INSERT" in outside_post_statements[1].sql() + assert "shared_hook_table" in outside_post_statements[1].sql() + + +@pytest.mark.xdist_group("dbt_manifest") +def test_dbt_hooks_with_transaction_flag_execution(sushi_test_dbt_context: Context): + model_fqn = '"memory"."sushi"."model_with_transaction_hooks"' + assert model_fqn in sushi_test_dbt_context.models + + plan = sushi_test_dbt_context.plan(select_models=["sushi.model_with_transaction_hooks"]) + sushi_test_dbt_context.apply(plan) + + result = sushi_test_dbt_context.engine_adapter.fetchdf( + "SELECT * FROM sushi.model_with_transaction_hooks" + ) + assert len(result) == 1 + assert result["id"][0] == 1 + assert result["name"][0] == "test" + + # ensure the outside pre-hook and post-hook table were created + pre_outside = sushi_test_dbt_context.engine_adapter.fetchdf( + "SELECT * FROM hook_outside_pre_table" + ) + assert len(pre_outside) == 1 + assert pre_outside["id"][0] == 1 + assert pre_outside["location"][0] == "outside" + assert pre_outside["execution_order"][0] == 1 + + post_outside = sushi_test_dbt_context.engine_adapter.fetchdf( + "SELECT * FROM hook_outside_post_table" + ) + assert len(post_outside) == 1 + assert post_outside["id"][0] == 5 + assert post_outside["location"][0] == "outside" + assert post_outside["execution_order"][0] == 5 + + # verify the shared table that was created by before_begin and populated by all hooks + shared_table = sushi_test_dbt_context.engine_adapter.fetchdf( + "SELECT * FROM shared_hook_table ORDER BY execution_order" + ) + assert len(shared_table) == 3 + assert shared_table["execution_order"].is_monotonic_increasing + + # The order of creation and insertion will verify the following order of execution + # 1. before_begin (transaction=false) ran BEFORE the transaction started and created the table + # 2. inside_pre (transaction=true) ran INSIDE the transaction and could insert into the table + # 3. inside_post (transaction=true) ran INSIDE the transaction and could insert into the table (but after pre statement) + # 4. after_commit (transaction=false) ran AFTER the transaction committed + + assert shared_table["id"][0] == 1 + assert shared_table["hook_name"][0] == "inside_pre" + assert shared_table["execution_order"][0] == 1 + + assert shared_table["id"][1] == 2 + assert shared_table["hook_name"][1] == "inside_post" + assert shared_table["execution_order"][1] == 2 + + assert shared_table["id"][2] == 3 + assert shared_table["hook_name"][2] == "after_commit" + assert shared_table["execution_order"][2] == 3 + + # the timestamps also should be monotonically increasing for the same reason + for i in range(len(shared_table) - 1): + assert shared_table["created_at"][i] <= shared_table["created_at"][i + 1] + + # the tables using the alternate syntax should have correct order as well + assert pre_outside["created_at"][0] < shared_table["created_at"][0] + assert post_outside["created_at"][0] > shared_table["created_at"][1] + + # running with execution time one day in the future to simulate a run + tomorrow = datetime.now() + timedelta(days=1) + sushi_test_dbt_context.run( + select_models=["sushi.model_with_transaction_hooks"], execution_time=tomorrow + ) + + # to verify that the transaction information persists in state and is respected + shared_table = sushi_test_dbt_context.engine_adapter.fetchdf( + "SELECT * FROM shared_hook_table ORDER BY execution_order" + ) + + # and the execution order for run is similar + assert shared_table["execution_order"].is_monotonic_increasing + assert shared_table["id"][3] == 4 + assert shared_table["hook_name"][3] == "inside_pre" + assert shared_table["execution_order"][3] == 4 + + assert shared_table["id"][4] == 5 + assert shared_table["hook_name"][4] == "inside_post" + assert shared_table["execution_order"][4] == 5 + + assert shared_table["id"][5] == 6 + assert shared_table["hook_name"][5] == "after_commit" + assert shared_table["execution_order"][5] == 6 + + for i in range(len(shared_table) - 1): + assert shared_table["created_at"][i] <= shared_table["created_at"][i + 1] diff --git a/tests/dbt/test_util.py b/tests/dbt/test_util.py index 67c8ba8e89..ce98f48a82 100644 --- a/tests/dbt/test_util.py +++ b/tests/dbt/test_util.py @@ -1,6 +1,6 @@ from __future__ import annotations -import pandas as pd +import pandas as pd # noqa: TID253 from sqlmesh.dbt.util import pandas_to_agate diff --git a/tests/engines/spark/test_db_api.py b/tests/engines/spark/test_db_api.py index 8bbfe7e9ab..eab7a0c223 100644 --- a/tests/engines/spark/test_db_api.py +++ b/tests/engines/spark/test_db_api.py @@ -4,10 +4,10 @@ from sqlmesh.engines.spark.db_api import errors from sqlmesh.engines.spark.db_api import spark_session as spark_session_db -pytestmark = [ - pytest.mark.slow, - pytest.mark.spark_pyspark, -] +# note: this is deliberately not marked with 'spark' so that it +# can run separately from the spark integration tests. +# running them at the same time mutates some global state in the SparkSession which breaks these tests +pytestmark = [pytest.mark.slow, pytest.mark.pyspark] def test_spark_session_cursor(spark_session: SparkSession): diff --git a/tests/fixtures/dbt/empty_project/dbt_project.yml b/tests/fixtures/dbt/empty_project/dbt_project.yml new file mode 100644 index 0000000000..dab3d1e0e8 --- /dev/null +++ b/tests/fixtures/dbt/empty_project/dbt_project.yml @@ -0,0 +1,18 @@ +name: 'empty_project' + +version: '1.0.0' +config-version: 2 + +profile: 'empty_project' + +model-paths: ["models"] +seed-paths: ["seeds"] +test-paths: ["tests"] +analysis-paths: ["analysis"] +macro-paths: ["macros"] + +target-path: "target" +clean-targets: + - "target" + - "dbt_modules" + - "logs" \ No newline at end of file diff --git a/tests/fixtures/dbt/empty_project/profiles.yml b/tests/fixtures/dbt/empty_project/profiles.yml new file mode 100644 index 0000000000..712456bffe --- /dev/null +++ b/tests/fixtures/dbt/empty_project/profiles.yml @@ -0,0 +1,13 @@ +empty_project: + + target: __DEFAULT_TARGET__ + + outputs: + __DEFAULT_TARGET__: + type: duckdb + # database is required for dbt < 1.5 where our adapter deliberately doesnt infer the database from the path and + # defaults it to "main", which raises a "project catalog doesnt match context catalog" error + # ref: https://github.com/SQLMesh/sqlmesh/pull/1109 + database: empty_project + path: 'empty_project.duckdb' + threads: 4 diff --git a/tests/fixtures/dbt/jaffle_shop_duckdb/dbt_project.yml b/tests/fixtures/dbt/jaffle_shop_duckdb/dbt_project.yml new file mode 100644 index 0000000000..1b71726467 --- /dev/null +++ b/tests/fixtures/dbt/jaffle_shop_duckdb/dbt_project.yml @@ -0,0 +1,34 @@ +name: 'jaffle_shop' + +config-version: 2 +version: '0.1' + +profile: 'jaffle_shop' + +model-paths: ["models"] +seed-paths: ["seeds"] +test-paths: ["tests"] +analysis-paths: ["analysis"] +macro-paths: ["macros"] + +target-path: "target" +clean-targets: + - "target" + - "dbt_modules" + - "logs" + +require-dbt-version: [">=1.0.0", "<2.0.0"] + +seeds: + +docs: + node_color: '#cd7f32' + +models: + jaffle_shop: + +materialized: table + staging: + +materialized: view + +docs: + node_color: 'silver' + +docs: + node_color: 'gold' diff --git a/tests/fixtures/dbt/jaffle_shop_duckdb/models/customers.sql b/tests/fixtures/dbt/jaffle_shop_duckdb/models/customers.sql new file mode 100644 index 0000000000..016a004fe5 --- /dev/null +++ b/tests/fixtures/dbt/jaffle_shop_duckdb/models/customers.sql @@ -0,0 +1,69 @@ +with customers as ( + + select * from {{ ref('stg_customers') }} + +), + +orders as ( + + select * from {{ ref('stg_orders') }} + +), + +payments as ( + + select * from {{ ref('stg_payments') }} + +), + +customer_orders as ( + + select + customer_id, + + min(order_date) as first_order, + max(order_date) as most_recent_order, + count(order_id) as number_of_orders + from orders + + group by customer_id + +), + +customer_payments as ( + + select + orders.customer_id, + sum(amount) as total_amount + + from payments + + left join orders on + payments.order_id = orders.order_id + + group by orders.customer_id + +), + +final as ( + + select + customers.customer_id, + customers.first_name, + customers.last_name, + customer_orders.first_order, + customer_orders.most_recent_order, + customer_orders.number_of_orders, + customer_payments.total_amount as customer_lifetime_value + + from customers + + left join customer_orders + on customers.customer_id = customer_orders.customer_id + + left join customer_payments + on customers.customer_id = customer_payments.customer_id + +) + +select * from final diff --git a/tests/fixtures/dbt/jaffle_shop_duckdb/models/docs.md b/tests/fixtures/dbt/jaffle_shop_duckdb/models/docs.md new file mode 100644 index 0000000000..c6ae93be07 --- /dev/null +++ b/tests/fixtures/dbt/jaffle_shop_duckdb/models/docs.md @@ -0,0 +1,14 @@ +{% docs orders_status %} + +Orders can be one of the following statuses: + +| status | description | +|----------------|------------------------------------------------------------------------------------------------------------------------| +| placed | The order has been placed but has not yet left the warehouse | +| shipped | The order has ben shipped to the customer and is currently in transit | +| completed | The order has been received by the customer | +| return_pending | The customer has indicated that they would like to return the order, but it has not yet been received at the warehouse | +| returned | The order has been returned by the customer and received at the warehouse | + + +{% enddocs %} diff --git a/tests/fixtures/dbt/jaffle_shop_duckdb/models/orders.sql b/tests/fixtures/dbt/jaffle_shop_duckdb/models/orders.sql new file mode 100644 index 0000000000..cbb2934911 --- /dev/null +++ b/tests/fixtures/dbt/jaffle_shop_duckdb/models/orders.sql @@ -0,0 +1,56 @@ +{% set payment_methods = ['credit_card', 'coupon', 'bank_transfer', 'gift_card'] %} + +with orders as ( + + select * from {{ ref('stg_orders') }} + +), + +payments as ( + + select * from {{ ref('stg_payments') }} + +), + +order_payments as ( + + select + order_id, + + {% for payment_method in payment_methods -%} + sum(case when payment_method = '{{ payment_method }}' then amount else 0 end) as {{ payment_method }}_amount, + {% endfor -%} + + sum(amount) as total_amount + + from payments + + group by order_id + +), + +final as ( + + select + orders.order_id, + orders.customer_id, + orders.order_date, + orders.status, + + {% for payment_method in payment_methods -%} + + order_payments.{{ payment_method }}_amount, + + {% endfor -%} + + order_payments.total_amount as amount + + from orders + + + left join order_payments + on orders.order_id = order_payments.order_id + +) + +select * from final diff --git a/tests/fixtures/dbt/jaffle_shop_duckdb/models/overview.md b/tests/fixtures/dbt/jaffle_shop_duckdb/models/overview.md new file mode 100644 index 0000000000..0544c42b17 --- /dev/null +++ b/tests/fixtures/dbt/jaffle_shop_duckdb/models/overview.md @@ -0,0 +1,11 @@ +{% docs __overview__ %} + +## Data Documentation for Jaffle Shop + +`jaffle_shop` is a fictional ecommerce store. + +This [dbt](https://www.getdbt.com/) project is for testing out code. + +The source code can be found [here](https://github.com/clrcrl/jaffle_shop). + +{% enddocs %} diff --git a/tests/fixtures/dbt/jaffle_shop_duckdb/models/schema.yml b/tests/fixtures/dbt/jaffle_shop_duckdb/models/schema.yml new file mode 100644 index 0000000000..381349cfda --- /dev/null +++ b/tests/fixtures/dbt/jaffle_shop_duckdb/models/schema.yml @@ -0,0 +1,82 @@ +version: 2 + +models: + - name: customers + description: This table has basic information about a customer, as well as some derived facts based on a customer's orders + + columns: + - name: customer_id + description: This is a unique identifier for a customer + tests: + - unique + - not_null + + - name: first_name + description: Customer's first name. PII. + + - name: last_name + description: Customer's last name. PII. + + - name: first_order + description: Date (UTC) of a customer's first order + + - name: most_recent_order + description: Date (UTC) of a customer's most recent order + + - name: number_of_orders + description: Count of the number of orders a customer has placed + + - name: total_order_amount + description: Total value (AUD) of a customer's orders + + - name: orders + description: This table has basic information about orders, as well as some derived facts based on payments + + columns: + - name: order_id + tests: + - unique + - not_null + description: This is a unique identifier for an order + + - name: customer_id + description: Foreign key to the customers table + tests: + - not_null + - relationships: + to: ref('customers') + field: customer_id + + - name: order_date + description: Date (UTC) that the order was placed + + - name: status + description: '{{ doc("orders_status") }}' + tests: + - accepted_values: + values: ['placed', 'shipped', 'completed', 'return_pending', 'returned'] + + - name: amount + description: Total amount (AUD) of the order + tests: + - not_null + + - name: credit_card_amount + description: Amount of the order (AUD) paid for by credit card + tests: + - not_null + + - name: coupon_amount + description: Amount of the order (AUD) paid for by coupon + tests: + - not_null + + - name: bank_transfer_amount + description: Amount of the order (AUD) paid for by bank transfer + tests: + - not_null + + - name: gift_card_amount + description: Amount of the order (AUD) paid for by gift card + tests: + - not_null diff --git a/tests/fixtures/dbt/jaffle_shop_duckdb/models/staging/schema.yml b/tests/fixtures/dbt/jaffle_shop_duckdb/models/staging/schema.yml new file mode 100644 index 0000000000..c207e4cf52 --- /dev/null +++ b/tests/fixtures/dbt/jaffle_shop_duckdb/models/staging/schema.yml @@ -0,0 +1,31 @@ +version: 2 + +models: + - name: stg_customers + columns: + - name: customer_id + tests: + - unique + - not_null + + - name: stg_orders + columns: + - name: order_id + tests: + - unique + - not_null + - name: status + tests: + - accepted_values: + values: ['placed', 'shipped', 'completed', 'return_pending', 'returned'] + + - name: stg_payments + columns: + - name: payment_id + tests: + - unique + - not_null + - name: payment_method + tests: + - accepted_values: + values: ['credit_card', 'coupon', 'bank_transfer', 'gift_card'] diff --git a/tests/fixtures/dbt/jaffle_shop_duckdb/models/staging/stg_customers.sql b/tests/fixtures/dbt/jaffle_shop_duckdb/models/staging/stg_customers.sql new file mode 100644 index 0000000000..cad0472695 --- /dev/null +++ b/tests/fixtures/dbt/jaffle_shop_duckdb/models/staging/stg_customers.sql @@ -0,0 +1,22 @@ +with source as ( + + {#- + Normally we would select from the table here, but we are using seeds to load + our data in this project + #} + select * from {{ ref('raw_customers') }} + +), + +renamed as ( + + select + id as customer_id, + first_name, + last_name + + from source + +) + +select * from renamed diff --git a/tests/fixtures/dbt/jaffle_shop_duckdb/models/staging/stg_orders.sql b/tests/fixtures/dbt/jaffle_shop_duckdb/models/staging/stg_orders.sql new file mode 100644 index 0000000000..a654dcb947 --- /dev/null +++ b/tests/fixtures/dbt/jaffle_shop_duckdb/models/staging/stg_orders.sql @@ -0,0 +1,23 @@ +with source as ( + + {#- + Normally we would select from the table here, but we are using seeds to load + our data in this project + #} + select * from {{ ref('raw_orders') }} + +), + +renamed as ( + + select + id as order_id, + user_id as customer_id, + order_date, + status + + from source + +) + +select * from renamed diff --git a/tests/fixtures/dbt/jaffle_shop_duckdb/models/staging/stg_payments.sql b/tests/fixtures/dbt/jaffle_shop_duckdb/models/staging/stg_payments.sql new file mode 100644 index 0000000000..700cf7f4f6 --- /dev/null +++ b/tests/fixtures/dbt/jaffle_shop_duckdb/models/staging/stg_payments.sql @@ -0,0 +1,25 @@ +with source as ( + + {#- + Normally we would select from the table here, but we are using seeds to load + our data in this project + #} + select * from {{ ref('raw_payments') }} + +), + +renamed as ( + + select + id as payment_id, + order_id, + payment_method, + + -- `amount` is currently stored in cents, so we convert it to dollars + amount / 100 as amount + + from source + +) + +select * from renamed diff --git a/tests/fixtures/dbt/jaffle_shop_duckdb/profiles.yml b/tests/fixtures/dbt/jaffle_shop_duckdb/profiles.yml new file mode 100644 index 0000000000..9008a2d62c --- /dev/null +++ b/tests/fixtures/dbt/jaffle_shop_duckdb/profiles.yml @@ -0,0 +1,8 @@ +jaffle_shop: + + target: dev + outputs: + dev: + type: duckdb + path: 'jaffle_shop.duckdb' + threads: 24 diff --git a/tests/schedulers/airflow/__init__.py b/tests/fixtures/dbt/jaffle_shop_duckdb/seeds/.gitkeep similarity index 100% rename from tests/schedulers/airflow/__init__.py rename to tests/fixtures/dbt/jaffle_shop_duckdb/seeds/.gitkeep diff --git a/tests/fixtures/dbt/jaffle_shop_duckdb/seeds/raw_customers.csv b/tests/fixtures/dbt/jaffle_shop_duckdb/seeds/raw_customers.csv new file mode 100644 index 0000000000..b3e6747d69 --- /dev/null +++ b/tests/fixtures/dbt/jaffle_shop_duckdb/seeds/raw_customers.csv @@ -0,0 +1,101 @@ +id,first_name,last_name +1,Michael,P. +2,Shawn,M. +3,Kathleen,P. +4,Jimmy,C. +5,Katherine,R. +6,Sarah,R. +7,Martin,M. +8,Frank,R. +9,Jennifer,F. +10,Henry,W. +11,Fred,S. +12,Amy,D. +13,Kathleen,M. +14,Steve,F. +15,Teresa,H. +16,Amanda,H. +17,Kimberly,R. +18,Johnny,K. +19,Virginia,F. +20,Anna,A. +21,Willie,H. +22,Sean,H. +23,Mildred,A. +24,David,G. +25,Victor,H. +26,Aaron,R. +27,Benjamin,B. +28,Lisa,W. +29,Benjamin,K. +30,Christina,W. +31,Jane,G. +32,Thomas,O. +33,Katherine,M. +34,Jennifer,S. +35,Sara,T. +36,Harold,O. +37,Shirley,J. +38,Dennis,J. +39,Louise,W. +40,Maria,A. +41,Gloria,C. +42,Diana,S. +43,Kelly,N. +44,Jane,R. +45,Scott,B. +46,Norma,C. +47,Marie,P. +48,Lillian,C. +49,Judy,N. +50,Billy,L. +51,Howard,R. +52,Laura,F. +53,Anne,B. +54,Rose,M. +55,Nicholas,R. +56,Joshua,K. +57,Paul,W. +58,Kathryn,K. +59,Adam,A. +60,Norma,W. +61,Timothy,R. +62,Elizabeth,P. +63,Edward,G. +64,David,C. +65,Brenda,W. +66,Adam,W. +67,Michael,H. +68,Jesse,E. +69,Janet,P. +70,Helen,F. +71,Gerald,C. +72,Kathryn,O. +73,Alan,B. +74,Harry,A. +75,Andrea,H. +76,Barbara,W. +77,Anne,W. +78,Harry,H. +79,Jack,R. +80,Phillip,H. +81,Shirley,H. +82,Arthur,D. +83,Virginia,R. +84,Christina,R. +85,Theresa,M. +86,Jason,C. +87,Phillip,B. +88,Adam,T. +89,Margaret,J. +90,Paul,P. +91,Todd,W. +92,Willie,O. +93,Frances,R. +94,Gregory,H. +95,Lisa,P. +96,Jacqueline,A. +97,Shirley,D. +98,Nicole,M. +99,Mary,G. +100,Jean,M. diff --git a/tests/fixtures/dbt/jaffle_shop_duckdb/seeds/raw_orders.csv b/tests/fixtures/dbt/jaffle_shop_duckdb/seeds/raw_orders.csv new file mode 100644 index 0000000000..7c2be07888 --- /dev/null +++ b/tests/fixtures/dbt/jaffle_shop_duckdb/seeds/raw_orders.csv @@ -0,0 +1,100 @@ +id,user_id,order_date,status +1,1,2018-01-01,returned +2,3,2018-01-02,completed +3,94,2018-01-04,completed +4,50,2018-01-05,completed +5,64,2018-01-05,completed +6,54,2018-01-07,completed +7,88,2018-01-09,completed +8,2,2018-01-11,returned +9,53,2018-01-12,completed +10,7,2018-01-14,completed +11,99,2018-01-14,completed +12,59,2018-01-15,completed +13,84,2018-01-17,completed +14,40,2018-01-17,returned +15,25,2018-01-17,completed +16,39,2018-01-18,completed +17,71,2018-01-18,completed +18,64,2018-01-20,returned +19,54,2018-01-22,completed +20,20,2018-01-23,completed +21,71,2018-01-23,completed +22,86,2018-01-24,completed +23,22,2018-01-26,return_pending +24,3,2018-01-27,completed +25,51,2018-01-28,completed +26,32,2018-01-28,completed +27,94,2018-01-29,completed +28,8,2018-01-29,completed +29,57,2018-01-31,completed +30,69,2018-02-02,completed +31,16,2018-02-02,completed +32,28,2018-02-04,completed +33,42,2018-02-04,completed +34,38,2018-02-06,completed +35,80,2018-02-08,completed +36,85,2018-02-10,completed +37,1,2018-02-10,completed +38,51,2018-02-10,completed +39,26,2018-02-11,completed +40,33,2018-02-13,completed +41,99,2018-02-14,completed +42,92,2018-02-16,completed +43,31,2018-02-17,completed +44,66,2018-02-17,completed +45,22,2018-02-17,completed +46,6,2018-02-19,completed +47,50,2018-02-20,completed +48,27,2018-02-21,completed +49,35,2018-02-21,completed +50,51,2018-02-23,completed +51,71,2018-02-24,completed +52,54,2018-02-25,return_pending +53,34,2018-02-26,completed +54,54,2018-02-26,completed +55,18,2018-02-27,completed +56,79,2018-02-28,completed +57,93,2018-03-01,completed +58,22,2018-03-01,completed +59,30,2018-03-02,completed +60,12,2018-03-03,completed +61,63,2018-03-03,completed +62,57,2018-03-05,completed +63,70,2018-03-06,completed +64,13,2018-03-07,completed +65,26,2018-03-08,completed +66,36,2018-03-10,completed +67,79,2018-03-11,completed +68,53,2018-03-11,completed +69,3,2018-03-11,completed +70,8,2018-03-12,completed +71,42,2018-03-12,shipped +72,30,2018-03-14,shipped +73,19,2018-03-16,completed +74,9,2018-03-17,shipped +75,69,2018-03-18,completed +76,25,2018-03-20,completed +77,35,2018-03-21,shipped +78,90,2018-03-23,shipped +79,52,2018-03-23,shipped +80,11,2018-03-23,shipped +81,76,2018-03-23,shipped +82,46,2018-03-24,shipped +83,54,2018-03-24,shipped +84,70,2018-03-26,placed +85,47,2018-03-26,shipped +86,68,2018-03-26,placed +87,46,2018-03-27,placed +88,91,2018-03-27,shipped +89,21,2018-03-28,placed +90,66,2018-03-30,shipped +91,47,2018-03-31,placed +92,84,2018-04-02,placed +93,66,2018-04-03,placed +94,63,2018-04-03,placed +95,27,2018-04-04,placed +96,90,2018-04-06,placed +97,89,2018-04-07,placed +98,41,2018-04-07,placed +99,85,2018-04-09,placed diff --git a/tests/fixtures/dbt/jaffle_shop_duckdb/seeds/raw_payments.csv b/tests/fixtures/dbt/jaffle_shop_duckdb/seeds/raw_payments.csv new file mode 100644 index 0000000000..a587baab59 --- /dev/null +++ b/tests/fixtures/dbt/jaffle_shop_duckdb/seeds/raw_payments.csv @@ -0,0 +1,114 @@ +id,order_id,payment_method,amount +1,1,credit_card,1000 +2,2,credit_card,2000 +3,3,coupon,100 +4,4,coupon,2500 +5,5,bank_transfer,1700 +6,6,credit_card,600 +7,7,credit_card,1600 +8,8,credit_card,2300 +9,9,gift_card,2300 +10,9,bank_transfer,0 +11,10,bank_transfer,2600 +12,11,credit_card,2700 +13,12,credit_card,100 +14,13,credit_card,500 +15,13,bank_transfer,1400 +16,14,bank_transfer,300 +17,15,coupon,2200 +18,16,credit_card,1000 +19,17,bank_transfer,200 +20,18,credit_card,500 +21,18,credit_card,800 +22,19,gift_card,600 +23,20,bank_transfer,1500 +24,21,credit_card,1200 +25,22,bank_transfer,800 +26,23,gift_card,2300 +27,24,coupon,2600 +28,25,bank_transfer,2000 +29,25,credit_card,2200 +30,25,coupon,1600 +31,26,credit_card,3000 +32,27,credit_card,2300 +33,28,bank_transfer,1900 +34,29,bank_transfer,1200 +35,30,credit_card,1300 +36,31,credit_card,1200 +37,32,credit_card,300 +38,33,credit_card,2200 +39,34,bank_transfer,1500 +40,35,credit_card,2900 +41,36,bank_transfer,900 +42,37,credit_card,2300 +43,38,credit_card,1500 +44,39,bank_transfer,800 +45,40,credit_card,1400 +46,41,credit_card,1700 +47,42,coupon,1700 +48,43,gift_card,1800 +49,44,gift_card,1100 +50,45,bank_transfer,500 +51,46,bank_transfer,800 +52,47,credit_card,2200 +53,48,bank_transfer,300 +54,49,credit_card,600 +55,49,credit_card,900 +56,50,credit_card,2600 +57,51,credit_card,2900 +58,51,credit_card,100 +59,52,bank_transfer,1500 +60,53,credit_card,300 +61,54,credit_card,1800 +62,54,bank_transfer,1100 +63,55,credit_card,2900 +64,56,credit_card,400 +65,57,bank_transfer,200 +66,58,coupon,1800 +67,58,gift_card,600 +68,59,gift_card,2800 +69,60,credit_card,400 +70,61,bank_transfer,1600 +71,62,gift_card,1400 +72,63,credit_card,2900 +73,64,bank_transfer,2600 +74,65,credit_card,0 +75,66,credit_card,2800 +76,67,bank_transfer,400 +77,67,credit_card,1900 +78,68,credit_card,1600 +79,69,credit_card,1900 +80,70,credit_card,2600 +81,71,credit_card,500 +82,72,credit_card,2900 +83,73,bank_transfer,300 +84,74,credit_card,3000 +85,75,credit_card,1900 +86,76,coupon,200 +87,77,credit_card,0 +88,77,bank_transfer,1900 +89,78,bank_transfer,2600 +90,79,credit_card,1800 +91,79,credit_card,900 +92,80,gift_card,300 +93,81,coupon,200 +94,82,credit_card,800 +95,83,credit_card,100 +96,84,bank_transfer,2500 +97,85,bank_transfer,1700 +98,86,coupon,2300 +99,87,gift_card,3000 +100,87,credit_card,2600 +101,88,credit_card,2900 +102,89,bank_transfer,2200 +103,90,bank_transfer,200 +104,91,credit_card,1900 +105,92,bank_transfer,1500 +106,92,coupon,200 +107,93,gift_card,2600 +108,94,coupon,700 +109,95,coupon,2400 +110,96,gift_card,1700 +111,97,bank_transfer,1400 +112,98,bank_transfer,1000 +113,99,credit_card,2400 diff --git a/tests/fixtures/dbt/sushi_test/analyses/waiter_performance_analysis.sql b/tests/fixtures/dbt/sushi_test/analyses/waiter_performance_analysis.sql new file mode 100644 index 0000000000..06d8bdb8fd --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/analyses/waiter_performance_analysis.sql @@ -0,0 +1,10 @@ +-- Simple analysis: Top performing waiters by total revenue + +SELECT + waiter_id, + SUM(revenue) AS total_revenue, + COUNT(*) AS days_worked +FROM {{ ref('waiter_revenue_by_day') }} +GROUP BY waiter_id +ORDER BY total_revenue DESC +LIMIT 10 \ No newline at end of file diff --git a/tests/fixtures/dbt/sushi_test/config.py b/tests/fixtures/dbt/sushi_test/config.py index 6f7f94530f..a68b3e2333 100644 --- a/tests/fixtures/dbt/sushi_test/config.py +++ b/tests/fixtures/dbt/sushi_test/config.py @@ -1,21 +1,29 @@ from pathlib import Path -from sqlmesh.core.config import AirflowSchedulerConfig, ModelDefaultsConfig +from sqlmesh.core.config import ModelDefaultsConfig from sqlmesh.dbt.loader import sqlmesh_config -variables = {"start": "Jan 1 2022"} - config = sqlmesh_config( - Path(__file__).parent, variables=variables, model_defaults=ModelDefaultsConfig(dialect="duckdb") + Path(__file__).parent, model_defaults=ModelDefaultsConfig(dialect="duckdb", start="Jan 1 2022") ) test_config = config -airflow_config = sqlmesh_config( +test_config_with_var_override = sqlmesh_config( + Path(__file__).parent, + model_defaults=ModelDefaultsConfig(dialect="duckdb", start="Jan 1 2022"), + variables={ + "some_var": "overridden_from_config_py", + }, +) + + +test_config_with_normalization_strategy = sqlmesh_config( Path(__file__).parent, - default_scheduler=AirflowSchedulerConfig(), - variables=variables, + model_defaults=ModelDefaultsConfig( + dialect="duckdb,normalization_strategy=LOWERCASE", start="Jan 1 2022" + ), ) diff --git a/tests/fixtures/dbt/sushi_test/dbt_packages/my_macros/dbt_project.yml b/tests/fixtures/dbt/sushi_test/dbt_packages/my_macros/dbt_project.yml new file mode 100644 index 0000000000..f0386b4e57 --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/dbt_packages/my_macros/dbt_project.yml @@ -0,0 +1,17 @@ +name: 'my_macros' +version: '1.0.0' +config-version: 2 + +profile: 'my_macros' + +model-paths: ["models"] +analysis-paths: ["analyses"] +test-paths: ["tests"] +seed-paths: ["seeds"] +macro-paths: ["macros"] +snapshot-paths: ["snapshots"] + +target-path: "target" +clean-targets: + - "target" + - "dbt_packages" diff --git a/tests/fixtures/dbt/sushi_test/dbt_packages/my_macros/macros/log_value_alt.sql b/tests/fixtures/dbt/sushi_test/dbt_packages/my_macros/macros/log_value_alt.sql new file mode 100644 index 0000000000..a88f316d3e --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/dbt_packages/my_macros/macros/log_value_alt.sql @@ -0,0 +1,3 @@ +{% macro log_value_alt(v) %} + {{ log("Entered value is: " ~ v) }} +{% endmacro %} diff --git a/tests/fixtures/dbt/sushi_test/dbt_packages/my_package/dbt_project.yml b/tests/fixtures/dbt/sushi_test/dbt_packages/my_package/dbt_project.yml new file mode 100644 index 0000000000..9c4797ebe0 --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/dbt_packages/my_package/dbt_project.yml @@ -0,0 +1,17 @@ +name: 'my_project' +version: '1.0.0' +config-version: 2 + +profile: 'my_project' + +model-paths: ["models"] +analysis-paths: ["analyses"] +test-paths: ["tests"] +seed-paths: ["seeds"] +macro-paths: ["macros"] +snapshot-paths: ["snapshots"] + +target-path: "target" +clean-targets: + - "target" + - "dbt_packages" diff --git a/tests/fixtures/dbt/sushi_test/dbt_packages/my_package/models/dummy_model.sql b/tests/fixtures/dbt/sushi_test/dbt_packages/my_package/models/dummy_model.sql new file mode 100644 index 0000000000..ff815fe55d --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/dbt_packages/my_package/models/dummy_model.sql @@ -0,0 +1,4 @@ +{{ log_value_alt(1) }} + +SELECT + 1 AS c diff --git a/tests/fixtures/dbt/sushi_test/dbt_project.yml b/tests/fixtures/dbt/sushi_test/dbt_project.yml index 1c72288c80..0b5f6b0f83 100644 --- a/tests/fixtures/dbt/sushi_test/dbt_project.yml +++ b/tests/fixtures/dbt/sushi_test/dbt_project.yml @@ -8,7 +8,10 @@ model-paths: ["models"] analysis-paths: ["analyses"] test-paths: ["tests"] seed-paths: ["seeds"] -macro-paths: ["macros"] +macro-paths: [ + "macros", + "dbt_packages/my_macros/macros", +] snapshot-paths: ["snapshots"] target-path: "target" # directory which will store compiled SQL files @@ -20,31 +23,59 @@ clean-targets: # directories to be removed by `dbt clean` # Full documentation: https://docs.getdbt.com/docs/configuring-models models: - +start: "{{ var('start') }}" sushi: +materialized: table +pre-hook: - '{{ log("pre-hook") }}' - +post-hook: + +post-hook: - '{{ log("post-hook") }}' seeds: sushi: +pre-hook: - '{{ log("pre-hook") }}' - +post-hook: + +post-hook: - '{{ log("post-hook") }}' +sources: + +quoting: + identifier: false + vars: - top_waiters:limit: 10 + top_waiters:limit: "{{ get_top_waiters_limit() }}" 'top_waiters:revenue': "revenue" # The following are only used for testing purposes customers:boo: ["a", "b"] yet_another_var: 1 + dynamic_test_var: 3 + some_var: 'should be overridden in customers package' customers: some_var: ["foo", "bar"] 'customers:bla': false 'customers:customer_id': "customer_id" + + nested_vars: + some_nested_var: 2 + + list_var: + - name: 'item1' + value: 1 + - name: 'item2' + value: 2 + + # Despite this being an invalid variable definition, dbt doesn't mind if it's unused + invalid_var: "{{ ref('ref_without_closing_paren' }}" + + +on-run-start: + - 'CREATE TABLE IF NOT EXISTS analytic_stats (physical_table VARCHAR, evaluation_time VARCHAR);' + - 'CREATE TABLE IF NOT EXISTS to_be_executed_last (col VARCHAR);' + - SELECT {{ var("yet_another_var") }} AS var, '{{ source("raw", "items").identifier }}' AS src, '{{ ref("waiters").identifier }}' AS ref; + - "{{ log_value('on-run-start') }}" +on-run-end: + - '{{ create_tables(schemas) }}' + - 'DROP TABLE to_be_executed_last;' + - '{{ graph_usage() }}' diff --git a/tests/fixtures/dbt/sushi_test/macros/check_model_is_table.sql b/tests/fixtures/dbt/sushi_test/macros/check_model_is_table.sql new file mode 100644 index 0000000000..42dc5615e4 --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/macros/check_model_is_table.sql @@ -0,0 +1,15 @@ +{%- macro check_model_is_table(model) -%} + {%- if model.config.materialized != 'table' -%} + {%- do exceptions.raise_compiler_error( + "Model must use the table materialization. Please check any model overrides." + ) -%} + {%- endif -%} +{%- endmacro -%} + +{%- macro check_model_is_table_alt(foo) -%} + {%- if foo.config.materialized != 'table' -%} + {%- do exceptions.raise_compiler_error( + "Model must use the table materialization. Please check any model overrides." + ) -%} + {%- endif -%} +{%- endmacro -%} diff --git a/tests/fixtures/dbt/sushi_test/macros/create_tables.sql b/tests/fixtures/dbt/sushi_test/macros/create_tables.sql new file mode 100644 index 0000000000..57616b7389 --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/macros/create_tables.sql @@ -0,0 +1,5 @@ +{% macro create_tables(schemas) %} + {% for schema in schemas %} + create or replace table schema_table_{{schema}} as select '{{schema}}' as schema; + {% endfor%} +{% endmacro %} \ No newline at end of file diff --git a/tests/fixtures/dbt/sushi_test/macros/distinct.sql b/tests/fixtures/dbt/sushi_test/macros/distinct.sql new file mode 100644 index 0000000000..1b339a9349 --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/macros/distinct.sql @@ -0,0 +1 @@ +{% macro default__select_distinct() %}distinct{% endmacro %} diff --git a/tests/fixtures/dbt/sushi_test/macros/graph_usage.sql b/tests/fixtures/dbt/sushi_test/macros/graph_usage.sql new file mode 100644 index 0000000000..8b133ec280 --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/macros/graph_usage.sql @@ -0,0 +1,18 @@ +{% macro graph_usage() %} +{% if execute %} + {% set model_nodes = graph.nodes.values() + | selectattr("resource_type", "equalto", "model") + | list %} + + {% set out = [] %} + {% for node in model_nodes %} + {% set line = "select '" ~ node.unique_id ~ "' as unique_id, '" ~ node.config.materialized ~ "' as materialized" %} + {% do out.append(line) %} + {% endfor %} + + {% if out %} + {% set sql_statement = "create or replace table graph_table as\n" ~ (out | join('\nunion all\n')) %} + {{ return(sql_statement) }} + {% endif %} +{% endif %} +{% endmacro %} diff --git a/tests/fixtures/dbt/sushi_test/macros/insert_hook.sql b/tests/fixtures/dbt/sushi_test/macros/insert_hook.sql new file mode 100644 index 0000000000..aa27a7fe6d --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/macros/insert_hook.sql @@ -0,0 +1,14 @@ +{% macro insert_into_shared_hook_table(hook_name) %} +INSERT INTO shared_hook_table ( + id, + hook_name, + execution_order, + created_at +) +VALUES ( + (SELECT COALESCE(MAX(id), 0) + 1 FROM shared_hook_table), + '{{ hook_name }}', + (SELECT COALESCE(MAX(id), 0) + 1 FROM shared_hook_table), + NOW() +) +{% endmacro %} diff --git a/tests/fixtures/dbt/sushi_test/macros/materializations/custom_incremental.sql b/tests/fixtures/dbt/sushi_test/macros/materializations/custom_incremental.sql new file mode 100644 index 0000000000..c61899c8ff --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/macros/materializations/custom_incremental.sql @@ -0,0 +1,61 @@ +{%- macro build_incremental_filter_sql(sql, time_column, existing_relation, interval_config) -%} + {# macro to build the filter and also test use of macro inside materialisation #} + WITH source_data AS ( + {{ sql }} + ) + SELECT * FROM source_data + WHERE {{ time_column }} >= ( + SELECT COALESCE(MAX({{ time_column }}), '1900-01-01') + {%- if interval_config %} + INTERVAL {{ interval_config }} {%- endif %} + FROM {{ existing_relation }} + ) +{%- endmacro -%} + +{%- materialization custom_incremental, default -%} + {%- set existing_relation = adapter.get_relation(database=database, schema=schema, identifier=identifier) -%} + {%- set new_relation = api.Relation.create(database=database, schema=schema, identifier=identifier) -%} + {%- set temp_relation = make_temp_relation(new_relation) -%} + + {%- set time_column = config.get('time_column') -%} + {%- set interval_config = config.get('interval') -%} + + {{ run_hooks(pre_hooks) }} + + {%- if existing_relation is none -%} + {# The first insert creates new table if it doesn't exist #} + {%- call statement('main') -%} + CREATE TABLE {{ new_relation }} + AS {{ sql }} + {%- endcall -%} + {%- else -%} + {# Incremental load, appending new data with optional time filtering #} + {%- if time_column is not none -%} + {%- set filtered_sql -%} + {{ build_incremental_filter_sql(sql, time_column, existing_relation, interval_config) }} + {%- endset -%} + {%- else -%} + {%- set filtered_sql = sql -%} + {%- endif -%} + + {{log(filtered_sql, info=true)}} + + {%- call statement('create_temp') -%} + {{ create_table_as(True, temp_relation, filtered_sql) }} + CREATE TABLE {{ temp_relation }} + AS {{ filtered_sql }} + {%- endcall -%} + + {%- call statement('insert') -%} + INSERT INTO {{ new_relation }} + SELECT * FROM {{ temp_relation }} + {%- endcall -%} + + {%- call statement('drop_temp') -%} + DROP TABLE {{ temp_relation }} + {%- endcall -%} + {%- endif -%} + + {{ run_hooks(post_hooks) }} + + {{ return({'relations': [new_relation]}) }} +{%- endmaterialization -%} diff --git a/tests/fixtures/dbt/sushi_test/macros/test_dependencies.sql b/tests/fixtures/dbt/sushi_test/macros/test_dependencies.sql index f3440b1007..88518df380 100644 --- a/tests/fixtures/dbt/sushi_test/macros/test_dependencies.sql +++ b/tests/fixtures/dbt/sushi_test/macros/test_dependencies.sql @@ -4,4 +4,11 @@ {% macro nested_test_dependencies() %} {{ log(var("yet_another_var", 2)) }} + {{ log(var("nested_vars")['some_nested_var']) }} +{% endmacro %} + + +{% macro dynamic_var_name_dependency(var_name) %} + {% set results = run_query('select 1 as one') %} + {{ return(var(var_name)) }} {% endmacro %} diff --git a/tests/fixtures/dbt/sushi_test/macros/top_waiters_limit.sql b/tests/fixtures/dbt/sushi_test/macros/top_waiters_limit.sql new file mode 100644 index 0000000000..5fb56335e9 --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/macros/top_waiters_limit.sql @@ -0,0 +1,3 @@ +{% macro get_top_waiters_limit() %} +10 +{% endmacro %} diff --git a/tests/fixtures/dbt/sushi_test/models/custom_incremental_model.sql b/tests/fixtures/dbt/sushi_test/models/custom_incremental_model.sql new file mode 100644 index 0000000000..c7e9a8f7ea --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/models/custom_incremental_model.sql @@ -0,0 +1,20 @@ +{{ config( + materialized='custom_incremental', + pre_hook=[ + "CREATE TABLE IF NOT EXISTS hook_table (id INTEGER, length_col TEXT, updated_at TIMESTAMP)" + ], + post_hook=[ + """ + INSERT INTO hook_table + SELECT + COALESCE(MAX(id), 0) + 1 AS id, + '{{ model.raw_code | length }}' AS length_col, + CURRENT_TIMESTAMP AS updated_at + FROM hook_table + """ + ] +) }} + +SELECT + current_timestamp as created_at, + hash(current_timestamp) as id, \ No newline at end of file diff --git a/tests/fixtures/dbt/sushi_test/models/custom_incremental_with_filter.sql b/tests/fixtures/dbt/sushi_test/models/custom_incremental_with_filter.sql new file mode 100644 index 0000000000..94cbdc9333 --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/models/custom_incremental_with_filter.sql @@ -0,0 +1,9 @@ +{{ config( + materialized='custom_incremental', + time_column='created_at', + interval='2 day' +) }} + +SELECT + CAST('{{ run_started_at }}' AS TIMESTAMP) as created_at, + hash('{{ run_started_at }}') as id, \ No newline at end of file diff --git a/tests/fixtures/dbt/sushi_test/models/dynamic_graph_model.sql b/tests/fixtures/dbt/sushi_test/models/dynamic_graph_model.sql new file mode 100644 index 0000000000..18da9c3c7b --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/models/dynamic_graph_model.sql @@ -0,0 +1,8 @@ + +{% if execute %} + {% for dataset, nodes in graph.nodes.values() | selectattr("resource_type", "equalto", "model") | groupby('schema') %} + {% if loop.first %} + SELECT 1 AS c + {% endif %} + {% endfor %} +{% endif %} diff --git a/tests/schedulers/airflow/operators/__init__.py b/tests/fixtures/dbt/sushi_test/models/empty_model.sql similarity index 100% rename from tests/schedulers/airflow/operators/__init__.py rename to tests/fixtures/dbt/sushi_test/models/empty_model.sql diff --git a/tests/fixtures/dbt/sushi_test/models/model_with_raw_code.sql b/tests/fixtures/dbt/sushi_test/models/model_with_raw_code.sql new file mode 100644 index 0000000000..1424f6e970 --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/models/model_with_raw_code.sql @@ -0,0 +1,11 @@ +{{ + config( + pre_hook=['CREATE TABLE IF NOT EXISTS t AS SELECT \'Length is {{ model.raw_code|length }}\' AS length_col'] + ) +}} + +{{ check_model_is_table(model) }} +{{ check_model_is_table_alt(model) }} + +SELECT + 1 AS c diff --git a/tests/fixtures/dbt/sushi_test/models/model_with_transaction_hooks.sql b/tests/fixtures/dbt/sushi_test/models/model_with_transaction_hooks.sql new file mode 100644 index 0000000000..49883f73df --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/models/model_with_transaction_hooks.sql @@ -0,0 +1,56 @@ +{{ + config( + materialized = 'table', + + pre_hook = [ + { + "sql": " + CREATE TABLE IF NOT EXISTS hook_outside_pre_table AS + SELECT + 1 AS id, + 'outside' AS location, + 1 AS execution_order, + NOW() AS created_at + ", + "transaction": false + }, + + before_begin(" + CREATE TABLE IF NOT EXISTS shared_hook_table ( + id INT, + hook_name VARCHAR, + execution_order INT, + created_at TIMESTAMPTZ + ) + "), + + { + "sql": "{{ insert_into_shared_hook_table('inside_pre') }}", + "transaction": true + } + ], + + post_hook = [ + { + "sql": "{{ insert_into_shared_hook_table('inside_post') }}", + "transaction": true + }, + + { + "sql": " + CREATE TABLE IF NOT EXISTS hook_outside_post_table AS + SELECT + 5 AS id, + 'outside' AS location, + 5 AS execution_order, + NOW() AS created_at + ", + "transaction": false + }, + + after_commit("{{ insert_into_shared_hook_table('after_commit') }}") + ] + ) +}} + +SELECT 1 AS id, 'test' AS name; diff --git a/tests/fixtures/dbt/sushi_test/models/non_validated_model.sql b/tests/fixtures/dbt/sushi_test/models/non_validated_model.sql new file mode 100644 index 0000000000..3140c5d723 --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/models/non_validated_model.sql @@ -0,0 +1,5 @@ +{{ config(materialized='table') }} + +SELECT + 1 AS c, + 2 AS c, diff --git a/tests/fixtures/dbt/sushi_test/models/schema.yml b/tests/fixtures/dbt/sushi_test/models/schema.yml index 78d6c3e95f..24b1d4b3ee 100644 --- a/tests/fixtures/dbt/sushi_test/models/schema.yml +++ b/tests/fixtures/dbt/sushi_test/models/schema.yml @@ -1,10 +1,80 @@ version: 2 models: + - name: simple_model_a + description: A simple model for testing + columns: + - name: a + data_type: int + unit_tests: + - name: test_simple_model_a_outputs_one + description: Test that simple_model_a outputs 1 as column a + model: simple_model_a + given: [] # No input models needed + expect: + format: csv + rows: | + a + 1 + - name: simple_model_b + description: Model that references simple_model_a + columns: + - name: a + data_type: int + unit_tests: + - name: test_simple_model_b_with_mock_input + description: Test simple_model_b with mocked simple_model_a input + model: simple_model_b + given: + - input: ref('simple_model_a') + format: csv + rows: | + a + 10 + 20 + 30 + expect: + format: csv + rows: | + a + 10 + 20 + 30 + - name: test_simple_model_b_with_sql_input + description: Test simple_model_b with SQL-defined input data + model: simple_model_b + given: + - input: ref('simple_model_a') + format: sql + rows: SELECT 42 AS a + expect: + format: sql + rows: SELECT 42 AS a - name: top_waiters + description: description of top waiters + columns: + - name: waiter_id + data_type: int + - name: revenue + data_type: double + - name: model_columns + data_type: int + freshness: + warn_after: {count: 6, period: hour} + error_after: {count: 7, period: hour} config: dialect: postgres + freshness: + warn_after: {count: 8, period: hour} + error_after: {count: 9, period: hour} - name: waiters + description: '{{ doc("waiters") }}' + config: + # Exercise pre and post hooks + pre_hook: + - SELECT 1 + post_hook: + - SELECT 1 - name: waiter_as_customer_by_day - name: waiter_revenue_by_day versions: @@ -21,5 +91,45 @@ sources: schema: raw tables: - name: items + config: + meta: - name: orders + config: + meta: - name: order_items + config: + meta: + freshness: + warn_after: {count: 10, period: hour} + error_after: {count: 11, period: hour} + config: + freshness: + warn_after: {count: 12, period: hour} + error_after: {count: 13, period: hour} + + - name: parquet_file + meta: + external_location: "read_parquet('path/to/external/{name}.parquet')" + tables: + - name: items + - name: orders + +semantic_models: + - name: top_waiters + description: Some description + model: ref('top_waiters') + measures: + - name: total_waiters + agg: sum + expr: waiter + dimensions: + - name: waiter + type: categorical + +metrics: + - name: some_waiter_thing + description: Something + type: simple + label: testing + type_params: + measure: total_waiters diff --git a/tests/fixtures/dbt/sushi_test/models/top_waiters.sql b/tests/fixtures/dbt/sushi_test/models/top_waiters.sql index c5b7948dfd..ce7e2154c5 100644 --- a/tests/fixtures/dbt/sushi_test/models/top_waiters.sql +++ b/tests/fixtures/dbt/sushi_test/models/top_waiters.sql @@ -1,13 +1,18 @@ {{ - config( - materialized='view', - limit_value=var('top_waiters:limit'), - ) + config({ + 'materialized': 'view', + 'limit_value': var('top_waiters:limit'), + 'meta': {'owner': 'analytics_team', 'priority': 'high'} + }) }} +{% set columns = model.columns %} +{% set config = model.config %} + SELECT waiter_id::INT AS waiter_id, - revenue::DOUBLE AS {{ var("top_waiters:revenue") }} + revenue::DOUBLE AS {{ var("top_waiters:revenue") }}, + {{ columns | length }} AS model_columns FROM {{ ref('sushi', 'waiter_revenue_by_day') }} WHERE ds = ( diff --git a/tests/fixtures/dbt/sushi_test/models/waiter_as_customer_by_day.sql b/tests/fixtures/dbt/sushi_test/models/waiter_as_customer_by_day.sql index 82237634b5..ed845b67cb 100644 --- a/tests/fixtures/dbt/sushi_test/models/waiter_as_customer_by_day.sql +++ b/tests/fixtures/dbt/sushi_test/models/waiter_as_customer_by_day.sql @@ -1,7 +1,7 @@ {{ config( materialized='incremental', - incremental_strategy='delete+insert', + incremental_strategy='incremental_by_time_range', cluster_by=['ds'], time_column='ds', ) diff --git a/tests/fixtures/dbt/sushi_test/models/waiter_revenue_by_day.sql b/tests/fixtures/dbt/sushi_test/models/waiter_revenue_by_day.sql index 335e7ab799..2731b07019 100644 --- a/tests/fixtures/dbt/sushi_test/models/waiter_revenue_by_day.sql +++ b/tests/fixtures/dbt/sushi_test/models/waiter_revenue_by_day.sql @@ -1,7 +1,7 @@ {{ config( materialized='incremental', - incremental_strategy='delete+insert', + incremental_strategy='incremental_by_time_range', cluster_by=['ds'], time_column='ds', dialect="bigquery" @@ -13,7 +13,8 @@ {{ test_dependencies() }} -{% set results = run_query('select 1 as constant') %} +{% set var_name = "dynamic_" + "test_" + "var" %} +{% set results = run_query('select ' ~ dynamic_var_name_dependency(var_name) ~ ' as constant') %} SELECT o.waiter_id::INT AS waiter_id, /* Waiter id */ @@ -29,11 +30,7 @@ LEFT JOIN {{ source('streaming', 'items') }} AS i ON oi.item_id = i.id AND oi.ds = i.ds {% if is_incremental() %} WHERE - o.ds > (select max(ds) from {{ this }}) -{% endif %} -{% if sqlmesh_incremental is defined %} - WHERE - o.ds BETWEEN '{{ start_ds }}' AND '{{ end_ds }}' + o.ds > (select CAST(max(ds) AS DATE) from {{ this }}) {% endif %} GROUP BY o.waiter_id, diff --git a/tests/fixtures/dbt/sushi_test/models/waiter_revenue_by_day_v1.sql b/tests/fixtures/dbt/sushi_test/models/waiter_revenue_by_day_v1.sql index 335e7ab799..e229dc8b91 100644 --- a/tests/fixtures/dbt/sushi_test/models/waiter_revenue_by_day_v1.sql +++ b/tests/fixtures/dbt/sushi_test/models/waiter_revenue_by_day_v1.sql @@ -1,7 +1,7 @@ {{ config( materialized='incremental', - incremental_strategy='delete+insert', + incremental_strategy='incremental_by_time_range', cluster_by=['ds'], time_column='ds', dialect="bigquery" diff --git a/tests/fixtures/dbt/sushi_test/models/waiters_doc_block.md b/tests/fixtures/dbt/sushi_test/models/waiters_doc_block.md new file mode 100644 index 0000000000..99d1582c91 --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/models/waiters_doc_block.md @@ -0,0 +1,4 @@ +{% docs waiters %} +waiters docs block +{% enddocs %} + diff --git a/tests/fixtures/dbt/sushi_test/packages/customers/dbt_project.yml b/tests/fixtures/dbt/sushi_test/packages/customers/dbt_project.yml index 7b09f72a45..c7b89da8f0 100644 --- a/tests/fixtures/dbt/sushi_test/packages/customers/dbt_project.yml +++ b/tests/fixtures/dbt/sushi_test/packages/customers/dbt_project.yml @@ -30,3 +30,11 @@ vars: some_other_var: 5 yet_another_var: 5 'customers:customer_id': "bla" + + +on-run-start: + - 'CREATE TABLE IF NOT EXISTS to_be_executed_first (col VARCHAR);' + - 'CREATE TABLE IF NOT EXISTS analytic_stats_packaged_project (physical_table VARCHAR, evaluation_time VARCHAR);' +on-run-end: + - 'DROP TABLE to_be_executed_first' + - '{{ packaged_tables(schemas) }}' \ No newline at end of file diff --git a/tests/fixtures/dbt/sushi_test/packages/customers/macros/packaged_tables.sql b/tests/fixtures/dbt/sushi_test/packages/customers/macros/packaged_tables.sql new file mode 100644 index 0000000000..51ce04f06d --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/packages/customers/macros/packaged_tables.sql @@ -0,0 +1,5 @@ +{% macro packaged_tables(schemas) %} + {% for schema in schemas %} + create or replace table schema_table_{{schema}}_nested_package as select '{{schema}}' as schema; + {% endfor%} +{% endmacro %} \ No newline at end of file diff --git a/tests/fixtures/dbt/sushi_test/profiles.yml b/tests/fixtures/dbt/sushi_test/profiles.yml index 056c3c2b91..f49ad8ea0f 100644 --- a/tests/fixtures/dbt/sushi_test/profiles.yml +++ b/tests/fixtures/dbt/sushi_test/profiles.yml @@ -3,6 +3,7 @@ sushi: in_memory: type: duckdb schema: sushi + database: memory duckdb: type: duckdb path: 'local.duckdb' diff --git a/tests/fixtures/dbt/sushi_test/seeds/properties.yml b/tests/fixtures/dbt/sushi_test/seeds/properties.yml index f370c1962f..480447f1ef 100644 --- a/tests/fixtures/dbt/sushi_test/seeds/properties.yml +++ b/tests/fixtures/dbt/sushi_test/seeds/properties.yml @@ -2,3 +2,13 @@ version: 2 seeds: - name: waiter_names + - name: waiter_revenue_semicolon + config: + delimiter: ";" + columns: + - name: waiter_id + data_type: int + - name: revenue + data_type: decimal + - name: quarter + data_type: text diff --git a/tests/fixtures/dbt/sushi_test/seeds/waiter_revenue_semicolon.csv b/tests/fixtures/dbt/sushi_test/seeds/waiter_revenue_semicolon.csv new file mode 100644 index 0000000000..df477a3ed4 --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/seeds/waiter_revenue_semicolon.csv @@ -0,0 +1,7 @@ +waiter_id;revenue;quarter +1;100.50;Q1 +2;200.75;Q1 +3;150.25;Q1 +1;125.00;Q2 +2;225.50;Q2 +3;175.75;Q2 \ No newline at end of file diff --git a/tests/fixtures/dbt/sushi_test/snapshots/items_check_snapshot.sql b/tests/fixtures/dbt/sushi_test/snapshots/items_check_snapshot.sql deleted file mode 100644 index fdda412e7f..0000000000 --- a/tests/fixtures/dbt/sushi_test/snapshots/items_check_snapshot.sql +++ /dev/null @@ -1,15 +0,0 @@ -{% snapshot items_check_snapshot %} - -{{ - config( - target_schema='snapshots', - unique_key='id', - strategy='check', - check_cols=['ds'], - invalidate_hard_deletes=True, - ) -}} - -select * from {{ source('streaming', 'items') }} - -{% endsnapshot %} diff --git a/tests/fixtures/dbt/sushi_test/snapshots/items_no_hard_delete_snapshot.sql b/tests/fixtures/dbt/sushi_test/snapshots/items_no_hard_delete_snapshot.sql index eb14c8d0a0..6d6292b951 100644 --- a/tests/fixtures/dbt/sushi_test/snapshots/items_no_hard_delete_snapshot.sql +++ b/tests/fixtures/dbt/sushi_test/snapshots/items_no_hard_delete_snapshot.sql @@ -6,6 +6,7 @@ unique_key='id', strategy='timestamp', updated_at='ds', + on_schema_change='sync_all_columns', ) }} diff --git a/tests/fixtures/dbt/sushi_test/snapshots/items_snapshot.sql b/tests/fixtures/dbt/sushi_test/snapshots/items_snapshot.sql deleted file mode 100644 index 6cce99d7de..0000000000 --- a/tests/fixtures/dbt/sushi_test/snapshots/items_snapshot.sql +++ /dev/null @@ -1,15 +0,0 @@ -{% snapshot items_snapshot %} - -{{ - config( - target_schema='snapshots', - unique_key='id', - strategy='timestamp', - updated_at='ds', - invalidate_hard_deletes=True, - ) -}} - -select * from {{ source('streaming', 'items') }} - -{% endsnapshot %} diff --git a/tests/fixtures/dbt/sushi_test/snapshots/items_snapshots.sql b/tests/fixtures/dbt/sushi_test/snapshots/items_snapshots.sql new file mode 100644 index 0000000000..fbce585edf --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/snapshots/items_snapshots.sql @@ -0,0 +1,48 @@ +{% snapshot items_snapshot %} + +{{ + config( + target_schema='snapshots', + unique_key='id', + strategy='timestamp', + updated_at='ds', + invalidate_hard_deletes=True, + on_schema_change='sync_all_columns', + ) +}} + +select * from {{ source('streaming', 'items') }} + +{% endsnapshot %} + +{% snapshot items_check_snapshot %} + +{{ + config( + target_schema='snapshots', + unique_key='id', + strategy='check', + check_cols=['ds'], + invalidate_hard_deletes=True, + ) +}} + +select * from {{ source('streaming', 'items') }} + +{% endsnapshot %} + +{% snapshot items_check_with_cast_snapshot %} + +{{ + config( + target_schema='snapshots', + unique_key='id', + strategy='check', + check_cols=['ds::DATE'], + invalidate_hard_deletes=True, + ) +}} + +select * from {{ source('streaming', 'items') }} + +{% endsnapshot %} diff --git a/tests/fixtures/dbt/sushi_test/tests/test_top_waiters.sql b/tests/fixtures/dbt/sushi_test/tests/test_top_waiters.sql new file mode 100644 index 0000000000..db6233db07 --- /dev/null +++ b/tests/fixtures/dbt/sushi_test/tests/test_top_waiters.sql @@ -0,0 +1,4 @@ +-- Check that revenue is positive +SELECT waiter_id +FROM {{ ref('top_waiters') }} +WHERE revenue < 0 diff --git a/tests/fixtures/migrations/environments.json b/tests/fixtures/migrations/environments.json index e841e38463..cbe4945863 100644 --- a/tests/fixtures/migrations/environments.json +++ b/tests/fixtures/migrations/environments.json @@ -1 +1 @@ -{"name":{"0":"staging","1":"dev"},"snapshots":{"0":"[{\"name\": \"sushi.waiter_as_customer_by_day\", \"fingerprint\": {\"data_hash\": \"486172035\", \"metadata_hash\": \"1992853678\", \"parent_data_hash\": \"2154574190\", \"parent_metadata_hash\": \"1349779748\"}, \"version\": \"1267397572\", \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"sushi.waiters\", \"identifier\": \"3386889721\"}, {\"name\": \"sushi.waiter_names\", \"identifier\": \"3233103305\"}, {\"name\": \"sushi.customers\", \"identifier\": \"3148897116\"}, {\"name\": \"sushi.orders\", \"identifier\": \"3564161223\"}], \"previous_versions\": [], \"is_materialized\": true, \"is_embedded_kind\": false}, {\"name\": \"sushi.waiter_revenue_by_day\", \"fingerprint\": {\"data_hash\": \"2443934302\", \"metadata_hash\": \"2904050331\", \"parent_data_hash\": \"764310396\", \"parent_metadata_hash\": \"3147731239\"}, \"version\": \"2695875565\", \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"sushi.order_items\", \"identifier\": \"1806777563\"}, {\"name\": \"sushi.items\", \"identifier\": \"2957171338\"}, {\"name\": \"sushi.orders\", \"identifier\": \"3564161223\"}], \"previous_versions\": [], \"is_materialized\": true, \"is_embedded_kind\": false}, {\"name\": \"sushi.top_waiters\", \"fingerprint\": {\"data_hash\": \"2891807529\", \"metadata_hash\": \"3392493998\", \"parent_data_hash\": \"1940707936\", \"parent_metadata_hash\": \"1276363398\"}, \"version\": \"3010914162\", \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"sushi.waiter_revenue_by_day\", \"identifier\": \"1609279380\"}], \"previous_versions\": [], \"is_materialized\": false, \"is_embedded_kind\": false}, {\"name\": \"sushi.waiters\", \"fingerprint\": {\"data_hash\": \"3501061139\", \"metadata_hash\": \"570478986\", \"parent_data_hash\": \"777615193\", \"parent_metadata_hash\": \"2042613269\"}, \"version\": \"2059227798\", \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"sushi.orders\", \"identifier\": \"3564161223\"}], \"previous_versions\": [], \"is_materialized\": false, \"is_embedded_kind\": true}, {\"name\": \"sushi.customers\", \"fingerprint\": {\"data_hash\": \"3553985282\", \"metadata_hash\": \"570478986\", \"parent_data_hash\": \"777615193\", \"parent_metadata_hash\": \"2042613269\"}, \"version\": \"2359719298\", \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"sushi.orders\", \"identifier\": \"3564161223\"}], \"previous_versions\": [], \"is_materialized\": true, \"is_embedded_kind\": false}, {\"name\": \"sushi.waiter_names\", \"fingerprint\": {\"data_hash\": \"1876476880\", \"metadata_hash\": \"570478986\", \"parent_data_hash\": \"0\", \"parent_metadata_hash\": \"0\"}, \"version\": \"2505706914\", \"physical_schema\": \"sqlmesh\", \"parents\": [], \"previous_versions\": [], \"is_materialized\": true, \"is_embedded_kind\": false}, {\"name\": \"sushi.customer_revenue_by_day\", \"fingerprint\": {\"data_hash\": \"2657552867\", \"metadata_hash\": \"129771006\", \"parent_data_hash\": \"764310396\", \"parent_metadata_hash\": \"3147731239\"}, \"version\": \"1291364031\", \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"sushi.order_items\", \"identifier\": \"1806777563\"}, {\"name\": \"sushi.items\", \"identifier\": \"2957171338\"}, {\"name\": \"sushi.orders\", \"identifier\": \"3564161223\"}], \"previous_versions\": [], \"is_materialized\": true, \"is_embedded_kind\": false}, {\"name\": \"sushi.items\", \"fingerprint\": {\"data_hash\": \"1960378930\", \"metadata_hash\": \"2900807542\", \"parent_data_hash\": \"0\", \"parent_metadata_hash\": \"0\"}, \"version\": \"312608270\", \"physical_schema\": \"sqlmesh\", \"parents\": [], \"previous_versions\": [], \"is_materialized\": true, \"is_embedded_kind\": false}, {\"name\": \"sushi.order_items\", \"fingerprint\": {\"data_hash\": \"653664599\", \"metadata_hash\": \"1960934702\", \"parent_data_hash\": \"3170724558\", \"parent_metadata_hash\": \"867324801\"}, \"version\": \"1015284155\", \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"sushi.items\", \"identifier\": \"2957171338\"}, {\"name\": \"sushi.orders\", \"identifier\": \"3564161223\"}], \"previous_versions\": [], \"is_materialized\": true, \"is_embedded_kind\": false}, {\"name\": \"sushi.orders\", \"fingerprint\": {\"data_hash\": \"1628439771\", \"metadata_hash\": \"2745052130\", \"parent_data_hash\": \"0\", \"parent_metadata_hash\": \"0\"}, \"version\": \"925846788\", \"physical_schema\": \"sqlmesh\", \"parents\": [], \"previous_versions\": [], \"is_materialized\": true, \"is_embedded_kind\": false}]","1":"[{\"name\": \"sushi.waiter_as_customer_by_day\", \"fingerprint\": {\"data_hash\": \"486172035\", \"metadata_hash\": \"1992853678\", \"parent_data_hash\": \"2824767713\", \"parent_metadata_hash\": \"1349779748\"}, \"version\": \"3668757715\", \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"sushi.waiters\", \"identifier\": \"3386889721\"}, {\"name\": \"sushi.waiter_names\", \"identifier\": \"1604207722\"}, {\"name\": \"sushi.customers\", \"identifier\": \"3148897116\"}, {\"name\": \"sushi.orders\", \"identifier\": \"3564161223\"}], \"previous_versions\": [{\"fingerprint\": {\"data_hash\": \"486172035\", \"metadata_hash\": \"1992853678\", \"parent_data_hash\": \"2154574190\", \"parent_metadata_hash\": \"1349779748\"}, \"version\": \"1267397572\"}], \"is_materialized\": true, \"is_embedded_kind\": false}, {\"name\": \"sushi.waiter_revenue_by_day\", \"fingerprint\": {\"data_hash\": \"2443934302\", \"metadata_hash\": \"2904050331\", \"parent_data_hash\": \"764310396\", \"parent_metadata_hash\": \"3147731239\"}, \"version\": \"2695875565\", \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"sushi.order_items\", \"identifier\": \"1806777563\"}, {\"name\": \"sushi.items\", \"identifier\": \"2957171338\"}, {\"name\": \"sushi.orders\", \"identifier\": \"3564161223\"}], \"previous_versions\": [], \"is_materialized\": true, \"is_embedded_kind\": false}, {\"name\": \"sushi.top_waiters\", \"fingerprint\": {\"data_hash\": \"2891807529\", \"metadata_hash\": \"3392493998\", \"parent_data_hash\": \"1940707936\", \"parent_metadata_hash\": \"1276363398\"}, \"version\": \"3010914162\", \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"sushi.waiter_revenue_by_day\", \"identifier\": \"1609279380\"}], \"previous_versions\": [], \"is_materialized\": false, \"is_embedded_kind\": false}, {\"name\": \"sushi.waiters\", \"fingerprint\": {\"data_hash\": \"3501061139\", \"metadata_hash\": \"570478986\", \"parent_data_hash\": \"777615193\", \"parent_metadata_hash\": \"2042613269\"}, \"version\": \"2059227798\", \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"sushi.orders\", \"identifier\": \"3564161223\"}], \"previous_versions\": [], \"is_materialized\": false, \"is_embedded_kind\": true}, {\"name\": \"sushi.customers\", \"fingerprint\": {\"data_hash\": \"3553985282\", \"metadata_hash\": \"570478986\", \"parent_data_hash\": \"777615193\", \"parent_metadata_hash\": \"2042613269\"}, \"version\": \"2359719298\", \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"sushi.orders\", \"identifier\": \"3564161223\"}], \"previous_versions\": [], \"is_materialized\": true, \"is_embedded_kind\": false}, {\"name\": \"sushi.waiter_names\", \"fingerprint\": {\"data_hash\": \"4133862560\", \"metadata_hash\": \"570478986\", \"parent_data_hash\": \"0\", \"parent_metadata_hash\": \"0\"}, \"version\": \"1204702829\", \"physical_schema\": \"sqlmesh\", \"parents\": [], \"previous_versions\": [{\"fingerprint\": {\"data_hash\": \"1876476880\", \"metadata_hash\": \"570478986\", \"parent_data_hash\": \"0\", \"parent_metadata_hash\": \"0\"}, \"version\": \"2505706914\"}], \"change_category\": 1, \"is_materialized\": true, \"is_embedded_kind\": false}, {\"name\": \"sushi.customer_revenue_by_day\", \"fingerprint\": {\"data_hash\": \"2657552867\", \"metadata_hash\": \"129771006\", \"parent_data_hash\": \"764310396\", \"parent_metadata_hash\": \"3147731239\"}, \"version\": \"1291364031\", \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"sushi.order_items\", \"identifier\": \"1806777563\"}, {\"name\": \"sushi.items\", \"identifier\": \"2957171338\"}, {\"name\": \"sushi.orders\", \"identifier\": \"3564161223\"}], \"previous_versions\": [], \"is_materialized\": true, \"is_embedded_kind\": false}, {\"name\": \"sushi.items\", \"fingerprint\": {\"data_hash\": \"1960378930\", \"metadata_hash\": \"2900807542\", \"parent_data_hash\": \"0\", \"parent_metadata_hash\": \"0\"}, \"version\": \"312608270\", \"physical_schema\": \"sqlmesh\", \"parents\": [], \"previous_versions\": [], \"is_materialized\": true, \"is_embedded_kind\": false}, {\"name\": \"sushi.order_items\", \"fingerprint\": {\"data_hash\": \"653664599\", \"metadata_hash\": \"1960934702\", \"parent_data_hash\": \"3170724558\", \"parent_metadata_hash\": \"867324801\"}, \"version\": \"1015284155\", \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"sushi.items\", \"identifier\": \"2957171338\"}, {\"name\": \"sushi.orders\", \"identifier\": \"3564161223\"}], \"previous_versions\": [], \"is_materialized\": true, \"is_embedded_kind\": false}, {\"name\": \"sushi.orders\", \"fingerprint\": {\"data_hash\": \"1628439771\", \"metadata_hash\": \"2745052130\", \"parent_data_hash\": \"0\", \"parent_metadata_hash\": \"0\"}, \"version\": \"925846788\", \"physical_schema\": \"sqlmesh\", \"parents\": [], \"previous_versions\": [], \"is_materialized\": true, \"is_embedded_kind\": false}]"},"start_at":{"0":"2023-01-01","1":"2023-01-01"},"end_at":{"0":"2023-01-07","1":"2023-01-07"},"plan_id":{"0":"2b16ff4b77dc44789b628b4a8a4ed38a","1":"d5dcc7aafce742aab763331525196613"},"previous_plan_id":{"0":null,"1":"79f4bab2177b495ab877b674bc511f2b"},"expiration_ts":{"0":1681419197966,"1":1681419273635}} \ No newline at end of file +{"name":{"0":"staging","1":"dev"},"snapshots":{"0":"[{\"name\": \"\\\"sushi\\\".\\\"waiter_as_customer_by_day\\\"\", \"temp_version\": \"1267397572\", \"change_category\": 4, \"fingerprint\": {\"data_hash\": \"849558693\", \"metadata_hash\": \"2088684978\", \"parent_data_hash\": \"2705906012\", \"parent_metadata_hash\": \"665080906\"}, \"previous_versions\": [{\"fingerprint\": {\"data_hash\": \"486172035\", \"metadata_hash\": \"1992853678\", \"parent_data_hash\": \"2154574190\", \"parent_metadata_hash\": \"1349779748\"}, \"version\": \"1267397572\", \"change_category\": 4, \"physical_schema\": \"sqlmesh\"}], \"version\": \"1267397572\", \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"\\\"sushi\\\".\\\"waiter_names\\\"\", \"identifier\": \"1609854746\"}, {\"name\": \"\\\"sushi\\\".\\\"waiters\\\"\", \"identifier\": \"4123940212\"}, {\"name\": \"\\\"sushi\\\".\\\"orders\\\"\", \"identifier\": \"1250207606\"}, {\"name\": \"\\\"sushi\\\".\\\"customers\\\"\", \"identifier\": \"1461038955\"}], \"kind_name\": \"INCREMENTAL_BY_TIME_RANGE\", \"node_type\": \"model\"}, {\"name\": \"\\\"sushi\\\".\\\"waiter_revenue_by_day\\\"\", \"temp_version\": \"2695875565\", \"change_category\": 4, \"fingerprint\": {\"data_hash\": \"2224089837\", \"metadata_hash\": \"2504236462\", \"parent_data_hash\": \"2738168331\", \"parent_metadata_hash\": \"1795276494\"}, \"previous_versions\": [{\"fingerprint\": {\"data_hash\": \"2443934302\", \"metadata_hash\": \"2904050331\", \"parent_data_hash\": \"764310396\", \"parent_metadata_hash\": \"3147731239\"}, \"version\": \"2695875565\", \"change_category\": 4, \"physical_schema\": \"sqlmesh\"}], \"version\": \"2695875565\", \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"\\\"sushi\\\".\\\"items\\\"\", \"identifier\": \"3721860967\"}, {\"name\": \"\\\"sushi\\\".\\\"orders\\\"\", \"identifier\": \"1250207606\"}, {\"name\": \"\\\"sushi\\\".\\\"order_items\\\"\", \"identifier\": \"1422946820\"}], \"kind_name\": \"INCREMENTAL_BY_TIME_RANGE\", \"node_type\": \"model\"}, {\"name\": \"\\\"sushi\\\".\\\"top_waiters\\\"\", \"temp_version\": \"3010914162\", \"change_category\": 4, \"fingerprint\": {\"data_hash\": \"4131026946\", \"metadata_hash\": \"154190563\", \"parent_data_hash\": \"929243525\", \"parent_metadata_hash\": \"2366450878\"}, \"previous_versions\": [{\"fingerprint\": {\"data_hash\": \"2891807529\", \"metadata_hash\": \"3392493998\", \"parent_data_hash\": \"1940707936\", \"parent_metadata_hash\": \"1276363398\"}, \"version\": \"3010914162\", \"change_category\": 4, \"physical_schema\": \"sqlmesh\"}], \"version\": \"3010914162\", \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"\\\"sushi\\\".\\\"waiter_revenue_by_day\\\"\", \"identifier\": \"2175947464\"}], \"kind_name\": \"VIEW\", \"node_type\": \"model\"}, {\"name\": \"\\\"sushi\\\".\\\"waiters\\\"\", \"temp_version\": \"2059227798\", \"change_category\": 4, \"fingerprint\": {\"data_hash\": \"2037801255\", \"metadata_hash\": \"3063653103\", \"parent_data_hash\": \"458609840\", \"parent_metadata_hash\": \"2007040660\"}, \"previous_versions\": [{\"fingerprint\": {\"data_hash\": \"3501061139\", \"metadata_hash\": \"570478986\", \"parent_data_hash\": \"777615193\", \"parent_metadata_hash\": \"2042613269\"}, \"version\": \"2059227798\", \"change_category\": 4, \"physical_schema\": \"sqlmesh\"}], \"version\": \"2059227798\", \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"\\\"sushi\\\".\\\"orders\\\"\", \"identifier\": \"1250207606\"}], \"kind_name\": \"EMBEDDED\", \"node_type\": \"model\"}, {\"name\": \"\\\"sushi\\\".\\\"customers\\\"\", \"temp_version\": \"2359719298\", \"change_category\": 4, \"fingerprint\": {\"data_hash\": \"2431070412\", \"metadata_hash\": \"3063653103\", \"parent_data_hash\": \"458609840\", \"parent_metadata_hash\": \"2007040660\"}, \"previous_versions\": [{\"fingerprint\": {\"data_hash\": \"3553985282\", \"metadata_hash\": \"570478986\", \"parent_data_hash\": \"777615193\", \"parent_metadata_hash\": \"2042613269\"}, \"version\": \"2359719298\", \"change_category\": 4, \"physical_schema\": \"sqlmesh\"}], \"version\": \"2359719298\", \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"\\\"sushi\\\".\\\"orders\\\"\", \"identifier\": \"1250207606\"}], \"kind_name\": \"FULL\", \"node_type\": \"model\"}, {\"name\": \"\\\"sushi\\\".\\\"waiter_names\\\"\", \"temp_version\": \"2505706914\", \"change_category\": 4, \"fingerprint\": {\"data_hash\": \"3604872020\", \"metadata_hash\": \"3468846895\", \"parent_data_hash\": \"0\", \"parent_metadata_hash\": \"0\"}, \"previous_versions\": [{\"fingerprint\": {\"data_hash\": \"1876476880\", \"metadata_hash\": \"570478986\", \"parent_data_hash\": \"0\", \"parent_metadata_hash\": \"0\"}, \"version\": \"2505706914\", \"change_category\": 4, \"physical_schema\": \"sqlmesh\"}], \"version\": \"2505706914\", \"physical_schema\": \"sqlmesh\", \"parents\": [], \"kind_name\": \"SEED\", \"node_type\": \"model\"}, {\"name\": \"\\\"sushi\\\".\\\"customer_revenue_by_day\\\"\", \"temp_version\": \"1291364031\", \"change_category\": 4, \"fingerprint\": {\"data_hash\": \"131732542\", \"metadata_hash\": \"1368842087\", \"parent_data_hash\": \"2738168331\", \"parent_metadata_hash\": \"1795276494\"}, \"previous_versions\": [{\"fingerprint\": {\"data_hash\": \"2657552867\", \"metadata_hash\": \"129771006\", \"parent_data_hash\": \"764310396\", \"parent_metadata_hash\": \"3147731239\"}, \"version\": \"1291364031\", \"change_category\": 4, \"physical_schema\": \"sqlmesh\"}], \"version\": \"1291364031\", \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"\\\"sushi\\\".\\\"items\\\"\", \"identifier\": \"3721860967\"}, {\"name\": \"\\\"sushi\\\".\\\"orders\\\"\", \"identifier\": \"1250207606\"}, {\"name\": \"\\\"sushi\\\".\\\"order_items\\\"\", \"identifier\": \"1422946820\"}], \"kind_name\": \"INCREMENTAL_BY_TIME_RANGE\", \"node_type\": \"model\"}, {\"name\": \"\\\"sushi\\\".\\\"items\\\"\", \"temp_version\": \"312608270\", \"change_category\": 4, \"fingerprint\": {\"data_hash\": \"1862622614\", \"metadata_hash\": \"3651173237\", \"parent_data_hash\": \"0\", \"parent_metadata_hash\": \"0\"}, \"previous_versions\": [{\"fingerprint\": {\"data_hash\": \"1960378930\", \"metadata_hash\": \"2900807542\", \"parent_data_hash\": \"0\", \"parent_metadata_hash\": \"0\"}, \"version\": \"312608270\", \"change_category\": 4, \"physical_schema\": \"sqlmesh\"}], \"version\": \"312608270\", \"physical_schema\": \"sqlmesh\", \"parents\": [], \"kind_name\": \"INCREMENTAL_BY_TIME_RANGE\", \"node_type\": \"model\"}, {\"name\": \"\\\"sushi\\\".\\\"order_items\\\"\", \"temp_version\": \"1015284155\", \"change_category\": 4, \"fingerprint\": {\"data_hash\": \"4010068827\", \"metadata_hash\": \"799196655\", \"parent_data_hash\": \"2342431947\", \"parent_metadata_hash\": \"1746080605\"}, \"previous_versions\": [{\"fingerprint\": {\"data_hash\": \"653664599\", \"metadata_hash\": \"1960934702\", \"parent_data_hash\": \"3170724558\", \"parent_metadata_hash\": \"867324801\"}, \"version\": \"1015284155\", \"change_category\": 4, \"physical_schema\": \"sqlmesh\"}], \"version\": \"1015284155\", \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"\\\"sushi\\\".\\\"items\\\"\", \"identifier\": \"3721860967\"}, {\"name\": \"\\\"sushi\\\".\\\"orders\\\"\", \"identifier\": \"1250207606\"}], \"kind_name\": \"INCREMENTAL_BY_TIME_RANGE\", \"node_type\": \"model\"}, {\"name\": \"\\\"sushi\\\".\\\"orders\\\"\", \"temp_version\": \"925846788\", \"change_category\": 4, \"fingerprint\": {\"data_hash\": \"1588786367\", \"metadata_hash\": \"1674367104\", \"parent_data_hash\": \"0\", \"parent_metadata_hash\": \"0\"}, \"previous_versions\": [{\"fingerprint\": {\"data_hash\": \"1628439771\", \"metadata_hash\": \"2745052130\", \"parent_data_hash\": \"0\", \"parent_metadata_hash\": \"0\"}, \"version\": \"925846788\", \"change_category\": 4, \"physical_schema\": \"sqlmesh\"}], \"version\": \"925846788\", \"physical_schema\": \"sqlmesh\", \"parents\": [], \"kind_name\": \"INCREMENTAL_BY_TIME_RANGE\", \"node_type\": \"model\"}]","1":"[{\"name\": \"\\\"sushi\\\".\\\"waiter_as_customer_by_day\\\"\", \"temp_version\": \"3668757715\", \"change_category\": 4, \"fingerprint\": {\"data_hash\": \"1936268024\", \"metadata_hash\": \"2088684978\", \"parent_data_hash\": \"3055854652\", \"parent_metadata_hash\": \"665080906\"}, \"previous_versions\": [{\"fingerprint\": {\"data_hash\": \"486172035\", \"metadata_hash\": \"1992853678\", \"parent_data_hash\": \"2154574190\", \"parent_metadata_hash\": \"1349779748\"}, \"version\": \"1267397572\"}, {\"fingerprint\": {\"data_hash\": \"486172035\", \"metadata_hash\": \"1992853678\", \"parent_data_hash\": \"2824767713\", \"parent_metadata_hash\": \"1349779748\"}, \"version\": \"3668757715\", \"change_category\": 4, \"physical_schema\": \"sqlmesh\"}], \"version\": \"3668757715\", \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"\\\"sushi\\\".\\\"waiter_names\\\"\", \"identifier\": \"2725136291\"}, {\"name\": \"\\\"sushi\\\".\\\"waiters\\\"\", \"identifier\": \"4123940212\"}, {\"name\": \"\\\"sushi\\\".\\\"orders\\\"\", \"identifier\": \"1250207606\"}, {\"name\": \"\\\"sushi\\\".\\\"customers\\\"\", \"identifier\": \"1461038955\"}], \"kind_name\": \"INCREMENTAL_BY_TIME_RANGE\", \"node_type\": \"model\"}, {\"name\": \"\\\"sushi\\\".\\\"waiter_revenue_by_day\\\"\", \"temp_version\": \"2695875565\", \"change_category\": 4, \"fingerprint\": {\"data_hash\": \"2224089837\", \"metadata_hash\": \"2504236462\", \"parent_data_hash\": \"2738168331\", \"parent_metadata_hash\": \"1795276494\"}, \"previous_versions\": [{\"fingerprint\": {\"data_hash\": \"2443934302\", \"metadata_hash\": \"2904050331\", \"parent_data_hash\": \"764310396\", \"parent_metadata_hash\": \"3147731239\"}, \"version\": \"2695875565\", \"change_category\": 4, \"physical_schema\": \"sqlmesh\"}], \"version\": \"2695875565\", \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"\\\"sushi\\\".\\\"items\\\"\", \"identifier\": \"3721860967\"}, {\"name\": \"\\\"sushi\\\".\\\"orders\\\"\", \"identifier\": \"1250207606\"}, {\"name\": \"\\\"sushi\\\".\\\"order_items\\\"\", \"identifier\": \"1422946820\"}], \"kind_name\": \"INCREMENTAL_BY_TIME_RANGE\", \"node_type\": \"model\"}, {\"name\": \"\\\"sushi\\\".\\\"top_waiters\\\"\", \"temp_version\": \"3010914162\", \"change_category\": 4, \"fingerprint\": {\"data_hash\": \"4131026946\", \"metadata_hash\": \"154190563\", \"parent_data_hash\": \"929243525\", \"parent_metadata_hash\": \"2366450878\"}, \"previous_versions\": [{\"fingerprint\": {\"data_hash\": \"2891807529\", \"metadata_hash\": \"3392493998\", \"parent_data_hash\": \"1940707936\", \"parent_metadata_hash\": \"1276363398\"}, \"version\": \"3010914162\", \"change_category\": 4, \"physical_schema\": \"sqlmesh\"}], \"version\": \"3010914162\", \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"\\\"sushi\\\".\\\"waiter_revenue_by_day\\\"\", \"identifier\": \"2175947464\"}], \"kind_name\": \"VIEW\", \"node_type\": \"model\"}, {\"name\": \"\\\"sushi\\\".\\\"waiters\\\"\", \"temp_version\": \"2059227798\", \"change_category\": 4, \"fingerprint\": {\"data_hash\": \"2037801255\", \"metadata_hash\": \"3063653103\", \"parent_data_hash\": \"458609840\", \"parent_metadata_hash\": \"2007040660\"}, \"previous_versions\": [{\"fingerprint\": {\"data_hash\": \"3501061139\", \"metadata_hash\": \"570478986\", \"parent_data_hash\": \"777615193\", \"parent_metadata_hash\": \"2042613269\"}, \"version\": \"2059227798\", \"change_category\": 4, \"physical_schema\": \"sqlmesh\"}], \"version\": \"2059227798\", \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"\\\"sushi\\\".\\\"orders\\\"\", \"identifier\": \"1250207606\"}], \"kind_name\": \"EMBEDDED\", \"node_type\": \"model\"}, {\"name\": \"\\\"sushi\\\".\\\"customers\\\"\", \"temp_version\": \"2359719298\", \"change_category\": 4, \"fingerprint\": {\"data_hash\": \"2431070412\", \"metadata_hash\": \"3063653103\", \"parent_data_hash\": \"458609840\", \"parent_metadata_hash\": \"2007040660\"}, \"previous_versions\": [{\"fingerprint\": {\"data_hash\": \"3553985282\", \"metadata_hash\": \"570478986\", \"parent_data_hash\": \"777615193\", \"parent_metadata_hash\": \"2042613269\"}, \"version\": \"2359719298\", \"change_category\": 4, \"physical_schema\": \"sqlmesh\"}], \"version\": \"2359719298\", \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"\\\"sushi\\\".\\\"orders\\\"\", \"identifier\": \"1250207606\"}], \"kind_name\": \"FULL\", \"node_type\": \"model\"}, {\"name\": \"\\\"sushi\\\".\\\"waiter_names\\\"\", \"temp_version\": \"1204702829\", \"change_category\": 1, \"fingerprint\": {\"data_hash\": \"1437406487\", \"metadata_hash\": \"3468846895\", \"parent_data_hash\": \"0\", \"parent_metadata_hash\": \"0\"}, \"previous_versions\": [{\"fingerprint\": {\"data_hash\": \"1876476880\", \"metadata_hash\": \"570478986\", \"parent_data_hash\": \"0\", \"parent_metadata_hash\": \"0\"}, \"version\": \"2505706914\"}, {\"fingerprint\": {\"data_hash\": \"4133862560\", \"metadata_hash\": \"570478986\", \"parent_data_hash\": \"0\", \"parent_metadata_hash\": \"0\"}, \"version\": \"1204702829\", \"change_category\": 1, \"physical_schema\": \"sqlmesh\"}], \"version\": \"1204702829\", \"physical_schema\": \"sqlmesh\", \"parents\": [], \"kind_name\": \"SEED\", \"node_type\": \"model\"}, {\"name\": \"\\\"sushi\\\".\\\"customer_revenue_by_day\\\"\", \"temp_version\": \"1291364031\", \"change_category\": 4, \"fingerprint\": {\"data_hash\": \"131732542\", \"metadata_hash\": \"1368842087\", \"parent_data_hash\": \"2738168331\", \"parent_metadata_hash\": \"1795276494\"}, \"previous_versions\": [{\"fingerprint\": {\"data_hash\": \"2657552867\", \"metadata_hash\": \"129771006\", \"parent_data_hash\": \"764310396\", \"parent_metadata_hash\": \"3147731239\"}, \"version\": \"1291364031\", \"change_category\": 4, \"physical_schema\": \"sqlmesh\"}], \"version\": \"1291364031\", \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"\\\"sushi\\\".\\\"items\\\"\", \"identifier\": \"3721860967\"}, {\"name\": \"\\\"sushi\\\".\\\"orders\\\"\", \"identifier\": \"1250207606\"}, {\"name\": \"\\\"sushi\\\".\\\"order_items\\\"\", \"identifier\": \"1422946820\"}], \"kind_name\": \"INCREMENTAL_BY_TIME_RANGE\", \"node_type\": \"model\"}, {\"name\": \"\\\"sushi\\\".\\\"items\\\"\", \"temp_version\": \"312608270\", \"change_category\": 4, \"fingerprint\": {\"data_hash\": \"1862622614\", \"metadata_hash\": \"3651173237\", \"parent_data_hash\": \"0\", \"parent_metadata_hash\": \"0\"}, \"previous_versions\": [{\"fingerprint\": {\"data_hash\": \"1960378930\", \"metadata_hash\": \"2900807542\", \"parent_data_hash\": \"0\", \"parent_metadata_hash\": \"0\"}, \"version\": \"312608270\", \"change_category\": 4, \"physical_schema\": \"sqlmesh\"}], \"version\": \"312608270\", \"physical_schema\": \"sqlmesh\", \"parents\": [], \"kind_name\": \"INCREMENTAL_BY_TIME_RANGE\", \"node_type\": \"model\"}, {\"name\": \"\\\"sushi\\\".\\\"order_items\\\"\", \"temp_version\": \"1015284155\", \"change_category\": 4, \"fingerprint\": {\"data_hash\": \"4010068827\", \"metadata_hash\": \"799196655\", \"parent_data_hash\": \"2342431947\", \"parent_metadata_hash\": \"1746080605\"}, \"previous_versions\": [{\"fingerprint\": {\"data_hash\": \"653664599\", \"metadata_hash\": \"1960934702\", \"parent_data_hash\": \"3170724558\", \"parent_metadata_hash\": \"867324801\"}, \"version\": \"1015284155\", \"change_category\": 4, \"physical_schema\": \"sqlmesh\"}], \"version\": \"1015284155\", \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"\\\"sushi\\\".\\\"items\\\"\", \"identifier\": \"3721860967\"}, {\"name\": \"\\\"sushi\\\".\\\"orders\\\"\", \"identifier\": \"1250207606\"}], \"kind_name\": \"INCREMENTAL_BY_TIME_RANGE\", \"node_type\": \"model\"}, {\"name\": \"\\\"sushi\\\".\\\"orders\\\"\", \"temp_version\": \"925846788\", \"change_category\": 4, \"fingerprint\": {\"data_hash\": \"1588786367\", \"metadata_hash\": \"1674367104\", \"parent_data_hash\": \"0\", \"parent_metadata_hash\": \"0\"}, \"previous_versions\": [{\"fingerprint\": {\"data_hash\": \"1628439771\", \"metadata_hash\": \"2745052130\", \"parent_data_hash\": \"0\", \"parent_metadata_hash\": \"0\"}, \"version\": \"925846788\", \"change_category\": 4, \"physical_schema\": \"sqlmesh\"}], \"version\": \"925846788\", \"physical_schema\": \"sqlmesh\", \"parents\": [], \"kind_name\": \"INCREMENTAL_BY_TIME_RANGE\", \"node_type\": \"model\"}]"},"start_at":{"0":"2023-01-01 00:00:00","1":"2023-01-01 00:00:00"},"end_at":{"0":"2023-01-07 00:00:00","1":"2023-01-07 00:00:00"},"plan_id":{"0":"2b16ff4b77dc44789b628b4a8a4ed38a","1":"d5dcc7aafce742aab763331525196613"},"previous_plan_id":{"0":null,"1":"79f4bab2177b495ab877b674bc511f2b"},"expiration_ts":{"0":1681419197966,"1":1681419273635},"finalized_ts":{"0":null,"1":null},"promoted_snapshot_ids":{"0":null,"1":null},"suffix_target":{"0":"schema","1":"schema"},"catalog_name_override":{"0":null,"1":null},"previous_finalized_snapshots":{"0":null,"1":null},"normalize_name":{"0":false,"1":false},"requirements":{"0":"{}","1":"{}"}} \ No newline at end of file diff --git a/tests/fixtures/migrations/intervals.json b/tests/fixtures/migrations/intervals.json new file mode 100644 index 0000000000..276fdd60de --- /dev/null +++ b/tests/fixtures/migrations/intervals.json @@ -0,0 +1 @@ +{"id":{"0":"1a1121bc700040d8af4f78ad96e025f1","1":"b901107d2ede4f50be32090eb2559d1a","2":"b366e44fd5e541008cb987a503b5ed7a","3":"ccbcd24427ac432da53fa158313ad800","4":"4fd6bdae011c4978aac8eb5a47521753","5":"d8549fb5f3674b29b4aa2b9988a42052","6":"3f8120d2a2c74f3baca25172537a7788","7":"f417d94c20e44dc5b1a0c29478672ac4","8":"6fd67cfbfcc743c8a87c32a95431c079","9":"b5a8f45c901e4c97aa634eb3ee5f521e","10":"46c7fdaccfd84ba68d766021d7d76511"},"created_ts":{"0":1757115220259,"1":1757115220259,"2":1757115220259,"3":1757115220259,"4":1757115220259,"5":1757115220259,"6":1757115220259,"7":1757115220259,"8":1757115220259,"9":1757115220260,"10":1757115220260},"name":{"0":"\"sushi\".\"waiter_as_customer_by_day\"","1":"\"sushi\".\"waiter_revenue_by_day\"","2":"\"sushi\".\"top_waiters\"","3":"\"sushi\".\"customers\"","4":"\"sushi\".\"waiter_names\"","5":"\"sushi\".\"customer_revenue_by_day\"","6":"\"sushi\".\"items\"","7":"\"sushi\".\"order_items\"","8":"\"sushi\".\"orders\"","9":"\"sushi\".\"waiter_as_customer_by_day\"","10":"\"sushi\".\"waiter_names\""},"identifier":{"0":"1281222509","1":"1609279380","2":"599861134","3":"3148897116","4":"3233103305","5":"1308408370","6":"2957171338","7":"1806777563","8":"3564161223","9":"1084858582","10":"1604207722"},"version":{"0":"1267397572","1":"2695875565","2":"3010914162","3":"2359719298","4":"2505706914","5":"1291364031","6":"312608270","7":"1015284155","8":"925846788","9":"3668757715","10":"1204702829"},"start_ts":{"0":1672531200000,"1":1672531200000,"2":1672531200000,"3":1672531200000,"4":1672531200000,"5":1672531200000,"6":1672531200000,"7":1672531200000,"8":1672531200000,"9":1672531200000,"10":1672531200000},"end_ts":{"0":1673136000000,"1":1673136000000,"2":1673136000000,"3":1673136000000,"4":1673136000000,"5":1673136000000,"6":1673136000000,"7":1673136000000,"8":1673136000000,"9":1673136000000,"10":1673136000000},"is_dev":{"0":false,"1":false,"2":false,"3":false,"4":false,"5":false,"6":false,"7":false,"8":false,"9":false,"10":false},"is_removed":{"0":false,"1":false,"2":false,"3":false,"4":false,"5":false,"6":false,"7":false,"8":false,"9":false,"10":false},"is_compacted":{"0":true,"1":true,"2":true,"3":true,"4":true,"5":true,"6":true,"7":true,"8":true,"9":true,"10":true}} \ No newline at end of file diff --git a/tests/fixtures/migrations/snapshots.json b/tests/fixtures/migrations/snapshots.json index 3794db0f06..638009abf1 100644 --- a/tests/fixtures/migrations/snapshots.json +++ b/tests/fixtures/migrations/snapshots.json @@ -1 +1 @@ -{"name":{"0":"sushi.waiter_as_customer_by_day","1":"sushi.waiter_revenue_by_day","2":"sushi.top_waiters","3":"sushi.waiters","4":"sushi.customers","5":"sushi.waiter_names","6":"sushi.customer_revenue_by_day","7":"sushi.items","8":"sushi.order_items","9":"sushi.orders","10":"sushi.waiter_as_customer_by_day","11":"sushi.waiter_names"},"identifier":{"0":"1281222509","1":"1609279380","2":"599861134","3":"3386889721","4":"3148897116","5":"3233103305","6":"1308408370","7":"2957171338","8":"1806777563","9":"3564161223","10":"1084858582","11":"1604207722"},"version":{"0":"1267397572","1":"2695875565","2":"3010914162","3":"2059227798","4":"2359719298","5":"2505706914","6":"1291364031","7":"312608270","8":"1015284155","9":"925846788","10":"3668757715","11":"1204702829"},"snapshot":{"0":"{\"name\": \"sushi.waiter_as_customer_by_day\", \"fingerprint\": {\"data_hash\": \"486172035\", \"metadata_hash\": \"1992853678\", \"parent_data_hash\": \"2154574190\", \"parent_metadata_hash\": \"1349779748\"}, \"physical_schema\": \"sqlmesh\", \"model\": {\"name\": \"sushi.waiter_as_customer_by_day\", \"kind\": {\"name\": \"INCREMENTAL_BY_TIME_RANGE\", \"time_column\": {\"column\": \"ds\", \"format\": \"%Y-%m-%d\"}}, \"dialect\": \"duckdb\", \"cron\": \"@daily\", \"owner\": \"jen\", \"partitioned_by\": [], \"pre\": [], \"post\": [], \"audits\": [[\"not_null\", {\"columns\": \"ARRAY(waiter_id)\"}]], \"expressions\": [], \"python_env\": {}, \"jinja_macros\": {\"packages\": {}, \"root_macros\": {}, \"global_objs\": {}}, \"query\": \"SELECT w.ds AS ds, w.waiter_id AS waiter_id, wn.name AS waiter_name FROM sushi.waiters AS w JOIN sushi.customers AS c ON w.waiter_id = c.customer_id JOIN sushi.waiter_names AS wn ON w.waiter_id = wn.id\", \"source_type\": \"sql\"}, \"parents\": [{\"name\": \"sushi.waiters\", \"identifier\": \"3386889721\"}, {\"name\": \"sushi.waiter_names\", \"identifier\": \"3233103305\"}, {\"name\": \"sushi.customers\", \"identifier\": \"3148897116\"}, {\"name\": \"sushi.orders\", \"identifier\": \"3564161223\"}], \"audits\": [], \"intervals\": [[1672531200000, 1673136000000]], \"dev_intervals\": [], \"created_ts\": 1680814376348, \"updated_ts\": 1680814376348, \"ttl\": \"in 1 week\", \"previous_versions\": [], \"indirect_versions\": {}, \"version\": \"1267397572\"}","1":"{\"name\": \"sushi.waiter_revenue_by_day\", \"fingerprint\": {\"data_hash\": \"2443934302\", \"metadata_hash\": \"2904050331\", \"parent_data_hash\": \"764310396\", \"parent_metadata_hash\": \"3147731239\"}, \"physical_schema\": \"sqlmesh\", \"model\": {\"name\": \"sushi.waiter_revenue_by_day\", \"kind\": {\"name\": \"INCREMENTAL_BY_TIME_RANGE\", \"time_column\": {\"column\": \"ds\", \"format\": \"%Y-%m-%d\"}}, \"dialect\": \"duckdb\", \"cron\": \"@daily\", \"owner\": \"jen\", \"description\": \"Table of revenue generated by waiters by day.\", \"batch_size\": 10, \"partitioned_by\": [], \"pre\": [], \"post\": [], \"audits\": [[\"number_of_rows\", {\"threshold\": \"0\"}]], \"expressions\": [], \"python_env\": {}, \"jinja_macros\": {\"packages\": {}, \"root_macros\": {}, \"global_objs\": {}}, \"query\": \"SELECT CAST(o.waiter_id AS INT) AS waiter_id \/* Waiter id *\/, CAST(SUM(oi.quantity * i.price) AS DOUBLE) AS revenue \/* Revenue from orders taken by this waiter *\/, CAST(o.ds AS TEXT) AS ds \/* Date *\/ FROM sushi.orders AS o LEFT JOIN sushi.order_items AS oi ON o.id = oi.order_id AND o.ds = oi.ds LEFT JOIN sushi.items AS i ON oi.item_id = i.id AND oi.ds = i.ds WHERE o.ds BETWEEN @start_ds AND @end_ds GROUP BY o.waiter_id, o.ds\", \"source_type\": \"sql\"}, \"parents\": [{\"name\": \"sushi.order_items\", \"identifier\": \"1806777563\"}, {\"name\": \"sushi.items\", \"identifier\": \"2957171338\"}, {\"name\": \"sushi.orders\", \"identifier\": \"3564161223\"}], \"audits\": [], \"intervals\": [[1672531200000, 1673136000000]], \"dev_intervals\": [], \"created_ts\": 1680814376361, \"updated_ts\": 1680814376361, \"ttl\": \"in 1 week\", \"previous_versions\": [], \"indirect_versions\": {}, \"version\": \"2695875565\"}","2":"{\"name\": \"sushi.top_waiters\", \"fingerprint\": {\"data_hash\": \"2891807529\", \"metadata_hash\": \"3392493998\", \"parent_data_hash\": \"1940707936\", \"parent_metadata_hash\": \"1276363398\"}, \"physical_schema\": \"sqlmesh\", \"model\": {\"name\": \"sushi.top_waiters\", \"kind\": {\"name\": \"VIEW\"}, \"dialect\": \"duckdb\", \"cron\": \"@daily\", \"owner\": \"jen\", \"description\": \"View of top waiters.\", \"partitioned_by\": [], \"pre\": [], \"post\": [], \"audits\": [[\"unique_values\", {\"columns\": \"ARRAY(waiter_id)\"}]], \"expressions\": [], \"python_env\": {}, \"jinja_macros\": {\"packages\": {}, \"root_macros\": {}, \"global_objs\": {}}, \"query\": \"SELECT CAST(waiter_id AS INT) AS waiter_id, CAST(revenue AS DOUBLE) AS revenue FROM sushi.waiter_revenue_by_day WHERE ds = (SELECT MAX(ds) FROM sushi.waiter_revenue_by_day) ORDER BY revenue DESC LIMIT 10\", \"source_type\": \"sql\"}, \"parents\": [{\"name\": \"sushi.waiter_revenue_by_day\", \"identifier\": \"1609279380\"}], \"audits\": [], \"intervals\": [[1672531200000, 1673136000000]], \"dev_intervals\": [], \"created_ts\": 1680814376384, \"updated_ts\": 1680814376384, \"ttl\": \"in 1 week\", \"previous_versions\": [], \"indirect_versions\": {}, \"version\": \"3010914162\"}","3":"{\"name\": \"sushi.waiters\", \"fingerprint\": {\"data_hash\": \"3501061139\", \"metadata_hash\": \"570478986\", \"parent_data_hash\": \"777615193\", \"parent_metadata_hash\": \"2042613269\"}, \"physical_schema\": \"sqlmesh\", \"model\": {\"name\": \"sushi.waiters\", \"kind\": {\"name\": \"EMBEDDED\"}, \"dialect\": \"duckdb\", \"cron\": \"@daily\", \"owner\": \"jen\", \"partitioned_by\": [], \"pre\": [], \"post\": [], \"audits\": [], \"expressions\": [], \"python_env\": {\"incremental_by_ds\": {\"payload\": \"def incremental_by_ds(evaluator, column):\\n expression = evaluator.transform(exp.Between(this=column, low=MacroVar(\\n this='start_ds'), high=MacroVar(this='end_ds')))\\n if not isinstance(expression, exp.Expression):\\n raise MacroEvalError(\\n f'Return type is {type(expression)}, expected exp.Expression')\\n return expression\", \"kind\": \"definition\", \"name\": \"incremental_by_ds\", \"path\": \"macros\/macros.py\"}, \"exp\": {\"payload\": \"import sqlglot.expressions as exp\", \"kind\": \"import\"}, \"MacroVar\": {\"payload\": \"from sqlmesh.core.dialect import MacroVar\", \"kind\": \"import\"}, \"MacroEvalError\": {\"payload\": \"from sqlmesh.utils.errors import MacroEvalError\", \"kind\": \"import\"}}, \"jinja_macros\": {\"packages\": {}, \"root_macros\": {}, \"global_objs\": {}}, \"query\": \"SELECT DISTINCT CAST(waiter_id AS INT) AS waiter_id, CAST(ds AS TEXT) AS ds FROM sushi.orders AS o WHERE @incremental_by_ds(ds)\", \"source_type\": \"sql\"}, \"parents\": [{\"name\": \"sushi.orders\", \"identifier\": \"3564161223\"}], \"audits\": [], \"intervals\": [], \"dev_intervals\": [], \"created_ts\": 1680814376387, \"updated_ts\": 1680814376387, \"ttl\": \"in 1 week\", \"previous_versions\": [], \"indirect_versions\": {}, \"version\": \"2059227798\"}","4":"{\"name\": \"sushi.customers\", \"fingerprint\": {\"data_hash\": \"3553985282\", \"metadata_hash\": \"570478986\", \"parent_data_hash\": \"777615193\", \"parent_metadata_hash\": \"2042613269\"}, \"physical_schema\": \"sqlmesh\", \"model\": {\"name\": \"sushi.customers\", \"kind\": {\"name\": \"FULL\"}, \"dialect\": \"duckdb\", \"cron\": \"@daily\", \"owner\": \"jen\", \"partitioned_by\": [], \"pre\": [[\"noop\", {\"x\": \"1\"}]], \"post\": [[\"noop\", {}], [\"noop\", {\"y\": \"ARRAY('a', 2)\"}]], \"audits\": [], \"expressions\": [], \"python_env\": {\"noop\": {\"payload\": \"def noop(context, start, end, latest, **kwargs):\\n pass\", \"kind\": \"definition\", \"name\": \"noop\", \"path\": \"hooks\/hooks.py\"}}, \"jinja_macros\": {\"packages\": {}, \"root_macros\": {}, \"global_objs\": {}}, \"query\": \"SELECT DISTINCT CAST(customer_id AS INT) AS customer_id FROM sushi.orders AS o\", \"source_type\": \"sql\"}, \"parents\": [{\"name\": \"sushi.orders\", \"identifier\": \"3564161223\"}], \"audits\": [], \"intervals\": [[1672531200000, 1673136000000]], \"dev_intervals\": [], \"created_ts\": 1680814376388, \"updated_ts\": 1680814376388, \"ttl\": \"in 1 week\", \"previous_versions\": [], \"indirect_versions\": {}, \"version\": \"2359719298\"}","5":"{\"name\": \"sushi.waiter_names\", \"fingerprint\": {\"data_hash\": \"1876476880\", \"metadata_hash\": \"570478986\", \"parent_data_hash\": \"0\", \"parent_metadata_hash\": \"0\"}, \"physical_schema\": \"sqlmesh\", \"model\": {\"name\": \"sushi.waiter_names\", \"kind\": {\"name\": \"SEED\", \"path\": \"..\/seeds\/waiter_names.csv\", \"batch_size\": 5}, \"dialect\": \"duckdb\", \"cron\": \"@daily\", \"owner\": \"jen\", \"partitioned_by\": [], \"pre\": [], \"post\": [], \"audits\": [], \"expressions\": [], \"jinja_macros\": {\"packages\": {}, \"root_macros\": {}, \"global_objs\": {}}, \"seed\": {\"content\": \"id,name\\n0,Toby\\n1,Tyson\\n2,Ryan\\n3,George\\n4,Chris\\n5,Max\\n6,Vincent\\n7,Iaroslav\\n8,Emma\\n9,Maia\\n\"}, \"source_type\": \"seed\"}, \"parents\": [], \"audits\": [], \"intervals\": [[1672531200000, 1673136000000]], \"dev_intervals\": [], \"created_ts\": 1680814376389, \"updated_ts\": 1680814376389, \"ttl\": \"in 1 week\", \"previous_versions\": [], \"indirect_versions\": {}, \"version\": \"2505706914\"}","6":"{\"name\": \"sushi.customer_revenue_by_day\", \"fingerprint\": {\"data_hash\": \"2657552867\", \"metadata_hash\": \"129771006\", \"parent_data_hash\": \"764310396\", \"parent_metadata_hash\": \"3147731239\"}, \"physical_schema\": \"sqlmesh\", \"model\": {\"name\": \"sushi.customer_revenue_by_day\", \"kind\": {\"name\": \"INCREMENTAL_BY_TIME_RANGE\", \"time_column\": {\"column\": \"ds\", \"format\": \"%Y-%m-%d\"}}, \"dialect\": \"hive\", \"cron\": \"@daily\", \"owner\": \"jen\", \"description\": \"Table of revenue from customers by day.\", \"batch_size\": 10, \"partitioned_by\": [], \"pre\": [], \"post\": [], \"audits\": [], \"expressions\": [], \"python_env\": {}, \"jinja_macros\": {\"packages\": {}, \"root_macros\": {}, \"global_objs\": {}}, \"query\": \"WITH order_total AS (SELECT oi.order_id AS order_id, SUM(oi.quantity * i.price) AS total, oi.ds AS ds FROM sushi.order_items AS oi LEFT JOIN sushi.items AS i ON oi.item_id = i.id AND oi.ds = i.ds WHERE oi.ds BETWEEN '{{ start_ds }}' AND '{{ end_ds }}' GROUP BY oi.order_id, oi.ds) SELECT CAST(o.customer_id AS INT) AS customer_id \/* Customer id *\/, CAST(SUM(ot.total) AS DOUBLE) AS revenue \/* Revenue from orders made by this customer *\/, CAST(o.ds AS TEXT) AS ds \/* Date *\/ FROM sushi.orders AS o LEFT JOIN order_total AS ot ON o.id = ot.order_id AND o.ds = ot.ds WHERE o.ds BETWEEN '{{ start_ds }}' AND '{{ end_ds }}' GROUP BY o.customer_id, o.ds\", \"source_type\": \"sql\"}, \"parents\": [{\"name\": \"sushi.order_items\", \"identifier\": \"1806777563\"}, {\"name\": \"sushi.items\", \"identifier\": \"2957171338\"}, {\"name\": \"sushi.orders\", \"identifier\": \"3564161223\"}], \"audits\": [], \"intervals\": [[1672531200000, 1673136000000]], \"dev_intervals\": [], \"created_ts\": 1680814376391, \"updated_ts\": 1680814376391, \"ttl\": \"in 1 week\", \"previous_versions\": [], \"indirect_versions\": {}, \"version\": \"1291364031\"}","7":"{\"name\": \"sushi.items\", \"fingerprint\": {\"data_hash\": \"1960378930\", \"metadata_hash\": \"2900807542\", \"parent_data_hash\": \"0\", \"parent_metadata_hash\": \"0\"}, \"physical_schema\": \"sqlmesh\", \"model\": {\"name\": \"sushi.items\", \"kind\": {\"name\": \"INCREMENTAL_BY_TIME_RANGE\", \"time_column\": {\"column\": \"ds\", \"format\": \"%Y-%m-%d\"}}, \"dialect\": \"\", \"cron\": \"@daily\", \"start\": \"Jan 1 2022\", \"batch_size\": 30, \"partitioned_by\": [], \"pre\": [], \"post\": [], \"depends_on\": [], \"columns\": {\"id\": \"INT\", \"name\": \"TEXT\", \"price\": \"DOUBLE\", \"ds\": \"TEXT\"}, \"audits\": [[\"accepted_values\", {\"column\": \"name\", \"values\": \"ARRAY('Ahi', 'Aji', 'Amaebi', 'Anago', 'Aoyagi', 'Bincho', 'Katsuo', 'Ebi', 'Escolar', 'Hamachi', 'Hamachi Toro', 'Hirame', 'Hokigai', 'Hotate', 'Ika', 'Ikura', 'Iwashi', 'Kani', 'Kanpachi', 'Maguro', 'Saba', 'Sake', 'Sake Toro', 'Tai', 'Tako', 'Tamago', 'Tobiko', 'Toro', 'Tsubugai', 'Umi Masu', 'Unagi', 'Uni')\"}], [\"not_null\", {\"columns\": \"ARRAY(name, price)\"}], [\"assert_items_price_exceeds_threshold\", {\"price\": \"0\"}]], \"expressions\": [], \"python_env\": {\"execute\": {\"payload\": \"def execute(context, start, end, latest, **kwargs):\\n dfs = []\\n for dt in iter_dates(start, end):\\n num_items = random.randint(10, len(ITEMS))\\n dfs.append(pd.DataFrame({'name': random.sample(ITEMS, num_items),\\n 'price': np.random.uniform(3.0, 10.0, size=num_items).round(2),\\n 'ds': to_ds(dt)}).reset_index().rename(columns={'index': 'id'}))\\n return pd.concat(dfs)\", \"kind\": \"definition\", \"name\": \"execute\", \"path\": \"models\/items.py\"}, \"iter_dates\": {\"payload\": \"def iter_dates(start, end):\\n for i in range((end - start).days + 1):\\n dt = start + timedelta(days=i)\\n set_seed(dt)\\n yield dt\", \"kind\": \"definition\", \"name\": \"iter_dates\", \"path\": \"helper.py\"}, \"timedelta\": {\"payload\": \"from datetime import timedelta\", \"kind\": \"import\"}, \"set_seed\": {\"payload\": \"def set_seed(dt):\\n ts = int(dt.timestamp())\\n random.seed(ts)\\n np.random.seed(ts)\", \"kind\": \"definition\", \"name\": \"set_seed\", \"path\": \"helper.py\"}, \"random\": {\"payload\": \"import random\", \"kind\": \"import\"}, \"np\": {\"payload\": \"import numpy as np\", \"kind\": \"import\"}, \"ITEMS\": {\"payload\": \"['Ahi', 'Aji', 'Amaebi', 'Anago', 'Aoyagi', 'Bincho', 'Katsuo', 'Ebi', 'Escolar', 'Hamachi', 'Hamachi Toro', 'Hirame', 'Hokigai', 'Hotate', 'Ika', 'Ikura', 'Iwashi', 'Kani', 'Kanpachi', 'Maguro', 'Saba', 'Sake', 'Sake Toro', 'Tai', 'Tako', 'Tamago', 'Tobiko', 'Toro', 'Tsubugai', 'Umi Masu', 'Unagi', 'Uni']\", \"kind\": \"value\"}, \"pd\": {\"payload\": \"import pandas as pd\", \"kind\": \"import\"}, \"to_ds\": {\"payload\": \"from sqlmesh.utils.date import to_ds\", \"kind\": \"import\"}}, \"jinja_macros\": {\"packages\": {}, \"root_macros\": {}, \"global_objs\": {}}, \"entrypoint\": \"execute\", \"source_type\": \"python\"}, \"parents\": [], \"audits\": [{\"name\": \"assert_items_price_exceeds_threshold\", \"dialect\": \"\", \"skip\": false, \"blocking\": true, \"query\": \"SELECT * FROM @this_model WHERE price <= @price\", \"expressions\": []}], \"intervals\": [[1672531200000, 1673136000000]], \"dev_intervals\": [], \"created_ts\": 1680814376399, \"updated_ts\": 1680814376399, \"ttl\": \"in 1 week\", \"previous_versions\": [], \"indirect_versions\": {}, \"version\": \"312608270\"}","8":"{\"name\": \"sushi.order_items\", \"fingerprint\": {\"data_hash\": \"653664599\", \"metadata_hash\": \"1960934702\", \"parent_data_hash\": \"3170724558\", \"parent_metadata_hash\": \"867324801\"}, \"physical_schema\": \"sqlmesh\", \"model\": {\"name\": \"sushi.order_items\", \"kind\": {\"name\": \"INCREMENTAL_BY_TIME_RANGE\", \"time_column\": {\"column\": \"ds\", \"format\": \"%Y-%m-%d\"}}, \"dialect\": \"\", \"cron\": \"@daily\", \"batch_size\": 30, \"partitioned_by\": [], \"pre\": [], \"post\": [], \"depends_on\": [\"sushi.items\", \"sushi.orders\"], \"columns\": {\"id\": \"INT\", \"order_id\": \"INT\", \"item_id\": \"INT\", \"quantity\": \"INT\", \"ds\": \"TEXT\"}, \"audits\": [[\"not_null\", {\"columns\": \"ARRAY(id, order_id, item_id, quantity)\"}], [\"assert_order_items_quantity_exceeds_threshold\", {\"quantity\": \"0\"}]], \"expressions\": [], \"python_env\": {\"execute\": {\"payload\": \"def execute(context, start, end, latest, **kwargs):\\n orders_table = context.table('sushi.orders')\\n items_table = context.table(ITEMS)\\n for dt in iter_dates(start, end):\\n orders = context.fetchdf(\\n f\\\"\\\"\\\"\\n SELECT *\\n FROM {orders_table}\\n WHERE ds = '{to_ds(dt)}'\\n \\\"\\\"\\\"\\n )\\n items = context.fetchdf(\\n f\\\"\\\"\\\"\\n SELECT *\\n FROM {items_table}\\n WHERE ds = '{to_ds(dt)}'\\n \\\"\\\"\\\"\\n )\\n for order_id in orders['id']:\\n n = random.randint(1, 5)\\n yield pd.DataFrame({'order_id': order_id, 'item_id': items.\\n sample(n=n)['id'], 'quantity': np.random.randint(1, 10, n),\\n 'ds': to_ds(dt)}).reset_index().rename(columns={'index': 'id'})\", \"kind\": \"definition\", \"name\": \"execute\", \"path\": \"models\/order_items.py\"}, \"ITEMS\": {\"payload\": \"'sushi.items'\", \"kind\": \"value\"}, \"iter_dates\": {\"payload\": \"def iter_dates(start, end):\\n for i in range((end - start).days + 1):\\n dt = start + timedelta(days=i)\\n set_seed(dt)\\n yield dt\", \"kind\": \"definition\", \"name\": \"iter_dates\", \"path\": \"helper.py\"}, \"timedelta\": {\"payload\": \"from datetime import timedelta\", \"kind\": \"import\"}, \"set_seed\": {\"payload\": \"def set_seed(dt):\\n ts = int(dt.timestamp())\\n random.seed(ts)\\n np.random.seed(ts)\", \"kind\": \"definition\", \"name\": \"set_seed\", \"path\": \"helper.py\"}, \"random\": {\"payload\": \"import random\", \"kind\": \"import\"}, \"np\": {\"payload\": \"import numpy as np\", \"kind\": \"import\"}, \"to_ds\": {\"payload\": \"from sqlmesh.utils.date import to_ds\", \"kind\": \"import\"}, \"pd\": {\"payload\": \"import pandas as pd\", \"kind\": \"import\"}}, \"jinja_macros\": {\"packages\": {}, \"root_macros\": {}, \"global_objs\": {}}, \"entrypoint\": \"execute\", \"source_type\": \"python\"}, \"parents\": [{\"name\": \"sushi.items\", \"identifier\": \"2957171338\"}, {\"name\": \"sushi.orders\", \"identifier\": \"3564161223\"}], \"audits\": [{\"name\": \"assert_order_items_quantity_exceeds_threshold\", \"dialect\": \"\", \"skip\": false, \"blocking\": true, \"query\": \"SELECT * FROM @this_model WHERE quantity <= @quantity\", \"expressions\": []}], \"intervals\": [[1672531200000, 1673136000000]], \"dev_intervals\": [], \"created_ts\": 1680814376401, \"updated_ts\": 1680814376401, \"ttl\": \"in 1 week\", \"previous_versions\": [], \"indirect_versions\": {}, \"version\": \"1015284155\"}","9":"{\"name\": \"sushi.orders\", \"fingerprint\": {\"data_hash\": \"1628439771\", \"metadata_hash\": \"2745052130\", \"parent_data_hash\": \"0\", \"parent_metadata_hash\": \"0\"}, \"physical_schema\": \"sqlmesh\", \"model\": {\"name\": \"sushi.orders\", \"kind\": {\"name\": \"INCREMENTAL_BY_TIME_RANGE\", \"time_column\": {\"column\": \"ds\", \"format\": \"%Y-%m-%d\"}}, \"dialect\": \"\", \"cron\": \"@daily\", \"description\": \"Table of sushi orders.\", \"start\": \"2022-01-01\", \"batch_size\": 30, \"partitioned_by\": [], \"pre\": [], \"post\": [], \"depends_on\": [], \"columns\": {\"id\": \"INT\", \"customer_id\": \"INT\", \"waiter_id\": \"INT\", \"start_ts\": \"INT\", \"end_ts\": \"INT\", \"ds\": \"TEXT\"}, \"audits\": [], \"expressions\": [], \"python_env\": {\"execute\": {\"payload\": \"def execute(context, start, end, latest, **kwargs):\\n dfs = []\\n for dt in iter_dates(start, end):\\n num_orders = random.randint(10, 30)\\n start_ts = [int((dt + timedelta(seconds=random.randint(0, 80000))).\\n timestamp()) for _ in range(num_orders)]\\n end_ts = [int(s + random.randint(0, 60 * 60)) for s in start_ts]\\n dfs.append(pd.DataFrame({'customer_id': random.choices(CUSTOMERS, k\\n =num_orders), 'waiter_id': random.choices(WAITERS, k=num_orders\\n ), 'start_ts': start_ts, 'end_ts': end_ts, 'ds': to_ds(dt)}).\\n reset_index().rename(columns={'index': 'id'}))\\n return pd.concat(dfs)\", \"kind\": \"definition\", \"name\": \"execute\", \"path\": \"models\/orders.py\"}, \"iter_dates\": {\"payload\": \"def iter_dates(start, end):\\n for i in range((end - start).days + 1):\\n dt = start + timedelta(days=i)\\n set_seed(dt)\\n yield dt\", \"kind\": \"definition\", \"name\": \"iter_dates\", \"path\": \"helper.py\"}, \"timedelta\": {\"payload\": \"from datetime import timedelta\", \"kind\": \"import\"}, \"set_seed\": {\"payload\": \"def set_seed(dt):\\n ts = int(dt.timestamp())\\n random.seed(ts)\\n np.random.seed(ts)\", \"kind\": \"definition\", \"name\": \"set_seed\", \"path\": \"helper.py\"}, \"random\": {\"payload\": \"import random\", \"kind\": \"import\"}, \"np\": {\"payload\": \"import numpy as np\", \"kind\": \"import\"}, \"pd\": {\"payload\": \"import pandas as pd\", \"kind\": \"import\"}, \"CUSTOMERS\": {\"payload\": \"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]\", \"kind\": \"value\"}, \"WAITERS\": {\"payload\": \"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\", \"kind\": \"value\"}, \"to_ds\": {\"payload\": \"from sqlmesh.utils.date import to_ds\", \"kind\": \"import\"}}, \"jinja_macros\": {\"packages\": {}, \"root_macros\": {}, \"global_objs\": {}}, \"entrypoint\": \"execute\", \"source_type\": \"python\"}, \"parents\": [], \"audits\": [], \"intervals\": [[1672531200000, 1673136000000]], \"dev_intervals\": [], \"created_ts\": 1680814376402, \"updated_ts\": 1680814376402, \"ttl\": \"in 1 week\", \"previous_versions\": [], \"indirect_versions\": {}, \"version\": \"925846788\"}","10":"{\"name\": \"sushi.waiter_as_customer_by_day\", \"fingerprint\": {\"data_hash\": \"486172035\", \"metadata_hash\": \"1992853678\", \"parent_data_hash\": \"2824767713\", \"parent_metadata_hash\": \"1349779748\"}, \"physical_schema\": \"sqlmesh\", \"model\": {\"name\": \"sushi.waiter_as_customer_by_day\", \"kind\": {\"name\": \"INCREMENTAL_BY_TIME_RANGE\", \"time_column\": {\"column\": \"ds\", \"format\": \"%Y-%m-%d\"}}, \"dialect\": \"duckdb\", \"cron\": \"@daily\", \"owner\": \"jen\", \"partitioned_by\": [], \"pre\": [], \"post\": [], \"audits\": [[\"not_null\", {\"columns\": \"ARRAY(waiter_id)\"}]], \"expressions\": [], \"python_env\": {}, \"jinja_macros\": {\"packages\": {}, \"root_macros\": {}, \"global_objs\": {}}, \"table_properties\": {\"key\": \"'value'\"}, \"query\": \"SELECT w.ds AS ds, w.waiter_id AS waiter_id, wn.name AS waiter_name FROM sushi.waiters AS w JOIN sushi.customers AS c ON w.waiter_id = c.customer_id JOIN sushi.waiter_names AS wn ON w.waiter_id = wn.id\", \"source_type\": \"sql\"}, \"parents\": [{\"name\": \"sushi.waiters\", \"identifier\": \"3386889721\"}, {\"name\": \"sushi.waiter_names\", \"identifier\": \"1604207722\"}, {\"name\": \"sushi.customers\", \"identifier\": \"3148897116\"}, {\"name\": \"sushi.orders\", \"identifier\": \"3564161223\"}], \"audits\": [], \"intervals\": [[1672531200000, 1673136000000]], \"dev_intervals\": [], \"created_ts\": 1680814464891, \"updated_ts\": 1680814464891, \"ttl\": \"in 1 week\", \"previous_versions\": [{\"fingerprint\": {\"data_hash\": \"486172035\", \"metadata_hash\": \"1992853678\", \"parent_data_hash\": \"2154574190\", \"parent_metadata_hash\": \"1349779748\"}, \"version\": \"1267397572\"}], \"indirect_versions\": {}, \"version\": \"3668757715\"}","11":"{\"name\": \"sushi.waiter_names\", \"fingerprint\": {\"data_hash\": \"4133862560\", \"metadata_hash\": \"570478986\", \"parent_data_hash\": \"0\", \"parent_metadata_hash\": \"0\"}, \"physical_schema\": \"sqlmesh\", \"model\": {\"name\": \"sushi.waiter_names\", \"kind\": {\"name\": \"SEED\", \"path\": \"..\/seeds\/waiter_names.csv\", \"batch_size\": 5}, \"dialect\": \"duckdb\", \"cron\": \"@daily\", \"owner\": \"jen\", \"partitioned_by\": [], \"pre\": [], \"post\": [], \"audits\": [], \"expressions\": [], \"jinja_macros\": {\"packages\": {}, \"root_macros\": {}, \"global_objs\": {}}, \"seed\": {\"content\": \"id,name\\n0,Toby\\n1,Tyson\\n2,Ryan\\n3,George\\n4,Chris\\n5,Max\\n6,Vincent\\n7,Iaroslav\\n8,Emma\\n9,Maia\\n10,Jim\\n\"}, \"source_type\": \"seed\"}, \"parents\": [], \"audits\": [], \"intervals\": [[1672531200000, 1673136000000]], \"dev_intervals\": [], \"created_ts\": 1680814464932, \"updated_ts\": 1680814464932, \"ttl\": \"in 1 week\", \"previous_versions\": [{\"fingerprint\": {\"data_hash\": \"1876476880\", \"metadata_hash\": \"570478986\", \"parent_data_hash\": \"0\", \"parent_metadata_hash\": \"0\"}, \"version\": \"2505706914\"}], \"indirect_versions\": {\"sushi.waiter_as_customer_by_day\": [{\"fingerprint\": {\"data_hash\": \"486172035\", \"metadata_hash\": \"1992853678\", \"parent_data_hash\": \"2154574190\", \"parent_metadata_hash\": \"1349779748\"}, \"version\": \"1267397572\"}, {\"fingerprint\": {\"data_hash\": \"486172035\", \"metadata_hash\": \"1992853678\", \"parent_data_hash\": \"2824767713\", \"parent_metadata_hash\": \"1349779748\"}, \"version\": \"3668757715\"}]}, \"version\": \"1204702829\", \"change_category\": 1}"}} \ No newline at end of file +{"name":{"0":"\"sushi\".\"waiter_as_customer_by_day\"","1":"\"sushi\".\"waiter_revenue_by_day\"","2":"\"sushi\".\"top_waiters\"","3":"\"sushi\".\"waiters\"","4":"\"sushi\".\"customers\"","5":"\"sushi\".\"waiter_names\"","6":"\"sushi\".\"customer_revenue_by_day\"","7":"\"sushi\".\"items\"","8":"\"sushi\".\"order_items\"","9":"\"sushi\".\"orders\"","10":"\"sushi\".\"waiter_as_customer_by_day\"","11":"\"sushi\".\"waiter_names\"","12":"\"sushi\".\"waiter_as_customer_by_day\"","13":"\"sushi\".\"waiter_names\"","14":"\"sushi\".\"customer_revenue_by_day\"","15":"\"sushi\".\"top_waiters\"","16":"\"sushi\".\"waiter_revenue_by_day\"","17":"\"sushi\".\"order_items\"","18":"\"sushi\".\"items\"","19":"\"sushi\".\"waiter_as_customer_by_day\"","20":"\"sushi\".\"waiter_names\"","21":"\"sushi\".\"customers\"","22":"\"sushi\".\"waiters\"","23":"\"sushi\".\"orders\""},"identifier":{"0":"1281222509","1":"1609279380","2":"599861134","3":"3386889721","4":"3148897116","5":"3233103305","6":"1308408370","7":"2957171338","8":"1806777563","9":"3564161223","10":"1084858582","11":"1604207722","12":"3998224796","13":"2725136291","14":"3566886383","15":"129039563","16":"2175947464","17":"1422946820","18":"3721860967","19":"1341746752","20":"1609854746","21":"1461038955","22":"4123940212","23":"1250207606"},"version":{"0":"1267397572","1":"2695875565","2":"3010914162","3":"2059227798","4":"2359719298","5":"2505706914","6":"1291364031","7":"312608270","8":"1015284155","9":"925846788","10":"3668757715","11":"3668757715","12":"3668757715","13":"1204702829","14":"1291364031","15":"3010914162","16":"2695875565","17":"1015284155","18":"312608270","19":"1267397572","20":"2505706914","21":"2359719298","22":"2059227798","23":"925846788"},"snapshot":{"0":"{\"name\": \"\\\"sushi\\\".\\\"waiter_as_customer_by_day\\\"\", \"fingerprint\": {\"data_hash\": \"486172035\", \"metadata_hash\": \"1992853678\", \"parent_data_hash\": \"2154574190\", \"parent_metadata_hash\": \"1349779748\"}, \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"\\\"sushi\\\".\\\"waiters\\\"\", \"identifier\": \"3386889721\"}, {\"name\": \"\\\"sushi\\\".\\\"waiter_names\\\"\", \"identifier\": \"3233103305\"}, {\"name\": \"\\\"sushi\\\".\\\"customers\\\"\", \"identifier\": \"3148897116\"}, {\"name\": \"\\\"sushi\\\".\\\"orders\\\"\", \"identifier\": \"3564161223\"}], \"created_ts\": 1680814376348, \"ttl\": \"in 1 week\", \"previous_versions\": [], \"version\": \"1267397572\", \"node\": {\"name\": \"sushi.waiter_as_customer_by_day\", \"kind\": {\"name\": \"INCREMENTAL_BY_TIME_RANGE\", \"time_column\": {\"column\": \"ds\", \"format\": \"%Y-%m-%d\"}}, \"dialect\": \"duckdb\", \"cron\": \"@daily\", \"owner\": \"jen\", \"partitioned_by\": [], \"audits\": [[\"not_null\", {\"columns\": \"ARRAY(waiter_id)\"}]], \"python_env\": {}, \"jinja_macros\": {\"packages\": {}, \"root_macros\": {}, \"global_objs\": {}}, \"query\": \"SELECT w.ds AS ds, w.waiter_id AS waiter_id, wn.name AS waiter_name FROM sushi.waiters AS w JOIN sushi.customers AS c ON w.waiter_id = c.customer_id JOIN sushi.waiter_names AS wn ON w.waiter_id = wn.id\", \"source_type\": \"sql\", \"project\": \"\", \"default_catalog\": null}, \"change_category\": 4, \"base_table_name_override\": \"sushi.waiter_as_customer_by_day\"}","1":"{\"name\": \"\\\"sushi\\\".\\\"waiter_revenue_by_day\\\"\", \"fingerprint\": {\"data_hash\": \"2443934302\", \"metadata_hash\": \"2904050331\", \"parent_data_hash\": \"764310396\", \"parent_metadata_hash\": \"3147731239\"}, \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"\\\"sushi\\\".\\\"order_items\\\"\", \"identifier\": \"1806777563\"}, {\"name\": \"\\\"sushi\\\".\\\"items\\\"\", \"identifier\": \"2957171338\"}, {\"name\": \"\\\"sushi\\\".\\\"orders\\\"\", \"identifier\": \"3564161223\"}], \"created_ts\": 1680814376361, \"ttl\": \"in 1 week\", \"previous_versions\": [], \"version\": \"2695875565\", \"node\": {\"name\": \"sushi.waiter_revenue_by_day\", \"kind\": {\"name\": \"INCREMENTAL_BY_TIME_RANGE\", \"time_column\": {\"column\": \"ds\", \"format\": \"%Y-%m-%d\"}, \"batch_size\": 10}, \"dialect\": \"duckdb\", \"cron\": \"@daily\", \"owner\": \"jen\", \"description\": \"Table of revenue generated by waiters by day.\", \"partitioned_by\": [], \"audits\": [[\"number_of_rows\", {\"threshold\": \"0\"}]], \"python_env\": {}, \"jinja_macros\": {\"packages\": {}, \"root_macros\": {}, \"global_objs\": {}}, \"query\": \"SELECT CAST(o.waiter_id AS INT) AS waiter_id \/* Waiter id *\/, CAST(SUM(oi.quantity * i.price) AS DOUBLE) AS revenue \/* Revenue from orders taken by this waiter *\/, CAST(o.ds AS TEXT) AS ds \/* Date *\/ FROM sushi.orders AS o LEFT JOIN sushi.order_items AS oi ON o.id = oi.order_id AND o.ds = oi.ds LEFT JOIN sushi.items AS i ON oi.item_id = i.id AND oi.ds = i.ds WHERE o.ds BETWEEN @start_ds AND @end_ds GROUP BY o.waiter_id, o.ds\", \"source_type\": \"sql\", \"project\": \"\", \"default_catalog\": null}, \"change_category\": 4, \"base_table_name_override\": \"sushi.waiter_revenue_by_day\"}","2":"{\"name\": \"\\\"sushi\\\".\\\"top_waiters\\\"\", \"fingerprint\": {\"data_hash\": \"2891807529\", \"metadata_hash\": \"3392493998\", \"parent_data_hash\": \"1940707936\", \"parent_metadata_hash\": \"1276363398\"}, \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"\\\"sushi\\\".\\\"waiter_revenue_by_day\\\"\", \"identifier\": \"1609279380\"}], \"created_ts\": 1680814376384, \"ttl\": \"in 1 week\", \"previous_versions\": [], \"version\": \"3010914162\", \"node\": {\"name\": \"sushi.top_waiters\", \"kind\": {\"name\": \"VIEW\"}, \"dialect\": \"duckdb\", \"cron\": \"@daily\", \"owner\": \"jen\", \"description\": \"View of top waiters.\", \"partitioned_by\": [], \"audits\": [[\"unique_values\", {\"columns\": \"ARRAY(waiter_id)\"}]], \"python_env\": {}, \"jinja_macros\": {\"packages\": {}, \"root_macros\": {}, \"global_objs\": {}}, \"query\": \"SELECT CAST(waiter_id AS INT) AS waiter_id, CAST(revenue AS DOUBLE) AS revenue FROM sushi.waiter_revenue_by_day WHERE ds = (SELECT MAX(ds) FROM sushi.waiter_revenue_by_day) ORDER BY revenue DESC LIMIT 10\", \"source_type\": \"sql\", \"project\": \"\", \"default_catalog\": null}, \"change_category\": 4, \"base_table_name_override\": \"sushi.top_waiters\"}","3":"{\"name\": \"\\\"sushi\\\".\\\"waiters\\\"\", \"fingerprint\": {\"data_hash\": \"3501061139\", \"metadata_hash\": \"570478986\", \"parent_data_hash\": \"777615193\", \"parent_metadata_hash\": \"2042613269\"}, \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"\\\"sushi\\\".\\\"orders\\\"\", \"identifier\": \"3564161223\"}], \"created_ts\": 1680814376387, \"ttl\": \"in 1 week\", \"previous_versions\": [], \"version\": \"2059227798\", \"node\": {\"name\": \"sushi.waiters\", \"kind\": {\"name\": \"EMBEDDED\"}, \"dialect\": \"duckdb\", \"cron\": \"@daily\", \"owner\": \"jen\", \"partitioned_by\": [], \"audits\": [], \"python_env\": {\"incremental_by_ds\": {\"payload\": \"def incremental_by_ds(evaluator, column):\\n expression = evaluator.transform(exp.Between(this=column, low=MacroVar(\\n this='start_ds'), high=MacroVar(this='end_ds')))\\n if not isinstance(expression, exp.Expression):\\n raise MacroEvalError(\\n f'Return type is {type(expression)}, expected exp.Expression')\\n return expression\", \"kind\": \"definition\", \"name\": \"incremental_by_ds\", \"path\": \"macros\/macros.py\"}, \"exp\": {\"payload\": \"import sqlglot.expressions as exp\", \"kind\": \"import\"}, \"MacroVar\": {\"payload\": \"from sqlmesh.core.dialect import MacroVar\", \"kind\": \"import\"}, \"MacroEvalError\": {\"payload\": \"from sqlmesh.utils.errors import MacroEvalError\", \"kind\": \"import\"}}, \"jinja_macros\": {\"packages\": {}, \"root_macros\": {}, \"global_objs\": {}}, \"query\": \"SELECT DISTINCT CAST(waiter_id AS INT) AS waiter_id, CAST(ds AS TEXT) AS ds FROM sushi.orders AS o WHERE @incremental_by_ds(ds)\", \"source_type\": \"sql\", \"project\": \"\", \"default_catalog\": null}, \"change_category\": 4, \"base_table_name_override\": \"sushi.waiters\"}","4":"{\"name\": \"\\\"sushi\\\".\\\"customers\\\"\", \"fingerprint\": {\"data_hash\": \"3553985282\", \"metadata_hash\": \"570478986\", \"parent_data_hash\": \"777615193\", \"parent_metadata_hash\": \"2042613269\"}, \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"\\\"sushi\\\".\\\"orders\\\"\", \"identifier\": \"3564161223\"}], \"created_ts\": 1680814376388, \"ttl\": \"in 1 week\", \"previous_versions\": [], \"version\": \"2359719298\", \"node\": {\"name\": \"sushi.customers\", \"kind\": {\"name\": \"FULL\"}, \"dialect\": \"duckdb\", \"cron\": \"@daily\", \"owner\": \"jen\", \"partitioned_by\": [], \"audits\": [], \"python_env\": {\"noop\": {\"payload\": \"def noop(context, start, end, latest, **kwargs):\\n pass\", \"kind\": \"definition\", \"name\": \"noop\", \"path\": \"hooks\/hooks.py\"}}, \"jinja_macros\": {\"packages\": {}, \"root_macros\": {}, \"global_objs\": {}}, \"query\": \"SELECT DISTINCT CAST(customer_id AS INT) AS customer_id FROM sushi.orders AS o\", \"source_type\": \"sql\", \"project\": \"\", \"default_catalog\": null}, \"change_category\": 4, \"base_table_name_override\": \"sushi.customers\"}","5":"{\"name\": \"\\\"sushi\\\".\\\"waiter_names\\\"\", \"fingerprint\": {\"data_hash\": \"1876476880\", \"metadata_hash\": \"570478986\", \"parent_data_hash\": \"0\", \"parent_metadata_hash\": \"0\"}, \"physical_schema\": \"sqlmesh\", \"parents\": [], \"created_ts\": 1680814376389, \"ttl\": \"in 1 week\", \"previous_versions\": [], \"version\": \"2505706914\", \"node\": {\"name\": \"sushi.waiter_names\", \"kind\": {\"name\": \"SEED\", \"path\": \"..\/seeds\/waiter_names.csv\", \"batch_size\": 5}, \"dialect\": \"duckdb\", \"cron\": \"@daily\", \"owner\": \"jen\", \"partitioned_by\": [], \"audits\": [], \"jinja_macros\": {\"packages\": {}, \"root_macros\": {}, \"global_objs\": {}}, \"seed\": {\"content\": \"id,name\\n0,Toby\\n1,Tyson\\n2,Ryan\\n3,George\\n4,Chris\\n5,Max\\n6,Vincent\\n7,Iaroslav\\n8,Emma\\n9,Maia\\n\"}, \"source_type\": \"seed\", \"project\": \"\", \"default_catalog\": null}, \"change_category\": 4, \"base_table_name_override\": \"sushi.waiter_names\"}","6":"{\"name\": \"\\\"sushi\\\".\\\"customer_revenue_by_day\\\"\", \"fingerprint\": {\"data_hash\": \"2657552867\", \"metadata_hash\": \"129771006\", \"parent_data_hash\": \"764310396\", \"parent_metadata_hash\": \"3147731239\"}, \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"\\\"sushi\\\".\\\"order_items\\\"\", \"identifier\": \"1806777563\"}, {\"name\": \"\\\"sushi\\\".\\\"items\\\"\", \"identifier\": \"2957171338\"}, {\"name\": \"\\\"sushi\\\".\\\"orders\\\"\", \"identifier\": \"3564161223\"}], \"created_ts\": 1680814376391, \"ttl\": \"in 1 week\", \"previous_versions\": [], \"version\": \"1291364031\", \"node\": {\"name\": \"sushi.customer_revenue_by_day\", \"kind\": {\"name\": \"INCREMENTAL_BY_TIME_RANGE\", \"time_column\": {\"column\": \"ds\", \"format\": \"%Y-%m-%d\"}, \"batch_size\": 10}, \"dialect\": \"hive\", \"cron\": \"@daily\", \"owner\": \"jen\", \"description\": \"Table of revenue from customers by day.\", \"partitioned_by\": [], \"audits\": [], \"python_env\": {}, \"jinja_macros\": {\"packages\": {}, \"root_macros\": {}, \"global_objs\": {}}, \"query\": \"JINJA_QUERY_BEGIN;\\nWITH order_total AS (SELECT oi.order_id AS order_id, SUM(oi.quantity * i.price) AS total, oi.ds AS ds FROM sushi.order_items AS oi LEFT JOIN sushi.items AS i ON oi.item_id = i.id AND oi.ds = i.ds WHERE oi.ds BETWEEN '{{ start_ds }}' AND '{{ end_ds }}' GROUP BY oi.order_id, oi.ds) SELECT CAST(o.customer_id AS INT) AS customer_id \/* Customer id *\/, CAST(SUM(ot.total) AS DOUBLE) AS revenue \/* Revenue from orders made by this customer *\/, CAST(o.ds AS TEXT) AS ds \/* Date *\/ FROM sushi.orders AS o LEFT JOIN order_total AS ot ON o.id = ot.order_id AND o.ds = ot.ds WHERE o.ds BETWEEN '{{ start_ds }}' AND '{{ end_ds }}' GROUP BY o.customer_id, o.ds\\nJINJA_END;\", \"source_type\": \"sql\", \"project\": \"\", \"default_catalog\": null}, \"change_category\": 4, \"base_table_name_override\": \"sushi.customer_revenue_by_day\"}","7":"{\"name\": \"\\\"sushi\\\".\\\"items\\\"\", \"fingerprint\": {\"data_hash\": \"1960378930\", \"metadata_hash\": \"2900807542\", \"parent_data_hash\": \"0\", \"parent_metadata_hash\": \"0\"}, \"physical_schema\": \"sqlmesh\", \"parents\": [], \"created_ts\": 1680814376399, \"ttl\": \"in 1 week\", \"previous_versions\": [], \"version\": \"312608270\", \"node\": {\"name\": \"sushi.items\", \"kind\": {\"name\": \"INCREMENTAL_BY_TIME_RANGE\", \"time_column\": {\"column\": \"ds\", \"format\": \"%Y-%m-%d\"}, \"batch_size\": 30}, \"dialect\": \"\", \"cron\": \"@daily\", \"start\": \"Jan 1 2022\", \"partitioned_by\": [], \"depends_on\": [], \"columns\": {\"id\": \"INT\", \"name\": \"TEXT\", \"price\": \"DOUBLE\", \"ds\": \"TEXT\"}, \"audits\": [[\"accepted_values\", {\"column\": \"name\", \"values\": \"ARRAY('Ahi', 'Aji', 'Amaebi', 'Anago', 'Aoyagi', 'Bincho', 'Katsuo', 'Ebi', 'Escolar', 'Hamachi', 'Hamachi Toro', 'Hirame', 'Hokigai', 'Hotate', 'Ika', 'Ikura', 'Iwashi', 'Kani', 'Kanpachi', 'Maguro', 'Saba', 'Sake', 'Sake Toro', 'Tai', 'Tako', 'Tamago', 'Tobiko', 'Toro', 'Tsubugai', 'Umi Masu', 'Unagi', 'Uni')\"}], [\"not_null\", {\"columns\": \"ARRAY(name, price)\"}], [\"assert_items_price_exceeds_threshold\", {\"price\": \"0\"}]], \"python_env\": {\"execute\": {\"payload\": \"def execute(context, start, end, latest, **kwargs):\\n dfs = []\\n for dt in iter_dates(start, end):\\n num_items = random.randint(10, len(ITEMS))\\n dfs.append(pd.DataFrame({'name': random.sample(ITEMS, num_items),\\n 'price': np.random.uniform(3.0, 10.0, size=num_items).round(2),\\n 'ds': to_ds(dt)}).reset_index().rename(columns={'index': 'id'}))\\n return pd.concat(dfs)\", \"kind\": \"definition\", \"name\": \"execute\", \"path\": \"models\/items.py\"}, \"iter_dates\": {\"payload\": \"def iter_dates(start, end):\\n for i in range((end - start).days + 1):\\n dt = start + timedelta(days=i)\\n set_seed(dt)\\n yield dt\", \"kind\": \"definition\", \"name\": \"iter_dates\", \"path\": \"helper.py\"}, \"timedelta\": {\"payload\": \"from datetime import timedelta\", \"kind\": \"import\"}, \"set_seed\": {\"payload\": \"def set_seed(dt):\\n ts = int(dt.timestamp())\\n random.seed(ts)\\n np.random.seed(ts)\", \"kind\": \"definition\", \"name\": \"set_seed\", \"path\": \"helper.py\"}, \"random\": {\"payload\": \"import random\", \"kind\": \"import\"}, \"np\": {\"payload\": \"import numpy as np\", \"kind\": \"import\"}, \"ITEMS\": {\"payload\": \"['Ahi', 'Aji', 'Amaebi', 'Anago', 'Aoyagi', 'Bincho', 'Katsuo', 'Ebi', 'Escolar', 'Hamachi', 'Hamachi Toro', 'Hirame', 'Hokigai', 'Hotate', 'Ika', 'Ikura', 'Iwashi', 'Kani', 'Kanpachi', 'Maguro', 'Saba', 'Sake', 'Sake Toro', 'Tai', 'Tako', 'Tamago', 'Tobiko', 'Toro', 'Tsubugai', 'Umi Masu', 'Unagi', 'Uni']\", \"kind\": \"value\"}, \"pd\": {\"payload\": \"import pandas as pd\", \"kind\": \"import\"}, \"to_ds\": {\"payload\": \"from sqlmesh.utils.date import to_ds\", \"kind\": \"import\"}}, \"jinja_macros\": {\"packages\": {}, \"root_macros\": {}, \"global_objs\": {}}, \"entrypoint\": \"execute\", \"source_type\": \"python\", \"project\": \"\", \"default_catalog\": null, \"audit_definitions\": {\"assert_items_price_exceeds_threshold\": {\"name\": \"assert_items_price_exceeds_threshold\", \"dialect\": \"\", \"skip\": false, \"blocking\": true, \"query\": \"SELECT * FROM @this_model WHERE price <= @price\", \"expressions\": []}}}, \"change_category\": 4, \"base_table_name_override\": \"sushi.items\"}","8":"{\"name\": \"\\\"sushi\\\".\\\"order_items\\\"\", \"fingerprint\": {\"data_hash\": \"653664599\", \"metadata_hash\": \"1960934702\", \"parent_data_hash\": \"3170724558\", \"parent_metadata_hash\": \"867324801\"}, \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"\\\"sushi\\\".\\\"items\\\"\", \"identifier\": \"2957171338\"}, {\"name\": \"\\\"sushi\\\".\\\"orders\\\"\", \"identifier\": \"3564161223\"}], \"created_ts\": 1680814376401, \"ttl\": \"in 1 week\", \"previous_versions\": [], \"version\": \"1015284155\", \"node\": {\"name\": \"sushi.order_items\", \"kind\": {\"name\": \"INCREMENTAL_BY_TIME_RANGE\", \"time_column\": {\"column\": \"ds\", \"format\": \"%Y-%m-%d\"}, \"batch_size\": 30}, \"dialect\": \"\", \"cron\": \"@daily\", \"partitioned_by\": [], \"depends_on\": [\"\\\"sushi\\\".\\\"items\\\"\", \"\\\"sushi\\\".\\\"orders\\\"\"], \"columns\": {\"id\": \"INT\", \"order_id\": \"INT\", \"item_id\": \"INT\", \"quantity\": \"INT\", \"ds\": \"TEXT\"}, \"audits\": [[\"not_null\", {\"columns\": \"ARRAY(id, order_id, item_id, quantity)\"}], [\"assert_order_items_quantity_exceeds_threshold\", {\"quantity\": \"0\"}]], \"python_env\": {\"execute\": {\"payload\": \"def execute(context, start, end, latest, **kwargs):\\n orders_table = context.table('sushi.orders')\\n items_table = context.table(ITEMS)\\n for dt in iter_dates(start, end):\\n orders = context.fetchdf(\\n f\\\"\\\"\\\"\\n SELECT *\\n FROM {orders_table}\\n WHERE ds = '{to_ds(dt)}'\\n \\\"\\\"\\\"\\n )\\n items = context.fetchdf(\\n f\\\"\\\"\\\"\\n SELECT *\\n FROM {items_table}\\n WHERE ds = '{to_ds(dt)}'\\n \\\"\\\"\\\"\\n )\\n for order_id in orders['id']:\\n n = random.randint(1, 5)\\n yield pd.DataFrame({'order_id': order_id, 'item_id': items.\\n sample(n=n)['id'], 'quantity': np.random.randint(1, 10, n),\\n 'ds': to_ds(dt)}).reset_index().rename(columns={'index': 'id'})\", \"kind\": \"definition\", \"name\": \"execute\", \"path\": \"models\/order_items.py\"}, \"ITEMS\": {\"payload\": \"'sushi.items'\", \"kind\": \"value\"}, \"iter_dates\": {\"payload\": \"def iter_dates(start, end):\\n for i in range((end - start).days + 1):\\n dt = start + timedelta(days=i)\\n set_seed(dt)\\n yield dt\", \"kind\": \"definition\", \"name\": \"iter_dates\", \"path\": \"helper.py\"}, \"timedelta\": {\"payload\": \"from datetime import timedelta\", \"kind\": \"import\"}, \"set_seed\": {\"payload\": \"def set_seed(dt):\\n ts = int(dt.timestamp())\\n random.seed(ts)\\n np.random.seed(ts)\", \"kind\": \"definition\", \"name\": \"set_seed\", \"path\": \"helper.py\"}, \"random\": {\"payload\": \"import random\", \"kind\": \"import\"}, \"np\": {\"payload\": \"import numpy as np\", \"kind\": \"import\"}, \"to_ds\": {\"payload\": \"from sqlmesh.utils.date import to_ds\", \"kind\": \"import\"}, \"pd\": {\"payload\": \"import pandas as pd\", \"kind\": \"import\"}}, \"jinja_macros\": {\"packages\": {}, \"root_macros\": {}, \"global_objs\": {}}, \"entrypoint\": \"execute\", \"source_type\": \"python\", \"project\": \"\", \"default_catalog\": null, \"audit_definitions\": {\"assert_order_items_quantity_exceeds_threshold\": {\"name\": \"assert_order_items_quantity_exceeds_threshold\", \"dialect\": \"\", \"skip\": false, \"blocking\": true, \"query\": \"SELECT * FROM @this_model WHERE quantity <= @quantity\", \"expressions\": []}}}, \"change_category\": 4, \"base_table_name_override\": \"sushi.order_items\"}","9":"{\"name\": \"\\\"sushi\\\".\\\"orders\\\"\", \"fingerprint\": {\"data_hash\": \"1628439771\", \"metadata_hash\": \"2745052130\", \"parent_data_hash\": \"0\", \"parent_metadata_hash\": \"0\"}, \"physical_schema\": \"sqlmesh\", \"parents\": [], \"created_ts\": 1680814376402, \"ttl\": \"in 1 week\", \"previous_versions\": [], \"version\": \"925846788\", \"node\": {\"name\": \"sushi.orders\", \"kind\": {\"name\": \"INCREMENTAL_BY_TIME_RANGE\", \"time_column\": {\"column\": \"ds\", \"format\": \"%Y-%m-%d\"}, \"batch_size\": 30}, \"dialect\": \"\", \"cron\": \"@daily\", \"description\": \"Table of sushi orders.\", \"start\": \"2022-01-01\", \"partitioned_by\": [], \"depends_on\": [], \"columns\": {\"id\": \"INT\", \"customer_id\": \"INT\", \"waiter_id\": \"INT\", \"start_ts\": \"INT\", \"end_ts\": \"INT\", \"ds\": \"TEXT\"}, \"audits\": [], \"python_env\": {\"execute\": {\"payload\": \"def execute(context, start, end, latest, **kwargs):\\n dfs = []\\n for dt in iter_dates(start, end):\\n num_orders = random.randint(10, 30)\\n start_ts = [int((dt + timedelta(seconds=random.randint(0, 80000))).\\n timestamp()) for _ in range(num_orders)]\\n end_ts = [int(s + random.randint(0, 60 * 60)) for s in start_ts]\\n dfs.append(pd.DataFrame({'customer_id': random.choices(CUSTOMERS, k\\n =num_orders), 'waiter_id': random.choices(WAITERS, k=num_orders\\n ), 'start_ts': start_ts, 'end_ts': end_ts, 'ds': to_ds(dt)}).\\n reset_index().rename(columns={'index': 'id'}))\\n return pd.concat(dfs)\", \"kind\": \"definition\", \"name\": \"execute\", \"path\": \"models\/orders.py\"}, \"iter_dates\": {\"payload\": \"def iter_dates(start, end):\\n for i in range((end - start).days + 1):\\n dt = start + timedelta(days=i)\\n set_seed(dt)\\n yield dt\", \"kind\": \"definition\", \"name\": \"iter_dates\", \"path\": \"helper.py\"}, \"timedelta\": {\"payload\": \"from datetime import timedelta\", \"kind\": \"import\"}, \"set_seed\": {\"payload\": \"def set_seed(dt):\\n ts = int(dt.timestamp())\\n random.seed(ts)\\n np.random.seed(ts)\", \"kind\": \"definition\", \"name\": \"set_seed\", \"path\": \"helper.py\"}, \"random\": {\"payload\": \"import random\", \"kind\": \"import\"}, \"np\": {\"payload\": \"import numpy as np\", \"kind\": \"import\"}, \"pd\": {\"payload\": \"import pandas as pd\", \"kind\": \"import\"}, \"CUSTOMERS\": {\"payload\": \"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]\", \"kind\": \"value\"}, \"WAITERS\": {\"payload\": \"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\", \"kind\": \"value\"}, \"to_ds\": {\"payload\": \"from sqlmesh.utils.date import to_ds\", \"kind\": \"import\"}}, \"jinja_macros\": {\"packages\": {}, \"root_macros\": {}, \"global_objs\": {}}, \"entrypoint\": \"execute\", \"source_type\": \"python\", \"project\": \"\", \"default_catalog\": null}, \"change_category\": 4, \"base_table_name_override\": \"sushi.orders\"}","10":"{\"name\": \"\\\"sushi\\\".\\\"waiter_as_customer_by_day\\\"\", \"fingerprint\": {\"data_hash\": \"486172035\", \"metadata_hash\": \"1992853678\", \"parent_data_hash\": \"2824767713\", \"parent_metadata_hash\": \"1349779748\"}, \"physical_schema\": \"sqlmesh\", \"parents\": [{\"name\": \"\\\"sushi\\\".\\\"waiters\\\"\", \"identifier\": \"3386889721\"}, {\"name\": \"\\\"sushi\\\".\\\"waiter_names\\\"\", \"identifier\": \"1604207722\"}, {\"name\": \"\\\"sushi\\\".\\\"customers\\\"\", \"identifier\": \"3148897116\"}, {\"name\": \"\\\"sushi\\\".\\\"orders\\\"\", \"identifier\": \"3564161223\"}], \"created_ts\": 1680814464891, \"ttl\": \"in 1 week\", \"previous_versions\": [{\"fingerprint\": {\"data_hash\": \"486172035\", \"metadata_hash\": \"1992853678\", \"parent_data_hash\": \"2154574190\", \"parent_metadata_hash\": \"1349779748\"}, \"version\": \"1267397572\"}], \"version\": \"3668757715\", \"node\": {\"name\": \"sushi.waiter_as_customer_by_day\", \"kind\": {\"name\": \"INCREMENTAL_BY_TIME_RANGE\", \"time_column\": {\"column\": \"ds\", \"format\": \"%Y-%m-%d\"}}, \"dialect\": \"duckdb\", \"cron\": \"@daily\", \"owner\": \"jen\", \"partitioned_by\": [], \"audits\": [[\"not_null\", {\"columns\": \"ARRAY(waiter_id)\"}]], \"python_env\": {}, \"jinja_macros\": {\"packages\": {}, \"root_macros\": {}, \"global_objs\": {}}, \"table_properties\": \"('key' = 'value')\", \"query\": \"SELECT w.ds AS ds, w.waiter_id AS waiter_id, wn.name AS waiter_name FROM sushi.waiters AS w JOIN sushi.customers AS c ON w.waiter_id = c.customer_id JOIN sushi.waiter_names AS wn ON w.waiter_id = wn.id\", \"source_type\": \"sql\", \"project\": \"\", \"default_catalog\": null}, \"change_category\": 4, \"base_table_name_override\": \"sushi.waiter_as_customer_by_day\"}","11":"{\"name\": \"\\\"sushi\\\".\\\"waiter_names\\\"\", \"fingerprint\": {\"data_hash\": \"4133862560\", \"metadata_hash\": \"570478986\", \"parent_data_hash\": \"0\", \"parent_metadata_hash\": \"0\"}, \"physical_schema\": \"sqlmesh\", \"parents\": [], \"created_ts\": 1680814464932, \"ttl\": \"in 1 week\", \"previous_versions\": [{\"fingerprint\": {\"data_hash\": \"1876476880\", \"metadata_hash\": \"570478986\", \"parent_data_hash\": \"0\", \"parent_metadata_hash\": \"0\"}, \"version\": \"2505706914\"}], \"version\": \"1204702829\", \"change_category\": 1, \"node\": {\"name\": \"sushi.waiter_names\", \"kind\": {\"name\": \"SEED\", \"path\": \"..\/seeds\/waiter_names.csv\", \"batch_size\": 5}, \"dialect\": \"duckdb\", \"cron\": \"@daily\", \"owner\": \"jen\", \"partitioned_by\": [], \"audits\": [], \"jinja_macros\": {\"packages\": {}, \"root_macros\": {}, \"global_objs\": {}}, \"seed\": {\"content\": \"id,name\\n0,Toby\\n1,Tyson\\n2,Ryan\\n3,George\\n4,Chris\\n5,Max\\n6,Vincent\\n7,Iaroslav\\n8,Emma\\n9,Maia\\n10,Jim\\n\"}, \"source_type\": \"seed\", \"project\": \"\", \"default_catalog\": null}, \"base_table_name_override\": \"sushi.waiter_names\"}","12":"{\"name\":\"\\\"sushi\\\".\\\"waiter_as_customer_by_day\\\"\",\"temp_version\":\"3668757715\",\"change_category\":4,\"fingerprint\":{\"data_hash\":\"1936268024\",\"metadata_hash\":\"2088684978\",\"parent_data_hash\":\"3055854652\",\"parent_metadata_hash\":\"665080906\"},\"previous_versions\":[{\"fingerprint\":{\"data_hash\":\"486172035\",\"metadata_hash\":\"1992853678\",\"parent_data_hash\":\"2154574190\",\"parent_metadata_hash\":\"1349779748\"},\"version\":\"1267397572\"},{\"fingerprint\":{\"data_hash\":\"486172035\",\"metadata_hash\":\"1992853678\",\"parent_data_hash\":\"2824767713\",\"parent_metadata_hash\":\"1349779748\"},\"version\":\"3668757715\",\"change_category\":4,\"physical_schema\":\"sqlmesh\"}],\"base_table_name_override\":\"sushi.waiter_as_customer_by_day\",\"physical_schema\":\"sqlmesh\",\"node\":{\"name\":\"sushi.waiter_as_customer_by_day\",\"project\":\"\",\"owner\":\"jen\",\"cron\":\"@daily\",\"tags\":[],\"dialect\":\"duckdb\",\"kind\":{\"name\":\"INCREMENTAL_BY_TIME_RANGE\",\"on_destructive_change\":\"ERROR\",\"dialect\":\"duckdb\",\"forward_only\":false,\"disable_restatement\":false,\"time_column\":{\"column\":\"ds\",\"format\":\"%Y-%m-%d\"}},\"partitioned_by\":[],\"clustered_by\":[],\"audits\":[[\"not_null\",{\"columns\":\"ARRAY(waiter_id)\"}]],\"grains\":[],\"references\":[],\"physical_properties\":\"('key' = 'value')\",\"allow_partials\":false,\"signals\":[],\"enabled\":true,\"python_env\":{},\"jinja_macros\":{\"packages\":{},\"root_macros\":{},\"global_objs\":{},\"create_builtins_module\":\"sqlmesh.utils.jinja\",\"top_level_packages\":[]},\"audit_definitions\":{},\"mapping_schema\":{},\"extract_dependencies_from_query\":true,\"query\":\"SELECT w.ds AS ds, w.waiter_id AS waiter_id, wn.name AS waiter_name FROM sushi.waiters AS w JOIN sushi.customers AS c ON w.waiter_id = c.customer_id JOIN sushi.waiter_names AS wn ON w.waiter_id = wn.id\",\"source_type\":\"sql\"},\"parents\":[{\"name\":\"\\\"sushi\\\".\\\"waiter_names\\\"\",\"identifier\":\"2725136291\"},{\"name\":\"\\\"sushi\\\".\\\"waiters\\\"\",\"identifier\":\"4123940212\"},{\"name\":\"\\\"sushi\\\".\\\"orders\\\"\",\"identifier\":\"1250207606\"},{\"name\":\"\\\"sushi\\\".\\\"customers\\\"\",\"identifier\":\"1461038955\"}],\"created_ts\":1680814464891,\"ttl\":\"in 1 week\",\"version\":\"3668757715\",\"migrated\":true}","13":"{\"name\":\"\\\"sushi\\\".\\\"waiter_names\\\"\",\"temp_version\":\"1204702829\",\"change_category\":1,\"fingerprint\":{\"data_hash\":\"1437406487\",\"metadata_hash\":\"3468846895\",\"parent_data_hash\":\"0\",\"parent_metadata_hash\":\"0\"},\"previous_versions\":[{\"fingerprint\":{\"data_hash\":\"1876476880\",\"metadata_hash\":\"570478986\",\"parent_data_hash\":\"0\",\"parent_metadata_hash\":\"0\"},\"version\":\"2505706914\"},{\"fingerprint\":{\"data_hash\":\"4133862560\",\"metadata_hash\":\"570478986\",\"parent_data_hash\":\"0\",\"parent_metadata_hash\":\"0\"},\"version\":\"1204702829\",\"change_category\":1,\"physical_schema\":\"sqlmesh\"}],\"base_table_name_override\":\"sushi.waiter_names\",\"physical_schema\":\"sqlmesh\",\"node\":{\"name\":\"sushi.waiter_names\",\"project\":\"\",\"owner\":\"jen\",\"cron\":\"@daily\",\"tags\":[],\"dialect\":\"duckdb\",\"kind\":{\"name\":\"SEED\",\"path\":\"..\/seeds\/waiter_names.csv\",\"batch_size\":5},\"partitioned_by\":[],\"clustered_by\":[],\"audits\":[],\"grains\":[],\"references\":[],\"allow_partials\":false,\"signals\":[],\"enabled\":true,\"python_env\":{},\"jinja_macros\":{\"packages\":{},\"root_macros\":{},\"global_objs\":{},\"create_builtins_module\":\"sqlmesh.utils.jinja\",\"top_level_packages\":[]},\"audit_definitions\":{},\"mapping_schema\":{},\"extract_dependencies_from_query\":true,\"seed\":{\"content\":\"\"},\"column_hashes\":{\"id\":\"3061821109\",\"name\":\"2706736258\"},\"derived_columns_to_types\":{\"id\":\"BIGINT\",\"name\":\"TEXT\"},\"is_hydrated\":false,\"source_type\":\"seed\"},\"parents\":[],\"created_ts\":1680814464932,\"ttl\":\"in 1 week\",\"version\":\"1204702829\",\"migrated\":true}","14":"{\"name\":\"\\\"sushi\\\".\\\"customer_revenue_by_day\\\"\",\"temp_version\":\"1291364031\",\"change_category\":4,\"fingerprint\":{\"data_hash\":\"131732542\",\"metadata_hash\":\"1368842087\",\"parent_data_hash\":\"2738168331\",\"parent_metadata_hash\":\"1795276494\"},\"previous_versions\":[{\"fingerprint\":{\"data_hash\":\"2657552867\",\"metadata_hash\":\"129771006\",\"parent_data_hash\":\"764310396\",\"parent_metadata_hash\":\"3147731239\"},\"version\":\"1291364031\",\"change_category\":4,\"physical_schema\":\"sqlmesh\"}],\"base_table_name_override\":\"sushi.customer_revenue_by_day\",\"physical_schema\":\"sqlmesh\",\"node\":{\"name\":\"sushi.customer_revenue_by_day\",\"project\":\"\",\"description\":\"Table of revenue from customers by day.\",\"owner\":\"jen\",\"cron\":\"@daily\",\"tags\":[],\"dialect\":\"hive\",\"kind\":{\"name\":\"INCREMENTAL_BY_TIME_RANGE\",\"on_destructive_change\":\"ERROR\",\"dialect\":\"hive\",\"batch_size\":10,\"forward_only\":false,\"disable_restatement\":false,\"time_column\":{\"column\":\"ds\",\"format\":\"%Y-%m-%d\"}},\"partitioned_by\":[],\"clustered_by\":[],\"audits\":[],\"grains\":[],\"references\":[],\"allow_partials\":false,\"signals\":[],\"enabled\":true,\"python_env\":{},\"jinja_macros\":{\"packages\":{},\"root_macros\":{},\"global_objs\":{},\"create_builtins_module\":\"sqlmesh.utils.jinja\",\"top_level_packages\":[]},\"audit_definitions\":{},\"mapping_schema\":{},\"extract_dependencies_from_query\":true,\"query\":\"JINJA_QUERY_BEGIN;\\nWITH order_total AS (SELECT oi.order_id AS order_id, SUM(oi.quantity * i.price) AS total, oi.ds AS ds FROM sushi.order_items AS oi LEFT JOIN sushi.items AS i ON oi.item_id = i.id AND oi.ds = i.ds WHERE oi.ds BETWEEN '{{ start_ds }}' AND '{{ end_ds }}' GROUP BY oi.order_id, oi.ds) SELECT CAST(o.customer_id AS INT) AS customer_id \/* Customer id *\/, CAST(SUM(ot.total) AS DOUBLE) AS revenue \/* Revenue from orders made by this customer *\/, CAST(o.ds AS TEXT) AS ds \/* Date *\/ FROM sushi.orders AS o LEFT JOIN order_total AS ot ON o.id = ot.order_id AND o.ds = ot.ds WHERE o.ds BETWEEN '{{ start_ds }}' AND '{{ end_ds }}' GROUP BY o.customer_id, o.ds\\nJINJA_END\",\"source_type\":\"sql\"},\"parents\":[{\"name\":\"\\\"sushi\\\".\\\"items\\\"\",\"identifier\":\"3721860967\"},{\"name\":\"\\\"sushi\\\".\\\"orders\\\"\",\"identifier\":\"1250207606\"},{\"name\":\"\\\"sushi\\\".\\\"order_items\\\"\",\"identifier\":\"1422946820\"}],\"created_ts\":1680814376391,\"ttl\":\"in 1 week\",\"version\":\"1291364031\",\"migrated\":true}","15":"{\"name\":\"\\\"sushi\\\".\\\"top_waiters\\\"\",\"temp_version\":\"3010914162\",\"change_category\":4,\"fingerprint\":{\"data_hash\":\"4131026946\",\"metadata_hash\":\"154190563\",\"parent_data_hash\":\"929243525\",\"parent_metadata_hash\":\"2366450878\"},\"previous_versions\":[{\"fingerprint\":{\"data_hash\":\"2891807529\",\"metadata_hash\":\"3392493998\",\"parent_data_hash\":\"1940707936\",\"parent_metadata_hash\":\"1276363398\"},\"version\":\"3010914162\",\"change_category\":4,\"physical_schema\":\"sqlmesh\"}],\"base_table_name_override\":\"sushi.top_waiters\",\"physical_schema\":\"sqlmesh\",\"node\":{\"name\":\"sushi.top_waiters\",\"project\":\"\",\"description\":\"View of top waiters.\",\"owner\":\"jen\",\"cron\":\"@daily\",\"tags\":[],\"dialect\":\"duckdb\",\"kind\":{\"name\":\"VIEW\",\"materialized\":false},\"partitioned_by\":[],\"clustered_by\":[],\"audits\":[[\"unique_values\",{\"columns\":\"ARRAY(waiter_id)\"}]],\"grains\":[],\"references\":[],\"allow_partials\":false,\"signals\":[],\"enabled\":true,\"python_env\":{},\"jinja_macros\":{\"packages\":{},\"root_macros\":{},\"global_objs\":{},\"create_builtins_module\":\"sqlmesh.utils.jinja\",\"top_level_packages\":[]},\"audit_definitions\":{},\"mapping_schema\":{},\"extract_dependencies_from_query\":true,\"query\":\"SELECT CAST(waiter_id AS INT) AS waiter_id, CAST(revenue AS DOUBLE) AS revenue FROM sushi.waiter_revenue_by_day WHERE ds = (SELECT MAX(ds) FROM sushi.waiter_revenue_by_day) ORDER BY revenue DESC LIMIT 10\",\"source_type\":\"sql\"},\"parents\":[{\"name\":\"\\\"sushi\\\".\\\"waiter_revenue_by_day\\\"\",\"identifier\":\"2175947464\"}],\"created_ts\":1680814376384,\"ttl\":\"in 1 week\",\"version\":\"3010914162\",\"migrated\":true}","16":"{\"name\":\"\\\"sushi\\\".\\\"waiter_revenue_by_day\\\"\",\"temp_version\":\"2695875565\",\"change_category\":4,\"fingerprint\":{\"data_hash\":\"2224089837\",\"metadata_hash\":\"2504236462\",\"parent_data_hash\":\"2738168331\",\"parent_metadata_hash\":\"1795276494\"},\"previous_versions\":[{\"fingerprint\":{\"data_hash\":\"2443934302\",\"metadata_hash\":\"2904050331\",\"parent_data_hash\":\"764310396\",\"parent_metadata_hash\":\"3147731239\"},\"version\":\"2695875565\",\"change_category\":4,\"physical_schema\":\"sqlmesh\"}],\"base_table_name_override\":\"sushi.waiter_revenue_by_day\",\"physical_schema\":\"sqlmesh\",\"node\":{\"name\":\"sushi.waiter_revenue_by_day\",\"project\":\"\",\"description\":\"Table of revenue generated by waiters by day.\",\"owner\":\"jen\",\"cron\":\"@daily\",\"tags\":[],\"dialect\":\"duckdb\",\"kind\":{\"name\":\"INCREMENTAL_BY_TIME_RANGE\",\"on_destructive_change\":\"ERROR\",\"dialect\":\"duckdb\",\"batch_size\":10,\"forward_only\":false,\"disable_restatement\":false,\"time_column\":{\"column\":\"ds\",\"format\":\"%Y-%m-%d\"}},\"partitioned_by\":[],\"clustered_by\":[],\"audits\":[[\"number_of_rows\",{\"threshold\":\"0\"}]],\"grains\":[],\"references\":[],\"allow_partials\":false,\"signals\":[],\"enabled\":true,\"python_env\":{},\"jinja_macros\":{\"packages\":{},\"root_macros\":{},\"global_objs\":{},\"create_builtins_module\":\"sqlmesh.utils.jinja\",\"top_level_packages\":[]},\"audit_definitions\":{},\"mapping_schema\":{},\"extract_dependencies_from_query\":true,\"query\":\"SELECT CAST(o.waiter_id AS INT) AS waiter_id \/* Waiter id *\/, CAST(SUM(oi.quantity * i.price) AS DOUBLE) AS revenue \/* Revenue from orders taken by this waiter *\/, CAST(o.ds AS TEXT) AS ds \/* Date *\/ FROM sushi.orders AS o LEFT JOIN sushi.order_items AS oi ON o.id = oi.order_id AND o.ds = oi.ds LEFT JOIN sushi.items AS i ON oi.item_id = i.id AND oi.ds = i.ds WHERE o.ds BETWEEN @start_ds AND @end_ds GROUP BY o.waiter_id, o.ds\",\"source_type\":\"sql\"},\"parents\":[{\"name\":\"\\\"sushi\\\".\\\"items\\\"\",\"identifier\":\"3721860967\"},{\"name\":\"\\\"sushi\\\".\\\"orders\\\"\",\"identifier\":\"1250207606\"},{\"name\":\"\\\"sushi\\\".\\\"order_items\\\"\",\"identifier\":\"1422946820\"}],\"created_ts\":1680814376361,\"ttl\":\"in 1 week\",\"version\":\"2695875565\",\"migrated\":true}","17":"{\"name\":\"\\\"sushi\\\".\\\"order_items\\\"\",\"temp_version\":\"1015284155\",\"change_category\":4,\"fingerprint\":{\"data_hash\":\"4010068827\",\"metadata_hash\":\"799196655\",\"parent_data_hash\":\"2342431947\",\"parent_metadata_hash\":\"1746080605\"},\"previous_versions\":[{\"fingerprint\":{\"data_hash\":\"653664599\",\"metadata_hash\":\"1960934702\",\"parent_data_hash\":\"3170724558\",\"parent_metadata_hash\":\"867324801\"},\"version\":\"1015284155\",\"change_category\":4,\"physical_schema\":\"sqlmesh\"}],\"base_table_name_override\":\"sushi.order_items\",\"physical_schema\":\"sqlmesh\",\"node\":{\"name\":\"sushi.order_items\",\"project\":\"\",\"cron\":\"@daily\",\"tags\":[],\"dialect\":\"\",\"kind\":{\"name\":\"INCREMENTAL_BY_TIME_RANGE\",\"on_destructive_change\":\"ERROR\",\"dialect\":\"\",\"batch_size\":30,\"forward_only\":false,\"disable_restatement\":false,\"time_column\":{\"column\":\"ds\",\"format\":\"%Y-%m-%d\"}},\"partitioned_by\":[],\"clustered_by\":[],\"depends_on\":[\"\\\"sushi\\\".\\\"items\\\"\",\"\\\"sushi\\\".\\\"orders\\\"\"],\"columns\":{\"id\":\"INT\",\"order_id\":\"INT\",\"item_id\":\"INT\",\"quantity\":\"INT\",\"ds\":\"TEXT\"},\"audits\":[[\"not_null\",{\"columns\":\"ARRAY(id, order_id, item_id, quantity)\"}],[\"assert_order_items_quantity_exceeds_threshold\",{\"quantity\":\"0\"}]],\"grains\":[],\"references\":[],\"allow_partials\":false,\"signals\":[],\"enabled\":true,\"python_env\":{\"execute\":{\"payload\":\"def execute(context, start, end, latest, **kwargs):\\n orders_table = context.table('sushi.orders')\\n items_table = context.table(ITEMS)\\n for dt in iter_dates(start, end):\\n orders = context.fetchdf(\\n f\\\"\\\"\\\"\\n SELECT *\\n FROM {orders_table}\\n WHERE ds = '{to_ds(dt)}'\\n \\\"\\\"\\\"\\n )\\n items = context.fetchdf(\\n f\\\"\\\"\\\"\\n SELECT *\\n FROM {items_table}\\n WHERE ds = '{to_ds(dt)}'\\n \\\"\\\"\\\"\\n )\\n for order_id in orders['id']:\\n n = random.randint(1, 5)\\n yield pd.DataFrame({'order_id': order_id, 'item_id': items.\\n sample(n=n)['id'], 'quantity': np.random.randint(1, 10, n),\\n 'ds': to_ds(dt)}).reset_index().rename(columns={'index': 'id'})\",\"kind\":\"definition\",\"name\":\"execute\",\"path\":\"models\/order_items.py\"},\"ITEMS\":{\"payload\":\"'sushi.items'\",\"kind\":\"value\"},\"iter_dates\":{\"payload\":\"def iter_dates(start, end):\\n for i in range((end - start).days + 1):\\n dt = start + timedelta(days=i)\\n set_seed(dt)\\n yield dt\",\"kind\":\"definition\",\"name\":\"iter_dates\",\"path\":\"helper.py\"},\"timedelta\":{\"payload\":\"from datetime import timedelta\",\"kind\":\"import\"},\"set_seed\":{\"payload\":\"def set_seed(dt):\\n ts = int(dt.timestamp())\\n random.seed(ts)\\n np.random.seed(ts)\",\"kind\":\"definition\",\"name\":\"set_seed\",\"path\":\"helper.py\"},\"random\":{\"payload\":\"import random\",\"kind\":\"import\"},\"np\":{\"payload\":\"import numpy as np\",\"kind\":\"import\"},\"to_ds\":{\"payload\":\"from sqlmesh.utils.date import to_ds\",\"kind\":\"import\"},\"pd\":{\"payload\":\"import pandas as pd\",\"kind\":\"import\"}},\"jinja_macros\":{\"packages\":{},\"root_macros\":{},\"global_objs\":{},\"create_builtins_module\":\"sqlmesh.utils.jinja\",\"top_level_packages\":[]},\"audit_definitions\":{\"assert_order_items_quantity_exceeds_threshold\":{\"name\":\"assert_order_items_quantity_exceeds_threshold\",\"dialect\":\"\",\"skip\":false,\"blocking\":true,\"standalone\":false,\"query\":\"SELECT * FROM @this_model WHERE quantity <= @quantity\",\"defaults\":{},\"expressions\":[],\"jinja_macros\":{\"packages\":{},\"root_macros\":{},\"global_objs\":{},\"create_builtins_module\":\"sqlmesh.utils.jinja\",\"top_level_packages\":[]}}},\"mapping_schema\":{},\"extract_dependencies_from_query\":true,\"entrypoint\":\"execute\",\"source_type\":\"python\"},\"parents\":[{\"name\":\"\\\"sushi\\\".\\\"items\\\"\",\"identifier\":\"3721860967\"},{\"name\":\"\\\"sushi\\\".\\\"orders\\\"\",\"identifier\":\"1250207606\"}],\"created_ts\":1680814376401,\"ttl\":\"in 1 week\",\"version\":\"1015284155\",\"migrated\":true}","18":"{\"name\":\"\\\"sushi\\\".\\\"items\\\"\",\"temp_version\":\"312608270\",\"change_category\":4,\"fingerprint\":{\"data_hash\":\"1862622614\",\"metadata_hash\":\"3651173237\",\"parent_data_hash\":\"0\",\"parent_metadata_hash\":\"0\"},\"previous_versions\":[{\"fingerprint\":{\"data_hash\":\"1960378930\",\"metadata_hash\":\"2900807542\",\"parent_data_hash\":\"0\",\"parent_metadata_hash\":\"0\"},\"version\":\"312608270\",\"change_category\":4,\"physical_schema\":\"sqlmesh\"}],\"base_table_name_override\":\"sushi.items\",\"physical_schema\":\"sqlmesh\",\"node\":{\"name\":\"sushi.items\",\"project\":\"\",\"start\":\"Jan 1 2022\",\"cron\":\"@daily\",\"tags\":[],\"dialect\":\"\",\"kind\":{\"name\":\"INCREMENTAL_BY_TIME_RANGE\",\"on_destructive_change\":\"ERROR\",\"dialect\":\"\",\"batch_size\":30,\"forward_only\":false,\"disable_restatement\":false,\"time_column\":{\"column\":\"ds\",\"format\":\"%Y-%m-%d\"}},\"partitioned_by\":[],\"clustered_by\":[],\"depends_on\":[],\"columns\":{\"id\":\"INT\",\"name\":\"TEXT\",\"price\":\"DOUBLE\",\"ds\":\"TEXT\"},\"audits\":[[\"accepted_values\",{\"column\":\"name\",\"values\":\"ARRAY('Ahi', 'Aji', 'Amaebi', 'Anago', 'Aoyagi', 'Bincho', 'Katsuo', 'Ebi', 'Escolar', 'Hamachi', 'Hamachi Toro', 'Hirame', 'Hokigai', 'Hotate', 'Ika', 'Ikura', 'Iwashi', 'Kani', 'Kanpachi', 'Maguro', 'Saba', 'Sake', 'Sake Toro', 'Tai', 'Tako', 'Tamago', 'Tobiko', 'Toro', 'Tsubugai', 'Umi Masu', 'Unagi', 'Uni')\"}],[\"not_null\",{\"columns\":\"ARRAY(name, price)\"}],[\"assert_items_price_exceeds_threshold\",{\"price\":\"0\"}]],\"grains\":[],\"references\":[],\"allow_partials\":false,\"signals\":[],\"enabled\":true,\"python_env\":{\"execute\":{\"payload\":\"def execute(context, start, end, latest, **kwargs):\\n dfs = []\\n for dt in iter_dates(start, end):\\n num_items = random.randint(10, len(ITEMS))\\n dfs.append(pd.DataFrame({'name': random.sample(ITEMS, num_items),\\n 'price': np.random.uniform(3.0, 10.0, size=num_items).round(2),\\n 'ds': to_ds(dt)}).reset_index().rename(columns={'index': 'id'}))\\n return pd.concat(dfs)\",\"kind\":\"definition\",\"name\":\"execute\",\"path\":\"models\/items.py\"},\"iter_dates\":{\"payload\":\"def iter_dates(start, end):\\n for i in range((end - start).days + 1):\\n dt = start + timedelta(days=i)\\n set_seed(dt)\\n yield dt\",\"kind\":\"definition\",\"name\":\"iter_dates\",\"path\":\"helper.py\"},\"timedelta\":{\"payload\":\"from datetime import timedelta\",\"kind\":\"import\"},\"set_seed\":{\"payload\":\"def set_seed(dt):\\n ts = int(dt.timestamp())\\n random.seed(ts)\\n np.random.seed(ts)\",\"kind\":\"definition\",\"name\":\"set_seed\",\"path\":\"helper.py\"},\"random\":{\"payload\":\"import random\",\"kind\":\"import\"},\"np\":{\"payload\":\"import numpy as np\",\"kind\":\"import\"},\"ITEMS\":{\"payload\":\"['Ahi', 'Aji', 'Amaebi', 'Anago', 'Aoyagi', 'Bincho', 'Katsuo', 'Ebi', 'Escolar', 'Hamachi', 'Hamachi Toro', 'Hirame', 'Hokigai', 'Hotate', 'Ika', 'Ikura', 'Iwashi', 'Kani', 'Kanpachi', 'Maguro', 'Saba', 'Sake', 'Sake Toro', 'Tai', 'Tako', 'Tamago', 'Tobiko', 'Toro', 'Tsubugai', 'Umi Masu', 'Unagi', 'Uni']\",\"kind\":\"value\"},\"pd\":{\"payload\":\"import pandas as pd\",\"kind\":\"import\"},\"to_ds\":{\"payload\":\"from sqlmesh.utils.date import to_ds\",\"kind\":\"import\"}},\"jinja_macros\":{\"packages\":{},\"root_macros\":{},\"global_objs\":{},\"create_builtins_module\":\"sqlmesh.utils.jinja\",\"top_level_packages\":[]},\"audit_definitions\":{\"assert_items_price_exceeds_threshold\":{\"name\":\"assert_items_price_exceeds_threshold\",\"dialect\":\"\",\"skip\":false,\"blocking\":true,\"standalone\":false,\"query\":\"SELECT * FROM @this_model WHERE price <= @price\",\"defaults\":{},\"expressions\":[],\"jinja_macros\":{\"packages\":{},\"root_macros\":{},\"global_objs\":{},\"create_builtins_module\":\"sqlmesh.utils.jinja\",\"top_level_packages\":[]}}},\"mapping_schema\":{},\"extract_dependencies_from_query\":true,\"entrypoint\":\"execute\",\"source_type\":\"python\"},\"parents\":[],\"created_ts\":1680814376399,\"ttl\":\"in 1 week\",\"version\":\"312608270\",\"migrated\":true}","19":"{\"name\":\"\\\"sushi\\\".\\\"waiter_as_customer_by_day\\\"\",\"temp_version\":\"1267397572\",\"change_category\":4,\"fingerprint\":{\"data_hash\":\"849558693\",\"metadata_hash\":\"2088684978\",\"parent_data_hash\":\"2705906012\",\"parent_metadata_hash\":\"665080906\"},\"previous_versions\":[{\"fingerprint\":{\"data_hash\":\"486172035\",\"metadata_hash\":\"1992853678\",\"parent_data_hash\":\"2154574190\",\"parent_metadata_hash\":\"1349779748\"},\"version\":\"1267397572\",\"change_category\":4,\"physical_schema\":\"sqlmesh\"}],\"base_table_name_override\":\"sushi.waiter_as_customer_by_day\",\"physical_schema\":\"sqlmesh\",\"node\":{\"name\":\"sushi.waiter_as_customer_by_day\",\"project\":\"\",\"owner\":\"jen\",\"cron\":\"@daily\",\"tags\":[],\"dialect\":\"duckdb\",\"kind\":{\"name\":\"INCREMENTAL_BY_TIME_RANGE\",\"on_destructive_change\":\"ERROR\",\"dialect\":\"duckdb\",\"forward_only\":false,\"disable_restatement\":false,\"time_column\":{\"column\":\"ds\",\"format\":\"%Y-%m-%d\"}},\"partitioned_by\":[],\"clustered_by\":[],\"audits\":[[\"not_null\",{\"columns\":\"ARRAY(waiter_id)\"}]],\"grains\":[],\"references\":[],\"allow_partials\":false,\"signals\":[],\"enabled\":true,\"python_env\":{},\"jinja_macros\":{\"packages\":{},\"root_macros\":{},\"global_objs\":{},\"create_builtins_module\":\"sqlmesh.utils.jinja\",\"top_level_packages\":[]},\"audit_definitions\":{},\"mapping_schema\":{},\"extract_dependencies_from_query\":true,\"query\":\"SELECT w.ds AS ds, w.waiter_id AS waiter_id, wn.name AS waiter_name FROM sushi.waiters AS w JOIN sushi.customers AS c ON w.waiter_id = c.customer_id JOIN sushi.waiter_names AS wn ON w.waiter_id = wn.id\",\"source_type\":\"sql\"},\"parents\":[{\"name\":\"\\\"sushi\\\".\\\"waiter_names\\\"\",\"identifier\":\"1609854746\"},{\"name\":\"\\\"sushi\\\".\\\"waiters\\\"\",\"identifier\":\"4123940212\"},{\"name\":\"\\\"sushi\\\".\\\"orders\\\"\",\"identifier\":\"1250207606\"},{\"name\":\"\\\"sushi\\\".\\\"customers\\\"\",\"identifier\":\"1461038955\"}],\"created_ts\":1680814376348,\"ttl\":\"in 1 week\",\"version\":\"1267397572\",\"migrated\":true}","20":"{\"name\":\"\\\"sushi\\\".\\\"waiter_names\\\"\",\"temp_version\":\"2505706914\",\"change_category\":4,\"fingerprint\":{\"data_hash\":\"3604872020\",\"metadata_hash\":\"3468846895\",\"parent_data_hash\":\"0\",\"parent_metadata_hash\":\"0\"},\"previous_versions\":[{\"fingerprint\":{\"data_hash\":\"1876476880\",\"metadata_hash\":\"570478986\",\"parent_data_hash\":\"0\",\"parent_metadata_hash\":\"0\"},\"version\":\"2505706914\",\"change_category\":4,\"physical_schema\":\"sqlmesh\"}],\"base_table_name_override\":\"sushi.waiter_names\",\"physical_schema\":\"sqlmesh\",\"node\":{\"name\":\"sushi.waiter_names\",\"project\":\"\",\"owner\":\"jen\",\"cron\":\"@daily\",\"tags\":[],\"dialect\":\"duckdb\",\"kind\":{\"name\":\"SEED\",\"path\":\"..\/seeds\/waiter_names.csv\",\"batch_size\":5},\"partitioned_by\":[],\"clustered_by\":[],\"audits\":[],\"grains\":[],\"references\":[],\"allow_partials\":false,\"signals\":[],\"enabled\":true,\"python_env\":{},\"jinja_macros\":{\"packages\":{},\"root_macros\":{},\"global_objs\":{},\"create_builtins_module\":\"sqlmesh.utils.jinja\",\"top_level_packages\":[]},\"audit_definitions\":{},\"mapping_schema\":{},\"extract_dependencies_from_query\":true,\"seed\":{\"content\":\"\"},\"column_hashes\":{\"id\":\"3679804453\",\"name\":\"537745575\"},\"derived_columns_to_types\":{\"id\":\"BIGINT\",\"name\":\"TEXT\"},\"is_hydrated\":false,\"source_type\":\"seed\"},\"parents\":[],\"created_ts\":1680814376389,\"ttl\":\"in 1 week\",\"version\":\"2505706914\",\"migrated\":true}","21":"{\"name\":\"\\\"sushi\\\".\\\"customers\\\"\",\"temp_version\":\"2359719298\",\"change_category\":4,\"fingerprint\":{\"data_hash\":\"2431070412\",\"metadata_hash\":\"3063653103\",\"parent_data_hash\":\"458609840\",\"parent_metadata_hash\":\"2007040660\"},\"previous_versions\":[{\"fingerprint\":{\"data_hash\":\"3553985282\",\"metadata_hash\":\"570478986\",\"parent_data_hash\":\"777615193\",\"parent_metadata_hash\":\"2042613269\"},\"version\":\"2359719298\",\"change_category\":4,\"physical_schema\":\"sqlmesh\"}],\"base_table_name_override\":\"sushi.customers\",\"physical_schema\":\"sqlmesh\",\"node\":{\"name\":\"sushi.customers\",\"project\":\"\",\"owner\":\"jen\",\"cron\":\"@daily\",\"tags\":[],\"dialect\":\"duckdb\",\"kind\":{\"name\":\"FULL\"},\"partitioned_by\":[],\"clustered_by\":[],\"audits\":[],\"grains\":[],\"references\":[],\"allow_partials\":false,\"signals\":[],\"enabled\":true,\"python_env\":{\"noop\":{\"payload\":\"def noop(context, start, end, latest, **kwargs):\\n pass\",\"kind\":\"definition\",\"name\":\"noop\",\"path\":\"hooks\/hooks.py\"}},\"jinja_macros\":{\"packages\":{},\"root_macros\":{},\"global_objs\":{},\"create_builtins_module\":\"sqlmesh.utils.jinja\",\"top_level_packages\":[]},\"audit_definitions\":{},\"mapping_schema\":{},\"extract_dependencies_from_query\":true,\"query\":\"SELECT DISTINCT CAST(customer_id AS INT) AS customer_id FROM sushi.orders AS o\",\"source_type\":\"sql\"},\"parents\":[{\"name\":\"\\\"sushi\\\".\\\"orders\\\"\",\"identifier\":\"1250207606\"}],\"created_ts\":1680814376388,\"ttl\":\"in 1 week\",\"version\":\"2359719298\",\"migrated\":true}","22":"{\"name\":\"\\\"sushi\\\".\\\"waiters\\\"\",\"temp_version\":\"2059227798\",\"change_category\":4,\"fingerprint\":{\"data_hash\":\"2037801255\",\"metadata_hash\":\"3063653103\",\"parent_data_hash\":\"458609840\",\"parent_metadata_hash\":\"2007040660\"},\"previous_versions\":[{\"fingerprint\":{\"data_hash\":\"3501061139\",\"metadata_hash\":\"570478986\",\"parent_data_hash\":\"777615193\",\"parent_metadata_hash\":\"2042613269\"},\"version\":\"2059227798\",\"change_category\":4,\"physical_schema\":\"sqlmesh\"}],\"base_table_name_override\":\"sushi.waiters\",\"physical_schema\":\"sqlmesh\",\"node\":{\"name\":\"sushi.waiters\",\"project\":\"\",\"owner\":\"jen\",\"cron\":\"@daily\",\"tags\":[],\"dialect\":\"duckdb\",\"kind\":{\"name\":\"EMBEDDED\",\"disable_restatement\":true},\"partitioned_by\":[],\"clustered_by\":[],\"audits\":[],\"grains\":[],\"references\":[],\"allow_partials\":false,\"signals\":[],\"enabled\":true,\"python_env\":{\"incremental_by_ds\":{\"payload\":\"def incremental_by_ds(evaluator, column):\\n expression = evaluator.transform(exp.Between(this=column, low=MacroVar(\\n this='start_ds'), high=MacroVar(this='end_ds')))\\n if not isinstance(expression, exp.Expression):\\n raise MacroEvalError(\\n f'Return type is {type(expression)}, expected exp.Expression')\\n return expression\",\"kind\":\"definition\",\"name\":\"incremental_by_ds\",\"path\":\"macros\/macros.py\"},\"exp\":{\"payload\":\"import sqlglot.expressions as exp\",\"kind\":\"import\"},\"MacroVar\":{\"payload\":\"from sqlmesh.core.dialect import MacroVar\",\"kind\":\"import\"},\"MacroEvalError\":{\"payload\":\"from sqlmesh.utils.errors import MacroEvalError\",\"kind\":\"import\"}},\"jinja_macros\":{\"packages\":{},\"root_macros\":{},\"global_objs\":{},\"create_builtins_module\":\"sqlmesh.utils.jinja\",\"top_level_packages\":[]},\"audit_definitions\":{},\"mapping_schema\":{},\"extract_dependencies_from_query\":true,\"query\":\"SELECT DISTINCT CAST(waiter_id AS INT) AS waiter_id, CAST(ds AS TEXT) AS ds FROM sushi.orders AS o WHERE @incremental_by_ds(ds)\",\"source_type\":\"sql\"},\"parents\":[{\"name\":\"\\\"sushi\\\".\\\"orders\\\"\",\"identifier\":\"1250207606\"}],\"created_ts\":1680814376387,\"ttl\":\"in 1 week\",\"version\":\"2059227798\",\"migrated\":true}","23":"{\"name\":\"\\\"sushi\\\".\\\"orders\\\"\",\"temp_version\":\"925846788\",\"change_category\":4,\"fingerprint\":{\"data_hash\":\"1588786367\",\"metadata_hash\":\"1674367104\",\"parent_data_hash\":\"0\",\"parent_metadata_hash\":\"0\"},\"previous_versions\":[{\"fingerprint\":{\"data_hash\":\"1628439771\",\"metadata_hash\":\"2745052130\",\"parent_data_hash\":\"0\",\"parent_metadata_hash\":\"0\"},\"version\":\"925846788\",\"change_category\":4,\"physical_schema\":\"sqlmesh\"}],\"base_table_name_override\":\"sushi.orders\",\"physical_schema\":\"sqlmesh\",\"node\":{\"name\":\"sushi.orders\",\"project\":\"\",\"description\":\"Table of sushi orders.\",\"start\":\"2022-01-01\",\"cron\":\"@daily\",\"tags\":[],\"dialect\":\"\",\"kind\":{\"name\":\"INCREMENTAL_BY_TIME_RANGE\",\"on_destructive_change\":\"ERROR\",\"dialect\":\"\",\"batch_size\":30,\"forward_only\":false,\"disable_restatement\":false,\"time_column\":{\"column\":\"ds\",\"format\":\"%Y-%m-%d\"}},\"partitioned_by\":[],\"clustered_by\":[],\"depends_on\":[],\"columns\":{\"id\":\"INT\",\"customer_id\":\"INT\",\"waiter_id\":\"INT\",\"start_ts\":\"INT\",\"end_ts\":\"INT\",\"ds\":\"TEXT\"},\"audits\":[],\"grains\":[],\"references\":[],\"allow_partials\":false,\"signals\":[],\"enabled\":true,\"python_env\":{\"execute\":{\"payload\":\"def execute(context, start, end, latest, **kwargs):\\n dfs = []\\n for dt in iter_dates(start, end):\\n num_orders = random.randint(10, 30)\\n start_ts = [int((dt + timedelta(seconds=random.randint(0, 80000))).\\n timestamp()) for _ in range(num_orders)]\\n end_ts = [int(s + random.randint(0, 60 * 60)) for s in start_ts]\\n dfs.append(pd.DataFrame({'customer_id': random.choices(CUSTOMERS, k\\n =num_orders), 'waiter_id': random.choices(WAITERS, k=num_orders\\n ), 'start_ts': start_ts, 'end_ts': end_ts, 'ds': to_ds(dt)}).\\n reset_index().rename(columns={'index': 'id'}))\\n return pd.concat(dfs)\",\"kind\":\"definition\",\"name\":\"execute\",\"path\":\"models\/orders.py\"},\"iter_dates\":{\"payload\":\"def iter_dates(start, end):\\n for i in range((end - start).days + 1):\\n dt = start + timedelta(days=i)\\n set_seed(dt)\\n yield dt\",\"kind\":\"definition\",\"name\":\"iter_dates\",\"path\":\"helper.py\"},\"timedelta\":{\"payload\":\"from datetime import timedelta\",\"kind\":\"import\"},\"set_seed\":{\"payload\":\"def set_seed(dt):\\n ts = int(dt.timestamp())\\n random.seed(ts)\\n np.random.seed(ts)\",\"kind\":\"definition\",\"name\":\"set_seed\",\"path\":\"helper.py\"},\"random\":{\"payload\":\"import random\",\"kind\":\"import\"},\"np\":{\"payload\":\"import numpy as np\",\"kind\":\"import\"},\"pd\":{\"payload\":\"import pandas as pd\",\"kind\":\"import\"},\"CUSTOMERS\":{\"payload\":\"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]\",\"kind\":\"value\"},\"WAITERS\":{\"payload\":\"[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]\",\"kind\":\"value\"},\"to_ds\":{\"payload\":\"from sqlmesh.utils.date import to_ds\",\"kind\":\"import\"}},\"jinja_macros\":{\"packages\":{},\"root_macros\":{},\"global_objs\":{},\"create_builtins_module\":\"sqlmesh.utils.jinja\",\"top_level_packages\":[]},\"audit_definitions\":{},\"mapping_schema\":{},\"extract_dependencies_from_query\":true,\"entrypoint\":\"execute\",\"source_type\":\"python\"},\"parents\":[],\"created_ts\":1680814376402,\"ttl\":\"in 1 week\",\"version\":\"925846788\",\"migrated\":true}"},"kind_name":{"0":"INCREMENTAL_BY_TIME_RANGE","1":"INCREMENTAL_BY_TIME_RANGE","2":"VIEW","3":"EMBEDDED","4":"FULL","5":"SEED","6":"INCREMENTAL_BY_TIME_RANGE","7":"INCREMENTAL_BY_TIME_RANGE","8":"INCREMENTAL_BY_TIME_RANGE","9":"INCREMENTAL_BY_TIME_RANGE","10":"INCREMENTAL_BY_TIME_RANGE","11":"SEED","12":"INCREMENTAL_BY_TIME_RANGE","13":"SEED","14":"INCREMENTAL_BY_TIME_RANGE","15":"VIEW","16":"INCREMENTAL_BY_TIME_RANGE","17":"INCREMENTAL_BY_TIME_RANGE","18":"INCREMENTAL_BY_TIME_RANGE","19":"INCREMENTAL_BY_TIME_RANGE","20":"SEED","21":"FULL","22":"EMBEDDED","23":"INCREMENTAL_BY_TIME_RANGE"},"updated_ts":{"0":1680814376348,"1":1680814376361,"2":1680814376384,"3":1680814376387,"4":1680814376388,"5":1680814376389,"6":1680814376391,"7":1680814376399,"8":1680814376401,"9":1680814376402,"10":1680814464891,"11":1680814464932,"12":1680814464891,"13":1680814464932,"14":1680814376391,"15":1680814376384,"16":1680814376361,"17":1680814376401,"18":1680814376399,"19":1680814376348,"20":1680814376389,"21":1680814376388,"22":1680814376387,"23":1680814376402},"unpaused_ts":{"0":null,"1":null,"2":null,"3":null,"4":null,"5":null,"6":null,"7":null,"8":null,"9":null,"10":null,"11":null,"12":null,"13":null,"14":null,"15":null,"16":null,"17":null,"18":null,"19":null,"20":null,"21":null,"22":null,"23":null},"ttl_ms":{"0":604800000,"1":604800000,"2":604800000,"3":604800000,"4":604800000,"5":604800000,"6":604800000,"7":604800000,"8":604800000,"9":604800000,"10":604800000,"11":604800000,"12":604800000,"13":604800000,"14":604800000,"15":604800000,"16":604800000,"17":604800000,"18":604800000,"19":604800000,"20":604800000,"21":604800000,"22":604800000,"23":604800000},"unrestorable":{"0":false,"1":false,"2":false,"3":false,"4":false,"5":false,"6":false,"7":false,"8":false,"9":false,"10":false,"11":false,"12":false,"13":false,"14":false,"15":false,"16":false,"17":false,"18":false,"19":false,"20":false,"21":false,"22":false,"23":false}} \ No newline at end of file diff --git a/tests/fixtures/migrations/versions.json b/tests/fixtures/migrations/versions.json new file mode 100644 index 0000000000..5eac7ed987 --- /dev/null +++ b/tests/fixtures/migrations/versions.json @@ -0,0 +1 @@ +{"schema_version":{"0":60},"sqlglot_version":{"0":"25.31.4"},"sqlmesh_version":{"0":"0.134.0"}} diff --git a/tests/fixtures/multi_virtual_layer/audits/.gitkeep b/tests/fixtures/multi_virtual_layer/audits/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/fixtures/multi_virtual_layer/macros/.gitkeep b/tests/fixtures/multi_virtual_layer/macros/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/fixtures/multi_virtual_layer/macros/__init__.py b/tests/fixtures/multi_virtual_layer/macros/__init__.py new file mode 100644 index 0000000000..1e9de0aa75 --- /dev/null +++ b/tests/fixtures/multi_virtual_layer/macros/__init__.py @@ -0,0 +1,6 @@ +from sqlmesh import macro + + +@macro() +def one(context): + return 1 diff --git a/tests/fixtures/multi_virtual_layer/models/.gitkeep b/tests/fixtures/multi_virtual_layer/models/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/fixtures/multi_virtual_layer/models/first_schema/model_one.sql b/tests/fixtures/multi_virtual_layer/models/first_schema/model_one.sql new file mode 100644 index 0000000000..1bb062a80b --- /dev/null +++ b/tests/fixtures/multi_virtual_layer/models/first_schema/model_one.sql @@ -0,0 +1,8 @@ +MODEL ( + kind FULL, +); + +SELECT + @overriden_var as item_id, + @global_one as global_one, + @one() AS macro_one \ No newline at end of file diff --git a/tests/fixtures/multi_virtual_layer/models/first_schema/model_two.sql b/tests/fixtures/multi_virtual_layer/models/first_schema/model_two.sql new file mode 100644 index 0000000000..c09794f02f --- /dev/null +++ b/tests/fixtures/multi_virtual_layer/models/first_schema/model_two.sql @@ -0,0 +1,9 @@ +MODEL ( + kind FULL, +); + +SELECT + item_id, + global_one +FROM + first_schema.model_one; \ No newline at end of file diff --git a/tests/fixtures/multi_virtual_layer/models/second_schema/model_one.sql b/tests/fixtures/multi_virtual_layer/models/second_schema/model_one.sql new file mode 100644 index 0000000000..b4b75d80bf --- /dev/null +++ b/tests/fixtures/multi_virtual_layer/models/second_schema/model_one.sql @@ -0,0 +1,9 @@ +MODEL ( + kind FULL, + gateway second +); + +SELECT + @overriden_var as item_id, + @global_one as global_one, + @one() AS macro_one diff --git a/tests/fixtures/multi_virtual_layer/models/second_schema/model_two.sql b/tests/fixtures/multi_virtual_layer/models/second_schema/model_two.sql new file mode 100644 index 0000000000..f7688d70de --- /dev/null +++ b/tests/fixtures/multi_virtual_layer/models/second_schema/model_two.sql @@ -0,0 +1,10 @@ +MODEL ( + kind FULL, + gateway second +); + +SELECT + item_id, + global_one +FROM + second_schema.model_one; \ No newline at end of file diff --git a/tests/fixtures/multi_virtual_layer/tests/.gitkeep b/tests/fixtures/multi_virtual_layer/tests/.gitkeep new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/integrations/github/cicd/conftest.py b/tests/integrations/github/cicd/conftest.py new file mode 100644 index 0000000000..f869dc41ad --- /dev/null +++ b/tests/integrations/github/cicd/conftest.py @@ -0,0 +1,212 @@ +import typing as t + +import pytest +from pytest_mock.plugin import MockerFixture +from pathlib import Path + +from sqlmesh.core.config import Config +from sqlmesh.core.console import set_console, get_console, MarkdownConsole +from sqlmesh.integrations.github.cicd.config import GithubCICDBotConfig +from sqlmesh.integrations.github.cicd.controller import ( + GithubController, + GithubEvent, + MergeStateStatus, + PullRequestInfo, +) +from sqlmesh.utils import AttributeDict +from sqlglot.helper import ensure_list + + +@pytest.fixture +def github_client(mocker: MockerFixture): + from github import Github + from github.Issue import Issue + from github.PullRequest import PullRequest + from github.PullRequestReview import PullRequestReview + from github.Repository import Repository + + client_mock = mocker.MagicMock(spec=Github) + mocker.patch("github.Github", client_mock) + + mock_repository = mocker.MagicMock(spec=Repository) + client_mock.get_repo.return_value = mock_repository + + mock_pull_request = mocker.MagicMock(spec=PullRequest) + mock_pull_request.base.ref = "main" + mock_pull_request.get_reviews.return_value = [mocker.MagicMock(spec=PullRequestReview)] + mock_repository.get_pull.return_value = mock_pull_request + mock_repository.get_issue.return_value = mocker.MagicMock(spec=Issue) + + return client_mock + + +@pytest.fixture +def make_pull_request_review(github_client) -> t.Callable: + from github.PullRequestReview import PullRequestReview + + def _make_function(username: str, state: str, **kwargs) -> PullRequestReview: + return PullRequestReview( + github_client.requester, + {}, + { + # Name is whatever they provide in their GitHub profile or login as a fallback. Always use login. + "user": AttributeDict(name="Unrelated", login=username), + "state": state, + **kwargs, + }, + ) + + return _make_function + + +@pytest.fixture +def sqlmesh_repo_root_path() -> Path: + return next(p for p in Path(__file__).parents if str(p).endswith("tests")).parent + + +@pytest.fixture +def make_controller( + mocker: MockerFixture, + copy_to_temp_path: t.Callable, + monkeypatch: pytest.MonkeyPatch, + sqlmesh_repo_root_path: Path, +) -> t.Callable: + from github import Github + + def _make_function( + event_path: t.Union[str, Path, t.Dict], + client: Github, + *, + merge_state_status: MergeStateStatus = MergeStateStatus.CLEAN, + bot_config: t.Optional[GithubCICDBotConfig] = None, + mock_out_context: bool = True, + config: t.Optional[t.Union[Config, str]] = None, + paths: t.Optional[t.Union[Path, t.List[Path]]] = None, + ) -> GithubController: + if mock_out_context: + mocker.patch("sqlmesh.core.context.Context.apply", mocker.MagicMock()) + mocker.patch("sqlmesh.core.context.Context._run_plan_tests", mocker.MagicMock()) + mocker.patch("sqlmesh.core.context.Context._run_tests", mocker.MagicMock()) + mocker.patch( + "sqlmesh.integrations.github.cicd.controller.GithubController._get_merge_state_status", + mocker.MagicMock(side_effect=lambda: merge_state_status), + ) + if bot_config: + mocker.patch( + "sqlmesh.integrations.github.cicd.controller.GithubController.bot_config", + bot_config, + ) + + if paths is None: + paths = copy_to_temp_path(sqlmesh_repo_root_path / "examples" / "sushi") + + paths = ensure_list(paths) + + if isinstance(event_path, str): + # resolve relative event_path references to absolute so they dont get affected by chdir() below + as_path = Path(event_path) + if not as_path.is_absolute(): + event_path = sqlmesh_repo_root_path / as_path + + # set the current working directory to the temp path so that config references to eg duckdb "db.db" + # get created in the temp path and not in the SQLMesh repo root path that the tests are triggered from + monkeypatch.chdir(paths[0]) + + # make the tests think they are running in GitHub Actions + monkeypatch.setenv("GITHUB_ACTIONS", "true") + + orig_console = get_console() + try: + set_console(MarkdownConsole(warning_capture_only=True, error_capture_only=True)) + + return GithubController( + paths=paths, + token="abc", + event=( + GithubEvent.from_path(event_path) + if isinstance(event_path, (str, Path)) + else GithubEvent.from_obj(event_path) + ), + client=client, + config=config, + ) + + finally: + set_console(orig_console) + + return _make_function + + +@pytest.fixture +def make_event_issue_comment() -> t.Callable: + def _make_function(action: str, comment: str) -> t.Dict: + return { + "action": action, + "comment": {"body": comment}, + "issue": { + "pull_request": { + "url": "https://api.github.com/repos/Codertocat/Hello-World/pulls/2" + } + }, + } + + return _make_function + + +class MockIssueComment: + def __init__(self, body: str): + self.body = body + + def edit(self, body): + self.body = body + + +@pytest.fixture +def make_mock_issue_comment() -> t.Callable: + def _make_function( + comment: str, created_comments: t.Optional[t.List[MockIssueComment]] = None + ) -> MockIssueComment: + mock_issue_comment = MockIssueComment(body=comment) + if created_comments is not None: + created_comments.append(mock_issue_comment) + return mock_issue_comment + + return _make_function + + +class MockCheckRun: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.previous_kwargs = [] + + def edit(self, **kwargs): + self.previous_kwargs.append(self.kwargs.copy()) + self.kwargs = {**self.kwargs, **kwargs} + + @property + def all_kwargs(self) -> t.List[t.Dict]: + return self.previous_kwargs + [self.kwargs] + + +@pytest.fixture +def make_mock_check_run() -> t.Callable: + def _make_function(**kwargs) -> MockCheckRun: + return MockCheckRun(**kwargs) + + return _make_function + + +@pytest.fixture +def make_event_from_fixture() -> t.Callable: + def _make_function(fixture_path: str) -> GithubEvent: + return GithubEvent.from_path(fixture_path) + + return _make_function + + +@pytest.fixture +def make_pull_request_info() -> t.Callable: + def _make_function(event: GithubEvent) -> PullRequestInfo: + return PullRequestInfo.create_from_pull_request_url(event.pull_request_url) + + return _make_function diff --git a/tests/integrations/github/cicd/fixtures.py b/tests/integrations/github/cicd/fixtures.py deleted file mode 100644 index 056bbdbac0..0000000000 --- a/tests/integrations/github/cicd/fixtures.py +++ /dev/null @@ -1,169 +0,0 @@ -import typing as t - -import pytest -from pytest_mock.plugin import MockerFixture - -from sqlmesh.integrations.github.cicd.config import GithubCICDBotConfig -from sqlmesh.integrations.github.cicd.controller import ( - GithubController, - GithubEvent, - MergeStateStatus, - PullRequestInfo, -) -from sqlmesh.utils import AttributeDict - - -@pytest.fixture -def github_client(mocker: MockerFixture): - from github import Github - from github.Issue import Issue - from github.PullRequest import PullRequest - from github.PullRequestReview import PullRequestReview - from github.Repository import Repository - - client_mock = mocker.MagicMock(spec=Github) - mocker.patch("github.Github", client_mock) - mock_repository = mocker.MagicMock(spec=Repository) - mock_pull_request = mocker.MagicMock(spec=PullRequest) - mock_pull_request.get_reviews = mocker.MagicMock( - side_effect=[mocker.MagicMock(spec=PullRequestReview)] - ) - mock_repository.get_pull = mocker.MagicMock(side_effect=mock_pull_request) - mock_repository.get_issue = mocker.MagicMock(side_effect=mocker.MagicMock(spec=Issue)) - client_mock.get_repo = mocker.MagicMock(side_effect=mock_repository) - - return client_mock - - -@pytest.fixture -def make_pull_request_review() -> t.Callable: - from github.PullRequestReview import PullRequestReview - - def _make_function(username: str, state: str, **kwargs) -> PullRequestReview: - return PullRequestReview( - "test", # type: ignore - {}, - { - # Name is whatever they provide in their GitHub profile or login as fallback. Always use login. - "user": AttributeDict(name="Unrelated", login=username), - "state": state, - **kwargs, - }, - completed=False, - ) - - return _make_function - - -@pytest.fixture -def make_controller(mocker: MockerFixture) -> t.Callable: - from github import Github - - def _make_function( - event_path: t.Union[str, t.Dict], - client: Github, - *, - merge_state_status: MergeStateStatus = MergeStateStatus.CLEAN, - bot_config: t.Optional[GithubCICDBotConfig] = None, - mock_out_context: bool = True, - ) -> GithubController: - if mock_out_context: - mocker.patch("sqlmesh.core.context.Context.apply", mocker.MagicMock()) - mocker.patch("sqlmesh.core.context.Context._run_plan_tests", mocker.MagicMock()) - mocker.patch("sqlmesh.core.context.Context._run_tests", mocker.MagicMock()) - mocker.patch( - "sqlmesh.integrations.github.cicd.controller.GithubController._get_merge_state_status", - mocker.MagicMock(side_effect=lambda: merge_state_status), - ) - if bot_config: - mocker.patch( - "sqlmesh.integrations.github.cicd.controller.GithubController.bot_config", - bot_config, - ) - return GithubController( - paths=["examples/sushi"], - token="abc", - event=( - GithubEvent.from_path(event_path) - if isinstance(event_path, str) - else GithubEvent.from_obj(event_path) - ), - client=client, - ) - - return _make_function - - -@pytest.fixture -def make_event_issue_comment() -> t.Callable: - def _make_function(action: str, comment: str) -> t.Dict: - return { - "action": action, - "comment": {"body": comment}, - "issue": { - "pull_request": { - "url": "https://api.github.com/repos/Codertocat/Hello-World/pulls/2" - } - }, - } - - return _make_function - - -class MockIssueComment: - def __init__(self, body: str): - self.body = body - - def edit(self, body): - self.body = body - - -@pytest.fixture -def make_mock_issue_comment() -> t.Callable: - def _make_function( - comment: str, created_comments: t.Optional[t.List[MockIssueComment]] = None - ) -> MockIssueComment: - mock_issue_comment = MockIssueComment(body=comment) - if created_comments is not None: - created_comments.append(mock_issue_comment) - return mock_issue_comment - - return _make_function - - -class MockCheckRun: - def __init__(self, **kwargs): - self.kwargs = kwargs - self.previous_kwargs = [] - - def edit(self, **kwargs): - self.previous_kwargs.append(self.kwargs.copy()) - self.kwargs = {**self.kwargs, **kwargs} - - @property - def all_kwargs(self) -> t.List[t.Dict]: - return self.previous_kwargs + [self.kwargs] - - -@pytest.fixture -def make_mock_check_run() -> t.Callable: - def _make_function(**kwargs) -> MockCheckRun: - return MockCheckRun(**kwargs) - - return _make_function - - -@pytest.fixture -def make_event_from_fixture() -> t.Callable: - def _make_function(fixture_path: str) -> GithubEvent: - return GithubEvent.from_path(fixture_path) - - return _make_function - - -@pytest.fixture -def make_pull_request_info() -> t.Callable: - def _make_function(event: GithubEvent) -> PullRequestInfo: - return PullRequestInfo.create_from_pull_request_url(event.pull_request_url) - - return _make_function diff --git a/tests/integrations/github/cicd/test_config.py b/tests/integrations/github/cicd/test_config.py index 394015c705..e4424cf3ba 100644 --- a/tests/integrations/github/cicd/test_config.py +++ b/tests/integrations/github/cicd/test_config.py @@ -8,6 +8,7 @@ Config, load_config_from_paths, ) +from sqlmesh.utils.errors import ConfigError from sqlmesh.integrations.github.cicd.config import MergeMethod from tests.utils.test_filesystem import create_temp_file @@ -33,12 +34,14 @@ def test_load_yaml_config_default(tmp_path): assert config.cicd_bot.invalidate_environment_after_deploy assert config.cicd_bot.merge_method is None assert config.cicd_bot.command_namespace is None - assert config.cicd_bot.auto_categorize_changes == CategorizerConfig.all_off() + assert config.cicd_bot.auto_categorize_changes == config.plan.auto_categorize_changes assert config.cicd_bot.default_pr_start is None assert not config.cicd_bot.enable_deploy_command assert config.cicd_bot.skip_pr_backfill - assert config.cicd_bot.pr_include_unmodified is None + assert not config.cicd_bot.pr_include_unmodified assert config.cicd_bot.pr_environment_name is None + assert config.cicd_bot.prod_branch_names == ["main", "master"] + assert not config.cicd_bot.pr_min_intervals def test_load_yaml_config(tmp_path): @@ -61,6 +64,8 @@ def test_load_yaml_config(tmp_path): skip_pr_backfill: false pr_include_unmodified: true pr_environment_name: "MyOverride" + prod_branch_name: testing + pr_min_intervals: 1 model_defaults: dialect: duckdb """, @@ -84,6 +89,8 @@ def test_load_yaml_config(tmp_path): assert not config.cicd_bot.skip_pr_backfill assert config.cicd_bot.pr_include_unmodified assert config.cicd_bot.pr_environment_name == "MyOverride" + assert config.cicd_bot.prod_branch_names == ["testing"] + assert config.cicd_bot.pr_min_intervals == 1 def test_load_python_config_defaults(tmp_path): @@ -108,12 +115,14 @@ def test_load_python_config_defaults(tmp_path): assert config.cicd_bot.invalidate_environment_after_deploy assert config.cicd_bot.merge_method is None assert config.cicd_bot.command_namespace is None - assert config.cicd_bot.auto_categorize_changes == CategorizerConfig.all_off() + assert config.cicd_bot.auto_categorize_changes == config.plan.auto_categorize_changes assert config.cicd_bot.default_pr_start is None assert not config.cicd_bot.enable_deploy_command assert config.cicd_bot.skip_pr_backfill - assert config.cicd_bot.pr_include_unmodified is None + assert not config.cicd_bot.pr_include_unmodified assert config.cicd_bot.pr_environment_name is None + assert config.cicd_bot.prod_branch_names == ["main", "master"] + assert not config.cicd_bot.pr_min_intervals def test_load_python_config(tmp_path): @@ -136,10 +145,12 @@ def test_load_python_config(tmp_path): seed=AutoCategorizationMode.FULL, ), default_pr_start="1 week ago", + pr_min_intervals=1, enable_deploy_command=True, skip_pr_backfill=False, pr_include_unmodified=True, pr_environment_name="MyOverride", + prod_branch_name="testing", ), model_defaults=ModelDefaultsConfig(dialect="duckdb"), ) @@ -165,6 +176,8 @@ def test_load_python_config(tmp_path): assert not config.cicd_bot.skip_pr_backfill assert config.cicd_bot.pr_include_unmodified assert config.cicd_bot.pr_environment_name == "MyOverride" + assert config.cicd_bot.prod_branch_names == ["testing"] + assert config.cicd_bot.pr_min_intervals == 1 def test_validation(tmp_path): @@ -181,7 +194,7 @@ def test_validation(tmp_path): """, ) with pytest.raises( - ValueError, match="enable_deploy_command must be set if command_namespace is set" + ConfigError, match=r".*enable_deploy_command must be set if command_namespace is set.*" ): load_config_from_paths(Config, project_paths=[tmp_path / "config.yaml"]) @@ -197,6 +210,83 @@ def test_validation(tmp_path): """, ) with pytest.raises( - ValueError, match="merge_method must be set if enable_deploy_command is True" + ConfigError, match=r".*merge_method must be set if enable_deploy_command is True.*" ): load_config_from_paths(Config, project_paths=[tmp_path / "config.yaml"]) + + +def test_ttl_in_past(tmp_path): + create_temp_file( + tmp_path, + pathlib.Path("config.yaml"), + """ +environment_ttl: in 1 week +model_defaults: + dialect: duckdb +""", + ) + + config = load_config_from_paths(Config, project_paths=[tmp_path / "config.yaml"]) + assert config.environment_ttl == "in 1 week" + + create_temp_file( + tmp_path, + pathlib.Path("config.yaml"), + """ +environment_ttl: 1 week +model_defaults: + dialect: duckdb +""", + ) + with pytest.raises( + ConfigError, + match=r".*TTL '1 week' is in the past. Please specify a relative time in the future. Ex: `in 1 week` instead of `1 week`.*", + ): + load_config_from_paths(Config, project_paths=[tmp_path / "config.yaml"]) + + create_temp_file( + tmp_path, + pathlib.Path("config.yaml"), + """ +snapshot_ttl: 1 week +model_defaults: + dialect: duckdb + """, + ) + with pytest.raises( + ValueError, + match="TTL '1 week' is in the past. Please specify a relative time in the future. Ex: `in 1 week` instead of `1 week`.", + ): + load_config_from_paths(Config, project_paths=[tmp_path / "config.yaml"]) + + +def test_properties_inherit_from_project_config(tmp_path): + (tmp_path / "config.yaml").write_text(""" +plan: + auto_categorize_changes: + external: off + python: full + sql: off + seed: full + include_unmodified: true + +cicd_bot: + type: github + +model_defaults: + dialect: duckdb +""") + + config = load_config_from_paths(Config, [tmp_path / "config.yaml"]) + + assert ( + config.cicd_bot.auto_categorize_changes + == config.plan.auto_categorize_changes + == CategorizerConfig( + external=AutoCategorizationMode.OFF, + python=AutoCategorizationMode.FULL, + sql=AutoCategorizationMode.OFF, + seed=AutoCategorizationMode.FULL, + ) + ) + assert config.cicd_bot.pr_include_unmodified == config.plan.include_unmodified == True diff --git a/tests/integrations/github/cicd/test_github_commands.py b/tests/integrations/github/cicd/test_github_commands.py index 97684dbef4..01e4c9af31 100644 --- a/tests/integrations/github/cicd/test_github_commands.py +++ b/tests/integrations/github/cicd/test_github_commands.py @@ -1,24 +1,28 @@ # type: ignore +import typing as t import os import pathlib from unittest import TestCase, mock from unittest.result import TestResult +TestResult.__test__ = False # prevent pytest trying to collect this as a test class + import pytest from pytest_mock.plugin import MockerFixture from sqlmesh.core import constants as c from sqlmesh.core.plan import Plan +from sqlmesh.core.test.result import ModelTextTestResult from sqlmesh.core.user import User, UserRole from sqlmesh.integrations.github.cicd import command from sqlmesh.integrations.github.cicd.config import GithubCICDBotConfig, MergeMethod from sqlmesh.integrations.github.cicd.controller import ( + GithubController, GithubCheckConclusion, GithubCheckStatus, ) -from sqlmesh.utils.errors import PlanError, TestError +from sqlmesh.utils.errors import ConflictingPlanError, PlanError, TestError, CICDBotError -pytest_plugins = ["tests.integrations.github.cicd.fixtures"] pytestmark = [ pytest.mark.github, pytest.mark.slow, @@ -144,7 +148,8 @@ def test_run_all_success_with_approvers_approved(
:ship: Prod Plan Being Applied -**New environment `prod` will be created from `prod`**""" + +**`prod` environment will be initialized**""" ) with open(github_output_file, "r", encoding="utf-8") as f: output = f.read() @@ -271,7 +276,8 @@ def test_run_all_success_with_approvers_approved_merge_delete(
:ship: Prod Plan Being Applied -**New environment `prod` will be created from `prod`**""" + +**`prod` environment will be initialized**""" ) with open(github_output_file, "r", encoding="utf-8") as f: output = f.read() @@ -380,7 +386,7 @@ def test_run_all_missing_approval( assert GithubCheckStatus(approval_checks_runs[0]["status"]).is_queued assert GithubCheckStatus(approval_checks_runs[1]["status"]).is_in_progress assert GithubCheckStatus(approval_checks_runs[2]["status"]).is_completed - assert GithubCheckConclusion(approval_checks_runs[2]["conclusion"]).is_neutral + assert GithubCheckConclusion(approval_checks_runs[2]["conclusion"]).is_failure assert len(controller._context.apply.call_args_list) == 1 pr_plan = controller._context.apply.call_args_list[0][0] @@ -400,7 +406,7 @@ def test_run_all_missing_approval( output = f.read() assert ( output - == "run_unit_tests=success\nhas_required_approval=neutral\ncreated_pr_environment=true\npr_environment_name=hello_world_2\npr_environment_synced=success\nprod_plan_preview=success\nprod_environment_synced=skipped\n" + == "run_unit_tests=success\nhas_required_approval=failure\ncreated_pr_environment=true\npr_environment_name=hello_world_2\npr_environment_synced=success\nprod_plan_preview=success\nprod_environment_synced=skipped\n" ) @@ -447,11 +453,11 @@ def test_run_all_test_failed( github_client, bot_config=GithubCICDBotConfig(merge_method=MergeMethod.MERGE), ) - test_result = TestResult() + test_result = ModelTextTestResult(stream=None, descriptions=True, verbosity=0) test_result.testsRun += 1 - test_result.addFailure(TestCase(), (None, None, None)) + test_result.addFailure(TestCase(), (TestError, TestError("some error"), None)) controller._context._run_tests = mocker.MagicMock( - side_effect=lambda **kwargs: (test_result, "some error") + side_effect=lambda **kwargs: (test_result, "") ) controller._context.users = [ User(username="test", github_username="test_github", roles=[UserRole.REQUIRED_APPROVER]) @@ -461,7 +467,8 @@ def test_run_all_test_failed( github_output_file = tmp_path / "github_output.txt" with mock.patch.dict(os.environ, {"GITHUB_OUTPUT": str(github_output_file)}): - command._run_all(controller) + with pytest.raises(CICDBotError): + command._run_all(controller) assert "SQLMesh - Run Unit Tests" in controller._check_run_mapping test_checks_runs = controller._check_run_mapping["SQLMesh - Run Unit Tests"].all_kwargs @@ -472,15 +479,9 @@ def test_run_all_test_failed( assert GithubCheckConclusion(test_checks_runs[2]["conclusion"]).is_failure assert test_checks_runs[2]["output"]["title"] == "Tests Failed" assert ( - test_checks_runs[2]["output"]["summary"] - == """**Num Successful Tests: 0** - - -```some error``` - - -""" + """sqlmesh.utils.errors.TestError: some error""" in test_checks_runs[2]["output"]["summary"] ) + assert """Failed tests""" in test_checks_runs[2]["output"]["summary"] assert "SQLMesh - Prod Plan Preview" in controller._check_run_mapping prod_plan_preview_checks_runs = controller._check_run_mapping[ @@ -496,7 +497,7 @@ def test_run_all_test_failed( ) assert ( prod_plan_preview_checks_runs[1]["output"]["summary"] - == "Unit Test(s) Failed so skipping creating prod plan" + == "Linter or Unit Test(s) failed so skipping creating prod plan" ) assert "SQLMesh - PR Environment Synced" in controller._check_run_mapping @@ -593,7 +594,8 @@ def test_run_all_test_exception( github_output_file = tmp_path / "github_output.txt" with mock.patch.dict(os.environ, {"GITHUB_OUTPUT": str(github_output_file)}): - command._run_all(controller) + with pytest.raises(CICDBotError): + command._run_all(controller) assert "SQLMesh - Run Unit Tests" in controller._check_run_mapping test_checks_runs = controller._check_run_mapping["SQLMesh - Run Unit Tests"].all_kwargs @@ -623,7 +625,7 @@ def test_run_all_test_exception( ) assert ( prod_plan_preview_checks_runs[1]["output"]["summary"] - == "Unit Test(s) Failed so skipping creating prod plan" + == "Linter or Unit Test(s) failed so skipping creating prod plan" ) assert "SQLMesh - PR Environment Synced" in controller._check_run_mapping @@ -727,7 +729,8 @@ def raise_on_pr_plan(plan: Plan): github_output_file = tmp_path / "github_output.txt" with mock.patch.dict(os.environ, {"GITHUB_OUTPUT": str(github_output_file)}): - command._run_all(controller) + with pytest.raises(CICDBotError): + command._run_all(controller) assert "SQLMesh - PR Environment Synced" in controller._check_run_mapping pr_checks_runs = controller._check_run_mapping["SQLMesh - PR Environment Synced"].all_kwargs @@ -788,7 +791,7 @@ def raise_on_pr_plan(plan: Plan): ) -def test_prod_update_failure( +def make_test_prod_update_failure_case( github_client, make_controller, make_mock_check_run, @@ -796,6 +799,8 @@ def test_prod_update_failure( make_pull_request_review, tmp_path: pathlib.Path, mocker: MockerFixture, + to_raise_on_prod_plan: Exception, + expect_prod_sync_conclusion: GithubCheckConclusion, ): """ Scenario: @@ -842,14 +847,15 @@ def test_prod_update_failure( def raise_on_prod_plan(plan: Plan): if plan.environment.name == c.PROD: - raise PlanError("Failed to update Prod environment") + raise to_raise_on_prod_plan controller._context.apply = mocker.MagicMock(side_effect=lambda plan: raise_on_prod_plan(plan)) github_output_file = tmp_path / "github_output.txt" with mock.patch.dict(os.environ, {"GITHUB_OUTPUT": str(github_output_file)}): - command._run_all(controller) + with pytest.raises(CICDBotError): + command._run_all(controller) assert "SQLMesh - Prod Plan Preview" in controller._check_run_mapping prod_plan_preview_checks_runs = controller._check_run_mapping[ @@ -875,7 +881,16 @@ def raise_on_prod_plan(plan: Plan): assert GithubCheckStatus(prod_checks_runs[0]["status"]).is_queued assert GithubCheckStatus(prod_checks_runs[1]["status"]).is_in_progress assert GithubCheckStatus(prod_checks_runs[2]["status"]).is_completed - assert GithubCheckConclusion(prod_checks_runs[2]["conclusion"]).is_action_required + assert GithubCheckConclusion(prod_checks_runs[2]["conclusion"]) == expect_prod_sync_conclusion + if expect_prod_sync_conclusion.is_action_required: + assert prod_checks_runs[2]["output"]["title"] == "Failed due to error applying plan" + assert ( + prod_checks_runs[2]["output"]["summary"] + == f"""**Plan error:** +``` +{to_raise_on_prod_plan} +```""" + ) assert "SQLMesh - Run Unit Tests" in controller._check_run_mapping test_checks_runs = controller._check_run_mapping["SQLMesh - Run Unit Tests"].all_kwargs @@ -912,17 +927,114 @@ def raise_on_prod_plan(plan: Plan):
:ship: Prod Plan Being Applied -**New environment `prod` will be created from `prod`**""" + +**`prod` environment will be initialized**""" ) with open(github_output_file, "r", encoding="utf-8") as f: output = f.read() assert ( output - == "run_unit_tests=success\nhas_required_approval=success\ncreated_pr_environment=true\npr_environment_name=hello_world_2\npr_environment_synced=success\nprod_plan_preview=success\nprod_environment_synced=action_required\n" + == f"run_unit_tests=success\nhas_required_approval=success\ncreated_pr_environment=true\npr_environment_name=hello_world_2\npr_environment_synced=success\nprod_plan_preview=success\nprod_environment_synced={expect_prod_sync_conclusion.value}\n" ) +def test_prod_update_failure( + github_client, + make_controller, + make_mock_check_run, + make_mock_issue_comment, + make_pull_request_review, + tmp_path: pathlib.Path, + mocker: MockerFixture, +): + """ + Scenario: + - PR is not merged + - PR has been approved by a required reviewer + - Tests passed + - PR Merge Method defined + - Delete environment is enabled + - Prod environment update failed + """ + + make_test_prod_update_failure_case( + github_client=github_client, + make_controller=make_controller, + make_mock_check_run=make_mock_check_run, + make_mock_issue_comment=make_mock_issue_comment, + make_pull_request_review=make_pull_request_review, + tmp_path=tmp_path, + mocker=mocker, + to_raise_on_prod_plan=PlanError("Failed to update Prod environment"), + expect_prod_sync_conclusion=GithubCheckConclusion.ACTION_REQUIRED, + ) + + +def test_prod_update_conflict( + github_client, + make_controller, + make_mock_check_run, + make_mock_issue_comment, + make_pull_request_review, + tmp_path: pathlib.Path, + mocker: MockerFixture, +): + """ + Scenario: + - PR is not merged + - PR has been approved by a required reviewer + - Tests passed + - PR Merge Method defined + - Delete environment is enabled + - Prod environment update conflicted + """ + + make_test_prod_update_failure_case( + github_client=github_client, + make_controller=make_controller, + make_mock_check_run=make_mock_check_run, + make_mock_issue_comment=make_mock_issue_comment, + make_pull_request_review=make_pull_request_review, + tmp_path=tmp_path, + mocker=mocker, + to_raise_on_prod_plan=ConflictingPlanError("Plan a conflicts with plan b"), + expect_prod_sync_conclusion=GithubCheckConclusion.SKIPPED, + ) + + +def test_prod_update_exception( + github_client, + make_controller, + make_mock_check_run, + make_mock_issue_comment, + make_pull_request_review, + tmp_path: pathlib.Path, + mocker: MockerFixture, +): + """ + Scenario: + - PR is not merged + - PR has been approved by a required reviewer + - Tests passed + - PR Merge Method defined + - Delete environment is enabled + - Prod environment update fails with an unknown exception + """ + + make_test_prod_update_failure_case( + github_client=github_client, + make_controller=make_controller, + make_mock_check_run=make_mock_check_run, + make_mock_issue_comment=make_mock_issue_comment, + make_pull_request_review=make_pull_request_review, + tmp_path=tmp_path, + mocker=mocker, + to_raise_on_prod_plan=RuntimeError("boom"), + expect_prod_sync_conclusion=GithubCheckConclusion.FAILURE, + ) + + def test_comment_command_invalid( github_client, make_controller, @@ -1044,6 +1156,7 @@ def test_comment_command_deploy_prod( User(username="test", github_username="test_github", roles=[UserRole.REQUIRED_APPROVER]) ] controller._context.invalidate_environment = mocker.MagicMock() + assert not controller.forward_only_plan github_output_file = tmp_path / "github_output.txt" @@ -1108,7 +1221,8 @@ def test_comment_command_deploy_prod(
:ship: Prod Plan Being Applied -**New environment `prod` will be created from `prod`**""" + +**`prod` environment will be initialized**""" ) with open(github_output_file, "r", encoding="utf-8") as f: @@ -1186,3 +1300,155 @@ def test_comment_command_deploy_prod_not_enabled( with open(github_output_file, "r", encoding="utf-8") as f: output = f.read() assert output == "" + + +def test_comment_command_deploy_prod_no_deploy_detected_yet( + github_client, + make_controller, + make_mock_check_run, + make_mock_issue_comment, + tmp_path: pathlib.Path, + mocker: MockerFixture, +): + """ + Scenario: + - PR is not merged + - No requred approvers defined + - Tests passed + - PR Merge Method defined + - Deploy command enabled but not yet triggered + + Outcome: + - "Prod Environment Synced" step should explain the reason why it was skipped is because /deploy has not yet been detected + """ + mock_repo = github_client.get_repo() + mock_repo.create_check_run = mocker.MagicMock( + side_effect=lambda **kwargs: make_mock_check_run(**kwargs) + ) + + created_comments = [] + mock_issue = mock_repo.get_issue() + mock_issue.create_comment = mocker.MagicMock( + side_effect=lambda comment: make_mock_issue_comment( + comment=comment, created_comments=created_comments + ) + ) + mock_issue.get_comments = mocker.MagicMock(side_effect=lambda: created_comments) + + mock_pull_request = mock_repo.get_pull() + mock_pull_request.get_reviews = mocker.MagicMock(lambda: []) + mock_pull_request.merged = False + mock_pull_request.merge = mocker.MagicMock() + + controller = make_controller( + "tests/fixtures/github/pull_request_synchronized.json", + github_client, + bot_config=GithubCICDBotConfig(merge_method=MergeMethod.REBASE, enable_deploy_command=True), + ) + controller._context._run_tests = mocker.MagicMock( + side_effect=lambda **kwargs: (TestResult(), "") + ) + + github_output_file = tmp_path / "github_output.txt" + + with mock.patch.dict(os.environ, {"GITHUB_OUTPUT": str(github_output_file)}): + command._run_all(controller) + + assert "SQLMesh - Prod Plan Preview" in controller._check_run_mapping + assert "SQLMesh - PR Environment Synced" in controller._check_run_mapping + assert "SQLMesh - Prod Environment Synced" in controller._check_run_mapping + assert "SQLMesh - Run Unit Tests" in controller._check_run_mapping + prod_checks_runs = controller._check_run_mapping["SQLMesh - Prod Environment Synced"].all_kwargs + assert len(prod_checks_runs) == 2 + assert GithubCheckStatus(prod_checks_runs[0]["status"]).is_queued + assert GithubCheckStatus(prod_checks_runs[1]["status"]).is_completed + assert prod_checks_runs[1]["output"]["title"] == "Skipped deployment" + assert ( + prod_checks_runs[1]["output"]["summary"] + == "Skipped Deploying to Production because a `/deploy` command has not been detected yet" + ) + assert GithubCheckConclusion(prod_checks_runs[1]["conclusion"]).is_skipped + + # required approvers are irrelevant because /deploy command is enabled + assert "SQLMesh - Has Required Approval" not in controller._check_run_mapping + + +def test_deploy_prod_forward_only( + github_client, + make_controller: t.Callable[..., GithubController], + make_mock_check_run, + make_mock_issue_comment, + tmp_path: pathlib.Path, + mocker: MockerFixture, +): + """ + Scenario: + - PR is created with a branch name indicating that plans should be forward-only + - PR is not merged + - Tests passed + - PR Merge Method defined + - Deploy command has been triggered + + Outcome: + - "Prod Environment Synced" step should show a tip explaining how to retroactively apply forward-only changes to old data + - Bot Comment should show the same tip + """ + mock_repo = github_client.get_repo() + mock_repo.create_check_run = mocker.MagicMock( + side_effect=lambda **kwargs: make_mock_check_run(**kwargs) + ) + + created_comments = [] + mock_issue = mock_repo.get_issue() + mock_issue.create_comment = mocker.MagicMock( + side_effect=lambda comment: make_mock_issue_comment( + comment=comment, created_comments=created_comments + ) + ) + mock_issue.get_comments = mocker.MagicMock(side_effect=lambda: created_comments) + + mock_pull_request = mock_repo.get_pull() + mock_pull_request.get_reviews = mocker.MagicMock(lambda: []) + mock_pull_request.merged = False + mock_pull_request.merge = mocker.MagicMock() + mock_pull_request.head.ref = "unit-test-forward-only" + + controller = make_controller( + "tests/fixtures/github/pull_request_synchronized.json", + github_client, + bot_config=GithubCICDBotConfig( + merge_method=MergeMethod.SQUASH, + enable_deploy_command=True, + forward_only_branch_suffix="-forward-only", + ), + mock_out_context=False, + ) + + # create existing prod to apply against + controller._context.plan(auto_apply=True) + + github_output_file = tmp_path / "github_output.txt" + + # then, run a deploy with forward_only set + assert controller.forward_only_plan + with mock.patch.dict(os.environ, {"GITHUB_OUTPUT": str(github_output_file)}): + command._deploy_production(controller) + + # Prod Environment Synced step should be successful + assert "SQLMesh - Prod Environment Synced" in controller._check_run_mapping + prod_checks_runs = controller._check_run_mapping["SQLMesh - Prod Environment Synced"].all_kwargs + assert len(prod_checks_runs) == 2 + assert GithubCheckStatus(prod_checks_runs[0]["status"]).is_in_progress + assert GithubCheckStatus(prod_checks_runs[1]["status"]).is_completed + assert prod_checks_runs[1]["output"]["title"] == "Deployed to Prod" + assert GithubCheckConclusion(prod_checks_runs[1]["conclusion"]).is_success + + # PR comment should be updated with forward-only tip + assert len(created_comments) == 1 + assert ( + """> [!TIP] +> In order to see this forward-only plan retroactively apply to historical intervals on the production model, run the below for date ranges in scope: +> +> `$ sqlmesh plan --restate-model sushi.customer_revenue_by_day --start YYYY-MM-DD --end YYYY-MM-DD`""" + in created_comments[0].body + ) diff --git a/tests/integrations/github/cicd/test_github_controller.py b/tests/integrations/github/cicd/test_github_controller.py index 1104ac28d2..e4fe10e321 100644 --- a/tests/integrations/github/cicd/test_github_controller.py +++ b/tests/integrations/github/cicd/test_github_controller.py @@ -1,29 +1,58 @@ # type: ignore +import typing as t import os import pathlib from unittest import mock from unittest.mock import PropertyMock, call import pytest +import time_machine from pytest_mock.plugin import MockerFixture from sqlmesh.core import constants as c from sqlmesh.core.config import CategorizerConfig from sqlmesh.core.dialect import parse_one from sqlmesh.core.model import SqlModel -from sqlmesh.core.snapshot import SnapshotChangeCategory from sqlmesh.core.user import User, UserRole +from sqlmesh.core.plan.definition import Plan +from sqlmesh.core.linter.rule import RuleViolation from sqlmesh.integrations.github.cicd.config import GithubCICDBotConfig, MergeMethod from sqlmesh.integrations.github.cicd.controller import ( BotCommand, - GithubCheckStatus, MergeStateStatus, + GithubCheckConclusion, ) -from tests.integrations.github.cicd.fixtures import MockIssueComment +from sqlmesh.integrations.github.cicd.controller import GithubController +from sqlmesh.integrations.github.cicd.command import _update_pr_environment +from sqlmesh.utils.date import to_datetime, now +from tests.integrations.github.cicd.conftest import MockIssueComment +from sqlmesh.utils.errors import SQLMeshError -pytest_plugins = ["tests.integrations.github.cicd.fixtures"] pytestmark = pytest.mark.github + +def add_linter_violations(controller: GithubController): + class _MockModel: + _path = "tests/linter_test.sql" + + class _MockLinterRule: + name = "mock_linter_rule" + + controller._console.show_linter_violations( + [ + RuleViolation( + rule=_MockLinterRule(), violation_msg="Linter warning", violation_range=None + ) + ], + _MockModel(), + ) + controller._console.show_linter_violations( + [RuleViolation(rule=_MockLinterRule(), violation_msg="Linter error", violation_range=None)], + _MockModel(), + is_error=True, + ) + + github_controller_approvers_params = [ ( "2 approvers, 1 required", @@ -237,6 +266,7 @@ def test_pr_plan(github_client, make_controller): def test_pr_plan_auto_categorization(github_client, make_controller): custom_categorizer_config = CategorizerConfig.all_semi() default_start = "1 week ago" + default_start_absolute = to_datetime(default_start, relative_base=now()) controller = make_controller( "tests/fixtures/github/pull_request_synchronized.json", github_client, @@ -250,7 +280,19 @@ def test_pr_plan_auto_categorization(github_client, make_controller): assert not controller._context.apply.called assert controller._context._run_plan_tests.call_args == call(skip_tests=True) assert controller._pr_plan_builder._categorizer_config == custom_categorizer_config - assert controller.pr_plan.start == default_start + assert controller.pr_plan.start == default_start_absolute + assert not controller.pr_plan.start_override_per_model + + +def test_pr_plan_min_intervals(github_client, make_controller): + controller = make_controller( + "tests/fixtures/github/pull_request_synchronized.json", + github_client, + bot_config=GithubCICDBotConfig(default_pr_start="1 day ago", pr_min_intervals=1), + ) + assert controller.pr_plan.environment.name == "hello_world_2" + assert isinstance(controller.pr_plan, Plan) + assert controller.pr_plan.start_override_per_model def test_prod_plan(github_client, make_controller): @@ -297,7 +339,8 @@ def test_prod_plan_with_gaps(github_client, make_controller): assert controller.prod_plan_with_gaps.environment.name == c.PROD assert not controller.prod_plan_with_gaps.skip_backfill - assert not controller._prod_plan_with_gaps_builder._auto_categorization_enabled + # auto_categorization should now be enabled to prevent uncategorized snapshot errors + assert controller._prod_plan_with_gaps_builder._auto_categorization_enabled assert not controller.prod_plan_with_gaps.no_gaps assert not controller._context.apply.called assert controller._context._run_plan_tests.call_args == call(skip_tests=True) @@ -418,6 +461,33 @@ def test_deploy_to_prod_merge_error(github_client, make_controller): controller.deploy_to_prod() +def test_deploy_to_prod_blocked_pr(github_client, make_controller): + mock_pull_request = github_client.get_repo().get_pull() + mock_pull_request.merged = False + controller = make_controller( + "tests/fixtures/github/pull_request_synchronized.json", + github_client, + merge_state_status=MergeStateStatus.BLOCKED, + ) + with pytest.raises( + Exception, + match=r"^Branch protection or ruleset requirement is likely not satisfied, e.g. missing CODEOWNERS approval.*", + ): + controller.deploy_to_prod() + + +def test_deploy_to_prod_not_blocked_pr_if_config_set(github_client, make_controller): + mock_pull_request = github_client.get_repo().get_pull() + mock_pull_request.merged = False + controller = make_controller( + "tests/fixtures/github/pull_request_synchronized.json", + github_client, + merge_state_status=MergeStateStatus.BLOCKED, + bot_config=GithubCICDBotConfig(check_if_blocked_on_deploy_to_prod=False), + ) + controller.deploy_to_prod() + + def test_deploy_to_prod_dirty_pr(github_client, make_controller): mock_pull_request = github_client.get_repo().get_pull() mock_pull_request.merged = False @@ -426,7 +496,10 @@ def test_deploy_to_prod_dirty_pr(github_client, make_controller): github_client, merge_state_status=MergeStateStatus.DIRTY, ) - with pytest.raises(Exception, match=r"^Merge commit cannot be cleanly created.*"): + with pytest.raises( + Exception, + match=r"^Merge commit cannot be cleanly created. Likely from a merge conflict.*", + ): controller.deploy_to_prod() @@ -544,16 +617,11 @@ def test_uncategorized( make_mock_issue_comment, tmp_path: pathlib.Path, ): - snapshot_categrozied = make_snapshot(SqlModel(name="a", query=parse_one("select 1, ds"))) - snapshot_categrozied.categorize_as(SnapshotChangeCategory.BREAKING) snapshot_uncategorized = make_snapshot(SqlModel(name="b", query=parse_one("select 1, ds"))) mocker.patch( - "sqlmesh.core.plan.Plan.modified_snapshots", + "sqlmesh.core.plan.Plan.uncategorized", PropertyMock( - return_value={ - snapshot_categrozied.snapshot_id: snapshot_categrozied, - snapshot_uncategorized.snapshot_id: snapshot_uncategorized, - }, + return_value=[snapshot_uncategorized], ), ) mock_repo = github_client.get_repo() @@ -569,21 +637,220 @@ def test_uncategorized( ) ) mock_issue.get_comments = mocker.MagicMock(side_effect=lambda: created_comments) + + # note: context is deliberately not mocked out so that context.apply() throws UncategorizedPlanError due to the uncategorized snapshot controller = make_controller( - "tests/fixtures/github/pull_request_synchronized.json", github_client + "tests/fixtures/github/pull_request_synchronized.json", + github_client, + mock_out_context=False, ) + assert controller.pr_plan.uncategorized github_output_file = tmp_path / "github_output.txt" with mock.patch.dict(os.environ, {"GITHUB_OUTPUT": str(github_output_file)}): - controller.update_pr_environment_check(GithubCheckStatus.COMPLETED) + _update_pr_environment(controller) assert "SQLMesh - PR Environment Synced" in controller._check_run_mapping pr_environment_check_run = controller._check_run_mapping[ "SQLMesh - PR Environment Synced" ].all_kwargs - assert len(pr_environment_check_run) == 1 + assert len(pr_environment_check_run) == 2 + assert pr_environment_check_run[0]["status"] == "in_progress" + assert pr_environment_check_run[1]["status"] == "completed" + assert pr_environment_check_run[1]["conclusion"] == "action_required" + summary = pr_environment_check_run[1]["output"]["summary"] + assert "Action Required to create or update PR Environment" in summary + assert "The following models could not be categorized automatically" in summary + assert '- "b"' in summary + assert "Run `sqlmesh plan hello_world_2` locally to apply these changes" in summary + + +@time_machine.travel("2025-07-07 00:00:00 UTC", tick=False) +def test_get_plan_summary_doesnt_truncate_backfill_list( + github_client, make_controller: t.Callable[..., GithubController] +): + controller = make_controller( + "tests/fixtures/github/pull_request_synchronized.json", + github_client, + mock_out_context=False, + ) + + summary = controller.get_plan_summary(controller.prod_plan) + + assert "more ...." not in summary + assert ( - pr_environment_check_run[0]["output"]["summary"] - == """
PR Environment Summary
ModelChange TypeDates Loaded
aBreakingN/A
bUncategorizedN/A
""" + """**Models needing backfill:** +* `memory.raw.demographics`: [full refresh] +* `memory.sushi.active_customers`: [full refresh] +* `memory.sushi.count_customers_active`: [full refresh] +* `memory.sushi.count_customers_inactive`: [full refresh] +* `memory.sushi.customer_revenue_by_day`: [2025-06-30 - 2025-07-06] +* `memory.sushi.customer_revenue_lifetime`: [2025-06-30 - 2025-07-06] +* `memory.sushi.customers`: [full refresh] +* `memory.sushi.items`: [2025-06-30 - 2025-07-06] +* `memory.sushi.latest_order`: [full refresh] +* `memory.sushi.marketing`: [2025-06-30 - 2025-07-06] +* `memory.sushi.order_items`: [2025-06-30 - 2025-07-06] +* `memory.sushi.orders`: [2025-06-30 - 2025-07-06] +* `memory.sushi.raw_marketing`: [full refresh] +* `memory.sushi.top_waiters`: [recreate view] +* `memory.sushi.waiter_as_customer_by_day`: [2025-06-30 - 2025-07-06] +* `memory.sushi.waiter_names`: [full refresh] +* `memory.sushi.waiter_revenue_by_day`: [2025-06-30 - 2025-07-06]""" + in summary ) + + +def test_get_plan_summary_includes_warnings_and_errors( + github_client, make_controller: t.Callable[..., GithubController] +): + controller = make_controller( + "tests/fixtures/github/pull_request_synchronized.json", + github_client, + mock_out_context=False, + ) + + controller._console.log_warning("Warning 1\nWith multiline") + controller._console.log_warning("Warning 2") + controller._console.log_error("Error 1") + add_linter_violations(controller) + + summary = controller.get_plan_summary(controller.prod_plan) + + assert ("> [!WARNING]\n>\n> - Warning 1\n> With multiline\n>\n> - Warning 2\n>\n>") in summary + assert ( + "> Linter warnings for `tests/linter_test.sql`:\n> - mock_linter_rule: Linter warning\n>" + ) in summary + assert ("> [!CAUTION]\n>\n> - Error 1\n>\n>") in summary + assert ( + "> Linter **errors** for `tests/linter_test.sql`:\n> - mock_linter_rule: Linter error\n>" + ) in summary + + +def test_get_pr_environment_summary_includes_warnings_and_errors( + github_client, make_controller: t.Callable[..., GithubController] +): + controller = make_controller( + "tests/fixtures/github/pull_request_synchronized.json", + github_client, + mock_out_context=False, + ) + + controller._console.log_warning("Warning 1") + controller._console.log_error("Error 1") + add_linter_violations(controller) + + # completed with no exception triggers a SUCCESS conclusion and only shows warnings + success_summary = controller.get_pr_environment_summary( + conclusion=GithubCheckConclusion.SUCCESS + ) + assert "> [!WARNING]\n>\n> - Warning 1\n" in success_summary + assert ( + "> Linter warnings for `tests/linter_test.sql`:\n> - mock_linter_rule: Linter warning\n" + in success_summary + ) + assert "Error 1" not in success_summary + assert "mock_linter_rule: Linter error" not in success_summary + + # since they got consumed in the previous call + controller._console.log_warning("Warning 1") + controller._console.log_error("Error 1") + add_linter_violations(controller) + + # completed with an exception triggers a FAILED conclusion and shows errors + error_summary = controller.get_pr_environment_summary( + conclusion=GithubCheckConclusion.FAILURE, exception=SQLMeshError("Something broke") + ) + assert "> [!WARNING]\n>\n> - Warning 1\n>\n" in error_summary + assert ( + "> Linter warnings for `tests/linter_test.sql`:\n> - mock_linter_rule: Linter warning\n" + in error_summary + ) + assert "[!CAUTION]\n>
\n>\n> - Error 1\n>\n" in error_summary + assert ( + "> Linter **errors** for `tests/linter_test.sql`:\n> - mock_linter_rule: Linter error\n" + in error_summary + ) + + +def test_pr_comment_deploy_indicator_includes_command_namespace( + mocker: MockerFixture, + github_client, + make_mock_issue_comment, + make_controller: t.Callable[..., GithubController], +): + mock_repo = github_client.get_repo() + + created_comments = [] + mock_issue = mock_repo.get_issue() + mock_issue.create_comment = mocker.MagicMock( + side_effect=lambda comment: make_mock_issue_comment( + comment=comment, created_comments=created_comments + ) + ) + mock_issue.get_comments = mocker.MagicMock(side_effect=lambda: created_comments) + + controller = make_controller( + "tests/fixtures/github/pull_request_synchronized.json", + github_client, + mock_out_context=False, + bot_config=GithubCICDBotConfig( + enable_deploy_command=True, + merge_method=MergeMethod.SQUASH, + command_namespace="#SQLMesh", + ), + ) + + _update_pr_environment(controller) + + assert len(created_comments) > 0 + + comment = created_comments[0].body + + assert "To **apply** this PR's plan to prod, comment:\n - `/deploy`" not in comment + assert "To **apply** this PR's plan to prod, comment:\n - `#SQLMesh/deploy`" in comment + + +def test_forward_only_config_falls_back_to_plan_config( + github_client, + make_controller: t.Callable[..., GithubController], + mocker: MockerFixture, +): + mock_repo = github_client.get_repo() + mock_repo.create_check_run = mocker.MagicMock( + side_effect=lambda **kwargs: make_mock_check_run(**kwargs) + ) + + created_comments = [] + mock_issue = mock_repo.get_issue() + mock_issue.create_comment = mocker.MagicMock( + side_effect=lambda comment: make_mock_issue_comment( + comment=comment, created_comments=created_comments + ) + ) + mock_issue.get_comments = mocker.MagicMock(side_effect=lambda: created_comments) + + mock_pull_request = mock_repo.get_pull() + mock_pull_request.get_reviews = mocker.MagicMock(lambda: []) + mock_pull_request.merged = False + mock_pull_request.merge = mocker.MagicMock() + mock_pull_request.head.ref = "unit-test-test-pr" + + controller = make_controller( + "tests/fixtures/github/pull_request_synchronized.json", + github_client, + bot_config=GithubCICDBotConfig( + merge_method=MergeMethod.SQUASH, + enable_deploy_command=True, + forward_only_branch_suffix="-forward-only", + ), + mock_out_context=False, + ) + + controller._context.config.plan.forward_only = True + assert controller.forward_only_plan + + controller._context.config.plan.forward_only = False + assert controller.forward_only_plan is False diff --git a/tests/integrations/github/cicd/test_github_event.py b/tests/integrations/github/cicd/test_github_event.py index 719b60fb3f..88979b05a8 100644 --- a/tests/integrations/github/cicd/test_github_event.py +++ b/tests/integrations/github/cicd/test_github_event.py @@ -1,6 +1,5 @@ import pytest -pytest_plugins = ["tests.integrations.github.cicd.fixtures"] pytestmark = pytest.mark.github diff --git a/tests/integrations/github/cicd/test_integration.py b/tests/integrations/github/cicd/test_integration.py index 51a0367b9f..ce357f6d36 100644 --- a/tests/integrations/github/cicd/test_integration.py +++ b/tests/integrations/github/cicd/test_integration.py @@ -9,13 +9,14 @@ from unittest import mock import pytest -from freezegun import freeze_time +import time_machine from pytest_mock.plugin import MockerFixture from sqlglot import exp -from sqlmesh.core.config import CategorizerConfig +from sqlmesh.core.config import CategorizerConfig, Config, ModelDefaultsConfig, LinterConfig from sqlmesh.core.engine_adapter.shared import DataObject from sqlmesh.core.user import User, UserRole +from sqlmesh.core.model.common import ParsableSql from sqlmesh.integrations.github.cicd import command from sqlmesh.integrations.github.cicd.config import GithubCICDBotConfig, MergeMethod from sqlmesh.integrations.github.cicd.controller import ( @@ -23,9 +24,9 @@ GithubCheckStatus, GithubController, ) -from tests.integrations.github.cicd.fixtures import MockIssueComment +from sqlmesh.utils.errors import CICDBotError, SQLMeshError +from tests.integrations.github.cicd.conftest import MockIssueComment -pytest_plugins = ["tests.integrations.github.cicd.fixtures"] pytestmark = [ pytest.mark.slow, pytest.mark.github, @@ -37,11 +38,9 @@ def get_environment_objects(controller: GithubController, environment: str) -> t def get_num_days_loaded(controller: GithubController, environment: str, model: str) -> int: - return int( - controller._context.engine_adapter.fetchdf( - f"SELECT distinct event_date FROM sushi__{environment}.{model}" - ).count() - ) + return controller._context.engine_adapter.fetchdf( + f"SELECT distinct event_date FROM sushi__{environment}.{model}" + ).shape[0] def get_columns( @@ -51,7 +50,147 @@ def get_columns( return controller._context.engine_adapter.columns(table) -@freeze_time("2023-01-01 15:00:00") +@time_machine.travel("2023-01-01 15:00:00 UTC") +def test_linter( + github_client, + make_controller, + make_mock_check_run, + make_mock_issue_comment, + make_pull_request_review, + tmp_path: pathlib.Path, + mocker: MockerFixture, +): + """ + PR with the Linter enabled will contain a new check with the linter specific output. + + Scenarios: + - PR with linter errors leads to job failures & skips + - PR with linter warnings leads to job successes + """ + mock_repo = github_client.get_repo() + mock_repo.create_check_run = mocker.MagicMock( + side_effect=lambda **kwargs: make_mock_check_run(**kwargs) + ) + + created_comments: t.List[MockIssueComment] = [] + mock_issue = mock_repo.get_issue() + mock_issue.create_comment = mocker.MagicMock( + side_effect=lambda comment: make_mock_issue_comment( + comment=comment, created_comments=created_comments + ) + ) + mock_issue.get_comments = mocker.MagicMock(side_effect=lambda: created_comments) + + mock_pull_request = mock_repo.get_pull() + mock_pull_request.get_reviews = mocker.MagicMock( + side_effect=lambda: [make_pull_request_review(username="test_github", state="APPROVED")] + ) + mock_pull_request.merged = False + mock_pull_request.merge = mocker.MagicMock() + + before_all = [ + "CREATE SCHEMA IF NOT EXISTS raw", + "DROP VIEW IF EXISTS raw.demographics", + "CREATE VIEW raw.demographics AS (SELECT 1 AS customer_id, '00000' AS zip)", + ] + + # Case 1: Test for linter errors + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + linter=LinterConfig(enabled=True, rules="ALL"), + before_all=before_all, + ) + + controller = make_controller( + "tests/fixtures/github/pull_request_synchronized.json", + github_client, + bot_config=GithubCICDBotConfig( + merge_method=MergeMethod.MERGE, + invalidate_environment_after_deploy=False, + auto_categorize_changes=CategorizerConfig.all_full(), + default_pr_start=None, + skip_pr_backfill=False, + ), + mock_out_context=False, + config=config, + ) + + github_output_file = tmp_path / "github_output.txt" + + with mock.patch.dict(os.environ, {"GITHUB_OUTPUT": str(github_output_file)}): + with pytest.raises(CICDBotError): + command._run_all(controller) + + assert "SQLMesh - Linter" in controller._check_run_mapping + linter_checks_runs = controller._check_run_mapping["SQLMesh - Linter"].all_kwargs + assert "Linter **errors** for" in linter_checks_runs[2]["output"]["summary"] + assert GithubCheckConclusion(linter_checks_runs[2]["conclusion"]).is_failure + + for check in ( + "SQLMesh - PR Environment Synced", + "SQLMesh - Prod Plan Preview", + ): + assert check in controller._check_run_mapping + check_runs = controller._check_run_mapping[check].all_kwargs + assert GithubCheckConclusion(check_runs[-1]["conclusion"]).is_skipped + + with open(github_output_file, "r", encoding="utf-8") as f: + output = f.read() + assert ( + output + == "linter=failure\nrun_unit_tests=success\npr_environment_name=hello_world_2\npr_environment_synced=skipped\nprod_plan_preview=skipped\n" + ) + + # empty github file for next case + open(github_output_file, "w").close() + + # Case 2: Test for linter warnings + config = Config( + model_defaults=ModelDefaultsConfig(dialect="duckdb"), + linter=LinterConfig(enabled=True, warn_rules="ALL"), + before_all=before_all, + ) + + controller = make_controller( + "tests/fixtures/github/pull_request_synchronized.json", + github_client, + bot_config=GithubCICDBotConfig( + merge_method=MergeMethod.MERGE, + invalidate_environment_after_deploy=False, + auto_categorize_changes=CategorizerConfig.all_full(), + default_pr_start=None, + skip_pr_backfill=False, + ), + mock_out_context=False, + config=config, + ) + + with mock.patch.dict(os.environ, {"GITHUB_OUTPUT": str(github_output_file)}): + command._run_all(controller) + + assert "SQLMesh - Linter" in controller._check_run_mapping + linter_checks_runs = controller._check_run_mapping["SQLMesh - Linter"].all_kwargs + assert "Linter warnings for" in linter_checks_runs[-1]["output"]["summary"] + assert GithubCheckConclusion(linter_checks_runs[-1]["conclusion"]).is_success + + for check in ( + "SQLMesh - Run Unit Tests", + "SQLMesh - PR Environment Synced", + "SQLMesh - Prod Plan Preview", + ): + assert check in controller._check_run_mapping + check_runs = controller._check_run_mapping[check].all_kwargs + assert GithubCheckConclusion(check_runs[-1]["conclusion"]).is_success + + with open(github_output_file, "r", encoding="utf-8") as f: + output = f.read() + assert ( + output + == "linter=success\nrun_unit_tests=success\ncreated_pr_environment=true\npr_environment_name=hello_world_2\npr_environment_synced=success\nprod_plan_preview=success\n" + ) + + +@time_machine.travel("2023-01-01 15:00:00 UTC") def test_merge_pr_has_non_breaking_change( github_client, make_controller, @@ -111,8 +250,10 @@ def test_merge_pr_has_non_breaking_change( ] # Make a non-breaking change model = controller._context.get_model("sushi.waiter_revenue_by_day").copy() - model.query.select(exp.alias_("1", "new_col"), copy=False) - controller._context.upsert_model(model) + controller._context.upsert_model( + model, + query_=ParsableSql(sql=model.query.select(exp.alias_("1", "new_col")).sql(model.dialect)), + ) github_output_file = tmp_path / "github_output.txt" @@ -140,9 +281,19 @@ def test_merge_pr_has_non_breaking_change( assert GithubCheckStatus(pr_checks_runs[2]["status"]).is_completed assert GithubCheckConclusion(pr_checks_runs[2]["conclusion"]).is_success assert pr_checks_runs[2]["output"]["title"] == "PR Virtual Data Environment: hello_world_2" + pr_env_summary = pr_checks_runs[2]["output"]["summary"] + assert ( + """### Directly Modified +- `memory.sushi.waiter_revenue_by_day` (Non-breaking) + **Kind:** INCREMENTAL_BY_TIME_RANGE + **Dates loaded in PR:** [2022-12-25 - 2022-12-31]""" + in pr_env_summary + ) assert ( - pr_checks_runs[2]["output"]["summary"] - == """
PR Environment Summary
ModelChange TypeDates Loaded
sushi.waiter_revenue_by_dayNon-breaking2022-12-25 - 2022-12-31
""" + """### Indirectly Modified +- `memory.sushi.top_waiters` (Indirect Non-breaking) + **Kind:** VIEW [recreate view]""" + in pr_env_summary ) assert "SQLMesh - Prod Plan Preview" in controller._check_run_mapping @@ -154,35 +305,42 @@ def test_merge_pr_has_non_breaking_change( assert GithubCheckStatus(prod_plan_preview_checks_runs[1]["status"]).is_in_progress assert GithubCheckStatus(prod_plan_preview_checks_runs[2]["status"]).is_completed assert GithubCheckConclusion(prod_plan_preview_checks_runs[2]["conclusion"]).is_success - expected_prod_plan_summary = """**Summary of differences against `prod`:** - - -**Directly Modified:** -- `sushi.waiter_revenue_by_day` -```diff ---- - -+++ - -@@ -15,7 +15,8 @@ - - SELECT - CAST(o.waiter_id AS INT) AS waiter_id, - CAST(SUM(oi.quantity * i.price) AS DOUBLE) AS revenue, -- CAST(o.event_date AS DATE) AS event_date -+ CAST(o.event_date AS DATE) AS event_date, -+ 1 AS new_col - FROM sushi.orders AS o - LEFT JOIN sushi.order_items AS oi - ON o.id = oi.order_id AND o.event_date = oi.event_date -``` - -**Indirectly Modified:** -- `sushi.top_waiters` + expected_prod_plan_directly_modified_summary = """**Directly Modified:** +* `memory.sushi.waiter_revenue_by_day` (Non-breaking) + + ```diff + --- + + +++ + + @@ -17,7 +17,8 @@ + + SELECT + CAST(o.waiter_id AS INT) AS waiter_id, + CAST(SUM(oi.quantity * i.price) AS DOUBLE) AS revenue, + - CAST(o.event_date AS DATE) AS event_date + + CAST(o.event_date AS DATE) AS event_date, + + 1 AS new_col + FROM sushi.orders AS o + LEFT JOIN sushi.order_items AS oi + ON o.id = oi.order_id AND o.event_date = oi.event_date + ``` + Indirectly Modified Children: + - `memory.sushi.top_waiters` (Indirect Non-breaking) """ + expected_prod_plan_indirectly_modified_summary = """**Indirectly Modified:** +- `memory.sushi.top_waiters` (Indirect Non-breaking) +""" + assert prod_plan_preview_checks_runs[2]["output"]["title"] == "Prod Plan Preview" - assert prod_plan_preview_checks_runs[2]["output"]["summary"] == expected_prod_plan_summary + prod_plan_preview_summary = prod_plan_preview_checks_runs[2]["output"]["summary"] + assert ( + "This is a preview that shows the differences between this PR environment `hello_world_2` and `prod`" + in prod_plan_preview_summary + ) + assert expected_prod_plan_directly_modified_summary in prod_plan_preview_summary + assert expected_prod_plan_indirectly_modified_summary in prod_plan_preview_summary assert "SQLMesh - Prod Environment Synced" in controller._check_run_mapping prod_checks_runs = controller._check_run_mapping["SQLMesh - Prod Environment Synced"].all_kwargs @@ -192,10 +350,10 @@ def test_merge_pr_has_non_breaking_change( assert GithubCheckStatus(prod_checks_runs[2]["status"]).is_completed assert GithubCheckConclusion(prod_checks_runs[2]["conclusion"]).is_success assert prod_checks_runs[2]["output"]["title"] == "Deployed to Prod" - assert ( - prod_checks_runs[2]["output"]["summary"] - == "**Generated Prod Plan**\n" + expected_prod_plan_summary - ) + prod_environment_synced_summary = prod_checks_runs[2]["output"]["summary"] + assert "**Generated Prod Plan**" in prod_environment_synced_summary + assert expected_prod_plan_directly_modified_summary in prod_environment_synced_summary + assert expected_prod_plan_indirectly_modified_summary in prod_environment_synced_summary assert "SQLMesh - Has Required Approval" in controller._check_run_mapping approval_checks_runs = controller._check_run_mapping[ @@ -225,19 +383,15 @@ def test_merge_pr_has_non_breaking_change( assert mock_pull_request.merge.called assert len(created_comments) == 1 + comment_body = created_comments[0].body assert ( - created_comments[0].body - == f""":robot: **SQLMesh Bot Info** :robot: + """:robot: **SQLMesh Bot Info** :robot: - :eyes: To **review** this PR's changes, use virtual data environment: - - `hello_world_2` -
- :ship: Prod Plan Being Applied - -{expected_prod_plan_summary} -
- -""" + - `hello_world_2`""" + in comment_body ) + assert expected_prod_plan_directly_modified_summary in comment_body + assert expected_prod_plan_indirectly_modified_summary in comment_body with open(github_output_file, "r", encoding="utf-8") as f: output = f.read() @@ -247,7 +401,7 @@ def test_merge_pr_has_non_breaking_change( ) -@freeze_time("2023-01-01 15:00:00") +@time_machine.travel("2023-01-01 15:00:00 UTC") def test_merge_pr_has_non_breaking_change_diff_start( github_client, make_controller, @@ -307,8 +461,10 @@ def test_merge_pr_has_non_breaking_change_diff_start( ] # Make a non-breaking change model = controller._context.get_model("sushi.waiter_revenue_by_day").copy() - model.query.select(exp.alias_("1", "new_col"), copy=False) - controller._context.upsert_model(model) + controller._context.upsert_model( + model, + query_=ParsableSql(sql=model.query.select(exp.alias_("1", "new_col")).sql(model.dialect)), + ) github_output_file = tmp_path / "github_output.txt" @@ -336,9 +492,20 @@ def test_merge_pr_has_non_breaking_change_diff_start( assert GithubCheckStatus(pr_checks_runs[2]["status"]).is_completed assert GithubCheckConclusion(pr_checks_runs[2]["conclusion"]).is_success assert pr_checks_runs[2]["output"]["title"] == "PR Virtual Data Environment: hello_world_2" + pr_env_summary = pr_checks_runs[2]["output"]["summary"] + assert ( + """### Directly Modified +- `memory.sushi.waiter_revenue_by_day` (Non-breaking) + **Kind:** INCREMENTAL_BY_TIME_RANGE + **Dates loaded in PR:** [2022-12-29 - 2022-12-31] + **Dates *not* loaded in PR:** [2022-12-25 - 2022-12-28]""" + in pr_env_summary + ) assert ( - pr_checks_runs[2]["output"]["summary"] - == """
PR Environment Summary
ModelChange TypeDates Loaded
sushi.waiter_revenue_by_dayNon-breaking2022-12-29 - 2022-12-31
""" + """### Indirectly Modified +- `memory.sushi.top_waiters` (Indirect Non-breaking) + **Kind:** VIEW [recreate view]""" + in pr_env_summary ) assert "SQLMesh - Prod Plan Preview" in controller._check_run_mapping @@ -351,37 +518,41 @@ def test_merge_pr_has_non_breaking_change_diff_start( assert GithubCheckStatus(prod_plan_preview_checks_runs[2]["status"]).is_completed assert GithubCheckConclusion(prod_plan_preview_checks_runs[2]["conclusion"]).is_success assert prod_plan_preview_checks_runs[2]["output"]["title"] == "Prod Plan Preview" - expected_prod_plan = """**Summary of differences against `prod`:** - - -**Directly Modified:** -- `sushi.waiter_revenue_by_day` -```diff ---- - -+++ - -@@ -15,7 +15,8 @@ - - SELECT - CAST(o.waiter_id AS INT) AS waiter_id, - CAST(SUM(oi.quantity * i.price) AS DOUBLE) AS revenue, -- CAST(o.event_date AS DATE) AS event_date -+ CAST(o.event_date AS DATE) AS event_date, -+ 1 AS new_col - FROM sushi.orders AS o - LEFT JOIN sushi.order_items AS oi - ON o.id = oi.order_id AND o.event_date = oi.event_date -``` -**Indirectly Modified:** -- `sushi.top_waiters` - - -**Models needing backfill (missing dates):** -* `sushi.waiter_revenue_by_day`: 2022-12-25 - 2022-12-28 + expected_prod_plan_directly_modified_summary = """**Directly Modified:** +* `memory.sushi.waiter_revenue_by_day` (Non-breaking) + + ```diff + --- + + +++ + + @@ -17,7 +17,8 @@ + + SELECT + CAST(o.waiter_id AS INT) AS waiter_id, + CAST(SUM(oi.quantity * i.price) AS DOUBLE) AS revenue, + - CAST(o.event_date AS DATE) AS event_date + + CAST(o.event_date AS DATE) AS event_date, + + 1 AS new_col + FROM sushi.orders AS o + LEFT JOIN sushi.order_items AS oi + ON o.id = oi.order_id AND o.event_date = oi.event_date + ``` + Indirectly Modified Children: + - `memory.sushi.top_waiters` (Indirect Non-breaking) """ - assert prod_plan_preview_checks_runs[2]["output"]["summary"] == expected_prod_plan + expected_prod_plan_indirectly_modified_summary = """**Indirectly Modified:** +- `memory.sushi.top_waiters` (Indirect Non-breaking) +""" + + prod_plan_preview_summary = prod_plan_preview_checks_runs[2]["output"]["summary"] + assert ( + "This is a preview that shows the differences between this PR environment `hello_world_2` and `prod`" + in prod_plan_preview_summary + ) + assert expected_prod_plan_directly_modified_summary in prod_plan_preview_summary + assert expected_prod_plan_indirectly_modified_summary in prod_plan_preview_summary assert "SQLMesh - Prod Environment Synced" in controller._check_run_mapping prod_checks_runs = controller._check_run_mapping["SQLMesh - Prod Environment Synced"].all_kwargs @@ -391,9 +562,10 @@ def test_merge_pr_has_non_breaking_change_diff_start( assert GithubCheckStatus(prod_checks_runs[2]["status"]).is_completed assert GithubCheckConclusion(prod_checks_runs[2]["conclusion"]).is_success assert prod_checks_runs[2]["output"]["title"] == "Deployed to Prod" - assert ( - prod_checks_runs[2]["output"]["summary"] == "**Generated Prod Plan**\n" + expected_prod_plan - ) + prod_environment_synced_summary = prod_checks_runs[2]["output"]["summary"] + assert "**Generated Prod Plan**" in prod_environment_synced_summary + assert expected_prod_plan_directly_modified_summary in prod_environment_synced_summary + assert expected_prod_plan_indirectly_modified_summary in prod_environment_synced_summary assert "SQLMesh - Has Required Approval" in controller._check_run_mapping approval_checks_runs = controller._check_run_mapping[ @@ -424,19 +596,15 @@ def test_merge_pr_has_non_breaking_change_diff_start( assert mock_pull_request.merge.called assert len(created_comments) == 1 + comment_body = created_comments[0].body assert ( - created_comments[0].body - == f""":robot: **SQLMesh Bot Info** :robot: + """:robot: **SQLMesh Bot Info** :robot: - :eyes: To **review** this PR's changes, use virtual data environment: - - `hello_world_2` -
- :ship: Prod Plan Being Applied - -{expected_prod_plan} -
- -""" + - `hello_world_2`""" + in comment_body ) + assert expected_prod_plan_directly_modified_summary in comment_body + assert expected_prod_plan_indirectly_modified_summary in comment_body with open(github_output_file, "r", encoding="utf-8") as f: output = f.read() @@ -446,7 +614,7 @@ def test_merge_pr_has_non_breaking_change_diff_start( ) -@freeze_time("2023-01-01 15:00:00") +@time_machine.travel("2023-01-01 15:00:00 UTC") def test_merge_pr_has_non_breaking_change_no_categorization( github_client, make_controller, @@ -503,13 +671,16 @@ def test_merge_pr_has_non_breaking_change_no_categorization( ] # Make a non-breaking change model = controller._context.get_model("sushi.waiter_revenue_by_day").copy() - model.query.select(exp.alias_("1", "new_col"), copy=False) - controller._context.upsert_model(model) + controller._context.upsert_model( + model, + query_=ParsableSql(sql=model.query.select(exp.alias_("1", "new_col")).sql(model.dialect)), + ) github_output_file = tmp_path / "github_output.txt" with mock.patch.dict(os.environ, {"GITHUB_OUTPUT": str(github_output_file)}): - command._run_all(controller) + with pytest.raises(CICDBotError): + command._run_all(controller) assert "SQLMesh - Run Unit Tests" in controller._check_run_mapping test_checks_runs = controller._check_run_mapping["SQLMesh - Run Unit Tests"].all_kwargs @@ -532,9 +703,15 @@ def test_merge_pr_has_non_breaking_change_no_categorization( assert GithubCheckStatus(pr_checks_runs[2]["status"]).is_completed assert GithubCheckConclusion(pr_checks_runs[2]["conclusion"]).is_action_required assert pr_checks_runs[2]["output"]["title"] == "PR Virtual Data Environment: hello_world_2" + pr_env_summary = pr_checks_runs[2]["output"]["summary"] assert ( - pr_checks_runs[2]["output"]["summary"] - == ":warning: Action Required to create or update PR Environment `hello_world_2`. There are likely uncateogrized changes. Run `plan` locally to apply these changes. If you want the bot to automatically categorize changes, then check documentation (https://sqlmesh.readthedocs.io/en/stable/integrations/github/) for more information." + """:warning: Action Required to create or update PR Environment `hello_world_2` :warning: + +The following models could not be categorized automatically: +- "memory"."sushi"."waiter_revenue_by_day" + +Run `sqlmesh plan hello_world_2` locally to apply these changes""" + in pr_env_summary ) assert "SQLMesh - Prod Plan Preview" in controller._check_run_mapping @@ -560,9 +737,11 @@ def test_merge_pr_has_non_breaking_change_no_categorization( assert GithubCheckStatus(prod_checks_runs[0]["status"]).is_queued assert GithubCheckStatus(prod_checks_runs[1]["status"]).is_completed assert GithubCheckConclusion(prod_checks_runs[1]["conclusion"]).is_skipped - skip_reason = "Skipped Deploying to Production because the PR environment was not updated" - assert prod_checks_runs[1]["output"]["title"] == skip_reason - assert prod_checks_runs[1]["output"]["summary"] == skip_reason + assert prod_checks_runs[1]["output"]["title"] == "Skipped deployment" + assert ( + prod_checks_runs[1]["output"]["summary"] + == "Skipped Deploying to Production because the PR environment was not updated" + ) assert "SQLMesh - Has Required Approval" in controller._check_run_mapping approval_checks_runs = controller._check_run_mapping[ @@ -680,8 +859,8 @@ def test_merge_pr_has_no_changes( assert GithubCheckConclusion(pr_checks_runs[2]["conclusion"]).is_skipped assert pr_checks_runs[2]["output"]["title"] == "PR Virtual Data Environment: hello_world_2" assert ( - pr_checks_runs[2]["output"]["summary"] - == ":next_track_button: Skipped creating or updating PR Environment `hello_world_2`. No changes were detected compared to the prod environment." + ":next_track_button: Skipped creating or updating PR Environment `hello_world_2` :next_track_button:\n\nNo changes were detected compared to the prod environment." + in pr_checks_runs[2]["output"]["summary"] ) assert "SQLMesh - Prod Plan Preview" in controller._check_run_mapping @@ -693,9 +872,11 @@ def test_merge_pr_has_no_changes( assert GithubCheckStatus(prod_plan_preview_checks_runs[1]["status"]).is_in_progress assert GithubCheckStatus(prod_plan_preview_checks_runs[2]["status"]).is_completed assert GithubCheckConclusion(prod_plan_preview_checks_runs[2]["conclusion"]).is_success - expected_prod_plan_summary = "**No differences when compared to `prod`**\n\n\n" + expected_prod_plan_summary = ( + "**No changes to plan: project files match the `prod` environment**" + ) assert prod_plan_preview_checks_runs[2]["output"]["title"] == "Prod Plan Preview" - assert prod_plan_preview_checks_runs[2]["output"]["summary"] == expected_prod_plan_summary + assert expected_prod_plan_summary in prod_plan_preview_checks_runs[2]["output"]["summary"] assert "SQLMesh - Prod Environment Synced" in controller._check_run_mapping prod_checks_runs = controller._check_run_mapping["SQLMesh - Prod Environment Synced"].all_kwargs @@ -705,10 +886,9 @@ def test_merge_pr_has_no_changes( assert GithubCheckStatus(prod_checks_runs[2]["status"]).is_completed assert GithubCheckConclusion(prod_checks_runs[2]["conclusion"]).is_success assert prod_checks_runs[2]["output"]["title"] == "Deployed to Prod" - assert ( - prod_checks_runs[2]["output"]["summary"] - == "**Generated Prod Plan**\n" + expected_prod_plan_summary - ) + prod_environment_synced_summary = prod_checks_runs[2]["output"]["summary"] + assert "**Generated Prod Plan**" in prod_environment_synced_summary + assert expected_prod_plan_summary in prod_environment_synced_summary assert "SQLMesh - Has Required Approval" in controller._check_run_mapping approval_checks_runs = controller._check_run_mapping[ @@ -735,17 +915,14 @@ def test_merge_pr_has_no_changes( assert mock_pull_request.merge.called assert len(created_comments) == 1 + comment_body = created_comments[0].body assert ( - created_comments[0].body - == f""":robot: **SQLMesh Bot Info** :robot: + f""":robot: **SQLMesh Bot Info** :robot:
- :ship: Prod Plan Being Applied - -{expected_prod_plan_summary} -
- -""" + :ship: Prod Plan Being Applied""" + in comment_body ) + assert expected_prod_plan_summary in comment_body with open(github_output_file, "r", encoding="utf-8") as f: output = f.read() @@ -755,7 +932,7 @@ def test_merge_pr_has_no_changes( ) -@freeze_time("2023-01-01 15:00:00") +@time_machine.travel("2023-01-01 15:00:00 UTC") def test_no_merge_since_no_deploy_signal( github_client, make_controller, @@ -813,8 +990,10 @@ def test_no_merge_since_no_deploy_signal( ] # Make a non-breaking change model = controller._context.get_model("sushi.waiter_revenue_by_day").copy() - model.query.select(exp.alias_("1", "new_col"), copy=False) - controller._context.upsert_model(model) + controller._context.upsert_model( + model, + query_=ParsableSql(sql=model.query.select(exp.alias_("1", "new_col")).sql(model.dialect)), + ) github_output_file = tmp_path / "github_output.txt" @@ -842,9 +1021,19 @@ def test_no_merge_since_no_deploy_signal( assert GithubCheckStatus(pr_checks_runs[2]["status"]).is_completed assert GithubCheckConclusion(pr_checks_runs[2]["conclusion"]).is_success assert pr_checks_runs[2]["output"]["title"] == "PR Virtual Data Environment: hello_world_2" + pr_env_summary = pr_checks_runs[2]["output"]["summary"] assert ( - pr_checks_runs[2]["output"]["summary"] - == """
PR Environment Summary
ModelChange TypeDates Loaded
sushi.waiter_revenue_by_dayNon-breaking2022-12-25 - 2022-12-31
""" + """### Directly Modified +- `memory.sushi.waiter_revenue_by_day` (Non-breaking) + **Kind:** INCREMENTAL_BY_TIME_RANGE + **Dates loaded in PR:** [2022-12-25 - 2022-12-31]""" + in pr_env_summary + ) + assert ( + """### Indirectly Modified +- `memory.sushi.top_waiters` (Indirect Non-breaking) + **Kind:** VIEW [recreate view]""" + in pr_env_summary ) assert "SQLMesh - Prod Plan Preview" in controller._check_run_mapping @@ -856,35 +1045,42 @@ def test_no_merge_since_no_deploy_signal( assert GithubCheckStatus(prod_plan_preview_checks_runs[1]["status"]).is_in_progress assert GithubCheckStatus(prod_plan_preview_checks_runs[2]["status"]).is_completed assert GithubCheckConclusion(prod_plan_preview_checks_runs[2]["conclusion"]).is_success - expected_prod_plan = """**Summary of differences against `prod`:** - - -**Directly Modified:** -- `sushi.waiter_revenue_by_day` -```diff ---- - -+++ - -@@ -15,7 +15,8 @@ - - SELECT - CAST(o.waiter_id AS INT) AS waiter_id, - CAST(SUM(oi.quantity * i.price) AS DOUBLE) AS revenue, -- CAST(o.event_date AS DATE) AS event_date -+ CAST(o.event_date AS DATE) AS event_date, -+ 1 AS new_col - FROM sushi.orders AS o - LEFT JOIN sushi.order_items AS oi - ON o.id = oi.order_id AND o.event_date = oi.event_date -``` - -**Indirectly Modified:** -- `sushi.top_waiters` + expected_prod_plan_directly_modified_summary = """**Directly Modified:** +* `memory.sushi.waiter_revenue_by_day` (Non-breaking) + + ```diff + --- + + +++ + + @@ -17,7 +17,8 @@ + + SELECT + CAST(o.waiter_id AS INT) AS waiter_id, + CAST(SUM(oi.quantity * i.price) AS DOUBLE) AS revenue, + - CAST(o.event_date AS DATE) AS event_date + + CAST(o.event_date AS DATE) AS event_date, + + 1 AS new_col + FROM sushi.orders AS o + LEFT JOIN sushi.order_items AS oi + ON o.id = oi.order_id AND o.event_date = oi.event_date + ``` + Indirectly Modified Children: + - `memory.sushi.top_waiters` (Indirect Non-breaking)""" + + expected_prod_plan_indirectly_modified_summary = """**Indirectly Modified:** +- `memory.sushi.top_waiters` (Indirect Non-breaking) """ + assert prod_plan_preview_checks_runs[2]["output"]["title"] == "Prod Plan Preview" - assert prod_plan_preview_checks_runs[2]["output"]["summary"] == expected_prod_plan + prod_plan_preview_summary = prod_plan_preview_checks_runs[2]["output"]["summary"] + assert ( + "This is a preview that shows the differences between this PR environment `hello_world_2` and `prod`" + in prod_plan_preview_summary + ) + assert expected_prod_plan_directly_modified_summary in prod_plan_preview_summary + assert expected_prod_plan_indirectly_modified_summary in prod_plan_preview_summary assert "SQLMesh - Prod Environment Synced" in controller._check_run_mapping prod_checks_runs = controller._check_run_mapping["SQLMesh - Prod Environment Synced"].all_kwargs @@ -892,9 +1088,11 @@ def test_no_merge_since_no_deploy_signal( assert GithubCheckStatus(prod_checks_runs[0]["status"]).is_queued assert GithubCheckStatus(prod_checks_runs[1]["status"]).is_completed assert GithubCheckConclusion(prod_checks_runs[1]["conclusion"]).is_skipped - skip_reason = "Skipped Deploying to Production because a required approver has not approved" - assert prod_checks_runs[1]["output"]["title"] == skip_reason - assert prod_checks_runs[1]["output"]["summary"] == skip_reason + assert prod_checks_runs[1]["output"]["title"] == "Skipped deployment" + assert ( + prod_checks_runs[1]["output"]["summary"] + == "Skipped Deploying to Production because a required approver has not approved" + ) assert "SQLMesh - Has Required Approval" in controller._check_run_mapping approval_checks_runs = controller._check_run_mapping[ @@ -904,7 +1102,7 @@ def test_no_merge_since_no_deploy_signal( assert GithubCheckStatus(approval_checks_runs[0]["status"]).is_queued assert GithubCheckStatus(approval_checks_runs[1]["status"]).is_in_progress assert GithubCheckStatus(approval_checks_runs[2]["status"]).is_completed - assert GithubCheckConclusion(approval_checks_runs[2]["conclusion"]).is_neutral + assert GithubCheckConclusion(approval_checks_runs[2]["conclusion"]).is_failure assert approval_checks_runs[2]["output"]["title"] == "Need a Required Approval" assert ( approval_checks_runs[2]["output"]["summary"] @@ -932,11 +1130,11 @@ def test_no_merge_since_no_deploy_signal( output = f.read() assert ( output - == "run_unit_tests=success\nhas_required_approval=neutral\ncreated_pr_environment=true\npr_environment_name=hello_world_2\npr_environment_synced=success\nprod_plan_preview=success\nprod_environment_synced=skipped\n" + == "run_unit_tests=success\nhas_required_approval=failure\ncreated_pr_environment=true\npr_environment_name=hello_world_2\npr_environment_synced=success\nprod_plan_preview=success\nprod_environment_synced=skipped\n" ) -@freeze_time("2023-01-01 15:00:00") +@time_machine.travel("2023-01-01 15:00:00 UTC") def test_no_merge_since_no_deploy_signal_no_approvers_defined( github_client, make_controller, @@ -994,8 +1192,10 @@ def test_no_merge_since_no_deploy_signal_no_approvers_defined( controller._context.users = [User(username="test", github_username="test_github", roles=[])] # Make a non-breaking change model = controller._context.get_model("sushi.waiter_revenue_by_day").copy() - model.query.select(exp.alias_("1", "new_col"), copy=False) - controller._context.upsert_model(model) + controller._context.upsert_model( + model, + query_=ParsableSql(sql=model.query.select(exp.alias_("1", "new_col")).sql(model.dialect)), + ) github_output_file = tmp_path / "github_output.txt" @@ -1023,9 +1223,20 @@ def test_no_merge_since_no_deploy_signal_no_approvers_defined( assert GithubCheckStatus(pr_checks_runs[2]["status"]).is_completed assert GithubCheckConclusion(pr_checks_runs[2]["conclusion"]).is_success assert pr_checks_runs[2]["output"]["title"] == "PR Virtual Data Environment: hello_world_2" + pr_env_summary = pr_checks_runs[2]["output"]["summary"] + assert ( + """### Directly Modified +- `memory.sushi.waiter_revenue_by_day` (Non-breaking) + **Kind:** INCREMENTAL_BY_TIME_RANGE + **Dates loaded in PR:** [2022-12-30 - 2022-12-31] + **Dates *not* loaded in PR:** [2022-12-25 - 2022-12-29]""" + in pr_env_summary + ) assert ( - pr_checks_runs[2]["output"]["summary"] - == """
PR Environment Summary
ModelChange TypeDates Loaded
sushi.waiter_revenue_by_dayNon-breaking2022-12-30 - 2022-12-31
""" + """### Indirectly Modified +- `memory.sushi.top_waiters` (Indirect Non-breaking) + **Kind:** VIEW [recreate view]""" + in pr_env_summary ) assert "SQLMesh - Prod Plan Preview" in controller._check_run_mapping @@ -1037,38 +1248,40 @@ def test_no_merge_since_no_deploy_signal_no_approvers_defined( assert GithubCheckStatus(prod_plan_preview_checks_runs[1]["status"]).is_in_progress assert GithubCheckStatus(prod_plan_preview_checks_runs[2]["status"]).is_completed assert GithubCheckConclusion(prod_plan_preview_checks_runs[2]["conclusion"]).is_success - expected_prod_plan = """**Summary of differences against `prod`:** - - -**Directly Modified:** -- `sushi.waiter_revenue_by_day` -```diff ---- - -+++ - -@@ -15,7 +15,8 @@ - - SELECT - CAST(o.waiter_id AS INT) AS waiter_id, - CAST(SUM(oi.quantity * i.price) AS DOUBLE) AS revenue, -- CAST(o.event_date AS DATE) AS event_date -+ CAST(o.event_date AS DATE) AS event_date, -+ 1 AS new_col - FROM sushi.orders AS o - LEFT JOIN sushi.order_items AS oi - ON o.id = oi.order_id AND o.event_date = oi.event_date -``` - -**Indirectly Modified:** -- `sushi.top_waiters` - - -**Models needing backfill (missing dates):** -* `sushi.waiter_revenue_by_day`: 2022-12-25 - 2022-12-29 + expected_prod_plan_directly_modified_summary = """**Directly Modified:** +* `memory.sushi.waiter_revenue_by_day` (Non-breaking) + + ```diff + --- + + +++ + + @@ -17,7 +17,8 @@ + + SELECT + CAST(o.waiter_id AS INT) AS waiter_id, + CAST(SUM(oi.quantity * i.price) AS DOUBLE) AS revenue, + - CAST(o.event_date AS DATE) AS event_date + + CAST(o.event_date AS DATE) AS event_date, + + 1 AS new_col + FROM sushi.orders AS o + LEFT JOIN sushi.order_items AS oi + ON o.id = oi.order_id AND o.event_date = oi.event_date + ``` + Indirectly Modified Children: + - `memory.sushi.top_waiters` (Indirect Non-breaking) +""" + expected_prod_plan_indirectly_modified_summary = """**Indirectly Modified:** +- `memory.sushi.top_waiters` (Indirect Non-breaking) """ assert prod_plan_preview_checks_runs[2]["output"]["title"] == "Prod Plan Preview" - assert prod_plan_preview_checks_runs[2]["output"]["summary"] == expected_prod_plan + prod_plan_preview_summary = prod_plan_preview_checks_runs[2]["output"]["summary"] + assert ( + "This is a preview that shows the differences between this PR environment `hello_world_2` and `prod`" + in prod_plan_preview_summary + ) + assert expected_prod_plan_directly_modified_summary in prod_plan_preview_summary + assert expected_prod_plan_indirectly_modified_summary in prod_plan_preview_summary assert "SQLMesh - Prod Environment Synced" not in controller._check_run_mapping assert "SQLMesh - Has Required Approval" not in controller._check_run_mapping @@ -1096,7 +1309,7 @@ def test_no_merge_since_no_deploy_signal_no_approvers_defined( ) -@freeze_time("2023-01-01 15:00:00") +@time_machine.travel("2023-01-01 15:00:00 UTC") def test_deploy_comment_pre_categorized( github_client, make_controller, @@ -1155,8 +1368,10 @@ def test_deploy_comment_pre_categorized( controller._context.users = [User(username="test", github_username="test_github", roles=[])] # Make a non-breaking change model = controller._context.get_model("sushi.waiter_revenue_by_day").copy() - model.query.select(exp.alias_("1", "new_col"), copy=False) - controller._context.upsert_model(model) + controller._context.upsert_model( + model, + query_=ParsableSql(sql=model.query.select(exp.alias_("1", "new_col")).sql(model.dialect)), + ) # Manually categorize the change as non-breaking and don't backfill anything controller._context.plan( @@ -1193,9 +1408,19 @@ def test_deploy_comment_pre_categorized( assert GithubCheckStatus(pr_checks_runs[2]["status"]).is_completed assert GithubCheckConclusion(pr_checks_runs[2]["conclusion"]).is_success assert pr_checks_runs[2]["output"]["title"] == "PR Virtual Data Environment: hello_world_2" + pr_env_summary = pr_checks_runs[2]["output"]["summary"] + assert ( + """### Directly Modified +- `memory.sushi.waiter_revenue_by_day` (Non-breaking) + **Kind:** INCREMENTAL_BY_TIME_RANGE + **Dates loaded in PR:** [2022-12-25 - 2022-12-31]""" + in pr_env_summary + ) assert ( - pr_checks_runs[2]["output"]["summary"] - == """
PR Environment Summary
ModelChange TypeDates Loaded
sushi.waiter_revenue_by_dayNon-breaking2022-12-25 - 2022-12-31
""" + """### Indirectly Modified +- `memory.sushi.top_waiters` (Indirect Non-breaking) + **Kind:** VIEW [recreate view]""" + in pr_env_summary ) assert "SQLMesh - Prod Plan Preview" in controller._check_run_mapping @@ -1207,35 +1432,40 @@ def test_deploy_comment_pre_categorized( assert GithubCheckStatus(prod_plan_preview_checks_runs[1]["status"]).is_in_progress assert GithubCheckStatus(prod_plan_preview_checks_runs[2]["status"]).is_completed assert GithubCheckConclusion(prod_plan_preview_checks_runs[2]["conclusion"]).is_success - expected_prod_plan = """**Summary of differences against `prod`:** - - -**Directly Modified:** -- `sushi.waiter_revenue_by_day` -```diff ---- - -+++ - -@@ -15,7 +15,8 @@ - - SELECT - CAST(o.waiter_id AS INT) AS waiter_id, - CAST(SUM(oi.quantity * i.price) AS DOUBLE) AS revenue, -- CAST(o.event_date AS DATE) AS event_date -+ CAST(o.event_date AS DATE) AS event_date, -+ 1 AS new_col - FROM sushi.orders AS o - LEFT JOIN sushi.order_items AS oi - ON o.id = oi.order_id AND o.event_date = oi.event_date -``` - -**Indirectly Modified:** -- `sushi.top_waiters` - + expected_prod_plan_directly_modified_summary = """**Directly Modified:** +* `memory.sushi.waiter_revenue_by_day` (Non-breaking) + + ```diff + --- + + +++ + + @@ -17,7 +17,8 @@ + + SELECT + CAST(o.waiter_id AS INT) AS waiter_id, + CAST(SUM(oi.quantity * i.price) AS DOUBLE) AS revenue, + - CAST(o.event_date AS DATE) AS event_date + + CAST(o.event_date AS DATE) AS event_date, + + 1 AS new_col + FROM sushi.orders AS o + LEFT JOIN sushi.order_items AS oi + ON o.id = oi.order_id AND o.event_date = oi.event_date + ``` + Indirectly Modified Children: + - `memory.sushi.top_waiters` (Indirect Non-breaking) +""" + expected_prod_plan_indirectly_modified_summary = """**Indirectly Modified:** +- `memory.sushi.top_waiters` (Indirect Non-breaking) """ assert prod_plan_preview_checks_runs[2]["output"]["title"] == "Prod Plan Preview" - assert prod_plan_preview_checks_runs[2]["output"]["summary"] == expected_prod_plan + prod_plan_preview_summary = prod_plan_preview_checks_runs[2]["output"]["summary"] + assert ( + "This is a preview that shows the differences between this PR environment `hello_world_2` and `prod`" + in prod_plan_preview_summary + ) + assert expected_prod_plan_directly_modified_summary in prod_plan_preview_summary + assert expected_prod_plan_indirectly_modified_summary in prod_plan_preview_summary assert "SQLMesh - Prod Environment Synced" in controller._check_run_mapping prod_checks_runs = controller._check_run_mapping["SQLMesh - Prod Environment Synced"].all_kwargs @@ -1245,9 +1475,10 @@ def test_deploy_comment_pre_categorized( assert GithubCheckStatus(prod_checks_runs[2]["status"]).is_completed assert GithubCheckConclusion(prod_checks_runs[2]["conclusion"]).is_success assert prod_checks_runs[2]["output"]["title"] == "Deployed to Prod" - assert ( - prod_checks_runs[2]["output"]["summary"] == "**Generated Prod Plan**\n" + expected_prod_plan - ) + prod_environment_synced_summary = prod_checks_runs[2]["output"]["summary"] + assert "**Generated Prod Plan**" in prod_environment_synced_summary + assert expected_prod_plan_directly_modified_summary in prod_environment_synced_summary + assert expected_prod_plan_indirectly_modified_summary in prod_environment_synced_summary assert "SQLMesh - Has Required Approval" not in controller._check_run_mapping @@ -1259,21 +1490,19 @@ def test_deploy_comment_pre_categorized( assert mock_pull_request.merge.called assert len(created_comments) == 1 + comment_body = created_comments[0].body assert ( - created_comments[0].body - == f""":robot: **SQLMesh Bot Info** :robot: + """:robot: **SQLMesh Bot Info** :robot: - :eyes: To **review** this PR's changes, use virtual data environment: - `hello_world_2` - :arrow_forward: To **apply** this PR's plan to prod, comment: - `/deploy`
- :ship: Prod Plan Being Applied - -{expected_prod_plan} -
- -""" + :ship: Prod Plan Being Applied""" + in comment_body ) + assert expected_prod_plan_directly_modified_summary in comment_body + assert expected_prod_plan_indirectly_modified_summary in comment_body with open(github_output_file, "r", encoding="utf-8") as f: output = f.read() @@ -1283,7 +1512,7 @@ def test_deploy_comment_pre_categorized( ) -@freeze_time("2023-01-01 15:00:00") +@time_machine.travel("2023-01-01 15:00:00 UTC") def test_error_msg_when_applying_plan_with_bug( github_client, make_controller, @@ -1341,13 +1570,18 @@ def test_error_msg_when_applying_plan_with_bug( ] # Make an error by adding a column that doesn't exist model = controller._context.get_model("sushi.waiter_revenue_by_day").copy() - model.query.select(exp.alias_("non_existing_col", "new_col"), copy=False) - controller._context.upsert_model(model) + controller._context.upsert_model( + model, + query_=ParsableSql( + sql=model.query.select(exp.alias_("non_existing_col", "new_col")).sql(model.dialect) + ), + ) github_output_file = tmp_path / "github_output.txt" with mock.patch.dict(os.environ, {"GITHUB_OUTPUT": str(github_output_file)}): - command._run_all(controller) + with pytest.raises(CICDBotError): + command._run_all(controller) assert "SQLMesh - Run Unit Tests" in controller._check_run_mapping test_checks_runs = controller._check_run_mapping["SQLMesh - Run Unit Tests"].all_kwargs @@ -1370,10 +1604,10 @@ def test_error_msg_when_applying_plan_with_bug( assert GithubCheckStatus(pr_checks_runs[2]["status"]).is_completed assert GithubCheckConclusion(pr_checks_runs[2]["conclusion"]).is_failure assert pr_checks_runs[2]["output"]["title"] == "PR Virtual Data Environment: hello_world_2" - assert ( - 'Binder Error: Referenced column "non_existing_col" not found in FROM clause!' - in pr_checks_runs[2]["output"]["summary"] - ) + summary = pr_checks_runs[2]["output"]["summary"].replace("\n", "") + assert '**Skipped models*** `"memory"."sushi"."top_waiters"`' in summary + assert '**Failed models*** `"memory"."sushi"."waiter_revenue_by_day"`' in summary + assert 'Binder Error: Referenced column "non_existing_col" not found in FROM clause!' in summary assert "SQLMesh - Prod Plan Preview" in controller._check_run_mapping prod_plan_preview_checks_runs = controller._check_run_mapping[ @@ -1398,9 +1632,11 @@ def test_error_msg_when_applying_plan_with_bug( assert GithubCheckStatus(prod_checks_runs[0]["status"]).is_queued assert GithubCheckStatus(prod_checks_runs[1]["status"]).is_completed assert GithubCheckConclusion(prod_checks_runs[1]["conclusion"]).is_skipped - skip_reason = "Skipped Deploying to Production because the PR environment was not updated" - assert prod_checks_runs[1]["output"]["title"] == skip_reason - assert prod_checks_runs[1]["output"]["summary"] == skip_reason + assert prod_checks_runs[1]["output"]["title"] == "Skipped deployment" + assert ( + prod_checks_runs[1]["output"]["summary"] + == "Skipped Deploying to Production because the PR environment was not updated" + ) assert "SQLMesh - Has Required Approval" in controller._check_run_mapping approval_checks_runs = controller._check_run_mapping[ @@ -1435,7 +1671,7 @@ def test_error_msg_when_applying_plan_with_bug( ) -@freeze_time("2023-01-01 15:00:00") +@time_machine.travel("2023-01-01 15:00:00 UTC") def test_overlapping_changes_models( github_client, make_controller, @@ -1497,8 +1733,10 @@ def test_overlapping_changes_models( # These changes have shared children and this ensures we don't repeat the children in the output # Make a non-breaking change model = controller._context.get_model("sushi.customers").copy() - model.query.select(exp.alias_("1", "new_col"), copy=False) - controller._context.upsert_model(model) + controller._context.upsert_model( + model, + query_=ParsableSql(sql=model.query.select(exp.alias_("1", "new_col")).sql(model.dialect)), + ) # Make a breaking change model = controller._context.get_model("sushi.waiter_names").copy() @@ -1531,9 +1769,31 @@ def test_overlapping_changes_models( assert GithubCheckStatus(pr_checks_runs[2]["status"]).is_completed assert GithubCheckConclusion(pr_checks_runs[2]["conclusion"]).is_success assert pr_checks_runs[2]["output"]["title"] == "PR Virtual Data Environment: hello_world_2" + pr_env_summary = pr_checks_runs[2]["output"]["summary"] + assert ( + """### Directly Modified +- `memory.sushi.customers` (Non-breaking) + **Kind:** FULL [full refresh] + +- `memory.sushi.waiter_names` (Breaking) + **Kind:** SEED [full refresh]""" + in pr_env_summary + ) assert ( - pr_checks_runs[2]["output"]["summary"] - == """
PR Environment Summary
ModelChange TypeDates Loaded
sushi.customersNon-breaking2022-12-25 - 2022-12-31
sushi.waiter_namesBreaking2022-12-31 - 2022-12-31
sushi.waiter_as_customer_by_dayIndirect Breaking2022-12-25 - 2022-12-31
""" + """### Indirectly Modified +- `memory.sushi.active_customers` (Indirect Non-breaking) + **Kind:** CUSTOM [full refresh] + +- `memory.sushi.count_customers_active` (Indirect Non-breaking) + **Kind:** FULL [full refresh] + +- `memory.sushi.count_customers_inactive` (Indirect Non-breaking) + **Kind:** FULL [full refresh] + +- `memory.sushi.waiter_as_customer_by_day` (Indirect Breaking) + **Kind:** INCREMENTAL_BY_TIME_RANGE + **Dates loaded in PR:** [2022-12-25 - 2022-12-31]""" + in pr_env_summary ) assert "SQLMesh - Prod Plan Preview" in controller._check_run_mapping @@ -1545,40 +1805,54 @@ def test_overlapping_changes_models( assert GithubCheckStatus(prod_plan_preview_checks_runs[1]["status"]).is_in_progress assert GithubCheckStatus(prod_plan_preview_checks_runs[2]["status"]).is_completed assert GithubCheckConclusion(prod_plan_preview_checks_runs[2]["conclusion"]).is_success - expected_prod_plan_summary = """**Summary of differences against `prod`:** - - -**Directly Modified:** -- `sushi.customers` -```diff ---- - -+++ -@@ -25,7 +25,8 @@ + expected_prod_plan_directly_modified_summary = """**Directly Modified:** +* `memory.sushi.customers` (Non-breaking) + + ```diff + --- + + +++ + + @@ -32,7 +32,8 @@ + + SELECT DISTINCT + CAST(o.customer_id AS INT) AS customer_id, + m.status, + - d.zip + + d.zip, + + 1 AS new_col + FROM sushi.orders AS o + LEFT JOIN ( + WITH current_marketing AS ( + ``` + Indirectly Modified Children: + - `memory.sushi.active_customers` (Indirect Non-breaking) + - `memory.sushi.count_customers_active` (Indirect Non-breaking) + - `memory.sushi.count_customers_inactive` (Indirect Non-breaking) + - `memory.sushi.waiter_as_customer_by_day` (Indirect Breaking) + + +* `memory.sushi.waiter_names` (Breaking) + + + Indirectly Modified Children: + - `memory.sushi.waiter_as_customer_by_day` (Indirect Breaking)""" + + expected_prod_plan_indirectly_modified_summary = """**Indirectly Modified:** +- `memory.sushi.active_customers` (Indirect Non-breaking) +- `memory.sushi.count_customers_active` (Indirect Non-breaking) +- `memory.sushi.count_customers_inactive` (Indirect Non-breaking) +- `memory.sushi.waiter_as_customer_by_day` (Indirect Breaking)""" - SELECT DISTINCT - CAST(o.customer_id AS INT) AS customer_id, - m.status, -- d.zip -+ d.zip, -+ 1 AS new_col - FROM sushi.orders AS o - LEFT JOIN current_marketing AS m - ON o.customer_id = m.customer_id -``` -- `sushi.waiter_names` -```diff - -``` - -**Indirectly Modified:** -- `sushi.active_customers` -- `sushi.waiter_as_customer_by_day` - -""" assert prod_plan_preview_checks_runs[2]["output"]["title"] == "Prod Plan Preview" - assert prod_plan_preview_checks_runs[2]["output"]["summary"] == expected_prod_plan_summary + prod_plan_preview_summary = prod_plan_preview_checks_runs[2]["output"]["summary"] + assert ( + "This is a preview that shows the differences between this PR environment `hello_world_2` and `prod`" + in prod_plan_preview_summary + ) + assert expected_prod_plan_directly_modified_summary in prod_plan_preview_summary + assert expected_prod_plan_indirectly_modified_summary in prod_plan_preview_summary assert "SQLMesh - Prod Environment Synced" in controller._check_run_mapping prod_checks_runs = controller._check_run_mapping["SQLMesh - Prod Environment Synced"].all_kwargs @@ -1588,10 +1862,10 @@ def test_overlapping_changes_models( assert GithubCheckStatus(prod_checks_runs[2]["status"]).is_completed assert GithubCheckConclusion(prod_checks_runs[2]["conclusion"]).is_success assert prod_checks_runs[2]["output"]["title"] == "Deployed to Prod" - assert ( - prod_checks_runs[2]["output"]["summary"] - == "**Generated Prod Plan**\n" + expected_prod_plan_summary - ) + prod_environment_synced_summary = prod_checks_runs[2]["output"]["summary"] + assert "**Generated Prod Plan**" in prod_environment_synced_summary + assert expected_prod_plan_directly_modified_summary in prod_environment_synced_summary + assert expected_prod_plan_indirectly_modified_summary in prod_environment_synced_summary assert "SQLMesh - Has Required Approval" in controller._check_run_mapping approval_checks_runs = controller._check_run_mapping[ @@ -1613,25 +1887,23 @@ def test_overlapping_changes_models( """ ) - assert len(get_environment_objects(controller, "hello_world_2")) == 4 + assert len(get_environment_objects(controller, "hello_world_2")) == 6 assert "new_col" in get_columns(controller, "hello_world_2", "customers") assert mock_pull_request.merge.called assert len(created_comments) == 1 + comment_body = created_comments[0].body assert ( - created_comments[0].body - == f""":robot: **SQLMesh Bot Info** :robot: + f""":robot: **SQLMesh Bot Info** :robot: - :eyes: To **review** this PR's changes, use virtual data environment: - `hello_world_2`
- :ship: Prod Plan Being Applied - -{expected_prod_plan_summary} -
- -""" + :ship: Prod Plan Being Applied""" + in comment_body ) + assert expected_prod_plan_directly_modified_summary in comment_body + assert expected_prod_plan_indirectly_modified_summary in comment_body with open(github_output_file, "r", encoding="utf-8") as f: output = f.read() @@ -1641,8 +1913,8 @@ def test_overlapping_changes_models( ) -@freeze_time("2023-01-01 15:00:00") -def test_pr_delete_model( +@time_machine.travel("2023-01-01 15:00:00 UTC") +def test_pr_add_model( github_client, make_controller, make_mock_check_run, @@ -1652,16 +1924,16 @@ def test_pr_delete_model( mocker: MockerFixture, ): """ - PR with a removed model and auto-categorization will be backfilled, merged, and deployed to prod + PR with an added model and auto-categorization will be backfilled, merged, and deployed to prod Scenario: - PR is not merged - - PR has been approved by a required reviewer + - /deploy command has been issued - Tests passed - PR Merge Method defined - - Delete environment is disabled - Changes made in PR with auto-categorization """ + mock_repo = github_client.get_repo() mock_repo.create_check_run = mocker.MagicMock( side_effect=lambda **kwargs: make_mock_check_run(**kwargs) @@ -1684,32 +1956,34 @@ def test_pr_delete_model( mock_pull_request.merge = mocker.MagicMock() controller = make_controller( - "tests/fixtures/github/pull_request_synchronized.json", + "tests/fixtures/github/pull_request_command_deploy.json", github_client, bot_config=GithubCICDBotConfig( merge_method=MergeMethod.MERGE, - invalidate_environment_after_deploy=False, auto_categorize_changes=CategorizerConfig.all_full(), + enable_deploy_command=True, default_pr_start=None, skip_pr_backfill=False, ), mock_out_context=False, ) controller._context.plan("prod", no_prompts=True, auto_apply=True) - controller._context.users = [ - User(username="test", github_username="test_github", roles=[UserRole.REQUIRED_APPROVER]) - ] - # Remove a model - model = controller._context.get_model("sushi.top_waiters").copy() - del controller._context._models[model.fqn] - controller._context.dag = controller._context.dag.prune(*controller._context._models.keys()) - github_output_file = tmp_path / "github_output.txt" - - mocker.patch( - "sqlmesh.core.engine_adapter.base.EngineAdapter.insert_overwrite_by_time_partition", - side_effect=Exception("Test Exception"), + # Add a model + (controller._context.path / "models" / "cicd_test_model.sql").write_text( + """ + MODEL ( + name sushi.cicd_test_model, + kind FULL + ); + + select 1; + """ ) + controller._context.load() + assert '"memory"."sushi"."cicd_test_model"' in controller._context.models + + github_output_file = tmp_path / "github_output.txt" with mock.patch.dict(os.environ, {"GITHUB_OUTPUT": str(github_output_file)}): command._run_all(controller) @@ -1722,6 +1996,7 @@ def test_pr_delete_model( assert GithubCheckStatus(test_checks_runs[2]["status"]).is_completed assert GithubCheckConclusion(test_checks_runs[2]["conclusion"]).is_success assert test_checks_runs[2]["output"]["title"] == "Tests Passed" + print(test_checks_runs[2]["output"]["summary"]) assert ( test_checks_runs[2]["output"]["summary"].strip() == "**Successfully Ran `3` Tests Against `duckdb`**" @@ -1735,18 +2010,16 @@ def test_pr_delete_model( assert GithubCheckStatus(pr_checks_runs[2]["status"]).is_completed assert GithubCheckConclusion(pr_checks_runs[2]["conclusion"]).is_success assert pr_checks_runs[2]["output"]["title"] == "PR Virtual Data Environment: hello_world_2" + pr_env_summary = pr_checks_runs[2]["output"]["summary"] assert ( - pr_checks_runs[2]["output"]["summary"] - == """
PR Environment Summary
ModelChange TypeDates Loaded
"memory"."sushi"."top_waiters"BreakingREMOVED
""" + """### Added +- `memory.sushi.cicd_test_model` (Breaking) + **Kind:** FULL [full refresh]""" + in pr_env_summary ) - expected_prod_plan_summary = """**Summary of differences against `prod`:** - - -**Removed Models:** -- `sushi.top_waiters` - -""" + expected_prod_plan_summary = """**Added Models:** +- `memory.sushi.cicd_test_model` (Breaking)""" assert "SQLMesh - Prod Plan Preview" in controller._check_run_mapping prod_plan_preview_checks_runs = controller._check_run_mapping[ @@ -1758,7 +2031,7 @@ def test_pr_delete_model( assert GithubCheckStatus(prod_plan_preview_checks_runs[2]["status"]).is_completed assert GithubCheckConclusion(prod_plan_preview_checks_runs[2]["conclusion"]).is_success assert prod_plan_preview_checks_runs[2]["output"]["title"] == "Prod Plan Preview" - assert prod_plan_preview_checks_runs[2]["output"]["summary"] == expected_prod_plan_summary + assert expected_prod_plan_summary in prod_plan_preview_checks_runs[2]["output"]["summary"] assert "SQLMesh - Prod Environment Synced" in controller._check_run_mapping prod_checks_runs = controller._check_run_mapping["SQLMesh - Prod Environment Synced"].all_kwargs @@ -1768,10 +2041,160 @@ def test_pr_delete_model( assert GithubCheckStatus(prod_checks_runs[2]["status"]).is_completed assert GithubCheckConclusion(prod_checks_runs[2]["conclusion"]).is_success assert prod_checks_runs[2]["output"]["title"] == "Deployed to Prod" + prod_environment_synced_summary = prod_checks_runs[2]["output"]["summary"] + assert "**Generated Prod Plan**" in prod_environment_synced_summary + assert expected_prod_plan_summary in prod_environment_synced_summary + + assert mock_pull_request.merge.called + + assert len(created_comments) == 1 + comment_body = created_comments[0].body + assert ( + """:robot: **SQLMesh Bot Info** :robot: +- :eyes: To **review** this PR's changes, use virtual data environment: + - `hello_world_2` +- :arrow_forward: To **apply** this PR's plan to prod, comment: + - `/deploy` +
+ :ship: Prod Plan Being Applied""" + in comment_body + ) + assert expected_prod_plan_summary in comment_body + assert ( - prod_checks_runs[2]["output"]["summary"] - == "**Generated Prod Plan**\n" + expected_prod_plan_summary + github_output_file.read_text() + == "run_unit_tests=success\ncreated_pr_environment=true\npr_environment_name=hello_world_2\npr_environment_synced=success\nprod_plan_preview=success\nprod_environment_synced=success\n" + ) + + +@time_machine.travel("2023-01-01 15:00:00 UTC") +def test_pr_delete_model( + github_client, + make_controller, + make_mock_check_run, + make_mock_issue_comment, + make_pull_request_review, + tmp_path: pathlib.Path, + mocker: MockerFixture, +): + """ + PR with a removed model and auto-categorization will be backfilled, merged, and deployed to prod + + Scenario: + - PR is not merged + - PR has been approved by a required reviewer + - Tests passed + - PR Merge Method defined + - Delete environment is disabled + - Changes made in PR with auto-categorization + """ + mock_repo = github_client.get_repo() + mock_repo.create_check_run = mocker.MagicMock( + side_effect=lambda **kwargs: make_mock_check_run(**kwargs) + ) + + created_comments: t.List[MockIssueComment] = [] + mock_issue = mock_repo.get_issue() + mock_issue.create_comment = mocker.MagicMock( + side_effect=lambda comment: make_mock_issue_comment( + comment=comment, created_comments=created_comments + ) + ) + mock_issue.get_comments = mocker.MagicMock(side_effect=lambda: created_comments) + + mock_pull_request = mock_repo.get_pull() + mock_pull_request.get_reviews = mocker.MagicMock( + side_effect=lambda: [make_pull_request_review(username="test_github", state="APPROVED")] + ) + mock_pull_request.merged = False + mock_pull_request.merge = mocker.MagicMock() + + controller = make_controller( + "tests/fixtures/github/pull_request_synchronized.json", + github_client, + bot_config=GithubCICDBotConfig( + merge_method=MergeMethod.MERGE, + invalidate_environment_after_deploy=False, + auto_categorize_changes=CategorizerConfig.all_full(), + default_pr_start=None, + skip_pr_backfill=False, + ), + mock_out_context=False, ) + controller._context.plan("prod", no_prompts=True, auto_apply=True) + controller._context.users = [ + User(username="test", github_username="test_github", roles=[UserRole.REQUIRED_APPROVER]) + ] + # Remove a model + model = controller._context.get_model("sushi.top_waiters").copy() + del controller._context._models[model.fqn] + controller._context.dag = controller._context.dag.prune(*controller._context._models.keys()) + + github_output_file = tmp_path / "github_output.txt" + + mocker.patch( + "sqlmesh.core.engine_adapter.base.EngineAdapter.insert_overwrite_by_time_partition", + side_effect=Exception("Test Exception"), + ) + + with mock.patch.dict(os.environ, {"GITHUB_OUTPUT": str(github_output_file)}): + command._run_all(controller) + + assert "SQLMesh - Run Unit Tests" in controller._check_run_mapping + test_checks_runs = controller._check_run_mapping["SQLMesh - Run Unit Tests"].all_kwargs + assert len(test_checks_runs) == 3 + assert GithubCheckStatus(test_checks_runs[0]["status"]).is_queued + assert GithubCheckStatus(test_checks_runs[1]["status"]).is_in_progress + assert GithubCheckStatus(test_checks_runs[2]["status"]).is_completed + assert GithubCheckConclusion(test_checks_runs[2]["conclusion"]).is_success + assert test_checks_runs[2]["output"]["title"] == "Tests Passed" + print(test_checks_runs[2]["output"]["summary"]) + assert ( + test_checks_runs[2]["output"]["summary"].strip() + == "**Successfully Ran `3` Tests Against `duckdb`**" + ) + + assert "SQLMesh - PR Environment Synced" in controller._check_run_mapping + pr_checks_runs = controller._check_run_mapping["SQLMesh - PR Environment Synced"].all_kwargs + assert len(pr_checks_runs) == 3 + assert GithubCheckStatus(pr_checks_runs[0]["status"]).is_queued + assert GithubCheckStatus(pr_checks_runs[1]["status"]).is_in_progress + assert GithubCheckStatus(pr_checks_runs[2]["status"]).is_completed + assert GithubCheckConclusion(pr_checks_runs[2]["conclusion"]).is_success + assert pr_checks_runs[2]["output"]["title"] == "PR Virtual Data Environment: hello_world_2" + pr_env_summary = pr_checks_runs[2]["output"]["summary"] + assert ( + """### Removed +- `memory.sushi.top_waiters` (Breaking)""" + in pr_env_summary + ) + + expected_prod_plan_summary = """**Removed Models:** +- `memory.sushi.top_waiters` (Breaking)""" + + assert "SQLMesh - Prod Plan Preview" in controller._check_run_mapping + prod_plan_preview_checks_runs = controller._check_run_mapping[ + "SQLMesh - Prod Plan Preview" + ].all_kwargs + assert len(prod_plan_preview_checks_runs) == 3 + assert GithubCheckStatus(prod_plan_preview_checks_runs[0]["status"]).is_queued + assert GithubCheckStatus(prod_plan_preview_checks_runs[1]["status"]).is_in_progress + assert GithubCheckStatus(prod_plan_preview_checks_runs[2]["status"]).is_completed + assert GithubCheckConclusion(prod_plan_preview_checks_runs[2]["conclusion"]).is_success + assert prod_plan_preview_checks_runs[2]["output"]["title"] == "Prod Plan Preview" + assert expected_prod_plan_summary in prod_plan_preview_checks_runs[2]["output"]["summary"] + + assert "SQLMesh - Prod Environment Synced" in controller._check_run_mapping + prod_checks_runs = controller._check_run_mapping["SQLMesh - Prod Environment Synced"].all_kwargs + assert len(prod_checks_runs) == 3 + assert GithubCheckStatus(prod_checks_runs[0]["status"]).is_queued + assert GithubCheckStatus(prod_checks_runs[1]["status"]).is_in_progress + assert GithubCheckStatus(prod_checks_runs[2]["status"]).is_completed + assert GithubCheckConclusion(prod_checks_runs[2]["conclusion"]).is_success + assert prod_checks_runs[2]["output"]["title"] == "Deployed to Prod" + prod_environment_synced_summary = prod_checks_runs[2]["output"]["summary"] + assert "**Generated Prod Plan**" in prod_environment_synced_summary + assert expected_prod_plan_summary in prod_environment_synced_summary assert "SQLMesh - Has Required Approval" in controller._check_run_mapping approval_checks_runs = controller._check_run_mapping[ @@ -1798,23 +2221,270 @@ def test_pr_delete_model( assert mock_pull_request.merge.called assert len(created_comments) == 1 + comment_body = created_comments[0].body assert ( - created_comments[0].body - == f""":robot: **SQLMesh Bot Info** :robot: + """:robot: **SQLMesh Bot Info** :robot: - :eyes: To **review** this PR's changes, use virtual data environment: - `hello_world_2`
- :ship: Prod Plan Being Applied + :ship: Prod Plan Being Applied""" + in comment_body + ) + assert expected_prod_plan_summary in comment_body + + with open(github_output_file, "r", encoding="utf-8") as f: + output = f.read() + assert ( + output + == "run_unit_tests=success\nhas_required_approval=success\ncreated_pr_environment=true\npr_environment_name=hello_world_2\npr_environment_synced=success\nprod_plan_preview=success\nprod_environment_synced=success\n" + ) -{expected_prod_plan_summary} -
+@time_machine.travel("2023-01-01 15:00:00 UTC") +def test_has_required_approval_but_not_base_branch( + github_client, + make_controller, + make_mock_check_run, + make_mock_issue_comment, + make_pull_request_review, + tmp_path: pathlib.Path, + mocker: MockerFixture, +): + """ + PR with a non-breaking change and auto-categorization will be backfilled, but NOT automatically merged or deployed to production if it is branched from a non-production branch. + + Scenario: + - PR is not merged + - PR has been approved by a required reviewer + - Tests passed + - PR Merge Method defined + - Delete environment is disabled + - Changes made in PR with auto-categorization + - PR is not merged, despite having required approval, since the base branch is not prod + """ + mock_repo = github_client.get_repo() + mock_repo.create_check_run = mocker.MagicMock( + side_effect=lambda **kwargs: make_mock_check_run(**kwargs) + ) + + created_comments: t.List[MockIssueComment] = [] + mock_issue = mock_repo.get_issue() + mock_issue.create_comment = mocker.MagicMock( + side_effect=lambda comment: make_mock_issue_comment( + comment=comment, created_comments=created_comments + ) + ) + mock_issue.get_comments = mocker.MagicMock(side_effect=lambda: created_comments) + + mock_pull_request = mock_repo.get_pull() + mock_pull_request.base.ref = "feature/branch" + mock_pull_request.get_reviews = mocker.MagicMock( + side_effect=lambda: [make_pull_request_review(username="test_github", state="APPROVED")] + ) + mock_pull_request.merged = False + mock_pull_request.merge = mocker.MagicMock() + + controller = make_controller( + "tests/fixtures/github/pull_request_synchronized.json", + github_client, + bot_config=GithubCICDBotConfig( + merge_method=MergeMethod.MERGE, + invalidate_environment_after_deploy=False, + auto_categorize_changes=CategorizerConfig.all_full(), + default_pr_start=None, + skip_pr_backfill=False, + ), + mock_out_context=False, + ) + controller._context.plan("prod", no_prompts=True, auto_apply=True) + controller._context.users = [ + User(username="test", github_username="test_github", roles=[UserRole.REQUIRED_APPROVER]) + ] + # Make a non-breaking change + model = controller._context.get_model("sushi.waiter_revenue_by_day").copy() + controller._context.upsert_model( + model, + query_=ParsableSql(sql=model.query.select(exp.alias_("1", "new_col")).sql(model.dialect)), + ) + + github_output_file = tmp_path / "github_output.txt" + + with mock.patch.dict(os.environ, {"GITHUB_OUTPUT": str(github_output_file)}): + command._run_all(controller) + + assert "SQLMesh - Run Unit Tests" in controller._check_run_mapping + test_checks_runs = controller._check_run_mapping["SQLMesh - Run Unit Tests"].all_kwargs + assert len(test_checks_runs) == 3 + assert GithubCheckStatus(test_checks_runs[0]["status"]).is_queued + assert GithubCheckStatus(test_checks_runs[1]["status"]).is_in_progress + assert GithubCheckStatus(test_checks_runs[2]["status"]).is_completed + assert GithubCheckConclusion(test_checks_runs[2]["conclusion"]).is_success + assert test_checks_runs[2]["output"]["title"] == "Tests Passed" + assert ( + test_checks_runs[2]["output"]["summary"].strip() + == "**Successfully Ran `3` Tests Against `duckdb`**" + ) + + assert "SQLMesh - PR Environment Synced" in controller._check_run_mapping + pr_checks_runs = controller._check_run_mapping["SQLMesh - PR Environment Synced"].all_kwargs + assert len(pr_checks_runs) == 3 + assert GithubCheckStatus(pr_checks_runs[0]["status"]).is_queued + assert GithubCheckStatus(pr_checks_runs[1]["status"]).is_in_progress + assert GithubCheckStatus(pr_checks_runs[2]["status"]).is_completed + assert GithubCheckConclusion(pr_checks_runs[2]["conclusion"]).is_success + assert pr_checks_runs[2]["output"]["title"] == "PR Virtual Data Environment: hello_world_2" + pr_env_summary = pr_checks_runs[2]["output"]["summary"] + assert ( + """### Directly Modified +- `memory.sushi.waiter_revenue_by_day` (Non-breaking) + **Kind:** INCREMENTAL_BY_TIME_RANGE + **Dates loaded in PR:** [2022-12-25 - 2022-12-31]""" + in pr_env_summary + ) + assert ( + """### Indirectly Modified +- `memory.sushi.top_waiters` (Indirect Non-breaking) + **Kind:** VIEW [recreate view]""" + in pr_env_summary + ) + + assert "SQLMesh - Prod Plan Preview" in controller._check_run_mapping + prod_plan_preview_checks_runs = controller._check_run_mapping[ + "SQLMesh - Prod Plan Preview" + ].all_kwargs + assert len(prod_plan_preview_checks_runs) == 3 + assert GithubCheckStatus(prod_plan_preview_checks_runs[0]["status"]).is_queued + assert GithubCheckStatus(prod_plan_preview_checks_runs[1]["status"]).is_in_progress + assert GithubCheckStatus(prod_plan_preview_checks_runs[2]["status"]).is_completed + assert GithubCheckConclusion(prod_plan_preview_checks_runs[2]["conclusion"]).is_success + expected_prod_plan_directly_modified_summary = """**Directly Modified:** +* `memory.sushi.waiter_revenue_by_day` (Non-breaking) + + ```diff + --- + + +++ + + @@ -17,7 +17,8 @@ + + SELECT + CAST(o.waiter_id AS INT) AS waiter_id, + CAST(SUM(oi.quantity * i.price) AS DOUBLE) AS revenue, + - CAST(o.event_date AS DATE) AS event_date + + CAST(o.event_date AS DATE) AS event_date, + + 1 AS new_col + FROM sushi.orders AS o + LEFT JOIN sushi.order_items AS oi + ON o.id = oi.order_id AND o.event_date = oi.event_date + ``` + Indirectly Modified Children: + - `memory.sushi.top_waiters` (Indirect Non-breaking)""" + + expected_prod_plan_indirectly_modified_summary = """**Indirectly Modified:** +- `memory.sushi.top_waiters` (Indirect Non-breaking)""" + + assert prod_plan_preview_checks_runs[2]["output"]["title"] == "Prod Plan Preview" + prod_plan_preview_summary = prod_plan_preview_checks_runs[2]["output"]["summary"] + assert ( + "This is a preview that shows the differences between this PR environment `hello_world_2` and `prod`" + in prod_plan_preview_summary + ) + assert expected_prod_plan_directly_modified_summary in prod_plan_preview_summary + assert expected_prod_plan_indirectly_modified_summary in prod_plan_preview_summary + + assert "SQLMesh - Prod Environment Synced" not in controller._check_run_mapping + + assert "SQLMesh - Has Required Approval" in controller._check_run_mapping + approval_checks_runs = controller._check_run_mapping[ + "SQLMesh - Has Required Approval" + ].all_kwargs + assert len(approval_checks_runs) == 3 + assert GithubCheckStatus(approval_checks_runs[0]["status"]).is_queued + assert GithubCheckStatus(approval_checks_runs[1]["status"]).is_in_progress + assert GithubCheckStatus(approval_checks_runs[2]["status"]).is_completed + assert GithubCheckConclusion(approval_checks_runs[2]["conclusion"]).is_success + assert ( + approval_checks_runs[2]["output"]["title"] + == "Obtained approval from required approvers: test_github" + ) + assert ( + approval_checks_runs[2]["output"]["summary"] + == """**List of possible required approvers:** +- `test_github` """ ) + assert len(get_environment_objects(controller, "hello_world_2")) == 2 + assert get_num_days_loaded(controller, "hello_world_2", "waiter_revenue_by_day") == 7 + assert "new_col" in get_columns(controller, "hello_world_2", "waiter_revenue_by_day") + assert "new_col" not in get_columns(controller, None, "waiter_revenue_by_day") + + assert not mock_pull_request.merge.called + + assert len(created_comments) == 1 + assert ( + created_comments[0].body + == """:robot: **SQLMesh Bot Info** :robot: +- :eyes: To **review** this PR's changes, use virtual data environment: + - `hello_world_2`""" + ) + with open(github_output_file, "r", encoding="utf-8") as f: output = f.read() assert ( output - == "run_unit_tests=success\nhas_required_approval=success\ncreated_pr_environment=true\npr_environment_name=hello_world_2\npr_environment_synced=success\nprod_plan_preview=success\nprod_environment_synced=success\n" + == "run_unit_tests=success\nhas_required_approval=success\ncreated_pr_environment=true\npr_environment_name=hello_world_2\npr_environment_synced=success\nprod_plan_preview=success\n" ) + + +def test_unexpected_error_is_handled( + github_client, + make_controller, + make_mock_check_run, + make_mock_issue_comment, + mocker: MockerFixture, +): + """ + Scenario: + - Plan throws a SQLMeshError due to a migration version mismatch + - Outcome should be a nice error like the CLI gives and not a stack trace + """ + + mock_repo = github_client.get_repo() + mock_repo.create_check_run = mocker.MagicMock( + side_effect=lambda **kwargs: make_mock_check_run(**kwargs) + ) + + created_comments: t.List[MockIssueComment] = [] + mock_issue = mock_repo.get_issue() + mock_issue.create_comment = mocker.MagicMock( + side_effect=lambda comment: make_mock_issue_comment( + comment=comment, created_comments=created_comments + ) + ) + + controller = make_controller( + "tests/fixtures/github/pull_request_synchronized.json", + github_client, + bot_config=GithubCICDBotConfig(), + mock_out_context=True, + ) + assert isinstance(controller, GithubController) + + assert isinstance(controller._context.apply, mocker.MagicMock) + controller._context.apply.side_effect = SQLMeshError( + "SQLGlot (local) is using version 'X' which is ahead of 'Y' (remote). Please run a migration" + ) + + command._update_pr_environment(controller) + + assert "SQLMesh - PR Environment Synced" in controller._check_run_mapping + pr_checks_runs = controller._check_run_mapping["SQLMesh - PR Environment Synced"].all_kwargs # type: ignore + assert pr_checks_runs[1]["output"]["title"] == "PR Virtual Data Environment: hello_world_2" + summary = pr_checks_runs[1]["output"]["summary"] + assert ( + "**Error:** SQLGlot (local) is using version 'X' which is ahead of 'Y' (remote). Please run a migration" + in pr_checks_runs[1]["output"]["summary"] + ) + assert "SQLMeshError" not in summary + assert "Traceback (most recent call last)" not in summary diff --git a/tests/integrations/jupyter/example_outputs.ipynb b/tests/integrations/jupyter/example_outputs.ipynb index 59190d84e9..cfe3aa2457 100644 --- a/tests/integrations/jupyter/example_outputs.ipynb +++ b/tests/integrations/jupyter/example_outputs.ipynb @@ -12,7 +12,7 @@ "import shutil\n", "import pathlib\n", "\n", - "from freezegun import freeze_time\n", + "import time_machine\n", "\n", "# import to register magics\n", "import sqlmesh" @@ -25,7 +25,7 @@ "metadata": {}, "outputs": [], "source": [ - "freezer = freeze_time(\"2032-01-01 00:00:00\")\n", + "freezer = time_machine.travel(\"2032-01-01 00:00:00 UTC\")\n", "freezer.start()" ] }, @@ -51,7 +51,7 @@ " prev_parent_stem = parent.stem\n", "else:\n", " raise RuntimeError(\"Couldn't find root dir\")\n", - " \n", + "\n", "EXAMPLE_SUSHI_DIR = pathlib.Path(root_dir) / \"examples\" / \"sushi\"\n", "str(EXAMPLE_SUSHI_DIR)" ] @@ -336,7 +336,7 @@ "metadata": {}, "outputs": [], "source": [ - "freezer = freeze_time(\"2032-01-02 00:00:00\")\n", + "freezer = time_machine.travel(\"2032-01-02 00:00:00\")\n", "freezer.start()" ] }, diff --git a/tests/integrations/jupyter/test_magics.py b/tests/integrations/jupyter/test_magics.py index ee6b6c5519..991df8fc15 100644 --- a/tests/integrations/jupyter/test_magics.py +++ b/tests/integrations/jupyter/test_magics.py @@ -5,7 +5,7 @@ import pytest from bs4 import BeautifulSoup -from freezegun import freeze_time +import time_machine from hyperscript import h from IPython.core.error import UsageError from IPython.testing.globalipapp import start_ipython @@ -15,6 +15,7 @@ from sqlmesh import Context, RuntimeEnv from sqlmesh.magics import register_magics +from pathlib import Path logger = logging.getLogger(__name__) @@ -22,8 +23,11 @@ SUSHI_EXAMPLE_PATH = pathlib.Path("./examples/sushi") SUCCESS_STYLE = "color: #008000; text-decoration-color: #008000" NEUTRAL_STYLE = "color: #008080; text-decoration-color: #008080" +BOLD_ONLY = "font-weight: bold" +BOLD_NEUTRAL_STYLE = f"{NEUTRAL_STYLE}; {BOLD_ONLY}" +BOLD_SUCCESS_STYLE = f"{SUCCESS_STYLE}; {BOLD_ONLY}" RICH_PRE_STYLE = "white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace" -FREEZE_TIME = "2023-01-01 00:00:00" +FREEZE_TIME = "2023-01-01 00:00:00 UTC" pytestmark = pytest.mark.jupyter @@ -55,7 +59,7 @@ def sushi_context(copy_to_temp_path, notebook, tmp_path) -> Context: @pytest.fixture -@freeze_time(FREEZE_TIME) +@time_machine.travel(FREEZE_TIME) def loaded_sushi_context(sushi_context) -> Context: with capture_output(): sushi_context.plan(no_prompts=True, auto_apply=True) @@ -77,9 +81,9 @@ def convert_all_html_output_to_tags(): def _convert_html_to_tags(html: str) -> t.List[str]: # BS4 automatically adds html and body tags so we remove those since they are not actually part of the output return [ - tag.name + tag.name # type: ignore for tag in BeautifulSoup(html, "html").find_all() - if tag.name not in {"html", "body"} + if tag.name not in {"html", "body"} # type: ignore ] def _convert(output: CapturedIO) -> t.List[t.List[str]]: @@ -106,8 +110,9 @@ def test_context(notebook, convert_all_html_output_to_text, get_all_html_output, assert output.stdout == "" assert output.stderr == "" assert len(output.outputs) == 1 + sushi_path = str(Path("examples/sushi")) assert convert_all_html_output_to_text(output) == [ - "SQLMesh project context set to: examples/sushi" + f"SQLMesh project context set to: {sushi_path}" ] assert get_all_html_output(output) == [ str( @@ -117,7 +122,7 @@ def test_context(notebook, convert_all_html_output_to_text, get_all_html_output, h( "span", {"style": SUCCESS_STYLE}, - "SQLMesh project context set to: examples/sushi", + f"SQLMesh project context set to: {sushi_path}", autoescape=False, ), autoescape=False, @@ -129,7 +134,7 @@ def test_context(notebook, convert_all_html_output_to_text, get_all_html_output, def test_init(tmp_path, notebook, convert_all_html_output_to_text, get_all_html_output): with pytest.raises(UsageError, match="the following arguments are required: path"): notebook.run_line_magic(magic_name="init", line="") - with pytest.raises(UsageError, match="the following arguments are required: sql_dialect"): + with pytest.raises(UsageError, match="the following arguments are required: engine"): notebook.run_line_magic(magic_name="init", line="foo") with capture_output() as output: notebook.run_line_magic(magic_name="init", line=f"{tmp_path} duckdb") @@ -163,9 +168,9 @@ def test_render( assert output.stdout == "" assert output.stderr == "" - assert len(output.outputs) == 1 - assert len(convert_all_html_output_to_text(output)[0]) > 2200 - assert len(convert_all_html_output_to_tags(output)[0]) > 150 + assert len(output.outputs) == 2 + assert len(convert_all_html_output_to_text(output)[1]) > 2200 + assert len(convert_all_html_output_to_tags(output)[1]) > 150 @pytest.mark.slow @@ -177,13 +182,13 @@ def test_render_no_format( assert output.stdout == "" assert output.stderr == "" - assert len(output.outputs) == 1 - assert len(convert_all_html_output_to_text(output)[0]) >= 700 - assert len(convert_all_html_output_to_tags(output)[0]) >= 50 + assert len(output.outputs) == 2 + assert len(convert_all_html_output_to_text(output)[1]) >= 700 + assert len(convert_all_html_output_to_tags(output)[1]) >= 50 @pytest.mark.slow -@freeze_time(FREEZE_TIME) +@time_machine.travel(FREEZE_TIME) def test_evaluate(notebook, loaded_sushi_context): with capture_output() as output: notebook.run_line_magic(magic_name="evaluate", line="sushi.top_waiters") @@ -237,7 +242,7 @@ def test_diff(sushi_context, notebook, convert_all_html_output_to_text, get_all_ assert not output.stderr assert len(output.outputs) == 2 assert convert_all_html_output_to_text(output) == [ - "Summary of differences against `prod`:", + "Differences from the `prod` environment:", "Models:\n└── Directly Modified:\n └── sqlmesh_example.test", ] assert get_all_html_output(output) == [ @@ -248,7 +253,7 @@ def test_diff(sushi_context, notebook, convert_all_html_output_to_text, get_all_ h( "span", {"style": "font-weight: bold"}, - "Summary of differences against `prod`:", + "Differences from the `prod` environment:", autoescape=False, ), autoescape=False, @@ -289,33 +294,28 @@ def test_plan( with capture_output() as output: notebook.run_line_magic(magic_name="plan", line="--no-prompts --auto-apply") - # TODO: Should this be going to stdout? This is printing the status updates for when each batch finishes for - # the models and how long it took - assert len(output.stdout.strip().split("\n")) == 22 assert not output.stderr - assert len(output.outputs) == 5 + assert len(output.outputs) == 4 text_output = convert_all_html_output_to_text(output) # TODO: Is this what we expect? # This has minor differences between CI/CD and local. assert "[2K" in text_output[0] assert text_output[1].startswith( - "Virtually Updating 'prod' ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0%" + "Updating virtual layer ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0%" ) # TODO: Is this what we expect? assert text_output[2] == "" - assert text_output[3] == "" - assert text_output[4] == "The target environment has been updated successfully" + assert text_output[3] == "✔ Virtual layer updated" assert convert_all_html_output_to_tags(output) == [ ["pre", "span"], - ["pre"] + ["span"] * 4, - ["pre"], + ["pre"] + ["span"] * 5, ["pre"], ["pre", "span"], ] @pytest.mark.slow -@freeze_time("2023-01-03 00:00:00") +@time_machine.travel("2023-01-03 00:00:00 UTC") def test_run_dag( notebook, loaded_sushi_context, convert_all_html_output_to_text, get_all_html_output ): @@ -323,15 +323,220 @@ def test_run_dag( notebook.run_line_magic(magic_name="run_dag", line="") assert not output.stdout.startswith( - "'Evaluating models ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 18/18" + "'Executing model batches ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 100.0% • 18/18" ) assert not output.stderr - assert len(output.outputs) == 2 - assert convert_all_html_output_to_text(output) == [ - "All model batches have been executed successfully", - "Run finished for environment 'prod'", - ] - assert get_all_html_output(output) == [ + + # At least 4 outputs expected as the number of models in the particular batch might vary + assert len(output.outputs) >= 4 + + html_text_actual = convert_all_html_output_to_text(output) + + # Check for key elements in the output + assert any("[2K" in text for text in html_text_actual) + assert any("Executing model batches" in text for text in html_text_actual) + assert any("✔ Model batches executed" in text for text in html_text_actual) + assert any("Run finished for environment 'prod'" in text for text in html_text_actual) + + # Check the final messages + final_outputs = [text for text in html_text_actual if text.strip()] + assert final_outputs[-2] == "✔ Model batches executed" + assert final_outputs[-1] == "Run finished for environment 'prod'" + + actual_html_output = get_all_html_output(output) + # Replace dynamic elapsed time with 00 + for i, chunk in enumerate(actual_html_output): + pattern = r'font-weight: bold">0.\d{2}s ' + import re + + actual_html_output[i] = re.sub(pattern, 'font-weight: bold">0.00s ', chunk) + expected_html_output = [ + str( + h( + "pre", + {"style": RICH_PRE_STYLE}, + "\x1b", + h( + "span", + {"style": BOLD_ONLY}, + "[", + autoescape=False, + ), + "2K", + autoescape=False, + ) + ), + str( + h( + "pre", + {"style": RICH_PRE_STYLE}, + h( + "span", + {"style": "color: #000080; text-decoration-color: #000080; font-weight: bold"}, + "Executing model batches", + autoescape=False, + ), + " ", + h( + "span", + {"style": "color: #f92672; text-decoration-color: #f92672"}, + "━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╸", + autoescape=False, + ), + h( + "span", + {"style": "color: #3a3a3a; text-decoration-color: #3a3a3a"}, + "━━", + autoescape=False, + ), + " ", + h( + "span", + {"style": "color: #800080; text-decoration-color: #800080"}, + "93.8%", + autoescape=False, + ), + " • ", + h( + "span", + {"style": SUCCESS_STYLE}, + "15/16", + autoescape=False, + ), + " • ", + h( + "span", + {"style": "color: #808000; text-decoration-color: #808000"}, + "0:00:00", + autoescape=False, + ), + "sushi.waiter_as_customer_by_day ", + h( + "span", + {"style": SUCCESS_STYLE}, + ".. ", + autoescape=False, + ), + " ", + autoescape=False, + ) + ), + str( + h( + "pre", + {"style": RICH_PRE_STYLE}, + h( + "span", + {"style": BOLD_ONLY}, + "[", + autoescape=False, + ), + h( + "span", + {"style": BOLD_NEUTRAL_STYLE}, + "1", + autoescape=False, + ), + "/", + h( + "span", + {"style": BOLD_NEUTRAL_STYLE}, + "1", + autoescape=False, + ), + h( + "span", + {"style": BOLD_ONLY}, + "]", + autoescape=False, + ), + " sushi.waiter_as_customer_by_day ", + h( + "span", + {"style": BOLD_ONLY}, + "[", + autoescape=False, + ), + "insert ", + h( + "span", + {"style": BOLD_NEUTRAL_STYLE}, + "2023", + autoescape=False, + ), + "-", + h( + "span", + {"style": BOLD_NEUTRAL_STYLE}, + "01", + autoescape=False, + ), + "-", + h( + "span", + {"style": BOLD_NEUTRAL_STYLE}, + "01", + autoescape=False, + ), + " - ", + h( + "span", + {"style": BOLD_NEUTRAL_STYLE}, + "2023", + autoescape=False, + ), + "-", + h( + "span", + {"style": BOLD_NEUTRAL_STYLE}, + "01", + autoescape=False, + ), + "-", + h( + "span", + {"style": BOLD_NEUTRAL_STYLE}, + "02", + autoescape=False, + ), + ", audits ", + h( + "span", + {"style": SUCCESS_STYLE}, + "✔", + autoescape=False, + ), + h( + "span", + {"style": BOLD_NEUTRAL_STYLE}, + "2", + autoescape=False, + ), + h( + "span", + {"style": BOLD_ONLY}, + "]", + autoescape=False, + ), + " ", + h( + "span", + {"style": BOLD_NEUTRAL_STYLE}, + "0.", + autoescape=False, + ), + "00s ", + autoescape=False, + ) + ), + str( + h( + "pre", + {"style": RICH_PRE_STYLE}, + "", + autoescape=False, + ) + ), str( h( "pre", @@ -339,7 +544,7 @@ def test_run_dag( h( "span", {"style": SUCCESS_STYLE}, - "All model batches have been executed successfully", + "✔ Model batches executed", autoescape=False, ), autoescape=False, @@ -368,7 +573,7 @@ def test_run_dag( @pytest.mark.slow -@freeze_time(FREEZE_TIME) +@time_machine.travel(FREEZE_TIME) def test_invalidate( notebook, loaded_sushi_context, convert_all_html_output_to_text, get_all_html_output ): @@ -383,7 +588,7 @@ def test_invalidate( assert not output.stderr assert len(output.outputs) == 1 assert convert_all_html_output_to_text(output) == [ - "Environment 'dev' has been invalidated.", + "Environment 'dev' invalidated.", ] assert get_all_html_output(output) == [ str( @@ -410,7 +615,7 @@ def test_invalidate( h( "span", {"style": SUCCESS_STYLE}, - " has been invalidated.", + " invalidated.", autoescape=False, ) ), @@ -470,9 +675,9 @@ def test_create_test(notebook, sushi_context): assert ( test_file.read_text() == """test_top_waiters: - model: sushi.top_waiters + model: '"memory"."sushi"."top_waiters"' inputs: - sushi.waiter_revenue_by_day: + '"memory"."sushi"."waiter_revenue_by_day"': - waiter_id: 1 outputs: query: [] @@ -502,20 +707,6 @@ def test_test(notebook, sushi_context): assert test_file.read_text() == """test_customer_revenue_by_day: TESTING\n""" -def test_run_test(notebook, sushi_context): - with capture_output() as output: - notebook.run_line_magic( - magic_name="run_test", - line=f"{sushi_context.path / 'tests' / 'test_customer_revenue_by_day.yaml'}::test_customer_revenue_by_day", - ) - - assert not output.stdout - # TODO: Does it make sense for this to go to stderr? - assert "Ran 1 test" in output.stderr - assert "OK" in output.stderr - assert not output.outputs - - @pytest.mark.slow def test_audit(notebook, loaded_sushi_context, convert_all_html_output_to_text): with capture_output() as output: @@ -545,66 +736,26 @@ def test_fetchdf(notebook, sushi_context): def test_info(notebook, sushi_context, convert_all_html_output_to_text, get_all_html_output): with capture_output() as output: - notebook.run_line_magic(magic_name="info", line="") + notebook.run_line_magic(magic_name="info", line="--verbose") assert not output.stdout assert not output.stderr - assert len(output.outputs) == 3 + assert len(output.outputs) == 6 assert convert_all_html_output_to_text(output) == [ - "Models: 17", - "Macros: 5", + "Models: 20", + "Macros: 8", + "", + "Connection:\n type: duckdb\n concurrent_tasks: 1\n register_comments: true\n pre_ping: false\n pretty_sql: false\n extensions: []\n connector_config: {}\n secrets: None\n filesystems: []", + "Test Connection:\n type: duckdb\n concurrent_tasks: 1\n register_comments: true\n pre_ping: false\n pretty_sql: false\n extensions: []\n connector_config: {}\n secrets: None\n filesystems: []", "Data warehouse connection succeeded", ] assert get_all_html_output(output) == [ - str( - h( - "pre", - {"style": RICH_PRE_STYLE}, - "Models: " - + str( - h( - "span", - # "color: #008000; text-decoration-color: #008000" - {"style": f"{NEUTRAL_STYLE}; font-weight: bold"}, - "17", - autoescape=False, - ) - ), - autoescape=False, - ) - ), - str( - h( - "pre", - {"style": RICH_PRE_STYLE}, - "Macros: " - + str( - h( - "span", - {"style": f"{NEUTRAL_STYLE}; font-weight: bold"}, - "5", - autoescape=False, - ) - ), - autoescape=False, - ) - ), - str( - h( - "pre", - {"style": RICH_PRE_STYLE}, - "Data warehouse connection " - + str( - h( - "span", - {"style": SUCCESS_STYLE}, - "succeeded", - autoescape=False, - ) - ), - autoescape=False, - ) - ), + "
Models: 20
", + "
Macros: 8
", + "
",
+        '
Connection:  type: duckdb  concurrent_tasks: 1  register_comments: true  pre_ping: false  pretty_sql: false  extensions: []  connector_config: {}  secrets: None  filesystems: []
', + '
Test Connection:  type: duckdb  concurrent_tasks: 1  register_comments: true  pre_ping: false  pretty_sql: false  extensions: []  connector_config: {}  secrets: None  filesystems: []
', + "
Data warehouse connection succeeded
", ] @@ -646,7 +797,7 @@ def test_migrate( @pytest.mark.slow -def test_create_external_models(notebook, loaded_sushi_context): +def test_create_external_models(notebook, loaded_sushi_context, convert_all_html_output_to_text): external_model_file = loaded_sushi_context.path / "external_models.yaml" external_model_file.unlink() assert not external_model_file.exists() @@ -656,7 +807,11 @@ def test_create_external_models(notebook, loaded_sushi_context): assert not output.stdout assert not output.stderr - assert not output.outputs + assert len(output.outputs) == 2 + converted = sorted(convert_all_html_output_to_text(output)) + assert 'Unable to get schema for \'"memory"."raw"."model1"\'' in converted[0] + assert 'Unable to get schema for \'"memory"."raw"."model2"\'' in converted[1] + assert external_model_file.exists() assert ( external_model_file.read_text() @@ -664,12 +819,13 @@ def test_create_external_models(notebook, loaded_sushi_context): columns: customer_id: INT zip: TEXT + gateway: duckdb """ ) @pytest.mark.slow -@freeze_time(FREEZE_TIME) +@time_machine.travel(FREEZE_TIME) def test_table_diff(notebook, loaded_sushi_context, convert_all_html_output_to_text): with capture_output(): loaded_sushi_context.plan("dev", no_prompts=True, auto_apply=True, include_unmodified=True) @@ -678,20 +834,15 @@ def test_table_diff(notebook, loaded_sushi_context, convert_all_html_output_to_t assert not output.stdout assert not output.stderr - assert len(output.outputs) == 4 + + assert len(output.outputs) == 1 assert convert_all_html_output_to_text(output) == [ - """Schema Diff Between 'DEV' and 'PROD' environments for model 'sushi.top_waiters': -└── Schemas match""", - """Row Counts: -└── FULL MATCH: 8 rows (100.0%)""", - """COMMON ROWS column comparison stats:""", - """pct_match -revenue 100.0""", + "No models contain differences with the selection criteria: 'sushi.top_waiters'" ] @pytest.mark.slow -@freeze_time(FREEZE_TIME) +@time_machine.travel(FREEZE_TIME) def test_table_name(notebook, loaded_sushi_context, convert_all_html_output_to_text): with capture_output() as output: notebook.run_line_magic(magic_name="table_name", line="sushi.orders") @@ -702,3 +853,61 @@ def test_table_name(notebook, loaded_sushi_context, convert_all_html_output_to_t assert convert_all_html_output_to_text(output)[0].startswith( "memory.sqlmesh__sushi.sushi__orders__" ) + + +def test_lint(notebook, sushi_context): + from sqlmesh.core.config import LinterConfig + + sushi_context.config.linter = LinterConfig(enabled=True, warn_rules="ALL") + sushi_context.load() + + with capture_output() as output: + notebook.run_line_magic(magic_name="lint", line="") + + assert len(output.outputs) > 1 + assert "Linter warnings for" in output.outputs[0].data["text/plain"] + + with capture_output() as output: + notebook.run_line_magic(magic_name="lint", line="--models sushi.items") + + assert len(output.outputs) == 1 + assert "Linter warnings for" in output.outputs[0].data["text/plain"] + + with capture_output() as output: + notebook.run_line_magic(magic_name="lint", line="--models sushi.items sushi.raw_marketing") + + assert len(output.outputs) == 2 + assert "Linter warnings for" in output.outputs[0].data["text/plain"] + + +@pytest.mark.slow +def test_destroy( + notebook, + loaded_sushi_context, + convert_all_html_output_to_text, + get_all_html_output, + monkeypatch, +): + # Mock input to return 'y' for the confirmation prompt + monkeypatch.setattr("builtins.input", lambda: "y") + + with capture_output() as output: + notebook.run_line_magic(magic_name="destroy", line="") + + assert not output.stdout + assert not output.stderr + text_output = convert_all_html_output_to_text(output) + expected_messages = [ + "[WARNING] This will permanently delete all engine-managed objects, state tables and SQLMesh cache.\nThe operation may disrupt any currently running or scheduled plans.", + "Schemas to be deleted:", + "• memory.sushi", + "Snapshot tables to be deleted:", + "This action will DELETE ALL the above resources managed by SQLMesh AND\npotentially external resources created by other tools in these schemas.", + "Are you ABSOLUTELY SURE you want to proceed with deletion? [y/n]:", + "Environment 'prod' invalidated.", + "Deleted object memory.sushi", + "State tables removed.", + "Destroy completed successfully.", + ] + for message in expected_messages: + assert any(message in line for line in text_output) diff --git a/tests/lsp/conftest.py b/tests/lsp/conftest.py new file mode 100644 index 0000000000..6b3d3315aa --- /dev/null +++ b/tests/lsp/conftest.py @@ -0,0 +1,4 @@ +import pytest + +# Apply the 'fast' mark to all tests in this directory and subdirectories +pytestmark = pytest.mark.fast diff --git a/tests/lsp/test_code_actions.py b/tests/lsp/test_code_actions.py new file mode 100644 index 0000000000..509f49f9b1 --- /dev/null +++ b/tests/lsp/test_code_actions.py @@ -0,0 +1,182 @@ +import typing as t +import os +from lsprotocol import types +from sqlmesh.core.context import Context +from sqlmesh.lsp.context import LSPContext +from sqlmesh.lsp.uri import URI + + +def test_code_actions_with_linting(copy_to_temp_path: t.Callable): + """Test that code actions are generated for linting violations.""" + + # Copy sushi example to a temporary directory + sushi_paths = copy_to_temp_path("examples/sushi") + sushi_path = sushi_paths[0] + + # Override the config and turn the linter on + config_path = sushi_path / "config.py" + with config_path.open("r") as f: + lines = f.readlines() + lines = [ + line.replace("enabled=False,", "enabled=True,") if "enabled=False," in line else line + for line in lines + ] + with config_path.open("w") as f: + f.writelines(lines) + + # Override the latest_order.sql file to introduce a linter violation + model_content = """MODEL ( + name sushi.latest_order, + kind CUSTOM ( + materialization 'custom_full_with_custom_kind', + materialization_properties ( + custom_property = 'sushi!!!' + ) + ), + cron '@daily' +); + +SELECT * +FROM sushi.orders +ORDER BY event_date DESC LIMIT 1 +""" + latest_order_path = sushi_path / "models" / "latest_order.sql" + with latest_order_path.open("w") as f: + f.write(model_content) + + # Create context with the mocked config + context = Context(paths=[str(sushi_path)]) + + # Create LSP context + lsp_context = LSPContext(context) + + # Get diagnostics (linting violations) + violations = lsp_context.lint_model(URI.from_path(sushi_path / "models" / "latest_order.sql")) + + uri = URI.from_path(sushi_path / "models" / "latest_order.sql") + + # First, convert violations to LSP diagnostics + diagnostics = [] + for violation in violations: + if violation.violation_range: + diagnostic = types.Diagnostic( + range=types.Range( + start=types.Position( + line=violation.violation_range.start.line, + character=violation.violation_range.start.character, + ), + end=types.Position( + line=violation.violation_range.end.line, + character=violation.violation_range.end.character, + ), + ), + message=violation.violation_msg, + severity=types.DiagnosticSeverity.Warning, + ) + diagnostics.append(diagnostic) + + # Create code action params with diagnostics + params = types.CodeActionParams( + text_document=types.TextDocumentIdentifier(uri=uri.value), + range=types.Range( + start=types.Position(line=0, character=0), + end=types.Position(line=100, character=0), + ), + context=types.CodeActionContext(diagnostics=diagnostics), + ) + + # Get code actions + code_actions = lsp_context.get_code_actions( + URI.from_path(sushi_path / "models" / "latest_order.sql"), params + ) + + # Verify we have code actions + assert code_actions is not None + assert len(code_actions) > 0 + + # Verify the code action properties + first_action = code_actions[0] + if not isinstance(first_action, types.CodeAction): + raise AssertionError("First action is not a CodeAction instance") + assert first_action.kind == types.CodeActionKind.QuickFix + assert first_action.edit is not None + assert first_action.edit.changes is not None + assert ( + URI.from_path(sushi_path / "models" / "latest_order.sql").value in first_action.edit.changes + ) + + # The fix should replace SELECT * with specific columns + text_edits = first_action.edit.changes[ + URI.from_path(sushi_path / "models" / "latest_order.sql").value + ] + assert len(text_edits) > 0 + + +def test_code_actions_create_file(copy_to_temp_path: t.Callable) -> None: + sushi_paths = copy_to_temp_path("examples/sushi") + sushi_path = sushi_paths[0] + + # Remove external models file and enable linter + os.remove(sushi_path / "external_models.yaml") + config_path = sushi_path / "config.py" + with config_path.open("r") as f: + content = f.read() + + before = """ linter=LinterConfig( + enabled=False, + rules=[ + "ambiguousorinvalidcolumn", + "invalidselectstarexpansion", + "noselectstar", + "nomissingaudits", + "nomissingowner", + "nomissingexternalmodels", + ], + ),""" + after = """linter=LinterConfig(enabled=True, rules=["nomissingexternalmodels"]),""" + content = content.replace(before, after) + with config_path.open("w") as f: + f.write(content) + + context = Context(paths=[str(sushi_path)]) + lsp_context = LSPContext(context) + + uri = URI.from_path(sushi_path / "models" / "customers.sql") + violations = lsp_context.lint_model(uri) + + diagnostics = [] + for violation in violations: + if violation.violation_range: + diagnostics.append( + types.Diagnostic( + range=types.Range( + start=types.Position( + line=violation.violation_range.start.line, + character=violation.violation_range.start.character, + ), + end=types.Position( + line=violation.violation_range.end.line, + character=violation.violation_range.end.character, + ), + ), + message=violation.violation_msg, + severity=types.DiagnosticSeverity.Warning, + ) + ) + + params = types.CodeActionParams( + text_document=types.TextDocumentIdentifier(uri=uri.value), + range=types.Range( + start=types.Position(line=0, character=0), end=types.Position(line=1, character=0) + ), + context=types.CodeActionContext(diagnostics=diagnostics), + ) + + actions = lsp_context.get_code_actions(uri, params) + assert actions is not None and len(actions) > 0 + action = next(a for a in actions if isinstance(a, types.CodeAction)) + assert action.edit is not None + assert action.edit.document_changes is not None + create_file = [c for c in action.edit.document_changes if isinstance(c, types.CreateFile)] + assert create_file, "Expected a CreateFile operation" + assert create_file[0].uri == URI.from_path(sushi_path / "external_models.yaml").value diff --git a/tests/lsp/test_completions.py b/tests/lsp/test_completions.py new file mode 100644 index 0000000000..e0772c1a96 --- /dev/null +++ b/tests/lsp/test_completions.py @@ -0,0 +1,197 @@ +from sqlglot import Tokenizer +from sqlmesh.core.context import Context +from sqlmesh.lsp.completions import ( + get_keywords_from_tokenizer, + get_sql_completions, + extract_keywords_from_content, +) +from sqlmesh.lsp.context import LSPContext +from sqlmesh.lsp.uri import URI + + +TOKENIZER_KEYWORDS = set(Tokenizer.KEYWORDS.keys()) + + +def test_get_keywords_from_tokenizer(): + assert len(get_keywords_from_tokenizer()) >= len(TOKENIZER_KEYWORDS) + + +def test_get_sql_completions_no_context(): + completions = get_sql_completions(None, None) + assert len(completions.keywords) >= len(TOKENIZER_KEYWORDS) + assert len(completions.models) == 0 + + +def test_get_macros(): + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + file_path = next(key for key in lsp_context.map.keys() if key.name == "active_customers.sql") + with open(file_path, "r", encoding="utf-8") as f: + file_content = f.read() + + file_uri = URI.from_path(file_path) + completions = LSPContext.get_completions(lsp_context, file_uri, file_content) + + each_macro = next((m for m in completions.macros if m.name == "each")) + assert each_macro.name == "each" + assert each_macro.description + add_one_macro = next((m for m in completions.macros if m.name == "add_one")) + assert add_one_macro.name == "add_one" + assert add_one_macro.description + + +def test_model_completions_include_descriptions(): + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + completions = LSPContext.get_completions(lsp_context, None) + + model_entry = next( + (m for m in completions.model_completions if m.name == "sushi.customers"), + None, + ) + assert model_entry is not None + assert model_entry.description + + +def test_get_sql_completions_with_context_no_file_uri(): + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + completions = LSPContext.get_completions(lsp_context, None) + assert len(completions.keywords) >= len(TOKENIZER_KEYWORDS) + assert "sushi.active_customers" in completions.models + assert "sushi.customers" in completions.models + + +def test_get_sql_completions_with_context_and_file_uri(): + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + file_uri = next(key for key in lsp_context.map.keys() if key.name == "active_customers.sql") + completions = LSPContext.get_completions(lsp_context, URI.from_path(file_uri)) + assert len(completions.keywords) > len(TOKENIZER_KEYWORDS) + assert "sushi.active_customers" not in completions.models + + +def test_extract_keywords_from_content(): + # Test extracting keywords from SQL content + content = """ + SELECT customer_id, order_date, total_amount + FROM orders o + JOIN customers c ON o.customer_id = c.id + WHERE order_date > '2024-01-01' + """ + + keywords = extract_keywords_from_content(content) + + # Check that identifiers are extracted + assert "customer_id" in keywords + assert "order_date" in keywords + assert "total_amount" in keywords + assert "orders" in keywords + assert "customers" in keywords + assert "o" in keywords # alias + assert "c" in keywords # alias + assert "id" in keywords + + # Check that SQL keywords are NOT included + assert "SELECT" not in keywords + assert "FROM" not in keywords + assert "JOIN" not in keywords + assert "WHERE" not in keywords + assert "ON" not in keywords + + +def test_get_sql_completions_with_file_content(): + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + # SQL content with custom identifiers + content = """ + SELECT my_custom_column, another_identifier + FROM my_custom_table mct + JOIN some_other_table sot ON mct.id = sot.table_id + WHERE my_custom_column > 100 + """ + + file_uri = next(key for key in lsp_context.map.keys() if key.name == "active_customers.sql") + completions = LSPContext.get_completions(lsp_context, URI.from_path(file_uri), content) + + # Check that SQL keywords are included + assert any(k in ["SELECT", "FROM", "WHERE", "JOIN"] for k in completions.keywords) + + # Check that file-specific identifiers are included at the end + keywords_list = completions.keywords + assert "my_custom_column" in keywords_list + assert "another_identifier" in keywords_list + assert "my_custom_table" in keywords_list + assert "some_other_table" in keywords_list + assert "mct" in keywords_list # alias + assert "sot" in keywords_list # alias + assert "table_id" in keywords_list + + # Check that file keywords come after SQL keywords + # SQL keywords should appear first in the list + sql_keyword_indices = [ + i for i, k in enumerate(keywords_list) if k in ["SELECT", "FROM", "WHERE", "JOIN"] + ] + file_keyword_indices = [ + i for i, k in enumerate(keywords_list) if k in ["my_custom_column", "my_custom_table"] + ] + + if sql_keyword_indices and file_keyword_indices: + assert max(sql_keyword_indices) < min(file_keyword_indices), ( + "SQL keywords should come before file keywords" + ) + + +def test_get_sql_completions_with_partial_cte_query(): + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + # Partial SQL query with CTEs + content = """ + WITH _latest_complete_month AS ( + SELECT MAX(date_trunc('month', order_date)) as month + FROM orders + ), + _filtered AS ( + SELECT * FROM + """ + + file_uri = next(key for key in lsp_context.map.keys() if key.name == "active_customers.sql") + completions = LSPContext.get_completions(lsp_context, URI.from_path(file_uri), content) + + # Check that CTE names are included in the keywords + keywords_list = completions.keywords + assert "_latest_complete_month" in keywords_list + assert "_filtered" in keywords_list + + # Also check other identifiers from the partial query + assert "month" in keywords_list + assert "order_date" in keywords_list + assert "orders" in keywords_list + + +def test_extract_keywords_from_partial_query(): + # Test extracting keywords from an incomplete SQL query + content = """ + WITH cte1 AS ( + SELECT col1, col2 FROM table1 + ), + cte2 AS ( + SELECT * FROM cte1 WHERE + """ + + keywords = extract_keywords_from_content(content) + + # Check that CTEs are extracted + assert "cte1" in keywords + assert "cte2" in keywords + + # Check that columns and tables are extracted + assert "col1" in keywords + assert "col2" in keywords + assert "table1" in keywords diff --git a/tests/lsp/test_context.py b/tests/lsp/test_context.py new file mode 100644 index 0000000000..b463a17139 --- /dev/null +++ b/tests/lsp/test_context.py @@ -0,0 +1,63 @@ +from pathlib import Path + +from sqlmesh.core.context import Context +from sqlmesh.lsp.context import LSPContext, ModelTarget +from sqlmesh.lsp.uri import URI + + +def test_lsp_context(): + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + assert lsp_context is not None + assert lsp_context.context is not None + assert lsp_context.map is not None + + # find one model in the map + active_customers_key = next( + key for key in lsp_context.map.keys() if key.name == "active_customers.sql" + ) + + # Check that the value is a ModelInfo with the expected model name + assert isinstance(lsp_context.map[active_customers_key], ModelTarget) + assert "sushi.active_customers" in lsp_context.map[active_customers_key].names + + +def test_lsp_context_list_workspace_tests(): + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + # List workspace tests + tests = lsp_context.list_workspace_tests() + + # Check that the tests are returned correctly + assert len(tests) == 3 + assert any(test.name == "test_order_items" for test in tests) + + +def test_lsp_context_get_document_tests(): + test_path = Path.cwd() / "examples/sushi/tests/test_order_items.yaml" + uri = URI.from_path(test_path) + + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + tests = lsp_context.get_document_tests(uri) + + assert len(tests) == 1 + assert tests[0].uri == uri.value + assert tests[0].name == "test_order_items" + + +def test_lsp_context_run_test(): + test_path = Path.cwd() / "examples/sushi/tests/test_order_items.yaml" + uri = URI.from_path(test_path) + + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + # Run the test + result = lsp_context.run_test(uri, "test_order_items") + + # Check that the result is not None and has the expected properties + assert result is not None + assert result.success is True diff --git a/tests/lsp/test_diagnostics.py b/tests/lsp/test_diagnostics.py new file mode 100644 index 0000000000..96167d47e5 --- /dev/null +++ b/tests/lsp/test_diagnostics.py @@ -0,0 +1,55 @@ +from sqlmesh import Context +from sqlmesh.core.linter.helpers import read_range_from_file +from sqlmesh.lsp.context import LSPContext +from sqlmesh.lsp.uri import URI + + +def test_diagnostic_on_sushi(tmp_path, copy_to_temp_path) -> None: + # Copy sushi example to a temporary directory + sushi_paths = copy_to_temp_path("examples/sushi") + sushi_path = sushi_paths[0] + + # Override the active_customers.sql file to introduce a linter violation + active_customers_path = sushi_path / "models" / "active_customers.sql" + # Replace SELECT customer_id, zip with SELECT * to trigger a linter violation + with active_customers_path.open("r") as f: + lines = f.readlines() + lines = [ + line.replace("SELECT customer_id, zip", "SELECT *") + if "SELECT customer_id, zip" in line + else line + for line in lines + ] + with active_customers_path.open("w") as f: + f.writelines(lines) + + # Override the config and turn the linter on + config_path = sushi_path / "config.py" + with config_path.open("r") as f: + lines = f.readlines() + lines = [ + line.replace("enabled=False,", "enabled=True,") if "enabled=False," in line else line + for line in lines + ] + with config_path.open("w") as f: + f.writelines(lines) + + # Load the context with the temporary sushi path + context = Context(paths=[str(sushi_path)]) + lsp_context = LSPContext(context) + + # Diagnostics should be available + active_customers_uri = URI.from_path(active_customers_path) + lsp_diagnostics = lsp_context.lint_model(active_customers_uri) + + assert len(lsp_diagnostics) > 0 + + # Get the no select star diagnostic + select_star_diagnostic = [diag for diag in lsp_diagnostics if diag.rule.name == "noselectstar"] + assert len(select_star_diagnostic) == 1 + diagnostic = select_star_diagnostic[0] + + assert diagnostic.violation_range + + contents = read_range_from_file(active_customers_path, diagnostic.violation_range) + assert contents == "*" diff --git a/tests/lsp/test_document_highlight.py b/tests/lsp/test_document_highlight.py new file mode 100644 index 0000000000..e6ce0ae7ec --- /dev/null +++ b/tests/lsp/test_document_highlight.py @@ -0,0 +1,111 @@ +from lsprotocol.types import Position, DocumentHighlightKind + +from sqlmesh.core.context import Context +from sqlmesh.lsp.context import LSPContext, ModelTarget +from sqlmesh.lsp.rename import get_document_highlights +from sqlmesh.lsp.uri import URI +from tests.lsp.test_reference_cte import find_ranges_from_regex + + +def test_get_document_highlights_cte(): + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + # Use the existing customers.sql model which has CTEs + sushi_customers_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.customers" in info.names + ) + + with open(sushi_customers_path, "r", encoding="utf-8") as file: + read_file = file.readlines() + + test_uri = URI.from_path(sushi_customers_path) + + # Find the ranges for "current_marketing" CTE (not outer one) + ranges = find_ranges_from_regex(read_file, r"current_marketing(?!_outer)") + assert len(ranges) >= 2 # Should have definition + usage + + # Test highlighting CTE definition - position on "current_marketing" definition + position = Position(line=ranges[0].start.line, character=ranges[0].start.character + 4) + highlights = get_document_highlights(lsp_context, test_uri, position) + + assert highlights is not None + assert len(highlights) >= 2 # Definition + at least 1 usage + + # Check that we have both definition (Write) and usage (Read) highlights + highlight_kinds = [h.kind for h in highlights] + assert DocumentHighlightKind.Write in highlight_kinds # CTE definition + assert DocumentHighlightKind.Read in highlight_kinds # CTE usage + + # Test highlighting CTE usage - position on "current_marketing" usage + position = Position(line=ranges[1].start.line, character=ranges[1].start.character + 4) + highlights = get_document_highlights(lsp_context, test_uri, position) + + assert highlights is not None + assert len(highlights) >= 2 # Should find the same references + + +def test_get_document_highlights_no_symbol(): + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + # Use the existing customers.sql model + sushi_customers_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.customers" in info.names + ) + + test_uri = URI.from_path(sushi_customers_path) + + # Test position not on any CTE symbol - just on a random keyword + position = Position(line=5, character=5) + highlights = get_document_highlights(lsp_context, test_uri, position) + + assert highlights is None + + +def test_get_document_highlights_multiple_ctes(): + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + # Use the existing customers.sql model which has both outer and inner CTEs + sushi_customers_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.customers" in info.names + ) + + with open(sushi_customers_path, "r", encoding="utf-8") as file: + read_file = file.readlines() + + test_uri = URI.from_path(sushi_customers_path) + + # Test the outer CTE - "current_marketing_outer" + outer_ranges = find_ranges_from_regex(read_file, r"current_marketing_outer") + assert len(outer_ranges) >= 2 # Should have definition + usage + + # Test highlighting outer CTE - should only highlight that CTE + position = Position( + line=outer_ranges[0].start.line, character=outer_ranges[0].start.character + 4 + ) + highlights = get_document_highlights(lsp_context, test_uri, position) + + assert highlights is not None + assert len(highlights) == len(outer_ranges) # Should match all occurrences of outer CTE + + # Test the inner CTE - "current_marketing" (not outer) + inner_ranges = find_ranges_from_regex(read_file, r"current_marketing(?!_outer)") + assert len(inner_ranges) >= 2 # Should have definition + usage + + # Test highlighting inner CTE - should only highlight that CTE, not the outer one + position = Position( + line=inner_ranges[0].start.line, character=inner_ranges[0].start.character + 4 + ) + highlights = get_document_highlights(lsp_context, test_uri, position) + + # This should return the column usages as well + assert highlights is not None + assert len(highlights) == 4 diff --git a/tests/lsp/test_hints.py b/tests/lsp/test_hints.py new file mode 100644 index 0000000000..99851a1361 --- /dev/null +++ b/tests/lsp/test_hints.py @@ -0,0 +1,203 @@ +"""Tests for type hinting SQLMesh models""" + +import pytest + +from sqlglot import exp, parse_one + +from sqlmesh.core.context import Context +from sqlmesh.lsp.context import LSPContext, ModelTarget +from sqlmesh.lsp.hints import get_hints, _get_type_hints_for_model_from_query +from sqlmesh.lsp.uri import URI + + +@pytest.mark.fast +def test_hints() -> None: + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + # Find model URIs + active_customers_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.active_customers" in info.names + ) + customer_revenue_lifetime_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.customer_revenue_lifetime" in info.names + ) + customer_revenue_by_day_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.customer_revenue_by_day" in info.names + ) + + active_customers_uri = URI.from_path(active_customers_path) + ac_hints = get_hints(lsp_context, active_customers_uri, start_line=0, end_line=9999) + assert len(ac_hints) == 2 + assert ac_hints[0].label == "::INT" + assert ac_hints[1].label == "::TEXT" + + customer_revenue_lifetime_uri = URI.from_path(customer_revenue_lifetime_path) + crl_hints = get_hints( + lsp_context=lsp_context, + document_uri=customer_revenue_lifetime_uri, + start_line=0, + end_line=9999, + ) + assert len(crl_hints) == 3 + assert crl_hints[0].label == "::INT" + assert crl_hints[1].label == "::DOUBLE" + assert crl_hints[2].label == "::DATE" + + customer_revenue_by_day_uri = URI.from_path(customer_revenue_by_day_path) + crbd_hints = get_hints( + lsp_context=lsp_context, + document_uri=customer_revenue_by_day_uri, + start_line=0, + end_line=9999, + ) + assert len(crbd_hints) == 1 + assert crbd_hints[0].label == "::INT" + + +@pytest.mark.fast +def test_union_hints() -> None: + query_str = """SELECT a FROM table_a UNION SELECT b FROM table_b UNION SELECT c FROM table_c""" + query = parse_one(query_str, dialect="postgres") + + result = _get_type_hints_for_model_from_query( + query=query, + dialect="postgres", + columns_to_types={ + "a": exp.DataType.build("TEXT"), + "b": exp.DataType.build("INT"), + "c": exp.DataType.build("DATE"), + }, + start_line=0, + end_line=1, + ) + + assert len(result) == 3 + assert result[0].label == "::DATE" + assert result[1].label == "::TEXT" + assert result[2].label == "::INT" + + +@pytest.mark.fast +def test_complex_hints() -> None: + query = parse_one("SELECT a, b FROM c", dialect="postgres") + + result = _get_type_hints_for_model_from_query( + query=query, + dialect="postgres", + columns_to_types={ + "a": exp.DataType.build("VARCHAR(100)"), + "b": exp.DataType.build("STRUCT>>"), + }, + start_line=0, + end_line=1, + ) + + assert len(result) == 2 + assert result[0].label == "::VARCHAR(100)" + assert result[1].label == "::STRUCT>" + + +@pytest.mark.fast +def test_simple_cast_hints() -> None: + """Don't add type hints if the expression is already a cast""" + query = parse_one("SELECT a::INT, CAST(b AS DATE), c FROM d", dialect="postgres") + + result = _get_type_hints_for_model_from_query( + query=query, + dialect="postgres", + columns_to_types={ + "a": exp.DataType.build("INT"), + "b": exp.DataType.build("DATE"), + "c": exp.DataType.build("TEXT"), + }, + start_line=0, + end_line=1, + ) + + assert len(result) == 1 + assert result[0].label == "::TEXT" + + +@pytest.mark.fast +def test_alias_cast_hints() -> None: + """Don't add type hints if the expression is already a cast""" + query = parse_one( + "SELECT raw_a::INT AS a, CAST(raw_b AS DATE) AS b, c FROM d", dialect="postgres" + ) + + result = _get_type_hints_for_model_from_query( + query=query, + dialect="postgres", + columns_to_types={ + "a": exp.DataType.build("INT"), + "b": exp.DataType.build("DATE"), + "c": exp.DataType.build("TEXT"), + }, + start_line=0, + end_line=1, + ) + + assert len(result) == 1 + assert result[0].label == "::TEXT" + + +@pytest.mark.fast +def test_simple_cte_hints() -> None: + """Don't add type hints if the expression is already a cast""" + query = parse_one("WITH t AS (SELECT a FROM b) SELECT a AS c FROM t", dialect="postgres") + + result = _get_type_hints_for_model_from_query( + query=query, + dialect="postgres", + columns_to_types={ + "c": exp.DataType.build("INT"), + }, + start_line=0, + end_line=1, + ) + + assert len(result) == 1 + assert result[0].label == "::INT" + + +@pytest.mark.fast +def test_cte_with_union_hints() -> None: + """Don't add type hints if the expression is already a cast""" + query = parse_one( + """WITH x AS (SELECT a FROM t), + y AS (SELECT b FROM t), + z AS (SELECT c FROM t) + SELECT a AS d FROM x + UNION + SELECT b AS e FROM y + UNION + SELECT c AS f FROM z""", + dialect="postgres", + ) + + result = _get_type_hints_for_model_from_query( + query=query, + dialect="postgres", + columns_to_types={ + "a": exp.DataType.build("TEXT"), + "b": exp.DataType.build("DATE"), + "c": exp.DataType.build("INT"), + "d": exp.DataType.build("TEXT"), + "e": exp.DataType.build("DATE"), + "f": exp.DataType.build("INT"), + }, + start_line=0, + end_line=9999, + ) + + assert len(result) == 3 + assert result[0].label == "::INT" + assert result[1].label == "::TEXT" + assert result[2].label == "::DATE" diff --git a/tests/lsp/test_reference.py b/tests/lsp/test_reference.py new file mode 100644 index 0000000000..6aae4b869e --- /dev/null +++ b/tests/lsp/test_reference.py @@ -0,0 +1,198 @@ +from sqlmesh.core.context import Context +from sqlmesh.core.linter.rule import Position +from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget +from sqlmesh.lsp.reference import ModelReference, get_model_definitions_for_a_path, by_position +from sqlmesh.lsp.uri import URI + + +def test_reference() -> None: + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + # Find model URIs + active_customers_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.active_customers" in info.names + ) + sushi_customers_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.customers" in info.names + ) + + active_customers_uri = URI.from_path(active_customers_path) + references = get_model_definitions_for_a_path(lsp_context, active_customers_uri) + + assert len(references) == 1 + path = references[0].path + assert path is not None + assert path == sushi_customers_path + + # Check that the reference in the correct range is sushi.customers + path = active_customers_uri.to_path() + with open(path, "r") as file: + read_file = file.readlines() + + # Get the string range in the read file + referenced_text = get_string_from_range(read_file, references[0].range) + assert referenced_text == "sushi.customers" + + +def test_reference_with_alias() -> None: + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + waiter_revenue_by_day_path = next( + uri + for uri, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.waiter_revenue_by_day" in info.names + ) + + references = [ + ref + for ref in get_model_definitions_for_a_path( + lsp_context, URI.from_path(waiter_revenue_by_day_path) + ) + if isinstance(ref, ModelReference) + ] + assert len(references) == 3 + + with open(waiter_revenue_by_day_path, "r") as file: + read_file = file.readlines() + + assert str(references[0].path).endswith("orders.py") + assert get_string_from_range(read_file, references[0].range) == "sushi.orders" + assert ( + references[0].markdown_description + == """Table of sushi orders. + +| Column | Type | Description | +|--------|------|-------------| +| id | INT | | +| customer_id | INT | | +| waiter_id | INT | | +| start_ts | INT | | +| end_ts | INT | | +| event_date | DATE | |""" + ) + assert str(references[1].path).endswith("order_items.py") + assert get_string_from_range(read_file, references[1].range) == "sushi.order_items" + assert str(references[2].path).endswith("items.py") + assert get_string_from_range(read_file, references[2].range) == "sushi.items" + + +def test_standalone_audit_reference() -> None: + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + # Find the standalone audit URI + audit_path = next( + uri + for uri, info in lsp_context.map.items() + if isinstance(info, AuditTarget) and info.name == "assert_item_price_above_zero" + ) + # Find the items model URI + items_path = next( + uri + for uri, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.items" in info.names + ) + + references = get_model_definitions_for_a_path(lsp_context, URI.from_path(audit_path)) + + assert len(references) == 1 + assert references[0].path == items_path + + # Check that the reference in the correct range is sushi.items + with open(audit_path, "r") as file: + read_file = file.readlines() + referenced_text = get_string_from_range(read_file, references[0].range) + assert referenced_text == "sushi.items" + + +def get_string_from_range(file_lines, range_obj) -> str: + start_line = range_obj.start.line + end_line = range_obj.end.line + start_character = range_obj.start.character + end_character = range_obj.end.character + + # If the reference spans multiple lines, handle it accordingly + if start_line == end_line: + # Reference is on a single line + line_content = file_lines[start_line] + return line_content[start_character:end_character] + + # Reference spans multiple lines + result = file_lines[start_line][start_character:] # First line from start_character to end + for line_num in range(start_line + 1, end_line): # Middle lines (if any) + result += file_lines[line_num] + result += file_lines[end_line][:end_character] # Last line up to end_character + return result + + +def test_filter_references_by_position() -> None: + """Test that we can filter references correctly based on cursor position.""" + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + # Use a file with multiple references (waiter_revenue_by_day) + waiter_revenue_by_day_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.waiter_revenue_by_day" in info.names + ) + + # Get all references in the file + all_references = get_model_definitions_for_a_path( + lsp_context, URI.from_path(waiter_revenue_by_day_path) + ) + assert len(all_references) == 3 + + # Get file contents to locate positions for testing + with open(waiter_revenue_by_day_path, "r") as file: + read_file = file.readlines() + + # Test positions for each reference + for i, reference in enumerate(all_references): + # Position inside the reference - should return exactly one reference + middle_line = (reference.range.start.line + reference.range.end.line) // 2 + middle_char = (reference.range.start.character + reference.range.end.character) // 2 + position_inside = Position(line=middle_line, character=middle_char) + filtered = list(filter(by_position(position_inside), all_references)) + assert len(filtered) == 1 + assert filtered[0].path == reference.path + assert filtered[0].range == reference.range + + # For testing outside position, use a position before the current reference + # or after the last reference for the last one + if i == 0: + outside_line = reference.range.start.line + outside_char = max(0, reference.range.start.character - 5) + else: + prev_ref = all_references[i - 1] + outside_line = prev_ref.range.end.line + outside_char = prev_ref.range.end.character + 5 + + position_outside = Position(line=outside_line, character=outside_char) + filtered_outside = list(filter(by_position(position_outside), all_references)) + assert reference not in filtered_outside, ( + f"Reference {i} should not match position outside its range" + ) + + # Test case: cursor at beginning of file - no references should match + position_start = Position(line=0, character=0) + filtered_start = list(filter(by_position(position_start), all_references)) + assert len(filtered_start) == 0 or all( + ref.range.start.line == 0 and ref.range.start.character <= 0 for ref in filtered_start + ) + + # Test case: cursor at end of file - no references should match (unless there's a reference at the end) + last_line = len(read_file) - 1 + last_char = len(read_file[last_line]) - 1 + position_end = Position(line=last_line, character=last_char) + filtered_end = list(filter(by_position(position_end), all_references)) + assert len(filtered_end) == 0 or all( + ref.range.end.line >= last_line and ref.range.end.character >= last_char + for ref in filtered_end + ) diff --git a/tests/lsp/test_reference_cte.py b/tests/lsp/test_reference_cte.py new file mode 100644 index 0000000000..9bc74bc990 --- /dev/null +++ b/tests/lsp/test_reference_cte.py @@ -0,0 +1,64 @@ +import re +from sqlmesh.core.context import Context +from sqlmesh.lsp.context import LSPContext, ModelTarget +from sqlmesh.lsp.reference import CTEReference, get_references +from sqlmesh.lsp.uri import URI +from lsprotocol.types import Range, Position +import typing as t + + +def test_cte_parsing(): + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + # Find model URIs + sushi_customers_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.customers" in info.names + ) + + with open(sushi_customers_path, "r", encoding="utf-8") as file: + read_file = file.readlines() + + # Find position of the cte reference + ranges = find_ranges_from_regex(read_file, r"current_marketing(?!_outer)") + assert len(ranges) == 2 + position = Position(line=ranges[1].start.line, character=ranges[1].start.character + 4) + references = get_references(lsp_context, URI.from_path(sushi_customers_path), position) + assert len(references) == 1 + assert references[0].path == sushi_customers_path + assert isinstance(references[0], CTEReference) + assert ( + references[0].range.start.line == ranges[1].start.line + ) # The reference location (where we clicked) + assert ( + references[0].target_range.start.line == ranges[0].start.line + ) # The CTE definition location + + # Find the position of the current_marketing_outer reference + ranges = find_ranges_from_regex(read_file, r"current_marketing_outer") + assert len(ranges) == 2 + position = Position(line=ranges[1].start.line, character=ranges[1].start.character + 4) + references = get_references(lsp_context, URI.from_path(sushi_customers_path), position) + assert len(references) == 1 + assert references[0].path == sushi_customers_path + assert isinstance(references[0], CTEReference) + assert ( + references[0].range.start.line == ranges[1].start.line + ) # The reference location (where we clicked) + assert ( + references[0].target_range.start.line == ranges[0].start.line + ) # The CTE definition location + + +def find_ranges_from_regex(read_file: t.List[str], regex: str) -> t.List[Range]: + """Find all ranges in the read file that match the regex.""" + return [ + Range( + start=Position(line=line_number, character=match.start()), + end=Position(line=line_number, character=match.end()), + ) + for line_number, line in enumerate(read_file) + for match in [m for m in [re.search(regex, line)] if m] + ] diff --git a/tests/lsp/test_reference_cte_find_all.py b/tests/lsp/test_reference_cte_find_all.py new file mode 100644 index 0000000000..dabe1589e2 --- /dev/null +++ b/tests/lsp/test_reference_cte_find_all.py @@ -0,0 +1,139 @@ +from lsprotocol.types import Position +from sqlmesh.core.context import Context +from sqlmesh.lsp.context import LSPContext, ModelTarget +from sqlmesh.lsp.reference import get_cte_references +from sqlmesh.lsp.uri import URI +from tests.lsp.test_reference_cte import find_ranges_from_regex + + +def test_cte_find_all_references(): + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + sushi_customers_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.customers" in info.names + ) + + with open(sushi_customers_path, "r", encoding="utf-8") as file: + read_file = file.readlines() + + # Test finding all references of "current_marketing" + ranges = find_ranges_from_regex(read_file, r"current_marketing(?!_outer)") + assert len(ranges) == 2 # regex finds 2 occurrences (definition and FROM clause) + + # Click on the CTE definition + position = Position(line=ranges[0].start.line, character=ranges[0].start.character + 4) + references = get_cte_references(lsp_context, URI.from_path(sushi_customers_path), position) + # Should find the definition, FROM clause, and column prefix usages + assert len(references) == 4 # definition + FROM + 2 column prefix uses + assert all(ref.path == sushi_customers_path for ref in references) + + reference_ranges = [ref.range for ref in references] + for expected_range in ranges: + assert any( + ref_range.start.line == expected_range.start.line + and ref_range.start.character == expected_range.start.character + for ref_range in reference_ranges + ), ( + f"Expected to find reference at line {expected_range.start.line}, char {expected_range.start.character}" + ) + + # Click on the CTE usage + position = Position(line=ranges[1].start.line, character=ranges[1].start.character + 4) + references = get_cte_references(lsp_context, URI.from_path(sushi_customers_path), position) + + # Should find the same references + assert len(references) == 4 # definition + FROM + 2 column prefix uses + assert all(ref.path == sushi_customers_path for ref in references) + + reference_ranges = [ref.range for ref in references] + for expected_range in ranges: + assert any( + ref_range.start.line == expected_range.start.line + and ref_range.start.character == expected_range.start.character + for ref_range in reference_ranges + ), ( + f"Expected to find reference at line {expected_range.start.line}, char {expected_range.start.character}" + ) + + +def test_cte_find_all_references_outer(): + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + sushi_customers_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.customers" in info.names + ) + + with open(sushi_customers_path, "r", encoding="utf-8") as file: + read_file = file.readlines() + + # Test finding all references of "current_marketing_outer" + ranges = find_ranges_from_regex(read_file, r"current_marketing_outer") + assert len(ranges) == 2 + + # Click on the CTE definition + position = Position(line=ranges[0].start.line, character=ranges[0].start.character + 4) + references = get_cte_references(lsp_context, URI.from_path(sushi_customers_path), position) + + # Should find both the definition and the usage + assert len(references) == 2 + assert all(ref.path == sushi_customers_path for ref in references) + + # Verify that we found both occurrences + reference_ranges = [ref.range for ref in references] + for expected_range in ranges: + assert any( + ref_range.start.line == expected_range.start.line + and ref_range.start.character == expected_range.start.character + for ref_range in reference_ranges + ), ( + f"Expected to find reference at line {expected_range.start.line}, char {expected_range.start.character}" + ) + + # Click on the CTE usage + position = Position(line=ranges[1].start.line, character=ranges[1].start.character + 4) + references = get_cte_references(lsp_context, URI.from_path(sushi_customers_path), position) + + # Should find the same references + assert len(references) == 2 + assert all(ref.path == sushi_customers_path for ref in references) + + reference_ranges = [ref.range for ref in references] + for expected_range in ranges: + assert any( + ref_range.start.line == expected_range.start.line + and ref_range.start.character == expected_range.start.character + for ref_range in reference_ranges + ), ( + f"Expected to find reference at line {expected_range.start.line}, char {expected_range.start.character}" + ) + + +def test_cte_no_references_on_non_cte(): + # Test that clicking on non-CTE elements returns nothing, once this is supported adapt this test accordingly + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + sushi_customers_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.customers" in info.names + ) + + with open(sushi_customers_path, "r", encoding="utf-8") as file: + read_file = file.readlines() + + # Click on a regular table reference + ranges = find_ranges_from_regex(read_file, r"sushi\.orders") + assert len(ranges) >= 1 + + position = Position(line=ranges[0].start.line, character=ranges[0].start.character + 4) + references = get_cte_references(lsp_context, URI.from_path(sushi_customers_path), position) + + # Should find no references since this is not a CTE + assert len(references) == 0 diff --git a/tests/lsp/test_reference_external_model.py b/tests/lsp/test_reference_external_model.py new file mode 100644 index 0000000000..25de22f10f --- /dev/null +++ b/tests/lsp/test_reference_external_model.py @@ -0,0 +1,122 @@ +import os +from pathlib import Path + +from sqlmesh import Config +from sqlmesh.core.context import Context +from sqlmesh.core.linter.helpers import read_range_from_file +from sqlmesh.core.linter.rule import Position +from sqlmesh.lsp.context import LSPContext, ModelTarget +from sqlmesh.lsp.reference import get_references +from sqlmesh.lsp.uri import URI +from sqlmesh.utils.lineage import ExternalModelReference +from tests.utils.test_filesystem import create_temp_file +import typing as t + + +def test_reference() -> None: + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + # Find model URIs + customers = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.customers" in info.names + ) + + # Position of reference in file sushi.customers for sushi.raw_demographics + position = Position(line=42, character=20) + references = get_references(lsp_context, URI.from_path(customers), position) + + assert len(references) == 1 + reference = references[0] + assert isinstance(reference, ExternalModelReference) + path = reference.path + assert path is not None + assert str(path).endswith("external_models.yaml") + + source_range = read_range_from_file(customers, reference.range) + assert source_range == "raw.demographics" + + if reference.target_range is None: + raise AssertionError("Reference target range should not be None") + path = reference.path + assert path is not None + target_range = read_range_from_file(path, reference.target_range) + assert target_range == "raw.demographics" + + +def test_unregistered_external_model(tmp_path: Path): + model_path = tmp_path / "models" / "foo.sql" + contents = "MODEL (name test.foo, kind FULL); SELECT * FROM external_model" + create_temp_file(tmp_path, model_path, contents) + ctx = Context(paths=[tmp_path], config=Config()) + lsp_context = LSPContext(ctx) + + uri = URI.from_path(model_path) + references = get_references(lsp_context, uri, Position(line=0, character=len(contents) - 3)) + + assert len(references) == 1 + reference = references[0] + assert isinstance(reference, ExternalModelReference) + assert reference.path is None + assert reference.target_range is None + assert reference.markdown_description == "Unregistered external model" + assert read_range_from_file(model_path, reference.range) == "external_model" + + +def test_unregistered_external_model_with_schema( + copy_to_temp_path: t.Callable[[str], list[Path]], +) -> None: + """ + Tests that the linter correctly identifies unregistered external model dependencies. + + This test removes the `external_models.yaml` file from the sushi example project, + enables the linter, and verifies that the linter raises a violation for a model + that depends on unregistered external models. + """ + sushi_paths = copy_to_temp_path("examples/sushi") + sushi_path = sushi_paths[0] + + # Remove the external_models.yaml file + os.remove(sushi_path / "external_models.yaml") + + # Override the config.py to turn on lint + with open(sushi_path / "config.py", "r") as f: + read_file = f.read() + + before = """ linter=LinterConfig( + enabled=False, + rules=[ + "ambiguousorinvalidcolumn", + "invalidselectstarexpansion", + "noselectstar", + "nomissingaudits", + "nomissingowner", + "nomissingexternalmodels", + ], + ),""" + after = """linter=LinterConfig(enabled=True, rules=["nomissingexternalmodels"]),""" + read_file = read_file.replace(before, after) + assert after in read_file + with open(sushi_path / "config.py", "w") as f: + f.writelines(read_file) + + # Load the context with the temporary sushi path + context = Context(paths=[sushi_path]) + + model = context.get_model("sushi.customers") + if model is None: + raise AssertionError("Model 'sushi.customers' not found in context") + + lsp_context = LSPContext(context) + path = model._path + assert path is not None + uri = URI.from_path(path) + references = get_references(lsp_context, uri, Position(line=42, character=20)) + + assert len(references) == 1 + reference = references[0] + assert isinstance(reference, ExternalModelReference) + assert reference.path is None + assert read_range_from_file(path, reference.range) == "raw.demographics" diff --git a/tests/lsp/test_reference_macro.py b/tests/lsp/test_reference_macro.py new file mode 100644 index 0000000000..3ee7c48b3b --- /dev/null +++ b/tests/lsp/test_reference_macro.py @@ -0,0 +1,29 @@ +from sqlmesh.core.context import Context +from sqlmesh.lsp.context import LSPContext, ModelTarget +from sqlmesh.lsp.reference import MacroReference, get_macro_definitions_for_a_path +from sqlmesh.lsp.uri import URI + + +def test_macro_references() -> None: + """Test that macro references (e.g., @ADD_ONE, @MULTIPLY) have proper go-to-definition support.""" + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + # Find the top_waiters model that uses macros + top_waiters_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.top_waiters" in info.names + ) + + top_waiters_uri = URI.from_path(top_waiters_path) + macro_references = get_macro_definitions_for_a_path(lsp_context, top_waiters_uri) + + # We expect 3 macro references: @ADD_ONE, @MULTIPLY, @SQL_LITERAL + assert len(macro_references) == 3 + + # Check that all references point to the utils.py file + for ref in macro_references: + assert isinstance(ref, MacroReference) + assert URI.from_path(ref.path).value.endswith("sushi/macros/utils.py") + assert ref.target_range is not None diff --git a/tests/lsp/test_reference_macro_find_all.py b/tests/lsp/test_reference_macro_find_all.py new file mode 100644 index 0000000000..328924599a --- /dev/null +++ b/tests/lsp/test_reference_macro_find_all.py @@ -0,0 +1,195 @@ +from lsprotocol.types import Position +from sqlmesh.core.context import Context +from sqlmesh.lsp.context import LSPContext, ModelTarget +from sqlmesh.lsp.reference import ( + get_macro_find_all_references, + get_macro_definitions_for_a_path, +) +from sqlmesh.lsp.uri import URI +from sqlmesh.core.linter.helpers import ( + read_range_from_file, + Range as SQLMeshRange, + Position as SQLMeshPosition, +) + + +def test_find_all_references_for_macro_add_one(): + """Test finding all references to the @ADD_ONE macro.""" + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + # Find the top_waiters model that uses @ADD_ONE macro + top_waiters_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.top_waiters" in info.names + ) + + top_waiters_uri = URI.from_path(top_waiters_path) + macro_references = get_macro_definitions_for_a_path(lsp_context, top_waiters_uri) + + # Find the @ADD_ONE reference + add_one_ref = next((ref for ref in macro_references if ref.range.start.line == 12), None) + assert add_one_ref is not None, "Should find @ADD_ONE reference in top_waiters" + + # Click on the @ADD_ONE macro at line 13, character 5 (the @ symbol) + position = Position(line=12, character=5) + + all_references = get_macro_find_all_references(lsp_context, top_waiters_uri, position) + + # Should find at least 2 references: the definition and the usage in top_waiters + assert len(all_references) >= 2, f"Expected at least 2 references, found {len(all_references)}" + + # Verify the macro definition is included + definition_refs = [ref for ref in all_references if "utils.py" in str(ref.path)] + assert len(definition_refs) >= 1, "Should include the macro definition in utils.py" + + # Verify the usage in top_waiters is included + usage_refs = [ref for ref in all_references if "top_waiters" in str(ref.path)] + assert len(usage_refs) >= 1, "Should include the usage in top_waiters.sql" + + expected_files = { + "utils.py": {"pattern": r"def add_one", "expected_content": "def add_one"}, + "customers.sql": {"pattern": r"@ADD_ONE\s*\(", "expected_content": "ADD_ONE"}, + "top_waiters.sql": {"pattern": r"@ADD_ONE\s*\(", "expected_content": "ADD_ONE"}, + } + + for expected_file, expectations in expected_files.items(): + file_refs = [ref for ref in all_references if expected_file in str(ref.path)] + assert len(file_refs) >= 1, f"Should find at least one reference in {expected_file}" + + file_ref = file_refs[0] + file_path = file_ref.path + + sqlmesh_range = SQLMeshRange( + start=SQLMeshPosition( + line=file_ref.range.start.line, character=file_ref.range.start.character + ), + end=SQLMeshPosition( + line=file_ref.range.end.line, character=file_ref.range.end.character + ), + ) + + # Read the content at the reference location + content = read_range_from_file(file_path, sqlmesh_range) + assert content.startswith(expectations["expected_content"]), ( + f"Expected content to start with '{expectations['expected_content']}', got: {content}" + ) + + +def test_find_all_references_for_macro_multiply(): + """Test finding all references to the @MULTIPLY macro.""" + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + # Find the top_waiters model that uses @MULTIPLY macro + top_waiters_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.top_waiters" in info.names + ) + + top_waiters_uri = URI.from_path(top_waiters_path) + macro_references = get_macro_definitions_for_a_path(lsp_context, top_waiters_uri) + + # Find the @MULTIPLY reference + multiply_ref = next((ref for ref in macro_references if ref.range.start.line == 13), None) + assert multiply_ref is not None, "Should find @MULTIPLY reference in top_waiters" + + # Click on the @MULTIPLY macro at line 14, character 5 (the @ symbol) + position = Position(line=13, character=5) + all_references = get_macro_find_all_references(lsp_context, top_waiters_uri, position) + + # Should find at least 2 references: the definition and the usage + assert len(all_references) >= 2, f"Expected at least 2 references, found {len(all_references)}" + + # Verify both definition and usage are included + assert any("utils.py" in str(ref.path) for ref in all_references), ( + "Should include macro definition" + ) + assert any("top_waiters" in str(ref.path) for ref in all_references), "Should include usage" + + +def test_find_all_references_for_sql_literal_macro(): + """Test finding references to @SQL_LITERAL macro .""" + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + # Find the top_waiters model that uses @SQL_LITERAL macro + top_waiters_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.top_waiters" in info.names + ) + + top_waiters_uri = URI.from_path(top_waiters_path) + macro_references = get_macro_definitions_for_a_path(lsp_context, top_waiters_uri) + + # Find the @SQL_LITERAL reference + sql_literal_ref = next((ref for ref in macro_references if ref.range.start.line == 14), None) + assert sql_literal_ref is not None, "Should find @SQL_LITERAL reference in top_waiters" + + # Click on the @SQL_LITERAL macro + position = Position(line=14, character=5) + all_references = get_macro_find_all_references(lsp_context, top_waiters_uri, position) + + # For user-defined macros in utils.py, should find references + assert len(all_references) >= 2, f"Expected at least 2 references, found {len(all_references)}" + + +def test_find_references_from_outside_macro_position(): + """Test that clicking outside a macro doesn't return macro references.""" + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + top_waiters_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.top_waiters" in info.names + ) + + top_waiters_uri = URI.from_path(top_waiters_path) + + # Click on a position that is not on a macro + position = Position(line=0, character=0) # First line, which is a comment + all_references = get_macro_find_all_references(lsp_context, top_waiters_uri, position) + + # Should return empty list when not on a macro + assert len(all_references) == 0, "Should not find macro references when not on a macro" + + +def test_multi_repo_macro_references(): + """Test finding macro references across multiple repositories.""" + context = Context(paths=["examples/multi/repo_1", "examples/multi/repo_2"], gateway="memory") + lsp_context = LSPContext(context) + + # Find model 'd' which uses macros from repo_2 + d_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "silver.d" in info.names + ) + + d_uri = URI.from_path(d_path) + macro_references = get_macro_definitions_for_a_path(lsp_context, d_uri) + + if macro_references: + # Click on the second macro reference which appears under the same name in repo_1 ('dup') + first_ref = macro_references[1] + position = Position( + line=first_ref.range.start.line, character=first_ref.range.start.character + 1 + ) + all_references = get_macro_find_all_references(lsp_context, d_uri, position) + + # Should find the definition and usage + assert len(all_references) == 2, f"Expected 2 references, found {len(all_references)}" + + # Verify references from repo_2 + assert any("repo_2" in str(ref.path) for ref in all_references), ( + "Should find macro in repo_2" + ) + + # But not references in repo_1 since despite identical name they're different macros + assert not any("repo_1" in str(ref.path) for ref in all_references), ( + "Shouldn't find macro in repo_1" + ) diff --git a/tests/lsp/test_reference_macro_multi.py b/tests/lsp/test_reference_macro_multi.py new file mode 100644 index 0000000000..3902c0b275 --- /dev/null +++ b/tests/lsp/test_reference_macro_multi.py @@ -0,0 +1,24 @@ +from sqlmesh.core.context import Context +from sqlmesh.lsp.context import LSPContext, ModelTarget +from sqlmesh.lsp.reference import MacroReference, get_macro_definitions_for_a_path +from sqlmesh.lsp.uri import URI + + +def test_macro_references_multirepo() -> None: + context = Context(paths=["examples/multi/repo_1", "examples/multi/repo_2"], gateway="memory") + lsp_context = LSPContext(context) + + d_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "silver.d" in info.names + ) + + d = URI.from_path(d_path) + macro_references = get_macro_definitions_for_a_path(lsp_context, d) + + assert len(macro_references) == 2 + for ref in macro_references: + assert isinstance(ref, MacroReference) + assert str(URI.from_path(ref.path).value).endswith("multi/repo_2/macros/__init__.py") + assert ref.target_range is not None diff --git a/tests/lsp/test_reference_model_column_prefix.py b/tests/lsp/test_reference_model_column_prefix.py new file mode 100644 index 0000000000..082ee9c8e6 --- /dev/null +++ b/tests/lsp/test_reference_model_column_prefix.py @@ -0,0 +1,209 @@ +from pathlib import Path + +from sqlmesh.cli.project_init import init_example_project +from sqlmesh.core.context import Context +from sqlmesh.core.linter.rule import Position +from sqlmesh.lsp.context import LSPContext, ModelTarget +from sqlmesh.lsp.reference import get_all_references +from sqlmesh.lsp.uri import URI +from tests.lsp.test_reference_cte import find_ranges_from_regex + + +def test_model_reference_with_column_prefix(): + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + sushi_customers_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.customers" in info.names + ) + + with open(sushi_customers_path, "r", encoding="utf-8") as file: + read_file = file.readlines() + + # Test finding references for "sushi.orders" + ranges = find_ranges_from_regex(read_file, r"sushi\.orders") + + # Click on the table reference in FROM clause (should be the second occurrence) + from_clause_range = None + for r in ranges: + line_content = read_file[r.start.line].strip() + if "FROM" in line_content: + from_clause_range = r + break + + assert from_clause_range is not None, "Should find FROM clause with sushi.orders" + + position = Position( + line=from_clause_range.start.line, character=from_clause_range.start.character + 6 + ) + + model_refs = get_all_references(lsp_context, URI.from_path(sushi_customers_path), position) + + assert len(model_refs) >= 6 + + # Verify that we have the FROM clause reference + assert any(ref.range.start.line == from_clause_range.start.line for ref in model_refs), ( + "Should find FROM clause reference" + ) + + +def test_column_prefix_references_are_found(): + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + sushi_customers_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.customers" in info.names + ) + + with open(sushi_customers_path, "r", encoding="utf-8") as file: + read_file = file.readlines() + + # Find all occurrences of sushi.orders in the file + ranges = find_ranges_from_regex(read_file, r"sushi\.orders") + + # Should find exactly 1 in FROM clause with column prefix + assert len(ranges) == 1, f"Expected 1 occurrence of 'sushi.orders', found {len(ranges)}" + + # Verify we have the expected lines + line_contents = [read_file[r.start.line].strip() for r in ranges] + + # Should find FROM clause + assert any("FROM sushi.orders" in content for content in line_contents), ( + "Should find FROM clause with sushi.orders" + ) + + +def test_quoted_uppercase_table_and_column_references(tmp_path: Path): + # Initialize example project in temporary directory with case sensitive normalization + init_example_project( + tmp_path, engine_type="duckdb", dialect="duckdb,normalization_strategy=case_sensitive" + ) + + # Create a model with quoted uppercase schema and table names + models_dir = tmp_path / "models" + + # First, create the uppercase SUSHI.orders model that will be referenced + uppercase_orders_path = models_dir / "uppercase_orders.sql" + uppercase_orders_path.write_text("""MODEL ( + name "SUSHI".orders, + kind FULL +); + +SELECT + 1 as id, + 1 as customer_id, + 1 as item_id""") + + # Second, create the lowercase sushi.orders model that will be referenced + lowercase_orders_path = models_dir / "lowercase_orders.sql" + lowercase_orders_path.write_text("""MODEL ( + name sushi.orders, + kind FULL +); + +SELECT + 1 as id, + 1 as customer_id""") + + quoted_test_path = models_dir / "quoted_test.sql" + quoted_test_path.write_text("""MODEL ( + name "SUSHI".quoted_test, + kind FULL +); + +SELECT + o.id, + o.customer_id, + o.item_id, + c.item_id as c_item_id +FROM "SUSHI".orders AS o, sushi.orders as c +WHERE "SUSHI".orders.id > 0 + AND "SUSHI".orders.customer_id IS NOT NULL + AND sushi.orders.id > 0""") + + context = Context(paths=tmp_path) + lsp_context = LSPContext(context) + + # Find the quoted test model + quoted_test_model_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and '"SUSHI".quoted_test' in info.names + ) + + with open(quoted_test_model_path, "r", encoding="utf-8") as file: + read_file = file.readlines() + + # Test finding references for quoted "SUSHI".orders + ranges = find_ranges_from_regex(read_file, r'"SUSHI"\.orders') + + # Should find 3 occurrences: FROM clause and 2 in WHERE clause with column prefix + assert len(ranges) == 3, f"Expected 3 occurrences of '\"SUSHI\".orders', found {len(ranges)}" + + # Click on the table reference in FROM clause + from_clause_range = None + for r in ranges: + line_content = read_file[r.start.line].strip() + if "FROM" in line_content: + from_clause_range = r + break + + assert from_clause_range is not None, 'Should find FROM clause with "SUSHI".orders' + + position = Position( + line=from_clause_range.start.line, character=from_clause_range.start.character + 5 + ) + + model_refs = get_all_references(lsp_context, URI.from_path(quoted_test_model_path), position) + + # Should find only references to "SUSHI".orders (3 total: FROM clause and 2 column prefixes in WHERE) + # The lowercase sushi.orders should NOT be included if case sensitivity is working + assert len(model_refs) == 4, ( + f'Expected exactly 3 references for "SUSHI".orders, found {len(model_refs)}' + ) + + # Verify that we have all 3 references + ref_lines = [ref.range.start.line for ref in model_refs] + + # Count how many references are on each line + from_line = from_clause_range.start.line + where_lines = [r.start.line for r in ranges if r.start.line != from_line] + + assert from_line in ref_lines, "Should find FROM clause reference" + for where_line in where_lines: + assert where_line in ref_lines, f"Should find WHERE clause reference on line {where_line}" + + # Now test that lowercase sushi.orders references are separate + lowercase_ranges = find_ranges_from_regex(read_file, r"sushi\.orders") + + # Should find 2 occurrences: FROM clause and 1 in WHERE clause + assert len(lowercase_ranges) == 2, ( + f"Expected 2 occurrences of 'sushi.orders', found {len(lowercase_ranges)}" + ) + + # Click on the lowercase table reference + lowercase_from_range = None + for r in lowercase_ranges: + line_content = read_file[r.start.line].strip() + if "FROM" in line_content: + lowercase_from_range = r + break + + assert lowercase_from_range is not None, "Should find FROM clause with sushi.orders" + + lowercase_position = Position( + line=lowercase_from_range.start.line, character=lowercase_from_range.start.character + 5 + ) + + lowercase_refs = get_all_references( + lsp_context, URI.from_path(quoted_test_model_path), lowercase_position + ) + + # Should find only references to lowercase sushi.orders, NOT the uppercase ones + assert len(lowercase_refs) == 3, ( + f"Expected exactly 2 references for sushi.orders, found {len(lowercase_refs)}" + ) diff --git a/tests/lsp/test_reference_model_find_all.py b/tests/lsp/test_reference_model_find_all.py new file mode 100644 index 0000000000..cd9c0a3a1c --- /dev/null +++ b/tests/lsp/test_reference_model_find_all.py @@ -0,0 +1,312 @@ +from lsprotocol.types import Position +from sqlmesh.core.context import Context +from sqlmesh.lsp.context import LSPContext, ModelTarget, AuditTarget +from sqlmesh.lsp.reference import ( + get_model_find_all_references, + get_model_definitions_for_a_path, +) +from sqlmesh.lsp.uri import URI +from tests.lsp.test_reference_cte import find_ranges_from_regex + + +def test_find_references_for_model_usages(): + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + # Find customers model which uses sushi.orders + customers_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.customers" in info.names + ) + + with open(customers_path, "r", encoding="utf-8") as file: + read_file = file.readlines() + + # Find sushi.orders reference + ranges = find_ranges_from_regex(read_file, r"sushi\.orders") + assert len(ranges) >= 1, "Should find at least one reference to sushi.orders" + + # Click on the model reference + position = Position(line=ranges[0].start.line, character=ranges[0].start.character + 6) + references = get_model_find_all_references(lsp_context, URI.from_path(customers_path), position) + assert len(references) >= 6, ( + f"Expected at least 6 references to sushi.orders (including column prefix), found {len(references)}" + ) + + # Verify expected files are present + reference_files = {str(ref.path) for ref in references} + expected_patterns = [ + "orders", + "customers", + "customer_revenue_by_day", + "customer_revenue_lifetime", + "latest_order", + "waiter_revenue_by_day", + ] + for pattern in expected_patterns: + assert any(pattern in uri for uri in reference_files), ( + f"Missing reference in file containing '{pattern}'" + ) + + # Verify exact ranges for each reference pattern + # Note: customers file has multiple references due to column prefix support + expected_ranges = { + "orders": [(0, 0, 0, 0)], # the start for the model itself + "customers": [(30, 7, 30, 19)], # FROM clause + "waiter_revenue_by_day": [(19, 5, 19, 17)], + "customer_revenue_lifetime": [(38, 7, 38, 19)], + "customer_revenue_by_day": [(33, 5, 33, 17)], + "latest_order": [(12, 5, 12, 17)], + } + + # Group references by file pattern + refs_by_pattern = {} + for ref in references: + matched_pattern = None + for pattern in expected_patterns: + if pattern in str(ref.path): + matched_pattern = pattern + break + + if matched_pattern: + if matched_pattern not in refs_by_pattern: + refs_by_pattern[matched_pattern] = [] + refs_by_pattern[matched_pattern].append(ref) + + # Verify each pattern has the expected references + for pattern, expected_range_list in expected_ranges.items(): + assert pattern in refs_by_pattern, f"Missing references for pattern '{pattern}'" + + actual_refs = refs_by_pattern[pattern] + assert len(actual_refs) == len(expected_range_list), ( + f"Expected {len(expected_range_list)} references for {pattern}, found {len(actual_refs)}" + ) + + # Sort both actual and expected by line number for consistent comparison + actual_refs_sorted = sorted( + actual_refs, key=lambda r: (r.range.start.line, r.range.start.character) + ) + expected_sorted = sorted(expected_range_list, key=lambda r: (r[0], r[1])) + + for i, (ref, expected_range) in enumerate(zip(actual_refs_sorted, expected_sorted)): + expected_start_line, expected_start_char, expected_end_line, expected_end_char = ( + expected_range + ) + + assert ref.range.start.line == expected_start_line, ( + f"Expected {pattern} reference #{i + 1} start line {expected_start_line}, found {ref.range.start.line}" + ) + assert ref.range.start.character == expected_start_char, ( + f"Expected {pattern} reference #{i + 1} start character {expected_start_char}, found {ref.range.start.character}" + ) + assert ref.range.end.line == expected_end_line, ( + f"Expected {pattern} reference #{i + 1} end line {expected_end_line}, found {ref.range.end.line}" + ) + assert ref.range.end.character == expected_end_char, ( + f"Expected {pattern} reference #{i + 1} end character {expected_end_char}, found {ref.range.end.character}" + ) + + +def test_find_references_for_marketing_model(): + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + customers_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.customers" in info.names + ) + + with open(customers_path, "r", encoding="utf-8") as file: + read_file = file.readlines() + + # Find sushi.marketing reference + marketing_ranges = find_ranges_from_regex(read_file, r"sushi\.marketing") + assert len(marketing_ranges) >= 1, "Should find at least one reference to sushi.marketing" + + position = Position( + line=marketing_ranges[0].start.line, character=marketing_ranges[0].start.character + 8 + ) + references = get_model_find_all_references(lsp_context, URI.from_path(customers_path), position) + + # sushi.marketing should have exactly 2 references: model itself + customers usage + assert len(references) == 2, ( + f"Expected exactly 2 references to sushi.marketing, found {len(references)}" + ) + + # Verify files are present + reference_files = {str(ref.path) for ref in references} + expected_patterns = ["marketing", "customers"] + for pattern in expected_patterns: + assert any(pattern in uri for uri in reference_files), ( + f"Missing reference in file containing '{pattern}'" + ) + + +def test_find_references_for_python_model(): + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + # Start from customer_revenue_by_day which references sushi.items + revenue_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.customer_revenue_by_day" in info.names + ) + + with open(revenue_path, "r", encoding="utf-8") as file: + revenue_file = file.readlines() + + # Find sushi.items reference + items_ranges = find_ranges_from_regex(revenue_file, r"sushi\.items") + assert len(items_ranges) >= 1, "Should find at least one reference to sushi.items" + + position = Position( + line=items_ranges[0].start.line, character=items_ranges[0].start.character + 6 + ) + references = get_model_find_all_references(lsp_context, URI.from_path(revenue_path), position) + assert len(references) == 5 + + # Verify expected files + reference_files = {str(ref.path) for ref in references} + + # Models and also the Audit which references it: assert_item_price_above_zero + expected_patterns = [ + "items", + "customer_revenue_by_day", + "customer_revenue_lifetime", + "waiter_revenue_by_day", + "assert_item_price_above_zero", + ] + for pattern in expected_patterns: + assert any(pattern in uri for uri in reference_files), ( + f"Missing reference in file containing '{pattern}'" + ) + + +def test_waiter_revenue_by_day_multiple_references(): + # Test sushi.waiter_revenue_by_day which is referenced 3 times in top_waiters + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + top_waiters_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.top_waiters" in info.names + ) + + with open(top_waiters_path, "r", encoding="utf-8") as file: + top_waiters_file = file.readlines() + + # Find multiple references to sushi.waiter_revenue_by_day + waiter_revenue_ranges = find_ranges_from_regex( + top_waiters_file, r"sushi\.waiter_revenue_by_day" + ) + assert len(waiter_revenue_ranges) >= 2, ( + "Should find at least 2 references to sushi.waiter_revenue_by_day in top_waiters" + ) + + # Click on the first reference + position = Position( + line=waiter_revenue_ranges[0].start.line, + character=waiter_revenue_ranges[0].start.character + 10, + ) + references = get_model_find_all_references( + lsp_context, URI.from_path(top_waiters_path), position + ) + + # Should find model definition + 3 references in top_waiters = 4 total + assert len(references) == 4, ( + f"Expected exactly 4 references to sushi.waiter_revenue_by_day, found {len(references)}" + ) + + # Count references in top_waiters file + top_waiters_refs = [ref for ref in references if "top_waiters" in str(ref.path)] + assert len(top_waiters_refs) == 3, ( + f"Expected exactly 3 references in top_waiters, found {len(top_waiters_refs)}" + ) + + # Verify model definition is included + assert any("waiter_revenue_by_day" in str(ref.path) for ref in references), ( + "Should include model definition" + ) + + +def test_precise_character_positions(): + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + customers_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.customers" in info.names + ) + + # Test clicking on different parts of "sushi.orders" reference at line 31 + + # Click on 's' in "sushi" - should work + position = Position(line=30, character=7) + references = get_model_find_all_references(lsp_context, URI.from_path(customers_path), position) + assert len(references) > 0, "Should find references when clicking on 's' in 'sushi'" + + # Click on '.' between sushi and orders - should work + position = Position(line=30, character=12) + references = get_model_find_all_references(lsp_context, URI.from_path(customers_path), position) + assert len(references) > 0, "Should find references when clicking on '.' separator" + + # Click on 'o' in "orders" - should work + position = Position(line=30, character=13) + references = get_model_find_all_references(lsp_context, URI.from_path(customers_path), position) + assert len(references) > 0, "Should find references when clicking on 'o' in 'orders'" + + # Click just before "sushi" - should not work + position = Position(line=30, character=6) + references = get_model_find_all_references(lsp_context, URI.from_path(customers_path), position) + assert len(references) == 0, "Should not find references when clicking just before 'sushi'" + + # Click just after "orders" - should not work + position = Position(line=30, character=21) + references = get_model_find_all_references(lsp_context, URI.from_path(customers_path), position) + assert len(references) == 0, "Should not find references when clicking just after 'orders'" + + +def test_audit_model_references(): + # Tests finding model references in audits + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + # Find audit files + audit_paths = [path for path, info in lsp_context.map.items() if isinstance(info, AuditTarget)] + + if audit_paths: + audit_path = audit_paths[0] + refs = get_model_definitions_for_a_path(lsp_context, URI.from_path(audit_path)) + + # Audits can reference models + if refs: + # Click on the first reference which is: sushi.items + first_ref = refs[0] + position = Position( + line=first_ref.range.start.line, character=first_ref.range.start.character + 1 + ) + references = get_model_find_all_references( + lsp_context, URI.from_path(audit_path), position + ) + + assert len(references) == 5, "Should find references from audit files as well" + + reference_files = {str(ref.path) for ref in references} + + # Models and also the Audit which references it: assert_item_price_above_zero + expected_patterns = [ + "items", + "customer_revenue_by_day", + "customer_revenue_lifetime", + "waiter_revenue_by_day", + "assert_item_price_above_zero", + ] + for pattern in expected_patterns: + assert any(pattern in uri for uri in reference_files), ( + f"Missing reference in file containing '{pattern}'" + ) diff --git a/tests/lsp/test_rename_cte.py b/tests/lsp/test_rename_cte.py new file mode 100644 index 0000000000..4ca1002c2e --- /dev/null +++ b/tests/lsp/test_rename_cte.py @@ -0,0 +1,212 @@ +from lsprotocol.types import Position +from sqlmesh.core.context import Context +from sqlmesh.lsp.context import LSPContext, ModelTarget +from sqlmesh.lsp.rename import prepare_rename, rename_symbol +from sqlmesh.lsp.uri import URI +from tests.lsp.test_reference_cte import find_ranges_from_regex + + +def test_prepare_rename_cte(): + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + sushi_customers_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.customers" in info.names + ) + + with open(sushi_customers_path, "r", encoding="utf-8") as file: + read_file = file.readlines() + + # Test clicking on CTE definition for "current_marketing" + ranges = find_ranges_from_regex(read_file, r"current_marketing(?!_outer)") + assert len(ranges) == 2 + + # Click on the CTE definition + position = Position(line=ranges[0].start.line, character=ranges[0].start.character + 4) + result = prepare_rename(lsp_context, URI.from_path(sushi_customers_path), position) + + assert result is not None + assert result.placeholder == "cte_name" + assert result.range == ranges[0] # Should return the definition range + + # Test clicking on CTE usage + position = Position(line=ranges[1].start.line, character=ranges[1].start.character + 4) + result = prepare_rename(lsp_context, URI.from_path(sushi_customers_path), position) + + assert result is not None + assert result.placeholder == "cte_name" + assert result.range == ranges[0] # Should still return the definition range + + +def test_prepare_rename_cte_outer(): + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + sushi_customers_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.customers" in info.names + ) + + with open(sushi_customers_path, "r", encoding="utf-8") as file: + read_file = file.readlines() + + # Test clicking on CTE definition for "current_marketing_outer" + ranges = find_ranges_from_regex(read_file, r"current_marketing_outer") + assert len(ranges) == 2 + + # Click on the CTE definition + position = Position(line=ranges[0].start.line, character=ranges[0].start.character + 4) + result = prepare_rename(lsp_context, URI.from_path(sushi_customers_path), position) + + assert result is not None + assert result.placeholder == "cte_name" + assert result.range == ranges[0] # Should return the definition range + + +def test_prepare_rename_non_cte(): + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + sushi_customers_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.customers" in info.names + ) + + with open(sushi_customers_path, "r", encoding="utf-8") as file: + read_file = file.readlines() + + # Click on a regular table reference (not a CTE) + ranges = find_ranges_from_regex(read_file, r"sushi\.orders") + assert len(ranges) >= 1 + + position = Position(line=ranges[0].start.line, character=ranges[0].start.character + 4) + result = prepare_rename(lsp_context, URI.from_path(sushi_customers_path), position) + + assert result is None + + +def test_rename_cte(): + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + sushi_customers_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.customers" in info.names + ) + + with open(sushi_customers_path, "r", encoding="utf-8") as file: + read_file = file.readlines() + + # Test renaming "current_marketing" to "new_marketing" + ranges = find_ranges_from_regex(read_file, r"current_marketing(?!_outer)") + assert len(ranges) == 2 + + # Click on the CTE definition + position = Position(line=ranges[0].start.line, character=ranges[0].start.character + 4) + workspace_edit = rename_symbol( + lsp_context, URI.from_path(sushi_customers_path), position, "new_marketing" + ) + + assert workspace_edit is not None + assert workspace_edit.changes is not None + + uri = URI.from_path(sushi_customers_path).value + assert uri in workspace_edit.changes + + edits = workspace_edit.changes[uri] + + # Should have edited four occurences including column usages + assert len(edits) == 4 + + # Verify that both ranges are being edited + edit_ranges = [edit.range for edit in edits] + for expected_range in ranges: + assert any( + edit_range.start.line == expected_range.start.line + and edit_range.start.character == expected_range.start.character + for edit_range in edit_ranges + ), ( + f"Expected to find edit at line {expected_range.start.line}, char {expected_range.start.character}" + ) + + # Verify that all edits have the new name + assert all(edit.new_text == "new_marketing" for edit in edits) + + # Apply the edits to verify the result + with open(sushi_customers_path, "r", encoding="utf-8") as file: + lines = file.readlines() + + # Apply edits in reverse order to avoid offset issues + sorted_edits = sorted( + edits, key=lambda e: (e.range.start.line, e.range.start.character), reverse=True + ) + for edit in sorted_edits: + line_idx = edit.range.start.line + start_char = edit.range.start.character + end_char = edit.range.end.character + + line = lines[line_idx] + new_line = line[:start_char] + edit.new_text + line[end_char:] + lines[line_idx] = new_line + + # Verify the edited content + edited_content = "".join(lines) + assert "new_marketing" in edited_content + assert "current_marketing" not in edited_content.replace("current_marketing_outer", "") + assert edited_content.count("new_marketing") == 4 + assert ( + " SELECT new_marketing.* FROM new_marketing WHERE new_marketing.customer_id != 100\n" + in lines + ) + + +def test_rename_cte_outer(): + context = Context(paths=["examples/sushi"]) + lsp_context = LSPContext(context) + + sushi_customers_path = next( + path + for path, info in lsp_context.map.items() + if isinstance(info, ModelTarget) and "sushi.customers" in info.names + ) + + with open(sushi_customers_path, "r", encoding="utf-8") as file: + read_file = file.readlines() + + # Test renaming "current_marketing_outer" to "new_marketing_outer" + ranges = find_ranges_from_regex(read_file, r"current_marketing_outer") + assert len(ranges) == 2 + + # Click on the CTE usage + position = Position(line=ranges[1].start.line, character=ranges[1].start.character + 4) + workspace_edit = rename_symbol( + lsp_context, URI.from_path(sushi_customers_path), position, "new_marketing_outer" + ) + + assert workspace_edit is not None + assert workspace_edit.changes is not None + + uri = URI.from_path(sushi_customers_path).value + assert uri in workspace_edit.changes + + edits = workspace_edit.changes[uri] + assert len(edits) == 2 # Should have 2 edits: definition + usage + + # Verify that both ranges are being edited + edit_ranges = [edit.range for edit in edits] + for expected_range in ranges: + assert any( + edit_range.start.line == expected_range.start.line + and edit_range.start.character == expected_range.start.character + for edit_range in edit_ranges + ), ( + f"Expected to find edit at line {expected_range.start.line}, char {expected_range.start.character}" + ) + + # Verify that all edits have the new name + assert all(edit.new_text == "new_marketing_outer" for edit in edits) diff --git a/tests/pyproject.toml b/tests/pyproject.toml new file mode 100644 index 0000000000..73f143bfde --- /dev/null +++ b/tests/pyproject.toml @@ -0,0 +1,30 @@ +[project] +name = "sqlmesh-tests" +dynamic = ["version", "dependencies"] +description = "Tests for SQLMesh" +authors = [{ name = "TobikoData Inc.", email = "engineering@tobikodata.com" }] +license = { text = "Apache License 2.0" } + +[project.urls] +Homepage = "https://sqlmesh.com/" +Documentation = "https://sqlmesh.readthedocs.io/en/stable/" +Repository = "https://github.com/SQLMesh/sqlmesh" +Issues = "https://github.com/SQLMesh/sqlmesh/issues" + +[build-system] +requires = ["setuptools", "setuptools_scm", "toml"] +build-backend = "setuptools.build_meta" + +[tool.setuptools.package-dir] +sqlmesh_tests = "" + +[tool.setuptools_scm] +root = "../" +version_file = "_version.py" +fallback_version = "0.0.0" +local_scheme = "no-local-version" + +[tool.setuptools.package-data] +"*" = ["fixtures/**", "*.toml"] + + diff --git a/tests/schedulers/airflow/conftest.py b/tests/schedulers/airflow/conftest.py deleted file mode 100644 index 789e543247..0000000000 --- a/tests/schedulers/airflow/conftest.py +++ /dev/null @@ -1,49 +0,0 @@ -import logging -import os - -import pytest -from tenacity import retry, stop_after_attempt, wait_fixed - -from sqlmesh.core.config import AirflowSchedulerConfig -from sqlmesh.schedulers.airflow.client import AirflowClient -from sqlmesh.utils import str_to_bool - -logger = logging.getLogger(__name__) - - -@pytest.fixture(scope="session") -def is_docker() -> bool: - return str_to_bool(os.environ.get("IS_DOCKER")) - - -@pytest.fixture(scope="session") -def airflow_host(is_docker: bool) -> str: - return "airflow-webserver" if is_docker else "localhost" - - -@pytest.fixture(scope="session") -def airflow_scheduler_backend(airflow_host: str) -> AirflowSchedulerConfig: - return _get_airflow_scheduler_backend(airflow_host) - - -@pytest.fixture(scope="session") -def airflow_client(airflow_scheduler_backend: AirflowSchedulerConfig) -> AirflowClient: - return airflow_scheduler_backend.get_client() - - -@retry(wait=wait_fixed(3), stop=stop_after_attempt(10), reraise=True) -def _get_airflow_scheduler_backend(airflow_host: str) -> AirflowSchedulerConfig: - backend = AirflowSchedulerConfig(airflow_url=f"http://{airflow_host}:8080/") - client = backend.get_client() - - try: - client.get_all_dags() - except Exception: - logger.info( - "Failed to fetch the list of DAGs from Airflow. Make sure the test Airflow cluster is running" - ) - raise - - logger.info("The Airflow Client is ready") - - return backend diff --git a/tests/schedulers/airflow/operators/fixtures.py b/tests/schedulers/airflow/operators/fixtures.py deleted file mode 100644 index d77ac00267..0000000000 --- a/tests/schedulers/airflow/operators/fixtures.py +++ /dev/null @@ -1,10 +0,0 @@ -import os -from unittest import mock - -import pytest - - -@pytest.fixture -def set_airflow_as_library(): - with mock.patch.dict(os.environ, {"_AIRFLOW__AS_LIBRARY": "1"}): - yield diff --git a/tests/schedulers/airflow/operators/test_sensor.py b/tests/schedulers/airflow/operators/test_sensor.py deleted file mode 100644 index 7bb4a2eb68..0000000000 --- a/tests/schedulers/airflow/operators/test_sensor.py +++ /dev/null @@ -1,165 +0,0 @@ -from unittest.mock import call - -import pytest -from airflow.utils.context import Context -from pytest_mock.plugin import MockerFixture - -from sqlmesh.core.dialect import parse_one -from sqlmesh.core.model import SqlModel -from sqlmesh.core.snapshot import SnapshotChangeCategory -from sqlmesh.schedulers.airflow.operators.sensor import ( - ExternalSensor, - HighWaterMarkSensor, -) -from sqlmesh.utils.date import to_datetime - -pytest_plugins = ["tests.schedulers.airflow.operators.fixtures"] -pytestmark = pytest.mark.airflow - - -def test_no_current_hwm(mocker: MockerFixture, make_snapshot, random_name, set_airflow_as_library): - this_snapshot = make_snapshot(SqlModel(name="this", query=parse_one("select 1, ds"))) - this_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - - target_snapshot = make_snapshot(SqlModel(name="target", query=parse_one("select 2, ds"))) - target_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - - task = HighWaterMarkSensor( - target_snapshot_info=target_snapshot.table_info, - this_snapshot=this_snapshot, - task_id="test_hwm_task", - ) - - get_snapshots_mock = mocker.patch( - "sqlmesh.core.state_sync.cache.CachingStateSync.get_snapshots" - ) - get_snapshots_mock.return_value = {target_snapshot.snapshot_id: target_snapshot} - - dag_run_mock = mocker.Mock() - dag_run_mock.data_interval_end = to_datetime("2022-01-01") - - context = Context(dag_run=dag_run_mock) # type: ignore - assert not task.poke(context) - - get_snapshots_mock.assert_called_once_with([target_snapshot.table_info]) - - -def test_current_hwm_below_target(mocker: MockerFixture, make_snapshot, set_airflow_as_library): - this_snapshot = make_snapshot( - SqlModel(name="this", query=parse_one("select 1, ds")), version="a" - ) - this_snapshot.change_category = SnapshotChangeCategory.BREAKING - - target_snapshot_v1 = make_snapshot( - SqlModel(name="that", query=parse_one("select 2, ds")), version="b" - ) - target_snapshot_v1.change_category = SnapshotChangeCategory.BREAKING - - target_snapshot_v2 = make_snapshot( - SqlModel(name="that", query=parse_one("select 3, ds")), version="b" - ) - target_snapshot_v2.change_category = SnapshotChangeCategory.FORWARD_ONLY - - target_snapshot_v2.add_interval("2022-01-01", "2022-01-01") - - task = HighWaterMarkSensor( - target_snapshot_info=target_snapshot_v1.table_info, - this_snapshot=this_snapshot, - task_id="test_hwm_task", - ) - - get_snapshots_mock = mocker.patch( - "sqlmesh.core.state_sync.cache.CachingStateSync.get_snapshots" - ) - get_snapshots_mock.return_value = {target_snapshot_v1.snapshot_id: target_snapshot_v1} - - dag_run_mock = mocker.Mock() - dag_run_mock.data_interval_end = to_datetime("2022-01-03") - - context = Context(dag_run=dag_run_mock) # type: ignore - - assert not task.poke(context) - - get_snapshots_mock.assert_called_once_with([target_snapshot_v1.table_info]) - - -def test_current_hwm_above_target(mocker: MockerFixture, make_snapshot, set_airflow_as_library): - this_snapshot = make_snapshot( - SqlModel(name="this", query=parse_one("select 1, ds")), version="a" - ) - this_snapshot.change_category = SnapshotChangeCategory.BREAKING - - target_snapshot_v1 = make_snapshot( - SqlModel(name="that", query=parse_one("select 2, ds")), version="b" - ) - target_snapshot_v1.change_category = SnapshotChangeCategory.BREAKING - target_snapshot_v1.add_interval("2022-01-01", "2022-01-02") - - task = HighWaterMarkSensor( - target_snapshot_info=target_snapshot_v1.table_info, - this_snapshot=this_snapshot, - task_id="test_hwm_task", - ) - - get_snapshots_mock = mocker.patch( - "sqlmesh.core.state_sync.cache.CachingStateSync.get_snapshots" - ) - get_snapshots_mock.return_value = {target_snapshot_v1.snapshot_id: target_snapshot_v1} - - dag_run_mock = mocker.Mock() - dag_run_mock.data_interval_end = to_datetime("2022-01-03") - - context = Context(dag_run=dag_run_mock) # type: ignore - - assert task.poke(context) - - get_snapshots_mock.assert_called_once_with([target_snapshot_v1.table_info]) - - -def test_external_sensor(mocker: MockerFixture, make_snapshot, set_airflow_as_library): - snapshot = make_snapshot( - SqlModel( - name="this", - query=parse_one("select 1"), - signals=[ - {"table_name": "test_table_name_a", "ds": parse_one("@end_ds")}, - { - "table_name": "test_table_name_b", - "ds": parse_one("@end_ds"), - "hour": parse_one("@end_hour"), - }, - ], - ) - ) - - external_sensor_mock = mocker.Mock() - external_sensor_mock.poke.return_value = True - - factory_mock = mocker.Mock() - factory_mock.return_value = external_sensor_mock - - dag_run_mock = mocker.Mock() - dag_run_mock.data_interval_start = to_datetime("2023-01-01") - dag_run_mock.data_interval_end = to_datetime("2023-01-02") - - context = Context(dag_run=dag_run_mock) # type: ignore - - task = ExternalSensor( - snapshot=snapshot, - external_table_sensor_factory=factory_mock, - task_id="test_hwm_task", - ) - assert task.poke(context) - - factory_mock.assert_has_calls( - [ - call({"table_name": "test_table_name_a", "ds": "2023-01-01"}), - call({"table_name": "test_table_name_b", "ds": "2023-01-01", "hour": 23}), - ] - ) - external_sensor_mock.poke.assert_has_calls( - [ - call(context), - call(context), - ] - ) diff --git a/tests/schedulers/airflow/operators/test_targets.py b/tests/schedulers/airflow/operators/test_targets.py deleted file mode 100644 index 2c5436159d..0000000000 --- a/tests/schedulers/airflow/operators/test_targets.py +++ /dev/null @@ -1,216 +0,0 @@ -import typing as t -from unittest.mock import call - -import pytest -from airflow.exceptions import AirflowSkipException -from airflow.utils.context import Context -from pytest_mock.plugin import MockerFixture -from sqlglot import parse_one - -from sqlmesh.core.environment import Environment -from sqlmesh.core.model import Model, Seed, SeedKind, SeedModel, SqlModel -from sqlmesh.core.snapshot import ( - DeployabilityIndex, - SnapshotChangeCategory, - SnapshotTableCleanupTask, -) -from sqlmesh.engines import commands -from sqlmesh.schedulers.airflow.operators import targets -from sqlmesh.utils.date import to_datetime - -pytest_plugins = ["tests.schedulers.airflow.operators.fixtures"] -pytestmark = pytest.mark.airflow - - -@pytest.fixture -def model() -> Model: - return SqlModel( - name="test_model", - query=parse_one("SELECT a, ds FROM tbl"), - ) - - -def test_evaluation_target_execute( - mocker: MockerFixture, make_snapshot: t.Callable, model: Model, set_airflow_as_library -): - interval_ds = to_datetime("2022-01-01") - logical_ds = to_datetime("2022-01-02") - - dag_run_mock = mocker.Mock() - dag_run_mock.data_interval_start = interval_ds - dag_run_mock.data_interval_end = interval_ds - dag_run_mock.logical_date = logical_ds - - context = Context(dag_run=dag_run_mock) # type: ignore - - evaluator_evaluate_mock = mocker.patch( - "sqlmesh.core.snapshot.evaluator.SnapshotEvaluator.evaluate" - ) - evaluator_evaluate_mock.return_value = None - - add_interval_mock = mocker.patch("sqlmesh.core.state_sync.cache.CachingStateSync.add_interval") - - variable_get_mock = mocker.patch("sqlmesh.schedulers.airflow.operators.targets.Variable.get") - - variable_get_mock.return_value = "default_catalog" - - snapshot = make_snapshot(model) - snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - parent_snapshots = {snapshot.name: snapshot} - - deployability_index = DeployabilityIndex.all_deployable() - - target = targets.SnapshotEvaluationTarget( - snapshot=snapshot, - parent_snapshots=parent_snapshots, - deployability_index=deployability_index, - ) - target.execute(context, lambda: mocker.Mock(), "spark") - - add_interval_mock.assert_called_once_with(snapshot, interval_ds, interval_ds, is_dev=False) - - evaluator_evaluate_mock.assert_called_once_with( - snapshot, - start=interval_ds, - end=interval_ds, - execution_time=logical_ds, - snapshots=parent_snapshots, - deployability_index=deployability_index, - batch_index=0, - ) - - -def test_evaluation_target_execute_seed_model( - mocker: MockerFixture, make_snapshot: t.Callable, set_airflow_as_library -): - interval_ds = to_datetime("2022-01-01") - logical_ds = to_datetime("2022-01-02") - - dag_run_mock = mocker.Mock() - dag_run_mock.data_interval_start = interval_ds - dag_run_mock.data_interval_end = interval_ds - dag_run_mock.logical_date = logical_ds - - variable_get_mock = mocker.patch("sqlmesh.schedulers.airflow.operators.targets.Variable.get") - - variable_get_mock.return_value = "default_catalog" - - context = Context(dag_run=dag_run_mock) # type: ignore - - snapshot = make_snapshot( - SeedModel( - name="a", - kind=SeedKind(path="./path/to/seed"), - seed=Seed(content="content"), - column_hashes={"col": "hash1"}, - depends_on=set(), - ).to_dehydrated() - ) - snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - - evaluator_evaluate_mock = mocker.patch( - "sqlmesh.core.snapshot.evaluator.SnapshotEvaluator.evaluate" - ) - evaluator_evaluate_mock.return_value = None - - add_interval_mock = mocker.patch("sqlmesh.core.state_sync.cache.CachingStateSync.add_interval") - - get_snapshots_mock = mocker.patch( - "sqlmesh.core.state_sync.cache.CachingStateSync.get_snapshots" - ) - get_snapshots_mock.return_value = {snapshot.snapshot_id: snapshot} - - deployability_index = DeployabilityIndex.all_deployable() - - target = targets.SnapshotEvaluationTarget( - snapshot=snapshot, parent_snapshots={}, deployability_index=deployability_index - ) - target.execute(context, lambda: mocker.Mock(), "spark") - - add_interval_mock.assert_called_once_with(snapshot, interval_ds, interval_ds, is_dev=False) - - evaluator_evaluate_mock.assert_called_once_with( - snapshot, - start=interval_ds, - end=interval_ds, - execution_time=logical_ds, - snapshots={snapshot.name: snapshot}, - deployability_index=deployability_index, - batch_index=0, - ) - - -def test_cleanup_target_execute( - mocker: MockerFixture, make_snapshot: t.Callable, model: Model, set_airflow_as_library -): - snapshot = make_snapshot(model) - snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - - environment = Environment( - name="test_env", snapshots=[snapshot.table_info], start_at="", plan_id="test_plan_id" - ) - - cleanup_task = SnapshotTableCleanupTask(snapshot=snapshot.table_info, dev_table_only=False) - - command = commands.CleanupCommandPayload( - environments=[environment], - tasks=[cleanup_task], - ) - - task_instance_mock = mocker.Mock() - task_instance_mock.xcom_pull.return_value = command.json() - - variable_get_mock = mocker.patch("sqlmesh.schedulers.airflow.operators.targets.Variable.get") - - variable_get_mock.return_value = "default_catalog" - - context = Context(ti=task_instance_mock) # type: ignore - - evaluator_cleanup_mock = mocker.patch( - "sqlmesh.core.snapshot.evaluator.SnapshotEvaluator.cleanup" - ) - - delete_xcom_mock = mocker.patch("sqlmesh.schedulers.airflow.operators.targets._delete_xcom") - - target = targets.SnapshotCleanupTarget() - - evaluator_adapter_mock = mocker.MagicMock() - target.execute(context, lambda: evaluator_adapter_mock, "spark") - - evaluator_adapter_mock.cursor().execute.assert_has_calls( - [call("DROP SCHEMA IF EXISTS `default__test_env` CASCADE")] - ) - evaluator_cleanup_mock.assert_called_once_with([cleanup_task]) - - task_instance_mock.xcom_pull.assert_called_once_with(key="snapshot_cleanup_command") - - delete_xcom_mock.assert_called_once() - - -def test_cleanup_target_skip_execution( - mocker: MockerFixture, make_snapshot: t.Callable, model: Model, set_airflow_as_library -): - snapshot = make_snapshot(model) - snapshot.version = "test_version" - - task_instance_mock = mocker.Mock() - task_instance_mock.xcom_pull.return_value = commands.CleanupCommandPayload( - tasks=[], environments=[] - ).json() - - context = Context(ti=task_instance_mock) # type: ignore - - evaluator_demote_mock = mocker.patch("sqlmesh.core.snapshot.evaluator.SnapshotEvaluator.demote") - evaluator_cleanup_mock = mocker.patch( - "sqlmesh.core.snapshot.evaluator.SnapshotEvaluator.cleanup" - ) - - delete_xcom_mock = mocker.patch("sqlmesh.schedulers.airflow.operators.targets._delete_xcom") - - target = targets.SnapshotCleanupTarget() - with pytest.raises(AirflowSkipException): - target.execute(context, lambda: mocker.Mock(), "spark") - - evaluator_demote_mock.assert_not_called() - evaluator_cleanup_mock.assert_not_called() - delete_xcom_mock.assert_called_once() diff --git a/tests/schedulers/airflow/test_client.py b/tests/schedulers/airflow/test_client.py deleted file mode 100644 index 14e025ea22..0000000000 --- a/tests/schedulers/airflow/test_client.py +++ /dev/null @@ -1,484 +0,0 @@ -import json -from unittest.mock import call -from urllib.parse import urlencode - -import pytest -import requests -from pytest_mock.plugin import MockerFixture -from sqlglot import parse_one - -from sqlmesh.core.config import EnvironmentSuffixTarget -from sqlmesh.core.environment import Environment -from sqlmesh.core.model import IncrementalByTimeRangeKind, SqlModel -from sqlmesh.core.node import NodeType -from sqlmesh.core.snapshot import Snapshot, SnapshotChangeCategory, SnapshotId -from sqlmesh.schedulers.airflow import common -from sqlmesh.schedulers.airflow.client import AirflowClient, _list_to_json -from sqlmesh.utils.date import to_timestamp - -pytestmark = pytest.mark.airflow - - -@pytest.fixture -def snapshot() -> Snapshot: - snapshot = Snapshot.from_node( - SqlModel( - name="test_model", - kind=IncrementalByTimeRangeKind(time_column="ds", dialect="spark"), - storage_format="parquet", - partitioned_by=["a"], - query=parse_one("SELECT a, ds FROM tbl"), - pre_statements=[ - parse_one("@DEF(key, 'value')"), - ], - dialect="spark", - ), - nodes={}, - ttl="in 1 week", - ) - snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - snapshot.updated_ts = 1665014400000 - snapshot.created_ts = 1665014400000 - return snapshot - - -def test_apply_plan(mocker: MockerFixture, snapshot: Snapshot): - apply_plan_response_mock = mocker.Mock() - apply_plan_response_mock.json.return_value = {"request_id": "test_request_id"} - apply_plan_response_mock.status_code = 200 - apply_plan_mock = mocker.patch("requests.Session.post") - apply_plan_mock.return_value = apply_plan_response_mock - - environment = Environment( - name="test_env", - snapshots=[snapshot.table_info], - start_at="2022-01-01", - end_at="2022-01-01", - plan_id="test_plan_id", - previous_plan_id="previous_plan_id", - promoted_snapshot_ids=[snapshot.snapshot_id], - ) - - request_id = "test_request_id" - - client = AirflowClient(airflow_url=common.AIRFLOW_LOCAL_URL, session=requests.Session()) - client.apply_plan( - [snapshot], - environment, - request_id, - models_to_backfill={'"test_model"'}, - directly_modified_snapshots=[snapshot.snapshot_id], - restatements={ - snapshot.snapshot_id: (to_timestamp("2024-01-01"), to_timestamp("2024-01-02")) - }, - ) - - apply_plan_mock.assert_called_once() - args, data = apply_plan_mock.call_args_list[0] - - assert args[0] == "http://localhost:8080/sqlmesh/api/v1/plans" - assert data["headers"] == {"Content-Type": "application/json"} - assert json.loads(data["data"]) == { - "new_snapshots": [ - { - "created_ts": 1665014400000, - "ttl": "in 1 week", - "fingerprint": snapshot.fingerprint.dict(), - "intervals": [], - "dev_intervals": [], - "node": { - "audits": [], - "clustered_by": [], - "cron": "@daily", - "dialect": "spark", - "pre_statements": ["@DEF(key, " "'value')"], - "kind": { - "name": "INCREMENTAL_BY_TIME_RANGE", - "time_column": {"column": "`ds`"}, - "forward_only": False, - "on_destructive_change": "ERROR", - "disable_restatement": False, - "dialect": "spark", - }, - "mapping_schema": {}, - "inline_audits": {}, - "name": "test_model", - "partitioned_by": ["`a`"], - "query": "SELECT a, ds FROM tbl", - "references": [], - "project": "", - "storage_format": "parquet", - "jinja_macros": { - "create_builtins_module": "sqlmesh.utils.jinja", - "global_objs": {}, - "packages": {}, - "root_macros": {}, - "top_level_packages": [], - }, - "source_type": "sql", - "tags": [], - "grains": [], - "allow_partials": False, - "signals": [], - "enabled": True, - }, - "audits": [], - "name": '"test_model"', - "parents": [], - "previous_versions": [], - "updated_ts": 1665014400000, - "version": snapshot.version, - "change_category": snapshot.change_category, - "migrated": False, - "unrestorable": False, - } - ], - "environment": { - "name": "test_env", - "snapshots": [ - { - "fingerprint": snapshot.fingerprint.dict(), - "name": '"test_model"', - "node_type": NodeType.MODEL, - "previous_versions": [], - "version": snapshot.version, - "physical_schema": "sqlmesh__default", - "change_category": snapshot.change_category, - "parents": [], - "kind_name": "INCREMENTAL_BY_TIME_RANGE", - } - ], - "start_at": "2022-01-01", - "end_at": "2022-01-01", - "plan_id": "test_plan_id", - "previous_plan_id": "previous_plan_id", - "promoted_snapshot_ids": [ - { - "name": '"test_model"', - "identifier": snapshot.identifier, - } - ], - "suffix_target": "schema", - "normalize_name": True, - }, - "no_gaps": False, - "skip_backfill": False, - "notification_targets": [], - "request_id": request_id, - "backfill_concurrent_tasks": 1, - "ddl_concurrent_tasks": 1, - "users": [], - "is_dev": False, - "forward_only": False, - "allow_destructive_snapshots": [], - "models_to_backfill": ['"test_model"'], - "end_bounded": False, - "ensure_finalized_snapshots": False, - "directly_modified_snapshots": [{"identifier": "844700562", "name": '"test_model"'}], - "indirectly_modified_snapshots": {}, - "removed_snapshots": [], - "restatements": {'"test_model"': [to_timestamp("2024-01-01"), to_timestamp("2024-01-02")]}, - } - - common.PlanApplicationRequest.parse_raw(data["data"]) - - -def snapshot_url(snapshot_ids, key="ids") -> str: - return urlencode({key: _list_to_json(snapshot_ids)[0]}) - - -def test_get_snapshots(mocker: MockerFixture, snapshot: Snapshot): - snapshots = common.SnapshotsResponse(snapshots=[snapshot]) - - get_snapshots_response_mock = mocker.Mock() - get_snapshots_response_mock.status_code = 200 - get_snapshots_response_mock.json.return_value = snapshots.dict() - get_snapshots_mock = mocker.patch("requests.Session.get") - get_snapshots_mock.return_value = get_snapshots_response_mock - - client = AirflowClient(airflow_url=common.AIRFLOW_LOCAL_URL, session=requests.Session()) - result = client.get_snapshots([snapshot.snapshot_id]) - - assert result == [snapshot] - - get_snapshots_mock.assert_called_once_with( - f"http://localhost:8080/sqlmesh/api/v1/snapshots?{snapshot_url([snapshot.snapshot_id])}" - ) - - -def test_get_snapshots_batching(mocker: MockerFixture, snapshot: Snapshot): - snapshots = common.SnapshotsResponse(snapshots=[snapshot]) - - get_snapshots_response_mock = mocker.Mock() - get_snapshots_response_mock.status_code = 200 - get_snapshots_response_mock.json.return_value = snapshots.dict() - get_snapshots_mock = mocker.patch("requests.Session.get") - get_snapshots_mock.return_value = get_snapshots_response_mock - - snapshot_ids_batch_size = 40 - first_batch_ids = [ - SnapshotId(name=snapshot.name, identifier=str(i)) for i in range(snapshot_ids_batch_size) - ] - - client = AirflowClient( - airflow_url=common.AIRFLOW_LOCAL_URL, - session=requests.Session(), - snapshot_ids_batch_size=snapshot_ids_batch_size, - ) - result = client.get_snapshots([*first_batch_ids, snapshot.snapshot_id]) - - assert result == [snapshot] * 2 - - get_snapshots_mock.assert_has_calls( - [ - call(f"http://localhost:8080/sqlmesh/api/v1/snapshots?{snapshot_url(first_batch_ids)}"), - call().json(), - call( - f"http://localhost:8080/sqlmesh/api/v1/snapshots?{snapshot_url([snapshot.snapshot_id])}" - ), - call().json(), - ] - ) - - -def test_snapshots_exist(mocker: MockerFixture, snapshot: Snapshot): - snapshot_ids = common.SnapshotIdsResponse(snapshot_ids=[snapshot.snapshot_id]) - - snapshots_exist_response_mock = mocker.Mock() - snapshots_exist_response_mock.status_code = 200 - snapshots_exist_response_mock.json.return_value = snapshot_ids.dict() - snapshots_exist_mock = mocker.patch("requests.Session.get") - snapshots_exist_mock.return_value = snapshots_exist_response_mock - - client = AirflowClient(airflow_url=common.AIRFLOW_LOCAL_URL, session=requests.Session()) - result = client.snapshots_exist([snapshot.snapshot_id]) - - assert result == {snapshot.snapshot_id} - - snapshots_exist_mock.assert_called_once_with( - f"http://localhost:8080/sqlmesh/api/v1/snapshots?check_existence&{snapshot_url([snapshot.snapshot_id])}" - ) - - -def test_snapshots_exist_batching(mocker: MockerFixture, snapshot: Snapshot): - snapshot_ids = common.SnapshotIdsResponse(snapshot_ids=[snapshot.snapshot_id]) - - snapshots_exist_response_mock = mocker.Mock() - snapshots_exist_response_mock.status_code = 200 - snapshots_exist_response_mock.json.return_value = snapshot_ids.dict() - snapshots_exist_mock = mocker.patch("requests.Session.get") - snapshots_exist_mock.return_value = snapshots_exist_response_mock - - snapshot_ids_batch_size = 40 - first_batch_ids = [ - SnapshotId(name=snapshot.name, identifier=str(i)) for i in range(snapshot_ids_batch_size) - ] - - client = AirflowClient( - airflow_url=common.AIRFLOW_LOCAL_URL, - session=requests.Session(), - snapshot_ids_batch_size=snapshot_ids_batch_size, - ) - result = client.snapshots_exist([*first_batch_ids, snapshot.snapshot_id]) - - assert result == {snapshot.snapshot_id} - - snapshots_exist_mock.assert_has_calls( - [ - call( - f"http://localhost:8080/sqlmesh/api/v1/snapshots?check_existence&{snapshot_url(first_batch_ids)}" - ), - call().json(), - call( - f"http://localhost:8080/sqlmesh/api/v1/snapshots?check_existence&{snapshot_url([snapshot.snapshot_id])}" - ), - call().json(), - ] - ) - - -def test_models_exist(mocker: MockerFixture, snapshot: Snapshot): - model_names = ["model_a", "model_b"] - - models_exist_response_mock = mocker.Mock() - models_exist_response_mock.status_code = 200 - models_exist_response_mock.json.return_value = common.ExistingModelsResponse( - names=model_names - ).dict() - models_exist_mock = mocker.patch("requests.Session.get") - models_exist_mock.return_value = models_exist_response_mock - - client = AirflowClient(airflow_url=common.AIRFLOW_LOCAL_URL, session=requests.Session()) - result = client.nodes_exist(model_names, exclude_external=True) - - assert result == set(model_names) - - models_exist_mock.assert_called_once_with( - "http://localhost:8080/sqlmesh/api/v1/models?exclude_external&names=model_a%2Cmodel_b" - ) - - -def test_get_environment(mocker: MockerFixture, snapshot: Snapshot): - environment = Environment( - name="test", - snapshots=[snapshot.table_info], - start_at="2022-01-01", - end_at="2022-01-01", - plan_id="test_plan_id", - previous_plan_id=None, - suffix_target=EnvironmentSuffixTarget.TABLE, - ) - - get_environment_response_mock = mocker.Mock() - get_environment_response_mock.status_code = 200 - get_environment_response_mock.json.return_value = environment.dict() - get_environment_mock = mocker.patch("requests.Session.get") - get_environment_mock.return_value = get_environment_response_mock - - client = AirflowClient(airflow_url=common.AIRFLOW_LOCAL_URL, session=requests.Session()) - result = client.get_environment("dev") - - assert result == environment - - get_environment_mock.assert_called_once_with( - "http://localhost:8080/sqlmesh/api/v1/environments/dev" - ) - - -def test_get_environments(mocker: MockerFixture, snapshot: Snapshot): - environment = Environment( - name="test", - snapshots=[snapshot.table_info], - start_at="2022-01-01", - end_at="2022-01-01", - plan_id="test_plan_id", - previous_plan_id=None, - ) - environments = common.EnvironmentsResponse(environments=[environment]) - - get_environments_response_mock = mocker.Mock() - get_environments_response_mock.status_code = 200 - get_environments_response_mock.json.return_value = environments.dict() - get_environments_mock = mocker.patch("requests.Session.get") - get_environments_mock.return_value = get_environments_response_mock - - client = AirflowClient(airflow_url=common.AIRFLOW_LOCAL_URL, session=requests.Session()) - result = client.get_environments() - - assert result == [environment] - - get_environments_mock.assert_called_once_with( - "http://localhost:8080/sqlmesh/api/v1/environments" - ) - - -@pytest.mark.parametrize("ensure_finalized_snapshots", [True, False]) -def test_max_interval_end_for_environment( - mocker: MockerFixture, snapshot: Snapshot, ensure_finalized_snapshots: bool -): - response = common.IntervalEndResponse( - environment="test_environment", max_interval_end=to_timestamp("2023-01-01") - ) - - max_interval_end_response_mock = mocker.Mock() - max_interval_end_response_mock.status_code = 200 - max_interval_end_response_mock.json.return_value = response.dict() - max_interval_end_mock = mocker.patch("requests.Session.get") - max_interval_end_mock.return_value = max_interval_end_response_mock - - client = AirflowClient(airflow_url=common.AIRFLOW_LOCAL_URL, session=requests.Session()) - result = client.max_interval_end_for_environment("test_environment", ensure_finalized_snapshots) - - assert result == response.max_interval_end - - flags = "?ensure_finalized_snapshots" if ensure_finalized_snapshots else "" - max_interval_end_mock.assert_called_once_with( - f"http://localhost:8080/sqlmesh/api/v1/environments/test_environment/max_interval_end{flags}" - ) - - -@pytest.mark.parametrize("ensure_finalized_snapshots", [True, False]) -def test_greatest_common_interval_end( - mocker: MockerFixture, snapshot: Snapshot, ensure_finalized_snapshots: bool -): - response = common.IntervalEndResponse( - environment="test_environment", max_interval_end=to_timestamp("2023-01-01") - ) - - max_interval_end_response_mock = mocker.Mock() - max_interval_end_response_mock.status_code = 200 - max_interval_end_response_mock.json.return_value = response.dict() - max_interval_end_mock = mocker.patch("requests.Session.get") - max_interval_end_mock.return_value = max_interval_end_response_mock - - client = AirflowClient(airflow_url=common.AIRFLOW_LOCAL_URL, session=requests.Session()) - result = client.greatest_common_interval_end( - "test_environment", {"a.b.c"}, ensure_finalized_snapshots - ) - - assert result == response.max_interval_end - - flags = "ensure_finalized_snapshots&" if ensure_finalized_snapshots else "" - max_interval_end_mock.assert_called_once_with( - f"http://localhost:8080/sqlmesh/api/v1/environments/test_environment/greatest_common_interval_end?{flags}models=%5B%22a.b.c%22%5D" - ) - - -def test_get_dag_run_state(mocker: MockerFixture): - get_dag_run_state_mock = mocker.Mock() - get_dag_run_state_mock.status_code = 200 - get_dag_run_state_mock.json.return_value = {"state": "failed"} - get_snapshot_mock = mocker.patch("requests.Session.get") - get_snapshot_mock.return_value = get_dag_run_state_mock - - client = AirflowClient(airflow_url=common.AIRFLOW_LOCAL_URL, session=requests.Session()) - result = client.get_dag_run_state("test_dag_id", "test_dag_run_id") - - assert result == "failed" - - get_snapshot_mock.assert_called_once_with( - "http://localhost:8080/api/v1/dags/test_dag_id/dagRuns/test_dag_run_id" - ) - - -def test_invalidat_environment(mocker: MockerFixture): - delete_environment_response_mock = mocker.Mock() - delete_environment_response_mock.status_code = 200 - delete_environment_response_mock.json.return_value = {"name": "test_environment"} - delete_environment_mock = mocker.patch("requests.Session.delete") - delete_environment_mock.return_value = delete_environment_response_mock - - client = AirflowClient(airflow_url=common.AIRFLOW_LOCAL_URL, session=requests.Session()) - client.invalidate_environment("test_environment") - - delete_environment_mock.assert_called_once_with( - "http://localhost:8080/sqlmesh/api/v1/environments/test_environment" - ) - - -def test_get_variable(mocker: MockerFixture): - get_variable_response_mock = mocker.Mock() - get_variable_response_mock.status_code = 200 - get_variable_response_mock.json.return_value = {"value": "test_value", "key": "test_key"} - get_variable_mock = mocker.patch("requests.Session.get") - get_variable_mock.return_value = get_variable_response_mock - - client = AirflowClient(airflow_url=common.AIRFLOW_LOCAL_URL, session=requests.Session()) - assert client.get_variable("test_key") == "test_value" - - get_variable_mock.assert_called_once_with("http://localhost:8080/api/v1/variables/test_key") - - -def test_url_no_trailing_slash(mocker: MockerFixture, snapshot: Snapshot): - get_variable_response_mock = mocker.Mock() - get_variable_response_mock.status_code = 200 - get_variable_response_mock.json.return_value = {"value": "test_value", "key": "test_key"} - get_variable_mock = mocker.patch("requests.Session.get") - get_variable_mock.return_value = get_variable_response_mock - - client = AirflowClient(airflow_url="http://localhost:8080/prefix", session=requests.Session()) - client.get_variable("test_key") - - get_variable_mock.assert_called_once_with( - "http://localhost:8080/prefix/api/v1/variables/test_key" - ) diff --git a/tests/schedulers/airflow/test_common.py b/tests/schedulers/airflow/test_common.py deleted file mode 100644 index f33bb500c5..0000000000 --- a/tests/schedulers/airflow/test_common.py +++ /dev/null @@ -1,14 +0,0 @@ -from __future__ import annotations - -import pytest - -from sqlmesh.schedulers.airflow import common - -pytestmark = pytest.mark.airflow - - -def test_snapshot_dag_id(): - assert ( - common.dag_id_for_name_version('"test_schema"."test_table"', "version") - == "sqlmesh_snapshot__test_schema___test_table__version_dag" - ) diff --git a/tests/schedulers/airflow/test_dag_generator.py b/tests/schedulers/airflow/test_dag_generator.py deleted file mode 100644 index c9a410117a..0000000000 --- a/tests/schedulers/airflow/test_dag_generator.py +++ /dev/null @@ -1,175 +0,0 @@ -import typing as t -from pytest_mock.plugin import MockerFixture - -from sqlglot import parse_one -from airflow.models import BaseOperator -from airflow.utils.context import Context - -from sqlmesh.core.config import EnvironmentSuffixTarget -from sqlmesh.core.environment import Environment -from sqlmesh.core.model import IncrementalByUniqueKeyKind, SqlModel -from sqlmesh.core.snapshot import ( - Snapshot, - SnapshotChangeCategory, -) -from sqlmesh.schedulers.airflow.dag_generator import SnapshotDagGenerator -from sqlmesh.schedulers.airflow import common -from sqlmesh.schedulers.airflow.operators.targets import BaseTarget, SnapshotEvaluationTarget -from sqlmesh.schedulers.airflow.operators.sensor import HighWaterMarkSensor -from sqlmesh.utils.date import to_datetime, to_timestamp - - -class TestSubmitOperator(BaseOperator): - __test__ = False # prevent pytest trying to collect this as a test class - - def __init__( - self, - *, - target: BaseTarget, - **kwargs: t.Any, - ) -> None: - super().__init__(**kwargs) - self.target = target - - -def test_generate_plan_application_dag__batch_index_populated(mocker: MockerFixture, make_snapshot): - model = SqlModel( - name="test_model", - kind=IncrementalByUniqueKeyKind(unique_key="item_id", batch_size=1), - cron="@daily", - start="2020-01-01", - end="2020-01-07", - storage_format="ICEBERG", - query=parse_one(""" - SELECT item_id::int AS item_id, event_date::date AS event_date - FROM ( - VALUES - (2, '2020-01-01'), - (1, '2020-01-01'), - (3, '2020-01-03'), - (1, '2020-01-04'), - (1, '2020-01-05'), - (1, '2020-01-06'), - (1, '2020-01-07') - ) AS t(item_id, event_date) - WHERE event_date BETWEEN @start_date AND @end_date - """), - ) - - snapshot: Snapshot = make_snapshot(model) - snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - - state_reader_mock = mocker.Mock() - state_reader_mock.get_snapshots.return_value = {} - - generator = SnapshotDagGenerator( - engine_operator=TestSubmitOperator, - engine_operator_args={}, - ddl_engine_operator=TestSubmitOperator, - ddl_engine_operator_args={}, - external_table_sensor_factory=None, - sensor_mode="reschedule", - state_reader=state_reader_mock, - external_sensor_args=None, - high_water_mark_sensor_args=None, - ) - - environment_name = "test_env" - new_environment = Environment( - name=environment_name, - snapshots=[], - start_at="2020-01-01", - end_at="2020-01-10", - plan_id="test_plan_id", - suffix_target=EnvironmentSuffixTarget.TABLE, - catalog_name_override="test_catalog", - ) - - dag_plan = common.PlanDagSpec( - request_id="test_request_id", - environment=new_environment, - new_snapshots=[snapshot], - backfill_intervals_per_snapshot=[ - common.BackfillIntervalsPerSnapshot( - snapshot_id=snapshot.snapshot_id, - intervals=[ - (to_datetime("2020-01-01"), to_datetime("2020-01-02")), - (to_datetime("2020-01-02"), to_datetime("2020-01-03")), - (to_datetime("2020-01-03"), to_datetime("2020-01-04")), - ], - ) - ], - demoted_snapshots=[], - no_gaps=True, - notification_targets=[], - backfill_concurrent_tasks=1, - ddl_concurrent_tasks=1, - users=[], - is_dev=False, - allow_destructive_snapshots=set(), - execution_time=to_datetime("2024-01-01"), - ) - - dag = generator.generate_plan_application_dag(dag_plan) - assert dag is not None - - backfill_tasks = [ - t - for t in dag.tasks - if "backfill__test_model" in t.task_id - and not t.task_id.endswith("__start") - and not t.task_id.endswith("__end") - ] - assert len(backfill_tasks) == 3 - - for batch_idx, task in enumerate(backfill_tasks): - target: SnapshotEvaluationTarget = task.target # type: ignore - assert target is not None - command = target._get_command_payload(context=t.cast(Context, None)) - assert command is not None - assert target.batch_index == batch_idx - assert command.batch_index == batch_idx - - -def test_sensor_mode_override(mocker: MockerFixture, make_snapshot): - snapshot_a = make_snapshot( - SqlModel(name="a", kind=dict(name="FULL"), query=parse_one("select 1 as a, ds")), - ) - snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) - snapshot_a.unpaused_ts = to_timestamp("2024-01-01") - - snapshot_b = make_snapshot( - SqlModel(name="b", kind=dict(name="FULL"), query=parse_one("select a, ds from a")), - nodes={snapshot_a.name: snapshot_a.node}, - ) - snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) - snapshot_b.unpaused_ts = to_timestamp("2024-01-01") - - state_reader_mock = mocker.Mock() - state_reader_mock.get_snapshots.return_value = { - snapshot_a.snapshot_id: snapshot_a, - snapshot_b.snapshot_id: snapshot_b, - } - - generator = SnapshotDagGenerator( - engine_operator=TestSubmitOperator, - engine_operator_args={}, - ddl_engine_operator=TestSubmitOperator, - ddl_engine_operator_args={}, - external_table_sensor_factory=None, - sensor_mode="poke", - state_reader=state_reader_mock, - external_sensor_args=None, - high_water_mark_sensor_args=None, - ) - - dags = generator.generate_cadence_dags([snapshot_a, snapshot_b]) - assert len(dags) == 2 - - assert len(dags[0].tasks) == 1 - assert isinstance(dags[0].tasks[0], TestSubmitOperator) - - assert len(dags[1].tasks) == 2 - assert isinstance(dags[1].tasks[0], HighWaterMarkSensor) - assert isinstance(dags[1].tasks[1], TestSubmitOperator) - assert dags[1].tasks[0].mode == "poke" diff --git a/tests/schedulers/airflow/test_end_to_end.py b/tests/schedulers/airflow/test_end_to_end.py deleted file mode 100644 index fc33762535..0000000000 --- a/tests/schedulers/airflow/test_end_to_end.py +++ /dev/null @@ -1,59 +0,0 @@ -from datetime import timedelta - -import pytest -from pytest_mock.plugin import MockerFixture -from tenacity import retry, stop_after_attempt, wait_fixed - -from sqlmesh.core.context import Context -from sqlmesh.schedulers.airflow.client import AirflowClient -from sqlmesh.utils.date import now, to_date, yesterday -from tests.conftest import SushiDataValidator - -pytestmark = [ - pytest.mark.airflow, - pytest.mark.docker, -] - - -@pytest.fixture(autouse=True) -def wait_for_airflow(airflow_client: AirflowClient): - @retry(wait=wait_fixed(2), stop=stop_after_attempt(15), reraise=True) - def get_receiver_dag() -> None: - airflow_client.get_janitor_dag() - - get_receiver_dag() - - -def test_sushi(mocker: MockerFixture, is_docker: bool): - start = to_date(now() - timedelta(days=7)) - end = now() - - airflow_config = "airflow_config_docker" if is_docker else "airflow_config" - context = Context(paths="./examples/sushi", config=airflow_config) - assert context.default_catalog == "spark_catalog" - for fqn in context.models: - assert fqn.startswith('"spark_catalog"."') - data_validator = SushiDataValidator.from_context(context) - - context.plan( - environment="test_dev", - start=start, - end=end, - skip_tests=True, - no_prompts=True, - auto_apply=True, - ) - - data_validator.validate( - "sushi.customer_revenue_lifetime", start, yesterday(), env_name="test_dev" - ) - - # Ensure that the plan has been applied successfully. - no_change_plan = context.plan( - environment="test_dev_two", - start=start, - end=end, - skip_tests=True, - no_prompts=True, - ) - assert not no_change_plan.requires_backfill diff --git a/tests/schedulers/airflow/test_integration.py b/tests/schedulers/airflow/test_integration.py deleted file mode 100644 index ceeadab3cd..0000000000 --- a/tests/schedulers/airflow/test_integration.py +++ /dev/null @@ -1,127 +0,0 @@ -import typing as t -from datetime import timedelta - -import pytest -from sqlglot import parse_one -from tenacity import retry, stop_after_attempt, wait_fixed - -from sqlmesh.core import constants as c -from sqlmesh.core.environment import Environment -from sqlmesh.core.model import IncrementalByTimeRangeKind, Model, SqlModel -from sqlmesh.core.snapshot import Snapshot, SnapshotChangeCategory -from sqlmesh.schedulers.airflow import common -from sqlmesh.schedulers.airflow.client import AirflowClient -from sqlmesh.utils import random_id -from sqlmesh.utils.date import yesterday -from sqlmesh.utils.errors import SQLMeshError - -pytestmark = [ - pytest.mark.airflow, - pytest.mark.docker, -] - - -DAG_CREATION_WAIT_INTERVAL = 3 -DAG_CREATION_RETRY_ATTEMPTS = 5 -DAG_RUN_POLL_INTERVAL = 1 - - -def test_system_dags(airflow_client: AirflowClient): - @retry(wait=wait_fixed(2), stop=stop_after_attempt(15), reraise=True) - def get_system_dags() -> t.List[t.Dict[str, t.Any]]: - return [ - airflow_client.get_janitor_dag(), - ] - - system_dags = get_system_dags() - assert all(d["is_active"] for d in system_dags) - - -def test_apply_plan_create_backfill_promote( - airflow_client: AirflowClient, make_snapshot, random_name -): - model_name = random_name() - snapshot = make_snapshot(_create_model(model_name)) - snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - - environment_name = _random_environment_name() - environment = _create_environment(snapshot, name=environment_name) - environment.start_at = yesterday() - timedelta(days=1) - environment.end_at = None - - assert airflow_client.get_variable(common.DEFAULT_CATALOG_VARIABLE_NAME) == "spark_catalog" - - assert airflow_client.get_environment(environment_name) is None - - _apply_plan_and_block(airflow_client, [snapshot], environment, is_dev=False) - - assert airflow_client.get_environment(environment_name).snapshots == [ # type: ignore - snapshot.table_info - ] - - # Make sure that the same Snapshot can't be added again. - with pytest.raises(SQLMeshError, match=r"Snapshots.*already exist.*"): - airflow_client.apply_plan([snapshot], environment, random_name()) - - # Verify full environment demotion. - environment.snapshots = [] - environment.previous_plan_id = environment.plan_id - environment.plan_id = "new_plan_id" - _apply_plan_and_block(airflow_client, [], environment) - assert not airflow_client.get_environment(environment_name).snapshots # type: ignore - - -def _apply_plan_and_block( - airflow_client: AirflowClient, - new_snapshots: t.List[Snapshot], - environment: Environment, - is_dev: t.Optional[bool] = None, -) -> None: - if is_dev is None: - is_dev = environment.name != c.PROD - - plan_request_id = random_id() - airflow_client.apply_plan(new_snapshots, environment, plan_request_id, is_dev=is_dev) - - plan_application_dag_id = common.plan_application_dag_id(environment.name, plan_request_id) - plan_application_dag_run_id = airflow_client.wait_for_first_dag_run( - plan_application_dag_id, DAG_CREATION_WAIT_INTERVAL, DAG_CREATION_RETRY_ATTEMPTS - ) - assert airflow_client.wait_for_dag_run_completion( - plan_application_dag_id, plan_application_dag_run_id, DAG_RUN_POLL_INTERVAL - ) - - -@retry(wait=wait_fixed(3), stop=stop_after_attempt(5), reraise=True) -def _get_snapshot_dag( - airflow_client: AirflowClient, model_name: str, version: str -) -> t.Dict[str, t.Any]: - return airflow_client.get_snapshot_dag(model_name, version) - - -def _create_model(name: str) -> Model: - return SqlModel( - name=name, - kind=IncrementalByTimeRangeKind(time_column="ds", batch_size=30), - description="Dummy table", - owner="jen", - cron="@daily", - start="2020-01-01", - partitioned_by=["ds"], - query=parse_one("SELECT '2022-01-01'::TEXT AS ds, 1::INT AS one"), - ) - - -def _create_environment(snapshot: Snapshot, name: t.Optional[str] = None) -> Environment: - return Environment( - name=name or _random_environment_name(), - snapshots=[snapshot.table_info], - start_at="2022-01-01", - end_at="2022-01-01", - plan_id="test_plan_id", - previous_plan_id=None, - ) - - -def _random_environment_name() -> str: - return f"test_environment_{random_id()[-8:]}" diff --git a/tests/schedulers/airflow/test_mwaa_client.py b/tests/schedulers/airflow/test_mwaa_client.py deleted file mode 100644 index a551fe0a78..0000000000 --- a/tests/schedulers/airflow/test_mwaa_client.py +++ /dev/null @@ -1,159 +0,0 @@ -import base64 -import json - -import pytest -from pytest_mock.plugin import MockerFixture - -from sqlmesh.schedulers.airflow.mwaa_client import MWAAClient - -pytestmark = pytest.mark.airflow - - -def test_get_first_dag_run_id(mocker: MockerFixture): - list_runs_response_mock = mocker.Mock() - list_runs_response_mock.json.return_value = { - "stdout": _encode_output(json.dumps([{"run_id": "test_run_id", "state": "success"}])), - "stderr": "", - } - list_runs_response_mock.status_code = 200 - list_runs_mock = mocker.patch("requests.Session.post") - list_runs_mock.return_value = list_runs_response_mock - - url_and_auth_token_mock = mocker.patch( - "sqlmesh.schedulers.airflow.mwaa_client.url_and_auth_token_for_environment" - ) - url_and_auth_token_mock.return_value = ("https://test_airflow_host", "test_token") - - client = MWAAClient("test_environment") - - assert client.get_first_dag_run_id("test_dag_id") == "test_run_id" - - list_runs_mock.assert_called_once_with( - "https://test_airflow_host/aws_mwaa/cli", - data="dags list-runs -o json -d test_dag_id", - ) - url_and_auth_token_mock.assert_called_once_with("test_environment") - - -def test_get_dag_run_state(mocker: MockerFixture): - list_runs_response_mock = mocker.Mock() - list_runs_response_mock.json.return_value = { - "stdout": _encode_output( - json.dumps( - [ - {"run_id": "test_run_id_a", "state": "success"}, - {"run_id": "test_run_id_b", "state": "failed"}, - ] - ) - ), - "stderr": "", - } - list_runs_response_mock.status_code = 200 - list_runs_mock = mocker.patch("requests.Session.post") - list_runs_mock.return_value = list_runs_response_mock - - url_and_auth_token_mock = mocker.patch( - "sqlmesh.schedulers.airflow.mwaa_client.url_and_auth_token_for_environment" - ) - url_and_auth_token_mock.return_value = ("https://test_airflow_host", "test_token") - - client = MWAAClient("test_environment") - - assert client.get_dag_run_state("test_dag_id", "test_run_id_b") == "failed" - - list_runs_mock.assert_called_once_with( - "https://test_airflow_host/aws_mwaa/cli", - data="dags list-runs -o json -d test_dag_id", - ) - url_and_auth_token_mock.assert_called_once_with("test_environment") - - -def test_get_variable(mocker: MockerFixture): - get_variable_response_mock = mocker.Mock() - get_variable_response_mock.json.return_value = { - "stdout": _encode_output("test_value"), - "stderr": "", - } - get_variable_response_mock.status_code = 200 - get_variable_mock = mocker.patch("requests.Session.post") - get_variable_mock.return_value = get_variable_response_mock - - url_and_auth_token_mock = mocker.patch( - "sqlmesh.schedulers.airflow.mwaa_client.url_and_auth_token_for_environment" - ) - url_and_auth_token_mock.return_value = ("https://test_airflow_host", "test_token") - - client = MWAAClient("test_environment") - - assert client.get_variable("test_key") == "test_value" - - get_variable_mock.assert_called_once_with( - "https://test_airflow_host/aws_mwaa/cli", - data="variables get test_key", - ) - url_and_auth_token_mock.assert_called_once_with("test_environment") - - -def test_get_variable_not_found(mocker: MockerFixture): - get_variable_response_mock = mocker.Mock() - get_variable_response_mock.json.return_value = { - "stdout": "", - "stderr": _encode_output("Variable test_key does not exist"), - } - get_variable_response_mock.status_code = 200 - get_variable_mock = mocker.patch("requests.Session.post") - get_variable_mock.return_value = get_variable_response_mock - - url_and_auth_token_mock = mocker.patch( - "sqlmesh.schedulers.airflow.mwaa_client.url_and_auth_token_for_environment" - ) - url_and_auth_token_mock.return_value = ("https://test_airflow_host", "test_token") - - client = MWAAClient("test_environment") - - assert client.get_variable("test_key") is None - - get_variable_mock.assert_called_once_with( - "https://test_airflow_host/aws_mwaa/cli", - data="variables get test_key", - ) - - -def test_token_refresh(mocker: MockerFixture): - list_runs_response_mock = mocker.Mock() - list_runs_response_mock.json.return_value = { - "stdout": _encode_output(json.dumps([{"run_id": "test_run_id", "state": "success"}])), - "stderr": "", - } - list_runs_response_mock.status_code = 200 - list_runs_mock = mocker.patch("requests.Session.post") - list_runs_mock.return_value = list_runs_response_mock - - url_and_auth_token_mock = mocker.patch( - "sqlmesh.schedulers.airflow.mwaa_client.url_and_auth_token_for_environment" - ) - url_and_auth_token_mock.return_value = ("https://test_airflow_host", "test_token") - - now_mock = mocker.patch("sqlmesh.schedulers.airflow.mwaa_client.now_timestamp") - now_mock.return_value = 0 - - client = MWAAClient("test_environment") - client.get_first_dag_run_id("test_dag_id") - - now_mock.return_value = 15000 # 15 seconds later - client.get_first_dag_run_id("test_dag_id") - - now_mock.return_value = 31000 # 31 seconds later - client.get_first_dag_run_id("test_dag_id") - - now_mock.return_value = 45000 # 45 seconds later - client.get_first_dag_run_id("test_dag_id") - - now_mock.return_value = 63000 # 63 seconds later - client.get_first_dag_run_id("test_dag_id") - - assert url_and_auth_token_mock.call_count == 3 - - -def _encode_output(out: str) -> str: - return base64.b64encode(out.encode("utf-8")).decode("utf-8") diff --git a/tests/schedulers/airflow/test_plan.py b/tests/schedulers/airflow/test_plan.py deleted file mode 100644 index 0714d86e66..0000000000 --- a/tests/schedulers/airflow/test_plan.py +++ /dev/null @@ -1,559 +0,0 @@ -import typing as t -from datetime import datetime -from unittest import mock - -import pytest -from _pytest.fixtures import FixtureRequest -from _pytest.monkeypatch import MonkeyPatch -from pytest_mock.plugin import MockerFixture -from sqlglot import parse_one - -from sqlmesh.core.config import EnvironmentSuffixTarget -from sqlmesh.core.context import Context -from sqlmesh.core.environment import Environment -from sqlmesh.core.model import ( - IncrementalByTimeRangeKind, - ModelKindName, - create_sql_model, -) -from sqlmesh.core.node import NodeType -from sqlmesh.core.snapshot import ( - DeployabilityIndex, - Snapshot, - SnapshotChangeCategory, - SnapshotFingerprint, - SnapshotTableInfo, -) -from sqlmesh.schedulers.airflow import common -from sqlmesh.schedulers.airflow.plan import PlanDagState, create_plan_dag_spec -from sqlmesh.utils.date import to_datetime, to_timestamp -from sqlmesh.utils.errors import SQLMeshError - -pytestmark = pytest.mark.airflow - - -@pytest.fixture -def snapshot(make_snapshot, random_name) -> Snapshot: - result = make_snapshot( - create_sql_model( - random_name(), - parse_one("SELECT 1, ds"), - kind=IncrementalByTimeRangeKind(time_column="ds"), - start="2022-01-01", - ), - ) - result.categorize_as(SnapshotChangeCategory.BREAKING) - return result - - -@pytest.fixture -def depends_on_self_snapshot(make_snapshot, random_name) -> Snapshot: - name = random_name() - result = make_snapshot( - create_sql_model( - name, - parse_one(f"SELECT 1, ds FROM {name}"), - kind=IncrementalByTimeRangeKind(time_column="ds", batch_size=1), - start="2022-01-01", - ), - ) - result.categorize_as(SnapshotChangeCategory.BREAKING) - return result - - -@pytest.mark.parametrize( - "snapshot_fixture, expected_intervals, paused_forward_only", - [ - ("snapshot", [(to_datetime("2022-01-01"), to_datetime("2022-01-05"))], False), - ("snapshot", [(to_datetime("2022-01-01"), to_datetime("2022-01-05"))], True), - ( - "depends_on_self_snapshot", - [ - (to_datetime("2022-01-01"), to_datetime("2022-01-02")), - (to_datetime("2022-01-02"), to_datetime("2022-01-03")), - (to_datetime("2022-01-03"), to_datetime("2022-01-04")), - (to_datetime("2022-01-04"), to_datetime("2022-01-05")), - ], - False, - ), - ], -) -def test_create_plan_dag_spec( - mocker: MockerFixture, - snapshot_fixture: str, - expected_intervals: t.List[t.Tuple[datetime, datetime]], - paused_forward_only: bool, - random_name, - request: FixtureRequest, -): - the_snapshot = request.getfixturevalue(snapshot_fixture) - the_snapshot.categorize_as( - SnapshotChangeCategory.FORWARD_ONLY - if paused_forward_only - else SnapshotChangeCategory.BREAKING - ) - - environment_name = random_name() - new_environment = Environment( - name=environment_name, - snapshots=[the_snapshot.table_info], - start_at="2022-01-01", - end_at="2022-01-04", - plan_id="test_plan_id", - suffix_target=EnvironmentSuffixTarget.TABLE, - catalog_name_override="test_catalog", - ) - - plan_request = common.PlanApplicationRequest( - request_id="test_request_id", - new_snapshots=[the_snapshot], - environment=new_environment, - no_gaps=True, - skip_backfill=False, - restatements={}, - notification_targets=[], - backfill_concurrent_tasks=1, - ddl_concurrent_tasks=1, - users=[], - is_dev=False, - forward_only=True, - models_to_backfill=None, - end_bounded=False, - ensure_finalized_snapshots=False, - directly_modified_snapshots=[the_snapshot.snapshot_id], - indirectly_modified_snapshots={}, - removed_snapshots=[], - ) - - deleted_snapshot = SnapshotTableInfo( - name="test_schema.deleted_model", - fingerprint=SnapshotFingerprint(data_hash="1", metadata_hash="1"), - version="test_version", - physical_schema="test_physical_schema", - parents=[], - change_category=SnapshotChangeCategory.BREAKING, - kind_name=ModelKindName.FULL, - node_type=NodeType.MODEL, - ) - old_environment = Environment( - name=environment_name, - snapshots=[deleted_snapshot], - start_at="2022-01-01", - end_at="2022-01-01", - plan_id="test_plan_id", - suffix_target=EnvironmentSuffixTarget.SCHEMA, - ) - - state_sync_mock = mocker.Mock() - state_sync_mock.get_snapshots.return_value = {} - state_sync_mock.get_environment.return_value = old_environment - state_sync_mock.get_snapshot_intervals.return_value = [] - state_sync_mock.refresh_snapshot_intervals.return_value = [] - - expected_no_gaps_snapshot_names = {the_snapshot.name} if not paused_forward_only else set() - - with mock.patch( - "sqlmesh.schedulers.airflow.plan.now", - side_effect=lambda: to_datetime("2023-01-01"), - ): - plan_spec = create_plan_dag_spec(plan_request, state_sync_mock) - - assert plan_spec == common.PlanDagSpec( - request_id="test_request_id", - environment=new_environment, - new_snapshots=[the_snapshot], - backfill_intervals_per_snapshot=[ - common.BackfillIntervalsPerSnapshot( - snapshot_id=the_snapshot.snapshot_id, - intervals=expected_intervals, - before_promote=not paused_forward_only, - ) - ], - demoted_snapshots=[deleted_snapshot], - unpaused_dt="2022-01-04", - no_gaps=True, - notification_targets=[], - backfill_concurrent_tasks=1, - ddl_concurrent_tasks=1, - users=[], - is_dev=False, - allow_destructive_snapshots=set(), - forward_only=True, - dag_start_ts=to_timestamp("2023-01-01"), - no_gaps_snapshot_names=expected_no_gaps_snapshot_names, - deployability_index_for_creation=( - DeployabilityIndex.all_deployable() - if not paused_forward_only - else DeployabilityIndex.none_deployable() - ), - directly_modified_snapshots=[the_snapshot.snapshot_id], - indirectly_modified_snapshots={}, - removed_snapshots=[], - ) - - state_sync_mock.get_snapshots.assert_called_once() - state_sync_mock.get_environment.assert_called_once() - state_sync_mock.refresh_snapshot_intervals.assert_called_once() - list(state_sync_mock.refresh_snapshot_intervals.call_args_list[0][0][0]) == [the_snapshot] - - -@pytest.mark.parametrize( - "snapshot_fixture, expected_intervals", - [ - ( - "snapshot", - [(to_datetime("2022-01-02"), to_datetime("2022-01-04"))], - ), - ( - "depends_on_self_snapshot", - [ - (to_datetime("2022-01-02"), to_datetime("2022-01-03")), - (to_datetime("2022-01-03"), to_datetime("2022-01-04")), - ], - ), - ], -) -def test_restatement( - mocker: MockerFixture, - monkeypatch: MonkeyPatch, - snapshot_fixture: str, - expected_intervals: t.List[t.Tuple[datetime, datetime]], - random_name, - request: FixtureRequest, -): - the_snapshot = request.getfixturevalue(snapshot_fixture) - environment_name = random_name() - new_environment = Environment( - name=environment_name, - snapshots=[the_snapshot.table_info], - start_at="2022-01-01", - end_at="2022-01-07", - plan_id="test_plan_id", - ) - - the_snapshot.add_interval("2022-01-01", "2022-01-07") - - plan_request = common.PlanApplicationRequest( - request_id="test_request_id", - new_snapshots=[], - environment=new_environment, - no_gaps=True, - skip_backfill=False, - restatements={ - the_snapshot.name: ( - to_timestamp("2022-01-02"), - to_timestamp("2022-01-04"), - ) - }, - notification_targets=[], - backfill_concurrent_tasks=1, - ddl_concurrent_tasks=1, - users=[], - is_dev=False, - forward_only=True, - models_to_backfill=None, - end_bounded=False, - ensure_finalized_snapshots=False, - directly_modified_snapshots=[], - indirectly_modified_snapshots={}, - removed_snapshots=[], - ) - old_environment = Environment( - name=environment_name, - snapshots=[the_snapshot.table_info], - start_at="2022-01-01", - end_at="2022-01-07", - plan_id="test_plan_id", - ) - - state_sync_mock = mocker.Mock() - state_sync_mock.get_snapshots.return_value = {the_snapshot.snapshot_id: the_snapshot} - state_sync_mock.get_environment.return_value = old_environment - state_sync_mock.refresh_snapshot_intervals.return_value = [the_snapshot] - - now_value = "2022-01-09T23:59:59+00:00" - with mock.patch( - "sqlmesh.schedulers.airflow.plan.now", side_effect=lambda: to_datetime(now_value) - ): - plan_spec = create_plan_dag_spec(plan_request, state_sync_mock) - - assert plan_spec == common.PlanDagSpec( - request_id="test_request_id", - environment=new_environment, - new_snapshots=[], - backfill_intervals_per_snapshot=[ - common.BackfillIntervalsPerSnapshot( - snapshot_id=the_snapshot.snapshot_id, - intervals=expected_intervals, - ) - ], - demoted_snapshots=[], - unpaused_dt=None, - no_gaps=True, - notification_targets=[], - backfill_concurrent_tasks=1, - ddl_concurrent_tasks=1, - users=[], - is_dev=False, - allow_destructive_snapshots=set(), - forward_only=True, - dag_start_ts=to_timestamp(now_value), - no_gaps_snapshot_names={the_snapshot.name}, - directly_modified_snapshots=[], - indirectly_modified_snapshots={}, - removed_snapshots=[], - ) - - state_sync_mock.get_snapshots.assert_called_once() - state_sync_mock.get_environment.assert_called_once() - state_sync_mock.refresh_snapshot_intervals.assert_called_once() - - state_sync_mock.remove_interval.assert_called_once_with( - [(the_snapshot, (to_timestamp("2022-01-02"), to_timestamp("2022-01-04")))], - remove_shared_versions=True, - ) - - assert the_snapshot.intervals == [ - (to_timestamp("2022-01-01"), to_timestamp("2022-01-02")), - (to_timestamp("2022-01-04"), to_timestamp("2022-01-08")), - ] - - -def test_select_models_for_backfill(mocker: MockerFixture, random_name, make_snapshot): - snapshot_a = make_snapshot( - create_sql_model( - "a", - parse_one("SELECT 1, ds"), - kind=IncrementalByTimeRangeKind(time_column="ds"), - start="2022-01-01", - ), - ) - snapshot_a.categorize_as(SnapshotChangeCategory.BREAKING) - - snapshot_b = make_snapshot( - create_sql_model( - "b", - parse_one("SELECT 2, ds"), - kind=IncrementalByTimeRangeKind(time_column="ds"), - start="2022-01-01", - ), - ) - snapshot_b.categorize_as(SnapshotChangeCategory.BREAKING) - - environment_name = random_name() - new_environment = Environment( - name=environment_name, - snapshots=[snapshot_a.table_info, snapshot_b.table_info], - start_at="2022-01-01", - end_at="2022-01-04", - plan_id="test_plan_id", - suffix_target=EnvironmentSuffixTarget.TABLE, - ) - - plan_request = common.PlanApplicationRequest( - request_id="test_request_id", - new_snapshots=[snapshot_a, snapshot_b], - environment=new_environment, - no_gaps=True, - skip_backfill=False, - restatements={}, - notification_targets=[], - backfill_concurrent_tasks=1, - ddl_concurrent_tasks=1, - users=[], - is_dev=False, - forward_only=True, - models_to_backfill={snapshot_b.name}, - end_bounded=False, - ensure_finalized_snapshots=False, - directly_modified_snapshots=[snapshot_a.snapshot_id, snapshot_b.snapshot_id], - indirectly_modified_snapshots={}, - removed_snapshots=[], - ) - - state_sync_mock = mocker.Mock() - state_sync_mock.get_snapshots.return_value = {} - state_sync_mock.get_environment.return_value = None - state_sync_mock.get_snapshot_intervals.return_value = [] - state_sync_mock.refresh_snapshot_intervals.return_value = [] - - with mock.patch( - "sqlmesh.schedulers.airflow.plan.now", - side_effect=lambda: to_datetime("2023-01-01"), - ): - plan_spec = create_plan_dag_spec(plan_request, state_sync_mock) - - assert plan_spec == common.PlanDagSpec( - request_id="test_request_id", - environment=new_environment, - new_snapshots=[snapshot_a, snapshot_b], - backfill_intervals_per_snapshot=[ - common.BackfillIntervalsPerSnapshot( - snapshot_id=snapshot_b.snapshot_id, - intervals=[(to_datetime("2022-01-01"), to_datetime("2022-01-05"))], - before_promote=True, - ) - ], - demoted_snapshots=[], - unpaused_dt="2022-01-04", - no_gaps=True, - notification_targets=[], - backfill_concurrent_tasks=1, - ddl_concurrent_tasks=1, - users=[], - is_dev=False, - forward_only=True, - allow_destructive_snapshots=set(), - dag_start_ts=to_timestamp("2023-01-01"), - deployability_index=DeployabilityIndex.all_deployable(), - no_gaps_snapshot_names={'"a"', '"b"'}, - models_to_backfill={snapshot_b.name}, - directly_modified_snapshots=[snapshot_a.snapshot_id, snapshot_b.snapshot_id], - indirectly_modified_snapshots={}, - removed_snapshots=[], - ) - - -def test_create_plan_dag_spec_duplicated_snapshot( - mocker: MockerFixture, snapshot: Snapshot, random_name -): - environment_name = random_name() - new_environment = Environment( - name=environment_name, - snapshots=[snapshot.table_info], - start_at="2022-01-01", - end_at="2022-01-01", - plan_id="test_plan_id", - ) - - plan_request = common.PlanApplicationRequest( - request_id="test_request_id", - new_snapshots=[snapshot], - environment=new_environment, - no_gaps=False, - skip_backfill=False, - restatements={}, - notification_targets=[], - backfill_concurrent_tasks=1, - ddl_concurrent_tasks=1, - users=[], - is_dev=False, - forward_only=False, - models_to_backfill=None, - end_bounded=False, - ensure_finalized_snapshots=False, - directly_modified_snapshots=[], - indirectly_modified_snapshots={}, - removed_snapshots=[], - ) - - dag_run_mock = mocker.Mock() - dag_run_mock.conf = plan_request.dict() - - state_sync_mock = mocker.Mock() - state_sync_mock.get_snapshots.return_value = {snapshot.snapshot_id: snapshot} - - with pytest.raises(SQLMeshError): - create_plan_dag_spec(plan_request, state_sync_mock) - - state_sync_mock.get_snapshots.assert_called_once() - - -@pytest.mark.parametrize("unbounded_end", [None, ""]) -def test_create_plan_dag_spec_unbounded_end( - mocker: MockerFixture, - snapshot: Snapshot, - make_snapshot, - random_name, - unbounded_end: t.Optional[str], -): - unrelated_snapshot = make_snapshot(create_sql_model(random_name(), parse_one("SELECT 2, ds"))) - unrelated_snapshot.categorize_as(SnapshotChangeCategory.BREAKING) - - environment_name = random_name() - new_environment = Environment( - name=environment_name, - snapshots=[snapshot.table_info], - start_at="2022-01-01", - end_at=unbounded_end, - plan_id="test_plan_id", - ) - - plan_request = common.PlanApplicationRequest( - request_id="test_request_id", - new_snapshots=[], - environment=new_environment, - no_gaps=True, - skip_backfill=False, - restatements={}, - notification_targets=[], - backfill_concurrent_tasks=1, - ddl_concurrent_tasks=1, - users=[], - is_dev=False, - forward_only=False, - models_to_backfill=None, - end_bounded=False, - ensure_finalized_snapshots=False, - directly_modified_snapshots=[], - indirectly_modified_snapshots={}, - removed_snapshots=[], - ) - - state_sync_mock = mocker.Mock() - state_sync_mock.get_snapshots.return_value = { - snapshot.snapshot_id: snapshot, - unrelated_snapshot.snapshot_id: unrelated_snapshot, - } - state_sync_mock.get_environment.return_value = None - state_sync_mock.get_snapshot_intervals.return_value = [] - state_sync_mock.refresh_snapshot_intervals.return_value = [] - - create_plan_dag_spec(plan_request, state_sync_mock) - - state_sync_mock.get_snapshots.assert_called_once() - state_sync_mock.get_environment.assert_called_once() - state_sync_mock.refresh_snapshot_intervals.assert_called_once() - - -def test_plan_dag_state(snapshot: Snapshot, sushi_context: Context, random_name): - environment_name = random_name() - environment = Environment( - name=environment_name, - snapshots=[snapshot.table_info], - start_at=to_timestamp("2022-01-01"), - end_at=None, - plan_id="test_plan_id", - ) - plan_dag_spec = common.PlanDagSpec( - request_id="test_request_id", - environment=environment, - new_snapshots=[], - backfill_intervals_per_snapshot=[], - demoted_snapshots=[], - unpaused_dt=None, - no_gaps=True, - notification_targets=[], - backfill_concurrent_tasks=1, - ddl_concurrent_tasks=1, - users=[], - is_dev=False, - allow_destructive_snapshots=set(), - forward_only=True, - dag_start_ts=to_timestamp("2023-01-01"), - ) - - plan_dag_state = PlanDagState.from_state_sync(sushi_context.state_sync) - - assert not plan_dag_state.get_dag_specs() - - plan_dag_state.add_dag_spec(plan_dag_spec) - assert plan_dag_state.get_dag_specs() == [plan_dag_spec] - - plan_dag_state.delete_dag_specs([]) - assert plan_dag_state.get_dag_specs() == [plan_dag_spec] - - plan_dag_state.delete_dag_specs( - [common.plan_application_dag_id(environment_name, "test_request_id")] - ) - assert not plan_dag_state.get_dag_specs() diff --git a/tests/setup.py b/tests/setup.py index 542d1cc3f6..ab48a3128f 100644 --- a/tests/setup.py +++ b/tests/setup.py @@ -1,27 +1,14 @@ -import os - import setuptools +from pathlib import Path +import toml # type: ignore -os.chdir(os.path.join(os.path.dirname(__file__), "..")) -sqlmesh_dist = setuptools.distutils.core.run_setup("setup.py", stop_after="init") -requirements = sqlmesh_dist.install_requires + sqlmesh_dist.extras_require["dev"] # type: ignore -os.chdir(os.path.dirname(__file__)) +# This relies on `make package-tests` copying the sqlmesh pyproject.toml into tests/ so we can reference it +# Otherwise, it's not available in the build environment +sqlmesh_pyproject = Path(__file__).parent / "sqlmesh_pyproject.toml" +parsed = toml.load(sqlmesh_pyproject)["project"] +install_requires = parsed["dependencies"] + parsed["optional-dependencies"]["dev"] +# remove dbt dependencies +install_requires = [req for req in install_requires if not req.startswith("dbt")] -setuptools.setup( - name="sqlmesh-tests", - description="Tests for SQLMesh", - url="https://github.com/TobikoData/sqlmesh", - author="TobikoData Inc.", - author_email="engineering@tobikodata.com", - license="Apache License 2.0", - package_dir={"sqlmesh_tests": ""}, - package_data={"": ["fixtures/**"]}, - use_scm_version={ - "root": "..", - "write_to": "_version.py", - "fallback_version": "0.0.0", - "local_scheme": "no-local-version", - }, - setup_requires=["setuptools_scm"], - install_requires=requirements, -) +# this is just so we can have a dynamic install_requires, everything else is defined in pyproject.toml +setuptools.setup(install_requires=install_requires) diff --git a/tests/test_forking.py b/tests/test_forking.py new file mode 100644 index 0000000000..d11379a158 --- /dev/null +++ b/tests/test_forking.py @@ -0,0 +1,88 @@ +import os +import pytest + +from sqlmesh import Context +from sqlmesh.core.model import schema +import concurrent.futures + + +pytestmark = pytest.mark.isolated + + +def test_parallel_load(assert_exp_eq, mocker): + mocker.patch("sqlmesh.core.constants.MAX_FORK_WORKERS", 2) + + spy_update_schemas = mocker.spy(schema, "_update_model_schemas") + process_pool_executor = mocker.spy(concurrent.futures.ProcessPoolExecutor, "__init__") + as_completed = mocker.spy(concurrent.futures, "as_completed") + + context = Context(paths="examples/sushi") + + if hasattr(os, "fork"): + process_pool_executor.assert_called() + as_completed.assert_called() + executor_args = process_pool_executor.call_args + assert executor_args[1]["max_workers"] == 2 + + assert len(context.models) == 20 + spy_update_schemas.assert_called() + assert_exp_eq( + context.render("sushi.customers"), + """ +WITH "current_marketing_outer" AS ( + SELECT + "marketing"."customer_id" AS "customer_id", + "marketing"."status" AS "status" + FROM "memory"."sushi"."marketing" AS "marketing" + WHERE + "marketing"."valid_to" IS NULL +) +SELECT DISTINCT + CAST("o"."customer_id" AS INT) AS "customer_id", /* this comment should not be registered */ + "m"."status" AS "status", + "d"."zip" AS "zip" +FROM "memory"."sushi"."orders" AS "o" +LEFT JOIN ( + WITH "current_marketing" AS ( + SELECT + "current_marketing_outer"."customer_id" AS "customer_id", + "current_marketing_outer"."status" AS "status", + 2 AS "another_column" + FROM "current_marketing_outer" AS "current_marketing_outer" + ) + SELECT + "current_marketing"."customer_id" AS "customer_id", + "current_marketing"."status" AS "status", + "current_marketing"."another_column" AS "another_column" + FROM "current_marketing" AS "current_marketing" + WHERE + "current_marketing"."customer_id" <> 100 +) AS "m" + ON "m"."customer_id" = "o"."customer_id" +LEFT JOIN "memory"."raw"."demographics" AS "d" + ON "d"."customer_id" = "o"."customer_id" + WHERE + "o"."customer_id" > 0 + """, + ) + + context.plan(no_prompts=True, auto_apply=True) + + +def test_parallel_load_multi_repo(assert_exp_eq, mocker): + mocker.patch("sqlmesh.core.constants.MAX_FORK_WORKERS", 2) + + process_pool_executor = mocker.spy(concurrent.futures.ProcessPoolExecutor, "__init__") + context = Context(paths=["examples/multi/repo_1", "examples/multi/repo_2"], gateway="memory") + + if hasattr(os, "fork"): + executor_args = process_pool_executor.call_args + assert executor_args[1]["max_workers"] == 2 + assert len(context.models) == 5 + + assert_exp_eq( + context.render("memory.bronze.a"), + 'SELECT 1 AS "col_a", \'b\' AS "col_b", 1 AS "one", \'repo_1\' AS "dup"', + ) + + context.plan(no_prompts=True, auto_apply=True) diff --git a/tests/utils/__init__.py b/tests/utils/__init__.py index e69de29bb2..744ad37757 100644 --- a/tests/utils/__init__.py +++ b/tests/utils/__init__.py @@ -0,0 +1,23 @@ +import pytest + +from sqlmesh.utils import sanitize_name + + +@pytest.mark.parametrize( + "raw,exclude_unicode,include_unicode", + [ + ("simple", "simple", "simple"), + ("snake_case", "snake_case", "snake_case"), + ("客户数据", "____", "客户数据"), + ("客户-数据 v2", "______v2", "客户_数据_v2"), + ("中文,逗号", "_____", "中文_逗号"), + ("a/b", "a_b", "a_b"), + ("spaces\tand\nnewlines", "spaces_and_newlines", "spaces_and_newlines"), + ("data📦2025", "data_2025", "data_2025"), + ("MiXeD123_名字", "MiXeD123___", "MiXeD123_名字"), + ("", "", ""), + ], +) +def test_sanitize_name_no_(raw, exclude_unicode, include_unicode): + assert sanitize_name(raw) == exclude_unicode + assert sanitize_name(raw, include_unicode=True) == include_unicode diff --git a/tests/utils/pandas.py b/tests/utils/pandas.py index 130e515be2..b9451f4545 100644 --- a/tests/utils/pandas.py +++ b/tests/utils/pandas.py @@ -2,8 +2,8 @@ import typing as t -import numpy as np -import pandas as pd +import numpy as np # noqa: TID253 +import pandas as pd # noqa: TID253 def create_df(data: t.Sequence[t.Tuple], schema: t.Dict[str, str]) -> pd.DataFrame: diff --git a/tests/utils/test_aws.py b/tests/utils/test_aws.py new file mode 100644 index 0000000000..905cc00dfe --- /dev/null +++ b/tests/utils/test_aws.py @@ -0,0 +1,34 @@ +import pytest +from sqlmesh.utils.errors import SQLMeshError, ConfigError +from sqlmesh.utils.aws import validate_s3_uri, parse_s3_uri + + +def test_validate_s3_uri(): + with pytest.raises(SQLMeshError, match=r".*must be a s3://.*"): + validate_s3_uri("hdfs://foo/bar") + + with pytest.raises(ConfigError, match=r".*must be a s3://.*"): + validate_s3_uri("hdfs://foo/bar", error_type=ConfigError) + + with pytest.raises(SQLMeshError, match=r".*must be a s3://.*"): + validate_s3_uri("/foo/bar") + + with pytest.raises(SQLMeshError, match=r".*cannot be more than 700 characters"): + long_path = "foo/bar/" * 100 + assert len(long_path) > 700 + validate_s3_uri(f"s3://{long_path}") + + assert validate_s3_uri("s3://foo/bar/") == "s3://foo/bar/" + assert validate_s3_uri("s3://foo/bar/baz") == "s3://foo/bar/baz" + assert validate_s3_uri("s3://foo/bar/baz", base=True) == "s3://foo/bar/baz/" + + +def test_parse_s3_uri(): + with pytest.raises(SQLMeshError, match=r".*must be a s3://.*"): + parse_s3_uri("hdfs://foo/bar") + + assert parse_s3_uri("s3://foo") == ("foo", "") + assert parse_s3_uri("s3://foo/") == ("foo", "") + assert parse_s3_uri("s3://foo/bar") == ("foo", "bar") + assert parse_s3_uri("s3://foo/bar/") == ("foo", "bar/") + assert parse_s3_uri("s3://foo/bar/baz/bing.txt") == ("foo", "bar/baz/bing.txt") diff --git a/tests/utils/test_cache.py b/tests/utils/test_cache.py index 870d0b4d19..ed19765b8a 100644 --- a/tests/utils/test_cache.py +++ b/tests/utils/test_cache.py @@ -1,9 +1,11 @@ +import typing as t from pathlib import Path from pytest_mock.plugin import MockerFixture from sqlglot import parse_one -from sqlmesh.core.model import SqlModel +from sqlmesh.core import dialect as d +from sqlmesh.core.model import SqlModel, load_sql_based_model from sqlmesh.core.model.cache import OptimizedQueryCache from sqlmesh.utils.cache import FileCache from sqlmesh.utils.pydantic import PydanticModel @@ -14,7 +16,7 @@ class _TestEntry(PydanticModel): def test_file_cache(tmp_path: Path, mocker: MockerFixture): - cache = FileCache(tmp_path, _TestEntry) + cache: FileCache[_TestEntry] = FileCache(tmp_path) test_entry_a = _TestEntry(value="value_a") test_entry_b = _TestEntry(value="value_b") @@ -37,6 +39,7 @@ def test_file_cache(tmp_path: Path, mocker: MockerFixture): loader.assert_called_once() assert "___test_model_" in cache._cache_entry_path('"test_model"').name + assert "客户数据" in cache._cache_entry_path("客户数据").name def test_optimized_query_cache(tmp_path: Path, mocker: MockerFixture): @@ -79,3 +82,52 @@ def test_optimized_query_cache_missing_rendered_query(tmp_path: Path, mocker: Mo assert model._query_renderer._cache == [None] assert model._query_renderer._optimized_cache is None + + +def test_optimized_query_cache_macro_def_change(tmp_path: Path, mocker: MockerFixture): + expressions = d.parse( + """ + MODEL (name db.table); + + @DEF(filter_, a = 1); + + SELECT a FROM (SELECT 1 AS a) WHERE @filter_; + """ + ) + model = t.cast(SqlModel, load_sql_based_model(expressions)) + + cache = OptimizedQueryCache(tmp_path) + + assert not cache.with_optimized_query(model) + + model._query_renderer._cache = [] + model._query_renderer._optimized_cache = None + + assert cache.with_optimized_query(model) + assert ( + model.render_query_or_raise().sql() + == 'SELECT "_0"."a" AS "a" FROM (SELECT 1 AS "a") AS "_0" WHERE "_0"."a" = 1' + ) + + # Change the filter_ definition + new_expressions = d.parse( + """ + MODEL (name db.table); + + @DEF(filter_, a = 2); + + SELECT a FROM (SELECT 1 AS a) WHERE @filter_; + """ + ) + new_model = t.cast(SqlModel, load_sql_based_model(new_expressions)) + + assert not cache.with_optimized_query(new_model) + + new_model._query_renderer._cache = [] + new_model._query_renderer._optimized_cache = None + + assert cache.with_optimized_query(new_model) + assert ( + new_model.render_query_or_raise().sql() + == 'SELECT "_0"."a" AS "a" FROM (SELECT 1 AS "a") AS "_0" WHERE "_0"."a" = 2' + ) diff --git a/tests/utils/test_concurrency.py b/tests/utils/test_concurrency.py index 892c6ef485..5e1e4326f7 100644 --- a/tests/utils/test_concurrency.py +++ b/tests/utils/test_concurrency.py @@ -139,7 +139,7 @@ def raise_(snapshot): assert len(errors) == 1 assert errors[0].node == failed_snapshot.snapshot_id - assert skipped == [snapshot_a.snapshot_id, snapshot_b.snapshot_id, snapshot_c.snapshot_id] + assert set(skipped) == {snapshot_a.snapshot_id, snapshot_b.snapshot_id, snapshot_c.snapshot_id} @pytest.mark.parametrize("tasks_num", [1, 3]) diff --git a/tests/utils/test_connection_pool.py b/tests/utils/test_connection_pool.py index 454308edce..c5926a3824 100644 --- a/tests/utils/test_connection_pool.py +++ b/tests/utils/test_connection_pool.py @@ -6,6 +6,7 @@ from sqlmesh.utils.connection_pool import ( SingletonConnectionPool, ThreadLocalConnectionPool, + ThreadLocalSharedConnectionPool, ) @@ -207,3 +208,118 @@ def thread(): assert cursor_mock_thread_one.rollback.call_count == 1 assert cursor_mock_thread_two.begin.call_count == 1 + + +def test_thread_local_connection_pool_attributes(mocker: MockerFixture): + pool = ThreadLocalConnectionPool(connection_factory=lambda: mocker.Mock()) + + pool.set_attribute("foo", "bar") + current_threadid = get_ident() + + def _in_thread(pool: ThreadLocalConnectionPool): + assert get_ident() != current_threadid + pool.set_attribute("foo", "baz") + + with ThreadPoolExecutor() as executor: + future = executor.submit(_in_thread, pool) + assert not future.exception() + + assert pool.get_all_attributes("foo") == ["bar", "baz"] + assert pool.get_attribute("foo") == "bar" + + pool.close_all() + + assert pool.get_all_attributes("foo") == [] + assert pool.get_attribute("foo") is None + + +def test_thread_local_shared_connection_pool(mocker: MockerFixture): + cursor_mock_thread_one = mocker.Mock() + cursor_mock_thread_two = mocker.Mock() + connection_mock = mocker.Mock() + connection_mock.cursor.side_effect = [ + cursor_mock_thread_one, + cursor_mock_thread_two, + cursor_mock_thread_one, + ] + + test_thread_id = get_ident() + + connection_factory_mock = mocker.Mock(return_value=connection_mock) + pool = ThreadLocalSharedConnectionPool(connection_factory_mock) + + assert pool.get_cursor() == cursor_mock_thread_one + assert pool.get_cursor() == cursor_mock_thread_one + assert pool.get() == connection_mock + assert pool.get() == connection_mock + + def thread(): + assert pool.get_cursor() == cursor_mock_thread_two + assert pool.get_cursor() == cursor_mock_thread_two + assert pool.get() == connection_mock + assert pool.get() == connection_mock + + with ThreadPoolExecutor(max_workers=1) as executor: + executor.submit(thread).result() + + assert pool._connection is not None + assert len(pool._thread_cursors) == 2 + + pool.close_all(exclude_calling_thread=True) + + assert pool._connection is not None + assert len(pool._thread_cursors) == 1 + assert test_thread_id in pool._thread_cursors + + pool.close_cursor() + pool.close() + + assert pool.get_cursor() == cursor_mock_thread_one + + pool.close_all() + + assert connection_factory_mock.call_count == 1 + + assert cursor_mock_thread_one.close.call_count == 2 + assert connection_mock.cursor.call_count == 3 + assert connection_mock.close.call_count == 1 + + +def test_thread_local_shared_connection_pool_close(mocker: MockerFixture): + connection_mock = mocker.Mock() + cursor_mock = mocker.Mock() + connection_mock.cursor.return_value = cursor_mock + + connection_factory_mock = mocker.Mock(return_value=connection_mock) + pool = ThreadLocalSharedConnectionPool(connection_factory_mock) + + # First time we get a connection + pool.get() + pool.get() + pool.get_cursor() + pool.get_cursor() + + # This shouldn't close the connection, only the cursor + pool.close() + pool.get() + pool.get() + pool.get_cursor() + + pool.get_cursor() + # This shouldn't close the connection either + pool.close_all(exclude_calling_thread=True) + + pool.get() + pool.get() + # Now this should close the connection + pool.close_all() + + # Re-open the connection + pool.get() + pool.get() + # Close it again + pool.close_all() + + assert cursor_mock.close.call_count == 2 + assert connection_factory_mock.call_count == 2 + assert connection_mock.close.call_count == 2 diff --git a/tests/utils/test_dag.py b/tests/utils/test_dag.py index 58b53c4cb0..7c142ee4a0 100644 --- a/tests/utils/test_dag.py +++ b/tests/utils/test_dag.py @@ -44,6 +44,11 @@ def test_sorted(): assert result[6] == "a" +def test_upstream(): + dag = DAG({"a": {"b", "c"}, "b": {"d", "e"}, "c": {"f", "g"}}) + assert dag.upstream("a") == {"b", "c", "d", "e", "f", "g"} + + def test_sorted_with_cycles(): dag = DAG({"a": {}, "b": {"a"}, "c": {"b"}, "d": {"b", "e"}, "e": {"b", "d"}}) @@ -52,8 +57,7 @@ def test_sorted_with_cycles(): expected_error_message = ( "Detected a cycle in the DAG. Please make sure there are no circular references between nodes.\n" - "Last nodes added to the DAG: c\n" - "Possible candidates to check for circular references: d, e" + "Cycle:\nd ->\ne ->\nd" ) assert expected_error_message == str(ex.value) @@ -65,7 +69,7 @@ def test_sorted_with_cycles(): expected_error_message = ( "Detected a cycle in the DAG. Please make sure there are no circular references between nodes.\n" - "Possible candidates to check for circular references: a, b, c" + "Cycle:\na ->\nb ->\nc ->\na" ) assert expected_error_message == str(ex.value) @@ -76,11 +80,11 @@ def test_sorted_with_cycles(): dag.sorted expected_error_message = ( - "Last nodes added to the DAG: c\n" - "Possible candidates to check for circular references: b, d" + "Detected a cycle in the DAG. Please make sure there are no circular references between nodes.\n" + + "Cycle:\nb ->\nd ->\nb" ) - assert expected_error_message in str(ex.value) + assert expected_error_message == str(ex.value) def test_reversed_graph(): diff --git a/tests/utils/test_date.py b/tests/utils/test_date.py index 0f1a9c4239..cb35a6973c 100644 --- a/tests/utils/test_date.py +++ b/tests/utils/test_date.py @@ -2,14 +2,17 @@ from datetime import date, datetime import pytest -from freezegun import freeze_time +import time_machine from sqlglot import exp +import pandas as pd # noqa: TID253 from sqlmesh.utils.date import ( UTC, TimeLike, date_dict, - is_catagorical_relative_expression, + format_tz_datetime, + is_categorical_relative_expression, + is_relative, make_inclusive, to_datetime, to_time_column, @@ -57,7 +60,7 @@ def test_to_datetime() -> None: ], ) def test_to_datetime_with_expressions(expression, result) -> None: - with freeze_time("2023-01-20 12:30:30"): + with time_machine.travel("2023-01-20 12:30:30 UTC", tick=False): assert to_datetime(expression) == result @@ -68,25 +71,25 @@ def test_to_timestamp() -> None: @pytest.mark.parametrize( "start_in, end_in, start_out, end_out", [ - ("2020-01-01", "2020-01-01", "2020-01-01", "2020-01-01 23:59:59.999999"), - ("2020-01-01", date(2020, 1, 1), "2020-01-01", "2020-01-01 23:59:59.999999"), + ("2020-01-01", "2020-01-01", "2020-01-01", "2020-01-01 23:59:59.999999+00:00"), + ("2020-01-01", date(2020, 1, 1), "2020-01-01", "2020-01-01 23:59:59.999999+00:00"), ( date(2020, 1, 1), date(2020, 1, 1), "2020-01-01", - "2020-01-01 23:59:59.999999", + "2020-01-01 23:59:59.999999+00:00", ), ( "2020-01-01", "2020-01-01 12:00:00", "2020-01-01", - "2020-01-01 11:59:59.999999", + "2020-01-01 11:59:59.999999+00:00", ), ( "2020-01-01", to_datetime("2020-01-02"), "2020-01-01", - "2020-01-01 23:59:59.999999", + "2020-01-01 23:59:59.999999+00:00", ), ], ) @@ -97,6 +100,47 @@ def test_make_inclusive(start_in, end_in, start_out, end_out) -> None: ) +@pytest.mark.parametrize( + "start_in, end_in, start_out, end_out, dialect", + [ + ("2020-01-01", "2020-01-01", "2020-01-01", "2020-01-01 23:59:59.999999999+00:00", "tsql"), + ( + "2020-01-01", + date(2020, 1, 1), + "2020-01-01", + "2020-01-01 23:59:59.999999999+00:00", + "tsql", + ), + ( + date(2020, 1, 1), + date(2020, 1, 1), + "2020-01-01", + "2020-01-01 23:59:59.999999999+00:00", + "tsql", + ), + ( + "2020-01-01", + "2020-01-01 12:00:00", + "2020-01-01", + "2020-01-01 11:59:59.999999999+00:00", + "tsql", + ), + ( + "2020-01-01", + to_datetime("2020-01-02"), + "2020-01-01", + "2020-01-01 23:59:59.999999999+00:00", + "tsql", + ), + ], +) +def test_make_inclusive_tsql(start_in, end_in, start_out, end_out, dialect) -> None: + assert make_inclusive(start_in, end_in, "tsql") == ( + to_datetime(start_out), + pd.Timestamp(end_out), + ) + + @pytest.mark.parametrize( "expression, result", [ @@ -120,7 +164,7 @@ def test_make_inclusive(start_in, end_in, start_out, end_out) -> None: ], ) def test_is_catagorical_relative_expression(expression, result): - assert is_catagorical_relative_expression(expression) == result + assert is_categorical_relative_expression(expression) == result def test_to_ts(): @@ -134,53 +178,77 @@ def test_to_tstz(): @pytest.mark.parametrize( - "time_column, time_column_type, time_column_format, result", + "time_column, time_column_type, dialect, time_column_format, result", [ ( exp.null(), exp.DataType.build("TIMESTAMP"), + "", None, "CAST(NULL AS TIMESTAMP)", ), ( "2020-01-01 00:00:00+00:00", exp.DataType.build("DATE"), + "", None, "CAST('2020-01-01' AS DATE)", ), ( "2020-01-01 00:00:00+00:00", exp.DataType.build("TIMESTAMPTZ"), + "", None, "CAST('2020-01-01 00:00:00+00:00' AS TIMESTAMPTZ)", ), ( "2020-01-01 00:00:00+00:00", exp.DataType.build("TIMESTAMP"), + "", None, "CAST('2020-01-01 00:00:00' AS TIMESTAMP)", ), ( "2020-01-01 00:00:00+00:00", exp.DataType.build("TEXT"), + "", "%Y-%m-%dT%H:%M:%S%z", "'2020-01-01T00:00:00+0000'", ), ( "2020-01-01 00:00:00+00:00", exp.DataType.build("INT"), + "", "%Y%m%d", "20200101", ), + ( + "2020-01-01 00:00:00+00:00", + exp.DataType.build("TIMESTAMPTZ"), + "tsql", + "%Y%m%d", + "CAST('2020-01-01 00:00:00+00:00' AS DATETIMEOFFSET)", + ), + ( + pd.Timestamp("2020-01-01 00:00:00.1234567+00:00"), + exp.DataType.build("DATETIME2", dialect="tsql"), + "tsql", + None, + "CAST('2020-01-01 00:00:00.123456700' AS DATETIME2)", + ), ], ) def test_to_time_column( time_column: t.Union[TimeLike, exp.Null], time_column_type: exp.DataType, + dialect: str, time_column_format: t.Optional[str], result: str, ): - assert to_time_column(time_column, time_column_type, time_column_format).sql() == result + assert ( + to_time_column(time_column, time_column_type, dialect, time_column_format).sql(dialect) + == result + ) def test_date_dict(): @@ -190,6 +258,10 @@ def test_date_dict(): "execution_dt": datetime(2020, 1, 2, 1, 0, 0, tzinfo=UTC), "start_dt": datetime(2020, 1, 1, 0, 0, 0, tzinfo=UTC), "end_dt": datetime(2020, 1, 2, 0, 0, 0, tzinfo=UTC), + "latest_dtntz": datetime(2020, 1, 2, 1, 0, 0, tzinfo=None), + "execution_dtntz": datetime(2020, 1, 2, 1, 0, 0, tzinfo=None), + "start_dtntz": datetime(2020, 1, 1, 0, 0, 0, tzinfo=None), + "end_dtntz": datetime(2020, 1, 2, 0, 0, 0, tzinfo=None), "latest_date": date(2020, 1, 2), "execution_date": date(2020, 1, 2), "start_date": date(2020, 1, 1), @@ -219,3 +291,55 @@ def test_date_dict(): "start_hour": 0, "end_hour": 0, } + + +@pytest.mark.parametrize( + "start, end, expected_start_dt, expected_end_dt", + [ + ( + "2020-01-01 00:00:00.1234567", + "2020-01-02", + to_datetime("2020-01-01 00:00:00.1234567+00:00"), + pd.Timestamp("2020-01-02 23:59:59.999999999+00:00"), + ), + ( + "2020-01-01 00:00:00.1234567", + "2020-01-02 00:00:00.1234567", + to_datetime("2020-01-01 00:00:00.1234567+00:00"), + pd.Timestamp("2020-01-02 00:00:00.123455999+00:00"), + ), + ( + "2020-01-01 00:00:00.1234567", + "2020-01-02 00:00:00", + to_datetime("2020-01-01 00:00:00.1234567+00:00"), + pd.Timestamp("2020-01-01 23:59:59.999999999+00:00"), + ), + ], +) +def test_tsql_date_dict(start, end, expected_start_dt, expected_end_dt): + resp = date_dict( + "2020-01-02 01:00:00", + *make_inclusive(start, end, "tsql"), + ) + assert resp["start_dt"] == expected_start_dt + assert resp["end_dt"] == expected_end_dt + + +def test_format_tz_datetime(): + test_datetime = to_datetime("2020-01-01 00:00:00") + assert format_tz_datetime(test_datetime) == "2020-01-01 12:00AM UTC" + assert format_tz_datetime(test_datetime, format_string=None) == "2020-01-01 00:00:00+00:00" + + +def test_is_relative(): + assert is_relative("1 week ago") + assert is_relative("1 week") + assert is_relative("1 day ago") + assert is_relative("yesterday") + + assert not is_relative("2024-01-01") + assert not is_relative("2024-01-01 01:02:03") + assert not is_relative(to_datetime("2024-01-01 01:02:03")) + assert not is_relative(to_timestamp("2024-01-01 01:02:03")) + assert not is_relative(to_datetime("1 week ago")) + assert not is_relative(to_timestamp("1 day ago")) diff --git a/tests/utils/test_git_client.py b/tests/utils/test_git_client.py new file mode 100644 index 0000000000..13eecf294b --- /dev/null +++ b/tests/utils/test_git_client.py @@ -0,0 +1,173 @@ +import subprocess +from pathlib import Path +import pytest +from sqlmesh.utils.git import GitClient + + +@pytest.fixture +def git_repo(tmp_path: Path) -> Path: + repo_path = tmp_path / "test_repo" + repo_path.mkdir() + subprocess.run(["git", "init", "-b", "main"], cwd=repo_path, check=True, capture_output=True) + return repo_path + + +def test_git_uncommitted_changes(git_repo: Path): + git_client = GitClient(git_repo) + + test_file = git_repo / "model.sql" + test_file.write_text("SELECT 1 AS a") + subprocess.run(["git", "add", "model.sql"], cwd=git_repo, check=True, capture_output=True) + subprocess.run( + [ + "git", + "-c", + "user.name=Max", + "-c", + "user.email=max@rb.com", + "commit", + "-m", + "Initial commit", + ], + cwd=git_repo, + check=True, + capture_output=True, + ) + assert git_client.list_uncommitted_changed_files() == [] + + # make an unstaged change and see that it is listed + test_file.write_text("SELECT 2 AS a") + uncommitted = git_client.list_uncommitted_changed_files() + assert len(uncommitted) == 1 + assert uncommitted[0].name == "model.sql" + + # stage the change and test that it is still detected + subprocess.run(["git", "add", "model.sql"], cwd=git_repo, check=True, capture_output=True) + uncommitted = git_client.list_uncommitted_changed_files() + assert len(uncommitted) == 1 + assert uncommitted[0].name == "model.sql" + + +def test_git_both_staged_and_unstaged_changes(git_repo: Path): + git_client = GitClient(git_repo) + + file1 = git_repo / "model1.sql" + file2 = git_repo / "model2.sql" + file1.write_text("SELECT 1") + file2.write_text("SELECT 2") + subprocess.run(["git", "add", "."], cwd=git_repo, check=True, capture_output=True) + subprocess.run( + [ + "git", + "-c", + "user.name=Max", + "-c", + "user.email=max@rb.com", + "commit", + "-m", + "Initial commit", + ], + cwd=git_repo, + check=True, + capture_output=True, + ) + + # stage file1 + file1.write_text("SELECT 10") + subprocess.run(["git", "add", "model1.sql"], cwd=git_repo, check=True, capture_output=True) + + # modify file2 but don't stage it! + file2.write_text("SELECT 20") + + # both should be detected + uncommitted = git_client.list_uncommitted_changed_files() + assert len(uncommitted) == 2 + file_names = {f.name for f in uncommitted} + assert file_names == {"model1.sql", "model2.sql"} + + +def test_git_untracked_files(git_repo: Path): + git_client = GitClient(git_repo) + initial_file = git_repo / "initial.sql" + initial_file.write_text("SELECT 0") + subprocess.run(["git", "add", "initial.sql"], cwd=git_repo, check=True, capture_output=True) + subprocess.run( + [ + "git", + "-c", + "user.name=Max", + "-c", + "user.email=max@rb.com", + "commit", + "-m", + "Initial commit", + ], + cwd=git_repo, + check=True, + capture_output=True, + ) + + new_file = git_repo / "new_model.sql" + new_file.write_text("SELECT 1") + + # untracked file should not appear in uncommitted changes + assert git_client.list_uncommitted_changed_files() == [] + + # but in untracked + untracked = git_client.list_untracked_files() + assert len(untracked) == 1 + assert untracked[0].name == "new_model.sql" + + +def test_git_committed_changes(git_repo: Path): + git_client = GitClient(git_repo) + + test_file = git_repo / "model.sql" + test_file.write_text("SELECT 1") + subprocess.run(["git", "add", "model.sql"], cwd=git_repo, check=True, capture_output=True) + subprocess.run( + [ + "git", + "-c", + "user.name=Max", + "-c", + "user.email=max@rb.com", + "commit", + "-m", + "Initial commit", + ], + cwd=git_repo, + check=True, + capture_output=True, + ) + + subprocess.run( + ["git", "checkout", "-b", "feature"], + cwd=git_repo, + check=True, + capture_output=True, + ) + + test_file.write_text("SELECT 2") + subprocess.run(["git", "add", "model.sql"], cwd=git_repo, check=True, capture_output=True) + subprocess.run( + [ + "git", + "-c", + "user.name=Max", + "-c", + "user.email=max@rb.com", + "commit", + "-m", + "Update on feature branch", + ], + cwd=git_repo, + check=True, + capture_output=True, + ) + + committed = git_client.list_committed_changed_files(target_branch="main") + assert len(committed) == 1 + assert committed[0].name == "model.sql" + + assert git_client.list_uncommitted_changed_files() == [] diff --git a/tests/utils/test_helpers.py b/tests/utils/test_helpers.py index 48008064a9..20a544512e 100644 --- a/tests/utils/test_helpers.py +++ b/tests/utils/test_helpers.py @@ -1,7 +1,10 @@ import pytest +from functools import wraps from sqlglot import expressions from sqlglot.optimizer.annotate_types import annotate_types +from sqlmesh.core.console import set_console, get_console, TerminalConsole + from sqlmesh.utils import columns_to_types_all_known @@ -72,3 +75,19 @@ ) def test_columns_to_types_all_known(columns_to_types, expected) -> None: assert columns_to_types_all_known(columns_to_types) == expected + + +def use_terminal_console(func): + @wraps(func) + def test_wrapper(*args, **kwargs): + orig_console = get_console() + try: + new_console = TerminalConsole() + new_console.console.width = 80 + new_console.console.no_color = True + set_console(new_console) + func(*args, **kwargs) + finally: + set_console(orig_console) + + return test_wrapper diff --git a/tests/utils/test_jinja.py b/tests/utils/test_jinja.py index 05c6d9e524..1cf7c1bf95 100644 --- a/tests/utils/test_jinja.py +++ b/tests/utils/test_jinja.py @@ -49,6 +49,10 @@ def test_macro_registry_render(): "macro_a_a", ] + assert ( + extractor.extract("""{% set foo = bar | replace("'", "\\"") %}""", dialect="bigquery") == {} + ) + def test_macro_registry_render_nested_self_package_references(): package_a = """ @@ -280,3 +284,48 @@ def test_find_call_names(): ("package", "package_macro"), ("'stringval'", "function"), ] + + +def test_dbt_adapter_macro_scope(): + package_a = """ +{% macro spark__macro_a() %} +macro_a +{% endmacro %}""" + + extractor = MacroExtractor() + registry = JinjaMacroRegistry() + + macros = extractor.extract(package_a) + macros["spark__macro_a"].is_top_level = True + + registry.add_macros(macros, package="package_a") + + rendered = registry.build_environment().from_string("{{ spark__macro_a() }}").render() + assert rendered.strip() == "macro_a" + + +def test_macro_registry_to_expressions_sorted(): + refs = AttributeDict( + { + "payments": { + "database": "jaffle_shop", + "schema": "main", + "nested": {"foo": "bar", "baz": "bing"}, + }, + "orders": {"schema": "main", "database": "jaffle_shop", "nested_list": ["b", "a", "c"]}, + } + ) + + registry = JinjaMacroRegistry() + registry.add_globals({"sources": {}, "refs": refs}) + + # Ensure that the AttributeDict string representation is sorted + # in order to prevent an unexpected *visual* diff in ModelDiff + # (note that the actual diff is based on the data hashes, so this is purely visual) + expressions = registry.to_expressions() + assert len(expressions) == 1 + assert ( + expressions[0].sql(dialect="duckdb") + == "refs = {'orders': {'database': 'jaffle_shop', 'nested_list': ['a', 'b', 'c'], 'schema': 'main'}, 'payments': {'database': 'jaffle_shop', 'nested': {'baz': 'bing', 'foo': 'bar'}, 'schema': 'main'}}\n" + "sources = {}" + ) diff --git a/tests/utils/test_lineage_description.py b/tests/utils/test_lineage_description.py new file mode 100644 index 0000000000..e7053e3bcc --- /dev/null +++ b/tests/utils/test_lineage_description.py @@ -0,0 +1,32 @@ +from sqlmesh.core.context import Context +from sqlmesh.utils.lineage import generate_markdown_description + + +def test_model_description() -> None: + context = Context(paths=["examples/sushi"]) + + model_no_description = context.get_model("sushi.order_items") + markdown = generate_markdown_description(model_no_description) + + assert markdown == ( + "| Column | Type | Description |\n" + "|--------|------|-------------|\n" + "| id | INT | |\n" + "| order_id | INT | |\n" + "| item_id | INT | |\n" + "| quantity | INT | |\n" + "| event_date | DATE | |" + ) + + model_with_description = context.get_model("sushi.customers") + markdown = generate_markdown_description(model_with_description) + + assert markdown == ( + "Sushi customer data\n" + "\n" + "| Column | Type | Description |\n" + "|--------|------|-------------|\n" + "| customer_id | INT | customer_id uniquely identifies customers |\n" + "| status | TEXT | |\n" + "| zip | TEXT | |" + ) diff --git a/tests/utils/test_metaprogramming.py b/tests/utils/test_metaprogramming.py index f236a75258..4e55ae490e 100644 --- a/tests/utils/test_metaprogramming.py +++ b/tests/utils/test_metaprogramming.py @@ -1,19 +1,28 @@ import typing as t +from contextlib import contextmanager from dataclasses import dataclass from pathlib import Path +from tenacity import retry, stop_after_attempt -import pandas as pd +import re +import pandas as pd # noqa: TID253 import pytest import sqlglot from pytest_mock.plugin import MockerFixture -from sqlglot.expressions import to_table +from sqlglot import exp +from sqlglot import exp as expressions +from sqlglot.expressions import SQLGLOT_META, to_table +from sqlglot.optimizer.pushdown_projections import SELECT_ALL import tests.utils.test_date as test_date +from sqlmesh.core.dialect import normalize_model_name from sqlmesh.core import constants as c +from sqlmesh.core.macros import RuntimeStage from sqlmesh.utils.errors import SQLMeshError from sqlmesh.utils.metaprogramming import ( Executable, ExecutableKind, + _dict_sort, build_env, func_globals, normalize_source, @@ -40,26 +49,27 @@ def test_print_exception(mocker: MockerFixture): except Exception as ex: print_exception(ex, test_env, out_mock) - expected_message = f"""Traceback (most recent call last): + expected_message = r""" File ".*?.tests.utils.test_metaprogramming\.py", line 48, in test_print_exception + eval\("test_fun\(\)", env\).* - File "{__file__}", line 39, in test_print_exception - eval("test_fun()", env) - - File "", line 1, in - - File '/test/path.py' (or imported file), line 2, in test_fun - def test_fun(): - raise RuntimeError("error") + File '/test/path.py' \(or imported file\), line 2, in test_fun + def test_fun\(\): + raise RuntimeError\("error"\) + RuntimeError: error +""" + actual_message = out_mock.write.call_args_list[0][0][0] + assert isinstance(actual_message, str) + expected_message = "".join(expected_message.split()) + actual_message = "".join(actual_message.split()) -RuntimeError: error -""" - out_mock.write.assert_called_once_with(expected_message) + assert re.match(expected_message, actual_message) X = 1 Y = 2 Z = 3 +W = 0 my_lambda = lambda: print("z") # noqa: E731 @@ -73,7 +83,18 @@ class DataClass: x: int +class ReferencedClass: + def __init__(self, value: int): + self.value = value + + def get_value(self) -> int: + return self.value + + class MyClass: + def __init__(self, x: int): + self.helper = ReferencedClass(x * 2) + @staticmethod def foo(): return KLASS_X @@ -85,6 +106,13 @@ def bar(cls): def baz(self): return KLASS_Z + def use_referenced(self, value: int) -> int: + ref = ReferencedClass(value) + return ref.get_value() + + def compute_with_reference(self) -> int: + return self.helper.get_value() + 10 + def other_func(a: int) -> int: import sqlglot @@ -93,41 +121,82 @@ def other_func(a: int) -> int: pd.DataFrame([{"x": 1}]) to_table("y") my_lambda() # type: ignore - return X + a + obj = MyClass(a) + return X + a + W + obj.compute_with_reference() + + +@contextmanager +def sample_context_manager(): + yield + + +@retry(stop=stop_after_attempt(3)) +def fetch_data(): + return "'test data'" -def noop_metadata() -> None: - return None +def custom_decorator(_func): + def wrapper(*args, **kwargs): + return _func(*args, **kwargs) + return wrapper -setattr(noop_metadata, c.SQLMESH_METADATA, True) +@custom_decorator +def function_with_custom_decorator(): + return -def main_func(y: int) -> int: + +def main_func(y: int, foo=exp.true(), *, bar=expressions.Literal.number(1) + 2) -> int: """DOC STRING""" sqlglot.parse_one("1") - MyClass() + MyClass(47) DataClass(x=y) - noop_metadata() + normalize_model_name("test" + SQLGLOT_META) + fetch_data() + function_with_custom_decorator() def closure(z: int) -> int: return z + Z + with sample_context_manager(): + pass + return closure(y) + other_func(Y) +def macro1() -> str: + print("macro1 hello there") + print(RuntimeStage.CREATING) + return "1" + + +def macro2() -> str: + print("macro2 hello there") + print(RuntimeStage.LOADING) + return "2" + + def test_func_globals() -> None: assert func_globals(main_func) == { "Y": 2, "Z": 3, "DataClass": DataClass, "MyClass": MyClass, - "noop_metadata": noop_metadata, + "normalize_model_name": normalize_model_name, "other_func": other_func, "sqlglot": sqlglot, + "exp": exp, + "expressions": exp, + "fetch_data": fetch_data, + "sample_context_manager": sample_context_manager, + "function_with_custom_decorator": function_with_custom_decorator, + "SQLGLOT_META": SQLGLOT_META, } assert func_globals(other_func) == { "X": 1, + "W": 0, + "MyClass": MyClass, "my_lambda": my_lambda, "pd": pd, "to_table": to_table, @@ -150,14 +219,19 @@ def closure() -> int: def test_normalize_source() -> None: assert ( normalize_source(main_func) - == """def main_func(y: int): + == """def main_func(y: int, foo=exp.true(), *, bar=expressions.Literal.number(1) + 2 + ): sqlglot.parse_one('1') - MyClass() + MyClass(47) DataClass(x=y) - noop_metadata() + normalize_model_name('test' + SQLGLOT_META) + fetch_data() + function_with_custom_decorator() def closure(z: int): return z + Z + with sample_context_manager(): + pass return closure(y) + other_func(Y)""" ) @@ -169,40 +243,58 @@ def closure(z: int): pd.DataFrame([{'x': 1}]) to_table('y') my_lambda() - return X + a""" + obj = MyClass(a) + return X + a + W + obj.compute_with_reference()""" ) def test_serialize_env_error() -> None: with pytest.raises(SQLMeshError): # pretend to be the module pandas - serialize_env({"test_date": test_date}, path=Path("tests/utils")) + serialize_env({"test_date": (test_date, None)}, path=Path("tests/utils")) + + with pytest.raises(SQLMeshError): + serialize_env({"select_all": (SELECT_ALL, None)}, path=Path("tests/utils")) def test_serialize_env() -> None: - env: t.Dict[str, t.Any] = {} path = Path("tests/utils") + env: t.Dict[str, t.Tuple[t.Any, t.Optional[bool]]] = {} + build_env(main_func, env=env, name="MAIN", path=path) - env = serialize_env(env, path=path) # type: ignore + serialized_env = serialize_env(env, path=path) # type: ignore + assert prepare_env(serialized_env) - assert env == { + expected_env = { "MAIN": Executable( name="main_func", alias="MAIN", path="test_metaprogramming.py", - payload="""def main_func(y: int): + payload="""def main_func(y: int, foo=exp.true(), *, bar=expressions.Literal.number(1) + 2 + ): sqlglot.parse_one('1') - MyClass() + MyClass(47) DataClass(x=y) - noop_metadata() + normalize_model_name('test' + SQLGLOT_META) + fetch_data() + function_with_custom_decorator() def closure(z: int): return z + Z + with sample_context_manager(): + pass return closure(y) + other_func(Y)""", ), "X": Executable(payload="1", kind=ExecutableKind.VALUE), "Y": Executable(payload="2", kind=ExecutableKind.VALUE), "Z": Executable(payload="3", kind=ExecutableKind.VALUE), + "W": Executable(payload="0", kind=ExecutableKind.VALUE), + "_GeneratorContextManager": Executable( + payload="from contextlib import _GeneratorContextManager", kind=ExecutableKind.IMPORT + ), + "contextmanager": Executable( + payload="from contextlib import contextmanager", kind=ExecutableKind.IMPORT + ), "KLASS_X": Executable(payload="1", kind=ExecutableKind.VALUE), "KLASS_Y": Executable(payload="2", kind=ExecutableKind.VALUE), "KLASS_Z": Executable(payload="3", kind=ExecutableKind.VALUE), @@ -224,6 +316,9 @@ class DataClass: path="test_metaprogramming.py", payload="""class MyClass: + def __init__(self, x: int): + self.helper = ReferencedClass(x * 2) + @staticmethod def foo(): return KLASS_X @@ -233,24 +328,52 @@ def bar(cls): return KLASS_Y def baz(self): - return KLASS_Z""", + return KLASS_Z + + def use_referenced(self, value: int): + ref = ReferencedClass(value) + return ref.get_value() + + def compute_with_reference(self): + return self.helper.get_value() + 10""", + ), + "ReferencedClass": Executable( + kind=ExecutableKind.DEFINITION, + name="ReferencedClass", + path="test_metaprogramming.py", + payload="""class ReferencedClass: + + def __init__(self, value: int): + self.value = value + + def get_value(self): + return self.value""", ), "dataclass": Executable( payload="from dataclasses import dataclass", kind=ExecutableKind.IMPORT ), "pd": Executable(payload="import pandas as pd", kind=ExecutableKind.IMPORT), "sqlglot": Executable(kind=ExecutableKind.IMPORT, payload="import sqlglot"), + "exp": Executable(kind=ExecutableKind.IMPORT, payload="import sqlglot.expressions as exp"), + "expressions": Executable( + kind=ExecutableKind.IMPORT, payload="import sqlglot.expressions as expressions" + ), + "func": Executable( + payload="""@contextmanager +def sample_context_manager(): + yield""", + name="sample_context_manager", + path="test_metaprogramming.py", + alias="func", + ), "my_lambda": Executable( name="my_lambda", path="test_metaprogramming.py", payload="my_lambda = lambda : print('z')", ), - "noop_metadata": Executable( - name="noop_metadata", - path="test_metaprogramming.py", - payload="""def noop_metadata(): - return None""", - is_metadata=True, + "normalize_model_name": Executable( + payload="from sqlmesh.core.dialect import normalize_model_name", + kind=ExecutableKind.IMPORT, ), "other_func": Executable( name="other_func", @@ -261,6 +384,257 @@ def baz(self): pd.DataFrame([{'x': 1}]) to_table('y') my_lambda() - return X + a""", + obj = MyClass(a) + return X + a + W + obj.compute_with_reference()""", + ), + "sample_context_manager": Executable( + payload="""@contextmanager +def sample_context_manager(): + yield""", + name="sample_context_manager", + path="test_metaprogramming.py", + ), + "wraps": Executable(payload="from functools import wraps", kind=ExecutableKind.IMPORT), + "functools": Executable(payload="import functools", kind=ExecutableKind.IMPORT), + "retry": Executable(payload="from tenacity import retry", kind=ExecutableKind.IMPORT), + "stop_after_attempt": Executable( + payload="from tenacity.stop import stop_after_attempt", kind=ExecutableKind.IMPORT + ), + "wrapped_f": Executable( + payload='''@retry(stop=stop_after_attempt(3)) +def fetch_data(): + return "'test data'"''', + name="fetch_data", + path="test_metaprogramming.py", + alias="wrapped_f", + ), + "fetch_data": Executable( + payload='''@retry(stop=stop_after_attempt(3)) +def fetch_data(): + return "'test data'"''', + name="fetch_data", + path="test_metaprogramming.py", + ), + "f": Executable( + payload='''@retry(stop=stop_after_attempt(3)) +def fetch_data(): + return "'test data'"''', + name="fetch_data", + path="test_metaprogramming.py", + alias="f", + ), + "function_with_custom_decorator": Executable( + name="wrapper", + path="test_metaprogramming.py", + payload="""def wrapper(*args, **kwargs): + return _func(*args, **kwargs)""", + alias="function_with_custom_decorator", + ), + "custom_decorator": Executable( + name="custom_decorator", + path="test_metaprogramming.py", + payload="""def custom_decorator(_func): + + def wrapper(*args, **kwargs): + return _func(*args, **kwargs) + return wrapper""", + ), + "_func": Executable( + name="function_with_custom_decorator", + path="test_metaprogramming.py", + payload="""@custom_decorator +def function_with_custom_decorator(): + return""", + alias="_func", ), + "SQLGLOT_META": Executable.value("sqlglot.meta"), } + + assert all(not is_metadata for (_, is_metadata) in env.values()) + assert serialized_env == expected_env + + # Annotate the entrypoint as "metadata only" to show how it propagates + setattr(main_func, c.SQLMESH_METADATA, True) + + env = {} + + build_env(main_func, env=env, name="MAIN", path=path) + serialized_env = serialize_env(env, path=path) # type: ignore + assert prepare_env(serialized_env) + + expected_env = {k: Executable(**v.dict(), is_metadata=True) for k, v in expected_env.items()} + + # Every object is treated as "metadata only", transitively + assert all(is_metadata for (_, is_metadata) in env.values()) + assert serialized_env == expected_env + + # Check that class references inside init are captured + init_globals = func_globals(MyClass.__init__) + assert "ReferencedClass" in init_globals + + env = {} + build_env(other_func, env=env, name="other_func_test", path=path) + serialized_env = serialize_env(env, path=path) + + assert "MyClass" in serialized_env + assert "ReferencedClass" in serialized_env + + prepared_env = prepare_env(serialized_env) + result = eval("other_func_test(2)", prepared_env) + assert result == 17 + + +def test_serialize_env_with_enum_import_appearing_in_two_functions() -> None: + path = Path("tests/utils") + env: t.Dict[str, t.Tuple[t.Any, t.Optional[bool]]] = {} + + build_env(macro1, env=env, name="macro1", path=path) + build_env(macro2, env=env, name="macro2", path=path) + + serialized_env = serialize_env(env, path=path) # type: ignore + assert prepare_env(serialized_env) + + expected_env = { + "RuntimeStage": Executable( + payload="from sqlmesh.core.macros import RuntimeStage", kind=ExecutableKind.IMPORT + ), + "macro1": Executable( + payload="""def macro1(): + print('macro1 hello there') + print(RuntimeStage.CREATING) + return '1'""", + name="macro1", + path="test_metaprogramming.py", + ), + "macro2": Executable( + payload="""def macro2(): + print('macro2 hello there') + print(RuntimeStage.LOADING) + return '2'""", + name="macro2", + path="test_metaprogramming.py", + ), + } + + assert serialized_env == expected_env + + +def test_dict_sort_basic_types(): + """Test dict_sort with basic Python types.""" + # Test basic types that should use standard repr + assert _dict_sort(42) == "42" + assert _dict_sort("hello") == "'hello'" + assert _dict_sort(True) == "True" + assert _dict_sort(None) == "None" + assert _dict_sort(3.14) == "3.14" + + +def test_dict_sort_dict_ordering(): + """Test that dict_sort produces consistent output for dicts with different key ordering.""" + # Same dict with different key ordering + dict1 = {"c": 3, "a": 1, "b": 2} + dict2 = {"a": 1, "b": 2, "c": 3} + dict3 = {"b": 2, "c": 3, "a": 1} + + repr1 = _dict_sort(dict1) + repr2 = _dict_sort(dict2) + repr3 = _dict_sort(dict3) + + # All should produce the same representation + assert repr1 == repr2 == repr3 + assert repr1 == "{'a': 1, 'b': 2, 'c': 3}" + + +def test_dict_sort_mixed_key_types(): + """Test dict_sort with mixed key types (strings and numbers).""" + dict1 = {42: "number", "string": "text", 1: "one"} + dict2 = {"string": "text", 1: "one", 42: "number"} + + repr1 = _dict_sort(dict1) + repr2 = _dict_sort(dict2) + + # Should produce consistent ordering despite mixed key types + assert repr1 == repr2 + # Numbers come before strings when sorting by string representation + assert repr1 == "{1: 'one', 42: 'number', 'string': 'text'}" + + +def test_dict_sort_nested_structures(): + """Test dict_sort with deeply nested dictionaries.""" + nested1 = {"outer": {"z": 26, "a": 1}, "list": [3, {"y": 2, "x": 1}], "simple": "value"} + + nested2 = {"simple": "value", "list": [3, {"x": 1, "y": 2}], "outer": {"a": 1, "z": 26}} + + repr1 = _dict_sort(nested1) + repr2 = _dict_sort(nested2) + + assert repr1 != repr2 + # Verify structure is maintained with sorted keys + expected1 = "{'list': [3, {'y': 2, 'x': 1}], 'outer': {'z': 26, 'a': 1}, 'simple': 'value'}" + expected2 = "{'list': [3, {'x': 1, 'y': 2}], 'outer': {'a': 1, 'z': 26}, 'simple': 'value'}" + assert repr1 == expected1 + assert repr2 == expected2 + + +def test_dict_sort_lists_and_tuples(): + """Test dict_sort preserves order for lists/tuples and doesn't sort nested dicts.""" + # Lists should be unchanged + list_with_dicts = [{"z": 26, "a": 1}, {"y": 25, "b": 2}] + list_repr = _dict_sort(list_with_dicts) + expected_list = "[{'z': 26, 'a': 1}, {'y': 25, 'b': 2}]" + assert list_repr == expected_list + + # Tuples should be unchanged + tuple_with_dicts = ({"z": 26, "a": 1}, {"y": 25, "b": 2}) + tuple_repr = _dict_sort(tuple_with_dicts) + expected_tuple = "({'z': 26, 'a': 1}, {'y': 25, 'b': 2})" + assert tuple_repr == expected_tuple + + +def test_dict_sort_empty_containers(): + """Test dict_sort with empty containers.""" + assert _dict_sort({}) == "{}" + assert _dict_sort([]) == "[]" + assert _dict_sort(()) == "()" + + +def test_dict_sort_special_characters(): + """Test dict_sort handles special characters correctly.""" + special_dict = { + "quotes": "text with 'single' and \"double\" quotes", + "unicode": "unicode: ñáéíóú", + "newlines": "text\nwith\nnewlines", + "backslashes": "path\\to\\file", + } + + result = _dict_sort(special_dict) + + # Should be valid Python that can be evaluated + reconstructed = eval(result) + assert reconstructed == special_dict + + # Should be deterministic - same input produces same output + result2 = _dict_sort(special_dict) + assert result == result2 + + +def test_dict_sort_executable_integration(): + """Test that dict_sort works correctly with Executable.value().""" + # Test the integration with Executable.value which is the main use case + variables1 = {"env": "dev", "debug": True, "timeout": 30} + variables2 = {"timeout": 30, "debug": True, "env": "dev"} + + exec1 = Executable.value(variables1, sort_root_dict=True) + exec2 = Executable.value(variables2, sort_root_dict=True) + + # Should produce identical payloads despite different input ordering + assert exec1.payload == exec2.payload + assert exec1.payload == "{'debug': True, 'env': 'dev', 'timeout': 30}" + + # Should be valid Python + reconstructed = eval(exec1.payload) + assert reconstructed == variables1 + + # non-deterministic repr should not change the payload + exec3 = Executable.value(variables1) + assert exec3.payload == "{'env': 'dev', 'debug': True, 'timeout': 30}" diff --git a/tests/utils/test_pydantic.py b/tests/utils/test_pydantic.py index 9a7278c3ba..b07d45acb1 100644 --- a/tests/utils/test_pydantic.py +++ b/tests/utils/test_pydantic.py @@ -1,7 +1,9 @@ +import typing as t +import pytest from functools import cached_property from sqlmesh.utils.date import TimeLike, to_date, to_datetime -from sqlmesh.utils.pydantic import PYDANTIC_MAJOR_VERSION, PydanticModel +from sqlmesh.utils.pydantic import PydanticModel, get_concrete_types_from_typehint def test_datetime_date_serialization() -> None: @@ -12,12 +14,8 @@ class Test(PydanticModel): deserialized_date = Test.parse_raw(Test(ds=to_date(target_ds)).json()) deserialized_datetime = Test.parse_raw(Test(ds=to_datetime(target_ds)).json()) - if PYDANTIC_MAJOR_VERSION >= 2: - assert deserialized_date.ds == to_date(target_ds) - assert deserialized_datetime.ds == to_datetime("2022-01-01T00:00:00+00:00") - else: - assert deserialized_date.ds == target_ds - assert deserialized_datetime.ds == "2022-01-01T00:00:00+00:00" + assert deserialized_date.ds == to_date(target_ds) + assert deserialized_datetime.ds == to_datetime("2022-01-01T00:00:00+00:00") def test_pydantic_2_equality() -> None: @@ -59,3 +57,33 @@ def private(self) -> str: model_2_b = TestModel2(name="a") assert hash(model_2_a) == hash(model_2_b) assert hash(model_a) != hash(model_2_a) + + +def test_pydantic_dict_default_args_override() -> None: + class TestModel(PydanticModel): + name: str + + assert TestModel(name="foo").dict(by_alias=True) + + +@pytest.mark.parametrize( + "input,output", + [ + (t.Dict[str, t.Any], {dict}), + (dict, {dict}), + (t.List[str], {list}), + (list, {list}), + (t.Tuple[str, ...], {tuple}), + (tuple, {tuple}), + (t.Set[str], {set}), + (set, {set}), + (t.Optional[t.Dict[str, t.Any]], {dict, type(None)}), + (t.Optional[t.List[str]], {list, type(None)}), + ( + t.Union[str, t.List[str], t.Dict[str, t.Any], t.Optional[t.Set[str]]], + {str, list, dict, set, type(None)}, + ), + ], +) +def test_get_concrete_types_from_typehint(input: t.Any, output: set[type]) -> None: + assert get_concrete_types_from_typehint(input) == output diff --git a/tests/utils/test_windows.py b/tests/utils/test_windows.py new file mode 100644 index 0000000000..196589d9c2 --- /dev/null +++ b/tests/utils/test_windows.py @@ -0,0 +1,39 @@ +import pytest +from pathlib import Path +from sqlmesh.utils.windows import IS_WINDOWS, WINDOWS_LONGPATH_PREFIX, fix_windows_path + + +@pytest.mark.skipif( + not IS_WINDOWS, reason="pathlib.Path only produces WindowsPath objects on Windows" +) +def test_fix_windows_path(): + short_path = Path("c:\\foo") + short_path_prefixed = Path(WINDOWS_LONGPATH_PREFIX + "c:\\foo") + + segments = "\\".join(["bar", "baz", "bing"] * 50) + long_path = Path("c:\\" + segments) + long_path_prefixed = Path(WINDOWS_LONGPATH_PREFIX + "c:\\" + segments) + + assert len(str(short_path.absolute)) < 260 + assert len(str(long_path.absolute)) > 260 + + # paths less than 260 chars are still prefixed because they may be being used as a base path + assert fix_windows_path(short_path) == short_path_prefixed + + # paths greater than 260 characters don't work at all without the prefix + assert fix_windows_path(long_path) == long_path_prefixed + + # multiple calls dont keep appending the same prefix + assert ( + fix_windows_path(fix_windows_path(fix_windows_path(long_path_prefixed))) + == long_path_prefixed + ) + + # paths with relative sections need to have relative sections resolved before they can be used + # since the \\?\ prefix doesnt work for paths with relative sections + assert fix_windows_path(Path("c:\\foo\\..\\bar")) == Path(WINDOWS_LONGPATH_PREFIX + "c:\\bar") + + # also check that relative sections are still resolved if they are added to a previously prefixed path + base = fix_windows_path(Path("c:\\foo")) + assert base == Path(WINDOWS_LONGPATH_PREFIX + "c:\\foo") + assert fix_windows_path(base / ".." / "bar") == Path(WINDOWS_LONGPATH_PREFIX + "c:\\bar") diff --git a/tests/utils/test_yaml.py b/tests/utils/test_yaml.py index 42adfc69b7..5a2e04e5be 100644 --- a/tests/utils/test_yaml.py +++ b/tests/utils/test_yaml.py @@ -1,6 +1,7 @@ import os import pytest +from decimal import Decimal import sqlmesh.utils.yaml as yaml from sqlmesh.utils.errors import SQLMeshError @@ -41,3 +42,40 @@ def test_yaml() -> None: with pytest.raises(SQLMeshError) as ex: yaml.load("") assert "YAML source can't be empty." in str(ex.value) + + decimal_value = Decimal(123.45) + assert yaml.load(yaml.dump(decimal_value)) == str(decimal_value) + + +def test_load_keep_last_duplicate_key() -> None: + input_str = """ +name: first_name +name: second_name +name: third_name + +foo: bar + +mapping: + key: first_value + key: second_value + key: third_value + +sequence: + - one + - two +""" + # Default behavior of ruamel is to keep the first key encountered + assert yaml.load(input_str, allow_duplicate_keys=True) == { + "name": "first_name", + "foo": "bar", + "mapping": {"key": "first_value"}, + "sequence": ["one", "two"], + } + + # Test keeping last key + assert yaml.load(input_str, allow_duplicate_keys=True, keep_last_duplicate_key=True) == { + "name": "third_name", + "foo": "bar", + "mapping": {"key": "third_value"}, + "sequence": ["one", "two"], + } diff --git a/tests/web/conftest.py b/tests/web/conftest.py index 55597f5089..6b6fcaad29 100644 --- a/tests/web/conftest.py +++ b/tests/web/conftest.py @@ -1,14 +1,24 @@ from pathlib import Path import pytest +from fastapi import FastAPI from sqlmesh.core.context import Context -from web.server.main import api_console, app +from sqlmesh.core.console import set_console + +from web.server.console import api_console from web.server.settings import Settings, get_loaded_context, get_settings @pytest.fixture -def project_tmp_path(tmp_path: Path) -> Path: +def web_app() -> FastAPI: + from web.server.main import create_app + + return create_app() + + +@pytest.fixture +def project_tmp_path(web_app: FastAPI, tmp_path: Path): def get_settings_override() -> Settings: return Settings(project_path=tmp_path) @@ -19,26 +29,30 @@ def get_settings_override() -> Settings: """ ) - app.dependency_overrides[get_settings] = get_settings_override - return tmp_path + web_app.dependency_overrides[get_settings] = get_settings_override + yield tmp_path + web_app.dependency_overrides = {} @pytest.fixture -def project_context(project_tmp_path: Path) -> Context: - context = Context(paths=project_tmp_path, console=api_console) +def project_context(web_app: FastAPI, project_tmp_path: Path): + set_console(api_console) + context = Context(paths=project_tmp_path) def get_loaded_context_override() -> Context: return context - app.dependency_overrides[get_loaded_context] = get_loaded_context_override - return context + web_app.dependency_overrides[get_loaded_context] = get_loaded_context_override + yield context + web_app.dependency_overrides = {} @pytest.fixture -def web_sushi_context(sushi_context: Context) -> Context: +def web_sushi_context(web_app: FastAPI, sushi_context: Context): def get_context_override() -> Context: sushi_context.console = api_console return sushi_context - app.dependency_overrides[get_loaded_context] = get_context_override - return sushi_context + web_app.dependency_overrides[get_loaded_context] = get_context_override + yield sushi_context + web_app.dependency_overrides = {} diff --git a/tests/web/test_lineage.py b/tests/web/test_lineage.py index 060130c49f..0cffd3ecc3 100644 --- a/tests/web/test_lineage.py +++ b/tests/web/test_lineage.py @@ -1,17 +1,20 @@ from __future__ import annotations import pytest +from fastapi import FastAPI from fastapi.testclient import TestClient from sqlmesh.core.context import Context -from web.server.main import app pytestmark = pytest.mark.web -client = TestClient(app) +@pytest.fixture +def client(web_app: FastAPI) -> TestClient: + return TestClient(web_app) -def test_get_lineage(web_sushi_context: Context) -> None: + +def test_get_lineage(client: TestClient, web_sushi_context: Context) -> None: response = client.get("/api/lineage/sushi.waiters/event_date") assert response.status_code == 200 @@ -44,7 +47,7 @@ def test_get_lineage(web_sushi_context: Context) -> None: "customer_id": { "expression": 'CAST("o"."customer_id" AS INT) AS "customer_id" /* this comment should not be registered */', "models": {'"memory"."sushi"."orders"': ["customer_id"]}, - "source": '''WITH "current_marketing" AS ( + "source": """WITH "current_marketing_outer" AS ( SELECT "marketing"."customer_id" AS "customer_id", "marketing"."status" AS "status" @@ -55,10 +58,27 @@ def test_get_lineage(web_sushi_context: Context) -> None: SELECT DISTINCT CAST("o"."customer_id" AS INT) AS "customer_id" /* this comment should not be registered */ FROM "memory"."sushi"."orders" AS "o" -LEFT JOIN "current_marketing" AS "m" +LEFT JOIN ( + WITH "current_marketing" AS ( + SELECT + "current_marketing_outer"."customer_id" AS "customer_id", + "current_marketing_outer"."status" AS "status", + 2 AS "another_column" + FROM "current_marketing_outer" AS "current_marketing_outer" + ) + SELECT + "current_marketing"."customer_id" AS "customer_id", + "current_marketing"."status" AS "status", + "current_marketing"."another_column" AS "another_column" + FROM "current_marketing" AS "current_marketing" + WHERE + "current_marketing"."customer_id" <> 100 +) AS "m" ON "m"."customer_id" = "o"."customer_id" LEFT JOIN "memory"."raw"."demographics" AS "d" - ON "d"."customer_id" = "o"."customer_id"''', + ON "d"."customer_id" = "o"."customer_id" +WHERE + "o"."customer_id" > 0""", } }, '"memory"."sushi"."orders"': { @@ -74,7 +94,7 @@ def test_get_lineage(web_sushi_context: Context) -> None: } -def test_get_lineage_managed_columns(web_sushi_context: Context) -> None: +def test_get_lineage_managed_columns(client: TestClient, web_sushi_context: Context) -> None: # Get lineage of managed column response = client.get("/api/lineage/sushi.marketing/valid_from") assert response.status_code == 200 @@ -91,7 +111,7 @@ def test_get_lineage_managed_columns(web_sushi_context: Context) -> None: } -def test_get_lineage_single_model(project_context: Context) -> None: +def test_get_lineage_single_model(client: TestClient, project_context: Context) -> None: project_tmp_path = project_context.path models_dir = project_tmp_path / "models" models_dir.mkdir() @@ -106,7 +126,7 @@ def test_get_lineage_single_model(project_context: Context) -> None: assert response_json['"bar"']["col"]["models"] == {} -def test_get_lineage_external_model(project_context: Context) -> None: +def test_get_lineage_external_model(client: TestClient, project_context: Context) -> None: project_tmp_path = project_context.path models_dir = project_tmp_path / "models" models_dir.mkdir() @@ -126,7 +146,7 @@ def test_get_lineage_external_model(project_context: Context) -> None: assert response_json['"baz"']["col"]["models"] == {'"external_table"': ["col"]} -def test_get_lineage_cte(project_context: Context) -> None: +def test_get_lineage_cte(client: TestClient, project_context: Context) -> None: project_tmp_path = project_context.path models_dir = project_tmp_path / "models" models_dir.mkdir() @@ -167,7 +187,7 @@ def test_get_lineage_cte(project_context: Context) -> None: assert response_json['"baz"']["col"]["models"] == {'"external_table"': ["col"]} -def test_get_lineage_cte_downstream(project_context: Context) -> None: +def test_get_lineage_cte_downstream(client: TestClient, project_context: Context) -> None: project_tmp_path = project_context.path models_dir = project_tmp_path / "models" models_dir.mkdir() @@ -208,7 +228,7 @@ def test_get_lineage_cte_downstream(project_context: Context) -> None: assert response_json['"baz"']["col"]["models"] == {'"external_table"': ["col"]} -def test_get_lineage_join(project_context: Context) -> None: +def test_get_lineage_join(client: TestClient, project_context: Context) -> None: project_tmp_path = project_context.path models_dir = project_tmp_path / "models" models_dir.mkdir() @@ -237,7 +257,7 @@ def test_get_lineage_join(project_context: Context) -> None: assert response_json['"baz"']["price"]["models"] == {'"external_baz"': ["price"]} -def test_get_lineage_multiple_columns(project_context: Context) -> None: +def test_get_lineage_multiple_columns(client: TestClient, project_context: Context) -> None: project_tmp_path = project_context.path models_dir = project_tmp_path / "models" models_dir.mkdir() @@ -262,7 +282,7 @@ def test_get_lineage_multiple_columns(project_context: Context) -> None: assert response_json['"bar"']["multiplier"]["models"] == {'"external_bar"': ["multiplier"]} -def test_get_lineage_union(project_context: Context) -> None: +def test_get_lineage_union(client: TestClient, project_context: Context) -> None: project_tmp_path = project_context.path models_dir = project_tmp_path / "models" models_dir.mkdir() @@ -297,7 +317,7 @@ def test_get_lineage_union(project_context: Context) -> None: assert response_json['"foo"']["col"]["models"] == {'"bar"': ["col"], '"baz"': ["col"]} -def test_get_lineage_union_downstream(project_context: Context) -> None: +def test_get_lineage_union_downstream(client: TestClient, project_context: Context) -> None: project_tmp_path = project_context.path models_dir = project_tmp_path / "models" models_dir.mkdir() @@ -343,7 +363,7 @@ def test_get_lineage_union_downstream(project_context: Context) -> None: assert response_json['"qwe"']["col"]["models"] == {'"external_qwe"': ["col"]} -def test_get_lineage_cte_union(project_context: Context) -> None: +def test_get_lineage_cte_union(client: TestClient, project_context: Context) -> None: project_tmp_path = project_context.path models_dir = project_tmp_path / "models" models_dir.mkdir() @@ -386,7 +406,7 @@ def test_get_lineage_cte_union(project_context: Context) -> None: assert response_json['"baz"']["col"]["models"] == {'"external_baz"': ["col"]} -def test_get_lineage_cte_union_downstream(project_context: Context) -> None: +def test_get_lineage_cte_union_downstream(client: TestClient, project_context: Context) -> None: project_tmp_path = project_context.path models_dir = project_tmp_path / "models" models_dir.mkdir() @@ -436,7 +456,9 @@ def test_get_lineage_cte_union_downstream(project_context: Context) -> None: assert response_json['"qwe"']["col"]["models"] == {'"external_qwe"': ["col"]} -def test_get_lineage_cte_downstream_union_downstream(project_context: Context) -> None: +def test_get_lineage_cte_downstream_union_downstream( + client: TestClient, project_context: Context +) -> None: project_tmp_path = project_context.path models_dir = project_tmp_path / "models" models_dir.mkdir() @@ -485,7 +507,9 @@ def test_get_lineage_cte_downstream_union_downstream(project_context: Context) - } -def test_get_lineage_nested_cte_union_downstream(project_context: Context) -> None: +def test_get_lineage_nested_cte_union_downstream( + client: TestClient, project_context: Context +) -> None: project_tmp_path = project_context.path models_dir = project_tmp_path / "models" models_dir.mkdir() @@ -539,7 +563,7 @@ def test_get_lineage_nested_cte_union_downstream(project_context: Context) -> No } -def test_get_lineage_subquery(project_context: Context) -> None: +def test_get_lineage_subquery(client: TestClient, project_context: Context) -> None: project_tmp_path = project_context.path models_dir = project_tmp_path / "models" models_dir.mkdir() @@ -577,7 +601,7 @@ def test_get_lineage_subquery(project_context: Context) -> None: assert response_json['"baz"']["col"]["models"] == {'"external_table"': ["col"]} -def test_get_lineage_cte_name_collision(project_context: Context) -> None: +def test_get_lineage_cte_name_collision(client: TestClient, project_context: Context) -> None: project_tmp_path = project_context.path models_dir = project_tmp_path / "models" models_dir.mkdir() @@ -622,7 +646,9 @@ def test_get_lineage_cte_name_collision(project_context: Context) -> None: assert response_json['"baz"']["col"]["models"] == {'"external_table"': ["col"]} -def test_get_lineage_derived_table_alias_collision(project_context: Context) -> None: +def test_get_lineage_derived_table_alias_collision( + client: TestClient, project_context: Context +) -> None: project_tmp_path = project_context.path models_dir = project_tmp_path / "models" models_dir.mkdir() @@ -661,7 +687,7 @@ def test_get_lineage_derived_table_alias_collision(project_context: Context) -> assert response_json['"baz"']["col"]["models"] == {'"external_table"': ["col"]} -def test_get_lineage_constants(project_context: Context) -> None: +def test_get_lineage_constants(client: TestClient, project_context: Context) -> None: project_tmp_path = project_context.path models_dir = project_tmp_path / "models" models_dir.mkdir() @@ -699,7 +725,7 @@ def test_get_lineage_constants(project_context: Context) -> None: assert response_json['"bar"']["col"]["models"] == {'"external_table"': ["col"]} -def test_get_lineage_quoted_columns(project_context: Context) -> None: +def test_get_lineage_quoted_columns(client: TestClient, project_context: Context) -> None: project_tmp_path = project_context.path models_dir = project_tmp_path / "models" models_dir.mkdir() diff --git a/tests/web/test_main.py b/tests/web/test_main.py index 687341b3fb..cf2220ad6d 100644 --- a/tests/web/test_main.py +++ b/tests/web/test_main.py @@ -5,24 +5,26 @@ import pyarrow as pa # type: ignore import pytest +from fastapi import FastAPI from fastapi.testclient import TestClient -from httpx import AsyncClient +from httpx import ASGITransport, AsyncClient from pytest_mock.plugin import MockerFixture from sqlmesh.core.context import Context from sqlmesh.core.environment import Environment from sqlmesh.utils.errors import PlanError from web.server.api.endpoints.files import _get_file_with_content -from web.server.main import app from web.server.settings import get_settings pytestmark = pytest.mark.web -client = TestClient(app) +@pytest.fixture +def client(web_app: FastAPI) -> TestClient: + return TestClient(web_app) -def test_get_files(project_tmp_path: Path) -> None: +def test_get_files(client: TestClient, project_tmp_path: Path) -> None: models_dir = project_tmp_path / "models" models_dir.mkdir() sql_file = models_dir / "foo.sql" @@ -59,7 +61,7 @@ def test_get_files(project_tmp_path: Path) -> None: } -def test_get_file(project_tmp_path: Path) -> None: +def test_get_file(client: TestClient, project_tmp_path: Path) -> None: txt_file = project_tmp_path / "foo.txt" txt_file.write_text("bar") @@ -73,12 +75,12 @@ def test_get_file(project_tmp_path: Path) -> None: } -def test_get_file_not_found() -> None: +def test_get_file_not_found(client: TestClient) -> None: response = client.get("/api/files/not_found.txt") assert response.status_code == 404 -def test_get_file_invalid_path(project_tmp_path: Path) -> None: +def test_get_file_invalid_path(client: TestClient, project_tmp_path: Path) -> None: config = project_tmp_path / "config.py" config.write_text( """from sqlmesh.core.config import Config, ModelDefaultsConfig @@ -92,7 +94,7 @@ def test_get_file_invalid_path(project_tmp_path: Path) -> None: assert response.status_code == 404 -def test_write_file(project_tmp_path: Path) -> None: +def test_write_file(client: TestClient, project_tmp_path: Path) -> None: response = client.post("/api/files/foo.txt", json={"content": "bar"}) file = _get_file_with_content(project_tmp_path / "foo.txt", "foo.txt") assert response.status_code == 204 @@ -104,7 +106,19 @@ def test_write_file(project_tmp_path: Path) -> None: } -def test_update_file(project_tmp_path: Path) -> None: +def test_write_file_non_ascii(client: TestClient, project_tmp_path: Path) -> None: + response = client.post("/api/files/foo.txt", json={"content": "何か良いこと"}) + file = _get_file_with_content(project_tmp_path / "foo.txt", "foo.txt") + assert response.status_code == 204 + assert file.dict() == { + "name": "foo.txt", + "path": "foo.txt", + "extension": ".txt", + "content": "何か良いこと", + } + + +def test_update_file(client: TestClient, project_tmp_path: Path) -> None: txt_file = project_tmp_path / "foo.txt" txt_file.write_text("bar") @@ -119,7 +133,7 @@ def test_update_file(project_tmp_path: Path) -> None: } -def test_rename_file(project_tmp_path: Path) -> None: +def test_rename_file(client: TestClient, project_tmp_path: Path) -> None: txt_file = project_tmp_path / "foo.txt" txt_file.write_text("bar") @@ -135,7 +149,7 @@ def test_rename_file(project_tmp_path: Path) -> None: assert not txt_file.exists() -def test_rename_file_and_keep_content(project_tmp_path: Path) -> None: +def test_rename_file_and_keep_content(client: TestClient, project_tmp_path: Path) -> None: txt_file = project_tmp_path / "foo.txt" txt_file.write_text("bar") @@ -153,12 +167,12 @@ def test_rename_file_and_keep_content(project_tmp_path: Path) -> None: assert not txt_file.exists() -def test_rename_file_not_found(project_tmp_path: Path) -> None: +def test_rename_file_not_found(client: TestClient, project_tmp_path: Path) -> None: response = client.post("/api/files/foo.txt", json={"new_path": "baz.txt"}) assert response.status_code == 404 -def test_rename_file_already_exists(project_tmp_path: Path) -> None: +def test_rename_file_already_exists(client: TestClient, project_tmp_path: Path) -> None: foo_file = project_tmp_path / "foo.txt" foo_file.write_text("foo") bar_file = project_tmp_path / "bar.txt" @@ -176,7 +190,7 @@ def test_rename_file_already_exists(project_tmp_path: Path) -> None: assert not foo_file.exists() -def test_rename_file_to_existing_directory(project_tmp_path: Path) -> None: +def test_rename_file_to_existing_directory(client: TestClient, project_tmp_path: Path) -> None: foo_file = project_tmp_path / "foo.txt" foo_file.touch() existing_dir = project_tmp_path / "existing_dir" @@ -187,12 +201,12 @@ def test_rename_file_to_existing_directory(project_tmp_path: Path) -> None: assert foo_file.exists() -def test_write_file_empty_body(project_tmp_path: Path) -> None: +def test_write_file_empty_body(client: TestClient, project_tmp_path: Path) -> None: response = client.post("/api/files/foo.txt", json={}) assert response.status_code == 204 -def test_delete_file(project_tmp_path: Path) -> None: +def test_delete_file(client: TestClient, project_tmp_path: Path) -> None: txt_file = project_tmp_path / "foo.txt" txt_file.write_text("bar") @@ -201,19 +215,24 @@ def test_delete_file(project_tmp_path: Path) -> None: assert not txt_file.exists() -def test_delete_file_not_found(project_tmp_path: Path) -> None: +def test_delete_file_not_found(client: TestClient, project_tmp_path: Path) -> None: response = client.delete("/api/files/not_found.txt") assert response.status_code == 404 -def test_create_directory(project_tmp_path: Path) -> None: +def test_create_directory(client: TestClient, project_tmp_path: Path) -> None: response = client.post("/api/directories/new_dir") assert response.status_code == 200 assert (project_tmp_path / "new_dir").exists() - assert response.json() == {"directories": [], "files": [], "name": "new_dir", "path": "new_dir"} + assert response.json() == { + "directories": [], + "files": [], + "name": "new_dir", + "path": "new_dir", + } -def test_create_directory_already_exists(project_tmp_path: Path) -> None: +def test_create_directory_already_exists(client: TestClient, project_tmp_path: Path) -> None: new_dir = project_tmp_path / "new_dir" new_dir.mkdir() @@ -222,7 +241,7 @@ def test_create_directory_already_exists(project_tmp_path: Path) -> None: assert response.json()["message"] == "Directory already exists" -def test_rename_directory(project_tmp_path: Path) -> None: +def test_rename_directory(client: TestClient, project_tmp_path: Path) -> None: new_dir = project_tmp_path / "new_dir" new_dir.mkdir() @@ -238,7 +257,7 @@ def test_rename_directory(project_tmp_path: Path) -> None: } -def test_rename_directory_already_exists_empty(project_tmp_path: Path) -> None: +def test_rename_directory_already_exists_empty(client: TestClient, project_tmp_path: Path) -> None: new_dir = project_tmp_path / "new_dir" new_dir.mkdir() existing_dir = project_tmp_path / "renamed_dir" @@ -256,7 +275,9 @@ def test_rename_directory_already_exists_empty(project_tmp_path: Path) -> None: } -def test_rename_directory_already_exists_not_empty(project_tmp_path: Path) -> None: +def test_rename_directory_already_exists_not_empty( + client: TestClient, project_tmp_path: Path +) -> None: new_dir = project_tmp_path / "new_dir" new_dir.mkdir() existing_dir = project_tmp_path / "renamed_dir" @@ -270,7 +291,7 @@ def test_rename_directory_already_exists_not_empty(project_tmp_path: Path) -> No assert new_dir.exists() -def test_rename_directory_to_existing_file(project_tmp_path: Path) -> None: +def test_rename_directory_to_existing_file(client: TestClient, project_tmp_path: Path) -> None: new_dir = project_tmp_path / "new_dir" new_dir.mkdir() existing_file = project_tmp_path / "foo.txt" @@ -282,7 +303,7 @@ def test_rename_directory_to_existing_file(project_tmp_path: Path) -> None: assert new_dir.exists() -def test_delete_directory(project_tmp_path: Path) -> None: +def test_delete_directory(client: TestClient, project_tmp_path: Path) -> None: new_dir = project_tmp_path / "new_dir" new_dir.mkdir() @@ -291,12 +312,12 @@ def test_delete_directory(project_tmp_path: Path) -> None: assert not new_dir.exists() -def test_delete_directory_not_found(project_tmp_path: Path) -> None: +def test_delete_directory_not_found(client: TestClient, project_tmp_path: Path) -> None: response = client.delete("/api/directories/fake_dir") assert response.status_code == 404 -def test_delete_directory_not_a_directory(project_tmp_path: Path) -> None: +def test_delete_directory_not_a_directory(client: TestClient, project_tmp_path: Path) -> None: txt_file = project_tmp_path / "foo.txt" txt_file.touch() @@ -305,7 +326,7 @@ def test_delete_directory_not_a_directory(project_tmp_path: Path) -> None: assert response.json()["message"] == "Not a directory" -def test_delete_directory_not_empty(project_tmp_path: Path) -> None: +def test_delete_directory_not_empty(client: TestClient, project_tmp_path: Path) -> None: new_dir = project_tmp_path / "new_dir" new_dir.mkdir() (new_dir / "foo.txt").touch() @@ -315,7 +336,7 @@ def test_delete_directory_not_empty(project_tmp_path: Path) -> None: assert not new_dir.exists() -def test_apply(project_tmp_path: Path) -> None: +def test_apply(client: TestClient, project_tmp_path: Path) -> None: models_dir = project_tmp_path / "models" models_dir.mkdir() sql_file = models_dir / "foo.sql" @@ -329,14 +350,16 @@ def test_apply(project_tmp_path: Path) -> None: @pytest.mark.skip( reason="needs to be fixed: plan tests are failing inside coroutine and won't throw 422" ) -def test_apply_test_failures(web_sushi_context: Context, mocker: MockerFixture) -> None: +def test_apply_test_failures( + client: TestClient, web_sushi_context: Context, mocker: MockerFixture +) -> None: mocker.patch.object(web_sushi_context, "_run_plan_tests", side_effect=PlanError()) response = client.post("/api/commands/apply", json={"environment": "dev"}) assert response.status_code == 422 assert response.json()["message"] == "Unable to run a plan" -def test_plan(web_sushi_context: Context) -> None: +def test_plan(client: TestClient, web_sushi_context: Context) -> None: client.app.state.circuit_breaker = threading.Event() # type: ignore response = client.post("/api/plan", json={"environment": "dev"}) assert response.status_code == 204 @@ -345,7 +368,9 @@ def test_plan(web_sushi_context: Context) -> None: @pytest.mark.skip( reason="needs to be fixed: plan tests are failing inside coroutine and won't throw 422" ) -def test_plan_test_failures(web_sushi_context: Context, mocker: MockerFixture) -> None: +def test_plan_test_failures( + client: TestClient, web_sushi_context: Context, mocker: MockerFixture +) -> None: mocker.patch.object(web_sushi_context, "_run_plan_tests", side_effect=PlanError()) response = client.post("/api/plan", json={"environment": "dev"}) assert response.status_code == 422 @@ -353,21 +378,22 @@ def test_plan_test_failures(web_sushi_context: Context, mocker: MockerFixture) - @pytest.mark.asyncio -async def test_cancel() -> None: +async def test_cancel(client: TestClient) -> None: client.app.state.circuit_breaker = threading.Event() # type: ignore - async with AsyncClient(app=app, base_url="http://testserver") as _client: + transport = ASGITransport(client.app) # type: ignore + async with AsyncClient(transport=transport, base_url="http://testserver") as _client: await _client.post("/api/plan", json={"environment": "dev"}) response = await _client.post("/api/plan/cancel") assert response.status_code == 204 - assert app.state.task.cancelled() + assert client.app.state.task.cancelled() # type: ignore -def test_cancel_no_task() -> None: +def test_cancel_no_task(client: TestClient) -> None: response = client.post("/api/plan/cancel") assert response.status_code == 204 -def test_evaluate(web_sushi_context: Context) -> None: +def test_evaluate(client: TestClient, web_sushi_context: Context) -> None: response = client.post( "/api/commands/evaluate", json={ @@ -384,7 +410,7 @@ def test_evaluate(web_sushi_context: Context) -> None: assert not df.empty -def test_meta() -> None: +def test_meta(client: TestClient) -> None: from sqlmesh.cli.main import _sqlmesh_version response = client.get("/api/meta") @@ -392,14 +418,14 @@ def test_meta() -> None: assert response.json() == {"version": _sqlmesh_version(), "has_running_task": False} -def test_modules() -> None: +def test_modules(client: TestClient) -> None: settings = get_settings() response = client.get("/api/modules") assert response.status_code == 200 assert response.json() == list(settings.modules) -def test_fetchdf(web_sushi_context: Context) -> None: +def test_fetchdf(client: TestClient, web_sushi_context: Context) -> None: response = client.post("/api/commands/fetchdf", json={"sql": "SELECT * from sushi.top_waiters"}) assert response.status_code == 200 with pa.ipc.open_stream(response.content) as reader: @@ -407,7 +433,7 @@ def test_fetchdf(web_sushi_context: Context) -> None: assert not df.empty -def test_get_model(web_sushi_context: Context) -> None: +def test_get_model(client: TestClient, web_sushi_context: Context) -> None: response = client.get("/api/models/sushi.customers") assert response.status_code == 200 @@ -429,7 +455,7 @@ def test_get_model(web_sushi_context: Context) -> None: # TODO: add better tests for this endpoint -def test_get_models(web_sushi_context: Context) -> None: +def test_get_models(client: TestClient, web_sushi_context: Context) -> None: response = client.get("/api/models") assert response.status_code == 200 @@ -442,19 +468,19 @@ def test_get_models(web_sushi_context: Context) -> None: assert test_model.get("columns") -def test_render(web_sushi_context: Context) -> None: +def test_render(client: TestClient, web_sushi_context: Context) -> None: response = client.post("/api/commands/render", json={"model": "sushi.items"}) assert response.status_code == 200 assert response.json()["sql"] -def test_render_invalid_model(web_sushi_context: Context) -> None: +def test_render_invalid_model(client: TestClient, web_sushi_context: Context) -> None: response = client.post("/api/commands/render", json={"model": "foo.bar"}) assert response.status_code == 422 assert response.json()["message"] == "Unable to find a model" -def test_get_environments(project_context: Context) -> None: +def test_get_environments(client: TestClient, project_context: Context) -> None: response = client.get("/api/environments") assert response.status_code == 200 response_json = response.json() @@ -462,7 +488,11 @@ def test_get_environments(project_context: Context) -> None: environment = Environment.parse_obj(response_json["environments"]["prod"]) assert environment == Environment( - name="prod", snapshots=[], start_at="1970-01-01", plan_id="", suffix_target="schema" + name="prod", + snapshots=[], + start_at="1970-01-01", + plan_id="", + suffix_target="schema", ) assert response_json["pinned_environments"] == list(project_context.config.pinned_environments) assert ( @@ -471,15 +501,19 @@ def test_get_environments(project_context: Context) -> None: ) -def test_delete_environment_success(web_sushi_context: Context): +def test_delete_environment_success(client: TestClient, web_sushi_context: Context): response = client.delete("/api/environments/test") assert response.status_code == 204 -def test_delete_environment_failure(web_sushi_context: Context, mocker: MockerFixture): +def test_delete_environment_failure( + client: TestClient, web_sushi_context: Context, mocker: MockerFixture +): mocker.patch.object( - web_sushi_context.state_sync, "invalidate_environment", side_effect=Exception("Some error") + web_sushi_context.state_sync, + "invalidate_environment", + side_effect=Exception("Some error"), ) response = client.delete("/api/environments/test") @@ -488,7 +522,7 @@ def test_delete_environment_failure(web_sushi_context: Context, mocker: MockerFi assert response.json()["message"] == "Unable to delete environments" -def test_table_diff(web_sushi_context: Context) -> None: +def test_table_diff(client: TestClient, web_sushi_context: Context) -> None: web_sushi_context.plan( "dev", no_prompts=True, @@ -505,11 +539,10 @@ def test_table_diff(web_sushi_context: Context) -> None: }, ) assert response.status_code == 200 - assert "schema_diff" in response.json() - assert "row_diff" in response.json() + assert response.json() == None -def test_test(web_sushi_context: Context) -> None: +def test_test(client: TestClient, web_sushi_context: Context) -> None: response = client.get("/api/commands/test") assert response.status_code == 200 response_json = response.json() @@ -524,7 +557,7 @@ def test_test(web_sushi_context: Context) -> None: assert response_json["failures"] == [] -def test_test_failure(project_context: Context) -> None: +def test_test_failure(client: TestClient, project_context: Context) -> None: models_dir = project_context.path / "models" models_dir.mkdir() sql_file = models_dir / "foo.sql" diff --git a/tooling/README.md b/tooling/README.md new file mode 100644 index 0000000000..fe6712dc2e --- /dev/null +++ b/tooling/README.md @@ -0,0 +1,11 @@ +# Tooling + +This directory contains the tooling for building the SQLMesh project. + +## Vscode + +The `vscode` directory contains sample configs for VSCode that can be copied into the user's VSCode settings in order to make the development of SQLMesh easier. The following command will copy the sample configs into the user's VSCode settings. + +```bash +make vscode_settings +``` \ No newline at end of file diff --git a/tooling/validating_migration_numbers.sh b/tooling/validating_migration_numbers.sh new file mode 100755 index 0000000000..6997d41fe1 --- /dev/null +++ b/tooling/validating_migration_numbers.sh @@ -0,0 +1,55 @@ +#!/bin/bash + +# Navigate to the migrations directory (modify the path if necessary) +cd "sqlmesh/migrations" || exit 1 + + +# Collect all migration files matching the pattern (e.g., v0001_initial.py) +migration_files=(v*.py) + +# Initialize an array to hold migration numbers +numbers=() + +# Extract migration numbers from filenames +for file in "${migration_files[@]}"; do + if [[ $file =~ ^v0*([0-9]+)_ ]]; then + num=${BASH_REMATCH[1]} + if [[ "$num" -gt 0 ]]; then + numbers+=("$num") + fi + fi +done + +# Check if any migration files were found +if [[ ${#numbers[@]} -eq 0 ]]; then + echo "No migration files found matching the pattern 'v_.py'." + exit 1 +fi + +# Check for duplicate migration numbers +duplicates=$(printf "%s\n" "${numbers[@]}" | sort | uniq -d) +if [[ -n $duplicates ]]; then + echo "Error: Duplicate migration numbers found: $duplicates" + exit 1 +fi + +# Sort the migration numbers +sorted_numbers=($(printf "%s\n" "${numbers[@]}" | sort -n)) + +# Get the first and last migration numbers +first_number="${sorted_numbers[0]}" +last_index=$((${#sorted_numbers[@]} - 1)) +last_number="${sorted_numbers[$last_index]}" + +# Check for gaps in the migration sequence +expected_numbers=($(seq "$first_number" "$last_number")) + +if [[ "${sorted_numbers[*]}" != "${expected_numbers[*]}" ]]; then + echo "Error: Missing migration numbers in sequence." + echo "Expected sequence: ${expected_numbers[*]}" + echo "Found sequence: ${sorted_numbers[*]}" + exit 1 +fi + +echo "All migration numbers are sequential and without overlaps." +exit 0 diff --git a/tooling/vscode/extensions.json b/tooling/vscode/extensions.json new file mode 100644 index 0000000000..b9df8890ab --- /dev/null +++ b/tooling/vscode/extensions.json @@ -0,0 +1,12 @@ +{ + // See http://go.microsoft.com/fwlink/?LinkId=827846 + // for the documentation about the extensions.json format + "recommendations": [ + "dbaeumer.vscode-eslint", + "amodio.tsl-problem-matcher", + "ms-vscode.extension-test-runner", + "ms-playwright.playwright", + "esbenp.prettier-vscode", + "charliermarsh.ruff" + ] +} diff --git a/tooling/vscode/launch.json b/tooling/vscode/launch.json new file mode 100644 index 0000000000..76f55db912 --- /dev/null +++ b/tooling/vscode/launch.json @@ -0,0 +1,21 @@ +// A launch configuration that compiles the extension and then opens it inside a new window +// Use IntelliSense to learn about possible attributes. +// Hover to view descriptions of existing attributes. +// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 +{ + "$schema": "vscode://schemas/launch", + "version": "0.2.0", + "configurations": [ + { + "name": "Run Extension", + "type": "extensionHost", + "request": "launch", + "args": [ + "${workspaceFolder}/examples/sushi", + "--extensionDevelopmentPath=${workspaceFolder}/vscode/extension" + ], + "outFiles": ["${workspaceFolder}/vscode/extension/dist/**/*.js"], + "preLaunchTask": "${defaultBuildTask}" + } + ] +} diff --git a/tooling/vscode/settings.json b/tooling/vscode/settings.json new file mode 100644 index 0000000000..6cdfa704c7 --- /dev/null +++ b/tooling/vscode/settings.json @@ -0,0 +1,27 @@ +// Place your settings in this file to overwrite default and user settings. +{ + "files.exclude": { + "vscode/extension/out": false, // set this to true to hide the "out" folder with the compiled JS files + "vscode/extension/dist": false, // set this to true to hide the "dist" folder with the compiled JS files + "vscode/react/node_modules": false, + "vscode/react/dist": false + }, + "search.exclude": { + "vscode/extension/out": true, // set this to false to include "out" folder in search results + "vscode/extension/dist": true, // set this to false to include "dist" folder in search results + "vscode/react/node_modules": true, + "vscode/react/dist": true + }, + // Turn off tsc task auto detection since we have the necessary tasks as npm scripts + "typescript.tsc.autoDetect": "off", + // Playwright configuration + // Python configuration to ensure consistent environment + "python.defaultInterpreterPath": "${env:VIRTUAL_ENV}/bin/python", + "python.terminal.activateEnvironment": true, + "terminal.integrated.env.osx": { + "PATH": "${env:VIRTUAL_ENV}/bin:${env:PATH}" + }, + "terminal.integrated.env.linux": { + "PATH": "${env:VIRTUAL_ENV}/bin:${env:PATH}" + } +} diff --git a/tooling/vscode/tasks.json b/tooling/vscode/tasks.json new file mode 100644 index 0000000000..18c5042471 --- /dev/null +++ b/tooling/vscode/tasks.json @@ -0,0 +1,70 @@ +// See https://go.microsoft.com/fwlink/?LinkId=733558 +// for the documentation about the tasks.json format +{ + "$schema": "vscode://schemas/tasks", + "version": "2.0.0", + "tasks": [ + { + "label": "extension-watch", + "type": "shell", + "command": "pnpm run watch", + "problemMatcher": { + "base": "$ts-webpack-watch", + "background": { + "activeOnStart": true, + "beginsPattern": "build started", + "endsPattern": "build finished" + } + }, + "isBackground": true, + "presentation": { + "reveal": "never", + "group": "watchers" + }, + "group": { + "kind": "build" + }, + "options": { + "cwd": "${workspaceFolder}/vscode/extension" + }, + "dependsOrder": "parallel" + }, + { + "label": "react-dev", + "type": "shell", + "command": "pnpm run build:watch", + "options": { + "cwd": "${workspaceFolder}/vscode/react" + }, + "group": { + "kind": "build" + }, + "isBackground": true, + "problemMatcher": { + "owner": "webpack", + "pattern": { + "regexp": "." + }, + "background": { + "activeOnStart": true, + "beginsPattern": "Generating routes", + "endsPattern": "built in" + } + }, + "presentation": { + "reveal": "never", + "group": "watchers" + }, + "dependsOrder": "parallel" + }, + { + "label": "extension-watch-develop", + "group": { + "kind": "build", + "isDefault": true + }, + "dependsOn": ["react-dev", "extension-watch"], + "dependsOrder": "parallel" + } + ] +} diff --git a/vscode/bus/.gitignore b/vscode/bus/.gitignore new file mode 100644 index 0000000000..de4d1f007d --- /dev/null +++ b/vscode/bus/.gitignore @@ -0,0 +1,2 @@ +dist +node_modules diff --git a/vscode/bus/package.json b/vscode/bus/package.json new file mode 100644 index 0000000000..024942f6c4 --- /dev/null +++ b/vscode/bus/package.json @@ -0,0 +1,19 @@ +{ + "name": "sqlmesh-extension-bus", + "private": true, + "version": "0.0.1", + "scripts": { + "ci": "pnpm run lint", + "build": "tsc", + "dev": "tsc -w", + "lint": "tsc --noEmit" + }, + "files": [ + "/dist" + ], + "main": "dist/index.js", + "types": "dist/index.d.ts", + "devDependencies": { + "typescript": "^5.8.3" + } +} diff --git a/vscode/bus/src/brand.ts b/vscode/bus/src/brand.ts new file mode 100644 index 0000000000..2b9c3ca37a --- /dev/null +++ b/vscode/bus/src/brand.ts @@ -0,0 +1,19 @@ +declare const __brand: unique symbol +type Brand = { [__brand]: B } + +/** + * Branded is a type that adds a brand to a type. It is a type that is used to + * ensure that the type is unique and that it is not possible to mix up types + * with the same brand. + * + * @example + * + * type UserId = Branded + * type UserName = Branded + * + * const userId = '123' as UserId + * const userName = 'John Doe' as UserName + * + * userId == userName -> compile error + */ +export type Branded = T & Brand diff --git a/vscode/bus/src/callbacks.ts b/vscode/bus/src/callbacks.ts new file mode 100644 index 0000000000..0601fd892a --- /dev/null +++ b/vscode/bus/src/callbacks.ts @@ -0,0 +1,115 @@ +import type { Result } from './result' + +export type CallbackShape = Record + +export type Callback = { + openFile: { + uri: string + } + rpcResponse: RPCResponse +} & CallbackShape + +/** + * A tuple type representing a callback event with its associated payload. + * The first element is the callback key (e.g., 'openFile', 'formatProject'). + * The second element is the payload type associated with that key. + * + * Example: + * const openFileEvent: CallbackEvent<'openFile'> = ['openFile', { path: '/path/to/file' }]; + */ +export type CallbackEvent = { + [K in keyof Callback]: { key: K; payload: Callback[K] } +}[keyof Callback] + +export type VSCodeCallbackShape = Record + +/** + * A tuple type representing a VSCode event with its associated payload. + */ +export type VSCodeCallback = { + changeFocusOnFile: { + path: string + } + savedFile: { + fileUri: string + } + rpcRequest: RPCRequest +} & VSCodeCallbackShape + +export type VSCodeEvent = { + [K in keyof VSCodeCallback]: { key: K; payload: VSCodeCallback[K] } +}[keyof VSCodeCallback] + +type RPCMethodsShape = Record + +export type RPCMethods = { + get_active_file: { + params: {} + result: { + fileUri?: string + } + } + api_query: { + params: { + url: string + method: string + params: any + body: any + } + result: any + } + get_selected_model: { + params: {} + result: { + selectedModel?: any + } + } + get_all_models: { + params: {} + result: { + ok: boolean + models?: any[] + error?: string + } + } + set_selected_model: { + params: { + model: any + } + result: { + ok: boolean + selectedModel?: any + } + } + get_environments: { + params: {} + result: { + ok: boolean + environments?: Record + error?: string + } + } + run_table_diff: { + params: { + sourceModel: string + sourceEnvironment: string + targetEnvironment: string + } + result: { + ok: boolean + data?: any + error?: string + } + } +} & RPCMethodsShape + +export type RPCRequest = { + requestId: string + method: keyof RPCMethods + params: RPCMethods[keyof RPCMethods]['params'] +} + +export type RPCResponse = { + requestId: string + result: Result +} diff --git a/vscode/bus/src/result.ts b/vscode/bus/src/result.ts new file mode 100644 index 0000000000..753b8c2e22 --- /dev/null +++ b/vscode/bus/src/result.ts @@ -0,0 +1,27 @@ +/** + * A result is a value that can be either an ok or an error + */ +export type Result = { ok: true; value: T } | { ok: false; error: E } + +/** + * returns true if the result is an error + */ +export const isErr = ( + result: Result, +): result is { ok: false; error: E } => { + return !result.ok +} + +/** + * returns an ok version `Result` from a value `T` + */ +export const ok = (value: T): { ok: true; value: T } => { + return { ok: true, value } +} + +/** + * returns an error version `Result` from an error `E` + */ +export const err = (error: E): { ok: false; error: E } => { + return { ok: false, error } +} diff --git a/vscode/bus/tsconfig.json b/vscode/bus/tsconfig.json new file mode 100644 index 0000000000..61a624b252 --- /dev/null +++ b/vscode/bus/tsconfig.json @@ -0,0 +1,19 @@ +{ + "compilerOptions": { + "target": "ESNext", + "module": "ESNext", + "declaration": true, + "outDir": "./dist", + "strict": true, + "strictNullChecks": true, + "noImplicitAny": true, + "noUnusedLocals": true, + "noUnusedParameters": true, + "noEmit": true, + "moduleResolution": "node", + "baseUrl": "./", + "skipLibCheck": true + }, + "include": ["src/**/*"], + "exclude": ["node_modules", "dist", "../../node_modules"] +} diff --git a/vscode/extension/.gitignore b/vscode/extension/.gitignore new file mode 100644 index 0000000000..f54728090f --- /dev/null +++ b/vscode/extension/.gitignore @@ -0,0 +1,10 @@ +node_modules +dist +out +.vscode-test +.test_setup +*.vsix +LICENSE +src_react +!src_react/.gitkeep +playwright-report \ No newline at end of file diff --git a/vscode/extension/.vscodeignore b/vscode/extension/.vscodeignore new file mode 100644 index 0000000000..47b7075c62 --- /dev/null +++ b/vscode/extension/.vscodeignore @@ -0,0 +1,29 @@ +.vscode/** +.vscode-test/** +out/** +node_modules/** +src/** +.gitignore +.yarnrc +webpack.config.js +vsc-extension-quickstart.md +**/tsconfig.json +**/eslint.config.mjs +**/*.map +**/*.ts +**/.vscode-test.* +assets/logo.svg +esbuild.js +openapi.json +test-results/** +E2E_TESTING.md +**/*.test.ts +**/*.test.js +.mocharc.json +tsconfig.test.json +tsconfig.build.json +src/test/** +tests/** +.claude +.idea +.test_setup \ No newline at end of file diff --git a/vscode/extension/CHANGELOG.md b/vscode/extension/CHANGELOG.md new file mode 100644 index 0000000000..fa2519edaf --- /dev/null +++ b/vscode/extension/CHANGELOG.md @@ -0,0 +1,9 @@ +# Change Log + +All notable changes to the "vscode" extension will be documented in this file. + +Check [Keep a Changelog](http://keepachangelog.com/) for recommendations on how to structure this file. + +## [Unreleased] + +- Initial release \ No newline at end of file diff --git a/vscode/extension/README.md b/vscode/extension/README.md new file mode 100644 index 0000000000..64f6c3e130 --- /dev/null +++ b/vscode/extension/README.md @@ -0,0 +1,96 @@ +# SQLMesh Visual Studio Code Extension + +**Transform your data engineering workflow with intelligent SQL development and powerful lineage visualization.** + +## Overview + +The SQLMesh VSCode extension brings the power of SQLMesh directly into your editor with intelligent code assistance, interactive lineage visualization, and seamless integration with Tobiko Cloud. + +## 🚀 Quick Start + +1. **Install the extension** from the [Visual Studio Marketplace](https://marketplace.visualstudio.com/items?itemName=tobikodata.sqlmesh) +2. **Set up your Python environment** with SQLMesh +3. **Configure your Python interpreter** in VSCode +4. **Start building data pipelines!** + +For a more detailed guide, see the [VSCode Extension Guide](https://sqlmesh.readthedocs.io/en/stable/guides/vscode/). + +## ✨ Features + +### 🔗 Interactive Lineage Visualization +- **Real-time lineage graphs** showing data flow between models +- **Interactive exploration** with clickable nodes +- **Model dependency tracking** across your entire project + +### 🧠 Intelligent Code Assistance +- **Smart auto-completion** for model names and SQLMesh keywords +- **Hover tooltips** with model descriptions and metadata +- **Go-to-definition** navigation for model references +- **Real-time error detection** with inline diagnostics + +### 🎨 Code Formatting & Quality +- **Automatic formatting** for SQLMesh models +- **Integrated linter** with built-in and custom rules +- **Format on save** support +- **Project-wide formatting** commands + +### ☁️ Tobiko Cloud Integration +- **Seamless authentication** within VSCode +- **Cloud project management** +- **Secure credential handling** + +## 📖 Usage + +Here's an overview of the extension's features: + +### Viewing Model Lineage +1. Open any SQLMesh model file +2. Navigate to the "Lineage" tab in the panel +3. Explore your data pipeline visually + +### Using Auto-completion +- Start typing model names or SQLMesh keywords +- Press `Ctrl+Space` to trigger suggestions +- Navigate with arrow keys and press `Enter` to accept + +### Formatting Code +- **Single file**: Right-click → "Format Document" +- **Entire project**: Command Palette → "Format SQLMesh project" +- **Auto-format**: Enable "Format on Save" in VSCode settings + +### Managing Tobiko Cloud Authentication +- **Sign in**: Command Palette → "Sign in to Tobiko Cloud" +- **Sign out**: Command Palette → "Sign out of Tobiko Cloud" +- **View status**: Check the bottom-left status bar + +## 🐛 Troubleshooting + +If you encounter issues, please refer to the [VSCode Extension Guide](https://sqlmesh.readthedocs.io/en/stable/guides/vscode/) for troubleshooting steps. + +## 📚 Documentation + +- [Full SQLMesh Documentation](https://sqlmesh.readthedocs.io/) +- [VSCode Extension Guide](https://sqlmesh.readthedocs.io/en/stable/guides/vscode/) +- [Tobiko Cloud Documentation](https://docs.tobiko.cloud/) + +## 🤝 Contributing + +We welcome contributions! Please: + +1. [Report bugs](https://github.com/tobikodata/sqlmesh/issues) you encounter +2. [Request features](https://github.com/tobikodata/sqlmesh/issues) you'd like to see +3. Share feedback on your experience + +## 📄 License + +This extension is licensed under the Apache License 2.0. See [LICENSE](LICENSE) for details. + +## 🔗 Links + +- [SQLMesh GitHub Repository](https://github.com/tobikodata/sqlmesh) +- [Tobiko Data Website](https://tobikodata.com) +- [Extension Marketplace Page](https://marketplace.visualstudio.com/items?itemName=tobikodata.sqlmesh) + +--- + +**Happy data engineering!** 🚀 \ No newline at end of file diff --git a/vscode/extension/assets/images/diff.svg b/vscode/extension/assets/images/diff.svg new file mode 100644 index 0000000000..fec20deaa1 --- /dev/null +++ b/vscode/extension/assets/images/diff.svg @@ -0,0 +1,3 @@ + + + \ No newline at end of file diff --git a/vscode/extension/assets/logo.png b/vscode/extension/assets/logo.png new file mode 100644 index 0000000000..95108a963d Binary files /dev/null and b/vscode/extension/assets/logo.png differ diff --git a/vscode/extension/assets/logo.svg b/vscode/extension/assets/logo.svg new file mode 100644 index 0000000000..66cc27fc25 --- /dev/null +++ b/vscode/extension/assets/logo.svg @@ -0,0 +1,9 @@ + + + + + + + + + \ No newline at end of file diff --git a/vscode/extension/esbuild.js b/vscode/extension/esbuild.js new file mode 100644 index 0000000000..f24f83a687 --- /dev/null +++ b/vscode/extension/esbuild.js @@ -0,0 +1,65 @@ +const esbuild = require('esbuild') + +const production = process.argv.includes('--production') +const watch = process.argv.includes('--watch') + +async function main() { + const ctx = await esbuild.context({ + entryPoints: ['src/extension.ts'], + bundle: true, + format: 'cjs', + minify: production, + sourcemap: !production, + sourcesContent: false, + platform: 'node', + outfile: 'dist/extension.js', + external: ['vscode'], + logLevel: 'warning', + plugins: [ + /* add to the end of plugins array */ + esbuildProblemMatcherPlugin, + { + name: 'exclude-tests', + setup(build) { + build.onResolve({ filter: /\.test\.ts$/ }, args => { + return { external: true } + }) + }, + }, + ], + }) + if (watch) { + await ctx.watch() + } else { + await ctx.rebuild() + await ctx.dispose() + } +} + +/** + * @type {import('esbuild').Plugin} + */ +const esbuildProblemMatcherPlugin = { + name: 'esbuild-problem-matcher', + + setup(build) { + build.onStart(() => { + console.log('[watch] build started') + }) + build.onEnd(result => { + result.errors.forEach(({ text, location }) => { + console.error(`✘ [ERROR] ${text}`) + if (location == null) return + console.error( + ` ${location.file}:${location.line}:${location.column}:`, + ) + }) + console.log('[watch] build finished') + }) + }, +} + +main().catch(e => { + console.error(e) + process.exit(1) +}) diff --git a/vscode/extension/eslint.config.mjs b/vscode/extension/eslint.config.mjs new file mode 100644 index 0000000000..8713558998 --- /dev/null +++ b/vscode/extension/eslint.config.mjs @@ -0,0 +1,72 @@ +import eslint from '@eslint/js' +import tseslint from 'typescript-eslint' + +export default tseslint.config( + eslint.configs.recommended, + tseslint.configs.strict, + tseslint.configs.stylistic, + tseslint.configs.recommendedTypeChecked, + { + languageOptions: { + parserOptions: { + projectService: true, + tsconfigRootDir: import.meta.dirname, + }, + }, + }, + { + rules: { + 'no-fallthrough': 'error', + '@typescript-eslint/switch-exhaustiveness-check': 'error', + '@typescript-eslint/no-unsafe-assignment': 'off', + '@typescript-eslint/no-explicit-any': 'off', + '@typescript-eslint/no-non-null-assertion': 'off', + '@typescript-eslint/restrict-template-expressions': 'off', + '@typescript-eslint/no-unsafe-argument': 'off', + '@typescript-eslint/no-unsafe-member-access': 'off', + }, + }, + { + files: ['**/*.ts'], + ignores: ['**/*.test.ts'], + rules: { + 'no-restricted-imports': [ + 'error', + { + patterns: ['*.test', '*.test.ts', '**/test/**'], + }, + ], + }, + }, + { + files: ['**/*.test.ts'], + languageOptions: { + parserOptions: { + projectService: false, + project: './tsconfig.test.json', + tsconfigRootDir: import.meta.dirname, + }, + }, + rules: { + '@typescript-eslint/no-unsafe-call': 'off', + '@typescript-eslint/no-unsafe-member-access': 'off', + }, + }, + { + files: ['tests/**/*.spec.ts'], + rules: { + 'no-restricted-imports': [ + 'error', + { + patterns: [ + { + group: ['@playwright/test'], + message: + 'Import { test, expect, Page } from "./fixtures" instead of directly from @playwright/test', + }, + ], + }, + ], + }, + }, +) diff --git a/vscode/extension/package.json b/vscode/extension/package.json new file mode 100644 index 0000000000..35499ad68f --- /dev/null +++ b/vscode/extension/package.json @@ -0,0 +1,185 @@ +{ + "name": "sqlmesh", + "displayName": "SQLMesh", + "description": "Official SQLMesh extension for VSCode", + "publisher": "tobikodata", + "version": "0.0.7", + "repository": { + "type": "git", + "url": "https://github.com/tobikodata/sqlmesh" + }, + "main": "./dist/extension.js", + "icon": "assets/logo.png", + "engines": { + "vscode": "^1.96.0" + }, + "categories": [ + "Other" + ], + "activationEvents": [ + "onLanguage:sql", + "onLanguage:python", + "onLanguage:yaml" + ], + "extensionKind": [ + "workspace" + ], + "extensionDependencies": [ + "ms-python.python" + ], + "contributes": { + "configuration": { + "type": "object", + "title": "SQLMesh", + "properties": { + "sqlmesh.projectPaths": { + "type": "array", + "items": { + "type": "string" + }, + "default": [], + "description": "The path to the SQLMesh project. If not set, the extension will try to find the project root automatically. If set, the extension will use the project root as the workspace path, e.g. it will run `sqlmesh` and `sqlmesh_lsp` in the project root. The path can be absolute `/Users/sqlmesh_user/sqlmesh_project/sushi` or relative `./project_folder/sushi` to the workspace root. Multiple paths can be used for multi-project setups." + }, + "sqlmesh.lspEntrypoint": { + "type": "string", + "default": "", + "markdownDescription": "The entry point for the SQLMesh LSP server. If not set the extension looks for the default lsp. If set, the extension will use the entry point as the LSP path, The path can be absolute `/Users/sqlmesh_user/sqlmesh_project/sushi/sqlmesh_lsp` or relative `./project_folder/sushi/sqlmesh_lsp` to the workspace root. It can also have arguments, e.g. `./project_folder/sushi/sqlmesh_lsp --port 5000`." + } + } + }, + "viewsContainers": { + "panel": [ + { + "id": "lineage_view", + "title": "Lineage", + "icon": "./assets/images/dag.svg" + } + ] + }, + "views": { + "lineage_view": [ + { + "id": "sqlmesh.lineage", + "name": "", + "type": "webview", + "icon": "./assets/images/dag.svg" + } + ] + }, + "authentication": [ + { + "id": "tobikodata", + "label": "Tobiko" + } + ], + "commands": [ + { + "command": "sqlmesh.format", + "title": "SQLMesh: Format Project", + "description": "SQLMesh" + }, + { + "command": "sqlmesh.restart", + "title": "SQLMesh: Restart Servers", + "description": "SQLMesh" + }, + { + "command": "sqlmesh.printEnvironment", + "title": "SQLMesh: Print Environment Variables", + "description": "SQLMesh" + }, + { + "command": "sqlmesh.signin", + "title": "SQLMesh: Sign in to Tobiko Cloud", + "description": "SQLMesh" + }, + { + "command": "sqlmesh.signinSpecifyFlow", + "title": "SQLMesh: Sign in to Tobiko Cloud (Specify Auth Flow)", + "description": "SQLMesh" + }, + { + "command": "sqlmesh.signout", + "title": "SQLMesh: Sign out from Tobiko Cloud", + "description": "SQLMesh" + }, + { + "command": "sqlmesh.renderModel", + "title": "SQLMesh: Render Model", + "description": "SQLMesh", + "icon": "$(open-preview)" + }, + { + "command": "sqlmesh.stop", + "title": "SQLMesh: Stop Server", + "description": "SQLMesh" + }, + { + "command": "sqlmesh.showTableDiff", + "title": "SQLMesh: Show Table Diff", + "description": "SQLMesh", + "icon": "$(diff)" + } + ], + "menus": { + "editor/title": [ + { + "command": "sqlmesh.renderModel", + "when": "resourceExtname == .sql", + "group": "navigation" + }, + { + "command": "sqlmesh.showTableDiff", + "when": "resourceExtname == .sql", + "group": "navigation" + } + ] + } + }, + "scripts": { + "ci": "pnpm run lint && pnpm run compile && pnpm run test:unit", + "lint": "eslint src tests", + "lint:fix": "eslint src tests --fix", + "test:unit": "vitest run", + "code-server": "code-server", + "test:e2e": "pnpm run vscode:package && playwright test", + "test:e2e:ui": "pnpm run vscode:package && playwright test --ui", + "test:e2e:headed": "pnpm run vscode:package && playwright test --headed", + "compile": "pnpm run check-types && node esbuild.js", + "check-types": "tsc --noEmit -p ./tsconfig.build.json", + "watch": "node esbuild.js --watch", + "watch:tsc": "tsc --noEmit --watch --project tsconfig.json", + "vscode:package": "vsce package --no-dependencies", + "vscode:prepublish": "cp ../../LICENSE . && pnpm run package", + "package": "rm -rf ./src_react && mkdir -p ./src_react && cd ../react && pnpm run build && cd ../extension && cp -r ../react/dist/* ./src_react && pnpm run check-types && node esbuild.js --production" + }, + "dependencies": { + "@duckdb/node-api": "1.3.2-alpha.25", + "@types/fs-extra": "^11.0.4", + "@types/shell-quote": "^1.7.5", + "@vscode/python-extension": "^1.0.5", + "fs-extra": "^11.3.0", + "shell-quote": "^1.8.3", + "vscode-jsonrpc": "^8.2.1", + "vscode-languageclient": "^9.0.1", + "zod": "^3.25.76" + }, + "devDependencies": { + "@eslint/js": "^9.31.0", + "@playwright/test": "^1.54.1", + "@types/mocha": "^10.0.10", + "@types/node": "20.11.25", + "@types/vscode": "1.96.0", + "@vitest/ui": "^3.2.4", + "@vscode/test-cli": "^0.0.10", + "@vscode/test-electron": "^2.5.2", + "@vscode/vsce": "^3.6.0", + "esbuild": "^0.25.8", + "eslint": "^9.31.0", + "ts-loader": "^9.5.2", + "typescript": "^5.8.3", + "typescript-eslint": "^8.38.0", + "vitest": "^3.2.4", + "yaml": "^2.8.0" + } +} diff --git a/vscode/extension/playwright.config.ts b/vscode/extension/playwright.config.ts new file mode 100644 index 0000000000..95d3bda589 --- /dev/null +++ b/vscode/extension/playwright.config.ts @@ -0,0 +1,35 @@ +import { defineConfig } from '@playwright/test' + +export default defineConfig({ + testDir: 'tests', + timeout: 60_000, + // TODO: When stable, allow retries in CI + retries: process.env.CI ? 2 : 0, + workers: process.env.CI ? 2 : 4, + reporter: [['html', { outputFolder: 'playwright-report' }], ['list']], + projects: [ + { + name: 'setup', + testMatch: 'tests/extension.setup.ts', + teardown: 'cleanup', + }, + { + name: 'cleanup', + testMatch: 'tests/extension.teardown.ts', + }, + { + name: 'electron-vscode', + use: { + browserName: 'chromium', + headless: true, + launchOptions: { + slowMo: process.env.CI ? 0 : 100, + }, + viewport: { width: 1512, height: 944 }, + video: 'retain-on-failure', + trace: 'retain-on-first-failure', + }, + dependencies: ['setup'], + }, + ], +}) diff --git a/vscode/extension/src/auth/auth.ts b/vscode/extension/src/auth/auth.ts new file mode 100644 index 0000000000..8d7908f06b --- /dev/null +++ b/vscode/extension/src/auth/auth.ts @@ -0,0 +1,381 @@ +import { + env, + Uri, + AuthenticationProvider, + AuthenticationProviderAuthenticationSessionsChangeEvent, + AuthenticationSession, + Event, + EventEmitter, + window, +} from 'vscode' +import { getTcloudBin } from '../utilities/sqlmesh/sqlmesh' +import { err, isErr, ok, Result } from '@bus/result' +import { execAsync } from '../utilities/exec' +import { getProjectRoot } from '../utilities/common/utilities' +import z from 'zod' +import { traceError } from '../utilities/common/log' +import { ErrorType } from '../utilities/errors' + +export const AUTH_TYPE = 'tobikodata' +export const AUTH_NAME = 'Tobiko' + +const tokenSchema = z.object({ + iss: z.string(), + aud: z.string(), + sub: z.string(), + scope: z.string(), + iat: z.number(), + exp: z.number(), + email: z.string(), +}) + +const statusResponseSchema = z.discriminatedUnion('is_logged_in', [ + z.object({ + is_logged_in: z.literal(true), + id_token: tokenSchema, + }), + z.object({ + is_logged_in: z.literal(false), + id_token: z.object({}), + }), +]) + +type StatusResponse = z.infer + +const loginUrlResponseSchema = z.object({ + url: z.string(), + verifier_code: z.string(), +}) + +const deviceCodeResponseSchema = z.object({ + device_code: z.string(), + user_code: z.string(), + verification_uri: z.string(), + verification_uri_complete: z.string(), + expires_in: z.number(), +}) + +export class AuthenticationProviderTobikoCloud + implements AuthenticationProvider +{ + static id = AUTH_TYPE + static name = AUTH_NAME + + private _sessionChangeEmitter = + new EventEmitter() + + onDidChangeSessions: Event = + this._sessionChangeEmitter.event + + /** + * Get the status of the authentication provider from the cli + * @returns true if the user is logged in with the id token, false otherwise + */ + private async get_status(): Promise> { + const workspacePath = await getProjectRoot() + const tcloudBin = await getTcloudBin() + if (isErr(tcloudBin)) { + return tcloudBin + } + const tcloudBinPath = tcloudBin.value + const result = await execAsync( + tcloudBinPath.bin, + ['auth', 'vscode', 'status'], + { + cwd: workspacePath.uri.fsPath, + env: tcloudBinPath.env, + }, + ) + if (result.exitCode !== 0) { + return err({ + type: 'generic', + message: 'Failed to get tcloud auth status', + }) + } + const status = result.stdout + const statusToJson: any = JSON.parse(status) + const statusResponse = statusResponseSchema.parse(statusToJson) + return ok(statusResponse) + } + + async getSessions(): Promise { + const status = await this.get_status() + if (isErr(status)) { + return [] + } + const statusResponse = status.value + if (!statusResponse.is_logged_in) { + return [] + } + const token = statusResponse.id_token + if (!token) { + throw new Error('Invalid state from tcloud, failed to get token.') + } + const session = { + id: token.email, + account: { + id: token.sub, + label: token.email, + }, + scopes: token.scope.split(' '), + accessToken: '', + } + return [session] + } + + async createSession(): Promise { + await this.sign_in_oauth_flow() + const status = await this.get_status() + if (isErr(status)) { + throw new Error('Failed to get tcloud auth status') + } + const statusResponse = status.value + if (!statusResponse.is_logged_in) { + throw new Error('Failed to login to tcloud') + } + const token = statusResponse.id_token + if (!token) { + throw new Error('Failed to get tcloud token') + } + const session: AuthenticationSession = { + id: token.email, + account: { + id: token.email, + label: 'Tobiko', + }, + scopes: token.scope.split(' '), + accessToken: '', + } + this._sessionChangeEmitter.fire({ + added: [session], + removed: [], + changed: [], + }) + return session + } + + async removeSession(): Promise { + // Get current sessions before logging out + const currentSessions = await this.getSessions() + const tcloudBin = await getTcloudBin() + const workspacePath = await getProjectRoot() + if (isErr(tcloudBin)) { + throw new Error('Failed to get tcloud bin') + } + const tcloudBinPath = tcloudBin.value + const result = await execAsync(tcloudBinPath.bin, ['auth', 'logout'], { + cwd: workspacePath.uri.fsPath, + env: tcloudBinPath.env, + }) + if (result.exitCode !== 0) { + throw new Error('Failed to logout from tcloud') + } + + // Emit event with the actual sessions that were removed + if (currentSessions.length > 0) { + this._sessionChangeEmitter.fire({ + added: [], + removed: currentSessions, + changed: [], + }) + } + } + + async sign_in_oauth_flow(): Promise { + const workspacePath = await getProjectRoot() + const tcloudBin = await getTcloudBin() + if (isErr(tcloudBin)) { + throw new Error('Failed to get tcloud bin') + } + const tcloudBinPath = tcloudBin.value + const result = await execAsync( + tcloudBinPath.bin, + ['auth', 'vscode', 'login-url'], + { + cwd: workspacePath.uri.fsPath, + }, + ) + if (result.exitCode !== 0) { + throw new Error('Failed to get tcloud login url') + } + + try { + const resultToJson: any = JSON.parse(result.stdout) + const urlCode = loginUrlResponseSchema.parse(resultToJson) + const url = urlCode.url + + if (!url) { + throw new Error('Invalid login URL received') + } + + const ac = new AbortController() + const timeout = setTimeout( + () => { + ac.abort() + }, + 1000 * 60 * 5, + ) + const backgroundServerForLogin = execAsync( + tcloudBinPath.bin, + ['auth', 'vscode', 'start-server', urlCode.verifier_code], + { + cwd: workspacePath.uri.fsPath, + signal: ac.signal, + env: tcloudBinPath.env, + }, + ) + + const messageResult = await window.showInformationMessage( + 'Please login to Tobiko Cloud', + { + modal: true, + }, + 'Sign in with browser', + 'Cancel', + ) + + if (messageResult === 'Sign in with browser') { + await env.openExternal(Uri.parse(url)) + } else { + // Always abort the server if not proceeding with sign in + ac.abort() + clearTimeout(timeout) + if (messageResult === 'Cancel') { + throw new Error('Login cancelled') + } + return + } + + try { + const output = await backgroundServerForLogin + if (output.exitCode !== 0) { + throw new Error(`Failed to complete authentication: ${output.stderr}`) + } + // Get updated session and notify about the change + const sessions = await this.getSessions() + if (sessions.length > 0) { + this._sessionChangeEmitter.fire({ + added: sessions, + removed: [], + changed: [], + }) + } + } catch (error) { + if (error instanceof Error && error.name === 'AbortError') { + throw new Error('Authentication timeout or aborted') + } + traceError(`Server error: ${error}`) + throw error + } finally { + clearTimeout(timeout) + } + } catch (error) { + if (error instanceof Error && error.message === 'Login cancelled') { + throw error + } + traceError(`Authentication flow error: ${error}`) + throw new Error('Failed to complete authentication flow') + } + } + + async sign_in_device_flow(): Promise { + const workspacePath = await getProjectRoot() + const tcloudBin = await getTcloudBin() + if (isErr(tcloudBin)) { + throw new Error('Failed to get tcloud bin') + } + const tcloudBinPath = tcloudBin.value + const result = await execAsync( + tcloudBinPath.bin, + ['auth', 'vscode', 'device'], + { + cwd: workspacePath.uri.fsPath, + env: tcloudBinPath.env, + }, + ) + if (result.exitCode !== 0) { + throw new Error('Failed to get device code') + } + + try { + const resultToJson: any = JSON.parse(result.stdout) + const deviceCodeResponse = deviceCodeResponseSchema.parse(resultToJson) + + const ac = new AbortController() + const timeout = setTimeout( + () => { + ac.abort() + }, + 1000 * 60 * 5, + ) + const waiting = execAsync( + tcloudBinPath.bin, + ['auth', 'vscode', 'poll_device', deviceCodeResponse.device_code], + { + cwd: workspacePath.uri.fsPath, + signal: ac.signal, + env: tcloudBinPath.env, + }, + ) + + const messageResult = await window.showInformationMessage( + `Confirm the code ${deviceCodeResponse.user_code} at ${deviceCodeResponse.verification_uri}`, + { + modal: true, + }, + 'Open browser', + 'Cancel', + ) + + if (messageResult === 'Open browser') { + await env.openExternal( + Uri.parse(deviceCodeResponse.verification_uri_complete), + ) + } + if (messageResult === 'Cancel') { + ac.abort() + throw new Error('Login cancelled') + } + + try { + const output = await waiting + if (output.exitCode !== 0) { + throw new Error(`Failed to authenticate: ${output.stderr}`) + } + + // Get updated session and notify about the change + const sessions = await this.getSessions() + if (sessions.length > 0) { + this._sessionChangeEmitter.fire({ + added: sessions, + removed: [], + changed: [], + }) + } + } catch (error) { + traceError(`Authentication error: ${error}`) + throw error + } finally { + clearTimeout(timeout) + } + } catch (error) { + traceError(`JSON parsing error: ${error}`) + throw new Error('Failed to parse device code response') + } + } +} + +/** + * Checks if the user is currently signed into Tobiko Cloud. + * @returns A promise that resolves to true if the user is signed in, false otherwise. + */ +export async function isSignedIntoTobikoCloud(): Promise { + try { + const authProvider = new AuthenticationProviderTobikoCloud() + const sessions = await authProvider.getSessions() + return sessions.length > 0 + } catch (error) { + traceError(`Error checking authentication status: ${error}`) + return false + } +} diff --git a/vscode/extension/src/commands/commands.test.ts b/vscode/extension/src/commands/commands.test.ts new file mode 100644 index 0000000000..7f531a2ce1 --- /dev/null +++ b/vscode/extension/src/commands/commands.test.ts @@ -0,0 +1,19 @@ +import { assert, describe, it } from 'vitest' +import * as fs from 'fs' +import * as path from 'path' + +describe('Commands', () => { + it('all commands should start with "SQLMesh: " prefix', () => { + const packageJsonPath = path.join(__dirname, '..', '..', 'package.json') + const packageJson = JSON.parse(fs.readFileSync(packageJsonPath, 'utf8')) + + const commands = packageJson.contributes?.commands || [] + + commands.forEach((command: any) => { + assert( + command.title.startsWith('SQLMesh: '), + `Command "${command.command}" title "${command.title}" should start with "SQLMesh: "`, + ) + }) + }) +}) diff --git a/vscode/extension/src/commands/format.ts b/vscode/extension/src/commands/format.ts new file mode 100644 index 0000000000..f01435523a --- /dev/null +++ b/vscode/extension/src/commands/format.ts @@ -0,0 +1,42 @@ +import { traceLog } from '../utilities/common/log' +import { err, isErr, ok, Result } from '@bus/result' +import * as vscode from 'vscode' +import { ErrorType, handleError } from '../utilities/errors' +import { AuthenticationProviderTobikoCloud } from '../auth/auth' +import { LSPClient } from '../lsp/lsp' + +export const format = + ( + authProvider: AuthenticationProviderTobikoCloud, + lsp: LSPClient | undefined, + restartLSP: () => Promise, + ) => + async (): Promise => { + traceLog('Calling format') + const out = await internalFormat(lsp) + if (isErr(out)) { + return handleError( + authProvider, + restartLSP, + out.error, + 'Project format failed', + ) + } + vscode.window.showInformationMessage('Project formatted successfully') + } + +const internalFormat = async ( + lsp: LSPClient | undefined, +): Promise> => { + if (lsp === undefined) { + return err({ + type: 'generic', + message: 'LSP is not available', + }) + } + const response = await lsp.call_custom_method('sqlmesh/format_project', {}) + if (isErr(response)) { + return response + } + return ok(undefined) +} diff --git a/vscode/extension/src/commands/printEnvironment.ts b/vscode/extension/src/commands/printEnvironment.ts new file mode 100644 index 0000000000..41e74e40d3 --- /dev/null +++ b/vscode/extension/src/commands/printEnvironment.ts @@ -0,0 +1,40 @@ +import * as vscode from 'vscode' +import { getSqlmeshEnvironment } from '../utilities/sqlmesh/sqlmesh' +import { isErr } from '@bus/result' +import { IS_WINDOWS } from '../utilities/isWindows' + +export function printEnvironment() { + return async () => { + const envResult = await getSqlmeshEnvironment() + + if (isErr(envResult)) { + await vscode.window.showErrorMessage(envResult.error) + return + } + + const env = envResult.value + + // Create a new terminal with the SQLMesh environment + const terminal = vscode.window.createTerminal({ + name: 'SQLMesh Environment', + env: env, + }) + + // Show the terminal + terminal.show() + + // Run the appropriate command to display environment variables + if (IS_WINDOWS) { + // On Windows, use 'set' command + terminal.sendText('set') + } else { + // On Unix-like systems, use 'env' command + terminal.sendText('env | sort') + } + + // Show a notification + vscode.window.showInformationMessage( + 'SQLMesh environment variables displayed in terminal', + ) + } +} diff --git a/vscode/extension/src/commands/renderModel.ts b/vscode/extension/src/commands/renderModel.ts new file mode 100644 index 0000000000..24225c3e45 --- /dev/null +++ b/vscode/extension/src/commands/renderModel.ts @@ -0,0 +1,199 @@ +import * as vscode from 'vscode' +import { LSPClient } from '../lsp/lsp' +import { isErr } from '@bus/result' +import { RenderModelEntry } from '../lsp/custom' +import { RenderedModelProvider } from '../providers/renderedModelProvider' + +export async function reRenderModelForSourceFile( + sourceUri: string, + lspClient: LSPClient | undefined, + renderedModelProvider: RenderedModelProvider, +): Promise { + const renderedUri = renderedModelProvider.getRenderedUriForSource(sourceUri) + if (!renderedUri) { + return // No rendered model exists for this source file + } + if (!lspClient) { + return + } + + // Call the render model API + const result = await lspClient.call_custom_method('sqlmesh/render_model', { + textDocumentUri: sourceUri, + }) + + if (isErr(result)) { + // Silently fail on auto-rerender errors to avoid spamming user + return + } + + // Check if we got any models + if (!result.value.models || result.value.models.length === 0) { + return + } + + // Get the originally rendered model information + const originalModelInfo = + renderedModelProvider.getModelInfoForRendered(renderedUri) + + // Find the specific model that was originally rendered, or fall back to the first model + const selectedModel = originalModelInfo + ? result.value.models.find( + model => + model.name === originalModelInfo.name && + model.fqn === originalModelInfo.fqn, + ) || result.value.models[0] + : result.value.models[0] + + // Update the existing rendered model content + renderedModelProvider.updateRenderedModel( + renderedUri, + selectedModel.rendered_query, + ) +} + +export function renderModel( + lspClient?: LSPClient, + renderedModelProvider?: RenderedModelProvider, +) { + return async () => { + if (!lspClient) { + vscode.window.showErrorMessage('LSP client not available') + return + } + + // Get the current active editor + const activeEditor = vscode.window.activeTextEditor + + let documentUri: string + + if (!activeEditor) { + // No active editor, show a list of all models + const allModelsResult = await lspClient.call_custom_method( + 'sqlmesh/all_models_for_render', + {}, + ) + + if (isErr(allModelsResult)) { + vscode.window.showErrorMessage( + `Failed to get models: ${allModelsResult.error.message}`, + ) + return + } + + if ( + !allModelsResult.value.models || + allModelsResult.value.models.length === 0 + ) { + vscode.window.showInformationMessage('No models found in the project') + return + } + + // Let user choose from all models + const items = allModelsResult.value.models.map(model => ({ + label: model.name, + description: model.fqn, + detail: model.description ? model.description : undefined, + model: model, + })) + + const selected = await vscode.window.showQuickPick(items, { + placeHolder: 'Select a model to render', + }) + + if (!selected) { + return + } + + // Use the selected model's URI + documentUri = selected.model.uri + } else { + // Get the current document URI + documentUri = activeEditor.document.uri.toString(true) + } + + // Call the render model API + const result = await lspClient.call_custom_method('sqlmesh/render_model', { + textDocumentUri: documentUri, + }) + + if (isErr(result)) { + vscode.window.showErrorMessage( + `Failed to render model: ${result.error.message}`, + ) + return + } + + // Check if we got any models + if (!result.value.models || result.value.models.length === 0) { + vscode.window.showInformationMessage( + 'No models found in the current file', + ) + return + } + + // If multiple models, let user choose + let selectedModel: RenderModelEntry + if (result.value.models.length > 1) { + const items = result.value.models.map(model => ({ + label: model.name, + description: model.fqn, + detail: model.description ? model.description : undefined, + model: model, + })) + + const selected = await vscode.window.showQuickPick(items, { + placeHolder: 'Select a model to render', + }) + + if (!selected) { + return + } + + selectedModel = selected.model + } else { + selectedModel = result.value.models[0] + } + + if (!renderedModelProvider) { + vscode.window.showErrorMessage('Rendered model provider not available') + return + } + + // Store the rendered content and get a virtual URI + const uri = renderedModelProvider.storeRenderedModel( + selectedModel.name, + selectedModel.rendered_query, + documentUri, + selectedModel, + ) + + // Open the virtual document + const document = await vscode.workspace.openTextDocument(uri) + + // Determine the view column for side-by-side display + // Find the rightmost column with an editor + let maxColumn = vscode.ViewColumn.One + for (const editor of vscode.window.visibleTextEditors) { + if (editor.viewColumn && editor.viewColumn > maxColumn) { + maxColumn = editor.viewColumn + } + } + + // Open in the next column after the rightmost editor + const viewColumn = maxColumn + 1 + + // Open the document in the editor as a preview (preview: true is default) + await vscode.window.showTextDocument(document, { + viewColumn: viewColumn, + preview: true, + preserveFocus: false, + }) + + // Execute "Keep Open" command to convert preview tab to permanent tab + await vscode.commands.executeCommand('workbench.action.keepEditor') + + // Explicitly set the language mode to SQL for syntax highlighting + await vscode.languages.setTextDocumentLanguage(document, 'sql') + } +} diff --git a/vscode/extension/src/commands/signin.ts b/vscode/extension/src/commands/signin.ts new file mode 100644 index 0000000000..77131a8253 --- /dev/null +++ b/vscode/extension/src/commands/signin.ts @@ -0,0 +1,30 @@ +import { AuthenticationProviderTobikoCloud } from '../auth/auth' +import * as vscode from 'vscode' +import { isCodespaces } from '../utilities/isCodespaces' +import { traceInfo } from '../utilities/common/log' + +export const signIn = + ( + authenticationProvider: AuthenticationProviderTobikoCloud, + onSignInSuccess: () => Promise, + ) => + async () => { + if (isCodespaces()) { + await authenticationProvider.sign_in_device_flow() + } else { + await authenticationProvider.createSession() + } + + // Do not await this, as this will block the thread, you just need to show the message, but not block + vscode.window.showInformationMessage('Signed in successfully') + + // Execute callback after successful sign-in + if (onSignInSuccess) { + traceInfo('Executing post sign-in callback') + try { + await onSignInSuccess() + } catch (error) { + traceInfo(`Error in post sign-in callback: ${error}`) + } + } + } diff --git a/vscode/extension/src/commands/signinSpecifyFlow.ts b/vscode/extension/src/commands/signinSpecifyFlow.ts new file mode 100644 index 0000000000..2e0c0cfe15 --- /dev/null +++ b/vscode/extension/src/commands/signinSpecifyFlow.ts @@ -0,0 +1,61 @@ +import { AuthenticationProviderTobikoCloud } from '../auth/auth' +import { traceInfo } from '../utilities/common/log' +import { window } from 'vscode' + +export const signInSpecifyFlow = + ( + authenticationProvider: AuthenticationProviderTobikoCloud, + onSignInSuccess?: () => Promise, + ) => + async () => { + traceInfo('Sign in specify flow') + const flowOptions = [ + { + label: 'OAuth Flow', + description: 'Sign in using OAuth flow in your browser', + }, + { label: 'Device Flow', description: 'Sign in using a device code' }, + ] + const selectedFlow = await window.showQuickPick(flowOptions, { + placeHolder: 'Select authentication flow method', + ignoreFocusOut: true, + }) + if (!selectedFlow) { + traceInfo('Sign in cancelled by user') + return + } + if (selectedFlow.label === 'OAuth Flow') { + await authenticationProvider.sign_in_oauth_flow() + await authenticationProvider.getSessions() + await window.showInformationMessage('Sign in success') + + // Execute callback after successful sign-in + if (onSignInSuccess) { + traceInfo('Executing post sign-in callback') + try { + await onSignInSuccess() + } catch (error) { + traceInfo(`Error in post sign-in callback: ${error}`) + } + } + return + } else if (selectedFlow.label === 'Device Flow') { + await authenticationProvider.sign_in_device_flow() + await authenticationProvider.getSessions() + await window.showInformationMessage('Sign in success') + + // Execute callback after successful sign-in + if (onSignInSuccess) { + traceInfo('Executing post sign-in callback') + try { + await onSignInSuccess() + } catch (error) { + traceInfo(`Error in post sign-in callback: ${error}`) + } + } + return + } else { + traceInfo('Invalid flow selected') + return + } + } diff --git a/vscode/extension/src/commands/signout.ts b/vscode/extension/src/commands/signout.ts new file mode 100644 index 0000000000..614723a70f --- /dev/null +++ b/vscode/extension/src/commands/signout.ts @@ -0,0 +1,8 @@ +import { AuthenticationProviderTobikoCloud } from '../auth/auth' +import * as vscode from 'vscode' + +export const signOut = + (authenticationProvider: AuthenticationProviderTobikoCloud) => async () => { + await authenticationProvider.removeSession() + await vscode.window.showInformationMessage('Signed out successfully') + } diff --git a/vscode/extension/src/commands/stop.ts b/vscode/extension/src/commands/stop.ts new file mode 100644 index 0000000000..429d6fa7b6 --- /dev/null +++ b/vscode/extension/src/commands/stop.ts @@ -0,0 +1,18 @@ +import { window } from 'vscode' +import { LSPClient } from '../lsp/lsp' +import { traceInfo } from '../utilities/common/log' + +export const stop = (lspClient: LSPClient | undefined) => { + return async () => { + traceInfo('Stopping LSP server') + + if (!lspClient) { + await window.showInformationMessage('LSP server is not running') + return + } + + await lspClient.stop(true) + await window.showInformationMessage('LSP server stopped') + traceInfo('LSP server stopped successfully') + } +} diff --git a/vscode/extension/src/commands/tableDiff.ts b/vscode/extension/src/commands/tableDiff.ts new file mode 100644 index 0000000000..d9587d261b --- /dev/null +++ b/vscode/extension/src/commands/tableDiff.ts @@ -0,0 +1,589 @@ +import * as vscode from 'vscode' +import { LSPClient } from '../lsp/lsp' +import { isErr } from '@bus/result' +import { CallbackEvent, RPCRequest } from '@bus/callbacks' +import { getWorkspaceFolders } from '../utilities/common/vscodeapi' + +interface ModelInfo { + name: string + fqn: string + description?: string | null +} + +export function showTableDiff( + lspClient?: LSPClient, + extensionUri?: vscode.Uri, +) { + return async () => { + if (!lspClient) { + vscode.window.showErrorMessage('LSP client not available') + return + } + + if (!extensionUri) { + vscode.window.showErrorMessage('Extension URI not available') + return + } + + // Get the current active editor + const activeEditor = vscode.window.activeTextEditor + let selectedModelInfo: ModelInfo | null = null + + if (!activeEditor) { + // No active editor, show a list of all models + const allModelsResult = await lspClient.call_custom_method( + 'sqlmesh/get_models', + {}, + ) + + if (isErr(allModelsResult)) { + vscode.window.showErrorMessage( + `Failed to get models: ${allModelsResult.error.message}`, + ) + return + } + + if ( + !allModelsResult.value.models || + allModelsResult.value.models.length === 0 + ) { + vscode.window.showInformationMessage('No models found in the project') + return + } + + // Let user choose from all models + const items = (allModelsResult.value.models as ModelInfo[]).map( + (model: ModelInfo) => ({ + label: model.name, + description: model.fqn, + detail: model.description ? model.description : undefined, + model: { + name: model.name, + fqn: model.fqn, + description: model.description, + }, + }), + ) + + const selected = await vscode.window.showQuickPick(items, { + placeHolder: 'Select a model for table diff', + }) + + if (!selected) { + return + } + + selectedModelInfo = selected.model + } else { + // Get the current document URI and check if it contains models + const documentUri = activeEditor.document.uri.toString(true) + + // Call the render model API to get models in the current file + const result = await lspClient.call_custom_method( + 'sqlmesh/render_model', + { + textDocumentUri: documentUri, + }, + ) + + if (isErr(result)) { + vscode.window.showErrorMessage( + `Failed to get models from current file: ${result.error.message}`, + ) + return + } + + // Check if we got any models + if (!result.value.models || result.value.models.length === 0) { + vscode.window.showInformationMessage( + 'No models found in the current file', + ) + return + } + + // If multiple models, let user choose + if (result.value.models.length > 1) { + const items = result.value.models.map(model => ({ + label: model.name, + description: model.fqn, + detail: model.description ? model.description : undefined, + model: model, + })) + + const selected = await vscode.window.showQuickPick(items, { + placeHolder: 'Select a model for table diff', + }) + + if (!selected) { + return + } + + selectedModelInfo = selected.model + } else { + selectedModelInfo = result.value.models[0] + } + } + + // Ensure we have a selected model + if (!selectedModelInfo) { + vscode.window.showErrorMessage('No model selected') + return + } + + // Get environments for selection + const environmentsResult = await lspClient.call_custom_method( + 'sqlmesh/get_environments', + {}, + ) + + if (isErr(environmentsResult)) { + vscode.window.showErrorMessage( + `Failed to get environments: ${environmentsResult.error.message}`, + ) + return + } + + const environments = environmentsResult.value.environments || {} + const environmentNames = Object.keys(environments) + + if (environmentNames.length === 0) { + vscode.window.showErrorMessage('No environments found') + return + } + + // Let user select source environment + const sourceEnvironmentItems = environmentNames.map(env => ({ + label: env, + description: `Source environment: ${env}`, + })) + + const selectedSourceEnv = await vscode.window.showQuickPick( + sourceEnvironmentItems, + { + placeHolder: 'Select source environment', + }, + ) + + if (!selectedSourceEnv) { + return + } + + // Let user select target environment (excluding source) + const targetEnvironmentItems = environmentNames + .filter(env => env !== selectedSourceEnv.label) + .map(env => ({ + label: env, + description: `Target environment: ${env}`, + })) + + if (targetEnvironmentItems.length === 0) { + vscode.window.showErrorMessage( + 'Need at least two environments for comparison', + ) + return + } + + const selectedTargetEnv = await vscode.window.showQuickPick( + targetEnvironmentItems, + { + placeHolder: 'Select target environment', + }, + ) + + if (!selectedTargetEnv) { + return + } + + // Run table diff immediately with selected parameters + const tableDiffResult = await vscode.window.withProgress( + { + location: vscode.ProgressLocation.Notification, + title: 'SQLMesh', + cancellable: false, + }, + async progress => { + progress.report({ message: 'Calculating table differences...' }) + + return await lspClient.call_custom_method('sqlmesh/api', { + method: 'GET', + url: '/api/table_diff', + params: { + source: selectedSourceEnv.label, + target: selectedTargetEnv.label, + model_or_snapshot: selectedModelInfo.name, + }, + body: {}, + }) + }, + ) + + if (isErr(tableDiffResult)) { + vscode.window.showErrorMessage( + `Failed to run table diff: ${tableDiffResult.error.message}`, + ) + return + } + + // Determine the view column for side-by-side display + // Find the rightmost column with an editor + let maxColumn = vscode.ViewColumn.One + for (const editor of vscode.window.visibleTextEditors) { + if (editor.viewColumn && editor.viewColumn > maxColumn) { + maxColumn = editor.viewColumn + } + } + + // Open in the next column after the rightmost editor + const viewColumn = maxColumn + 1 + + // Create a webview panel for the table diff + const panel = vscode.window.createWebviewPanel( + 'sqlmesh.tableDiff', + `SQLMesh Table Diff - ${selectedModelInfo.name} (${selectedSourceEnv.label} → ${selectedTargetEnv.label})`, + viewColumn, + { + enableScripts: true, + retainContextWhenHidden: true, + localResourceRoots: [extensionUri], + }, + ) + + // Store the initial data for the webview + // eslint-disable-next-line prefer-const + let initialData = { + selectedModel: selectedModelInfo, + sourceEnvironment: selectedSourceEnv.label, + targetEnvironment: selectedTargetEnv.label, + tableDiffData: tableDiffResult.value, + environments: environments, + } + + // Set up message listener for events from the webview + panel.webview.onDidReceiveMessage( + async request => { + if (!request || !request.key) { + return + } + const message: CallbackEvent = request + switch (message.key) { + case 'openFile': { + const workspaceFolders = getWorkspaceFolders() + if (workspaceFolders.length != 1) { + throw new Error('Only one workspace folder is supported') + } + const fullPath = vscode.Uri.parse(message.payload.uri) + const document = await vscode.workspace.openTextDocument(fullPath) + await vscode.window.showTextDocument(document) + break + } + case 'rpcRequest': { + const payload: RPCRequest = message.payload + const requestId = payload.requestId + switch (payload.method) { + case 'api_query': { + const response = await lspClient.call_custom_method( + 'sqlmesh/api', + payload.params, + ) + let responseCallback: CallbackEvent + if (isErr(response)) { + let errorMessage: string + switch (response.error.type) { + case 'generic': + errorMessage = response.error.message + break + case 'invalid_state': + errorMessage = `Invalid state: ${response.error.message}` + break + case 'sqlmesh_outdated': + errorMessage = `SQLMesh version issue: ${response.error.message}` + break + default: + errorMessage = 'Unknown error' + } + responseCallback = { + key: 'rpcResponse', + payload: { + requestId, + result: { + ok: false, + error: errorMessage, + }, + }, + } + } else { + responseCallback = { + key: 'rpcResponse', + payload: { + requestId, + result: response, + }, + } + } + await panel.webview.postMessage(responseCallback) + break + } + case 'get_active_file': { + const active_file = + vscode.window.activeTextEditor?.document.uri.fsPath + const responseCallback: CallbackEvent = { + key: 'rpcResponse', + payload: { + requestId, + result: { + fileUri: active_file, + }, + }, + } + await panel.webview.postMessage(responseCallback) + break + } + case 'get_selected_model': { + const responseCallback: CallbackEvent = { + key: 'rpcResponse', + payload: { + requestId, + result: { + ok: true, + value: { + selectedModel: initialData.selectedModel, + }, + }, + }, + } + await panel.webview.postMessage(responseCallback) + break + } + case 'get_initial_data': { + const responseCallback: CallbackEvent = { + key: 'rpcResponse', + payload: { + requestId, + result: { + ok: true, + value: { + selectedModel: initialData.selectedModel, + sourceEnvironment: initialData.sourceEnvironment, + targetEnvironment: initialData.targetEnvironment, + tableDiffData: initialData.tableDiffData, + environments: initialData.environments, + }, + }, + }, + } + await panel.webview.postMessage(responseCallback) + break + } + case 'get_all_models': { + const allModelsResult = await lspClient.call_custom_method( + 'sqlmesh/get_models', + {}, + ) + + let responseCallback: CallbackEvent + if (isErr(allModelsResult)) { + responseCallback = { + key: 'rpcResponse', + payload: { + requestId, + result: { + ok: false, + error: `Failed to get models: ${allModelsResult.error.message}`, + }, + }, + } + } else { + responseCallback = { + key: 'rpcResponse', + payload: { + requestId, + result: { + ok: true, + value: { + ok: true, + models: allModelsResult.value.models || [], + }, + }, + }, + } + } + await panel.webview.postMessage(responseCallback) + break + } + case 'set_selected_model': { + const modelInfo = payload.params?.model + if (modelInfo) { + initialData.selectedModel = modelInfo + // Update the panel title to reflect the new selection + panel.title = `SQLMesh Table Diff - ${modelInfo.name} (${initialData.sourceEnvironment} → ${initialData.targetEnvironment})` + } + + const responseCallback: CallbackEvent = { + key: 'rpcResponse', + payload: { + requestId, + result: { + ok: true, + value: { + ok: true, + selectedModel: initialData.selectedModel, + }, + }, + }, + } + await panel.webview.postMessage(responseCallback) + break + } + case 'get_environments': { + const responseCallback: CallbackEvent = { + key: 'rpcResponse', + payload: { + requestId, + result: { + ok: true, + value: { + ok: true, + environments: initialData.environments, + }, + }, + }, + } + await panel.webview.postMessage(responseCallback) + break + } + case 'run_table_diff': { + const { sourceModel, sourceEnvironment, targetEnvironment } = + payload.params || {} + + if (!sourceModel || !sourceEnvironment || !targetEnvironment) { + const responseCallback: CallbackEvent = { + key: 'rpcResponse', + payload: { + requestId, + result: { + ok: false, + error: + 'Missing required parameters: sourceModel, sourceEnvironment, or targetEnvironment', + }, + }, + } + await panel.webview.postMessage(responseCallback) + break + } + + const tableDiffResult = await vscode.window.withProgress( + { + location: vscode.ProgressLocation.Notification, + title: 'SQLMesh', + cancellable: false, + }, + async progress => { + progress.report({ + message: 'Calculating table differences...', + }) + + return await lspClient.call_custom_method('sqlmesh/api', { + method: 'GET', + url: '/api/table_diff', + params: { + source: sourceEnvironment, + target: targetEnvironment, + model_or_snapshot: sourceModel, + }, + body: {}, + }) + }, + ) + + let responseCallback: CallbackEvent + if (isErr(tableDiffResult)) { + responseCallback = { + key: 'rpcResponse', + payload: { + requestId, + result: { + ok: false, + error: `Failed to run table diff: ${tableDiffResult.error.message}`, + }, + }, + } + } else { + responseCallback = { + key: 'rpcResponse', + payload: { + requestId, + result: { + ok: true, + value: { + ok: true, + data: tableDiffResult.value, + }, + }, + }, + } + } + await panel.webview.postMessage(responseCallback) + break + } + default: { + throw new Error(`Unhandled RPC method: ${payload.method}`) + } + } + break + } + default: + console.error( + 'Unhandled message type under queryRequest: ', + message, + ) + } + }, + undefined, + [], + ) + + // Set the HTML content + panel.webview.html = getHTML(panel.webview, extensionUri) + } +} + +function getHTML(webview: vscode.Webview, extensionUri: vscode.Uri): string { + const cssUri = webview.asWebviewUri( + vscode.Uri.joinPath(extensionUri, 'src_react', 'assets', 'index.css'), + ) + const jsUri = webview.asWebviewUri( + vscode.Uri.joinPath(extensionUri, 'src_react', 'assets', 'index.js'), + ) + const faviconUri = webview.asWebviewUri( + vscode.Uri.joinPath(extensionUri, 'src_react', 'favicon.ico'), + ) + const logoUri = webview.asWebviewUri( + vscode.Uri.joinPath(extensionUri, 'src_react', 'logo192.png'), + ) + + return ` + + + + + + + + + + SQLMesh Table Diff + + + + + +
+ + +` +} diff --git a/vscode/extension/src/completion/completion.ts b/vscode/extension/src/completion/completion.ts new file mode 100644 index 0000000000..8e8a101c50 --- /dev/null +++ b/vscode/extension/src/completion/completion.ts @@ -0,0 +1,36 @@ +import * as vscode from 'vscode' +import { LSPClient } from '../lsp/lsp' +import { isErr } from '@bus/result' + +export const selector: vscode.DocumentSelector = { + pattern: '**/*.sql', +} + +export const completionProvider = ( + lsp: LSPClient, +): vscode.CompletionItemProvider => { + return { + async provideCompletionItems(document) { + const result = await lsp.call_custom_method('sqlmesh/all_models', { + textDocument: { + uri: document.uri.fsPath, + }, + }) + if (isErr(result)) { + return [] + } + const modelCompletions = result.value.models.map( + model => + new vscode.CompletionItem(model, vscode.CompletionItemKind.Reference), + ) + const keywordCompletions = result.value.keywords.map( + keyword => + new vscode.CompletionItem(keyword, vscode.CompletionItemKind.Keyword), + ) + return new vscode.CompletionList([ + ...modelCompletions, + ...keywordCompletions, + ]) + }, + } +} diff --git a/vscode/extension/src/extension.ts b/vscode/extension/src/extension.ts new file mode 100644 index 0000000000..cfea8c2228 --- /dev/null +++ b/vscode/extension/src/extension.ts @@ -0,0 +1,221 @@ +/********************************************************************** + * Extension entry point * + *********************************************************************/ + +import * as vscode from 'vscode' + +import { format } from './commands/format' +import { signOut } from './commands/signout' +import { signIn } from './commands/signin' +import { signInSpecifyFlow } from './commands/signinSpecifyFlow' +import { renderModel, reRenderModelForSourceFile } from './commands/renderModel' +import { stop } from './commands/stop' +import { printEnvironment } from './commands/printEnvironment' + +import { + createOutputChannel, + onDidChangeConfiguration, + registerCommand, +} from './utilities/common/vscodeapi' +import { + registerLogger, + traceInfo, + traceVerbose, + traceError, +} from './utilities/common/log' +import { onDidChangePythonInterpreter } from './utilities/common/python' +import { sleep } from './utilities/sleep' +import { handleError } from './utilities/errors' + +import { selector, completionProvider } from './completion/completion' +import { LineagePanel } from './webviews/lineagePanel' +import { RenderedModelProvider } from './providers/renderedModelProvider' +import { showTableDiff } from './commands/tableDiff' + +import { + controller as testController, + setupTestController, +} from './tests/tests' + +import { isErr } from '@bus/result' +import { AuthenticationProviderTobikoCloud } from './auth/auth' +import { LSPClient } from './lsp/lsp' + +/** Singleton LSP client for the extension. */ +let lspClient: LSPClient | undefined + +/** Handle to the (single) test controller disposable so we can replace it on restart. */ +let testControllerDisposable: vscode.Disposable | undefined + +export async function activate(context: vscode.ExtensionContext) { + const extensionOutputChannel = createOutputChannel('sqlmesh') + context.subscriptions.push( + extensionOutputChannel, + registerLogger(extensionOutputChannel), + ) + traceInfo('Activating SQLMesh extension') + + const authProvider = new AuthenticationProviderTobikoCloud() + context.subscriptions.push( + vscode.authentication.registerAuthenticationProvider( + AuthenticationProviderTobikoCloud.id, + AuthenticationProviderTobikoCloud.name, + authProvider, + { supportsMultipleAccounts: false }, + ), + ) + + const restartLsp = async (invokedByUser = false): Promise => { + if (!lspClient) { + lspClient = new LSPClient() + } + + traceVerbose('Restarting SQLMesh LSP client') + const result = await lspClient.restart(invokedByUser) + if (isErr(result)) { + await handleError( + authProvider, + restartLsp, + result.error, + 'LSP restart failed', + ) + return + } + + // push once to avoid duplicate disposables on multiple restarts + if (!context.subscriptions.includes(lspClient)) { + context.subscriptions.push(lspClient) + } + + /* Replace the test controller each time we restart the client */ + if (testControllerDisposable) { + testControllerDisposable.dispose() + } + testControllerDisposable = setupTestController(lspClient) + context.subscriptions.push(testControllerDisposable) + } + + // commands needing the restart helper + context.subscriptions.push( + vscode.commands.registerCommand( + 'sqlmesh.signin', + signIn(authProvider, () => restartLsp()), + ), + vscode.commands.registerCommand( + 'sqlmesh.signinSpecifyFlow', + signInSpecifyFlow(authProvider, () => restartLsp()), + ), + vscode.commands.registerCommand('sqlmesh.signout', signOut(authProvider)), + ) + + // Instantiate the LSP client (once) + lspClient = new LSPClient() + const startResult = await lspClient.start() + if (isErr(startResult)) { + await handleError( + authProvider, + restartLsp, + startResult.error, + 'Failed to start LSP', + ) + return // abort activation – nothing else to do + } + + context.subscriptions.push(lspClient) + + // Initialize the test controller + testControllerDisposable = setupTestController(lspClient) + context.subscriptions.push(testControllerDisposable, testController) + + // Register the rendered model provider + const renderedModelProvider = new RenderedModelProvider() + context.subscriptions.push( + vscode.workspace.registerTextDocumentContentProvider( + RenderedModelProvider.getScheme(), + renderedModelProvider, + ), + renderedModelProvider, + ) + + context.subscriptions.push( + vscode.commands.registerCommand( + 'sqlmesh.renderModel', + renderModel(lspClient, renderedModelProvider), + ), + ) + + const lineagePanel = new LineagePanel(context.extensionUri, lspClient) + context.subscriptions.push( + vscode.window.registerWebviewViewProvider( + LineagePanel.viewType, + lineagePanel, + ), + ) + + // Register the table diff command + context.subscriptions.push( + vscode.commands.registerCommand( + 'sqlmesh.showTableDiff', + showTableDiff(lspClient, context.extensionUri), + ), + ) + + // Re‑render model automatically when its source file is saved + context.subscriptions.push( + vscode.workspace.onDidSaveTextDocument(async document => { + if ( + renderedModelProvider.hasRenderedModelForSource( + document.uri.toString(true), + ) + ) { + await sleep(100) + await reRenderModelForSourceFile( + document.uri.toString(true), + lspClient, + renderedModelProvider, + ) + } + }), + ) + + // miscellaneous commands + context.subscriptions.push( + vscode.commands.registerCommand( + 'sqlmesh.format', + format(authProvider, lspClient, restartLsp), + ), + registerCommand('sqlmesh.restart', () => restartLsp(true)), + registerCommand('sqlmesh.stop', stop(lspClient)), + registerCommand('sqlmesh.printEnvironment', printEnvironment()), + ) + + context.subscriptions.push( + onDidChangePythonInterpreter(() => restartLsp()), + onDidChangeConfiguration(() => restartLsp()), + ) + + if (!lspClient.hasCompletionCapability()) { + context.subscriptions.push( + vscode.languages.registerCompletionItemProvider( + selector, + completionProvider(lspClient), + ), + ) + } + + traceInfo('SQLMesh extension activated') +} + +// This method is called when your extension is deactivated +export async function deactivate() { + try { + if (testControllerDisposable) { + testControllerDisposable.dispose() + } + if (lspClient) { + await lspClient.dispose() + } + } catch (e) { + traceError(`Error during deactivate: ${e}`) + } +} diff --git a/vscode/extension/src/lsp/custom.ts b/vscode/extension/src/lsp/custom.ts new file mode 100644 index 0000000000..c8999d5b00 --- /dev/null +++ b/vscode/extension/src/lsp/custom.ts @@ -0,0 +1,221 @@ +export interface AllModelsMethod { + method: 'sqlmesh/all_models' + request: AllModelsRequest + response: AllModelsResponse +} + +export interface RenderModelMethod { + method: 'sqlmesh/render_model' + request: RenderModelRequest + response: RenderModelResponse +} + +interface RenderModelRequest { + textDocumentUri: string +} + +interface RenderModelResponse extends BaseResponse { + models: RenderModelEntry[] +} + +export interface RenderModelEntry { + name: string + fqn: string + description: string | null | undefined + rendered_query: string +} + +export type CustomLSPMethods = + | AllModelsMethod + | AbstractAPICall + | RenderModelMethod + | AllModelsForRenderMethod + | SupportedMethodsMethod + | FormatProjectMethod + | ListWorkspaceTests + | ListDocumentTests + | RunTest + | GetEnvironmentsMethod + | GetTableDiffModelsMethod + +interface AllModelsRequest { + textDocument: { + uri: string + } +} + +interface AllModelsResponse extends BaseResponse { + models: string[] + keywords: string[] +} + +export interface AbstractAPICallRequest { + url: string + method: string + params: Record + body: Record +} + +export interface AbstractAPICall { + method: 'sqlmesh/api' + request: AbstractAPICallRequest + response: AbstractAPICallResponse +} + +type AbstractAPICallResponse = object & BaseResponse + +export interface AllModelsForRenderMethod { + method: 'sqlmesh/all_models_for_render' + request: AllModelsForRenderRequest + response: AllModelsForRenderResponse +} + +// eslint-disable-next-line @typescript-eslint/no-empty-object-type +interface AllModelsForRenderRequest {} + +interface AllModelsForRenderResponse extends BaseResponse { + models: ModelForRendering[] +} + +export interface ModelForRendering { + name: string + fqn: string + description: string | null | undefined + uri: string +} + +export interface SupportedMethodsMethod { + method: 'sqlmesh/supported_methods' + request: SupportedMethodsRequest + response: SupportedMethodsResponse +} + +// eslint-disable-next-line @typescript-eslint/no-empty-object-type +interface SupportedMethodsRequest {} + +interface SupportedMethodsResponse extends BaseResponse { + methods: CustomMethod[] +} + +interface CustomMethod { + name: string +} + +export interface FormatProjectMethod { + method: 'sqlmesh/format_project' + request: FormatProjectRequest + response: FormatProjectResponse +} + +// eslint-disable-next-line @typescript-eslint/no-empty-object-type +interface FormatProjectRequest {} + +// eslint-disable-next-line @typescript-eslint/no-empty-object-type +interface FormatProjectResponse extends BaseResponse {} + +interface BaseResponse { + response_error?: string +} + +export interface ListWorkspaceTests { + method: 'sqlmesh/list_workspace_tests' + request: ListWorkspaceTestsRequest + response: ListWorkspaceTestsResponse +} + +type ListWorkspaceTestsRequest = object + +interface Position { + line: number + character: number +} + +interface Range { + start: Position + end: Position +} + +interface TestEntry { + name: string + uri: string + range: Range +} + +interface ListWorkspaceTestsResponse extends BaseResponse { + tests: TestEntry[] +} + +export interface ListDocumentTests { + method: 'sqlmesh/list_document_tests' + request: ListDocumentTestsRequest + response: ListDocumentTestsResponse +} + +export interface DocumentIdentifier { + uri: string +} + +export interface ListDocumentTestsRequest { + textDocument: DocumentIdentifier +} + +export interface ListDocumentTestsResponse extends BaseResponse { + tests: TestEntry[] +} + +export interface RunTest { + method: 'sqlmesh/run_test' + request: RunTestRequest + response: RunTestResponse +} + +export interface RunTestRequest { + textDocument: DocumentIdentifier + testName: string +} + +export interface RunTestResponse extends BaseResponse { + success: boolean + error_message?: string +} + +export interface GetEnvironmentsMethod { + method: 'sqlmesh/get_environments' + request: GetEnvironmentsRequest + response: GetEnvironmentsResponse +} + +// eslint-disable-next-line @typescript-eslint/no-empty-object-type +interface GetEnvironmentsRequest {} + +interface GetEnvironmentsResponse extends BaseResponse { + environments: Record + pinned_environments: string[] + default_target_environment: string +} + +interface EnvironmentInfo { + name: string + snapshots: string[] + start_at: string + plan_id: string +} + +export interface GetTableDiffModelsMethod { + method: 'sqlmesh/get_models' + request: GetModelsRequest + response: GetModelsResponse +} + +// eslint-disable-next-line @typescript-eslint/no-empty-object-type +interface GetModelsRequest {} + +interface GetModelsResponse extends BaseResponse { + models: ModelInfo[] +} + +interface ModelInfo { + name: string + fqn: string + description: string | null | undefined +} diff --git a/vscode/extension/src/lsp/lsp.ts b/vscode/extension/src/lsp/lsp.ts new file mode 100644 index 0000000000..1a11249853 --- /dev/null +++ b/vscode/extension/src/lsp/lsp.ts @@ -0,0 +1,269 @@ +import { window, OutputChannel, Disposable } from 'vscode' +import { + ServerOptions, + LanguageClientOptions, + LanguageClient, + TransportKind, +} from 'vscode-languageclient/node' +import { sqlmeshLspExec } from '../utilities/sqlmesh/sqlmesh' +import { err, isErr, ok, Result } from '@bus/result' +import { getWorkspaceFolders } from '../utilities/common/vscodeapi' +import { traceError, traceInfo } from '../utilities/common/log' +import { + ErrorType, + ErrorTypeGeneric, + ErrorTypeInvalidState, + ErrorTypeSQLMeshOutdated, +} from '../utilities/errors' +import { CustomLSPMethods } from './custom' +import { resolveProjectPath } from '../utilities/config' + +type SupportedMethodsState = + | { type: 'not-fetched' } + | { type: 'fetched'; methods: Set } + | { type: 'endpoint-not-supported' } // fallback for very old servers + +let outputChannel: OutputChannel | undefined + +export class LSPClient implements Disposable { + private client: LanguageClient | undefined + + /** Caches which custom methods the server supports */ + private supportedMethodsState: SupportedMethodsState = { type: 'not-fetched' } + + /** + * Remember whether the user explicitly stopped the client so that we do not + * auto‑start again until they ask for it. + */ + private explicitlyStopped = false + + /** True when a LanguageClient instance is alive. */ + private get isRunning(): boolean { + return this.client !== undefined + } + + /** + * Query whether the connected server advertises completion capability. + * (Transient helper kept for backwards‑compat reasons.) + */ + public hasCompletionCapability(): boolean { + if (!this.client) { + traceError('LSP client is not initialized') + return false + } + return ( + this.client.initializeResult?.capabilities?.completionProvider !== + undefined + ) + } + + /** Start the Language Client unless it is already running. */ + public async start( + overrideStoppedByUser = false, + ): Promise> { + if (this.explicitlyStopped && !overrideStoppedByUser) { + traceInfo( + 'LSP client has been explicitly stopped by user, not starting again.', + ) + return ok(undefined) + } + + // Guard against duplicate initialisation + if (this.isRunning) { + traceInfo('LSP client already running – start() is a no‑op.') + return ok(undefined) + } + + // Ensure we have an output channel + if (!outputChannel) { + outputChannel = window.createOutputChannel('sqlmesh-lsp') + } + + // Resolve sqlmesh executable + const sqlmesh = await sqlmeshLspExec() + if (isErr(sqlmesh)) { + traceError( + `Failed to get sqlmesh_lsp_exec, ${JSON.stringify(sqlmesh.error)}`, + ) + return sqlmesh + } + + // We need at least one workspace + if (getWorkspaceFolders().length === 0) { + const msg = 'No workspace folders found' + traceError(msg) + return err({ type: 'generic', message: msg }) + } + + const workspacePath = sqlmesh.value.workspacePath + const serverOptions: ServerOptions = { + run: { + command: sqlmesh.value.bin, + transport: TransportKind.stdio, + options: { cwd: workspacePath, env: sqlmesh.value.env }, + args: sqlmesh.value.args, + }, + debug: { + command: sqlmesh.value.bin, + transport: TransportKind.stdio, + options: { cwd: workspacePath, env: sqlmesh.value.env }, + args: sqlmesh.value.args, + }, + } + const paths = resolveProjectPath(getWorkspaceFolders()[0]) + if (isErr(paths)) { + traceError(`Failed to resolve project paths: ${paths.error}`) + return err({ type: 'generic', message: paths.error }) + } + const clientOptions: LanguageClientOptions = { + documentSelector: [ + { scheme: 'file', pattern: '**/*.sql' }, + { scheme: 'file', pattern: '**/external_models.yaml' }, + { scheme: 'file', pattern: '**/external_models.yml' }, + ], + diagnosticCollectionName: 'sqlmesh', + outputChannel, + initializationOptions: paths.value.projectPaths + ? { + project_paths: paths.value.projectPaths, + } + : null, + } + + traceInfo( + `Starting SQLMesh LSP (cwd=${workspacePath})\n` + + ` serverOptions=${JSON.stringify(serverOptions)}\n` + + ` clientOptions=${JSON.stringify(clientOptions)}`, + ) + + this.client = new LanguageClient( + 'sqlmesh-lsp', + 'SQLMesh Language Server', + serverOptions, + clientOptions, + ) + this.explicitlyStopped = false // user wanted it running again + await this.client.start() + return ok(undefined) + } + + /** Restart = stop + start. */ + public async restart( + overrideStoppedByUser = false, + ): Promise> { + await this.stop() // this also disposes + return this.start(overrideStoppedByUser) + } + + /** + * Stop the client (if running) and clean up all VS Code resources so that a + * future `start()` registers its commands without collisions. + */ + public async stop(stoppedByUser = false): Promise { + if (this.client) { + // Shut down the JSON‑RPC connection + await this.client + .stop() + .catch(err => traceError(`Error while stopping LSP: ${err}`)) + + // Unregister commands, code lenses, etc. + await this.client.dispose() + + this.client = undefined + this.supportedMethodsState = { type: 'not-fetched' } + traceInfo('SQLMesh LSP client disposed.') + } + + if (stoppedByUser) { + this.explicitlyStopped = true + traceInfo('SQLMesh LSP client stopped by user.') + } + } + + public async dispose(): Promise { + await this.stop() + } + + private async fetchSupportedMethods(): Promise { + if (!this.client || this.supportedMethodsState.type !== 'not-fetched') + return + + const result = await this.internal_call_custom_method( + 'sqlmesh/supported_methods', + {}, + ) + if (isErr(result)) { + traceError(`Failed to fetch supported methods: ${result.error}`) + this.supportedMethodsState = { type: 'endpoint-not-supported' } + return + } + + const methodNames = new Set(result.value.methods.map(m => m.name)) + this.supportedMethodsState = { type: 'fetched', methods: methodNames } + traceInfo(`Fetched supported methods: ${[...methodNames].join(', ')}`) + } + + public async call_custom_method< + Method extends Exclude< + CustomLSPMethods['method'], + 'sqlmesh/supported_methods' + >, + Request extends Extract['request'], + Response extends Extract['response'], + >( + method: Method, + request: Request, + ): Promise< + Result< + Response, + ErrorTypeGeneric | ErrorTypeInvalidState | ErrorTypeSQLMeshOutdated + > + > { + if (!this.client) { + return err({ type: 'generic', message: 'LSP client not ready.' }) + } + + await this.fetchSupportedMethods() + + const supportedState = this.supportedMethodsState + if ( + supportedState.type === 'fetched' && + !supportedState.methods.has(method) + ) { + return err({ + type: 'sqlmesh_outdated', + message: `Method '${method}' is not supported by this LSP server.`, + }) + } + + const response = await this.internal_call_custom_method( + method, + request as any, + ) + if (isErr(response)) { + return err({ type: 'generic', message: response.error }) + } + return ok(response.value as Response) + } + + /** + * Low‑level helper that sends a raw JSON‑RPC request without any feature checks. + */ + public async internal_call_custom_method< + Method extends CustomLSPMethods['method'], + Request extends Extract['request'], + Response extends Extract['response'], + >(method: Method, request: Request): Promise> { + if (!this.client) return err('lsp client not ready') + + try { + const result = await this.client.sendRequest(method, request) + if ((result as any).response_error) + return err((result as any).response_error) + return ok(result) + } catch (error) { + traceError(`LSP '${method}' request failed: ${JSON.stringify(error)}`) + return err(JSON.stringify(error)) + } + } +} diff --git a/vscode/extension/src/providers/renderedModelProvider.ts b/vscode/extension/src/providers/renderedModelProvider.ts new file mode 100644 index 0000000000..017e071db6 --- /dev/null +++ b/vscode/extension/src/providers/renderedModelProvider.ts @@ -0,0 +1,141 @@ +import * as vscode from 'vscode' +import { RenderModelEntry } from '../lsp/custom' + +interface RenderedModelInfo { + content: string + sourceUri?: string + modelInfo?: RenderModelEntry +} + +/** + * Content provider for read-only rendered SQL models + */ +export class RenderedModelProvider + implements vscode.TextDocumentContentProvider +{ + private static readonly scheme = 'sqlmesh-rendered' + + // Single map containing all rendered model information + private renderedModels = new Map() + // Track which source file URIs are associated with rendered models + private sourceToRenderedUri = new Map() + + // Event emitter for content changes + private _onDidChange = new vscode.EventEmitter() + readonly onDidChange = this._onDidChange.event + + /** + * Provide text content for a given URI + */ + provideTextDocumentContent(uri: vscode.Uri): string { + const key = uri.toString() + const modelInfo = this.renderedModels.get(key) + return modelInfo?.content || '' + } + + /** + * Store rendered model content and create a URI for it + */ + storeRenderedModel( + modelName: string, + content: string, + sourceUri?: string, + modelInfo?: RenderModelEntry, + ): vscode.Uri { + const fileName = `${modelName} (rendered)` + // Add a timestamp to make the URI unique for each render + const timestamp = Date.now() + // Use vscode.Uri.from for proper URI construction + const uri = vscode.Uri.from({ + scheme: RenderedModelProvider.scheme, + path: fileName, + fragment: timestamp.toString(), + }) + + const uriString = uri.toString() + + // Store all information in single map + this.renderedModels.set(uriString, { + content, + sourceUri, + modelInfo, + }) + + // Track the association between a source file and the rendered model + if (sourceUri) { + // Remove any existing mapping for this source file + const existingRenderedUri = this.sourceToRenderedUri.get(sourceUri) + if (existingRenderedUri) { + this.renderedModels.delete(existingRenderedUri.toString()) + } + + this.sourceToRenderedUri.set(sourceUri, uri) + } + + this._onDidChange.fire(uri) + return uri + } + + /** + * Update an existing rendered model with new content + */ + updateRenderedModel(uri: vscode.Uri, content: string): void { + const uriString = uri.toString() + const existingInfo = this.renderedModels.get(uriString) + if (existingInfo) { + this.renderedModels.set(uriString, { + ...existingInfo, + content, + }) + } + this._onDidChange.fire(uri) + } + + /** + * Get the rendered URI for a given source file URI + */ + getRenderedUriForSource(sourceUri: string): vscode.Uri | undefined { + return this.sourceToRenderedUri.get(sourceUri) + } + + /** + * Get the source URI for a given rendered model URI + */ + getSourceUriForRendered(renderedUri: string): string | undefined { + const modelInfo = this.renderedModels.get(renderedUri) + return modelInfo?.sourceUri + } + + /** + * Get the model information for a given rendered model URI + */ + getModelInfoForRendered( + renderedUri: vscode.Uri, + ): RenderModelEntry | undefined { + const modelInfo = this.renderedModels.get(renderedUri.toString()) + return modelInfo?.modelInfo + } + + /** + * Check if a source file has an associated rendered model + */ + hasRenderedModelForSource(sourceUri: string): boolean { + return this.sourceToRenderedUri.has(sourceUri) + } + + /** + * Get the URI scheme for rendered models + */ + static getScheme(): string { + return this.scheme + } + + /** + * Clean up old rendered models to prevent memory leaks + */ + dispose() { + this.renderedModels.clear() + this.sourceToRenderedUri.clear() + this._onDidChange.dispose() + } +} diff --git a/vscode/extension/src/tests/tests.ts b/vscode/extension/src/tests/tests.ts new file mode 100644 index 0000000000..dd3503165c --- /dev/null +++ b/vscode/extension/src/tests/tests.ts @@ -0,0 +1,155 @@ +import * as vscode from 'vscode' +import path from 'path' +import { LSPClient } from '../lsp/lsp' +import { isErr } from '@bus/result' +import { Disposable } from 'vscode' + +export const controller = vscode.tests.createTestController( + 'sqlmeshTests', + 'SQLMesh Tests', +) + +export const setupTestController = (lsp: LSPClient): Disposable => { + controller.resolveHandler = async test => { + console.log('Resolving test:', test?.id) + const uri = test?.uri + if (uri) { + await discoverDocumentTests(uri.toString()) + } else { + await discoverWorkspaceTests() + } + } + + // Discover tests immediately when the controller is set up + // This is useful for the initial load of tests in the workspace + // eslint-disable-next-line @typescript-eslint/no-floating-promises + discoverWorkspaceTests() + + controller.createRunProfile( + 'Run', + vscode.TestRunProfileKind.Run, + request => runTests(request), + true, + ) + + async function discoverDocumentTests(uri: string) { + const result = await lsp.call_custom_method('sqlmesh/list_document_tests', { + textDocument: { uri }, + }) + if (isErr(result)) { + vscode.window.showErrorMessage( + `Failed to list SQLMesh tests: ${result.error.message}`, + ) + return + } + const fileItem = controller.items.get(uri) + if (!fileItem) { + vscode.window.showErrorMessage(`No test item found for document: ${uri}`) + return + } + fileItem.children.replace([]) + for (const test of result.value.tests) { + const testItem = controller.createTestItem( + test.name, + test.name, + vscode.Uri.parse(test.uri), + ) + const range = test.range + testItem.range = new vscode.Range( + new vscode.Position(range.start.line, range.start.character), + new vscode.Position(range.end.line, range.end.character), + ) + fileItem.children.add(testItem) + } + } + + async function discoverWorkspaceTests() { + const result = await lsp.call_custom_method( + 'sqlmesh/list_workspace_tests', + {}, + ) + if (isErr(result)) { + vscode.window.showErrorMessage( + `Failed to list SQLMesh tests: ${result.error.message}`, + ) + return + } + controller.items.replace([]) + const files = new Map() + for (const entry of result.value.tests) { + const uri = vscode.Uri.parse(entry.uri) + let fileItem = files.get(uri.toString()) + if (!fileItem) { + fileItem = controller.createTestItem( + uri.toString(), + path.basename(uri.fsPath), + uri, + ) + // THIS IS WHERE YOU RESOLVE THE RANGE + fileItem.canResolveChildren = true + files.set(uri.toString(), fileItem) + controller.items.add(fileItem) + } + const testId = `${uri.toString()}::${entry.name}` + const testItem = controller.createTestItem(testId, entry.name, uri) + fileItem.children.add(testItem) + } + } + + async function runTests(request: vscode.TestRunRequest) { + const run = controller.createTestRun(request) + + const tests: vscode.TestItem[] = [] + const collect = (item: vscode.TestItem) => { + if (item.children.size === 0) tests.push(item) + item.children.forEach(collect) + } + + if (request.include) request.include.forEach(collect) + else controller.items.forEach(collect) + + for (const test of tests) { + run.started(test) + const startTime = Date.now() + const uri = test.uri + if (uri === undefined) { + run.failed(test, new vscode.TestMessage('Test item has no URI')) + continue + } + const response = await lsp.call_custom_method('sqlmesh/run_test', { + textDocument: { uri: uri.toString() }, + testName: test.id, + }) + if (isErr(response)) { + run.failed(test, new vscode.TestMessage(response.error.message)) + continue + } else { + const result = response.value + const duration = Date.now() - startTime + if (result.success) { + run.passed(test, duration) + } else { + run.failed( + test, + new vscode.TestMessage(result.error_message ?? 'Test failed'), + duration, + ) + } + } + } + run.end() + } + + // onChangeFile of yaml file reload the tests + return vscode.workspace.onDidChangeTextDocument(async event => { + if (event.document.languageId === 'yaml') { + const uri = event.document.uri.toString() + const testItem = controller.items.get(uri) + if (testItem) { + await discoverDocumentTests(uri) + } else { + await discoverWorkspaceTests() + } + } + }) +} diff --git a/vscode/extension/src/utilities/common/constants.ts b/vscode/extension/src/utilities/common/constants.ts new file mode 100644 index 0000000000..274b60bc9a --- /dev/null +++ b/vscode/extension/src/utilities/common/constants.ts @@ -0,0 +1,24 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import * as path from 'path' + +const folderName = path.basename(__dirname) +export const EXTENSION_ROOT_DIR = + folderName === 'common' + ? path.dirname(path.dirname(__dirname)) + : path.dirname(__dirname) +export const BUNDLED_PYTHON_SCRIPTS_DIR = path.join( + EXTENSION_ROOT_DIR, + 'bundled', +) +export const SERVER_SCRIPT_PATH = path.join( + BUNDLED_PYTHON_SCRIPTS_DIR, + 'tool', + `lsp_server.py`, +) +export const DEBUG_SERVER_SCRIPT_PATH = path.join( + BUNDLED_PYTHON_SCRIPTS_DIR, + 'tool', + `_debug_server.py`, +) diff --git a/vscode/extension/src/utilities/common/log.ts b/vscode/extension/src/utilities/common/log.ts new file mode 100644 index 0000000000..9498be3432 --- /dev/null +++ b/vscode/extension/src/utilities/common/log.ts @@ -0,0 +1,59 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +import * as util from 'util' +import { Disposable, LogOutputChannel } from 'vscode' + +type Arguments = unknown[] +class OutputChannelLogger { + constructor(private readonly channel: LogOutputChannel) {} + + public traceLog(...data: Arguments): void { + this.channel.appendLine(util.format(...data)) + } + + public traceError(...data: Arguments): void { + this.channel.error(util.format(...data)) + } + + public traceWarn(...data: Arguments): void { + this.channel.warn(util.format(...data)) + } + + public traceInfo(...data: Arguments): void { + this.channel.info(util.format(...data)) + } + + public traceVerbose(...data: Arguments): void { + this.channel.debug(util.format(...data)) + } +} + +let channel: OutputChannelLogger | undefined +export function registerLogger(logChannel: LogOutputChannel): Disposable { + channel = new OutputChannelLogger(logChannel) + return { + dispose: () => { + channel = undefined + }, + } +} + +export function traceLog(...args: Arguments): void { + channel?.traceLog(...args) +} + +export function traceError(...args: Arguments): void { + channel?.traceError(...args) +} + +export function traceWarn(...args: Arguments): void { + channel?.traceWarn(...args) +} + +export function traceInfo(...args: Arguments): void { + channel?.traceInfo(...args) +} + +export function traceVerbose(...args: Arguments): void { + channel?.traceVerbose(...args) +} diff --git a/vscode/extension/src/utilities/common/python.ts b/vscode/extension/src/utilities/common/python.ts new file mode 100644 index 0000000000..b30e2e91f7 --- /dev/null +++ b/vscode/extension/src/utilities/common/python.ts @@ -0,0 +1,155 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import { commands, Disposable, Event, EventEmitter, Uri } from 'vscode' +import { traceError, traceLog } from './log' +import { PythonExtension, ResolvedEnvironment } from '@vscode/python-extension' +import path from 'path' +import { err, ok, Result } from '@bus/result' +import * as vscode from 'vscode' + +export interface IInterpreterDetails { + path?: string[] + resource?: Uri + isVirtualEnvironment?: boolean + binPath?: string +} + +const onDidChangePythonInterpreterEvent = + new EventEmitter() + +export const onDidChangePythonInterpreter: Event = + onDidChangePythonInterpreterEvent.event + +let _api: PythonExtension | undefined + +async function getPythonExtensionAPI(): Promise { + if (_api) { + return _api + } + _api = await PythonExtension.api() + return _api +} + +export async function initializePython( + disposables: Disposable[], +): Promise { + try { + const api = await getPythonExtensionAPI() + + if (api) { + disposables.push( + api.environments.onDidChangeActiveEnvironmentPath(async e => { + const environment = await api.environments.resolveEnvironment(e.path) + const isVirtualEnv = environment?.environment !== undefined + // Get the directory of the Python executable for virtual environments + const pythonDir = environment?.executable.uri + ? path.dirname(environment.executable.uri.fsPath) + : undefined + + onDidChangePythonInterpreterEvent.fire({ + path: [e.path], + resource: e.resource?.uri, + isVirtualEnvironment: isVirtualEnv, + binPath: isVirtualEnv ? pythonDir : undefined, + }) + }), + ) + + traceLog('Waiting for interpreter from python extension.') + onDidChangePythonInterpreterEvent.fire(await getInterpreterDetails()) + } + } catch (error) { + traceError('Error initializing python: ', error) + } +} + +export async function resolveInterpreter( + interpreter: string[], +): Promise { + const api = await getPythonExtensionAPI() + return api?.environments.resolveEnvironment(interpreter[0]) +} + +export async function getInterpreterDetails( + resource?: Uri, +): Promise { + const api = await getPythonExtensionAPI() + const environment = await api?.environments.resolveEnvironment( + api?.environments.getActiveEnvironmentPath(resource), + ) + if (environment?.executable.uri && checkVersion(environment)) { + const isVirtualEnv = environment.environment !== undefined + // Get the directory of the Python executable + const pythonDir = path.dirname(environment?.executable.uri.fsPath) + + return { + path: [environment?.executable.uri.fsPath], + resource, + isVirtualEnvironment: isVirtualEnv, + // For virtual environments, we need to point directly to the bin directory + // rather than constructing it from the environment folder + binPath: isVirtualEnv ? pythonDir : undefined, + } + } + return { path: undefined, resource } +} + +export async function getDebuggerPath(): Promise { + const api = await getPythonExtensionAPI() + return api?.debug.getDebuggerPackagePath() +} + +export async function runPythonExtensionCommand( + command: string, + ...rest: any[] +) { + await getPythonExtensionAPI() + return await commands.executeCommand(command, ...rest) +} + +export function checkVersion( + resolved: ResolvedEnvironment | undefined, +): boolean { + const version = resolved?.version + if (version?.major === 3 && version?.minor >= 8) { + return true + } + traceError( + `Python version ${version?.major}.${version?.minor} is not supported.`, + ) + traceError(`Selected python path: ${resolved?.executable.uri?.fsPath}`) + traceError('Supported versions are 3.8 and above.') + return false +} + +/** + * getPythonEnvVariables returns the environment variables for the current python interpreter. + * + * @returns The environment variables for the current python interpreter. + */ +export async function getPythonEnvVariables(): Promise< + Result, string> +> { + const api = await getPythonExtensionAPI() + if (!api) { + return err('Python extension API not found') + } + + const workspaces = vscode.workspace.workspaceFolders + if (!workspaces) { + return ok({}) + } + const out: Record = {} + for (const workspace of workspaces) { + const envVariables = api.environments.getEnvironmentVariables(workspace.uri) + if (envVariables) { + for (const [key, value] of Object.entries(envVariables)) { + if (value) { + out[key] = value + } + } + } + } + return ok(out) +} diff --git a/vscode/extension/src/utilities/common/settings.ts b/vscode/extension/src/utilities/common/settings.ts new file mode 100644 index 0000000000..000dfc5534 --- /dev/null +++ b/vscode/extension/src/utilities/common/settings.ts @@ -0,0 +1,150 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import { + ConfigurationChangeEvent, + ConfigurationScope, + WorkspaceConfiguration, + WorkspaceFolder, +} from 'vscode' +import { getInterpreterDetails } from './python' +import { getConfiguration, getWorkspaceFolders } from './vscodeapi' + +export interface ISettings { + cwd: string + workspace: string + args: string[] + path: string[] + interpreter: string[] + importStrategy: string + showNotifications: string +} + +export function getExtensionSettings( + namespace: string, + includeInterpreter?: boolean, +): Promise { + return Promise.all( + getWorkspaceFolders().map(w => + getWorkspaceSettings(namespace, w, includeInterpreter), + ), + ) +} + +function resolveVariables( + value: string[], + workspace?: WorkspaceFolder, +): string[] { + const substitutions = new Map() + const home = process.env.HOME || process.env.USERPROFILE + if (home) { + substitutions.set('${userHome}', home) + } + if (workspace) { + substitutions.set('${workspaceFolder}', workspace.uri.fsPath) + } + substitutions.set('${cwd}', process.cwd()) + getWorkspaceFolders().forEach(w => { + substitutions.set('${workspaceFolder:' + w.name + '}', w.uri.fsPath) + }) + + return value.map(s => { + for (const [k, v] of substitutions) { + s = s.replace(k, v) + } + return s + }) +} + +export function getInterpreterFromSetting( + namespace: string, + scope?: ConfigurationScope, +) { + const config = getConfiguration(namespace, scope) + return config.get('interpreter') +} + +export async function getWorkspaceSettings( + namespace: string, + workspace: WorkspaceFolder, + includeInterpreter?: boolean, +): Promise { + const config = getConfiguration(namespace, workspace.uri) + + let interpreter: string[] = [] + if (includeInterpreter) { + interpreter = getInterpreterFromSetting(namespace, workspace) ?? [] + if (interpreter.length === 0) { + interpreter = (await getInterpreterDetails(workspace.uri)).path ?? [] + } + } + + const workspaceSetting = { + cwd: workspace.uri.fsPath, + workspace: workspace.uri.toString(), + args: resolveVariables(config.get(`args`) ?? [], workspace), + path: resolveVariables(config.get(`path`) ?? [], workspace), + interpreter: resolveVariables(interpreter, workspace), + importStrategy: config.get(`importStrategy`) ?? 'useBundled', + showNotifications: config.get(`showNotifications`) ?? 'off', + } + return workspaceSetting +} + +function getGlobalValue( + config: WorkspaceConfiguration, + key: string, + defaultValue: T, +): T { + const inspect = config.inspect(key) + return inspect?.globalValue ?? inspect?.defaultValue ?? defaultValue +} + +export async function getGlobalSettings( + namespace: string, + includeInterpreter?: boolean, +): Promise { + const config = getConfiguration(namespace) + + let interpreter: string[] = [] + if (includeInterpreter) { + interpreter = getGlobalValue(config, 'interpreter', []) + if (interpreter === undefined || interpreter.length === 0) { + interpreter = (await getInterpreterDetails()).path ?? [] + } + } + + const setting = { + cwd: process.cwd(), + workspace: process.cwd(), + args: getGlobalValue(config, 'args', []), + path: getGlobalValue(config, 'path', []), + interpreter: interpreter, + importStrategy: getGlobalValue( + config, + 'importStrategy', + 'useBundled', + ), + showNotifications: getGlobalValue( + config, + 'showNotifications', + 'off', + ), + } + return setting +} + +export function checkIfConfigurationChanged( + e: ConfigurationChangeEvent, + namespace: string, +): boolean { + const settings = [ + `${namespace}.args`, + `${namespace}.path`, + `${namespace}.interpreter`, + `${namespace}.importStrategy`, + `${namespace}.showNotifications`, + ] + const changed = settings.map(s => e.affectsConfiguration(s)) + return changed.includes(true) +} diff --git a/vscode/extension/src/utilities/common/utilities.ts b/vscode/extension/src/utilities/common/utilities.ts new file mode 100644 index 0000000000..fc7088d49e --- /dev/null +++ b/vscode/extension/src/utilities/common/utilities.ts @@ -0,0 +1,75 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import * as fs from 'fs-extra' +import * as path from 'path' +import { LogLevel, Uri, WorkspaceFolder } from 'vscode' +import { Trace } from 'vscode-jsonrpc/node' +import { getWorkspaceFolders } from './vscodeapi' + +function logLevelToTrace(logLevel: LogLevel): Trace { + switch (logLevel) { + case LogLevel.Error: + case LogLevel.Warning: + case LogLevel.Info: + return Trace.Messages + + case LogLevel.Debug: + case LogLevel.Trace: + return Trace.Verbose + + case LogLevel.Off: + default: + return Trace.Off + } +} + +export function getLSClientTraceLevel( + channelLogLevel: LogLevel, + globalLogLevel: LogLevel, +): Trace { + if (channelLogLevel === LogLevel.Off) { + return logLevelToTrace(globalLogLevel) + } + if (globalLogLevel === LogLevel.Off) { + return logLevelToTrace(channelLogLevel) + } + return logLevelToTrace( + channelLogLevel <= globalLogLevel ? channelLogLevel : globalLogLevel, + ) +} + +export async function getProjectRoot(): Promise { + const workspaces: readonly WorkspaceFolder[] = getWorkspaceFolders() + if (workspaces.length === 0) { + return { + uri: Uri.file(process.cwd()), + name: path.basename(process.cwd()), + index: 0, + } + } else if (workspaces.length === 1) { + return workspaces[0] + } else { + let rootWorkspace = workspaces[0] + let root = undefined + for (const w of workspaces) { + if (await fs.pathExists(w.uri.fsPath)) { + root = w.uri.fsPath + rootWorkspace = w + break + } + } + + for (const w of workspaces) { + if ( + root && + root.length > w.uri.fsPath.length && + (await fs.pathExists(w.uri.fsPath)) + ) { + root = w.uri.fsPath + rootWorkspace = w + } + } + return rootWorkspace + } +} diff --git a/vscode/extension/src/utilities/common/vscodeapi.ts b/vscode/extension/src/utilities/common/vscodeapi.ts new file mode 100644 index 0000000000..6687d60933 --- /dev/null +++ b/vscode/extension/src/utilities/common/vscodeapi.ts @@ -0,0 +1,50 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +import { + commands, + ConfigurationScope, + Disposable, + LogOutputChannel, + Uri, + window, + workspace, + WorkspaceConfiguration, + WorkspaceFolder, +} from 'vscode' + +export function createOutputChannel(name: string): LogOutputChannel { + return window.createOutputChannel(name, { log: true }) +} + +export function getConfiguration( + config: string, + scope?: ConfigurationScope, +): WorkspaceConfiguration { + return workspace.getConfiguration(config, scope) +} + +export function registerCommand( + command: string, + callback: (...args: any[]) => any, + thisArg?: any, +): Disposable { + return commands.registerCommand(command, callback, thisArg) +} + +export const { onDidChangeConfiguration } = workspace + +export function isVirtualWorkspace(): boolean { + const isVirtual = workspace.workspaceFolders?.every( + f => f.uri.scheme !== 'file', + ) + return !!isVirtual +} + +export function getWorkspaceFolders(): readonly WorkspaceFolder[] { + return workspace.workspaceFolders ?? [] +} + +export function getWorkspaceFolder(uri: Uri): WorkspaceFolder | undefined { + return workspace.getWorkspaceFolder(uri) +} diff --git a/vscode/extension/src/utilities/config.ts b/vscode/extension/src/utilities/config.ts new file mode 100644 index 0000000000..53c2662612 --- /dev/null +++ b/vscode/extension/src/utilities/config.ts @@ -0,0 +1,146 @@ +import { workspace, WorkspaceFolder } from 'vscode' +import path from 'path' +import fs from 'fs' +import { Result, err, isErr, ok } from '@bus/result' +import { traceVerbose, traceInfo } from './common/log' +import { parse } from 'shell-quote' +import { z } from 'zod' + +const sqlmeshConfigurationSchema = z.object({ + projectPaths: z.array(z.string()), + lspEntryPoint: z.string(), +}) + +export type SqlmeshConfiguration = z.infer + +/** + * Get the SQLMesh configuration from VS Code settings. + * + * @returns The SQLMesh configuration + */ +function getSqlmeshConfiguration(): SqlmeshConfiguration { + const config = workspace.getConfiguration('sqlmesh') + const projectPaths = config.get('projectPaths', []) + const lspEntryPoint = config.get('lspEntrypoint', '') + const parsed = sqlmeshConfigurationSchema.safeParse({ + projectPaths, + lspEntryPoint, + }) + if (!parsed.success) { + throw new Error( + `Invalid SQLMesh configuration: ${JSON.stringify(parsed.error)}`, + ) + } + return parsed.data +} + +const stringsArray = z.array(z.string()) + +/** + * Get the SQLMesh LSP entry point from VS Code settings. undefined if not set + * it's expected to be a string in the format "command arg1 arg2 ...". + */ +export function getSqlmeshLspEntryPoint(): + | { + entrypoint: string + args: string[] + } + | undefined { + const config = getSqlmeshConfiguration() + if (config.lspEntryPoint === '') { + return undefined + } + // Split the entry point into command and arguments + const parts = parse(config.lspEntryPoint) + const parsed = stringsArray.safeParse(parts) + if (!parsed.success) { + throw new Error( + `Invalid lspEntrypoint configuration: ${config.lspEntryPoint}. Expected a + string in the format "command arg1 arg2 ...".`, + ) + } + const entrypoint = parsed.data[0] + const args = parsed.data.slice(1) + return { entrypoint, args } +} + +/** + * Validate and resolve the project paths from configuration. + * If no project path is configured, use the workspace folder. + * If the project path is configured, it must be a directory that contains a SQLMesh project. + * + * @param workspaceFolder The current workspace folder + * @returns A Result containing the resolved project paths or an error + */ +export function resolveProjectPath(workspaceFolder: WorkspaceFolder): Result< + { + projectPaths: string[] | undefined + workspaceFolder: string + }, + string +> { + const config = getSqlmeshConfiguration() + + if (config.projectPaths.length === 0) { + // If no project path is configured, use the workspace folder + traceVerbose('No project path configured, using workspace folder') + return ok({ + workspaceFolder: workspaceFolder.uri.fsPath, + projectPaths: undefined, + }) + } + + const resolvedPaths: string[] = [] + for (const projectPath of config.projectPaths) { + const result = resolveSingleProjectPath(workspaceFolder, projectPath) + if (isErr(result)) { + return result + } + resolvedPaths.push(result.value) + } + return ok({ + projectPaths: resolvedPaths, + workspaceFolder: workspaceFolder.uri.fsPath, + }) +} + +function resolveSingleProjectPath( + workspaceFolder: WorkspaceFolder, + projectPath: string, +): Result { + let resolvedPath: string + + // Check if the path is absolute + if (path.isAbsolute(projectPath)) { + resolvedPath = projectPath + } else { + // Resolve relative path from workspace root + resolvedPath = path.join(workspaceFolder.uri.fsPath, projectPath) + } + + // Normalize the path + resolvedPath = path.normalize(resolvedPath) + + // Validate that the path exists + if (!fs.existsSync(resolvedPath)) { + return err(`Configured project path does not exist: ${resolvedPath}`) + } + + // Validate that it's a directory + const stats = fs.statSync(resolvedPath) + if (!stats.isDirectory()) { + return err(`Configured project path is not a directory: ${resolvedPath}`) + } + + // Check if it contains SQLMesh project files (config.yaml, config.yml, or config.py) + const configFiles = ['config.yaml', 'config.yml', 'config.py'] + const hasConfigFile = configFiles.some(file => + fs.existsSync(path.join(resolvedPath, file)), + ) + if (!hasConfigFile) { + traceInfo(`Warning: No SQLMesh configuration file found in ${resolvedPath}`) + } + + traceVerbose(`Using project path: ${resolvedPath}`) + return ok(resolvedPath) +} diff --git a/vscode/extension/src/utilities/errors.ts b/vscode/extension/src/utilities/errors.ts new file mode 100644 index 0000000000..f062d89369 --- /dev/null +++ b/vscode/extension/src/utilities/errors.ts @@ -0,0 +1,187 @@ +import { window } from 'vscode' +import { AuthenticationProviderTobikoCloud } from '../auth/auth' +import { signIn } from '../commands/signin' +import { traceInfo } from './common/log' + +/** + * Represents different types of errors that can occur in the application. + */ +export type ErrorType = + | ErrorTypeGeneric + | { type: 'not_signed_in' } + | { type: 'sqlmesh_not_found' } + | { type: 'sqlmesh_lsp_not_found' } + // tcloud_bin_not_found is used when the tcloud executable is not found. This is likely to happen if the user + // opens a project that has a `tcloud.yaml` file but doesn't have tcloud installed. + | { type: 'tcloud_bin_not_found' } + | SqlmeshLspDependenciesMissingError + | ErrorTypeInvalidState + | ErrorTypeSQLMeshOutdated + +/** + * ErrorTypeSQLMeshOutdated is used when the SQLMesh version is outdated. The + * message should explain the problem, but the suggestion to update SQLMesh is + * handled at the place where the error is shown. + */ +export interface ErrorTypeSQLMeshOutdated { + type: 'sqlmesh_outdated' + /** + * A message that describes the outdated SQLMesh version, it should not talk about + * updating SQLMesh. This is done at the place where the error is handled. + */ + message: string +} + +/** + * ErrorTypeInvalidState is used when the state of the application is invalid state. + * They should never be thrown by the application unless there is a bug in the code. + * The shown message should be generic and not contain any sensitive information but + * asks the user to report the issue to the developers. + */ +export interface ErrorTypeInvalidState { + type: 'invalid_state' + /** + * A message that describes the invalid state, it should not talk about reporting + * the issue to the developers. This is done at the place where the error is + * handled. + */ + message: string +} + +/** + * ErrorTypeGeneric is a generic error type that can be used to represent any error with a message. + */ +export interface ErrorTypeGeneric { + type: 'generic' + message: string +} + +/** + * SqlmeshLspDependenciesMissingError is used when the sqlmesh_lsp is found but + * the lsp extras are missing. + */ +interface SqlmeshLspDependenciesMissingError { + type: 'sqlmesh_lsp_dependencies_missing' + is_missing_pygls: boolean + is_missing_lsprotocol: boolean + is_tobiko_cloud: boolean +} + +export async function handleError( + authProvider: AuthenticationProviderTobikoCloud, + restartLsp: () => Promise, + error: ErrorType, + genericErrorPrefix?: string, +): Promise { + traceInfo('handleError', error) + switch (error.type) { + case 'invalid_state': + await window.showErrorMessage( + `Invalid state: ${error.message}. Please report this issue to the developers.`, + ) + return + case 'sqlmesh_outdated': + await window.showErrorMessage( + `SQLMesh itself is outdated. Please update SQLMesh to the latest version to use this feature. ${error.message}`, + ) + return + case 'not_signed_in': + return handleNotSignedInError(authProvider, restartLsp) + case 'sqlmesh_not_found': + return handleSqlmeshNotFoundError() + case 'sqlmesh_lsp_not_found': + return handleSqlmeshLspNotFoundError() + case 'sqlmesh_lsp_dependencies_missing': + return handleSqlmeshLspDependenciesMissingError(error) + case 'tcloud_bin_not_found': + return handleTcloudBinNotFoundError() + case 'generic': + if (genericErrorPrefix) { + await window.showErrorMessage(`${genericErrorPrefix}: ${error.message}`) + } else { + await window.showErrorMessage(`An error occurred: ${error.message}`) + } + return + } +} + +/** + * Handles the case where the user is not signed in to Tobiko Cloud. + * @param authProvider - The authentication provider to use for signing in. + */ +const handleNotSignedInError = async ( + authProvider: AuthenticationProviderTobikoCloud, + restartLsp: () => Promise, +): Promise => { + traceInfo('handleNotSginedInError') + const result = await window.showInformationMessage( + 'Please sign in to Tobiko Cloud to use SQLMesh', + 'Sign In', + ) + if (result === 'Sign In') { + await signIn(authProvider, restartLsp)() + } +} + +/** + * Handles the case where the sqlmesh executable is not found. + */ +const handleSqlmeshNotFoundError = async (): Promise => { + traceInfo('handleSqlmeshNotFoundError') + await window.showErrorMessage('SQLMesh not found, please check installation') +} + +/** + * Handles the case where the sqlmesh_lsp is not found. + */ +const handleSqlmeshLspNotFoundError = async (): Promise => { + traceInfo('handleSqlmeshLspNotFoundError') + await window.showErrorMessage( + 'SQLMesh LSP not found, please check installation', + ) +} + +/** + * Handles the case where the sqlmesh_lsp is found but the lsp extras are missing. + */ +const handleSqlmeshLspDependenciesMissingError = async ( + error: SqlmeshLspDependenciesMissingError, +): Promise => { + traceInfo('handleSqlmeshLspDependenciesMissingError') + if (error.is_tobiko_cloud) { + await window.showErrorMessage( + 'LSP dependencies missing, make sure to include `lsp` in the `extras` section of your `tcloud.yaml` file.', + ) + } else { + const install = await window.showErrorMessage( + 'LSP dependencies missing, make sure to install `sqlmesh[lsp]`.', + 'Install', + ) + if (install === 'Install') { + const terminal = window.createTerminal({ + name: 'SQLMesh LSP Install', + hideFromUser: false, + }) + terminal.show() + terminal.sendText("pip install 'sqlmesh[lsp]'", false) + } + } +} + +/** + * Handles the case where the tcloud executable is not found. + */ +const handleTcloudBinNotFoundError = async (): Promise => { + const result = await window.showErrorMessage( + 'tcloud executable not found, please check installation', + 'Install', + ) + if (result === 'Install') { + const terminal = window.createTerminal({ + name: 'Tcloud Install', + hideFromUser: false, + }) + terminal.show() + terminal.sendText('pip install tcloud', false) + } +} diff --git a/vscode/extension/src/utilities/exec.ts b/vscode/extension/src/utilities/exec.ts new file mode 100644 index 0000000000..4748785b2b --- /dev/null +++ b/vscode/extension/src/utilities/exec.ts @@ -0,0 +1,69 @@ +import { exec, ExecOptions } from 'node:child_process' +import { traceInfo } from './common/log' + +export interface ExecResult { + exitCode: number + stdout: string + stderr: string +} + +export async function execAsync( + command: string, + args: string[] = [], + options: ExecOptions & { signal?: AbortSignal } = {}, +): Promise { + const fullCmd = `${command} ${args.join(' ')}` + traceInfo(`Executing command: ${fullCmd} in ${options.cwd}`) + + try { + const result = await execAsyncCore(command, args, options) + traceInfo( + `Command ${fullCmd} exited with code ${result.exitCode}; stdout: ${result.stdout}; stderr: ${result.stderr}`, + ) + return result + } catch (err) { + if ((err as any)?.name === 'AbortError') { + traceInfo(`Command ${fullCmd} was cancelled by AbortController`) + } else { + traceInfo(`Command ${fullCmd} failed: ${(err as Error).message}`) + } + throw err // keep original error semantics + } +} + +function execAsyncCore( + command: string, + args: string[], + options: ExecOptions & { signal?: AbortSignal } = {}, +): Promise { + return new Promise((resolve, reject) => { + const child = exec( + `${command} ${args.join(' ')}`, + options, + (error, stdout, stderr) => { + if (error) { + // Forward AbortError unchanged so callers can detect cancellation + if ((error as NodeJS.ErrnoException).name === 'AbortError') { + reject(error) + } else { + resolve({ + exitCode: typeof error.code === 'number' ? error.code : 1, + stdout, + stderr, + }) + } + return + } + + resolve({ + exitCode: child.exitCode ?? 0, + stdout, + stderr, + }) + }, + ) + + // surface “spawn failed” errors that occur before the callback + child.once('error', reject) + }) +} diff --git a/vscode/extension/src/utilities/isCodespaces.ts b/vscode/extension/src/utilities/isCodespaces.ts new file mode 100644 index 0000000000..16d9d441b8 --- /dev/null +++ b/vscode/extension/src/utilities/isCodespaces.ts @@ -0,0 +1,10 @@ +/** + * isCodespaces checks if the current environment is a Codespaces + * + * @returns true if the current environment is a Codespaces, false otherwise + */ +export const isCodespaces = () => { + return ( + process.env.CODESPACES === 'true' || !!process.env.GITHUB_CODESPACE_TOKEN + ) +} diff --git a/vscode/extension/src/utilities/isWindows.ts b/vscode/extension/src/utilities/isWindows.ts new file mode 100644 index 0000000000..03c0e31eba --- /dev/null +++ b/vscode/extension/src/utilities/isWindows.ts @@ -0,0 +1,4 @@ +/** + * Whether the current platform is Windows. + */ +export const IS_WINDOWS = process.platform === 'win32' diff --git a/vscode/extension/src/utilities/python.ts b/vscode/extension/src/utilities/python.ts new file mode 100644 index 0000000000..c056a6c89b --- /dev/null +++ b/vscode/extension/src/utilities/python.ts @@ -0,0 +1,43 @@ +import { getInterpreterDetails } from './common/python' +import { err, ok, Result } from '@bus/result' +import { traceInfo } from './common/log' +import { promisify } from 'util' +import { execFile } from 'child_process' + +/** isPythonModuleInstallled returns true if the given python module is installed. + * + * @param moduleName - The name of the python module to check. + * @returns True if the module is installed, false otherwise. + */ +export const isPythonModuleInstalled = async ( + moduleName: string, +): Promise> => { + const interpreterDetails = await getInterpreterDetails() + if (!interpreterDetails.path) { + return err('No Python interpreter found') + } + const pythonPath = interpreterDetails.path[0] + const checkScript = ` +import sys +if sys.version_info >= (3, 12): + from importlib import metadata +else: + import importlib_metadata as metadata + +try: + metadata.version('${moduleName}') + print("true") +except metadata.PackageNotFoundError: + print("false") +` + try { + const execFileAsync = promisify(execFile) + const { stdout } = await execFileAsync(pythonPath, ['-c', checkScript]) + const isInstalled = stdout.trim() === 'true' + traceInfo(`${moduleName} is installed: ${isInstalled}`) + + return ok(stdout.trim() === 'true') + } catch (error) { + return err(`Failed to check tcloud installation: ${error}`) + } +} diff --git a/vscode/extension/src/utilities/semver.test.ts b/vscode/extension/src/utilities/semver.test.ts new file mode 100644 index 0000000000..951bf03c04 --- /dev/null +++ b/vscode/extension/src/utilities/semver.test.ts @@ -0,0 +1,37 @@ +import { describe, it, expect } from 'vitest' +import { isSemVerGreaterThanOrEqual } from './semver' + +describe('isSemVerGreaterThanOrEqual', () => { + it('should return true when major version is greater', () => { + expect(isSemVerGreaterThanOrEqual([2, 0, 0], [1, 0, 0])).toBe(true) + expect(isSemVerGreaterThanOrEqual([3, 0, 0], [2, 5, 10])).toBe(true) + }) + + it('should return false when major version is less', () => { + expect(isSemVerGreaterThanOrEqual([1, 0, 0], [2, 0, 0])).toBe(false) + expect(isSemVerGreaterThanOrEqual([0, 10, 20], [1, 0, 0])).toBe(false) + }) + + it('should compare minor version when major versions are equal', () => { + expect(isSemVerGreaterThanOrEqual([1, 2, 0], [1, 1, 0])).toBe(true) + expect(isSemVerGreaterThanOrEqual([1, 1, 0], [1, 2, 0])).toBe(false) + expect(isSemVerGreaterThanOrEqual([2, 5, 0], [2, 3, 10])).toBe(true) + }) + + it('should compare patch version when major and minor versions are equal', () => { + expect(isSemVerGreaterThanOrEqual([1, 1, 2], [1, 1, 1])).toBe(true) + expect(isSemVerGreaterThanOrEqual([1, 1, 1], [1, 1, 2])).toBe(false) + expect(isSemVerGreaterThanOrEqual([2, 3, 10], [2, 3, 5])).toBe(true) + }) + + it('should return true when versions are equal', () => { + expect(isSemVerGreaterThanOrEqual([1, 0, 0], [1, 0, 0])).toBe(true) + expect(isSemVerGreaterThanOrEqual([2, 5, 10], [2, 5, 10])).toBe(true) + }) + + it('should handle zero versions correctly', () => { + expect(isSemVerGreaterThanOrEqual([0, 0, 1], [0, 0, 0])).toBe(true) + expect(isSemVerGreaterThanOrEqual([0, 1, 0], [0, 0, 10])).toBe(true) + expect(isSemVerGreaterThanOrEqual([0, 0, 0], [0, 0, 0])).toBe(true) + }) +}) diff --git a/vscode/extension/src/utilities/semver.ts b/vscode/extension/src/utilities/semver.ts new file mode 100644 index 0000000000..fed83af4a4 --- /dev/null +++ b/vscode/extension/src/utilities/semver.ts @@ -0,0 +1,24 @@ +type SemVer = [number, number, number] + +/** + * Check if a is greater than or equal to b. + * + * @param a - The first version. + * @param b - The second version. + * @returns True if a is greater than b, false otherwise. + */ +export function isSemVerGreaterThanOrEqual(a: SemVer, b: SemVer): boolean { + if (a[0] > b[0]) { + return true + } + if (a[0] < b[0]) { + return false + } + if (a[1] > b[1]) { + return true + } + if (a[1] < b[1]) { + return false + } + return a[2] >= b[2] +} diff --git a/vscode/extension/src/utilities/sleep.ts b/vscode/extension/src/utilities/sleep.ts new file mode 100644 index 0000000000..e7e70f3133 --- /dev/null +++ b/vscode/extension/src/utilities/sleep.ts @@ -0,0 +1,8 @@ +/** + * Utility function that creates a promise which resolves after the specified time. + * @param ms The time to sleep in milliseconds + * @returns A promise that resolves after the specified time + */ +export async function sleep(ms: number): Promise { + return new Promise(resolve => setTimeout(resolve, ms)) +} diff --git a/vscode/extension/src/utilities/sqlmesh/sqlmesh.ts b/vscode/extension/src/utilities/sqlmesh/sqlmesh.ts new file mode 100644 index 0000000000..c9e181fc06 --- /dev/null +++ b/vscode/extension/src/utilities/sqlmesh/sqlmesh.ts @@ -0,0 +1,479 @@ +import path from 'path' +import { traceInfo, traceLog, traceVerbose } from '../common/log' +import { getInterpreterDetails, getPythonEnvVariables } from '../common/python' +import { Result, err, isErr, ok } from '@bus/result' +import { getProjectRoot } from '../common/utilities' +import { isPythonModuleInstalled } from '../python' +import fs from 'fs' +import { ErrorType } from '../errors' +import { isSignedIntoTobikoCloud } from '../../auth/auth' +import { execAsync } from '../exec' +import z from 'zod' +import { ProgressLocation, window } from 'vscode' +import { IS_WINDOWS } from '../isWindows' +import { getSqlmeshLspEntryPoint, resolveProjectPath } from '../config' +import { isSemVerGreaterThanOrEqual } from '../semver' + +export interface SqlmeshExecInfo { + workspacePath: string + bin: string + env: Record + args: string[] +} + +/** + * Gets the current SQLMesh environment variables that would be used for execution. + * This is useful for debugging and understanding the environment configuration. + * + * @returns A Result containing the environment variables or an error + */ +export async function getSqlmeshEnvironment(): Promise, string>> { + const interpreterDetails = await getInterpreterDetails() + const envVariables = await getPythonEnvVariables() + if (isErr(envVariables)) { + return err(envVariables.error) + } + + const binPath = interpreterDetails.binPath + const virtualEnvPath = binPath && interpreterDetails.isVirtualEnvironment + ? path.dirname(path.dirname(binPath)) // binPath points to bin dir in venv + : binPath ? path.dirname(binPath) : undefined + + const env: Record = { + ...process.env, + ...envVariables.value, + PYTHONPATH: interpreterDetails.path?.[0] ?? '', + } + + if (virtualEnvPath) { + env['VIRTUAL_ENV'] = virtualEnvPath + } + + if (binPath) { + env['PATH'] = `${binPath}${path.delimiter}${process.env.PATH || ''}` + } + + return ok(env) +} + +/** + * Returns true if the current project is a Tcloud project. To detect this we, + * 1. Check if the project has a tcloud.yaml file in the project root. If it does, we assume it's a Tcloud project. + * 2. Check if the project has tcloud installed in the Python environment. + * + * @returns A Result indicating whether tcloud is installed. + */ +export const isTcloudProject = async (): Promise> => { + const projectRoot = await getProjectRoot() + const resolvedPath = resolveProjectPath(projectRoot) + if (isErr(resolvedPath)) { + return err(resolvedPath.error) + } + const tcloudYamlPath = path.join(resolvedPath.value.workspaceFolder, 'tcloud.yaml') + const tcloudYmlPath = path.join(resolvedPath.value.workspaceFolder, 'tcloud.yml') + const isTcloudYamlFilePresent = fs.existsSync(tcloudYamlPath) + const isTcloudYmlFilePresent = fs.existsSync(tcloudYmlPath) + if (isTcloudYamlFilePresent || isTcloudYmlFilePresent) { + traceVerbose(`tcloud yaml or yml file present at : ${tcloudYamlPath}`) + return ok(true) + } + const isTcloudInstalled = await isPythonModuleInstalled('tcloud') + if (isErr(isTcloudInstalled)) { + return isTcloudInstalled + } + traceVerbose(`tcloud is installed: ${isTcloudInstalled.value}`) + return ok(isTcloudInstalled.value) +} + +/** + * Get the tcloud executable for the current Python environment. + * + * @returns The tcloud executable for the current Python environment. + */ +export const getTcloudBin = async (): Promise> => { + const tcloud = IS_WINDOWS ? 'tcloud.exe' : 'tcloud' + const interpreterDetails = await getInterpreterDetails() + if (!interpreterDetails.path) { + return err({ + type: 'tcloud_bin_not_found', + }) + } + const pythonPath = interpreterDetails.path[0] + const binPath = path.join(path.dirname(pythonPath), tcloud) + if (!fs.existsSync(binPath)) { + return err({type: 'tcloud_bin_not_found'}) + } + const env = await getSqlmeshEnvironment() + if (isErr(env)) { + return err({ + type: 'generic', + message: env.error, + }) + } + return ok({ + bin: binPath, + workspacePath: interpreterDetails.resource?.fsPath ?? '', + env: env.value, + args: [], + }) +} + +const isSqlmeshInstalledSchema = z.object({ + is_installed: z.boolean(), +}) + +/** + * Returns true if the current project is a sqlmesh enterprise project is installed and updated. + * + * @returns A Result indicating whether sqlmesh enterprise is installed and updated. + */ +export const isSqlmeshEnterpriseInstalled = async (): Promise< + Result +> => { + traceInfo('Checking if sqlmesh enterprise is installed') + const tcloudBin = await getTcloudBin() + if (isErr(tcloudBin)) { + return tcloudBin + } + const projectRoot = await getProjectRoot() + const resolvedPath = resolveProjectPath(projectRoot) + if (isErr(resolvedPath)) { + return err({ + type: 'generic', + message: resolvedPath.error, + }) + } + const called = await execAsync(tcloudBin.value.bin, ['is_sqlmesh_installed'], { + cwd: resolvedPath.value.workspaceFolder, + env: tcloudBin.value.env, + }) + if (called.exitCode !== 0) { + return err({ + type: 'generic', + message: `Failed to check if sqlmesh enterprise is installed: ${called.stderr}`, + }) + } + const parsed = isSqlmeshInstalledSchema.safeParse(JSON.parse(called.stdout)) + if (!parsed.success) { + return err({ + type: 'generic', + message: `Failed to parse sqlmesh enterprise installation status: ${parsed.error.message}`, + }) + } + return ok(parsed.data.is_installed) +} + +/** + * Install sqlmesh enterprise. + * + * @returns A Result indicating whether sqlmesh enterprise was installed. + */ +export const installSqlmeshEnterprise = async ( + abortController: AbortController, +): Promise> => { + const tcloudBin = await getTcloudBin() + if (isErr(tcloudBin)) { + return tcloudBin + } + const projectRoot = await getProjectRoot() + const resolvedPath = resolveProjectPath(projectRoot) + if (isErr(resolvedPath)) { + return err({ + type: 'generic', + message: resolvedPath.error, + }) + } + const called = await execAsync(tcloudBin.value.bin, ['install_sqlmesh'], { + signal: abortController.signal, + cwd: resolvedPath.value.workspaceFolder, + env: tcloudBin.value.env, + }) + if (called.exitCode !== 0) { + return err({ + type: 'generic', + message: `Failed to install sqlmesh enterprise: ${called.stderr}`, + }) + } + return ok(true) +} + +let installationLock: Promise> | undefined = undefined + +/** + * Checks if sqlmesh enterprise is installed and updated. If not, it will install it. + * This will also create a progress message in vscode in order to inform the user that sqlmesh enterprise is being installed. + * Uses a lock mechanism to prevent parallel executions. + * + * @returns A Result indicating whether sqlmesh enterprise was installed in the call. + */ +export const ensureSqlmeshEnterpriseInstalled = async (): Promise< + Result +> => { + // If there's an ongoing installation, wait for it to complete + if (installationLock) { + return installationLock + } + + // Create a new lock + installationLock = (async () => { + try { + traceInfo('Ensuring sqlmesh enterprise is installed') + const isInstalled = await isSqlmeshEnterpriseInstalled() + if (isErr(isInstalled)) { + return isInstalled + } + if (isInstalled.value) { + traceInfo('Sqlmesh enterprise is installed') + return ok(false) + } + traceInfo('Sqlmesh enterprise is not installed, installing...') + const abortController = new AbortController() + const installResult = await window.withProgress( + { + location: ProgressLocation.Notification, + title: 'SQLMesh', + cancellable: true, + }, + async (progress, token) => { + // Connect the cancellation token to our abort controller + token.onCancellationRequested(() => { + abortController.abort() + traceInfo('Sqlmesh enterprise installation cancelled') + window.showInformationMessage('Installation cancelled') + }) + progress.report({ message: 'Installing enterprise python package...' }) + const result = await installSqlmeshEnterprise(abortController) + if (isErr(result)) { + return result + } + return ok(true) + }, + ) + if (isErr(installResult)) { + return installResult + } + return ok(true) + } finally { + // Clear the lock when done + installationLock = undefined + } + })() + + return installationLock +} + +/** + * Ensure that the sqlmesh_lsp dependencies are installed. + * + * @returns A Result indicating whether the sqlmesh_lsp dependencies were installed. + */ +export const ensureSqlmeshLspDependenciesInstalled = async (): Promise< + Result +> => { + const isPyglsInstalled = await isPythonModuleInstalled('pygls') + if (isErr(isPyglsInstalled)) { + return err({ + type: 'generic', + message: isPyglsInstalled.error, + }) + } + const isLsprotocolInstalled = await isPythonModuleInstalled('lsprotocol') + if (isErr(isLsprotocolInstalled)) { + return err({ + type: 'generic', + message: isLsprotocolInstalled.error, + }) + } + const isTobikoCloudInstalled = await isTcloudProject() + if (isErr(isTobikoCloudInstalled)) { + return err({ + type: 'generic', + message: isTobikoCloudInstalled.error, + }) + } + if (!isPyglsInstalled.value || !isLsprotocolInstalled.value) { + return err({ + type: 'sqlmesh_lsp_dependencies_missing', + is_missing_pygls: !isPyglsInstalled.value, + is_missing_lsprotocol: !isLsprotocolInstalled.value, + is_tobiko_cloud: isTobikoCloudInstalled.value, + }) + } + return ok(undefined) +} + +/** + * Get the sqlmesh_lsp executable for the current workspace. + * + * @returns The sqlmesh_lsp executable for the current workspace. + */ +export const sqlmeshLspExec = async (): Promise< + Result +> => { + const projectRoot = await getProjectRoot() + const resolvedPath = resolveProjectPath(projectRoot) + if (isErr(resolvedPath)) { + return err({ + type: 'generic', + message: resolvedPath.error, + }) + } + const workspacePath = resolvedPath.value.workspaceFolder + + const configuredLSPExec = getSqlmeshLspEntryPoint() + if (configuredLSPExec) { + traceLog(`Using configured SQLMesh LSP entry point: ${configuredLSPExec.entrypoint} ${configuredLSPExec.args.join(' ')}`) + return ok({ + bin: configuredLSPExec.entrypoint, + workspacePath: workspacePath, + env: process.env, + args: configuredLSPExec.args, + }) + } + const sqlmeshLSP = IS_WINDOWS ? 'sqlmesh_lsp.exe' : 'sqlmesh_lsp' + const envVariables = await getPythonEnvVariables() + if (isErr(envVariables)) { + return err({ + type: 'generic', + message: envVariables.error, + }) + } + + const interpreterDetails = await getInterpreterDetails() + traceLog(`Interpreter details: ${JSON.stringify(interpreterDetails)}`) + if (interpreterDetails.path) { + traceVerbose( + `Using interpreter from Python extension: ${interpreterDetails.path.join( + ' ', + )}`, + ) + } + if (interpreterDetails.isVirtualEnvironment) { + traceLog('Using virtual environment') + const tcloudInstalled = await isTcloudProject() + if (isErr(tcloudInstalled)) { + return err({ + type: 'generic', + message: tcloudInstalled.error, + }) + } + if (tcloudInstalled.value) { + traceLog('Tcloud installed, installing sqlmesh') + const tcloudBin = await getTcloudBin() + if (isErr(tcloudBin)) { + return tcloudBin + } + const isSignedIn = await isSignedIntoTobikoCloud() + if (!isSignedIn) { + return err({ + type: 'not_signed_in', + }) + } + const ensured = await ensureSqlmeshEnterpriseInstalled() + if (isErr(ensured)) { + return ensured + } + const tcloudBinVersion = await getTcloudBinVersion() + if (isErr(tcloudBinVersion)) { + return tcloudBinVersion + } + // TODO: Remove this once we have a stable version of tcloud that supports sqlmesh_lsp. + if (isSemVerGreaterThanOrEqual(tcloudBinVersion.value, [2, 10, 1])) { + return ok ({ + bin: tcloudBin.value.bin, + workspacePath: workspacePath, + env: tcloudBin.value.env, + args: ['sqlmesh_lsp'], + }) + } + } + const binPath = path.join(interpreterDetails.binPath!, sqlmeshLSP) + traceLog(`Bin path: ${binPath}`) + if (!fs.existsSync(binPath)) { + return err({ + type: 'sqlmesh_lsp_not_found', + }) + } + const ensuredDependencies = await ensureSqlmeshLspDependenciesInstalled() + if (isErr(ensuredDependencies)) { + return ensuredDependencies + } + const env = await getSqlmeshEnvironment() + if (isErr(env)) { + return err({ + type: 'generic', + message: env.error, + }) + } + return ok({ + bin: binPath, + workspacePath: workspacePath, + env: env.value, + args: [], + }) + } else { + const env = await getSqlmeshEnvironment() + if (isErr(env)) { + return err({ + type: 'generic', + message: env.error, + }) + } + const exists = await doesExecutableExist(sqlmeshLSP) + if (!exists) { + return err({ + type: 'sqlmesh_lsp_not_found', + }) + } + return ok({ + bin: sqlmeshLSP, + workspacePath: workspacePath, + env: env.value, + args: [], + }) + } +} + +async function doesExecutableExist(executable: string): Promise { + const command = process.platform === 'win32' ? 'where.exe' : 'which' + traceLog(`Checking if ${executable} exists with ${command}`) + try { + const result = await execAsync(command, [executable]) + traceLog(`Checked if ${executable} exists with ${command}, with result ${result.exitCode}`) + const exists = result.exitCode === 0 + traceLog(`Checked if ${executable} exists with ${command}, with result ${exists}`) + return exists + } catch { + traceLog(`Checked if ${executable} exists with ${command}, errored, returning false`) + return false + } +} + +/** + * Get the version of the tcloud bin. + * + * @returns The version of the tcloud bin. + */ +async function getTcloudBinVersion(): Promise> { + const tcloudBin = await getTcloudBin() + if (isErr(tcloudBin)) { + return tcloudBin + } + const called = await execAsync(tcloudBin.value.bin, ['--version'], { + env: tcloudBin.value.env, + }) + if (called.exitCode !== 0) { + return err({ + type: 'generic', + message: `Failed to get tcloud bin version: ${called.stderr}`, + }) + } + const version = called.stdout.split('.').map(Number) + if (version.length !== 3) { + return err({ + type: 'generic', + message: `Failed to get tcloud bin version: ${called.stdout}`, + }) + } + return ok(version as [number, number, number]) +} \ No newline at end of file diff --git a/vscode/extension/src/webviews/lineagePanel.ts b/vscode/extension/src/webviews/lineagePanel.ts new file mode 100644 index 0000000000..0fd0be9c2a --- /dev/null +++ b/vscode/extension/src/webviews/lineagePanel.ts @@ -0,0 +1,225 @@ +import { CallbackEvent, RPCRequest } from '@bus/callbacks' +import { + Disposable, + Uri, + Webview, + WebviewView, + WebviewViewProvider, + window, + workspace, +} from 'vscode' +import { getWorkspaceFolders } from '../utilities/common/vscodeapi' +import { LSPClient } from '../lsp/lsp' +import { isErr } from '@bus/result' + +export class LineagePanel implements WebviewViewProvider, Disposable { + public static readonly viewType = 'sqlmesh.lineage' + + private panel: WebviewView | undefined + private lsp: LSPClient + private readonly extensionUri: Uri + + private disposables: Disposable[] = [] + + public constructor(extensionUri: Uri, lsp: LSPClient) { + this.extensionUri = extensionUri + this.lsp = lsp + + if (this.panel) { + this.panel.webview.html = this.getHTML(this.panel.webview) + } + + this.disposables.push( + workspace.onDidSaveTextDocument(document => { + this.panel?.webview.postMessage({ + key: 'vscode_send', + payload: { + key: 'savedFile', + payload: { fileUri: document.uri.toString() }, + }, + }) + }), + ) + this.disposables.push( + window.onDidChangeActiveTextEditor(editor => { + if (editor) { + this.panel?.webview.postMessage({ + key: 'vscode_send', + payload: { + key: 'changeFocusOnFile', + payload: { path: editor.document.uri.toString() }, + }, + }) + } + }), + ) + } + + public resolveWebviewView(webviewView: WebviewView) { + if (this.panel) { + webviewView = this.panel + } + this.panel = webviewView + + webviewView.webview.options = { + // Allow scripts in the webview + enableScripts: true, + localResourceRoots: [this.extensionUri], + } + + // Set content options for external URL access + // Set up message listener for events from the iframe + const disposable = webviewView.webview.onDidReceiveMessage( + async request => { + if (!request) { + return + } + if (!request.key) { + return + } + const message: CallbackEvent = request + switch (message.key) { + case 'openFile': { + const workspaceFolders = getWorkspaceFolders() + if (workspaceFolders.length != 1) { + throw new Error('Only one workspace folder is supported') + } + const fullPath = Uri.parse(message.payload.uri) + const document = await workspace.openTextDocument(fullPath) + await window.showTextDocument(document) + break + } + case 'rpcRequest': { + const payload: RPCRequest = message.payload + const requestId = payload.requestId + switch (payload.method) { + case 'api_query': { + const response = await this.lsp.call_custom_method( + 'sqlmesh/api', + payload.params, + ) + let responseCallback: CallbackEvent + if (isErr(response)) { + let errorMessage: string + switch (response.error.type) { + case 'generic': + errorMessage = response.error.message + break + case 'invalid_state': + errorMessage = `Invalid state: ${response.error.message}` + break + case 'sqlmesh_outdated': + errorMessage = `SQLMesh version issue: ${response.error.message}` + break + default: + errorMessage = 'Unknown error' + } + responseCallback = { + key: 'rpcResponse', + payload: { + requestId, + result: { + ok: false, + error: errorMessage, + }, + }, + } + } else { + responseCallback = { + key: 'rpcResponse', + payload: { + requestId, + result: response, + }, + } + } + await webviewView.webview.postMessage(responseCallback) + break + } + case 'get_active_file': { + const active_file = window.activeTextEditor?.document.uri.fsPath + const responseCallback: CallbackEvent = { + key: 'rpcResponse', + payload: { + requestId, + result: { + fileUri: active_file, + }, + }, + } + await webviewView.webview.postMessage(responseCallback) + break + } + default: { + throw new Error(`Unhandled RPC method: ${payload.method}`) + } + } + break + } + default: + console.error( + 'Unhandled message type under queryRequest: ', + message, + ) + } + }, + undefined, + [], + ) + this.disposables.push(disposable) + webviewView.webview.html = this.getHTML(webviewView.webview) + } + + private getHTML(panel: Webview) { + const cssUri = panel.asWebviewUri( + Uri.joinPath(this.extensionUri, 'src_react', 'assets', 'index.css'), + ) + const jsUri = panel.asWebviewUri( + Uri.joinPath(this.extensionUri, 'src_react', 'assets', 'index.js'), + ) + const faviconUri = panel.asWebviewUri( + Uri.joinPath(this.extensionUri, 'src_react', 'favicon.ico'), + ) + const logoUri = panel.asWebviewUri( + Uri.joinPath(this.extensionUri, 'src_react', 'logo192.png'), + ) + + // Handle query requests from the React app + + return ` + + + + + + + + + + Create TanStack App - react + + + + + +
+ + +` + } + + dispose() { + // WebviewView doesn't have a dispose method + // We can clear references + this.panel = undefined + this.disposables.forEach(disposable => { + disposable.dispose() + }) + this.disposables = [] + } +} diff --git a/vscode/extension/tests/bad_setup.spec.ts b/vscode/extension/tests/bad_setup.spec.ts new file mode 100644 index 0000000000..b76eee2b3d --- /dev/null +++ b/vscode/extension/tests/bad_setup.spec.ts @@ -0,0 +1,134 @@ +import { expect, test } from './fixtures' +import fs from 'fs-extra' +import os from 'os' +import path from 'path' +import { + createVirtualEnvironment, + openFile, + openLineageView, + openServerPage, + pipInstall, + REPO_ROOT, + SUSHI_SOURCE_PATH, + waitForLoadedSQLMesh, +} from './utils' + +test('missing LSP dependencies shows install prompt', async ({ + page, + sharedCodeServer, + tempDir, +}) => { + const pythonEnvDir = path.join(tempDir, '.venv') + const pythonDetails = await createVirtualEnvironment(pythonEnvDir) + const custom_materializations = path.join( + REPO_ROOT, + 'examples', + 'custom_materializations', + ) + const sqlmeshWithExtras = `${REPO_ROOT}[bigquery]` + await pipInstall(pythonDetails, [sqlmeshWithExtras, custom_materializations]) + + // Copy sushi project + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + + // Configure VS Code settings to use our Python environment + const settings = { + 'python.defaultInterpreterPath': pythonDetails.pythonPath, + 'sqlmesh.environmentPath': pythonEnvDir, + } + await fs.ensureDir(path.join(tempDir, '.vscode')) + await fs.writeJson(path.join(tempDir, '.vscode', 'settings.json'), settings, { + spaces: 2, + }) + + await openServerPage(page, tempDir, sharedCodeServer) + + // Open a top_waiters model to trigger SQLMesh activation + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + await page + .getByRole('treeitem', { name: 'customers.sql', exact: true }) + .locator('a') + .click() + + // Wait for the message to show that LSP extras need to be installed + await page.waitForSelector('text=LSP dependencies missing') + expect(await page.locator('text=Install').count()).toBeGreaterThanOrEqual(1) +}) + +test('lineage, no sqlmesh found', async ({ + page, + sharedCodeServer, + tempDir, +}) => { + const pythonEnvDir = path.join(tempDir, '.venv') + const pythonDetails = await createVirtualEnvironment(pythonEnvDir) + + // Copy sushi project + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + + // Configure VS Code settings to use our Python environment + const settings = { + 'python.defaultInterpreterPath': pythonDetails.pythonPath, + 'sqlmesh.environmentPath': pythonEnvDir, + } + await fs.ensureDir(path.join(tempDir, '.vscode')) + await fs.writeJson(path.join(tempDir, '.vscode', 'settings.json'), settings, { + spaces: 2, + }) + + // navigate to code-server instance + await openServerPage(page, tempDir, sharedCodeServer) + + await openLineageView(page) + + // Assert shows that sqlmesh is not installed + await page.waitForSelector('text=SQLMesh LSP not found') +}) + +// Checks that if you have another file open like somewhere else, it still checks the workspace first for a successful context +// it's very flaky but runs when debugging +// - the typing in of the file name is very flaky +test('check that the LSP runs correctly by opening lineage when looking at another file before not in workspace', async ({ + page, + sharedCodeServer, + tempDir, +}) => { + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + const pythonEnvDir = path.join(tempDir, '.venv') + const pythonDetails = await createVirtualEnvironment(pythonEnvDir) + const sqlmeshWithExtras = `${REPO_ROOT}[lsp, bigquery]` + const custom_materializations = path.join( + REPO_ROOT, + 'examples', + 'custom_materializations', + ) + await pipInstall(pythonDetails, [sqlmeshWithExtras, custom_materializations]) + + // Configure VS Code settings to use our Python environment + const settings = { + 'python.defaultInterpreterPath': pythonDetails.pythonPath, + 'sqlmesh.environmentPath': tempDir, + } + await fs.ensureDir(path.join(tempDir, '.vscode')) + await fs.writeJson(path.join(tempDir, '.vscode', 'settings.json'), settings, { + spaces: 2, + }) + + // Write a sql file in another folder + const tempDir2 = await fs.mkdtemp( + path.join(os.tmpdir(), 'vscode-test-tcloud-2-'), + ) + const sqlFile = path.join(tempDir2, 'models', 'customers.sql') + await fs.ensureDir(path.dirname(sqlFile)) + await fs.writeFile(sqlFile, 'SELECT 1') + + await openServerPage(page, tempDir, sharedCodeServer) + + // Open the SQL file from the other directory + await openFile(page, sqlFile) + + await waitForLoadedSQLMesh(page) +}) diff --git a/vscode/extension/tests/broken_project.spec.ts b/vscode/extension/tests/broken_project.spec.ts new file mode 100644 index 0000000000..f32a39a86d --- /dev/null +++ b/vscode/extension/tests/broken_project.spec.ts @@ -0,0 +1,292 @@ +import { test, expect } from './fixtures' +import fs from 'fs-extra' +import path from 'path' +import { + openLineageView, + openServerPage, + openProblemsView, + saveFile, + SUSHI_SOURCE_PATH, + waitForLoadedSQLMesh, +} from './utils' +import { createPythonInterpreterSettingsSpecifier } from './utils_code_server' + +test('bad project, double model', async ({ + tempDir, + page, + sharedCodeServer, +}) => { + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + + // Read the customers.sql file + const customersSql = await fs.readFile( + path.join(tempDir, 'models', 'customers.sql'), + 'utf8', + ) + + // Write the customers.sql file with a double model + await fs.writeFile( + path.join(tempDir, 'models', 'customers_duplicated.sql'), + customersSql, + ) + + await createPythonInterpreterSettingsSpecifier(tempDir) + await openServerPage(page, tempDir, sharedCodeServer) + + await page.waitForSelector('text=models') + + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + + await page + .getByRole('treeitem', { name: 'customers.sql', exact: true }) + .locator('a') + .click() + + await page.waitForSelector('text=Error creating context') + + await page.waitForTimeout(500) +}) + +test('working project, then broken through adding double model, then refixed', async ({ + page, + tempDir, + sharedCodeServer, +}) => { + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + + await createPythonInterpreterSettingsSpecifier(tempDir) + await openServerPage(page, tempDir, sharedCodeServer) + await page.waitForLoadState('networkidle') + + // Open the lineage view to confirm it loads properly + await openLineageView(page) + await waitForLoadedSQLMesh(page) + + // Read the customers.sql file + const customersSql = await fs.readFile( + path.join(tempDir, 'models', 'customers.sql'), + 'utf8', + ) + + // Add a duplicate model to break the project + await fs.writeFile( + path.join(tempDir, 'models', 'customers_duplicated.sql'), + customersSql, + ) + + // Open the customers model to trigger the error + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + await page + .getByRole('treeitem', { name: 'customers.sql', exact: true }) + .locator('a') + .click() + // Save to refresh the context + await saveFile(page) + + // Wait for the error to appear + const iframes = page.locator('iframe') + const iframeCount = await iframes.count() + let errorCount = 0 + + for (let i = 0; i < iframeCount; i++) { + const iframe = iframes.nth(i) + const contentFrame = iframe.contentFrame() + if (contentFrame) { + const activeFrame = contentFrame.locator('#active-frame').contentFrame() + if (activeFrame) { + try { + await activeFrame + .getByText('Error: Failed to load model') + .waitFor({ timeout: 1000 }) + errorCount++ + } catch { + // Continue to next iframe if this one doesn't have the error + continue + } + } + } + } + expect(errorCount).toBeGreaterThan(0) + + // Remove the duplicated model to fix the project + await fs.remove(path.join(tempDir, 'models', 'customers_duplicated.sql')) + + // Save again to refresh the context + await saveFile(page) + + const iframes2 = page.locator('iframe') + const iframeCount2 = await iframes2.count() + let raw_demographicsCount = 0 + + for (let i = 0; i < iframeCount2; i++) { + const iframe = iframes2.nth(i) + const contentFrame = iframe.contentFrame() + if (contentFrame) { + const activeFrame = contentFrame.locator('#active-frame').contentFrame() + if (activeFrame) { + try { + await activeFrame + .getByText('sushi.customers') + .waitFor({ timeout: 1000 }) + raw_demographicsCount++ + } catch { + // Continue to next iframe if this one doesn't have the error + continue + } + } + } + } + expect(raw_demographicsCount).toBeGreaterThan(0) +}) + +test('bad project, double model, then fixed', async ({ + page, + tempDir, + sharedCodeServer, +}) => { + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + + // Read the customers.sql file + const customersSql = await fs.readFile( + path.join(tempDir, 'models', 'customers.sql'), + 'utf8', + ) + + // Write the customers.sql file with a double model + await fs.writeFile( + path.join(tempDir, 'models', 'customers_duplicated.sql'), + customersSql, + ) + + await openServerPage(page, tempDir, sharedCodeServer) + + await page.waitForSelector('text=models') + + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + + await page + .getByRole('treeitem', { name: 'customers.sql', exact: true }) + .locator('a') + .click() + + await page.waitForSelector('text=Error creating context') + + // Remove the duplicated model + await fs.remove(path.join(tempDir, 'models', 'customers_duplicated.sql')) + + // Open the linage view + await openLineageView(page) + + // Wait for the error to go away + const iframes = page.locator('iframe') + const iframeCount = await iframes.count() + let raw_demographicsCount = 0 + + for (let i = 0; i < iframeCount; i++) { + const iframe = iframes.nth(i) + const contentFrame = iframe.contentFrame() + if (contentFrame) { + const activeFrame = contentFrame.locator('#active-frame').contentFrame() + if (activeFrame) { + try { + await activeFrame + .getByText('sushi.customers') + .waitFor({ timeout: 1000 }) + raw_demographicsCount++ + } catch { + continue + } + } + } + } + expect(raw_demographicsCount).toBeGreaterThan(0) +}) + +test('bad project, double model, check lineage', async ({ + page, + sharedCodeServer, + tempDir, +}) => { + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + + // Read the customers.sql file + const customersSql = await fs.readFile( + path.join(tempDir, 'models', 'customers.sql'), + 'utf8', + ) + + // Write the customers.sql file with a double model + await fs.writeFile( + path.join(tempDir, 'models', 'customers_duplicated.sql'), + customersSql, + ) + + await createPythonInterpreterSettingsSpecifier(tempDir) + await openServerPage(page, tempDir, sharedCodeServer) + + // Open the lineage view + await openLineageView(page) + + await page.waitForSelector('text=Error creating context') + await page.waitForSelector('text=Error:') + + await page.waitForTimeout(500) +}) + +test('bad model block, then fixed', async ({ + page, + sharedCodeServer, + tempDir, +}) => { + // Copy over the sushi project + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + await createPythonInterpreterSettingsSpecifier(tempDir) + + // Add a model with a bad model block + const badModelPath = path.join(tempDir, 'models', 'bad_model.sql') + const contents = + 'MODEL ( name sushi.bad_block, test); SELECT * FROM sushi.customers' + await fs.writeFile(badModelPath, contents) + + await openServerPage(page, tempDir, sharedCodeServer) + await page.waitForLoadState('networkidle') + + // Open the customers.sql model + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + await page + .getByRole('treeitem', { name: 'customers.sql', exact: true }) + .locator('a') + .click() + + // Wait for the error to appear + await page.waitForSelector('text=Error creating context') + + await openProblemsView(page) + + // Assert error is present in the problems view + const errorElement = page + .getByText("Required keyword: 'value' missing for") + .first() + await expect(errorElement).toBeVisible({ timeout: 5000 }) + + // Remove the bad model file + await fs.remove(badModelPath) + + // Click on the grain part of the model and save + await page.getByText('grain').click() + await saveFile(page) + + await waitForLoadedSQLMesh(page) +}) diff --git a/vscode/extension/tests/commands.spec.ts b/vscode/extension/tests/commands.spec.ts new file mode 100644 index 0000000000..afd926310c --- /dev/null +++ b/vscode/extension/tests/commands.spec.ts @@ -0,0 +1,96 @@ +import { test, expect } from './fixtures' +import path from 'path' +import fs from 'fs-extra' +import os from 'os' +import { + openServerPage, + saveFile, + SUSHI_SOURCE_PATH, + waitForLoadedSQLMesh, +} from './utils' +import { createPythonInterpreterSettingsSpecifier } from './utils_code_server' +import { DuckDBInstance } from '@duckdb/node-api' + +test.describe('Update external models columns', () => { + test('New external model', async ({ page, sharedCodeServer }) => { + // Normal setting up + const tempDir = await fs.mkdtemp( + path.join(os.tmpdir(), 'vscode-test-sushi-'), + ) + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + await createPythonInterpreterSettingsSpecifier(tempDir) + + // Changing the config to set the default gateway to use the fixed one. + const configPath = path.join(tempDir, 'config.py') + const configContent = await fs.readFile(configPath, 'utf8') + const original = `default_gateway="duckdb",` + expect(configContent).toContain(original) + const target = `default_gateway="duckdb_persistent",` + const updatedConfigContent = configContent.replace(original, target) + expect(updatedConfigContent).toContain(target) + await fs.writeFile(configPath, updatedConfigContent) + + // Create an additional table in the database + const table = 'raw.test_table' + const databasePath = path.join(tempDir, 'data', 'duckdb.db') + const instance = await DuckDBInstance.create(databasePath) + const connection = await instance.connect() + await connection.run(`CREATE SCHEMA IF NOT EXISTS raw`) + await connection.run( + `CREATE TABLE IF NOT EXISTS ${table}( + id INTEGER, + value VARCHAR + )`, + ) + connection.closeSync() + instance.closeSync() + expect(fs.existsSync(databasePath)).toBe(true) + + // Update the external_models in the config to include the new table but + // not the columns by appending '- name: ${table}' to the external_models.yaml file + const externalModelsPath = path.join(tempDir, 'external_models.yaml') + const externalModelsContent = await fs.readFile(externalModelsPath, 'utf8') + const newExternalModel = `- name: ${table}` + const updatedExternalModelsContent = `${externalModelsContent}\n${newExternalModel}` + await fs.writeFile(externalModelsPath, updatedExternalModelsContent) + + // Open the server page + await openServerPage(page, tempDir, sharedCodeServer) + + // Wait for the models folder to be visible + await page.waitForSelector('text=models') + + // Click on the models folder, excluding external_models + await page + .getByRole('treeitem', { name: 'external_models.yaml', exact: true }) + .locator('a') + .click() + + await waitForLoadedSQLMesh(page) + + // Click the update columns button + await page.waitForSelector('text=Update Columns') + const updateColumnButtons = page.getByRole('button', { + name: 'Update Columns', + exact: true, + }) + // Click each one of them + for (const button of await updateColumnButtons.all()) { + await button.click() + await page.waitForTimeout(1_000) // Wait for the action to complete + } + + await page.waitForTimeout(1_000) + await saveFile(page) + await page.waitForTimeout(1_000) + + // Check the file contains the columns + const updatedExternalModelsContentAfterUpdate = await fs.readFile( + externalModelsPath, + 'utf8', + ) + expect(updatedExternalModelsContentAfterUpdate).toContain( + `- name: ${table}\n columns:\n id: INT\n value: TEXT`, + ) + }) +}) diff --git a/vscode/extension/tests/completions.spec.ts b/vscode/extension/tests/completions.spec.ts new file mode 100644 index 0000000000..32ec7d96e3 --- /dev/null +++ b/vscode/extension/tests/completions.spec.ts @@ -0,0 +1,164 @@ +import { test, expect } from './fixtures' +import path from 'path' +import fs from 'fs-extra' +import os from 'os' +import { + openServerPage, + SUSHI_SOURCE_PATH, + waitForLoadedSQLMesh, +} from './utils' +import { createPythonInterpreterSettingsSpecifier } from './utils_code_server' + +test('Autocomplete for model names', async ({ page, sharedCodeServer }) => { + const tempDir = await fs.mkdtemp(path.join(os.tmpdir(), 'vscode-test-sushi-')) + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + + await createPythonInterpreterSettingsSpecifier(tempDir) + await openServerPage(page, tempDir, sharedCodeServer) + + // Wait for the models folder to be visible + await page.waitForSelector('text=models') + + // Click on the models folder + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + + // Open the top_waiters model + await page + .getByRole('treeitem', { name: 'top_waiters.sql', exact: true }) + .locator('a') + .click() + + await page.waitForSelector('text=grain') + await waitForLoadedSQLMesh(page) + + await page.locator('text=grain').first().click() + + // Move to the end of the file + for (let i = 0; i < 100; i++) { + await page.keyboard.press('ArrowDown') + } + + // Add a new line + await page.keyboard.press('Enter') + + // Type the beginning of sushi.customers to trigger autocomplete + await page.keyboard.type('sushi.waiter_as_customer') + + // Wait a moment for autocomplete to appear + await page.waitForTimeout(500) + + // Check if the autocomplete suggestion for sushi.customers is visible + expect( + await page.locator('text=sushi.waiter_as_customer_by_day').count(), + ).toBeGreaterThanOrEqual(1) + expect( + await page.locator('text=SQLMesh Model').count(), + ).toBeGreaterThanOrEqual(1) +}) + +// Skip the macro completions test as regular checks because they are flaky and +// covered in other non-integration tests. +test.describe('Macro Completions', () => { + test('Completion for inbuilt macros', async ({ page, sharedCodeServer }) => { + const tempDir = await fs.mkdtemp( + path.join(os.tmpdir(), 'vscode-test-sushi-'), + ) + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + + await createPythonInterpreterSettingsSpecifier(tempDir) + await openServerPage(page, tempDir, sharedCodeServer) + + // Wait for the models folder to be visible + await page.waitForSelector('text=models') + + // Click on the models folder + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + + // Open the top_waiters model + await page + .getByRole('treeitem', { name: 'customers.sql', exact: true }) + .locator('a') + .click() + + await page.waitForSelector('text=grain') + await waitForLoadedSQLMesh(page) + + await page.locator('text=grain').first().click() + + // Move to the end of the file + for (let i = 0; i < 100; i++) { + await page.keyboard.press('ArrowDown') + } + + // Add a new line + await page.keyboard.press('Enter') + + await page.waitForTimeout(500) + + // Hit the '@' key to trigger autocomplete for inbuilt macros + await page.keyboard.press('@') + await page.keyboard.type('eac') + + // Wait a moment for autocomplete to appear + await page.waitForTimeout(500) + + // Check if the autocomplete suggestion for inbuilt macros is visible + expect(await page.locator('text=@each').count()).toBeGreaterThanOrEqual(1) + }) + + test('Completion for custom macros', async ({ page, sharedCodeServer }) => { + const tempDir = await fs.mkdtemp( + path.join(os.tmpdir(), 'vscode-test-sushi-'), + ) + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + + await createPythonInterpreterSettingsSpecifier(tempDir) + await openServerPage(page, tempDir, sharedCodeServer) + + // Wait for the models folder to be visible + await page.waitForSelector('text=models') + + // Click on the models folder + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + + // Open the top_waiters model + await page + .getByRole('treeitem', { name: 'customers.sql', exact: true }) + .locator('a') + .click() + + await page.waitForSelector('text=grain') + await waitForLoadedSQLMesh(page) + + await page.locator('text=grain').first().click() + + // Move to the end of the file + for (let i = 0; i < 100; i++) { + await page.keyboard.press('ArrowDown') + } + + // Add a new line + await page.keyboard.press('Enter') + + // Type the beginning of a macro to trigger autocomplete + await page.keyboard.press('@') + await page.keyboard.type('add_o') + + // Wait a moment for autocomplete to appear + await page.waitForTimeout(500) + + // Check if the autocomplete suggestion for custom macros is visible + expect(await page.locator('text=@add_one').count()).toBeGreaterThanOrEqual( + 1, + ) + }) +}) diff --git a/vscode/extension/tests/configuration.spec.ts b/vscode/extension/tests/configuration.spec.ts new file mode 100644 index 0000000000..6f187d5274 --- /dev/null +++ b/vscode/extension/tests/configuration.spec.ts @@ -0,0 +1,213 @@ +import { test, expect } from './fixtures' +import { + createVirtualEnvironment, + openServerPage, + pipInstall, + REPO_ROOT, + SUSHI_SOURCE_PATH, + waitForLoadedSQLMesh, +} from './utils' +import path from 'path' +import fs from 'fs-extra' + +async function setupPythonEnvironment(tempDir: string): Promise { + // Create a temporary directory for the virtual environment + const venvDir = path.join(tempDir, '.venv') + fs.mkdirSync(venvDir, { recursive: true }) + + // Create virtual environment + const pythonDetails = await createVirtualEnvironment(venvDir) + + // Install sqlmesh from the local repository with LSP support + const customMaterializations = path.join( + REPO_ROOT, + 'examples', + 'custom_materializations', + ) + const sqlmeshWithExtras = `${REPO_ROOT}[lsp,bigquery]` + await pipInstall(pythonDetails, [sqlmeshWithExtras, customMaterializations]) +} + +/** + * Creates an entrypoint file used to test the LSP configuration. + * + * The entrypoint file is a bash script that simply calls out to the + */ +const createEntrypointFile = ( + tempDir: string, + entrypointFileName: string, + bitToStripFromArgs = '', +): { + entrypointFile: string + fileWhereStoredInputs: string +} => { + const entrypointFile = path.join(tempDir, entrypointFileName) + const fileWhereStoredInputs = path.join(tempDir, 'inputs.txt') + const sqlmeshLSPFile = path.join(tempDir, '.venv/bin/sqlmesh_lsp') + + // Create the entrypoint file + fs.writeFileSync( + entrypointFile, + `#!/bin/bash +echo "$@" > ${fileWhereStoredInputs} +# Strip bitToStripFromArgs from the beginning of the args if it matches +if [[ "$1" == "${bitToStripFromArgs}" ]]; then + shift +fi +# Call the sqlmesh_lsp with the remaining arguments +${sqlmeshLSPFile} "$@"`, + { mode: 0o755 }, // Make it executable + ) + + return { + entrypointFile, + fileWhereStoredInputs, + } +} + +test.describe('Test LSP Entrypoint configuration', () => { + test('specify single entrypoint relative path', async ({ + page, + sharedCodeServer, + tempDir, + }) => { + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + + await setupPythonEnvironment(tempDir) + + const { fileWhereStoredInputs } = createEntrypointFile( + tempDir, + 'entrypoint.sh', + ) + + const settings = { + 'sqlmesh.lspEntrypoint': './entrypoint.sh', + } + // Write the settings to the settings.json file + const settingsPath = path.join(tempDir, '.vscode', 'settings.json') + fs.mkdirSync(path.dirname(settingsPath), { recursive: true }) + fs.writeFileSync(settingsPath, JSON.stringify(settings, null, 2)) + + await openServerPage(page, tempDir, sharedCodeServer) + + // Wait for the models folder to be visible + await page.waitForSelector('text=models') + + // Click on the models folder, excluding external_models + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + + // Open the customer_revenue_lifetime model + await page + .getByRole('treeitem', { name: 'customers.sql', exact: true }) + .locator('a') + .click() + + await waitForLoadedSQLMesh(page) + + // Check that the output file exists and contains the entrypoint script arguments + expect(fs.existsSync(fileWhereStoredInputs)).toBe(true) + expect(fs.readFileSync(fileWhereStoredInputs, 'utf8')).toBe(`--stdio +`) + }) + + test('specify one entrypoint absolute path', async ({ + page, + sharedCodeServer, + tempDir, + }) => { + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + + await setupPythonEnvironment(tempDir) + + const { entrypointFile, fileWhereStoredInputs } = createEntrypointFile( + tempDir, + 'entrypoint.sh', + ) + // Assert that the entrypoint file is an absolute path + expect(path.isAbsolute(entrypointFile)).toBe(true) + + const settings = { + 'sqlmesh.lspEntrypoint': `${entrypointFile}`, + } + // Write the settings to the settings.json file + const settingsPath = path.join(tempDir, '.vscode', 'settings.json') + fs.mkdirSync(path.dirname(settingsPath), { recursive: true }) + fs.writeFileSync(settingsPath, JSON.stringify(settings, null, 2)) + + await openServerPage(page, tempDir, sharedCodeServer) + + // Wait for the models folder to be visible + await page.waitForSelector('text=models') + + // Click on the models folder, excluding external_models + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + + // Open the customer_revenue_lifetime model + await page + .getByRole('treeitem', { name: 'customers.sql', exact: true }) + .locator('a') + .click() + + await waitForLoadedSQLMesh(page) + + // Check that the output file exists and contains the entrypoint script arguments + expect(fs.existsSync(fileWhereStoredInputs)).toBe(true) + expect(fs.readFileSync(fileWhereStoredInputs, 'utf8')).toBe(`--stdio +`) + }) + + test('specify entrypoint with arguments', async ({ + page, + sharedCodeServer, + tempDir, + }) => { + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + + await setupPythonEnvironment(tempDir) + + const { fileWhereStoredInputs } = createEntrypointFile( + tempDir, + 'entrypoint.sh', + '--argToIgnore', + ) + + const settings = { + 'sqlmesh.lspEntrypoint': './entrypoint.sh --argToIgnore', + } + // Write the settings to the settings.json file + const settingsPath = path.join(tempDir, '.vscode', 'settings.json') + fs.mkdirSync(path.dirname(settingsPath), { recursive: true }) + fs.writeFileSync(settingsPath, JSON.stringify(settings, null, 2)) + + await openServerPage(page, tempDir, sharedCodeServer) + + // Wait for the models folder to be visible + await page.waitForSelector('text=models') + + // Click on the models folder, excluding external_models + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + + // Open the customer_revenue_lifetime model + await page + .getByRole('treeitem', { name: 'customers.sql', exact: true }) + .locator('a') + .click() + + await waitForLoadedSQLMesh(page) + + // Check that the output file exists and contains the entrypoint script arguments + expect(fs.existsSync(fileWhereStoredInputs)).toBe(true) + expect(fs.readFileSync(fileWhereStoredInputs, 'utf8')) + .toBe(`--argToIgnore --stdio +`) + }) +}) diff --git a/vscode/extension/tests/diagnostics.spec.ts b/vscode/extension/tests/diagnostics.spec.ts new file mode 100644 index 0000000000..1c0e471e82 --- /dev/null +++ b/vscode/extension/tests/diagnostics.spec.ts @@ -0,0 +1,326 @@ +import { expect, test } from './fixtures' +import path from 'path' +import fs from 'fs-extra' +import os from 'os' +import { openProblemsView, openServerPage, SUSHI_SOURCE_PATH } from './utils' +import { createPythonInterpreterSettingsSpecifier } from './utils_code_server' +import { execAsync } from '../src/utilities/exec' +import yaml from 'yaml' + +test('Workspace diagnostics show up in the diagnostics panel', async ({ + page, + sharedCodeServer, + tempDir, +}) => { + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + await createPythonInterpreterSettingsSpecifier(tempDir) + + const configPath = path.join(tempDir, 'config.py') + const configContent = await fs.readFile(configPath, 'utf8') + const updatedContent = configContent.replace('enabled=False', 'enabled=True') + await fs.writeFile(configPath, updatedContent) + + await openServerPage(page, tempDir, sharedCodeServer) + // Wait for the models folder to be visible + await page.waitForSelector('text=models') + + // Click on the models folder, excluding external_models + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + + // Open the customer_revenue_lifetime model + await page + .getByRole('treeitem', { name: 'customers.sql', exact: true }) + .locator('a') + .click() + + await openProblemsView(page) + + await page.waitForSelector('text=problems') + await page.waitForSelector('text=All models should have an owner') +}) + +test.describe('Bad config.py/config.yaml file issues', () => { + const setup = async (tempDir: string) => { + // Run the sqlmesh CLI from the root of the repo using the local path + const sqlmeshCliPath = path.resolve(__dirname, '../../../.venv/bin/sqlmesh') + const result = await execAsync(sqlmeshCliPath, ['init', 'duckdb'], { + cwd: tempDir, + }) + expect(result.exitCode).toBe(0) + } + + test('sqlmesh init, then corrupted config.yaml, bad yaml', async ({ + page, + sharedCodeServer, + tempDir, + }) => { + await setup(tempDir) + await createPythonInterpreterSettingsSpecifier(tempDir) + + const configYamlPath = path.join(tempDir, 'config.yaml') + // Write an invalid YAML to config.yaml + await fs.writeFile(configYamlPath, 'invalid_yaml; asdfasudfy') + + await openServerPage(page, tempDir, sharedCodeServer) + + // Open full_model.sql model + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + await page + .getByRole('treeitem', { name: 'full_model.sql', exact: true }) + .locator('a') + .click() + + // Wait for the error to appear + await page.waitForSelector('text=Error creating context') + + await openProblemsView(page) + + // Asser that the error is present in the problems view + await page + .getByText('Invalid YAML configuration:') + .first() + .isVisible({ timeout: 5_000 }) + }) + + test('sqlmesh init, then corrupted config.yaml, bad parameters', async ({ + page, + sharedCodeServer, + }) => { + const tempDir = await fs.mkdtemp( + path.join(os.tmpdir(), 'vscode-test-tcloud-'), + ) + await setup(tempDir) + await createPythonInterpreterSettingsSpecifier(tempDir) + + const configYamlPath = path.join(tempDir, 'config.yaml') + // Write an invalid YAML to config.yaml + const config = { + gateway: 'test', + } + // Write config to the yaml file + await fs.writeFile(configYamlPath, yaml.stringify(config)) + + await openServerPage(page, tempDir, sharedCodeServer) + + // Open full_model.sql model + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + await page + .getByRole('treeitem', { name: 'full_model.sql', exact: true }) + .locator('a') + .click() + + // Wait for the error to appear + await page.waitForSelector('text=Error creating context') + + await openProblemsView(page) + + // Asser that the error is present in the problems view + await page + .getByText('Invalid project config:', { exact: true }) + .first() + .isVisible({ timeout: 5_000 }) + }) + + test('sushi example, correct python, bad config', async ({ + page, + sharedCodeServer, + tempDir, + }) => { + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + await createPythonInterpreterSettingsSpecifier(tempDir) + + const configPyPath = path.join(tempDir, 'config.py') + // Write an invalid Python to config.py + await fs.writeFile(configPyPath, 'config = {}') + + await openServerPage(page, tempDir, sharedCodeServer) + + // Open customers.sql model + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + await page + .getByRole('treeitem', { name: 'customers.sql', exact: true }) + .locator('a') + .click() + + // Expect the error to appear + await page.waitForSelector('text=Error creating context') + + await openProblemsView(page) + + // Assert that the error is present in the problems view + const errorElement = page + .getByText('Config needs to be a valid object of type') + .first() + await expect(errorElement).toBeVisible({ timeout: 5000 }) + }) + + test('sushi example, bad config.py', async ({ + page, + sharedCodeServer, + tempDir, + }) => { + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + await createPythonInterpreterSettingsSpecifier(tempDir) + + const configPyPath = path.join(tempDir, 'config.py') + // Write an invalid Python to config.py + await fs.writeFile(configPyPath, 'invalid_python_code = [1, 2, 3') + + await openServerPage(page, tempDir, sharedCodeServer) + + // Open customers.sql model + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + await page + .getByRole('treeitem', { name: 'customers.sql', exact: true }) + .locator('a') + .click() + + // Expect the error to appear + await page.waitForSelector('text=Error creating context') + + await openProblemsView(page) + + // Assert that the error is present in the problems view + const errorElement = page.getByText('Failed to load config file:').first() + await expect(errorElement).toBeVisible({ timeout: 5000 }) + }) +}) + +test.describe('Diagnostics for bad SQLMesh models', () => { + test('duplicate model names', async ({ page, sharedCodeServer, tempDir }) => { + // Copy over the sushi project + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + await createPythonInterpreterSettingsSpecifier(tempDir) + + // Duplicate the customers.sql model + const customersSqlPath = path.join(tempDir, 'models', 'customers.sql') + const duplicatedCustomersSqlPath = path.join( + tempDir, + 'models', + 'customers_duplicated.sql', + ) + await fs.copy(customersSqlPath, duplicatedCustomersSqlPath) + + await openServerPage(page, tempDir, sharedCodeServer) + + // Open full_model.sql model + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + await page + .getByRole('treeitem', { name: 'customers.sql', exact: true }) + .locator('a') + .click() + + // Wait for the error to appear + await page.waitForSelector('text=Error creating context') + + await openProblemsView(page) + + // Asser that the error is present in the problems view + await page + .getByText('Duplicate SQLMesh model name') + .first() + .isVisible({ timeout: 5_000 }) + }) + + test('bad model block', async ({ page, sharedCodeServer, tempDir }) => { + // Copy over the sushi project + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + await createPythonInterpreterSettingsSpecifier(tempDir) + + // Add a model with a bad model block + const customersSqlPath = path.join(tempDir, 'models', 'bad_model.sql') + const contents = + 'MODEL ( name sushi.bad_block, test); SELECT * FROM sushi.customers' + await fs.writeFile(customersSqlPath, contents) + + await page.goto( + `http://127.0.0.1:${sharedCodeServer.codeServerPort}/?folder=${tempDir}`, + ) + await page.waitForLoadState('networkidle') + + // Open the customers.sql model + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + await page + .getByRole('treeitem', { name: 'customers.sql', exact: true }) + .locator('a') + .click() + + // Wait for the error to appear + await page.waitForSelector('text=Error creating context') + + await openProblemsView(page) + + // Assert error is present in the problems view + const errorElement = page + .getByText("Required keyword: 'value' missing for") + .first() + await expect(errorElement).toBeVisible({ timeout: 5000 }) + }) +}) + +test.describe('Diagnostics for bad audits', () => { + test('bad audit block in audit', async ({ + page, + sharedCodeServer, + tempDir, + }) => { + // Copy over the sushi project + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + await createPythonInterpreterSettingsSpecifier(tempDir) + + // Make an existing audit file a bad audit + const auditFilePath = path.join( + tempDir, + 'audits', + 'assert_item_price_above_zero.sql', + ) + const readFile = await fs.readFile(auditFilePath, 'utf8') + const updatedContent = readFile.replace('AUDIT (', 'AUDIT ( rubbish value,') + await fs.writeFile(auditFilePath, updatedContent) + + // Navigate to the code-server instance + await openServerPage(page, tempDir, sharedCodeServer) + + // Open a the customers.sql model + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + await page + .getByRole('treeitem', { name: 'customers.sql', exact: true }) + .locator('a') + .click() + + // Wait for the error to appear + await page.waitForSelector('text=Error creating context') + + await openProblemsView(page) + + // Assert that the error is present in the problems view + const errorElement = page + .getByText("Invalid extra fields {'rubbish'} in the audit definition") + .first() + await expect(errorElement).toBeVisible({ timeout: 5000 }) + }) +}) diff --git a/vscode/extension/tests/extension.setup.ts b/vscode/extension/tests/extension.setup.ts new file mode 100644 index 0000000000..7447e53704 --- /dev/null +++ b/vscode/extension/tests/extension.setup.ts @@ -0,0 +1,80 @@ +import { test as setup } from '@playwright/test' +import { execSync } from 'child_process' +import path from 'path' +import fs from 'fs-extra' +import { createHash } from 'crypto' +import { tmpdir } from 'os' + +setup('prepare extension', async () => { + console.log('Setting up extension for Playwright tests...') + + const extensionDir = path.join(__dirname, '..') + const testSetupDir = path.join(extensionDir, '.test_setup') + const extensionsDir = path.join(testSetupDir, 'extensions') + + // Clean up any existing test setup directory + + // Get the extension version from package.json + const packageJson = JSON.parse( + fs.readFileSync(path.join(extensionDir, 'package.json'), 'utf-8'), + ) + const version = packageJson.version + const extensionName = packageJson.name || 'sqlmesh' + + // Look for the specific version .vsix file + const vsixFileName = `${extensionName}-${version}.vsix` + const vsixPath = path.join(extensionDir, vsixFileName) + + if (!fs.existsSync(vsixPath)) { + throw new Error( + `Extension file ${vsixFileName} not found. Run "pnpm run vscode:package" first.`, + ) + } + + // Create a temporary user data directory for the installation + const tempUserDataDir = await fs.mkdtemp( + path.join(tmpdir(), 'vscode-test-install-user-data-'), + ) + + try { + // Check if in .test_setup there is a extension hash file which contains the hash of the extension + // If it does, check if the hash is the same as the hash of the extension in the vsix file + // If it is, skip the installation + // If it is not, remove the extension hash file and install the extension + const extensionHashFile = path.join(testSetupDir, 'extension-hash.txt') + console.log('extensionHashFile', extensionHashFile) + if (fs.existsSync(extensionHashFile)) { + const extensionHash = fs.readFileSync(extensionHashFile, 'utf-8') + const vsixHash = await hashFile(vsixPath) + if (extensionHash === vsixHash) { + console.log('Extension already installed') + return + } + } + + await fs.remove(testSetupDir) + await fs.ensureDir(testSetupDir) + await fs.ensureDir(extensionsDir) + + console.log(`Installing extension: ${vsixFileName}`) + execSync( + `pnpm run code-server --user-data-dir "${tempUserDataDir}" --extensions-dir "${extensionsDir}" --install-extension "${vsixPath}"`, + { + stdio: 'inherit', + cwd: extensionDir, + }, + ) + + // Write the hash of the extension to the extension hash file + const extensionHash = await hashFile(vsixPath) + await fs.writeFile(extensionHashFile, extensionHash) + } finally { + // Clean up temporary user data directory + await fs.remove(tempUserDataDir) + } +}) + +async function hashFile(filePath: string): Promise { + const fileBuffer = await fs.readFile(filePath) + return createHash('sha256').update(fileBuffer).digest('hex') +} diff --git a/vscode/extension/tests/extension.teardown.ts b/vscode/extension/tests/extension.teardown.ts new file mode 100644 index 0000000000..587ce695a6 --- /dev/null +++ b/vscode/extension/tests/extension.teardown.ts @@ -0,0 +1,16 @@ +import { test as teardown } from '@playwright/test' +import path from 'path' +import fs from 'fs-extra' + +teardown('cleanup extension', async () => { + console.log('Cleaning up extension test setup...') + + const extensionDir = path.join(__dirname, '..') + const testSetupDir = path.join(extensionDir, '.test_setup') + + // Clean up test setup directory + if (fs.existsSync(testSetupDir)) { + await fs.remove(testSetupDir) + console.log('Test setup directory cleaned up') + } +}) diff --git a/vscode/extension/tests/external_models.spec.ts b/vscode/extension/tests/external_models.spec.ts new file mode 100644 index 0000000000..4fdd19fa61 --- /dev/null +++ b/vscode/extension/tests/external_models.spec.ts @@ -0,0 +1,65 @@ +import { + openServerPage, + SUSHI_SOURCE_PATH, + waitForLoadedSQLMesh, +} from './utils' +import { createPythonInterpreterSettingsSpecifier } from './utils_code_server' +import { test, expect } from './fixtures' +import fs from 'fs-extra' +import path from 'path' + +test.describe('External model files trigger lsp', () => { + test('external_models.yaml', async ({ page, sharedCodeServer, tempDir }) => { + const file = 'external_models.yaml' + + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + + // Assert external_models.yaml exists + const externalModelsYamlPath = path.join(tempDir, file) + expect(await fs.pathExists(externalModelsYamlPath)).toBe(true) + + await createPythonInterpreterSettingsSpecifier(tempDir) + await openServerPage(page, tempDir, sharedCodeServer) + + // Wait for the models folder to be visible + await page.waitForSelector('text=models') + + // Click on the external_models file (e.g., external_models.yaml or external_models.yml) + await page + .getByRole('treeitem', { name: file, exact: true }) + .locator('a') + .click() + + await page.waitForSelector('text=raw.demographics') + await waitForLoadedSQLMesh(page) + }) + + test('external_models.yml', async ({ page, sharedCodeServer, tempDir }) => { + const file = 'external_models.yml' + + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + + // Move external_models.yaml to external_models.yml + const externalModelsYamlPath = path.join(tempDir, 'external_models.yaml') + const externalModelsYmlPath = path.join(tempDir, file) + await fs.rename(externalModelsYamlPath, externalModelsYmlPath) + + // Assert external_models.yml exists + expect(await fs.pathExists(externalModelsYmlPath)).toBe(true) + + await createPythonInterpreterSettingsSpecifier(tempDir) + await openServerPage(page, tempDir, sharedCodeServer) + + // Wait for the models folder to be visible + await page.waitForSelector('text=models') + + // Click on the external_models.yml file + await page + .getByRole('treeitem', { name: file, exact: true }) + .locator('a') + .click() + + await page.waitForSelector('text=raw.demographics') + await waitForLoadedSQLMesh(page) + }) +}) diff --git a/vscode/extension/tests/find_references.spec.ts b/vscode/extension/tests/find_references.spec.ts new file mode 100644 index 0000000000..ccc5eaf916 --- /dev/null +++ b/vscode/extension/tests/find_references.spec.ts @@ -0,0 +1,637 @@ +import { test, expect, Page } from './fixtures' +import fs from 'fs-extra' +import { + findAllReferences, + goToReferences, + openServerPage, + SUSHI_SOURCE_PATH, + waitForLoadedSQLMesh, +} from './utils' +import { createPythonInterpreterSettingsSpecifier } from './utils_code_server' + +// Helper function to set up a test environment for model references +async function setupModelTestEnvironment(tempDir: string): Promise { + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + await createPythonInterpreterSettingsSpecifier(tempDir) +} + +// Helper function to navigate to models folder +async function navigateToModels(page: Page) { + await page.waitForSelector('text=models') + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() +} + +// Helper function to navigate to audits folder +async function navigateToAudits(page: Page) { + await page.waitForSelector('text=audits') + await page + .getByRole('treeitem', { name: 'audits', exact: true }) + .locator('a') + .click() +} + +// Helper function to open customers.sql and wait for SQLMesh context +async function openCustomersFile(page: Page) { + await navigateToModels(page) + await page + .getByRole('treeitem', { name: 'customers.sql', exact: true }) + .locator('a') + .click() + await page.waitForSelector('text=grain') + await waitForLoadedSQLMesh(page) +} + +// Helper function to open top_waiters.sql and wait for SQLMesh context +async function openTopWaitersFile(page: Page) { + await navigateToModels(page) + await page + .getByRole('treeitem', { name: 'top_waiters.sql', exact: true }) + .locator('a') + .click() + await page.waitForSelector('text=grain') + await waitForLoadedSQLMesh(page) +} + +test.describe('Model References', () => { + test('Go to References (Shift+F12) for Model usage', async ({ + page, + sharedCodeServer, + tempDir, + }) => { + await setupModelTestEnvironment(tempDir) + await openServerPage(page, tempDir, sharedCodeServer) + + // Open customers.sql which contains references to other models + await openCustomersFile(page) + + // Step 4: Position cursor on the sushi.orders model reference in the SQL query + await page.locator('text=sushi.orders').first().click() + + // Step 5: Trigger "Go to References" command using Shift+F12 keyboard shortcut + await goToReferences(page) + + // Step 6: Wait for VSCode references panel to appear at the bottom + await page.waitForSelector('text=References') + + // Step 7: Ensure references panel has populated with all usages of sushi.orders model + await page.waitForFunction( + () => { + const referenceElements = document.querySelectorAll( + '.reference-item, .monaco-list-row, .references-view .tree-row', + ) + return referenceElements.length >= 6 + }, + { timeout: 10000 }, + ) + + // Step 8: Verify the references panel shows both SQL and Python files containing references + const hasReferences = await page.evaluate(() => { + const body = document.body.textContent || '' + return ( + body.includes('References') && + (body.includes('.sql') || body.includes('.py')) + ) + }) + + expect(hasReferences).toBe(true) + + // Step 9: Find and click on the orders.py reference to navigate to the model definition + let clickedReference = false + + const referenceItems = page.locator( + '.monaco-list-row, .reference-item, .monaco-tl-row', + ) + const count = await referenceItems.count() + + for (let i = 0; i < count; i++) { + const item = referenceItems.nth(i) + const text = await item.textContent() + + // Search for the orders.py reference which contains the Python model definition + if (text && text.includes('orders.py')) { + await item.click() + clickedReference = true + break + } + } + + expect(clickedReference).toBe(true) + + // Step 10: Verify successful navigation to orders.py by checking for unique Python code + await expect(page.locator('text=list(range(0, 100))')).toBeVisible() + }) + + test('Find All References (Alt+Shift+F12) for Model', async ({ + page, + sharedCodeServer, + tempDir, + }) => { + await setupModelTestEnvironment(tempDir) + + await openServerPage(page, tempDir, sharedCodeServer) + + // Open customers.sql which contains multiple model references + await openCustomersFile(page) + + // Step 4: Click on sushi.orders model reference to position cursor + await page.locator('text=sushi.orders').first().click() + + // Step 5: Trigger "Find All References" command using Alt+Shift+F12 (or +Shift+F12 on Windows/Linux) + await findAllReferences(page) + + let clickedReference = false + const referenceItems = page.locator( + '.monaco-list-row, .reference-item, .monaco-tl-row', + ) + const count = await referenceItems.count() + + // Step 6: Iterate through references to find and click on orders.py + for (let i = 0; i < count; i++) { + const item = referenceItems.nth(i) + const text = await item.textContent() + + // Find the orders.py reference which contains the model implementation + if (text && text.includes('orders.py')) { + await item.click() + + clickedReference = true + break + } + } + + expect(clickedReference).toBe(true) + + // Step 7: Verify navigation to orders.py by checking for Python import statement + await expect(page.locator('text=import random')).toBeVisible() + + // Step 8: Click on the import statement to ensure file is fully loaded and interactive + await page.locator('text=import random').first().click() + + // Step 9: Final verification that we're viewing the correct Python model file + await expect(page.locator('text=list(range(0, 100))')).toBeVisible() + }) + + test('Go to References for Model from Audit', async ({ + page, + sharedCodeServer, + tempDir, + }) => { + await setupModelTestEnvironment(tempDir) + await openServerPage(page, tempDir, sharedCodeServer) + + // Open assert_item_price_above_zero.sql audit file which references sushi.items model + await navigateToAudits(page) + await page + .getByRole('treeitem', { + name: 'assert_item_price_above_zero.sql', + exact: true, + }) + .locator('a') + .click() + + // Wait for audit file to load and SQLMesh context to initialize + await page.waitForSelector('text=standalone') + await waitForLoadedSQLMesh(page) + + // Step 4: Click on sushi.items model reference in the audit query + await page.locator('text=sushi.items').first().click() + + // Step 5: Trigger "Go to References" to find all places where sushi.items is used + await goToReferences(page) + + // Step 6: Wait for VSCode references panel to appear + await page.waitForSelector('text=References') + + // Step 7: Ensure references panel shows multiple files that reference sushi.items + await page.waitForFunction( + () => { + const referenceElements = document.querySelectorAll( + '.reference-item, .monaco-list-row, .references-view .tree-row', + ) + return referenceElements.length >= 4 + }, + { timeout: 10000 }, + ) + + // Step 8: Verify references panel contains both audit and model files + const hasReferences = await page.evaluate(() => { + const body = document.body.textContent || '' + return ( + body.includes('References') && + (body.includes('.sql') || body.includes('.py')) + ) + }) + + expect(hasReferences).toBe(true) + + // 9. Click on one of the references to navigate to it + let clickedReference = false + + const referenceItems = page.locator( + '.monaco-list-row, .reference-item, .monaco-tl-row', + ) + const count = await referenceItems.count() + + for (let i = 0; i < count; i++) { + const item = referenceItems.nth(i) + const text = await item.textContent() + + // Search for the customer_revenue_by_day.sql file which joins with sushi.items + if (text && text.includes('customer_revenue_by_day.sql')) { + await item.click() + clickedReference = true + break + } + } + + expect(clickedReference).toBe(true) + + // Step 10: Verify navigation to customer_revenue_by_day.sql by checking for SQL JOIN syntax + await expect(page.locator('text=LEFT JOIN')).toBeVisible() + + // Step 11: Click on LEFT JOIN to ensure file is interactive and verify content + await page.locator('text=LEFT JOIN').first().click() + await expect( + page.locator('text=FROM sushi.order_items AS oi'), + ).toBeVisible() + }) + + test('Find All Model References from Audit', async ({ + page, + sharedCodeServer, + tempDir, + }) => { + await setupModelTestEnvironment(tempDir) + + await openServerPage(page, tempDir, sharedCodeServer) + + // Open the audit file that validates item prices + await navigateToAudits(page) + await page + .getByRole('treeitem', { + name: 'assert_item_price_above_zero.sql', + exact: true, + }) + .locator('a') + .click() + + // Ensure audit file and SQLMesh context are fully loaded + await page.waitForSelector('text=standalone') + await waitForLoadedSQLMesh(page) + + // Step 4: Position cursor on sushi.items model reference + await page.locator('text=sushi.items').first().click() + + // Step 5: Use Find All References to see all occurrences across the project + await findAllReferences(page) + + // Assert that the references panel shows the correct files + await page.waitForSelector('text=References') + await page.waitForSelector('text=customer_revenue_by_day.sql') + await page.waitForSelector('text=items.py') + }) +}) + +test.describe('CTE References', () => { + test('Go to references from definition of CTE', async ({ + page, + sharedCodeServer, + tempDir, + }) => { + await setupModelTestEnvironment(tempDir) + await openServerPage(page, tempDir, sharedCodeServer) + + await openCustomersFile(page) + + // Click on the CTE definition "current_marketing_outer" at line 20 to position cursor + await page.locator('text=current_marketing_outer').first().click() + + // Use keyboard shortcut to find all references + await goToReferences(page) + + // Wait for the references to appear + await page.waitForSelector('text=References') + + // Wait for reference panel to populate + await page.waitForFunction( + () => { + const referenceElements = document.querySelectorAll( + '.reference-item, .monaco-list-row, .references-view .tree-row', + ) + return referenceElements.length >= 2 + }, + { timeout: 5000 }, + ) + + // Verify that the customers.sql file is shown in results + await expect(page.locator('text=customers.sql').first()).toBeVisible() + + // Check that both CTE definition and usage are listed in references + await page.waitForSelector('text=References') + await page.waitForSelector('text=WITH current_marketing_outer AS') + await page.waitForSelector('text=FROM current_marketing_outer') + }) + + test('Go to references from usage of CTE', async ({ + page, + sharedCodeServer, + tempDir, + }) => { + await setupModelTestEnvironment(tempDir) + + await openServerPage(page, tempDir, sharedCodeServer) + + await openCustomersFile(page) + + // Click on the CTE usage this time for "current_marketing_outer" + await page.locator('text=FROM current_marketing_outer').click({ + position: { x: 80, y: 5 }, // Clicks on the usage rather than first which was definition + }) + + // Use keyboard shortcut to go to references + await goToReferences(page) + + // Wait for the references to appear + await page.waitForSelector('text=References') + + // Better assertions: wait for reference panel to populate + await page.waitForFunction( + () => { + const referenceElements = document.querySelectorAll( + '.reference-item, .monaco-list-row, .references-view .tree-row', + ) + return referenceElements.length >= 2 + }, + { timeout: 5000 }, + ) + + await page.waitForSelector('text=References') + await page.waitForSelector('text=WITH current_marketing_outer AS') + await page.waitForSelector('text=FROM current_marketing_outer') + + // Verify that the customers.sql file is shown in results + await expect(page.locator('text=customers.sql').first()).toBeVisible() + }) + + test('Go to references for nested CTE', async ({ + page, + sharedCodeServer, + tempDir, + }) => { + await setupModelTestEnvironment(tempDir) + await openServerPage(page, tempDir, sharedCodeServer) + + await openCustomersFile(page) + + // Click on the nested CTE "current_marketing" + await page.locator('text=WITH current_marketing AS').click({ + position: { x: 100, y: 5 }, // Click on the CTE name part + }) + + // Use keyboard shortcut to find all references + await goToReferences(page) + + // Wait for the references to appear + await page.waitForSelector('text=References') + + // Wait for reference panel to populate + await page.waitForFunction( + () => { + const referenceElements = document.querySelectorAll( + '.reference-item, .monaco-list-row, .references-view .tree-row', + ) + return referenceElements.length >= 2 + }, + { timeout: 5000 }, + ) + + // Verify that the customers.sql file is shown in results + await expect(page.locator('text=customers.sql').first()).toBeVisible() + + // Check that both CTE definition and usage are listed in references + await page.waitForSelector('text=References') + await page.waitForSelector('text=WITH current_marketing AS') + await page.waitForSelector('text=FROM current_marketing') + }) + + test('Find all references for CTE', async ({ + page, + sharedCodeServer, + tempDir, + }) => { + await setupModelTestEnvironment(tempDir) + + await openServerPage(page, tempDir, sharedCodeServer) + await openCustomersFile(page) + + // Click on the CTE definition "current_marketing_outer" + await page.locator('text=current_marketing_outer').first().click() + + // Use keyboard shortcut to find all references + await findAllReferences(page) + + // Verify references contains expected content + await page.waitForSelector('text=References') + await page.waitForSelector('text=WITH current_marketing_outer AS') + await page.waitForSelector('text=FROM current_marketing_outer') + + // Verify that the customers.sql file is shown in results + await expect(page.locator('text=customers.sql').first()).toBeVisible() + }) + + test('Find all references from usage for CTE', async ({ + page, + sharedCodeServer, + tempDir, + }) => { + await setupModelTestEnvironment(tempDir) + + await openServerPage(page, tempDir, sharedCodeServer) + + await openCustomersFile(page) + + // Click on the CTE usage of "current_marketing_outer" using last + await page.locator('text=current_marketing_outer').last().click() + + // Use keyboard shortcut to find all references + await findAllReferences(page) + + // Verify references contains expected content + await page.waitForSelector('text=References') + await page.waitForSelector('text=WITH current_marketing_outer AS') + await page.waitForSelector('text=FROM current_marketing_outer') + + // Verify that the customers.sql file is shown in results + await expect(page.locator('text=customers.sql').first()).toBeVisible() + }) + + test('Find all references for nested CTE', async ({ + page, + sharedCodeServer, + tempDir, + }) => { + await setupModelTestEnvironment(tempDir) + + await openServerPage(page, tempDir, sharedCodeServer) + await openCustomersFile(page) + + // Click on the nested CTE "current_marketing" at line 33 + // We need to be more specific to get the inner one + await page.locator('text=WITH current_marketing AS').click({ + position: { x: 100, y: 5 }, // Click on the CTE name part + }) + + // Use keyboard shortcut to find all references + await findAllReferences(page) + + // Verify references contains expected content + await page.waitForSelector('text=References') + await page.waitForSelector('text=WITH current_marketing AS') + await page.waitForSelector('text=FROM current_marketing') + + // Verify that the customers.sql file is shown in results + await expect(page.locator('text=customers.sql').first()).toBeVisible() + }) +}) + +test.describe('Macro References', () => { + test('Go to References for @ADD_ONE macro', async ({ + page, + sharedCodeServer, + tempDir, + }) => { + await setupModelTestEnvironment(tempDir) + + await openServerPage(page, tempDir, sharedCodeServer) + await openTopWaitersFile(page) + + // Click on the @ADD_ONE macro usage + await page.locator('text=@ADD_ONE').first().click() + + // Use keyboard shortcut to find all references + await goToReferences(page) + + // Wait for the references to appear + await page.waitForSelector('text=References') + + // Wait for reference panel to populate + await page.waitForFunction( + () => { + const referenceElements = document.querySelectorAll( + '.reference-item, .monaco-list-row, .references-view .tree-row', + ) + return referenceElements.length >= 2 + }, + { timeout: 5000 }, + ) + + // Verify that both the definition and two usages are shown + await expect(page.locator('text=utils.py').first()).toBeVisible() + await expect(page.locator('text=top_waiters.sql').first()).toBeVisible() + await expect(page.locator('text=customers.sql').first()).toBeVisible() + }) + + test('Find All References for @MULTIPLY macro', async ({ + page, + sharedCodeServer, + tempDir, + }) => { + await setupModelTestEnvironment(tempDir) + + await openServerPage(page, tempDir, sharedCodeServer) + + await openTopWaitersFile(page) + + // Click on the @MULTIPLY macro usage and then navigate to it + await page.locator('text=@MULTIPLY').first().click() + + // Use keyboard shortcut to find all references + await findAllReferences(page) + + // Verify references contains expected content + await page.waitForSelector('text=References') + + // Verify that both utils.py (definition) and top_waiters.sql (usage) are shown + await expect(page.locator('text=utils.py').first()).toBeVisible() + await expect(page.locator('text=top_waiters.sql').first()).toBeVisible() + + // Click on the utils.py reference to navigate to the macro definition + let clickedReference = false + const referenceItems = page.locator( + '.monaco-list-row, .reference-item, .monaco-tl-row', + ) + const count = await referenceItems.count() + + for (let i = 0; i < count; i++) { + const item = referenceItems.nth(i) + const text = await item.textContent() + + // Find the utils.py reference which contains the macro definition + if (text && text.includes('utils.py')) { + await item.click() + clickedReference = true + break + } + } + + expect(clickedReference).toBe(true) + + // Verify it appeared and click on it + await expect(page.locator('text=def multiply')).toBeVisible() + await page.locator('text=def multiply').first().click() + + // Verify navigation to utils.py by checking the import that appears there + await expect( + page.locator('text=from sqlmesh import SQL, macro'), + ).toBeVisible() + }) + + test('Go to References for @SQL_LITERAL macro', async ({ + page, + sharedCodeServer, + tempDir, + }) => { + await setupModelTestEnvironment(tempDir) + + await openServerPage(page, tempDir, sharedCodeServer) + await openTopWaitersFile(page) + + // Click on the @SQL_LITERAL macro usage + await page.locator('text=@SQL_LITERAL').first().click() + + // Use keyboard shortcut to find references + await goToReferences(page) + + // Wait for the references to appear + await page.waitForSelector('text=References') + + // Wait for reference panel to populate + await page.waitForFunction( + () => { + const referenceElements = document.querySelectorAll( + '.reference-item, .monaco-list-row, .references-view .tree-row', + ) + return referenceElements.length >= 2 + }, + { timeout: 5000 }, + ) + + // Verify that references include both definition and usage + const hasReferences = await page.evaluate(() => { + const body = document.body.textContent || '' + return ( + body.includes('References') && + body.includes('.py') && + body.includes('.sql') + ) + }) + + expect(hasReferences).toBe(true) + + await expect(page.locator('text=utils.py').first()).toBeVisible() + await expect(page.locator('text=top_waiters.sql').first()).toBeVisible() + }) +}) diff --git a/vscode/extension/tests/fixtures.ts b/vscode/extension/tests/fixtures.ts new file mode 100644 index 0000000000..7a10d811fa --- /dev/null +++ b/vscode/extension/tests/fixtures.ts @@ -0,0 +1,64 @@ +import { test as base } from '@playwright/test' +import path from 'path' +import fs from 'fs-extra' +import os from 'os' +import { + startCodeServer, + stopCodeServer, + CodeServerContext, +} from './utils_code_server' + +// Worker-scoped fixture to start/stop VS Code server once per worker +export const test = base.extend< + // eslint-disable-next-line @typescript-eslint/no-empty-object-type + {}, + { sharedCodeServer: CodeServerContext; tempDir: string } +>({ + sharedCodeServer: [ + // eslint-disable-next-line no-empty-pattern + async ({}, use) => { + // Create a temporary directory for the shared server + const tempDir = await fs.mkdtemp( + path.join(os.tmpdir(), 'vscode-test-shared-server-'), + ) + + // Start the code server once per worker + const context = await startCodeServer({ + tempDir, + }) + + console.log( + `Started shared VS Code server for worker ${test.info().workerIndex} on port ${context.codeServerPort}`, + ) + + // Provide the context to all tests in this worker + await use(context) + + // Clean up after all tests in this worker are done + console.log(`Stopping shared VS Code server`) + await stopCodeServer(context) + }, + { scope: 'worker', auto: true }, + ], + tempDir: [ + // eslint-disable-next-line no-empty-pattern + async ({}, use) => { + // Create a temporary directory for each test + const tempDir = await fs.mkdtemp( + path.join(os.tmpdir(), 'vscode-test-temp-'), + ) + console.log(`Created temporary directory: ${tempDir}`) + await use(tempDir) + + // Clean up after each test + console.log(`Cleaning up temporary directory: ${tempDir}`) + await fs.remove(tempDir) + }, + // eslint-disable-next-line @typescript-eslint/ban-ts-comment + // @ts-expect-error + { auto: true }, + ], +}) + +// Export expect and Page from Playwright for convenience +export { expect, Page } from '@playwright/test' diff --git a/vscode/extension/tests/format.spec.ts b/vscode/extension/tests/format.spec.ts new file mode 100644 index 0000000000..c8a98a066c --- /dev/null +++ b/vscode/extension/tests/format.spec.ts @@ -0,0 +1,46 @@ +import { test, expect } from './fixtures' +import fs from 'fs-extra' +import { + openServerPage, + runCommand, + SUSHI_SOURCE_PATH, + waitForLoadedSQLMesh, +} from './utils' +import { createPythonInterpreterSettingsSpecifier } from './utils_code_server' + +test('Format project works correctly', async ({ + page, + sharedCodeServer, + tempDir, +}) => { + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + + await createPythonInterpreterSettingsSpecifier(tempDir) + await openServerPage(page, tempDir, sharedCodeServer) + + // Wait for the models folder to be visible + await page.waitForSelector('text=models') + + // Click on the models folder, excluding external_models + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + + // Open the customer_revenue_lifetime model + await page + .getByRole('treeitem', { name: 'customers.sql', exact: true }) + .locator('a') + .click() + + await page.waitForSelector('text=grain') + await waitForLoadedSQLMesh(page) + + // Format the project + await runCommand(page, 'SQLMesh: Format Project') + + // Check that the notification appears saying 'Project formatted successfully' + await expect( + page.getByText('Project formatted successfully', { exact: true }), + ).toBeVisible() +}) diff --git a/vscode/extension/tests/go_to_definition.spec.ts b/vscode/extension/tests/go_to_definition.spec.ts new file mode 100644 index 0000000000..3b85c73f27 --- /dev/null +++ b/vscode/extension/tests/go_to_definition.spec.ts @@ -0,0 +1,77 @@ +import { test, expect } from './fixtures' +import fs from 'fs-extra' +import { + goToDefinition, + openServerPage, + SUSHI_SOURCE_PATH, + waitForLoadedSQLMesh, +} from './utils' +import { createPythonInterpreterSettingsSpecifier } from './utils_code_server' + +test('Stop server works', async ({ page, sharedCodeServer, tempDir }) => { + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + await createPythonInterpreterSettingsSpecifier(tempDir) + await openServerPage(page, tempDir, sharedCodeServer) + + // Wait for the models folder to be visible + await page.waitForSelector('text=models') + + // Click on the models folder + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + + // Open the customer_revenue_lifetime model + await page + .getByRole('treeitem', { name: 'top_waiters.sql', exact: true }) + .locator('a') + .click() + + await page.waitForSelector('text=grain') + await waitForLoadedSQLMesh(page) + + // Render the model + await page.locator('text=@MULTIPLY').click() + await goToDefinition(page) + + // Check if the model is rendered by check if "`oi`.`order_id` AS `order_id`," is in the window + await expect(page.locator('text=def multiply(')).toBeVisible() +}) + +test('Go to definition for model', async ({ + page, + sharedCodeServer, + tempDir, +}) => { + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + await createPythonInterpreterSettingsSpecifier(tempDir) + + // Navigate to code-server instance + await openServerPage(page, tempDir, sharedCodeServer) + + // Wait for the models folder to be visible + await page.waitForSelector('text=models') + + // Click on the models folder + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + + // Open the top_waiters model + await page + .getByRole('treeitem', { name: 'top_waiters.sql', exact: true }) + .locator('a') + .click() + + await page.waitForSelector('text=grain') + await waitForLoadedSQLMesh(page) + + // Go to definition for the model + await page.locator('text=sushi.waiter_revenue_by_day').first().click() + await goToDefinition(page) + await expect( + page.locator('text=SUM(oi.quantity * i.price)::DOUBLE AS revenue'), + ).toBeVisible() +}) diff --git a/vscode/extension/tests/hints.spec.ts b/vscode/extension/tests/hints.spec.ts new file mode 100644 index 0000000000..a74f8e184b --- /dev/null +++ b/vscode/extension/tests/hints.spec.ts @@ -0,0 +1,41 @@ +import { test, expect } from './fixtures' +import fs from 'fs-extra' +import { + openServerPage, + SUSHI_SOURCE_PATH, + waitForLoadedSQLMesh, +} from './utils' +import { createPythonInterpreterSettingsSpecifier } from './utils_code_server' + +test('Model type hinting', async ({ page, sharedCodeServer, tempDir }) => { + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + await createPythonInterpreterSettingsSpecifier(tempDir) + await openServerPage(page, tempDir, sharedCodeServer) + + // Wait for the models folder to be visible + await page.waitForSelector('text=models') + + // Click on the models folder + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + + // Open the customers_revenue_by_day model + await page + .getByRole('treeitem', { + name: 'customer_revenue_by_day.sql', + exact: true, + }) + .locator('a') + .click() + + await page.waitForSelector('text=grain') + await waitForLoadedSQLMesh(page) + + // Wait a moment for hints to appear + await page.waitForTimeout(500) + + // Check if the hint is visible + expect(await page.locator('text="country code"::INT').count()).toBe(1) +}) diff --git a/vscode/extension/tests/lineage.spec.ts b/vscode/extension/tests/lineage.spec.ts new file mode 100644 index 0000000000..66e3048246 --- /dev/null +++ b/vscode/extension/tests/lineage.spec.ts @@ -0,0 +1,233 @@ +import { test, Page } from './fixtures' +import path from 'path' +import fs from 'fs-extra' +import os from 'os' +import { + openLineageView, + openServerPage, + SUSHI_SOURCE_PATH, + waitForLoadedSQLMesh, +} from './utils' +import { writeFileSync } from 'fs' +import { createPythonInterpreterSettingsSpecifier } from './utils_code_server' + +/** + * Helper function to launch VS Code and test lineage with given project path config + */ +async function testLineageWithProjectPath(page: Page): Promise { + await openLineageView(page) + await waitForLoadedSQLMesh(page) +} + +test('Lineage panel renders correctly - no project path config (default)', async ({ + page, + sharedCodeServer, + tempDir, +}) => { + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + await createPythonInterpreterSettingsSpecifier(tempDir) + + await openServerPage(page, tempDir, sharedCodeServer) + await testLineageWithProjectPath(page) +}) + +test('Lineage panel renders correctly - relative project path', async ({ + page, + sharedCodeServer, + tempDir, +}) => { + const projectDir = path.join(tempDir, 'projects', 'sushi') + await fs.copy(SUSHI_SOURCE_PATH, projectDir) + + const settings = { + 'sqlmesh.projectPaths': ['./projects/sushi'], + 'python.defaultInterpreterPath': sharedCodeServer.defaultPythonInterpreter, + } + await fs.ensureDir(path.join(tempDir, '.vscode')) + await fs.writeJson(path.join(tempDir, '.vscode', 'settings.json'), settings, { + spaces: 2, + }) + + try { + await openServerPage(page, tempDir, sharedCodeServer) + await testLineageWithProjectPath(page) + } finally { + await fs.remove(tempDir) + } +}) + +test('Lineage panel renders correctly - absolute project path', async ({ + page, + tempDir, + sharedCodeServer, +}) => { + // Copy the sushi project to temporary directory + const projectDir = path.join(tempDir, 'projects', 'sushi') + await fs.ensureDir(path.join(tempDir, '.vscode')) + await fs.copy(SUSHI_SOURCE_PATH, projectDir) + + const settings = { + 'sqlmesh.projectPaths': [projectDir], + 'python.defaultInterpreterPath': sharedCodeServer.defaultPythonInterpreter, + } + await fs.writeJson(path.join(tempDir, '.vscode', 'settings.json'), settings, { + spaces: 2, + }) + + await openServerPage(page, tempDir, sharedCodeServer) + await testLineageWithProjectPath(page) +}) + +test('Lineage panel renders correctly - relative project outside of workspace', async ({ + page, + sharedCodeServer, + tempDir, +}) => { + const projectDir = path.join(tempDir, 'projects', 'sushi') + await fs.copy(SUSHI_SOURCE_PATH, projectDir) + + const workspaceDir = path.join(tempDir, 'workspace') + await fs.ensureDir(workspaceDir) + + const settings = { + 'sqlmesh.projectPaths': ['./../projects/sushi'], + 'python.defaultInterpreterPath': sharedCodeServer.defaultPythonInterpreter, + } + await fs.ensureDir(path.join(workspaceDir, '.vscode')) + await fs.writeJson( + path.join(workspaceDir, '.vscode', 'settings.json'), + settings, + { spaces: 2 }, + ) + await openServerPage(page, workspaceDir, sharedCodeServer) + await testLineageWithProjectPath(page) +}) + +test('Lineage panel renders correctly - absolute path project outside of workspace', async ({ + page, + sharedCodeServer, + tempDir, +}) => { + const projectDir = path.join(tempDir, 'projects', 'sushi') + await fs.copy(SUSHI_SOURCE_PATH, projectDir) + + const workspaceDir = path.join(tempDir, 'workspace') + await fs.ensureDir(workspaceDir) + + const settings = { + 'sqlmesh.projectPaths': [projectDir], + 'python.defaultInterpreterPath': sharedCodeServer.defaultPythonInterpreter, + } + await fs.ensureDir(path.join(workspaceDir, '.vscode')) + await fs.writeJson( + path.join(workspaceDir, '.vscode', 'settings.json'), + settings, + { spaces: 2 }, + ) + + await openServerPage(page, workspaceDir, sharedCodeServer) + await testLineageWithProjectPath(page) +}) + +// These work on local machine when debuggin but not on CI, so skipping for now +test('Lineage panel renders correctly - multiworkspace setup', async ({ + page, + sharedCodeServer, +}) => { + const workspaceDir = await fs.mkdtemp( + path.join(os.tmpdir(), 'vscode-test-workspace-'), + ) + const projectDir1 = path.join(workspaceDir, 'projects', 'sushi1') + const projectDir2 = path.join(workspaceDir, 'projects', 'sushi2') + await fs.copy(SUSHI_SOURCE_PATH, projectDir1) + await fs.ensureDir(projectDir2) + + // Add a .code-workspace file with multiple projects + const workspaceFilePath = path.join( + workspaceDir, + 'multi-workspace.code-workspace', + ) + writeFileSync( + workspaceFilePath, + JSON.stringify({ + folders: [ + { + name: 'sushi1', + path: 'projects/sushi1', + }, + { + name: 'sushi2', + path: 'projects/sushi2', + }, + ], + }), + ) + + const settings = { + 'python.defaultInterpreterPath': sharedCodeServer.defaultPythonInterpreter, + } + await fs.ensureDir(path.join(projectDir1, '.vscode')) + await fs.writeJson( + path.join(projectDir1, '.vscode', 'settings.json'), + settings, + { spaces: 2 }, + ) + + await openServerPage(page, workspaceFilePath, sharedCodeServer) + await page.reload() + await testLineageWithProjectPath(page) +}) + +test('Lineage panel renders correctly - multiworkspace setup reversed', async ({ + page, + sharedCodeServer, +}) => { + const workspaceDir = await fs.mkdtemp( + path.join(os.tmpdir(), 'vscode-test-workspace-'), + ) + const projectDir1 = path.join(workspaceDir, 'projects', 'sushi1') + const projectDir2 = path.join(workspaceDir, 'projects', 'sushi2') + await fs.copy(SUSHI_SOURCE_PATH, projectDir2) + await fs.ensureDir(projectDir1) + + // Add a .code-workspace file with multiple projects + const workspaceFilePath = path.join( + workspaceDir, + 'multi-workspace.code-workspace', + ) + writeFileSync( + workspaceFilePath, + JSON.stringify({ + folders: [ + { + name: 'sushi1', + path: 'projects/sushi1', + }, + { + name: 'sushi2', + path: 'projects/sushi2', + }, + ], + }), + ) + + const settings = { + 'python.defaultInterpreterPath': sharedCodeServer.defaultPythonInterpreter, + } + await fs.ensureDir(path.join(projectDir1, '.vscode')) + await fs.writeJson( + path.join(projectDir1, '.vscode', 'settings.json'), + settings, + { spaces: 2 }, + ) + await fs.ensureDir(path.join(projectDir2, '.vscode')) + await fs.writeJson( + path.join(projectDir2, '.vscode', 'settings.json'), + settings, + { spaces: 2 }, + ) + + await openServerPage(page, workspaceFilePath, sharedCodeServer) + await page.reload() + await testLineageWithProjectPath(page) +}) diff --git a/vscode/extension/tests/lineage_settings.spec.ts b/vscode/extension/tests/lineage_settings.spec.ts new file mode 100644 index 0000000000..c3237f13dc --- /dev/null +++ b/vscode/extension/tests/lineage_settings.spec.ts @@ -0,0 +1,64 @@ +import { test, expect } from './fixtures' +import fs from 'fs-extra' +import { + openLineageView, + openServerPage, + SUSHI_SOURCE_PATH, + waitForLoadedSQLMesh, +} from './utils' +import { createPythonInterpreterSettingsSpecifier } from './utils_code_server' + +test('Settings button is visible in the lineage view', async ({ + page, + sharedCodeServer, + tempDir, +}) => { + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + await createPythonInterpreterSettingsSpecifier(tempDir) + + await openServerPage(page, tempDir, sharedCodeServer) + + await page.waitForSelector('text=models') + + // Click on the models folder, excluding external_models + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + // Open the waiters.py model + await page + .getByRole('treeitem', { name: 'waiters.py', exact: true }) + .locator('a') + .click() + await waitForLoadedSQLMesh(page) + + // Open lineage + await openLineageView(page) + + const iframes = page.locator('iframe') + const iframeCount = await iframes.count() + let settingsCount = 0 + + for (let i = 0; i < iframeCount; i++) { + const iframe = iframes.nth(i) + const contentFrame = iframe.contentFrame() + if (contentFrame) { + const activeFrame = contentFrame.locator('#active-frame').contentFrame() + if (activeFrame) { + try { + await activeFrame + .getByRole('button', { + name: 'Settings', + }) + .waitFor({ timeout: 1000 }) + settingsCount++ + } catch { + // Continue to next iframe if this one doesn't have the error + continue + } + } + } + } + + expect(settingsCount).toBeGreaterThan(0) +}) diff --git a/vscode/extension/tests/multi_project.spec.ts b/vscode/extension/tests/multi_project.spec.ts new file mode 100644 index 0000000000..987c014537 --- /dev/null +++ b/vscode/extension/tests/multi_project.spec.ts @@ -0,0 +1,35 @@ +import { test } from './fixtures' +import { + MULTI_SOURCE_PATH, + openServerPage, + waitForLoadedSQLMesh, +} from './utils' +import fs from 'fs-extra' + +test('should work with multi-project setups', async ({ + page, + sharedCodeServer, + tempDir, +}) => { + await fs.copy(MULTI_SOURCE_PATH, tempDir) + + // Open the server + await openServerPage(page, tempDir, sharedCodeServer) + + // Open a model + await page + .getByRole('treeitem', { name: 'repo_1', exact: true }) + .locator('a') + .click() + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + await page + .getByRole('treeitem', { name: 'a.sql', exact: true }) + .locator('a') + .click() + + // Wait for for the project to be loaded + await waitForLoadedSQLMesh(page) +}) diff --git a/vscode/extension/tests/python_env.spec.ts b/vscode/extension/tests/python_env.spec.ts new file mode 100644 index 0000000000..cfbdc7efa6 --- /dev/null +++ b/vscode/extension/tests/python_env.spec.ts @@ -0,0 +1,128 @@ +import { test, Page } from './fixtures' +import fs from 'fs-extra' +import { + createVirtualEnvironment, + openLineageView, + openServerPage, + pipInstall, + PythonEnvironment, + REPO_ROOT, + SUSHI_SOURCE_PATH, + waitForLoadedSQLMesh, +} from './utils' +import path from 'path' +import { setTcloudVersion, setupAuthenticatedState } from './tcloud_utils' +import { CodeServerContext } from './utils_code_server' + +function writeEnvironmentConfig(sushiPath: string) { + const configPath = path.join(sushiPath, 'config.py') + const originalConfig = fs.readFileSync(configPath, 'utf8') + + const newConfig = + ` +import os + +test_var = os.getenv("TEST_VAR") +if test_var is None or test_var == "": + raise Exception("TEST_VAR is not set") +` + originalConfig + + fs.writeFileSync(configPath, newConfig) +} + +async function runTest( + page: Page, + context: CodeServerContext, + tempDir: string, +): Promise { + await openServerPage(page, tempDir, context) + await page.waitForSelector('text=models') + await openLineageView(page) +} + +async function setupEnvironment(tempDir: string): Promise<{ + pythonDetails: PythonEnvironment +}> { + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + const pythonEnvDir = path.join(tempDir, '.venv') + const pythonDetails = await createVirtualEnvironment(pythonEnvDir) + const custom_materializations = path.join( + REPO_ROOT, + 'examples', + 'custom_materializations', + ) + const sqlmeshWithExtras = `${REPO_ROOT}[bigquery,lsp]` + await pipInstall(pythonDetails, [sqlmeshWithExtras, custom_materializations]) + + const settings = { + 'python.defaultInterpreterPath': pythonDetails.pythonPath, + 'sqlmesh.environmentPath': pythonEnvDir, + } + await fs.ensureDir(path.join(tempDir, '.vscode')) + await fs.writeJson(path.join(tempDir, '.vscode', 'settings.json'), settings, { + spaces: 2, + }) + return { pythonDetails } +} + +test.describe('python environment variable injection on sqlmesh_lsp', () => { + test('normal setup - error ', async ({ page, sharedCodeServer, tempDir }) => { + await setupEnvironment(tempDir) + writeEnvironmentConfig(tempDir) + await runTest(page, sharedCodeServer, tempDir) + await page.waitForSelector('text=Error creating context') + }) + + test('normal setup - set', async ({ page, sharedCodeServer, tempDir }) => { + await setupEnvironment(tempDir) + writeEnvironmentConfig(tempDir) + const env_file = path.join(tempDir, '.env') + fs.writeFileSync(env_file, 'TEST_VAR=test_value') + await runTest(page, sharedCodeServer, tempDir) + await waitForLoadedSQLMesh(page) + }) +}) + +async function setupTcloudProject( + tempDir: string, + pythonDetails: PythonEnvironment, +) { + // Install the mock tcloud package + const mockTcloudPath = path.join(__dirname, 'tcloud') + await pipInstall(pythonDetails, [mockTcloudPath]) + + // Create a tcloud.yaml to mark this as a tcloud project + const tcloudConfig = { + url: 'https://mock.tobikodata.com', + org: 'test-org', + project: 'test-project', + } + await fs.writeFile( + path.join(tempDir, 'tcloud.yaml'), + `url: ${tcloudConfig.url}\norg: ${tcloudConfig.org}\nproject: ${tcloudConfig.project}\n`, + ) + // Write mock ".tcloud_auth_state.json" file + await setupAuthenticatedState(tempDir) + // Set tcloud version to 2.10.1 + await setTcloudVersion(tempDir, '2.10.1') +} + +test.describe('tcloud version', () => { + test('normal setup - error ', async ({ page, sharedCodeServer, tempDir }) => { + const { pythonDetails } = await setupEnvironment(tempDir) + await setupTcloudProject(tempDir, pythonDetails) + writeEnvironmentConfig(tempDir) + await runTest(page, sharedCodeServer, tempDir) + await page.waitForSelector('text=Error creating context') + }) + + test('normal setup - set', async ({ page, sharedCodeServer, tempDir }) => { + const { pythonDetails } = await setupEnvironment(tempDir) + await setupTcloudProject(tempDir, pythonDetails) + writeEnvironmentConfig(tempDir) + const env_file = path.join(tempDir, '.env') + fs.writeFileSync(env_file, 'TEST_VAR=test_value') + await runTest(page, sharedCodeServer, tempDir) + await waitForLoadedSQLMesh(page) + }) +}) diff --git a/vscode/extension/tests/quickfix.spec.ts b/vscode/extension/tests/quickfix.spec.ts new file mode 100644 index 0000000000..c3f37a2acc --- /dev/null +++ b/vscode/extension/tests/quickfix.spec.ts @@ -0,0 +1,96 @@ +import fs from 'fs-extra' +import path from 'path' +import { + openProblemsView, + SUSHI_SOURCE_PATH, + waitForLoadedSQLMesh, +} from './utils' +import { test, expect } from './fixtures' +import { createPythonInterpreterSettingsSpecifier } from './utils_code_server' + +test('noselectstar quickfix', async ({ page, sharedCodeServer, tempDir }) => { + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + await createPythonInterpreterSettingsSpecifier(tempDir) + + // Override the settings for the linter + const configPath = path.join(tempDir, 'config.py') + const read = await fs.readFile(configPath, 'utf8') + // Replace linter to be on + const target = 'enabled=True' + const replaced = read.replace('enabled=False', 'enabled=True') + // Assert replaced correctly + expect(replaced).toContain(target) + + // Replace the rules to only have noselectstar + const targetRules = `rules=[ + "noselectstar", + ],` + const replacedTheOtherRules = replaced.replace( + `rules=[ + "ambiguousorinvalidcolumn", + "invalidselectstarexpansion", + "noselectstar", + "nomissingaudits", + "nomissingowner", + "nomissingexternalmodels", + ],`, + targetRules, + ) + expect(replacedTheOtherRules).toContain(targetRules) + + await fs.writeFile(configPath, replacedTheOtherRules) + // Replace the file to cause the error + const modelPath = path.join(tempDir, 'models', 'latest_order.sql') + const readModel = await fs.readFile(modelPath, 'utf8') + // Replace the specific select with the select star + const modelReplaced = readModel.replace( + 'SELECT id, customer_id, start_ts, end_ts, event_date', + 'SELECT *', + ) + await fs.writeFile(modelPath, modelReplaced) + + // Open the code server with the specified directory + await page.goto( + `http://127.0.0.1:${sharedCodeServer.codeServerPort}/?folder=${tempDir}`, + ) + await page.waitForLoadState('networkidle') + + // Open the file with the linter issue + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + await page + .getByRole('treeitem', { name: 'latest_order.sql', exact: true }) + .locator('a') + .click() + + await waitForLoadedSQLMesh(page) + + await openProblemsView(page) + + await page.getByRole('button', { name: 'Show fixes' }).click() + // Wait for the quick fix menu to appear and click the specific action within it. + const quickFixMenu = page.getByRole('menu') + await quickFixMenu.waitFor({ state: 'visible' }) + const replaceSelectStar = quickFixMenu.getByRole('menuitem', { + name: /Replace SELECT \* with/i, + }) + await replaceSelectStar.first().waitFor({ state: 'visible' }) + await replaceSelectStar.first().click() + + // Wait for the quick fix to be applied by polling the file content + await expect + .poll(async () => { + const content = (await fs.readFile(modelPath)).toString('utf8') + return content.includes('SELECT *') + }) + .toBeFalsy() + + // Assert that the model no longer contains SELECT * but SELECT id, customer_id, waiter_id, start_ts, end_ts, event_date + const readUpdatedFile = (await fs.readFile(modelPath)).toString('utf8') + expect(readUpdatedFile).not.toContain('SELECT *') + expect(readUpdatedFile).toContain( + 'SELECT id, customer_id, waiter_id, start_ts, end_ts, event_date', + ) +}) diff --git a/vscode/extension/tests/rename_cte.spec.ts b/vscode/extension/tests/rename_cte.spec.ts new file mode 100644 index 0000000000..579bda06dd --- /dev/null +++ b/vscode/extension/tests/rename_cte.spec.ts @@ -0,0 +1,187 @@ +import { test, expect, Page } from './fixtures' +import fs from 'fs-extra' +import { + findAllReferences, + openServerPage, + renameSymbol, + SUSHI_SOURCE_PATH, + waitForLoadedSQLMesh, +} from './utils' +import { createPythonInterpreterSettingsSpecifier } from './utils_code_server' + +async function setupTestEnvironment({ + page, + sharedCodeServer, + tempDir, +}: { + page: Page + sharedCodeServer: any + tempDir: string +}) { + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + await createPythonInterpreterSettingsSpecifier(tempDir) + + await openServerPage(page, tempDir, sharedCodeServer) + + // Navigate to customers.sql which contains CTEs + await page.waitForSelector('text=models') + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + await page + .getByRole('treeitem', { name: 'customers.sql', exact: true }) + .locator('a') + .click() + await page.waitForSelector('text=grain') + await waitForLoadedSQLMesh(page) +} + +test.describe('CTE Rename', () => { + test('Rename CTE from definition', async ({ + page, + sharedCodeServer, + tempDir, + }) => { + await setupTestEnvironment({ page, sharedCodeServer, tempDir }) + // Click on the inner CTE definition "current_marketing" (not the outer one) + await page.locator('text=WITH current_marketing AS').click({ + position: { x: 100, y: 5 }, + }) + + // Open rename + await renameSymbol(page) + await page.waitForSelector('text=Rename') + await page.waitForSelector('input:focus') + + // Type new name and confirm + await page.keyboard.type('new_marketing') + await page.keyboard.press('Enter') + + // Verify the rename was applied + await page.waitForSelector('text=WITH new_marketing AS') + }) + + test('Rename CTE from usage', async ({ page, sharedCodeServer, tempDir }) => { + await setupTestEnvironment({ page, sharedCodeServer, tempDir }) + // Click on CTE usage in FROM clause + await page.locator('text=FROM current_marketing_outer').click({ + position: { x: 80, y: 5 }, + }) + + // Open rename + await renameSymbol(page) + await page.waitForSelector('text=Rename') + await page.waitForSelector('input:focus') + + // Type new name + await page.keyboard.type('updated_marketing_out') + + // Confirm rename + await page.keyboard.press('Enter') + + await page.waitForSelector('text=WITH updated_marketing_out AS') + await page.waitForSelector('text=FROM updated_marketing_out') + }) + + test('Cancel CTE rename', async ({ page, sharedCodeServer, tempDir }) => { + await setupTestEnvironment({ page, sharedCodeServer, tempDir }) + // Click on the CTE to rename + await page.locator('text=current_marketing_outer').first().click() + + // Open rename + await renameSymbol(page) + await page.waitForSelector('text=Rename') + await page.waitForSelector('input:focus') + + // Type new name but cancel + await page.keyboard.type('cancelled_name') + await page.keyboard.press('Escape') + + // Wait for UI to update + await page.waitForTimeout(500) + + // Verify CTE name was NOT changed + await expect( + page.locator('text=current_marketing_outer').first(), + ).toBeVisible() + await expect(page.locator('text=cancelled_name')).not.toBeVisible() + }) + + test('Rename CTE updates all references', async ({ + page, + tempDir, + sharedCodeServer, + }) => { + await setupTestEnvironment({ page, sharedCodeServer, tempDir }) + // Click on the CTE definition + await page.locator('text=WITH current_marketing AS').click({ + position: { x: 100, y: 5 }, + }) + + // Open rename + await renameSymbol(page) + await page.waitForSelector('text=Rename') + await page.waitForSelector('input:focus') + + // Type new name and confirm + await page.keyboard.type('renamed_cte') + await page.keyboard.press('Enter') + + // Click on the renamed CTE + await page.locator('text=WITH renamed_cte AS').click({ + position: { x: 100, y: 5 }, + }) + + // Find all references using keyboard shortcut + await findAllReferences(page) + + // Verify references panel shows all occurrences + await page.waitForSelector('text=References') + await expect(page.locator('text=customers.sql').first()).toBeVisible() + await page.waitForSelector('text=WITH renamed_cte AS') + await page.waitForSelector('text=renamed_cte.*') + await page.waitForSelector('text=FROM renamed_cte') + await page.waitForSelector('text=renamed_cte.customer_id != 100') + }) + + test('Rename CTE with preview', async ({ + page, + sharedCodeServer, + tempDir, + }) => { + await setupTestEnvironment({ page, sharedCodeServer, tempDir }) + // Click on the CTE to rename + await page.locator('text=WITH current_marketing AS').click({ + position: { x: 100, y: 5 }, + }) + + // Open rename + await renameSymbol(page) + await page.waitForSelector('text=Rename') + await page.waitForSelector('input:focus') + + // Type new name + await page.keyboard.type('preview_marketing') + + // Press Cmd+Enter (Meta+Enter) to preview changes + await page.keyboard.press( + process.platform === 'darwin' ? 'Meta+Enter' : 'Control+Enter', + ) + + // Verify preview UI is showing + await expect(page.locator('text=Refactor Preview').first()).toBeVisible() + await expect(page.locator('text=Apply').first()).toBeVisible() + await expect(page.locator('text=Discard').first()).toBeVisible() + + // Verify the preview shows both old and new names + await expect(page.locator('text=current_marketing').first()).toBeVisible() + await expect(page.locator('text=preview_marketing').first()).toBeVisible() + + // Apply the changes + await page.locator('text=Apply').click() + + // Verify the rename was applied + await expect(page.locator('text=WITH preview_marketing AS')).toBeVisible() + }) +}) diff --git a/vscode/extension/tests/render.spec.ts b/vscode/extension/tests/render.spec.ts new file mode 100644 index 0000000000..db660daae1 --- /dev/null +++ b/vscode/extension/tests/render.spec.ts @@ -0,0 +1,152 @@ +import { test, expect } from './fixtures' +import fs from 'fs-extra' +import { + openLineageView, + openServerPage, + runCommand, + SUSHI_SOURCE_PATH, + waitForLoadedSQLMesh, +} from './utils' +import { createPythonInterpreterSettingsSpecifier } from './utils_code_server' + +test('Render works correctly', async ({ page, sharedCodeServer, tempDir }) => { + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + await createPythonInterpreterSettingsSpecifier(tempDir) + + await openServerPage(page, tempDir, sharedCodeServer) + + // Wait for the models folder to be visible + await page.waitForSelector('text=models') + + // Click on the models folder, excluding external_models + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + + // Open the customer_revenue_lifetime model + await page + .getByRole('treeitem', { name: 'customers.sql', exact: true }) + .locator('a') + .click() + + await page.waitForSelector('text=grain') + await waitForLoadedSQLMesh(page) + + // Render the model + await runCommand(page, 'Render Model') + + // Check if the model is rendered by check if "`oi`.`order_id` AS `order_id`," is in the window + await expect(page.locator('text="marketing"."customer_id" AS')).toBeVisible() + await expect(page.locator('text=sushi.customers (rendered)')).toBeVisible() +}) + +test('Render works correctly with model without a description', async ({ + page, + sharedCodeServer, + tempDir, +}) => { + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + + await createPythonInterpreterSettingsSpecifier(tempDir) + await openServerPage(page, tempDir, sharedCodeServer) + + // Wait for the models folder to be visible + await page.waitForSelector('text=models') + + // Click on the models folder, excluding external_models + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + + // Open the latest_order model + await page + .getByRole('treeitem', { name: 'latest_order.sql', exact: true }) + .locator('a') + .click() + + await page.waitForSelector('text=custom_full_with_custom_kind') + await waitForLoadedSQLMesh(page) + + // Render the model + await runCommand(page, 'Render Model') + + // Check if the model is rendered correctly + await expect(page.locator('text="orders"."id" AS "id",')).toBeVisible() + await expect(page.locator('text=sushi.latest_order (rendered)')).toBeVisible() +}) + +test('Render works correctly with every rendered model opening a new tab', async ({ + page, + sharedCodeServer, + tempDir, +}) => { + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + + await createPythonInterpreterSettingsSpecifier(tempDir) + await openServerPage(page, tempDir, sharedCodeServer) + + // Wait for the models folder to be visible + await page.waitForSelector('text=models') + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + await page + .getByRole('treeitem', { name: 'latest_order.sql', exact: true }) + .locator('a') + .click() + await page.waitForSelector('text=custom_full_with_custom_kind') + await waitForLoadedSQLMesh(page) + + // Render the model + await runCommand(page, 'Render Model') + + // Check if the model is rendered correctly + await expect(page.locator('text=sushi.latest_order (rendered)')).toBeVisible() + + // Open the customers model + await page + .getByRole('treeitem', { name: 'customers.sql', exact: true }) + .locator('a') + .click() + await page.waitForSelector('text=grain') + + // Render the customers model + await runCommand(page, 'Render Model') + + // Assert both tabs exist + await expect(page.locator('text=sushi.latest_order (rendered)')).toBeVisible() + await expect(page.locator('text=sushi.customers (rendered)')).toBeVisible() +}) + +test('Render shows model picker when no active editor is open', async ({ + page, + sharedCodeServer, + tempDir, +}) => { + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + await createPythonInterpreterSettingsSpecifier(tempDir) + + // Navigate to code-server instance + await openServerPage(page, tempDir, sharedCodeServer) + + // Load the lineage view to initialize SQLMesh context (like lineage.spec.ts does) + await openLineageView(page) + + await waitForLoadedSQLMesh(page) + + // Run the render command without any active editor + await runCommand(page, 'Render Model') + + // Type to filter for customers model and select it + await page.keyboard.type('customers') + await page.waitForSelector('text=sushi.customers', { timeout: 2_000 }) + await page.locator('text=sushi.customers').click() + + // Verify the rendered model is shown + await expect(page.locator('text=sushi.customers (rendered)')).toBeVisible({ + timeout: 2_000, + }) +}) diff --git a/vscode/extension/tests/stop.spec.ts b/vscode/extension/tests/stop.spec.ts new file mode 100644 index 0000000000..64a12b2e46 --- /dev/null +++ b/vscode/extension/tests/stop.spec.ts @@ -0,0 +1,131 @@ +import { + openServerPage, + runCommand, + SUSHI_SOURCE_PATH, + waitForLoadedSQLMesh, +} from './utils' +import { test } from './fixtures' +import fs from 'fs-extra' +import { createPythonInterpreterSettingsSpecifier } from './utils_code_server' + +test('Stop server works', async ({ page, sharedCodeServer, tempDir }) => { + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + + await createPythonInterpreterSettingsSpecifier(tempDir) + + // Navigate to code-server instance + await openServerPage(page, tempDir, sharedCodeServer) + + // Wait for the models folder to be visible in the file explorer + await page.waitForSelector('text=models') + + // Click on the models folder, excluding external_models + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + + // Open the customers.sql model + await page + .getByRole('treeitem', { name: 'customers.sql', exact: true }) + .locator('a') + .click() + + await page.waitForSelector('text=grain') + await waitForLoadedSQLMesh(page) + + // Stop the server + await runCommand(page, 'SQLMesh: Stop Server') + + // Await LSP server stopped message + await page.waitForSelector('text=LSP server stopped') + + // Render the model + await runCommand(page, 'SQLMesh: Render Model') + + // Await error message + await page.waitForSelector( + 'text="Failed to render model: LSP client not ready."', + ) +}) + +test('Stopped server only restarts when explicitly requested', async ({ + page, + sharedCodeServer, + tempDir, +}) => { + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + + await createPythonInterpreterSettingsSpecifier(tempDir) + + // Navigate to code-server instance + await openServerPage(page, tempDir, sharedCodeServer) + + // Wait for the models folder to be visible in the file explorer + await page.waitForSelector('text=models') + + // Click on the models folder, excluding external_models + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + + // Open the customers.sql model + await page + .getByRole('treeitem', { name: 'marketing.sql', exact: true }) + .locator('a') + .click() + await page.waitForSelector('text=grain') + await waitForLoadedSQLMesh(page) + + // Click on sushi.raw_marketing + await page.getByText('sushi.raw_marketing;').click() + + // Open the preview hover + await runCommand(page, 'Show Definition Preview Hover') + + // Assert that the hover is visible with text "Table of marketing status." + await page.waitForSelector('text=Table of marketing status.', { + timeout: 5_000, + state: 'visible', + }) + + // Hit Esc to close the hover + await page.keyboard.press('Escape') + + // Assert that the hover is no longer visible + await page.waitForSelector('text=Table of marketing status.', { + timeout: 5_000, + state: 'hidden', + }) + + // Stop the server + await runCommand(page, 'SQLMesh: Stop Server') + + // Await LSP server stopped message + await page.waitForSelector('text=LSP server stopped') + + // Open the preview hover again + await runCommand(page, 'Show Definition Preview Hover') + + // Assert that the hover is not visible + await page.waitForSelector('text=Table of marketing status.', { + timeout: 5_000, + state: 'hidden', + }) + + // Restart the server explicitly + await runCommand(page, 'SQLMesh: Restart Server') + + // Await LSP server started message + await waitForLoadedSQLMesh(page) + + // Open the preview hover again + await runCommand(page, 'Show Definition Preview Hover') + + // Assert that the hover is visible with text "Table of marketing status." + await page.waitForSelector('text=Table of marketing status.', { + timeout: 5_000, + state: 'visible', + }) +}) diff --git a/vscode/extension/tests/tcloud.spec.ts b/vscode/extension/tests/tcloud.spec.ts new file mode 100644 index 0000000000..1229696e02 --- /dev/null +++ b/vscode/extension/tests/tcloud.spec.ts @@ -0,0 +1,416 @@ +import { expect, test } from './fixtures' +import path from 'path' +import fs from 'fs-extra' +import { + createVirtualEnvironment, + openServerPage, + pipInstall, + REPO_ROOT, + SUSHI_SOURCE_PATH, + waitForLoadedSQLMesh, +} from './utils' +import { setTcloudVersion, setupAuthenticatedState } from './tcloud_utils' +import { + createPythonInterpreterSettingsSpecifier, + startCodeServer, + stopCodeServer, +} from './utils_code_server' + +/** + * Helper function to create and set up a Python virtual environment + */ +async function setupPythonEnvironment(envDir: string): Promise { + // Create virtual environment + const pythonDetails = await createVirtualEnvironment(envDir) + + // Install the mock tcloud package + const mockTcloudPath = path.join(__dirname, 'tcloud') + await pipInstall(pythonDetails, [mockTcloudPath]) + + // Install sqlmesh from the local repository with LSP support + const customMaterializations = path.join( + REPO_ROOT, + 'examples', + 'custom_materializations', + ) + const sqlmeshWithExtras = `${REPO_ROOT}[lsp,bigquery]` + await pipInstall(pythonDetails, [sqlmeshWithExtras, customMaterializations]) + + return pythonDetails.pythonPath +} + +test('not signed in, shows sign in window', async ({ + page, + sharedCodeServer, + tempDir, +}) => { + const pythonEnvDir = path.join(tempDir, '.venv') + + const context = await startCodeServer({ tempDir }) + + try { + // Copy sushi project + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + + // Create a tcloud.yaml to mark this as a tcloud project + const tcloudConfig = { + url: 'https://mock.tobikodata.com', + org: 'test-org', + project: 'test-project', + } + await fs.writeFile( + path.join(tempDir, 'tcloud.yaml'), + `url: ${tcloudConfig.url}\norg: ${tcloudConfig.org}\nproject: ${tcloudConfig.project}\n`, + ) + + // Set tcloud version to 2.10.0 + await setTcloudVersion(tempDir, '2.10.0') + + // Set up Python environment with mock tcloud and sqlmesh + const pythonPath = await setupPythonEnvironment(pythonEnvDir) + + // Configure VS Code settings to use our Python environment + const settings = { + 'python.defaultInterpreterPath': pythonPath, + 'sqlmesh.environmentPath': pythonEnvDir, + } + await fs.ensureDir(path.join(tempDir, '.vscode')) + await fs.writeJson( + path.join(tempDir, '.vscode', 'settings.json'), + settings, + { spaces: 2 }, + ) + + await openServerPage(page, tempDir, sharedCodeServer) + + // Open a SQL file to trigger SQLMesh activation + // Wait for the models folder to be visible + await page.waitForSelector('text=models') + + // Click on the models folder + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + + // Open the top_waiters model + await page + .getByRole('treeitem', { name: 'customers.sql', exact: true }) + .locator('a') + .click() + + // Wait for the file to open + await page.waitForLoadState('networkidle') + + await page.waitForSelector( + 'text=Please sign in to Tobiko Cloud to use SQLMesh', + ) + } finally { + await stopCodeServer(context) + } +}) + +test('signed in and not installed shows installation window', async ({ + page, + sharedCodeServer, + tempDir, +}) => { + const pythonEnvDir = path.join(tempDir, '.venv') + + const context = await startCodeServer({ tempDir }) + + try { + // Copy sushi project + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + + // Create a tcloud.yaml to mark this as a tcloud project + const tcloudConfig = { + url: 'https://mock.tobikodata.com', + org: 'test-org', + project: 'test-project', + } + await fs.writeFile( + path.join(tempDir, 'tcloud.yaml'), + `url: ${tcloudConfig.url}\norg: ${tcloudConfig.org}\nproject: ${tcloudConfig.project}\n`, + ) + + // Write mock ".tcloud_auth_state.json" file + await setupAuthenticatedState(tempDir) + + // Set tcloud version to 2.10.0 + await setTcloudVersion(tempDir, '2.10.0') + + // Set up Python environment with mock tcloud and sqlmesh + const pythonPath = await setupPythonEnvironment(pythonEnvDir) + + // Configure VS Code settings to use our Python environment + const settings = { + 'python.defaultInterpreterPath': pythonPath, + 'sqlmesh.environmentPath': pythonEnvDir, + } + await fs.ensureDir(path.join(tempDir, '.vscode')) + await fs.writeJson( + path.join(tempDir, '.vscode', 'settings.json'), + settings, + { spaces: 2 }, + ) + + await openServerPage(page, tempDir, sharedCodeServer) + + // Open a SQL file to trigger SQLMesh activation + // Wait for the models folder to be visible + await page.waitForSelector('text=models') + + // Click on the models folder + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + + // Open the top_waiters model + await page + .getByRole('treeitem', { name: 'customers.sql', exact: true }) + .locator('a') + .click() + + await page.waitForSelector('text=Installing enterprise python package') + await expect( + page.locator('text=Installing enterprise python package'), + ).toHaveCount(2) + + await waitForLoadedSQLMesh(page) + } finally { + await stopCodeServer(context) + } +}) + +test('tcloud sqlmesh_lsp command starts the sqlmesh_lsp in old version when ready', async ({ + page, + sharedCodeServer, + tempDir, +}) => { + const pythonEnvDir = path.join(tempDir, '.venv') + + try { + // Copy sushi project + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + + // Create a tcloud.yaml to mark this as a tcloud project + const tcloudConfig = { + url: 'https://mock.tobikodata.com', + org: 'test-org', + project: 'test-project', + } + await fs.writeFile( + path.join(tempDir, 'tcloud.yaml'), + `url: ${tcloudConfig.url}\norg: ${tcloudConfig.org}\nproject: ${tcloudConfig.project}\n`, + ) + + // Write mock ".tcloud_auth_state.json" file + await setupAuthenticatedState(tempDir) + + // Set tcloud version to 2.10.0 + await setTcloudVersion(tempDir, '2.10.0') + + // Set up Python environment with mock tcloud and sqlmesh + const pythonPath = await setupPythonEnvironment(pythonEnvDir) + + // Mark sqlmesh as installed + const binDir = path.dirname(pythonPath) + const installStateFile = path.join(binDir, '.sqlmesh_installed') + await fs.writeFile(installStateFile, '') + + // Configure VS Code settings to use our Python environment + const settings = { + 'python.defaultInterpreterPath': pythonPath, + 'sqlmesh.environmentPath': pythonEnvDir, + } + await fs.ensureDir(path.join(tempDir, '.vscode')) + await fs.writeJson( + path.join(tempDir, '.vscode', 'settings.json'), + settings, + { spaces: 2 }, + ) + + // Start VS Code + await openServerPage(page, tempDir, sharedCodeServer) + + // Open a SQL file to trigger SQLMesh activation + // Wait for the models folder to be visible + await page.waitForSelector('text=models') + + // Click on the models folder + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + + // Open the top_waiters model + await page + .getByRole('treeitem', { name: 'customers.sql', exact: true }) + .locator('a') + .click() + + // Verify the context loads successfully + await waitForLoadedSQLMesh(page) + } finally { + // Clean up + await fs.remove(tempDir) + } +}) + +test('tcloud sqlmesh_lsp command starts the sqlmesh_lsp in new version when ready', async ({ + page, + sharedCodeServer, + tempDir, +}) => { + const pythonEnvDir = path.join(tempDir, '.venv') + + try { + // Copy sushi project + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + + // Create a tcloud.yaml to mark this as a tcloud project + const tcloudConfig = { + url: 'https://mock.tobikodata.com', + org: 'test-org', + project: 'test-project', + } + await fs.writeFile( + path.join(tempDir, 'tcloud.yaml'), + `url: ${tcloudConfig.url}\norg: ${tcloudConfig.org}\nproject: ${tcloudConfig.project}\n`, + ) + + // Write mock ".tcloud_auth_state.json" file + await setupAuthenticatedState(tempDir) + + // Set tcloud version to 2.10.0 + await setTcloudVersion(tempDir, '2.10.1') + + // Set up Python environment with mock tcloud and sqlmesh + const pythonPath = await setupPythonEnvironment(pythonEnvDir) + + // Mark sqlmesh as installed + const binDir = path.dirname(pythonPath) + const installStateFile = path.join(binDir, '.sqlmesh_installed') + await fs.writeFile(installStateFile, '') + + // Configure VS Code settings to use our Python environment + const settings = { + 'python.defaultInterpreterPath': pythonPath, + 'sqlmesh.environmentPath': pythonEnvDir, + } + await fs.ensureDir(path.join(tempDir, '.vscode')) + await fs.writeJson( + path.join(tempDir, '.vscode', 'settings.json'), + settings, + { spaces: 2 }, + ) + + await openServerPage(page, tempDir, sharedCodeServer) + + // Open a SQL file to trigger SQLMesh activation + // Wait for the models folder to be visible + await page.waitForSelector('text=models') + + // Click on the models folder + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + + // Open the top_waiters model + await page + .getByRole('treeitem', { name: 'customers.sql', exact: true }) + .locator('a') + .click() + + // Verify the context loads successfully + await waitForLoadedSQLMesh(page) + } finally { + // Clean up + await fs.remove(tempDir) + } +}) + +// This test is skipped becuase of the way the sign in window is shown is not useable by playwright. It's not solvable +// but the test is still useful when running it manually. +test.skip('tcloud not signed in and not installed, shows sign in window and then fact that loaded', async ({ + page, + tempDir, +}) => { + const pythonEnvDir = path.join(tempDir, '.venv') + + // Create a tcloud.yaml to mark this as a tcloud project + const tcloudConfig = { + url: 'https://mock.tobikodata.com', + org: 'test-org', + project: 'test-project', + } + await fs.writeFile( + path.join(tempDir, 'tcloud.yaml'), + `url: ${tcloudConfig.url}\norg: ${tcloudConfig.org}\nproject: ${tcloudConfig.project}\n`, + ) + + // Set up Python environment with mock tcloud and sqlmesh + const pythonPath = await setupPythonEnvironment(pythonEnvDir) + + // Configure VS Code settings to use our Python environment + const settings = { + 'python.defaultInterpreterPath': pythonPath, + 'sqlmesh.environmentPath': pythonEnvDir, + } + await fs.ensureDir(path.join(tempDir, '.vscode')) + await fs.writeJson(path.join(tempDir, '.vscode', 'settings.json'), settings, { + spaces: 2, + }) + + // Set tcloud version to 2.10.0 + await setTcloudVersion(tempDir, '2.10.1') + + // Start VS Code + const context = await startCodeServer({ + tempDir, + }) + await createPythonInterpreterSettingsSpecifier(tempDir) + await page.goto(`http://127.0.0.1:${context.codeServerPort}`) + + try { + // Copy sushi project + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + + // Open a SQL file to trigger SQLMesh activation + // Wait for the models folder to be visible + await page.waitForSelector('text=models') + + // Click on the models folder + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + + // Open the top_waiters model + await page + .getByRole('treeitem', { name: 'customers.sql', exact: true }) + .locator('a') + .click() + + // Verify the sign in window is shown + await page.waitForSelector( + 'text=Please sign in to Tobiko Cloud to use SQLMesh', + ) + + // Click on the sign in button + await page + .getByRole('button', { name: 'Sign in' }) + .filter({ hasText: 'Sign in' }) + .click() + await page.waitForSelector('text="Signed in successfully"') + + await page.waitForSelector('text=Installing enterprise python package') + + await waitForLoadedSQLMesh(page) + } finally { + await stopCodeServer(context) + } +}) diff --git a/vscode/extension/tests/tcloud/README.md b/vscode/extension/tests/tcloud/README.md new file mode 100644 index 0000000000..c3c723be37 --- /dev/null +++ b/vscode/extension/tests/tcloud/README.md @@ -0,0 +1,53 @@ +# Mock tcloud CLI for Testing + +This directory contains a mock implementation of the tcloud CLI for testing the VSCode extension. + +## Implemented Commands + +The mock implements only the commands used by the VSCode extension: + +### Authentication Commands +- `tcloud auth vscode status` - Returns authentication status +- `tcloud auth vscode login-url` - Returns mock OAuth login URL +- `tcloud auth vscode start-server ` - Simulates OAuth callback +- `tcloud auth vscode device` - Returns mock device flow info +- `tcloud auth vscode poll_device ` - Simulates device flow success +- `tcloud auth logout` - Clears authentication state + +### SQLMesh Commands +- `tcloud is_sqlmesh_installed` - Checks installation status +- `tcloud install_sqlmesh` - Marks SQLMesh as installed +- `tcloud sqlmesh ` - Echoes sqlmesh commands + +## State Management + +The mock maintains state in two files: +- `.tcloud_auth_state.json` - Authentication state (logged in/out, ID token) +- `.sqlmesh_installed` - SQLMesh installation marker + +## Usage in Tests + +To use this mock in tests: + +1. Ensure the mock is in PATH or reference it directly +2. The mock will simulate successful authentication flows +3. State persists between calls for realistic testing + +## Example + +```bash +# Check auth status +./tcloud auth vscode status +# Output: {"is_logged_in": false, "id_token": null} + +# Simulate login +./tcloud auth vscode login-url +# Output: {"url": "https://mock-auth.example.com/auth?client_id=mock&redirect_uri=http://localhost:7890", "verifier_code": "mock_verifier_12345"} + +./tcloud auth vscode start-server mock_verifier_12345 +# Output: Mock server started successfully + +# Check status again +./tcloud auth vscode status +# Output: {"is_logged_in": true, "id_token": {"email": "test@example.com", "name": "Test User", "exp": 1736790123}} +``` \ No newline at end of file diff --git a/vscode/extension/tests/tcloud/mock_tcloud/__init__.py b/vscode/extension/tests/tcloud/mock_tcloud/__init__.py new file mode 100644 index 0000000000..98ad152e0e --- /dev/null +++ b/vscode/extension/tests/tcloud/mock_tcloud/__init__.py @@ -0,0 +1 @@ +# Mock tcloud package \ No newline at end of file diff --git a/vscode/extension/tests/tcloud/mock_tcloud/cli.py b/vscode/extension/tests/tcloud/mock_tcloud/cli.py new file mode 100755 index 0000000000..55de42ca81 --- /dev/null +++ b/vscode/extension/tests/tcloud/mock_tcloud/cli.py @@ -0,0 +1,311 @@ +""" +Mock tcloud CLI for testing VSCode extension. +Implements only the commands used by the extension. +""" + +import json +import os +import subprocess +import sys +import time +from pathlib import Path + +import click + +def get_auth_state_file(): + """Get the path to the auth state file in the current working directory""" + return Path.cwd() / ".tcloud_auth_state.json" + + +def get_version_state_file(): + """Get the path to the version state file in the current working directory""" + return Path.cwd() / ".tcloud_version_state.json" + + +def load_auth_state(): + """Load authentication state from file""" + auth_file = get_auth_state_file() + if auth_file.exists(): + with open(auth_file, "r") as f: + return json.load(f) + return {"is_logged_in": False, "id_token": None} + + +def save_auth_state(state): + """Save authentication state to file""" + auth_file = get_auth_state_file() + with open(auth_file, "w") as f: + json.dump(state, f) + + +def load_version_state(): + """Load version state from file""" + version_file = get_version_state_file() + if version_file.exists(): + with open(version_file, "r") as f: + return json.load(f) + # Default to version 2.10.0 if no state file exists + return {"version": "2.10.0"} + + +@click.group(no_args_is_help=True, invoke_without_command=True) +@click.option( + "--project", + type=str, + help="The name of the project.", +) +@click.option( + "--version", + is_flag=True, + help="Show version", +) +@click.pass_context +def cli(ctx: click.Context, project: str, version: bool) -> None: + """Mock Tobiko Cloud CLI""" + if version: + version_state = load_version_state() + print(version_state["version"]) + ctx.exit(0) + + if ctx.invoked_subcommand is None: + click.echo(ctx.get_help()) + ctx.exit(0) + + ctx.ensure_object(dict) + ctx.obj["project"] = project + + +@cli.command("is_sqlmesh_installed", hidden=True) +@click.pass_context +def is_sqlmesh_installed(ctx: click.Context) -> None: + """Check if SQLMesh Enterprise is installed""" + # For testing, we'll track installation state in a file in the current bin directory + # This matches where the test expects it to be + bin_dir = Path(sys.executable).parent + install_state_file = bin_dir / ".sqlmesh_installed" + is_installed = install_state_file.exists() + + print( + json.dumps( + { + "is_installed": is_installed, + } + ) + ) + + +@cli.command("install_sqlmesh") +@click.pass_context +def install_sqlmesh(ctx: click.Context) -> None: + """Install the correct version of SQLMesh Enterprise""" + + # For 3 seconds output to stdout + for i in range(3): + print(f"Installing SQLMesh Enterprise logs {i + 1}/3", flush=True) + time.sleep(1) + + # Simulate installation by creating a marker file in the bin directory + bin_dir = Path(sys.executable).parent + install_state_file = bin_dir / ".sqlmesh_installed" + install_state_file.touch() + + print("Mock SQLMesh Enterprise installed successfully") + + +@cli.command("sqlmesh") +@click.argument("args", nargs=-1) +@click.pass_context +def sqlmesh(ctx: click.Context, args) -> None: + """Run SQLMesh Enterprise commands""" + # Pass through to the real sqlmesh command + + # Get the path to sqlmesh in the same environment as this script + bin_dir = os.path.dirname(sys.executable) + sqlmesh_path = os.path.join(bin_dir, "sqlmesh") + + if not os.path.exists(sqlmesh_path): + # Try with .exe extension on Windows + sqlmesh_path = os.path.join(bin_dir, "sqlmesh.exe") + + if not os.path.exists(sqlmesh_path): + # Fall back to using sqlmesh from PATH + sqlmesh_path = "sqlmesh" + + # Execute the real sqlmesh with the provided arguments + result = subprocess.run([sqlmesh_path] + list(args), capture_output=False) + sys.exit(result.returncode) + + +@cli.command("sqlmesh_lsp") +@click.argument("args", nargs=-1) +@click.pass_context +def sqlmesh_lsp(ctx: click.Context, args) -> None: + """Run SQLMesh LSP server""" + # For testing purposes, we'll simulate the LSP server starting + print("Starting SQLMesh LSP server...", flush=True) + + # Get the path to sqlmesh in the same environment as this script + bin_dir = os.path.dirname(sys.executable) + sqlmesh_path = os.path.join(bin_dir, "sqlmesh") + + if not os.path.exists(sqlmesh_path): + # Try with .exe extension on Windows + sqlmesh_path = os.path.join(bin_dir, "sqlmesh.exe") + + if not os.path.exists(sqlmesh_path): + # Fall back to using sqlmesh from PATH + sqlmesh_path = "sqlmesh" + + # Execute the real sqlmesh with lsp command and provided arguments + result = subprocess.run([sqlmesh_path, "lsp"] + list(args), capture_output=False) + sys.exit(result.returncode) + + +@click.group() +def auth() -> None: + """ + Tobiko Cloud Authentication + """ + + +@auth.command() +def logout() -> None: + """Logout of any current session""" + save_auth_state({"is_logged_in": False, "id_token": None}) + print("Logged out successfully") + + +### Methods for VSCode +@auth.group(hidden=True) +def vscode() -> None: + """Commands for VSCode integration""" + pass + + +@vscode.command("login-url") +def login_url() -> None: + """ + Login to Tobiko Cloud. + + This returns a JSON object with the following fields: + - url: The URL to login open + """ + # Return mock OAuth URL and verifier + print( + json.dumps( + { + "url": "https://mock-auth.example.com/auth?client_id=mock&redirect_uri=http://localhost:7890", + "verifier_code": "mock_verifier_12345", + } + ) + ) + + +@vscode.command("start-server") +@click.argument("code_verifier", type=str, required=True) +def start_server(code_verifier: str) -> None: + """ + Start the server to catch the redirect from the browser. + """ + # Simulate successful authentication after a short delay + time.sleep(0.5) + + # Update auth state to logged in + save_auth_state( + { + "is_logged_in": True, + "id_token": { + "iss": "https://mock.tobikodata.com", + "aud": "mock-audience", + "sub": "user-123", + "scope": "openid email profile", + "iat": int(time.time()), + "exp": int(time.time()) + 3600, # Token expires in 1 hour + "email": "test@example.com", + "name": "Test User", + }, + } + ) + + # The real command would start a server, but for testing we just simulate success + print("Mock server started successfully") + + +@vscode.command("status") +def vscode_status() -> None: + """ + Auth status for logged in + """ + state = load_auth_state() + print( + json.dumps( + {"is_logged_in": state["is_logged_in"], "id_token": state["id_token"]} + ) + ) + + +@vscode.command("device") +def vscode_device() -> None: + """ + Initiate device flow for VSCode integration + """ + print( + json.dumps( + { + "device_code": "MOCK-DEVICE-CODE", + "user_code": "ABCD-1234", + "verification_uri": "https://mock-auth.example.com/device", + "verification_uri_complete": "https://mock-auth.example.com/device?user_code=ABCD-1234", + "expires_in": 600, + "interval": 5, + } + ) + ) + + +@vscode.command("poll_device") +@click.argument("device_code", type=str, required=True) +@click.option( + "-i", + "--interval", + type=int, + default=5, + help="The interval between polling attempts in seconds", +) +@click.option( + "-t", + "--timeout", + type=int, + default=300, + help="The timeout for the device flow in seconds", +) +def vscode_poll_device(device_code: str, interval: int, timeout: int) -> None: + """ + Poll the device flow for VSCode integration + """ + # For testing, we'll just succeed immediately + save_auth_state( + { + "is_logged_in": True, + "id_token": { + "iss": "https://mock.tobikodata.com", + "aud": "mock-audience", + "sub": "device-user-123", + "scope": "openid email profile", + "iat": int(time.time()), + "exp": int(time.time()) + 3600, + "email": "device@example.com", + "name": "Device User", + }, + } + ) + + print(json.dumps({"success": True})) + + +# Add auth group to main CLI +cli.add_command(auth) + + +if __name__ == "__main__": + cli() diff --git a/vscode/extension/tests/tcloud/pyproject.toml b/vscode/extension/tests/tcloud/pyproject.toml new file mode 100644 index 0000000000..505af5c783 --- /dev/null +++ b/vscode/extension/tests/tcloud/pyproject.toml @@ -0,0 +1,18 @@ +[build-system] +requires = ["setuptools>=61.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "mock-tcloud" +version = "0.1.0" +description = "Mock tcloud CLI for testing VSCode extension" +requires-python = ">=3.8" +dependencies = [ + "click>=8.0", +] + +[project.scripts] +tcloud = "mock_tcloud.cli:cli" + +[tool.setuptools] +packages = ["mock_tcloud"] \ No newline at end of file diff --git a/vscode/extension/tests/tcloud_utils.ts b/vscode/extension/tests/tcloud_utils.ts new file mode 100644 index 0000000000..5334b69ef6 --- /dev/null +++ b/vscode/extension/tests/tcloud_utils.ts @@ -0,0 +1,34 @@ +import path from 'path' +import fs from 'fs-extra' + +/** + * Helper function to set up a pre-authenticated tcloud state + */ +export async function setupAuthenticatedState(tempDir: string): Promise { + const authStateFile = path.join(tempDir, '.tcloud_auth_state.json') + const authState = { + is_logged_in: true, + id_token: { + iss: 'https://mock.tobikodata.com', + aud: 'mock-audience', + sub: 'user-123', + scope: 'openid email profile', + iat: Math.floor(Date.now() / 1000), + exp: Math.floor(Date.now() / 1000) + 3600, // Valid for 1 hour + email: 'test@example.com', + name: 'Test User', + }, + } + await fs.writeJson(authStateFile, authState) +} + +/** + * Helper function to set the tcloud version for testing + */ +export async function setTcloudVersion( + tempDir: string, + version: string, +): Promise { + const versionStateFile = path.join(tempDir, '.tcloud_version_state.json') + await fs.writeJson(versionStateFile, { version }) +} diff --git a/vscode/extension/tests/tests.spec.ts b/vscode/extension/tests/tests.spec.ts new file mode 100644 index 0000000000..bea8776447 --- /dev/null +++ b/vscode/extension/tests/tests.spec.ts @@ -0,0 +1,43 @@ +import { test } from './fixtures' +import fs from 'fs-extra' +import { + openServerPage, + runCommand, + SUSHI_SOURCE_PATH, + waitForLoadedSQLMesh, +} from './utils' +import { createPythonInterpreterSettingsSpecifier } from './utils_code_server' + +test('Format project works correctly', async ({ + page, + sharedCodeServer, + tempDir, +}) => { + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + + await createPythonInterpreterSettingsSpecifier(tempDir) + await openServerPage(page, tempDir, sharedCodeServer) + + // Wait for the models folder to be visible + await page.waitForSelector('text=models') + + // Click on the models folder, excluding external_models + await page + .getByRole('treeitem', { name: 'models', exact: true }) + .locator('a') + .click() + + // Open the customer_revenue_lifetime model + await page + .getByRole('treeitem', { name: 'customers.sql', exact: true }) + .locator('a') + .click() + + await page.waitForSelector('text=grain') + await waitForLoadedSQLMesh(page) + + // Format the project + await runCommand(page, 'Test: Run All Tests') + + await page.waitForSelector('text=test_order_items') +}) diff --git a/vscode/extension/tests/utils.ts b/vscode/extension/tests/utils.ts new file mode 100644 index 0000000000..effdc3c062 --- /dev/null +++ b/vscode/extension/tests/utils.ts @@ -0,0 +1,236 @@ +import path from 'path' +import { Page } from '@playwright/test' +import { execAsync } from '../src/utilities/exec' +import { CodeServerContext } from './utils_code_server' + +// Where your extension lives on disk +export const EXT_PATH = path.resolve(__dirname, '..') +// Where the sushi project lives which we copy from +export const SUSHI_SOURCE_PATH = path.join( + __dirname, + '..', + '..', + '..', + 'examples', + 'sushi', +) +export const MULTI_SOURCE_PATH = path.join( + __dirname, + '..', + '..', + '..', + 'examples', + 'multi', +) +export const REPO_ROOT = path.join(__dirname, '..', '..', '..') + +/** + * Click on the Explorer tab in the VS Code activity bar if the Explorer tab is not already active. + * This is necessary because the Explorer tab may not be visible if the user has not opened it yet. + */ +export const clickExplorerTab = async (page: Page): Promise => { + const isExplorerActive = await page.locator("text='Explorer'").isVisible() + if (!isExplorerActive) { + // Wait for the activity bar to be loaded + await page.waitForSelector('.actions-container[role="tablist"]') + + // Click on the Explorer tab using the codicon class + await page.click('.codicon-explorer-view-icon') + + // Wait a bit for the explorer view to activate + await page.locator("text='Explorer'").waitFor({ state: 'visible' }) + } +} + +export interface PythonEnvironment { + pythonPath: string + pipPath: string +} + +/** + * Create a virtual environment in the given directory using uv. + * @param venvDir The directory to create the virtual environment in. + */ +export const createVirtualEnvironment = async ( + venvDir: string, +): Promise => { + // Try to use uv first, fallback to python -m venv + const { exitCode, stderr } = await execAsync(`uv venv "${venvDir}"`) + if (exitCode !== 0) { + throw new Error(`Failed to create venv with uv: ${stderr}`) + } + + // Get paths + const isWindows = process.platform === 'win32' + const binDir = path.join(venvDir, isWindows ? 'Scripts' : 'bin') + const pythonPath = path.join(binDir, isWindows ? 'python.exe' : 'python') + const pipPath = path.join(binDir, isWindows ? 'pip.exe' : 'pip') + + return { + pythonPath, + pipPath, + } +} + +/** + * Install packages in the given virtual environment using uv. + * @param pythonDetails The Python environment to use. + * @param packagePaths The paths to the packages to install (string[]). + */ +export const pipInstall = async ( + pythonDetails: PythonEnvironment, + packagePaths: string[], +): Promise => { + const packages = packagePaths.map(pkg => `-e "${pkg}"`).join(' ') + const execString = `uv pip install --python "${pythonDetails.pythonPath}" ${packages}` + const { stderr, exitCode } = await execAsync(execString) + if (exitCode !== 0) { + throw new Error(`Failed to install package with uv: ${stderr}`) + } +} + +/** + * Open the lineage view in the given window. + */ +export const openLineageView = async (page: Page) => + await runCommand(page, 'Lineage: Focus On View') + +/** + * Open the problems/diagnostics view in the given window. + */ +export const openProblemsView = async (page: Page) => + await runCommand(page, 'View: Focus Problems') + +/** + * Restart the SQLMesh servers + */ +export const restartSqlmeshServers = async (page: Page) => + runCommand(page, 'SQLMesh: Restart Servers') + +/** + * Open the vscode command palette and run the given command. + * @param page The window to run the command in. + * @param command The command to run. + */ +export const runCommand = async ( + page: Page, + command: string, +): Promise => { + const maxRetries = 3 + const retryDelay = 3000 + + for (let attempt = 0; attempt < maxRetries; attempt++) { + try { + await page.keyboard.press( + process.platform === 'darwin' ? 'Meta+Shift+P' : 'Control+Shift+P', + ) + await page.waitForSelector( + 'input[aria-label="Type the name of a command to run."]', + { timeout: 5000 }, + ) + await page.keyboard.type(command) + const commandElement = await page.waitForSelector( + `a:has-text("${command}")`, + { timeout: 5000 }, + ) + await commandElement.click() + return // Success, exit the retry loop + } catch (error) { + if (attempt === maxRetries - 1) { + throw error // Last attempt failed, throw the error + } + + // Close any open command palette before retrying + await page.keyboard.press('Escape') + await page.waitForTimeout(retryDelay) + } + } +} + +/** + * Go to definition. Assumes the location is clicked. + */ +export const goToDefinition = async (page: Page) => + runCommand(page, 'Go to Definition') + +/** + * Save file + */ +export const saveFile = async (page: Page) => runCommand(page, 'File: Save') + +/** + * Rename Symbol opens the rename symbol dialog in VS Code. + */ +export const renameSymbol = async (page: Page) => + runCommand(page, 'Rename Symbol') + +/** + * Find all references to the symbol under the cursor. + */ +export const findAllReferences = async (page: Page): Promise => + runCommand(page, 'References: Find All References') + +/** + * Go to references. Assumes the location is clicked. + */ +export const goToReferences = async (page: Page): Promise => + runCommand(page, 'Go to References') + +/** + * Open the vscode code file picker and select the given file. + */ +export const openFile = async (page: Page, file: string): Promise => { + const maxRetries = 3 + const retryDelay = 3000 + + const fileName = path.basename(file) + + for (let attempt = 0; attempt < maxRetries; attempt++) { + try { + await page.keyboard.press( + process.platform === 'darwin' ? 'Meta+P' : 'Control+P', + ) + await page + .getByRole('textbox', { name: 'Search files by name' }) + .waitFor({ state: 'visible', timeout: 5000 }) + await page.keyboard.type(file) + const commandElement = await page.waitForSelector( + `a:has-text("${fileName}")`, + { timeout: 5000 }, + ) + await commandElement.click() + return // Success, exit the retry loop + } catch (error) { + if (attempt === maxRetries - 1) { + throw error // Last attempt failed, throw the error + } + + // Close any open command palette before retrying + await page.keyboard.press('Escape') + await page.waitForTimeout(retryDelay) + } + } +} + +/** + * Wait for SQLMesh context to be loaded. + */ +export const waitForLoadedSQLMesh = (page: Page) => + page.waitForSelector('text=Loaded SQLMesh Context') + +/** + * Go to VSCode page + */ +export const openServerPage = async ( + page: Page, + targetPath: string, + context: CodeServerContext, +) => { + const isWorkspace = targetPath.endsWith('.code-workspace') + const param = isWorkspace ? 'workspace' : 'folder' + await page.goto( + `http://127.0.0.1:${context.codeServerPort}/?${param}=${targetPath}`, + ) + await page.waitForLoadState('networkidle') + await page.waitForSelector('[role="application"]', { timeout: 10000 }) +} diff --git a/vscode/extension/tests/utils_code_server.ts b/vscode/extension/tests/utils_code_server.ts new file mode 100644 index 0000000000..68bf2ed597 --- /dev/null +++ b/vscode/extension/tests/utils_code_server.ts @@ -0,0 +1,178 @@ +import { spawn, ChildProcess } from 'child_process' +import path from 'path' +import fs from 'fs-extra' +import os from 'os' +import { clearTimeout } from 'node:timers' + +export interface CodeServerContext { + codeServerProcess: ChildProcess + codeServerPort: number + tempDir: string + defaultPythonInterpreter: string +} + +/** + * Get the path to the extensions directory set up by global setup + * @returns The extensions directory path + */ +function getExtensionsDir(): string { + const extensionDir = path.join(__dirname, '..') + const extensionsDir = path.join(extensionDir, '.test_setup', 'extensions') + + if (!fs.existsSync(extensionsDir)) { + throw new Error( + `Extensions directory not found at ${extensionsDir}. Make sure global setup has run.`, + ) + } + + return extensionsDir +} + +/** + * Creates a .vscode/settings.json specifier for the Python interpreter + */ +export const createPythonInterpreterSettingsSpecifier = async ( + directory: string, +): Promise => { + const defaultPythonInterpreter = path.join( + __dirname, + '..', + '..', + '..', + '.venv', + 'bin', + 'python', + ) + const vscodeDir = path.join(directory, '.vscode') + await fs.ensureDir(vscodeDir) + const settingsFilePath = path.join(vscodeDir, 'settings.json') + await fs.writeJson(settingsFilePath, { + 'python.defaultInterpreterPath': defaultPythonInterpreter, + }) + return settingsFilePath +} + +/** + * @param tempDir - The temporary directory to use for the code-server instance + * @param placeFileWithPythonInterpreter - Whether to place a vscode/settings.json file in the temp directory that points to the python interpreter of the environmen the test is running in. + * @returns The code-server context + */ +export async function startCodeServer({ + tempDir, +}: { + tempDir: string +}): Promise { + // Get the extensions directory set up by global setup + const extensionsDir = getExtensionsDir() + + // Find an available port + const codeServerPort = Math.floor(Math.random() * 10000) + 50000 + const defaultPythonInterpreter = path.join( + __dirname, + '..', + '..', + '..', + '.venv', + 'bin', + 'python', + ) + + const userDataDir = await fs.mkdtemp( + path.join(os.tmpdir(), 'vscode-test-sushi-user-data-dir-'), + ) + + // Start code-server instance using the shared extensions directory + const codeServerProcess = spawn( + 'pnpm', + [ + 'run', + 'code-server', + '--bind-addr', + `127.0.0.1:${codeServerPort}`, + '--auth', + 'none', + '--disable-telemetry', + '--disable-update-check', + '--disable-workspace-trust', + '--user-data-dir', + userDataDir, + '--extensions-dir', + extensionsDir, + tempDir, + ], + { + stdio: 'pipe', + cwd: path.join(__dirname, '..'), + }, + ) + + // Wait for code-server to be ready + await new Promise((resolve, reject) => { + let output = '' + const timeout = setTimeout(() => { + reject(new Error('Code-server failed to start within timeout')) + }, 30000) + + codeServerProcess.stdout?.on('data', (data: Buffer) => { + output += data.toString() + if (output.includes('HTTP server listening on')) { + clearTimeout(timeout) + resolve() + } + }) + + codeServerProcess.stderr?.on('data', (data: Buffer) => { + console.error('Code-server stderr:', data.toString()) + }) + + codeServerProcess.on('error', error => { + clearTimeout(timeout) + reject(error) + }) + + codeServerProcess.on('exit', code => { + if (code !== 0) { + clearTimeout(timeout) + console.error('Code-server exited with code:', code) + } + }) + }) + + return { + codeServerProcess, + codeServerPort, + tempDir, + defaultPythonInterpreter, + } +} + +export async function stopCodeServer( + context: CodeServerContext, +): Promise { + const { codeServerProcess, tempDir } = context + + // Clean up code-server process + codeServerProcess.kill('SIGTERM') + + // Wait for process to exit + await new Promise(resolve => { + codeServerProcess.on('exit', () => { + resolve() + }) + // Force kill after 5 seconds + setTimeout(() => { + if (!codeServerProcess.killed) { + codeServerProcess.kill('SIGKILL') + } + resolve() + }, 5000) + }) + + // Clean up temporary directory + try { + await fs.remove(tempDir) + } catch (error) { + // Ignore errors when removing temp directory + console.warn(`Failed to remove temp directory ${tempDir}:`, error) + } +} diff --git a/vscode/extension/tests/venv_naming.spec.ts b/vscode/extension/tests/venv_naming.spec.ts new file mode 100644 index 0000000000..5cb1730a18 --- /dev/null +++ b/vscode/extension/tests/venv_naming.spec.ts @@ -0,0 +1,38 @@ +import { test } from './fixtures' +import fs from 'fs-extra' +import path from 'path' +import { + createVirtualEnvironment, + openLineageView, + openServerPage, + pipInstall, + REPO_ROOT, + SUSHI_SOURCE_PATH, + waitForLoadedSQLMesh, +} from './utils' + +test('venv being named .env', async ({ page, sharedCodeServer, tempDir }) => { + const pythonEnvDir = path.join(tempDir, '.env') + const pythonDetails = await createVirtualEnvironment(pythonEnvDir) + const custom_materializations = path.join( + REPO_ROOT, + 'examples', + 'custom_materializations', + ) + const sqlmeshWithExtras = `${REPO_ROOT}[bigquery,lsp]` + await pipInstall(pythonDetails, [sqlmeshWithExtras, custom_materializations]) + await fs.copy(SUSHI_SOURCE_PATH, tempDir) + + const settings = { + 'python.defaultInterpreterPath': pythonDetails.pythonPath, + } + await fs.ensureDir(path.join(tempDir, '.vscode')) + await fs.writeJson(path.join(tempDir, '.vscode', 'settings.json'), settings, { + spaces: 2, + }) + + await openServerPage(page, tempDir, sharedCodeServer) + await page.waitForSelector('text=models') + await openLineageView(page) + await waitForLoadedSQLMesh(page) +}) diff --git a/vscode/extension/tsconfig.build.json b/vscode/extension/tsconfig.build.json new file mode 100644 index 0000000000..2ea6c120ab --- /dev/null +++ b/vscode/extension/tsconfig.build.json @@ -0,0 +1,11 @@ +{ + "extends": "./tsconfig.json", + "exclude": [ + "node_modules", + "../node_modules", + "../../node_modules", + "src/**/*.test.ts", + "tests/**/*", + "scripts/**/*" + ] +} diff --git a/vscode/extension/tsconfig.json b/vscode/extension/tsconfig.json new file mode 100644 index 0000000000..a20a2f08d3 --- /dev/null +++ b/vscode/extension/tsconfig.json @@ -0,0 +1,26 @@ +{ + "compilerOptions": { + "module": "Node16", + "target": "ES2022", + "lib": ["ES2022", "DOM"], + "sourceMap": true, + "strict": true /* enable all strict type-checking options */, + /* Additional Checks */ + "noImplicitReturns": true, + "noFallthroughCasesInSwitch": true, + "noUnusedParameters": true, + "noUnusedLocals": true, + "types": ["mocha"], + "paths": { + "@bus/*": ["../bus/src/*"] + } + }, + "include": ["tests/**/*", "src/**/*", "../bus/src/**/*"], + "exclude": [ + "node_modules", + "../node_modules", + "../../node_modules", + "tests/**/*.test.ts", + "src/**/*.test.ts" + ] +} diff --git a/vscode/extension/tsconfig.test.json b/vscode/extension/tsconfig.test.json new file mode 100644 index 0000000000..752c0f0952 --- /dev/null +++ b/vscode/extension/tsconfig.test.json @@ -0,0 +1,8 @@ +{ + "extends": "./tsconfig.json", + "compilerOptions": { + "types": ["node"] + }, + "include": ["src/**/*.test.ts", "tests/**/*"], + "exclude": ["node_modules", "../node_modules", "../../node_modules"] +} diff --git a/vscode/extension/vitest.config.ts b/vscode/extension/vitest.config.ts new file mode 100644 index 0000000000..49fbea3be3 --- /dev/null +++ b/vscode/extension/vitest.config.ts @@ -0,0 +1,9 @@ +import { defineConfig } from 'vitest/config' + +export default defineConfig({ + test: { + globals: true, + include: ['src/**/*.test.ts'], + exclude: ['**/node_modules/**', '**/dist/**', '**/out/**'], + }, +}) diff --git a/vscode/openapi.json b/vscode/openapi.json new file mode 100644 index 0000000000..32a7445e32 --- /dev/null +++ b/vscode/openapi.json @@ -0,0 +1,2240 @@ +{ + "openapi": "3.1.0", + "info": { "title": "FastAPI", "version": "0.1.0" }, + "paths": { + "/api/commands/apply": { + "post": { + "summary": "Initiate Apply", + "description": "Apply a plan", + "operationId": "initiate_apply_api_commands_apply_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Body_initiate_apply_api_commands_apply_post" + } + } + } + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "anyOf": [ + { "$ref": "#/components/schemas/PlanApplyStageTracker" }, + { "type": "null" } + ], + "title": "Response Initiate Apply Api Commands Apply Post" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + }, + "/api/commands/evaluate": { + "post": { + "summary": "Evaluate", + "description": "Evaluate a model with a default limit of 1000", + "operationId": "evaluate_api_commands_evaluate_post", + "requestBody": { + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/EvaluateInput" } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { "application/json": { "schema": {} } } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + }, + "/api/commands/fetchdf": { + "post": { + "summary": "Fetchdf", + "description": "Fetches a dataframe given a sql string", + "operationId": "fetchdf_api_commands_fetchdf_post", + "requestBody": { + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/FetchdfInput" } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { "application/json": { "schema": {} } } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + }, + "/api/commands/render": { + "post": { + "summary": "Render", + "description": "Renders a model's query, optionally expanding referenced models", + "operationId": "render_api_commands_render_post", + "requestBody": { + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/RenderInput" } + } + }, + "required": true + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/Query" } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + }, + "/api/commands/test": { + "get": { + "summary": "Test", + "description": "Run one or all model tests", + "operationId": "test_api_commands_test_get", + "parameters": [ + { + "name": "test", + "in": "query", + "required": false, + "schema": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Test" + } + }, + { + "name": "verbosity", + "in": "query", + "required": false, + "schema": { "$ref": "#/components/schemas/Verbosity", "default": 0 } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/TestResult" } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + }, + "/api/files": { + "get": { + "summary": "Get Files", + "description": "Get all project files.", + "operationId": "get_files_api_files_get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/Directory" } + } + } + } + } + } + }, + "/api/files/{path}": { + "get": { + "summary": "Get File", + "description": "Get a file, including its contents.", + "operationId": "get_file_api_files__path__get", + "parameters": [ + { + "name": "path", + "in": "path", + "required": true, + "schema": { "type": "string", "title": "Path" } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/File" } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + }, + "post": { + "summary": "Write File", + "description": "Create, update, or rename a file.", + "operationId": "write_file_api_files__path__post", + "parameters": [ + { + "name": "path", + "in": "path", + "required": true, + "schema": { "type": "string", "title": "Path" } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Body_write_file_api_files__path__post" + } + } + } + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "anyOf": [ + { "$ref": "#/components/schemas/File" }, + { "type": "null" } + ], + "title": "Response Write File Api Files Path Post" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + }, + "delete": { + "summary": "Delete File", + "description": "Delete a file.", + "operationId": "delete_file_api_files__path__delete", + "parameters": [ + { + "name": "path", + "in": "path", + "required": true, + "schema": { "type": "string", "title": "Path" } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "null", + "title": "Response Delete File Api Files Path Delete" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + }, + "/api/directories/{path}": { + "post": { + "summary": "Write Directory", + "description": "Create or rename a directory.", + "operationId": "write_directory_api_directories__path__post", + "parameters": [ + { + "name": "path", + "in": "path", + "required": true, + "schema": { "type": "string", "title": "Path" } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Body_write_directory_api_directories__path__post" + } + } + } + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/Directory" } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + }, + "delete": { + "summary": "Delete Directory", + "description": "Delete a directory.", + "operationId": "delete_directory_api_directories__path__delete", + "parameters": [ + { + "name": "path", + "in": "path", + "required": true, + "schema": { "type": "string", "title": "Path" } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { "application/json": { "schema": {} } } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + }, + "/api/plan": { + "post": { + "summary": "Initiate Plan", + "description": "Get a plan for an environment.", + "operationId": "initiate_plan_api_plan_post", + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Body_initiate_plan_api_plan_post" + } + } + } + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "anyOf": [ + { "$ref": "#/components/schemas/PlanOverviewStageTracker" }, + { "type": "null" } + ], + "title": "Response Initiate Plan Api Plan Post" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + }, + "/api/plan/cancel": { + "post": { + "summary": "Cancel Plan", + "description": "Cancel a plan application", + "operationId": "cancel_plan_api_plan_cancel_post", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "anyOf": [ + { "$ref": "#/components/schemas/PlanCancelStageTracker" }, + { "type": "null" } + ], + "title": "Response Cancel Plan Api Plan Cancel Post" + } + } + } + } + } + } + }, + "/api/environments": { + "get": { + "summary": "Get Environments", + "description": "Get the environments", + "operationId": "get_environments_api_environments_get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/Environments" } + } + } + } + } + } + }, + "/api/environments/{environment}": { + "delete": { + "summary": "Delete Environment", + "description": "Invalidate and delete an environment", + "operationId": "delete_environment_api_environments__environment__delete", + "parameters": [ + { + "name": "environment", + "in": "path", + "required": true, + "schema": { "type": "string", "title": "Environment" } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { "application/json": { "schema": {} } } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + }, + "/api/events": { + "get": { + "summary": "Events", + "description": "SQLMesh console server sent events", + "operationId": "events_api_events_get", + "responses": { + "200": { + "description": "Successful Response", + "content": { "application/json": { "schema": {} } } + } + } + } + }, + "/api/lineage/{model_name}/{column_name}": { + "get": { + "summary": "Column Lineage", + "description": "Get a column's lineage", + "operationId": "column_lineage_api_lineage__model_name___column_name__get", + "parameters": [ + { + "name": "model_name", + "in": "path", + "required": true, + "schema": { "type": "string", "title": "Model Name" } + }, + { + "name": "column_name", + "in": "path", + "required": true, + "schema": { "type": "string", "title": "Column Name" } + }, + { + "name": "models_only", + "in": "query", + "required": false, + "schema": { + "type": "boolean", + "default": false, + "title": "Models Only" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "object", + "additionalProperties": { + "type": "object", + "additionalProperties": { + "$ref": "#/components/schemas/LineageColumn" + } + }, + "title": "Response Column Lineage Api Lineage Model Name Column Name Get" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + }, + "/api/lineage/{model_name}": { + "get": { + "summary": "Model Lineage", + "description": "Get a model's lineage", + "operationId": "model_lineage_api_lineage__model_name__get", + "parameters": [ + { + "name": "model_name", + "in": "path", + "required": true, + "schema": { "type": "string", "title": "Model Name" } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "object", + "additionalProperties": { + "type": "array", + "uniqueItems": true, + "items": { "type": "string" } + }, + "title": "Response Model Lineage Api Lineage Model Name Get" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + }, + "/api/models": { + "get": { + "summary": "Get Models", + "description": "Get a list of models", + "operationId": "get_models_api_models_get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "anyOf": [ + { + "items": { "$ref": "#/components/schemas/Model" }, + "type": "array" + }, + { "$ref": "#/components/schemas/ApiExceptionPayload" } + ], + "title": "Response Get Models Api Models Get" + } + } + } + } + } + } + }, + "/api/models/{name}": { + "get": { + "summary": "Get Model", + "description": "Get a single model", + "operationId": "get_model_api_models__name__get", + "parameters": [ + { + "name": "name", + "in": "path", + "required": true, + "schema": { "type": "string", "title": "Name" } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/Model" } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + }, + "/api/meta": { + "get": { + "summary": "Get Api Meta", + "description": "Get the metadata", + "operationId": "get_api_meta_api_meta_get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/Meta" } + } + } + } + } + } + }, + "/api/modules": { + "get": { + "summary": "Get Api Modules", + "description": "Get the modules", + "operationId": "get_api_modules_api_modules_get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "items": { "$ref": "#/components/schemas/Modules" }, + "type": "array", + "title": "Response Get Api Modules Api Modules Get" + } + } + } + } + } + } + }, + "/api/table_diff": { + "get": { + "summary": "Get Table Diff", + "description": "Calculate differences between tables, taking into account schema and row level differences.", + "operationId": "get_table_diff_api_table_diff_get", + "parameters": [ + { + "name": "source", + "in": "query", + "required": true, + "schema": { "type": "string", "title": "Source" } + }, + { + "name": "target", + "in": "query", + "required": true, + "schema": { "type": "string", "title": "Target" } + }, + { + "name": "on", + "in": "query", + "required": false, + "schema": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "On" + } + }, + { + "name": "model_or_snapshot", + "in": "query", + "required": false, + "schema": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Model Or Snapshot" + } + }, + { + "name": "where", + "in": "query", + "required": false, + "schema": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Where" + } + }, + { + "name": "temp_schema", + "in": "query", + "required": false, + "schema": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Temp Schema" + } + }, + { + "name": "limit", + "in": "query", + "required": false, + "schema": { "type": "integer", "default": 20, "title": "Limit" } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "anyOf": [ + { "$ref": "#/components/schemas/TableDiff" }, + { "type": "null" } + ], + "title": "Response Get Table Diff Api Table Diff Get" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + }, + "/health": { + "get": { + "summary": "Health", + "operationId": "health_health_get", + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "string", + "title": "Response Health Health Get" + } + } + } + } + } + } + }, + "/{full_path}": { + "get": { + "summary": "Index", + "operationId": "index__full_path__get", + "parameters": [ + { + "name": "full_path", + "in": "path", + "required": true, + "schema": { "type": "string", "title": "Full Path" } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { "application/json": { "schema": {} } } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { "$ref": "#/components/schemas/HTTPValidationError" } + } + } + } + } + } + } + }, + "components": { + "schemas": { + "ApiExceptionPayload": { + "properties": { + "timestamp": { "type": "integer", "title": "Timestamp" }, + "message": { "type": "string", "title": "Message" }, + "origin": { "type": "string", "title": "Origin" }, + "status": { + "anyOf": [{ "type": "integer" }, { "type": "null" }], + "title": "Status" + }, + "trigger": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Trigger" + }, + "type": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Type" + }, + "description": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Description" + }, + "traceback": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Traceback" + }, + "stack": { + "anyOf": [ + { "items": { "type": "string" }, "type": "array" }, + { "type": "null" } + ], + "title": "Stack" + } + }, + "additionalProperties": false, + "type": "object", + "required": ["timestamp", "message", "origin"], + "title": "ApiExceptionPayload" + }, + "BackfillDetails": { + "properties": { + "name": { "type": "string", "title": "Name" }, + "view_name": { "type": "string", "title": "View Name" }, + "node_type": { + "$ref": "#/components/schemas/NodeType", + "default": "model" + }, + "parents": { + "items": { "type": "string" }, + "type": "array", + "uniqueItems": true, + "title": "Parents", + "default": [] + }, + "interval": { + "items": { "type": "string" }, + "type": "array", + "title": "Interval" + }, + "batches": { "type": "integer", "title": "Batches" } + }, + "additionalProperties": false, + "type": "object", + "required": ["name", "view_name", "interval", "batches"], + "title": "BackfillDetails" + }, + "BackfillTask": { + "properties": { + "name": { "type": "string", "title": "Name" }, + "view_name": { "type": "string", "title": "View Name" }, + "node_type": { + "$ref": "#/components/schemas/NodeType", + "default": "model" + }, + "parents": { + "items": { "type": "string" }, + "type": "array", + "uniqueItems": true, + "title": "Parents", + "default": [] + }, + "completed": { "type": "integer", "title": "Completed" }, + "total": { "type": "integer", "title": "Total" }, + "start": { "type": "integer", "title": "Start" }, + "end": { + "anyOf": [{ "type": "integer" }, { "type": "null" }], + "title": "End" + }, + "interval": { + "anyOf": [ + { "items": { "type": "string" }, "type": "array" }, + { "type": "null" } + ], + "title": "Interval" + } + }, + "additionalProperties": false, + "type": "object", + "required": ["name", "view_name", "completed", "total", "start"], + "title": "BackfillTask" + }, + "Body_initiate_apply_api_commands_apply_post": { + "properties": { + "environment": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Environment" + }, + "plan_dates": { + "anyOf": [ + { "$ref": "#/components/schemas/PlanDates" }, + { "type": "null" } + ] + }, + "plan_options": { + "anyOf": [ + { "$ref": "#/components/schemas/PlanOptions" }, + { "type": "null" } + ] + }, + "categories": { + "anyOf": [ + { + "additionalProperties": { + "$ref": "#/components/schemas/SnapshotChangeCategory" + }, + "type": "object" + }, + { "type": "null" } + ], + "title": "Categories" + } + }, + "type": "object", + "title": "Body_initiate_apply_api_commands_apply_post" + }, + "Body_initiate_plan_api_plan_post": { + "properties": { + "environment": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Environment" + }, + "plan_dates": { + "anyOf": [ + { "$ref": "#/components/schemas/PlanDates" }, + { "type": "null" } + ] + }, + "plan_options": { + "anyOf": [ + { "$ref": "#/components/schemas/PlanOptions" }, + { "type": "null" } + ] + }, + "categories": { + "anyOf": [ + { + "additionalProperties": { + "$ref": "#/components/schemas/SnapshotChangeCategory" + }, + "type": "object" + }, + { "type": "null" } + ], + "title": "Categories" + } + }, + "type": "object", + "title": "Body_initiate_plan_api_plan_post" + }, + "Body_write_directory_api_directories__path__post": { + "properties": { + "new_path": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "New Path" + } + }, + "type": "object", + "title": "Body_write_directory_api_directories__path__post" + }, + "Body_write_file_api_files__path__post": { + "properties": { + "content": { "type": "string", "title": "Content", "default": "" }, + "new_path": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "New Path" + } + }, + "type": "object", + "title": "Body_write_file_api_files__path__post" + }, + "ChangeDirect": { + "properties": { + "name": { "type": "string", "title": "Name" }, + "view_name": { "type": "string", "title": "View Name" }, + "node_type": { + "$ref": "#/components/schemas/NodeType", + "default": "model" + }, + "parents": { + "items": { "type": "string" }, + "type": "array", + "uniqueItems": true, + "title": "Parents", + "default": [] + }, + "diff": { "type": "string", "title": "Diff" }, + "indirect": { + "items": { "$ref": "#/components/schemas/ChangeDisplay" }, + "type": "array", + "title": "Indirect", + "default": [] + }, + "direct": { + "items": { "$ref": "#/components/schemas/ChangeDisplay" }, + "type": "array", + "title": "Direct", + "default": [] + }, + "change_category": { + "anyOf": [ + { "$ref": "#/components/schemas/SnapshotChangeCategory" }, + { "type": "null" } + ] + } + }, + "additionalProperties": false, + "type": "object", + "required": ["name", "view_name", "diff"], + "title": "ChangeDirect" + }, + "ChangeDisplay": { + "properties": { + "name": { "type": "string", "title": "Name" }, + "view_name": { "type": "string", "title": "View Name" }, + "node_type": { + "$ref": "#/components/schemas/NodeType", + "default": "model" + }, + "parents": { + "items": { "type": "string" }, + "type": "array", + "uniqueItems": true, + "title": "Parents", + "default": [] + } + }, + "additionalProperties": false, + "type": "object", + "required": ["name", "view_name"], + "title": "ChangeDisplay" + }, + "ChangeIndirect": { + "properties": { + "name": { "type": "string", "title": "Name" }, + "view_name": { "type": "string", "title": "View Name" }, + "node_type": { + "$ref": "#/components/schemas/NodeType", + "default": "model" + }, + "parents": { + "items": { "type": "string" }, + "type": "array", + "uniqueItems": true, + "title": "Parents", + "default": [] + } + }, + "additionalProperties": false, + "type": "object", + "required": ["name", "view_name"], + "title": "ChangeIndirect" + }, + "Column": { + "properties": { + "name": { "type": "string", "title": "Name" }, + "type": { "type": "string", "title": "Type" }, + "description": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Description" + } + }, + "additionalProperties": false, + "type": "object", + "required": ["name", "type"], + "title": "Column" + }, + "Directory": { + "properties": { + "name": { "type": "string", "title": "Name" }, + "path": { "type": "string", "title": "Path" }, + "directories": { + "items": { "$ref": "#/components/schemas/Directory" }, + "type": "array", + "title": "Directories", + "default": [] + }, + "files": { + "items": { "$ref": "#/components/schemas/File" }, + "type": "array", + "title": "Files", + "default": [] + } + }, + "additionalProperties": false, + "type": "object", + "required": ["name", "path"], + "title": "Directory" + }, + "Environment": { + "properties": { + "name": { "type": "string", "title": "Name", "default": "prod" }, + "start_at": { + "anyOf": [ + { "type": "string", "format": "date" }, + { "type": "string", "format": "date-time" }, + { "type": "string" }, + { "type": "integer" }, + { "type": "number" } + ], + "title": "Start At" + }, + "end_at": { + "anyOf": [ + { "type": "string", "format": "date" }, + { "type": "string", "format": "date-time" }, + { "type": "string" }, + { "type": "integer" }, + { "type": "number" }, + { "type": "null" } + ], + "title": "End At" + }, + "plan_id": { "type": "string", "title": "Plan Id" }, + "previous_plan_id": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Previous Plan Id" + }, + "expiration_ts": { + "anyOf": [{ "type": "integer" }, { "type": "null" }], + "title": "Expiration Ts" + }, + "finalized_ts": { + "anyOf": [{ "type": "integer" }, { "type": "null" }], + "title": "Finalized Ts" + }, + "suffix_target": { + "$ref": "#/components/schemas/EnvironmentSuffixTarget", + "default": "schema" + }, + "catalog_name_override": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Catalog Name Override" + }, + "normalize_name": { + "type": "boolean", + "title": "Normalize Name", + "default": true + }, + "gateway_managed": { + "type": "boolean", + "title": "Gateway Managed", + "default": false + }, + "snapshots": { "items": {}, "type": "array", "title": "Snapshots" }, + "promoted_snapshot_ids": { + "anyOf": [{ "items": {}, "type": "array" }, { "type": "null" }], + "title": "Promoted Snapshot Ids" + }, + "previous_finalized_snapshots": { + "anyOf": [{ "items": {}, "type": "array" }, { "type": "null" }], + "title": "Previous Finalized Snapshots" + }, + "requirements": { + "additionalProperties": { "type": "string" }, + "type": "object", + "title": "Requirements", + "default": {} + } + }, + "additionalProperties": false, + "type": "object", + "required": ["start_at", "plan_id", "snapshots"], + "title": "Environment", + "description": "Represents an isolated environment.\n\nEnvironments are isolated workspaces that hold pointers to physical tables.\n\nArgs:\n snapshots: The snapshots that are part of this environment.\n promoted_snapshot_ids: The IDs of the snapshots that are promoted in this environment\n (i.e. for which the views are created). If not specified, all snapshots are promoted.\n previous_finalized_snapshots: Snapshots that were part of this environment last time it was finalized.\n requirements: A mapping of library versions for all the snapshots in this environment." + }, + "EnvironmentSuffixTarget": { + "type": "string", + "enum": ["schema", "table"], + "title": "EnvironmentSuffixTarget" + }, + "Environments": { + "properties": { + "environments": { + "additionalProperties": { + "$ref": "#/components/schemas/Environment" + }, + "type": "object", + "title": "Environments", + "default": {} + }, + "pinned_environments": { + "items": { "type": "string" }, + "type": "array", + "uniqueItems": true, + "title": "Pinned Environments", + "default": [] + }, + "default_target_environment": { + "type": "string", + "title": "Default Target Environment", + "default": "" + } + }, + "additionalProperties": false, + "type": "object", + "title": "Environments" + }, + "EvaluateInput": { + "properties": { + "model": { "type": "string", "title": "Model" }, + "start": { + "anyOf": [ + { "type": "string", "format": "date" }, + { "type": "string", "format": "date-time" }, + { "type": "string" }, + { "type": "integer" }, + { "type": "number" } + ], + "title": "Start" + }, + "end": { + "anyOf": [ + { "type": "string", "format": "date" }, + { "type": "string", "format": "date-time" }, + { "type": "string" }, + { "type": "integer" }, + { "type": "number" } + ], + "title": "End" + }, + "execution_time": { + "anyOf": [ + { "type": "string", "format": "date" }, + { "type": "string", "format": "date-time" }, + { "type": "string" }, + { "type": "integer" }, + { "type": "number" } + ], + "title": "Execution Time" + }, + "limit": { "type": "integer", "title": "Limit", "default": 1000 } + }, + "additionalProperties": false, + "type": "object", + "required": ["model", "start", "end", "execution_time"], + "title": "EvaluateInput" + }, + "FetchdfInput": { + "properties": { + "sql": { "type": "string", "title": "Sql" }, + "limit": { "type": "integer", "title": "Limit", "default": 1000 } + }, + "additionalProperties": false, + "type": "object", + "required": ["sql"], + "title": "FetchdfInput" + }, + "File": { + "properties": { + "name": { "type": "string", "title": "Name" }, + "path": { "type": "string", "title": "Path" }, + "extension": { + "type": "string", + "title": "Extension", + "default": "" + }, + "content": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Content" + } + }, + "additionalProperties": false, + "type": "object", + "required": ["name", "path"], + "title": "File" + }, + "HTTPValidationError": { + "properties": { + "detail": { + "items": { "$ref": "#/components/schemas/ValidationError" }, + "type": "array", + "title": "Detail" + } + }, + "type": "object", + "title": "HTTPValidationError" + }, + "IntervalUnit": { + "type": "string", + "enum": [ + "year", + "month", + "day", + "hour", + "half_hour", + "quarter_hour", + "five_minute" + ], + "title": "IntervalUnit", + "description": "IntervalUnit is the inferred granularity of an incremental node.\n\nIntervalUnit can be one of 5 types, YEAR, MONTH, DAY, HOUR, MINUTE. The unit is inferred\nbased on the cron schedule of a node. The minimum time delta between a sample set of dates\nis used to determine which unit a node's schedule is." + }, + "LineageColumn": { + "properties": { + "source": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Source" + }, + "expression": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Expression" + }, + "models": { + "additionalProperties": { + "items": { "type": "string" }, + "type": "array", + "uniqueItems": true + }, + "type": "object", + "title": "Models" + } + }, + "additionalProperties": false, + "type": "object", + "required": ["models"], + "title": "LineageColumn" + }, + "Meta": { + "properties": { + "version": { "type": "string", "title": "Version" }, + "has_running_task": { + "type": "boolean", + "title": "Has Running Task", + "default": false + } + }, + "additionalProperties": false, + "type": "object", + "required": ["version"], + "title": "Meta" + }, + "Model": { + "properties": { + "name": { "type": "string", "title": "Name" }, + "fqn": { "type": "string", "title": "Fqn" }, + "path": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Path" + }, + "full_path": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Full Path" + }, + "dialect": { "type": "string", "title": "Dialect" }, + "type": { "$ref": "#/components/schemas/ModelType" }, + "columns": { + "items": { "$ref": "#/components/schemas/Column" }, + "type": "array", + "title": "Columns" + }, + "description": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Description" + }, + "details": { + "anyOf": [ + { "$ref": "#/components/schemas/ModelDetails" }, + { "type": "null" } + ] + }, + "sql": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Sql" + }, + "definition": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Definition" + }, + "default_catalog": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Default Catalog" + }, + "hash": { "type": "string", "title": "Hash" } + }, + "additionalProperties": false, + "type": "object", + "required": ["name", "fqn", "dialect", "type", "columns", "hash"], + "title": "Model" + }, + "ModelDetails": { + "properties": { + "owner": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Owner" + }, + "kind": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Kind" + }, + "batch_size": { + "anyOf": [{ "type": "integer" }, { "type": "null" }], + "title": "Batch Size" + }, + "cron": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Cron" + }, + "stamp": { + "anyOf": [ + { "type": "string", "format": "date" }, + { "type": "string", "format": "date-time" }, + { "type": "string" }, + { "type": "integer" }, + { "type": "number" }, + { "type": "null" } + ], + "title": "Stamp" + }, + "start": { + "anyOf": [ + { "type": "string", "format": "date" }, + { "type": "string", "format": "date-time" }, + { "type": "string" }, + { "type": "integer" }, + { "type": "number" }, + { "type": "null" } + ], + "title": "Start" + }, + "retention": { + "anyOf": [{ "type": "integer" }, { "type": "null" }], + "title": "Retention" + }, + "table_format": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Table Format" + }, + "storage_format": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Storage Format" + }, + "time_column": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Time Column" + }, + "tags": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Tags" + }, + "references": { + "items": { "$ref": "#/components/schemas/Reference" }, + "type": "array", + "title": "References", + "default": [] + }, + "partitioned_by": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Partitioned By" + }, + "clustered_by": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Clustered By" + }, + "lookback": { + "anyOf": [{ "type": "integer" }, { "type": "null" }], + "title": "Lookback" + }, + "cron_prev": { + "anyOf": [ + { "type": "string", "format": "date" }, + { "type": "string", "format": "date-time" }, + { "type": "string" }, + { "type": "integer" }, + { "type": "number" }, + { "type": "null" } + ], + "title": "Cron Prev" + }, + "cron_next": { + "anyOf": [ + { "type": "string", "format": "date" }, + { "type": "string", "format": "date-time" }, + { "type": "string" }, + { "type": "integer" }, + { "type": "number" }, + { "type": "null" } + ], + "title": "Cron Next" + }, + "interval_unit": { + "anyOf": [ + { "$ref": "#/components/schemas/IntervalUnit" }, + { "type": "null" } + ] + }, + "annotated": { + "anyOf": [{ "type": "boolean" }, { "type": "null" }], + "title": "Annotated" + } + }, + "additionalProperties": false, + "type": "object", + "title": "ModelDetails" + }, + "ModelType": { + "type": "string", + "enum": ["python", "sql", "seed", "external", "source"], + "title": "ModelType" + }, + "ModelsDiff": { + "properties": { + "direct": { + "items": { "$ref": "#/components/schemas/ChangeDirect" }, + "type": "array", + "title": "Direct", + "default": [] + }, + "indirect": { + "items": { "$ref": "#/components/schemas/ChangeIndirect" }, + "type": "array", + "title": "Indirect", + "default": [] + }, + "metadata": { + "items": { "$ref": "#/components/schemas/ChangeDisplay" }, + "type": "array", + "title": "Metadata", + "default": [] + } + }, + "additionalProperties": false, + "type": "object", + "title": "ModelsDiff" + }, + "Modules": { + "type": "string", + "enum": [ + "editor", + "files", + "data-catalog", + "plans", + "tests", + "audits", + "errors", + "data", + "lineage" + ], + "title": "Modules" + }, + "NodeType": { + "type": "string", + "enum": ["model", "audit"], + "title": "NodeType" + }, + "PlanApplyStageTracker": { + "properties": { + "start": { + "anyOf": [ + { "type": "string", "format": "date" }, + { "type": "string", "format": "date-time" }, + { "type": "string" }, + { "type": "integer" }, + { "type": "number" }, + { "type": "null" } + ], + "title": "Start" + }, + "end": { + "anyOf": [ + { "type": "string", "format": "date" }, + { "type": "string", "format": "date-time" }, + { "type": "string" }, + { "type": "integer" }, + { "type": "number" }, + { "type": "null" } + ], + "title": "End" + }, + "meta": { "$ref": "#/components/schemas/TrackableMeta" }, + "environment": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Environment" + }, + "plan_options": { + "anyOf": [ + { "$ref": "#/components/schemas/PlanOptions" }, + { "type": "null" } + ] + }, + "creation": { + "anyOf": [ + { "$ref": "#/components/schemas/PlanStageCreation" }, + { "type": "null" } + ] + }, + "restate": { + "anyOf": [ + { "$ref": "#/components/schemas/PlanStageRestate" }, + { "type": "null" } + ] + }, + "backfill": { + "anyOf": [ + { "$ref": "#/components/schemas/PlanStageBackfill" }, + { "type": "null" } + ] + }, + "promote": { + "anyOf": [ + { "$ref": "#/components/schemas/PlanStagePromote" }, + { "type": "null" } + ] + } + }, + "additionalProperties": false, + "type": "object", + "title": "PlanApplyStageTracker" + }, + "PlanCancelStageTracker": { + "properties": { + "start": { + "anyOf": [ + { "type": "string", "format": "date" }, + { "type": "string", "format": "date-time" }, + { "type": "string" }, + { "type": "integer" }, + { "type": "number" }, + { "type": "null" } + ], + "title": "Start" + }, + "end": { + "anyOf": [ + { "type": "string", "format": "date" }, + { "type": "string", "format": "date-time" }, + { "type": "string" }, + { "type": "integer" }, + { "type": "number" }, + { "type": "null" } + ], + "title": "End" + }, + "meta": { "$ref": "#/components/schemas/TrackableMeta" }, + "environment": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Environment" + }, + "plan_options": { + "anyOf": [ + { "$ref": "#/components/schemas/PlanOptions" }, + { "type": "null" } + ] + }, + "cancel": { + "anyOf": [ + { "$ref": "#/components/schemas/PlanStageCancel" }, + { "type": "null" } + ] + } + }, + "additionalProperties": false, + "type": "object", + "title": "PlanCancelStageTracker" + }, + "PlanDates": { + "properties": { + "start": { + "anyOf": [ + { "type": "string", "format": "date" }, + { "type": "string", "format": "date-time" }, + { "type": "string" }, + { "type": "integer" }, + { "type": "number" }, + { "type": "null" } + ], + "title": "Start" + }, + "end": { + "anyOf": [ + { "type": "string", "format": "date" }, + { "type": "string", "format": "date-time" }, + { "type": "string" }, + { "type": "integer" }, + { "type": "number" }, + { "type": "null" } + ], + "title": "End" + } + }, + "additionalProperties": false, + "type": "object", + "title": "PlanDates" + }, + "PlanOptions": { + "properties": { + "skip_tests": { + "type": "boolean", + "title": "Skip Tests", + "default": false + }, + "skip_backfill": { + "type": "boolean", + "title": "Skip Backfill", + "default": false + }, + "no_gaps": { + "type": "boolean", + "title": "No Gaps", + "default": false + }, + "forward_only": { + "type": "boolean", + "title": "Forward Only", + "default": false + }, + "no_auto_categorization": { + "type": "boolean", + "title": "No Auto Categorization", + "default": false + }, + "include_unmodified": { + "type": "boolean", + "title": "Include Unmodified", + "default": false + }, + "create_from": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Create From" + }, + "restate_models": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Restate Models" + }, + "auto_apply": { + "type": "boolean", + "title": "Auto Apply", + "default": false + } + }, + "additionalProperties": false, + "type": "object", + "title": "PlanOptions" + }, + "PlanOverviewStageTracker": { + "properties": { + "start": { + "anyOf": [ + { "type": "string", "format": "date" }, + { "type": "string", "format": "date-time" }, + { "type": "string" }, + { "type": "integer" }, + { "type": "number" }, + { "type": "null" } + ], + "title": "Start" + }, + "end": { + "anyOf": [ + { "type": "string", "format": "date" }, + { "type": "string", "format": "date-time" }, + { "type": "string" }, + { "type": "integer" }, + { "type": "number" }, + { "type": "null" } + ], + "title": "End" + }, + "meta": { "$ref": "#/components/schemas/TrackableMeta" }, + "environment": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Environment" + }, + "plan_options": { + "anyOf": [ + { "$ref": "#/components/schemas/PlanOptions" }, + { "type": "null" } + ] + }, + "validation": { + "anyOf": [ + { "$ref": "#/components/schemas/PlanStageValidation" }, + { "type": "null" } + ] + }, + "changes": { + "anyOf": [ + { "$ref": "#/components/schemas/PlanStageChanges" }, + { "type": "null" } + ] + }, + "backfills": { + "anyOf": [ + { "$ref": "#/components/schemas/PlanStageBackfills" }, + { "type": "null" } + ] + } + }, + "additionalProperties": false, + "type": "object", + "title": "PlanOverviewStageTracker" + }, + "PlanStageBackfill": { + "properties": { + "meta": { "$ref": "#/components/schemas/TrackableMeta" }, + "queue": { + "items": { "type": "string" }, + "type": "array", + "uniqueItems": true, + "title": "Queue", + "default": [] + }, + "tasks": { + "additionalProperties": { + "$ref": "#/components/schemas/BackfillTask" + }, + "type": "object", + "title": "Tasks", + "default": {} + } + }, + "additionalProperties": false, + "type": "object", + "title": "PlanStageBackfill" + }, + "PlanStageBackfills": { + "properties": { + "meta": { "$ref": "#/components/schemas/TrackableMeta" }, + "models": { + "anyOf": [ + { + "items": { "$ref": "#/components/schemas/BackfillDetails" }, + "type": "array" + }, + { "type": "null" } + ], + "title": "Models" + } + }, + "additionalProperties": false, + "type": "object", + "title": "PlanStageBackfills" + }, + "PlanStageCancel": { + "properties": { + "meta": { "$ref": "#/components/schemas/TrackableMeta" } + }, + "additionalProperties": false, + "type": "object", + "title": "PlanStageCancel" + }, + "PlanStageChanges": { + "properties": { + "added": { + "anyOf": [ + { + "items": { "$ref": "#/components/schemas/ChangeDisplay" }, + "type": "array" + }, + { "type": "null" } + ], + "title": "Added" + }, + "removed": { + "anyOf": [ + { + "items": { "$ref": "#/components/schemas/ChangeDisplay" }, + "type": "array" + }, + { "type": "null" } + ], + "title": "Removed" + }, + "modified": { + "anyOf": [ + { "$ref": "#/components/schemas/ModelsDiff" }, + { "type": "null" } + ] + }, + "meta": { "$ref": "#/components/schemas/TrackableMeta" } + }, + "additionalProperties": false, + "type": "object", + "title": "PlanStageChanges" + }, + "PlanStageCreation": { + "properties": { + "meta": { "$ref": "#/components/schemas/TrackableMeta" }, + "total_tasks": { "type": "integer", "title": "Total Tasks" }, + "num_tasks": { "type": "integer", "title": "Num Tasks" } + }, + "additionalProperties": false, + "type": "object", + "required": ["total_tasks", "num_tasks"], + "title": "PlanStageCreation" + }, + "PlanStagePromote": { + "properties": { + "meta": { "$ref": "#/components/schemas/TrackableMeta" }, + "total_tasks": { "type": "integer", "title": "Total Tasks" }, + "num_tasks": { "type": "integer", "title": "Num Tasks" }, + "target_environment": { + "type": "string", + "title": "Target Environment" + } + }, + "additionalProperties": false, + "type": "object", + "required": ["total_tasks", "num_tasks", "target_environment"], + "title": "PlanStagePromote" + }, + "PlanStageRestate": { + "properties": { + "meta": { "$ref": "#/components/schemas/TrackableMeta" } + }, + "additionalProperties": false, + "type": "object", + "title": "PlanStageRestate" + }, + "PlanStageValidation": { + "properties": { + "meta": { "$ref": "#/components/schemas/TrackableMeta" } + }, + "additionalProperties": false, + "type": "object", + "title": "PlanStageValidation" + }, + "Query": { + "properties": { "sql": { "type": "string", "title": "Sql" } }, + "additionalProperties": false, + "type": "object", + "required": ["sql"], + "title": "Query" + }, + "Reference": { + "properties": { + "name": { "type": "string", "title": "Name" }, + "expression": { "type": "string", "title": "Expression" }, + "unique": { "type": "boolean", "title": "Unique" } + }, + "additionalProperties": false, + "type": "object", + "required": ["name", "expression", "unique"], + "title": "Reference" + }, + "RenderInput": { + "properties": { + "model": { "type": "string", "title": "Model" }, + "start": { + "anyOf": [ + { "type": "string", "format": "date" }, + { "type": "string", "format": "date-time" }, + { "type": "string" }, + { "type": "integer" }, + { "type": "number" }, + { "type": "null" } + ], + "title": "Start" + }, + "end": { + "anyOf": [ + { "type": "string", "format": "date" }, + { "type": "string", "format": "date-time" }, + { "type": "string" }, + { "type": "integer" }, + { "type": "number" }, + { "type": "null" } + ], + "title": "End" + }, + "execution_time": { + "anyOf": [ + { "type": "string", "format": "date" }, + { "type": "string", "format": "date-time" }, + { "type": "string" }, + { "type": "integer" }, + { "type": "number" }, + { "type": "null" } + ], + "title": "Execution Time" + }, + "expand": { + "anyOf": [ + { "type": "boolean" }, + { "items": { "type": "string" }, "type": "array" } + ], + "title": "Expand", + "default": false + }, + "pretty": { "type": "boolean", "title": "Pretty", "default": true }, + "dialect": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Dialect" + } + }, + "additionalProperties": false, + "type": "object", + "required": ["model"], + "title": "RenderInput" + }, + "RowDiff": { + "properties": { + "source": { "type": "string", "title": "Source" }, + "target": { "type": "string", "title": "Target" }, + "stats": { + "additionalProperties": { "type": "number" }, + "type": "object", + "title": "Stats" + }, + "sample": { + "additionalProperties": true, + "type": "object", + "title": "Sample" + }, + "source_count": { "type": "integer", "title": "Source Count" }, + "target_count": { "type": "integer", "title": "Target Count" }, + "count_pct_change": { "type": "number", "title": "Count Pct Change" } + }, + "additionalProperties": false, + "type": "object", + "required": [ + "source", + "target", + "stats", + "sample", + "source_count", + "target_count", + "count_pct_change" + ], + "title": "RowDiff" + }, + "SchemaDiff": { + "properties": { + "source": { "type": "string", "title": "Source" }, + "target": { "type": "string", "title": "Target" }, + "source_schema": { + "additionalProperties": { "type": "string" }, + "type": "object", + "title": "Source Schema" + }, + "target_schema": { + "additionalProperties": { "type": "string" }, + "type": "object", + "title": "Target Schema" + }, + "added": { + "additionalProperties": { "type": "string" }, + "type": "object", + "title": "Added" + }, + "removed": { + "additionalProperties": { "type": "string" }, + "type": "object", + "title": "Removed" + }, + "modified": { + "additionalProperties": { "type": "string" }, + "type": "object", + "title": "Modified" + } + }, + "additionalProperties": false, + "type": "object", + "required": [ + "source", + "target", + "source_schema", + "target_schema", + "added", + "removed", + "modified" + ], + "title": "SchemaDiff" + }, + "SnapshotChangeCategory": { + "type": "integer", + "enum": [1, 2, 3, 4, 5, 6], + "title": "SnapshotChangeCategory", + "description": "Values are ordered by decreasing severity and that ordering is required.\n\nBREAKING: The change requires that snapshot modified and downstream dependencies be rebuilt\nNON_BREAKING: The change requires that only the snapshot modified be rebuilt\nFORWARD_ONLY: The change requires no rebuilding\nINDIRECT_BREAKING: The change was caused indirectly and is breaking.\nINDIRECT_NON_BREAKING: The change was caused indirectly by a non-breaking change.\nMETADATA: The change was caused by a metadata update." + }, + "Status": { + "type": "string", + "enum": ["init", "success", "fail"], + "title": "Status", + "description": "An enumeration of statuses." + }, + "TableDiff": { + "properties": { + "schema_diff": { "$ref": "#/components/schemas/SchemaDiff" }, + "row_diff": { "$ref": "#/components/schemas/RowDiff" }, + "on": { + "items": { "items": { "type": "string" }, "type": "array" }, + "type": "array", + "title": "On" + } + }, + "additionalProperties": false, + "type": "object", + "required": ["schema_diff", "row_diff", "on"], + "title": "TableDiff" + }, + "TestCase": { + "properties": { + "name": { "type": "string", "title": "Name" }, + "path": { "type": "string", "title": "Path" } + }, + "additionalProperties": false, + "type": "object", + "required": ["name", "path"], + "title": "TestCase" + }, + "TestErrorOrFailure": { + "properties": { + "name": { "type": "string", "title": "Name" }, + "path": { "type": "string", "title": "Path" }, + "tb": { "type": "string", "title": "Tb" } + }, + "additionalProperties": false, + "type": "object", + "required": ["name", "path", "tb"], + "title": "TestErrorOrFailure" + }, + "TestResult": { + "properties": { + "tests_run": { "type": "integer", "title": "Tests Run" }, + "failures": { + "items": { "$ref": "#/components/schemas/TestErrorOrFailure" }, + "type": "array", + "title": "Failures" + }, + "errors": { + "items": { "$ref": "#/components/schemas/TestErrorOrFailure" }, + "type": "array", + "title": "Errors" + }, + "skipped": { + "items": { "$ref": "#/components/schemas/TestSkipped" }, + "type": "array", + "title": "Skipped" + }, + "successes": { + "items": { "$ref": "#/components/schemas/TestCase" }, + "type": "array", + "title": "Successes" + } + }, + "additionalProperties": false, + "type": "object", + "required": ["tests_run", "failures", "errors", "skipped", "successes"], + "title": "TestResult" + }, + "TestSkipped": { + "properties": { + "name": { "type": "string", "title": "Name" }, + "path": { "type": "string", "title": "Path" }, + "reason": { "type": "string", "title": "Reason" } + }, + "additionalProperties": false, + "type": "object", + "required": ["name", "path", "reason"], + "title": "TestSkipped" + }, + "TrackableMeta": { + "properties": { + "status": { + "$ref": "#/components/schemas/Status", + "default": "init" + }, + "start": { "type": "integer", "title": "Start" }, + "end": { + "anyOf": [{ "type": "integer" }, { "type": "null" }], + "title": "End" + }, + "done": { "type": "boolean", "title": "Done", "default": false } + }, + "additionalProperties": false, + "type": "object", + "title": "TrackableMeta" + }, + "ValidationError": { + "properties": { + "loc": { + "items": { "anyOf": [{ "type": "string" }, { "type": "integer" }] }, + "type": "array", + "title": "Location" + }, + "msg": { "type": "string", "title": "Message" }, + "type": { "type": "string", "title": "Error Type" } + }, + "type": "object", + "required": ["loc", "msg", "type"], + "title": "ValidationError" + }, + "Verbosity": { + "type": "integer", + "enum": [0, 1, 2], + "title": "Verbosity", + "description": "Verbosity levels for SQLMesh output." + } + } + } +} diff --git a/vscode/react/.cta.json b/vscode/react/.cta.json new file mode 100644 index 0000000000..c70f3cc92c --- /dev/null +++ b/vscode/react/.cta.json @@ -0,0 +1,13 @@ +{ + "framework": "react", + "projectName": "react", + "typescript": true, + "tailwind": false, + "packageManager": "npm", + "toolchain": "none", + "mode": "file-router", + "git": true, + "variableValues": {}, + "version": 1, + "existingAddOns": [] +} diff --git a/vscode/react/.gitignore b/vscode/react/.gitignore new file mode 100644 index 0000000000..85125ac8c1 --- /dev/null +++ b/vscode/react/.gitignore @@ -0,0 +1,8 @@ +node_modules +.DS_Store +dist +dist-ssr +*.local + +*storybook.log +storybook-static diff --git a/vscode/react/.storybook/main.ts b/vscode/react/.storybook/main.ts new file mode 100644 index 0000000000..7601de4a00 --- /dev/null +++ b/vscode/react/.storybook/main.ts @@ -0,0 +1,17 @@ +import type { StorybookConfig } from '@storybook/react-vite' + +const config: StorybookConfig = { + stories: ['../src/**/*.mdx', '../src/**/*.stories.@(js|jsx|mjs|ts|tsx)'], + addons: [ + '@chromatic-com/storybook', + '@storybook/addon-docs', + '@storybook/addon-onboarding', + '@storybook/addon-a11y', + '@storybook/addon-vitest', + ], + framework: { + name: '@storybook/react-vite', + options: {}, + }, +} +export default config diff --git a/vscode/react/.storybook/preview.ts b/vscode/react/.storybook/preview.ts new file mode 100644 index 0000000000..1deb4d5908 --- /dev/null +++ b/vscode/react/.storybook/preview.ts @@ -0,0 +1,22 @@ +import type { Preview } from '@storybook/react-vite' +import './storybook.css' + +const preview: Preview = { + parameters: { + controls: { + matchers: { + color: /(background|color)$/i, + date: /Date$/i, + }, + }, + + a11y: { + // 'todo' - show a11y violations in the test UI only + // 'error' - fail CI on a11y violations + // 'off' - skip a11y checks entirely + test: 'todo', + }, + }, +} + +export default preview diff --git a/vscode/react/.storybook/storybook.css b/vscode/react/.storybook/storybook.css new file mode 100644 index 0000000000..562df7eadf --- /dev/null +++ b/vscode/react/.storybook/storybook.css @@ -0,0 +1,917 @@ +/* Storybook CSS Variables */ +html { + --color-graph-edge-secondary: var(--vscode-disabledForeground); + --color-graph-edge-main: var(--vscode-disabledForeground); + --color-graph-edge-selected: var(--vscode-textLink-foreground); + --color-graph-edge-direct: var(--vscode-disabledForeground); + --vscode-font-family: -apple-system, BlinkMacSystemFont, sans-serif; + --vscode-font-weight: normal; + --vscode-font-size: 13px; + --vscode-editor-font-family: Menlo, Monaco, 'Courier New', monospace; + --vscode-editor-font-weight: normal; + --vscode-editor-font-size: 12px; + --text-link-decoration: none; + --vscode-foreground: rgba(204, 204, 204, 0.87); + --vscode-disabledForeground: rgba(204, 204, 204, 0.5); + --vscode-errorForeground: #bf616a; + --vscode-descriptionForeground: rgba(204, 204, 204, 0.61); + --vscode-icon-foreground: #c5c5c5; + --vscode-focusBorder: #30373a; + --vscode-selection-background: rgba(255, 255, 255, 0.2); + --vscode-textLink-foreground: #81a1c1; + --vscode-textLink-activeForeground: #81a1c1; + --vscode-textSeparator-foreground: #88c0d0; + --vscode-textPreformat-foreground: #88c0d0; + --vscode-textPreformat-background: rgba(255, 255, 255, 0.1); + --vscode-textBlockQuote-background: #222222; + --vscode-textBlockQuote-border: rgba(0, 122, 204, 0.5); + --vscode-textCodeBlock-background: rgba(10, 10, 10, 0.4); + --vscode-sash-hoverBorder: #30373a; + --vscode-badge-background: #88c0d0; + --vscode-badge-foreground: #141414; + --vscode-activityWarningBadge-foreground: #000000; + --vscode-activityWarningBadge-background: #cca700; + --vscode-activityErrorBadge-foreground: #000000; + --vscode-activityErrorBadge-background: #f14c4c; + --vscode-scrollbar-shadow: rgba(0, 0, 0, 0); + --vscode-scrollbarSlider-background: rgba(64, 64, 64, 0.33); + --vscode-scrollbarSlider-hoverBackground: rgba(64, 64, 64, 0.33); + --vscode-scrollbarSlider-activeBackground: rgba(96, 96, 96, 0.33); + --vscode-progressBar-background: #a3be8c; + --vscode-chart-line: #236b8e; + --vscode-chart-axis: rgba(191, 191, 191, 0.4); + --vscode-chart-guide: rgba(191, 191, 191, 0.2); + --vscode-editor-background: #1a1a1a; + --vscode-editor-foreground: #d8dee9; + --vscode-editorStickyScroll-background: #1a1a1a; + --vscode-editorStickyScrollHover-background: #2a2d2e; + --vscode-editorStickyScroll-shadow: rgba(0, 0, 0, 0); + --vscode-editorWidget-background: #141414; + --vscode-editorWidget-foreground: rgba(204, 204, 204, 0.87); + --vscode-editorWidget-border: #454545; + --vscode-editorWidget-resizeBorder: #ffffff; + --vscode-editorError-foreground: #bf616a; + --vscode-editorError-border: rgba(191, 97, 106, 0); + --vscode-editorWarning-foreground: #ebcb8b; + --vscode-editorWarning-border: rgba(204, 204, 204, 0); + --vscode-editorInfo-foreground: #3794ff; + --vscode-editorHint-foreground: rgba(238, 238, 238, 0.7); + --vscode-editorLink-activeForeground: #ffffff; + --vscode-editor-selectionBackground: rgba(64, 64, 64, 0.6); + --vscode-editor-inactiveSelectionBackground: rgba(64, 64, 64, 0.47); + --vscode-editor-selectionHighlightBackground: rgba(64, 64, 64, 0.8); + --vscode-editor-compositionBorder: #ffffff; + --vscode-editor-findMatchBackground: rgba(136, 192, 208, 0.4); + --vscode-editor-findMatchHighlightBackground: rgba(136, 192, 208, 0.27); + --vscode-editor-findRangeHighlightBackground: rgba(255, 255, 255, 0.2); + --vscode-editor-hoverHighlightBackground: #292929; + --vscode-editorHoverWidget-background: #1a1a1a; + --vscode-editorHoverWidget-foreground: rgba(204, 204, 204, 0.87); + --vscode-editorHoverWidget-border: #2a2a2a; + --vscode-editorHoverWidget-statusBarBackground: #1f1f1f; + --vscode-editorInlayHint-foreground: #505050; + --vscode-editorInlayHint-background: rgba(0, 0, 0, 0); + --vscode-editorInlayHint-typeForeground: #505050; + --vscode-editorInlayHint-typeBackground: rgba(0, 0, 0, 0); + --vscode-editorInlayHint-parameterForeground: #505050; + --vscode-editorInlayHint-parameterBackground: rgba(0, 0, 0, 0); + --vscode-editorLightBulb-foreground: #ffcc00; + --vscode-editorLightBulbAutoFix-foreground: #75beff; + --vscode-editorLightBulbAi-foreground: #ffcc00; + --vscode-editor-snippetTabstopHighlightBackground: rgba(204, 204, 204, 0.33); + --vscode-editor-snippetFinalTabstopHighlightBorder: #cccccc; + --vscode-diffEditor-insertedTextBackground: rgba(163, 190, 140, 0.13); + --vscode-diffEditor-removedTextBackground: rgba(191, 97, 106, 0.13); + --vscode-diffEditor-insertedLineBackground: rgba(155, 185, 85, 0.2); + --vscode-diffEditor-removedLineBackground: rgba(255, 0, 0, 0.2); + --vscode-diffEditor-diagonalFill: rgba(204, 204, 204, 0.2); + --vscode-diffEditor-unchangedRegionBackground: #141414; + --vscode-diffEditor-unchangedRegionForeground: rgba(204, 204, 204, 0.87); + --vscode-diffEditor-unchangedCodeBackground: rgba(116, 116, 116, 0.16); + --vscode-widget-shadow: rgba(0, 0, 0, 0.4); + --vscode-toolbar-hoverBackground: rgba(90, 93, 94, 0.31); + --vscode-toolbar-activeBackground: rgba(99, 102, 103, 0.31); + --vscode-breadcrumb-foreground: rgba(204, 204, 204, 0.6); + --vscode-breadcrumb-background: #1a1a1a; + --vscode-breadcrumb-focusForeground: rgba(224, 224, 224, 0.87); + --vscode-breadcrumb-activeSelectionForeground: #ffffff; + --vscode-breadcrumbPicker-background: #141414; + --vscode-merge-currentHeaderBackground: rgba(136, 192, 208, 0.4); + --vscode-merge-currentContentBackground: rgba(136, 192, 208, 0.3); + --vscode-merge-incomingHeaderBackground: rgba(163, 190, 140, 0.4); + --vscode-merge-incomingContentBackground: rgba(163, 190, 140, 0.3); + --vscode-merge-commonHeaderBackground: rgba(96, 96, 96, 0.4); + --vscode-merge-commonContentBackground: rgba(96, 96, 96, 0.16); + --vscode-merge-border: rgba(42, 42, 42, 0); + --vscode-editorOverviewRuler-currentContentForeground: rgba( + 136, + 192, + 208, + 0.4 + ); + --vscode-editorOverviewRuler-incomingContentForeground: rgba( + 163, + 190, + 140, + 0.4 + ); + --vscode-editorOverviewRuler-commonContentForeground: rgba(96, 96, 96, 0.4); + --vscode-editorOverviewRuler-findMatchForeground: rgba(209, 134, 22, 0.49); + --vscode-editorOverviewRuler-selectionHighlightForeground: rgba( + 160, + 160, + 160, + 0.8 + ); + --vscode-problemsErrorIcon-foreground: #bf616a; + --vscode-problemsWarningIcon-foreground: #ebcb8b; + --vscode-problemsInfoIcon-foreground: #3794ff; + --vscode-minimap-findMatchHighlight: rgba(21, 172, 145, 0.44); + --vscode-minimap-selectionOccurrenceHighlight: #676767; + --vscode-minimap-selectionHighlight: #363636; + --vscode-minimap-infoHighlight: #3794ff; + --vscode-minimap-warningHighlight: #ea7620; + --vscode-minimap-errorHighlight: #f14c4c; + --vscode-minimap-background: #181818; + --vscode-minimap-foregroundOpacity: #000000; + --vscode-minimapSlider-background: rgba(64, 64, 64, 0.17); + --vscode-minimapSlider-hoverBackground: rgba(64, 64, 64, 0.17); + --vscode-minimapSlider-activeBackground: rgba(96, 96, 96, 0.17); + --vscode-charts-foreground: rgba(204, 204, 204, 0.87); + --vscode-charts-lines: rgba(204, 204, 204, 0.43); + --vscode-charts-red: #bf616a; + --vscode-charts-blue: #3794ff; + --vscode-charts-yellow: #ebcb8b; + --vscode-charts-orange: rgba(21, 172, 145, 0.44); + --vscode-charts-green: #89d185; + --vscode-charts-purple: #b180d7; + --vscode-input-background: rgba(42, 42, 42, 0.33); + --vscode-input-foreground: #ffffff; + --vscode-input-border: #2a2a2a; + --vscode-inputOption-activeBorder: #ffffff; + --vscode-inputOption-hoverBackground: rgba(90, 93, 94, 0.5); + --vscode-inputOption-activeBackground: rgba(48, 55, 58, 0.4); + --vscode-inputOption-activeForeground: #ffffff; + --vscode-input-placeholderForeground: rgba(255, 255, 255, 0.6); + --vscode-inputValidation-infoBackground: #88c0d0; + --vscode-inputValidation-infoForeground: #141414; + --vscode-inputValidation-infoBorder: #88c0d0; + --vscode-inputValidation-warningBackground: #ebcb8b; + --vscode-inputValidation-warningBorder: #ebcb8b; + --vscode-inputValidation-errorBackground: #bf616a; + --vscode-inputValidation-errorBorder: #bf616a; + --vscode-dropdown-background: #1a1a1a; + --vscode-dropdown-foreground: #ffffff; + --vscode-dropdown-border: #2a2a2a; + --vscode-button-foreground: #191c22; + --vscode-button-separator: rgba(25, 28, 34, 0.4); + --vscode-button-background: #81a1c1; + --vscode-button-hoverBackground: #87a6c4; + --vscode-button-secondaryForeground: #ececec; + --vscode-button-secondaryBackground: #565656; + --vscode-button-secondaryHoverBackground: #767676; + --vscode-radio-activeForeground: #ffffff; + --vscode-radio-activeBackground: rgba(48, 55, 58, 0.4); + --vscode-radio-activeBorder: #ffffff; + --vscode-radio-inactiveBorder: rgba(255, 255, 255, 0.2); + --vscode-radio-inactiveHoverBackground: rgba(90, 93, 94, 0.5); + --vscode-checkbox-background: #1a1a1a; + --vscode-checkbox-selectBackground: #141414; + --vscode-checkbox-foreground: #ffffff; + --vscode-checkbox-border: #2a2a2a; + --vscode-checkbox-selectBorder: #c5c5c5; + --vscode-keybindingLabel-background: rgba(128, 128, 128, 0.17); + --vscode-keybindingLabel-foreground: #cccccc; + --vscode-keybindingLabel-border: rgba(51, 51, 51, 0.6); + --vscode-keybindingLabel-bottomBorder: rgba(68, 68, 68, 0.6); + --vscode-list-focusBackground: #434c5e; + --vscode-list-focusForeground: #eceff4; + --vscode-list-focusOutline: #30373a; + --vscode-list-activeSelectionBackground: rgba(255, 255, 255, 0.11); + --vscode-list-activeSelectionForeground: #ffffff; + --vscode-list-inactiveSelectionBackground: rgba(255, 255, 255, 0.06); + --vscode-list-inactiveSelectionForeground: rgba(255, 255, 255, 0.84); + --vscode-list-hoverBackground: rgba(42, 42, 42, 0.6); + --vscode-list-hoverForeground: #ffffff; + --vscode-list-dropBackground: rgba(255, 255, 255, 0.6); + --vscode-list-dropBetweenBackground: #c5c5c5; + --vscode-list-highlightForeground: #88c0d0; + --vscode-list-focusHighlightForeground: #88c0d0; + --vscode-list-invalidItemForeground: #cccccc; + --vscode-list-errorForeground: #bf616a; + --vscode-list-warningForeground: #ebcb8b; + --vscode-listFilterWidget-background: #141414; + --vscode-listFilterWidget-outline: rgba(0, 0, 0, 0); + --vscode-listFilterWidget-noMatchesOutline: #be1100; + --vscode-listFilterWidget-shadow: rgba(0, 0, 0, 0.4); + --vscode-list-filterMatchBackground: rgba(136, 192, 208, 0.27); + --vscode-list-deemphasizedForeground: #cccccc; + --vscode-tree-indentGuidesStroke: rgba(204, 204, 204, 0.33); + --vscode-tree-inactiveIndentGuidesStroke: rgba(204, 204, 204, 0.13); + --vscode-tree-tableColumnsBorder: rgba(204, 204, 204, 0.13); + --vscode-tree-tableOddRowsBackground: rgba(204, 204, 204, 0.04); + --vscode-editorActionList-background: #141414; + --vscode-editorActionList-foreground: rgba(204, 204, 204, 0.87); + --vscode-editorActionList-focusForeground: #ffffff; + --vscode-editorActionList-focusBackground: rgba(255, 255, 255, 0.11); + --vscode-menu-foreground: #cccccc; + --vscode-menu-background: #141414; + --vscode-menu-selectionForeground: #ffffff; + --vscode-menu-selectionBackground: rgba(255, 255, 255, 0.11); + --vscode-menu-separatorBackground: #cccccc; + --vscode-quickInput-background: #141414; + --vscode-quickInput-foreground: rgba(204, 204, 204, 0.87); + --vscode-quickInputTitle-background: rgba(255, 255, 255, 0.1); + --vscode-pickerGroup-foreground: #ffffff; + --vscode-pickerGroup-border: rgba(42, 42, 42, 0); + --vscode-quickInputList-focusForeground: #ffffff; + --vscode-quickInputList-focusBackground: rgba(255, 255, 255, 0.11); + --vscode-search-resultsInfoForeground: rgba(204, 204, 204, 0.56); + --vscode-searchEditor-findMatchBackground: rgba(136, 192, 208, 0.18); + --vscode-editor-lineHighlightBackground: #292929; + --vscode-editor-lineHighlightBorder: #292929; + --vscode-editor-rangeHighlightBackground: rgba(64, 64, 64, 0.32); + --vscode-editor-symbolHighlightBackground: rgba(136, 192, 208, 0.27); + --vscode-editorCursor-foreground: #ffffff; + --vscode-editorMultiCursor-primary\.foreground: #ffffff; + --vscode-editorMultiCursor-secondary\.foreground: #ffffff; + --vscode-editorWhitespace-foreground: rgba(80, 80, 80, 0.7); + --vscode-editorLineNumber-foreground: #505050; + --vscode-editorIndentGuide-background: rgba(64, 64, 64, 0.7); + --vscode-editorIndentGuide-activeBackground: #505050; + --vscode-editorIndentGuide-background1: rgba(64, 64, 64, 0.7); + --vscode-editorIndentGuide-background2: rgba(0, 0, 0, 0); + --vscode-editorIndentGuide-background3: rgba(0, 0, 0, 0); + --vscode-editorIndentGuide-background4: rgba(0, 0, 0, 0); + --vscode-editorIndentGuide-background5: rgba(0, 0, 0, 0); + --vscode-editorIndentGuide-background6: rgba(0, 0, 0, 0); + --vscode-editorIndentGuide-activeBackground1: #505050; + --vscode-editorIndentGuide-activeBackground2: rgba(0, 0, 0, 0); + --vscode-editorIndentGuide-activeBackground3: rgba(0, 0, 0, 0); + --vscode-editorIndentGuide-activeBackground4: rgba(0, 0, 0, 0); + --vscode-editorIndentGuide-activeBackground5: rgba(0, 0, 0, 0); + --vscode-editorIndentGuide-activeBackground6: rgba(0, 0, 0, 0); + --vscode-editorActiveLineNumber-foreground: #c6c6c6; + --vscode-editorLineNumber-activeForeground: #ffffff; + --vscode-editorRuler-foreground: #494949; + --vscode-editorCodeLens-foreground: #505050; + --vscode-editorBracketMatch-background: rgba(20, 20, 20, 0); + --vscode-editorBracketMatch-border: rgba(255, 255, 255, 0.33); + --vscode-editorOverviewRuler-border: rgba(0, 0, 0, 0); + --vscode-editorGutter-background: #1a1a1a; + --vscode-editorUnnecessaryCode-opacity: rgba(0, 0, 0, 0.67); + --vscode-editorGhostText-foreground: rgba(255, 255, 255, 0.34); + --vscode-editorOverviewRuler-rangeHighlightForeground: rgba(0, 122, 204, 0.6); + --vscode-editorOverviewRuler-errorForeground: rgba(255, 18, 18, 0.7); + --vscode-editorOverviewRuler-warningForeground: #ebcb8b; + --vscode-editorOverviewRuler-infoForeground: #3794ff; + --vscode-editorBracketHighlight-foreground1: #ffd700; + --vscode-editorBracketHighlight-foreground2: #da70d6; + --vscode-editorBracketHighlight-foreground3: #179fff; + --vscode-editorBracketHighlight-foreground4: rgba(0, 0, 0, 0); + --vscode-editorBracketHighlight-foreground5: rgba(0, 0, 0, 0); + --vscode-editorBracketHighlight-foreground6: rgba(0, 0, 0, 0); + --vscode-editorBracketHighlight-unexpectedBracket\.foreground: rgba( + 255, + 18, + 18, + 0.8 + ); + --vscode-editorBracketPairGuide-background1: rgba(0, 0, 0, 0); + --vscode-editorBracketPairGuide-background2: rgba(0, 0, 0, 0); + --vscode-editorBracketPairGuide-background3: rgba(0, 0, 0, 0); + --vscode-editorBracketPairGuide-background4: rgba(0, 0, 0, 0); + --vscode-editorBracketPairGuide-background5: rgba(0, 0, 0, 0); + --vscode-editorBracketPairGuide-background6: rgba(0, 0, 0, 0); + --vscode-editorBracketPairGuide-activeBackground1: rgba(0, 0, 0, 0); + --vscode-editorBracketPairGuide-activeBackground2: rgba(0, 0, 0, 0); + --vscode-editorBracketPairGuide-activeBackground3: rgba(0, 0, 0, 0); + --vscode-editorBracketPairGuide-activeBackground4: rgba(0, 0, 0, 0); + --vscode-editorBracketPairGuide-activeBackground5: rgba(0, 0, 0, 0); + --vscode-editorBracketPairGuide-activeBackground6: rgba(0, 0, 0, 0); + --vscode-editorUnicodeHighlight-border: #ebcb8b; + --vscode-diffEditor-move\.border: rgba(139, 139, 139, 0.61); + --vscode-diffEditor-moveActive\.border: #ffa500; + --vscode-diffEditor-unchangedRegionShadow: #000000; + --vscode-editorOverviewRuler-bracketMatchForeground: #a0a0a0; + --vscode-actionBar-toggledBackground: rgba(48, 55, 58, 0.4); + --vscode-symbolIcon-arrayForeground: rgba(204, 204, 204, 0.87); + --vscode-symbolIcon-booleanForeground: rgba(204, 204, 204, 0.87); + --vscode-symbolIcon-classForeground: #ee9d28; + --vscode-symbolIcon-colorForeground: rgba(204, 204, 204, 0.87); + --vscode-symbolIcon-constantForeground: rgba(204, 204, 204, 0.87); + --vscode-symbolIcon-constructorForeground: #b180d7; + --vscode-symbolIcon-enumeratorForeground: #ee9d28; + --vscode-symbolIcon-enumeratorMemberForeground: #75beff; + --vscode-symbolIcon-eventForeground: #ee9d28; + --vscode-symbolIcon-fieldForeground: #75beff; + --vscode-symbolIcon-fileForeground: rgba(204, 204, 204, 0.87); + --vscode-symbolIcon-folderForeground: rgba(204, 204, 204, 0.87); + --vscode-symbolIcon-functionForeground: #b180d7; + --vscode-symbolIcon-interfaceForeground: #75beff; + --vscode-symbolIcon-keyForeground: rgba(204, 204, 204, 0.87); + --vscode-symbolIcon-keywordForeground: rgba(204, 204, 204, 0.87); + --vscode-symbolIcon-methodForeground: #b180d7; + --vscode-symbolIcon-moduleForeground: rgba(204, 204, 204, 0.87); + --vscode-symbolIcon-namespaceForeground: rgba(204, 204, 204, 0.87); + --vscode-symbolIcon-nullForeground: rgba(204, 204, 204, 0.87); + --vscode-symbolIcon-numberForeground: rgba(204, 204, 204, 0.87); + --vscode-symbolIcon-objectForeground: rgba(204, 204, 204, 0.87); + --vscode-symbolIcon-operatorForeground: rgba(204, 204, 204, 0.87); + --vscode-symbolIcon-packageForeground: rgba(204, 204, 204, 0.87); + --vscode-symbolIcon-propertyForeground: rgba(204, 204, 204, 0.87); + --vscode-symbolIcon-referenceForeground: rgba(204, 204, 204, 0.87); + --vscode-symbolIcon-snippetForeground: rgba(204, 204, 204, 0.87); + --vscode-symbolIcon-stringForeground: rgba(204, 204, 204, 0.87); + --vscode-symbolIcon-structForeground: rgba(204, 204, 204, 0.87); + --vscode-symbolIcon-textForeground: rgba(204, 204, 204, 0.87); + --vscode-symbolIcon-typeParameterForeground: rgba(204, 204, 204, 0.87); + --vscode-symbolIcon-unitForeground: rgba(204, 204, 204, 0.87); + --vscode-symbolIcon-variableForeground: #75beff; + --vscode-peekViewTitle-background: #2a2a2a; + --vscode-peekViewTitleLabel-foreground: #ffffff; + --vscode-peekViewTitleDescription-foreground: #ffffff; + --vscode-peekView-border: #505050; + --vscode-peekViewResult-background: #141414; + --vscode-peekViewResult-lineForeground: rgba(255, 255, 255, 0.4); + --vscode-peekViewResult-fileForeground: #ffffff; + --vscode-peekViewResult-selectionBackground: #404040; + --vscode-peekViewResult-selectionForeground: #ffffff; + --vscode-peekViewEditor-background: #141414; + --vscode-peekViewEditorGutter-background: #141414; + --vscode-peekViewEditorStickyScroll-background: #141414; + --vscode-peekViewResult-matchHighlightBackground: rgba(255, 255, 255, 0.4); + --vscode-peekViewEditor-matchHighlightBackground: rgba(255, 255, 255, 0.4); + --vscode-editor-foldBackground: rgba(64, 64, 64, 0.18); + --vscode-editor-foldPlaceholderForeground: #808080; + --vscode-editorGutter-foldingControlForeground: #c5c5c5; + --vscode-editorSuggestWidget-background: #141414; + --vscode-editorSuggestWidget-border: #2a2a2a; + --vscode-editorSuggestWidget-foreground: #ffffff; + --vscode-editorSuggestWidget-selectedForeground: #ffffff; + --vscode-editorSuggestWidget-selectedBackground: #404040; + --vscode-editorSuggestWidget-highlightForeground: #ffffff; + --vscode-editorSuggestWidget-focusHighlightForeground: #88c0d0; + --vscode-editorSuggestWidgetStatus-foreground: rgba(255, 255, 255, 0.5); + --vscode-inlineEdit-indicator\.foreground: #191c22; + --vscode-inlineEdit-indicator\.background: #81a1c1; + --vscode-inlineEdit-indicator\.border: rgba(25, 28, 34, 0.4); + --vscode-inlineEdit-originalBackground: rgba(191, 97, 106, 0.05); + --vscode-inlineEdit-modifiedBackground: rgba(163, 190, 140, 0.05); + --vscode-inlineEdit-originalChangedLineBackground: rgba(0, 0, 0, 0); + --vscode-inlineEdit-originalChangedTextBackground: rgba(191, 97, 106, 0.13); + --vscode-inlineEdit-modifiedChangedLineBackground: rgba(0, 0, 0, 0); + --vscode-inlineEdit-modifiedChangedTextBackground: rgba(163, 190, 140, 0.13); + --vscode-inlineEdit-border: #3e3e3e; + --vscode-inlineChat-foreground: rgba(204, 204, 204, 0.87); + --vscode-inlineChat-background: #141414; + --vscode-inlineChat-border: #454545; + --vscode-inlineChat-shadow: rgba(0, 0, 0, 0.4); + --vscode-inlineChatInput-border: #454545; + --vscode-inlineChatInput-focusBorder: #30373a; + --vscode-inlineChatInput-placeholderForeground: rgba(255, 255, 255, 0.6); + --vscode-inlineChatInput-background: rgba(42, 42, 42, 0.33); + --vscode-inlineChatDiff-inserted: rgba(163, 190, 140, 0.07); + --vscode-editorOverviewRuler-inlineChatInserted: rgba(163, 190, 140, 0.08); + --vscode-inlineChatDiff-removed: rgba(191, 97, 106, 0.07); + --vscode-editorOverviewRuler-inlineChatRemoved: rgba(191, 97, 106, 0.08); + --vscode-editor-wordHighlightBackground: rgba(255, 255, 255, 0.13); + --vscode-editor-wordHighlightStrongBackground: rgba(255, 255, 255, 0.18); + --vscode-editor-wordHighlightTextBackground: rgba(255, 255, 255, 0.13); + --vscode-editorOverviewRuler-wordHighlightForeground: rgba( + 160, + 160, + 160, + 0.8 + ); + --vscode-editorOverviewRuler-wordHighlightStrongForeground: rgba( + 192, + 160, + 192, + 0.8 + ); + --vscode-editorOverviewRuler-wordHighlightTextForeground: rgba( + 160, + 160, + 160, + 0.8 + ); + --vscode-tab-activeBackground: #1a1a1a; + --vscode-tab-unfocusedActiveBackground: #1a1a1a; + --vscode-tab-inactiveBackground: #141414; + --vscode-tab-unfocusedInactiveBackground: #141414; + --vscode-tab-activeForeground: #ffffff; + --vscode-tab-inactiveForeground: #505050; + --vscode-tab-unfocusedActiveForeground: rgba(255, 255, 255, 0.6); + --vscode-tab-unfocusedInactiveForeground: rgba(255, 255, 255, 0.4); + --vscode-tab-hoverBackground: rgba(255, 255, 255, 0); + --vscode-tab-unfocusedHoverBackground: rgba(42, 42, 42, 0.7); + --vscode-tab-border: rgba(255, 255, 255, 0.05); + --vscode-tab-lastPinnedBorder: rgba(204, 204, 204, 0.33); + --vscode-tab-activeBorder: #1a1a1a; + --vscode-tab-unfocusedActiveBorder: rgba(136, 192, 208, 0); + --vscode-tab-activeBorderTop: rgba(255, 255, 255, 0); + --vscode-tab-unfocusedActiveBorderTop: rgba(255, 255, 255, 0); + --vscode-tab-selectedBorderTop: rgba(255, 255, 255, 0); + --vscode-tab-selectedBackground: #1a1a1a; + --vscode-tab-selectedForeground: #ffffff; + --vscode-tab-unfocusedHoverBorder: rgba(136, 192, 208, 0); + --vscode-tab-dragAndDropBorder: #ffffff; + --vscode-tab-activeModifiedBorder: #3399cc; + --vscode-tab-inactiveModifiedBorder: rgba(51, 153, 204, 0.5); + --vscode-tab-unfocusedActiveModifiedBorder: rgba(51, 153, 204, 0.5); + --vscode-tab-unfocusedInactiveModifiedBorder: rgba(51, 153, 204, 0.25); + --vscode-editorPane-background: #1a1a1a; + --vscode-editorGroup-emptyBackground: #141414; + --vscode-editorGroupHeader-tabsBackground: #141414; + --vscode-editorGroupHeader-tabsBorder: rgba(255, 255, 255, 0.05); + --vscode-editorGroupHeader-noTabsBackground: #141414; + --vscode-editorGroup-border: rgba(255, 255, 255, 0.05); + --vscode-editorGroup-dropBackground: rgba(42, 42, 42, 0.6); + --vscode-editorGroup-dropIntoPromptForeground: rgba(204, 204, 204, 0.87); + --vscode-editorGroup-dropIntoPromptBackground: #141414; + --vscode-sideBySideEditor-horizontalBorder: rgba(255, 255, 255, 0.05); + --vscode-sideBySideEditor-verticalBorder: rgba(255, 255, 255, 0.05); + --vscode-panel-background: #141414; + --vscode-panel-border: rgba(255, 255, 255, 0.05); + --vscode-panelTitle-activeForeground: #ffffff; + --vscode-panelTitle-inactiveForeground: rgba(204, 204, 204, 0.6); + --vscode-panelTitle-activeBorder: rgba(255, 255, 255, 0); + --vscode-panelInput-border: #2a2a2a; + --vscode-panel-dropBorder: #ffffff; + --vscode-panelSection-dropBackground: rgba(42, 42, 42, 0.6); + --vscode-panelSectionHeader-background: rgba(128, 128, 128, 0.2); + --vscode-panelSection-border: rgba(255, 255, 255, 0.05); + --vscode-panelStickyScroll-background: #141414; + --vscode-panelStickyScroll-shadow: rgba(0, 0, 0, 0); + --vscode-banner-background: rgba(255, 255, 255, 0.11); + --vscode-banner-foreground: #ffffff; + --vscode-banner-iconForeground: #3794ff; + --vscode-statusBar-foreground: rgba(204, 204, 204, 0.51); + --vscode-statusBar-noFolderForeground: #ffffff; + --vscode-statusBar-background: #141414; + --vscode-statusBar-noFolderBackground: #141414; + --vscode-statusBar-border: rgba(255, 255, 255, 0.05); + --vscode-statusBar-focusBorder: rgba(204, 204, 204, 0.51); + --vscode-statusBar-noFolderBorder: rgba(255, 255, 255, 0.05); + --vscode-statusBarItem-activeBackground: #505050; + --vscode-statusBarItem-focusBorder: rgba(204, 204, 204, 0.51); + --vscode-statusBarItem-hoverBackground: #404040; + --vscode-statusBarItem-hoverForeground: rgba(204, 204, 204, 0.51); + --vscode-statusBarItem-compactHoverBackground: rgba(255, 255, 255, 0.2); + --vscode-statusBarItem-prominentForeground: rgba(204, 204, 204, 0.51); + --vscode-statusBarItem-prominentBackground: #2a2a2a; + --vscode-statusBarItem-prominentHoverForeground: rgba(204, 204, 204, 0.51); + --vscode-statusBarItem-prominentHoverBackground: #404040; + --vscode-statusBarItem-errorBackground: #7b3239; + --vscode-statusBarItem-errorForeground: #ffffff; + --vscode-statusBarItem-errorHoverForeground: rgba(204, 204, 204, 0.51); + --vscode-statusBarItem-errorHoverBackground: #404040; + --vscode-statusBarItem-warningBackground: #bf8b21; + --vscode-statusBarItem-warningForeground: #ffffff; + --vscode-statusBarItem-warningHoverForeground: rgba(204, 204, 204, 0.51); + --vscode-statusBarItem-warningHoverBackground: #404040; + --vscode-activityBar-background: #141414; + --vscode-activityBar-foreground: rgba(204, 204, 204, 0.6); + --vscode-activityBar-inactiveForeground: rgba(204, 204, 204, 0.24); + --vscode-activityBar-activeBorder: rgba(204, 204, 204, 0.6); + --vscode-activityBar-dropBorder: rgba(204, 204, 204, 0.6); + --vscode-activityBarBadge-background: #88c0d0; + --vscode-activityBarBadge-foreground: #000000; + --vscode-activityBarTop-foreground: #e7e7e7; + --vscode-activityBarTop-activeBorder: #e7e7e7; + --vscode-activityBarTop-inactiveForeground: rgba(231, 231, 231, 0.6); + --vscode-activityBarTop-dropBorder: #e7e7e7; + --vscode-profileBadge-background: #4d4d4d; + --vscode-profileBadge-foreground: #ffffff; + --vscode-statusBarItem-remoteBackground: #88c0d0; + --vscode-statusBarItem-remoteForeground: #000000; + --vscode-statusBarItem-remoteHoverForeground: rgba(204, 204, 204, 0.51); + --vscode-statusBarItem-remoteHoverBackground: #404040; + --vscode-statusBarItem-offlineBackground: #6c1717; + --vscode-statusBarItem-offlineForeground: #000000; + --vscode-statusBarItem-offlineHoverForeground: rgba(204, 204, 204, 0.51); + --vscode-statusBarItem-offlineHoverBackground: #404040; + --vscode-extensionBadge-remoteBackground: #88c0d0; + --vscode-extensionBadge-remoteForeground: #000000; + --vscode-sideBar-background: #141414; + --vscode-sideBar-foreground: rgba(204, 204, 204, 0.6); + --vscode-sideBar-border: rgba(255, 255, 255, 0.05); + --vscode-sideBarTitle-background: #141414; + --vscode-sideBarTitle-foreground: #cccccc; + --vscode-sideBar-dropBackground: rgba(42, 42, 42, 0.6); + --vscode-sideBarSectionHeader-background: #141414; + --vscode-sideBarSectionHeader-foreground: #505050; + --vscode-sideBarStickyScroll-background: #141414; + --vscode-sideBarStickyScroll-shadow: rgba(0, 0, 0, 0); + --vscode-titleBar-activeForeground: rgba(204, 204, 204, 0.51); + --vscode-titleBar-inactiveForeground: rgba(204, 204, 204, 0.38); + --vscode-titleBar-activeBackground: #141414; + --vscode-titleBar-inactiveBackground: #141414; + --vscode-titleBar-border: rgba(255, 255, 255, 0.05); + --vscode-menubar-selectionForeground: rgba(204, 204, 204, 0.51); + --vscode-menubar-selectionBackground: rgba(204, 204, 204, 0.2); + --vscode-commandCenter-foreground: rgba(204, 204, 204, 0.51); + --vscode-commandCenter-activeForeground: rgba(204, 204, 204, 0.51); + --vscode-commandCenter-inactiveForeground: rgba(204, 204, 204, 0.38); + --vscode-commandCenter-background: rgba(255, 255, 255, 0.05); + --vscode-commandCenter-activeBackground: rgba(255, 255, 255, 0.08); + --vscode-commandCenter-border: rgba(204, 204, 204, 0.1); + --vscode-commandCenter-activeBorder: rgba(204, 204, 204, 0.15); + --vscode-commandCenter-inactiveBorder: rgba(204, 204, 204, 0.09); + --vscode-notifications-foreground: #ffffff; + --vscode-notifications-background: #141414; + --vscode-notificationLink-foreground: #88c0d0; + --vscode-notificationCenterHeader-background: #1a1a1a; + --vscode-notifications-border: #1a1a1a; + --vscode-notificationsErrorIcon-foreground: #bf616a; + --vscode-notificationsWarningIcon-foreground: #ebcb8b; + --vscode-notificationsInfoIcon-foreground: #3794ff; + --vscode-extensionButton-background: #81a1c1; + --vscode-extensionButton-foreground: #191c22; + --vscode-extensionButton-hoverBackground: #87a6c4; + --vscode-extensionButton-separator: rgba(25, 28, 34, 0.4); + --vscode-extensionButton-prominentBackground: #565656; + --vscode-extensionButton-prominentForeground: #ffffff; + --vscode-extensionButton-prominentHoverBackground: #767676; + --vscode-editorGutter-modifiedBackground: #ebcb8b; + --vscode-editorGutter-addedBackground: #a3be8c; + --vscode-editorGutter-deletedBackground: #bf616a; + --vscode-minimapGutter-modifiedBackground: #e5b95c; + --vscode-minimapGutter-addedBackground: #15ac91; + --vscode-minimapGutter-deletedBackground: #f14c4c; + --vscode-editorOverviewRuler-modifiedForeground: rgba(235, 203, 139, 0.6); + --vscode-editorOverviewRuler-addedForeground: rgba(163, 190, 140, 0.6); + --vscode-editorOverviewRuler-deletedForeground: rgba(191, 97, 106, 0.6); + --vscode-chat-requestBorder: rgba(255, 255, 255, 0.1); + --vscode-chat-requestBackground: rgba(26, 26, 26, 0.62); + --vscode-chat-slashCommandBackground: rgba(52, 65, 75, 0.56); + --vscode-chat-slashCommandForeground: #40a6ff; + --vscode-chat-avatarBackground: #1f1f1f; + --vscode-chat-avatarForeground: rgba(204, 204, 204, 0.87); + --vscode-chat-editedFileForeground: #e2c08d; + --vscode-terminal-background: #141414; + --vscode-terminal-foreground: rgba(255, 255, 255, 0.8); + --vscode-terminalCursor-foreground: #ffffff; + --vscode-terminalCursor-background: rgba(255, 255, 255, 0.13); + --vscode-terminal-selectionBackground: rgba(99, 98, 98, 0.87); + --vscode-terminal-inactiveSelectionBackground: rgba(99, 98, 98, 0.43); + --vscode-terminalCommandDecoration-defaultBackground: rgba( + 255, + 255, + 255, + 0.25 + ); + --vscode-terminalCommandDecoration-successBackground: #1b81a8; + --vscode-terminalCommandDecoration-errorBackground: #f14c4c; + --vscode-terminalOverviewRuler-cursorForeground: rgba(160, 160, 160, 0.8); + --vscode-terminal-border: rgba(255, 255, 255, 0.05); + --vscode-terminalOverviewRuler-border: rgba(0, 0, 0, 0); + --vscode-terminal-findMatchBackground: rgba(136, 192, 208, 0.4); + --vscode-terminal-hoverHighlightBackground: rgba(41, 41, 41, 0.5); + --vscode-terminal-findMatchHighlightBackground: rgba(136, 192, 208, 0.27); + --vscode-terminalOverviewRuler-findMatchForeground: rgba(209, 134, 22, 0.49); + --vscode-terminal-dropBackground: rgba(42, 42, 42, 0.6); + --vscode-terminal-tab\.activeBorder: #1a1a1a; + --vscode-terminal-initialHintForeground: rgba(255, 255, 255, 0.34); + --vscode-editorMarkerNavigationError-background: rgba(191, 97, 106, 0.75); + --vscode-editorMarkerNavigationError-headerBackground: rgba( + 191, + 97, + 106, + 0.07 + ); + --vscode-editorMarkerNavigationWarning-background: #cccccc; + --vscode-editorMarkerNavigationWarning-headerBackground: rgba( + 204, + 204, + 204, + 0.1 + ); + --vscode-editorMarkerNavigationInfo-background: #3794ff; + --vscode-editorMarkerNavigationInfo-headerBackground: rgba(55, 148, 255, 0.1); + --vscode-editorMarkerNavigation-background: rgba(255, 255, 255, 0.44); + --vscode-editor-linkedEditingBackground: rgba(255, 0, 0, 0.3); + --vscode-editorHoverWidget-highlightForeground: #88c0d0; + --vscode-editor-placeholder\.foreground: rgba(255, 255, 255, 0.34); + --vscode-scmGraph-historyItemRefColor: #3794ff; + --vscode-scmGraph-historyItemRemoteRefColor: #b180d7; + --vscode-scmGraph-historyItemBaseRefColor: #ea5c00; + --vscode-scmGraph-historyItemHoverDefaultLabelForeground: rgba( + 204, + 204, + 204, + 0.87 + ); + --vscode-scmGraph-historyItemHoverDefaultLabelBackground: #88c0d0; + --vscode-scmGraph-historyItemHoverLabelForeground: #191c22; + --vscode-scmGraph-historyItemHoverAdditionsForeground: #81b88b; + --vscode-scmGraph-historyItemHoverDeletionsForeground: #c74e39; + --vscode-scmGraph-foreground1: #ffb000; + --vscode-scmGraph-foreground2: #dc267f; + --vscode-scmGraph-foreground3: #994f00; + --vscode-scmGraph-foreground4: #40b0a6; + --vscode-scmGraph-foreground5: #b66dff; + --vscode-commentsView-resolvedIcon: rgba(204, 204, 204, 0.5); + --vscode-commentsView-unresolvedIcon: #30373a; + --vscode-editorCommentsWidget-replyInputBackground: #2a2a2a; + --vscode-editorCommentsWidget-resolvedBorder: rgba(204, 204, 204, 0.5); + --vscode-editorCommentsWidget-unresolvedBorder: #30373a; + --vscode-editorCommentsWidget-rangeBackground: rgba(48, 55, 58, 0.1); + --vscode-editorCommentsWidget-rangeActiveBackground: rgba(48, 55, 58, 0.1); + --vscode-editorGutter-commentRangeForeground: #282828; + --vscode-editorOverviewRuler-commentForeground: #282828; + --vscode-editorOverviewRuler-commentUnresolvedForeground: #282828; + --vscode-editorGutter-commentGlyphForeground: #d8dee9; + --vscode-editorGutter-commentUnresolvedGlyphForeground: #d8dee9; + --vscode-ports-iconRunningProcessForeground: #88c0d0; + --vscode-simpleFindWidget-sashBorder: #454545; + --vscode-settings-headerForeground: #e7e7e7; + --vscode-settings-settingsHeaderHoverForeground: rgba(231, 231, 231, 0.7); + --vscode-settings-modifiedItemIndicator: #0c7d9d; + --vscode-settings-headerBorder: rgba(255, 255, 255, 0.05); + --vscode-settings-sashBorder: rgba(255, 255, 255, 0.05); + --vscode-settings-dropdownBackground: #1a1a1a; + --vscode-settings-dropdownForeground: #ffffff; + --vscode-settings-dropdownBorder: #2a2a2a; + --vscode-settings-dropdownListBorder: #454545; + --vscode-settings-checkboxBackground: #1a1a1a; + --vscode-settings-checkboxForeground: #ffffff; + --vscode-settings-checkboxBorder: #2a2a2a; + --vscode-settings-textInputBackground: rgba(42, 42, 42, 0.33); + --vscode-settings-textInputForeground: #ffffff; + --vscode-settings-textInputBorder: #2a2a2a; + --vscode-settings-numberInputBackground: rgba(42, 42, 42, 0.33); + --vscode-settings-numberInputForeground: #ffffff; + --vscode-settings-numberInputBorder: #2a2a2a; + --vscode-settings-focusedRowBackground: rgba(42, 42, 42, 0.36); + --vscode-settings-rowHoverBackground: rgba(42, 42, 42, 0.18); + --vscode-settings-focusedRowBorder: #30373a; + --vscode-keybindingTable-headerBackground: rgba(204, 204, 204, 0.04); + --vscode-keybindingTable-rowsBackground: rgba(204, 204, 204, 0.04); + --vscode-debugToolBar-background: #1a1a1a; + --vscode-debugIcon-startForeground: #89d185; + --vscode-notebook-cellBorderColor: rgba(255, 255, 255, 0.06); + --vscode-notebook-focusedEditorBorder: #30373a; + --vscode-notebookStatusSuccessIcon-foreground: #89d185; + --vscode-notebookEditorOverviewRuler-runningCellForeground: #89d185; + --vscode-notebookStatusErrorIcon-foreground: #bf616a; + --vscode-notebookStatusRunningIcon-foreground: rgba(204, 204, 204, 0.87); + --vscode-notebook-cellToolbarSeparator: rgba(128, 128, 128, 0.35); + --vscode-notebook-selectedCellBackground: rgba(255, 255, 255, 0.06); + --vscode-notebook-selectedCellBorder: rgba(255, 255, 255, 0.06); + --vscode-notebook-focusedCellBorder: #30373a; + --vscode-notebook-inactiveFocusedCellBorder: rgba(255, 255, 255, 0.06); + --vscode-notebook-cellStatusBarItemHoverBackground: rgba(255, 255, 255, 0.15); + --vscode-notebook-cellInsertionIndicator: #30373a; + --vscode-notebookScrollbarSlider-background: rgba(64, 64, 64, 0.33); + --vscode-notebookScrollbarSlider-hoverBackground: rgba(64, 64, 64, 0.33); + --vscode-notebookScrollbarSlider-activeBackground: rgba(96, 96, 96, 0.33); + --vscode-notebook-symbolHighlightBackground: rgba(255, 255, 255, 0.04); + --vscode-notebook-cellEditorBackground: #141414; + --vscode-notebook-editorBackground: #1a1a1a; + --vscode-debugIcon-breakpointForeground: #e51400; + --vscode-debugIcon-breakpointDisabledForeground: #848484; + --vscode-debugIcon-breakpointUnverifiedForeground: #848484; + --vscode-debugIcon-breakpointCurrentStackframeForeground: #ffcc00; + --vscode-debugIcon-breakpointStackframeForeground: #89d185; + --vscode-editor-stackFrameHighlightBackground: rgba(255, 255, 0, 0.2); + --vscode-editor-focusedStackFrameHighlightBackground: rgba( + 122, + 189, + 122, + 0.3 + ); + --vscode-multiDiffEditor-headerBackground: #262626; + --vscode-multiDiffEditor-background: #1a1a1a; + --vscode-interactive-activeCodeBorder: #505050; + --vscode-interactive-inactiveCodeBorder: rgba(255, 255, 255, 0.06); + --vscode-testing-iconFailed: #f14c4c; + --vscode-testing-iconErrored: #f14c4c; + --vscode-testing-iconPassed: #73c991; + --vscode-testing-runAction: #73c991; + --vscode-testing-iconQueued: #cca700; + --vscode-testing-iconUnset: #848484; + --vscode-testing-iconSkipped: #848484; + --vscode-testing-peekBorder: #bf616a; + --vscode-testing-messagePeekBorder: #3794ff; + --vscode-testing-peekHeaderBackground: rgba(191, 97, 106, 0.1); + --vscode-testing-messagePeekHeaderBackground: rgba(55, 148, 255, 0.1); + --vscode-testing-coveredBackground: rgba(163, 190, 140, 0.13); + --vscode-testing-coveredBorder: rgba(163, 190, 140, 0.1); + --vscode-testing-coveredGutterBackground: rgba(163, 190, 140, 0.08); + --vscode-testing-uncoveredBranchBackground: #452c2f; + --vscode-testing-uncoveredBackground: rgba(191, 97, 106, 0.13); + --vscode-testing-uncoveredBorder: rgba(191, 97, 106, 0.1); + --vscode-testing-uncoveredGutterBackground: rgba(191, 97, 106, 0.2); + --vscode-testing-coverCountBadgeBackground: #88c0d0; + --vscode-testing-coverCountBadgeForeground: #141414; + --vscode-testing-message\.error\.badgeBackground: #f14c4c; + --vscode-testing-message\.error\.badgeBorder: #f14c4c; + --vscode-testing-message\.error\.badgeForeground: #000000; + --vscode-testing-message\.info\.decorationForeground: rgba( + 216, + 222, + 233, + 0.5 + ); + --vscode-testing-iconErrored\.retired: rgba(241, 76, 76, 0.7); + --vscode-testing-iconFailed\.retired: rgba(241, 76, 76, 0.7); + --vscode-testing-iconPassed\.retired: rgba(115, 201, 145, 0.7); + --vscode-testing-iconQueued\.retired: rgba(204, 167, 0, 0.7); + --vscode-testing-iconUnset\.retired: rgba(132, 132, 132, 0.7); + --vscode-testing-iconSkipped\.retired: rgba(132, 132, 132, 0.7); + --vscode-searchEditor-textInputBorder: #2a2a2a; + --vscode-debugExceptionWidget-border: #141414; + --vscode-debugExceptionWidget-background: #505050; + --vscode-editor-inlineValuesForeground: rgba(255, 255, 255, 0.5); + --vscode-editor-inlineValuesBackground: rgba(255, 200, 0, 0.2); + --vscode-statusBar-debuggingBackground: #434c5e; + --vscode-statusBar-debuggingForeground: #d8dee9; + --vscode-statusBar-debuggingBorder: rgba(255, 255, 255, 0.05); + --vscode-commandCenter-debuggingBackground: rgba(67, 76, 94, 0.26); + --vscode-debugTokenExpression-name: #c586c0; + --vscode-debugTokenExpression-type: #4a90e2; + --vscode-debugTokenExpression-value: rgba(204, 204, 204, 0.6); + --vscode-debugTokenExpression-string: #ce9178; + --vscode-debugTokenExpression-boolean: #4e94ce; + --vscode-debugTokenExpression-number: #b5cea8; + --vscode-debugTokenExpression-error: #f48771; + --vscode-debugView-exceptionLabelForeground: rgba(204, 204, 204, 0.87); + --vscode-debugView-exceptionLabelBackground: #6c2022; + --vscode-debugView-stateLabelForeground: rgba(204, 204, 204, 0.87); + --vscode-debugView-stateLabelBackground: rgba(136, 136, 136, 0.27); + --vscode-debugView-valueChangedHighlight: #569cd6; + --vscode-debugConsole-infoForeground: #3794ff; + --vscode-debugConsole-warningForeground: #ebcb8b; + --vscode-debugConsole-errorForeground: #bf616a; + --vscode-debugConsole-sourceForeground: rgba(204, 204, 204, 0.87); + --vscode-debugConsoleInputIcon-foreground: rgba(204, 204, 204, 0.87); + --vscode-debugIcon-pauseForeground: #75beff; + --vscode-debugIcon-stopForeground: #f48771; + --vscode-debugIcon-disconnectForeground: #f48771; + --vscode-debugIcon-restartForeground: #89d185; + --vscode-debugIcon-stepOverForeground: #75beff; + --vscode-debugIcon-stepIntoForeground: #75beff; + --vscode-debugIcon-stepOutForeground: #75beff; + --vscode-debugIcon-continueForeground: #75beff; + --vscode-debugIcon-stepBackForeground: #75beff; + --vscode-mergeEditor-change\.background: rgba(155, 185, 85, 0.2); + --vscode-mergeEditor-change\.word\.background: rgba(156, 204, 44, 0.2); + --vscode-mergeEditor-changeBase\.background: #4b1818; + --vscode-mergeEditor-changeBase\.word\.background: #6f1313; + --vscode-mergeEditor-conflict\.unhandledUnfocused\.border: rgba( + 255, + 166, + 0, + 0.48 + ); + --vscode-mergeEditor-conflict\.unhandledFocused\.border: #ffa600; + --vscode-mergeEditor-conflict\.handledUnfocused\.border: rgba( + 134, + 134, + 134, + 0.29 + ); + --vscode-mergeEditor-conflict\.handledFocused\.border: rgba( + 193, + 193, + 193, + 0.8 + ); + --vscode-mergeEditor-conflict\.handled\.minimapOverViewRuler: rgba( + 173, + 172, + 168, + 0.93 + ); + --vscode-mergeEditor-conflict\.unhandled\.minimapOverViewRuler: #fcba03; + --vscode-mergeEditor-conflictingLines\.background: rgba(255, 234, 0, 0.28); + --vscode-mergeEditor-conflict\.input1\.background: rgba(136, 192, 208, 0.16); + --vscode-mergeEditor-conflict\.input2\.background: rgba(163, 190, 140, 0.16); + --vscode-extensionIcon-starForeground: #ff8e00; + --vscode-extensionIcon-verifiedForeground: #81a1c1; + --vscode-extensionIcon-preReleaseForeground: #1d9271; + --vscode-extensionIcon-sponsorForeground: #d758b3; + --vscode-terminal-ansiBlack: #2a2a2a; + --vscode-terminal-ansiRed: #bf616a; + --vscode-terminal-ansiGreen: #a3be8c; + --vscode-terminal-ansiYellow: #ebcb8b; + --vscode-terminal-ansiBlue: #81a1c1; + --vscode-terminal-ansiMagenta: #b48ead; + --vscode-terminal-ansiCyan: #88c0d0; + --vscode-terminal-ansiWhite: #ffffff; + --vscode-terminal-ansiBrightBlack: #505050; + --vscode-terminal-ansiBrightRed: #bf616a; + --vscode-terminal-ansiBrightGreen: #a3be8c; + --vscode-terminal-ansiBrightYellow: #ebcb8b; + --vscode-terminal-ansiBrightBlue: #81a1c1; + --vscode-terminal-ansiBrightMagenta: #b48ead; + --vscode-terminal-ansiBrightCyan: #88c0d0; + --vscode-terminal-ansiBrightWhite: #ffffff; + --vscode-terminalStickyScrollHover-background: #2a2d2e; + --vscode-terminalCommandGuide-foreground: rgba(255, 255, 255, 0.06); + --vscode-walkThrough-embeddedEditorBackground: #141414; + --vscode-profiles-sashBorder: rgba(255, 255, 255, 0.05); + --vscode-gitDecoration-addedResourceForeground: #a3be8c; + --vscode-gitDecoration-modifiedResourceForeground: #ebcb8b; + --vscode-gitDecoration-deletedResourceForeground: #bf616a; + --vscode-gitDecoration-renamedResourceForeground: #73c991; + --vscode-gitDecoration-untrackedResourceForeground: #88c0d0; + --vscode-gitDecoration-ignoredResourceForeground: #505050; + --vscode-gitDecoration-stageModifiedResourceForeground: #e2c08d; + --vscode-gitDecoration-stageDeletedResourceForeground: #c74e39; + --vscode-gitDecoration-conflictingResourceForeground: #e4676b; + --vscode-gitDecoration-submoduleResourceForeground: #8db9e2; + --vscode-git-blame\.editorDecorationForeground: #505050; + --vscode-gitlens-gutterBackgroundColor: rgba(255, 255, 255, 0.07); + --vscode-gitlens-gutterForegroundColor: #bebebe; + --vscode-gitlens-gutterUncommittedForegroundColor: rgba(0, 188, 242, 0.6); + --vscode-gitlens-trailingLineBackgroundColor: rgba(0, 0, 0, 0); + --vscode-gitlens-trailingLineForegroundColor: rgba(204, 204, 204, 0.6); + --vscode-gitlens-lineHighlightBackgroundColor: rgba(0, 188, 242, 0.2); + --vscode-gitlens-lineHighlightOverviewRulerColor: rgba(0, 188, 242, 0.6); + --vscode-gitlens-openAutolinkedIssueIconColor: #3fb950; + --vscode-gitlens-closedAutolinkedIssueIconColor: #a371f7; + --vscode-gitlens-closedPullRequestIconColor: #f85149; + --vscode-gitlens-openPullRequestIconColor: #3fb950; + --vscode-gitlens-mergedPullRequestIconColor: #a371f7; + --vscode-gitlens-unpublishedChangesIconColor: #35b15e; + --vscode-gitlens-unpublishedCommitIconColor: #35b15e; + --vscode-gitlens-unpulledChangesIconColor: #b15e35; + --vscode-gitlens-decorations\.addedForegroundColor: #a3be8c; + --vscode-gitlens-decorations\.copiedForegroundColor: #73c991; + --vscode-gitlens-decorations\.deletedForegroundColor: #bf616a; + --vscode-gitlens-decorations\.ignoredForegroundColor: #505050; + --vscode-gitlens-decorations\.modifiedForegroundColor: #ebcb8b; + --vscode-gitlens-decorations\.untrackedForegroundColor: #88c0d0; + --vscode-gitlens-decorations\.renamedForegroundColor: #73c991; + --vscode-gitlens-decorations\.branchAheadForegroundColor: #35b15e; + --vscode-gitlens-decorations\.branchBehindForegroundColor: #b15e35; + --vscode-gitlens-decorations\.branchDivergedForegroundColor: #d8af1b; + --vscode-gitlens-decorations\.branchUpToDateForegroundColor: rgba( + 204, + 204, + 204, + 0.6 + ); + --vscode-gitlens-decorations\.branchUnpublishedForegroundColor: rgba( + 204, + 204, + 204, + 0.6 + ); + --vscode-gitlens-decorations\.branchMissingUpstreamForegroundColor: #c74e39; + --vscode-gitlens-decorations\.statusMergingOrRebasingConflictForegroundColor: #c74e39; + --vscode-gitlens-decorations\.statusMergingOrRebasingForegroundColor: #d8af1b; + --vscode-gitlens-decorations\.workspaceRepoMissingForegroundColor: #909090; + --vscode-gitlens-decorations\.workspaceCurrentForegroundColor: #35b15e; + --vscode-gitlens-decorations\.workspaceRepoOpenForegroundColor: #35b15e; + --vscode-gitlens-decorations\.worktreeHasUncommittedChangesForegroundColor: #e2c08d; + --vscode-gitlens-decorations\.worktreeMissingForegroundColor: #c74e39; + --vscode-gitlens-graphLane1Color: #15a0bf; + --vscode-gitlens-graphLane2Color: #0669f7; + --vscode-gitlens-graphLane3Color: #8e00c2; + --vscode-gitlens-graphLane4Color: #c517b6; + --vscode-gitlens-graphLane5Color: #d90171; + --vscode-gitlens-graphLane6Color: #cd0101; + --vscode-gitlens-graphLane7Color: #f25d2e; + --vscode-gitlens-graphLane8Color: #f2ca33; + --vscode-gitlens-graphLane9Color: #7bd938; + --vscode-gitlens-graphLane10Color: #2ece9d; + --vscode-gitlens-graphChangesColumnAddedColor: #347d39; + --vscode-gitlens-graphChangesColumnDeletedColor: #c93c37; + --vscode-gitlens-graphMinimapMarkerHeadColor: #05e617; + --vscode-gitlens-graphScrollMarkerHeadColor: #05e617; + --vscode-gitlens-graphMinimapMarkerUpstreamColor: #09ae17; + --vscode-gitlens-graphScrollMarkerUpstreamColor: #09ae17; + --vscode-gitlens-graphMinimapMarkerHighlightsColor: #fbff0a; + --vscode-gitlens-graphScrollMarkerHighlightsColor: #fbff0a; + --vscode-gitlens-graphMinimapMarkerLocalBranchesColor: #3087cf; + --vscode-gitlens-graphScrollMarkerLocalBranchesColor: #3087cf; + --vscode-gitlens-graphMinimapMarkerPullRequestsColor: #c76801; + --vscode-gitlens-graphScrollMarkerPullRequestsColor: #c76801; + --vscode-gitlens-graphMinimapMarkerRemoteBranchesColor: #2b5e88; + --vscode-gitlens-graphScrollMarkerRemoteBranchesColor: #2b5e88; + --vscode-gitlens-graphMinimapMarkerStashesColor: #b34db3; + --vscode-gitlens-graphScrollMarkerStashesColor: #b34db3; + --vscode-gitlens-graphMinimapMarkerTagsColor: #6b562e; + --vscode-gitlens-graphScrollMarkerTagsColor: #6b562e; + --vscode-gitlens-launchpadIndicatorMergeableColor: #3fb950; + --vscode-gitlens-launchpadIndicatorMergeableHoverColor: #3fb950; + --vscode-gitlens-launchpadIndicatorBlockedColor: #c74e39; + --vscode-gitlens-launchpadIndicatorBlockedHoverColor: #c74e39; + --vscode-gitlens-launchpadIndicatorAttentionColor: #d8af1b; + --vscode-gitlens-launchpadIndicatorAttentionHoverColor: #d8af1b; + --vscode-issues-newIssueDecoration: rgba(255, 255, 255, 0.28); + --vscode-issues-open: #3fb950; + --vscode-issues-closed: #cb2431; + --vscode-pullRequests-merged: #8957e5; + --vscode-pullRequests-draft: #6e7681; + --vscode-pullRequests-open: #3fb950; + --vscode-pullRequests-closed: #cb2431; + --vscode-pullRequests-notification: #3794ff; + --vscode-rainbowtrack1: #ff0000; + --vscode-rainbowtrack2: #00ff00; + --vscode-rainbowtrack3: #0000ff; +} diff --git a/vscode/react/.storybook/vitest.setup.ts b/vscode/react/.storybook/vitest.setup.ts new file mode 100644 index 0000000000..06a5566d6d --- /dev/null +++ b/vscode/react/.storybook/vitest.setup.ts @@ -0,0 +1,7 @@ +import * as a11yAddonAnnotations from '@storybook/addon-a11y/preview' +import { setProjectAnnotations } from '@storybook/react-vite' +import * as projectAnnotations from './preview' + +// This is an important step to apply the right configuration when testing your stories. +// More info at: https://storybook.js.org/docs/api/portable-stories/portable-stories-vitest#setprojectannotations +setProjectAnnotations([a11yAddonAnnotations, projectAnnotations]) diff --git a/vscode/react/index.html b/vscode/react/index.html new file mode 100644 index 0000000000..575e79754d --- /dev/null +++ b/vscode/react/index.html @@ -0,0 +1,38 @@ + + + + + + + + + + + Create TanStack App - react + + +
+ + + diff --git a/vscode/react/orval.config.ts b/vscode/react/orval.config.ts new file mode 100644 index 0000000000..6f8c33d298 --- /dev/null +++ b/vscode/react/orval.config.ts @@ -0,0 +1,17 @@ +import { defineConfig } from 'orval' + +export default defineConfig({ + 'sqlmesh-api': { + input: '../openapi.json', + output: { + prettier: true, + target: './src/api/client.ts', + override: { + mutator: { + path: './src/api/instance.ts', + name: 'fetchAPI', + }, + }, + }, + }, +}) diff --git a/vscode/react/package.json b/vscode/react/package.json new file mode 100644 index 0000000000..e12dd12179 --- /dev/null +++ b/vscode/react/package.json @@ -0,0 +1,61 @@ +{ + "name": "react", + "private": true, + "type": "module", + "scripts": { + "start": "vite --port 3000", + "dev": "vite", + "build": "pnpm run lint && vite build", + "build:watch": "vite build --watch --outDir ./../extension/src_react --emptyOutDir", + "serve": "vite preview", + "test": "vitest run", + "generate:api": "orval --config ./orval.config.ts", + "lint": "tsc --noEmit", + "storybook": "storybook dev -p 6006", + "build-storybook": "storybook build" + }, + "dependencies": { + "@headlessui/react": "^2.2.5", + "@heroicons/react": "^2.2.0", + "@radix-ui/react-select": "^2.2.5", + "@tailwindcss/postcss": "^4.1.11", + "@tailwindcss/vite": "^4.1.11", + "@tanstack/react-query": "^5.83.0", + "@tanstack/react-router": "^1.129.8", + "@tanstack/react-router-devtools": "^1.131.26", + "@tanstack/react-virtual": "^3.13.12", + "@tanstack/router-plugin": "^1.129.8", + "apache-arrow": "^19.0.1", + "clsx": "^2.1.1", + "elkjs": "^0.8.2", + "orval": "^7.10.0", + "react": "^18.3.1", + "react-dom": "^18.3.1", + "react-router": "^7.7.0", + "reactflow": "^11.11.4", + "tailwindcss": "^4.1.11", + "vscode-uri": "^3.1.0" + }, + "devDependencies": { + "@chromatic-com/storybook": "^4.0.1", + "@storybook/addon-a11y": "^9.0.18", + "@storybook/addon-docs": "^9.0.18", + "@storybook/addon-onboarding": "^9.0.18", + "@storybook/addon-vitest": "^9.0.18", + "@storybook/react-vite": "^9.0.18", + "@testing-library/dom": "^10.4.1", + "@testing-library/react": "^16.3.0", + "@types/react": "^18.3.23", + "@types/react-dom": "^18.3.7", + "@vitejs/plugin-react": "^4.7.0", + "@vitest/browser": "3.2.3", + "@vitest/coverage-v8": "3.2.3", + "jsdom": "^26.1.0", + "playwright": "^1.54.1", + "storybook": "^9.0.18", + "typescript": "^5.8.3", + "vite": "^6.3.5", + "vitest": "^3.2.4", + "web-vitals": "^4.2.4" + } +} diff --git a/vscode/react/src/App.css b/vscode/react/src/App.css new file mode 100644 index 0000000000..c49ccd115b --- /dev/null +++ b/vscode/react/src/App.css @@ -0,0 +1,74 @@ +@import 'tailwindcss'; +@config "../tailwind.config.cjs"; + +@tailwind base; +@tailwind components; +@tailwind utilities; + +.App { + text-align: center; +} + +.App-logo { + height: 40vmin; + pointer-events: none; +} + +@media (prefers-reduced-motion: no-preference) { + .App-logo { + animation: App-logo-spin infinite 20s linear; + } +} + +.App-header { + background-color: #282c34; + min-height: 100vh; + display: flex; + flex-direction: column; + align-items: center; + justify-content: center; + font-size: calc(10px + 2vmin); + color: white; +} + +.App-link { + color: #61dafb; +} + +@keyframes App-logo-spin { + from { + transform: rotate(0deg); + } + to { + transform: rotate(360deg); + } +} + +@layer components { + .scrollbar--horizontal::-webkit-scrollbar { + height: var(--scrollbar-size); + } + .scrollbar--vertical::-webkit-scrollbar { + width: var(--scrollbar-size); + } + .scrollbar::-webkit-scrollbar-track { + background: transparent; + } + .scrollbar::-webkit-scrollbar-thumb { + background: var(--scrollbar-backgroud); + border-radius: var(--scrollbar-radius); + } + .input-ring { + @apply ring-accent-200 ring-offset-accent-500; + } + .input-ring:focus { + @apply outline-none ring-offset-2 ring-4; + } +} + +:root { + --color-graph-edge-secondary: var(--vscode-disabledForeground); + --color-graph-edge-main: var(--vscode-disabledForeground); + --color-graph-edge-selected: var(--vscode-textLink-foreground); + --color-graph-edge-direct: var(--vscode-disabledForeground); +} diff --git a/vscode/react/src/api/client.ts b/vscode/react/src/api/client.ts new file mode 100644 index 0000000000..028b2d1912 --- /dev/null +++ b/vscode/react/src/api/client.ts @@ -0,0 +1,1264 @@ +/** + * Generated by orval v7.9.0 🍺 + * Do not edit manually. + * FastAPI + * OpenAPI spec version: 0.1.0 + */ +import { fetchAPI } from './instance' +export type ApiExceptionPayloadStatus = number | null + +export type ApiExceptionPayloadTrigger = string | null + +export type ApiExceptionPayloadType = string | null + +export type ApiExceptionPayloadDescription = string | null + +export type ApiExceptionPayloadTraceback = string | null + +export type ApiExceptionPayloadStack = string[] | null + +export interface ApiExceptionPayload { + timestamp: number + message: string + origin: string + status?: ApiExceptionPayloadStatus + trigger?: ApiExceptionPayloadTrigger + type?: ApiExceptionPayloadType + description?: ApiExceptionPayloadDescription + traceback?: ApiExceptionPayloadTraceback + stack?: ApiExceptionPayloadStack +} + +export interface BackfillDetails { + name: string + view_name: string + node_type?: NodeType + parents?: string[] + interval: string[] + batches: number +} + +export type BackfillTaskEnd = number | null + +export type BackfillTaskInterval = string[] | null + +export interface BackfillTask { + name: string + view_name: string + node_type?: NodeType + parents?: string[] + completed: number + total: number + start: number + end?: BackfillTaskEnd + interval?: BackfillTaskInterval +} + +export type BodyInitiateApplyApiCommandsApplyPostEnvironment = string | null + +export type BodyInitiateApplyApiCommandsApplyPostPlanDates = PlanDates | null + +export type BodyInitiateApplyApiCommandsApplyPostPlanOptions = + PlanOptions | null + +export type BodyInitiateApplyApiCommandsApplyPostCategoriesAnyOf = { + [key: string]: SnapshotChangeCategory +} + +export type BodyInitiateApplyApiCommandsApplyPostCategories = + BodyInitiateApplyApiCommandsApplyPostCategoriesAnyOf | null + +export interface BodyInitiateApplyApiCommandsApplyPost { + environment?: BodyInitiateApplyApiCommandsApplyPostEnvironment + plan_dates?: BodyInitiateApplyApiCommandsApplyPostPlanDates + plan_options?: BodyInitiateApplyApiCommandsApplyPostPlanOptions + categories?: BodyInitiateApplyApiCommandsApplyPostCategories +} + +export type BodyInitiatePlanApiPlanPostEnvironment = string | null + +export type BodyInitiatePlanApiPlanPostPlanDates = PlanDates | null + +export type BodyInitiatePlanApiPlanPostPlanOptions = PlanOptions | null + +export type BodyInitiatePlanApiPlanPostCategoriesAnyOf = { + [key: string]: SnapshotChangeCategory +} + +export type BodyInitiatePlanApiPlanPostCategories = + BodyInitiatePlanApiPlanPostCategoriesAnyOf | null + +export interface BodyInitiatePlanApiPlanPost { + environment?: BodyInitiatePlanApiPlanPostEnvironment + plan_dates?: BodyInitiatePlanApiPlanPostPlanDates + plan_options?: BodyInitiatePlanApiPlanPostPlanOptions + categories?: BodyInitiatePlanApiPlanPostCategories +} + +export type BodyWriteDirectoryApiDirectoriesPathPostNewPath = string | null + +export interface BodyWriteDirectoryApiDirectoriesPathPost { + new_path?: BodyWriteDirectoryApiDirectoriesPathPostNewPath +} + +export type BodyWriteFileApiFilesPathPostNewPath = string | null + +export interface BodyWriteFileApiFilesPathPost { + content?: string + new_path?: BodyWriteFileApiFilesPathPostNewPath +} + +export type ChangeDirectChangeCategory = SnapshotChangeCategory | null + +export interface ChangeDirect { + name: string + view_name: string + node_type?: NodeType + parents?: string[] + diff: string + indirect?: ChangeDisplay[] + direct?: ChangeDisplay[] + change_category?: ChangeDirectChangeCategory +} + +export interface ChangeDisplay { + name: string + view_name: string + node_type?: NodeType + parents?: string[] +} + +export interface ChangeIndirect { + name: string + view_name: string + node_type?: NodeType + parents?: string[] +} + +export type ColumnDescription = string | null + +export interface Column { + name: string + type: string + description?: ColumnDescription +} + +export interface Directory { + name: string + path: string + directories?: Directory[] + files?: File[] +} + +export type EnvironmentStartAt = string | string | string | number | number + +export type EnvironmentEndAt = string | string | string | number | number | null + +export type EnvironmentPreviousPlanId = string | null + +export type EnvironmentExpirationTs = number | null + +export type EnvironmentFinalizedTs = number | null + +export type EnvironmentCatalogNameOverride = string | null + +export type EnvironmentPromotedSnapshotIds = unknown[] | null + +export type EnvironmentPreviousFinalizedSnapshots = unknown[] | null + +export type EnvironmentRequirements = { [key: string]: string } + +/** + * Represents an isolated environment. + +Environments are isolated workspaces that hold pointers to physical tables. + +Args: + snapshots: The snapshots that are part of this environment. + promoted_snapshot_ids: The IDs of the snapshots that are promoted in this environment + (i.e. for which the views are created). If not specified, all snapshots are promoted. + previous_finalized_snapshots: Snapshots that were part of this environment last time it was finalized. + requirements: A mapping of library versions for all the snapshots in this environment. + */ +export interface Environment { + name?: string + start_at: EnvironmentStartAt + end_at?: EnvironmentEndAt + plan_id: string + previous_plan_id?: EnvironmentPreviousPlanId + expiration_ts?: EnvironmentExpirationTs + finalized_ts?: EnvironmentFinalizedTs + suffix_target?: EnvironmentSuffixTarget + catalog_name_override?: EnvironmentCatalogNameOverride + normalize_name?: boolean + gateway_managed?: boolean + snapshots: unknown[] + promoted_snapshot_ids?: EnvironmentPromotedSnapshotIds + previous_finalized_snapshots?: EnvironmentPreviousFinalizedSnapshots + requirements?: EnvironmentRequirements +} + +export type EnvironmentSuffixTarget = + (typeof EnvironmentSuffixTarget)[keyof typeof EnvironmentSuffixTarget] + +// eslint-disable-next-line @typescript-eslint/no-redeclare +export const EnvironmentSuffixTarget = { + schema: 'schema', + table: 'table', +} as const + +export type EnvironmentsEnvironments = { [key: string]: Environment } + +export interface Environments { + environments?: EnvironmentsEnvironments + pinned_environments?: string[] + default_target_environment?: string +} + +export type EvaluateInputStart = string | string | string | number | number + +export type EvaluateInputEnd = string | string | string | number | number + +export type EvaluateInputExecutionTime = + | string + | string + | string + | number + | number + +export interface EvaluateInput { + model: string + start: EvaluateInputStart + end: EvaluateInputEnd + execution_time: EvaluateInputExecutionTime + limit?: number +} + +export interface FetchdfInput { + sql: string + limit?: number +} + +export type FileContent = string | null + +export interface File { + name: string + path: string + extension?: string + content?: FileContent +} + +export interface HTTPValidationError { + detail?: ValidationError[] +} + +/** + * IntervalUnit is the inferred granularity of an incremental node. + +IntervalUnit can be one of 5 types, YEAR, MONTH, DAY, HOUR, MINUTE. The unit is inferred +based on the cron schedule of a node. The minimum time delta between a sample set of dates +is used to determine which unit a node's schedule is. + */ +export type IntervalUnit = (typeof IntervalUnit)[keyof typeof IntervalUnit] + +// eslint-disable-next-line @typescript-eslint/no-redeclare +export const IntervalUnit = { + year: 'year', + month: 'month', + day: 'day', + hour: 'hour', + half_hour: 'half_hour', + quarter_hour: 'quarter_hour', + five_minute: 'five_minute', +} as const + +export type LineageColumnSource = string | null + +export type LineageColumnExpression = string | null + +export type LineageColumnModels = { [key: string]: string[] } + +export interface LineageColumn { + source?: LineageColumnSource + expression?: LineageColumnExpression + models: LineageColumnModels +} + +export interface Meta { + version: string + has_running_task?: boolean +} + +export type ModelPath = string | null + +export type ModelFullPath = string | null + +export type ModelDescription = string | null + +export type ModelDetailsProperty = ModelDetails | null + +export type ModelSql = string | null + +export type ModelDefinition = string | null + +export type ModelDefaultCatalog = string | null + +export interface Model { + name: string + fqn: string + path?: ModelPath + full_path?: ModelFullPath + dialect: string + type: ModelType + columns: Column[] + description?: ModelDescription + details?: ModelDetailsProperty + sql?: ModelSql + definition?: ModelDefinition + default_catalog?: ModelDefaultCatalog + hash: string +} + +export type ModelDetailsOwner = string | null + +export type ModelDetailsKind = string | null + +export type ModelDetailsBatchSize = number | null + +export type ModelDetailsCron = string | null + +export type ModelDetailsStamp = + | string + | string + | string + | number + | number + | null + +export type ModelDetailsStart = + | string + | string + | string + | number + | number + | null + +export type ModelDetailsRetention = number | null + +export type ModelDetailsTableFormat = string | null + +export type ModelDetailsStorageFormat = string | null + +export type ModelDetailsTimeColumn = string | null + +export type ModelDetailsTags = string | null + +export type ModelDetailsPartitionedBy = string | null + +export type ModelDetailsClusteredBy = string | null + +export type ModelDetailsLookback = number | null + +export type ModelDetailsCronPrev = + | string + | string + | string + | number + | number + | null + +export type ModelDetailsCronNext = + | string + | string + | string + | number + | number + | null + +export type ModelDetailsIntervalUnit = IntervalUnit | null + +export type ModelDetailsAnnotated = boolean | null + +export interface ModelDetails { + owner?: ModelDetailsOwner + kind?: ModelDetailsKind + batch_size?: ModelDetailsBatchSize + cron?: ModelDetailsCron + stamp?: ModelDetailsStamp + start?: ModelDetailsStart + retention?: ModelDetailsRetention + table_format?: ModelDetailsTableFormat + storage_format?: ModelDetailsStorageFormat + time_column?: ModelDetailsTimeColumn + tags?: ModelDetailsTags + references?: Reference[] + partitioned_by?: ModelDetailsPartitionedBy + clustered_by?: ModelDetailsClusteredBy + lookback?: ModelDetailsLookback + cron_prev?: ModelDetailsCronPrev + cron_next?: ModelDetailsCronNext + interval_unit?: ModelDetailsIntervalUnit + annotated?: ModelDetailsAnnotated +} + +export type ModelType = (typeof ModelType)[keyof typeof ModelType] + +// eslint-disable-next-line @typescript-eslint/no-redeclare +export const ModelType = { + python: 'python', + sql: 'sql', + seed: 'seed', + external: 'external', + source: 'source', +} as const + +export interface ModelsDiff { + direct?: ChangeDirect[] + indirect?: ChangeIndirect[] + metadata?: ChangeDisplay[] +} + +export type Modules = (typeof Modules)[keyof typeof Modules] + +// eslint-disable-next-line @typescript-eslint/no-redeclare +export const Modules = { + editor: 'editor', + files: 'files', + 'data-catalog': 'data-catalog', + plans: 'plans', + tests: 'tests', + audits: 'audits', + errors: 'errors', + data: 'data', + lineage: 'lineage', +} as const + +export type NodeType = (typeof NodeType)[keyof typeof NodeType] + +// eslint-disable-next-line @typescript-eslint/no-redeclare +export const NodeType = { + model: 'model', + audit: 'audit', +} as const + +export type PlanApplyStageTrackerStart = + | string + | string + | string + | number + | number + | null + +export type PlanApplyStageTrackerEnd = + | string + | string + | string + | number + | number + | null + +export type PlanApplyStageTrackerEnvironment = string | null + +export type PlanApplyStageTrackerPlanOptions = PlanOptions | null + +export type PlanApplyStageTrackerCreation = PlanStageCreation | null + +export type PlanApplyStageTrackerRestate = PlanStageRestate | null + +export type PlanApplyStageTrackerBackfill = PlanStageBackfill | null + +export type PlanApplyStageTrackerPromote = PlanStagePromote | null + +export interface PlanApplyStageTracker { + start?: PlanApplyStageTrackerStart + end?: PlanApplyStageTrackerEnd + meta?: TrackableMeta + environment?: PlanApplyStageTrackerEnvironment + plan_options?: PlanApplyStageTrackerPlanOptions + creation?: PlanApplyStageTrackerCreation + restate?: PlanApplyStageTrackerRestate + backfill?: PlanApplyStageTrackerBackfill + promote?: PlanApplyStageTrackerPromote +} + +export type PlanCancelStageTrackerStart = + | string + | string + | string + | number + | number + | null + +export type PlanCancelStageTrackerEnd = + | string + | string + | string + | number + | number + | null + +export type PlanCancelStageTrackerEnvironment = string | null + +export type PlanCancelStageTrackerPlanOptions = PlanOptions | null + +export type PlanCancelStageTrackerCancel = PlanStageCancel | null + +export interface PlanCancelStageTracker { + start?: PlanCancelStageTrackerStart + end?: PlanCancelStageTrackerEnd + meta?: TrackableMeta + environment?: PlanCancelStageTrackerEnvironment + plan_options?: PlanCancelStageTrackerPlanOptions + cancel?: PlanCancelStageTrackerCancel +} + +export type PlanDatesStart = string | string | string | number | number | null + +export type PlanDatesEnd = string | string | string | number | number | null + +export interface PlanDates { + start?: PlanDatesStart + end?: PlanDatesEnd +} + +export type PlanOptionsCreateFrom = string | null + +export type PlanOptionsRestateModels = string | null + +export interface PlanOptions { + skip_tests?: boolean + skip_backfill?: boolean + no_gaps?: boolean + forward_only?: boolean + no_auto_categorization?: boolean + include_unmodified?: boolean + create_from?: PlanOptionsCreateFrom + restate_models?: PlanOptionsRestateModels + auto_apply?: boolean +} + +export type PlanOverviewStageTrackerStart = + | string + | string + | string + | number + | number + | null + +export type PlanOverviewStageTrackerEnd = + | string + | string + | string + | number + | number + | null + +export type PlanOverviewStageTrackerEnvironment = string | null + +export type PlanOverviewStageTrackerPlanOptions = PlanOptions | null + +export type PlanOverviewStageTrackerValidation = PlanStageValidation | null + +export type PlanOverviewStageTrackerChanges = PlanStageChanges | null + +export type PlanOverviewStageTrackerBackfills = PlanStageBackfills | null + +export interface PlanOverviewStageTracker { + start?: PlanOverviewStageTrackerStart + end?: PlanOverviewStageTrackerEnd + meta?: TrackableMeta + environment?: PlanOverviewStageTrackerEnvironment + plan_options?: PlanOverviewStageTrackerPlanOptions + validation?: PlanOverviewStageTrackerValidation + changes?: PlanOverviewStageTrackerChanges + backfills?: PlanOverviewStageTrackerBackfills +} + +export type PlanStageBackfillTasks = { [key: string]: BackfillTask } + +export interface PlanStageBackfill { + meta?: TrackableMeta + queue?: string[] + tasks?: PlanStageBackfillTasks +} + +export type PlanStageBackfillsModels = BackfillDetails[] | null + +export interface PlanStageBackfills { + meta?: TrackableMeta + models?: PlanStageBackfillsModels +} + +export interface PlanStageCancel { + meta?: TrackableMeta +} + +export type PlanStageChangesAdded = ChangeDisplay[] | null + +export type PlanStageChangesRemoved = ChangeDisplay[] | null + +export type PlanStageChangesModified = ModelsDiff | null + +export interface PlanStageChanges { + added?: PlanStageChangesAdded + removed?: PlanStageChangesRemoved + modified?: PlanStageChangesModified + meta?: TrackableMeta +} + +export interface PlanStageCreation { + meta?: TrackableMeta + total_tasks: number + num_tasks: number +} + +export interface PlanStagePromote { + meta?: TrackableMeta + total_tasks: number + num_tasks: number + target_environment: string +} + +export interface PlanStageRestate { + meta?: TrackableMeta +} + +export interface PlanStageValidation { + meta?: TrackableMeta +} + +export interface Query { + sql: string +} + +export interface Reference { + name: string + expression: string + unique: boolean +} + +export type RenderInputStart = string | string | string | number | number | null + +export type RenderInputEnd = string | string | string | number | number | null + +export type RenderInputExecutionTime = + | string + | string + | string + | number + | number + | null + +export type RenderInputExpand = boolean | string[] + +export type RenderInputDialect = string | null + +export interface RenderInput { + model: string + start?: RenderInputStart + end?: RenderInputEnd + execution_time?: RenderInputExecutionTime + expand?: RenderInputExpand + pretty?: boolean + dialect?: RenderInputDialect +} + +export type RowDiffStats = { [key: string]: number } + +export type RowDiffSample = { [key: string]: unknown } + +export interface RowDiff { + source: string + target: string + stats: RowDiffStats + sample: RowDiffSample + source_count: number + target_count: number + count_pct_change: number +} + +export type SchemaDiffSourceSchema = { [key: string]: string } + +export type SchemaDiffTargetSchema = { [key: string]: string } + +export type SchemaDiffAdded = { [key: string]: string } + +export type SchemaDiffRemoved = { [key: string]: string } + +export type SchemaDiffModified = { [key: string]: string } + +export interface SchemaDiff { + source: string + target: string + source_schema: SchemaDiffSourceSchema + target_schema: SchemaDiffTargetSchema + added: SchemaDiffAdded + removed: SchemaDiffRemoved + modified: SchemaDiffModified +} + +/** + * Values are ordered by decreasing severity and that ordering is required. + +BREAKING: The change requires that snapshot modified and downstream dependencies be rebuilt +NON_BREAKING: The change requires that only the snapshot modified be rebuilt +FORWARD_ONLY: The change requires no rebuilding +INDIRECT_BREAKING: The change was caused indirectly and is breaking. +INDIRECT_NON_BREAKING: The change was caused indirectly by a non-breaking change. +METADATA: The change was caused by a metadata update. + */ +export type SnapshotChangeCategory = + (typeof SnapshotChangeCategory)[keyof typeof SnapshotChangeCategory] + +// eslint-disable-next-line @typescript-eslint/no-redeclare +export const SnapshotChangeCategory = { + NUMBER_1: 1, + NUMBER_2: 2, + NUMBER_3: 3, + NUMBER_4: 4, + NUMBER_5: 5, + NUMBER_6: 6, +} as const + +/** + * An enumeration of statuses. + */ +export type Status = (typeof Status)[keyof typeof Status] + +// eslint-disable-next-line @typescript-eslint/no-redeclare +export const Status = { + init: 'init', + success: 'success', + fail: 'fail', +} as const + +export interface TableDiff { + schema_diff: SchemaDiff + row_diff: RowDiff + on: string[][] +} + +export interface TestCase { + name: string + path: string +} + +export interface TestErrorOrFailure { + name: string + path: string + tb: string +} + +export interface TestResult { + tests_run: number + failures: TestErrorOrFailure[] + errors: TestErrorOrFailure[] + skipped: TestSkipped[] + successes: TestCase[] +} + +export interface TestSkipped { + name: string + path: string + reason: string +} + +export type TrackableMetaEnd = number | null + +export interface TrackableMeta { + status?: Status + start?: number + end?: TrackableMetaEnd + done?: boolean +} + +export type ValidationErrorLocItem = string | number + +export interface ValidationError { + loc: ValidationErrorLocItem[] + msg: string + type: string +} + +/** + * Verbosity levels for SQLMesh output. + */ +export type Verbosity = (typeof Verbosity)[keyof typeof Verbosity] + +// eslint-disable-next-line @typescript-eslint/no-redeclare +export const Verbosity = { + NUMBER_0: 0, + NUMBER_1: 1, + NUMBER_2: 2, +} as const + +export type InitiateApplyApiCommandsApplyPost200 = PlanApplyStageTracker | null + +export type TestApiCommandsTestGetParams = { + test?: string | null + verbosity?: Verbosity +} + +export type WriteFileApiFilesPathPost200 = File | null + +export type InitiatePlanApiPlanPost200 = PlanOverviewStageTracker | null + +export type CancelPlanApiPlanCancelPost200 = PlanCancelStageTracker | null + +export type ColumnLineageApiLineageModelNameColumnNameGetParams = { + models_only?: boolean +} + +export type ColumnLineageApiLineageModelNameColumnNameGet200 = { + [key: string]: { [key: string]: LineageColumn } +} + +export type ModelLineageApiLineageModelNameGet200 = { [key: string]: string[] } + +export type GetModelsApiModelsGet200 = Model[] | ApiExceptionPayload + +export type GetTableDiffApiTableDiffGetParams = { + source: string + target: string + on?: string | null + model_or_snapshot?: string | null + where?: string | null + temp_schema?: string | null + limit?: number +} + +export type GetTableDiffApiTableDiffGet200 = TableDiff | null + +type SecondParameter unknown> = Parameters[1] + +/** + * Apply a plan + * @summary Initiate Apply + */ +export const initiateApplyApiCommandsApplyPost = ( + bodyInitiateApplyApiCommandsApplyPost: BodyInitiateApplyApiCommandsApplyPost, + options?: SecondParameter, +) => { + return fetchAPI( + { + url: `/api/commands/apply`, + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + data: bodyInitiateApplyApiCommandsApplyPost, + }, + options, + ) +} + +/** + * Evaluate a model with a default limit of 1000 + * @summary Evaluate + */ +export const evaluateApiCommandsEvaluatePost = ( + evaluateInput: EvaluateInput, + options?: SecondParameter, +) => { + return fetchAPI( + { + url: `/api/commands/evaluate`, + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + data: evaluateInput, + }, + options, + ) +} + +/** + * Fetches a dataframe given a sql string + * @summary Fetchdf + */ +export const fetchdfApiCommandsFetchdfPost = ( + fetchdfInput: FetchdfInput, + options?: SecondParameter, +) => { + return fetchAPI( + { + url: `/api/commands/fetchdf`, + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + data: fetchdfInput, + }, + options, + ) +} + +/** + * Renders a model's query, optionally expanding referenced models + * @summary Render + */ +export const renderApiCommandsRenderPost = ( + renderInput: RenderInput, + options?: SecondParameter, +) => { + return fetchAPI( + { + url: `/api/commands/render`, + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + data: renderInput, + }, + options, + ) +} + +/** + * Run one or all model tests + * @summary Test + */ +export const testApiCommandsTestGet = ( + params?: TestApiCommandsTestGetParams, + options?: SecondParameter, +) => { + return fetchAPI( + { url: `/api/commands/test`, method: 'GET', params }, + options, + ) +} + +/** + * Get all project files. + * @summary Get Files + */ +export const getFilesApiFilesGet = ( + options?: SecondParameter, +) => { + return fetchAPI({ url: `/api/files`, method: 'GET' }, options) +} + +/** + * Get a file, including its contents. + * @summary Get File + */ +export const getFileApiFilesPathGet = ( + path: string, + options?: SecondParameter, +) => { + return fetchAPI({ url: `/api/files/${path}`, method: 'GET' }, options) +} + +/** + * Create, update, or rename a file. + * @summary Write File + */ +export const writeFileApiFilesPathPost = ( + path: string, + bodyWriteFileApiFilesPathPost: BodyWriteFileApiFilesPathPost, + options?: SecondParameter, +) => { + return fetchAPI( + { + url: `/api/files/${path}`, + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + data: bodyWriteFileApiFilesPathPost, + }, + options, + ) +} + +/** + * Delete a file. + * @summary Delete File + */ +export const deleteFileApiFilesPathDelete = ( + path: string, + options?: SecondParameter, +) => { + return fetchAPI( + { url: `/api/files/${path}`, method: 'DELETE' }, + options, + ) +} + +/** + * Create or rename a directory. + * @summary Write Directory + */ +export const writeDirectoryApiDirectoriesPathPost = ( + path: string, + bodyWriteDirectoryApiDirectoriesPathPost: BodyWriteDirectoryApiDirectoriesPathPost, + options?: SecondParameter, +) => { + return fetchAPI( + { + url: `/api/directories/${path}`, + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + data: bodyWriteDirectoryApiDirectoriesPathPost, + }, + options, + ) +} + +/** + * Delete a directory. + * @summary Delete Directory + */ +export const deleteDirectoryApiDirectoriesPathDelete = ( + path: string, + options?: SecondParameter, +) => { + return fetchAPI( + { url: `/api/directories/${path}`, method: 'DELETE' }, + options, + ) +} + +/** + * Get a plan for an environment. + * @summary Initiate Plan + */ +export const initiatePlanApiPlanPost = ( + bodyInitiatePlanApiPlanPost: BodyInitiatePlanApiPlanPost, + options?: SecondParameter, +) => { + return fetchAPI( + { + url: `/api/plan`, + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + data: bodyInitiatePlanApiPlanPost, + }, + options, + ) +} + +/** + * Cancel a plan application + * @summary Cancel Plan + */ +export const cancelPlanApiPlanCancelPost = ( + options?: SecondParameter, +) => { + return fetchAPI( + { url: `/api/plan/cancel`, method: 'POST' }, + options, + ) +} + +/** + * Get the environments + * @summary Get Environments + */ +export const getEnvironmentsApiEnvironmentsGet = ( + options?: SecondParameter, +) => { + return fetchAPI( + { url: `/api/environments`, method: 'GET' }, + options, + ) +} + +/** + * Invalidate and delete an environment + * @summary Delete Environment + */ +export const deleteEnvironmentApiEnvironmentsEnvironmentDelete = ( + environment: string, + options?: SecondParameter, +) => { + return fetchAPI( + { url: `/api/environments/${environment}`, method: 'DELETE' }, + options, + ) +} + +/** + * SQLMesh console server sent events + * @summary Events + */ +export const eventsApiEventsGet = ( + options?: SecondParameter, +) => { + return fetchAPI({ url: `/api/events`, method: 'GET' }, options) +} + +/** + * Get a column's lineage + * @summary Column Lineage + */ +export const columnLineageApiLineageModelNameColumnNameGet = ( + modelName: string, + columnName: string, + params?: ColumnLineageApiLineageModelNameColumnNameGetParams, + options?: SecondParameter, +) => { + return fetchAPI( + { url: `/api/lineage/${modelName}/${columnName}`, method: 'GET', params }, + options, + ) +} + +/** + * Get a model's lineage + * @summary Model Lineage + */ +export const modelLineageApiLineageModelNameGet = ( + modelName: string, + options?: SecondParameter, +) => { + return fetchAPI( + { url: `/api/lineage/${modelName}`, method: 'GET' }, + options, + ) +} + +/** + * Get a list of models + * @summary Get Models + */ +export const getModelsApiModelsGet = ( + options?: SecondParameter, +) => { + return fetchAPI( + { url: `/api/models`, method: 'GET' }, + options, + ) +} + +/** + * Get a single model + * @summary Get Model + */ +export const getModelApiModelsNameGet = ( + name: string, + options?: SecondParameter, +) => { + return fetchAPI({ url: `/api/models/${name}`, method: 'GET' }, options) +} + +/** + * Get the metadata + * @summary Get Api Meta + */ +export const getApiMetaApiMetaGet = ( + options?: SecondParameter, +) => { + return fetchAPI({ url: `/api/meta`, method: 'GET' }, options) +} + +/** + * Get the modules + * @summary Get Api Modules + */ +export const getApiModulesApiModulesGet = ( + options?: SecondParameter, +) => { + return fetchAPI({ url: `/api/modules`, method: 'GET' }, options) +} + +/** + * Calculate differences between tables, taking into account schema and row level differences. + * @summary Get Table Diff + */ +export const getTableDiffApiTableDiffGet = ( + params: GetTableDiffApiTableDiffGetParams, + options?: SecondParameter, +) => { + return fetchAPI( + { url: `/api/table_diff`, method: 'GET', params }, + options, + ) +} + +/** + * @summary Health + */ +export const healthHealthGet = (options?: SecondParameter) => { + return fetchAPI({ url: `/health`, method: 'GET' }, options) +} + +/** + * @summary Index + */ +export const indexFullPathGet = ( + fullPath: string, + options?: SecondParameter, +) => { + return fetchAPI({ url: `/${fullPath}`, method: 'GET' }, options) +} + +export type InitiateApplyApiCommandsApplyPostResult = NonNullable< + Awaited> +> +export type EvaluateApiCommandsEvaluatePostResult = NonNullable< + Awaited> +> +export type FetchdfApiCommandsFetchdfPostResult = NonNullable< + Awaited> +> +export type RenderApiCommandsRenderPostResult = NonNullable< + Awaited> +> +export type TestApiCommandsTestGetResult = NonNullable< + Awaited> +> +export type GetFilesApiFilesGetResult = NonNullable< + Awaited> +> +export type GetFileApiFilesPathGetResult = NonNullable< + Awaited> +> +export type WriteFileApiFilesPathPostResult = NonNullable< + Awaited> +> +export type DeleteFileApiFilesPathDeleteResult = NonNullable< + Awaited> +> +export type WriteDirectoryApiDirectoriesPathPostResult = NonNullable< + Awaited> +> +export type DeleteDirectoryApiDirectoriesPathDeleteResult = NonNullable< + Awaited> +> +export type InitiatePlanApiPlanPostResult = NonNullable< + Awaited> +> +export type CancelPlanApiPlanCancelPostResult = NonNullable< + Awaited> +> +export type GetEnvironmentsApiEnvironmentsGetResult = NonNullable< + Awaited> +> +export type DeleteEnvironmentApiEnvironmentsEnvironmentDeleteResult = + NonNullable< + Awaited< + ReturnType + > + > +export type EventsApiEventsGetResult = NonNullable< + Awaited> +> +export type ColumnLineageApiLineageModelNameColumnNameGetResult = NonNullable< + Awaited> +> +export type ModelLineageApiLineageModelNameGetResult = NonNullable< + Awaited> +> +export type GetModelsApiModelsGetResult = NonNullable< + Awaited> +> +export type GetModelApiModelsNameGetResult = NonNullable< + Awaited> +> +export type GetApiMetaApiMetaGetResult = NonNullable< + Awaited> +> +export type GetApiModulesApiModulesGetResult = NonNullable< + Awaited> +> +export type GetTableDiffApiTableDiffGetResult = NonNullable< + Awaited> +> +export type HealthHealthGetResult = NonNullable< + Awaited> +> +export type IndexFullPathGetResult = NonNullable< + Awaited> +> diff --git a/vscode/react/src/api/index.ts b/vscode/react/src/api/index.ts new file mode 100644 index 0000000000..0ba0314668 --- /dev/null +++ b/vscode/react/src/api/index.ts @@ -0,0 +1,103 @@ +import { + type UseQueryResult, + useQuery, + type QueryMeta, +} from '@tanstack/react-query' +import { + getModelsApiModelsGet, + type ModelLineageApiLineageModelNameGet200, + modelLineageApiLineageModelNameGet, + type ColumnLineageApiLineageModelNameColumnNameGet200, + columnLineageApiLineageModelNameColumnNameGet, + type Meta, + getApiMetaApiMetaGet, + type GetModelsApiModelsGet200, + type ApiExceptionPayload, + type Model, + getModelApiModelsNameGet, + type ColumnLineageApiLineageModelNameColumnNameGetParams, +} from './client' + +export interface ApiOptions { + delay?: number + trigger?: string + removeTimeoutErrorAfter?: number +} + +export interface ApiQueryOptions { + enabled?: boolean +} + +export interface ApiQueryMeta extends QueryMeta { + onError: (error: ApiExceptionPayload) => void + onSuccess: () => void +} + +export type UseQueryWithTimeoutOptions< + TData = any, + TError extends ApiExceptionPayload = ApiExceptionPayload, +> = UseQueryResult & { + cancel: () => void + isTimeout: boolean +} + +export function useApiMeta(): UseQueryResult { + return useQuery({ + queryKey: ['/api/meta'], + queryFn: getApiMetaApiMetaGet, + enabled: true, + }) +} + +export function useApiModels(): UseQueryResult { + return useQuery({ + queryKey: ['/api/models'], + queryFn: getModelsApiModelsGet, + }) +} + +export function useApiModel(modelName: string): UseQueryResult { + return useQuery({ + queryKey: ['/api/models', modelName], + queryFn: async ({ signal }) => + await getModelApiModelsNameGet(modelName, { signal }), + }) +} + +export function useApiModelLineage( + modelName: string, +): UseQueryResult { + return useQuery({ + queryKey: ['/api/lineage', modelName], + queryFn: async ({ signal }) => { + try { + const response = await modelLineageApiLineageModelNameGet(modelName, { + signal, + }) + return response + } catch (error) { + console.error('error fetching lineage', error) + throw error + } + }, + }) +} + +export function useApiColumnLineage( + model: string, + column: string, + params?: ColumnLineageApiLineageModelNameColumnNameGetParams, +): UseQueryResult { + return useQuery({ + queryKey: ['/api/lineage', model, column], + queryFn: async ({ signal }) => + await columnLineageApiLineageModelNameColumnNameGet( + model, + column, + params, + { + signal, + }, + ), + }) +} diff --git a/vscode/react/src/api/instance.ts b/vscode/react/src/api/instance.ts new file mode 100644 index 0000000000..3627b273de --- /dev/null +++ b/vscode/react/src/api/instance.ts @@ -0,0 +1,53 @@ +import { callRpc } from '@/utils/rpc' +import { isErr } from '@bus/result' + +declare global { + interface Window { + __BASE_URL__?: string + } +} + +interface ResponseWithDetail { + ok: boolean + detail?: string +} + +interface FetchOptionsWithSignal { + signal?: AbortSignal +} + +interface FetchOptions { + url: string + method: 'GET' | 'POST' | 'PUT' | 'DELETE' | 'PATCH' + data?: B + responseType?: string + headers?: Record + credentials?: 'omit' | 'same-origin' | 'include' + mode?: 'cors' | 'no-cors' | 'same-origin' + cache?: + | 'default' + | 'no-store' + | 'reload' + | 'no-cache' + | 'force-cache' + | 'only-if-cached' + params?: Record +} + +export async function fetchAPI( + config: FetchOptions, + _options?: Partial, +): Promise { + const request = { + url: config.url, + method: config.method, + params: config.params, + body: config.data, + } + const result = await callRpc('api_query', request) + if (isErr(result)) { + throw new Error(result.error) + } + const response = result.value.data + return response +} diff --git a/vscode/react/src/components/button/Button.tsx b/vscode/react/src/components/button/Button.tsx new file mode 100644 index 0000000000..b84bd93a33 --- /dev/null +++ b/vscode/react/src/components/button/Button.tsx @@ -0,0 +1,222 @@ +import React from 'react' +import clsx from 'clsx' +import { + EnumSize, + type Size, + type Variant, + type EnumVariant, +} from '@/style/variants' + +/** + * Utility type that restricts a type T to only include keys from type K + * This allows for creating a subset of a union type + */ +export type Subset = T extends K ? T : never + +export type ButtonVariant = Subset< + Variant, + | typeof EnumVariant.Primary + | typeof EnumVariant.Secondary + | typeof EnumVariant.Success + | typeof EnumVariant.Danger + | typeof EnumVariant.Warning + | typeof EnumVariant.Alternative + | typeof EnumVariant.Neutral + | typeof EnumVariant.Info +> + +export type ButtonSize = Subset< + Size, + | typeof EnumSize.xs + | typeof EnumSize.sm + | typeof EnumSize.md + | typeof EnumSize.lg +> + +export const EnumButtonShape = { + Square: 'square', + Rounded: 'rounded', + Circle: 'circle', + Pill: 'pill', +} as const + +export const EnumButtonFormat = { + Solid: 'solid', + Outline: 'outline', + Ghost: 'ghost', + Link: 'link', +} as const + +export type ButtonShape = (typeof EnumButtonShape)[keyof typeof EnumButtonShape] +export type ButtonFormat = + (typeof EnumButtonFormat)[keyof typeof EnumButtonFormat] + +export interface PropsButton + extends React.ButtonHTMLAttributes { + variant?: ButtonVariant + size?: ButtonSize + shape?: ButtonShape + format?: ButtonFormat + value?: string + form?: string + onClick?: (e: React.MouseEvent) => void +} + +const VARIANT = new Map([ + [ + 'primary', + 'border-primary-500 bg-primary-500 hover:bg-primary-400 active:bg-primary-400 text-light', + ], + [ + 'alternative', + 'border-neutral-300 bg-neutral-5 hover:bg-neutral-20 active:bg-neutral-200 text-neutral-600', + ], + [ + 'secondary', + 'border-secondary-500 bg-secondary-500 hover:bg-secondary-600 active:bg-secondary-400 text-neutral-100', + ], + [ + 'success', + 'border-success-500 bg-success-500 hover:bg-success-600 active:bg-success-400 text-neutral-100', + ], + [ + 'danger', + 'border-danger-500 bg-danger-500 hover:bg-danger-600 active:bg-danger-400 text-neutral-100', + ], + [ + 'warning', + 'border-warning-500 bg-warning-500 hover:bg-warning-600 active:bg-warning-400 text-neutral-100', + ], + [ + 'neutral', + 'border-neutral-200 bg-neutral-200 hover:bg-neutral-300 active:bg-neutral-300 text-primary-900', + ], + [ + 'info', + 'border-transparent bg-neutral-10 dark:bg-neutral-20 active:bg-neutral-10 text-neutral-700 dark:text-neutral-200', + ], +]) + +const SHAPE = new Map([ + ['rounded', `rounded-md`], + ['square', `rounded-none`], + ['circle', `rounded-full`], +]) + +const SIZE = new Map([ + [EnumSize.xs, `text-xs leading-2 border`], + [EnumSize.sm, `px-2 py-[0.125rem] text-xs leading-4 border-2`], + [EnumSize.md, `px-3 py-2 text-base leading-6 border-2`], + [EnumSize.lg, `px-4 py-3 text-lg border-4`], +]) + +const Button = makeButton( + React.forwardRef(ButtonPlain), +) + +const ButtonLink = makeButton( + React.forwardRef(ButtonLinkPlain), +) + +export { VARIANT, SHAPE, SIZE, Button, ButtonLink, makeButton } + +function ButtonPlain( + { + type = 'button', + disabled = false, + children = [], + form, + autoFocus, + tabIndex, + onClick, + className, + }: PropsButton, + ref?: React.ForwardedRef, +): JSX.Element { + return ( + + ) +} + +function ButtonLinkPlain( + { children = [], autoFocus, tabIndex, className }: PropsButton, + ref?: React.ForwardedRef, +): JSX.Element { + return ( +
+ {children} +
+ ) +} + +function makeButton( + Component: React.ElementType, +): React.ForwardRefExoticComponent< + PropsButton & React.RefAttributes +> { + return React.forwardRef(function Wrapper( + { + type = 'button', + disabled = false, + variant = 'primary', + shape = 'rounded', + size = EnumSize.md, + children = [], + className, + form, + autoFocus, + tabIndex, + onClick, + }: PropsButton, + ref?: React.ForwardedRef, + ): JSX.Element { + return ( + + {children} + + ) + }) +} diff --git a/vscode/react/src/components/graph/CogIcon.tsx b/vscode/react/src/components/graph/CogIcon.tsx new file mode 100644 index 0000000000..75d3e952bd --- /dev/null +++ b/vscode/react/src/components/graph/CogIcon.tsx @@ -0,0 +1,27 @@ +import * as React from 'react' + +/** + * CogIcon as taken from https://heroicons.com/. Slightly modified to remove fill color. + * + * @param props - SVG props + * @returns SVG element + */ +export function CogIcon(props: React.SVGProps): JSX.Element { + return ( + + ) +} diff --git a/vscode/react/src/components/graph/Graph.css b/vscode/react/src/components/graph/Graph.css new file mode 100644 index 0000000000..bef07d33b8 --- /dev/null +++ b/vscode/react/src/components/graph/Graph.css @@ -0,0 +1,49 @@ +.react-flow__node { + z-index: 10 !important; + background-color: var(--vscode-editor-background, #ffffff); +} +.react-flow__handle { + background-color: currentColor; +} +.react-flow__node.react-flow__node-model:hover, +.react-flow__node.react-flow__node-model:active { + z-index: 20 !important; +} +react-flow__attribution { + background: transparent; +} +.lineage__column-source b { + font-weight: 900; + /* color: var(--color-primary); */ + color: black; +} +.react-flow__edge { + pointer-events: none !important; + z-index: -1 !important; +} +.react-flow__background path { + stroke: inherit; +} + +.react-flow__controls-button { + box-shadow: none; + border: var(--vscode-button-border); + background: var(--vscode-button-background); + color: var(--vscode-foreground); +} +.react-flow__controls-button:hover { + background: var(--vscode-button-hoverBackground); + border: var(--vscode-button-hoverBorder); +} +.react-flow__controls-button:active { + background: var(--vscode-button-activeBackground); + border: var(--vscode-button-activeBorder); +} +.react-flow__controls { + box-shadow: none; +} +.react-flow__panel { + box-shadow: none !important; + border: none !important; + padding: 0 !important; +} diff --git a/vscode/react/src/components/graph/ModelColumns.tsx b/vscode/react/src/components/graph/ModelColumns.tsx new file mode 100644 index 0000000000..e0e180de51 --- /dev/null +++ b/vscode/react/src/components/graph/ModelColumns.tsx @@ -0,0 +1,585 @@ +import React, { useEffect, useMemo, useCallback } from 'react' +import { Handle, Position, useUpdateNodeInternals } from 'reactflow' +import 'reactflow/dist/base.css' +import { mergeLineageWithColumns, mergeConnections } from './help' +import { + debounceSync, + isArrayNotEmpty, + isFalse, + isNil, + isNotNil, + truncate, +} from '@/utils/index' +import { toID, type PartialColumnHandleId, type Side } from './types' +import { NoSymbolIcon } from '@heroicons/react/24/solid' +import { ClockIcon, ExclamationCircleIcon } from '@heroicons/react/24/outline' +import clsx from 'clsx' +import { + type ColumnDescription, + type ColumnLineageApiLineageModelNameColumnNameGet200, + type LineageColumn, +} from '@/api/client' +import Loading from '@/components/loading/Loading' +import Spinner from '@/components/logo/Spinner' +import './Graph.css' +import { + type InitialSQLMeshModel, + type ModelSQLMeshModel, +} from '@/domain/sqlmesh-model' +import { useLineageFlow } from './context' +import { useApiColumnLineage } from '@/api/index' +import SourceList from '@/components/sourceList/SourceList' +import type { Lineage } from '@/domain/lineage' +import type { Column, ColumnName } from '@/domain/column' +import type { ModelEncodedFQN } from '@/domain/models' + +export function ModelColumns({ + nodeId, + columns, + disabled, + className, + limit, + withHandles = false, + withDescription = true, + maxHeight = '50vh', +}: { + nodeId: ModelEncodedFQN + columns: Column[] + disabled?: boolean + className?: string + limit: number + withHandles?: boolean + withDescription?: boolean + maxHeight?: string +}): JSX.Element { + const { + mainNode, + connections, + isActiveColumn, + setConnections, + manuallySelectedColumn, + setManuallySelectedColumn, + setLineage, + removeActiveEdges, + addActiveEdges, + lineage, + lineageCache, + setLineageCache, + } = useLineageFlow() + + const [columnsSelected = [], columnsRest = []] = useMemo(() => { + const active: Column[] = [] + const rest: Column[] = [] + + columns.forEach(column => { + if (isActiveColumn(nodeId, column.name)) { + active.push(column) + } else { + rest.push(column) + } + }) + + return [active, rest] + }, [nodeId, columns, isActiveColumn]) + + function updateColumnLineage( + columnLineage: Record> = {}, + ): void { + let newLineageCache = lineageCache + let currentConnections + let currentLineage + + if (isNil(lineageCache)) { + const mainNodeLineage = isNil(mainNode) + ? undefined + : (lineage[mainNode] ?? lineageCache?.[mainNode]) + + newLineageCache = lineage + currentConnections = new Map() + currentLineage = + isNil(mainNode) || isNil(mainNodeLineage) + ? {} + : { [mainNode]: { models: [] } } + } else { + currentConnections = connections + currentLineage = structuredClone(lineage) + } + + const { connections: newConnections, activeEdges } = mergeConnections( + currentConnections, + columnLineage, + ) + + if (newConnections.size === 0 && activeEdges.length === 0) { + currentLineage = structuredClone(lineage) + newLineageCache = undefined + } else { + setConnections(newConnections) + addActiveEdges(activeEdges) + } + + const mergedLineage = mergeLineageWithColumns(currentLineage, columnLineage) + + setLineageCache(newLineageCache) + setLineage(mergedLineage) + } + + const isSelectManually = useCallback( + function isSelectManually(columnName: ColumnName): boolean { + if (isNil(manuallySelectedColumn)) return false + + const [selectedModel, selectedColumn] = manuallySelectedColumn + + if (isNil(selectedModel) || isNil(selectedColumn)) return false + + return selectedModel.fqn === nodeId && selectedColumn.name === columnName + }, + [nodeId, manuallySelectedColumn], + ) + + const removeEdges = useCallback( + function removeEdges(columnId: PartialColumnHandleId): void { + const visited = new Set() + + removeActiveEdges(walk(columnId, 'left').concat(walk(columnId, 'right'))) + + if (connections.size === 0 && isNotNil(lineageCache)) { + setLineage(lineageCache) + setLineageCache(undefined) + } + + setConnections(connections) + + function walk(id: string, side: Side): Array<[string, string]> { + if (visited.has(id)) return [] + + const edges = connections.get(id)?.[side] ?? [] + + connections.delete(id) + + visited.add(id) + + return edges + .map(edge => + [ + side === 'left' + ? [toID('left', id), toID('right', edge)] + : [toID('left', edge), toID('right', id)], + ].concat(walk(edge, side)), + ) + .flat() as Array<[PartialColumnHandleId, PartialColumnHandleId]> + } + }, + [removeActiveEdges, connections], + ) + + return ( +
+ {isArrayNotEmpty(columnsSelected) && ( +
+ {columnsSelected.map(column => ( + + ))} +
+ )} + {columnsRest.length <= limit && ( +
+ {columnsRest.map(column => ( + + ))} +
+ )} + {columnsRest.length > limit && ( +
+ + keyId="name" + keyName="name" + items={columnsRest} + withCounter={false} + withFilter={columnsRest.length > limit} + disabled={disabled} + listItem={({ disabled, item }) => ( + + )} + /> +
+ )} +
+ ) +} + +function ModelColumn({ + id, + nodeId, + column, + className, + disabled = false, + isActive = false, + hasLeft = false, + hasRight = false, + isEmpty = false, + updateColumnLineage, + removeEdges, + selectManually, + withHandles = false, + withDescription = true, +}: { + id: PartialColumnHandleId + nodeId: ModelEncodedFQN + column: Column + disabled?: boolean + isActive?: boolean + hasLeft?: boolean + hasRight?: boolean + isEmpty?: boolean + withHandles?: boolean + withDescription?: boolean + updateColumnLineage: ( + lineage: ColumnLineageApiLineageModelNameColumnNameGet200, + ) => void + removeEdges: (columnId: PartialColumnHandleId) => void + selectManually?: React.Dispatch< + React.SetStateAction< + [ModelSQLMeshModel, Column] | undefined + > + > + className?: string +}): JSX.Element { + const { + refetch: getColumnLineage, + isFetching, + isError, + } = useApiColumnLineage(nodeId, column.name, { models_only: true }) + + useEffect(() => { + if (isNil(selectManually)) return + + toggleColumnLineage() + selectManually(undefined) + }, [selectManually]) + + function toggleColumnLineage(): void { + if (disabled) return + + if (isActive) { + removeEdges(id) + } else { + void getColumnLineage().then(({ data }) => + updateColumnLineage(data ?? {}), + ) + } + } + + const showHandles = withHandles && (hasLeft || hasRight) + + return ( +
+
+ {showHandles ? ( + + + + + ) : ( + <> + + + + )} +
+
+ ) +} + +function ColumnHandles({ + nodeId, + id, + hasLeft = false, + hasRight = false, + disabled = false, + children, + className, +}: { + nodeId: ModelEncodedFQN + id: PartialColumnHandleId + children: React.ReactNode + className?: string + hasLeft?: boolean + hasRight?: boolean + disabled?: boolean +}): JSX.Element { + const updateNodeInternals = useUpdateNodeInternals() + + useEffect(() => { + // TODO: This is a hack to fix the issue where the handles are not rendered yet + setTimeout(() => { + updateNodeInternals(nodeId) + }, 100) + }, [hasLeft, hasRight]) + + return ( +
+ {hasLeft && ( + + )} + {children} + {hasRight && ( + + )} +
+ ) +} + +function ColumnDisplay({ + columnName, + columnType, + columnDescription, + className, + disabled = false, + withDescription = true, +}: { + columnName: ColumnName + columnType: string + columnDescription?: ColumnDescription + disabled?: boolean + withDescription?: boolean + className?: string +}): JSX.Element { + return ( +
+
+ + {disabled && ( + + )} + {truncate(columnName, 50, 20)} + + + {truncate(columnType, 20, 10)} + +
+ {isNotNil(columnDescription) && withDescription && ( +

{columnDescription}

+ )} +
+ ) +} + +function ColumnStatus({ + isFetching = false, + isError = false, + isTimeout = false, +}: { + isFetching: boolean + isError: boolean + isTimeout: boolean +}): JSX.Element { + return ( + <> + {isFetching && ( + + + + )} + {isTimeout && isFalse(isFetching) && ( + + )} + {isError && isFalse(isFetching) && ( + + )} + + ) +} + +function getColumnFromLineage( + lineage: Record, + nodeId: string, + columnName: string, +): LineageColumn | undefined { + return lineage?.[nodeId]?.columns?.[columnName as ColumnName] +} diff --git a/vscode/react/src/components/graph/ModelLineage.tsx b/vscode/react/src/components/graph/ModelLineage.tsx new file mode 100644 index 0000000000..3d157d3869 --- /dev/null +++ b/vscode/react/src/components/graph/ModelLineage.tsx @@ -0,0 +1,413 @@ +import { useApiModelLineage, useApiModels } from '@/api/index' +import { useEffect, useMemo, useState } from 'react' +import { type ModelSQLMeshModel } from '@/domain/sqlmesh-model' +import { type HighlightedNodes, useLineageFlow } from './context' +import ReactFlow, { + Controls, + Background, + BackgroundVariant, + type EdgeChange, + applyEdgeChanges, + applyNodeChanges, + type NodeChange, + useReactFlow, + type Edge, + type Node, + ReactFlowProvider, +} from 'reactflow' +import Loading from '@/components/loading/Loading' +import Spinner from '@/components/logo/Spinner' +import { createLineageWorker } from '@/components/graph/workers/index' +import { isArrayEmpty, isFalse, isNil, isNotNil } from '@/utils/index' +import clsx from 'clsx' +import ModelNode from './ModelNode' +import { + getEdges, + getLineageIndex, + getActiveNodes, + getUpdatedNodes, + getUpdatedEdges, + createGraphLayout, +} from './help' +import { SettingsControl } from '@/components/graph/SettingsControl' +import { + toModelLineage, + type ModelLineage as ModelLineageType, +} from '@/domain/lineage' +import './Graph.css' +import { + toKeys, + type LineageWorkerMessage, + type LineageWorkerRequestMessage, + type LineageWorkerResponseMessage, + type LineageWorkerErrorMessage, +} from './types' +import { encode } from '@/domain/models' + +const WITH_COLUMNS_LIMIT = 30 + +export function ModelLineage({ + model, + highlightedNodes, +}: { + model: ModelSQLMeshModel + highlightedNodes?: HighlightedNodes +}): JSX.Element { + const { + setActiveNodes, + setActiveEdges, + setConnections, + setLineage, + handleError, + setSelectedNodes, + setMainNode, + setWithColumns, + setHighlightedNodes, + setNodeConnections, + setLineageCache, + setUnknownModels, + models, + unknownModels, + setWithSecondary, + setWithConnected, + setWithImpacted, + } = useLineageFlow() + + useEffect(() => { + setWithColumns(true) + setWithSecondary(true) + setWithConnected(true) + setWithImpacted(true) + }, [setWithSecondary]) + + const { refetch: getModelLineage, isFetching: isFetchingModelLineage } = + useApiModelLineage(model.name) + const { isFetching: isFetchingModels } = useApiModels() + + const [isMergingModels, setIsMergingModels] = useState(false) + const [modelLineage, setModelLineage] = useState< + ModelLineageType | undefined + >(undefined) + + useEffect(() => { + const lineageWorker = new createLineageWorker() + + lineageWorker.addEventListener('message', handleLineageWorkerMessage) + + getModelLineage() + .then(({ data }) => { + setModelLineage(data ? toModelLineage(data) : undefined) + if (isNil(data)) return + + setIsMergingModels(true) + + const message: LineageWorkerRequestMessage = { + topic: 'lineage', + payload: { + currentLineage: {}, + newLineage: data, + mainNode: model.fqn, + }, + } + lineageWorker.postMessage(message) + }) + .catch(error => { + handleError?.(error) + }) + .finally(() => { + setActiveNodes(new Set()) + setActiveEdges(new Map()) + setConnections(new Map()) + setSelectedNodes(new Set()) + setLineageCache(undefined) + setMainNode(model.fqn) + }) + + return () => { + lineageWorker.removeEventListener('message', handleLineageWorkerMessage) + lineageWorker.terminate() + + setLineage({}) + setNodeConnections({}) + setMainNode(undefined) + setHighlightedNodes({}) + } + }, [model.name, model.hash]) + + useEffect(() => { + const modelNames = toKeys(modelLineage ?? {}) + for (const modelName of modelNames) { + const encodedModelName = encode(modelName) + if ( + isFalse(encodedModelName in models) && + isFalse(encodedModelName in unknownModels) + ) { + unknownModels.add(encodedModelName) + } + } + + setUnknownModels(new Set(unknownModels)) + }, [modelLineage, models]) + + useEffect(() => { + setHighlightedNodes(highlightedNodes ?? {}) + }, [highlightedNodes]) + + function handleLineageWorkerMessage( + e: MessageEvent, + ): void { + if (e.data.topic === 'lineage') { + const message = e.data as LineageWorkerResponseMessage + setIsMergingModels(false) + setNodeConnections(message.payload.nodesConnections) + setLineage(message.payload.lineage) + + if ( + Object.values(message.payload.lineage ?? {}).length > WITH_COLUMNS_LIMIT + ) { + setWithColumns(false) + } + } + + if (e.data.topic === 'error') { + const message = e.data as LineageWorkerErrorMessage + handleError?.(message.error) + setIsMergingModels(false) + } + } + + const isFetching = + isFetchingModelLineage || isFetchingModels || isMergingModels + + return ( +
+ {isFetching && ( +
+ + + +

+ {isFetching ? "Loading Model's Lineage..." : "Merging Model's..."} +

+
+
+ )} + + + +
+ ) +} + +function ModelColumnLineage(): JSX.Element { + const { + withColumns, + lineage, + mainNode, + selectedEdges, + selectedNodes, + withConnected, + withImpacted, + withSecondary, + hasBackground, + activeEdges, + connectedNodes, + connections, + nodesMap, + handleError, + setActiveNodes, + setWithColumns, + } = useLineageFlow() + + const { setCenter } = useReactFlow() + + const [isBuildingLayout, setIsBuildingLayout] = useState(false) + + const nodeTypes = useMemo(() => ({ model: ModelNode }), []) + + const allEdges = useMemo(() => getEdges(lineage), [lineage]) + const lineageIndex = useMemo(() => getLineageIndex(lineage), [lineage]) + + const [nodes, setNodes] = useState([]) + const [edges, setEdges] = useState([]) + + useEffect(() => { + if (isArrayEmpty(allEdges) || isNil(mainNode)) return + + setIsBuildingLayout(true) + + const newActiveNodes = getActiveNodes( + allEdges, + activeEdges, + selectedEdges, + nodesMap, + ) + const newNodes = getUpdatedNodes( + Object.values(nodesMap), + newActiveNodes, + mainNode, + connectedNodes, + selectedNodes, + connections, + withConnected, + withImpacted, + withSecondary, + ) + const newEdges = getUpdatedEdges( + allEdges, + connections, + activeEdges, + newActiveNodes, + selectedEdges, + selectedNodes, + connectedNodes, + withConnected, + withImpacted, + withSecondary, + ) + const createLayout = createGraphLayout({ + nodesMap, + nodes: newNodes, + edges: newEdges, + }) + + createLayout + .create() + .then(layout => { + setEdges(layout.edges) + setNodes(layout.nodes) + }) + .catch(error => { + handleError?.(error) + setEdges([]) + setNodes([]) + }) + .finally(() => { + const node = isNil(mainNode) ? undefined : nodesMap[mainNode] + + if (isNotNil(node)) { + setCenter(node.position.x, node.position.y, { + zoom: 0.5, + duration: 0, + }) + } + + setTimeout(() => { + setIsBuildingLayout(false) + }, 100) + }) + + return () => { + createLayout.terminate() + + setEdges([]) + setNodes([]) + } + }, [activeEdges, nodesMap, lineageIndex]) + + useEffect(() => { + if (isNil(mainNode) || isArrayEmpty(nodes)) return + + const newActiveNodes = getActiveNodes( + allEdges, + activeEdges, + selectedEdges, + nodesMap, + ) + const newNodes = getUpdatedNodes( + nodes, + newActiveNodes, + mainNode, + connectedNodes, + selectedNodes, + connections, + withConnected, + withImpacted, + withSecondary, + ) + + const newEdges = getUpdatedEdges( + allEdges, + connections, + activeEdges, + newActiveNodes, + selectedEdges, + selectedNodes, + connectedNodes, + withConnected, + withImpacted, + withSecondary, + ) + + setEdges(newEdges) + setNodes(newNodes) + setActiveNodes(newActiveNodes) + }, [ + connections, + nodesMap, + allEdges, + activeEdges, + selectedNodes, + selectedEdges, + connectedNodes, + withConnected, + withImpacted, + withSecondary, + withColumns, + mainNode, + ]) + + function onNodesChange(changes: NodeChange[]): void { + setNodes(applyNodeChanges(changes, nodes)) + } + + function onEdgesChange(changes: EdgeChange[]): void { + setEdges(applyEdgeChanges(changes, edges)) + } + + return ( + <> + {isBuildingLayout && ( +
+ + + +

Building Lineage...

+
+
+ )} + + + + + + + + ) +} diff --git a/vscode/react/src/components/graph/ModelNode.tsx b/vscode/react/src/components/graph/ModelNode.tsx new file mode 100644 index 0000000000..864b1437fa --- /dev/null +++ b/vscode/react/src/components/graph/ModelNode.tsx @@ -0,0 +1,234 @@ +import { isNil, isArrayNotEmpty, isNotNil, isFalse } from '@/utils/index' +import clsx from 'clsx' +import { useMemo, useCallback, useState } from 'react' +import { ModelType, type Model } from '@/api/client' +import { useLineageFlow } from './context' +import { type GraphNodeData } from './help' +import { Position, type NodeProps } from 'reactflow' +import { ModelNodeHeaderHandles } from './ModelNodeHeaderHandles' +import { ModelColumns } from './ModelColumns' +import { fromAPIColumn, type Column } from '@/domain/column' +import { decode, type ModelEncodedFQN } from '@/domain/models' +import { toKeys } from './types' +import { MAX_VISIBLE_COLUMNS } from './constants' + +export const EnumLineageNodeModelType = { + ...ModelType, + cte: 'cte', + unknown: 'unknown', +} as const + +export const EnumColumnType = { + UNKNOWN: 'UNKNOWN', + STRUCT: 'STRUCT', +} as const + +export type LineageNodeModelType = keyof typeof EnumLineageNodeModelType +export type ColumnType = keyof typeof EnumColumnType + +export default function ModelNode({ + id: idProp, + data, + sourcePosition, + targetPosition, +}: NodeProps): JSX.Element { + const id = idProp as ModelEncodedFQN + const nodeData: GraphNodeData = data ?? { + label: '', + type: EnumLineageNodeModelType.unknown, + withColumns: false, + } + const { + // connections, + models, + handleClickModel, + lineage, + lineageCache, + selectedNodes, + setSelectedNodes, + mainNode, + withConnected, + connectedNodes, + highlightedNodes, + activeNodes, + } = useLineageFlow() + + const columns: Column[] = useMemo(() => { + const modelsArray = Object.values(models) + const decodedId = decode(id) + const model = modelsArray.find((m: Model) => m.fqn === decodedId) + const modelColumns = model?.columns?.map(fromAPIColumn) ?? [] + + toKeys(lineage[decodedId]?.columns ?? {}).forEach(column => { + const found = modelColumns.find(({ name }) => name === column) + if (isNil(found)) { + modelColumns.push( + fromAPIColumn({ name: column, type: EnumColumnType.UNKNOWN }), + ) + } + }) + return modelColumns.map(column => { + let columnType = column.type ?? EnumColumnType.UNKNOWN + if (columnType.startsWith(EnumColumnType.STRUCT)) { + columnType = EnumColumnType.STRUCT + } + return { + ...column, + type: columnType, + } + }) + }, [id, models, lineage]) + + const highlightedNodeModels = useMemo( + () => Object.values(highlightedNodes).flat(), + [highlightedNodes], + ) + + const [isMouseOver, setIsMouseOver] = useState(false) + + const handleClick = useCallback( + (e: React.MouseEvent) => { + e.stopPropagation() + if (handleClickModel) { + handleClickModel(id) + } + }, + [handleClickModel, id, data.isInteractive], + ) + + const handleSelect = useCallback( + (e: React.MouseEvent) => { + e.stopPropagation() + + if (highlightedNodeModels.includes(id) || mainNode === id) return + + setSelectedNodes(current => { + if (current.has(id)) { + current.delete(id) + } else { + current.add(id) + } + + return new Set(current) + }) + }, + [setSelectedNodes, highlightedNodeModels], + ) + + const splat = highlightedNodes['*'] + // const hasSelectedColumns = columns.some(({ name }) => + // connections.get(toID(id, name)), + // ) + const hasHighlightedNodes = Object.keys(highlightedNodes).length > 0 + const highlighted = Object.keys(highlightedNodes).find(key => + highlightedNodes[key]!.includes(id), + ) + const isMainNode = mainNode === id + const isHighlightedNode = highlightedNodeModels.includes(id) + const isSelected = selectedNodes.has(id) + // Ensure nodeData.type is a valid LineageNodeModelType + const nodeType: LineageNodeModelType = Object.values( + EnumLineageNodeModelType, + ).includes(nodeData.type) + ? (nodeData.type as LineageNodeModelType) + : EnumLineageNodeModelType.unknown + + const isModelSQL = nodeType === EnumLineageNodeModelType.sql + const isCTE = nodeType === EnumLineageNodeModelType.cte + const isModelExternal = nodeType === EnumLineageNodeModelType.external + const isModelSeed = nodeType === EnumLineageNodeModelType.seed + const isModelUnknown = nodeType === EnumLineageNodeModelType.unknown + const showColumns = + nodeData.withColumns && + isArrayNotEmpty(columns) && + isFalse(hasHighlightedNodes) + const isActiveNode = + selectedNodes.size > 0 || activeNodes.size > 0 || withConnected + ? isSelected || + activeNodes.has(id as ModelEncodedFQN) || + (withConnected && connectedNodes.has(id)) + : connectedNodes.has(id) + const isInteractive = true + // mainNode !== id && + // isNotNil(handleClickModel) && + // isFalse(isCTE) && + // isFalse(isModelUnknown) + const shouldDisableColumns = isFalse(isModelSQL) + + return ( +
setIsMouseOver(true)} + onMouseLeave={() => setIsMouseOver(false)} + className={clsx( + 'text-xs font-semibold border-4', + isMouseOver ? 'z-50' : 'z-1', + showColumns ? 'rounded-xl' : 'rounded-2xl', + (hasHighlightedNodes ? isHighlightedNode : isActiveNode) || isMainNode + ? 'opacity-100' + : 'opacity-40 hover:opacity-100', + isNil(highlighted) + ? hasHighlightedNodes + ? splat + : [ + isCTE + ? 'border-accent-500 bg-accent-500 text-accent-500 dark:border-accent-300 dark:bg-accent-300 dark:text-accent-300' + : isModelUnknown + ? 'border-neutral-500 bg-neutral-500 text-neutral-500 dark:border-neutral-300 dark:bg-neutral-300 dark:text-neutral-300' + : 'border-secondary-500 bg-secondary-500 text-secondary-500 dark:bg-primary-500 dark:border-primary-500 dark:text-primary-500', + isMainNode + ? 'ring-8 ring-brand-50' + : isModelExternal || isModelSeed + ? 'ring-8 ring-accent-50' + : '', + ] + : highlighted, + isSelected && isCTE + ? 'ring-8 ring-accent-50' + : isSelected && isModelUnknown + ? 'ring-8 ring-neutral-50' + : isSelected && 'ring-8 ring-secondary-50 dark:ring-primary-50', + )} + style={{ + maxWidth: isNil(nodeData.width) + ? 'auto' + : `${nodeData.width as number}px`, + }} + > + + {showColumns && ( + + )} +
+ ) +} diff --git a/vscode/react/src/components/graph/ModelNodeHeaderHandles.tsx b/vscode/react/src/components/graph/ModelNodeHeaderHandles.tsx new file mode 100644 index 0000000000..a23d6af5c4 --- /dev/null +++ b/vscode/react/src/components/graph/ModelNodeHeaderHandles.tsx @@ -0,0 +1,111 @@ +import { type MouseEvent } from 'react' +import { Handle, Position } from 'reactflow' +import 'reactflow/dist/base.css' +import { getModelNodeTypeTitle } from './help' +import { isNotNil, truncate } from '@/utils/index' +import { toID } from './types' +import { ArrowRightCircleIcon } from '@heroicons/react/24/solid' +import clsx from 'clsx' +import { type LineageNodeModelType } from './ModelNode' +import type { ModelEncodedFQN } from '@/domain/models' + +export function ModelNodeHeaderHandles({ + id, + className, + hasLeft = false, + hasRight = false, + isSelected = false, + isDraggable = false, + label, + type, + numberOfColumns, + handleClick, + handleSelect, +}: { + id: ModelEncodedFQN + label: string + type?: LineageNodeModelType + hasLeft?: boolean + hasRight?: boolean + numberOfColumns?: number + className?: string + isSelected?: boolean + isDraggable?: boolean + handleClick?: (e: MouseEvent) => void + handleSelect?: (e: MouseEvent) => void +}): JSX.Element { + return ( +
+ {hasLeft && ( + + + + )} +
+ {isNotNil(handleSelect) && ( + + + + )} + + {isNotNil(type) && ( + + {getModelNodeTypeTitle(type)} + + )} + + {truncate(decodeURI(label), 50, 20)} + + {isNotNil(numberOfColumns) && ( + + {numberOfColumns} + + )} + +
+ {hasRight && ( + + + + )} +
+ ) +} diff --git a/vscode/react/src/components/graph/SettingsControl.tsx b/vscode/react/src/components/graph/SettingsControl.tsx new file mode 100644 index 0000000000..3016a96ee7 --- /dev/null +++ b/vscode/react/src/components/graph/SettingsControl.tsx @@ -0,0 +1,50 @@ +import { Menu, MenuButton, MenuItem, MenuItems } from '@headlessui/react' +import { CheckIcon } from '@heroicons/react/24/outline' +import { CogIcon } from '@/components/graph/CogIcon' +import clsx from 'clsx' + +interface SettingsControlProps { + showColumns: boolean + onWithColumnsChange: (value: boolean) => void +} + +export function SettingsControl({ + showColumns, + onWithColumnsChange, +}: SettingsControlProps): JSX.Element { + return ( + + + + + onWithColumnsChange(!showColumns)} + > + Show Columns + {showColumns && ( + + + + ) +} diff --git a/vscode/react/src/components/graph/constants.ts b/vscode/react/src/components/graph/constants.ts new file mode 100644 index 0000000000..0927ac027e --- /dev/null +++ b/vscode/react/src/components/graph/constants.ts @@ -0,0 +1,16 @@ +/** + * Space between nodes. + */ +export const NODE_BALANCE_SPACE = 64 +/** + * Height of a column line. + */ +export const COLUMN_LINE_HEIGHT = 24 +/** + * Assumed width of a character. + */ +export const CHAR_WIDTH = 8 +/** + * Maximum number of columns that can be visible in a node. + */ +export const MAX_VISIBLE_COLUMNS = 5 diff --git a/vscode/react/src/components/graph/context.tsx b/vscode/react/src/components/graph/context.tsx new file mode 100644 index 0000000000..9ab4f0722e --- /dev/null +++ b/vscode/react/src/components/graph/context.tsx @@ -0,0 +1,330 @@ +import { + createContext, + useState, + useContext, + useCallback, + useMemo, +} from 'react' +import { getNodeMap, hasActiveEdge, hasActiveEdgeConnector } from './help' +import { type Node } from 'reactflow' +import type { Lineage } from '@/domain/lineage' +import type { ModelSQLMeshModel } from '@/domain/sqlmesh-model' +import type { Column } from '@/domain/column' +import type { ModelEncodedFQN, ModelName } from '@/domain/models' +import type { ColumnName } from '@/domain/column' +import type { Model } from '@/api/client' +import { toID, toKeys } from './types' +import type { ConnectedNode } from '@/components/graph/types' + +export interface Connections { + left: string[] + right: string[] +} +export type ActiveColumns = Map +export type ActiveEdges = Map> +export type ActiveNodes = Set +export type SelectedNodes = Set +export type HighlightedNodes = Record + +interface LineageFlow { + lineage: Record + lineageCache?: Record + mainNode?: ModelEncodedFQN + connectedNodes: Set + activeEdges: ActiveEdges + activeNodes: ActiveNodes + selectedNodes: SelectedNodes + selectedEdges: ConnectedNode[] + models: Record + unknownModels: Set + connections: Map + withConnected: boolean + withColumns: boolean + hasBackground: boolean + withImpacted: boolean + withSecondary: boolean + manuallySelectedColumn?: [ModelSQLMeshModel, Column] + highlightedNodes: HighlightedNodes + nodesMap: Record + setHighlightedNodes: React.Dispatch> + setActiveNodes: React.Dispatch> + setWithConnected: React.Dispatch> + setMainNode: React.Dispatch> + setSelectedNodes: React.Dispatch> + setWithColumns: React.Dispatch> + setHasBackground: React.Dispatch> + setWithImpacted: React.Dispatch> + setWithSecondary: React.Dispatch> + setConnections: React.Dispatch>> + hasActiveEdge: (edge: [string | undefined, string | undefined]) => boolean + addActiveEdges: (edges: Array<[string, string]>) => void + removeActiveEdges: (edges: Array<[string, string]>) => void + setActiveEdges: React.Dispatch> + setUnknownModels: React.Dispatch>> + setLineage: React.Dispatch>> + setLineageCache: React.Dispatch< + React.SetStateAction | undefined> + > + handleClickModel?: (modelName: ModelEncodedFQN) => void + handleError?: (error: any) => void + setManuallySelectedColumn: React.Dispatch< + React.SetStateAction<[ModelSQLMeshModel, Column] | undefined> + > + setNodeConnections: React.Dispatch> + isActiveColumn: ( + modelName: ModelEncodedFQN, + columnName: ColumnName, + ) => boolean +} + +export const LineageFlowContext = createContext({ + selectedEdges: [], + lineage: {}, + lineageCache: undefined, + withColumns: false, + withConnected: false, + withImpacted: true, + withSecondary: false, + hasBackground: true, + mainNode: undefined, + activeEdges: new Map(), + activeNodes: new Set(), + models: {}, + unknownModels: new Set(), + manuallySelectedColumn: undefined, + connections: new Map(), + selectedNodes: new Set(), + connectedNodes: new Set(), + highlightedNodes: {}, + nodesMap: {}, + setHighlightedNodes: () => {}, + setWithColumns: () => false, + setHasBackground: () => false, + setWithImpacted: () => false, + setWithSecondary: () => false, + setWithConnected: () => false, + hasActiveEdge: () => false, + addActiveEdges: () => {}, + removeActiveEdges: () => {}, + setActiveEdges: () => {}, + handleClickModel: () => {}, + setManuallySelectedColumn: () => {}, + handleError: error => console.error(error), + setLineage: () => {}, + setLineageCache: () => {}, + isActiveColumn: () => false, + setConnections: () => {}, + setSelectedNodes: () => {}, + setMainNode: () => {}, + setActiveNodes: () => {}, + setNodeConnections: () => {}, + setUnknownModels: () => {}, +}) + +export default function LineageFlowProvider({ + handleError, + handleClickModel, + children, + showColumns = true, + showConnected = false, + showControls = true, + models, +}: { + children: React.ReactNode + handleClickModel?: (modelName: ModelEncodedFQN) => void + handleError?: (error: any) => void + showColumns?: boolean + showConnected?: boolean + showControls?: boolean + models: Record +}): JSX.Element { + const [lineage, setLineage] = useState>({}) + const [unknownModels, setUnknownModels] = useState(new Set()) + const [lineageCache, setLineageCache] = useState< + Record | undefined + >(undefined) + const [nodesConnections, setNodeConnections] = useState< + Record + >({}) + const [withColumns, setWithColumns] = useState(showColumns) + const [mainNode, setMainNode] = useState() + const [manuallySelectedColumn, setManuallySelectedColumn] = + useState<[ModelSQLMeshModel, Column]>() + const [activeEdges, setActiveEdges] = useState(new Map()) + const [connections, setConnections] = useState>( + new Map(), + ) + const [withConnected, setWithConnected] = useState(showConnected) + const [selectedNodes, setSelectedNodes] = useState(new Set()) + const [activeNodes, setActiveNodes] = useState(new Set()) + const [highlightedNodes, setHighlightedNodes] = useState({}) + const [hasBackground, setHasBackground] = useState(true) + const [withImpacted, setWithImpacted] = useState(true) + const [withSecondary, setWithSecondary] = useState(false) + + const nodesMap = useMemo( + () => + getNodeMap({ + lineage, + // @ts-expect-error TODO: fix this, should move to internal representation + models, + unknownModels, + withColumns, + }), + [lineage, models, withColumns, unknownModels], + ) + + const checkActiveEdge = useCallback( + function checkActiveEdge( + edge: [string | undefined, string | undefined], + ): boolean { + return hasActiveEdge(activeEdges, edge) + }, + [activeEdges], + ) + + const addActiveEdges = useCallback( + function addActiveEdges(edges: Array<[string, string]>): void { + setActiveEdges(activeEdges => { + edges.forEach(([leftConnect, rightConnect]) => { + const left = activeEdges.get(leftConnect) ?? [] + const right = activeEdges.get(rightConnect) ?? [] + const hasDuplicateLeft = left.some( + ([left, right]) => left === leftConnect && right === rightConnect, + ) + const hasDuplicateRight = right.some( + ([left, right]) => left === leftConnect && right === rightConnect, + ) + + if (!hasDuplicateLeft) { + left.push([leftConnect, rightConnect]) + } + + if (!hasDuplicateRight) { + right.push([leftConnect, rightConnect]) + } + + activeEdges.set(leftConnect, left) + activeEdges.set(rightConnect, right) + }) + + return new Map(activeEdges) + }) + }, + [setActiveEdges], + ) + + const removeActiveEdges = useCallback( + function removeActiveEdges(edges: Array<[string, string]>): void { + setActiveEdges(activeEdges => { + edges.forEach(([left, right]) => { + const edgesLeft = (activeEdges.get(left) ?? []).filter( + e => e[0] !== left && e[1] !== right, + ) + const edgesRight = (activeEdges.get(right) ?? []).filter( + e => e[0] !== left && e[1] !== right, + ) + + activeEdges.set(left, edgesLeft) + activeEdges.set(right, edgesRight) + }) + + return new Map(activeEdges) + }) + + setConnections(connections => { + edges.forEach(([left, right]) => { + connections.delete(left) + connections.delete(right) + }) + + return new Map(connections) + }) + }, + [setActiveEdges, setConnections], + ) + + const isActiveColumn = useCallback( + function isActive( + modelName: ModelEncodedFQN, + columnName: ColumnName, + ): boolean { + const leftConnector = toID('left', modelName, columnName) + const rightConnector = toID('right', modelName, columnName) + return ( + hasActiveEdgeConnector(activeEdges, leftConnector) || + hasActiveEdgeConnector(activeEdges, rightConnector) + ) + }, + [checkActiveEdge, activeEdges], + ) + + const connectedNodes = useMemo( + () => new Set(toKeys(nodesConnections)), + [nodesConnections], + ) + + const selectedEdges = useMemo( + () => + Array.from(selectedNodes) + .flatMap(id => nodesConnections[id]) + .filter(Boolean) as any[], + [nodesConnections, selectedNodes], + ) + + return ( + + {children} + + ) +} + +export function useLineageFlow(): LineageFlow { + return useContext(LineageFlowContext) +} diff --git a/vscode/react/src/components/graph/help.ts b/vscode/react/src/components/graph/help.ts new file mode 100644 index 0000000000..93e5c4db45 --- /dev/null +++ b/vscode/react/src/components/graph/help.ts @@ -0,0 +1,739 @@ +import ELK, { type ElkNode } from 'elkjs/lib/elk.bundled.js' +import { + isArrayNotEmpty, + isFalse, + isNil, + isNotNil, + isObjectEmpty, +} from '@/utils/index' +import { type LineageColumn } from '@/api/client' +import { Position, type Edge, type Node, type XYPosition } from 'reactflow' +import { type ActiveEdges, type Connections } from './context' +import { toID, toKeys } from './types' +import { + EnumLineageNodeModelType, + type LineageNodeModelType, +} from './ModelNode' +import type { Lineage } from '@/domain/lineage' +import type { ConnectedNode } from '@/components/graph/types' +import { encode, type ModelEncodedFQN, type ModelURI } from '@/domain/models' +import type { Column, ColumnName } from '@/domain/column' +import type { ModelSQLMeshModel } from '@/domain/sqlmesh-model' +import { + CHAR_WIDTH, + COLUMN_LINE_HEIGHT, + MAX_VISIBLE_COLUMNS, + NODE_BALANCE_SPACE, +} from './constants' + +export interface GraphNodeData { + label: string + type: LineageNodeModelType + withColumns: boolean + width?: number + height?: number + [key: string]: any +} + +export function createGraphLayout({ + nodesMap, + nodes = [], + edges = [], +}: { + nodesMap: Record + nodes: Node[] + edges: Edge[] +}): { + create: () => Promise<{ nodes: Node[]; edges: Edge[] }> + terminate: () => void +} { + // https://eclipse.dev/elk/reference/options.html + const elk: any = new ELK() + + return { + terminate: () => elk.worker.terminate(), + create: async () => + new Promise((resolve, reject) => { + elk + .layout({ + id: 'root', + layoutOptions: { + 'elk.algorithm': 'layered', + 'elk.layered.layering.strategy': 'NETWORK_SIMPLEX', + 'elk.layered.crossingMinimization.strategy': 'INTERACTIVE', + 'elk.direction': 'RIGHT', + // https://eclipse.dev/elk/reference/options/org-eclipse-elk-layered-considerModelOrder-strategy.html + 'elk.layered.considerModelOrder.strategy': 'PREFER_NODES', + 'elk.layered.nodePlacement.strategy': 'SIMPLE', + }, + children: nodes.map(node => ({ + id: node.id, + width: node.data.width, + height: node.data.height, + })), + edges: edges.map(edge => ({ + id: edge.id, + sources: [edge.source], + targets: [edge.target], + })), + }) + .then((layout: any) => + resolve({ + edges, + nodes: repositionNodes(layout.children, nodesMap), + }), + ) + .catch(reject) + }), + } +} + +export function getEdges( + lineage: Record = {}, +): Edge[] { + const modelNames = toKeys(lineage) + const outputEdges: Edge[] = [] + + for (const targetModelName of modelNames) { + const targetModel = lineage[targetModelName]! + + targetModel.models.forEach(sourceModelName => { + outputEdges.push(createGraphEdge(sourceModelName, targetModelName)) + }) + + const targetColumnNames = toKeys(targetModel.columns ?? {}) + for (const targetColumnName of targetColumnNames) { + const sourceModel = targetModel.columns?.[targetColumnName] + + if (isNil(sourceModel) || isNil(sourceModel.models)) continue + + const sourceModelNames = toKeys(sourceModel.models) + for (const sourceModelName of sourceModelNames) { + const sourceColumns = sourceModel.models[sourceModelName] + + if (isNil(sourceColumns)) continue + + for (const sourceColumnName of sourceColumns) { + const sourceHandler = toID('right', sourceModelName, sourceColumnName) + const targetHandler = toID('left', targetModelName, targetColumnName) + outputEdges.push( + createGraphEdge( + sourceModelName, + targetModelName, + sourceHandler, + targetHandler, + true, + { + columnSource: sourceColumnName, + columnTarget: targetColumnName, + }, + ), + ) + } + } + } + } + + return outputEdges +} + +export function getNodeMap({ + lineage, + models, + unknownModels, + withColumns, +}: { + models: Record + withColumns: boolean + unknownModels: Set + lineage?: Record +}): Record { + if (isNil(lineage)) return {} + + const sources = new Set(Object.values(lineage).flatMap(l => l.models)) + const modelNames = Object.keys(lineage) + + return modelNames.reduce((acc: Record, modelName: string) => { + const decodedModelName = modelName.includes('%') + ? decodeURI(modelName) + : modelName + const model = Object.values(models).find(m => m.fqn === decodedModelName) + const nodeType: LineageNodeModelType = isNotNil(model) + ? (model.type as LineageNodeModelType) + : // If model name present in lineage but not in global models + // it means either this is a CTE or model is UNKNOWN + // CTEs only have connections between columns + // where UNKNOWN model has connection only from another model + unknownModels.has(modelName) + ? EnumLineageNodeModelType.unknown + : EnumLineageNodeModelType.cte + + const node = createGraphNode(modelName, { + label: model?.name ?? modelName, + withColumns, + type: nodeType, + }) + const columnsCount = withColumns + ? (models[modelName]?.columns?.length ?? 0) + : 0 + + const maxWidth = Math.min( + getNodeMaxWidth(modelName, columnsCount === 0, models), + 320, + ) + const maxHeight = getNodeMaxHeight(columnsCount) + + node.data.width = maxWidth + NODE_BALANCE_SPACE * 3 + node.data.height = withColumns + ? maxHeight + NODE_BALANCE_SPACE * 2 + : NODE_BALANCE_SPACE + + if (isArrayNotEmpty(lineage[node.id]?.models)) { + node.targetPosition = Position.Left + } + + if (sources.has(node.id as ModelEncodedFQN)) { + node.sourcePosition = Position.Right + } + + acc[modelName] = node + + return acc + }, {}) +} + +function getNodeMaxWidth( + label: string, + hasColumns: boolean = false, + models: Record = {}, +): number { + const defaultWidth = label.length * CHAR_WIDTH + const columns = models[label]?.columns ?? [] + + return hasColumns + ? Math.max(...columns.map(getColumnWidth), defaultWidth) + : defaultWidth +} + +function getColumnWidth(column: Column): number { + return ( + (column.name.length + column.type.length) * CHAR_WIDTH + NODE_BALANCE_SPACE + ) +} + +function getNodeMaxHeight(columnsCount: number): number { + return ( + COLUMN_LINE_HEIGHT * Math.min(columnsCount, MAX_VISIBLE_COLUMNS) + + NODE_BALANCE_SPACE + ) +} + +function repositionNodes( + elkNodes: ElkNode[] = [], + nodesMap: Record, +): Node[] { + const nodes: Node[] = [] + + elkNodes.forEach(node => { + const output = nodesMap[node.id] + + if (isNil(output)) return + + if (isNotNil(node.x) && output.position.x === 0) { + output.position.x = node.x + } + + if (isNotNil(node.y) && output.position.y === 0) { + output.position.y = node.y + } + + nodes.push(output) + }) + + return nodes +} + +function createGraphNode( + id: string, + data: GraphNodeData, + position: XYPosition = { x: 0, y: 0 }, + hidden: boolean = false, +): Node { + return { + id, + dragHandle: '.drag-handle', + type: 'model', + position, + hidden, + data, + connectable: false, + selectable: false, + deletable: false, + focusable: false, + zIndex: -1, + } +} + +function createGraphEdge( + source: string, + target: string, + sourceHandle?: string, + targetHandle?: string, + hidden: boolean = false, + data?: Data, +): Edge { + const output: Edge = { + id: toID(source, target, sourceHandle, targetHandle), + source, + target, + hidden, + data, + type: 'smoothstep', + style: { + strokeWidth: isNil(sourceHandle) || isNil(targetHandle) ? 2 : 4, + }, + } + + if (sourceHandle != null) { + output.sourceHandle = sourceHandle + } + + if (targetHandle != null) { + output.targetHandle = targetHandle + } + + return output +} + +export function mergeLineageWithColumns( + currentLineage: Record = {}, + newLineage: Record> = {}, +): Record { + for (const targetModelName in newLineage) { + const targetModelNameEncoded = encodeURI(targetModelName) + + if (isNil(currentLineage[targetModelNameEncoded])) { + currentLineage[targetModelNameEncoded] = { columns: {}, models: [] } + } + + const currentLineageModel = currentLineage[targetModelNameEncoded]! + const newLineageModel = newLineage[targetModelName]! + + for (const targetColumnName in newLineageModel) { + const targetColumnNameEncoded = encodeURI(targetColumnName) + const newLineageModelColumn = newLineageModel[targetColumnName]! + + if (isNil(currentLineageModel.columns)) { + currentLineageModel.columns = {} + } + + // New Column Lineage delivers fresh data, so we can just assign it + currentLineageModel.columns[targetColumnNameEncoded as ColumnName] = { + expression: newLineageModelColumn.expression ?? undefined, + source: newLineageModelColumn.source ?? undefined, + models: {}, + } + + // If there are no models in new lineage, skip + if (isObjectEmpty(newLineageModelColumn.models)) continue + + const currentLineageModelColumn = + currentLineageModel.columns[targetColumnNameEncoded as ColumnName]! + const currentLineageModelColumnModels = currentLineageModelColumn.models + + for (const sourceColumnName in newLineageModelColumn.models) { + const sourceColumnNameEncoded = encodeURI(sourceColumnName) + const currentLineageModelColumnModel = + currentLineageModelColumnModels[ + sourceColumnNameEncoded as ModelEncodedFQN + ]! + const newLineageModelColumnModel = + newLineageModelColumn.models[sourceColumnName]! + + // @ts-expect-error TODO: fix this + currentLineageModelColumnModels[ + sourceColumnNameEncoded as ModelEncodedFQN + ] = Array.from( + new Set( + isNil(currentLineageModelColumnModel) + ? newLineageModelColumnModel + : currentLineageModelColumnModel.concat( + newLineageModelColumnModel as ColumnName[], + ), + ), + ).map(uri => encode(uri as ModelURI)) + } + } + } + + return currentLineage +} + +export function mergeConnections( + connections: Map, + lineage: Record> = {}, +): { + connections: Map + activeEdges: Array<[string, string]> +} { + const activeEdges: Array<[string, string]> = [] + + // We are getting lineage in format of target -> source + for (const targetModelName in lineage) { + const targetModelNameEncoded = encodeURI(targetModelName) + const model = lineage[targetModelName]! + + for (const targetColumnName in model) { + const targetColumnNameEncoded = encodeURI(targetColumnName) + const column = model[targetColumnName] + + // We don't have any connectins so we skip + if (isNil(column?.models)) continue + + // At this point our Node is model -> {modelName} and column -> {columnName} + // It is a target (left handler) + // but it can also be a source (right handler) for other connections + const modelColumnIdTarget = toID( + targetModelNameEncoded, + targetColumnNameEncoded, + ) + + // We need to check if {modelColumnIdTarget} is already a source/target for other connections + // Left and Right coresponds to node's handlers for {columnName} column + const connectionsModelTarget = connections.get(modelColumnIdTarget) ?? { + left: [], + right: [], + } + + Object.entries(column.models).forEach(([sourceModelName, columns]) => { + const sourceModelNameEncoded = encodeURI(sourceModelName) + columns.forEach(sourceColumnName => { + const sourceColumnNameEncoded = encodeURI(sourceColumnName) + // It is a source (right handler) + // but it can also be a target (left handler) for other connections + const modelColumnIdSource = toID( + sourceModelNameEncoded, + sourceColumnNameEncoded, + ) + + // We need to check if {modelColumnIdSource} is already a source/target for other connections + // Left and Right coresponds to node's handlers for {column} column + const connectionsModelSource = connections.get( + modelColumnIdSource, + ) ?? { left: [], right: [] } + + // we need to add {modelColumnIdTarget} to {connectionsModelSource}'s right handlers + connectionsModelSource.right = Array.from( + new Set(connectionsModelSource.right.concat(modelColumnIdTarget)), + ) + + // We need to add {modelColumnIdSource} to {connectionsModelTarget}'s right handlers + connectionsModelTarget.left = Array.from( + new Set(connectionsModelTarget.left.concat(modelColumnIdSource)), + ) + + connections.set(modelColumnIdSource, connectionsModelSource) + connections.set(modelColumnIdTarget, connectionsModelTarget) + + // Now we need to update active edges from connections + // Left bucket contains references to all sources (right handlers) + // And right bucket contains references to all targets (left handlers) + connectionsModelSource.left.forEach(id => { + activeEdges.push([ + toID('left', modelColumnIdSource), + toID('right', id), + ]) + }) + connectionsModelSource.right.forEach(id => { + activeEdges.push([ + toID('left', id), + toID('right', modelColumnIdSource), + ]) + }) + }) + }) + } + } + + return { + connections, + activeEdges, + } +} + +export function getLineageIndex(lineage: Record = {}): string { + return Object.keys(lineage) + .reduce((acc: string[], key) => { + const { models = [], columns = {} } = lineage[key]! + const allModels = new Set() + + models.forEach(m => allModels.add(m)) + + if (isNotNil(columns)) { + toKeys(columns).forEach(columnName => { + const column = columns[columnName] + if (isNotNil(column) && isNotNil(column.models)) { + toKeys(column.models).forEach(m => allModels.add(m)) + } + }) + } + + return acc.concat(Array.from(allModels)) + }, []) + .sort() + .join('') +} + +export function getModelAncestors( + lineage: Record = {}, + name: string, + output = new Set(), +): Set { + const model = lineage[name] + const models = model?.models ?? [] + + for (const modelName of models) { + if (output.has(modelName)) continue + + getModelAncestors(lineage, modelName, output).add(modelName) + } + + return output +} + +export function getActiveNodes( + edges: Edge[] = [], + activeEdges: ActiveEdges, + selectedEdges: ConnectedNode[], + nodesMap: Record, +): Set { + return new Set( + edges.reduce((acc: ModelEncodedFQN[], edge) => { + const sourceNode = isNil(edge.sourceHandle) + ? undefined + : nodesMap[edge.sourceHandle] + const targetNode = isNil(edge.targetHandle) + ? undefined + : nodesMap[edge.targetHandle] + + if ( + isNotNil(sourceNode) && + isNotNil(edge.sourceHandle) && + sourceNode.data.type === EnumLineageNodeModelType.external && + hasActiveEdgeConnector(activeEdges, edge.sourceHandle) + ) { + acc.push(edge.source as ModelEncodedFQN) + } else if ( + isNotNil(targetNode) && + isNotNil(edge.targetHandle) && + targetNode.data.type === EnumLineageNodeModelType.external && + hasActiveEdgeConnector(activeEdges, edge.targetHandle) + ) { + acc.push(edge.target as ModelEncodedFQN) + } else { + const isActiveEdge = hasActiveEdge(activeEdges, [ + edge.targetHandle, + edge.sourceHandle, + ]) + + if (isActiveEdge || hasEdge(selectedEdges, edge.id)) { + if (isNotNil(edge.source)) { + acc.push(edge.source as ModelEncodedFQN) + } + + if (isNotNil(edge.target)) { + acc.push(edge.target as ModelEncodedFQN) + } + } + } + return acc + }, []), + ) +} + +export function getUpdatedEdges( + edges: Edge[] = [], + connections: Map, + activeEdges: ActiveEdges, + activeNodes: Set, + selectedEdges: ConnectedNode[], + selectedNodes: Set, + connectedNodes: Set, + withConnected: boolean = false, + withImpacted: boolean = false, + withSecondary: boolean = false, +): Edge[] { + const tempEdges = edges.map(edge => { + const isActiveEdge = hasActiveEdge(activeEdges, [ + edge.targetHandle, + edge.sourceHandle, + ]) + + edge.hidden = true + + if (isNil(edge.sourceHandle) && isNil(edge.targetHandle)) { + // Edge between models + const hasSelections = + selectedNodes.size > 0 || connections.size > 0 || activeNodes.size > 0 + const isImpactedEdge = + connectedNodes.has(edge.source) || connectedNodes.has(edge.target) + const isSecondaryEdge = + isFalse(connectedNodes.has(edge.source)) || + isFalse(connectedNodes.has(edge.target)) + const withoutImpactedNodes = + isFalse(withImpacted) && + isFalse(withConnected) && + isFalse(hasSelections) + const withoutSecondaryNodes = + isFalse(withSecondary) && isFalse(hasSelections) + const shouldHideSecondary = isSecondaryEdge && withoutSecondaryNodes + const shouldHideImpacted = isImpactedEdge && withoutImpactedNodes + const isVisibleEdge = + selectedNodes.size > 0 && + hasEdge(selectedEdges, edge.id) && + activeNodes.has(edge.source) && + activeNodes.has(edge.target) + + if ( + isFalse(shouldHideImpacted) && + isFalse(shouldHideSecondary) && + (isFalse(hasSelections) || isVisibleEdge) + ) { + edge.hidden = false + } + } else { + // Edge between columns + if (connections.size > 0 && isActiveEdge) { + edge.hidden = false + } + } + + let stroke = 'var(--color-graph-edge-main)' + let strokeWidth = 2 + + const isConnectedSource = connectedNodes.has(edge.source) + const isConnectedTarget = connectedNodes.has(edge.target) + + if ( + hasEdge(selectedEdges, edge.id) || + (withConnected && isConnectedSource && isConnectedTarget) + ) { + strokeWidth = 4 + stroke = 'var(--color-graph-edge-selected)' + edge.zIndex = 10 + } else { + if (isActiveEdge) { + stroke = 'var(--color-graph-edge-secondary)' + } else if (isConnectedSource && isConnectedTarget) { + strokeWidth = 4 + stroke = 'var(--color-graph-edge-direct)' + edge.zIndex = 10 + } + } + + edge.style = { + ...edge.style, + stroke, + strokeWidth, + } + + return edge + }) + + return tempEdges +} + +export function getUpdatedNodes( + nodes: Node[] = [], + activeNodes: Set, + mainNode: ModelEncodedFQN, + connectedNodes: Set, + selectedNodes: Set, + connections: Map, + withConnected: boolean, + withImpacted: boolean, + withSecondary: boolean, +): Node[] { + return nodes.map(node => { + node.hidden = true + + const hasSelections = selectedNodes.size > 0 || connections.size > 0 + const isActiveNode = activeNodes.size === 0 || activeNodes.has(node.id) + const isImpactedNode = connectedNodes.has(node.id) + const isSecondaryNode = isFalse(connectedNodes.has(node.id)) + const withoutImpactedNodes = + isFalse(withImpacted) && isFalse(withConnected) && isFalse(hasSelections) + const withoutSecondaryNodes = + isFalse(withSecondary) && isFalse(hasSelections) + const shouldHideSecondary = isSecondaryNode && withoutSecondaryNodes + const shouldHideImpacted = isImpactedNode && withoutImpactedNodes + + if (isFalse(shouldHideImpacted) && isFalse(shouldHideSecondary)) { + node.hidden = isFalse(isActiveNode) + } + + if (node.data.type === EnumLineageNodeModelType.cte) { + node.hidden = isFalse(activeNodes.has(node.id)) + } + + if (mainNode === node.id) { + node.hidden = false + } + + return node + }) +} + +export function hasActiveEdge( + activeEdges: ActiveEdges = new Map(), + [leftConnect, rightConnect]: [ + string | undefined | null, + string | undefined | null, + ], +): boolean { + if (isNil(leftConnect) && isNil(rightConnect)) return false + + const left = isNil(leftConnect) ? undefined : activeEdges.get(leftConnect) + const right = isNil(rightConnect) ? undefined : activeEdges.get(rightConnect) + + if (isNil(left) && isNil(right)) return false + + const inLeft = Boolean( + left?.some(([l, r]) => l === leftConnect && r === rightConnect), + ) + const inRight = Boolean( + right?.some(([l, r]) => l === leftConnect && r === rightConnect), + ) + + return inLeft || inRight +} + +export function hasActiveEdgeConnector( + activeEdges: ActiveEdges = new Map(), + connector: string, +): boolean { + return (activeEdges.get(connector) ?? []).length > 0 +} + +export function getModelNodeTypeTitle(type: LineageNodeModelType): string { + switch (type) { + case EnumLineageNodeModelType.python: + return 'PYTHON' + case EnumLineageNodeModelType.sql: + return 'SQL' + case EnumLineageNodeModelType.seed: + return 'SEED' + case EnumLineageNodeModelType.cte: + return 'CTE' + case EnumLineageNodeModelType.external: + return 'EXTERNAL' + case EnumLineageNodeModelType.source: + return 'SOURCE' + default: + return 'UNKNOWN' + } +} + +function hasEdge(nodes: ConnectedNode[], edge: string): boolean { + return nodes.some(node => node.id === edge || hasEdge(node.edges, edge)) +} diff --git a/vscode/react/src/components/graph/types.ts b/vscode/react/src/components/graph/types.ts new file mode 100644 index 0000000000..6e188b31c8 --- /dev/null +++ b/vscode/react/src/components/graph/types.ts @@ -0,0 +1,101 @@ +import type { ColumnName } from '@/domain/column' +import type { ModelEncodedFQN } from '@/domain/models' +import type { Branded } from '@bus/brand' +import type { Lineage } from '@/domain/lineage' + +export type Side = 'left' | 'right' + +export type Direction = 'upstream' | 'downstream' + +export type NodeId = string + +export type EdgeId = string + +/** + * Partial column handle id that isn't complete yet as it's missing the left/right side + * definition. + */ +export type PartialColumnHandleId = Branded +export type ColumnHandleId = Branded +export type ModelHandleId = Branded + +/** + * Converts a list of strings to a single string with a double underscore + * Outlines with types, the type of ids that can be created. + * @param args + * @returns + */ +export function toID( + leftOrRight: Side, + modelName: ModelEncodedFQN, + columnName: ColumnName, +): NodeId +export function toID( + modelName: ModelEncodedFQN, + columnName: ColumnName, +): PartialColumnHandleId +export function toID( + leftOrRight: Side, + partialColumnHandleId: PartialColumnHandleId, +): ColumnHandleId +export function toID( + leftOrRight: Side, + modelName: ModelEncodedFQN, +): ModelHandleId +export function toID(source: NodeId, target: NodeId): NodeId +export function toID( + source: NodeId, + target: NodeId, + sourceHandle: string | undefined, + targetHandle: string | undefined, +): EdgeId +export function toID(...args: Array): string { + return args.filter(Boolean).join('__') +} + +export function toKeys(obj: Record): K[] { + return Object.keys(obj) as K[] +} + +export type ModelLineage = Record + +// Worker Message Types +export interface ConnectedNode { + id?: string + edges: ConnectedNode[] +} + +export interface LineageWorkerRequestPayload { + currentLineage: Record + newLineage: Record + mainNode: string +} + +export interface LineageWorkerResponsePayload { + lineage: Record + nodesConnections: Record +} + +export interface LineageWorkerErrorPayload { + error: Error +} + +export interface LineageWorkerRequestMessage { + topic: 'lineage' + payload: LineageWorkerRequestPayload +} + +export interface LineageWorkerResponseMessage { + topic: 'lineage' + payload: LineageWorkerResponsePayload +} + +export interface LineageWorkerErrorMessage { + topic: 'error' + error: Error +} + +export type LineageWorkerMessage = + | LineageWorkerRequestMessage + | LineageWorkerResponseMessage + | LineageWorkerErrorMessage diff --git a/vscode/react/src/components/graph/workers/index.ts b/vscode/react/src/components/graph/workers/index.ts new file mode 100644 index 0000000000..9eb76d5287 --- /dev/null +++ b/vscode/react/src/components/graph/workers/index.ts @@ -0,0 +1,3 @@ +import createLineageWorker from './lineage.ts?worker&inline' + +export { createLineageWorker } diff --git a/vscode/react/src/components/graph/workers/lineage.ts b/vscode/react/src/components/graph/workers/lineage.ts new file mode 100644 index 0000000000..fe8337b72d --- /dev/null +++ b/vscode/react/src/components/graph/workers/lineage.ts @@ -0,0 +1,129 @@ +import { isFalse, isNil } from '@/utils/index' +import { type Lineage } from '@/domain/lineage' +import type { ModelEncodedFQN } from '@/domain/models' +import { + toID, + type NodeId, + type LineageWorkerMessage, + type LineageWorkerRequestMessage, + type LineageWorkerResponseMessage, + type LineageWorkerErrorMessage, + type ConnectedNode, +} from '@/components/graph/types' +import type { Direction } from '../types' + +interface WorkerScope { + onmessage: ((e: MessageEvent) => void) | null + postMessage: (message: LineageWorkerMessage) => void +} + +const scope = self as unknown as WorkerScope + +scope.onmessage = async (e: MessageEvent) => { + if (e.data.topic === 'lineage') { + try { + const message = e.data as LineageWorkerRequestMessage + const { currentLineage, newLineage, mainNode } = message.payload + const lineage = await mergeLineageWithModels(currentLineage, newLineage) + const nodesConnections = await getNodesConnections(mainNode, lineage) + + const responseMessage: LineageWorkerResponseMessage = { + topic: 'lineage', + payload: { + lineage, + nodesConnections, + }, + } + scope.postMessage(responseMessage) + } catch (error) { + const errorMessage: LineageWorkerErrorMessage = { + topic: 'error', + error: error as Error, + } + scope.postMessage(errorMessage) + } + } +} + +async function mergeLineageWithModels( + currentLineage: Record = {}, + data: Record = {}, +): Promise> { + return Object.entries(data).reduce( + (acc: Record, [key, models = []]) => { + key = encodeURI(key) + + acc[key] = { + models: models.map(encodeURI) as ModelEncodedFQN[], + columns: currentLineage?.[key]?.columns ?? undefined, + } + + return acc + }, + {}, + ) +} + +async function getNodesConnections( + mainNode: string, + lineage: Record = {}, +): Promise> { + return new Promise((resolve, reject) => { + if (isNil(lineage) || isNil(mainNode)) return {} + + const distances: Record = {} + + try { + getConnectedNodes('upstream', mainNode, lineage, distances) + getConnectedNodes('downstream', mainNode, lineage, distances) + } catch (error) { + reject(error) + } + + resolve(distances) + }) +} + +function getConnectedNodes( + direction: Direction = 'downstream', + node: string, + lineage: Record = {}, + result: Record = {}, +): void { + const isDownstream = direction === 'downstream' + let models: string[] = [] + + if (isDownstream) { + models = Object.keys(lineage).filter(key => + lineage[key]!.models.includes(node as ModelEncodedFQN), + ) + } else { + models = lineage[node]?.models ?? [] + } + + if (isFalse(node in result)) { + result[node] = { edges: [] } + } + + for (const model of models) { + const connectedNode = isDownstream + ? createConnectedNode(node, model, [result[node]!]) + : createConnectedNode(model, node, [result[node]!]) + + if (model in result) { + result[model]!.edges.push(connectedNode) + } else { + result[model] = connectedNode + getConnectedNodes(direction, model, lineage, result) + } + } +} + +function createConnectedNode( + source: NodeId, + target: NodeId, + edges: ConnectedNode[] = [], +): ConnectedNode { + const id = toID(source, target) + return { id, edges } +} diff --git a/vscode/react/src/components/input/Input.tsx b/vscode/react/src/components/input/Input.tsx new file mode 100644 index 0000000000..409db4b639 --- /dev/null +++ b/vscode/react/src/components/input/Input.tsx @@ -0,0 +1,121 @@ +import clsx from 'clsx' +import React from 'react' +import { isNotNil } from '@/utils/index' +import { EnumSize, type Size } from '@/style/variants' +import Textfield from './Textfield' +import Selector from './Selector' + +export interface PropsInput { + label?: string + info?: string + size?: Size + disabled?: boolean + required?: boolean + autoFocus?: boolean + className?: string + children?: ({ + disabled, + required, + autoFocus, + size, + className, + }: { + className: string + disabled: boolean + required: boolean + autoFocus: boolean + size: Size + }) => React.ReactNode | React.ReactNode +} + +function Input({ + label, + info, + size = EnumSize.md, + className, + children, + disabled = false, + required = false, + autoFocus = false, +}: PropsInput): JSX.Element { + const cn = clsx( + 'text-left relative block bg-theme-lighter border-neutral-200 dark:border-neutral-700', + 'focus:outline-none focus:border-secondary-500', + 'ring-secondary-300 ring-opacity-60 ring-offset ring-offset-secondary-100', + size === EnumSize.sm && + 'px-2 py-0.5 text-xs leading-4 border-2 focus:ring-2 rounded-[0.25rem] min-w-[7rem]', + size === EnumSize.md && + 'px-3 py-2 text-sm leading-4 border-2 focus:ring-4 rounded-md min-w-[10rem]', + size === EnumSize.lg && + 'px-3 py-2 text-sm leading-6 border-2 focus:ring-4 rounded-md min-w-[10rem]', + ) + + return ( +
+ {isNotNil(label) && {label}} + {typeof children === 'function' + ? children({ disabled, required, autoFocus, size, className: cn }) + : children} + {isNotNil(info) && {info}} +
+ ) +} + +function InputLabel({ + htmlFor, + className, + children, + size = EnumSize.md, +}: { + htmlFor?: string + className?: string + children: React.ReactNode + size?: Size +}): JSX.Element { + return ( + + ) +} + +function InputInfo({ + className, + children, +}: { + className?: string + children: React.ReactNode +}): JSX.Element { + return ( + + {children} + + ) +} + +Input.Label = InputLabel +Input.Info = InputInfo +Input.Textfield = Textfield +Input.Selector = Selector + +export default Input diff --git a/vscode/react/src/components/input/InputToggle.tsx b/vscode/react/src/components/input/InputToggle.tsx new file mode 100644 index 0000000000..a7e8fce7e0 --- /dev/null +++ b/vscode/react/src/components/input/InputToggle.tsx @@ -0,0 +1,34 @@ +// import { EnumSize } from '@/style/variants' +// import Toggle from '@/components/toggle/Toggle' +import clsx from 'clsx' + +export default function InputToggle({ + label, + info, + // enabled, + // disabled = false, + // setEnabled, + className, +}: { + label: string + // enabled: boolean + // setEnabled: (enabled: boolean) => void + info?: string + // disabled?: boolean + className?: string +}): JSX.Element { + return ( +
+ + {/* */} +
+ ) +} diff --git a/vscode/react/src/components/input/Selector.tsx b/vscode/react/src/components/input/Selector.tsx new file mode 100644 index 0000000000..ea5712be6e --- /dev/null +++ b/vscode/react/src/components/input/Selector.tsx @@ -0,0 +1,132 @@ +import React from 'react' +import { + ChevronUpDownIcon, + ChevronUpIcon, + ChevronDownIcon, + CheckIcon, +} from '@heroicons/react/24/solid' +import clsx from 'clsx' +import * as Select from '@radix-ui/react-select' +import { EnumSize, type Size } from '@/style/variants' + +export interface PropsSelector { + list: Array<{ text: string; value: string }> + onChange: (value: string) => void + size?: Size + name?: string + value?: string + disabled?: boolean + required?: boolean + autoFocus?: boolean + className?: string +} + +export default React.forwardRef( + function Selector( + { + list = [], + required = false, + disabled = false, + autoFocus = false, + size = EnumSize.md, + name, + value = 'default', + className, + onChange, + }: PropsSelector, + ref?: React.Ref, + ): JSX.Element { + const item = list.find(i => i.value === value) ?? + list[0] ?? { text: '', value } + + disabled = disabled || list.length < 2 + + return ( + + + + + + + + + + + + {list.map(({ text, value }) => ( + + {text} + + ))} + + + + + + + ) + }, +) + +function SelectItem({ + disabled = false, + value, + children, + className, +}: { + value: string + children: React.ReactNode + disabled?: boolean + className?: string +}): JSX.Element { + return ( + + {children} + + + + + ) +} diff --git a/vscode/react/src/components/input/Textfield.tsx b/vscode/react/src/components/input/Textfield.tsx new file mode 100644 index 0000000000..81c08e4a5c --- /dev/null +++ b/vscode/react/src/components/input/Textfield.tsx @@ -0,0 +1,50 @@ +import clsx from 'clsx' +import React from 'react' +import { isFalse } from '@/utils/index' + +export interface PropsTextfield { + value?: string | number | undefined + type?: string + placeholder?: string + disabled?: boolean + autoFocus?: boolean + className?: string + onInput?: (e: React.ChangeEvent) => void + onKeyDown?: (e: React.KeyboardEvent) => void +} + +export default React.forwardRef( + function Input( + { + type = 'text', + value, + placeholder, + className, + disabled = false, + autoFocus = false, + onInput, + onKeyDown, + }: PropsTextfield, + ref?: React.Ref, + ): JSX.Element { + value = value ?? '' + return ( + + ) + }, +) diff --git a/vscode/react/src/components/loading/Loading.tsx b/vscode/react/src/components/loading/Loading.tsx new file mode 100644 index 0000000000..509b0d1724 --- /dev/null +++ b/vscode/react/src/components/loading/Loading.tsx @@ -0,0 +1,71 @@ +import React from 'react' +import Spinner from '@/components/logo/Spinner' +import clsx from 'clsx' +import { + EnumSize, + EnumVariant, + type Variant, + type Size, +} from '@/style/variants' +import { isNotNil } from '@/utils/index' +import Title from '@/components/title/Title' + +export default function Loading({ + hasSpinner = false, + size = EnumSize.sm, + variant = EnumVariant.Info, + text, + children, + className, +}: { + children?: React.ReactNode + text?: string + size?: Size + variant?: Variant + hasSpinner?: boolean + className?: string +}): JSX.Element { + return ( + + + {hasSpinner && ( + + )} + {isNotNil(text) ? ( + + ) : ( + children + )} + </span> + </span> + ) +} diff --git a/vscode/react/src/components/loading/LoadingSegment.tsx b/vscode/react/src/components/loading/LoadingSegment.tsx new file mode 100644 index 0000000000..3aebb7ede1 --- /dev/null +++ b/vscode/react/src/components/loading/LoadingSegment.tsx @@ -0,0 +1,26 @@ +import React from 'react' +import Spinner from '@/components/logo/Spinner' +import clsx from 'clsx' +import Loading from './Loading' + +export default function LoadingSegment({ + children, + className, +}: { + className?: string + children?: React.ReactNode +}): JSX.Element { + return ( + <div + className={clsx( + 'flex justify-center items-center w-full h-full', + className, + )} + > + <Loading className="inline-block"> + <Spinner className="w-3 h-3 border border-neutral-10 mr-4" /> + <h3 className="text-md">{children}</h3> + </Loading> + </div> + ) +} diff --git a/vscode/react/src/components/loading/LoadingStatus.tsx b/vscode/react/src/components/loading/LoadingStatus.tsx new file mode 100644 index 0000000000..97b9e55e8a --- /dev/null +++ b/vscode/react/src/components/loading/LoadingStatus.tsx @@ -0,0 +1,15 @@ +import Spinner from '@/components/logo/Spinner' +import Loading from './Loading' + +export default function LoadingStatus({ + children, +}: { + children: React.ReactNode +}): JSX.Element { + return ( + <Loading className="inline-block"> + <Spinner className="w-3 h-3 border border-neutral-10 mr-2" /> + <span className="inline-block text-xs whitespace-nowrap">{children}</span> + </Loading> + ) +} diff --git a/vscode/react/src/components/logo/Spinner.tsx b/vscode/react/src/components/logo/Spinner.tsx new file mode 100644 index 0000000000..6c6e191ac2 --- /dev/null +++ b/vscode/react/src/components/logo/Spinner.tsx @@ -0,0 +1,53 @@ +import clsx from 'clsx' +import React from 'react' +import { EnumVariant, type Variant } from '@/style/variants' + +interface PropsSpinner extends React.SVGAttributes<SVGAElement> { + variant?: Variant +} + +export default function Spinner({ + style, + className, + variant = EnumVariant.Info, +}: PropsSpinner): JSX.Element { + return ( + <svg + style={style} + className={clsx('animate-spin bg-transparent rounded-full', className)} + viewBox="0 0 64 64" + xmlns="http://www.w3.org/2000/svg" + aria-label="Loading" + role="img" + > + <path + fillRule="evenodd" + clipRule="evenodd" + d="M16 59.7128C31.3054 68.5494 50.8763 63.3054 59.7128 48C68.5494 32.6946 63.3054 13.1237 48 4.28719C32.6946 -4.54937 13.1237 0.694636 4.28719 16C-4.54937 31.3054 0.694637 50.8763 16 59.7128ZM23 47.5885C31.6093 52.559 42.6179 49.6093 47.5885 41C52.559 32.3907 49.6093 21.3821 41 16.4115C32.3907 11.441 21.3821 14.3907 16.4115 23C11.441 31.6093 14.3907 42.6179 23 47.5885Z" + className="fill-theme-lighter" + /> + <path + fillRule="evenodd" + clipRule="evenodd" + d="M50.5827 26.5161C49.2259 21.9157 46.1691 17.8082 41.6875 15.2208C37.4263 12.7606 32.6191 12.103 28.1488 13.0114L25.1365 1.76921C32.4854 0.0988895 40.4586 1.08788 47.5 5.15321C54.7617 9.34574 59.6854 16.0326 61.8138 23.5067L50.5827 26.5161Z" + className={clsx( + variant === EnumVariant.Primary && 'fill-primary-500', + variant === EnumVariant.Secondary && 'fill-secondary-500', + variant === EnumVariant.Success && 'fill-success-500', + variant === EnumVariant.Warning && 'fill-warning-500', + variant === EnumVariant.Danger && 'fill-danger-500', + variant === EnumVariant.Info && 'fill-neutral-500', + )} + stroke={clsx( + variant === EnumVariant.Primary && 'var(--color-primary-500)', + variant === EnumVariant.Secondary && 'var(--color-secondary-500)', + variant === EnumVariant.Success && 'var(--color-success-500)', + variant === EnumVariant.Warning && 'var(--color-warning-500)', + variant === EnumVariant.Danger && 'var(--color-danger-500)', + variant === EnumVariant.Info && 'var(--color-neutral-500)', + )} + strokeWidth="4" + /> + </svg> + ) +} diff --git a/vscode/react/src/components/sourceList/SourceList.tsx b/vscode/react/src/components/sourceList/SourceList.tsx new file mode 100644 index 0000000000..25d3db4903 --- /dev/null +++ b/vscode/react/src/components/sourceList/SourceList.tsx @@ -0,0 +1,253 @@ +import Input from '@/components/input/Input' +import { type Virtualizer, useVirtualizer } from '@tanstack/react-virtual' +import { + isArrayEmpty, + isNil, + isNotNil, + isStringEmptyOrNil, +} from '@/utils/index' +import clsx from 'clsx' +import { useEffect, useMemo, useRef, useState } from 'react' +import { EnumSize, EnumVariant } from '@/style/variants' +import { Button } from '../button/Button' + +interface ListItem< + TListItem extends Record<string, any> = Record<string, any>, +> { + id: string + name: string + item: TListItem + to: string + description?: string + text?: string + disabled?: boolean +} + +export default function SourceList< + TItem extends Record<string, any> = Record<string, string>, + TType extends Record<string, string> = Record<string, string>, +>({ + items = [], + keyId = 'id', + keyName = '', + keyDescription = '', + to = '', + disabled = false, + withCounter = true, + withFilter = true, + types, + className, + isActive, + listItem, +}: { + keyId: string + withCounter?: boolean + withFilter?: boolean + to?: string + items?: TItem[] + types?: TType + keyName?: string + keyDescription?: string + disabled?: boolean + className?: string + isActive?: (id: string) => boolean + listItem: (listItem: ListItem<TItem>) => React.ReactNode +}): JSX.Element { + const elSourceList = useRef<HTMLDivElement>(null) + + const [filter, setFilter] = useState('') + + const scrollableAreaRef = useRef<HTMLDivElement>(null) + + const [activeItemIndex, filtered] = useMemo(() => { + let activeIndex = -1 + const filteredList: TItem[] = [] + + items.forEach((item, index) => { + const id = ensureString(item[keyId]) + const description = ensureString(item[keyDescription]) + const name = ensureString(item[keyName]) + const type = ensureString(types?.[id]) + + if ( + name.includes(filter) || + description.includes(filter) || + type.includes(filter) + ) { + filteredList.push(item) + } + + if (isNotNil(isActive) && isActive(item[keyId])) { + activeIndex = index + } + }) + + return [activeIndex, filteredList] + }, [items, filter, isActive]) + + const rowVirtualizer = useVirtualizer({ + count: filtered.length, + getScrollElement: () => scrollableAreaRef.current, + estimateSize: () => 32 + (keyDescription.length > 0 ? 16 : 0), + }) + + const scrollToItem = ({ + itemIndex, + isSmoothScroll = true, + }: { + itemIndex: number + isSmoothScroll?: boolean + }): void => { + rowVirtualizer.scrollToIndex(itemIndex, { + align: 'center', + behavior: isSmoothScroll ? 'smooth' : 'auto', + }) + } + + const isOutsideVisibleRange = ({ + itemIndex, + range, + }: { + itemIndex: number + range: Virtualizer<HTMLDivElement, Element>['range'] + }): boolean => + isNotNil(range) && + (range.startIndex > itemIndex || range?.endIndex < itemIndex) + + /** + * The return button should appear when the + * active item is available in the list (not + * filtered out) and it is not in the visible + * range of the virtualized list + */ + const shouldShowReturnButton = + isStringEmptyOrNil(filter) && + activeItemIndex > -1 && + isOutsideVisibleRange({ + range: rowVirtualizer.range, + itemIndex: activeItemIndex, + }) + + // scroll to the active item when the activeItemIndex changes + useEffect(() => { + if ( + activeItemIndex > -1 && + isOutsideVisibleRange({ + range: rowVirtualizer.range, + itemIndex: activeItemIndex, + }) + ) { + scrollToItem({ itemIndex: activeItemIndex, isSmoothScroll: false }) + } + }, [activeItemIndex]) + + const rows = rowVirtualizer.getVirtualItems() + const totalSize = rowVirtualizer.getTotalSize() + + return ( + <div + ref={elSourceList} + className={clsx( + 'flex flex-col w-full h-full text-sm text-neutral-600 dark:text-neutral-300', + className, + )} + style={{ contain: 'strict' }} + > + {withFilter && ( + <div className="p-1 w-full flex justify-between"> + <Input + className="w-full !m-0" + size={EnumSize.sm} + > + {({ className }) => ( + <Input.Textfield + className={clsx(className, 'w-full')} + value={filter} + placeholder="Filter items" + type="search" + onInput={(e: React.ChangeEvent<HTMLInputElement>) => { + setFilter(e.target.value) + }} + /> + )} + </Input> + {withCounter && ( + <div className="ml-1 px-3 bg-primary-10 text-primary-500 rounded-full text-xs flex items-center"> + {filtered.length} + </div> + )} + </div> + )} + <div className="w-full h-full relative p-1"> + {shouldShowReturnButton && ( + <Button + className="absolute left-[50%] translate-x-[-50%] -top-2 z-10 text-ellipsis !block overflow-hidden no-wrap max-w-[90%] !border-neutral-20 shadow-md !bg-theme !hover:bg-theme text-neutral-500 dark:text-neutral-300 !focus:ring-2 !focus:ring-theme-500 !focus:ring-offset-2 !focus:ring-offset-theme-50 !focus:ring-opacity-50 !focus:outline-none !focus:ring-offset-transparent !focus:ring-offset-0 !focus:ring" + onClick={() => scrollToItem({ itemIndex: activeItemIndex })} + size={EnumSize.sm} + variant={EnumVariant.Secondary} + > + Scroll to selected + </Button> + )} + <div + ref={scrollableAreaRef} + className="w-full h-full relative overflow-hidden overflow-y-auto hover:scrollbar scrollbar--horizontal scrollbar--vertical" + style={{ contain: 'strict' }} + > + <div + className="relative w-full" + style={{ height: totalSize > 0 ? `${totalSize}px` : '100%' }} + > + <ul + className="w-full absolute top-0 left-0" + style={{ transform: `translateY(${rows[0]?.start ?? 0}px)` }} + > + {isArrayEmpty(filtered) && ( + <li + key="not-found" + className="px-2 py-0.5 text-center whitespace-nowrap overflow-ellipsis overflow-hidden" + > + {filter.length > 0 ? 'No Results Found' : 'Empty List'} + </li> + )} + {rows.map(virtualItem => { + const item = filtered[virtualItem.index]! + const id = ensureString(item[keyId]) + const description = ensureString(item[keyDescription]) + const name = ensureString(item[keyName]) + const text = ensureString(types?.[id]) + + return ( + <li + key={virtualItem.key} + data-index={virtualItem.index} + ref={rowVirtualizer.measureElement} + className={clsx( + 'font-normal w-full', + disabled && 'cursor-not-allowed', + )} + tabIndex={id === filter ? -1 : 0} + > + {listItem?.({ + id, + to: `${to}/${id}`, + name, + description, + text, + disabled, + item: filtered[virtualItem.index]!, + })} + </li> + ) + })} + </ul> + </div> + </div> + </div> + </div> + ) +} + +function ensureString(value?: string | number): string { + return isNil(value) ? '' : String(value) +} diff --git a/vscode/react/src/components/sourceList/SourceListItem.tsx b/vscode/react/src/components/sourceList/SourceListItem.tsx new file mode 100644 index 0000000000..ad62221c3b --- /dev/null +++ b/vscode/react/src/components/sourceList/SourceListItem.tsx @@ -0,0 +1,67 @@ +import { isNotNil } from '@/utils/index' +import clsx from 'clsx' +import { NavLink } from 'react-router' +import { EnumVariant, type Variant } from '@/style/variants' + +export default function SourceListItem({ + name, + description, + to, + text, + variant, + disabled = false, + handleDelete, +}: { + name: string + description?: string + to: string + variant?: Variant + disabled?: boolean + text?: string + handleDelete?: () => void +}): JSX.Element { + function handleKeyUp(e: React.KeyboardEvent<HTMLAnchorElement>): void { + if (e.key === 'Delete' || e.key === 'Backspace') { + e.preventDefault() + e.stopPropagation() + + handleDelete?.() + } + } + + return ( + <NavLink + onKeyUp={handleKeyUp} + to={to} + className={({ isActive }: { isActive: boolean }) => + clsx( + 'block overflow-hidden px-2 py-1.5 rounded-md w-full font-semibold', + disabled && 'opacity-50 pointer-events-none', + isActive + ? variant === EnumVariant.Primary + ? 'text-primary-500 bg-primary-10' + : variant === EnumVariant.Danger + ? 'text-danger-500 bg-danger-5' + : 'text-neutral-600 dark:text-neutral-100 bg-neutral-10' + : 'hover:bg-neutral-5 text-neutral-500 dark:text-neutral-400', + ) + } + > + <div className="flex items-center"> + <span className="whitespace-nowrap overflow-ellipsis overflow-hidden min-w-10"> + {name} + </span> + {isNotNil(text) && ( + <span className=" ml-2 px-2 rounded-md leading-0 text-[0.5rem] bg-neutral-10 text-neutral-700 dark:text-neutral-200"> + {text} + </span> + )} + </div> + {isNotNil(description) && ( + <p className="text-xs overflow-hidden whitespace-nowrap overflow-ellipsis text-neutral-300 dark:text-neutral-500"> + {description} + </p> + )} + </NavLink> + ) +} diff --git a/vscode/react/src/components/tablediff/Card.tsx b/vscode/react/src/components/tablediff/Card.tsx new file mode 100644 index 0000000000..d2f4d833c2 --- /dev/null +++ b/vscode/react/src/components/tablediff/Card.tsx @@ -0,0 +1,51 @@ +import { type ReactNode } from 'react' +import { twColors, twMerge } from './tailwind-utils' + +interface CardProps { + children: ReactNode + className?: string +} + +export function Card({ children, className }: CardProps) { + return ( + <div + className={twMerge( + 'rounded-xl shadow-sm border overflow-hidden', + twColors.bgEditor, + twColors.borderNeutral100, + className, + )} + > + {children} + </div> + ) +} + +interface CardHeaderProps { + children: ReactNode + className?: string +} + +export function CardHeader({ children, className }: CardHeaderProps) { + return ( + <div + className={twMerge( + 'px-6 py-4 border-b', + twColors.bgNeutral10, + twColors.borderNeutral100, + className, + )} + > + {children} + </div> + ) +} + +interface CardContentProps { + children: ReactNode + className?: string +} + +export function CardContent({ children, className }: CardContentProps) { + return <div className={twMerge('px-6 py-4', className)}>{children}</div> +} diff --git a/vscode/react/src/components/tablediff/ColumnStatsSection.tsx b/vscode/react/src/components/tablediff/ColumnStatsSection.tsx new file mode 100644 index 0000000000..6b65318864 --- /dev/null +++ b/vscode/react/src/components/tablediff/ColumnStatsSection.tsx @@ -0,0 +1,381 @@ +import { useState } from 'react' +import { type TableDiffData, type SampleValue } from './types' +import { twColors, twMerge } from './tailwind-utils' +import { Card } from './Card' +import { + ArrowsUpDownIcon, + ArrowsRightLeftIcon, +} from '@heroicons/react/24/outline' + +interface ColumnStatsSectionProps { + columnStats: TableDiffData['row_diff']['column_stats'] +} + +interface StatHeaderProps { + stat: string +} + +const StatHeader = ({ stat }: StatHeaderProps) => ( + <th + key={stat} + className={twMerge( + 'text-left py-3 px-4 font-medium text-sm', + twColors.textForeground, + )} + title={stat} + > + {stat} + </th> +) + +interface StatCellProps { + value: SampleValue +} + +const StatCell = ({ value }: StatCellProps) => ( + <td + className={twMerge('py-3 px-4 font-mono text-sm', twColors.textMuted)} + title={String(value)} + > + {typeof value === 'number' ? value.toFixed(1) : String(value)} + </td> +) + +interface ColumnStatRowProps { + columnName: string + statsValue: TableDiffData['row_diff']['column_stats'][string] +} + +const ColumnStatRow = ({ columnName, statsValue }: ColumnStatRowProps) => ( + <tr + className={twMerge( + 'transition-colors border-b', + twColors.borderNeutral100, + twColors.bgHover, + )} + > + <td + className={twMerge( + 'py-3 px-4 font-medium text-sm', + twColors.textForeground, + )} + title={columnName} + > + {columnName} + </td> + {statsValue && typeof statsValue === 'object' + ? Object.values(statsValue as Record<string, SampleValue>).map( + (value, idx) => ( + <StatCell + key={idx} + value={value} + /> + ), + ) + : [ + <StatCell + key="single-value" + value={statsValue} + />, + ]} + </tr> +) + +export function ColumnStatsSection({ columnStats }: ColumnStatsSectionProps) { + const [isVertical, setIsVertical] = useState(false) + + if (Object.keys(columnStats || {}).length === 0) { + return null + } + + // Get the first stats object to determine the column headers + const firstStatsValue = Object.values(columnStats)[0] + const statKeys = + firstStatsValue && typeof firstStatsValue === 'object' + ? Object.keys(firstStatsValue as Record<string, SampleValue>) + : [] + + return ( + <div className="grid grid-cols-1 gap-4"> + {/* Statistics Table Card */} + <Card className="overflow-hidden"> + {/* Toggle Button */} + <div + className={twMerge( + 'flex justify-end p-2 border-b', + twColors.borderNeutral100, + )} + > + <button + onClick={() => setIsVertical(!isVertical)} + className={twMerge( + 'flex items-center gap-1 px-2 py-1 text-xs rounded transition-colors', + twColors.bgHover, + twColors.textMuted, + )} + title={`Switch to ${isVertical ? 'horizontal' : 'vertical'} layout`} + > + {isVertical ? ( + <> + <ArrowsRightLeftIcon className="w-3 h-3" /> + Horizontal + </> + ) : ( + <> + <ArrowsUpDownIcon className="w-3 h-3" /> + Vertical + </> + )} + </button> + </div> + + <div className="overflow-auto max-h-96"> + {isVertical ? ( + // Vertical layout: Each stat as a separate row + <table className="w-full"> + <thead + className={twMerge('sticky top-0 z-10', twColors.bgNeutral10)} + > + <tr className={twMerge('border-b', twColors.borderNeutral100)}> + <th + className={twMerge( + 'text-left py-3 px-4 font-medium text-sm', + twColors.textForeground, + )} + > + Column + </th> + {Object.keys(columnStats).map(col => ( + <th + key={col} + className={twMerge( + 'text-left py-3 px-4 font-medium text-sm', + twColors.textForeground, + )} + > + {col} + </th> + ))} + </tr> + </thead> + <tbody> + {statKeys.map(stat => ( + <tr + key={stat} + className={twMerge( + 'transition-colors border-b', + twColors.borderNeutral100, + twColors.bgHover, + )} + > + <td + className={twMerge( + 'py-3 px-4 font-medium text-sm', + twColors.textForeground, + )} + > + {stat} + </td> + {Object.entries(columnStats).map(([col, statsValue]) => ( + <StatCell + key={col} + value={ + statsValue && typeof statsValue === 'object' + ? (statsValue as Record<string, SampleValue>)[stat] + : statsValue + } + /> + ))} + </tr> + ))} + </tbody> + </table> + ) : ( + // Horizontal layout: Original layout + <table className="w-full"> + <thead + className={twMerge('sticky top-0 z-10', twColors.bgNeutral10)} + > + <tr className={twMerge('border-b', twColors.borderNeutral100)}> + <th + className={twMerge( + 'text-left py-3 px-4 font-medium text-sm', + twColors.textForeground, + )} + > + Column + </th> + {statKeys.map(stat => ( + <StatHeader + key={stat} + stat={stat} + /> + ))} + </tr> + </thead> + <tbody> + {Object.entries(columnStats).map(([col, statsValue]) => ( + <ColumnStatRow + key={col} + columnName={col} + statsValue={statsValue} + /> + ))} + </tbody> + </table> + )} + </div> + </Card> + + {/* Summary Cards */} + <div className="grid grid-cols-1 md:grid-cols-3 gap-4"> + {(() => { + let percentages: { column: string; percentage: number }[] = [] + + if (columnStats && typeof columnStats === 'object') { + if ( + 'pct_match' in columnStats && + typeof columnStats.pct_match === 'object' && + columnStats.pct_match !== null + ) { + const pctMatchData = columnStats.pct_match as Record< + string, + number + > + percentages = Object.entries(pctMatchData) + .map(([col, value]) => ({ + column: col, + percentage: Number(value) || 0, + })) + .filter(item => !isNaN(item.percentage)) + } else { + percentages = Object.entries(columnStats) + .map(([col, stats]) => { + if (!stats || typeof stats !== 'object') return null + + const statsObj = stats as Record<string, number> + const pctMatch = + statsObj.pct_match || + statsObj.match_pct || + statsObj.percentage || + 0 + + return { column: col, percentage: Number(pctMatch) } + }) + .filter( + (item): item is { column: string; percentage: number } => + item !== null && + !isNaN(item.percentage) && + item.column !== 'pct_match', + ) + } + } + + const validPercentages = percentages.map(p => p.percentage) + const highest = + percentages.length > 0 + ? percentages.find( + p => p.percentage === Math.max(...validPercentages), + ) + : null + const lowest = + percentages.length > 0 + ? percentages.find( + p => p.percentage === Math.min(...validPercentages), + ) + : null + const average = + validPercentages.length > 0 + ? validPercentages.reduce((a, b) => a + b, 0) / + validPercentages.length + : 0 + + return ( + <> + <Card className="overflow-hidden"> + <div className={twMerge('h-1', twColors.bgSuccess)} /> + <div className="text-center py-2 px-2"> + <div + className={twMerge( + 'text-lg font-light mb-0.5', + twColors.textSuccess500, + )} + > + {highest ? `${highest.percentage.toFixed(1)}%` : 'N/A'} + </div> + <div + className={twMerge( + 'text-xs font-medium', + twColors.textMuted, + )} + > + Highest Match + </div> + <div + className={twMerge('text-xs mt-0.5', twColors.textMuted)} + > + {highest ? highest.column : 'No data'} + </div> + </div> + </Card> + + <Card className="overflow-hidden"> + <div className={twMerge('h-1', twColors.bgPrimary)} /> + <div className="text-center py-2 px-2"> + <div + className={twMerge( + 'text-lg font-light mb-0.5', + twColors.textPrimary, + )} + > + {average > 0 ? `${average.toFixed(1)}%` : 'N/A'} + </div> + <div + className={twMerge( + 'text-xs font-medium', + twColors.textMuted, + )} + > + Average Match + </div> + <div + className={twMerge('text-xs mt-0.5', twColors.textMuted)} + > + Across {validPercentages.length} columns + </div> + </div> + </Card> + + <Card className="overflow-hidden"> + <div className={twMerge('h-1', twColors.bgDanger)} /> + <div className="text-center py-2 px-2"> + <div + className={twMerge( + 'text-lg font-light mb-0.5', + twColors.textDanger500, + )} + > + {lowest ? `${lowest.percentage.toFixed(1)}%` : 'N/A'} + </div> + <div + className={twMerge( + 'text-xs font-medium', + twColors.textMuted, + )} + > + Lowest Match + </div> + <div + className={twMerge('text-xs mt-0.5', twColors.textMuted)} + > + {lowest ? lowest.column : 'No data'} + </div> + </div> + </Card> + </> + ) + })()} + </div> + </div> + ) +} diff --git a/vscode/react/src/components/tablediff/ContentSections.tsx b/vscode/react/src/components/tablediff/ContentSections.tsx new file mode 100644 index 0000000000..dee8019e0c --- /dev/null +++ b/vscode/react/src/components/tablediff/ContentSections.tsx @@ -0,0 +1,78 @@ +import { SectionCard } from './SectionCard' +import { SchemaDiffSection } from './SchemaDiffSection' +import { RowStatsSection } from './RowStatsSection' +import { ColumnStatsSection } from './ColumnStatsSection' +import { SampleDataSection } from './SampleDataSection' +import { usePersistedState } from './hooks' +import type { TableDiffData, ExpandedSections } from './types' + +interface ContentSectionsProps { + data: TableDiffData +} + +export function ContentSections({ data }: ContentSectionsProps) { + const [expanded, setExpanded] = usePersistedState<ExpandedSections>( + 'tableDiffExpanded', + { + schema: true, + rows: true, + columnStats: false, + sampleData: false, + }, + ) + + const toggle = (section: keyof ExpandedSections) => { + setExpanded(prev => ({ + ...prev, + [section]: !prev[section], + })) + } + + const { schema_diff, row_diff } = data + + return ( + <div className="overflow-y-auto h-[calc(100%-200px)]"> + {/* Schema Changes */} + <SectionCard + id="schema" + title="Schema Changes" + expanded={expanded.schema} + onToggle={() => toggle('schema')} + > + <SchemaDiffSection schemaDiff={schema_diff} /> + </SectionCard> + + {/* Row Statistics */} + <SectionCard + id="rows" + title="Row Statistics" + expanded={expanded.rows} + onToggle={() => toggle('rows')} + > + <RowStatsSection rowDiff={row_diff} /> + </SectionCard> + + {/* Column Statistics */} + <SectionCard + id="columnStats" + title="Column Statistics" + expanded={expanded.columnStats} + onToggle={() => toggle('columnStats')} + > + <ColumnStatsSection columnStats={row_diff.column_stats} /> + </SectionCard> + + {/* Sample Data */} + {row_diff.processed_sample_data && ( + <SectionCard + id="sampleData" + title="Data Differences" + expanded={expanded.sampleData} + onToggle={() => toggle('sampleData')} + > + <SampleDataSection rowDiff={row_diff} /> + </SectionCard> + )} + </div> + ) +} diff --git a/vscode/react/src/components/tablediff/HeaderCard.tsx b/vscode/react/src/components/tablediff/HeaderCard.tsx new file mode 100644 index 0000000000..5c3b3872a6 --- /dev/null +++ b/vscode/react/src/components/tablediff/HeaderCard.tsx @@ -0,0 +1,201 @@ +import { Card, CardContent } from './Card' +import { DiffConfig } from './Legend' +import { twColors, twMerge } from './tailwind-utils' +import type { TableDiffData } from './types' + +interface HeaderCardProps { + schemaDiff: TableDiffData['schema_diff'] + rowDiff: TableDiffData['row_diff'] + limit: number + whereClause: string + onColumns: string + on: string[][] | undefined + where: string | undefined + isRerunning: boolean + onLimitChange: (limit: number) => void + onWhereClauseChange: (where: string) => void + onOnColumnsChange: (on: string) => void + onRerun: () => void + hasChanges: boolean +} + +export function HeaderCard({ + schemaDiff, + rowDiff, + limit, + whereClause, + onColumns, + on, + where, + isRerunning, + onLimitChange, + onWhereClauseChange, + onOnColumnsChange, + onRerun, + hasChanges, +}: HeaderCardProps) { + const formatPercentage = (v: number) => `${v.toFixed(1)}%` + const formatCount = (v: number) => v.toLocaleString() + + return ( + <Card className="mb-4"> + <CardContent className="py-4"> + <div className="flex items-center gap-3 flex-wrap"> + <span className={twMerge('text-sm font-medium', twColors.textSource)}> + Source: + </span> + <code + className={twMerge( + 'px-3 py-1.5 rounded-lg text-sm whitespace-nowrap', + twColors.bgSource, + 'text-white', + )} + > + {schemaDiff.source} + </code> + <span + className={twMerge('text-sm font-medium ml-4', twColors.textTarget)} + > + Target: + </span> + <code + className={twMerge( + 'px-3 py-1.5 rounded-lg text-sm whitespace-nowrap', + twColors.bgTarget, + 'text-white', + )} + > + {schemaDiff.target} + </code> + </div> + <div className="flex items-center gap-6 text-sm flex-wrap mt-3"> + <div className="flex items-center gap-2"> + <span className={twColors.textMuted}>Source rows:</span> + <span className={twMerge('font-semibold', twColors.textSource)}> + {formatCount(rowDiff.source_count)} + </span> + </div> + <div className="flex items-center gap-2"> + <span className={twColors.textMuted}>Target rows:</span> + <span className={twMerge('font-semibold', twColors.textTarget)}> + {formatCount(rowDiff.target_count)} + </span> + </div> + <div className="flex items-center gap-2"> + <span className={twColors.textMuted}>Change:</span> + <span + className={twMerge( + 'font-semibold', + rowDiff.count_pct_change > 0 + ? twColors.textSuccess500 + : rowDiff.count_pct_change < 0 + ? twColors.textDanger500 + : twColors.textMuted, + )} + > + {formatPercentage(rowDiff.count_pct_change)} + </span> + </div> + </div> + <div className="mt-4 space-y-3"> + <div className="flex flex-wrap gap-3 items-end"> + <div className="flex flex-col gap-1"> + <label + className={twMerge('text-xs font-medium', twColors.textMuted)} + > + Limit: + </label> + <input + type="number" + value={limit} + onChange={e => + onLimitChange(Math.max(1, parseInt(e.target.value) || 1)) + } + className={twMerge( + 'w-20 px-2 py-1 text-sm rounded border', + 'bg-[var(--vscode-input-background)]', + 'border-[var(--vscode-input-border)]', + 'text-[var(--vscode-input-foreground)]', + 'focus:outline-none focus:ring-1 focus:ring-[var(--vscode-focusBorder)]', + )} + min="1" + max="10000" + disabled={isRerunning} + /> + </div> + <div className="flex flex-col gap-1 flex-1 max-w-sm"> + <label + className={twMerge('text-xs font-medium', twColors.textMuted)} + > + Where: + </label> + <input + type="text" + value={whereClause} + onChange={e => onWhereClauseChange(e.target.value)} + placeholder="e.g. created_at > '2024-01-01'" + className={twMerge( + 'px-2 py-1 text-sm rounded border', + 'bg-[var(--vscode-input-background)]', + 'border-[var(--vscode-input-border)]', + 'text-[var(--vscode-input-foreground)]', + 'placeholder:text-[var(--vscode-input-placeholderForeground)]', + 'focus:outline-none focus:ring-1 focus:ring-[var(--vscode-focusBorder)]', + )} + disabled={isRerunning} + /> + </div> + <div className="flex flex-col gap-1 flex-1 max-w-xs"> + <label + className={twMerge('text-xs font-medium', twColors.textMuted)} + > + On (grain): + </label> + <input + type="text" + value={onColumns} + onChange={e => onOnColumnsChange(e.target.value)} + placeholder="e.g. s.id = t.id AND s.date = t.date" + className={twMerge( + 'px-2 py-1 text-sm rounded border', + 'bg-[var(--vscode-input-background)]', + 'border-[var(--vscode-input-border)]', + 'text-[var(--vscode-input-foreground)]', + 'placeholder:text-[var(--vscode-input-placeholderForeground)]', + 'focus:outline-none focus:ring-1 focus:ring-[var(--vscode-focusBorder)]', + )} + disabled={isRerunning} + /> + </div> + <button + onClick={onRerun} + disabled={isRerunning || !hasChanges} + className={twMerge( + 'px-4 py-1.5 text-sm rounded font-medium transition-colors', + 'bg-[var(--vscode-button-background)]', + 'text-[var(--vscode-button-foreground)]', + 'hover:bg-[var(--vscode-button-hoverBackground)]', + 'disabled:opacity-50 disabled:cursor-not-allowed', + 'focus:outline-none focus:ring-1 focus:ring-[var(--vscode-focusBorder)]', + hasChanges && + !isRerunning && + 'bg-[var(--vscode-button-secondaryBackground)] ring-1 ring-[var(--vscode-button-secondaryForeground)]', + )} + > + {isRerunning ? 'Running...' : 'Rerun'} + </button> + </div> + <div className="flex justify-end"> + {on && ( + <DiffConfig + on={on} + limit={limit} + where={where} + /> + )} + </div> + </div> + </CardContent> + </Card> + ) +} diff --git a/vscode/react/src/components/tablediff/Legend.tsx b/vscode/react/src/components/tablediff/Legend.tsx new file mode 100644 index 0000000000..274db60625 --- /dev/null +++ b/vscode/react/src/components/tablediff/Legend.tsx @@ -0,0 +1,59 @@ +import { twColors, twMerge } from './tailwind-utils' + +interface DiffConfigProps { + on: string[] | string[][] + limit?: number + where?: string +} + +interface ConfigItemProps { + label: string + value: string | number +} + +function ConfigItem({ label, value }: ConfigItemProps) { + return ( + <div className="flex items-center gap-1"> + <span className={twMerge('text-xs font-medium', twColors.textMuted)}> + {label}: + </span> + <code + className={twMerge( + 'text-xs px-1 py-0.5 rounded', + twColors.bgNeutral10, + twColors.textForeground, + )} + > + {value} + </code> + </div> + ) +} + +export function DiffConfig({ on, limit, where }: DiffConfigProps) { + // Handle the grain (join keys) + const grainColumns = Array.isArray(on[0]) + ? on.flat().filter((col, index, arr) => arr.indexOf(col) === index) // Remove duplicates from nested array + : (on as string[]) + + return ( + <div className="flex items-center gap-4 flex-wrap text-xs"> + <ConfigItem + label="Grain" + value={grainColumns.join(', ')} + /> + {limit && ( + <ConfigItem + label="Limit" + value={limit} + /> + )} + {where && ( + <ConfigItem + label="Where" + value={where} + /> + )} + </div> + ) +} diff --git a/vscode/react/src/components/tablediff/RerunController.tsx b/vscode/react/src/components/tablediff/RerunController.tsx new file mode 100644 index 0000000000..4d0e7885b2 --- /dev/null +++ b/vscode/react/src/components/tablediff/RerunController.tsx @@ -0,0 +1,168 @@ +import { useState, useEffect } from 'react' +import { callRpc } from '../../utils/rpc' +import type { TableDiffData, TableDiffParams } from './types' + +interface RerunControllerProps { + data: TableDiffData + onDataUpdate?: (data: TableDiffData) => void + children: (props: { + limit: number + whereClause: string + onColumns: string + isRerunning: boolean + hasChanges: boolean + setLimit: (limit: number) => void + setWhereClause: (where: string) => void + setOnColumns: (on: string) => void + handleRerun: () => void + }) => React.ReactNode +} + +export function RerunController({ + data, + onDataUpdate, + children, +}: RerunControllerProps) { + const [isRerunning, setIsRerunning] = useState(false) + const [limit, setLimit] = useState(data.limit || 20) + const [whereClause, setWhereClause] = useState(data.where || '') + const [onColumns, setOnColumns] = useState( + data.on?.map(([sCol, tCol]) => `s.${sCol} = t.${tCol}`).join(' AND ') || '', + ) + + // Update state when data changes + useEffect(() => { + setLimit(data.limit || 20) + setWhereClause(data.where || '') + setOnColumns( + data.on?.map(([sCol, tCol]) => `s.${sCol} = t.${tCol}`).join(' AND ') || + '', + ) + }, [data.limit, data.where, data.on]) + + // Helper function to parse on columns back to array format + const parseOnColumns = (onString: string): string[][] => { + if (!onString.trim()) return [] + + // Parse "s.id = t.id AND s.date = t.date" back to [["id", "id"], ["date", "date"]] + const conditions = onString.split(' AND ') + return conditions.map(condition => { + const match = condition.trim().match(/^s\.(\w+)\s*=\s*t\.(\w+)$/) + if (match) { + return [match[1], match[2]] + } + // Fallback for simple format + return [condition.trim(), condition.trim()] + }) + } + + const hasChanges = + limit !== (data.limit || 20) || + whereClause !== (data.where || '') || + onColumns !== + (data.on?.map(([sCol, tCol]) => `s.${sCol} = t.${tCol}`).join(' AND ') || + '') + + const handleRerun = async () => { + if (isRerunning || !hasChanges) return + + setIsRerunning(true) + try { + // Get the initial data to extract the model name and environment names + const initialDataResult = await callRpc('get_initial_data', {}) + if (!initialDataResult.ok || !initialDataResult.value?.selectedModel) { + console.error('Failed to get initial data for rerun') + return + } + + const params: TableDiffParams = { + source: initialDataResult.value.sourceEnvironment || 'prod', + target: initialDataResult.value.targetEnvironment || 'dev', + model_or_snapshot: initialDataResult.value.selectedModel.name, + limit: Math.min(Math.max(1, limit), 10000), // Ensure limit is within bounds + ...(whereClause.trim() && { where: whereClause.trim() }), + ...(onColumns.trim() && { on: onColumns.trim() }), + } + + console.log('Rerunning table diff with params:', params) + + try { + const timeoutPromise = new Promise((_, reject) => { + setTimeout(() => reject(new Error('Request timeout')), 30000) // 30 second timeout + }) + + const apiPromise = callRpc('api_query', { + method: 'GET', + url: '/api/table_diff', + params: params, + body: {}, + }) + + const result = (await Promise.race([apiPromise, timeoutPromise])) as any + + console.log('Table diff result:', result) + + if (result.ok && result.value) { + let newData: TableDiffData + if (result.value.data) { + newData = { + ...result.value.data, + limit, + where: whereClause, + on: parseOnColumns(onColumns), + } + } else { + newData = { + ...result.value, + limit, + where: whereClause, + on: parseOnColumns(onColumns), + } + } + + console.log('Updating table diff data:', newData) + onDataUpdate?.(newData) + } else { + console.error('API call failed:', result.error) + // Try to extract meaningful error message + let errorMessage = 'Unknown error' + if (typeof result.error === 'string') { + try { + const parsed = JSON.parse(result.error) + errorMessage = parsed.message || parsed.code || result.error + } catch { + errorMessage = result.error + } + } + console.error('Processed error message:', errorMessage) + setIsRerunning(false) + return + } + } catch (apiError) { + console.error('API call threw exception:', apiError) + setIsRerunning(false) + return + } + } catch (error) { + console.error('Error rerunning table diff:', error) + } finally { + setIsRerunning(false) + } + } + + return ( + <> + {children({ + limit, + whereClause, + onColumns, + isRerunning, + hasChanges, + setLimit, + setWhereClause, + setOnColumns, + handleRerun, + })} + </> + ) +} diff --git a/vscode/react/src/components/tablediff/RowStatsSection.tsx b/vscode/react/src/components/tablediff/RowStatsSection.tsx new file mode 100644 index 0000000000..076f7d95fe --- /dev/null +++ b/vscode/react/src/components/tablediff/RowStatsSection.tsx @@ -0,0 +1,102 @@ +import { type TableDiffData } from './types' +import { twColors, twMerge } from './tailwind-utils' +import { Card, CardContent } from './Card' + +interface RowStatsSectionProps { + rowDiff: TableDiffData['row_diff'] +} + +export function RowStatsSection({ rowDiff }: RowStatsSectionProps) { + const formatPercentage = (v: number) => `${(v * 100).toFixed(1)}%` + const formatCount = (v: number) => v.toLocaleString() + + const fullMatchCount = Math.round(rowDiff.stats.full_match_count || 0) + const joinCount = Math.round(rowDiff.stats.join_count || 0) + const partialMatchCount = joinCount - fullMatchCount + const sOnlyCount = Math.round(rowDiff.stats.s_only_count || 0) + const tOnlyCount = Math.round(rowDiff.stats.t_only_count || 0) + const totalRows = rowDiff.source_count + rowDiff.target_count + const fullMatchPct = totalRows > 0 ? (2 * fullMatchCount) / totalRows : 0 + + return ( + <div className="grid grid-cols-2 md:grid-cols-4 gap-3"> + {/* Full Match Card */} + <Card className="overflow-hidden"> + <div className={twMerge('h-1', twColors.bgSuccess)} /> + <CardContent className="text-center py-3 px-3"> + <div + className={twMerge( + 'text-2xl font-light mb-1', + twColors.textSuccess500, + )} + > + {formatCount(fullMatchCount)} + </div> + <div className={twMerge('text-xs font-medium', twColors.textMuted)}> + Full Matches + </div> + <div className={twMerge('text-xs mt-0.5', twColors.textMuted)}> + {formatPercentage(fullMatchPct)} + </div> + </CardContent> + </Card> + + {/* Partial Match Card */} + <Card className="overflow-hidden"> + <div className={twMerge('h-1', twColors.bgPrimary)} /> + <CardContent className="text-center py-3 px-3"> + <div + className={twMerge( + 'text-2xl font-light mb-1', + twColors.textPrimary, + )} + > + {formatCount(partialMatchCount)} + </div> + <div className={twMerge('text-xs font-medium', twColors.textMuted)}> + Partial Matches + </div> + <div className={twMerge('text-xs mt-0.5', twColors.textMuted)}> + {formatPercentage(partialMatchCount / totalRows)} + </div> + </CardContent> + </Card> + + {/* Source Only Card */} + <Card className="overflow-hidden"> + <div className={twMerge('h-1', twColors.bgSource)} /> + <CardContent className="text-center py-3 px-3"> + <div + className={twMerge('text-2xl font-light mb-1', twColors.textSource)} + > + {formatCount(sOnlyCount)} + </div> + <div className={twMerge('text-xs font-medium', twColors.textMuted)}> + Source Only + </div> + <div className={twMerge('text-xs mt-0.5', twColors.textMuted)}> + {formatPercentage(sOnlyCount / totalRows)} + </div> + </CardContent> + </Card> + + {/* Target Only Card */} + <Card className="overflow-hidden"> + <div className={twMerge('h-1', twColors.bgTarget)} /> + <CardContent className="text-center py-3 px-3"> + <div + className={twMerge('text-2xl font-light mb-1', twColors.textTarget)} + > + {formatCount(tOnlyCount)} + </div> + <div className={twMerge('text-xs font-medium', twColors.textMuted)}> + Target Only + </div> + <div className={twMerge('text-xs mt-0.5', twColors.textMuted)}> + {formatPercentage(tOnlyCount / totalRows)} + </div> + </CardContent> + </Card> + </div> + ) +} diff --git a/vscode/react/src/components/tablediff/SampleDataSection.tsx b/vscode/react/src/components/tablediff/SampleDataSection.tsx new file mode 100644 index 0000000000..00a7f07269 --- /dev/null +++ b/vscode/react/src/components/tablediff/SampleDataSection.tsx @@ -0,0 +1,440 @@ +import { useMemo } from 'react' +import { + type TableDiffData, + type SampleRow, + type SampleValue, + formatCellValue, +} from './types' +import { twColors, twMerge } from './tailwind-utils' + +interface SampleDataSectionProps { + rowDiff: TableDiffData['row_diff'] +} + +interface TableHeaderCellProps { + columnKey: string + sourceName?: SampleValue + targetName?: SampleValue +} + +const TableHeaderCell = ({ + columnKey, + sourceName, + targetName, +}: TableHeaderCellProps) => { + const isSource = columnKey === sourceName + const isTarget = columnKey === targetName + + return ( + <th + className={twMerge( + 'text-left py-2 px-2 font-medium whitespace-nowrap', + isSource && twColors.textSource, + isTarget && twColors.textTarget, + !isSource && !isTarget && twColors.textMuted, + )} + > + {columnKey} + </th> + ) +} + +interface DiffTableCellProps { + columnKey: string + value: SampleValue + sourceName?: SampleValue + targetName?: SampleValue + decimals?: number +} + +const DiffTableCell = ({ + columnKey, + value, + sourceName, + targetName, + decimals = 3, +}: DiffTableCellProps) => { + const isSource = columnKey === sourceName + const isTarget = columnKey === targetName + + return ( + <td + className={twMerge( + 'py-2 px-2 font-mono whitespace-nowrap', + isSource && twColors.textSource + ' bg-blue-500/10', + isTarget && twColors.textTarget + ' bg-green-500/10', + !isSource && !isTarget && twColors.textForeground, + )} + > + {formatCellValue(value, decimals)} + </td> + ) +} + +interface DiffTableRowProps { + row: SampleRow + sourceName?: SampleValue + targetName?: SampleValue + decimals?: number +} + +const DiffTableRow = ({ + row, + sourceName, + targetName, + decimals, +}: DiffTableRowProps) => ( + <tr + className={twMerge( + 'transition-colors', + twColors.borderPanel, + 'border-b', + twColors.bgHover, + )} + > + {Object.entries(row) + .filter(([key]) => !key.startsWith('__')) + .map(([key, cell]) => ( + <DiffTableCell + key={key} + columnKey={key} + value={cell} + sourceName={sourceName} + targetName={targetName} + decimals={decimals} + /> + ))} + </tr> +) + +interface SimpleTableCellProps { + value: SampleValue + colorClass: string + decimals?: number +} + +const SimpleTableCell = ({ + value, + colorClass, + decimals = 3, +}: SimpleTableCellProps) => ( + <td + className={twMerge( + 'py-3 px-4 font-mono whitespace-nowrap text-sm', + colorClass, + )} + > + {formatCellValue(value, decimals)} + </td> +) + +interface SimpleTableRowProps { + row: SampleRow + colorClass: string + borderColorClass: string + decimals?: number +} + +const SimpleTableRow = ({ + row, + colorClass, + borderColorClass, + decimals, +}: SimpleTableRowProps) => ( + <tr + className={twMerge( + 'transition-colors border-b', + borderColorClass, + twColors.bgHover, + )} + > + {Object.values(row).map((cell, cellIdx) => ( + <SimpleTableCell + key={cellIdx} + value={cell} + colorClass={colorClass} + decimals={decimals} + /> + ))} + </tr> +) + +interface ColumnDifferenceGroupProps { + columnName: string + rows: SampleRow[] + decimals: number +} + +const ColumnDifferenceGroup = ({ + columnName, + rows, + decimals, +}: ColumnDifferenceGroupProps) => { + if (!rows || rows.length === 0) return null + + const sourceName = rows[0].__source_name__ + const targetName = rows[0].__target_name__ + + return ( + <div className="mb-4"> + <div + className={twMerge( + 'flex items-center gap-2 text-sm font-medium mb-3', + twColors.textForeground, + )} + > + <span className="font-semibold">Column: {columnName}</span> + <span + className={twMerge( + 'text-xs px-2 py-0.5 rounded-full', + twColors.bgNeutral10, + )} + > + {rows.length} difference{rows.length > 1 ? 's' : ''} + </span> + </div> + <div + className={twMerge( + 'border rounded-lg overflow-hidden', + twColors.borderNeutral100, + )} + > + <div className="overflow-auto max-h-60"> + <table className="w-full"> + <thead + className={twMerge( + 'sticky top-0 z-10', + twColors.bgNeutral10, + 'border-b', + twColors.borderNeutral100, + )} + > + <tr> + {Object.keys(rows[0] || {}) + .filter(key => !key.startsWith('__')) + .map(key => ( + <TableHeaderCell + key={key} + columnKey={key} + sourceName={sourceName} + targetName={targetName} + /> + ))} + </tr> + </thead> + <tbody> + {rows.slice(0, 10).map((row, rowIdx) => ( + <DiffTableRow + key={rowIdx} + row={row} + sourceName={sourceName} + targetName={targetName} + decimals={decimals} + /> + ))} + </tbody> + </table> + </div> + {rows.length > 10 && ( + <p className={twMerge('text-xs mt-2', twColors.textMuted)}> + Showing first 10 of {rows.length} differing rows + </p> + )} + </div> + </div> + ) +} + +export function SampleDataSection({ rowDiff }: SampleDataSectionProps) { + const { processed_sample_data, decimals = 3 } = rowDiff + + if (!processed_sample_data) { + return ( + <div className="px-8 py-3"> + <p className={twMerge('text-sm', twColors.textMuted)}> + No processed sample data available + </p> + </div> + ) + } + + const { column_differences, source_only, target_only } = processed_sample_data + + // Group column differences by column name + const groupedDifferences = useMemo(() => { + const groups: Record<string, SampleRow[]> = {} + + column_differences.forEach((row: SampleRow) => { + const columnName = String(row.__column_name__ || 'unknown') + if (!groups[columnName]) { + groups[columnName] = [] + } + groups[columnName].push(row) + }) + + return groups + }, [column_differences]) + + return ( + <div className="px-8 py-3 space-y-6"> + {/* COMMON ROWS diff */} + <div> + <h4 + className={twMerge( + 'text-base font-semibold mb-4', + twColors.textPrimary, + )} + > + Common Rows + </h4> + {Object.keys(groupedDifferences).length > 0 ? ( + <div className="space-y-4"> + {Object.entries(groupedDifferences).map(([columnName, rows]) => ( + <ColumnDifferenceGroup + key={columnName} + columnName={columnName} + rows={rows} + decimals={decimals} + /> + ))} + </div> + ) : ( + <p className={twMerge('text-sm', twColors.textSuccess)}> + ✓ All joined rows match + </p> + )} + </div> + + {/* SOURCE ONLY & TARGET ONLY tables */} + {source_only && source_only.length > 0 && ( + <div> + <h4 + className={twMerge( + 'text-base font-semibold mb-4', + twColors.textSource, + )} + > + Source Only Rows + </h4> + <div + className={twMerge( + 'border-2 rounded-lg overflow-hidden', + twColors.borderSource, + )} + > + <div className="overflow-auto max-h-80"> + <table className="w-full"> + <thead + className={twMerge('sticky top-0 z-10', twColors.bgNeutral10)} + > + <tr + className={twMerge('border-b', twColors.borderNeutral100)} + > + {Object.keys(source_only[0] || {}).map(col => ( + <th + key={col} + className={twMerge( + 'text-left py-3 px-4 font-medium whitespace-nowrap text-sm', + twColors.textForeground, + )} + > + {col} + </th> + ))} + </tr> + </thead> + <tbody> + {source_only.slice(0, 10).map((row, rowIdx) => ( + <SimpleTableRow + key={rowIdx} + row={row} + colorClass={twColors.textSource} + borderColorClass={twColors.borderNeutral100} + decimals={decimals} + /> + ))} + </tbody> + </table> + </div> + {source_only.length > 10 && ( + <div + className={twMerge( + 'px-4 py-2 text-xs', + twColors.bgNeutral5, + twColors.textMuted, + )} + > + Showing first 10 of {source_only.length} rows + </div> + )} + </div> + </div> + )} + + {target_only && target_only.length > 0 && ( + <div> + <h4 + className={twMerge( + 'text-base font-semibold mb-4', + twColors.textTarget, + )} + > + Target Only Rows + </h4> + <div + className={twMerge( + 'border-2 rounded-lg overflow-hidden', + twColors.borderTarget, + )} + > + <div className="overflow-auto max-h-80"> + <table className="w-full"> + <thead + className={twMerge('sticky top-0 z-10', twColors.bgNeutral10)} + > + <tr + className={twMerge('border-b', twColors.borderNeutral100)} + > + {Object.keys(target_only[0] || {}).map(col => ( + <th + key={col} + className={twMerge( + 'text-left py-3 px-4 font-medium whitespace-nowrap text-sm', + twColors.textForeground, + )} + > + {col} + </th> + ))} + </tr> + </thead> + <tbody> + {target_only.slice(0, 10).map((row, rowIdx) => ( + <SimpleTableRow + key={rowIdx} + row={row} + colorClass={twColors.textTarget} + borderColorClass={twColors.borderNeutral100} + decimals={decimals} + /> + ))} + </tbody> + </table> + </div> + {target_only.length > 10 && ( + <div + className={twMerge( + 'px-4 py-2 text-xs', + twColors.bgNeutral5, + twColors.textMuted, + )} + > + Showing first 10 of {target_only.length} rows + </div> + )} + </div> + </div> + )} + </div> + ) +} diff --git a/vscode/react/src/components/tablediff/SchemaDiffSection.tsx b/vscode/react/src/components/tablediff/SchemaDiffSection.tsx new file mode 100644 index 0000000000..274ac1979c --- /dev/null +++ b/vscode/react/src/components/tablediff/SchemaDiffSection.tsx @@ -0,0 +1,122 @@ +import { useMemo } from 'react' +import { type TableDiffData } from './types' +import { twColors, twMerge } from './tailwind-utils' + +interface SchemaDiffSectionProps { + schemaDiff: TableDiffData['schema_diff'] +} + +interface SchemaChangeItemProps { + column: string + type: string + changeType: 'added' | 'removed' | 'modified' +} + +const SchemaChangeItem = ({ + column, + type, + changeType, +}: SchemaChangeItemProps) => { + const styleMap = { + added: { + bgClass: twColors.bgSuccess10, + borderClass: 'border-l-4 ' + twColors.borderSuccess500, + textClass: twColors.textSuccess500, + symbol: '+', + }, + removed: { + bgClass: twColors.bgDanger10, + borderClass: 'border-l-4 ' + twColors.borderDanger500, + textClass: twColors.textDanger500, + symbol: '-', + }, + modified: { + bgClass: twColors.bgPrimary10, + borderClass: 'border-l-4 ' + twColors.borderPrimary, + textClass: twColors.textPrimary, + symbol: '~', + }, + } + + const { bgClass, borderClass, textClass, symbol } = styleMap[changeType] + + return ( + <div + className={twMerge( + 'flex items-center gap-3 text-sm px-4 py-3 rounded-lg mb-2', + bgClass, + borderClass, + )} + > + <span className={twMerge('font-mono font-bold', textClass)}> + {symbol} + </span> + <span + className={twMerge('font-mono truncate', textClass)} + title={column} + > + {column} + </span> + <span className={twColors.textMuted}>:</span> + <span + className={twMerge('truncate', textClass)} + title={type} + > + {type} + </span> + </div> + ) +} + +export function SchemaDiffSection({ schemaDiff }: SchemaDiffSectionProps) { + const schemaHasChanges = useMemo(() => { + return ( + Object.keys(schemaDiff.added || {}).length > 0 || + Object.keys(schemaDiff.removed || {}).length > 0 || + Object.keys(schemaDiff.modified || {}).length > 0 + ) + }, [schemaDiff]) + + return ( + <div> + {!schemaHasChanges ? ( + <div + className={twMerge( + 'text-sm px-4 py-3 rounded-lg', + twColors.bgSuccess10, + twColors.textSuccess500, + )} + > + ✓ Schemas are identical + </div> + ) : ( + <> + {Object.entries(schemaDiff.added).map(([col, type]) => ( + <SchemaChangeItem + key={col} + column={col} + type={type} + changeType="added" + /> + ))} + {Object.entries(schemaDiff.removed).map(([col, type]) => ( + <SchemaChangeItem + key={col} + column={col} + type={type} + changeType="removed" + /> + ))} + {Object.entries(schemaDiff.modified).map(([col, type]) => ( + <SchemaChangeItem + key={col} + column={col} + type={type} + changeType="modified" + /> + ))} + </> + )} + </div> + ) +} diff --git a/vscode/react/src/components/tablediff/SectionCard.tsx b/vscode/react/src/components/tablediff/SectionCard.tsx new file mode 100644 index 0000000000..af9d42ecd2 --- /dev/null +++ b/vscode/react/src/components/tablediff/SectionCard.tsx @@ -0,0 +1,63 @@ +import { type ReactNode } from 'react' +import { ChevronDownIcon, ChevronRightIcon } from '@heroicons/react/24/outline' +import { Card, CardHeader, CardContent } from './Card' +import { twColors, twMerge } from './tailwind-utils' + +interface Props { + id: string + title: string + children: ReactNode + expanded: boolean + onToggle: () => void + badge?: { text: string; color?: string } +} + +export function SectionCard({ + title, + children, + expanded, + onToggle, + badge, +}: Props) { + return ( + <Card className="mb-4"> + <CardHeader className="p-0"> + <button + onClick={onToggle} + className={twMerge( + 'w-full flex items-center justify-between px-3 py-1 text-left transition-colors', + twColors.bgHover, + twColors.textForeground, + )} + > + <div className="flex items-center gap-1.5"> + <span className="font-medium text-xs">{title}</span> + {badge && ( + <span + className={twMerge( + 'inline-block px-2 py-0.5 text-xs font-medium rounded-full', + badge.color || twColors.bgNeutral10, + )} + > + {badge.text} + </span> + )} + </div> + {expanded ? ( + <ChevronDownIcon className="w-3 h-3" /> + ) : ( + <ChevronRightIcon className="w-3 h-3" /> + )} + </button> + </CardHeader> + <div + className={twMerge( + 'transition-all duration-200 ease-in-out overflow-hidden', + expanded ? 'max-h-[2000px]' : 'max-h-0', + )} + > + <CardContent>{children}</CardContent> + </div> + </Card> + ) +} diff --git a/vscode/react/src/components/tablediff/SectionToggle.tsx b/vscode/react/src/components/tablediff/SectionToggle.tsx new file mode 100644 index 0000000000..4066db22b7 --- /dev/null +++ b/vscode/react/src/components/tablediff/SectionToggle.tsx @@ -0,0 +1,47 @@ +import { type ReactNode } from 'react' +import { ChevronDownIcon, ChevronRightIcon } from '@heroicons/react/24/outline' +import { type ExpandedSections } from './types' +import { twColors, twMerge } from './tailwind-utils' + +interface SectionToggleProps { + id: keyof ExpandedSections + title: string + expanded: boolean + onToggle(): void + children: ReactNode +} + +export function SectionToggle({ + title, + expanded, + onToggle, + children, +}: SectionToggleProps) { + return ( + <div className={twMerge('border-b', twColors.borderPanel)}> + <button + onClick={onToggle} + className={twMerge( + 'w-full px-4 py-2 flex items-center text-left select-none transition-colors', + twColors.textForeground, + twColors.bgHover, + )} + > + {expanded ? ( + <ChevronDownIcon className="w-4 h-4 mr-2 shrink-0 transition-transform" /> + ) : ( + <ChevronRightIcon className="w-4 h-4 mr-2 shrink-0 transition-transform" /> + )} + <span className="font-medium flex-1">{title}</span> + </button> + <div + className={twMerge( + 'overflow-hidden transition-all duration-200', + expanded ? 'max-h-screen' : 'max-h-0', + )} + > + {children} + </div> + </div> + ) +} diff --git a/vscode/react/src/components/tablediff/TableDiff.tsx b/vscode/react/src/components/tablediff/TableDiff.tsx new file mode 100644 index 0000000000..8374f1bc73 --- /dev/null +++ b/vscode/react/src/components/tablediff/TableDiff.tsx @@ -0,0 +1,148 @@ +import { useState, useEffect } from 'react' +import LoadingStatus from '../loading/LoadingStatus' +import { TableDiffResults } from './TableDiffResults' +import { callRpc } from '../../utils/rpc' +import { type TableDiffData } from './types' + +interface ModelInfo { + name: string + fqn: string + description?: string +} + +export function TableDiff() { + const [selectedModel, setSelectedModel] = useState<ModelInfo | null>(null) + const [sourceEnvironment, setSourceEnvironment] = useState<string>('prod') + const [targetEnvironment, setTargetEnvironment] = useState<string>('dev') + const [tableDiffData, setTableDiffData] = useState<TableDiffData | null>(null) + const [isLoadingDiff] = useState(false) + const [diffError] = useState<string | null>(null) + const [hasInitialData, setHasInitialData] = useState(false) + + const handleDataUpdate = (newData: TableDiffData) => { + setTableDiffData(newData) + } + + // Load initial data on mount + useEffect(() => { + const loadInitialData = async () => { + try { + // Try to get initial data first (pre-selected from VSCode) + const initialDataResult = await callRpc('get_initial_data', {}) + if (initialDataResult.ok && initialDataResult.value) { + const data = initialDataResult.value + + // Set all initial state from pre-selected data + if (data.selectedModel) { + setSelectedModel(data.selectedModel) + } + if (data.sourceEnvironment) { + setSourceEnvironment(data.sourceEnvironment) + } + if (data.targetEnvironment) { + setTargetEnvironment(data.targetEnvironment) + } + + // Always mark as having initial data if we got a response from VSCode + setHasInitialData(true) + + if (data.tableDiffData) { + // Handle different response structures + let diffData: TableDiffData | null = null + + if (data.tableDiffData.data !== undefined) { + // Response has a nested data field + diffData = data.tableDiffData.data + } else if ( + data.tableDiffData && + typeof data.tableDiffData === 'object' && + 'schema_diff' in data.tableDiffData && + 'row_diff' in data.tableDiffData + ) { + // Response is the data directly + diffData = data.tableDiffData as TableDiffData + } + + setTableDiffData(diffData) + } + } + } catch (error) { + console.error('Error loading initial data:', error) + } + } + + loadInitialData() + }, []) + + // If we're still loading, show loading state + if (isLoadingDiff) { + return ( + <div className="h-[100vh] w-[100vw]"> + <LoadingStatus>Running table diff...</LoadingStatus> + </div> + ) + } + + // If we have initial data, handle all possible states + if (hasInitialData) { + // Show results if we have them + if (tableDiffData) { + return ( + <div className="h-[100vh] w-[100vw]"> + <TableDiffResults + data={tableDiffData} + onDataUpdate={handleDataUpdate} + /> + </div> + ) + } + + // Show error if there was one + if (diffError) { + return ( + <div className="h-[100vh] w-[100vw] flex items-center justify-center"> + <div className="text-red-400 text-center"> + <div className="text-lg font-semibold mb-2"> + Error running table diff + </div> + <div>{diffError}</div> + </div> + </div> + ) + } + + // If we have initial data but no results and no error, show appropriate message + return ( + <div className="h-[100vh] w-[100vw] flex items-center justify-center"> + <div className="text-neutral-400 text-center"> + <div className="text-lg font-semibold mb-2">No differences found</div> + <div> + The selected model "{selectedModel?.name}" has no differences + between <span className="text-blue-400">{sourceEnvironment}</span>{' '} + and <span className="text-green-400">{targetEnvironment}</span>{' '} + environments. + </div> + </div> + </div> + ) + } + + // If we don't have initial data yet, show loading + if (!hasInitialData) { + return ( + <div className="h-[100vh] w-[100vw]"> + <LoadingStatus>Loading...</LoadingStatus> + </div> + ) + } + + // This should never happen with the new flow + return ( + <div className="h-[100vh] w-[100vw] flex items-center justify-center"> + <div className="text-neutral-400 text-center"> + <div className="text-lg font-semibold mb-2">Unexpected state</div> + <div>Please try running the table diff command again.</div> + </div> + </div> + ) +} diff --git a/vscode/react/src/components/tablediff/TableDiffResults.tsx b/vscode/react/src/components/tablediff/TableDiffResults.tsx new file mode 100644 index 0000000000..45e92b65d1 --- /dev/null +++ b/vscode/react/src/components/tablediff/TableDiffResults.tsx @@ -0,0 +1,64 @@ +import { HeaderCard } from './HeaderCard' +import { ContentSections } from './ContentSections' +import { RerunController } from './RerunController' +import { type TableDiffData } from './types' +import { twColors, twMerge } from './tailwind-utils' + +interface Props { + data: TableDiffData + onDataUpdate?: (data: TableDiffData) => void +} + +export function TableDiffResults({ data, onDataUpdate }: Props) { + if (!data) + return ( + <div className={twMerge('p-4', twColors.textForeground)}> + No data available + </div> + ) + + return ( + <RerunController + data={data} + onDataUpdate={onDataUpdate} + > + {({ + limit, + whereClause, + onColumns, + isRerunning, + hasChanges, + setLimit, + setWhereClause, + setOnColumns, + handleRerun, + }) => ( + <div + className={twMerge( + 'h-full w-full text-[13px] font-sans p-4', + 'bg-[var(--vscode-editor-background)]', + twColors.textForeground, + isRerunning && 'pointer-events-none opacity-75', + )} + > + <HeaderCard + schemaDiff={data.schema_diff} + rowDiff={data.row_diff} + limit={limit} + whereClause={whereClause} + onColumns={onColumns} + on={data.on} + where={data.where} + isRerunning={isRerunning} + onLimitChange={setLimit} + onWhereClauseChange={setWhereClause} + onOnColumnsChange={setOnColumns} + onRerun={handleRerun} + hasChanges={hasChanges} + /> + <ContentSections data={data} /> + </div> + )} + </RerunController> + ) +} diff --git a/vscode/react/src/components/tablediff/hooks.ts b/vscode/react/src/components/tablediff/hooks.ts new file mode 100644 index 0000000000..803b0b8a16 --- /dev/null +++ b/vscode/react/src/components/tablediff/hooks.ts @@ -0,0 +1,29 @@ +import { useState, useEffect } from 'react' + +/** + * Persist state in localStorage so the user's expand / collapse choices + * survive reloads and navigation in VS Code's WebView. + */ +export function usePersistedState<T>( + key: string, + initial: T, +): [T, React.Dispatch<React.SetStateAction<T>>] { + const [state, setState] = useState<T>(() => { + try { + const stored = localStorage.getItem(key) + return stored ? (JSON.parse(stored) as T) : initial + } catch { + return initial + } + }) + + useEffect(() => { + try { + localStorage.setItem(key, JSON.stringify(state)) + } catch { + /* noop */ + } + }, [key, state]) + + return [state, setState] +} diff --git a/vscode/react/src/components/tablediff/index.ts b/vscode/react/src/components/tablediff/index.ts new file mode 100644 index 0000000000..a5b7ea2776 --- /dev/null +++ b/vscode/react/src/components/tablediff/index.ts @@ -0,0 +1,15 @@ +// Main components +export { TableDiff } from './TableDiff' +export { TableDiffResults } from './TableDiffResults' + +// Section components +export { SectionToggle } from './SectionToggle' +export { SchemaDiffSection } from './SchemaDiffSection' +export { RowStatsSection } from './RowStatsSection' +export { ColumnStatsSection } from './ColumnStatsSection' +export { SampleDataSection } from './SampleDataSection' + +// Utilities +export { usePersistedState } from './hooks' +export { twColors, twMerge } from './tailwind-utils' +export * from './types' diff --git a/vscode/react/src/components/tablediff/tailwind-utils.ts b/vscode/react/src/components/tablediff/tailwind-utils.ts new file mode 100644 index 0000000000..182dc69c28 --- /dev/null +++ b/vscode/react/src/components/tablediff/tailwind-utils.ts @@ -0,0 +1,86 @@ +// Tailwind utility classes with CSS variables +export const twColors = { + // Text colors + textForeground: 'text-[var(--vscode-editor-foreground)]', + textInfo: 'text-[var(--vscode-testing-iconUnset)]', + textSuccess: 'text-[var(--vscode-testing-iconPassed)]', + textError: 'text-[var(--vscode-testing-iconFailed)]', + textWarning: 'text-[var(--vscode-testing-iconQueued)]', + textMuted: 'text-[var(--vscode-descriptionForeground)]', + textAccent: 'text-[var(--vscode-textLink-foreground)]', + textAdded: 'text-[var(--vscode-diffEditor-insertedTextForeground)]', + textRemoved: 'text-[var(--vscode-diffEditor-removedTextForeground)]', + textModified: 'text-[var(--vscode-diffEditor-modifiedTextForeground)]', + + // Source and target environment colors + textSource: 'text-[var(--vscode-debugIcon-continueForeground)]', + textTarget: 'text-[var(--vscode-debugIcon-startForeground)]', + textClass: 'text-[var(--vscode-symbolIcon-classForeground)]', + bgSource: 'bg-[var(--vscode-debugIcon-continueForeground)]', + bgTarget: 'bg-[var(--vscode-debugIcon-startForeground)]', + bgClass: 'bg-[var(--vscode-symbolIcon-classForeground)]', + borderSource: 'border-[var(--vscode-debugIcon-continueForeground)]', + borderTarget: 'border-[var(--vscode-debugIcon-startForeground)]', + borderClass: 'border-[var(--vscode-symbolIcon-classForeground)]', + + // Background colors + bgEditor: 'bg-[var(--vscode-editor-background)]', + bgInput: 'bg-[var(--vscode-input-background)]', + bgHover: 'hover:bg-[var(--vscode-list-hoverBackground)]', + bgInactiveSelection: 'bg-[var(--vscode-editor-inactiveSelectionBackground)]', + bgAdded: 'bg-[var(--vscode-diffEditor-insertedTextBackground)]', + bgRemoved: 'bg-[var(--vscode-diffEditor-removedTextBackground)]', + bgModified: 'bg-[var(--vscode-diffEditor-modifiedTextBackground)]', + bgTestSuccess: 'bg-[var(--vscode-testing-iconPassed)]', + bgError: 'bg-[var(--vscode-testing-iconFailed)]', + bgWarning: 'bg-[var(--vscode-testing-iconQueued)]', + bgInfo: 'bg-[var(--vscode-testing-iconUnset)]', + + // Border colors + borderPanel: 'border-[var(--vscode-panel-border)]', + borderInfo: 'border-[var(--vscode-testing-iconUnset)]', + borderSuccess: 'border-[var(--vscode-testing-iconPassed)]', + borderError: 'border-[var(--vscode-diffEditor-removedTextForeground)]', + borderWarning: 'border-[var(--vscode-diffEditor-modifiedTextForeground)]', + borderAdded: 'border-[var(--vscode-diffEditor-insertedTextForeground)]', + borderRemoved: 'border-[var(--vscode-diffEditor-removedTextForeground)]', + borderModified: 'border-[var(--vscode-diffEditor-modifiedTextForeground)]', + + //These colors are similar to web UI + // Primary (blue) + textPrimary: 'text-[#3b82f6]', + bgPrimary10: 'bg-[#3b82f6]/10', + bgPrimary: 'bg-[#3b82f6]', + borderPrimary: 'border-[#3b82f6]', + + // Success (green) + textSuccess500: 'text-[#10b981]', + bgSuccess10: 'bg-[#10b981]/10', + bgSuccess: 'bg-[#10b981]', + borderSuccess500: 'border-[#10b981]', + + // Danger (red) + textDanger500: 'text-[#ef4444]', + bgDanger5: 'bg-[#ef4444]/5', + bgDanger10: 'bg-[#ef4444]/10', + bgDanger: 'bg-[#ef4444]', + borderDanger500: 'border-[#ef4444]', + + // Brand (purple) + textBrand: 'text-[#8b5cf6]', + bgBrand10: 'bg-[#8b5cf6]/10', + bgBrand: 'bg-[#8b5cf6]', + borderBrand500: 'border-[#8b5cf6]', + + // Neutral + bgNeutral5: 'bg-[var(--vscode-editor-inactiveSelectionBackground)]', + bgNeutral10: 'bg-[var(--vscode-list-hoverBackground)]', + textNeutral500: 'text-[var(--vscode-descriptionForeground)]', + textNeutral600: 'text-[var(--vscode-editor-foreground)]', + borderNeutral100: 'border-[var(--vscode-panel-border)]', +} + +// Helper function to combine conditional classes +export function twMerge(...classes: (string | false | undefined | null)[]) { + return classes.filter(Boolean).join(' ') +} diff --git a/vscode/react/src/components/tablediff/types.ts b/vscode/react/src/components/tablediff/types.ts new file mode 100644 index 0000000000..271828476b --- /dev/null +++ b/vscode/react/src/components/tablediff/types.ts @@ -0,0 +1,84 @@ +// Type for data values in samples - can be strings, numbers, booleans, or null +export type SampleValue = string | number | boolean | null + +// Type for row data in samples +export type SampleRow = Record<string, SampleValue> + +// Type for column statistics +export type ColumnStats = Record<string, number | string | null> + +export interface TableDiffData { + schema_diff: { + source: string + target: string + source_schema: Record<string, string> + target_schema: Record<string, string> + added: Record<string, string> + removed: Record<string, string> + modified: Record<string, string> + } + row_diff: { + source: string + target: string + stats: Record<string, number> + sample: Record<string, SampleValue[]> + joined_sample: Record<string, SampleValue[]> + s_sample: Record<string, SampleValue[]> + t_sample: Record<string, SampleValue[]> + column_stats: ColumnStats + source_count: number + target_count: number + count_pct_change: number + decimals: number + processed_sample_data?: { + column_differences: SampleRow[] + source_only: SampleRow[] + target_only: SampleRow[] + } + } + on: string[][] + limit?: number + where?: string +} + +export interface TableDiffParams { + source: string + target: string + model_or_snapshot: string + on?: string + where?: string + temp_schema?: string + limit?: number +} + +export interface ExpandedSections { + schema: boolean + rows: boolean + columnStats: boolean + sampleData: boolean +} + +export const themeColors = { + success: 'var(--vscode-testing-iconPassed, #22c55e)', + warning: 'var(--vscode-testing-iconQueued, #f59e0b)', + error: 'var(--vscode-testing-iconFailed, #ef4444)', + info: 'var(--vscode-testing-iconUnset, #3b82f6)', + addedText: 'var(--vscode-diffEditor-insertedTextForeground, #22c55e)', + removedText: 'var(--vscode-diffEditor-removedTextForeground, #ef4444)', + modifiedText: 'var(--vscode-diffEditor-modifiedTextForeground, #f59e0b)', + muted: 'var(--vscode-descriptionForeground)', + accent: 'var(--vscode-textLink-foreground)', + border: 'var(--vscode-panel-border)', +} + +// Helper utilities +export function cn(...classes: (string | false | undefined)[]) { + return classes.filter(Boolean).join(' ') +} + +export const formatCellValue = (cell: SampleValue, decimals = 3): string => { + if (cell == null) return 'null' + if (typeof cell === 'number') + return cell % 1 === 0 ? cell.toString() : cell.toFixed(decimals) + return String(cell) +} diff --git a/vscode/react/src/components/title/Title.tsx b/vscode/react/src/components/title/Title.tsx new file mode 100644 index 0000000000..29e4edd2c0 --- /dev/null +++ b/vscode/react/src/components/title/Title.tsx @@ -0,0 +1,50 @@ +import clsx from 'clsx' +import { + EnumSize, + EnumVariant, + type Variant, + type Size, +} from '@/style/variants' + +export default function Title({ + as = 'p', + size = EnumSize.sm, + variant = EnumVariant.Info, + text, + className, +}: { + text: string + as?: 'p' | 'span' | 'small' | 'h1' | 'h2' | 'h3' | 'h4' | 'h5' | 'h6' + size?: Size + variant?: Variant + className?: string +}): JSX.Element { + const Tag = as + return ( + <Tag + className={clsx( + 'font-bold whitespace-nowrap', + variant === EnumVariant.Primary && + 'text-primary-600 dark:text-primary-400', + variant === EnumVariant.Secondary && + 'text-secondary-600 dark:text-secondary-400', + variant === EnumVariant.Success && + 'text-success-600 dark:text-success-400', + variant === EnumVariant.Warning && + 'text-warning-600 dark:text-warning-400', + variant === EnumVariant.Danger && + 'text-danger-600 dark:text-danger-400', + variant === EnumVariant.Info && + 'text-neutral-600 dark:text-neutral-400', + size === EnumSize.xs && 'text-xs', + size === EnumSize.sm && 'text-md', + size === EnumSize.md && 'text-lg', + size === EnumSize.lg && 'text-2xl', + size === EnumSize.xl && 'text-4xl', + className, + )} + > + {text} + </Tag> + ) +} diff --git a/vscode/react/src/domain/column.ts b/vscode/react/src/domain/column.ts new file mode 100644 index 0000000000..bd3f7dd9ed --- /dev/null +++ b/vscode/react/src/domain/column.ts @@ -0,0 +1,18 @@ +import { type Column as APIColumn } from '@/api/client' +import { type Branded } from '@bus/brand' + +export type ColumnName = Branded<string, 'ColumnName'> + +export type Column = { + name: ColumnName + type: string + description?: string +} + +export function fromAPIColumn(column: APIColumn): Column { + return { + name: column.name as ColumnName, + type: column.type, + description: column.description ?? undefined, + } +} diff --git a/vscode/react/src/domain/initial.ts b/vscode/react/src/domain/initial.ts new file mode 100644 index 0000000000..68b98cc471 --- /dev/null +++ b/vscode/react/src/domain/initial.ts @@ -0,0 +1,39 @@ +import { isNil, isNotNil, uid } from '@/utils/index' + +type Initial<T extends object> = T & { id?: string } +type InitialWithId<T extends object> = T & { id: string } + +export class ModelInitial<T extends object = any> { + private readonly _initial: InitialWithId<T> + + isModel = true + + constructor(initial?: Initial<T> | InitialWithId<T>) { + if (isNil(initial)) { + this._initial = Object.assign({ + id: uid(), + }) as InitialWithId<T> + } else { + this._initial = isNotNil(initial?.id) + ? (initial as InitialWithId<T>) + : new Proxy<InitialWithId<T>>( + Object.assign(initial ?? {}, { + id: uid(), + }), + { + set() { + throw new Error('Cannot change initial file') + }, + }, + ) + } + } + + get initial(): InitialWithId<T> { + return this._initial + } + + get id(): string { + return this.initial.id + } +} diff --git a/vscode/react/src/domain/lineage.ts b/vscode/react/src/domain/lineage.ts new file mode 100644 index 0000000000..499ff709a7 --- /dev/null +++ b/vscode/react/src/domain/lineage.ts @@ -0,0 +1,37 @@ +import { + type LineageColumn as ApiLineageColumn, + type ModelLineageApiLineageModelNameGet200, +} from '@/api/client' +import type { ModelEncodedFQN, ModelFQN } from '@/domain/models' +import type { ColumnName } from './column' + +export interface Lineage { + models: ModelEncodedFQN[] + columns?: Record<ColumnName, LineageColumn> +} + +export interface LineageColumn { + source?: string + expression?: string + models: { + [key: ModelEncodedFQN]: ColumnName[] + } +} + +export const toLineageColumn = (column: ApiLineageColumn): LineageColumn => { + return { + source: column.source ?? undefined, + expression: column.expression ?? undefined, + models: column.models as Record<ModelEncodedFQN, ColumnName[]>, + } +} + +export interface ModelLineage { + [key: ModelFQN]: ModelFQN[] +} + +export const toModelLineage = ( + lineage: ModelLineageApiLineageModelNameGet200, +): ModelLineage => { + return lineage as Record<ModelFQN, ModelFQN[]> +} diff --git a/vscode/react/src/domain/models.ts b/vscode/react/src/domain/models.ts new file mode 100644 index 0000000000..a54aaf1246 --- /dev/null +++ b/vscode/react/src/domain/models.ts @@ -0,0 +1,57 @@ +import type { Branded } from '@bus/brand' + +/** + * ModelName is a type that represents the name of a model. + */ +export type ModelName = Branded<string, 'ModelName'> + +/** + * ModelEncodedName is a type that represents the encoded name of a model. + */ +export type ModelEncodedName = Branded<string, 'ModelEncodedName'> + +/** + * ModelFQN is a type that represents the fully qualified name of a model. + */ +export type ModelFQN = Branded<string, 'ModelFQN'> + +/** + * ModelEncodedFQN is a type that represents the encoded fully qualified name of a model. + */ +export type ModelEncodedFQN = Branded<string, 'ModelEncodedFQN'> + +/** + * ModelURI is a type that represents the URI of a model. + */ +export type ModelURI = Branded<string, 'ModelURI'> + +/** + * ModelEncodedURI is a type that represents the encoded URI of a model. + */ +export type ModelEncodedURI = Branded<string, 'ModelEncodedURI'> + +export function encode(fqn: ModelName): ModelEncodedName +export function encode(fqn: ModelURI): ModelEncodedURI +export function encode(fqn: ModelFQN): ModelEncodedFQN +export function encode(s: string): string { + return encodeURI(s) +} + +export function decode(fqn: ModelEncodedName): ModelName +export function decode(fqn: ModelEncodedURI): ModelURI +export function decode(fqn: ModelEncodedFQN): ModelFQN +export function decode(s: string): string { + return decodeURI(s) +} + +/** + * ModelPath is a type that represents the path of a model. + * A model path is relative to the project root. + */ +export type ModelPath = Branded<string, 'ModelPath'> + +/** + * ModelFullPath is a type that represents the full path of a model. + * A model full path is a fully qualified path to a model. + */ +export type ModelFullPath = Branded<string, 'ModelFullPath'> diff --git a/vscode/react/src/domain/sqlmesh-model.ts b/vscode/react/src/domain/sqlmesh-model.ts new file mode 100644 index 0000000000..273da62a4d --- /dev/null +++ b/vscode/react/src/domain/sqlmesh-model.ts @@ -0,0 +1,162 @@ +import { + type ModelDetails, + type Model, + type ModelDescription, + type ModelSql, + ModelType, + type ModelDefaultCatalog, + type ModelDefinition, +} from '@/api/client' +import type { + ModelEncodedFQN, + ModelName, + ModelPath, + ModelEncodedName, + ModelFullPath, +} from '@/domain/models' +import { isArrayNotEmpty } from '@/utils/index' +import { ModelInitial } from './initial' +import type { Lineage } from './lineage' +import { fromAPIColumn, type Column } from '@/domain/column' + +export interface InitialSQLMeshModel + extends Omit<Model, 'name' | 'fqn' | 'path' | 'full_path'> { + name: ModelName + fqn: ModelEncodedFQN + path: ModelPath + full_path: ModelFullPath + lineage?: Record<ModelName, Lineage> +} + +export class ModelSQLMeshModel< + T extends InitialSQLMeshModel = InitialSQLMeshModel, +> extends ModelInitial<T> { + _details: ModelDetails = {} + _detailsIndex: string = '' + + name: ModelEncodedName + fqn: ModelEncodedFQN + path: ModelPath + full_path: ModelFullPath + dialect: string + type: ModelType + columns: Column[] + default_catalog?: ModelDefaultCatalog + description?: ModelDescription + sql?: ModelSql + definition?: ModelDefinition + hash: string + + constructor(initial?: T | ModelSQLMeshModel) { + super( + (initial as ModelSQLMeshModel<T>)?.isModel + ? (initial as ModelSQLMeshModel<T>).initial + : { + ...(initial as T), + dialect: initial?.dialect ?? 'Default', + columns: initial?.columns ?? [], + details: initial?.details ?? {}, + }, + ) + + this.name = encodeURI(this.initial.name) as ModelEncodedName + this.fqn = encodeURI(this.initial.fqn) as ModelEncodedFQN + this.default_catalog = this.initial.default_catalog + this.path = this.initial.path as ModelPath + this.full_path = this.initial.full_path as ModelFullPath + this.dialect = this.initial.dialect + this.description = this.initial.description + this.sql = this.initial.sql + this.definition = this.initial.definition + this.columns = this.initial.columns?.map(fromAPIColumn) ?? [] + this.type = this.initial.type + this.hash = this.initial.hash + this.details = this.initial.details ?? {} + } + + get defaultCatalog(): ModelDefaultCatalog | undefined { + return this.default_catalog + } + + get details(): ModelDetails { + return this._details + } + + set details(details: ModelDetails) { + const output = [] + + for (const value of Object.values(details)) { + if (isArrayNotEmpty(value)) { + value.forEach(v => { + output.push(...Object.values(v)) + }) + } else { + output.push(value) + } + } + + this._details = details + this._detailsIndex = output.join(' ') + } + + get index(): string { + return [ + this.displayName, + this.path, + this.type, + ...this.columns.map(column => Object.values(column)).flat(), + this._detailsIndex, + this.dialect, + this.description, + ] + .filter(Boolean) + .join(' ') + .toLowerCase() + } + + get isModelPython(): boolean { + return this.type === ModelType.python + } + + get isModelSQL(): boolean { + return this.type === ModelType.sql + } + + get isModelSeed(): boolean { + return this.type === ModelType.seed + } + + get isModelExternal(): boolean { + return this.type === ModelType.external + } + + get displayName(): string { + return decodeURI(this.name) + } + + update(initial: Partial<InitialSQLMeshModel> = {}): void { + for (const [key, value] of Object.entries(initial)) { + if (key === 'columns') { + this.columns = value as Column[] + } else if (key === 'details') { + this.details = value as ModelDetails + } else if (key === 'name') { + this.name = encodeURI(value as string) as ModelEncodedName + } else if (key === 'fqn') { + this.fqn = encodeURI(value as string) as ModelEncodedFQN + } else if (key === 'type') { + this.type = value as ModelType + } else if (key === 'default_catalog') { + this.default_catalog = value as ModelDefaultCatalog + } else if (key === 'description') { + this.description = value as ModelDescription + } else if (key === 'full_path') { + this.full_path = value as ModelFullPath + } else if (key === 'path') { + this.path = value as ModelPath + } else if (key in this) { + this[key as 'dialect' | 'sql'] = value as string + } + } + } +} diff --git a/vscode/react/src/hooks/eventBus.tsx b/vscode/react/src/hooks/eventBus.tsx new file mode 100644 index 0000000000..308886b6fc --- /dev/null +++ b/vscode/react/src/hooks/eventBus.tsx @@ -0,0 +1,86 @@ +// event-bus.tsx +import { + createContext, + useContext, + useMemo, + useRef, + useEffect, + type PropsWithChildren, +} from 'react' + +/** + * 1️⃣ List every event & its payload shape here + */ +export type EventMap = { + changeFocusedFile: { + fileUri: string + } + savedFile: { + fileUri: string + } +} + +/** + * 2️⃣ Generic bus API — strongly–typed by EventMap + */ +export type EventBus<M extends Record<string, any>> = { + emit<K extends keyof M>(type: K, payload: M[K]): void + on<K extends keyof M>(type: K, handler: (payload: M[K]) => void): () => void // returns “unsubscribe” +} + +function createEventBus<M extends Record<string, any>>(): EventBus<M> { + const listeners = new Map<keyof M, Set<(p: any) => void>>() + + return { + emit(type, payload) { + listeners.get(type)?.forEach(fn => fn(payload)) + }, + on(type, handler) { + let set = listeners.get(type) + if (!set) { + set = new Set() + listeners.set(type, set) + } + set.add(handler) + // remove listener on demand + return () => set!.delete(handler) + }, + } +} + +/** + * 3️⃣ React Context wrapper + */ +const EventBusContext = createContext<EventBus<EventMap> | null>(null) + +export const EventBusProvider = ({ children }: PropsWithChildren<{}>) => { + const bus = useMemo(() => createEventBus<EventMap>(), []) + return ( + <EventBusContext.Provider value={bus}>{children}</EventBusContext.Provider> + ) +} + +/** + * 4️⃣ Convenience hooks + */ +export function useEventBus() { + const bus = useContext(EventBusContext) + if (!bus) throw new Error('useEventBus must be inside <EventBusProvider>') + return bus +} + +export function useEvent<K extends keyof EventMap>( + type: K, + handler: (payload: EventMap[K]) => void, +) { + const bus = useEventBus() + // keep latest handler ref without resubscribing each render + const saved = useRef<typeof handler>(handler) + useEffect(() => { + saved.current = handler + }) + useEffect(() => { + const unsub = bus.on(type, payload => saved.current(payload)) + return unsub // unsubscribe on unmount + }, [bus, type]) +} diff --git a/vscode/react/src/hooks/vscode.ts b/vscode/react/src/hooks/vscode.ts new file mode 100644 index 0000000000..62365019a3 --- /dev/null +++ b/vscode/react/src/hooks/vscode.ts @@ -0,0 +1,9 @@ +import { sendVSCodeMessage } from '@/utils/vscodeapi' + +/** + * use this hook to send messages to the vscode extension + * + * when deving the extension, we use an iframe to load the react app + * so we need to send messages to the parent window + */ +export const useVSCode = () => sendVSCodeMessage diff --git a/vscode/react/src/main.tsx b/vscode/react/src/main.tsx new file mode 100644 index 0000000000..5e24fc648f --- /dev/null +++ b/vscode/react/src/main.tsx @@ -0,0 +1,43 @@ +import { StrictMode } from 'react' +import ReactDOM from 'react-dom/client' +import reportWebVitals from './reportWebVitals.ts' +import { EventBusProvider } from './hooks/eventBus.tsx' +import { TableDiffPage } from './pages/tablediff.tsx' +import { LineagePage } from './pages/lineage.tsx' + +// Detect panel type +declare global { + interface Window { + __SQLMESH_PANEL_TYPE__?: string + } +} + +const panelType = window.__SQLMESH_PANEL_TYPE__ || 'lineage' + +// component selector +function App() { + if (panelType === 'tablediff') { + return <TableDiffPage /> + } + + return <LineagePage /> +} + +// Render the app +const rootElement = document.getElementById('app') +if (rootElement && !rootElement.innerHTML) { + const root = ReactDOM.createRoot(rootElement) + + root.render( + <StrictMode> + <EventBusProvider> + <App /> + </EventBusProvider> + </StrictMode>, + ) +} + +// If you want to start measuring performance in your app, pass a function +// to log results (for example: reportWebVitals(console.log)) +// or send to an analytics endpoint. Learn more: https://bit.ly/CRA-vitals +reportWebVitals() diff --git a/vscode/react/src/pages/lineage.tsx b/vscode/react/src/pages/lineage.tsx new file mode 100644 index 0000000000..18925f28da --- /dev/null +++ b/vscode/react/src/pages/lineage.tsx @@ -0,0 +1,238 @@ +import '../App.css' +import { + QueryCache, + QueryClient, + QueryClientProvider, + useQueryClient, +} from '@tanstack/react-query' +import { useApiModels } from '@/api' +import LineageFlowProvider from '@/components/graph/context' +import { ModelLineage } from '@/components/graph/ModelLineage' +import { useVSCode } from '@/hooks/vscode' +import React, { useState } from 'react' +import { ModelSQLMeshModel } from '@/domain/sqlmesh-model' +import { useEventBus } from '@/hooks/eventBus' +import type { VSCodeEvent } from '@bus/callbacks' +import { URI } from 'vscode-uri' +import type { Model } from '@/api/client' +import { useRpc } from '@/utils/rpc' +import { + type ModelPath, + type ModelFullPath, + type ModelName, + type ModelEncodedFQN, +} from '@/domain/models' + +export function LineagePage() { + const { emit } = useEventBus() + + // Handle messages from VSCode extension + React.useEffect(() => { + const handleMessage = (event: MessageEvent) => { + // Ensure the message is from VSCode + if (event.data && event.data.key === 'vscode_send') { + const payload: VSCodeEvent = event.data.payload + switch (payload.key) { + case 'changeFocusOnFile': + emit('changeFocusedFile', { fileUri: payload.payload.path }) + break + case 'savedFile': + emit('savedFile', { fileUri: payload.payload.fileUri }) + break + default: + console.error( + 'Unhandled message type in lineage page:', + payload.key, + ) + } + } + } + window.addEventListener('message', handleMessage) + return () => { + window.removeEventListener('message', handleMessage) + } + }, []) + + const client = new QueryClient({ + queryCache: new QueryCache({}), + defaultOptions: { + queries: { + networkMode: 'always', + refetchOnWindowFocus: false, + retry: false, + staleTime: Infinity, + }, + }, + }) + + return ( + <QueryClientProvider client={client}> + <Lineage /> + </QueryClientProvider> + ) +} + +function Lineage() { + const [selectedModel, setSelectedModel] = useState<string | undefined>( + undefined, + ) + const { on } = useEventBus() + const queryClient = useQueryClient() + + const { + data: models, + isLoading: isLoadingModels, + error: modelsError, + } = useApiModels() + const rpc = useRpc() + React.useEffect(() => { + const fetchFirstTimeModelIfNotSet = async ( + models: Model[], + ): Promise<string | undefined> => { + if (!Array.isArray(models)) { + return undefined + } + const activeFile = await rpc('get_active_file', {}) + // @ts-ignore + if (!activeFile.fileUri) { + return models[0].name + } + // @ts-ignore + const fileUri: string = activeFile.fileUri + const filePath = URI.file(fileUri).path + const model = models.find((m: Model) => { + if (!m.full_path) { + return false + } + return URI.file(m.full_path).path === filePath + }) + if (model) { + return model.name + } + return undefined + } + if (selectedModel === undefined && Array.isArray(models)) { + fetchFirstTimeModelIfNotSet(models).then(modelName => { + if (modelName && selectedModel === undefined) { + setSelectedModel(modelName) + } else { + setSelectedModel(models[0].name) + } + }) + } + }, [models, selectedModel]) + + const modelsRecord = + Array.isArray(models) && + models.reduce( + (acc, model) => { + acc[model.name] = model + return acc + }, + {} as Record<string, Model>, + ) + + React.useEffect(() => { + const handleChangeFocusedFile = (fileUri: { fileUri: string }) => { + const full_path = URI.parse(fileUri.fileUri).path + const model = Object.values(modelsRecord).find( + m => URI.file(m.full_path).path === full_path, + ) + if (model) { + setSelectedModel(model.name) + } + } + + const handleSavedFile = () => { + queryClient.invalidateQueries() + } + + const offChangeFocusedFile = on( + 'changeFocusedFile', + handleChangeFocusedFile, + ) + const offSavedFile = on('savedFile', handleSavedFile) + + // If your event bus returns an "off" function, call it on cleanup + return () => { + if (offChangeFocusedFile) offChangeFocusedFile() + if (offSavedFile) offSavedFile() + } + }, [on, queryClient, modelsRecord]) + + if (modelsError) { + return <div>Error: {modelsError.message}</div> + } + + if ( + isLoadingModels || + models === undefined || + modelsRecord === false || + selectedModel === undefined + ) { + return <div>Loading models...</div> + } + if (!Array.isArray(models)) { + return <div>Error: Models data is not in the expected format</div> + } + + return ( + <LineageComponentFromWeb + selectedModel={selectedModel} + models={modelsRecord} + /> + ) +} + +export function LineageComponentFromWeb({ + selectedModel, + models, +}: { + selectedModel: string + models: Record<string, Model> +}): JSX.Element { + const vscode = useVSCode() + function handleClickModel(id: string): void { + const decodedId = decodeURIComponent(id) + const model = Object.values(models).find(m => m.fqn === decodedId) + if (!model) { + throw new Error('Model not found') + } + if (!model.full_path) { + return + } + vscode('openFile', { uri: URI.file(model.full_path).toString() }) + } + + function handleError(error: any): void { + console.error(error) + } + + const model = models[selectedModel] + if (!model) { + return <div>Error: Model not found</div> + } + + const sqlmModel = new ModelSQLMeshModel() + sqlmModel.update({ + ...model, + name: model.name as ModelName, + fqn: model.fqn as ModelEncodedFQN, + path: model.path as ModelPath, + full_path: model.full_path as ModelFullPath, + }) + + return ( + <div className="h-[100vh] w-[100vw]"> + <LineageFlowProvider + showColumns={true} + handleClickModel={handleClickModel} + handleError={handleError} + models={models} + showControls={false} + > + <ModelLineage model={sqlmModel} /> + </LineageFlowProvider> + </div> + ) +} diff --git a/vscode/react/src/pages/tablediff.tsx b/vscode/react/src/pages/tablediff.tsx new file mode 100644 index 0000000000..47e3b4ed58 --- /dev/null +++ b/vscode/react/src/pages/tablediff.tsx @@ -0,0 +1,22 @@ +import '../App.css' +import { QueryClient, QueryClientProvider } from '@tanstack/react-query' +import { TableDiff } from '../components/tablediff/TableDiff' + +export function TableDiffPage() { + const client = new QueryClient({ + defaultOptions: { + queries: { + networkMode: 'always', + refetchOnWindowFocus: false, + retry: false, + staleTime: Infinity, + }, + }, + }) + + return ( + <QueryClientProvider client={client}> + <TableDiff /> + </QueryClientProvider> + ) +} diff --git a/vscode/react/src/reportWebVitals.ts b/vscode/react/src/reportWebVitals.ts new file mode 100644 index 0000000000..16b66b5f6c --- /dev/null +++ b/vscode/react/src/reportWebVitals.ts @@ -0,0 +1,13 @@ +const reportWebVitals = (onPerfEntry?: () => void) => { + if (onPerfEntry && onPerfEntry instanceof Function) { + import('web-vitals').then(({ onCLS, onINP, onFCP, onLCP, onTTFB }) => { + onCLS(onPerfEntry) + onINP(onPerfEntry) + onFCP(onPerfEntry) + onLCP(onPerfEntry) + onTTFB(onPerfEntry) + }) + } +} + +export default reportWebVitals diff --git a/vscode/react/src/routeTree.gen.ts b/vscode/react/src/routeTree.gen.ts new file mode 100644 index 0000000000..f18a46802a --- /dev/null +++ b/vscode/react/src/routeTree.gen.ts @@ -0,0 +1,95 @@ +/* eslint-disable */ + +// @ts-nocheck + +// noinspection JSUnusedGlobalSymbols + +// This file was automatically generated by TanStack Router. +// You should NOT make any changes in this file as it will be overwritten. +// Additionally, you should also exclude this file from your linter and/or formatter to prevent it from being checked or modified. + +import { Route as rootRouteImport } from './routes/__root' +import { Route as TablediffRouteImport } from './routes/tablediff' +import { Route as LineageRouteImport } from './routes/lineage' +import { Route as IndexRouteImport } from './routes/index' + +const TablediffRoute = TablediffRouteImport.update({ + id: '/tablediff', + path: '/tablediff', + getParentRoute: () => rootRouteImport, +} as any) +const LineageRoute = LineageRouteImport.update({ + id: '/lineage', + path: '/lineage', + getParentRoute: () => rootRouteImport, +} as any) +const IndexRoute = IndexRouteImport.update({ + id: '/', + path: '/', + getParentRoute: () => rootRouteImport, +} as any) + +export interface FileRoutesByFullPath { + '/': typeof IndexRoute + '/lineage': typeof LineageRoute + '/tablediff': typeof TablediffRoute +} +export interface FileRoutesByTo { + '/': typeof IndexRoute + '/lineage': typeof LineageRoute + '/tablediff': typeof TablediffRoute +} +export interface FileRoutesById { + __root__: typeof rootRouteImport + '/': typeof IndexRoute + '/lineage': typeof LineageRoute + '/tablediff': typeof TablediffRoute +} +export interface FileRouteTypes { + fileRoutesByFullPath: FileRoutesByFullPath + fullPaths: '/' | '/lineage' | '/tablediff' + fileRoutesByTo: FileRoutesByTo + to: '/' | '/lineage' | '/tablediff' + id: '__root__' | '/' | '/lineage' | '/tablediff' + fileRoutesById: FileRoutesById +} +export interface RootRouteChildren { + IndexRoute: typeof IndexRoute + LineageRoute: typeof LineageRoute + TablediffRoute: typeof TablediffRoute +} + +declare module '@tanstack/react-router' { + interface FileRoutesByPath { + '/tablediff': { + id: '/tablediff' + path: '/tablediff' + fullPath: '/tablediff' + preLoaderRoute: typeof TablediffRouteImport + parentRoute: typeof rootRouteImport + } + '/lineage': { + id: '/lineage' + path: '/lineage' + fullPath: '/lineage' + preLoaderRoute: typeof LineageRouteImport + parentRoute: typeof rootRouteImport + } + '/': { + id: '/' + path: '/' + fullPath: '/' + preLoaderRoute: typeof IndexRouteImport + parentRoute: typeof rootRouteImport + } + } +} + +const rootRouteChildren: RootRouteChildren = { + IndexRoute: IndexRoute, + LineageRoute: LineageRoute, + TablediffRoute: TablediffRoute, +} +export const routeTree = rootRouteImport + ._addFileChildren(rootRouteChildren) + ._addFileTypes<FileRouteTypes>() diff --git a/vscode/react/src/routes/__root.tsx b/vscode/react/src/routes/__root.tsx new file mode 100644 index 0000000000..a600d0f849 --- /dev/null +++ b/vscode/react/src/routes/__root.tsx @@ -0,0 +1,21 @@ +import { Outlet, createRootRoute } from '@tanstack/react-router' +import { TanStackRouterDevtools } from '@tanstack/react-router-devtools' +import '../App.css' +import { LineagePage } from '@/pages/lineage' +import { TableDiffPage } from '@/pages/tablediff' + +export const Route = createRootRoute({ + component: () => { + return ( + <> + <Outlet /> + <TanStackRouterDevtools /> + </> + ) + }, + notFoundComponent: () => { + // switch to lineage or table diff based on panel type + const panelType = (window as any).__SQLMESH_PANEL_TYPE__ || 'lineage' + return panelType === 'tablediff' ? <TableDiffPage /> : <LineagePage /> + }, +}) diff --git a/vscode/react/src/routes/index.tsx b/vscode/react/src/routes/index.tsx new file mode 100644 index 0000000000..e7bcf5ddfd --- /dev/null +++ b/vscode/react/src/routes/index.tsx @@ -0,0 +1,34 @@ +import { createFileRoute } from '@tanstack/react-router' +import '../App.css' + +export const Route = createFileRoute('/')({ + component: App, +}) + +function App() { + return ( + <div className="App"> + <header className="App-header"> + <p> + Edit <code>src/routes/index.tsx</code> and save to reload. + </p> + <a + className="App-link" + href="https://reactjs.org" + target="_blank" + rel="noopener noreferrer" + > + Learn React + </a> + <a + className="App-link" + href="https://tanstack.com" + target="_blank" + rel="noopener noreferrer" + > + Learn TanStack + </a> + </header> + </div> + ) +} diff --git a/vscode/react/src/routes/lineage.tsx b/vscode/react/src/routes/lineage.tsx new file mode 100644 index 0000000000..3228549b2e --- /dev/null +++ b/vscode/react/src/routes/lineage.tsx @@ -0,0 +1,6 @@ +import { createFileRoute } from '@tanstack/react-router' +import { LineagePage } from '@/pages/lineage' + +export const Route = createFileRoute('/lineage')({ + component: LineagePage, +}) diff --git a/vscode/react/src/routes/tablediff.tsx b/vscode/react/src/routes/tablediff.tsx new file mode 100644 index 0000000000..c9776048cd --- /dev/null +++ b/vscode/react/src/routes/tablediff.tsx @@ -0,0 +1,6 @@ +import { createFileRoute } from '@tanstack/react-router' +import { TableDiffPage } from '../pages/tablediff' + +export const Route = createFileRoute('/tablediff')({ + component: TableDiffPage, +}) diff --git a/vscode/react/src/style/variants.ts b/vscode/react/src/style/variants.ts new file mode 100644 index 0000000000..57b8e0c398 --- /dev/null +++ b/vscode/react/src/style/variants.ts @@ -0,0 +1,23 @@ +export const EnumSize = { + xs: 'xs', + sm: 'sm', + md: 'md', + lg: 'lg', + xl: 'xl', +} as const + +export type Size = (typeof EnumSize)[keyof typeof EnumSize] + +export const EnumVariant = { + Brand: 'brand', + Primary: 'primary', + Alternative: 'alternative', + Secondary: 'secondary', + Success: 'success', + Danger: 'danger', + Warning: 'warning', + Info: 'info', + Neutral: 'neutral', +} as const + +export type Variant = (typeof EnumVariant)[keyof typeof EnumVariant] diff --git a/vscode/react/src/styles.css b/vscode/react/src/styles.css new file mode 100644 index 0000000000..84640987c3 --- /dev/null +++ b/vscode/react/src/styles.css @@ -0,0 +1,13 @@ +body { + margin: 0; + font-family: + -apple-system, BlinkMacSystemFont, 'Segoe UI', 'Roboto', 'Oxygen', 'Ubuntu', + 'Cantarell', 'Fira Sans', 'Droid Sans', 'Helvetica Neue', sans-serif; + -webkit-font-smoothing: antialiased; + -moz-osx-font-smoothing: grayscale; +} + +code { + font-family: + source-code-pro, Menlo, Monaco, Consolas, 'Courier New', monospace; +} diff --git a/vscode/react/src/utils/index.spec.ts b/vscode/react/src/utils/index.spec.ts new file mode 100644 index 0000000000..d1d815e48b --- /dev/null +++ b/vscode/react/src/utils/index.spec.ts @@ -0,0 +1,211 @@ +import { describe, it, test, expect } from 'vitest' +import { + isArrayNotEmpty, + isArrayEmpty, + isObjectEmpty, + isObject, + isNil, + isNotNil, + isDate, + toDate, + toDateFormat, + isStringEmpty, + ensureString, +} from './index' + +describe('isArrayNotEmpty', () => { + test('returns true for non-empty arrays', () => { + expect(isArrayNotEmpty([1, 2, 3])).toBe(true) + }) + + test('returns false for empty arrays', () => { + expect(isArrayNotEmpty([])).toBe(false) + }) + + test('returns false for non-array values', () => { + expect(isArrayNotEmpty(123)).toBe(false) + expect(isArrayNotEmpty('abc')).toBe(false) + expect(isArrayNotEmpty({})).toBe(false) + expect(isArrayNotEmpty(null)).toBe(false) + expect(isArrayNotEmpty(undefined)).toBe(false) + }) +}) + +describe('isArrayEmpty', () => { + test('returns false for non-empty arrays', () => { + expect(isArrayEmpty([1, 2, 3])).toBe(false) + }) + + test('returns true for empty arrays', () => { + expect(isArrayEmpty([])).toBe(true) + }) + + test('returns false for non-array values', () => { + expect(isArrayEmpty(123)).toBe(false) + expect(isArrayEmpty('abc')).toBe(false) + expect(isArrayEmpty({})).toBe(false) + expect(isArrayEmpty(null)).toBe(false) + expect(isArrayEmpty(undefined)).toBe(false) + }) +}) + +describe('isObjectEmpty', () => { + test('returns true for empty objects', () => { + expect(isObjectEmpty({})).toBe(true) + }) + + test('returns false for non-empty objects', () => { + expect(isObjectEmpty({ a: 1, b: 2 })).toBe(false) + }) + + test('returns false for non-object values', () => { + expect(isObjectEmpty(123)).toBe(false) + expect(isObjectEmpty('abc')).toBe(false) + expect(isObjectEmpty([])).toBe(false) + expect(isObjectEmpty(null)).toBe(false) + expect(isObjectEmpty(undefined)).toBe(false) + }) +}) + +describe('isObject', () => { + test('returns true for objects', () => { + expect(isObject({})).toBe(true) + expect(isObject({ a: 1, b: 2 })).toBe(true) + }) + + test('returns false for non-object values', () => { + expect(isObject(123)).toBe(false) + expect(isObject('abc')).toBe(false) + expect(isObject([])).toBe(false) + expect(isObject(null)).toBe(false) + expect(isObject(undefined)).toBe(false) + }) +}) + +describe('isNil', () => { + test('returns true for null and undefined', () => { + expect(isNil(null)).toBe(true) + expect(isNil(undefined)).toBe(true) + }) + + test('returns false for other values', () => { + expect(isNil(123)).toBe(false) + expect(isNil('abc')).toBe(false) + expect(isNil({})).toBe(false) + expect(isNil([])).toBe(false) + }) +}) + +describe('isNotNil', () => { + it('should return true for a non-nil value', () => { + expect(isNotNil('foo')).toBe(true) + }) + + it('should return false for a nil value', () => { + expect(isNotNil(null)).toBe(false) + }) +}) + +describe('isDate', () => { + it('returns true for a Date object', () => { + expect(isDate(new Date())).toBe(true) + }) + + it('returns false for a string', () => { + expect(isDate('2023-02-07')).toBe(false) + }) + + it('returns false for a number', () => { + expect(isDate(123456789)).toBe(false) + }) +}) + +describe('toDate', () => { + it('returns a Date object for a valid string date', () => { + const date = toDate('2023-02-07') + expect(isDate(date)).toBe(true) + }) + + it('returns a Date object for a valid numeric timestamp', () => { + const date = toDate(1612738400000) + expect(isDate(date)).toBe(true) + }) + + it('returns undefined for an invalid date string', () => { + expect(toDate('not a date')).toBe(undefined) + }) + + it('returns undefined for an invalid numeric value', () => { + expect(toDate('not a number')).toBe(undefined) + }) +}) + +describe('toDateFormat', () => { + it('returns an empty string for a null date', () => { + expect(toDateFormat(undefined)).toBe('') + }) + + it('returns a formatted date string for a valid date and default format', () => { + expect(toDateFormat(new Date('2023-02-07 00:00:00'))).toBe('2023-02-07') + }) + + it('returns a default formatted date string for a unsupported custom format', () => { + expect(toDateFormat(new Date('2023-02-07 00:00:00'), 'dd/mm/yyyy')).toBe( + 'Tue Feb 07 2023', + ) + }) +}) + +describe('isStringEmpty', () => { + it('returns true for an empty string', () => { + expect(isStringEmpty('')).toBe(true) + }) + + it('returns false for a non-empty string', () => { + expect(isStringEmpty('hello')).toBe(false) + }) + + it('returns false for a string with only spaces', () => { + expect(isStringEmpty(' ')).toBe(false) + }) + + it('returns false for null', () => { + expect(isStringEmpty(null)).toBe(false) + }) + + it('returns false for undefined', () => { + expect(isStringEmpty(undefined)).toBe(false) + }) + + it('returns false for non-string values', () => { + expect(isStringEmpty(123)).toBe(false) + expect(isStringEmpty({})).toBe(false) + expect(isStringEmpty([])).toBe(false) + expect(isStringEmpty(true)).toBe(false) + }) +}) + +describe('ensureString', () => { + it('returns the same string for string input', () => { + expect(ensureString('hello')).toBe('hello') + }) + + it('returns empty string for null', () => { + expect(ensureString(null)).toBe('') + }) + + it('returns empty string for undefined', () => { + expect(ensureString(undefined)).toBe('') + }) + + it('returns empty string for non-string values', () => { + expect(ensureString(123)).toBe('') + expect(ensureString({})).toBe('') + expect(ensureString([])).toBe('') + expect(ensureString(true)).toBe('') + }) + + it('returns empty string for empty string input', () => { + expect(ensureString('')).toBe('') + }) +}) diff --git a/vscode/react/src/utils/index.ts b/vscode/react/src/utils/index.ts new file mode 100644 index 0000000000..ed0faab1db --- /dev/null +++ b/vscode/react/src/utils/index.ts @@ -0,0 +1,211 @@ +export function isTrue(value: unknown): boolean { + return value === true +} + +export function isFalse(value: unknown): boolean { + return value === false +} + +export function isFalseOrNil(value: unknown): boolean { + return isNil(value) || isFalse(value) +} + +export function isString(value: unknown): value is string { + return typeof value === 'string' +} + +export function isNumber(value: unknown): value is number { + return typeof value === 'number' +} + +export function isPrimitive( + value: unknown, +): value is string | number | boolean { + return isString(value) || isNumber(value) || typeof value === 'boolean' +} + +export function isStringEmpty(value: unknown): value is '' { + return value === '' +} + +export function isStringEmptyOrNil( + value: unknown, +): value is undefined | null | '' { + return isNil(value) || isStringEmpty(value) +} + +export function ensureString(value: unknown): string { + return isString(value) ? value : '' +} + +export function isStringNotEmpty(value: unknown): value is string { + return isString(value) && value.trim() !== '' +} + +export function isArrayNotEmpty<T = any>(value: unknown): value is T[] { + return Array.isArray(value) && value.length > 0 +} + +export function isArrayEmpty(value: unknown): boolean { + return Array.isArray(value) && value.length === 0 +} + +export function isObjectEmpty(value: unknown): boolean { + return isObject(value) && isArrayEmpty(Object.keys(value as object)) +} + +export function isObjectNotEmpty<TValue>(value: unknown): value is TValue { + return isObject(value) && isArrayNotEmpty(Object.keys(value as object)) +} + +export function isObject(value: unknown): boolean { + return ( + typeof value === 'object' && isNotNil(value) && value.constructor === Object + ) +} + +export function isNil(value: unknown): value is undefined | null { + return value == null +} + +export function isNotNil<T>(value: T | null | undefined): value is T { + return value != null +} + +export function isDate(value: unknown): boolean { + return value instanceof Date +} + +export function toDate(value?: string | number): Date | undefined { + if (isNil(value)) return undefined + + try { + const date = new Date(value) + + return isNaN(date.getTime()) ? undefined : date + } catch { + return undefined + } +} + +export function toDateFormat( + date?: Date, + format: string = 'yyyy-mm-dd', + isUTC: boolean = true, +): string { + if (isNil(date)) return '' + + const year = isUTC ? date.getUTCFullYear() : date.getFullYear() + const month = toFormatted( + isUTC ? date.getUTCMonth() + 1 : date.getMonth() + 1, + ) + const day = toFormatted(isUTC ? date.getUTCDate() : date.getDate()) + const hour = toFormatted(isUTC ? date.getUTCHours() : date.getHours()) + const minute = toFormatted(isUTC ? date.getUTCMinutes() : date.getMinutes()) + const second = toFormatted(isUTC ? date.getUTCSeconds() : date.getSeconds()) + + const formats: Record<string, string> = { + 'mm/dd/yyyy': `${month}/${day}/${year}`, + 'yyyy-mm-dd': `${year}-${month}-${day}`, + 'yyyy-mm-dd hh-mm-ss': `${year}-${month}-${day} ${hour}:${minute}:${second}`, + } + + return formats[format] ?? date.toDateString() + + function toFormatted(n: number): string { + return n.toString().padStart(2, '0') + } +} + +export function includes<T>(array: T[], value: T): boolean { + return array.includes(value) +} + +export function toRatio( + top?: number, + bottom?: number, + multiplier = 100, +): number { + if (isNil(top) || isNil(bottom) || bottom === 0) return 0 + if (isNaN(top) || isNaN(bottom) || isNaN(multiplier)) return 0 + + return (top / bottom) * multiplier +} + +export function parseJSON<T>(value: string | null): T | undefined { + if (isNil(value)) return undefined + + try { + return value === 'undefined' ? undefined : JSON.parse(value ?? '') + } catch { + return undefined + } +} + +export function debounceSync( + fn: (...args: any) => void, + delay: number = 500, + immediate: boolean = false, +): (...args: any) => void { + let timeoutID: ReturnType<typeof setTimeout> | undefined + + return function callback(...args: any) { + const callNow = immediate && isNil(timeoutID) + + clearTimeout(timeoutID) + + timeoutID = setTimeout(() => { + timeoutID = undefined + + if (isFalse(immediate)) { + fn(...args) + } + }, delay) + + if (callNow) { + fn(...args) + } + } +} + +export function uid(): string { + const time = new Date().getTime().toString(36) + const random = Math.random().toString(36).substring(2, 8) + + return time + random +} + +export function toUniqueName(prefix?: string, suffix?: string): string { + // Should be enough for now + const hex = (Date.now() % 100000).toString(16) + + return `${isNil(prefix) ? '' : `${prefix}_`}${hex}${ + suffix ?? '' + }`.toLowerCase() +} + +export function truncate( + text: string, + maxChars = 0, + displayBefore = 5, + delimiter = '...', + displayAfter?: number, +): string { + const textLength = text.length + displayBefore = Math.abs(displayBefore) + displayAfter = isNil(displayAfter) ? displayBefore : Math.abs(displayAfter) + + if (maxChars > textLength || displayBefore + displayAfter >= textLength) { + return text + } + + if (displayAfter === 0) { + return text.substring(0, displayBefore) + delimiter + } + + return ( + text.substring(0, displayBefore) + + delimiter + + text.substring(textLength - displayAfter) + ) +} diff --git a/vscode/react/src/utils/receiveEvents.ts b/vscode/react/src/utils/receiveEvents.ts new file mode 100644 index 0000000000..f612e1c239 --- /dev/null +++ b/vscode/react/src/utils/receiveEvents.ts @@ -0,0 +1,22 @@ +import type { VSCodeCallback } from '@bus/callbacks' + +/** + * add event listener to the window. + * + * returns a function to remove the event listener. + */ +export const addEventListener = <K extends keyof VSCodeCallback>( + handlerKey: K, + handler: (payload: VSCodeCallback[K]) => void, +): (() => void) => { + const handleMessage = (event: MessageEvent) => { + const { key, payload } = event.data + if (key === handlerKey) { + handler(payload) + } + } + window.addEventListener('message', handleMessage) + return () => { + window.removeEventListener('message', handleMessage) + } +} diff --git a/vscode/react/src/utils/rpc.ts b/vscode/react/src/utils/rpc.ts new file mode 100644 index 0000000000..ccb92382e2 --- /dev/null +++ b/vscode/react/src/utils/rpc.ts @@ -0,0 +1,50 @@ +import { + type RPCRequest, + type RPCMethods, + type CallbackEvent, +} from '@bus/callbacks' +import { type Result } from '@bus/result' +import { sendVSCodeMessage } from './vscodeapi' + +export const useRpc = () => { + return callRpc +} + +export const callRpc = async <T extends keyof RPCMethods>( + method: T, + params: RPCMethods[T]['params'], +): Promise<Result<RPCMethods[T]['result'], string>> => { + return new Promise((resolve, reject) => { + const requestId = `query_${Date.now()}_${Math.random().toString(36).substr(2, 9)}` + const messageHandler = (event: MessageEvent) => { + if (event.data) { + const eventData = event.data as CallbackEvent + if (eventData.key !== 'rpcResponse') { + return + } + if (eventData.payload.requestId !== requestId) { + return + } + const payload = eventData.payload.result + window.removeEventListener('message', messageHandler) + return resolve(payload) + } + } + + // Add the listener + window.addEventListener('message', messageHandler) + + const request: RPCRequest = { + requestId, + method, + params, + } + sendVSCodeMessage('rpcRequest', request) + + // Set a timeout to prevent hanging promises + setTimeout(() => { + window.removeEventListener('message', messageHandler) + reject(new Error('Query request timed out')) + }, 30000) // 30 second timeout + }) +} diff --git a/vscode/react/src/utils/vscodeapi.ts b/vscode/react/src/utils/vscodeapi.ts new file mode 100644 index 0000000000..f06b71ac41 --- /dev/null +++ b/vscode/react/src/utils/vscodeapi.ts @@ -0,0 +1,35 @@ +import type { Callback } from '@bus/callbacks' + +/** + * send a message to the vscode extension. + * + * This should generally not be used directly, but rather through the useVSCode hook. + */ +export const sendVSCodeMessage = <K extends keyof Callback>( + callbackName: K, + payload: Callback[K], +): void => { + const eventPayload = { + key: callbackName, + payload: payload, + } + getVSCodeAPI().postMessage(eventPayload) +} + +let VSCODE_API: VSCodeAPI | undefined + +interface VSCodeAPI { + postMessage: (message: any) => void +} + +declare function acquireVsCodeApi(): VSCodeAPI + +function getVSCodeAPI(): VSCodeAPI { + if (!VSCODE_API) { + VSCODE_API = acquireVsCodeApi() + } + if (!VSCODE_API) { + throw new Error('VSCode API not initialized') + } + return VSCODE_API +} diff --git a/vscode/react/tailwind.config.cjs b/vscode/react/tailwind.config.cjs new file mode 100644 index 0000000000..c31bd9f9ec --- /dev/null +++ b/vscode/react/tailwind.config.cjs @@ -0,0 +1,187 @@ +/** @type {import('tailwindcss').Config} */ +module.exports = { + content: ['./index.html', './src/**/*.{js,ts,jsx,tsx}'], + darkMode: ['class', '[mode="dark"]'], + theme: { + colors: { + current: 'currentColor', + inherit: 'inherit', + transparent: 'transparent', + prose: { + DEFAULT: 'var(--color-text)', + darker: 'var(--color-text-darker)', + lighter: 'var(--color-text-lighter)', + }, + dark: { + DEFAULT: 'var(--color-dark)', + darker: 'var(--color-dark-darker)', + lighter: 'var(--color-dark-lighter)', + }, + light: { + DEFAULT: 'var(--color-light)', + darker: 'var(--color-light-darker)', + lighter: 'var(--color-light-lighter)', + }, + overlay: { + DEFAULT: 'var(--color-overlay)', + darker: 'var(--color-overlay-darker)', + lighter: 'var(--color-overlay-lighter)', + }, + editor: { + DEFAULT: 'var(--color-editor)', + darker: 'var(--color-editor-darker)', + lighter: 'var(--color-editor-lighter)', + }, + logo: { + DEFAULT: 'var(--color-logo)', + darker: 'var(--color-logo-darker)', + lighter: 'var(--color-logo-lighter)', + }, + theme: { + DEFAULT: 'var(--color-theme)', + darker: 'var(--color-theme-darker)', + lighter: 'var(--color-theme-lighter)', + }, + divider: { + DEFAULT: 'var(--color-divider)', + }, + brand: { + 5: 'var(--color-brand-5)', + 10: 'var(--color-brand-10)', + 20: 'var(--color-brand-20)', + 75: 'var(--color-brand-75)', + 50: 'var(--color-brand-50)', + 90: 'var(--color-brand-90)', + 100: 'var(--color-brand-100)', + 200: 'var(--color-brand-200)', + 300: 'var(--color-brand-300)', + 400: 'var(--color-brand-400)', + 500: 'var(--color-brand-500)', + 600: 'var(--color-brand-600)', + 700: 'var(--color-brand-700)', + 800: 'var(--color-brand-800)', + 900: 'var(--color-brand-900)', + }, + neutral: { + 5: 'var(--color-neutral-5)', + 10: 'var(--color-neutral-10)', + 20: 'var(--color-neutral-20)', + 30: 'var(--color-neutral-30)', + 40: 'var(--color-neutral-40)', + 50: 'var(--color-neutral-50)', + 60: 'var(--color-neutral-60)', + 70: 'var(--color-neutral-70)', + 80: 'var(--color-neutral-80)', + 90: 'var(--color-neutral-90)', + 100: 'var(--color-neutral-100)', + 200: 'var(--color-neutral-200)', + 300: 'var(--color-neutral-300)', + 400: 'var(--color-neutral-400)', + 500: 'var(--color-neutral-500)', + 600: 'var(--color-neutral-600)', + 700: 'var(--color-neutral-700)', + 800: 'var(--color-neutral-800)', + 900: 'var(--color-neutral-900)', + }, + primary: { + 5: 'var(--color-primary-5)', + 10: 'var(--color-primary-10)', + 20: 'var(--color-primary-20)', + 30: 'var(--color-primary-30)', + 40: 'var(--color-primary-40)', + 50: 'var(--color-primary-50)', + 60: 'var(--color-primary-60)', + 70: 'var(--color-primary-70)', + 80: 'var(--color-primary-80)', + 90: 'var(--color-primary-90)', + 100: 'var(--color-primary-100)', + 200: 'var(--color-primary-200)', + 300: 'var(--color-primary-300)', + 400: 'var(--color-primary-400)', + 500: 'var(--color-primary-500)', + 600: 'var(--color-primary-600)', + 700: 'var(--color-primary-700)', + 800: 'var(--color-primary-800)', + 900: 'var(--color-primary-900)', + }, + secondary: { + 5: 'var(--color-secondary-5)', + 10: 'var(--color-secondary-10)', + 20: 'var(--color-secondary-20)', + 30: 'var(--color-secondary-30)', + 100: 'var(--color-secondary-100)', + 200: 'var(--color-secondary-200)', + 300: 'var(--color-secondary-300)', + 400: 'var(--color-secondary-400)', + 500: 'var(--color-secondary-500)', + 600: 'var(--color-secondary-600)', + 700: 'var(--color-secondary-700)', + 800: 'var(--color-secondary-800)', + 900: 'var(--color-secondary-900)', + }, + accent: { + 5: 'var(--color-accent-5)', + 50: 'var(--color-accent-50)', + 100: 'var(--color-accent-100)', + 200: 'var(--color-accent-200)', + 300: 'var(--color-accent-300)', + 400: 'var(--color-accent-400)', + 500: 'var(--color-accent-500)', + 600: 'var(--color-accent-600)', + 700: 'var(--color-accent-700)', + 800: 'var(--color-accent-800)', + 900: 'var(--color-accent-900)', + }, + success: { + 5: 'var(--color-success-5)', + 10: 'var(--color-success-10)', + 20: 'var(--color-success-20)', + 30: 'var(--color-success-30)', + 40: 'var(--color-success-40)', + 100: 'var(--color-success-100)', + 200: 'var(--color-success-200)', + 300: 'var(--color-success-300)', + 400: 'var(--color-success-400)', + 500: 'var(--color-success-500)', + 600: 'var(--color-success-600)', + 700: 'var(--color-success-700)', + 800: 'var(--color-success-800)', + 900: 'var(--color-success-900)', + }, + danger: { + 5: 'var(--color-danger-5)', + 10: 'var(--color-danger-10)', + 20: 'var(--color-danger-20)', + 30: 'var(--color-danger-30)', + 40: 'var(--color-danger-40)', + 100: 'var(--color-danger-100)', + 200: 'var(--color-danger-200)', + 300: 'var(--color-danger-300)', + 400: 'var(--color-danger-400)', + 500: 'var(--color-danger-500)', + 600: 'var(--color-danger-600)', + 700: 'var(--color-danger-700)', + 800: 'var(--color-danger-800)', + 900: 'var(--color-danger-900)', + }, + warning: { + 5: 'var(--color-warning-5)', + 10: 'var(--color-warning-10)', + 100: 'var(--color-warning-100)', + 200: 'var(--color-warning-200)', + 300: 'var(--color-warning-300)', + 400: 'var(--color-warning-400)', + 500: 'var(--color-warning-500)', + 600: 'var(--color-warning-600)', + 700: 'var(--color-warning-700)', + 800: 'var(--color-warning-800)', + 900: 'var(--color-warning-900)', + }, + }, + fontFamily: { + mono: ['JetBrains Mono', 'monospace'], + sans: ['Inter', 'sans-serif'], + serif: ['Publico', 'serif'], + }, + }, +} diff --git a/vscode/react/tsconfig.json b/vscode/react/tsconfig.json new file mode 100644 index 0000000000..b57d3316b0 --- /dev/null +++ b/vscode/react/tsconfig.json @@ -0,0 +1,29 @@ +{ + "include": ["**/*.ts", "**/*.tsx"], + "compilerOptions": { + "target": "ES2022", + "jsx": "react-jsx", + "module": "ESNext", + "lib": ["ES2022", "DOM", "DOM.Iterable"], + "types": ["vite/client", "react", "react-dom"], + + /* Bundler mode */ + "moduleResolution": "bundler", + "allowImportingTsExtensions": true, + "verbatimModuleSyntax": true, + "noEmit": true, + + /* Linting */ + "skipLibCheck": true, + "strict": true, + "noUnusedLocals": true, + "noUnusedParameters": true, + "noFallthroughCasesInSwitch": true, + "noUncheckedSideEffectImports": true, + "baseUrl": ".", + "paths": { + "@/*": ["./src/*"], + "@bus/*": ["../bus/src/*"] + } + } +} diff --git a/vscode/react/vite.config.js b/vscode/react/vite.config.js new file mode 100644 index 0000000000..84568f6cdd --- /dev/null +++ b/vscode/react/vite.config.js @@ -0,0 +1,56 @@ +import { defineConfig } from 'vite' +import viteReact from '@vitejs/plugin-react' +import { TanStackRouterVite } from '@tanstack/router-plugin/vite' +import { resolve } from 'node:path' +import tailwindcss from '@tailwindcss/vite' + +// https://vitejs.dev/config/ +export default defineConfig({ + plugins: [ + TanStackRouterVite({ autoCodeSplitting: false }), + viteReact(), + tailwindcss(), + ], + test: { + globals: true, + environment: 'jsdom', + }, + + // This is to ensure we can import the bus module from the bus folder and have nice import paths + resolve: { + alias: { + '@': resolve(__dirname, './src'), + '@bus': resolve(__dirname, '../bus/src'), + }, + }, + + // This is to ensure that the assets are in the assets folder are all named assets/[name].[extension] rather + // than having a hash. + build: { + // Everything below is pure Rollup syntax + rollupOptions: { + output: { + // ── JavaScript ────────────────────────────────────── + entryFileNames: 'assets/[name].js', // main-entry + chunkFileNames: 'assets/[name].js', // code-splits + // ── CSS & other assets ───────────────────────────── + assetFileNames: ({ name }) => { + // name = original file name with extension + const ext = name?.substring(name.lastIndexOf('.')) + return `assets/[name]${ext}` // e.g. style.css + }, + }, + }, + }, + + // This ensures that the API calls to the server are proxied to the server. + server: { + proxy: { + '/api': { + target: 'http://localhost:5174', + changeOrigin: true, + secure: false, + }, + }, + }, +}) diff --git a/vscode/react/vitest.shims.d.ts b/vscode/react/vitest.shims.d.ts new file mode 100644 index 0000000000..a1d31e5a7b --- /dev/null +++ b/vscode/react/vitest.shims.d.ts @@ -0,0 +1 @@ +/// <reference types="@vitest/browser/providers/playwright" /> diff --git a/web/Dockerfile.api b/web/Dockerfile.api index 0a9fb85c1b..da9438ec87 100644 --- a/web/Dockerfile.api +++ b/web/Dockerfile.api @@ -2,7 +2,9 @@ FROM python:3.11 WORKDIR /sqlmesh -COPY setup.py setup.py +COPY pyproject.toml pyproject.toml +COPY Makefile Makefile +COPY examples/custom_materializations/ examples/custom_materializations/ COPY sqlmesh/_version.py sqlmesh/_version.py -RUN pip install -e .[dev,web] +RUN make install-dev diff --git a/web/Dockerfile.app b/web/Dockerfile.app index 90e7b21448..b234f904b6 100644 --- a/web/Dockerfile.app +++ b/web/Dockerfile.app @@ -1,12 +1,18 @@ -FROM mcr.microsoft.com/playwright:v1.42.0-jammy +FROM mcr.microsoft.com/playwright:v1.49.0-jammy WORKDIR /app -ENV PATH /app/node_modules/.bin:$PATH - RUN apt-get update && apt-get -y install libnss3 libatk-bridge2.0-0 libdrm-dev libxkbcommon-dev libgbm-dev libasound-dev libatspi2.0-0 libxshmfence-dev -COPY web/client/package*.json . +# Install pnpm globally +RUN npm install -g pnpm@latest + +# Copy package files for workspaces +COPY package.json pnpm-lock.yaml pnpm-workspace.yaml ./ +COPY web/client/package.json ./web/client/ + +# Install dependencies +RUN pnpm install --frozen-lockfile -RUN npm install -g npm@latest && \ - npm install --no-audit --no-fund --no-package-lock +# Copy source files (excluding node_modules which were installed above) +COPY web/client/ ./web/client/ diff --git a/web/client/.eslintrc.js b/web/client/.eslintrc.js deleted file mode 100644 index 3066ab026a..0000000000 --- a/web/client/.eslintrc.js +++ /dev/null @@ -1,56 +0,0 @@ -const OFF = 0 -const ERROR = 2 - -module.exports = { - root: true, - env: { - browser: true, - es2021: true, - }, - extends: ['plugin:react/recommended', 'standard-with-typescript', 'prettier'], - parser: '@typescript-eslint/parser', - parserOptions: { - tsconfigRootDir: __dirname, - project: './tsconfig.json', - }, - plugins: ['react', '@typescript-eslint'], - rules: { - 'react/jsx-uses-react': OFF, - 'react/react-in-jsx-scope': OFF, - 'no-use-before-define': OFF, - '@typescript-eslint/promise-function-async': OFF, - '@typescript-eslint/no-non-null-assertion': OFF, - 'no-return-await': OFF, - '@typescript-eslint/return-await': OFF, - '@typescript-eslint/no-use-before-define': [ - ERROR, - { - variables: true, - functions: false, - classes: false, - allowNamedExports: true, - }, - ], - '@typescript-eslint/no-dynamic-delete': OFF, - '@typescript-eslint/naming-convention': [ - ERROR, - { - selector: 'variable', - format: ['camelCase', 'PascalCase', 'UPPER_CASE', 'snake_case'], - }, - ], - '@typescript-eslint/no-confusing-void-expression': OFF, - }, - ignorePatterns: [ - 'src/api/client.ts', - 'test-results', - 'playwright', - 'playwright-report', - 'dist', - ], - settings: { - react: { - version: '18.2', - }, - }, -} diff --git a/web/client/.prettierignore b/web/client/.prettierignore deleted file mode 100644 index 5f6d0b843c..0000000000 --- a/web/client/.prettierignore +++ /dev/null @@ -1,9 +0,0 @@ -**/*.py -.prettierignore -.gitignore -node_modules/ -/test-results/ -/playwright-report/ -/playwright/.cache/ -dist -tsconfig.tsbuildinfo \ No newline at end of file diff --git a/web/client/eslint.config.mjs b/web/client/eslint.config.mjs new file mode 100644 index 0000000000..dbf0ec076b --- /dev/null +++ b/web/client/eslint.config.mjs @@ -0,0 +1,26 @@ +import eslint from '@eslint/js' +import tseslint from 'typescript-eslint' + +export default tseslint.config( + { + ignores: [ + 'dist/**/*', + 'node_modules/**', + 'src/utils/tbk-components.js', + '**/*.cjs', + '**/*.mjs', + 'src/api/client.ts', + ], + }, + eslint.configs.recommended, + tseslint.configs.recommended, + { + rules: { + '@typescript-eslint/no-explicit-any': 'off', + '@typescript-eslint/no-unused-vars': 'off', + '@typescript-eslint/no-empty-object-type': 'off', + '@typescript-eslint/no-unused-expressions': 'off', + 'no-empty': 'off', + }, + }, +) diff --git a/web/client/index.html b/web/client/index.html index 5fb52d376a..e1e71899ef 100644 --- a/web/client/index.html +++ b/web/client/index.html @@ -6,24 +6,230 @@ name="viewport" content="width=device-width, initial-scale=1.0" /> + <title>SQLMesh by Tobiko - SQLMesh by Tobiko + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/web/client/openapi.json b/web/client/openapi.json index 48d0c0f93f..bf1cef0809 100644 --- a/web/client/openapi.json +++ b/web/client/openapi.json @@ -11,12 +11,7 @@ "content": { "application/json": { "schema": { - "allOf": [ - { - "$ref": "#/components/schemas/Body_initiate_apply_api_commands_apply_post" - } - ], - "title": "Body" + "$ref": "#/components/schemas/Body_initiate_apply_api_commands_apply_post" } } } @@ -154,14 +149,10 @@ } }, { - "name": "verbose", + "name": "verbosity", "in": "query", "required": false, - "schema": { - "type": "boolean", - "default": false, - "title": "Verbose" - } + "schema": { "$ref": "#/components/schemas/Verbosity", "default": 0 } } ], "responses": { @@ -249,12 +240,7 @@ "content": { "application/json": { "schema": { - "allOf": [ - { - "$ref": "#/components/schemas/Body_write_file_api_files__path__post" - } - ], - "title": "Body" + "$ref": "#/components/schemas/Body_write_file_api_files__path__post" } } } @@ -336,12 +322,7 @@ "content": { "application/json": { "schema": { - "allOf": [ - { - "$ref": "#/components/schemas/Body_write_directory_api_directories__path__post" - } - ], - "title": "Body" + "$ref": "#/components/schemas/Body_write_directory_api_directories__path__post" } } } @@ -402,12 +383,7 @@ "content": { "application/json": { "schema": { - "allOf": [ - { - "$ref": "#/components/schemas/Body_initiate_plan_api_plan_post" - } - ], - "title": "Body" + "$ref": "#/components/schemas/Body_initiate_plan_api_plan_post" } } } @@ -697,9 +673,9 @@ }, "/api/modules": { "get": { - "summary": "Get Api Meta", + "summary": "Get Api Modules", "description": "Get the modules", - "operationId": "get_api_meta_api_modules_get", + "operationId": "get_api_modules_api_modules_get", "responses": { "200": { "description": "Successful Response", @@ -708,7 +684,7 @@ "schema": { "items": { "$ref": "#/components/schemas/Modules" }, "type": "array", - "title": "Response Get Api Meta Api Modules Get" + "title": "Response Get Api Modules Api Modules Get" } } } @@ -761,6 +737,15 @@ "title": "Where" } }, + { + "name": "temp_schema", + "in": "query", + "required": false, + "schema": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Temp Schema" + } + }, { "name": "limit", "in": "query", @@ -773,7 +758,13 @@ "description": "Successful Response", "content": { "application/json": { - "schema": { "$ref": "#/components/schemas/TableDiff" } + "schema": { + "anyOf": [ + { "$ref": "#/components/schemas/TableDiff" }, + { "type": "null" } + ], + "title": "Response Get Table Diff Api Table Diff Get" + } } } }, @@ -881,7 +872,7 @@ "name": { "type": "string", "title": "Name" }, "view_name": { "type": "string", "title": "View Name" }, "node_type": { - "allOf": [{ "$ref": "#/components/schemas/NodeType" }], + "$ref": "#/components/schemas/NodeType", "default": "model" }, "parents": { @@ -908,7 +899,7 @@ "name": { "type": "string", "title": "Name" }, "view_name": { "type": "string", "title": "View Name" }, "node_type": { - "allOf": [{ "$ref": "#/components/schemas/NodeType" }], + "$ref": "#/components/schemas/NodeType", "default": "model" }, "parents": { @@ -1032,7 +1023,7 @@ "name": { "type": "string", "title": "Name" }, "view_name": { "type": "string", "title": "View Name" }, "node_type": { - "allOf": [{ "$ref": "#/components/schemas/NodeType" }], + "$ref": "#/components/schemas/NodeType", "default": "model" }, "parents": { @@ -1072,7 +1063,7 @@ "name": { "type": "string", "title": "Name" }, "view_name": { "type": "string", "title": "View Name" }, "node_type": { - "allOf": [{ "$ref": "#/components/schemas/NodeType" }], + "$ref": "#/components/schemas/NodeType", "default": "model" }, "parents": { @@ -1093,7 +1084,7 @@ "name": { "type": "string", "title": "Name" }, "view_name": { "type": "string", "title": "View Name" }, "node_type": { - "allOf": [{ "$ref": "#/components/schemas/NodeType" }], + "$ref": "#/components/schemas/NodeType", "default": "model" }, "parents": { @@ -1148,26 +1139,6 @@ "Environment": { "properties": { "name": { "type": "string", "title": "Name", "default": "prod" }, - "suffix_target": { - "allOf": [ - { "$ref": "#/components/schemas/EnvironmentSuffixTarget" } - ], - "default": "schema" - }, - "catalog_name_override": { - "anyOf": [{ "type": "string" }, { "type": "null" }], - "title": "Catalog Name Override" - }, - "normalize_name": { - "type": "boolean", - "title": "Normalize Name", - "default": true - }, - "snapshots": { - "items": { "$ref": "#/components/schemas/SnapshotTableInfo" }, - "type": "array", - "title": "Snapshots" - }, "start_at": { "anyOf": [ { "type": "string", "format": "date" }, @@ -1202,32 +1173,45 @@ "anyOf": [{ "type": "integer" }, { "type": "null" }], "title": "Finalized Ts" }, + "suffix_target": { + "$ref": "#/components/schemas/EnvironmentSuffixTarget", + "default": "schema" + }, + "catalog_name_override": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Catalog Name Override" + }, + "normalize_name": { + "type": "boolean", + "title": "Normalize Name", + "default": true + }, + "gateway_managed": { + "type": "boolean", + "title": "Gateway Managed", + "default": false + }, + "snapshots": { "items": {}, "type": "array", "title": "Snapshots" }, "promoted_snapshot_ids": { - "anyOf": [ - { - "items": { "$ref": "#/components/schemas/SnapshotId" }, - "type": "array" - }, - { "type": "null" } - ], + "anyOf": [{ "items": {}, "type": "array" }, { "type": "null" }], "title": "Promoted Snapshot Ids" }, "previous_finalized_snapshots": { - "anyOf": [ - { - "items": { "$ref": "#/components/schemas/SnapshotTableInfo" }, - "type": "array" - }, - { "type": "null" } - ], + "anyOf": [{ "items": {}, "type": "array" }, { "type": "null" }], "title": "Previous Finalized Snapshots" + }, + "requirements": { + "additionalProperties": { "type": "string" }, + "type": "object", + "title": "Requirements", + "default": {} } }, "additionalProperties": false, "type": "object", - "required": ["snapshots", "start_at", "plan_id"], + "required": ["start_at", "plan_id", "snapshots"], "title": "Environment", - "description": "Represents an isolated environment.\n\nEnvironments are isolated workspaces that hold pointers to physical tables.\n\nArgs:\n snapshots: The snapshots that are part of this environment.\n start_at: The start time of the environment.\n end_at: The end time of the environment.\n plan_id: The ID of the plan that last updated this environment.\n previous_plan_id: The ID of the previous plan that updated this environment.\n expiration_ts: The timestamp when this environment will expire.\n finalized_ts: The timestamp when this environment was finalized.\n promoted_snapshot_ids: The IDs of the snapshots that are promoted in this environment\n (i.e. for which the views are created). If not specified, all snapshots are promoted.\n previous_finalized_snapshots: Snapshots that were part of this environment last time it was finalized." + "description": "Represents an isolated environment.\n\nEnvironments are isolated workspaces that hold pointers to physical tables.\n\nArgs:\n snapshots: The snapshots that are part of this environment.\n promoted_snapshot_ids: The IDs of the snapshots that are promoted in this environment\n (i.e. for which the views are created). If not specified, all snapshots are promoted.\n previous_finalized_snapshots: Snapshots that were part of this environment last time it was finalized.\n requirements: A mapping of library versions for all the snapshots in this environment." }, "EnvironmentSuffixTarget": { "type": "string", @@ -1399,6 +1383,7 @@ "name": { "type": "string", "title": "Name" }, "fqn": { "type": "string", "title": "Fqn" }, "path": { "type": "string", "title": "Path" }, + "full_path": { "type": "string", "title": "Full Path" }, "dialect": { "type": "string", "title": "Dialect" }, "type": { "$ref": "#/components/schemas/ModelType" }, "columns": { @@ -1420,6 +1405,10 @@ "anyOf": [{ "type": "string" }, { "type": "null" }], "title": "Sql" }, + "definition": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Definition" + }, "default_catalog": { "anyOf": [{ "type": "string" }, { "type": "null" }], "title": "Default Catalog" @@ -1432,6 +1421,7 @@ "name", "fqn", "path", + "full_path", "dialect", "type", "columns", @@ -1483,6 +1473,10 @@ "anyOf": [{ "type": "integer" }, { "type": "null" }], "title": "Retention" }, + "table_format": { + "anyOf": [{ "type": "string" }, { "type": "null" }], + "title": "Table Format" + }, "storage_format": { "anyOf": [{ "type": "string" }, { "type": "null" }], "title": "Storage Format" @@ -1550,28 +1544,9 @@ "type": "object", "title": "ModelDetails" }, - "ModelKindName": { - "type": "string", - "enum": [ - "INCREMENTAL_BY_TIME_RANGE", - "INCREMENTAL_BY_UNIQUE_KEY", - "INCREMENTAL_BY_PARTITION", - "INCREMENTAL_UNMANAGED", - "FULL", - "SCD_TYPE_2", - "SCD_TYPE_2_BY_TIME", - "SCD_TYPE_2_BY_COLUMN", - "VIEW", - "EMBEDDED", - "SEED", - "EXTERNAL" - ], - "title": "ModelKindName", - "description": "The kind of model, determining how this data is computed and stored in the warehouse." - }, "ModelType": { "type": "string", - "enum": ["python", "sql", "seed", "external"], + "enum": ["python", "sql", "seed", "external", "source"], "title": "ModelType" }, "ModelsDiff": { @@ -1604,7 +1579,7 @@ "enum": [ "editor", "files", - "docs", + "data-catalog", "plans", "tests", "audits", @@ -2073,7 +2048,11 @@ "type": "object", "title": "Stats" }, - "sample": { "type": "object", "title": "Sample" }, + "sample": { + "additionalProperties": true, + "type": "object", + "title": "Sample" + }, "source_count": { "type": "integer", "title": "Source Count" }, "target_count": { "type": "integer", "title": "Target Count" }, "count_pct_change": { "type": "number", "title": "Count Pct Change" } @@ -2140,113 +2119,6 @@ "title": "SnapshotChangeCategory", "description": "Values are ordered by decreasing severity and that ordering is required.\n\nBREAKING: The change requires that snapshot modified and downstream dependencies be rebuilt\nNON_BREAKING: The change requires that only the snapshot modified be rebuilt\nFORWARD_ONLY: The change requires no rebuilding\nINDIRECT_BREAKING: The change was caused indirectly and is breaking.\nINDIRECT_NON_BREAKING: The change was caused indirectly by a non-breaking change.\nMETADATA: The change was caused by a metadata update." }, - "SnapshotDataVersion": { - "properties": { - "fingerprint": { "$ref": "#/components/schemas/SnapshotFingerprint" }, - "version": { "type": "string", "title": "Version" }, - "temp_version": { - "anyOf": [{ "type": "string" }, { "type": "null" }], - "title": "Temp Version" - }, - "change_category": { - "anyOf": [ - { "$ref": "#/components/schemas/SnapshotChangeCategory" }, - { "type": "null" } - ] - }, - "physical_schema": { - "anyOf": [{ "type": "string" }, { "type": "null" }], - "title": "Physical Schema" - } - }, - "additionalProperties": false, - "type": "object", - "required": ["fingerprint", "version"], - "title": "SnapshotDataVersion" - }, - "SnapshotFingerprint": { - "properties": { - "data_hash": { "type": "string", "title": "Data Hash" }, - "metadata_hash": { "type": "string", "title": "Metadata Hash" }, - "parent_data_hash": { - "type": "string", - "title": "Parent Data Hash", - "default": "0" - }, - "parent_metadata_hash": { - "type": "string", - "title": "Parent Metadata Hash", - "default": "0" - } - }, - "additionalProperties": false, - "type": "object", - "required": ["data_hash", "metadata_hash"], - "title": "SnapshotFingerprint" - }, - "SnapshotId": { - "properties": { - "name": { "type": "string", "title": "Name" }, - "identifier": { "type": "string", "title": "Identifier" } - }, - "additionalProperties": false, - "type": "object", - "required": ["name", "identifier"], - "title": "SnapshotId" - }, - "SnapshotTableInfo": { - "properties": { - "name": { "type": "string", "title": "Name" }, - "temp_version": { - "anyOf": [{ "type": "string" }, { "type": "null" }], - "title": "Temp Version" - }, - "change_category": { - "anyOf": [ - { "$ref": "#/components/schemas/SnapshotChangeCategory" }, - { "type": "null" } - ] - }, - "fingerprint": { "$ref": "#/components/schemas/SnapshotFingerprint" }, - "previous_versions": { - "items": { "$ref": "#/components/schemas/SnapshotDataVersion" }, - "type": "array", - "title": "Previous Versions", - "default": [] - }, - "base_table_name_override": { - "anyOf": [{ "type": "string" }, { "type": "null" }], - "title": "Base Table Name Override" - }, - "version": { "type": "string", "title": "Version" }, - "physical_schema": { "type": "string", "title": "Physical Schema" }, - "parents": { - "items": { "$ref": "#/components/schemas/SnapshotId" }, - "type": "array", - "title": "Parents" - }, - "kind_name": { - "anyOf": [ - { "$ref": "#/components/schemas/ModelKindName" }, - { "type": "null" } - ] - }, - "node_type": { - "allOf": [{ "$ref": "#/components/schemas/NodeType" }], - "default": "model" - } - }, - "additionalProperties": false, - "type": "object", - "required": [ - "name", - "fingerprint", - "version", - "physical_schema", - "parents" - ], - "title": "SnapshotTableInfo" - }, "Status": { "type": "string", "enum": ["init", "success", "fail"], @@ -2332,7 +2204,7 @@ "TrackableMeta": { "properties": { "status": { - "allOf": [{ "$ref": "#/components/schemas/Status" }], + "$ref": "#/components/schemas/Status", "default": "init" }, "start": { "type": "integer", "title": "Start" }, @@ -2359,6 +2231,12 @@ "type": "object", "required": ["loc", "msg", "type"], "title": "ValidationError" + }, + "Verbosity": { + "type": "integer", + "enum": [0, 1, 2], + "title": "Verbosity", + "description": "Verbosity levels for SQLMesh output." } } } diff --git a/web/client/package-lock.json b/web/client/package-lock.json deleted file mode 100644 index 7bd0f0c60e..0000000000 --- a/web/client/package-lock.json +++ /dev/null @@ -1,9062 +0,0 @@ -{ - "name": "tobiko", - "version": "0.0.0", - "lockfileVersion": 3, - "requires": true, - "packages": { - "": { - "name": "tobiko", - "version": "0.0.0", - "dependencies": { - "@codemirror/autocomplete": "^6.16.2", - "@codemirror/commands": "^6.6.0", - "@codemirror/lang-python": "^6.1.6", - "@codemirror/lang-sql": "^6.6.4", - "@codemirror/language": "^6.10.2", - "@codemirror/legacy-modes": "^6.4.0", - "@codemirror/state": "^6.4.1", - "@codemirror/view": "^6.28.1", - "@headlessui/react": "^1.7.17", - "@heroicons/react": "^2.0.18", - "@radix-ui/react-context-menu": "^2.1.4", - "@radix-ui/react-select": "^1.2.2", - "@tailwindcss/container-queries": "^0.1.1", - "@tanstack/react-query": "^4.33.0", - "@tanstack/react-table": "^8.9.2", - "@tanstack/react-virtual": "^3.0.0-beta.56", - "@uidotdev/usehooks": "^2.2.0", - "@uiw/react-codemirror": "^4.21.12", - "apache-arrow": "^13.0.0", - "clsx": "^2.0.0", - "diff": "^5.2.0", - "elkjs": "^0.8.2", - "pluralize": "^8.0.0", - "react": "^18.2.0", - "react-dnd": "^16.0.1", - "react-dnd-html5-backend": "^16.0.1", - "react-dom": "^18.2.0", - "react-router-dom": "^6.15.0", - "react-split": "^2.0.14", - "reactflow": "^11.8.3", - "thememirror": "^2.0.1", - "zustand": "^4.4.1" - }, - "devDependencies": { - "@playwright/test": "^1.37.1", - "@testing-library/jest-dom": "^6.1.2", - "@testing-library/react": "^14.0.0", - "@testing-library/user-event": "^14.4.3", - "@types/diff": "^5.2.1", - "@types/pluralize": "^0.0.30", - "@types/react": "^18.2.21", - "@types/react-dom": "^18.2.7", - "@typescript-eslint/eslint-plugin": "^6.5.0", - "@vitejs/plugin-react-swc": "^3.3.2", - "autoprefixer": "^10.4.15", - "eslint": "^8.48.0", - "eslint-config-prettier": "^9.0.0", - "eslint-config-standard-with-typescript": "^39.0.0", - "eslint-plugin-import": "^2.28.1", - "eslint-plugin-n": "^16.0.2", - "eslint-plugin-promise": "^6.1.1", - "eslint-plugin-react": "^7.33.2", - "jsdom": "^22.1.0", - "orval": "^6.22.1", - "postcss": "^8.4.29", - "prettier": "^3.0.3", - "tailwindcss": "^3.3.3", - "typescript": "^5.2.2", - "vite": "^4.4.9", - "vitest": "^0.34.3" - } - }, - "node_modules/@75lb/deep-merge": { - "version": "1.1.1", - "license": "MIT", - "dependencies": { - "lodash.assignwith": "^4.2.0", - "typical": "^7.1.1" - }, - "engines": { - "node": ">=12.17" - } - }, - "node_modules/@75lb/deep-merge/node_modules/typical": { - "version": "7.1.1", - "license": "MIT", - "engines": { - "node": ">=12.17" - } - }, - "node_modules/@aashutoshrathi/word-wrap": { - "version": "1.2.6", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/@adobe/css-tools": { - "version": "4.3.2", - "dev": true, - "license": "MIT" - }, - "node_modules/@alloc/quick-lru": { - "version": "5.2.0", - "license": "MIT", - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/@apidevtools/json-schema-ref-parser": { - "version": "9.0.6", - "dev": true, - "license": "MIT", - "dependencies": { - "@jsdevtools/ono": "^7.1.3", - "call-me-maybe": "^1.0.1", - "js-yaml": "^3.13.1" - } - }, - "node_modules/@apidevtools/json-schema-ref-parser/node_modules/argparse": { - "version": "1.0.10", - "dev": true, - "license": "MIT", - "dependencies": { - "sprintf-js": "~1.0.2" - } - }, - "node_modules/@apidevtools/json-schema-ref-parser/node_modules/js-yaml": { - "version": "3.14.1", - "dev": true, - "license": "MIT", - "dependencies": { - "argparse": "^1.0.7", - "esprima": "^4.0.0" - }, - "bin": { - "js-yaml": "bin/js-yaml.js" - } - }, - "node_modules/@apidevtools/openapi-schemas": { - "version": "2.1.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=10" - } - }, - "node_modules/@apidevtools/swagger-methods": { - "version": "3.0.2", - "dev": true, - "license": "MIT" - }, - "node_modules/@apidevtools/swagger-parser": { - "version": "10.1.0", - "dev": true, - "license": "MIT", - "dependencies": { - "@apidevtools/json-schema-ref-parser": "9.0.6", - "@apidevtools/openapi-schemas": "^2.1.0", - "@apidevtools/swagger-methods": "^3.0.2", - "@jsdevtools/ono": "^7.1.3", - "ajv": "^8.6.3", - "ajv-draft-04": "^1.0.0", - "call-me-maybe": "^1.0.1" - }, - "peerDependencies": { - "openapi-types": ">=7" - } - }, - "node_modules/@apidevtools/swagger-parser/node_modules/ajv": { - "version": "8.12.0", - "dev": true, - "license": "MIT", - "dependencies": { - "fast-deep-equal": "^3.1.1", - "json-schema-traverse": "^1.0.0", - "require-from-string": "^2.0.2", - "uri-js": "^4.2.2" - }, - "funding": { - "type": "github", - "url": "https://github.com/sponsors/epoberezkin" - } - }, - "node_modules/@apidevtools/swagger-parser/node_modules/ajv-draft-04": { - "version": "1.0.0", - "dev": true, - "license": "MIT", - "peerDependencies": { - "ajv": "^8.5.0" - }, - "peerDependenciesMeta": { - "ajv": { - "optional": true - } - } - }, - "node_modules/@apidevtools/swagger-parser/node_modules/json-schema-traverse": { - "version": "1.0.0", - "dev": true, - "license": "MIT" - }, - "node_modules/@asyncapi/specs": { - "version": "4.3.1", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "@types/json-schema": "^7.0.11" - } - }, - "node_modules/@babel/code-frame": { - "version": "7.23.5", - "dev": true, - "license": "MIT", - "dependencies": { - "@babel/highlight": "^7.23.4", - "chalk": "^2.4.2" - }, - "engines": { - "node": ">=6.9.0" - } - }, - "node_modules/@babel/code-frame/node_modules/ansi-styles": { - "version": "3.2.1", - "dev": true, - "license": "MIT", - "dependencies": { - "color-convert": "^1.9.0" - }, - "engines": { - "node": ">=4" - } - }, - "node_modules/@babel/code-frame/node_modules/chalk": { - "version": "2.4.2", - "dev": true, - "license": "MIT", - "dependencies": { - "ansi-styles": "^3.2.1", - "escape-string-regexp": "^1.0.5", - "supports-color": "^5.3.0" - }, - "engines": { - "node": ">=4" - } - }, - "node_modules/@babel/code-frame/node_modules/color-convert": { - "version": "1.9.3", - "dev": true, - "license": "MIT", - "dependencies": { - "color-name": "1.1.3" - } - }, - "node_modules/@babel/code-frame/node_modules/color-name": { - "version": "1.1.3", - "dev": true, - "license": "MIT" - }, - "node_modules/@babel/code-frame/node_modules/escape-string-regexp": { - "version": "1.0.5", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=0.8.0" - } - }, - "node_modules/@babel/code-frame/node_modules/has-flag": { - "version": "3.0.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=4" - } - }, - "node_modules/@babel/code-frame/node_modules/supports-color": { - "version": "5.5.0", - "dev": true, - "license": "MIT", - "dependencies": { - "has-flag": "^3.0.0" - }, - "engines": { - "node": ">=4" - } - }, - "node_modules/@babel/helper-validator-identifier": { - "version": "7.22.20", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=6.9.0" - } - }, - "node_modules/@babel/highlight": { - "version": "7.23.4", - "dev": true, - "license": "MIT", - "dependencies": { - "@babel/helper-validator-identifier": "^7.22.20", - "chalk": "^2.4.2", - "js-tokens": "^4.0.0" - }, - "engines": { - "node": ">=6.9.0" - } - }, - "node_modules/@babel/highlight/node_modules/ansi-styles": { - "version": "3.2.1", - "dev": true, - "license": "MIT", - "dependencies": { - "color-convert": "^1.9.0" - }, - "engines": { - "node": ">=4" - } - }, - "node_modules/@babel/highlight/node_modules/chalk": { - "version": "2.4.2", - "dev": true, - "license": "MIT", - "dependencies": { - "ansi-styles": "^3.2.1", - "escape-string-regexp": "^1.0.5", - "supports-color": "^5.3.0" - }, - "engines": { - "node": ">=4" - } - }, - "node_modules/@babel/highlight/node_modules/color-convert": { - "version": "1.9.3", - "dev": true, - "license": "MIT", - "dependencies": { - "color-name": "1.1.3" - } - }, - "node_modules/@babel/highlight/node_modules/color-name": { - "version": "1.1.3", - "dev": true, - "license": "MIT" - }, - "node_modules/@babel/highlight/node_modules/escape-string-regexp": { - "version": "1.0.5", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=0.8.0" - } - }, - "node_modules/@babel/highlight/node_modules/has-flag": { - "version": "3.0.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=4" - } - }, - "node_modules/@babel/highlight/node_modules/supports-color": { - "version": "5.5.0", - "dev": true, - "license": "MIT", - "dependencies": { - "has-flag": "^3.0.0" - }, - "engines": { - "node": ">=4" - } - }, - "node_modules/@babel/runtime": { - "version": "7.23.5", - "license": "MIT", - "dependencies": { - "regenerator-runtime": "^0.14.0" - }, - "engines": { - "node": ">=6.9.0" - } - }, - "node_modules/@codemirror/autocomplete": { - "version": "6.16.2", - "resolved": "https://registry.npmjs.org/@codemirror/autocomplete/-/autocomplete-6.16.2.tgz", - "integrity": "sha512-MjfDrHy0gHKlPWsvSsikhO1+BOh+eBHNgfH1OXs1+DAf30IonQldgMM3kxLDTG9ktE7kDLaA1j/l7KMPA4KNfw==", - "dependencies": { - "@codemirror/language": "^6.0.0", - "@codemirror/state": "^6.0.0", - "@codemirror/view": "^6.17.0", - "@lezer/common": "^1.0.0" - }, - "peerDependencies": { - "@codemirror/language": "^6.0.0", - "@codemirror/state": "^6.0.0", - "@codemirror/view": "^6.0.0", - "@lezer/common": "^1.0.0" - } - }, - "node_modules/@codemirror/commands": { - "version": "6.6.0", - "resolved": "https://registry.npmjs.org/@codemirror/commands/-/commands-6.6.0.tgz", - "integrity": "sha512-qnY+b7j1UNcTS31Eenuc/5YJB6gQOzkUoNmJQc0rznwqSRpeaWWpjkWy2C/MPTcePpsKJEM26hXrOXl1+nceXg==", - "dependencies": { - "@codemirror/language": "^6.0.0", - "@codemirror/state": "^6.4.0", - "@codemirror/view": "^6.27.0", - "@lezer/common": "^1.1.0" - } - }, - "node_modules/@codemirror/lang-python": { - "version": "6.1.6", - "resolved": "https://registry.npmjs.org/@codemirror/lang-python/-/lang-python-6.1.6.tgz", - "integrity": "sha512-ai+01WfZhWqM92UqjnvorkxosZ2aq2u28kHvr+N3gu012XqY2CThD67JPMHnGceRfXPDBmn1HnyqowdpF57bNg==", - "dependencies": { - "@codemirror/autocomplete": "^6.3.2", - "@codemirror/language": "^6.8.0", - "@codemirror/state": "^6.0.0", - "@lezer/common": "^1.2.1", - "@lezer/python": "^1.1.4" - } - }, - "node_modules/@codemirror/lang-sql": { - "version": "6.6.4", - "resolved": "https://registry.npmjs.org/@codemirror/lang-sql/-/lang-sql-6.6.4.tgz", - "integrity": "sha512-n+FVfKGut+frOvor9dU5pFUalcP614WBNQ9IT1kOUj1t6LFLjWHi2I9DdxXnJuxqFV9jTyYF79coDV3ilSJqCw==", - "dependencies": { - "@codemirror/autocomplete": "^6.0.0", - "@codemirror/language": "^6.0.0", - "@codemirror/state": "^6.0.0", - "@lezer/common": "^1.2.0", - "@lezer/highlight": "^1.0.0", - "@lezer/lr": "^1.0.0" - } - }, - "node_modules/@codemirror/language": { - "version": "6.10.2", - "resolved": "https://registry.npmjs.org/@codemirror/language/-/language-6.10.2.tgz", - "integrity": "sha512-kgbTYTo0Au6dCSc/TFy7fK3fpJmgHDv1sG1KNQKJXVi+xBTEeBPY/M30YXiU6mMXeH+YIDLsbrT4ZwNRdtF+SA==", - "dependencies": { - "@codemirror/state": "^6.0.0", - "@codemirror/view": "^6.23.0", - "@lezer/common": "^1.1.0", - "@lezer/highlight": "^1.0.0", - "@lezer/lr": "^1.0.0", - "style-mod": "^4.0.0" - } - }, - "node_modules/@codemirror/legacy-modes": { - "version": "6.4.0", - "resolved": "https://registry.npmjs.org/@codemirror/legacy-modes/-/legacy-modes-6.4.0.tgz", - "integrity": "sha512-5m/K+1A6gYR0e+h/dEde7LoGimMjRtWXZFg4Lo70cc8HzjSdHe3fLwjWMR0VRl5KFT1SxalSap7uMgPKF28wBA==", - "dependencies": { - "@codemirror/language": "^6.0.0" - } - }, - "node_modules/@codemirror/lint": { - "version": "6.4.2", - "license": "MIT", - "dependencies": { - "@codemirror/state": "^6.0.0", - "@codemirror/view": "^6.0.0", - "crelt": "^1.0.5" - } - }, - "node_modules/@codemirror/search": { - "version": "6.5.5", - "license": "MIT", - "dependencies": { - "@codemirror/state": "^6.0.0", - "@codemirror/view": "^6.0.0", - "crelt": "^1.0.5" - } - }, - "node_modules/@codemirror/state": { - "version": "6.4.1", - "resolved": "https://registry.npmjs.org/@codemirror/state/-/state-6.4.1.tgz", - "integrity": "sha512-QkEyUiLhsJoZkbumGZlswmAhA7CBU02Wrz7zvH4SrcifbsqwlXShVXg65f3v/ts57W3dqyamEriMhij1Z3Zz4A==" - }, - "node_modules/@codemirror/theme-one-dark": { - "version": "6.1.2", - "license": "MIT", - "dependencies": { - "@codemirror/language": "^6.0.0", - "@codemirror/state": "^6.0.0", - "@codemirror/view": "^6.0.0", - "@lezer/highlight": "^1.0.0" - } - }, - "node_modules/@codemirror/view": { - "version": "6.28.1", - "resolved": "https://registry.npmjs.org/@codemirror/view/-/view-6.28.1.tgz", - "integrity": "sha512-BUWr+zCJpMkA/u69HlJmR+YkV4yPpM81HeMkOMZuwFa8iM5uJdEPKAs1icIRZKkKmy0Ub1x9/G3PQLTXdpBxrQ==", - "dependencies": { - "@codemirror/state": "^6.4.0", - "style-mod": "^4.1.0", - "w3c-keyname": "^2.2.4" - } - }, - "node_modules/@esbuild/linux-arm64": { - "version": "0.19.8", - "cpu": [ - "arm64" - ], - "dev": true, - "license": "MIT", - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">=12" - } - }, - "node_modules/@eslint-community/eslint-utils": { - "version": "4.4.0", - "dev": true, - "license": "MIT", - "dependencies": { - "eslint-visitor-keys": "^3.3.0" - }, - "engines": { - "node": "^12.22.0 || ^14.17.0 || >=16.0.0" - }, - "peerDependencies": { - "eslint": "^6.0.0 || ^7.0.0 || >=8.0.0" - } - }, - "node_modules/@eslint-community/regexpp": { - "version": "4.10.0", - "dev": true, - "license": "MIT", - "engines": { - "node": "^12.0.0 || ^14.0.0 || >=16.0.0" - } - }, - "node_modules/@eslint/eslintrc": { - "version": "2.1.4", - "dev": true, - "license": "MIT", - "dependencies": { - "ajv": "^6.12.4", - "debug": "^4.3.2", - "espree": "^9.6.0", - "globals": "^13.19.0", - "ignore": "^5.2.0", - "import-fresh": "^3.2.1", - "js-yaml": "^4.1.0", - "minimatch": "^3.1.2", - "strip-json-comments": "^3.1.1" - }, - "engines": { - "node": "^12.22.0 || ^14.17.0 || >=16.0.0" - }, - "funding": { - "url": "https://opencollective.com/eslint" - } - }, - "node_modules/@eslint/js": { - "version": "8.55.0", - "dev": true, - "license": "MIT", - "engines": { - "node": "^12.22.0 || ^14.17.0 || >=16.0.0" - } - }, - "node_modules/@exodus/schemasafe": { - "version": "1.3.0", - "dev": true, - "license": "MIT" - }, - "node_modules/@floating-ui/core": { - "version": "1.5.2", - "license": "MIT", - "dependencies": { - "@floating-ui/utils": "^0.1.3" - } - }, - "node_modules/@floating-ui/dom": { - "version": "1.5.3", - "license": "MIT", - "dependencies": { - "@floating-ui/core": "^1.4.2", - "@floating-ui/utils": "^0.1.3" - } - }, - "node_modules/@floating-ui/react-dom": { - "version": "2.0.4", - "license": "MIT", - "dependencies": { - "@floating-ui/dom": "^1.5.1" - }, - "peerDependencies": { - "react": ">=16.8.0", - "react-dom": ">=16.8.0" - } - }, - "node_modules/@floating-ui/utils": { - "version": "0.1.6", - "license": "MIT" - }, - "node_modules/@headlessui/react": { - "version": "1.7.17", - "license": "MIT", - "dependencies": { - "client-only": "^0.0.1" - }, - "engines": { - "node": ">=10" - }, - "peerDependencies": { - "react": "^16 || ^17 || ^18", - "react-dom": "^16 || ^17 || ^18" - } - }, - "node_modules/@heroicons/react": { - "version": "2.0.18", - "license": "MIT", - "peerDependencies": { - "react": ">= 16" - } - }, - "node_modules/@humanwhocodes/config-array": { - "version": "0.11.13", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "@humanwhocodes/object-schema": "^2.0.1", - "debug": "^4.1.1", - "minimatch": "^3.0.5" - }, - "engines": { - "node": ">=10.10.0" - } - }, - "node_modules/@humanwhocodes/module-importer": { - "version": "1.0.1", - "dev": true, - "license": "Apache-2.0", - "engines": { - "node": ">=12.22" - }, - "funding": { - "type": "github", - "url": "https://github.com/sponsors/nzakas" - } - }, - "node_modules/@humanwhocodes/object-schema": { - "version": "2.0.1", - "dev": true, - "license": "BSD-3-Clause" - }, - "node_modules/@ibm-cloud/openapi-ruleset": { - "version": "1.14.2", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "@ibm-cloud/openapi-ruleset-utilities": "1.3.0", - "@stoplight/spectral-formats": "^1.5.0", - "@stoplight/spectral-functions": "^1.7.2", - "@stoplight/spectral-rulesets": "^1.16.0", - "chalk": "^4.1.1", - "lodash": "^4.17.21", - "loglevel": "^1.8.1", - "loglevel-plugin-prefix": "0.8.4", - "minimatch": "^6.1.6", - "validator": "^13.7.0" - }, - "engines": { - "node": ">=16.0.0" - } - }, - "node_modules/@ibm-cloud/openapi-ruleset-utilities": { - "version": "1.3.0", - "dev": true, - "license": "Apache-2.0", - "engines": { - "node": ">=16.0.0" - } - }, - "node_modules/@ibm-cloud/openapi-ruleset/node_modules/brace-expansion": { - "version": "2.0.1", - "dev": true, - "license": "MIT", - "dependencies": { - "balanced-match": "^1.0.0" - } - }, - "node_modules/@ibm-cloud/openapi-ruleset/node_modules/minimatch": { - "version": "6.2.0", - "dev": true, - "license": "ISC", - "dependencies": { - "brace-expansion": "^2.0.1" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" - } - }, - "node_modules/@jest/schemas": { - "version": "29.6.3", - "dev": true, - "license": "MIT", - "dependencies": { - "@sinclair/typebox": "^0.27.8" - }, - "engines": { - "node": "^14.15.0 || ^16.10.0 || >=18.0.0" - } - }, - "node_modules/@jridgewell/gen-mapping": { - "version": "0.3.3", - "license": "MIT", - "dependencies": { - "@jridgewell/set-array": "^1.0.1", - "@jridgewell/sourcemap-codec": "^1.4.10", - "@jridgewell/trace-mapping": "^0.3.9" - }, - "engines": { - "node": ">=6.0.0" - } - }, - "node_modules/@jridgewell/resolve-uri": { - "version": "3.1.1", - "license": "MIT", - "engines": { - "node": ">=6.0.0" - } - }, - "node_modules/@jridgewell/set-array": { - "version": "1.1.2", - "license": "MIT", - "engines": { - "node": ">=6.0.0" - } - }, - "node_modules/@jridgewell/sourcemap-codec": { - "version": "1.4.15", - "license": "MIT" - }, - "node_modules/@jridgewell/trace-mapping": { - "version": "0.3.20", - "license": "MIT", - "dependencies": { - "@jridgewell/resolve-uri": "^3.1.0", - "@jridgewell/sourcemap-codec": "^1.4.14" - } - }, - "node_modules/@jsdevtools/ono": { - "version": "7.1.3", - "dev": true, - "license": "MIT" - }, - "node_modules/@jsep-plugin/regex": { - "version": "1.0.3", - "dev": true, - "license": "MIT", - "engines": { - "node": ">= 10.16.0" - }, - "peerDependencies": { - "jsep": "^0.4.0||^1.0.0" - } - }, - "node_modules/@jsep-plugin/ternary": { - "version": "1.1.3", - "dev": true, - "license": "MIT", - "engines": { - "node": ">= 10.16.0" - }, - "peerDependencies": { - "jsep": "^0.4.0||^1.0.0" - } - }, - "node_modules/@lezer/common": { - "version": "1.2.1", - "resolved": "https://registry.npmjs.org/@lezer/common/-/common-1.2.1.tgz", - "integrity": "sha512-yemX0ZD2xS/73llMZIK6KplkjIjf2EvAHcinDi/TfJ9hS25G0388+ClHt6/3but0oOxinTcQHJLDXh6w1crzFQ==" - }, - "node_modules/@lezer/highlight": { - "version": "1.2.0", - "license": "MIT", - "dependencies": { - "@lezer/common": "^1.0.0" - } - }, - "node_modules/@lezer/lr": { - "version": "1.3.14", - "license": "MIT", - "dependencies": { - "@lezer/common": "^1.0.0" - } - }, - "node_modules/@lezer/python": { - "version": "1.1.9", - "license": "MIT", - "dependencies": { - "@lezer/highlight": "^1.0.0", - "@lezer/lr": "^1.0.0" - } - }, - "node_modules/@nodelib/fs.scandir": { - "version": "2.1.5", - "license": "MIT", - "dependencies": { - "@nodelib/fs.stat": "2.0.5", - "run-parallel": "^1.1.9" - }, - "engines": { - "node": ">= 8" - } - }, - "node_modules/@nodelib/fs.stat": { - "version": "2.0.5", - "license": "MIT", - "engines": { - "node": ">= 8" - } - }, - "node_modules/@nodelib/fs.walk": { - "version": "1.2.8", - "license": "MIT", - "dependencies": { - "@nodelib/fs.scandir": "2.1.5", - "fastq": "^1.6.0" - }, - "engines": { - "node": ">= 8" - } - }, - "node_modules/@orval/angular": { - "version": "6.22.1", - "dev": true, - "license": "MIT", - "dependencies": { - "@orval/core": "6.22.1" - } - }, - "node_modules/@orval/axios": { - "version": "6.22.1", - "dev": true, - "license": "MIT", - "dependencies": { - "@orval/core": "6.22.1" - } - }, - "node_modules/@orval/core": { - "version": "6.22.1", - "dev": true, - "license": "MIT", - "dependencies": { - "@apidevtools/swagger-parser": "^10.1.0", - "@ibm-cloud/openapi-ruleset": "^1.14.2", - "acorn": "^8.11.2", - "ajv": "^8.12.0", - "chalk": "^4.1.2", - "compare-versions": "^6.1.0", - "debug": "^4.3.4", - "esbuild": "^0.19.5", - "esutils": "2.0.3", - "fs-extra": "^11.2.0", - "globby": "11.1.0", - "lodash.get": "^4.4.2", - "lodash.isempty": "^4.4.0", - "lodash.omit": "^4.5.0", - "lodash.uniq": "^4.5.0", - "lodash.uniqby": "^4.7.0", - "lodash.uniqwith": "^4.5.0", - "micromatch": "^4.0.5", - "openapi-types": "^12.1.3", - "openapi3-ts": "^3.2.0", - "swagger2openapi": "^7.0.8" - } - }, - "node_modules/@orval/core/node_modules/ajv": { - "version": "8.12.0", - "dev": true, - "license": "MIT", - "dependencies": { - "fast-deep-equal": "^3.1.1", - "json-schema-traverse": "^1.0.0", - "require-from-string": "^2.0.2", - "uri-js": "^4.2.2" - }, - "funding": { - "type": "github", - "url": "https://github.com/sponsors/epoberezkin" - } - }, - "node_modules/@orval/core/node_modules/json-schema-traverse": { - "version": "1.0.0", - "dev": true, - "license": "MIT" - }, - "node_modules/@orval/mock": { - "version": "6.22.1", - "dev": true, - "license": "MIT", - "dependencies": { - "@orval/core": "6.22.1", - "lodash.get": "^4.4.2", - "lodash.omit": "^4.5.0", - "openapi3-ts": "^3.0.0" - } - }, - "node_modules/@orval/query": { - "version": "6.22.1", - "dev": true, - "license": "MIT", - "dependencies": { - "@orval/core": "6.22.1", - "lodash.omitby": "^4.6.0", - "vitest": "^0.34.6" - } - }, - "node_modules/@orval/swr": { - "version": "6.22.1", - "dev": true, - "license": "MIT", - "dependencies": { - "@orval/core": "6.22.1" - } - }, - "node_modules/@orval/zod": { - "version": "6.22.1", - "dev": true, - "license": "MIT", - "dependencies": { - "@orval/core": "6.22.1", - "lodash.uniq": "^4.5.0" - } - }, - "node_modules/@playwright/test": { - "version": "1.40.1", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "playwright": "1.40.1" - }, - "bin": { - "playwright": "cli.js" - }, - "engines": { - "node": ">=16" - } - }, - "node_modules/@radix-ui/number": { - "version": "1.0.1", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.13.10" - } - }, - "node_modules/@radix-ui/primitive": { - "version": "1.0.1", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.13.10" - } - }, - "node_modules/@radix-ui/react-arrow": { - "version": "1.0.3", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.13.10", - "@radix-ui/react-primitive": "1.0.3" - }, - "peerDependencies": { - "@types/react": "*", - "@types/react-dom": "*", - "react": "^16.8 || ^17.0 || ^18.0", - "react-dom": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - }, - "@types/react-dom": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-collection": { - "version": "1.0.3", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.13.10", - "@radix-ui/react-compose-refs": "1.0.1", - "@radix-ui/react-context": "1.0.1", - "@radix-ui/react-primitive": "1.0.3", - "@radix-ui/react-slot": "1.0.2" - }, - "peerDependencies": { - "@types/react": "*", - "@types/react-dom": "*", - "react": "^16.8 || ^17.0 || ^18.0", - "react-dom": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - }, - "@types/react-dom": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-compose-refs": { - "version": "1.0.1", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.13.10" - }, - "peerDependencies": { - "@types/react": "*", - "react": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-context": { - "version": "1.0.1", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.13.10" - }, - "peerDependencies": { - "@types/react": "*", - "react": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-context-menu": { - "version": "2.1.5", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.13.10", - "@radix-ui/primitive": "1.0.1", - "@radix-ui/react-context": "1.0.1", - "@radix-ui/react-menu": "2.0.6", - "@radix-ui/react-primitive": "1.0.3", - "@radix-ui/react-use-callback-ref": "1.0.1", - "@radix-ui/react-use-controllable-state": "1.0.1" - }, - "peerDependencies": { - "@types/react": "*", - "@types/react-dom": "*", - "react": "^16.8 || ^17.0 || ^18.0", - "react-dom": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - }, - "@types/react-dom": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-direction": { - "version": "1.0.1", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.13.10" - }, - "peerDependencies": { - "@types/react": "*", - "react": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-dismissable-layer": { - "version": "1.0.5", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.13.10", - "@radix-ui/primitive": "1.0.1", - "@radix-ui/react-compose-refs": "1.0.1", - "@radix-ui/react-primitive": "1.0.3", - "@radix-ui/react-use-callback-ref": "1.0.1", - "@radix-ui/react-use-escape-keydown": "1.0.3" - }, - "peerDependencies": { - "@types/react": "*", - "@types/react-dom": "*", - "react": "^16.8 || ^17.0 || ^18.0", - "react-dom": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - }, - "@types/react-dom": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-focus-guards": { - "version": "1.0.1", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.13.10" - }, - "peerDependencies": { - "@types/react": "*", - "react": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-focus-scope": { - "version": "1.0.4", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.13.10", - "@radix-ui/react-compose-refs": "1.0.1", - "@radix-ui/react-primitive": "1.0.3", - "@radix-ui/react-use-callback-ref": "1.0.1" - }, - "peerDependencies": { - "@types/react": "*", - "@types/react-dom": "*", - "react": "^16.8 || ^17.0 || ^18.0", - "react-dom": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - }, - "@types/react-dom": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-id": { - "version": "1.0.1", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.13.10", - "@radix-ui/react-use-layout-effect": "1.0.1" - }, - "peerDependencies": { - "@types/react": "*", - "react": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-menu": { - "version": "2.0.6", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.13.10", - "@radix-ui/primitive": "1.0.1", - "@radix-ui/react-collection": "1.0.3", - "@radix-ui/react-compose-refs": "1.0.1", - "@radix-ui/react-context": "1.0.1", - "@radix-ui/react-direction": "1.0.1", - "@radix-ui/react-dismissable-layer": "1.0.5", - "@radix-ui/react-focus-guards": "1.0.1", - "@radix-ui/react-focus-scope": "1.0.4", - "@radix-ui/react-id": "1.0.1", - "@radix-ui/react-popper": "1.1.3", - "@radix-ui/react-portal": "1.0.4", - "@radix-ui/react-presence": "1.0.1", - "@radix-ui/react-primitive": "1.0.3", - "@radix-ui/react-roving-focus": "1.0.4", - "@radix-ui/react-slot": "1.0.2", - "@radix-ui/react-use-callback-ref": "1.0.1", - "aria-hidden": "^1.1.1", - "react-remove-scroll": "2.5.5" - }, - "peerDependencies": { - "@types/react": "*", - "@types/react-dom": "*", - "react": "^16.8 || ^17.0 || ^18.0", - "react-dom": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - }, - "@types/react-dom": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-popper": { - "version": "1.1.3", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.13.10", - "@floating-ui/react-dom": "^2.0.0", - "@radix-ui/react-arrow": "1.0.3", - "@radix-ui/react-compose-refs": "1.0.1", - "@radix-ui/react-context": "1.0.1", - "@radix-ui/react-primitive": "1.0.3", - "@radix-ui/react-use-callback-ref": "1.0.1", - "@radix-ui/react-use-layout-effect": "1.0.1", - "@radix-ui/react-use-rect": "1.0.1", - "@radix-ui/react-use-size": "1.0.1", - "@radix-ui/rect": "1.0.1" - }, - "peerDependencies": { - "@types/react": "*", - "@types/react-dom": "*", - "react": "^16.8 || ^17.0 || ^18.0", - "react-dom": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - }, - "@types/react-dom": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-portal": { - "version": "1.0.4", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.13.10", - "@radix-ui/react-primitive": "1.0.3" - }, - "peerDependencies": { - "@types/react": "*", - "@types/react-dom": "*", - "react": "^16.8 || ^17.0 || ^18.0", - "react-dom": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - }, - "@types/react-dom": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-presence": { - "version": "1.0.1", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.13.10", - "@radix-ui/react-compose-refs": "1.0.1", - "@radix-ui/react-use-layout-effect": "1.0.1" - }, - "peerDependencies": { - "@types/react": "*", - "@types/react-dom": "*", - "react": "^16.8 || ^17.0 || ^18.0", - "react-dom": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - }, - "@types/react-dom": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-primitive": { - "version": "1.0.3", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.13.10", - "@radix-ui/react-slot": "1.0.2" - }, - "peerDependencies": { - "@types/react": "*", - "@types/react-dom": "*", - "react": "^16.8 || ^17.0 || ^18.0", - "react-dom": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - }, - "@types/react-dom": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-roving-focus": { - "version": "1.0.4", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.13.10", - "@radix-ui/primitive": "1.0.1", - "@radix-ui/react-collection": "1.0.3", - "@radix-ui/react-compose-refs": "1.0.1", - "@radix-ui/react-context": "1.0.1", - "@radix-ui/react-direction": "1.0.1", - "@radix-ui/react-id": "1.0.1", - "@radix-ui/react-primitive": "1.0.3", - "@radix-ui/react-use-callback-ref": "1.0.1", - "@radix-ui/react-use-controllable-state": "1.0.1" - }, - "peerDependencies": { - "@types/react": "*", - "@types/react-dom": "*", - "react": "^16.8 || ^17.0 || ^18.0", - "react-dom": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - }, - "@types/react-dom": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-select": { - "version": "1.2.2", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.13.10", - "@radix-ui/number": "1.0.1", - "@radix-ui/primitive": "1.0.1", - "@radix-ui/react-collection": "1.0.3", - "@radix-ui/react-compose-refs": "1.0.1", - "@radix-ui/react-context": "1.0.1", - "@radix-ui/react-direction": "1.0.1", - "@radix-ui/react-dismissable-layer": "1.0.4", - "@radix-ui/react-focus-guards": "1.0.1", - "@radix-ui/react-focus-scope": "1.0.3", - "@radix-ui/react-id": "1.0.1", - "@radix-ui/react-popper": "1.1.2", - "@radix-ui/react-portal": "1.0.3", - "@radix-ui/react-primitive": "1.0.3", - "@radix-ui/react-slot": "1.0.2", - "@radix-ui/react-use-callback-ref": "1.0.1", - "@radix-ui/react-use-controllable-state": "1.0.1", - "@radix-ui/react-use-layout-effect": "1.0.1", - "@radix-ui/react-use-previous": "1.0.1", - "@radix-ui/react-visually-hidden": "1.0.3", - "aria-hidden": "^1.1.1", - "react-remove-scroll": "2.5.5" - }, - "peerDependencies": { - "@types/react": "*", - "@types/react-dom": "*", - "react": "^16.8 || ^17.0 || ^18.0", - "react-dom": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - }, - "@types/react-dom": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-select/node_modules/@radix-ui/react-dismissable-layer": { - "version": "1.0.4", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.13.10", - "@radix-ui/primitive": "1.0.1", - "@radix-ui/react-compose-refs": "1.0.1", - "@radix-ui/react-primitive": "1.0.3", - "@radix-ui/react-use-callback-ref": "1.0.1", - "@radix-ui/react-use-escape-keydown": "1.0.3" - }, - "peerDependencies": { - "@types/react": "*", - "@types/react-dom": "*", - "react": "^16.8 || ^17.0 || ^18.0", - "react-dom": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - }, - "@types/react-dom": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-select/node_modules/@radix-ui/react-focus-scope": { - "version": "1.0.3", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.13.10", - "@radix-ui/react-compose-refs": "1.0.1", - "@radix-ui/react-primitive": "1.0.3", - "@radix-ui/react-use-callback-ref": "1.0.1" - }, - "peerDependencies": { - "@types/react": "*", - "@types/react-dom": "*", - "react": "^16.8 || ^17.0 || ^18.0", - "react-dom": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - }, - "@types/react-dom": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-select/node_modules/@radix-ui/react-popper": { - "version": "1.1.2", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.13.10", - "@floating-ui/react-dom": "^2.0.0", - "@radix-ui/react-arrow": "1.0.3", - "@radix-ui/react-compose-refs": "1.0.1", - "@radix-ui/react-context": "1.0.1", - "@radix-ui/react-primitive": "1.0.3", - "@radix-ui/react-use-callback-ref": "1.0.1", - "@radix-ui/react-use-layout-effect": "1.0.1", - "@radix-ui/react-use-rect": "1.0.1", - "@radix-ui/react-use-size": "1.0.1", - "@radix-ui/rect": "1.0.1" - }, - "peerDependencies": { - "@types/react": "*", - "@types/react-dom": "*", - "react": "^16.8 || ^17.0 || ^18.0", - "react-dom": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - }, - "@types/react-dom": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-select/node_modules/@radix-ui/react-portal": { - "version": "1.0.3", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.13.10", - "@radix-ui/react-primitive": "1.0.3" - }, - "peerDependencies": { - "@types/react": "*", - "@types/react-dom": "*", - "react": "^16.8 || ^17.0 || ^18.0", - "react-dom": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - }, - "@types/react-dom": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-slot": { - "version": "1.0.2", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.13.10", - "@radix-ui/react-compose-refs": "1.0.1" - }, - "peerDependencies": { - "@types/react": "*", - "react": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-use-callback-ref": { - "version": "1.0.1", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.13.10" - }, - "peerDependencies": { - "@types/react": "*", - "react": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-use-controllable-state": { - "version": "1.0.1", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.13.10", - "@radix-ui/react-use-callback-ref": "1.0.1" - }, - "peerDependencies": { - "@types/react": "*", - "react": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-use-escape-keydown": { - "version": "1.0.3", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.13.10", - "@radix-ui/react-use-callback-ref": "1.0.1" - }, - "peerDependencies": { - "@types/react": "*", - "react": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-use-layout-effect": { - "version": "1.0.1", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.13.10" - }, - "peerDependencies": { - "@types/react": "*", - "react": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-use-previous": { - "version": "1.0.1", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.13.10" - }, - "peerDependencies": { - "@types/react": "*", - "react": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-use-rect": { - "version": "1.0.1", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.13.10", - "@radix-ui/rect": "1.0.1" - }, - "peerDependencies": { - "@types/react": "*", - "react": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-use-size": { - "version": "1.0.1", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.13.10", - "@radix-ui/react-use-layout-effect": "1.0.1" - }, - "peerDependencies": { - "@types/react": "*", - "react": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/@radix-ui/react-visually-hidden": { - "version": "1.0.3", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.13.10", - "@radix-ui/react-primitive": "1.0.3" - }, - "peerDependencies": { - "@types/react": "*", - "@types/react-dom": "*", - "react": "^16.8 || ^17.0 || ^18.0", - "react-dom": "^16.8 || ^17.0 || ^18.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - }, - "@types/react-dom": { - "optional": true - } - } - }, - "node_modules/@radix-ui/rect": { - "version": "1.0.1", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.13.10" - } - }, - "node_modules/@react-dnd/asap": { - "version": "5.0.2", - "license": "MIT" - }, - "node_modules/@react-dnd/invariant": { - "version": "4.0.2", - "license": "MIT" - }, - "node_modules/@react-dnd/shallowequal": { - "version": "4.0.2", - "license": "MIT" - }, - "node_modules/@reactflow/background": { - "version": "11.3.6", - "license": "MIT", - "dependencies": { - "@reactflow/core": "11.10.1", - "classcat": "^5.0.3", - "zustand": "^4.4.1" - }, - "peerDependencies": { - "react": ">=17", - "react-dom": ">=17" - } - }, - "node_modules/@reactflow/controls": { - "version": "11.2.6", - "license": "MIT", - "dependencies": { - "@reactflow/core": "11.10.1", - "classcat": "^5.0.3", - "zustand": "^4.4.1" - }, - "peerDependencies": { - "react": ">=17", - "react-dom": ">=17" - } - }, - "node_modules/@reactflow/core": { - "version": "11.10.1", - "license": "MIT", - "dependencies": { - "@types/d3": "^7.4.0", - "@types/d3-drag": "^3.0.1", - "@types/d3-selection": "^3.0.3", - "@types/d3-zoom": "^3.0.1", - "classcat": "^5.0.3", - "d3-drag": "^3.0.0", - "d3-selection": "^3.0.0", - "d3-zoom": "^3.0.0", - "zustand": "^4.4.1" - }, - "peerDependencies": { - "react": ">=17", - "react-dom": ">=17" - } - }, - "node_modules/@reactflow/minimap": { - "version": "11.7.6", - "license": "MIT", - "dependencies": { - "@reactflow/core": "11.10.1", - "@types/d3-selection": "^3.0.3", - "@types/d3-zoom": "^3.0.1", - "classcat": "^5.0.3", - "d3-selection": "^3.0.0", - "d3-zoom": "^3.0.0", - "zustand": "^4.4.1" - }, - "peerDependencies": { - "react": ">=17", - "react-dom": ">=17" - } - }, - "node_modules/@reactflow/node-resizer": { - "version": "2.2.6", - "license": "MIT", - "dependencies": { - "@reactflow/core": "11.10.1", - "classcat": "^5.0.4", - "d3-drag": "^3.0.0", - "d3-selection": "^3.0.0", - "zustand": "^4.4.1" - }, - "peerDependencies": { - "react": ">=17", - "react-dom": ">=17" - } - }, - "node_modules/@reactflow/node-toolbar": { - "version": "1.3.6", - "license": "MIT", - "dependencies": { - "@reactflow/core": "11.10.1", - "classcat": "^5.0.3", - "zustand": "^4.4.1" - }, - "peerDependencies": { - "react": ">=17", - "react-dom": ">=17" - } - }, - "node_modules/@remix-run/router": { - "version": "1.13.1", - "license": "MIT", - "engines": { - "node": ">=14.0.0" - } - }, - "node_modules/@sinclair/typebox": { - "version": "0.27.8", - "dev": true, - "license": "MIT" - }, - "node_modules/@stoplight/json": { - "version": "3.21.0", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "@stoplight/ordered-object-literal": "^1.0.3", - "@stoplight/path": "^1.3.2", - "@stoplight/types": "^13.6.0", - "jsonc-parser": "~2.2.1", - "lodash": "^4.17.21", - "safe-stable-stringify": "^1.1" - }, - "engines": { - "node": ">=8.3.0" - } - }, - "node_modules/@stoplight/json-ref-readers": { - "version": "1.2.2", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "node-fetch": "^2.6.0", - "tslib": "^1.14.1" - }, - "engines": { - "node": ">=8.3.0" - } - }, - "node_modules/@stoplight/json-ref-readers/node_modules/tslib": { - "version": "1.14.1", - "dev": true, - "license": "0BSD" - }, - "node_modules/@stoplight/json-ref-resolver": { - "version": "3.1.6", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "@stoplight/json": "^3.21.0", - "@stoplight/path": "^1.3.2", - "@stoplight/types": "^12.3.0 || ^13.0.0", - "@types/urijs": "^1.19.19", - "dependency-graph": "~0.11.0", - "fast-memoize": "^2.5.2", - "immer": "^9.0.6", - "lodash": "^4.17.21", - "tslib": "^2.6.0", - "urijs": "^1.19.11" - }, - "engines": { - "node": ">=8.3.0" - } - }, - "node_modules/@stoplight/ordered-object-literal": { - "version": "1.0.5", - "dev": true, - "license": "Apache-2.0", - "engines": { - "node": ">=8" - } - }, - "node_modules/@stoplight/path": { - "version": "1.3.2", - "dev": true, - "license": "Apache-2.0", - "engines": { - "node": ">=8" - } - }, - "node_modules/@stoplight/spectral-core": { - "version": "1.18.3", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "@stoplight/better-ajv-errors": "1.0.3", - "@stoplight/json": "~3.21.0", - "@stoplight/path": "1.3.2", - "@stoplight/spectral-parsers": "^1.0.0", - "@stoplight/spectral-ref-resolver": "^1.0.0", - "@stoplight/spectral-runtime": "^1.0.0", - "@stoplight/types": "~13.6.0", - "@types/es-aggregate-error": "^1.0.2", - "@types/json-schema": "^7.0.11", - "ajv": "^8.6.0", - "ajv-errors": "~3.0.0", - "ajv-formats": "~2.1.0", - "es-aggregate-error": "^1.0.7", - "jsonpath-plus": "7.1.0", - "lodash": "~4.17.21", - "lodash.topath": "^4.5.2", - "minimatch": "3.1.2", - "nimma": "0.2.2", - "pony-cause": "^1.0.0", - "simple-eval": "1.0.0", - "tslib": "^2.3.0" - }, - "engines": { - "node": "^12.20 || >= 14.13" - } - }, - "node_modules/@stoplight/spectral-core/node_modules/@stoplight/better-ajv-errors": { - "version": "1.0.3", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "jsonpointer": "^5.0.0", - "leven": "^3.1.0" - }, - "engines": { - "node": "^12.20 || >= 14.13" - }, - "peerDependencies": { - "ajv": ">=8" - } - }, - "node_modules/@stoplight/spectral-core/node_modules/@stoplight/types": { - "version": "13.6.0", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "@types/json-schema": "^7.0.4", - "utility-types": "^3.10.0" - }, - "engines": { - "node": "^12.20 || >=14.13" - } - }, - "node_modules/@stoplight/spectral-core/node_modules/ajv": { - "version": "8.12.0", - "dev": true, - "license": "MIT", - "dependencies": { - "fast-deep-equal": "^3.1.1", - "json-schema-traverse": "^1.0.0", - "require-from-string": "^2.0.2", - "uri-js": "^4.2.2" - }, - "funding": { - "type": "github", - "url": "https://github.com/sponsors/epoberezkin" - } - }, - "node_modules/@stoplight/spectral-core/node_modules/ajv-errors": { - "version": "3.0.0", - "dev": true, - "license": "MIT", - "peerDependencies": { - "ajv": "^8.0.1" - } - }, - "node_modules/@stoplight/spectral-core/node_modules/json-schema-traverse": { - "version": "1.0.0", - "dev": true, - "license": "MIT" - }, - "node_modules/@stoplight/spectral-formats": { - "version": "1.6.0", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "@stoplight/json": "^3.17.0", - "@stoplight/spectral-core": "^1.8.0", - "@types/json-schema": "^7.0.7", - "tslib": "^2.3.1" - }, - "engines": { - "node": ">=12" - } - }, - "node_modules/@stoplight/spectral-functions": { - "version": "1.7.2", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "@stoplight/better-ajv-errors": "1.0.3", - "@stoplight/json": "^3.17.1", - "@stoplight/spectral-core": "^1.7.0", - "@stoplight/spectral-formats": "^1.0.0", - "@stoplight/spectral-runtime": "^1.1.0", - "ajv": "^8.6.3", - "ajv-draft-04": "~1.0.0", - "ajv-errors": "~3.0.0", - "ajv-formats": "~2.1.0", - "lodash": "~4.17.21", - "tslib": "^2.3.0" - }, - "engines": { - "node": ">=12" - } - }, - "node_modules/@stoplight/spectral-functions/node_modules/@stoplight/better-ajv-errors": { - "version": "1.0.3", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "jsonpointer": "^5.0.0", - "leven": "^3.1.0" - }, - "engines": { - "node": "^12.20 || >= 14.13" - }, - "peerDependencies": { - "ajv": ">=8" - } - }, - "node_modules/@stoplight/spectral-functions/node_modules/ajv": { - "version": "8.12.0", - "dev": true, - "license": "MIT", - "dependencies": { - "fast-deep-equal": "^3.1.1", - "json-schema-traverse": "^1.0.0", - "require-from-string": "^2.0.2", - "uri-js": "^4.2.2" - }, - "funding": { - "type": "github", - "url": "https://github.com/sponsors/epoberezkin" - } - }, - "node_modules/@stoplight/spectral-functions/node_modules/ajv-draft-04": { - "version": "1.0.0", - "dev": true, - "license": "MIT", - "peerDependencies": { - "ajv": "^8.5.0" - }, - "peerDependenciesMeta": { - "ajv": { - "optional": true - } - } - }, - "node_modules/@stoplight/spectral-functions/node_modules/ajv-errors": { - "version": "3.0.0", - "dev": true, - "license": "MIT", - "peerDependencies": { - "ajv": "^8.0.1" - } - }, - "node_modules/@stoplight/spectral-functions/node_modules/json-schema-traverse": { - "version": "1.0.0", - "dev": true, - "license": "MIT" - }, - "node_modules/@stoplight/spectral-parsers": { - "version": "1.0.3", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "@stoplight/json": "~3.21.0", - "@stoplight/types": "^13.6.0", - "@stoplight/yaml": "~4.2.3", - "tslib": "^2.3.1" - }, - "engines": { - "node": "^12.20 || >=14.13" - } - }, - "node_modules/@stoplight/spectral-ref-resolver": { - "version": "1.0.4", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "@stoplight/json-ref-readers": "1.2.2", - "@stoplight/json-ref-resolver": "~3.1.6", - "@stoplight/spectral-runtime": "^1.1.2", - "dependency-graph": "0.11.0", - "tslib": "^2.3.1" - }, - "engines": { - "node": ">=12" - } - }, - "node_modules/@stoplight/spectral-rulesets": { - "version": "1.18.0", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "@asyncapi/specs": "^4.1.0", - "@stoplight/better-ajv-errors": "1.0.3", - "@stoplight/json": "^3.17.0", - "@stoplight/spectral-core": "^1.8.1", - "@stoplight/spectral-formats": "^1.5.0", - "@stoplight/spectral-functions": "^1.5.1", - "@stoplight/spectral-runtime": "^1.1.1", - "@stoplight/types": "^13.6.0", - "@types/json-schema": "^7.0.7", - "ajv": "^8.8.2", - "ajv-formats": "~2.1.0", - "json-schema-traverse": "^1.0.0", - "lodash": "~4.17.21", - "tslib": "^2.3.0" - }, - "engines": { - "node": ">=12" - } - }, - "node_modules/@stoplight/spectral-rulesets/node_modules/@stoplight/better-ajv-errors": { - "version": "1.0.3", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "jsonpointer": "^5.0.0", - "leven": "^3.1.0" - }, - "engines": { - "node": "^12.20 || >= 14.13" - }, - "peerDependencies": { - "ajv": ">=8" - } - }, - "node_modules/@stoplight/spectral-rulesets/node_modules/ajv": { - "version": "8.12.0", - "dev": true, - "license": "MIT", - "dependencies": { - "fast-deep-equal": "^3.1.1", - "json-schema-traverse": "^1.0.0", - "require-from-string": "^2.0.2", - "uri-js": "^4.2.2" - }, - "funding": { - "type": "github", - "url": "https://github.com/sponsors/epoberezkin" - } - }, - "node_modules/@stoplight/spectral-rulesets/node_modules/json-schema-traverse": { - "version": "1.0.0", - "dev": true, - "license": "MIT" - }, - "node_modules/@stoplight/spectral-runtime": { - "version": "1.1.2", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "@stoplight/json": "^3.17.0", - "@stoplight/path": "^1.3.2", - "@stoplight/types": "^12.3.0", - "abort-controller": "^3.0.0", - "lodash": "^4.17.21", - "node-fetch": "^2.6.7", - "tslib": "^2.3.1" - }, - "engines": { - "node": ">=12" - } - }, - "node_modules/@stoplight/spectral-runtime/node_modules/@stoplight/types": { - "version": "12.5.0", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "@types/json-schema": "^7.0.4", - "utility-types": "^3.10.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/@stoplight/types": { - "version": "13.20.0", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "@types/json-schema": "^7.0.4", - "utility-types": "^3.10.0" - }, - "engines": { - "node": "^12.20 || >=14.13" - } - }, - "node_modules/@stoplight/yaml": { - "version": "4.2.3", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "@stoplight/ordered-object-literal": "^1.0.1", - "@stoplight/types": "^13.0.0", - "@stoplight/yaml-ast-parser": "0.0.48", - "tslib": "^2.2.0" - }, - "engines": { - "node": ">=10.8" - } - }, - "node_modules/@stoplight/yaml-ast-parser": { - "version": "0.0.48", - "dev": true, - "license": "Apache-2.0" - }, - "node_modules/@swc/core": { - "version": "1.3.100", - "dev": true, - "hasInstallScript": true, - "license": "Apache-2.0", - "dependencies": { - "@swc/counter": "^0.1.1", - "@swc/types": "^0.1.5" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "type": "opencollective", - "url": "https://opencollective.com/swc" - }, - "optionalDependencies": { - "@swc/core-darwin-arm64": "1.3.100", - "@swc/core-darwin-x64": "1.3.100", - "@swc/core-linux-arm64-gnu": "1.3.100", - "@swc/core-linux-arm64-musl": "1.3.100", - "@swc/core-linux-x64-gnu": "1.3.100", - "@swc/core-linux-x64-musl": "1.3.100", - "@swc/core-win32-arm64-msvc": "1.3.100", - "@swc/core-win32-ia32-msvc": "1.3.100", - "@swc/core-win32-x64-msvc": "1.3.100" - }, - "peerDependencies": { - "@swc/helpers": "^0.5.0" - }, - "peerDependenciesMeta": { - "@swc/helpers": { - "optional": true - } - } - }, - "node_modules/@swc/core-linux-arm64-gnu": { - "version": "1.3.100", - "cpu": [ - "arm64" - ], - "dev": true, - "license": "Apache-2.0 AND MIT", - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">=10" - } - }, - "node_modules/@swc/core-linux-arm64-musl": { - "version": "1.3.100", - "cpu": [ - "arm64" - ], - "dev": true, - "license": "Apache-2.0 AND MIT", - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">=10" - } - }, - "node_modules/@swc/counter": { - "version": "0.1.2", - "dev": true, - "license": "Apache-2.0" - }, - "node_modules/@swc/types": { - "version": "0.1.5", - "dev": true, - "license": "Apache-2.0" - }, - "node_modules/@tailwindcss/container-queries": { - "version": "0.1.1", - "resolved": "https://registry.npmjs.org/@tailwindcss/container-queries/-/container-queries-0.1.1.tgz", - "integrity": "sha512-p18dswChx6WnTSaJCSGx6lTmrGzNNvm2FtXmiO6AuA1V4U5REyoqwmT6kgAsIMdjo07QdAfYXHJ4hnMtfHzWgA==", - "peerDependencies": { - "tailwindcss": ">=3.2.0" - } - }, - "node_modules/@tanstack/query-core": { - "version": "4.36.1", - "license": "MIT", - "funding": { - "type": "github", - "url": "https://github.com/sponsors/tannerlinsley" - } - }, - "node_modules/@tanstack/react-query": { - "version": "4.36.1", - "license": "MIT", - "dependencies": { - "@tanstack/query-core": "4.36.1", - "use-sync-external-store": "^1.2.0" - }, - "funding": { - "type": "github", - "url": "https://github.com/sponsors/tannerlinsley" - }, - "peerDependencies": { - "react": "^16.8.0 || ^17.0.0 || ^18.0.0", - "react-dom": "^16.8.0 || ^17.0.0 || ^18.0.0", - "react-native": "*" - }, - "peerDependenciesMeta": { - "react-dom": { - "optional": true - }, - "react-native": { - "optional": true - } - } - }, - "node_modules/@tanstack/react-table": { - "version": "8.10.7", - "license": "MIT", - "dependencies": { - "@tanstack/table-core": "8.10.7" - }, - "engines": { - "node": ">=12" - }, - "funding": { - "type": "github", - "url": "https://github.com/sponsors/tannerlinsley" - }, - "peerDependencies": { - "react": ">=16", - "react-dom": ">=16" - } - }, - "node_modules/@tanstack/react-virtual": { - "version": "3.0.1", - "license": "MIT", - "dependencies": { - "@tanstack/virtual-core": "3.0.0" - }, - "funding": { - "type": "github", - "url": "https://github.com/sponsors/tannerlinsley" - }, - "peerDependencies": { - "react": "^16.8.0 || ^17.0.0 || ^18.0.0", - "react-dom": "^16.8.0 || ^17.0.0 || ^18.0.0" - } - }, - "node_modules/@tanstack/table-core": { - "version": "8.10.7", - "license": "MIT", - "engines": { - "node": ">=12" - }, - "funding": { - "type": "github", - "url": "https://github.com/sponsors/tannerlinsley" - } - }, - "node_modules/@tanstack/virtual-core": { - "version": "3.0.0", - "license": "MIT", - "funding": { - "type": "github", - "url": "https://github.com/sponsors/tannerlinsley" - } - }, - "node_modules/@testing-library/dom": { - "version": "9.3.3", - "dev": true, - "license": "MIT", - "dependencies": { - "@babel/code-frame": "^7.10.4", - "@babel/runtime": "^7.12.5", - "@types/aria-query": "^5.0.1", - "aria-query": "5.1.3", - "chalk": "^4.1.0", - "dom-accessibility-api": "^0.5.9", - "lz-string": "^1.5.0", - "pretty-format": "^27.0.2" - }, - "engines": { - "node": ">=14" - } - }, - "node_modules/@testing-library/jest-dom": { - "version": "6.1.5", - "dev": true, - "license": "MIT", - "dependencies": { - "@adobe/css-tools": "^4.3.1", - "@babel/runtime": "^7.9.2", - "aria-query": "^5.0.0", - "chalk": "^3.0.0", - "css.escape": "^1.5.1", - "dom-accessibility-api": "^0.5.6", - "lodash": "^4.17.15", - "redent": "^3.0.0" - }, - "engines": { - "node": ">=14", - "npm": ">=6", - "yarn": ">=1" - }, - "peerDependencies": { - "@jest/globals": ">= 28", - "@types/jest": ">= 28", - "jest": ">= 28", - "vitest": ">= 0.32" - }, - "peerDependenciesMeta": { - "@jest/globals": { - "optional": true - }, - "@types/jest": { - "optional": true - }, - "jest": { - "optional": true - }, - "vitest": { - "optional": true - } - } - }, - "node_modules/@testing-library/jest-dom/node_modules/chalk": { - "version": "3.0.0", - "dev": true, - "license": "MIT", - "dependencies": { - "ansi-styles": "^4.1.0", - "supports-color": "^7.1.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/@testing-library/react": { - "version": "14.1.2", - "dev": true, - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.12.5", - "@testing-library/dom": "^9.0.0", - "@types/react-dom": "^18.0.0" - }, - "engines": { - "node": ">=14" - }, - "peerDependencies": { - "react": "^18.0.0", - "react-dom": "^18.0.0" - } - }, - "node_modules/@testing-library/user-event": { - "version": "14.5.1", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=12", - "npm": ">=6" - }, - "peerDependencies": { - "@testing-library/dom": ">=7.21.4" - } - }, - "node_modules/@tootallnate/once": { - "version": "2.0.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">= 10" - } - }, - "node_modules/@types/aria-query": { - "version": "5.0.4", - "dev": true, - "license": "MIT" - }, - "node_modules/@types/chai": { - "version": "4.3.11", - "dev": true, - "license": "MIT" - }, - "node_modules/@types/chai-subset": { - "version": "1.3.5", - "dev": true, - "license": "MIT", - "dependencies": { - "@types/chai": "*" - } - }, - "node_modules/@types/command-line-args": { - "version": "5.2.0", - "license": "MIT" - }, - "node_modules/@types/command-line-usage": { - "version": "5.0.2", - "license": "MIT" - }, - "node_modules/@types/d3": { - "version": "7.4.3", - "license": "MIT", - "dependencies": { - "@types/d3-array": "*", - "@types/d3-axis": "*", - "@types/d3-brush": "*", - "@types/d3-chord": "*", - "@types/d3-color": "*", - "@types/d3-contour": "*", - "@types/d3-delaunay": "*", - "@types/d3-dispatch": "*", - "@types/d3-drag": "*", - "@types/d3-dsv": "*", - "@types/d3-ease": "*", - "@types/d3-fetch": "*", - "@types/d3-force": "*", - "@types/d3-format": "*", - "@types/d3-geo": "*", - "@types/d3-hierarchy": "*", - "@types/d3-interpolate": "*", - "@types/d3-path": "*", - "@types/d3-polygon": "*", - "@types/d3-quadtree": "*", - "@types/d3-random": "*", - "@types/d3-scale": "*", - "@types/d3-scale-chromatic": "*", - "@types/d3-selection": "*", - "@types/d3-shape": "*", - "@types/d3-time": "*", - "@types/d3-time-format": "*", - "@types/d3-timer": "*", - "@types/d3-transition": "*", - "@types/d3-zoom": "*" - } - }, - "node_modules/@types/d3-array": { - "version": "3.2.1", - "license": "MIT" - }, - "node_modules/@types/d3-axis": { - "version": "3.0.6", - "license": "MIT", - "dependencies": { - "@types/d3-selection": "*" - } - }, - "node_modules/@types/d3-brush": { - "version": "3.0.6", - "license": "MIT", - "dependencies": { - "@types/d3-selection": "*" - } - }, - "node_modules/@types/d3-chord": { - "version": "3.0.6", - "license": "MIT" - }, - "node_modules/@types/d3-color": { - "version": "3.1.3", - "license": "MIT" - }, - "node_modules/@types/d3-contour": { - "version": "3.0.6", - "license": "MIT", - "dependencies": { - "@types/d3-array": "*", - "@types/geojson": "*" - } - }, - "node_modules/@types/d3-delaunay": { - "version": "6.0.4", - "license": "MIT" - }, - "node_modules/@types/d3-dispatch": { - "version": "3.0.6", - "license": "MIT" - }, - "node_modules/@types/d3-drag": { - "version": "3.0.7", - "license": "MIT", - "dependencies": { - "@types/d3-selection": "*" - } - }, - "node_modules/@types/d3-dsv": { - "version": "3.0.7", - "license": "MIT" - }, - "node_modules/@types/d3-ease": { - "version": "3.0.2", - "license": "MIT" - }, - "node_modules/@types/d3-fetch": { - "version": "3.0.7", - "license": "MIT", - "dependencies": { - "@types/d3-dsv": "*" - } - }, - "node_modules/@types/d3-force": { - "version": "3.0.9", - "license": "MIT" - }, - "node_modules/@types/d3-format": { - "version": "3.0.4", - "license": "MIT" - }, - "node_modules/@types/d3-geo": { - "version": "3.1.0", - "license": "MIT", - "dependencies": { - "@types/geojson": "*" - } - }, - "node_modules/@types/d3-hierarchy": { - "version": "3.1.6", - "license": "MIT" - }, - "node_modules/@types/d3-interpolate": { - "version": "3.0.4", - "license": "MIT", - "dependencies": { - "@types/d3-color": "*" - } - }, - "node_modules/@types/d3-path": { - "version": "3.0.2", - "license": "MIT" - }, - "node_modules/@types/d3-polygon": { - "version": "3.0.2", - "license": "MIT" - }, - "node_modules/@types/d3-quadtree": { - "version": "3.0.6", - "license": "MIT" - }, - "node_modules/@types/d3-random": { - "version": "3.0.3", - "license": "MIT" - }, - "node_modules/@types/d3-scale": { - "version": "4.0.8", - "license": "MIT", - "dependencies": { - "@types/d3-time": "*" - } - }, - "node_modules/@types/d3-scale-chromatic": { - "version": "3.0.3", - "license": "MIT" - }, - "node_modules/@types/d3-selection": { - "version": "3.0.10", - "license": "MIT" - }, - "node_modules/@types/d3-shape": { - "version": "3.1.6", - "license": "MIT", - "dependencies": { - "@types/d3-path": "*" - } - }, - "node_modules/@types/d3-time": { - "version": "3.0.3", - "license": "MIT" - }, - "node_modules/@types/d3-time-format": { - "version": "4.0.3", - "license": "MIT" - }, - "node_modules/@types/d3-timer": { - "version": "3.0.2", - "license": "MIT" - }, - "node_modules/@types/d3-transition": { - "version": "3.0.8", - "license": "MIT", - "dependencies": { - "@types/d3-selection": "*" - } - }, - "node_modules/@types/d3-zoom": { - "version": "3.0.8", - "license": "MIT", - "dependencies": { - "@types/d3-interpolate": "*", - "@types/d3-selection": "*" - } - }, - "node_modules/@types/diff": { - "version": "5.2.1", - "resolved": "https://registry.npmjs.org/@types/diff/-/diff-5.2.1.tgz", - "integrity": "sha512-uxpcuwWJGhe2AR1g8hD9F5OYGCqjqWnBUQFD8gMZsDbv8oPHzxJF6iMO6n8Tk0AdzlxoaaoQhOYlIg/PukVU8g==", - "dev": true, - "license": "MIT" - }, - "node_modules/@types/es-aggregate-error": { - "version": "1.0.6", - "dev": true, - "license": "MIT", - "dependencies": { - "@types/node": "*" - } - }, - "node_modules/@types/geojson": { - "version": "7946.0.13", - "license": "MIT" - }, - "node_modules/@types/json-schema": { - "version": "7.0.15", - "dev": true, - "license": "MIT" - }, - "node_modules/@types/json5": { - "version": "0.0.29", - "dev": true, - "license": "MIT" - }, - "node_modules/@types/node": { - "version": "20.3.0", - "license": "MIT" - }, - "node_modules/@types/pad-left": { - "version": "2.1.1", - "license": "MIT" - }, - "node_modules/@types/pluralize": { - "version": "0.0.30", - "dev": true, - "license": "MIT" - }, - "node_modules/@types/prop-types": { - "version": "15.7.11", - "devOptional": true, - "license": "MIT" - }, - "node_modules/@types/react": { - "version": "18.2.42", - "devOptional": true, - "license": "MIT", - "dependencies": { - "@types/prop-types": "*", - "@types/scheduler": "*", - "csstype": "^3.0.2" - } - }, - "node_modules/@types/react-dom": { - "version": "18.2.17", - "devOptional": true, - "license": "MIT", - "dependencies": { - "@types/react": "*" - } - }, - "node_modules/@types/scheduler": { - "version": "0.16.8", - "devOptional": true, - "license": "MIT" - }, - "node_modules/@types/semver": { - "version": "7.5.6", - "dev": true, - "license": "MIT" - }, - "node_modules/@types/urijs": { - "version": "1.19.25", - "dev": true, - "license": "MIT" - }, - "node_modules/@typescript-eslint/eslint-plugin": { - "version": "6.13.2", - "dev": true, - "license": "MIT", - "dependencies": { - "@eslint-community/regexpp": "^4.5.1", - "@typescript-eslint/scope-manager": "6.13.2", - "@typescript-eslint/type-utils": "6.13.2", - "@typescript-eslint/utils": "6.13.2", - "@typescript-eslint/visitor-keys": "6.13.2", - "debug": "^4.3.4", - "graphemer": "^1.4.0", - "ignore": "^5.2.4", - "natural-compare": "^1.4.0", - "semver": "^7.5.4", - "ts-api-utils": "^1.0.1" - }, - "engines": { - "node": "^16.0.0 || >=18.0.0" - }, - "funding": { - "type": "opencollective", - "url": "https://opencollective.com/typescript-eslint" - }, - "peerDependencies": { - "@typescript-eslint/parser": "^6.0.0 || ^6.0.0-alpha", - "eslint": "^7.0.0 || ^8.0.0" - }, - "peerDependenciesMeta": { - "typescript": { - "optional": true - } - } - }, - "node_modules/@typescript-eslint/parser": { - "version": "6.13.2", - "dev": true, - "license": "BSD-2-Clause", - "dependencies": { - "@typescript-eslint/scope-manager": "6.13.2", - "@typescript-eslint/types": "6.13.2", - "@typescript-eslint/typescript-estree": "6.13.2", - "@typescript-eslint/visitor-keys": "6.13.2", - "debug": "^4.3.4" - }, - "engines": { - "node": "^16.0.0 || >=18.0.0" - }, - "funding": { - "type": "opencollective", - "url": "https://opencollective.com/typescript-eslint" - }, - "peerDependencies": { - "eslint": "^7.0.0 || ^8.0.0" - }, - "peerDependenciesMeta": { - "typescript": { - "optional": true - } - } - }, - "node_modules/@typescript-eslint/scope-manager": { - "version": "6.13.2", - "dev": true, - "license": "MIT", - "dependencies": { - "@typescript-eslint/types": "6.13.2", - "@typescript-eslint/visitor-keys": "6.13.2" - }, - "engines": { - "node": "^16.0.0 || >=18.0.0" - }, - "funding": { - "type": "opencollective", - "url": "https://opencollective.com/typescript-eslint" - } - }, - "node_modules/@typescript-eslint/type-utils": { - "version": "6.13.2", - "dev": true, - "license": "MIT", - "dependencies": { - "@typescript-eslint/typescript-estree": "6.13.2", - "@typescript-eslint/utils": "6.13.2", - "debug": "^4.3.4", - "ts-api-utils": "^1.0.1" - }, - "engines": { - "node": "^16.0.0 || >=18.0.0" - }, - "funding": { - "type": "opencollective", - "url": "https://opencollective.com/typescript-eslint" - }, - "peerDependencies": { - "eslint": "^7.0.0 || ^8.0.0" - }, - "peerDependenciesMeta": { - "typescript": { - "optional": true - } - } - }, - "node_modules/@typescript-eslint/types": { - "version": "6.13.2", - "dev": true, - "license": "MIT", - "engines": { - "node": "^16.0.0 || >=18.0.0" - }, - "funding": { - "type": "opencollective", - "url": "https://opencollective.com/typescript-eslint" - } - }, - "node_modules/@typescript-eslint/typescript-estree": { - "version": "6.13.2", - "dev": true, - "license": "BSD-2-Clause", - "dependencies": { - "@typescript-eslint/types": "6.13.2", - "@typescript-eslint/visitor-keys": "6.13.2", - "debug": "^4.3.4", - "globby": "^11.1.0", - "is-glob": "^4.0.3", - "semver": "^7.5.4", - "ts-api-utils": "^1.0.1" - }, - "engines": { - "node": "^16.0.0 || >=18.0.0" - }, - "funding": { - "type": "opencollective", - "url": "https://opencollective.com/typescript-eslint" - }, - "peerDependenciesMeta": { - "typescript": { - "optional": true - } - } - }, - "node_modules/@typescript-eslint/utils": { - "version": "6.13.2", - "dev": true, - "license": "MIT", - "dependencies": { - "@eslint-community/eslint-utils": "^4.4.0", - "@types/json-schema": "^7.0.12", - "@types/semver": "^7.5.0", - "@typescript-eslint/scope-manager": "6.13.2", - "@typescript-eslint/types": "6.13.2", - "@typescript-eslint/typescript-estree": "6.13.2", - "semver": "^7.5.4" - }, - "engines": { - "node": "^16.0.0 || >=18.0.0" - }, - "funding": { - "type": "opencollective", - "url": "https://opencollective.com/typescript-eslint" - }, - "peerDependencies": { - "eslint": "^7.0.0 || ^8.0.0" - } - }, - "node_modules/@typescript-eslint/visitor-keys": { - "version": "6.13.2", - "dev": true, - "license": "MIT", - "dependencies": { - "@typescript-eslint/types": "6.13.2", - "eslint-visitor-keys": "^3.4.1" - }, - "engines": { - "node": "^16.0.0 || >=18.0.0" - }, - "funding": { - "type": "opencollective", - "url": "https://opencollective.com/typescript-eslint" - } - }, - "node_modules/@uidotdev/usehooks": { - "version": "2.4.1", - "license": "MIT", - "engines": { - "node": ">=16" - }, - "peerDependencies": { - "react": ">=18.0.0", - "react-dom": ">=18.0.0" - } - }, - "node_modules/@uiw/codemirror-extensions-basic-setup": { - "version": "4.21.21", - "license": "MIT", - "dependencies": { - "@codemirror/autocomplete": "^6.0.0", - "@codemirror/commands": "^6.0.0", - "@codemirror/language": "^6.0.0", - "@codemirror/lint": "^6.0.0", - "@codemirror/search": "^6.0.0", - "@codemirror/state": "^6.0.0", - "@codemirror/view": "^6.0.0" - }, - "funding": { - "url": "https://jaywcjlove.github.io/#/sponsor" - }, - "peerDependencies": { - "@codemirror/autocomplete": ">=6.0.0", - "@codemirror/commands": ">=6.0.0", - "@codemirror/language": ">=6.0.0", - "@codemirror/lint": ">=6.0.0", - "@codemirror/search": ">=6.0.0", - "@codemirror/state": ">=6.0.0", - "@codemirror/view": ">=6.0.0" - } - }, - "node_modules/@uiw/react-codemirror": { - "version": "4.21.21", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.18.6", - "@codemirror/commands": "^6.1.0", - "@codemirror/state": "^6.1.1", - "@codemirror/theme-one-dark": "^6.0.0", - "@uiw/codemirror-extensions-basic-setup": "4.21.21", - "codemirror": "^6.0.0" - }, - "funding": { - "url": "https://jaywcjlove.github.io/#/sponsor" - }, - "peerDependencies": { - "@babel/runtime": ">=7.11.0", - "@codemirror/state": ">=6.0.0", - "@codemirror/theme-one-dark": ">=6.0.0", - "@codemirror/view": ">=6.0.0", - "codemirror": ">=6.0.0", - "react": ">=16.8.0", - "react-dom": ">=16.8.0" - } - }, - "node_modules/@ungap/structured-clone": { - "version": "1.2.0", - "dev": true, - "license": "ISC" - }, - "node_modules/@vitejs/plugin-react-swc": { - "version": "3.5.0", - "dev": true, - "license": "MIT", - "dependencies": { - "@swc/core": "^1.3.96" - }, - "peerDependencies": { - "vite": "^4 || ^5" - } - }, - "node_modules/@vitest/expect": { - "version": "0.34.6", - "dev": true, - "license": "MIT", - "dependencies": { - "@vitest/spy": "0.34.6", - "@vitest/utils": "0.34.6", - "chai": "^4.3.10" - }, - "funding": { - "url": "https://opencollective.com/vitest" - } - }, - "node_modules/@vitest/runner": { - "version": "0.34.6", - "dev": true, - "license": "MIT", - "dependencies": { - "@vitest/utils": "0.34.6", - "p-limit": "^4.0.0", - "pathe": "^1.1.1" - }, - "funding": { - "url": "https://opencollective.com/vitest" - } - }, - "node_modules/@vitest/runner/node_modules/p-limit": { - "version": "4.0.0", - "dev": true, - "license": "MIT", - "dependencies": { - "yocto-queue": "^1.0.0" - }, - "engines": { - "node": "^12.20.0 || ^14.13.1 || >=16.0.0" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/@vitest/runner/node_modules/yocto-queue": { - "version": "1.0.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=12.20" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/@vitest/snapshot": { - "version": "0.34.6", - "dev": true, - "license": "MIT", - "dependencies": { - "magic-string": "^0.30.1", - "pathe": "^1.1.1", - "pretty-format": "^29.5.0" - }, - "funding": { - "url": "https://opencollective.com/vitest" - } - }, - "node_modules/@vitest/snapshot/node_modules/ansi-styles": { - "version": "5.2.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/chalk/ansi-styles?sponsor=1" - } - }, - "node_modules/@vitest/snapshot/node_modules/pretty-format": { - "version": "29.7.0", - "dev": true, - "license": "MIT", - "dependencies": { - "@jest/schemas": "^29.6.3", - "ansi-styles": "^5.0.0", - "react-is": "^18.0.0" - }, - "engines": { - "node": "^14.15.0 || ^16.10.0 || >=18.0.0" - } - }, - "node_modules/@vitest/snapshot/node_modules/react-is": { - "version": "18.2.0", - "dev": true, - "license": "MIT" - }, - "node_modules/@vitest/spy": { - "version": "0.34.6", - "dev": true, - "license": "MIT", - "dependencies": { - "tinyspy": "^2.1.1" - }, - "funding": { - "url": "https://opencollective.com/vitest" - } - }, - "node_modules/@vitest/utils": { - "version": "0.34.6", - "dev": true, - "license": "MIT", - "dependencies": { - "diff-sequences": "^29.4.3", - "loupe": "^2.3.6", - "pretty-format": "^29.5.0" - }, - "funding": { - "url": "https://opencollective.com/vitest" - } - }, - "node_modules/@vitest/utils/node_modules/ansi-styles": { - "version": "5.2.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/chalk/ansi-styles?sponsor=1" - } - }, - "node_modules/@vitest/utils/node_modules/pretty-format": { - "version": "29.7.0", - "dev": true, - "license": "MIT", - "dependencies": { - "@jest/schemas": "^29.6.3", - "ansi-styles": "^5.0.0", - "react-is": "^18.0.0" - }, - "engines": { - "node": "^14.15.0 || ^16.10.0 || >=18.0.0" - } - }, - "node_modules/@vitest/utils/node_modules/react-is": { - "version": "18.2.0", - "dev": true, - "license": "MIT" - }, - "node_modules/abab": { - "version": "2.0.6", - "dev": true, - "license": "BSD-3-Clause" - }, - "node_modules/abort-controller": { - "version": "3.0.0", - "dev": true, - "license": "MIT", - "dependencies": { - "event-target-shim": "^5.0.0" - }, - "engines": { - "node": ">=6.5" - } - }, - "node_modules/acorn": { - "version": "8.11.2", - "dev": true, - "license": "MIT", - "bin": { - "acorn": "bin/acorn" - }, - "engines": { - "node": ">=0.4.0" - } - }, - "node_modules/acorn-jsx": { - "version": "5.3.2", - "dev": true, - "license": "MIT", - "peerDependencies": { - "acorn": "^6.0.0 || ^7.0.0 || ^8.0.0" - } - }, - "node_modules/acorn-walk": { - "version": "8.3.1", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=0.4.0" - } - }, - "node_modules/agent-base": { - "version": "6.0.2", - "dev": true, - "license": "MIT", - "dependencies": { - "debug": "4" - }, - "engines": { - "node": ">= 6.0.0" - } - }, - "node_modules/ajv": { - "version": "6.12.6", - "dev": true, - "license": "MIT", - "dependencies": { - "fast-deep-equal": "^3.1.1", - "fast-json-stable-stringify": "^2.0.0", - "json-schema-traverse": "^0.4.1", - "uri-js": "^4.2.2" - }, - "funding": { - "type": "github", - "url": "https://github.com/sponsors/epoberezkin" - } - }, - "node_modules/ajv-formats": { - "version": "2.1.1", - "dev": true, - "license": "MIT", - "dependencies": { - "ajv": "^8.0.0" - }, - "peerDependencies": { - "ajv": "^8.0.0" - }, - "peerDependenciesMeta": { - "ajv": { - "optional": true - } - } - }, - "node_modules/ajv-formats/node_modules/ajv": { - "version": "8.12.0", - "dev": true, - "license": "MIT", - "dependencies": { - "fast-deep-equal": "^3.1.1", - "json-schema-traverse": "^1.0.0", - "require-from-string": "^2.0.2", - "uri-js": "^4.2.2" - }, - "funding": { - "type": "github", - "url": "https://github.com/sponsors/epoberezkin" - } - }, - "node_modules/ajv-formats/node_modules/json-schema-traverse": { - "version": "1.0.0", - "dev": true, - "license": "MIT" - }, - "node_modules/ansi-colors": { - "version": "4.1.3", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=6" - } - }, - "node_modules/ansi-regex": { - "version": "5.0.1", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=8" - } - }, - "node_modules/ansi-styles": { - "version": "4.3.0", - "license": "MIT", - "dependencies": { - "color-convert": "^2.0.1" - }, - "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/chalk/ansi-styles?sponsor=1" - } - }, - "node_modules/any-promise": { - "version": "1.3.0", - "license": "MIT" - }, - "node_modules/anymatch": { - "version": "3.1.3", - "license": "ISC", - "dependencies": { - "normalize-path": "^3.0.0", - "picomatch": "^2.0.4" - }, - "engines": { - "node": ">= 8" - } - }, - "node_modules/apache-arrow": { - "version": "13.0.0", - "license": "Apache-2.0", - "dependencies": { - "@types/command-line-args": "5.2.0", - "@types/command-line-usage": "5.0.2", - "@types/node": "20.3.0", - "@types/pad-left": "2.1.1", - "command-line-args": "5.2.1", - "command-line-usage": "7.0.1", - "flatbuffers": "23.5.26", - "json-bignum": "^0.0.3", - "pad-left": "^2.1.0", - "tslib": "^2.5.3" - }, - "bin": { - "arrow2csv": "bin/arrow2csv.js" - } - }, - "node_modules/arg": { - "version": "5.0.2", - "license": "MIT" - }, - "node_modules/argparse": { - "version": "2.0.1", - "dev": true, - "license": "Python-2.0" - }, - "node_modules/aria-hidden": { - "version": "1.2.3", - "license": "MIT", - "dependencies": { - "tslib": "^2.0.0" - }, - "engines": { - "node": ">=10" - } - }, - "node_modules/aria-query": { - "version": "5.1.3", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "deep-equal": "^2.0.5" - } - }, - "node_modules/array-back": { - "version": "3.1.0", - "license": "MIT", - "engines": { - "node": ">=6" - } - }, - "node_modules/array-buffer-byte-length": { - "version": "1.0.0", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.2", - "is-array-buffer": "^3.0.1" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/array-includes": { - "version": "3.1.7", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.2", - "define-properties": "^1.2.0", - "es-abstract": "^1.22.1", - "get-intrinsic": "^1.2.1", - "is-string": "^1.0.7" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/array-union": { - "version": "2.1.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=8" - } - }, - "node_modules/array.prototype.findlastindex": { - "version": "1.2.3", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.2", - "define-properties": "^1.2.0", - "es-abstract": "^1.22.1", - "es-shim-unscopables": "^1.0.0", - "get-intrinsic": "^1.2.1" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/array.prototype.flat": { - "version": "1.3.2", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.2", - "define-properties": "^1.2.0", - "es-abstract": "^1.22.1", - "es-shim-unscopables": "^1.0.0" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/array.prototype.flatmap": { - "version": "1.3.2", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.2", - "define-properties": "^1.2.0", - "es-abstract": "^1.22.1", - "es-shim-unscopables": "^1.0.0" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/array.prototype.tosorted": { - "version": "1.1.2", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.2", - "define-properties": "^1.2.0", - "es-abstract": "^1.22.1", - "es-shim-unscopables": "^1.0.0", - "get-intrinsic": "^1.2.1" - } - }, - "node_modules/arraybuffer.prototype.slice": { - "version": "1.0.2", - "dev": true, - "license": "MIT", - "dependencies": { - "array-buffer-byte-length": "^1.0.0", - "call-bind": "^1.0.2", - "define-properties": "^1.2.0", - "es-abstract": "^1.22.1", - "get-intrinsic": "^1.2.1", - "is-array-buffer": "^3.0.2", - "is-shared-array-buffer": "^1.0.2" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/assertion-error": { - "version": "1.1.0", - "dev": true, - "license": "MIT", - "engines": { - "node": "*" - } - }, - "node_modules/astring": { - "version": "1.8.6", - "dev": true, - "license": "MIT", - "bin": { - "astring": "bin/astring" - } - }, - "node_modules/asynciterator.prototype": { - "version": "1.0.0", - "dev": true, - "license": "MIT", - "dependencies": { - "has-symbols": "^1.0.3" - } - }, - "node_modules/asynckit": { - "version": "0.4.0", - "dev": true, - "license": "MIT" - }, - "node_modules/autoprefixer": { - "version": "10.4.16", - "dev": true, - "funding": [ - { - "type": "opencollective", - "url": "https://opencollective.com/postcss/" - }, - { - "type": "tidelift", - "url": "https://tidelift.com/funding/github/npm/autoprefixer" - }, - { - "type": "github", - "url": "https://github.com/sponsors/ai" - } - ], - "license": "MIT", - "dependencies": { - "browserslist": "^4.21.10", - "caniuse-lite": "^1.0.30001538", - "fraction.js": "^4.3.6", - "normalize-range": "^0.1.2", - "picocolors": "^1.0.0", - "postcss-value-parser": "^4.2.0" - }, - "bin": { - "autoprefixer": "bin/autoprefixer" - }, - "engines": { - "node": "^10 || ^12 || >=14" - }, - "peerDependencies": { - "postcss": "^8.1.0" - } - }, - "node_modules/available-typed-arrays": { - "version": "1.0.5", - "dev": true, - "license": "MIT", - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/balanced-match": { - "version": "1.0.2", - "license": "MIT" - }, - "node_modules/binary-extensions": { - "version": "2.2.0", - "license": "MIT", - "engines": { - "node": ">=8" - } - }, - "node_modules/brace-expansion": { - "version": "1.1.11", - "license": "MIT", - "dependencies": { - "balanced-match": "^1.0.0", - "concat-map": "0.0.1" - } - }, - "node_modules/braces": { - "version": "3.0.2", - "license": "MIT", - "dependencies": { - "fill-range": "^7.0.1" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/browserslist": { - "version": "4.22.2", - "dev": true, - "funding": [ - { - "type": "opencollective", - "url": "https://opencollective.com/browserslist" - }, - { - "type": "tidelift", - "url": "https://tidelift.com/funding/github/npm/browserslist" - }, - { - "type": "github", - "url": "https://github.com/sponsors/ai" - } - ], - "license": "MIT", - "dependencies": { - "caniuse-lite": "^1.0.30001565", - "electron-to-chromium": "^1.4.601", - "node-releases": "^2.0.14", - "update-browserslist-db": "^1.0.13" - }, - "bin": { - "browserslist": "cli.js" - }, - "engines": { - "node": "^6 || ^7 || ^8 || ^9 || ^10 || ^11 || ^12 || >=13.7" - } - }, - "node_modules/builtin-modules": { - "version": "3.3.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=6" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/builtins": { - "version": "5.0.1", - "dev": true, - "license": "MIT", - "dependencies": { - "semver": "^7.0.0" - } - }, - "node_modules/cac": { - "version": "6.7.14", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=8" - } - }, - "node_modules/call-bind": { - "version": "1.0.5", - "dev": true, - "license": "MIT", - "dependencies": { - "function-bind": "^1.1.2", - "get-intrinsic": "^1.2.1", - "set-function-length": "^1.1.1" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/call-me-maybe": { - "version": "1.0.2", - "dev": true, - "license": "MIT" - }, - "node_modules/callsites": { - "version": "3.1.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=6" - } - }, - "node_modules/camelcase-css": { - "version": "2.0.1", - "license": "MIT", - "engines": { - "node": ">= 6" - } - }, - "node_modules/caniuse-lite": { - "version": "1.0.30001636", - "resolved": "https://registry.npmjs.org/caniuse-lite/-/caniuse-lite-1.0.30001636.tgz", - "integrity": "sha512-bMg2vmr8XBsbL6Lr0UHXy/21m84FTxDLWn2FSqMd5PrlbMxwJlQnC2YWYxVgp66PZE+BBNF2jYQUBKCo1FDeZg==", - "dev": true, - "funding": [ - { - "type": "opencollective", - "url": "https://opencollective.com/browserslist" - }, - { - "type": "tidelift", - "url": "https://tidelift.com/funding/github/npm/caniuse-lite" - }, - { - "type": "github", - "url": "https://github.com/sponsors/ai" - } - ], - "license": "CC-BY-4.0" - }, - "node_modules/chai": { - "version": "4.3.10", - "dev": true, - "license": "MIT", - "dependencies": { - "assertion-error": "^1.1.0", - "check-error": "^1.0.3", - "deep-eql": "^4.1.3", - "get-func-name": "^2.0.2", - "loupe": "^2.3.6", - "pathval": "^1.1.1", - "type-detect": "^4.0.8" - }, - "engines": { - "node": ">=4" - } - }, - "node_modules/chalk": { - "version": "4.1.2", - "license": "MIT", - "dependencies": { - "ansi-styles": "^4.1.0", - "supports-color": "^7.1.0" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/chalk/chalk?sponsor=1" - } - }, - "node_modules/chalk-template": { - "version": "0.4.0", - "license": "MIT", - "dependencies": { - "chalk": "^4.1.2" - }, - "engines": { - "node": ">=12" - }, - "funding": { - "url": "https://github.com/chalk/chalk-template?sponsor=1" - } - }, - "node_modules/check-error": { - "version": "1.0.3", - "dev": true, - "license": "MIT", - "dependencies": { - "get-func-name": "^2.0.2" - }, - "engines": { - "node": "*" - } - }, - "node_modules/chokidar": { - "version": "3.5.3", - "funding": [ - { - "type": "individual", - "url": "https://paulmillr.com/funding/" - } - ], - "license": "MIT", - "dependencies": { - "anymatch": "~3.1.2", - "braces": "~3.0.2", - "glob-parent": "~5.1.2", - "is-binary-path": "~2.1.0", - "is-glob": "~4.0.1", - "normalize-path": "~3.0.0", - "readdirp": "~3.6.0" - }, - "engines": { - "node": ">= 8.10.0" - }, - "optionalDependencies": { - "fsevents": "~2.3.2" - } - }, - "node_modules/chokidar/node_modules/glob-parent": { - "version": "5.1.2", - "license": "ISC", - "dependencies": { - "is-glob": "^4.0.1" - }, - "engines": { - "node": ">= 6" - } - }, - "node_modules/classcat": { - "version": "5.0.4", - "license": "MIT" - }, - "node_modules/client-only": { - "version": "0.0.1", - "license": "MIT" - }, - "node_modules/cliui": { - "version": "8.0.1", - "dev": true, - "license": "ISC", - "dependencies": { - "string-width": "^4.2.0", - "strip-ansi": "^6.0.1", - "wrap-ansi": "^7.0.0" - }, - "engines": { - "node": ">=12" - } - }, - "node_modules/clsx": { - "version": "2.0.0", - "license": "MIT", - "engines": { - "node": ">=6" - } - }, - "node_modules/codemirror": { - "version": "6.0.1", - "license": "MIT", - "dependencies": { - "@codemirror/autocomplete": "^6.0.0", - "@codemirror/commands": "^6.0.0", - "@codemirror/language": "^6.0.0", - "@codemirror/lint": "^6.0.0", - "@codemirror/search": "^6.0.0", - "@codemirror/state": "^6.0.0", - "@codemirror/view": "^6.0.0" - } - }, - "node_modules/color-convert": { - "version": "2.0.1", - "license": "MIT", - "dependencies": { - "color-name": "~1.1.4" - }, - "engines": { - "node": ">=7.0.0" - } - }, - "node_modules/color-name": { - "version": "1.1.4", - "license": "MIT" - }, - "node_modules/combined-stream": { - "version": "1.0.8", - "dev": true, - "license": "MIT", - "dependencies": { - "delayed-stream": "~1.0.0" - }, - "engines": { - "node": ">= 0.8" - } - }, - "node_modules/command-line-args": { - "version": "5.2.1", - "license": "MIT", - "dependencies": { - "array-back": "^3.1.0", - "find-replace": "^3.0.0", - "lodash.camelcase": "^4.3.0", - "typical": "^4.0.0" - }, - "engines": { - "node": ">=4.0.0" - } - }, - "node_modules/command-line-usage": { - "version": "7.0.1", - "license": "MIT", - "dependencies": { - "array-back": "^6.2.2", - "chalk-template": "^0.4.0", - "table-layout": "^3.0.0", - "typical": "^7.1.1" - }, - "engines": { - "node": ">=12.20.0" - } - }, - "node_modules/command-line-usage/node_modules/array-back": { - "version": "6.2.2", - "license": "MIT", - "engines": { - "node": ">=12.17" - } - }, - "node_modules/command-line-usage/node_modules/typical": { - "version": "7.1.1", - "license": "MIT", - "engines": { - "node": ">=12.17" - } - }, - "node_modules/commander": { - "version": "4.1.1", - "license": "MIT", - "engines": { - "node": ">= 6" - } - }, - "node_modules/compare-versions": { - "version": "6.1.0", - "dev": true, - "license": "MIT" - }, - "node_modules/concat-map": { - "version": "0.0.1", - "license": "MIT" - }, - "node_modules/crelt": { - "version": "1.0.6", - "license": "MIT" - }, - "node_modules/cross-spawn": { - "version": "7.0.3", - "dev": true, - "license": "MIT", - "dependencies": { - "path-key": "^3.1.0", - "shebang-command": "^2.0.0", - "which": "^2.0.1" - }, - "engines": { - "node": ">= 8" - } - }, - "node_modules/css.escape": { - "version": "1.5.1", - "dev": true, - "license": "MIT" - }, - "node_modules/cssesc": { - "version": "3.0.0", - "license": "MIT", - "bin": { - "cssesc": "bin/cssesc" - }, - "engines": { - "node": ">=4" - } - }, - "node_modules/cssstyle": { - "version": "3.0.0", - "dev": true, - "license": "MIT", - "dependencies": { - "rrweb-cssom": "^0.6.0" - }, - "engines": { - "node": ">=14" - } - }, - "node_modules/csstype": { - "version": "3.1.3", - "devOptional": true, - "license": "MIT" - }, - "node_modules/d3-color": { - "version": "3.1.0", - "license": "ISC", - "engines": { - "node": ">=12" - } - }, - "node_modules/d3-dispatch": { - "version": "3.0.1", - "license": "ISC", - "engines": { - "node": ">=12" - } - }, - "node_modules/d3-drag": { - "version": "3.0.0", - "license": "ISC", - "dependencies": { - "d3-dispatch": "1 - 3", - "d3-selection": "3" - }, - "engines": { - "node": ">=12" - } - }, - "node_modules/d3-ease": { - "version": "3.0.1", - "license": "BSD-3-Clause", - "engines": { - "node": ">=12" - } - }, - "node_modules/d3-interpolate": { - "version": "3.0.1", - "license": "ISC", - "dependencies": { - "d3-color": "1 - 3" - }, - "engines": { - "node": ">=12" - } - }, - "node_modules/d3-selection": { - "version": "3.0.0", - "license": "ISC", - "engines": { - "node": ">=12" - } - }, - "node_modules/d3-timer": { - "version": "3.0.1", - "license": "ISC", - "engines": { - "node": ">=12" - } - }, - "node_modules/d3-transition": { - "version": "3.0.1", - "license": "ISC", - "dependencies": { - "d3-color": "1 - 3", - "d3-dispatch": "1 - 3", - "d3-ease": "1 - 3", - "d3-interpolate": "1 - 3", - "d3-timer": "1 - 3" - }, - "engines": { - "node": ">=12" - }, - "peerDependencies": { - "d3-selection": "2 - 3" - } - }, - "node_modules/d3-zoom": { - "version": "3.0.0", - "license": "ISC", - "dependencies": { - "d3-dispatch": "1 - 3", - "d3-drag": "2 - 3", - "d3-interpolate": "1 - 3", - "d3-selection": "2 - 3", - "d3-transition": "2 - 3" - }, - "engines": { - "node": ">=12" - } - }, - "node_modules/data-urls": { - "version": "4.0.0", - "dev": true, - "license": "MIT", - "dependencies": { - "abab": "^2.0.6", - "whatwg-mimetype": "^3.0.0", - "whatwg-url": "^12.0.0" - }, - "engines": { - "node": ">=14" - } - }, - "node_modules/debug": { - "version": "4.3.4", - "dev": true, - "license": "MIT", - "dependencies": { - "ms": "2.1.2" - }, - "engines": { - "node": ">=6.0" - }, - "peerDependenciesMeta": { - "supports-color": { - "optional": true - } - } - }, - "node_modules/decimal.js": { - "version": "10.4.3", - "dev": true, - "license": "MIT" - }, - "node_modules/deep-eql": { - "version": "4.1.3", - "dev": true, - "license": "MIT", - "dependencies": { - "type-detect": "^4.0.0" - }, - "engines": { - "node": ">=6" - } - }, - "node_modules/deep-equal": { - "version": "2.2.3", - "dev": true, - "license": "MIT", - "dependencies": { - "array-buffer-byte-length": "^1.0.0", - "call-bind": "^1.0.5", - "es-get-iterator": "^1.1.3", - "get-intrinsic": "^1.2.2", - "is-arguments": "^1.1.1", - "is-array-buffer": "^3.0.2", - "is-date-object": "^1.0.5", - "is-regex": "^1.1.4", - "is-shared-array-buffer": "^1.0.2", - "isarray": "^2.0.5", - "object-is": "^1.1.5", - "object-keys": "^1.1.1", - "object.assign": "^4.1.4", - "regexp.prototype.flags": "^1.5.1", - "side-channel": "^1.0.4", - "which-boxed-primitive": "^1.0.2", - "which-collection": "^1.0.1", - "which-typed-array": "^1.1.13" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/deep-is": { - "version": "0.1.4", - "dev": true, - "license": "MIT" - }, - "node_modules/define-data-property": { - "version": "1.1.1", - "dev": true, - "license": "MIT", - "dependencies": { - "get-intrinsic": "^1.2.1", - "gopd": "^1.0.1", - "has-property-descriptors": "^1.0.0" - }, - "engines": { - "node": ">= 0.4" - } - }, - "node_modules/define-properties": { - "version": "1.2.1", - "dev": true, - "license": "MIT", - "dependencies": { - "define-data-property": "^1.0.1", - "has-property-descriptors": "^1.0.0", - "object-keys": "^1.1.1" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/delayed-stream": { - "version": "1.0.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=0.4.0" - } - }, - "node_modules/dependency-graph": { - "version": "0.11.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">= 0.6.0" - } - }, - "node_modules/detect-node-es": { - "version": "1.1.0", - "license": "MIT" - }, - "node_modules/didyoumean": { - "version": "1.2.2", - "license": "Apache-2.0" - }, - "node_modules/diff": { - "version": "5.2.0", - "resolved": "https://registry.npmjs.org/diff/-/diff-5.2.0.tgz", - "integrity": "sha512-uIFDxqpRZGZ6ThOk84hEfqWoHx2devRFvpTZcTHur85vImfaxUbTW9Ryh4CpCuDnToOP1CEtXKIgytHBPVff5A==", - "engines": { - "node": ">=0.3.1" - } - }, - "node_modules/diff-sequences": { - "version": "29.6.3", - "dev": true, - "license": "MIT", - "engines": { - "node": "^14.15.0 || ^16.10.0 || >=18.0.0" - } - }, - "node_modules/dir-glob": { - "version": "3.0.1", - "dev": true, - "license": "MIT", - "dependencies": { - "path-type": "^4.0.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/dlv": { - "version": "1.1.3", - "license": "MIT" - }, - "node_modules/dnd-core": { - "version": "16.0.1", - "license": "MIT", - "dependencies": { - "@react-dnd/asap": "^5.0.1", - "@react-dnd/invariant": "^4.0.1", - "redux": "^4.2.0" - } - }, - "node_modules/doctrine": { - "version": "3.0.0", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "esutils": "^2.0.2" - }, - "engines": { - "node": ">=6.0.0" - } - }, - "node_modules/dom-accessibility-api": { - "version": "0.5.16", - "dev": true, - "license": "MIT" - }, - "node_modules/domexception": { - "version": "4.0.0", - "dev": true, - "license": "MIT", - "dependencies": { - "webidl-conversions": "^7.0.0" - }, - "engines": { - "node": ">=12" - } - }, - "node_modules/electron-to-chromium": { - "version": "1.4.608", - "dev": true, - "license": "ISC" - }, - "node_modules/elkjs": { - "version": "0.8.2", - "license": "EPL-2.0" - }, - "node_modules/emoji-regex": { - "version": "8.0.0", - "dev": true, - "license": "MIT" - }, - "node_modules/enquirer": { - "version": "2.4.1", - "dev": true, - "license": "MIT", - "dependencies": { - "ansi-colors": "^4.1.1", - "strip-ansi": "^6.0.1" - }, - "engines": { - "node": ">=8.6" - } - }, - "node_modules/entities": { - "version": "4.5.0", - "dev": true, - "license": "BSD-2-Clause", - "engines": { - "node": ">=0.12" - }, - "funding": { - "url": "https://github.com/fb55/entities?sponsor=1" - } - }, - "node_modules/es-abstract": { - "version": "1.22.3", - "dev": true, - "license": "MIT", - "dependencies": { - "array-buffer-byte-length": "^1.0.0", - "arraybuffer.prototype.slice": "^1.0.2", - "available-typed-arrays": "^1.0.5", - "call-bind": "^1.0.5", - "es-set-tostringtag": "^2.0.1", - "es-to-primitive": "^1.2.1", - "function.prototype.name": "^1.1.6", - "get-intrinsic": "^1.2.2", - "get-symbol-description": "^1.0.0", - "globalthis": "^1.0.3", - "gopd": "^1.0.1", - "has-property-descriptors": "^1.0.0", - "has-proto": "^1.0.1", - "has-symbols": "^1.0.3", - "hasown": "^2.0.0", - "internal-slot": "^1.0.5", - "is-array-buffer": "^3.0.2", - "is-callable": "^1.2.7", - "is-negative-zero": "^2.0.2", - "is-regex": "^1.1.4", - "is-shared-array-buffer": "^1.0.2", - "is-string": "^1.0.7", - "is-typed-array": "^1.1.12", - "is-weakref": "^1.0.2", - "object-inspect": "^1.13.1", - "object-keys": "^1.1.1", - "object.assign": "^4.1.4", - "regexp.prototype.flags": "^1.5.1", - "safe-array-concat": "^1.0.1", - "safe-regex-test": "^1.0.0", - "string.prototype.trim": "^1.2.8", - "string.prototype.trimend": "^1.0.7", - "string.prototype.trimstart": "^1.0.7", - "typed-array-buffer": "^1.0.0", - "typed-array-byte-length": "^1.0.0", - "typed-array-byte-offset": "^1.0.0", - "typed-array-length": "^1.0.4", - "unbox-primitive": "^1.0.2", - "which-typed-array": "^1.1.13" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/es-aggregate-error": { - "version": "1.0.11", - "dev": true, - "license": "MIT", - "dependencies": { - "define-data-property": "^1.1.0", - "define-properties": "^1.2.1", - "es-abstract": "^1.22.1", - "function-bind": "^1.1.1", - "get-intrinsic": "^1.2.1", - "globalthis": "^1.0.3", - "has-property-descriptors": "^1.0.0", - "set-function-name": "^2.0.1" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/es-get-iterator": { - "version": "1.1.3", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.2", - "get-intrinsic": "^1.1.3", - "has-symbols": "^1.0.3", - "is-arguments": "^1.1.1", - "is-map": "^2.0.2", - "is-set": "^2.0.2", - "is-string": "^1.0.7", - "isarray": "^2.0.5", - "stop-iteration-iterator": "^1.0.0" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/es-iterator-helpers": { - "version": "1.0.15", - "dev": true, - "license": "MIT", - "dependencies": { - "asynciterator.prototype": "^1.0.0", - "call-bind": "^1.0.2", - "define-properties": "^1.2.1", - "es-abstract": "^1.22.1", - "es-set-tostringtag": "^2.0.1", - "function-bind": "^1.1.1", - "get-intrinsic": "^1.2.1", - "globalthis": "^1.0.3", - "has-property-descriptors": "^1.0.0", - "has-proto": "^1.0.1", - "has-symbols": "^1.0.3", - "internal-slot": "^1.0.5", - "iterator.prototype": "^1.1.2", - "safe-array-concat": "^1.0.1" - } - }, - "node_modules/es-set-tostringtag": { - "version": "2.0.2", - "dev": true, - "license": "MIT", - "dependencies": { - "get-intrinsic": "^1.2.2", - "has-tostringtag": "^1.0.0", - "hasown": "^2.0.0" - }, - "engines": { - "node": ">= 0.4" - } - }, - "node_modules/es-shim-unscopables": { - "version": "1.0.2", - "dev": true, - "license": "MIT", - "dependencies": { - "hasown": "^2.0.0" - } - }, - "node_modules/es-to-primitive": { - "version": "1.2.1", - "dev": true, - "license": "MIT", - "dependencies": { - "is-callable": "^1.1.4", - "is-date-object": "^1.0.1", - "is-symbol": "^1.0.2" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/es6-promise": { - "version": "3.3.1", - "dev": true, - "license": "MIT" - }, - "node_modules/esbuild": { - "version": "0.19.8", - "dev": true, - "hasInstallScript": true, - "license": "MIT", - "bin": { - "esbuild": "bin/esbuild" - }, - "engines": { - "node": ">=12" - }, - "optionalDependencies": { - "@esbuild/android-arm": "0.19.8", - "@esbuild/android-arm64": "0.19.8", - "@esbuild/android-x64": "0.19.8", - "@esbuild/darwin-arm64": "0.19.8", - "@esbuild/darwin-x64": "0.19.8", - "@esbuild/freebsd-arm64": "0.19.8", - "@esbuild/freebsd-x64": "0.19.8", - "@esbuild/linux-arm": "0.19.8", - "@esbuild/linux-arm64": "0.19.8", - "@esbuild/linux-ia32": "0.19.8", - "@esbuild/linux-loong64": "0.19.8", - "@esbuild/linux-mips64el": "0.19.8", - "@esbuild/linux-ppc64": "0.19.8", - "@esbuild/linux-riscv64": "0.19.8", - "@esbuild/linux-s390x": "0.19.8", - "@esbuild/linux-x64": "0.19.8", - "@esbuild/netbsd-x64": "0.19.8", - "@esbuild/openbsd-x64": "0.19.8", - "@esbuild/sunos-x64": "0.19.8", - "@esbuild/win32-arm64": "0.19.8", - "@esbuild/win32-ia32": "0.19.8", - "@esbuild/win32-x64": "0.19.8" - } - }, - "node_modules/escalade": { - "version": "3.1.1", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=6" - } - }, - "node_modules/escape-string-regexp": { - "version": "4.0.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/eslint": { - "version": "8.55.0", - "dev": true, - "license": "MIT", - "dependencies": { - "@eslint-community/eslint-utils": "^4.2.0", - "@eslint-community/regexpp": "^4.6.1", - "@eslint/eslintrc": "^2.1.4", - "@eslint/js": "8.55.0", - "@humanwhocodes/config-array": "^0.11.13", - "@humanwhocodes/module-importer": "^1.0.1", - "@nodelib/fs.walk": "^1.2.8", - "@ungap/structured-clone": "^1.2.0", - "ajv": "^6.12.4", - "chalk": "^4.0.0", - "cross-spawn": "^7.0.2", - "debug": "^4.3.2", - "doctrine": "^3.0.0", - "escape-string-regexp": "^4.0.0", - "eslint-scope": "^7.2.2", - "eslint-visitor-keys": "^3.4.3", - "espree": "^9.6.1", - "esquery": "^1.4.2", - "esutils": "^2.0.2", - "fast-deep-equal": "^3.1.3", - "file-entry-cache": "^6.0.1", - "find-up": "^5.0.0", - "glob-parent": "^6.0.2", - "globals": "^13.19.0", - "graphemer": "^1.4.0", - "ignore": "^5.2.0", - "imurmurhash": "^0.1.4", - "is-glob": "^4.0.0", - "is-path-inside": "^3.0.3", - "js-yaml": "^4.1.0", - "json-stable-stringify-without-jsonify": "^1.0.1", - "levn": "^0.4.1", - "lodash.merge": "^4.6.2", - "minimatch": "^3.1.2", - "natural-compare": "^1.4.0", - "optionator": "^0.9.3", - "strip-ansi": "^6.0.1", - "text-table": "^0.2.0" - }, - "bin": { - "eslint": "bin/eslint.js" - }, - "engines": { - "node": "^12.22.0 || ^14.17.0 || >=16.0.0" - }, - "funding": { - "url": "https://opencollective.com/eslint" - } - }, - "node_modules/eslint-compat-utils": { - "version": "0.1.2", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=12" - }, - "peerDependencies": { - "eslint": ">=6.0.0" - } - }, - "node_modules/eslint-config-prettier": { - "version": "9.1.0", - "dev": true, - "license": "MIT", - "bin": { - "eslint-config-prettier": "bin/cli.js" - }, - "peerDependencies": { - "eslint": ">=7.0.0" - } - }, - "node_modules/eslint-config-standard": { - "version": "17.1.0", - "dev": true, - "funding": [ - { - "type": "github", - "url": "https://github.com/sponsors/feross" - }, - { - "type": "patreon", - "url": "https://www.patreon.com/feross" - }, - { - "type": "consulting", - "url": "https://feross.org/support" - } - ], - "license": "MIT", - "engines": { - "node": ">=12.0.0" - }, - "peerDependencies": { - "eslint": "^8.0.1", - "eslint-plugin-import": "^2.25.2", - "eslint-plugin-n": "^15.0.0 || ^16.0.0 ", - "eslint-plugin-promise": "^6.0.0" - } - }, - "node_modules/eslint-config-standard-with-typescript": { - "version": "39.1.1", - "dev": true, - "license": "MIT", - "dependencies": { - "@typescript-eslint/parser": "^6.4.0", - "eslint-config-standard": "17.1.0" - }, - "peerDependencies": { - "@typescript-eslint/eslint-plugin": "^6.4.0", - "eslint": "^8.0.1", - "eslint-plugin-import": "^2.25.2", - "eslint-plugin-n": "^15.0.0 || ^16.0.0 ", - "eslint-plugin-promise": "^6.0.0", - "typescript": "*" - } - }, - "node_modules/eslint-import-resolver-node": { - "version": "0.3.9", - "dev": true, - "license": "MIT", - "dependencies": { - "debug": "^3.2.7", - "is-core-module": "^2.13.0", - "resolve": "^1.22.4" - } - }, - "node_modules/eslint-import-resolver-node/node_modules/debug": { - "version": "3.2.7", - "dev": true, - "license": "MIT", - "dependencies": { - "ms": "^2.1.1" - } - }, - "node_modules/eslint-module-utils": { - "version": "2.8.0", - "dev": true, - "license": "MIT", - "dependencies": { - "debug": "^3.2.7" - }, - "engines": { - "node": ">=4" - }, - "peerDependenciesMeta": { - "eslint": { - "optional": true - } - } - }, - "node_modules/eslint-module-utils/node_modules/debug": { - "version": "3.2.7", - "dev": true, - "license": "MIT", - "dependencies": { - "ms": "^2.1.1" - } - }, - "node_modules/eslint-plugin-es-x": { - "version": "7.5.0", - "dev": true, - "license": "MIT", - "dependencies": { - "@eslint-community/eslint-utils": "^4.1.2", - "@eslint-community/regexpp": "^4.6.0", - "eslint-compat-utils": "^0.1.2" - }, - "engines": { - "node": "^14.18.0 || >=16.0.0" - }, - "funding": { - "url": "https://github.com/sponsors/ota-meshi" - }, - "peerDependencies": { - "eslint": ">=8" - } - }, - "node_modules/eslint-plugin-import": { - "version": "2.29.0", - "dev": true, - "license": "MIT", - "dependencies": { - "array-includes": "^3.1.7", - "array.prototype.findlastindex": "^1.2.3", - "array.prototype.flat": "^1.3.2", - "array.prototype.flatmap": "^1.3.2", - "debug": "^3.2.7", - "doctrine": "^2.1.0", - "eslint-import-resolver-node": "^0.3.9", - "eslint-module-utils": "^2.8.0", - "hasown": "^2.0.0", - "is-core-module": "^2.13.1", - "is-glob": "^4.0.3", - "minimatch": "^3.1.2", - "object.fromentries": "^2.0.7", - "object.groupby": "^1.0.1", - "object.values": "^1.1.7", - "semver": "^6.3.1", - "tsconfig-paths": "^3.14.2" - }, - "engines": { - "node": ">=4" - }, - "peerDependencies": { - "eslint": "^2 || ^3 || ^4 || ^5 || ^6 || ^7.2.0 || ^8" - } - }, - "node_modules/eslint-plugin-import/node_modules/debug": { - "version": "3.2.7", - "dev": true, - "license": "MIT", - "dependencies": { - "ms": "^2.1.1" - } - }, - "node_modules/eslint-plugin-import/node_modules/doctrine": { - "version": "2.1.0", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "esutils": "^2.0.2" - }, - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/eslint-plugin-import/node_modules/semver": { - "version": "6.3.1", - "dev": true, - "license": "ISC", - "bin": { - "semver": "bin/semver.js" - } - }, - "node_modules/eslint-plugin-n": { - "version": "16.3.1", - "dev": true, - "license": "MIT", - "dependencies": { - "@eslint-community/eslint-utils": "^4.4.0", - "builtins": "^5.0.1", - "eslint-plugin-es-x": "^7.1.0", - "get-tsconfig": "^4.7.0", - "ignore": "^5.2.4", - "is-builtin-module": "^3.2.1", - "is-core-module": "^2.12.1", - "minimatch": "^3.1.2", - "resolve": "^1.22.2", - "semver": "^7.5.3" - }, - "engines": { - "node": ">=16.0.0" - }, - "funding": { - "url": "https://github.com/sponsors/mysticatea" - }, - "peerDependencies": { - "eslint": ">=7.0.0" - } - }, - "node_modules/eslint-plugin-promise": { - "version": "6.1.1", - "dev": true, - "license": "ISC", - "engines": { - "node": "^12.22.0 || ^14.17.0 || >=16.0.0" - }, - "peerDependencies": { - "eslint": "^7.0.0 || ^8.0.0" - } - }, - "node_modules/eslint-plugin-react": { - "version": "7.33.2", - "dev": true, - "license": "MIT", - "dependencies": { - "array-includes": "^3.1.6", - "array.prototype.flatmap": "^1.3.1", - "array.prototype.tosorted": "^1.1.1", - "doctrine": "^2.1.0", - "es-iterator-helpers": "^1.0.12", - "estraverse": "^5.3.0", - "jsx-ast-utils": "^2.4.1 || ^3.0.0", - "minimatch": "^3.1.2", - "object.entries": "^1.1.6", - "object.fromentries": "^2.0.6", - "object.hasown": "^1.1.2", - "object.values": "^1.1.6", - "prop-types": "^15.8.1", - "resolve": "^2.0.0-next.4", - "semver": "^6.3.1", - "string.prototype.matchall": "^4.0.8" - }, - "engines": { - "node": ">=4" - }, - "peerDependencies": { - "eslint": "^3 || ^4 || ^5 || ^6 || ^7 || ^8" - } - }, - "node_modules/eslint-plugin-react/node_modules/doctrine": { - "version": "2.1.0", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "esutils": "^2.0.2" - }, - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/eslint-plugin-react/node_modules/resolve": { - "version": "2.0.0-next.5", - "dev": true, - "license": "MIT", - "dependencies": { - "is-core-module": "^2.13.0", - "path-parse": "^1.0.7", - "supports-preserve-symlinks-flag": "^1.0.0" - }, - "bin": { - "resolve": "bin/resolve" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/eslint-plugin-react/node_modules/semver": { - "version": "6.3.1", - "dev": true, - "license": "ISC", - "bin": { - "semver": "bin/semver.js" - } - }, - "node_modules/eslint-scope": { - "version": "7.2.2", - "dev": true, - "license": "BSD-2-Clause", - "dependencies": { - "esrecurse": "^4.3.0", - "estraverse": "^5.2.0" - }, - "engines": { - "node": "^12.22.0 || ^14.17.0 || >=16.0.0" - }, - "funding": { - "url": "https://opencollective.com/eslint" - } - }, - "node_modules/eslint-visitor-keys": { - "version": "3.4.3", - "dev": true, - "license": "Apache-2.0", - "engines": { - "node": "^12.22.0 || ^14.17.0 || >=16.0.0" - }, - "funding": { - "url": "https://opencollective.com/eslint" - } - }, - "node_modules/espree": { - "version": "9.6.1", - "dev": true, - "license": "BSD-2-Clause", - "dependencies": { - "acorn": "^8.9.0", - "acorn-jsx": "^5.3.2", - "eslint-visitor-keys": "^3.4.1" - }, - "engines": { - "node": "^12.22.0 || ^14.17.0 || >=16.0.0" - }, - "funding": { - "url": "https://opencollective.com/eslint" - } - }, - "node_modules/esprima": { - "version": "4.0.1", - "dev": true, - "license": "BSD-2-Clause", - "bin": { - "esparse": "bin/esparse.js", - "esvalidate": "bin/esvalidate.js" - }, - "engines": { - "node": ">=4" - } - }, - "node_modules/esquery": { - "version": "1.5.0", - "dev": true, - "license": "BSD-3-Clause", - "dependencies": { - "estraverse": "^5.1.0" - }, - "engines": { - "node": ">=0.10" - } - }, - "node_modules/esrecurse": { - "version": "4.3.0", - "dev": true, - "license": "BSD-2-Clause", - "dependencies": { - "estraverse": "^5.2.0" - }, - "engines": { - "node": ">=4.0" - } - }, - "node_modules/estraverse": { - "version": "5.3.0", - "dev": true, - "license": "BSD-2-Clause", - "engines": { - "node": ">=4.0" - } - }, - "node_modules/esutils": { - "version": "2.0.3", - "dev": true, - "license": "BSD-2-Clause", - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/event-target-shim": { - "version": "5.0.1", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=6" - } - }, - "node_modules/execa": { - "version": "5.1.1", - "dev": true, - "license": "MIT", - "dependencies": { - "cross-spawn": "^7.0.3", - "get-stream": "^6.0.0", - "human-signals": "^2.1.0", - "is-stream": "^2.0.0", - "merge-stream": "^2.0.0", - "npm-run-path": "^4.0.1", - "onetime": "^5.1.2", - "signal-exit": "^3.0.3", - "strip-final-newline": "^2.0.0" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sindresorhus/execa?sponsor=1" - } - }, - "node_modules/fast-deep-equal": { - "version": "3.1.3", - "license": "MIT" - }, - "node_modules/fast-glob": { - "version": "3.3.2", - "license": "MIT", - "dependencies": { - "@nodelib/fs.stat": "^2.0.2", - "@nodelib/fs.walk": "^1.2.3", - "glob-parent": "^5.1.2", - "merge2": "^1.3.0", - "micromatch": "^4.0.4" - }, - "engines": { - "node": ">=8.6.0" - } - }, - "node_modules/fast-glob/node_modules/glob-parent": { - "version": "5.1.2", - "license": "ISC", - "dependencies": { - "is-glob": "^4.0.1" - }, - "engines": { - "node": ">= 6" - } - }, - "node_modules/fast-json-stable-stringify": { - "version": "2.1.0", - "dev": true, - "license": "MIT" - }, - "node_modules/fast-levenshtein": { - "version": "2.0.6", - "dev": true, - "license": "MIT" - }, - "node_modules/fast-memoize": { - "version": "2.5.2", - "dev": true, - "license": "MIT" - }, - "node_modules/fast-safe-stringify": { - "version": "2.1.1", - "dev": true, - "license": "MIT" - }, - "node_modules/fastq": { - "version": "1.15.0", - "license": "ISC", - "dependencies": { - "reusify": "^1.0.4" - } - }, - "node_modules/file-entry-cache": { - "version": "6.0.1", - "dev": true, - "license": "MIT", - "dependencies": { - "flat-cache": "^3.0.4" - }, - "engines": { - "node": "^10.12.0 || >=12.0.0" - } - }, - "node_modules/fill-range": { - "version": "7.0.1", - "license": "MIT", - "dependencies": { - "to-regex-range": "^5.0.1" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/find-replace": { - "version": "3.0.0", - "license": "MIT", - "dependencies": { - "array-back": "^3.0.1" - }, - "engines": { - "node": ">=4.0.0" - } - }, - "node_modules/find-up": { - "version": "5.0.0", - "dev": true, - "license": "MIT", - "dependencies": { - "locate-path": "^6.0.0", - "path-exists": "^4.0.0" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/flat-cache": { - "version": "3.2.0", - "dev": true, - "license": "MIT", - "dependencies": { - "flatted": "^3.2.9", - "keyv": "^4.5.3", - "rimraf": "^3.0.2" - }, - "engines": { - "node": "^10.12.0 || >=12.0.0" - } - }, - "node_modules/flatbuffers": { - "version": "23.5.26", - "license": "SEE LICENSE IN LICENSE" - }, - "node_modules/flatted": { - "version": "3.2.9", - "dev": true, - "license": "ISC" - }, - "node_modules/for-each": { - "version": "0.3.3", - "dev": true, - "license": "MIT", - "dependencies": { - "is-callable": "^1.1.3" - } - }, - "node_modules/form-data": { - "version": "4.0.0", - "dev": true, - "license": "MIT", - "dependencies": { - "asynckit": "^0.4.0", - "combined-stream": "^1.0.8", - "mime-types": "^2.1.12" - }, - "engines": { - "node": ">= 6" - } - }, - "node_modules/fraction.js": { - "version": "4.3.7", - "dev": true, - "license": "MIT", - "engines": { - "node": "*" - }, - "funding": { - "type": "patreon", - "url": "https://github.com/sponsors/rawify" - } - }, - "node_modules/fs-extra": { - "version": "11.2.0", - "dev": true, - "license": "MIT", - "dependencies": { - "graceful-fs": "^4.2.0", - "jsonfile": "^6.0.1", - "universalify": "^2.0.0" - }, - "engines": { - "node": ">=14.14" - } - }, - "node_modules/fs.realpath": { - "version": "1.0.0", - "license": "ISC" - }, - "node_modules/function-bind": { - "version": "1.1.2", - "license": "MIT", - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/function.prototype.name": { - "version": "1.1.6", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.2", - "define-properties": "^1.2.0", - "es-abstract": "^1.22.1", - "functions-have-names": "^1.2.3" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/functions-have-names": { - "version": "1.2.3", - "dev": true, - "license": "MIT", - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/get-caller-file": { - "version": "2.0.5", - "dev": true, - "license": "ISC", - "engines": { - "node": "6.* || 8.* || >= 10.*" - } - }, - "node_modules/get-func-name": { - "version": "2.0.2", - "dev": true, - "license": "MIT", - "engines": { - "node": "*" - } - }, - "node_modules/get-intrinsic": { - "version": "1.2.2", - "dev": true, - "license": "MIT", - "dependencies": { - "function-bind": "^1.1.2", - "has-proto": "^1.0.1", - "has-symbols": "^1.0.3", - "hasown": "^2.0.0" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/get-nonce": { - "version": "1.0.1", - "license": "MIT", - "engines": { - "node": ">=6" - } - }, - "node_modules/get-stream": { - "version": "6.0.1", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/get-symbol-description": { - "version": "1.0.0", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.2", - "get-intrinsic": "^1.1.1" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/get-tsconfig": { - "version": "4.7.2", - "dev": true, - "license": "MIT", - "dependencies": { - "resolve-pkg-maps": "^1.0.0" - }, - "funding": { - "url": "https://github.com/privatenumber/get-tsconfig?sponsor=1" - } - }, - "node_modules/glob": { - "version": "7.2.3", - "dev": true, - "license": "ISC", - "dependencies": { - "fs.realpath": "^1.0.0", - "inflight": "^1.0.4", - "inherits": "2", - "minimatch": "^3.1.1", - "once": "^1.3.0", - "path-is-absolute": "^1.0.0" - }, - "engines": { - "node": "*" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" - } - }, - "node_modules/glob-parent": { - "version": "6.0.2", - "license": "ISC", - "dependencies": { - "is-glob": "^4.0.3" - }, - "engines": { - "node": ">=10.13.0" - } - }, - "node_modules/globals": { - "version": "13.23.0", - "dev": true, - "license": "MIT", - "dependencies": { - "type-fest": "^0.20.2" - }, - "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/globalthis": { - "version": "1.0.3", - "dev": true, - "license": "MIT", - "dependencies": { - "define-properties": "^1.1.3" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/globby": { - "version": "11.1.0", - "dev": true, - "license": "MIT", - "dependencies": { - "array-union": "^2.1.0", - "dir-glob": "^3.0.1", - "fast-glob": "^3.2.9", - "ignore": "^5.2.0", - "merge2": "^1.4.1", - "slash": "^3.0.0" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/gopd": { - "version": "1.0.1", - "dev": true, - "license": "MIT", - "dependencies": { - "get-intrinsic": "^1.1.3" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/graceful-fs": { - "version": "4.2.11", - "dev": true, - "license": "ISC" - }, - "node_modules/graphemer": { - "version": "1.4.0", - "dev": true, - "license": "MIT" - }, - "node_modules/has-bigints": { - "version": "1.0.2", - "dev": true, - "license": "MIT", - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/has-flag": { - "version": "4.0.0", - "license": "MIT", - "engines": { - "node": ">=8" - } - }, - "node_modules/has-property-descriptors": { - "version": "1.0.1", - "dev": true, - "license": "MIT", - "dependencies": { - "get-intrinsic": "^1.2.2" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/has-proto": { - "version": "1.0.1", - "dev": true, - "license": "MIT", - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/has-symbols": { - "version": "1.0.3", - "dev": true, - "license": "MIT", - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/has-tostringtag": { - "version": "1.0.0", - "dev": true, - "license": "MIT", - "dependencies": { - "has-symbols": "^1.0.2" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/hasown": { - "version": "2.0.0", - "license": "MIT", - "dependencies": { - "function-bind": "^1.1.2" - }, - "engines": { - "node": ">= 0.4" - } - }, - "node_modules/hoist-non-react-statics": { - "version": "3.3.2", - "license": "BSD-3-Clause", - "dependencies": { - "react-is": "^16.7.0" - } - }, - "node_modules/hoist-non-react-statics/node_modules/react-is": { - "version": "16.13.1", - "license": "MIT" - }, - "node_modules/html-encoding-sniffer": { - "version": "3.0.0", - "dev": true, - "license": "MIT", - "dependencies": { - "whatwg-encoding": "^2.0.0" - }, - "engines": { - "node": ">=12" - } - }, - "node_modules/http-proxy-agent": { - "version": "5.0.0", - "dev": true, - "license": "MIT", - "dependencies": { - "@tootallnate/once": "2", - "agent-base": "6", - "debug": "4" - }, - "engines": { - "node": ">= 6" - } - }, - "node_modules/http2-client": { - "version": "1.3.5", - "dev": true, - "license": "MIT" - }, - "node_modules/https-proxy-agent": { - "version": "5.0.1", - "dev": true, - "license": "MIT", - "dependencies": { - "agent-base": "6", - "debug": "4" - }, - "engines": { - "node": ">= 6" - } - }, - "node_modules/human-signals": { - "version": "2.1.0", - "dev": true, - "license": "Apache-2.0", - "engines": { - "node": ">=10.17.0" - } - }, - "node_modules/iconv-lite": { - "version": "0.6.3", - "dev": true, - "license": "MIT", - "dependencies": { - "safer-buffer": ">= 2.1.2 < 3.0.0" - }, - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/ignore": { - "version": "5.3.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">= 4" - } - }, - "node_modules/immer": { - "version": "9.0.21", - "devOptional": true, - "license": "MIT", - "funding": { - "type": "opencollective", - "url": "https://opencollective.com/immer" - } - }, - "node_modules/import-fresh": { - "version": "3.3.0", - "dev": true, - "license": "MIT", - "dependencies": { - "parent-module": "^1.0.0", - "resolve-from": "^4.0.0" - }, - "engines": { - "node": ">=6" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/imurmurhash": { - "version": "0.1.4", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=0.8.19" - } - }, - "node_modules/indent-string": { - "version": "4.0.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=8" - } - }, - "node_modules/inflight": { - "version": "1.0.6", - "license": "ISC", - "dependencies": { - "once": "^1.3.0", - "wrappy": "1" - } - }, - "node_modules/inherits": { - "version": "2.0.4", - "license": "ISC" - }, - "node_modules/internal-slot": { - "version": "1.0.6", - "dev": true, - "license": "MIT", - "dependencies": { - "get-intrinsic": "^1.2.2", - "hasown": "^2.0.0", - "side-channel": "^1.0.4" - }, - "engines": { - "node": ">= 0.4" - } - }, - "node_modules/invariant": { - "version": "2.2.4", - "license": "MIT", - "dependencies": { - "loose-envify": "^1.0.0" - } - }, - "node_modules/is-arguments": { - "version": "1.1.1", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.2", - "has-tostringtag": "^1.0.0" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/is-array-buffer": { - "version": "3.0.2", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.2", - "get-intrinsic": "^1.2.0", - "is-typed-array": "^1.1.10" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/is-async-function": { - "version": "2.0.0", - "dev": true, - "license": "MIT", - "dependencies": { - "has-tostringtag": "^1.0.0" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/is-bigint": { - "version": "1.0.4", - "dev": true, - "license": "MIT", - "dependencies": { - "has-bigints": "^1.0.1" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/is-binary-path": { - "version": "2.1.0", - "license": "MIT", - "dependencies": { - "binary-extensions": "^2.0.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/is-boolean-object": { - "version": "1.1.2", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.2", - "has-tostringtag": "^1.0.0" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/is-builtin-module": { - "version": "3.2.1", - "dev": true, - "license": "MIT", - "dependencies": { - "builtin-modules": "^3.3.0" - }, - "engines": { - "node": ">=6" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/is-callable": { - "version": "1.2.7", - "dev": true, - "license": "MIT", - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/is-core-module": { - "version": "2.13.1", - "license": "MIT", - "dependencies": { - "hasown": "^2.0.0" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/is-date-object": { - "version": "1.0.5", - "dev": true, - "license": "MIT", - "dependencies": { - "has-tostringtag": "^1.0.0" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/is-extglob": { - "version": "2.1.1", - "license": "MIT", - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/is-finalizationregistry": { - "version": "1.0.2", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.2" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/is-fullwidth-code-point": { - "version": "3.0.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=8" - } - }, - "node_modules/is-generator-function": { - "version": "1.0.10", - "dev": true, - "license": "MIT", - "dependencies": { - "has-tostringtag": "^1.0.0" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/is-glob": { - "version": "4.0.3", - "license": "MIT", - "dependencies": { - "is-extglob": "^2.1.1" - }, - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/is-map": { - "version": "2.0.2", - "dev": true, - "license": "MIT", - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/is-negative-zero": { - "version": "2.0.2", - "dev": true, - "license": "MIT", - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/is-number": { - "version": "7.0.0", - "license": "MIT", - "engines": { - "node": ">=0.12.0" - } - }, - "node_modules/is-number-object": { - "version": "1.0.7", - "dev": true, - "license": "MIT", - "dependencies": { - "has-tostringtag": "^1.0.0" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/is-path-inside": { - "version": "3.0.3", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=8" - } - }, - "node_modules/is-potential-custom-element-name": { - "version": "1.0.1", - "dev": true, - "license": "MIT" - }, - "node_modules/is-regex": { - "version": "1.1.4", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.2", - "has-tostringtag": "^1.0.0" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/is-set": { - "version": "2.0.2", - "dev": true, - "license": "MIT", - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/is-shared-array-buffer": { - "version": "1.0.2", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.2" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/is-stream": { - "version": "2.0.1", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/is-string": { - "version": "1.0.7", - "dev": true, - "license": "MIT", - "dependencies": { - "has-tostringtag": "^1.0.0" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/is-symbol": { - "version": "1.0.4", - "dev": true, - "license": "MIT", - "dependencies": { - "has-symbols": "^1.0.2" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/is-typed-array": { - "version": "1.1.12", - "dev": true, - "license": "MIT", - "dependencies": { - "which-typed-array": "^1.1.11" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/is-weakmap": { - "version": "2.0.1", - "dev": true, - "license": "MIT", - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/is-weakref": { - "version": "1.0.2", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.2" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/is-weakset": { - "version": "2.0.2", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.2", - "get-intrinsic": "^1.1.1" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/isarray": { - "version": "2.0.5", - "dev": true, - "license": "MIT" - }, - "node_modules/isexe": { - "version": "2.0.0", - "dev": true, - "license": "ISC" - }, - "node_modules/iterator.prototype": { - "version": "1.1.2", - "dev": true, - "license": "MIT", - "dependencies": { - "define-properties": "^1.2.1", - "get-intrinsic": "^1.2.1", - "has-symbols": "^1.0.3", - "reflect.getprototypeof": "^1.0.4", - "set-function-name": "^2.0.1" - } - }, - "node_modules/jiti": { - "version": "1.21.0", - "license": "MIT", - "bin": { - "jiti": "bin/jiti.js" - } - }, - "node_modules/js-tokens": { - "version": "4.0.0", - "license": "MIT" - }, - "node_modules/js-yaml": { - "version": "4.1.0", - "dev": true, - "license": "MIT", - "dependencies": { - "argparse": "^2.0.1" - }, - "bin": { - "js-yaml": "bin/js-yaml.js" - } - }, - "node_modules/jsdom": { - "version": "22.1.0", - "dev": true, - "license": "MIT", - "dependencies": { - "abab": "^2.0.6", - "cssstyle": "^3.0.0", - "data-urls": "^4.0.0", - "decimal.js": "^10.4.3", - "domexception": "^4.0.0", - "form-data": "^4.0.0", - "html-encoding-sniffer": "^3.0.0", - "http-proxy-agent": "^5.0.0", - "https-proxy-agent": "^5.0.1", - "is-potential-custom-element-name": "^1.0.1", - "nwsapi": "^2.2.4", - "parse5": "^7.1.2", - "rrweb-cssom": "^0.6.0", - "saxes": "^6.0.0", - "symbol-tree": "^3.2.4", - "tough-cookie": "^4.1.2", - "w3c-xmlserializer": "^4.0.0", - "webidl-conversions": "^7.0.0", - "whatwg-encoding": "^2.0.0", - "whatwg-mimetype": "^3.0.0", - "whatwg-url": "^12.0.1", - "ws": "^8.13.0", - "xml-name-validator": "^4.0.0" - }, - "engines": { - "node": ">=16" - }, - "peerDependencies": { - "canvas": "^2.5.0" - }, - "peerDependenciesMeta": { - "canvas": { - "optional": true - } - } - }, - "node_modules/jsep": { - "version": "1.3.8", - "dev": true, - "license": "MIT", - "engines": { - "node": ">= 10.16.0" - } - }, - "node_modules/json-bignum": { - "version": "0.0.3", - "engines": { - "node": ">=0.8" - } - }, - "node_modules/json-buffer": { - "version": "3.0.1", - "dev": true, - "license": "MIT" - }, - "node_modules/json-schema-traverse": { - "version": "0.4.1", - "dev": true, - "license": "MIT" - }, - "node_modules/json-stable-stringify-without-jsonify": { - "version": "1.0.1", - "dev": true, - "license": "MIT" - }, - "node_modules/json5": { - "version": "1.0.2", - "dev": true, - "license": "MIT", - "dependencies": { - "minimist": "^1.2.0" - }, - "bin": { - "json5": "lib/cli.js" - } - }, - "node_modules/jsonc-parser": { - "version": "2.2.1", - "dev": true, - "license": "MIT" - }, - "node_modules/jsonfile": { - "version": "6.1.0", - "dev": true, - "license": "MIT", - "dependencies": { - "universalify": "^2.0.0" - }, - "optionalDependencies": { - "graceful-fs": "^4.1.6" - } - }, - "node_modules/jsonpath-plus": { - "version": "7.1.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=12.0.0" - } - }, - "node_modules/jsonpointer": { - "version": "5.0.1", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/jsx-ast-utils": { - "version": "3.3.5", - "dev": true, - "license": "MIT", - "dependencies": { - "array-includes": "^3.1.6", - "array.prototype.flat": "^1.3.1", - "object.assign": "^4.1.4", - "object.values": "^1.1.6" - }, - "engines": { - "node": ">=4.0" - } - }, - "node_modules/keyv": { - "version": "4.5.4", - "dev": true, - "license": "MIT", - "dependencies": { - "json-buffer": "3.0.1" - } - }, - "node_modules/leven": { - "version": "3.1.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=6" - } - }, - "node_modules/levn": { - "version": "0.4.1", - "dev": true, - "license": "MIT", - "dependencies": { - "prelude-ls": "^1.2.1", - "type-check": "~0.4.0" - }, - "engines": { - "node": ">= 0.8.0" - } - }, - "node_modules/lilconfig": { - "version": "2.1.0", - "license": "MIT", - "engines": { - "node": ">=10" - } - }, - "node_modules/lines-and-columns": { - "version": "1.2.4", - "license": "MIT" - }, - "node_modules/local-pkg": { - "version": "0.4.3", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=14" - }, - "funding": { - "url": "https://github.com/sponsors/antfu" - } - }, - "node_modules/locate-path": { - "version": "6.0.0", - "dev": true, - "license": "MIT", - "dependencies": { - "p-locate": "^5.0.0" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/lodash": { - "version": "4.17.21", - "dev": true, - "license": "MIT" - }, - "node_modules/lodash.assignwith": { - "version": "4.2.0", - "license": "MIT" - }, - "node_modules/lodash.camelcase": { - "version": "4.3.0", - "license": "MIT" - }, - "node_modules/lodash.get": { - "version": "4.4.2", - "dev": true, - "license": "MIT" - }, - "node_modules/lodash.isempty": { - "version": "4.4.0", - "dev": true, - "license": "MIT" - }, - "node_modules/lodash.merge": { - "version": "4.6.2", - "dev": true, - "license": "MIT" - }, - "node_modules/lodash.omit": { - "version": "4.5.0", - "dev": true, - "license": "MIT" - }, - "node_modules/lodash.omitby": { - "version": "4.6.0", - "dev": true, - "license": "MIT" - }, - "node_modules/lodash.topath": { - "version": "4.5.2", - "dev": true, - "license": "MIT" - }, - "node_modules/lodash.uniq": { - "version": "4.5.0", - "dev": true, - "license": "MIT" - }, - "node_modules/lodash.uniqby": { - "version": "4.7.0", - "dev": true, - "license": "MIT" - }, - "node_modules/lodash.uniqwith": { - "version": "4.5.0", - "dev": true, - "license": "MIT" - }, - "node_modules/loglevel": { - "version": "1.8.1", - "dev": true, - "license": "MIT", - "engines": { - "node": ">= 0.6.0" - }, - "funding": { - "type": "tidelift", - "url": "https://tidelift.com/funding/github/npm/loglevel" - } - }, - "node_modules/loglevel-plugin-prefix": { - "version": "0.8.4", - "dev": true, - "license": "MIT" - }, - "node_modules/loose-envify": { - "version": "1.4.0", - "license": "MIT", - "dependencies": { - "js-tokens": "^3.0.0 || ^4.0.0" - }, - "bin": { - "loose-envify": "cli.js" - } - }, - "node_modules/loupe": { - "version": "2.3.7", - "dev": true, - "license": "MIT", - "dependencies": { - "get-func-name": "^2.0.1" - } - }, - "node_modules/lru-cache": { - "version": "6.0.0", - "dev": true, - "license": "ISC", - "dependencies": { - "yallist": "^4.0.0" - }, - "engines": { - "node": ">=10" - } - }, - "node_modules/lz-string": { - "version": "1.5.0", - "dev": true, - "license": "MIT", - "bin": { - "lz-string": "bin/bin.js" - } - }, - "node_modules/magic-string": { - "version": "0.30.5", - "dev": true, - "license": "MIT", - "dependencies": { - "@jridgewell/sourcemap-codec": "^1.4.15" - }, - "engines": { - "node": ">=12" - } - }, - "node_modules/merge-stream": { - "version": "2.0.0", - "dev": true, - "license": "MIT" - }, - "node_modules/merge2": { - "version": "1.4.1", - "license": "MIT", - "engines": { - "node": ">= 8" - } - }, - "node_modules/micromatch": { - "version": "4.0.5", - "license": "MIT", - "dependencies": { - "braces": "^3.0.2", - "picomatch": "^2.3.1" - }, - "engines": { - "node": ">=8.6" - } - }, - "node_modules/mime-db": { - "version": "1.52.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">= 0.6" - } - }, - "node_modules/mime-types": { - "version": "2.1.35", - "dev": true, - "license": "MIT", - "dependencies": { - "mime-db": "1.52.0" - }, - "engines": { - "node": ">= 0.6" - } - }, - "node_modules/mimic-fn": { - "version": "2.1.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=6" - } - }, - "node_modules/min-indent": { - "version": "1.0.1", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=4" - } - }, - "node_modules/minimatch": { - "version": "3.1.2", - "license": "ISC", - "dependencies": { - "brace-expansion": "^1.1.7" - }, - "engines": { - "node": "*" - } - }, - "node_modules/minimist": { - "version": "1.2.8", - "dev": true, - "license": "MIT", - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/mlly": { - "version": "1.4.2", - "dev": true, - "license": "MIT", - "dependencies": { - "acorn": "^8.10.0", - "pathe": "^1.1.1", - "pkg-types": "^1.0.3", - "ufo": "^1.3.0" - } - }, - "node_modules/ms": { - "version": "2.1.2", - "dev": true, - "license": "MIT" - }, - "node_modules/mz": { - "version": "2.7.0", - "license": "MIT", - "dependencies": { - "any-promise": "^1.0.0", - "object-assign": "^4.0.1", - "thenify-all": "^1.0.0" - } - }, - "node_modules/nanoid": { - "version": "3.3.7", - "funding": [ - { - "type": "github", - "url": "https://github.com/sponsors/ai" - } - ], - "license": "MIT", - "bin": { - "nanoid": "bin/nanoid.cjs" - }, - "engines": { - "node": "^10 || ^12 || ^13.7 || ^14 || >=15.0.1" - } - }, - "node_modules/natural-compare": { - "version": "1.4.0", - "dev": true, - "license": "MIT" - }, - "node_modules/nimma": { - "version": "0.2.2", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "@jsep-plugin/regex": "^1.0.1", - "@jsep-plugin/ternary": "^1.0.2", - "astring": "^1.8.1", - "jsep": "^1.2.0" - }, - "engines": { - "node": "^12.20 || >=14.13" - }, - "optionalDependencies": { - "jsonpath-plus": "^6.0.1", - "lodash.topath": "^4.5.2" - } - }, - "node_modules/nimma/node_modules/jsonpath-plus": { - "version": "6.0.1", - "dev": true, - "license": "MIT", - "optional": true, - "engines": { - "node": ">=10.0.0" - } - }, - "node_modules/node-fetch": { - "version": "2.7.0", - "dev": true, - "license": "MIT", - "dependencies": { - "whatwg-url": "^5.0.0" - }, - "engines": { - "node": "4.x || >=6.0.0" - }, - "peerDependencies": { - "encoding": "^0.1.0" - }, - "peerDependenciesMeta": { - "encoding": { - "optional": true - } - } - }, - "node_modules/node-fetch-h2": { - "version": "2.3.0", - "dev": true, - "license": "MIT", - "dependencies": { - "http2-client": "^1.2.5" - }, - "engines": { - "node": "4.x || >=6.0.0" - } - }, - "node_modules/node-fetch/node_modules/tr46": { - "version": "0.0.3", - "dev": true, - "license": "MIT" - }, - "node_modules/node-fetch/node_modules/webidl-conversions": { - "version": "3.0.1", - "dev": true, - "license": "BSD-2-Clause" - }, - "node_modules/node-fetch/node_modules/whatwg-url": { - "version": "5.0.0", - "dev": true, - "license": "MIT", - "dependencies": { - "tr46": "~0.0.3", - "webidl-conversions": "^3.0.0" - } - }, - "node_modules/node-readfiles": { - "version": "0.2.0", - "dev": true, - "license": "MIT", - "dependencies": { - "es6-promise": "^3.2.1" - } - }, - "node_modules/node-releases": { - "version": "2.0.14", - "dev": true, - "license": "MIT" - }, - "node_modules/normalize-path": { - "version": "3.0.0", - "license": "MIT", - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/normalize-range": { - "version": "0.1.2", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/npm-run-path": { - "version": "4.0.1", - "dev": true, - "license": "MIT", - "dependencies": { - "path-key": "^3.0.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/nwsapi": { - "version": "2.2.7", - "dev": true, - "license": "MIT" - }, - "node_modules/oas-kit-common": { - "version": "1.0.8", - "dev": true, - "license": "BSD-3-Clause", - "dependencies": { - "fast-safe-stringify": "^2.0.7" - } - }, - "node_modules/oas-linter": { - "version": "3.2.2", - "dev": true, - "license": "BSD-3-Clause", - "dependencies": { - "@exodus/schemasafe": "^1.0.0-rc.2", - "should": "^13.2.1", - "yaml": "^1.10.0" - }, - "funding": { - "url": "https://github.com/Mermade/oas-kit?sponsor=1" - } - }, - "node_modules/oas-linter/node_modules/yaml": { - "version": "1.10.2", - "dev": true, - "license": "ISC", - "engines": { - "node": ">= 6" - } - }, - "node_modules/oas-resolver": { - "version": "2.5.6", - "dev": true, - "license": "BSD-3-Clause", - "dependencies": { - "node-fetch-h2": "^2.3.0", - "oas-kit-common": "^1.0.8", - "reftools": "^1.1.9", - "yaml": "^1.10.0", - "yargs": "^17.0.1" - }, - "bin": { - "resolve": "resolve.js" - }, - "funding": { - "url": "https://github.com/Mermade/oas-kit?sponsor=1" - } - }, - "node_modules/oas-resolver/node_modules/yaml": { - "version": "1.10.2", - "dev": true, - "license": "ISC", - "engines": { - "node": ">= 6" - } - }, - "node_modules/oas-schema-walker": { - "version": "1.1.5", - "dev": true, - "license": "BSD-3-Clause", - "funding": { - "url": "https://github.com/Mermade/oas-kit?sponsor=1" - } - }, - "node_modules/oas-validator": { - "version": "5.0.8", - "dev": true, - "license": "BSD-3-Clause", - "dependencies": { - "call-me-maybe": "^1.0.1", - "oas-kit-common": "^1.0.8", - "oas-linter": "^3.2.2", - "oas-resolver": "^2.5.6", - "oas-schema-walker": "^1.1.5", - "reftools": "^1.1.9", - "should": "^13.2.1", - "yaml": "^1.10.0" - }, - "funding": { - "url": "https://github.com/Mermade/oas-kit?sponsor=1" - } - }, - "node_modules/oas-validator/node_modules/yaml": { - "version": "1.10.2", - "dev": true, - "license": "ISC", - "engines": { - "node": ">= 6" - } - }, - "node_modules/object-assign": { - "version": "4.1.1", - "license": "MIT", - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/object-hash": { - "version": "3.0.0", - "license": "MIT", - "engines": { - "node": ">= 6" - } - }, - "node_modules/object-inspect": { - "version": "1.13.1", - "dev": true, - "license": "MIT", - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/object-is": { - "version": "1.1.5", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.2", - "define-properties": "^1.1.3" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/object-keys": { - "version": "1.1.1", - "dev": true, - "license": "MIT", - "engines": { - "node": ">= 0.4" - } - }, - "node_modules/object.assign": { - "version": "4.1.5", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.5", - "define-properties": "^1.2.1", - "has-symbols": "^1.0.3", - "object-keys": "^1.1.1" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/object.entries": { - "version": "1.1.7", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.2", - "define-properties": "^1.2.0", - "es-abstract": "^1.22.1" - }, - "engines": { - "node": ">= 0.4" - } - }, - "node_modules/object.fromentries": { - "version": "2.0.7", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.2", - "define-properties": "^1.2.0", - "es-abstract": "^1.22.1" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/object.groupby": { - "version": "1.0.1", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.2", - "define-properties": "^1.2.0", - "es-abstract": "^1.22.1", - "get-intrinsic": "^1.2.1" - } - }, - "node_modules/object.hasown": { - "version": "1.1.3", - "dev": true, - "license": "MIT", - "dependencies": { - "define-properties": "^1.2.0", - "es-abstract": "^1.22.1" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/object.values": { - "version": "1.1.7", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.2", - "define-properties": "^1.2.0", - "es-abstract": "^1.22.1" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/once": { - "version": "1.4.0", - "license": "ISC", - "dependencies": { - "wrappy": "1" - } - }, - "node_modules/onetime": { - "version": "5.1.2", - "dev": true, - "license": "MIT", - "dependencies": { - "mimic-fn": "^2.1.0" - }, - "engines": { - "node": ">=6" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/openapi-types": { - "version": "12.1.3", - "dev": true, - "license": "MIT" - }, - "node_modules/openapi3-ts": { - "version": "3.2.0", - "dev": true, - "license": "MIT", - "dependencies": { - "yaml": "^2.2.1" - } - }, - "node_modules/optionator": { - "version": "0.9.3", - "dev": true, - "license": "MIT", - "dependencies": { - "@aashutoshrathi/word-wrap": "^1.2.3", - "deep-is": "^0.1.3", - "fast-levenshtein": "^2.0.6", - "levn": "^0.4.1", - "prelude-ls": "^1.2.1", - "type-check": "^0.4.0" - }, - "engines": { - "node": ">= 0.8.0" - } - }, - "node_modules/orval": { - "version": "6.22.1", - "dev": true, - "license": "MIT", - "dependencies": { - "@apidevtools/swagger-parser": "^10.1.0", - "@orval/angular": "6.22.1", - "@orval/axios": "6.22.1", - "@orval/core": "6.22.1", - "@orval/mock": "6.22.1", - "@orval/query": "6.22.1", - "@orval/swr": "6.22.1", - "@orval/zod": "6.22.1", - "ajv": "^8.12.0", - "cac": "^6.7.14", - "chalk": "^4.1.2", - "chokidar": "^3.5.3", - "enquirer": "^2.4.1", - "execa": "^5.1.1", - "find-up": "5.0.0", - "fs-extra": "^11.2.0", - "lodash.uniq": "^4.5.0", - "openapi-types": "^12.1.3", - "openapi3-ts": "^3.2.0", - "string-argv": "^0.3.2", - "tsconfck": "^2.0.1" - }, - "bin": { - "orval": "dist/bin/orval.js" - } - }, - "node_modules/orval/node_modules/ajv": { - "version": "8.12.0", - "dev": true, - "license": "MIT", - "dependencies": { - "fast-deep-equal": "^3.1.1", - "json-schema-traverse": "^1.0.0", - "require-from-string": "^2.0.2", - "uri-js": "^4.2.2" - }, - "funding": { - "type": "github", - "url": "https://github.com/sponsors/epoberezkin" - } - }, - "node_modules/orval/node_modules/json-schema-traverse": { - "version": "1.0.0", - "dev": true, - "license": "MIT" - }, - "node_modules/p-limit": { - "version": "3.1.0", - "dev": true, - "license": "MIT", - "dependencies": { - "yocto-queue": "^0.1.0" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/p-locate": { - "version": "5.0.0", - "dev": true, - "license": "MIT", - "dependencies": { - "p-limit": "^3.0.2" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/pad-left": { - "version": "2.1.0", - "license": "MIT", - "dependencies": { - "repeat-string": "^1.5.4" - }, - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/parent-module": { - "version": "1.0.1", - "dev": true, - "license": "MIT", - "dependencies": { - "callsites": "^3.0.0" - }, - "engines": { - "node": ">=6" - } - }, - "node_modules/parse5": { - "version": "7.1.2", - "dev": true, - "license": "MIT", - "dependencies": { - "entities": "^4.4.0" - }, - "funding": { - "url": "https://github.com/inikulin/parse5?sponsor=1" - } - }, - "node_modules/path-exists": { - "version": "4.0.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=8" - } - }, - "node_modules/path-is-absolute": { - "version": "1.0.1", - "license": "MIT", - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/path-key": { - "version": "3.1.1", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=8" - } - }, - "node_modules/path-parse": { - "version": "1.0.7", - "license": "MIT" - }, - "node_modules/path-type": { - "version": "4.0.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=8" - } - }, - "node_modules/pathe": { - "version": "1.1.1", - "dev": true, - "license": "MIT" - }, - "node_modules/pathval": { - "version": "1.1.1", - "dev": true, - "license": "MIT", - "engines": { - "node": "*" - } - }, - "node_modules/picocolors": { - "version": "1.0.0", - "license": "ISC" - }, - "node_modules/picomatch": { - "version": "2.3.1", - "license": "MIT", - "engines": { - "node": ">=8.6" - }, - "funding": { - "url": "https://github.com/sponsors/jonschlinkert" - } - }, - "node_modules/pify": { - "version": "2.3.0", - "license": "MIT", - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/pirates": { - "version": "4.0.6", - "license": "MIT", - "engines": { - "node": ">= 6" - } - }, - "node_modules/pkg-types": { - "version": "1.0.3", - "dev": true, - "license": "MIT", - "dependencies": { - "jsonc-parser": "^3.2.0", - "mlly": "^1.2.0", - "pathe": "^1.1.0" - } - }, - "node_modules/pkg-types/node_modules/jsonc-parser": { - "version": "3.2.0", - "dev": true, - "license": "MIT" - }, - "node_modules/playwright": { - "version": "1.40.1", - "dev": true, - "license": "Apache-2.0", - "dependencies": { - "playwright-core": "1.40.1" - }, - "bin": { - "playwright": "cli.js" - }, - "engines": { - "node": ">=16" - }, - "optionalDependencies": { - "fsevents": "2.3.2" - } - }, - "node_modules/playwright-core": { - "version": "1.40.1", - "dev": true, - "license": "Apache-2.0", - "bin": { - "playwright-core": "cli.js" - }, - "engines": { - "node": ">=16" - } - }, - "node_modules/pluralize": { - "version": "8.0.0", - "license": "MIT", - "engines": { - "node": ">=4" - } - }, - "node_modules/pony-cause": { - "version": "1.1.1", - "dev": true, - "license": "0BSD", - "engines": { - "node": ">=12.0.0" - } - }, - "node_modules/postcss": { - "version": "8.4.32", - "funding": [ - { - "type": "opencollective", - "url": "https://opencollective.com/postcss/" - }, - { - "type": "tidelift", - "url": "https://tidelift.com/funding/github/npm/postcss" - }, - { - "type": "github", - "url": "https://github.com/sponsors/ai" - } - ], - "license": "MIT", - "dependencies": { - "nanoid": "^3.3.7", - "picocolors": "^1.0.0", - "source-map-js": "^1.0.2" - }, - "engines": { - "node": "^10 || ^12 || >=14" - } - }, - "node_modules/postcss-import": { - "version": "15.1.0", - "license": "MIT", - "dependencies": { - "postcss-value-parser": "^4.0.0", - "read-cache": "^1.0.0", - "resolve": "^1.1.7" - }, - "engines": { - "node": ">=14.0.0" - }, - "peerDependencies": { - "postcss": "^8.0.0" - } - }, - "node_modules/postcss-js": { - "version": "4.0.1", - "license": "MIT", - "dependencies": { - "camelcase-css": "^2.0.1" - }, - "engines": { - "node": "^12 || ^14 || >= 16" - }, - "funding": { - "type": "opencollective", - "url": "https://opencollective.com/postcss/" - }, - "peerDependencies": { - "postcss": "^8.4.21" - } - }, - "node_modules/postcss-load-config": { - "version": "4.0.2", - "funding": [ - { - "type": "opencollective", - "url": "https://opencollective.com/postcss/" - }, - { - "type": "github", - "url": "https://github.com/sponsors/ai" - } - ], - "license": "MIT", - "dependencies": { - "lilconfig": "^3.0.0", - "yaml": "^2.3.4" - }, - "engines": { - "node": ">= 14" - }, - "peerDependencies": { - "postcss": ">=8.0.9", - "ts-node": ">=9.0.0" - }, - "peerDependenciesMeta": { - "postcss": { - "optional": true - }, - "ts-node": { - "optional": true - } - } - }, - "node_modules/postcss-load-config/node_modules/lilconfig": { - "version": "3.0.0", - "license": "MIT", - "engines": { - "node": ">=14" - } - }, - "node_modules/postcss-nested": { - "version": "6.0.1", - "license": "MIT", - "dependencies": { - "postcss-selector-parser": "^6.0.11" - }, - "engines": { - "node": ">=12.0" - }, - "funding": { - "type": "opencollective", - "url": "https://opencollective.com/postcss/" - }, - "peerDependencies": { - "postcss": "^8.2.14" - } - }, - "node_modules/postcss-selector-parser": { - "version": "6.0.13", - "license": "MIT", - "dependencies": { - "cssesc": "^3.0.0", - "util-deprecate": "^1.0.2" - }, - "engines": { - "node": ">=4" - } - }, - "node_modules/postcss-value-parser": { - "version": "4.2.0", - "license": "MIT" - }, - "node_modules/prelude-ls": { - "version": "1.2.1", - "dev": true, - "license": "MIT", - "engines": { - "node": ">= 0.8.0" - } - }, - "node_modules/prettier": { - "version": "3.1.0", - "dev": true, - "license": "MIT", - "bin": { - "prettier": "bin/prettier.cjs" - }, - "engines": { - "node": ">=14" - }, - "funding": { - "url": "https://github.com/prettier/prettier?sponsor=1" - } - }, - "node_modules/pretty-format": { - "version": "27.5.1", - "dev": true, - "license": "MIT", - "dependencies": { - "ansi-regex": "^5.0.1", - "ansi-styles": "^5.0.0", - "react-is": "^17.0.1" - }, - "engines": { - "node": "^10.13.0 || ^12.13.0 || ^14.15.0 || >=15.0.0" - } - }, - "node_modules/pretty-format/node_modules/ansi-styles": { - "version": "5.2.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/chalk/ansi-styles?sponsor=1" - } - }, - "node_modules/prop-types": { - "version": "15.8.1", - "license": "MIT", - "dependencies": { - "loose-envify": "^1.4.0", - "object-assign": "^4.1.1", - "react-is": "^16.13.1" - } - }, - "node_modules/prop-types/node_modules/react-is": { - "version": "16.13.1", - "license": "MIT" - }, - "node_modules/psl": { - "version": "1.9.0", - "dev": true, - "license": "MIT" - }, - "node_modules/punycode": { - "version": "2.3.1", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=6" - } - }, - "node_modules/querystringify": { - "version": "2.2.0", - "dev": true, - "license": "MIT" - }, - "node_modules/queue-microtask": { - "version": "1.2.3", - "funding": [ - { - "type": "github", - "url": "https://github.com/sponsors/feross" - }, - { - "type": "patreon", - "url": "https://www.patreon.com/feross" - }, - { - "type": "consulting", - "url": "https://feross.org/support" - } - ], - "license": "MIT" - }, - "node_modules/react": { - "version": "18.2.0", - "license": "MIT", - "dependencies": { - "loose-envify": "^1.1.0" - }, - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/react-dnd": { - "version": "16.0.1", - "license": "MIT", - "dependencies": { - "@react-dnd/invariant": "^4.0.1", - "@react-dnd/shallowequal": "^4.0.1", - "dnd-core": "^16.0.1", - "fast-deep-equal": "^3.1.3", - "hoist-non-react-statics": "^3.3.2" - }, - "peerDependencies": { - "@types/hoist-non-react-statics": ">= 3.3.1", - "@types/node": ">= 12", - "@types/react": ">= 16", - "react": ">= 16.14" - }, - "peerDependenciesMeta": { - "@types/hoist-non-react-statics": { - "optional": true - }, - "@types/node": { - "optional": true - }, - "@types/react": { - "optional": true - } - } - }, - "node_modules/react-dnd-html5-backend": { - "version": "16.0.1", - "license": "MIT", - "dependencies": { - "dnd-core": "^16.0.1" - } - }, - "node_modules/react-dom": { - "version": "18.2.0", - "license": "MIT", - "dependencies": { - "loose-envify": "^1.1.0", - "scheduler": "^0.23.0" - }, - "peerDependencies": { - "react": "^18.2.0" - } - }, - "node_modules/react-is": { - "version": "17.0.2", - "dev": true, - "license": "MIT" - }, - "node_modules/react-remove-scroll": { - "version": "2.5.5", - "license": "MIT", - "dependencies": { - "react-remove-scroll-bar": "^2.3.3", - "react-style-singleton": "^2.2.1", - "tslib": "^2.1.0", - "use-callback-ref": "^1.3.0", - "use-sidecar": "^1.1.2" - }, - "engines": { - "node": ">=10" - }, - "peerDependencies": { - "@types/react": "^16.8.0 || ^17.0.0 || ^18.0.0", - "react": "^16.8.0 || ^17.0.0 || ^18.0.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/react-remove-scroll-bar": { - "version": "2.3.4", - "license": "MIT", - "dependencies": { - "react-style-singleton": "^2.2.1", - "tslib": "^2.0.0" - }, - "engines": { - "node": ">=10" - }, - "peerDependencies": { - "@types/react": "^16.8.0 || ^17.0.0 || ^18.0.0", - "react": "^16.8.0 || ^17.0.0 || ^18.0.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/react-router": { - "version": "6.20.1", - "license": "MIT", - "dependencies": { - "@remix-run/router": "1.13.1" - }, - "engines": { - "node": ">=14.0.0" - }, - "peerDependencies": { - "react": ">=16.8" - } - }, - "node_modules/react-router-dom": { - "version": "6.20.1", - "license": "MIT", - "dependencies": { - "@remix-run/router": "1.13.1", - "react-router": "6.20.1" - }, - "engines": { - "node": ">=14.0.0" - }, - "peerDependencies": { - "react": ">=16.8", - "react-dom": ">=16.8" - } - }, - "node_modules/react-split": { - "version": "2.0.14", - "license": "MIT", - "dependencies": { - "prop-types": "^15.5.7", - "split.js": "^1.6.0" - }, - "peerDependencies": { - "react": "*" - } - }, - "node_modules/react-style-singleton": { - "version": "2.2.1", - "license": "MIT", - "dependencies": { - "get-nonce": "^1.0.0", - "invariant": "^2.2.4", - "tslib": "^2.0.0" - }, - "engines": { - "node": ">=10" - }, - "peerDependencies": { - "@types/react": "^16.8.0 || ^17.0.0 || ^18.0.0", - "react": "^16.8.0 || ^17.0.0 || ^18.0.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/reactflow": { - "version": "11.10.1", - "license": "MIT", - "dependencies": { - "@reactflow/background": "11.3.6", - "@reactflow/controls": "11.2.6", - "@reactflow/core": "11.10.1", - "@reactflow/minimap": "11.7.6", - "@reactflow/node-resizer": "2.2.6", - "@reactflow/node-toolbar": "1.3.6" - }, - "peerDependencies": { - "react": ">=17", - "react-dom": ">=17" - } - }, - "node_modules/read-cache": { - "version": "1.0.0", - "license": "MIT", - "dependencies": { - "pify": "^2.3.0" - } - }, - "node_modules/readdirp": { - "version": "3.6.0", - "license": "MIT", - "dependencies": { - "picomatch": "^2.2.1" - }, - "engines": { - "node": ">=8.10.0" - } - }, - "node_modules/redent": { - "version": "3.0.0", - "dev": true, - "license": "MIT", - "dependencies": { - "indent-string": "^4.0.0", - "strip-indent": "^3.0.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/redux": { - "version": "4.2.1", - "license": "MIT", - "dependencies": { - "@babel/runtime": "^7.9.2" - } - }, - "node_modules/reflect.getprototypeof": { - "version": "1.0.4", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.2", - "define-properties": "^1.2.0", - "es-abstract": "^1.22.1", - "get-intrinsic": "^1.2.1", - "globalthis": "^1.0.3", - "which-builtin-type": "^1.1.3" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/reftools": { - "version": "1.1.9", - "dev": true, - "license": "BSD-3-Clause", - "funding": { - "url": "https://github.com/Mermade/oas-kit?sponsor=1" - } - }, - "node_modules/regenerator-runtime": { - "version": "0.14.0", - "license": "MIT" - }, - "node_modules/regexp.prototype.flags": { - "version": "1.5.1", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.2", - "define-properties": "^1.2.0", - "set-function-name": "^2.0.0" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/repeat-string": { - "version": "1.6.1", - "license": "MIT", - "engines": { - "node": ">=0.10" - } - }, - "node_modules/require-directory": { - "version": "2.1.1", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/require-from-string": { - "version": "2.0.2", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/requires-port": { - "version": "1.0.0", - "dev": true, - "license": "MIT" - }, - "node_modules/resolve": { - "version": "1.22.8", - "license": "MIT", - "dependencies": { - "is-core-module": "^2.13.0", - "path-parse": "^1.0.7", - "supports-preserve-symlinks-flag": "^1.0.0" - }, - "bin": { - "resolve": "bin/resolve" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/resolve-from": { - "version": "4.0.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=4" - } - }, - "node_modules/resolve-pkg-maps": { - "version": "1.0.0", - "dev": true, - "license": "MIT", - "funding": { - "url": "https://github.com/privatenumber/resolve-pkg-maps?sponsor=1" - } - }, - "node_modules/reusify": { - "version": "1.0.4", - "license": "MIT", - "engines": { - "iojs": ">=1.0.0", - "node": ">=0.10.0" - } - }, - "node_modules/rimraf": { - "version": "3.0.2", - "dev": true, - "license": "ISC", - "dependencies": { - "glob": "^7.1.3" - }, - "bin": { - "rimraf": "bin.js" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" - } - }, - "node_modules/rollup": { - "version": "3.29.4", - "dev": true, - "license": "MIT", - "bin": { - "rollup": "dist/bin/rollup" - }, - "engines": { - "node": ">=14.18.0", - "npm": ">=8.0.0" - }, - "optionalDependencies": { - "fsevents": "~2.3.2" - } - }, - "node_modules/rrweb-cssom": { - "version": "0.6.0", - "dev": true, - "license": "MIT" - }, - "node_modules/run-parallel": { - "version": "1.2.0", - "funding": [ - { - "type": "github", - "url": "https://github.com/sponsors/feross" - }, - { - "type": "patreon", - "url": "https://www.patreon.com/feross" - }, - { - "type": "consulting", - "url": "https://feross.org/support" - } - ], - "license": "MIT", - "dependencies": { - "queue-microtask": "^1.2.2" - } - }, - "node_modules/safe-array-concat": { - "version": "1.0.1", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.2", - "get-intrinsic": "^1.2.1", - "has-symbols": "^1.0.3", - "isarray": "^2.0.5" - }, - "engines": { - "node": ">=0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/safe-regex-test": { - "version": "1.0.0", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.2", - "get-intrinsic": "^1.1.3", - "is-regex": "^1.1.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/safe-stable-stringify": { - "version": "1.1.1", - "dev": true, - "license": "MIT" - }, - "node_modules/safer-buffer": { - "version": "2.1.2", - "dev": true, - "license": "MIT" - }, - "node_modules/saxes": { - "version": "6.0.0", - "dev": true, - "license": "ISC", - "dependencies": { - "xmlchars": "^2.2.0" - }, - "engines": { - "node": ">=v12.22.7" - } - }, - "node_modules/scheduler": { - "version": "0.23.0", - "license": "MIT", - "dependencies": { - "loose-envify": "^1.1.0" - } - }, - "node_modules/semver": { - "version": "7.5.4", - "dev": true, - "license": "ISC", - "dependencies": { - "lru-cache": "^6.0.0" - }, - "bin": { - "semver": "bin/semver.js" - }, - "engines": { - "node": ">=10" - } - }, - "node_modules/set-function-length": { - "version": "1.1.1", - "dev": true, - "license": "MIT", - "dependencies": { - "define-data-property": "^1.1.1", - "get-intrinsic": "^1.2.1", - "gopd": "^1.0.1", - "has-property-descriptors": "^1.0.0" - }, - "engines": { - "node": ">= 0.4" - } - }, - "node_modules/set-function-name": { - "version": "2.0.1", - "dev": true, - "license": "MIT", - "dependencies": { - "define-data-property": "^1.0.1", - "functions-have-names": "^1.2.3", - "has-property-descriptors": "^1.0.0" - }, - "engines": { - "node": ">= 0.4" - } - }, - "node_modules/shebang-command": { - "version": "2.0.0", - "dev": true, - "license": "MIT", - "dependencies": { - "shebang-regex": "^3.0.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/shebang-regex": { - "version": "3.0.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=8" - } - }, - "node_modules/should": { - "version": "13.2.3", - "dev": true, - "license": "MIT", - "dependencies": { - "should-equal": "^2.0.0", - "should-format": "^3.0.3", - "should-type": "^1.4.0", - "should-type-adaptors": "^1.0.1", - "should-util": "^1.0.0" - } - }, - "node_modules/should-equal": { - "version": "2.0.0", - "dev": true, - "license": "MIT", - "dependencies": { - "should-type": "^1.4.0" - } - }, - "node_modules/should-format": { - "version": "3.0.3", - "dev": true, - "license": "MIT", - "dependencies": { - "should-type": "^1.3.0", - "should-type-adaptors": "^1.0.1" - } - }, - "node_modules/should-type": { - "version": "1.4.0", - "dev": true, - "license": "MIT" - }, - "node_modules/should-type-adaptors": { - "version": "1.1.0", - "dev": true, - "license": "MIT", - "dependencies": { - "should-type": "^1.3.0", - "should-util": "^1.0.0" - } - }, - "node_modules/should-util": { - "version": "1.0.1", - "dev": true, - "license": "MIT" - }, - "node_modules/side-channel": { - "version": "1.0.4", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.0", - "get-intrinsic": "^1.0.2", - "object-inspect": "^1.9.0" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/siginfo": { - "version": "2.0.0", - "dev": true, - "license": "ISC" - }, - "node_modules/signal-exit": { - "version": "3.0.7", - "dev": true, - "license": "ISC" - }, - "node_modules/simple-eval": { - "version": "1.0.0", - "dev": true, - "license": "MIT", - "dependencies": { - "jsep": "^1.1.2" - }, - "engines": { - "node": ">=12" - } - }, - "node_modules/slash": { - "version": "3.0.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=8" - } - }, - "node_modules/source-map-js": { - "version": "1.0.2", - "license": "BSD-3-Clause", - "engines": { - "node": ">=0.10.0" - } - }, - "node_modules/split.js": { - "version": "1.6.5", - "license": "MIT" - }, - "node_modules/sprintf-js": { - "version": "1.0.3", - "dev": true, - "license": "BSD-3-Clause" - }, - "node_modules/stackback": { - "version": "0.0.2", - "dev": true, - "license": "MIT" - }, - "node_modules/std-env": { - "version": "3.6.0", - "dev": true, - "license": "MIT" - }, - "node_modules/stop-iteration-iterator": { - "version": "1.0.0", - "dev": true, - "license": "MIT", - "dependencies": { - "internal-slot": "^1.0.4" - }, - "engines": { - "node": ">= 0.4" - } - }, - "node_modules/stream-read-all": { - "version": "3.0.1", - "license": "MIT", - "engines": { - "node": ">=10" - } - }, - "node_modules/string-argv": { - "version": "0.3.2", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=0.6.19" - } - }, - "node_modules/string-width": { - "version": "4.2.3", - "dev": true, - "license": "MIT", - "dependencies": { - "emoji-regex": "^8.0.0", - "is-fullwidth-code-point": "^3.0.0", - "strip-ansi": "^6.0.1" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/string.prototype.matchall": { - "version": "4.0.10", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.2", - "define-properties": "^1.2.0", - "es-abstract": "^1.22.1", - "get-intrinsic": "^1.2.1", - "has-symbols": "^1.0.3", - "internal-slot": "^1.0.5", - "regexp.prototype.flags": "^1.5.0", - "set-function-name": "^2.0.0", - "side-channel": "^1.0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/string.prototype.trim": { - "version": "1.2.8", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.2", - "define-properties": "^1.2.0", - "es-abstract": "^1.22.1" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/string.prototype.trimend": { - "version": "1.0.7", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.2", - "define-properties": "^1.2.0", - "es-abstract": "^1.22.1" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/string.prototype.trimstart": { - "version": "1.0.7", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.2", - "define-properties": "^1.2.0", - "es-abstract": "^1.22.1" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/strip-ansi": { - "version": "6.0.1", - "dev": true, - "license": "MIT", - "dependencies": { - "ansi-regex": "^5.0.1" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/strip-bom": { - "version": "3.0.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=4" - } - }, - "node_modules/strip-final-newline": { - "version": "2.0.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=6" - } - }, - "node_modules/strip-indent": { - "version": "3.0.0", - "dev": true, - "license": "MIT", - "dependencies": { - "min-indent": "^1.0.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/strip-json-comments": { - "version": "3.1.1", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=8" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/strip-literal": { - "version": "1.3.0", - "dev": true, - "license": "MIT", - "dependencies": { - "acorn": "^8.10.0" - }, - "funding": { - "url": "https://github.com/sponsors/antfu" - } - }, - "node_modules/style-mod": { - "version": "4.1.0", - "license": "MIT" - }, - "node_modules/sucrase": { - "version": "3.34.0", - "license": "MIT", - "dependencies": { - "@jridgewell/gen-mapping": "^0.3.2", - "commander": "^4.0.0", - "glob": "7.1.6", - "lines-and-columns": "^1.1.6", - "mz": "^2.7.0", - "pirates": "^4.0.1", - "ts-interface-checker": "^0.1.9" - }, - "bin": { - "sucrase": "bin/sucrase", - "sucrase-node": "bin/sucrase-node" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/sucrase/node_modules/glob": { - "version": "7.1.6", - "license": "ISC", - "dependencies": { - "fs.realpath": "^1.0.0", - "inflight": "^1.0.4", - "inherits": "2", - "minimatch": "^3.0.4", - "once": "^1.3.0", - "path-is-absolute": "^1.0.0" - }, - "engines": { - "node": "*" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" - } - }, - "node_modules/supports-color": { - "version": "7.2.0", - "license": "MIT", - "dependencies": { - "has-flag": "^4.0.0" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/supports-preserve-symlinks-flag": { - "version": "1.0.0", - "license": "MIT", - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/swagger2openapi": { - "version": "7.0.8", - "dev": true, - "license": "BSD-3-Clause", - "dependencies": { - "call-me-maybe": "^1.0.1", - "node-fetch": "^2.6.1", - "node-fetch-h2": "^2.3.0", - "node-readfiles": "^0.2.0", - "oas-kit-common": "^1.0.8", - "oas-resolver": "^2.5.6", - "oas-schema-walker": "^1.1.5", - "oas-validator": "^5.0.8", - "reftools": "^1.1.9", - "yaml": "^1.10.0", - "yargs": "^17.0.1" - }, - "bin": { - "boast": "boast.js", - "oas-validate": "oas-validate.js", - "swagger2openapi": "swagger2openapi.js" - }, - "funding": { - "url": "https://github.com/Mermade/oas-kit?sponsor=1" - } - }, - "node_modules/swagger2openapi/node_modules/yaml": { - "version": "1.10.2", - "dev": true, - "license": "ISC", - "engines": { - "node": ">= 6" - } - }, - "node_modules/symbol-tree": { - "version": "3.2.4", - "dev": true, - "license": "MIT" - }, - "node_modules/table-layout": { - "version": "3.0.2", - "license": "MIT", - "dependencies": { - "@75lb/deep-merge": "^1.1.1", - "array-back": "^6.2.2", - "command-line-args": "^5.2.1", - "command-line-usage": "^7.0.0", - "stream-read-all": "^3.0.1", - "typical": "^7.1.1", - "wordwrapjs": "^5.1.0" - }, - "bin": { - "table-layout": "bin/cli.js" - }, - "engines": { - "node": ">=12.17" - } - }, - "node_modules/table-layout/node_modules/array-back": { - "version": "6.2.2", - "license": "MIT", - "engines": { - "node": ">=12.17" - } - }, - "node_modules/table-layout/node_modules/typical": { - "version": "7.1.1", - "license": "MIT", - "engines": { - "node": ">=12.17" - } - }, - "node_modules/tailwindcss": { - "version": "3.3.6", - "license": "MIT", - "dependencies": { - "@alloc/quick-lru": "^5.2.0", - "arg": "^5.0.2", - "chokidar": "^3.5.3", - "didyoumean": "^1.2.2", - "dlv": "^1.1.3", - "fast-glob": "^3.3.0", - "glob-parent": "^6.0.2", - "is-glob": "^4.0.3", - "jiti": "^1.19.1", - "lilconfig": "^2.1.0", - "micromatch": "^4.0.5", - "normalize-path": "^3.0.0", - "object-hash": "^3.0.0", - "picocolors": "^1.0.0", - "postcss": "^8.4.23", - "postcss-import": "^15.1.0", - "postcss-js": "^4.0.1", - "postcss-load-config": "^4.0.1", - "postcss-nested": "^6.0.1", - "postcss-selector-parser": "^6.0.11", - "resolve": "^1.22.2", - "sucrase": "^3.32.0" - }, - "bin": { - "tailwind": "lib/cli.js", - "tailwindcss": "lib/cli.js" - }, - "engines": { - "node": ">=14.0.0" - } - }, - "node_modules/text-table": { - "version": "0.2.0", - "dev": true, - "license": "MIT" - }, - "node_modules/thememirror": { - "version": "2.0.1", - "license": "MIT", - "peerDependencies": { - "@codemirror/language": "^6.0.0", - "@codemirror/state": "^6.0.0", - "@codemirror/view": "^6.0.0" - } - }, - "node_modules/thenify": { - "version": "3.3.1", - "license": "MIT", - "dependencies": { - "any-promise": "^1.0.0" - } - }, - "node_modules/thenify-all": { - "version": "1.6.0", - "license": "MIT", - "dependencies": { - "thenify": ">= 3.1.0 < 4" - }, - "engines": { - "node": ">=0.8" - } - }, - "node_modules/tinybench": { - "version": "2.5.1", - "dev": true, - "license": "MIT" - }, - "node_modules/tinypool": { - "version": "0.7.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=14.0.0" - } - }, - "node_modules/tinyspy": { - "version": "2.2.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=14.0.0" - } - }, - "node_modules/to-regex-range": { - "version": "5.0.1", - "license": "MIT", - "dependencies": { - "is-number": "^7.0.0" - }, - "engines": { - "node": ">=8.0" - } - }, - "node_modules/tough-cookie": { - "version": "4.1.3", - "dev": true, - "license": "BSD-3-Clause", - "dependencies": { - "psl": "^1.1.33", - "punycode": "^2.1.1", - "universalify": "^0.2.0", - "url-parse": "^1.5.3" - }, - "engines": { - "node": ">=6" - } - }, - "node_modules/tough-cookie/node_modules/universalify": { - "version": "0.2.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">= 4.0.0" - } - }, - "node_modules/tr46": { - "version": "4.1.1", - "dev": true, - "license": "MIT", - "dependencies": { - "punycode": "^2.3.0" - }, - "engines": { - "node": ">=14" - } - }, - "node_modules/ts-api-utils": { - "version": "1.0.3", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=16.13.0" - }, - "peerDependencies": { - "typescript": ">=4.2.0" - } - }, - "node_modules/ts-interface-checker": { - "version": "0.1.13", - "license": "Apache-2.0" - }, - "node_modules/tsconfck": { - "version": "2.1.2", - "dev": true, - "license": "MIT", - "bin": { - "tsconfck": "bin/tsconfck.js" - }, - "engines": { - "node": "^14.13.1 || ^16 || >=18" - }, - "peerDependencies": { - "typescript": "^4.3.5 || ^5.0.0" - }, - "peerDependenciesMeta": { - "typescript": { - "optional": true - } - } - }, - "node_modules/tsconfig-paths": { - "version": "3.14.2", - "dev": true, - "license": "MIT", - "dependencies": { - "@types/json5": "^0.0.29", - "json5": "^1.0.2", - "minimist": "^1.2.6", - "strip-bom": "^3.0.0" - } - }, - "node_modules/tslib": { - "version": "2.6.2", - "license": "0BSD" - }, - "node_modules/type-check": { - "version": "0.4.0", - "dev": true, - "license": "MIT", - "dependencies": { - "prelude-ls": "^1.2.1" - }, - "engines": { - "node": ">= 0.8.0" - } - }, - "node_modules/type-detect": { - "version": "4.0.8", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=4" - } - }, - "node_modules/type-fest": { - "version": "0.20.2", - "dev": true, - "license": "(MIT OR CC0-1.0)", - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/typed-array-buffer": { - "version": "1.0.0", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.2", - "get-intrinsic": "^1.2.1", - "is-typed-array": "^1.1.10" - }, - "engines": { - "node": ">= 0.4" - } - }, - "node_modules/typed-array-byte-length": { - "version": "1.0.0", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.2", - "for-each": "^0.3.3", - "has-proto": "^1.0.1", - "is-typed-array": "^1.1.10" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/typed-array-byte-offset": { - "version": "1.0.0", - "dev": true, - "license": "MIT", - "dependencies": { - "available-typed-arrays": "^1.0.5", - "call-bind": "^1.0.2", - "for-each": "^0.3.3", - "has-proto": "^1.0.1", - "is-typed-array": "^1.1.10" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/typed-array-length": { - "version": "1.0.4", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.2", - "for-each": "^0.3.3", - "is-typed-array": "^1.1.9" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/typescript": { - "version": "5.3.3", - "dev": true, - "license": "Apache-2.0", - "bin": { - "tsc": "bin/tsc", - "tsserver": "bin/tsserver" - }, - "engines": { - "node": ">=14.17" - } - }, - "node_modules/typical": { - "version": "4.0.0", - "license": "MIT", - "engines": { - "node": ">=8" - } - }, - "node_modules/ufo": { - "version": "1.3.2", - "dev": true, - "license": "MIT" - }, - "node_modules/unbox-primitive": { - "version": "1.0.2", - "dev": true, - "license": "MIT", - "dependencies": { - "call-bind": "^1.0.2", - "has-bigints": "^1.0.2", - "has-symbols": "^1.0.3", - "which-boxed-primitive": "^1.0.2" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/universalify": { - "version": "2.0.1", - "dev": true, - "license": "MIT", - "engines": { - "node": ">= 10.0.0" - } - }, - "node_modules/update-browserslist-db": { - "version": "1.0.13", - "dev": true, - "funding": [ - { - "type": "opencollective", - "url": "https://opencollective.com/browserslist" - }, - { - "type": "tidelift", - "url": "https://tidelift.com/funding/github/npm/browserslist" - }, - { - "type": "github", - "url": "https://github.com/sponsors/ai" - } - ], - "license": "MIT", - "dependencies": { - "escalade": "^3.1.1", - "picocolors": "^1.0.0" - }, - "bin": { - "update-browserslist-db": "cli.js" - }, - "peerDependencies": { - "browserslist": ">= 4.21.0" - } - }, - "node_modules/uri-js": { - "version": "4.4.1", - "dev": true, - "license": "BSD-2-Clause", - "dependencies": { - "punycode": "^2.1.0" - } - }, - "node_modules/urijs": { - "version": "1.19.11", - "dev": true, - "license": "MIT" - }, - "node_modules/url-parse": { - "version": "1.5.10", - "dev": true, - "license": "MIT", - "dependencies": { - "querystringify": "^2.1.1", - "requires-port": "^1.0.0" - } - }, - "node_modules/use-callback-ref": { - "version": "1.3.0", - "license": "MIT", - "dependencies": { - "tslib": "^2.0.0" - }, - "engines": { - "node": ">=10" - }, - "peerDependencies": { - "@types/react": "^16.8.0 || ^17.0.0 || ^18.0.0", - "react": "^16.8.0 || ^17.0.0 || ^18.0.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/use-sidecar": { - "version": "1.1.2", - "license": "MIT", - "dependencies": { - "detect-node-es": "^1.1.0", - "tslib": "^2.0.0" - }, - "engines": { - "node": ">=10" - }, - "peerDependencies": { - "@types/react": "^16.9.0 || ^17.0.0 || ^18.0.0", - "react": "^16.8.0 || ^17.0.0 || ^18.0.0" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - } - } - }, - "node_modules/use-sync-external-store": { - "version": "1.2.0", - "license": "MIT", - "peerDependencies": { - "react": "^16.8.0 || ^17.0.0 || ^18.0.0" - } - }, - "node_modules/util-deprecate": { - "version": "1.0.2", - "license": "MIT" - }, - "node_modules/utility-types": { - "version": "3.10.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">= 4" - } - }, - "node_modules/validator": { - "version": "13.11.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">= 0.10" - } - }, - "node_modules/vite": { - "version": "4.5.1", - "dev": true, - "license": "MIT", - "dependencies": { - "esbuild": "^0.18.10", - "postcss": "^8.4.27", - "rollup": "^3.27.1" - }, - "bin": { - "vite": "bin/vite.js" - }, - "engines": { - "node": "^14.18.0 || >=16.0.0" - }, - "funding": { - "url": "https://github.com/vitejs/vite?sponsor=1" - }, - "optionalDependencies": { - "fsevents": "~2.3.2" - }, - "peerDependencies": { - "@types/node": ">= 14", - "less": "*", - "lightningcss": "^1.21.0", - "sass": "*", - "stylus": "*", - "sugarss": "*", - "terser": "^5.4.0" - }, - "peerDependenciesMeta": { - "@types/node": { - "optional": true - }, - "less": { - "optional": true - }, - "lightningcss": { - "optional": true - }, - "sass": { - "optional": true - }, - "stylus": { - "optional": true - }, - "sugarss": { - "optional": true - }, - "terser": { - "optional": true - } - } - }, - "node_modules/vite-node": { - "version": "0.34.6", - "dev": true, - "license": "MIT", - "dependencies": { - "cac": "^6.7.14", - "debug": "^4.3.4", - "mlly": "^1.4.0", - "pathe": "^1.1.1", - "picocolors": "^1.0.0", - "vite": "^3.0.0 || ^4.0.0 || ^5.0.0-0" - }, - "bin": { - "vite-node": "vite-node.mjs" - }, - "engines": { - "node": ">=v14.18.0" - }, - "funding": { - "url": "https://opencollective.com/vitest" - } - }, - "node_modules/vite/node_modules/@esbuild/linux-arm64": { - "version": "0.18.20", - "cpu": [ - "arm64" - ], - "dev": true, - "license": "MIT", - "optional": true, - "os": [ - "linux" - ], - "engines": { - "node": ">=12" - } - }, - "node_modules/vite/node_modules/esbuild": { - "version": "0.18.20", - "dev": true, - "hasInstallScript": true, - "license": "MIT", - "bin": { - "esbuild": "bin/esbuild" - }, - "engines": { - "node": ">=12" - }, - "optionalDependencies": { - "@esbuild/android-arm": "0.18.20", - "@esbuild/android-arm64": "0.18.20", - "@esbuild/android-x64": "0.18.20", - "@esbuild/darwin-arm64": "0.18.20", - "@esbuild/darwin-x64": "0.18.20", - "@esbuild/freebsd-arm64": "0.18.20", - "@esbuild/freebsd-x64": "0.18.20", - "@esbuild/linux-arm": "0.18.20", - "@esbuild/linux-arm64": "0.18.20", - "@esbuild/linux-ia32": "0.18.20", - "@esbuild/linux-loong64": "0.18.20", - "@esbuild/linux-mips64el": "0.18.20", - "@esbuild/linux-ppc64": "0.18.20", - "@esbuild/linux-riscv64": "0.18.20", - "@esbuild/linux-s390x": "0.18.20", - "@esbuild/linux-x64": "0.18.20", - "@esbuild/netbsd-x64": "0.18.20", - "@esbuild/openbsd-x64": "0.18.20", - "@esbuild/sunos-x64": "0.18.20", - "@esbuild/win32-arm64": "0.18.20", - "@esbuild/win32-ia32": "0.18.20", - "@esbuild/win32-x64": "0.18.20" - } - }, - "node_modules/vitest": { - "version": "0.34.6", - "dev": true, - "license": "MIT", - "dependencies": { - "@types/chai": "^4.3.5", - "@types/chai-subset": "^1.3.3", - "@types/node": "*", - "@vitest/expect": "0.34.6", - "@vitest/runner": "0.34.6", - "@vitest/snapshot": "0.34.6", - "@vitest/spy": "0.34.6", - "@vitest/utils": "0.34.6", - "acorn": "^8.9.0", - "acorn-walk": "^8.2.0", - "cac": "^6.7.14", - "chai": "^4.3.10", - "debug": "^4.3.4", - "local-pkg": "^0.4.3", - "magic-string": "^0.30.1", - "pathe": "^1.1.1", - "picocolors": "^1.0.0", - "std-env": "^3.3.3", - "strip-literal": "^1.0.1", - "tinybench": "^2.5.0", - "tinypool": "^0.7.0", - "vite": "^3.1.0 || ^4.0.0 || ^5.0.0-0", - "vite-node": "0.34.6", - "why-is-node-running": "^2.2.2" - }, - "bin": { - "vitest": "vitest.mjs" - }, - "engines": { - "node": ">=v14.18.0" - }, - "funding": { - "url": "https://opencollective.com/vitest" - }, - "peerDependencies": { - "@edge-runtime/vm": "*", - "@vitest/browser": "*", - "@vitest/ui": "*", - "happy-dom": "*", - "jsdom": "*", - "playwright": "*", - "safaridriver": "*", - "webdriverio": "*" - }, - "peerDependenciesMeta": { - "@edge-runtime/vm": { - "optional": true - }, - "@vitest/browser": { - "optional": true - }, - "@vitest/ui": { - "optional": true - }, - "happy-dom": { - "optional": true - }, - "jsdom": { - "optional": true - }, - "playwright": { - "optional": true - }, - "safaridriver": { - "optional": true - }, - "webdriverio": { - "optional": true - } - } - }, - "node_modules/w3c-keyname": { - "version": "2.2.8", - "license": "MIT" - }, - "node_modules/w3c-xmlserializer": { - "version": "4.0.0", - "dev": true, - "license": "MIT", - "dependencies": { - "xml-name-validator": "^4.0.0" - }, - "engines": { - "node": ">=14" - } - }, - "node_modules/webidl-conversions": { - "version": "7.0.0", - "dev": true, - "license": "BSD-2-Clause", - "engines": { - "node": ">=12" - } - }, - "node_modules/whatwg-encoding": { - "version": "2.0.0", - "dev": true, - "license": "MIT", - "dependencies": { - "iconv-lite": "0.6.3" - }, - "engines": { - "node": ">=12" - } - }, - "node_modules/whatwg-mimetype": { - "version": "3.0.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=12" - } - }, - "node_modules/whatwg-url": { - "version": "12.0.1", - "dev": true, - "license": "MIT", - "dependencies": { - "tr46": "^4.1.1", - "webidl-conversions": "^7.0.0" - }, - "engines": { - "node": ">=14" - } - }, - "node_modules/which": { - "version": "2.0.2", - "dev": true, - "license": "ISC", - "dependencies": { - "isexe": "^2.0.0" - }, - "bin": { - "node-which": "bin/node-which" - }, - "engines": { - "node": ">= 8" - } - }, - "node_modules/which-boxed-primitive": { - "version": "1.0.2", - "dev": true, - "license": "MIT", - "dependencies": { - "is-bigint": "^1.0.1", - "is-boolean-object": "^1.1.0", - "is-number-object": "^1.0.4", - "is-string": "^1.0.5", - "is-symbol": "^1.0.3" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/which-builtin-type": { - "version": "1.1.3", - "dev": true, - "license": "MIT", - "dependencies": { - "function.prototype.name": "^1.1.5", - "has-tostringtag": "^1.0.0", - "is-async-function": "^2.0.0", - "is-date-object": "^1.0.5", - "is-finalizationregistry": "^1.0.2", - "is-generator-function": "^1.0.10", - "is-regex": "^1.1.4", - "is-weakref": "^1.0.2", - "isarray": "^2.0.5", - "which-boxed-primitive": "^1.0.2", - "which-collection": "^1.0.1", - "which-typed-array": "^1.1.9" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/which-collection": { - "version": "1.0.1", - "dev": true, - "license": "MIT", - "dependencies": { - "is-map": "^2.0.1", - "is-set": "^2.0.1", - "is-weakmap": "^2.0.1", - "is-weakset": "^2.0.1" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/which-typed-array": { - "version": "1.1.13", - "dev": true, - "license": "MIT", - "dependencies": { - "available-typed-arrays": "^1.0.5", - "call-bind": "^1.0.4", - "for-each": "^0.3.3", - "gopd": "^1.0.1", - "has-tostringtag": "^1.0.0" - }, - "engines": { - "node": ">= 0.4" - }, - "funding": { - "url": "https://github.com/sponsors/ljharb" - } - }, - "node_modules/why-is-node-running": { - "version": "2.2.2", - "dev": true, - "license": "MIT", - "dependencies": { - "siginfo": "^2.0.0", - "stackback": "0.0.2" - }, - "bin": { - "why-is-node-running": "cli.js" - }, - "engines": { - "node": ">=8" - } - }, - "node_modules/wordwrapjs": { - "version": "5.1.0", - "license": "MIT", - "engines": { - "node": ">=12.17" - } - }, - "node_modules/wrap-ansi": { - "version": "7.0.0", - "dev": true, - "license": "MIT", - "dependencies": { - "ansi-styles": "^4.0.0", - "string-width": "^4.1.0", - "strip-ansi": "^6.0.0" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/chalk/wrap-ansi?sponsor=1" - } - }, - "node_modules/wrappy": { - "version": "1.0.2", - "license": "ISC" - }, - "node_modules/ws": { - "version": "8.14.2", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=10.0.0" - }, - "peerDependencies": { - "bufferutil": "^4.0.1", - "utf-8-validate": ">=5.0.2" - }, - "peerDependenciesMeta": { - "bufferutil": { - "optional": true - }, - "utf-8-validate": { - "optional": true - } - } - }, - "node_modules/xml-name-validator": { - "version": "4.0.0", - "dev": true, - "license": "Apache-2.0", - "engines": { - "node": ">=12" - } - }, - "node_modules/xmlchars": { - "version": "2.2.0", - "dev": true, - "license": "MIT" - }, - "node_modules/y18n": { - "version": "5.0.8", - "dev": true, - "license": "ISC", - "engines": { - "node": ">=10" - } - }, - "node_modules/yallist": { - "version": "4.0.0", - "dev": true, - "license": "ISC" - }, - "node_modules/yaml": { - "version": "2.3.4", - "license": "ISC", - "engines": { - "node": ">= 14" - } - }, - "node_modules/yargs": { - "version": "17.7.2", - "dev": true, - "license": "MIT", - "dependencies": { - "cliui": "^8.0.1", - "escalade": "^3.1.1", - "get-caller-file": "^2.0.5", - "require-directory": "^2.1.1", - "string-width": "^4.2.3", - "y18n": "^5.0.5", - "yargs-parser": "^21.1.1" - }, - "engines": { - "node": ">=12" - } - }, - "node_modules/yargs-parser": { - "version": "21.1.1", - "dev": true, - "license": "ISC", - "engines": { - "node": ">=12" - } - }, - "node_modules/yocto-queue": { - "version": "0.1.0", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sponsors/sindresorhus" - } - }, - "node_modules/zustand": { - "version": "4.4.7", - "license": "MIT", - "dependencies": { - "use-sync-external-store": "1.2.0" - }, - "engines": { - "node": ">=12.7.0" - }, - "peerDependencies": { - "@types/react": ">=16.8", - "immer": ">=9.0", - "react": ">=16.8" - }, - "peerDependenciesMeta": { - "@types/react": { - "optional": true - }, - "immer": { - "optional": true - }, - "react": { - "optional": true - } - } - } - } -} diff --git a/web/client/package.json b/web/client/package.json index b56ec9fc07..b6c99153f9 100644 --- a/web/client/package.json +++ b/web/client/package.json @@ -2,13 +2,12 @@ "name": "tobiko", "version": "0.0.0", "scripts": { - "dev": "npm run generate:api && vite", - "build": "npm run generate:api && tsc && vite build", + "dev": "pnpm run generate:api && vite", + "build": "pnpm run generate:api && tsc && vite build", "preview": "vite preview --no-open", - "prettier": "prettier --write .", - "lint:fix": "eslint --fix .", - "format": "npm run prettier && npm run lint:fix", - "test": "npm run generate:api npm run test:unit && npm run test:e2e", + "lint": "eslint", + "lint:fix": "eslint --fix", + "test": "pnpm run generate:api && pnpm run test:unit && pnpm run test:e2e", "test:unit:watch": "NODE_ENV=development vitest --watch=true", "test:unit": "NODE_ENV=testing vitest --watch=false", "test:e2e": "NODE_ENV=testing playwright test", @@ -16,65 +15,66 @@ "generate:api": "orval --config ./orval.config.ts" }, "dependencies": { - "@codemirror/autocomplete": "^6.16.2", - "@codemirror/commands": "^6.6.0", - "@codemirror/lang-python": "^6.1.6", - "@codemirror/lang-sql": "^6.6.4", - "@codemirror/language": "^6.10.2", - "@codemirror/legacy-modes": "^6.4.0", - "@codemirror/state": "^6.4.1", - "@codemirror/view": "^6.28.1", - "@headlessui/react": "^1.7.17", - "@heroicons/react": "^2.0.18", - "@radix-ui/react-context-menu": "^2.1.4", - "@radix-ui/react-select": "^1.2.2", + "@codemirror/autocomplete": "^6.18.6", + "@codemirror/commands": "^6.8.1", + "@codemirror/lang-python": "^6.2.1", + "@codemirror/lang-sql": "^6.9.0", + "@codemirror/language": "^6.11.2", + "@codemirror/legacy-modes": "^6.5.1", + "@codemirror/state": "^6.5.2", + "@codemirror/view": "^6.38.1", + "@headlessui/react": "^2.2.5", + "@heroicons/react": "^2.2.0", + "@lit/react": "^1.0.8", + "@radix-ui/react-context-menu": "^2.2.15", + "@radix-ui/react-select": "^2.2.5", "@tailwindcss/container-queries": "^0.1.1", - "@tanstack/react-query": "^4.33.0", - "@tanstack/react-table": "^8.9.2", - "@tanstack/react-virtual": "^3.0.0-beta.56", - "@uidotdev/usehooks": "^2.2.0", - "@uiw/react-codemirror": "^4.21.12", - "apache-arrow": "^13.0.0", - "clsx": "^2.0.0", - "diff": "^5.2.0", + "@tanstack/react-query": "^5.83.0", + "@tanstack/react-table": "^8.21.3", + "@tanstack/react-virtual": "^3.13.12", + "@uidotdev/usehooks": "^2.4.1", + "@uiw/react-codemirror": "^4.24.1", + "apache-arrow": "^19.0.1", + "clsx": "^2.1.1", + "diff": "^8.0.2", "elkjs": "^0.8.2", "pluralize": "^8.0.0", - "react": "^18.2.0", + "react": "^18.3.1", "react-dnd": "^16.0.1", "react-dnd-html5-backend": "^16.0.1", - "react-dom": "^18.2.0", - "react-router-dom": "^6.15.0", + "react-dom": "^18.3.1", + "react-markdown": "^10.1.0", + "react-router": "^7.7.0", "react-split": "^2.0.14", - "reactflow": "^11.8.3", + "reactflow": "^11.11.4", "thememirror": "^2.0.1", - "zustand": "^4.4.1" + "zustand": "^5.0.6" }, "devDependencies": { - "@playwright/test": "^1.37.1", - "@testing-library/jest-dom": "^6.1.2", - "@testing-library/react": "^14.0.0", - "@testing-library/user-event": "^14.4.3", - "@types/diff": "^5.2.1", - "@types/pluralize": "^0.0.30", - "@types/react": "^18.2.21", - "@types/react-dom": "^18.2.7", - "@typescript-eslint/eslint-plugin": "^6.5.0", - "@vitejs/plugin-react-swc": "^3.3.2", - "autoprefixer": "^10.4.15", - "eslint": "^8.48.0", - "eslint-config-prettier": "^9.0.0", - "eslint-config-standard-with-typescript": "^39.0.0", - "eslint-plugin-import": "^2.28.1", - "eslint-plugin-n": "^16.0.2", - "eslint-plugin-promise": "^6.1.1", - "eslint-plugin-react": "^7.33.2", - "jsdom": "^22.1.0", - "orval": "^6.22.1", - "postcss": "^8.4.29", - "prettier": "^3.0.3", - "tailwindcss": "^3.3.3", - "typescript": "^5.2.2", - "vite": "^4.4.9", - "vitest": "^0.34.3" + "@eslint/js": "^9.31.0", + "@playwright/test": "^1.54.1", + "@swc/core": "^1.13.2", + "@testing-library/jest-dom": "^6.6.3", + "@testing-library/react": "^16.3.0", + "@testing-library/user-event": "^14.6.1", + "@types/pluralize": "^0.0.33", + "@types/react": "^18.3.23", + "@types/react-dom": "^18.3.7", + "@vitejs/plugin-react-swc": "^3.11.0", + "ajv": "^8.17.1", + "autoprefixer": "^10.4.21", + "eslint": "^9.31.0", + "jsdom": "^26.1.0", + "orval": "^7.10.0", + "postcss": "^8.5.6", + "tailwindcss": "^3.4.17", + "typescript": "^5.8.3", + "typescript-eslint": "^8.38.0", + "vite": "^6.3.5", + "vite-plugin-css-injected-by-js": "^3.5.2", + "vitest": "^3.2.4" + }, + "optionalDependencies": { + "@swc/core-linux-x64-gnu": "^1.13.2" } } diff --git a/web/client/postcss.config.js b/web/client/postcss.config.cjs similarity index 100% rename from web/client/postcss.config.js rename to web/client/postcss.config.cjs diff --git a/web/client/public/css/base.css b/web/client/public/css/base.css new file mode 100644 index 0000000000..a8c5a688fa --- /dev/null +++ b/web/client/public/css/base.css @@ -0,0 +1,241 @@ +:root { + font-synthesis: none; + text-rendering: optimizeLegibility; + -webkit-font-smoothing: antialiased; + -moz-osx-font-smoothing: grayscale; + -webkit-text-size-adjust: 100%; + + --color-brand: hsla(24, 100%, 60%, 1); + --color-primary: hsla(198, 100%, 63%, 1); + --color-secondary: hsla(216, 100%, 50%, 1); + --color-accent: hsla(264, 100%, 60%, 1); + --color-success: hsla(111, 66%, 55%, 1); + --color-warning: hsla(30, 94%, 62%, 1); + --color-danger: hsla(1, 100%, 64%, 1); + --color-neutral: hsla(0, 0%, 50%, 1); + + /* Brand */ + --color-brand-5: hsla(24, 100%, 60%, 0.05); + --color-brand-10: hsla(24, 100%, 60%, 0.1); + --color-brand-20: hsla(24, 100%, 60%, 0.2); + --color-brand-30: hsla(24, 100%, 60%, 0.3); + --color-brand-40: hsla(24, 100%, 60%, 0.4); + --color-brand-50: hsla(24, 100%, 60%, 0.5); + --color-brand-75: hsla(24, 100%, 60%, 0.75); + --color-brand-90: hsla(24, 100%, 60%, 0.9); + --color-brand-100: hsla(37, 100%, 92%, 1); + --color-brand-200: hsla(34, 100%, 84%, 1); + --color-brand-300: hsla(31, 100%, 76%, 1); + --color-brand-400: hsla(27, 100%, 70%, 1); + --color-brand-500: var(--color-brand); + --color-brand-600: hsla(20, 72%, 50%, 1); + --color-brand-700: hsla(16, 76%, 41%, 1); + --color-brand-800: hsla(13, 80%, 32%, 1); + --color-brand-900: hsla(10, 86%, 26%, 1); + + /* Primary */ + --color-primary-5: hsla(198, 100%, 63%, 0.05); + --color-primary-10: hsla(198, 100%, 63%, 0.1); + --color-primary-20: hsla(198, 100%, 63%, 0.2); + --color-primary-30: hsla(198, 100%, 63%, 0.3); + --color-primary-40: hsla(198, 100%, 63%, 0.4); + --color-primary-50: hsla(198, 100%, 63%, 0.5); + --color-primary-60: hsla(198, 100%, 63%, 0.6); + --color-primary-70: hsla(198, 100%, 63%, 0.7); + --color-primary-80: hsla(198, 100%, 63%, 0.8); + --color-primary-90: hsla(198, 100%, 63%, 0.9); + --color-primary-100: hsla(183, 100%, 93%, 1); + --color-primary-200: hsla(186, 100%, 85%, 1); + --color-primary-300: hsla(191, 100%, 78%, 1); + --color-primary-400: hsla(194, 100%, 72%, 1); + --color-primary-500: var(--color-primary); + --color-primary-600: hsla(202, 70%, 52%, 1); + --color-primary-700: hsla(206, 69%, 42%, 1); + --color-primary-800: hsla(210, 75%, 33%, 1); + --color-primary-900: hsla(214, 82%, 26%, 1); + + /* Secondary */ + --color-secondary-5: hsla(216, 100%, 50%, 0.05); + --color-secondary-10: hsla(216, 100%, 50%, 0.1); + --color-secondary-20: hsla(216, 100%, 50%, 0.2); + --color-secondary-30: hsla(216, 100%, 50%, 0.3); + --color-secondary-40: hsla(216, 100%, 50%, 0.4); + --color-secondary-50: hsla(216, 100%, 50%, 0.5); + --color-secondary-60: hsla(216, 100%, 50%, 0.6); + --color-secondary-70: hsla(216, 100%, 50%, 0.7); + --color-secondary-80: hsla(216, 100%, 50%, 0.8); + --color-secondary-90: hsla(216, 100%, 50%, 0.9); + --color-secondary-100: hsla(207, 100%, 90%, 1); + --color-secondary-200: hsla(209, 100%, 80%, 1); + --color-secondary-300: hsla(211, 100%, 70%, 1); + --color-secondary-400: hsla(213, 100%, 62%, 1); + --color-secondary-500: var(--color-secondary); + --color-secondary-600: hsla(219, 100%, 43%, 1); + --color-secondary-700: hsla(221, 100%, 36%, 1); + --color-secondary-800: hsla(223, 100%, 29%, 1); + --color-secondary-900: hsla(226, 100%, 24%, 1); + + /* Accent */ + --color-accent-5: hsla(264, 100%, 60%, 0.05); + --color-accent-50: hsla(264, 100%, 60%, 0.5); + --color-accent-100: hsla(264, 100%, 98%, 1); + --color-accent-200: hsla(260, 100%, 90%, 1); + --color-accent-300: hsla(260, 100%, 80%, 1); + --color-accent-400: hsla(260, 100%, 70%, 1); + --color-accent-500: var(--color-accent); + --color-accent-600: hsla(264, 100%, 50%, 1); + --color-accent-700: hsla(264, 100%, 40%, 1); + --color-accent-800: hsla(264, 100%, 20%, 1); + --color-accent-900: hsla(264, 100%, 8%, 1); + + /* Success */ + --color-success-5: hsla(111, 66%, 55%, 0.05); + --color-success-10: hsla(111, 66%, 55%, 0.1); + --color-success-20: hsla(111, 66%, 55%, 0.2); + --color-success-100: hsla(92, 90%, 92%, 1); + --color-success-200: hsla(98, 90%, 84%, 1); + --color-success-300: hsla(102, 81%, 75%, 1); + --color-success-400: hsla(106, 72%, 66%, 1); + --color-success-500: var(--color-success); + --color-success-600: hsla(116, 61%, 45%, 1); + --color-success-700: hsla(120, 67%, 36%, 1); + --color-success-800: hsla(125, 74%, 28%, 1); + --color-success-900: hsla(130, 81%, 22%, 1); + + /* Warning */ + --color-warning-5: hsla(30, 94%, 62%, 0.05); + --color-warning-10: hsla(30, 94%, 62%, 0.1); + --color-warning-100: hsla(42, 95%, 92%, 1); + --color-warning-200: hsla(39, 97%, 85%, 1); + --color-warning-300: hsla(36, 97%, 77%, 1); + --color-warning-400: hsla(33, 95%, 71%, 1); + --color-warning-500: var(--color-warning); + --color-warning-600: hsla(27, 67%, 51%, 1); + --color-warning-700: hsla(24, 69%, 42%, 1); + --color-warning-800: hsla(21, 75%, 32%, 1); + --color-warning-900: hsla(18, 82%, 26%, 1); + + /* Danger */ + --color-danger-5: hsla(21, 100%, 50%, 0.05); + --color-danger-10: hsla(21, 100%, 50%, 0.1); + --color-danger-20: hsla(21, 100%, 50%, 0.2); + --color-danger-100: hsla(21, 100%, 93%, 1); + --color-danger-200: hsla(16, 100%, 86%, 1); + --color-danger-300: hsla(11, 100%, 78%, 1); + --color-danger-400: hsla(6, 100%, 73%, 1); + --color-danger-500: var(--color-danger); + --color-danger-600: hsla(356, 70%, 53%, 1); + --color-danger-700: hsla(351, 67%, 43%, 1); + --color-danger-800: hsla(345, 73%, 33%, 1); + --color-danger-900: hsla(341, 79%, 27%, 1); + + /* Neutral */ + --color-neutral-5: hsla(190, 8%, 50%, 0.05); + --color-neutral-10: hsla(190, 8%, 50%, 0.1); + --color-neutral-20: hsla(190, 8%, 50%, 0.2); + --color-neutral-30: hsla(190, 8%, 50%, 0.3); + --color-neutral-40: hsla(190, 8%, 50%, 0.4); + --color-neutral-50: hsla(190, 8%, 50%, 0.5); + --color-neutral-60: hsla(190, 8%, 50%, 0.6); + --color-neutral-70: hsla(190, 8%, 50%, 0.7); + --color-neutral-80: hsla(190, 8%, 50%, 0.8); + --color-neutral-90: hsla(190, 8%, 50%, 0.9); + --color-neutral-100: hsl(190, 8%, 98%); + --color-neutral-200: hsl(190, 8%, 86%); + --color-neutral-300: hsl(190, 8%, 74%); + --color-neutral-400: hsl(190, 8%, 66%); + --color-neutral-500: var(--color-neutral); + --color-neutral-600: hsl(190, 8%, 38%); + --color-neutral-700: hsl(202, 8%, 26%); + --color-neutral-800: hsl(214, 8%, 14%); + --color-neutral-900: hsl(226, 8%, 4%); + + --color-dark-darker: hsl(226, 16%, 4%); + --color-dark: hsl(226, 24%, 8%); + --color-dark-lighter: hsl(226, 32%, 16%); + + --color-dark-transparent-20: hsla(226, 24%, 8%, 0.2); + + --color-light-darker: hsl(0, 0%, 92%); + --color-light: hsl(0, 0%, 100%); + --color-light-lighter: hsl(0, 0%, 98%); + + --color-light-transparent-20: hsla(0, 0%, 100%, 0.2); + + /* General */ + --unit: 16px; + --leading: 1.5; + + --scrollbar-size: 6px; + --scrollbar-radius: 1rem; + --scrollbar-backgroud: var(--color-brand); +} + +html[mode='dark'] { + --color-theme-darker: var(--color-dark-darker); + --color-theme: var(--color-dark); + --color-theme-lighter: var(--color-dark-lighter); + + --color-text-darker: var(--color-neutral-400); + --color-text: var(--color-neutral-300); + --color-text-lighter: var(--color-neutral-200); + + --color-divider-darker: var(--color-theme-darker); + --color-divider: var(--color-theme-lighter); + --color-divider-lighter: var(--color-neutral-100); + + --color-logo-darker: var(--color-neutral-300); + --color-logo: var(--color-neutral-200); + --color-logo-lighter: var(--color-neutral-100); + + --color-editor-text: var(--color-neutral-300); + --color-editor-active-line-text: var(--color-neutral-200); + + --color-graph-edge-secondary: var(--color-primary-500); + --color-graph-edge-main: var(--color-primary-20); + --color-graph-edge-selected: var(--color-primary-500); + --color-graph-edge-direct: var(--color-primary-20); + + --color-transparent-20: var(--color-dark-transparent-20); +} + +html[mode='light'] { + --color-theme-darker: var(--color-light-darker); + --color-theme: var(--color-light); + --color-theme-lighter: var(--color-light-lighter); + + --color-text-darker: var(--color-neutral-700); + --color-text: var(--color-neutral-600); + --color-text-lighter: var(--color-neutral-500); + + --color-divider-darker: var(--color-neutral-500); + --color-divider: var(--color-neutral-200); + --color-divider-lighter: var(--color-neutral-300); + + --color-logo-darker: var(--color-neutral-900); + --color-logo: var(--color-neutral-800); + --color-logo-lighter: var(--color-neutral-300); + + --color-editor-text: var(--color-neutral-600); + --color-editor-active-line-text: var(--color-neutral-700); + + --color-graph-edge-secondary: var(--color-secondary-500); + --color-graph-edge-main: var(--color-secondary-20); + --color-graph-edge-selected: var(--color-secondary-500); + --color-graph-edge-direct: var(--color-secondary-20); + + --color-transparent-20: var(--color-light-transparent-20); +} + +html { + width: 100%; + height: 100%; + font-size: var(--unit); + line-height: var(--leading); + font-weight: 600; + background: var(--color-theme); + color: var(--color-text); + + --color-editor: var(--color-theme); + --color-overlay: var(--color-dark-darker); +} diff --git a/web/client/public/css/design.css b/web/client/public/css/design.css new file mode 100644 index 0000000000..c3e771ed02 --- /dev/null +++ b/web/client/public/css/design.css @@ -0,0 +1,456 @@ +:root { + /* Colors: Brand */ + --color-tobiko: hsla(16, 94%, 50%, 1); + --color-sqlmesh: hsla(230, 94%, 50%, 1); + --color-sqlglot: hsla(148, 100%, 38%, 1); + + /* Colors: Brand Palette*/ + --color-pacific: hsla(194, 100%, 50%, 1); + + --color-pacific-5: hsla(194, 100%, 50%, 0.05); + --color-pacific-10: hsla(194, 100%, 50%, 0.1); + + --color-wasabi: hsla(112, 67%, 68%, 1); + --color-yuzu: hsla(58, 100%, 56%, 1); + + --color-uni-5: hsla(38, 96%, 51%, 0.05); + --color-uni-10: hsla(38, 96%, 51%, 0.1); + --color-uni-15: hsla(38, 96%, 51%, 0.15); + --color-uni-20: hsla(38, 96%, 51%, 0.2); + --color-uni-50: hsla(38, 96%, 51%, 0.5); + --color-uni: hsla(38, 96%, 51%, 1); + + --color-salmon-5: hsla(345, 92%, 81%, 0.05); + --color-salmon-10: hsla(345, 92%, 81%, 0.1); + --color-salmon-15: hsla(345, 92%, 81%, 0.15); + --color-salmon-20: hsla(345, 92%, 81%, 0.2); + --color-salmon-50: hsla(345, 92%, 81%, 0.5); + --color-salmon: hsla(345, 92%, 81%, 1); + + /* Colors: Base */ + --color-white: hsla(0, 0%, 100%, 1); + --color-black: hsla(0, 0%, 5%, 1); + --color-cyan: hsla(198, 100%, 63%, 1); + --color-deep-blue: hsla(210, 100%, 50%, 1); + --color-purple: hsla(264, 100%, 60%, 1); + --color-emerald: hsla(130, 65%, 50%, 1); + --color-mandarin: hsla(30, 100%, 50%, 1); + --color-scarlet: hsla(350, 85%, 60%, 1); + --color-sunflower: hsla(48, 100%, 50%, 1); + --color-peach: hsla(24, 100%, 70%, 1); + + --color-turquoise-5: hsla(150, 80%, 45%, 0.05); + --color-turquoise-10: hsla(150, 80%, 45%, 0.1); + --color-turquoise-20: hsla(150, 80%, 45%, 0.2); + --color-turquoise-50: hsla(150, 80%, 45%, 0.5); + --color-turquoise: hsla(150, 80%, 45%, 1); + + --color-fuchsia-5: hsla(320, 100%, 70%, 0.05); + --color-fuchsia-10: hsla(320, 100%, 70%, 0.1); + --color-fuchsia-20: hsla(320, 100%, 70%, 0.2); + --color-fuchsia-50: hsla(320, 100%, 70%, 0.5); + --color-fuchsia: hsla(320, 100%, 70%, 1); + + --color-gray: hsla(0, 0%, 50%, 1); + + /* Colors: Tokens */ + --color-cyan-5: hsla(198, 100%, 63%, 0.05); + --color-cyan-10: hsla(198, 100%, 63%, 0.1); + --color-cyan-15: hsla(198, 100%, 63%, 0.15); + --color-cyan-20: hsla(198, 100%, 63%, 0.2); + --color-cyan-100: hsla(183, 100%, 93%, 1); + --color-cyan-200: hsla(186, 100%, 85%, 1); + --color-cyan-300: hsla(191, 100%, 78%, 1); + --color-cyan-400: hsla(194, 100%, 72%, 1); + --color-cyan-500: var(--color-cyan); + --color-cyan-525: hsla(198, 100%, 70%, 1); + --color-cyan-550: hsla(198, 100%, 68%, 1); + --color-cyan-600: hsla(202, 70%, 52%, 1); + --color-cyan-700: hsla(206, 69%, 42%, 1); + --color-cyan-800: hsla(210, 75%, 33%, 1); + --color-cyan-900: hsla(214, 82%, 26%, 1); + + --color-deep-blue-5: hsla(216, 100%, 50%, 0.05); + --color-deep-blue-10: hsla(216, 100%, 50%, 0.1); + --color-deep-blue-15: hsla(216, 100%, 50%, 0.15); + --color-deep-blue-20: hsla(216, 100%, 50%, 0.2); + --color-deep-blue-60: hsla(216, 100%, 50%, 0.6); + --color-deep-blue-100: hsla(207, 100%, 95%, 1); + --color-deep-blue-125: hsla(207, 100%, 92%, 1); + --color-deep-blue-150: hsla(207, 100%, 88%, 1); + --color-deep-blue-200: hsla(209, 100%, 80%, 1); + --color-deep-blue-300: hsla(211, 100%, 70%, 1); + --color-deep-blue-400: hsla(213, 100%, 62%, 1); + --color-deep-blue-500: var(--color-deep-blue); + --color-deep-blue-525: hsla(216, 100%, 60%, 1); + --color-deep-blue-550: hsla(216, 100%, 54%, 1); + --color-deep-blue-600: hsla(219, 100%, 43%, 1); + --color-deep-blue-700: hsla(221, 100%, 36%, 1); + --color-deep-blue-725: hsla(223, 100%, 33%, 1); + --color-deep-blue-750: hsla(223, 100%, 31%, 1); + --color-deep-blue-800: hsla(223, 100%, 29%, 1); + --color-deep-blue-900: hsla(226, 100%, 24%, 1); + + --color-pacific-5: hsla(194, 100%, 50%, 0.05); + --color-pacific-10: hsla(194, 100%, 50%, 0.1); + --color-pacific-15: hsla(194, 100%, 50%, 0.15); + --color-pacific-20: hsla(194, 100%, 50%, 0.2); + --color-pacific-50: hsla(194, 100%, 50%, 0.5); + --color-pacific-100: hsla(194, 100%, 98%, 1); + --color-pacific-125: hsla(194, 100%, 94%, 1); + --color-pacific-150: hsla(194, 100%, 92%, 1); + --color-pacific-200: hsla(194, 100%, 86%, 1); + --color-pacific-300: hsla(194, 100%, 74%, 1); + --color-pacific-400: hsla(194, 100%, 66%, 1); + --color-pacific-500: var(--color-pacific); + --color-pacific-525: hsla(194, 100%, 46%, 1); + --color-pacific-550: hsla(194, 100%, 42%, 1); + --color-pacific-600: hsla(194, 100%, 38%, 1); + + --color-purple-5: hsla(264, 100%, 60%, 0.05); + --color-purple-10: hsla(264, 100%, 60%, 0.1); + --color-purple-15: hsla(264, 100%, 60%, 0.15); + --color-purple-20: hsla(264, 100%, 60%, 0.2); + --color-purple-100: hsla(264, 100%, 98%, 1); + --color-purple-125: hsla(264, 100%, 94%, 1); + --color-purple-150: hsla(264, 100%, 92%, 1); + --color-purple-200: hsla(260, 100%, 90%, 1); + --color-purple-300: hsla(260, 100%, 80%, 1); + --color-purple-400: hsla(260, 100%, 70%, 1); + --color-purple-500: var(--color-purple); + --color-purple-600: hsla(264, 100%, 50%, 1); + --color-purple-700: hsla(264, 100%, 40%, 1); + --color-purple-800: hsla(264, 100%, 20%, 1); + --color-purple-900: hsla(264, 100%, 8%, 1); + + --color-emerald-5: hsla(111, 66%, 55%, 0.05); + --color-emerald-10: hsla(111, 66%, 55%, 0.1); + --color-emerald-15: hsla(111, 66%, 55%, 0.15); + --color-emerald-100: hsla(92, 90%, 92%, 1); + --color-emerald-125: hsla(92, 90%, 88%, 1); + --color-emerald-150: hsla(92, 90%, 86%, 1); + --color-emerald-200: hsla(98, 90%, 84%, 1); + --color-emerald-300: hsla(102, 81%, 75%, 1); + --color-emerald-400: hsla(106, 72%, 66%, 1); + --color-emerald-500: var(--color-emerald); + --color-emerald-525: hsla(111, 66%, 63%, 1); + --color-emerald-550: hsla(111, 66%, 61%, 1); + --color-emerald-600: hsla(116, 61%, 45%, 1); + --color-emerald-700: hsla(120, 67%, 36%, 1); + --color-emerald-800: hsla(125, 74%, 28%, 1); + --color-emerald-900: hsla(130, 81%, 22%, 1); + + --color-mandarin-5: hsla(30, 94%, 62%, 0.05); + --color-mandarin-10: hsla(30, 94%, 62%, 0.1); + --color-mandarin-15: hsla(30, 94%, 62%, 0.15); + --color-mandarin-100: hsla(42, 95%, 92%, 1); + --color-mandarin-125: hsla(42, 95%, 88%, 1); + --color-mandarin-150: hsla(42, 95%, 86%, 1); + --color-mandarin-200: hsla(39, 97%, 85%, 1); + --color-mandarin-300: hsla(36, 97%, 77%, 1); + --color-mandarin-400: hsla(33, 95%, 71%, 1); + --color-mandarin-500: var(--color-mandarin); + --color-mandarin-525: hsla(30, 94%, 68%, 1); + --color-mandarin-550: hsla(30, 94%, 66%, 1); + --color-mandarin-600: hsla(27, 67%, 51%, 1); + --color-mandarin-700: hsla(24, 69%, 42%, 1); + --color-mandarin-800: hsla(21, 75%, 32%, 1); + --color-mandarin-900: hsla(18, 82%, 26%, 1); + + --color-scarlet-5: hsla(21, 100%, 50%, 0.05); + --color-scarlet-10: hsla(21, 100%, 50%, 0.1); + --color-scarlet-15: hsla(21, 100%, 50%, 0.15); + --color-scarlet-100: hsla(21, 100%, 93%, 1); + --color-scarlet-125: hsla(21, 100%, 89%, 1); + --color-scarlet-150: hsla(21, 100%, 87%, 1); + --color-scarlet-200: hsla(16, 100%, 86%, 1); + --color-scarlet-300: hsla(11, 100%, 78%, 1); + --color-scarlet-400: hsla(6, 100%, 73%, 1); + --color-scarlet-500: var(--color-scarlet); + --color-scarlet-525: hsla(356, 100%, 70%, 1); + --color-scarlet-550: hsla(356, 100%, 68%, 1); + --color-scarlet-600: hsla(356, 70%, 53%, 1); + --color-scarlet-700: hsla(351, 67%, 43%, 1); + --color-scarlet-725: hsla(346, 67%, 38%, 1); + --color-scarlet-750: hsla(346, 67%, 38%, 1); + --color-scarlet-800: hsla(345, 73%, 33%, 1); + --color-scarlet-900: hsla(341, 79%, 27%, 1); + + --color-gray-3: hsla(190, 8%, 50%, 0.03); + --color-gray-5: hsla(190, 8%, 50%, 0.05); + --color-gray-10: hsla(190, 8%, 50%, 0.1); + --color-gray-15: hsla(190, 8%, 50%, 0.15); + --color-gray-25: hsla(190, 8%, 50%, 0.25); + --color-gray-50: hsla(202, 8%, 26%, 0.5); + --color-gray-100: hsl(190, 8%, 98%); + --color-gray-125: hsl(190, 8%, 94%); + --color-gray-150: hsl(190, 8%, 92%); + --color-gray-200: hsl(190, 8%, 86%); + --color-gray-300: hsl(190, 8%, 74%); + --color-gray-400: hsl(190, 8%, 66%); + --color-gray-500: var(--color-gray); + --color-gray-525: hsl(190, 8%, 46%); + --color-gray-550: hsl(190, 8%, 42%); + --color-gray-600: hsl(190, 8%, 38%); + --color-gray-700: hsl(202, 8%, 26%); + --color-gray-800: hsl(214, 8%, 14%); + --color-gray-900: hsl(226, 8%, 4%); + + /* Colors: Semantic */ + --color-background: hsl(30, 12%, 97%); + --color-light: hsl(30, 12%, 100%); + --color-dark: hsl(226, 24%, 8%); + --color-outline: var(--color-tobiko); + + --color-neutral-5: var(--color-gray-5); + --color-neutral-10: var(--color-gray-10); + --color-neutral-15: var(--color-gray-15); + --color-neutral-25: var(--color-gray-25); + --color-neutral-100: var(--color-gray-100); + --color-neutral: var(--color-gray-500); + --color-neutral-500: var(--color-gray-500); + --color-neutral-600: var(--color-gray-600); + --color-neutral-700: var(--color-gray-700); + + --color-border: var(--color-gray-10); + + --color-brand-lighter: var(--color-tobiko-300); + --color-brand: var(--color-tobiko-500); + --color-brand-darker: var(--color-tobiko-700); + + --color-primary-light: var(--color-cyan-100); + --color-primary-hover: var(--color-cyan-525); + --color-primary-active: var(--color-cyan-550); + --color-primary: var(--color-cyan-500); + --color-primary-darker: var(--color-cyan-600); + + --color-secondary-light: var(--color-deep-blue-100); + --color-secondary-hover: var(--color-deep-blue-125); + --color-secondary-active: var(--color-deep-blue-150); + --color-secondary: var(--color-deep-blue-500); + --color-secondary-darker: var(--color-deep-blue-600); + + --color-success-light: var(--color-emerald-100); + --color-success: var(--color-emerald-500); + + --color-danger-light: var(--color-scarlet-100); + --color-danger: var(--color-scarlet-500); + + --color-warning-light: var(--color-mandarin-100); + --color-warning: var(--color-mandarin-500); + + --color-stroke: var(--color-gray); + --color-link: var(--color-secondary); + --color-underline: var(--color-secondary); + --color-underline-hover: var(--color-deep-blue-200); + --color-divider: var(--color-gray-10); + + --color-subtitle: var(--color-gray-300); + --color-tagline: var(--color-gray-400); + --color-text: var(--color-gray-600); + --color-title: var(--color-gray-700); + --color-header: var(--color-gray-800); + --color-headline: var(--color-gray-900); + + --color-input-bg: var(--color-gray-100); + --color-input-text: var(--color-gray-600); + --color-input-placeholder: var(--color-gray-300); + --color-input-border: var(--color-gray-200); + + --color-code: hsl(226, 32%, 16%); + --color-scrollbar: var(--color-gray-300); + + --color-shadow-5: hsla(226, 24%, 65%, 0.05); + --color-shadow-10: hsla(226, 24%, 65%, 0.1); + --color-shadow-15: hsla(226, 24%, 65%, 0.15); + --color-shadow-25: hsla(226, 24%, 65%, 0.25); + --color-shadow-30: hsla(226, 24%, 65%, 0.3); + --color-shadow-50: hsla(226, 24%, 65%, 0.5); + --color-shadow-75: hsla(226, 24%, 65%, 0.75); + --color-shadow-95: hsla(226, 24%, 65%, 0.95); + + --color-change-add: var(--color-emerald-550); + --color-change-remove: var(--color-scarlet-550); + --color-change-directly-modified: var(--color-deep-blue-550); + --color-change-indirectly-modified: var(--color-mandarin-550); + --color-change-metadata: var(--color-gray-300); + --color-change-backfill: var(--color-gray-600); + --color-run: var(--color-cyan-500); + --color-plan: var(--color-purple-500); + --color-model: var(--color-deep-blue-500); + --color-environment: var(--color-mandarin-400); + + --color-complete-5: hsla(130, 65%, 50%, 0.05); + --color-complete-10: hsla(130, 65%, 50%, 0.1); + --color-complete-100: hsla(130, 65%, 95%, 1); + --color-complete: hsla(130, 65%, 50%, 1); + --color-complete-600: hsla(130, 65%, 40%, 1); + + --color-pending-5: hsla(30, 100%, 50%, 0.05); + --color-pending-10: hsla(30, 100%, 50%, 0.1); + --color-pending-100: hsla(30, 100%, 95%, 1); + --color-pending: hsla(30, 100%, 50%, 1); + --color-pending-600: hsla(30, 100%, 40%, 1); + + --color-failed-5: hsla(350, 85%, 60%, 0.05); + --color-failed-10: hsla(350, 85%, 60%, 0.1); + --color-failed-100: hsla(0, 90%, 95%, 1); + --color-failed: hsla(350, 85%, 60%, 1); + --color-failed-600: hsla(0, 90%, 40%, 1); + + --color-behind-5: var(--color-failed-5); + --color-behind-10: var(--color-failed-10); + --color-behind-100: var(--color-failed-100); + --color-behind: var(--color-failed); + --color-behind-600: var(--color-failed-600); + + --color-progress-5: hsla(190, 8%, 50%, 0.05); + --color-progress-10: hsla(190, 8%, 50%, 0.1); + --color-progress-100: hsla(190, 8%, 95%, 1); + --color-progress: hsla(190, 8%, 50%, 1); + --color-progress-600: hsla(190, 8%, 40%, 1); + + --color-status-failed: var(--color-failed); + --color-status-behind: var(--color-failed); + --color-status-progress: var(--color-progress); + + --color-active: var(--color-brand); + --color-accent: var(--color-pacific); + + --color-chart-data: var(--color-deep-blue-600); + --color-chart-overlay: var(--color-purple-600); + + /* Space: Base */ + --one: 1px; + --base: 4px; + --half: calc(var(--base) / 2); + --step: var(--base); + --step-2: calc(var(--base) * 2); + --step-3: calc(var(--base) * 3); + --step-4: calc(var(--base) * 4); + --step-5: calc(var(--base) * 5); + --step-6: calc(var(--base) * 6); + --step-7: calc(var(--base) * 7); + --step-8: calc(var(--base) * 8); + --step-9: calc(var(--base) * 9); + --step-10: calc(var(--base) * 10); + --step-11: calc(var(--base) * 11); + --step-12: calc(var(--base) * 12); + --step-15: calc(var(--base) * 15); + --step-16: calc(var(--base) * 16); + --step-20: calc(var(--base) * 20); + --step-24: calc(var(--base) * 24); + --step-30: calc(var(--base) * 30); + --step-32: calc(var(--base) * 32); + + /* Typography: Base */ + --font-size: 16px; + --leading: 1.5; + --font-weight: 500; + --font-sans: 'Inter', sans-serif; + --font-accent: 'Sohne', sans-serif; + --font-serif: 'Martina Plantijn', serif; + --font-mono: 'JetBrains Mono', monospace; + + --text-2xs: 10px; + --text-xs: calc(var(--font-size) * 0.75); + --text-s: calc(var(--font-size) * 0.875); + --text-m: var(--font-size); + --text-l: calc(var(--font-size) * 1.125); + --text-xl: calc(var(--font-size) * 1.25); + --text-2xl: calc(var(--font-size) * 1.5); + + --text-headline: calc(var(--font-size) * 4); + --text-display: 45px; + --text-header: calc(var(--font-size) * 2); + --text-tagline: 23px; + --text-title: var(--font-size); + --text-subtitle: calc(var(--font-size) * 0.75); + + --text-leading-xs: 1; + --text-leading-s: 1.25; + --text-leading-m: var(--leading); + --text-leading-l: 1.75; + --text-leading-xl: 2; + + --text-thin: 100; + --text-extra-light: 200; + --text-light: 300; + --text-normal: 400; + --text-medium: var(--font-weight); + --text-semibold: 600; + --text-bold: 700; + --text-extra-bold: 800; + --text-black: 900; + + --size-xs-svg: var(--step-4); + --size-s-svg: var(--step-5); + --size-m-svg: var(--step-6); + --size-l-svg: var(--step-7); + --size-xl-svg: var(--step-8); + + --size-scrollbar: calc(var(--step) + var(--one)); + --size-sidebar: calc(var(--media-3xs) + var(--step-8)); + --size-half-screen-height: 50vh; + + --margin: var(--half); + --padding: var(--step-4); + + --radius-2xs: var(--step); + --radius-xs: var(--step-2); + --radius-s: var(--step-3); + --radius-m: var(--step-4); + --radius-l: var(--step-5); + --radius-xl: var(--step-6); + --radius-2xl: var(--step-8); + + --outline-offset: var(--one); + --outline-ring: calc(var(--half) + var(--one)); + --outline-ring-color: var(--color-tobiko-500); + --outline-color: var(--color-tobiko-100); + + --header-height: var(--step-10); + --footer-height: var(--step-6); + + --media-3xs: calc(var(--base) * 40); /* 160px */ + --media-2xs: calc(var(--base) * 60); /* 240px */ + --media-xs: calc(var(--base) * 80); /* 320px */ + --media-s: calc(var(--base) * 120); /* 480px */ + --media-m: calc(var(--base) * 160); /* 640px */ + --media-l: calc(var(--base) * 192); /* 768px */ + --media-xl: calc(var(--base) * 256); /* 1024px */ + --media-2xl: calc(var(--base) * 320); /* 1280px */ + --media-3xl: calc(var(--base) * 360); /* 1440px */ + --media-4xl: calc(var(--base) * 480); /* 1920px */ + + /* Utility */ + --underline: inset 0 calc(-1 * var(--one)) 0; + --layer-floor: 0; + --layer-lowest: 9; + --layer-low: 99; + --layer-middle: 999; + --layer-high: 9999; + --layer-highest: 99999; + --layer-ceil: 999999; + + --shadow-xs: 0 0 var(--half) var(--color-shadow-10); + --shadow-s: 0 var(--half) calc(var(--half) * 2) var(--color-shadow-10); + --shadow-m: 0 var(--step) calc(var(--half) * 3) var(--color-shadow-10); + --shadow-l: 0 var(--step) calc(var(--half) * 4) var(--color-shadow-10); + --shadow-xl: 0 var(--step) calc(var(--half) * 5) var(--color-shadow-10); + + --tooltip-shadow: + 0 var(--step-2) calc(var(--step) * 4) -1px var(--color-shadow-30), + inset 0 0 0 1px var(--color-gray-300); + --tooltip-background: var(--color-light); + --tooltip-text: var(--color-gray-600); + --tooltip-border-radius: var(--radius-s); + + --ease-in-out-exponential: cubic-bezier(0.4, 0, 0.2, 1); + + font-synthesis: none; + text-rendering: optimizeLegibility; + -webkit-font-smoothing: antialiased; + -moz-osx-font-smoothing: grayscale; + -webkit-text-size-adjust: 100%; +} diff --git a/web/client/public/favicons/favicon.ico b/web/client/public/favicons/favicon.ico deleted file mode 100644 index 45937721f6..0000000000 Binary files a/web/client/public/favicons/favicon.ico and /dev/null differ diff --git a/web/client/public/favicons/favicon.svg b/web/client/public/favicons/favicon.svg new file mode 100644 index 0000000000..cbf6e39228 --- /dev/null +++ b/web/client/public/favicons/favicon.svg @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/web/client/public/fonts/Inter/Inter-VariableFont_slnt,wght.ttf b/web/client/public/fonts/Inter/Inter-VariableFont_slnt,wght.ttf new file mode 100644 index 0000000000..ec3164efa8 Binary files /dev/null and b/web/client/public/fonts/Inter/Inter-VariableFont_slnt,wght.ttf differ diff --git a/web/client/public/fonts/Inter/OFL.txt b/web/client/public/fonts/Inter/OFL.txt new file mode 100644 index 0000000000..ad214842c4 --- /dev/null +++ b/web/client/public/fonts/Inter/OFL.txt @@ -0,0 +1,93 @@ +Copyright 2020 The Inter Project Authors (https://github.com/rsms/inter) + +This Font Software is licensed under the SIL Open Font License, Version 1.1. +This license is copied below, and is also available with a FAQ at: +http://scripts.sil.org/OFL + + +----------------------------------------------------------- +SIL OPEN FONT LICENSE Version 1.1 - 26 February 2007 +----------------------------------------------------------- + +PREAMBLE +The goals of the Open Font License (OFL) are to stimulate worldwide +development of collaborative font projects, to support the font creation +efforts of academic and linguistic communities, and to provide a free and +open framework in which fonts may be shared and improved in partnership +with others. + +The OFL allows the licensed fonts to be used, studied, modified and +redistributed freely as long as they are not sold by themselves. The +fonts, including any derivative works, can be bundled, embedded, +redistributed and/or sold with any software provided that any reserved +names are not used by derivative works. The fonts and derivatives, +however, cannot be released under any other type of license. The +requirement for fonts to remain under this license does not apply +to any document created using the fonts or their derivatives. + +DEFINITIONS +"Font Software" refers to the set of files released by the Copyright +Holder(s) under this license and clearly marked as such. This may +include source files, build scripts and documentation. + +"Reserved Font Name" refers to any names specified as such after the +copyright statement(s). + +"Original Version" refers to the collection of Font Software components as +distributed by the Copyright Holder(s). + +"Modified Version" refers to any derivative made by adding to, deleting, +or substituting -- in part or in whole -- any of the components of the +Original Version, by changing formats or by porting the Font Software to a +new environment. + +"Author" refers to any designer, engineer, programmer, technical +writer or other person who contributed to the Font Software. + +PERMISSION & CONDITIONS +Permission is hereby granted, free of charge, to any person obtaining +a copy of the Font Software, to use, study, copy, merge, embed, modify, +redistribute, and sell modified and unmodified copies of the Font +Software, subject to the following conditions: + +1) Neither the Font Software nor any of its individual components, +in Original or Modified Versions, may be sold by itself. + +2) Original or Modified Versions of the Font Software may be bundled, +redistributed and/or sold with any software, provided that each copy +contains the above copyright notice and this license. These can be +included either as stand-alone text files, human-readable headers or +in the appropriate machine-readable metadata fields within text or +binary files as long as those fields can be easily viewed by the user. + +3) No Modified Version of the Font Software may use the Reserved Font +Name(s) unless explicit written permission is granted by the corresponding +Copyright Holder. This restriction only applies to the primary font name as +presented to the users. + +4) The name(s) of the Copyright Holder(s) or the Author(s) of the Font +Software shall not be used to promote, endorse or advertise any +Modified Version, except to acknowledge the contribution(s) of the +Copyright Holder(s) and the Author(s) or with their explicit written +permission. + +5) The Font Software, modified or unmodified, in part or in whole, +must be distributed entirely under this license, and must not be +distributed under any other license. The requirement for fonts to +remain under this license does not apply to any document created +using the Font Software. + +TERMINATION +This license becomes null and void if any of the above conditions are +not met. + +DISCLAIMER +THE FONT SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT +OF COPYRIGHT, PATENT, TRADEMARK, OR OTHER RIGHT. IN NO EVENT SHALL THE +COPYRIGHT HOLDER BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +INCLUDING ANY GENERAL, SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL +DAMAGES, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF THE USE OR INABILITY TO USE THE FONT SOFTWARE OR FROM +OTHER DEALINGS IN THE FONT SOFTWARE. diff --git a/web/client/public/fonts/Inter/README.txt b/web/client/public/fonts/Inter/README.txt new file mode 100644 index 0000000000..3078f199cc --- /dev/null +++ b/web/client/public/fonts/Inter/README.txt @@ -0,0 +1,72 @@ +Inter Variable Font +=================== + +This download contains Inter as both a variable font and static fonts. + +Inter is a variable font with these axes: + slnt + wght + +This means all the styles are contained in a single file: + Inter-VariableFont_slnt,wght.ttf + +If your app fully supports variable fonts, you can now pick intermediate styles +that aren’t available as static fonts. Not all apps support variable fonts, and +in those cases you can use the static font files for Inter: + static/Inter-Thin.ttf + static/Inter-ExtraLight.ttf + static/Inter-Light.ttf + static/Inter-Regular.ttf + static/Inter-Medium.ttf + static/Inter-SemiBold.ttf + static/Inter-Bold.ttf + static/Inter-ExtraBold.ttf + static/Inter-Black.ttf + +Get started +----------- + +1. Install the font files you want to use + +2. Use your app's font picker to view the font family and all the +available styles + +Learn more about variable fonts +------------------------------- + + https://developers.google.com/web/fundamentals/design-and-ux/typography/variable-fonts + https://variablefonts.typenetwork.com + https://medium.com/variable-fonts + +In desktop apps + + https://theblog.adobe.com/can-variable-fonts-illustrator-cc + https://helpx.adobe.com/nz/photoshop/using/fonts.html#variable_fonts + +Online + + https://developers.google.com/fonts/docs/getting_started + https://developer.mozilla.org/en-US/docs/Web/CSS/CSS_Fonts/Variable_Fonts_Guide + https://developer.microsoft.com/en-us/microsoft-edge/testdrive/demos/variable-fonts + +Installing fonts + + MacOS: https://support.apple.com/en-us/HT201749 + Linux: https://www.google.com/search?q=how+to+install+a+font+on+gnu%2Blinux + Windows: https://support.microsoft.com/en-us/help/314960/how-to-install-or-remove-a-font-in-windows + +Android Apps + + https://developers.google.com/fonts/docs/android + https://developer.android.com/guide/topics/ui/look-and-feel/downloadable-fonts + +License +------- +Please read the full license text (OFL.txt) to understand the permissions, +restrictions and requirements for usage, redistribution, and modification. + +You can use them in your products & projects – print or digital, +commercial or otherwise. + +This isn't legal advice, please consider consulting a lawyer and see the full +license for all details. diff --git a/web/client/public/fonts/Inter/static/Inter-Black.ttf b/web/client/public/fonts/Inter/static/Inter-Black.ttf new file mode 100644 index 0000000000..5aecf7dc41 Binary files /dev/null and b/web/client/public/fonts/Inter/static/Inter-Black.ttf differ diff --git a/web/client/public/fonts/Inter/static/Inter-Bold.ttf b/web/client/public/fonts/Inter/static/Inter-Bold.ttf new file mode 100644 index 0000000000..8e82c70d10 Binary files /dev/null and b/web/client/public/fonts/Inter/static/Inter-Bold.ttf differ diff --git a/web/client/public/fonts/Inter/static/Inter-ExtraBold.ttf b/web/client/public/fonts/Inter/static/Inter-ExtraBold.ttf new file mode 100644 index 0000000000..cb4b8217fc Binary files /dev/null and b/web/client/public/fonts/Inter/static/Inter-ExtraBold.ttf differ diff --git a/web/client/public/fonts/Inter/static/Inter-ExtraLight.ttf b/web/client/public/fonts/Inter/static/Inter-ExtraLight.ttf new file mode 100644 index 0000000000..64aee30a4e Binary files /dev/null and b/web/client/public/fonts/Inter/static/Inter-ExtraLight.ttf differ diff --git a/web/client/public/fonts/Inter/static/Inter-Light.ttf b/web/client/public/fonts/Inter/static/Inter-Light.ttf new file mode 100644 index 0000000000..9e265d8905 Binary files /dev/null and b/web/client/public/fonts/Inter/static/Inter-Light.ttf differ diff --git a/web/client/public/fonts/Inter/static/Inter-Medium.ttf b/web/client/public/fonts/Inter/static/Inter-Medium.ttf new file mode 100644 index 0000000000..b53fb1c4ac Binary files /dev/null and b/web/client/public/fonts/Inter/static/Inter-Medium.ttf differ diff --git a/web/client/public/fonts/Inter/static/Inter-Regular.ttf b/web/client/public/fonts/Inter/static/Inter-Regular.ttf new file mode 100644 index 0000000000..8d4eebf206 Binary files /dev/null and b/web/client/public/fonts/Inter/static/Inter-Regular.ttf differ diff --git a/web/client/public/fonts/Inter/static/Inter-SemiBold.ttf b/web/client/public/fonts/Inter/static/Inter-SemiBold.ttf new file mode 100644 index 0000000000..c6aeeb16a6 Binary files /dev/null and b/web/client/public/fonts/Inter/static/Inter-SemiBold.ttf differ diff --git a/web/client/public/fonts/Inter/static/Inter-Thin.ttf b/web/client/public/fonts/Inter/static/Inter-Thin.ttf new file mode 100644 index 0000000000..7aed55d560 Binary files /dev/null and b/web/client/public/fonts/Inter/static/Inter-Thin.ttf differ diff --git a/web/client/public/fonts/JetBrains_Mono/JetBrainsMono-Italic-VariableFont_wght.ttf b/web/client/public/fonts/JetBrains_Mono/JetBrainsMono-Italic-VariableFont_wght.ttf new file mode 100644 index 0000000000..914e323363 Binary files /dev/null and b/web/client/public/fonts/JetBrains_Mono/JetBrainsMono-Italic-VariableFont_wght.ttf differ diff --git a/web/client/public/fonts/JetBrains_Mono/JetBrainsMono-VariableFont_wght.ttf b/web/client/public/fonts/JetBrains_Mono/JetBrainsMono-VariableFont_wght.ttf new file mode 100644 index 0000000000..d73994ad49 Binary files /dev/null and b/web/client/public/fonts/JetBrains_Mono/JetBrainsMono-VariableFont_wght.ttf differ diff --git a/web/client/public/fonts/JetBrains_Mono/OFL.txt b/web/client/public/fonts/JetBrains_Mono/OFL.txt new file mode 100644 index 0000000000..201e940f3a --- /dev/null +++ b/web/client/public/fonts/JetBrains_Mono/OFL.txt @@ -0,0 +1,93 @@ +Copyright 2020 The JetBrains Mono Project Authors (https://github.com/JetBrains/JetBrainsMono) + +This Font Software is licensed under the SIL Open Font License, Version 1.1. +This license is copied below, and is also available with a FAQ at: +http://scripts.sil.org/OFL + + +----------------------------------------------------------- +SIL OPEN FONT LICENSE Version 1.1 - 26 February 2007 +----------------------------------------------------------- + +PREAMBLE +The goals of the Open Font License (OFL) are to stimulate worldwide +development of collaborative font projects, to support the font creation +efforts of academic and linguistic communities, and to provide a free and +open framework in which fonts may be shared and improved in partnership +with others. + +The OFL allows the licensed fonts to be used, studied, modified and +redistributed freely as long as they are not sold by themselves. The +fonts, including any derivative works, can be bundled, embedded, +redistributed and/or sold with any software provided that any reserved +names are not used by derivative works. The fonts and derivatives, +however, cannot be released under any other type of license. The +requirement for fonts to remain under this license does not apply +to any document created using the fonts or their derivatives. + +DEFINITIONS +"Font Software" refers to the set of files released by the Copyright +Holder(s) under this license and clearly marked as such. This may +include source files, build scripts and documentation. + +"Reserved Font Name" refers to any names specified as such after the +copyright statement(s). + +"Original Version" refers to the collection of Font Software components as +distributed by the Copyright Holder(s). + +"Modified Version" refers to any derivative made by adding to, deleting, +or substituting -- in part or in whole -- any of the components of the +Original Version, by changing formats or by porting the Font Software to a +new environment. + +"Author" refers to any designer, engineer, programmer, technical +writer or other person who contributed to the Font Software. + +PERMISSION & CONDITIONS +Permission is hereby granted, free of charge, to any person obtaining +a copy of the Font Software, to use, study, copy, merge, embed, modify, +redistribute, and sell modified and unmodified copies of the Font +Software, subject to the following conditions: + +1) Neither the Font Software nor any of its individual components, +in Original or Modified Versions, may be sold by itself. + +2) Original or Modified Versions of the Font Software may be bundled, +redistributed and/or sold with any software, provided that each copy +contains the above copyright notice and this license. These can be +included either as stand-alone text files, human-readable headers or +in the appropriate machine-readable metadata fields within text or +binary files as long as those fields can be easily viewed by the user. + +3) No Modified Version of the Font Software may use the Reserved Font +Name(s) unless explicit written permission is granted by the corresponding +Copyright Holder. This restriction only applies to the primary font name as +presented to the users. + +4) The name(s) of the Copyright Holder(s) or the Author(s) of the Font +Software shall not be used to promote, endorse or advertise any +Modified Version, except to acknowledge the contribution(s) of the +Copyright Holder(s) and the Author(s) or with their explicit written +permission. + +5) The Font Software, modified or unmodified, in part or in whole, +must be distributed entirely under this license, and must not be +distributed under any other license. The requirement for fonts to +remain under this license does not apply to any document created +using the Font Software. + +TERMINATION +This license becomes null and void if any of the above conditions are +not met. + +DISCLAIMER +THE FONT SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO ANY WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT +OF COPYRIGHT, PATENT, TRADEMARK, OR OTHER RIGHT. IN NO EVENT SHALL THE +COPYRIGHT HOLDER BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, +INCLUDING ANY GENERAL, SPECIAL, INDIRECT, INCIDENTAL, OR CONSEQUENTIAL +DAMAGES, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF THE USE OR INABILITY TO USE THE FONT SOFTWARE OR FROM +OTHER DEALINGS IN THE FONT SOFTWARE. diff --git a/web/client/public/fonts/JetBrains_Mono/README.txt b/web/client/public/fonts/JetBrains_Mono/README.txt new file mode 100644 index 0000000000..0a8510da61 --- /dev/null +++ b/web/client/public/fonts/JetBrains_Mono/README.txt @@ -0,0 +1,79 @@ +JetBrains Mono Variable Font +============================ + +This download contains JetBrains Mono as both variable fonts and static fonts. + +JetBrains Mono is a variable font with this axis: + wght + +This means all the styles are contained in these files: + JetBrainsMono-VariableFont_wght.ttf + JetBrainsMono-Italic-VariableFont_wght.ttf + +If your app fully supports variable fonts, you can now pick intermediate styles +that aren’t available as static fonts. Not all apps support variable fonts, and +in those cases you can use the static font files for JetBrains Mono: + static/JetBrainsMono-Thin.ttf + static/JetBrainsMono-ExtraLight.ttf + static/JetBrainsMono-Light.ttf + static/JetBrainsMono-Regular.ttf + static/JetBrainsMono-Medium.ttf + static/JetBrainsMono-SemiBold.ttf + static/JetBrainsMono-Bold.ttf + static/JetBrainsMono-ExtraBold.ttf + static/JetBrainsMono-ThinItalic.ttf + static/JetBrainsMono-ExtraLightItalic.ttf + static/JetBrainsMono-LightItalic.ttf + static/JetBrainsMono-Italic.ttf + static/JetBrainsMono-MediumItalic.ttf + static/JetBrainsMono-SemiBoldItalic.ttf + static/JetBrainsMono-BoldItalic.ttf + static/JetBrainsMono-ExtraBoldItalic.ttf + +Get started +----------- + +1. Install the font files you want to use + +2. Use your app's font picker to view the font family and all the +available styles + +Learn more about variable fonts +------------------------------- + + https://developers.google.com/web/fundamentals/design-and-ux/typography/variable-fonts + https://variablefonts.typenetwork.com + https://medium.com/variable-fonts + +In desktop apps + + https://theblog.adobe.com/can-variable-fonts-illustrator-cc + https://helpx.adobe.com/nz/photoshop/using/fonts.html#variable_fonts + +Online + + https://developers.google.com/fonts/docs/getting_started + https://developer.mozilla.org/en-US/docs/Web/CSS/CSS_Fonts/Variable_Fonts_Guide + https://developer.microsoft.com/en-us/microsoft-edge/testdrive/demos/variable-fonts + +Installing fonts + + MacOS: https://support.apple.com/en-us/HT201749 + Linux: https://www.google.com/search?q=how+to+install+a+font+on+gnu%2Blinux + Windows: https://support.microsoft.com/en-us/help/314960/how-to-install-or-remove-a-font-in-windows + +Android Apps + + https://developers.google.com/fonts/docs/android + https://developer.android.com/guide/topics/ui/look-and-feel/downloadable-fonts + +License +------- +Please read the full license text (OFL.txt) to understand the permissions, +restrictions and requirements for usage, redistribution, and modification. + +You can use them in your products & projects – print or digital, +commercial or otherwise. + +This isn't legal advice, please consider consulting a lawyer and see the full +license for all details. diff --git a/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-Bold.ttf b/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-Bold.ttf new file mode 100644 index 0000000000..b7484374e7 Binary files /dev/null and b/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-Bold.ttf differ diff --git a/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-BoldItalic.ttf b/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-BoldItalic.ttf new file mode 100644 index 0000000000..0091142300 Binary files /dev/null and b/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-BoldItalic.ttf differ diff --git a/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-ExtraBold.ttf b/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-ExtraBold.ttf new file mode 100644 index 0000000000..88eab2f7ba Binary files /dev/null and b/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-ExtraBold.ttf differ diff --git a/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-ExtraBoldItalic.ttf b/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-ExtraBoldItalic.ttf new file mode 100644 index 0000000000..85e67db403 Binary files /dev/null and b/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-ExtraBoldItalic.ttf differ diff --git a/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-ExtraLight.ttf b/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-ExtraLight.ttf new file mode 100644 index 0000000000..1f73714431 Binary files /dev/null and b/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-ExtraLight.ttf differ diff --git a/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-ExtraLightItalic.ttf b/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-ExtraLightItalic.ttf new file mode 100644 index 0000000000..745b58eeaa Binary files /dev/null and b/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-ExtraLightItalic.ttf differ diff --git a/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-Italic.ttf b/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-Italic.ttf new file mode 100644 index 0000000000..5b484dd610 Binary files /dev/null and b/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-Italic.ttf differ diff --git a/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-Light.ttf b/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-Light.ttf new file mode 100644 index 0000000000..296186f1db Binary files /dev/null and b/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-Light.ttf differ diff --git a/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-LightItalic.ttf b/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-LightItalic.ttf new file mode 100644 index 0000000000..399ede7440 Binary files /dev/null and b/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-LightItalic.ttf differ diff --git a/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-Medium.ttf b/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-Medium.ttf new file mode 100644 index 0000000000..ad31fbd7f0 Binary files /dev/null and b/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-Medium.ttf differ diff --git a/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-MediumItalic.ttf b/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-MediumItalic.ttf new file mode 100644 index 0000000000..4f499f281d Binary files /dev/null and b/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-MediumItalic.ttf differ diff --git a/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-Regular.ttf b/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-Regular.ttf new file mode 100644 index 0000000000..02bc07ea08 Binary files /dev/null and b/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-Regular.ttf differ diff --git a/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-SemiBold.ttf b/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-SemiBold.ttf new file mode 100644 index 0000000000..c3adfd3151 Binary files /dev/null and b/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-SemiBold.ttf differ diff --git a/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-SemiBoldItalic.ttf b/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-SemiBoldItalic.ttf new file mode 100644 index 0000000000..62d58add47 Binary files /dev/null and b/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-SemiBoldItalic.ttf differ diff --git a/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-Thin.ttf b/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-Thin.ttf new file mode 100644 index 0000000000..6a6a556f10 Binary files /dev/null and b/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-Thin.ttf differ diff --git a/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-ThinItalic.ttf b/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-ThinItalic.ttf new file mode 100644 index 0000000000..33a23d7ca8 Binary files /dev/null and b/web/client/public/fonts/JetBrains_Mono/static/JetBrainsMono-ThinItalic.ttf differ diff --git a/web/client/public/fonts/Martina_Plantijn/martina-plantijn-regular.otf b/web/client/public/fonts/Martina_Plantijn/martina-plantijn-regular.otf new file mode 100644 index 0000000000..12a17f7869 Binary files /dev/null and b/web/client/public/fonts/Martina_Plantijn/martina-plantijn-regular.otf differ diff --git a/web/client/public/fonts/Martina_Plantijn/martina-plantijn-regular.woff2 b/web/client/public/fonts/Martina_Plantijn/martina-plantijn-regular.woff2 new file mode 100644 index 0000000000..d0f050784b Binary files /dev/null and b/web/client/public/fonts/Martina_Plantijn/martina-plantijn-regular.woff2 differ diff --git a/web/client/public/fonts/Sohne/sohne-buch-kursiv.otf b/web/client/public/fonts/Sohne/sohne-buch-kursiv.otf new file mode 100644 index 0000000000..b2983733c5 Binary files /dev/null and b/web/client/public/fonts/Sohne/sohne-buch-kursiv.otf differ diff --git a/web/client/public/fonts/Sohne/sohne-buch-kursiv.woff2 b/web/client/public/fonts/Sohne/sohne-buch-kursiv.woff2 new file mode 100644 index 0000000000..98ea2ba6e4 Binary files /dev/null and b/web/client/public/fonts/Sohne/sohne-buch-kursiv.woff2 differ diff --git a/web/client/public/fonts/Sohne/sohne-buch.otf b/web/client/public/fonts/Sohne/sohne-buch.otf new file mode 100644 index 0000000000..66fd3b455f Binary files /dev/null and b/web/client/public/fonts/Sohne/sohne-buch.otf differ diff --git a/web/client/public/fonts/Sohne/sohne-buch.woff2 b/web/client/public/fonts/Sohne/sohne-buch.woff2 new file mode 100644 index 0000000000..6d62a8d55b Binary files /dev/null and b/web/client/public/fonts/Sohne/sohne-buch.woff2 differ diff --git a/web/client/public/fonts/Sohne/sohne-kraftig.otf b/web/client/public/fonts/Sohne/sohne-kraftig.otf new file mode 100644 index 0000000000..2f5e15c83b Binary files /dev/null and b/web/client/public/fonts/Sohne/sohne-kraftig.otf differ diff --git a/web/client/public/fonts/Sohne/sohne-kraftig.woff2 b/web/client/public/fonts/Sohne/sohne-kraftig.woff2 new file mode 100644 index 0000000000..f274f1444d Binary files /dev/null and b/web/client/public/fonts/Sohne/sohne-kraftig.woff2 differ diff --git a/web/client/public/fonts/Sohne/sohne-mono-buch.otf b/web/client/public/fonts/Sohne/sohne-mono-buch.otf new file mode 100644 index 0000000000..8b87ed1207 Binary files /dev/null and b/web/client/public/fonts/Sohne/sohne-mono-buch.otf differ diff --git a/web/client/public/fonts/Sohne/sohne-mono-buch.woff2 b/web/client/public/fonts/Sohne/sohne-mono-buch.woff2 new file mode 100644 index 0000000000..bdffed9187 Binary files /dev/null and b/web/client/public/fonts/Sohne/sohne-mono-buch.woff2 differ diff --git a/web/client/public/fonts/Sohne/sohne-mono-kraftig.otf b/web/client/public/fonts/Sohne/sohne-mono-kraftig.otf new file mode 100644 index 0000000000..ac65eb6d2a Binary files /dev/null and b/web/client/public/fonts/Sohne/sohne-mono-kraftig.otf differ diff --git a/web/client/public/fonts/Sohne/sohne-mono-kraftig.woff2 b/web/client/public/fonts/Sohne/sohne-mono-kraftig.woff2 new file mode 100644 index 0000000000..27fa1e659f Binary files /dev/null and b/web/client/public/fonts/Sohne/sohne-mono-kraftig.woff2 differ diff --git a/web/client/src/App.tsx b/web/client/src/App.tsx index 8be9ce44de..4ee9ae24fe 100644 --- a/web/client/src/App.tsx +++ b/web/client/src/App.tsx @@ -1,8 +1,6 @@ -import { useEffect, Suspense } from 'react' -import { RouterProvider } from 'react-router-dom' +import { useEffect, Suspense, lazy } from 'react' +import { RouterProvider } from 'react-router' import { Divider } from '@components/divider/Divider' -import Header from './library/pages/root/Header' -import Footer from './library/pages/root/Footer' import { getBrowserRouter } from './routes' import { useApiModules } from './api' import { useStoreContext } from '@context/context' @@ -11,6 +9,14 @@ import { EnumErrorKey, useNotificationCenter, } from './library/pages/root/context/notificationCenter' +import { isArrayNotEmpty, isNotNil } from './utils' +import NotFound from './library/pages/root/NotFound' + +const IS_HEADLESS: boolean = Boolean((window as any).__IS_HEADLESS__ ?? false) +const Header: Optional JSX.Element>> = + IS_HEADLESS ? undefined : lazy(() => import('./library/pages/root/Header')) +const Footer: Optional JSX.Element>> = + IS_HEADLESS ? undefined : lazy(() => import('./library/pages/root/Footer')) export default function App(): JSX.Element { const { addError } = useNotificationCenter() @@ -34,24 +40,37 @@ export default function App(): JSX.Element { } }, []) - const router = getBrowserRouter(modules.list) + const router = getBrowserRouter(modules) return ( <> -
- + {isNotNil(Header) && ( + <> +
+ + + )}
- {isFetching && ( + {isFetching ? ( Building Modules... + ) : ( + Loading Page...}> + {isArrayNotEmpty(modules.list) ? ( + + ) : ( + + )} + )} - Loading Page...}> - -
- -